moe_gating_top_k (#5271)

1. What this PR does / why we need it?
This PR supports the moe_gating_top_k operator, which enables
post-positioned renormalization (renorm) on the basis of softmax.
2. Does this PR introduce any user-facing change?
No user-facing changes are required.
3. How was this patch tested?
This patch was tested with the test_npu_moe_gating_top_k test case.
vLLM version: release/v0.13.0
vLLM main:
ad32e3e19c

---------

Signed-off-by: ZCG12345 <2097562023@qq.com>
Signed-off-by: zzzzwwjj <34335947+zzzzwwjj@users.noreply.github.com>
Co-authored-by: zzzzwwjj <34335947+zzzzwwjj@users.noreply.github.com>
This commit is contained in:
ZCG12345
2025-12-30 09:28:01 +08:00
committed by GitHub
parent 15d73f248e
commit 45c3c279e2
34 changed files with 4791 additions and 22 deletions

View File

@@ -17,7 +17,6 @@
from typing import Callable, Optional
import torch
import torch_npu
from vllm_ascend.utils import get_weight_prefetch_method
@@ -214,21 +213,19 @@ def _select_experts_with_fusion_ops(
e_score_correction_bias.dtype != router_logits.dtype:
e_score_correction_bias = e_score_correction_bias.to(
router_logits.dtype)
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
_, topk_ids, topk_weights = torch.ops._C_ascend.moe_gating_top_k(
router_logits,
k=top_k,
bias=e_score_correction_bias,
k_group=topk_group,
group_count=num_expert_group,
group_select_mode=1, # 0: the maximum in the group; 1: topk2.sum(fix)
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
norm_type=norm_type, # 0: softmax; 1: sigmoid
# out_flag=False, # todo new api; should the third output be output
# y2_flag=False, # old api; should the third output be output
routed_scaling_factor=1,
eps=float(1e-20))
if scoring_func == "softmax":
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
kGroup=topk_group,
groupCount=num_expert_group,
groupSelectMode=1, # 0: the maximum in the group; 1: topk2.sum(fix)
renorm=1, # 0: softmax->topk(fix); 1: topk->softmax
normType=norm_type, # 0: softmax; 1: sigmoid
outFlag=False, # todo new api; should the third output be output
routedScalingFactor=1,
eps=float(1e-20),
biasOptional=e_score_correction_bias,
)
return topk_weights, topk_ids