diff --git a/qdrant_client/conversions/conversion.py b/qdrant_client/conversions/conversion.py index d9c7639b..8205673a 100644 --- a/qdrant_client/conversions/conversion.py +++ b/qdrant_client/conversions/conversion.py @@ -2432,8 +2432,12 @@ def convert_vector(vector: Union[List[float], List[List[float]]]) -> grpc.Vector vector[0], list ): # we can't say whether it is an empty dense or multi-dense vector return grpc.Vector( - data=[inner_vector for multi_vector in model for inner_vector in multi_vector], - vectors_count=len(model), + data=[ + inner_vector + for multi_vector in vector + for inner_vector in multi_vector # type: ignore + ], + vectors_count=len(vector), ) return grpc.Vector(data=vector) @@ -2598,7 +2602,7 @@ def convert_prefetch_query(cls, model: rest.Prefetch) -> grpc.PrefetchQuery: return grpc.PrefetchQuery( prefetch=prefetch, - query=cls.convert_query(model.query) if model.query is not None else None, + query=cls.convert_query_interface(model.query) if model.query is not None else None, using=model.using if model.using is not None else None, filter=cls.convert_filter(model.filter) if model.filter is not None else None, params=cls.convert_search_params(model.params) if model.params is not None else None, diff --git a/tests/congruence_tests/test_query.py b/tests/congruence_tests/test_query.py index 5d489258..63fc0053 100644 --- a/tests/congruence_tests/test_query.py +++ b/tests/congruence_tests/test_query.py @@ -83,6 +83,24 @@ def dense_query_text(self, client: QdrantBase) -> models.QueryResponse: limit=10, ) + def dense_query_text_np_array(self, client: QdrantBase) -> models.QueryResponse: + return client.query_points( + collection_name=COLLECTION_NAME, + query=np.array(self.dense_vector_query_text), + using="text", + with_payload=True, + limit=10, + ) + + def dense_query_text_by_id(self, client: QdrantBase) -> models.QueryResponse: + return client.query_points( + collection_name=COLLECTION_NAME, + query=1, + using="text", + with_payload=True, + limit=10, + ) + def dense_query_image(self, client: QdrantBase) -> models.QueryResponse: return client.query_points( collection_name=COLLECTION_NAME, @@ -1100,3 +1118,50 @@ def test_query_with_nan(): print(local_client.query_points(COLLECTION_NAME, query=query)) with pytest.raises(UnexpectedResponse): remote_client.query_points(COLLECTION_NAME, query=query) + + +@pytest.mark.parametrize("prefer_grpc", (False, True)) +def test_flat_query_dense_interface(prefer_grpc): + fixture_points = generate_fixtures() + + searcher = TestSimpleSearcher() + + local_client = init_local() + init_client(local_client, fixture_points) + + remote_client = init_remote(prefer_grpc=prefer_grpc) + init_client(remote_client, fixture_points) + + compare_client_results(local_client, remote_client, searcher.dense_query_text) + compare_client_results(local_client, remote_client, searcher.dense_query_text_np_array) + compare_client_results(local_client, remote_client, searcher.dense_query_text_by_id) + + +@pytest.mark.parametrize("prefer_grpc", (False, True)) +def test_flat_query_sparse_interface(prefer_grpc): + fixture_points = generate_sparse_fixtures() + + searcher = TestSimpleSearcher() + + local_client = init_local() + init_client(local_client, fixture_points, sparse_vectors_config=sparse_vectors_config) + + remote_client = init_remote(prefer_grpc=prefer_grpc) + init_client(remote_client, fixture_points, sparse_vectors_config=sparse_vectors_config) + + compare_client_results(local_client, remote_client, searcher.sparse_query_text) + + +@pytest.mark.parametrize("prefer_grpc", (True,)) +def test_flat_query_multivector_interface(prefer_grpc): + fixture_points = generate_multivector_fixtures() + + searcher = TestSimpleSearcher() + + local_client = init_local() + init_client(local_client, fixture_points, vectors_config=multi_vector_config) + + remote_client = init_remote(prefer_grpc=prefer_grpc) + init_client(remote_client, fixture_points, vectors_config=multi_vector_config) + + compare_client_results(local_client, remote_client, searcher.multivec_query_text) diff --git a/tests/conversions/fixtures.py b/tests/conversions/fixtures.py index 86b21c09..0e0b6a90 100644 --- a/tests/conversions/fixtures.py +++ b/tests/conversions/fixtures.py @@ -568,6 +568,7 @@ "sparse": grpc.Vector( data=[1.0, 2.0, -1.0, -0.2], indices=SparseIndices(data=[1, 2, 3]) ), + "multi": grpc.Vector(data=[1.0, 2.0, 3.0, 4.0], vectors_count=2), } ) )