diff --git a/tests/ut/quantization/test_w8a8.py b/tests/ut/quantization/test_w8a8.py index 669f2b9..90a5f59 100644 --- a/tests/ut/quantization/test_w8a8.py +++ b/tests/ut/quantization/test_w8a8.py @@ -716,8 +716,19 @@ class TestSelectExperts(TestBase): self.hidden_states = torch.randn(self.num_tokens, self.hidden_size) self.router_logits = torch.randn(self.num_tokens, self.num_experts) - def test_softmax_scoring(self): + @patch('torch_npu.npu_moe_gating_top_k_softmax') + def test_softmax_scoring(self, mock_topk): """Test softmax scoring function""" + mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k), + torch.zeros(self.num_tokens, + self.top_k, + dtype=torch.long), + torch.arange(0, + self.num_tokens * self.top_k, + dtype=torch.int32).view( + self.top_k, + -1).permute(1, + 0).contiguous()) weights, ids, _ = select_experts(hidden_states=self.hidden_states, router_logits=self.router_logits, @@ -816,13 +827,19 @@ class TestSelectExperts(TestBase): self.assertEqual(ids.shape, (self.num_tokens, self.top_k)) self.assertEqual(ids.dtype, torch.int32) - @patch('torch.topk') + @patch('torch_npu.npu_moe_gating_top_k_softmax') def test_renormalize(self, mock_topk): - """Test weight renormalization""" + """Test renormalization""" mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k), torch.zeros(self.num_tokens, self.top_k, - dtype=torch.long)) + dtype=torch.long), + torch.arange(0, + self.num_tokens * self.top_k, + dtype=torch.int32).view( + self.top_k, + -1).permute(1, + 0).contiguous()) weights, ids, _ = select_experts( hidden_states=self.hidden_states, @@ -836,13 +853,19 @@ class TestSelectExperts(TestBase): sums = weights.sum(dim=-1) self.assertTrue(torch.allclose(sums, torch.ones_like(sums))) - @patch('torch.topk') + @patch('torch_npu.npu_moe_gating_top_k_softmax') def test_output_dtypes(self, mock_topk): """Test output dtypes""" mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k), torch.zeros(self.num_tokens, self.top_k, - dtype=torch.long)) + dtype=torch.long), + torch.arange(0, + self.num_tokens * self.top_k, + dtype=torch.int32).view( + self.top_k, + -1).permute(1, + 0).contiguous()) weights, ids, _ = select_experts( hidden_states=self.hidden_states, diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index e86f77d..14396c1 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -173,8 +173,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): custom_routing_function=custom_routing_function, scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias, - global_num_experts=global_num_experts, - is_unquantized=True) + global_num_experts=global_num_experts) topk_weights = topk_weights.to(x.dtype) # this is a naive implementation for experts load balance so as diff --git a/vllm_ascend/ops/layers/experts_selector.py b/vllm_ascend/ops/layers/experts_selector.py index c1f9312..eace164 100644 --- a/vllm_ascend/ops/layers/experts_selector.py +++ b/vllm_ascend/ops/layers/experts_selector.py @@ -43,7 +43,6 @@ def select_experts(hidden_states: torch.Tensor, routed_scaling_factor=1.0, e_score_correction_bias: Optional[torch.Tensor] = None, indices_type: Optional[torch.dtype] = None, - is_unquantized: bool = False, global_num_experts: int = -1): """ Fused experts with select experts. @@ -60,7 +59,6 @@ def select_experts(hidden_states: torch.Tensor, scoring_func: Scoring function to use. e_score_correction_bias: Correction bias to apply to expert scores. indices_type: dtype of indices - is_unquantized: Whether the data are unquantized. global_num_experts: Global number of experts. Returns: @@ -80,8 +78,7 @@ def select_experts(hidden_states: torch.Tensor, custom_routing_function=custom_routing_function, scoring_func=scoring_func, routed_scaling_factor=routed_scaling_factor, - global_num_experts=global_num_experts, - is_unquantized=is_unquantized) + global_num_experts=global_num_experts) if topk_weights is None: topk_weights, topk_ids = _native_select_experts( @@ -183,8 +180,7 @@ def _select_experts_with_fusion_ops( custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", routed_scaling_factor=1.0, - global_num_experts: int = -1, - is_unquantized: bool = False): + global_num_experts: int = -1): topk_weights, topk_ids, row_idx = None, None, None # NOTE: now npu_moe_gating_top_k can only support 'group_count=256' pattern @@ -205,7 +201,7 @@ def _select_experts_with_fusion_ops( routed_scaling_factor=1, eps=float(1e-20)) row_idx = return_row_idx(hidden_states, top_k) - if not use_grouped_topk and custom_routing_function is None and scoring_func == "softmax" and is_unquantized: + if not use_grouped_topk and custom_routing_function is None and scoring_func == "softmax": topk_weights, topk_ids, row_idx = torch_npu.npu_moe_gating_top_k_softmax( x=router_logits, finished=None, k=top_k) topk_ids = topk_ids.to(torch.int32)