use npu_moe_gating_top_k_softmax (#1355)
### What this PR does / why we need it?
The optimization solution for non-deepseek select_experts is to replace
gating_topk_softmax with softmax+topk+to, which is optimized from 37us
to 14us on bf16/fp16 of qwen3-235b
- vLLM version: v0.9.2
- vLLM main:
1a4f35e2ea
---------
Signed-off-by: ttanzhiqiang <389825161@qq.com>
This commit is contained in:
@@ -22,10 +22,13 @@ from vllm.config import CompilationLevel, get_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.layer import \
|
||||
UnquantizedFusedMoEMethod
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ops.fused_moe import (fused_experts, fused_experts_moge,
|
||||
select_experts)
|
||||
select_experts,
|
||||
select_gating_top_k_softmax_experts)
|
||||
from vllm_ascend.utils import is_310p
|
||||
|
||||
SELECT_GATING_TOPK_SOTFMAX_EXPERTS: bool = envs_ascend.SELECT_GATING_TOPK_SOTFMAX_EXPERTS
|
||||
original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__
|
||||
|
||||
|
||||
@@ -54,19 +57,27 @@ def forward_oot(
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
topk_weights, topk_ids = select_experts(
|
||||
global_num_experts=global_num_experts,
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
|
||||
if SELECT_GATING_TOPK_SOTFMAX_EXPERTS:
|
||||
topk_weights, topk_ids = select_gating_top_k_softmax_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize)
|
||||
else:
|
||||
topk_weights, topk_ids = select_experts(
|
||||
global_num_experts=global_num_experts,
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
|
||||
if topk_ids.shape[1] < top_k or is_310p():
|
||||
assert global_num_experts is not None
|
||||
|
||||
Reference in New Issue
Block a user