[Model] Support DeepSeek-V4
This commit is contained in:
120
vllm_mlu/mlu_forward_context.py
Normal file
120
vllm_mlu/mlu_forward_context.py
Normal file
@@ -0,0 +1,120 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
from vllm.forward_context import DPMetadata
|
||||
|
||||
|
||||
@dataclass
|
||||
class MLUDPMetadata(DPMetadata):
|
||||
# mlu platform arguments
|
||||
# token num for current dp group
|
||||
token_num: int = None
|
||||
# token num offset for current dp group
|
||||
token_num_offset: int = None
|
||||
# whether we can use reduce scatter for both attn layer and mlp layer
|
||||
layer_use_reduce_scatter: bool = False
|
||||
# token num need to be pad for prefill, then we can do reduce scatter +
|
||||
# all gather to optimize comm time
|
||||
prefill_pad_to_token_num: int = -1
|
||||
# token num in each dp group, the list length is attn data parallel size
|
||||
# used to do all gather in dp groups after all reduce in attn
|
||||
token_split_list: List[int] = None
|
||||
# token num in each card, the list length is world size
|
||||
# used to do all gather in all cards after reduce scatter in attn
|
||||
attn_token_split_list_reduce_scatter: List[int] = None
|
||||
# token num in each tp group, the list length is tensor parallel size
|
||||
# used to do all gather in tp groups after reduce scatter in moe
|
||||
moe_token_split_list_reduce_scatter: List[int] = None
|
||||
# prefill or decode stage in each dp group
|
||||
dp_is_prefill: List[bool] = None
|
||||
|
||||
# ADDITIONAL fields for merged compute and communication.
|
||||
# Global sequence lengths for each batch size for prefill stage.
|
||||
seq_lens: List[int] = None
|
||||
# Batch sizes for each attn dp rank for prefill stage.
|
||||
batch_sizes: List[int] = None
|
||||
|
||||
# ADDITIONAL fields for custom split for embedding, logits and dense mlp layer
|
||||
# token num in each emb tp group, the list length is tensor parallel size
|
||||
# used to do all gather in emb tp groups after reduce scatter in moe
|
||||
emb_token_split_list: List[int] = None
|
||||
# batch sizes in each logits tp group, the list length is tensor parallel size
|
||||
# used to do all gather in logits tp groups after reduce scatter in moe
|
||||
logits_batch_split_list: List[int] = None
|
||||
# token num in each dense mlp group, the list length is dense mlp tp size
|
||||
# used to do one more all gather after dense mlp and before reduce scatter
|
||||
dense_attn_token_split_list: List[int] = None
|
||||
|
||||
@staticmethod
|
||||
def make_oot(
|
||||
data_parallel_rank: int,
|
||||
data_parallel_size: int,
|
||||
tensor_parallel_size: int,
|
||||
dp_token_nums: List[int],
|
||||
dp_is_prefill: List[bool],
|
||||
prefill_dispatch_use_RS_AG: bool,
|
||||
seq_lens: List[int] = None,
|
||||
batch_sizes: List[int] = None,
|
||||
emb_query_lens: List[int] = None,
|
||||
logits_batch_sizes: List[int] = None,
|
||||
dense_attn_token_split_list: List[int] = None,
|
||||
) -> "MLUDPMetadata":
|
||||
token_num_offset = sum(dp_token_nums[:data_parallel_rank])
|
||||
token_num = dp_token_nums[data_parallel_rank]
|
||||
token_split_list = dp_token_nums
|
||||
|
||||
attn_can_use_reduce_scatter = all(
|
||||
(num != 0 and num % tensor_parallel_size == 0)
|
||||
for num in token_split_list
|
||||
)
|
||||
all_split_token_num_equal = all(
|
||||
num == token_split_list[0] for num in token_split_list
|
||||
)
|
||||
layer_can_use_reduce_scatter = (
|
||||
attn_can_use_reduce_scatter and all_split_token_num_equal
|
||||
)
|
||||
|
||||
attn_token_split_list_reduce_scatter = None
|
||||
moe_token_split_list_reduce_scatter = None
|
||||
prefill_pad_to_token_num = -1
|
||||
tp_world_size = data_parallel_size * tensor_parallel_size
|
||||
if layer_can_use_reduce_scatter:
|
||||
attn_token_split_list_reduce_scatter = (
|
||||
[token_split_list[0] // tensor_parallel_size] * tp_world_size
|
||||
)
|
||||
moe_token_split_list_reduce_scatter = (
|
||||
attn_token_split_list_reduce_scatter[:tensor_parallel_size]
|
||||
)
|
||||
elif (
|
||||
prefill_dispatch_use_RS_AG
|
||||
and all(is_prefill for is_prefill in dp_is_prefill)
|
||||
):
|
||||
dp_group_max_token_nums = max(dp_token_nums)
|
||||
prefill_pad_to_token_num = (
|
||||
(dp_group_max_token_nums + tensor_parallel_size - 1)
|
||||
// tensor_parallel_size
|
||||
) * tensor_parallel_size
|
||||
attn_token_split_list_reduce_scatter = (
|
||||
[prefill_pad_to_token_num // tensor_parallel_size] * tp_world_size
|
||||
)
|
||||
|
||||
return MLUDPMetadata(
|
||||
max_tokens_across_dp_cpu=None,
|
||||
num_tokens_across_dp_cpu=None,
|
||||
token_num=token_num,
|
||||
token_num_offset=token_num_offset,
|
||||
token_split_list=token_split_list,
|
||||
layer_use_reduce_scatter=layer_can_use_reduce_scatter,
|
||||
prefill_pad_to_token_num=prefill_pad_to_token_num,
|
||||
attn_token_split_list_reduce_scatter=attn_token_split_list_reduce_scatter,
|
||||
moe_token_split_list_reduce_scatter=moe_token_split_list_reduce_scatter,
|
||||
seq_lens=seq_lens,
|
||||
batch_sizes=batch_sizes,
|
||||
dp_is_prefill=dp_is_prefill,
|
||||
emb_token_split_list=emb_query_lens,
|
||||
logits_batch_split_list=logits_batch_sizes,
|
||||
dense_attn_token_split_list=dense_attn_token_split_list,
|
||||
)
|
||||
Reference in New Issue
Block a user