diff --git a/tests/e2e/singlecard/ops/test_fused_moe.py b/tests/e2e/singlecard/ops/test_fused_moe.py index 4735a5f..8180858 100644 --- a/tests/e2e/singlecard/ops/test_fused_moe.py +++ b/tests/e2e/singlecard/ops/test_fused_moe.py @@ -28,7 +28,8 @@ import torch import torch_npu from vllm.model_executor.layers.activation import SiluAndMul -from vllm_ascend.ops.moe.experts_selector import select_experts +from vllm_ascend.ops.moe.experts_selector import (check_npu_moe_gating_top_k, + select_experts) from vllm_ascend.ops.moe.moe_mlp import unified_apply_mlp from vllm_ascend.ops.moe.token_dispatcher import TokenDispatcherWithAllGather @@ -296,7 +297,10 @@ def test_select_experts( e_score_correction_bias=e_score_correction_bias, ) - if use_grouped_topk: + call_moe_gatingtopk = check_npu_moe_gating_top_k( + hidden_states, topk, topk_group, num_expert_group, scoring_func, + custom_routing_function) + if not call_moe_gatingtopk and use_grouped_topk: mock_native_grouped_topk.assert_called_once() else: mock_native_grouped_topk.assert_not_called() diff --git a/tests/ut/quantization/test_w8a8.py b/tests/ut/quantization/test_w8a8.py index 6702d2b..2ad8088 100644 --- a/tests/ut/quantization/test_w8a8.py +++ b/tests/ut/quantization/test_w8a8.py @@ -753,6 +753,14 @@ class TestSelectExperts(TestBase): self.hidden_states = torch.randn(self.num_tokens, self.hidden_size) self.router_logits = torch.randn(self.num_tokens, self.num_experts) + """Mock custom routing""" + self.mock_custom_routing = MagicMock() + self.mock_custom_routing.return_value = (torch.ones( + self.num_tokens, self.top_k), + torch.zeros( + self.num_tokens, + self.top_k, + dtype=torch.int32)) self.mock_ctx = MagicMock() self.mock_ctx.weight_prefetch_method = MagicMock() @@ -762,7 +770,7 @@ class TestSelectExperts(TestBase): self.addCleanup(patcher.stop) patcher.start() - @patch('torch_npu.npu_moe_gating_top_k_softmax') + @patch('torch_npu.npu_moe_gating_top_k') def test_softmax_scoring(self, mock_topk): """Test softmax scoring function""" mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k), @@ -789,12 +797,14 @@ class TestSelectExperts(TestBase): def test_sigmoid_scoring(self): """Test sigmoid scoring function""" - weights, ids = select_experts(hidden_states=self.hidden_states, - router_logits=self.router_logits, - top_k=self.top_k, - use_grouped_topk=False, - renormalize=False, - scoring_func="sigmoid") + weights, ids = select_experts( + hidden_states=self.hidden_states, + router_logits=self.router_logits, + top_k=self.top_k, + use_grouped_topk=False, + renormalize=False, + scoring_func="sigmoid", + custom_routing_function=self.mock_custom_routing) self.assertEqual(weights.shape, (self.num_tokens, self.top_k)) self.assertEqual(ids.shape, (self.num_tokens, self.top_k)) @@ -853,27 +863,20 @@ class TestSelectExperts(TestBase): def test_custom_routing_function(self): """Test custom routing function""" - mock_custom_routing = MagicMock() - mock_custom_routing.return_value = (torch.ones(self.num_tokens, - self.top_k), - torch.zeros(self.num_tokens, - self.top_k, - dtype=torch.int32)) - weights, ids = select_experts( hidden_states=self.hidden_states, router_logits=self.router_logits, top_k=self.top_k, use_grouped_topk=False, renormalize=False, - custom_routing_function=mock_custom_routing) + custom_routing_function=self.mock_custom_routing) - mock_custom_routing.assert_called_once() + self.mock_custom_routing.assert_called_once() self.assertEqual(weights.shape, (self.num_tokens, self.top_k)) self.assertEqual(ids.shape, (self.num_tokens, self.top_k)) self.assertEqual(ids.dtype, torch.int32) - @patch('torch_npu.npu_moe_gating_top_k_softmax') + @patch('torch_npu.npu_moe_gating_top_k') def test_renormalize(self, mock_topk): """Test renormalization""" mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k), @@ -899,13 +902,13 @@ class TestSelectExperts(TestBase): sums = weights.sum(dim=-1) self.assertTrue(torch.allclose(sums, torch.ones_like(sums))) - @patch('torch_npu.npu_moe_gating_top_k_softmax') + @patch('torch_npu.npu_moe_gating_top_k') def test_output_dtypes(self, mock_topk): """Test output dtypes""" mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k), torch.zeros(self.num_tokens, self.top_k, - dtype=torch.long), + dtype=torch.int32), torch.arange(0, self.num_tokens * self.top_k, dtype=torch.int32).view( diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index a700fbf..580508a 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -96,6 +96,7 @@ def set_ascend_forward_context( ep_size = (get_ep_group().world_size if vllm_config.parallel_config.enable_expert_parallel else 1) + # fused_moe_state is used in torchair, it will be deleted along with torchair is_deepseek_v3_r1 = hasattr( vllm_config.model_config.hf_config, 'n_routed_experts' ) and vllm_config.model_config.hf_config.n_routed_experts == 256 diff --git a/vllm_ascend/ops/moe/experts_selector.py b/vllm_ascend/ops/moe/experts_selector.py index e511d6b..05ec0e3 100644 --- a/vllm_ascend/ops/moe/experts_selector.py +++ b/vllm_ascend/ops/moe/experts_selector.py @@ -20,8 +20,6 @@ import torch import torch_npu from vllm.forward_context import get_forward_context -from vllm_ascend.ascend_config import get_ascend_config - def select_experts(hidden_states: torch.Tensor, router_logits: torch.Tensor, @@ -62,21 +60,28 @@ def select_experts(hidden_states: torch.Tensor, if weight_prefetch_method: weight_prefetch_method.maybe_prefetch_moe_weight_preprocess( hidden_states, "gate_up") - topk_weights, topk_ids = _select_experts_with_fusion_ops( + is_support_npu_moe_gating_top_k = check_npu_moe_gating_top_k( hidden_states=hidden_states, - router_logits=router_logits, top_k=top_k, - use_grouped_topk=use_grouped_topk, topk_group=topk_group, - renormalize=renormalize, - e_score_correction_bias=e_score_correction_bias, num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, scoring_func=scoring_func, - routed_scaling_factor=routed_scaling_factor, - global_num_experts=global_num_experts) + custom_routing_function=custom_routing_function) - if topk_weights is None: + if is_support_npu_moe_gating_top_k: + topk_weights, topk_ids = _select_experts_with_fusion_ops( + hidden_states=hidden_states, + router_logits=router_logits, + top_k=top_k, + use_grouped_topk=use_grouped_topk, + topk_group=topk_group, + renormalize=renormalize, + e_score_correction_bias=e_score_correction_bias, + num_expert_group=num_expert_group, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, + global_num_experts=global_num_experts) + else: topk_weights, topk_ids = _native_select_experts( hidden_states=hidden_states, router_logits=router_logits, @@ -93,6 +98,32 @@ def select_experts(hidden_states: torch.Tensor, return topk_weights, topk_ids +def check_npu_moe_gating_top_k( + hidden_states: torch.Tensor, + top_k: int, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + scoring_func: str = "softmax", + custom_routing_function: Optional[Callable] = None): + if custom_routing_function is not None: + return False + if scoring_func != "softmax" and scoring_func != "sigmoid": + return False + topk_group = topk_group if topk_group is not None else 1 + num_expert_group = num_expert_group if num_expert_group is not None else 1 + if not (num_expert_group > 0 and hidden_states.shape[-1] % num_expert_group + == 0 and hidden_states.shape[-1] // num_expert_group > 2): + return False + if topk_group < 1 or topk_group > num_expert_group: + return False + if top_k < 1 or \ + top_k > (hidden_states.shape[-1] / (num_expert_group * topk_group)): + return False + if topk_group * hidden_states.shape[-1] / num_expert_group < top_k: + return False + return True + + def _native_grouped_topk( topk_weights: torch.Tensor, num_expert_group: Optional[int], @@ -171,34 +202,31 @@ def _select_experts_with_fusion_ops( e_score_correction_bias: Optional[torch.Tensor], topk_group: Optional[int], num_expert_group: Optional[int], - custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", routed_scaling_factor=1.0, global_num_experts: int = -1): - topk_weights, topk_ids = None, None - # NOTE: now npu_moe_gating_top_k can only support 'group_count=256' pattern - global_redundant_expert_num = get_ascend_config().init_redundancy_expert - is_deepseek_v3_r1 = global_num_experts - global_redundant_expert_num == 256 - if is_deepseek_v3_r1: - topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( - router_logits, - k=top_k, # topk currently 8 - bias=e_score_correction_bias, - k_group=topk_group, # fix: 4 - group_count=num_expert_group, # fix 8 - group_select_mode= - 1, # 0: the maximum in the group; 1: topk2.sum(fix) - renorm=0, # 0: softmax->topk(fix); 1: topk->softmax - norm_type=1, # 0: softmax; 1: sigmoid(fix) - # out_flag=False, # todo new api; should the third output be output - # y2_flag=False, # old api; should the third output be output - routed_scaling_factor=1, - eps=float(1e-20)) - if not use_grouped_topk and custom_routing_function is None and scoring_func == "softmax": - topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k_softmax( - x=router_logits, finished=None, k=top_k) - topk_ids = topk_ids.to(torch.int32) + topk_group = topk_group if topk_group is not None else 1 + num_expert_group = num_expert_group if num_expert_group is not None else 1 + norm_type = 0 if scoring_func == "softmax" else 1 + if e_score_correction_bias is not None and \ + e_score_correction_bias.dtype != router_logits.dtype: + e_score_correction_bias = e_score_correction_bias.to( + router_logits.dtype) + topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( + router_logits, + k=top_k, + bias=e_score_correction_bias, + k_group=topk_group, + group_count=num_expert_group, + group_select_mode=1, # 0: the maximum in the group; 1: topk2.sum(fix) + renorm=0, # 0: softmax->topk(fix); 1: topk->softmax + norm_type=norm_type, # 0: softmax; 1: sigmoid + # out_flag=False, # todo new api; should the third output be output + # y2_flag=False, # old api; should the third output be output + routed_scaling_factor=1, + eps=float(1e-20)) + if scoring_func == "softmax": topk_weights = _renormalize_topk_weights(topk_weights, renormalize) return topk_weights, topk_ids