Skip to content

Commit

Permalink
Model: Device-id and data-parallel inference in CLI and Torch (#452)
Browse files Browse the repository at this point in the history
* inital commit

* fmt

* lint

* fix: typos, more docs
  • Loading branch information
michaelfeil authored Nov 11, 2024
1 parent a178460 commit 941105f
Show file tree
Hide file tree
Showing 24 changed files with 1,375 additions and 135 deletions.
6 changes: 6 additions & 0 deletions docs/docs/cli_v2.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@ $ infinity_emb v2 --help
│ the model forward pass. │
│ [env var: `INFINITY_DEVICE`] │
│ [default: auto] │
│ --device-id TEXT device id defines the model │
│ placement. e.g. `0,1` will │
│ place the model on │
│ MPS/CUDA/GPU 0 and 1 each │
│ [env var: │
│ `INFINITY_DEVICE_ID`] │
│ --lengths-via-tokenize --no-lengths-via-tokenize if True, returned tokens is │
│ based on actual tokenizer │
│ count. If false, uses │
Expand Down
32 changes: 30 additions & 2 deletions libs/infinity_emb/infinity_emb/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,27 @@
# Copyright (c) 2023-now michaelfeil

import sys
from dataclasses import asdict, dataclass
from dataclasses import asdict, dataclass, field
from itertools import zip_longest
from typing import Optional
from copy import deepcopy


from infinity_emb._optional_imports import CHECK_PYDANTIC
from infinity_emb.env import MANAGER
from infinity_emb.primitives import (
Device,
DeviceID,
Dtype,
EmbeddingDtype,
InferenceEngine,
PoolingMethod,
LoadingStrategy,
)

if CHECK_PYDANTIC.is_available:
from pydantic.dataclasses import dataclass as dataclass_pydantic
from pydantic import ConfigDict
# if python>=3.10 use kw_only
dataclass_args = {"kw_only": True} if sys.version_info >= (3, 10) else {}

Expand All @@ -37,6 +42,8 @@ class EngineArgs:
vector_disk_cache_path, str: file path to folder of cache.
Defaults to "" - default no caching.
device, Device or str: device to use for inference. Defaults to Device.auto.
device_id, DeviceID or str: device index to use for inference.
Defaults to [], no preferred placement.
compile, bool: compile model for better performance. Defaults to False.
bettertransformer, bool: use bettertransformer. Defaults to True.
dtype, Dtype or str: data type to use for inference. Defaults to Dtype.auto.
Expand All @@ -53,6 +60,7 @@ class EngineArgs:
model_warmup: bool = MANAGER.model_warmup[0]
vector_disk_cache_path: str = ""
device: Device = Device[MANAGER.device[0]]
device_id: DeviceID = field(default_factory=lambda: DeviceID(MANAGER.device_id[0]))
compile: bool = MANAGER.compile[0]
bettertransformer: bool = MANAGER.bettertransformer[0]
dtype: Dtype = Dtype[MANAGER.dtype[0]]
Expand All @@ -61,6 +69,8 @@ class EngineArgs:
embedding_dtype: EmbeddingDtype = EmbeddingDtype[MANAGER.embedding_dtype[0]]
served_model_name: str = MANAGER.served_model_name[0]

_loading_strategy: Optional[LoadingStrategy] = None

def __post_init__(self):
# convert the following strings to enums
# so they don't need to be exported to the external interface
Expand All @@ -71,6 +81,8 @@ def __post_init__(self):
object.__setattr__(self, "device", Device.auto)
else:
object.__setattr__(self, "device", Device[self.device])
if not isinstance(self.device_id, DeviceID):
object.__setattr__(self, "device_id", DeviceID(self.device_id))
if not isinstance(self.dtype, Dtype):
object.__setattr__(self, "dtype", Dtype[self.dtype])
if not isinstance(self.pooling_method, PoolingMethod):
Expand Down Expand Up @@ -100,7 +112,9 @@ def __post_init__(self):
if CHECK_PYDANTIC.is_available:
# convert to pydantic dataclass
# and check if the dataclass is valid
@dataclass_pydantic(frozen=True, **dataclass_args)
@dataclass_pydantic(
frozen=True, config=ConfigDict(arbitrary_types_allowed=True), **dataclass_args
)
class EngineArgsPydantic(EngineArgs):
def __post_init__(self):
# overwrite the __post_init__ method
Expand All @@ -109,10 +123,24 @@ def __post_init__(self):

# validate
EngineArgsPydantic(**self.__dict__)
if self._loading_strategy is None:
self.update_loading_strategy()
elif isinstance(self._loading_strategy, dict):
object.__setattr__(self, "_loading_strategy", LoadingStrategy(**self._loading_strategy))

def to_dict(self):
return asdict(self)

def update_loading_strategy(self):
"""Assign a device id to the EngineArgs object."""
from infinity_emb.inference import loading_strategy # type: ignore

object.__setattr__(self, "_loading_strategy", loading_strategy.get_loading_strategy(self))
return self

def copy(self):
return deepcopy(self)

@classmethod
def from_env(cls) -> list["EngineArgs"]:
"""Create a list of EngineArgs from environment variables."""
Expand Down
8 changes: 5 additions & 3 deletions libs/infinity_emb/infinity_emb/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ def __init__(

self.running = False
self._running_sepamore: Optional[Semaphore] = None
self._model, self._min_inference_t, self._max_inference_t = select_model(self._engine_args)
self._model_replicas, self._min_inference_t, self._max_inference_t = select_model(
self._engine_args
)

@classmethod
def from_args(
Expand Down Expand Up @@ -85,7 +87,7 @@ async def astart(self):
self.running = True
self._batch_handler = BatchHandler(
max_batch_size=self._engine_args.batch_size,
model=self._model,
model_replicas=self._model_replicas,
batch_delay=self._min_inference_t / 2,
vector_disk_cache_path=self._engine_args.vector_disk_cache_path,
verbose=logger.level <= 10,
Expand Down Expand Up @@ -122,7 +124,7 @@ def is_running(self) -> bool:

@property
def capabilities(self) -> set[ModelCapabilites]:
return self._model.capabilities
return self._model_replicas[0].capabilities

@property
def engine_args(self) -> EngineArgs:
Expand Down
7 changes: 6 additions & 1 deletion libs/infinity_emb/infinity_emb/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from infinity_emb.log_handler import logger
from infinity_emb.primitives import (
Device,
DeviceIDProxy,
Dtype,
EmbeddingDtype,
EnumType,
Expand Down Expand Up @@ -224,7 +225,7 @@ def log_level(self):

def _typed_multiple(self, name: str, cls: type["EnumTypeLike"]) -> list["str"]:
result = self._optional_infinity_var_multiple(name, default=[cls.default_value()])
assert all(cls(v) for v in result)
tuple(cls(v) for v in result) # check if all values are valid
return result

@cached_property
Expand All @@ -243,6 +244,10 @@ def pooling_method(self) -> list[str]:
def device(self) -> list[str]:
return self._typed_multiple("device", Device)

@cached_property
def device_id(self):
return self._typed_multiple("device_id", DeviceIDProxy)

@cached_property
def embedding_dtype(self) -> list[str]:
return self._typed_multiple("embedding_dtype", EmbeddingDtype)
Expand Down
63 changes: 32 additions & 31 deletions libs/infinity_emb/infinity_emb/inference/batch_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def submit(self, *args, **kwargs):
class BatchHandler:
def __init__(
self,
model: BaseTransformer,
model_replicas: list[BaseTransformer],
max_batch_size: int,
max_queue_wait: int = MANAGER.queue_size,
batch_delay: float = 5e-3,
Expand Down Expand Up @@ -101,16 +101,19 @@ def __init__(
)
self._result_store = ResultKVStoreFuture(cache)
# model
self.model_worker = ModelWorker(
max_batch_size=max_batch_size,
shutdown=ShutdownReadOnly(self._shutdown),
model=model,
threadpool=ThreadPoolExecutorReadOnly(self._threadpool),
input_q=self._queue_prio,
output_q=self._result_queue,
verbose=verbose,
batch_delay=batch_delay,
)
self.model_worker = [
ModelWorker(
max_batch_size=max_batch_size,
shutdown=ShutdownReadOnly(self._shutdown),
model=model_replica,
threadpool=ThreadPoolExecutorReadOnly(self._threadpool),
input_q=self._queue_prio,
output_q=self._result_queue,
verbose=verbose,
batch_delay=batch_delay,
)
for model_replica in model_replicas
]

if batch_delay > 0.1:
logger.warning(f"high batch delay of {batch_delay}")
Expand All @@ -135,10 +138,9 @@ async def embed(self, sentences: list[str]) -> tuple[list["EmbeddingReturnType"]
list["EmbeddingReturnType"]: list of embedding as 1darray
int: token usage
"""
if "embed" not in self.model_worker.capabilities:
if "embed" not in self.capabilities:
raise ModelNotDeployedError(
"the loaded moded cannot fullyfill `embed`. "
f"Options are {self.model_worker.capabilities}."
"the loaded moded cannot fullyfill `embed`. " f"Options are {self.capabilities}."
)
input_sentences = [EmbeddingSingle(sentence=s) for s in sentences]

Expand Down Expand Up @@ -169,10 +171,9 @@ async def rerank(
list[float]: list of scores
int: token usage
"""
if "rerank" not in self.model_worker.capabilities:
if "rerank" not in self.capabilities:
raise ModelNotDeployedError(
"the loaded moded cannot fullyfill `rerank`. "
f"Options are {self.model_worker.capabilities}."
"the loaded moded cannot fullyfill `rerank`. " f"Options are {self.capabilities}."
)
rerankables = [ReRankSingle(query=query, document=doc) for doc in docs]
scores, usage = await self._schedule(rerankables)
Expand Down Expand Up @@ -209,10 +210,9 @@ async def classify(
list[ClassifyReturnType]: list of class encodings
int: token usage
"""
if "classify" not in self.model_worker.capabilities:
if "classify" not in self.capabilities:
raise ModelNotDeployedError(
"the loaded moded cannot fullyfill `classify`. "
f"Options are {self.model_worker.capabilities}."
"the loaded moded cannot fullyfill `classify`. " f"Options are {self.capabilities}."
)
items = [PredictSingle(sentence=s) for s in sentences]
classifications, usage = await self._schedule(items)
Expand Down Expand Up @@ -242,10 +242,10 @@ async def image_embed(
int: token usage
"""

if "image_embed" not in self.model_worker.capabilities:
if "image_embed" not in self.capabilities:
raise ModelNotDeployedError(
"the loaded moded cannot fullyfill `image_embed`. "
f"Options are {self.model_worker.capabilities}."
f"Options are {self.capabilities}."
)

items = await resolve_images(images)
Expand All @@ -271,15 +271,15 @@ async def audio_embed(
int: token usage
"""

if "audio_embed" not in self.model_worker.capabilities:
if "audio_embed" not in self.capabilities:
raise ModelNotDeployedError(
"the loaded moded cannot fullyfill `audio_embed`. "
f"Options are {self.model_worker.capabilities}."
f"Options are {self.capabilities}."
)

items = await resolve_audios(
audios,
getattr(self.model_worker._model, "sampling_rate", -42),
getattr(self.model_worker[0]._model, "sampling_rate", -42),
)
embeddings, usage = await self._schedule(items)
return embeddings, usage
Expand Down Expand Up @@ -308,7 +308,7 @@ async def _schedule(self, list_queueitem: Sequence[AbstractSingle]) -> tuple[lis
@property
def capabilities(self) -> set[ModelCapabilites]:
# TODO: try to remove inheritance here and return upon init.
return self.model_worker.capabilities
return self.model_worker[0].capabilities

def is_overloaded(self) -> bool:
"""checks if more items can be queued.
Expand Down Expand Up @@ -344,7 +344,7 @@ async def _get_prios_usage(self, items: Sequence[AbstractSingle]) -> tuple[list[
get_lengths_with_tokenize,
self._threadpool,
_sentences=[it.str_repr() for it in items],
tokenize=self.model_worker.tokenize_lengths,
tokenize=self.model_worker[0].tokenize_lengths,
)

@staticmethod
Expand Down Expand Up @@ -391,7 +391,8 @@ async def spawn(self):
ShutdownReadOnly(self._shutdown), self._result_queue, self._threadpool
)
)
self.model_worker.spawn()
for worker in self.model_worker:
worker.spawn()

async def shutdown(self):
"""
Expand Down Expand Up @@ -575,14 +576,14 @@ def _postprocess_batch(self):
# if not self._shutdown.is_set():
# logger.debug("Sending a warm up through embedding.")
# try:
# if "embed" in self.model_worker.capabilities:
# if "embed" in self.capabilities:
# # await self.embed(sentences=["test"] * self.max_batch_size)
# self.
# if "rerank" in self.model_worker.capabilities:
# if "rerank" in self.capabilities:
# # await self.rerank(
# # query="query", docs=["test"] * self.max_batch_size
# # )
# if "classify" in self.model_worker.capabilities:
# if "classify" in self.capabilities:
# # await self.classify(sentences=["test"] * self.max_batch_size)
# except Exception:
# pass
Loading

0 comments on commit 941105f

Please sign in to comment.