diff --git a/vllm_ascend/distributed/kv_transfer/__init__.py b/vllm_ascend/distributed/kv_transfer/__init__.py index 0450a104..dae05787 100644 --- a/vllm_ascend/distributed/kv_transfer/__init__.py +++ b/vllm_ascend/distributed/kv_transfer/__init__.py @@ -19,6 +19,13 @@ from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory def register_connector(): + # override multi_connector as ascend_multi_connector + if "MultiConnector" in KVConnectorFactory._registry: + KVConnectorFactory._registry.pop("MultiConnector") + KVConnectorFactory.register_connector( + "MultiConnector", "vllm_ascend.distributed.kv_transfer.ascend_multi_connector", "AscendMultiConnector" + ) + KVConnectorFactory.register_connector( "MooncakeConnectorV1", "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector", "MooncakeConnector" ) diff --git a/vllm_ascend/distributed/kv_transfer/ascend_multi_connector.py b/vllm_ascend/distributed/kv_transfer/ascend_multi_connector.py new file mode 100644 index 00000000..bd5be6c5 --- /dev/null +++ b/vllm_ascend/distributed/kv_transfer/ascend_multi_connector.py @@ -0,0 +1,22 @@ +from typing import TYPE_CHECKING + +from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import MultiConnector + +from vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector import MooncakeLayerwiseConnector + +if TYPE_CHECKING: + from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.request import Request + + +class AscendMultiConnector(MultiConnector): + def update_state_after_alloc(self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int): + chosen_connector = self._requests_to_connector.get(request.request_id, -1) + empty_blocks = blocks.new_empty() + for i, c in enumerate(self._connectors): + if i == chosen_connector or isinstance(c, MooncakeLayerwiseConnector): + # Forward call to the chosen connector (if any). + c.update_state_after_alloc(request, blocks, num_external_tokens) + else: + # Call with empty blocks for other connectors. + c.update_state_after_alloc(request, empty_blocks, 0) diff --git a/vllm_ascend/worker/worker.py b/vllm_ascend/worker/worker.py index b8cbbdbd..39f4f646 100644 --- a/vllm_ascend/worker/worker.py +++ b/vllm_ascend/worker/worker.py @@ -504,6 +504,7 @@ class NPUWorker(WorkerBase): def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None: """Allocate NPU KV cache with the specified kv_cache_config.""" + ensure_kv_transfer_initialized(self.vllm_config, kv_cache_config) if self.vllm_config.model_config.enable_sleep_mode: allocator = CaMemAllocator.get_instance() context = allocator.use_memory_pool(tag="kv_cache") @@ -575,7 +576,6 @@ class NPUWorker(WorkerBase): self.parallel_config.decode_context_parallel_size, ) init_ascend_model_parallel(self.parallel_config) - ensure_kv_transfer_initialized(self.vllm_config) ensure_ec_transfer_initialized(self.vllm_config) def _create_profiler(self, trace_name: str):