# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project from dataclasses import dataclass from typing import List from vllm.forward_context import DPMetadata @dataclass class MLUDPMetadata(DPMetadata): # mlu platform arguments # token num for current dp group token_num: int = None # token num offset for current dp group token_num_offset: int = None # whether we can use reduce scatter for both attn layer and mlp layer layer_use_reduce_scatter: bool = False # token num need to be pad for prefill, then we can do reduce scatter + # all gather to optimize comm time prefill_pad_to_token_num: int = -1 # token num in each dp group, the list length is attn data parallel size # used to do all gather in dp groups after all reduce in attn token_split_list: List[int] = None # token num in each card, the list length is world size # used to do all gather in all cards after reduce scatter in attn attn_token_split_list_reduce_scatter: List[int] = None # token num in each tp group, the list length is tensor parallel size # used to do all gather in tp groups after reduce scatter in moe moe_token_split_list_reduce_scatter: List[int] = None # prefill or decode stage in each dp group dp_is_prefill: List[bool] = None # ADDITIONAL fields for merged compute and communication. # Global sequence lengths for each batch size for prefill stage. seq_lens: List[int] = None # Batch sizes for each attn dp rank for prefill stage. batch_sizes: List[int] = None # ADDITIONAL fields for custom split for embedding, logits and dense mlp layer # token num in each emb tp group, the list length is tensor parallel size # used to do all gather in emb tp groups after reduce scatter in moe emb_token_split_list: List[int] = None # batch sizes in each logits tp group, the list length is tensor parallel size # used to do all gather in logits tp groups after reduce scatter in moe logits_batch_split_list: List[int] = None # token num in each dense mlp group, the list length is dense mlp tp size # used to do one more all gather after dense mlp and before reduce scatter dense_attn_token_split_list: List[int] = None @staticmethod def make_oot( data_parallel_rank: int, data_parallel_size: int, tensor_parallel_size: int, dp_token_nums: List[int], dp_is_prefill: List[bool], prefill_dispatch_use_RS_AG: bool, seq_lens: List[int] = None, batch_sizes: List[int] = None, emb_query_lens: List[int] = None, logits_batch_sizes: List[int] = None, dense_attn_token_split_list: List[int] = None, ) -> "MLUDPMetadata": token_num_offset = sum(dp_token_nums[:data_parallel_rank]) token_num = dp_token_nums[data_parallel_rank] token_split_list = dp_token_nums attn_can_use_reduce_scatter = all( (num != 0 and num % tensor_parallel_size == 0) for num in token_split_list ) all_split_token_num_equal = all( num == token_split_list[0] for num in token_split_list ) layer_can_use_reduce_scatter = ( attn_can_use_reduce_scatter and all_split_token_num_equal ) attn_token_split_list_reduce_scatter = None moe_token_split_list_reduce_scatter = None prefill_pad_to_token_num = -1 tp_world_size = data_parallel_size * tensor_parallel_size if layer_can_use_reduce_scatter: attn_token_split_list_reduce_scatter = ( [token_split_list[0] // tensor_parallel_size] * tp_world_size ) moe_token_split_list_reduce_scatter = ( attn_token_split_list_reduce_scatter[:tensor_parallel_size] ) elif ( prefill_dispatch_use_RS_AG and all(is_prefill for is_prefill in dp_is_prefill) ): dp_group_max_token_nums = max(dp_token_nums) prefill_pad_to_token_num = ( (dp_group_max_token_nums + tensor_parallel_size - 1) // tensor_parallel_size ) * tensor_parallel_size attn_token_split_list_reduce_scatter = ( [prefill_pad_to_token_num // tensor_parallel_size] * tp_world_size ) return MLUDPMetadata( max_tokens_across_dp_cpu=None, num_tokens_across_dp_cpu=None, token_num=token_num, token_num_offset=token_num_offset, token_split_list=token_split_list, layer_use_reduce_scatter=layer_can_use_reduce_scatter, prefill_pad_to_token_num=prefill_pad_to_token_num, attn_token_split_list_reduce_scatter=attn_token_split_list_reduce_scatter, moe_token_split_list_reduce_scatter=moe_token_split_list_reduce_scatter, seq_lens=seq_lens, batch_sizes=batch_sizes, dp_is_prefill=dp_is_prefill, emb_token_split_list=emb_query_lens, logits_batch_split_list=logits_batch_sizes, dense_attn_token_split_list=dense_attn_token_split_list, )