[Feature] Mooncake connector get remote ptp size (#5822)

### What this PR does / why we need it?
To support elastic scaling when using mooncake connector, we should
support to **configure different tp sizes for different nodes**.
As a result, we transfer the prefill node information, such as tp size,
through **the request's kv_transfer_params**.
The decode nodes **get the prefill tp size** through the request's
kv_transfer_params, instead of getting it from the configuration of the
mooncake connector .

- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef

Signed-off-by: yuxinshan <syx_ctyg@126.com>
Signed-off-by: CalvinXKY <kyxiezju@163.com>
This commit is contained in:
yuxinshan
2026-01-26 14:28:33 +08:00
committed by GitHub
parent 611e223b7d
commit 0bb1f91c2c
2 changed files with 124 additions and 36 deletions

View File

@@ -7,7 +7,7 @@ import time
import types
import unittest
from collections import defaultdict, deque
from typing import Any, Dict, OrderedDict
from typing import Any, Dict, OrderedDict, Optional
from unittest.mock import MagicMock, patch
import msgspec
@@ -691,7 +691,8 @@ class TestMooncakeConnectorMetadata(unittest.TestCase):
"remote_host": "localhost",
"remote_port": 5000,
"remote_pcp_size": 1,
"remote_dcp_size": 1
"remote_dcp_size": 1,
"remote_ptp_size": 2
})
self.assertEqual(len(meta.requests), 1)
@@ -702,6 +703,7 @@ class TestMooncakeConnectorMetadata(unittest.TestCase):
self.assertEqual(req_meta.remote_engine_id, "remote_engine")
self.assertEqual(req_meta.remote_host, "localhost")
self.assertEqual(req_meta.remote_port, 5000)
self.assertEqual(req_meta.remote_ptp_size, 2)
class TestMooncakeConnectorSchedulerMatchedTokens(unittest.TestCase):
@@ -1209,9 +1211,13 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
def test_get_remote_tp_rank(self):
def get_tp_rank(prefill_tp_size: int, prefill_pp_size: int,
decode_tp_size: int, num_kv_heads: int,
tp_num_need_pulls: int, is_deepseek_mla: bool):
def get_tp_rank(prefill_tp_size: int,
prefill_pp_size: int,
decode_tp_size: int,
num_kv_heads: int,
tp_num_need_pulls: int,
is_deepseek_mla: bool,
remote_ptp_size: Optional[int] = None):
with patch(
'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_ascend_config',
return_value=MagicMock()), \
@@ -1226,7 +1232,8 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
self.engine_id)
worker.tp_num_need_pulls = tp_num_need_pulls
worker.use_sparse = 0
return worker._get_remote_ranks_for_req('test')
return worker._get_remote_ranks_for_req(
'test', remote_ptp_size)
self.assertIn(
get_tp_rank(16, 1, 1, 4, 4, False)[0],
@@ -1285,13 +1292,30 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
get_tp_rank(4, 4, 4, 1, 1, True),
[[[0, 4, 8, 12], [1, 5, 9, 13], [2, 6, 10, 14], [3, 7, 11, 15]]])
# check remote ptp size
self.assertListEqual(get_tp_rank(16, 1, 2, 4, 2, False, 8),
get_tp_rank(8, 1, 2, 4, 2, False))
self.assertListEqual(get_tp_rank(8, 1, 2, 4, 2, False, 4),
get_tp_rank(4, 1, 2, 4, 2, False))
self.assertListEqual(get_tp_rank(4, 1, 2, 4, 1, False, 2),
get_tp_rank(2, 1, 2, 4, 1, False))
def test_get_kv_split_metadata(self):
def get_kv_split_metadata(use_mla, pcp_size, dcp_size, tp_size,
tp_rank, pcp_rank, _prefill_tp_size,
remote_pcp_size, remote_dcp_size,
remote_port, remote_block_ids,
local_block_ids, remote_engine_id):
def get_kv_split_metadata(use_mla,
pcp_size,
dcp_size,
tp_size,
tp_rank,
pcp_rank,
_prefill_tp_size,
remote_pcp_size,
remote_dcp_size,
remote_port,
remote_block_ids,
local_block_ids,
remote_engine_id,
remote_ptp_size=None):
worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id)
@@ -1310,6 +1334,7 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
meta.remote_pcp_size = remote_pcp_size
meta.remote_dcp_size = remote_dcp_size
meta.remote_ptp_size = remote_ptp_size
meta.remote_port = remote_port
meta.remote_block_ids = remote_block_ids
meta.local_block_ids = local_block_ids
@@ -1367,6 +1392,40 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
[1, 2, 3], [1, 2, 3, 4, 5], 0),
([[30000], [30008]], [[1, 2, 3], [4, 5]], [[1, 2, 3], [1, 2]]))
# check remote ptp size
self.assertEqual(
get_kv_split_metadata(True, 1, 1, 8, 1, 0, 8, 1, 8, 30000, [1],
[1], 0, 16),
get_kv_split_metadata(True, 1, 1, 8, 1, 0, 16, 1, 8, 30000, [1],
[1], 0)
)
self.assertEqual(
get_kv_split_metadata(False, 1, 1, 8, 1, 0, 8, 1, 8, 30000, [1],
[1], 0, 16),
get_kv_split_metadata(False, 1, 1, 8, 1, 0, 16, 1, 8, 30000, [1],
[1], 0)
)
self.assertEqual(
get_kv_split_metadata(False, 1, 1, 8, 1, 0, 8, 2, 8, 30000, [1],
[1], 0, 16),
get_kv_split_metadata(False, 1, 1, 8, 1, 0, 16, 2, 8, 30000, [1],
[1], 0)
)
def test_get_tp_num_need_pulls(self):
worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id)
worker.num_key_value_heads = 8
tp_num_need_pulls = worker._get_tp_num_need_pulls(prefill_tp_size=4)
self.assertEqual(tp_num_need_pulls, 1)
worker.vllm_config.model_config.is_deepseek_mla = False
tp_num_need_pulls = worker._get_tp_num_need_pulls(prefill_tp_size=4)
self.assertEqual(tp_num_need_pulls, 2)
tp_num_need_pulls = worker._get_tp_num_need_pulls(prefill_tp_size=None)
self.assertEqual(tp_num_need_pulls, 1)
if __name__ == '__main__':
unittest.main()

View File

@@ -87,6 +87,7 @@ class ReqMeta:
remote_request_id: str
remote_pcp_size: int
remote_dcp_size: int
remote_ptp_size: int | None
remote_multi_nodes_meta_mapping: dict[str, dict[str, Any]]
num_prompt_blocks: int
@@ -773,6 +774,7 @@ class MooncakeConnectorMetadata(KVConnectorMetadata):
remote_port=kv_transfer_params["remote_port"],
remote_pcp_size=kv_transfer_params.get("remote_pcp_size", 1),
remote_dcp_size=kv_transfer_params.get("remote_dcp_size", 1),
remote_ptp_size=kv_transfer_params.get("remote_ptp_size"),
remote_multi_nodes_meta_mapping=kv_transfer_params.get("remote_multi_nodes_meta_mapping", {}),
num_prompt_blocks=kv_transfer_params.get("num_prompt_blocks", 0),
)
@@ -890,6 +892,7 @@ class MooncakeConnectorScheduler:
self.side_channel_host = get_ip()
self.pcp_size = vllm_config.parallel_config.prefill_context_parallel_size
self.dcp_size = vllm_config.parallel_config.decode_context_parallel_size
self.tp_size = vllm_config.parallel_config.tensor_parallel_size
self.max_device_id = (
vllm_config.parallel_config.tensor_parallel_size
* vllm_config.parallel_config.data_parallel_size
@@ -1039,6 +1042,7 @@ class MooncakeConnectorScheduler:
remote_port=self.side_channel_port,
remote_pcp_size=self.pcp_size,
remote_dcp_size=self.dcp_size,
remote_ptp_size=self.tp_size,
last_token_id=request.output_token_ids[-1],
remote_multi_nodes_meta_mapping=self.multi_nodes_meta_mapping,
num_prompt_blocks=num_prompt_blocks,
@@ -1324,8 +1328,9 @@ class MooncakeConnectorWorker:
In cp/dcp scenario, kv_cache may be split, so we need to pull multiple blocks from multiple remote P node.
Use this function to calculate remote port and remote block number of each remote P node that we need to pull.
"""
prefill_tp_size = meta.remote_ptp_size if getattr(meta, "remote_ptp_size", None) else self._prefill_tp_size
if meta.remote_pcp_size * meta.remote_dcp_size * self.pcp_size * self.dcp_size == 1:
choosen_rank_list = self._get_remote_rank(req_id)
choosen_rank_list = self._get_remote_rank(req_id, prefill_tp_size)
remote_handshake_port_list = [[x + meta.remote_port for x in choosen_rank_list]]
local_block_ids_list, remote_block_ids_list = [meta.local_block_ids], [meta.remote_block_ids]
return remote_handshake_port_list, local_block_ids_list, remote_block_ids_list
@@ -1333,7 +1338,7 @@ class MooncakeConnectorWorker:
def context_parallel_parameters_check():
assert (meta.remote_pcp_size * meta.remote_dcp_size) % (self.pcp_size * self.dcp_size) == 0
if not self.use_mla:
p_node_heads_per_rank = math.ceil(self.num_key_value_heads / self._prefill_tp_size)
p_node_heads_per_rank = math.ceil(self.num_key_value_heads / prefill_tp_size)
d_node_heads_per_rank = math.ceil(self.num_key_value_heads / self.tp_size)
assert d_node_heads_per_rank % p_node_heads_per_rank == 0
@@ -1387,7 +1392,7 @@ class MooncakeConnectorWorker:
def get_local_remote_block_port_mappings():
context_parallel_parameters_check()
p_node_cp_group_meta = get_cp_group_meta(
self._prefill_tp_size, meta.remote_pcp_size, meta.remote_dcp_size, meta.remote_port
prefill_tp_size, meta.remote_pcp_size, meta.remote_dcp_size, meta.remote_port
)
d_node_cp_group_meta = get_cp_group_meta(self.tp_size, self.pcp_size, self.dcp_size, self.side_channel_port)
local_remote_block_port_mappings: dict[int, list[list[int]]] = {}
@@ -1427,7 +1432,7 @@ class MooncakeConnectorWorker:
local_remote_block_port_mappings: dict[int, list[list[int]]],
) -> dict[int, RemotePortInfo]:
remote_port_send_num: dict[int, RemotePortInfo] = {}
for port in range(self._prefill_tp_size * meta.remote_pcp_size):
for port in range(prefill_tp_size * meta.remote_pcp_size):
remote_host_info = meta.remote_multi_nodes_meta_mapping.get(str(port), None)
if remote_host_info is None:
remote_host = meta.remote_host
@@ -1518,8 +1523,9 @@ class MooncakeConnectorWorker:
)
local_block_offset += num_blocks_to_pull
assert self.tp_num_need_pulls == len(remote_handshake_port_list[0]), (
f"tp_num_need_pulls: {self.tp_num_need_pulls}, remote_handshake_port_list: {remote_handshake_port_list}"
tp_num_need_pulls = self._get_tp_num_need_pulls(prefill_tp_size)
assert tp_num_need_pulls == len(remote_handshake_port_list[0]), (
f"tp_num_need_pulls: {tp_num_need_pulls}, remote_handshake_port_list: {remote_handshake_port_list}"
)
return remote_handshake_port_list, local_block_ids_list, remote_block_ids_list
@@ -1535,13 +1541,17 @@ class MooncakeConnectorWorker:
len(meta.local_block_ids),
len(meta.remote_block_ids),
)
prefill_tp_size = meta.remote_ptp_size if getattr(meta, "remote_ptp_size", None) else self._prefill_tp_size
tp_num_need_pulls = self._get_tp_num_need_pulls(prefill_tp_size)
if meta.remote_pcp_size * meta.remote_dcp_size > 1:
remote_handshake_port_list, local_block_ids_list, remote_block_ids_list = self._get_kv_split_metadata(
req_id, meta
)
for pcp_dcp_rank in range(len(remote_handshake_port_list)):
for i in range(self.tp_num_need_pulls):
for i in range(tp_num_need_pulls):
assert self.kv_recv_thread is not None
remote_host, remote_engine_id = self._get_remote_host_info_by_port(
meta.remote_port,
@@ -1559,16 +1569,16 @@ class MooncakeConnectorWorker:
remote_host=remote_host,
remote_handshake_port=remote_handshake_port_list[pcp_dcp_rank][i],
offset=i,
tp_num_need_pulls=self.tp_num_need_pulls,
tp_num_need_pulls=tp_num_need_pulls,
remote_port_send_num=self.remote_port_send_num[meta.remote_engine_id],
all_task_done=(
pcp_dcp_rank == len(remote_handshake_port_list) - 1 and i == self.tp_num_need_pulls - 1
pcp_dcp_rank == len(remote_handshake_port_list) - 1 and i == tp_num_need_pulls - 1
),
)
else: # TODO: support prefill context parallel and pipeline parallel open at the same time
choosen_rank_list = self._get_remote_rank(req_id)
choosen_rank_list = self._get_remote_rank(req_id, prefill_tp_size)
remote_handshake_port_list = [[x + meta.remote_port] for x in choosen_rank_list]
for i in range(self.tp_num_need_pulls * self._prefill_pp_size):
for i in range(tp_num_need_pulls * self._prefill_pp_size):
assert self.kv_recv_thread is not None
remote_host, remote_engine_id = self._get_remote_host_info_by_port(
meta.remote_port,
@@ -1586,8 +1596,8 @@ class MooncakeConnectorWorker:
remote_host=remote_host,
remote_handshake_port=remote_handshake_port_list[i][0],
offset=i,
tp_num_need_pulls=self.tp_num_need_pulls,
all_task_done=(i == self.tp_num_need_pulls * self._prefill_pp_size - 1),
tp_num_need_pulls=tp_num_need_pulls,
all_task_done=(i == tp_num_need_pulls * self._prefill_pp_size - 1),
)
for req_id in metadata.reqs_in_batch:
@@ -1607,6 +1617,21 @@ class MooncakeConnectorWorker:
for req_id, delay_start_time in metadata.requests_to_send.items():
self.kv_send_thread.add_delayed_request(req_id, delay_start_time)
def _get_tp_num_need_pulls(self, prefill_tp_size: int) -> int:
if prefill_tp_size is None:
prefill_tp_size = self._prefill_tp_size
if prefill_tp_size == self._prefill_tp_size:
return self.tp_num_need_pulls
if self.vllm_config.model_config.is_deepseek_mla:
tp_num_need_pulls = 1
else:
num_d_block_heads = max(1, self.num_key_value_heads // self.tp_size)
num_p_block_heads = max(1, self.num_key_value_heads // prefill_tp_size)
tp_num_need_pulls = num_d_block_heads // num_p_block_heads
return tp_num_need_pulls
def _get_remote_host_info_by_port(
self,
base_port: int,
@@ -1624,16 +1649,17 @@ class MooncakeConnectorWorker:
def _prefill_get_remote_rank(self, req_id: str) -> list[int]:
return sum(self._get_remote_ranks_for_req(req_id), [])
def _get_remote_rank(self, req_id: str) -> list[int]:
return self._get_remote_ranks_for_req(req_id)[self.tp_rank]
def _get_remote_rank(self, req_id: str, prefill_tp_size: int | None = None) -> list[int]:
return self._get_remote_ranks_for_req(req_id, prefill_tp_size)[self.tp_rank]
def _get_remote_tp_ranks(
self, tp_ori_data: np.ndarray, rand_group_index: list[int], num_groups: int
self, tp_ori_data: np.ndarray, rand_group_index: list[int], num_groups: int, prefill_tp_size: int
) -> list[list[int]]:
tp_num_need_pulls = self._get_tp_num_need_pulls(prefill_tp_size)
# random split prefill tp list
tp_sampled_nums = []
if (
self._prefill_tp_size > self.num_key_value_heads
prefill_tp_size > self.num_key_value_heads
or self.vllm_config.model_config.is_deepseek_mla
or self.use_sparse
):
@@ -1641,24 +1667,27 @@ class MooncakeConnectorWorker:
choosen_group = tp_ori_data[:, [rand_group_index]]
flattened = choosen_group.reshape(-1).tolist()
tp_sampled_nums = [
flattened[i : i + self.tp_num_need_pulls] for i in range(0, len(flattened), self.tp_num_need_pulls)
flattened[i : i + tp_num_need_pulls] for i in range(0, len(flattened), tp_num_need_pulls)
]
# non-random split
else:
group_size = self._prefill_tp_size // self._decode_tp_size
group_size = prefill_tp_size // self._decode_tp_size
for i in range(self._decode_tp_size):
slice = tp_ori_data[i * group_size : (i + 1) * group_size]
tp_sampled_nums.append(slice.tolist())
return tp_sampled_nums
def _get_remote_ranks_for_req(self, req_id: str) -> list[list[int]]:
def _get_remote_ranks_for_req(self, req_id: str, prefill_tp_size: int | None = None) -> list[list[int]]:
if prefill_tp_size is None:
prefill_tp_size = self._prefill_tp_size
# Divide the ports according to the TP within the PP
sampled_nums = []
if self._prefill_tp_size == self._decode_tp_size:
if prefill_tp_size == self._decode_tp_size:
sampled_nums = list(
map(
lambda tp: [tp + pp * self._prefill_tp_size for pp in range(self._prefill_pp_size)],
range(self._prefill_tp_size),
lambda tp: [tp + pp * prefill_tp_size for pp in range(self._prefill_pp_size)],
range(prefill_tp_size),
)
)
return sampled_nums
@@ -1667,7 +1696,7 @@ class MooncakeConnectorWorker:
num_kv_head = 1
else:
num_kv_head = self.num_key_value_heads
ori_data = np.arange(self._prefill_tp_size * self._prefill_pp_size)
ori_data = np.arange(prefill_tp_size * self._prefill_pp_size)
seed = string_to_int64_hash(req_id)
rand = random.Random(seed)
# random split prefill tp list
@@ -1679,7 +1708,7 @@ class MooncakeConnectorWorker:
range(num_groups), (max(self._decode_tp_size // num_kv_head, 1))
) # random choose a group
all_results = [
self._get_remote_tp_ranks(ori_data[pp_index], rand_group_index, num_groups)
self._get_remote_tp_ranks(ori_data[pp_index], rand_group_index, num_groups, prefill_tp_size)
for pp_index in range(self._prefill_pp_size)
]
for group_index in range(len(all_results[0])):