[BugFix]Fix group list type of mc2. (#3864)

### What this PR does / why we need it?
Fix the precision issue caused by the inconsistency between the group
list type used by mc2 and that of eplb.

- vLLM version: v0.11.0rc3
- vLLM main:
83f478bb19

---------

Signed-off-by: offline0806 <3337230449@qq.com>
This commit is contained in:
offline893
2025-10-30 21:39:01 +08:00
committed by GitHub
parent 655a229455
commit 627f20ce26
3 changed files with 44 additions and 40 deletions

View File

@@ -69,7 +69,8 @@ class MoETokenDispatcher(ABC):
dynamic_scale_for_share: Optional[Any] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False):
with_quant: bool = False,
dynamic_eplb: bool = False):
raise NotImplementedError("Dispatch function not implemented.")
@abstractmethod
@@ -156,21 +157,20 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
kwargs_mc2.update(stage1_kwargs)
return kwargs_mc2
def token_dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: Optional[torch.Tensor] = None,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
shared_experts: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False,
):
def token_dispatch(self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: Optional[torch.Tensor] = None,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
shared_experts: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False,
dynamic_eplb: bool = False):
self.with_quant = with_quant
# Apply log2phy if needed
@@ -221,8 +221,10 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
"expand_scales": expand_scales
}
group_list_type = 1 if dynamic_eplb else 0
return {
"group_list_type": 0,
"group_list_type": group_list_type,
"hidden_states": expand_x,
"group_list": expert_token_nums,
"dynamic_scale": dynamic_scale,
@@ -336,7 +338,8 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
dynamic_scale_for_share: Optional[Any] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False):
with_quant: bool = False,
dynamic_eplb: bool = False):
self.with_quant = with_quant
self.original_shape = hidden_states.shape
@@ -426,7 +429,8 @@ class TokenDispatcherWithMoge(MoETokenDispatcher):
dynamic_scale_for_share: Optional[Any] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False):
with_quant: bool = False,
dynamic_eplb: bool = False):
self.bsz, _ = hidden_states.shape
flatten_topk_ids = topk_ids.view(-1)
self.sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
@@ -501,21 +505,20 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
self.local_expert_indices[i + 1] -
1), "local_expert_indices must be continuous"
def token_dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: Optional[torch.Tensor] = None,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
shared_experts: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False,
):
def token_dispatch(self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: Optional[torch.Tensor] = None,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
shared_experts: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False,
dynamic_eplb: bool = False):
self.with_quant = with_quant
self.hidden_shape = hidden_states.shape