This commit is contained in:
2025-08-13 19:46:19 +08:00
commit 5d2e7edf78
1232 changed files with 361215 additions and 0 deletions

View File

@@ -0,0 +1,175 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This file contains a new class `KVLookupBufferBase` that allows developers to
think of KV cache operations as inserting new KV cache entries (`insert`)
into the lookup buffer and querying existing KV caches (`drop_select`)
from the lookup buffer.
This file also contains a new class `KVStoreBufferBase` that allows developers
to manage the KVCache buffer as a simple key-value storage buffer with basic
put/get operations.
These classes above are abstracted behind class `KVCacheBufferBase`.
"""
from abc import ABC, abstractmethod
from typing import Optional
import torch
class KVCacheBufferBase(ABC):
"""
Abstract base class for a KVCache buffer.
"""
@abstractmethod
def close(self) -> None:
"""Close the buffer and release resources.
This method is responsible for cleaning up resources related to the
KVCache buffer when it is no longer needed.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise NotImplementedError
class KVLookupBufferBase(KVCacheBufferBase):
"""
Abstract base class for a KVCache lookup buffer.
This class provides an abstraction for a key-value (KV) cache lookup buffer.
The key of the lookup buffer:
- input_tokens: token IDs of the request
- roi: a binary mask on top of input_tokens.
- Purpose of roi: Since KV cache may only be available for a subset of
tokens in the input (for example, when vLLM is connected to an external
KV cache service), roi specifies the subset of tokens that the KV cache
is associated with.
- NOTE: roi can be further extended to describe which part of KV the
current process is holding (each process may only hold a part of KV
due to TP and PP). This is not implemented for now.
The value of the lookup buffer:
- key: the key tensor in the KV cache
- value: the value tensor in the KV cache
- hidden: the final hidden state generated by model forwarding. This allows
vLLM to bypass further model forwarding by transmitting the hidden state.
"""
@abstractmethod
def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor,
key: torch.Tensor, value: torch.Tensor,
hidden: torch.Tensor) -> None:
"""Insert into the lookup buffer.
The functionality is similar to the following python statement
```
buffer[input_tokens, roi] = [key, value, hidden]
```
FIXME: in the future, we should only have two arguments, key and value,
where key is a tensor dict and value is a tensor dict.
FIXME: we should transmit both sampler outputs and the hidden states.
Args:
input_tokens (torch.Tensor): token IDs.
roi (torch.Tensor): A binary mask on top of the input tokens
key (torch.Tensor): The key tensor in the KV cache.
value (torch.Tensor): The value tensor in the KV cache.
hidden (torch.Tensor): The final hidden state tensor generated
during model forwarding to bypass model
forwarding.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise NotImplementedError
@abstractmethod
def drop_select(
self, input_tokens: Optional[torch.Tensor],
roi: Optional[torch.Tensor]) -> list[Optional[torch.Tensor]]:
"""Select and *drop* KV cache entries from the lookup buffer.
The functionality is similar to the following python statements
```
ret = buffer.pop(input_tokens, roi)
return ret
```
If `input_tokens` and `roi` is `None`, it means selecting any of the
KV caches in the buffer, return, and remove it from the buffer, useful
when offloading KV cache to KV cache storage service.
Args:
input_tokens (torch.Tensor): token IDs.
roi (torch.Tensor): A binary mask on top of the input tokens
Returns:
list[Optional[torch.Tensor]]: A list of tensors. Can be None.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise NotImplementedError
class KVStoreBufferBase(KVCacheBufferBase):
"""
Abstract base class for a KVCache storage buffer with key-value semantics.
This class provides a simple key-value storage buffer abstract with basic
put/get operations, which enables flexible KVCache transfer granular
control.
The functionality is similar to a distributed key-value store, where:
- Key: A unique string identifier for the cached entry
- Value:
- Tensor to be stored and retrieved
- None (indicating deletion or empty value)
"""
@abstractmethod
def put(
self,
key: str,
value: Optional[torch.Tensor],
) -> None:
"""Store a key-value pair in the buffer.
Args:
key (str): Unique identifier for a tensor, this tensor could be the
key cache tensor, value cache tensor, or hidden state tensor
generated during model forwarding.
value (Optional[torch.Tensor]): Tensor to be stored.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise NotImplementedError
@abstractmethod
def get(
self,
key: str,
) -> Optional[torch.Tensor]:
"""Retrieve a value from the buffer by key.
Args:
key (str): Unique identifier for a tensor, this tensor could be the
key cache tensor, value cache tensor, or hidden state tensor
generated during model forwarding.
Returns:
Optional[torch.Tensor]: Stored tensor if exists, None otherwise.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise NotImplementedError

View File

@@ -0,0 +1,161 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This file contains a new class `MooncakeStore` that allows developers to
think of KV cache transfer operations as putting new KV cache entries
into a remote KVStore-based lookup buffer and getting existing KV caches
from this remote lookup buffer.
"""
import json
import os
from dataclasses import dataclass
from typing import Optional
import torch
from safetensors.torch import load as safetensors_load
from safetensors.torch import save as safetensors_save
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_lookup_buffer.base import (
KVStoreBufferBase)
from vllm.logger import init_logger
DEFAULT_GLOBAL_SEGMENT_SIZE = 3355443200 # 3.125 GiB
DEFAULT_LOCAL_BUFFER_SIZE = 1073741824 # 1.0 GiB
logger = init_logger(__name__)
@dataclass
class MooncakeStoreConfig:
local_hostname: str
metadata_server: str
global_segment_size: int
local_buffer_size: int
protocol: str
device_name: str
master_server_address: str
@staticmethod
def from_file(file_path: str) -> 'MooncakeStoreConfig':
"""Load the config from a JSON file."""
with open(file_path) as fin:
config = json.load(fin)
return MooncakeStoreConfig(
local_hostname=config.get("local_hostname"),
metadata_server=config.get("metadata_server"),
global_segment_size=config.get("global_segment_size",
DEFAULT_GLOBAL_SEGMENT_SIZE),
local_buffer_size=config.get("local_buffer_size",
DEFAULT_LOCAL_BUFFER_SIZE),
protocol=config.get("protocol", "tcp"),
device_name=config.get("device_name", ""),
master_server_address=config.get("master_server_address"),
)
@staticmethod
def load_from_env() -> 'MooncakeStoreConfig':
"""Load config from a file specified in the environment variable."""
config_file_path = os.getenv('MOONCAKE_CONFIG_PATH')
if config_file_path is None:
raise ValueError(
"The environment variable 'MOONCAKE_CONFIG_PATH' is not set.")
return MooncakeStoreConfig.from_file(config_file_path)
class MooncakeStore(KVStoreBufferBase):
def __init__(
self,
config: VllmConfig,
):
try:
from mooncake.store import MooncakeDistributedStore
except ImportError as e:
raise ImportError(
"Please install mooncake by following the instructions at "
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501
"to run vLLM with MooncakeConnector.") from e
try:
self.store = MooncakeDistributedStore()
self.config = MooncakeStoreConfig.load_from_env()
logger.info("Mooncake Configuration loaded successfully.")
self.store.setup(self.config.local_hostname,
self.config.metadata_server,
self.config.global_segment_size,
self.config.local_buffer_size,
self.config.protocol, self.config.device_name,
self.config.master_server_address)
except ValueError as e:
logger.error("Configuration loading failed: %s", e)
raise
except Exception as exc:
logger.error(
"An error occurred while loading the configuration: %s", exc)
raise
def close(self):
# MooncakeDistributedStore will automatically call the destructor, so
# it is unnecessary to close it manually.
pass
def put(
self,
key: str,
value: Optional[torch.Tensor],
) -> None:
# A message queue needs to be introduced before making it asynchronous.
if value is not None:
self._put_impl(key, value)
def get(
self,
key: str,
) -> Optional[torch.Tensor]:
# A message queue needs to be introduced before making it asynchronous.
value = self._get_impl(key)
return value
def _put_impl(
self,
key: str,
value: torch.Tensor,
) -> None:
"""Put KVCache to Mooncake Store"""
device_id = value.device.index if value.device.type == 'cuda' else -1
device_tensor = torch.tensor(device_id, dtype=torch.int32)
value_bytes = safetensors_save({
"tensor": value,
"device_id": device_tensor
})
try:
self.store.put(key, value_bytes)
except TypeError as err:
logger.error("Failed to put value into Mooncake Store: %s", err)
raise TypeError("Mooncake Store Put Type Error.") from err
def _get_impl(
self,
key: str,
) -> Optional[torch.Tensor]:
"""Get KVCache from Mooncake Store"""
try:
data = self.store.get(key)
except TypeError as err:
logger.error("Failed to get value from Mooncake Store: %s", err)
raise TypeError("Mooncake Store Get Type Error.") from err
if data:
loaded_tensors = safetensors_load(data)
tensor = loaded_tensors["tensor"]
device_id_tensor = loaded_tensors["device_id"]
device_id = int(device_id_tensor.item())
device = torch.device(
'cuda', device_id) if device_id >= 0 else torch.device('cpu')
return tensor.to(device)
return None

View File

@@ -0,0 +1,237 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Implements a distributed key-value (KV) cache transfer mechanism.
Key Features:
- Distributed KV cache transmission using PyNccl pipes.
- Non-blocking `insert`, blocking `drop_select`.
- Use CPU signal pipe to avoid racing condition
- Handles buffer size constraints and provide backpressure mechanism to
stop the prefill instance when the decode instance is slow.
"""
import threading
from collections import deque
from typing import Optional, Union
import torch
from vllm.distributed.kv_transfer.kv_lookup_buffer.base import (
KVLookupBufferBase)
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
from vllm.logger import init_logger
logger = init_logger(__name__)
class SimpleBuffer(KVLookupBufferBase):
def __init__(self, signal_pipe: KVPipeBase, data_pipe: KVPipeBase,
buffer_size_thresh: float):
"""
signal_pipe: on CPU
NOTE: on-device recv will block all threads in the process, making the
KV cache producer unable to listen to new request while transmitting
KV cache. Luckily CPU recv only blocks the current thread so we use
CPU recv to listen to new request.
data_pipe: on device (e.g. GPU)
"""
self.buffer: deque[list[torch.Tensor]] = deque()
self.buffer_size = 0
self.buffer_size_threshold = buffer_size_thresh
self.buffer_cv = threading.Condition()
self.signal_pipe = signal_pipe
self.data_pipe = data_pipe
self.request_handling_thread: Optional[threading.Thread] = None
self.normal_signal = torch.tensor([0], device="cpu")
self.end_signal = None
def _matches(self, tokens_roi_sender: list[torch.Tensor],
tokens_roi_recver: list[torch.Tensor]):
# tokens_roi_sender: tokens and roi of the producer (in the buffer)
# tokens_roi_recver: tokens and roi of the consumer (query)
tokens_sender = tokens_roi_sender[0]
tokens_recver = tokens_roi_recver[0]
roi_sender = tokens_roi_sender[1]
roi_recver = tokens_roi_recver[1]
if tokens_recver is None:
# consumer sends an empty request
# semantics: DROP SELECT * LIMIT 1
# so any of the data in the buffer can be drop-selected
return True
# Assuming that roi is a binary mask on tokens
tokens_sender = tokens_sender[roi_sender]
tokens_recver = tokens_recver[roi_recver]
# simple common prefix matching
min_length = min(len(tokens_sender), len(tokens_recver))
if torch.allclose(tokens_sender[:min_length],
tokens_recver[:min_length]):
return min_length
return 0
def _send_tensor_and_dec_size(self,
tensor: Optional[torch.Tensor]) -> None:
assert tensor is not None, "Use self.data_pipe.send(None) instead"
self.buffer_size -= tensor.element_size() * tensor.numel()
if tensor.dtype == torch.bool:
tensor = tensor.float()
self.data_pipe.send_tensor(tensor)
def _get_element_size(self, data: Optional[Union[list, torch.Tensor]]):
if isinstance(data, torch.Tensor):
return data.element_size() * data.numel()
if not data:
# cannot perform `not data` on a tensor
# so this check needs to go after the check above
return 0
raise AssertionError(f"Unknown data type {type(data)}")
def _add_to_buffer(self, input_tokens: torch.Tensor, roi: torch.Tensor,
key: torch.Tensor, value: torch.Tensor,
hidden: torch.Tensor):
if isinstance(input_tokens, torch.Tensor):
input_tokens = input_tokens.clone()
if isinstance(roi, torch.Tensor):
roi = roi.clone()
if isinstance(key, torch.Tensor):
key = key.clone()
if isinstance(value, torch.Tensor):
value = value.clone()
if isinstance(hidden, torch.Tensor):
hidden = hidden.clone()
buffer_item = [input_tokens, roi, key, value, hidden]
data_size = sum([self._get_element_size(data) for data in buffer_item])
with self.buffer_cv:
if self.buffer_size + data_size > self.buffer_size_threshold:
# log outside the while loop to avoid this message being logged
# repeatedly.
logger.debug("KV transfer buffer is full. Handling...")
while self.buffer_size + data_size > self.buffer_size_threshold:
self.buffer_cv.wait()
self.buffer_size += data_size
self.buffer.append(buffer_item)
self.buffer_cv.notify()
def _is_end_signal(self, signal):
return signal is None
def drop_select_handler(self):
try:
while True:
signal = self.signal_pipe.recv_tensor()
if self._is_end_signal(signal):
logger.info("Received end signal!")
break
input_tokens = self.data_pipe.recv_tensor()
roi = self.data_pipe.recv_tensor()
assert roi is not None, "Please provide the roi when sending "\
"drop-select request"
roi = (roi > 0.5)
tokens_roi_recver = [input_tokens, roi]
def is_buffer_available(
tokens_roi_recver: list[torch.Tensor], ) -> bool:
# perform input tokens and roi matching
# FIXME: this matching is O(n), ideally it should be O(1)
# but this buffer size won't (and shouldn't) be too large so
# the fix is not urgent.
for _ in range(len(self.buffer)):
if self._matches(self.buffer[0],
tokens_roi_recver) > 0:
return True
# rotate the element we just accessed to the end
self.buffer.rotate(-1)
return False
with self.buffer_cv:
while not is_buffer_available(tokens_roi_recver):
logger.debug(
"KV transfer buffer is not available. Waiting...")
self.buffer_cv.wait()
# need to clone the tensor
# in case the tensor is freed before sending finishes
matched_item = self.buffer.popleft()
for tensor in matched_item:
self._send_tensor_and_dec_size(tensor)
self.buffer_cv.notify()
except RuntimeError as e:
if 'Connection closed by peer' not in str(e):
raise e
logger.debug("Closing drop_select_handler")
def drop_select(
self, input_tokens: Optional[torch.Tensor],
roi: Optional[torch.Tensor]) -> list[Optional[torch.Tensor]]:
assert self.request_handling_thread is None, \
"drop_select should be called by the KV cache consumer "\
"(e.g. the decode vLLM instance)"
if isinstance(input_tokens, torch.Tensor):
input_tokens = input_tokens.clone()
if isinstance(roi, torch.Tensor):
roi = roi.clone().float()
self.signal_pipe.send_tensor(self.normal_signal)
self.data_pipe.send_tensor(input_tokens)
self.data_pipe.send_tensor(roi)
input_tokens = self.data_pipe.recv_tensor()
roi = self.data_pipe.recv_tensor()
if roi is not None:
# convert from float tensor to bool tensor
# as PyNccl does not support sending bool tensor
roi = (roi > 0.5)
key = self.data_pipe.recv_tensor()
value = self.data_pipe.recv_tensor()
hidden = self.data_pipe.recv_tensor()
return [input_tokens, roi, key, value, hidden]
def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor,
key: torch.Tensor, value: torch.Tensor,
hidden: torch.Tensor) -> None:
self._add_to_buffer(input_tokens, roi, key, value, hidden)
# when calling the insert, the current process is a sender
# need to launch the request handler and start listening to request.
if self.request_handling_thread is None:
self.request_handling_thread = threading.Thread(
target=self.drop_select_handler)
self.request_handling_thread.start()
def close(self):
if hasattr(self, "request_handling_thread"
) and self.request_handling_thread is not None:
self.request_handling_thread.join()
else:
# TODO: have a explicit close signal and have a explicit way to
# check if it's requester
self.signal_pipe.send_tensor(self.end_signal)