【fix】ops gatingtopk fix nightly ci error (#4340)

### What this PR does / why we need it?
This pr https://github.com/vllm-project/vllm-ascend/pull/2958 is
supporting gatingtopk operator generalization, but caused nightly ci
error.
Now we add check logits for ops gatingtopk, and fix nightly ci.

- vLLM version: v0.12.0

Signed-off-by: 1092626063 <1092626063@qq.com>
This commit is contained in:
1092626063
2025-12-04 20:09:21 +08:00
committed by GitHub
parent da84eb2f40
commit b3e1377a92
3 changed files with 53 additions and 22 deletions

View File

@@ -28,7 +28,8 @@ import torch
import torch_npu
from vllm.model_executor.layers.activation import SiluAndMul
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
from vllm_ascend.ops.fused_moe.experts_selector import (
check_npu_moe_gating_top_k, select_experts)
from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp
from vllm_ascend.ops.fused_moe.token_dispatcher import \
TokenDispatcherWithAllGather
@@ -303,7 +304,10 @@ def test_select_experts(
e_score_correction_bias=e_score_correction_bias,
)
if use_grouped_topk:
call_moe_gatingtopk = check_npu_moe_gating_top_k(
hidden_states, topk, topk_group, num_expert_group, scoring_func,
custom_routing_function)
if not call_moe_gatingtopk and use_grouped_topk:
mock_native_grouped_topk.assert_called_once()
else:
mock_native_grouped_topk.assert_not_called()

View File

@@ -823,8 +823,7 @@ class TestSelectExperts(TestBase):
top_k=self.top_k,
use_grouped_topk=False,
renormalize=False,
scoring_func="invalid_func",
custom_routing_function=self.mock_custom_routing)
scoring_func="invalid_func")
@patch('torch.topk')
def test_grouped_topk(self, mock_topk):
@@ -834,15 +833,13 @@ class TestSelectExperts(TestBase):
self.top_k,
dtype=torch.long))
weights, ids = select_experts(
hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
use_grouped_topk=True,
renormalize=False,
topk_group=4,
num_expert_group=2,
custom_routing_function=self.mock_custom_routing)
weights, ids = select_experts(hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
use_grouped_topk=True,
renormalize=False,
topk_group=4,
num_expert_group=2)
mock_topk.assert_called()
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
@@ -864,8 +861,7 @@ class TestSelectExperts(TestBase):
renormalize=False,
topk_group=4,
num_expert_group=2,
e_score_correction_bias=e_score_correction_bias,
custom_routing_function=self.mock_custom_routing)
e_score_correction_bias=e_score_correction_bias)
mock_grouped_topk.assert_called_once()
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))

View File

@@ -60,7 +60,15 @@ def select_experts(hidden_states: torch.Tensor,
if weight_prefetch_method:
weight_prefetch_method.maybe_prefetch_moe_weight_preprocess(
hidden_states, "gate_up")
if custom_routing_function is None:
is_support_npu_moe_gating_top_k = check_npu_moe_gating_top_k(
hidden_states=hidden_states,
top_k=top_k,
topk_group=topk_group,
num_expert_group=num_expert_group,
scoring_func=scoring_func,
custom_routing_function=custom_routing_function)
if is_support_npu_moe_gating_top_k:
topk_weights, topk_ids = _select_experts_with_fusion_ops(
hidden_states=hidden_states,
router_logits=router_logits,
@@ -90,6 +98,32 @@ def select_experts(hidden_states: torch.Tensor,
return topk_weights, topk_ids
def check_npu_moe_gating_top_k(
hidden_states: torch.Tensor,
top_k: int,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
scoring_func: str = "softmax",
custom_routing_function: Optional[Callable] = None):
if custom_routing_function is not None:
return False
if scoring_func != "softmax" and scoring_func != "sigmoid":
return False
topk_group = topk_group if topk_group is not None else 1
num_expert_group = num_expert_group if num_expert_group is not None else 1
if not (num_expert_group > 0 and hidden_states.shape[-1] % num_expert_group
== 0 and hidden_states.shape[-1] // num_expert_group > 2):
return False
if topk_group < 1 or topk_group > num_expert_group:
return False
if top_k < 1 or \
top_k > (hidden_states.shape[-1] / (num_expert_group * topk_group)):
return False
if topk_group * hidden_states.shape[-1] / num_expert_group < top_k:
return False
return True
def _native_grouped_topk(
topk_weights: torch.Tensor,
num_expert_group: Optional[int],
@@ -172,12 +206,9 @@ def _select_experts_with_fusion_ops(
routed_scaling_factor=1.0,
global_num_experts: int = -1):
if scoring_func == "softmax":
norm_type = 0
topk_group = 1
num_expert_group = 1
else:
norm_type = 1
topk_group = topk_group if topk_group is not None else 1
num_expert_group = num_expert_group if num_expert_group is not None else 1
norm_type = 0 if scoring_func == "softmax" else 1
if e_score_correction_bias is not None and \
e_score_correction_bias.dtype != router_logits.dtype:
e_score_correction_bias = e_score_correction_bias.to(