Skip to content

Commit

Permalink
HSEARCH-5052 Filter out documents with vectors below a "similarity li…
Browse files Browse the repository at this point in the history
…mit"
  • Loading branch information
marko-bekhta committed Dec 22, 2023
1 parent e56cfb7 commit 47a71c5
Show file tree
Hide file tree
Showing 11 changed files with 336 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -883,4 +883,9 @@ SearchException vectorKnnMatchVectorTypeDiffersFromField(String absoluteFieldPat
value = "Elasticsearch distribution does not allow to apply constant score to a knn predicate."
+ " Constant score will not be applied.")
void elasticsearchKnnIgnoresConstantScore();

@Message(id = ID_OFFSET + 188,
value = "An OpenSearch distribution does not allow specifying the `required minimum similarity` option. "
+ "This option is only applicable to an Elasticsearch distribution of an Elasticsearch backend.")
SearchException knnRequiredMinimumSimilarityUnsupportedOption();
}
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@ public ElasticsearchKnnPredicateOptionsStep<ElasticsearchKnnPredicateOptionsStep
return this;
}

@Override
public ElasticsearchKnnPredicateOptionsStep<?> requiredMinimumSimilarity(float similarity) {
this.builder.requiredMinimumSimilarity( similarity );
return this;
}

@Override
public ElasticsearchKnnPredicateOptionsStep<ElasticsearchKnnPredicateOptionsStep<?>> matching(byte... vector) {
this.builder.vector( vector );
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,15 @@ public abstract class ElasticsearchKnnPredicate extends AbstractElasticsearchSin
protected final int k;
protected final JsonArray vector;
protected final Integer numberOfCandidates;
protected final Float similarity;

private ElasticsearchKnnPredicate(AbstractKnnBuilder<?> builder) {
super( builder );
this.filter = builder.filter;
this.k = builder.k;
this.vector = builder.vector;
this.numberOfCandidates = builder.numberOfCandidates;
this.similarity = builder.similarity;
builder.filter = null;
builder.vector = null;
}
Expand Down Expand Up @@ -80,6 +82,7 @@ private abstract static class AbstractKnnBuilder<F> extends AbstractBuilder impl
private JsonArray vector;
private ElasticsearchSearchPredicate filter;
protected Integer numberOfCandidates;
protected Float similarity;

private AbstractKnnBuilder(ElasticsearchFieldCodec<F> codec, ElasticsearchSearchIndexScope<?> scope,
ElasticsearchSearchIndexValueFieldContext<F> field) {
Expand Down Expand Up @@ -150,6 +153,7 @@ private static class ElasticsearchImpl extends ElasticsearchKnnPredicate {
private static final JsonObjectAccessor FILTER_ACCESSOR = JsonAccessor.root().property( "filter" ).asObject();
private static final JsonAccessor<Integer> NUM_CANDIDATES_ACCESSOR =
JsonAccessor.root().property( "num_candidates" ).asInteger();
private static final JsonAccessor<Float> SIMILARITY_ACCESSOR = JsonAccessor.root().property( "similarity" ).asFloat();


private ElasticsearchImpl(Builder<?> builder) {
Expand Down Expand Up @@ -177,7 +181,9 @@ protected JsonObject doToJsonQuery(PredicateRequestContext context, JsonObject o
}
NUM_CANDIDATES_ACCESSOR.set( innerObject, numberOfCandidates != null ? numberOfCandidates : k );
QUERY_VECTOR_ACCESSOR.set( innerObject, vector );

if ( similarity != null ) {
SIMILARITY_ACCESSOR.set( innerObject, similarity );
}
return innerObject;
}

Expand Down Expand Up @@ -205,6 +211,11 @@ public void constantScore() {
log.elasticsearchKnnIgnoresConstantScore();
}

@Override
public void requiredMinimumSimilarity(float similarity) {
this.similarity = similarity;
}

@Override
public SearchPredicate build() {
return new ElasticsearchImpl( this );
Expand Down Expand Up @@ -250,6 +261,11 @@ public void numberOfCandidates(int numberOfCandidates) {
throw log.knnNumberOfCandidatesUnsupportedOption();
}

@Override
public void requiredMinimumSimilarity(float similarity) {
throw log.knnRequiredMinimumSimilarityUnsupportedOption();
}

@Override
public SearchPredicate build() {
return new OpenSearchImpl( this );
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
/*
* Hibernate Search, full-text search for your domain model
*
* License: GNU Lesser General Public License (LGPL), version 2.1 or later
* See the lgpl.txt file in the root directory or <http://www.gnu.org/licenses/lgpl-2.1.html>.
*/
package org.hibernate.search.backend.lucene.lowlevel.query.impl;

import java.io.IOException;
import java.util.Objects;

import org.hibernate.search.util.common.AssertionFailure;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.FilterWeight;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.KnnByteVectorQuery;
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.TwoPhaseIterator;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.VectorUtil;

public class VectorSimilarityFilterQuery extends Query {

private final Query query;
private final float similarityAsScore;

public static VectorSimilarityFilterQuery create(KnnByteVectorQuery query, float similarityLimit,
VectorSimilarityFunction vectorSimilarityFunction) {
// We assume that `similarityLimit` is a distance so we need to convert it to the score using a formula from a
// similarity function:
return new VectorSimilarityFilterQuery(
query, byteSimilarityDistanceToScore( similarityLimit, vectorSimilarityFunction ) );
}

public static VectorSimilarityFilterQuery create(KnnFloatVectorQuery query, float similarityLimit,
VectorSimilarityFunction vectorSimilarityFunction) {
// We assume that `similarityLimit` is a distance so we need to convert it to the score using a formula from a
// similarity function:
return new VectorSimilarityFilterQuery(
query, floatSimilarityDistanceToScore( similarityLimit, vectorSimilarityFunction ) );
}

private VectorSimilarityFilterQuery(Query query, float similarityAsScore) {
this.query = query;
this.similarityAsScore = similarityAsScore;
}

@Override
public Query rewrite(IndexSearcher indexSearcher) throws IOException {
Query rewritten = query.rewrite( indexSearcher );
if ( rewritten == query ) {
return this;
}
// Knn queries are rewritten and we need to use a rewritten one to get the weights and scores:
return new VectorSimilarityFilterQuery( rewritten, this.similarityAsScore );
}

@Override
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
// we've already converted distance/similarity to a score, but now if the underlying query is boosting the score,
// we'd want to boost our converted one as well to get the expected matches:
return new SimilarityWeight( query.createWeight( searcher, scoreMode, boost ), similarityAsScore * boost );
}

@Override
public void visit(QueryVisitor visitor) {
visitor.visitLeaf( this );
}

@Override
public String toString(String field) {
return getClass().getName() + "{" +
"query=" + query +
", similarityLimit=" + similarityAsScore +
'}';
}

@Override
public boolean equals(Object o) {
if ( this == o ) {
return true;
}
if ( o == null || getClass() != o.getClass() ) {
return false;
}
VectorSimilarityFilterQuery that = (VectorSimilarityFilterQuery) o;
return Float.compare( similarityAsScore, that.similarityAsScore ) == 0 && Objects.equals( query, that.query );
}

@Override
public int hashCode() {
return Objects.hash( query, similarityAsScore );
}

private static float floatSimilarityDistanceToScore(float distance, VectorSimilarityFunction similarityFunction) {
switch ( similarityFunction ) {
case EUCLIDEAN:
return 1.0f / ( 1.0f + distance );
case DOT_PRODUCT:
case COSINE:
return Math.max( ( 1 + distance ) / 2, 0 );
case MAXIMUM_INNER_PRODUCT:
return VectorUtil.scaleMaxInnerProductScore( distance );
default:
throw new AssertionFailure( "Unknown similarity function: " + similarityFunction );
}
}

private static float byteSimilarityDistanceToScore(float distance, VectorSimilarityFunction similarityFunction) {
switch ( similarityFunction ) {
case EUCLIDEAN:
return 1.0f / ( 1.0f + distance );
case DOT_PRODUCT:
return distance;
case COSINE:
return ( 1 + distance ) / 2;
case MAXIMUM_INNER_PRODUCT:
return VectorUtil.scaleMaxInnerProductScore( distance );
default:
throw new AssertionFailure( "Unknown similarity function: " + similarityFunction );
}
}

private static class SimilarityWeight extends FilterWeight {
private final float similarityAsScore;

protected SimilarityWeight(Weight weight, float similarityAsScore) {
super( weight );
this.similarityAsScore = similarityAsScore;
}

@Override
public Explanation explain(LeafReaderContext context, int doc) throws IOException {
Explanation explanation = super.explain( context, doc );
if ( explanation.isMatch() && similarityAsScore > explanation.getValue().floatValue() ) {
return Explanation.noMatch( "Similarity limit is greater than the vector similarity.", explanation );
}
return explanation;
}

@Override
public Scorer scorer(LeafReaderContext context) throws IOException {
Scorer scorer = super.scorer( context );
return new MinScoreScorer( this, scorer, similarityAsScore );
}
}

// An adapted version of `org.opensearch.common.lucene.search.function.MinScoreScorer`:
private static class MinScoreScorer extends Scorer {
private final Scorer in;
private final float minScore;
private float curScore;

MinScoreScorer(Weight weight, Scorer scorer, float minScore) {
super( weight );
this.in = scorer;
this.minScore = minScore;
}

@Override
public int docID() {
return in.docID();
}

@Override
public float score() {
return curScore;
}

@Override
public int advanceShallow(int target) throws IOException {
return in.advanceShallow( target );
}

@Override
public float getMaxScore(int upTo) throws IOException {
return in.getMaxScore( upTo );
}

@Override
public DocIdSetIterator iterator() {
return TwoPhaseIterator.asDocIdSetIterator( twoPhaseIterator() );
}

@Override
public TwoPhaseIterator twoPhaseIterator() {
final TwoPhaseIterator inTwoPhase = this.in.twoPhaseIterator();
final DocIdSetIterator approximation = inTwoPhase == null ? in.iterator() : inTwoPhase.approximation();
return new TwoPhaseIterator( approximation ) {

@Override
public boolean matches() throws IOException {
// we need to check the two-phase iterator first
// otherwise calling score() is illegal
if ( inTwoPhase != null && !inTwoPhase.matches() ) {
return false;
}
curScore = in.score();
return curScore >= minScore;
}

@Override
public float matchCost() {
return 1000f // random constant for the score computation
+ ( inTwoPhase == null ? 0 : inTwoPhase.matchCost() );
}
};
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import java.lang.reflect.Array;

import org.hibernate.search.backend.lucene.logging.impl.Log;
import org.hibernate.search.backend.lucene.lowlevel.query.impl.VectorSimilarityFilterQuery;
import org.hibernate.search.backend.lucene.search.common.impl.AbstractLuceneValueFieldSearchQueryElementFactory;
import org.hibernate.search.backend.lucene.search.common.impl.LuceneSearchIndexScope;
import org.hibernate.search.backend.lucene.search.common.impl.LuceneSearchIndexValueFieldContext;
Expand All @@ -20,6 +21,7 @@
import org.hibernate.search.util.common.AssertionFailure;
import org.hibernate.search.util.common.logging.impl.LoggerFactory;

import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.KnnByteVectorQuery;
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.Query;
Expand All @@ -31,23 +33,29 @@ public class LuceneKnnPredicate extends AbstractLuceneSingleFieldPredicate imple
private final int k;
private final Object vector;
private final LuceneSearchPredicate filter;
private final Float similarity;
private final VectorSimilarityFunction similarityFunction;

private LuceneKnnPredicate(Builder<?> builder) {
super( builder );
this.k = builder.k;
this.vector = builder.vector;
this.filter = builder.filter;
this.similarity = builder.similarity;
this.similarityFunction = builder.similarityFunction;
}

@Override
protected Query doToQuery(PredicateRequestContext context) {
if ( vector instanceof byte[] ) {
return new KnnByteVectorQuery(
KnnByteVectorQuery query = new KnnByteVectorQuery(
absoluteFieldPath, (byte[]) vector, k, filter == null ? null : filter.toQuery( context ) );
return similarity == null ? query : VectorSimilarityFilterQuery.create( query, similarity, similarityFunction );
}
if ( vector instanceof float[] ) {
return new KnnFloatVectorQuery(
KnnFloatVectorQuery query = new KnnFloatVectorQuery(
absoluteFieldPath, (float[]) vector, k, filter == null ? null : filter.toQuery( context ) );
return similarity == null ? query : VectorSimilarityFilterQuery.create( query, similarity, similarityFunction );
}

throw new UnsupportedOperationException(
Expand All @@ -65,9 +73,11 @@ public KnnPredicateBuilder create(LuceneSearchIndexScope<?> scope, LuceneSearchI
private static class Builder<F> extends AbstractBuilder implements KnnPredicateBuilder {
private final Class<?> vectorElementsType;
private final int indexedVectorsDimension;
private final VectorSimilarityFunction similarityFunction;
private int k;
private Object vector;
private LuceneSearchPredicate filter;
private Float similarity;

private Builder(LuceneSearchIndexScope<?> scope, LuceneSearchIndexValueFieldContext<F> field) {
super( scope, field );
Expand All @@ -77,6 +87,7 @@ private Builder(LuceneSearchIndexScope<?> scope, LuceneSearchIndexValueFieldCont
LuceneVectorFieldCodec<F> vectorCodec = (LuceneVectorFieldCodec<F>) codec;
vectorElementsType = vectorCodec.vectorElementsType();
indexedVectorsDimension = vectorCodec.getConfiguredDimensions();
similarityFunction = vectorCodec.getVectorSimilarity();
}
else {
// shouldn't really happen as if someone tries this it should fail on `queryElementFactory` lookup.
Expand Down Expand Up @@ -111,6 +122,11 @@ public void filter(SearchPredicate filter) {
this.filter = LuceneSearchPredicate.from( scope, filter );
}

@Override
public void requiredMinimumSimilarity(float similarity) {
this.similarity = similarity;
}

@Override
public SearchPredicate build() {
return new LuceneKnnPredicate( this );
Expand Down
Loading

0 comments on commit 47a71c5

Please sign in to comment.