[main][bugfix] fix mooncake kv cache transfer when one P has multi nodes (#5960)

### What this PR does / why we need it?
In PD disaggregation case, when P has multi nodes, mooncake fails to
send data. Fix the issue in this PR.

The details:
If a P rank does not need to transfer kv cache to any one D rank, D node
should send a message to P node to release the kv
cache in P node. If P has multi nodes, D node should know the
corresponding IP in each P node, then D node can send message to the
right P node. Otherwise, send data error will happen. This PR fix this
issue by providing P nodes IP to D node through Parameter
`remote_port_send_num`.

- vLLM version: v0.13.0
- vLLM main:
2c24bc6996

---------

Signed-off-by: wangxiaochao <w00642655@china.huawei.com>
Co-authored-by: wangxiaochao <w00642655@china.huawei.com>
This commit is contained in:
wangxiaochao6
2026-01-19 16:35:13 +08:00
committed by GitHub
parent ebb940691f
commit bc486d9530

View File

@@ -269,7 +269,7 @@ class KVCacheSendingThread(threading.Thread):
self.prefill_tp_size
handshake_port = self.side_channel_port + device_index
if self.port_send_num[request_id] >= \
remote_port_send_num[handshake_port]:
remote_port_send_num[handshake_port]['num']:
self.task_tracker.update_done_task_count(
request_id)
del self.port_send_num[request_id]
@@ -382,7 +382,7 @@ class KVCacheRecvingThread(threading.Thread):
remote_handshake_port: int,
offset: int,
tp_num_need_pulls: int,
remote_port_send_num: dict[int, int] = {},
remote_port_send_num: dict[int, dict[str, int | str]] = {},
all_task_done: bool = False):
"""Add a new request to the queue for processing."""
logger.debug(f"Adding request {request_id} to the queue.")
@@ -463,8 +463,9 @@ class KVCacheRecvingThread(threading.Thread):
self.proc_not_transfer_request[request_id] = True
if self.proc_not_transfer_request[request_id]:
for remote_port in remote_port_send_num.keys():
if remote_port_send_num[remote_port] == 0:
self._send_done_recv_signal(request_id, remote_host,
if remote_port_send_num[remote_port]['num'] == 0:
remote_host_ = remote_port_send_num[remote_port]['host']
self._send_done_recv_signal(request_id, remote_host_,
remote_port,
remote_port_send_num)
self.proc_not_transfer_request[request_id] = False
@@ -705,7 +706,7 @@ class KVCacheRecvingThread(threading.Thread):
def _send_done_recv_signal(self, request_id: str, remote_host: str,
remote_handshake_port: int,
remote_port_send_num: dict[int, int]):
remote_port_send_num: dict[int, dict[str, int | str]]):
logger.debug("Sending done recving signal for request %s to %s:%d",
request_id, remote_host, remote_handshake_port)
sock: Optional[zmq.Socket] = None # type: ignore
@@ -1170,7 +1171,7 @@ class MooncakeConnectorWorker:
self.tp_num_need_pulls = num_d_block_heads // num_p_block_heads
self.local_remote_block_port_mapping: dict[
str, Optional[List[List[int]]]] = {}
self.remote_port_send_num: dict[str, dict[int, int]] = {}
self.remote_port_send_num: dict[str, dict[int, dict[str, int | str]]] = {}
def _get_prefill_decode_size(self, vllm_config: VllmConfig):
# get prefill tp and dp size from extra config
@@ -1457,19 +1458,23 @@ class MooncakeConnectorWorker:
return local_remote_block_port_mappings
def get_remote_port_send_num(local_remote_block_port_mappings):
remote_port_send_num: dict[int, int] = {}
remote_port_send_num: dict[int, dict[str, int | str]] = {}
for port in range(self._prefill_tp_size * meta.remote_pcp_size):
remote_port_send_num[meta.remote_port + port] = 0
remote_host = meta.remote_multi_nodes_meta_mapping[str(port)]['host']
remote_port_send_num[meta.remote_port + port] = {}
remote_port_send_num[meta.remote_port + port]['num'] = 0
remote_port_send_num[meta.remote_port + port]['host'] = remote_host
for local_port in local_remote_block_port_mappings.keys():
remote_port_head_list = local_remote_block_port_mappings[
local_port]
for remote_port_list in remote_port_head_list:
for remote_port in remote_port_list:
remote_port_send_num[remote_port] += 1
remote_port_send_num[remote_port]['num'] += 1
return remote_port_send_num
if meta.remote_engine_id not in self.local_remote_block_port_mapping:
self.local_remote_block_port_mapping[meta.remote_engine_id] = None
if self.local_remote_block_port_mapping[meta.remote_engine_id] is None:
local_remote_block_port_mappings = get_local_remote_block_port_mappings(
)
@@ -1837,4 +1842,4 @@ def get_prefill_pp_indices(
f"{sum(partitions)=} does not match {num_hidden_layers=}.")
start_layer = sum(partitions[:pp_rank])
end_layer = start_layer + partitions[pp_rank]
return (start_layer, end_layer)
return (start_layer, end_layer)