[Feature] Mooncake connector get remote ptp size (#5822)
### What this PR does / why we need it?
To support elastic scaling when using mooncake connector, we should
support to **configure different tp sizes for different nodes**.
As a result, we transfer the prefill node information, such as tp size,
through **the request's kv_transfer_params**.
The decode nodes **get the prefill tp size** through the request's
kv_transfer_params, instead of getting it from the configuration of the
mooncake connector .
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef
Signed-off-by: yuxinshan <syx_ctyg@126.com>
Signed-off-by: CalvinXKY <kyxiezju@163.com>
This commit is contained in:
@@ -87,6 +87,7 @@ class ReqMeta:
|
||||
remote_request_id: str
|
||||
remote_pcp_size: int
|
||||
remote_dcp_size: int
|
||||
remote_ptp_size: int | None
|
||||
remote_multi_nodes_meta_mapping: dict[str, dict[str, Any]]
|
||||
num_prompt_blocks: int
|
||||
|
||||
@@ -773,6 +774,7 @@ class MooncakeConnectorMetadata(KVConnectorMetadata):
|
||||
remote_port=kv_transfer_params["remote_port"],
|
||||
remote_pcp_size=kv_transfer_params.get("remote_pcp_size", 1),
|
||||
remote_dcp_size=kv_transfer_params.get("remote_dcp_size", 1),
|
||||
remote_ptp_size=kv_transfer_params.get("remote_ptp_size"),
|
||||
remote_multi_nodes_meta_mapping=kv_transfer_params.get("remote_multi_nodes_meta_mapping", {}),
|
||||
num_prompt_blocks=kv_transfer_params.get("num_prompt_blocks", 0),
|
||||
)
|
||||
@@ -890,6 +892,7 @@ class MooncakeConnectorScheduler:
|
||||
self.side_channel_host = get_ip()
|
||||
self.pcp_size = vllm_config.parallel_config.prefill_context_parallel_size
|
||||
self.dcp_size = vllm_config.parallel_config.decode_context_parallel_size
|
||||
self.tp_size = vllm_config.parallel_config.tensor_parallel_size
|
||||
self.max_device_id = (
|
||||
vllm_config.parallel_config.tensor_parallel_size
|
||||
* vllm_config.parallel_config.data_parallel_size
|
||||
@@ -1039,6 +1042,7 @@ class MooncakeConnectorScheduler:
|
||||
remote_port=self.side_channel_port,
|
||||
remote_pcp_size=self.pcp_size,
|
||||
remote_dcp_size=self.dcp_size,
|
||||
remote_ptp_size=self.tp_size,
|
||||
last_token_id=request.output_token_ids[-1],
|
||||
remote_multi_nodes_meta_mapping=self.multi_nodes_meta_mapping,
|
||||
num_prompt_blocks=num_prompt_blocks,
|
||||
@@ -1324,8 +1328,9 @@ class MooncakeConnectorWorker:
|
||||
In cp/dcp scenario, kv_cache may be split, so we need to pull multiple blocks from multiple remote P node.
|
||||
Use this function to calculate remote port and remote block number of each remote P node that we need to pull.
|
||||
"""
|
||||
prefill_tp_size = meta.remote_ptp_size if getattr(meta, "remote_ptp_size", None) else self._prefill_tp_size
|
||||
if meta.remote_pcp_size * meta.remote_dcp_size * self.pcp_size * self.dcp_size == 1:
|
||||
choosen_rank_list = self._get_remote_rank(req_id)
|
||||
choosen_rank_list = self._get_remote_rank(req_id, prefill_tp_size)
|
||||
remote_handshake_port_list = [[x + meta.remote_port for x in choosen_rank_list]]
|
||||
local_block_ids_list, remote_block_ids_list = [meta.local_block_ids], [meta.remote_block_ids]
|
||||
return remote_handshake_port_list, local_block_ids_list, remote_block_ids_list
|
||||
@@ -1333,7 +1338,7 @@ class MooncakeConnectorWorker:
|
||||
def context_parallel_parameters_check():
|
||||
assert (meta.remote_pcp_size * meta.remote_dcp_size) % (self.pcp_size * self.dcp_size) == 0
|
||||
if not self.use_mla:
|
||||
p_node_heads_per_rank = math.ceil(self.num_key_value_heads / self._prefill_tp_size)
|
||||
p_node_heads_per_rank = math.ceil(self.num_key_value_heads / prefill_tp_size)
|
||||
d_node_heads_per_rank = math.ceil(self.num_key_value_heads / self.tp_size)
|
||||
assert d_node_heads_per_rank % p_node_heads_per_rank == 0
|
||||
|
||||
@@ -1387,7 +1392,7 @@ class MooncakeConnectorWorker:
|
||||
def get_local_remote_block_port_mappings():
|
||||
context_parallel_parameters_check()
|
||||
p_node_cp_group_meta = get_cp_group_meta(
|
||||
self._prefill_tp_size, meta.remote_pcp_size, meta.remote_dcp_size, meta.remote_port
|
||||
prefill_tp_size, meta.remote_pcp_size, meta.remote_dcp_size, meta.remote_port
|
||||
)
|
||||
d_node_cp_group_meta = get_cp_group_meta(self.tp_size, self.pcp_size, self.dcp_size, self.side_channel_port)
|
||||
local_remote_block_port_mappings: dict[int, list[list[int]]] = {}
|
||||
@@ -1427,7 +1432,7 @@ class MooncakeConnectorWorker:
|
||||
local_remote_block_port_mappings: dict[int, list[list[int]]],
|
||||
) -> dict[int, RemotePortInfo]:
|
||||
remote_port_send_num: dict[int, RemotePortInfo] = {}
|
||||
for port in range(self._prefill_tp_size * meta.remote_pcp_size):
|
||||
for port in range(prefill_tp_size * meta.remote_pcp_size):
|
||||
remote_host_info = meta.remote_multi_nodes_meta_mapping.get(str(port), None)
|
||||
if remote_host_info is None:
|
||||
remote_host = meta.remote_host
|
||||
@@ -1518,8 +1523,9 @@ class MooncakeConnectorWorker:
|
||||
)
|
||||
local_block_offset += num_blocks_to_pull
|
||||
|
||||
assert self.tp_num_need_pulls == len(remote_handshake_port_list[0]), (
|
||||
f"tp_num_need_pulls: {self.tp_num_need_pulls}, remote_handshake_port_list: {remote_handshake_port_list}"
|
||||
tp_num_need_pulls = self._get_tp_num_need_pulls(prefill_tp_size)
|
||||
assert tp_num_need_pulls == len(remote_handshake_port_list[0]), (
|
||||
f"tp_num_need_pulls: {tp_num_need_pulls}, remote_handshake_port_list: {remote_handshake_port_list}"
|
||||
)
|
||||
|
||||
return remote_handshake_port_list, local_block_ids_list, remote_block_ids_list
|
||||
@@ -1535,13 +1541,17 @@ class MooncakeConnectorWorker:
|
||||
len(meta.local_block_ids),
|
||||
len(meta.remote_block_ids),
|
||||
)
|
||||
|
||||
prefill_tp_size = meta.remote_ptp_size if getattr(meta, "remote_ptp_size", None) else self._prefill_tp_size
|
||||
tp_num_need_pulls = self._get_tp_num_need_pulls(prefill_tp_size)
|
||||
|
||||
if meta.remote_pcp_size * meta.remote_dcp_size > 1:
|
||||
remote_handshake_port_list, local_block_ids_list, remote_block_ids_list = self._get_kv_split_metadata(
|
||||
req_id, meta
|
||||
)
|
||||
|
||||
for pcp_dcp_rank in range(len(remote_handshake_port_list)):
|
||||
for i in range(self.tp_num_need_pulls):
|
||||
for i in range(tp_num_need_pulls):
|
||||
assert self.kv_recv_thread is not None
|
||||
remote_host, remote_engine_id = self._get_remote_host_info_by_port(
|
||||
meta.remote_port,
|
||||
@@ -1559,16 +1569,16 @@ class MooncakeConnectorWorker:
|
||||
remote_host=remote_host,
|
||||
remote_handshake_port=remote_handshake_port_list[pcp_dcp_rank][i],
|
||||
offset=i,
|
||||
tp_num_need_pulls=self.tp_num_need_pulls,
|
||||
tp_num_need_pulls=tp_num_need_pulls,
|
||||
remote_port_send_num=self.remote_port_send_num[meta.remote_engine_id],
|
||||
all_task_done=(
|
||||
pcp_dcp_rank == len(remote_handshake_port_list) - 1 and i == self.tp_num_need_pulls - 1
|
||||
pcp_dcp_rank == len(remote_handshake_port_list) - 1 and i == tp_num_need_pulls - 1
|
||||
),
|
||||
)
|
||||
else: # TODO: support prefill context parallel and pipeline parallel open at the same time
|
||||
choosen_rank_list = self._get_remote_rank(req_id)
|
||||
choosen_rank_list = self._get_remote_rank(req_id, prefill_tp_size)
|
||||
remote_handshake_port_list = [[x + meta.remote_port] for x in choosen_rank_list]
|
||||
for i in range(self.tp_num_need_pulls * self._prefill_pp_size):
|
||||
for i in range(tp_num_need_pulls * self._prefill_pp_size):
|
||||
assert self.kv_recv_thread is not None
|
||||
remote_host, remote_engine_id = self._get_remote_host_info_by_port(
|
||||
meta.remote_port,
|
||||
@@ -1586,8 +1596,8 @@ class MooncakeConnectorWorker:
|
||||
remote_host=remote_host,
|
||||
remote_handshake_port=remote_handshake_port_list[i][0],
|
||||
offset=i,
|
||||
tp_num_need_pulls=self.tp_num_need_pulls,
|
||||
all_task_done=(i == self.tp_num_need_pulls * self._prefill_pp_size - 1),
|
||||
tp_num_need_pulls=tp_num_need_pulls,
|
||||
all_task_done=(i == tp_num_need_pulls * self._prefill_pp_size - 1),
|
||||
)
|
||||
|
||||
for req_id in metadata.reqs_in_batch:
|
||||
@@ -1607,6 +1617,21 @@ class MooncakeConnectorWorker:
|
||||
for req_id, delay_start_time in metadata.requests_to_send.items():
|
||||
self.kv_send_thread.add_delayed_request(req_id, delay_start_time)
|
||||
|
||||
def _get_tp_num_need_pulls(self, prefill_tp_size: int) -> int:
|
||||
if prefill_tp_size is None:
|
||||
prefill_tp_size = self._prefill_tp_size
|
||||
|
||||
if prefill_tp_size == self._prefill_tp_size:
|
||||
return self.tp_num_need_pulls
|
||||
|
||||
if self.vllm_config.model_config.is_deepseek_mla:
|
||||
tp_num_need_pulls = 1
|
||||
else:
|
||||
num_d_block_heads = max(1, self.num_key_value_heads // self.tp_size)
|
||||
num_p_block_heads = max(1, self.num_key_value_heads // prefill_tp_size)
|
||||
tp_num_need_pulls = num_d_block_heads // num_p_block_heads
|
||||
return tp_num_need_pulls
|
||||
|
||||
def _get_remote_host_info_by_port(
|
||||
self,
|
||||
base_port: int,
|
||||
@@ -1624,16 +1649,17 @@ class MooncakeConnectorWorker:
|
||||
def _prefill_get_remote_rank(self, req_id: str) -> list[int]:
|
||||
return sum(self._get_remote_ranks_for_req(req_id), [])
|
||||
|
||||
def _get_remote_rank(self, req_id: str) -> list[int]:
|
||||
return self._get_remote_ranks_for_req(req_id)[self.tp_rank]
|
||||
def _get_remote_rank(self, req_id: str, prefill_tp_size: int | None = None) -> list[int]:
|
||||
return self._get_remote_ranks_for_req(req_id, prefill_tp_size)[self.tp_rank]
|
||||
|
||||
def _get_remote_tp_ranks(
|
||||
self, tp_ori_data: np.ndarray, rand_group_index: list[int], num_groups: int
|
||||
self, tp_ori_data: np.ndarray, rand_group_index: list[int], num_groups: int, prefill_tp_size: int
|
||||
) -> list[list[int]]:
|
||||
tp_num_need_pulls = self._get_tp_num_need_pulls(prefill_tp_size)
|
||||
# random split prefill tp list
|
||||
tp_sampled_nums = []
|
||||
if (
|
||||
self._prefill_tp_size > self.num_key_value_heads
|
||||
prefill_tp_size > self.num_key_value_heads
|
||||
or self.vllm_config.model_config.is_deepseek_mla
|
||||
or self.use_sparse
|
||||
):
|
||||
@@ -1641,24 +1667,27 @@ class MooncakeConnectorWorker:
|
||||
choosen_group = tp_ori_data[:, [rand_group_index]]
|
||||
flattened = choosen_group.reshape(-1).tolist()
|
||||
tp_sampled_nums = [
|
||||
flattened[i : i + self.tp_num_need_pulls] for i in range(0, len(flattened), self.tp_num_need_pulls)
|
||||
flattened[i : i + tp_num_need_pulls] for i in range(0, len(flattened), tp_num_need_pulls)
|
||||
]
|
||||
# non-random split
|
||||
else:
|
||||
group_size = self._prefill_tp_size // self._decode_tp_size
|
||||
group_size = prefill_tp_size // self._decode_tp_size
|
||||
for i in range(self._decode_tp_size):
|
||||
slice = tp_ori_data[i * group_size : (i + 1) * group_size]
|
||||
tp_sampled_nums.append(slice.tolist())
|
||||
return tp_sampled_nums
|
||||
|
||||
def _get_remote_ranks_for_req(self, req_id: str) -> list[list[int]]:
|
||||
def _get_remote_ranks_for_req(self, req_id: str, prefill_tp_size: int | None = None) -> list[list[int]]:
|
||||
if prefill_tp_size is None:
|
||||
prefill_tp_size = self._prefill_tp_size
|
||||
|
||||
# Divide the ports according to the TP within the PP
|
||||
sampled_nums = []
|
||||
if self._prefill_tp_size == self._decode_tp_size:
|
||||
if prefill_tp_size == self._decode_tp_size:
|
||||
sampled_nums = list(
|
||||
map(
|
||||
lambda tp: [tp + pp * self._prefill_tp_size for pp in range(self._prefill_pp_size)],
|
||||
range(self._prefill_tp_size),
|
||||
lambda tp: [tp + pp * prefill_tp_size for pp in range(self._prefill_pp_size)],
|
||||
range(prefill_tp_size),
|
||||
)
|
||||
)
|
||||
return sampled_nums
|
||||
@@ -1667,7 +1696,7 @@ class MooncakeConnectorWorker:
|
||||
num_kv_head = 1
|
||||
else:
|
||||
num_kv_head = self.num_key_value_heads
|
||||
ori_data = np.arange(self._prefill_tp_size * self._prefill_pp_size)
|
||||
ori_data = np.arange(prefill_tp_size * self._prefill_pp_size)
|
||||
seed = string_to_int64_hash(req_id)
|
||||
rand = random.Random(seed)
|
||||
# random split prefill tp list
|
||||
@@ -1679,7 +1708,7 @@ class MooncakeConnectorWorker:
|
||||
range(num_groups), (max(self._decode_tp_size // num_kv_head, 1))
|
||||
) # random choose a group
|
||||
all_results = [
|
||||
self._get_remote_tp_ranks(ori_data[pp_index], rand_group_index, num_groups)
|
||||
self._get_remote_tp_ranks(ori_data[pp_index], rand_group_index, num_groups, prefill_tp_size)
|
||||
for pp_index in range(self._prefill_pp_size)
|
||||
]
|
||||
for group_index in range(len(all_results[0])):
|
||||
|
||||
Reference in New Issue
Block a user