Skip to content

Commit

Permalink
Improve int4 compressed comparisons performance (apache#13321)
Browse files Browse the repository at this point in the history
This updates the int4 dot-product comparison to have an optimized one for when one of the vectors are compressed (the most common search case). This change actually makes the compressed search on ARM faster than the uncompressed. However, on AVX512/256, it still slightly slower than uncompressed, but it still much faster now with this optimization than before (eagerly decompressing).

This optimized is tied tightly with how the vectors are actually compressed and stored, consequently, I added a new scorer that is within the lucene99 codec.

So, this gives us 8x reduction over float32, well more than 2x faster queries than float32, and no need to rerank as the recall and accuracy are excellent.
  • Loading branch information
benwtrent authored May 1, 2024
1 parent e6dac34 commit e40e108
Show file tree
Hide file tree
Showing 18 changed files with 871 additions and 173 deletions.
3 changes: 3 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,9 @@ Optimizations
* GITHUB#13284: Per-field doc values and knn vectors readers now use a HashMap internally instead of
a TreeMap. (Adrien Grand)

* GITHUB#13321: Improve compressed int4 quantized vector search by utilizing SIMD inline with the decompression
process. (Ben Trent)

Bug Fixes
---------------------

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,21 @@
value = 3,
jvmArgsAppend = {"-Xmx2g", "-Xms2g", "-XX:+AlwaysPreTouch"})
public class VectorUtilBenchmark {
static void compressBytes(byte[] raw, byte[] compressed) {
for (int i = 0; i < compressed.length; ++i) {
int v = (raw[i] << 4) | raw[compressed.length + i];
compressed[i] = (byte) v;
}
}

private byte[] bytesA;
private byte[] bytesB;
private byte[] halfBytesA;
private byte[] halfBytesB;
private byte[] halfBytesBPacked;
private float[] floatsA;
private float[] floatsB;
private int expectedhalfByteDotProduct;

@Param({"1", "128", "207", "256", "300", "512", "702", "1024"})
int size;
Expand All @@ -55,11 +63,18 @@ public void init() {
random.nextBytes(bytesB);
// random half byte arrays for binary methods
// this means that all values must be between 0 and 15
expectedhalfByteDotProduct = 0;
halfBytesA = new byte[size];
halfBytesB = new byte[size];
for (int i = 0; i < size; ++i) {
halfBytesA[i] = (byte) random.nextInt(16);
halfBytesB[i] = (byte) random.nextInt(16);
expectedhalfByteDotProduct += halfBytesA[i] * halfBytesB[i];
}
// pack the half byte arrays
if (size % 2 == 0) {
halfBytesBPacked = new byte[(size + 1) >> 1];
compressBytes(halfBytesB, halfBytesBPacked);
}

// random float arrays for float methods
Expand Down Expand Up @@ -115,6 +130,31 @@ public int binaryHalfByteVector() {
return VectorUtil.int4DotProduct(halfBytesA, halfBytesB);
}

@Benchmark
public int binaryHalfByteScalarPacked() {
if (size % 2 != 0) {
throw new RuntimeException("Size must be even for this benchmark");
}
int v = VectorUtil.int4DotProductPacked(halfBytesA, halfBytesBPacked);
if (v != expectedhalfByteDotProduct) {
throw new RuntimeException("Expected " + expectedhalfByteDotProduct + " but got " + v);
}
return v;
}

@Benchmark
@Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
public int binaryHalfByteVectorPacked() {
if (size % 2 != 0) {
throw new RuntimeException("Size must be even for this benchmark");
}
int v = VectorUtil.int4DotProductPacked(halfBytesA, halfBytesBPacked);
if (v != expectedhalfByteDotProduct) {
throw new RuntimeException("Expected " + expectedhalfByteDotProduct + " but got " + v);
}
return v;
}

@Benchmark
public float floatCosineScalar() {
return VectorUtil.cosine(floatsA, floatsB);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@

import java.io.IOException;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues;
import org.apache.lucene.util.quantization.ScalarQuantizedRandomVectorScorer;
import org.apache.lucene.util.quantization.ScalarQuantizedRandomVectorScorerSupplier;
import org.apache.lucene.util.quantization.ScalarQuantizedVectorSimilarity;
import org.apache.lucene.util.quantization.ScalarQuantizer;

Expand All @@ -35,6 +35,23 @@
*/
public class ScalarQuantizedVectorScorer implements FlatVectorsScorer {

public static float quantizeQuery(
float[] query,
byte[] quantizedQuery,
VectorSimilarityFunction similarityFunction,
ScalarQuantizer scalarQuantizer) {
float[] processedQuery =
switch (similarityFunction) {
case EUCLIDEAN, DOT_PRODUCT, MAXIMUM_INNER_PRODUCT -> query;
case COSINE -> {
float[] queryCopy = ArrayUtil.copyOfSubArray(query, 0, query.length);
VectorUtil.l2normalize(queryCopy);
yield queryCopy;
}
};
return scalarQuantizer.quantize(processedQuery, quantizedQuery, similarityFunction);
}

private final FlatVectorsScorer nonQuantizedDelegate;

public ScalarQuantizedVectorScorer(FlatVectorsScorer flatVectorsScorer) {
Expand Down Expand Up @@ -65,18 +82,21 @@ public RandomVectorScorer getRandomVectorScorer(
ScalarQuantizer scalarQuantizer = quantizedByteVectorValues.getScalarQuantizer();
byte[] targetBytes = new byte[target.length];
float offsetCorrection =
ScalarQuantizedRandomVectorScorer.quantizeQuery(
target, targetBytes, similarityFunction, scalarQuantizer);
quantizeQuery(target, targetBytes, similarityFunction, scalarQuantizer);
ScalarQuantizedVectorSimilarity scalarQuantizedVectorSimilarity =
ScalarQuantizedVectorSimilarity.fromVectorSimilarity(
similarityFunction,
scalarQuantizer.getConstantMultiplier(),
scalarQuantizer.getBits());
return new ScalarQuantizedRandomVectorScorer(
scalarQuantizedVectorSimilarity,
quantizedByteVectorValues,
targetBytes,
offsetCorrection);
return new RandomVectorScorer.AbstractRandomVectorScorer(quantizedByteVectorValues) {
@Override
public float score(int node) throws IOException {
byte[] nodeVector = quantizedByteVectorValues.vectorValue(node);
float nodeOffset = quantizedByteVectorValues.getScoreCorrectionConstant(node);
return scalarQuantizedVectorSimilarity.score(
targetBytes, offsetCorrection, nodeVector, nodeOffset);
}
};
}
// It is possible to get to this branch during initial indexing and flush
return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target);
Expand All @@ -95,4 +115,60 @@ public RandomVectorScorer getRandomVectorScorer(
public String toString() {
return "ScalarQuantizedVectorScorer(" + "nonQuantizedDelegate=" + nonQuantizedDelegate + ')';
}

/**
* Quantized vector scorer supplier
*
* @lucene.experimental
*/
public static class ScalarQuantizedRandomVectorScorerSupplier
implements RandomVectorScorerSupplier {

private final RandomAccessQuantizedByteVectorValues values;
private final ScalarQuantizedVectorSimilarity similarity;
private final VectorSimilarityFunction vectorSimilarityFunction;

public ScalarQuantizedRandomVectorScorerSupplier(
VectorSimilarityFunction similarityFunction,
ScalarQuantizer scalarQuantizer,
RandomAccessQuantizedByteVectorValues values) {
this.similarity =
ScalarQuantizedVectorSimilarity.fromVectorSimilarity(
similarityFunction,
scalarQuantizer.getConstantMultiplier(),
scalarQuantizer.getBits());
this.values = values;
this.vectorSimilarityFunction = similarityFunction;
}

private ScalarQuantizedRandomVectorScorerSupplier(
ScalarQuantizedVectorSimilarity similarity,
VectorSimilarityFunction vectorSimilarityFunction,
RandomAccessQuantizedByteVectorValues values) {
this.similarity = similarity;
this.values = values;
this.vectorSimilarityFunction = vectorSimilarityFunction;
}

@Override
public RandomVectorScorer scorer(int ord) throws IOException {
final RandomAccessQuantizedByteVectorValues vectorsCopy = values.copy();
final byte[] queryVector = values.vectorValue(ord);
final float queryOffset = values.getScoreCorrectionConstant(ord);
return new RandomVectorScorer.AbstractRandomVectorScorer(vectorsCopy) {
@Override
public float score(int node) throws IOException {
byte[] nodeVector = vectorsCopy.vectorValue(node);
float nodeOffset = vectorsCopy.getScoreCorrectionConstant(node);
return similarity.score(queryVector, queryOffset, nodeVector, nodeOffset);
}
};
}

@Override
public RandomVectorScorerSupplier copy() throws IOException {
return new ScalarQuantizedRandomVectorScorerSupplier(
similarity, vectorSimilarityFunction, values.copy());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ public byte[] vectorValue(int targetOrd) throws IOException {
return binaryValue;
}

@Override
public IndexInput getSlice() {
return slice;
}

private void readValue(int targetOrd) throws IOException {
slice.seek((long) targetOrd * byteSize);
slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ public int size() {
return size;
}

@Override
public IndexInput getSlice() {
return slice;
}

@Override
public float[] vectorValue(int targetOrd) throws IOException {
if (lastOrd == targetOrd) {
Expand Down
Loading

0 comments on commit e40e108

Please sign in to comment.