[BugFix]Fix group list type of mc2. (#3890)

### What this PR does / why we need it?
Fix the precision issue caused by the inconsistency between the group
list type used by mc2 and that of eplb.

---------

Signed-off-by: offline0806 <3337230449@qq.com>
This commit is contained in:
offline893
2025-10-30 21:44:14 +08:00
committed by GitHub
parent c506ba60fb
commit d5a9aba03f
4 changed files with 18 additions and 14 deletions

View File

@@ -1,8 +1,5 @@
from __future__ import annotations from __future__ import annotations
import os
from unittest.mock import patch
import pytest import pytest
from vllm import SamplingParams from vllm import SamplingParams
from vllm.config import CompilationConfig, CUDAGraphMode from vllm.config import CompilationConfig, CUDAGraphMode

View File

@@ -244,7 +244,8 @@ class AscendFusedMoE(FusedMoE):
self.expert_map != -1) if self.expert_map is not None else self.expert_map != -1) if self.expert_map is not None else
self.global_num_experts) self.global_num_experts)
if self.dynamic_eplb: if self.dynamic_eplb:
self.moe_load = torch.zeros(local_num_experts, dtype=torch.int64) self.moe_load = torch.zeros(local_num_experts,
dtype=torch.int64).npu()
self.moe_config.num_experts = self.global_num_experts self.moe_config.num_experts = self.global_num_experts
self.moe_config.num_local_experts = self.local_num_experts self.moe_config.num_local_experts = self.local_num_experts
@@ -340,9 +341,9 @@ class AscendFusedMoE(FusedMoE):
if isinstance(final_hidden_states, tuple): if isinstance(final_hidden_states, tuple):
final_hidden_states, group_list_type, expert_tokens = final_hidden_states final_hidden_states, group_list_type, expert_tokens = final_hidden_states
if self.dynamic_eplb: if self.dynamic_eplb:
self.moe_load += expert_tokens if group_list_type else \ self.moe_load += expert_tokens if group_list_type == 1 else \
torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]]) torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]])
final_hidden_states = forward_context.moe_comm_method.finalize( final_hidden_states = forward_context.moe_comm_method.finalize(
hidden_states=final_hidden_states, hidden_states=final_hidden_states,

View File

@@ -130,7 +130,8 @@ class MoECommMethod(ABC):
dynamic_scale_for_share=dynamic_scale_for_share, dynamic_scale_for_share=dynamic_scale_for_share,
mc2_mask=self.mc2_mask, mc2_mask=self.mc2_mask,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
with_quant=use_int8_w8a8 or use_int4_w4a8) with_quant=use_int8_w8a8 or use_int4_w4a8,
dynamic_eplb=dynamic_eplb)
permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type, topk_scales = \ permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type, topk_scales = \
results["hidden_states"], results["group_list"], results.get("dynamic_scale"), results["group_list_type"], results.get("topk_scales") results["hidden_states"], results["group_list"], results.get("dynamic_scale"), results["group_list_type"], results.get("topk_scales")

View File

@@ -69,7 +69,8 @@ class MoETokenDispatcher(ABC):
dynamic_scale_for_share: Optional[Any] = None, dynamic_scale_for_share: Optional[Any] = None,
mc2_mask: Optional[torch.Tensor] = None, mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
with_quant: bool = False): with_quant: bool = False,
dynamic_eplb: bool = False):
raise NotImplementedError("Dispatch function not implemented.") raise NotImplementedError("Dispatch function not implemented.")
@abstractmethod @abstractmethod
@@ -175,7 +176,8 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
dynamic_scale_for_share: Optional[Any] = None, dynamic_scale_for_share: Optional[Any] = None,
mc2_mask: Optional[torch.Tensor] = None, mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
with_quant: bool = False): with_quant: bool = False,
dynamic_eplb: bool = False):
self.with_quant = with_quant self.with_quant = with_quant
self.expert_map = expert_map self.expert_map = expert_map
self.topk_ids = topk_ids self.topk_ids = topk_ids
@@ -210,7 +212,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
if shared_experts is not None: if shared_experts is not None:
shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states) shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states)
self.shared_act = shared_experts.act_fn(shared_gate_up) self.shared_act = shared_experts.act_fn(shared_gate_up)
group_list_type = 0 group_list_type = 1 if dynamic_eplb else 0
return { return {
"group_list_type": group_list_type, "group_list_type": group_list_type,
"hidden_states": expand_x, "hidden_states": expand_x,
@@ -333,7 +335,8 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
dynamic_scale_for_share: Optional[Any] = None, dynamic_scale_for_share: Optional[Any] = None,
mc2_mask: Optional[torch.Tensor] = None, mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
with_quant: bool = False): with_quant: bool = False,
dynamic_eplb: bool = False):
self.with_quant = with_quant self.with_quant = with_quant
self.original_shape = hidden_states.shape self.original_shape = hidden_states.shape
@@ -424,7 +427,8 @@ class TokenDispatcherWithMoge(MoETokenDispatcher):
dynamic_scale_for_share: Optional[Any] = None, dynamic_scale_for_share: Optional[Any] = None,
mc2_mask: Optional[torch.Tensor] = None, mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
with_quant: bool = False): with_quant: bool = False,
dynamic_eplb: bool = False):
self.bsz, _ = hidden_states.shape self.bsz, _ = hidden_states.shape
flatten_topk_ids = topk_ids.view(-1) flatten_topk_ids = topk_ids.view(-1)
self.sorted_topk_ids = torch.argsort(flatten_topk_ids.float()) self.sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
@@ -521,7 +525,8 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
dynamic_scale_for_share: Optional[Any] = None, dynamic_scale_for_share: Optional[Any] = None,
mc2_mask: Optional[torch.Tensor] = None, mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
with_quant: bool = False): with_quant: bool = False,
dynamic_eplb: bool = False):
self.with_quant = with_quant self.with_quant = with_quant
self.hidden_shape = hidden_states.shape self.hidden_shape = hidden_states.shape
self.topk_weights = topk_weights self.topk_weights = topk_weights