[Feature] Mooncake connector get remote ptp size (#5822)

### What this PR does / why we need it?
To support elastic scaling when using mooncake connector, we should
support to **configure different tp sizes for different nodes**.
As a result, we transfer the prefill node information, such as tp size,
through **the request's kv_transfer_params**.
The decode nodes **get the prefill tp size** through the request's
kv_transfer_params, instead of getting it from the configuration of the
mooncake connector .

- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef

Signed-off-by: yuxinshan <syx_ctyg@126.com>
Signed-off-by: CalvinXKY <kyxiezju@163.com>
This commit is contained in:
yuxinshan
2026-01-26 14:28:33 +08:00
committed by GitHub
parent 611e223b7d
commit 0bb1f91c2c
2 changed files with 124 additions and 36 deletions

View File

@@ -87,6 +87,7 @@ class ReqMeta:
remote_request_id: str
remote_pcp_size: int
remote_dcp_size: int
remote_ptp_size: int | None
remote_multi_nodes_meta_mapping: dict[str, dict[str, Any]]
num_prompt_blocks: int
@@ -773,6 +774,7 @@ class MooncakeConnectorMetadata(KVConnectorMetadata):
remote_port=kv_transfer_params["remote_port"],
remote_pcp_size=kv_transfer_params.get("remote_pcp_size", 1),
remote_dcp_size=kv_transfer_params.get("remote_dcp_size", 1),
remote_ptp_size=kv_transfer_params.get("remote_ptp_size"),
remote_multi_nodes_meta_mapping=kv_transfer_params.get("remote_multi_nodes_meta_mapping", {}),
num_prompt_blocks=kv_transfer_params.get("num_prompt_blocks", 0),
)
@@ -890,6 +892,7 @@ class MooncakeConnectorScheduler:
self.side_channel_host = get_ip()
self.pcp_size = vllm_config.parallel_config.prefill_context_parallel_size
self.dcp_size = vllm_config.parallel_config.decode_context_parallel_size
self.tp_size = vllm_config.parallel_config.tensor_parallel_size
self.max_device_id = (
vllm_config.parallel_config.tensor_parallel_size
* vllm_config.parallel_config.data_parallel_size
@@ -1039,6 +1042,7 @@ class MooncakeConnectorScheduler:
remote_port=self.side_channel_port,
remote_pcp_size=self.pcp_size,
remote_dcp_size=self.dcp_size,
remote_ptp_size=self.tp_size,
last_token_id=request.output_token_ids[-1],
remote_multi_nodes_meta_mapping=self.multi_nodes_meta_mapping,
num_prompt_blocks=num_prompt_blocks,
@@ -1324,8 +1328,9 @@ class MooncakeConnectorWorker:
In cp/dcp scenario, kv_cache may be split, so we need to pull multiple blocks from multiple remote P node.
Use this function to calculate remote port and remote block number of each remote P node that we need to pull.
"""
prefill_tp_size = meta.remote_ptp_size if getattr(meta, "remote_ptp_size", None) else self._prefill_tp_size
if meta.remote_pcp_size * meta.remote_dcp_size * self.pcp_size * self.dcp_size == 1:
choosen_rank_list = self._get_remote_rank(req_id)
choosen_rank_list = self._get_remote_rank(req_id, prefill_tp_size)
remote_handshake_port_list = [[x + meta.remote_port for x in choosen_rank_list]]
local_block_ids_list, remote_block_ids_list = [meta.local_block_ids], [meta.remote_block_ids]
return remote_handshake_port_list, local_block_ids_list, remote_block_ids_list
@@ -1333,7 +1338,7 @@ class MooncakeConnectorWorker:
def context_parallel_parameters_check():
assert (meta.remote_pcp_size * meta.remote_dcp_size) % (self.pcp_size * self.dcp_size) == 0
if not self.use_mla:
p_node_heads_per_rank = math.ceil(self.num_key_value_heads / self._prefill_tp_size)
p_node_heads_per_rank = math.ceil(self.num_key_value_heads / prefill_tp_size)
d_node_heads_per_rank = math.ceil(self.num_key_value_heads / self.tp_size)
assert d_node_heads_per_rank % p_node_heads_per_rank == 0
@@ -1387,7 +1392,7 @@ class MooncakeConnectorWorker:
def get_local_remote_block_port_mappings():
context_parallel_parameters_check()
p_node_cp_group_meta = get_cp_group_meta(
self._prefill_tp_size, meta.remote_pcp_size, meta.remote_dcp_size, meta.remote_port
prefill_tp_size, meta.remote_pcp_size, meta.remote_dcp_size, meta.remote_port
)
d_node_cp_group_meta = get_cp_group_meta(self.tp_size, self.pcp_size, self.dcp_size, self.side_channel_port)
local_remote_block_port_mappings: dict[int, list[list[int]]] = {}
@@ -1427,7 +1432,7 @@ class MooncakeConnectorWorker:
local_remote_block_port_mappings: dict[int, list[list[int]]],
) -> dict[int, RemotePortInfo]:
remote_port_send_num: dict[int, RemotePortInfo] = {}
for port in range(self._prefill_tp_size * meta.remote_pcp_size):
for port in range(prefill_tp_size * meta.remote_pcp_size):
remote_host_info = meta.remote_multi_nodes_meta_mapping.get(str(port), None)
if remote_host_info is None:
remote_host = meta.remote_host
@@ -1518,8 +1523,9 @@ class MooncakeConnectorWorker:
)
local_block_offset += num_blocks_to_pull
assert self.tp_num_need_pulls == len(remote_handshake_port_list[0]), (
f"tp_num_need_pulls: {self.tp_num_need_pulls}, remote_handshake_port_list: {remote_handshake_port_list}"
tp_num_need_pulls = self._get_tp_num_need_pulls(prefill_tp_size)
assert tp_num_need_pulls == len(remote_handshake_port_list[0]), (
f"tp_num_need_pulls: {tp_num_need_pulls}, remote_handshake_port_list: {remote_handshake_port_list}"
)
return remote_handshake_port_list, local_block_ids_list, remote_block_ids_list
@@ -1535,13 +1541,17 @@ class MooncakeConnectorWorker:
len(meta.local_block_ids),
len(meta.remote_block_ids),
)
prefill_tp_size = meta.remote_ptp_size if getattr(meta, "remote_ptp_size", None) else self._prefill_tp_size
tp_num_need_pulls = self._get_tp_num_need_pulls(prefill_tp_size)
if meta.remote_pcp_size * meta.remote_dcp_size > 1:
remote_handshake_port_list, local_block_ids_list, remote_block_ids_list = self._get_kv_split_metadata(
req_id, meta
)
for pcp_dcp_rank in range(len(remote_handshake_port_list)):
for i in range(self.tp_num_need_pulls):
for i in range(tp_num_need_pulls):
assert self.kv_recv_thread is not None
remote_host, remote_engine_id = self._get_remote_host_info_by_port(
meta.remote_port,
@@ -1559,16 +1569,16 @@ class MooncakeConnectorWorker:
remote_host=remote_host,
remote_handshake_port=remote_handshake_port_list[pcp_dcp_rank][i],
offset=i,
tp_num_need_pulls=self.tp_num_need_pulls,
tp_num_need_pulls=tp_num_need_pulls,
remote_port_send_num=self.remote_port_send_num[meta.remote_engine_id],
all_task_done=(
pcp_dcp_rank == len(remote_handshake_port_list) - 1 and i == self.tp_num_need_pulls - 1
pcp_dcp_rank == len(remote_handshake_port_list) - 1 and i == tp_num_need_pulls - 1
),
)
else: # TODO: support prefill context parallel and pipeline parallel open at the same time
choosen_rank_list = self._get_remote_rank(req_id)
choosen_rank_list = self._get_remote_rank(req_id, prefill_tp_size)
remote_handshake_port_list = [[x + meta.remote_port] for x in choosen_rank_list]
for i in range(self.tp_num_need_pulls * self._prefill_pp_size):
for i in range(tp_num_need_pulls * self._prefill_pp_size):
assert self.kv_recv_thread is not None
remote_host, remote_engine_id = self._get_remote_host_info_by_port(
meta.remote_port,
@@ -1586,8 +1596,8 @@ class MooncakeConnectorWorker:
remote_host=remote_host,
remote_handshake_port=remote_handshake_port_list[i][0],
offset=i,
tp_num_need_pulls=self.tp_num_need_pulls,
all_task_done=(i == self.tp_num_need_pulls * self._prefill_pp_size - 1),
tp_num_need_pulls=tp_num_need_pulls,
all_task_done=(i == tp_num_need_pulls * self._prefill_pp_size - 1),
)
for req_id in metadata.reqs_in_batch:
@@ -1607,6 +1617,21 @@ class MooncakeConnectorWorker:
for req_id, delay_start_time in metadata.requests_to_send.items():
self.kv_send_thread.add_delayed_request(req_id, delay_start_time)
def _get_tp_num_need_pulls(self, prefill_tp_size: int) -> int:
if prefill_tp_size is None:
prefill_tp_size = self._prefill_tp_size
if prefill_tp_size == self._prefill_tp_size:
return self.tp_num_need_pulls
if self.vllm_config.model_config.is_deepseek_mla:
tp_num_need_pulls = 1
else:
num_d_block_heads = max(1, self.num_key_value_heads // self.tp_size)
num_p_block_heads = max(1, self.num_key_value_heads // prefill_tp_size)
tp_num_need_pulls = num_d_block_heads // num_p_block_heads
return tp_num_need_pulls
def _get_remote_host_info_by_port(
self,
base_port: int,
@@ -1624,16 +1649,17 @@ class MooncakeConnectorWorker:
def _prefill_get_remote_rank(self, req_id: str) -> list[int]:
return sum(self._get_remote_ranks_for_req(req_id), [])
def _get_remote_rank(self, req_id: str) -> list[int]:
return self._get_remote_ranks_for_req(req_id)[self.tp_rank]
def _get_remote_rank(self, req_id: str, prefill_tp_size: int | None = None) -> list[int]:
return self._get_remote_ranks_for_req(req_id, prefill_tp_size)[self.tp_rank]
def _get_remote_tp_ranks(
self, tp_ori_data: np.ndarray, rand_group_index: list[int], num_groups: int
self, tp_ori_data: np.ndarray, rand_group_index: list[int], num_groups: int, prefill_tp_size: int
) -> list[list[int]]:
tp_num_need_pulls = self._get_tp_num_need_pulls(prefill_tp_size)
# random split prefill tp list
tp_sampled_nums = []
if (
self._prefill_tp_size > self.num_key_value_heads
prefill_tp_size > self.num_key_value_heads
or self.vllm_config.model_config.is_deepseek_mla
or self.use_sparse
):
@@ -1641,24 +1667,27 @@ class MooncakeConnectorWorker:
choosen_group = tp_ori_data[:, [rand_group_index]]
flattened = choosen_group.reshape(-1).tolist()
tp_sampled_nums = [
flattened[i : i + self.tp_num_need_pulls] for i in range(0, len(flattened), self.tp_num_need_pulls)
flattened[i : i + tp_num_need_pulls] for i in range(0, len(flattened), tp_num_need_pulls)
]
# non-random split
else:
group_size = self._prefill_tp_size // self._decode_tp_size
group_size = prefill_tp_size // self._decode_tp_size
for i in range(self._decode_tp_size):
slice = tp_ori_data[i * group_size : (i + 1) * group_size]
tp_sampled_nums.append(slice.tolist())
return tp_sampled_nums
def _get_remote_ranks_for_req(self, req_id: str) -> list[list[int]]:
def _get_remote_ranks_for_req(self, req_id: str, prefill_tp_size: int | None = None) -> list[list[int]]:
if prefill_tp_size is None:
prefill_tp_size = self._prefill_tp_size
# Divide the ports according to the TP within the PP
sampled_nums = []
if self._prefill_tp_size == self._decode_tp_size:
if prefill_tp_size == self._decode_tp_size:
sampled_nums = list(
map(
lambda tp: [tp + pp * self._prefill_tp_size for pp in range(self._prefill_pp_size)],
range(self._prefill_tp_size),
lambda tp: [tp + pp * prefill_tp_size for pp in range(self._prefill_pp_size)],
range(prefill_tp_size),
)
)
return sampled_nums
@@ -1667,7 +1696,7 @@ class MooncakeConnectorWorker:
num_kv_head = 1
else:
num_kv_head = self.num_key_value_heads
ori_data = np.arange(self._prefill_tp_size * self._prefill_pp_size)
ori_data = np.arange(prefill_tp_size * self._prefill_pp_size)
seed = string_to_int64_hash(req_id)
rand = random.Random(seed)
# random split prefill tp list
@@ -1679,7 +1708,7 @@ class MooncakeConnectorWorker:
range(num_groups), (max(self._decode_tp_size // num_kv_head, 1))
) # random choose a group
all_results = [
self._get_remote_tp_ranks(ori_data[pp_index], rand_group_index, num_groups)
self._get_remote_tp_ranks(ori_data[pp_index], rand_group_index, num_groups, prefill_tp_size)
for pp_index in range(self._prefill_pp_size)
]
for group_index in range(len(all_results[0])):