[PERF]support MERRouter (#1421)

### What this PR does / why we need it?
This PR introduces an expert rearrange algorithm for PanguProMoE model.
Different from the original grouped topk, it filters out the top experts
that are allocated more tokens. Therefore, we can load less experts when
calculating gmm.

We have test this algorithm for PanguProMoE-72B on 300I Duo platform and
800I A2 platform. On 300I Duo platform, we find that `num_voted_experts`
set to 5 achieves both good performance and accuracy. While on 800I A2,
we still set it to 8 to use original pangu grouped topk.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->

Signed-off-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: angazenn <zengyanjia@huawei.com>
This commit is contained in:
Angazenn
2025-06-28 16:14:49 +08:00
committed by GitHub
parent 8fa188111d
commit c59d69d9e6
3 changed files with 84 additions and 37 deletions

View File

@@ -57,6 +57,7 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import IntermediateTensors
from vllm_ascend.distributed.parallel_state import get_ep_group
from vllm_ascend.utils import is_310p
logger = init_logger(__name__)
@@ -339,41 +340,81 @@ class PanguProMoEMLP(nn.Module):
return x
class PanguProMoESparseMoeBlock(nn.Module):
def topk_wrapper(num_voted_experts):
@staticmethod
def pangu_group8_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
renormalize: bool = False,
num_expert_group: int = 0,
topk_group: int = 0,
global_num_experts: int = 0,
):
scores = F.softmax(gating_output, dim=1)
num_tokens = scores.shape[0]
router_scale = _ROUTER_SCALE.squeeze( # type: ignore
)
ep_size = get_ep_group().world_size
local_num_experts = global_num_experts // ep_size
local_num_group = topk // ep_size
router_scale = _ROUTER_SCALE.squeeze() # type: ignore
experts_per_group = global_num_experts // topk
local_group_start = get_ep_group().rank_in_group * local_num_experts
local_group_end = (get_ep_group().rank_in_group +
1) * local_num_experts
scores = F.softmax(gating_output, dim=1)
scores = scores[...,
get_ep_group().rank_in_group *
local_num_experts:(get_ep_group().rank_in_group + 1) *
local_num_experts]
scores = scores[..., local_group_start:local_group_end]
router_weights = router_scale[get_ep_group().rank_in_group *
local_num_experts:
(get_ep_group().rank_in_group + 1) *
local_num_experts]
topk_weights, topk_ids = torch.max(scores.view(scores.shape[0],
local_num_group, -1),
dim=-1)
bias = torch.arange(0,
local_num_experts,
topk,
device=scores.device,
dtype=torch.int32).unsqueeze(0)
topk_ids = topk_ids.to(torch.int32) + bias
router_weights = router_scale[local_group_start:local_group_end]
if num_voted_experts == 8:
# use original topk
topk_weights, topk_ids = torch.max(scores.view(
scores.shape[0], local_num_group, -1),
dim=-1)
bias = torch.arange(0,
local_num_experts,
experts_per_group,
device=scores.device,
dtype=torch.int32).unsqueeze(0)
topk_ids = topk_ids.to(torch.int32) + bias
else:
group_expert_indices = torch.arange(experts_per_group,
dtype=torch.int32,
device=scores.device).view(
1, 1, -1)
group_expert_offset = (torch.arange(
local_num_group, dtype=torch.int32, device=scores.device) *
experts_per_group).unsqueeze(0)
expert_index_range = torch.arange(experts_per_group,
dtype=torch.int32,
device=scores.device)
scores_grouped = scores.view(num_tokens, local_num_group,
experts_per_group)
best_expert_idx = torch.argmax(scores_grouped,
dim=2) # (num_tokens, num_groups)
vote_mask = (best_expert_idx.unsqueeze(-1).to(
torch.int32) == group_expert_indices)
expert_vote_freq = vote_mask.sum(dim=0)
sorted_indices = torch.argsort(expert_vote_freq,
dim=1,
descending=True).to(torch.int32)
topk_experts = sorted_indices[:, :num_voted_experts]
keep_mask = ((
topk_experts.unsqueeze(-1) == expert_index_range).any(
dim=1)).unsqueeze(0)
masked_scores = torch.where(keep_mask, scores_grouped, 0)
topk_weights, best_pos_in_group = masked_scores.max(dim=2)
best_pos_in_group = best_pos_in_group.to(torch.int32)
topk_ids = (best_pos_in_group + group_expert_offset).to(
torch.int32)
flatten_topk_ids = topk_ids.view(-1)
router_weights = router_weights.index_select(0, flatten_topk_ids).view(
@@ -382,6 +423,11 @@ class PanguProMoESparseMoeBlock(nn.Module):
return topk_weights, topk_ids
return pangu_group8_topk
class PanguProMoESparseMoeBlock(nn.Module):
def __init__(
self,
config: PretrainedConfig,
@@ -397,14 +443,15 @@ class PanguProMoESparseMoeBlock(nn.Module):
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {config.num_experts}.")
self.local_num_group = config.num_experts_per_tok // get_ep_group(
).world_size
self.num_experts_per_tok = config.num_experts_per_tok
self.local_num_experts = config.num_experts // get_ep_group(
).world_size
self.router_scale = torch.nn.Parameter(
torch.ones((1, self.num_experts)))
# on 300I Duo platform, we find that num_voted_experts set to 5 achieves
# good performance without sacrifice too much accuracy. for other platform,
# this is set to 8 to use original pangu grouped topk.
num_voted_experts = 5 if is_310p() else 8
self.experts = FusedMoE(
num_experts=config.num_experts,
top_k=config.num_experts_per_tok,
@@ -412,8 +459,7 @@ class PanguProMoESparseMoeBlock(nn.Module):
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
quant_config=quant_config,
custom_routing_function=PanguProMoESparseMoeBlock.
pangu_group8_topk,
custom_routing_function=topk_wrapper(num_voted_experts),
prefix=f"{prefix}.experts",
)

View File

@@ -21,7 +21,7 @@ import torch
from vllm.model_executor.layers.fused_moe.layer import \
UnquantizedFusedMoEMethod
from vllm_ascend.ops.fused_moe import (fused_experts, fused_experts_310p,
from vllm_ascend.ops.fused_moe import (fused_experts, fused_experts_moge,
select_experts)
from vllm_ascend.utils import is_310p
@@ -58,9 +58,9 @@ def forward_oot(
e_score_correction_bias=e_score_correction_bias,
)
if is_310p():
if topk_ids.shape[1] < top_k or is_310p():
assert global_num_experts is not None
return fused_experts_310p(
return fused_experts_moge(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,

View File

@@ -39,7 +39,7 @@ from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
from vllm_ascend.utils import (FusedMoEState, dispose_tensor,
get_fused_moe_state, npu_stream_switch,
get_fused_moe_state, is_310p, npu_stream_switch,
npu_wait_tensor)
MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER
@@ -548,8 +548,7 @@ def fused_experts_with_all2all_buffer(
return final_hidden_states
# Currently, fused_experts on 310p only supports PanguProMoE.
def fused_experts_310p(
def fused_experts_moge(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
@@ -614,8 +613,11 @@ def fused_experts_310p(
group_list=group_list,
)[0]
gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to(
torch.float16)
if is_310p():
gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to(
torch.float16)
else:
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
gate_up_out *= topk_scales
w2 = w2.transpose(1, 2)
@@ -628,8 +630,7 @@ def fused_experts_310p(
group_list=group_list,
)[0]
unsorted_topk_ids = torch.argsort(sorted_topk_ids.float()).to(
torch.int32) + torch.Tensor([0]).to(torch.int32).npu()
unsorted_topk_ids = torch.argsort(sorted_topk_ids.float()).to(torch.int32)
unsorted_hidden_states = down_out_list.index_select(0, unsorted_topk_ids)
final_hidden_states = unsorted_hidden_states.reshape(
bsz, top_k // ep_size, -1).sum(1)