Skip to content

Commit

Permalink
fix: fix grpc conversion bugs for sparse and multi vectors (#682)
Browse files Browse the repository at this point in the history
* fix: fix grpc conversion bugs for sparse and multi vectors

* fix: fix mypy
  • Loading branch information
joein committed Jul 8, 2024
1 parent 318260a commit 6c17384
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 3 deletions.
10 changes: 7 additions & 3 deletions qdrant_client/conversions/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2432,8 +2432,12 @@ def convert_vector(vector: Union[List[float], List[List[float]]]) -> grpc.Vector
vector[0], list
): # we can't say whether it is an empty dense or multi-dense vector
return grpc.Vector(
data=[inner_vector for multi_vector in model for inner_vector in multi_vector],
vectors_count=len(model),
data=[
inner_vector
for multi_vector in vector
for inner_vector in multi_vector # type: ignore
],
vectors_count=len(vector),
)
return grpc.Vector(data=vector)

Expand Down Expand Up @@ -2598,7 +2602,7 @@ def convert_prefetch_query(cls, model: rest.Prefetch) -> grpc.PrefetchQuery:

return grpc.PrefetchQuery(
prefetch=prefetch,
query=cls.convert_query(model.query) if model.query is not None else None,
query=cls.convert_query_interface(model.query) if model.query is not None else None,
using=model.using if model.using is not None else None,
filter=cls.convert_filter(model.filter) if model.filter is not None else None,
params=cls.convert_search_params(model.params) if model.params is not None else None,
Expand Down
65 changes: 65 additions & 0 deletions tests/congruence_tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,24 @@ def dense_query_text(self, client: QdrantBase) -> models.QueryResponse:
limit=10,
)

def dense_query_text_np_array(self, client: QdrantBase) -> models.QueryResponse:
return client.query_points(
collection_name=COLLECTION_NAME,
query=np.array(self.dense_vector_query_text),
using="text",
with_payload=True,
limit=10,
)

def dense_query_text_by_id(self, client: QdrantBase) -> models.QueryResponse:
return client.query_points(
collection_name=COLLECTION_NAME,
query=1,
using="text",
with_payload=True,
limit=10,
)

def dense_query_image(self, client: QdrantBase) -> models.QueryResponse:
return client.query_points(
collection_name=COLLECTION_NAME,
Expand Down Expand Up @@ -1100,3 +1118,50 @@ def test_query_with_nan():
print(local_client.query_points(COLLECTION_NAME, query=query))
with pytest.raises(UnexpectedResponse):
remote_client.query_points(COLLECTION_NAME, query=query)


@pytest.mark.parametrize("prefer_grpc", (False, True))
def test_flat_query_dense_interface(prefer_grpc):
fixture_points = generate_fixtures()

searcher = TestSimpleSearcher()

local_client = init_local()
init_client(local_client, fixture_points)

remote_client = init_remote(prefer_grpc=prefer_grpc)
init_client(remote_client, fixture_points)

compare_client_results(local_client, remote_client, searcher.dense_query_text)
compare_client_results(local_client, remote_client, searcher.dense_query_text_np_array)
compare_client_results(local_client, remote_client, searcher.dense_query_text_by_id)


@pytest.mark.parametrize("prefer_grpc", (False, True))
def test_flat_query_sparse_interface(prefer_grpc):
fixture_points = generate_sparse_fixtures()

searcher = TestSimpleSearcher()

local_client = init_local()
init_client(local_client, fixture_points, sparse_vectors_config=sparse_vectors_config)

remote_client = init_remote(prefer_grpc=prefer_grpc)
init_client(remote_client, fixture_points, sparse_vectors_config=sparse_vectors_config)

compare_client_results(local_client, remote_client, searcher.sparse_query_text)


@pytest.mark.parametrize("prefer_grpc", (True,))
def test_flat_query_multivector_interface(prefer_grpc):
fixture_points = generate_multivector_fixtures()

searcher = TestSimpleSearcher()

local_client = init_local()
init_client(local_client, fixture_points, vectors_config=multi_vector_config)

remote_client = init_remote(prefer_grpc=prefer_grpc)
init_client(remote_client, fixture_points, vectors_config=multi_vector_config)

compare_client_results(local_client, remote_client, searcher.multivec_query_text)
1 change: 1 addition & 0 deletions tests/conversions/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,7 @@
"sparse": grpc.Vector(
data=[1.0, 2.0, -1.0, -0.2], indices=SparseIndices(data=[1, 2, 3])
),
"multi": grpc.Vector(data=[1.0, 2.0, 3.0, 4.0], vectors_count=2),
}
)
)
Expand Down

0 comments on commit 6c17384

Please sign in to comment.