From a2e4c3fe783dc8c3ce500a6e8794685bb90193e8 Mon Sep 17 00:00:00 2001 From: wangxiyuan Date: Fri, 21 Nov 2025 23:03:20 +0800 Subject: [PATCH] Revert "[cherry-pick][refactor]support gatingtopk operator generalization (#4050)" (#4352) This reverts commit c87a77e8b4f1b435d8ec32af3b0c729e1cdb511d. it breaks ops e2e test Signed-off-by: wangxiyuan --- tests/ut/quantization/test_w8a8.py | 63 +++++++++----------- vllm_ascend/ascend_forward_context.py | 1 - vllm_ascend/ops/moe/experts_selector.py | 79 +++++++++++++------------ 3 files changed, 69 insertions(+), 74 deletions(-) diff --git a/tests/ut/quantization/test_w8a8.py b/tests/ut/quantization/test_w8a8.py index ed9cc9a..b88e78f 100644 --- a/tests/ut/quantization/test_w8a8.py +++ b/tests/ut/quantization/test_w8a8.py @@ -754,14 +754,6 @@ class TestSelectExperts(TestBase): self.hidden_states = torch.randn(self.num_tokens, self.hidden_size) self.router_logits = torch.randn(self.num_tokens, self.num_experts) - """Mock custom routing""" - self.mock_custom_routing = MagicMock() - self.mock_custom_routing.return_value = (torch.ones( - self.num_tokens, self.top_k), - torch.zeros( - self.num_tokens, - self.top_k, - dtype=torch.int32)) self.mock_ctx = MagicMock() self.mock_ctx.weight_prefetch_method = MagicMock() @@ -771,7 +763,7 @@ class TestSelectExperts(TestBase): self.addCleanup(patcher.stop) patcher.start() - @patch('torch_npu.npu_moe_gating_top_k') + @patch('torch_npu.npu_moe_gating_top_k_softmax') def test_softmax_scoring(self, mock_topk): """Test softmax scoring function""" mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k), @@ -798,14 +790,12 @@ class TestSelectExperts(TestBase): def test_sigmoid_scoring(self): """Test sigmoid scoring function""" - weights, ids = select_experts( - hidden_states=self.hidden_states, - router_logits=self.router_logits, - top_k=self.top_k, - use_grouped_topk=False, - renormalize=False, - scoring_func="sigmoid", - custom_routing_function=self.mock_custom_routing) + weights, ids = select_experts(hidden_states=self.hidden_states, + router_logits=self.router_logits, + top_k=self.top_k, + use_grouped_topk=False, + renormalize=False, + scoring_func="sigmoid") self.assertEqual(weights.shape, (self.num_tokens, self.top_k)) self.assertEqual(ids.shape, (self.num_tokens, self.top_k)) @@ -818,8 +808,7 @@ class TestSelectExperts(TestBase): top_k=self.top_k, use_grouped_topk=False, renormalize=False, - scoring_func="invalid_func", - custom_routing_function=self.mock_custom_routing) + scoring_func="invalid_func") @patch('torch.topk') def test_grouped_topk(self, mock_topk): @@ -829,15 +818,13 @@ class TestSelectExperts(TestBase): self.top_k, dtype=torch.long)) - weights, ids = select_experts( - hidden_states=self.hidden_states, - router_logits=self.router_logits, - top_k=self.top_k, - use_grouped_topk=True, - renormalize=False, - topk_group=4, - num_expert_group=2, - custom_routing_function=self.mock_custom_routing) + weights, ids = select_experts(hidden_states=self.hidden_states, + router_logits=self.router_logits, + top_k=self.top_k, + use_grouped_topk=True, + renormalize=False, + topk_group=4, + num_expert_group=2) mock_topk.assert_called() self.assertEqual(weights.shape, (self.num_tokens, self.top_k)) @@ -859,8 +846,7 @@ class TestSelectExperts(TestBase): renormalize=False, topk_group=4, num_expert_group=2, - e_score_correction_bias=e_score_correction_bias, - custom_routing_function=self.mock_custom_routing) + e_score_correction_bias=e_score_correction_bias) mock_grouped_topk.assert_called_once() self.assertEqual(weights.shape, (self.num_tokens, self.top_k)) @@ -868,20 +854,27 @@ class TestSelectExperts(TestBase): def test_custom_routing_function(self): """Test custom routing function""" + mock_custom_routing = MagicMock() + mock_custom_routing.return_value = (torch.ones(self.num_tokens, + self.top_k), + torch.zeros(self.num_tokens, + self.top_k, + dtype=torch.int32)) + weights, ids = select_experts( hidden_states=self.hidden_states, router_logits=self.router_logits, top_k=self.top_k, use_grouped_topk=False, renormalize=False, - custom_routing_function=self.mock_custom_routing) + custom_routing_function=mock_custom_routing) - self.mock_custom_routing.assert_called_once() + mock_custom_routing.assert_called_once() self.assertEqual(weights.shape, (self.num_tokens, self.top_k)) self.assertEqual(ids.shape, (self.num_tokens, self.top_k)) self.assertEqual(ids.dtype, torch.int32) - @patch('torch_npu.npu_moe_gating_top_k') + @patch('torch_npu.npu_moe_gating_top_k_softmax') def test_renormalize(self, mock_topk): """Test renormalization""" mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k), @@ -907,13 +900,13 @@ class TestSelectExperts(TestBase): sums = weights.sum(dim=-1) self.assertTrue(torch.allclose(sums, torch.ones_like(sums))) - @patch('torch_npu.npu_moe_gating_top_k') + @patch('torch_npu.npu_moe_gating_top_k_softmax') def test_output_dtypes(self, mock_topk): """Test output dtypes""" mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k), torch.zeros(self.num_tokens, self.top_k, - dtype=torch.int32), + dtype=torch.long), torch.arange(0, self.num_tokens * self.top_k, dtype=torch.int32).view( diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index 580508a..a700fbf 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -96,7 +96,6 @@ def set_ascend_forward_context( ep_size = (get_ep_group().world_size if vllm_config.parallel_config.enable_expert_parallel else 1) - # fused_moe_state is used in torchair, it will be deleted along with torchair is_deepseek_v3_r1 = hasattr( vllm_config.model_config.hf_config, 'n_routed_experts' ) and vllm_config.model_config.hf_config.n_routed_experts == 256 diff --git a/vllm_ascend/ops/moe/experts_selector.py b/vllm_ascend/ops/moe/experts_selector.py index eb3fc84..e511d6b 100644 --- a/vllm_ascend/ops/moe/experts_selector.py +++ b/vllm_ascend/ops/moe/experts_selector.py @@ -20,6 +20,8 @@ import torch import torch_npu from vllm.forward_context import get_forward_context +from vllm_ascend.ascend_config import get_ascend_config + def select_experts(hidden_states: torch.Tensor, router_logits: torch.Tensor, @@ -60,20 +62,21 @@ def select_experts(hidden_states: torch.Tensor, if weight_prefetch_method: weight_prefetch_method.maybe_prefetch_moe_weight_preprocess( hidden_states, "gate_up") - if custom_routing_function is None: - topk_weights, topk_ids = _select_experts_with_fusion_ops( - hidden_states=hidden_states, - router_logits=router_logits, - top_k=top_k, - use_grouped_topk=use_grouped_topk, - topk_group=topk_group, - renormalize=renormalize, - e_score_correction_bias=e_score_correction_bias, - num_expert_group=num_expert_group, - scoring_func=scoring_func, - routed_scaling_factor=routed_scaling_factor, - global_num_experts=global_num_experts) - else: + topk_weights, topk_ids = _select_experts_with_fusion_ops( + hidden_states=hidden_states, + router_logits=router_logits, + top_k=top_k, + use_grouped_topk=use_grouped_topk, + topk_group=topk_group, + renormalize=renormalize, + e_score_correction_bias=e_score_correction_bias, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, + global_num_experts=global_num_experts) + + if topk_weights is None: topk_weights, topk_ids = _native_select_experts( hidden_states=hidden_states, router_logits=router_logits, @@ -168,34 +171,34 @@ def _select_experts_with_fusion_ops( e_score_correction_bias: Optional[torch.Tensor], topk_group: Optional[int], num_expert_group: Optional[int], + custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", routed_scaling_factor=1.0, global_num_experts: int = -1): - if scoring_func == "softmax": - norm_type = 0 - topk_group = 1 - num_expert_group = 1 - else: - norm_type = 1 - if e_score_correction_bias is not None and \ - e_score_correction_bias.dtype != router_logits.dtype: - e_score_correction_bias = e_score_correction_bias.to( - router_logits.dtype) - topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( - router_logits, - k=top_k, - bias=e_score_correction_bias, - k_group=topk_group, - group_count=num_expert_group, - group_select_mode=1, # 0: the maximum in the group; 1: topk2.sum(fix) - renorm=0, # 0: softmax->topk(fix); 1: topk->softmax - norm_type=norm_type, # 0: softmax; 1: sigmoid - # out_flag=False, # todo new api; should the third output be output - # y2_flag=False, # old api; should the third output be output - routed_scaling_factor=1, - eps=float(1e-20)) - if scoring_func == "softmax": + topk_weights, topk_ids = None, None + # NOTE: now npu_moe_gating_top_k can only support 'group_count=256' pattern + global_redundant_expert_num = get_ascend_config().init_redundancy_expert + is_deepseek_v3_r1 = global_num_experts - global_redundant_expert_num == 256 + if is_deepseek_v3_r1: + topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( + router_logits, + k=top_k, # topk currently 8 + bias=e_score_correction_bias, + k_group=topk_group, # fix: 4 + group_count=num_expert_group, # fix 8 + group_select_mode= + 1, # 0: the maximum in the group; 1: topk2.sum(fix) + renorm=0, # 0: softmax->topk(fix); 1: topk->softmax + norm_type=1, # 0: softmax; 1: sigmoid(fix) + # out_flag=False, # todo new api; should the third output be output + # y2_flag=False, # old api; should the third output be output + routed_scaling_factor=1, + eps=float(1e-20)) + if not use_grouped_topk and custom_routing_function is None and scoring_func == "softmax": + topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k_softmax( + x=router_logits, finished=None, k=top_k) + topk_ids = topk_ids.to(torch.int32) topk_weights = _renormalize_topk_weights(topk_weights, renormalize) return topk_weights, topk_ids