From 675387f1fd8d40f5c1330ea0ee7ea6ac2624c1c3 Mon Sep 17 00:00:00 2001 From: zxr2333 <64738772+nwpu-zxr@users.noreply.github.com> Date: Mon, 9 Mar 2026 10:49:04 +0800 Subject: [PATCH] [P/D][KVPool]Mooncake Layerwise Connector supports kv_pool (#7032) ### What this PR does / why we need it? This PR creates and registers `ascend_multi_connector`, which allows the `mooncake_layerwise_connector` to use the kv_pooling feature. We unregister the original vllm's `MultiConnector` and replace it with `AscendMultiConnector` when registering the connectors. ### Does this PR introduce _any_ user-facing change? No. User can use `MultiConnector` to initialize `AscendMultiConnector`. ### How was this patch tested? By CI. - vLLM version: v0.16.0 - vLLM main: https://github.com/vllm-project/vllm/commit/4034c3d32e30d01639459edd3ab486f56993876d --------- Signed-off-by: nwpu-zxr --- .../distributed/kv_transfer/__init__.py | 7 ++++++ .../kv_transfer/ascend_multi_connector.py | 22 +++++++++++++++++++ vllm_ascend/worker/worker.py | 2 +- 3 files changed, 30 insertions(+), 1 deletion(-) create mode 100644 vllm_ascend/distributed/kv_transfer/ascend_multi_connector.py diff --git a/vllm_ascend/distributed/kv_transfer/__init__.py b/vllm_ascend/distributed/kv_transfer/__init__.py index 0450a104..dae05787 100644 --- a/vllm_ascend/distributed/kv_transfer/__init__.py +++ b/vllm_ascend/distributed/kv_transfer/__init__.py @@ -19,6 +19,13 @@ from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory def register_connector(): + # override multi_connector as ascend_multi_connector + if "MultiConnector" in KVConnectorFactory._registry: + KVConnectorFactory._registry.pop("MultiConnector") + KVConnectorFactory.register_connector( + "MultiConnector", "vllm_ascend.distributed.kv_transfer.ascend_multi_connector", "AscendMultiConnector" + ) + KVConnectorFactory.register_connector( "MooncakeConnectorV1", "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector", "MooncakeConnector" ) diff --git a/vllm_ascend/distributed/kv_transfer/ascend_multi_connector.py b/vllm_ascend/distributed/kv_transfer/ascend_multi_connector.py new file mode 100644 index 00000000..bd5be6c5 --- /dev/null +++ b/vllm_ascend/distributed/kv_transfer/ascend_multi_connector.py @@ -0,0 +1,22 @@ +from typing import TYPE_CHECKING + +from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import MultiConnector + +from vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector import MooncakeLayerwiseConnector + +if TYPE_CHECKING: + from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.request import Request + + +class AscendMultiConnector(MultiConnector): + def update_state_after_alloc(self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int): + chosen_connector = self._requests_to_connector.get(request.request_id, -1) + empty_blocks = blocks.new_empty() + for i, c in enumerate(self._connectors): + if i == chosen_connector or isinstance(c, MooncakeLayerwiseConnector): + # Forward call to the chosen connector (if any). + c.update_state_after_alloc(request, blocks, num_external_tokens) + else: + # Call with empty blocks for other connectors. + c.update_state_after_alloc(request, empty_blocks, 0) diff --git a/vllm_ascend/worker/worker.py b/vllm_ascend/worker/worker.py index b8cbbdbd..39f4f646 100644 --- a/vllm_ascend/worker/worker.py +++ b/vllm_ascend/worker/worker.py @@ -504,6 +504,7 @@ class NPUWorker(WorkerBase): def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None: """Allocate NPU KV cache with the specified kv_cache_config.""" + ensure_kv_transfer_initialized(self.vllm_config, kv_cache_config) if self.vllm_config.model_config.enable_sleep_mode: allocator = CaMemAllocator.get_instance() context = allocator.use_memory_pool(tag="kv_cache") @@ -575,7 +576,6 @@ class NPUWorker(WorkerBase): self.parallel_config.decode_context_parallel_size, ) init_ascend_model_parallel(self.parallel_config) - ensure_kv_transfer_initialized(self.vllm_config) ensure_ec_transfer_initialized(self.vllm_config) def _create_profiler(self, trace_name: str):