121 lines
5.2 KiB
Python
121 lines
5.2 KiB
Python
|
|
# SPDX-License-Identifier: Apache-2.0
|
||
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||
|
|
|
||
|
|
from dataclasses import dataclass
|
||
|
|
from typing import List
|
||
|
|
|
||
|
|
from vllm.forward_context import DPMetadata
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class MLUDPMetadata(DPMetadata):
|
||
|
|
# mlu platform arguments
|
||
|
|
# token num for current dp group
|
||
|
|
token_num: int = None
|
||
|
|
# token num offset for current dp group
|
||
|
|
token_num_offset: int = None
|
||
|
|
# whether we can use reduce scatter for both attn layer and mlp layer
|
||
|
|
layer_use_reduce_scatter: bool = False
|
||
|
|
# token num need to be pad for prefill, then we can do reduce scatter +
|
||
|
|
# all gather to optimize comm time
|
||
|
|
prefill_pad_to_token_num: int = -1
|
||
|
|
# token num in each dp group, the list length is attn data parallel size
|
||
|
|
# used to do all gather in dp groups after all reduce in attn
|
||
|
|
token_split_list: List[int] = None
|
||
|
|
# token num in each card, the list length is world size
|
||
|
|
# used to do all gather in all cards after reduce scatter in attn
|
||
|
|
attn_token_split_list_reduce_scatter: List[int] = None
|
||
|
|
# token num in each tp group, the list length is tensor parallel size
|
||
|
|
# used to do all gather in tp groups after reduce scatter in moe
|
||
|
|
moe_token_split_list_reduce_scatter: List[int] = None
|
||
|
|
# prefill or decode stage in each dp group
|
||
|
|
dp_is_prefill: List[bool] = None
|
||
|
|
|
||
|
|
# ADDITIONAL fields for merged compute and communication.
|
||
|
|
# Global sequence lengths for each batch size for prefill stage.
|
||
|
|
seq_lens: List[int] = None
|
||
|
|
# Batch sizes for each attn dp rank for prefill stage.
|
||
|
|
batch_sizes: List[int] = None
|
||
|
|
|
||
|
|
# ADDITIONAL fields for custom split for embedding, logits and dense mlp layer
|
||
|
|
# token num in each emb tp group, the list length is tensor parallel size
|
||
|
|
# used to do all gather in emb tp groups after reduce scatter in moe
|
||
|
|
emb_token_split_list: List[int] = None
|
||
|
|
# batch sizes in each logits tp group, the list length is tensor parallel size
|
||
|
|
# used to do all gather in logits tp groups after reduce scatter in moe
|
||
|
|
logits_batch_split_list: List[int] = None
|
||
|
|
# token num in each dense mlp group, the list length is dense mlp tp size
|
||
|
|
# used to do one more all gather after dense mlp and before reduce scatter
|
||
|
|
dense_attn_token_split_list: List[int] = None
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def make_oot(
|
||
|
|
data_parallel_rank: int,
|
||
|
|
data_parallel_size: int,
|
||
|
|
tensor_parallel_size: int,
|
||
|
|
dp_token_nums: List[int],
|
||
|
|
dp_is_prefill: List[bool],
|
||
|
|
prefill_dispatch_use_RS_AG: bool,
|
||
|
|
seq_lens: List[int] = None,
|
||
|
|
batch_sizes: List[int] = None,
|
||
|
|
emb_query_lens: List[int] = None,
|
||
|
|
logits_batch_sizes: List[int] = None,
|
||
|
|
dense_attn_token_split_list: List[int] = None,
|
||
|
|
) -> "MLUDPMetadata":
|
||
|
|
token_num_offset = sum(dp_token_nums[:data_parallel_rank])
|
||
|
|
token_num = dp_token_nums[data_parallel_rank]
|
||
|
|
token_split_list = dp_token_nums
|
||
|
|
|
||
|
|
attn_can_use_reduce_scatter = all(
|
||
|
|
(num != 0 and num % tensor_parallel_size == 0)
|
||
|
|
for num in token_split_list
|
||
|
|
)
|
||
|
|
all_split_token_num_equal = all(
|
||
|
|
num == token_split_list[0] for num in token_split_list
|
||
|
|
)
|
||
|
|
layer_can_use_reduce_scatter = (
|
||
|
|
attn_can_use_reduce_scatter and all_split_token_num_equal
|
||
|
|
)
|
||
|
|
|
||
|
|
attn_token_split_list_reduce_scatter = None
|
||
|
|
moe_token_split_list_reduce_scatter = None
|
||
|
|
prefill_pad_to_token_num = -1
|
||
|
|
tp_world_size = data_parallel_size * tensor_parallel_size
|
||
|
|
if layer_can_use_reduce_scatter:
|
||
|
|
attn_token_split_list_reduce_scatter = (
|
||
|
|
[token_split_list[0] // tensor_parallel_size] * tp_world_size
|
||
|
|
)
|
||
|
|
moe_token_split_list_reduce_scatter = (
|
||
|
|
attn_token_split_list_reduce_scatter[:tensor_parallel_size]
|
||
|
|
)
|
||
|
|
elif (
|
||
|
|
prefill_dispatch_use_RS_AG
|
||
|
|
and all(is_prefill for is_prefill in dp_is_prefill)
|
||
|
|
):
|
||
|
|
dp_group_max_token_nums = max(dp_token_nums)
|
||
|
|
prefill_pad_to_token_num = (
|
||
|
|
(dp_group_max_token_nums + tensor_parallel_size - 1)
|
||
|
|
// tensor_parallel_size
|
||
|
|
) * tensor_parallel_size
|
||
|
|
attn_token_split_list_reduce_scatter = (
|
||
|
|
[prefill_pad_to_token_num // tensor_parallel_size] * tp_world_size
|
||
|
|
)
|
||
|
|
|
||
|
|
return MLUDPMetadata(
|
||
|
|
max_tokens_across_dp_cpu=None,
|
||
|
|
num_tokens_across_dp_cpu=None,
|
||
|
|
token_num=token_num,
|
||
|
|
token_num_offset=token_num_offset,
|
||
|
|
token_split_list=token_split_list,
|
||
|
|
layer_use_reduce_scatter=layer_can_use_reduce_scatter,
|
||
|
|
prefill_pad_to_token_num=prefill_pad_to_token_num,
|
||
|
|
attn_token_split_list_reduce_scatter=attn_token_split_list_reduce_scatter,
|
||
|
|
moe_token_split_list_reduce_scatter=moe_token_split_list_reduce_scatter,
|
||
|
|
seq_lens=seq_lens,
|
||
|
|
batch_sizes=batch_sizes,
|
||
|
|
dp_is_prefill=dp_is_prefill,
|
||
|
|
emb_token_split_list=emb_query_lens,
|
||
|
|
logits_batch_split_list=logits_batch_sizes,
|
||
|
|
dense_attn_token_split_list=dense_attn_token_split_list,
|
||
|
|
)
|