[cherry-pick][refactor]support gatingtopk operator generalization (#4050)
### What this PR does / why we need it? pick from : https://github.com/vllm-project/vllm-ascend/pull/2958 Past: npu_moe_gating_top_k can only support 'group_count=256' pattern Now: 1、npu_moe_gating_top_k support all size of group_count 2、the functionality of `torch_npu.npu_moe_gating_top_k_softmax` are included in `torch_npu.npu_moe_gating_top_k` CANN: depends on 8.3.RC1 Performance: 1. GLM4.5-w8a8, TPS improve 6% 2. Qwen3, the same as before Signed-off-by: 1092626063 <1092626063@qq.com>
This commit is contained in:
@@ -754,6 +754,14 @@ class TestSelectExperts(TestBase):
|
||||
|
||||
self.hidden_states = torch.randn(self.num_tokens, self.hidden_size)
|
||||
self.router_logits = torch.randn(self.num_tokens, self.num_experts)
|
||||
"""Mock custom routing"""
|
||||
self.mock_custom_routing = MagicMock()
|
||||
self.mock_custom_routing.return_value = (torch.ones(
|
||||
self.num_tokens, self.top_k),
|
||||
torch.zeros(
|
||||
self.num_tokens,
|
||||
self.top_k,
|
||||
dtype=torch.int32))
|
||||
|
||||
self.mock_ctx = MagicMock()
|
||||
self.mock_ctx.weight_prefetch_method = MagicMock()
|
||||
@@ -763,7 +771,7 @@ class TestSelectExperts(TestBase):
|
||||
self.addCleanup(patcher.stop)
|
||||
patcher.start()
|
||||
|
||||
@patch('torch_npu.npu_moe_gating_top_k_softmax')
|
||||
@patch('torch_npu.npu_moe_gating_top_k')
|
||||
def test_softmax_scoring(self, mock_topk):
|
||||
"""Test softmax scoring function"""
|
||||
mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k),
|
||||
@@ -790,12 +798,14 @@ class TestSelectExperts(TestBase):
|
||||
def test_sigmoid_scoring(self):
|
||||
"""Test sigmoid scoring function"""
|
||||
|
||||
weights, ids = select_experts(hidden_states=self.hidden_states,
|
||||
router_logits=self.router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=False,
|
||||
renormalize=False,
|
||||
scoring_func="sigmoid")
|
||||
weights, ids = select_experts(
|
||||
hidden_states=self.hidden_states,
|
||||
router_logits=self.router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=False,
|
||||
renormalize=False,
|
||||
scoring_func="sigmoid",
|
||||
custom_routing_function=self.mock_custom_routing)
|
||||
|
||||
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
|
||||
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
|
||||
@@ -808,7 +818,8 @@ class TestSelectExperts(TestBase):
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=False,
|
||||
renormalize=False,
|
||||
scoring_func="invalid_func")
|
||||
scoring_func="invalid_func",
|
||||
custom_routing_function=self.mock_custom_routing)
|
||||
|
||||
@patch('torch.topk')
|
||||
def test_grouped_topk(self, mock_topk):
|
||||
@@ -818,13 +829,15 @@ class TestSelectExperts(TestBase):
|
||||
self.top_k,
|
||||
dtype=torch.long))
|
||||
|
||||
weights, ids = select_experts(hidden_states=self.hidden_states,
|
||||
router_logits=self.router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=True,
|
||||
renormalize=False,
|
||||
topk_group=4,
|
||||
num_expert_group=2)
|
||||
weights, ids = select_experts(
|
||||
hidden_states=self.hidden_states,
|
||||
router_logits=self.router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=True,
|
||||
renormalize=False,
|
||||
topk_group=4,
|
||||
num_expert_group=2,
|
||||
custom_routing_function=self.mock_custom_routing)
|
||||
|
||||
mock_topk.assert_called()
|
||||
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
|
||||
@@ -846,7 +859,8 @@ class TestSelectExperts(TestBase):
|
||||
renormalize=False,
|
||||
topk_group=4,
|
||||
num_expert_group=2,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
custom_routing_function=self.mock_custom_routing)
|
||||
|
||||
mock_grouped_topk.assert_called_once()
|
||||
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
|
||||
@@ -854,27 +868,20 @@ class TestSelectExperts(TestBase):
|
||||
|
||||
def test_custom_routing_function(self):
|
||||
"""Test custom routing function"""
|
||||
mock_custom_routing = MagicMock()
|
||||
mock_custom_routing.return_value = (torch.ones(self.num_tokens,
|
||||
self.top_k),
|
||||
torch.zeros(self.num_tokens,
|
||||
self.top_k,
|
||||
dtype=torch.int32))
|
||||
|
||||
weights, ids = select_experts(
|
||||
hidden_states=self.hidden_states,
|
||||
router_logits=self.router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=False,
|
||||
renormalize=False,
|
||||
custom_routing_function=mock_custom_routing)
|
||||
custom_routing_function=self.mock_custom_routing)
|
||||
|
||||
mock_custom_routing.assert_called_once()
|
||||
self.mock_custom_routing.assert_called_once()
|
||||
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
|
||||
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
|
||||
self.assertEqual(ids.dtype, torch.int32)
|
||||
|
||||
@patch('torch_npu.npu_moe_gating_top_k_softmax')
|
||||
@patch('torch_npu.npu_moe_gating_top_k')
|
||||
def test_renormalize(self, mock_topk):
|
||||
"""Test renormalization"""
|
||||
mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k),
|
||||
@@ -900,13 +907,13 @@ class TestSelectExperts(TestBase):
|
||||
sums = weights.sum(dim=-1)
|
||||
self.assertTrue(torch.allclose(sums, torch.ones_like(sums)))
|
||||
|
||||
@patch('torch_npu.npu_moe_gating_top_k_softmax')
|
||||
@patch('torch_npu.npu_moe_gating_top_k')
|
||||
def test_output_dtypes(self, mock_topk):
|
||||
"""Test output dtypes"""
|
||||
mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k),
|
||||
torch.zeros(self.num_tokens,
|
||||
self.top_k,
|
||||
dtype=torch.long),
|
||||
dtype=torch.int32),
|
||||
torch.arange(0,
|
||||
self.num_tokens * self.top_k,
|
||||
dtype=torch.int32).view(
|
||||
|
||||
Reference in New Issue
Block a user