[Feature] Mooncake connector get remote ptp size (#5822)
### What this PR does / why we need it?
To support elastic scaling when using mooncake connector, we should
support to **configure different tp sizes for different nodes**.
As a result, we transfer the prefill node information, such as tp size,
through **the request's kv_transfer_params**.
The decode nodes **get the prefill tp size** through the request's
kv_transfer_params, instead of getting it from the configuration of the
mooncake connector .
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef
Signed-off-by: yuxinshan <syx_ctyg@126.com>
Signed-off-by: CalvinXKY <kyxiezju@163.com>
This commit is contained in:
@@ -7,7 +7,7 @@ import time
|
|||||||
import types
|
import types
|
||||||
import unittest
|
import unittest
|
||||||
from collections import defaultdict, deque
|
from collections import defaultdict, deque
|
||||||
from typing import Any, Dict, OrderedDict
|
from typing import Any, Dict, OrderedDict, Optional
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import msgspec
|
import msgspec
|
||||||
@@ -691,7 +691,8 @@ class TestMooncakeConnectorMetadata(unittest.TestCase):
|
|||||||
"remote_host": "localhost",
|
"remote_host": "localhost",
|
||||||
"remote_port": 5000,
|
"remote_port": 5000,
|
||||||
"remote_pcp_size": 1,
|
"remote_pcp_size": 1,
|
||||||
"remote_dcp_size": 1
|
"remote_dcp_size": 1,
|
||||||
|
"remote_ptp_size": 2
|
||||||
})
|
})
|
||||||
|
|
||||||
self.assertEqual(len(meta.requests), 1)
|
self.assertEqual(len(meta.requests), 1)
|
||||||
@@ -702,6 +703,7 @@ class TestMooncakeConnectorMetadata(unittest.TestCase):
|
|||||||
self.assertEqual(req_meta.remote_engine_id, "remote_engine")
|
self.assertEqual(req_meta.remote_engine_id, "remote_engine")
|
||||||
self.assertEqual(req_meta.remote_host, "localhost")
|
self.assertEqual(req_meta.remote_host, "localhost")
|
||||||
self.assertEqual(req_meta.remote_port, 5000)
|
self.assertEqual(req_meta.remote_port, 5000)
|
||||||
|
self.assertEqual(req_meta.remote_ptp_size, 2)
|
||||||
|
|
||||||
|
|
||||||
class TestMooncakeConnectorSchedulerMatchedTokens(unittest.TestCase):
|
class TestMooncakeConnectorSchedulerMatchedTokens(unittest.TestCase):
|
||||||
@@ -1209,9 +1211,13 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
|
|||||||
|
|
||||||
def test_get_remote_tp_rank(self):
|
def test_get_remote_tp_rank(self):
|
||||||
|
|
||||||
def get_tp_rank(prefill_tp_size: int, prefill_pp_size: int,
|
def get_tp_rank(prefill_tp_size: int,
|
||||||
decode_tp_size: int, num_kv_heads: int,
|
prefill_pp_size: int,
|
||||||
tp_num_need_pulls: int, is_deepseek_mla: bool):
|
decode_tp_size: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
tp_num_need_pulls: int,
|
||||||
|
is_deepseek_mla: bool,
|
||||||
|
remote_ptp_size: Optional[int] = None):
|
||||||
with patch(
|
with patch(
|
||||||
'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_ascend_config',
|
'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_ascend_config',
|
||||||
return_value=MagicMock()), \
|
return_value=MagicMock()), \
|
||||||
@@ -1226,7 +1232,8 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
|
|||||||
self.engine_id)
|
self.engine_id)
|
||||||
worker.tp_num_need_pulls = tp_num_need_pulls
|
worker.tp_num_need_pulls = tp_num_need_pulls
|
||||||
worker.use_sparse = 0
|
worker.use_sparse = 0
|
||||||
return worker._get_remote_ranks_for_req('test')
|
return worker._get_remote_ranks_for_req(
|
||||||
|
'test', remote_ptp_size)
|
||||||
|
|
||||||
self.assertIn(
|
self.assertIn(
|
||||||
get_tp_rank(16, 1, 1, 4, 4, False)[0],
|
get_tp_rank(16, 1, 1, 4, 4, False)[0],
|
||||||
@@ -1285,13 +1292,30 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
|
|||||||
get_tp_rank(4, 4, 4, 1, 1, True),
|
get_tp_rank(4, 4, 4, 1, 1, True),
|
||||||
[[[0, 4, 8, 12], [1, 5, 9, 13], [2, 6, 10, 14], [3, 7, 11, 15]]])
|
[[[0, 4, 8, 12], [1, 5, 9, 13], [2, 6, 10, 14], [3, 7, 11, 15]]])
|
||||||
|
|
||||||
|
# check remote ptp size
|
||||||
|
self.assertListEqual(get_tp_rank(16, 1, 2, 4, 2, False, 8),
|
||||||
|
get_tp_rank(8, 1, 2, 4, 2, False))
|
||||||
|
self.assertListEqual(get_tp_rank(8, 1, 2, 4, 2, False, 4),
|
||||||
|
get_tp_rank(4, 1, 2, 4, 2, False))
|
||||||
|
self.assertListEqual(get_tp_rank(4, 1, 2, 4, 1, False, 2),
|
||||||
|
get_tp_rank(2, 1, 2, 4, 1, False))
|
||||||
|
|
||||||
def test_get_kv_split_metadata(self):
|
def test_get_kv_split_metadata(self):
|
||||||
|
|
||||||
def get_kv_split_metadata(use_mla, pcp_size, dcp_size, tp_size,
|
def get_kv_split_metadata(use_mla,
|
||||||
tp_rank, pcp_rank, _prefill_tp_size,
|
pcp_size,
|
||||||
remote_pcp_size, remote_dcp_size,
|
dcp_size,
|
||||||
remote_port, remote_block_ids,
|
tp_size,
|
||||||
local_block_ids, remote_engine_id):
|
tp_rank,
|
||||||
|
pcp_rank,
|
||||||
|
_prefill_tp_size,
|
||||||
|
remote_pcp_size,
|
||||||
|
remote_dcp_size,
|
||||||
|
remote_port,
|
||||||
|
remote_block_ids,
|
||||||
|
local_block_ids,
|
||||||
|
remote_engine_id,
|
||||||
|
remote_ptp_size=None):
|
||||||
|
|
||||||
worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id)
|
worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id)
|
||||||
|
|
||||||
@@ -1310,6 +1334,7 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
|
|||||||
|
|
||||||
meta.remote_pcp_size = remote_pcp_size
|
meta.remote_pcp_size = remote_pcp_size
|
||||||
meta.remote_dcp_size = remote_dcp_size
|
meta.remote_dcp_size = remote_dcp_size
|
||||||
|
meta.remote_ptp_size = remote_ptp_size
|
||||||
meta.remote_port = remote_port
|
meta.remote_port = remote_port
|
||||||
meta.remote_block_ids = remote_block_ids
|
meta.remote_block_ids = remote_block_ids
|
||||||
meta.local_block_ids = local_block_ids
|
meta.local_block_ids = local_block_ids
|
||||||
@@ -1367,6 +1392,40 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
|
|||||||
[1, 2, 3], [1, 2, 3, 4, 5], 0),
|
[1, 2, 3], [1, 2, 3, 4, 5], 0),
|
||||||
([[30000], [30008]], [[1, 2, 3], [4, 5]], [[1, 2, 3], [1, 2]]))
|
([[30000], [30008]], [[1, 2, 3], [4, 5]], [[1, 2, 3], [1, 2]]))
|
||||||
|
|
||||||
|
# check remote ptp size
|
||||||
|
self.assertEqual(
|
||||||
|
get_kv_split_metadata(True, 1, 1, 8, 1, 0, 8, 1, 8, 30000, [1],
|
||||||
|
[1], 0, 16),
|
||||||
|
get_kv_split_metadata(True, 1, 1, 8, 1, 0, 16, 1, 8, 30000, [1],
|
||||||
|
[1], 0)
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
get_kv_split_metadata(False, 1, 1, 8, 1, 0, 8, 1, 8, 30000, [1],
|
||||||
|
[1], 0, 16),
|
||||||
|
get_kv_split_metadata(False, 1, 1, 8, 1, 0, 16, 1, 8, 30000, [1],
|
||||||
|
[1], 0)
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
get_kv_split_metadata(False, 1, 1, 8, 1, 0, 8, 2, 8, 30000, [1],
|
||||||
|
[1], 0, 16),
|
||||||
|
get_kv_split_metadata(False, 1, 1, 8, 1, 0, 16, 2, 8, 30000, [1],
|
||||||
|
[1], 0)
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_get_tp_num_need_pulls(self):
|
||||||
|
worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id)
|
||||||
|
worker.num_key_value_heads = 8
|
||||||
|
|
||||||
|
tp_num_need_pulls = worker._get_tp_num_need_pulls(prefill_tp_size=4)
|
||||||
|
self.assertEqual(tp_num_need_pulls, 1)
|
||||||
|
|
||||||
|
worker.vllm_config.model_config.is_deepseek_mla = False
|
||||||
|
tp_num_need_pulls = worker._get_tp_num_need_pulls(prefill_tp_size=4)
|
||||||
|
self.assertEqual(tp_num_need_pulls, 2)
|
||||||
|
|
||||||
|
tp_num_need_pulls = worker._get_tp_num_need_pulls(prefill_tp_size=None)
|
||||||
|
self.assertEqual(tp_num_need_pulls, 1)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -87,6 +87,7 @@ class ReqMeta:
|
|||||||
remote_request_id: str
|
remote_request_id: str
|
||||||
remote_pcp_size: int
|
remote_pcp_size: int
|
||||||
remote_dcp_size: int
|
remote_dcp_size: int
|
||||||
|
remote_ptp_size: int | None
|
||||||
remote_multi_nodes_meta_mapping: dict[str, dict[str, Any]]
|
remote_multi_nodes_meta_mapping: dict[str, dict[str, Any]]
|
||||||
num_prompt_blocks: int
|
num_prompt_blocks: int
|
||||||
|
|
||||||
@@ -773,6 +774,7 @@ class MooncakeConnectorMetadata(KVConnectorMetadata):
|
|||||||
remote_port=kv_transfer_params["remote_port"],
|
remote_port=kv_transfer_params["remote_port"],
|
||||||
remote_pcp_size=kv_transfer_params.get("remote_pcp_size", 1),
|
remote_pcp_size=kv_transfer_params.get("remote_pcp_size", 1),
|
||||||
remote_dcp_size=kv_transfer_params.get("remote_dcp_size", 1),
|
remote_dcp_size=kv_transfer_params.get("remote_dcp_size", 1),
|
||||||
|
remote_ptp_size=kv_transfer_params.get("remote_ptp_size"),
|
||||||
remote_multi_nodes_meta_mapping=kv_transfer_params.get("remote_multi_nodes_meta_mapping", {}),
|
remote_multi_nodes_meta_mapping=kv_transfer_params.get("remote_multi_nodes_meta_mapping", {}),
|
||||||
num_prompt_blocks=kv_transfer_params.get("num_prompt_blocks", 0),
|
num_prompt_blocks=kv_transfer_params.get("num_prompt_blocks", 0),
|
||||||
)
|
)
|
||||||
@@ -890,6 +892,7 @@ class MooncakeConnectorScheduler:
|
|||||||
self.side_channel_host = get_ip()
|
self.side_channel_host = get_ip()
|
||||||
self.pcp_size = vllm_config.parallel_config.prefill_context_parallel_size
|
self.pcp_size = vllm_config.parallel_config.prefill_context_parallel_size
|
||||||
self.dcp_size = vllm_config.parallel_config.decode_context_parallel_size
|
self.dcp_size = vllm_config.parallel_config.decode_context_parallel_size
|
||||||
|
self.tp_size = vllm_config.parallel_config.tensor_parallel_size
|
||||||
self.max_device_id = (
|
self.max_device_id = (
|
||||||
vllm_config.parallel_config.tensor_parallel_size
|
vllm_config.parallel_config.tensor_parallel_size
|
||||||
* vllm_config.parallel_config.data_parallel_size
|
* vllm_config.parallel_config.data_parallel_size
|
||||||
@@ -1039,6 +1042,7 @@ class MooncakeConnectorScheduler:
|
|||||||
remote_port=self.side_channel_port,
|
remote_port=self.side_channel_port,
|
||||||
remote_pcp_size=self.pcp_size,
|
remote_pcp_size=self.pcp_size,
|
||||||
remote_dcp_size=self.dcp_size,
|
remote_dcp_size=self.dcp_size,
|
||||||
|
remote_ptp_size=self.tp_size,
|
||||||
last_token_id=request.output_token_ids[-1],
|
last_token_id=request.output_token_ids[-1],
|
||||||
remote_multi_nodes_meta_mapping=self.multi_nodes_meta_mapping,
|
remote_multi_nodes_meta_mapping=self.multi_nodes_meta_mapping,
|
||||||
num_prompt_blocks=num_prompt_blocks,
|
num_prompt_blocks=num_prompt_blocks,
|
||||||
@@ -1324,8 +1328,9 @@ class MooncakeConnectorWorker:
|
|||||||
In cp/dcp scenario, kv_cache may be split, so we need to pull multiple blocks from multiple remote P node.
|
In cp/dcp scenario, kv_cache may be split, so we need to pull multiple blocks from multiple remote P node.
|
||||||
Use this function to calculate remote port and remote block number of each remote P node that we need to pull.
|
Use this function to calculate remote port and remote block number of each remote P node that we need to pull.
|
||||||
"""
|
"""
|
||||||
|
prefill_tp_size = meta.remote_ptp_size if getattr(meta, "remote_ptp_size", None) else self._prefill_tp_size
|
||||||
if meta.remote_pcp_size * meta.remote_dcp_size * self.pcp_size * self.dcp_size == 1:
|
if meta.remote_pcp_size * meta.remote_dcp_size * self.pcp_size * self.dcp_size == 1:
|
||||||
choosen_rank_list = self._get_remote_rank(req_id)
|
choosen_rank_list = self._get_remote_rank(req_id, prefill_tp_size)
|
||||||
remote_handshake_port_list = [[x + meta.remote_port for x in choosen_rank_list]]
|
remote_handshake_port_list = [[x + meta.remote_port for x in choosen_rank_list]]
|
||||||
local_block_ids_list, remote_block_ids_list = [meta.local_block_ids], [meta.remote_block_ids]
|
local_block_ids_list, remote_block_ids_list = [meta.local_block_ids], [meta.remote_block_ids]
|
||||||
return remote_handshake_port_list, local_block_ids_list, remote_block_ids_list
|
return remote_handshake_port_list, local_block_ids_list, remote_block_ids_list
|
||||||
@@ -1333,7 +1338,7 @@ class MooncakeConnectorWorker:
|
|||||||
def context_parallel_parameters_check():
|
def context_parallel_parameters_check():
|
||||||
assert (meta.remote_pcp_size * meta.remote_dcp_size) % (self.pcp_size * self.dcp_size) == 0
|
assert (meta.remote_pcp_size * meta.remote_dcp_size) % (self.pcp_size * self.dcp_size) == 0
|
||||||
if not self.use_mla:
|
if not self.use_mla:
|
||||||
p_node_heads_per_rank = math.ceil(self.num_key_value_heads / self._prefill_tp_size)
|
p_node_heads_per_rank = math.ceil(self.num_key_value_heads / prefill_tp_size)
|
||||||
d_node_heads_per_rank = math.ceil(self.num_key_value_heads / self.tp_size)
|
d_node_heads_per_rank = math.ceil(self.num_key_value_heads / self.tp_size)
|
||||||
assert d_node_heads_per_rank % p_node_heads_per_rank == 0
|
assert d_node_heads_per_rank % p_node_heads_per_rank == 0
|
||||||
|
|
||||||
@@ -1387,7 +1392,7 @@ class MooncakeConnectorWorker:
|
|||||||
def get_local_remote_block_port_mappings():
|
def get_local_remote_block_port_mappings():
|
||||||
context_parallel_parameters_check()
|
context_parallel_parameters_check()
|
||||||
p_node_cp_group_meta = get_cp_group_meta(
|
p_node_cp_group_meta = get_cp_group_meta(
|
||||||
self._prefill_tp_size, meta.remote_pcp_size, meta.remote_dcp_size, meta.remote_port
|
prefill_tp_size, meta.remote_pcp_size, meta.remote_dcp_size, meta.remote_port
|
||||||
)
|
)
|
||||||
d_node_cp_group_meta = get_cp_group_meta(self.tp_size, self.pcp_size, self.dcp_size, self.side_channel_port)
|
d_node_cp_group_meta = get_cp_group_meta(self.tp_size, self.pcp_size, self.dcp_size, self.side_channel_port)
|
||||||
local_remote_block_port_mappings: dict[int, list[list[int]]] = {}
|
local_remote_block_port_mappings: dict[int, list[list[int]]] = {}
|
||||||
@@ -1427,7 +1432,7 @@ class MooncakeConnectorWorker:
|
|||||||
local_remote_block_port_mappings: dict[int, list[list[int]]],
|
local_remote_block_port_mappings: dict[int, list[list[int]]],
|
||||||
) -> dict[int, RemotePortInfo]:
|
) -> dict[int, RemotePortInfo]:
|
||||||
remote_port_send_num: dict[int, RemotePortInfo] = {}
|
remote_port_send_num: dict[int, RemotePortInfo] = {}
|
||||||
for port in range(self._prefill_tp_size * meta.remote_pcp_size):
|
for port in range(prefill_tp_size * meta.remote_pcp_size):
|
||||||
remote_host_info = meta.remote_multi_nodes_meta_mapping.get(str(port), None)
|
remote_host_info = meta.remote_multi_nodes_meta_mapping.get(str(port), None)
|
||||||
if remote_host_info is None:
|
if remote_host_info is None:
|
||||||
remote_host = meta.remote_host
|
remote_host = meta.remote_host
|
||||||
@@ -1518,8 +1523,9 @@ class MooncakeConnectorWorker:
|
|||||||
)
|
)
|
||||||
local_block_offset += num_blocks_to_pull
|
local_block_offset += num_blocks_to_pull
|
||||||
|
|
||||||
assert self.tp_num_need_pulls == len(remote_handshake_port_list[0]), (
|
tp_num_need_pulls = self._get_tp_num_need_pulls(prefill_tp_size)
|
||||||
f"tp_num_need_pulls: {self.tp_num_need_pulls}, remote_handshake_port_list: {remote_handshake_port_list}"
|
assert tp_num_need_pulls == len(remote_handshake_port_list[0]), (
|
||||||
|
f"tp_num_need_pulls: {tp_num_need_pulls}, remote_handshake_port_list: {remote_handshake_port_list}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return remote_handshake_port_list, local_block_ids_list, remote_block_ids_list
|
return remote_handshake_port_list, local_block_ids_list, remote_block_ids_list
|
||||||
@@ -1535,13 +1541,17 @@ class MooncakeConnectorWorker:
|
|||||||
len(meta.local_block_ids),
|
len(meta.local_block_ids),
|
||||||
len(meta.remote_block_ids),
|
len(meta.remote_block_ids),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
prefill_tp_size = meta.remote_ptp_size if getattr(meta, "remote_ptp_size", None) else self._prefill_tp_size
|
||||||
|
tp_num_need_pulls = self._get_tp_num_need_pulls(prefill_tp_size)
|
||||||
|
|
||||||
if meta.remote_pcp_size * meta.remote_dcp_size > 1:
|
if meta.remote_pcp_size * meta.remote_dcp_size > 1:
|
||||||
remote_handshake_port_list, local_block_ids_list, remote_block_ids_list = self._get_kv_split_metadata(
|
remote_handshake_port_list, local_block_ids_list, remote_block_ids_list = self._get_kv_split_metadata(
|
||||||
req_id, meta
|
req_id, meta
|
||||||
)
|
)
|
||||||
|
|
||||||
for pcp_dcp_rank in range(len(remote_handshake_port_list)):
|
for pcp_dcp_rank in range(len(remote_handshake_port_list)):
|
||||||
for i in range(self.tp_num_need_pulls):
|
for i in range(tp_num_need_pulls):
|
||||||
assert self.kv_recv_thread is not None
|
assert self.kv_recv_thread is not None
|
||||||
remote_host, remote_engine_id = self._get_remote_host_info_by_port(
|
remote_host, remote_engine_id = self._get_remote_host_info_by_port(
|
||||||
meta.remote_port,
|
meta.remote_port,
|
||||||
@@ -1559,16 +1569,16 @@ class MooncakeConnectorWorker:
|
|||||||
remote_host=remote_host,
|
remote_host=remote_host,
|
||||||
remote_handshake_port=remote_handshake_port_list[pcp_dcp_rank][i],
|
remote_handshake_port=remote_handshake_port_list[pcp_dcp_rank][i],
|
||||||
offset=i,
|
offset=i,
|
||||||
tp_num_need_pulls=self.tp_num_need_pulls,
|
tp_num_need_pulls=tp_num_need_pulls,
|
||||||
remote_port_send_num=self.remote_port_send_num[meta.remote_engine_id],
|
remote_port_send_num=self.remote_port_send_num[meta.remote_engine_id],
|
||||||
all_task_done=(
|
all_task_done=(
|
||||||
pcp_dcp_rank == len(remote_handshake_port_list) - 1 and i == self.tp_num_need_pulls - 1
|
pcp_dcp_rank == len(remote_handshake_port_list) - 1 and i == tp_num_need_pulls - 1
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
else: # TODO: support prefill context parallel and pipeline parallel open at the same time
|
else: # TODO: support prefill context parallel and pipeline parallel open at the same time
|
||||||
choosen_rank_list = self._get_remote_rank(req_id)
|
choosen_rank_list = self._get_remote_rank(req_id, prefill_tp_size)
|
||||||
remote_handshake_port_list = [[x + meta.remote_port] for x in choosen_rank_list]
|
remote_handshake_port_list = [[x + meta.remote_port] for x in choosen_rank_list]
|
||||||
for i in range(self.tp_num_need_pulls * self._prefill_pp_size):
|
for i in range(tp_num_need_pulls * self._prefill_pp_size):
|
||||||
assert self.kv_recv_thread is not None
|
assert self.kv_recv_thread is not None
|
||||||
remote_host, remote_engine_id = self._get_remote_host_info_by_port(
|
remote_host, remote_engine_id = self._get_remote_host_info_by_port(
|
||||||
meta.remote_port,
|
meta.remote_port,
|
||||||
@@ -1586,8 +1596,8 @@ class MooncakeConnectorWorker:
|
|||||||
remote_host=remote_host,
|
remote_host=remote_host,
|
||||||
remote_handshake_port=remote_handshake_port_list[i][0],
|
remote_handshake_port=remote_handshake_port_list[i][0],
|
||||||
offset=i,
|
offset=i,
|
||||||
tp_num_need_pulls=self.tp_num_need_pulls,
|
tp_num_need_pulls=tp_num_need_pulls,
|
||||||
all_task_done=(i == self.tp_num_need_pulls * self._prefill_pp_size - 1),
|
all_task_done=(i == tp_num_need_pulls * self._prefill_pp_size - 1),
|
||||||
)
|
)
|
||||||
|
|
||||||
for req_id in metadata.reqs_in_batch:
|
for req_id in metadata.reqs_in_batch:
|
||||||
@@ -1607,6 +1617,21 @@ class MooncakeConnectorWorker:
|
|||||||
for req_id, delay_start_time in metadata.requests_to_send.items():
|
for req_id, delay_start_time in metadata.requests_to_send.items():
|
||||||
self.kv_send_thread.add_delayed_request(req_id, delay_start_time)
|
self.kv_send_thread.add_delayed_request(req_id, delay_start_time)
|
||||||
|
|
||||||
|
def _get_tp_num_need_pulls(self, prefill_tp_size: int) -> int:
|
||||||
|
if prefill_tp_size is None:
|
||||||
|
prefill_tp_size = self._prefill_tp_size
|
||||||
|
|
||||||
|
if prefill_tp_size == self._prefill_tp_size:
|
||||||
|
return self.tp_num_need_pulls
|
||||||
|
|
||||||
|
if self.vllm_config.model_config.is_deepseek_mla:
|
||||||
|
tp_num_need_pulls = 1
|
||||||
|
else:
|
||||||
|
num_d_block_heads = max(1, self.num_key_value_heads // self.tp_size)
|
||||||
|
num_p_block_heads = max(1, self.num_key_value_heads // prefill_tp_size)
|
||||||
|
tp_num_need_pulls = num_d_block_heads // num_p_block_heads
|
||||||
|
return tp_num_need_pulls
|
||||||
|
|
||||||
def _get_remote_host_info_by_port(
|
def _get_remote_host_info_by_port(
|
||||||
self,
|
self,
|
||||||
base_port: int,
|
base_port: int,
|
||||||
@@ -1624,16 +1649,17 @@ class MooncakeConnectorWorker:
|
|||||||
def _prefill_get_remote_rank(self, req_id: str) -> list[int]:
|
def _prefill_get_remote_rank(self, req_id: str) -> list[int]:
|
||||||
return sum(self._get_remote_ranks_for_req(req_id), [])
|
return sum(self._get_remote_ranks_for_req(req_id), [])
|
||||||
|
|
||||||
def _get_remote_rank(self, req_id: str) -> list[int]:
|
def _get_remote_rank(self, req_id: str, prefill_tp_size: int | None = None) -> list[int]:
|
||||||
return self._get_remote_ranks_for_req(req_id)[self.tp_rank]
|
return self._get_remote_ranks_for_req(req_id, prefill_tp_size)[self.tp_rank]
|
||||||
|
|
||||||
def _get_remote_tp_ranks(
|
def _get_remote_tp_ranks(
|
||||||
self, tp_ori_data: np.ndarray, rand_group_index: list[int], num_groups: int
|
self, tp_ori_data: np.ndarray, rand_group_index: list[int], num_groups: int, prefill_tp_size: int
|
||||||
) -> list[list[int]]:
|
) -> list[list[int]]:
|
||||||
|
tp_num_need_pulls = self._get_tp_num_need_pulls(prefill_tp_size)
|
||||||
# random split prefill tp list
|
# random split prefill tp list
|
||||||
tp_sampled_nums = []
|
tp_sampled_nums = []
|
||||||
if (
|
if (
|
||||||
self._prefill_tp_size > self.num_key_value_heads
|
prefill_tp_size > self.num_key_value_heads
|
||||||
or self.vllm_config.model_config.is_deepseek_mla
|
or self.vllm_config.model_config.is_deepseek_mla
|
||||||
or self.use_sparse
|
or self.use_sparse
|
||||||
):
|
):
|
||||||
@@ -1641,24 +1667,27 @@ class MooncakeConnectorWorker:
|
|||||||
choosen_group = tp_ori_data[:, [rand_group_index]]
|
choosen_group = tp_ori_data[:, [rand_group_index]]
|
||||||
flattened = choosen_group.reshape(-1).tolist()
|
flattened = choosen_group.reshape(-1).tolist()
|
||||||
tp_sampled_nums = [
|
tp_sampled_nums = [
|
||||||
flattened[i : i + self.tp_num_need_pulls] for i in range(0, len(flattened), self.tp_num_need_pulls)
|
flattened[i : i + tp_num_need_pulls] for i in range(0, len(flattened), tp_num_need_pulls)
|
||||||
]
|
]
|
||||||
# non-random split
|
# non-random split
|
||||||
else:
|
else:
|
||||||
group_size = self._prefill_tp_size // self._decode_tp_size
|
group_size = prefill_tp_size // self._decode_tp_size
|
||||||
for i in range(self._decode_tp_size):
|
for i in range(self._decode_tp_size):
|
||||||
slice = tp_ori_data[i * group_size : (i + 1) * group_size]
|
slice = tp_ori_data[i * group_size : (i + 1) * group_size]
|
||||||
tp_sampled_nums.append(slice.tolist())
|
tp_sampled_nums.append(slice.tolist())
|
||||||
return tp_sampled_nums
|
return tp_sampled_nums
|
||||||
|
|
||||||
def _get_remote_ranks_for_req(self, req_id: str) -> list[list[int]]:
|
def _get_remote_ranks_for_req(self, req_id: str, prefill_tp_size: int | None = None) -> list[list[int]]:
|
||||||
|
if prefill_tp_size is None:
|
||||||
|
prefill_tp_size = self._prefill_tp_size
|
||||||
|
|
||||||
# Divide the ports according to the TP within the PP
|
# Divide the ports according to the TP within the PP
|
||||||
sampled_nums = []
|
sampled_nums = []
|
||||||
if self._prefill_tp_size == self._decode_tp_size:
|
if prefill_tp_size == self._decode_tp_size:
|
||||||
sampled_nums = list(
|
sampled_nums = list(
|
||||||
map(
|
map(
|
||||||
lambda tp: [tp + pp * self._prefill_tp_size for pp in range(self._prefill_pp_size)],
|
lambda tp: [tp + pp * prefill_tp_size for pp in range(self._prefill_pp_size)],
|
||||||
range(self._prefill_tp_size),
|
range(prefill_tp_size),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return sampled_nums
|
return sampled_nums
|
||||||
@@ -1667,7 +1696,7 @@ class MooncakeConnectorWorker:
|
|||||||
num_kv_head = 1
|
num_kv_head = 1
|
||||||
else:
|
else:
|
||||||
num_kv_head = self.num_key_value_heads
|
num_kv_head = self.num_key_value_heads
|
||||||
ori_data = np.arange(self._prefill_tp_size * self._prefill_pp_size)
|
ori_data = np.arange(prefill_tp_size * self._prefill_pp_size)
|
||||||
seed = string_to_int64_hash(req_id)
|
seed = string_to_int64_hash(req_id)
|
||||||
rand = random.Random(seed)
|
rand = random.Random(seed)
|
||||||
# random split prefill tp list
|
# random split prefill tp list
|
||||||
@@ -1679,7 +1708,7 @@ class MooncakeConnectorWorker:
|
|||||||
range(num_groups), (max(self._decode_tp_size // num_kv_head, 1))
|
range(num_groups), (max(self._decode_tp_size // num_kv_head, 1))
|
||||||
) # random choose a group
|
) # random choose a group
|
||||||
all_results = [
|
all_results = [
|
||||||
self._get_remote_tp_ranks(ori_data[pp_index], rand_group_index, num_groups)
|
self._get_remote_tp_ranks(ori_data[pp_index], rand_group_index, num_groups, prefill_tp_size)
|
||||||
for pp_index in range(self._prefill_pp_size)
|
for pp_index in range(self._prefill_pp_size)
|
||||||
]
|
]
|
||||||
for group_index in range(len(all_results[0])):
|
for group_index in range(len(all_results[0])):
|
||||||
|
|||||||
Reference in New Issue
Block a user