From bc486d9530f30cd4198d69674d904193bbccd02f Mon Sep 17 00:00:00 2001 From: wangxiaochao6 Date: Mon, 19 Jan 2026 16:35:13 +0800 Subject: [PATCH] [main][bugfix] fix mooncake kv cache transfer when one P has multi nodes (#5960) ### What this PR does / why we need it? In PD disaggregation case, when P has multi nodes, mooncake fails to send data. Fix the issue in this PR. The details: If a P rank does not need to transfer kv cache to any one D rank, D node should send a message to P node to release the kv cache in P node. If P has multi nodes, D node should know the corresponding IP in each P node, then D node can send message to the right P node. Otherwise, send data error will happen. This PR fix this issue by providing P nodes IP to D node through Parameter `remote_port_send_num`. - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2c24bc6996cb165fce92f780b388a5e39b3f4060 --------- Signed-off-by: wangxiaochao Co-authored-by: wangxiaochao --- .../kv_transfer/kv_p2p/mooncake_connector.py | 25 +++++++++++-------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_connector.py b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_connector.py index a43e856b..3f432378 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_connector.py +++ b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_connector.py @@ -269,7 +269,7 @@ class KVCacheSendingThread(threading.Thread): self.prefill_tp_size handshake_port = self.side_channel_port + device_index if self.port_send_num[request_id] >= \ - remote_port_send_num[handshake_port]: + remote_port_send_num[handshake_port]['num']: self.task_tracker.update_done_task_count( request_id) del self.port_send_num[request_id] @@ -382,7 +382,7 @@ class KVCacheRecvingThread(threading.Thread): remote_handshake_port: int, offset: int, tp_num_need_pulls: int, - remote_port_send_num: dict[int, int] = {}, + remote_port_send_num: dict[int, dict[str, int | str]] = {}, all_task_done: bool = False): """Add a new request to the queue for processing.""" logger.debug(f"Adding request {request_id} to the queue.") @@ -463,8 +463,9 @@ class KVCacheRecvingThread(threading.Thread): self.proc_not_transfer_request[request_id] = True if self.proc_not_transfer_request[request_id]: for remote_port in remote_port_send_num.keys(): - if remote_port_send_num[remote_port] == 0: - self._send_done_recv_signal(request_id, remote_host, + if remote_port_send_num[remote_port]['num'] == 0: + remote_host_ = remote_port_send_num[remote_port]['host'] + self._send_done_recv_signal(request_id, remote_host_, remote_port, remote_port_send_num) self.proc_not_transfer_request[request_id] = False @@ -705,7 +706,7 @@ class KVCacheRecvingThread(threading.Thread): def _send_done_recv_signal(self, request_id: str, remote_host: str, remote_handshake_port: int, - remote_port_send_num: dict[int, int]): + remote_port_send_num: dict[int, dict[str, int | str]]): logger.debug("Sending done recving signal for request %s to %s:%d", request_id, remote_host, remote_handshake_port) sock: Optional[zmq.Socket] = None # type: ignore @@ -1170,7 +1171,7 @@ class MooncakeConnectorWorker: self.tp_num_need_pulls = num_d_block_heads // num_p_block_heads self.local_remote_block_port_mapping: dict[ str, Optional[List[List[int]]]] = {} - self.remote_port_send_num: dict[str, dict[int, int]] = {} + self.remote_port_send_num: dict[str, dict[int, dict[str, int | str]]] = {} def _get_prefill_decode_size(self, vllm_config: VllmConfig): # get prefill tp and dp size from extra config @@ -1457,19 +1458,23 @@ class MooncakeConnectorWorker: return local_remote_block_port_mappings def get_remote_port_send_num(local_remote_block_port_mappings): - remote_port_send_num: dict[int, int] = {} + remote_port_send_num: dict[int, dict[str, int | str]] = {} for port in range(self._prefill_tp_size * meta.remote_pcp_size): - remote_port_send_num[meta.remote_port + port] = 0 + remote_host = meta.remote_multi_nodes_meta_mapping[str(port)]['host'] + remote_port_send_num[meta.remote_port + port] = {} + remote_port_send_num[meta.remote_port + port]['num'] = 0 + remote_port_send_num[meta.remote_port + port]['host'] = remote_host for local_port in local_remote_block_port_mappings.keys(): remote_port_head_list = local_remote_block_port_mappings[ local_port] for remote_port_list in remote_port_head_list: for remote_port in remote_port_list: - remote_port_send_num[remote_port] += 1 + remote_port_send_num[remote_port]['num'] += 1 return remote_port_send_num if meta.remote_engine_id not in self.local_remote_block_port_mapping: self.local_remote_block_port_mapping[meta.remote_engine_id] = None + if self.local_remote_block_port_mapping[meta.remote_engine_id] is None: local_remote_block_port_mappings = get_local_remote_block_port_mappings( ) @@ -1837,4 +1842,4 @@ def get_prefill_pp_indices( f"{sum(partitions)=} does not match {num_hidden_layers=}.") start_layer = sum(partitions[:pp_rank]) end_layer = start_layer + partitions[pp_rank] - return (start_layer, end_layer) + return (start_layer, end_layer) \ No newline at end of file