From 06fd1f4d8f0a869f4482fc1c78b62a75ccbb66a1 Mon Sep 17 00:00:00 2001 From: Michael Feil <63565275+michaelfeil@users.noreply.github.com> Date: Mon, 23 Sep 2024 20:58:26 -0700 Subject: [PATCH] async requests enabled (#376) * async requests enabled * aiohttp checks * fmt --- .../infinity_emb/_optional_imports.py | 3 +- .../infinity_emb/inference/batch_handler.py | 5 +- .../infinity_emb/transformer/vision/utils.py | 67 ++++++++++++------- 3 files changed, 48 insertions(+), 27 deletions(-) diff --git a/libs/infinity_emb/infinity_emb/_optional_imports.py b/libs/infinity_emb/infinity_emb/_optional_imports.py index a1e53f14..6ae4a392 100644 --- a/libs/infinity_emb/infinity_emb/_optional_imports.py +++ b/libs/infinity_emb/infinity_emb/_optional_imports.py @@ -66,7 +66,8 @@ def _raise_error(self) -> None: CHECK_SENTENCE_TRANSFORMERS = OptionalImports("sentence_transformers", "torch") CHECK_TRANSFORMERS = OptionalImports("transformers", "torch") CHECK_TORCH = OptionalImports("torch.nn", "torch") -CHECK_REQUESTS = OptionalImports("requests", "server") +# CHECK_REQUESTS = OptionalImports("requests", "server") +CHECK_AIOHTTP = OptionalImports("aiohttp", "server") CHECK_PIL = OptionalImports("PIL", "vision") CHECK_SOUNDFILE = OptionalImports("soundfile", "audio") CHECK_PYDANTIC = OptionalImports("pydantic", "server") diff --git a/libs/infinity_emb/infinity_emb/inference/batch_handler.py b/libs/infinity_emb/infinity_emb/inference/batch_handler.py index 2aaea0de..031bb19c 100644 --- a/libs/infinity_emb/infinity_emb/inference/batch_handler.py +++ b/libs/infinity_emb/infinity_emb/inference/batch_handler.py @@ -233,7 +233,7 @@ async def image_embed( f"options are {self.model_worker.capabilities}." ) - items = await asyncio.to_thread(resolve_images, images) + items = await resolve_images(images) embeddings, usage = await self._schedule(items) return embeddings, usage @@ -262,8 +262,7 @@ async def audio_embed( f"options are {self.model_worker.capabilities}." ) - items = await asyncio.to_thread( - resolve_audios, + items = await resolve_audios( audios, getattr(self.model_worker._model, "sampling_rate", -42), ) diff --git a/libs/infinity_emb/infinity_emb/transformer/vision/utils.py b/libs/infinity_emb/infinity_emb/transformer/vision/utils.py index 3bb555c3..08ff04ae 100644 --- a/libs/infinity_emb/infinity_emb/transformer/vision/utils.py +++ b/libs/infinity_emb/infinity_emb/transformer/vision/utils.py @@ -1,10 +1,11 @@ # SPDX-License-Identifier: MIT # Copyright (c) 2023-now michaelfeil +import asyncio import io from typing import List, Union -from infinity_emb._optional_imports import CHECK_PIL, CHECK_REQUESTS, CHECK_SOUNDFILE +from infinity_emb._optional_imports import CHECK_AIOHTTP, CHECK_PIL, CHECK_SOUNDFILE from infinity_emb.primitives import ( AudioCorruption, AudioSingle, @@ -13,11 +14,12 @@ ImageSingle, ) +if CHECK_AIOHTTP.is_available: + import aiohttp + if CHECK_PIL.is_available: from PIL import Image # type: ignore -if CHECK_REQUESTS.is_available: - import requests # type: ignore if CHECK_SOUNDFILE.is_available: import soundfile as sf # type: ignore @@ -27,17 +29,20 @@ def resolve_from_img_obj(img_obj: "ImageClassType") -> ImageSingle: return ImageSingle(image=img_obj) -def resolve_from_img_url(img_url: str) -> ImageSingle: +async def resolve_from_img_url( + img_url: str, session: "aiohttp.ClientSession" +) -> ImageSingle: """Resolve an image from an URL.""" try: - downloaded_img = requests.get(img_url, stream=True).raw + # requests.get(img_url, stream=True).raw + downloaded_img = await (await session.get(img_url)).read() except Exception as e: raise ImageCorruption( f"error opening an image in your request image from url: {e}" ) try: - img = Image.open(downloaded_img) + img = Image.open(io.BytesIO(downloaded_img)) if img.size[0] < 3 or img.size[1] < 3: # https://upload.wikimedia.org/wikipedia/commons/c/ca/1x1.png raise ImageCorruption( @@ -50,37 +55,48 @@ def resolve_from_img_url(img_url: str) -> ImageSingle: ) -def resolve_image(img: Union[str, "ImageClassType"]) -> ImageSingle: +async def resolve_image( + img: Union[str, "ImageClassType"], session: "aiohttp.ClientSession" +) -> ImageSingle: """Resolve a single image.""" if isinstance(img, Image.Image): return resolve_from_img_obj(img) elif isinstance(img, str): - return resolve_from_img_url(img) + return await resolve_from_img_url(img, session=session) else: raise ValueError( f"Invalid image type: {img} is neither str nor ImageClassType object" ) -def resolve_images(images: List[Union[str, "ImageClassType"]]) -> List[ImageSingle]: +async def resolve_images( + images: List[Union[str, "ImageClassType"]] +) -> List[ImageSingle]: """Resolve images from URLs or ImageClassType Objects using multithreading.""" # TODO: improve parallel requests, safety, error handling - CHECK_REQUESTS.mark_required() + CHECK_AIOHTTP.mark_required() CHECK_PIL.mark_required() resolved_imgs = [] - for img in images: - try: - resolved_imgs.append(resolve_image(img)) - except Exception as e: - raise ImageCorruption( - f"Failed to resolve image: {img}.\nError msg: {str(e)}" + + try: + async with aiohttp.ClientSession(trust_env=True) as session: + resolved_imgs = await asyncio.gather( + *[resolve_image(img, session) for img in images] ) + except Exception as e: + raise ImageCorruption( + f"Failed to resolve image: {images}.\nError msg: {str(e)}" + ) return resolved_imgs -def resolve_audio(audio: Union[str, bytes], allowed_sampling_rate: int) -> AudioSingle: +async def resolve_audio( + audio: Union[str, bytes], + allowed_sampling_rate: int, + session: "aiohttp.ClientSession", +) -> AudioSingle: if isinstance(audio, bytes): try: audio_bytes = io.BytesIO(audio) @@ -88,7 +104,8 @@ def resolve_audio(audio: Union[str, bytes], allowed_sampling_rate: int) -> Audio raise AudioCorruption(f"Error opening audio: {e}") else: try: - downloaded = requests.get(audio, stream=True).content + downloaded = await (await session.get(audio)).read() + # downloaded = requests.get(audio, stream=True).content audio_bytes = io.BytesIO(downloaded) except Exception as e: raise AudioCorruption(f"Error downloading audio.\nError msg: {str(e)}") @@ -104,18 +121,22 @@ def resolve_audio(audio: Union[str, bytes], allowed_sampling_rate: int) -> Audio raise AudioCorruption(f"Error opening audio: {e}.\nError msg: {str(e)}") -def resolve_audios( +async def resolve_audios( audio_urls: list[Union[str, bytes]], allowed_sampling_rate: int ) -> list[AudioSingle]: """Resolve audios from URLs.""" - CHECK_REQUESTS.mark_required() + CHECK_AIOHTTP.mark_required() CHECK_SOUNDFILE.mark_required() resolved_audios: list[AudioSingle] = [] - for audio in audio_urls: + async with aiohttp.ClientSession(trust_env=True) as session: try: - audio_single = resolve_audio(audio, allowed_sampling_rate) - resolved_audios.append(audio_single) + resolved_audios = await asyncio.gather( + *[ + resolve_audio(audio, allowed_sampling_rate, session) + for audio in audio_urls + ] + ) except Exception as e: raise AudioCorruption(f"Failed to resolve audio: {e}")