[cherry-pick][refactor]support gatingtopk operator generalization (#4050)
### What this PR does / why we need it? pick from : https://github.com/vllm-project/vllm-ascend/pull/2958 Past: npu_moe_gating_top_k can only support 'group_count=256' pattern Now: 1、npu_moe_gating_top_k support all size of group_count 2、the functionality of `torch_npu.npu_moe_gating_top_k_softmax` are included in `torch_npu.npu_moe_gating_top_k` CANN: depends on 8.3.RC1 Performance: 1. GLM4.5-w8a8, TPS improve 6% 2. Qwen3, the same as before Signed-off-by: 1092626063 <1092626063@qq.com>
This commit is contained in:
@@ -754,6 +754,14 @@ class TestSelectExperts(TestBase):
|
||||
|
||||
self.hidden_states = torch.randn(self.num_tokens, self.hidden_size)
|
||||
self.router_logits = torch.randn(self.num_tokens, self.num_experts)
|
||||
"""Mock custom routing"""
|
||||
self.mock_custom_routing = MagicMock()
|
||||
self.mock_custom_routing.return_value = (torch.ones(
|
||||
self.num_tokens, self.top_k),
|
||||
torch.zeros(
|
||||
self.num_tokens,
|
||||
self.top_k,
|
||||
dtype=torch.int32))
|
||||
|
||||
self.mock_ctx = MagicMock()
|
||||
self.mock_ctx.weight_prefetch_method = MagicMock()
|
||||
@@ -763,7 +771,7 @@ class TestSelectExperts(TestBase):
|
||||
self.addCleanup(patcher.stop)
|
||||
patcher.start()
|
||||
|
||||
@patch('torch_npu.npu_moe_gating_top_k_softmax')
|
||||
@patch('torch_npu.npu_moe_gating_top_k')
|
||||
def test_softmax_scoring(self, mock_topk):
|
||||
"""Test softmax scoring function"""
|
||||
mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k),
|
||||
@@ -790,12 +798,14 @@ class TestSelectExperts(TestBase):
|
||||
def test_sigmoid_scoring(self):
|
||||
"""Test sigmoid scoring function"""
|
||||
|
||||
weights, ids = select_experts(hidden_states=self.hidden_states,
|
||||
router_logits=self.router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=False,
|
||||
renormalize=False,
|
||||
scoring_func="sigmoid")
|
||||
weights, ids = select_experts(
|
||||
hidden_states=self.hidden_states,
|
||||
router_logits=self.router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=False,
|
||||
renormalize=False,
|
||||
scoring_func="sigmoid",
|
||||
custom_routing_function=self.mock_custom_routing)
|
||||
|
||||
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
|
||||
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
|
||||
@@ -808,7 +818,8 @@ class TestSelectExperts(TestBase):
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=False,
|
||||
renormalize=False,
|
||||
scoring_func="invalid_func")
|
||||
scoring_func="invalid_func",
|
||||
custom_routing_function=self.mock_custom_routing)
|
||||
|
||||
@patch('torch.topk')
|
||||
def test_grouped_topk(self, mock_topk):
|
||||
@@ -818,13 +829,15 @@ class TestSelectExperts(TestBase):
|
||||
self.top_k,
|
||||
dtype=torch.long))
|
||||
|
||||
weights, ids = select_experts(hidden_states=self.hidden_states,
|
||||
router_logits=self.router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=True,
|
||||
renormalize=False,
|
||||
topk_group=4,
|
||||
num_expert_group=2)
|
||||
weights, ids = select_experts(
|
||||
hidden_states=self.hidden_states,
|
||||
router_logits=self.router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=True,
|
||||
renormalize=False,
|
||||
topk_group=4,
|
||||
num_expert_group=2,
|
||||
custom_routing_function=self.mock_custom_routing)
|
||||
|
||||
mock_topk.assert_called()
|
||||
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
|
||||
@@ -846,7 +859,8 @@ class TestSelectExperts(TestBase):
|
||||
renormalize=False,
|
||||
topk_group=4,
|
||||
num_expert_group=2,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
custom_routing_function=self.mock_custom_routing)
|
||||
|
||||
mock_grouped_topk.assert_called_once()
|
||||
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
|
||||
@@ -854,27 +868,20 @@ class TestSelectExperts(TestBase):
|
||||
|
||||
def test_custom_routing_function(self):
|
||||
"""Test custom routing function"""
|
||||
mock_custom_routing = MagicMock()
|
||||
mock_custom_routing.return_value = (torch.ones(self.num_tokens,
|
||||
self.top_k),
|
||||
torch.zeros(self.num_tokens,
|
||||
self.top_k,
|
||||
dtype=torch.int32))
|
||||
|
||||
weights, ids = select_experts(
|
||||
hidden_states=self.hidden_states,
|
||||
router_logits=self.router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=False,
|
||||
renormalize=False,
|
||||
custom_routing_function=mock_custom_routing)
|
||||
custom_routing_function=self.mock_custom_routing)
|
||||
|
||||
mock_custom_routing.assert_called_once()
|
||||
self.mock_custom_routing.assert_called_once()
|
||||
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
|
||||
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
|
||||
self.assertEqual(ids.dtype, torch.int32)
|
||||
|
||||
@patch('torch_npu.npu_moe_gating_top_k_softmax')
|
||||
@patch('torch_npu.npu_moe_gating_top_k')
|
||||
def test_renormalize(self, mock_topk):
|
||||
"""Test renormalization"""
|
||||
mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k),
|
||||
@@ -900,13 +907,13 @@ class TestSelectExperts(TestBase):
|
||||
sums = weights.sum(dim=-1)
|
||||
self.assertTrue(torch.allclose(sums, torch.ones_like(sums)))
|
||||
|
||||
@patch('torch_npu.npu_moe_gating_top_k_softmax')
|
||||
@patch('torch_npu.npu_moe_gating_top_k')
|
||||
def test_output_dtypes(self, mock_topk):
|
||||
"""Test output dtypes"""
|
||||
mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k),
|
||||
torch.zeros(self.num_tokens,
|
||||
self.top_k,
|
||||
dtype=torch.long),
|
||||
dtype=torch.int32),
|
||||
torch.arange(0,
|
||||
self.num_tokens * self.top_k,
|
||||
dtype=torch.int32).view(
|
||||
|
||||
@@ -96,6 +96,7 @@ def set_ascend_forward_context(
|
||||
ep_size = (get_ep_group().world_size if
|
||||
vllm_config.parallel_config.enable_expert_parallel else 1)
|
||||
|
||||
# fused_moe_state is used in torchair, it will be deleted along with torchair
|
||||
is_deepseek_v3_r1 = hasattr(
|
||||
vllm_config.model_config.hf_config, 'n_routed_experts'
|
||||
) and vllm_config.model_config.hf_config.n_routed_experts == 256
|
||||
|
||||
@@ -20,8 +20,6 @@ import torch
|
||||
import torch_npu
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
|
||||
|
||||
def select_experts(hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
@@ -62,21 +60,20 @@ def select_experts(hidden_states: torch.Tensor,
|
||||
if weight_prefetch_method:
|
||||
weight_prefetch_method.maybe_prefetch_moe_weight_preprocess(
|
||||
hidden_states, "gate_up")
|
||||
topk_weights, topk_ids = _select_experts_with_fusion_ops(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
topk_group=topk_group,
|
||||
renormalize=renormalize,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
global_num_experts=global_num_experts)
|
||||
|
||||
if topk_weights is None:
|
||||
if custom_routing_function is None:
|
||||
topk_weights, topk_ids = _select_experts_with_fusion_ops(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
topk_group=topk_group,
|
||||
renormalize=renormalize,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
num_expert_group=num_expert_group,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
global_num_experts=global_num_experts)
|
||||
else:
|
||||
topk_weights, topk_ids = _native_select_experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
@@ -171,34 +168,34 @@ def _select_experts_with_fusion_ops(
|
||||
e_score_correction_bias: Optional[torch.Tensor],
|
||||
topk_group: Optional[int],
|
||||
num_expert_group: Optional[int],
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor=1.0,
|
||||
global_num_experts: int = -1):
|
||||
|
||||
topk_weights, topk_ids = None, None
|
||||
# NOTE: now npu_moe_gating_top_k can only support 'group_count=256' pattern
|
||||
global_redundant_expert_num = get_ascend_config().init_redundancy_expert
|
||||
is_deepseek_v3_r1 = global_num_experts - global_redundant_expert_num == 256
|
||||
if is_deepseek_v3_r1:
|
||||
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
||||
router_logits,
|
||||
k=top_k, # topk currently 8
|
||||
bias=e_score_correction_bias,
|
||||
k_group=topk_group, # fix: 4
|
||||
group_count=num_expert_group, # fix 8
|
||||
group_select_mode=
|
||||
1, # 0: the maximum in the group; 1: topk2.sum(fix)
|
||||
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
|
||||
norm_type=1, # 0: softmax; 1: sigmoid(fix)
|
||||
# out_flag=False, # todo new api; should the third output be output
|
||||
# y2_flag=False, # old api; should the third output be output
|
||||
routed_scaling_factor=1,
|
||||
eps=float(1e-20))
|
||||
if not use_grouped_topk and custom_routing_function is None and scoring_func == "softmax":
|
||||
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k_softmax(
|
||||
x=router_logits, finished=None, k=top_k)
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
if scoring_func == "softmax":
|
||||
norm_type = 0
|
||||
topk_group = 1
|
||||
num_expert_group = 1
|
||||
else:
|
||||
norm_type = 1
|
||||
if e_score_correction_bias is not None and \
|
||||
e_score_correction_bias.dtype != router_logits.dtype:
|
||||
e_score_correction_bias = e_score_correction_bias.to(
|
||||
router_logits.dtype)
|
||||
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
||||
router_logits,
|
||||
k=top_k,
|
||||
bias=e_score_correction_bias,
|
||||
k_group=topk_group,
|
||||
group_count=num_expert_group,
|
||||
group_select_mode=1, # 0: the maximum in the group; 1: topk2.sum(fix)
|
||||
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
|
||||
norm_type=norm_type, # 0: softmax; 1: sigmoid
|
||||
# out_flag=False, # todo new api; should the third output be output
|
||||
# y2_flag=False, # old api; should the third output be output
|
||||
routed_scaling_factor=1,
|
||||
eps=float(1e-20))
|
||||
if scoring_func == "softmax":
|
||||
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
|
||||
|
||||
return topk_weights, topk_ids
|
||||
|
||||
Reference in New Issue
Block a user