[Feature] Mooncake connector get remote ptp size (#5822)
### What this PR does / why we need it?
To support elastic scaling when using mooncake connector, we should
support to **configure different tp sizes for different nodes**.
As a result, we transfer the prefill node information, such as tp size,
through **the request's kv_transfer_params**.
The decode nodes **get the prefill tp size** through the request's
kv_transfer_params, instead of getting it from the configuration of the
mooncake connector .
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef
Signed-off-by: yuxinshan <syx_ctyg@126.com>
Signed-off-by: CalvinXKY <kyxiezju@163.com>
This commit is contained in:
@@ -7,7 +7,7 @@ import time
|
||||
import types
|
||||
import unittest
|
||||
from collections import defaultdict, deque
|
||||
from typing import Any, Dict, OrderedDict
|
||||
from typing import Any, Dict, OrderedDict, Optional
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import msgspec
|
||||
@@ -691,7 +691,8 @@ class TestMooncakeConnectorMetadata(unittest.TestCase):
|
||||
"remote_host": "localhost",
|
||||
"remote_port": 5000,
|
||||
"remote_pcp_size": 1,
|
||||
"remote_dcp_size": 1
|
||||
"remote_dcp_size": 1,
|
||||
"remote_ptp_size": 2
|
||||
})
|
||||
|
||||
self.assertEqual(len(meta.requests), 1)
|
||||
@@ -702,6 +703,7 @@ class TestMooncakeConnectorMetadata(unittest.TestCase):
|
||||
self.assertEqual(req_meta.remote_engine_id, "remote_engine")
|
||||
self.assertEqual(req_meta.remote_host, "localhost")
|
||||
self.assertEqual(req_meta.remote_port, 5000)
|
||||
self.assertEqual(req_meta.remote_ptp_size, 2)
|
||||
|
||||
|
||||
class TestMooncakeConnectorSchedulerMatchedTokens(unittest.TestCase):
|
||||
@@ -1209,9 +1211,13 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
|
||||
|
||||
def test_get_remote_tp_rank(self):
|
||||
|
||||
def get_tp_rank(prefill_tp_size: int, prefill_pp_size: int,
|
||||
decode_tp_size: int, num_kv_heads: int,
|
||||
tp_num_need_pulls: int, is_deepseek_mla: bool):
|
||||
def get_tp_rank(prefill_tp_size: int,
|
||||
prefill_pp_size: int,
|
||||
decode_tp_size: int,
|
||||
num_kv_heads: int,
|
||||
tp_num_need_pulls: int,
|
||||
is_deepseek_mla: bool,
|
||||
remote_ptp_size: Optional[int] = None):
|
||||
with patch(
|
||||
'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_ascend_config',
|
||||
return_value=MagicMock()), \
|
||||
@@ -1226,7 +1232,8 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
|
||||
self.engine_id)
|
||||
worker.tp_num_need_pulls = tp_num_need_pulls
|
||||
worker.use_sparse = 0
|
||||
return worker._get_remote_ranks_for_req('test')
|
||||
return worker._get_remote_ranks_for_req(
|
||||
'test', remote_ptp_size)
|
||||
|
||||
self.assertIn(
|
||||
get_tp_rank(16, 1, 1, 4, 4, False)[0],
|
||||
@@ -1285,13 +1292,30 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
|
||||
get_tp_rank(4, 4, 4, 1, 1, True),
|
||||
[[[0, 4, 8, 12], [1, 5, 9, 13], [2, 6, 10, 14], [3, 7, 11, 15]]])
|
||||
|
||||
# check remote ptp size
|
||||
self.assertListEqual(get_tp_rank(16, 1, 2, 4, 2, False, 8),
|
||||
get_tp_rank(8, 1, 2, 4, 2, False))
|
||||
self.assertListEqual(get_tp_rank(8, 1, 2, 4, 2, False, 4),
|
||||
get_tp_rank(4, 1, 2, 4, 2, False))
|
||||
self.assertListEqual(get_tp_rank(4, 1, 2, 4, 1, False, 2),
|
||||
get_tp_rank(2, 1, 2, 4, 1, False))
|
||||
|
||||
def test_get_kv_split_metadata(self):
|
||||
|
||||
def get_kv_split_metadata(use_mla, pcp_size, dcp_size, tp_size,
|
||||
tp_rank, pcp_rank, _prefill_tp_size,
|
||||
remote_pcp_size, remote_dcp_size,
|
||||
remote_port, remote_block_ids,
|
||||
local_block_ids, remote_engine_id):
|
||||
def get_kv_split_metadata(use_mla,
|
||||
pcp_size,
|
||||
dcp_size,
|
||||
tp_size,
|
||||
tp_rank,
|
||||
pcp_rank,
|
||||
_prefill_tp_size,
|
||||
remote_pcp_size,
|
||||
remote_dcp_size,
|
||||
remote_port,
|
||||
remote_block_ids,
|
||||
local_block_ids,
|
||||
remote_engine_id,
|
||||
remote_ptp_size=None):
|
||||
|
||||
worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id)
|
||||
|
||||
@@ -1310,6 +1334,7 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
|
||||
|
||||
meta.remote_pcp_size = remote_pcp_size
|
||||
meta.remote_dcp_size = remote_dcp_size
|
||||
meta.remote_ptp_size = remote_ptp_size
|
||||
meta.remote_port = remote_port
|
||||
meta.remote_block_ids = remote_block_ids
|
||||
meta.local_block_ids = local_block_ids
|
||||
@@ -1367,6 +1392,40 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
|
||||
[1, 2, 3], [1, 2, 3, 4, 5], 0),
|
||||
([[30000], [30008]], [[1, 2, 3], [4, 5]], [[1, 2, 3], [1, 2]]))
|
||||
|
||||
# check remote ptp size
|
||||
self.assertEqual(
|
||||
get_kv_split_metadata(True, 1, 1, 8, 1, 0, 8, 1, 8, 30000, [1],
|
||||
[1], 0, 16),
|
||||
get_kv_split_metadata(True, 1, 1, 8, 1, 0, 16, 1, 8, 30000, [1],
|
||||
[1], 0)
|
||||
)
|
||||
self.assertEqual(
|
||||
get_kv_split_metadata(False, 1, 1, 8, 1, 0, 8, 1, 8, 30000, [1],
|
||||
[1], 0, 16),
|
||||
get_kv_split_metadata(False, 1, 1, 8, 1, 0, 16, 1, 8, 30000, [1],
|
||||
[1], 0)
|
||||
)
|
||||
self.assertEqual(
|
||||
get_kv_split_metadata(False, 1, 1, 8, 1, 0, 8, 2, 8, 30000, [1],
|
||||
[1], 0, 16),
|
||||
get_kv_split_metadata(False, 1, 1, 8, 1, 0, 16, 2, 8, 30000, [1],
|
||||
[1], 0)
|
||||
)
|
||||
|
||||
def test_get_tp_num_need_pulls(self):
|
||||
worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id)
|
||||
worker.num_key_value_heads = 8
|
||||
|
||||
tp_num_need_pulls = worker._get_tp_num_need_pulls(prefill_tp_size=4)
|
||||
self.assertEqual(tp_num_need_pulls, 1)
|
||||
|
||||
worker.vllm_config.model_config.is_deepseek_mla = False
|
||||
tp_num_need_pulls = worker._get_tp_num_need_pulls(prefill_tp_size=4)
|
||||
self.assertEqual(tp_num_need_pulls, 2)
|
||||
|
||||
tp_num_need_pulls = worker._get_tp_num_need_pulls(prefill_tp_size=None)
|
||||
self.assertEqual(tp_num_need_pulls, 1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user