From e9ada685ece798f9fe0d4a287e3f5246a8a7207b Mon Sep 17 00:00:00 2001 From: weijinqian0 <1184188277@qq.com> Date: Sat, 7 Jun 2025 10:15:56 +0800 Subject: [PATCH] [CI]Moe alltoall communication optimization (#1067) [CI]Moe alltoall communication optimization The DeepSeek V3/R1 model has 256 routing experts. During parallel inference, if the load of an EP rank is high, the overall communication and computing time is slowed down, which becomes a weakness of parallel inference because the load is unevenly distributed. However, the data volume in the prefill phase is large, and the inter-card communication time consumption/calculation time consumption and the data volume are closely related to each other. Therefore, less non-linear precision loss can be used to obtain a near-linear performance improvement. During parallel inference, global synchronization occurs during communication. As a result, the card with low load completes the calculation first and waits for the card with the highest load to complete the calculation. Therefore, if the load is unbalanced, the card with high load slows down the overall time consumption. Significant performance gains can be achieved by discarding a small number of tokens, which is unacceptable in some precision-sensitive scenarios. However, similar to quantification, it is a solution that uses an acceptable precision loss in some scenarios for performance. In addition, a trade-off between performance and precision can be achieved by configuring a proportion of discarded tokens. Perform the test on A3. The batch size is 8 (B), the prompt length is 3.5K tokens (S), and the parallel configuration is as follows: AttnDP=2, AttnTP=8, MoeTP=1, and MoeEP=16. In this sence, we got a 10%-15% performance gain. Plus, the next version, we'll have an alltoallv moe. --------- Signed-off-by: weijinqian_v1 Co-authored-by: weijinqian_v1 --- vllm_ascend/envs.py | 5 + vllm_ascend/ops/fused_moe.py | 280 +++++++++++++++++++++++++++++++++-- 2 files changed, 273 insertions(+), 12 deletions(-) diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index 2d11e3d..23557f1 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -112,6 +112,11 @@ env_variables: Dict[str, Callable[[], Any]] = { "VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE": lambda: bool(int(os.getenv("VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE", '0')) ), + # MOE_ALL2ALL_BUFFER: + # 0: default, normal init. + # 1: enable moe_all2all_buffer. + "MOE_ALL2ALL_BUFFER": + lambda: bool(int(os.getenv("MOE_ALL2ALL_BUFFER", '0'))), # VLLM_ASCEND_ACL_OP_INIT_MODE: # 0: default, normal init. # 1: delay init until launch aclops. diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 6036e60..fd06d90 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -15,7 +15,7 @@ # This file is a part of the vllm-ascend project. # Adapted from vllm/tests/kernels/test_moe.py -from typing import Callable, Optional +from typing import Callable, List, Optional import torch import torch.distributed as dist @@ -37,6 +37,71 @@ from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2 USING_LCCL_COM: bool = envs_ascend.USING_LCCL_COM +MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER + + +def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int, + max_row_per_ep_rank: int, num_tokens: int, + top_k: int) -> tuple[torch.Tensor, torch.Tensor]: + original_total_elements = num_tokens * top_k + device = topk_ids.device + original_dtype = topk_ids.dtype + + if original_total_elements == 0: + output_len = ep_size * max_row_per_ep_rank + topk_ids_pad = torch.full((output_len, ), + expert_num, + dtype=original_dtype, + device=device) + unpad_indices = torch.full((original_total_elements, ), + -1, + dtype=torch.long, + device=device) + return topk_ids_pad, unpad_indices + + experts_per_ep_rank_val = expert_num // ep_size + if experts_per_ep_rank_val == 0: + raise ValueError( + "expert_num // ep_size is 0, which leads to division by zero in ep_rank calculation. " + "Ensure expert_num >= ep_size.") + + assigned_ep_rank = (topk_ids.float() / + experts_per_ep_rank_val).to(original_dtype) + indices_arange = torch.arange(topk_ids.shape[0], device=device) + + is_new_segment = torch.cat((torch.tensor([True], device=device), + assigned_ep_rank[1:] != assigned_ep_rank[:-1])) + temp_start_markers = torch.full_like(indices_arange, + -1, + dtype=indices_arange.dtype) + temp_start_markers[is_new_segment] = indices_arange[is_new_segment] + start_offset_for_each_token = torch.cummax(temp_start_markers, dim=0)[0] + token_intra_ep_rank_idx = indices_arange - start_offset_for_each_token + is_kept_mask = token_intra_ep_rank_idx < max_row_per_ep_rank + cumsum_kept = torch.cumsum(is_kept_mask.float(), dim=0).to(torch.long) + indices_in_rec_cond_list_for_all = cumsum_kept - 1 + unpad_indices = torch.where( + is_kept_mask, indices_in_rec_cond_list_for_all, + torch.tensor(-1, device=device, dtype=torch.long)) + output_len = ep_size * max_row_per_ep_rank + topk_ids_pad = torch.full((output_len, ), + expert_num, + dtype=original_dtype, + device=device) + if topk_ids.shape[0] > 0: + all_destination_indices = assigned_ep_rank * max_row_per_ep_rank + token_intra_ep_rank_idx + temp_pad_buffer = torch.full((output_len + 1, ), + expert_num, + dtype=original_dtype, + device=device) + output_len_tensor = torch.tensor(output_len, + dtype=torch.long, + device=device) + scatter_indices = torch.where(is_kept_mask, all_destination_indices, + output_len_tensor) + temp_pad_buffer.scatter_(0, scatter_indices, topk_ids) + topk_ids_pad = temp_pad_buffer[:output_len] + return topk_ids_pad, unpad_indices def fused_experts_with_mc2(hidden_states: torch.Tensor, @@ -146,8 +211,62 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor, return hidden_states -# currently expert parallelism implemented with all2all -# is under-optimized. +def apply_mlp(hidden_states_wrapper: List[torch.Tensor], + w1: torch.Tensor, + w2: torch.Tensor, + group_list: torch.Tensor, + group_list_type: int = 1) -> torch.Tensor: + """ + apply MLP: gate_up_proj -> swiglu -> down_proj + + Args: + hidden_states_wrapper: wrapper of input hidden states with shape (num_tokens, hidden_size). + w1: expert weights1 with shape + (num_experts, hidden_size, intermediate_size * 2) + w2: expert weights2 with shape + (num_experts, intermediate_size, hidden_size) + group_list: number of tokens for each expert, follow cumsum mode, and + with shape (num_experts). + transpose_weight: + w1: (num_experts, intermediate_size * 2, hidden_size) -> + (num_experts, hidden_size, intermediate_size * 2) + w2: (num_experts, hidden_size, intermediate_size) -> + (num_experts, intermediate_size, hidden_size) + + Returns: + hidden_states: output hidden states after MLP. + """ + + assert len(hidden_states_wrapper) == 1 + hidden_states = hidden_states_wrapper.pop() + + w1 = w1.transpose(1, 2) + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[w1], + split_item=2, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + ) + + hidden_states = torch.cat(hidden_states, dim=0) + hidden_states = torch_npu.npu_swiglu(hidden_states) + + w2 = w2.transpose(1, 2) + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[w2], + split_item=2, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + ) + + hidden_states = torch.cat(hidden_states, dim=0) + return hidden_states + + def fused_experts_with_all2all( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -283,6 +402,133 @@ def fused_experts_with_all2all( return final_hidden_states +# currently expert parallelism implemented with all2all +# is under-optimized. +def fused_experts_with_all2all_buffer( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + top_k: int, + max_model_len: int, + global_batch_size: int, + expert_map: torch.Tensor = None, + ep_group: GroupCoordinator = None, +): + original_shape = hidden_states.shape + if len(original_shape) == 3: + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + + num_tokens, _ = hidden_states.shape + device = hidden_states.device + + global_num_experts = len(expert_map) + local_num_experts = global_num_experts // ep_group.world_size + row_idx_len = num_tokens * top_k + row_idx = (torch.arange(0, row_idx_len, dtype=torch.int32, + device=device).view(top_k, + -1).permute(1, 0).contiguous()) + hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( + hidden_states, + row_idx=row_idx, + expert_idx=topk_ids, + active_num=num_tokens) + + max_row_per_ep_rank = (-(-global_batch_size // ep_group.world_size) * + max_model_len // ep_group.world_size + + 1) * top_k * 2 + expert_idx_buffer_scatter, unpad_indices = process_topk_ids( + expanded_expert_idx, global_num_experts, ep_group.world_size, + max_row_per_ep_rank, num_tokens, top_k) + hidden_states_pad_idx = torch.zeros( + expert_idx_buffer_scatter.shape, + dtype=expert_idx_buffer_scatter.dtype, + device=expert_idx_buffer_scatter.device) + non_pad_len = torch.sum( + (expert_idx_buffer_scatter != global_num_experts).to(torch.int32)) + hidden_states_pad_idx[ + expert_idx_buffer_scatter != global_num_experts] = torch.arange( + non_pad_len, + dtype=expert_idx_buffer_scatter.dtype, + device=hidden_states.device) + + hidden_states_buffer_scatter = hidden_states[hidden_states_pad_idx] + expert_idx_buffer_gather = torch.empty_like( + expert_idx_buffer_scatter, + dtype=expert_idx_buffer_scatter.dtype, + device=expert_idx_buffer_scatter.device) + hidden_states_buffer_gather = torch.empty_like( + hidden_states_buffer_scatter, + dtype=hidden_states_buffer_scatter.dtype, + device=hidden_states_buffer_scatter.device) + dist.all_to_all_single(expert_idx_buffer_gather, + expert_idx_buffer_scatter, + group=ep_group.device_group) + dist.all_to_all_single(hidden_states_buffer_gather, + hidden_states_buffer_scatter, + group=ep_group.device_group) + mask = expert_idx_buffer_gather != global_num_experts + local_expert_idx = expert_idx_buffer_gather[mask] - ep_group.rank * ( + global_num_experts // ep_group.world_size) + hidden_states = hidden_states_buffer_gather[mask] + idx_type = local_expert_idx.dtype + sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx.float()) + sorted_local_expert_idx = sorted_local_expert_idx.to(idx_type) + + expert_tokens = torch_npu.npu_moe_compute_expert_tokens( + sorted_local_expert_idx, local_num_experts).to(torch.int64) + hidden_states = hidden_states[sorted_idx] + group_list_type = 0 + + hidden_states_wrapper = [hidden_states] + del hidden_states + + hidden_states = apply_mlp(hidden_states_wrapper, + w1, + w2, + expert_tokens, + group_list_type=group_list_type) + + resorted_idx = torch.argsort(sorted_idx.float()).to(sorted_idx.dtype) + hidden_states = hidden_states[resorted_idx] + hidden_states_scatter = torch.zeros( + (mask.shape[0], hidden_states.shape[1]), + dtype=hidden_states.dtype, + device=hidden_states.device) + hidden_states_scatter[mask] = hidden_states + hidden_states_gatter = torch.empty_like( + hidden_states_scatter, + dtype=hidden_states_scatter.dtype, + device=hidden_states_scatter.device) + dist.all_to_all_single(hidden_states_gatter, + hidden_states_scatter, + group=ep_group.device_group) + hidden_states_gatter = hidden_states_gatter[ + expert_idx_buffer_scatter != global_num_experts] + if hidden_states_gatter.shape[0] != row_idx_len: + hidden_states = torch.zeros((row_idx_len, hidden_states.shape[1]), + dtype=hidden_states.dtype, + device=hidden_states.device) + hidden_states[unpad_indices != -1] = hidden_states_gatter + else: + # TODO: Reorder device memory 2 times here, replace the current + hidden_states = hidden_states_gatter + final_hidden_states = torch_npu.npu_moe_finalize_routing( + hidden_states, + skip1=None, + skip2=None, + bias=None, + scales=topk_weights, + expanded_src_to_dst_row=expanded_row_idx, + export_for_source_row=topk_ids, + ) + + if len(original_shape) == 3: + final_hidden_states = final_hidden_states.view(original_shape) + return final_hidden_states + + def fused_experts( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -585,6 +831,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): self.ep_size = ep_group.world_size self.global_batch_size = vllm_config.scheduler_config.max_num_seqs self.local_batch_size = self.global_batch_size // self.ep_size + self.max_model_len = vllm_config.model_config.max_model_len ascend_config = get_ascend_config() self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled @@ -613,21 +860,22 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): self, layer: torch.nn.Module, x: torch.Tensor, - use_grouped_topk: bool, - top_k: int, router_logits: torch.Tensor, + top_k: int, renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, + use_grouped_topk: bool = False, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, is_prefill: bool = False, enable_force_load_balance: bool = False, **kwargs, - ): + ) -> torch.Tensor: + # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern if global_num_experts == 256: topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( @@ -683,11 +931,19 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): topk_ids=topk_ids, top_k=top_k, expert_map=expert_map) + elif MOE_ALL2ALL_BUFFER: + return fused_experts_with_all2all_buffer( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + max_model_len=self.max_model_len, + global_batch_size=self.global_batch_size, + expert_map=expert_map, + ep_group=get_ep_group()) else: - # The current implementation of deepseek moe splits hidden_states - # according to tp_size before they are feed into fused_moe module. - # Therefore, all2all is needed no matter how dp/tp is set so as to - # dispatch/combine tokens. return fused_experts_with_all2all(hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight,