diff --git a/tests/e2e/nightly/ops/test_fused_moe.py b/tests/e2e/nightly/ops/test_fused_moe.py index 853baa93..8fcac0e4 100644 --- a/tests/e2e/nightly/ops/test_fused_moe.py +++ b/tests/e2e/nightly/ops/test_fused_moe.py @@ -28,7 +28,8 @@ import torch import torch_npu from vllm.model_executor.layers.activation import SiluAndMul -from vllm_ascend.ops.fused_moe.experts_selector import select_experts +from vllm_ascend.ops.fused_moe.experts_selector import ( + check_npu_moe_gating_top_k, select_experts) from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp from vllm_ascend.ops.fused_moe.token_dispatcher import \ TokenDispatcherWithAllGather @@ -303,7 +304,10 @@ def test_select_experts( e_score_correction_bias=e_score_correction_bias, ) - if use_grouped_topk: + call_moe_gatingtopk = check_npu_moe_gating_top_k( + hidden_states, topk, topk_group, num_expert_group, scoring_func, + custom_routing_function) + if not call_moe_gatingtopk and use_grouped_topk: mock_native_grouped_topk.assert_called_once() else: mock_native_grouped_topk.assert_not_called() diff --git a/tests/ut/quantization/test_w8a8.py b/tests/ut/quantization/test_w8a8.py index ce18023f..c574d998 100644 --- a/tests/ut/quantization/test_w8a8.py +++ b/tests/ut/quantization/test_w8a8.py @@ -823,8 +823,7 @@ class TestSelectExperts(TestBase): top_k=self.top_k, use_grouped_topk=False, renormalize=False, - scoring_func="invalid_func", - custom_routing_function=self.mock_custom_routing) + scoring_func="invalid_func") @patch('torch.topk') def test_grouped_topk(self, mock_topk): @@ -834,15 +833,13 @@ class TestSelectExperts(TestBase): self.top_k, dtype=torch.long)) - weights, ids = select_experts( - hidden_states=self.hidden_states, - router_logits=self.router_logits, - top_k=self.top_k, - use_grouped_topk=True, - renormalize=False, - topk_group=4, - num_expert_group=2, - custom_routing_function=self.mock_custom_routing) + weights, ids = select_experts(hidden_states=self.hidden_states, + router_logits=self.router_logits, + top_k=self.top_k, + use_grouped_topk=True, + renormalize=False, + topk_group=4, + num_expert_group=2) mock_topk.assert_called() self.assertEqual(weights.shape, (self.num_tokens, self.top_k)) @@ -864,8 +861,7 @@ class TestSelectExperts(TestBase): renormalize=False, topk_group=4, num_expert_group=2, - e_score_correction_bias=e_score_correction_bias, - custom_routing_function=self.mock_custom_routing) + e_score_correction_bias=e_score_correction_bias) mock_grouped_topk.assert_called_once() self.assertEqual(weights.shape, (self.num_tokens, self.top_k)) diff --git a/vllm_ascend/ops/fused_moe/experts_selector.py b/vllm_ascend/ops/fused_moe/experts_selector.py index eb3fc848..05ec0e38 100644 --- a/vllm_ascend/ops/fused_moe/experts_selector.py +++ b/vllm_ascend/ops/fused_moe/experts_selector.py @@ -60,7 +60,15 @@ def select_experts(hidden_states: torch.Tensor, if weight_prefetch_method: weight_prefetch_method.maybe_prefetch_moe_weight_preprocess( hidden_states, "gate_up") - if custom_routing_function is None: + is_support_npu_moe_gating_top_k = check_npu_moe_gating_top_k( + hidden_states=hidden_states, + top_k=top_k, + topk_group=topk_group, + num_expert_group=num_expert_group, + scoring_func=scoring_func, + custom_routing_function=custom_routing_function) + + if is_support_npu_moe_gating_top_k: topk_weights, topk_ids = _select_experts_with_fusion_ops( hidden_states=hidden_states, router_logits=router_logits, @@ -90,6 +98,32 @@ def select_experts(hidden_states: torch.Tensor, return topk_weights, topk_ids +def check_npu_moe_gating_top_k( + hidden_states: torch.Tensor, + top_k: int, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + scoring_func: str = "softmax", + custom_routing_function: Optional[Callable] = None): + if custom_routing_function is not None: + return False + if scoring_func != "softmax" and scoring_func != "sigmoid": + return False + topk_group = topk_group if topk_group is not None else 1 + num_expert_group = num_expert_group if num_expert_group is not None else 1 + if not (num_expert_group > 0 and hidden_states.shape[-1] % num_expert_group + == 0 and hidden_states.shape[-1] // num_expert_group > 2): + return False + if topk_group < 1 or topk_group > num_expert_group: + return False + if top_k < 1 or \ + top_k > (hidden_states.shape[-1] / (num_expert_group * topk_group)): + return False + if topk_group * hidden_states.shape[-1] / num_expert_group < top_k: + return False + return True + + def _native_grouped_topk( topk_weights: torch.Tensor, num_expert_group: Optional[int], @@ -172,12 +206,9 @@ def _select_experts_with_fusion_ops( routed_scaling_factor=1.0, global_num_experts: int = -1): - if scoring_func == "softmax": - norm_type = 0 - topk_group = 1 - num_expert_group = 1 - else: - norm_type = 1 + topk_group = topk_group if topk_group is not None else 1 + num_expert_group = num_expert_group if num_expert_group is not None else 1 + norm_type = 0 if scoring_func == "softmax" else 1 if e_score_correction_bias is not None and \ e_score_correction_bias.dtype != router_logits.dtype: e_score_correction_bias = e_score_correction_bias.to(