diff --git a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_connector.py b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_connector.py index 3f432378..a54ef6cb 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_connector.py +++ b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_connector.py @@ -76,6 +76,7 @@ class ReqMeta: remote_host: str remote_port: int remote_engine_id: str + remote_request_id: str remote_pcp_size: int remote_dcp_size: int remote_multi_nodes_meta_mapping: dict[str, dict[str, Any]] @@ -375,6 +376,7 @@ class KVCacheRecvingThread(threading.Thread): def add_request(self, request_id: str, + remote_request_id: str, local_block_ids: list[int], remote_block_ids: list[int], remote_engine_id: str, @@ -391,6 +393,7 @@ class KVCacheRecvingThread(threading.Thread): "local_block_ids": local_block_ids, "remote_block_ids": remote_block_ids, "remote_engine_id": remote_engine_id, + "remote_request_id": remote_request_id, "remote_host": remote_host, "remote_handshake_port": remote_handshake_port, "offset": offset, @@ -423,6 +426,7 @@ class KVCacheRecvingThread(threading.Thread): def _handle_request(self, req_meta: dict[str, Any]): request_id = req_meta["request_id"] + remote_request_id = req_meta["remote_request_id"] remote_host = req_meta["remote_host"] remote_handshake_port = req_meta["remote_handshake_port"] remote_port_send_num = req_meta["remote_port_send_num"] @@ -430,14 +434,14 @@ class KVCacheRecvingThread(threading.Thread): try: logger.debug( - f"Starting to transfer KV cache for request {request_id}.") + f"Starting to transfer KV cache for request {remote_request_id}.") self._transfer_kv_cache(req_meta) logger.debug( - f"Finished transferring KV cache for request {request_id}.") + f"Finished transferring KV cache for request {remote_request_id}.") except Exception as e: logger.error( "Failed to transfer KV cache for request " - f"{request_id}: {e}", + f"{remote_request_id}: {e}", exc_info=True) finally: if all_task_done: @@ -448,10 +452,10 @@ class KVCacheRecvingThread(threading.Thread): # Always send the done signal to the remote host to ensure proper # resource cleanup. Failing to do so may cause a memory leak on the # remote host. - self._send_done_recv_signal(request_id, remote_host, + self._send_done_recv_signal(remote_request_id, remote_host, remote_handshake_port, remote_port_send_num) - self._send_done_signal_to_free_remote_port(request_id, remote_host, + self._send_done_signal_to_free_remote_port(remote_request_id, remote_host, remote_port_send_num) def _send_done_signal_to_free_remote_port(self, request_id, remote_host, @@ -472,7 +476,7 @@ class KVCacheRecvingThread(threading.Thread): def _transfer_kv_cache(self, req_meta: dict[str, Any]): """Handle a KV cache transfer request.""" - request_id = req_meta["request_id"] + remote_request_id = req_meta["remote_request_id"] remote_block_ids = req_meta["remote_block_ids"] local_block_ids = req_meta["local_block_ids"] remote_engine_id = req_meta["remote_engine_id"] @@ -558,7 +562,7 @@ class KVCacheRecvingThread(threading.Thread): dst_list, length_list) if ret < 0: logger.error("Mooncake transfer failed for request %s", - req_meta["request_id"]) + req_meta["remote_request_id"]) raise RuntimeError(f"Mooncake transfer failed, ret: {ret}") req_end_time = time.perf_counter() @@ -566,7 +570,7 @@ class KVCacheRecvingThread(threading.Thread): logger.info( "KV cache transfer for request %s took %.2f ms (%d groups," " %d blocks). local_ip %s local_device_id %s remote_session_id %s", - request_id, req_transfer_elapsed, num_transfer_groups, num_blocks, + remote_request_id, req_transfer_elapsed, num_transfer_groups, num_blocks, get_ip(), self.tp_rank, session_id) # Determine if the current position is the offset position at the end of @@ -791,6 +795,7 @@ class MooncakeConnectorMetadata(KVConnectorMetadata): num_external_tokens=num_external_tokens, remote_block_ids=kv_transfer_params["remote_block_ids"], remote_engine_id=kv_transfer_params["remote_engine_id"], + remote_request_id=kv_transfer_params["remote_request_id"], remote_host=kv_transfer_params["remote_host"], remote_port=kv_transfer_params["remote_port"], remote_pcp_size=kv_transfer_params.get("remote_pcp_size", 1), @@ -996,7 +1001,7 @@ class MooncakeConnectorScheduler: if params is not None and params.get("do_remote_prefill"): if params.get("remote_block_ids"): if all(p in params for p in ("remote_engine_id", "remote_host", - "remote_port")): + "remote_port", "remote_request_id")): local_block_ids = (blocks.get_unhashed_block_ids() if num_external_tokens > 0 else []) # Get unhashed blocks to pull from remote. @@ -1074,6 +1079,7 @@ class MooncakeConnectorScheduler: do_remote_decode=False, remote_block_ids=computed_block_ids, remote_engine_id=self.engine_id, + remote_request_id=request.request_id, remote_host=self.side_channel_host, remote_port=self.side_channel_port, remote_pcp_size=self.pcp_size, @@ -1583,6 +1589,7 @@ class MooncakeConnectorWorker: meta.remote_multi_nodes_meta_mapping) self.kv_recv_thread.add_request( request_id=req_id, + remote_request_id=meta.remote_request_id, local_block_ids=local_block_ids_list[pcp_dcp_rank], remote_block_ids=remote_block_ids_list[ pcp_dcp_rank], @@ -1610,6 +1617,7 @@ class MooncakeConnectorWorker: meta.remote_multi_nodes_meta_mapping) self.kv_recv_thread.add_request( request_id=req_id, + remote_request_id=meta.remote_request_id, local_block_ids=meta.local_block_ids, remote_block_ids=meta.remote_block_ids, remote_engine_id=remote_engine_id,