[BugFix]Fix group list type of mc2. (#3864)
### What this PR does / why we need it?
Fix the precision issue caused by the inconsistency between the group
list type used by mc2 and that of eplb.
- vLLM version: v0.11.0rc3
- vLLM main:
83f478bb19
---------
Signed-off-by: offline0806 <3337230449@qq.com>
This commit is contained in:
@@ -266,7 +266,8 @@ class AscendFusedMoE(FusedMoE):
|
||||
self.expert_map != -1) if self.expert_map is not None else
|
||||
self.global_num_experts)
|
||||
if self.dynamic_eplb:
|
||||
self.moe_load = torch.zeros(local_num_experts, dtype=torch.int64)
|
||||
self.moe_load = torch.zeros(local_num_experts,
|
||||
dtype=torch.int64).npu()
|
||||
|
||||
self.moe_config.num_experts = self.global_num_experts
|
||||
self.moe_config.num_local_experts = self.local_num_experts
|
||||
@@ -362,9 +363,8 @@ class AscendFusedMoE(FusedMoE):
|
||||
|
||||
if isinstance(final_hidden_states, tuple):
|
||||
final_hidden_states, group_list_type, expert_tokens = final_hidden_states
|
||||
|
||||
if self.dynamic_eplb:
|
||||
self.moe_load += expert_tokens if group_list_type else \
|
||||
self.moe_load += expert_tokens if group_list_type == 1 else \
|
||||
torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]])
|
||||
|
||||
final_hidden_states = forward_context.moe_comm_method.finalize(
|
||||
|
||||
@@ -130,7 +130,8 @@ class MoECommMethod(ABC):
|
||||
dynamic_scale_for_share=dynamic_scale_for_share,
|
||||
mc2_mask=mc2_mask,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
with_quant=use_int8_w8a8 or use_int4_w4a8)
|
||||
with_quant=use_int8_w8a8 or use_int4_w4a8,
|
||||
dynamic_eplb=dynamic_eplb)
|
||||
|
||||
permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type, topk_scales, context_metadata = \
|
||||
results["hidden_states"], results["group_list"], results.get("dynamic_scale"), results["group_list_type"], results.get("topk_scales"), results.get("context_metadata")
|
||||
|
||||
@@ -69,7 +69,8 @@ class MoETokenDispatcher(ABC):
|
||||
dynamic_scale_for_share: Optional[Any] = None,
|
||||
mc2_mask: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
with_quant: bool = False):
|
||||
with_quant: bool = False,
|
||||
dynamic_eplb: bool = False):
|
||||
raise NotImplementedError("Dispatch function not implemented.")
|
||||
|
||||
@abstractmethod
|
||||
@@ -156,8 +157,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
kwargs_mc2.update(stage1_kwargs)
|
||||
return kwargs_mc2
|
||||
|
||||
def token_dispatch(
|
||||
self,
|
||||
def token_dispatch(self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
@@ -170,7 +170,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
mc2_mask: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
with_quant: bool = False,
|
||||
):
|
||||
dynamic_eplb: bool = False):
|
||||
self.with_quant = with_quant
|
||||
|
||||
# Apply log2phy if needed
|
||||
@@ -221,8 +221,10 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
"expand_scales": expand_scales
|
||||
}
|
||||
|
||||
group_list_type = 1 if dynamic_eplb else 0
|
||||
|
||||
return {
|
||||
"group_list_type": 0,
|
||||
"group_list_type": group_list_type,
|
||||
"hidden_states": expand_x,
|
||||
"group_list": expert_token_nums,
|
||||
"dynamic_scale": dynamic_scale,
|
||||
@@ -336,7 +338,8 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
|
||||
dynamic_scale_for_share: Optional[Any] = None,
|
||||
mc2_mask: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
with_quant: bool = False):
|
||||
with_quant: bool = False,
|
||||
dynamic_eplb: bool = False):
|
||||
self.with_quant = with_quant
|
||||
self.original_shape = hidden_states.shape
|
||||
|
||||
@@ -426,7 +429,8 @@ class TokenDispatcherWithMoge(MoETokenDispatcher):
|
||||
dynamic_scale_for_share: Optional[Any] = None,
|
||||
mc2_mask: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
with_quant: bool = False):
|
||||
with_quant: bool = False,
|
||||
dynamic_eplb: bool = False):
|
||||
self.bsz, _ = hidden_states.shape
|
||||
flatten_topk_ids = topk_ids.view(-1)
|
||||
self.sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
|
||||
@@ -501,8 +505,7 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
|
||||
self.local_expert_indices[i + 1] -
|
||||
1), "local_expert_indices must be continuous"
|
||||
|
||||
def token_dispatch(
|
||||
self,
|
||||
def token_dispatch(self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
@@ -515,7 +518,7 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
|
||||
mc2_mask: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
with_quant: bool = False,
|
||||
):
|
||||
dynamic_eplb: bool = False):
|
||||
self.with_quant = with_quant
|
||||
self.hidden_shape = hidden_states.shape
|
||||
|
||||
|
||||
Reference in New Issue
Block a user