From 0bb1f91c2cdb372a507ec5ba1c49c5fb68e3c41c Mon Sep 17 00:00:00 2001 From: yuxinshan <82206277+yuxinshan@users.noreply.github.com> Date: Mon, 26 Jan 2026 14:28:33 +0800 Subject: [PATCH] [Feature] Mooncake connector get remote ptp size (#5822) ### What this PR does / why we need it? To support elastic scaling when using mooncake connector, we should support to **configure different tp sizes for different nodes**. As a result, we transfer the prefill node information, such as tp size, through **the request's kv_transfer_params**. The decode nodes **get the prefill tp size** through the request's kv_transfer_params, instead of getting it from the configuration of the mooncake connector . - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2f4e6548efec402b913ffddc8726230d9311948d Signed-off-by: yuxinshan Signed-off-by: CalvinXKY --- .../kv_connector/test_mooncake_connector.py | 81 ++++++++++++++++--- .../kv_transfer/kv_p2p/mooncake_connector.py | 79 ++++++++++++------ 2 files changed, 124 insertions(+), 36 deletions(-) diff --git a/tests/ut/kv_connector/test_mooncake_connector.py b/tests/ut/kv_connector/test_mooncake_connector.py index 8cd1e6af..45918c4a 100644 --- a/tests/ut/kv_connector/test_mooncake_connector.py +++ b/tests/ut/kv_connector/test_mooncake_connector.py @@ -7,7 +7,7 @@ import time import types import unittest from collections import defaultdict, deque -from typing import Any, Dict, OrderedDict +from typing import Any, Dict, OrderedDict, Optional from unittest.mock import MagicMock, patch import msgspec @@ -691,7 +691,8 @@ class TestMooncakeConnectorMetadata(unittest.TestCase): "remote_host": "localhost", "remote_port": 5000, "remote_pcp_size": 1, - "remote_dcp_size": 1 + "remote_dcp_size": 1, + "remote_ptp_size": 2 }) self.assertEqual(len(meta.requests), 1) @@ -702,6 +703,7 @@ class TestMooncakeConnectorMetadata(unittest.TestCase): self.assertEqual(req_meta.remote_engine_id, "remote_engine") self.assertEqual(req_meta.remote_host, "localhost") self.assertEqual(req_meta.remote_port, 5000) + self.assertEqual(req_meta.remote_ptp_size, 2) class TestMooncakeConnectorSchedulerMatchedTokens(unittest.TestCase): @@ -1209,9 +1211,13 @@ class TestMooncakeConnectorWorker(unittest.TestCase): def test_get_remote_tp_rank(self): - def get_tp_rank(prefill_tp_size: int, prefill_pp_size: int, - decode_tp_size: int, num_kv_heads: int, - tp_num_need_pulls: int, is_deepseek_mla: bool): + def get_tp_rank(prefill_tp_size: int, + prefill_pp_size: int, + decode_tp_size: int, + num_kv_heads: int, + tp_num_need_pulls: int, + is_deepseek_mla: bool, + remote_ptp_size: Optional[int] = None): with patch( 'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_ascend_config', return_value=MagicMock()), \ @@ -1226,7 +1232,8 @@ class TestMooncakeConnectorWorker(unittest.TestCase): self.engine_id) worker.tp_num_need_pulls = tp_num_need_pulls worker.use_sparse = 0 - return worker._get_remote_ranks_for_req('test') + return worker._get_remote_ranks_for_req( + 'test', remote_ptp_size) self.assertIn( get_tp_rank(16, 1, 1, 4, 4, False)[0], @@ -1285,13 +1292,30 @@ class TestMooncakeConnectorWorker(unittest.TestCase): get_tp_rank(4, 4, 4, 1, 1, True), [[[0, 4, 8, 12], [1, 5, 9, 13], [2, 6, 10, 14], [3, 7, 11, 15]]]) + # check remote ptp size + self.assertListEqual(get_tp_rank(16, 1, 2, 4, 2, False, 8), + get_tp_rank(8, 1, 2, 4, 2, False)) + self.assertListEqual(get_tp_rank(8, 1, 2, 4, 2, False, 4), + get_tp_rank(4, 1, 2, 4, 2, False)) + self.assertListEqual(get_tp_rank(4, 1, 2, 4, 1, False, 2), + get_tp_rank(2, 1, 2, 4, 1, False)) + def test_get_kv_split_metadata(self): - def get_kv_split_metadata(use_mla, pcp_size, dcp_size, tp_size, - tp_rank, pcp_rank, _prefill_tp_size, - remote_pcp_size, remote_dcp_size, - remote_port, remote_block_ids, - local_block_ids, remote_engine_id): + def get_kv_split_metadata(use_mla, + pcp_size, + dcp_size, + tp_size, + tp_rank, + pcp_rank, + _prefill_tp_size, + remote_pcp_size, + remote_dcp_size, + remote_port, + remote_block_ids, + local_block_ids, + remote_engine_id, + remote_ptp_size=None): worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id) @@ -1310,6 +1334,7 @@ class TestMooncakeConnectorWorker(unittest.TestCase): meta.remote_pcp_size = remote_pcp_size meta.remote_dcp_size = remote_dcp_size + meta.remote_ptp_size = remote_ptp_size meta.remote_port = remote_port meta.remote_block_ids = remote_block_ids meta.local_block_ids = local_block_ids @@ -1367,6 +1392,40 @@ class TestMooncakeConnectorWorker(unittest.TestCase): [1, 2, 3], [1, 2, 3, 4, 5], 0), ([[30000], [30008]], [[1, 2, 3], [4, 5]], [[1, 2, 3], [1, 2]])) + # check remote ptp size + self.assertEqual( + get_kv_split_metadata(True, 1, 1, 8, 1, 0, 8, 1, 8, 30000, [1], + [1], 0, 16), + get_kv_split_metadata(True, 1, 1, 8, 1, 0, 16, 1, 8, 30000, [1], + [1], 0) + ) + self.assertEqual( + get_kv_split_metadata(False, 1, 1, 8, 1, 0, 8, 1, 8, 30000, [1], + [1], 0, 16), + get_kv_split_metadata(False, 1, 1, 8, 1, 0, 16, 1, 8, 30000, [1], + [1], 0) + ) + self.assertEqual( + get_kv_split_metadata(False, 1, 1, 8, 1, 0, 8, 2, 8, 30000, [1], + [1], 0, 16), + get_kv_split_metadata(False, 1, 1, 8, 1, 0, 16, 2, 8, 30000, [1], + [1], 0) + ) + + def test_get_tp_num_need_pulls(self): + worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id) + worker.num_key_value_heads = 8 + + tp_num_need_pulls = worker._get_tp_num_need_pulls(prefill_tp_size=4) + self.assertEqual(tp_num_need_pulls, 1) + + worker.vllm_config.model_config.is_deepseek_mla = False + tp_num_need_pulls = worker._get_tp_num_need_pulls(prefill_tp_size=4) + self.assertEqual(tp_num_need_pulls, 2) + + tp_num_need_pulls = worker._get_tp_num_need_pulls(prefill_tp_size=None) + self.assertEqual(tp_num_need_pulls, 1) + if __name__ == '__main__': unittest.main() diff --git a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_connector.py b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_connector.py index d03ba3ed..1464a2fc 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_connector.py +++ b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_connector.py @@ -87,6 +87,7 @@ class ReqMeta: remote_request_id: str remote_pcp_size: int remote_dcp_size: int + remote_ptp_size: int | None remote_multi_nodes_meta_mapping: dict[str, dict[str, Any]] num_prompt_blocks: int @@ -773,6 +774,7 @@ class MooncakeConnectorMetadata(KVConnectorMetadata): remote_port=kv_transfer_params["remote_port"], remote_pcp_size=kv_transfer_params.get("remote_pcp_size", 1), remote_dcp_size=kv_transfer_params.get("remote_dcp_size", 1), + remote_ptp_size=kv_transfer_params.get("remote_ptp_size"), remote_multi_nodes_meta_mapping=kv_transfer_params.get("remote_multi_nodes_meta_mapping", {}), num_prompt_blocks=kv_transfer_params.get("num_prompt_blocks", 0), ) @@ -890,6 +892,7 @@ class MooncakeConnectorScheduler: self.side_channel_host = get_ip() self.pcp_size = vllm_config.parallel_config.prefill_context_parallel_size self.dcp_size = vllm_config.parallel_config.decode_context_parallel_size + self.tp_size = vllm_config.parallel_config.tensor_parallel_size self.max_device_id = ( vllm_config.parallel_config.tensor_parallel_size * vllm_config.parallel_config.data_parallel_size @@ -1039,6 +1042,7 @@ class MooncakeConnectorScheduler: remote_port=self.side_channel_port, remote_pcp_size=self.pcp_size, remote_dcp_size=self.dcp_size, + remote_ptp_size=self.tp_size, last_token_id=request.output_token_ids[-1], remote_multi_nodes_meta_mapping=self.multi_nodes_meta_mapping, num_prompt_blocks=num_prompt_blocks, @@ -1324,8 +1328,9 @@ class MooncakeConnectorWorker: In cp/dcp scenario, kv_cache may be split, so we need to pull multiple blocks from multiple remote P node. Use this function to calculate remote port and remote block number of each remote P node that we need to pull. """ + prefill_tp_size = meta.remote_ptp_size if getattr(meta, "remote_ptp_size", None) else self._prefill_tp_size if meta.remote_pcp_size * meta.remote_dcp_size * self.pcp_size * self.dcp_size == 1: - choosen_rank_list = self._get_remote_rank(req_id) + choosen_rank_list = self._get_remote_rank(req_id, prefill_tp_size) remote_handshake_port_list = [[x + meta.remote_port for x in choosen_rank_list]] local_block_ids_list, remote_block_ids_list = [meta.local_block_ids], [meta.remote_block_ids] return remote_handshake_port_list, local_block_ids_list, remote_block_ids_list @@ -1333,7 +1338,7 @@ class MooncakeConnectorWorker: def context_parallel_parameters_check(): assert (meta.remote_pcp_size * meta.remote_dcp_size) % (self.pcp_size * self.dcp_size) == 0 if not self.use_mla: - p_node_heads_per_rank = math.ceil(self.num_key_value_heads / self._prefill_tp_size) + p_node_heads_per_rank = math.ceil(self.num_key_value_heads / prefill_tp_size) d_node_heads_per_rank = math.ceil(self.num_key_value_heads / self.tp_size) assert d_node_heads_per_rank % p_node_heads_per_rank == 0 @@ -1387,7 +1392,7 @@ class MooncakeConnectorWorker: def get_local_remote_block_port_mappings(): context_parallel_parameters_check() p_node_cp_group_meta = get_cp_group_meta( - self._prefill_tp_size, meta.remote_pcp_size, meta.remote_dcp_size, meta.remote_port + prefill_tp_size, meta.remote_pcp_size, meta.remote_dcp_size, meta.remote_port ) d_node_cp_group_meta = get_cp_group_meta(self.tp_size, self.pcp_size, self.dcp_size, self.side_channel_port) local_remote_block_port_mappings: dict[int, list[list[int]]] = {} @@ -1427,7 +1432,7 @@ class MooncakeConnectorWorker: local_remote_block_port_mappings: dict[int, list[list[int]]], ) -> dict[int, RemotePortInfo]: remote_port_send_num: dict[int, RemotePortInfo] = {} - for port in range(self._prefill_tp_size * meta.remote_pcp_size): + for port in range(prefill_tp_size * meta.remote_pcp_size): remote_host_info = meta.remote_multi_nodes_meta_mapping.get(str(port), None) if remote_host_info is None: remote_host = meta.remote_host @@ -1518,8 +1523,9 @@ class MooncakeConnectorWorker: ) local_block_offset += num_blocks_to_pull - assert self.tp_num_need_pulls == len(remote_handshake_port_list[0]), ( - f"tp_num_need_pulls: {self.tp_num_need_pulls}, remote_handshake_port_list: {remote_handshake_port_list}" + tp_num_need_pulls = self._get_tp_num_need_pulls(prefill_tp_size) + assert tp_num_need_pulls == len(remote_handshake_port_list[0]), ( + f"tp_num_need_pulls: {tp_num_need_pulls}, remote_handshake_port_list: {remote_handshake_port_list}" ) return remote_handshake_port_list, local_block_ids_list, remote_block_ids_list @@ -1535,13 +1541,17 @@ class MooncakeConnectorWorker: len(meta.local_block_ids), len(meta.remote_block_ids), ) + + prefill_tp_size = meta.remote_ptp_size if getattr(meta, "remote_ptp_size", None) else self._prefill_tp_size + tp_num_need_pulls = self._get_tp_num_need_pulls(prefill_tp_size) + if meta.remote_pcp_size * meta.remote_dcp_size > 1: remote_handshake_port_list, local_block_ids_list, remote_block_ids_list = self._get_kv_split_metadata( req_id, meta ) for pcp_dcp_rank in range(len(remote_handshake_port_list)): - for i in range(self.tp_num_need_pulls): + for i in range(tp_num_need_pulls): assert self.kv_recv_thread is not None remote_host, remote_engine_id = self._get_remote_host_info_by_port( meta.remote_port, @@ -1559,16 +1569,16 @@ class MooncakeConnectorWorker: remote_host=remote_host, remote_handshake_port=remote_handshake_port_list[pcp_dcp_rank][i], offset=i, - tp_num_need_pulls=self.tp_num_need_pulls, + tp_num_need_pulls=tp_num_need_pulls, remote_port_send_num=self.remote_port_send_num[meta.remote_engine_id], all_task_done=( - pcp_dcp_rank == len(remote_handshake_port_list) - 1 and i == self.tp_num_need_pulls - 1 + pcp_dcp_rank == len(remote_handshake_port_list) - 1 and i == tp_num_need_pulls - 1 ), ) else: # TODO: support prefill context parallel and pipeline parallel open at the same time - choosen_rank_list = self._get_remote_rank(req_id) + choosen_rank_list = self._get_remote_rank(req_id, prefill_tp_size) remote_handshake_port_list = [[x + meta.remote_port] for x in choosen_rank_list] - for i in range(self.tp_num_need_pulls * self._prefill_pp_size): + for i in range(tp_num_need_pulls * self._prefill_pp_size): assert self.kv_recv_thread is not None remote_host, remote_engine_id = self._get_remote_host_info_by_port( meta.remote_port, @@ -1586,8 +1596,8 @@ class MooncakeConnectorWorker: remote_host=remote_host, remote_handshake_port=remote_handshake_port_list[i][0], offset=i, - tp_num_need_pulls=self.tp_num_need_pulls, - all_task_done=(i == self.tp_num_need_pulls * self._prefill_pp_size - 1), + tp_num_need_pulls=tp_num_need_pulls, + all_task_done=(i == tp_num_need_pulls * self._prefill_pp_size - 1), ) for req_id in metadata.reqs_in_batch: @@ -1607,6 +1617,21 @@ class MooncakeConnectorWorker: for req_id, delay_start_time in metadata.requests_to_send.items(): self.kv_send_thread.add_delayed_request(req_id, delay_start_time) + def _get_tp_num_need_pulls(self, prefill_tp_size: int) -> int: + if prefill_tp_size is None: + prefill_tp_size = self._prefill_tp_size + + if prefill_tp_size == self._prefill_tp_size: + return self.tp_num_need_pulls + + if self.vllm_config.model_config.is_deepseek_mla: + tp_num_need_pulls = 1 + else: + num_d_block_heads = max(1, self.num_key_value_heads // self.tp_size) + num_p_block_heads = max(1, self.num_key_value_heads // prefill_tp_size) + tp_num_need_pulls = num_d_block_heads // num_p_block_heads + return tp_num_need_pulls + def _get_remote_host_info_by_port( self, base_port: int, @@ -1624,16 +1649,17 @@ class MooncakeConnectorWorker: def _prefill_get_remote_rank(self, req_id: str) -> list[int]: return sum(self._get_remote_ranks_for_req(req_id), []) - def _get_remote_rank(self, req_id: str) -> list[int]: - return self._get_remote_ranks_for_req(req_id)[self.tp_rank] + def _get_remote_rank(self, req_id: str, prefill_tp_size: int | None = None) -> list[int]: + return self._get_remote_ranks_for_req(req_id, prefill_tp_size)[self.tp_rank] def _get_remote_tp_ranks( - self, tp_ori_data: np.ndarray, rand_group_index: list[int], num_groups: int + self, tp_ori_data: np.ndarray, rand_group_index: list[int], num_groups: int, prefill_tp_size: int ) -> list[list[int]]: + tp_num_need_pulls = self._get_tp_num_need_pulls(prefill_tp_size) # random split prefill tp list tp_sampled_nums = [] if ( - self._prefill_tp_size > self.num_key_value_heads + prefill_tp_size > self.num_key_value_heads or self.vllm_config.model_config.is_deepseek_mla or self.use_sparse ): @@ -1641,24 +1667,27 @@ class MooncakeConnectorWorker: choosen_group = tp_ori_data[:, [rand_group_index]] flattened = choosen_group.reshape(-1).tolist() tp_sampled_nums = [ - flattened[i : i + self.tp_num_need_pulls] for i in range(0, len(flattened), self.tp_num_need_pulls) + flattened[i : i + tp_num_need_pulls] for i in range(0, len(flattened), tp_num_need_pulls) ] # non-random split else: - group_size = self._prefill_tp_size // self._decode_tp_size + group_size = prefill_tp_size // self._decode_tp_size for i in range(self._decode_tp_size): slice = tp_ori_data[i * group_size : (i + 1) * group_size] tp_sampled_nums.append(slice.tolist()) return tp_sampled_nums - def _get_remote_ranks_for_req(self, req_id: str) -> list[list[int]]: + def _get_remote_ranks_for_req(self, req_id: str, prefill_tp_size: int | None = None) -> list[list[int]]: + if prefill_tp_size is None: + prefill_tp_size = self._prefill_tp_size + # Divide the ports according to the TP within the PP sampled_nums = [] - if self._prefill_tp_size == self._decode_tp_size: + if prefill_tp_size == self._decode_tp_size: sampled_nums = list( map( - lambda tp: [tp + pp * self._prefill_tp_size for pp in range(self._prefill_pp_size)], - range(self._prefill_tp_size), + lambda tp: [tp + pp * prefill_tp_size for pp in range(self._prefill_pp_size)], + range(prefill_tp_size), ) ) return sampled_nums @@ -1667,7 +1696,7 @@ class MooncakeConnectorWorker: num_kv_head = 1 else: num_kv_head = self.num_key_value_heads - ori_data = np.arange(self._prefill_tp_size * self._prefill_pp_size) + ori_data = np.arange(prefill_tp_size * self._prefill_pp_size) seed = string_to_int64_hash(req_id) rand = random.Random(seed) # random split prefill tp list @@ -1679,7 +1708,7 @@ class MooncakeConnectorWorker: range(num_groups), (max(self._decode_tp_size // num_kv_head, 1)) ) # random choose a group all_results = [ - self._get_remote_tp_ranks(ori_data[pp_index], rand_group_index, num_groups) + self._get_remote_tp_ranks(ori_data[pp_index], rand_group_index, num_groups, prefill_tp_size) for pp_index in range(self._prefill_pp_size) ] for group_index in range(len(all_results[0])):