[main] adapt usage of npu_moe_gating_top_k_softmax and remove envs.SELECT_GATING_TOPK_SOTFMAX_EXPERTS (#2112)

backport of v0.9.1-dev:
https://github.com/vllm-project/vllm-ascend/pull/1902

origin main npu_moe_gating_top_k_softmax:
https://github.com/vllm-project/vllm-ascend/pull/1355

- vLLM version: v0.10.0
- vLLM main:
055bd3978e

Signed-off-by: huangxialu <huangxialu1@huawei.com>
This commit is contained in:
huangxialu
2025-07-31 21:05:56 +08:00
committed by GitHub
parent e8660d7978
commit 9c9a7cd90b
5 changed files with 146 additions and 89 deletions

View File

@@ -22,13 +22,10 @@ from vllm.config import CompilationLevel, get_current_vllm_config
from vllm.model_executor.layers.fused_moe.layer import \
UnquantizedFusedMoEMethod
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ops.fused_moe import (fused_experts, fused_experts_moge,
select_experts,
select_gating_top_k_softmax_experts)
select_experts)
from vllm_ascend.utils import is_310p
SELECT_GATING_TOPK_SOTFMAX_EXPERTS: bool = envs_ascend.SELECT_GATING_TOPK_SOTFMAX_EXPERTS
original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__
@@ -61,26 +58,19 @@ def forward_oot(
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None) -> torch.Tensor:
if SELECT_GATING_TOPK_SOTFMAX_EXPERTS:
topk_weights, topk_ids = select_gating_top_k_softmax_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
renormalize=renormalize)
else:
topk_weights, topk_ids = select_experts(
global_num_experts=global_num_experts,
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)
topk_weights, topk_ids = select_experts(
global_num_experts=global_num_experts,
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)
if topk_ids.shape[1] < top_k or is_310p():
assert global_num_experts is not None

View File

@@ -52,7 +52,6 @@ from vllm_ascend.utils import (AscendSocVersion, dispose_tensor,
get_rm_router_logits_state, is_310p)
MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER
SELECT_GATING_TOPK_SOTFMAX_EXPERTS: bool = envs_ascend.SELECT_GATING_TOPK_SOTFMAX_EXPERTS
def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int,
@@ -859,39 +858,6 @@ def fused_experts(
return final_hidden_states
def select_gating_top_k_softmax_experts(
hidden_states: torch.Tensor, router_logits: torch.Tensor, top_k: int,
renormalize: bool) -> tuple[torch.Tensor, torch.Tensor]:
"""
Select top-k experts based on router logits.
only supports float16、bfloat16、float32
Args:
hidden_states: Hidden states of shape (num_tokens, hidden_size).
router_logits: Router logits of shape (num_tokens, num_experts).
top_k: Number of experts to select.
renormalize: Whether to renormalize the routing weights.
Returns:
topk_weights: Routing weights of shape (num_tokens, top_k).
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
Raises:
ValueError: If an unsupported scoring function is provided.
"""
topk_weights, topk_ids, row_idx = torch_npu.npu_moe_gating_top_k_softmax(
router_logits, None, k=top_k)
# # Required by npu_moe_init_routing
# topk_weights = topk_weights.to(hidden_states.dtype)
# topk_ids = topk_ids.to(torch.int32)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids
def native_grouped_topk(
topk_weights: torch.Tensor,
num_expert_group: Optional[int],
@@ -953,8 +919,24 @@ def select_experts(
ValueError: If an unsupported scoring function is provided.
"""
def _renormalize_topk_weights(
topk_weights: torch.Tensor,
renormalize: bool,
):
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1,
keepdim=True)
return topk_weights
if scoring_func == "softmax":
# NOTE: vLLM use dtype=torch.float here
if not use_grouped_topk and custom_routing_function is None:
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k_softmax(
x=router_logits, finished=None, k=top_k)
topk_ids = topk_ids.to(torch.int32)
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
return topk_weights, topk_ids
topk_weights = router_logits.softmax(dim=-1)
elif scoring_func == "sigmoid":
topk_weights = router_logits.sigmoid()
@@ -988,10 +970,11 @@ def select_experts(
k=top_k,
dim=-1,
sorted=False)
elif custom_routing_function is None:
topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1)
topk_weights = topk_weights.to(hidden_states.dtype)
else:
topk_ids = topk_ids.to(torch.int32)
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
return topk_weights, topk_ids
if custom_routing_function is not None:
topk_weights, topk_ids = custom_routing_function(
hidden_states=hidden_states,
gating_output=router_logits,
@@ -1002,11 +985,12 @@ def select_experts(
topk_ids = topk_ids.to(torch.int32)
return topk_weights, topk_ids
topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1)
topk_weights = topk_weights.to(hidden_states.dtype)
# Required by npu_moe_init_routing
topk_ids = topk_ids.to(torch.int32)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
return topk_weights, topk_ids
@@ -1070,23 +1054,18 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
if is_deepseek_v3_r1:
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
router_logits,
k=top_k, # topk当前写8
k=top_k, # topk currently is 8
bias=e_score_correction_bias,
k_group=topk_group, # fix: 4
group_count=num_expert_group, # fix 8
group_select_mode=1, # 0: group中的最大; 1: topk2.sum(fix)
group_select_mode=
1, # 0: the maximum in the group; 1: topk2.sum(fix)
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
norm_type=1, # 0: softmax; 1: sigmoid(fix)
# out_flag=False, # todo new api; 第三个输出是否输出
# y2_flag=False, # old api; 第三个输出是否输出
# out_flag=False, # todo new api; should the third output be output
# y2_flag=False, # old api; should the third output be output
routed_scaling_factor=1,
eps=float(1e-20))
elif SELECT_GATING_TOPK_SOTFMAX_EXPERTS:
topk_weights, topk_ids = select_gating_top_k_softmax_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
renormalize=renormalize)
else:
topk_weights, topk_ids = select_experts(
hidden_states=x,