Skip to content

Commit

Permalink
[Bugfix] Fix disagg hang caused by the prefill and decode communicati…
Browse files Browse the repository at this point in the history
…on issues (#12723)

Signed-off-by: Lu Fang <[email protected]>
  • Loading branch information
houseroad authored Feb 8, 2025
1 parent 932c6b7 commit 45cbc49
Showing 1 changed file with 40 additions and 47 deletions.
87 changes: 40 additions & 47 deletions vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
stop the prefill instance when the decode instance is slow.
"""
import threading
import time
from collections import deque
from typing import Deque, List, Optional, Union

Expand All @@ -29,21 +28,21 @@ class SimpleBuffer(KVLookupBufferBase):
def __init__(self, signal_pipe: KVPipeBase, data_pipe: KVPipeBase,
buffer_size_thresh: float):
"""
signal_pipe: on CPU
NOTE: on-device recv will block all threads in the process, making the
KV cache producer unable to listen to new request while transmitting
KV cache. Luckily CPU recv only blocks the current thread so we use
signal_pipe: on CPU
NOTE: on-device recv will block all threads in the process, making the
KV cache producer unable to listen to new request while transmitting
KV cache. Luckily CPU recv only blocks the current thread so we use
CPU recv to listen to new request.
data_pipe: on device (e.g. GPU)
"""

self.buffer: Deque[List[torch.Tensor]] = deque()

self.buffer_size = 0
self.buffer_size_threshold = buffer_size_thresh
self.buffer_lock = threading.Lock()
self.buffer_cv = threading.Condition()
self.signal_pipe = signal_pipe
self.data_pipe = data_pipe
self.request_handling_thread: Optional[threading.Thread] = None
Expand Down Expand Up @@ -116,11 +115,19 @@ def _add_to_buffer(self, input_tokens: torch.Tensor, roi: torch.Tensor,
hidden = hidden.clone()

buffer_item = [input_tokens, roi, key, value, hidden]
data_size = sum([self._get_element_size(data) for data in buffer_item])

with self.buffer_cv:
if self.buffer_size + data_size > self.buffer_size_threshold:
# log outside the while loop to avoid this message being logged
# repeatedly.
logger.debug("KV transfer buffer is full. Handling...")
while self.buffer_size + data_size > self.buffer_size_threshold:
self.buffer_cv.wait()

with self.buffer_lock:
for data in buffer_item:
self.buffer_size += self._get_element_size(data)
self.buffer_size += data_size
self.buffer.append(buffer_item)
self.buffer_cv.notify()

def _is_end_signal(self, signal):
return signal is None
Expand All @@ -143,35 +150,31 @@ def drop_select_handler(self):
roi = (roi > 0.5)
tokens_roi_recver = [input_tokens, roi]

matched_length = 0

# perform input tokens and roi matching
# FIXME: this matching is O(n), ideally it should be O(1)
# but this buffer size won't (and shouldn't) be too large so
# the fix is not urgent.
with self.buffer_lock:

def is_buffer_available(
tokens_roi_recver: List[torch.Tensor], ) -> bool:
# perform input tokens and roi matching
# FIXME: this matching is O(n), ideally it should be O(1)
# but this buffer size won't (and shouldn't) be too large so
# the fix is not urgent.
for _ in range(len(self.buffer)):

temp_length = self._matches(self.buffer[0],
tokens_roi_recver)
if temp_length > 0:
matched_length = temp_length
break
if self._matches(self.buffer[0],
tokens_roi_recver) > 0:
return True
# rotate the element we just accessed to the end
self.buffer.rotate(-1)

if matched_length > 0:
# need to clone the tensor
# in case the tensor is freed before sending finishes
matched_item = self.buffer.popleft()
for tensor in matched_item:
self._send_tensor_and_dec_size(tensor)

else:
# no match, just send None
for _ in range(5):
self.data_pipe.send_tensor(None)
return False

with self.buffer_cv:
while not is_buffer_available(tokens_roi_recver):
logger.debug(
"KV transfer buffer is not available. Waiting...")
self.buffer_cv.wait()
# need to clone the tensor
# in case the tensor is freed before sending finishes
matched_item = self.buffer.popleft()
for tensor in matched_item:
self._send_tensor_and_dec_size(tensor)
self.buffer_cv.notify()

except RuntimeError as e:
if 'Connection closed by peer' not in str(e):
Expand Down Expand Up @@ -208,20 +211,10 @@ def drop_select(

return [input_tokens, roi, key, value, hidden]

def full_handler(self):
time.sleep(0.001)

def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor,
key: torch.Tensor, value: torch.Tensor,
hidden: torch.Tensor) -> None:

if self.buffer_size > self.buffer_size_threshold:
# log outside the while loop to avoid this message being logged
# repeatedly.
logger.debug("KV transfer buffer is full. Handling...")
while self.buffer_size > self.buffer_size_threshold:
self.full_handler()

self._add_to_buffer(input_tokens, roi, key, value, hidden)

# when calling the insert, the current process is a sender
Expand Down

0 comments on commit 45cbc49

Please sign in to comment.