Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Follow up PR for Audio End to End testing #390

Merged
merged 16 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions libs/infinity_emb/tests/end_to_end/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import numpy as np
import pytest
from numpy import dot
from numpy.linalg import norm


class Helpers:
Expand Down Expand Up @@ -98,6 +100,10 @@ async def embedding_verify(client, model_base, prefix, model_name, decimal=3):
embedding["embedding"], st_embedding, decimal=decimal
)

@staticmethod
def cosine_similarity(a, b):
return dot(a, b) / (norm(a) * norm(b))


@pytest.fixture
def helpers():
Expand Down
49 changes: 47 additions & 2 deletions libs/infinity_emb/tests/end_to_end/test_torch_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,11 @@ async def test_audio_single(client):


@pytest.mark.anyio
@pytest.mark.skip("text only")
async def test_audio_single_text_only(client):
text = "a sound of a at"

response = await client.post(
f"{PREFIX}/embeddings_audio",
f"{PREFIX}/embeddings",
json={"model": MODEL, "input": text},
)
assert response.status_code == 200
Expand All @@ -79,6 +78,43 @@ async def test_audio_single_text_only(client):
assert len(rdata_results[0]["embedding"]) > 0


@pytest.mark.anyio
async def test_meta(client, helpers):
audio_url = "https://github.com/michaelfeil/infinity/raw/3b72eb7c14bae06e68ddd07c1f23fe0bf403f220/libs/infinity_emb/tests/data/audio/beep.wav"

text_input = ["a beep", "a horse", "a fish"]
audio_input = [audio_url]
response_text = await client.post(
f"{PREFIX}/embeddings",
json={"model": MODEL, "input": text_input},
)
response_audio = await client.post(
f"{PREFIX}/embeddings_audio",
json={"model": MODEL, "input": audio_input},
)

assert response_text.status_code == 200
assert response_audio.status_code == 200

rdata_text = response_text.json()
rdata_results_text = rdata_text["data"]

rdata_audio = response_audio.json()
rdata_results_audio = rdata_audio["data"]

embeddings_audio_beep = rdata_results_audio[0]["embedding"]
embeddings_text_beep = rdata_results_text[0]["embedding"]
embeddings_text_horse = rdata_results_text[1]["embedding"]
embeddings_text_fish = rdata_results_text[2]["embedding"]

assert helpers.cosine_similarity(
embeddings_audio_beep, embeddings_text_beep
) > helpers.cosine_similarity(embeddings_audio_beep, embeddings_text_fish)
assert helpers.cosine_similarity(
embeddings_audio_beep, embeddings_text_beep
) > helpers.cosine_similarity(embeddings_audio_beep, embeddings_text_horse)


@pytest.mark.anyio
@pytest.mark.parametrize("no_of_audios", [1, 5, 10])
async def test_audio_multiple(client, no_of_audios):
Expand Down Expand Up @@ -120,3 +156,12 @@ async def test_audio_empty(client):
json={"model": MODEL, "input": audio_url_empty},
)
assert response_empty.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY


@pytest.mark.anyio
async def test_unsupported_endpoints(client):
response_unsupported = await client.post(
f"{PREFIX}/classify",
json={"model": MODEL, "input": ["test"]},
)
assert response_unsupported.status_code == status.HTTP_400_BAD_REQUEST
49 changes: 47 additions & 2 deletions libs/infinity_emb/tests/end_to_end/test_torch_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,11 @@ async def test_vision_single(client):


@pytest.mark.anyio
@pytest.mark.skip("text only")
async def test_vision_single_text_only(client):
text = "a image of a cat"

response = await client.post(
f"{PREFIX}/embeddings_image",
f"{PREFIX}/embeddings",
json={"model": MODEL, "input": text},
)
assert response.status_code == 200
Expand All @@ -79,6 +78,43 @@ async def test_vision_single_text_only(client):
assert len(rdata_results[0]["embedding"]) > 0


@pytest.mark.anyio
async def test_meta(client, helpers):
image_url = "http://images.cocodataset.org/val2017/000000039769.jpg"

text_input = ["a cat", "a car", "a fridge"]
image_input = [image_url]
response_text = await client.post(
f"{PREFIX}/embeddings",
json={"model": MODEL, "input": text_input},
)
response_image = await client.post(
f"{PREFIX}/embeddings_image",
json={"model": MODEL, "input": image_input},
)

assert response_text.status_code == 200
assert response_image.status_code == 200

rdata_text = response_text.json()
rdata_results_text = rdata_text["data"]

rdata_image = response_image.json()
rdata_results_image = rdata_image["data"]

embeddings_image_cat = rdata_results_image[0]["embedding"]
embeddings_text_cat = rdata_results_text[0]["embedding"]
embeddings_text_car = rdata_results_text[1]["embedding"]
embeddings_text_fridge = rdata_results_text[2]["embedding"]

assert helpers.cosine_similarity(
embeddings_image_cat, embeddings_text_cat
) > helpers.cosine_similarity(embeddings_image_cat, embeddings_text_car)
assert helpers.cosine_similarity(
embeddings_image_cat, embeddings_text_cat
) > helpers.cosine_similarity(embeddings_image_cat, embeddings_text_fridge)


@pytest.mark.anyio
@pytest.mark.parametrize("no_of_images", [1, 5, 10])
async def test_vision_multiple(client, no_of_images):
Expand Down Expand Up @@ -119,3 +155,12 @@ async def test_vision_empty(client):
json={"model": MODEL, "input": image_url_empty},
)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY


@pytest.mark.anyio
async def test_unsupported_endpoints(client):
response_unsupported = await client.post(
f"{PREFIX}/classify",
json={"model": MODEL, "input": ["test"]},
)
assert response_unsupported.status_code == status.HTTP_400_BAD_REQUEST
Loading