diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py index 5e1b62352d14c..3462f7de020ef 100644 --- a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py @@ -10,7 +10,6 @@ stop the prefill instance when the decode instance is slow. """ import threading -import time from collections import deque from typing import Deque, List, Optional, Union @@ -29,13 +28,13 @@ class SimpleBuffer(KVLookupBufferBase): def __init__(self, signal_pipe: KVPipeBase, data_pipe: KVPipeBase, buffer_size_thresh: float): """ - signal_pipe: on CPU - - NOTE: on-device recv will block all threads in the process, making the - KV cache producer unable to listen to new request while transmitting - KV cache. Luckily CPU recv only blocks the current thread so we use + signal_pipe: on CPU + + NOTE: on-device recv will block all threads in the process, making the + KV cache producer unable to listen to new request while transmitting + KV cache. Luckily CPU recv only blocks the current thread so we use CPU recv to listen to new request. - + data_pipe: on device (e.g. GPU) """ @@ -43,7 +42,7 @@ def __init__(self, signal_pipe: KVPipeBase, data_pipe: KVPipeBase, self.buffer_size = 0 self.buffer_size_threshold = buffer_size_thresh - self.buffer_lock = threading.Lock() + self.buffer_cv = threading.Condition() self.signal_pipe = signal_pipe self.data_pipe = data_pipe self.request_handling_thread: Optional[threading.Thread] = None @@ -116,11 +115,19 @@ def _add_to_buffer(self, input_tokens: torch.Tensor, roi: torch.Tensor, hidden = hidden.clone() buffer_item = [input_tokens, roi, key, value, hidden] + data_size = sum([self._get_element_size(data) for data in buffer_item]) + + with self.buffer_cv: + if self.buffer_size + data_size > self.buffer_size_threshold: + # log outside the while loop to avoid this message being logged + # repeatedly. + logger.debug("KV transfer buffer is full. Handling...") + while self.buffer_size + data_size > self.buffer_size_threshold: + self.buffer_cv.wait() - with self.buffer_lock: - for data in buffer_item: - self.buffer_size += self._get_element_size(data) + self.buffer_size += data_size self.buffer.append(buffer_item) + self.buffer_cv.notify() def _is_end_signal(self, signal): return signal is None @@ -143,35 +150,31 @@ def drop_select_handler(self): roi = (roi > 0.5) tokens_roi_recver = [input_tokens, roi] - matched_length = 0 - - # perform input tokens and roi matching - # FIXME: this matching is O(n), ideally it should be O(1) - # but this buffer size won't (and shouldn't) be too large so - # the fix is not urgent. - with self.buffer_lock: - + def is_buffer_available( + tokens_roi_recver: List[torch.Tensor], ) -> bool: + # perform input tokens and roi matching + # FIXME: this matching is O(n), ideally it should be O(1) + # but this buffer size won't (and shouldn't) be too large so + # the fix is not urgent. for _ in range(len(self.buffer)): - - temp_length = self._matches(self.buffer[0], - tokens_roi_recver) - if temp_length > 0: - matched_length = temp_length - break + if self._matches(self.buffer[0], + tokens_roi_recver) > 0: + return True # rotate the element we just accessed to the end self.buffer.rotate(-1) - - if matched_length > 0: - # need to clone the tensor - # in case the tensor is freed before sending finishes - matched_item = self.buffer.popleft() - for tensor in matched_item: - self._send_tensor_and_dec_size(tensor) - - else: - # no match, just send None - for _ in range(5): - self.data_pipe.send_tensor(None) + return False + + with self.buffer_cv: + while not is_buffer_available(tokens_roi_recver): + logger.debug( + "KV transfer buffer is not available. Waiting...") + self.buffer_cv.wait() + # need to clone the tensor + # in case the tensor is freed before sending finishes + matched_item = self.buffer.popleft() + for tensor in matched_item: + self._send_tensor_and_dec_size(tensor) + self.buffer_cv.notify() except RuntimeError as e: if 'Connection closed by peer' not in str(e): @@ -208,20 +211,10 @@ def drop_select( return [input_tokens, roi, key, value, hidden] - def full_handler(self): - time.sleep(0.001) - def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, key: torch.Tensor, value: torch.Tensor, hidden: torch.Tensor) -> None: - if self.buffer_size > self.buffer_size_threshold: - # log outside the while loop to avoid this message being logged - # repeatedly. - logger.debug("KV transfer buffer is full. Handling...") - while self.buffer_size > self.buffer_size_threshold: - self.full_handler() - self._add_to_buffer(input_tokens, roi, key, value, hidden) # when calling the insert, the current process is a sender