Skip to content

Commit

Permalink
Add options to QueryRelevantDocuments, add a new WithScoreThreshold o…
Browse files Browse the repository at this point in the history
…ption

It is useful to either retrieve only the relevant documents or none at
all to avoid tainting the conversation with documents that are not
relevant to add.

To allow that, this patch converts the existing Limit parameter to a
functional option and adds a new one WithScoreThreshold that allows to
set the minimal required cosine distance between the document and the
vectorized query.
  • Loading branch information
jhrozek committed Oct 31, 2024
1 parent c3a1ff3 commit f45d3f6
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 9 deletions.
4 changes: 3 additions & 1 deletion examples/qdrant/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ func main() {
}

// Query the most relevant documents based on a given embedding
retrievedDocs, err := vectorDB.QueryRelevantDocuments(ctx, queryEmbedding, 5, collection_name)
retrievedDocs, err := vectorDB.QueryRelevantDocuments(
ctx, queryEmbedding, collection_name,
db.WithLimit(5), db.WithScoreThreshold(0.7))
if err != nil {
log.Fatalf("Failed to query documents: %v", err)
}
Expand Down
31 changes: 27 additions & 4 deletions pkg/db/qdrant.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,25 +88,48 @@ func (qv *QdrantVector) SaveEmbeddings(ctx context.Context, docID string, embedd
return nil
}

// QueryOpt represents an option for a query. This is the type that should
// be returned from query options functions.
type QueryOpt func(*qdrant.QueryPoints)

// WithLimit sets the limit of the number of documents to return in a query.
func WithLimit(limit uint64) QueryOpt {
return func(q *qdrant.QueryPoints) {
q.Limit = &limit
}
}

// WithScoreThreshold sets the score threshold for a query. The higher the threshold, the more relevant the results.
func WithScoreThreshold(threshold float32) QueryOpt {
return func(q *qdrant.QueryPoints) {
q.ScoreThreshold = &threshold
}
}

// QueryRelevantDocuments retrieves the most relevant documents based on a given embedding.
//
// Parameters:
// - ctx: The context for the query.
// - embedding: The query embedding.
// - limit: The number of documents to return.
// - collection: The collection name to query.
//
// Returns:
// - A slice of QDrantDocument structs containing the most relevant documents.
// - An error if the query fails.
func (qv *QdrantVector) QueryRelevantDocuments(ctx context.Context, embedding []float32, limit int, colllection string) ([]Document, error) {
limitUint := uint64(limit) // Convert limit to uint64
func (qv *QdrantVector) QueryRelevantDocuments(
ctx context.Context, embedding []float32, collection string, queryOpts ...QueryOpt,
) ([]Document, error) {
query := &qdrant.QueryPoints{
CollectionName: colllection, // Replace with actual collection name
CollectionName: collection, // Replace with actual collection name
Query: qdrant.NewQuery(embedding...),
Limit: &limitUint,
WithPayload: qdrant.NewWithPayloadInclude("content"),
}

for _, opt := range queryOpts {
opt(query)
}

response, err := qv.client.Query(ctx, query)
if err != nil {
return nil, fmt.Errorf("failed to search points: %w", err)
Expand Down
13 changes: 9 additions & 4 deletions pkg/db/qdrant_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,19 @@ func (t *testQdrantVector) SaveEmbeddings(ctx context.Context, docID string, emb
return err
}

func (t *testQdrantVector) QueryRelevantDocuments(ctx context.Context, embedding []float32, limit int, collection string) ([]Document, error) {
limitUint := uint64(limit)
func (t *testQdrantVector) QueryRelevantDocuments(
ctx context.Context, embedding []float32, collection string, queryOpts ...QueryOpt,
) ([]Document, error) {
query := &qdrant.QueryPoints{
CollectionName: collection,
Query: qdrant.NewQuery(embedding...),
Limit: &limitUint,
WithPayload: qdrant.NewWithPayloadInclude("content"),
}

for _, opt := range queryOpts {
opt(query)
}

response, err := t.mockClient.Query(ctx, query)
if err != nil {
return nil, err
Expand Down Expand Up @@ -186,7 +190,8 @@ func TestQueryRelevantDocuments(t *testing.T) {
})).Return(mockResponse, nil)

// Test the QueryRelevantDocuments function
docs, err := qv.QueryRelevantDocuments(ctx, embedding, limit, collection)
docs, err := qv.QueryRelevantDocuments(ctx, embedding, collection,
WithLimit(5), WithScoreThreshold(0.7))
assert.NoError(t, err)
assert.Len(t, docs, 1)
assert.Equal(t, "test content", docs[0].Metadata["content"])
Expand Down

0 comments on commit f45d3f6

Please sign in to comment.