Skip to content

Commit

Permalink
Update native vector provider to use unsigned int7 values only (elast…
Browse files Browse the repository at this point in the history
…ic#108243)

This commit updates the native vector provider to reflect that Lucene's scalar quantization is unsigned int7, with a range of values from 0 to 127 inclusive. Stride has been pushed down into native, to allow other platforms to more easily select there own stride length.

Previously the implementation supports signed int8. We might want the more general signed int8 implementation in the future, but for now unsigned int7 is sufficient, and allows to provide more efficient implementations on x64.
  • Loading branch information
ChrisHegarty authored May 4, 2024
1 parent 4daac77 commit 7f90a98
Show file tree
Hide file tree
Showing 19 changed files with 180 additions and 173 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ public void setup() throws IOException {
vec1 = new byte[dims];
vec2 = new byte[dims];

ThreadLocalRandom.current().nextBytes(vec1);
ThreadLocalRandom.current().nextBytes(vec2);
randomInt7BytesBetween(vec1);
randomInt7BytesBetween(vec2);
vec1Offset = ThreadLocalRandom.current().nextFloat();
vec2Offset = ThreadLocalRandom.current().nextFloat();

Expand All @@ -113,8 +113,8 @@ public void setup() throws IOException {
scoreCorrectionConstant
);
luceneSqrScorer = ScalarQuantizedVectorSimilarity.fromVectorSimilarity(VectorSimilarityFunction.EUCLIDEAN, scoreCorrectionConstant);
nativeDotScorer = factory.getScalarQuantizedVectorScorer(dims, size, scoreCorrectionConstant, DOT_PRODUCT, in).get();
nativeSqrScorer = factory.getScalarQuantizedVectorScorer(dims, size, scoreCorrectionConstant, EUCLIDEAN, in).get();
nativeDotScorer = factory.getInt7ScalarQuantizedVectorScorer(dims, size, scoreCorrectionConstant, DOT_PRODUCT, in).get();
nativeSqrScorer = factory.getInt7ScalarQuantizedVectorScorer(dims, size, scoreCorrectionConstant, EUCLIDEAN, in).get();

// sanity
var f1 = dotProductLucene();
Expand Down Expand Up @@ -185,4 +185,15 @@ public float squareDistanceScalar() {
float adjustedDistance = squareDistance * scoreCorrectionConstant;
return 1 / (1f + adjustedDistance);
}

// Unsigned int7 byte vectors have values in the range of 0 to 127 (inclusive).
static final byte MIN_INT7_VALUE = 0;
static final byte MAX_INT7_VALUE = 127;

static void randomInt7BytesBetween(byte[] bytes) {
var random = ThreadLocalRandom.current();
for (int i = 0, len = bytes.length; i < len;) {
bytes[i++] = (byte) random.nextInt(MIN_INT7_VALUE, MAX_INT7_VALUE + 1);
}
}
}
2 changes: 1 addition & 1 deletion libs/native/libraries/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ configurations {
}

var zstdVersion = "1.5.5"
var vecVersion = "1.0.3"
var vecVersion = "1.0.6"

repositories {
exclusiveContent {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,24 @@
*/
public interface VectorSimilarityFunctions {
/**
* Produces a method handle returning the dot product of byte (signed int8) vectors.
* Produces a method handle returning the dot product of byte (unsigned int7) vectors.
*
* <p> Unsigned int7 byte vectors have values in the range of 0 to 127 (inclusive).
*
* <p> The type of the method handle will have {@code int} as return type, The type of
* its first and second arguments will be {@code MemorySegment}, whose contents is the
* vector data bytes. The third argument is the length of the vector data.
*/
MethodHandle dotProductHandle();
MethodHandle dotProductHandle7u();

/**
* Produces a method handle returning the square distance of byte (signed int8) vectors.
* Produces a method handle returning the square distance of byte (unsigned int7) vectors.
*
* <p> Unsigned int7 byte vectors have values in the range of 0 to 127 (inclusive).
*
* <p> The type of the method handle will have {@code int} as return type, The type of
* its first and second arguments will be {@code MemorySegment}, whose contents is the
* vector data bytes. The third argument is the length of the vector data.
*/
MethodHandle squareDistanceHandle();
MethodHandle squareDistanceHandle7u();
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import java.lang.invoke.MethodType;

import static java.lang.foreign.ValueLayout.ADDRESS;
import static java.lang.foreign.ValueLayout.JAVA_BYTE;
import static java.lang.foreign.ValueLayout.JAVA_INT;
import static org.elasticsearch.nativeaccess.jdk.LinkerHelper.downcallHandle;

Expand Down Expand Up @@ -51,139 +50,87 @@ public VectorSimilarityFunctions getVectorSimilarityFunctions() {

private static final class JdkVectorSimilarityFunctions implements VectorSimilarityFunctions {

static final MethodHandle dot8stride$mh = downcallHandle("dot8s_stride", FunctionDescriptor.of(JAVA_INT));
static final MethodHandle sqr8stride$mh = downcallHandle("sqr8s_stride", FunctionDescriptor.of(JAVA_INT));

static final MethodHandle dot8s$mh = downcallHandle("dot8s", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT));
static final MethodHandle sqr8s$mh = downcallHandle("sqr8s", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT));

// Stride of the native implementation - consumes this number of bytes per loop invocation.
// There must be at least this number of bytes/elements available when going native
static final int DOT_STRIDE = 32;
static final int SQR_STRIDE = 16;

static {
assert DOT_STRIDE > 0 && (DOT_STRIDE & (DOT_STRIDE - 1)) == 0 : "Not a power of two";
assert dot8Stride() == DOT_STRIDE : dot8Stride() + " != " + DOT_STRIDE;
assert SQR_STRIDE > 0 && (SQR_STRIDE & (SQR_STRIDE - 1)) == 0 : "Not a power of two";
assert sqr8Stride() == SQR_STRIDE : sqr8Stride() + " != " + SQR_STRIDE;
}
static final MethodHandle dot7u$mh = downcallHandle("dot7u", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT));
static final MethodHandle sqr7u$mh = downcallHandle("sqr7u", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT));

/**
* Computes the dot product of given byte vectors.
* Computes the dot product of given unsigned int7 byte vectors.
*
* <p> Unsigned int7 byte vectors have values in the range of 0 to 127 (inclusive).
*
* @param a address of the first vector
* @param b address of the second vector
* @param length the vector dimensions
*/
static int dotProduct(MemorySegment a, MemorySegment b, int length) {
static int dotProduct7u(MemorySegment a, MemorySegment b, int length) {
assert length >= 0;
if (a.byteSize() != b.byteSize()) {
throw new IllegalArgumentException("dimensions differ: " + a.byteSize() + "!=" + b.byteSize());
}
if (length > a.byteSize()) {
throw new IllegalArgumentException("length: " + length + ", greater than vector dimensions: " + a.byteSize());
}
int i = 0;
int res = 0;
if (length >= DOT_STRIDE) {
i += length & ~(DOT_STRIDE - 1);
res = dot8s(a, b, i);
}

// tail
for (; i < length; i++) {
res += a.get(JAVA_BYTE, i) * b.get(JAVA_BYTE, i);
}
assert i == length;
return res;
return dot7u(a, b, length);
}

/**
* Computes the square distance of given byte vectors.
* Computes the square distance of given unsigned int7 byte vectors.
*
* <p> Unsigned int7 byte vectors have values in the range of 0 to 127 (inclusive).
*
* @param a address of the first vector
* @param b address of the second vector
* @param length the vector dimensions
*/
static int squareDistance(MemorySegment a, MemorySegment b, int length) {
static int squareDistance7u(MemorySegment a, MemorySegment b, int length) {
assert length >= 0;
if (a.byteSize() != b.byteSize()) {
throw new IllegalArgumentException("dimensions differ: " + a.byteSize() + "!=" + b.byteSize());
}
if (length > a.byteSize()) {
throw new IllegalArgumentException("length: " + length + ", greater than vector dimensions: " + a.byteSize());
}
int i = 0;
int res = 0;
if (length >= SQR_STRIDE) {
i += length & ~(SQR_STRIDE - 1);
res = sqr8s(a, b, i);
}

// tail
for (; i < length; i++) {
int dist = a.get(JAVA_BYTE, i) - b.get(JAVA_BYTE, i);
res += dist * dist;
}
assert i == length;
return res;
}

private static int dot8Stride() {
try {
return (int) dot8stride$mh.invokeExact();
} catch (Throwable t) {
throw new AssertionError(t);
}
}

private static int sqr8Stride() {
try {
return (int) sqr8stride$mh.invokeExact();
} catch (Throwable t) {
throw new AssertionError(t);
}
return sqr7u(a, b, length);
}

private static int dot8s(MemorySegment a, MemorySegment b, int length) {
private static int dot7u(MemorySegment a, MemorySegment b, int length) {
try {
return (int) dot8s$mh.invokeExact(a, b, length);
return (int) dot7u$mh.invokeExact(a, b, length);
} catch (Throwable t) {
throw new AssertionError(t);
}
}

private static int sqr8s(MemorySegment a, MemorySegment b, int length) {
private static int sqr7u(MemorySegment a, MemorySegment b, int length) {
try {
return (int) sqr8s$mh.invokeExact(a, b, length);
return (int) sqr7u$mh.invokeExact(a, b, length);
} catch (Throwable t) {
throw new AssertionError(t);
}
}

static final MethodHandle DOT_HANDLE;
static final MethodHandle SQR_HANDLE;
static final MethodHandle DOT_HANDLE_7U;
static final MethodHandle SQR_HANDLE_7U;

static {
try {
var lookup = MethodHandles.lookup();
var mt = MethodType.methodType(int.class, MemorySegment.class, MemorySegment.class, int.class);
DOT_HANDLE = lookup.findStatic(JdkVectorSimilarityFunctions.class, "dotProduct", mt);
SQR_HANDLE = lookup.findStatic(JdkVectorSimilarityFunctions.class, "squareDistance", mt);
DOT_HANDLE_7U = lookup.findStatic(JdkVectorSimilarityFunctions.class, "dotProduct7u", mt);
SQR_HANDLE_7U = lookup.findStatic(JdkVectorSimilarityFunctions.class, "squareDistance7u", mt);
} catch (NoSuchMethodException | IllegalAccessException e) {
throw new RuntimeException(e);
}
}

@Override
public MethodHandle dotProductHandle() {
return DOT_HANDLE;
public MethodHandle dotProductHandle7u() {
return DOT_HANDLE_7U;
}

@Override
public MethodHandle squareDistanceHandle() {
return SQR_HANDLE;
public MethodHandle squareDistanceHandle7u() {
return SQR_HANDLE_7U;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@

public class JDKVectorLibraryTests extends VectorSimilarityFunctionsTests {

// bounds of the range of values that can be seen by int7 scalar quantized vectors
static final byte MIN_INT7_VALUE = 0;
static final byte MAX_INT7_VALUE = 127;

static final Class<IllegalArgumentException> IAE = IllegalArgumentException.class;

static final int[] VECTOR_DIMS = { 1, 4, 6, 8, 13, 16, 25, 31, 32, 33, 64, 100, 128, 207, 256, 300, 512, 702, 1023, 1024, 1025 };
Expand Down Expand Up @@ -49,14 +53,14 @@ public static Iterable<Object[]> parametersFactory() {
return () -> IntStream.of(VECTOR_DIMS).boxed().map(i -> new Object[] { i }).iterator();
}

public void testBinaryVectors() {
public void testInt7BinaryVectors() {
assumeTrue(notSupportedMsg(), supported());
final int dims = size;
final int numVecs = randomIntBetween(2, 101);
var values = new byte[numVecs][dims];
var segment = arena.allocate((long) dims * numVecs);
for (int i = 0; i < numVecs; i++) {
random().nextBytes(values[i]);
randomBytesBetween(values[i], MIN_INT7_VALUE, MAX_INT7_VALUE);
MemorySegment.copy(MemorySegment.ofArray(values[i]), 0L, segment, (long) i * dims, dims);
}

Expand All @@ -65,29 +69,29 @@ public void testBinaryVectors() {
int first = randomInt(numVecs - 1);
int second = randomInt(numVecs - 1);
// dot product
int implDot = dotProduct(segment.asSlice((long) first * dims, dims), segment.asSlice((long) second * dims, dims), dims);
int implDot = dotProduct7u(segment.asSlice((long) first * dims, dims), segment.asSlice((long) second * dims, dims), dims);
int otherDot = dotProductScalar(values[first], values[second]);
assertEquals(otherDot, implDot);

int squareDist = squareDistance(segment.asSlice((long) first * dims, dims), segment.asSlice((long) second * dims, dims), dims);
int otherSq = squareDistanceScalar(values[first], values[second]);
assertEquals(otherSq, squareDist);
int implSqr = squareDistance7u(segment.asSlice((long) first * dims, dims), segment.asSlice((long) second * dims, dims), dims);
int otherSqr = squareDistanceScalar(values[first], values[second]);
assertEquals(otherSqr, implSqr);
}
}

public void testIllegalDims() {
assumeTrue(notSupportedMsg(), supported());
var segment = arena.allocate((long) size * 3);
var e = expectThrows(IAE, () -> dotProduct(segment.asSlice(0L, size), segment.asSlice(size, size + 1), size));
var e = expectThrows(IAE, () -> dotProduct7u(segment.asSlice(0L, size), segment.asSlice(size, size + 1), size));
assertThat(e.getMessage(), containsString("dimensions differ"));

e = expectThrows(IAE, () -> dotProduct(segment.asSlice(0L, size), segment.asSlice(size, size), size + 1));
e = expectThrows(IAE, () -> dotProduct7u(segment.asSlice(0L, size), segment.asSlice(size, size), size + 1));
assertThat(e.getMessage(), containsString("greater than vector dimensions"));
}

int dotProduct(MemorySegment a, MemorySegment b, int length) {
int dotProduct7u(MemorySegment a, MemorySegment b, int length) {
try {
return (int) getVectorDistance().dotProductHandle().invokeExact(a, b, length);
return (int) getVectorDistance().dotProductHandle7u().invokeExact(a, b, length);
} catch (Throwable e) {
if (e instanceof Error err) {
throw err;
Expand All @@ -99,9 +103,9 @@ int dotProduct(MemorySegment a, MemorySegment b, int length) {
}
}

int squareDistance(MemorySegment a, MemorySegment b, int length) {
int squareDistance7u(MemorySegment a, MemorySegment b, int length) {
try {
return (int) getVectorDistance().squareDistanceHandle().invokeExact(a, b, length);
return (int) getVectorDistance().squareDistanceHandle7u().invokeExact(a, b, length);
} catch (Throwable e) {
if (e instanceof Error err) {
throw err;
Expand Down
8 changes: 7 additions & 1 deletion libs/vec/native/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,19 @@ apply plugin: 'c'

var os = org.gradle.internal.os.OperatingSystem.current()

// To update this library run publish_vec_binaries.sh
// To update this library run publish_vec_binaries.sh ( or ./gradlew vecSharedLibrary )
// Or
// For local development, build the docker image with:
// docker build --platform linux/arm64 --progress=plain .
// Grab the image id from the console output, then, e.g.
// docker run 9c9f36564c148b275aeecc42749e7b4580ded79dcf51ff6ccc008c8861e7a979 > build/libs/vec/shared/libvec.so
//
// To run tests and benchmarks on a locally built libvec,
// 1. Temporarily comment out the download in libs/native/library/build.gradle
// libs "org.elasticsearch:vec:${vecVersion}@zip"
// 2. Copy your locally built libvec binary, e.g.
// cp libs/vec/native/build/libs/vec/shared/libvec.dylib libs/native/libraries/build/platform/darwin-aarch64/libvec.dylib
//
// Look at the disassemble:
// objdump --disassemble-symbols=_dot8s build/libs/vec/shared/libvec.dylib
// Note: symbol decoration may differ on Linux, i.e. the leading underscore is not present
Expand Down
2 changes: 1 addition & 1 deletion libs/vec/native/publish_vec_binaries.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ if [ -z "$ARTIFACTORY_API_KEY" ]; then
exit 1;
fi

VERSION="1.0.3"
VERSION="1.0.6"
ARTIFACTORY_REPOSITORY="${ARTIFACTORY_REPOSITORY:-https://artifactory.elastic.dev/artifactory/elasticsearch-native/}"
TEMP=$(mktemp -d)

Expand Down
Loading

0 comments on commit 7f90a98

Please sign in to comment.