[BugFix]Fix group list type of mc2. (#3890)

### What this PR does / why we need it?
Fix the precision issue caused by the inconsistency between the group
list type used by mc2 and that of eplb.

---------

Signed-off-by: offline0806 <3337230449@qq.com>
This commit is contained in:
offline893
2025-10-30 21:44:14 +08:00
committed by GitHub
parent c506ba60fb
commit d5a9aba03f
4 changed files with 18 additions and 14 deletions

View File

@@ -69,7 +69,8 @@ class MoETokenDispatcher(ABC):
dynamic_scale_for_share: Optional[Any] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False):
with_quant: bool = False,
dynamic_eplb: bool = False):
raise NotImplementedError("Dispatch function not implemented.")
@abstractmethod
@@ -175,7 +176,8 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
dynamic_scale_for_share: Optional[Any] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False):
with_quant: bool = False,
dynamic_eplb: bool = False):
self.with_quant = with_quant
self.expert_map = expert_map
self.topk_ids = topk_ids
@@ -210,7 +212,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
if shared_experts is not None:
shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states)
self.shared_act = shared_experts.act_fn(shared_gate_up)
group_list_type = 0
group_list_type = 1 if dynamic_eplb else 0
return {
"group_list_type": group_list_type,
"hidden_states": expand_x,
@@ -333,7 +335,8 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
dynamic_scale_for_share: Optional[Any] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False):
with_quant: bool = False,
dynamic_eplb: bool = False):
self.with_quant = with_quant
self.original_shape = hidden_states.shape
@@ -424,7 +427,8 @@ class TokenDispatcherWithMoge(MoETokenDispatcher):
dynamic_scale_for_share: Optional[Any] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False):
with_quant: bool = False,
dynamic_eplb: bool = False):
self.bsz, _ = hidden_states.shape
flatten_topk_ids = topk_ids.view(-1)
self.sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
@@ -521,7 +525,8 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
dynamic_scale_for_share: Optional[Any] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False):
with_quant: bool = False,
dynamic_eplb: bool = False):
self.with_quant = with_quant
self.hidden_shape = hidden_states.shape
self.topk_weights = topk_weights