# SPDX-License-Identifier: Apache-2.0 from typing import TYPE_CHECKING, Any, Optional import torch from ucm.integration.vllm.ucm_connector import UCMConnector from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) from vllm.logger import init_logger from vllm.v1.core.sched.output import SchedulerOutput logger = init_logger(__name__) if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( KVConnectorPromMetrics, KVConnectorStats, PromMetric, PromMetricT) from vllm.forward_context import ForwardContext from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.request import Request class UCMConnectorV1(KVConnectorBase_V1): def __init__( self, vllm_config: "VllmConfig", role: KVConnectorRole, kv_cache_config: "KVCacheConfig", ): super().__init__(vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config) assert vllm_config.kv_transfer_config is not None ImplCls = UCMConnector self._ucm_engine = ImplCls(vllm_config, role) # ============================== # Worker-side methods # ============================== def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]) -> None: """ Initialize with the KV caches. Useful for pre-registering the KV Caches in the KVConnector (e.g. for NIXL). Args: kv_caches: A dictionary mapping layer names to KV cache tensors. """ self._ucm_engine.register_kv_caches(kv_caches) def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None: """ Start loading the KV cache from the connector to vLLM's paged KV buffer. This is called from the forward context before the forward pass to enable async loading during model execution. Args: forward_context (ForwardContext): the forward context. **kwargs: additional arguments for the load operation Note: The number of elements in kv_caches and layer_names should be the same. """ self._ucm_engine.start_load_kv(forward_context, **kwargs) def wait_for_layer_load(self, layer_name: str) -> None: """ Block until the KV for a specific layer is loaded into vLLM's paged buffer. This is called from within attention layer to ensure async copying from start_load_kv is complete. This interface will be useful for layer-by-layer pipelining. Args: layer_name: the name of that layer """ self._ucm_engine.wait_for_layer_load(layer_name) def save_kv_layer( self, layer_name: str, kv_layer: torch.Tensor, attn_metadata: "AttentionMetadata", **kwargs: Any, ) -> None: """ Start saving the a layer of KV cache from vLLM's paged buffer to the connector. This is called from within attention layer to enable async copying during execution. Args: layer_name (str): the name of the layer. kv_layer (torch.Tensor): the paged KV buffer of the current layer in vLLM. attn_metadata (AttentionMetadata): the attention metadata. **kwargs: additional arguments for the save operation. """ self._ucm_engine.save_kv_layer(layer_name, kv_layer, attn_metadata, **kwargs) def wait_for_save(self) -> None: """ Block until all the save operations is done. This is called as the forward context exits to ensure that the async saving from save_kv_layer is complete before finishing the forward. This prevents overwrites of paged KV buffer before saving done. """ self._ucm_engine.wait_for_save() def clear_connector_metadata(self) -> None: """Clear the connector metadata. This function should be called by the model runner every time after the model execution. """ self._ucm_engine.clear_connector_metadata() def bind_connector_metadata( self, connector_metadata: KVConnectorMetadata) -> None: """Set the connector metadata from the scheduler. This function should be called by the model runner every time before the model execution. The metadata will be used for runtime KV cache loading and saving. Args: connector_metadata (dict): the connector metadata. """ self._ucm_engine.bind_connector_metadata(connector_metadata) def get_block_ids_with_load_errors(self) -> set[int]: """ Get the set of block IDs that failed to load. Returns: Set of block IDs that encountered load errors. Empty set if no load errors occurred. """ return self._ucm_engine.get_block_ids_with_load_errors() # ============================== # Scheduler-side methods # ============================== def get_num_new_matched_tokens( self, request: "Request", num_computed_tokens: int, ) -> tuple[int | None, bool]: """ Get number of new tokens that can be loaded from the external KV cache beyond the num_computed_tokens. Args: request (Request): the request object. num_computed_tokens (int): the number of locally computed tokens for this request Returns: the number of tokens that can be loaded from the external KV cache beyond what is already computed. """ return self._ucm_engine.get_num_new_matched_tokens( request, num_computed_tokens) def update_state_after_alloc(self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int) -> None: """ Update KVConnector state after block allocation. """ self._ucm_engine.update_state_after_alloc(request, blocks, num_external_tokens) def build_connector_meta( self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata: """ Build the connector metadata for this step. This function should NOT modify fields in the scheduler_output. Also, calling this function will reset the state of the connector. Args: scheduler_output (SchedulerOutput): the scheduler output object. """ return self._ucm_engine.build_connector_meta(scheduler_output) def request_finished( self, request: "Request", block_ids: list[int], ) -> tuple[bool, dict[str, Any] | None]: """ Called when a request has finished, before its blocks are freed. Returns: True if the request is being saved/sent asynchronously and blocks should not be freed until the request_id is returned from get_finished(). Optional KVTransferParams to be included in the request outputs returned by the engine. """ return self._ucm_engine.request_finished(request, block_ids) # ============================== # Metrics & Stats # ============================== @classmethod def build_kv_connector_stats( cls, data: dict[str, Any] | None = None ) -> Optional["KVConnectorStats"]: """ KVConnectorStats resolution method. This method allows dynamically registered connectors to return their own KVConnectorStats object, which can implement custom aggregation logic on the data dict. """ return UCMConnector.build_kv_connector_stats(data) @classmethod def build_prom_metrics( cls, vllm_config: "VllmConfig", metric_types: dict[type["PromMetric"], type["PromMetricT"]], labelnames: list[str], per_engine_labelvalues: dict[int, list[object]], ) -> Optional["KVConnectorPromMetrics"]: """ Create a KVConnectorPromMetrics subclass which should register per-connector Prometheus metrics and implement observe() to expose connector transfer stats via Prometheus. This implementation forwards the call to the underlying UCMConnector engine. """ return UCMConnector.build_prom_metrics( vllm_config, metric_types, labelnames, per_engine_labelvalues, )