From c7157af8f7c6c53c5fa07a3e456d04f90119ad49 Mon Sep 17 00:00:00 2001 From: wangxiaoteng888 <56506195+wangxiaoteng888@users.noreply.github.com> Date: Wed, 18 Mar 2026 10:50:02 +0800 Subject: [PATCH] [P/D] LayerwiseConnector supports the virtual push functionality on node D. (#7361) ### What this PR does / why we need it? LayerwiseConnector supports the virtual push functionality on node D.By adding a do_virtual flag to request metadata, the system can now identify and process certain requests virtually, bypassing the actual KV cache transfer process. This allows for immediate completion of these requests from the consumer's perspective, potentially enabling optimizations or specific testing scenarios where physical data transfer is not required. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? By ci - vLLM version: v0.17.0 - vLLM main: https://github.com/vllm-project/vllm/commit/4034c3d32e30d01639459edd3ab486f56993876d --------- Signed-off-by: wangxiaoteng --- .../kv_p2p/mooncake_layerwise_connector.py | 25 +++++++++++++------ 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py index f4de1323..bd6b0975 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py +++ b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py @@ -114,6 +114,7 @@ class ReqMeta: remote_cache_tokens: int = 0 local_computed_tokens: int = 0 local_transed_tokens: int = 0 + do_virtual: bool = False @dataclass @@ -587,6 +588,7 @@ class MooncakeLayerwiseConnectorMetadata(KVConnectorMetadata): remote_tp_size=kv_transfer_params.get("remote_tp_size"), remote_pcp_size=kv_transfer_params.get("remote_pcp_size"), remote_dcp_size=kv_transfer_params.get("remote_dcp_size"), + do_virtual=kv_transfer_params.get("do_virtual"), chunk_finish=chunk_finish, remote_cache_tokens=remote_cache_tokens, local_computed_tokens=local_computed_tokens, @@ -763,6 +765,7 @@ class MooncakeLayerwiseConnectorScheduler: def update_state_after_alloc(self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int): params = request.kv_transfer_params + do_virtual = params.get("do_virtual") logger.debug( "MooncakeLayerwiseConnector update_state_after_alloc: num_external_tokens=%s, kv_transfer_params=%s", num_external_tokens, @@ -804,16 +807,16 @@ class MooncakeLayerwiseConnectorScheduler: remote_dcp_size=self.vllm_config.parallel_config.decode_context_parallel_size, remote_cached_tokens=remote_cached_tokens, ) + if not do_virtual: + future = self.executor.submit( + self._access_metaserver, url=params.get("metaserver", None), message=kv_transfer_params + ) - future = self.executor.submit( - self._access_metaserver, url=params.get("metaserver", None), message=kv_transfer_params - ) + def handle_exception(future): + if future.exception(): + logger.error(f"Access metaserver fail: {future.exception()}") - def handle_exception(future): - if future.exception(): - logger.error(f"Access metaserver fail: {future.exception()}") - - future.add_done_callback(handle_exception) + future.add_done_callback(handle_exception) # Layerwise prefiller add request need send if params is not None and params.get("do_remote_decode"): @@ -1034,6 +1037,7 @@ class MooncakeLayerwiseConnectorWorker: # TODO(kunpengW-code): Reuse k_buffer, v_buffer self.k_quant_buffer: torch.Tensor | None = None self.v_quant_buffer: torch.Tensor | None = None + self.virtual_request: set[str] = set() def create_kv_buffer(self, first_kv_cache_tuple): alignment = 2 * 1024 * 1024 @@ -1231,6 +1235,8 @@ class MooncakeLayerwiseConnectorWorker: if self.vllm_config.kv_transfer_config.is_kv_consumer else set() ) + done_recving.update(self.virtual_request) + self.virtual_request = set() if len(done_recving) > 0: logger.info( f"Number of completed KV cache recv requests: {len(done_recving)}, receive requests: {done_recving}" @@ -1403,6 +1409,9 @@ class MooncakeLayerwiseConnectorWorker: self.current_layer = 0 if self.vllm_config.kv_transfer_config.is_kv_consumer: for req_id, meta in metadata.requests.items(): + if meta.do_virtual: + self.virtual_request.add(req_id) + continue external_req_id = get_external_request_id(req_id) assert self.kv_recv_layer_thread is not None with self.kv_recv_layer_thread.lock: