Revert "[cherry-pick][refactor]support gatingtopk operator generalization (#4050)" (#4352)

This reverts commit c87a77e8b4.

it breaks ops e2e test

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiyuan
2025-11-21 23:03:20 +08:00
committed by GitHub
parent 5ad0ccdc31
commit a2e4c3fe78
3 changed files with 69 additions and 74 deletions

View File

@@ -754,14 +754,6 @@ class TestSelectExperts(TestBase):
self.hidden_states = torch.randn(self.num_tokens, self.hidden_size)
self.router_logits = torch.randn(self.num_tokens, self.num_experts)
"""Mock custom routing"""
self.mock_custom_routing = MagicMock()
self.mock_custom_routing.return_value = (torch.ones(
self.num_tokens, self.top_k),
torch.zeros(
self.num_tokens,
self.top_k,
dtype=torch.int32))
self.mock_ctx = MagicMock()
self.mock_ctx.weight_prefetch_method = MagicMock()
@@ -771,7 +763,7 @@ class TestSelectExperts(TestBase):
self.addCleanup(patcher.stop)
patcher.start()
@patch('torch_npu.npu_moe_gating_top_k')
@patch('torch_npu.npu_moe_gating_top_k_softmax')
def test_softmax_scoring(self, mock_topk):
"""Test softmax scoring function"""
mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k),
@@ -798,14 +790,12 @@ class TestSelectExperts(TestBase):
def test_sigmoid_scoring(self):
"""Test sigmoid scoring function"""
weights, ids = select_experts(
hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
use_grouped_topk=False,
renormalize=False,
scoring_func="sigmoid",
custom_routing_function=self.mock_custom_routing)
weights, ids = select_experts(hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
use_grouped_topk=False,
renormalize=False,
scoring_func="sigmoid")
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
@@ -818,8 +808,7 @@ class TestSelectExperts(TestBase):
top_k=self.top_k,
use_grouped_topk=False,
renormalize=False,
scoring_func="invalid_func",
custom_routing_function=self.mock_custom_routing)
scoring_func="invalid_func")
@patch('torch.topk')
def test_grouped_topk(self, mock_topk):
@@ -829,15 +818,13 @@ class TestSelectExperts(TestBase):
self.top_k,
dtype=torch.long))
weights, ids = select_experts(
hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
use_grouped_topk=True,
renormalize=False,
topk_group=4,
num_expert_group=2,
custom_routing_function=self.mock_custom_routing)
weights, ids = select_experts(hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
use_grouped_topk=True,
renormalize=False,
topk_group=4,
num_expert_group=2)
mock_topk.assert_called()
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
@@ -859,8 +846,7 @@ class TestSelectExperts(TestBase):
renormalize=False,
topk_group=4,
num_expert_group=2,
e_score_correction_bias=e_score_correction_bias,
custom_routing_function=self.mock_custom_routing)
e_score_correction_bias=e_score_correction_bias)
mock_grouped_topk.assert_called_once()
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
@@ -868,20 +854,27 @@ class TestSelectExperts(TestBase):
def test_custom_routing_function(self):
"""Test custom routing function"""
mock_custom_routing = MagicMock()
mock_custom_routing.return_value = (torch.ones(self.num_tokens,
self.top_k),
torch.zeros(self.num_tokens,
self.top_k,
dtype=torch.int32))
weights, ids = select_experts(
hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
use_grouped_topk=False,
renormalize=False,
custom_routing_function=self.mock_custom_routing)
custom_routing_function=mock_custom_routing)
self.mock_custom_routing.assert_called_once()
mock_custom_routing.assert_called_once()
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
self.assertEqual(ids.dtype, torch.int32)
@patch('torch_npu.npu_moe_gating_top_k')
@patch('torch_npu.npu_moe_gating_top_k_softmax')
def test_renormalize(self, mock_topk):
"""Test renormalization"""
mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k),
@@ -907,13 +900,13 @@ class TestSelectExperts(TestBase):
sums = weights.sum(dim=-1)
self.assertTrue(torch.allclose(sums, torch.ones_like(sums)))
@patch('torch_npu.npu_moe_gating_top_k')
@patch('torch_npu.npu_moe_gating_top_k_softmax')
def test_output_dtypes(self, mock_topk):
"""Test output dtypes"""
mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k),
torch.zeros(self.num_tokens,
self.top_k,
dtype=torch.int32),
dtype=torch.long),
torch.arange(0,
self.num_tokens * self.top_k,
dtype=torch.int32).view(