Skip to content

Commit

Permalink
Merge pull request #12 from michaelfeil/feature-refactor
Browse files Browse the repository at this point in the history
Improve real-time / sleep strategy, async await for queues and result futures
  • Loading branch information
michaelfeil authored Oct 22, 2023
2 parents 10cd8e3 + 4443006 commit 493311d
Show file tree
Hide file tree
Showing 6 changed files with 209 additions and 173 deletions.
201 changes: 128 additions & 73 deletions libs/infinity_emb/infinity_emb/inference/batch_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import threading
import time
from concurrent.futures import ThreadPoolExecutor
from operator import attrgetter
from typing import Dict, List, Union

from infinity_emb.inference.primitives import (
Expand All @@ -13,121 +14,159 @@
OverloadStatus,
PrioritizedQueueItem,
)
from infinity_emb.inference.threading_asyncio import EventTS
from infinity_emb.inference.threading_asyncio import to_thread
from infinity_emb.log_handler import logger
from infinity_emb.transformer.abstract import BaseTransformer
from infinity_emb.transformer.utils import get_lengths_with_tokenize


class CustomPrioQueue:
def __init__(self, sort_on_arrival=True) -> None:
self._lock_queue_event = threading.Lock() # lock queue and
# queue is a always sorted
def __init__(self) -> None:
""""""
self._lock_queue_event = threading.Lock()
self._queue: List[PrioritizedQueueItem] = []
# event that indicates items in queue.
self._sync_event = threading.Event()
self._sort_on_arrival = sort_on_arrival

def __len__(self):
return len(self._queue)

async def extend(self, items: List[PrioritizedQueueItem]):
with self._lock_queue_event:
if self._sort_on_arrival:
for item in items:
bisect.insort(self._queue, item)
else:
self._queue.extend(items)
for item in items:
bisect.insort(self._queue, item)

self._sync_event.set()

def pop_optimal_batch(
self, size: int, timeout=0.2
self, size: int, timeout=0.2, latest_first=False
) -> Union[List[EmbeddingResult], None]:
"""
pop batch `up to size` + `continuous (sorted)` from queue
Args:
size (int): max size of batch
timeout (float, optional): timeout until None is returned. Defaults to 0.2.
latest_first (bool, optional): guarantees processing of oldest item in list.
As latest first requires getting argmin of created timestamps,
which is slow. Defaults to False.
returns:
None: if there is not a single item in self._queue
None: if there is not a single item in self._queue after timeout
else: List[EmbeddingResult] with len(1<=size)
"""
if not self._queue:
if not self._sync_event.wait(timeout):
return None

if len(self._queue) > size:
# pick a random continuous slice, at beginning or at the end.
start = random.randrange(-size, len(self._queue))
if latest_first:
# slower operation: we have a large list
# and can spend some time computing the argmin
# pick oldest one -> argmin timestamp
start = self._queue.index(min(self._queue, key=attrgetter("timestamp")))
else:
# pick a random continuous slice, at beginning or at the end.
start = random.randrange(-size, len(self._queue))
start = max(0, min(len(self._queue) - size, start))
else:
start = 0
end = start + size

if not self._sort_on_arrival:
self._queue.sort()

with self._lock_queue_event:
new_items = self._queue[start:end]
self._queue = self._queue[:start] + self._queue[end:]
if not self._queue:
self._sync_event.clear()
# assert 1 <= len(new_items) <= size
return [i.item for i in new_items]


# class ResultKVStore:
# def __init__(self) -> None:
# self._kv: Dict[str, NpEmbeddingType] = {}

# def __len__(self):
# return len(self._kv)

# async def wait_for_response(self, uuid: str, event: EventTS) -> NpEmbeddingType:
# await event.wait()
# response = self._kv[uuid]
# del self._kv[uuid]
# return response

return list(n.item for n in new_items)
# async def extend(self, batch: List[EmbeddingResult]) -> None:
# """extend store with results"""
# _update = {item.uuid: item.embedding for item in batch}

# self._kv.update(_update)
# for item in batch:
# # all done, mark EmbeddingResult for collection
# item.event.set()

class ResultKVStore:

class ResultKVStoreFuture:
# TODO: test if this works.
def __init__(self) -> None:
self._lock = threading.Lock()
self._kv: Dict[str, NpEmbeddingType] = {}

def __len__(self):
return len(self._kv)

async def wait_for_response(self, uuid: str, event: EventTS) -> NpEmbeddingType:
await event.wait()
# with self._lock:
response = self._kv[uuid]
del self._kv[uuid]
async def wait_for_response(self, fut: asyncio.Future) -> NpEmbeddingType:
"""wait for future to return"""
response = await fut
return response

async def extend(self, batch: List[EmbeddingResult]) -> None:
"""extend store with results"""
# # with self._lock:
# for item in batch:
# # first create (while _lock)
# self._kv[item.uuid] = item.embedding
# # all done, mark EmbeddingResult for collection
# item.event.set()
_update = {item.uuid: item.embedding for item in batch}

self._kv.update(_update)
for item in batch:
# all done, mark EmbeddingResult for collection
item.event.set()
item.future.set_result(item.embedding)


class BatchHandler:
def __init__(
self,
model: BaseTransformer,
max_batch_size: int,
threadpool: ThreadPoolExecutor,
max_queue_wait: int = 64_000,
batch_delay: float = 5e-3,
verbose=False,
) -> None:
"""
performs batching around the model.
model: BaseTransformer, implements fn (core|pre|post)_encode
max_batch_size: max batch size of the models
max_queue_wait: max items to queue in the batch, default 64_000 sentences
batch_delay: sleep in seconds, wait time for pre/post methods.
Best result: setting to 1/2 or 1/3 the minimal expected
time for core_encode method / "gpu inference".
Dont set it above 1x minimal expected time of interence.
Should not be 0 to not block Python's GIL.
"""
self.model = model
self.max_batch_size = max_batch_size
self.max_queue_wait = max_queue_wait
self._verbose = verbose
self._shutdown = threading.Event()
self._queue_prio = CustomPrioQueue()
self._result_store = ResultKVStore()
self._result_store = ResultKVStoreFuture()
self._feature_queue: queue.Queue = queue.Queue(4)
self._postprocess_queue: queue.Queue = queue.Queue(4)
self.max_batch_size = max_batch_size
self.model = model
self.max_queue_wait = max_queue_wait
self._threadpool = threadpool
self._batch_delay = float(max(1e-5, batch_delay))
self._threadpool = ThreadPoolExecutor()
self._ready = False
self._verbose = verbose
self._last_inference = time.time()

def shutdown(self):
self._shutdown.set()
if batch_delay > 0.5:
logger.warn(f"high batch delay of {self._batch_delay}")
if max_batch_size > max_queue_wait * 10:
logger.warn(
f"queue_size={self.max_queue_wait} to small "
f"over batch_size={self.max_batch_size}."
" Consider increasing queue size"
)

async def schedule(self, sentences: List[str]) -> tuple[List[NpEmbeddingType], int]:
"""Schedule a sentence to be embedded. Awaits until embedded.
Expand All @@ -140,34 +179,28 @@ async def schedule(self, sentences: List[str]) -> tuple[List[NpEmbeddingType], i
NpEmbeddingType: embedding as 1darray
"""
# add an unique identifier
uuid_event = []
prioqueue = []

prios, usage = get_lengths_with_tokenize(
sentences
) # , self.model.tokenize_lengths)
prios, usage = get_lengths_with_tokenize(sentences)

futures = [self.loop.create_future() for _ in prios]

for s, p in zip(sentences, prios):
inner = EmbeddingResult(sentence=s, event=EventTS(self._threadpool))
item = PrioritizedQueueItem(item=inner, priority=p)
uuid_event.append((inner.uuid, inner.event))
for s, p, fut in zip(sentences, prios, futures):
item = PrioritizedQueueItem(
priority=p, item=EmbeddingResult(sentence=s, future=fut)
)
prioqueue.append(item)
await self._queue_prio.extend(prioqueue)

gather_results = [
self._result_store.wait_for_response(uuid, event)
for uuid, event in uuid_event
]
embeddings = await asyncio.gather(*gather_results)
embeddings = await asyncio.gather(
*[self._result_store.wait_for_response(fut) for fut in futures]
)
return embeddings, usage

def is_overloaded(self) -> bool:
# start consuming
"""checks if more items can be queued."""
return len(self._queue_prio) > self.max_queue_wait

def ready(self) -> bool:
return self._ready

def overload_status(self) -> OverloadStatus:
"""
returns info about the queue status
Expand All @@ -194,11 +227,14 @@ def _preprocess_batch(self):
or (len(self._queue_prio) < self.max_batch_size * 4)
):
# add some stochastic delay
time.sleep(2e-3)
time.sleep(self._batch_delay)
continue
# decision to attemp to pop a batch
# -> will happen if a single datapoint is available
batch = self._queue_prio.pop_optimal_batch(self.max_batch_size)

batch = self._queue_prio.pop_optimal_batch(
self.max_batch_size, latest_first=False
)
if not batch:
# not a single sentence available / len=0, wait for more
continue
Expand Down Expand Up @@ -234,6 +270,7 @@ def _core_batch(self):
except queue.Empty:
continue
(feat, batch) = core_batch
self._last_inference = time.time()
embed = self.model.encode_core(feat)
if self._verbose:
logger.debug("[🏃] Inference done on batch_size=%s", len(batch))
Expand All @@ -260,17 +297,26 @@ async def _postprocess_batch(self):
while not self._shutdown.is_set():
try:
post_batch = self._postprocess_queue.get_nowait()
if not self._postprocess_queue.qsize():
# queue is not in a hurry
# give the CPU some time to focus
# on moving the next batch to GPU on the forward pass
# before proceeding
await asyncio.sleep(1e-3)
except queue.Empty:
# 7 ms, assuming this is below
# instead use async await to get
try:
post_batch = await to_thread(
self._postprocess_queue.get, self._threadpool, timeout=1
)
except queue.Empty:
# in case of timeout start again
continue

if (
not self._postprocess_queue.qsize()
and self._last_inference < time.time() + self._batch_delay * 2
):
# 5 ms, assuming this is below
# 3-50ms for inference on avg.
await asyncio.sleep(5e-3)
continue
# give the CPU some time to focus
# on moving the next batch to GPU on the forward pass
# before proceeding
await asyncio.sleep(self._batch_delay)
embed, batch = post_batch
embeddings = self.model.encode_post(embed).tolist()
for i, item in enumerate(batch):
Expand All @@ -284,6 +330,15 @@ async def _postprocess_batch(self):
async def spawn(self):
"""set up the resources in batch"""
logger.info("creating batching engine")
self.loop = asyncio.get_event_loop() # asyncio.events._get_running_loop()
self._threadpool.submit(self._preprocess_batch)
self._threadpool.submit(self._core_batch)
asyncio.create_task(self._postprocess_batch())

def shutdown(self):
"""
set the shutdown event and close threadpool.
Blocking event, until shutdown complete.
"""
self._shutdown.set()
self._threadpool.shutdown(wait=True)
30 changes: 23 additions & 7 deletions libs/infinity_emb/infinity_emb/inference/primitives.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,48 @@
import asyncio
import time
from dataclasses import dataclass, field
from typing import Optional
from uuid import uuid4

import numpy as np

from infinity_emb.inference.threading_asyncio import EventTS
# from infinity_emb.inference.threading_asyncio import EventTS

NpEmbeddingType = np.ndarray


@dataclass
@dataclass(order=True)
class EmbeddingResult:
sentence: str
event: EventTS
uuid: str = field(default_factory=lambda: str(uuid4()))
created: float = field(default_factory=time.time)
embedding: Optional[NpEmbeddingType] = None
sentence: str = field(compare=False)
uuid: str = field(default_factory=lambda: str(uuid4()), compare=False)
embedding: Optional[NpEmbeddingType] = field(default=None, compare=False)
future: Optional[asyncio.Future] = field(default=None, compare=False)
# event: Optional[EventTS] = field(default=None, compare=False)


@dataclass(order=True)
class PrioritizedQueueItem:
priority: int
item: EmbeddingResult = field(compare=False)
timestamp: float = field(default_factory=time.time, compare=False)


@dataclass
class OverloadStatus:
queue_fraction: float
queue_absolute: int
results_absolute: int


if __name__ == "__main__":
import bisect
from concurrent.futures import ThreadPoolExecutor

tp = ThreadPoolExecutor()
r1 = EmbeddingResult(5, "hello")
r2 = EmbeddingResult(6, "hello_")
r3 = EmbeddingResult(6, "hello_")
r1 < r2
l1 = []
bisect.insort(l1, r1)
bisect.insort(l1, r2)
Loading

0 comments on commit 493311d

Please sign in to comment.