diff --git a/README.md b/README.md index 6196a8eb..2c7d7ed4 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,7 @@ Infinity is a high-throughput, low-latency REST API for serving vector embedding ## Why Infinity: Infinity provides the following features: * **Deploy any model from MTEB**: deploy the model you know from [SentenceTransformers](https://github.com/UKPLab/sentence-transformers/) -* **Fast inference backends**: The inference server is built on top of [torch](https://github.com/pytorch/pytorch), [optimum(onnx/tensorrt)](https://github.com/qdrant/fastembed) and [CTranslate2](https://github.com/OpenNMT/CTranslate2), using FlashAttention to get the most out of **CUDA**, **ROCM**, **CPU** or **MPS** chips. +* **Fast inference backends**: The inference server is built on top of [torch](https://github.com/pytorch/pytorch), [optimum(onnx/tensorrt)](https://huggingface.co/docs/optimum/index) and [CTranslate2](https://github.com/OpenNMT/CTranslate2), using FlashAttention to get the most out of **CUDA**, **ROCM**, **CPU** or **MPS** device. * **Dynamic batching**: New embedding requests are queued while GPU is busy with the previous ones. New requests are squeezed intro your device as soon as ready. Similar max throughput on GPU as text-embeddings-inference. * **Correct and tested implementation**: Unit and end-to-end tested. Embeddings via infinity are identical to [SentenceTransformers](https://github.com/UKPLab/sentence-transformers/) (up to numerical precision). Lets API users create embeddings till infinity and beyond. * **Easy to use**: The API is built on top of [FastAPI](https://fastapi.tiangolo.com/), [Swagger](https://swagger.io/) makes it fully documented. API are aligned to [OpenAI's Embedding specs](https://platform.openai.com/docs/guides/embeddings/what-are-embeddings). See below on how to get started. diff --git a/docs/docs/index.md b/docs/docs/index.md index a7774ce3..2dfe5776 100644 --- a/docs/docs/index.md +++ b/docs/docs/index.md @@ -5,7 +5,7 @@ Infinity is a high-throughput, low-latency REST API for serving vector embedding Infinity provides the following features: * **Deploy any model from MTEB**: deploy the model you know from [SentenceTransformers](https://github.com/UKPLab/sentence-transformers/) -* **Fast inference backends**: The inference server is built on top of [torch](https://github.com/pytorch/pytorch), [optimum(onnx/tensorrt)](https://github.com/qdrant/fastembed) and [CTranslate2](https://github.com/OpenNMT/CTranslate2), using FlashAttention to get the most out of **CUDA**, **ROCM**, **CPU** or **MPS** chips. +* **Fast inference backends**: The inference server is built on top of [torch](https://github.com/pytorch/pytorch), [optimum(onnx/tensorrt)](https://huggingface.co/docs/optimum/index) and [CTranslate2](https://github.com/OpenNMT/CTranslate2), using FlashAttention to get the most out of **CUDA**, **ROCM**, **CPU** or **MPS** device. * **Dynamic batching**: New embedding requests are queued while GPU is busy with the previous ones. New requests are squeezed intro your device as soon as ready. Similar max throughput on GPU as text-embeddings-inference. * **Correct and tested implementation**: Unit and end-to-end tested. Embeddings via infinity are identical to [SentenceTransformers](https://github.com/UKPLab/sentence-transformers/) (up to numerical precision). Lets API users create embeddings till infinity and beyond. * **Easy to use**: The API is built on top of [FastAPI](https://fastapi.tiangolo.com/), [Swagger](https://swagger.io/) makes it fully documented. API are aligned to [OpenAI's Embedding specs](https://platform.openai.com/docs/guides/embeddings/what-are-embeddings). See below on how to get started. diff --git a/docs/docs/python_engine.md b/docs/docs/python_engine.md index a1538932..1c189eed 100644 --- a/docs/docs/python_engine.md +++ b/docs/docs/python_engine.md @@ -1,30 +1,40 @@ -Enhancing the document involves improving clarity, structure, and adding helpful context where necessary. Here's an enhanced version: - # Python Engine Integration ## Launching Embedding generation with Python -Use asynchronous programming in Python using `asyncio` for flexible and efficient embedding processing with Infinity. This advanced method allows for concurrent execution, making it ideal for high-throughput embedding generation. +Use asynchronous programming in Python using `asyncio` for flexible and efficient embedding processing with Infinity. This advanced method allows for concurrent execution of different requests, making it ideal for high-throughput embedding generation. ```python import asyncio from infinity_emb import AsyncEmbeddingEngine, EngineArgs +from infinity_emb.log_handler import logger +logger.setLevel(5) # Debug # Define sentences for embedding sentences = ["Embed this sentence via Infinity.", "Paris is in France."] # Initialize the embedding engine with model specifications engine = AsyncEmbeddingEngine.from_args( - EngineArgs(model_name_or_path="BAAI/bge-small-en-v1.5", engine="torch", - lengths_via_tokenize=True + EngineArgs( + model_name_or_path="BAAI/bge-small-en-v1.5", + engine="torch", + lengths_via_tokenize=True ) ) async def main(): async with engine: # Context manager initializes and terminates the engine + + job1 = asyncio.create_task(engine.embed(sentences=sentences)) + # submit a second job in parallel + job2 = asyncio.create_task(engine.embed(sentences=["Hello world"])) # usage is total token count according to tokenizer. - embeddings, usage = await engine.embed(sentences=sentences) - # Embeddings are now available for use -asyncio.run(main()) + embeddings, usage = await job1 + embeddings2, usage2 = await job2 + # Embeddings are now available for use - they ran in the same batch. + print(f"for {sentences}, generated embeddings {len(embeddings)} with tot_tokens={usage}") +asyncio.run( + main() +) ``` ## Reranker diff --git a/libs/infinity_emb/infinity_emb/inference/select_model.py b/libs/infinity_emb/infinity_emb/inference/select_model.py index ae6e7c35..a4d1ccb8 100644 --- a/libs/infinity_emb/infinity_emb/inference/select_model.py +++ b/libs/infinity_emb/infinity_emb/inference/select_model.py @@ -18,7 +18,7 @@ def get_engine_type_from_config( engine_args: EngineArgs, ) -> Union[EmbedderEngine, RerankEngine]: - if engine_args.engine in [InferenceEngine.debugengine, InferenceEngine.fastembed]: + if engine_args.engine in [InferenceEngine.debugengine]: return EmbedderEngine.from_inference_engine(engine_args.engine) if Path(engine_args.model_name_or_path).is_dir(): diff --git a/libs/infinity_emb/infinity_emb/primitives.py b/libs/infinity_emb/infinity_emb/primitives.py index 6240ef1a..de0513c9 100644 --- a/libs/infinity_emb/infinity_emb/primitives.py +++ b/libs/infinity_emb/infinity_emb/primitives.py @@ -23,7 +23,6 @@ class InferenceEngine(enum.Enum): torch = "torch" ctranslate2 = "ctranslate2" optimum = "optimum" - fastembed = "fastembed" debugengine = "dummytransformer" diff --git a/libs/infinity_emb/infinity_emb/transformer/embedder/fastembed.py b/libs/infinity_emb/infinity_emb/transformer/embedder/fastembed.py deleted file mode 100644 index b3f4d565..00000000 --- a/libs/infinity_emb/infinity_emb/transformer/embedder/fastembed.py +++ /dev/null @@ -1,76 +0,0 @@ -import copy -from typing import Dict, List - -import numpy as np - -from infinity_emb.args import EngineArgs -from infinity_emb.log_handler import logger -from infinity_emb.primitives import Device, EmbeddingReturnType, PoolingMethod -from infinity_emb.transformer.abstract import BaseEmbedder - -try: - from fastembed.embedding import TextEmbedding # type: ignore - - from infinity_emb.transformer.utils_optimum import normalize - - FASTEMBED_AVAILABLE = True -except ImportError: - FASTEMBED_AVAILABLE = False - - -class Fastembed(BaseEmbedder): - def __init__(self, *, engine_args: EngineArgs) -> None: - if not FASTEMBED_AVAILABLE: - raise ImportError( - "fastembed is not installed." "`pip install infinity-emb[fastembed]`" - ) - logger.warning( - "deprecated: fastembed inference" - " is deprecated and will be removed in the future." - ) - - providers = ["CPUExecutionProvider"] - - if engine_args.device != Device.cpu: - providers = ["CUDAExecutionProvider"] + providers - - if engine_args.revision is not None: - logger.warning("revision is not used for fastembed") - - self.model = TextEmbedding( - model_name=engine_args.model_name_or_path, cache_dir=None - ).model - if self.model is None: - raise ValueError("fastembed model is not available") - if engine_args.pooling_method != PoolingMethod.auto: - logger.warning("pooling_method is not used for fastembed") - self._infinity_tokenizer = copy.deepcopy(self.model.tokenizer) - self.model.model.set_providers(providers) - - def encode_pre(self, sentences: List[str]) -> Dict[str, np.ndarray]: - encoded = self.model.tokenizer.encode_batch(sentences) - input_ids = np.array([e.ids for e in encoded]) - attention_mask = np.array([e.attention_mask for e in encoded]) - - onnx_input = { - "input_ids": np.array(input_ids, dtype=np.int64), - "attention_mask": np.array(attention_mask, dtype=np.int64), - "token_type_ids": np.array( - [np.zeros(len(e), dtype=np.int64) for e in input_ids], dtype=np.int64 - ), - } - return onnx_input - - def encode_core(self, features: Dict[str, np.ndarray]) -> np.ndarray: - model_output = self.model.model.run(None, features) - last_hidden_state = model_output[0][:, 0] - return last_hidden_state - - def encode_post(self, embedding: np.ndarray) -> EmbeddingReturnType: - return normalize(embedding).astype(np.float32) - - def tokenize_lengths(self, sentences: List[str]) -> List[int]: - tks = self._infinity_tokenizer.encode_batch( - sentences, - ) - return [len(t.tokens) for t in tks] diff --git a/libs/infinity_emb/infinity_emb/transformer/utils.py b/libs/infinity_emb/infinity_emb/transformer/utils.py index cc0d8357..8709c63f 100644 --- a/libs/infinity_emb/infinity_emb/transformer/utils.py +++ b/libs/infinity_emb/infinity_emb/transformer/utils.py @@ -11,13 +11,11 @@ ) from infinity_emb.transformer.embedder.ct2 import CT2SentenceTransformer from infinity_emb.transformer.embedder.dummytransformer import DummyTransformer -from infinity_emb.transformer.embedder.fastembed import Fastembed from infinity_emb.transformer.embedder.optimum import OptimumEmbedder from infinity_emb.transformer.embedder.sentence_transformer import ( SentenceTransformerPatched, ) -# from infinity_emb.transformer.fastembed import FastEmbed __all__ = [ "length_tokenizer", "get_lengths_with_tokenize", @@ -28,7 +26,6 @@ class EmbedderEngine(Enum): torch = SentenceTransformerPatched ctranslate2 = CT2SentenceTransformer - fastembed = Fastembed debugengine = DummyTransformer optimum = OptimumEmbedder @@ -38,8 +35,6 @@ def from_inference_engine(engine: InferenceEngine): return EmbedderEngine.torch elif engine == InferenceEngine.ctranslate2: return EmbedderEngine.ctranslate2 - elif engine == InferenceEngine.fastembed: - return EmbedderEngine.fastembed elif engine == InferenceEngine.debugengine: return EmbedderEngine.debugengine elif engine == InferenceEngine.optimum: diff --git a/libs/infinity_emb/poetry.lock b/libs/infinity_emb/poetry.lock index 055f073f..eb4c125e 100644 --- a/libs/infinity_emb/poetry.lock +++ b/libs/infinity_emb/poetry.lock @@ -487,7 +487,7 @@ files = [ name = "coloredlogs" version = "15.0.1" description = "Colored terminal output for Python's logging module" -optional = true +optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" files = [ {file = "coloredlogs-15.0.1-py2.py3-none-any.whl", hash = "sha256:612ee75c546f53e92e70049c9dbfcc18c935a2b9a53b66085ce9ef6a6e5c0934"}, @@ -801,7 +801,7 @@ all = ["email-validator (>=2.0.0)", "httpx (>=0.23.0)", "itsdangerous (>=1.1.0)" name = "fastembed" version = "0.2.1" description = "Fast, light, accurate library built for retrieval embedding generation" -optional = true +optional = false python-versions = ">=3.8.0,<3.13" files = [ {file = "fastembed-0.2.1-py3-none-any.whl", hash = "sha256:7e3d71e449a54e1c47bd502246ee9367259002691caba5d3c25b9ebb47d1f621"}, @@ -841,7 +841,7 @@ typing = ["typing-extensions (>=4.8)"] name = "flatbuffers" version = "23.5.26" description = "The FlatBuffers serialization format for Python" -optional = true +optional = false python-versions = "*" files = [ {file = "flatbuffers-23.5.26-py2.py3-none-any.whl", hash = "sha256:c0ff356da363087b915fde4b8b45bdda73432fc17cddb3c8157472eab1422ad1"}, @@ -1221,7 +1221,7 @@ typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "t name = "humanfriendly" version = "10.0" description = "Human friendly output for text interfaces using Python" -optional = true +optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" files = [ {file = "humanfriendly-10.0-py2.py3-none-any.whl", hash = "sha256:1697e1a8a8f550fd43c2865cd84542fc175a61dcb779b6fee18cf6b6ccba1477"}, @@ -1322,7 +1322,7 @@ files = [ name = "loguru" version = "0.7.2" description = "Python logging made (stupidly) simple" -optional = true +optional = false python-versions = ">=3.5" files = [ {file = "loguru-0.7.2-py3-none-any.whl", hash = "sha256:003d71e3d3ed35f0f8984898359d65b79e5b21943f78af86aa5491210429b8eb"}, @@ -1582,7 +1582,7 @@ beautifulsoup4 = ">=4.11.1" name = "mpmath" version = "1.3.0" description = "Python library for arbitrary-precision floating-point arithmetic" -optional = true +optional = false python-versions = "*" files = [ {file = "mpmath-1.3.0-py3-none-any.whl", hash = "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c"}, @@ -2003,7 +2003,7 @@ files = [ name = "onnx" version = "1.15.0" description = "Open Neural Network Exchange" -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "onnx-1.15.0-cp310-cp310-macosx_10_12_universal2.whl", hash = "sha256:51cacb6aafba308aaf462252ced562111f6991cdc7bc57a6c554c3519453a8ff"}, @@ -2044,7 +2044,7 @@ reference = ["Pillow", "google-re2"] name = "onnxruntime" version = "1.17.0" description = "ONNX Runtime is a runtime accelerator for Machine Learning models" -optional = true +optional = false python-versions = "*" files = [ {file = "onnxruntime-1.17.0-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:d2b22a25a94109cc983443116da8d9805ced0256eb215c5e6bc6dcbabefeab96"}, @@ -2721,7 +2721,7 @@ diagrams = ["jinja2", "railroad-diagrams"] name = "pyreadline3" version = "3.4.1" description = "A python implementation of GNU readline." -optional = true +optional = false python-versions = "*" files = [ {file = "pyreadline3-3.4.1-py3-none-any.whl", hash = "sha256:b0efb6516fd4fb07b45949053826a62fa4cb353db5be2bbb4a7aa1fdd1e345fb"}, @@ -3449,7 +3449,7 @@ full = ["httpx (>=0.22.0)", "itsdangerous", "jinja2", "python-multipart", "pyyam name = "sympy" version = "1.12" description = "Computer algebra system (CAS) in Python" -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "sympy-1.12-py3-none-any.whl", hash = "sha256:c3588cd4295d0c0f603d0f2ae780587e64e2efeedb3521e46b9bb1d08d184fa5"}, @@ -3487,7 +3487,7 @@ files = [ name = "tokenizers" version = "0.15.2" description = "" -optional = true +optional = false python-versions = ">=3.7" files = [ {file = "tokenizers-0.15.2-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:52f6130c9cbf70544287575a985bf44ae1bda2da7e8c24e97716080593638012"}, @@ -4245,7 +4245,7 @@ files = [ name = "win32-setctime" version = "1.1.0" description = "A small Python utility to set file creation time on Windows" -optional = true +optional = false python-versions = ">=3.5" files = [ {file = "win32_setctime-1.1.0-py3-none-any.whl", hash = "sha256:231db239e959c2fe7eb1d7dc129f11172354f98361c4fa2d6d2d7e278baa8aad"}, @@ -4491,10 +4491,9 @@ docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.link testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"] [extras] -all = ["ctranslate2", "diskcache", "fastapi", "fastembed", "optimum", "orjson", "prometheus-fastapi-instrumentator", "pydantic", "rich", "sentence-transformers", "torch", "typer", "uvicorn"] +all = ["ctranslate2", "diskcache", "fastapi", "optimum", "orjson", "prometheus-fastapi-instrumentator", "pydantic", "rich", "sentence-transformers", "torch", "typer", "uvicorn"] cache = ["diskcache"] ct2 = ["ctranslate2", "sentence-transformers", "torch", "transformers"] -fastembed = ["fastembed"] logging = ["rich"] onnxruntime-gpu = ["onnxruntime-gpu"] optimum = ["optimum"] @@ -4505,4 +4504,4 @@ torch = ["hf_transfer", "sentence-transformers", "torch"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.13" -content-hash = "5c6c808163e5b2a6f835cd00d859c506b9a023be90f6c928a6fe903fc534aaf6" +content-hash = "144b7d01f541181d84159f3daa4881feffa59f51a1be94e18c5e3993a4403724" diff --git a/libs/infinity_emb/pyproject.toml b/libs/infinity_emb/pyproject.toml index 9a502a72..357ade88 100644 --- a/libs/infinity_emb/pyproject.toml +++ b/libs/infinity_emb/pyproject.toml @@ -27,7 +27,6 @@ sentence-transformers = {version = "^2.4.0", optional=true} transformers = {version = ">4.8.0", optional=true} ctranslate2 = {version = "^4.0.0", optional=true} optimum = {version = ">=1.16.2", optional=true, extras=["onnxruntime"]} -fastembed = {version = ">=0.2.1", optional=true} hf_transfer = {version=">=0.1.5", optional=true} # cache diskcache = {version = "*", optional=true} @@ -47,6 +46,7 @@ anyio = "*" trio = "*" coverage = {extras = ["toml"], version = "^7.3.2"} mypy = "^1.5.1" +fastembed = ">=0.2.1" [tool.poetry.group.codespell.dependencies] codespell = "^2.2.0" @@ -68,12 +68,11 @@ mypy-protobuf = "^3.0.0" [tool.poetry.extras] ct2=["ctranslate2","sentence-transformers","torch","transformers"] optimum=["optimum"] -fastembed=["fastembed"] torch=["sentence-transformers","torch","hf_transfer"] logging=["rich"] cache=["diskcache"] server=["fastapi", "pydantic", "orjson", "prometheus-fastapi-instrumentator", "uvicorn", "typer","rich"] -all=["ctranslate2", "fastapi", "fastembed", "optimum", "orjson", "prometheus-fastapi-instrumentator", "pydantic", "rich", "sentence-transformers", "torch", "typer", "uvicorn","diskcache"] +all=["ctranslate2", "fastapi", "optimum", "orjson", "prometheus-fastapi-instrumentator", "pydantic", "rich", "sentence-transformers", "torch", "typer", "uvicorn","diskcache"] # non-default gpu tensorrt=["tensorrt"] onnxruntime-gpu=["onnxruntime-gpu"] diff --git a/libs/infinity_emb/tests/end_to_end/test_fastembed.py b/libs/infinity_emb/tests/end_to_end/test_fastembed.py deleted file mode 100644 index 22fae4a1..00000000 --- a/libs/infinity_emb/tests/end_to_end/test_fastembed.py +++ /dev/null @@ -1,66 +0,0 @@ -import pytest -from asgi_lifespan import LifespanManager -from httpx import AsyncClient -from sentence_transformers import SentenceTransformer # type: ignore - -from infinity_emb import create_server -from infinity_emb.args import EngineArgs -from infinity_emb.primitives import Device, InferenceEngine - -PREFIX = "/v1_fastembed" -MODEL: str = "BAAI/bge-small-en-v1.5" # pytest.DEFAULT_BERT_MODEL # type: ignore - -batch_size = 8 - - -app = create_server( - url_prefix=PREFIX, - engine_args=EngineArgs( - model_name_or_path=MODEL, - batch_size=batch_size, - engine=InferenceEngine.fastembed, - device=Device.cpu, - ), -) - - -@pytest.fixture -def model_base() -> SentenceTransformer: - return SentenceTransformer(MODEL, device="cpu") - - -@pytest.fixture() -async def client(): - async with AsyncClient( - app=app, base_url="http://test", timeout=20 - ) as client, LifespanManager(app): - yield client - - -@pytest.mark.anyio -async def test_model_route(client): - response = await client.get(f"{PREFIX}/models") - assert response.status_code == 200 - rdata = response.json() - assert "data" in rdata - assert rdata["data"][0].get("id", "") == MODEL - assert isinstance(rdata["data"][0].get("stats"), dict) - - -@pytest.mark.anyio -async def test_embedding(client, model_base, helpers): - await helpers.embedding_verify(client, model_base, prefix=PREFIX, model_name=MODEL) - - -@pytest.mark.performance -@pytest.mark.anyio -async def test_batch_embedding(client, get_sts_bechmark_dataset, model_base, helpers): - await helpers.util_batch_embedding( - client=client, - sts_bechmark_dataset=get_sts_bechmark_dataset, - model_base=model_base, - prefix=PREFIX, - model_name=MODEL, - batch_size=batch_size, - downsample=32, - ) diff --git a/libs/infinity_emb/tests/unit_test/inference/test_select_model.py b/libs/infinity_emb/tests/unit_test/inference/test_select_model.py index 40189859..3e0f1ca9 100644 --- a/libs/infinity_emb/tests/unit_test/inference/test_select_model.py +++ b/libs/infinity_emb/tests/unit_test/inference/test_select_model.py @@ -10,11 +10,7 @@ def test_engine(engine): select_model( EngineArgs( engine=engine, - model_name_or_path=( - "BAAI/bge-small-en-v1.5" - if engine == InferenceEngine.fastembed - else pytest.DEFAULT_BERT_MODEL - ), + model_name_or_path=(pytest.DEFAULT_BERT_MODEL), batch_size=4, device=Device.cpu, ) diff --git a/libs/infinity_emb/tests/unit_test/test_engine.py b/libs/infinity_emb/tests/unit_test/test_engine.py index c02a8daf..4895beb8 100644 --- a/libs/infinity_emb/tests/unit_test/test_engine.py +++ b/libs/infinity_emb/tests/unit_test/test_engine.py @@ -188,26 +188,6 @@ async def test_async_api_torch_usage(): assert embeddings.shape[1] >= 10 -@pytest.mark.anyio -async def test_async_api_fastembed(): - sentences = ["Hi", "how"] - engine = AsyncEmbeddingEngine.from_args( - EngineArgs( - model_name_or_path="BAAI/bge-small-en-v1.5", - engine=InferenceEngine.fastembed, - device="cpu", - model_warmup=False, - ) - ) - async with engine: - embeddings, usage = await engine.embed(sentences) - embeddings = np.array(embeddings) - assert usage == sum([len(s) for s in sentences]) - assert embeddings.shape[0] == len(sentences) - assert embeddings.shape[1] >= 10 - assert not engine.is_overloaded() - - @pytest.mark.anyio async def test_async_api_failing(): sentences = ["Hi", "how"] diff --git a/libs/infinity_emb/tests/unit_test/transformer/embedder/test_torch.py b/libs/infinity_emb/tests/unit_test/transformer/embedder/test_torch.py new file mode 100644 index 00000000..51519c66 --- /dev/null +++ b/libs/infinity_emb/tests/unit_test/transformer/embedder/test_torch.py @@ -0,0 +1,41 @@ +import numpy as np +import pytest +import torch + +from infinity_emb.args import EngineArgs +from infinity_emb.transformer.embedder.sentence_transformer import ( + SentenceTransformerPatched, +) + +try: + from fastembed import TextEmbedding # type: ignore + + FASTEMBED_AVAILABLE = True +except ImportError: + FASTEMBED_AVAILABLE = False + + +@pytest.mark.skipif(not FASTEMBED_AVAILABLE, reason="fastembed not available") +def test_sentence_transformer_equals_fastembed( + text=( + "This is a test sentence for the sentence transformer and fastembed. The PyTorch API of nested tensors is in prototype stage and will change in the near future." + ), + model_name="BAAI/bge-small-en-v1.5", +) -> None: + model_st = SentenceTransformerPatched( + engine_args=EngineArgs(model_name_or_path=model_name) + ) + model_fastembed = TextEmbedding(model_name) + + embedding_st = model_st.encode_post( + model_st.encode_core(model_st.encode_pre([text])) + ) + embedding_fast = np.array(list(model_fastembed.embed(documents=[text]))) + + assert embedding_fast.shape == embedding_st.shape + assert np.allclose(embedding_st[0], embedding_fast[0], atol=1e-3) + # cosine similarity + sim = torch.nn.functional.cosine_similarity( + torch.tensor(embedding_st), torch.tensor(embedding_fast) + ) + assert sim > 0.99