Add minimal vLLM 0.16.1 build repo for BI-V150
This commit is contained in:
10
vllm/distributed/kv_transfer/kv_connector/base.py
Normal file
10
vllm/distributed/kv_transfer/kv_connector/base.py
Normal file
@@ -0,0 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Defines the base type for KV cache connectors."""
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
|
||||
|
||||
KVConnectorBase = KVConnectorBase_V1
|
||||
KVConnectorBaseType = KVConnectorBase_V1
|
||||
|
||||
__all__ = ["KVConnectorBase", "KVConnectorBaseType"]
|
||||
203
vllm/distributed/kv_transfer/kv_connector/factory.py
Normal file
203
vllm/distributed/kv_transfer/kv_connector/factory.py
Normal file
@@ -0,0 +1,203 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import importlib
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_connector.base import (
|
||||
KVConnectorBase,
|
||||
KVConnectorBaseType,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import (
|
||||
KVConnectorRole,
|
||||
supports_hma,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.func_utils import supports_kw
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.kv_transfer import KVTransferConfig
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class KVConnectorFactory:
|
||||
_registry: dict[str, Callable[[], type[KVConnectorBase]]] = {}
|
||||
|
||||
@classmethod
|
||||
def register_connector(cls, name: str, module_path: str, class_name: str) -> None:
|
||||
"""Register a connector with a lazy-loading module and class name."""
|
||||
if name in cls._registry:
|
||||
raise ValueError(f"Connector '{name}' is already registered.")
|
||||
|
||||
def loader() -> type[KVConnectorBase]:
|
||||
module = importlib.import_module(module_path)
|
||||
return getattr(module, class_name)
|
||||
|
||||
cls._registry[name] = loader
|
||||
|
||||
@classmethod
|
||||
def create_connector(
|
||||
cls,
|
||||
config: "VllmConfig",
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: "KVCacheConfig | None" = None,
|
||||
) -> KVConnectorBase:
|
||||
kv_transfer_config = config.kv_transfer_config
|
||||
if kv_transfer_config is None:
|
||||
raise ValueError("kv_transfer_config must be set to create a connector")
|
||||
connector_cls, compat_sig = cls._get_connector_class_with_compat(
|
||||
kv_transfer_config
|
||||
)
|
||||
|
||||
# check if the connector supports HMA
|
||||
hma_enabled = not config.scheduler_config.disable_hybrid_kv_cache_manager
|
||||
if hma_enabled and not supports_hma(connector_cls):
|
||||
raise ValueError(
|
||||
f"Connector {connector_cls.__name__} does not support HMA but "
|
||||
f"HMA is enabled. Please set `--disable-hybrid-kv-cache-manager`."
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Creating v1 connector with name: %s and engine_id: %s",
|
||||
connector_cls.__name__,
|
||||
kv_transfer_config.engine_id,
|
||||
)
|
||||
# NOTE(Kuntai): v1 connector is explicitly separated into two roles.
|
||||
# Scheduler connector:
|
||||
# - Co-locate with scheduler process
|
||||
# - Should only be used inside the Scheduler class
|
||||
# Worker connector:
|
||||
# - Co-locate with worker process
|
||||
# - Should only be used inside the forward context & attention layer
|
||||
# We build separately to enforce strict separation
|
||||
if compat_sig:
|
||||
# Old signature: __init__(self, vllm_config, role)
|
||||
return connector_cls(config, role)
|
||||
else:
|
||||
# New signature: __init__(self, vllm_config, role, kv_cache_config)
|
||||
return connector_cls(config, role, kv_cache_config)
|
||||
|
||||
@classmethod
|
||||
def get_connector_class_by_name(
|
||||
cls, connector_name: str
|
||||
) -> type[KVConnectorBaseType]:
|
||||
"""Get a registered connector class by name.
|
||||
|
||||
Raises ValueError if the connector is not registered.
|
||||
|
||||
Args:
|
||||
connector_name: Name of the registered connector.
|
||||
|
||||
Returns:
|
||||
The connector class.
|
||||
"""
|
||||
if connector_name not in cls._registry:
|
||||
raise ValueError(f"Connector '{connector_name}' is not registered.")
|
||||
return cls._registry[connector_name]()
|
||||
|
||||
@classmethod
|
||||
def _get_connector_class_with_compat(
|
||||
cls, kv_transfer_config: "KVTransferConfig"
|
||||
) -> tuple[type[KVConnectorBaseType], bool]:
|
||||
connector_name = kv_transfer_config.kv_connector
|
||||
if connector_name is None:
|
||||
raise ValueError("Connector name is not set in KVTransferConfig")
|
||||
compat_sig = False
|
||||
if connector_name in cls._registry:
|
||||
connector_cls = cls._registry[connector_name]()
|
||||
else:
|
||||
connector_module_path = kv_transfer_config.kv_connector_module_path
|
||||
if connector_module_path is None:
|
||||
raise ValueError(f"Unsupported connector type: {connector_name}")
|
||||
connector_module = importlib.import_module(connector_module_path)
|
||||
try:
|
||||
connector_cls = getattr(connector_module, connector_name)
|
||||
except AttributeError as e:
|
||||
raise AttributeError(
|
||||
f"Class {connector_name} not found in {connector_module_path}"
|
||||
) from e
|
||||
connector_cls = cast(type[KVConnectorBaseType], connector_cls)
|
||||
if not supports_kw(connector_cls, "kv_cache_config"):
|
||||
compat_sig = True
|
||||
logger.warning(
|
||||
"Connector %s uses deprecated signature with 2 required arguments. "
|
||||
"Please update to include kv_cache_config as the second argument.",
|
||||
connector_cls.__name__,
|
||||
)
|
||||
return connector_cls, compat_sig
|
||||
|
||||
@classmethod
|
||||
def get_connector_class(
|
||||
cls, kv_transfer_config: "KVTransferConfig"
|
||||
) -> type[KVConnectorBaseType]:
|
||||
"""Get the connector class by name."""
|
||||
connector_cls, _ = cls._get_connector_class_with_compat(kv_transfer_config)
|
||||
return connector_cls
|
||||
|
||||
|
||||
# Register various connectors here.
|
||||
# The registration should not be done in each individual file, as we want to
|
||||
# only load the files corresponding to the current connector.
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"ExampleConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.example_connector",
|
||||
"ExampleConnector",
|
||||
)
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"P2pNcclConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_connector",
|
||||
"P2pNcclConnector",
|
||||
)
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"LMCacheConnectorV1",
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector",
|
||||
"LMCacheConnectorV1",
|
||||
)
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"LMCacheMPConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.lmcache_mp_connector",
|
||||
"LMCacheMPConnector",
|
||||
)
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"NixlConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector",
|
||||
"NixlConnector",
|
||||
)
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"MultiConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.multi_connector",
|
||||
"MultiConnector",
|
||||
)
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"MoRIIOConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector",
|
||||
"MoRIIOConnector",
|
||||
)
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"OffloadingConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.offloading_connector",
|
||||
"OffloadingConnector",
|
||||
)
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"DecodeBenchConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.decode_bench_connector",
|
||||
"DecodeBenchConnector",
|
||||
)
|
||||
KVConnectorFactory.register_connector(
|
||||
"MooncakeConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.mooncake.mooncake_connector",
|
||||
"MooncakeConnector",
|
||||
)
|
||||
502
vllm/distributed/kv_transfer/kv_connector/utils.py
Normal file
502
vllm/distributed/kv_transfer/kv_connector/utils.py
Normal file
@@ -0,0 +1,502 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
KV cache helper for store.
|
||||
"""
|
||||
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig, get_current_vllm_config, get_layers_from_vllm_config
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.backend import AttentionBackend
|
||||
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
EngineId = str
|
||||
|
||||
|
||||
def get_kv_connector_cache_layout():
|
||||
# NOTE (NickLucche) When running disaggregated PD with NIXL, HND layout is
|
||||
# used for faster transfer.
|
||||
vllm_config = get_current_vllm_config()
|
||||
kv_config = vllm_config.kv_transfer_config
|
||||
if kv_config is not None:
|
||||
connector_cls = KVConnectorFactory.get_connector_class(kv_config)
|
||||
required_kvcache_layout = connector_cls.get_required_kvcache_layout(vllm_config)
|
||||
if required_kvcache_layout is not None:
|
||||
return required_kvcache_layout
|
||||
logger.info_once(
|
||||
"Connectors do not specify a kv cache layout, defaulting to NHD."
|
||||
)
|
||||
return "NHD"
|
||||
|
||||
|
||||
class KVOutputAggregator:
|
||||
"""Utility class to aggregate the output of all workers into a single
|
||||
output corresponding to Rank 0 for scheduler."""
|
||||
|
||||
def __init__(self, expected_finished_count: int):
|
||||
# Complete transfer tracker. Used to track finished requests
|
||||
# [req_id -> n_remaining_workers]
|
||||
self._recv_remaining_count = dict[str, int]()
|
||||
self._send_remaining_count = dict[str, int]()
|
||||
self._expected_finished_count = expected_finished_count
|
||||
|
||||
@classmethod
|
||||
def from_connector(cls, connector: "KVConnectorBase", world_size: int):
|
||||
return cls(connector.get_finished_count() or world_size)
|
||||
|
||||
def aggregate(
|
||||
self, outputs: list[ModelRunnerOutput | None], output_rank: int = 0
|
||||
) -> ModelRunnerOutput | None:
|
||||
if not outputs[output_rank]:
|
||||
return None
|
||||
|
||||
# Aggregate kv_connector_output from all workers
|
||||
|
||||
def update_finished_set(
|
||||
req_ids: set[str] | None,
|
||||
remaining_count_dict: dict[str, int],
|
||||
finished_set: set[str],
|
||||
) -> None:
|
||||
for req_id in req_ids or ():
|
||||
remaining_count = remaining_count_dict.get(
|
||||
req_id, self._expected_finished_count
|
||||
)
|
||||
remaining_count_dict[req_id] = remaining_count - 1
|
||||
if remaining_count_dict[req_id] == 0:
|
||||
finished_set.add(req_id)
|
||||
del remaining_count_dict[req_id]
|
||||
|
||||
finished_sending = set[str]()
|
||||
finished_recving = set[str]()
|
||||
aggregated_kv_connector_stats = None
|
||||
combined_kv_cache_events = None
|
||||
invalid_block_ids = set[int]()
|
||||
for model_runner_output in outputs:
|
||||
assert model_runner_output is not None
|
||||
kv_output = model_runner_output.kv_connector_output
|
||||
if not kv_output:
|
||||
continue
|
||||
# Allow the worker to dynamically update the expected number of
|
||||
# finished sending/recving for new requests.
|
||||
if (
|
||||
kv_output.expected_finished_count > 0
|
||||
and kv_output.expected_finished_count != self._expected_finished_count
|
||||
):
|
||||
logger.debug(
|
||||
"Expected finished requests updated from %d to %d",
|
||||
self._expected_finished_count,
|
||||
kv_output.expected_finished_count,
|
||||
)
|
||||
self._expected_finished_count = kv_output.expected_finished_count
|
||||
|
||||
update_finished_set(
|
||||
kv_output.finished_sending, self._send_remaining_count, finished_sending
|
||||
)
|
||||
update_finished_set(
|
||||
kv_output.finished_recving, self._recv_remaining_count, finished_recving
|
||||
)
|
||||
|
||||
# Aggregate kv_connector_stats from all workers.
|
||||
if aggregated_kv_connector_stats is None:
|
||||
# Use the first worker's kv_connector_stats as accumulator.
|
||||
aggregated_kv_connector_stats = kv_output.kv_connector_stats
|
||||
elif kv_connector_stats := kv_output.kv_connector_stats:
|
||||
if aggregated_kv_connector_stats is None:
|
||||
aggregated_kv_connector_stats = kv_connector_stats
|
||||
else:
|
||||
assert isinstance(
|
||||
aggregated_kv_connector_stats, type(kv_connector_stats)
|
||||
)
|
||||
aggregated_kv_connector_stats = (
|
||||
aggregated_kv_connector_stats.aggregate(kv_connector_stats)
|
||||
)
|
||||
|
||||
# Combine kv_cache_events from all workers.
|
||||
if combined_kv_cache_events is None:
|
||||
# Use the first worker's kv_cache events as start event list.
|
||||
combined_kv_cache_events = kv_output.kv_cache_events
|
||||
elif kv_cache_events := kv_output.kv_cache_events:
|
||||
assert isinstance(
|
||||
combined_kv_cache_events,
|
||||
type(kv_cache_events),
|
||||
)
|
||||
worker_kv_cache_events = kv_cache_events.get_all_events()
|
||||
combined_kv_cache_events.add_events(worker_kv_cache_events)
|
||||
combined_kv_cache_events.increment_workers(1)
|
||||
|
||||
invalid_block_ids |= kv_output.invalid_block_ids
|
||||
|
||||
# select output of the worker specified by output_rank
|
||||
output = outputs[output_rank]
|
||||
|
||||
assert output is not None
|
||||
output.kv_connector_output = KVConnectorOutput(
|
||||
finished_sending=finished_sending or None,
|
||||
finished_recving=finished_recving or None,
|
||||
kv_connector_stats=aggregated_kv_connector_stats or None,
|
||||
kv_cache_events=combined_kv_cache_events or None,
|
||||
invalid_block_ids=invalid_block_ids,
|
||||
expected_finished_count=self._expected_finished_count,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _make_src_and_dst_indices(
|
||||
src_block_ids: list[int],
|
||||
dst_block_ids: list[int],
|
||||
src_device: torch.device | str,
|
||||
dst_device: torch.device | str,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
src_indices = torch.tensor(src_block_ids, device=src_device, dtype=torch.int64)
|
||||
dst_indices = torch.tensor(dst_block_ids, device=dst_device, dtype=torch.int64)
|
||||
return src_indices, dst_indices
|
||||
|
||||
|
||||
def copy_kv_blocks(
|
||||
src_kv_caches: dict[str, torch.Tensor],
|
||||
dst_kv_caches: dict[str, torch.Tensor],
|
||||
src_block_ids: list[int],
|
||||
dst_block_ids: list[int],
|
||||
direction: Literal["h2d", "d2h"],
|
||||
) -> None:
|
||||
"""Copy kv blocks between different buffers."""
|
||||
if (
|
||||
not src_kv_caches
|
||||
or not dst_kv_caches
|
||||
or not src_block_ids
|
||||
or not dst_block_ids
|
||||
or len(src_block_ids) != len(dst_block_ids)
|
||||
):
|
||||
return
|
||||
|
||||
src_device = next(iter(src_kv_caches.values())).device
|
||||
dst_device = next(iter(dst_kv_caches.values())).device
|
||||
|
||||
src_indices, dst_indices = _make_src_and_dst_indices(
|
||||
src_block_ids=src_block_ids,
|
||||
dst_block_ids=dst_block_ids,
|
||||
src_device=src_device,
|
||||
dst_device=dst_device,
|
||||
)
|
||||
|
||||
if direction == "h2d":
|
||||
copy_fn = current_platform.insert_blocks_to_device
|
||||
else:
|
||||
copy_fn = current_platform.swap_out_blocks_to_host
|
||||
for layer_name in src_kv_caches:
|
||||
src_tensor = src_kv_caches[layer_name]
|
||||
dst_tensor = dst_kv_caches[layer_name]
|
||||
copy_fn(src_tensor, dst_tensor, src_indices, dst_indices)
|
||||
|
||||
|
||||
def kv_postprocess_blksize_on_receive(cache, indices, block_size_ratio):
|
||||
"""
|
||||
Transforms the layout of received KV cache blocks to the local block_size.
|
||||
(Only works for local blocksize > remote blocksize)
|
||||
|
||||
example:
|
||||
local blocksize = 16 tokens, remote blocksize = 4 tokens
|
||||
local block[0] = remote block[0, 1, 2, 3]
|
||||
remote is |h0-b0|h1-b0|h2-b0|h3-b0|h0-b1|h1-b1|h2-b1|h3-b1|...
|
||||
local is |h0-b0..................|h1-b0..................|...
|
||||
permute is to:
|
||||
1. view => view remote as n_blocks * remote_shape(H,remoteN,D)
|
||||
2. permute => (H, nblocks, remoteN, D)
|
||||
3. flatten => (H, localN, D)
|
||||
"""
|
||||
blocks_to_update = cache.index_select(0, indices)
|
||||
# use physical order
|
||||
blocks_to_update = blocks_to_update.permute(0, 2, 1, 3)
|
||||
n_kv_heads, block_size, head_size = blocks_to_update.shape[1:]
|
||||
remote_block_size = block_size // block_size_ratio
|
||||
n_blocks = block_size_ratio
|
||||
|
||||
permuted_blocks = (
|
||||
blocks_to_update.reshape(-1, n_blocks, n_kv_heads, remote_block_size, head_size)
|
||||
.permute(0, 2, 1, 3, 4)
|
||||
.flatten(2, 3)
|
||||
)
|
||||
permuted_blocks = permuted_blocks.permute(0, 2, 1, 3)
|
||||
cache.index_copy_(0, indices, permuted_blocks)
|
||||
|
||||
|
||||
def kv_postprocess_layout_on_receive(cache, indices):
|
||||
"""Transforms the layout of received KV cache blocks to the local format.
|
||||
|
||||
This method corrects layout mismatches from direct memory copies by
|
||||
permuting the tensor dimensions.
|
||||
|
||||
- **Source Layout:** `[num_blocks, n_kv_head, block_size, head_dim]`
|
||||
- **Target Layout:** `[num_blocks, block_size, n_kv_head, head_dim]`
|
||||
|
||||
Implementation:
|
||||
- x = blocks_to_update.reshape(src_shape) # view local kv with sender layout
|
||||
- permuted_blocks = x.permute(*inv_order) # transpose n_kv_heads, block_size
|
||||
- cache.index_copy_(0, indices, permuted_blocks) # copy permuted kv back
|
||||
|
||||
"""
|
||||
blocks_to_update = cache.index_select(0, indices)
|
||||
target_shape = list(blocks_to_update.shape)
|
||||
target_shape[0] = -1
|
||||
inv_order = [0, 2, 1, 3]
|
||||
src_shape = tuple(target_shape[i] for i in inv_order)
|
||||
blocks_to_update = cache.index_select(0, indices)
|
||||
permuted_blocks = blocks_to_update.reshape(src_shape).permute(*inv_order)
|
||||
cache.index_copy_(0, indices, permuted_blocks)
|
||||
|
||||
|
||||
def kv_postprocess_blksize_and_layout_on_receive(cache, indices, block_size_ratio):
|
||||
"""
|
||||
Transforms the layout of received KV cache to the local block_size and HND.
|
||||
(Only works for local blocksize > remote blocksize)
|
||||
|
||||
prefill is HND, smaller block_size
|
||||
decode(local) is NHD, larger block_size
|
||||
"""
|
||||
blocks_to_update = cache.index_select(0, indices)
|
||||
|
||||
block_size, n_kv_heads, head_size = blocks_to_update.shape[1:]
|
||||
remote_block_size = block_size // block_size_ratio
|
||||
n_blocks = block_size_ratio
|
||||
|
||||
permuted_blocks = (
|
||||
blocks_to_update.reshape(-1, n_blocks, n_kv_heads, remote_block_size, head_size)
|
||||
.permute(0, 1, 3, 2, 4)
|
||||
.flatten(1, 2)
|
||||
)
|
||||
cache.index_copy_(0, indices, permuted_blocks)
|
||||
|
||||
|
||||
def yield_req_data(
|
||||
scheduler_output,
|
||||
) -> Iterator[tuple[str, tuple[list[int], ...], bool]]:
|
||||
"""
|
||||
Yields:
|
||||
(req_id, new_block_id_groups, preempted)
|
||||
"""
|
||||
# new requests
|
||||
for req_data in scheduler_output.scheduled_new_reqs:
|
||||
yield req_data.req_id, req_data.block_ids, False
|
||||
|
||||
# cached requests
|
||||
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||
yield from zip(
|
||||
cached_reqs.req_ids,
|
||||
cached_reqs.new_block_ids,
|
||||
(req_id in cached_reqs.resumed_req_ids for req_id in cached_reqs.req_ids),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TpKVTopology:
|
||||
"""
|
||||
Helper class for tensor parallel and KV topology information for
|
||||
mapping between local and remote TP workers.
|
||||
"""
|
||||
|
||||
tp_rank: int
|
||||
remote_tp_size: dict[EngineId, int]
|
||||
is_mla: bool
|
||||
total_num_kv_heads: int
|
||||
attn_backend: type[AttentionBackend]
|
||||
engine_id: EngineId
|
||||
remote_block_size: dict[EngineId, int]
|
||||
tensor_shape: torch.Size | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
# Figure out whether the first dimension of the cache is K/V
|
||||
# or num_blocks. This is used to register the memory regions correctly.
|
||||
_MOCK_BLOCK_SIZE = 16
|
||||
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
|
||||
num_blocks=1, block_size=_MOCK_BLOCK_SIZE, num_kv_heads=1, head_size=1
|
||||
)
|
||||
logger.debug("Test kv_cache_shape: %s", kv_cache_shape)
|
||||
# Non-MLA backends caches have 5 dims [2, num_blocks, H,N,D],
|
||||
# we just mock num_blocks to 1 for the dimension check below.
|
||||
self._is_kv_layout_blocks_first = (
|
||||
len(kv_cache_shape) == 5 and kv_cache_shape[0] == 1
|
||||
)
|
||||
|
||||
self._cross_layers_blocks = False
|
||||
if self.tensor_shape is not None:
|
||||
self._cross_layers_blocks = (
|
||||
len(self.tensor_shape) == len(kv_cache_shape) + 1
|
||||
)
|
||||
|
||||
if self._cross_layers_blocks:
|
||||
logger.debug("Using cross-layer KV cache")
|
||||
# prepend layers dimension
|
||||
_MOCK_NUM_LAYERS = 80
|
||||
kv_cache_shape = (_MOCK_NUM_LAYERS,) + kv_cache_shape
|
||||
try:
|
||||
kv_cache_stride_order = self.attn_backend.get_kv_cache_stride_order(
|
||||
include_num_layers_dimension=self._cross_layers_blocks
|
||||
)
|
||||
except (AttributeError, NotImplementedError):
|
||||
kv_cache_stride_order = tuple(range(len(self.tensor_shape)))
|
||||
|
||||
# In case of cross layers permute kv_cache_shape according to
|
||||
# stride_order to retrieve physical position of block_size
|
||||
kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order)
|
||||
|
||||
# In the default non-cross layers layout the block_size position
|
||||
# is logical while in the cross layers case it is the physical
|
||||
# position. This matches the shape of the actual kv cache tensors
|
||||
# passed at register_kv_caches()/register_cross_layers_kv_cache()
|
||||
block_size_position = kv_cache_shape.index(_MOCK_BLOCK_SIZE)
|
||||
|
||||
assert block_size_position is not None
|
||||
self._block_size_position = -(len(kv_cache_shape) - block_size_position)
|
||||
|
||||
@property
|
||||
def is_kv_layout_blocks_first(self) -> bool:
|
||||
return self._is_kv_layout_blocks_first
|
||||
|
||||
@property
|
||||
def split_k_and_v(self) -> bool:
|
||||
# Whether to register regions for K and V separately (when present).
|
||||
return not (
|
||||
self._cross_layers_blocks or self.is_mla or self.is_kv_layout_blocks_first
|
||||
)
|
||||
|
||||
@property
|
||||
def tp_size(self) -> int:
|
||||
return self.remote_tp_size[self.engine_id]
|
||||
|
||||
@property
|
||||
def block_size(self) -> int:
|
||||
return self.remote_block_size[self.engine_id]
|
||||
|
||||
@property
|
||||
def cross_layers_blocks(self) -> bool:
|
||||
return self._cross_layers_blocks
|
||||
|
||||
@property
|
||||
def block_size_position(self) -> int:
|
||||
return self._block_size_position
|
||||
|
||||
def tp_ratio(
|
||||
self,
|
||||
remote_tp_size: int,
|
||||
) -> int:
|
||||
"""
|
||||
Calculate the tensor parallel ratio between local and remote TP.
|
||||
We can think of it as the number of local TP workers-per-remote TP
|
||||
workers. Local workers will read from the same remote TP worker in
|
||||
groups of size `tp_ratio`.If remote tp_size > local tp_size, the
|
||||
ratio is flipped (remote_size/local_size) and the returned value is
|
||||
negative.
|
||||
"""
|
||||
if self.tp_size >= remote_tp_size:
|
||||
assert self.tp_size % remote_tp_size == 0, (
|
||||
f"Local tensor parallel size {self.tp_size} is not divisible "
|
||||
f"by remote tensor parallel size {remote_tp_size}."
|
||||
)
|
||||
return self.tp_size // remote_tp_size
|
||||
|
||||
assert remote_tp_size % self.tp_size == 0, (
|
||||
f"Remote tensor parallel size {remote_tp_size} is not divisible "
|
||||
f"by local tensor parallel size {self.tp_size}."
|
||||
)
|
||||
# P TP > D TP case, return the ratio as negative
|
||||
return -remote_tp_size // self.tp_size
|
||||
|
||||
def block_size_ratio(
|
||||
self,
|
||||
remote_block_size: int,
|
||||
) -> int:
|
||||
"""
|
||||
Calculate the block size ratio between local and remote TP.
|
||||
"""
|
||||
assert self.block_size % remote_block_size == 0, (
|
||||
f"Local block size {self.block_size} is not divisible "
|
||||
f"by remote block size {remote_block_size} or vice versa."
|
||||
)
|
||||
return self.block_size // remote_block_size
|
||||
|
||||
def tp_ratio_from_engine_id(
|
||||
self,
|
||||
remote_engine_id: EngineId,
|
||||
) -> int:
|
||||
remote_tp_size = self.remote_tp_size[remote_engine_id]
|
||||
return self.tp_ratio(remote_tp_size)
|
||||
|
||||
def block_size_ratio_from_engine_id(
|
||||
self,
|
||||
remote_engine_id: EngineId,
|
||||
) -> int:
|
||||
remote_block_size = self.remote_block_size[remote_engine_id]
|
||||
return self.block_size_ratio(remote_block_size)
|
||||
|
||||
def is_kv_replicated(self, engine_id: EngineId) -> bool:
|
||||
"""
|
||||
Whether the KV cache is replicated across TP workers due to the
|
||||
number of TP workers being greater than the number of KV heads.
|
||||
"""
|
||||
tp_size = self.remote_tp_size[engine_id]
|
||||
return tp_size // self.total_num_kv_heads >= 1
|
||||
|
||||
def replicates_kv_cache(self, remote_engine_id: EngineId) -> bool:
|
||||
# MLA is always replicated as the hidden dim can't be split.
|
||||
return self.is_mla or self.is_kv_replicated(remote_engine_id)
|
||||
|
||||
def get_target_remote_ranks(
|
||||
self,
|
||||
remote_tp_size: int,
|
||||
) -> list[int]:
|
||||
"""
|
||||
Get the remote TP rank (on P) that the current local TP rank
|
||||
(on D) will read from. When remote tp_size > local tp_size, we
|
||||
read from multiple remote ranks.
|
||||
"""
|
||||
tp_ratio = self.tp_ratio(remote_tp_size)
|
||||
if tp_ratio > 0:
|
||||
return [self.tp_rank // tp_ratio]
|
||||
|
||||
# P TP > D TP case, D reads from |tp_ratio| remote workers.
|
||||
tp_ratio = -tp_ratio
|
||||
return [self.tp_rank * tp_ratio + i for i in range(tp_ratio)]
|
||||
|
||||
def get_target_remote_ranks_from_engine_id(
|
||||
self,
|
||||
remote_engine_id: EngineId,
|
||||
) -> list[int]:
|
||||
remote_tp_size = self.remote_tp_size[remote_engine_id]
|
||||
return self.get_target_remote_ranks(remote_tp_size)
|
||||
|
||||
|
||||
def get_current_attn_backend(vllm_config: VllmConfig):
|
||||
layer_type = cast(type[Any], AttentionLayerBase)
|
||||
layers = get_layers_from_vllm_config(vllm_config, layer_type, None)
|
||||
if layers:
|
||||
backend = next(iter(layers.values())).get_attn_backend()
|
||||
else:
|
||||
# Fallback for tests, when static_forward_context is empty.
|
||||
logger.debug(
|
||||
"No layers found in the vLLM config. "
|
||||
"Falling back to default attention backend."
|
||||
)
|
||||
from vllm.v1.attention.selector import get_attn_backend
|
||||
|
||||
backend = get_attn_backend(
|
||||
head_size=vllm_config.model_config.get_head_size(),
|
||||
dtype=vllm_config.model_config.dtype,
|
||||
kv_cache_dtype=vllm_config.cache_config.cache_dtype,
|
||||
block_size=vllm_config.cache_config.block_size,
|
||||
use_mla=vllm_config.model_config.use_mla,
|
||||
)
|
||||
return backend
|
||||
19
vllm/distributed/kv_transfer/kv_connector/v1/__init__.py
Normal file
19
vllm/distributed/kv_transfer/kv_connector/v1/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1,
|
||||
KVConnectorRole,
|
||||
SupportsHMA,
|
||||
supports_hma,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.decode_bench_connector import ( # noqa: E501
|
||||
DecodeBenchConnector,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"KVConnectorRole",
|
||||
"KVConnectorBase_V1",
|
||||
"supports_hma",
|
||||
"SupportsHMA",
|
||||
"DecodeBenchConnector",
|
||||
]
|
||||
607
vllm/distributed/kv_transfer/kv_connector/v1/base.py
Normal file
607
vllm/distributed/kv_transfer/kv_connector/v1/base.py
Normal file
@@ -0,0 +1,607 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
KVConnectorBase_V1 Class for Distributed KV Cache & Hidden State
|
||||
communication in vLLM v1
|
||||
|
||||
The class provides the following primitives:
|
||||
Scheduler-side: runs in the scheduler, binds metadata, which
|
||||
is used by the worker-side to load/save KV cache.
|
||||
get_num_new_matched_tokens() - get number of new tokens
|
||||
that exist in the remote KV cache. Might be called multiple
|
||||
times for a given request and should be side-effect free.
|
||||
update_state_after_alloc() - update KVConnector state after
|
||||
temporary buffer alloc by the CacheManager.
|
||||
update_connector_output() - update KVConnector state after
|
||||
output is received from worker-side connectors.
|
||||
request_finished() - called once when a request is finished,
|
||||
with the computed kv cache blocks for the request.
|
||||
Returns whether KV cache should be freed now or if the
|
||||
connector now assumes responsibility for freeing the
|
||||
the blocks asynchronously. Also optionally returns KV
|
||||
transfer params.
|
||||
take_events() - returns new KV events that were collected
|
||||
by the connector since the last call.
|
||||
|
||||
Worker-side: runs in each worker, loads/saves KV cache to/from
|
||||
the Connector based on the metadata.
|
||||
handle_preemptions() - called if there are preempted requests,
|
||||
before their blocks are overwritten
|
||||
|
||||
start_load_kv() - starts loading all KVs (maybe async)
|
||||
wait_for_layer_load() - blocks until layer i load is done
|
||||
|
||||
save_kv_layer() - starts saving KV for layer i (maybe async)
|
||||
wait_for_save() - blocks until all saves are done
|
||||
|
||||
get_finished() - called with ids of finished requests, returns
|
||||
ids of requests that have completed async sending/recving.
|
||||
"""
|
||||
|
||||
import enum
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Iterable
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.outputs import KVConnectorOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_events import KVCacheEvent, KVConnectorKVEvents
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
||||
KVConnectorPromMetrics,
|
||||
KVConnectorStats,
|
||||
PromMetric,
|
||||
PromMetricT,
|
||||
)
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.request import Request
|
||||
|
||||
# s_tensor_list, d_tensor_list, s_indices, d_indices, direction
|
||||
CopyBlocksOp = Callable[
|
||||
[
|
||||
dict[str, torch.Tensor],
|
||||
dict[str, torch.Tensor],
|
||||
list[int],
|
||||
list[int],
|
||||
Literal["h2d", "d2h"],
|
||||
],
|
||||
None,
|
||||
]
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class SupportsHMA(ABC):
|
||||
"""
|
||||
The class that indicates the corresponding connector supports hybrid memory
|
||||
allocator (HMA).
|
||||
This is required to use the connector together with hybrid memory allocator.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def request_finished_all_groups(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: tuple[list[int], ...],
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
"""
|
||||
Called exactly once when a request has finished for all kv cache groups,
|
||||
before its blocks are freed for each group.
|
||||
|
||||
NOTE(Kuntai): This function is only supported by connectors that support HMA.
|
||||
|
||||
The connector may assumes responsibility for freeing the blocks
|
||||
asynchronously by returning True.
|
||||
|
||||
Returns:
|
||||
True if the request is being saved/sent asynchronously and blocks
|
||||
should not be freed until the request_id is returned from
|
||||
get_finished().
|
||||
Optional KVTransferParams to be included in the request outputs
|
||||
returned by the engine.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def supports_hma(connector: Any) -> bool:
|
||||
if isinstance(connector, type):
|
||||
return issubclass(connector, SupportsHMA)
|
||||
else:
|
||||
return isinstance(connector, SupportsHMA)
|
||||
|
||||
|
||||
class KVConnectorRole(enum.Enum):
|
||||
# Connector running in the scheduler process
|
||||
SCHEDULER = 0
|
||||
|
||||
# Connector running in the worker process
|
||||
WORKER = 1
|
||||
|
||||
|
||||
class KVConnectorHandshakeMetadata(ABC): # noqa: B024
|
||||
"""
|
||||
Metadata used for out of band connector handshake between
|
||||
P/D workers. This needs to serializeable.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class KVConnectorMetadata(ABC): # noqa: B024
|
||||
"""
|
||||
Abstract Metadata used to communicate between the
|
||||
Scheduler KVConnector and Worker KVConnector.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class KVConnectorBase_V1(ABC):
|
||||
"""
|
||||
Base class for KV connectors.
|
||||
"""
|
||||
|
||||
@property
|
||||
def prefer_cross_layer_blocks(self) -> bool:
|
||||
"""
|
||||
Indicates whether this connector prefers KV blocks that hold KV data for all
|
||||
layers, which can speed up KV data transfers. Defaults to False.
|
||||
"""
|
||||
return False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: "KVCacheConfig | None" = None,
|
||||
):
|
||||
logger.warning(
|
||||
"Initializing KVConnectorBase_V1. This API is experimental and "
|
||||
"subject to change in the future as we iterate the design."
|
||||
)
|
||||
self._connector_metadata: KVConnectorMetadata | None = None
|
||||
self._vllm_config = vllm_config
|
||||
if vllm_config.kv_transfer_config is not None:
|
||||
self._kv_transfer_config = vllm_config.kv_transfer_config
|
||||
else:
|
||||
raise ValueError("kv_transfer_config must be set for KVConnectorBase_V1")
|
||||
self._kv_cache_config = kv_cache_config
|
||||
if self._kv_cache_config is None:
|
||||
logger.warning(
|
||||
"KVConnectorBase_V1 initialized without kv_cache_config. "
|
||||
"This is deprecated - please update your connector to accept "
|
||||
"kv_cache_config as the third constructor argument and pass it "
|
||||
"to super().__init__()."
|
||||
)
|
||||
self._role = role
|
||||
|
||||
@property
|
||||
def role(self) -> KVConnectorRole:
|
||||
return self._role
|
||||
|
||||
# ==============================
|
||||
# Worker-side methods
|
||||
# ==============================
|
||||
|
||||
def bind_connector_metadata(self, connector_metadata: KVConnectorMetadata) -> None:
|
||||
"""Set the connector metadata from the scheduler.
|
||||
|
||||
This function should be called by the model runner every time
|
||||
before the model execution. The metadata will be used for runtime
|
||||
KV cache loading and saving.
|
||||
|
||||
Args:
|
||||
connector_metadata (dict): the connector metadata.
|
||||
"""
|
||||
self._connector_metadata = connector_metadata
|
||||
|
||||
def clear_connector_metadata(self) -> None:
|
||||
"""Clear the connector metadata.
|
||||
|
||||
This function should be called by the model runner every time
|
||||
after the model execution.
|
||||
"""
|
||||
self._connector_metadata = None
|
||||
|
||||
def _get_connector_metadata(self) -> KVConnectorMetadata:
|
||||
"""Get the connector metadata.
|
||||
|
||||
This function should only be called inside the connector.
|
||||
|
||||
Returns:
|
||||
ConnectorMetadata: the connector metadata.
|
||||
"""
|
||||
# Should only be called while set to valid metadata.
|
||||
assert self._connector_metadata is not None
|
||||
return self._connector_metadata
|
||||
|
||||
def has_connector_metadata(self) -> bool:
|
||||
"""Check whether the connector metadata is currently set.
|
||||
|
||||
Returns:
|
||||
bool: True if connector metadata exists, False otherwise.
|
||||
"""
|
||||
return self._connector_metadata is not None
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
"""
|
||||
Initialize with the KV caches. Useful for pre-registering the
|
||||
KV Caches in the KVConnector (e.g. for NIXL).
|
||||
|
||||
Args:
|
||||
kv_caches: dictionary of layer names, kv cache
|
||||
"""
|
||||
return
|
||||
|
||||
def register_cross_layers_kv_cache(
|
||||
self, kv_cache: torch.Tensor, attn_backend: type["AttentionBackend"]
|
||||
):
|
||||
"""
|
||||
Initialize with a single KV cache tensor used by all layers.
|
||||
The first dimension should be num_layers.
|
||||
This function will only be called for models with uniform layers,
|
||||
and only if the prefers_cross_layer_blocks is set to True.
|
||||
Only one of the functions
|
||||
{register_kv_caches, register_cross_layers_kv_cache} will be called.
|
||||
|
||||
Args:
|
||||
kv_cache: a cross-layers kv cache tensor
|
||||
attn_backend: The attention backend that corresponds to all layers
|
||||
"""
|
||||
return
|
||||
|
||||
def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp):
|
||||
"""
|
||||
Set the xPU-specific ops for copying KV between host and device.
|
||||
Needed when host buffer is used for kv transfer (e.g., in NixlConnector)
|
||||
"""
|
||||
return
|
||||
|
||||
def handle_preemptions(self, preempted_req_ids: set[str]):
|
||||
"""
|
||||
Handle preempted requests BEFORE their blocks are overwritten.
|
||||
Needed for connectors which use async saves (e.g., OffloadingConnector)
|
||||
"""
|
||||
return
|
||||
|
||||
@abstractmethod
|
||||
def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None:
|
||||
"""
|
||||
Start loading the KV cache from the connector to vLLM's paged
|
||||
KV buffer. This is called from the forward context before the
|
||||
forward pass to enable async loading during model execution.
|
||||
|
||||
Args:
|
||||
forward_context (ForwardContext): the forward context.
|
||||
**kwargs: additional arguments for the load operation
|
||||
|
||||
Note:
|
||||
The number of elements in kv_caches and layer_names should be
|
||||
the same.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
"""
|
||||
Block until the KV for a specific layer is loaded into vLLM's
|
||||
paged buffer. This is called from within attention layer to ensure
|
||||
async copying from start_load_kv is complete.
|
||||
|
||||
This interface will be useful for layer-by-layer pipelining.
|
||||
|
||||
Args:
|
||||
layer_name: the name of that layer
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_kv_layer(
|
||||
self,
|
||||
layer_name: str,
|
||||
kv_layer: torch.Tensor,
|
||||
attn_metadata: "AttentionMetadata",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Start saving a layer of KV cache from vLLM's paged buffer
|
||||
to the connector. This is called from within attention layer to
|
||||
enable async copying during execution.
|
||||
|
||||
Args:
|
||||
layer_name (str): the name of the layer.
|
||||
kv_layer (torch.Tensor): the paged KV buffer of the current
|
||||
layer in vLLM.
|
||||
attn_metadata (AttentionMetadata): the attention metadata.
|
||||
**kwargs: additional arguments for the save operation.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def wait_for_save(self):
|
||||
"""
|
||||
Block until all the save operations is done. This is called
|
||||
as the forward context exits to ensure that the async saving
|
||||
from save_kv_layer is complete before finishing the forward.
|
||||
|
||||
This prevents overwrites of paged KV buffer before saving done.
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_finished(
|
||||
self, finished_req_ids: set[str]
|
||||
) -> tuple[set[str] | None, set[str] | None]:
|
||||
"""
|
||||
Notifies worker-side connector ids of requests that have
|
||||
finished generating tokens on the worker.
|
||||
The scheduler process (via the Executors) will use this output
|
||||
to track which workers are done.
|
||||
|
||||
Returns:
|
||||
ids of requests that have finished asynchronous transfer
|
||||
(requests that previously returned True from request_finished()),
|
||||
tuple of (sending/saving ids, recving/loading ids).
|
||||
The finished saves/sends req ids must belong to a set provided in a
|
||||
call to this method (this call or a prior one).
|
||||
"""
|
||||
return None, None
|
||||
|
||||
def get_block_ids_with_load_errors(self) -> set[int]:
|
||||
"""
|
||||
Get the set of block IDs that failed to load.
|
||||
|
||||
Returns:
|
||||
Set of block IDs that encountered load errors.
|
||||
Empty set if no load errors occurred.
|
||||
|
||||
Notes:
|
||||
- Applies to both sync- and async-loading requests.
|
||||
- Async loading: failed blocks may be reported in any forward pass
|
||||
up to and including the pass where the request ID is returned by
|
||||
`get_finished()`. Even if failures occur, the request must still
|
||||
be reported via `get_finished()`, and the failed block IDs must
|
||||
appear here no later than that same pass.
|
||||
- Sync loading: failed blocks should be reported in the forward
|
||||
pass in which they are detected.
|
||||
"""
|
||||
return set()
|
||||
|
||||
def shutdown(self):
|
||||
"""
|
||||
Shutdown the connector. This is called when the worker process
|
||||
is shutting down to ensure that all the async operations are
|
||||
completed and the connector is cleaned up properly.
|
||||
"""
|
||||
return None
|
||||
|
||||
def get_kv_connector_stats(self) -> "KVConnectorStats | None":
|
||||
"""
|
||||
Get the KV connector stats collected during the last interval.
|
||||
"""
|
||||
return None
|
||||
|
||||
def get_kv_connector_kv_cache_events(self) -> "KVConnectorKVEvents | None":
|
||||
"""
|
||||
Get the KV connector kv cache events collected during the last interval.
|
||||
This function should be called by the model runner every time after the
|
||||
model execution and before cleanup.
|
||||
"""
|
||||
return None
|
||||
|
||||
def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None:
|
||||
"""
|
||||
Get the KVConnector handshake metadata for this connector.
|
||||
This metadata is used for out-of-band connector handshake
|
||||
between P/D workers.
|
||||
|
||||
Returns:
|
||||
KVConnectorHandshakeMetadata: the handshake metadata.
|
||||
None if no handshake metadata is available.
|
||||
"""
|
||||
return None
|
||||
|
||||
# ==============================
|
||||
# Scheduler-side methods
|
||||
# ==============================
|
||||
|
||||
@abstractmethod
|
||||
def get_num_new_matched_tokens(
|
||||
self,
|
||||
request: "Request",
|
||||
num_computed_tokens: int,
|
||||
) -> tuple[int | None, bool]:
|
||||
"""
|
||||
Get number of new tokens that can be loaded from the
|
||||
external KV cache beyond the num_computed_tokens.
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
num_computed_tokens (int): the number of locally
|
||||
computed tokens for this request
|
||||
|
||||
Returns:
|
||||
A tuple with the following elements:
|
||||
- An optional number of tokens that can be loaded from the
|
||||
external KV cache beyond what is already computed.
|
||||
If None, it means that the connector needs more time to
|
||||
determine the number of matched tokens, and the scheduler
|
||||
should query for this request again later.
|
||||
- `True` if external KV cache tokens will be loaded
|
||||
asynchronously (between scheduler steps). Must be
|
||||
'False' if the first element is 0.
|
||||
|
||||
Notes:
|
||||
The connector should only consider the largest prefix of prompt-
|
||||
tokens for which KV cache is actually available at the time of the
|
||||
call. If the cache cannot be loaded for some tokens (e.g., due to
|
||||
connectivity issues or eviction), those tokens must not be taken
|
||||
into account.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_state_after_alloc(
|
||||
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
|
||||
):
|
||||
"""
|
||||
Update KVConnector state after block allocation.
|
||||
|
||||
If get_num_new_matched_tokens previously returned True for a
|
||||
request, this function may be called twice for that same request -
|
||||
first when blocks are allocated for the connector tokens to be
|
||||
asynchronously loaded into, and second when any additional blocks
|
||||
are allocated, after the load/transfer is complete.
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
blocks (KVCacheBlocks): the blocks allocated for the request.
|
||||
num_external_tokens (int): the number of tokens that will be
|
||||
loaded from the external KV cache.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: SchedulerOutput
|
||||
) -> KVConnectorMetadata:
|
||||
"""
|
||||
Build the connector metadata for this step.
|
||||
|
||||
This function should NOT modify fields in the scheduler_output.
|
||||
Also, calling this function will reset the state of the connector.
|
||||
|
||||
Args:
|
||||
scheduler_output (SchedulerOutput): the scheduler output object.
|
||||
"""
|
||||
pass
|
||||
|
||||
def update_connector_output(self, connector_output: KVConnectorOutput):
|
||||
"""
|
||||
Update KVConnector state from worker-side connectors output.
|
||||
|
||||
Args:
|
||||
connector_output (KVConnectorOutput): the worker-side
|
||||
connectors output.
|
||||
"""
|
||||
return
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
"""
|
||||
Called exactly once when a request has finished, before its blocks are
|
||||
freed.
|
||||
|
||||
The connector may assumes responsibility for freeing the blocks
|
||||
asynchronously by returning True.
|
||||
|
||||
Returns:
|
||||
True if the request is being saved/sent asynchronously and blocks
|
||||
should not be freed until the request_id is returned from
|
||||
get_finished().
|
||||
Optional KVTransferParams to be included in the request outputs
|
||||
returned by the engine.
|
||||
"""
|
||||
return False, None
|
||||
|
||||
def take_events(self) -> Iterable["KVCacheEvent"]:
|
||||
"""
|
||||
Take the KV cache events from the connector.
|
||||
|
||||
Yields:
|
||||
New KV cache events since the last call.
|
||||
"""
|
||||
return ()
|
||||
|
||||
@classmethod
|
||||
def get_required_kvcache_layout(cls, vllm_config: "VllmConfig") -> str | None:
|
||||
"""
|
||||
Get the required KV cache layout for this connector.
|
||||
Args:
|
||||
vllm_config (VllmConfig): the vllm config.
|
||||
|
||||
Returns:
|
||||
str: the required KV cache layout. e.g. HND, or NHD.
|
||||
None if the connector does not require a specific layout.
|
||||
"""
|
||||
|
||||
if cls is KVConnectorBase_V1:
|
||||
raise TypeError(
|
||||
"get_required_kvcache_layout should not be called "
|
||||
"on the abstract base class"
|
||||
)
|
||||
return None
|
||||
|
||||
def get_finished_count(self) -> int | None:
|
||||
"""
|
||||
Get the count of requests expected to complete send/receive operations
|
||||
via this connector. This method is used to initialize the
|
||||
KVOutputAggregator, overwriting the default world_size.
|
||||
|
||||
Returns:
|
||||
int: expected sending or receiving completion count.
|
||||
"""
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def build_kv_connector_stats(
|
||||
cls, data: dict[str, Any] | None = None
|
||||
) -> "KVConnectorStats | None":
|
||||
"""
|
||||
KVConnectorStats resolution method. This method allows dynamically
|
||||
registered connectors to return their own KVConnectorStats object,
|
||||
which can implement custom aggregation logic on the data dict.
|
||||
"""
|
||||
return None
|
||||
|
||||
def set_xfer_handshake_metadata(
|
||||
self, metadata: dict[int, KVConnectorHandshakeMetadata]
|
||||
) -> None:
|
||||
"""
|
||||
Set the KV connector handshake metadata for this connector.
|
||||
|
||||
Args:
|
||||
metadata (KVConnectorHandshakeMetadata): the handshake metadata to set.
|
||||
"""
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def build_prom_metrics(
|
||||
cls,
|
||||
vllm_config: "VllmConfig",
|
||||
metric_types: dict[type["PromMetric"], type["PromMetricT"]],
|
||||
labelnames: list[str],
|
||||
per_engine_labelvalues: dict[int, list[object]],
|
||||
) -> "KVConnectorPromMetrics | None":
|
||||
"""
|
||||
Create a KVConnectorPromMetrics subclass which should register
|
||||
per-connector Prometheus metrics and implement observe() to
|
||||
expose connector transfer stats via Prometheus.
|
||||
"""
|
||||
return None
|
||||
|
||||
def reset_cache(self) -> bool | None:
|
||||
"""
|
||||
Reset the connector's internal cache.
|
||||
|
||||
Returns:
|
||||
bool: True if the cache was successfully reset, False otherwise.
|
||||
"""
|
||||
logger.debug(
|
||||
"Connector cache reset requested, but %s does not implement reset_cache().",
|
||||
type(self).__name__,
|
||||
)
|
||||
|
||||
return None
|
||||
@@ -0,0 +1,419 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
DecodeBenchConnector: A KV Connector for decode instance performance testing.
|
||||
|
||||
This connector emulates a prefill-decode disaggregated setting by filling
|
||||
the KV cache with dummy values, allowing measurement of decoder performance
|
||||
under larger input sequence lengths (ISL) in resource-limited environments.
|
||||
|
||||
Usage:
|
||||
To use this connector for benchmarking, configure it in the kv_transfer_config:
|
||||
|
||||
Example:
|
||||
vllm serve <model> --kv-transfer-config '{
|
||||
"kv_connector": "DecodeBenchConnector",
|
||||
"kv_role": "kv_both",
|
||||
"kv_connector_extra_config": {
|
||||
"fill_mean": 0.015,
|
||||
"fill_std": 0.0
|
||||
}
|
||||
}'
|
||||
|
||||
Then run your benchmark with desired input/output lengths:
|
||||
vllm bench serve --base-url http://127.0.0.1:8000 --model <model> \\
|
||||
--dataset-name random --random-input-len 40000 \\
|
||||
--random-output-len 100 --max-concurrency 10
|
||||
|
||||
Configuration options (via kv_connector_extra_config):
|
||||
- fill_mean (float): Mean value for random normal fill (default: 0.015)
|
||||
- fill_std (float): Standard deviation for random fill (default: 0.0)
|
||||
Set to 0 for constant values, >0 for random sampling
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import (
|
||||
KVConnectorBase_V1,
|
||||
KVConnectorRole,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.attention.backend import AttentionMetadata
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DecodeBenchConnectorMetadata(KVConnectorMetadata):
|
||||
"""Metadata for DecodeBenchConnector.
|
||||
|
||||
Contains information about which requests need their KV cache filled
|
||||
with dummy values for benchmarking purposes.
|
||||
"""
|
||||
|
||||
# request_id -> (block_ids_per_group, num_tokens_to_fill)
|
||||
# block_ids_per_group is a tuple of lists, one per KV cache group
|
||||
# For standard attention: single group, e.g., ([1, 2, 3],)
|
||||
# For MLA: multiple groups, e.g., ([1, 2], [1, 2])
|
||||
reqs_to_fill: dict[str, tuple[tuple[list[int], ...], int]]
|
||||
|
||||
|
||||
class DecodeBenchConnector(KVConnectorBase_V1):
|
||||
"""
|
||||
A KV Connector for decode instance performance testing.
|
||||
|
||||
This connector fills the KV cache with dummy (non-zero) values to
|
||||
emulate a prefill-decode disaggregated setting, enabling performance
|
||||
testing of the decoder with larger input sequence lengths.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: "KVCacheConfig | None" = None,
|
||||
):
|
||||
super().__init__(vllm_config, role, kv_cache_config)
|
||||
|
||||
self.connector_scheduler: DecodeBenchConnectorScheduler | None = None
|
||||
self.connector_worker: DecodeBenchConnectorWorker | None = None
|
||||
|
||||
if role == KVConnectorRole.SCHEDULER:
|
||||
self.connector_scheduler = DecodeBenchConnectorScheduler(vllm_config)
|
||||
elif role == KVConnectorRole.WORKER:
|
||||
self.connector_worker = DecodeBenchConnectorWorker(vllm_config)
|
||||
|
||||
# ==============================
|
||||
# Worker-side methods
|
||||
# ==============================
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
assert self.connector_worker is not None
|
||||
self.connector_worker.register_kv_caches(kv_caches)
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None:
|
||||
assert self.connector_worker is not None
|
||||
assert isinstance(self._connector_metadata, DecodeBenchConnectorMetadata)
|
||||
self.connector_worker.start_fill_kv(self._connector_metadata)
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
# All operations are synchronous, so nothing to wait for
|
||||
pass
|
||||
|
||||
def save_kv_layer(
|
||||
self,
|
||||
layer_name: str,
|
||||
kv_layer: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
# This connector doesn't save KV cache (benchmarking only)
|
||||
pass
|
||||
|
||||
def wait_for_save(self):
|
||||
# This connector doesn't save KV cache (benchmarking only)
|
||||
pass
|
||||
|
||||
# ==============================
|
||||
# Scheduler-side methods
|
||||
# ==============================
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self,
|
||||
request: "Request",
|
||||
num_computed_tokens: int,
|
||||
) -> tuple[int | None, bool]:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.get_num_new_matched_tokens(
|
||||
request, num_computed_tokens
|
||||
)
|
||||
|
||||
def update_state_after_alloc(
|
||||
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
|
||||
):
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.update_state_after_alloc(
|
||||
request, blocks, num_external_tokens
|
||||
)
|
||||
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: "SchedulerOutput"
|
||||
) -> KVConnectorMetadata:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.build_connector_meta(scheduler_output)
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
assert self.connector_scheduler is not None
|
||||
self.connector_scheduler.request_finished(request)
|
||||
return False, None
|
||||
|
||||
|
||||
class DecodeBenchConnectorScheduler:
|
||||
"""Scheduler-side implementation for DecodeBenchConnector."""
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig"):
|
||||
self.vllm_config = vllm_config
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
|
||||
# Track which requests have already been filled
|
||||
self._filled_requests: set[str] = set()
|
||||
|
||||
# Track pending fills for the current scheduler step
|
||||
# request_id -> (block_ids_per_group, num_tokens_to_fill)
|
||||
# Note: _pending_fills doesn't need explicit cleanup - it's cleared
|
||||
# after build_connector_meta() is called in the same scheduler step
|
||||
self._pending_fills: dict[str, tuple[tuple[list[int], ...], int]] = {}
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self,
|
||||
request: "Request",
|
||||
num_computed_tokens: int,
|
||||
) -> tuple[int, bool]:
|
||||
"""
|
||||
For new requests, return the number of tokens that should be filled
|
||||
with dummy KV cache values.
|
||||
|
||||
Returns:
|
||||
(num_tokens_to_fill, is_async)
|
||||
- num_tokens_to_fill: number of uncomputed tokens minus 1
|
||||
(we fill everything except the last token for decode)
|
||||
- is_async: False (synchronous filling)
|
||||
"""
|
||||
req_id = request.request_id
|
||||
|
||||
# Only fill once per request on first scheduling
|
||||
if req_id in self._filled_requests:
|
||||
return 0, False
|
||||
|
||||
# Calculate how many tokens we need to fill
|
||||
# Fill all uncomputed tokens except the last one (which will be decoded)
|
||||
# This simulates having processed a long prefill
|
||||
num_uncomputed_tokens = request.num_tokens - num_computed_tokens
|
||||
num_tokens_to_fill = max(0, num_uncomputed_tokens - 1)
|
||||
|
||||
if num_tokens_to_fill == 0:
|
||||
return 0, False
|
||||
|
||||
# Return False for synchronous operation - the fill is fast enough
|
||||
# that async overhead isn't worth it
|
||||
return num_tokens_to_fill, False
|
||||
|
||||
def update_state_after_alloc(
|
||||
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
|
||||
):
|
||||
"""
|
||||
Called after blocks are allocated. Store the block IDs so we can
|
||||
fill them with dummy values.
|
||||
|
||||
Supports both standard attention (single KV cache group) and MLA
|
||||
(multiple KV cache groups).
|
||||
"""
|
||||
req_id = request.request_id
|
||||
|
||||
if num_external_tokens == 0:
|
||||
return
|
||||
|
||||
# Get the block IDs that were allocated
|
||||
# block_groups is a tuple of lists, one per KV cache group
|
||||
# For standard attention: 1 group
|
||||
# For MLA: multiple groups (one per attention type)
|
||||
block_groups = blocks.get_block_ids()
|
||||
|
||||
# Calculate how many blocks we need to fill
|
||||
# num_external_tokens are the tokens we said we'd provide
|
||||
num_blocks_to_fill = cdiv(num_external_tokens, self.block_size)
|
||||
|
||||
# Extract the first num_blocks_to_fill blocks from each group
|
||||
# All groups should have the same block IDs for the same request
|
||||
block_ids_per_group = tuple(
|
||||
group_blocks[:num_blocks_to_fill] for group_blocks in block_groups
|
||||
)
|
||||
|
||||
# Store the blocks to fill for all group. _pending_fills doesn't need cleanup
|
||||
# as it's cleared after build_connector_meta
|
||||
self._pending_fills[req_id] = (
|
||||
block_ids_per_group,
|
||||
num_external_tokens,
|
||||
)
|
||||
self._filled_requests.add(req_id)
|
||||
|
||||
logger.debug(
|
||||
"DecodeBenchConnector: Allocated %d blocks across %d KV cache groups "
|
||||
"for request %s",
|
||||
num_blocks_to_fill,
|
||||
len(block_groups),
|
||||
req_id,
|
||||
)
|
||||
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: "SchedulerOutput"
|
||||
) -> KVConnectorMetadata:
|
||||
"""
|
||||
Build metadata containing information about which blocks to fill
|
||||
with dummy KV values.
|
||||
"""
|
||||
meta = DecodeBenchConnectorMetadata(reqs_to_fill=self._pending_fills.copy())
|
||||
|
||||
# Clear pending fills after building metadata
|
||||
self._pending_fills.clear()
|
||||
|
||||
return meta
|
||||
|
||||
def request_finished(self, request: "Request"):
|
||||
"""
|
||||
Called when a request has finished. Clean up any state.
|
||||
"""
|
||||
self._filled_requests.discard(request.request_id)
|
||||
|
||||
|
||||
class DecodeBenchConnectorWorker:
|
||||
"""Worker-side implementation for DecodeBenchConnector."""
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig"):
|
||||
self.vllm_config = vllm_config
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
|
||||
# Get fill parameters from extra config
|
||||
kv_transfer_config = vllm_config.kv_transfer_config
|
||||
assert kv_transfer_config is not None
|
||||
self.fill_mean = kv_transfer_config.get_from_extra_config("fill_mean", 0.015)
|
||||
self.fill_std = kv_transfer_config.get_from_extra_config("fill_std", 0.0)
|
||||
|
||||
# Will be populated via register_kv_caches
|
||||
self.kv_caches: dict[str, torch.Tensor] | None = None
|
||||
|
||||
# Mapping from KV cache group index to list of layer names in that group
|
||||
self.group_to_layers: dict[int, list[str]] | None = None
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
"""Store references to the KV cache tensors and build group mapping."""
|
||||
self.kv_caches = kv_caches
|
||||
|
||||
# For simplicity, assume all layers belong to group 0 (standard attention)
|
||||
# For MLA models with multiple groups, the metadata will handle the mapping
|
||||
# We just need to fill the blocks specified in the metadata
|
||||
self.group_to_layers = {0: list(kv_caches.keys())}
|
||||
|
||||
logger.debug(
|
||||
"DecodeBenchConnector: Registered %d KV cache layers",
|
||||
len(kv_caches),
|
||||
)
|
||||
|
||||
def start_fill_kv(self, metadata: DecodeBenchConnectorMetadata):
|
||||
"""
|
||||
Fill the allocated KV cache blocks with dummy (non-zero) values.
|
||||
|
||||
This simulates having a populated KV cache from a prefill phase,
|
||||
allowing decode performance testing with larger context sizes.
|
||||
|
||||
Supports both standard attention (single group) and MLA (multiple groups).
|
||||
"""
|
||||
if not metadata.reqs_to_fill:
|
||||
return
|
||||
|
||||
assert self.kv_caches is not None, "KV caches must be registered before filling"
|
||||
assert self.group_to_layers is not None, "Group mapping must be initialized"
|
||||
|
||||
for req_id, (block_ids_per_group, num_tokens) in metadata.reqs_to_fill.items():
|
||||
# Fill blocks for each KV cache group
|
||||
for group_idx, block_ids in enumerate(block_ids_per_group):
|
||||
self._fill_blocks(group_idx, block_ids, num_tokens)
|
||||
|
||||
logger.debug(
|
||||
"DecodeBenchConnector: Filled %d blocks (%d tokens) across %d groups "
|
||||
"for request %s",
|
||||
len(block_ids_per_group[0]) if block_ids_per_group else 0,
|
||||
num_tokens,
|
||||
len(block_ids_per_group),
|
||||
req_id,
|
||||
)
|
||||
|
||||
def _fill_blocks(self, group_idx: int, block_ids: list[int], num_tokens: int):
|
||||
"""
|
||||
Fill specified blocks with dummy non-zero values for a specific KV cache group.
|
||||
|
||||
Args:
|
||||
group_idx: The KV cache group index to fill
|
||||
block_ids: List of block IDs to fill in this group
|
||||
num_tokens: Total number of tokens to fill across these blocks
|
||||
"""
|
||||
if not block_ids:
|
||||
return
|
||||
|
||||
assert self.kv_caches is not None
|
||||
assert self.group_to_layers is not None
|
||||
|
||||
# Get the layers that belong to this group
|
||||
layer_names = self.group_to_layers.get(group_idx, [])
|
||||
|
||||
# Fill only the layers in this group
|
||||
for layer_name in layer_names:
|
||||
if layer_name not in self.kv_caches:
|
||||
logger.warning(
|
||||
"DecodeBenchConnector: Layer %s not found in KV caches", layer_name
|
||||
)
|
||||
continue
|
||||
|
||||
kv_cache = self.kv_caches[layer_name]
|
||||
|
||||
# Convert block_ids to tensor on device
|
||||
block_ids_tensor = torch.tensor(
|
||||
block_ids, dtype=torch.long, device=kv_cache.device
|
||||
)
|
||||
|
||||
# Filter invalid block IDs
|
||||
valid_mask = block_ids_tensor < kv_cache.shape[0]
|
||||
valid_block_ids = block_ids_tensor[valid_mask]
|
||||
|
||||
if len(valid_block_ids) == 0:
|
||||
continue
|
||||
|
||||
# Create fill values - either constant or random
|
||||
block_shape = kv_cache.shape[1:]
|
||||
if self.fill_std > 0:
|
||||
# Random normal sampling
|
||||
fill_values = torch.normal(
|
||||
mean=self.fill_mean,
|
||||
std=self.fill_std,
|
||||
size=(len(valid_block_ids),) + block_shape,
|
||||
dtype=kv_cache.dtype,
|
||||
device=kv_cache.device,
|
||||
)
|
||||
else:
|
||||
# Constant fill value
|
||||
fill_values = torch.full(
|
||||
(len(valid_block_ids),) + block_shape,
|
||||
self.fill_mean,
|
||||
dtype=kv_cache.dtype,
|
||||
device=kv_cache.device,
|
||||
)
|
||||
|
||||
# Batch fill operation
|
||||
kv_cache[valid_block_ids] = fill_values
|
||||
|
||||
logger.debug(
|
||||
"DecodeBenchConnector: Filled %d blocks in group %d with %s values "
|
||||
"(mean=%.3f, std=%.3f)",
|
||||
len(block_ids),
|
||||
group_idx,
|
||||
"random" if self.fill_std > 0 else "constant",
|
||||
self.fill_mean,
|
||||
self.fill_std,
|
||||
)
|
||||
@@ -0,0 +1,442 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1,
|
||||
KVConnectorMetadata,
|
||||
KVConnectorRole,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention.mla_attention import MLACommonMetadata
|
||||
from vllm.utils.hashing import safe_hash
|
||||
from vllm.v1.attention.backend import AttentionMetadata
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReqMeta:
|
||||
# Request tokens
|
||||
token_ids: torch.Tensor
|
||||
# Slot mappings, should have the same length as token_ids
|
||||
slot_mapping: torch.Tensor
|
||||
# Is store or load
|
||||
is_store: bool
|
||||
mm_hashes: list[str]
|
||||
|
||||
@staticmethod
|
||||
def make_meta(
|
||||
token_ids: list[int],
|
||||
block_ids: list[int],
|
||||
block_size: int,
|
||||
is_store: bool,
|
||||
mm_hashes: list[str],
|
||||
) -> "ReqMeta":
|
||||
valid_num_tokens = align_to_block_size(len(token_ids), block_size)
|
||||
token_ids_tensor = torch.tensor(token_ids)[:valid_num_tokens]
|
||||
block_ids_tensor = torch.tensor(block_ids)
|
||||
num_blocks = block_ids_tensor.shape[0]
|
||||
block_offsets = torch.arange(0, block_size)
|
||||
slot_mapping = (
|
||||
block_offsets.reshape((1, block_size))
|
||||
+ block_ids_tensor.reshape((num_blocks, 1)) * block_size
|
||||
)
|
||||
slot_mapping = slot_mapping.flatten()[:valid_num_tokens]
|
||||
return ReqMeta(
|
||||
token_ids=token_ids_tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
is_store=is_store,
|
||||
mm_hashes=mm_hashes,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExampleConnectorMetadata(KVConnectorMetadata):
|
||||
requests: list[ReqMeta] = field(default_factory=list)
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
token_ids: list[int],
|
||||
block_ids: list[int],
|
||||
block_size: int,
|
||||
is_store: bool,
|
||||
mm_hashes: list[str],
|
||||
) -> None:
|
||||
self.requests.append(
|
||||
ReqMeta.make_meta(token_ids, block_ids, block_size, is_store, mm_hashes)
|
||||
)
|
||||
|
||||
|
||||
class ExampleConnector(KVConnectorBase_V1):
|
||||
# NOTE: This is Simple debug implementation of the KV connector.
|
||||
# It save / load the KV cache to / from the disk.
|
||||
# It does extra work which will overwrite the existing prefix-cache in GPU
|
||||
# - to remove the overhead, need to add some "mask" in the ReqMeta class
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: "KVCacheConfig | None" = None,
|
||||
):
|
||||
super().__init__(
|
||||
vllm_config=vllm_config,
|
||||
role=role,
|
||||
kv_cache_config=kv_cache_config,
|
||||
)
|
||||
self._block_size = vllm_config.cache_config.block_size
|
||||
self._requests_need_load: dict[str, Request] = {}
|
||||
self._storage_path = self._kv_transfer_config.get_from_extra_config(
|
||||
"shared_storage_path", "/tmp"
|
||||
)
|
||||
logger.info(self._kv_transfer_config)
|
||||
logger.info("Shared storage path is %s", self._storage_path)
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None:
|
||||
"""Start loading the KV cache from the connector buffer to vLLM's
|
||||
paged KV buffer.
|
||||
|
||||
Args:
|
||||
forward_context (ForwardContext): the forward context.
|
||||
**kwargs: additional arguments for the load operation
|
||||
|
||||
Note:
|
||||
The number of elements in kv_caches and layer_names should be
|
||||
the same.
|
||||
"""
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
|
||||
def inject_kv_into_layer(
|
||||
dst_kv_cache_layer: torch.Tensor,
|
||||
src_kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
) -> None:
|
||||
"""Inject the KV cache into the layer.
|
||||
|
||||
Args:
|
||||
dst_kv_cache_layer (torch.Tensor): the destination KV cache
|
||||
layer. In shape [2, num_pages, page_size, xxx] if not
|
||||
using MLA, [num_pages, page_size, xxx] otherwise.
|
||||
src_kv_cache (torch.Tensor): the source KV cache. In shape
|
||||
[2, num_tokens, xxx] if not using MLA, [num_tokens, xxx]
|
||||
otherwise.
|
||||
slot_mapping (torch.Tensor): the slot mapping. In shape
|
||||
[num_tokens].
|
||||
"""
|
||||
dst_kv_cache_layer_shape = dst_kv_cache_layer.shape
|
||||
if isinstance(attn_metadata, MLACommonMetadata):
|
||||
num_pages = dst_kv_cache_layer_shape[0]
|
||||
page_size = dst_kv_cache_layer_shape[1]
|
||||
dst_kv_cache_layer = dst_kv_cache_layer.reshape(
|
||||
num_pages * page_size, -1
|
||||
)
|
||||
dst_kv_cache_layer[slot_mapping, ...] = src_kv_cache
|
||||
else:
|
||||
num_pages = dst_kv_cache_layer_shape[1]
|
||||
page_size = dst_kv_cache_layer_shape[2]
|
||||
dst_kv_cache_layer = dst_kv_cache_layer.reshape(
|
||||
2, num_pages * page_size, -1
|
||||
)
|
||||
dst_kv_cache_layer[:, slot_mapping, ...] = src_kv_cache
|
||||
|
||||
# Get the metadata
|
||||
metadata: KVConnectorMetadata = self._get_connector_metadata()
|
||||
assert isinstance(metadata, ExampleConnectorMetadata)
|
||||
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if attn_metadata is None:
|
||||
logger.warning("In connector.start_load_kv, but the attn_metadata is None")
|
||||
return
|
||||
|
||||
# Load the KV for each request each layer
|
||||
for request in metadata.requests:
|
||||
if request.is_store:
|
||||
continue
|
||||
logger.info(
|
||||
"Inject KV cache of %d tokens to the paged memory",
|
||||
len(request.slot_mapping),
|
||||
)
|
||||
for layer_name in forward_context.no_compile_layers:
|
||||
layer = forward_context.no_compile_layers[layer_name]
|
||||
|
||||
# Only process layers that have kv_cache
|
||||
# attribute (attention layers) Skip non-attention
|
||||
# layers like FusedMoE/MLP etc.
|
||||
kv_cache_attr = getattr(layer, "kv_cache", None)
|
||||
if kv_cache_attr is None:
|
||||
continue
|
||||
|
||||
kv_cache_layer = kv_cache_attr[forward_context.virtual_engine]
|
||||
|
||||
filename = self._generate_filename_debug(
|
||||
layer_name, request.token_ids, request.mm_hashes
|
||||
)
|
||||
kv_cache = safetensors.torch.load_file(filename)["kv_cache"].cuda()
|
||||
inject_kv_into_layer(kv_cache_layer, kv_cache, request.slot_mapping)
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
"""Blocking until the KV for a specific layer is loaded into vLLM's
|
||||
paged buffer.
|
||||
|
||||
This interface will be useful for layer-by-layer pipelining.
|
||||
|
||||
Args:
|
||||
layer_name: the name of that layer
|
||||
"""
|
||||
return
|
||||
|
||||
def save_kv_layer(
|
||||
self,
|
||||
layer_name: str,
|
||||
kv_layer: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Start saving the KV cache of the layer from vLLM's paged buffer
|
||||
to the connector.
|
||||
|
||||
Args:
|
||||
layer_name (str): the name of the layer.
|
||||
kv_layer (torch.Tensor): the paged KV buffer of the current
|
||||
layer in vLLM.
|
||||
attn_metadata (AttentionMetadata): the attention metadata.
|
||||
**kwargs: additional arguments for the save operation.
|
||||
"""
|
||||
|
||||
def extract_kv_from_layer(
|
||||
layer: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Extract the KV cache from the layer.
|
||||
|
||||
Assume the shape of the layer is (2, num_pages, page_size, xxx)
|
||||
if MLA is not used, and (num_pages, page_size, xxx) otherwise.
|
||||
"""
|
||||
if isinstance(attn_metadata, MLACommonMetadata):
|
||||
num_pages, page_size = layer.shape[0], layer.shape[1]
|
||||
return layer.reshape(num_pages * page_size, -1)[slot_mapping, ...]
|
||||
num_pages, page_size = layer.shape[1], layer.shape[2]
|
||||
return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping, ...]
|
||||
|
||||
connector_metadata = self._get_connector_metadata()
|
||||
assert isinstance(connector_metadata, ExampleConnectorMetadata)
|
||||
for request in connector_metadata.requests:
|
||||
if request.is_store:
|
||||
filename = self._generate_filename_debug(
|
||||
layer_name, request.token_ids, request.mm_hashes
|
||||
)
|
||||
kv_cache = extract_kv_from_layer(kv_layer, request.slot_mapping)
|
||||
tensors = {"kv_cache": kv_cache.detach().cpu()}
|
||||
safetensors.torch.save_file(tensors, filename)
|
||||
|
||||
def wait_for_save(self):
|
||||
return
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self,
|
||||
request: "Request",
|
||||
num_computed_tokens: int,
|
||||
) -> tuple[int | None, bool]:
|
||||
"""
|
||||
Get number of new tokens that can be loaded from the
|
||||
external KV cache beyond the num_computed_tokens.
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
num_computed_tokens (int): the number of locally
|
||||
computed tokens for this request
|
||||
|
||||
Returns:
|
||||
the number of tokens that can be loaded from the
|
||||
external KV cache beyond what is already computed.
|
||||
"""
|
||||
# NOTE: in this debug implementation, we assume that the prompt is
|
||||
# cached_prompt + newly_generated_single_token
|
||||
# Therefore, we use prompt_token_ids[:-1] to determine the folder name
|
||||
|
||||
# NOTE: in current v1 scheduler, the num_computed_tokens is aligned
|
||||
# with the block granularity. And it expects the returned blocks and
|
||||
# num_computed_tokens to also be aligned with the block granularity.
|
||||
if not self._found_match_for_request(request):
|
||||
return 0, False
|
||||
|
||||
logger.info("External Cache Hit!")
|
||||
|
||||
# Now, first num_tokens_to_check tokens are hit, we need to prepare
|
||||
# the metadata for the worker connector to correctly load the KV
|
||||
token_ids = request.prompt_token_ids or []
|
||||
num_tokens_to_check = align_to_block_size(len(token_ids) - 1, self._block_size)
|
||||
|
||||
return num_tokens_to_check - num_computed_tokens, False
|
||||
|
||||
def update_state_after_alloc(
|
||||
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
|
||||
):
|
||||
"""
|
||||
Update KVConnector state after block allocation.
|
||||
|
||||
If blocks were allocated, add to _requests_need_load,
|
||||
such that we load the KVs in the next forward pass.
|
||||
"""
|
||||
if num_external_tokens > 0:
|
||||
self._requests_need_load[request.request_id] = request
|
||||
|
||||
def build_connector_meta(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
) -> KVConnectorMetadata:
|
||||
"""Build the connector metadata for this step.
|
||||
|
||||
This function should NOT modify any fields in the scheduler_output.
|
||||
Also, calling this function will reset the state of the connector.
|
||||
|
||||
Args:
|
||||
scheduler_output (SchedulerOutput): the scheduler output object.
|
||||
"""
|
||||
meta = ExampleConnectorMetadata()
|
||||
|
||||
total_need_load = 0
|
||||
for new_req in scheduler_output.scheduled_new_reqs:
|
||||
token_ids = new_req.prompt_token_ids or []
|
||||
mm_hashes = [f.identifier for f in new_req.mm_features]
|
||||
if new_req.req_id in self._requests_need_load:
|
||||
meta.add_request(
|
||||
token_ids=token_ids,
|
||||
block_ids=new_req.block_ids[0],
|
||||
block_size=self._block_size,
|
||||
is_store=False,
|
||||
mm_hashes=mm_hashes,
|
||||
)
|
||||
total_need_load += 1
|
||||
else:
|
||||
# NOTE: here, we set the store and load being exclusive,
|
||||
# but a single request can have both store and load.
|
||||
# NOTE(rob): for this debug implementation, we only cache
|
||||
# the original prompt tokens.
|
||||
if not self._found_match_for_prompt(token_ids, mm_hashes):
|
||||
meta.add_request(
|
||||
token_ids=token_ids,
|
||||
block_ids=new_req.block_ids[0],
|
||||
block_size=self._block_size,
|
||||
is_store=True,
|
||||
mm_hashes=mm_hashes,
|
||||
)
|
||||
|
||||
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||
for i, req_id in enumerate(cached_reqs.req_ids):
|
||||
resumed_from_preemption = req_id in cached_reqs.resumed_req_ids
|
||||
if not resumed_from_preemption or req_id not in self._requests_need_load:
|
||||
continue
|
||||
|
||||
num_computed_tokens = cached_reqs.num_computed_tokens[i]
|
||||
num_new_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
new_block_ids = cached_reqs.new_block_ids[i]
|
||||
|
||||
# NOTE(rob): cached_req_data does not have the full
|
||||
# list of token ids (only new tokens). So we look it
|
||||
# up in the actual request object.
|
||||
request = self._requests_need_load[req_id]
|
||||
total_tokens = num_computed_tokens + num_new_tokens
|
||||
token_ids = request.all_token_ids[:total_tokens]
|
||||
|
||||
# NOTE(rob): For resumed req, new_block_ids is all
|
||||
# of the block_ids for the request.
|
||||
assert new_block_ids is not None
|
||||
block_ids = new_block_ids[0]
|
||||
|
||||
meta.add_request(
|
||||
token_ids=token_ids,
|
||||
block_ids=block_ids,
|
||||
block_size=self._block_size,
|
||||
is_store=False,
|
||||
mm_hashes=[f.identifier for f in request.mm_features],
|
||||
)
|
||||
total_need_load += 1
|
||||
|
||||
assert total_need_load == len(self._requests_need_load)
|
||||
self._requests_need_load.clear()
|
||||
return meta
|
||||
|
||||
# ==============================
|
||||
# Helper functions
|
||||
# ==============================
|
||||
|
||||
def _found_match_for_request(
|
||||
self,
|
||||
request: "Request",
|
||||
) -> bool:
|
||||
"""Check if the cache is hit for the request."""
|
||||
return self._found_match_for_prompt(
|
||||
list(request.prompt_token_ids or []),
|
||||
[f.identifier for f in request.mm_features],
|
||||
)
|
||||
|
||||
def _found_match_for_prompt(
|
||||
self,
|
||||
prompt_token_ids: list[int],
|
||||
mm_hashes: list[str],
|
||||
) -> bool:
|
||||
num_tokens_to_check = align_to_block_size(
|
||||
len(prompt_token_ids) - 1, self._block_size
|
||||
)
|
||||
foldername = self._generate_foldername_debug(
|
||||
torch.tensor(prompt_token_ids)[:num_tokens_to_check],
|
||||
mm_hashes,
|
||||
create_folder=False,
|
||||
)
|
||||
return os.path.exists(foldername)
|
||||
|
||||
def _generate_foldername_debug(
|
||||
self,
|
||||
token_ids: torch.Tensor,
|
||||
mm_hashes: list[str],
|
||||
create_folder=False,
|
||||
) -> str:
|
||||
"""Generate a folder name based on the hash of the bytes of the input
|
||||
ids.
|
||||
"""
|
||||
token_bytes = token_ids.numpy().tobytes()
|
||||
# Add mm_hashes to the bytes being hashed to avoid path traversal and
|
||||
# to create a canonical key.
|
||||
if mm_hashes:
|
||||
mm_str = "-".join(mm_hashes)
|
||||
token_bytes += mm_str.encode("utf-8")
|
||||
input_ids_hash = safe_hash(token_bytes, usedforsecurity=False).hexdigest()
|
||||
|
||||
foldername = os.path.join(self._storage_path, input_ids_hash)
|
||||
if create_folder:
|
||||
os.makedirs(foldername, exist_ok=True)
|
||||
return foldername
|
||||
|
||||
def _generate_filename_debug(
|
||||
self,
|
||||
layer_name: str,
|
||||
token_ids: torch.Tensor,
|
||||
mm_hashes: list[str],
|
||||
) -> str:
|
||||
"""Generate a file name based on the layer name and the hash
|
||||
of the bytes of the input ids.
|
||||
"""
|
||||
foldername = self._generate_foldername_debug(
|
||||
token_ids, mm_hashes=mm_hashes, create_folder=True
|
||||
)
|
||||
return os.path.join(foldername, f"{layer_name}.safetensors")
|
||||
|
||||
|
||||
def align_to_block_size(num_tokens: int, block_size) -> int:
|
||||
"""Align the number of tokens to the block size."""
|
||||
return (num_tokens - 1) // block_size * block_size
|
||||
@@ -0,0 +1,344 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Iterable
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_events import (
|
||||
BlockStored,
|
||||
KVCacheEvent,
|
||||
KVConnectorKVEvents,
|
||||
KVEventAggregator,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1,
|
||||
KVConnectorMetadata,
|
||||
KVConnectorRole,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backend import AttentionMetadata
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.outputs import KVConnectorOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class LMCacheKVEvents(KVConnectorKVEvents):
|
||||
"""
|
||||
Concrete implementation of KVConnectorKVEvents using KVEventAggregator.
|
||||
"""
|
||||
|
||||
def __init__(self, num_workers: int) -> None:
|
||||
self._aggregator = KVEventAggregator(num_workers)
|
||||
|
||||
def add_events(self, events: list[KVCacheEvent]) -> None:
|
||||
self._aggregator.add_events(events)
|
||||
|
||||
def aggregate(self) -> "LMCacheKVEvents":
|
||||
"""
|
||||
Aggregate KV events and retain only common events.
|
||||
"""
|
||||
common_events = self._aggregator.get_common_events()
|
||||
self._aggregator.clear_events()
|
||||
self._aggregator.add_events(common_events)
|
||||
self._aggregator.reset_workers()
|
||||
return self
|
||||
|
||||
def increment_workers(self, count: int = 1) -> None:
|
||||
self._aggregator.increment_workers(count)
|
||||
|
||||
def get_all_events(self) -> list[KVCacheEvent]:
|
||||
return self._aggregator.get_all_events()
|
||||
|
||||
def get_number_of_workers(self) -> int:
|
||||
return self._aggregator.get_number_of_workers()
|
||||
|
||||
def clear_events(self) -> None:
|
||||
self._aggregator.clear_events()
|
||||
self._aggregator.reset_workers()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<LMCacheKVEvents events={self.get_all_events()}>"
|
||||
|
||||
|
||||
class LMCacheConnectorV1(KVConnectorBase_V1):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: "KVCacheConfig",
|
||||
):
|
||||
super().__init__(
|
||||
vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config
|
||||
)
|
||||
assert vllm_config.kv_transfer_config is not None
|
||||
use_native = vllm_config.kv_transfer_config.get_from_extra_config(
|
||||
"use_native", False
|
||||
)
|
||||
if use_native:
|
||||
logger.info("Initializing native LMCache connector")
|
||||
# lazy import
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import lmcache_integration
|
||||
|
||||
_adapter = lmcache_integration.vllm_v1_adapter
|
||||
|
||||
cls = _adapter.LMCacheConnectorV1Impl
|
||||
else:
|
||||
logger.info("Initializing latest dev LMCache connector")
|
||||
# lazy import
|
||||
from lmcache.integration.vllm.vllm_v1_adapter import (
|
||||
LMCacheConnectorV1Impl as LMCacheConnectorLatestImpl,
|
||||
)
|
||||
|
||||
cls = LMCacheConnectorLatestImpl
|
||||
|
||||
self._lmcache_engine = cls(vllm_config, role, self)
|
||||
|
||||
self._kv_cache_events: LMCacheKVEvents | None = None
|
||||
|
||||
# ==============================
|
||||
# Worker-side methods
|
||||
# ==============================
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
"""
|
||||
Initialize with the KV caches. Useful for pre-registering the
|
||||
KV Caches in the KVConnector (e.g. for NIXL).
|
||||
|
||||
Args:
|
||||
kv_caches: dictionary of layer names, kv cache
|
||||
"""
|
||||
if hasattr(self._lmcache_engine, "register_kv_caches"):
|
||||
self._lmcache_engine.register_kv_caches(kv_caches)
|
||||
else:
|
||||
logger.warning(
|
||||
"LMCache engine does not support register_kv_caches, "
|
||||
"please check and use the latest version"
|
||||
)
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None:
|
||||
"""
|
||||
Start loading the KV cache from the connector to vLLM's paged
|
||||
KV buffer. This is called from the forward context before the
|
||||
forward pass to enable async loading during model execution.
|
||||
|
||||
Args:
|
||||
forward_context (ForwardContext): the forward context.
|
||||
**kwargs: additional arguments for the load operation
|
||||
|
||||
Note:
|
||||
The number of elements in kv_caches and layer_names should be
|
||||
the same.
|
||||
|
||||
"""
|
||||
self._lmcache_engine.start_load_kv(forward_context, **kwargs)
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
"""
|
||||
Block until the KV for a specific layer is loaded into vLLM's
|
||||
paged buffer. This is called from within attention layer to ensure
|
||||
async copying from start_load_kv is complete.
|
||||
|
||||
This interface will be useful for layer-by-layer pipelining.
|
||||
|
||||
Args:
|
||||
layer_name: the name of that layer
|
||||
"""
|
||||
self._lmcache_engine.wait_for_layer_load(layer_name)
|
||||
|
||||
def save_kv_layer(
|
||||
self,
|
||||
layer_name: str,
|
||||
kv_layer: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Start saving the a layer of KV cache from vLLM's paged buffer
|
||||
to the connector. This is called from within attention layer to
|
||||
enable async copying during execution.
|
||||
|
||||
Args:
|
||||
layer_name (str): the name of the layer.
|
||||
kv_layer (torch.Tensor): the paged KV buffer of the current
|
||||
layer in vLLM.
|
||||
attn_metadata (AttentionMetadata): the attention metadata.
|
||||
**kwargs: additional arguments for the save operation.
|
||||
"""
|
||||
self._lmcache_engine.save_kv_layer(
|
||||
layer_name, kv_layer, attn_metadata, **kwargs
|
||||
)
|
||||
|
||||
def wait_for_save(self):
|
||||
"""
|
||||
Block until all the save operations is done. This is called
|
||||
as the forward context exits to ensure that the async saving
|
||||
from save_kv_layer is complete before finishing the forward.
|
||||
|
||||
This prevents overwrites of paged KV buffer before saving done.
|
||||
"""
|
||||
self._lmcache_engine.wait_for_save()
|
||||
|
||||
def get_finished(
|
||||
self, finished_req_ids: set[str]
|
||||
) -> tuple[set[str] | None, set[str] | None]:
|
||||
"""
|
||||
Notifies worker-side connector ids of requests that have
|
||||
finished generating tokens.
|
||||
|
||||
Returns:
|
||||
ids of requests that have finished asynchronous transfer
|
||||
(requests that previously returned True from request_finished()),
|
||||
tuple of (sending/saving ids, recving/loading ids).
|
||||
The finished saves/sends req ids must belong to a set provided in a
|
||||
call to this method (this call or a prior one).
|
||||
"""
|
||||
return self._lmcache_engine.get_finished(finished_req_ids)
|
||||
|
||||
def get_block_ids_with_load_errors(self) -> set[int]:
|
||||
"""
|
||||
Get the set of block IDs that failed to load.
|
||||
|
||||
Returns:
|
||||
Set of block IDs that encountered load errors.
|
||||
Empty set if no load errors occurred.
|
||||
"""
|
||||
method = getattr(self._lmcache_engine, "get_block_ids_with_load_errors", None)
|
||||
if callable(method):
|
||||
return method()
|
||||
|
||||
# Fallback for older versions that don't support this method
|
||||
return set()
|
||||
|
||||
def get_kv_connector_kv_cache_events(self) -> LMCacheKVEvents | None:
|
||||
"""
|
||||
Get the KV connector kv cache events collected during the last interval.
|
||||
"""
|
||||
|
||||
events = self._lmcache_engine.get_kv_events() # type: ignore [attr-defined]
|
||||
if not events:
|
||||
return None
|
||||
|
||||
blocks: list[BlockStored] = [
|
||||
BlockStored(
|
||||
block_hashes=e.block_hashes,
|
||||
parent_block_hash=e.parent_block_hash,
|
||||
token_ids=e.token_ids,
|
||||
lora_id=e.lora_id,
|
||||
block_size=e.block_size,
|
||||
medium=e.medium,
|
||||
lora_name=getattr(e, "lora_name", None),
|
||||
)
|
||||
for e in events
|
||||
]
|
||||
|
||||
lmcache_kv_events = LMCacheKVEvents(num_workers=1)
|
||||
lmcache_kv_events.add_events(blocks)
|
||||
return lmcache_kv_events
|
||||
|
||||
# ==============================
|
||||
# Scheduler-side methods
|
||||
# ==============================
|
||||
def get_num_new_matched_tokens(
|
||||
self,
|
||||
request: "Request",
|
||||
num_computed_tokens: int,
|
||||
) -> tuple[int | None, bool]:
|
||||
"""
|
||||
Get number of new tokens that can be loaded from the
|
||||
external KV cache beyond the num_computed_tokens.
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
num_computed_tokens (int): the number of locally
|
||||
computed tokens for this request
|
||||
|
||||
Returns:
|
||||
the number of tokens that can be loaded from the
|
||||
external KV cache beyond what is already computed.
|
||||
"""
|
||||
return self._lmcache_engine.get_num_new_matched_tokens(
|
||||
request, num_computed_tokens
|
||||
), False
|
||||
|
||||
def update_state_after_alloc(
|
||||
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
|
||||
):
|
||||
"""
|
||||
Update KVConnector state after block allocation.
|
||||
"""
|
||||
self._lmcache_engine.update_state_after_alloc(request, num_external_tokens)
|
||||
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: SchedulerOutput
|
||||
) -> KVConnectorMetadata:
|
||||
"""
|
||||
Build the connector metadata for this step.
|
||||
|
||||
This function should NOT modify fields in the scheduler_output.
|
||||
Also, calling this function will reset the state of the connector.
|
||||
|
||||
Args:
|
||||
scheduler_output (SchedulerOutput): the scheduler output object.
|
||||
"""
|
||||
return self._lmcache_engine.build_connector_meta(scheduler_output)
|
||||
|
||||
def update_connector_output(self, connector_output: KVConnectorOutput):
|
||||
"""
|
||||
Update KVConnector state from worker-side connectors output.
|
||||
|
||||
Args:
|
||||
connector_output (KVConnectorOutput): the worker-side
|
||||
connectors output.
|
||||
"""
|
||||
# Get the KV events
|
||||
kv_cache_events = connector_output.kv_cache_events
|
||||
if not kv_cache_events or not isinstance(kv_cache_events, LMCacheKVEvents):
|
||||
return
|
||||
|
||||
if self._kv_cache_events is None:
|
||||
self._kv_cache_events = kv_cache_events
|
||||
else:
|
||||
self._kv_cache_events.add_events(kv_cache_events.get_all_events())
|
||||
self._kv_cache_events.increment_workers(
|
||||
kv_cache_events.get_number_of_workers()
|
||||
)
|
||||
return
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
"""
|
||||
Called when a request has finished, before its blocks are freed.
|
||||
|
||||
Returns:
|
||||
True if the request is being saved/sent asynchronously and blocks
|
||||
should not be freed until the request_id is returned from
|
||||
get_finished().
|
||||
Optional KVTransferParams to be included in the request outputs
|
||||
returned by the engine.
|
||||
"""
|
||||
return self._lmcache_engine.request_finished(request, block_ids)
|
||||
|
||||
def take_events(self) -> Iterable["KVCacheEvent"]:
|
||||
"""
|
||||
Take the KV cache events from the connector.
|
||||
|
||||
Yields:
|
||||
New KV cache events since the last call.
|
||||
"""
|
||||
if self._kv_cache_events is not None:
|
||||
self._kv_cache_events.aggregate()
|
||||
kv_cache_events = self._kv_cache_events.get_all_events()
|
||||
yield from kv_cache_events
|
||||
self._kv_cache_events.clear_events()
|
||||
self._kv_cache_events = None
|
||||
@@ -0,0 +1,18 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
from . import multi_process_adapter, vllm_v1_adapter
|
||||
from .multi_process_adapter import (
|
||||
LMCacheMPSchedulerAdapter,
|
||||
LMCacheMPWorkerAdapter,
|
||||
LoadStoreOp,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"vllm_v1_adapter",
|
||||
"multi_process_adapter",
|
||||
"LMCacheMPSchedulerAdapter",
|
||||
"LMCacheMPWorkerAdapter",
|
||||
"LoadStoreOp",
|
||||
]
|
||||
@@ -0,0 +1,666 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import os
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from itertools import islice
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import zmq
|
||||
from lmcache.utils import _lmcache_nvtx_annotate, init_logger
|
||||
from lmcache.v1.multiprocess.custom_types import (
|
||||
CudaIPCWrapper,
|
||||
IPCCacheEngineKey,
|
||||
KVCache,
|
||||
)
|
||||
from lmcache.v1.multiprocess.mq import MessageQueueClient, MessagingFuture
|
||||
from lmcache.v1.multiprocess.protocol import RequestType, get_response_class
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def wrap_kv_caches(kv_caches: dict[str, torch.Tensor]) -> KVCache:
|
||||
logger.info("KV caches keys are %s", list(kv_caches.keys()))
|
||||
return [CudaIPCWrapper(tensor) for tensor in kv_caches.values()]
|
||||
|
||||
|
||||
def striding_block_hashes(
|
||||
block_hashes: list[bytes], blocks_in_chunk: int
|
||||
) -> Iterable[bytes]:
|
||||
"""Extract chunk-level hashes from block hashes by striding.
|
||||
|
||||
In hash-based vLLM, each vLLM block has its own hash. LMCache chunks
|
||||
span ``blocks_in_chunk`` consecutive blocks. The representative hash
|
||||
for a chunk is the hash of the **last** block in that chunk (because
|
||||
each block hash already encodes its prefix). So we start at index
|
||||
``blocks_in_chunk - 1`` and stride by ``blocks_in_chunk``.
|
||||
"""
|
||||
return islice(block_hashes, blocks_in_chunk - 1, None, blocks_in_chunk)
|
||||
|
||||
|
||||
def send_lmcache_request(
|
||||
mq_client: MessageQueueClient,
|
||||
request_type: RequestType,
|
||||
payloads: list[Any],
|
||||
) -> MessagingFuture[Any]:
|
||||
"""
|
||||
Helper function to send the request to the LMCache multiprocess server
|
||||
|
||||
Args:
|
||||
mq_client: The LMCache multiprocess mode message queue client
|
||||
request_type: The request type
|
||||
payloads: The request payloads
|
||||
|
||||
Returns:
|
||||
A messaging future for the request
|
||||
"""
|
||||
|
||||
future = mq_client.submit_request(
|
||||
request_type, payloads, get_response_class(request_type)
|
||||
)
|
||||
return future
|
||||
|
||||
|
||||
def get_lmcache_chunk_size(
|
||||
mq_client: MessageQueueClient,
|
||||
) -> int:
|
||||
"""
|
||||
Helper function to get the LMCache chunk size from the server
|
||||
|
||||
Args:
|
||||
mq_client: The LMCache multiprocess mode message queue client
|
||||
|
||||
Returns:
|
||||
An integer representing the LMCache chunk size
|
||||
"""
|
||||
future = send_lmcache_request(mq_client, RequestType.GET_CHUNK_SIZE, [])
|
||||
chunk_size = future.result()
|
||||
return chunk_size
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoadStoreOp:
|
||||
block_ids: list[int]
|
||||
"""Block ids for the load/store operation"""
|
||||
|
||||
token_ids: list[int] | None = None
|
||||
"""Token IDs for the load/store operation (token mode)"""
|
||||
|
||||
block_hashes: list[bytes] | None = None
|
||||
"""Block hashes for the load/store operation (hash mode)"""
|
||||
|
||||
start: int = 0
|
||||
"""Start token index (token mode only)"""
|
||||
|
||||
end: int = 0
|
||||
"""End token index (token mode only)"""
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.block_ids)
|
||||
|
||||
|
||||
StoreResult = bool
|
||||
RetrieveResult = list[bool]
|
||||
LookupResult = int
|
||||
|
||||
|
||||
class LMCacheMPSchedulerAdapter:
|
||||
def __init__(
|
||||
self,
|
||||
server_url: str,
|
||||
context: zmq.Context,
|
||||
model_name: str,
|
||||
world_size: int,
|
||||
kv_rank: int,
|
||||
vllm_block_size: int,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
server_url: The server URL for the LMCache message queue
|
||||
context: The ZMQ context
|
||||
|
||||
model_name: The model name used for LMCache keys
|
||||
world_size: The world size used for LMCache keys
|
||||
kv_rank: The kv rank used for LMCache keys
|
||||
vllm_block_size: The block size used in vLLM
|
||||
"""
|
||||
self.mq_client = MessageQueueClient(server_url, context)
|
||||
|
||||
# Request futures
|
||||
self.lookup_futures: dict[str, MessagingFuture[LookupResult]] = {}
|
||||
|
||||
self.model_name = model_name
|
||||
self.world_size = world_size
|
||||
self.worker_id = kv_rank
|
||||
|
||||
# Read chunk size from lmcache
|
||||
self.chunk_size = get_lmcache_chunk_size(self.mq_client)
|
||||
assert self.chunk_size % vllm_block_size == 0, (
|
||||
"LMCache chunk size should be a multiple of vLLM block size"
|
||||
)
|
||||
self.blocks_in_chunk = self.chunk_size // vllm_block_size
|
||||
|
||||
@_lmcache_nvtx_annotate
|
||||
def maybe_submit_lookup_request(
|
||||
self,
|
||||
request_id: str,
|
||||
block_hashes: list[bytes] | None = None,
|
||||
token_ids: list[int] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Submit a new lookup request to LMCache if there is no ongoing request.
|
||||
|
||||
Supports both token-based and hash-based vLLM:
|
||||
- token_ids: token IDs (token-based vLLM) -> single token-mode key
|
||||
- block_hashes: block hashes (hash-based vLLM) -> strided hash-mode keys
|
||||
|
||||
Exactly one of block_hashes or token_ids must be provided.
|
||||
|
||||
Args:
|
||||
request_id: The ID of the lookup request. The same ID indicates it's
|
||||
from the same request
|
||||
block_hashes: Block hashes to lookup from LMCache (hash mode)
|
||||
token_ids: Token IDs to lookup from LMCache (token mode)
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Notes:
|
||||
This function will have a side-effect: submitting a look up request to
|
||||
LMCache, which will essentially 'lock' the KV cache chunks in the LMCache
|
||||
for later retrieve operations.
|
||||
In the meantime, this function will record the lookup request, and the
|
||||
status of the look up request can be checked by `check_lookup_result`.
|
||||
"""
|
||||
if request_id in self.lookup_futures:
|
||||
# Skip if there is already a lookup request
|
||||
return
|
||||
|
||||
assert (block_hashes is None) != (token_ids is None), (
|
||||
"Exactly one of block_hashes or token_ids must be provided"
|
||||
)
|
||||
|
||||
if block_hashes is not None:
|
||||
# Hash mode: stride block hashes -> N hash-mode keys
|
||||
chunk_hashes = list(
|
||||
striding_block_hashes(block_hashes, self.blocks_in_chunk)
|
||||
)
|
||||
keys = [
|
||||
self._create_hash_key(ch, request_id=request_id) for ch in chunk_hashes
|
||||
]
|
||||
else:
|
||||
# Token mode: truncate to chunk-aligned length
|
||||
assert token_ids is not None
|
||||
aligned_end = (len(token_ids) // self.chunk_size) * self.chunk_size
|
||||
if aligned_end == 0:
|
||||
return
|
||||
keys = [
|
||||
self._create_key(
|
||||
token_ids,
|
||||
start=0,
|
||||
end=aligned_end,
|
||||
request_id=request_id,
|
||||
).no_worker_id_version()
|
||||
]
|
||||
|
||||
future = send_lmcache_request(
|
||||
self.mq_client,
|
||||
RequestType.LOOKUP,
|
||||
[keys],
|
||||
)
|
||||
self.lookup_futures[request_id] = future
|
||||
|
||||
@_lmcache_nvtx_annotate
|
||||
def check_lookup_result(self, request_id: str) -> int | None:
|
||||
"""
|
||||
Check the result of a previously submitted lookup request.
|
||||
|
||||
Args:
|
||||
request_id: The ID of the lookup request submitted in
|
||||
`maybe_submit_lookup_request`
|
||||
|
||||
Returns:
|
||||
An integer representing the total number of tokens matched
|
||||
in LMCache (prefix matching), or
|
||||
None if the lookup request is not finished yet.
|
||||
"""
|
||||
assert request_id in self.lookup_futures, (
|
||||
f"Lookup request for request_id={request_id} has not been submitted"
|
||||
)
|
||||
|
||||
future = self.lookup_futures[request_id]
|
||||
if not future.query():
|
||||
return None
|
||||
|
||||
result = future.result()
|
||||
num_chunks = result
|
||||
return num_chunks * self.chunk_size
|
||||
|
||||
def num_blocks_per_chunk(self) -> int:
|
||||
"""
|
||||
Returns:
|
||||
The number of vllm blocks in a LMCache data chunk
|
||||
"""
|
||||
return self.blocks_in_chunk
|
||||
|
||||
def cleanup_lookup_result(self, request_id: str) -> None:
|
||||
"""
|
||||
Clean up lookup future for a finished request to prevent memory leak.
|
||||
Args:
|
||||
request_id: The ID of the finished request.
|
||||
"""
|
||||
self.lookup_futures.pop(request_id, None)
|
||||
|
||||
def end_session(self, request_id: str) -> None:
|
||||
"""
|
||||
Notify LMCache server to remove the session for a finished request.
|
||||
Args:
|
||||
request_id: The ID of the finished request.
|
||||
"""
|
||||
send_lmcache_request(
|
||||
self.mq_client,
|
||||
RequestType.END_SESSION,
|
||||
[request_id],
|
||||
)
|
||||
|
||||
# Helper functions
|
||||
def _create_key(
|
||||
self,
|
||||
token_ids: list[int],
|
||||
start: int = 0,
|
||||
end: int = 0,
|
||||
request_id: str | None = None,
|
||||
) -> IPCCacheEngineKey:
|
||||
"""Convert token IDs to an IPC cache engine key"""
|
||||
return IPCCacheEngineKey(
|
||||
model_name=self.model_name,
|
||||
world_size=self.world_size,
|
||||
worker_id=self.worker_id,
|
||||
token_ids=tuple(token_ids),
|
||||
start=start,
|
||||
end=end,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
def _create_hash_key(
|
||||
self, chunk_hash: bytes, request_id: str | None = None
|
||||
) -> IPCCacheEngineKey:
|
||||
"""Create a hash-mode IPC cache engine key"""
|
||||
return IPCCacheEngineKey(
|
||||
model_name=self.model_name,
|
||||
world_size=self.world_size,
|
||||
worker_id=None,
|
||||
chunk_hash=chunk_hash,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
|
||||
class LMCacheMPWorkerAdapter:
|
||||
def __init__(
|
||||
self,
|
||||
server_url: str,
|
||||
context: zmq.Context,
|
||||
model_name: str,
|
||||
world_size: int,
|
||||
kv_rank: int,
|
||||
vllm_block_size: int,
|
||||
):
|
||||
self.mq_client = MessageQueueClient(server_url, context)
|
||||
|
||||
# Instance id for GPU worker
|
||||
self.instance_id = os.getpid()
|
||||
|
||||
# Registered kv caches from vLLM
|
||||
self.kv_caches: dict[str, torch.Tensor] = {}
|
||||
|
||||
# Request futures
|
||||
# request_id -> (future, other merged requests)
|
||||
self.store_futures: dict[
|
||||
str, tuple[MessagingFuture[StoreResult], list[str]]
|
||||
] = {}
|
||||
self.retrieve_futures: dict[
|
||||
str, tuple[MessagingFuture[RetrieveResult], list[str]]
|
||||
] = {}
|
||||
|
||||
# The store requests that have finished execution in LMCache
|
||||
self.finished_stores: set[str] = set()
|
||||
# The finished request ids that are passed via vLLM and also
|
||||
# have corresponding store requests submitted to LMCache before
|
||||
self.previously_finished: set[str] = set()
|
||||
|
||||
self.model_name = model_name
|
||||
self.world_size = world_size
|
||||
self.worker_id = kv_rank
|
||||
|
||||
# Read chunk size from lmcache
|
||||
chunk_size = get_lmcache_chunk_size(self.mq_client)
|
||||
assert chunk_size % vllm_block_size == 0, (
|
||||
"LMCache chunk size should be a multiple of vLLM block size"
|
||||
)
|
||||
self.blocks_in_chunk = chunk_size // vllm_block_size
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
"""
|
||||
Register the kv caches with LMCache server
|
||||
|
||||
Args:
|
||||
kv_caches: A dict of kv caches to register. The keys are the
|
||||
layer names and the values are the corresponding tensors.
|
||||
"""
|
||||
# Register kv cache and send the request
|
||||
self.kv_caches = kv_caches
|
||||
logger.info("Registering kv caches")
|
||||
future = send_lmcache_request(
|
||||
self.mq_client,
|
||||
RequestType.REGISTER_KV_CACHE,
|
||||
[self.instance_id, wrap_kv_caches(kv_caches)],
|
||||
)
|
||||
future.result()
|
||||
|
||||
@_lmcache_nvtx_annotate
|
||||
def submit_store_request(
|
||||
self, request_id: str, op: LoadStoreOp, event: torch.cuda.Event
|
||||
):
|
||||
"""
|
||||
Submit a KV cache store request to LMCache
|
||||
|
||||
Args:
|
||||
request_id: The ID of the request
|
||||
op: The LoadStoreOp describing the store operation.
|
||||
event: The CUDA event that is recorded after the current
|
||||
model inference step
|
||||
"""
|
||||
if op.block_hashes is not None:
|
||||
# Hash mode
|
||||
chunk_hashes = list(
|
||||
striding_block_hashes(op.block_hashes, self.blocks_in_chunk)
|
||||
)
|
||||
keys = [
|
||||
self._create_hash_key(ch, request_id=request_id) for ch in chunk_hashes
|
||||
]
|
||||
else:
|
||||
# Token mode
|
||||
assert op.token_ids is not None
|
||||
keys = [
|
||||
self._create_key(op.token_ids, op.start, op.end, request_id=request_id)
|
||||
]
|
||||
future = send_lmcache_request(
|
||||
self.mq_client,
|
||||
RequestType.STORE,
|
||||
[keys, self.instance_id, op.block_ids, event.ipc_handle()],
|
||||
).to_cuda_future()
|
||||
self.store_futures[request_id] = (future, [])
|
||||
|
||||
@_lmcache_nvtx_annotate
|
||||
def submit_retrieve_request(
|
||||
self, request_id: str, op: LoadStoreOp, event: torch.cuda.Event
|
||||
):
|
||||
"""
|
||||
Submit a KV cache retrieve request to LMCache
|
||||
|
||||
Args:
|
||||
request_id: The ID of the request
|
||||
op: The LoadStoreOp describing the retrieve operation.
|
||||
event: The CUDA event that is recorded after the current
|
||||
model inference step
|
||||
"""
|
||||
if op.block_hashes is not None:
|
||||
# Hash mode
|
||||
chunk_hashes = list(
|
||||
striding_block_hashes(op.block_hashes, self.blocks_in_chunk)
|
||||
)
|
||||
keys = [
|
||||
self._create_hash_key(ch, request_id=request_id) for ch in chunk_hashes
|
||||
]
|
||||
else:
|
||||
# Token mode
|
||||
assert op.token_ids is not None
|
||||
keys = [
|
||||
self._create_key(op.token_ids, op.start, op.end, request_id=request_id)
|
||||
]
|
||||
future = send_lmcache_request(
|
||||
self.mq_client,
|
||||
RequestType.RETRIEVE,
|
||||
[keys, self.instance_id, op.block_ids, event.ipc_handle()],
|
||||
).to_cuda_future()
|
||||
self.retrieve_futures[request_id] = (future, [])
|
||||
|
||||
@_lmcache_nvtx_annotate
|
||||
def batched_submit_store_requests(
|
||||
self,
|
||||
request_ids: list[str],
|
||||
ops: list[LoadStoreOp],
|
||||
event: torch.cuda.Event,
|
||||
):
|
||||
"""
|
||||
Submit a batched store request to LMCache
|
||||
|
||||
Args:
|
||||
request_ids: The IDs of the requests
|
||||
ops: The LoadStoreOps describing the store operations. Should have
|
||||
the same length as request_ids
|
||||
event: The CUDA event that is recorded after the current
|
||||
model inference step
|
||||
"""
|
||||
all_keys: list[IPCCacheEngineKey] = []
|
||||
block_ids: list[int] = []
|
||||
for request_id, op in zip(request_ids, ops, strict=False):
|
||||
if op.block_hashes is not None:
|
||||
chunk_hashes = list(
|
||||
striding_block_hashes(op.block_hashes, self.blocks_in_chunk)
|
||||
)
|
||||
keys = [
|
||||
self._create_hash_key(ch, request_id=request_id)
|
||||
for ch in chunk_hashes
|
||||
]
|
||||
all_keys.extend(keys)
|
||||
else:
|
||||
assert op.token_ids is not None
|
||||
all_keys.append(
|
||||
self._create_key(
|
||||
op.token_ids, op.start, op.end, request_id=request_id
|
||||
)
|
||||
)
|
||||
block_ids.extend(op.block_ids)
|
||||
future = send_lmcache_request(
|
||||
self.mq_client,
|
||||
RequestType.STORE,
|
||||
[
|
||||
all_keys,
|
||||
self.instance_id,
|
||||
block_ids,
|
||||
event.ipc_handle(),
|
||||
],
|
||||
).to_cuda_future()
|
||||
self.store_futures[request_ids[0]] = (future, list(request_ids[1:]))
|
||||
|
||||
@_lmcache_nvtx_annotate
|
||||
def batched_submit_retrieve_requests(
|
||||
self,
|
||||
request_ids: list[str],
|
||||
ops: list[LoadStoreOp],
|
||||
event: torch.cuda.Event,
|
||||
):
|
||||
"""
|
||||
Submit a batched retrieve request to LMCache
|
||||
|
||||
Args:
|
||||
request_ids: The IDs of the requests
|
||||
ops: The LoadStoreOps describing the retrieve operations. Should have
|
||||
the same length as request_ids
|
||||
event: The CUDA event that is recorded after the current
|
||||
model inference step
|
||||
"""
|
||||
all_keys: list[IPCCacheEngineKey] = []
|
||||
block_ids: list[int] = []
|
||||
for request_id, op in zip(request_ids, ops, strict=False):
|
||||
if op.block_hashes is not None:
|
||||
chunk_hashes = list(
|
||||
striding_block_hashes(op.block_hashes, self.blocks_in_chunk)
|
||||
)
|
||||
keys = [
|
||||
self._create_hash_key(ch, request_id=request_id)
|
||||
for ch in chunk_hashes
|
||||
]
|
||||
all_keys.extend(keys)
|
||||
else:
|
||||
assert op.token_ids is not None
|
||||
all_keys.append(
|
||||
self._create_key(
|
||||
op.token_ids, op.start, op.end, request_id=request_id
|
||||
)
|
||||
)
|
||||
block_ids.extend(op.block_ids)
|
||||
future = send_lmcache_request(
|
||||
self.mq_client,
|
||||
RequestType.RETRIEVE,
|
||||
[
|
||||
all_keys,
|
||||
self.instance_id,
|
||||
block_ids,
|
||||
event.ipc_handle(),
|
||||
],
|
||||
).to_cuda_future()
|
||||
self.retrieve_futures[request_ids[0]] = (future, list(request_ids[1:]))
|
||||
|
||||
@_lmcache_nvtx_annotate
|
||||
def get_finished(
|
||||
self, finished_req_ids_from_engine: set[str]
|
||||
) -> tuple[set[str] | None, set[str] | None]:
|
||||
"""
|
||||
Check and get the finished store and retrieve requests.
|
||||
|
||||
Args:
|
||||
finished_req_ids_from_engine: the set of request ids that are
|
||||
reported as finished from the vLLM engine side.
|
||||
|
||||
Returns:
|
||||
A tuple of two sets:
|
||||
- The first set contains the finished store request ids. The returned
|
||||
store request ids MUST be seen before in the
|
||||
`finished_req_ids_from_engine`.
|
||||
- The second set contains the finished retrieve request ids.
|
||||
|
||||
Notes:
|
||||
When enabling async scheduling in vLLM, the same request ID may appear
|
||||
multiple times in `finished_req_ids_from_engine`. The adapter should
|
||||
take care of deduplicating the request IDs and only return the request
|
||||
IDs that have not been returned before.
|
||||
"""
|
||||
finished_stores = set()
|
||||
finished_retrieves = set()
|
||||
for request_id, (s_future, other_reqs) in self.store_futures.items():
|
||||
if not s_future.query():
|
||||
continue
|
||||
|
||||
s_result = s_future.result()
|
||||
finished_stores.add(request_id)
|
||||
finished_stores.update(other_reqs)
|
||||
|
||||
if not s_result:
|
||||
# TODO: add error handling here
|
||||
logger.error(
|
||||
"Something went wrong when processing the "
|
||||
"store request for request_id=%s",
|
||||
request_id,
|
||||
)
|
||||
|
||||
for request_id, (r_future, other_reqs) in self.retrieve_futures.items():
|
||||
if not r_future.query():
|
||||
continue
|
||||
|
||||
r_result = r_future.result()
|
||||
finished_retrieves.add(request_id)
|
||||
finished_retrieves.update(other_reqs)
|
||||
|
||||
if not all(r_result):
|
||||
# TODO: add error handing here
|
||||
logger.error(
|
||||
"Something went wrong when processing the "
|
||||
"retrieve request for request_id=%s, result=%s",
|
||||
request_id,
|
||||
r_result,
|
||||
)
|
||||
|
||||
# Remove the finished requests from the tracking dicts
|
||||
for request_id in finished_stores:
|
||||
self.store_futures.pop(request_id, None)
|
||||
for request_id in finished_retrieves:
|
||||
self.retrieve_futures.pop(request_id, None)
|
||||
|
||||
# Update the internal states
|
||||
self.finished_stores.update(finished_stores)
|
||||
|
||||
ret_stores = set()
|
||||
for req_id in finished_req_ids_from_engine:
|
||||
if req_id in self.finished_stores or req_id in self.store_futures:
|
||||
self.previously_finished.add(req_id)
|
||||
else:
|
||||
ret_stores.add(req_id)
|
||||
|
||||
# Calculate the final finished stores
|
||||
ret_stores.update(self._update_and_get_finished_store())
|
||||
|
||||
return ret_stores, finished_retrieves
|
||||
|
||||
def num_blocks_per_chunk(self) -> int:
|
||||
"""
|
||||
Returns:
|
||||
The number of vllm blocks in a LMCache data chunk
|
||||
"""
|
||||
return self.blocks_in_chunk
|
||||
|
||||
def shutdown(self):
|
||||
"""
|
||||
Shutdown the LMCache MP worker adapter
|
||||
"""
|
||||
logger.info("Unregistering kv caches")
|
||||
send_lmcache_request(
|
||||
self.mq_client, RequestType.UNREGISTER_KV_CACHE, [self.instance_id]
|
||||
).result()
|
||||
|
||||
self.mq_client.close()
|
||||
|
||||
# Helper functions
|
||||
def _update_and_get_finished_store(
|
||||
self,
|
||||
) -> set[str]:
|
||||
"""Converge the internal states about finished stores
|
||||
and returns the 'safe finished store request ids' back
|
||||
"""
|
||||
safe_finished_s = self.finished_stores.intersection(self.previously_finished)
|
||||
self.finished_stores.difference_update(self.previously_finished)
|
||||
self.previously_finished.difference_update(safe_finished_s)
|
||||
|
||||
return safe_finished_s
|
||||
|
||||
def _create_key(
|
||||
self,
|
||||
token_ids: list[int],
|
||||
start: int = 0,
|
||||
end: int = 0,
|
||||
request_id: str | None = None,
|
||||
) -> IPCCacheEngineKey:
|
||||
"""Convert token IDs to an IPC cache engine key"""
|
||||
return IPCCacheEngineKey(
|
||||
model_name=self.model_name,
|
||||
world_size=self.world_size,
|
||||
worker_id=self.worker_id,
|
||||
token_ids=tuple(token_ids),
|
||||
start=start,
|
||||
end=end,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
def _create_hash_key(
|
||||
self, chunk_hash: bytes, request_id: str | None = None
|
||||
) -> IPCCacheEngineKey:
|
||||
"""Create a hash-mode IPC cache engine key"""
|
||||
return IPCCacheEngineKey(
|
||||
model_name=self.model_name,
|
||||
world_size=self.world_size,
|
||||
worker_id=self.worker_id,
|
||||
chunk_hash=chunk_hash,
|
||||
request_id=request_id,
|
||||
)
|
||||
@@ -0,0 +1,211 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Standard
|
||||
import os
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
from lmcache.logging import init_logger
|
||||
from lmcache.v1.config import LMCacheEngineConfig as V1Config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.multimodal.inputs import PlaceholderRange
|
||||
from vllm.v1.core.sched.output import NewRequestData
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
ENGINE_NAME = "vllm-instance"
|
||||
|
||||
# Thread-safe singleton storage
|
||||
_config_instance: V1Config | None = None
|
||||
_config_lock = threading.Lock()
|
||||
|
||||
|
||||
def is_false(value: str) -> bool:
|
||||
"""Check if the given string value is equivalent to 'false'."""
|
||||
return value.lower() in ("false", "0", "no", "n", "off")
|
||||
|
||||
|
||||
def lmcache_get_or_create_config() -> V1Config:
|
||||
"""Get the LMCache configuration from the environment variable
|
||||
`LMCACHE_CONFIG_FILE`. If the environment variable is not set, this
|
||||
function will return the default configuration.
|
||||
|
||||
This function is thread-safe and implements singleton pattern,
|
||||
ensuring the configuration is loaded only once.
|
||||
"""
|
||||
global _config_instance
|
||||
|
||||
# Double-checked locking for thread-safe singleton
|
||||
if _config_instance is None:
|
||||
with _config_lock:
|
||||
if _config_instance is None: # Check again within lock
|
||||
LMCacheEngineConfig = V1Config # type: ignore[assignment]
|
||||
|
||||
if "LMCACHE_CONFIG_FILE" not in os.environ:
|
||||
logger.warning(
|
||||
"No LMCache configuration file is set. Trying to read"
|
||||
" configurations from the environment variables."
|
||||
)
|
||||
logger.warning(
|
||||
"You can set the configuration file through "
|
||||
"the environment variable: LMCACHE_CONFIG_FILE"
|
||||
)
|
||||
_config_instance = LMCacheEngineConfig.from_env()
|
||||
else:
|
||||
config_file = os.environ["LMCACHE_CONFIG_FILE"]
|
||||
logger.info("Loading LMCache config file %s", config_file)
|
||||
_config_instance = LMCacheEngineConfig.from_file(config_file)
|
||||
# Update config from environment variables
|
||||
_config_instance.update_config_from_env()
|
||||
return _config_instance
|
||||
|
||||
|
||||
def hex_hash_to_int16(s: str) -> int:
|
||||
"""
|
||||
Convert a hex hash string to a 16-bit integer.
|
||||
"""
|
||||
return int(s, 16) & 0xFFFF
|
||||
|
||||
|
||||
def apply_mm_hashes_to_token_ids(
|
||||
token_ids: torch.Tensor,
|
||||
mm_hashes: list[str],
|
||||
mm_positions: list["PlaceholderRange"],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Overwrite token_ids in-place for multimodal placeholders using
|
||||
efficient slice assignments.
|
||||
"""
|
||||
n = token_ids.size(0)
|
||||
for hash_str, placeholder in zip(mm_hashes, mm_positions):
|
||||
start, length = placeholder.offset, placeholder.length
|
||||
if start >= n:
|
||||
continue
|
||||
end = min(start + length, n)
|
||||
token_ids[start:end] = hex_hash_to_int16(hash_str)
|
||||
return token_ids
|
||||
|
||||
|
||||
def mla_enabled(model_config: "ModelConfig") -> bool:
|
||||
return (
|
||||
hasattr(model_config, "use_mla")
|
||||
and isinstance(model_config.use_mla, bool)
|
||||
and model_config.use_mla
|
||||
)
|
||||
|
||||
|
||||
def create_lmcache_metadata(
|
||||
vllm_config=None, model_config=None, parallel_config=None, cache_config=None
|
||||
):
|
||||
"""
|
||||
Create LMCacheEngineMetadata from vLLM configuration.
|
||||
|
||||
This function extracts common metadata creation logic that was duplicated
|
||||
across multiple files.
|
||||
|
||||
Args:
|
||||
vllm_config (VllmConfig): vLLM configuration object containing model,
|
||||
parallel, and cache configs (alternative to
|
||||
individual config parameters)
|
||||
model_config (ModelConfig): Model configuration (alternative to
|
||||
vllm_config)
|
||||
parallel_config (ParallelConfig): Parallel configuration (alternative
|
||||
to vllm_config)
|
||||
cache_config (CacheConfig): Cache configuration (alternative to
|
||||
vllm_config)
|
||||
"""
|
||||
# Third Party
|
||||
# First Party
|
||||
from lmcache.config import LMCacheEngineMetadata
|
||||
|
||||
from vllm.utils.torch_utils import get_kv_cache_torch_dtype
|
||||
|
||||
config = lmcache_get_or_create_config()
|
||||
# Support both vllm_config object and individual config parameters
|
||||
if vllm_config is not None:
|
||||
model_cfg = vllm_config.model_config
|
||||
parallel_cfg = vllm_config.parallel_config
|
||||
cache_cfg = vllm_config.cache_config
|
||||
else:
|
||||
if model_config is None or parallel_config is None or cache_config is None:
|
||||
raise ValueError(
|
||||
"Either vllm_config must be provided, or all of "
|
||||
"model_config, parallel_config, and cache_config must be provided."
|
||||
)
|
||||
model_cfg = model_config
|
||||
parallel_cfg = parallel_config
|
||||
cache_cfg = cache_config
|
||||
|
||||
# Get KV cache dtype
|
||||
kv_dtype = get_kv_cache_torch_dtype(cache_cfg.cache_dtype, model_cfg.dtype)
|
||||
|
||||
# Check if MLA is enabled
|
||||
use_mla = mla_enabled(model_cfg)
|
||||
|
||||
# Construct KV shape (for memory pool)
|
||||
num_layer = model_cfg.get_num_layers(parallel_cfg)
|
||||
chunk_size = config.chunk_size
|
||||
num_kv_head = model_cfg.get_num_kv_heads(parallel_cfg)
|
||||
head_size = model_cfg.get_head_size()
|
||||
kv_shape = (num_layer, 1 if use_mla else 2, chunk_size, num_kv_head, head_size)
|
||||
|
||||
# Create metadata
|
||||
metadata = LMCacheEngineMetadata(
|
||||
model_cfg.model,
|
||||
parallel_cfg.world_size,
|
||||
parallel_cfg.rank,
|
||||
"vllm",
|
||||
kv_dtype,
|
||||
kv_shape,
|
||||
use_mla,
|
||||
)
|
||||
|
||||
return metadata, config
|
||||
|
||||
|
||||
def extract_mm_features(
|
||||
request: Union["Request", "NewRequestData"], modify: bool = False
|
||||
) -> tuple[list[str], list["PlaceholderRange"]]:
|
||||
"""
|
||||
Normalize multimodal information from a Request into parallel lists.
|
||||
|
||||
This helper reads either:
|
||||
1) `request.mm_features` (objects each exposing `.identifier` and
|
||||
`.mm_position`), or
|
||||
2) legacy fields `request.mm_hashes` and `request.mm_positions`.
|
||||
|
||||
It returns two equally sized lists: the multimodal hash identifiers and
|
||||
their corresponding positions. If the request contains no multimodal info,
|
||||
it returns `([], [])`.
|
||||
|
||||
Args:
|
||||
request (Request): The source object.
|
||||
modify (bool):
|
||||
Controls copy semantics for the legacy-path return values.
|
||||
- If True and legacy fields are used, shallow-copies are returned so
|
||||
the caller can mutate the lists without affecting `request`.
|
||||
- If False, the original legacy sequences are returned as-is
|
||||
(zero-copy); treat them as read-only.
|
||||
|
||||
Returns:
|
||||
tuple[list[str], list[PlaceholderRange]]: (`mm_hashes`, `mm_positions`).
|
||||
May be `([], [])` when no multimodal data is present.
|
||||
"""
|
||||
if getattr(request, "mm_features", None):
|
||||
mm_hashes, mm_positions = zip(
|
||||
*((f.identifier, f.mm_position) for f in request.mm_features)
|
||||
)
|
||||
return (list(mm_hashes), list(mm_positions))
|
||||
elif getattr(request, "mm_hashes", None):
|
||||
if modify:
|
||||
return (
|
||||
request.mm_hashes.copy(), # type: ignore
|
||||
request.mm_positions.copy(), # type: ignore
|
||||
)
|
||||
else:
|
||||
return (request.mm_hashes, request.mm_positions) # type: ignore
|
||||
else:
|
||||
return ([], [])
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,955 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import enum
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
import torch
|
||||
import zmq
|
||||
from lmcache.integration.vllm.utils import mla_enabled
|
||||
from lmcache.utils import init_logger as lmcache_init_logger
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1,
|
||||
KVConnectorMetadata,
|
||||
KVConnectorRole,
|
||||
)
|
||||
from vllm.v1.attention.backend import AttentionMetadata
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.outputs import KVConnectorOutput
|
||||
from vllm.v1.request import RequestStatus
|
||||
from vllm.v1.utils import ConstantList
|
||||
|
||||
try:
|
||||
from lmcache.integration.vllm.vllm_multi_process_adapter import (
|
||||
LMCacheMPSchedulerAdapter,
|
||||
LMCacheMPWorkerAdapter,
|
||||
LoadStoreOp,
|
||||
)
|
||||
except ImportError:
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_integration import (
|
||||
LMCacheMPSchedulerAdapter,
|
||||
LMCacheMPWorkerAdapter,
|
||||
LoadStoreOp,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.distributed.kv_events import KVCacheEvent
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
||||
KVConnectorPromMetrics,
|
||||
KVConnectorStats,
|
||||
PromMetric,
|
||||
PromMetricT,
|
||||
)
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = lmcache_init_logger(__name__)
|
||||
|
||||
|
||||
# Helper functions
|
||||
def reformat_block_ids(block_ids: tuple[list[int], ...] | None) -> list[int]:
|
||||
if block_ids is None:
|
||||
return []
|
||||
assert isinstance(block_ids, tuple), (
|
||||
f"Expected block_ids to be a tuple of lists, but got {type(block_ids)}"
|
||||
)
|
||||
|
||||
if len(block_ids) > 1:
|
||||
raise RuntimeError(
|
||||
"LMCacheMPConnector only works without hybrid kv cache manager. "
|
||||
"Please pass --disable-hybrid-kv-cache-manager when starting vllm"
|
||||
)
|
||||
|
||||
return block_ids[0]
|
||||
|
||||
|
||||
def extract_world_size_and_kv_rank(
|
||||
world_size: int,
|
||||
rank: int,
|
||||
vllm_config: VllmConfig,
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Convert the rank for the MLA.
|
||||
"""
|
||||
use_mla = mla_enabled(vllm_config.model_config)
|
||||
if not use_mla:
|
||||
return world_size, rank
|
||||
else:
|
||||
# Tensor parallel does not change the KV caches for MLA models.
|
||||
# So we need to "exclude" the effect of TP on rank and world size
|
||||
tp_size = vllm_config.parallel_config.tensor_parallel_size
|
||||
# vLLM constructs TP groups first, and then construct other
|
||||
# parallel groups on top of TP groups.
|
||||
# for example, TP=4, PP=2,
|
||||
# TP group: [0, 1, 2, 3], [4, 5, 6, 7]
|
||||
# PP group: [0, 4], [1, 5], [2, 6], [3, 7]
|
||||
# So we can "exclude" the effect of TP by rank // tp_size.
|
||||
return world_size // tp_size, rank // tp_size
|
||||
|
||||
|
||||
def create_scheduler_adapter(
|
||||
server_url: str, zmq_context: zmq.Context, vllm_config: VllmConfig
|
||||
) -> LMCacheMPSchedulerAdapter:
|
||||
world_size, kv_rank = extract_world_size_and_kv_rank(
|
||||
vllm_config.parallel_config.world_size,
|
||||
vllm_config.parallel_config.rank,
|
||||
vllm_config,
|
||||
)
|
||||
return LMCacheMPSchedulerAdapter(
|
||||
server_url,
|
||||
zmq_context,
|
||||
vllm_config.model_config.model,
|
||||
world_size,
|
||||
kv_rank,
|
||||
vllm_config.cache_config.block_size,
|
||||
)
|
||||
|
||||
|
||||
def create_worker_adapter(
|
||||
server_url: str, zmq_context: zmq.Context, vllm_config: VllmConfig
|
||||
) -> LMCacheMPWorkerAdapter:
|
||||
world_size, kv_rank = extract_world_size_and_kv_rank(
|
||||
vllm_config.parallel_config.world_size,
|
||||
vllm_config.parallel_config.rank,
|
||||
vllm_config,
|
||||
)
|
||||
return LMCacheMPWorkerAdapter(
|
||||
server_url,
|
||||
zmq_context,
|
||||
vllm_config.model_config.model,
|
||||
world_size,
|
||||
kv_rank,
|
||||
vllm_config.cache_config.block_size,
|
||||
)
|
||||
|
||||
|
||||
class LMCacheMPRequestState(enum.Enum):
|
||||
"""
|
||||
State machine:
|
||||
PREFETCHING -- update_state_after_alloc --> WAITING_FOR_LOAD
|
||||
WAITING_FOR_LOAD -- process_loading_requests --> READY
|
||||
"""
|
||||
|
||||
PREFETCHING = enum.auto()
|
||||
WAITING_FOR_LOAD = enum.auto()
|
||||
READY = enum.auto()
|
||||
|
||||
|
||||
@dataclass
|
||||
class LMCacheMPRequestTracker:
|
||||
# NOTE: this class used vLLM data structures, should be part of
|
||||
# vLLM integration code
|
||||
|
||||
request_id: str
|
||||
|
||||
# Read-only lists to track the token ids and block hashes
|
||||
all_token_ids: ConstantList[int]
|
||||
block_hashes: ConstantList["BlockHash"]
|
||||
|
||||
# Block ids and hashes will be updated at update_states_after_alloc and
|
||||
# during the generation
|
||||
allocated_block_ids: list[int] = field(default_factory=list)
|
||||
|
||||
# Number of scheduled tokens in this request. We keep tracking this to
|
||||
# avoid saving half-full blocks.
|
||||
num_scheduled_tokens: int = 0
|
||||
|
||||
# Number of blocks stored will be initialized when lookup the external
|
||||
# hit tokens and will be updated when processing new requests and cached
|
||||
# requests.
|
||||
num_stored_blocks: int = 0
|
||||
|
||||
# Staging load operation -- save vllm and lmcache hit tokens during lookup
|
||||
num_vllm_hit_blocks: int = 0
|
||||
num_lmcache_hit_blocks: int = 0
|
||||
|
||||
# Main state
|
||||
state: LMCacheMPRequestState = LMCacheMPRequestState.PREFETCHING
|
||||
|
||||
def __init__(self, request: "Request"):
|
||||
self.request_id = request.request_id
|
||||
self.all_token_ids = request.all_token_ids
|
||||
self.block_hashes = ConstantList(request.block_hashes)
|
||||
self.allocated_block_ids = []
|
||||
self.num_stored_blocks = 0
|
||||
self.num_vllm_hit_blocks = 0
|
||||
self.num_lmcache_hit_blocks = 0
|
||||
self.state = LMCacheMPRequestState.PREFETCHING
|
||||
|
||||
####
|
||||
# Check the state of the request
|
||||
####
|
||||
def needs_retrieve(self) -> bool:
|
||||
"""Check whether the current request needs retrieve, will be used
|
||||
update_stage_after_alloc"""
|
||||
return (
|
||||
self.num_lmcache_hit_blocks > self.num_vllm_hit_blocks
|
||||
and self.state != LMCacheMPRequestState.READY
|
||||
)
|
||||
|
||||
def is_ready_for_retrieving(self) -> bool:
|
||||
"""Check whether the current request is ready for retrieving,
|
||||
will be used in process_loading_requests"""
|
||||
return (
|
||||
self.state == LMCacheMPRequestState.WAITING_FOR_LOAD
|
||||
and self.needs_retrieve()
|
||||
)
|
||||
|
||||
####
|
||||
# Update internal states
|
||||
####
|
||||
def increase_num_scheduled_tokens(self, num_new_tokens: int):
|
||||
self.num_scheduled_tokens += num_new_tokens
|
||||
|
||||
def increase_num_stored_blocks(self, num_new_blocks: int):
|
||||
"""Increase the number of stored blocks for the current request
|
||||
This function will be called when processing the cached requests.
|
||||
"""
|
||||
self.num_stored_blocks += num_new_blocks
|
||||
|
||||
def append_block_ids(
|
||||
self,
|
||||
new_block_ids: list[int],
|
||||
):
|
||||
"""Update the block ids for the current request
|
||||
This function will be called when processing the cached requests.
|
||||
"""
|
||||
self.allocated_block_ids.extend(new_block_ids)
|
||||
|
||||
####
|
||||
# For debugging
|
||||
####
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"LMCacheMPRequestTracker(request_id={self.request_id}, "
|
||||
f"num_tokens={len(self.all_token_ids)}, "
|
||||
f"num_block_hashes={len(self.block_hashes)}, "
|
||||
f"num_allocated_blocks={len(self.allocated_block_ids)}, "
|
||||
f"num_stored_blocks={self.num_stored_blocks}, "
|
||||
f"vllm_hit_blocks={self.num_vllm_hit_blocks}, "
|
||||
f"lmcache_hit_blocks={self.num_lmcache_hit_blocks}, "
|
||||
f"state={self.state})"
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.__repr__()
|
||||
|
||||
|
||||
@dataclass
|
||||
class LMCacheMPRequestMetadata:
|
||||
request_id: str
|
||||
direction: Literal["STORE", "RETRIEVE"]
|
||||
op: LoadStoreOp
|
||||
|
||||
@staticmethod
|
||||
def GetStoreMetadata(
|
||||
tracker: LMCacheMPRequestTracker,
|
||||
blocks_in_chunk: int,
|
||||
vllm_block_size: int,
|
||||
) -> "LMCacheMPRequestMetadata | None":
|
||||
"""
|
||||
Generate the store metadata for the current request tracker.
|
||||
|
||||
Args:
|
||||
tracker: The request tracker to generate the metadata from.
|
||||
blocks_in_chunk: the number of blocks in a LMCache data chunk
|
||||
vllm_block_size: the block size used in vLLM
|
||||
"""
|
||||
# Store the blocks that has block hashes
|
||||
# NOTE: the invariant here is that `num_stored_blocks` should
|
||||
# always be a multiple of `blocks_in_chunk`
|
||||
# TODO: This should be checked everytime we update the num_stored_blocks
|
||||
min_available_blocks = min(
|
||||
len(tracker.block_hashes),
|
||||
len(tracker.allocated_block_ids),
|
||||
tracker.num_scheduled_tokens // vllm_block_size,
|
||||
)
|
||||
num_staging_blocks = min_available_blocks - tracker.num_stored_blocks
|
||||
num_chunks = num_staging_blocks // blocks_in_chunk
|
||||
|
||||
if num_chunks >= 1:
|
||||
start = tracker.num_stored_blocks
|
||||
end = start + num_chunks * blocks_in_chunk
|
||||
block_ids = tracker.allocated_block_ids[start:end]
|
||||
start_token_idx = start * vllm_block_size
|
||||
end_token_idx = end * vllm_block_size
|
||||
token_ids = list(tracker.all_token_ids)
|
||||
op = LoadStoreOp(
|
||||
token_ids=token_ids,
|
||||
block_ids=block_ids,
|
||||
start=start_token_idx,
|
||||
end=end_token_idx,
|
||||
)
|
||||
|
||||
ret = LMCacheMPRequestMetadata(
|
||||
request_id=tracker.request_id,
|
||||
direction="STORE",
|
||||
op=op,
|
||||
)
|
||||
|
||||
# Update the request tracker
|
||||
tracker.increase_num_stored_blocks(end - start)
|
||||
return ret
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def GetRetrieveMetadata(
|
||||
tracker: LMCacheMPRequestTracker,
|
||||
blocks_in_chunk: int,
|
||||
vllm_block_size: int,
|
||||
) -> "LMCacheMPRequestMetadata | None":
|
||||
"""
|
||||
Generate the retrieve metadata for the current request tracker.
|
||||
|
||||
Args:
|
||||
tracker: The request tracker to generate the metadata from.
|
||||
blocks_in_chunk: the number of blocks in a LMCache data chunk
|
||||
vllm_block_size: the block size used in vLLM
|
||||
"""
|
||||
if not tracker.is_ready_for_retrieving():
|
||||
return None
|
||||
|
||||
# |---------------------|-----------------|----------------|
|
||||
# | num_vllm_hit_blocks |
|
||||
# | lmcache chunk 1 | lmcache chunk 2 |
|
||||
# | need to retrieve |
|
||||
|
||||
start = tracker.num_vllm_hit_blocks // blocks_in_chunk * blocks_in_chunk
|
||||
end = tracker.num_lmcache_hit_blocks
|
||||
assert end % blocks_in_chunk == 0, (
|
||||
"The number of LMCache hit blocks should be a multiple of the "
|
||||
"number of blocks in a lmcache chunk. "
|
||||
)
|
||||
assert len(tracker.block_hashes) >= end, (
|
||||
"The number of block hashes should be greater than or equal to the "
|
||||
"number of LMCache hit blocks. "
|
||||
)
|
||||
if end > start:
|
||||
block_ids = tracker.allocated_block_ids[start:end]
|
||||
start_token_idx = start * vllm_block_size
|
||||
end_token_idx = end * vllm_block_size
|
||||
token_ids = list(tracker.all_token_ids)
|
||||
op = LoadStoreOp(
|
||||
token_ids=token_ids,
|
||||
block_ids=block_ids,
|
||||
start=start_token_idx,
|
||||
end=end_token_idx,
|
||||
)
|
||||
|
||||
ret = LMCacheMPRequestMetadata(
|
||||
request_id=tracker.request_id,
|
||||
direction="RETRIEVE",
|
||||
op=op,
|
||||
)
|
||||
return ret
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class LMCacheMPConnectorMetadata(KVConnectorMetadata):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.requests: list[LMCacheMPRequestMetadata] = []
|
||||
|
||||
def add_request_metadata(self, request_metadata: LMCacheMPRequestMetadata):
|
||||
self.requests.append(request_metadata)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.requests)
|
||||
|
||||
# For debugging
|
||||
def __str__(self):
|
||||
request_strs = []
|
||||
for req_meta in self.requests:
|
||||
request_strs.append(
|
||||
f"RequestMetadata(request_id={req_meta.request_id}, "
|
||||
f"direction={req_meta.direction}, "
|
||||
f"num_blocks={len(req_meta.op)}, "
|
||||
f"block_ids={req_meta.op.block_ids})"
|
||||
)
|
||||
return "[" + "\n".join(request_strs) + "]"
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
class LMCacheMPConnector(KVConnectorBase_V1):
|
||||
"""
|
||||
The connector for LMCache multi-process mode.
|
||||
|
||||
Extra configs (kv_transfer_config.extra_config):
|
||||
- lmcache.mp.host: the host of the LMCache server.
|
||||
- lmcache.mp.port: the port of the LMCache server.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: "KVCacheConfig | None" = None,
|
||||
):
|
||||
super().__init__(vllm_config, role, kv_cache_config)
|
||||
|
||||
assert vllm_config.kv_transfer_config is not None
|
||||
server_host = vllm_config.kv_transfer_config.get_from_extra_config(
|
||||
"lmcache.mp.host", "tcp://localhost"
|
||||
)
|
||||
server_port = vllm_config.kv_transfer_config.get_from_extra_config(
|
||||
"lmcache.mp.port", 5555
|
||||
)
|
||||
|
||||
server_url = f"{server_host}:{server_port}"
|
||||
zmq_context = zmq.Context.instance()
|
||||
if self.role == KVConnectorRole.SCHEDULER:
|
||||
self.scheduler_adapter = create_scheduler_adapter(
|
||||
server_url, zmq_context, vllm_config
|
||||
)
|
||||
self.request_trackers: dict[str, LMCacheMPRequestTracker] = {}
|
||||
elif self.role == KVConnectorRole.WORKER:
|
||||
self.worker_adapter = create_worker_adapter(
|
||||
server_url, zmq_context, vllm_config
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown KVConnectorRole: {self.role}")
|
||||
|
||||
self.vllm_block_size = vllm_config.cache_config.block_size
|
||||
|
||||
@property
|
||||
def role(self) -> KVConnectorRole:
|
||||
return self._role
|
||||
|
||||
# ==============================
|
||||
# Worker-side methods
|
||||
# ==============================
|
||||
|
||||
def _get_connector_metadata(self) -> KVConnectorMetadata:
|
||||
"""Get the connector metadata.
|
||||
|
||||
This function should only be called inside the connector.
|
||||
|
||||
Returns:
|
||||
ConnectorMetadata: the connector metadata.
|
||||
"""
|
||||
|
||||
# Should only be called while set to valid metadata.
|
||||
assert self._connector_metadata is not None
|
||||
return self._connector_metadata
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
"""
|
||||
Initialize with the KV caches. Useful for pre-registering the
|
||||
KV Caches in the KVConnector (e.g. for NIXL).
|
||||
|
||||
Args:
|
||||
kv_caches: dictionary of layer names, kv cache
|
||||
"""
|
||||
logger.info("Registering kv caches!")
|
||||
self.worker_adapter.register_kv_caches(kv_caches)
|
||||
return
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None:
|
||||
"""
|
||||
Start loading the KV cache from the connector to vLLM's paged
|
||||
KV buffer. This is called from the forward context before the
|
||||
forward pass to enable async loading during model execution.
|
||||
|
||||
Args:
|
||||
forward_context (ForwardContext): the forward context.
|
||||
**kwargs: additional arguments for the load operation
|
||||
|
||||
Note:
|
||||
The number of elements in kv_caches and layer_names should be
|
||||
the same.
|
||||
|
||||
"""
|
||||
metadata = self._get_connector_metadata()
|
||||
assert isinstance(metadata, LMCacheMPConnectorMetadata)
|
||||
|
||||
request_ids = []
|
||||
ops = []
|
||||
|
||||
for meta in metadata.requests:
|
||||
if meta.direction != "RETRIEVE":
|
||||
continue
|
||||
request_ids.append(meta.request_id)
|
||||
ops.append(meta.op)
|
||||
|
||||
if len(request_ids) == 0:
|
||||
return
|
||||
|
||||
with torch.cuda.stream(torch.cuda.current_stream()):
|
||||
event = torch.cuda.Event(interprocess=True)
|
||||
event.record()
|
||||
|
||||
self.worker_adapter.batched_submit_retrieve_requests(request_ids, ops, event)
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
"""
|
||||
Block until the KV for a specific layer is loaded into vLLM's
|
||||
paged buffer. This is called from within attention layer to ensure
|
||||
async copying from start_load_kv is complete.
|
||||
|
||||
This interface will be useful for layer-by-layer pipelining.
|
||||
|
||||
Args:
|
||||
layer_name: the name of that layer
|
||||
"""
|
||||
return
|
||||
|
||||
def save_kv_layer(
|
||||
self,
|
||||
layer_name: str,
|
||||
kv_layer: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Start saving a layer of KV cache from vLLM's paged buffer
|
||||
to the connector. This is called from within attention layer to
|
||||
enable async copying during execution.
|
||||
|
||||
Args:
|
||||
layer_name (str): the name of the layer.
|
||||
kv_layer (torch.Tensor): the paged KV buffer of the current
|
||||
layer in vLLM.
|
||||
attn_metadata (AttentionMetadata): the attention metadata.
|
||||
**kwargs: additional arguments for the save operation.
|
||||
"""
|
||||
return
|
||||
|
||||
def wait_for_save(self):
|
||||
"""
|
||||
Block until all the save operations is done. This is called
|
||||
as the forward context exits to ensure that the async saving
|
||||
from save_kv_layer is complete before finishing the forward.
|
||||
|
||||
This prevents overwrites of paged KV buffer before saving done.
|
||||
"""
|
||||
metadata = self._get_connector_metadata()
|
||||
assert isinstance(metadata, LMCacheMPConnectorMetadata)
|
||||
|
||||
request_ids = []
|
||||
ops = []
|
||||
for meta in metadata.requests:
|
||||
if meta.direction != "STORE":
|
||||
continue
|
||||
request_ids.append(meta.request_id)
|
||||
ops.append(meta.op)
|
||||
|
||||
if len(request_ids) == 0:
|
||||
return
|
||||
|
||||
with torch.cuda.stream(torch.cuda.current_stream()):
|
||||
event = torch.cuda.Event(interprocess=True)
|
||||
event.record()
|
||||
|
||||
self.worker_adapter.batched_submit_store_requests(request_ids, ops, event)
|
||||
|
||||
def get_finished(
|
||||
self, finished_req_ids: set[str]
|
||||
) -> tuple[set[str] | None, set[str] | None]:
|
||||
"""
|
||||
Notifies worker-side connector ids of requests that have
|
||||
finished generating tokens on the worker.
|
||||
The scheduler process (via the Executors) will use this output
|
||||
to track which workers are done.
|
||||
|
||||
Returns:
|
||||
ids of requests that have finished asynchronous transfer
|
||||
(requests that previously returned True from request_finished()),
|
||||
tuple of (sending/saving ids, recving/loading ids).
|
||||
The finished saves/sends req ids must belong to a set provided in a
|
||||
call to this method (this call or a prior one).
|
||||
"""
|
||||
val = self.worker_adapter.get_finished(finished_req_ids)
|
||||
# logger.error("Finished req ids: %s, %s", val[0], val[1])
|
||||
return val
|
||||
|
||||
def get_block_ids_with_load_errors(self) -> set[int]:
|
||||
"""
|
||||
Get the set of block IDs that failed to load.
|
||||
|
||||
Returns:
|
||||
Set of block IDs that encountered load errors.
|
||||
Empty set if no load errors occurred.
|
||||
|
||||
Notes:
|
||||
- Applies to both sync- and async-loading requests.
|
||||
- Async loading: failed blocks may be reported in any forward pass
|
||||
up to and including the pass where the request ID is returned by
|
||||
`get_finished()`. Even if failures occur, the request must still
|
||||
be reported via `get_finished()`, and the failed block IDs must
|
||||
appear here no later than that same pass.
|
||||
- Sync loading: failed blocks should be reported in the forward
|
||||
pass in which they are detected.
|
||||
"""
|
||||
# TODO: add error tracking
|
||||
return set()
|
||||
|
||||
def shutdown(self):
|
||||
"""
|
||||
Shutdown the connector. This is called when the worker process
|
||||
is shutting down to ensure that all the async operations are
|
||||
completed and the connector is cleaned up properly.
|
||||
"""
|
||||
if hasattr(self, "worker_adapter"):
|
||||
self.worker_adapter.shutdown()
|
||||
return None
|
||||
|
||||
def get_kv_connector_stats(self) -> "KVConnectorStats | None":
|
||||
"""
|
||||
Get the KV connector stats collected during the last interval.
|
||||
"""
|
||||
return None
|
||||
|
||||
# ==============================
|
||||
# Scheduler-side methods
|
||||
# ==============================
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self,
|
||||
request: "Request",
|
||||
num_computed_tokens: int,
|
||||
) -> tuple[int | None, bool]:
|
||||
"""
|
||||
Get number of new tokens that can be loaded from the
|
||||
external KV cache beyond the num_computed_tokens.
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
num_computed_tokens (int): the number of locally
|
||||
computed tokens for this request
|
||||
|
||||
Returns:
|
||||
A tuple with the following elements:
|
||||
- An optional number of tokens that can be loaded from the
|
||||
external KV cache beyond what is already computed.
|
||||
If None, it means that the connector needs more time to
|
||||
determine the number of matched tokens, and the scheduler
|
||||
should query for this request again later.
|
||||
- `True` if external KV cache tokens will be loaded
|
||||
asynchronously (between scheduler steps). Must be
|
||||
'False' if the first element is 0.
|
||||
|
||||
Notes:
|
||||
The connector should only consider the largest prefix of prompt-
|
||||
tokens for which KV cache is actually available at the time of the
|
||||
call. If the cache cannot be loaded for some tokens (e.g., due to
|
||||
connectivity issues or eviction), those tokens must not be taken
|
||||
into account.
|
||||
"""
|
||||
tracker = self._get_or_create_request_tracker(request)
|
||||
# TODO: support loading KV for preempted requests in the future
|
||||
if request.status == RequestStatus.PREEMPTED:
|
||||
return 0, False
|
||||
|
||||
self.scheduler_adapter.maybe_submit_lookup_request(
|
||||
request.request_id,
|
||||
token_ids=list(request.all_token_ids),
|
||||
)
|
||||
|
||||
ret = self.scheduler_adapter.check_lookup_result(request.request_id)
|
||||
if ret is None:
|
||||
return None, True
|
||||
|
||||
if ret == 0:
|
||||
return 0, False
|
||||
|
||||
assert (
|
||||
ret % (self.scheduler_adapter.num_blocks_per_chunk() * self.vllm_block_size)
|
||||
== 0
|
||||
)
|
||||
|
||||
# Update num stored blocks for the tracker
|
||||
num_vllm_blocks = num_computed_tokens // self.vllm_block_size
|
||||
num_lmcache_blocks = ret // self.vllm_block_size
|
||||
tracker.increase_num_stored_blocks(num_lmcache_blocks)
|
||||
|
||||
# Save the vllm and lmcache hit tokens
|
||||
tracker.num_vllm_hit_blocks = num_vllm_blocks
|
||||
tracker.num_lmcache_hit_blocks = num_lmcache_blocks
|
||||
|
||||
need_to_load = max(0, ret - num_computed_tokens)
|
||||
logger.debug(
|
||||
"vLLM hit is: %d, Need to load is %d", num_computed_tokens, need_to_load
|
||||
)
|
||||
return need_to_load, need_to_load > 0
|
||||
|
||||
def update_state_after_alloc(
|
||||
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
|
||||
):
|
||||
"""
|
||||
Update KVConnector state after block allocation.
|
||||
|
||||
If get_num_new_matched_tokens previously returned True for a
|
||||
request, this function may be called twice for that same request -
|
||||
first when blocks are allocated for the connector tokens to be
|
||||
asynchronously loaded into, and second when any additional blocks
|
||||
are allocated, after the load/transfer is complete.
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
blocks (KVCacheBlocks): the blocks allocated for the request.
|
||||
num_external_tokens (int): the number of tokens that will be
|
||||
loaded from the external KV cache.
|
||||
"""
|
||||
# NOTE: the `blocks` are NEW BLOCKS allocated for this request.
|
||||
tracker = self._get_request_tracker(request.request_id)
|
||||
block_ids = reformat_block_ids(blocks.get_block_ids())
|
||||
|
||||
# No matter we need to retrieve or not, we need to update
|
||||
# the block ids into the tracker
|
||||
tracker.append_block_ids(block_ids)
|
||||
|
||||
# Update the state of the tracker
|
||||
condition = tracker.needs_retrieve()
|
||||
if tracker.state == LMCacheMPRequestState.PREFETCHING:
|
||||
# If need to retrieve, change to WAITING_FOR_LOAD
|
||||
# Otherwise, change to READY
|
||||
tracker.state = (
|
||||
LMCacheMPRequestState.WAITING_FOR_LOAD
|
||||
if condition
|
||||
else LMCacheMPRequestState.READY
|
||||
)
|
||||
# Clean up lookup future in scheduler adapter
|
||||
self.scheduler_adapter.cleanup_lookup_result(request.request_id)
|
||||
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: SchedulerOutput
|
||||
) -> KVConnectorMetadata:
|
||||
"""
|
||||
Build the connector metadata for this step.
|
||||
|
||||
This function should NOT modify fields in the scheduler_output.
|
||||
Also, calling this function will reset the state of the connector.
|
||||
|
||||
Args:
|
||||
scheduler_output (SchedulerOutput): the scheduler output object.
|
||||
"""
|
||||
metadata = LMCacheMPConnectorMetadata()
|
||||
|
||||
self._process_retrieve_requests(metadata)
|
||||
self._process_new_requests(scheduler_output, metadata)
|
||||
self._process_cached_requests(scheduler_output, metadata)
|
||||
|
||||
if len(metadata) > 0:
|
||||
logger.debug("Final connector metadata: %s", metadata)
|
||||
|
||||
return metadata
|
||||
|
||||
def update_connector_output(self, connector_output: KVConnectorOutput):
|
||||
"""
|
||||
Update KVConnector state from worker-side connectors output.
|
||||
|
||||
Args:
|
||||
connector_output (KVConnectorOutput): the worker-side
|
||||
connectors output.
|
||||
"""
|
||||
return
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
"""
|
||||
Called exactly once when a request has finished, before its blocks are
|
||||
freed.
|
||||
|
||||
The connector may assumes responsibility for freeing the blocks
|
||||
asynchronously by returning True.
|
||||
|
||||
Returns:
|
||||
True if the request is being saved/sent asynchronously and blocks
|
||||
should not be freed until the request_id is returned from
|
||||
get_finished().
|
||||
Optional KVTransferParams to be included in the request outputs
|
||||
returned by the engine.
|
||||
"""
|
||||
# Clean up request tracker to prevent memory leak
|
||||
self._cleanup_request_tracker(request.request_id)
|
||||
# Notify LMCache to end the session for this request
|
||||
self.scheduler_adapter.end_session(request.request_id)
|
||||
|
||||
return True, None
|
||||
|
||||
def take_events(self) -> Iterable["KVCacheEvent"]:
|
||||
"""
|
||||
Take the KV cache events from the connector.
|
||||
|
||||
Yields:
|
||||
New KV cache events since the last call.
|
||||
"""
|
||||
return ()
|
||||
|
||||
@classmethod
|
||||
def get_required_kvcache_layout(cls, vllm_config: "VllmConfig") -> str | None:
|
||||
"""
|
||||
Get the required KV cache layout for this connector.
|
||||
Args:
|
||||
vllm_config (VllmConfig): the vllm config.
|
||||
|
||||
Returns:
|
||||
str: the required KV cache layout. e.g. HND, or NHD.
|
||||
None if the connector does not require a specific layout.
|
||||
"""
|
||||
|
||||
if cls is KVConnectorBase_V1:
|
||||
raise TypeError(
|
||||
"get_required_kvcache_layout should not be called "
|
||||
"on the abstract base class"
|
||||
)
|
||||
return None
|
||||
|
||||
def get_finished_count(self) -> int | None:
|
||||
"""
|
||||
Get the count of requests expected to complete send/receive operations
|
||||
via this connector. This method is used to initialize the
|
||||
KVOutputAggregator, overwriting the default world_size.
|
||||
|
||||
Returns:
|
||||
int: expected sending or receiving completion count.
|
||||
"""
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def build_kv_connector_stats(
|
||||
cls, data: dict[str, Any] | None = None
|
||||
) -> "KVConnectorStats | None":
|
||||
"""
|
||||
KVConnectorStats resolution method. This method allows dynamically
|
||||
registered connectors to return their own KVConnectorStats object,
|
||||
which can implement custom aggregation logic on the data dict.
|
||||
"""
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def build_prom_metrics(
|
||||
cls,
|
||||
vllm_config: "VllmConfig",
|
||||
metric_types: dict[type["PromMetric"], type["PromMetricT"]],
|
||||
labelnames: list[str],
|
||||
per_engine_labelvalues: dict[int, list[object]],
|
||||
) -> "KVConnectorPromMetrics | None":
|
||||
"""
|
||||
Create a KVConnectorPromMetrics subclass which should register
|
||||
per-connector Prometheus metrics and implement observe() to
|
||||
expose connector transfer stats via Prometheus.
|
||||
"""
|
||||
return None
|
||||
|
||||
##############################
|
||||
# Helper functions
|
||||
##############################
|
||||
def _process_retrieve_requests(
|
||||
self,
|
||||
metadata: LMCacheMPConnectorMetadata,
|
||||
) -> None:
|
||||
blocks_per_chunk = self.scheduler_adapter.num_blocks_per_chunk()
|
||||
|
||||
for request_tracker in self.request_trackers.values():
|
||||
if request_tracker.state != LMCacheMPRequestState.WAITING_FOR_LOAD:
|
||||
continue
|
||||
r_metadata = LMCacheMPRequestMetadata.GetRetrieveMetadata(
|
||||
request_tracker,
|
||||
blocks_per_chunk,
|
||||
vllm_block_size=self.vllm_block_size,
|
||||
)
|
||||
if r_metadata is not None:
|
||||
metadata.add_request_metadata(r_metadata)
|
||||
request_tracker.state = LMCacheMPRequestState.READY
|
||||
|
||||
def _process_new_requests(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
metadata: LMCacheMPConnectorMetadata,
|
||||
) -> None:
|
||||
blocks_per_chunk = self.scheduler_adapter.num_blocks_per_chunk()
|
||||
|
||||
for new_request in scheduler_output.scheduled_new_reqs:
|
||||
request_tracker = self._get_request_tracker(new_request.req_id)
|
||||
|
||||
num_new_tokens = scheduler_output.num_scheduled_tokens[new_request.req_id]
|
||||
request_tracker.increase_num_scheduled_tokens(num_new_tokens)
|
||||
|
||||
r_meta = LMCacheMPRequestMetadata.GetStoreMetadata(
|
||||
request_tracker, blocks_per_chunk, self.vllm_block_size
|
||||
)
|
||||
if r_meta is not None:
|
||||
metadata.add_request_metadata(r_meta)
|
||||
|
||||
def _process_cached_requests(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
metadata: LMCacheMPConnectorMetadata,
|
||||
) -> None:
|
||||
blocks_per_chunk = self.scheduler_adapter.num_blocks_per_chunk()
|
||||
|
||||
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||
for idx, request_id in enumerate(cached_reqs.req_ids):
|
||||
request_tracker = self._get_request_tracker(request_id)
|
||||
|
||||
# Update block ids
|
||||
new_block_ids = reformat_block_ids(cached_reqs.new_block_ids[idx])
|
||||
if request_id not in cached_reqs.resumed_req_ids:
|
||||
request_tracker.append_block_ids(new_block_ids)
|
||||
|
||||
# Update new scheduled tokens
|
||||
num_new_tokens = cached_reqs.num_computed_tokens[idx]
|
||||
request_tracker.increase_num_scheduled_tokens(num_new_tokens)
|
||||
|
||||
r_meta = LMCacheMPRequestMetadata.GetStoreMetadata(
|
||||
request_tracker, blocks_per_chunk, self.vllm_block_size
|
||||
)
|
||||
|
||||
if r_meta is not None:
|
||||
metadata.add_request_metadata(r_meta)
|
||||
|
||||
def _get_request_tracker(self, request_id: str) -> LMCacheMPRequestTracker:
|
||||
assert request_id in self.request_trackers, (
|
||||
f"Request tracker for request_id {request_id} not found. "
|
||||
)
|
||||
return self.request_trackers[request_id]
|
||||
|
||||
def _get_or_create_request_tracker(
|
||||
self, request: "Request"
|
||||
) -> LMCacheMPRequestTracker:
|
||||
request_id = request.request_id
|
||||
# Remove the old trackers that is created before the preemption
|
||||
if (
|
||||
request.status == RequestStatus.PREEMPTED
|
||||
and request_id in self.request_trackers
|
||||
):
|
||||
tracker = self.request_trackers[request_id]
|
||||
|
||||
# NOTE: since this function may be called multiple times
|
||||
# for a single request (because get_num_new_matched_tokens
|
||||
# may be called multiple times) for the same request, we
|
||||
# will only do the remove if the tracker is not in the "fresh"
|
||||
# state, i.e., PREFETCHING
|
||||
if tracker.state != LMCacheMPRequestState.PREFETCHING:
|
||||
self.request_trackers.pop(request_id)
|
||||
|
||||
if request_id not in self.request_trackers:
|
||||
new_tracker = LMCacheMPRequestTracker(request)
|
||||
self.request_trackers[request_id] = new_tracker
|
||||
return self.request_trackers[request_id]
|
||||
|
||||
def _cleanup_request_tracker(self, request_id: str) -> None:
|
||||
"""
|
||||
Clean up request tracker and associated lookup future for a request.
|
||||
This should be called when a request is finished to prevent memory leak.
|
||||
"""
|
||||
# Clean up request tracker
|
||||
if self.request_trackers.pop(request_id, None):
|
||||
logger.debug(
|
||||
"[KVConnector] Cleaned up request_tracker for request %s",
|
||||
request_id,
|
||||
)
|
||||
186
vllm/distributed/kv_transfer/kv_connector/v1/metrics.py
Normal file
186
vllm/distributed/kv_transfer/kv_connector/v1/metrics.py
Normal file
@@ -0,0 +1,186 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, TypeAlias, TypeVar
|
||||
|
||||
from prometheus_client import Counter, Gauge, Histogram
|
||||
|
||||
from vllm.config import KVTransferConfig, VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
|
||||
from vllm.logger import init_logger
|
||||
|
||||
PromMetric: TypeAlias = Gauge | Counter | Histogram
|
||||
PromMetricT = TypeVar("PromMetricT", bound=PromMetric)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class KVConnectorStats:
|
||||
"""
|
||||
Base class for KV Connector Stats, a container for transfer performance
|
||||
metrics or otherwise important telemetry from the connector.
|
||||
All sub-classes need to be serializable as stats are sent from worker to
|
||||
logger process.
|
||||
"""
|
||||
|
||||
data: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def reset(self):
|
||||
"""Reset the stats, clear the state."""
|
||||
raise NotImplementedError
|
||||
|
||||
def aggregate(self, other: "KVConnectorStats") -> "KVConnectorStats":
|
||||
"""
|
||||
Aggregate stats with another `KVConnectorStats` object.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def reduce(self) -> dict[str, int | float]:
|
||||
"""
|
||||
Reduce the observations collected during a time interval to one or
|
||||
more representative values (eg avg/median/sum of the series).
|
||||
This is meant to be called by the logger to produce a summary of the
|
||||
stats for the last time interval.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
"""Return True if the stats are empty."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class KVConnectorLogging:
|
||||
def __init__(self, kv_transfer_config: KVTransferConfig | None):
|
||||
# Instantiate the connector's stats class.
|
||||
if kv_transfer_config and kv_transfer_config.kv_connector:
|
||||
self.connector_cls = KVConnectorFactory.get_connector_class(
|
||||
kv_transfer_config
|
||||
)
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.transfer_stats_accumulator: KVConnectorStats | None = None
|
||||
|
||||
def observe(self, transfer_stats_data: dict[str, Any]):
|
||||
# Should not be called when a KVConnector is not configured.
|
||||
assert self.connector_cls is not None
|
||||
# Called periodically when connector syncs with the scheduler.
|
||||
# Note that this is not the same as the logging interval.
|
||||
# We expect transfer_stats_data to be aggregated across all workers and
|
||||
# consist of observations from a single connector or a MultiConnector.
|
||||
transfer_stats = self.connector_cls.build_kv_connector_stats(
|
||||
transfer_stats_data
|
||||
)
|
||||
if transfer_stats is None:
|
||||
logger.warning_once(
|
||||
"The connector %s is collecting stats but "
|
||||
"does not implement the "
|
||||
"`build_kv_connector_stats` method. "
|
||||
"Stats will not be logged.",
|
||||
self.connector_cls,
|
||||
)
|
||||
return
|
||||
|
||||
if self.transfer_stats_accumulator is None:
|
||||
self.transfer_stats_accumulator = transfer_stats
|
||||
else:
|
||||
# Accumulate last interval stats.
|
||||
self.transfer_stats_accumulator = self.transfer_stats_accumulator.aggregate(
|
||||
transfer_stats
|
||||
)
|
||||
|
||||
def log(self, log_fn=logger.info):
|
||||
"""Log transfer metrics periodically, similar to throughput logging"""
|
||||
if (
|
||||
self.transfer_stats_accumulator
|
||||
and not self.transfer_stats_accumulator.is_empty()
|
||||
):
|
||||
# Produce a single cumulative stats object for the last time
|
||||
# interval from the recorded observations.
|
||||
xfer_metrics = self.transfer_stats_accumulator.reduce()
|
||||
xfer_metrics_str = ", ".join(f"{k}={v}" for k, v in xfer_metrics.items())
|
||||
log_fn("KV Transfer metrics: %s", xfer_metrics_str)
|
||||
|
||||
# Reset metrics for next interval
|
||||
self.reset()
|
||||
|
||||
|
||||
class KVConnectorPromMetrics:
|
||||
"""
|
||||
A base class for per-connector Prometheus metric registration
|
||||
and recording.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
metric_types: dict[type[PromMetric], type[PromMetricT]],
|
||||
labelnames: list[str],
|
||||
per_engine_labelvalues: dict[int, list[object]],
|
||||
):
|
||||
self._kv_transfer_config = vllm_config.kv_transfer_config
|
||||
self._gauge_cls = metric_types[Gauge]
|
||||
self._counter_cls = metric_types[Counter]
|
||||
self._histogram_cls = metric_types[Histogram]
|
||||
self._labelnames = labelnames
|
||||
self.per_engine_labelvalues = per_engine_labelvalues
|
||||
|
||||
def make_per_engine(self, metric: PromMetric) -> dict[int, PromMetric]:
|
||||
"""
|
||||
Create a per-engine child of a prometheus_client.Metric with
|
||||
the appropriate labels set. The parent metric must be created
|
||||
using the labelnames list.
|
||||
"""
|
||||
return {
|
||||
idx: metric.labels(*labelvalues)
|
||||
for idx, labelvalues in self.per_engine_labelvalues.items()
|
||||
}
|
||||
|
||||
def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0):
|
||||
"""
|
||||
Record the supplied transfer statistics to Prometheus metrics. These
|
||||
statistics are engine-specific, and should be recorded to a metric
|
||||
with the appropriate 'engine' label. These metric instances can be
|
||||
created using the make_per_engine() helper method.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class KVConnectorPrometheus:
|
||||
"""
|
||||
Support for registering per-connector Prometheus metrics, and
|
||||
recording transfer statistics to those metrics. Uses
|
||||
KVConnectorBase.build_prom_metrics().
|
||||
"""
|
||||
|
||||
_gauge_cls = Gauge
|
||||
_counter_cls = Counter
|
||||
_histogram_cls = Histogram
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
labelnames: list[str],
|
||||
per_engine_labelvalues: dict[int, list[object]],
|
||||
):
|
||||
self.prom_metrics: KVConnectorPromMetrics | None = None
|
||||
kv_transfer_config = vllm_config.kv_transfer_config
|
||||
if kv_transfer_config and kv_transfer_config.kv_connector:
|
||||
connector_cls = KVConnectorFactory.get_connector_class(kv_transfer_config)
|
||||
metric_types = {
|
||||
Gauge: self._gauge_cls,
|
||||
Counter: self._counter_cls,
|
||||
Histogram: self._histogram_cls,
|
||||
}
|
||||
self.prom_metrics = connector_cls.build_prom_metrics(
|
||||
vllm_config,
|
||||
metric_types,
|
||||
labelnames,
|
||||
per_engine_labelvalues,
|
||||
)
|
||||
|
||||
def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0):
|
||||
if self.prom_metrics is None:
|
||||
return
|
||||
self.prom_metrics.observe(transfer_stats_data, engine_idx)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,127 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import EngineId
|
||||
from vllm.logger import init_logger
|
||||
|
||||
WorkerAddr = str
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class RegisterWorkerPayload(BaseModel):
|
||||
engine_id: EngineId
|
||||
dp_rank: int
|
||||
tp_rank: int
|
||||
pp_rank: int
|
||||
addr: WorkerAddr
|
||||
|
||||
|
||||
@dataclass
|
||||
class EngineEntry:
|
||||
engine_id: EngineId
|
||||
# {tp_rank: {pp_rank: worker_addr}}
|
||||
worker_addr: dict[int, dict[int, WorkerAddr]]
|
||||
|
||||
|
||||
class MooncakeBootstrapServer:
|
||||
"""
|
||||
A centralized server running on the global rank 0 prefiller worker.
|
||||
Prefiller workers register their connection info (IP, port, ranks) here.
|
||||
"""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, host: str, port: int):
|
||||
self.workers: dict[int, EngineEntry] = {}
|
||||
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.app = FastAPI()
|
||||
self._register_routes()
|
||||
self.server_thread: threading.Thread | None = None
|
||||
self.server: uvicorn.Server | None = None
|
||||
|
||||
def __del__(self):
|
||||
self.shutdown()
|
||||
|
||||
def _register_routes(self):
|
||||
# All methods are async. No need to use lock to protect data.
|
||||
self.app.post("/register")(self.register_worker)
|
||||
self.app.get("/query", response_model=dict[int, EngineEntry])(self.query)
|
||||
|
||||
def start(self):
|
||||
if self.server_thread:
|
||||
return
|
||||
|
||||
config = uvicorn.Config(app=self.app, host=self.host, port=self.port)
|
||||
self.server = uvicorn.Server(config=config)
|
||||
self.server_thread = threading.Thread(
|
||||
target=self.server.run, name="mooncake_bootstrap_server", daemon=True
|
||||
)
|
||||
self.server_thread.start()
|
||||
while not self.server.started:
|
||||
time.sleep(0.1) # Wait for the server to start
|
||||
logger.info("Mooncake Bootstrap Server started at %s:%d", self.host, self.port)
|
||||
|
||||
def shutdown(self):
|
||||
if self.server_thread is None or self.server is None or not self.server.started:
|
||||
return
|
||||
|
||||
self.server.should_exit = True
|
||||
self.server_thread.join()
|
||||
logger.info("Mooncake Bootstrap Server stopped.")
|
||||
|
||||
async def register_worker(self, payload: RegisterWorkerPayload):
|
||||
"""Handles registration of a prefiller worker."""
|
||||
if payload.dp_rank not in self.workers:
|
||||
self.workers[payload.dp_rank] = EngineEntry(
|
||||
engine_id=payload.engine_id,
|
||||
worker_addr={},
|
||||
)
|
||||
|
||||
dp_entry = self.workers[payload.dp_rank]
|
||||
if dp_entry.engine_id != payload.engine_id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
f"Engine ID mismatch for dp_rank={payload.dp_rank}: "
|
||||
f"expected {dp_entry.engine_id}, got {payload.engine_id}"
|
||||
),
|
||||
)
|
||||
if payload.tp_rank not in dp_entry.worker_addr:
|
||||
dp_entry.worker_addr[payload.tp_rank] = {}
|
||||
|
||||
tp_entry = dp_entry.worker_addr[payload.tp_rank]
|
||||
if payload.pp_rank in tp_entry:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
f"Worker with dp_rank={payload.dp_rank}, "
|
||||
f"tp_rank={payload.tp_rank}, pp_rank={payload.pp_rank} "
|
||||
f"is already registered at "
|
||||
f"{tp_entry[payload.pp_rank]}, "
|
||||
f"but still want to register at {payload.addr}"
|
||||
),
|
||||
)
|
||||
|
||||
tp_entry[payload.pp_rank] = payload.addr
|
||||
logger.debug(
|
||||
"Registered worker: engine_id=%s, dp_rank=%d, tp_rank=%d, pp_rank=%d at %s",
|
||||
payload.engine_id,
|
||||
payload.dp_rank,
|
||||
payload.tp_rank,
|
||||
payload.pp_rank,
|
||||
payload.addr,
|
||||
)
|
||||
|
||||
return {"status": "ok"}
|
||||
|
||||
async def query(self) -> dict[int, EngineEntry]:
|
||||
return self.workers
|
||||
@@ -0,0 +1,321 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import contextlib
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import msgspec
|
||||
import torch
|
||||
import zmq
|
||||
|
||||
from vllm import envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorMetadata,
|
||||
)
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.network_utils import (
|
||||
get_ip,
|
||||
get_open_port,
|
||||
make_zmq_socket,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
from dataclasses import field
|
||||
from enum import Enum
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
Transfer = tuple[int, float]
|
||||
EngineId = str
|
||||
ReqId = str
|
||||
|
||||
|
||||
@dataclass
|
||||
class WriteTask:
|
||||
request_id: str
|
||||
dst_engine_id: str
|
||||
local_block_ids: list[int]
|
||||
remote_block_ids_hint: list[int] | None
|
||||
layer_name: str
|
||||
event: torch.cuda.Event
|
||||
remote_notify_port: int
|
||||
remote_ip: str
|
||||
enqueue_time: float = field(default_factory=time.perf_counter)
|
||||
retried: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class LayerTransferPlan:
|
||||
"""Plan for transferring a single layer."""
|
||||
|
||||
request_id: str
|
||||
layer_name: str
|
||||
sess_idx: int
|
||||
transfer_local_offsets: list[int]
|
||||
transfer_remote_offsets: list[int]
|
||||
transfer_sizes: list[int]
|
||||
use_batch: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class RemoteAllocInfo:
|
||||
"""Information about remote block allocation."""
|
||||
|
||||
block_ids: list[int]
|
||||
writes_done: int = 0
|
||||
decode_dp_rank: int = 0
|
||||
transfer_offset: tuple[list[int], list[int], list[int]] | None = None
|
||||
|
||||
|
||||
class ROLE(Enum):
|
||||
PRODUCER = "producer"
|
||||
CONSUMER = "consumer"
|
||||
NOTINIT = "notinit"
|
||||
|
||||
|
||||
class MoRIIOAgentMetadata(
|
||||
msgspec.Struct,
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
# required for @cached_property.d
|
||||
dict=True,
|
||||
):
|
||||
engine_id: str
|
||||
agent_metadata: bytes
|
||||
kv_caches_base_addr: list[int]
|
||||
num_blocks: int
|
||||
block_len: int
|
||||
attn_backend_name: str
|
||||
|
||||
|
||||
class RoleManager:
|
||||
"""Manages role state across the connector."""
|
||||
|
||||
_instance: "RoleManager | None" = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._role: ROLE = ROLE.NOTINIT
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> "RoleManager":
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
def set_role(self, role: ROLE) -> None:
|
||||
"""Set the current role."""
|
||||
with self._lock:
|
||||
self._role = role
|
||||
|
||||
def get_role(self) -> ROLE:
|
||||
"""Get the current role."""
|
||||
return self._role
|
||||
|
||||
|
||||
def set_role(role: ROLE):
|
||||
"""Set the global role."""
|
||||
RoleManager.get_instance().set_role(role)
|
||||
|
||||
|
||||
def get_role() -> ROLE:
|
||||
"""Get the global role."""
|
||||
return RoleManager.get_instance().get_role()
|
||||
|
||||
|
||||
class MoRIIOMode(Enum):
|
||||
READ = "read"
|
||||
WRITE = "write"
|
||||
|
||||
|
||||
class MoRIIOError(Exception):
|
||||
"""Base exception for MoRIIO operations."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class HandshakeError(MoRIIOError):
|
||||
"""Exception raised when handshake fails."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TransferError(MoRIIOError):
|
||||
"""Exception raised when transfer fails."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def get_moriio_mode() -> MoRIIOMode:
|
||||
read_mode = envs.VLLM_MORIIO_CONNECTOR_READ_MODE
|
||||
logger.debug("MoRIIO Connector read_mode: %s", read_mode)
|
||||
if read_mode:
|
||||
return MoRIIOMode.READ
|
||||
else:
|
||||
return MoRIIOMode.WRITE
|
||||
|
||||
|
||||
def get_port_offset(dp_rank: int, tp_rank: int, tp_size: int = 1) -> int:
|
||||
return (dp_rank) * tp_size + tp_rank
|
||||
|
||||
|
||||
@dataclass
|
||||
class MoRIIOConfig:
|
||||
local_ip: str
|
||||
local_kv_port: int
|
||||
proxy_ip: str
|
||||
local_ping_port: int
|
||||
proxy_ping_port: int
|
||||
http_port: int
|
||||
handshake_port: int
|
||||
notify_port: int
|
||||
tp_rank: int
|
||||
dp_rank: int
|
||||
dp_size: int
|
||||
tp_size: int
|
||||
|
||||
@classmethod
|
||||
def from_vllm_config(cls, vllm_config: VllmConfig) -> "MoRIIOConfig":
|
||||
# Port Configuration:
|
||||
# local_ping_port -> Outgoing heartbeat to proxy
|
||||
# proxy_ping_port -> Remote proxy's heartbeat ingress port
|
||||
# http_port -> Instance's HTTP service endpoint
|
||||
# local_kv_port -> service port for mori engine
|
||||
# notify_port -> For synchronizing stages between prefill and decode
|
||||
# handshake_port -> For initial handshake between mori engine
|
||||
|
||||
# TODO : merge notify_port and handshake_port to simplify port management
|
||||
# supports non-contiguous ports
|
||||
assert vllm_config.kv_transfer_config is not None, (
|
||||
"kv_transfer_config must be set for MoRIIOConnector"
|
||||
)
|
||||
kv_transfer_config = vllm_config.kv_transfer_config
|
||||
extra_config = kv_transfer_config.kv_connector_extra_config
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
base_notify_port = int(extra_config["notify_port"])
|
||||
dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
port_offset = get_port_offset(dp_rank, tp_rank)
|
||||
|
||||
return cls(
|
||||
local_ip=get_ip(),
|
||||
local_kv_port=get_open_port(),
|
||||
proxy_ip=extra_config["proxy_ip"],
|
||||
local_ping_port=get_open_port(),
|
||||
proxy_ping_port=int(extra_config["proxy_ping_port"]),
|
||||
http_port=int(extra_config["http_port"]),
|
||||
handshake_port=int(extra_config["handshake_port"]),
|
||||
notify_port=base_notify_port + port_offset,
|
||||
tp_rank=tp_rank,
|
||||
dp_rank=dp_rank,
|
||||
dp_size=dp_size,
|
||||
tp_size=tp_size,
|
||||
)
|
||||
|
||||
|
||||
class MoRIIOConstants:
|
||||
"""Constants for MoRIIO connector."""
|
||||
|
||||
# ZMQ message types
|
||||
GET_META_MSG = b"get_meta_msg"
|
||||
POP_DONE_RECV = b"pop_done_recv"
|
||||
OVER = b"OVER"
|
||||
COMPLETION_PREFIX = "cmpl"
|
||||
|
||||
PING_INTERVAL = 5
|
||||
MAX_PING_RETRIES = 100
|
||||
DEFAULT_HANDSHAKE_PORT = "6301"
|
||||
DEFAULT_NOTIFY_PORT = "61005"
|
||||
|
||||
VLLM_MORI_READ_ABORT_REQUEST_TIMEOUT = 3600
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReqMeta:
|
||||
"""Metadata for a single request."""
|
||||
|
||||
local_block_ids: list[int]
|
||||
remote_block_ids: list[int]
|
||||
remote_host: str
|
||||
remote_port: int
|
||||
remote_handshake_port: int
|
||||
remote_notify_port: int
|
||||
remote_engine_id: str
|
||||
tp_size: int
|
||||
remote_dp_size: int
|
||||
|
||||
|
||||
class MoRIIOConnectorMetadata(KVConnectorMetadata):
|
||||
def __init__(self):
|
||||
self.reqs_to_recv: dict[ReqId, ReqMeta] = {}
|
||||
self.reqs_to_save: dict[ReqId, ReqMeta] = {}
|
||||
self.reqs_to_send: dict[ReqId, float] = {}
|
||||
|
||||
def __repr__(self):
|
||||
return_str = ""
|
||||
for req_id, req_meta in self.reqs_to_recv.items():
|
||||
return_str += (
|
||||
f"{req_id = },{req_meta.local_block_ids = },"
|
||||
f"{req_meta.remote_host = },{req_meta.remote_port = }"
|
||||
f"{req_meta.remote_engine_id = },{req_meta.tp_size = }"
|
||||
)
|
||||
return_str = f"MoRIIOConnectorMetadata:reqs_to_recv:{return_str},"
|
||||
|
||||
for req_id, expiry in self.reqs_to_send.items():
|
||||
return_str += f"{req_id = },{expiry = }"
|
||||
return_str = f"MoRIIOConnectorMetadata:reqs_to_send:{return_str},"
|
||||
return return_str
|
||||
|
||||
def add_new_req(
|
||||
self,
|
||||
request_id: ReqId,
|
||||
local_block_ids: list[int],
|
||||
kv_transfer_params: dict[str, Any],
|
||||
write_mode=False,
|
||||
):
|
||||
_req = ReqMeta(
|
||||
local_block_ids=local_block_ids,
|
||||
remote_block_ids=kv_transfer_params["remote_block_ids"],
|
||||
remote_engine_id=kv_transfer_params["remote_engine_id"],
|
||||
remote_host=kv_transfer_params["remote_host"],
|
||||
remote_port=kv_transfer_params["remote_port"],
|
||||
remote_handshake_port=kv_transfer_params["remote_handshake_port"],
|
||||
remote_notify_port=kv_transfer_params["remote_notify_port"],
|
||||
tp_size=kv_transfer_params.get("tp_size", 1),
|
||||
remote_dp_size=kv_transfer_params.get("remote_dp_size", 1),
|
||||
)
|
||||
if write_mode:
|
||||
self.reqs_to_save[request_id] = _req
|
||||
else:
|
||||
self.reqs_to_recv[request_id] = _req
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]:
|
||||
"""Context manager for a ZMQ socket"""
|
||||
|
||||
if socket_type not in (zmq.ROUTER, zmq.REQ, zmq.DEALER):
|
||||
raise ValueError(f"Unexpected socket type: {socket_type}")
|
||||
|
||||
ctx: zmq.Context | None = None
|
||||
try:
|
||||
ctx = zmq.Context() # type: ignore[attr-defined]
|
||||
yield make_zmq_socket(
|
||||
ctx=ctx, path=addr, socket_type=socket_type, bind=socket_type == zmq.ROUTER
|
||||
)
|
||||
finally:
|
||||
if ctx is not None:
|
||||
ctx.destroy(linger=0)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,609 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from weakref import ref as weakref_ref
|
||||
|
||||
import msgpack
|
||||
import torch
|
||||
import zmq
|
||||
|
||||
from vllm import envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.network_utils import (
|
||||
make_zmq_path,
|
||||
make_zmq_socket,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
from queue import Empty, Queue
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common import (
|
||||
ROLE,
|
||||
HandshakeError,
|
||||
LayerTransferPlan,
|
||||
MoRIIOAgentMetadata,
|
||||
MoRIIOConstants,
|
||||
MoRIIOError,
|
||||
RemoteAllocInfo,
|
||||
TransferError,
|
||||
WriteTask,
|
||||
get_port_offset,
|
||||
get_role,
|
||||
zmq_ctx,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector import (
|
||||
MoRIIOConnectorWorker,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
try:
|
||||
from mori.io import (
|
||||
EngineDesc,
|
||||
IOEngine,
|
||||
MemoryDesc,
|
||||
PollCqMode,
|
||||
RdmaBackendConfig,
|
||||
)
|
||||
|
||||
logger.info("MoRIIO is available")
|
||||
except ImportError:
|
||||
logger.error("MoRIIO is not available")
|
||||
|
||||
|
||||
"""Write task execution logic for MoRIIO connector."""
|
||||
|
||||
|
||||
class MoRIIOWriter:
|
||||
"""Handles write operations for KV cache transfers.
|
||||
Implements distributed KV cache transfer using the MoRIIO library
|
||||
for RDMA-based communication between prefill and decode instances."""
|
||||
|
||||
def __init__(self, worker: "MoRIIOConnectorWorker"):
|
||||
"""Initialize the writer.
|
||||
|
||||
Args:
|
||||
worker: Reference to the parent worker
|
||||
"""
|
||||
self._worker_ref: weakref_ref[MoRIIOConnectorWorker] = weakref_ref(worker)
|
||||
self._write_task_q: Queue[WriteTask] = Queue()
|
||||
self._write_worker_started = False
|
||||
self._write_worker_lock = threading.Lock()
|
||||
self._deferred_tasks: list[WriteTask] = []
|
||||
|
||||
@property
|
||||
def worker(self) -> "MoRIIOConnectorWorker":
|
||||
"""Get the worker instance.
|
||||
|
||||
Returns:
|
||||
The parent worker instance
|
||||
|
||||
Raises:
|
||||
RuntimeError: If worker has been garbage collected
|
||||
"""
|
||||
worker = self._worker_ref()
|
||||
if worker is None:
|
||||
raise RuntimeError("Parent worker has been garbage collected")
|
||||
return worker
|
||||
|
||||
def ensure_worker_started(self) -> None:
|
||||
"""Ensure the background write worker is running."""
|
||||
if self._write_worker_started:
|
||||
return
|
||||
self._write_worker_started = True
|
||||
with self._write_worker_lock:
|
||||
thread = threading.Thread(
|
||||
target=self._write_worker_loop, daemon=True, name="moriio-write-worker"
|
||||
)
|
||||
thread.start()
|
||||
logger.info("Started MoRIIO write worker thread")
|
||||
|
||||
def schedule_write(self, task: WriteTask) -> None:
|
||||
"""Schedule a write task.
|
||||
|
||||
Args:
|
||||
task: The write task to schedule
|
||||
"""
|
||||
self.ensure_worker_started()
|
||||
self._write_task_q.put(task)
|
||||
|
||||
def _write_worker_loop(self) -> None:
|
||||
"""Main loop for the write worker thread."""
|
||||
|
||||
while True:
|
||||
# Process deferred tasks first
|
||||
self._process_deferred_tasks()
|
||||
|
||||
# Get new task
|
||||
try:
|
||||
task = self._write_task_q.get(timeout=0.01)
|
||||
except Empty:
|
||||
continue
|
||||
|
||||
# Check if remote blocks are ready
|
||||
if not self._is_remote_ready(task):
|
||||
# task.retry_count += 1
|
||||
self._deferred_tasks.append(task)
|
||||
# logger.debug(
|
||||
# "Deferred task for request %s (retry %d)",
|
||||
# task.request_id, task.retry_count
|
||||
# )
|
||||
continue
|
||||
|
||||
# Execute the task
|
||||
|
||||
self._execute_write_task(task)
|
||||
|
||||
def _process_deferred_tasks(self) -> None:
|
||||
"""Process tasks that were previously deferred."""
|
||||
if not self._deferred_tasks:
|
||||
return
|
||||
|
||||
still_deferred: list[WriteTask] = []
|
||||
for task in self._deferred_tasks:
|
||||
if self._is_remote_ready(task):
|
||||
self._execute_write_task(task)
|
||||
else:
|
||||
still_deferred.append(task)
|
||||
|
||||
self._deferred_tasks = still_deferred
|
||||
|
||||
def _is_remote_ready(self, task: WriteTask) -> bool:
|
||||
"""Check if remote blocks are allocated for this task.
|
||||
|
||||
Args:
|
||||
task: The write task
|
||||
|
||||
Returns:
|
||||
True if remote blocks are ready
|
||||
"""
|
||||
return (
|
||||
task.request_id in self.worker.moriio_wrapper.done_remote_allocate_req_dict
|
||||
)
|
||||
|
||||
def _get_remote_alloc_info(self, request_id: str) -> RemoteAllocInfo:
|
||||
"""Get remote allocation info for a request.
|
||||
|
||||
Args:
|
||||
request_id: The request ID
|
||||
|
||||
Returns:
|
||||
Remote allocation information
|
||||
|
||||
Raises:
|
||||
KeyError: If allocation info is missing
|
||||
"""
|
||||
try:
|
||||
return self.worker.moriio_wrapper.done_remote_allocate_req_dict[request_id]
|
||||
except KeyError as e:
|
||||
raise KeyError(
|
||||
f"Remote allocation info missing for request {request_id}"
|
||||
) from e
|
||||
|
||||
def _execute_write_task(self, task: WriteTask) -> None:
|
||||
"""Execute a single write task.
|
||||
|
||||
Args:
|
||||
task: The write task to execute
|
||||
|
||||
"""
|
||||
# Get remote allocation info
|
||||
request_info = self._get_remote_alloc_info(task.request_id)
|
||||
|
||||
if request_info.block_ids is None:
|
||||
logger.debug("Request %s remote block IDs not ready", task.request_id)
|
||||
return
|
||||
|
||||
# Wait for CUDA event
|
||||
# The attention computation of the current layer cannot
|
||||
# overlap with the kv transfer task,
|
||||
# otherwise it will cause precision issues.
|
||||
# This event is used to synchronize the kv transfer and computation tasks.
|
||||
task.event.synchronize()
|
||||
|
||||
# Update engine ID with DP rank
|
||||
task.dst_engine_id = self.worker.get_engine_name_with_dp(
|
||||
task.dst_engine_id, request_info.decode_dp_rank
|
||||
)
|
||||
|
||||
# Get or create sessions
|
||||
sessions, remote_moriio_meta = self.worker._get_built_session(
|
||||
task.dst_engine_id
|
||||
)
|
||||
|
||||
# Prepare transfer plan
|
||||
plan = self._prepare_transfer_plan(task, request_info, remote_moriio_meta)
|
||||
|
||||
# Execute transfer
|
||||
self._do_layer_write(plan, sessions)
|
||||
|
||||
# Finalize if all layers complete
|
||||
self._finalize_if_complete(task, request_info)
|
||||
|
||||
def _prepare_transfer_plan(
|
||||
self,
|
||||
task: WriteTask,
|
||||
request_info: RemoteAllocInfo,
|
||||
remote_moriio_meta: MoRIIOAgentMetadata,
|
||||
) -> LayerTransferPlan:
|
||||
"""Prepare the transfer plan for a layer.
|
||||
|
||||
Args:
|
||||
task: The write task
|
||||
request_info: Remote allocation information
|
||||
|
||||
Returns:
|
||||
The transfer plan
|
||||
"""
|
||||
# Compute offsets if not cached
|
||||
if request_info.transfer_offset is None:
|
||||
offsets = self.worker._compute_block_transfer_offsets(
|
||||
task.layer_name,
|
||||
task.local_block_ids,
|
||||
request_info.block_ids,
|
||||
remote_moriio_meta,
|
||||
)
|
||||
request_info.transfer_offset = offsets
|
||||
|
||||
# Get session index
|
||||
layer_names = list(self.worker.layer_name_to_local_kv_cache_metadata.keys())
|
||||
sess_idx = layer_names.index(task.layer_name)
|
||||
|
||||
local_off, remote_off, sizes = request_info.transfer_offset
|
||||
|
||||
return LayerTransferPlan(
|
||||
request_id=task.request_id,
|
||||
layer_name=task.layer_name,
|
||||
sess_idx=sess_idx,
|
||||
transfer_local_offsets=local_off,
|
||||
transfer_remote_offsets=remote_off,
|
||||
transfer_sizes=sizes,
|
||||
use_batch=True,
|
||||
)
|
||||
|
||||
def _do_layer_write(self, plan: LayerTransferPlan, sessions: list) -> None:
|
||||
"""Perform the actual layer write.
|
||||
|
||||
Args:
|
||||
plan: The transfer plan
|
||||
sessions: List of transfer sessions
|
||||
"""
|
||||
if plan.use_batch:
|
||||
self.worker.moriio_wrapper.write_remote_data(
|
||||
plan.transfer_sizes,
|
||||
plan.transfer_local_offsets,
|
||||
plan.transfer_remote_offsets,
|
||||
sessions[plan.sess_idx],
|
||||
)
|
||||
else:
|
||||
for i in range(len(plan.transfer_local_offsets)):
|
||||
self.worker.moriio_wrapper.write_remote_data_single(
|
||||
plan.transfer_sizes[i],
|
||||
plan.transfer_local_offsets[i],
|
||||
plan.transfer_remote_offsets[i],
|
||||
plan.sess_idx,
|
||||
)
|
||||
|
||||
def _finalize_if_complete(
|
||||
self, task: WriteTask, request_info: RemoteAllocInfo
|
||||
) -> None:
|
||||
"""Finalize transfer if all layers are complete.
|
||||
|
||||
Args:
|
||||
task: The write task
|
||||
request_info: Remote allocation information
|
||||
"""
|
||||
request_info.writes_done += 1
|
||||
|
||||
if request_info.writes_done >= self.worker.num_layers:
|
||||
# Wait for transfer to complete
|
||||
self.worker.moriio_wrapper.waiting_for_transfer_complete()
|
||||
|
||||
remote_port = task.remote_notify_port + get_port_offset(
|
||||
request_info.decode_dp_rank, self.worker.tp_rank
|
||||
)
|
||||
# Consider using RDMA immediate data in decode side
|
||||
# to eliminate the need for this notification.
|
||||
# Consider including the first gen token from prefill in the notification
|
||||
|
||||
# Send completion notification
|
||||
self.worker.moriio_wrapper.send_notify(
|
||||
task.request_id, task.remote_ip, remote_port
|
||||
)
|
||||
# mark request as done, then we can free the blocks
|
||||
with self.worker.moriio_wrapper.lock:
|
||||
self.worker.moriio_wrapper.done_req_ids.append(task.request_id)
|
||||
del self.worker.moriio_wrapper.done_remote_allocate_req_dict[
|
||||
task.request_id
|
||||
]
|
||||
logger.debug(
|
||||
"Completed transfer for request %s, notified port %d",
|
||||
task.request_id,
|
||||
remote_port,
|
||||
)
|
||||
|
||||
|
||||
class MoRIIOWrapper:
|
||||
"""Wrapper for MoRIIO engine operations.
|
||||
|
||||
Handles both producer and consumer roles for KV cache transfers.
|
||||
|
||||
Args:
|
||||
moriio_engine: MoRIIO engine instance
|
||||
tp_rank: Tensor parallel rank
|
||||
dp_rank: Data parallel rank
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
moriio_engine: "IOEngine | None" = None,
|
||||
tp_rank: int = 0,
|
||||
dp_rank: int = 0,
|
||||
):
|
||||
self.tp_rank = tp_rank
|
||||
self.dp_rank = dp_rank
|
||||
self.moriio_engine = moriio_engine
|
||||
self.remote_memory_metadata = None
|
||||
self.local_memory_registered = False
|
||||
self.local_memory_metadata = None
|
||||
self.transfer_status: list[Any] = []
|
||||
self.remote_engine_ip: str | None = None
|
||||
self.notify_port: int | None = None
|
||||
self.lock = threading.Lock()
|
||||
self.done_req_ids: list[str] = []
|
||||
self.done_remote_allocate_req_dict: dict[str, RemoteAllocInfo] = {}
|
||||
self.done_write_cache_req_ids: list[str] = []
|
||||
self.notify_thread: threading.Thread | None = None
|
||||
self.sessions: list[IOEngine.Session] = []
|
||||
self.paths: dict[str, zmq.Socket] = {}
|
||||
|
||||
def set_moriio_engine(self, moriio_engine):
|
||||
assert moriio_engine is not None, (
|
||||
"You Cannot pass None engine to MoRIIOWrapper!"
|
||||
)
|
||||
self.moriio_engine = moriio_engine
|
||||
|
||||
def set_backend_type(self, backend_type):
|
||||
assert self.moriio_engine is not None, "MoRIIO engine must be set first"
|
||||
qp_per_transfer = envs.VLLM_MORIIO_QP_PER_TRANSFER
|
||||
post_batch_size = envs.VLLM_MORIIO_POST_BATCH_SIZE
|
||||
num_worker_threads = envs.VLLM_MORIIO_NUM_WORKERS
|
||||
poll_mode = PollCqMode.POLLING
|
||||
rdma_cfg = RdmaBackendConfig(
|
||||
qp_per_transfer,
|
||||
post_batch_size,
|
||||
num_worker_threads,
|
||||
poll_mode,
|
||||
)
|
||||
self.moriio_engine.create_backend(backend_type, rdma_cfg)
|
||||
|
||||
def get_agent_metadata(self):
|
||||
assert self.moriio_engine is not None, "MoRIIO engine must be set first"
|
||||
engine_metadata = self.moriio_engine.get_engine_desc()
|
||||
engine_metadata_packed = engine_metadata.pack()
|
||||
return engine_metadata_packed
|
||||
|
||||
def register_remote_engine(self, remote_packed_engine_metadata):
|
||||
assert self.moriio_engine is not None, "MoRIIO engine must be set first"
|
||||
consumer_engine_metadata = EngineDesc.unpack(remote_packed_engine_metadata)
|
||||
self.moriio_engine.register_remote_engine(consumer_engine_metadata)
|
||||
return consumer_engine_metadata.key
|
||||
|
||||
def register_local_tensor(self, tensor: torch.Tensor):
|
||||
assert self.moriio_engine is not None, "MoRIIO engine must be set first"
|
||||
try:
|
||||
self.local_memory_metadata = self.moriio_engine.register_torch_tensor(
|
||||
tensor
|
||||
)
|
||||
assert self.local_memory_metadata is not None, (
|
||||
"register_torch_tensor returned None"
|
||||
)
|
||||
local_memory_metadata_packed = self.local_memory_metadata.pack()
|
||||
except Exception as e:
|
||||
raise MoRIIOError(f"Failed to register local memory: {e}") from e
|
||||
self.local_memory_registered = True
|
||||
return local_memory_metadata_packed
|
||||
|
||||
def get_unpack_memory_metadata(self, packed_memory_metadata):
|
||||
return MemoryDesc.unpack(packed_memory_metadata)
|
||||
|
||||
def build_session(self, local_memory_metadata, remote_memory_metadata):
|
||||
assert self.moriio_engine is not None, "MoRIIO engine must be set first"
|
||||
return self.moriio_engine.create_session(
|
||||
local_memory_metadata, remote_memory_metadata
|
||||
)
|
||||
|
||||
def read_remote_data(
|
||||
self, transfer_size_byte, local_offset=0, remote_offset=0, session=None
|
||||
):
|
||||
assert self.local_memory_registered, "You have not register local memory data!"
|
||||
assert self.moriio_engine is not None, "MoRIIO engine must be set first"
|
||||
transfer_status = session.batch_read(
|
||||
local_offset,
|
||||
remote_offset,
|
||||
transfer_size_byte,
|
||||
self.moriio_engine.allocate_transfer_uid(),
|
||||
)
|
||||
|
||||
return transfer_status
|
||||
|
||||
def write_remote_data(
|
||||
self, transfer_size_byte, local_offset=0, remote_offset=0, session=None
|
||||
):
|
||||
assert self.local_memory_registered, "You have not register local memory data!"
|
||||
assert self.moriio_engine is not None, "MoRIIO engine must be set first"
|
||||
write_uid = self.moriio_engine.allocate_transfer_uid()
|
||||
|
||||
transfer_status = session.batch_write(
|
||||
local_offset, remote_offset, transfer_size_byte, write_uid
|
||||
)
|
||||
with self.lock:
|
||||
self.transfer_status.append(transfer_status)
|
||||
|
||||
def write_remote_data_single(
|
||||
self, transfer_size_byte, local_offset=0, remote_offset=0, sess_idx=0
|
||||
):
|
||||
assert self.local_memory_registered, "You have not register local memory data!"
|
||||
assert self.moriio_engine is not None, "MoRIIO engine must be set first"
|
||||
transfer_status = self.sessions[sess_idx].write(
|
||||
local_offset,
|
||||
remote_offset,
|
||||
transfer_size_byte,
|
||||
self.moriio_engine.allocate_transfer_uid(),
|
||||
)
|
||||
with self.lock:
|
||||
self.transfer_status.append(transfer_status)
|
||||
|
||||
def waiting_for_transfer_complete(self):
|
||||
if not self.transfer_status:
|
||||
return
|
||||
|
||||
transfers_to_wait = []
|
||||
with self.lock:
|
||||
transfers_to_wait = self.transfer_status[:]
|
||||
self.transfer_status.clear()
|
||||
|
||||
for status in transfers_to_wait:
|
||||
try:
|
||||
status.Wait()
|
||||
if not status.Succeeded():
|
||||
logger.error(
|
||||
"Transfer failed: %s, Code: %s", status.Message(), status.Code()
|
||||
)
|
||||
raise TransferError("MoRIIO transfer failed!")
|
||||
except Exception as e:
|
||||
logger.error("Transfer %s failed: %s", status, e)
|
||||
raise
|
||||
|
||||
def async_wait_reqid(self):
|
||||
assert self.notify_port is not None, "Notify port cannot be None"
|
||||
|
||||
if self.notify_thread is not None:
|
||||
return
|
||||
|
||||
def _async_wait():
|
||||
host = "*"
|
||||
path = make_zmq_path("tcp", host, self.notify_port)
|
||||
logger.info("Node starting to listen notify from path = %s", path)
|
||||
|
||||
with zmq_ctx(zmq.ROUTER, path) as sock:
|
||||
while True:
|
||||
try:
|
||||
identity, msg = sock.recv_multipart()
|
||||
self._handle_message(msg)
|
||||
except Exception as e:
|
||||
logger.error("Error processing message: %s", e)
|
||||
raise HandshakeError(f"Error processing message: {e}") from e
|
||||
|
||||
self.notify_thread = threading.Thread(
|
||||
target=_async_wait, daemon=True, name="moriio-notify-listener"
|
||||
)
|
||||
self.notify_thread.start()
|
||||
|
||||
def _handle_message(self, msg: bytes):
|
||||
"""Handles incoming messages from remote nodes."""
|
||||
# Handles incoming remote messages:
|
||||
# Prefill Role:
|
||||
# [write] mode: receives block information (allocation)
|
||||
# [read] mode: receives block release messages from decode side
|
||||
# Decode Role:
|
||||
# [write] mode: receives KV cache write completion notifications
|
||||
handled = False
|
||||
try:
|
||||
data = msgpack.loads(msg)
|
||||
if isinstance(data, dict) and "req_id" in data:
|
||||
self._handle_structured_message(data)
|
||||
|
||||
return
|
||||
except (msgpack.exceptions.ExtraData, msgpack.exceptions.UnpackException):
|
||||
logger.debug("Failed to decode msgpack message, will try as string")
|
||||
pass
|
||||
|
||||
try:
|
||||
msg_str = msg.decode("UTF-8")
|
||||
if msg_str.startswith(MoRIIOConstants.COMPLETION_PREFIX):
|
||||
self._handle_completion_message(msg_str)
|
||||
handled = True
|
||||
except UnicodeDecodeError:
|
||||
logger.warning("Received non-UTF8 message: %s", msg_str)
|
||||
if not handled:
|
||||
raise MoRIIOError(f"Unhandled message format: {msg_str}")
|
||||
|
||||
def _handle_structured_message(self, data: dict):
|
||||
assert get_role() == ROLE.PRODUCER, "Only prefill can get block messages"
|
||||
req_id = data["req_id"]
|
||||
block_notify_list = data.get("block_notify_list", [])
|
||||
decode_dp_rank = data.get("decode_rank", 0)
|
||||
assert len(block_notify_list) > 0, (
|
||||
"block_notify_list cannot be empty in remote allocate message"
|
||||
)
|
||||
|
||||
with self.lock:
|
||||
self.done_remote_allocate_req_dict[req_id] = RemoteAllocInfo(
|
||||
block_ids=block_notify_list, decode_dp_rank=decode_dp_rank
|
||||
)
|
||||
|
||||
def _handle_completion_message(self, msg: str):
|
||||
with self.lock:
|
||||
if get_role() == ROLE.PRODUCER:
|
||||
self.done_req_ids.append(msg)
|
||||
else:
|
||||
self.done_write_cache_req_ids.append(msg)
|
||||
|
||||
def send_notify(self, req_ids, remote_ip, remote_port):
|
||||
if not remote_ip or not remote_port:
|
||||
logger.warning("Missing remote_ip or remote_port for notification")
|
||||
return
|
||||
|
||||
path = make_zmq_path("tcp", remote_ip, remote_port)
|
||||
|
||||
if path not in self.paths:
|
||||
ctx = zmq.Context.instance()
|
||||
sock = make_zmq_socket(
|
||||
ctx=ctx, path=path, socket_type=zmq.DEALER, bind=False
|
||||
)
|
||||
self.paths[path] = sock
|
||||
|
||||
req_list = req_ids if isinstance(req_ids, list) else [req_ids]
|
||||
|
||||
sock = self.paths[path]
|
||||
try:
|
||||
for req_id in req_list:
|
||||
if not isinstance(req_id, str):
|
||||
logger.warning(
|
||||
"Invalid req_id type: %s, expected str", type(req_id)
|
||||
)
|
||||
continue
|
||||
sock.send(req_id.encode("utf-8"))
|
||||
except Exception as e:
|
||||
logger.error("Failed to send notification to %s: %s", path, e)
|
||||
self.paths.pop(path, None)
|
||||
raise
|
||||
|
||||
def pop_finished_req_ids(self):
|
||||
# producer invocation: get the set of completed requests at the decode
|
||||
with self.lock:
|
||||
done_send = set(self.done_req_ids)
|
||||
self.done_req_ids = []
|
||||
return done_send
|
||||
|
||||
def pop_finished_write_req_ids(self):
|
||||
# Call the consumer in write mode to get the collection after write completion
|
||||
with self.lock:
|
||||
done_write_cache = set(self.done_write_cache_req_ids)
|
||||
self.done_write_cache_req_ids = []
|
||||
return done_write_cache
|
||||
|
||||
def shutdown(self):
|
||||
logger.debug("Closing MoRIIOWrapper and cleaning up ZMQ sockets")
|
||||
for path, sock in self.paths.items():
|
||||
try:
|
||||
sock.close(linger=0)
|
||||
logger.debug("Closed ZMQ socket for path: %s", path)
|
||||
except Exception as e:
|
||||
logger.warning("Error closing ZMQ socket for path %s: %s", path, e)
|
||||
self.paths.clear()
|
||||
515
vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
Normal file
515
vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
Normal file
@@ -0,0 +1,515 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.kv_transfer import KVTransferConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
CopyBlocksOp,
|
||||
KVConnectorBase_V1,
|
||||
KVConnectorHandshakeMetadata,
|
||||
KVConnectorMetadata,
|
||||
KVConnectorRole,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
||||
KVConnectorPromMetrics,
|
||||
KVConnectorStats,
|
||||
PromMetric,
|
||||
PromMetricT,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.outputs import KVConnectorOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.distributed.kv_events import KVCacheEvent
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiKVConnectorMetadata(KVConnectorMetadata):
|
||||
metadata: tuple[KVConnectorMetadata, ...]
|
||||
extra_async_saves: dict[str, int] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiKVConnectorStats(KVConnectorStats):
|
||||
"""
|
||||
Maintain a dict of KVConnectorStats objects, one for each connector.
|
||||
This is used to aggregate the stats from all connectors separately.
|
||||
"""
|
||||
|
||||
def aggregate(self, other: KVConnectorStats) -> KVConnectorStats:
|
||||
for connector_id, stats in other.data.items():
|
||||
if connector_id not in self.data:
|
||||
self[connector_id] = stats
|
||||
else:
|
||||
assert isinstance(stats, type(self.data[connector_id]))
|
||||
self[connector_id] = self[connector_id].aggregate(stats)
|
||||
return self
|
||||
|
||||
def reset(self):
|
||||
for stats in self.data.values():
|
||||
stats.reset()
|
||||
|
||||
def reduce(self) -> dict[str, Any]:
|
||||
# TODO (NickLucche) Adjust for logging on separate lines
|
||||
return {
|
||||
connector_id: stats.reduce() for connector_id, stats in self.data.items()
|
||||
}
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
return all(stats.is_empty() for stats in self.data.values())
|
||||
|
||||
def __getitem__(self, connector_id: str) -> KVConnectorStats:
|
||||
return self.data[connector_id]
|
||||
|
||||
def __setitem__(self, connector_id: str, stats: KVConnectorStats):
|
||||
self.data[connector_id] = stats
|
||||
|
||||
|
||||
class MultiKVConnectorPromMetrics(KVConnectorPromMetrics):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
metric_types: dict[type[PromMetric], type[PromMetricT]],
|
||||
labelnames: list[str],
|
||||
per_engine_labelvalues: dict[int, list[object]],
|
||||
prom_metrics: dict[str, KVConnectorPromMetrics],
|
||||
):
|
||||
super().__init__(vllm_config, metric_types, labelnames, per_engine_labelvalues)
|
||||
self._prom_metrics = prom_metrics
|
||||
|
||||
def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0):
|
||||
for connector_id, stats_data in transfer_stats_data.items():
|
||||
assert connector_id in self._prom_metrics, (
|
||||
f"{connector_id} is not contained in the list of registered connectors "
|
||||
f"with Prometheus metrics support: {self._prom_metrics.keys()}"
|
||||
)
|
||||
self._prom_metrics[connector_id].observe(stats_data["data"], engine_idx)
|
||||
|
||||
|
||||
class MultiConnector(KVConnectorBase_V1):
|
||||
"""
|
||||
A wrapper for using multiple KVConnectors at the same time.
|
||||
|
||||
The current logic is:
|
||||
- Load KV from the first connector that advertises available tokens from
|
||||
get_num_new_matched_tokens(), based on the order in the config.
|
||||
- Save to all connectors.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: "KVCacheConfig",
|
||||
):
|
||||
super().__init__(
|
||||
vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config
|
||||
)
|
||||
|
||||
self._connectors: list[KVConnectorBase_V1] = []
|
||||
self._ktc_kv_transfer_config = []
|
||||
for connector_cls, temp_config in self._get_connector_classes_and_configs(
|
||||
vllm_config
|
||||
):
|
||||
self._connectors.append(connector_cls(temp_config, role, kv_cache_config))
|
||||
self._ktc_kv_transfer_config.append(temp_config.kv_transfer_config)
|
||||
|
||||
# A mapping from request id to the index of the connector chosen to
|
||||
# load the request from (if any).
|
||||
self._requests_to_connector: dict[str, int] = {}
|
||||
|
||||
# Keeps track of *additional* remaining async saves (beyond 1) to be
|
||||
# finished per request. Not needed for async loads since we only allow
|
||||
# a single connector to load.
|
||||
# Propagated from scheduler to worker side via the connector metadata.
|
||||
self._extra_async_saves: dict[str, int] = {}
|
||||
|
||||
@property
|
||||
def prefer_cross_layer_blocks(self) -> bool:
|
||||
if not self._connectors:
|
||||
return False
|
||||
return all(c.prefer_cross_layer_blocks for c in self._connectors)
|
||||
|
||||
@classmethod
|
||||
def _get_connector_classes_and_configs(
|
||||
cls, vllm_config: "VllmConfig"
|
||||
) -> list[tuple[type[KVConnectorBaseType], "VllmConfig"]]:
|
||||
assert vllm_config.kv_transfer_config is not None
|
||||
ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"connectors"
|
||||
)
|
||||
assert ktcs is not None
|
||||
ret: list[tuple[type[KVConnectorBaseType], VllmConfig]] = []
|
||||
for ktc in ktcs:
|
||||
temp_config = copy.copy(vllm_config)
|
||||
engine_id = ktc.get("engine_id", vllm_config.kv_transfer_config.engine_id)
|
||||
temp_config.kv_transfer_config = KVTransferConfig(
|
||||
**ktc, engine_id=engine_id
|
||||
)
|
||||
ret.append(
|
||||
(
|
||||
KVConnectorFactory.get_connector_class(
|
||||
temp_config.kv_transfer_config
|
||||
),
|
||||
temp_config,
|
||||
)
|
||||
)
|
||||
return ret
|
||||
|
||||
def register_cross_layers_kv_cache(
|
||||
self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
|
||||
):
|
||||
# Register on all connectors
|
||||
for c in self._connectors:
|
||||
c.register_cross_layers_kv_cache(kv_cache, attn_backend)
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
for c in self._connectors:
|
||||
c.register_kv_caches(kv_caches)
|
||||
|
||||
# We must override the base class method here because we need to bind
|
||||
# the metadata to each connector in the order of the connectors in the
|
||||
# MultiKVConnectorMetadata.
|
||||
#
|
||||
# Note: Call the base class method to ensure metadata is also set on the
|
||||
# MultiConnector instance itself; otherwise, `has_connector_metadata()` will
|
||||
# always return False.
|
||||
def bind_connector_metadata(self, connector_metadata: KVConnectorMetadata) -> None:
|
||||
assert isinstance(connector_metadata, MultiKVConnectorMetadata)
|
||||
if connector_metadata.extra_async_saves:
|
||||
self._extra_async_saves.update(connector_metadata.extra_async_saves)
|
||||
for c, cm in zip(self._connectors, connector_metadata.metadata):
|
||||
c.bind_connector_metadata(cm)
|
||||
super().bind_connector_metadata(connector_metadata)
|
||||
|
||||
def clear_connector_metadata(self) -> None:
|
||||
for c in self._connectors:
|
||||
c.clear_connector_metadata()
|
||||
super().clear_connector_metadata()
|
||||
|
||||
def shutdown(self):
|
||||
exception: Exception | None = None
|
||||
for c in self._connectors:
|
||||
try:
|
||||
c.shutdown()
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Exception during connector %s shutdown.", c.__class__.__name__
|
||||
)
|
||||
exception = e
|
||||
if exception:
|
||||
raise exception
|
||||
|
||||
# ==============================
|
||||
# Worker-side methods
|
||||
# ==============================
|
||||
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
|
||||
for c in self._connectors:
|
||||
c.start_load_kv(forward_context, **kwargs)
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
for c in self._connectors:
|
||||
c.wait_for_layer_load(layer_name)
|
||||
|
||||
def save_kv_layer(
|
||||
self,
|
||||
layer_name: str,
|
||||
kv_layer: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
for c in self._connectors:
|
||||
c.save_kv_layer(layer_name, kv_layer, attn_metadata, **kwargs)
|
||||
|
||||
def wait_for_save(self):
|
||||
for c in self._connectors:
|
||||
c.wait_for_save()
|
||||
|
||||
def get_finished(
|
||||
self, finished_req_ids: set[str]
|
||||
) -> tuple[set[str] | None, set[str] | None]:
|
||||
finished_sending: set[str] = set()
|
||||
finished_recving: set[str] = set()
|
||||
for c in self._connectors:
|
||||
sending, recving = c.get_finished(finished_req_ids)
|
||||
if not recving and not sending:
|
||||
continue
|
||||
# Aggregate finished recving request ids.
|
||||
finished_recving.update(recving or ())
|
||||
# Aggregate finished sending request ids - only include
|
||||
# once we've drained the "extra" count (for cases where
|
||||
# more than one connector is async-saving the same request).
|
||||
for req_id in sending or ():
|
||||
extra_pending = self._extra_async_saves.get(req_id)
|
||||
if extra_pending is None:
|
||||
finished_sending.add(req_id)
|
||||
continue
|
||||
assert extra_pending > 0
|
||||
if extra_pending == 1:
|
||||
del self._extra_async_saves[req_id]
|
||||
else:
|
||||
self._extra_async_saves[req_id] = extra_pending - 1
|
||||
|
||||
return finished_sending or None, finished_recving or None
|
||||
|
||||
def get_block_ids_with_load_errors(self) -> set[int]:
|
||||
agg_block_ids: set[int] = set()
|
||||
for c in self._connectors:
|
||||
agg_block_ids |= c.get_block_ids_with_load_errors()
|
||||
return agg_block_ids
|
||||
|
||||
def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp):
|
||||
"""Set xPU-specific copy ops for all sub-connectors."""
|
||||
for c in self._connectors:
|
||||
c.set_host_xfer_buffer_ops(copy_operation)
|
||||
|
||||
def handle_preemptions(self, preempted_req_ids: set[str]):
|
||||
"""Handle preempted requests for all sub-connectors."""
|
||||
for c in self._connectors:
|
||||
c.handle_preemptions(preempted_req_ids)
|
||||
|
||||
def get_finished_count(self) -> int | None:
|
||||
# TODO(https://github.com/vllm-project/vllm/issues/33400)
|
||||
# Currently no connectors return non-None
|
||||
return None
|
||||
|
||||
# TODO: Add a generic implementation of 'get_kv_connector_kv_cache_events'
|
||||
# method for the MultiConnector. It should be able to get events from
|
||||
# multiple connectors, handling the case where only a subset of the
|
||||
# requested connectors implements the 'get_kv_connector_kv_cache_events'
|
||||
# WIP: https://github.com/vllm-project/vllm/pull/31811
|
||||
|
||||
# ==============================
|
||||
# Scheduler-side methods
|
||||
# ==============================
|
||||
def get_num_new_matched_tokens(
|
||||
self,
|
||||
request: "Request",
|
||||
num_computed_tokens: int,
|
||||
) -> tuple[int | None, bool]:
|
||||
to_return = (0, False)
|
||||
for i, c in enumerate(self._connectors):
|
||||
toks, load_async = c.get_num_new_matched_tokens(
|
||||
request, num_computed_tokens
|
||||
)
|
||||
# If there is a connector still looking up the matches,
|
||||
# we return None to indicate that we are not done yet.
|
||||
if toks is None:
|
||||
return (None, False)
|
||||
# The first connector that has new matched tokens will be assigned
|
||||
# to this request.
|
||||
if to_return[0] == 0 and toks > 0:
|
||||
self._requests_to_connector[request.request_id] = i
|
||||
to_return = (toks, load_async)
|
||||
return to_return
|
||||
|
||||
def update_state_after_alloc(
|
||||
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
|
||||
):
|
||||
chosen_connector = self._requests_to_connector.get(request.request_id, -1)
|
||||
empty_blocks = blocks.new_empty()
|
||||
for i, c in enumerate(self._connectors):
|
||||
if i == chosen_connector:
|
||||
# Forward call to the chosen connector (if any).
|
||||
c.update_state_after_alloc(request, blocks, num_external_tokens)
|
||||
else:
|
||||
# Call with empty blocks for other connectors.
|
||||
c.update_state_after_alloc(request, empty_blocks, 0)
|
||||
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: SchedulerOutput
|
||||
) -> MultiKVConnectorMetadata:
|
||||
metadata = MultiKVConnectorMetadata(
|
||||
metadata=tuple(
|
||||
c.build_connector_meta(scheduler_output) for c in self._connectors
|
||||
)
|
||||
)
|
||||
if self._extra_async_saves:
|
||||
metadata.extra_async_saves = self._extra_async_saves
|
||||
self._extra_async_saves = {}
|
||||
return metadata
|
||||
|
||||
def update_connector_output(self, connector_output: KVConnectorOutput):
|
||||
for c in self._connectors:
|
||||
c.update_connector_output(connector_output)
|
||||
|
||||
def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None:
|
||||
"""
|
||||
Get the KVConnector handshake metadata from sub-connectors.
|
||||
Returns the first non-None metadata from sub-connectors.
|
||||
"""
|
||||
for c in self._connectors:
|
||||
metadata = c.get_handshake_metadata()
|
||||
if metadata is not None:
|
||||
return metadata
|
||||
return None
|
||||
|
||||
def set_xfer_handshake_metadata(
|
||||
self, metadata: dict[int, KVConnectorHandshakeMetadata]
|
||||
) -> None:
|
||||
"""
|
||||
Set the KV connector handshake metadata for all sub-connectors.
|
||||
This is needed to start the NIXL listener thread for NixlConnector.
|
||||
"""
|
||||
for c in self._connectors:
|
||||
c.set_xfer_handshake_metadata(metadata)
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
blocks: list[int],
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
async_saves = 0
|
||||
kv_txfer_params = None
|
||||
for c in self._connectors:
|
||||
async_save, txfer_params = c.request_finished(request, blocks)
|
||||
if async_save:
|
||||
async_saves += 1
|
||||
if txfer_params is not None:
|
||||
if kv_txfer_params is not None:
|
||||
# TODO we can probably change this to merge the dicts here,
|
||||
# checking for key clashes.
|
||||
raise RuntimeError(
|
||||
"Only one connector can produce KV transfer params"
|
||||
)
|
||||
kv_txfer_params = txfer_params
|
||||
if async_saves > 1:
|
||||
self._extra_async_saves[request.request_id] = async_saves - 1
|
||||
|
||||
# Clean up other state for this request.
|
||||
self._requests_to_connector.pop(request.request_id, None)
|
||||
|
||||
return async_saves > 0, kv_txfer_params
|
||||
|
||||
def take_events(self) -> Iterable["KVCacheEvent"]:
|
||||
for c in self._connectors:
|
||||
yield from c.take_events()
|
||||
|
||||
@classmethod
|
||||
def get_required_kvcache_layout(cls, vllm_config: "VllmConfig") -> str | None:
|
||||
"""
|
||||
Get the required KV cache layout for this connector.
|
||||
Args:
|
||||
vllm_config (VllmConfig): the vllm config.
|
||||
|
||||
Returns:
|
||||
str: the required KV cache layout. e.g. HND, or NHD.
|
||||
None if the connector does not require a specific layout.
|
||||
"""
|
||||
assert vllm_config.kv_transfer_config is not None
|
||||
layouts: set[str] = set()
|
||||
for connector_cls, temp_config in cls._get_connector_classes_and_configs(
|
||||
vllm_config
|
||||
):
|
||||
required_kvcache_layout = connector_cls.get_required_kvcache_layout(
|
||||
temp_config
|
||||
)
|
||||
if required_kvcache_layout is not None:
|
||||
layouts.add(required_kvcache_layout)
|
||||
|
||||
if len(layouts) > 1:
|
||||
raise ValueError(
|
||||
f"KV cache layout mismatch: "
|
||||
f"found {len(layouts)} different layouts "
|
||||
f"({', '.join(layouts)})."
|
||||
f"All connectors must use the same layout."
|
||||
)
|
||||
return next(iter(layouts), None)
|
||||
|
||||
@classmethod
|
||||
def build_kv_connector_stats(
|
||||
cls, data: dict[str, Any] | None = None
|
||||
) -> KVConnectorStats | None:
|
||||
if data is None:
|
||||
return MultiKVConnectorStats()
|
||||
|
||||
# data is a dict mapping connector name to their stats data.
|
||||
# The stats data can be either:
|
||||
# 1. Already-instantiated KVConnectorStats objects (same process)
|
||||
# 2. Serialized dicts (cross-process after serialization)
|
||||
# We need to reconstruct proper KVConnectorStats objects from dicts
|
||||
reconstructed_data = {}
|
||||
for connector_name, stats_value in data.items():
|
||||
# If already a KVConnectorStats object, use it directly
|
||||
if isinstance(stats_value, KVConnectorStats):
|
||||
reconstructed_data[connector_name] = stats_value
|
||||
continue
|
||||
|
||||
# Otherwise, reconstruct from serialized dict
|
||||
# Get the connector class to reconstruct its stats
|
||||
connector_cls = KVConnectorFactory.get_connector_class_by_name(
|
||||
connector_name
|
||||
)
|
||||
|
||||
# stats_value is the serialized dataclass which contains {'data': {...}}
|
||||
# We need to extract the inner 'data' field to avoid double-nesting
|
||||
assert isinstance(stats_value, dict) and "data" in stats_value, (
|
||||
f"Expected a dict with a 'data' field, got {stats_value}"
|
||||
)
|
||||
inner_data = stats_value["data"]
|
||||
|
||||
# Use the connector's build_kv_connector_stats to reconstruct
|
||||
if reconstructed_stats := connector_cls.build_kv_connector_stats(
|
||||
data=inner_data
|
||||
):
|
||||
reconstructed_data[connector_name] = reconstructed_stats
|
||||
|
||||
return MultiKVConnectorStats(data=reconstructed_data)
|
||||
|
||||
def get_kv_connector_stats(self) -> MultiKVConnectorStats | None:
|
||||
# Group connector stats by connector type.
|
||||
stats_by_connector: MultiKVConnectorStats | None = None
|
||||
for c in self._connectors:
|
||||
stats = c.get_kv_connector_stats()
|
||||
if stats is None:
|
||||
continue
|
||||
if stats_by_connector is None:
|
||||
# Lazy init to allow optional return value.
|
||||
stats_by_connector = MultiKVConnectorStats()
|
||||
stats_by_connector[c.__class__.__name__] = stats
|
||||
return stats_by_connector
|
||||
|
||||
@classmethod
|
||||
def build_prom_metrics(
|
||||
cls,
|
||||
vllm_config: "VllmConfig",
|
||||
metric_types: dict[type["PromMetric"], type["PromMetricT"]],
|
||||
labelnames: list[str],
|
||||
per_engine_labelvalues: dict[int, list[object]],
|
||||
) -> KVConnectorPromMetrics:
|
||||
prom_metrics: dict[str, KVConnectorPromMetrics] = {}
|
||||
for connector_cls, temp_config in cls._get_connector_classes_and_configs(
|
||||
vllm_config
|
||||
):
|
||||
connector_prom = connector_cls.build_prom_metrics(
|
||||
temp_config, metric_types, labelnames, per_engine_labelvalues
|
||||
)
|
||||
if connector_prom is not None:
|
||||
prom_metrics[connector_cls.__name__] = connector_prom
|
||||
return MultiKVConnectorPromMetrics(
|
||||
vllm_config,
|
||||
metric_types,
|
||||
labelnames,
|
||||
per_engine_labelvalues,
|
||||
prom_metrics,
|
||||
)
|
||||
|
||||
def reset_cache(self) -> bool:
|
||||
results = [c.reset_cache() is not False for c in self._connectors]
|
||||
return all(results)
|
||||
2790
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
Normal file
2790
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,800 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from itertools import islice
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.distributed.kv_events import BlockRemoved, BlockStored, KVCacheEvent
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import yield_req_data
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import (
|
||||
KVConnectorBase_V1,
|
||||
KVConnectorRole,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
||||
KVConnectorPromMetrics,
|
||||
KVConnectorStats,
|
||||
PromMetric,
|
||||
PromMetricT,
|
||||
)
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.kv_offload.abstract import OffloadingManager
|
||||
from vllm.v1.kv_offload.factory import OffloadingSpecFactory
|
||||
from vllm.v1.kv_offload.mediums import GPULoadStoreSpec
|
||||
from vllm.v1.kv_offload.spec import OffloadingSpec
|
||||
from vllm.v1.kv_offload.worker.worker import (
|
||||
OffloadingWorker,
|
||||
TransferSpec,
|
||||
TransferType,
|
||||
)
|
||||
from vllm.v1.outputs import KVConnectorOutput
|
||||
from vllm.v1.request import Request
|
||||
|
||||
ReqId = str
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OffloadingOperationMetrics:
|
||||
op_size: int
|
||||
op_time: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class OffloadingConnectorStats(KVConnectorStats):
|
||||
def __post_init__(self):
|
||||
if not self.data:
|
||||
# Empty container init, no data is passed in.
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.data: dict[str, list[OffloadingOperationMetrics]] = {}
|
||||
|
||||
def aggregate(self, other: KVConnectorStats) -> KVConnectorStats:
|
||||
if not other.is_empty():
|
||||
for k, v in other.data.items():
|
||||
if k not in self.data:
|
||||
self.data[k] = v
|
||||
else:
|
||||
accumulator = self.data[k]
|
||||
assert isinstance(accumulator, list)
|
||||
accumulator.extend(v)
|
||||
return self
|
||||
|
||||
def reduce(self) -> dict[str, int | float]:
|
||||
"""
|
||||
Reduce the observations collected during a time interval to one or
|
||||
more representative values (eg avg/median/sum of the series).
|
||||
This is meant to be called by the logger to produce a summary of the
|
||||
stats for the last time interval.
|
||||
"""
|
||||
return_dict: dict[str, int | float] = {}
|
||||
for transfer_type, ops_list in self.data.items():
|
||||
assert isinstance(ops_list, list)
|
||||
total_bytes = 0
|
||||
total_time = 0.0
|
||||
for op in ops_list:
|
||||
assert isinstance(op, dict)
|
||||
total_bytes += op["op_size"]
|
||||
total_time += op["op_time"]
|
||||
return_dict[f"{transfer_type}_total_bytes"] = total_bytes
|
||||
return_dict[f"{transfer_type}_total_time"] = total_time
|
||||
return return_dict
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
return not self.data
|
||||
|
||||
def record_transfer(self, num_bytes: int, time: float, transfer_type: TransferType):
|
||||
src, dst = transfer_type
|
||||
transfer_type_key = src + "_to_" + dst
|
||||
op = OffloadingOperationMetrics(num_bytes, time)
|
||||
if transfer_type_key in self.data:
|
||||
self.data[transfer_type_key].append(op)
|
||||
else:
|
||||
self.data[transfer_type_key] = [op]
|
||||
|
||||
|
||||
@dataclass
|
||||
class OffloadingConnectorMetadata(KVConnectorMetadata):
|
||||
reqs_to_load: dict[ReqId, TransferSpec]
|
||||
reqs_to_store: dict[ReqId, TransferSpec]
|
||||
|
||||
|
||||
class OffloadingConnector(KVConnectorBase_V1):
|
||||
@property
|
||||
def prefer_cross_layer_blocks(self) -> bool:
|
||||
return True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: KVCacheConfig | None = None,
|
||||
):
|
||||
super().__init__(vllm_config, role, kv_cache_config)
|
||||
|
||||
spec = OffloadingSpecFactory.create_spec(vllm_config, kv_cache_config)
|
||||
|
||||
self.connector_scheduler: OffloadingConnectorScheduler | None = None
|
||||
self.connector_worker: OffloadingConnectorWorker | None = None
|
||||
if role == KVConnectorRole.SCHEDULER:
|
||||
self.connector_scheduler = OffloadingConnectorScheduler(spec)
|
||||
elif role == KVConnectorRole.WORKER:
|
||||
self.connector_worker = OffloadingConnectorWorker(spec)
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
assert self.connector_worker is not None
|
||||
self.connector_worker.register_kv_caches(kv_caches)
|
||||
|
||||
def register_cross_layers_kv_cache(
|
||||
self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
|
||||
):
|
||||
assert self.connector_worker is not None
|
||||
self.connector_worker.register_cross_layers_kv_cache(kv_cache, attn_backend)
|
||||
|
||||
def handle_preemptions(self, preempted_req_ids: set[str]):
|
||||
assert self.connector_worker is not None
|
||||
self.connector_worker.handle_preemptions(preempted_req_ids)
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
|
||||
assert self.connector_worker is not None
|
||||
assert isinstance(self._connector_metadata, OffloadingConnectorMetadata)
|
||||
self.connector_worker.start_kv_transfers(self._connector_metadata)
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
pass
|
||||
|
||||
def save_kv_layer(
|
||||
self,
|
||||
layer_name: str,
|
||||
kv_layer: torch.Tensor,
|
||||
attn_metadata: "AttentionMetadata",
|
||||
**kwargs,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def wait_for_save(self):
|
||||
assert self.connector_worker is not None
|
||||
assert isinstance(self._connector_metadata, OffloadingConnectorMetadata)
|
||||
self.connector_worker.prepare_store_kv(self._connector_metadata)
|
||||
|
||||
def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
|
||||
assert self.connector_worker is not None
|
||||
return self.connector_worker.get_finished(finished_req_ids)
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: "Request", num_computed_tokens: int
|
||||
) -> tuple[int | None, bool]:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.get_num_new_matched_tokens(
|
||||
request, num_computed_tokens
|
||||
)
|
||||
|
||||
def update_state_after_alloc(
|
||||
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
|
||||
):
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.update_state_after_alloc(
|
||||
request, blocks, num_external_tokens
|
||||
)
|
||||
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: SchedulerOutput
|
||||
) -> KVConnectorMetadata:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.build_connector_meta(scheduler_output)
|
||||
|
||||
def update_connector_output(self, connector_output: KVConnectorOutput):
|
||||
assert self.connector_scheduler is not None
|
||||
self.connector_scheduler.update_connector_output(connector_output)
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.request_finished(request, block_ids)
|
||||
|
||||
def take_events(self) -> Iterable[KVCacheEvent]:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.take_events()
|
||||
|
||||
def get_kv_connector_stats(self) -> KVConnectorStats | None:
|
||||
if self.connector_worker is None:
|
||||
return None # We only emit stats from the worker-side
|
||||
return self.connector_worker.get_kv_connector_stats()
|
||||
|
||||
@classmethod
|
||||
def build_kv_connector_stats(
|
||||
cls, data: dict[str, Any] | None = None
|
||||
) -> KVConnectorStats | None:
|
||||
return (
|
||||
OffloadingConnectorStats(data=data)
|
||||
if data is not None
|
||||
else OffloadingConnectorStats()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def build_prom_metrics(
|
||||
cls,
|
||||
vllm_config: VllmConfig,
|
||||
metric_types: dict[type[PromMetric], type[PromMetricT]],
|
||||
labelnames: list[str],
|
||||
per_engine_labelvalues: dict[int, list[object]],
|
||||
) -> KVConnectorPromMetrics:
|
||||
return OffloadPromMetrics(
|
||||
vllm_config, metric_types, labelnames, per_engine_labelvalues
|
||||
)
|
||||
|
||||
|
||||
class OffloadingConnectorScheduler:
|
||||
"""Implementation of Scheduler side methods"""
|
||||
|
||||
def __init__(self, spec: OffloadingSpec):
|
||||
self.gpu_block_size = spec.gpu_block_size
|
||||
self.offloaded_block_size = spec.offloaded_block_size
|
||||
self.block_size_factor = self.offloaded_block_size // self.gpu_block_size
|
||||
self.manager: OffloadingManager = spec.get_manager()
|
||||
|
||||
self._requests: dict[ReqId, Request] = {}
|
||||
# list of GPU block IDs per request
|
||||
self._request_block_ids: dict[ReqId, list[int]] = {}
|
||||
# requests to load for the current scheduler step
|
||||
self._reqs_to_load: dict[ReqId, TransferSpec] = {}
|
||||
# request blocks are stored in order
|
||||
# index of next block (of size offloaded_block_size) to offload
|
||||
self._next_stored_block_idx: dict[ReqId, int] = {}
|
||||
# if GPU prefix caching is enabled,
|
||||
# track loaded blocks to avoid redundant loads
|
||||
self._blocks_being_loaded: set[BlockHash] | None = (
|
||||
set() if spec.vllm_config.cache_config.enable_prefix_caching else None
|
||||
)
|
||||
|
||||
# request ID -> set(block hashes being stored/load)
|
||||
self._reqs_being_stored = defaultdict[ReqId, set[BlockHash]](set)
|
||||
self._reqs_being_loaded = defaultdict[ReqId, set[BlockHash]](set)
|
||||
|
||||
def _get_block_hashes(
|
||||
self,
|
||||
req: Request,
|
||||
start_idx: int = 0,
|
||||
end_idx: int | None = None,
|
||||
) -> Iterable[BlockHash]:
|
||||
return islice(
|
||||
req.block_hashes,
|
||||
self.block_size_factor * start_idx + self.block_size_factor - 1,
|
||||
self.block_size_factor * end_idx if end_idx else None,
|
||||
self.block_size_factor,
|
||||
)
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: Request, num_computed_tokens: int
|
||||
) -> tuple[int | None, bool]:
|
||||
"""
|
||||
Get number of new tokens that can be loaded beyond the
|
||||
num_computed_tokens.
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
num_computed_tokens (int): the number of locally
|
||||
computed tokens for this request
|
||||
|
||||
Returns:
|
||||
A tuple with the following elements:
|
||||
- The number of tokens that can be loaded beyond what is
|
||||
already computed.
|
||||
If None, it means that the connector needs more time to
|
||||
determine the number of matched tokens, and the scheduler
|
||||
should query for this request again later.
|
||||
- `True` if tokens will be loaded asynchronously
|
||||
(between scheduler steps).
|
||||
"""
|
||||
num_blocks = request.num_tokens // self.offloaded_block_size
|
||||
|
||||
assert len(request.block_hashes) // self.block_size_factor == num_blocks
|
||||
block_hashes = self._get_block_hashes(request)
|
||||
|
||||
self.manager.touch(block_hashes)
|
||||
|
||||
full_block_tokens = self.offloaded_block_size * num_blocks
|
||||
if full_block_tokens - num_computed_tokens < self.offloaded_block_size:
|
||||
# we can load less than a block, skip
|
||||
return 0, False
|
||||
|
||||
start_block_idx = num_computed_tokens // self.offloaded_block_size
|
||||
hits = self.manager.lookup(
|
||||
self._get_block_hashes(request, start_idx=start_block_idx)
|
||||
)
|
||||
if hits is None:
|
||||
# indicates a lookup that should be tried later
|
||||
return None, False
|
||||
if hits == 0:
|
||||
return 0, False
|
||||
|
||||
num_hit_tokens = (
|
||||
self.offloaded_block_size * (start_block_idx + hits) - num_computed_tokens
|
||||
)
|
||||
logger.debug(
|
||||
"Request %s hit %s offloaded tokens after %s GPU hit tokens",
|
||||
request.request_id,
|
||||
num_hit_tokens,
|
||||
num_computed_tokens,
|
||||
)
|
||||
if num_hit_tokens < self.offloaded_block_size:
|
||||
return 0, False
|
||||
|
||||
if self._blocks_being_loaded:
|
||||
block_hashes = self._get_block_hashes(
|
||||
request, start_idx=start_block_idx, end_idx=start_block_idx + hits
|
||||
)
|
||||
|
||||
if any(
|
||||
block_hash in self._blocks_being_loaded for block_hash in block_hashes
|
||||
):
|
||||
# hit blocks are being loaded, delay request
|
||||
logger.debug(
|
||||
"Delaying request %s since some of its blocks are already"
|
||||
" being loaded",
|
||||
request.request_id,
|
||||
)
|
||||
return None, False
|
||||
|
||||
return num_hit_tokens, True
|
||||
|
||||
def update_state_after_alloc(
|
||||
self, request: Request, blocks: KVCacheBlocks, num_external_tokens: int
|
||||
):
|
||||
self._requests[request.request_id] = request
|
||||
# the block ids are updated in _get_reqs_to_store
|
||||
self._request_block_ids[request.request_id] = []
|
||||
|
||||
if num_external_tokens == 0:
|
||||
return
|
||||
|
||||
block_groups = blocks.get_block_ids()
|
||||
block_ids = block_groups[0]
|
||||
|
||||
num_computed_gpu_blocks = sum(
|
||||
block.block_hash is not None for block in blocks.blocks[0]
|
||||
)
|
||||
num_computed_tokens = num_computed_gpu_blocks * self.gpu_block_size
|
||||
full_block_tokens = num_computed_tokens + num_external_tokens
|
||||
assert full_block_tokens % self.offloaded_block_size == 0
|
||||
|
||||
num_pending_gpu_blocks = len(block_ids) - num_computed_gpu_blocks
|
||||
assert num_external_tokens == num_pending_gpu_blocks * self.gpu_block_size
|
||||
|
||||
start_block_idx = num_computed_tokens // self.offloaded_block_size
|
||||
num_blocks = full_block_tokens // self.offloaded_block_size
|
||||
|
||||
assert len(request.block_hashes) // self.block_size_factor >= num_blocks
|
||||
block_hashes = self._get_block_hashes(
|
||||
request, start_idx=start_block_idx, end_idx=num_blocks
|
||||
)
|
||||
|
||||
src_spec = self.manager.prepare_load(block_hashes)
|
||||
dst_spec = GPULoadStoreSpec(block_ids[num_computed_gpu_blocks:])
|
||||
|
||||
block_hashes = self._get_block_hashes(
|
||||
request, start_idx=start_block_idx, end_idx=num_blocks
|
||||
)
|
||||
|
||||
self._reqs_to_load[request.request_id] = (src_spec, dst_spec)
|
||||
req_blocks_being_loaded = self._reqs_being_loaded[request.request_id]
|
||||
req_blocks_being_loaded.update(block_hashes)
|
||||
self._next_stored_block_idx[request.request_id] = num_blocks
|
||||
|
||||
if self._blocks_being_loaded is not None:
|
||||
self._blocks_being_loaded.update(req_blocks_being_loaded)
|
||||
|
||||
def _get_reqs_to_store(self, scheduler_output: SchedulerOutput):
|
||||
reqs_to_store: dict[ReqId, TransferSpec] = {}
|
||||
# iterate over both new and cached requests
|
||||
for req_id, new_block_id_groups, preempted in yield_req_data(scheduler_output):
|
||||
if preempted:
|
||||
self._request_block_ids[req_id] = []
|
||||
|
||||
if new_block_id_groups:
|
||||
new_block_ids = new_block_id_groups[0]
|
||||
self._request_block_ids[req_id] += new_block_ids
|
||||
|
||||
block_ids = self._request_block_ids[req_id]
|
||||
|
||||
req = self._requests[req_id]
|
||||
new_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
total_tokens = req.num_computed_tokens + new_tokens
|
||||
num_blocks = total_tokens // self.offloaded_block_size
|
||||
start_block_idx = self._next_stored_block_idx.get(req_id, 0)
|
||||
num_new_blocks = num_blocks - start_block_idx
|
||||
|
||||
if num_new_blocks <= 0:
|
||||
continue
|
||||
|
||||
# NOTE: In async scheduling, placeholders may temporarily make
|
||||
# len(req.block_hashes) < num_blocks * self.block_size_factor.
|
||||
|
||||
new_block_hashes = self._get_block_hashes(
|
||||
req, start_idx=start_block_idx, end_idx=num_blocks
|
||||
)
|
||||
store_output = self.manager.prepare_store(new_block_hashes)
|
||||
if store_output is None:
|
||||
logger.warning(
|
||||
"Request %s: cannot store %s blocks", req_id, num_new_blocks
|
||||
)
|
||||
continue
|
||||
|
||||
self._next_stored_block_idx[req_id] = num_blocks
|
||||
|
||||
if not store_output.block_hashes_to_store:
|
||||
continue
|
||||
block_hashes_to_store = set(store_output.block_hashes_to_store)
|
||||
|
||||
block_hashes = self._get_block_hashes(req, end_idx=num_blocks)
|
||||
self.manager.touch(block_hashes)
|
||||
|
||||
new_block_hashes = self._get_block_hashes(
|
||||
req, start_idx=start_block_idx, end_idx=num_blocks
|
||||
)
|
||||
dst_spec = store_output.store_spec
|
||||
src_block_ids: list[int] = []
|
||||
for idx, blk_hash in enumerate(new_block_hashes):
|
||||
if blk_hash not in block_hashes_to_store:
|
||||
continue
|
||||
offloaded_block_idx = start_block_idx + idx
|
||||
gpu_block_idx = offloaded_block_idx * self.block_size_factor
|
||||
for i in range(self.block_size_factor):
|
||||
src_block_ids.append(block_ids[gpu_block_idx + i])
|
||||
src_spec = GPULoadStoreSpec(src_block_ids)
|
||||
|
||||
reqs_to_store[req_id] = (src_spec, dst_spec)
|
||||
self._reqs_being_stored[req_id] |= block_hashes_to_store
|
||||
|
||||
logger.debug(
|
||||
"Request %s offloading %s blocks starting from block #%d",
|
||||
req_id,
|
||||
len(block_hashes_to_store),
|
||||
start_block_idx,
|
||||
)
|
||||
|
||||
return reqs_to_store
|
||||
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: SchedulerOutput
|
||||
) -> KVConnectorMetadata:
|
||||
meta = OffloadingConnectorMetadata(
|
||||
reqs_to_load=self._reqs_to_load,
|
||||
reqs_to_store=self._get_reqs_to_store(scheduler_output),
|
||||
)
|
||||
self._reqs_to_load = {}
|
||||
|
||||
# NOTE (orozery): we should move this logic to update_connector_output
|
||||
# once KVConnectorOutput allows us to report completed transfers
|
||||
for req_id in scheduler_output.preempted_req_ids or ():
|
||||
block_hashes = self._reqs_being_stored.get(req_id)
|
||||
if block_hashes:
|
||||
self.manager.complete_store(block_hashes)
|
||||
block_hashes.clear()
|
||||
|
||||
return meta
|
||||
|
||||
def update_connector_output(self, connector_output: KVConnectorOutput):
|
||||
"""
|
||||
Update KVConnector state from worker-side connectors output.
|
||||
|
||||
Args:
|
||||
connector_output (KVConnectorOutput): the worker-side
|
||||
connectors output.
|
||||
"""
|
||||
for req_id in connector_output.finished_sending or []:
|
||||
block_hashes = self._reqs_being_stored.pop(req_id, None)
|
||||
if block_hashes:
|
||||
self.manager.complete_store(block_hashes)
|
||||
|
||||
for req_id in connector_output.finished_recving or []:
|
||||
block_hashes = self._reqs_being_loaded.pop(req_id, None)
|
||||
if block_hashes:
|
||||
if self._blocks_being_loaded:
|
||||
self._blocks_being_loaded.difference_update(block_hashes)
|
||||
self.manager.complete_load(block_hashes)
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: Request,
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
"""
|
||||
Called when a request has finished, before its blocks are freed.
|
||||
|
||||
Returns:
|
||||
True if the request is being saved/sent asynchronously and blocks
|
||||
should not be freed until the request_id is returned from
|
||||
get_finished().
|
||||
Optional KVTransferParams to be included in the request outputs
|
||||
returned by the engine.
|
||||
"""
|
||||
req_id = request.request_id
|
||||
self._requests.pop(req_id, None)
|
||||
self._request_block_ids.pop(req_id, None)
|
||||
self._next_stored_block_idx.pop(req_id, None)
|
||||
|
||||
request_being_stored = req_id in self._reqs_being_stored
|
||||
return request_being_stored, None
|
||||
|
||||
def take_events(self) -> Iterable[KVCacheEvent]:
|
||||
"""Take the KV cache events from the connector.
|
||||
|
||||
Returns:
|
||||
A list of KV cache events.
|
||||
"""
|
||||
for event in self.manager.take_events():
|
||||
if event.removed:
|
||||
yield BlockRemoved(block_hashes=event.block_hashes, medium=event.medium)
|
||||
else:
|
||||
yield BlockStored(
|
||||
block_hashes=event.block_hashes,
|
||||
parent_block_hash=None,
|
||||
token_ids=[],
|
||||
lora_id=None,
|
||||
block_size=event.block_size,
|
||||
medium=event.medium,
|
||||
lora_name=None,
|
||||
)
|
||||
|
||||
|
||||
class OffloadingConnectorWorker:
|
||||
"""Implementation of Worker side methods"""
|
||||
|
||||
def __init__(self, spec: OffloadingSpec):
|
||||
self.spec = spec
|
||||
self.worker = OffloadingWorker()
|
||||
|
||||
self._job_counter = 0
|
||||
|
||||
self.kv_connector_stats = OffloadingConnectorStats()
|
||||
# req_id -> (job_id, store)
|
||||
self._jobs: dict[int, tuple[ReqId, bool]] = {}
|
||||
# req_id -> active job IDs
|
||||
self._load_job: dict[ReqId, int] = {}
|
||||
# req_id -> set(active job IDs)
|
||||
self._store_jobs = defaultdict[ReqId, set[int]](set)
|
||||
# list of store jobs pending submission (job_id, transfer_spec)
|
||||
self._unsubmitted_store_jobs: list[tuple[int, TransferSpec]] = []
|
||||
|
||||
self._finished_reqs_waiting_for_store: set[ReqId] = set()
|
||||
|
||||
def _generate_job_id(self) -> int:
|
||||
job_id = self._job_counter
|
||||
self._job_counter = job_id + 1
|
||||
return job_id
|
||||
|
||||
def _register_handlers(
|
||||
self,
|
||||
kv_caches: dict[str, torch.Tensor],
|
||||
attn_backends: dict[str, type[AttentionBackend]],
|
||||
):
|
||||
for src_cls, dst_cls, handler in self.spec.get_handlers(
|
||||
kv_caches, attn_backends
|
||||
):
|
||||
self.worker.register_handler(src_cls, dst_cls, handler)
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
layer_names = list(kv_caches.keys())
|
||||
layers = get_layers_from_vllm_config(
|
||||
self.spec.vllm_config, Attention, layer_names
|
||||
)
|
||||
attn_backends = {
|
||||
layer_name: layers[layer_name].get_attn_backend()
|
||||
for layer_name in layer_names
|
||||
}
|
||||
self._register_handlers(kv_caches, attn_backends)
|
||||
|
||||
def register_cross_layers_kv_cache(
|
||||
self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
|
||||
):
|
||||
cross_layer_name = "ALL_LAYERS"
|
||||
kv_caches = {cross_layer_name: kv_cache}
|
||||
attn_backends = {cross_layer_name: attn_backend}
|
||||
self._register_handlers(kv_caches, attn_backends)
|
||||
|
||||
def handle_preemptions(self, preempted_req_ids: set[str]):
|
||||
for job_id, transfer_spec in self._unsubmitted_store_jobs:
|
||||
success = self.worker.transfer_async(job_id, transfer_spec)
|
||||
assert success
|
||||
self._unsubmitted_store_jobs.clear()
|
||||
|
||||
for req_id in preempted_req_ids:
|
||||
job_ids = self._store_jobs.get(req_id)
|
||||
if job_ids:
|
||||
self.worker.wait(job_ids)
|
||||
|
||||
def start_kv_transfers(self, metadata: OffloadingConnectorMetadata):
|
||||
for job_id, transfer_spec in self._unsubmitted_store_jobs:
|
||||
success = self.worker.transfer_async(job_id, transfer_spec)
|
||||
assert success
|
||||
self._unsubmitted_store_jobs.clear()
|
||||
|
||||
for req_id, transfer_spec in metadata.reqs_to_load.items():
|
||||
job_id = self._generate_job_id()
|
||||
self._jobs[job_id] = (req_id, False)
|
||||
assert req_id not in self._load_job
|
||||
self._load_job[req_id] = job_id
|
||||
success = self.worker.transfer_async(job_id, transfer_spec)
|
||||
assert success
|
||||
|
||||
def prepare_store_kv(self, metadata: OffloadingConnectorMetadata):
|
||||
for req_id, transfer_spec in metadata.reqs_to_store.items():
|
||||
job_id = self._generate_job_id()
|
||||
self._jobs[job_id] = (req_id, True)
|
||||
self._store_jobs[req_id].add(job_id)
|
||||
# NOTE(orozery): defer the store to the beginning of the next engine step,
|
||||
# so that offloading starts AFTER transfers related to token sampling,
|
||||
# thereby avoiding delays to token generation due to offloading.
|
||||
self._unsubmitted_store_jobs.append((job_id, transfer_spec))
|
||||
|
||||
def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
|
||||
"""
|
||||
Notifies worker-side connector ids of requests that have
|
||||
finished generating tokens.
|
||||
Returns a list of request IDs that finished loading or storing.
|
||||
|
||||
Returns:
|
||||
ids of requests that have finished asynchronous transfer
|
||||
tuple of (sending/saving ids, recving/loading ids).
|
||||
"""
|
||||
finished_sending = set()
|
||||
finished_recving = set()
|
||||
for transfer_result in self.worker.get_finished():
|
||||
# we currently do not support job failures
|
||||
job_id = transfer_result.job_id
|
||||
assert transfer_result.success
|
||||
req_id, store = self._jobs.pop(job_id)
|
||||
if (
|
||||
transfer_result.transfer_time
|
||||
and transfer_result.transfer_size is not None
|
||||
and transfer_result.transfer_type is not None
|
||||
):
|
||||
self.kv_connector_stats.record_transfer(
|
||||
num_bytes=transfer_result.transfer_size,
|
||||
time=transfer_result.transfer_time,
|
||||
transfer_type=transfer_result.transfer_type,
|
||||
)
|
||||
if store:
|
||||
req_jobs = self._store_jobs[req_id]
|
||||
req_jobs.remove(job_id)
|
||||
if req_jobs:
|
||||
continue
|
||||
|
||||
if req_id in self._finished_reqs_waiting_for_store:
|
||||
self._finished_reqs_waiting_for_store.remove(req_id)
|
||||
finished_sending.add(req_id)
|
||||
del self._store_jobs[req_id]
|
||||
else:
|
||||
req_job = self._load_job[req_id]
|
||||
assert job_id == req_job
|
||||
del self._load_job[req_id]
|
||||
finished_recving.add(req_id)
|
||||
|
||||
for req_id in finished_req_ids:
|
||||
pending_req_jobs = self._store_jobs.get(req_id)
|
||||
if pending_req_jobs:
|
||||
self._finished_reqs_waiting_for_store.add(req_id)
|
||||
elif pending_req_jobs is not None:
|
||||
finished_sending.add(req_id)
|
||||
del self._store_jobs[req_id]
|
||||
|
||||
return finished_sending, finished_recving
|
||||
|
||||
def get_kv_connector_stats(self) -> KVConnectorStats | None:
|
||||
"""
|
||||
Get the KV transfer stats for the connector.
|
||||
"""
|
||||
|
||||
if self.kv_connector_stats.is_empty():
|
||||
return None
|
||||
# Clear stats for next iteration
|
||||
kv_connector_stats = self.kv_connector_stats
|
||||
self.kv_connector_stats = OffloadingConnectorStats()
|
||||
return kv_connector_stats
|
||||
|
||||
|
||||
class OffloadPromMetrics(KVConnectorPromMetrics):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
metric_types: dict[type[PromMetric], type[PromMetricT]],
|
||||
labelnames: list[str],
|
||||
per_engine_labelvalues: dict[int, list[object]],
|
||||
):
|
||||
super().__init__(vllm_config, metric_types, labelnames, per_engine_labelvalues)
|
||||
# (engine_idx, transfer_tupe) -> (metric with bounded labels)
|
||||
self.histogram_transfer_size: dict[tuple[int, str], PromMetricT] = {}
|
||||
self.counter_kv_bytes: dict[tuple[int, str], PromMetricT] = {}
|
||||
self.counter_kv_transfer_time: dict[tuple[int, str], PromMetricT] = {}
|
||||
buckets = [ # In bytes
|
||||
1e6,
|
||||
5e6,
|
||||
10e6,
|
||||
20e6,
|
||||
40e6,
|
||||
60e6,
|
||||
80e6,
|
||||
100e6,
|
||||
150e6,
|
||||
200e6,
|
||||
]
|
||||
|
||||
self._counter_kv_bytes = self._counter_cls(
|
||||
name="vllm:kv_offload_total_bytes",
|
||||
documentation="Number of bytes offloaded by KV connector",
|
||||
labelnames=labelnames + ["transfer_type"],
|
||||
)
|
||||
|
||||
self._counter_kv_transfer_time = self._counter_cls(
|
||||
name="vllm:kv_offload_total_time",
|
||||
documentation="Total time measured by all KV offloading operations",
|
||||
labelnames=labelnames + ["transfer_type"],
|
||||
)
|
||||
|
||||
self._histogram_transfer_size = self._histogram_cls(
|
||||
name="vllm:kv_offload_size",
|
||||
documentation="Histogram of KV offload transfer size, in bytes.",
|
||||
buckets=buckets[:],
|
||||
labelnames=labelnames + ["transfer_type"],
|
||||
)
|
||||
|
||||
def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0):
|
||||
"""
|
||||
Observe transfer statistics from the new data structure.
|
||||
transfer_stats_data is expected to be a dict where:
|
||||
- keys are transfer type strings (e.g., "cpu_to_gpu", "gpu_to_cpu")
|
||||
- values are lists of OffloadingOperationMetrics objects
|
||||
"""
|
||||
|
||||
for transfer_type, ops in transfer_stats_data.items():
|
||||
# Cache:
|
||||
if (engine_idx, transfer_type) not in self.histogram_transfer_size:
|
||||
self.histogram_transfer_size[(engine_idx, transfer_type)] = (
|
||||
self._histogram_transfer_size.labels(
|
||||
*(self.per_engine_labelvalues[engine_idx] + [transfer_type])
|
||||
)
|
||||
)
|
||||
self.counter_kv_bytes[(engine_idx, transfer_type)] = (
|
||||
self._counter_kv_bytes.labels(
|
||||
*(self.per_engine_labelvalues[engine_idx] + [transfer_type])
|
||||
)
|
||||
)
|
||||
self.counter_kv_transfer_time[(engine_idx, transfer_type)] = (
|
||||
self._counter_kv_transfer_time.labels(
|
||||
*(self.per_engine_labelvalues[engine_idx] + [transfer_type])
|
||||
)
|
||||
)
|
||||
|
||||
# Process ops:
|
||||
assert isinstance(ops, list)
|
||||
for op in ops: # ops is a list of serialized OffloadingOperationMetrics
|
||||
assert isinstance(op, dict)
|
||||
# Observe size histogram
|
||||
self.histogram_transfer_size[(engine_idx, transfer_type)].observe(
|
||||
op["op_size"]
|
||||
)
|
||||
|
||||
# Increment byte and time counters
|
||||
self.counter_kv_bytes[(engine_idx, transfer_type)].inc(op["op_size"])
|
||||
|
||||
self.counter_kv_transfer_time[(engine_idx, transfer_type)].inc(
|
||||
op["op_time"]
|
||||
)
|
||||
@@ -0,0 +1,531 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1,
|
||||
KVConnectorMetadata,
|
||||
KVConnectorRole,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_engine import (
|
||||
P2pNcclEngine,
|
||||
)
|
||||
from vllm.distributed.parallel_state import get_world_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention.mla_attention import MLACommonMetadata
|
||||
from vllm.v1.attention.backend import AttentionMetadata
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReqMeta:
|
||||
# Request Id
|
||||
request_id: str
|
||||
# Request block ids
|
||||
block_ids: torch.Tensor
|
||||
# Request num tokens
|
||||
num_tokens: int
|
||||
|
||||
@staticmethod
|
||||
def make_meta(
|
||||
request_id: str, token_ids: list[int], block_ids: list[int], block_size: int
|
||||
) -> "ReqMeta":
|
||||
block_ids_tensor = torch.tensor(block_ids)
|
||||
return ReqMeta(
|
||||
request_id=request_id,
|
||||
block_ids=block_ids_tensor,
|
||||
num_tokens=len(token_ids),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class P2pNcclConnectorMetadata(KVConnectorMetadata):
|
||||
requests: list[ReqMeta]
|
||||
|
||||
def __init__(self):
|
||||
self.requests = []
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
request_id: str,
|
||||
token_ids: list[int],
|
||||
block_ids: list[int],
|
||||
block_size: int,
|
||||
) -> None:
|
||||
self.requests.append(
|
||||
ReqMeta.make_meta(request_id, token_ids, block_ids, block_size)
|
||||
)
|
||||
|
||||
|
||||
class P2pNcclConnector(KVConnectorBase_V1):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: "KVCacheConfig | None" = None,
|
||||
):
|
||||
super().__init__(
|
||||
vllm_config=vllm_config,
|
||||
role=role,
|
||||
kv_cache_config=kv_cache_config,
|
||||
)
|
||||
self._block_size = vllm_config.cache_config.block_size
|
||||
self._requests_need_load: dict[str, Any] = {}
|
||||
self.is_producer = self._kv_transfer_config.is_kv_producer
|
||||
self.chunked_prefill: dict[str, tuple[list[int], list[int] | None]] = {}
|
||||
|
||||
self._rank = get_world_group().rank if role == KVConnectorRole.WORKER else 0
|
||||
self._local_rank = (
|
||||
get_world_group().local_rank if role == KVConnectorRole.WORKER else 0
|
||||
)
|
||||
|
||||
self.p2p_nccl_engine = (
|
||||
P2pNcclEngine(
|
||||
local_rank=self._local_rank,
|
||||
config=self._kv_transfer_config,
|
||||
hostname="",
|
||||
port_offset=self._rank,
|
||||
)
|
||||
if role == KVConnectorRole.WORKER
|
||||
else None
|
||||
)
|
||||
|
||||
# ==============================
|
||||
# Worker-side methods
|
||||
# ==============================
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None:
|
||||
"""Start loading the KV cache from the connector buffer to vLLM's
|
||||
paged KV buffer.
|
||||
|
||||
Args:
|
||||
forward_context (ForwardContext): the forward context.
|
||||
**kwargs: additional arguments for the load operation
|
||||
|
||||
Note:
|
||||
The number of elements in kv_caches and layer_names should be
|
||||
the same.
|
||||
"""
|
||||
|
||||
# Only consumer/decode loads KV Cache
|
||||
if self.is_producer:
|
||||
return
|
||||
|
||||
assert self.p2p_nccl_engine is not None
|
||||
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if attn_metadata is None:
|
||||
return
|
||||
|
||||
def inject_kv_into_layer(
|
||||
layer: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
block_ids: torch.Tensor,
|
||||
request_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
Inject KV cache data into a given attention layer tensor.
|
||||
|
||||
This function updates `layer` in-place with values from `kv_cache`,
|
||||
handling different backend layouts:
|
||||
- MLA (Multi-Linear Attention) or FlashInfer: KV tensors are
|
||||
indexed along the first dimension.
|
||||
- FlashAttention: KV tensors are indexed along the second
|
||||
dimension.
|
||||
|
||||
If the number of provided block IDs does not match the number of KV
|
||||
blocks, only the overlapping portion is updated, and a warning is
|
||||
logged.
|
||||
|
||||
Args:
|
||||
layer (torch.Tensor): The attention layer KV tensor to update.
|
||||
kv_cache (torch.Tensor): The KV cache tensor to inject.
|
||||
block_ids (torch.Tensor): Indices of the blocks to update.
|
||||
request_id (str): Request identifier used for logging.
|
||||
|
||||
Returns:
|
||||
None. The function modifies `layer` in-place.
|
||||
"""
|
||||
if (
|
||||
isinstance(attn_metadata, MLACommonMetadata) or layer.shape[1] == 2
|
||||
): # MLA or FlashInfer
|
||||
num_block = kv_cache.shape[0]
|
||||
self.check_tensors_except_dim(layer, kv_cache, 0)
|
||||
if len(block_ids) == num_block:
|
||||
layer[block_ids, ...] = kv_cache
|
||||
else:
|
||||
layer[block_ids[:num_block], ...] = kv_cache
|
||||
logger.warning(
|
||||
"🚧kv_cache does not match, block_ids:%d, "
|
||||
"num_block:%d, request_id:%s",
|
||||
len(block_ids),
|
||||
num_block,
|
||||
request_id,
|
||||
)
|
||||
|
||||
elif layer.shape[0] == 2: # FlashAttention
|
||||
num_block = kv_cache.shape[1]
|
||||
self.check_tensors_except_dim(layer, kv_cache, 1)
|
||||
if len(block_ids) == num_block:
|
||||
layer[:, block_ids, ...] = kv_cache
|
||||
else:
|
||||
layer[:, block_ids[:num_block], ...] = kv_cache
|
||||
logger.warning(
|
||||
"🚧kv_cache does not match, block_ids:%d, "
|
||||
"num_block:%d, request_id:%s",
|
||||
len(block_ids),
|
||||
num_block,
|
||||
request_id,
|
||||
)
|
||||
|
||||
# Get the metadata
|
||||
metadata: KVConnectorMetadata = self._get_connector_metadata()
|
||||
assert isinstance(metadata, P2pNcclConnectorMetadata)
|
||||
|
||||
if metadata is None:
|
||||
return
|
||||
|
||||
# Load the KV for each request each layer
|
||||
for request in metadata.requests:
|
||||
request_id = request.request_id
|
||||
ip, port = self.parse_request_id(request_id, False)
|
||||
remote_address = ip + ":" + str(port + self._rank)
|
||||
for layer_name in forward_context.no_compile_layers:
|
||||
layer = forward_context.no_compile_layers[layer_name]
|
||||
|
||||
# Only process layers that have kv_cache
|
||||
# attribute (attention layers) Skip non-attention
|
||||
# layers like FusedMoE
|
||||
kv_cache = getattr(layer, "kv_cache", None)
|
||||
if kv_cache is None:
|
||||
continue
|
||||
|
||||
layer = kv_cache[forward_context.virtual_engine]
|
||||
|
||||
kv_cache = self.p2p_nccl_engine.recv_tensor(
|
||||
request.request_id + "#" + layer_name, remote_address
|
||||
)
|
||||
|
||||
if kv_cache is None:
|
||||
logger.warning("🚧kv_cache is None, %s", request.request_id)
|
||||
continue
|
||||
|
||||
inject_kv_into_layer(
|
||||
layer, kv_cache, request.block_ids, request.request_id
|
||||
)
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
"""Blocking until the KV for a specific layer is loaded into vLLM's
|
||||
paged buffer.
|
||||
|
||||
This interface will be useful for layer-by-layer pipelining.
|
||||
|
||||
Args:
|
||||
layer_name: the name of that layer
|
||||
"""
|
||||
return
|
||||
|
||||
def save_kv_layer(
|
||||
self,
|
||||
layer_name: str,
|
||||
kv_layer: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Start saving the KV cache of the layer from vLLM's paged buffer
|
||||
to the connector.
|
||||
|
||||
Args:
|
||||
layer_name (str): the name of the layer.
|
||||
kv_layer (torch.Tensor): the paged KV buffer of the current
|
||||
layer in vLLM.
|
||||
attn_metadata (AttentionMetadata): the attention metadata.
|
||||
**kwargs: additional arguments for the save operation.
|
||||
"""
|
||||
|
||||
# Only producer/prefill saves KV Cache
|
||||
if not self.is_producer:
|
||||
return
|
||||
|
||||
assert self.p2p_nccl_engine is not None
|
||||
|
||||
def extract_kv_from_layer(
|
||||
layer: torch.Tensor,
|
||||
block_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Extract KV cache slices from a given attention layer tensor.
|
||||
|
||||
This function handles multiple backend layouts:
|
||||
- MLA (Multi-Linear Attention) or FlashInfer: KV tensors are
|
||||
indexed along the first dimension.
|
||||
- FlashAttention: KV tensors are indexed along the second
|
||||
dimension.
|
||||
|
||||
Args:
|
||||
layer (torch.Tensor): The KV cache from the attention layer.
|
||||
block_ids (torch.Tensor): Indices of blocks to extract.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: A tensor containing the extracted KV slices.
|
||||
Returns None if the layout is unsupported.
|
||||
"""
|
||||
if (
|
||||
isinstance(attn_metadata, MLACommonMetadata) or layer.shape[1] == 2
|
||||
): # MLA or FlashInfer
|
||||
return layer[block_ids, ...]
|
||||
|
||||
if layer.shape[0] == 2: # FlashAttention
|
||||
return layer[:, block_ids, ...]
|
||||
|
||||
return None
|
||||
|
||||
connector_metadata = self._get_connector_metadata()
|
||||
assert isinstance(connector_metadata, P2pNcclConnectorMetadata)
|
||||
for request in connector_metadata.requests:
|
||||
request_id = request.request_id
|
||||
ip, port = self.parse_request_id(request_id, True)
|
||||
remote_address = ip + ":" + str(port + self._rank)
|
||||
|
||||
kv_cache = extract_kv_from_layer(kv_layer, request.block_ids)
|
||||
self.p2p_nccl_engine.send_tensor(
|
||||
request_id + "#" + layer_name, kv_cache, remote_address
|
||||
)
|
||||
|
||||
def wait_for_save(self):
|
||||
if self.is_producer:
|
||||
assert self.p2p_nccl_engine is not None
|
||||
self.p2p_nccl_engine.wait_for_sent()
|
||||
|
||||
def get_finished(
|
||||
self, finished_req_ids: set[str], **kwargs: Any
|
||||
) -> tuple[set[str] | None, set[str] | None]:
|
||||
"""
|
||||
Notifies worker-side connector ids of requests that have
|
||||
finished generating tokens.
|
||||
|
||||
Returns:
|
||||
ids of requests that have finished asynchronous transfer,
|
||||
tuple of (sending/saving ids, recving/loading ids).
|
||||
The finished saves/sends req ids must belong to a set provided in a
|
||||
call to this method (this call or a prior one).
|
||||
"""
|
||||
|
||||
assert self.p2p_nccl_engine is not None
|
||||
|
||||
no_compile_layers = self._vllm_config.compilation_config.static_forward_context
|
||||
return self.p2p_nccl_engine.get_finished(finished_req_ids, no_compile_layers)
|
||||
|
||||
# ==============================
|
||||
# Scheduler-side methods
|
||||
# ==============================
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self,
|
||||
request: "Request",
|
||||
num_computed_tokens: int,
|
||||
) -> tuple[int, bool]:
|
||||
"""
|
||||
Get number of new tokens that can be loaded from the
|
||||
external KV cache beyond the num_computed_tokens.
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
num_computed_tokens (int): the number of locally
|
||||
computed tokens for this request
|
||||
|
||||
Returns:
|
||||
the number of tokens that can be loaded from the
|
||||
external KV cache beyond what is already computed.
|
||||
"""
|
||||
if self.is_producer:
|
||||
return 0, False
|
||||
|
||||
prompt_token_ids = request.prompt_token_ids or []
|
||||
num_external_tokens = len(prompt_token_ids) - 1 - num_computed_tokens
|
||||
|
||||
if num_external_tokens < 0:
|
||||
num_external_tokens = 0
|
||||
|
||||
return num_external_tokens, False
|
||||
|
||||
def update_state_after_alloc(
|
||||
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
|
||||
):
|
||||
"""
|
||||
Update KVConnector state after block allocation.
|
||||
"""
|
||||
if not self.is_producer and num_external_tokens > 0:
|
||||
self._requests_need_load[request.request_id] = (
|
||||
request,
|
||||
blocks.get_block_ids()[0],
|
||||
)
|
||||
|
||||
def build_connector_meta(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
) -> KVConnectorMetadata:
|
||||
"""Build the connector metadata for this step.
|
||||
|
||||
This function should NOT modify any fields in the scheduler_output.
|
||||
Also, calling this function will reset the state of the connector.
|
||||
|
||||
Args:
|
||||
scheduler_output (SchedulerOutput): the scheduler output object.
|
||||
"""
|
||||
|
||||
meta = P2pNcclConnectorMetadata()
|
||||
|
||||
for new_req in scheduler_output.scheduled_new_reqs:
|
||||
if self.is_producer:
|
||||
num_scheduled_tokens = (scheduler_output.num_scheduled_tokens)[
|
||||
new_req.req_id
|
||||
]
|
||||
num_tokens = num_scheduled_tokens + new_req.num_computed_tokens
|
||||
# the request's prompt is chunked prefill
|
||||
if num_tokens < len(new_req.prompt_token_ids or []):
|
||||
# 'CachedRequestData' has no attribute 'prompt_token_ids'
|
||||
self.chunked_prefill[new_req.req_id] = (
|
||||
new_req.block_ids[0],
|
||||
new_req.prompt_token_ids,
|
||||
)
|
||||
continue
|
||||
# the request's prompt is not chunked prefill
|
||||
meta.add_request(
|
||||
request_id=new_req.req_id,
|
||||
token_ids=new_req.prompt_token_ids or [],
|
||||
block_ids=new_req.block_ids[0],
|
||||
block_size=self._block_size,
|
||||
)
|
||||
continue
|
||||
if new_req.req_id in self._requests_need_load:
|
||||
meta.add_request(
|
||||
request_id=new_req.req_id,
|
||||
token_ids=new_req.prompt_token_ids or [],
|
||||
block_ids=new_req.block_ids[0],
|
||||
block_size=self._block_size,
|
||||
)
|
||||
self._requests_need_load.pop(new_req.req_id)
|
||||
|
||||
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||
for i, req_id in enumerate(cached_reqs.req_ids):
|
||||
num_computed_tokens = cached_reqs.num_computed_tokens[i]
|
||||
new_block_ids = cached_reqs.new_block_ids[i]
|
||||
resumed_from_preemption = req_id in cached_reqs.resumed_req_ids
|
||||
|
||||
if self.is_producer:
|
||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
num_tokens = num_scheduled_tokens + num_computed_tokens
|
||||
assert req_id in self.chunked_prefill
|
||||
assert new_block_ids is not None
|
||||
block_ids = new_block_ids[0]
|
||||
if not resumed_from_preemption:
|
||||
block_ids = self.chunked_prefill[req_id][0] + block_ids
|
||||
prompt_token_ids = self.chunked_prefill[req_id][1]
|
||||
assert prompt_token_ids is not None
|
||||
# the request's prompt is chunked prefill again
|
||||
if num_tokens < len(prompt_token_ids):
|
||||
self.chunked_prefill[req_id] = (block_ids, prompt_token_ids)
|
||||
continue
|
||||
# the request's prompt is all prefilled finally
|
||||
meta.add_request(
|
||||
request_id=req_id,
|
||||
token_ids=prompt_token_ids,
|
||||
block_ids=block_ids,
|
||||
block_size=self._block_size,
|
||||
)
|
||||
self.chunked_prefill.pop(req_id, None)
|
||||
continue
|
||||
|
||||
# NOTE(rob): here we rely on the resumed requests being
|
||||
# the first N requests in the list scheduled_cache_reqs.
|
||||
if not resumed_from_preemption:
|
||||
break
|
||||
if req_id in self._requests_need_load:
|
||||
request, _ = self._requests_need_load.pop(req_id)
|
||||
total_tokens = num_computed_tokens + 1
|
||||
token_ids = request.all_token_ids[:total_tokens]
|
||||
|
||||
# NOTE(rob): For resumed req, new_block_ids is all
|
||||
# of the block_ids for the request.
|
||||
assert new_block_ids is not None
|
||||
block_ids = new_block_ids[0]
|
||||
|
||||
meta.add_request(
|
||||
request_id=req_id,
|
||||
token_ids=token_ids,
|
||||
block_ids=block_ids,
|
||||
block_size=self._block_size,
|
||||
)
|
||||
|
||||
self._requests_need_load.clear()
|
||||
return meta
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
"""
|
||||
Called when a request has finished, before its blocks are freed.
|
||||
|
||||
Returns:
|
||||
True if the request is being saved/sent asynchronously and blocks
|
||||
should not be freed until the request_id is returned from
|
||||
get_finished().
|
||||
Optional KVTransferParams to be included in the request outputs
|
||||
returned by the engine.
|
||||
"""
|
||||
|
||||
self.chunked_prefill.pop(request.request_id, None)
|
||||
|
||||
return False, None
|
||||
|
||||
# ==============================
|
||||
# Static methods
|
||||
# ==============================
|
||||
|
||||
@staticmethod
|
||||
def parse_request_id(request_id: str, is_prefill=True) -> tuple[str, int]:
|
||||
# Regular expression to match the string hostname and integer port
|
||||
if is_prefill:
|
||||
pattern = r"___decode_addr_(.*):(\d+)"
|
||||
else:
|
||||
pattern = r"___prefill_addr_(.*):(\d+)___"
|
||||
|
||||
# Use re.search to find the pattern in the request_id
|
||||
match = re.search(pattern, request_id)
|
||||
if match:
|
||||
# Extract the ranks
|
||||
ip = match.group(1)
|
||||
port = int(match.group(2))
|
||||
|
||||
return ip, port
|
||||
raise ValueError(f"Request id {request_id} does not contain hostname and port")
|
||||
|
||||
@staticmethod
|
||||
def check_tensors_except_dim(tensor1, tensor2, dim):
|
||||
shape1 = tensor1.size()
|
||||
shape2 = tensor2.size()
|
||||
|
||||
if len(shape1) != len(shape2) or not all(
|
||||
s1 == s2 for i, (s1, s2) in enumerate(zip(shape1, shape2)) if i != dim
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"Currently, only symmetric TP is supported. Asymmetric TP, PP,"
|
||||
"and others will be supported in future PRs."
|
||||
)
|
||||
@@ -0,0 +1,632 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from collections import deque
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import msgpack
|
||||
import torch
|
||||
import zmq
|
||||
|
||||
from vllm.config.kv_transfer import KVTransferConfig
|
||||
from vllm.distributed.device_communicators.pynccl_wrapper import (
|
||||
NCCLLibrary,
|
||||
buffer_type,
|
||||
cudaStream_t,
|
||||
ncclComm_t,
|
||||
ncclDataTypeEnum,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.p2p.tensor_memory_pool import ( # noqa: E501
|
||||
TensorMemoryPool,
|
||||
)
|
||||
from vllm.utils.network_utils import get_ip
|
||||
from vllm.utils.torch_utils import current_stream
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_MEM_POOL_SIZE_GB = 32
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_p2p_nccl_context(num_channels: str):
|
||||
original_values: dict[str, Any] = {}
|
||||
env_vars = [
|
||||
"NCCL_MAX_NCHANNELS",
|
||||
"NCCL_MIN_NCHANNELS",
|
||||
"NCCL_CUMEM_ENABLE",
|
||||
"NCCL_BUFFSIZE",
|
||||
"NCCL_PROTO", # LL,LL128,SIMPLE
|
||||
"NCCL_ALGO", # RING,TREE
|
||||
]
|
||||
|
||||
for var in env_vars:
|
||||
original_values[var] = os.environ.get(var)
|
||||
|
||||
logger.info("set_p2p_nccl_context, original_values: %s", original_values)
|
||||
|
||||
try:
|
||||
os.environ["NCCL_MAX_NCHANNELS"] = num_channels
|
||||
os.environ["NCCL_MIN_NCHANNELS"] = num_channels
|
||||
os.environ["NCCL_CUMEM_ENABLE"] = "1"
|
||||
yield
|
||||
finally:
|
||||
for var in env_vars:
|
||||
if original_values[var] is not None:
|
||||
os.environ[var] = original_values[var]
|
||||
else:
|
||||
os.environ.pop(var, None)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SendQueueItem:
|
||||
tensor_id: str
|
||||
remote_address: str
|
||||
tensor: torch.Tensor
|
||||
|
||||
|
||||
class P2pNcclEngine:
|
||||
def __init__(
|
||||
self,
|
||||
local_rank: int,
|
||||
config: KVTransferConfig,
|
||||
hostname: str = "",
|
||||
port_offset: int = 0,
|
||||
library_path: str | None = None,
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.rank = port_offset
|
||||
self.local_rank = local_rank
|
||||
self.device = torch.device(f"cuda:{self.local_rank}")
|
||||
self.nccl = NCCLLibrary(library_path)
|
||||
|
||||
if not hostname:
|
||||
hostname = get_ip()
|
||||
port = int(self.config.kv_port) + port_offset
|
||||
if port == 0:
|
||||
raise ValueError("Port cannot be 0")
|
||||
self._hostname = hostname
|
||||
self._port = port
|
||||
|
||||
# Each card corresponds to a ZMQ address.
|
||||
self.zmq_address = f"{self._hostname}:{self._port}"
|
||||
|
||||
# If `proxy_ip` or `proxy_port` is `""`,
|
||||
# then the ping thread will not be enabled.
|
||||
proxy_ip = self.config.get_from_extra_config("proxy_ip", "")
|
||||
proxy_port = self.config.get_from_extra_config("proxy_port", "")
|
||||
if proxy_ip == "" or proxy_port == "":
|
||||
self.proxy_address = ""
|
||||
self.http_address = ""
|
||||
else:
|
||||
self.proxy_address = proxy_ip + ":" + proxy_port
|
||||
# the `http_port` must be consistent with the port of OpenAI.
|
||||
http_port = self.config.get_from_extra_config("http_port", None)
|
||||
if http_port is None:
|
||||
example_cfg = {
|
||||
"kv_connector": "P2pNcclConnector",
|
||||
"kv_connector_extra_config": {"http_port": 8000},
|
||||
}
|
||||
example = (
|
||||
f"--port=8000 --kv-transfer-config='{json.dumps(example_cfg)}'"
|
||||
)
|
||||
raise ValueError(
|
||||
"kv_connector_extra_config.http_port is required. "
|
||||
f"Example: {example}"
|
||||
)
|
||||
self.http_address = f"{self._hostname}:{http_port}"
|
||||
|
||||
self.context = zmq.Context()
|
||||
self.router_socket = self.context.socket(zmq.ROUTER)
|
||||
self.router_socket.bind(f"tcp://{self.zmq_address}")
|
||||
|
||||
self.poller = zmq.Poller()
|
||||
self.poller.register(self.router_socket, zmq.POLLIN)
|
||||
|
||||
self.send_store_cv = threading.Condition()
|
||||
self.send_queue_cv = threading.Condition()
|
||||
self.recv_store_cv = threading.Condition()
|
||||
|
||||
self.send_stream = torch.cuda.Stream()
|
||||
self.recv_stream = torch.cuda.Stream()
|
||||
|
||||
mem_pool_size_gb = float(
|
||||
self.config.get_from_extra_config(
|
||||
"mem_pool_size_gb", DEFAULT_MEM_POOL_SIZE_GB
|
||||
)
|
||||
)
|
||||
self.pool = TensorMemoryPool(
|
||||
max_block_size=int(mem_pool_size_gb * 1024**3)
|
||||
) # GB
|
||||
|
||||
# The sending type includes tree mutually exclusive options:
|
||||
# PUT, GET, PUT_ASYNC.
|
||||
self.send_type = self.config.get_from_extra_config("send_type", "PUT_ASYNC")
|
||||
if self.send_type == "GET":
|
||||
# tensor_id: torch.Tensor
|
||||
self.send_store: dict[str, torch.Tensor] = {}
|
||||
else:
|
||||
# PUT or PUT_ASYNC
|
||||
# tensor_id: torch.Tensor
|
||||
self.send_queue: deque[SendQueueItem] = deque()
|
||||
if self.send_type == "PUT_ASYNC":
|
||||
self._send_thread = threading.Thread(
|
||||
target=self.send_async, daemon=True
|
||||
)
|
||||
self._send_thread.start()
|
||||
|
||||
# tensor_id: torch.Tensor/(addr, dtype, shape)
|
||||
self.recv_store: dict[str, Any] = {}
|
||||
self.recv_request_id_to_tensor_ids: dict[str, set[str]] = {}
|
||||
self.send_request_id_to_tensor_ids: dict[str, set[str]] = {}
|
||||
self.socks: dict[str, Any] = {} # remote_address: client socket
|
||||
self.comms: dict[str, Any] = {} # remote_address: (ncclComm_t, rank)
|
||||
|
||||
self.buffer_size = 0
|
||||
self.buffer_size_threshold = float(self.config.kv_buffer_size)
|
||||
|
||||
self.nccl_num_channels = self.config.get_from_extra_config(
|
||||
"nccl_num_channels", "8"
|
||||
)
|
||||
|
||||
self._listener_thread = threading.Thread(
|
||||
target=self.listen_for_requests, daemon=True
|
||||
)
|
||||
self._listener_thread.start()
|
||||
|
||||
self._ping_thread = None
|
||||
if port_offset == 0 and self.proxy_address != "":
|
||||
self._ping_thread = threading.Thread(target=self.ping, daemon=True)
|
||||
self._ping_thread.start()
|
||||
|
||||
logger.info(
|
||||
"💯P2pNcclEngine init, rank:%d, local_rank:%d, http_address:%s, "
|
||||
"zmq_address:%s, proxy_address:%s, send_type:%s, buffer_size_"
|
||||
"threshold:%.2f, nccl_num_channels:%s",
|
||||
self.rank,
|
||||
self.local_rank,
|
||||
self.http_address,
|
||||
self.zmq_address,
|
||||
self.proxy_address,
|
||||
self.send_type,
|
||||
self.buffer_size_threshold,
|
||||
self.nccl_num_channels,
|
||||
)
|
||||
|
||||
def create_connect(self, remote_address: str | None = None):
|
||||
assert remote_address is not None
|
||||
if remote_address not in self.socks:
|
||||
sock = self.context.socket(zmq.DEALER)
|
||||
sock.setsockopt_string(zmq.IDENTITY, self.zmq_address)
|
||||
sock.connect(f"tcp://{remote_address}")
|
||||
self.socks[remote_address] = sock
|
||||
if remote_address in self.comms:
|
||||
logger.info(
|
||||
"👋comm exists, remote_address:%s, comms:%s",
|
||||
remote_address,
|
||||
self.comms,
|
||||
)
|
||||
return sock, self.comms[remote_address]
|
||||
|
||||
unique_id = self.nccl.ncclGetUniqueId()
|
||||
data = {"cmd": "NEW", "unique_id": bytes(unique_id.internal)}
|
||||
sock.send(msgpack.dumps(data))
|
||||
|
||||
with torch.cuda.device(self.device):
|
||||
rank = 0
|
||||
with set_p2p_nccl_context(self.nccl_num_channels):
|
||||
comm: ncclComm_t = self.nccl.ncclCommInitRank(2, unique_id, rank)
|
||||
self.comms[remote_address] = (comm, rank)
|
||||
logger.info(
|
||||
"🤝ncclCommInitRank Success, %s👉%s, MyRank:%s",
|
||||
self.zmq_address,
|
||||
remote_address,
|
||||
rank,
|
||||
)
|
||||
|
||||
return self.socks[remote_address], self.comms[remote_address]
|
||||
|
||||
def send_tensor(
|
||||
self,
|
||||
tensor_id: str,
|
||||
tensor: torch.Tensor,
|
||||
remote_address: str | None = None,
|
||||
) -> bool:
|
||||
if remote_address is None:
|
||||
with self.recv_store_cv:
|
||||
self.recv_store[tensor_id] = tensor
|
||||
self.recv_store_cv.notify()
|
||||
return True
|
||||
|
||||
item = SendQueueItem(
|
||||
tensor_id=tensor_id, remote_address=remote_address, tensor=tensor
|
||||
)
|
||||
|
||||
if self.send_type == "PUT":
|
||||
return self.send_sync(item)
|
||||
|
||||
if self.send_type == "PUT_ASYNC":
|
||||
with self.send_queue_cv:
|
||||
self.send_queue.append(item)
|
||||
self.send_queue_cv.notify()
|
||||
return True
|
||||
|
||||
# GET
|
||||
with self.send_store_cv:
|
||||
tensor_size = tensor.element_size() * tensor.numel()
|
||||
if tensor_size > self.buffer_size_threshold:
|
||||
logger.warning(
|
||||
"❗[GET]tensor_id:%s, tensor_size:%d, is greater than"
|
||||
"buffer size threshold :%d, skip send to %s, rank:%d",
|
||||
tensor_id,
|
||||
tensor_size,
|
||||
self.buffer_size_threshold,
|
||||
remote_address,
|
||||
self.rank,
|
||||
)
|
||||
return False
|
||||
while self.buffer_size + tensor_size > self.buffer_size_threshold:
|
||||
assert len(self.send_store) > 0
|
||||
oldest_tensor_id = next(iter(self.send_store))
|
||||
oldest_tensor = self.send_store.pop(oldest_tensor_id)
|
||||
oldest_tensor_size = (
|
||||
oldest_tensor.element_size() * oldest_tensor.numel()
|
||||
)
|
||||
self.buffer_size -= oldest_tensor_size
|
||||
logger.debug(
|
||||
"⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d,"
|
||||
" buffer_size:%d, oldest_tensor_size:%d, rank:%d",
|
||||
remote_address,
|
||||
tensor_id,
|
||||
tensor_size,
|
||||
self.buffer_size,
|
||||
oldest_tensor_size,
|
||||
self.rank,
|
||||
)
|
||||
|
||||
self.send_store[tensor_id] = tensor
|
||||
self.buffer_size += tensor_size
|
||||
logger.debug(
|
||||
"🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, "
|
||||
"shape:%s, rank:%d, buffer_size:%d(%.2f%%)",
|
||||
remote_address,
|
||||
tensor_id,
|
||||
tensor_size,
|
||||
tensor.shape,
|
||||
self.rank,
|
||||
self.buffer_size,
|
||||
self.buffer_size / self.buffer_size_threshold * 100,
|
||||
)
|
||||
return True
|
||||
|
||||
def recv_tensor(
|
||||
self,
|
||||
tensor_id: str,
|
||||
remote_address: str | None = None,
|
||||
) -> torch.Tensor:
|
||||
if self.send_type == "PUT" or self.send_type == "PUT_ASYNC":
|
||||
start_time = time.time()
|
||||
with self.recv_store_cv:
|
||||
while tensor_id not in self.recv_store:
|
||||
self.recv_store_cv.wait()
|
||||
tensor = self.recv_store[tensor_id]
|
||||
|
||||
if tensor is not None:
|
||||
if isinstance(tensor, tuple):
|
||||
addr, dtype, shape = tensor
|
||||
tensor = self.pool.load_tensor(addr, dtype, shape, self.device)
|
||||
else:
|
||||
self.buffer_size -= tensor.element_size() * tensor.numel()
|
||||
else:
|
||||
duration = time.time() - start_time
|
||||
logger.warning(
|
||||
"🔴[PUT]Recv From %s, tensor_id:%s, duration:%.3fms, rank:%d",
|
||||
remote_address,
|
||||
tensor_id,
|
||||
duration * 1000,
|
||||
self.rank,
|
||||
)
|
||||
return tensor
|
||||
|
||||
# GET
|
||||
if remote_address is None:
|
||||
return None
|
||||
|
||||
if remote_address not in self.socks:
|
||||
self.create_connect(remote_address)
|
||||
|
||||
sock = self.socks[remote_address]
|
||||
comm, rank = self.comms[remote_address]
|
||||
|
||||
data = {"cmd": "GET", "tensor_id": tensor_id}
|
||||
sock.send(msgpack.dumps(data))
|
||||
|
||||
message = sock.recv()
|
||||
data = msgpack.loads(message)
|
||||
if data["ret"] != 0:
|
||||
logger.warning(
|
||||
"🔴[GET]Recv From %s, tensor_id: %s, ret: %d",
|
||||
remote_address,
|
||||
tensor_id,
|
||||
data["ret"],
|
||||
)
|
||||
return None
|
||||
|
||||
with torch.cuda.stream(self.recv_stream):
|
||||
tensor = torch.empty(
|
||||
data["shape"], dtype=getattr(torch, data["dtype"]), device=self.device
|
||||
)
|
||||
|
||||
self.recv(comm, tensor, rank ^ 1, self.recv_stream)
|
||||
|
||||
return tensor
|
||||
|
||||
def listen_for_requests(self):
|
||||
while True:
|
||||
socks = dict(self.poller.poll())
|
||||
if self.router_socket not in socks:
|
||||
continue
|
||||
|
||||
remote_address, message = self.router_socket.recv_multipart()
|
||||
data = msgpack.loads(message)
|
||||
if data["cmd"] == "NEW":
|
||||
unique_id = self.nccl.unique_id_from_bytes(bytes(data["unique_id"]))
|
||||
with torch.cuda.device(self.device):
|
||||
rank = 1
|
||||
with set_p2p_nccl_context(self.nccl_num_channels):
|
||||
comm: ncclComm_t = self.nccl.ncclCommInitRank(
|
||||
2, unique_id, rank
|
||||
)
|
||||
self.comms[remote_address.decode()] = (comm, rank)
|
||||
logger.info(
|
||||
"🤝ncclCommInitRank Success, %s👈%s, MyRank:%s",
|
||||
self.zmq_address,
|
||||
remote_address.decode(),
|
||||
rank,
|
||||
)
|
||||
elif data["cmd"] == "PUT":
|
||||
tensor_id = data["tensor_id"]
|
||||
try:
|
||||
with torch.cuda.stream(self.recv_stream):
|
||||
tensor = torch.empty(
|
||||
data["shape"],
|
||||
dtype=getattr(torch, data["dtype"]),
|
||||
device=self.device,
|
||||
)
|
||||
self.router_socket.send_multipart([remote_address, b"0"])
|
||||
comm, rank = self.comms[remote_address.decode()]
|
||||
self.recv(comm, tensor, rank ^ 1, self.recv_stream)
|
||||
tensor_size = tensor.element_size() * tensor.numel()
|
||||
if self.buffer_size + tensor_size > self.buffer_size_threshold:
|
||||
# Store Tensor in memory pool
|
||||
addr = self.pool.store_tensor(tensor)
|
||||
tensor = (addr, tensor.dtype, tensor.shape)
|
||||
logger.warning(
|
||||
"🔴[PUT]Recv Tensor, Out Of Threshold, "
|
||||
"%s👈%s, data:%s, addr:%d",
|
||||
self.zmq_address,
|
||||
remote_address.decode(),
|
||||
data,
|
||||
addr,
|
||||
)
|
||||
else:
|
||||
self.buffer_size += tensor_size
|
||||
|
||||
except torch.cuda.OutOfMemoryError:
|
||||
self.router_socket.send_multipart([remote_address, b"1"])
|
||||
tensor = None
|
||||
logger.warning(
|
||||
"🔴[PUT]Recv Tensor, Out Of Memory, %s👈%s, data:%s",
|
||||
self.zmq_address,
|
||||
remote_address.decode(),
|
||||
data,
|
||||
)
|
||||
|
||||
with self.recv_store_cv:
|
||||
self.recv_store[tensor_id] = tensor
|
||||
self.have_received_tensor_id(tensor_id)
|
||||
self.recv_store_cv.notify()
|
||||
|
||||
elif data["cmd"] == "GET":
|
||||
tensor_id = data["tensor_id"]
|
||||
with self.send_store_cv:
|
||||
tensor = self.send_store.pop(tensor_id, None)
|
||||
if tensor is not None:
|
||||
data = {
|
||||
"ret": 0,
|
||||
"shape": tensor.shape,
|
||||
"dtype": str(tensor.dtype).replace("torch.", ""),
|
||||
}
|
||||
# LRU
|
||||
self.send_store[tensor_id] = tensor
|
||||
self.have_sent_tensor_id(tensor_id)
|
||||
else:
|
||||
data = {"ret": 1}
|
||||
|
||||
self.router_socket.send_multipart([remote_address, msgpack.dumps(data)])
|
||||
|
||||
if data["ret"] == 0:
|
||||
comm, rank = self.comms[remote_address.decode()]
|
||||
self.send(comm, tensor.to(self.device), rank ^ 1, self.send_stream)
|
||||
else:
|
||||
logger.warning(
|
||||
"🚧Unexpected, Received message from %s, data:%s",
|
||||
remote_address,
|
||||
data,
|
||||
)
|
||||
|
||||
def have_sent_tensor_id(self, tensor_id: str):
|
||||
request_id = tensor_id.split("#")[0]
|
||||
if request_id not in self.send_request_id_to_tensor_ids:
|
||||
self.send_request_id_to_tensor_ids[request_id] = set()
|
||||
self.send_request_id_to_tensor_ids[request_id].add(tensor_id)
|
||||
|
||||
def have_received_tensor_id(self, tensor_id: str):
|
||||
request_id = tensor_id.split("#")[0]
|
||||
if request_id not in self.recv_request_id_to_tensor_ids:
|
||||
self.recv_request_id_to_tensor_ids[request_id] = set()
|
||||
self.recv_request_id_to_tensor_ids[request_id].add(tensor_id)
|
||||
|
||||
def send_async(self):
|
||||
while True:
|
||||
with self.send_queue_cv:
|
||||
while not self.send_queue:
|
||||
self.send_queue_cv.wait()
|
||||
item = self.send_queue.popleft()
|
||||
if not self.send_queue:
|
||||
self.send_queue_cv.notify()
|
||||
self.send_sync(item)
|
||||
|
||||
def wait_for_sent(self):
|
||||
if self.send_type == "PUT_ASYNC":
|
||||
start_time = time.time()
|
||||
with self.send_queue_cv:
|
||||
while self.send_queue:
|
||||
self.send_queue_cv.wait()
|
||||
duration = time.time() - start_time
|
||||
logger.debug(
|
||||
"🚧[PUT_ASYNC]It took %.3fms to wait for the send_queue"
|
||||
" to be empty, rank:%d",
|
||||
duration * 1000,
|
||||
self.rank,
|
||||
)
|
||||
|
||||
def send_sync(self, item: SendQueueItem) -> bool:
|
||||
if item.remote_address is None:
|
||||
return False
|
||||
if item.remote_address not in self.socks:
|
||||
self.create_connect(item.remote_address)
|
||||
|
||||
tensor = item.tensor
|
||||
|
||||
sock = self.socks[item.remote_address]
|
||||
comm, rank = self.comms[item.remote_address]
|
||||
data = {
|
||||
"cmd": "PUT",
|
||||
"tensor_id": item.tensor_id,
|
||||
"shape": tensor.shape,
|
||||
"dtype": str(tensor.dtype).replace("torch.", ""),
|
||||
}
|
||||
sock.send(msgpack.dumps(data))
|
||||
|
||||
response = sock.recv()
|
||||
if response != b"0":
|
||||
logger.error(
|
||||
"🔴Send Tensor, Peer Out Of Memory/Threshold, %s 👉 %s, "
|
||||
"MyRank:%s, data:%s, tensor:%s, size:%fGB, response:%s",
|
||||
self.zmq_address,
|
||||
item.remote_address,
|
||||
rank,
|
||||
data,
|
||||
tensor.shape,
|
||||
tensor.element_size() * tensor.numel() / 1024**3,
|
||||
response.decode(),
|
||||
)
|
||||
return False
|
||||
|
||||
self.send(comm, tensor.to(self.device), rank ^ 1, self.send_stream)
|
||||
|
||||
if self.send_type == "PUT_ASYNC":
|
||||
self.have_sent_tensor_id(item.tensor_id)
|
||||
|
||||
return True
|
||||
|
||||
def get_finished(
|
||||
self, finished_req_ids: set[str], no_compile_layers
|
||||
) -> tuple[set[str] | None, set[str] | None]:
|
||||
"""
|
||||
Notifies worker-side connector ids of requests that have
|
||||
finished generating tokens.
|
||||
|
||||
Returns:
|
||||
ids of requests that have finished asynchronous transfer,
|
||||
tuple of (sending/saving ids, recving/loading ids).
|
||||
The finished saves/sends req ids must belong to a set provided in a
|
||||
call to this method (this call or a prior one).
|
||||
"""
|
||||
|
||||
# Clear the buffer upon request completion.
|
||||
for request_id in finished_req_ids:
|
||||
for layer_name in no_compile_layers:
|
||||
tensor_id = request_id + "#" + layer_name
|
||||
if tensor_id in self.recv_store:
|
||||
with self.recv_store_cv:
|
||||
tensor = self.recv_store.pop(tensor_id, None)
|
||||
self.send_request_id_to_tensor_ids.pop(request_id, None)
|
||||
self.recv_request_id_to_tensor_ids.pop(request_id, None)
|
||||
if isinstance(tensor, tuple):
|
||||
addr, _, _ = tensor
|
||||
self.pool.free(addr)
|
||||
|
||||
# TODO:Retrieve requests that have already sent the KV cache.
|
||||
finished_sending: set[str] = set()
|
||||
|
||||
# TODO:Retrieve requests that have already received the KV cache.
|
||||
finished_recving: set[str] = set()
|
||||
|
||||
return finished_sending or None, finished_recving or None
|
||||
|
||||
def ping(self):
|
||||
sock = self.context.socket(zmq.DEALER)
|
||||
sock.setsockopt_string(zmq.IDENTITY, self.zmq_address)
|
||||
logger.debug("ping start, zmq_address:%s", self.zmq_address)
|
||||
sock.connect(f"tcp://{self.proxy_address}")
|
||||
data = {
|
||||
"type": "P" if self.config.is_kv_producer else "D",
|
||||
"http_address": self.http_address,
|
||||
"zmq_address": self.zmq_address,
|
||||
}
|
||||
while True:
|
||||
sock.send(msgpack.dumps(data))
|
||||
time.sleep(3)
|
||||
|
||||
def send(self, comm, tensor: torch.Tensor, dst: int, stream=None):
|
||||
assert tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}"
|
||||
)
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
|
||||
with torch.cuda.stream(stream):
|
||||
self.nccl.ncclSend(
|
||||
buffer_type(tensor.data_ptr()),
|
||||
tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(tensor.dtype),
|
||||
dst,
|
||||
comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
)
|
||||
stream.synchronize()
|
||||
|
||||
def recv(self, comm, tensor: torch.Tensor, src: int, stream=None):
|
||||
assert tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}"
|
||||
)
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
|
||||
with torch.cuda.stream(stream):
|
||||
self.nccl.ncclRecv(
|
||||
buffer_type(tensor.data_ptr()),
|
||||
tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(tensor.dtype),
|
||||
src,
|
||||
comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
)
|
||||
stream.synchronize()
|
||||
|
||||
def close(self) -> None:
|
||||
self._listener_thread.join()
|
||||
if self.send_type == "PUT_ASYNC":
|
||||
self._send_thread.join()
|
||||
if self._ping_thread is not None:
|
||||
self._ping_thread.join()
|
||||
@@ -0,0 +1,273 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import atexit
|
||||
import ctypes
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryBlock:
|
||||
size: int
|
||||
addr: int
|
||||
|
||||
|
||||
"""A memory pool for managing pinned host memory allocations for tensors.
|
||||
|
||||
This class implements a buddy allocation system to efficiently manage pinned
|
||||
host memory for tensor storage. It supports allocation, deallocation, and
|
||||
tensor storage/retrieval operations.
|
||||
|
||||
Key Features:
|
||||
- Uses power-of-two block sizes for efficient buddy allocation
|
||||
- Supports splitting and merging of memory blocks
|
||||
- Provides methods to store CUDA tensors in pinned host memory
|
||||
- Allows loading tensors from pinned memory back to device
|
||||
- Automatically cleans up memory on destruction
|
||||
|
||||
Attributes:
|
||||
max_block_size (int): Maximum block size (rounded to nearest power of two)
|
||||
min_block_size (int): Minimum block size (rounded to nearest power of two)
|
||||
free_lists (dict): Dictionary of free memory blocks by size
|
||||
allocated_blocks (dict): Dictionary of currently allocated blocks
|
||||
base_tensor (torch.Tensor): Base pinned memory tensor
|
||||
base_address (int): Base memory address of the pinned memory region
|
||||
|
||||
Example:
|
||||
>>> pool = TensorMemoryPool(max_block_size=1024*1024)
|
||||
>>> tensor = torch.randn(100, device='cuda')
|
||||
>>> addr = pool.store_tensor(tensor)
|
||||
>>> loaded_tensor = pool.load_tensor(addr, tensor.dtype,
|
||||
... tensor.shape, 'cuda')
|
||||
>>> pool.free(addr)
|
||||
"""
|
||||
|
||||
|
||||
class TensorMemoryPool:
|
||||
"""Initializes the memory pool with given size constraints.
|
||||
|
||||
Args:
|
||||
max_block_size (int): Maximum size of memory blocks to manage
|
||||
min_block_size (int, optional): Minimum size of memory blocks
|
||||
to manage. Defaults to 512.
|
||||
|
||||
Raises:
|
||||
ValueError: If block sizes are invalid or max_block_size is less
|
||||
than min_block_size
|
||||
"""
|
||||
|
||||
def __init__(self, max_block_size: int, min_block_size: int = 512):
|
||||
if max_block_size <= 0 or min_block_size <= 0:
|
||||
raise ValueError("Block sizes must be positive")
|
||||
if max_block_size < min_block_size:
|
||||
raise ValueError("Max block size must be greater than min block size")
|
||||
|
||||
self.max_block_size = self._round_to_power_of_two(max_block_size)
|
||||
self.min_block_size = self._round_to_power_of_two(min_block_size)
|
||||
|
||||
self.free_lists: dict[int, dict[int, MemoryBlock]] = {}
|
||||
self.allocated_blocks: dict[int, MemoryBlock] = {}
|
||||
|
||||
self._initialize_free_lists()
|
||||
self._allocate_pinned_memory()
|
||||
|
||||
atexit.register(self.cleanup)
|
||||
|
||||
def _round_to_power_of_two(self, size: int) -> int:
|
||||
return 1 << (size - 1).bit_length()
|
||||
|
||||
def _initialize_free_lists(self):
|
||||
size = self.max_block_size
|
||||
while size >= self.min_block_size:
|
||||
self.free_lists[size] = {}
|
||||
size //= 2
|
||||
|
||||
def _allocate_pinned_memory(self):
|
||||
self.base_tensor = torch.empty(
|
||||
self.max_block_size // 4, dtype=torch.float32, pin_memory=True
|
||||
)
|
||||
self.base_address = self.base_tensor.data_ptr()
|
||||
initial_block = MemoryBlock(size=self.max_block_size, addr=self.base_address)
|
||||
self.free_lists[self.max_block_size][initial_block.addr] = initial_block
|
||||
|
||||
logger.debug(
|
||||
"TensorMemoryPool, base_address:%d, max_block_size:%d",
|
||||
self.base_address,
|
||||
self.max_block_size,
|
||||
)
|
||||
|
||||
def allocate(self, size: int) -> int:
|
||||
"""Allocates a memory block of at least the requested size.
|
||||
|
||||
Args:
|
||||
size (int): Minimum size of memory to allocate
|
||||
|
||||
Returns:
|
||||
int: Address of the allocated memory block
|
||||
|
||||
Raises:
|
||||
ValueError: If size is invalid or insufficient memory is available
|
||||
"""
|
||||
if size <= 0:
|
||||
raise ValueError("Allocation size must be positive")
|
||||
|
||||
required_size = self._round_to_power_of_two(max(size, self.min_block_size))
|
||||
if required_size > self.max_block_size:
|
||||
raise ValueError("Requested size exceeds maximum block size")
|
||||
|
||||
current_size = required_size
|
||||
while current_size <= self.max_block_size:
|
||||
if self.free_lists[current_size]:
|
||||
_, block = self.free_lists[current_size].popitem()
|
||||
self._split_block(block, required_size)
|
||||
self.allocated_blocks[block.addr] = block
|
||||
return block.addr
|
||||
current_size *= 2
|
||||
|
||||
raise ValueError("Insufficient memory")
|
||||
|
||||
def _split_block(self, block: MemoryBlock, required_size: int):
|
||||
while block.size > required_size and block.size // 2 >= self.min_block_size:
|
||||
buddy_size = block.size // 2
|
||||
buddy_addr = block.addr + buddy_size
|
||||
|
||||
buddy = MemoryBlock(size=buddy_size, addr=buddy_addr)
|
||||
block.size = buddy_size
|
||||
|
||||
self.free_lists[buddy_size][buddy.addr] = buddy
|
||||
|
||||
def free(self, addr: int):
|
||||
"""Frees an allocated memory block.
|
||||
|
||||
Args:
|
||||
addr (int): Address of the block to free
|
||||
|
||||
Raises:
|
||||
ValueError: If address is invalid or not allocated
|
||||
"""
|
||||
if addr not in self.allocated_blocks:
|
||||
raise ValueError("Invalid address to free")
|
||||
|
||||
block = self.allocated_blocks.pop(addr)
|
||||
self._merge_buddies(block)
|
||||
|
||||
def _merge_buddies(self, block: MemoryBlock):
|
||||
MAX_MERGE_DEPTH = 30
|
||||
depth = 0
|
||||
|
||||
while depth < MAX_MERGE_DEPTH:
|
||||
buddy_offset = (
|
||||
block.size
|
||||
if (block.addr - self.base_address) % (2 * block.size) == 0
|
||||
else -block.size
|
||||
)
|
||||
buddy_addr = block.addr + buddy_offset
|
||||
buddy = self.free_lists[block.size].get(buddy_addr)
|
||||
if buddy:
|
||||
del self.free_lists[buddy.size][buddy.addr]
|
||||
merged_addr = min(block.addr, buddy.addr)
|
||||
merged_size = block.size * 2
|
||||
block = MemoryBlock(size=merged_size, addr=merged_addr)
|
||||
depth += 1
|
||||
else:
|
||||
break
|
||||
self.free_lists[block.size][block.addr] = block
|
||||
|
||||
def store_tensor(self, tensor: torch.Tensor) -> int:
|
||||
"""Stores a CUDA tensor in pinned host memory.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): CUDA tensor to store
|
||||
|
||||
Returns:
|
||||
int: Address where the tensor is stored
|
||||
|
||||
Raises:
|
||||
ValueError: If tensor is not on CUDA or allocation fails
|
||||
"""
|
||||
if not tensor.is_cuda:
|
||||
raise ValueError("Only CUDA tensors can be stored")
|
||||
|
||||
size = tensor.element_size() * tensor.numel()
|
||||
addr = self.allocate(size)
|
||||
block = self.allocated_blocks[addr]
|
||||
|
||||
if block.size < size:
|
||||
self.free(addr)
|
||||
raise ValueError(
|
||||
f"Allocated block size {block.size} is smaller than "
|
||||
f"required size {size}"
|
||||
)
|
||||
|
||||
try:
|
||||
buffer = (ctypes.c_byte * block.size).from_address(block.addr)
|
||||
cpu_tensor = torch.frombuffer(
|
||||
buffer, dtype=tensor.dtype, count=tensor.numel()
|
||||
).reshape(tensor.shape)
|
||||
except ValueError as err:
|
||||
self.free(addr)
|
||||
raise ValueError(f"Failed to create tensor view: {err}") from err
|
||||
|
||||
cpu_tensor.copy_(tensor)
|
||||
|
||||
return addr
|
||||
|
||||
def load_tensor(
|
||||
self,
|
||||
addr: int,
|
||||
dtype: torch.dtype,
|
||||
shape: tuple[int, ...],
|
||||
device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
"""Loads a tensor from pinned host memory to the specified device.
|
||||
|
||||
Args:
|
||||
addr (int): Address where tensor is stored
|
||||
dtype (torch.dtype): Data type of the tensor
|
||||
shape (tuple[int, ...]): Shape of the tensor
|
||||
device: Target device for the loaded tensor
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The loaded tensor on the specified device
|
||||
|
||||
Raises:
|
||||
ValueError: If address is invalid or sizes don't match
|
||||
"""
|
||||
if addr not in self.allocated_blocks:
|
||||
raise ValueError("Invalid address to load")
|
||||
|
||||
block = self.allocated_blocks[addr]
|
||||
num_elements = math.prod(shape)
|
||||
dtype_size = torch.tensor([], dtype=dtype).element_size()
|
||||
required_size = num_elements * dtype_size
|
||||
|
||||
if required_size > block.size:
|
||||
raise ValueError("Requested tensor size exceeds block size")
|
||||
|
||||
buffer = (ctypes.c_byte * block.size).from_address(block.addr)
|
||||
cpu_tensor = torch.frombuffer(buffer, dtype=dtype, count=num_elements).reshape(
|
||||
shape
|
||||
)
|
||||
|
||||
cuda_tensor = torch.empty(shape, dtype=dtype, device=device)
|
||||
|
||||
cuda_tensor.copy_(cpu_tensor)
|
||||
|
||||
return cuda_tensor
|
||||
|
||||
def cleanup(self):
|
||||
"""Cleans up all memory resources and resets the pool state."""
|
||||
self.free_lists.clear()
|
||||
self.allocated_blocks.clear()
|
||||
if hasattr(self, "base_tensor"):
|
||||
del self.base_tensor
|
||||
|
||||
def __del__(self):
|
||||
self.cleanup()
|
||||
Reference in New Issue
Block a user