[main][bugfix] fix mooncake kv cache transfer when one P has multi nodes (#5960)
### What this PR does / why we need it?
In PD disaggregation case, when P has multi nodes, mooncake fails to
send data. Fix the issue in this PR.
The details:
If a P rank does not need to transfer kv cache to any one D rank, D node
should send a message to P node to release the kv
cache in P node. If P has multi nodes, D node should know the
corresponding IP in each P node, then D node can send message to the
right P node. Otherwise, send data error will happen. This PR fix this
issue by providing P nodes IP to D node through Parameter
`remote_port_send_num`.
- vLLM version: v0.13.0
- vLLM main:
2c24bc6996
---------
Signed-off-by: wangxiaochao <w00642655@china.huawei.com>
Co-authored-by: wangxiaochao <w00642655@china.huawei.com>
This commit is contained in:
@@ -269,7 +269,7 @@ class KVCacheSendingThread(threading.Thread):
|
|||||||
self.prefill_tp_size
|
self.prefill_tp_size
|
||||||
handshake_port = self.side_channel_port + device_index
|
handshake_port = self.side_channel_port + device_index
|
||||||
if self.port_send_num[request_id] >= \
|
if self.port_send_num[request_id] >= \
|
||||||
remote_port_send_num[handshake_port]:
|
remote_port_send_num[handshake_port]['num']:
|
||||||
self.task_tracker.update_done_task_count(
|
self.task_tracker.update_done_task_count(
|
||||||
request_id)
|
request_id)
|
||||||
del self.port_send_num[request_id]
|
del self.port_send_num[request_id]
|
||||||
@@ -382,7 +382,7 @@ class KVCacheRecvingThread(threading.Thread):
|
|||||||
remote_handshake_port: int,
|
remote_handshake_port: int,
|
||||||
offset: int,
|
offset: int,
|
||||||
tp_num_need_pulls: int,
|
tp_num_need_pulls: int,
|
||||||
remote_port_send_num: dict[int, int] = {},
|
remote_port_send_num: dict[int, dict[str, int | str]] = {},
|
||||||
all_task_done: bool = False):
|
all_task_done: bool = False):
|
||||||
"""Add a new request to the queue for processing."""
|
"""Add a new request to the queue for processing."""
|
||||||
logger.debug(f"Adding request {request_id} to the queue.")
|
logger.debug(f"Adding request {request_id} to the queue.")
|
||||||
@@ -463,8 +463,9 @@ class KVCacheRecvingThread(threading.Thread):
|
|||||||
self.proc_not_transfer_request[request_id] = True
|
self.proc_not_transfer_request[request_id] = True
|
||||||
if self.proc_not_transfer_request[request_id]:
|
if self.proc_not_transfer_request[request_id]:
|
||||||
for remote_port in remote_port_send_num.keys():
|
for remote_port in remote_port_send_num.keys():
|
||||||
if remote_port_send_num[remote_port] == 0:
|
if remote_port_send_num[remote_port]['num'] == 0:
|
||||||
self._send_done_recv_signal(request_id, remote_host,
|
remote_host_ = remote_port_send_num[remote_port]['host']
|
||||||
|
self._send_done_recv_signal(request_id, remote_host_,
|
||||||
remote_port,
|
remote_port,
|
||||||
remote_port_send_num)
|
remote_port_send_num)
|
||||||
self.proc_not_transfer_request[request_id] = False
|
self.proc_not_transfer_request[request_id] = False
|
||||||
@@ -705,7 +706,7 @@ class KVCacheRecvingThread(threading.Thread):
|
|||||||
|
|
||||||
def _send_done_recv_signal(self, request_id: str, remote_host: str,
|
def _send_done_recv_signal(self, request_id: str, remote_host: str,
|
||||||
remote_handshake_port: int,
|
remote_handshake_port: int,
|
||||||
remote_port_send_num: dict[int, int]):
|
remote_port_send_num: dict[int, dict[str, int | str]]):
|
||||||
logger.debug("Sending done recving signal for request %s to %s:%d",
|
logger.debug("Sending done recving signal for request %s to %s:%d",
|
||||||
request_id, remote_host, remote_handshake_port)
|
request_id, remote_host, remote_handshake_port)
|
||||||
sock: Optional[zmq.Socket] = None # type: ignore
|
sock: Optional[zmq.Socket] = None # type: ignore
|
||||||
@@ -1170,7 +1171,7 @@ class MooncakeConnectorWorker:
|
|||||||
self.tp_num_need_pulls = num_d_block_heads // num_p_block_heads
|
self.tp_num_need_pulls = num_d_block_heads // num_p_block_heads
|
||||||
self.local_remote_block_port_mapping: dict[
|
self.local_remote_block_port_mapping: dict[
|
||||||
str, Optional[List[List[int]]]] = {}
|
str, Optional[List[List[int]]]] = {}
|
||||||
self.remote_port_send_num: dict[str, dict[int, int]] = {}
|
self.remote_port_send_num: dict[str, dict[int, dict[str, int | str]]] = {}
|
||||||
|
|
||||||
def _get_prefill_decode_size(self, vllm_config: VllmConfig):
|
def _get_prefill_decode_size(self, vllm_config: VllmConfig):
|
||||||
# get prefill tp and dp size from extra config
|
# get prefill tp and dp size from extra config
|
||||||
@@ -1457,19 +1458,23 @@ class MooncakeConnectorWorker:
|
|||||||
return local_remote_block_port_mappings
|
return local_remote_block_port_mappings
|
||||||
|
|
||||||
def get_remote_port_send_num(local_remote_block_port_mappings):
|
def get_remote_port_send_num(local_remote_block_port_mappings):
|
||||||
remote_port_send_num: dict[int, int] = {}
|
remote_port_send_num: dict[int, dict[str, int | str]] = {}
|
||||||
for port in range(self._prefill_tp_size * meta.remote_pcp_size):
|
for port in range(self._prefill_tp_size * meta.remote_pcp_size):
|
||||||
remote_port_send_num[meta.remote_port + port] = 0
|
remote_host = meta.remote_multi_nodes_meta_mapping[str(port)]['host']
|
||||||
|
remote_port_send_num[meta.remote_port + port] = {}
|
||||||
|
remote_port_send_num[meta.remote_port + port]['num'] = 0
|
||||||
|
remote_port_send_num[meta.remote_port + port]['host'] = remote_host
|
||||||
for local_port in local_remote_block_port_mappings.keys():
|
for local_port in local_remote_block_port_mappings.keys():
|
||||||
remote_port_head_list = local_remote_block_port_mappings[
|
remote_port_head_list = local_remote_block_port_mappings[
|
||||||
local_port]
|
local_port]
|
||||||
for remote_port_list in remote_port_head_list:
|
for remote_port_list in remote_port_head_list:
|
||||||
for remote_port in remote_port_list:
|
for remote_port in remote_port_list:
|
||||||
remote_port_send_num[remote_port] += 1
|
remote_port_send_num[remote_port]['num'] += 1
|
||||||
return remote_port_send_num
|
return remote_port_send_num
|
||||||
|
|
||||||
if meta.remote_engine_id not in self.local_remote_block_port_mapping:
|
if meta.remote_engine_id not in self.local_remote_block_port_mapping:
|
||||||
self.local_remote_block_port_mapping[meta.remote_engine_id] = None
|
self.local_remote_block_port_mapping[meta.remote_engine_id] = None
|
||||||
|
|
||||||
if self.local_remote_block_port_mapping[meta.remote_engine_id] is None:
|
if self.local_remote_block_port_mapping[meta.remote_engine_id] is None:
|
||||||
local_remote_block_port_mappings = get_local_remote_block_port_mappings(
|
local_remote_block_port_mappings = get_local_remote_block_port_mappings(
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user