diff --git a/libs/infinity_emb/tests/end_to_end/conftest.py b/libs/infinity_emb/tests/end_to_end/conftest.py index e443ee04..3ce61fc8 100644 --- a/libs/infinity_emb/tests/end_to_end/conftest.py +++ b/libs/infinity_emb/tests/end_to_end/conftest.py @@ -4,6 +4,8 @@ import numpy as np import pytest +from numpy import dot +from numpy.linalg import norm class Helpers: @@ -98,6 +100,10 @@ async def embedding_verify(client, model_base, prefix, model_name, decimal=3): embedding["embedding"], st_embedding, decimal=decimal ) + @staticmethod + def cosine_similarity(a, b): + return dot(a, b) / (norm(a) * norm(b)) + @pytest.fixture def helpers(): diff --git a/libs/infinity_emb/tests/end_to_end/test_torch_audio.py b/libs/infinity_emb/tests/end_to_end/test_torch_audio.py index 80625b43..120e7c63 100644 --- a/libs/infinity_emb/tests/end_to_end/test_torch_audio.py +++ b/libs/infinity_emb/tests/end_to_end/test_torch_audio.py @@ -62,12 +62,11 @@ async def test_audio_single(client): @pytest.mark.anyio -@pytest.mark.skip("text only") async def test_audio_single_text_only(client): text = "a sound of a at" response = await client.post( - f"{PREFIX}/embeddings_audio", + f"{PREFIX}/embeddings", json={"model": MODEL, "input": text}, ) assert response.status_code == 200 @@ -79,6 +78,43 @@ async def test_audio_single_text_only(client): assert len(rdata_results[0]["embedding"]) > 0 +@pytest.mark.anyio +async def test_meta(client, helpers): + audio_url = "https://github.com/michaelfeil/infinity/raw/3b72eb7c14bae06e68ddd07c1f23fe0bf403f220/libs/infinity_emb/tests/data/audio/beep.wav" + + text_input = ["a beep", "a horse", "a fish"] + audio_input = [audio_url] + response_text = await client.post( + f"{PREFIX}/embeddings", + json={"model": MODEL, "input": text_input}, + ) + response_audio = await client.post( + f"{PREFIX}/embeddings_audio", + json={"model": MODEL, "input": audio_input}, + ) + + assert response_text.status_code == 200 + assert response_audio.status_code == 200 + + rdata_text = response_text.json() + rdata_results_text = rdata_text["data"] + + rdata_audio = response_audio.json() + rdata_results_audio = rdata_audio["data"] + + embeddings_audio_beep = rdata_results_audio[0]["embedding"] + embeddings_text_beep = rdata_results_text[0]["embedding"] + embeddings_text_horse = rdata_results_text[1]["embedding"] + embeddings_text_fish = rdata_results_text[2]["embedding"] + + assert helpers.cosine_similarity( + embeddings_audio_beep, embeddings_text_beep + ) > helpers.cosine_similarity(embeddings_audio_beep, embeddings_text_fish) + assert helpers.cosine_similarity( + embeddings_audio_beep, embeddings_text_beep + ) > helpers.cosine_similarity(embeddings_audio_beep, embeddings_text_horse) + + @pytest.mark.anyio @pytest.mark.parametrize("no_of_audios", [1, 5, 10]) async def test_audio_multiple(client, no_of_audios): @@ -120,3 +156,12 @@ async def test_audio_empty(client): json={"model": MODEL, "input": audio_url_empty}, ) assert response_empty.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + + +@pytest.mark.anyio +async def test_unsupported_endpoints(client): + response_unsupported = await client.post( + f"{PREFIX}/classify", + json={"model": MODEL, "input": ["test"]}, + ) + assert response_unsupported.status_code == status.HTTP_400_BAD_REQUEST diff --git a/libs/infinity_emb/tests/end_to_end/test_torch_vision.py b/libs/infinity_emb/tests/end_to_end/test_torch_vision.py index f6e5b8a1..e9a7d68a 100644 --- a/libs/infinity_emb/tests/end_to_end/test_torch_vision.py +++ b/libs/infinity_emb/tests/end_to_end/test_torch_vision.py @@ -62,12 +62,11 @@ async def test_vision_single(client): @pytest.mark.anyio -@pytest.mark.skip("text only") async def test_vision_single_text_only(client): text = "a image of a cat" response = await client.post( - f"{PREFIX}/embeddings_image", + f"{PREFIX}/embeddings", json={"model": MODEL, "input": text}, ) assert response.status_code == 200 @@ -79,6 +78,43 @@ async def test_vision_single_text_only(client): assert len(rdata_results[0]["embedding"]) > 0 +@pytest.mark.anyio +async def test_meta(client, helpers): + image_url = "http://images.cocodataset.org/val2017/000000039769.jpg" + + text_input = ["a cat", "a car", "a fridge"] + image_input = [image_url] + response_text = await client.post( + f"{PREFIX}/embeddings", + json={"model": MODEL, "input": text_input}, + ) + response_image = await client.post( + f"{PREFIX}/embeddings_image", + json={"model": MODEL, "input": image_input}, + ) + + assert response_text.status_code == 200 + assert response_image.status_code == 200 + + rdata_text = response_text.json() + rdata_results_text = rdata_text["data"] + + rdata_image = response_image.json() + rdata_results_image = rdata_image["data"] + + embeddings_image_cat = rdata_results_image[0]["embedding"] + embeddings_text_cat = rdata_results_text[0]["embedding"] + embeddings_text_car = rdata_results_text[1]["embedding"] + embeddings_text_fridge = rdata_results_text[2]["embedding"] + + assert helpers.cosine_similarity( + embeddings_image_cat, embeddings_text_cat + ) > helpers.cosine_similarity(embeddings_image_cat, embeddings_text_car) + assert helpers.cosine_similarity( + embeddings_image_cat, embeddings_text_cat + ) > helpers.cosine_similarity(embeddings_image_cat, embeddings_text_fridge) + + @pytest.mark.anyio @pytest.mark.parametrize("no_of_images", [1, 5, 10]) async def test_vision_multiple(client, no_of_images): @@ -119,3 +155,12 @@ async def test_vision_empty(client): json={"model": MODEL, "input": image_url_empty}, ) assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + + +@pytest.mark.anyio +async def test_unsupported_endpoints(client): + response_unsupported = await client.post( + f"{PREFIX}/classify", + json={"model": MODEL, "input": ["test"]}, + ) + assert response_unsupported.status_code == status.HTTP_400_BAD_REQUEST