243 lines
8.8 KiB
Python
243 lines
8.8 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
"""
|
|
Implements a distributed key-value (KV) cache transfer mechanism.
|
|
|
|
Key Features:
|
|
- Distributed KV cache transmission using PyNccl pipes.
|
|
- Non-blocking `insert`, blocking `drop_select`.
|
|
- Use CPU signal pipe to avoid racing condition
|
|
- Handles buffer size constraints and provide backpressure mechanism to
|
|
stop the prefill instance when the decode instance is slow.
|
|
"""
|
|
|
|
import threading
|
|
from collections import deque
|
|
|
|
import torch
|
|
|
|
from vllm.distributed.kv_transfer.kv_lookup_buffer.base import KVLookupBufferBase
|
|
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
|
|
from vllm.logger import init_logger
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
class SimpleBuffer(KVLookupBufferBase):
|
|
def __init__(
|
|
self, signal_pipe: KVPipeBase, data_pipe: KVPipeBase, buffer_size_thresh: float
|
|
):
|
|
"""
|
|
signal_pipe: on CPU
|
|
|
|
NOTE: on-device recv will block all threads in the process, making the
|
|
KV cache producer unable to listen to new request while transmitting
|
|
KV cache. Luckily CPU recv only blocks the current thread so we use
|
|
CPU recv to listen to new request.
|
|
|
|
data_pipe: on device (e.g. GPU)
|
|
"""
|
|
|
|
self.buffer: deque[list[torch.Tensor]] = deque()
|
|
|
|
self.buffer_size = 0
|
|
self.buffer_size_threshold = buffer_size_thresh
|
|
self.buffer_cv = threading.Condition()
|
|
self.signal_pipe = signal_pipe
|
|
self.data_pipe = data_pipe
|
|
self.request_handling_thread: threading.Thread | None = None
|
|
|
|
self.normal_signal = torch.tensor([0], device="cpu")
|
|
self.end_signal = None
|
|
|
|
def _matches(
|
|
self,
|
|
tokens_roi_sender: list[torch.Tensor],
|
|
tokens_roi_recver: list[torch.Tensor],
|
|
):
|
|
# tokens_roi_sender: tokens and roi of the producer (in the buffer)
|
|
# tokens_roi_recver: tokens and roi of the consumer (query)
|
|
|
|
tokens_sender = tokens_roi_sender[0]
|
|
tokens_recver = tokens_roi_recver[0]
|
|
roi_sender = tokens_roi_sender[1]
|
|
roi_recver = tokens_roi_recver[1]
|
|
|
|
if tokens_recver is None:
|
|
# consumer sends an empty request
|
|
# semantics: DROP SELECT * LIMIT 1
|
|
# so any of the data in the buffer can be drop-selected
|
|
return True
|
|
|
|
# Assuming that roi is a binary mask on tokens
|
|
tokens_sender = tokens_sender[roi_sender]
|
|
tokens_recver = tokens_recver[roi_recver]
|
|
|
|
# simple common prefix matching
|
|
min_length = min(len(tokens_sender), len(tokens_recver))
|
|
if torch.allclose(tokens_sender[:min_length], tokens_recver[:min_length]):
|
|
return min_length
|
|
|
|
return 0
|
|
|
|
def _send_tensor_and_dec_size(self, tensor: torch.Tensor | None) -> None:
|
|
assert tensor is not None, "Use self.data_pipe.send(None) instead"
|
|
self.buffer_size -= tensor.element_size() * tensor.numel()
|
|
if tensor.dtype == torch.bool:
|
|
tensor = tensor.float()
|
|
self.data_pipe.send_tensor(tensor)
|
|
|
|
def _get_element_size(self, data: list | torch.Tensor | None):
|
|
if isinstance(data, torch.Tensor):
|
|
return data.element_size() * data.numel()
|
|
if not data:
|
|
# cannot perform `not data` on a tensor
|
|
# so this check needs to go after the check above
|
|
return 0
|
|
|
|
raise AssertionError(f"Unknown data type {type(data)}")
|
|
|
|
def _add_to_buffer(
|
|
self,
|
|
input_tokens: torch.Tensor,
|
|
roi: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
hidden: torch.Tensor,
|
|
):
|
|
if isinstance(input_tokens, torch.Tensor):
|
|
input_tokens = input_tokens.clone()
|
|
if isinstance(roi, torch.Tensor):
|
|
roi = roi.clone()
|
|
if isinstance(key, torch.Tensor):
|
|
key = key.clone()
|
|
if isinstance(value, torch.Tensor):
|
|
value = value.clone()
|
|
if isinstance(hidden, torch.Tensor):
|
|
hidden = hidden.clone()
|
|
|
|
buffer_item = [input_tokens, roi, key, value, hidden]
|
|
data_size = sum([self._get_element_size(data) for data in buffer_item])
|
|
|
|
with self.buffer_cv:
|
|
if self.buffer_size + data_size > self.buffer_size_threshold:
|
|
# log outside the while loop to avoid this message being logged
|
|
# repeatedly.
|
|
logger.debug("KV transfer buffer is full. Handling...")
|
|
while self.buffer_size + data_size > self.buffer_size_threshold:
|
|
self.buffer_cv.wait()
|
|
|
|
self.buffer_size += data_size
|
|
self.buffer.append(buffer_item)
|
|
self.buffer_cv.notify()
|
|
|
|
def _is_end_signal(self, signal):
|
|
return signal is None
|
|
|
|
def drop_select_handler(self):
|
|
try:
|
|
while True:
|
|
signal = self.signal_pipe.recv_tensor()
|
|
if self._is_end_signal(signal):
|
|
logger.info("Received end signal!")
|
|
break
|
|
|
|
input_tokens = self.data_pipe.recv_tensor()
|
|
|
|
roi = self.data_pipe.recv_tensor()
|
|
assert roi is not None, (
|
|
"Please provide the roi when sending drop-select request"
|
|
)
|
|
roi = roi > 0.5
|
|
tokens_roi_recver = [input_tokens, roi]
|
|
|
|
def is_buffer_available(
|
|
tokens_roi_recver: list[torch.Tensor],
|
|
) -> bool:
|
|
# perform input tokens and roi matching
|
|
# FIXME: this matching is O(n), ideally it should be O(1)
|
|
# but this buffer size won't (and shouldn't) be too large so
|
|
# the fix is not urgent.
|
|
for _ in range(len(self.buffer)):
|
|
if self._matches(self.buffer[0], tokens_roi_recver) > 0:
|
|
return True
|
|
# rotate the element we just accessed to the end
|
|
self.buffer.rotate(-1)
|
|
return False
|
|
|
|
with self.buffer_cv:
|
|
while not is_buffer_available(tokens_roi_recver):
|
|
logger.debug("KV transfer buffer is not available. Waiting...")
|
|
self.buffer_cv.wait()
|
|
# need to clone the tensor
|
|
# in case the tensor is freed before sending finishes
|
|
matched_item = self.buffer.popleft()
|
|
for tensor in matched_item:
|
|
self._send_tensor_and_dec_size(tensor)
|
|
self.buffer_cv.notify()
|
|
|
|
except RuntimeError as e:
|
|
if "Connection closed by peer" not in str(e):
|
|
raise e
|
|
|
|
logger.debug("Closing drop_select_handler")
|
|
|
|
def drop_select(
|
|
self, input_tokens: torch.Tensor | None, roi: torch.Tensor | None
|
|
) -> list[torch.Tensor | None]:
|
|
assert self.request_handling_thread is None, (
|
|
"drop_select should be called by the KV cache consumer "
|
|
"(e.g. the decode vLLM instance)"
|
|
)
|
|
|
|
if isinstance(input_tokens, torch.Tensor):
|
|
input_tokens = input_tokens.clone()
|
|
if isinstance(roi, torch.Tensor):
|
|
roi = roi.clone().float()
|
|
|
|
self.signal_pipe.send_tensor(self.normal_signal)
|
|
self.data_pipe.send_tensor(input_tokens)
|
|
self.data_pipe.send_tensor(roi)
|
|
|
|
input_tokens = self.data_pipe.recv_tensor()
|
|
roi = self.data_pipe.recv_tensor()
|
|
if roi is not None:
|
|
# convert from float tensor to bool tensor
|
|
# as PyNccl does not support sending bool tensor
|
|
roi = roi > 0.5
|
|
key = self.data_pipe.recv_tensor()
|
|
value = self.data_pipe.recv_tensor()
|
|
hidden = self.data_pipe.recv_tensor()
|
|
|
|
return [input_tokens, roi, key, value, hidden]
|
|
|
|
def insert(
|
|
self,
|
|
input_tokens: torch.Tensor,
|
|
roi: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
hidden: torch.Tensor,
|
|
) -> None:
|
|
self._add_to_buffer(input_tokens, roi, key, value, hidden)
|
|
|
|
# when calling the insert, the current process is a sender
|
|
# need to launch the request handler and start listening to request.
|
|
if self.request_handling_thread is None:
|
|
self.request_handling_thread = threading.Thread(
|
|
target=self.drop_select_handler
|
|
)
|
|
self.request_handling_thread.start()
|
|
|
|
def close(self):
|
|
if (
|
|
hasattr(self, "request_handling_thread")
|
|
and self.request_handling_thread is not None
|
|
):
|
|
self.request_handling_thread.join()
|
|
|
|
else:
|
|
# TODO: have a explicit close signal and have a explicit way to
|
|
# check if it's requester
|
|
self.signal_pipe.send_tensor(self.end_signal)
|