diff --git a/tests/ut/quantization/test_w8a8.py b/tests/ut/quantization/test_w8a8.py index b88e78f..ed9cc9a 100644 --- a/tests/ut/quantization/test_w8a8.py +++ b/tests/ut/quantization/test_w8a8.py @@ -754,6 +754,14 @@ class TestSelectExperts(TestBase): self.hidden_states = torch.randn(self.num_tokens, self.hidden_size) self.router_logits = torch.randn(self.num_tokens, self.num_experts) + """Mock custom routing""" + self.mock_custom_routing = MagicMock() + self.mock_custom_routing.return_value = (torch.ones( + self.num_tokens, self.top_k), + torch.zeros( + self.num_tokens, + self.top_k, + dtype=torch.int32)) self.mock_ctx = MagicMock() self.mock_ctx.weight_prefetch_method = MagicMock() @@ -763,7 +771,7 @@ class TestSelectExperts(TestBase): self.addCleanup(patcher.stop) patcher.start() - @patch('torch_npu.npu_moe_gating_top_k_softmax') + @patch('torch_npu.npu_moe_gating_top_k') def test_softmax_scoring(self, mock_topk): """Test softmax scoring function""" mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k), @@ -790,12 +798,14 @@ class TestSelectExperts(TestBase): def test_sigmoid_scoring(self): """Test sigmoid scoring function""" - weights, ids = select_experts(hidden_states=self.hidden_states, - router_logits=self.router_logits, - top_k=self.top_k, - use_grouped_topk=False, - renormalize=False, - scoring_func="sigmoid") + weights, ids = select_experts( + hidden_states=self.hidden_states, + router_logits=self.router_logits, + top_k=self.top_k, + use_grouped_topk=False, + renormalize=False, + scoring_func="sigmoid", + custom_routing_function=self.mock_custom_routing) self.assertEqual(weights.shape, (self.num_tokens, self.top_k)) self.assertEqual(ids.shape, (self.num_tokens, self.top_k)) @@ -808,7 +818,8 @@ class TestSelectExperts(TestBase): top_k=self.top_k, use_grouped_topk=False, renormalize=False, - scoring_func="invalid_func") + scoring_func="invalid_func", + custom_routing_function=self.mock_custom_routing) @patch('torch.topk') def test_grouped_topk(self, mock_topk): @@ -818,13 +829,15 @@ class TestSelectExperts(TestBase): self.top_k, dtype=torch.long)) - weights, ids = select_experts(hidden_states=self.hidden_states, - router_logits=self.router_logits, - top_k=self.top_k, - use_grouped_topk=True, - renormalize=False, - topk_group=4, - num_expert_group=2) + weights, ids = select_experts( + hidden_states=self.hidden_states, + router_logits=self.router_logits, + top_k=self.top_k, + use_grouped_topk=True, + renormalize=False, + topk_group=4, + num_expert_group=2, + custom_routing_function=self.mock_custom_routing) mock_topk.assert_called() self.assertEqual(weights.shape, (self.num_tokens, self.top_k)) @@ -846,7 +859,8 @@ class TestSelectExperts(TestBase): renormalize=False, topk_group=4, num_expert_group=2, - e_score_correction_bias=e_score_correction_bias) + e_score_correction_bias=e_score_correction_bias, + custom_routing_function=self.mock_custom_routing) mock_grouped_topk.assert_called_once() self.assertEqual(weights.shape, (self.num_tokens, self.top_k)) @@ -854,27 +868,20 @@ class TestSelectExperts(TestBase): def test_custom_routing_function(self): """Test custom routing function""" - mock_custom_routing = MagicMock() - mock_custom_routing.return_value = (torch.ones(self.num_tokens, - self.top_k), - torch.zeros(self.num_tokens, - self.top_k, - dtype=torch.int32)) - weights, ids = select_experts( hidden_states=self.hidden_states, router_logits=self.router_logits, top_k=self.top_k, use_grouped_topk=False, renormalize=False, - custom_routing_function=mock_custom_routing) + custom_routing_function=self.mock_custom_routing) - mock_custom_routing.assert_called_once() + self.mock_custom_routing.assert_called_once() self.assertEqual(weights.shape, (self.num_tokens, self.top_k)) self.assertEqual(ids.shape, (self.num_tokens, self.top_k)) self.assertEqual(ids.dtype, torch.int32) - @patch('torch_npu.npu_moe_gating_top_k_softmax') + @patch('torch_npu.npu_moe_gating_top_k') def test_renormalize(self, mock_topk): """Test renormalization""" mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k), @@ -900,13 +907,13 @@ class TestSelectExperts(TestBase): sums = weights.sum(dim=-1) self.assertTrue(torch.allclose(sums, torch.ones_like(sums))) - @patch('torch_npu.npu_moe_gating_top_k_softmax') + @patch('torch_npu.npu_moe_gating_top_k') def test_output_dtypes(self, mock_topk): """Test output dtypes""" mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k), torch.zeros(self.num_tokens, self.top_k, - dtype=torch.long), + dtype=torch.int32), torch.arange(0, self.num_tokens * self.top_k, dtype=torch.int32).view( diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index a700fbf..580508a 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -96,6 +96,7 @@ def set_ascend_forward_context( ep_size = (get_ep_group().world_size if vllm_config.parallel_config.enable_expert_parallel else 1) + # fused_moe_state is used in torchair, it will be deleted along with torchair is_deepseek_v3_r1 = hasattr( vllm_config.model_config.hf_config, 'n_routed_experts' ) and vllm_config.model_config.hf_config.n_routed_experts == 256 diff --git a/vllm_ascend/ops/moe/experts_selector.py b/vllm_ascend/ops/moe/experts_selector.py index e511d6b..eb3fc84 100644 --- a/vllm_ascend/ops/moe/experts_selector.py +++ b/vllm_ascend/ops/moe/experts_selector.py @@ -20,8 +20,6 @@ import torch import torch_npu from vllm.forward_context import get_forward_context -from vllm_ascend.ascend_config import get_ascend_config - def select_experts(hidden_states: torch.Tensor, router_logits: torch.Tensor, @@ -62,21 +60,20 @@ def select_experts(hidden_states: torch.Tensor, if weight_prefetch_method: weight_prefetch_method.maybe_prefetch_moe_weight_preprocess( hidden_states, "gate_up") - topk_weights, topk_ids = _select_experts_with_fusion_ops( - hidden_states=hidden_states, - router_logits=router_logits, - top_k=top_k, - use_grouped_topk=use_grouped_topk, - topk_group=topk_group, - renormalize=renormalize, - e_score_correction_bias=e_score_correction_bias, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - routed_scaling_factor=routed_scaling_factor, - global_num_experts=global_num_experts) - - if topk_weights is None: + if custom_routing_function is None: + topk_weights, topk_ids = _select_experts_with_fusion_ops( + hidden_states=hidden_states, + router_logits=router_logits, + top_k=top_k, + use_grouped_topk=use_grouped_topk, + topk_group=topk_group, + renormalize=renormalize, + e_score_correction_bias=e_score_correction_bias, + num_expert_group=num_expert_group, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, + global_num_experts=global_num_experts) + else: topk_weights, topk_ids = _native_select_experts( hidden_states=hidden_states, router_logits=router_logits, @@ -171,34 +168,34 @@ def _select_experts_with_fusion_ops( e_score_correction_bias: Optional[torch.Tensor], topk_group: Optional[int], num_expert_group: Optional[int], - custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", routed_scaling_factor=1.0, global_num_experts: int = -1): - topk_weights, topk_ids = None, None - # NOTE: now npu_moe_gating_top_k can only support 'group_count=256' pattern - global_redundant_expert_num = get_ascend_config().init_redundancy_expert - is_deepseek_v3_r1 = global_num_experts - global_redundant_expert_num == 256 - if is_deepseek_v3_r1: - topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( - router_logits, - k=top_k, # topk currently 8 - bias=e_score_correction_bias, - k_group=topk_group, # fix: 4 - group_count=num_expert_group, # fix 8 - group_select_mode= - 1, # 0: the maximum in the group; 1: topk2.sum(fix) - renorm=0, # 0: softmax->topk(fix); 1: topk->softmax - norm_type=1, # 0: softmax; 1: sigmoid(fix) - # out_flag=False, # todo new api; should the third output be output - # y2_flag=False, # old api; should the third output be output - routed_scaling_factor=1, - eps=float(1e-20)) - if not use_grouped_topk and custom_routing_function is None and scoring_func == "softmax": - topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k_softmax( - x=router_logits, finished=None, k=top_k) - topk_ids = topk_ids.to(torch.int32) + if scoring_func == "softmax": + norm_type = 0 + topk_group = 1 + num_expert_group = 1 + else: + norm_type = 1 + if e_score_correction_bias is not None and \ + e_score_correction_bias.dtype != router_logits.dtype: + e_score_correction_bias = e_score_correction_bias.to( + router_logits.dtype) + topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( + router_logits, + k=top_k, + bias=e_score_correction_bias, + k_group=topk_group, + group_count=num_expert_group, + group_select_mode=1, # 0: the maximum in the group; 1: topk2.sum(fix) + renorm=0, # 0: softmax->topk(fix); 1: topk->softmax + norm_type=norm_type, # 0: softmax; 1: sigmoid + # out_flag=False, # todo new api; should the third output be output + # y2_flag=False, # old api; should the third output be output + routed_scaling_factor=1, + eps=float(1e-20)) + if scoring_func == "softmax": topk_weights = _renormalize_topk_weights(topk_weights, renormalize) return topk_weights, topk_ids