### What this PR does / why we need it?
To adapt different shapes of the KV cache, UCM optimized the
initialization of store by moving it into `register_kv_caches`.
Therefore, this update adds `register_kv_caches` interface to
UCMConnectorV1.
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef
Signed-off-by: UnifiedCacheManager <unifiedcachem@163.com>
247 lines
8.7 KiB
Python
247 lines
8.7 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
from typing import TYPE_CHECKING, Any, Optional
|
|
|
|
import torch
|
|
from ucm.integration.vllm.ucm_connector import UCMConnector
|
|
from vllm.config import VllmConfig
|
|
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
|
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
|
from vllm.logger import init_logger
|
|
from vllm.v1.core.sched.output import SchedulerOutput
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
if TYPE_CHECKING:
|
|
from vllm.attention.backends.abstract import AttentionMetadata
|
|
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
|
KVConnectorPromMetrics, KVConnectorStats, PromMetric, PromMetricT)
|
|
from vllm.forward_context import ForwardContext
|
|
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
|
from vllm.v1.request import Request
|
|
|
|
|
|
class UCMConnectorV1(KVConnectorBase_V1):
|
|
|
|
def __init__(
|
|
self,
|
|
vllm_config: "VllmConfig",
|
|
role: KVConnectorRole,
|
|
kv_cache_config: "KVCacheConfig",
|
|
):
|
|
super().__init__(vllm_config=vllm_config,
|
|
role=role,
|
|
kv_cache_config=kv_cache_config)
|
|
assert vllm_config.kv_transfer_config is not None
|
|
|
|
ImplCls = UCMConnector
|
|
self._ucm_engine = ImplCls(vllm_config, role)
|
|
|
|
# ==============================
|
|
# Worker-side methods
|
|
# ==============================
|
|
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]) -> None:
|
|
"""
|
|
Initialize with the KV caches. Useful for pre-registering the
|
|
KV Caches in the KVConnector (e.g. for NIXL).
|
|
Args:
|
|
kv_caches: A dictionary mapping layer names to KV cache tensors.
|
|
"""
|
|
self._ucm_engine.register_kv_caches(kv_caches)
|
|
|
|
def start_load_kv(self, forward_context: "ForwardContext",
|
|
**kwargs: Any) -> None:
|
|
"""
|
|
Start loading the KV cache from the connector to vLLM's paged
|
|
KV buffer. This is called from the forward context before the
|
|
forward pass to enable async loading during model execution.
|
|
|
|
Args:
|
|
forward_context (ForwardContext): the forward context.
|
|
**kwargs: additional arguments for the load operation
|
|
|
|
Note:
|
|
The number of elements in kv_caches and layer_names should be
|
|
the same.
|
|
|
|
"""
|
|
self._ucm_engine.start_load_kv(forward_context, **kwargs)
|
|
|
|
def wait_for_layer_load(self, layer_name: str) -> None:
|
|
"""
|
|
Block until the KV for a specific layer is loaded into vLLM's
|
|
paged buffer. This is called from within attention layer to ensure
|
|
async copying from start_load_kv is complete.
|
|
|
|
This interface will be useful for layer-by-layer pipelining.
|
|
|
|
Args:
|
|
layer_name: the name of that layer
|
|
"""
|
|
self._ucm_engine.wait_for_layer_load(layer_name)
|
|
|
|
def save_kv_layer(
|
|
self,
|
|
layer_name: str,
|
|
kv_layer: torch.Tensor,
|
|
attn_metadata: "AttentionMetadata",
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""
|
|
Start saving the a layer of KV cache from vLLM's paged buffer
|
|
to the connector. This is called from within attention layer to
|
|
enable async copying during execution.
|
|
|
|
Args:
|
|
layer_name (str): the name of the layer.
|
|
kv_layer (torch.Tensor): the paged KV buffer of the current
|
|
layer in vLLM.
|
|
attn_metadata (AttentionMetadata): the attention metadata.
|
|
**kwargs: additional arguments for the save operation.
|
|
"""
|
|
self._ucm_engine.save_kv_layer(layer_name, kv_layer, attn_metadata,
|
|
**kwargs)
|
|
|
|
def wait_for_save(self) -> None:
|
|
"""
|
|
Block until all the save operations is done. This is called
|
|
as the forward context exits to ensure that the async saving
|
|
from save_kv_layer is complete before finishing the forward.
|
|
|
|
This prevents overwrites of paged KV buffer before saving done.
|
|
"""
|
|
self._ucm_engine.wait_for_save()
|
|
|
|
def clear_connector_metadata(self) -> None:
|
|
"""Clear the connector metadata.
|
|
|
|
This function should be called by the model runner every time
|
|
after the model execution.
|
|
"""
|
|
self._ucm_engine.clear_connector_metadata()
|
|
|
|
def bind_connector_metadata(
|
|
self, connector_metadata: KVConnectorMetadata) -> None:
|
|
"""Set the connector metadata from the scheduler.
|
|
|
|
This function should be called by the model runner every time
|
|
before the model execution. The metadata will be used for runtime
|
|
KV cache loading and saving.
|
|
|
|
Args:
|
|
connector_metadata (dict): the connector metadata.
|
|
"""
|
|
self._ucm_engine.bind_connector_metadata(connector_metadata)
|
|
|
|
def get_block_ids_with_load_errors(self) -> set[int]:
|
|
"""
|
|
Get the set of block IDs that failed to load.
|
|
|
|
Returns:
|
|
Set of block IDs that encountered load errors.
|
|
Empty set if no load errors occurred.
|
|
"""
|
|
return self._ucm_engine.get_block_ids_with_load_errors()
|
|
|
|
# ==============================
|
|
# Scheduler-side methods
|
|
# ==============================
|
|
def get_num_new_matched_tokens(
|
|
self,
|
|
request: "Request",
|
|
num_computed_tokens: int,
|
|
) -> tuple[int | None, bool]:
|
|
"""
|
|
Get number of new tokens that can be loaded from the
|
|
external KV cache beyond the num_computed_tokens.
|
|
|
|
Args:
|
|
request (Request): the request object.
|
|
num_computed_tokens (int): the number of locally
|
|
computed tokens for this request
|
|
|
|
Returns:
|
|
the number of tokens that can be loaded from the
|
|
external KV cache beyond what is already computed.
|
|
"""
|
|
return self._ucm_engine.get_num_new_matched_tokens(
|
|
request, num_computed_tokens)
|
|
|
|
def update_state_after_alloc(self, request: "Request",
|
|
blocks: "KVCacheBlocks",
|
|
num_external_tokens: int) -> None:
|
|
"""
|
|
Update KVConnector state after block allocation.
|
|
"""
|
|
self._ucm_engine.update_state_after_alloc(request, blocks,
|
|
num_external_tokens)
|
|
|
|
def build_connector_meta(
|
|
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
|
|
"""
|
|
Build the connector metadata for this step.
|
|
|
|
This function should NOT modify fields in the scheduler_output.
|
|
Also, calling this function will reset the state of the connector.
|
|
|
|
Args:
|
|
scheduler_output (SchedulerOutput): the scheduler output object.
|
|
"""
|
|
return self._ucm_engine.build_connector_meta(scheduler_output)
|
|
|
|
def request_finished(
|
|
self,
|
|
request: "Request",
|
|
block_ids: list[int],
|
|
) -> tuple[bool, dict[str, Any] | None]:
|
|
"""
|
|
Called when a request has finished, before its blocks are freed.
|
|
|
|
Returns:
|
|
True if the request is being saved/sent asynchronously and blocks
|
|
should not be freed until the request_id is returned from
|
|
get_finished().
|
|
Optional KVTransferParams to be included in the request outputs
|
|
returned by the engine.
|
|
"""
|
|
return self._ucm_engine.request_finished(request, block_ids)
|
|
|
|
# ==============================
|
|
# Metrics & Stats
|
|
# ==============================
|
|
|
|
@classmethod
|
|
def build_kv_connector_stats(
|
|
cls,
|
|
data: dict[str, Any] | None = None
|
|
) -> Optional["KVConnectorStats"]:
|
|
"""
|
|
KVConnectorStats resolution method. This method allows dynamically
|
|
registered connectors to return their own KVConnectorStats object,
|
|
which can implement custom aggregation logic on the data dict.
|
|
"""
|
|
return UCMConnector.build_kv_connector_stats(data)
|
|
|
|
@classmethod
|
|
def build_prom_metrics(
|
|
cls,
|
|
vllm_config: "VllmConfig",
|
|
metric_types: dict[type["PromMetric"], type["PromMetricT"]],
|
|
labelnames: list[str],
|
|
per_engine_labelvalues: dict[int, list[object]],
|
|
) -> Optional["KVConnectorPromMetrics"]:
|
|
"""
|
|
Create a KVConnectorPromMetrics subclass which should register
|
|
per-connector Prometheus metrics and implement observe() to
|
|
expose connector transfer stats via Prometheus.
|
|
|
|
This implementation forwards the call to the underlying
|
|
UCMConnector engine.
|
|
"""
|
|
return UCMConnector.build_prom_metrics(
|
|
vllm_config,
|
|
metric_types,
|
|
labelnames,
|
|
per_engine_labelvalues,
|
|
)
|