From 4480e53999b665af1a8f5e63a3f17c70de7f45af Mon Sep 17 00:00:00 2001 From: Michael Feil Date: Tue, 31 Oct 2023 20:31:21 +0100 Subject: [PATCH] format test --- .../tests/end_to_end/test_fastembed.py | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/libs/infinity_emb/tests/end_to_end/test_fastembed.py b/libs/infinity_emb/tests/end_to_end/test_fastembed.py index 3a31efd5..01b08a30 100644 --- a/libs/infinity_emb/tests/end_to_end/test_fastembed.py +++ b/libs/infinity_emb/tests/end_to_end/test_fastembed.py @@ -4,19 +4,17 @@ import numpy as np import pytest -import torch from asgi_lifespan import LifespanManager from httpx import AsyncClient from sentence_transformers import SentenceTransformer # type: ignore from infinity_emb import create_server -from infinity_emb.transformer.sentence_transformer import CT2SentenceTransformer from infinity_emb.transformer.utils import InferenceEngine PREFIX = "/v1_fastembed" -MODEL: str = "BAAI/bge-base-en" # pytest.DEFAULT_BERT_MODEL # type: ignore +MODEL: str = "BAAI/bge-base-en" # pytest.DEFAULT_BERT_MODEL # type: ignore -batch_size = 8 +batch_size = 8 app = create_server( model_name_or_path=MODEL, @@ -69,11 +67,14 @@ async def test_embedding(client, model_base): want_embeddings = model_base.encode(inp) for embedding, st_embedding in zip(rdata["data"], want_embeddings): - cosine_sim = np.dot(embedding["embedding"], st_embedding)/( - np.linalg.norm(embedding["embedding"])*np.linalg.norm(st_embedding)) + cosine_sim = np.dot(embedding["embedding"], st_embedding) / ( + np.linalg.norm(embedding["embedding"]) * np.linalg.norm(st_embedding) + ) # TODO: fastembed is not producing the correct results. assert cosine_sim > 0.95 - np.testing.assert_almost_equal(embedding["embedding"], st_embedding, decimal=0) + np.testing.assert_almost_equal( + embedding["embedding"], st_embedding, decimal=0 + ) @pytest.mark.performance @@ -84,7 +85,7 @@ async def test_batch_embedding(client, get_sts_bechmark_dataset, model_base): for item in d: sentences.append(item.texts[0]) random.shuffle(sentences) - sentences = sentences[::16] + sentences = sentences[::16] # sentences = sentences[:batch_size*2] dummy_sentences = ["test" * 512] * batch_size @@ -121,13 +122,12 @@ async def _post_batch(inputs): encodings = model_base.encode(sentences, batch_size=batch_size).tolist() end = time.perf_counter() time_st = end - start - + responses = np.array(responses) encodings = np.array(encodings) - - for r,e in zip(responses, encodings): - cosine_sim = np.dot(r, e)/( - np.linalg.norm(e)*np.linalg.norm(r)) + + for r, e in zip(responses, encodings): + cosine_sim = np.dot(r, e) / (np.linalg.norm(e) * np.linalg.norm(r)) assert cosine_sim > 0.95 np.testing.assert_almost_equal(np.array(responses), np.array(encodings), decimal=0) assert time_api / time_st < 2.5