[refactor]support gatingtopk operator generalization (#4356)
### What this PR does / why we need it? This pr is cherry-pick from : https://github.com/vllm-project/vllm-ascend/pull/2958 and https://github.com/vllm-project/vllm-ascend/pull/4340 Past: npu_moe_gating_top_k can only support 'group_count=256' pattern Now: 1、npu_moe_gating_top_k support all size of group_count 2、the functionality of `torch_npu.npu_moe_gating_top_k_softmax` are included in `torch_npu.npu_moe_gating_top_k` CANN: depends on 8.3.RC1 Performance: 1. GLM4.5-w8a8, TPS improve 6% 2. Qwen3, the same as before --------- Signed-off-by: 1092626063 <1092626063@qq.com>
This commit is contained in:
@@ -28,7 +28,8 @@ import torch
|
||||
import torch_npu
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
|
||||
from vllm_ascend.ops.moe.experts_selector import select_experts
|
||||
from vllm_ascend.ops.moe.experts_selector import (check_npu_moe_gating_top_k,
|
||||
select_experts)
|
||||
from vllm_ascend.ops.moe.moe_mlp import unified_apply_mlp
|
||||
from vllm_ascend.ops.moe.token_dispatcher import TokenDispatcherWithAllGather
|
||||
|
||||
@@ -296,7 +297,10 @@ def test_select_experts(
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
|
||||
if use_grouped_topk:
|
||||
call_moe_gatingtopk = check_npu_moe_gating_top_k(
|
||||
hidden_states, topk, topk_group, num_expert_group, scoring_func,
|
||||
custom_routing_function)
|
||||
if not call_moe_gatingtopk and use_grouped_topk:
|
||||
mock_native_grouped_topk.assert_called_once()
|
||||
else:
|
||||
mock_native_grouped_topk.assert_not_called()
|
||||
|
||||
Reference in New Issue
Block a user