[Feature] Mooncake connector get remote ptp size (#5822)

### What this PR does / why we need it?
To support elastic scaling when using mooncake connector, we should
support to **configure different tp sizes for different nodes**.
As a result, we transfer the prefill node information, such as tp size,
through **the request's kv_transfer_params**.
The decode nodes **get the prefill tp size** through the request's
kv_transfer_params, instead of getting it from the configuration of the
mooncake connector .

- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef

Signed-off-by: yuxinshan <syx_ctyg@126.com>
Signed-off-by: CalvinXKY <kyxiezju@163.com>
This commit is contained in:
yuxinshan
2026-01-26 14:28:33 +08:00
committed by GitHub
parent 611e223b7d
commit 0bb1f91c2c
2 changed files with 124 additions and 36 deletions

View File

@@ -7,7 +7,7 @@ import time
import types
import unittest
from collections import defaultdict, deque
from typing import Any, Dict, OrderedDict
from typing import Any, Dict, OrderedDict, Optional
from unittest.mock import MagicMock, patch
import msgspec
@@ -691,7 +691,8 @@ class TestMooncakeConnectorMetadata(unittest.TestCase):
"remote_host": "localhost",
"remote_port": 5000,
"remote_pcp_size": 1,
"remote_dcp_size": 1
"remote_dcp_size": 1,
"remote_ptp_size": 2
})
self.assertEqual(len(meta.requests), 1)
@@ -702,6 +703,7 @@ class TestMooncakeConnectorMetadata(unittest.TestCase):
self.assertEqual(req_meta.remote_engine_id, "remote_engine")
self.assertEqual(req_meta.remote_host, "localhost")
self.assertEqual(req_meta.remote_port, 5000)
self.assertEqual(req_meta.remote_ptp_size, 2)
class TestMooncakeConnectorSchedulerMatchedTokens(unittest.TestCase):
@@ -1209,9 +1211,13 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
def test_get_remote_tp_rank(self):
def get_tp_rank(prefill_tp_size: int, prefill_pp_size: int,
decode_tp_size: int, num_kv_heads: int,
tp_num_need_pulls: int, is_deepseek_mla: bool):
def get_tp_rank(prefill_tp_size: int,
prefill_pp_size: int,
decode_tp_size: int,
num_kv_heads: int,
tp_num_need_pulls: int,
is_deepseek_mla: bool,
remote_ptp_size: Optional[int] = None):
with patch(
'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_ascend_config',
return_value=MagicMock()), \
@@ -1226,7 +1232,8 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
self.engine_id)
worker.tp_num_need_pulls = tp_num_need_pulls
worker.use_sparse = 0
return worker._get_remote_ranks_for_req('test')
return worker._get_remote_ranks_for_req(
'test', remote_ptp_size)
self.assertIn(
get_tp_rank(16, 1, 1, 4, 4, False)[0],
@@ -1285,13 +1292,30 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
get_tp_rank(4, 4, 4, 1, 1, True),
[[[0, 4, 8, 12], [1, 5, 9, 13], [2, 6, 10, 14], [3, 7, 11, 15]]])
# check remote ptp size
self.assertListEqual(get_tp_rank(16, 1, 2, 4, 2, False, 8),
get_tp_rank(8, 1, 2, 4, 2, False))
self.assertListEqual(get_tp_rank(8, 1, 2, 4, 2, False, 4),
get_tp_rank(4, 1, 2, 4, 2, False))
self.assertListEqual(get_tp_rank(4, 1, 2, 4, 1, False, 2),
get_tp_rank(2, 1, 2, 4, 1, False))
def test_get_kv_split_metadata(self):
def get_kv_split_metadata(use_mla, pcp_size, dcp_size, tp_size,
tp_rank, pcp_rank, _prefill_tp_size,
remote_pcp_size, remote_dcp_size,
remote_port, remote_block_ids,
local_block_ids, remote_engine_id):
def get_kv_split_metadata(use_mla,
pcp_size,
dcp_size,
tp_size,
tp_rank,
pcp_rank,
_prefill_tp_size,
remote_pcp_size,
remote_dcp_size,
remote_port,
remote_block_ids,
local_block_ids,
remote_engine_id,
remote_ptp_size=None):
worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id)
@@ -1310,6 +1334,7 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
meta.remote_pcp_size = remote_pcp_size
meta.remote_dcp_size = remote_dcp_size
meta.remote_ptp_size = remote_ptp_size
meta.remote_port = remote_port
meta.remote_block_ids = remote_block_ids
meta.local_block_ids = local_block_ids
@@ -1367,6 +1392,40 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
[1, 2, 3], [1, 2, 3, 4, 5], 0),
([[30000], [30008]], [[1, 2, 3], [4, 5]], [[1, 2, 3], [1, 2]]))
# check remote ptp size
self.assertEqual(
get_kv_split_metadata(True, 1, 1, 8, 1, 0, 8, 1, 8, 30000, [1],
[1], 0, 16),
get_kv_split_metadata(True, 1, 1, 8, 1, 0, 16, 1, 8, 30000, [1],
[1], 0)
)
self.assertEqual(
get_kv_split_metadata(False, 1, 1, 8, 1, 0, 8, 1, 8, 30000, [1],
[1], 0, 16),
get_kv_split_metadata(False, 1, 1, 8, 1, 0, 16, 1, 8, 30000, [1],
[1], 0)
)
self.assertEqual(
get_kv_split_metadata(False, 1, 1, 8, 1, 0, 8, 2, 8, 30000, [1],
[1], 0, 16),
get_kv_split_metadata(False, 1, 1, 8, 1, 0, 16, 2, 8, 30000, [1],
[1], 0)
)
def test_get_tp_num_need_pulls(self):
worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id)
worker.num_key_value_heads = 8
tp_num_need_pulls = worker._get_tp_num_need_pulls(prefill_tp_size=4)
self.assertEqual(tp_num_need_pulls, 1)
worker.vllm_config.model_config.is_deepseek_mla = False
tp_num_need_pulls = worker._get_tp_num_need_pulls(prefill_tp_size=4)
self.assertEqual(tp_num_need_pulls, 2)
tp_num_need_pulls = worker._get_tp_num_need_pulls(prefill_tp_size=None)
self.assertEqual(tp_num_need_pulls, 1)
if __name__ == '__main__':
unittest.main()