Sync from v0.13
This commit is contained in:
597
vllm/distributed/kv_transfer/kv_connector/v1/base.py
Normal file
597
vllm/distributed/kv_transfer/kv_connector/v1/base.py
Normal file
@@ -0,0 +1,597 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
KVConnectorBase_V1 Class for Distributed KV Cache & Hidden State
|
||||
communication in vLLM v1
|
||||
|
||||
The class provides the following primitives:
|
||||
Scheduler-side: runs in the scheduler, binds metadata, which
|
||||
is used by the worker-side to load/save KV cache.
|
||||
get_num_new_matched_tokens() - get number of new tokens
|
||||
that exist in the remote KV cache. Might be called multiple
|
||||
times for a given request and should be side-effect free.
|
||||
update_state_after_alloc() - update KVConnector state after
|
||||
temporary buffer alloc by the CacheManager.
|
||||
update_connector_output() - update KVConnector state after
|
||||
output is received from worker-side connectors.
|
||||
request_finished() - called once when a request is finished,
|
||||
with the computed kv cache blocks for the request.
|
||||
Returns whether KV cache should be freed now or if the
|
||||
connector now assumes responsibility for freeing the
|
||||
the blocks asynchronously. Also optionally returns KV
|
||||
transfer params.
|
||||
take_events() - returns new KV events that were collected
|
||||
by the connector since the last call.
|
||||
|
||||
Worker-side: runs in each worker, loads/saves KV cache to/from
|
||||
the Connector based on the metadata.
|
||||
start_load_kv() - starts loading all KVs (maybe async)
|
||||
wait_for_layer_load() - blocks until layer i load is done
|
||||
|
||||
save_kv_layer() - starts saving KV for layer i (maybe async)
|
||||
wait_for_save() - blocks until all saves are done
|
||||
|
||||
get_finished() - called with ids of finished requests, returns
|
||||
ids of requests that have completed async sending/recving.
|
||||
"""
|
||||
|
||||
import enum
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Iterable
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Literal, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.outputs import KVConnectorOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_events import KVCacheEvent, KVConnectorKVEvents
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
||||
KVConnectorPromMetrics,
|
||||
KVConnectorStats,
|
||||
PromMetric,
|
||||
PromMetricT,
|
||||
)
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.request import Request
|
||||
|
||||
# s_tensor_list, d_tensor_list, s_indices, d_indices, direction
|
||||
CopyBlocksOp = Callable[
|
||||
[
|
||||
dict[str, torch.Tensor],
|
||||
dict[str, torch.Tensor],
|
||||
list[int],
|
||||
list[int],
|
||||
Literal["h2d", "d2h"],
|
||||
],
|
||||
None,
|
||||
]
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class SupportsHMA(ABC):
|
||||
"""
|
||||
The class that indicates the corresponding connector supports hybrid memory
|
||||
allocator (HMA).
|
||||
This is required to use the connector together with hybrid memory allocator.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def request_finished_all_groups(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: tuple[list[int], ...],
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
"""
|
||||
Called exactly once when a request has finished for all kv cache groups,
|
||||
before its blocks are freed for each group.
|
||||
|
||||
NOTE(Kuntai): This function is only supported by connectors that support HMA.
|
||||
|
||||
The connector may assumes responsibility for freeing the blocks
|
||||
asynchronously by returning True.
|
||||
|
||||
Returns:
|
||||
True if the request is being saved/sent asynchronously and blocks
|
||||
should not be freed until the request_id is returned from
|
||||
get_finished().
|
||||
Optional KVTransferParams to be included in the request outputs
|
||||
returned by the engine.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def supports_hma(connector: Any) -> bool:
|
||||
if isinstance(connector, type):
|
||||
return issubclass(connector, SupportsHMA)
|
||||
else:
|
||||
return isinstance(connector, SupportsHMA)
|
||||
|
||||
|
||||
class KVConnectorRole(enum.Enum):
|
||||
# Connector running in the scheduler process
|
||||
SCHEDULER = 0
|
||||
|
||||
# Connector running in the worker process
|
||||
WORKER = 1
|
||||
|
||||
|
||||
class KVConnectorHandshakeMetadata(ABC): # noqa: B024
|
||||
"""
|
||||
Metadata used for out of band connector handshake between
|
||||
P/D workers. This needs to serializeable.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class KVConnectorMetadata(ABC): # noqa: B024
|
||||
"""
|
||||
Abstract Metadata used to communicate between the
|
||||
Scheduler KVConnector and Worker KVConnector.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class KVConnectorBase_V1(ABC):
|
||||
"""
|
||||
Base class for KV connectors.
|
||||
|
||||
Attributes:
|
||||
prefer_cross_layer_blocks (bool): Indicates whether this connector
|
||||
prefers KV blocks that hold KV data for all layers (for speeding
|
||||
up KV data transfers).
|
||||
Defaults to False.
|
||||
"""
|
||||
|
||||
prefer_cross_layer_blocks: ClassVar[bool] = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: Optional["KVCacheConfig"] = None,
|
||||
):
|
||||
logger.warning(
|
||||
"Initializing KVConnectorBase_V1. This API is experimental and "
|
||||
"subject to change in the future as we iterate the design."
|
||||
)
|
||||
self._connector_metadata: KVConnectorMetadata | None = None
|
||||
self._vllm_config = vllm_config
|
||||
if vllm_config.kv_transfer_config is not None:
|
||||
self._kv_transfer_config = vllm_config.kv_transfer_config
|
||||
else:
|
||||
raise ValueError("kv_transfer_config must be set for KVConnectorBase_V1")
|
||||
self._kv_cache_config = kv_cache_config
|
||||
if self._kv_cache_config is None:
|
||||
logger.warning(
|
||||
"KVConnectorBase_V1 initialized without kv_cache_config. "
|
||||
"This is deprecated - please update your connector to accept "
|
||||
"kv_cache_config as the third constructor argument and pass it "
|
||||
"to super().__init__()."
|
||||
)
|
||||
self._role = role
|
||||
|
||||
@property
|
||||
def role(self) -> KVConnectorRole:
|
||||
return self._role
|
||||
|
||||
# ==============================
|
||||
# Worker-side methods
|
||||
# ==============================
|
||||
|
||||
def bind_connector_metadata(self, connector_metadata: KVConnectorMetadata) -> None:
|
||||
"""Set the connector metadata from the scheduler.
|
||||
|
||||
This function should be called by the model runner every time
|
||||
before the model execution. The metadata will be used for runtime
|
||||
KV cache loading and saving.
|
||||
|
||||
Args:
|
||||
connector_metadata (dict): the connector metadata.
|
||||
"""
|
||||
self._connector_metadata = connector_metadata
|
||||
|
||||
def clear_connector_metadata(self) -> None:
|
||||
"""Clear the connector metadata.
|
||||
|
||||
This function should be called by the model runner every time
|
||||
after the model execution.
|
||||
"""
|
||||
self._connector_metadata = None
|
||||
|
||||
def _get_connector_metadata(self) -> KVConnectorMetadata:
|
||||
"""Get the connector metadata.
|
||||
|
||||
This function should only be called inside the connector.
|
||||
|
||||
Returns:
|
||||
ConnectorMetadata: the connector metadata.
|
||||
"""
|
||||
# Should only be called while set to valid metadata.
|
||||
assert self._connector_metadata is not None
|
||||
return self._connector_metadata
|
||||
|
||||
def has_connector_metadata(self) -> bool:
|
||||
"""Check whether the connector metadata is currently set.
|
||||
|
||||
Returns:
|
||||
bool: True if connector metadata exists, False otherwise.
|
||||
"""
|
||||
return self._connector_metadata is not None
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
"""
|
||||
Initialize with the KV caches. Useful for pre-registering the
|
||||
KV Caches in the KVConnector (e.g. for NIXL).
|
||||
|
||||
Args:
|
||||
kv_caches: dictionary of layer names, kv cache
|
||||
"""
|
||||
return
|
||||
|
||||
def register_cross_layers_kv_cache(
|
||||
self, kv_cache: torch.Tensor, attn_backend: type["AttentionBackend"]
|
||||
):
|
||||
"""
|
||||
Initialize with a single KV cache tensor used by all layers.
|
||||
The first dimension should be num_layers.
|
||||
This function will only be called for models with uniform layers,
|
||||
and only if the prefers_cross_layer_blocks is set to True.
|
||||
Only one of the functions
|
||||
{register_kv_caches, register_cross_layers_kv_cache} will be called.
|
||||
|
||||
Args:
|
||||
kv_cache: a cross-layers kv cache tensor
|
||||
attn_backend: The attention backend that corresponds to all layers
|
||||
"""
|
||||
return
|
||||
|
||||
def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp):
|
||||
"""
|
||||
Set the xPU-specific ops for copying KV between host and device.
|
||||
Needed when host buffer is used for kv transfer (e.g., in NixlConnector)
|
||||
"""
|
||||
return
|
||||
|
||||
@abstractmethod
|
||||
def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None:
|
||||
"""
|
||||
Start loading the KV cache from the connector to vLLM's paged
|
||||
KV buffer. This is called from the forward context before the
|
||||
forward pass to enable async loading during model execution.
|
||||
|
||||
Args:
|
||||
forward_context (ForwardContext): the forward context.
|
||||
**kwargs: additional arguments for the load operation
|
||||
|
||||
Note:
|
||||
The number of elements in kv_caches and layer_names should be
|
||||
the same.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
"""
|
||||
Block until the KV for a specific layer is loaded into vLLM's
|
||||
paged buffer. This is called from within attention layer to ensure
|
||||
async copying from start_load_kv is complete.
|
||||
|
||||
This interface will be useful for layer-by-layer pipelining.
|
||||
|
||||
Args:
|
||||
layer_name: the name of that layer
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_kv_layer(
|
||||
self,
|
||||
layer_name: str,
|
||||
kv_layer: torch.Tensor,
|
||||
attn_metadata: "AttentionMetadata",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Start saving a layer of KV cache from vLLM's paged buffer
|
||||
to the connector. This is called from within attention layer to
|
||||
enable async copying during execution.
|
||||
|
||||
Args:
|
||||
layer_name (str): the name of the layer.
|
||||
kv_layer (torch.Tensor): the paged KV buffer of the current
|
||||
layer in vLLM.
|
||||
attn_metadata (AttentionMetadata): the attention metadata.
|
||||
**kwargs: additional arguments for the save operation.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def wait_for_save(self):
|
||||
"""
|
||||
Block until all the save operations is done. This is called
|
||||
as the forward context exits to ensure that the async saving
|
||||
from save_kv_layer is complete before finishing the forward.
|
||||
|
||||
This prevents overwrites of paged KV buffer before saving done.
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_finished(
|
||||
self, finished_req_ids: set[str]
|
||||
) -> tuple[set[str] | None, set[str] | None]:
|
||||
"""
|
||||
Notifies worker-side connector ids of requests that have
|
||||
finished generating tokens on the worker.
|
||||
The scheduler process (via the Executors) will use this output
|
||||
to track which workers are done.
|
||||
|
||||
Returns:
|
||||
ids of requests that have finished asynchronous transfer
|
||||
(requests that previously returned True from request_finished()),
|
||||
tuple of (sending/saving ids, recving/loading ids).
|
||||
The finished saves/sends req ids must belong to a set provided in a
|
||||
call to this method (this call or a prior one).
|
||||
"""
|
||||
return None, None
|
||||
|
||||
def get_block_ids_with_load_errors(self) -> set[int]:
|
||||
"""
|
||||
Get the set of block IDs that failed to load.
|
||||
|
||||
Returns:
|
||||
Set of block IDs that encountered load errors.
|
||||
Empty set if no load errors occurred.
|
||||
|
||||
Notes:
|
||||
- Applies to both sync- and async-loading requests.
|
||||
- Async loading: failed blocks may be reported in any forward pass
|
||||
up to and including the pass where the request ID is returned by
|
||||
`get_finished()`. Even if failures occur, the request must still
|
||||
be reported via `get_finished()`, and the failed block IDs must
|
||||
appear here no later than that same pass.
|
||||
- Sync loading: failed blocks should be reported in the forward
|
||||
pass in which they are detected.
|
||||
"""
|
||||
return set()
|
||||
|
||||
def shutdown(self):
|
||||
"""
|
||||
Shutdown the connector. This is called when the worker process
|
||||
is shutting down to ensure that all the async operations are
|
||||
completed and the connector is cleaned up properly.
|
||||
"""
|
||||
return None
|
||||
|
||||
def get_kv_connector_stats(self) -> Optional["KVConnectorStats"]:
|
||||
"""
|
||||
Get the KV connector stats collected during the last interval.
|
||||
"""
|
||||
return None
|
||||
|
||||
def get_kv_connector_kv_cache_events(self) -> Optional["KVConnectorKVEvents"]:
|
||||
"""
|
||||
Get the KV connector kv cache events collected during the last interval.
|
||||
This function should be called by the model runner every time after the
|
||||
model execution and before cleanup.
|
||||
"""
|
||||
return None
|
||||
|
||||
def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None:
|
||||
"""
|
||||
Get the KVConnector handshake metadata for this connector.
|
||||
This metadata is used for out-of-band connector handshake
|
||||
between P/D workers.
|
||||
|
||||
Returns:
|
||||
KVConnectorHandshakeMetadata: the handshake metadata.
|
||||
None if no handshake metadata is available.
|
||||
"""
|
||||
return None
|
||||
|
||||
# ==============================
|
||||
# Scheduler-side methods
|
||||
# ==============================
|
||||
|
||||
@abstractmethod
|
||||
def get_num_new_matched_tokens(
|
||||
self,
|
||||
request: "Request",
|
||||
num_computed_tokens: int,
|
||||
) -> tuple[int | None, bool]:
|
||||
"""
|
||||
Get number of new tokens that can be loaded from the
|
||||
external KV cache beyond the num_computed_tokens.
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
num_computed_tokens (int): the number of locally
|
||||
computed tokens for this request
|
||||
|
||||
Returns:
|
||||
A tuple with the following elements:
|
||||
- An optional number of tokens that can be loaded from the
|
||||
external KV cache beyond what is already computed.
|
||||
If None, it means that the connector needs more time to
|
||||
determine the number of matched tokens, and the scheduler
|
||||
should query for this request again later.
|
||||
- `True` if external KV cache tokens will be loaded
|
||||
asynchronously (between scheduler steps). Must be
|
||||
'False' if the first element is 0.
|
||||
|
||||
Notes:
|
||||
The connector should only consider the largest prefix of prompt-
|
||||
tokens for which KV cache is actually available at the time of the
|
||||
call. If the cache cannot be loaded for some tokens (e.g., due to
|
||||
connectivity issues or eviction), those tokens must not be taken
|
||||
into account.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_state_after_alloc(
|
||||
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
|
||||
):
|
||||
"""
|
||||
Update KVConnector state after block allocation.
|
||||
|
||||
If get_num_new_matched_tokens previously returned True for a
|
||||
request, this function may be called twice for that same request -
|
||||
first when blocks are allocated for the connector tokens to be
|
||||
asynchronously loaded into, and second when any additional blocks
|
||||
are allocated, after the load/transfer is complete.
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
blocks (KVCacheBlocks): the blocks allocated for the request.
|
||||
num_external_tokens (int): the number of tokens that will be
|
||||
loaded from the external KV cache.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: SchedulerOutput
|
||||
) -> KVConnectorMetadata:
|
||||
"""
|
||||
Build the connector metadata for this step.
|
||||
|
||||
This function should NOT modify fields in the scheduler_output.
|
||||
Also, calling this function will reset the state of the connector.
|
||||
|
||||
Args:
|
||||
scheduler_output (SchedulerOutput): the scheduler output object.
|
||||
"""
|
||||
pass
|
||||
|
||||
def update_connector_output(self, connector_output: KVConnectorOutput):
|
||||
"""
|
||||
Update KVConnector state from worker-side connectors output.
|
||||
|
||||
Args:
|
||||
connector_output (KVConnectorOutput): the worker-side
|
||||
connectors output.
|
||||
"""
|
||||
return
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
"""
|
||||
Called exactly once when a request has finished, before its blocks are
|
||||
freed.
|
||||
|
||||
The connector may assumes responsibility for freeing the blocks
|
||||
asynchronously by returning True.
|
||||
|
||||
Returns:
|
||||
True if the request is being saved/sent asynchronously and blocks
|
||||
should not be freed until the request_id is returned from
|
||||
get_finished().
|
||||
Optional KVTransferParams to be included in the request outputs
|
||||
returned by the engine.
|
||||
"""
|
||||
return False, None
|
||||
|
||||
def take_events(self) -> Iterable["KVCacheEvent"]:
|
||||
"""
|
||||
Take the KV cache events from the connector.
|
||||
|
||||
Yields:
|
||||
New KV cache events since the last call.
|
||||
"""
|
||||
return ()
|
||||
|
||||
@classmethod
|
||||
def get_required_kvcache_layout(cls, vllm_config: "VllmConfig") -> str | None:
|
||||
"""
|
||||
Get the required KV cache layout for this connector.
|
||||
Args:
|
||||
vllm_config (VllmConfig): the vllm config.
|
||||
|
||||
Returns:
|
||||
str: the required KV cache layout. e.g. HND, or NHD.
|
||||
None if the connector does not require a specific layout.
|
||||
"""
|
||||
|
||||
if cls is KVConnectorBase_V1:
|
||||
raise TypeError(
|
||||
"get_required_kvcache_layout should not be called "
|
||||
"on the abstract base class"
|
||||
)
|
||||
return None
|
||||
|
||||
def get_finished_count(self) -> int | None:
|
||||
"""
|
||||
Get the count of requests expected to complete send/receive operations
|
||||
via this connector. This method is used to initialize the
|
||||
KVOutputAggregator, overwriting the default world_size.
|
||||
|
||||
Returns:
|
||||
int: expected sending or receiving completion count.
|
||||
"""
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def build_kv_connector_stats(
|
||||
cls, data: dict[str, Any] | None = None
|
||||
) -> Optional["KVConnectorStats"]:
|
||||
"""
|
||||
KVConnectorStats resolution method. This method allows dynamically
|
||||
registered connectors to return their own KVConnectorStats object,
|
||||
which can implement custom aggregation logic on the data dict.
|
||||
"""
|
||||
return None
|
||||
|
||||
def set_xfer_handshake_metadata(
|
||||
self, metadata: dict[int, KVConnectorHandshakeMetadata]
|
||||
) -> None:
|
||||
"""
|
||||
Set the KV connector handshake metadata for this connector.
|
||||
|
||||
Args:
|
||||
metadata (KVConnectorHandshakeMetadata): the handshake metadata to set.
|
||||
"""
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def build_prom_metrics(
|
||||
cls,
|
||||
vllm_config: "VllmConfig",
|
||||
metric_types: dict[type["PromMetric"], type["PromMetricT"]],
|
||||
labelnames: list[str],
|
||||
per_engine_labelvalues: dict[int, list[object]],
|
||||
) -> Optional["KVConnectorPromMetrics"]:
|
||||
"""
|
||||
Create a KVConnectorPromMetrics subclass which should register
|
||||
per-connector Prometheus metrics and implement observe() to
|
||||
expose connector transfer stats via Prometheus.
|
||||
"""
|
||||
return None
|
||||
|
||||
def reset_cache(self) -> bool | None:
|
||||
"""
|
||||
Reset the connector's internal cache.
|
||||
|
||||
Returns:
|
||||
bool: True if the cache was successfully reset, False otherwise.
|
||||
"""
|
||||
logger.debug(
|
||||
"Connector cache reset requested, but %s does not implement reset_cache().",
|
||||
type(self).__name__,
|
||||
)
|
||||
|
||||
return None
|
||||
Reference in New Issue
Block a user