Files
enginex-mlu590-vllm/vllm_mlu/mlu_forward_context.py
2026-04-24 09:58:03 +08:00

121 lines
5.2 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from dataclasses import dataclass
from typing import List
from vllm.forward_context import DPMetadata
@dataclass
class MLUDPMetadata(DPMetadata):
# mlu platform arguments
# token num for current dp group
token_num: int = None
# token num offset for current dp group
token_num_offset: int = None
# whether we can use reduce scatter for both attn layer and mlp layer
layer_use_reduce_scatter: bool = False
# token num need to be pad for prefill, then we can do reduce scatter +
# all gather to optimize comm time
prefill_pad_to_token_num: int = -1
# token num in each dp group, the list length is attn data parallel size
# used to do all gather in dp groups after all reduce in attn
token_split_list: List[int] = None
# token num in each card, the list length is world size
# used to do all gather in all cards after reduce scatter in attn
attn_token_split_list_reduce_scatter: List[int] = None
# token num in each tp group, the list length is tensor parallel size
# used to do all gather in tp groups after reduce scatter in moe
moe_token_split_list_reduce_scatter: List[int] = None
# prefill or decode stage in each dp group
dp_is_prefill: List[bool] = None
# ADDITIONAL fields for merged compute and communication.
# Global sequence lengths for each batch size for prefill stage.
seq_lens: List[int] = None
# Batch sizes for each attn dp rank for prefill stage.
batch_sizes: List[int] = None
# ADDITIONAL fields for custom split for embedding, logits and dense mlp layer
# token num in each emb tp group, the list length is tensor parallel size
# used to do all gather in emb tp groups after reduce scatter in moe
emb_token_split_list: List[int] = None
# batch sizes in each logits tp group, the list length is tensor parallel size
# used to do all gather in logits tp groups after reduce scatter in moe
logits_batch_split_list: List[int] = None
# token num in each dense mlp group, the list length is dense mlp tp size
# used to do one more all gather after dense mlp and before reduce scatter
dense_attn_token_split_list: List[int] = None
@staticmethod
def make_oot(
data_parallel_rank: int,
data_parallel_size: int,
tensor_parallel_size: int,
dp_token_nums: List[int],
dp_is_prefill: List[bool],
prefill_dispatch_use_RS_AG: bool,
seq_lens: List[int] = None,
batch_sizes: List[int] = None,
emb_query_lens: List[int] = None,
logits_batch_sizes: List[int] = None,
dense_attn_token_split_list: List[int] = None,
) -> "MLUDPMetadata":
token_num_offset = sum(dp_token_nums[:data_parallel_rank])
token_num = dp_token_nums[data_parallel_rank]
token_split_list = dp_token_nums
attn_can_use_reduce_scatter = all(
(num != 0 and num % tensor_parallel_size == 0)
for num in token_split_list
)
all_split_token_num_equal = all(
num == token_split_list[0] for num in token_split_list
)
layer_can_use_reduce_scatter = (
attn_can_use_reduce_scatter and all_split_token_num_equal
)
attn_token_split_list_reduce_scatter = None
moe_token_split_list_reduce_scatter = None
prefill_pad_to_token_num = -1
tp_world_size = data_parallel_size * tensor_parallel_size
if layer_can_use_reduce_scatter:
attn_token_split_list_reduce_scatter = (
[token_split_list[0] // tensor_parallel_size] * tp_world_size
)
moe_token_split_list_reduce_scatter = (
attn_token_split_list_reduce_scatter[:tensor_parallel_size]
)
elif (
prefill_dispatch_use_RS_AG
and all(is_prefill for is_prefill in dp_is_prefill)
):
dp_group_max_token_nums = max(dp_token_nums)
prefill_pad_to_token_num = (
(dp_group_max_token_nums + tensor_parallel_size - 1)
// tensor_parallel_size
) * tensor_parallel_size
attn_token_split_list_reduce_scatter = (
[prefill_pad_to_token_num // tensor_parallel_size] * tp_world_size
)
return MLUDPMetadata(
max_tokens_across_dp_cpu=None,
num_tokens_across_dp_cpu=None,
token_num=token_num,
token_num_offset=token_num_offset,
token_split_list=token_split_list,
layer_use_reduce_scatter=layer_can_use_reduce_scatter,
prefill_pad_to_token_num=prefill_pad_to_token_num,
attn_token_split_list_reduce_scatter=attn_token_split_list_reduce_scatter,
moe_token_split_list_reduce_scatter=moe_token_split_list_reduce_scatter,
seq_lens=seq_lens,
batch_sizes=batch_sizes,
dp_is_prefill=dp_is_prefill,
emb_token_split_list=emb_query_lens,
logits_batch_split_list=logits_batch_sizes,
dense_attn_token_split_list=dense_attn_token_split_list,
)