[cherry-pick][refactor]support gatingtopk operator generalization (#4050)

### What this PR does / why we need it?
pick from : https://github.com/vllm-project/vllm-ascend/pull/2958
Past:
npu_moe_gating_top_k can only support 'group_count=256' pattern

Now:
1、npu_moe_gating_top_k support all size of group_count
2、the functionality of `torch_npu.npu_moe_gating_top_k_softmax` are
included in `torch_npu.npu_moe_gating_top_k`

CANN: depends on 8.3.RC1

Performance:
1. GLM4.5-w8a8, TPS improve 6%
2. Qwen3, the same as before


Signed-off-by: 1092626063 <1092626063@qq.com>
This commit is contained in:
1092626063
2025-11-19 10:39:28 +08:00
committed by GitHub
parent ddf3e75800
commit c87a77e8b4
3 changed files with 74 additions and 69 deletions

View File

@@ -754,6 +754,14 @@ class TestSelectExperts(TestBase):
self.hidden_states = torch.randn(self.num_tokens, self.hidden_size) self.hidden_states = torch.randn(self.num_tokens, self.hidden_size)
self.router_logits = torch.randn(self.num_tokens, self.num_experts) self.router_logits = torch.randn(self.num_tokens, self.num_experts)
"""Mock custom routing"""
self.mock_custom_routing = MagicMock()
self.mock_custom_routing.return_value = (torch.ones(
self.num_tokens, self.top_k),
torch.zeros(
self.num_tokens,
self.top_k,
dtype=torch.int32))
self.mock_ctx = MagicMock() self.mock_ctx = MagicMock()
self.mock_ctx.weight_prefetch_method = MagicMock() self.mock_ctx.weight_prefetch_method = MagicMock()
@@ -763,7 +771,7 @@ class TestSelectExperts(TestBase):
self.addCleanup(patcher.stop) self.addCleanup(patcher.stop)
patcher.start() patcher.start()
@patch('torch_npu.npu_moe_gating_top_k_softmax') @patch('torch_npu.npu_moe_gating_top_k')
def test_softmax_scoring(self, mock_topk): def test_softmax_scoring(self, mock_topk):
"""Test softmax scoring function""" """Test softmax scoring function"""
mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k), mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k),
@@ -790,12 +798,14 @@ class TestSelectExperts(TestBase):
def test_sigmoid_scoring(self): def test_sigmoid_scoring(self):
"""Test sigmoid scoring function""" """Test sigmoid scoring function"""
weights, ids = select_experts(hidden_states=self.hidden_states, weights, ids = select_experts(
router_logits=self.router_logits, hidden_states=self.hidden_states,
top_k=self.top_k, router_logits=self.router_logits,
use_grouped_topk=False, top_k=self.top_k,
renormalize=False, use_grouped_topk=False,
scoring_func="sigmoid") renormalize=False,
scoring_func="sigmoid",
custom_routing_function=self.mock_custom_routing)
self.assertEqual(weights.shape, (self.num_tokens, self.top_k)) self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
self.assertEqual(ids.shape, (self.num_tokens, self.top_k)) self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
@@ -808,7 +818,8 @@ class TestSelectExperts(TestBase):
top_k=self.top_k, top_k=self.top_k,
use_grouped_topk=False, use_grouped_topk=False,
renormalize=False, renormalize=False,
scoring_func="invalid_func") scoring_func="invalid_func",
custom_routing_function=self.mock_custom_routing)
@patch('torch.topk') @patch('torch.topk')
def test_grouped_topk(self, mock_topk): def test_grouped_topk(self, mock_topk):
@@ -818,13 +829,15 @@ class TestSelectExperts(TestBase):
self.top_k, self.top_k,
dtype=torch.long)) dtype=torch.long))
weights, ids = select_experts(hidden_states=self.hidden_states, weights, ids = select_experts(
router_logits=self.router_logits, hidden_states=self.hidden_states,
top_k=self.top_k, router_logits=self.router_logits,
use_grouped_topk=True, top_k=self.top_k,
renormalize=False, use_grouped_topk=True,
topk_group=4, renormalize=False,
num_expert_group=2) topk_group=4,
num_expert_group=2,
custom_routing_function=self.mock_custom_routing)
mock_topk.assert_called() mock_topk.assert_called()
self.assertEqual(weights.shape, (self.num_tokens, self.top_k)) self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
@@ -846,7 +859,8 @@ class TestSelectExperts(TestBase):
renormalize=False, renormalize=False,
topk_group=4, topk_group=4,
num_expert_group=2, num_expert_group=2,
e_score_correction_bias=e_score_correction_bias) e_score_correction_bias=e_score_correction_bias,
custom_routing_function=self.mock_custom_routing)
mock_grouped_topk.assert_called_once() mock_grouped_topk.assert_called_once()
self.assertEqual(weights.shape, (self.num_tokens, self.top_k)) self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
@@ -854,27 +868,20 @@ class TestSelectExperts(TestBase):
def test_custom_routing_function(self): def test_custom_routing_function(self):
"""Test custom routing function""" """Test custom routing function"""
mock_custom_routing = MagicMock()
mock_custom_routing.return_value = (torch.ones(self.num_tokens,
self.top_k),
torch.zeros(self.num_tokens,
self.top_k,
dtype=torch.int32))
weights, ids = select_experts( weights, ids = select_experts(
hidden_states=self.hidden_states, hidden_states=self.hidden_states,
router_logits=self.router_logits, router_logits=self.router_logits,
top_k=self.top_k, top_k=self.top_k,
use_grouped_topk=False, use_grouped_topk=False,
renormalize=False, renormalize=False,
custom_routing_function=mock_custom_routing) custom_routing_function=self.mock_custom_routing)
mock_custom_routing.assert_called_once() self.mock_custom_routing.assert_called_once()
self.assertEqual(weights.shape, (self.num_tokens, self.top_k)) self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
self.assertEqual(ids.shape, (self.num_tokens, self.top_k)) self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
self.assertEqual(ids.dtype, torch.int32) self.assertEqual(ids.dtype, torch.int32)
@patch('torch_npu.npu_moe_gating_top_k_softmax') @patch('torch_npu.npu_moe_gating_top_k')
def test_renormalize(self, mock_topk): def test_renormalize(self, mock_topk):
"""Test renormalization""" """Test renormalization"""
mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k), mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k),
@@ -900,13 +907,13 @@ class TestSelectExperts(TestBase):
sums = weights.sum(dim=-1) sums = weights.sum(dim=-1)
self.assertTrue(torch.allclose(sums, torch.ones_like(sums))) self.assertTrue(torch.allclose(sums, torch.ones_like(sums)))
@patch('torch_npu.npu_moe_gating_top_k_softmax') @patch('torch_npu.npu_moe_gating_top_k')
def test_output_dtypes(self, mock_topk): def test_output_dtypes(self, mock_topk):
"""Test output dtypes""" """Test output dtypes"""
mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k), mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k),
torch.zeros(self.num_tokens, torch.zeros(self.num_tokens,
self.top_k, self.top_k,
dtype=torch.long), dtype=torch.int32),
torch.arange(0, torch.arange(0,
self.num_tokens * self.top_k, self.num_tokens * self.top_k,
dtype=torch.int32).view( dtype=torch.int32).view(

View File

@@ -96,6 +96,7 @@ def set_ascend_forward_context(
ep_size = (get_ep_group().world_size if ep_size = (get_ep_group().world_size if
vllm_config.parallel_config.enable_expert_parallel else 1) vllm_config.parallel_config.enable_expert_parallel else 1)
# fused_moe_state is used in torchair, it will be deleted along with torchair
is_deepseek_v3_r1 = hasattr( is_deepseek_v3_r1 = hasattr(
vllm_config.model_config.hf_config, 'n_routed_experts' vllm_config.model_config.hf_config, 'n_routed_experts'
) and vllm_config.model_config.hf_config.n_routed_experts == 256 ) and vllm_config.model_config.hf_config.n_routed_experts == 256

View File

@@ -20,8 +20,6 @@ import torch
import torch_npu import torch_npu
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
from vllm_ascend.ascend_config import get_ascend_config
def select_experts(hidden_states: torch.Tensor, def select_experts(hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
@@ -62,21 +60,20 @@ def select_experts(hidden_states: torch.Tensor,
if weight_prefetch_method: if weight_prefetch_method:
weight_prefetch_method.maybe_prefetch_moe_weight_preprocess( weight_prefetch_method.maybe_prefetch_moe_weight_preprocess(
hidden_states, "gate_up") hidden_states, "gate_up")
topk_weights, topk_ids = _select_experts_with_fusion_ops( if custom_routing_function is None:
hidden_states=hidden_states, topk_weights, topk_ids = _select_experts_with_fusion_ops(
router_logits=router_logits, hidden_states=hidden_states,
top_k=top_k, router_logits=router_logits,
use_grouped_topk=use_grouped_topk, top_k=top_k,
topk_group=topk_group, use_grouped_topk=use_grouped_topk,
renormalize=renormalize, topk_group=topk_group,
e_score_correction_bias=e_score_correction_bias, renormalize=renormalize,
num_expert_group=num_expert_group, e_score_correction_bias=e_score_correction_bias,
custom_routing_function=custom_routing_function, num_expert_group=num_expert_group,
scoring_func=scoring_func, scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
global_num_experts=global_num_experts) global_num_experts=global_num_experts)
else:
if topk_weights is None:
topk_weights, topk_ids = _native_select_experts( topk_weights, topk_ids = _native_select_experts(
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits, router_logits=router_logits,
@@ -171,34 +168,34 @@ def _select_experts_with_fusion_ops(
e_score_correction_bias: Optional[torch.Tensor], e_score_correction_bias: Optional[torch.Tensor],
topk_group: Optional[int], topk_group: Optional[int],
num_expert_group: Optional[int], num_expert_group: Optional[int],
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
routed_scaling_factor=1.0, routed_scaling_factor=1.0,
global_num_experts: int = -1): global_num_experts: int = -1):
topk_weights, topk_ids = None, None if scoring_func == "softmax":
# NOTE: now npu_moe_gating_top_k can only support 'group_count=256' pattern norm_type = 0
global_redundant_expert_num = get_ascend_config().init_redundancy_expert topk_group = 1
is_deepseek_v3_r1 = global_num_experts - global_redundant_expert_num == 256 num_expert_group = 1
if is_deepseek_v3_r1: else:
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( norm_type = 1
router_logits, if e_score_correction_bias is not None and \
k=top_k, # topk currently 8 e_score_correction_bias.dtype != router_logits.dtype:
bias=e_score_correction_bias, e_score_correction_bias = e_score_correction_bias.to(
k_group=topk_group, # fix: 4 router_logits.dtype)
group_count=num_expert_group, # fix 8 topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
group_select_mode= router_logits,
1, # 0: the maximum in the group; 1: topk2.sum(fix) k=top_k,
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax bias=e_score_correction_bias,
norm_type=1, # 0: softmax; 1: sigmoid(fix) k_group=topk_group,
# out_flag=False, # todo new api; should the third output be output group_count=num_expert_group,
# y2_flag=False, # old api; should the third output be output group_select_mode=1, # 0: the maximum in the group; 1: topk2.sum(fix)
routed_scaling_factor=1, renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
eps=float(1e-20)) norm_type=norm_type, # 0: softmax; 1: sigmoid
if not use_grouped_topk and custom_routing_function is None and scoring_func == "softmax": # out_flag=False, # todo new api; should the third output be output
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k_softmax( # y2_flag=False, # old api; should the third output be output
x=router_logits, finished=None, k=top_k) routed_scaling_factor=1,
topk_ids = topk_ids.to(torch.int32) eps=float(1e-20))
if scoring_func == "softmax":
topk_weights = _renormalize_topk_weights(topk_weights, renormalize) topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
return topk_weights, topk_ids return topk_weights, topk_ids