[main] adapt usage of npu_moe_gating_top_k_softmax and remove envs.SELECT_GATING_TOPK_SOTFMAX_EXPERTS (#2112)

backport of v0.9.1-dev:
https://github.com/vllm-project/vllm-ascend/pull/1902

origin main npu_moe_gating_top_k_softmax:
https://github.com/vllm-project/vllm-ascend/pull/1355

- vLLM version: v0.10.0
- vLLM main:
055bd3978e

Signed-off-by: huangxialu <huangxialu1@huawei.com>
This commit is contained in:
huangxialu
2025-07-31 21:05:56 +08:00
committed by GitHub
parent e8660d7978
commit 9c9a7cd90b
5 changed files with 146 additions and 89 deletions

View File

@@ -22,13 +22,10 @@ from vllm.config import CompilationLevel, get_current_vllm_config
from vllm.model_executor.layers.fused_moe.layer import \
UnquantizedFusedMoEMethod
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ops.fused_moe import (fused_experts, fused_experts_moge,
select_experts,
select_gating_top_k_softmax_experts)
select_experts)
from vllm_ascend.utils import is_310p
SELECT_GATING_TOPK_SOTFMAX_EXPERTS: bool = envs_ascend.SELECT_GATING_TOPK_SOTFMAX_EXPERTS
original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__
@@ -61,26 +58,19 @@ def forward_oot(
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None) -> torch.Tensor:
if SELECT_GATING_TOPK_SOTFMAX_EXPERTS:
topk_weights, topk_ids = select_gating_top_k_softmax_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
renormalize=renormalize)
else:
topk_weights, topk_ids = select_experts(
global_num_experts=global_num_experts,
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)
topk_weights, topk_ids = select_experts(
global_num_experts=global_num_experts,
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)
if topk_ids.shape[1] < top_k or is_310p():
assert global_num_experts is not None