Sync from v0.13
This commit is contained in:
29
vllm/distributed/kv_transfer/README.md
Normal file
29
vllm/distributed/kv_transfer/README.md
Normal file
@@ -0,0 +1,29 @@
|
||||
|
||||
# Distributed KV cache transfer
|
||||
|
||||
This folder implements distributed KV cache transfer across vLLM instances.
|
||||
Currently the main use case is for disaggregated prefilling.
|
||||
|
||||
## Abstractions
|
||||
|
||||
The KV cache transfer contains three layer of abstractions:
|
||||
|
||||
- KV pipe: a FIFO pipe for torch.tensor transmission. Key APIs: `send_tensor` and `recv_tensor`.
|
||||
- KV lookup buffer: a lookup buffer for KV caches. Key: the tokens, value: the KV caches (and/or hidden states). Key APIs: `insert` and `drop_select` (similar to SQL semantics).
|
||||
- KV connector: a connector that connects the KV pipe and KV lookup buffer to vLLM. Key APIs: `send_kv_caches_and_hidden_states` and `recv_kv_caches_and_hidden_states`.
|
||||
|
||||
Why we need KV lookup buffer: FIFO pipe itself is not enough as prefill vLLM worker may process requests in a different order compared to decode vLLM worker. Say the QPS is really high, prefill worker may handle requests in order A -> B -> C, but the decode worker may process request C first. This is not the case that can be naturally handled by FIFO pipe, so we provide KV lookup buffer to help translate a FIFO pipe to a lookup buffer.
|
||||
|
||||
NOTE: KV pipe layer is bypassable: you can skip this layer if your distributed
|
||||
communication service already supports key-value-based lookup (like redis or
|
||||
RDMA database).
|
||||
|
||||
NOTE: If you want to not only transfer KV caches, but adjust the model execution flow of vLLM as well (for example, allow vLLM to receive KV caches on some tokens and do prefill on the remaining tokens), you can bypass both KV pipe layer and KV lookup buffer layer, and directly implement on KV connector layer. Bear in mind that as vLLM's model input is constantly changing, this implementation will likely be broken when vLLM has new updates.
|
||||
|
||||
## Disaggregated prefilling
|
||||
|
||||
The example usage is in [this file](../../../examples/online_serving/disaggregated_prefill.sh).
|
||||
|
||||
Here is the diagram of how we run disaggregated prefilling.
|
||||
|
||||

|
||||
20
vllm/distributed/kv_transfer/__init__.py
Normal file
20
vllm/distributed/kv_transfer/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_transfer_state import (
|
||||
KVConnectorBaseType,
|
||||
ensure_kv_transfer_initialized,
|
||||
ensure_kv_transfer_shutdown,
|
||||
get_kv_transfer_group,
|
||||
has_kv_transfer_group,
|
||||
is_v1_kv_transfer_group,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"get_kv_transfer_group",
|
||||
"has_kv_transfer_group",
|
||||
"is_v1_kv_transfer_group",
|
||||
"ensure_kv_transfer_initialized",
|
||||
"ensure_kv_transfer_shutdown",
|
||||
"KVConnectorBaseType",
|
||||
]
|
||||
BIN
vllm/distributed/kv_transfer/disagg_prefill_workflow.jpg
Normal file
BIN
vllm/distributed/kv_transfer/disagg_prefill_workflow.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 139 KiB |
10
vllm/distributed/kv_transfer/kv_connector/base.py
Normal file
10
vllm/distributed/kv_transfer/kv_connector/base.py
Normal file
@@ -0,0 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Defines the base type for KV cache connectors."""
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
|
||||
|
||||
KVConnectorBase = KVConnectorBase_V1
|
||||
KVConnectorBaseType = KVConnectorBase_V1
|
||||
|
||||
__all__ = ["KVConnectorBase", "KVConnectorBaseType"]
|
||||
197
vllm/distributed/kv_transfer/kv_connector/factory.py
Normal file
197
vllm/distributed/kv_transfer/kv_connector/factory.py
Normal file
@@ -0,0 +1,197 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import importlib
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, Optional, cast
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_connector.base import (
|
||||
KVConnectorBase,
|
||||
KVConnectorBaseType,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import (
|
||||
KVConnectorRole,
|
||||
supports_hma,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.func_utils import supports_kw
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.kv_transfer import KVTransferConfig
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class KVConnectorFactory:
|
||||
_registry: dict[str, Callable[[], type[KVConnectorBase]]] = {}
|
||||
|
||||
@classmethod
|
||||
def register_connector(cls, name: str, module_path: str, class_name: str) -> None:
|
||||
"""Register a connector with a lazy-loading module and class name."""
|
||||
if name in cls._registry:
|
||||
raise ValueError(f"Connector '{name}' is already registered.")
|
||||
|
||||
def loader() -> type[KVConnectorBase]:
|
||||
module = importlib.import_module(module_path)
|
||||
return getattr(module, class_name)
|
||||
|
||||
cls._registry[name] = loader
|
||||
|
||||
@classmethod
|
||||
def create_connector(
|
||||
cls,
|
||||
config: "VllmConfig",
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: Optional["KVCacheConfig"] = None,
|
||||
) -> KVConnectorBase:
|
||||
kv_transfer_config = config.kv_transfer_config
|
||||
if kv_transfer_config is None:
|
||||
raise ValueError("kv_transfer_config must be set to create a connector")
|
||||
connector_cls, compat_sig = cls._get_connector_class_with_compat(
|
||||
kv_transfer_config
|
||||
)
|
||||
|
||||
# check if the connector supports HMA
|
||||
hma_enabled = not config.scheduler_config.disable_hybrid_kv_cache_manager
|
||||
if hma_enabled and not supports_hma(connector_cls):
|
||||
raise ValueError(
|
||||
f"Connector {connector_cls.__name__} does not support HMA but "
|
||||
f"HMA is enabled. Please set `--disable-hybrid-kv-cache-manager`."
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Creating v1 connector with name: %s and engine_id: %s",
|
||||
connector_cls.__name__,
|
||||
kv_transfer_config.engine_id,
|
||||
)
|
||||
# NOTE(Kuntai): v1 connector is explicitly separated into two roles.
|
||||
# Scheduler connector:
|
||||
# - Co-locate with scheduler process
|
||||
# - Should only be used inside the Scheduler class
|
||||
# Worker connector:
|
||||
# - Co-locate with worker process
|
||||
# - Should only be used inside the forward context & attention layer
|
||||
# We build separately to enforce strict separation
|
||||
if compat_sig:
|
||||
# Old signature: __init__(self, vllm_config, role)
|
||||
return connector_cls(config, role)
|
||||
else:
|
||||
# New signature: __init__(self, vllm_config, role, kv_cache_config)
|
||||
return connector_cls(config, role, kv_cache_config)
|
||||
|
||||
@classmethod
|
||||
def get_connector_class_by_name(
|
||||
cls, connector_name: str
|
||||
) -> type[KVConnectorBaseType]:
|
||||
"""Get a registered connector class by name.
|
||||
|
||||
Raises ValueError if the connector is not registered.
|
||||
|
||||
Args:
|
||||
connector_name: Name of the registered connector.
|
||||
|
||||
Returns:
|
||||
The connector class.
|
||||
"""
|
||||
if connector_name not in cls._registry:
|
||||
raise ValueError(f"Connector '{connector_name}' is not registered.")
|
||||
return cls._registry[connector_name]()
|
||||
|
||||
@classmethod
|
||||
def _get_connector_class_with_compat(
|
||||
cls, kv_transfer_config: "KVTransferConfig"
|
||||
) -> tuple[type[KVConnectorBaseType], bool]:
|
||||
connector_name = kv_transfer_config.kv_connector
|
||||
if connector_name is None:
|
||||
raise ValueError("Connector name is not set in KVTransferConfig")
|
||||
compat_sig = False
|
||||
if connector_name in cls._registry:
|
||||
connector_cls = cls._registry[connector_name]()
|
||||
else:
|
||||
connector_module_path = kv_transfer_config.kv_connector_module_path
|
||||
if connector_module_path is None:
|
||||
raise ValueError(f"Unsupported connector type: {connector_name}")
|
||||
connector_module = importlib.import_module(connector_module_path)
|
||||
try:
|
||||
connector_cls = getattr(connector_module, connector_name)
|
||||
except AttributeError as e:
|
||||
raise AttributeError(
|
||||
f"Class {connector_name} not found in {connector_module_path}"
|
||||
) from e
|
||||
connector_cls = cast(type[KVConnectorBaseType], connector_cls)
|
||||
if not supports_kw(connector_cls, "kv_cache_config"):
|
||||
compat_sig = True
|
||||
logger.warning(
|
||||
"Connector %s uses deprecated signature with 2 required arguments. "
|
||||
"Please update to include kv_cache_config as the second argument.",
|
||||
connector_cls.__name__,
|
||||
)
|
||||
return connector_cls, compat_sig
|
||||
|
||||
@classmethod
|
||||
def get_connector_class(
|
||||
cls, kv_transfer_config: "KVTransferConfig"
|
||||
) -> type[KVConnectorBaseType]:
|
||||
"""Get the connector class by name."""
|
||||
connector_cls, _ = cls._get_connector_class_with_compat(kv_transfer_config)
|
||||
return connector_cls
|
||||
|
||||
|
||||
# Register various connectors here.
|
||||
# The registration should not be done in each individual file, as we want to
|
||||
# only load the files corresponding to the current connector.
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"ExampleConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.example_connector",
|
||||
"ExampleConnector",
|
||||
)
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"P2pNcclConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_connector",
|
||||
"P2pNcclConnector",
|
||||
)
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"LMCacheConnectorV1",
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector",
|
||||
"LMCacheConnectorV1",
|
||||
)
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"LMCacheMPConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.lmcache_mp_connector",
|
||||
"LMCacheMPConnector",
|
||||
)
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"NixlConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector",
|
||||
"NixlConnector",
|
||||
)
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"MultiConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.multi_connector",
|
||||
"MultiConnector",
|
||||
)
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"OffloadingConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.offloading_connector",
|
||||
"OffloadingConnector",
|
||||
)
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"DecodeBenchConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.decode_bench_connector",
|
||||
"DecodeBenchConnector",
|
||||
)
|
||||
KVConnectorFactory.register_connector(
|
||||
"MooncakeConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.mooncake_connector",
|
||||
"MooncakeConnector",
|
||||
)
|
||||
322
vllm/distributed/kv_transfer/kv_connector/utils.py
Normal file
322
vllm/distributed/kv_transfer/kv_connector/utils.py
Normal file
@@ -0,0 +1,322 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
KV cache helper for store.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def get_kv_connector_cache_layout():
|
||||
# NOTE (NickLucche) When running disaggregated PD with NIXL, HND layout is
|
||||
# used for faster transfer.
|
||||
vllm_config = get_current_vllm_config()
|
||||
kv_config = vllm_config.kv_transfer_config
|
||||
if kv_config is not None:
|
||||
connector_cls = KVConnectorFactory.get_connector_class(kv_config)
|
||||
required_kvcache_layout = connector_cls.get_required_kvcache_layout(vllm_config)
|
||||
if required_kvcache_layout is not None:
|
||||
return required_kvcache_layout
|
||||
logger.info_once(
|
||||
"Connectors do not specify a kv cache layout, defaulting to NHD."
|
||||
)
|
||||
return "NHD"
|
||||
|
||||
|
||||
class KVOutputAggregator:
|
||||
"""Utility class to aggregate the output of all workers into a single
|
||||
output corresponding to Rank 0 for scheduler."""
|
||||
|
||||
def __init__(self, expected_finished_count: int):
|
||||
# Complete transfer tracker. Used to track finished requests
|
||||
# [req_id -> n_remaining_workers]
|
||||
self._recv_remaining_count = dict[str, int]()
|
||||
self._send_remaining_count = dict[str, int]()
|
||||
self._expected_finished_count = expected_finished_count
|
||||
|
||||
@classmethod
|
||||
def from_connector(cls, connector: "KVConnectorBase", world_size: int):
|
||||
return cls(connector.get_finished_count() or world_size)
|
||||
|
||||
def aggregate(
|
||||
self, outputs: list[ModelRunnerOutput | None], output_rank: int = 0
|
||||
) -> ModelRunnerOutput | None:
|
||||
if not outputs[output_rank]:
|
||||
return None
|
||||
|
||||
# Aggregate kv_connector_output from all workers
|
||||
|
||||
def update_finished_set(
|
||||
req_ids: set[str] | None,
|
||||
remaining_count_dict: dict[str, int],
|
||||
finished_set: set[str],
|
||||
) -> None:
|
||||
for req_id in req_ids or ():
|
||||
remaining_count = remaining_count_dict.get(
|
||||
req_id, self._expected_finished_count
|
||||
)
|
||||
remaining_count_dict[req_id] = remaining_count - 1
|
||||
if remaining_count_dict[req_id] == 0:
|
||||
finished_set.add(req_id)
|
||||
del remaining_count_dict[req_id]
|
||||
|
||||
finished_sending = set[str]()
|
||||
finished_recving = set[str]()
|
||||
aggregated_kv_connector_stats = None
|
||||
combined_kv_cache_events = None
|
||||
invalid_block_ids = set[int]()
|
||||
for model_runner_output in outputs:
|
||||
assert model_runner_output is not None
|
||||
kv_output = model_runner_output.kv_connector_output
|
||||
if not kv_output:
|
||||
continue
|
||||
# Allow the worker to dynamically update the expected number of
|
||||
# finished sending/recving for new requests.
|
||||
if (
|
||||
kv_output.expected_finished_count > 0
|
||||
and kv_output.expected_finished_count != self._expected_finished_count
|
||||
):
|
||||
logger.debug(
|
||||
"Expected finished requests updated from %d to %d",
|
||||
self._expected_finished_count,
|
||||
kv_output.expected_finished_count,
|
||||
)
|
||||
self._expected_finished_count = kv_output.expected_finished_count
|
||||
|
||||
update_finished_set(
|
||||
kv_output.finished_sending, self._send_remaining_count, finished_sending
|
||||
)
|
||||
update_finished_set(
|
||||
kv_output.finished_recving, self._recv_remaining_count, finished_recving
|
||||
)
|
||||
|
||||
# Aggregate kv_connector_stats from all workers.
|
||||
if aggregated_kv_connector_stats is None:
|
||||
# Use the first worker's kv_connector_stats as accumulator.
|
||||
aggregated_kv_connector_stats = kv_output.kv_connector_stats
|
||||
elif kv_connector_stats := kv_output.kv_connector_stats:
|
||||
if aggregated_kv_connector_stats is None:
|
||||
aggregated_kv_connector_stats = kv_connector_stats
|
||||
else:
|
||||
assert isinstance(
|
||||
aggregated_kv_connector_stats, type(kv_connector_stats)
|
||||
)
|
||||
aggregated_kv_connector_stats = (
|
||||
aggregated_kv_connector_stats.aggregate(kv_connector_stats)
|
||||
)
|
||||
|
||||
# Combine kv_cache_events from all workers.
|
||||
if combined_kv_cache_events is None:
|
||||
# Use the first worker's kv_cache events as start event list.
|
||||
combined_kv_cache_events = kv_output.kv_cache_events
|
||||
elif kv_cache_events := kv_output.kv_cache_events:
|
||||
assert isinstance(
|
||||
combined_kv_cache_events,
|
||||
type(kv_cache_events),
|
||||
)
|
||||
worker_kv_cache_events = kv_cache_events.get_all_events()
|
||||
combined_kv_cache_events.add_events(worker_kv_cache_events)
|
||||
combined_kv_cache_events.increment_workers(1)
|
||||
|
||||
invalid_block_ids |= kv_output.invalid_block_ids
|
||||
|
||||
# select output of the worker specified by output_rank
|
||||
output = outputs[output_rank]
|
||||
|
||||
assert output is not None
|
||||
output.kv_connector_output = KVConnectorOutput(
|
||||
finished_sending=finished_sending or None,
|
||||
finished_recving=finished_recving or None,
|
||||
kv_connector_stats=aggregated_kv_connector_stats or None,
|
||||
kv_cache_events=combined_kv_cache_events or None,
|
||||
invalid_block_ids=invalid_block_ids,
|
||||
expected_finished_count=self._expected_finished_count,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _make_src_and_dst_indices(
|
||||
src_block_ids: list[int],
|
||||
dst_block_ids: list[int],
|
||||
src_device: torch.device | str,
|
||||
dst_device: torch.device | str,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
src_indices = torch.tensor(src_block_ids, device=src_device, dtype=torch.int64)
|
||||
dst_indices = torch.tensor(dst_block_ids, device=dst_device, dtype=torch.int64)
|
||||
return src_indices, dst_indices
|
||||
|
||||
|
||||
def copy_kv_blocks(
|
||||
src_kv_caches: dict[str, torch.Tensor],
|
||||
dst_kv_caches: dict[str, torch.Tensor],
|
||||
src_block_ids: list[int],
|
||||
dst_block_ids: list[int],
|
||||
direction: Literal["h2d", "d2h"],
|
||||
) -> None:
|
||||
"""Copy kv blocks between different buffers."""
|
||||
if (
|
||||
not src_kv_caches
|
||||
or not dst_kv_caches
|
||||
or not src_block_ids
|
||||
or not dst_block_ids
|
||||
or len(src_block_ids) != len(dst_block_ids)
|
||||
):
|
||||
return
|
||||
|
||||
src_device = next(iter(src_kv_caches.values())).device
|
||||
dst_device = next(iter(dst_kv_caches.values())).device
|
||||
|
||||
src_indices, dst_indices = _make_src_and_dst_indices(
|
||||
src_block_ids=src_block_ids,
|
||||
dst_block_ids=dst_block_ids,
|
||||
src_device=src_device,
|
||||
dst_device=dst_device,
|
||||
)
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if direction == "h2d":
|
||||
copy_fn = current_platform.insert_blocks_to_device
|
||||
else:
|
||||
copy_fn = current_platform.swap_out_blocks_to_host
|
||||
for layer_name in src_kv_caches:
|
||||
src_tensor = src_kv_caches[layer_name]
|
||||
dst_tensor = dst_kv_caches[layer_name]
|
||||
copy_fn(src_tensor, dst_tensor, src_indices, dst_indices)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TpKVTopology:
|
||||
"""
|
||||
Helper class for tensor parallel and KV topology information for
|
||||
mapping between local and remote TP workers.
|
||||
"""
|
||||
|
||||
tp_rank: int
|
||||
remote_tp_size: dict[str, int]
|
||||
is_mla: bool
|
||||
total_num_kv_heads: int
|
||||
attn_backend: type[AttentionBackend]
|
||||
engine_id: str
|
||||
remote_block_size: dict[str, int]
|
||||
|
||||
def __post_init__(self):
|
||||
# Figure out whether the first dimension of the cache is K/V
|
||||
# or num_blocks. This is used to register the memory regions correctly.
|
||||
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
|
||||
num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
|
||||
)
|
||||
# Non-MLA backends caches have 5 dims [2, num_blocks, H,N,D],
|
||||
# we just mock num_blocks to 1 for the dimension check below.
|
||||
self._is_kv_layout_blocks_first = (
|
||||
len(kv_cache_shape) == 5 and kv_cache_shape[0] == 1
|
||||
)
|
||||
|
||||
attn_backend = AttentionBackendEnum[self.attn_backend.get_name()]
|
||||
self._use_pallas = attn_backend == AttentionBackendEnum.PALLAS
|
||||
|
||||
@property
|
||||
def is_kv_layout_blocks_first(self) -> bool:
|
||||
return self._is_kv_layout_blocks_first
|
||||
|
||||
@property
|
||||
def split_k_and_v(self) -> bool:
|
||||
# Whether to register regions for K and V separately (when present).
|
||||
return not (self.is_mla or self._use_pallas or self.is_kv_layout_blocks_first)
|
||||
|
||||
@property
|
||||
def tp_size(self) -> int:
|
||||
return self.remote_tp_size[self.engine_id]
|
||||
|
||||
@property
|
||||
def block_size(self) -> int:
|
||||
return self.remote_block_size[self.engine_id]
|
||||
|
||||
def tp_ratio(
|
||||
self,
|
||||
remote_tp_size: int,
|
||||
) -> int:
|
||||
"""
|
||||
Calculate the tensor parallel ratio between local and remote TP.
|
||||
We can think of it as the number of local TP workers-per-remote TP
|
||||
workers. Local workers will read from the same remote TP worker in
|
||||
groups of size `tp_ratio`.
|
||||
"""
|
||||
assert self.tp_size % remote_tp_size == 0, (
|
||||
f"Local tensor parallel size {self.tp_size} is not divisible "
|
||||
f"by remote tensor parallel size {remote_tp_size}."
|
||||
)
|
||||
return self.tp_size // remote_tp_size
|
||||
|
||||
def block_size_ratio(
|
||||
self,
|
||||
remote_block_size: int,
|
||||
) -> float:
|
||||
"""
|
||||
Calculate the block size ratio between local and remote TP.
|
||||
"""
|
||||
assert self.block_size % remote_block_size == 0, (
|
||||
f"Local block size {self.block_size} is not divisible "
|
||||
f"by remote block size {remote_block_size} or vice versa."
|
||||
)
|
||||
return self.block_size // remote_block_size
|
||||
|
||||
def tp_ratio_from_engine_id(
|
||||
self,
|
||||
remote_engine_id: str,
|
||||
) -> int:
|
||||
remote_tp_size = self.remote_tp_size[remote_engine_id]
|
||||
return self.tp_ratio(remote_tp_size)
|
||||
|
||||
def block_size_ratio_from_engine_id(
|
||||
self,
|
||||
remote_engine_id: str,
|
||||
) -> float:
|
||||
remote_block_size = self.remote_block_size[remote_engine_id]
|
||||
return self.block_size_ratio(remote_block_size)
|
||||
|
||||
def is_kv_replicated(self, engine_id: str) -> bool:
|
||||
"""
|
||||
Whether the KV cache is replicated across TP workers due to the
|
||||
number of TP workers being greater than the number of KV heads.
|
||||
"""
|
||||
tp_size = self.remote_tp_size[engine_id]
|
||||
return tp_size // self.total_num_kv_heads >= 1
|
||||
|
||||
def replicates_kv_cache(self, remote_engine_id: str) -> bool:
|
||||
# MLA is always replicated as the hidden dim can't be split.
|
||||
return self.is_mla or self.is_kv_replicated(remote_engine_id)
|
||||
|
||||
def get_target_remote_rank(
|
||||
self,
|
||||
remote_tp_size: int,
|
||||
) -> int:
|
||||
"""
|
||||
Get the remote TP rank (on P) that the current local TP rank
|
||||
(on D) will read from.
|
||||
"""
|
||||
tp_ratio = self.tp_ratio(remote_tp_size)
|
||||
return self.tp_rank // tp_ratio
|
||||
|
||||
def get_target_remote_rank_from_engine_id(
|
||||
self,
|
||||
remote_engine_id: str,
|
||||
) -> int:
|
||||
remote_tp_size = self.remote_tp_size[remote_engine_id]
|
||||
return self.get_target_remote_rank(remote_tp_size)
|
||||
19
vllm/distributed/kv_transfer/kv_connector/v1/__init__.py
Normal file
19
vllm/distributed/kv_transfer/kv_connector/v1/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1,
|
||||
KVConnectorRole,
|
||||
SupportsHMA,
|
||||
supports_hma,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.decode_bench_connector import ( # noqa E:501
|
||||
DecodeBenchConnector,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"KVConnectorRole",
|
||||
"KVConnectorBase_V1",
|
||||
"supports_hma",
|
||||
"SupportsHMA",
|
||||
"DecodeBenchConnector",
|
||||
]
|
||||
597
vllm/distributed/kv_transfer/kv_connector/v1/base.py
Normal file
597
vllm/distributed/kv_transfer/kv_connector/v1/base.py
Normal file
@@ -0,0 +1,597 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
KVConnectorBase_V1 Class for Distributed KV Cache & Hidden State
|
||||
communication in vLLM v1
|
||||
|
||||
The class provides the following primitives:
|
||||
Scheduler-side: runs in the scheduler, binds metadata, which
|
||||
is used by the worker-side to load/save KV cache.
|
||||
get_num_new_matched_tokens() - get number of new tokens
|
||||
that exist in the remote KV cache. Might be called multiple
|
||||
times for a given request and should be side-effect free.
|
||||
update_state_after_alloc() - update KVConnector state after
|
||||
temporary buffer alloc by the CacheManager.
|
||||
update_connector_output() - update KVConnector state after
|
||||
output is received from worker-side connectors.
|
||||
request_finished() - called once when a request is finished,
|
||||
with the computed kv cache blocks for the request.
|
||||
Returns whether KV cache should be freed now or if the
|
||||
connector now assumes responsibility for freeing the
|
||||
the blocks asynchronously. Also optionally returns KV
|
||||
transfer params.
|
||||
take_events() - returns new KV events that were collected
|
||||
by the connector since the last call.
|
||||
|
||||
Worker-side: runs in each worker, loads/saves KV cache to/from
|
||||
the Connector based on the metadata.
|
||||
start_load_kv() - starts loading all KVs (maybe async)
|
||||
wait_for_layer_load() - blocks until layer i load is done
|
||||
|
||||
save_kv_layer() - starts saving KV for layer i (maybe async)
|
||||
wait_for_save() - blocks until all saves are done
|
||||
|
||||
get_finished() - called with ids of finished requests, returns
|
||||
ids of requests that have completed async sending/recving.
|
||||
"""
|
||||
|
||||
import enum
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Iterable
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Literal, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.outputs import KVConnectorOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_events import KVCacheEvent, KVConnectorKVEvents
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
||||
KVConnectorPromMetrics,
|
||||
KVConnectorStats,
|
||||
PromMetric,
|
||||
PromMetricT,
|
||||
)
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.request import Request
|
||||
|
||||
# s_tensor_list, d_tensor_list, s_indices, d_indices, direction
|
||||
CopyBlocksOp = Callable[
|
||||
[
|
||||
dict[str, torch.Tensor],
|
||||
dict[str, torch.Tensor],
|
||||
list[int],
|
||||
list[int],
|
||||
Literal["h2d", "d2h"],
|
||||
],
|
||||
None,
|
||||
]
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class SupportsHMA(ABC):
|
||||
"""
|
||||
The class that indicates the corresponding connector supports hybrid memory
|
||||
allocator (HMA).
|
||||
This is required to use the connector together with hybrid memory allocator.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def request_finished_all_groups(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: tuple[list[int], ...],
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
"""
|
||||
Called exactly once when a request has finished for all kv cache groups,
|
||||
before its blocks are freed for each group.
|
||||
|
||||
NOTE(Kuntai): This function is only supported by connectors that support HMA.
|
||||
|
||||
The connector may assumes responsibility for freeing the blocks
|
||||
asynchronously by returning True.
|
||||
|
||||
Returns:
|
||||
True if the request is being saved/sent asynchronously and blocks
|
||||
should not be freed until the request_id is returned from
|
||||
get_finished().
|
||||
Optional KVTransferParams to be included in the request outputs
|
||||
returned by the engine.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def supports_hma(connector: Any) -> bool:
|
||||
if isinstance(connector, type):
|
||||
return issubclass(connector, SupportsHMA)
|
||||
else:
|
||||
return isinstance(connector, SupportsHMA)
|
||||
|
||||
|
||||
class KVConnectorRole(enum.Enum):
|
||||
# Connector running in the scheduler process
|
||||
SCHEDULER = 0
|
||||
|
||||
# Connector running in the worker process
|
||||
WORKER = 1
|
||||
|
||||
|
||||
class KVConnectorHandshakeMetadata(ABC): # noqa: B024
|
||||
"""
|
||||
Metadata used for out of band connector handshake between
|
||||
P/D workers. This needs to serializeable.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class KVConnectorMetadata(ABC): # noqa: B024
|
||||
"""
|
||||
Abstract Metadata used to communicate between the
|
||||
Scheduler KVConnector and Worker KVConnector.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class KVConnectorBase_V1(ABC):
|
||||
"""
|
||||
Base class for KV connectors.
|
||||
|
||||
Attributes:
|
||||
prefer_cross_layer_blocks (bool): Indicates whether this connector
|
||||
prefers KV blocks that hold KV data for all layers (for speeding
|
||||
up KV data transfers).
|
||||
Defaults to False.
|
||||
"""
|
||||
|
||||
prefer_cross_layer_blocks: ClassVar[bool] = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: Optional["KVCacheConfig"] = None,
|
||||
):
|
||||
logger.warning(
|
||||
"Initializing KVConnectorBase_V1. This API is experimental and "
|
||||
"subject to change in the future as we iterate the design."
|
||||
)
|
||||
self._connector_metadata: KVConnectorMetadata | None = None
|
||||
self._vllm_config = vllm_config
|
||||
if vllm_config.kv_transfer_config is not None:
|
||||
self._kv_transfer_config = vllm_config.kv_transfer_config
|
||||
else:
|
||||
raise ValueError("kv_transfer_config must be set for KVConnectorBase_V1")
|
||||
self._kv_cache_config = kv_cache_config
|
||||
if self._kv_cache_config is None:
|
||||
logger.warning(
|
||||
"KVConnectorBase_V1 initialized without kv_cache_config. "
|
||||
"This is deprecated - please update your connector to accept "
|
||||
"kv_cache_config as the third constructor argument and pass it "
|
||||
"to super().__init__()."
|
||||
)
|
||||
self._role = role
|
||||
|
||||
@property
|
||||
def role(self) -> KVConnectorRole:
|
||||
return self._role
|
||||
|
||||
# ==============================
|
||||
# Worker-side methods
|
||||
# ==============================
|
||||
|
||||
def bind_connector_metadata(self, connector_metadata: KVConnectorMetadata) -> None:
|
||||
"""Set the connector metadata from the scheduler.
|
||||
|
||||
This function should be called by the model runner every time
|
||||
before the model execution. The metadata will be used for runtime
|
||||
KV cache loading and saving.
|
||||
|
||||
Args:
|
||||
connector_metadata (dict): the connector metadata.
|
||||
"""
|
||||
self._connector_metadata = connector_metadata
|
||||
|
||||
def clear_connector_metadata(self) -> None:
|
||||
"""Clear the connector metadata.
|
||||
|
||||
This function should be called by the model runner every time
|
||||
after the model execution.
|
||||
"""
|
||||
self._connector_metadata = None
|
||||
|
||||
def _get_connector_metadata(self) -> KVConnectorMetadata:
|
||||
"""Get the connector metadata.
|
||||
|
||||
This function should only be called inside the connector.
|
||||
|
||||
Returns:
|
||||
ConnectorMetadata: the connector metadata.
|
||||
"""
|
||||
# Should only be called while set to valid metadata.
|
||||
assert self._connector_metadata is not None
|
||||
return self._connector_metadata
|
||||
|
||||
def has_connector_metadata(self) -> bool:
|
||||
"""Check whether the connector metadata is currently set.
|
||||
|
||||
Returns:
|
||||
bool: True if connector metadata exists, False otherwise.
|
||||
"""
|
||||
return self._connector_metadata is not None
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
"""
|
||||
Initialize with the KV caches. Useful for pre-registering the
|
||||
KV Caches in the KVConnector (e.g. for NIXL).
|
||||
|
||||
Args:
|
||||
kv_caches: dictionary of layer names, kv cache
|
||||
"""
|
||||
return
|
||||
|
||||
def register_cross_layers_kv_cache(
|
||||
self, kv_cache: torch.Tensor, attn_backend: type["AttentionBackend"]
|
||||
):
|
||||
"""
|
||||
Initialize with a single KV cache tensor used by all layers.
|
||||
The first dimension should be num_layers.
|
||||
This function will only be called for models with uniform layers,
|
||||
and only if the prefers_cross_layer_blocks is set to True.
|
||||
Only one of the functions
|
||||
{register_kv_caches, register_cross_layers_kv_cache} will be called.
|
||||
|
||||
Args:
|
||||
kv_cache: a cross-layers kv cache tensor
|
||||
attn_backend: The attention backend that corresponds to all layers
|
||||
"""
|
||||
return
|
||||
|
||||
def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp):
|
||||
"""
|
||||
Set the xPU-specific ops for copying KV between host and device.
|
||||
Needed when host buffer is used for kv transfer (e.g., in NixlConnector)
|
||||
"""
|
||||
return
|
||||
|
||||
@abstractmethod
|
||||
def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None:
|
||||
"""
|
||||
Start loading the KV cache from the connector to vLLM's paged
|
||||
KV buffer. This is called from the forward context before the
|
||||
forward pass to enable async loading during model execution.
|
||||
|
||||
Args:
|
||||
forward_context (ForwardContext): the forward context.
|
||||
**kwargs: additional arguments for the load operation
|
||||
|
||||
Note:
|
||||
The number of elements in kv_caches and layer_names should be
|
||||
the same.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
"""
|
||||
Block until the KV for a specific layer is loaded into vLLM's
|
||||
paged buffer. This is called from within attention layer to ensure
|
||||
async copying from start_load_kv is complete.
|
||||
|
||||
This interface will be useful for layer-by-layer pipelining.
|
||||
|
||||
Args:
|
||||
layer_name: the name of that layer
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_kv_layer(
|
||||
self,
|
||||
layer_name: str,
|
||||
kv_layer: torch.Tensor,
|
||||
attn_metadata: "AttentionMetadata",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Start saving a layer of KV cache from vLLM's paged buffer
|
||||
to the connector. This is called from within attention layer to
|
||||
enable async copying during execution.
|
||||
|
||||
Args:
|
||||
layer_name (str): the name of the layer.
|
||||
kv_layer (torch.Tensor): the paged KV buffer of the current
|
||||
layer in vLLM.
|
||||
attn_metadata (AttentionMetadata): the attention metadata.
|
||||
**kwargs: additional arguments for the save operation.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def wait_for_save(self):
|
||||
"""
|
||||
Block until all the save operations is done. This is called
|
||||
as the forward context exits to ensure that the async saving
|
||||
from save_kv_layer is complete before finishing the forward.
|
||||
|
||||
This prevents overwrites of paged KV buffer before saving done.
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_finished(
|
||||
self, finished_req_ids: set[str]
|
||||
) -> tuple[set[str] | None, set[str] | None]:
|
||||
"""
|
||||
Notifies worker-side connector ids of requests that have
|
||||
finished generating tokens on the worker.
|
||||
The scheduler process (via the Executors) will use this output
|
||||
to track which workers are done.
|
||||
|
||||
Returns:
|
||||
ids of requests that have finished asynchronous transfer
|
||||
(requests that previously returned True from request_finished()),
|
||||
tuple of (sending/saving ids, recving/loading ids).
|
||||
The finished saves/sends req ids must belong to a set provided in a
|
||||
call to this method (this call or a prior one).
|
||||
"""
|
||||
return None, None
|
||||
|
||||
def get_block_ids_with_load_errors(self) -> set[int]:
|
||||
"""
|
||||
Get the set of block IDs that failed to load.
|
||||
|
||||
Returns:
|
||||
Set of block IDs that encountered load errors.
|
||||
Empty set if no load errors occurred.
|
||||
|
||||
Notes:
|
||||
- Applies to both sync- and async-loading requests.
|
||||
- Async loading: failed blocks may be reported in any forward pass
|
||||
up to and including the pass where the request ID is returned by
|
||||
`get_finished()`. Even if failures occur, the request must still
|
||||
be reported via `get_finished()`, and the failed block IDs must
|
||||
appear here no later than that same pass.
|
||||
- Sync loading: failed blocks should be reported in the forward
|
||||
pass in which they are detected.
|
||||
"""
|
||||
return set()
|
||||
|
||||
def shutdown(self):
|
||||
"""
|
||||
Shutdown the connector. This is called when the worker process
|
||||
is shutting down to ensure that all the async operations are
|
||||
completed and the connector is cleaned up properly.
|
||||
"""
|
||||
return None
|
||||
|
||||
def get_kv_connector_stats(self) -> Optional["KVConnectorStats"]:
|
||||
"""
|
||||
Get the KV connector stats collected during the last interval.
|
||||
"""
|
||||
return None
|
||||
|
||||
def get_kv_connector_kv_cache_events(self) -> Optional["KVConnectorKVEvents"]:
|
||||
"""
|
||||
Get the KV connector kv cache events collected during the last interval.
|
||||
This function should be called by the model runner every time after the
|
||||
model execution and before cleanup.
|
||||
"""
|
||||
return None
|
||||
|
||||
def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None:
|
||||
"""
|
||||
Get the KVConnector handshake metadata for this connector.
|
||||
This metadata is used for out-of-band connector handshake
|
||||
between P/D workers.
|
||||
|
||||
Returns:
|
||||
KVConnectorHandshakeMetadata: the handshake metadata.
|
||||
None if no handshake metadata is available.
|
||||
"""
|
||||
return None
|
||||
|
||||
# ==============================
|
||||
# Scheduler-side methods
|
||||
# ==============================
|
||||
|
||||
@abstractmethod
|
||||
def get_num_new_matched_tokens(
|
||||
self,
|
||||
request: "Request",
|
||||
num_computed_tokens: int,
|
||||
) -> tuple[int | None, bool]:
|
||||
"""
|
||||
Get number of new tokens that can be loaded from the
|
||||
external KV cache beyond the num_computed_tokens.
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
num_computed_tokens (int): the number of locally
|
||||
computed tokens for this request
|
||||
|
||||
Returns:
|
||||
A tuple with the following elements:
|
||||
- An optional number of tokens that can be loaded from the
|
||||
external KV cache beyond what is already computed.
|
||||
If None, it means that the connector needs more time to
|
||||
determine the number of matched tokens, and the scheduler
|
||||
should query for this request again later.
|
||||
- `True` if external KV cache tokens will be loaded
|
||||
asynchronously (between scheduler steps). Must be
|
||||
'False' if the first element is 0.
|
||||
|
||||
Notes:
|
||||
The connector should only consider the largest prefix of prompt-
|
||||
tokens for which KV cache is actually available at the time of the
|
||||
call. If the cache cannot be loaded for some tokens (e.g., due to
|
||||
connectivity issues or eviction), those tokens must not be taken
|
||||
into account.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_state_after_alloc(
|
||||
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
|
||||
):
|
||||
"""
|
||||
Update KVConnector state after block allocation.
|
||||
|
||||
If get_num_new_matched_tokens previously returned True for a
|
||||
request, this function may be called twice for that same request -
|
||||
first when blocks are allocated for the connector tokens to be
|
||||
asynchronously loaded into, and second when any additional blocks
|
||||
are allocated, after the load/transfer is complete.
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
blocks (KVCacheBlocks): the blocks allocated for the request.
|
||||
num_external_tokens (int): the number of tokens that will be
|
||||
loaded from the external KV cache.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: SchedulerOutput
|
||||
) -> KVConnectorMetadata:
|
||||
"""
|
||||
Build the connector metadata for this step.
|
||||
|
||||
This function should NOT modify fields in the scheduler_output.
|
||||
Also, calling this function will reset the state of the connector.
|
||||
|
||||
Args:
|
||||
scheduler_output (SchedulerOutput): the scheduler output object.
|
||||
"""
|
||||
pass
|
||||
|
||||
def update_connector_output(self, connector_output: KVConnectorOutput):
|
||||
"""
|
||||
Update KVConnector state from worker-side connectors output.
|
||||
|
||||
Args:
|
||||
connector_output (KVConnectorOutput): the worker-side
|
||||
connectors output.
|
||||
"""
|
||||
return
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
"""
|
||||
Called exactly once when a request has finished, before its blocks are
|
||||
freed.
|
||||
|
||||
The connector may assumes responsibility for freeing the blocks
|
||||
asynchronously by returning True.
|
||||
|
||||
Returns:
|
||||
True if the request is being saved/sent asynchronously and blocks
|
||||
should not be freed until the request_id is returned from
|
||||
get_finished().
|
||||
Optional KVTransferParams to be included in the request outputs
|
||||
returned by the engine.
|
||||
"""
|
||||
return False, None
|
||||
|
||||
def take_events(self) -> Iterable["KVCacheEvent"]:
|
||||
"""
|
||||
Take the KV cache events from the connector.
|
||||
|
||||
Yields:
|
||||
New KV cache events since the last call.
|
||||
"""
|
||||
return ()
|
||||
|
||||
@classmethod
|
||||
def get_required_kvcache_layout(cls, vllm_config: "VllmConfig") -> str | None:
|
||||
"""
|
||||
Get the required KV cache layout for this connector.
|
||||
Args:
|
||||
vllm_config (VllmConfig): the vllm config.
|
||||
|
||||
Returns:
|
||||
str: the required KV cache layout. e.g. HND, or NHD.
|
||||
None if the connector does not require a specific layout.
|
||||
"""
|
||||
|
||||
if cls is KVConnectorBase_V1:
|
||||
raise TypeError(
|
||||
"get_required_kvcache_layout should not be called "
|
||||
"on the abstract base class"
|
||||
)
|
||||
return None
|
||||
|
||||
def get_finished_count(self) -> int | None:
|
||||
"""
|
||||
Get the count of requests expected to complete send/receive operations
|
||||
via this connector. This method is used to initialize the
|
||||
KVOutputAggregator, overwriting the default world_size.
|
||||
|
||||
Returns:
|
||||
int: expected sending or receiving completion count.
|
||||
"""
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def build_kv_connector_stats(
|
||||
cls, data: dict[str, Any] | None = None
|
||||
) -> Optional["KVConnectorStats"]:
|
||||
"""
|
||||
KVConnectorStats resolution method. This method allows dynamically
|
||||
registered connectors to return their own KVConnectorStats object,
|
||||
which can implement custom aggregation logic on the data dict.
|
||||
"""
|
||||
return None
|
||||
|
||||
def set_xfer_handshake_metadata(
|
||||
self, metadata: dict[int, KVConnectorHandshakeMetadata]
|
||||
) -> None:
|
||||
"""
|
||||
Set the KV connector handshake metadata for this connector.
|
||||
|
||||
Args:
|
||||
metadata (KVConnectorHandshakeMetadata): the handshake metadata to set.
|
||||
"""
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def build_prom_metrics(
|
||||
cls,
|
||||
vllm_config: "VllmConfig",
|
||||
metric_types: dict[type["PromMetric"], type["PromMetricT"]],
|
||||
labelnames: list[str],
|
||||
per_engine_labelvalues: dict[int, list[object]],
|
||||
) -> Optional["KVConnectorPromMetrics"]:
|
||||
"""
|
||||
Create a KVConnectorPromMetrics subclass which should register
|
||||
per-connector Prometheus metrics and implement observe() to
|
||||
expose connector transfer stats via Prometheus.
|
||||
"""
|
||||
return None
|
||||
|
||||
def reset_cache(self) -> bool | None:
|
||||
"""
|
||||
Reset the connector's internal cache.
|
||||
|
||||
Returns:
|
||||
bool: True if the cache was successfully reset, False otherwise.
|
||||
"""
|
||||
logger.debug(
|
||||
"Connector cache reset requested, but %s does not implement reset_cache().",
|
||||
type(self).__name__,
|
||||
)
|
||||
|
||||
return None
|
||||
@@ -0,0 +1,419 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
DecodeBenchConnector: A KV Connector for decode instance performance testing.
|
||||
|
||||
This connector emulates a prefill-decode disaggregated setting by filling
|
||||
the KV cache with dummy values, allowing measurement of decoder performance
|
||||
under larger input sequence lengths (ISL) in resource-limited environments.
|
||||
|
||||
Usage:
|
||||
To use this connector for benchmarking, configure it in the kv_transfer_config:
|
||||
|
||||
Example:
|
||||
vllm serve <model> --kv-transfer-config '{
|
||||
"kv_connector": "DecodeBenchConnector",
|
||||
"kv_role": "kv_both",
|
||||
"kv_connector_extra_config": {
|
||||
"fill_mean": 0.015,
|
||||
"fill_std": 0.0
|
||||
}
|
||||
}'
|
||||
|
||||
Then run your benchmark with desired input/output lengths:
|
||||
vllm bench serve --base-url http://127.0.0.1:8000 --model <model> \\
|
||||
--dataset-name random --random-input-len 40000 \\
|
||||
--random-output-len 100 --max-concurrency 10
|
||||
|
||||
Configuration options (via kv_connector_extra_config):
|
||||
- fill_mean (float): Mean value for random normal fill (default: 0.015)
|
||||
- fill_std (float): Standard deviation for random fill (default: 0.0)
|
||||
Set to 0 for constant values, >0 for random sampling
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import (
|
||||
KVConnectorBase_V1,
|
||||
KVConnectorRole,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.math_utils import cdiv
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DecodeBenchConnectorMetadata(KVConnectorMetadata):
|
||||
"""Metadata for DecodeBenchConnector.
|
||||
|
||||
Contains information about which requests need their KV cache filled
|
||||
with dummy values for benchmarking purposes.
|
||||
"""
|
||||
|
||||
# request_id -> (block_ids_per_group, num_tokens_to_fill)
|
||||
# block_ids_per_group is a tuple of lists, one per KV cache group
|
||||
# For standard attention: single group, e.g., ([1, 2, 3],)
|
||||
# For MLA: multiple groups, e.g., ([1, 2], [1, 2])
|
||||
reqs_to_fill: dict[str, tuple[tuple[list[int], ...], int]]
|
||||
|
||||
|
||||
class DecodeBenchConnector(KVConnectorBase_V1):
|
||||
"""
|
||||
A KV Connector for decode instance performance testing.
|
||||
|
||||
This connector fills the KV cache with dummy (non-zero) values to
|
||||
emulate a prefill-decode disaggregated setting, enabling performance
|
||||
testing of the decoder with larger input sequence lengths.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: Optional["KVCacheConfig"] = None,
|
||||
):
|
||||
super().__init__(vllm_config, role, kv_cache_config)
|
||||
|
||||
self.connector_scheduler: DecodeBenchConnectorScheduler | None = None
|
||||
self.connector_worker: DecodeBenchConnectorWorker | None = None
|
||||
|
||||
if role == KVConnectorRole.SCHEDULER:
|
||||
self.connector_scheduler = DecodeBenchConnectorScheduler(vllm_config)
|
||||
elif role == KVConnectorRole.WORKER:
|
||||
self.connector_worker = DecodeBenchConnectorWorker(vllm_config)
|
||||
|
||||
# ==============================
|
||||
# Worker-side methods
|
||||
# ==============================
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
assert self.connector_worker is not None
|
||||
self.connector_worker.register_kv_caches(kv_caches)
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None:
|
||||
assert self.connector_worker is not None
|
||||
assert isinstance(self._connector_metadata, DecodeBenchConnectorMetadata)
|
||||
self.connector_worker.start_fill_kv(self._connector_metadata)
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
# All operations are synchronous, so nothing to wait for
|
||||
pass
|
||||
|
||||
def save_kv_layer(
|
||||
self,
|
||||
layer_name: str,
|
||||
kv_layer: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
# This connector doesn't save KV cache (benchmarking only)
|
||||
pass
|
||||
|
||||
def wait_for_save(self):
|
||||
# This connector doesn't save KV cache (benchmarking only)
|
||||
pass
|
||||
|
||||
# ==============================
|
||||
# Scheduler-side methods
|
||||
# ==============================
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self,
|
||||
request: "Request",
|
||||
num_computed_tokens: int,
|
||||
) -> tuple[int | None, bool]:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.get_num_new_matched_tokens(
|
||||
request, num_computed_tokens
|
||||
)
|
||||
|
||||
def update_state_after_alloc(
|
||||
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
|
||||
):
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.update_state_after_alloc(
|
||||
request, blocks, num_external_tokens
|
||||
)
|
||||
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: "SchedulerOutput"
|
||||
) -> KVConnectorMetadata:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.build_connector_meta(scheduler_output)
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
assert self.connector_scheduler is not None
|
||||
self.connector_scheduler.request_finished(request)
|
||||
return False, None
|
||||
|
||||
|
||||
class DecodeBenchConnectorScheduler:
|
||||
"""Scheduler-side implementation for DecodeBenchConnector."""
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig"):
|
||||
self.vllm_config = vllm_config
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
|
||||
# Track which requests have already been filled
|
||||
self._filled_requests: set[str] = set()
|
||||
|
||||
# Track pending fills for the current scheduler step
|
||||
# request_id -> (block_ids_per_group, num_tokens_to_fill)
|
||||
# Note: _pending_fills doesn't need explicit cleanup - it's cleared
|
||||
# after build_connector_meta() is called in the same scheduler step
|
||||
self._pending_fills: dict[str, tuple[tuple[list[int], ...], int]] = {}
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self,
|
||||
request: "Request",
|
||||
num_computed_tokens: int,
|
||||
) -> tuple[int, bool]:
|
||||
"""
|
||||
For new requests, return the number of tokens that should be filled
|
||||
with dummy KV cache values.
|
||||
|
||||
Returns:
|
||||
(num_tokens_to_fill, is_async)
|
||||
- num_tokens_to_fill: number of uncomputed tokens minus 1
|
||||
(we fill everything except the last token for decode)
|
||||
- is_async: False (synchronous filling)
|
||||
"""
|
||||
req_id = request.request_id
|
||||
|
||||
# Only fill once per request on first scheduling
|
||||
if req_id in self._filled_requests:
|
||||
return 0, False
|
||||
|
||||
# Calculate how many tokens we need to fill
|
||||
# Fill all uncomputed tokens except the last one (which will be decoded)
|
||||
# This simulates having processed a long prefill
|
||||
num_uncomputed_tokens = request.num_tokens - num_computed_tokens
|
||||
num_tokens_to_fill = max(0, num_uncomputed_tokens - 1)
|
||||
|
||||
if num_tokens_to_fill == 0:
|
||||
return 0, False
|
||||
|
||||
# Return False for synchronous operation - the fill is fast enough
|
||||
# that async overhead isn't worth it
|
||||
return num_tokens_to_fill, False
|
||||
|
||||
def update_state_after_alloc(
|
||||
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
|
||||
):
|
||||
"""
|
||||
Called after blocks are allocated. Store the block IDs so we can
|
||||
fill them with dummy values.
|
||||
|
||||
Supports both standard attention (single KV cache group) and MLA
|
||||
(multiple KV cache groups).
|
||||
"""
|
||||
req_id = request.request_id
|
||||
|
||||
if num_external_tokens == 0:
|
||||
return
|
||||
|
||||
# Get the block IDs that were allocated
|
||||
# block_groups is a tuple of lists, one per KV cache group
|
||||
# For standard attention: 1 group
|
||||
# For MLA: multiple groups (one per attention type)
|
||||
block_groups = blocks.get_block_ids()
|
||||
|
||||
# Calculate how many blocks we need to fill
|
||||
# num_external_tokens are the tokens we said we'd provide
|
||||
num_blocks_to_fill = cdiv(num_external_tokens, self.block_size)
|
||||
|
||||
# Extract the first num_blocks_to_fill blocks from each group
|
||||
# All groups should have the same block IDs for the same request
|
||||
block_ids_per_group = tuple(
|
||||
group_blocks[:num_blocks_to_fill] for group_blocks in block_groups
|
||||
)
|
||||
|
||||
# Store the blocks to fill for all group. _pending_fills doesn't need cleanup
|
||||
# as it's cleared after build_connector_meta
|
||||
self._pending_fills[req_id] = (
|
||||
block_ids_per_group,
|
||||
num_external_tokens,
|
||||
)
|
||||
self._filled_requests.add(req_id)
|
||||
|
||||
logger.debug(
|
||||
"DecodeBenchConnector: Allocated %d blocks across %d KV cache groups "
|
||||
"for request %s",
|
||||
num_blocks_to_fill,
|
||||
len(block_groups),
|
||||
req_id,
|
||||
)
|
||||
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: "SchedulerOutput"
|
||||
) -> KVConnectorMetadata:
|
||||
"""
|
||||
Build metadata containing information about which blocks to fill
|
||||
with dummy KV values.
|
||||
"""
|
||||
meta = DecodeBenchConnectorMetadata(reqs_to_fill=self._pending_fills.copy())
|
||||
|
||||
# Clear pending fills after building metadata
|
||||
self._pending_fills.clear()
|
||||
|
||||
return meta
|
||||
|
||||
def request_finished(self, request: "Request"):
|
||||
"""
|
||||
Called when a request has finished. Clean up any state.
|
||||
"""
|
||||
self._filled_requests.discard(request.request_id)
|
||||
|
||||
|
||||
class DecodeBenchConnectorWorker:
|
||||
"""Worker-side implementation for DecodeBenchConnector."""
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig"):
|
||||
self.vllm_config = vllm_config
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
|
||||
# Get fill parameters from extra config
|
||||
kv_transfer_config = vllm_config.kv_transfer_config
|
||||
assert kv_transfer_config is not None
|
||||
self.fill_mean = kv_transfer_config.get_from_extra_config("fill_mean", 0.015)
|
||||
self.fill_std = kv_transfer_config.get_from_extra_config("fill_std", 0.0)
|
||||
|
||||
# Will be populated via register_kv_caches
|
||||
self.kv_caches: dict[str, torch.Tensor] | None = None
|
||||
|
||||
# Mapping from KV cache group index to list of layer names in that group
|
||||
self.group_to_layers: dict[int, list[str]] | None = None
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
"""Store references to the KV cache tensors and build group mapping."""
|
||||
self.kv_caches = kv_caches
|
||||
|
||||
# For simplicity, assume all layers belong to group 0 (standard attention)
|
||||
# For MLA models with multiple groups, the metadata will handle the mapping
|
||||
# We just need to fill the blocks specified in the metadata
|
||||
self.group_to_layers = {0: list(kv_caches.keys())}
|
||||
|
||||
logger.debug(
|
||||
"DecodeBenchConnector: Registered %d KV cache layers",
|
||||
len(kv_caches),
|
||||
)
|
||||
|
||||
def start_fill_kv(self, metadata: DecodeBenchConnectorMetadata):
|
||||
"""
|
||||
Fill the allocated KV cache blocks with dummy (non-zero) values.
|
||||
|
||||
This simulates having a populated KV cache from a prefill phase,
|
||||
allowing decode performance testing with larger context sizes.
|
||||
|
||||
Supports both standard attention (single group) and MLA (multiple groups).
|
||||
"""
|
||||
if not metadata.reqs_to_fill:
|
||||
return
|
||||
|
||||
assert self.kv_caches is not None, "KV caches must be registered before filling"
|
||||
assert self.group_to_layers is not None, "Group mapping must be initialized"
|
||||
|
||||
for req_id, (block_ids_per_group, num_tokens) in metadata.reqs_to_fill.items():
|
||||
# Fill blocks for each KV cache group
|
||||
for group_idx, block_ids in enumerate(block_ids_per_group):
|
||||
self._fill_blocks(group_idx, block_ids, num_tokens)
|
||||
|
||||
logger.debug(
|
||||
"DecodeBenchConnector: Filled %d blocks (%d tokens) across %d groups "
|
||||
"for request %s",
|
||||
len(block_ids_per_group[0]) if block_ids_per_group else 0,
|
||||
num_tokens,
|
||||
len(block_ids_per_group),
|
||||
req_id,
|
||||
)
|
||||
|
||||
def _fill_blocks(self, group_idx: int, block_ids: list[int], num_tokens: int):
|
||||
"""
|
||||
Fill specified blocks with dummy non-zero values for a specific KV cache group.
|
||||
|
||||
Args:
|
||||
group_idx: The KV cache group index to fill
|
||||
block_ids: List of block IDs to fill in this group
|
||||
num_tokens: Total number of tokens to fill across these blocks
|
||||
"""
|
||||
if not block_ids:
|
||||
return
|
||||
|
||||
assert self.kv_caches is not None
|
||||
assert self.group_to_layers is not None
|
||||
|
||||
# Get the layers that belong to this group
|
||||
layer_names = self.group_to_layers.get(group_idx, [])
|
||||
|
||||
# Fill only the layers in this group
|
||||
for layer_name in layer_names:
|
||||
if layer_name not in self.kv_caches:
|
||||
logger.warning(
|
||||
"DecodeBenchConnector: Layer %s not found in KV caches", layer_name
|
||||
)
|
||||
continue
|
||||
|
||||
kv_cache = self.kv_caches[layer_name]
|
||||
|
||||
# Convert block_ids to tensor on device
|
||||
block_ids_tensor = torch.tensor(
|
||||
block_ids, dtype=torch.long, device=kv_cache.device
|
||||
)
|
||||
|
||||
# Filter invalid block IDs
|
||||
valid_mask = block_ids_tensor < kv_cache.shape[0]
|
||||
valid_block_ids = block_ids_tensor[valid_mask]
|
||||
|
||||
if len(valid_block_ids) == 0:
|
||||
continue
|
||||
|
||||
# Create fill values - either constant or random
|
||||
block_shape = kv_cache.shape[1:]
|
||||
if self.fill_std > 0:
|
||||
# Random normal sampling
|
||||
fill_values = torch.normal(
|
||||
mean=self.fill_mean,
|
||||
std=self.fill_std,
|
||||
size=(len(valid_block_ids),) + block_shape,
|
||||
dtype=kv_cache.dtype,
|
||||
device=kv_cache.device,
|
||||
)
|
||||
else:
|
||||
# Constant fill value
|
||||
fill_values = torch.full(
|
||||
(len(valid_block_ids),) + block_shape,
|
||||
self.fill_mean,
|
||||
dtype=kv_cache.dtype,
|
||||
device=kv_cache.device,
|
||||
)
|
||||
|
||||
# Batch fill operation
|
||||
kv_cache[valid_block_ids] = fill_values
|
||||
|
||||
logger.debug(
|
||||
"DecodeBenchConnector: Filled %d blocks in group %d with %s values "
|
||||
"(mean=%.3f, std=%.3f)",
|
||||
len(block_ids),
|
||||
group_idx,
|
||||
"random" if self.fill_std > 0 else "constant",
|
||||
self.fill_mean,
|
||||
self.fill_std,
|
||||
)
|
||||
@@ -0,0 +1,450 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1,
|
||||
KVConnectorMetadata,
|
||||
KVConnectorRole,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.hashing import safe_hash
|
||||
from vllm.v1.attention.backends.mla.common import MLACommonMetadata
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReqMeta:
|
||||
# Request tokens
|
||||
token_ids: torch.Tensor
|
||||
# Slot mappings, should have the same length as token_ids
|
||||
slot_mapping: torch.Tensor
|
||||
# Is store or load
|
||||
is_store: bool
|
||||
mm_hashes: list[str]
|
||||
|
||||
@staticmethod
|
||||
def make_meta(
|
||||
token_ids: list[int],
|
||||
block_ids: list[int],
|
||||
block_size: int,
|
||||
is_store: bool,
|
||||
mm_hashes: list[str],
|
||||
) -> "ReqMeta":
|
||||
valid_num_tokens = align_to_block_size(len(token_ids), block_size)
|
||||
token_ids_tensor = torch.tensor(token_ids)[:valid_num_tokens]
|
||||
block_ids_tensor = torch.tensor(block_ids)
|
||||
num_blocks = block_ids_tensor.shape[0]
|
||||
block_offsets = torch.arange(0, block_size)
|
||||
slot_mapping = (
|
||||
block_offsets.reshape((1, block_size))
|
||||
+ block_ids_tensor.reshape((num_blocks, 1)) * block_size
|
||||
)
|
||||
slot_mapping = slot_mapping.flatten()[:valid_num_tokens]
|
||||
return ReqMeta(
|
||||
token_ids=token_ids_tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
is_store=is_store,
|
||||
mm_hashes=mm_hashes,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExampleConnectorMetadata(KVConnectorMetadata):
|
||||
requests: list[ReqMeta] = field(default_factory=list)
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
token_ids: list[int],
|
||||
block_ids: list[int],
|
||||
block_size: int,
|
||||
is_store: bool,
|
||||
mm_hashes: list[str],
|
||||
) -> None:
|
||||
self.requests.append(
|
||||
ReqMeta.make_meta(token_ids, block_ids, block_size, is_store, mm_hashes)
|
||||
)
|
||||
|
||||
|
||||
class ExampleConnector(KVConnectorBase_V1):
|
||||
# NOTE: This is Simple debug implementation of the KV connector.
|
||||
# It save / load the KV cache to / from the disk.
|
||||
# It does extra work which will overwrite the existing prefix-cache in GPU
|
||||
# - to remove the overhead, need to add some "mask" in the ReqMeta class
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: Optional["KVCacheConfig"] = None,
|
||||
):
|
||||
super().__init__(
|
||||
vllm_config=vllm_config,
|
||||
role=role,
|
||||
kv_cache_config=kv_cache_config,
|
||||
)
|
||||
self._block_size = vllm_config.cache_config.block_size
|
||||
self._requests_need_load: dict[str, Request] = {}
|
||||
self._storage_path = self._kv_transfer_config.get_from_extra_config(
|
||||
"shared_storage_path", "/tmp"
|
||||
)
|
||||
logger.info(self._kv_transfer_config)
|
||||
logger.info("Shared storage path is %s", self._storage_path)
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None:
|
||||
"""Start loading the KV cache from the connector buffer to vLLM's
|
||||
paged KV buffer.
|
||||
|
||||
Args:
|
||||
forward_context (ForwardContext): the forward context.
|
||||
**kwargs: additional arguments for the load operation
|
||||
|
||||
Note:
|
||||
The number of elements in kv_caches and layer_names should be
|
||||
the same.
|
||||
"""
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
|
||||
def inject_kv_into_layer(
|
||||
dst_kv_cache_layer: torch.Tensor,
|
||||
src_kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
) -> None:
|
||||
"""Inject the KV cache into the layer.
|
||||
|
||||
Args:
|
||||
dst_kv_cache_layer (torch.Tensor): the destination KV cache
|
||||
layer. In shape [2, num_pages, page_size, xxx] if not
|
||||
using MLA, [num_pages, page_size, xxx] otherwise.
|
||||
src_kv_cache (torch.Tensor): the source KV cache. In shape
|
||||
[2, num_tokens, xxx] if not using MLA, [num_tokens, xxx]
|
||||
otherwise.
|
||||
slot_mapping (torch.Tensor): the slot mapping. In shape
|
||||
[num_tokens].
|
||||
"""
|
||||
dst_kv_cache_layer_shape = dst_kv_cache_layer.shape
|
||||
if isinstance(attn_metadata, MLACommonMetadata):
|
||||
num_pages = dst_kv_cache_layer_shape[0]
|
||||
page_size = dst_kv_cache_layer_shape[1]
|
||||
dst_kv_cache_layer = dst_kv_cache_layer.reshape(
|
||||
num_pages * page_size, -1
|
||||
)
|
||||
dst_kv_cache_layer[slot_mapping, ...] = src_kv_cache
|
||||
dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)
|
||||
else:
|
||||
num_pages = dst_kv_cache_layer_shape[1]
|
||||
page_size = dst_kv_cache_layer_shape[2]
|
||||
dst_kv_cache_layer = dst_kv_cache_layer.reshape(
|
||||
2, num_pages * page_size, -1
|
||||
)
|
||||
dst_kv_cache_layer[:, slot_mapping, ...] = src_kv_cache
|
||||
dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)
|
||||
|
||||
# Get the metadata
|
||||
metadata: KVConnectorMetadata = self._get_connector_metadata()
|
||||
assert isinstance(metadata, ExampleConnectorMetadata)
|
||||
|
||||
if metadata is None:
|
||||
logger.warning(
|
||||
"In connector.start_load_kv, but the connector metadata is None"
|
||||
)
|
||||
return
|
||||
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if attn_metadata is None:
|
||||
logger.warning("In connector.start_load_kv, but the attn_metadata is None")
|
||||
return
|
||||
|
||||
# Load the KV for each request each layer
|
||||
for request in metadata.requests:
|
||||
if request.is_store:
|
||||
continue
|
||||
logger.info(
|
||||
"Inject KV cache of %d tokens to the paged memory",
|
||||
len(request.slot_mapping),
|
||||
)
|
||||
for layer_name in forward_context.no_compile_layers:
|
||||
layer = forward_context.no_compile_layers[layer_name]
|
||||
|
||||
# Only process layers that have kv_cache
|
||||
# attribute (attention layers) Skip non-attention
|
||||
# layers like FusedMoE/MLP etc.
|
||||
kv_cache_attr = getattr(layer, "kv_cache", None)
|
||||
if kv_cache_attr is None:
|
||||
continue
|
||||
|
||||
kv_cache_layer = kv_cache_attr[forward_context.virtual_engine]
|
||||
|
||||
filename = self._generate_filename_debug(
|
||||
layer_name, request.token_ids, request.mm_hashes
|
||||
)
|
||||
kv_cache = safetensors.torch.load_file(filename)["kv_cache"].cuda()
|
||||
inject_kv_into_layer(kv_cache_layer, kv_cache, request.slot_mapping)
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
"""Blocking until the KV for a specific layer is loaded into vLLM's
|
||||
paged buffer.
|
||||
|
||||
This interface will be useful for layer-by-layer pipelining.
|
||||
|
||||
Args:
|
||||
layer_name: the name of that layer
|
||||
"""
|
||||
return
|
||||
|
||||
def save_kv_layer(
|
||||
self,
|
||||
layer_name: str,
|
||||
kv_layer: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Start saving the KV cache of the layer from vLLM's paged buffer
|
||||
to the connector.
|
||||
|
||||
Args:
|
||||
layer_name (str): the name of the layer.
|
||||
kv_layer (torch.Tensor): the paged KV buffer of the current
|
||||
layer in vLLM.
|
||||
attn_metadata (AttentionMetadata): the attention metadata.
|
||||
**kwargs: additional arguments for the save operation.
|
||||
"""
|
||||
|
||||
def extract_kv_from_layer(
|
||||
layer: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Extract the KV cache from the layer.
|
||||
|
||||
Assume the shape of the layer is (2, num_pages, page_size, xxx)
|
||||
if MLA is not used, and (num_pages, page_size, xxx) otherwise.
|
||||
"""
|
||||
if isinstance(attn_metadata, MLACommonMetadata):
|
||||
num_pages, page_size = layer.shape[0], layer.shape[1]
|
||||
return layer.reshape(num_pages * page_size, -1)[slot_mapping, ...]
|
||||
num_pages, page_size = layer.shape[1], layer.shape[2]
|
||||
return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping, ...]
|
||||
|
||||
connector_metadata = self._get_connector_metadata()
|
||||
assert isinstance(connector_metadata, ExampleConnectorMetadata)
|
||||
for request in connector_metadata.requests:
|
||||
if request.is_store:
|
||||
filename = self._generate_filename_debug(
|
||||
layer_name, request.token_ids, request.mm_hashes
|
||||
)
|
||||
kv_cache = extract_kv_from_layer(kv_layer, request.slot_mapping)
|
||||
tensors = {"kv_cache": kv_cache.detach().cpu()}
|
||||
safetensors.torch.save_file(tensors, filename)
|
||||
|
||||
def wait_for_save(self):
|
||||
return
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self,
|
||||
request: "Request",
|
||||
num_computed_tokens: int,
|
||||
) -> tuple[int | None, bool]:
|
||||
"""
|
||||
Get number of new tokens that can be loaded from the
|
||||
external KV cache beyond the num_computed_tokens.
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
num_computed_tokens (int): the number of locally
|
||||
computed tokens for this request
|
||||
|
||||
Returns:
|
||||
the number of tokens that can be loaded from the
|
||||
external KV cache beyond what is already computed.
|
||||
"""
|
||||
# NOTE: in this debug implementation, we assume that the prompt is
|
||||
# cached_prompt + newly_generated_single_token
|
||||
# Therefore, we use prompt_token_ids[:-1] to determine the folder name
|
||||
|
||||
# NOTE: in current v1 scheduler, the num_computed_tokens is aligned
|
||||
# with the block granularity. And it expects the returned blocks and
|
||||
# num_computed_tokens to also be aligned with the block granularity.
|
||||
if not self._found_match_for_request(request):
|
||||
return 0, False
|
||||
|
||||
logger.info("External Cache Hit!")
|
||||
|
||||
# Now, first num_tokens_to_check tokens are hit, we need to prepare
|
||||
# the metadata for the worker connector to correctly load the KV
|
||||
token_ids = request.prompt_token_ids or []
|
||||
num_tokens_to_check = align_to_block_size(len(token_ids) - 1, self._block_size)
|
||||
|
||||
return num_tokens_to_check - num_computed_tokens, False
|
||||
|
||||
def update_state_after_alloc(
|
||||
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
|
||||
):
|
||||
"""
|
||||
Update KVConnector state after block allocation.
|
||||
|
||||
If blocks were allocated, add to _requests_need_load,
|
||||
such that we load the KVs in the next forward pass.
|
||||
"""
|
||||
if num_external_tokens > 0:
|
||||
self._requests_need_load[request.request_id] = request
|
||||
|
||||
def build_connector_meta(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
) -> KVConnectorMetadata:
|
||||
"""Build the connector metadata for this step.
|
||||
|
||||
This function should NOT modify any fields in the scheduler_output.
|
||||
Also, calling this function will reset the state of the connector.
|
||||
|
||||
Args:
|
||||
scheduler_output (SchedulerOutput): the scheduler output object.
|
||||
"""
|
||||
meta = ExampleConnectorMetadata()
|
||||
|
||||
total_need_load = 0
|
||||
for new_req in scheduler_output.scheduled_new_reqs:
|
||||
token_ids = new_req.prompt_token_ids or []
|
||||
mm_hashes = [f.identifier for f in new_req.mm_features]
|
||||
if new_req.req_id in self._requests_need_load:
|
||||
meta.add_request(
|
||||
token_ids=token_ids,
|
||||
block_ids=new_req.block_ids[0],
|
||||
block_size=self._block_size,
|
||||
is_store=False,
|
||||
mm_hashes=mm_hashes,
|
||||
)
|
||||
total_need_load += 1
|
||||
else:
|
||||
# NOTE: here, we set the store and load being exclusive,
|
||||
# but a single request can have both store and load.
|
||||
# NOTE(rob): for this debug implementation, we only cache
|
||||
# the original prompt tokens.
|
||||
if not self._found_match_for_prompt(token_ids, mm_hashes):
|
||||
meta.add_request(
|
||||
token_ids=token_ids,
|
||||
block_ids=new_req.block_ids[0],
|
||||
block_size=self._block_size,
|
||||
is_store=True,
|
||||
mm_hashes=mm_hashes,
|
||||
)
|
||||
|
||||
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||
for i, req_id in enumerate(cached_reqs.req_ids):
|
||||
resumed_from_preemption = req_id in cached_reqs.resumed_req_ids
|
||||
if not resumed_from_preemption or req_id not in self._requests_need_load:
|
||||
continue
|
||||
|
||||
num_computed_tokens = cached_reqs.num_computed_tokens[i]
|
||||
num_new_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
new_block_ids = cached_reqs.new_block_ids[i]
|
||||
|
||||
# NOTE(rob): cached_req_data does not have the full
|
||||
# list of token ids (only new tokens). So we look it
|
||||
# up in the actual request object.
|
||||
request = self._requests_need_load[req_id]
|
||||
total_tokens = num_computed_tokens + num_new_tokens
|
||||
token_ids = request.all_token_ids[:total_tokens]
|
||||
|
||||
# NOTE(rob): For resumed req, new_block_ids is all
|
||||
# of the block_ids for the request.
|
||||
assert new_block_ids is not None
|
||||
block_ids = new_block_ids[0]
|
||||
|
||||
meta.add_request(
|
||||
token_ids=token_ids,
|
||||
block_ids=block_ids,
|
||||
block_size=self._block_size,
|
||||
is_store=False,
|
||||
mm_hashes=[f.identifier for f in request.mm_features],
|
||||
)
|
||||
total_need_load += 1
|
||||
|
||||
assert total_need_load == len(self._requests_need_load)
|
||||
self._requests_need_load.clear()
|
||||
return meta
|
||||
|
||||
# ==============================
|
||||
# Helper functions
|
||||
# ==============================
|
||||
|
||||
def _found_match_for_request(
|
||||
self,
|
||||
request: "Request",
|
||||
) -> bool:
|
||||
"""Check if the cache is hit for the request."""
|
||||
return self._found_match_for_prompt(
|
||||
list(request.prompt_token_ids or []),
|
||||
[f.identifier for f in request.mm_features],
|
||||
)
|
||||
|
||||
def _found_match_for_prompt(
|
||||
self,
|
||||
prompt_token_ids: list[int],
|
||||
mm_hashes: list[str],
|
||||
) -> bool:
|
||||
num_tokens_to_check = align_to_block_size(
|
||||
len(prompt_token_ids) - 1, self._block_size
|
||||
)
|
||||
foldername = self._generate_foldername_debug(
|
||||
torch.tensor(prompt_token_ids)[:num_tokens_to_check],
|
||||
mm_hashes,
|
||||
create_folder=False,
|
||||
)
|
||||
return os.path.exists(foldername)
|
||||
|
||||
def _generate_foldername_debug(
|
||||
self,
|
||||
token_ids: torch.Tensor,
|
||||
mm_hashes: list[str],
|
||||
create_folder=False,
|
||||
) -> str:
|
||||
"""Generate a folder name based on the hash of the bytes of the input
|
||||
ids.
|
||||
"""
|
||||
token_bytes = token_ids.numpy().tobytes()
|
||||
# Add mm_hashes to the bytes being hashed to avoid path traversal and
|
||||
# to create a canonical key.
|
||||
if mm_hashes:
|
||||
mm_str = "-".join(mm_hashes)
|
||||
token_bytes += mm_str.encode("utf-8")
|
||||
input_ids_hash = safe_hash(token_bytes, usedforsecurity=False).hexdigest()
|
||||
|
||||
foldername = os.path.join(self._storage_path, input_ids_hash)
|
||||
if create_folder:
|
||||
os.makedirs(foldername, exist_ok=True)
|
||||
return foldername
|
||||
|
||||
def _generate_filename_debug(
|
||||
self,
|
||||
layer_name: str,
|
||||
token_ids: torch.Tensor,
|
||||
mm_hashes: list[str],
|
||||
) -> str:
|
||||
"""Generate a file name based on the layer name and the hash
|
||||
of the bytes of the input ids.
|
||||
"""
|
||||
foldername = self._generate_foldername_debug(
|
||||
token_ids, mm_hashes=mm_hashes, create_folder=True
|
||||
)
|
||||
return os.path.join(foldername, f"{layer_name}.safetensors")
|
||||
|
||||
|
||||
def align_to_block_size(num_tokens: int, block_size) -> int:
|
||||
"""Align the number of tokens to the block size."""
|
||||
return (num_tokens - 1) // block_size * block_size
|
||||
@@ -0,0 +1,327 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Iterable
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_events import (
|
||||
BlockStored,
|
||||
KVCacheEvent,
|
||||
KVConnectorKVEvents,
|
||||
KVEventAggregator,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1,
|
||||
KVConnectorMetadata,
|
||||
KVConnectorRole,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.outputs import KVConnectorOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class LMCacheKVEvents(KVConnectorKVEvents):
|
||||
"""
|
||||
Concrete implementation of KVConnectorKVEvents using KVEventAggregator.
|
||||
"""
|
||||
|
||||
def __init__(self, num_workers: int) -> None:
|
||||
self._aggregator = KVEventAggregator(num_workers)
|
||||
|
||||
def add_events(self, events: list[KVCacheEvent]) -> None:
|
||||
self._aggregator.add_events(events)
|
||||
|
||||
def aggregate(self) -> "LMCacheKVEvents":
|
||||
"""
|
||||
Aggregate KV events and retain only common events.
|
||||
"""
|
||||
common_events = self._aggregator.get_common_events()
|
||||
self._aggregator.clear_events()
|
||||
self._aggregator.add_events(common_events)
|
||||
self._aggregator.reset_workers()
|
||||
return self
|
||||
|
||||
def increment_workers(self, count: int = 1) -> None:
|
||||
self._aggregator.increment_workers(count)
|
||||
|
||||
def get_all_events(self) -> list[KVCacheEvent]:
|
||||
return self._aggregator.get_all_events()
|
||||
|
||||
def get_number_of_workers(self) -> int:
|
||||
return self._aggregator.get_number_of_workers()
|
||||
|
||||
def clear_events(self) -> None:
|
||||
self._aggregator.clear_events()
|
||||
self._aggregator.reset_workers()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<LMCacheKVEvents events={self.get_all_events()}>"
|
||||
|
||||
|
||||
class LMCacheConnectorV1(KVConnectorBase_V1):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: "KVCacheConfig",
|
||||
):
|
||||
super().__init__(
|
||||
vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config
|
||||
)
|
||||
assert vllm_config.kv_transfer_config is not None
|
||||
use_native = vllm_config.kv_transfer_config.get_from_extra_config(
|
||||
"use_native", False
|
||||
)
|
||||
if use_native:
|
||||
logger.info("Initializing native LMCache connector")
|
||||
# lazy import
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import lmcache_integration
|
||||
|
||||
_adapter = lmcache_integration.vllm_v1_adapter
|
||||
|
||||
cls = _adapter.LMCacheConnectorV1Impl
|
||||
else:
|
||||
logger.info("Initializing latest dev LMCache connector")
|
||||
# lazy import
|
||||
from lmcache.integration.vllm.vllm_v1_adapter import (
|
||||
LMCacheConnectorV1Impl as LMCacheConnectorLatestImpl,
|
||||
)
|
||||
|
||||
cls = LMCacheConnectorLatestImpl
|
||||
|
||||
self._lmcache_engine = cls(vllm_config, role, self)
|
||||
|
||||
self._kv_cache_events: LMCacheKVEvents | None = None
|
||||
|
||||
# ==============================
|
||||
# Worker-side methods
|
||||
# ==============================
|
||||
def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None:
|
||||
"""
|
||||
Start loading the KV cache from the connector to vLLM's paged
|
||||
KV buffer. This is called from the forward context before the
|
||||
forward pass to enable async loading during model execution.
|
||||
|
||||
Args:
|
||||
forward_context (ForwardContext): the forward context.
|
||||
**kwargs: additional arguments for the load operation
|
||||
|
||||
Note:
|
||||
The number of elements in kv_caches and layer_names should be
|
||||
the same.
|
||||
|
||||
"""
|
||||
self._lmcache_engine.start_load_kv(forward_context, **kwargs)
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
"""
|
||||
Block until the KV for a specific layer is loaded into vLLM's
|
||||
paged buffer. This is called from within attention layer to ensure
|
||||
async copying from start_load_kv is complete.
|
||||
|
||||
This interface will be useful for layer-by-layer pipelining.
|
||||
|
||||
Args:
|
||||
layer_name: the name of that layer
|
||||
"""
|
||||
self._lmcache_engine.wait_for_layer_load(layer_name)
|
||||
|
||||
def save_kv_layer(
|
||||
self,
|
||||
layer_name: str,
|
||||
kv_layer: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Start saving the a layer of KV cache from vLLM's paged buffer
|
||||
to the connector. This is called from within attention layer to
|
||||
enable async copying during execution.
|
||||
|
||||
Args:
|
||||
layer_name (str): the name of the layer.
|
||||
kv_layer (torch.Tensor): the paged KV buffer of the current
|
||||
layer in vLLM.
|
||||
attn_metadata (AttentionMetadata): the attention metadata.
|
||||
**kwargs: additional arguments for the save operation.
|
||||
"""
|
||||
self._lmcache_engine.save_kv_layer(
|
||||
layer_name, kv_layer, attn_metadata, **kwargs
|
||||
)
|
||||
|
||||
def wait_for_save(self):
|
||||
"""
|
||||
Block until all the save operations is done. This is called
|
||||
as the forward context exits to ensure that the async saving
|
||||
from save_kv_layer is complete before finishing the forward.
|
||||
|
||||
This prevents overwrites of paged KV buffer before saving done.
|
||||
"""
|
||||
self._lmcache_engine.wait_for_save()
|
||||
|
||||
def get_finished(
|
||||
self, finished_req_ids: set[str]
|
||||
) -> tuple[set[str] | None, set[str] | None]:
|
||||
"""
|
||||
Notifies worker-side connector ids of requests that have
|
||||
finished generating tokens.
|
||||
|
||||
Returns:
|
||||
ids of requests that have finished asynchronous transfer
|
||||
(requests that previously returned True from request_finished()),
|
||||
tuple of (sending/saving ids, recving/loading ids).
|
||||
The finished saves/sends req ids must belong to a set provided in a
|
||||
call to this method (this call or a prior one).
|
||||
"""
|
||||
return self._lmcache_engine.get_finished(finished_req_ids)
|
||||
|
||||
def get_block_ids_with_load_errors(self) -> set[int]:
|
||||
"""
|
||||
Get the set of block IDs that failed to load.
|
||||
|
||||
Returns:
|
||||
Set of block IDs that encountered load errors.
|
||||
Empty set if no load errors occurred.
|
||||
"""
|
||||
method = getattr(self._lmcache_engine, "get_block_ids_with_load_errors", None)
|
||||
if callable(method):
|
||||
return method()
|
||||
|
||||
# Fallback for older versions that don't support this method
|
||||
return set()
|
||||
|
||||
def get_kv_connector_kv_cache_events(self) -> LMCacheKVEvents | None:
|
||||
"""
|
||||
Get the KV connector kv cache events collected during the last interval.
|
||||
"""
|
||||
|
||||
events = self._lmcache_engine.get_kv_events() # type: ignore [attr-defined]
|
||||
if not events:
|
||||
return None
|
||||
|
||||
blocks: list[BlockStored] = [
|
||||
BlockStored(
|
||||
block_hashes=e.block_hashes,
|
||||
parent_block_hash=e.parent_block_hash,
|
||||
token_ids=e.token_ids,
|
||||
lora_id=e.lora_id,
|
||||
block_size=e.block_size,
|
||||
medium=e.medium,
|
||||
)
|
||||
for e in events
|
||||
]
|
||||
|
||||
lmcache_kv_events = LMCacheKVEvents(num_workers=1)
|
||||
lmcache_kv_events.add_events(blocks)
|
||||
return lmcache_kv_events
|
||||
|
||||
# ==============================
|
||||
# Scheduler-side methods
|
||||
# ==============================
|
||||
def get_num_new_matched_tokens(
|
||||
self,
|
||||
request: "Request",
|
||||
num_computed_tokens: int,
|
||||
) -> tuple[int | None, bool]:
|
||||
"""
|
||||
Get number of new tokens that can be loaded from the
|
||||
external KV cache beyond the num_computed_tokens.
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
num_computed_tokens (int): the number of locally
|
||||
computed tokens for this request
|
||||
|
||||
Returns:
|
||||
the number of tokens that can be loaded from the
|
||||
external KV cache beyond what is already computed.
|
||||
"""
|
||||
return self._lmcache_engine.get_num_new_matched_tokens(
|
||||
request, num_computed_tokens
|
||||
), False
|
||||
|
||||
def update_state_after_alloc(
|
||||
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
|
||||
):
|
||||
"""
|
||||
Update KVConnector state after block allocation.
|
||||
"""
|
||||
self._lmcache_engine.update_state_after_alloc(request, num_external_tokens)
|
||||
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: SchedulerOutput
|
||||
) -> KVConnectorMetadata:
|
||||
"""
|
||||
Build the connector metadata for this step.
|
||||
|
||||
This function should NOT modify fields in the scheduler_output.
|
||||
Also, calling this function will reset the state of the connector.
|
||||
|
||||
Args:
|
||||
scheduler_output (SchedulerOutput): the scheduler output object.
|
||||
"""
|
||||
return self._lmcache_engine.build_connector_meta(scheduler_output)
|
||||
|
||||
def update_connector_output(self, connector_output: KVConnectorOutput):
|
||||
"""
|
||||
Update KVConnector state from worker-side connectors output.
|
||||
|
||||
Args:
|
||||
connector_output (KVConnectorOutput): the worker-side
|
||||
connectors output.
|
||||
"""
|
||||
# Get the KV events
|
||||
kv_cache_events = connector_output.kv_cache_events
|
||||
if not kv_cache_events or not isinstance(kv_cache_events, LMCacheKVEvents):
|
||||
return
|
||||
|
||||
if self._kv_cache_events is None:
|
||||
self._kv_cache_events = kv_cache_events
|
||||
else:
|
||||
self._kv_cache_events.add_events(kv_cache_events.get_all_events())
|
||||
self._kv_cache_events.increment_workers(
|
||||
kv_cache_events.get_number_of_workers()
|
||||
)
|
||||
return
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
"""
|
||||
Called when a request has finished, before its blocks are freed.
|
||||
|
||||
Returns:
|
||||
True if the request is being saved/sent asynchronously and blocks
|
||||
should not be freed until the request_id is returned from
|
||||
get_finished().
|
||||
Optional KVTransferParams to be included in the request outputs
|
||||
returned by the engine.
|
||||
"""
|
||||
return self._lmcache_engine.request_finished(request, block_ids)
|
||||
|
||||
def take_events(self) -> Iterable["KVCacheEvent"]:
|
||||
"""
|
||||
Take the KV cache events from the connector.
|
||||
|
||||
Yields:
|
||||
New KV cache events since the last call.
|
||||
"""
|
||||
if self._kv_cache_events is not None:
|
||||
self._kv_cache_events.aggregate()
|
||||
kv_cache_events = self._kv_cache_events.get_all_events()
|
||||
yield from kv_cache_events
|
||||
self._kv_cache_events.clear_events()
|
||||
self._kv_cache_events = None
|
||||
@@ -0,0 +1,18 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
from . import multi_process_adapter, vllm_v1_adapter
|
||||
from .multi_process_adapter import (
|
||||
LMCacheMPSchedulerAdapter,
|
||||
LMCacheMPWorkerAdapter,
|
||||
LoadStoreOp,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"vllm_v1_adapter",
|
||||
"multi_process_adapter",
|
||||
"LMCacheMPSchedulerAdapter",
|
||||
"LMCacheMPWorkerAdapter",
|
||||
"LoadStoreOp",
|
||||
]
|
||||
@@ -0,0 +1,378 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import os
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from itertools import islice
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import zmq
|
||||
from lmcache.utils import _lmcache_nvtx_annotate, init_logger
|
||||
from lmcache.v1.multiprocess.custom_types import (
|
||||
CudaIPCWrapper,
|
||||
IPCCacheEngineKey,
|
||||
KVCache,
|
||||
)
|
||||
from lmcache.v1.multiprocess.mq import MessageQueueClient, MessagingFuture
|
||||
from lmcache.v1.multiprocess.protocol import RequestType, get_response_class
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def wrap_kv_caches(kv_caches: dict[str, KVCache]) -> KVCache:
|
||||
logger.info("KV caches keys are %s", list(kv_caches.keys()))
|
||||
return [CudaIPCWrapper(tensor) for tensor in kv_caches.values()]
|
||||
|
||||
|
||||
def send_lmcache_request(
|
||||
mq_client: MessageQueueClient,
|
||||
request_type: RequestType,
|
||||
payloads: list[Any],
|
||||
) -> MessagingFuture[Any]:
|
||||
future = mq_client.submit_request(
|
||||
request_type, payloads, get_response_class(request_type)
|
||||
)
|
||||
return future
|
||||
|
||||
|
||||
def get_lmcache_chunk_size(
|
||||
mq_client: MessageQueueClient,
|
||||
) -> int:
|
||||
future = send_lmcache_request(mq_client, RequestType.GET_CHUNK_SIZE, [])
|
||||
chunk_size = future.result()
|
||||
return chunk_size
|
||||
|
||||
|
||||
def striding_block_hashes(
|
||||
block_hashes: list[bytes],
|
||||
blocks_in_chunk,
|
||||
) -> Iterable[bytes]:
|
||||
"""Striding the block hashes to get the block hashes for each chunk.
|
||||
For example, if blocks_in_chunk is 16, then we will get the block hashes
|
||||
for the 16th, 32nd, 48th, ... blocks.
|
||||
"""
|
||||
return islice(block_hashes, blocks_in_chunk - 1, None, blocks_in_chunk)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoadStoreOp:
|
||||
block_hashes: list[bytes]
|
||||
block_ids: list[int]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.block_hashes)
|
||||
|
||||
def __post_init__(self):
|
||||
assert len(self.block_hashes) == len(self.block_ids), (
|
||||
"The number of block hashes should be equal to the number of block ids "
|
||||
f"But got {len(self.block_hashes)} and {len(self.block_ids)}"
|
||||
)
|
||||
|
||||
|
||||
StoreResult = bool
|
||||
RetrieveResult = list[bool]
|
||||
LookupResult = list[bool]
|
||||
|
||||
|
||||
class LMCacheMPSchedulerAdapter:
|
||||
def __init__(
|
||||
self,
|
||||
server_url: str,
|
||||
context: zmq.Context,
|
||||
model_name: str,
|
||||
world_size: int,
|
||||
kv_rank: int,
|
||||
vllm_block_size: int,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
server_url: The server URL for the LMCache message queue
|
||||
context: The ZMQ context
|
||||
|
||||
model_name: The model name used for LMCache keys
|
||||
world_size: The world size used for LMCache keys
|
||||
kv_rank: The kv rank used for LMCache keys
|
||||
vllm_block_size: The block size used in vLLM
|
||||
"""
|
||||
self.mq_client = MessageQueueClient(server_url, context)
|
||||
|
||||
# Request futures
|
||||
self.lookup_futures: dict[str, MessagingFuture[LookupResult]] = {}
|
||||
|
||||
self.model_name = model_name
|
||||
self.world_size = world_size
|
||||
self.worker_id = kv_rank
|
||||
|
||||
# Read chunk size from lmcache
|
||||
self.chunk_size = get_lmcache_chunk_size(self.mq_client)
|
||||
assert self.chunk_size % vllm_block_size == 0, (
|
||||
"LMCache chunk size should be a multiple of vLLM block size"
|
||||
)
|
||||
self.blocks_in_chunk = self.chunk_size // vllm_block_size
|
||||
|
||||
@_lmcache_nvtx_annotate
|
||||
def maybe_submit_lookup_request(self, request_id: str, block_hashes: list[bytes]):
|
||||
if request_id in self.lookup_futures:
|
||||
# Skip if there is already a lookup request
|
||||
return
|
||||
|
||||
s = striding_block_hashes(block_hashes, self.blocks_in_chunk)
|
||||
keys = [self._create_key(block_hash) for block_hash in s]
|
||||
future = send_lmcache_request(
|
||||
self.mq_client,
|
||||
RequestType.LOOKUP,
|
||||
[keys, True],
|
||||
)
|
||||
self.lookup_futures[request_id] = future
|
||||
|
||||
@_lmcache_nvtx_annotate
|
||||
def check_lookup_result(self, request_id: str) -> int | None:
|
||||
assert request_id in self.lookup_futures, (
|
||||
f"Lookup request for request_id={request_id} has not been submitted"
|
||||
)
|
||||
|
||||
future = self.lookup_futures[request_id]
|
||||
if not future.query():
|
||||
return None
|
||||
|
||||
result = future.result()
|
||||
num_chunks = sum(result)
|
||||
return num_chunks * self.chunk_size
|
||||
|
||||
def num_blocks_per_chunk(self) -> int:
|
||||
"""
|
||||
Returns:
|
||||
The number of vllm blocks in a LMCache data chunk
|
||||
"""
|
||||
return self.blocks_in_chunk
|
||||
|
||||
# Helper functions
|
||||
def _create_key(self, block_hash: bytes) -> IPCCacheEngineKey:
|
||||
"""Convert a block hash to an IPC cache engine key"""
|
||||
return IPCCacheEngineKey(
|
||||
model_name=self.model_name,
|
||||
world_size=self.world_size,
|
||||
worker_id=self.worker_id,
|
||||
chunk_hash=block_hash,
|
||||
)
|
||||
|
||||
|
||||
class LMCacheMPWorkerAdapter:
|
||||
def __init__(
|
||||
self,
|
||||
server_url: str,
|
||||
context: zmq.Context,
|
||||
model_name: str,
|
||||
world_size: int,
|
||||
kv_rank: int,
|
||||
vllm_block_size: int,
|
||||
):
|
||||
self.mq_client = MessageQueueClient(server_url, context)
|
||||
|
||||
# Instance id for GPU worker
|
||||
self.instance_id = os.getpid()
|
||||
|
||||
# Registered kv caches from vLLM
|
||||
self.kv_caches: dict[str, torch.Tensor] = {}
|
||||
|
||||
# Request futures
|
||||
# request_id -> (future, other merged requests)
|
||||
self.store_futures: dict[
|
||||
str, tuple[MessagingFuture[StoreResult], list[str]]
|
||||
] = {}
|
||||
self.retrieve_futures: dict[
|
||||
str, tuple[MessagingFuture[RetrieveResult], list[str]]
|
||||
] = {}
|
||||
|
||||
self.finished_stores: set[str] = set()
|
||||
self.previously_finished: set[str] = set()
|
||||
|
||||
self.model_name = model_name
|
||||
self.world_size = world_size
|
||||
self.worker_id = kv_rank
|
||||
|
||||
# Read chunk size from lmcache
|
||||
chunk_size = get_lmcache_chunk_size(self.mq_client)
|
||||
assert chunk_size % vllm_block_size == 0, (
|
||||
"LMCache chunk size should be a multiple of vLLM block size"
|
||||
)
|
||||
self.blocks_in_chunk = chunk_size // vllm_block_size
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, KVCache]):
|
||||
# Register kv cache and send the request
|
||||
self.kv_caches = kv_caches
|
||||
logger.info("Registering kv caches")
|
||||
future = send_lmcache_request(
|
||||
self.mq_client,
|
||||
RequestType.REGISTER_KV_CACHE,
|
||||
[self.instance_id, wrap_kv_caches(kv_caches)],
|
||||
)
|
||||
future.result()
|
||||
|
||||
@_lmcache_nvtx_annotate
|
||||
def submit_store_request(
|
||||
self, request_id: str, op: LoadStoreOp, event: torch.cuda.Event
|
||||
):
|
||||
keys = self._block_hashes_to_keys(op.block_hashes)
|
||||
future = send_lmcache_request(
|
||||
self.mq_client,
|
||||
RequestType.STORE,
|
||||
[keys, self.instance_id, op.block_ids, event.ipc_handle()],
|
||||
).to_cuda_future()
|
||||
self.store_futures[request_id] = (future, [])
|
||||
|
||||
@_lmcache_nvtx_annotate
|
||||
def submit_retrieve_request(
|
||||
self, request_id: str, op: LoadStoreOp, event: torch.cuda.Event
|
||||
):
|
||||
keys = self._block_hashes_to_keys(op.block_hashes)
|
||||
future = send_lmcache_request(
|
||||
self.mq_client,
|
||||
RequestType.RETRIEVE,
|
||||
[keys, self.instance_id, op.block_ids, event.ipc_handle()],
|
||||
).to_cuda_future()
|
||||
self.retrieve_futures[request_id] = (future, [])
|
||||
|
||||
@_lmcache_nvtx_annotate
|
||||
def batched_submit_store_requests(
|
||||
self,
|
||||
request_ids: list[str],
|
||||
ops: list[LoadStoreOp],
|
||||
event: torch.cuda.Event,
|
||||
):
|
||||
keys = []
|
||||
block_ids = []
|
||||
for op in ops:
|
||||
keys.extend(self._block_hashes_to_keys(op.block_hashes))
|
||||
block_ids.extend(op.block_ids)
|
||||
future = send_lmcache_request(
|
||||
self.mq_client,
|
||||
RequestType.STORE,
|
||||
[keys, self.instance_id, block_ids, event.ipc_handle()],
|
||||
).to_cuda_future()
|
||||
self.store_futures[request_ids[0]] = (future, request_ids[1:])
|
||||
|
||||
@_lmcache_nvtx_annotate
|
||||
def batched_submit_retrieve_requests(
|
||||
self,
|
||||
request_ids: list[str],
|
||||
ops: list[LoadStoreOp],
|
||||
event: torch.cuda.Event,
|
||||
):
|
||||
keys = []
|
||||
block_ids = []
|
||||
for op in ops:
|
||||
keys.extend(self._block_hashes_to_keys(op.block_hashes))
|
||||
block_ids.extend(op.block_ids)
|
||||
future = send_lmcache_request(
|
||||
self.mq_client,
|
||||
RequestType.RETRIEVE,
|
||||
[keys, self.instance_id, block_ids, event.ipc_handle()],
|
||||
).to_cuda_future()
|
||||
self.retrieve_futures[request_ids[0]] = (future, request_ids[1:])
|
||||
|
||||
@_lmcache_nvtx_annotate
|
||||
def get_finished(
|
||||
self, finished_req_ids: set[str]
|
||||
) -> tuple[set[str] | None, set[str] | None]:
|
||||
finished_stores = set()
|
||||
finished_retrieves = set()
|
||||
for request_id, (future, other_reqs) in self.store_futures.items():
|
||||
if not future.query():
|
||||
continue
|
||||
|
||||
result = future.result()
|
||||
finished_stores.add(request_id)
|
||||
finished_stores.update(other_reqs)
|
||||
|
||||
if not result:
|
||||
# TODO: add error handling here
|
||||
logger.error(
|
||||
"Something went wrong when processing the "
|
||||
"store request for request_id=%s",
|
||||
request_id,
|
||||
)
|
||||
|
||||
for request_id, (future, other_reqs) in self.retrieve_futures.items():
|
||||
if not future.query():
|
||||
continue
|
||||
|
||||
result = future.result()
|
||||
finished_retrieves.add(request_id)
|
||||
finished_retrieves.update(other_reqs)
|
||||
|
||||
if not all(result):
|
||||
# TODO: add error handing here
|
||||
logger.error(
|
||||
"Something went wrong when processing the "
|
||||
"retrieve request for request_id=%s, result=%s",
|
||||
request_id,
|
||||
result,
|
||||
)
|
||||
|
||||
# Remove the finished requests from the tracking dicts
|
||||
for request_id in finished_stores:
|
||||
self.store_futures.pop(request_id, None)
|
||||
for request_id in finished_retrieves:
|
||||
self.retrieve_futures.pop(request_id, None)
|
||||
|
||||
# Update the internal states
|
||||
self.finished_stores.update(finished_stores)
|
||||
|
||||
ret_stores = set()
|
||||
for req_id in finished_req_ids:
|
||||
if req_id in self.finished_stores or req_id in self.store_futures:
|
||||
self.previously_finished.add(req_id)
|
||||
else:
|
||||
ret_stores.add(req_id)
|
||||
|
||||
# Calculate the final finished stores
|
||||
ret_stores.update(self._update_and_get_finished_store())
|
||||
|
||||
return ret_stores, finished_retrieves
|
||||
|
||||
def num_blocks_per_chunk(self) -> int:
|
||||
"""
|
||||
Returns:
|
||||
The number of vllm blocks in a LMCache data chunk
|
||||
"""
|
||||
return self.blocks_in_chunk
|
||||
|
||||
def shutdown(self):
|
||||
# Unregister kv cache
|
||||
logger.info("Unregistering kv caches")
|
||||
send_lmcache_request(
|
||||
self.mq_client, RequestType.UNREGISTER_KV_CACHE, [self.instance_id]
|
||||
).result()
|
||||
|
||||
self.mq_client.close()
|
||||
|
||||
# Helper functions
|
||||
def _update_and_get_finished_store(
|
||||
self,
|
||||
) -> set[str]:
|
||||
"""Converge the internal states about finished stores
|
||||
and returns the 'safe finished store request ids' back
|
||||
"""
|
||||
safe_finished_s = self.finished_stores.intersection(self.previously_finished)
|
||||
self.finished_stores.difference_update(self.previously_finished)
|
||||
self.previously_finished.difference_update(safe_finished_s)
|
||||
|
||||
return safe_finished_s
|
||||
|
||||
def _create_key(self, block_hash: bytes) -> IPCCacheEngineKey:
|
||||
"""Convert a block hash to an IPC cache engine key"""
|
||||
return IPCCacheEngineKey(
|
||||
model_name=self.model_name,
|
||||
world_size=self.world_size,
|
||||
worker_id=self.worker_id,
|
||||
chunk_hash=block_hash,
|
||||
)
|
||||
|
||||
def _block_hashes_to_keys(
|
||||
self, block_hashes: list[bytes]
|
||||
) -> list[IPCCacheEngineKey]:
|
||||
"""Convert block hashes to IPC cache engine keys"""
|
||||
s = striding_block_hashes(block_hashes, self.blocks_in_chunk)
|
||||
return [self._create_key(block_hash) for block_hash in s]
|
||||
@@ -0,0 +1,221 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Standard
|
||||
import os
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
from lmcache.config import LMCacheEngineConfig as Config
|
||||
from lmcache.logging import init_logger
|
||||
from lmcache.v1.config import LMCacheEngineConfig as V1Config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.multimodal.inputs import PlaceholderRange
|
||||
from vllm.v1.core.sched.output import NewRequestData
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
ENGINE_NAME = "vllm-instance"
|
||||
|
||||
# Thread-safe singleton storage
|
||||
_config_instance: Config | V1Config | None = None
|
||||
_config_lock = threading.Lock()
|
||||
|
||||
|
||||
def is_false(value: str) -> bool:
|
||||
"""Check if the given string value is equivalent to 'false'."""
|
||||
return value.lower() in ("false", "0", "no", "n", "off")
|
||||
|
||||
|
||||
def lmcache_get_or_create_config() -> Config | V1Config:
|
||||
"""Get the LMCache configuration from the environment variable
|
||||
`LMCACHE_CONFIG_FILE`. If the environment variable is not set, this
|
||||
function will return the default configuration.
|
||||
|
||||
This function is thread-safe and implements singleton pattern,
|
||||
ensuring the configuration is loaded only once.
|
||||
"""
|
||||
global _config_instance
|
||||
|
||||
# Double-checked locking for thread-safe singleton
|
||||
if _config_instance is None:
|
||||
with _config_lock:
|
||||
if _config_instance is None: # Check again within lock
|
||||
if is_false(os.getenv("LMCACHE_USE_EXPERIMENTAL", "True")):
|
||||
logger.warning(
|
||||
"Detected LMCACHE_USE_EXPERIMENTAL is set to False. "
|
||||
"Using legacy configuration is deprecated and will "
|
||||
"be remove soon! Please set LMCACHE_USE_EXPERIMENTAL "
|
||||
"to True."
|
||||
)
|
||||
LMCacheEngineConfig = Config # type: ignore[assignment]
|
||||
else:
|
||||
LMCacheEngineConfig = V1Config # type: ignore[assignment]
|
||||
|
||||
if "LMCACHE_CONFIG_FILE" not in os.environ:
|
||||
logger.warning(
|
||||
"No LMCache configuration file is set. Trying to read"
|
||||
" configurations from the environment variables."
|
||||
)
|
||||
logger.warning(
|
||||
"You can set the configuration file through "
|
||||
"the environment variable: LMCACHE_CONFIG_FILE"
|
||||
)
|
||||
_config_instance = LMCacheEngineConfig.from_env()
|
||||
else:
|
||||
config_file = os.environ["LMCACHE_CONFIG_FILE"]
|
||||
logger.info("Loading LMCache config file %s", config_file)
|
||||
_config_instance = LMCacheEngineConfig.from_file(config_file)
|
||||
# Update config from environment variables
|
||||
_config_instance.update_config_from_env()
|
||||
return _config_instance
|
||||
|
||||
|
||||
def hex_hash_to_int16(s: str) -> int:
|
||||
"""
|
||||
Convert a hex hash string to a 16-bit integer.
|
||||
"""
|
||||
return int(s, 16) & 0xFFFF
|
||||
|
||||
|
||||
def apply_mm_hashes_to_token_ids(
|
||||
token_ids: torch.Tensor,
|
||||
mm_hashes: list[str],
|
||||
mm_positions: list["PlaceholderRange"],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Overwrite token_ids in-place for multimodal placeholders using
|
||||
efficient slice assignments.
|
||||
"""
|
||||
n = token_ids.size(0)
|
||||
for hash_str, placeholder in zip(mm_hashes, mm_positions):
|
||||
start, length = placeholder.offset, placeholder.length
|
||||
if start >= n:
|
||||
continue
|
||||
end = min(start + length, n)
|
||||
token_ids[start:end] = hex_hash_to_int16(hash_str)
|
||||
return token_ids
|
||||
|
||||
|
||||
def mla_enabled(model_config: "ModelConfig") -> bool:
|
||||
return (
|
||||
hasattr(model_config, "use_mla")
|
||||
and isinstance(model_config.use_mla, bool)
|
||||
and model_config.use_mla
|
||||
)
|
||||
|
||||
|
||||
def create_lmcache_metadata(
|
||||
vllm_config=None, model_config=None, parallel_config=None, cache_config=None
|
||||
):
|
||||
"""
|
||||
Create LMCacheEngineMetadata from vLLM configuration.
|
||||
|
||||
This function extracts common metadata creation logic that was duplicated
|
||||
across multiple files.
|
||||
|
||||
Args:
|
||||
vllm_config (VllmConfig): vLLM configuration object containing model,
|
||||
parallel, and cache configs (alternative to
|
||||
individual config parameters)
|
||||
model_config (ModelConfig): Model configuration (alternative to
|
||||
vllm_config)
|
||||
parallel_config (ParallelConfig): Parallel configuration (alternative
|
||||
to vllm_config)
|
||||
cache_config (CacheConfig): Cache configuration (alternative to
|
||||
vllm_config)
|
||||
"""
|
||||
# Third Party
|
||||
# First Party
|
||||
from lmcache.config import LMCacheEngineMetadata
|
||||
|
||||
from vllm.utils.torch_utils import get_kv_cache_torch_dtype
|
||||
|
||||
config = lmcache_get_or_create_config()
|
||||
# Support both vllm_config object and individual config parameters
|
||||
if vllm_config is not None:
|
||||
model_cfg = vllm_config.model_config
|
||||
parallel_cfg = vllm_config.parallel_config
|
||||
cache_cfg = vllm_config.cache_config
|
||||
else:
|
||||
if model_config is None or parallel_config is None or cache_config is None:
|
||||
raise ValueError(
|
||||
"Either vllm_config must be provided, or all of "
|
||||
"model_config, parallel_config, and cache_config must be provided."
|
||||
)
|
||||
model_cfg = model_config
|
||||
parallel_cfg = parallel_config
|
||||
cache_cfg = cache_config
|
||||
|
||||
# Get KV cache dtype
|
||||
kv_dtype = get_kv_cache_torch_dtype(cache_cfg.cache_dtype, model_cfg.dtype)
|
||||
|
||||
# Check if MLA is enabled
|
||||
use_mla = mla_enabled(model_cfg)
|
||||
|
||||
# Construct KV shape (for memory pool)
|
||||
num_layer = model_cfg.get_num_layers(parallel_cfg)
|
||||
chunk_size = config.chunk_size
|
||||
num_kv_head = model_cfg.get_num_kv_heads(parallel_cfg)
|
||||
head_size = model_cfg.get_head_size()
|
||||
kv_shape = (num_layer, 1 if use_mla else 2, chunk_size, num_kv_head, head_size)
|
||||
|
||||
# Create metadata
|
||||
metadata = LMCacheEngineMetadata(
|
||||
model_cfg.model,
|
||||
parallel_cfg.world_size,
|
||||
parallel_cfg.rank,
|
||||
"vllm",
|
||||
kv_dtype,
|
||||
kv_shape,
|
||||
use_mla,
|
||||
)
|
||||
|
||||
return metadata, config
|
||||
|
||||
|
||||
def extract_mm_features(
|
||||
request: Union["Request", "NewRequestData"], modify: bool = False
|
||||
) -> tuple[list[str], list["PlaceholderRange"]]:
|
||||
"""
|
||||
Normalize multimodal information from a Request into parallel lists.
|
||||
|
||||
This helper reads either:
|
||||
1) `request.mm_features` (objects each exposing `.identifier` and
|
||||
`.mm_position`), or
|
||||
2) legacy fields `request.mm_hashes` and `request.mm_positions`.
|
||||
|
||||
It returns two equally sized lists: the multimodal hash identifiers and
|
||||
their corresponding positions. If the request contains no multimodal info,
|
||||
it returns `([], [])`.
|
||||
|
||||
Args:
|
||||
request (Request): The source object.
|
||||
modify (bool):
|
||||
Controls copy semantics for the legacy-path return values.
|
||||
- If True and legacy fields are used, shallow-copies are returned so
|
||||
the caller can mutate the lists without affecting `request`.
|
||||
- If False, the original legacy sequences are returned as-is
|
||||
(zero-copy); treat them as read-only.
|
||||
|
||||
Returns:
|
||||
tuple[list[str], list[PlaceholderRange]]: (`mm_hashes`, `mm_positions`).
|
||||
May be `([], [])` when no multimodal data is present.
|
||||
"""
|
||||
if getattr(request, "mm_features", None):
|
||||
mm_hashes, mm_positions = zip(
|
||||
*((f.identifier, f.mm_position) for f in request.mm_features)
|
||||
)
|
||||
return (list(mm_hashes), list(mm_positions))
|
||||
elif getattr(request, "mm_hashes", None):
|
||||
if modify:
|
||||
return (
|
||||
request.mm_hashes.copy(), # type: ignore
|
||||
request.mm_positions.copy(), # type: ignore
|
||||
)
|
||||
else:
|
||||
return (request.mm_hashes, request.mm_positions) # type: ignore
|
||||
else:
|
||||
return ([], [])
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,895 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import enum
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
|
||||
|
||||
import torch
|
||||
import zmq
|
||||
from lmcache.integration.vllm.utils import mla_enabled
|
||||
from lmcache.utils import init_logger as lmcache_init_logger
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1,
|
||||
KVConnectorMetadata,
|
||||
KVConnectorRole,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_integration import (
|
||||
LMCacheMPSchedulerAdapter,
|
||||
LMCacheMPWorkerAdapter,
|
||||
LoadStoreOp,
|
||||
)
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.outputs import KVConnectorOutput
|
||||
from vllm.v1.utils import ConstantList
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_events import KVCacheEvent
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
||||
KVConnectorPromMetrics,
|
||||
KVConnectorStats,
|
||||
PromMetric,
|
||||
PromMetricT,
|
||||
)
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = lmcache_init_logger(__name__)
|
||||
|
||||
|
||||
# Helper functions
|
||||
def reformat_block_ids(block_ids: tuple[list[int], ...] | None) -> list[int]:
|
||||
if block_ids is None:
|
||||
return []
|
||||
assert isinstance(block_ids, tuple), (
|
||||
f"Expected block_ids to be a tuple of lists, but got {type(block_ids)}"
|
||||
)
|
||||
|
||||
if len(block_ids) > 1:
|
||||
raise RuntimeError(
|
||||
"LMCacheMPConnector only works without hybrid kv cache manager. "
|
||||
"Please pass --disable-hybrid-kv-cache-manager when starting vllm"
|
||||
)
|
||||
|
||||
return block_ids[0]
|
||||
|
||||
|
||||
def extract_world_size_and_kv_rank(
|
||||
world_size: int,
|
||||
rank: int,
|
||||
vllm_config: VllmConfig,
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Convert the rank for the MLA.
|
||||
"""
|
||||
use_mla = mla_enabled(vllm_config.model_config)
|
||||
if not use_mla:
|
||||
return world_size, rank
|
||||
else:
|
||||
# Tensor parallel does not change the KV caches for MLA models.
|
||||
# So we need to "exclude" the effect of TP on rank and world size
|
||||
tp_size = vllm_config.parallel_config.tensor_parallel_size
|
||||
# vLLM constructs TP groups first, and then construct other
|
||||
# parallel groups on top of TP groups.
|
||||
# for example, TP=4, PP=2,
|
||||
# TP group: [0, 1, 2, 3], [4, 5, 6, 7]
|
||||
# PP group: [0, 4], [1, 5], [2, 6], [3, 7]
|
||||
# So we can "exclude" the effect of TP by rank // tp_size.
|
||||
return world_size // tp_size, rank // tp_size
|
||||
|
||||
|
||||
def create_scheduler_adapter(
|
||||
server_url: str, zmq_context: zmq.Context, vllm_config: VllmConfig
|
||||
) -> LMCacheMPSchedulerAdapter:
|
||||
world_size, kv_rank = extract_world_size_and_kv_rank(
|
||||
vllm_config.parallel_config.world_size,
|
||||
vllm_config.parallel_config.rank,
|
||||
vllm_config,
|
||||
)
|
||||
return LMCacheMPSchedulerAdapter(
|
||||
server_url,
|
||||
zmq_context,
|
||||
vllm_config.model_config.model,
|
||||
world_size,
|
||||
kv_rank,
|
||||
vllm_config.cache_config.block_size,
|
||||
)
|
||||
|
||||
|
||||
def create_worker_adapter(
|
||||
server_url: str, zmq_context: zmq.Context, vllm_config: VllmConfig
|
||||
) -> LMCacheMPWorkerAdapter:
|
||||
world_size, kv_rank = extract_world_size_and_kv_rank(
|
||||
vllm_config.parallel_config.world_size,
|
||||
vllm_config.parallel_config.rank,
|
||||
vllm_config,
|
||||
)
|
||||
return LMCacheMPWorkerAdapter(
|
||||
server_url,
|
||||
zmq_context,
|
||||
vllm_config.model_config.model,
|
||||
world_size,
|
||||
kv_rank,
|
||||
vllm_config.cache_config.block_size,
|
||||
)
|
||||
|
||||
|
||||
def convert_block_hashes_to_bytes(
|
||||
block_hashes: list["BlockHash"],
|
||||
) -> list[bytes]:
|
||||
return cast(list[bytes], block_hashes)
|
||||
|
||||
|
||||
class LMCacheMPRequestState(enum.Enum):
|
||||
"""
|
||||
State machine:
|
||||
PREFETCHING -- update_state_after_alloc --> WAITING_FOR_LOAD
|
||||
WAITING_FOR_LOAD -- process_loading_requests --> READY
|
||||
"""
|
||||
|
||||
PREFETCHING = enum.auto()
|
||||
WAITING_FOR_LOAD = enum.auto()
|
||||
READY = enum.auto()
|
||||
|
||||
|
||||
@dataclass
|
||||
class LMCacheMPRequestTracker:
|
||||
# NOTE: this class used vLLM data structures, should be part of
|
||||
# vLLM integration code
|
||||
|
||||
request_id: str
|
||||
|
||||
# Read-only lists to track the token ids and block hashes
|
||||
all_token_ids: ConstantList[int]
|
||||
block_hashes: ConstantList["BlockHash"]
|
||||
|
||||
# Block ids and hashes will be updated at update_states_after_alloc and
|
||||
# during the generation
|
||||
allocated_block_ids: list[int] = field(default_factory=list)
|
||||
|
||||
# Number of scheduled tokens in this request. We keep tracking this to
|
||||
# avoid saving half-full blocks.
|
||||
num_scheduled_tokens: int = 0
|
||||
|
||||
# Number of blocks stored will be initialized when lookup the external
|
||||
# hit tokens and will be updated when processing new requests and cached
|
||||
# requests.
|
||||
num_stored_blocks: int = 0
|
||||
|
||||
# Staging load operation -- save vllm and lmcache hit tokens during lookup
|
||||
num_vllm_hit_blocks: int = 0
|
||||
num_lmcache_hit_blocks: int = 0
|
||||
|
||||
# Main state
|
||||
state: LMCacheMPRequestState = LMCacheMPRequestState.PREFETCHING
|
||||
|
||||
def __init__(self, request: "Request"):
|
||||
self.request_id = request.request_id
|
||||
self.all_token_ids = request.all_token_ids
|
||||
self.block_hashes = ConstantList(request.block_hashes)
|
||||
self.allocated_block_ids = []
|
||||
self.num_stored_blocks = 0
|
||||
self.num_vllm_hit_blocks = 0
|
||||
self.num_lmcache_hit_blocks = 0
|
||||
self.state = LMCacheMPRequestState.PREFETCHING
|
||||
|
||||
####
|
||||
# Check the state of the request
|
||||
####
|
||||
def needs_retrieve(self) -> bool:
|
||||
"""Check whether the current request needs retrieve, will be used
|
||||
update_stage_after_alloc"""
|
||||
return (
|
||||
self.num_lmcache_hit_blocks > self.num_vllm_hit_blocks
|
||||
and self.state != LMCacheMPRequestState.READY
|
||||
)
|
||||
|
||||
def is_ready_for_retrieving(self) -> bool:
|
||||
"""Check whether the current request is ready for retrieving,
|
||||
will be used in process_loading_requests"""
|
||||
return (
|
||||
self.state == LMCacheMPRequestState.WAITING_FOR_LOAD
|
||||
and self.needs_retrieve()
|
||||
)
|
||||
|
||||
####
|
||||
# Update internal states
|
||||
####
|
||||
def increase_num_scheduled_tokens(self, num_new_tokens: int):
|
||||
self.num_scheduled_tokens += num_new_tokens
|
||||
|
||||
def increase_num_stored_blocks(self, num_new_blocks: int):
|
||||
"""Increase the number of stored blocks for the current request
|
||||
This function will be called when processing the cached requests.
|
||||
"""
|
||||
self.num_stored_blocks += num_new_blocks
|
||||
|
||||
def update_block_ids(
|
||||
self,
|
||||
new_block_ids: list[int],
|
||||
):
|
||||
"""Update the block ids for the current request
|
||||
This function will be called when processing the cached requests.
|
||||
"""
|
||||
self.allocated_block_ids.extend(new_block_ids)
|
||||
|
||||
####
|
||||
# For debugging
|
||||
####
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"LMCacheMPRequestTracker(request_id={self.request_id}, "
|
||||
f"num_tokens={len(self.all_token_ids)}, "
|
||||
f"num_block_hashes={len(self.block_hashes)}, "
|
||||
f"num_allocated_blocks={len(self.allocated_block_ids)}, "
|
||||
f"num_stored_blocks={self.num_stored_blocks}, "
|
||||
f"vllm_hit_blocks={self.num_vllm_hit_blocks}, "
|
||||
f"lmcache_hit_blocks={self.num_lmcache_hit_blocks}, "
|
||||
f"state={self.state})"
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.__repr__()
|
||||
|
||||
|
||||
@dataclass
|
||||
class LMCacheMPRequestMetadata:
|
||||
request_id: str
|
||||
direction: Literal["STORE", "RETRIEVE"]
|
||||
op: LoadStoreOp
|
||||
|
||||
@staticmethod
|
||||
def GetStoreMetadata(
|
||||
tracker: LMCacheMPRequestTracker,
|
||||
blocks_in_chunk: int,
|
||||
vllm_block_size: int,
|
||||
) -> "LMCacheMPRequestMetadata | None":
|
||||
"""
|
||||
Generate the store metadata for the current request tracker.
|
||||
|
||||
Args:
|
||||
tracker: The request tracker to generate the metadata from.
|
||||
blocks_in_chunk: the number of blocks in a LMCache data chunk
|
||||
"""
|
||||
# Store the blocks that has block hashes
|
||||
# NOTE: the invariant here is that `num_stored_blocks` should
|
||||
# always be a multiple of `blocks_in_chunk`
|
||||
# TODO: This should be checked everytime we update the num_stored_blocks
|
||||
min_available_blocks = min(
|
||||
len(tracker.block_hashes),
|
||||
len(tracker.allocated_block_ids),
|
||||
tracker.num_scheduled_tokens // vllm_block_size,
|
||||
)
|
||||
num_staging_blocks = min_available_blocks - tracker.num_stored_blocks
|
||||
num_chunks = num_staging_blocks // blocks_in_chunk
|
||||
|
||||
if num_chunks >= 1:
|
||||
start = tracker.num_stored_blocks
|
||||
end = start + num_chunks * blocks_in_chunk
|
||||
block_hashes = convert_block_hashes_to_bytes(
|
||||
tracker.block_hashes[start:end]
|
||||
)
|
||||
block_ids = tracker.allocated_block_ids[start:end]
|
||||
|
||||
ret = LMCacheMPRequestMetadata(
|
||||
request_id=tracker.request_id,
|
||||
direction="STORE",
|
||||
op=LoadStoreOp(block_hashes=block_hashes, block_ids=block_ids),
|
||||
)
|
||||
|
||||
# Update the request tracker
|
||||
tracker.increase_num_stored_blocks(end - start)
|
||||
return ret
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def GetRetrieveMetadata(
|
||||
tracker: LMCacheMPRequestTracker,
|
||||
blocks_in_chunk: int,
|
||||
) -> "LMCacheMPRequestMetadata | None":
|
||||
"""
|
||||
Generate the retrieve metadata for the current request tracker.
|
||||
|
||||
Args:
|
||||
tracker: The request tracker to generate the metadata from.
|
||||
blocks_in_chunk: the number of blocks in a LMCache data chunk
|
||||
"""
|
||||
if not tracker.is_ready_for_retrieving():
|
||||
return None
|
||||
|
||||
# |---------------------|-----------------|----------------|
|
||||
# | num_vllm_hit_blocks |
|
||||
# | lmcache chunk 1 | lmcache chunk 2 |
|
||||
# | need to retrieve |
|
||||
|
||||
start = tracker.num_vllm_hit_blocks // blocks_in_chunk * blocks_in_chunk
|
||||
end = tracker.num_lmcache_hit_blocks
|
||||
assert end % blocks_in_chunk == 0, (
|
||||
"The number of LMCache hit blocks should be a multiple of the "
|
||||
"number of blocks in a lmcache chunk. "
|
||||
)
|
||||
assert len(tracker.block_hashes) >= end, (
|
||||
"The number of block hashes should be greater than or equal to the "
|
||||
"number of LMCache hit blocks. "
|
||||
)
|
||||
if end > start:
|
||||
block_hashes = convert_block_hashes_to_bytes(
|
||||
tracker.block_hashes[start:end]
|
||||
)
|
||||
block_ids = tracker.allocated_block_ids[start:end]
|
||||
|
||||
ret = LMCacheMPRequestMetadata(
|
||||
request_id=tracker.request_id,
|
||||
direction="RETRIEVE",
|
||||
op=LoadStoreOp(block_hashes=block_hashes, block_ids=block_ids),
|
||||
)
|
||||
return ret
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class LMCacheMPConnectorMetadata(KVConnectorMetadata):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.requests: list[LMCacheMPRequestMetadata] = []
|
||||
|
||||
def add_request_metadata(self, request_metadata: LMCacheMPRequestMetadata):
|
||||
self.requests.append(request_metadata)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.requests)
|
||||
|
||||
# For debugging
|
||||
def __str__(self):
|
||||
request_strs = []
|
||||
for req_meta in self.requests:
|
||||
request_strs.append(
|
||||
f"RequestMetadata(request_id={req_meta.request_id}, "
|
||||
f"direction={req_meta.direction}, "
|
||||
f"num_blocks={len(req_meta.op)}, "
|
||||
f"block_ids={req_meta.op.block_ids})"
|
||||
)
|
||||
return "[" + "\n".join(request_strs) + "]"
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
class LMCacheMPConnector(KVConnectorBase_V1):
|
||||
"""
|
||||
The connector for LMCache multi-process mode.
|
||||
|
||||
Extra configs (kv_transfer_config.extra_config):
|
||||
- lmcache.mp.host: the host of the LMCache server.
|
||||
- lmcache.mp.port: the port of the LMCache server.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: Optional["KVCacheConfig"] = None,
|
||||
):
|
||||
super().__init__(vllm_config, role, kv_cache_config)
|
||||
|
||||
assert vllm_config.kv_transfer_config is not None
|
||||
server_host = vllm_config.kv_transfer_config.get_from_extra_config(
|
||||
"lmcache.mp.host", "tcp://localhost"
|
||||
)
|
||||
server_port = vllm_config.kv_transfer_config.get_from_extra_config(
|
||||
"lmcache.mp.port", 5555
|
||||
)
|
||||
|
||||
server_url = f"{server_host}:{server_port}"
|
||||
zmq_context = zmq.Context.instance()
|
||||
if self.role == KVConnectorRole.SCHEDULER:
|
||||
self.scheduler_adapter = create_scheduler_adapter(
|
||||
server_url, zmq_context, vllm_config
|
||||
)
|
||||
self.request_trackers: dict[str, LMCacheMPRequestTracker] = {}
|
||||
elif self.role == KVConnectorRole.WORKER:
|
||||
self.worker_adapter = create_worker_adapter(
|
||||
server_url, zmq_context, vllm_config
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown KVConnectorRole: {self.role}")
|
||||
|
||||
self.vllm_block_size = vllm_config.cache_config.block_size
|
||||
|
||||
@property
|
||||
def role(self) -> KVConnectorRole:
|
||||
return self._role
|
||||
|
||||
# ==============================
|
||||
# Worker-side methods
|
||||
# ==============================
|
||||
|
||||
def _get_connector_metadata(self) -> KVConnectorMetadata:
|
||||
"""Get the connector metadata.
|
||||
|
||||
This function should only be called inside the connector.
|
||||
|
||||
Returns:
|
||||
ConnectorMetadata: the connector metadata.
|
||||
"""
|
||||
|
||||
# Should only be called while set to valid metadata.
|
||||
assert self._connector_metadata is not None
|
||||
return self._connector_metadata
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
"""
|
||||
Initialize with the KV caches. Useful for pre-registering the
|
||||
KV Caches in the KVConnector (e.g. for NIXL).
|
||||
|
||||
Args:
|
||||
kv_caches: dictionary of layer names, kv cache
|
||||
"""
|
||||
logger.info("Registering kv caches!")
|
||||
self.worker_adapter.register_kv_caches(kv_caches)
|
||||
return
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None:
|
||||
"""
|
||||
Start loading the KV cache from the connector to vLLM's paged
|
||||
KV buffer. This is called from the forward context before the
|
||||
forward pass to enable async loading during model execution.
|
||||
|
||||
Args:
|
||||
forward_context (ForwardContext): the forward context.
|
||||
**kwargs: additional arguments for the load operation
|
||||
|
||||
Note:
|
||||
The number of elements in kv_caches and layer_names should be
|
||||
the same.
|
||||
|
||||
"""
|
||||
metadata = self._get_connector_metadata()
|
||||
assert isinstance(metadata, LMCacheMPConnectorMetadata)
|
||||
|
||||
with torch.cuda.stream(torch.cuda.current_stream()):
|
||||
event = torch.cuda.Event(interprocess=True)
|
||||
event.record()
|
||||
|
||||
request_ids = []
|
||||
ops = []
|
||||
|
||||
for meta in metadata.requests:
|
||||
if meta.direction != "RETRIEVE":
|
||||
continue
|
||||
request_ids.append(meta.request_id)
|
||||
ops.append(meta.op)
|
||||
|
||||
if len(request_ids) > 0:
|
||||
self.worker_adapter.batched_submit_retrieve_requests(
|
||||
request_ids, ops, event
|
||||
)
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
"""
|
||||
Block until the KV for a specific layer is loaded into vLLM's
|
||||
paged buffer. This is called from within attention layer to ensure
|
||||
async copying from start_load_kv is complete.
|
||||
|
||||
This interface will be useful for layer-by-layer pipelining.
|
||||
|
||||
Args:
|
||||
layer_name: the name of that layer
|
||||
"""
|
||||
return
|
||||
|
||||
def save_kv_layer(
|
||||
self,
|
||||
layer_name: str,
|
||||
kv_layer: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Start saving a layer of KV cache from vLLM's paged buffer
|
||||
to the connector. This is called from within attention layer to
|
||||
enable async copying during execution.
|
||||
|
||||
Args:
|
||||
layer_name (str): the name of the layer.
|
||||
kv_layer (torch.Tensor): the paged KV buffer of the current
|
||||
layer in vLLM.
|
||||
attn_metadata (AttentionMetadata): the attention metadata.
|
||||
**kwargs: additional arguments for the save operation.
|
||||
"""
|
||||
return
|
||||
|
||||
def wait_for_save(self):
|
||||
"""
|
||||
Block until all the save operations is done. This is called
|
||||
as the forward context exits to ensure that the async saving
|
||||
from save_kv_layer is complete before finishing the forward.
|
||||
|
||||
This prevents overwrites of paged KV buffer before saving done.
|
||||
"""
|
||||
metadata = self._get_connector_metadata()
|
||||
assert isinstance(metadata, LMCacheMPConnectorMetadata)
|
||||
|
||||
with torch.cuda.stream(torch.cuda.current_stream()):
|
||||
event = torch.cuda.Event(interprocess=True)
|
||||
event.record()
|
||||
|
||||
request_ids = []
|
||||
ops = []
|
||||
for meta in metadata.requests:
|
||||
if meta.direction != "STORE":
|
||||
continue
|
||||
request_ids.append(meta.request_id)
|
||||
ops.append(meta.op)
|
||||
|
||||
if len(request_ids) > 0:
|
||||
self.worker_adapter.batched_submit_store_requests(request_ids, ops, event)
|
||||
|
||||
def get_finished(
|
||||
self, finished_req_ids: set[str]
|
||||
) -> tuple[set[str] | None, set[str] | None]:
|
||||
"""
|
||||
Notifies worker-side connector ids of requests that have
|
||||
finished generating tokens on the worker.
|
||||
The scheduler process (via the Executors) will use this output
|
||||
to track which workers are done.
|
||||
|
||||
Returns:
|
||||
ids of requests that have finished asynchronous transfer
|
||||
(requests that previously returned True from request_finished()),
|
||||
tuple of (sending/saving ids, recving/loading ids).
|
||||
The finished saves/sends req ids must belong to a set provided in a
|
||||
call to this method (this call or a prior one).
|
||||
"""
|
||||
val = self.worker_adapter.get_finished(finished_req_ids)
|
||||
# logger.error("Finished req ids: %s, %s", val[0], val[1])
|
||||
return val
|
||||
|
||||
def get_block_ids_with_load_errors(self) -> set[int]:
|
||||
"""
|
||||
Get the set of block IDs that failed to load.
|
||||
|
||||
Returns:
|
||||
Set of block IDs that encountered load errors.
|
||||
Empty set if no load errors occurred.
|
||||
|
||||
Notes:
|
||||
- Applies to both sync- and async-loading requests.
|
||||
- Async loading: failed blocks may be reported in any forward pass
|
||||
up to and including the pass where the request ID is returned by
|
||||
`get_finished()`. Even if failures occur, the request must still
|
||||
be reported via `get_finished()`, and the failed block IDs must
|
||||
appear here no later than that same pass.
|
||||
- Sync loading: failed blocks should be reported in the forward
|
||||
pass in which they are detected.
|
||||
"""
|
||||
# TODO: add error tracking
|
||||
return set()
|
||||
|
||||
def shutdown(self):
|
||||
"""
|
||||
Shutdown the connector. This is called when the worker process
|
||||
is shutting down to ensure that all the async operations are
|
||||
completed and the connector is cleaned up properly.
|
||||
"""
|
||||
if hasattr(self, "worker_adapter"):
|
||||
self.worker_adapter.shutdown()
|
||||
return None
|
||||
|
||||
def get_kv_connector_stats(self) -> Optional["KVConnectorStats"]:
|
||||
"""
|
||||
Get the KV connector stats collected during the last interval.
|
||||
"""
|
||||
return None
|
||||
|
||||
# ==============================
|
||||
# Scheduler-side methods
|
||||
# ==============================
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self,
|
||||
request: "Request",
|
||||
num_computed_tokens: int,
|
||||
) -> tuple[int | None, bool]:
|
||||
"""
|
||||
Get number of new tokens that can be loaded from the
|
||||
external KV cache beyond the num_computed_tokens.
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
num_computed_tokens (int): the number of locally
|
||||
computed tokens for this request
|
||||
|
||||
Returns:
|
||||
A tuple with the following elements:
|
||||
- An optional number of tokens that can be loaded from the
|
||||
external KV cache beyond what is already computed.
|
||||
If None, it means that the connector needs more time to
|
||||
determine the number of matched tokens, and the scheduler
|
||||
should query for this request again later.
|
||||
- `True` if external KV cache tokens will be loaded
|
||||
asynchronously (between scheduler steps). Must be
|
||||
'False' if the first element is 0.
|
||||
|
||||
Notes:
|
||||
The connector should only consider the largest prefix of prompt-
|
||||
tokens for which KV cache is actually available at the time of the
|
||||
call. If the cache cannot be loaded for some tokens (e.g., due to
|
||||
connectivity issues or eviction), those tokens must not be taken
|
||||
into account.
|
||||
"""
|
||||
tracker = self._get_or_create_request_tracker(request)
|
||||
|
||||
self.scheduler_adapter.maybe_submit_lookup_request(
|
||||
request.request_id, convert_block_hashes_to_bytes(request.block_hashes)
|
||||
)
|
||||
|
||||
ret = self.scheduler_adapter.check_lookup_result(request.request_id)
|
||||
if ret is None:
|
||||
return None, True
|
||||
|
||||
if ret == 0:
|
||||
return 0, False
|
||||
|
||||
assert (
|
||||
ret % (self.scheduler_adapter.num_blocks_per_chunk() * self.vllm_block_size)
|
||||
== 0
|
||||
)
|
||||
|
||||
# Update num stored blocks for the tracker
|
||||
num_vllm_blocks = num_computed_tokens // self.vllm_block_size
|
||||
num_lmcache_blocks = ret // self.vllm_block_size
|
||||
tracker.increase_num_stored_blocks(num_lmcache_blocks)
|
||||
|
||||
# Save the vllm and lmcache hit tokens
|
||||
tracker.num_vllm_hit_blocks = num_vllm_blocks
|
||||
tracker.num_lmcache_hit_blocks = num_lmcache_blocks
|
||||
|
||||
need_to_load = max(0, ret - num_computed_tokens)
|
||||
logger.debug(
|
||||
"vLLM hit is: %d, Need to load is %d", num_computed_tokens, need_to_load
|
||||
)
|
||||
return need_to_load, need_to_load > 0
|
||||
|
||||
def update_state_after_alloc(
|
||||
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
|
||||
):
|
||||
"""
|
||||
Update KVConnector state after block allocation.
|
||||
|
||||
If get_num_new_matched_tokens previously returned True for a
|
||||
request, this function may be called twice for that same request -
|
||||
first when blocks are allocated for the connector tokens to be
|
||||
asynchronously loaded into, and second when any additional blocks
|
||||
are allocated, after the load/transfer is complete.
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
blocks (KVCacheBlocks): the blocks allocated for the request.
|
||||
num_external_tokens (int): the number of tokens that will be
|
||||
loaded from the external KV cache.
|
||||
"""
|
||||
# NOTE: the `blocks` are NEW BLOCKS allocated for this request.
|
||||
tracker = self._get_request_tracker(request.request_id)
|
||||
block_ids = reformat_block_ids(blocks.get_block_ids())
|
||||
|
||||
# No matter we need to retrieve or not, we need to update
|
||||
# the block ids into the tracker
|
||||
tracker.update_block_ids(block_ids)
|
||||
|
||||
# Update the state of the tracker
|
||||
condition = tracker.needs_retrieve()
|
||||
if tracker.state == LMCacheMPRequestState.PREFETCHING:
|
||||
# If need to retrieve, change to WAITING_FOR_LOAD
|
||||
# Otherwise, change to READY
|
||||
tracker.state = (
|
||||
LMCacheMPRequestState.WAITING_FOR_LOAD
|
||||
if condition
|
||||
else LMCacheMPRequestState.READY
|
||||
)
|
||||
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: SchedulerOutput
|
||||
) -> KVConnectorMetadata:
|
||||
"""
|
||||
Build the connector metadata for this step.
|
||||
|
||||
This function should NOT modify fields in the scheduler_output.
|
||||
Also, calling this function will reset the state of the connector.
|
||||
|
||||
Args:
|
||||
scheduler_output (SchedulerOutput): the scheduler output object.
|
||||
"""
|
||||
metadata = LMCacheMPConnectorMetadata()
|
||||
|
||||
self._process_retrieve_requests(metadata)
|
||||
self._process_new_requests(scheduler_output, metadata)
|
||||
self._process_cached_requests(scheduler_output, metadata)
|
||||
|
||||
if len(metadata) > 0:
|
||||
logger.debug("Final connector metadata: %s", metadata)
|
||||
|
||||
return metadata
|
||||
|
||||
def update_connector_output(self, connector_output: KVConnectorOutput):
|
||||
"""
|
||||
Update KVConnector state from worker-side connectors output.
|
||||
|
||||
Args:
|
||||
connector_output (KVConnectorOutput): the worker-side
|
||||
connectors output.
|
||||
"""
|
||||
return
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
"""
|
||||
Called exactly once when a request has finished, before its blocks are
|
||||
freed.
|
||||
|
||||
The connector may assumes responsibility for freeing the blocks
|
||||
asynchronously by returning True.
|
||||
|
||||
Returns:
|
||||
True if the request is being saved/sent asynchronously and blocks
|
||||
should not be freed until the request_id is returned from
|
||||
get_finished().
|
||||
Optional KVTransferParams to be included in the request outputs
|
||||
returned by the engine.
|
||||
"""
|
||||
return True, None
|
||||
|
||||
def take_events(self) -> Iterable["KVCacheEvent"]:
|
||||
"""
|
||||
Take the KV cache events from the connector.
|
||||
|
||||
Yields:
|
||||
New KV cache events since the last call.
|
||||
"""
|
||||
return ()
|
||||
|
||||
@classmethod
|
||||
def get_required_kvcache_layout(cls, vllm_config: "VllmConfig") -> str | None:
|
||||
"""
|
||||
Get the required KV cache layout for this connector.
|
||||
Args:
|
||||
vllm_config (VllmConfig): the vllm config.
|
||||
|
||||
Returns:
|
||||
str: the required KV cache layout. e.g. HND, or NHD.
|
||||
None if the connector does not require a specific layout.
|
||||
"""
|
||||
|
||||
if cls is KVConnectorBase_V1:
|
||||
raise TypeError(
|
||||
"get_required_kvcache_layout should not be called "
|
||||
"on the abstract base class"
|
||||
)
|
||||
return None
|
||||
|
||||
def get_finished_count(self) -> int | None:
|
||||
"""
|
||||
Get the count of requests expected to complete send/receive operations
|
||||
via this connector. This method is used to initialize the
|
||||
KVOutputAggregator, overwriting the default world_size.
|
||||
|
||||
Returns:
|
||||
int: expected sending or receiving completion count.
|
||||
"""
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def build_kv_connector_stats(
|
||||
cls, data: dict[str, Any] | None = None
|
||||
) -> Optional["KVConnectorStats"]:
|
||||
"""
|
||||
KVConnectorStats resolution method. This method allows dynamically
|
||||
registered connectors to return their own KVConnectorStats object,
|
||||
which can implement custom aggregation logic on the data dict.
|
||||
"""
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def build_prom_metrics(
|
||||
cls,
|
||||
vllm_config: "VllmConfig",
|
||||
metric_types: dict[type["PromMetric"], type["PromMetricT"]],
|
||||
labelnames: list[str],
|
||||
per_engine_labelvalues: dict[int, list[object]],
|
||||
) -> Optional["KVConnectorPromMetrics"]:
|
||||
"""
|
||||
Create a KVConnectorPromMetrics subclass which should register
|
||||
per-connector Prometheus metrics and implement observe() to
|
||||
expose connector transfer stats via Prometheus.
|
||||
"""
|
||||
return None
|
||||
|
||||
##############################
|
||||
# Helper functions
|
||||
##############################
|
||||
def _process_retrieve_requests(
|
||||
self,
|
||||
metadata: LMCacheMPConnectorMetadata,
|
||||
) -> None:
|
||||
blocks_per_chunk = self.scheduler_adapter.num_blocks_per_chunk()
|
||||
|
||||
for request_tracker in self.request_trackers.values():
|
||||
if request_tracker.state != LMCacheMPRequestState.WAITING_FOR_LOAD:
|
||||
continue
|
||||
r_metadata = LMCacheMPRequestMetadata.GetRetrieveMetadata(
|
||||
request_tracker, blocks_per_chunk
|
||||
)
|
||||
if r_metadata is not None:
|
||||
metadata.add_request_metadata(r_metadata)
|
||||
request_tracker.state = LMCacheMPRequestState.READY
|
||||
|
||||
def _process_new_requests(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
metadata: LMCacheMPConnectorMetadata,
|
||||
) -> None:
|
||||
blocks_per_chunk = self.scheduler_adapter.num_blocks_per_chunk()
|
||||
|
||||
for new_request in scheduler_output.scheduled_new_reqs:
|
||||
request_tracker = self._get_request_tracker(new_request.req_id)
|
||||
|
||||
num_new_tokens = scheduler_output.num_scheduled_tokens[new_request.req_id]
|
||||
request_tracker.increase_num_scheduled_tokens(num_new_tokens)
|
||||
|
||||
r_meta = LMCacheMPRequestMetadata.GetStoreMetadata(
|
||||
request_tracker, blocks_per_chunk, self.vllm_block_size
|
||||
)
|
||||
if r_meta is not None:
|
||||
metadata.add_request_metadata(r_meta)
|
||||
|
||||
def _process_cached_requests(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
metadata: LMCacheMPConnectorMetadata,
|
||||
) -> None:
|
||||
blocks_per_chunk = self.scheduler_adapter.num_blocks_per_chunk()
|
||||
|
||||
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||
for idx, request_id in enumerate(cached_reqs.req_ids):
|
||||
request_tracker = self._get_request_tracker(request_id)
|
||||
|
||||
# Update block ids
|
||||
new_block_ids = reformat_block_ids(cached_reqs.new_block_ids[idx])
|
||||
request_tracker.update_block_ids(new_block_ids)
|
||||
|
||||
# Update new scheduled tokens
|
||||
num_new_tokens = cached_reqs.num_computed_tokens[idx]
|
||||
request_tracker.increase_num_scheduled_tokens(num_new_tokens)
|
||||
|
||||
r_meta = LMCacheMPRequestMetadata.GetStoreMetadata(
|
||||
request_tracker, blocks_per_chunk, self.vllm_block_size
|
||||
)
|
||||
|
||||
if r_meta is not None:
|
||||
metadata.add_request_metadata(r_meta)
|
||||
|
||||
def _get_request_tracker(self, request_id: str) -> LMCacheMPRequestTracker:
|
||||
assert request_id in self.request_trackers, (
|
||||
f"Request tracker for request_id {request_id} not found. "
|
||||
)
|
||||
return self.request_trackers[request_id]
|
||||
|
||||
def _get_or_create_request_tracker(
|
||||
self, request: "Request"
|
||||
) -> LMCacheMPRequestTracker:
|
||||
request_id = request.request_id
|
||||
if request_id not in self.request_trackers:
|
||||
new_tracker = LMCacheMPRequestTracker(request)
|
||||
self.request_trackers[request_id] = new_tracker
|
||||
return self.request_trackers[request_id]
|
||||
186
vllm/distributed/kv_transfer/kv_connector/v1/metrics.py
Normal file
186
vllm/distributed/kv_transfer/kv_connector/v1/metrics.py
Normal file
@@ -0,0 +1,186 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, TypeAlias, TypeVar
|
||||
|
||||
from prometheus_client import Counter, Gauge, Histogram
|
||||
|
||||
from vllm.config import KVTransferConfig, VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
|
||||
from vllm.logger import init_logger
|
||||
|
||||
PromMetric: TypeAlias = Gauge | Counter | Histogram
|
||||
PromMetricT = TypeVar("PromMetricT", bound=PromMetric)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class KVConnectorStats:
|
||||
"""
|
||||
Base class for KV Connector Stats, a container for transfer performance
|
||||
metrics or otherwise important telemetry from the connector.
|
||||
All sub-classes need to be serializable as stats are sent from worker to
|
||||
logger process.
|
||||
"""
|
||||
|
||||
data: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def reset(self):
|
||||
"""Reset the stats, clear the state."""
|
||||
raise NotImplementedError
|
||||
|
||||
def aggregate(self, other: "KVConnectorStats") -> "KVConnectorStats":
|
||||
"""
|
||||
Aggregate stats with another `KVConnectorStats` object.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def reduce(self) -> dict[str, int | float]:
|
||||
"""
|
||||
Reduce the observations collected during a time interval to one or
|
||||
more representative values (eg avg/median/sum of the series).
|
||||
This is meant to be called by the logger to produce a summary of the
|
||||
stats for the last time interval.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
"""Return True if the stats are empty."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class KVConnectorLogging:
|
||||
def __init__(self, kv_transfer_config: KVTransferConfig | None):
|
||||
# Instantiate the connector's stats class.
|
||||
if kv_transfer_config and kv_transfer_config.kv_connector:
|
||||
self.connector_cls = KVConnectorFactory.get_connector_class(
|
||||
kv_transfer_config
|
||||
)
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.transfer_stats_accumulator: KVConnectorStats | None = None
|
||||
|
||||
def observe(self, transfer_stats_data: dict[str, Any]):
|
||||
# Should not be called when a KVConnector is not configured.
|
||||
assert self.connector_cls is not None
|
||||
# Called periodically when connector syncs with the scheduler.
|
||||
# Note that this is not the same as the logging interval.
|
||||
# We expect transfer_stats_data to be aggregated across all workers and
|
||||
# consist of observations from a single connector or a MultiConnector.
|
||||
transfer_stats = self.connector_cls.build_kv_connector_stats(
|
||||
transfer_stats_data
|
||||
)
|
||||
if transfer_stats is None:
|
||||
logger.warning_once(
|
||||
"The connector %s is collecting stats but "
|
||||
"does not implement the "
|
||||
"`build_kv_connector_stats` method. "
|
||||
"Stats will not be logged.",
|
||||
self.connector_cls,
|
||||
)
|
||||
return
|
||||
|
||||
if self.transfer_stats_accumulator is None:
|
||||
self.transfer_stats_accumulator = transfer_stats
|
||||
else:
|
||||
# Accumulate last interval stats.
|
||||
self.transfer_stats_accumulator = self.transfer_stats_accumulator.aggregate(
|
||||
transfer_stats
|
||||
)
|
||||
|
||||
def log(self, log_fn=logger.info):
|
||||
"""Log transfer metrics periodically, similar to throughput logging"""
|
||||
if (
|
||||
self.transfer_stats_accumulator
|
||||
and not self.transfer_stats_accumulator.is_empty()
|
||||
):
|
||||
# Produce a single cumulative stats object for the last time
|
||||
# interval from the recorded observations.
|
||||
xfer_metrics = self.transfer_stats_accumulator.reduce()
|
||||
xfer_metrics_str = ", ".join(f"{k}={v}" for k, v in xfer_metrics.items())
|
||||
log_fn("KV Transfer metrics: %s", xfer_metrics_str)
|
||||
|
||||
# Reset metrics for next interval
|
||||
self.reset()
|
||||
|
||||
|
||||
class KVConnectorPromMetrics:
|
||||
"""
|
||||
A base class for per-connector Prometheus metric registration
|
||||
and recording.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
metric_types: dict[type[PromMetric], type[PromMetricT]],
|
||||
labelnames: list[str],
|
||||
per_engine_labelvalues: dict[int, list[object]],
|
||||
):
|
||||
self._kv_transfer_config = vllm_config.kv_transfer_config
|
||||
self._gauge_cls = metric_types[Gauge]
|
||||
self._counter_cls = metric_types[Counter]
|
||||
self._histogram_cls = metric_types[Histogram]
|
||||
self._labelnames = labelnames
|
||||
self._per_engine_labelvalues = per_engine_labelvalues
|
||||
|
||||
def make_per_engine(self, metric: PromMetric) -> dict[int, PromMetric]:
|
||||
"""
|
||||
Create a per-engine child of a prometheus_client.Metric with
|
||||
the appropriate labels set. The parent metric must be created
|
||||
using the labelnames list.
|
||||
"""
|
||||
return {
|
||||
idx: metric.labels(*labelvalues)
|
||||
for idx, labelvalues in self._per_engine_labelvalues.items()
|
||||
}
|
||||
|
||||
def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0):
|
||||
"""
|
||||
Record the supplied transfer statistics to Prometheus metrics. These
|
||||
statistics are engine-specific, and should be recorded to a metric
|
||||
with the appropriate 'engine' label. These metric instances can be
|
||||
created using the make_per_engine() helper method.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class KVConnectorPrometheus:
|
||||
"""
|
||||
Support for registering per-connector Prometheus metrics, and
|
||||
recording transfer statistics to those metrics. Uses
|
||||
KVConnectorBase.build_prom_metrics().
|
||||
"""
|
||||
|
||||
_gauge_cls = Gauge
|
||||
_counter_cls = Counter
|
||||
_histogram_cls = Histogram
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
labelnames: list[str],
|
||||
per_engine_labelvalues: dict[int, list[object]],
|
||||
):
|
||||
self.prom_metrics: KVConnectorPromMetrics | None = None
|
||||
kv_transfer_config = vllm_config.kv_transfer_config
|
||||
if kv_transfer_config and kv_transfer_config.kv_connector:
|
||||
connector_cls = KVConnectorFactory.get_connector_class(kv_transfer_config)
|
||||
metric_types = {
|
||||
Gauge: self._gauge_cls,
|
||||
Counter: self._counter_cls,
|
||||
Histogram: self._histogram_cls,
|
||||
}
|
||||
self.prom_metrics = connector_cls.build_prom_metrics(
|
||||
vllm_config,
|
||||
metric_types,
|
||||
labelnames,
|
||||
per_engine_labelvalues,
|
||||
)
|
||||
|
||||
def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0):
|
||||
if self.prom_metrics is None:
|
||||
return
|
||||
self.prom_metrics.observe(transfer_stats_data, engine_idx)
|
||||
@@ -0,0 +1,914 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import asyncio
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import msgspec
|
||||
import numpy as np
|
||||
import torch
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import TpKVTopology
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1,
|
||||
KVConnectorMetadata,
|
||||
KVConnectorRole,
|
||||
)
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tp_group,
|
||||
)
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket
|
||||
from vllm.v1.attention.backends.utils import get_kv_cache_layout
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.request import RequestStatus
|
||||
|
||||
try:
|
||||
from mooncake.engine import TransferEngine
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Please install mooncake by following the instructions at "
|
||||
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501
|
||||
"to run VLLM with MooncakeTransferEngine."
|
||||
) from e
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.request import Request
|
||||
|
||||
EngineId = str
|
||||
ReqId = str
|
||||
|
||||
TRANS_DONE = b"trans_done"
|
||||
TRANS_ERROR = b"trans_error"
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MooncakeAgentMetadata(
|
||||
msgspec.Struct,
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
# required for @cached_property.
|
||||
dict=True,
|
||||
):
|
||||
remote_hostname: str
|
||||
remote_port: int
|
||||
request_ids: list[ReqId]
|
||||
kv_caches_base_addr: list[int]
|
||||
block_ids: list[list[int]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class RecvReqMeta:
|
||||
local_block_ids: list[int]
|
||||
remote_host: str
|
||||
remote_port: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class SendBlockMeta:
|
||||
local_block_ids: list[int]
|
||||
ready: threading.Event
|
||||
expire_time: float = float("inf")
|
||||
|
||||
|
||||
@dataclass
|
||||
class SendReqMeta:
|
||||
reqs: dict[ReqId, SendBlockMeta]
|
||||
lock: threading.Lock
|
||||
|
||||
|
||||
@dataclass
|
||||
class FinishedSendReqSet:
|
||||
set: set[ReqId]
|
||||
lock: threading.Lock
|
||||
|
||||
|
||||
@dataclass
|
||||
class FinishedReceiveReqSet:
|
||||
set: set[ReqId]
|
||||
lock: asyncio.Lock
|
||||
|
||||
|
||||
class MooncakeConnectorMetadata(KVConnectorMetadata):
|
||||
def __init__(self):
|
||||
self.reqs_to_recv: dict[ReqId, RecvReqMeta] = {}
|
||||
self.reqs_to_send: dict[ReqId, list[int]] = {}
|
||||
|
||||
def add_new_req(
|
||||
self,
|
||||
request_id: ReqId,
|
||||
local_block_ids: list[int],
|
||||
kv_transfer_params: dict[str, Any],
|
||||
load_remote_cache: bool = True,
|
||||
):
|
||||
if load_remote_cache:
|
||||
self.reqs_to_recv[request_id] = RecvReqMeta(
|
||||
local_block_ids=local_block_ids,
|
||||
remote_host=kv_transfer_params["remote_host"],
|
||||
remote_port=kv_transfer_params["remote_port"],
|
||||
)
|
||||
else:
|
||||
self.reqs_to_send[request_id] = local_block_ids
|
||||
|
||||
|
||||
class MooncakeConnector(KVConnectorBase_V1):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: Optional["KVCacheConfig"] = None,
|
||||
):
|
||||
super().__init__(vllm_config, role, kv_cache_config)
|
||||
|
||||
assert vllm_config.kv_transfer_config is not None
|
||||
assert vllm_config.kv_transfer_config.engine_id is not None
|
||||
self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id
|
||||
|
||||
if role == KVConnectorRole.SCHEDULER:
|
||||
self.connector_scheduler: MooncakeConnectorScheduler | None = (
|
||||
MooncakeConnectorScheduler(vllm_config, self.engine_id)
|
||||
)
|
||||
self.connector_worker: MooncakeConnectorWorker | None = None
|
||||
elif role == KVConnectorRole.WORKER:
|
||||
self.connector_scheduler = None
|
||||
self.connector_worker = MooncakeConnectorWorker(vllm_config, self.engine_id)
|
||||
|
||||
############################################################
|
||||
# Scheduler Side Methods
|
||||
############################################################
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: "Request", num_computed_tokens: int
|
||||
) -> tuple[int, bool]:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.get_num_new_matched_tokens(
|
||||
request, num_computed_tokens
|
||||
)
|
||||
|
||||
def update_state_after_alloc(
|
||||
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
|
||||
):
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.update_state_after_alloc(
|
||||
request, blocks, num_external_tokens
|
||||
)
|
||||
|
||||
def build_connector_meta(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
) -> KVConnectorMetadata:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.build_connector_meta(scheduler_output)
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.request_finished(request, block_ids)
|
||||
|
||||
############################################################
|
||||
# Worker Side Methods
|
||||
############################################################
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
assert self.connector_worker is not None
|
||||
self.connector_worker.register_kv_caches(kv_caches)
|
||||
|
||||
def get_finished(
|
||||
self, finished_req_ids: set[str]
|
||||
) -> tuple[set[str] | None, set[str] | None]:
|
||||
"""Get the finished recving and sending requests."""
|
||||
assert self.connector_worker is not None
|
||||
return self.connector_worker.get_finished()
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
|
||||
assert self.connector_worker is not None
|
||||
assert isinstance(self._connector_metadata, MooncakeConnectorMetadata)
|
||||
self.connector_worker.start_load_kv(self._connector_metadata)
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
"""MooncakeConnector does not do layerwise saving."""
|
||||
pass
|
||||
|
||||
def save_kv_layer(
|
||||
self,
|
||||
layer_name: str,
|
||||
kv_layer: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""MooncakeConnector does not save explicitly."""
|
||||
pass
|
||||
|
||||
def wait_for_save(self):
|
||||
pass
|
||||
|
||||
|
||||
class MooncakeConnectorScheduler:
|
||||
"""Implementation of Scheduler side methods"""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, engine_id: str):
|
||||
self.vllm_config = vllm_config
|
||||
self.engine_id: EngineId = engine_id
|
||||
self.side_channel_host = get_ip()
|
||||
self.side_channel_port = get_mooncake_side_channel_port(vllm_config)
|
||||
|
||||
assert vllm_config.kv_transfer_config
|
||||
self.kv_role = vllm_config.kv_transfer_config.kv_role
|
||||
logger.info("Initializing Mooncake Transfer Engine Scheduler %s", engine_id)
|
||||
|
||||
# Requests that need to start recv/send.
|
||||
# New requests are added by update_state_after_alloc in
|
||||
# the scheduler. Used to make metadata passed to Worker.
|
||||
self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {}
|
||||
self._reqs_need_send: dict[ReqId, list[int]] = {}
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: "Request", num_computed_tokens: int
|
||||
) -> tuple[int, bool]:
|
||||
"""
|
||||
For remote prefill, pull all prompt blocks from remote
|
||||
asynchronously relative to engine execution.
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
num_computed_tokens (int): the number of locally
|
||||
computed tokens for this request
|
||||
Returns:
|
||||
* the number of tokens that can be loaded from the
|
||||
external KV cache beyond what is already computed.
|
||||
* true if the external KV cache tokens will be loaded
|
||||
asynchronously (between scheduler steps).
|
||||
"""
|
||||
|
||||
params = request.kv_transfer_params
|
||||
logger.debug(
|
||||
"MooncakeConnector get_num_new_matched_tokens: "
|
||||
"num_computed_tokens=%s, kv_transfer_params=%s",
|
||||
num_computed_tokens,
|
||||
params,
|
||||
)
|
||||
|
||||
if params is not None and params.get("do_remote_prefill"):
|
||||
# Remote prefill: get all prompt blocks from remote.
|
||||
token_ids = request.prompt_token_ids or []
|
||||
count = len(token_ids) - num_computed_tokens
|
||||
if count > 0:
|
||||
return count, True
|
||||
|
||||
# No remote prefill for this request.
|
||||
return 0, False
|
||||
|
||||
def update_state_after_alloc(
|
||||
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
|
||||
):
|
||||
params = request.kv_transfer_params
|
||||
logger.debug(
|
||||
"MooncakeConnector update_state_after_alloc: "
|
||||
"num_external_tokens=%s, kv_transfer_params=%s",
|
||||
num_external_tokens,
|
||||
params,
|
||||
)
|
||||
|
||||
if not params:
|
||||
return
|
||||
|
||||
if params.get("do_remote_prefill"):
|
||||
assert self.kv_role != "kv_producer"
|
||||
if all(p in params for p in ("remote_host", "remote_port")):
|
||||
# If remote_blocks and num_external_tokens = 0, we have
|
||||
# a full prefix cache hit on the D worker. We need to call
|
||||
# send_notif in _read_blocks to free the memory on the P.
|
||||
local_block_ids = (
|
||||
blocks.get_unhashed_block_ids() if num_external_tokens > 0 else []
|
||||
)
|
||||
# Get unhashed blocks to pull from remote.
|
||||
self._reqs_need_recv[request.request_id] = (request, local_block_ids)
|
||||
else:
|
||||
logger.warning(
|
||||
"Got invalid KVTransferParams: %s. This "
|
||||
"request will not utilize KVTransfer",
|
||||
params,
|
||||
)
|
||||
# Only trigger 1 KV transfer per request.
|
||||
params["do_remote_prefill"] = False
|
||||
|
||||
elif params.get("do_remote_decode"):
|
||||
# Add an empty list to worker to create event.
|
||||
self._reqs_need_send[request.request_id] = []
|
||||
|
||||
def build_connector_meta(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
) -> KVConnectorMetadata:
|
||||
meta = MooncakeConnectorMetadata()
|
||||
|
||||
# Loop through scheduled reqs and convert to RecvReqMeta.
|
||||
if self.kv_role != "kv_producer":
|
||||
for req_id, (req, block_ids) in self._reqs_need_recv.items():
|
||||
assert req.kv_transfer_params is not None
|
||||
meta.add_new_req(
|
||||
request_id=req_id,
|
||||
local_block_ids=block_ids,
|
||||
kv_transfer_params=req.kv_transfer_params,
|
||||
)
|
||||
self._reqs_need_recv.clear()
|
||||
|
||||
if self.kv_role != "kv_consumer":
|
||||
for req_id, block_ids in self._reqs_need_send.items():
|
||||
meta.add_new_req(
|
||||
request_id=req_id,
|
||||
local_block_ids=block_ids,
|
||||
kv_transfer_params={},
|
||||
load_remote_cache=False,
|
||||
)
|
||||
self._reqs_need_send.clear()
|
||||
|
||||
return meta
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
"""
|
||||
Once a request is finished, determine whether request blocks
|
||||
should be freed now or will be sent asynchronously and freed later.
|
||||
"""
|
||||
|
||||
params = request.kv_transfer_params
|
||||
logger.debug(
|
||||
"MooncakeConnector request_finished, request_status=%s, "
|
||||
"kv_transfer_params=%s",
|
||||
request.status,
|
||||
params,
|
||||
)
|
||||
if not params:
|
||||
return False, None
|
||||
|
||||
if params.get("do_remote_prefill"):
|
||||
# If do_remote_prefill is still True when the request is finished,
|
||||
# update_state_after_alloc must not have been called (the request
|
||||
# must have been aborted before it was scheduled).
|
||||
# To avoid stranding the prefill blocks in the prefill instance,
|
||||
# we must add empty block_ids to _reqs_need_recv so that our
|
||||
# worker side will notify and free blocks in the prefill instance.
|
||||
assert self.kv_role != "kv_producer"
|
||||
self._reqs_need_recv[request.request_id] = (request, [])
|
||||
params["do_remote_prefill"] = False
|
||||
return False, None
|
||||
|
||||
if (
|
||||
not params.get("do_remote_decode")
|
||||
or request.status != RequestStatus.FINISHED_LENGTH_CAPPED
|
||||
):
|
||||
return False, None
|
||||
|
||||
assert self.kv_role != "kv_consumer"
|
||||
|
||||
# TODO: check whether block_ids actually ever be 0. If not we could
|
||||
# remove the conditional below
|
||||
delay_free_blocks = len(block_ids) > 0
|
||||
|
||||
if delay_free_blocks:
|
||||
self._reqs_need_send[request.request_id] = block_ids
|
||||
|
||||
return delay_free_blocks, dict(
|
||||
do_remote_prefill=True,
|
||||
do_remote_decode=False,
|
||||
remote_host=self.side_channel_host,
|
||||
remote_port=self.side_channel_port,
|
||||
)
|
||||
|
||||
|
||||
class MooncakeConnectorWorker:
|
||||
"""Implementation of Worker side methods"""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, engine_id: str):
|
||||
logger.info("Initializing Mooncake Transfer Engine worker %s", engine_id)
|
||||
|
||||
self.vllm_config = vllm_config
|
||||
|
||||
self.engine = TransferEngine()
|
||||
self.hostname = get_ip()
|
||||
ret_value = self.engine.initialize(self.hostname, "P2PHANDSHAKE", "rdma", "")
|
||||
if ret_value != 0:
|
||||
raise RuntimeError("Mooncake Transfer Engine initialization failed.")
|
||||
|
||||
self.rpc_port = self.engine.get_rpc_port()
|
||||
|
||||
logger.debug(
|
||||
"Mooncake Transfer Engine initialized at %s:%d",
|
||||
self.hostname,
|
||||
self.rpc_port,
|
||||
)
|
||||
|
||||
# Mooncake handshake port.
|
||||
self.side_channel_port: int = get_mooncake_side_channel_port(vllm_config)
|
||||
|
||||
self.engine_id: EngineId = engine_id
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.world_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_group = get_tp_group()
|
||||
self.num_blocks = 0
|
||||
|
||||
assert vllm_config.kv_transfer_config
|
||||
self.kv_role = vllm_config.kv_transfer_config.kv_role
|
||||
self.num_workers = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"num_workers", 10
|
||||
)
|
||||
|
||||
self.kv_caches_base_addr: list[int] = []
|
||||
self.device_kv_caches: dict[str, torch.Tensor] = {}
|
||||
self.reqs_need_send: SendReqMeta = SendReqMeta(reqs={}, lock=threading.Lock())
|
||||
|
||||
# For kv_both, we will act both prefiller and decoder.
|
||||
if self.kv_role != "kv_consumer":
|
||||
# Background thread for sending kvcaches to D.
|
||||
self._mooncake_sender_t: threading.Thread | None = None
|
||||
# Background thread for processing new sending requests.
|
||||
self._sender_executor = ThreadPoolExecutor(
|
||||
max_workers=self.num_workers, thread_name_prefix="vllm-mooncake-sender"
|
||||
)
|
||||
logger.debug(
|
||||
"Mooncake Prefiller: use %d workers to send kvcaches", self.num_workers
|
||||
)
|
||||
if self.kv_role != "kv_producer":
|
||||
self.receiver_loop = asyncio.new_event_loop()
|
||||
self._mooncake_receiver_t = threading.Thread(
|
||||
target=self._receiver_loop, args=(self.receiver_loop,), daemon=True
|
||||
)
|
||||
self._mooncake_receiver_t.start()
|
||||
logger.debug("Mooncake Decoder: start receiver thread")
|
||||
|
||||
self.finished_sending_reqs: FinishedSendReqSet = FinishedSendReqSet(
|
||||
set(), threading.Lock()
|
||||
)
|
||||
self.finished_recving_reqs: FinishedReceiveReqSet = FinishedReceiveReqSet(
|
||||
set(), asyncio.Lock()
|
||||
)
|
||||
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
self.model_config = vllm_config.model_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.use_mla = self.model_config.use_mla
|
||||
|
||||
backend = get_attn_backend(
|
||||
self.model_config.get_head_size(),
|
||||
self.model_config.dtype,
|
||||
self.cache_config.cache_dtype,
|
||||
self.block_size,
|
||||
use_mla=self.use_mla,
|
||||
)
|
||||
self.backend_name = backend.get_name()
|
||||
self.kv_cache_layout = get_kv_cache_layout()
|
||||
logger.debug("Detected attention backend %s", self.backend_name)
|
||||
logger.debug("Detected kv cache layout %s", self.kv_cache_layout)
|
||||
|
||||
self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size}
|
||||
self._block_size: dict[EngineId, int] = {self.engine_id: self.block_size}
|
||||
self.kv_topo = TpKVTopology(
|
||||
tp_rank=self.tp_rank,
|
||||
engine_id=self.engine_id,
|
||||
remote_tp_size=self._tp_size, # shared state
|
||||
remote_block_size=self._block_size, # shared state
|
||||
is_mla=self.use_mla,
|
||||
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
|
||||
attn_backend=backend,
|
||||
)
|
||||
self._use_pallas = self.kv_topo._use_pallas
|
||||
|
||||
self.zmq_ctx = zmq.Context()
|
||||
self.async_zmq_ctx = zmq.asyncio.Context()
|
||||
self._encoder = msgspec.msgpack.Encoder()
|
||||
self._decoder = msgspec.msgpack.Decoder(MooncakeAgentMetadata)
|
||||
|
||||
def __del__(self):
|
||||
self.shutdown()
|
||||
|
||||
def shutdown(self):
|
||||
"""Cleanup background threads on destruction."""
|
||||
self.zmq_ctx.term()
|
||||
self.async_zmq_ctx.term()
|
||||
if self.kv_role != "kv_consumer":
|
||||
self._sender_executor.shutdown(wait=False)
|
||||
if self._mooncake_sender_t:
|
||||
self._mooncake_sender_t.join()
|
||||
if self.kv_role != "kv_producer" and self.receiver_loop.is_running():
|
||||
self.receiver_loop.call_soon_threadsafe(self.receiver_loop.stop)
|
||||
self._mooncake_receiver_t.join()
|
||||
|
||||
def _receiver_loop(self, loop: asyncio.AbstractEventLoop):
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_forever()
|
||||
|
||||
def _mooncake_sender(
|
||||
self, ready_event: threading.Event, base_port: int, tp_rank: int
|
||||
):
|
||||
"""
|
||||
Background thread that listens for Mooncake requests, dispatches them
|
||||
to a thread pool, and sends acknowledgments upon completion.
|
||||
"""
|
||||
|
||||
frontend_path = make_zmq_path("tcp", self.hostname, base_port + tp_rank)
|
||||
frontend = make_zmq_socket(self.zmq_ctx, frontend_path, zmq.ROUTER)
|
||||
logger.debug("Mooncake sender starting listening on path: %s", frontend_path)
|
||||
|
||||
backend_path = make_zmq_path("inproc", str(uuid.uuid4()))
|
||||
backend = make_zmq_socket(self.zmq_ctx, backend_path, zmq.PULL)
|
||||
|
||||
poller = zmq.Poller()
|
||||
poller.register(frontend, zmq.POLLIN)
|
||||
poller.register(backend, zmq.POLLIN)
|
||||
|
||||
ready_event.set()
|
||||
|
||||
try:
|
||||
while True:
|
||||
sockets = dict(poller.poll())
|
||||
|
||||
if frontend in sockets:
|
||||
identity, _, metadata_bytes = frontend.recv_multipart()
|
||||
self._sender_executor.submit(
|
||||
self._sender_worker,
|
||||
identity,
|
||||
metadata_bytes,
|
||||
backend_path,
|
||||
)
|
||||
|
||||
if backend in sockets:
|
||||
identity, status = backend.recv_multipart()
|
||||
frontend.send_multipart((identity, b"", status))
|
||||
|
||||
except zmq.ContextTerminated:
|
||||
logger.debug("ZMQ context terminated, exiting Mooncake sender thread.")
|
||||
except Exception as e:
|
||||
logger.error("Error in Mooncake sender thread: %s. Exiting thread.", str(e))
|
||||
finally:
|
||||
frontend.close()
|
||||
backend.close()
|
||||
|
||||
def _sender_worker(
|
||||
self, identity: bytes, metadata_bytes: bytes, worker_channel_path: str
|
||||
):
|
||||
status = TRANS_ERROR
|
||||
|
||||
try:
|
||||
metadata = self._decoder.decode(metadata_bytes)
|
||||
self.send_kv_to_decode(metadata)
|
||||
status = TRANS_DONE
|
||||
except Exception as e:
|
||||
logger.error("Error processing Mooncake handshake: %s", e)
|
||||
finally:
|
||||
pusher = make_zmq_socket(self.zmq_ctx, worker_channel_path, zmq.PUSH)
|
||||
try:
|
||||
pusher.send_multipart((identity, status))
|
||||
except zmq.ZMQError as e:
|
||||
logger.warning(
|
||||
"Internal error, maybe the server is shutting down. Error: %s",
|
||||
e,
|
||||
)
|
||||
finally:
|
||||
pusher.close()
|
||||
|
||||
def send_kv_to_decode(self, meta: MooncakeAgentMetadata):
|
||||
send_reqs: list[tuple[ReqId, SendBlockMeta]] = []
|
||||
with self.reqs_need_send.lock:
|
||||
for req_id in meta.request_ids:
|
||||
send_meta = self.reqs_need_send.reqs.get(req_id)
|
||||
if send_meta is None:
|
||||
logger.warning("Request %s not found in reqs_need_send", req_id)
|
||||
return
|
||||
# Mark it as not expired. We will send it now.
|
||||
send_meta.expire_time = float("inf")
|
||||
send_reqs.append((req_id, send_meta))
|
||||
|
||||
self._send_blocks(send_reqs, meta)
|
||||
|
||||
with self.reqs_need_send.lock:
|
||||
for req_id in meta.request_ids:
|
||||
del self.reqs_need_send.reqs[req_id]
|
||||
|
||||
with self.finished_sending_reqs.lock:
|
||||
self.finished_sending_reqs.set.update(meta.request_ids)
|
||||
|
||||
def _send_blocks(
|
||||
self,
|
||||
send_reqs: list[tuple[ReqId, SendBlockMeta]],
|
||||
agent_meta: MooncakeAgentMetadata,
|
||||
):
|
||||
src_ptrs = []
|
||||
dst_ptrs = []
|
||||
lengths = []
|
||||
local_base_addr = self.kv_caches_base_addr
|
||||
remote_base_addr = agent_meta.kv_caches_base_addr
|
||||
block_len = self.block_len
|
||||
remote_session = f"{agent_meta.remote_hostname}:{agent_meta.remote_port}"
|
||||
|
||||
assert len(send_reqs) == len(agent_meta.block_ids)
|
||||
for (req_id, send_meta), remote_block_ids in zip(
|
||||
send_reqs, agent_meta.block_ids
|
||||
):
|
||||
send_meta.ready.wait()
|
||||
|
||||
num_remote_blocks = len(remote_block_ids)
|
||||
if num_remote_blocks == 0:
|
||||
continue
|
||||
|
||||
local_block_ids = send_meta.local_block_ids
|
||||
# Partial prefix cache hit: just read uncomputed blocks.
|
||||
num_local_blocks = len(local_block_ids)
|
||||
assert num_local_blocks >= num_remote_blocks
|
||||
if num_local_blocks > num_remote_blocks:
|
||||
local_block_ids = local_block_ids[-num_remote_blocks:]
|
||||
|
||||
# Group by indices
|
||||
group_local_block_ids, group_remote_block_ids = group_concurrent_contiguous(
|
||||
local_block_ids, remote_block_ids
|
||||
)
|
||||
|
||||
for local_layer_addr, remote_layer_addr in zip(
|
||||
local_base_addr, remote_base_addr
|
||||
):
|
||||
for group_local_block_id, group_remote_block_id in zip(
|
||||
group_local_block_ids, group_remote_block_ids
|
||||
):
|
||||
src_ptrs.append(
|
||||
local_layer_addr + group_local_block_id[0] * block_len
|
||||
)
|
||||
dst_ptrs.append(
|
||||
remote_layer_addr + group_remote_block_id[0] * block_len
|
||||
)
|
||||
lengths.append(block_len * len(group_local_block_id))
|
||||
|
||||
logger.debug(
|
||||
"Sending kv_caches for request %s (%d blocks) to %s",
|
||||
req_id,
|
||||
num_remote_blocks,
|
||||
remote_session,
|
||||
)
|
||||
|
||||
start_time = time.perf_counter()
|
||||
ret_value = self.engine.batch_transfer_sync_write(
|
||||
remote_session, src_ptrs, dst_ptrs, lengths
|
||||
)
|
||||
if ret_value != 0:
|
||||
raise RuntimeError(f"Error in batch_transfer_sync_write: {ret_value}")
|
||||
|
||||
logger.debug(
|
||||
"Sending to %s done, took %s",
|
||||
remote_session,
|
||||
time.perf_counter() - start_time,
|
||||
)
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
"""Register the KV Cache data in mooncake."""
|
||||
|
||||
logger.info("Registering KV_Caches. use_mla: %s", self.use_mla)
|
||||
|
||||
kv_data_ptrs = []
|
||||
kv_data_lens = []
|
||||
seen_base_addresses = []
|
||||
|
||||
split_k_and_v = self.kv_topo.split_k_and_v
|
||||
tensor_size_bytes = None
|
||||
for layer_name, cache_or_caches in kv_caches.items():
|
||||
logger.debug(
|
||||
"registering layer %s with shape %s", layer_name, cache_or_caches.shape
|
||||
)
|
||||
cache_list = cache_or_caches if split_k_and_v else [cache_or_caches]
|
||||
|
||||
for cache in cache_list:
|
||||
base_addr = cache.data_ptr()
|
||||
if base_addr in seen_base_addresses:
|
||||
continue
|
||||
|
||||
seen_base_addresses.append(base_addr)
|
||||
curr_tensor_size_bytes = cache.nbytes
|
||||
|
||||
if tensor_size_bytes is None:
|
||||
tensor_size_bytes = curr_tensor_size_bytes
|
||||
self.num_blocks = cache.shape[0]
|
||||
|
||||
assert tensor_size_bytes == curr_tensor_size_bytes, (
|
||||
"All kv cache tensors must have the same size"
|
||||
)
|
||||
kernel_block_size = cache.shape[-2 if self.use_mla else -3]
|
||||
assert self.block_size == kernel_block_size
|
||||
kv_data_ptrs.append(base_addr)
|
||||
kv_data_lens.append(tensor_size_bytes)
|
||||
|
||||
self.kv_caches_base_addr = seen_base_addresses
|
||||
|
||||
ret_value = self.engine.batch_register_memory(kv_data_ptrs, kv_data_lens)
|
||||
if ret_value != 0:
|
||||
raise RuntimeError("Mooncake batch memory registration failed.")
|
||||
|
||||
assert tensor_size_bytes is not None
|
||||
assert self.num_blocks != 0
|
||||
assert tensor_size_bytes % self.num_blocks == 0
|
||||
self.block_len = tensor_size_bytes // self.num_blocks
|
||||
self.device_kv_caches = kv_caches
|
||||
logger.debug(
|
||||
"registered num_blocks=%d block_len=%d", self.num_blocks, self.block_len
|
||||
)
|
||||
|
||||
# No need to launch server for D node.
|
||||
if self.kv_role == "kv_consumer":
|
||||
return
|
||||
|
||||
ready_event = threading.Event()
|
||||
self._mooncake_sender_t = threading.Thread(
|
||||
target=self._mooncake_sender,
|
||||
args=(ready_event, self.side_channel_port, self.tp_rank),
|
||||
daemon=True,
|
||||
name="mooncake_sender",
|
||||
)
|
||||
self._mooncake_sender_t.start()
|
||||
ready_event.wait() # Wait for listener ZMQ socket to be ready.
|
||||
|
||||
async def fetch_finished_recving_reqs(self) -> set[ReqId]:
|
||||
async with self.finished_recving_reqs.lock:
|
||||
finished_recving_reqs = self.finished_recving_reqs.set
|
||||
self.finished_recving_reqs.set = set()
|
||||
return finished_recving_reqs
|
||||
|
||||
def get_finished(self) -> tuple[set[str] | None, set[str] | None]:
|
||||
"""
|
||||
Get requests that are done sending or recving on this specific worker.
|
||||
The scheduler process (via the MultiprocExecutor) will use this output
|
||||
to track which workers are done.
|
||||
"""
|
||||
fut = None
|
||||
if self.kv_role != "kv_producer":
|
||||
fut = asyncio.run_coroutine_threadsafe(
|
||||
self.fetch_finished_recving_reqs(), self.receiver_loop
|
||||
)
|
||||
|
||||
if self.kv_role != "kv_consumer":
|
||||
with self.finished_sending_reqs.lock:
|
||||
finished_sending_reqs = self.finished_sending_reqs.set
|
||||
self.finished_sending_reqs.set = set()
|
||||
else:
|
||||
finished_sending_reqs = set()
|
||||
|
||||
finished_recving_reqs = fut.result() if fut else set()
|
||||
|
||||
if finished_sending_reqs or finished_recving_reqs:
|
||||
logger.debug(
|
||||
"Rank %s, get_finished: %s requests done sending "
|
||||
"and %s requests done recving",
|
||||
self.tp_rank,
|
||||
len(finished_sending_reqs),
|
||||
len(finished_recving_reqs),
|
||||
)
|
||||
|
||||
# Handle timeout to avoid stranding blocks on remote.
|
||||
now = time.perf_counter()
|
||||
with self.reqs_need_send.lock:
|
||||
expired_reqs = [
|
||||
req_id
|
||||
for req_id, send_meta in self.reqs_need_send.reqs.items()
|
||||
if send_meta.expire_time < now
|
||||
]
|
||||
for req_id in expired_reqs:
|
||||
logger.warning(
|
||||
"Request %s timed out after %d seconds without "
|
||||
"being sent. Freeing its blocks on the producer side.",
|
||||
req_id,
|
||||
envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT,
|
||||
)
|
||||
del self.reqs_need_send.reqs[req_id]
|
||||
if expired_reqs:
|
||||
finished_sending_reqs.update(expired_reqs)
|
||||
|
||||
return finished_sending_reqs or None, finished_recving_reqs or None
|
||||
|
||||
async def receive_kv(self, path: str, req_blocks: list[tuple[str, list[int]]]):
|
||||
req_ids, block_ids = map(list, zip(*req_blocks))
|
||||
metadata = MooncakeAgentMetadata(
|
||||
remote_hostname=self.hostname,
|
||||
remote_port=self.rpc_port,
|
||||
request_ids=req_ids,
|
||||
kv_caches_base_addr=self.kv_caches_base_addr,
|
||||
block_ids=block_ids,
|
||||
)
|
||||
|
||||
encoded_data = self._encoder.encode(metadata)
|
||||
logger.debug(
|
||||
"Size of encoded MooncakeAgentMetadata: %d bytes", len(encoded_data)
|
||||
)
|
||||
logger.debug("Sending kv transfer request for %s on path: %s", req_ids, path)
|
||||
|
||||
# Send query for the request.
|
||||
sock: zmq.asyncio.Socket = make_zmq_socket(
|
||||
self.async_zmq_ctx, path, zmq.REQ, bind=False, linger=0
|
||||
)
|
||||
sock.setsockopt(zmq.RCVTIMEO, 60000)
|
||||
try:
|
||||
await sock.send(encoded_data)
|
||||
ret_msg = await sock.recv()
|
||||
if ret_msg != TRANS_DONE:
|
||||
logger.error(
|
||||
"Error happens during tranfering kvcache for %s, see logs in prefiller.", # noqa: E501
|
||||
req_ids,
|
||||
)
|
||||
return
|
||||
except zmq.ContextTerminated:
|
||||
logger.debug("ZMQ context terminated, exiting Mooncake receiver thread.")
|
||||
except Exception as e:
|
||||
logger.error("MooncakeAgentMetadata transfer failed for %s: %s", req_ids, e)
|
||||
return
|
||||
finally:
|
||||
sock.close()
|
||||
|
||||
async with self.finished_recving_reqs.lock:
|
||||
self.finished_recving_reqs.set.update(req_ids)
|
||||
|
||||
logger.debug("pulling kv_caches for %s finished", req_ids)
|
||||
|
||||
def group_kv_pull(self, metadata: MooncakeConnectorMetadata):
|
||||
kv_pulls = defaultdict(list)
|
||||
for req_id, meta in metadata.reqs_to_recv.items():
|
||||
logger.debug(
|
||||
"start_load_kv for request %s from remote engine. "
|
||||
"Num local_block_ids: %s.",
|
||||
req_id,
|
||||
len(meta.local_block_ids),
|
||||
)
|
||||
path = make_zmq_path(
|
||||
"tcp", meta.remote_host, meta.remote_port + self.tp_rank
|
||||
)
|
||||
kv_pulls[path].append((req_id, meta.local_block_ids))
|
||||
|
||||
return kv_pulls
|
||||
|
||||
def start_load_kv(self, metadata: MooncakeConnectorMetadata):
|
||||
if self.kv_role != "kv_producer":
|
||||
kv_pulls = self.group_kv_pull(metadata)
|
||||
for path, req_blocks in kv_pulls.items():
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.receive_kv(path, req_blocks), self.receiver_loop
|
||||
)
|
||||
|
||||
if self.kv_role != "kv_consumer":
|
||||
with self.reqs_need_send.lock:
|
||||
for req_id, block_ids in metadata.reqs_to_send.items():
|
||||
if block_ids:
|
||||
# Already gone through request_finished()
|
||||
send_meta = self.reqs_need_send.reqs[req_id]
|
||||
send_meta.local_block_ids = block_ids
|
||||
send_meta.ready.set()
|
||||
send_meta.expire_time = (
|
||||
time.perf_counter()
|
||||
+ envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT
|
||||
)
|
||||
else:
|
||||
# From update_state_after_alloc(),
|
||||
# but not reach request_finished() yet
|
||||
self.reqs_need_send.reqs[req_id] = SendBlockMeta(
|
||||
local_block_ids=[], ready=threading.Event()
|
||||
)
|
||||
|
||||
|
||||
def group_concurrent_contiguous(
|
||||
src_indices: list[int], dst_indices: list[int]
|
||||
) -> tuple[list[list[int]], list[list[int]]]:
|
||||
"""Vectorised NumPy implementation."""
|
||||
if len(src_indices) == 0:
|
||||
return [], []
|
||||
|
||||
brk = np.where((np.diff(src_indices) != 1) | (np.diff(dst_indices) != 1))[0] + 1
|
||||
src_groups = np.split(src_indices, brk)
|
||||
dst_groups = np.split(dst_indices, brk)
|
||||
|
||||
src_groups = [g.tolist() for g in src_groups]
|
||||
dst_groups = [g.tolist() for g in dst_groups]
|
||||
|
||||
return src_groups, dst_groups
|
||||
|
||||
|
||||
def get_mooncake_side_channel_port(vllm_config: VllmConfig) -> int:
|
||||
# This logic is now centralized
|
||||
return (
|
||||
envs.VLLM_MOONCAKE_BOOTSTRAP_PORT
|
||||
+ vllm_config.parallel_config.data_parallel_rank
|
||||
* vllm_config.parallel_config.tensor_parallel_size
|
||||
)
|
||||
464
vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
Normal file
464
vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
Normal file
@@ -0,0 +1,464 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.kv_transfer import KVTransferConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1,
|
||||
KVConnectorMetadata,
|
||||
KVConnectorRole,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
||||
KVConnectorPromMetrics,
|
||||
KVConnectorStats,
|
||||
PromMetric,
|
||||
PromMetricT,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.outputs import KVConnectorOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.distributed.kv_events import KVCacheEvent
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiKVConnectorMetadata(KVConnectorMetadata):
|
||||
metadata: tuple[KVConnectorMetadata, ...]
|
||||
extra_async_saves: dict[str, int] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiKVConnectorStats(KVConnectorStats):
|
||||
"""
|
||||
Maintain a dict of KVConnectorStats objects, one for each connector.
|
||||
This is used to aggregate the stats from all connectors separately.
|
||||
"""
|
||||
|
||||
def aggregate(self, other: KVConnectorStats) -> KVConnectorStats:
|
||||
for connector_id, stats in other.data.items():
|
||||
if connector_id not in self.data:
|
||||
self[connector_id] = stats
|
||||
else:
|
||||
assert isinstance(stats, type(self.data[connector_id]))
|
||||
self[connector_id] = self[connector_id].aggregate(stats)
|
||||
return self
|
||||
|
||||
def reset(self):
|
||||
for stats in self.data.values():
|
||||
stats.reset()
|
||||
|
||||
def reduce(self) -> dict[str, Any]:
|
||||
# TODO (NickLucche) Adjust for logging on separate lines
|
||||
return {
|
||||
connector_id: stats.reduce() for connector_id, stats in self.data.items()
|
||||
}
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
return all(stats.is_empty() for stats in self.data.values())
|
||||
|
||||
def __getitem__(self, connector_id: str) -> KVConnectorStats:
|
||||
return self.data[connector_id]
|
||||
|
||||
def __setitem__(self, connector_id: str, stats: KVConnectorStats):
|
||||
self.data[connector_id] = stats
|
||||
|
||||
|
||||
class MultiKVConnectorPromMetrics(KVConnectorPromMetrics):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
metric_types: dict[type[PromMetric], type[PromMetricT]],
|
||||
labelnames: list[str],
|
||||
per_engine_labelvalues: dict[int, list[object]],
|
||||
prom_metrics: dict[str, KVConnectorPromMetrics],
|
||||
):
|
||||
super().__init__(vllm_config, metric_types, labelnames, per_engine_labelvalues)
|
||||
self._prom_metrics = prom_metrics
|
||||
|
||||
def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0):
|
||||
for connector_id, stats_data in transfer_stats_data.items():
|
||||
assert connector_id in self._prom_metrics, (
|
||||
f"{connector_id} is not contained in the list of registered connectors "
|
||||
f"with Prometheus metrics support: {self._prom_metrics.keys()}"
|
||||
)
|
||||
self._prom_metrics[connector_id].observe(stats_data["data"], engine_idx)
|
||||
|
||||
|
||||
class MultiConnector(KVConnectorBase_V1):
|
||||
"""
|
||||
A wrapper for using multiple KVConnectors at the same time.
|
||||
|
||||
The current logic is:
|
||||
- Load KV from the first connector that advertises available tokens from
|
||||
get_num_new_matched_tokens(), based on the order in the config.
|
||||
- Save to all connectors.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: "KVCacheConfig",
|
||||
):
|
||||
super().__init__(
|
||||
vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config
|
||||
)
|
||||
|
||||
self._connectors: list[KVConnectorBase_V1] = []
|
||||
self._ktc_kv_transfer_config = []
|
||||
for connector_cls, temp_config in self._get_connector_classes_and_configs(
|
||||
vllm_config
|
||||
):
|
||||
self._connectors.append(connector_cls(temp_config, role, kv_cache_config))
|
||||
self._ktc_kv_transfer_config.append(temp_config.kv_transfer_config)
|
||||
|
||||
# A mapping from request id to the index of the connector chosen to
|
||||
# load the request from (if any).
|
||||
self._requests_to_connector: dict[str, int] = {}
|
||||
|
||||
# Keeps track of *additional* remaining async saves (beyond 1) to be
|
||||
# finished per request. Not needed for async loads since we only allow
|
||||
# a single connector to load.
|
||||
# Propagated from scheduler to worker side via the connector metadata.
|
||||
self._extra_async_saves: dict[str, int] = {}
|
||||
|
||||
@classmethod
|
||||
def _get_connector_classes_and_configs(
|
||||
cls, vllm_config: "VllmConfig"
|
||||
) -> list[tuple[type[KVConnectorBaseType], "VllmConfig"]]:
|
||||
assert vllm_config.kv_transfer_config is not None
|
||||
ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"connectors"
|
||||
)
|
||||
assert ktcs is not None
|
||||
ret: list[tuple[type[KVConnectorBaseType], VllmConfig]] = []
|
||||
for ktc in ktcs:
|
||||
temp_config = copy.copy(vllm_config)
|
||||
engine_id = ktc.get("engine_id", vllm_config.kv_transfer_config.engine_id)
|
||||
temp_config.kv_transfer_config = KVTransferConfig(
|
||||
**ktc, engine_id=engine_id
|
||||
)
|
||||
ret.append(
|
||||
(
|
||||
KVConnectorFactory.get_connector_class(
|
||||
temp_config.kv_transfer_config
|
||||
),
|
||||
temp_config,
|
||||
)
|
||||
)
|
||||
return ret
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
for c in self._connectors:
|
||||
c.register_kv_caches(kv_caches)
|
||||
|
||||
# We must override the base class method here because we need to bind
|
||||
# the metadata to each connector in the order of the connectors in the
|
||||
# MultiKVConnectorMetadata.
|
||||
#
|
||||
# Note: Call the base class method to ensure metadata is also set on the
|
||||
# MultiConnector instance itself; otherwise, `has_connector_metadata()` will
|
||||
# always return False.
|
||||
def bind_connector_metadata(self, connector_metadata: KVConnectorMetadata) -> None:
|
||||
assert isinstance(connector_metadata, MultiKVConnectorMetadata)
|
||||
if connector_metadata.extra_async_saves:
|
||||
self._extra_async_saves.update(connector_metadata.extra_async_saves)
|
||||
for c, cm in zip(self._connectors, connector_metadata.metadata):
|
||||
c.bind_connector_metadata(cm)
|
||||
super().bind_connector_metadata(connector_metadata)
|
||||
|
||||
def clear_connector_metadata(self) -> None:
|
||||
for c in self._connectors:
|
||||
c.clear_connector_metadata()
|
||||
super().clear_connector_metadata()
|
||||
|
||||
def shutdown(self):
|
||||
exception: Exception | None = None
|
||||
for c in self._connectors:
|
||||
try:
|
||||
c.shutdown()
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Exception during connector %s shutdown.", c.__class__.__name__
|
||||
)
|
||||
exception = e
|
||||
if exception:
|
||||
raise exception
|
||||
|
||||
# ==============================
|
||||
# Worker-side methods
|
||||
# ==============================
|
||||
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
|
||||
for c in self._connectors:
|
||||
c.start_load_kv(forward_context, **kwargs)
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
for c in self._connectors:
|
||||
c.wait_for_layer_load(layer_name)
|
||||
|
||||
def save_kv_layer(
|
||||
self,
|
||||
layer_name: str,
|
||||
kv_layer: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
for c in self._connectors:
|
||||
c.save_kv_layer(layer_name, kv_layer, attn_metadata, **kwargs)
|
||||
|
||||
def wait_for_save(self):
|
||||
for c in self._connectors:
|
||||
c.wait_for_save()
|
||||
|
||||
def get_finished(
|
||||
self, finished_req_ids: set[str]
|
||||
) -> tuple[set[str] | None, set[str] | None]:
|
||||
finished_sending: set[str] = set()
|
||||
finished_recving: set[str] = set()
|
||||
for c in self._connectors:
|
||||
sending, recving = c.get_finished(finished_req_ids)
|
||||
if not recving and not sending:
|
||||
continue
|
||||
# Aggregate finished recving request ids.
|
||||
finished_recving.update(recving or ())
|
||||
# Aggregate finished sending request ids - only include
|
||||
# once we've drained the "extra" count (for cases where
|
||||
# more than one connector is async-saving the same request).
|
||||
for req_id in sending or ():
|
||||
extra_pending = self._extra_async_saves.get(req_id)
|
||||
if extra_pending is None:
|
||||
finished_sending.add(req_id)
|
||||
continue
|
||||
assert extra_pending > 0
|
||||
if extra_pending == 1:
|
||||
del self._extra_async_saves[req_id]
|
||||
else:
|
||||
self._extra_async_saves[req_id] = extra_pending - 1
|
||||
|
||||
return finished_sending or None, finished_recving or None
|
||||
|
||||
def get_block_ids_with_load_errors(self) -> set[int]:
|
||||
agg_block_ids: set[int] = set()
|
||||
for c in self._connectors:
|
||||
agg_block_ids |= c.get_block_ids_with_load_errors()
|
||||
return agg_block_ids
|
||||
|
||||
# TODO: Add a generic implementation of 'get_kv_connector_kv_cache_events' method
|
||||
# for the MultiConnector. It should be able to get events from multiple
|
||||
# connectors, handling the case where only a subset of the requested connectors
|
||||
# implements the 'get_kv_connector_kv_cache_events'
|
||||
# Follow on PR from https://github.com/vllm-project/vllm/pull/28309#pullrequestreview-3566351082
|
||||
|
||||
# ==============================
|
||||
# Scheduler-side methods
|
||||
# ==============================
|
||||
def get_num_new_matched_tokens(
|
||||
self,
|
||||
request: "Request",
|
||||
num_computed_tokens: int,
|
||||
) -> tuple[int | None, bool]:
|
||||
to_return = (0, False)
|
||||
for i, c in enumerate(self._connectors):
|
||||
toks, load_async = c.get_num_new_matched_tokens(
|
||||
request, num_computed_tokens
|
||||
)
|
||||
# If there is a connector still looking up the matches,
|
||||
# we return None to indicate that we are not done yet.
|
||||
if toks is None:
|
||||
return (None, False)
|
||||
# The first connector that has new matched tokens will be assigned
|
||||
# to this request.
|
||||
if to_return[0] == 0 and toks > 0:
|
||||
self._requests_to_connector[request.request_id] = i
|
||||
to_return = (toks, load_async)
|
||||
return to_return
|
||||
|
||||
def update_state_after_alloc(
|
||||
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
|
||||
):
|
||||
chosen_connector = self._requests_to_connector.get(request.request_id, -1)
|
||||
empty_blocks = blocks.new_empty()
|
||||
for i, c in enumerate(self._connectors):
|
||||
if i == chosen_connector:
|
||||
# Forward call to the chosen connector (if any).
|
||||
c.update_state_after_alloc(request, blocks, num_external_tokens)
|
||||
else:
|
||||
# Call with empty blocks for other connectors.
|
||||
c.update_state_after_alloc(request, empty_blocks, 0)
|
||||
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: SchedulerOutput
|
||||
) -> MultiKVConnectorMetadata:
|
||||
metadata = MultiKVConnectorMetadata(
|
||||
metadata=tuple(
|
||||
c.build_connector_meta(scheduler_output) for c in self._connectors
|
||||
)
|
||||
)
|
||||
if self._extra_async_saves:
|
||||
metadata.extra_async_saves = self._extra_async_saves
|
||||
self._extra_async_saves = {}
|
||||
return metadata
|
||||
|
||||
def update_connector_output(self, connector_output: KVConnectorOutput):
|
||||
for c in self._connectors:
|
||||
c.update_connector_output(connector_output)
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
blocks: list[int],
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
async_saves = 0
|
||||
kv_txfer_params = None
|
||||
for c in self._connectors:
|
||||
async_save, txfer_params = c.request_finished(request, blocks)
|
||||
if async_save:
|
||||
async_saves += 1
|
||||
if txfer_params is not None:
|
||||
if kv_txfer_params is not None:
|
||||
# TODO we can probably change this to merge the dicts here,
|
||||
# checking for key clashes.
|
||||
raise RuntimeError(
|
||||
"Only one connector can produce KV transfer params"
|
||||
)
|
||||
kv_txfer_params = txfer_params
|
||||
if async_saves > 1:
|
||||
self._extra_async_saves[request.request_id] = async_saves - 1
|
||||
|
||||
# Clean up other state for this request.
|
||||
self._requests_to_connector.pop(request.request_id, None)
|
||||
|
||||
return async_saves > 0, kv_txfer_params
|
||||
|
||||
def take_events(self) -> Iterable["KVCacheEvent"]:
|
||||
for c in self._connectors:
|
||||
yield from c.take_events()
|
||||
|
||||
@classmethod
|
||||
def get_required_kvcache_layout(cls, vllm_config: "VllmConfig") -> str | None:
|
||||
"""
|
||||
Get the required KV cache layout for this connector.
|
||||
Args:
|
||||
vllm_config (VllmConfig): the vllm config.
|
||||
|
||||
Returns:
|
||||
str: the required KV cache layout. e.g. HND, or NHD.
|
||||
None if the connector does not require a specific layout.
|
||||
"""
|
||||
assert vllm_config.kv_transfer_config is not None
|
||||
layouts: set[str] = set()
|
||||
for connector_cls, temp_config in cls._get_connector_classes_and_configs(
|
||||
vllm_config
|
||||
):
|
||||
required_kvcache_layout = connector_cls.get_required_kvcache_layout(
|
||||
temp_config
|
||||
)
|
||||
if required_kvcache_layout is not None:
|
||||
layouts.add(required_kvcache_layout)
|
||||
|
||||
if len(layouts) > 1:
|
||||
raise ValueError(
|
||||
f"KV cache layout mismatch: "
|
||||
f"found {len(layouts)} different layouts "
|
||||
f"({', '.join(layouts)})."
|
||||
f"All connectors must use the same layout."
|
||||
)
|
||||
return next(iter(layouts), None)
|
||||
|
||||
@classmethod
|
||||
def build_kv_connector_stats(
|
||||
cls, data: dict[str, Any] | None = None
|
||||
) -> KVConnectorStats | None:
|
||||
if data is None:
|
||||
return MultiKVConnectorStats()
|
||||
|
||||
# data is a dict mapping connector name to their stats data.
|
||||
# The stats data can be either:
|
||||
# 1. Already-instantiated KVConnectorStats objects (same process)
|
||||
# 2. Serialized dicts (cross-process after serialization)
|
||||
# We need to reconstruct proper KVConnectorStats objects from dicts
|
||||
reconstructed_data = {}
|
||||
for connector_name, stats_value in data.items():
|
||||
# If already a KVConnectorStats object, use it directly
|
||||
if isinstance(stats_value, KVConnectorStats):
|
||||
reconstructed_data[connector_name] = stats_value
|
||||
continue
|
||||
|
||||
# Otherwise, reconstruct from serialized dict
|
||||
# Get the connector class to reconstruct its stats
|
||||
connector_cls = KVConnectorFactory.get_connector_class_by_name(
|
||||
connector_name
|
||||
)
|
||||
|
||||
# stats_value is the serialized dataclass which contains {'data': {...}}
|
||||
# We need to extract the inner 'data' field to avoid double-nesting
|
||||
assert isinstance(stats_value, dict) and "data" in stats_value, (
|
||||
f"Expected a dict with a 'data' field, got {stats_value}"
|
||||
)
|
||||
inner_data = stats_value["data"]
|
||||
|
||||
# Use the connector's build_kv_connector_stats to reconstruct
|
||||
if reconstructed_stats := connector_cls.build_kv_connector_stats(
|
||||
data=inner_data
|
||||
):
|
||||
reconstructed_data[connector_name] = reconstructed_stats
|
||||
|
||||
return MultiKVConnectorStats(data=reconstructed_data)
|
||||
|
||||
def get_kv_connector_stats(self) -> MultiKVConnectorStats | None:
|
||||
# Group connector stats by connector type.
|
||||
stats_by_connector: MultiKVConnectorStats | None = None
|
||||
for c in self._connectors:
|
||||
stats = c.get_kv_connector_stats()
|
||||
if stats is None:
|
||||
continue
|
||||
if stats_by_connector is None:
|
||||
# Lazy init to allow optional return value.
|
||||
stats_by_connector = MultiKVConnectorStats()
|
||||
stats_by_connector[c.__class__.__name__] = stats
|
||||
return stats_by_connector
|
||||
|
||||
@classmethod
|
||||
def build_prom_metrics(
|
||||
cls,
|
||||
vllm_config: "VllmConfig",
|
||||
metric_types: dict[type["PromMetric"], type["PromMetricT"]],
|
||||
labelnames: list[str],
|
||||
per_engine_labelvalues: dict[int, list[object]],
|
||||
) -> KVConnectorPromMetrics:
|
||||
prom_metrics: dict[str, KVConnectorPromMetrics] = {}
|
||||
for connector_cls, temp_config in cls._get_connector_classes_and_configs(
|
||||
vllm_config
|
||||
):
|
||||
connector_prom = connector_cls.build_prom_metrics(
|
||||
temp_config, metric_types, labelnames, per_engine_labelvalues
|
||||
)
|
||||
if connector_prom is not None:
|
||||
prom_metrics[connector_cls.__name__] = connector_prom
|
||||
return MultiKVConnectorPromMetrics(
|
||||
vllm_config,
|
||||
metric_types,
|
||||
labelnames,
|
||||
per_engine_labelvalues,
|
||||
prom_metrics,
|
||||
)
|
||||
|
||||
def reset_cache(self) -> bool:
|
||||
results = [c.reset_cache() is not False for c in self._connectors]
|
||||
return all(results)
|
||||
2526
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
Normal file
2526
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,538 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable, Iterator
|
||||
from dataclasses import dataclass
|
||||
from itertools import islice
|
||||
from typing import Any, ClassVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.distributed.kv_events import BlockRemoved, BlockStored, KVCacheEvent
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import (
|
||||
KVConnectorBase_V1,
|
||||
KVConnectorRole,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.kv_offload.abstract import OffloadingManager
|
||||
from vllm.v1.kv_offload.factory import OffloadingSpecFactory
|
||||
from vllm.v1.kv_offload.mediums import GPULoadStoreSpec
|
||||
from vllm.v1.kv_offload.spec import OffloadingSpec
|
||||
from vllm.v1.kv_offload.worker.worker import OffloadingWorker, TransferSpec
|
||||
from vllm.v1.outputs import KVConnectorOutput
|
||||
from vllm.v1.request import Request
|
||||
|
||||
ReqId = str
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OffloadingConnectorMetadata(KVConnectorMetadata):
|
||||
reqs_to_load: dict[ReqId, TransferSpec]
|
||||
reqs_to_store: dict[ReqId, TransferSpec]
|
||||
|
||||
|
||||
class OffloadingConnector(KVConnectorBase_V1):
|
||||
prefer_cross_layer_blocks: ClassVar[bool] = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: KVCacheConfig | None = None,
|
||||
):
|
||||
super().__init__(vllm_config, role, kv_cache_config)
|
||||
|
||||
spec = OffloadingSpecFactory.create_spec(vllm_config)
|
||||
|
||||
self.connector_scheduler: OffloadingConnectorScheduler | None = None
|
||||
self.connector_worker: OffloadingConnectorWorker | None = None
|
||||
if role == KVConnectorRole.SCHEDULER:
|
||||
self.connector_scheduler = OffloadingConnectorScheduler(spec)
|
||||
elif role == KVConnectorRole.WORKER:
|
||||
self.connector_worker = OffloadingConnectorWorker(spec)
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
assert self.connector_worker is not None
|
||||
self.connector_worker.register_kv_caches(kv_caches)
|
||||
|
||||
def register_cross_layers_kv_cache(
|
||||
self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
|
||||
):
|
||||
assert self.connector_worker is not None
|
||||
self.connector_worker.register_cross_layers_kv_cache(kv_cache, attn_backend)
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
|
||||
assert self.connector_worker is not None
|
||||
assert isinstance(self._connector_metadata, OffloadingConnectorMetadata)
|
||||
self.connector_worker.start_load_kv(self._connector_metadata)
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
pass
|
||||
|
||||
def save_kv_layer(
|
||||
self,
|
||||
layer_name: str,
|
||||
kv_layer: torch.Tensor,
|
||||
attn_metadata: "AttentionMetadata",
|
||||
**kwargs,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def wait_for_save(self):
|
||||
assert self.connector_worker is not None
|
||||
assert isinstance(self._connector_metadata, OffloadingConnectorMetadata)
|
||||
self.connector_worker.start_store_kv(self._connector_metadata)
|
||||
|
||||
def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
|
||||
assert self.connector_worker is not None
|
||||
return self.connector_worker.get_finished(finished_req_ids)
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: "Request", num_computed_tokens: int
|
||||
) -> tuple[int, bool]:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.get_num_new_matched_tokens(
|
||||
request, num_computed_tokens
|
||||
)
|
||||
|
||||
def update_state_after_alloc(
|
||||
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
|
||||
):
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.update_state_after_alloc(
|
||||
request, blocks, num_external_tokens
|
||||
)
|
||||
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: SchedulerOutput
|
||||
) -> KVConnectorMetadata:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.build_connector_meta(scheduler_output)
|
||||
|
||||
def update_connector_output(self, connector_output: KVConnectorOutput):
|
||||
assert self.connector_scheduler is not None
|
||||
self.connector_scheduler.update_connector_output(connector_output)
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.request_finished(request, block_ids)
|
||||
|
||||
def take_events(self) -> Iterable[KVCacheEvent]:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.take_events()
|
||||
|
||||
|
||||
class OffloadingConnectorScheduler:
|
||||
"""Implementation of Scheduler side methods"""
|
||||
|
||||
def __init__(self, spec: OffloadingSpec):
|
||||
self.gpu_block_size = spec.gpu_block_size
|
||||
self.offloaded_block_size = spec.offloaded_block_size
|
||||
self.block_size_factor = self.offloaded_block_size // self.gpu_block_size
|
||||
self.manager: OffloadingManager = spec.get_manager()
|
||||
|
||||
self._requests: dict[ReqId, Request] = {}
|
||||
# list of GPU block IDs per request
|
||||
self._request_block_ids: dict[ReqId, list[int]] = {}
|
||||
# requests to load for the current scheduler step
|
||||
self._reqs_to_load: dict[ReqId, TransferSpec] = {}
|
||||
# request blocks are stored in order
|
||||
# index of next block (of size offloaded_block_size) to offload
|
||||
self._next_stored_block_idx: dict[ReqId, int] = {}
|
||||
|
||||
# request ID -> set(block hashes being stored/load)
|
||||
self._reqs_being_stored = defaultdict[ReqId, set[BlockHash]](set)
|
||||
self._reqs_being_loaded = defaultdict[ReqId, set[BlockHash]](set)
|
||||
|
||||
def _get_block_hashes(
|
||||
self,
|
||||
req: Request,
|
||||
start_idx: int = 0,
|
||||
end_idx: int | None = None,
|
||||
) -> Iterable[BlockHash]:
|
||||
return islice(
|
||||
req.block_hashes,
|
||||
self.block_size_factor * start_idx + self.block_size_factor - 1,
|
||||
self.block_size_factor * end_idx if end_idx else None,
|
||||
self.block_size_factor,
|
||||
)
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: Request, num_computed_tokens: int
|
||||
) -> tuple[int, bool]:
|
||||
"""
|
||||
Get number of new tokens that can be loaded beyond the
|
||||
num_computed_tokens.
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
num_computed_tokens (int): the number of locally
|
||||
computed tokens for this request
|
||||
|
||||
Returns:
|
||||
A tuple with the following elements:
|
||||
- The number of tokens that can be loaded beyond what is
|
||||
already computed.
|
||||
- `True` if tokens will be loaded asynchronously
|
||||
(between scheduler steps).
|
||||
"""
|
||||
num_blocks = request.num_tokens // self.offloaded_block_size
|
||||
|
||||
assert len(request.block_hashes) // self.block_size_factor == num_blocks
|
||||
block_hashes = self._get_block_hashes(request)
|
||||
|
||||
self.manager.touch(block_hashes)
|
||||
|
||||
full_block_tokens = self.offloaded_block_size * num_blocks
|
||||
if full_block_tokens - num_computed_tokens < self.offloaded_block_size:
|
||||
# we can load less than a block, skip
|
||||
return 0, False
|
||||
|
||||
start_block_idx = num_computed_tokens // self.offloaded_block_size
|
||||
hits = self.manager.lookup(
|
||||
self._get_block_hashes(request, start_idx=start_block_idx)
|
||||
)
|
||||
if hits == 0:
|
||||
return 0, False
|
||||
|
||||
num_hit_tokens = (
|
||||
self.offloaded_block_size * (start_block_idx + hits) - num_computed_tokens
|
||||
)
|
||||
logger.debug(
|
||||
"Request %s hit %s offloaded tokens after %s GPU hit tokens",
|
||||
request.request_id,
|
||||
num_hit_tokens,
|
||||
num_computed_tokens,
|
||||
)
|
||||
if num_hit_tokens < self.offloaded_block_size:
|
||||
return 0, False
|
||||
|
||||
return num_hit_tokens, True
|
||||
|
||||
def update_state_after_alloc(
|
||||
self, request: Request, blocks: KVCacheBlocks, num_external_tokens: int
|
||||
):
|
||||
self._requests[request.request_id] = request
|
||||
# the block ids are updated in _get_reqs_to_store
|
||||
self._request_block_ids[request.request_id] = []
|
||||
|
||||
if num_external_tokens == 0:
|
||||
return
|
||||
|
||||
block_groups = blocks.get_block_ids()
|
||||
block_ids = block_groups[0]
|
||||
|
||||
num_computed_gpu_blocks = sum(
|
||||
block.block_hash is not None for block in blocks.blocks[0]
|
||||
)
|
||||
num_computed_tokens = num_computed_gpu_blocks * self.gpu_block_size
|
||||
full_block_tokens = num_computed_tokens + num_external_tokens
|
||||
assert full_block_tokens % self.offloaded_block_size == 0
|
||||
|
||||
num_pending_gpu_blocks = len(block_ids) - num_computed_gpu_blocks
|
||||
assert num_external_tokens == num_pending_gpu_blocks * self.gpu_block_size
|
||||
|
||||
start_block_idx = num_computed_tokens // self.offloaded_block_size
|
||||
num_blocks = full_block_tokens // self.offloaded_block_size
|
||||
|
||||
assert len(request.block_hashes) // self.block_size_factor >= num_blocks
|
||||
block_hashes = self._get_block_hashes(
|
||||
request, start_idx=start_block_idx, end_idx=num_blocks
|
||||
)
|
||||
|
||||
src_spec = self.manager.prepare_load(block_hashes)
|
||||
dst_spec = GPULoadStoreSpec(block_ids[num_computed_gpu_blocks:])
|
||||
|
||||
block_hashes = self._get_block_hashes(
|
||||
request, start_idx=start_block_idx, end_idx=num_blocks
|
||||
)
|
||||
|
||||
self._reqs_to_load[request.request_id] = (src_spec, dst_spec)
|
||||
self._reqs_being_loaded[request.request_id].update(block_hashes)
|
||||
self._next_stored_block_idx[request.request_id] = num_blocks
|
||||
|
||||
def _get_reqs_to_store(self, scheduler_output: SchedulerOutput):
|
||||
reqs_to_store: dict[ReqId, TransferSpec] = {}
|
||||
# iterate over both new and cached requests
|
||||
for req_id, new_block_id_groups, preempted in yield_req_data(scheduler_output):
|
||||
if preempted:
|
||||
self._request_block_ids[req_id] = []
|
||||
|
||||
if new_block_id_groups:
|
||||
new_block_ids = new_block_id_groups[0]
|
||||
self._request_block_ids[req_id] += new_block_ids
|
||||
|
||||
block_ids = self._request_block_ids[req_id]
|
||||
|
||||
req = self._requests[req_id]
|
||||
new_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
total_tokens = req.num_computed_tokens + new_tokens
|
||||
num_blocks = total_tokens // self.offloaded_block_size
|
||||
start_block_idx = self._next_stored_block_idx.get(req_id, 0)
|
||||
num_new_blocks = num_blocks - start_block_idx
|
||||
|
||||
if num_new_blocks <= 0:
|
||||
continue
|
||||
|
||||
# NOTE: In async scheduling, placeholders may temporarily make
|
||||
# len(req.block_hashes) < num_blocks * self.block_size_factor.
|
||||
|
||||
new_block_hashes = self._get_block_hashes(
|
||||
req, start_idx=start_block_idx, end_idx=num_blocks
|
||||
)
|
||||
store_output = self.manager.prepare_store(new_block_hashes)
|
||||
if store_output is None:
|
||||
logger.warning(
|
||||
"Request %s: cannot store %s blocks", req_id, num_new_blocks
|
||||
)
|
||||
continue
|
||||
|
||||
self._next_stored_block_idx[req_id] = num_blocks
|
||||
|
||||
if not store_output.block_hashes_to_store:
|
||||
continue
|
||||
block_hashes_to_store = set(store_output.block_hashes_to_store)
|
||||
|
||||
block_hashes = self._get_block_hashes(req, end_idx=num_blocks)
|
||||
self.manager.touch(block_hashes)
|
||||
|
||||
new_block_hashes = self._get_block_hashes(
|
||||
req, start_idx=start_block_idx, end_idx=num_blocks
|
||||
)
|
||||
dst_spec = store_output.store_spec
|
||||
src_block_ids: list[int] = []
|
||||
for idx, blk_hash in enumerate(new_block_hashes):
|
||||
if blk_hash not in block_hashes_to_store:
|
||||
continue
|
||||
offloaded_block_idx = start_block_idx + idx
|
||||
gpu_block_idx = offloaded_block_idx * self.block_size_factor
|
||||
for i in range(self.block_size_factor):
|
||||
src_block_ids.append(block_ids[gpu_block_idx + i])
|
||||
src_spec = GPULoadStoreSpec(src_block_ids)
|
||||
|
||||
reqs_to_store[req_id] = (src_spec, dst_spec)
|
||||
self._reqs_being_stored[req_id] |= block_hashes_to_store
|
||||
|
||||
logger.debug(
|
||||
"Request %s offloading %s blocks starting from block #%d",
|
||||
req_id,
|
||||
len(block_hashes_to_store),
|
||||
start_block_idx,
|
||||
)
|
||||
|
||||
return reqs_to_store
|
||||
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: SchedulerOutput
|
||||
) -> KVConnectorMetadata:
|
||||
meta = OffloadingConnectorMetadata(
|
||||
reqs_to_load=self._reqs_to_load,
|
||||
reqs_to_store=self._get_reqs_to_store(scheduler_output),
|
||||
)
|
||||
self._reqs_to_load = {}
|
||||
return meta
|
||||
|
||||
def update_connector_output(self, connector_output: KVConnectorOutput):
|
||||
"""
|
||||
Update KVConnector state from worker-side connectors output.
|
||||
|
||||
Args:
|
||||
connector_output (KVConnectorOutput): the worker-side
|
||||
connectors output.
|
||||
"""
|
||||
for req_id in connector_output.finished_sending or []:
|
||||
block_hashes = self._reqs_being_stored.pop(req_id, None)
|
||||
if block_hashes:
|
||||
self.manager.complete_store(block_hashes)
|
||||
|
||||
for req_id in connector_output.finished_recving or []:
|
||||
block_hashes = self._reqs_being_loaded.pop(req_id, None)
|
||||
if block_hashes:
|
||||
self.manager.complete_load(block_hashes)
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: Request,
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
"""
|
||||
Called when a request has finished, before its blocks are freed.
|
||||
|
||||
Returns:
|
||||
True if the request is being saved/sent asynchronously and blocks
|
||||
should not be freed until the request_id is returned from
|
||||
get_finished().
|
||||
Optional KVTransferParams to be included in the request outputs
|
||||
returned by the engine.
|
||||
"""
|
||||
req_id = request.request_id
|
||||
self._requests.pop(req_id, None)
|
||||
self._request_block_ids.pop(req_id, None)
|
||||
self._next_stored_block_idx.pop(req_id, None)
|
||||
|
||||
request_being_stored = req_id in self._reqs_being_stored
|
||||
return request_being_stored, None
|
||||
|
||||
def take_events(self) -> Iterable[KVCacheEvent]:
|
||||
"""Take the KV cache events from the connector.
|
||||
|
||||
Returns:
|
||||
A list of KV cache events.
|
||||
"""
|
||||
for event in self.manager.take_events():
|
||||
if event.removed:
|
||||
yield BlockRemoved(block_hashes=event.block_hashes, medium=event.medium)
|
||||
else:
|
||||
yield BlockStored(
|
||||
block_hashes=event.block_hashes,
|
||||
parent_block_hash=None,
|
||||
token_ids=[],
|
||||
lora_id=None,
|
||||
block_size=event.block_size,
|
||||
medium=event.medium,
|
||||
)
|
||||
|
||||
|
||||
class OffloadingConnectorWorker:
|
||||
"""Implementation of Worker side methods"""
|
||||
|
||||
def __init__(self, spec: OffloadingSpec):
|
||||
self.spec = spec
|
||||
self.worker = OffloadingWorker()
|
||||
|
||||
self._job_counter = 0
|
||||
|
||||
# req_id -> (job_id, store)
|
||||
self._jobs: dict[int, tuple[ReqId, bool]] = {}
|
||||
# req_id -> active job IDs
|
||||
self._load_job: dict[ReqId, int] = {}
|
||||
# req_id -> set(active job IDs)
|
||||
self._store_jobs = defaultdict[ReqId, set[int]](set)
|
||||
|
||||
self._finished_reqs_waiting_for_store: set[ReqId] = set()
|
||||
|
||||
def _generate_job_id(self) -> int:
|
||||
job_id = self._job_counter
|
||||
self._job_counter = job_id + 1
|
||||
return job_id
|
||||
|
||||
def _register_handlers(
|
||||
self,
|
||||
kv_caches: dict[str, torch.Tensor],
|
||||
attn_backends: dict[str, type[AttentionBackend]],
|
||||
):
|
||||
for src_cls, dst_cls, handler in self.spec.get_handlers(
|
||||
kv_caches, attn_backends
|
||||
):
|
||||
self.worker.register_handler(src_cls, dst_cls, handler)
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
layer_names = list(kv_caches.keys())
|
||||
layers = get_layers_from_vllm_config(
|
||||
self.spec.vllm_config, Attention, layer_names
|
||||
)
|
||||
attn_backends = {
|
||||
layer_name: layers[layer_name].get_attn_backend()
|
||||
for layer_name in layer_names
|
||||
}
|
||||
self._register_handlers(kv_caches, attn_backends)
|
||||
|
||||
def register_cross_layers_kv_cache(
|
||||
self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
|
||||
):
|
||||
cross_layer_name = "ALL_LAYERS"
|
||||
kv_caches = {cross_layer_name: kv_cache}
|
||||
attn_backends = {cross_layer_name: attn_backend}
|
||||
self._register_handlers(kv_caches, attn_backends)
|
||||
|
||||
def start_load_kv(self, metadata: OffloadingConnectorMetadata):
|
||||
for req_id, transfer_spec in metadata.reqs_to_load.items():
|
||||
job_id = self._generate_job_id()
|
||||
self._jobs[job_id] = (req_id, False)
|
||||
assert req_id not in self._load_job
|
||||
self._load_job[req_id] = job_id
|
||||
assert self.worker.transfer_async(job_id, transfer_spec)
|
||||
|
||||
def start_store_kv(self, metadata: OffloadingConnectorMetadata):
|
||||
for req_id, transfer_spec in metadata.reqs_to_store.items():
|
||||
job_id = self._generate_job_id()
|
||||
self._jobs[job_id] = (req_id, True)
|
||||
self._store_jobs[req_id].add(job_id)
|
||||
assert self.worker.transfer_async(job_id, transfer_spec)
|
||||
|
||||
def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
|
||||
"""
|
||||
Notifies worker-side connector ids of requests that have
|
||||
finished generating tokens.
|
||||
Returns a list of request IDs that finished loading or storing.
|
||||
|
||||
Returns:
|
||||
ids of requests that have finished asynchronous transfer
|
||||
tuple of (sending/saving ids, recving/loading ids).
|
||||
"""
|
||||
finished_sending = set()
|
||||
finished_recving = set()
|
||||
for job_id, success in self.worker.get_finished():
|
||||
# we currently do not support job failures
|
||||
assert success
|
||||
req_id, store = self._jobs.pop(job_id)
|
||||
if store:
|
||||
req_jobs = self._store_jobs[req_id]
|
||||
req_jobs.remove(job_id)
|
||||
if req_jobs:
|
||||
continue
|
||||
|
||||
if req_id in self._finished_reqs_waiting_for_store:
|
||||
self._finished_reqs_waiting_for_store.remove(req_id)
|
||||
finished_sending.add(req_id)
|
||||
del self._store_jobs[req_id]
|
||||
else:
|
||||
req_job = self._load_job[req_id]
|
||||
assert job_id == req_job
|
||||
del self._load_job[req_id]
|
||||
finished_recving.add(req_id)
|
||||
|
||||
for req_id in finished_req_ids:
|
||||
pending_req_jobs = self._store_jobs.get(req_id)
|
||||
if pending_req_jobs:
|
||||
self._finished_reqs_waiting_for_store.add(req_id)
|
||||
elif pending_req_jobs is not None:
|
||||
finished_sending.add(req_id)
|
||||
del self._store_jobs[req_id]
|
||||
|
||||
return finished_sending, finished_recving
|
||||
|
||||
|
||||
def yield_req_data(
|
||||
scheduler_output,
|
||||
) -> Iterator[tuple[str, tuple[list[int], ...], bool]]:
|
||||
"""
|
||||
Yields:
|
||||
(req_id, new_block_id_groups, preempted)
|
||||
"""
|
||||
# new requests
|
||||
for req_data in scheduler_output.scheduled_new_reqs:
|
||||
yield req_data.req_id, req_data.block_ids, False
|
||||
|
||||
# cached requests
|
||||
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||
yield from zip(
|
||||
cached_reqs.req_ids,
|
||||
cached_reqs.new_block_ids,
|
||||
(req_id in cached_reqs.resumed_req_ids for req_id in cached_reqs.req_ids),
|
||||
)
|
||||
@@ -0,0 +1,531 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1,
|
||||
KVConnectorMetadata,
|
||||
KVConnectorRole,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_engine import (
|
||||
P2pNcclEngine,
|
||||
)
|
||||
from vllm.distributed.parallel_state import get_world_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.mla.common import MLACommonMetadata
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReqMeta:
|
||||
# Request Id
|
||||
request_id: str
|
||||
# Request block ids
|
||||
block_ids: torch.Tensor
|
||||
# Request num tokens
|
||||
num_tokens: int
|
||||
|
||||
@staticmethod
|
||||
def make_meta(
|
||||
request_id: str, token_ids: list[int], block_ids: list[int], block_size: int
|
||||
) -> "ReqMeta":
|
||||
block_ids_tensor = torch.tensor(block_ids)
|
||||
return ReqMeta(
|
||||
request_id=request_id,
|
||||
block_ids=block_ids_tensor,
|
||||
num_tokens=len(token_ids),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class P2pNcclConnectorMetadata(KVConnectorMetadata):
|
||||
requests: list[ReqMeta]
|
||||
|
||||
def __init__(self):
|
||||
self.requests = []
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
request_id: str,
|
||||
token_ids: list[int],
|
||||
block_ids: list[int],
|
||||
block_size: int,
|
||||
) -> None:
|
||||
self.requests.append(
|
||||
ReqMeta.make_meta(request_id, token_ids, block_ids, block_size)
|
||||
)
|
||||
|
||||
|
||||
class P2pNcclConnector(KVConnectorBase_V1):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: Optional["KVCacheConfig"] = None,
|
||||
):
|
||||
super().__init__(
|
||||
vllm_config=vllm_config,
|
||||
role=role,
|
||||
kv_cache_config=kv_cache_config,
|
||||
)
|
||||
self._block_size = vllm_config.cache_config.block_size
|
||||
self._requests_need_load: dict[str, Any] = {}
|
||||
self.is_producer = self._kv_transfer_config.is_kv_producer
|
||||
self.chunked_prefill: dict[str, tuple[list[int], list[int] | None]] = {}
|
||||
|
||||
self._rank = get_world_group().rank if role == KVConnectorRole.WORKER else 0
|
||||
self._local_rank = (
|
||||
get_world_group().local_rank if role == KVConnectorRole.WORKER else 0
|
||||
)
|
||||
|
||||
self.p2p_nccl_engine = (
|
||||
P2pNcclEngine(
|
||||
local_rank=self._local_rank,
|
||||
config=self._kv_transfer_config,
|
||||
hostname="",
|
||||
port_offset=self._rank,
|
||||
)
|
||||
if role == KVConnectorRole.WORKER
|
||||
else None
|
||||
)
|
||||
|
||||
# ==============================
|
||||
# Worker-side methods
|
||||
# ==============================
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None:
|
||||
"""Start loading the KV cache from the connector buffer to vLLM's
|
||||
paged KV buffer.
|
||||
|
||||
Args:
|
||||
forward_context (ForwardContext): the forward context.
|
||||
**kwargs: additional arguments for the load operation
|
||||
|
||||
Note:
|
||||
The number of elements in kv_caches and layer_names should be
|
||||
the same.
|
||||
"""
|
||||
|
||||
# Only consumer/decode loads KV Cache
|
||||
if self.is_producer:
|
||||
return
|
||||
|
||||
assert self.p2p_nccl_engine is not None
|
||||
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if attn_metadata is None:
|
||||
return
|
||||
|
||||
def inject_kv_into_layer(
|
||||
layer: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
block_ids: torch.Tensor,
|
||||
request_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
Inject KV cache data into a given attention layer tensor.
|
||||
|
||||
This function updates `layer` in-place with values from `kv_cache`,
|
||||
handling different backend layouts:
|
||||
- MLA (Multi-Linear Attention) or FlashInfer: KV tensors are
|
||||
indexed along the first dimension.
|
||||
- FlashAttention: KV tensors are indexed along the second
|
||||
dimension.
|
||||
|
||||
If the number of provided block IDs does not match the number of KV
|
||||
blocks, only the overlapping portion is updated, and a warning is
|
||||
logged.
|
||||
|
||||
Args:
|
||||
layer (torch.Tensor): The attention layer KV tensor to update.
|
||||
kv_cache (torch.Tensor): The KV cache tensor to inject.
|
||||
block_ids (torch.Tensor): Indices of the blocks to update.
|
||||
request_id (str): Request identifier used for logging.
|
||||
|
||||
Returns:
|
||||
None. The function modifies `layer` in-place.
|
||||
"""
|
||||
if (
|
||||
isinstance(attn_metadata, MLACommonMetadata) or layer.shape[1] == 2
|
||||
): # MLA or FlashInfer
|
||||
num_block = kv_cache.shape[0]
|
||||
self.check_tensors_except_dim(layer, kv_cache, 0)
|
||||
if len(block_ids) == num_block:
|
||||
layer[block_ids, ...] = kv_cache
|
||||
else:
|
||||
layer[block_ids[:num_block], ...] = kv_cache
|
||||
logger.warning(
|
||||
"🚧kv_cache does not match, block_ids:%d, "
|
||||
"num_block:%d, request_id:%s",
|
||||
len(block_ids),
|
||||
num_block,
|
||||
request_id,
|
||||
)
|
||||
|
||||
elif layer.shape[0] == 2: # FlashAttention
|
||||
num_block = kv_cache.shape[1]
|
||||
self.check_tensors_except_dim(layer, kv_cache, 1)
|
||||
if len(block_ids) == num_block:
|
||||
layer[:, block_ids, ...] = kv_cache
|
||||
else:
|
||||
layer[:, block_ids[:num_block], ...] = kv_cache
|
||||
logger.warning(
|
||||
"🚧kv_cache does not match, block_ids:%d, "
|
||||
"num_block:%d, request_id:%s",
|
||||
len(block_ids),
|
||||
num_block,
|
||||
request_id,
|
||||
)
|
||||
|
||||
# Get the metadata
|
||||
metadata: KVConnectorMetadata = self._get_connector_metadata()
|
||||
assert isinstance(metadata, P2pNcclConnectorMetadata)
|
||||
|
||||
if metadata is None:
|
||||
return
|
||||
|
||||
# Load the KV for each request each layer
|
||||
for request in metadata.requests:
|
||||
request_id = request.request_id
|
||||
ip, port = self.parse_request_id(request_id, False)
|
||||
remote_address = ip + ":" + str(port + self._rank)
|
||||
for layer_name in forward_context.no_compile_layers:
|
||||
layer = forward_context.no_compile_layers[layer_name]
|
||||
|
||||
# Only process layers that have kv_cache
|
||||
# attribute (attention layers) Skip non-attention
|
||||
# layers like FusedMoE
|
||||
kv_cache = getattr(layer, "kv_cache", None)
|
||||
if kv_cache is None:
|
||||
continue
|
||||
|
||||
layer = kv_cache[forward_context.virtual_engine]
|
||||
|
||||
kv_cache = self.p2p_nccl_engine.recv_tensor(
|
||||
request.request_id + "#" + layer_name, remote_address
|
||||
)
|
||||
|
||||
if kv_cache is None:
|
||||
logger.warning("🚧kv_cache is None, %s", request.request_id)
|
||||
continue
|
||||
|
||||
inject_kv_into_layer(
|
||||
layer, kv_cache, request.block_ids, request.request_id
|
||||
)
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
"""Blocking until the KV for a specific layer is loaded into vLLM's
|
||||
paged buffer.
|
||||
|
||||
This interface will be useful for layer-by-layer pipelining.
|
||||
|
||||
Args:
|
||||
layer_name: the name of that layer
|
||||
"""
|
||||
return
|
||||
|
||||
def save_kv_layer(
|
||||
self,
|
||||
layer_name: str,
|
||||
kv_layer: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Start saving the KV cache of the layer from vLLM's paged buffer
|
||||
to the connector.
|
||||
|
||||
Args:
|
||||
layer_name (str): the name of the layer.
|
||||
kv_layer (torch.Tensor): the paged KV buffer of the current
|
||||
layer in vLLM.
|
||||
attn_metadata (AttentionMetadata): the attention metadata.
|
||||
**kwargs: additional arguments for the save operation.
|
||||
"""
|
||||
|
||||
# Only producer/prefill saves KV Cache
|
||||
if not self.is_producer:
|
||||
return
|
||||
|
||||
assert self.p2p_nccl_engine is not None
|
||||
|
||||
def extract_kv_from_layer(
|
||||
layer: torch.Tensor,
|
||||
block_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Extract KV cache slices from a given attention layer tensor.
|
||||
|
||||
This function handles multiple backend layouts:
|
||||
- MLA (Multi-Linear Attention) or FlashInfer: KV tensors are
|
||||
indexed along the first dimension.
|
||||
- FlashAttention: KV tensors are indexed along the second
|
||||
dimension.
|
||||
|
||||
Args:
|
||||
layer (torch.Tensor): The KV cache from the attention layer.
|
||||
block_ids (torch.Tensor): Indices of blocks to extract.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: A tensor containing the extracted KV slices.
|
||||
Returns None if the layout is unsupported.
|
||||
"""
|
||||
if (
|
||||
isinstance(attn_metadata, MLACommonMetadata) or layer.shape[1] == 2
|
||||
): # MLA or FlashInfer
|
||||
return layer[block_ids, ...]
|
||||
|
||||
if layer.shape[0] == 2: # FlashAttention
|
||||
return layer[:, block_ids, ...]
|
||||
|
||||
return None
|
||||
|
||||
connector_metadata = self._get_connector_metadata()
|
||||
assert isinstance(connector_metadata, P2pNcclConnectorMetadata)
|
||||
for request in connector_metadata.requests:
|
||||
request_id = request.request_id
|
||||
ip, port = self.parse_request_id(request_id, True)
|
||||
remote_address = ip + ":" + str(port + self._rank)
|
||||
|
||||
kv_cache = extract_kv_from_layer(kv_layer, request.block_ids)
|
||||
self.p2p_nccl_engine.send_tensor(
|
||||
request_id + "#" + layer_name, kv_cache, remote_address
|
||||
)
|
||||
|
||||
def wait_for_save(self):
|
||||
if self.is_producer:
|
||||
assert self.p2p_nccl_engine is not None
|
||||
self.p2p_nccl_engine.wait_for_sent()
|
||||
|
||||
def get_finished(
|
||||
self, finished_req_ids: set[str], **kwargs: Any
|
||||
) -> tuple[set[str] | None, set[str] | None]:
|
||||
"""
|
||||
Notifies worker-side connector ids of requests that have
|
||||
finished generating tokens.
|
||||
|
||||
Returns:
|
||||
ids of requests that have finished asynchronous transfer,
|
||||
tuple of (sending/saving ids, recving/loading ids).
|
||||
The finished saves/sends req ids must belong to a set provided in a
|
||||
call to this method (this call or a prior one).
|
||||
"""
|
||||
|
||||
assert self.p2p_nccl_engine is not None
|
||||
|
||||
no_compile_layers = self._vllm_config.compilation_config.static_forward_context
|
||||
return self.p2p_nccl_engine.get_finished(finished_req_ids, no_compile_layers)
|
||||
|
||||
# ==============================
|
||||
# Scheduler-side methods
|
||||
# ==============================
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self,
|
||||
request: "Request",
|
||||
num_computed_tokens: int,
|
||||
) -> tuple[int, bool]:
|
||||
"""
|
||||
Get number of new tokens that can be loaded from the
|
||||
external KV cache beyond the num_computed_tokens.
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
num_computed_tokens (int): the number of locally
|
||||
computed tokens for this request
|
||||
|
||||
Returns:
|
||||
the number of tokens that can be loaded from the
|
||||
external KV cache beyond what is already computed.
|
||||
"""
|
||||
if self.is_producer:
|
||||
return 0, False
|
||||
|
||||
prompt_token_ids = request.prompt_token_ids or []
|
||||
num_external_tokens = len(prompt_token_ids) - 1 - num_computed_tokens
|
||||
|
||||
if num_external_tokens < 0:
|
||||
num_external_tokens = 0
|
||||
|
||||
return num_external_tokens, False
|
||||
|
||||
def update_state_after_alloc(
|
||||
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
|
||||
):
|
||||
"""
|
||||
Update KVConnector state after block allocation.
|
||||
"""
|
||||
if not self.is_producer and num_external_tokens > 0:
|
||||
self._requests_need_load[request.request_id] = (
|
||||
request,
|
||||
blocks.get_block_ids()[0],
|
||||
)
|
||||
|
||||
def build_connector_meta(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
) -> KVConnectorMetadata:
|
||||
"""Build the connector metadata for this step.
|
||||
|
||||
This function should NOT modify any fields in the scheduler_output.
|
||||
Also, calling this function will reset the state of the connector.
|
||||
|
||||
Args:
|
||||
scheduler_output (SchedulerOutput): the scheduler output object.
|
||||
"""
|
||||
|
||||
meta = P2pNcclConnectorMetadata()
|
||||
|
||||
for new_req in scheduler_output.scheduled_new_reqs:
|
||||
if self.is_producer:
|
||||
num_scheduled_tokens = (scheduler_output.num_scheduled_tokens)[
|
||||
new_req.req_id
|
||||
]
|
||||
num_tokens = num_scheduled_tokens + new_req.num_computed_tokens
|
||||
# the request's prompt is chunked prefill
|
||||
if num_tokens < len(new_req.prompt_token_ids or []):
|
||||
# 'CachedRequestData' has no attribute 'prompt_token_ids'
|
||||
self.chunked_prefill[new_req.req_id] = (
|
||||
new_req.block_ids[0],
|
||||
new_req.prompt_token_ids,
|
||||
)
|
||||
continue
|
||||
# the request's prompt is not chunked prefill
|
||||
meta.add_request(
|
||||
request_id=new_req.req_id,
|
||||
token_ids=new_req.prompt_token_ids or [],
|
||||
block_ids=new_req.block_ids[0],
|
||||
block_size=self._block_size,
|
||||
)
|
||||
continue
|
||||
if new_req.req_id in self._requests_need_load:
|
||||
meta.add_request(
|
||||
request_id=new_req.req_id,
|
||||
token_ids=new_req.prompt_token_ids or [],
|
||||
block_ids=new_req.block_ids[0],
|
||||
block_size=self._block_size,
|
||||
)
|
||||
self._requests_need_load.pop(new_req.req_id)
|
||||
|
||||
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||
for i, req_id in enumerate(cached_reqs.req_ids):
|
||||
num_computed_tokens = cached_reqs.num_computed_tokens[i]
|
||||
new_block_ids = cached_reqs.new_block_ids[i]
|
||||
resumed_from_preemption = req_id in cached_reqs.resumed_req_ids
|
||||
|
||||
if self.is_producer:
|
||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
num_tokens = num_scheduled_tokens + num_computed_tokens
|
||||
assert req_id in self.chunked_prefill
|
||||
assert new_block_ids is not None
|
||||
block_ids = new_block_ids[0]
|
||||
if not resumed_from_preemption:
|
||||
block_ids = self.chunked_prefill[req_id][0] + block_ids
|
||||
prompt_token_ids = self.chunked_prefill[req_id][1]
|
||||
assert prompt_token_ids is not None
|
||||
# the request's prompt is chunked prefill again
|
||||
if num_tokens < len(prompt_token_ids):
|
||||
self.chunked_prefill[req_id] = (block_ids, prompt_token_ids)
|
||||
continue
|
||||
# the request's prompt is all prefilled finally
|
||||
meta.add_request(
|
||||
request_id=req_id,
|
||||
token_ids=prompt_token_ids,
|
||||
block_ids=block_ids,
|
||||
block_size=self._block_size,
|
||||
)
|
||||
self.chunked_prefill.pop(req_id, None)
|
||||
continue
|
||||
|
||||
# NOTE(rob): here we rely on the resumed requests being
|
||||
# the first N requests in the list scheduled_cache_reqs.
|
||||
if not resumed_from_preemption:
|
||||
break
|
||||
if req_id in self._requests_need_load:
|
||||
request, _ = self._requests_need_load.pop(req_id)
|
||||
total_tokens = num_computed_tokens + 1
|
||||
token_ids = request.all_token_ids[:total_tokens]
|
||||
|
||||
# NOTE(rob): For resumed req, new_block_ids is all
|
||||
# of the block_ids for the request.
|
||||
assert new_block_ids is not None
|
||||
block_ids = new_block_ids[0]
|
||||
|
||||
meta.add_request(
|
||||
request_id=req_id,
|
||||
token_ids=token_ids,
|
||||
block_ids=block_ids,
|
||||
block_size=self._block_size,
|
||||
)
|
||||
|
||||
self._requests_need_load.clear()
|
||||
return meta
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
"""
|
||||
Called when a request has finished, before its blocks are freed.
|
||||
|
||||
Returns:
|
||||
True if the request is being saved/sent asynchronously and blocks
|
||||
should not be freed until the request_id is returned from
|
||||
get_finished().
|
||||
Optional KVTransferParams to be included in the request outputs
|
||||
returned by the engine.
|
||||
"""
|
||||
|
||||
self.chunked_prefill.pop(request.request_id, None)
|
||||
|
||||
return False, None
|
||||
|
||||
# ==============================
|
||||
# Static methods
|
||||
# ==============================
|
||||
|
||||
@staticmethod
|
||||
def parse_request_id(request_id: str, is_prefill=True) -> tuple[str, int]:
|
||||
# Regular expression to match the string hostname and integer port
|
||||
if is_prefill:
|
||||
pattern = r"___decode_addr_(.*):(\d+)"
|
||||
else:
|
||||
pattern = r"___prefill_addr_(.*):(\d+)___"
|
||||
|
||||
# Use re.search to find the pattern in the request_id
|
||||
match = re.search(pattern, request_id)
|
||||
if match:
|
||||
# Extract the ranks
|
||||
ip = match.group(1)
|
||||
port = int(match.group(2))
|
||||
|
||||
return ip, port
|
||||
raise ValueError(f"Request id {request_id} does not contain hostname and port")
|
||||
|
||||
@staticmethod
|
||||
def check_tensors_except_dim(tensor1, tensor2, dim):
|
||||
shape1 = tensor1.size()
|
||||
shape2 = tensor2.size()
|
||||
|
||||
if len(shape1) != len(shape2) or not all(
|
||||
s1 == s2 for i, (s1, s2) in enumerate(zip(shape1, shape2)) if i != dim
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"Currently, only symmetric TP is supported. Asymmetric TP, PP,"
|
||||
"and others will be supported in future PRs."
|
||||
)
|
||||
@@ -0,0 +1,632 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from collections import deque
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import msgpack
|
||||
import torch
|
||||
import zmq
|
||||
|
||||
from vllm.config.kv_transfer import KVTransferConfig
|
||||
from vllm.distributed.device_communicators.pynccl_wrapper import (
|
||||
NCCLLibrary,
|
||||
buffer_type,
|
||||
cudaStream_t,
|
||||
ncclComm_t,
|
||||
ncclDataTypeEnum,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.p2p.tensor_memory_pool import ( # noqa: E501
|
||||
TensorMemoryPool,
|
||||
)
|
||||
from vllm.utils.network_utils import get_ip
|
||||
from vllm.utils.torch_utils import current_stream
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_MEM_POOL_SIZE_GB = 32
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_p2p_nccl_context(num_channels: str):
|
||||
original_values: dict[str, Any] = {}
|
||||
env_vars = [
|
||||
"NCCL_MAX_NCHANNELS",
|
||||
"NCCL_MIN_NCHANNELS",
|
||||
"NCCL_CUMEM_ENABLE",
|
||||
"NCCL_BUFFSIZE",
|
||||
"NCCL_PROTO", # LL,LL128,SIMPLE
|
||||
"NCCL_ALGO", # RING,TREE
|
||||
]
|
||||
|
||||
for var in env_vars:
|
||||
original_values[var] = os.environ.get(var)
|
||||
|
||||
logger.info("set_p2p_nccl_context, original_values: %s", original_values)
|
||||
|
||||
try:
|
||||
os.environ["NCCL_MAX_NCHANNELS"] = num_channels
|
||||
os.environ["NCCL_MIN_NCHANNELS"] = num_channels
|
||||
os.environ["NCCL_CUMEM_ENABLE"] = "1"
|
||||
yield
|
||||
finally:
|
||||
for var in env_vars:
|
||||
if original_values[var] is not None:
|
||||
os.environ[var] = original_values[var]
|
||||
else:
|
||||
os.environ.pop(var, None)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SendQueueItem:
|
||||
tensor_id: str
|
||||
remote_address: str
|
||||
tensor: torch.Tensor
|
||||
|
||||
|
||||
class P2pNcclEngine:
|
||||
def __init__(
|
||||
self,
|
||||
local_rank: int,
|
||||
config: KVTransferConfig,
|
||||
hostname: str = "",
|
||||
port_offset: int = 0,
|
||||
library_path: str | None = None,
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.rank = port_offset
|
||||
self.local_rank = local_rank
|
||||
self.device = torch.device(f"cuda:{self.local_rank}")
|
||||
self.nccl = NCCLLibrary(library_path)
|
||||
|
||||
if not hostname:
|
||||
hostname = get_ip()
|
||||
port = int(self.config.kv_port) + port_offset
|
||||
if port == 0:
|
||||
raise ValueError("Port cannot be 0")
|
||||
self._hostname = hostname
|
||||
self._port = port
|
||||
|
||||
# Each card corresponds to a ZMQ address.
|
||||
self.zmq_address = f"{self._hostname}:{self._port}"
|
||||
|
||||
# If `proxy_ip` or `proxy_port` is `""`,
|
||||
# then the ping thread will not be enabled.
|
||||
proxy_ip = self.config.get_from_extra_config("proxy_ip", "")
|
||||
proxy_port = self.config.get_from_extra_config("proxy_port", "")
|
||||
if proxy_ip == "" or proxy_port == "":
|
||||
self.proxy_address = ""
|
||||
self.http_address = ""
|
||||
else:
|
||||
self.proxy_address = proxy_ip + ":" + proxy_port
|
||||
# the `http_port` must be consistent with the port of OpenAI.
|
||||
http_port = self.config.get_from_extra_config("http_port", None)
|
||||
if http_port is None:
|
||||
example_cfg = {
|
||||
"kv_connector": "P2pNcclConnector",
|
||||
"kv_connector_extra_config": {"http_port": 8000},
|
||||
}
|
||||
example = (
|
||||
f"--port=8000 --kv-transfer-config='{json.dumps(example_cfg)}'"
|
||||
)
|
||||
raise ValueError(
|
||||
"kv_connector_extra_config.http_port is required. "
|
||||
f"Example: {example}"
|
||||
)
|
||||
self.http_address = f"{self._hostname}:{http_port}"
|
||||
|
||||
self.context = zmq.Context()
|
||||
self.router_socket = self.context.socket(zmq.ROUTER)
|
||||
self.router_socket.bind(f"tcp://{self.zmq_address}")
|
||||
|
||||
self.poller = zmq.Poller()
|
||||
self.poller.register(self.router_socket, zmq.POLLIN)
|
||||
|
||||
self.send_store_cv = threading.Condition()
|
||||
self.send_queue_cv = threading.Condition()
|
||||
self.recv_store_cv = threading.Condition()
|
||||
|
||||
self.send_stream = torch.cuda.Stream()
|
||||
self.recv_stream = torch.cuda.Stream()
|
||||
|
||||
mem_pool_size_gb = float(
|
||||
self.config.get_from_extra_config(
|
||||
"mem_pool_size_gb", DEFAULT_MEM_POOL_SIZE_GB
|
||||
)
|
||||
)
|
||||
self.pool = TensorMemoryPool(
|
||||
max_block_size=int(mem_pool_size_gb * 1024**3)
|
||||
) # GB
|
||||
|
||||
# The sending type includes tree mutually exclusive options:
|
||||
# PUT, GET, PUT_ASYNC.
|
||||
self.send_type = self.config.get_from_extra_config("send_type", "PUT_ASYNC")
|
||||
if self.send_type == "GET":
|
||||
# tensor_id: torch.Tensor
|
||||
self.send_store: dict[str, torch.Tensor] = {}
|
||||
else:
|
||||
# PUT or PUT_ASYNC
|
||||
# tensor_id: torch.Tensor
|
||||
self.send_queue: deque[SendQueueItem] = deque()
|
||||
if self.send_type == "PUT_ASYNC":
|
||||
self._send_thread = threading.Thread(
|
||||
target=self.send_async, daemon=True
|
||||
)
|
||||
self._send_thread.start()
|
||||
|
||||
# tensor_id: torch.Tensor/(addr, dtype, shape)
|
||||
self.recv_store: dict[str, Any] = {}
|
||||
self.recv_request_id_to_tensor_ids: dict[str, set[str]] = {}
|
||||
self.send_request_id_to_tensor_ids: dict[str, set[str]] = {}
|
||||
self.socks: dict[str, Any] = {} # remote_address: client socket
|
||||
self.comms: dict[str, Any] = {} # remote_address: (ncclComm_t, rank)
|
||||
|
||||
self.buffer_size = 0
|
||||
self.buffer_size_threshold = float(self.config.kv_buffer_size)
|
||||
|
||||
self.nccl_num_channels = self.config.get_from_extra_config(
|
||||
"nccl_num_channels", "8"
|
||||
)
|
||||
|
||||
self._listener_thread = threading.Thread(
|
||||
target=self.listen_for_requests, daemon=True
|
||||
)
|
||||
self._listener_thread.start()
|
||||
|
||||
self._ping_thread = None
|
||||
if port_offset == 0 and self.proxy_address != "":
|
||||
self._ping_thread = threading.Thread(target=self.ping, daemon=True)
|
||||
self._ping_thread.start()
|
||||
|
||||
logger.info(
|
||||
"💯P2pNcclEngine init, rank:%d, local_rank:%d, http_address:%s, "
|
||||
"zmq_address:%s, proxy_address:%s, send_type:%s, buffer_size_"
|
||||
"threshold:%.2f, nccl_num_channels:%s",
|
||||
self.rank,
|
||||
self.local_rank,
|
||||
self.http_address,
|
||||
self.zmq_address,
|
||||
self.proxy_address,
|
||||
self.send_type,
|
||||
self.buffer_size_threshold,
|
||||
self.nccl_num_channels,
|
||||
)
|
||||
|
||||
def create_connect(self, remote_address: str | None = None):
|
||||
assert remote_address is not None
|
||||
if remote_address not in self.socks:
|
||||
sock = self.context.socket(zmq.DEALER)
|
||||
sock.setsockopt_string(zmq.IDENTITY, self.zmq_address)
|
||||
sock.connect(f"tcp://{remote_address}")
|
||||
self.socks[remote_address] = sock
|
||||
if remote_address in self.comms:
|
||||
logger.info(
|
||||
"👋comm exists, remote_address:%s, comms:%s",
|
||||
remote_address,
|
||||
self.comms,
|
||||
)
|
||||
return sock, self.comms[remote_address]
|
||||
|
||||
unique_id = self.nccl.ncclGetUniqueId()
|
||||
data = {"cmd": "NEW", "unique_id": bytes(unique_id.internal)}
|
||||
sock.send(msgpack.dumps(data))
|
||||
|
||||
with torch.cuda.device(self.device):
|
||||
rank = 0
|
||||
with set_p2p_nccl_context(self.nccl_num_channels):
|
||||
comm: ncclComm_t = self.nccl.ncclCommInitRank(2, unique_id, rank)
|
||||
self.comms[remote_address] = (comm, rank)
|
||||
logger.info(
|
||||
"🤝ncclCommInitRank Success, %s👉%s, MyRank:%s",
|
||||
self.zmq_address,
|
||||
remote_address,
|
||||
rank,
|
||||
)
|
||||
|
||||
return self.socks[remote_address], self.comms[remote_address]
|
||||
|
||||
def send_tensor(
|
||||
self,
|
||||
tensor_id: str,
|
||||
tensor: torch.Tensor,
|
||||
remote_address: str | None = None,
|
||||
) -> bool:
|
||||
if remote_address is None:
|
||||
with self.recv_store_cv:
|
||||
self.recv_store[tensor_id] = tensor
|
||||
self.recv_store_cv.notify()
|
||||
return True
|
||||
|
||||
item = SendQueueItem(
|
||||
tensor_id=tensor_id, remote_address=remote_address, tensor=tensor
|
||||
)
|
||||
|
||||
if self.send_type == "PUT":
|
||||
return self.send_sync(item)
|
||||
|
||||
if self.send_type == "PUT_ASYNC":
|
||||
with self.send_queue_cv:
|
||||
self.send_queue.append(item)
|
||||
self.send_queue_cv.notify()
|
||||
return True
|
||||
|
||||
# GET
|
||||
with self.send_store_cv:
|
||||
tensor_size = tensor.element_size() * tensor.numel()
|
||||
if tensor_size > self.buffer_size_threshold:
|
||||
logger.warning(
|
||||
"❗[GET]tensor_id:%s, tensor_size:%d, is greater than"
|
||||
"buffer size threshold :%d, skip send to %s, rank:%d",
|
||||
tensor_id,
|
||||
tensor_size,
|
||||
self.buffer_size_threshold,
|
||||
remote_address,
|
||||
self.rank,
|
||||
)
|
||||
return False
|
||||
while self.buffer_size + tensor_size > self.buffer_size_threshold:
|
||||
assert len(self.send_store) > 0
|
||||
oldest_tensor_id = next(iter(self.send_store))
|
||||
oldest_tensor = self.send_store.pop(oldest_tensor_id)
|
||||
oldest_tensor_size = (
|
||||
oldest_tensor.element_size() * oldest_tensor.numel()
|
||||
)
|
||||
self.buffer_size -= oldest_tensor_size
|
||||
logger.debug(
|
||||
"⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d,"
|
||||
" buffer_size:%d, oldest_tensor_size:%d, rank:%d",
|
||||
remote_address,
|
||||
tensor_id,
|
||||
tensor_size,
|
||||
self.buffer_size,
|
||||
oldest_tensor_size,
|
||||
self.rank,
|
||||
)
|
||||
|
||||
self.send_store[tensor_id] = tensor
|
||||
self.buffer_size += tensor_size
|
||||
logger.debug(
|
||||
"🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, "
|
||||
"shape:%s, rank:%d, buffer_size:%d(%.2f%%)",
|
||||
remote_address,
|
||||
tensor_id,
|
||||
tensor_size,
|
||||
tensor.shape,
|
||||
self.rank,
|
||||
self.buffer_size,
|
||||
self.buffer_size / self.buffer_size_threshold * 100,
|
||||
)
|
||||
return True
|
||||
|
||||
def recv_tensor(
|
||||
self,
|
||||
tensor_id: str,
|
||||
remote_address: str | None = None,
|
||||
) -> torch.Tensor:
|
||||
if self.send_type == "PUT" or self.send_type == "PUT_ASYNC":
|
||||
start_time = time.time()
|
||||
with self.recv_store_cv:
|
||||
while tensor_id not in self.recv_store:
|
||||
self.recv_store_cv.wait()
|
||||
tensor = self.recv_store[tensor_id]
|
||||
|
||||
if tensor is not None:
|
||||
if isinstance(tensor, tuple):
|
||||
addr, dtype, shape = tensor
|
||||
tensor = self.pool.load_tensor(addr, dtype, shape, self.device)
|
||||
else:
|
||||
self.buffer_size -= tensor.element_size() * tensor.numel()
|
||||
else:
|
||||
duration = time.time() - start_time
|
||||
logger.warning(
|
||||
"🔴[PUT]Recv From %s, tensor_id:%s, duration:%.3fms, rank:%d",
|
||||
remote_address,
|
||||
tensor_id,
|
||||
duration * 1000,
|
||||
self.rank,
|
||||
)
|
||||
return tensor
|
||||
|
||||
# GET
|
||||
if remote_address is None:
|
||||
return None
|
||||
|
||||
if remote_address not in self.socks:
|
||||
self.create_connect(remote_address)
|
||||
|
||||
sock = self.socks[remote_address]
|
||||
comm, rank = self.comms[remote_address]
|
||||
|
||||
data = {"cmd": "GET", "tensor_id": tensor_id}
|
||||
sock.send(msgpack.dumps(data))
|
||||
|
||||
message = sock.recv()
|
||||
data = msgpack.loads(message)
|
||||
if data["ret"] != 0:
|
||||
logger.warning(
|
||||
"🔴[GET]Recv From %s, tensor_id: %s, ret: %d",
|
||||
remote_address,
|
||||
tensor_id,
|
||||
data["ret"],
|
||||
)
|
||||
return None
|
||||
|
||||
with torch.cuda.stream(self.recv_stream):
|
||||
tensor = torch.empty(
|
||||
data["shape"], dtype=getattr(torch, data["dtype"]), device=self.device
|
||||
)
|
||||
|
||||
self.recv(comm, tensor, rank ^ 1, self.recv_stream)
|
||||
|
||||
return tensor
|
||||
|
||||
def listen_for_requests(self):
|
||||
while True:
|
||||
socks = dict(self.poller.poll())
|
||||
if self.router_socket not in socks:
|
||||
continue
|
||||
|
||||
remote_address, message = self.router_socket.recv_multipart()
|
||||
data = msgpack.loads(message)
|
||||
if data["cmd"] == "NEW":
|
||||
unique_id = self.nccl.unique_id_from_bytes(bytes(data["unique_id"]))
|
||||
with torch.cuda.device(self.device):
|
||||
rank = 1
|
||||
with set_p2p_nccl_context(self.nccl_num_channels):
|
||||
comm: ncclComm_t = self.nccl.ncclCommInitRank(
|
||||
2, unique_id, rank
|
||||
)
|
||||
self.comms[remote_address.decode()] = (comm, rank)
|
||||
logger.info(
|
||||
"🤝ncclCommInitRank Success, %s👈%s, MyRank:%s",
|
||||
self.zmq_address,
|
||||
remote_address.decode(),
|
||||
rank,
|
||||
)
|
||||
elif data["cmd"] == "PUT":
|
||||
tensor_id = data["tensor_id"]
|
||||
try:
|
||||
with torch.cuda.stream(self.recv_stream):
|
||||
tensor = torch.empty(
|
||||
data["shape"],
|
||||
dtype=getattr(torch, data["dtype"]),
|
||||
device=self.device,
|
||||
)
|
||||
self.router_socket.send_multipart([remote_address, b"0"])
|
||||
comm, rank = self.comms[remote_address.decode()]
|
||||
self.recv(comm, tensor, rank ^ 1, self.recv_stream)
|
||||
tensor_size = tensor.element_size() * tensor.numel()
|
||||
if self.buffer_size + tensor_size > self.buffer_size_threshold:
|
||||
# Store Tensor in memory pool
|
||||
addr = self.pool.store_tensor(tensor)
|
||||
tensor = (addr, tensor.dtype, tensor.shape)
|
||||
logger.warning(
|
||||
"🔴[PUT]Recv Tensor, Out Of Threshold, "
|
||||
"%s👈%s, data:%s, addr:%d",
|
||||
self.zmq_address,
|
||||
remote_address.decode(),
|
||||
data,
|
||||
addr,
|
||||
)
|
||||
else:
|
||||
self.buffer_size += tensor_size
|
||||
|
||||
except torch.cuda.OutOfMemoryError:
|
||||
self.router_socket.send_multipart([remote_address, b"1"])
|
||||
tensor = None
|
||||
logger.warning(
|
||||
"🔴[PUT]Recv Tensor, Out Of Memory, %s👈%s, data:%s",
|
||||
self.zmq_address,
|
||||
remote_address.decode(),
|
||||
data,
|
||||
)
|
||||
|
||||
with self.recv_store_cv:
|
||||
self.recv_store[tensor_id] = tensor
|
||||
self.have_received_tensor_id(tensor_id)
|
||||
self.recv_store_cv.notify()
|
||||
|
||||
elif data["cmd"] == "GET":
|
||||
tensor_id = data["tensor_id"]
|
||||
with self.send_store_cv:
|
||||
tensor = self.send_store.pop(tensor_id, None)
|
||||
if tensor is not None:
|
||||
data = {
|
||||
"ret": 0,
|
||||
"shape": tensor.shape,
|
||||
"dtype": str(tensor.dtype).replace("torch.", ""),
|
||||
}
|
||||
# LRU
|
||||
self.send_store[tensor_id] = tensor
|
||||
self.have_sent_tensor_id(tensor_id)
|
||||
else:
|
||||
data = {"ret": 1}
|
||||
|
||||
self.router_socket.send_multipart([remote_address, msgpack.dumps(data)])
|
||||
|
||||
if data["ret"] == 0:
|
||||
comm, rank = self.comms[remote_address.decode()]
|
||||
self.send(comm, tensor.to(self.device), rank ^ 1, self.send_stream)
|
||||
else:
|
||||
logger.warning(
|
||||
"🚧Unexpected, Received message from %s, data:%s",
|
||||
remote_address,
|
||||
data,
|
||||
)
|
||||
|
||||
def have_sent_tensor_id(self, tensor_id: str):
|
||||
request_id = tensor_id.split("#")[0]
|
||||
if request_id not in self.send_request_id_to_tensor_ids:
|
||||
self.send_request_id_to_tensor_ids[request_id] = set()
|
||||
self.send_request_id_to_tensor_ids[request_id].add(tensor_id)
|
||||
|
||||
def have_received_tensor_id(self, tensor_id: str):
|
||||
request_id = tensor_id.split("#")[0]
|
||||
if request_id not in self.recv_request_id_to_tensor_ids:
|
||||
self.recv_request_id_to_tensor_ids[request_id] = set()
|
||||
self.recv_request_id_to_tensor_ids[request_id].add(tensor_id)
|
||||
|
||||
def send_async(self):
|
||||
while True:
|
||||
with self.send_queue_cv:
|
||||
while not self.send_queue:
|
||||
self.send_queue_cv.wait()
|
||||
item = self.send_queue.popleft()
|
||||
if not self.send_queue:
|
||||
self.send_queue_cv.notify()
|
||||
self.send_sync(item)
|
||||
|
||||
def wait_for_sent(self):
|
||||
if self.send_type == "PUT_ASYNC":
|
||||
start_time = time.time()
|
||||
with self.send_queue_cv:
|
||||
while self.send_queue:
|
||||
self.send_queue_cv.wait()
|
||||
duration = time.time() - start_time
|
||||
logger.debug(
|
||||
"🚧[PUT_ASYNC]It took %.3fms to wait for the send_queue"
|
||||
" to be empty, rank:%d",
|
||||
duration * 1000,
|
||||
self.rank,
|
||||
)
|
||||
|
||||
def send_sync(self, item: SendQueueItem) -> bool:
|
||||
if item.remote_address is None:
|
||||
return False
|
||||
if item.remote_address not in self.socks:
|
||||
self.create_connect(item.remote_address)
|
||||
|
||||
tensor = item.tensor
|
||||
|
||||
sock = self.socks[item.remote_address]
|
||||
comm, rank = self.comms[item.remote_address]
|
||||
data = {
|
||||
"cmd": "PUT",
|
||||
"tensor_id": item.tensor_id,
|
||||
"shape": tensor.shape,
|
||||
"dtype": str(tensor.dtype).replace("torch.", ""),
|
||||
}
|
||||
sock.send(msgpack.dumps(data))
|
||||
|
||||
response = sock.recv()
|
||||
if response != b"0":
|
||||
logger.error(
|
||||
"🔴Send Tensor, Peer Out Of Memory/Threshold, %s 👉 %s, "
|
||||
"MyRank:%s, data:%s, tensor:%s, size:%fGB, response:%s",
|
||||
self.zmq_address,
|
||||
item.remote_address,
|
||||
rank,
|
||||
data,
|
||||
tensor.shape,
|
||||
tensor.element_size() * tensor.numel() / 1024**3,
|
||||
response.decode(),
|
||||
)
|
||||
return False
|
||||
|
||||
self.send(comm, tensor.to(self.device), rank ^ 1, self.send_stream)
|
||||
|
||||
if self.send_type == "PUT_ASYNC":
|
||||
self.have_sent_tensor_id(item.tensor_id)
|
||||
|
||||
return True
|
||||
|
||||
def get_finished(
|
||||
self, finished_req_ids: set[str], no_compile_layers
|
||||
) -> tuple[set[str] | None, set[str] | None]:
|
||||
"""
|
||||
Notifies worker-side connector ids of requests that have
|
||||
finished generating tokens.
|
||||
|
||||
Returns:
|
||||
ids of requests that have finished asynchronous transfer,
|
||||
tuple of (sending/saving ids, recving/loading ids).
|
||||
The finished saves/sends req ids must belong to a set provided in a
|
||||
call to this method (this call or a prior one).
|
||||
"""
|
||||
|
||||
# Clear the buffer upon request completion.
|
||||
for request_id in finished_req_ids:
|
||||
for layer_name in no_compile_layers:
|
||||
tensor_id = request_id + "#" + layer_name
|
||||
if tensor_id in self.recv_store:
|
||||
with self.recv_store_cv:
|
||||
tensor = self.recv_store.pop(tensor_id, None)
|
||||
self.send_request_id_to_tensor_ids.pop(request_id, None)
|
||||
self.recv_request_id_to_tensor_ids.pop(request_id, None)
|
||||
if isinstance(tensor, tuple):
|
||||
addr, _, _ = tensor
|
||||
self.pool.free(addr)
|
||||
|
||||
# TODO:Retrieve requests that have already sent the KV cache.
|
||||
finished_sending: set[str] = set()
|
||||
|
||||
# TODO:Retrieve requests that have already received the KV cache.
|
||||
finished_recving: set[str] = set()
|
||||
|
||||
return finished_sending or None, finished_recving or None
|
||||
|
||||
def ping(self):
|
||||
sock = self.context.socket(zmq.DEALER)
|
||||
sock.setsockopt_string(zmq.IDENTITY, self.zmq_address)
|
||||
logger.debug("ping start, zmq_address:%s", self.zmq_address)
|
||||
sock.connect(f"tcp://{self.proxy_address}")
|
||||
data = {
|
||||
"type": "P" if self.config.is_kv_producer else "D",
|
||||
"http_address": self.http_address,
|
||||
"zmq_address": self.zmq_address,
|
||||
}
|
||||
while True:
|
||||
sock.send(msgpack.dumps(data))
|
||||
time.sleep(3)
|
||||
|
||||
def send(self, comm, tensor: torch.Tensor, dst: int, stream=None):
|
||||
assert tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}"
|
||||
)
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
|
||||
with torch.cuda.stream(stream):
|
||||
self.nccl.ncclSend(
|
||||
buffer_type(tensor.data_ptr()),
|
||||
tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(tensor.dtype),
|
||||
dst,
|
||||
comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
)
|
||||
stream.synchronize()
|
||||
|
||||
def recv(self, comm, tensor: torch.Tensor, src: int, stream=None):
|
||||
assert tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}"
|
||||
)
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
|
||||
with torch.cuda.stream(stream):
|
||||
self.nccl.ncclRecv(
|
||||
buffer_type(tensor.data_ptr()),
|
||||
tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(tensor.dtype),
|
||||
src,
|
||||
comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
)
|
||||
stream.synchronize()
|
||||
|
||||
def close(self) -> None:
|
||||
self._listener_thread.join()
|
||||
if self.send_type == "PUT_ASYNC":
|
||||
self._send_thread.join()
|
||||
if self._ping_thread is not None:
|
||||
self._ping_thread.join()
|
||||
@@ -0,0 +1,273 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import atexit
|
||||
import ctypes
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryBlock:
|
||||
size: int
|
||||
addr: int
|
||||
|
||||
|
||||
"""A memory pool for managing pinned host memory allocations for tensors.
|
||||
|
||||
This class implements a buddy allocation system to efficiently manage pinned
|
||||
host memory for tensor storage. It supports allocation, deallocation, and
|
||||
tensor storage/retrieval operations.
|
||||
|
||||
Key Features:
|
||||
- Uses power-of-two block sizes for efficient buddy allocation
|
||||
- Supports splitting and merging of memory blocks
|
||||
- Provides methods to store CUDA tensors in pinned host memory
|
||||
- Allows loading tensors from pinned memory back to device
|
||||
- Automatically cleans up memory on destruction
|
||||
|
||||
Attributes:
|
||||
max_block_size (int): Maximum block size (rounded to nearest power of two)
|
||||
min_block_size (int): Minimum block size (rounded to nearest power of two)
|
||||
free_lists (dict): Dictionary of free memory blocks by size
|
||||
allocated_blocks (dict): Dictionary of currently allocated blocks
|
||||
base_tensor (torch.Tensor): Base pinned memory tensor
|
||||
base_address (int): Base memory address of the pinned memory region
|
||||
|
||||
Example:
|
||||
>>> pool = TensorMemoryPool(max_block_size=1024*1024)
|
||||
>>> tensor = torch.randn(100, device='cuda')
|
||||
>>> addr = pool.store_tensor(tensor)
|
||||
>>> loaded_tensor = pool.load_tensor(addr, tensor.dtype,
|
||||
... tensor.shape, 'cuda')
|
||||
>>> pool.free(addr)
|
||||
"""
|
||||
|
||||
|
||||
class TensorMemoryPool:
|
||||
"""Initializes the memory pool with given size constraints.
|
||||
|
||||
Args:
|
||||
max_block_size (int): Maximum size of memory blocks to manage
|
||||
min_block_size (int, optional): Minimum size of memory blocks
|
||||
to manage. Defaults to 512.
|
||||
|
||||
Raises:
|
||||
ValueError: If block sizes are invalid or max_block_size is less
|
||||
than min_block_size
|
||||
"""
|
||||
|
||||
def __init__(self, max_block_size: int, min_block_size: int = 512):
|
||||
if max_block_size <= 0 or min_block_size <= 0:
|
||||
raise ValueError("Block sizes must be positive")
|
||||
if max_block_size < min_block_size:
|
||||
raise ValueError("Max block size must be greater than min block size")
|
||||
|
||||
self.max_block_size = self._round_to_power_of_two(max_block_size)
|
||||
self.min_block_size = self._round_to_power_of_two(min_block_size)
|
||||
|
||||
self.free_lists: dict[int, dict[int, MemoryBlock]] = {}
|
||||
self.allocated_blocks: dict[int, MemoryBlock] = {}
|
||||
|
||||
self._initialize_free_lists()
|
||||
self._allocate_pinned_memory()
|
||||
|
||||
atexit.register(self.cleanup)
|
||||
|
||||
def _round_to_power_of_two(self, size: int) -> int:
|
||||
return 1 << (size - 1).bit_length()
|
||||
|
||||
def _initialize_free_lists(self):
|
||||
size = self.max_block_size
|
||||
while size >= self.min_block_size:
|
||||
self.free_lists[size] = {}
|
||||
size //= 2
|
||||
|
||||
def _allocate_pinned_memory(self):
|
||||
self.base_tensor = torch.empty(
|
||||
self.max_block_size // 4, dtype=torch.float32, pin_memory=True
|
||||
)
|
||||
self.base_address = self.base_tensor.data_ptr()
|
||||
initial_block = MemoryBlock(size=self.max_block_size, addr=self.base_address)
|
||||
self.free_lists[self.max_block_size][initial_block.addr] = initial_block
|
||||
|
||||
logger.debug(
|
||||
"TensorMemoryPool, base_address:%d, max_block_size:%d",
|
||||
self.base_address,
|
||||
self.max_block_size,
|
||||
)
|
||||
|
||||
def allocate(self, size: int) -> int:
|
||||
"""Allocates a memory block of at least the requested size.
|
||||
|
||||
Args:
|
||||
size (int): Minimum size of memory to allocate
|
||||
|
||||
Returns:
|
||||
int: Address of the allocated memory block
|
||||
|
||||
Raises:
|
||||
ValueError: If size is invalid or insufficient memory is available
|
||||
"""
|
||||
if size <= 0:
|
||||
raise ValueError("Allocation size must be positive")
|
||||
|
||||
required_size = self._round_to_power_of_two(max(size, self.min_block_size))
|
||||
if required_size > self.max_block_size:
|
||||
raise ValueError("Requested size exceeds maximum block size")
|
||||
|
||||
current_size = required_size
|
||||
while current_size <= self.max_block_size:
|
||||
if self.free_lists[current_size]:
|
||||
_, block = self.free_lists[current_size].popitem()
|
||||
self._split_block(block, required_size)
|
||||
self.allocated_blocks[block.addr] = block
|
||||
return block.addr
|
||||
current_size *= 2
|
||||
|
||||
raise ValueError("Insufficient memory")
|
||||
|
||||
def _split_block(self, block: MemoryBlock, required_size: int):
|
||||
while block.size > required_size and block.size // 2 >= self.min_block_size:
|
||||
buddy_size = block.size // 2
|
||||
buddy_addr = block.addr + buddy_size
|
||||
|
||||
buddy = MemoryBlock(size=buddy_size, addr=buddy_addr)
|
||||
block.size = buddy_size
|
||||
|
||||
self.free_lists[buddy_size][buddy.addr] = buddy
|
||||
|
||||
def free(self, addr: int):
|
||||
"""Frees an allocated memory block.
|
||||
|
||||
Args:
|
||||
addr (int): Address of the block to free
|
||||
|
||||
Raises:
|
||||
ValueError: If address is invalid or not allocated
|
||||
"""
|
||||
if addr not in self.allocated_blocks:
|
||||
raise ValueError("Invalid address to free")
|
||||
|
||||
block = self.allocated_blocks.pop(addr)
|
||||
self._merge_buddies(block)
|
||||
|
||||
def _merge_buddies(self, block: MemoryBlock):
|
||||
MAX_MERGE_DEPTH = 30
|
||||
depth = 0
|
||||
|
||||
while depth < MAX_MERGE_DEPTH:
|
||||
buddy_offset = (
|
||||
block.size
|
||||
if (block.addr - self.base_address) % (2 * block.size) == 0
|
||||
else -block.size
|
||||
)
|
||||
buddy_addr = block.addr + buddy_offset
|
||||
buddy = self.free_lists[block.size].get(buddy_addr)
|
||||
if buddy:
|
||||
del self.free_lists[buddy.size][buddy.addr]
|
||||
merged_addr = min(block.addr, buddy.addr)
|
||||
merged_size = block.size * 2
|
||||
block = MemoryBlock(size=merged_size, addr=merged_addr)
|
||||
depth += 1
|
||||
else:
|
||||
break
|
||||
self.free_lists[block.size][block.addr] = block
|
||||
|
||||
def store_tensor(self, tensor: torch.Tensor) -> int:
|
||||
"""Stores a CUDA tensor in pinned host memory.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): CUDA tensor to store
|
||||
|
||||
Returns:
|
||||
int: Address where the tensor is stored
|
||||
|
||||
Raises:
|
||||
ValueError: If tensor is not on CUDA or allocation fails
|
||||
"""
|
||||
if not tensor.is_cuda:
|
||||
raise ValueError("Only CUDA tensors can be stored")
|
||||
|
||||
size = tensor.element_size() * tensor.numel()
|
||||
addr = self.allocate(size)
|
||||
block = self.allocated_blocks[addr]
|
||||
|
||||
if block.size < size:
|
||||
self.free(addr)
|
||||
raise ValueError(
|
||||
f"Allocated block size {block.size} is smaller than "
|
||||
f"required size {size}"
|
||||
)
|
||||
|
||||
try:
|
||||
buffer = (ctypes.c_byte * block.size).from_address(block.addr)
|
||||
cpu_tensor = torch.frombuffer(
|
||||
buffer, dtype=tensor.dtype, count=tensor.numel()
|
||||
).reshape(tensor.shape)
|
||||
except ValueError as err:
|
||||
self.free(addr)
|
||||
raise ValueError(f"Failed to create tensor view: {err}") from err
|
||||
|
||||
cpu_tensor.copy_(tensor)
|
||||
|
||||
return addr
|
||||
|
||||
def load_tensor(
|
||||
self,
|
||||
addr: int,
|
||||
dtype: torch.dtype,
|
||||
shape: tuple[int, ...],
|
||||
device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
"""Loads a tensor from pinned host memory to the specified device.
|
||||
|
||||
Args:
|
||||
addr (int): Address where tensor is stored
|
||||
dtype (torch.dtype): Data type of the tensor
|
||||
shape (tuple[int, ...]): Shape of the tensor
|
||||
device: Target device for the loaded tensor
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The loaded tensor on the specified device
|
||||
|
||||
Raises:
|
||||
ValueError: If address is invalid or sizes don't match
|
||||
"""
|
||||
if addr not in self.allocated_blocks:
|
||||
raise ValueError("Invalid address to load")
|
||||
|
||||
block = self.allocated_blocks[addr]
|
||||
num_elements = math.prod(shape)
|
||||
dtype_size = torch.tensor([], dtype=dtype).element_size()
|
||||
required_size = num_elements * dtype_size
|
||||
|
||||
if required_size > block.size:
|
||||
raise ValueError("Requested tensor size exceeds block size")
|
||||
|
||||
buffer = (ctypes.c_byte * block.size).from_address(block.addr)
|
||||
cpu_tensor = torch.frombuffer(buffer, dtype=dtype, count=num_elements).reshape(
|
||||
shape
|
||||
)
|
||||
|
||||
cuda_tensor = torch.empty(shape, dtype=dtype, device=device)
|
||||
|
||||
cuda_tensor.copy_(cpu_tensor)
|
||||
|
||||
return cuda_tensor
|
||||
|
||||
def cleanup(self):
|
||||
"""Cleans up all memory resources and resets the pool state."""
|
||||
self.free_lists.clear()
|
||||
self.allocated_blocks.clear()
|
||||
if hasattr(self, "base_tensor"):
|
||||
del self.base_tensor
|
||||
|
||||
def __del__(self):
|
||||
self.cleanup()
|
||||
78
vllm/distributed/kv_transfer/kv_transfer_state.py
Normal file
78
vllm/distributed/kv_transfer/kv_transfer_state.py
Normal file
@@ -0,0 +1,78 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import (
|
||||
KVConnectorBase_V1,
|
||||
KVConnectorRole,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
|
||||
_KV_CONNECTOR_AGENT: KVConnectorBaseType | None = None
|
||||
|
||||
|
||||
def get_kv_transfer_group() -> KVConnectorBaseType:
|
||||
assert _KV_CONNECTOR_AGENT is not None, (
|
||||
"disaggregated KV cache transfer parallel group is not initialized"
|
||||
)
|
||||
return _KV_CONNECTOR_AGENT
|
||||
|
||||
|
||||
def has_kv_transfer_group() -> bool:
|
||||
return _KV_CONNECTOR_AGENT is not None
|
||||
|
||||
|
||||
def is_v1_kv_transfer_group(connector: KVConnectorBaseType | None = None) -> bool:
|
||||
"""Check if the KV connector is the v1 connector.
|
||||
If the argument is None, it will check the global KV connector
|
||||
|
||||
Args:
|
||||
connector: The KV connector to check. If None, it will check the
|
||||
global KV connector.
|
||||
|
||||
Note:
|
||||
This function will no-longer be needed after the v1 KV connector
|
||||
becomes the default.
|
||||
"""
|
||||
if connector is None:
|
||||
connector = _KV_CONNECTOR_AGENT
|
||||
|
||||
if connector is None:
|
||||
return False
|
||||
|
||||
return isinstance(connector, KVConnectorBase_V1)
|
||||
|
||||
|
||||
def ensure_kv_transfer_initialized(
|
||||
vllm_config: "VllmConfig", kv_cache_config: Optional["KVCacheConfig"] = None
|
||||
) -> None:
|
||||
"""
|
||||
Initialize KV cache transfer parallel group.
|
||||
"""
|
||||
|
||||
global _KV_CONNECTOR_AGENT
|
||||
|
||||
if vllm_config.kv_transfer_config is None:
|
||||
return
|
||||
|
||||
if (
|
||||
vllm_config.kv_transfer_config.is_kv_transfer_instance
|
||||
and _KV_CONNECTOR_AGENT is None
|
||||
):
|
||||
_KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector(
|
||||
config=vllm_config,
|
||||
role=KVConnectorRole.WORKER,
|
||||
kv_cache_config=kv_cache_config,
|
||||
)
|
||||
|
||||
|
||||
def ensure_kv_transfer_shutdown() -> None:
|
||||
global _KV_CONNECTOR_AGENT
|
||||
if _KV_CONNECTOR_AGENT is not None:
|
||||
_KV_CONNECTOR_AGENT.shutdown()
|
||||
_KV_CONNECTOR_AGENT = None
|
||||
Reference in New Issue
Block a user