init
This commit is contained in:
175
distributed/kv_transfer/kv_lookup_buffer/base.py
Normal file
175
distributed/kv_transfer/kv_lookup_buffer/base.py
Normal file
@@ -0,0 +1,175 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
This file contains a new class `KVLookupBufferBase` that allows developers to
|
||||
think of KV cache operations as inserting new KV cache entries (`insert`)
|
||||
into the lookup buffer and querying existing KV caches (`drop_select`)
|
||||
from the lookup buffer.
|
||||
|
||||
This file also contains a new class `KVStoreBufferBase` that allows developers
|
||||
to manage the KVCache buffer as a simple key-value storage buffer with basic
|
||||
put/get operations.
|
||||
|
||||
These classes above are abstracted behind class `KVCacheBufferBase`.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class KVCacheBufferBase(ABC):
|
||||
"""
|
||||
Abstract base class for a KVCache buffer.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def close(self) -> None:
|
||||
"""Close the buffer and release resources.
|
||||
|
||||
This method is responsible for cleaning up resources related to the
|
||||
KVCache buffer when it is no longer needed.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This method must be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class KVLookupBufferBase(KVCacheBufferBase):
|
||||
"""
|
||||
Abstract base class for a KVCache lookup buffer.
|
||||
|
||||
This class provides an abstraction for a key-value (KV) cache lookup buffer.
|
||||
|
||||
The key of the lookup buffer:
|
||||
- input_tokens: token IDs of the request
|
||||
- roi: a binary mask on top of input_tokens.
|
||||
- Purpose of roi: Since KV cache may only be available for a subset of
|
||||
tokens in the input (for example, when vLLM is connected to an external
|
||||
KV cache service), roi specifies the subset of tokens that the KV cache
|
||||
is associated with.
|
||||
- NOTE: roi can be further extended to describe which part of KV the
|
||||
current process is holding (each process may only hold a part of KV
|
||||
due to TP and PP). This is not implemented for now.
|
||||
|
||||
The value of the lookup buffer:
|
||||
- key: the key tensor in the KV cache
|
||||
- value: the value tensor in the KV cache
|
||||
- hidden: the final hidden state generated by model forwarding. This allows
|
||||
vLLM to bypass further model forwarding by transmitting the hidden state.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor,
|
||||
key: torch.Tensor, value: torch.Tensor,
|
||||
hidden: torch.Tensor) -> None:
|
||||
"""Insert into the lookup buffer.
|
||||
|
||||
The functionality is similar to the following python statement
|
||||
```
|
||||
buffer[input_tokens, roi] = [key, value, hidden]
|
||||
```
|
||||
|
||||
FIXME: in the future, we should only have two arguments, key and value,
|
||||
where key is a tensor dict and value is a tensor dict.
|
||||
|
||||
FIXME: we should transmit both sampler outputs and the hidden states.
|
||||
|
||||
Args:
|
||||
input_tokens (torch.Tensor): token IDs.
|
||||
roi (torch.Tensor): A binary mask on top of the input tokens
|
||||
key (torch.Tensor): The key tensor in the KV cache.
|
||||
value (torch.Tensor): The value tensor in the KV cache.
|
||||
hidden (torch.Tensor): The final hidden state tensor generated
|
||||
during model forwarding to bypass model
|
||||
forwarding.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This method must be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def drop_select(
|
||||
self, input_tokens: Optional[torch.Tensor],
|
||||
roi: Optional[torch.Tensor]) -> list[Optional[torch.Tensor]]:
|
||||
"""Select and *drop* KV cache entries from the lookup buffer.
|
||||
|
||||
The functionality is similar to the following python statements
|
||||
```
|
||||
ret = buffer.pop(input_tokens, roi)
|
||||
return ret
|
||||
```
|
||||
|
||||
If `input_tokens` and `roi` is `None`, it means selecting any of the
|
||||
KV caches in the buffer, return, and remove it from the buffer, useful
|
||||
when offloading KV cache to KV cache storage service.
|
||||
|
||||
Args:
|
||||
input_tokens (torch.Tensor): token IDs.
|
||||
roi (torch.Tensor): A binary mask on top of the input tokens
|
||||
|
||||
Returns:
|
||||
list[Optional[torch.Tensor]]: A list of tensors. Can be None.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This method must be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class KVStoreBufferBase(KVCacheBufferBase):
|
||||
"""
|
||||
Abstract base class for a KVCache storage buffer with key-value semantics.
|
||||
This class provides a simple key-value storage buffer abstract with basic
|
||||
put/get operations, which enables flexible KVCache transfer granular
|
||||
control.
|
||||
|
||||
The functionality is similar to a distributed key-value store, where:
|
||||
- Key: A unique string identifier for the cached entry
|
||||
- Value:
|
||||
- Tensor to be stored and retrieved
|
||||
- None (indicating deletion or empty value)
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def put(
|
||||
self,
|
||||
key: str,
|
||||
value: Optional[torch.Tensor],
|
||||
) -> None:
|
||||
"""Store a key-value pair in the buffer.
|
||||
|
||||
Args:
|
||||
key (str): Unique identifier for a tensor, this tensor could be the
|
||||
key cache tensor, value cache tensor, or hidden state tensor
|
||||
generated during model forwarding.
|
||||
|
||||
value (Optional[torch.Tensor]): Tensor to be stored.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This method must be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get(
|
||||
self,
|
||||
key: str,
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""Retrieve a value from the buffer by key.
|
||||
|
||||
Args:
|
||||
key (str): Unique identifier for a tensor, this tensor could be the
|
||||
key cache tensor, value cache tensor, or hidden state tensor
|
||||
generated during model forwarding.
|
||||
|
||||
Returns:
|
||||
Optional[torch.Tensor]: Stored tensor if exists, None otherwise.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This method must be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
161
distributed/kv_transfer/kv_lookup_buffer/mooncake_store.py
Normal file
161
distributed/kv_transfer/kv_lookup_buffer/mooncake_store.py
Normal file
@@ -0,0 +1,161 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
This file contains a new class `MooncakeStore` that allows developers to
|
||||
think of KV cache transfer operations as putting new KV cache entries
|
||||
into a remote KVStore-based lookup buffer and getting existing KV caches
|
||||
from this remote lookup buffer.
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load as safetensors_load
|
||||
from safetensors.torch import save as safetensors_save
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_lookup_buffer.base import (
|
||||
KVStoreBufferBase)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
DEFAULT_GLOBAL_SEGMENT_SIZE = 3355443200 # 3.125 GiB
|
||||
DEFAULT_LOCAL_BUFFER_SIZE = 1073741824 # 1.0 GiB
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MooncakeStoreConfig:
|
||||
local_hostname: str
|
||||
metadata_server: str
|
||||
global_segment_size: int
|
||||
local_buffer_size: int
|
||||
protocol: str
|
||||
device_name: str
|
||||
master_server_address: str
|
||||
|
||||
@staticmethod
|
||||
def from_file(file_path: str) -> 'MooncakeStoreConfig':
|
||||
"""Load the config from a JSON file."""
|
||||
with open(file_path) as fin:
|
||||
config = json.load(fin)
|
||||
return MooncakeStoreConfig(
|
||||
local_hostname=config.get("local_hostname"),
|
||||
metadata_server=config.get("metadata_server"),
|
||||
global_segment_size=config.get("global_segment_size",
|
||||
DEFAULT_GLOBAL_SEGMENT_SIZE),
|
||||
local_buffer_size=config.get("local_buffer_size",
|
||||
DEFAULT_LOCAL_BUFFER_SIZE),
|
||||
protocol=config.get("protocol", "tcp"),
|
||||
device_name=config.get("device_name", ""),
|
||||
master_server_address=config.get("master_server_address"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def load_from_env() -> 'MooncakeStoreConfig':
|
||||
"""Load config from a file specified in the environment variable."""
|
||||
config_file_path = os.getenv('MOONCAKE_CONFIG_PATH')
|
||||
if config_file_path is None:
|
||||
raise ValueError(
|
||||
"The environment variable 'MOONCAKE_CONFIG_PATH' is not set.")
|
||||
return MooncakeStoreConfig.from_file(config_file_path)
|
||||
|
||||
|
||||
class MooncakeStore(KVStoreBufferBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: VllmConfig,
|
||||
):
|
||||
|
||||
try:
|
||||
from mooncake.store import MooncakeDistributedStore
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Please install mooncake by following the instructions at "
|
||||
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501
|
||||
"to run vLLM with MooncakeConnector.") from e
|
||||
|
||||
try:
|
||||
self.store = MooncakeDistributedStore()
|
||||
self.config = MooncakeStoreConfig.load_from_env()
|
||||
logger.info("Mooncake Configuration loaded successfully.")
|
||||
|
||||
self.store.setup(self.config.local_hostname,
|
||||
self.config.metadata_server,
|
||||
self.config.global_segment_size,
|
||||
self.config.local_buffer_size,
|
||||
self.config.protocol, self.config.device_name,
|
||||
self.config.master_server_address)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error("Configuration loading failed: %s", e)
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"An error occurred while loading the configuration: %s", exc)
|
||||
raise
|
||||
|
||||
def close(self):
|
||||
# MooncakeDistributedStore will automatically call the destructor, so
|
||||
# it is unnecessary to close it manually.
|
||||
pass
|
||||
|
||||
def put(
|
||||
self,
|
||||
key: str,
|
||||
value: Optional[torch.Tensor],
|
||||
) -> None:
|
||||
# A message queue needs to be introduced before making it asynchronous.
|
||||
if value is not None:
|
||||
self._put_impl(key, value)
|
||||
|
||||
def get(
|
||||
self,
|
||||
key: str,
|
||||
) -> Optional[torch.Tensor]:
|
||||
# A message queue needs to be introduced before making it asynchronous.
|
||||
value = self._get_impl(key)
|
||||
return value
|
||||
|
||||
def _put_impl(
|
||||
self,
|
||||
key: str,
|
||||
value: torch.Tensor,
|
||||
) -> None:
|
||||
"""Put KVCache to Mooncake Store"""
|
||||
device_id = value.device.index if value.device.type == 'cuda' else -1
|
||||
device_tensor = torch.tensor(device_id, dtype=torch.int32)
|
||||
value_bytes = safetensors_save({
|
||||
"tensor": value,
|
||||
"device_id": device_tensor
|
||||
})
|
||||
try:
|
||||
self.store.put(key, value_bytes)
|
||||
except TypeError as err:
|
||||
logger.error("Failed to put value into Mooncake Store: %s", err)
|
||||
raise TypeError("Mooncake Store Put Type Error.") from err
|
||||
|
||||
def _get_impl(
|
||||
self,
|
||||
key: str,
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""Get KVCache from Mooncake Store"""
|
||||
try:
|
||||
data = self.store.get(key)
|
||||
except TypeError as err:
|
||||
logger.error("Failed to get value from Mooncake Store: %s", err)
|
||||
raise TypeError("Mooncake Store Get Type Error.") from err
|
||||
|
||||
if data:
|
||||
loaded_tensors = safetensors_load(data)
|
||||
tensor = loaded_tensors["tensor"]
|
||||
device_id_tensor = loaded_tensors["device_id"]
|
||||
device_id = int(device_id_tensor.item())
|
||||
device = torch.device(
|
||||
'cuda', device_id) if device_id >= 0 else torch.device('cpu')
|
||||
return tensor.to(device)
|
||||
|
||||
return None
|
||||
237
distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py
Normal file
237
distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py
Normal file
@@ -0,0 +1,237 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Implements a distributed key-value (KV) cache transfer mechanism.
|
||||
|
||||
Key Features:
|
||||
- Distributed KV cache transmission using PyNccl pipes.
|
||||
- Non-blocking `insert`, blocking `drop_select`.
|
||||
- Use CPU signal pipe to avoid racing condition
|
||||
- Handles buffer size constraints and provide backpressure mechanism to
|
||||
stop the prefill instance when the decode instance is slow.
|
||||
"""
|
||||
import threading
|
||||
from collections import deque
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_lookup_buffer.base import (
|
||||
KVLookupBufferBase)
|
||||
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class SimpleBuffer(KVLookupBufferBase):
|
||||
|
||||
def __init__(self, signal_pipe: KVPipeBase, data_pipe: KVPipeBase,
|
||||
buffer_size_thresh: float):
|
||||
"""
|
||||
signal_pipe: on CPU
|
||||
|
||||
NOTE: on-device recv will block all threads in the process, making the
|
||||
KV cache producer unable to listen to new request while transmitting
|
||||
KV cache. Luckily CPU recv only blocks the current thread so we use
|
||||
CPU recv to listen to new request.
|
||||
|
||||
data_pipe: on device (e.g. GPU)
|
||||
"""
|
||||
|
||||
self.buffer: deque[list[torch.Tensor]] = deque()
|
||||
|
||||
self.buffer_size = 0
|
||||
self.buffer_size_threshold = buffer_size_thresh
|
||||
self.buffer_cv = threading.Condition()
|
||||
self.signal_pipe = signal_pipe
|
||||
self.data_pipe = data_pipe
|
||||
self.request_handling_thread: Optional[threading.Thread] = None
|
||||
|
||||
self.normal_signal = torch.tensor([0], device="cpu")
|
||||
self.end_signal = None
|
||||
|
||||
def _matches(self, tokens_roi_sender: list[torch.Tensor],
|
||||
tokens_roi_recver: list[torch.Tensor]):
|
||||
|
||||
# tokens_roi_sender: tokens and roi of the producer (in the buffer)
|
||||
# tokens_roi_recver: tokens and roi of the consumer (query)
|
||||
|
||||
tokens_sender = tokens_roi_sender[0]
|
||||
tokens_recver = tokens_roi_recver[0]
|
||||
roi_sender = tokens_roi_sender[1]
|
||||
roi_recver = tokens_roi_recver[1]
|
||||
|
||||
if tokens_recver is None:
|
||||
# consumer sends an empty request
|
||||
# semantics: DROP SELECT * LIMIT 1
|
||||
# so any of the data in the buffer can be drop-selected
|
||||
return True
|
||||
|
||||
# Assuming that roi is a binary mask on tokens
|
||||
tokens_sender = tokens_sender[roi_sender]
|
||||
tokens_recver = tokens_recver[roi_recver]
|
||||
|
||||
# simple common prefix matching
|
||||
min_length = min(len(tokens_sender), len(tokens_recver))
|
||||
if torch.allclose(tokens_sender[:min_length],
|
||||
tokens_recver[:min_length]):
|
||||
return min_length
|
||||
|
||||
return 0
|
||||
|
||||
def _send_tensor_and_dec_size(self,
|
||||
tensor: Optional[torch.Tensor]) -> None:
|
||||
|
||||
assert tensor is not None, "Use self.data_pipe.send(None) instead"
|
||||
self.buffer_size -= tensor.element_size() * tensor.numel()
|
||||
if tensor.dtype == torch.bool:
|
||||
tensor = tensor.float()
|
||||
self.data_pipe.send_tensor(tensor)
|
||||
|
||||
def _get_element_size(self, data: Optional[Union[list, torch.Tensor]]):
|
||||
|
||||
if isinstance(data, torch.Tensor):
|
||||
return data.element_size() * data.numel()
|
||||
if not data:
|
||||
# cannot perform `not data` on a tensor
|
||||
# so this check needs to go after the check above
|
||||
return 0
|
||||
|
||||
raise AssertionError(f"Unknown data type {type(data)}")
|
||||
|
||||
def _add_to_buffer(self, input_tokens: torch.Tensor, roi: torch.Tensor,
|
||||
key: torch.Tensor, value: torch.Tensor,
|
||||
hidden: torch.Tensor):
|
||||
|
||||
if isinstance(input_tokens, torch.Tensor):
|
||||
input_tokens = input_tokens.clone()
|
||||
if isinstance(roi, torch.Tensor):
|
||||
roi = roi.clone()
|
||||
if isinstance(key, torch.Tensor):
|
||||
key = key.clone()
|
||||
if isinstance(value, torch.Tensor):
|
||||
value = value.clone()
|
||||
if isinstance(hidden, torch.Tensor):
|
||||
hidden = hidden.clone()
|
||||
|
||||
buffer_item = [input_tokens, roi, key, value, hidden]
|
||||
data_size = sum([self._get_element_size(data) for data in buffer_item])
|
||||
|
||||
with self.buffer_cv:
|
||||
if self.buffer_size + data_size > self.buffer_size_threshold:
|
||||
# log outside the while loop to avoid this message being logged
|
||||
# repeatedly.
|
||||
logger.debug("KV transfer buffer is full. Handling...")
|
||||
while self.buffer_size + data_size > self.buffer_size_threshold:
|
||||
self.buffer_cv.wait()
|
||||
|
||||
self.buffer_size += data_size
|
||||
self.buffer.append(buffer_item)
|
||||
self.buffer_cv.notify()
|
||||
|
||||
def _is_end_signal(self, signal):
|
||||
return signal is None
|
||||
|
||||
def drop_select_handler(self):
|
||||
|
||||
try:
|
||||
|
||||
while True:
|
||||
signal = self.signal_pipe.recv_tensor()
|
||||
if self._is_end_signal(signal):
|
||||
logger.info("Received end signal!")
|
||||
break
|
||||
|
||||
input_tokens = self.data_pipe.recv_tensor()
|
||||
|
||||
roi = self.data_pipe.recv_tensor()
|
||||
assert roi is not None, "Please provide the roi when sending "\
|
||||
"drop-select request"
|
||||
roi = (roi > 0.5)
|
||||
tokens_roi_recver = [input_tokens, roi]
|
||||
|
||||
def is_buffer_available(
|
||||
tokens_roi_recver: list[torch.Tensor], ) -> bool:
|
||||
# perform input tokens and roi matching
|
||||
# FIXME: this matching is O(n), ideally it should be O(1)
|
||||
# but this buffer size won't (and shouldn't) be too large so
|
||||
# the fix is not urgent.
|
||||
for _ in range(len(self.buffer)):
|
||||
if self._matches(self.buffer[0],
|
||||
tokens_roi_recver) > 0:
|
||||
return True
|
||||
# rotate the element we just accessed to the end
|
||||
self.buffer.rotate(-1)
|
||||
return False
|
||||
|
||||
with self.buffer_cv:
|
||||
while not is_buffer_available(tokens_roi_recver):
|
||||
logger.debug(
|
||||
"KV transfer buffer is not available. Waiting...")
|
||||
self.buffer_cv.wait()
|
||||
# need to clone the tensor
|
||||
# in case the tensor is freed before sending finishes
|
||||
matched_item = self.buffer.popleft()
|
||||
for tensor in matched_item:
|
||||
self._send_tensor_and_dec_size(tensor)
|
||||
self.buffer_cv.notify()
|
||||
|
||||
except RuntimeError as e:
|
||||
if 'Connection closed by peer' not in str(e):
|
||||
raise e
|
||||
|
||||
logger.debug("Closing drop_select_handler")
|
||||
|
||||
def drop_select(
|
||||
self, input_tokens: Optional[torch.Tensor],
|
||||
roi: Optional[torch.Tensor]) -> list[Optional[torch.Tensor]]:
|
||||
|
||||
assert self.request_handling_thread is None, \
|
||||
"drop_select should be called by the KV cache consumer "\
|
||||
"(e.g. the decode vLLM instance)"
|
||||
|
||||
if isinstance(input_tokens, torch.Tensor):
|
||||
input_tokens = input_tokens.clone()
|
||||
if isinstance(roi, torch.Tensor):
|
||||
roi = roi.clone().float()
|
||||
|
||||
self.signal_pipe.send_tensor(self.normal_signal)
|
||||
self.data_pipe.send_tensor(input_tokens)
|
||||
self.data_pipe.send_tensor(roi)
|
||||
|
||||
input_tokens = self.data_pipe.recv_tensor()
|
||||
roi = self.data_pipe.recv_tensor()
|
||||
if roi is not None:
|
||||
# convert from float tensor to bool tensor
|
||||
# as PyNccl does not support sending bool tensor
|
||||
roi = (roi > 0.5)
|
||||
key = self.data_pipe.recv_tensor()
|
||||
value = self.data_pipe.recv_tensor()
|
||||
hidden = self.data_pipe.recv_tensor()
|
||||
|
||||
return [input_tokens, roi, key, value, hidden]
|
||||
|
||||
def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor,
|
||||
key: torch.Tensor, value: torch.Tensor,
|
||||
hidden: torch.Tensor) -> None:
|
||||
|
||||
self._add_to_buffer(input_tokens, roi, key, value, hidden)
|
||||
|
||||
# when calling the insert, the current process is a sender
|
||||
# need to launch the request handler and start listening to request.
|
||||
if self.request_handling_thread is None:
|
||||
self.request_handling_thread = threading.Thread(
|
||||
target=self.drop_select_handler)
|
||||
self.request_handling_thread.start()
|
||||
|
||||
def close(self):
|
||||
|
||||
if hasattr(self, "request_handling_thread"
|
||||
) and self.request_handling_thread is not None:
|
||||
self.request_handling_thread.join()
|
||||
|
||||
else:
|
||||
# TODO: have a explicit close signal and have a explicit way to
|
||||
# check if it's requester
|
||||
self.signal_pipe.send_tensor(self.end_signal)
|
||||
Reference in New Issue
Block a user