diff --git a/vllm_ascend/distributed/ucm_connector.py b/vllm_ascend/distributed/ucm_connector.py index f44e5ee2..d38b6519 100644 --- a/vllm_ascend/distributed/ucm_connector.py +++ b/vllm_ascend/distributed/ucm_connector.py @@ -40,6 +40,15 @@ class UCMConnectorV1(KVConnectorBase_V1): # ============================== # Worker-side methods # ============================== + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]) -> None: + """ + Initialize with the KV caches. Useful for pre-registering the + KV Caches in the KVConnector (e.g. for NIXL). + Args: + kv_caches: A dictionary mapping layer names to KV cache tensors. + """ + self._ucm_engine.register_kv_caches(kv_caches) + def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None: """