From cef04b3555d935174deb1df5556eea452950fbfc Mon Sep 17 00:00:00 2001 From: JiangWeixiang <854746559@qq.com> Date: Thu, 22 Jan 2026 10:48:40 +0800 Subject: [PATCH] [bugfix] adapt_remote_request_id (#6051) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR addresses a request ID mismatch issue in the PD (Prefill-Decoding) separation deployment scenario for vllm-ascend. Upstream vLLM recently mitigated request ID collisions by appending a random suffix to each request_id (e.g., req-123 → req-123-abc), refer to [PR-27987](https://github.com/vllm-project/vllm/pull/27987 ) & [PR-29665](https://github.com/vllm-project/vllm/pull/29665). While this works in single-node deployments, it breaks compatibility in PD-separated setups: the Producer (Prefill node) and Consumer (Decoding node) end up with different request_id values, preventing the Consumer from correctly retrieving the KV cache generated by the Producer. To resolve this, this PR introduces a new field remote_request_id in the metadata passed via mooncake_connector. The Producer preserves and forwards the original (unmodified) request_id as remote_request_id. The Consumer then uses this remote_request_id—instead of its locally generated suffixed ID—to fetch the correct KV cache from the Prefill node. This ensures consistent request identification across PD nodes while maintaining compatibility with upstream vLLM’s request ID deduplication mechanism. image - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d68209402ddab3f54a09bc1f4de9a9495a283b60 Signed-off-by: ghphotoframe <854746559@qq.com> Co-authored-by: jiangweixiang --- .../kv_transfer/kv_p2p/mooncake_connector.py | 26 ++++++++++++------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_connector.py b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_connector.py index 3f432378..a54ef6cb 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_connector.py +++ b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_connector.py @@ -76,6 +76,7 @@ class ReqMeta: remote_host: str remote_port: int remote_engine_id: str + remote_request_id: str remote_pcp_size: int remote_dcp_size: int remote_multi_nodes_meta_mapping: dict[str, dict[str, Any]] @@ -375,6 +376,7 @@ class KVCacheRecvingThread(threading.Thread): def add_request(self, request_id: str, + remote_request_id: str, local_block_ids: list[int], remote_block_ids: list[int], remote_engine_id: str, @@ -391,6 +393,7 @@ class KVCacheRecvingThread(threading.Thread): "local_block_ids": local_block_ids, "remote_block_ids": remote_block_ids, "remote_engine_id": remote_engine_id, + "remote_request_id": remote_request_id, "remote_host": remote_host, "remote_handshake_port": remote_handshake_port, "offset": offset, @@ -423,6 +426,7 @@ class KVCacheRecvingThread(threading.Thread): def _handle_request(self, req_meta: dict[str, Any]): request_id = req_meta["request_id"] + remote_request_id = req_meta["remote_request_id"] remote_host = req_meta["remote_host"] remote_handshake_port = req_meta["remote_handshake_port"] remote_port_send_num = req_meta["remote_port_send_num"] @@ -430,14 +434,14 @@ class KVCacheRecvingThread(threading.Thread): try: logger.debug( - f"Starting to transfer KV cache for request {request_id}.") + f"Starting to transfer KV cache for request {remote_request_id}.") self._transfer_kv_cache(req_meta) logger.debug( - f"Finished transferring KV cache for request {request_id}.") + f"Finished transferring KV cache for request {remote_request_id}.") except Exception as e: logger.error( "Failed to transfer KV cache for request " - f"{request_id}: {e}", + f"{remote_request_id}: {e}", exc_info=True) finally: if all_task_done: @@ -448,10 +452,10 @@ class KVCacheRecvingThread(threading.Thread): # Always send the done signal to the remote host to ensure proper # resource cleanup. Failing to do so may cause a memory leak on the # remote host. - self._send_done_recv_signal(request_id, remote_host, + self._send_done_recv_signal(remote_request_id, remote_host, remote_handshake_port, remote_port_send_num) - self._send_done_signal_to_free_remote_port(request_id, remote_host, + self._send_done_signal_to_free_remote_port(remote_request_id, remote_host, remote_port_send_num) def _send_done_signal_to_free_remote_port(self, request_id, remote_host, @@ -472,7 +476,7 @@ class KVCacheRecvingThread(threading.Thread): def _transfer_kv_cache(self, req_meta: dict[str, Any]): """Handle a KV cache transfer request.""" - request_id = req_meta["request_id"] + remote_request_id = req_meta["remote_request_id"] remote_block_ids = req_meta["remote_block_ids"] local_block_ids = req_meta["local_block_ids"] remote_engine_id = req_meta["remote_engine_id"] @@ -558,7 +562,7 @@ class KVCacheRecvingThread(threading.Thread): dst_list, length_list) if ret < 0: logger.error("Mooncake transfer failed for request %s", - req_meta["request_id"]) + req_meta["remote_request_id"]) raise RuntimeError(f"Mooncake transfer failed, ret: {ret}") req_end_time = time.perf_counter() @@ -566,7 +570,7 @@ class KVCacheRecvingThread(threading.Thread): logger.info( "KV cache transfer for request %s took %.2f ms (%d groups," " %d blocks). local_ip %s local_device_id %s remote_session_id %s", - request_id, req_transfer_elapsed, num_transfer_groups, num_blocks, + remote_request_id, req_transfer_elapsed, num_transfer_groups, num_blocks, get_ip(), self.tp_rank, session_id) # Determine if the current position is the offset position at the end of @@ -791,6 +795,7 @@ class MooncakeConnectorMetadata(KVConnectorMetadata): num_external_tokens=num_external_tokens, remote_block_ids=kv_transfer_params["remote_block_ids"], remote_engine_id=kv_transfer_params["remote_engine_id"], + remote_request_id=kv_transfer_params["remote_request_id"], remote_host=kv_transfer_params["remote_host"], remote_port=kv_transfer_params["remote_port"], remote_pcp_size=kv_transfer_params.get("remote_pcp_size", 1), @@ -996,7 +1001,7 @@ class MooncakeConnectorScheduler: if params is not None and params.get("do_remote_prefill"): if params.get("remote_block_ids"): if all(p in params for p in ("remote_engine_id", "remote_host", - "remote_port")): + "remote_port", "remote_request_id")): local_block_ids = (blocks.get_unhashed_block_ids() if num_external_tokens > 0 else []) # Get unhashed blocks to pull from remote. @@ -1074,6 +1079,7 @@ class MooncakeConnectorScheduler: do_remote_decode=False, remote_block_ids=computed_block_ids, remote_engine_id=self.engine_id, + remote_request_id=request.request_id, remote_host=self.side_channel_host, remote_port=self.side_channel_port, remote_pcp_size=self.pcp_size, @@ -1583,6 +1589,7 @@ class MooncakeConnectorWorker: meta.remote_multi_nodes_meta_mapping) self.kv_recv_thread.add_request( request_id=req_id, + remote_request_id=meta.remote_request_id, local_block_ids=local_block_ids_list[pcp_dcp_rank], remote_block_ids=remote_block_ids_list[ pcp_dcp_rank], @@ -1610,6 +1617,7 @@ class MooncakeConnectorWorker: meta.remote_multi_nodes_meta_mapping) self.kv_recv_thread.add_request( request_id=req_id, + remote_request_id=meta.remote_request_id, local_block_ids=meta.local_block_ids, remote_block_ids=meta.remote_block_ids, remote_engine_id=remote_engine_id,