From c59d69d9e65de4b91628411ef415eca6bf512b44 Mon Sep 17 00:00:00 2001 From: Angazenn <92204292+Angazenn@users.noreply.github.com> Date: Sat, 28 Jun 2025 16:14:49 +0800 Subject: [PATCH] [PERF]support MERRouter (#1421) ### What this PR does / why we need it? This PR introduces an expert rearrange algorithm for PanguProMoE model. Different from the original grouped topk, it filters out the top experts that are allocated more tokens. Therefore, we can load less experts when calculating gmm. We have test this algorithm for PanguProMoE-72B on 300I Duo platform and 800I A2 platform. On 300I Duo platform, we find that `num_voted_experts` set to 5 achieves both good performance and accuracy. While on 800I A2, we still set it to 8 to use original pangu grouped topk. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Signed-off-by: angazenn Co-authored-by: angazenn --- vllm_ascend/models/pangu_moe.py | 100 ++++++++++++++++++++-------- vllm_ascend/ops/common_fused_moe.py | 6 +- vllm_ascend/ops/fused_moe.py | 15 +++-- 3 files changed, 84 insertions(+), 37 deletions(-) diff --git a/vllm_ascend/models/pangu_moe.py b/vllm_ascend/models/pangu_moe.py index 131e1e0..644a00e 100644 --- a/vllm_ascend/models/pangu_moe.py +++ b/vllm_ascend/models/pangu_moe.py @@ -57,6 +57,7 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import IntermediateTensors from vllm_ascend.distributed.parallel_state import get_ep_group +from vllm_ascend.utils import is_310p logger = init_logger(__name__) @@ -339,41 +340,81 @@ class PanguProMoEMLP(nn.Module): return x -class PanguProMoESparseMoeBlock(nn.Module): +def topk_wrapper(num_voted_experts): - @staticmethod def pangu_group8_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, topk: int, - renormalize: bool, + renormalize: bool = False, num_expert_group: int = 0, topk_group: int = 0, global_num_experts: int = 0, ): + scores = F.softmax(gating_output, dim=1) + num_tokens = scores.shape[0] + router_scale = _ROUTER_SCALE.squeeze( # type: ignore + ) + ep_size = get_ep_group().world_size local_num_experts = global_num_experts // ep_size local_num_group = topk // ep_size - router_scale = _ROUTER_SCALE.squeeze() # type: ignore + experts_per_group = global_num_experts // topk + local_group_start = get_ep_group().rank_in_group * local_num_experts + local_group_end = (get_ep_group().rank_in_group + + 1) * local_num_experts scores = F.softmax(gating_output, dim=1) - scores = scores[..., - get_ep_group().rank_in_group * - local_num_experts:(get_ep_group().rank_in_group + 1) * - local_num_experts] + scores = scores[..., local_group_start:local_group_end] - router_weights = router_scale[get_ep_group().rank_in_group * - local_num_experts: - (get_ep_group().rank_in_group + 1) * - local_num_experts] - topk_weights, topk_ids = torch.max(scores.view(scores.shape[0], - local_num_group, -1), - dim=-1) - bias = torch.arange(0, - local_num_experts, - topk, - device=scores.device, - dtype=torch.int32).unsqueeze(0) - topk_ids = topk_ids.to(torch.int32) + bias + router_weights = router_scale[local_group_start:local_group_end] + + if num_voted_experts == 8: + # use original topk + topk_weights, topk_ids = torch.max(scores.view( + scores.shape[0], local_num_group, -1), + dim=-1) + bias = torch.arange(0, + local_num_experts, + experts_per_group, + device=scores.device, + dtype=torch.int32).unsqueeze(0) + topk_ids = topk_ids.to(torch.int32) + bias + + else: + group_expert_indices = torch.arange(experts_per_group, + dtype=torch.int32, + device=scores.device).view( + 1, 1, -1) + group_expert_offset = (torch.arange( + local_num_group, dtype=torch.int32, device=scores.device) * + experts_per_group).unsqueeze(0) + expert_index_range = torch.arange(experts_per_group, + dtype=torch.int32, + device=scores.device) + + scores_grouped = scores.view(num_tokens, local_num_group, + experts_per_group) + best_expert_idx = torch.argmax(scores_grouped, + dim=2) # (num_tokens, num_groups) + vote_mask = (best_expert_idx.unsqueeze(-1).to( + torch.int32) == group_expert_indices) + + expert_vote_freq = vote_mask.sum(dim=0) + + sorted_indices = torch.argsort(expert_vote_freq, + dim=1, + descending=True).to(torch.int32) + topk_experts = sorted_indices[:, :num_voted_experts] + keep_mask = (( + topk_experts.unsqueeze(-1) == expert_index_range).any( + dim=1)).unsqueeze(0) + + masked_scores = torch.where(keep_mask, scores_grouped, 0) + + topk_weights, best_pos_in_group = masked_scores.max(dim=2) + best_pos_in_group = best_pos_in_group.to(torch.int32) + topk_ids = (best_pos_in_group + group_expert_offset).to( + torch.int32) flatten_topk_ids = topk_ids.view(-1) router_weights = router_weights.index_select(0, flatten_topk_ids).view( @@ -382,6 +423,11 @@ class PanguProMoESparseMoeBlock(nn.Module): return topk_weights, topk_ids + return pangu_group8_topk + + +class PanguProMoESparseMoeBlock(nn.Module): + def __init__( self, config: PretrainedConfig, @@ -397,14 +443,15 @@ class PanguProMoESparseMoeBlock(nn.Module): f"Tensor parallel size {self.tp_size} is greater than " f"the number of experts {config.num_experts}.") - self.local_num_group = config.num_experts_per_tok // get_ep_group( - ).world_size self.num_experts_per_tok = config.num_experts_per_tok - self.local_num_experts = config.num_experts // get_ep_group( - ).world_size self.router_scale = torch.nn.Parameter( torch.ones((1, self.num_experts))) + # on 300I Duo platform, we find that num_voted_experts set to 5 achieves + # good performance without sacrifice too much accuracy. for other platform, + # this is set to 8 to use original pangu grouped topk. + num_voted_experts = 5 if is_310p() else 8 + self.experts = FusedMoE( num_experts=config.num_experts, top_k=config.num_experts_per_tok, @@ -412,8 +459,7 @@ class PanguProMoESparseMoeBlock(nn.Module): intermediate_size=config.moe_intermediate_size, reduce_results=False, quant_config=quant_config, - custom_routing_function=PanguProMoESparseMoeBlock. - pangu_group8_topk, + custom_routing_function=topk_wrapper(num_voted_experts), prefix=f"{prefix}.experts", ) diff --git a/vllm_ascend/ops/common_fused_moe.py b/vllm_ascend/ops/common_fused_moe.py index 3c84f23..4e21c74 100644 --- a/vllm_ascend/ops/common_fused_moe.py +++ b/vllm_ascend/ops/common_fused_moe.py @@ -21,7 +21,7 @@ import torch from vllm.model_executor.layers.fused_moe.layer import \ UnquantizedFusedMoEMethod -from vllm_ascend.ops.fused_moe import (fused_experts, fused_experts_310p, +from vllm_ascend.ops.fused_moe import (fused_experts, fused_experts_moge, select_experts) from vllm_ascend.utils import is_310p @@ -58,9 +58,9 @@ def forward_oot( e_score_correction_bias=e_score_correction_bias, ) - if is_310p(): + if topk_ids.shape[1] < top_k or is_310p(): assert global_num_experts is not None - return fused_experts_310p( + return fused_experts_moge( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index cad90a4..bea0dc5 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -39,7 +39,7 @@ from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer from vllm_ascend.utils import (FusedMoEState, dispose_tensor, - get_fused_moe_state, npu_stream_switch, + get_fused_moe_state, is_310p, npu_stream_switch, npu_wait_tensor) MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER @@ -548,8 +548,7 @@ def fused_experts_with_all2all_buffer( return final_hidden_states -# Currently, fused_experts on 310p only supports PanguProMoE. -def fused_experts_310p( +def fused_experts_moge( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, @@ -614,8 +613,11 @@ def fused_experts_310p( group_list=group_list, )[0] - gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to( - torch.float16) + if is_310p(): + gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to( + torch.float16) + else: + gate_up_out = torch_npu.npu_swiglu(gate_up_out) gate_up_out *= topk_scales w2 = w2.transpose(1, 2) @@ -628,8 +630,7 @@ def fused_experts_310p( group_list=group_list, )[0] - unsorted_topk_ids = torch.argsort(sorted_topk_ids.float()).to( - torch.int32) + torch.Tensor([0]).to(torch.int32).npu() + unsorted_topk_ids = torch.argsort(sorted_topk_ids.float()).to(torch.int32) unsorted_hidden_states = down_out_list.index_select(0, unsorted_topk_ids) final_hidden_states = unsorted_hidden_states.reshape( bsz, top_k // ep_size, -1).sum(1)