From 14bd55f30c3a7aab092b1cde2ad589f6d6b16f3e Mon Sep 17 00:00:00 2001 From: zxr2333 <64738772+nwpu-zxr@users.noreply.github.com> Date: Thu, 29 Jan 2026 20:19:05 +0800 Subject: [PATCH] [P/D][BugFix] Fix layerwise P/D request_id error (#6360) ### What this PR does / why we need it? Fix layerwise Connector P/D request_id error, due to vllm pr: https://github.com/vllm-project/vllm/pull/27987, which will add a random suffix to request_id in EngineCore. - vLLM version: v0.14.1 - vLLM main: https://github.com/vllm-project/vllm/commit/dc917cceb877dfd13f98c538c4c96158047d98bd --------- Signed-off-by: nwpu-zxr --- .../test_mooncake_layerwise_connector.py | 3 ++ .../kv_p2p/mooncake_layerwise_connector.py | 28 +++++++++++++++---- 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/tests/ut/kv_connector/test_mooncake_layerwise_connector.py b/tests/ut/kv_connector/test_mooncake_layerwise_connector.py index 7ea4b142..54275f25 100644 --- a/tests/ut/kv_connector/test_mooncake_layerwise_connector.py +++ b/tests/ut/kv_connector/test_mooncake_layerwise_connector.py @@ -240,6 +240,7 @@ class TestKVCacheRecvingLayerThread(unittest.TestCase): with th.lock: th.task_tracker["reqX"] = 0 + th.request_map["reqX"] = "reqX" th.update_task("reqX") with th.lock: @@ -313,6 +314,7 @@ class TestKVCacheRecvingLayerThread(unittest.TestCase): with th.lock: th.task_tracker["reqA"] = 0 + th.request_map["reqA"] = "reqA" with self.assertRaises(SystemExit): th.run() @@ -386,6 +388,7 @@ class TestKVCacheRecvingLayerThread(unittest.TestCase): ready_event=self.ready_event) with th.lock: th.task_tracker["reqB"] = 0 + th.request_map["reqB"] = "reqB" with self.assertRaises(SystemExit): th.run() diff --git a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py index 97cb9e89..67463b90 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py +++ b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py @@ -355,6 +355,7 @@ class KVCacheRecvingLayerThread(threading.Thread): self.lock = threading.Lock() self.done_requests = set[str]() self.task_tracker = dict[str, int]() + self.request_map = dict[str, str]() self.ready_event = ready_event self.metadata = metadata @@ -374,7 +375,8 @@ class KVCacheRecvingLayerThread(threading.Thread): self.task_tracker[req_id] += 1 if self.task_tracker[req_id] == self.pd_head_ratio: self.task_tracker.pop(req_id) - self.done_requests.add(req_id) + self.done_requests.add(self.request_map[req_id]) + self.request_map.pop(req_id) def run(self): """Run the thread to handle KV cache transfer requests.""" @@ -615,9 +617,11 @@ class MooncakeLayerwiseConnectorScheduler: logger.info(f"Send request: {request.request_id} to proxy metaserver: {params.get('metaserver', None)}") # All parameters here should appear in the returned dict of # request_finished in the scheduler side except "request_id". + # change the format of request_id if vllm-version >= 0.14.0 + external_req_id = get_external_request_id(request.request_id) kv_transfer_params = dict( token_ids=[], - request_id=request.request_id, + request_id=external_req_id, do_remote_prefill=False, do_remote_decode=True, remote_block_ids=local_block_ids, @@ -1050,9 +1054,11 @@ class MooncakeLayerwiseConnectorWorker: self.current_layer = 0 if self.vllm_config.kv_transfer_config.is_kv_consumer: for req_id, meta in metadata.requests.items(): + external_req_id = get_external_request_id(req_id) assert self.kv_recv_layer_thread is not None with self.kv_recv_layer_thread.lock: - self.kv_recv_layer_thread.task_tracker[req_id] = 0 + self.kv_recv_layer_thread.task_tracker[external_req_id] = 0 + self.kv_recv_layer_thread.request_map[external_req_id] = req_id def save_kv_layer( self, @@ -1213,13 +1219,17 @@ class MooncakeLayerwiseConnectorWorker: return req_meta_update def send_done_send_signal(self, req_id, req_meta): + external_req_id = get_external_request_id(req_id) logger.info( - "Sending done sending signal for request %s to %s:%d", req_id, req_meta.remote_host, req_meta.remote_port + "Sending done sending signal for request %s to %s:%d", + external_req_id, + req_meta.remote_host, + req_meta.remote_port, ) try: path = make_zmq_path("tcp", req_meta.remote_host, req_meta.remote_port) msg_encoder = msgspec.msgpack.Encoder() - encoded_data = msg_encoder.encode((DONE_SENDING_MSG, req_id)) + encoded_data = msg_encoder.encode((DONE_SENDING_MSG, external_req_id)) with zmq_ctx(zmq.REQ, path) as sock: # type: ignore ensure_zmq_send(sock, encoded_data) ack = sock.recv() @@ -1227,7 +1237,7 @@ class MooncakeLayerwiseConnectorWorker: raise ValueError(f"Unexpected ACK response: {ack}") except Exception as e: logger.error( - f"Sending done sending signal for request {req_id} to " + f"Sending done sending signal for request {external_req_id} to " f"{req_meta.remote_host}:{req_meta.remote_port} fail with error: {e}" ) @@ -1338,3 +1348,9 @@ def ensure_zmq_recv( else: logger.error(f"Receive failed after all retries: {e}") raise RuntimeError(f"Failed to receive data after {max_retries} retries: {e}") + + +def get_external_request_id(request_id: str): + # NOTE(zxr): vLLM PR #27987 add additional suffix + # to EngineCore request_id with len(suffix) == 9 + return request_id[:-9]