From d5a9aba03f8e21171075448e1e6ecf9bfbfd4fcb Mon Sep 17 00:00:00 2001 From: offline893 <158537145+offline893@users.noreply.github.com> Date: Thu, 30 Oct 2025 21:44:14 +0800 Subject: [PATCH] [BugFix]Fix group list type of mc2. (#3890) ### What this PR does / why we need it? Fix the precision issue caused by the inconsistency between the group list type used by mc2 and that of eplb. --------- Signed-off-by: offline0806 <3337230449@qq.com> --- .../spec_decode_v1/test_v1_mtp_correctness.py | 3 --- vllm_ascend/ops/common_fused_moe.py | 9 +++++---- vllm_ascend/ops/moe/moe_comm_method.py | 3 ++- vllm_ascend/ops/moe/token_dispatcher.py | 17 +++++++++++------ 4 files changed, 18 insertions(+), 14 deletions(-) diff --git a/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py b/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py index 6dc2bc9..b6d8b66 100644 --- a/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py +++ b/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py @@ -1,8 +1,5 @@ from __future__ import annotations -import os -from unittest.mock import patch - import pytest from vllm import SamplingParams from vllm.config import CompilationConfig, CUDAGraphMode diff --git a/vllm_ascend/ops/common_fused_moe.py b/vllm_ascend/ops/common_fused_moe.py index 1aceb89..aec0ffc 100644 --- a/vllm_ascend/ops/common_fused_moe.py +++ b/vllm_ascend/ops/common_fused_moe.py @@ -244,7 +244,8 @@ class AscendFusedMoE(FusedMoE): self.expert_map != -1) if self.expert_map is not None else self.global_num_experts) if self.dynamic_eplb: - self.moe_load = torch.zeros(local_num_experts, dtype=torch.int64) + self.moe_load = torch.zeros(local_num_experts, + dtype=torch.int64).npu() self.moe_config.num_experts = self.global_num_experts self.moe_config.num_local_experts = self.local_num_experts @@ -340,9 +341,9 @@ class AscendFusedMoE(FusedMoE): if isinstance(final_hidden_states, tuple): final_hidden_states, group_list_type, expert_tokens = final_hidden_states - if self.dynamic_eplb: - self.moe_load += expert_tokens if group_list_type else \ - torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]]) + if self.dynamic_eplb: + self.moe_load += expert_tokens if group_list_type == 1 else \ + torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]]) final_hidden_states = forward_context.moe_comm_method.finalize( hidden_states=final_hidden_states, diff --git a/vllm_ascend/ops/moe/moe_comm_method.py b/vllm_ascend/ops/moe/moe_comm_method.py index 8f49841..7c8a973 100644 --- a/vllm_ascend/ops/moe/moe_comm_method.py +++ b/vllm_ascend/ops/moe/moe_comm_method.py @@ -130,7 +130,8 @@ class MoECommMethod(ABC): dynamic_scale_for_share=dynamic_scale_for_share, mc2_mask=self.mc2_mask, apply_router_weight_on_input=apply_router_weight_on_input, - with_quant=use_int8_w8a8 or use_int4_w4a8) + with_quant=use_int8_w8a8 or use_int4_w4a8, + dynamic_eplb=dynamic_eplb) permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type, topk_scales = \ results["hidden_states"], results["group_list"], results.get("dynamic_scale"), results["group_list_type"], results.get("topk_scales") diff --git a/vllm_ascend/ops/moe/token_dispatcher.py b/vllm_ascend/ops/moe/token_dispatcher.py index 83da546..88c8cb6 100644 --- a/vllm_ascend/ops/moe/token_dispatcher.py +++ b/vllm_ascend/ops/moe/token_dispatcher.py @@ -69,7 +69,8 @@ class MoETokenDispatcher(ABC): dynamic_scale_for_share: Optional[Any] = None, mc2_mask: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, - with_quant: bool = False): + with_quant: bool = False, + dynamic_eplb: bool = False): raise NotImplementedError("Dispatch function not implemented.") @abstractmethod @@ -175,7 +176,8 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): dynamic_scale_for_share: Optional[Any] = None, mc2_mask: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, - with_quant: bool = False): + with_quant: bool = False, + dynamic_eplb: bool = False): self.with_quant = with_quant self.expert_map = expert_map self.topk_ids = topk_ids @@ -210,7 +212,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): if shared_experts is not None: shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states) self.shared_act = shared_experts.act_fn(shared_gate_up) - group_list_type = 0 + group_list_type = 1 if dynamic_eplb else 0 return { "group_list_type": group_list_type, "hidden_states": expand_x, @@ -333,7 +335,8 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher): dynamic_scale_for_share: Optional[Any] = None, mc2_mask: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, - with_quant: bool = False): + with_quant: bool = False, + dynamic_eplb: bool = False): self.with_quant = with_quant self.original_shape = hidden_states.shape @@ -424,7 +427,8 @@ class TokenDispatcherWithMoge(MoETokenDispatcher): dynamic_scale_for_share: Optional[Any] = None, mc2_mask: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, - with_quant: bool = False): + with_quant: bool = False, + dynamic_eplb: bool = False): self.bsz, _ = hidden_states.shape flatten_topk_ids = topk_ids.view(-1) self.sorted_topk_ids = torch.argsort(flatten_topk_ids.float()) @@ -521,7 +525,8 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher): dynamic_scale_for_share: Optional[Any] = None, mc2_mask: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, - with_quant: bool = False): + with_quant: bool = False, + dynamic_eplb: bool = False): self.with_quant = with_quant self.hidden_shape = hidden_states.shape self.topk_weights = topk_weights