This commit is contained in:
root
2026-03-05 18:06:10 +08:00
commit 809cecae09
2569 changed files with 478204 additions and 0 deletions

View File

@@ -0,0 +1,29 @@
# Distributed KV cache transfer
This folder implements distributed KV cache transfer across vLLM instances.
Currently the main use case is for disaggregated prefilling.
## Abstractions
The KV cache transfer contains three layer of abstractions:
- KV pipe: a FIFO pipe for torch.tensor transmission. Key APIs: `send_tensor` and `recv_tensor`.
- KV lookup buffer: a lookup buffer for KV caches. Key: the tokens, value: the KV caches (and/or hidden states). Key APIs: `insert` and `drop_select` (similar to SQL semantics).
- KV connector: a connector that connects the KV pipe and KV lookup buffer to vLLM. Key APIs: `send_kv_caches_and_hidden_states` and `recv_kv_caches_and_hidden_states`.
Why we need KV lookup buffer: FIFO pipe itself is not enough as prefill vLLM worker may process requests in a different order compared to decode vLLM worker. Say the QPS is really high, prefill worker may handle requests in order A -> B -> C, but the decode worker may process request C first. This is not the case that can be naturally handled by FIFO pipe, so we provide KV lookup buffer to help translate a FIFO pipe to a lookup buffer.
NOTE: KV pipe layer is bypassable: you can skip this layer if your distributed
communication service already supports key-value-based lookup (like redis or
RDMA database).
NOTE: If you want to not only transfer KV caches, but adjust the model execution flow of vLLM as well (for example, allow vLLM to receive KV caches on some tokens and do prefill on the remaining tokens), you can bypass both KV pipe layer and KV lookup buffer layer, and directly implement on KV connector layer. Bear in mind that as vLLM's model input is constantly changing, this implementation will likely be broken when vLLM has new updates.
## Disaggregated prefilling
The example usage is in [this file](../../../examples/online_serving/disaggregated_prefill.sh).
Here is the diagram of how we run disaggregated prefilling.
![Disaggregated prefill workflow](./disagg_prefill_workflow.jpg)

View File

@@ -0,0 +1,20 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.distributed.kv_transfer.kv_transfer_state import (
KVConnectorBaseType,
ensure_kv_transfer_initialized,
ensure_kv_transfer_shutdown,
get_kv_transfer_group,
has_kv_transfer_group,
is_v1_kv_transfer_group,
)
__all__ = [
"get_kv_transfer_group",
"has_kv_transfer_group",
"is_v1_kv_transfer_group",
"ensure_kv_transfer_initialized",
"ensure_kv_transfer_shutdown",
"KVConnectorBaseType",
]

Binary file not shown.

After

Width:  |  Height:  |  Size: 139 KiB

View File

@@ -0,0 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Defines the base type for KV cache connectors."""
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
KVConnectorBase = KVConnectorBase_V1
KVConnectorBaseType = KVConnectorBase_V1
__all__ = ["KVConnectorBase", "KVConnectorBaseType"]

View File

@@ -0,0 +1,192 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib
from collections.abc import Callable
from typing import TYPE_CHECKING, Optional, cast
from vllm.distributed.kv_transfer.kv_connector.base import (
KVConnectorBase,
KVConnectorBaseType,
)
from vllm.distributed.kv_transfer.kv_connector.v1 import (
KVConnectorRole,
supports_hma,
)
from vllm.logger import init_logger
from vllm.utils.func_utils import supports_kw
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.config.kv_transfer import KVTransferConfig
from vllm.v1.kv_cache_interface import KVCacheConfig
logger = init_logger(__name__)
class KVConnectorFactory:
_registry: dict[str, Callable[[], type[KVConnectorBase]]] = {}
@classmethod
def register_connector(cls, name: str, module_path: str, class_name: str) -> None:
"""Register a connector with a lazy-loading module and class name."""
if name in cls._registry:
raise ValueError(f"Connector '{name}' is already registered.")
def loader() -> type[KVConnectorBase]:
module = importlib.import_module(module_path)
return getattr(module, class_name)
cls._registry[name] = loader
@classmethod
def create_connector(
cls,
config: "VllmConfig",
role: KVConnectorRole,
kv_cache_config: Optional["KVCacheConfig"] = None,
) -> KVConnectorBase:
kv_transfer_config = config.kv_transfer_config
if kv_transfer_config is None:
raise ValueError("kv_transfer_config must be set to create a connector")
connector_cls, compat_sig = cls._get_connector_class_with_compat(
kv_transfer_config
)
# check if the connector supports HMA
hma_enabled = not config.scheduler_config.disable_hybrid_kv_cache_manager
if hma_enabled and not supports_hma(connector_cls):
raise ValueError(
f"Connector {connector_cls.__name__} does not support HMA but "
f"HMA is enabled. Please set `--disable-hybrid-kv-cache-manager`."
)
logger.info(
"Creating v1 connector with name: %s and engine_id: %s",
connector_cls.__name__,
kv_transfer_config.engine_id,
)
# NOTE(Kuntai): v1 connector is explicitly separated into two roles.
# Scheduler connector:
# - Co-locate with scheduler process
# - Should only be used inside the Scheduler class
# Worker connector:
# - Co-locate with worker process
# - Should only be used inside the forward context & attention layer
# We build separately to enforce strict separation
if compat_sig:
# Old signature: __init__(self, vllm_config, role)
return connector_cls(config, role)
else:
# New signature: __init__(self, vllm_config, role, kv_cache_config)
return connector_cls(config, role, kv_cache_config)
@classmethod
def get_connector_class_by_name(
cls, connector_name: str
) -> type[KVConnectorBaseType]:
"""Get a registered connector class by name.
Raises ValueError if the connector is not registered.
Args:
connector_name: Name of the registered connector.
Returns:
The connector class.
"""
if connector_name not in cls._registry:
raise ValueError(f"Connector '{connector_name}' is not registered.")
return cls._registry[connector_name]()
@classmethod
def _get_connector_class_with_compat(
cls, kv_transfer_config: "KVTransferConfig"
) -> tuple[type[KVConnectorBaseType], bool]:
connector_name = kv_transfer_config.kv_connector
if connector_name is None:
raise ValueError("Connector name is not set in KVTransferConfig")
compat_sig = False
if connector_name in cls._registry:
connector_cls = cls._registry[connector_name]()
else:
connector_module_path = kv_transfer_config.kv_connector_module_path
if connector_module_path is None:
raise ValueError(f"Unsupported connector type: {connector_name}")
connector_module = importlib.import_module(connector_module_path)
try:
connector_cls = getattr(connector_module, connector_name)
except AttributeError as e:
raise AttributeError(
f"Class {connector_name} not found in {connector_module_path}"
) from e
connector_cls = cast(type[KVConnectorBaseType], connector_cls)
if not supports_kw(connector_cls, "kv_cache_config"):
compat_sig = True
logger.warning(
"Connector %s uses deprecated signature with 2 required arguments. "
"Please update to include kv_cache_config as the second argument.",
connector_cls.__name__,
)
return connector_cls, compat_sig
@classmethod
def get_connector_class(
cls, kv_transfer_config: "KVTransferConfig"
) -> type[KVConnectorBaseType]:
"""Get the connector class by name."""
connector_cls, _ = cls._get_connector_class_with_compat(kv_transfer_config)
return connector_cls
# Register various connectors here.
# The registration should not be done in each individual file, as we want to
# only load the files corresponding to the current connector.
KVConnectorFactory.register_connector(
"SharedStorageConnector",
"vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector",
"SharedStorageConnector",
)
KVConnectorFactory.register_connector(
"P2pNcclConnector",
"vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_connector",
"P2pNcclConnector",
)
KVConnectorFactory.register_connector(
"LMCacheConnectorV1",
"vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector",
"LMCacheConnectorV1",
)
KVConnectorFactory.register_connector(
"LMCacheMPConnector",
"vllm.distributed.kv_transfer.kv_connector.v1.lmcache_mp_connector",
"LMCacheMPConnector",
)
KVConnectorFactory.register_connector(
"NixlConnector",
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector",
"NixlConnector",
)
KVConnectorFactory.register_connector(
"MultiConnector",
"vllm.distributed.kv_transfer.kv_connector.v1.multi_connector",
"MultiConnector",
)
KVConnectorFactory.register_connector(
"OffloadingConnector",
"vllm.distributed.kv_transfer.kv_connector.v1.offloading_connector",
"OffloadingConnector",
)
KVConnectorFactory.register_connector(
"DecodeBenchConnector",
"vllm.distributed.kv_transfer.kv_connector.v1.decode_bench_connector",
"DecodeBenchConnector",
)

View File

@@ -0,0 +1,268 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
KV cache helper for store.
"""
from typing import TYPE_CHECKING, Literal
import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
from vllm.logger import init_logger
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
if TYPE_CHECKING:
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
logger = init_logger(__name__)
class model_aware_kv_ops_helper:
def __init__(self, config: VllmConfig):
self.is_deepseek_mla = config.model_config.is_deepseek_mla
self.use_mla_opt = not envs.VLLM_MLA_DISABLE
self.tp_size = config.parallel_config.tensor_parallel_size
def get_model_args(self, model_executable: torch.nn.Module):
model_config = model_executable.model.config
self.model_executable = model_executable
num_heads = int(model_config.num_key_value_heads / self.tp_size)
hidden_size = model_config.hidden_size
num_attention_heads = model_config.num_attention_heads
# Deepseek's MLA (Multi-head Latent Attention) uses two different
# kv_cache shapes based on whether VLLM_MLA_DISABLE is set to 0.
# When VLLM_MLA_DISABLE=0 (default), forward absorb is applied,
# resulting in a kv_cache shape of [num_blks, blk_size, 1,
# kv_lora_rank + qk_rope_head_dim].
# When VLLM_MLA_DISABLE=1, standard FA is used instead, leading
# to a kv_cache shape of [2, num_blks, blk_size,
# num_key_value_heads / tp, qk_nope_head_dim + qk_rope_head_dim].
# For more details, see vllm/v1/attention/backends/mla/common.py.
if self.is_deepseek_mla and self.use_mla_opt:
head_size = model_config.kv_lora_rank + model_config.qk_rope_head_dim
num_heads = 1
elif self.is_deepseek_mla and not self.use_mla_opt:
head_size = model_config.qk_nope_head_dim + model_config.qk_rope_head_dim
else:
head_size = getattr(model_config, "head_dim", None)
if head_size is None:
head_size = int(hidden_size // num_attention_heads)
return num_heads, head_size
def get_kv_from_cache(self, kv_cache, num_heads, head_size):
if self.is_deepseek_mla and self.use_mla_opt:
key_cache = kv_cache.reshape(-1, num_heads, head_size)
value_cache = kv_cache.reshape(-1, num_heads, head_size)
else:
key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
return key_cache, value_cache
def put_kv_to_cache(
self,
model_executable: torch.nn.Module,
keys,
values,
layer,
kv_cache,
slot_mapping,
start_pos,
end_pos,
):
model_config = model_executable.model.config
if self.is_deepseek_mla and self.use_mla_opt:
layer.self_attn.attn = layer.self_attn.mla_attn
k_c_normed_k_pe = keys.squeeze(1)
k_c_normed = k_c_normed_k_pe[:, : model_config.kv_lora_rank]
k_pe = k_c_normed_k_pe[:, model_config.kv_lora_rank :]
ops.concat_and_cache_mla(
k_c_normed.to(kv_cache.device),
k_pe.to(kv_cache.device),
kv_cache,
slot_mapping[start_pos:end_pos],
layer.self_attn.attn.kv_cache_dtype,
layer.self_attn.attn._k_scale,
)
else:
key_cache, value_cache = kv_cache[0], kv_cache[1]
ops.reshape_and_cache_flash(
keys.to(key_cache.device),
values.to(value_cache.device),
key_cache,
value_cache,
slot_mapping[start_pos:end_pos],
layer.self_attn.attn.kv_cache_dtype,
layer.self_attn.attn._k_scale,
layer.self_attn.attn._v_scale,
)
def get_kv_connector_cache_layout():
# NOTE (NickLucche) When running disaggregated PD with NIXL, HND layout is
# used for faster transfer.
vllm_config = get_current_vllm_config()
kv_config = vllm_config.kv_transfer_config
if kv_config is not None:
connector_cls = KVConnectorFactory.get_connector_class(kv_config)
required_kvcache_layout = connector_cls.get_required_kvcache_layout(vllm_config)
if required_kvcache_layout is not None:
return required_kvcache_layout
logger.info_once(
"Connectors do not specify a kv cache layout, defaulting to NHD."
)
return "NHD"
class KVOutputAggregator:
"""Utility class to aggregate the output of all workers into a single
output corresponding to Rank 0 for scheduler."""
def __init__(self, expected_finished_count: int):
# Complete transfer tracker. Used to track finished requests
# [req_id -> n_remaining_workers]
self._recv_remaining_count = dict[str, int]()
self._send_remaining_count = dict[str, int]()
self._expected_finished_count = expected_finished_count
@classmethod
def from_connector(cls, connector: "KVConnectorBase", world_size: int):
return cls(connector.get_finished_count() or world_size)
def aggregate(
self, outputs: list[ModelRunnerOutput | None], output_rank: int = 0
) -> ModelRunnerOutput | None:
if not outputs[output_rank]:
return None
# Aggregate kv_connector_output from all workers
def update_finished_set(
req_ids: set[str] | None,
remaining_count_dict: dict[str, int],
finished_set: set[str],
) -> None:
for req_id in req_ids or ():
remaining_count = remaining_count_dict.get(
req_id, self._expected_finished_count
)
remaining_count_dict[req_id] = remaining_count - 1
if remaining_count_dict[req_id] == 0:
finished_set.add(req_id)
del remaining_count_dict[req_id]
finished_sending = set[str]()
finished_recving = set[str]()
aggregated_kv_connector_stats = None
invalid_block_ids = set[int]()
for model_runner_output in outputs:
assert model_runner_output is not None
kv_output = model_runner_output.kv_connector_output
if not kv_output:
continue
# Allow the worker to dynamically update the expected number of
# finished sending/recving for new requests.
if (
kv_output.expected_finished_count > 0
and kv_output.expected_finished_count != self._expected_finished_count
):
logger.debug(
"Expected finished requests updated from %d to %d",
self._expected_finished_count,
kv_output.expected_finished_count,
)
self._expected_finished_count = kv_output.expected_finished_count
update_finished_set(
kv_output.finished_sending, self._send_remaining_count, finished_sending
)
update_finished_set(
kv_output.finished_recving, self._recv_remaining_count, finished_recving
)
# Aggregate kv_connector_stats from all workers.
if aggregated_kv_connector_stats is None:
# Use the first worker's kv_connector_stats as accumulator.
aggregated_kv_connector_stats = kv_output.kv_connector_stats
elif kv_connector_stats := kv_output.kv_connector_stats:
if aggregated_kv_connector_stats is None:
aggregated_kv_connector_stats = kv_connector_stats
else:
assert isinstance(
aggregated_kv_connector_stats, type(kv_connector_stats)
)
aggregated_kv_connector_stats = (
aggregated_kv_connector_stats.aggregate(kv_connector_stats)
)
invalid_block_ids |= kv_output.invalid_block_ids
# select output of the worker specified by output_rank
output = outputs[output_rank]
assert output is not None
output.kv_connector_output = KVConnectorOutput(
finished_sending=finished_sending or None,
finished_recving=finished_recving or None,
kv_connector_stats=aggregated_kv_connector_stats or None,
invalid_block_ids=invalid_block_ids,
expected_finished_count=self._expected_finished_count,
)
return output
def _make_src_and_dst_indices(
src_block_ids: list[int],
dst_block_ids: list[int],
src_device: torch.device | str,
dst_device: torch.device | str,
) -> tuple[torch.Tensor, torch.Tensor]:
src_indices = torch.tensor(src_block_ids, device=src_device, dtype=torch.int64)
dst_indices = torch.tensor(dst_block_ids, device=dst_device, dtype=torch.int64)
return src_indices, dst_indices
def copy_kv_blocks(
src_kv_caches: dict[str, torch.Tensor],
dst_kv_caches: dict[str, torch.Tensor],
src_block_ids: list[int],
dst_block_ids: list[int],
direction: Literal["h2d", "d2h"],
) -> None:
"""Copy kv blocks between different buffers."""
if (
not src_kv_caches
or not dst_kv_caches
or not src_block_ids
or not dst_block_ids
or len(src_block_ids) != len(dst_block_ids)
):
return
src_device = next(iter(src_kv_caches.values())).device
dst_device = next(iter(dst_kv_caches.values())).device
src_indices, dst_indices = _make_src_and_dst_indices(
src_block_ids=src_block_ids,
dst_block_ids=dst_block_ids,
src_device=src_device,
dst_device=dst_device,
)
from vllm.platforms import current_platform
if direction == "h2d":
copy_fn = current_platform.insert_blocks_to_device
else:
copy_fn = current_platform.swap_out_blocks_to_host
for layer_name in src_kv_caches:
src_tensor = src_kv_caches[layer_name]
dst_tensor = dst_kv_caches[layer_name]
copy_fn(src_tensor, dst_tensor, src_indices, dst_indices)

View File

@@ -0,0 +1,19 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1,
KVConnectorRole,
SupportsHMA,
supports_hma,
)
from vllm.distributed.kv_transfer.kv_connector.v1.decode_bench_connector import ( # noqa E:501
DecodeBenchConnector,
)
__all__ = [
"KVConnectorRole",
"KVConnectorBase_V1",
"supports_hma",
"SupportsHMA",
"DecodeBenchConnector",
]

View File

@@ -0,0 +1,546 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
KVConnectorBase_V1 Class for Distributed KV Cache & Hidden State
communication in vLLM v1
The class provides the following primitives:
Scheduler-side: runs in the scheduler, binds metadata, which
is used by the worker-side to load/save KV cache.
get_num_new_matched_tokens() - get number of new tokens
that exist in the remote KV cache. Might be called multiple
times for a given request and should be side-effect free.
update_state_after_alloc() - update KVConnector state after
temporary buffer alloc by the CacheManager.
update_connector_output() - update KVConnector state after
output is received from worker-side connectors.
request_finished() - called once when a request is finished,
with the computed kv cache blocks for the request.
Returns whether KV cache should be freed now or if the
connector now assumes responsibility for freeing the
the blocks asynchronously. Also optionally returns KV
transfer params.
take_events() - returns new KV events that were collected
by the connector since the last call.
Worker-side: runs in each worker, loads/saves KV cache to/from
the Connector based on the metadata.
start_load_kv() - starts loading all KVs (maybe async)
wait_for_layer_load() - blocks until layer i load is done
save_kv_layer() - starts saving KV for layer i (maybe async)
wait_for_save() - blocks until all saves are done
get_finished() - called with ids of finished requests, returns
ids of requests that have completed async sending/recving.
"""
import enum
from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable
from typing import TYPE_CHECKING, Any, Literal, Optional
import torch
from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import KVConnectorOutput
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed.kv_events import KVCacheEvent
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorPromMetrics,
KVConnectorStats,
PromMetric,
PromMetricT,
)
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request
# s_tensor_list, d_tensor_list, s_indices, d_indices, direction
CopyBlocksOp = Callable[
[
dict[str, torch.Tensor],
dict[str, torch.Tensor],
list[int],
list[int],
Literal["h2d", "d2h"],
],
None,
]
logger = init_logger(__name__)
class SupportsHMA(ABC):
"""
The class that indicates the corresponding connector supports hybrid memory
allocator (HMA).
This is required to use the connector together with hybrid memory allocator.
"""
@abstractmethod
def request_finished_all_groups(
self,
request: "Request",
block_ids: tuple[list[int], ...],
) -> tuple[bool, dict[str, Any] | None]:
"""
Called exactly once when a request has finished for all kv cache groups,
before its blocks are freed for each group.
NOTE(Kuntai): This function is only supported by connectors that support HMA.
The connector may assumes responsibility for freeing the blocks
asynchronously by returning True.
Returns:
True if the request is being saved/sent asynchronously and blocks
should not be freed until the request_id is returned from
get_finished().
Optional KVTransferParams to be included in the request outputs
returned by the engine.
"""
raise NotImplementedError
def supports_hma(connector: Any) -> bool:
if isinstance(connector, type):
return issubclass(connector, SupportsHMA)
else:
return isinstance(connector, SupportsHMA)
class KVConnectorRole(enum.Enum):
# Connector running in the scheduler process
SCHEDULER = 0
# Connector running in the worker process
WORKER = 1
class KVConnectorHandshakeMetadata(ABC): # noqa: B024
"""
Metadata used for out of band connector handshake between
P/D workers. This needs to serializeable.
"""
pass
class KVConnectorMetadata(ABC): # noqa: B024
"""
Abstract Metadata used to communicate between the
Scheduler KVConnector and Worker KVConnector.
"""
pass
class KVConnectorBase_V1(ABC):
def __init__(
self,
vllm_config: "VllmConfig",
role: KVConnectorRole,
kv_cache_config: Optional["KVCacheConfig"] = None,
):
logger.warning(
"Initializing KVConnectorBase_V1. This API is experimental and "
"subject to change in the future as we iterate the design."
)
self._connector_metadata: KVConnectorMetadata | None = None
self._vllm_config = vllm_config
if vllm_config.kv_transfer_config is not None:
self._kv_transfer_config = vllm_config.kv_transfer_config
else:
raise ValueError("kv_transfer_config must be set for KVConnectorBase_V1")
self._kv_cache_config = kv_cache_config
if self._kv_cache_config is None:
logger.warning(
"KVConnectorBase_V1 initialized without kv_cache_config. "
"This is deprecated - please update your connector to accept "
"kv_cache_config as the third constructor argument and pass it "
"to super().__init__()."
)
self._role = role
@property
def role(self) -> KVConnectorRole:
return self._role
# ==============================
# Worker-side methods
# ==============================
def bind_connector_metadata(self, connector_metadata: KVConnectorMetadata) -> None:
"""Set the connector metadata from the scheduler.
This function should be called by the model runner every time
before the model execution. The metadata will be used for runtime
KV cache loading and saving.
Args:
connector_metadata (dict): the connector metadata.
"""
self._connector_metadata = connector_metadata
def clear_connector_metadata(self) -> None:
"""Clear the connector metadata.
This function should be called by the model runner every time
after the model execution.
"""
self._connector_metadata = None
def _get_connector_metadata(self) -> KVConnectorMetadata:
"""Get the connector metadata.
This function should only be called inside the connector.
Returns:
ConnectorMetadata: the connector metadata.
"""
# Should only be called while set to valid metadata.
assert self._connector_metadata is not None
return self._connector_metadata
def has_connector_metadata(self) -> bool:
"""Check whether the connector metadata is currently set.
Returns:
bool: True if connector metadata exists, False otherwise.
"""
return self._connector_metadata is not None
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"""
Initialize with the KV caches. Useful for pre-registering the
KV Caches in the KVConnector (e.g. for NIXL).
Args:
kv_caches: dictionary of layer names, kv cache
"""
return
def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp):
"""
Set the xPU-specific ops for copying KV between host and device.
Needed when host buffer is used for kv transfer (e.g., in NixlConnector)
"""
return
@abstractmethod
def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None:
"""
Start loading the KV cache from the connector to vLLM's paged
KV buffer. This is called from the forward context before the
forward pass to enable async loading during model execution.
Args:
forward_context (ForwardContext): the forward context.
**kwargs: additional arguments for the load operation
Note:
The number of elements in kv_caches and layer_names should be
the same.
"""
pass
@abstractmethod
def wait_for_layer_load(self, layer_name: str) -> None:
"""
Block until the KV for a specific layer is loaded into vLLM's
paged buffer. This is called from within attention layer to ensure
async copying from start_load_kv is complete.
This interface will be useful for layer-by-layer pipelining.
Args:
layer_name: the name of that layer
"""
pass
@abstractmethod
def save_kv_layer(
self,
layer_name: str,
kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata",
**kwargs: Any,
) -> None:
"""
Start saving a layer of KV cache from vLLM's paged buffer
to the connector. This is called from within attention layer to
enable async copying during execution.
Args:
layer_name (str): the name of the layer.
kv_layer (torch.Tensor): the paged KV buffer of the current
layer in vLLM.
attn_metadata (AttentionMetadata): the attention metadata.
**kwargs: additional arguments for the save operation.
"""
pass
@abstractmethod
def wait_for_save(self):
"""
Block until all the save operations is done. This is called
as the forward context exits to ensure that the async saving
from save_kv_layer is complete before finishing the forward.
This prevents overwrites of paged KV buffer before saving done.
"""
pass
def get_finished(
self, finished_req_ids: set[str]
) -> tuple[set[str] | None, set[str] | None]:
"""
Notifies worker-side connector ids of requests that have
finished generating tokens on the worker.
The scheduler process (via the Executors) will use this output
to track which workers are done.
Returns:
ids of requests that have finished asynchronous transfer
(requests that previously returned True from request_finished()),
tuple of (sending/saving ids, recving/loading ids).
The finished saves/sends req ids must belong to a set provided in a
call to this method (this call or a prior one).
"""
return None, None
def get_block_ids_with_load_errors(self) -> set[int]:
"""
Get the set of block IDs that failed to load.
Returns:
Set of block IDs that encountered load errors.
Empty set if no load errors occurred.
Notes:
- Applies to both sync- and async-loading requests.
- Async loading: failed blocks may be reported in any forward pass
up to and including the pass where the request ID is returned by
`get_finished()`. Even if failures occur, the request must still
be reported via `get_finished()`, and the failed block IDs must
appear here no later than that same pass.
- Sync loading: failed blocks should be reported in the forward
pass in which they are detected.
"""
return set()
def shutdown(self):
"""
Shutdown the connector. This is called when the worker process
is shutting down to ensure that all the async operations are
completed and the connector is cleaned up properly.
"""
return None
def get_kv_connector_stats(self) -> Optional["KVConnectorStats"]:
"""
Get the KV connector stats collected during the last interval.
"""
return None
def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None:
"""
Get the KVConnector handshake metadata for this connector.
This metadata is used for out-of-band connector handshake
between P/D workers.
Returns:
KVConnectorHandshakeMetadata: the handshake metadata.
None if no handshake metadata is available.
"""
return None
# ==============================
# Scheduler-side methods
# ==============================
@abstractmethod
def get_num_new_matched_tokens(
self,
request: "Request",
num_computed_tokens: int,
) -> tuple[int | None, bool]:
"""
Get number of new tokens that can be loaded from the
external KV cache beyond the num_computed_tokens.
Args:
request (Request): the request object.
num_computed_tokens (int): the number of locally
computed tokens for this request
Returns:
A tuple with the following elements:
- An optional number of tokens that can be loaded from the
external KV cache beyond what is already computed.
If None, it means that the connector needs more time to
determine the number of matched tokens, and the scheduler
should query for this request again later.
- `True` if external KV cache tokens will be loaded
asynchronously (between scheduler steps). Must be
'False' if the first element is 0.
Notes:
The connector should only consider the largest prefix of prompt-
tokens for which KV cache is actually available at the time of the
call. If the cache cannot be loaded for some tokens (e.g., due to
connectivity issues or eviction), those tokens must not be taken
into account.
"""
pass
@abstractmethod
def update_state_after_alloc(
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
):
"""
Update KVConnector state after block allocation.
If get_num_new_matched_tokens previously returned True for a
request, this function may be called twice for that same request -
first when blocks are allocated for the connector tokens to be
asynchronously loaded into, and second when any additional blocks
are allocated, after the load/transfer is complete.
Args:
request (Request): the request object.
blocks (KVCacheBlocks): the blocks allocated for the request.
num_external_tokens (int): the number of tokens that will be
loaded from the external KV cache.
"""
pass
@abstractmethod
def build_connector_meta(
self, scheduler_output: SchedulerOutput
) -> KVConnectorMetadata:
"""
Build the connector metadata for this step.
This function should NOT modify fields in the scheduler_output.
Also, calling this function will reset the state of the connector.
Args:
scheduler_output (SchedulerOutput): the scheduler output object.
"""
pass
def update_connector_output(self, connector_output: KVConnectorOutput):
"""
Update KVConnector state from worker-side connectors output.
Args:
connector_output (KVConnectorOutput): the worker-side
connectors output.
"""
return
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, dict[str, Any] | None]:
"""
Called exactly once when a request has finished, before its blocks are
freed.
The connector may assumes responsibility for freeing the blocks
asynchronously by returning True.
Returns:
True if the request is being saved/sent asynchronously and blocks
should not be freed until the request_id is returned from
get_finished().
Optional KVTransferParams to be included in the request outputs
returned by the engine.
"""
return False, None
def take_events(self) -> Iterable["KVCacheEvent"]:
"""
Take the KV cache events from the connector.
Yields:
New KV cache events since the last call.
"""
return ()
@classmethod
def get_required_kvcache_layout(cls, vllm_config: "VllmConfig") -> str | None:
"""
Get the required KV cache layout for this connector.
Args:
vllm_config (VllmConfig): the vllm config.
Returns:
str: the required KV cache layout. e.g. HND, or NHD.
None if the connector does not require a specific layout.
"""
if cls is KVConnectorBase_V1:
raise TypeError(
"get_required_kvcache_layout should not be called "
"on the abstract base class"
)
return None
def get_finished_count(self) -> int | None:
"""
Get the count of requests expected to complete send/receive operations
via this connector. This method is used to initialize the
KVOutputAggregator, overwriting the default world_size.
Returns:
int: expected sending or receiving completion count.
"""
return None
@classmethod
def build_kv_connector_stats(
cls, data: dict[str, Any] | None = None
) -> Optional["KVConnectorStats"]:
"""
KVConnectorStats resolution method. This method allows dynamically
registered connectors to return their own KVConnectorStats object,
which can implement custom aggregation logic on the data dict.
"""
return None
def set_xfer_handshake_metadata(
self, metadata: dict[int, KVConnectorHandshakeMetadata]
) -> None:
"""
Set the KV connector handshake metadata for this connector.
Args:
metadata (KVConnectorHandshakeMetadata): the handshake metadata to set.
"""
return None
@classmethod
def build_prom_metrics(
cls,
vllm_config: "VllmConfig",
metric_types: dict[type["PromMetric"], type["PromMetricT"]],
labelnames: list[str],
per_engine_labelvalues: dict[int, list[str]],
) -> Optional["KVConnectorPromMetrics"]:
"""
Create a KVConnectorPromMetrics subclass which should register
per-connector Prometheus metrics and implement observe() to
expose connector transfer stats via Prometheus.
"""
return None

View File

@@ -0,0 +1,419 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
DecodeBenchConnector: A KV Connector for decode instance performance testing.
This connector emulates a prefill-decode disaggregated setting by filling
the KV cache with dummy values, allowing measurement of decoder performance
under larger input sequence lengths (ISL) in resource-limited environments.
Usage:
To use this connector for benchmarking, configure it in the kv_transfer_config:
Example:
vllm serve <model> --kv-transfer-config '{
"kv_connector": "DecodeBenchConnector",
"kv_role": "kv_both",
"kv_connector_extra_config": {
"fill_mean": 0.015,
"fill_std": 0.0
}
}'
Then run your benchmark with desired input/output lengths:
vllm bench serve --base-url http://127.0.0.1:8000 --model <model> \\
--dataset-name random --random-input-len 40000 \\
--random-output-len 100 --max-concurrency 10
Configuration options (via kv_connector_extra_config):
- fill_mean (float): Mean value for random normal fill (default: 0.015)
- fill_std (float): Standard deviation for random fill (default: 0.0)
Set to 0 for constant values, >0 for random sampling
"""
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional
import torch
from vllm.distributed.kv_transfer.kv_connector.v1 import (
KVConnectorBase_V1,
KVConnectorRole,
)
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata
from vllm.logger import init_logger
from vllm.utils.math_utils import cdiv
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request
logger = init_logger(__name__)
@dataclass
class DecodeBenchConnectorMetadata(KVConnectorMetadata):
"""Metadata for DecodeBenchConnector.
Contains information about which requests need their KV cache filled
with dummy values for benchmarking purposes.
"""
# request_id -> (block_ids_per_group, num_tokens_to_fill)
# block_ids_per_group is a tuple of lists, one per KV cache group
# For standard attention: single group, e.g., ([1, 2, 3],)
# For MLA: multiple groups, e.g., ([1, 2], [1, 2])
reqs_to_fill: dict[str, tuple[tuple[list[int], ...], int]]
class DecodeBenchConnector(KVConnectorBase_V1):
"""
A KV Connector for decode instance performance testing.
This connector fills the KV cache with dummy (non-zero) values to
emulate a prefill-decode disaggregated setting, enabling performance
testing of the decoder with larger input sequence lengths.
"""
def __init__(
self,
vllm_config: "VllmConfig",
role: KVConnectorRole,
kv_cache_config: Optional["KVCacheConfig"] = None,
):
super().__init__(vllm_config, role, kv_cache_config)
self.connector_scheduler: DecodeBenchConnectorScheduler | None = None
self.connector_worker: DecodeBenchConnectorWorker | None = None
if role == KVConnectorRole.SCHEDULER:
self.connector_scheduler = DecodeBenchConnectorScheduler(vllm_config)
elif role == KVConnectorRole.WORKER:
self.connector_worker = DecodeBenchConnectorWorker(vllm_config)
# ==============================
# Worker-side methods
# ==============================
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
assert self.connector_worker is not None
self.connector_worker.register_kv_caches(kv_caches)
def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None:
assert self.connector_worker is not None
assert isinstance(self._connector_metadata, DecodeBenchConnectorMetadata)
self.connector_worker.start_fill_kv(self._connector_metadata)
def wait_for_layer_load(self, layer_name: str) -> None:
# All operations are synchronous, so nothing to wait for
pass
def save_kv_layer(
self,
layer_name: str,
kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata",
**kwargs: Any,
) -> None:
# This connector doesn't save KV cache (benchmarking only)
pass
def wait_for_save(self):
# This connector doesn't save KV cache (benchmarking only)
pass
# ==============================
# Scheduler-side methods
# ==============================
def get_num_new_matched_tokens(
self,
request: "Request",
num_computed_tokens: int,
) -> tuple[int | None, bool]:
assert self.connector_scheduler is not None
return self.connector_scheduler.get_num_new_matched_tokens(
request, num_computed_tokens
)
def update_state_after_alloc(
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
):
assert self.connector_scheduler is not None
return self.connector_scheduler.update_state_after_alloc(
request, blocks, num_external_tokens
)
def build_connector_meta(
self, scheduler_output: "SchedulerOutput"
) -> KVConnectorMetadata:
assert self.connector_scheduler is not None
return self.connector_scheduler.build_connector_meta(scheduler_output)
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, dict[str, Any] | None]:
assert self.connector_scheduler is not None
self.connector_scheduler.request_finished(request)
return False, None
class DecodeBenchConnectorScheduler:
"""Scheduler-side implementation for DecodeBenchConnector."""
def __init__(self, vllm_config: "VllmConfig"):
self.vllm_config = vllm_config
self.block_size = vllm_config.cache_config.block_size
# Track which requests have already been filled
self._filled_requests: set[str] = set()
# Track pending fills for the current scheduler step
# request_id -> (block_ids_per_group, num_tokens_to_fill)
# Note: _pending_fills doesn't need explicit cleanup - it's cleared
# after build_connector_meta() is called in the same scheduler step
self._pending_fills: dict[str, tuple[tuple[list[int], ...], int]] = {}
def get_num_new_matched_tokens(
self,
request: "Request",
num_computed_tokens: int,
) -> tuple[int, bool]:
"""
For new requests, return the number of tokens that should be filled
with dummy KV cache values.
Returns:
(num_tokens_to_fill, is_async)
- num_tokens_to_fill: number of uncomputed tokens minus 1
(we fill everything except the last token for decode)
- is_async: False (synchronous filling)
"""
req_id = request.request_id
# Only fill once per request on first scheduling
if req_id in self._filled_requests:
return 0, False
# Calculate how many tokens we need to fill
# Fill all uncomputed tokens except the last one (which will be decoded)
# This simulates having processed a long prefill
num_uncomputed_tokens = request.num_tokens - num_computed_tokens
num_tokens_to_fill = max(0, num_uncomputed_tokens - 1)
if num_tokens_to_fill == 0:
return 0, False
# Return False for synchronous operation - the fill is fast enough
# that async overhead isn't worth it
return num_tokens_to_fill, False
def update_state_after_alloc(
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
):
"""
Called after blocks are allocated. Store the block IDs so we can
fill them with dummy values.
Supports both standard attention (single KV cache group) and MLA
(multiple KV cache groups).
"""
req_id = request.request_id
if num_external_tokens == 0:
return
# Get the block IDs that were allocated
# block_groups is a tuple of lists, one per KV cache group
# For standard attention: 1 group
# For MLA: multiple groups (one per attention type)
block_groups = blocks.get_block_ids()
# Calculate how many blocks we need to fill
# num_external_tokens are the tokens we said we'd provide
num_blocks_to_fill = cdiv(num_external_tokens, self.block_size)
# Extract the first num_blocks_to_fill blocks from each group
# All groups should have the same block IDs for the same request
block_ids_per_group = tuple(
group_blocks[:num_blocks_to_fill] for group_blocks in block_groups
)
# Store the blocks to fill for all group. _pending_fills doesn't need cleanup
# as it's cleared after build_connector_meta
self._pending_fills[req_id] = (
block_ids_per_group,
num_external_tokens,
)
self._filled_requests.add(req_id)
logger.debug(
"DecodeBenchConnector: Allocated %d blocks across %d KV cache groups "
"for request %s",
num_blocks_to_fill,
len(block_groups),
req_id,
)
def build_connector_meta(
self, scheduler_output: "SchedulerOutput"
) -> KVConnectorMetadata:
"""
Build metadata containing information about which blocks to fill
with dummy KV values.
"""
meta = DecodeBenchConnectorMetadata(reqs_to_fill=self._pending_fills.copy())
# Clear pending fills after building metadata
self._pending_fills.clear()
return meta
def request_finished(self, request: "Request"):
"""
Called when a request has finished. Clean up any state.
"""
self._filled_requests.discard(request.request_id)
class DecodeBenchConnectorWorker:
"""Worker-side implementation for DecodeBenchConnector."""
def __init__(self, vllm_config: "VllmConfig"):
self.vllm_config = vllm_config
self.block_size = vllm_config.cache_config.block_size
# Get fill parameters from extra config
kv_transfer_config = vllm_config.kv_transfer_config
assert kv_transfer_config is not None
self.fill_mean = kv_transfer_config.get_from_extra_config("fill_mean", 0.015)
self.fill_std = kv_transfer_config.get_from_extra_config("fill_std", 0.0)
# Will be populated via register_kv_caches
self.kv_caches: dict[str, torch.Tensor] | None = None
# Mapping from KV cache group index to list of layer names in that group
self.group_to_layers: dict[int, list[str]] | None = None
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"""Store references to the KV cache tensors and build group mapping."""
self.kv_caches = kv_caches
# For simplicity, assume all layers belong to group 0 (standard attention)
# For MLA models with multiple groups, the metadata will handle the mapping
# We just need to fill the blocks specified in the metadata
self.group_to_layers = {0: list(kv_caches.keys())}
logger.debug(
"DecodeBenchConnector: Registered %d KV cache layers",
len(kv_caches),
)
def start_fill_kv(self, metadata: DecodeBenchConnectorMetadata):
"""
Fill the allocated KV cache blocks with dummy (non-zero) values.
This simulates having a populated KV cache from a prefill phase,
allowing decode performance testing with larger context sizes.
Supports both standard attention (single group) and MLA (multiple groups).
"""
if not metadata.reqs_to_fill:
return
assert self.kv_caches is not None, "KV caches must be registered before filling"
assert self.group_to_layers is not None, "Group mapping must be initialized"
for req_id, (block_ids_per_group, num_tokens) in metadata.reqs_to_fill.items():
# Fill blocks for each KV cache group
for group_idx, block_ids in enumerate(block_ids_per_group):
self._fill_blocks(group_idx, block_ids, num_tokens)
logger.debug(
"DecodeBenchConnector: Filled %d blocks (%d tokens) across %d groups "
"for request %s",
len(block_ids_per_group[0]) if block_ids_per_group else 0,
num_tokens,
len(block_ids_per_group),
req_id,
)
def _fill_blocks(self, group_idx: int, block_ids: list[int], num_tokens: int):
"""
Fill specified blocks with dummy non-zero values for a specific KV cache group.
Args:
group_idx: The KV cache group index to fill
block_ids: List of block IDs to fill in this group
num_tokens: Total number of tokens to fill across these blocks
"""
if not block_ids:
return
assert self.kv_caches is not None
assert self.group_to_layers is not None
# Get the layers that belong to this group
layer_names = self.group_to_layers.get(group_idx, [])
# Fill only the layers in this group
for layer_name in layer_names:
if layer_name not in self.kv_caches:
logger.warning(
"DecodeBenchConnector: Layer %s not found in KV caches", layer_name
)
continue
kv_cache = self.kv_caches[layer_name]
# Convert block_ids to tensor on device
block_ids_tensor = torch.tensor(
block_ids, dtype=torch.long, device=kv_cache.device
)
# Filter invalid block IDs
valid_mask = block_ids_tensor < kv_cache.shape[0]
valid_block_ids = block_ids_tensor[valid_mask]
if len(valid_block_ids) == 0:
continue
# Create fill values - either constant or random
block_shape = kv_cache.shape[1:]
if self.fill_std > 0:
# Random normal sampling
fill_values = torch.normal(
mean=self.fill_mean,
std=self.fill_std,
size=(len(valid_block_ids),) + block_shape,
dtype=kv_cache.dtype,
device=kv_cache.device,
)
else:
# Constant fill value
fill_values = torch.full(
(len(valid_block_ids),) + block_shape,
self.fill_mean,
dtype=kv_cache.dtype,
device=kv_cache.device,
)
# Batch fill operation
kv_cache[valid_block_ids] = fill_values
logger.debug(
"DecodeBenchConnector: Filled %d blocks in group %d with %s values "
"(mean=%.3f, std=%.3f)",
len(block_ids),
group_idx,
"random" if self.fill_std > 0 else "constant",
self.fill_mean,
self.fill_std,
)

View File

@@ -0,0 +1,216 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Any
import torch
from lmcache.integration.vllm.vllm_v1_adapter import (
LMCacheConnectorV1Impl as LMCacheConnectorLatestImpl,
)
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1,
KVConnectorMetadata,
KVConnectorRole,
)
from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request
logger = init_logger(__name__)
class LMCacheConnectorV1(KVConnectorBase_V1):
def __init__(
self,
vllm_config: "VllmConfig",
role: KVConnectorRole,
kv_cache_config: "KVCacheConfig",
):
super().__init__(
vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config
)
assert vllm_config.kv_transfer_config is not None
use_native = vllm_config.kv_transfer_config.get_from_extra_config(
"use_native", False
)
if use_native:
logger.info("Initializing native LMCache connector")
# lazy import
from vllm.distributed.kv_transfer.kv_connector.v1 import lmcache_integration
_adapter = lmcache_integration.vllm_v1_adapter
cls = _adapter.LMCacheConnectorV1Impl
else:
logger.info("Initializing latest dev LMCache connector")
cls = LMCacheConnectorLatestImpl
self._lmcache_engine = cls(vllm_config, role, self)
# ==============================
# Worker-side methods
# ==============================
def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None:
"""
Start loading the KV cache from the connector to vLLM's paged
KV buffer. This is called from the forward context before the
forward pass to enable async loading during model execution.
Args:
forward_context (ForwardContext): the forward context.
**kwargs: additional arguments for the load operation
Note:
The number of elements in kv_caches and layer_names should be
the same.
"""
self._lmcache_engine.start_load_kv(forward_context, **kwargs)
def wait_for_layer_load(self, layer_name: str) -> None:
"""
Block until the KV for a specific layer is loaded into vLLM's
paged buffer. This is called from within attention layer to ensure
async copying from start_load_kv is complete.
This interface will be useful for layer-by-layer pipelining.
Args:
layer_name: the name of that layer
"""
self._lmcache_engine.wait_for_layer_load(layer_name)
def save_kv_layer(
self,
layer_name: str,
kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata",
**kwargs: Any,
) -> None:
"""
Start saving the a layer of KV cache from vLLM's paged buffer
to the connector. This is called from within attention layer to
enable async copying during execution.
Args:
layer_name (str): the name of the layer.
kv_layer (torch.Tensor): the paged KV buffer of the current
layer in vLLM.
attn_metadata (AttentionMetadata): the attention metadata.
**kwargs: additional arguments for the save operation.
"""
self._lmcache_engine.save_kv_layer(
layer_name, kv_layer, attn_metadata, **kwargs
)
def wait_for_save(self):
"""
Block until all the save operations is done. This is called
as the forward context exits to ensure that the async saving
from save_kv_layer is complete before finishing the forward.
This prevents overwrites of paged KV buffer before saving done.
"""
self._lmcache_engine.wait_for_save()
def get_finished(
self, finished_req_ids: set[str]
) -> tuple[set[str] | None, set[str] | None]:
"""
Notifies worker-side connector ids of requests that have
finished generating tokens.
Returns:
ids of requests that have finished asynchronous transfer
(requests that previously returned True from request_finished()),
tuple of (sending/saving ids, recving/loading ids).
The finished saves/sends req ids must belong to a set provided in a
call to this method (this call or a prior one).
"""
return self._lmcache_engine.get_finished(finished_req_ids)
def get_block_ids_with_load_errors(self) -> set[int]:
"""
Get the set of block IDs that failed to load.
Returns:
Set of block IDs that encountered load errors.
Empty set if no load errors occurred.
"""
method = getattr(self._lmcache_engine, "get_block_ids_with_load_errors", None)
if callable(method):
return method()
# Fallback for older versions that don't support this method
return set()
# ==============================
# Scheduler-side methods
# ==============================
def get_num_new_matched_tokens(
self,
request: "Request",
num_computed_tokens: int,
) -> tuple[int | None, bool]:
"""
Get number of new tokens that can be loaded from the
external KV cache beyond the num_computed_tokens.
Args:
request (Request): the request object.
num_computed_tokens (int): the number of locally
computed tokens for this request
Returns:
the number of tokens that can be loaded from the
external KV cache beyond what is already computed.
"""
return self._lmcache_engine.get_num_new_matched_tokens(
request, num_computed_tokens
), False
def update_state_after_alloc(
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
):
"""
Update KVConnector state after block allocation.
"""
self._lmcache_engine.update_state_after_alloc(request, num_external_tokens)
def build_connector_meta(
self, scheduler_output: SchedulerOutput
) -> KVConnectorMetadata:
"""
Build the connector metadata for this step.
This function should NOT modify fields in the scheduler_output.
Also, calling this function will reset the state of the connector.
Args:
scheduler_output (SchedulerOutput): the scheduler output object.
"""
return self._lmcache_engine.build_connector_meta(scheduler_output)
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, dict[str, Any] | None]:
"""
Called when a request has finished, before its blocks are freed.
Returns:
True if the request is being saved/sent asynchronously and blocks
should not be freed until the request_id is returned from
get_finished().
Optional KVTransferParams to be included in the request outputs
returned by the engine.
"""
return self._lmcache_engine.request_finished(request, block_ids)

View File

@@ -0,0 +1,18 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from . import multi_process_adapter, vllm_v1_adapter
from .multi_process_adapter import (
LMCacheMPSchedulerAdapter,
LMCacheMPWorkerAdapter,
LoadStoreOp,
)
__all__ = [
"vllm_v1_adapter",
"multi_process_adapter",
"LMCacheMPSchedulerAdapter",
"LMCacheMPWorkerAdapter",
"LoadStoreOp",
]

View File

@@ -0,0 +1,379 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from collections.abc import Iterable
from dataclasses import dataclass
from itertools import islice
from typing import Any
import torch
import zmq
from lmcache.utils import _lmcache_nvtx_annotate, init_logger
from lmcache.v1.multiprocess.custom_types import (
CudaIPCWrapper,
IPCCacheEngineKey,
KVCache,
)
from lmcache.v1.multiprocess.mq import MessageQueueClient, MessagingFuture
from lmcache.v1.multiprocess.protocol import RequestType, get_response_class
logger = init_logger(__name__)
def wrap_kv_caches(kv_caches: dict[str, KVCache]) -> KVCache:
logger.info("KV caches keys are %s", list(kv_caches.keys()))
return [CudaIPCWrapper(tensor) for tensor in kv_caches.values()]
def send_lmcache_request(
mq_client: MessageQueueClient,
request_type: RequestType,
payloads: list[Any],
) -> MessagingFuture[Any]:
future = mq_client.submit_request(
request_type, payloads, get_response_class(request_type)
)
return future
def get_lmcache_chunk_size(
mq_client: MessageQueueClient,
) -> int:
future = send_lmcache_request(mq_client, RequestType.GET_CHUNK_SIZE, [])
chunk_size = future.result()
return chunk_size
def striding_block_hashes(
block_hashes: list[bytes],
blocks_in_chunk,
) -> Iterable[bytes]:
"""Striding the block hashes to get the block hashes for each chunk.
For example, if blocks_in_chunk is 16, then we will get the block hashes
for the 16th, 32nd, 48th, ... blocks.
"""
return islice(block_hashes, blocks_in_chunk - 1, None, blocks_in_chunk)
@dataclass
class LoadStoreOp:
block_hashes: list[bytes]
block_ids: list[int]
def __len__(self) -> int:
return len(self.block_hashes)
def __post_init__(self):
assert len(self.block_hashes) == len(self.block_ids), (
"The number of block hashes should be equal to the number of block ids "
f"But got {len(self.block_hashes)} and {len(self.block_ids)}"
)
StoreResult = bool
RetrieveResult = list[bool]
LookupResult = list[bool]
class LMCacheMPSchedulerAdapter:
def __init__(
self,
server_url: str,
context: zmq.Context,
model_name: str,
world_size: int,
kv_rank: int,
vllm_block_size: int,
):
"""
Args:
server_url: The server URL for the LMCache message queue
context: The ZMQ context
model_name: The model name used for LMCache keys
world_size: The world size used for LMCache keys
kv_rank: The kv rank used for LMCache keys
vllm_block_size: The block size used in vLLM
"""
self.mq_client = MessageQueueClient(server_url, context)
# Request futures
self.lookup_futures: dict[str, MessagingFuture[LookupResult]] = {}
self.model_name = model_name
self.world_size = world_size
self.worker_id = kv_rank
# Read chunk size from lmcache
self.chunk_size = get_lmcache_chunk_size(self.mq_client)
assert self.chunk_size % vllm_block_size == 0, (
"LMCache chunk size should be a multiple of vLLM block size"
)
self.blocks_in_chunk = self.chunk_size // vllm_block_size
@_lmcache_nvtx_annotate
def maybe_submit_lookup_request(self, request_id: str, block_hashes: list[bytes]):
if request_id in self.lookup_futures:
# Skip if there is already a lookup request
return
s = striding_block_hashes(block_hashes, self.blocks_in_chunk)
keys = [self._create_key(block_hash) for block_hash in s]
future = send_lmcache_request(
self.mq_client,
RequestType.LOOKUP,
[keys, True],
)
self.lookup_futures[request_id] = future
@_lmcache_nvtx_annotate
def check_lookup_result(self, request_id: str) -> int | None:
assert request_id in self.lookup_futures, (
f"Lookup request for request_id={request_id} has not been submitted"
)
future = self.lookup_futures[request_id]
if not future.query():
return None
result = future.result()
num_chunks = sum(result)
return num_chunks * self.chunk_size
def num_blocks_per_chunk(self) -> int:
"""
Returns:
The number of vllm blocks in a LMCache data chunk
"""
return self.blocks_in_chunk
# Helper functions
def _create_key(self, block_hash: bytes) -> IPCCacheEngineKey:
"""Convert a block hash to an IPC cache engine key"""
return IPCCacheEngineKey(
model_name=self.model_name,
world_size=self.world_size,
worker_id=self.worker_id,
chunk_hash=block_hash,
)
class LMCacheMPWorkerAdapter:
def __init__(
self,
server_url: str,
context: zmq.Context,
model_name: str,
world_size: int,
kv_rank: int,
vllm_block_size: int,
):
self.mq_client = MessageQueueClient(server_url, context)
# Instance id for GPU worker
self.instance_id = os.getpid()
# Registered kv caches from vLLM
self.kv_caches: dict[str, torch.Tensor] = {}
# Request futures
# request_id -> (future, other merged requests)
self.store_futures: dict[
str, tuple[MessagingFuture[StoreResult], list[str]]
] = {}
self.retrieve_futures: dict[
str, tuple[MessagingFuture[RetrieveResult], list[str]]
] = {}
self.finished_stores: set[str] = set()
self.previously_finished: set[str] = set()
self.model_name = model_name
self.world_size = world_size
self.worker_id = kv_rank
# Read chunk size from lmcache
chunk_size = get_lmcache_chunk_size(self.mq_client)
assert chunk_size % vllm_block_size == 0, (
"LMCache chunk size should be a multiple of vLLM block size"
)
self.blocks_in_chunk = chunk_size // vllm_block_size
def register_kv_caches(self, kv_caches: dict[str, KVCache]):
# Register kv cache and send the request
self.kv_caches = kv_caches
logger.info("Registering kv caches")
future = send_lmcache_request(
self.mq_client,
RequestType.REGISTER_KV_CACHE,
[self.instance_id, wrap_kv_caches(kv_caches)],
)
future.result()
@_lmcache_nvtx_annotate
def submit_store_request(
self, request_id: str, op: LoadStoreOp, event: torch.cuda.Event
):
keys = self._block_hashes_to_keys(op.block_hashes)
future = send_lmcache_request(
self.mq_client,
RequestType.STORE,
[keys, self.instance_id, op.block_ids, event.ipc_handle()],
).to_cuda_future()
self.store_futures[request_id] = (future, [])
@_lmcache_nvtx_annotate
def submit_retrieve_request(
self, request_id: str, op: LoadStoreOp, event: torch.cuda.Event
):
keys = self._block_hashes_to_keys(op.block_hashes)
future = send_lmcache_request(
self.mq_client,
RequestType.RETRIEVE,
[keys, self.instance_id, op.block_ids, event.ipc_handle()],
).to_cuda_future()
self.retrieve_futures[request_id] = (future, [])
@_lmcache_nvtx_annotate
def batched_submit_store_requests(
self,
request_ids: list[str],
ops: list[LoadStoreOp],
event: torch.cuda.Event,
):
keys = []
block_ids = []
for op in ops:
keys.extend(self._block_hashes_to_keys(op.block_hashes))
block_ids.extend(op.block_ids)
future = send_lmcache_request(
self.mq_client,
RequestType.STORE,
[keys, self.instance_id, block_ids, event.ipc_handle()],
).to_cuda_future()
self.store_futures[request_ids[0]] = (future, request_ids[1:])
@_lmcache_nvtx_annotate
def batched_submit_retrieve_requests(
self,
request_ids: list[str],
ops: list[LoadStoreOp],
event: torch.cuda.Event,
):
keys = []
block_ids = []
for op in ops:
keys.extend(self._block_hashes_to_keys(op.block_hashes))
block_ids.extend(op.block_ids)
future = send_lmcache_request(
self.mq_client,
RequestType.RETRIEVE,
[keys, self.instance_id, block_ids, event.ipc_handle()],
).to_cuda_future()
self.retrieve_futures[request_ids[0]] = (future, request_ids[1:])
@_lmcache_nvtx_annotate
def get_finished(
self, finished_req_ids: set[str]
) -> tuple[set[str] | None, set[str] | None]:
finished_stores = set()
finished_retrieves = set()
for request_id, (future, other_reqs) in self.store_futures.items():
if not future.query():
continue
result = future.result()
finished_stores.add(request_id)
finished_stores.update(other_reqs)
if not result:
# TODO: add error handling here
logger.error(
"Something went wrong when processing the "
"store request for request_id=%s",
request_id,
)
for request_id, (future, other_reqs) in self.retrieve_futures.items():
if not future.query():
continue
result = future.result()
finished_retrieves.add(request_id)
finished_retrieves.update(other_reqs)
if not all(result):
# TODO: add error handing here
logger.error(
"Something went wrong when processing the "
"retrieve request for request_id=%s, result=%s",
request_id,
result,
)
logger.info("Retrieve request for request_id=%s finished", request_id)
# Remove the finished requests from the tracking dicts
for request_id in finished_stores:
self.store_futures.pop(request_id, None)
for request_id in finished_retrieves:
self.retrieve_futures.pop(request_id, None)
# Update the internal states
self.finished_stores.update(finished_stores)
ret_stores = set()
for req_id in finished_req_ids:
if req_id in self.finished_stores or req_id in self.store_futures:
self.previously_finished.add(req_id)
else:
ret_stores.add(req_id)
# Calculate the final finished stores
ret_stores.update(self._update_and_get_finished_store())
return ret_stores, finished_retrieves
def num_blocks_per_chunk(self) -> int:
"""
Returns:
The number of vllm blocks in a LMCache data chunk
"""
return self.blocks_in_chunk
def shutdown(self):
# Unregister kv cache
logger.info("Unregistering kv caches")
send_lmcache_request(
self.mq_client, RequestType.UNREGISTER_KV_CACHE, [self.instance_id]
).result()
self.mq_client.close()
# Helper functions
def _update_and_get_finished_store(
self,
) -> set[str]:
"""Converge the internal states about finished stores
and returns the 'safe finished store request ids' back
"""
safe_finished_s = self.finished_stores.intersection(self.previously_finished)
self.finished_stores.difference_update(self.previously_finished)
self.previously_finished.difference_update(safe_finished_s)
return safe_finished_s
def _create_key(self, block_hash: bytes) -> IPCCacheEngineKey:
"""Convert a block hash to an IPC cache engine key"""
return IPCCacheEngineKey(
model_name=self.model_name,
world_size=self.world_size,
worker_id=self.worker_id,
chunk_hash=block_hash,
)
def _block_hashes_to_keys(
self, block_hashes: list[bytes]
) -> list[IPCCacheEngineKey]:
"""Convert block hashes to IPC cache engine keys"""
s = striding_block_hashes(block_hashes, self.blocks_in_chunk)
return [self._create_key(block_hash) for block_hash in s]

View File

@@ -0,0 +1,221 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Standard
import os
import threading
from typing import TYPE_CHECKING, Union
import torch
from lmcache.config import LMCacheEngineConfig as Config
from lmcache.logging import init_logger
from lmcache.v1.config import LMCacheEngineConfig as V1Config
if TYPE_CHECKING:
from vllm.config import ModelConfig
from vllm.multimodal.inputs import PlaceholderRange
from vllm.v1.core.sched.output import NewRequestData
from vllm.v1.request import Request
logger = init_logger(__name__)
ENGINE_NAME = "vllm-instance"
# Thread-safe singleton storage
_config_instance: Config | V1Config | None = None
_config_lock = threading.Lock()
def is_false(value: str) -> bool:
"""Check if the given string value is equivalent to 'false'."""
return value.lower() in ("false", "0", "no", "n", "off")
def lmcache_get_or_create_config() -> Config | V1Config:
"""Get the LMCache configuration from the environment variable
`LMCACHE_CONFIG_FILE`. If the environment variable is not set, this
function will return the default configuration.
This function is thread-safe and implements singleton pattern,
ensuring the configuration is loaded only once.
"""
global _config_instance
# Double-checked locking for thread-safe singleton
if _config_instance is None:
with _config_lock:
if _config_instance is None: # Check again within lock
if is_false(os.getenv("LMCACHE_USE_EXPERIMENTAL", "True")):
logger.warning(
"Detected LMCACHE_USE_EXPERIMENTAL is set to False. "
"Using legacy configuration is deprecated and will "
"be remove soon! Please set LMCACHE_USE_EXPERIMENTAL "
"to True."
)
LMCacheEngineConfig = Config # type: ignore[assignment]
else:
LMCacheEngineConfig = V1Config # type: ignore[assignment]
if "LMCACHE_CONFIG_FILE" not in os.environ:
logger.warning(
"No LMCache configuration file is set. Trying to read"
" configurations from the environment variables."
)
logger.warning(
"You can set the configuration file through "
"the environment variable: LMCACHE_CONFIG_FILE"
)
_config_instance = LMCacheEngineConfig.from_env()
else:
config_file = os.environ["LMCACHE_CONFIG_FILE"]
logger.info("Loading LMCache config file %s", config_file)
_config_instance = LMCacheEngineConfig.from_file(config_file)
# Update config from environment variables
_config_instance.update_config_from_env()
return _config_instance
def hex_hash_to_int16(s: str) -> int:
"""
Convert a hex hash string to a 16-bit integer.
"""
return int(s, 16) & 0xFFFF
def apply_mm_hashes_to_token_ids(
token_ids: torch.Tensor,
mm_hashes: list[str],
mm_positions: list["PlaceholderRange"],
) -> torch.Tensor:
"""
Overwrite token_ids in-place for multimodal placeholders using
efficient slice assignments.
"""
n = token_ids.size(0)
for hash_str, placeholder in zip(mm_hashes, mm_positions):
start, length = placeholder.offset, placeholder.length
if start >= n:
continue
end = min(start + length, n)
token_ids[start:end] = hex_hash_to_int16(hash_str)
return token_ids
def mla_enabled(model_config: "ModelConfig") -> bool:
return (
hasattr(model_config, "use_mla")
and isinstance(model_config.use_mla, bool)
and model_config.use_mla
)
def create_lmcache_metadata(
vllm_config=None, model_config=None, parallel_config=None, cache_config=None
):
"""
Create LMCacheEngineMetadata from vLLM configuration.
This function extracts common metadata creation logic that was duplicated
across multiple files.
Args:
vllm_config (VllmConfig): vLLM configuration object containing model,
parallel, and cache configs (alternative to
individual config parameters)
model_config (ModelConfig): Model configuration (alternative to
vllm_config)
parallel_config (ParallelConfig): Parallel configuration (alternative
to vllm_config)
cache_config (CacheConfig): Cache configuration (alternative to
vllm_config)
"""
# Third Party
# First Party
from lmcache.config import LMCacheEngineMetadata
from vllm.utils.torch_utils import get_kv_cache_torch_dtype
config = lmcache_get_or_create_config()
# Support both vllm_config object and individual config parameters
if vllm_config is not None:
model_cfg = vllm_config.model_config
parallel_cfg = vllm_config.parallel_config
cache_cfg = vllm_config.cache_config
else:
if model_config is None or parallel_config is None or cache_config is None:
raise ValueError(
"Either vllm_config must be provided, or all of "
"model_config, parallel_config, and cache_config must be provided."
)
model_cfg = model_config
parallel_cfg = parallel_config
cache_cfg = cache_config
# Get KV cache dtype
kv_dtype = get_kv_cache_torch_dtype(cache_cfg.cache_dtype, model_cfg.dtype)
# Check if MLA is enabled
use_mla = mla_enabled(model_cfg)
# Construct KV shape (for memory pool)
num_layer = model_cfg.get_num_layers(parallel_cfg)
chunk_size = config.chunk_size
num_kv_head = model_cfg.get_num_kv_heads(parallel_cfg)
head_size = model_cfg.get_head_size()
kv_shape = (num_layer, 1 if use_mla else 2, chunk_size, num_kv_head, head_size)
# Create metadata
metadata = LMCacheEngineMetadata(
model_cfg.model,
parallel_cfg.world_size,
parallel_cfg.rank,
"vllm",
kv_dtype,
kv_shape,
use_mla,
)
return metadata, config
def extract_mm_features(
request: Union["Request", "NewRequestData"], modify: bool = False
) -> tuple[list[str], list["PlaceholderRange"]]:
"""
Normalize multimodal information from a Request into parallel lists.
This helper reads either:
1) `request.mm_features` (objects each exposing `.identifier` and
`.mm_position`), or
2) legacy fields `request.mm_hashes` and `request.mm_positions`.
It returns two equally sized lists: the multimodal hash identifiers and
their corresponding positions. If the request contains no multimodal info,
it returns `([], [])`.
Args:
request (Request): The source object.
modify (bool):
Controls copy semantics for the legacy-path return values.
- If True and legacy fields are used, shallow-copies are returned so
the caller can mutate the lists without affecting `request`.
- If False, the original legacy sequences are returned as-is
(zero-copy); treat them as read-only.
Returns:
tuple[list[str], list[PlaceholderRange]]: (`mm_hashes`, `mm_positions`).
May be `([], [])` when no multimodal data is present.
"""
if getattr(request, "mm_features", None):
mm_hashes, mm_positions = zip(
*((f.identifier, f.mm_position) for f in request.mm_features)
)
return (list(mm_hashes), list(mm_positions))
elif getattr(request, "mm_hashes", None):
if modify:
return (
request.mm_hashes.copy(), # type: ignore
request.mm_positions.copy(), # type: ignore
)
else:
return (request.mm_hashes, request.mm_positions) # type: ignore
else:
return ([], [])

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,867 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum
from collections.abc import Iterable
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
import torch
import zmq
from lmcache.utils import init_logger as lmcache_init_logger
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1,
KVConnectorMetadata,
KVConnectorRole,
)
from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_integration import (
LMCacheMPSchedulerAdapter,
LMCacheMPWorkerAdapter,
LoadStoreOp,
)
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import KVConnectorOutput
from vllm.v1.utils import ConstantList
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed.kv_events import KVCacheEvent
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorPromMetrics,
KVConnectorStats,
PromMetric,
PromMetricT,
)
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.kv_cache_utils import BlockHash
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request
logger = lmcache_init_logger(__name__)
# Helper functions
def reformat_block_ids(block_ids: tuple[list[int], ...] | None) -> list[int]:
if block_ids is None:
return []
assert isinstance(block_ids, tuple), (
f"Expected block_ids to be a tuple of lists, but got {type(block_ids)}"
)
if len(block_ids) > 1:
raise RuntimeError(
"LMCacheMPConnector only works without hybrid kv cache manager. "
"Please pass --disable-hybrid-kv-cache-manager when starting vllm"
)
return block_ids[0]
def create_scheduler_adapter(
server_url: str, zmq_context: zmq.Context, vllm_config: VllmConfig
) -> LMCacheMPSchedulerAdapter:
# TODO: have a helper function to calculate the correct rank and
# world size for the MLA and other models
return LMCacheMPSchedulerAdapter(
server_url,
zmq_context,
vllm_config.model_config.model,
vllm_config.parallel_config.world_size,
vllm_config.parallel_config.rank,
vllm_config.cache_config.block_size,
)
def create_worker_adapter(
server_url: str, zmq_context: zmq.Context, vllm_config: VllmConfig
) -> LMCacheMPWorkerAdapter:
# TODO: have a helper function to calculate the correct rank and
# world size for the MLA and other models
return LMCacheMPWorkerAdapter(
server_url,
zmq_context,
vllm_config.model_config.model,
vllm_config.parallel_config.world_size,
vllm_config.parallel_config.rank,
vllm_config.cache_config.block_size,
)
def convert_block_hashes_to_bytes(
block_hashes: list["BlockHash"],
) -> list[bytes]:
return cast(list[bytes], block_hashes)
class LMCacheMPRequestState(enum.Enum):
"""
State machine:
PREFETCHING -- update_state_after_alloc --> WAITING_FOR_LOAD
WAITING_FOR_LOAD -- process_loading_requests --> READY
"""
PREFETCHING = enum.auto()
WAITING_FOR_LOAD = enum.auto()
READY = enum.auto()
@dataclass
class LMCacheMPRequestTracker:
# NOTE: this class used vLLM data structures, should be part of
# vLLM integration code
request_id: str
# Read-only lists to track the token ids and block hashes
all_token_ids: ConstantList[int]
block_hashes: ConstantList["BlockHash"]
# Block ids and hashes will be updated at update_states_after_alloc and
# during the generation
allocated_block_ids: list[int] = field(default_factory=list)
# Number of scheduled tokens in this request. We keep tracking this to
# avoid saving half-full blocks.
num_scheduled_tokens: int = 0
# Number of blocks stored will be initialized when lookup the external
# hit tokens and will be updated when processing new requests and cached
# requests.
num_stored_blocks: int = 0
# Staging load operation -- save vllm and lmcache hit tokens during lookup
num_vllm_hit_blocks: int = 0
num_lmcache_hit_blocks: int = 0
# Main state
state: LMCacheMPRequestState = LMCacheMPRequestState.PREFETCHING
def __init__(self, request: "Request"):
self.request_id = request.request_id
self.all_token_ids = request.all_token_ids
self.block_hashes = ConstantList(request.block_hashes)
self.allocated_block_ids = []
self.num_stored_blocks = 0
self.num_vllm_hit_blocks = 0
self.num_lmcache_hit_blocks = 0
self.state = LMCacheMPRequestState.PREFETCHING
####
# Check the state of the request
####
def needs_retrieve(self) -> bool:
"""Check whether the current request needs retrieve, will be used
update_stage_after_alloc"""
return (
self.num_lmcache_hit_blocks > self.num_vllm_hit_blocks
and self.state != LMCacheMPRequestState.READY
)
def is_ready_for_retrieving(self) -> bool:
"""Check whether the current request is ready for retrieving,
will be used in process_loading_requests"""
return (
self.state == LMCacheMPRequestState.WAITING_FOR_LOAD
and self.needs_retrieve()
)
####
# Update internal states
####
def increase_num_scheduled_tokens(self, num_new_tokens: int):
self.num_scheduled_tokens += num_new_tokens
def increase_num_stored_blocks(self, num_new_blocks: int):
"""Increase the number of stored blocks for the current request
This function will be called when processing the cached requests.
"""
self.num_stored_blocks += num_new_blocks
def update_block_ids(
self,
new_block_ids: list[int],
):
"""Update the block ids for the current request
This function will be called when processing the cached requests.
"""
self.allocated_block_ids.extend(new_block_ids)
####
# For debugging
####
def __repr__(self) -> str:
return (
f"LMCacheMPRequestTracker(request_id={self.request_id}, "
f"num_tokens={len(self.all_token_ids)}, "
f"num_block_hashes={len(self.block_hashes)}, "
f"num_allocated_blocks={len(self.allocated_block_ids)}, "
f"num_stored_blocks={self.num_stored_blocks}, "
f"vllm_hit_blocks={self.num_vllm_hit_blocks}, "
f"lmcache_hit_blocks={self.num_lmcache_hit_blocks}, "
f"state={self.state})"
)
def __str__(self) -> str:
return self.__repr__()
@dataclass
class LMCacheMPRequestMetadata:
request_id: str
direction: Literal["STORE", "RETRIEVE"]
op: LoadStoreOp
@staticmethod
def GetStoreMetadata(
tracker: LMCacheMPRequestTracker,
blocks_in_chunk: int,
vllm_block_size: int,
) -> "LMCacheMPRequestMetadata | None":
"""
Generate the store metadata for the current request tracker.
Args:
tracker: The request tracker to generate the metadata from.
blocks_in_chunk: the number of blocks in a LMCache data chunk
"""
# Store the blocks that has block hashes
# NOTE: the invariant here is that `num_stored_blocks` should
# always be a multiple of `blocks_in_chunk`
# TODO: This should be checked everytime we update the num_stored_blocks
min_available_blocks = min(
len(tracker.block_hashes),
len(tracker.allocated_block_ids),
tracker.num_scheduled_tokens // vllm_block_size,
)
num_staging_blocks = min_available_blocks - tracker.num_stored_blocks
num_chunks = num_staging_blocks // blocks_in_chunk
if num_chunks >= 1:
start = tracker.num_stored_blocks
end = start + num_chunks * blocks_in_chunk
block_hashes = convert_block_hashes_to_bytes(
tracker.block_hashes[start:end]
)
block_ids = tracker.allocated_block_ids[start:end]
ret = LMCacheMPRequestMetadata(
request_id=tracker.request_id,
direction="STORE",
op=LoadStoreOp(block_hashes=block_hashes, block_ids=block_ids),
)
# Update the request tracker
tracker.increase_num_stored_blocks(end - start)
return ret
return None
@staticmethod
def GetRetrieveMetadata(
tracker: LMCacheMPRequestTracker,
blocks_in_chunk: int,
) -> "LMCacheMPRequestMetadata | None":
"""
Generate the retrieve metadata for the current request tracker.
Args:
tracker: The request tracker to generate the metadata from.
blocks_in_chunk: the number of blocks in a LMCache data chunk
"""
if not tracker.is_ready_for_retrieving():
return None
# |---------------------|-----------------|----------------|
# | num_vllm_hit_blocks |
# | lmcache chunk 1 | lmcache chunk 2 |
# | need to retrieve |
start = tracker.num_vllm_hit_blocks // blocks_in_chunk * blocks_in_chunk
end = tracker.num_lmcache_hit_blocks
assert end % blocks_in_chunk == 0, (
"The number of LMCache hit blocks should be a multiple of the "
"number of blocks in a lmcache chunk. "
)
assert len(tracker.block_hashes) >= end, (
"The number of block hashes should be greater than or equal to the "
"number of LMCache hit blocks. "
)
if end > start:
block_hashes = convert_block_hashes_to_bytes(
tracker.block_hashes[start:end]
)
block_ids = tracker.allocated_block_ids[start:end]
ret = LMCacheMPRequestMetadata(
request_id=tracker.request_id,
direction="RETRIEVE",
op=LoadStoreOp(block_hashes=block_hashes, block_ids=block_ids),
)
return ret
return None
class LMCacheMPConnectorMetadata(KVConnectorMetadata):
def __init__(self):
super().__init__()
self.requests: list[LMCacheMPRequestMetadata] = []
def add_request_metadata(self, request_metadata: LMCacheMPRequestMetadata):
self.requests.append(request_metadata)
def __len__(self):
return len(self.requests)
# For debugging
def __str__(self):
request_strs = []
for req_meta in self.requests:
request_strs.append(
f"RequestMetadata(request_id={req_meta.request_id}, "
f"direction={req_meta.direction}, "
f"num_blocks={len(req_meta.op)}, "
f"block_ids={req_meta.op.block_ids})"
)
return "[" + "\n".join(request_strs) + "]"
def __repr__(self):
return self.__str__()
class LMCacheMPConnector(KVConnectorBase_V1):
"""
The connector for LMCache multi-process mode.
Extra configs (kv_transfer_config.extra_config):
- lmcache.mp.host: the host of the LMCache server.
- lmcache.mp.port: the port of the LMCache server.
"""
def __init__(
self,
vllm_config: "VllmConfig",
role: KVConnectorRole,
kv_cache_config: Optional["KVCacheConfig"] = None,
):
super().__init__(vllm_config, role, kv_cache_config)
assert vllm_config.kv_transfer_config is not None
server_host = vllm_config.kv_transfer_config.get_from_extra_config(
"lmcache.mp.host", "tcp://localhost"
)
server_port = vllm_config.kv_transfer_config.get_from_extra_config(
"lmcache.mp.port", 5555
)
server_url = f"{server_host}:{server_port}"
zmq_context = zmq.Context.instance()
if self.role == KVConnectorRole.SCHEDULER:
self.scheduler_adapter = create_scheduler_adapter(
server_url, zmq_context, vllm_config
)
self.request_trackers: dict[str, LMCacheMPRequestTracker] = {}
elif self.role == KVConnectorRole.WORKER:
self.worker_adapter = create_worker_adapter(
server_url, zmq_context, vllm_config
)
else:
raise ValueError(f"Unknown KVConnectorRole: {self.role}")
self.vllm_block_size = vllm_config.cache_config.block_size
@property
def role(self) -> KVConnectorRole:
return self._role
# ==============================
# Worker-side methods
# ==============================
def _get_connector_metadata(self) -> KVConnectorMetadata:
"""Get the connector metadata.
This function should only be called inside the connector.
Returns:
ConnectorMetadata: the connector metadata.
"""
# Should only be called while set to valid metadata.
assert self._connector_metadata is not None
return self._connector_metadata
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"""
Initialize with the KV caches. Useful for pre-registering the
KV Caches in the KVConnector (e.g. for NIXL).
Args:
kv_caches: dictionary of layer names, kv cache
"""
logger.info("Registering kv caches!")
self.worker_adapter.register_kv_caches(kv_caches)
return
def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None:
"""
Start loading the KV cache from the connector to vLLM's paged
KV buffer. This is called from the forward context before the
forward pass to enable async loading during model execution.
Args:
forward_context (ForwardContext): the forward context.
**kwargs: additional arguments for the load operation
Note:
The number of elements in kv_caches and layer_names should be
the same.
"""
metadata = self._get_connector_metadata()
assert isinstance(metadata, LMCacheMPConnectorMetadata)
with torch.cuda.stream(torch.cuda.current_stream()):
event = torch.cuda.Event(interprocess=True)
event.record()
request_ids = []
ops = []
for meta in metadata.requests:
if meta.direction != "RETRIEVE":
continue
request_ids.append(meta.request_id)
ops.append(meta.op)
if len(request_ids) > 0:
logger.info(
"HERE! SUBMITTING THE BATCHED RETRIEVE REQUESTS %s", request_ids
)
self.worker_adapter.batched_submit_retrieve_requests(
request_ids, ops, event
)
def wait_for_layer_load(self, layer_name: str) -> None:
"""
Block until the KV for a specific layer is loaded into vLLM's
paged buffer. This is called from within attention layer to ensure
async copying from start_load_kv is complete.
This interface will be useful for layer-by-layer pipelining.
Args:
layer_name: the name of that layer
"""
return
def save_kv_layer(
self,
layer_name: str,
kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata",
**kwargs: Any,
) -> None:
"""
Start saving a layer of KV cache from vLLM's paged buffer
to the connector. This is called from within attention layer to
enable async copying during execution.
Args:
layer_name (str): the name of the layer.
kv_layer (torch.Tensor): the paged KV buffer of the current
layer in vLLM.
attn_metadata (AttentionMetadata): the attention metadata.
**kwargs: additional arguments for the save operation.
"""
return
def wait_for_save(self):
"""
Block until all the save operations is done. This is called
as the forward context exits to ensure that the async saving
from save_kv_layer is complete before finishing the forward.
This prevents overwrites of paged KV buffer before saving done.
"""
metadata = self._get_connector_metadata()
assert isinstance(metadata, LMCacheMPConnectorMetadata)
with torch.cuda.stream(torch.cuda.current_stream()):
event = torch.cuda.Event(interprocess=True)
event.record()
request_ids = []
ops = []
for meta in metadata.requests:
if meta.direction != "STORE":
continue
request_ids.append(meta.request_id)
ops.append(meta.op)
if len(request_ids) > 0:
self.worker_adapter.batched_submit_store_requests(request_ids, ops, event)
def get_finished(
self, finished_req_ids: set[str]
) -> tuple[set[str] | None, set[str] | None]:
"""
Notifies worker-side connector ids of requests that have
finished generating tokens on the worker.
The scheduler process (via the Executors) will use this output
to track which workers are done.
Returns:
ids of requests that have finished asynchronous transfer
(requests that previously returned True from request_finished()),
tuple of (sending/saving ids, recving/loading ids).
The finished saves/sends req ids must belong to a set provided in a
call to this method (this call or a prior one).
"""
val = self.worker_adapter.get_finished(finished_req_ids)
# logger.error("Finished req ids: %s, %s", val[0], val[1])
return val
def get_block_ids_with_load_errors(self) -> set[int]:
"""
Get the set of block IDs that failed to load.
Returns:
Set of block IDs that encountered load errors.
Empty set if no load errors occurred.
Notes:
- Applies to both sync- and async-loading requests.
- Async loading: failed blocks may be reported in any forward pass
up to and including the pass where the request ID is returned by
`get_finished()`. Even if failures occur, the request must still
be reported via `get_finished()`, and the failed block IDs must
appear here no later than that same pass.
- Sync loading: failed blocks should be reported in the forward
pass in which they are detected.
"""
# TODO: add error tracking
return set()
def shutdown(self):
"""
Shutdown the connector. This is called when the worker process
is shutting down to ensure that all the async operations are
completed and the connector is cleaned up properly.
"""
if hasattr(self, "worker_adapter"):
self.worker_adapter.shutdown()
return None
def get_kv_connector_stats(self) -> Optional["KVConnectorStats"]:
"""
Get the KV connector stats collected during the last interval.
"""
return None
# ==============================
# Scheduler-side methods
# ==============================
def get_num_new_matched_tokens(
self,
request: "Request",
num_computed_tokens: int,
) -> tuple[int | None, bool]:
"""
Get number of new tokens that can be loaded from the
external KV cache beyond the num_computed_tokens.
Args:
request (Request): the request object.
num_computed_tokens (int): the number of locally
computed tokens for this request
Returns:
A tuple with the following elements:
- An optional number of tokens that can be loaded from the
external KV cache beyond what is already computed.
If None, it means that the connector needs more time to
determine the number of matched tokens, and the scheduler
should query for this request again later.
- `True` if external KV cache tokens will be loaded
asynchronously (between scheduler steps). Must be
'False' if the first element is 0.
Notes:
The connector should only consider the largest prefix of prompt-
tokens for which KV cache is actually available at the time of the
call. If the cache cannot be loaded for some tokens (e.g., due to
connectivity issues or eviction), those tokens must not be taken
into account.
"""
tracker = self._get_or_create_request_tracker(request)
self.scheduler_adapter.maybe_submit_lookup_request(
request.request_id, convert_block_hashes_to_bytes(request.block_hashes)
)
ret = self.scheduler_adapter.check_lookup_result(request.request_id)
if ret is None:
return None, True
if ret == 0:
return 0, False
assert (
ret % (self.scheduler_adapter.num_blocks_per_chunk() * self.vllm_block_size)
== 0
)
# Update num stored blocks for the tracker
num_vllm_blocks = num_computed_tokens // self.vllm_block_size
num_lmcache_blocks = ret // self.vllm_block_size
tracker.increase_num_stored_blocks(num_lmcache_blocks)
# Save the vllm and lmcache hit tokens
tracker.num_vllm_hit_blocks = num_vllm_blocks
tracker.num_lmcache_hit_blocks = num_lmcache_blocks
need_to_load = max(0, ret - num_computed_tokens)
logger.debug(
"vLLM hit is: %d, Need to load is %d", num_computed_tokens, need_to_load
)
return need_to_load, need_to_load > 0
def update_state_after_alloc(
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
):
"""
Update KVConnector state after block allocation.
If get_num_new_matched_tokens previously returned True for a
request, this function may be called twice for that same request -
first when blocks are allocated for the connector tokens to be
asynchronously loaded into, and second when any additional blocks
are allocated, after the load/transfer is complete.
Args:
request (Request): the request object.
blocks (KVCacheBlocks): the blocks allocated for the request.
num_external_tokens (int): the number of tokens that will be
loaded from the external KV cache.
"""
# NOTE: the `blocks` are NEW BLOCKS allocated for this request.
tracker = self._get_request_tracker(request.request_id)
block_ids = reformat_block_ids(blocks.get_block_ids())
# No matter we need to retrieve or not, we need to update
# the block ids into the tracker
tracker.update_block_ids(block_ids)
# Update the state of the tracker
condition = tracker.needs_retrieve()
if tracker.state == LMCacheMPRequestState.PREFETCHING:
# If need to retrieve, change to WAITING_FOR_LOAD
# Otherwise, change to READY
tracker.state = (
LMCacheMPRequestState.WAITING_FOR_LOAD
if condition
else LMCacheMPRequestState.READY
)
def build_connector_meta(
self, scheduler_output: SchedulerOutput
) -> KVConnectorMetadata:
"""
Build the connector metadata for this step.
This function should NOT modify fields in the scheduler_output.
Also, calling this function will reset the state of the connector.
Args:
scheduler_output (SchedulerOutput): the scheduler output object.
"""
metadata = LMCacheMPConnectorMetadata()
self._process_retrieve_requests(metadata)
self._process_new_requests(scheduler_output, metadata)
self._process_cached_requests(scheduler_output, metadata)
if len(metadata) > 0:
logger.debug("Final connector metadata: %s", metadata)
return metadata
def update_connector_output(self, connector_output: KVConnectorOutput):
"""
Update KVConnector state from worker-side connectors output.
Args:
connector_output (KVConnectorOutput): the worker-side
connectors output.
"""
return
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, dict[str, Any] | None]:
"""
Called exactly once when a request has finished, before its blocks are
freed.
The connector may assumes responsibility for freeing the blocks
asynchronously by returning True.
Returns:
True if the request is being saved/sent asynchronously and blocks
should not be freed until the request_id is returned from
get_finished().
Optional KVTransferParams to be included in the request outputs
returned by the engine.
"""
return True, None
def take_events(self) -> Iterable["KVCacheEvent"]:
"""
Take the KV cache events from the connector.
Yields:
New KV cache events since the last call.
"""
return ()
@classmethod
def get_required_kvcache_layout(cls, vllm_config: "VllmConfig") -> str | None:
"""
Get the required KV cache layout for this connector.
Args:
vllm_config (VllmConfig): the vllm config.
Returns:
str: the required KV cache layout. e.g. HND, or NHD.
None if the connector does not require a specific layout.
"""
if cls is KVConnectorBase_V1:
raise TypeError(
"get_required_kvcache_layout should not be called "
"on the abstract base class"
)
return None
def get_finished_count(self) -> int | None:
"""
Get the count of requests expected to complete send/receive operations
via this connector. This method is used to initialize the
KVOutputAggregator, overwriting the default world_size.
Returns:
int: expected sending or receiving completion count.
"""
return None
@classmethod
def build_kv_connector_stats(
cls, data: dict[str, Any] | None = None
) -> Optional["KVConnectorStats"]:
"""
KVConnectorStats resolution method. This method allows dynamically
registered connectors to return their own KVConnectorStats object,
which can implement custom aggregation logic on the data dict.
"""
return None
@classmethod
def build_prom_metrics(
cls,
vllm_config: "VllmConfig",
metric_types: dict[type["PromMetric"], type["PromMetricT"]],
labelnames: list[str],
per_engine_labelvalues: dict[int, list[str]],
) -> Optional["KVConnectorPromMetrics"]:
"""
Create a KVConnectorPromMetrics subclass which should register
per-connector Prometheus metrics and implement observe() to
expose connector transfer stats via Prometheus.
"""
return None
##############################
# Helper functions
##############################
def _process_retrieve_requests(
self,
metadata: LMCacheMPConnectorMetadata,
) -> None:
blocks_per_chunk = self.scheduler_adapter.num_blocks_per_chunk()
for request_tracker in self.request_trackers.values():
if request_tracker.state != LMCacheMPRequestState.WAITING_FOR_LOAD:
continue
r_metadata = LMCacheMPRequestMetadata.GetRetrieveMetadata(
request_tracker, blocks_per_chunk
)
if r_metadata is not None:
metadata.add_request_metadata(r_metadata)
request_tracker.state = LMCacheMPRequestState.READY
def _process_new_requests(
self,
scheduler_output: SchedulerOutput,
metadata: LMCacheMPConnectorMetadata,
) -> None:
blocks_per_chunk = self.scheduler_adapter.num_blocks_per_chunk()
for new_request in scheduler_output.scheduled_new_reqs:
request_tracker = self._get_request_tracker(new_request.req_id)
num_new_tokens = scheduler_output.num_scheduled_tokens[new_request.req_id]
request_tracker.increase_num_scheduled_tokens(num_new_tokens)
r_meta = LMCacheMPRequestMetadata.GetStoreMetadata(
request_tracker, blocks_per_chunk, self.vllm_block_size
)
if r_meta is not None:
metadata.add_request_metadata(r_meta)
def _process_cached_requests(
self,
scheduler_output: SchedulerOutput,
metadata: LMCacheMPConnectorMetadata,
) -> None:
blocks_per_chunk = self.scheduler_adapter.num_blocks_per_chunk()
cached_reqs = scheduler_output.scheduled_cached_reqs
for idx, request_id in enumerate(cached_reqs.req_ids):
request_tracker = self._get_request_tracker(request_id)
# Update block ids
new_block_ids = reformat_block_ids(cached_reqs.new_block_ids[idx])
request_tracker.update_block_ids(new_block_ids)
# Update new scheduled tokens
num_new_tokens = cached_reqs.num_computed_tokens[idx]
request_tracker.increase_num_scheduled_tokens(num_new_tokens)
r_meta = LMCacheMPRequestMetadata.GetStoreMetadata(
request_tracker, blocks_per_chunk, self.vllm_block_size
)
if r_meta is not None:
metadata.add_request_metadata(r_meta)
def _get_request_tracker(self, request_id: str) -> LMCacheMPRequestTracker:
assert request_id in self.request_trackers, (
f"Request tracker for request_id {request_id} not found. "
)
return self.request_trackers[request_id]
def _get_or_create_request_tracker(
self, request: "Request"
) -> LMCacheMPRequestTracker:
request_id = request.request_id
if request_id not in self.request_trackers:
new_tracker = LMCacheMPRequestTracker(request)
self.request_trackers[request_id] = new_tracker
return self.request_trackers[request_id]

View File

@@ -0,0 +1,189 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass, field
from typing import Any, TypeAlias, TypeVar
from prometheus_client import Counter, Gauge, Histogram
from vllm.config import KVTransferConfig, VllmConfig
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
from vllm.distributed.kv_transfer.kv_transfer_state import has_kv_transfer_group
from vllm.logger import init_logger
PromMetric: TypeAlias = Gauge | Counter | Histogram
PromMetricT = TypeVar("PromMetricT", bound=PromMetric)
logger = init_logger(__name__)
@dataclass
class KVConnectorStats:
"""
Base class for KV Connector Stats, a container for transfer performance
metrics or otherwise important telemetry from the connector.
All sub-classes need to be serializable as stats are sent from worker to
logger process.
"""
data: dict[str, Any] = field(default_factory=dict)
def reset(self):
"""Reset the stats, clear the state."""
raise NotImplementedError
def aggregate(self, other: "KVConnectorStats") -> "KVConnectorStats":
"""
Aggregate stats with another `KVConnectorStats` object.
"""
raise NotImplementedError
def reduce(self) -> dict[str, int | float]:
"""
Reduce the observations collected during a time interval to one or
more representative values (eg avg/median/sum of the series).
This is meant to be called by the logger to produce a summary of the
stats for the last time interval.
"""
raise NotImplementedError
def is_empty(self) -> bool:
"""Return True if the stats are empty."""
raise NotImplementedError
class KVConnectorLogging:
def __init__(self, kv_tranfer_config: KVTransferConfig):
# This should be called on frontend process.
assert not has_kv_transfer_group()
# Instantiate the connector's stats class.
if kv_tranfer_config and kv_tranfer_config.kv_connector:
self.connector_cls = KVConnectorFactory.get_connector_class(
kv_tranfer_config
)
self.reset()
def reset(self):
self.transfer_stats_accumulator: KVConnectorStats | None = None
def observe(self, transfer_stats_data: dict[str, Any]):
# Should not be called when a KVConnector is not configured.
assert self.connector_cls is not None
# Called periodically when connector syncs with the scheduler.
# Note that this is not the same as the logging interval.
# We expect transfer_stats_data to be aggregated across all workers and
# consist of observations from a single connector or a MultiConnector.
transfer_stats = self.connector_cls.build_kv_connector_stats(
transfer_stats_data
)
if transfer_stats is None:
logger.warning_once(
"The connector %s is collecting stats but "
"does not implement the "
"`build_kv_connector_stats` method. "
"Stats will not be logged.",
self.connector_cls,
)
return
if self.transfer_stats_accumulator is None:
self.transfer_stats_accumulator = transfer_stats
else:
# Accumulate last interval stats.
self.transfer_stats_accumulator = self.transfer_stats_accumulator.aggregate(
transfer_stats
)
def log(self, log_fn=logger.info):
"""Log transfer metrics periodically, similar to throughput logging"""
if (
self.transfer_stats_accumulator
and not self.transfer_stats_accumulator.is_empty()
):
# Produce a single cumulative stats object for the last time
# interval from the recorded observations.
xfer_metrics = self.transfer_stats_accumulator.reduce()
xfer_metrics_str = ", ".join(f"{k}={v}" for k, v in xfer_metrics.items())
log_fn("KV Transfer metrics: %s", xfer_metrics_str)
# Reset metrics for next interval
self.reset()
class KVConnectorPromMetrics:
"""
A base class for per-connector Prometheus metric registration
and recording.
"""
def __init__(
self,
vllm_config: VllmConfig,
metric_types: dict[type[PromMetric], type[PromMetricT]],
labelnames: list[str],
per_engine_labelvalues: dict[int, list[str]],
):
self._kv_transfer_config = vllm_config.kv_transfer_config
self._gauge_cls = metric_types[Gauge]
self._counter_cls = metric_types[Counter]
self._histogram_cls = metric_types[Histogram]
self._labelnames = labelnames
self._per_engine_labelvalues = per_engine_labelvalues
def make_per_engine(self, metric: PromMetric) -> PromMetric:
"""
Create a per-engine child of a prometheus_client.Metric with
the appropriate labels set. The parent metric must be created
using the labelnames list.
"""
return {
idx: metric.labels(*labelvalues)
for idx, labelvalues in self._per_engine_labelvalues.items()
}
def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0):
"""
Record the supplied transfer statistics to Prometheus metrics. These
statistics are engine-specific, and should be recorded to a metric
with the appropriate 'engine' label. These metric instances can be
created using the make_per_engine() helper method.
"""
raise NotImplementedError
class KVConnectorPrometheus:
"""
Support for registering per-connector Prometheus metrics, and
recording transfer statistics to those metrics. Uses
KVConnectorBase.build_prom_metrics().
"""
_gauge_cls = Gauge
_counter_cls = Counter
_histogram_cls = Histogram
def __init__(
self,
vllm_config: VllmConfig,
labelnames: list[str],
per_engine_labelvalues: dict[int, list[str]],
):
self.prom_metrics: KVConnectorPromMetrics | None = None
kv_transfer_config = vllm_config.kv_transfer_config
if kv_transfer_config and kv_transfer_config.kv_connector:
connector_cls = KVConnectorFactory.get_connector_class(kv_transfer_config)
metric_types = {
Gauge: self._gauge_cls,
Counter: self._counter_cls,
Histogram: self._histogram_cls,
}
self.prom_metrics = connector_cls.build_prom_metrics(
vllm_config,
metric_types,
labelnames,
per_engine_labelvalues,
)
def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0):
if self.prom_metrics is None:
return
self.prom_metrics.observe(transfer_stats_data, engine_idx)

View File

@@ -0,0 +1,454 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
from collections.abc import Iterable
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
import torch
from vllm.config import VllmConfig
from vllm.config.kv_transfer import KVTransferConfig
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1,
KVConnectorMetadata,
KVConnectorRole,
)
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorPromMetrics,
KVConnectorStats,
PromMetric,
PromMetricT,
)
from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import KVConnectorOutput
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.distributed.kv_events import KVCacheEvent
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request
logger = init_logger(__name__)
@dataclass
class MultiKVConnectorMetadata(KVConnectorMetadata):
metadata: tuple[KVConnectorMetadata, ...]
extra_async_saves: dict[str, int] | None = None
@dataclass
class MultiKVConnectorStats(KVConnectorStats):
"""
Maintain a dict of KVConnectorStats objects, one for each connector.
This is used to aggregate the stats from all connectors separately.
"""
def aggregate(self, other: KVConnectorStats) -> KVConnectorStats:
for connector_id, stats in other.data.items():
if connector_id not in self.data:
self[connector_id] = stats
else:
assert isinstance(stats, type(self.data[connector_id]))
self[connector_id] = self[connector_id].aggregate(stats)
return self
def reset(self):
for stats in self.data.values():
stats.reset()
def reduce(self) -> dict[str, Any]:
# TODO (NickLucche) Adjust for logging on separate lines
return {
connector_id: stats.reduce() for connector_id, stats in self.data.items()
}
def is_empty(self) -> bool:
return all(stats.is_empty() for stats in self.data.values())
def __getitem__(self, connector_id: str) -> KVConnectorStats:
return self.data[connector_id]
def __setitem__(self, connector_id: str, stats: KVConnectorStats):
self.data[connector_id] = stats
class MultiKVConnectorPromMetrics(KVConnectorPromMetrics):
def __init__(
self,
vllm_config: "VllmConfig",
metric_types: dict[type[PromMetric], type[PromMetricT]],
labelnames: list[str],
per_engine_labelvalues: dict[int, list[str]],
prom_metrics: dict[str, KVConnectorPromMetrics],
):
super().__init__(vllm_config, metric_types, labelnames, per_engine_labelvalues)
self._prom_metrics = prom_metrics
def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0):
for connector_id, stats_data in transfer_stats_data.items():
assert connector_id in self._prom_metrics, (
f"{connector_id} is not contained in the list of registered connectors "
f"with Prometheus metrics support: {self._prom_metrics.keys()}"
)
self._prom_metrics[connector_id].observe(stats_data["data"], engine_idx)
class MultiConnector(KVConnectorBase_V1):
"""
A wrapper for using multiple KVConnectors at the same time.
The current logic is:
- Load KV from the first connector that advertises available tokens from
get_num_new_matched_tokens(), based on the order in the config.
- Save to all connectors.
"""
def __init__(
self,
vllm_config: "VllmConfig",
role: KVConnectorRole,
kv_cache_config: "KVCacheConfig",
):
super().__init__(
vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config
)
self._connectors: list[KVConnectorBase_V1] = []
self._ktc_kv_transfer_config = []
for connector_cls, temp_config in self._get_connector_classes_and_configs(
vllm_config
):
self._connectors.append(connector_cls(temp_config, role, kv_cache_config))
self._ktc_kv_transfer_config.append(temp_config.kv_transfer_config)
# A mapping from request id to the index of the connector chosen to
# load the request from (if any).
self._requests_to_connector: dict[str, int] = {}
# Keeps track of *additional* remaining async saves (beyond 1) to be
# finished per request. Not needed for async loads since we only allow
# a single connector to load.
# Propagated from scheduler to worker side via the connector metadata.
self._extra_async_saves: dict[str, int] = {}
@classmethod
def _get_connector_classes_and_configs(
cls, vllm_config: "VllmConfig"
) -> list[tuple[type[KVConnectorBaseType], "VllmConfig"]]:
assert vllm_config.kv_transfer_config is not None
ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"connectors"
)
assert ktcs is not None
ret: list[tuple[type[KVConnectorBaseType], VllmConfig]] = []
for ktc in ktcs:
temp_config = copy.copy(vllm_config)
engine_id = ktc.get("engine_id", vllm_config.kv_transfer_config.engine_id)
temp_config.kv_transfer_config = KVTransferConfig(
**ktc, engine_id=engine_id
)
ret.append(
(
KVConnectorFactory.get_connector_class(
temp_config.kv_transfer_config
),
temp_config,
)
)
return ret
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
for c in self._connectors:
c.register_kv_caches(kv_caches)
# We must override the base class method here because we need to bind
# the metadata to each connector in the order of the connectors in the
# MultiKVConnectorMetadata.
#
# Note: Call the base class method to ensure metadata is also set on the
# MultiConnector instance itself; otherwise, `has_connector_metadata()` will
# always return False.
def bind_connector_metadata(self, connector_metadata: KVConnectorMetadata) -> None:
assert isinstance(connector_metadata, MultiKVConnectorMetadata)
if connector_metadata.extra_async_saves:
self._extra_async_saves.update(connector_metadata.extra_async_saves)
for c, cm in zip(self._connectors, connector_metadata.metadata):
c.bind_connector_metadata(cm)
super().bind_connector_metadata(connector_metadata)
def clear_connector_metadata(self) -> None:
for c in self._connectors:
c.clear_connector_metadata()
super().clear_connector_metadata()
def shutdown(self):
exception: Exception | None = None
for c in self._connectors:
try:
c.shutdown()
except Exception as e:
logger.exception(
"Exception during connector %s shutdown.", c.__class__.__name__
)
exception = e
if exception:
raise exception
# ==============================
# Worker-side methods
# ==============================
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
for c in self._connectors:
c.start_load_kv(forward_context, **kwargs)
def wait_for_layer_load(self, layer_name: str) -> None:
for c in self._connectors:
c.wait_for_layer_load(layer_name)
def save_kv_layer(
self,
layer_name: str,
kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata",
**kwargs,
) -> None:
for c in self._connectors:
c.save_kv_layer(layer_name, kv_layer, attn_metadata, **kwargs)
def wait_for_save(self):
for c in self._connectors:
c.wait_for_save()
def get_finished(
self, finished_req_ids: set[str]
) -> tuple[set[str] | None, set[str] | None]:
finished_sending: set[str] = set()
finished_recving: set[str] = set()
for c in self._connectors:
sending, recving = c.get_finished(finished_req_ids)
if not recving and not sending:
continue
# Aggregate finished recving request ids.
finished_recving.update(recving or ())
# Aggregate finished sending request ids - only include
# once we've drained the "extra" count (for cases where
# more than one connector is async-saving the same request).
for req_id in sending or ():
extra_pending = self._extra_async_saves.get(req_id)
if extra_pending is None:
finished_sending.add(req_id)
continue
assert extra_pending > 0
if extra_pending == 1:
del self._extra_async_saves[req_id]
else:
self._extra_async_saves[req_id] = extra_pending - 1
return finished_sending or None, finished_recving or None
def get_block_ids_with_load_errors(self) -> set[int]:
agg_block_ids: set[int] = set()
for c in self._connectors:
agg_block_ids |= c.get_block_ids_with_load_errors()
return agg_block_ids
# ==============================
# Scheduler-side methods
# ==============================
def get_num_new_matched_tokens(
self,
request: "Request",
num_computed_tokens: int,
) -> tuple[int | None, bool]:
to_return = (0, False)
for i, c in enumerate(self._connectors):
toks, load_async = c.get_num_new_matched_tokens(
request, num_computed_tokens
)
# If there is a connector still looking up the matches,
# we return None to indicate that we are not done yet.
if toks is None:
return (None, False)
# The first connector that has new matched tokens will be assigned
# to this request.
if to_return[0] == 0 and toks > 0:
self._requests_to_connector[request.request_id] = i
to_return = (toks, load_async)
return to_return
def update_state_after_alloc(
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
):
chosen_connector = self._requests_to_connector.get(request.request_id, -1)
empty_blocks = blocks.new_empty()
for i, c in enumerate(self._connectors):
if i == chosen_connector:
# Forward call to the chosen connector (if any).
c.update_state_after_alloc(request, blocks, num_external_tokens)
else:
# Call with empty blocks for other connectors.
c.update_state_after_alloc(request, empty_blocks, 0)
def build_connector_meta(
self, scheduler_output: SchedulerOutput
) -> MultiKVConnectorMetadata:
metadata = MultiKVConnectorMetadata(
metadata=tuple(
c.build_connector_meta(scheduler_output) for c in self._connectors
)
)
if self._extra_async_saves:
metadata.extra_async_saves = self._extra_async_saves
self._extra_async_saves = {}
return metadata
def update_connector_output(self, connector_output: KVConnectorOutput):
for c in self._connectors:
c.update_connector_output(connector_output)
def request_finished(
self,
request: "Request",
blocks: list[int],
) -> tuple[bool, dict[str, Any] | None]:
async_saves = 0
kv_txfer_params = None
for c in self._connectors:
async_save, txfer_params = c.request_finished(request, blocks)
if async_save:
async_saves += 1
if txfer_params is not None:
if kv_txfer_params is not None:
# TODO we can probably change this to merge the dicts here,
# checking for key clashes.
raise RuntimeError(
"Only one connector can produce KV transfer params"
)
kv_txfer_params = txfer_params
if async_saves > 1:
self._extra_async_saves[request.request_id] = async_saves - 1
# Clean up other state for this request.
self._requests_to_connector.pop(request.request_id, None)
return async_saves > 0, kv_txfer_params
def take_events(self) -> Iterable["KVCacheEvent"]:
for c in self._connectors:
yield from c.take_events()
@classmethod
def get_required_kvcache_layout(cls, vllm_config: "VllmConfig") -> str | None:
"""
Get the required KV cache layout for this connector.
Args:
vllm_config (VllmConfig): the vllm config.
Returns:
str: the required KV cache layout. e.g. HND, or NHD.
None if the connector does not require a specific layout.
"""
assert vllm_config.kv_transfer_config is not None
layouts: set[str] = set()
for connector_cls, temp_config in cls._get_connector_classes_and_configs(
vllm_config
):
required_kvcache_layout = connector_cls.get_required_kvcache_layout(
temp_config
)
if required_kvcache_layout is not None:
layouts.add(required_kvcache_layout)
if len(layouts) > 1:
raise ValueError(
f"KV cache layout mismatch: "
f"found {len(layouts)} different layouts "
f"({', '.join(layouts)})."
f"All connectors must use the same layout."
)
return next(iter(layouts), None)
@classmethod
def build_kv_connector_stats(
cls, data: dict[str, Any] | None = None
) -> KVConnectorStats | None:
if data is None:
return MultiKVConnectorStats()
# data is a dict mapping connector name to their stats data.
# The stats data can be either:
# 1. Already-instantiated KVConnectorStats objects (same process)
# 2. Serialized dicts (cross-process after serialization)
# We need to reconstruct proper KVConnectorStats objects from dicts
reconstructed_data = {}
for connector_name, stats_value in data.items():
# If already a KVConnectorStats object, use it directly
if isinstance(stats_value, KVConnectorStats):
reconstructed_data[connector_name] = stats_value
continue
# Otherwise, reconstruct from serialized dict
# Get the connector class to reconstruct its stats
connector_cls = KVConnectorFactory.get_connector_class_by_name(
connector_name
)
# stats_value is the serialized dataclass which contains {'data': {...}}
# We need to extract the inner 'data' field to avoid double-nesting
assert isinstance(stats_value, dict) and "data" in stats_value, (
f"Expected a dict with a 'data' field, got {stats_value}"
)
inner_data = stats_value["data"]
# Use the connector's build_kv_connector_stats to reconstruct
if reconstructed_stats := connector_cls.build_kv_connector_stats(
data=inner_data
):
reconstructed_data[connector_name] = reconstructed_stats
return MultiKVConnectorStats(data=reconstructed_data)
def get_kv_connector_stats(self) -> MultiKVConnectorStats | None:
# Group connector stats by connector type.
stats_by_connector: MultiKVConnectorStats | None = None
for c in self._connectors:
stats = c.get_kv_connector_stats()
if stats is None:
continue
if stats_by_connector is None:
# Lazy init to allow optional return value.
stats_by_connector = MultiKVConnectorStats()
stats_by_connector[c.__class__.__name__] = stats
return stats_by_connector
@classmethod
def build_prom_metrics(
cls,
vllm_config: "VllmConfig",
metric_types: dict[type["PromMetric"], type["PromMetricT"]],
labelnames: list[str],
per_engine_labelvalues: dict[int, list[str]],
) -> KVConnectorPromMetrics:
prom_metrics: dict[str, KVConnectorPromMetrics] = {}
for connector_cls, temp_config in cls._get_connector_classes_and_configs(
vllm_config
):
connector_prom = connector_cls.build_prom_metrics(
temp_config, metric_types, labelnames, per_engine_labelvalues
)
if connector_prom is not None:
prom_metrics[connector_cls.__name__] = connector_prom
return MultiKVConnectorPromMetrics(
vllm_config,
metric_types,
labelnames,
per_engine_labelvalues,
prom_metrics,
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,504 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import defaultdict
from collections.abc import Iterable, Iterator
from dataclasses import dataclass
from itertools import islice
from typing import Any
import torch
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed.kv_events import BlockRemoved, BlockStored, KVCacheEvent
from vllm.distributed.kv_transfer.kv_connector.v1 import (
KVConnectorBase_V1,
KVConnectorRole,
)
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata
from vllm.forward_context import ForwardContext
from vllm.logger import init_logger
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.kv_cache_utils import BlockHash
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.kv_offload.abstract import OffloadingManager
from vllm.v1.kv_offload.factory import OffloadingSpecFactory
from vllm.v1.kv_offload.mediums import GPULoadStoreSpec
from vllm.v1.kv_offload.spec import OffloadingSpec
from vllm.v1.kv_offload.worker.worker import OffloadingWorker, TransferSpec
from vllm.v1.outputs import KVConnectorOutput
from vllm.v1.request import Request
ReqId = str
logger = init_logger(__name__)
@dataclass
class OffloadingConnectorMetadata(KVConnectorMetadata):
reqs_to_load: dict[ReqId, TransferSpec]
reqs_to_store: dict[ReqId, TransferSpec]
class OffloadingConnector(KVConnectorBase_V1):
def __init__(
self,
vllm_config: VllmConfig,
role: KVConnectorRole,
kv_cache_config: KVCacheConfig | None = None,
):
super().__init__(vllm_config, role, kv_cache_config)
spec = OffloadingSpecFactory.create_spec(vllm_config)
self.connector_scheduler: OffloadingConnectorScheduler | None = None
self.connector_worker: OffloadingConnectorWorker | None = None
if role == KVConnectorRole.SCHEDULER:
self.connector_scheduler = OffloadingConnectorScheduler(spec)
elif role == KVConnectorRole.WORKER:
self.connector_worker = OffloadingConnectorWorker(spec)
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
assert self.connector_worker is not None
self.connector_worker.register_kv_caches(kv_caches)
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
assert self.connector_worker is not None
assert isinstance(self._connector_metadata, OffloadingConnectorMetadata)
self.connector_worker.start_load_kv(self._connector_metadata)
def wait_for_layer_load(self, layer_name: str) -> None:
pass
def save_kv_layer(
self,
layer_name: str,
kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata",
**kwargs,
) -> None:
pass
def wait_for_save(self):
assert self.connector_worker is not None
assert isinstance(self._connector_metadata, OffloadingConnectorMetadata)
self.connector_worker.start_store_kv(self._connector_metadata)
def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
assert self.connector_worker is not None
return self.connector_worker.get_finished(finished_req_ids)
def get_num_new_matched_tokens(
self, request: "Request", num_computed_tokens: int
) -> tuple[int, bool]:
assert self.connector_scheduler is not None
return self.connector_scheduler.get_num_new_matched_tokens(
request, num_computed_tokens
)
def update_state_after_alloc(
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
):
assert self.connector_scheduler is not None
return self.connector_scheduler.update_state_after_alloc(
request, blocks, num_external_tokens
)
def build_connector_meta(
self, scheduler_output: SchedulerOutput
) -> KVConnectorMetadata:
assert self.connector_scheduler is not None
return self.connector_scheduler.build_connector_meta(scheduler_output)
def update_connector_output(self, connector_output: KVConnectorOutput):
assert self.connector_scheduler is not None
self.connector_scheduler.update_connector_output(connector_output)
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, dict[str, Any] | None]:
assert self.connector_scheduler is not None
return self.connector_scheduler.request_finished(request, block_ids)
def take_events(self) -> Iterable[KVCacheEvent]:
assert self.connector_scheduler is not None
return self.connector_scheduler.take_events()
class OffloadingConnectorScheduler:
"""Implementation of Scheduler side methods"""
def __init__(self, spec: OffloadingSpec):
self.gpu_block_size = spec.gpu_block_size
self.offloaded_block_size = spec.offloaded_block_size
self.block_size_factor = self.offloaded_block_size // self.gpu_block_size
self.manager: OffloadingManager = spec.get_manager()
self._requests: dict[ReqId, Request] = {}
# list of GPU block IDs per request
self._request_block_ids: dict[ReqId, list[int]] = {}
# requests to load for the current scheduler step
self._reqs_to_load: dict[ReqId, TransferSpec] = {}
# request blocks are stored in order
# index of next block (of size offloaded_block_size) to offload
self._next_stored_block_idx: dict[ReqId, int] = {}
# request ID -> set(block hashes being stored/load)
self._reqs_being_stored = defaultdict[ReqId, set[BlockHash]](set)
self._reqs_being_loaded = defaultdict[ReqId, set[BlockHash]](set)
def _get_block_hashes(
self,
req: Request,
start_idx: int = 0,
end_idx: int | None = None,
) -> Iterable[BlockHash]:
return islice(
req.block_hashes,
self.block_size_factor * start_idx + self.block_size_factor - 1,
self.block_size_factor * end_idx if end_idx else None,
self.block_size_factor,
)
def get_num_new_matched_tokens(
self, request: Request, num_computed_tokens: int
) -> tuple[int, bool]:
"""
Get number of new tokens that can be loaded beyond the
num_computed_tokens.
Args:
request (Request): the request object.
num_computed_tokens (int): the number of locally
computed tokens for this request
Returns:
A tuple with the following elements:
- The number of tokens that can be loaded beyond what is
already computed.
- `True` if tokens will be loaded asynchronously
(between scheduler steps).
"""
num_blocks = request.num_tokens // self.offloaded_block_size
assert len(request.block_hashes) // self.block_size_factor == num_blocks
block_hashes = self._get_block_hashes(request)
self.manager.touch(block_hashes)
full_block_tokens = self.offloaded_block_size * num_blocks
if full_block_tokens - num_computed_tokens < self.offloaded_block_size:
# we can load less than a block, skip
return 0, False
start_block_idx = num_computed_tokens // self.offloaded_block_size
hits = self.manager.lookup(
self._get_block_hashes(request, start_idx=start_block_idx)
)
if hits == 0:
return 0, False
num_hit_tokens = (
self.offloaded_block_size * (start_block_idx + hits) - num_computed_tokens
)
logger.debug(
"Request %s hit %s offloaded tokens after %s GPU hit tokens",
request.request_id,
num_hit_tokens,
num_computed_tokens,
)
if num_hit_tokens < self.offloaded_block_size:
return 0, False
return num_hit_tokens, True
def update_state_after_alloc(
self, request: Request, blocks: KVCacheBlocks, num_external_tokens: int
):
self._requests[request.request_id] = request
# the block ids are updated in _get_reqs_to_store
self._request_block_ids[request.request_id] = []
if num_external_tokens == 0:
return
block_groups = blocks.get_block_ids()
block_ids = block_groups[0]
num_computed_gpu_blocks = sum(
block.block_hash is not None for block in blocks.blocks[0]
)
num_computed_tokens = num_computed_gpu_blocks * self.gpu_block_size
full_block_tokens = num_computed_tokens + num_external_tokens
assert full_block_tokens % self.offloaded_block_size == 0
num_pending_gpu_blocks = len(block_ids) - num_computed_gpu_blocks
assert num_external_tokens == num_pending_gpu_blocks * self.gpu_block_size
start_block_idx = num_computed_tokens // self.offloaded_block_size
num_blocks = full_block_tokens // self.offloaded_block_size
assert len(request.block_hashes) // self.block_size_factor >= num_blocks
block_hashes = self._get_block_hashes(
request, start_idx=start_block_idx, end_idx=num_blocks
)
src_spec = self.manager.prepare_load(block_hashes)
dst_spec = GPULoadStoreSpec(block_ids[num_computed_gpu_blocks:])
block_hashes = self._get_block_hashes(
request, start_idx=start_block_idx, end_idx=num_blocks
)
self._reqs_to_load[request.request_id] = (src_spec, dst_spec)
self._reqs_being_loaded[request.request_id].update(block_hashes)
self._next_stored_block_idx[request.request_id] = num_blocks
def _get_reqs_to_store(self, scheduler_output: SchedulerOutput):
reqs_to_store: dict[ReqId, TransferSpec] = {}
# iterate over both new and cached requests
for req_id, new_block_id_groups, preempted in yield_req_data(scheduler_output):
if preempted:
self._request_block_ids[req_id] = []
if new_block_id_groups:
new_block_ids = new_block_id_groups[0]
self._request_block_ids[req_id] += new_block_ids
block_ids = self._request_block_ids[req_id]
req = self._requests[req_id]
new_tokens = scheduler_output.num_scheduled_tokens[req_id]
total_tokens = req.num_computed_tokens + new_tokens
num_blocks = total_tokens // self.offloaded_block_size
start_block_idx = self._next_stored_block_idx.get(req_id, 0)
num_new_blocks = num_blocks - start_block_idx
if num_new_blocks <= 0:
continue
# NOTE: In async scheduling, placeholders may temporarily make
# len(req.block_hashes) < num_blocks * self.block_size_factor.
new_block_hashes = self._get_block_hashes(
req, start_idx=start_block_idx, end_idx=num_blocks
)
store_output = self.manager.prepare_store(new_block_hashes)
if store_output is None:
logger.warning(
"Request %s: cannot store %s blocks", req_id, num_new_blocks
)
continue
self._next_stored_block_idx[req_id] = num_blocks
if not store_output.block_hashes_to_store:
continue
block_hashes_to_store = set(store_output.block_hashes_to_store)
block_hashes = self._get_block_hashes(req, end_idx=num_blocks)
self.manager.touch(block_hashes)
new_block_hashes = self._get_block_hashes(
req, start_idx=start_block_idx, end_idx=num_blocks
)
dst_spec = store_output.store_spec
src_block_ids: list[int] = []
for idx, blk_hash in enumerate(new_block_hashes):
if blk_hash not in block_hashes_to_store:
continue
offloaded_block_idx = start_block_idx + idx
gpu_block_idx = offloaded_block_idx * self.block_size_factor
for i in range(self.block_size_factor):
src_block_ids.append(block_ids[gpu_block_idx + i])
src_spec = GPULoadStoreSpec(src_block_ids)
reqs_to_store[req_id] = (src_spec, dst_spec)
self._reqs_being_stored[req_id] |= block_hashes_to_store
logger.debug(
"Request %s offloading %s blocks starting from block #%d",
req_id,
len(block_hashes_to_store),
start_block_idx,
)
return reqs_to_store
def build_connector_meta(
self, scheduler_output: SchedulerOutput
) -> KVConnectorMetadata:
meta = OffloadingConnectorMetadata(
reqs_to_load=self._reqs_to_load,
reqs_to_store=self._get_reqs_to_store(scheduler_output),
)
self._reqs_to_load = {}
return meta
def update_connector_output(self, connector_output: KVConnectorOutput):
"""
Update KVConnector state from worker-side connectors output.
Args:
connector_output (KVConnectorOutput): the worker-side
connectors output.
"""
for req_id in connector_output.finished_sending or []:
block_hashes = self._reqs_being_stored.pop(req_id, None)
if block_hashes:
self.manager.complete_store(block_hashes)
for req_id in connector_output.finished_recving or []:
block_hashes = self._reqs_being_loaded.pop(req_id, None)
if block_hashes:
self.manager.complete_load(block_hashes)
def request_finished(
self,
request: Request,
block_ids: list[int],
) -> tuple[bool, dict[str, Any] | None]:
"""
Called when a request has finished, before its blocks are freed.
Returns:
True if the request is being saved/sent asynchronously and blocks
should not be freed until the request_id is returned from
get_finished().
Optional KVTransferParams to be included in the request outputs
returned by the engine.
"""
req_id = request.request_id
self._requests.pop(req_id, None)
self._request_block_ids.pop(req_id, None)
self._next_stored_block_idx.pop(req_id, None)
request_being_stored = req_id in self._reqs_being_stored
return request_being_stored, None
def take_events(self) -> Iterable[KVCacheEvent]:
"""Take the KV cache events from the connector.
Returns:
A list of KV cache events.
"""
for event in self.manager.take_events():
if event.removed:
yield BlockRemoved(block_hashes=event.block_hashes, medium=event.medium)
else:
yield BlockStored(
block_hashes=event.block_hashes,
parent_block_hash=None,
token_ids=[],
lora_id=None,
block_size=event.block_size,
medium=event.medium,
)
class OffloadingConnectorWorker:
"""Implementation of Worker side methods"""
def __init__(self, spec: OffloadingSpec):
self.spec = spec
self.worker = OffloadingWorker()
self._job_counter = 0
# req_id -> (job_id, store)
self._jobs: dict[int, tuple[ReqId, bool]] = {}
# req_id -> active job IDs
self._load_job: dict[ReqId, int] = {}
# req_id -> set(active job IDs)
self._store_jobs = defaultdict[ReqId, set[int]](set)
self._finished_reqs_waiting_for_store: set[ReqId] = set()
def _generate_job_id(self) -> int:
job_id = self._job_counter
self._job_counter = job_id + 1
return job_id
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
for src_cls, dst_cls, handler in self.spec.get_handlers(kv_caches):
self.worker.register_handler(src_cls, dst_cls, handler)
def start_load_kv(self, metadata: OffloadingConnectorMetadata):
for req_id, transfer_spec in metadata.reqs_to_load.items():
job_id = self._generate_job_id()
self._jobs[job_id] = (req_id, False)
assert req_id not in self._load_job
self._load_job[req_id] = job_id
assert self.worker.transfer_async(job_id, transfer_spec)
def start_store_kv(self, metadata: OffloadingConnectorMetadata):
for req_id, transfer_spec in metadata.reqs_to_store.items():
job_id = self._generate_job_id()
self._jobs[job_id] = (req_id, True)
self._store_jobs[req_id].add(job_id)
assert self.worker.transfer_async(job_id, transfer_spec)
def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
"""
Notifies worker-side connector ids of requests that have
finished generating tokens.
Returns a list of request IDs that finished loading or storing.
Returns:
ids of requests that have finished asynchronous transfer
tuple of (sending/saving ids, recving/loading ids).
"""
finished_sending = set()
finished_recving = set()
for job_id, success in self.worker.get_finished():
# we currently do not support job failures
assert success
req_id, store = self._jobs.pop(job_id)
if store:
req_jobs = self._store_jobs[req_id]
req_jobs.remove(job_id)
if req_jobs:
continue
if req_id in self._finished_reqs_waiting_for_store:
self._finished_reqs_waiting_for_store.remove(req_id)
finished_sending.add(req_id)
del self._store_jobs[req_id]
else:
req_job = self._load_job[req_id]
assert job_id == req_job
del self._load_job[req_id]
finished_recving.add(req_id)
for req_id in finished_req_ids:
pending_req_jobs = self._store_jobs.get(req_id)
if pending_req_jobs:
self._finished_reqs_waiting_for_store.add(req_id)
elif pending_req_jobs is not None:
finished_sending.add(req_id)
del self._store_jobs[req_id]
return finished_sending, finished_recving
def yield_req_data(
scheduler_output,
) -> Iterator[tuple[str, tuple[list[int], ...], bool]]:
"""
Yields:
(req_id, new_block_id_groups, preempted)
"""
# new requests
for req_data in scheduler_output.scheduled_new_reqs:
yield req_data.req_id, req_data.block_ids, False
# cached requests
cached_reqs = scheduler_output.scheduled_cached_reqs
yield from zip(
cached_reqs.req_ids,
cached_reqs.new_block_ids,
(req_id in cached_reqs.resumed_req_ids for req_id in cached_reqs.req_ids),
)

View File

@@ -0,0 +1,531 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional
import regex as re
import torch
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1,
KVConnectorMetadata,
KVConnectorRole,
)
from vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_engine import (
P2pNcclEngine,
)
from vllm.distributed.parallel_state import get_world_group
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request
logger = init_logger(__name__)
@dataclass
class ReqMeta:
# Request Id
request_id: str
# Request block ids
block_ids: torch.Tensor
# Request num tokens
num_tokens: int
@staticmethod
def make_meta(
request_id: str, token_ids: list[int], block_ids: list[int], block_size: int
) -> "ReqMeta":
block_ids_tensor = torch.tensor(block_ids)
return ReqMeta(
request_id=request_id,
block_ids=block_ids_tensor,
num_tokens=len(token_ids),
)
@dataclass
class P2pNcclConnectorMetadata(KVConnectorMetadata):
requests: list[ReqMeta]
def __init__(self):
self.requests = []
def add_request(
self,
request_id: str,
token_ids: list[int],
block_ids: list[int],
block_size: int,
) -> None:
self.requests.append(
ReqMeta.make_meta(request_id, token_ids, block_ids, block_size)
)
class P2pNcclConnector(KVConnectorBase_V1):
def __init__(
self,
vllm_config: "VllmConfig",
role: KVConnectorRole,
kv_cache_config: Optional["KVCacheConfig"] = None,
):
super().__init__(
vllm_config=vllm_config,
role=role,
kv_cache_config=kv_cache_config,
)
self._block_size = vllm_config.cache_config.block_size
self._requests_need_load: dict[str, Any] = {}
self.is_producer = self._kv_transfer_config.is_kv_producer
self.chunked_prefill: dict[str, tuple[list[int], list[int] | None]] = {}
self._rank = get_world_group().rank if role == KVConnectorRole.WORKER else 0
self._local_rank = (
get_world_group().local_rank if role == KVConnectorRole.WORKER else 0
)
self.p2p_nccl_engine = (
P2pNcclEngine(
local_rank=self._local_rank,
config=self._kv_transfer_config,
hostname="",
port_offset=self._rank,
)
if role == KVConnectorRole.WORKER
else None
)
# ==============================
# Worker-side methods
# ==============================
def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None:
"""Start loading the KV cache from the connector buffer to vLLM's
paged KV buffer.
Args:
forward_context (ForwardContext): the forward context.
**kwargs: additional arguments for the load operation
Note:
The number of elements in kv_caches and layer_names should be
the same.
"""
# Only consumer/decode loads KV Cache
if self.is_producer:
return
assert self.p2p_nccl_engine is not None
attn_metadata = forward_context.attn_metadata
if attn_metadata is None:
return
def inject_kv_into_layer(
layer: torch.Tensor,
kv_cache: torch.Tensor,
block_ids: torch.Tensor,
request_id: str,
) -> None:
"""
Inject KV cache data into a given attention layer tensor.
This function updates `layer` in-place with values from `kv_cache`,
handling different backend layouts:
- MLA (Multi-Linear Attention) or FlashInfer: KV tensors are
indexed along the first dimension.
- FlashAttention: KV tensors are indexed along the second
dimension.
If the number of provided block IDs does not match the number of KV
blocks, only the overlapping portion is updated, and a warning is
logged.
Args:
layer (torch.Tensor): The attention layer KV tensor to update.
kv_cache (torch.Tensor): The KV cache tensor to inject.
block_ids (torch.Tensor): Indices of the blocks to update.
request_id (str): Request identifier used for logging.
Returns:
None. The function modifies `layer` in-place.
"""
if (
isinstance(attn_metadata, MLACommonMetadata) or layer.shape[1] == 2
): # MLA or FlashInfer
num_block = kv_cache.shape[0]
self.check_tensors_except_dim(layer, kv_cache, 0)
if len(block_ids) == num_block:
layer[block_ids, ...] = kv_cache
else:
layer[block_ids[:num_block], ...] = kv_cache
logger.warning(
"🚧kv_cache does not match, block_ids:%d, "
"num_block:%d, request_id:%s",
len(block_ids),
num_block,
request_id,
)
elif layer.shape[0] == 2: # FlashAttention
num_block = kv_cache.shape[1]
self.check_tensors_except_dim(layer, kv_cache, 1)
if len(block_ids) == num_block:
layer[:, block_ids, ...] = kv_cache
else:
layer[:, block_ids[:num_block], ...] = kv_cache
logger.warning(
"🚧kv_cache does not match, block_ids:%d, "
"num_block:%d, request_id:%s",
len(block_ids),
num_block,
request_id,
)
# Get the metadata
metadata: KVConnectorMetadata = self._get_connector_metadata()
assert isinstance(metadata, P2pNcclConnectorMetadata)
if metadata is None:
return
# Load the KV for each request each layer
for request in metadata.requests:
request_id = request.request_id
ip, port = self.parse_request_id(request_id, False)
remote_address = ip + ":" + str(port + self._rank)
for layer_name in forward_context.no_compile_layers:
layer = forward_context.no_compile_layers[layer_name]
# Only process layers that have kv_cache
# attribute (attention layers) Skip non-attention
# layers like FusedMoE
kv_cache = getattr(layer, "kv_cache", None)
if kv_cache is None:
continue
layer = kv_cache[forward_context.virtual_engine]
kv_cache = self.p2p_nccl_engine.recv_tensor(
request.request_id + "#" + layer_name, remote_address
)
if kv_cache is None:
logger.warning("🚧kv_cache is None, %s", request.request_id)
continue
inject_kv_into_layer(
layer, kv_cache, request.block_ids, request.request_id
)
def wait_for_layer_load(self, layer_name: str) -> None:
"""Blocking until the KV for a specific layer is loaded into vLLM's
paged buffer.
This interface will be useful for layer-by-layer pipelining.
Args:
layer_name: the name of that layer
"""
return
def save_kv_layer(
self,
layer_name: str,
kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata",
**kwargs: Any,
) -> None:
"""Start saving the KV cache of the layer from vLLM's paged buffer
to the connector.
Args:
layer_name (str): the name of the layer.
kv_layer (torch.Tensor): the paged KV buffer of the current
layer in vLLM.
attn_metadata (AttentionMetadata): the attention metadata.
**kwargs: additional arguments for the save operation.
"""
# Only producer/prefill saves KV Cache
if not self.is_producer:
return
assert self.p2p_nccl_engine is not None
def extract_kv_from_layer(
layer: torch.Tensor,
block_ids: torch.Tensor,
) -> torch.Tensor:
"""
Extract KV cache slices from a given attention layer tensor.
This function handles multiple backend layouts:
- MLA (Multi-Linear Attention) or FlashInfer: KV tensors are
indexed along the first dimension.
- FlashAttention: KV tensors are indexed along the second
dimension.
Args:
layer (torch.Tensor): The KV cache from the attention layer.
block_ids (torch.Tensor): Indices of blocks to extract.
Returns:
torch.Tensor: A tensor containing the extracted KV slices.
Returns None if the layout is unsupported.
"""
if (
isinstance(attn_metadata, MLACommonMetadata) or layer.shape[1] == 2
): # MLA or FlashInfer
return layer[block_ids, ...]
if layer.shape[0] == 2: # FlashAttention
return layer[:, block_ids, ...]
return None
connector_metadata = self._get_connector_metadata()
assert isinstance(connector_metadata, P2pNcclConnectorMetadata)
for request in connector_metadata.requests:
request_id = request.request_id
ip, port = self.parse_request_id(request_id, True)
remote_address = ip + ":" + str(port + self._rank)
kv_cache = extract_kv_from_layer(kv_layer, request.block_ids)
self.p2p_nccl_engine.send_tensor(
request_id + "#" + layer_name, kv_cache, remote_address
)
def wait_for_save(self):
if self.is_producer:
assert self.p2p_nccl_engine is not None
self.p2p_nccl_engine.wait_for_sent()
def get_finished(
self, finished_req_ids: set[str], **kwargs: Any
) -> tuple[set[str] | None, set[str] | None]:
"""
Notifies worker-side connector ids of requests that have
finished generating tokens.
Returns:
ids of requests that have finished asynchronous transfer,
tuple of (sending/saving ids, recving/loading ids).
The finished saves/sends req ids must belong to a set provided in a
call to this method (this call or a prior one).
"""
assert self.p2p_nccl_engine is not None
no_compile_layers = self._vllm_config.compilation_config.static_forward_context
return self.p2p_nccl_engine.get_finished(finished_req_ids, no_compile_layers)
# ==============================
# Scheduler-side methods
# ==============================
def get_num_new_matched_tokens(
self,
request: "Request",
num_computed_tokens: int,
) -> tuple[int, bool]:
"""
Get number of new tokens that can be loaded from the
external KV cache beyond the num_computed_tokens.
Args:
request (Request): the request object.
num_computed_tokens (int): the number of locally
computed tokens for this request
Returns:
the number of tokens that can be loaded from the
external KV cache beyond what is already computed.
"""
if self.is_producer:
return 0, False
prompt_token_ids = request.prompt_token_ids or []
num_external_tokens = len(prompt_token_ids) - 1 - num_computed_tokens
if num_external_tokens < 0:
num_external_tokens = 0
return num_external_tokens, False
def update_state_after_alloc(
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
):
"""
Update KVConnector state after block allocation.
"""
if not self.is_producer and num_external_tokens > 0:
self._requests_need_load[request.request_id] = (
request,
blocks.get_block_ids()[0],
)
def build_connector_meta(
self,
scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata:
"""Build the connector metadata for this step.
This function should NOT modify any fields in the scheduler_output.
Also, calling this function will reset the state of the connector.
Args:
scheduler_output (SchedulerOutput): the scheduler output object.
"""
meta = P2pNcclConnectorMetadata()
for new_req in scheduler_output.scheduled_new_reqs:
if self.is_producer:
num_scheduled_tokens = (scheduler_output.num_scheduled_tokens)[
new_req.req_id
]
num_tokens = num_scheduled_tokens + new_req.num_computed_tokens
# the request's prompt is chunked prefill
if num_tokens < len(new_req.prompt_token_ids or []):
# 'CachedRequestData' has no attribute 'prompt_token_ids'
self.chunked_prefill[new_req.req_id] = (
new_req.block_ids[0],
new_req.prompt_token_ids,
)
continue
# the request's prompt is not chunked prefill
meta.add_request(
request_id=new_req.req_id,
token_ids=new_req.prompt_token_ids or [],
block_ids=new_req.block_ids[0],
block_size=self._block_size,
)
continue
if new_req.req_id in self._requests_need_load:
meta.add_request(
request_id=new_req.req_id,
token_ids=new_req.prompt_token_ids or [],
block_ids=new_req.block_ids[0],
block_size=self._block_size,
)
self._requests_need_load.pop(new_req.req_id)
cached_reqs = scheduler_output.scheduled_cached_reqs
for i, req_id in enumerate(cached_reqs.req_ids):
num_computed_tokens = cached_reqs.num_computed_tokens[i]
new_block_ids = cached_reqs.new_block_ids[i]
resumed_from_preemption = req_id in cached_reqs.resumed_req_ids
if self.is_producer:
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
num_tokens = num_scheduled_tokens + num_computed_tokens
assert req_id in self.chunked_prefill
assert new_block_ids is not None
block_ids = new_block_ids[0]
if not resumed_from_preemption:
block_ids = self.chunked_prefill[req_id][0] + block_ids
prompt_token_ids = self.chunked_prefill[req_id][1]
assert prompt_token_ids is not None
# the request's prompt is chunked prefill again
if num_tokens < len(prompt_token_ids):
self.chunked_prefill[req_id] = (block_ids, prompt_token_ids)
continue
# the request's prompt is all prefilled finally
meta.add_request(
request_id=req_id,
token_ids=prompt_token_ids,
block_ids=block_ids,
block_size=self._block_size,
)
self.chunked_prefill.pop(req_id, None)
continue
# NOTE(rob): here we rely on the resumed requests being
# the first N requests in the list scheduled_cache_reqs.
if not resumed_from_preemption:
break
if req_id in self._requests_need_load:
request, _ = self._requests_need_load.pop(req_id)
total_tokens = num_computed_tokens + 1
token_ids = request.all_token_ids[:total_tokens]
# NOTE(rob): For resumed req, new_block_ids is all
# of the block_ids for the request.
assert new_block_ids is not None
block_ids = new_block_ids[0]
meta.add_request(
request_id=req_id,
token_ids=token_ids,
block_ids=block_ids,
block_size=self._block_size,
)
self._requests_need_load.clear()
return meta
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, dict[str, Any] | None]:
"""
Called when a request has finished, before its blocks are freed.
Returns:
True if the request is being saved/sent asynchronously and blocks
should not be freed until the request_id is returned from
get_finished().
Optional KVTransferParams to be included in the request outputs
returned by the engine.
"""
self.chunked_prefill.pop(request.request_id, None)
return False, None
# ==============================
# Static methods
# ==============================
@staticmethod
def parse_request_id(request_id: str, is_prefill=True) -> tuple[str, int]:
# Regular expression to match the string hostname and integer port
if is_prefill:
pattern = r"___decode_addr_(.*):(\d+)"
else:
pattern = r"___prefill_addr_(.*):(\d+)___"
# Use re.search to find the pattern in the request_id
match = re.search(pattern, request_id)
if match:
# Extract the ranks
ip = match.group(1)
port = int(match.group(2))
return ip, port
raise ValueError(f"Request id {request_id} does not contain hostname and port")
@staticmethod
def check_tensors_except_dim(tensor1, tensor2, dim):
shape1 = tensor1.size()
shape2 = tensor2.size()
if len(shape1) != len(shape2) or not all(
s1 == s2 for i, (s1, s2) in enumerate(zip(shape1, shape2)) if i != dim
):
raise NotImplementedError(
"Currently, only symmetric TP is supported. Asymmetric TP, PP,"
"and others will be supported in future PRs."
)

View File

@@ -0,0 +1,632 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import logging
import os
import threading
import time
from collections import deque
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any
import msgpack
import torch
import zmq
from vllm.config.kv_transfer import KVTransferConfig
from vllm.distributed.device_communicators.pynccl_wrapper import (
NCCLLibrary,
buffer_type,
cudaStream_t,
ncclComm_t,
ncclDataTypeEnum,
)
from vllm.distributed.kv_transfer.kv_connector.v1.p2p.tensor_memory_pool import ( # noqa: E501
TensorMemoryPool,
)
from vllm.utils.network_utils import get_ip
from vllm.utils.torch_utils import current_stream
logger = logging.getLogger(__name__)
DEFAULT_MEM_POOL_SIZE_GB = 32
@contextmanager
def set_p2p_nccl_context(num_channels: str):
original_values: dict[str, Any] = {}
env_vars = [
"NCCL_MAX_NCHANNELS",
"NCCL_MIN_NCHANNELS",
"NCCL_CUMEM_ENABLE",
"NCCL_BUFFSIZE",
"NCCL_PROTO", # LL,LL128,SIMPLE
"NCCL_ALGO", # RING,TREE
]
for var in env_vars:
original_values[var] = os.environ.get(var)
logger.info("set_p2p_nccl_context, original_values: %s", original_values)
try:
os.environ["NCCL_MAX_NCHANNELS"] = num_channels
os.environ["NCCL_MIN_NCHANNELS"] = num_channels
os.environ["NCCL_CUMEM_ENABLE"] = "1"
yield
finally:
for var in env_vars:
if original_values[var] is not None:
os.environ[var] = original_values[var]
else:
os.environ.pop(var, None)
@dataclass
class SendQueueItem:
tensor_id: str
remote_address: str
tensor: torch.Tensor
class P2pNcclEngine:
def __init__(
self,
local_rank: int,
config: KVTransferConfig,
hostname: str = "",
port_offset: int = 0,
library_path: str | None = None,
) -> None:
self.config = config
self.rank = port_offset
self.local_rank = local_rank
self.device = torch.device(f"cuda:{self.local_rank}")
self.nccl = NCCLLibrary(library_path)
if not hostname:
hostname = get_ip()
port = int(self.config.kv_port) + port_offset
if port == 0:
raise ValueError("Port cannot be 0")
self._hostname = hostname
self._port = port
# Each card corresponds to a ZMQ address.
self.zmq_address = f"{self._hostname}:{self._port}"
# If `proxy_ip` or `proxy_port` is `""`,
# then the ping thread will not be enabled.
proxy_ip = self.config.get_from_extra_config("proxy_ip", "")
proxy_port = self.config.get_from_extra_config("proxy_port", "")
if proxy_ip == "" or proxy_port == "":
self.proxy_address = ""
self.http_address = ""
else:
self.proxy_address = proxy_ip + ":" + proxy_port
# the `http_port` must be consistent with the port of OpenAI.
http_port = self.config.get_from_extra_config("http_port", None)
if http_port is None:
example_cfg = {
"kv_connector": "P2pNcclConnector",
"kv_connector_extra_config": {"http_port": 8000},
}
example = (
f"--port=8000 --kv-transfer-config='{json.dumps(example_cfg)}'"
)
raise ValueError(
"kv_connector_extra_config.http_port is required. "
f"Example: {example}"
)
self.http_address = f"{self._hostname}:{http_port}"
self.context = zmq.Context()
self.router_socket = self.context.socket(zmq.ROUTER)
self.router_socket.bind(f"tcp://{self.zmq_address}")
self.poller = zmq.Poller()
self.poller.register(self.router_socket, zmq.POLLIN)
self.send_store_cv = threading.Condition()
self.send_queue_cv = threading.Condition()
self.recv_store_cv = threading.Condition()
self.send_stream = torch.cuda.Stream()
self.recv_stream = torch.cuda.Stream()
mem_pool_size_gb = float(
self.config.get_from_extra_config(
"mem_pool_size_gb", DEFAULT_MEM_POOL_SIZE_GB
)
)
self.pool = TensorMemoryPool(
max_block_size=int(mem_pool_size_gb * 1024**3)
) # GB
# The sending type includes tree mutually exclusive options:
# PUT, GET, PUT_ASYNC.
self.send_type = self.config.get_from_extra_config("send_type", "PUT_ASYNC")
if self.send_type == "GET":
# tensor_id: torch.Tensor
self.send_store: dict[str, torch.Tensor] = {}
else:
# PUT or PUT_ASYNC
# tensor_id: torch.Tensor
self.send_queue: deque[SendQueueItem] = deque()
if self.send_type == "PUT_ASYNC":
self._send_thread = threading.Thread(
target=self.send_async, daemon=True
)
self._send_thread.start()
# tensor_id: torch.Tensor/(addr, dtype, shape)
self.recv_store: dict[str, Any] = {}
self.recv_request_id_to_tensor_ids: dict[str, set[str]] = {}
self.send_request_id_to_tensor_ids: dict[str, set[str]] = {}
self.socks: dict[str, Any] = {} # remote_address: client socket
self.comms: dict[str, Any] = {} # remote_address: (ncclComm_t, rank)
self.buffer_size = 0
self.buffer_size_threshold = float(self.config.kv_buffer_size)
self.nccl_num_channels = self.config.get_from_extra_config(
"nccl_num_channels", "8"
)
self._listener_thread = threading.Thread(
target=self.listen_for_requests, daemon=True
)
self._listener_thread.start()
self._ping_thread = None
if port_offset == 0 and self.proxy_address != "":
self._ping_thread = threading.Thread(target=self.ping, daemon=True)
self._ping_thread.start()
logger.info(
"💯P2pNcclEngine init, rank:%d, local_rank:%d, http_address:%s, "
"zmq_address:%s, proxy_address:%s, send_type:%s, buffer_size_"
"threshold:%.2f, nccl_num_channels:%s",
self.rank,
self.local_rank,
self.http_address,
self.zmq_address,
self.proxy_address,
self.send_type,
self.buffer_size_threshold,
self.nccl_num_channels,
)
def create_connect(self, remote_address: str | None = None):
assert remote_address is not None
if remote_address not in self.socks:
sock = self.context.socket(zmq.DEALER)
sock.setsockopt_string(zmq.IDENTITY, self.zmq_address)
sock.connect(f"tcp://{remote_address}")
self.socks[remote_address] = sock
if remote_address in self.comms:
logger.info(
"👋comm exists, remote_address:%s, comms:%s",
remote_address,
self.comms,
)
return sock, self.comms[remote_address]
unique_id = self.nccl.ncclGetUniqueId()
data = {"cmd": "NEW", "unique_id": bytes(unique_id.internal)}
sock.send(msgpack.dumps(data))
with torch.cuda.device(self.device):
rank = 0
with set_p2p_nccl_context(self.nccl_num_channels):
comm: ncclComm_t = self.nccl.ncclCommInitRank(2, unique_id, rank)
self.comms[remote_address] = (comm, rank)
logger.info(
"🤝ncclCommInitRank Success, %s👉%s, MyRank:%s",
self.zmq_address,
remote_address,
rank,
)
return self.socks[remote_address], self.comms[remote_address]
def send_tensor(
self,
tensor_id: str,
tensor: torch.Tensor,
remote_address: str | None = None,
) -> bool:
if remote_address is None:
with self.recv_store_cv:
self.recv_store[tensor_id] = tensor
self.recv_store_cv.notify()
return True
item = SendQueueItem(
tensor_id=tensor_id, remote_address=remote_address, tensor=tensor
)
if self.send_type == "PUT":
return self.send_sync(item)
if self.send_type == "PUT_ASYNC":
with self.send_queue_cv:
self.send_queue.append(item)
self.send_queue_cv.notify()
return True
# GET
with self.send_store_cv:
tensor_size = tensor.element_size() * tensor.numel()
if tensor_size > self.buffer_size_threshold:
logger.warning(
"❗[GET]tensor_id:%s, tensor_size:%d, is greater than"
"buffer size threshold :%d, skip send to %s, rank:%d",
tensor_id,
tensor_size,
self.buffer_size_threshold,
remote_address,
self.rank,
)
return False
while self.buffer_size + tensor_size > self.buffer_size_threshold:
assert len(self.send_store) > 0
oldest_tensor_id = next(iter(self.send_store))
oldest_tensor = self.send_store.pop(oldest_tensor_id)
oldest_tensor_size = (
oldest_tensor.element_size() * oldest_tensor.numel()
)
self.buffer_size -= oldest_tensor_size
logger.debug(
"⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d,"
" buffer_size:%d, oldest_tensor_size:%d, rank:%d",
remote_address,
tensor_id,
tensor_size,
self.buffer_size,
oldest_tensor_size,
self.rank,
)
self.send_store[tensor_id] = tensor
self.buffer_size += tensor_size
logger.debug(
"🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, "
"shape:%s, rank:%d, buffer_size:%d(%.2f%%)",
remote_address,
tensor_id,
tensor_size,
tensor.shape,
self.rank,
self.buffer_size,
self.buffer_size / self.buffer_size_threshold * 100,
)
return True
def recv_tensor(
self,
tensor_id: str,
remote_address: str | None = None,
) -> torch.Tensor:
if self.send_type == "PUT" or self.send_type == "PUT_ASYNC":
start_time = time.time()
with self.recv_store_cv:
while tensor_id not in self.recv_store:
self.recv_store_cv.wait()
tensor = self.recv_store[tensor_id]
if tensor is not None:
if isinstance(tensor, tuple):
addr, dtype, shape = tensor
tensor = self.pool.load_tensor(addr, dtype, shape, self.device)
else:
self.buffer_size -= tensor.element_size() * tensor.numel()
else:
duration = time.time() - start_time
logger.warning(
"🔴[PUT]Recv From %s, tensor_id:%s, duration:%.3fms, rank:%d",
remote_address,
tensor_id,
duration * 1000,
self.rank,
)
return tensor
# GET
if remote_address is None:
return None
if remote_address not in self.socks:
self.create_connect(remote_address)
sock = self.socks[remote_address]
comm, rank = self.comms[remote_address]
data = {"cmd": "GET", "tensor_id": tensor_id}
sock.send(msgpack.dumps(data))
message = sock.recv()
data = msgpack.loads(message)
if data["ret"] != 0:
logger.warning(
"🔴[GET]Recv From %s, tensor_id: %s, ret: %d",
remote_address,
tensor_id,
data["ret"],
)
return None
with torch.cuda.stream(self.recv_stream):
tensor = torch.empty(
data["shape"], dtype=getattr(torch, data["dtype"]), device=self.device
)
self.recv(comm, tensor, rank ^ 1, self.recv_stream)
return tensor
def listen_for_requests(self):
while True:
socks = dict(self.poller.poll())
if self.router_socket not in socks:
continue
remote_address, message = self.router_socket.recv_multipart()
data = msgpack.loads(message)
if data["cmd"] == "NEW":
unique_id = self.nccl.unique_id_from_bytes(bytes(data["unique_id"]))
with torch.cuda.device(self.device):
rank = 1
with set_p2p_nccl_context(self.nccl_num_channels):
comm: ncclComm_t = self.nccl.ncclCommInitRank(
2, unique_id, rank
)
self.comms[remote_address.decode()] = (comm, rank)
logger.info(
"🤝ncclCommInitRank Success, %s👈%s, MyRank:%s",
self.zmq_address,
remote_address.decode(),
rank,
)
elif data["cmd"] == "PUT":
tensor_id = data["tensor_id"]
try:
with torch.cuda.stream(self.recv_stream):
tensor = torch.empty(
data["shape"],
dtype=getattr(torch, data["dtype"]),
device=self.device,
)
self.router_socket.send_multipart([remote_address, b"0"])
comm, rank = self.comms[remote_address.decode()]
self.recv(comm, tensor, rank ^ 1, self.recv_stream)
tensor_size = tensor.element_size() * tensor.numel()
if self.buffer_size + tensor_size > self.buffer_size_threshold:
# Store Tensor in memory pool
addr = self.pool.store_tensor(tensor)
tensor = (addr, tensor.dtype, tensor.shape)
logger.warning(
"🔴[PUT]Recv Tensor, Out Of Threshold, "
"%s👈%s, data:%s, addr:%d",
self.zmq_address,
remote_address.decode(),
data,
addr,
)
else:
self.buffer_size += tensor_size
except torch.cuda.OutOfMemoryError:
self.router_socket.send_multipart([remote_address, b"1"])
tensor = None
logger.warning(
"🔴[PUT]Recv Tensor, Out Of Memory, %s👈%s, data:%s",
self.zmq_address,
remote_address.decode(),
data,
)
with self.recv_store_cv:
self.recv_store[tensor_id] = tensor
self.have_received_tensor_id(tensor_id)
self.recv_store_cv.notify()
elif data["cmd"] == "GET":
tensor_id = data["tensor_id"]
with self.send_store_cv:
tensor = self.send_store.pop(tensor_id, None)
if tensor is not None:
data = {
"ret": 0,
"shape": tensor.shape,
"dtype": str(tensor.dtype).replace("torch.", ""),
}
# LRU
self.send_store[tensor_id] = tensor
self.have_sent_tensor_id(tensor_id)
else:
data = {"ret": 1}
self.router_socket.send_multipart([remote_address, msgpack.dumps(data)])
if data["ret"] == 0:
comm, rank = self.comms[remote_address.decode()]
self.send(comm, tensor.to(self.device), rank ^ 1, self.send_stream)
else:
logger.warning(
"🚧Unexpected, Received message from %s, data:%s",
remote_address,
data,
)
def have_sent_tensor_id(self, tensor_id: str):
request_id = tensor_id.split("#")[0]
if request_id not in self.send_request_id_to_tensor_ids:
self.send_request_id_to_tensor_ids[request_id] = set()
self.send_request_id_to_tensor_ids[request_id].add(tensor_id)
def have_received_tensor_id(self, tensor_id: str):
request_id = tensor_id.split("#")[0]
if request_id not in self.recv_request_id_to_tensor_ids:
self.recv_request_id_to_tensor_ids[request_id] = set()
self.recv_request_id_to_tensor_ids[request_id].add(tensor_id)
def send_async(self):
while True:
with self.send_queue_cv:
while not self.send_queue:
self.send_queue_cv.wait()
item = self.send_queue.popleft()
if not self.send_queue:
self.send_queue_cv.notify()
self.send_sync(item)
def wait_for_sent(self):
if self.send_type == "PUT_ASYNC":
start_time = time.time()
with self.send_queue_cv:
while self.send_queue:
self.send_queue_cv.wait()
duration = time.time() - start_time
logger.debug(
"🚧[PUT_ASYNC]It took %.3fms to wait for the send_queue"
" to be empty, rank:%d",
duration * 1000,
self.rank,
)
def send_sync(self, item: SendQueueItem) -> bool:
if item.remote_address is None:
return False
if item.remote_address not in self.socks:
self.create_connect(item.remote_address)
tensor = item.tensor
sock = self.socks[item.remote_address]
comm, rank = self.comms[item.remote_address]
data = {
"cmd": "PUT",
"tensor_id": item.tensor_id,
"shape": tensor.shape,
"dtype": str(tensor.dtype).replace("torch.", ""),
}
sock.send(msgpack.dumps(data))
response = sock.recv()
if response != b"0":
logger.error(
"🔴Send Tensor, Peer Out Of Memory/Threshold, %s 👉 %s, "
"MyRank:%s, data:%s, tensor:%s, size:%fGB, response:%s",
self.zmq_address,
item.remote_address,
rank,
data,
tensor.shape,
tensor.element_size() * tensor.numel() / 1024**3,
response.decode(),
)
return False
self.send(comm, tensor.to(self.device), rank ^ 1, self.send_stream)
if self.send_type == "PUT_ASYNC":
self.have_sent_tensor_id(item.tensor_id)
return True
def get_finished(
self, finished_req_ids: set[str], no_compile_layers
) -> tuple[set[str] | None, set[str] | None]:
"""
Notifies worker-side connector ids of requests that have
finished generating tokens.
Returns:
ids of requests that have finished asynchronous transfer,
tuple of (sending/saving ids, recving/loading ids).
The finished saves/sends req ids must belong to a set provided in a
call to this method (this call or a prior one).
"""
# Clear the buffer upon request completion.
for request_id in finished_req_ids:
for layer_name in no_compile_layers:
tensor_id = request_id + "#" + layer_name
if tensor_id in self.recv_store:
with self.recv_store_cv:
tensor = self.recv_store.pop(tensor_id, None)
self.send_request_id_to_tensor_ids.pop(request_id, None)
self.recv_request_id_to_tensor_ids.pop(request_id, None)
if isinstance(tensor, tuple):
addr, _, _ = tensor
self.pool.free(addr)
# TODO:Retrieve requests that have already sent the KV cache.
finished_sending: set[str] = set()
# TODO:Retrieve requests that have already received the KV cache.
finished_recving: set[str] = set()
return finished_sending or None, finished_recving or None
def ping(self):
sock = self.context.socket(zmq.DEALER)
sock.setsockopt_string(zmq.IDENTITY, self.zmq_address)
logger.debug("ping start, zmq_address:%s", self.zmq_address)
sock.connect(f"tcp://{self.proxy_address}")
data = {
"type": "P" if self.config.is_kv_producer else "D",
"http_address": self.http_address,
"zmq_address": self.zmq_address,
}
while True:
sock.send(msgpack.dumps(data))
time.sleep(3)
def send(self, comm, tensor: torch.Tensor, dst: int, stream=None):
assert tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}"
)
if stream is None:
stream = current_stream()
with torch.cuda.stream(stream):
self.nccl.ncclSend(
buffer_type(tensor.data_ptr()),
tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype),
dst,
comm,
cudaStream_t(stream.cuda_stream),
)
stream.synchronize()
def recv(self, comm, tensor: torch.Tensor, src: int, stream=None):
assert tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}"
)
if stream is None:
stream = current_stream()
with torch.cuda.stream(stream):
self.nccl.ncclRecv(
buffer_type(tensor.data_ptr()),
tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype),
src,
comm,
cudaStream_t(stream.cuda_stream),
)
stream.synchronize()
def close(self) -> None:
self._listener_thread.join()
if self.send_type == "PUT_ASYNC":
self._send_thread.join()
if self._ping_thread is not None:
self._ping_thread.join()

View File

@@ -0,0 +1,273 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import atexit
import ctypes
import math
from dataclasses import dataclass
import torch
from vllm.logger import init_logger
logger = init_logger(__name__)
@dataclass
class MemoryBlock:
size: int
addr: int
"""A memory pool for managing pinned host memory allocations for tensors.
This class implements a buddy allocation system to efficiently manage pinned
host memory for tensor storage. It supports allocation, deallocation, and
tensor storage/retrieval operations.
Key Features:
- Uses power-of-two block sizes for efficient buddy allocation
- Supports splitting and merging of memory blocks
- Provides methods to store CUDA tensors in pinned host memory
- Allows loading tensors from pinned memory back to device
- Automatically cleans up memory on destruction
Attributes:
max_block_size (int): Maximum block size (rounded to nearest power of two)
min_block_size (int): Minimum block size (rounded to nearest power of two)
free_lists (dict): Dictionary of free memory blocks by size
allocated_blocks (dict): Dictionary of currently allocated blocks
base_tensor (torch.Tensor): Base pinned memory tensor
base_address (int): Base memory address of the pinned memory region
Example:
>>> pool = TensorMemoryPool(max_block_size=1024*1024)
>>> tensor = torch.randn(100, device='cuda')
>>> addr = pool.store_tensor(tensor)
>>> loaded_tensor = pool.load_tensor(addr, tensor.dtype,
... tensor.shape, 'cuda')
>>> pool.free(addr)
"""
class TensorMemoryPool:
"""Initializes the memory pool with given size constraints.
Args:
max_block_size (int): Maximum size of memory blocks to manage
min_block_size (int, optional): Minimum size of memory blocks
to manage. Defaults to 512.
Raises:
ValueError: If block sizes are invalid or max_block_size is less
than min_block_size
"""
def __init__(self, max_block_size: int, min_block_size: int = 512):
if max_block_size <= 0 or min_block_size <= 0:
raise ValueError("Block sizes must be positive")
if max_block_size < min_block_size:
raise ValueError("Max block size must be greater than min block size")
self.max_block_size = self._round_to_power_of_two(max_block_size)
self.min_block_size = self._round_to_power_of_two(min_block_size)
self.free_lists: dict[int, dict[int, MemoryBlock]] = {}
self.allocated_blocks: dict[int, MemoryBlock] = {}
self._initialize_free_lists()
self._allocate_pinned_memory()
atexit.register(self.cleanup)
def _round_to_power_of_two(self, size: int) -> int:
return 1 << (size - 1).bit_length()
def _initialize_free_lists(self):
size = self.max_block_size
while size >= self.min_block_size:
self.free_lists[size] = {}
size //= 2
def _allocate_pinned_memory(self):
self.base_tensor = torch.empty(
self.max_block_size // 4, dtype=torch.float32, pin_memory=True
)
self.base_address = self.base_tensor.data_ptr()
initial_block = MemoryBlock(size=self.max_block_size, addr=self.base_address)
self.free_lists[self.max_block_size][initial_block.addr] = initial_block
logger.debug(
"TensorMemoryPool, base_address:%d, max_block_size:%d",
self.base_address,
self.max_block_size,
)
def allocate(self, size: int) -> int:
"""Allocates a memory block of at least the requested size.
Args:
size (int): Minimum size of memory to allocate
Returns:
int: Address of the allocated memory block
Raises:
ValueError: If size is invalid or insufficient memory is available
"""
if size <= 0:
raise ValueError("Allocation size must be positive")
required_size = self._round_to_power_of_two(max(size, self.min_block_size))
if required_size > self.max_block_size:
raise ValueError("Requested size exceeds maximum block size")
current_size = required_size
while current_size <= self.max_block_size:
if self.free_lists[current_size]:
_, block = self.free_lists[current_size].popitem()
self._split_block(block, required_size)
self.allocated_blocks[block.addr] = block
return block.addr
current_size *= 2
raise ValueError("Insufficient memory")
def _split_block(self, block: MemoryBlock, required_size: int):
while block.size > required_size and block.size // 2 >= self.min_block_size:
buddy_size = block.size // 2
buddy_addr = block.addr + buddy_size
buddy = MemoryBlock(size=buddy_size, addr=buddy_addr)
block.size = buddy_size
self.free_lists[buddy_size][buddy.addr] = buddy
def free(self, addr: int):
"""Frees an allocated memory block.
Args:
addr (int): Address of the block to free
Raises:
ValueError: If address is invalid or not allocated
"""
if addr not in self.allocated_blocks:
raise ValueError("Invalid address to free")
block = self.allocated_blocks.pop(addr)
self._merge_buddies(block)
def _merge_buddies(self, block: MemoryBlock):
MAX_MERGE_DEPTH = 30
depth = 0
while depth < MAX_MERGE_DEPTH:
buddy_offset = (
block.size
if (block.addr - self.base_address) % (2 * block.size) == 0
else -block.size
)
buddy_addr = block.addr + buddy_offset
buddy = self.free_lists[block.size].get(buddy_addr)
if buddy:
del self.free_lists[buddy.size][buddy.addr]
merged_addr = min(block.addr, buddy.addr)
merged_size = block.size * 2
block = MemoryBlock(size=merged_size, addr=merged_addr)
depth += 1
else:
break
self.free_lists[block.size][block.addr] = block
def store_tensor(self, tensor: torch.Tensor) -> int:
"""Stores a CUDA tensor in pinned host memory.
Args:
tensor (torch.Tensor): CUDA tensor to store
Returns:
int: Address where the tensor is stored
Raises:
ValueError: If tensor is not on CUDA or allocation fails
"""
if not tensor.is_cuda:
raise ValueError("Only CUDA tensors can be stored")
size = tensor.element_size() * tensor.numel()
addr = self.allocate(size)
block = self.allocated_blocks[addr]
if block.size < size:
self.free(addr)
raise ValueError(
f"Allocated block size {block.size} is smaller than "
f"required size {size}"
)
try:
buffer = (ctypes.c_byte * block.size).from_address(block.addr)
cpu_tensor = torch.frombuffer(
buffer, dtype=tensor.dtype, count=tensor.numel()
).reshape(tensor.shape)
except ValueError as err:
self.free(addr)
raise ValueError(f"Failed to create tensor view: {err}") from err
cpu_tensor.copy_(tensor)
return addr
def load_tensor(
self,
addr: int,
dtype: torch.dtype,
shape: tuple[int, ...],
device: torch.device,
) -> torch.Tensor:
"""Loads a tensor from pinned host memory to the specified device.
Args:
addr (int): Address where tensor is stored
dtype (torch.dtype): Data type of the tensor
shape (tuple[int, ...]): Shape of the tensor
device: Target device for the loaded tensor
Returns:
torch.Tensor: The loaded tensor on the specified device
Raises:
ValueError: If address is invalid or sizes don't match
"""
if addr not in self.allocated_blocks:
raise ValueError("Invalid address to load")
block = self.allocated_blocks[addr]
num_elements = math.prod(shape)
dtype_size = torch.tensor([], dtype=dtype).element_size()
required_size = num_elements * dtype_size
if required_size > block.size:
raise ValueError("Requested tensor size exceeds block size")
buffer = (ctypes.c_byte * block.size).from_address(block.addr)
cpu_tensor = torch.frombuffer(buffer, dtype=dtype, count=num_elements).reshape(
shape
)
cuda_tensor = torch.empty(shape, dtype=dtype, device=device)
cuda_tensor.copy_(cpu_tensor)
return cuda_tensor
def cleanup(self):
"""Cleans up all memory resources and resets the pool state."""
self.free_lists.clear()
self.allocated_blocks.clear()
if hasattr(self, "base_tensor"):
del self.base_tensor
def __del__(self):
self.cleanup()

View File

@@ -0,0 +1,450 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
import os
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Optional
import safetensors
import torch
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1,
KVConnectorMetadata,
KVConnectorRole,
)
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request
logger = init_logger(__name__)
@dataclass
class ReqMeta:
# Request tokens
token_ids: torch.Tensor
# Slot mappings, should have the same length as token_ids
slot_mapping: torch.Tensor
# Is store or load
is_store: bool
mm_hashes: list[str]
@staticmethod
def make_meta(
token_ids: list[int],
block_ids: list[int],
block_size: int,
is_store: bool,
mm_hashes: list[str],
) -> "ReqMeta":
valid_num_tokens = align_to_block_size(len(token_ids), block_size)
token_ids_tensor = torch.tensor(token_ids)[:valid_num_tokens]
block_ids_tensor = torch.tensor(block_ids)
num_blocks = block_ids_tensor.shape[0]
block_offsets = torch.arange(0, block_size)
slot_mapping = (
block_offsets.reshape((1, block_size))
+ block_ids_tensor.reshape((num_blocks, 1)) * block_size
)
slot_mapping = slot_mapping.flatten()[:valid_num_tokens]
return ReqMeta(
token_ids=token_ids_tensor,
slot_mapping=slot_mapping,
is_store=is_store,
mm_hashes=mm_hashes,
)
@dataclass
class SharedStorageConnectorMetadata(KVConnectorMetadata):
requests: list[ReqMeta] = field(default_factory=list)
def add_request(
self,
token_ids: list[int],
block_ids: list[int],
block_size: int,
is_store: bool,
mm_hashes: list[str],
) -> None:
self.requests.append(
ReqMeta.make_meta(token_ids, block_ids, block_size, is_store, mm_hashes)
)
class SharedStorageConnector(KVConnectorBase_V1):
# NOTE: This is Simple debug implementation of the KV connector.
# It save / load the KV cache to / from the disk.
# It does extra work which will overwrite the existing prefix-cache in GPU
# - to remove the overhead, need to add some "mask" in the ReqMeta class
def __init__(
self,
vllm_config: "VllmConfig",
role: KVConnectorRole,
kv_cache_config: Optional["KVCacheConfig"] = None,
):
super().__init__(
vllm_config=vllm_config,
role=role,
kv_cache_config=kv_cache_config,
)
self._block_size = vllm_config.cache_config.block_size
self._requests_need_load: dict[str, Request] = {}
self._storage_path = self._kv_transfer_config.get_from_extra_config(
"shared_storage_path", "/tmp"
)
logger.info(self._kv_transfer_config)
logger.info("Shared storage path is %s", self._storage_path)
def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None:
"""Start loading the KV cache from the connector buffer to vLLM's
paged KV buffer.
Args:
forward_context (ForwardContext): the forward context.
**kwargs: additional arguments for the load operation
Note:
The number of elements in kv_caches and layer_names should be
the same.
"""
attn_metadata = forward_context.attn_metadata
def inject_kv_into_layer(
dst_kv_cache_layer: torch.Tensor,
src_kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
) -> None:
"""Inject the KV cache into the layer.
Args:
dst_kv_cache_layer (torch.Tensor): the destination KV cache
layer. In shape [2, num_pages, page_size, xxx] if not
using MLA, [num_pages, page_size, xxx] otherwise.
src_kv_cache (torch.Tensor): the source KV cache. In shape
[2, num_tokens, xxx] if not using MLA, [num_tokens, xxx]
otherwise.
slot_mapping (torch.Tensor): the slot mapping. In shape
[num_tokens].
"""
dst_kv_cache_layer_shape = dst_kv_cache_layer.shape
if isinstance(attn_metadata, MLACommonMetadata):
num_pages = dst_kv_cache_layer_shape[0]
page_size = dst_kv_cache_layer_shape[1]
dst_kv_cache_layer = dst_kv_cache_layer.reshape(
num_pages * page_size, -1
)
dst_kv_cache_layer[slot_mapping, ...] = src_kv_cache
dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)
else:
num_pages = dst_kv_cache_layer_shape[1]
page_size = dst_kv_cache_layer_shape[2]
dst_kv_cache_layer = dst_kv_cache_layer.reshape(
2, num_pages * page_size, -1
)
dst_kv_cache_layer[:, slot_mapping, ...] = src_kv_cache
dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)
# Get the metadata
metadata: KVConnectorMetadata = self._get_connector_metadata()
assert isinstance(metadata, SharedStorageConnectorMetadata)
if metadata is None:
logger.warning(
"In connector.start_load_kv, but the connector metadata is None"
)
return
attn_metadata = forward_context.attn_metadata
if attn_metadata is None:
logger.warning("In connector.start_load_kv, but the attn_metadata is None")
return
# Load the KV for each request each layer
for request in metadata.requests:
if request.is_store:
continue
logger.info(
"Inject KV cache of %d tokens to the paged memory",
len(request.slot_mapping),
)
for layer_name in forward_context.no_compile_layers:
layer = forward_context.no_compile_layers[layer_name]
# Only process layers that have kv_cache
# attribute (attention layers) Skip non-attention
# layers like FusedMoE/MLP etc.
kv_cache_attr = getattr(layer, "kv_cache", None)
if kv_cache_attr is None:
continue
kv_cache_layer = kv_cache_attr[forward_context.virtual_engine]
filename = self._generate_filename_debug(
layer_name, request.token_ids, request.mm_hashes
)
kv_cache = safetensors.torch.load_file(filename)["kv_cache"].cuda()
inject_kv_into_layer(kv_cache_layer, kv_cache, request.slot_mapping)
def wait_for_layer_load(self, layer_name: str) -> None:
"""Blocking until the KV for a specific layer is loaded into vLLM's
paged buffer.
This interface will be useful for layer-by-layer pipelining.
Args:
layer_name: the name of that layer
"""
return
def save_kv_layer(
self,
layer_name: str,
kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata",
**kwargs: Any,
) -> None:
"""Start saving the KV cache of the layer from vLLM's paged buffer
to the connector.
Args:
layer_name (str): the name of the layer.
kv_layer (torch.Tensor): the paged KV buffer of the current
layer in vLLM.
attn_metadata (AttentionMetadata): the attention metadata.
**kwargs: additional arguments for the save operation.
"""
def extract_kv_from_layer(
layer: torch.Tensor,
slot_mapping: torch.Tensor,
) -> torch.Tensor:
"""Extract the KV cache from the layer.
Assume the shape of the layer is (2, num_pages, page_size, xxx)
if MLA is not used, and (num_pages, page_size, xxx) otherwise.
"""
if isinstance(attn_metadata, MLACommonMetadata):
num_pages, page_size = layer.shape[0], layer.shape[1]
return layer.reshape(num_pages * page_size, -1)[slot_mapping, ...]
num_pages, page_size = layer.shape[1], layer.shape[2]
return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping, ...]
connector_metadata = self._get_connector_metadata()
assert isinstance(connector_metadata, SharedStorageConnectorMetadata)
for request in connector_metadata.requests:
if request.is_store:
filename = self._generate_filename_debug(
layer_name, request.token_ids, request.mm_hashes
)
kv_cache = extract_kv_from_layer(kv_layer, request.slot_mapping)
tensors = {"kv_cache": kv_cache.detach().cpu()}
safetensors.torch.save_file(tensors, filename)
def wait_for_save(self):
return
def get_num_new_matched_tokens(
self,
request: "Request",
num_computed_tokens: int,
) -> tuple[int | None, bool]:
"""
Get number of new tokens that can be loaded from the
external KV cache beyond the num_computed_tokens.
Args:
request (Request): the request object.
num_computed_tokens (int): the number of locally
computed tokens for this request
Returns:
the number of tokens that can be loaded from the
external KV cache beyond what is already computed.
"""
# NOTE: in this debug implementation, we assume that the prompt is
# cached_prompt + newly_generated_single_token
# Therefore, we use prompt_token_ids[:-1] to determine the folder name
# NOTE: in current v1 scheduler, the num_computed_tokens is aligned
# with the block granularity. And it expects the returned blocks and
# num_computed_tokens to also be aligned with the block granularity.
if not self._found_match_for_request(request):
return 0, False
logger.info("External Cache Hit!")
# Now, first num_tokens_to_check tokens are hit, we need to prepare
# the metadata for the worker connector to correctly load the KV
token_ids = request.prompt_token_ids or []
num_tokens_to_check = align_to_block_size(len(token_ids) - 1, self._block_size)
return num_tokens_to_check - num_computed_tokens, False
def update_state_after_alloc(
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
):
"""
Update KVConnector state after block allocation.
If blocks were allocated, add to _requests_need_load,
such that we load the KVs in the next forward pass.
"""
if num_external_tokens > 0:
self._requests_need_load[request.request_id] = request
def build_connector_meta(
self,
scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata:
"""Build the connector metadata for this step.
This function should NOT modify any fields in the scheduler_output.
Also, calling this function will reset the state of the connector.
Args:
scheduler_output (SchedulerOutput): the scheduler output object.
"""
meta = SharedStorageConnectorMetadata()
total_need_load = 0
for new_req in scheduler_output.scheduled_new_reqs:
token_ids = new_req.prompt_token_ids or []
mm_hashes = [f.identifier for f in new_req.mm_features]
if new_req.req_id in self._requests_need_load:
meta.add_request(
token_ids=token_ids,
block_ids=new_req.block_ids[0],
block_size=self._block_size,
is_store=False,
mm_hashes=mm_hashes,
)
total_need_load += 1
else:
# NOTE: here, we set the store and load being exclusive,
# but a single request can have both store and load.
# NOTE(rob): for this debug implementation, we only cache
# the original prompt tokens.
if not self._found_match_for_prompt(token_ids, mm_hashes):
meta.add_request(
token_ids=token_ids,
block_ids=new_req.block_ids[0],
block_size=self._block_size,
is_store=True,
mm_hashes=mm_hashes,
)
cached_reqs = scheduler_output.scheduled_cached_reqs
for i, req_id in enumerate(cached_reqs.req_ids):
resumed_from_preemption = req_id in cached_reqs.resumed_req_ids
if not resumed_from_preemption or req_id not in self._requests_need_load:
continue
num_computed_tokens = cached_reqs.num_computed_tokens[i]
num_new_tokens = scheduler_output.num_scheduled_tokens[req_id]
new_block_ids = cached_reqs.new_block_ids[i]
# NOTE(rob): cached_req_data does not have the full
# list of token ids (only new tokens). So we look it
# up in the actual request object.
request = self._requests_need_load[req_id]
total_tokens = num_computed_tokens + num_new_tokens
token_ids = request.all_token_ids[:total_tokens]
# NOTE(rob): For resumed req, new_block_ids is all
# of the block_ids for the request.
assert new_block_ids is not None
block_ids = new_block_ids[0]
meta.add_request(
token_ids=token_ids,
block_ids=block_ids,
block_size=self._block_size,
is_store=False,
mm_hashes=[f.identifier for f in request.mm_features],
)
total_need_load += 1
assert total_need_load == len(self._requests_need_load)
self._requests_need_load.clear()
return meta
# ==============================
# Helper functions
# ==============================
def _found_match_for_request(
self,
request: "Request",
) -> bool:
"""Check if the cache is hit for the request."""
return self._found_match_for_prompt(
list(request.prompt_token_ids or []),
[f.identifier for f in request.mm_features],
)
def _found_match_for_prompt(
self,
prompt_token_ids: list[int],
mm_hashes: list[str],
) -> bool:
num_tokens_to_check = align_to_block_size(
len(prompt_token_ids) - 1, self._block_size
)
foldername = self._generate_foldername_debug(
torch.tensor(prompt_token_ids)[:num_tokens_to_check],
mm_hashes,
create_folder=False,
)
return os.path.exists(foldername)
def _generate_foldername_debug(
self,
token_ids: torch.Tensor,
mm_hashes: list[str],
create_folder=False,
) -> str:
"""Generate a folder name based on the hash of the bytes of the input
ids.
"""
token_bytes = token_ids.numpy().tobytes()
# Add mm_hashes to the bytes being hashed to avoid path traversal and
# to create a canonical key.
if mm_hashes:
mm_str = "-".join(mm_hashes)
token_bytes += mm_str.encode("utf-8")
input_ids_hash = hashlib.md5(token_bytes, usedforsecurity=False).hexdigest()
foldername = os.path.join(self._storage_path, input_ids_hash)
if create_folder:
os.makedirs(foldername, exist_ok=True)
return foldername
def _generate_filename_debug(
self,
layer_name: str,
token_ids: torch.Tensor,
mm_hashes: list[str],
) -> str:
"""Generate a file name based on the layer name and the hash
of the bytes of the input ids.
"""
foldername = self._generate_foldername_debug(
token_ids, mm_hashes=mm_hashes, create_folder=True
)
return os.path.join(foldername, f"{layer_name}.safetensors")
def align_to_block_size(num_tokens: int, block_size) -> int:
"""Align the number of tokens to the block size."""
return (num_tokens - 1) // block_size * block_size

View File

@@ -0,0 +1,179 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This file contains a new class `KVLookupBufferBase` that allows developers to
think of KV cache operations as inserting new KV cache entries (`insert`)
into the lookup buffer and querying existing KV caches (`drop_select`)
from the lookup buffer.
This file also contains a new class `KVStoreBufferBase` that allows developers
to manage the KVCache buffer as a simple key-value storage buffer with basic
put/get operations.
These classes above are abstracted behind class `KVCacheBufferBase`.
"""
from abc import ABC, abstractmethod
import torch
class KVCacheBufferBase(ABC):
"""
Abstract base class for a KVCache buffer.
"""
@abstractmethod
def close(self) -> None:
"""Close the buffer and release resources.
This method is responsible for cleaning up resources related to the
KVCache buffer when it is no longer needed.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise NotImplementedError
class KVLookupBufferBase(KVCacheBufferBase):
"""
Abstract base class for a KVCache lookup buffer.
This class provides an abstraction for a key-value (KV) cache lookup buffer.
The key of the lookup buffer:
- input_tokens: token IDs of the request
- roi: a binary mask on top of input_tokens.
- Purpose of roi: Since KV cache may only be available for a subset of
tokens in the input (for example, when vLLM is connected to an external
KV cache service), roi specifies the subset of tokens that the KV cache
is associated with.
- NOTE: roi can be further extended to describe which part of KV the
current process is holding (each process may only hold a part of KV
due to TP and PP). This is not implemented for now.
The value of the lookup buffer:
- key: the key tensor in the KV cache
- value: the value tensor in the KV cache
- hidden: the final hidden state generated by model forwarding. This allows
vLLM to bypass further model forwarding by transmitting the hidden state.
"""
@abstractmethod
def insert(
self,
input_tokens: torch.Tensor,
roi: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
hidden: torch.Tensor,
) -> None:
"""Insert into the lookup buffer.
The functionality is similar to the following python statement
```
buffer[input_tokens, roi] = [key, value, hidden]
```
FIXME: in the future, we should only have two arguments, key and value,
where key is a tensor dict and value is a tensor dict.
FIXME: we should transmit both sampler outputs and the hidden states.
Args:
input_tokens (torch.Tensor): token IDs.
roi (torch.Tensor): A binary mask on top of the input tokens
key (torch.Tensor): The key tensor in the KV cache.
value (torch.Tensor): The value tensor in the KV cache.
hidden (torch.Tensor): The final hidden state tensor generated
during model forwarding to bypass model
forwarding.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise NotImplementedError
@abstractmethod
def drop_select(
self, input_tokens: torch.Tensor | None, roi: torch.Tensor | None
) -> list[torch.Tensor | None]:
"""Select and *drop* KV cache entries from the lookup buffer.
The functionality is similar to the following python statements
```
ret = buffer.pop(input_tokens, roi)
return ret
```
If `input_tokens` and `roi` is `None`, it means selecting any of the
KV caches in the buffer, return, and remove it from the buffer, useful
when offloading KV cache to KV cache storage service.
Args:
input_tokens (torch.Tensor): token IDs.
roi (torch.Tensor): A binary mask on top of the input tokens
Returns:
list[Optional[torch.Tensor]]: A list of tensors. Can be None.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise NotImplementedError
class KVStoreBufferBase(KVCacheBufferBase):
"""
Abstract base class for a KVCache storage buffer with key-value semantics.
This class provides a simple key-value storage buffer abstract with basic
put/get operations, which enables flexible KVCache transfer granular
control.
The functionality is similar to a distributed key-value store, where:
- Key: A unique string identifier for the cached entry
- Value:
- Tensor to be stored and retrieved
- None (indicating deletion or empty value)
"""
@abstractmethod
def put(
self,
key: str,
value: torch.Tensor | None,
) -> None:
"""Store a key-value pair in the buffer.
Args:
key (str): Unique identifier for a tensor, this tensor could be the
key cache tensor, value cache tensor, or hidden state tensor
generated during model forwarding.
value (Optional[torch.Tensor]): Tensor to be stored.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise NotImplementedError
@abstractmethod
def get(
self,
key: str,
) -> torch.Tensor | None:
"""Retrieve a value from the buffer by key.
Args:
key (str): Unique identifier for a tensor, this tensor could be the
key cache tensor, value cache tensor, or hidden state tensor
generated during model forwarding.
Returns:
Optional[torch.Tensor]: Stored tensor if exists, None otherwise.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise NotImplementedError

View File

@@ -0,0 +1,164 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This file contains a new class `MooncakeStore` that allows developers to
think of KV cache transfer operations as putting new KV cache entries
into a remote KVStore-based lookup buffer and getting existing KV caches
from this remote lookup buffer.
"""
import json
import os
from dataclasses import dataclass
import torch
from safetensors.torch import load as safetensors_load
from safetensors.torch import save as safetensors_save
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_lookup_buffer.base import KVStoreBufferBase
from vllm.logger import init_logger
DEFAULT_GLOBAL_SEGMENT_SIZE = 3355443200 # 3.125 GiB
DEFAULT_LOCAL_BUFFER_SIZE = 1073741824 # 1.0 GiB
logger = init_logger(__name__)
@dataclass
class MooncakeStoreConfig:
local_hostname: str
metadata_server: str
global_segment_size: int
local_buffer_size: int
protocol: str
device_name: str
master_server_address: str
@staticmethod
def from_file(file_path: str) -> "MooncakeStoreConfig":
"""Load the config from a JSON file."""
with open(file_path) as fin:
config = json.load(fin)
return MooncakeStoreConfig(
local_hostname=config.get("local_hostname"),
metadata_server=config.get("metadata_server"),
global_segment_size=config.get(
"global_segment_size", DEFAULT_GLOBAL_SEGMENT_SIZE
),
local_buffer_size=config.get(
"local_buffer_size", DEFAULT_LOCAL_BUFFER_SIZE
),
protocol=config.get("protocol", "tcp"),
device_name=config.get("device_name", ""),
master_server_address=config.get("master_server_address"),
)
@staticmethod
def load_from_env() -> "MooncakeStoreConfig":
"""Load config from a file specified in the environment variable."""
config_file_path = os.getenv("MOONCAKE_CONFIG_PATH")
if config_file_path is None:
raise ValueError(
"The environment variable 'MOONCAKE_CONFIG_PATH' is not set."
)
return MooncakeStoreConfig.from_file(config_file_path)
class MooncakeStore(KVStoreBufferBase):
def __init__(
self,
config: VllmConfig,
):
try:
from mooncake.store import MooncakeDistributedStore
except ImportError as e:
raise ImportError(
"Please install mooncake by following the instructions at "
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501
"to run vLLM with MooncakeConnector."
) from e
try:
self.store = MooncakeDistributedStore()
self.config = MooncakeStoreConfig.load_from_env()
logger.info("Mooncake Configuration loaded successfully.")
self.store.setup(
self.config.local_hostname,
self.config.metadata_server,
self.config.global_segment_size,
self.config.local_buffer_size,
self.config.protocol,
self.config.device_name,
self.config.master_server_address,
)
except ValueError as e:
logger.error("Configuration loading failed: %s", e)
raise
except Exception as exc:
logger.error("An error occurred while loading the configuration: %s", exc)
raise
def close(self):
# MooncakeDistributedStore will automatically call the destructor, so
# it is unnecessary to close it manually.
pass
def put(
self,
key: str,
value: torch.Tensor | None,
) -> None:
# A message queue needs to be introduced before making it asynchronous.
if value is not None:
self._put_impl(key, value)
def get(
self,
key: str,
) -> torch.Tensor | None:
# A message queue needs to be introduced before making it asynchronous.
value = self._get_impl(key)
return value
def _put_impl(
self,
key: str,
value: torch.Tensor,
) -> None:
"""Put KVCache to Mooncake Store"""
device_id = value.device.index if value.device.type == "cuda" else -1
device_tensor = torch.tensor(device_id, dtype=torch.int32)
value_bytes = safetensors_save({"tensor": value, "device_id": device_tensor})
try:
self.store.put(key, value_bytes)
except TypeError as err:
logger.error("Failed to put value into Mooncake Store: %s", err)
raise TypeError("Mooncake Store Put Type Error.") from err
def _get_impl(
self,
key: str,
) -> torch.Tensor | None:
"""Get KVCache from Mooncake Store"""
try:
data = self.store.get(key)
except TypeError as err:
logger.error("Failed to get value from Mooncake Store: %s", err)
raise TypeError("Mooncake Store Get Type Error.") from err
if data:
loaded_tensors = safetensors_load(data)
tensor = loaded_tensors["tensor"]
device_id_tensor = loaded_tensors["device_id"]
device_id = int(device_id_tensor.item())
device = (
torch.device("cuda", device_id)
if device_id >= 0
else torch.device("cpu")
)
return tensor.to(device)
return None

View File

@@ -0,0 +1,242 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Implements a distributed key-value (KV) cache transfer mechanism.
Key Features:
- Distributed KV cache transmission using PyNccl pipes.
- Non-blocking `insert`, blocking `drop_select`.
- Use CPU signal pipe to avoid racing condition
- Handles buffer size constraints and provide backpressure mechanism to
stop the prefill instance when the decode instance is slow.
"""
import threading
from collections import deque
import torch
from vllm.distributed.kv_transfer.kv_lookup_buffer.base import KVLookupBufferBase
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
from vllm.logger import init_logger
logger = init_logger(__name__)
class SimpleBuffer(KVLookupBufferBase):
def __init__(
self, signal_pipe: KVPipeBase, data_pipe: KVPipeBase, buffer_size_thresh: float
):
"""
signal_pipe: on CPU
NOTE: on-device recv will block all threads in the process, making the
KV cache producer unable to listen to new request while transmitting
KV cache. Luckily CPU recv only blocks the current thread so we use
CPU recv to listen to new request.
data_pipe: on device (e.g. GPU)
"""
self.buffer: deque[list[torch.Tensor]] = deque()
self.buffer_size = 0
self.buffer_size_threshold = buffer_size_thresh
self.buffer_cv = threading.Condition()
self.signal_pipe = signal_pipe
self.data_pipe = data_pipe
self.request_handling_thread: threading.Thread | None = None
self.normal_signal = torch.tensor([0], device="cpu")
self.end_signal = None
def _matches(
self,
tokens_roi_sender: list[torch.Tensor],
tokens_roi_recver: list[torch.Tensor],
):
# tokens_roi_sender: tokens and roi of the producer (in the buffer)
# tokens_roi_recver: tokens and roi of the consumer (query)
tokens_sender = tokens_roi_sender[0]
tokens_recver = tokens_roi_recver[0]
roi_sender = tokens_roi_sender[1]
roi_recver = tokens_roi_recver[1]
if tokens_recver is None:
# consumer sends an empty request
# semantics: DROP SELECT * LIMIT 1
# so any of the data in the buffer can be drop-selected
return True
# Assuming that roi is a binary mask on tokens
tokens_sender = tokens_sender[roi_sender]
tokens_recver = tokens_recver[roi_recver]
# simple common prefix matching
min_length = min(len(tokens_sender), len(tokens_recver))
if torch.allclose(tokens_sender[:min_length], tokens_recver[:min_length]):
return min_length
return 0
def _send_tensor_and_dec_size(self, tensor: torch.Tensor | None) -> None:
assert tensor is not None, "Use self.data_pipe.send(None) instead"
self.buffer_size -= tensor.element_size() * tensor.numel()
if tensor.dtype == torch.bool:
tensor = tensor.float()
self.data_pipe.send_tensor(tensor)
def _get_element_size(self, data: list | torch.Tensor | None):
if isinstance(data, torch.Tensor):
return data.element_size() * data.numel()
if not data:
# cannot perform `not data` on a tensor
# so this check needs to go after the check above
return 0
raise AssertionError(f"Unknown data type {type(data)}")
def _add_to_buffer(
self,
input_tokens: torch.Tensor,
roi: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
hidden: torch.Tensor,
):
if isinstance(input_tokens, torch.Tensor):
input_tokens = input_tokens.clone()
if isinstance(roi, torch.Tensor):
roi = roi.clone()
if isinstance(key, torch.Tensor):
key = key.clone()
if isinstance(value, torch.Tensor):
value = value.clone()
if isinstance(hidden, torch.Tensor):
hidden = hidden.clone()
buffer_item = [input_tokens, roi, key, value, hidden]
data_size = sum([self._get_element_size(data) for data in buffer_item])
with self.buffer_cv:
if self.buffer_size + data_size > self.buffer_size_threshold:
# log outside the while loop to avoid this message being logged
# repeatedly.
logger.debug("KV transfer buffer is full. Handling...")
while self.buffer_size + data_size > self.buffer_size_threshold:
self.buffer_cv.wait()
self.buffer_size += data_size
self.buffer.append(buffer_item)
self.buffer_cv.notify()
def _is_end_signal(self, signal):
return signal is None
def drop_select_handler(self):
try:
while True:
signal = self.signal_pipe.recv_tensor()
if self._is_end_signal(signal):
logger.info("Received end signal!")
break
input_tokens = self.data_pipe.recv_tensor()
roi = self.data_pipe.recv_tensor()
assert roi is not None, (
"Please provide the roi when sending drop-select request"
)
roi = roi > 0.5
tokens_roi_recver = [input_tokens, roi]
def is_buffer_available(
tokens_roi_recver: list[torch.Tensor],
) -> bool:
# perform input tokens and roi matching
# FIXME: this matching is O(n), ideally it should be O(1)
# but this buffer size won't (and shouldn't) be too large so
# the fix is not urgent.
for _ in range(len(self.buffer)):
if self._matches(self.buffer[0], tokens_roi_recver) > 0:
return True
# rotate the element we just accessed to the end
self.buffer.rotate(-1)
return False
with self.buffer_cv:
while not is_buffer_available(tokens_roi_recver):
logger.debug("KV transfer buffer is not available. Waiting...")
self.buffer_cv.wait()
# need to clone the tensor
# in case the tensor is freed before sending finishes
matched_item = self.buffer.popleft()
for tensor in matched_item:
self._send_tensor_and_dec_size(tensor)
self.buffer_cv.notify()
except RuntimeError as e:
if "Connection closed by peer" not in str(e):
raise e
logger.debug("Closing drop_select_handler")
def drop_select(
self, input_tokens: torch.Tensor | None, roi: torch.Tensor | None
) -> list[torch.Tensor | None]:
assert self.request_handling_thread is None, (
"drop_select should be called by the KV cache consumer "
"(e.g. the decode vLLM instance)"
)
if isinstance(input_tokens, torch.Tensor):
input_tokens = input_tokens.clone()
if isinstance(roi, torch.Tensor):
roi = roi.clone().float()
self.signal_pipe.send_tensor(self.normal_signal)
self.data_pipe.send_tensor(input_tokens)
self.data_pipe.send_tensor(roi)
input_tokens = self.data_pipe.recv_tensor()
roi = self.data_pipe.recv_tensor()
if roi is not None:
# convert from float tensor to bool tensor
# as PyNccl does not support sending bool tensor
roi = roi > 0.5
key = self.data_pipe.recv_tensor()
value = self.data_pipe.recv_tensor()
hidden = self.data_pipe.recv_tensor()
return [input_tokens, roi, key, value, hidden]
def insert(
self,
input_tokens: torch.Tensor,
roi: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
hidden: torch.Tensor,
) -> None:
self._add_to_buffer(input_tokens, roi, key, value, hidden)
# when calling the insert, the current process is a sender
# need to launch the request handler and start listening to request.
if self.request_handling_thread is None:
self.request_handling_thread = threading.Thread(
target=self.drop_select_handler
)
self.request_handling_thread.start()
def close(self):
if (
hasattr(self, "request_handling_thread")
and self.request_handling_thread is not None
):
self.request_handling_thread.join()
else:
# TODO: have a explicit close signal and have a explicit way to
# check if it's requester
self.signal_pipe.send_tensor(self.end_signal)

View File

@@ -0,0 +1,66 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This file defines an interface `KVPipeBase`
that provides an abstraction for sending and receiving tensors, or None, via
distributed communications.
All classes instantiated from this interface are assumed to be a FIFO pipe.
If your distributed communication platform already supports key-value lookup,
you can bypass this interface and directly start from `kv_lookup_buffer`.
"""
from abc import ABC, abstractmethod
import torch
class KVPipeBase(ABC):
"""
This class provides an interface for sending and receiving tensors, or
None, by distributed communications.
"""
@abstractmethod
def send_tensor(self, tensor: torch.Tensor | None) -> None:
"""Send a tensor, or None, via the pipe.
Need to support sending None -- important for error handling.
TODO: add a `key` argument so that we can use traditional
key-value database as the distributed communication mechanism behind
the pipe.
Args:
tensor (Optional[torch.Tensor]): The tensor to be sent. Can be None.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise NotImplementedError
@abstractmethod
def recv_tensor(self) -> torch.Tensor | None:
"""Receive a tensor (can be None) from the pipeline.
Returns:
Optional[torch.Tensor]: The tensor received from the pipeline. Can
be None.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise NotImplementedError
@abstractmethod
def close(self) -> None:
"""Close the pipeline and release resources.
This method is responsible for closing the communication pipeline
and releasing any resources associated with it.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise NotImplementedError

View File

@@ -0,0 +1,295 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import os
import struct
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
import torch
import zmq
from safetensors.torch import load as safetensors_load
from safetensors.torch import save as safetensors_save
from vllm.config.kv_transfer import KVTransferConfig
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
from vllm.logger import init_logger
from vllm.utils.network_utils import join_host_port, make_zmq_path, split_host_port
logger = init_logger(__name__)
NONE_INT = -150886311
@dataclass
class MooncakeTransferEngineConfig:
prefill_url: str
decode_url: str
metadata_backend: str | None
metadata_server: str
protocol: str
device_name: str
@staticmethod
def from_file(file_path: str) -> "MooncakeTransferEngineConfig":
"""Load the config from a JSON file."""
with open(file_path) as fin:
config = json.load(fin)
return MooncakeTransferEngineConfig(
prefill_url=config.get("prefill_url"),
decode_url=config.get("decode_url"),
metadata_backend=config.get("metadata_backend", None),
metadata_server=config.get("metadata_server"),
protocol=config.get("protocol", "tcp"),
device_name=config.get("device_name", ""),
)
@staticmethod
def load_from_env() -> "MooncakeTransferEngineConfig":
"""Load config from a file specified in the environment variable."""
config_file_path = os.getenv("MOONCAKE_CONFIG_PATH")
if config_file_path is None:
raise ValueError(
"The environment variable 'MOONCAKE_CONFIG_PATH' is not set."
)
return MooncakeTransferEngineConfig.from_file(config_file_path)
class MooncakeTransferEngine:
"""Handles the transfer of data using mooncake_vllm_adaptor and ZeroMQ."""
def __init__(self, kv_rank: int, local_rank: int):
try:
from mooncake.engine import TransferEngine
except ImportError as e:
raise ImportError(
"Please install mooncake by following the instructions at "
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501
"to run vLLM with MooncakeConnector."
) from e
self.engine = TransferEngine()
self.local_rank = local_rank
try:
self.config = MooncakeTransferEngineConfig.load_from_env()
logger.info("Mooncake Configuration loaded successfully.")
except ValueError as e:
logger.error(e)
raise
except Exception as exc:
logger.error("An error occurred while loading the configuration: %s", exc)
raise
prefill_host, base_prefill_port = split_host_port(self.config.prefill_url)
decode_host, base_decode_port = split_host_port(self.config.decode_url)
# Avoid ports conflict when running prefill and decode on the same node
if prefill_host == decode_host and base_prefill_port == base_decode_port:
base_decode_port = base_decode_port + 100
prefill_port = base_prefill_port + self.local_rank
decode_port = base_decode_port + self.local_rank
self.prefill_url = join_host_port(prefill_host, prefill_port)
self.decode_url = join_host_port(decode_host, decode_port)
self.initialize(
self.prefill_url if kv_rank == 0 else self.decode_url,
self.config.metadata_server,
self.config.protocol,
self.config.device_name,
self.config.metadata_backend,
)
self.remote_url = self.decode_url if kv_rank == 0 else self.prefill_url
# Initialize ZeroMQ context and sockets
self.context = zmq.Context() # type: ignore[attr-defined]
self.sender_socket = self.context.socket(zmq.constants.PUSH)
self.receiver_socket = self.context.socket(zmq.constants.PULL)
self.sender_ack = self.context.socket(zmq.constants.PULL)
self.receiver_ack = self.context.socket(zmq.constants.PUSH)
self.buffer_cleaner = ThreadPoolExecutor(max_workers=1)
self._setup_metadata_sockets(
kv_rank, prefill_host, base_prefill_port, decode_host, base_decode_port
)
def _setup_metadata_sockets(
self, kv_rank: int, p_host: str, p_port: int, d_host: str, d_port: int
) -> None:
"""Set up ZeroMQ sockets for sending and receiving data."""
# Offsets < 8 are left for initialization in case tp and pp are enabled
p_rank_offset = p_port + 8 + self.local_rank * 2
d_rank_offset = d_port + 8 + self.local_rank * 2
if kv_rank == 0:
self.sender_socket.bind(make_zmq_path("tcp", p_host, p_rank_offset + 1))
self.receiver_socket.connect(
make_zmq_path("tcp", d_host, d_rank_offset + 1)
)
self.sender_ack.connect(make_zmq_path("tcp", d_host, d_rank_offset + 2))
self.receiver_ack.bind(make_zmq_path("tcp", p_host, p_rank_offset + 2))
else:
self.receiver_socket.connect(
make_zmq_path("tcp", p_host, p_rank_offset + 1)
)
self.sender_socket.bind(make_zmq_path("tcp", d_host, d_rank_offset + 1))
self.receiver_ack.bind(make_zmq_path("tcp", d_host, d_rank_offset + 2))
self.sender_ack.connect(make_zmq_path("tcp", p_host, p_rank_offset + 2))
def initialize(
self,
local_hostname: str,
metadata_server: str,
protocol: str,
device_name: str,
metadata_backend: str | None,
) -> None:
"""Initialize the mooncake instance."""
if metadata_backend is None:
self.engine.initialize(
local_hostname, metadata_server, protocol, device_name
)
else:
supported_backend = ["etcd", "redis"]
metadata_backend = metadata_backend.lower()
if metadata_backend not in supported_backend:
raise ValueError(
"Mooncake Configuration error. `metadata_backend`"
f" should be one of {supported_backend}."
)
self.engine.initialize_ext(
local_hostname, metadata_server, protocol, device_name, metadata_backend
)
def allocate_managed_buffer(self, length: int) -> int:
"""Allocate a managed buffer of the specified length."""
ret = self.engine.allocate_managed_buffer(length)
if ret <= 0:
logger.error("Allocation Return Error")
raise Exception("Allocation Return Error")
return ret
def free_managed_buffer(self, buffer: int, length: int) -> int:
"""Free a previously allocated managed buffer."""
return self.engine.free_managed_buffer(buffer, length)
def transfer_sync(self, buffer: int, peer_buffer_address: int, length: int) -> int:
"""Synchronously transfer data to the specified address."""
ret = self.engine.transfer_sync_read(
self.remote_url, buffer, peer_buffer_address, length
)
if ret < 0:
logger.error("Transfer Return Error")
raise Exception("Transfer Return Error")
return ret
def write_bytes_to_buffer(self, buffer: int, user_data: bytes, length: int) -> int:
"""Write bytes to the allocated buffer."""
return self.engine.write_bytes_to_buffer(buffer, user_data, length)
def read_bytes_from_buffer(self, buffer: int, length: int) -> bytes:
"""Read bytes from the allocated buffer."""
return self.engine.read_bytes_from_buffer(buffer, length)
def wait_for_ack(self, src_ptr: int, length: int) -> None:
"""Asynchronously wait for ACK from the receiver."""
ack = self.sender_ack.recv()
if ack != b"ACK":
logger.error("Failed to receive ACK from the receiver")
self.free_managed_buffer(src_ptr, length)
def send_bytes(self, user_data: bytes) -> None:
"""Send bytes to the remote process."""
length = len(user_data)
src_ptr = self.allocate_managed_buffer(length)
self.write_bytes_to_buffer(src_ptr, user_data, length)
self.sender_socket.send_multipart(
[struct.pack("!Q", src_ptr), struct.pack("!Q", length)]
)
self.buffer_cleaner.submit(self.wait_for_ack, src_ptr, length)
def recv_bytes(self) -> bytes:
"""Receive bytes from the remote process."""
data = self.receiver_socket.recv_multipart()
src_ptr = struct.unpack("!Q", data[0])[0]
length = struct.unpack("!Q", data[1])[0]
dst_ptr = self.allocate_managed_buffer(length)
self.transfer_sync(dst_ptr, src_ptr, length)
ret = self.read_bytes_from_buffer(dst_ptr, length)
# Buffer cleanup
self.receiver_ack.send(b"ACK")
self.free_managed_buffer(dst_ptr, length)
return ret
class MooncakePipe(KVPipeBase):
"""MooncakeTransferEngine based Pipe implementation."""
def __init__(
self, local_rank: int, config: KVTransferConfig, device: str | None = None
):
"""Initialize the mooncake pipe and set related parameters."""
self.config = config
self.local_rank = local_rank
self.kv_rank = self.config.kv_rank
assert self.kv_rank is not None
if device is None:
self.device = self._select_device(self.config.kv_buffer_device)
else:
self.device = self._select_device(device)
self.transfer_engine = MooncakeTransferEngine(self.kv_rank, self.local_rank)
self.transport_thread: ThreadPoolExecutor | None = None
self.none_tensor = torch.tensor([NONE_INT], device=self.device)
def _select_device(self, device: str) -> torch.device:
"""Select available device (CUDA or CPU)."""
logger.info("Selecting device: %s", device)
if device == "cuda":
return torch.device(f"cuda:{self.local_rank}")
else:
return torch.device("cpu")
def tensor_hash(self, tensor: torch.Tensor) -> int:
"""Calculate the hash value of the tensor."""
return hash(tensor.data_ptr())
def _send_impl(self, tensor: torch.Tensor) -> None:
"""Implement the tensor sending logic using safetensors."""
self.transfer_engine.send_bytes(safetensors_save({"tensor": tensor}))
def _recv_impl(self) -> torch.Tensor:
"""Implement the tensor receiving logic using safetensors."""
data = self.transfer_engine.recv_bytes()
return safetensors_load(data)["tensor"].to(self.device)
def send_tensor(self, tensor: torch.Tensor | None) -> None:
"""Send tensor to the target process."""
if self.transport_thread is None:
self.transport_thread = ThreadPoolExecutor(max_workers=1)
tensor = tensor if tensor is not None else self.none_tensor
assert len(tensor.shape) > 0
self.transport_thread.submit(self._send_impl, tensor)
def recv_tensor(self) -> torch.Tensor | None:
"""Receive tensor from other processes."""
if self.transport_thread is None:
self.transport_thread = ThreadPoolExecutor(max_workers=1)
tensor = self.transport_thread.submit(self._recv_impl).result()
if tensor.numel() == 1 and tensor.item() == NONE_INT:
return None
else:
return tensor
def close(self) -> None:
"""Cleanup logic when closing the pipe."""
self.transfer_engine.sender_socket.close()
self.transfer_engine.receiver_socket.close()
self.transfer_engine.sender_ack.close()
self.transfer_engine.receiver_ack.close()
self.transfer_engine.context.term() # Terminate the ZMQ context
logger.info("Closed the transfer engine and cleaned up resources.")

View File

@@ -0,0 +1,285 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This module implements a PyNccl pipe for sending and receiving
Optional[torch.Tensor] between distributed ranks with advanced
communication features.
Key Features:
- Supports sending and receiving tensors with metadata
- Handles both CUDA and CPU device communications
- Implements a non-blocking tensor transfer mechanism
- Manages buffer size and provides backpressure control
- Supports distributed process groups with configurable parameters
"""
import threading
import time
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor
import torch
from vllm.config.kv_transfer import KVTransferConfig
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger
logger = init_logger(__name__)
class BrokenPipeException(Exception):
def __init__(self, message):
self.message = message
super().__init__(self.message)
Metadata = dict[str, torch.Tensor | None]
class PyNcclPipe(KVPipeBase):
METADATA_LENGTH = 16
MAX_TENSOR_DIMENSIONS = 14
METADATA_DTYPE = torch.int64
def __init__(
self,
local_rank: int,
config: KVTransferConfig,
device: str | None = None,
port_offset: int = 0,
):
self.config = config
self.local_rank = local_rank
self.kv_rank = self.config.kv_rank
assert self.kv_rank is not None
self.kv_parallel_size = self.config.kv_parallel_size
if device is None:
self.device = self._select_device(self.config.kv_buffer_device)
else:
self.device = self._select_device(device)
# build distributed connection and send/recv implementation
store_timeout = self.config.get_from_extra_config("store_timeout", 300)
self.group = StatelessProcessGroup.create(
host=self.config.kv_ip,
port=self.config.kv_port + port_offset,
rank=self.kv_rank,
world_size=self.kv_parallel_size,
store_timeout=store_timeout,
)
# add a barrier to make sure the connection is initiated properly
self.group.barrier()
impl = self._get_device_send_recv_impl(self.group)
self.device_send_func, self.device_recv_func = impl
# set target rank
self.target_rank_for_send = (self.kv_rank + 1) % self.kv_parallel_size
self.target_rank_for_recv = (self.kv_rank - 1) % self.kv_parallel_size
# transportation-related variables
self.transport_thread: ThreadPoolExecutor | None = None
self.buffer_size = 0
self.buffer_size_lock = threading.Lock()
self.buffer_size_thresh = self.config.kv_buffer_size
def _get_device_send_recv_impl(
self, group: StatelessProcessGroup
) -> tuple[
Callable[[torch.Tensor, int], None], Callable[[torch.Tensor, int], None]
]:
send: Callable[[torch.Tensor, int], None]
recv: Callable[[torch.Tensor, int], None]
if self.device.type == "cuda":
# use PyNCCL for send / recv
comm = PyNcclCommunicator(group, device=self.local_rank)
comm.disabled = False
send, recv = comm.send, comm.recv # type: ignore
else:
# This send / recv implementation here is NOT intended to transfer
# KV caches (and should NOT be repurposed to transfer KV caches).
# Currently it is only used to transmit control-plane messages
# for PyNcclBuffer.
send = group.send_obj
def my_recv(x, src):
x[...] = group.recv_obj(src)
recv = my_recv
return send, recv
def _select_device(self, device: str):
logger.info("Selecting device: %s", device)
if device == "cuda":
return torch.device(f"cuda:{self.local_rank}")
else:
return torch.device("cpu")
def _make_metadata(self, tensor: torch.Tensor | None) -> Metadata:
"""
Create the metadata as a dictionary based on the input tensor.
Args:
tensor: The input tensor or None if no tensor is provided.
Returns:
metadata: A dictionary with the following keys:
- "dtype": The data type of the tensor or None.
- "shape": The shape of the tensor or None.
"""
if tensor is None:
return {"dtype": None, "shape": None}
else:
return {"dtype": tensor.dtype, "shape": tensor.shape}
def _prepare_recv_buffer(self, metadata: Metadata) -> torch.Tensor:
"""
Create a buffer to receive the tensor based on the provided metadata.
Args:
metadata: A dictionary with keys "dtype" and "shape",
describing the tensor's data type and shape.
Returns:
buffer: A tensor of the specified type and shape,
allocated on `self.device`.
"""
return torch.empty(
metadata["shape"], dtype=metadata["dtype"], device=self.device
)
def _send_metadata(self, metadata: Metadata):
"""
Send the metadata dictionary to the target rank.
Args:
metadata: A dictionary with keys "dtype" and "shape".
"""
self.group.send_obj(metadata, self.target_rank_for_send)
def _recv_metadata(self) -> Metadata:
"""
Receive the metadata dictionary from the target rank.
Returns:
metadata: A dictionary with keys "dtype" and "shape"
describing the tensor.
"""
return self.group.recv_obj(self.target_rank_for_recv)
def _send_impl(self, tensor: torch.Tensor | None) -> None:
"""
The actual implementation of sending the tensor and its metadata to the
target rank.
Args:
tensor: The input tensor to be sent, or `None` if no tensor is
being sent.
"""
metadata = self._make_metadata(tensor)
self._send_metadata(metadata)
if tensor is not None:
self.device_send_func(tensor.to(self.device), self.target_rank_for_send)
def _recv_impl(self) -> torch.Tensor | None:
"""
The actual implementation of receiving a tensor and its metadata from
the target rank.
Returns:
buffer: The received tensor, or `None` if no tensor is received.
"""
metadata = self._recv_metadata()
if metadata["dtype"] is None:
return None
buffer = self._prepare_recv_buffer(metadata)
self.device_recv_func(buffer, self.target_rank_for_recv)
return buffer
def send_tensor_wrapper(
self, tensor: torch.Tensor | None, tensor_size: int
) -> None:
"""
Wrapper for _send_impl to handle exceptions and update buffer size.
"""
try:
self._send_impl(tensor)
with self.buffer_size_lock:
self.buffer_size -= tensor_size
except Exception as e:
logger.error(
"[rank%d]: Exception when trying to send %s, msg: %s",
torch.distributed.get_rank(),
str(tensor),
str(e),
)
import traceback
traceback.print_exc()
def block_if_full(self):
"""
Block the current thread if the buffer size is larger than the
threshold.
"""
while self.buffer_size > self.buffer_size_thresh:
logger.debug("KV cache transfer pipe is full. Waiting...")
time.sleep(0.05)
def send_tensor(self, tensor: torch.Tensor | None) -> None:
"""
Sends a tensor and its metadata to the destination rank in a
non-blocking way.
Args:
tensor: The tensor to send, or `None` if no tensor is being sent.
"""
if self.transport_thread is None:
self.transport_thread = ThreadPoolExecutor(max_workers=1)
if tensor is not None:
tensor_size = tensor.element_size() * tensor.numel()
else:
tensor_size = 0
self.block_if_full()
with self.buffer_size_lock:
self.buffer_size += tensor_size
self.transport_thread.submit(self.send_tensor_wrapper, tensor, tensor_size)
def recv_tensor(self) -> torch.Tensor | None:
"""
Receives a tensor and its metadata from the source rank. Blocking call.
Returns:
The received tensor, or `None` if no tensor is received.
"""
if self.transport_thread is None:
self.transport_thread = ThreadPoolExecutor(max_workers=1)
future = self.transport_thread.submit(self._recv_impl)
try:
tensor = future.result()
except Exception as e:
logger.error("Encountering exception in KV receiving thread")
logger.error("%s", e)
logger.error("My device: %s", self.device)
import traceback
traceback.print_exc()
raise e
return tensor
def close(self):
"""
Close the pipe and release associated resources.
"""
if hasattr(self, "transport_thread") and self.transport_thread is not None:
self.transport_thread.shutdown()

View File

@@ -0,0 +1,78 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Optional
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
from vllm.distributed.kv_transfer.kv_connector.v1 import (
KVConnectorBase_V1,
KVConnectorRole,
)
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.v1.kv_cache_interface import KVCacheConfig
_KV_CONNECTOR_AGENT: KVConnectorBaseType | None = None
def get_kv_transfer_group() -> KVConnectorBaseType:
assert _KV_CONNECTOR_AGENT is not None, (
"disaggregated KV cache transfer parallel group is not initialized"
)
return _KV_CONNECTOR_AGENT
def has_kv_transfer_group() -> bool:
return _KV_CONNECTOR_AGENT is not None
def is_v1_kv_transfer_group(connector: KVConnectorBaseType | None = None) -> bool:
"""Check if the KV connector is the v1 connector.
If the argument is None, it will check the global KV connector
Args:
connector: The KV connector to check. If None, it will check the
global KV connector.
Note:
This function will no-longer be needed after the v1 KV connector
becomes the default.
"""
if connector is None:
connector = _KV_CONNECTOR_AGENT
if connector is None:
return False
return isinstance(connector, KVConnectorBase_V1)
def ensure_kv_transfer_initialized(
vllm_config: "VllmConfig", kv_cache_config: Optional["KVCacheConfig"] = None
) -> None:
"""
Initialize KV cache transfer parallel group.
"""
global _KV_CONNECTOR_AGENT
if vllm_config.kv_transfer_config is None:
return
if (
vllm_config.kv_transfer_config.is_kv_transfer_instance
and _KV_CONNECTOR_AGENT is None
):
_KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector(
config=vllm_config,
role=KVConnectorRole.WORKER,
kv_cache_config=kv_cache_config,
)
def ensure_kv_transfer_shutdown() -> None:
global _KV_CONNECTOR_AGENT
if _KV_CONNECTOR_AGENT is not None:
_KV_CONNECTOR_AGENT.shutdown()
_KV_CONNECTOR_AGENT = None