[PERF]support MERRouter (#1421)

### What this PR does / why we need it?
This PR introduces an expert rearrange algorithm for PanguProMoE model.
Different from the original grouped topk, it filters out the top experts
that are allocated more tokens. Therefore, we can load less experts when
calculating gmm.

We have test this algorithm for PanguProMoE-72B on 300I Duo platform and
800I A2 platform. On 300I Duo platform, we find that `num_voted_experts`
set to 5 achieves both good performance and accuracy. While on 800I A2,
we still set it to 8 to use original pangu grouped topk.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->

Signed-off-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: angazenn <zengyanjia@huawei.com>
This commit is contained in:
Angazenn
2025-06-28 16:14:49 +08:00
committed by GitHub
parent 8fa188111d
commit c59d69d9e6
3 changed files with 84 additions and 37 deletions

View File

@@ -21,7 +21,7 @@ import torch
from vllm.model_executor.layers.fused_moe.layer import \
UnquantizedFusedMoEMethod
from vllm_ascend.ops.fused_moe import (fused_experts, fused_experts_310p,
from vllm_ascend.ops.fused_moe import (fused_experts, fused_experts_moge,
select_experts)
from vllm_ascend.utils import is_310p
@@ -58,9 +58,9 @@ def forward_oot(
e_score_correction_bias=e_score_correction_bias,
)
if is_310p():
if topk_ids.shape[1] < top_k or is_310p():
assert global_num_experts is not None
return fused_experts_310p(
return fused_experts_moge(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,

View File

@@ -39,7 +39,7 @@ from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
from vllm_ascend.utils import (FusedMoEState, dispose_tensor,
get_fused_moe_state, npu_stream_switch,
get_fused_moe_state, is_310p, npu_stream_switch,
npu_wait_tensor)
MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER
@@ -548,8 +548,7 @@ def fused_experts_with_all2all_buffer(
return final_hidden_states
# Currently, fused_experts on 310p only supports PanguProMoE.
def fused_experts_310p(
def fused_experts_moge(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
@@ -614,8 +613,11 @@ def fused_experts_310p(
group_list=group_list,
)[0]
gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to(
torch.float16)
if is_310p():
gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to(
torch.float16)
else:
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
gate_up_out *= topk_scales
w2 = w2.transpose(1, 2)
@@ -628,8 +630,7 @@ def fused_experts_310p(
group_list=group_list,
)[0]
unsorted_topk_ids = torch.argsort(sorted_topk_ids.float()).to(
torch.int32) + torch.Tensor([0]).to(torch.int32).npu()
unsorted_topk_ids = torch.argsort(sorted_topk_ids.float()).to(torch.int32)
unsorted_hidden_states = down_out_list.index_select(0, unsorted_topk_ids)
final_hidden_states = unsorted_hidden_states.reshape(
bsz, top_k // ep_size, -1).sum(1)