From 1f9fb869ad8832fefed92ae6ddaf36552d694c89 Mon Sep 17 00:00:00 2001 From: Angazenn <92204292+Angazenn@users.noreply.github.com> Date: Sat, 24 May 2025 14:29:36 +0800 Subject: [PATCH] [BugFix] Fix accuracy bugs for unquantized deepseekv3 models (#897) ### What this PR does / why we need it? This PR fixes two accuracy bugs incurred by PR #819 when running deepseekv3 series models: 1. #819 adds `all_to_all` communication in quantized cases, but `all_gather` && `reduce_scatter` are removed in both of quantized and unquantized cases. When running unquantized deepseekv3 models with `ep_size == world_size`, the moe modules fail to communicate. Therefore, this PR adds `all_to_all` communication on unquantized situation to solve this accuracy issue. 2. Use `ep_size` rather than `dp_size` to decide whether to use `all_to_all` in moe. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? CI passed with new added/existing test. --------- Signed-off-by: angazenn Co-authored-by: angazenn --- vllm_ascend/ops/fused_moe.py | 161 ++++++++++++++++++++++- vllm_ascend/quantization/quant_config.py | 3 +- vllm_ascend/quantization/w8a8_dynamic.py | 7 +- 3 files changed, 162 insertions(+), 9 deletions(-) diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 01da5be..6313a75 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -18,9 +18,11 @@ from typing import Callable, Optional import torch +import torch.distributed as dist import torch_npu from vllm.config import get_current_vllm_config -from vllm.distributed import (get_tensor_model_parallel_world_size, +from vllm.distributed import (GroupCoordinator, + get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) from vllm.distributed.parallel_state import get_dp_group from vllm.model_executor.layers.fused_moe.layer import ( @@ -154,6 +156,143 @@ def fused_experts_with_mc2( return hidden_states +# currently expert parallelism implemented with all2all +# is under-optimized. +def fused_experts_with_all2all( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + top_k: int, + expert_map: torch.Tensor = None, + ep_group: GroupCoordinator = None, +): + original_shape = hidden_states.shape + if len(original_shape) == 3: + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + + num_tokens, _ = hidden_states.shape + num_experts = w1.shape[0] + device = hidden_states.device + + if expert_map is not None: + global_num_experts = len(expert_map) + local_num_experts = global_num_experts // ep_group.world_size + row_idx_len = num_tokens * top_k + row_idx = (torch.arange(0, + row_idx_len, + dtype=torch.int32, + device=device).view(top_k, -1).permute( + 1, 0).contiguous()) + hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( + hidden_states, + row_idx=row_idx, + expert_idx=topk_ids, + active_num=num_tokens) + + global_expert_tokens = torch.bincount(expanded_expert_idx, + minlength=global_num_experts) + scatter_sizes = global_expert_tokens.view(ep_group.world_size, + -1).sum(-1) + + gather_sizes = torch.empty_like(scatter_sizes) + dist.all_to_all_single(gather_sizes, + scatter_sizes, + group=ep_group.device_group) + scatter_size_list = scatter_sizes.cpu().tolist() + gather_size_list = gather_sizes.cpu().tolist() + + expanded_expert_idx = expanded_expert_idx % local_num_experts + hidden_states = ep_group.all_to_all(hidden_states, 0, 0, + scatter_size_list, + gather_size_list) + local_expert_idx = ep_group.all_to_all(expanded_expert_idx, 0, 0, + scatter_size_list, + gather_size_list) + + sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx) + + expert_tokens = torch_npu.npu_moe_compute_expert_tokens( + sorted_local_expert_idx, local_num_experts).to(torch.int64) + + hidden_states = hidden_states[sorted_idx] + else: + row_idx_len = num_tokens * top_k + row_idx = torch.arange(0, + row_idx_len, + dtype=torch.int32, + device=topk_weights.device).view( + top_k, -1).permute(1, 0).contiguous() + hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( + hidden_states, + row_idx=row_idx, + expert_idx=topk_ids, + active_num=num_tokens) + + expert_tokens = torch_npu.npu_moe_compute_expert_tokens( + expanded_expert_idx, num_experts) + expert_tokens = expert_tokens.to(torch.int64) + + w1 = w1.transpose(1, 2) + gate_up_out_list = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[w1], + split_item=2, + group_list_type=0, + group_type=0, + group_list=expert_tokens, + ) + + # TODO: Remove this in the future. + hidden_states = torch.cat(gate_up_out_list, dim=0) + hidden_states = torch_npu.npu_swiglu(hidden_states) + + w2 = w2.transpose(1, 2) + down_out_list = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[w2], + split_item=2, + group_list_type=0, + group_type=0, + group_list=expert_tokens, + ) + + hidden_states = torch.cat(down_out_list, dim=0) + + if expert_map is not None: + resorted_idx = torch.argsort(sorted_idx) + hidden_states = hidden_states[resorted_idx] + hidden_states = ep_group.all_to_all(hidden_states, 0, 0, + gather_size_list, + scatter_size_list) + + final_hidden_states = torch_npu.npu_moe_finalize_routing( + hidden_states, + skip1=None, + skip2=None, + bias=None, + scales=topk_weights, + expanded_src_to_dst_row=expanded_row_idx, + export_for_source_row=topk_ids, + ) + else: + # TODO: Reorder device memory 2 times here, replace the current + # implementation here when suitable operators become available. + final_hidden_states = torch_npu.npu_moe_finalize_routing( + hidden_states, + skip1=None, + skip2=None, + bias=None, + scales=topk_weights, + expanded_src_to_dst_row=expanded_row_idx, + export_for_source_row=topk_ids, + ) + if len(original_shape) == 3: + final_hidden_states = final_hidden_states.view(original_shape) + return final_hidden_states + + def fused_experts( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -494,7 +633,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, - is_prefill=False, + is_prefill: bool = False, **kwargs, ): # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern @@ -536,7 +675,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): top_k=top_k, expert_map=expert_map, moe_all_to_all_group_name=self.moe_all_to_all_group_name) - else: + elif get_ep_group().world_size == 1: return fused_experts(hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, @@ -544,6 +683,19 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): topk_ids=topk_ids, top_k=top_k, expert_map=expert_map) + else: + # The current implementation of deepseek moe splits hidden_states + # according to tp_size before they are feed into fused_moe module. + # Therefore, all2all is needed no matter how dp/tp is set so as to + # dispatch/combine tokens. + return fused_experts_with_all2all(hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + expert_map=expert_map, + ep_group=get_ep_group()) class AscendFusedMoE(FusedMoE): @@ -721,8 +873,7 @@ class AscendFusedMoE(FusedMoE): scoring_func=self.scoring_func, e_score_correction_bias=self.e_score_correction_bias, is_prefill=is_prefill, - enable_force_load_balance=enable_force_load_balance, - dp_size=self.dp_size) + enable_force_load_balance=enable_force_load_balance) if VLLM_ENABLE_MC2 and not is_prefill: ... diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py index 3b0d0c4..40dbae3 100644 --- a/vllm_ascend/quantization/quant_config.py +++ b/vllm_ascend/quantization/quant_config.py @@ -323,14 +323,13 @@ class AscendFusedMoEMethod(FusedMoEMethodBase): e_score_correction_bias: Optional[torch.Tensor] = None, is_prefill: bool = True, enable_force_load_balance: bool = False, - dp_size: int = 1, **kwargs, ) -> torch.Tensor: return self.quant_method.apply( layer, x, router_logits, top_k, renormalize, use_grouped_topk, global_num_experts, expert_map, topk_group, num_expert_group, custom_routing_function, scoring_func, e_score_correction_bias, - is_prefill, enable_force_load_balance, dp_size) + is_prefill, enable_force_load_balance) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if hasattr(self.quant_method, "process_weights_after_loading"): diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 5d2b442..0f54b01 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -582,7 +582,6 @@ class AscendW8A8DynamicFusedMoEMethod: e_score_correction_bias: Optional[torch.Tensor] = None, is_prefill: bool = True, enable_force_load_balance: bool = True, - dp_size: int = 1, **kwargs, ) -> torch.Tensor: assert router_logits.shape[ @@ -635,7 +634,7 @@ class AscendW8A8DynamicFusedMoEMethod: top_k=top_k, expert_map=expert_map, moe_all_to_all_group_name=self.moe_all_to_all_group_name) - elif dp_size == 1: + elif self.ep_group.world_size == 1: return fused_experts(hidden_states=x, w1=layer.w13_weight, w1_scale=layer.w13_weight_scale, @@ -646,6 +645,10 @@ class AscendW8A8DynamicFusedMoEMethod: top_k=top_k, expert_map=expert_map) else: + # The current implementation of deepseek moe splits hidden_states + # according to tp_size before they are feed into fused_moe module. + # Therefore, all2all is needed no matter how dp/tp is set so as to + # dispatch/combine tokens. return fused_experts_with_all2all(hidden_states=x, w1=layer.w13_weight, w1_scale=layer.w13_weight_scale,