[BugFix]Fix group list type of mc2. (#3890)
### What this PR does / why we need it? Fix the precision issue caused by the inconsistency between the group list type used by mc2 and that of eplb. --------- Signed-off-by: offline0806 <3337230449@qq.com>
This commit is contained in:
@@ -1,8 +1,5 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from vllm import SamplingParams
|
from vllm import SamplingParams
|
||||||
from vllm.config import CompilationConfig, CUDAGraphMode
|
from vllm.config import CompilationConfig, CUDAGraphMode
|
||||||
|
|||||||
@@ -244,7 +244,8 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
self.expert_map != -1) if self.expert_map is not None else
|
self.expert_map != -1) if self.expert_map is not None else
|
||||||
self.global_num_experts)
|
self.global_num_experts)
|
||||||
if self.dynamic_eplb:
|
if self.dynamic_eplb:
|
||||||
self.moe_load = torch.zeros(local_num_experts, dtype=torch.int64)
|
self.moe_load = torch.zeros(local_num_experts,
|
||||||
|
dtype=torch.int64).npu()
|
||||||
|
|
||||||
self.moe_config.num_experts = self.global_num_experts
|
self.moe_config.num_experts = self.global_num_experts
|
||||||
self.moe_config.num_local_experts = self.local_num_experts
|
self.moe_config.num_local_experts = self.local_num_experts
|
||||||
@@ -340,9 +341,9 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
if isinstance(final_hidden_states, tuple):
|
if isinstance(final_hidden_states, tuple):
|
||||||
final_hidden_states, group_list_type, expert_tokens = final_hidden_states
|
final_hidden_states, group_list_type, expert_tokens = final_hidden_states
|
||||||
|
|
||||||
if self.dynamic_eplb:
|
if self.dynamic_eplb:
|
||||||
self.moe_load += expert_tokens if group_list_type else \
|
self.moe_load += expert_tokens if group_list_type == 1 else \
|
||||||
torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]])
|
torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]])
|
||||||
|
|
||||||
final_hidden_states = forward_context.moe_comm_method.finalize(
|
final_hidden_states = forward_context.moe_comm_method.finalize(
|
||||||
hidden_states=final_hidden_states,
|
hidden_states=final_hidden_states,
|
||||||
|
|||||||
@@ -130,7 +130,8 @@ class MoECommMethod(ABC):
|
|||||||
dynamic_scale_for_share=dynamic_scale_for_share,
|
dynamic_scale_for_share=dynamic_scale_for_share,
|
||||||
mc2_mask=self.mc2_mask,
|
mc2_mask=self.mc2_mask,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
with_quant=use_int8_w8a8 or use_int4_w4a8)
|
with_quant=use_int8_w8a8 or use_int4_w4a8,
|
||||||
|
dynamic_eplb=dynamic_eplb)
|
||||||
|
|
||||||
permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type, topk_scales = \
|
permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type, topk_scales = \
|
||||||
results["hidden_states"], results["group_list"], results.get("dynamic_scale"), results["group_list_type"], results.get("topk_scales")
|
results["hidden_states"], results["group_list"], results.get("dynamic_scale"), results["group_list_type"], results.get("topk_scales")
|
||||||
|
|||||||
@@ -69,7 +69,8 @@ class MoETokenDispatcher(ABC):
|
|||||||
dynamic_scale_for_share: Optional[Any] = None,
|
dynamic_scale_for_share: Optional[Any] = None,
|
||||||
mc2_mask: Optional[torch.Tensor] = None,
|
mc2_mask: Optional[torch.Tensor] = None,
|
||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
with_quant: bool = False):
|
with_quant: bool = False,
|
||||||
|
dynamic_eplb: bool = False):
|
||||||
raise NotImplementedError("Dispatch function not implemented.")
|
raise NotImplementedError("Dispatch function not implemented.")
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@@ -175,7 +176,8 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
|||||||
dynamic_scale_for_share: Optional[Any] = None,
|
dynamic_scale_for_share: Optional[Any] = None,
|
||||||
mc2_mask: Optional[torch.Tensor] = None,
|
mc2_mask: Optional[torch.Tensor] = None,
|
||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
with_quant: bool = False):
|
with_quant: bool = False,
|
||||||
|
dynamic_eplb: bool = False):
|
||||||
self.with_quant = with_quant
|
self.with_quant = with_quant
|
||||||
self.expert_map = expert_map
|
self.expert_map = expert_map
|
||||||
self.topk_ids = topk_ids
|
self.topk_ids = topk_ids
|
||||||
@@ -210,7 +212,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
|||||||
if shared_experts is not None:
|
if shared_experts is not None:
|
||||||
shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states)
|
shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states)
|
||||||
self.shared_act = shared_experts.act_fn(shared_gate_up)
|
self.shared_act = shared_experts.act_fn(shared_gate_up)
|
||||||
group_list_type = 0
|
group_list_type = 1 if dynamic_eplb else 0
|
||||||
return {
|
return {
|
||||||
"group_list_type": group_list_type,
|
"group_list_type": group_list_type,
|
||||||
"hidden_states": expand_x,
|
"hidden_states": expand_x,
|
||||||
@@ -333,7 +335,8 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
|
|||||||
dynamic_scale_for_share: Optional[Any] = None,
|
dynamic_scale_for_share: Optional[Any] = None,
|
||||||
mc2_mask: Optional[torch.Tensor] = None,
|
mc2_mask: Optional[torch.Tensor] = None,
|
||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
with_quant: bool = False):
|
with_quant: bool = False,
|
||||||
|
dynamic_eplb: bool = False):
|
||||||
self.with_quant = with_quant
|
self.with_quant = with_quant
|
||||||
self.original_shape = hidden_states.shape
|
self.original_shape = hidden_states.shape
|
||||||
|
|
||||||
@@ -424,7 +427,8 @@ class TokenDispatcherWithMoge(MoETokenDispatcher):
|
|||||||
dynamic_scale_for_share: Optional[Any] = None,
|
dynamic_scale_for_share: Optional[Any] = None,
|
||||||
mc2_mask: Optional[torch.Tensor] = None,
|
mc2_mask: Optional[torch.Tensor] = None,
|
||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
with_quant: bool = False):
|
with_quant: bool = False,
|
||||||
|
dynamic_eplb: bool = False):
|
||||||
self.bsz, _ = hidden_states.shape
|
self.bsz, _ = hidden_states.shape
|
||||||
flatten_topk_ids = topk_ids.view(-1)
|
flatten_topk_ids = topk_ids.view(-1)
|
||||||
self.sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
|
self.sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
|
||||||
@@ -521,7 +525,8 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
|
|||||||
dynamic_scale_for_share: Optional[Any] = None,
|
dynamic_scale_for_share: Optional[Any] = None,
|
||||||
mc2_mask: Optional[torch.Tensor] = None,
|
mc2_mask: Optional[torch.Tensor] = None,
|
||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
with_quant: bool = False):
|
with_quant: bool = False,
|
||||||
|
dynamic_eplb: bool = False):
|
||||||
self.with_quant = with_quant
|
self.with_quant = with_quant
|
||||||
self.hidden_shape = hidden_states.shape
|
self.hidden_shape = hidden_states.shape
|
||||||
self.topk_weights = topk_weights
|
self.topk_weights = topk_weights
|
||||||
|
|||||||
Reference in New Issue
Block a user