diff --git a/tests/ut/torchair/ops/test_torchair_fused_moe.py b/tests/ut/torchair/ops/test_torchair_fused_moe.py index 705c794..e3365de 100644 --- a/tests/ut/torchair/ops/test_torchair_fused_moe.py +++ b/tests/ut/torchair/ops/test_torchair_fused_moe.py @@ -383,7 +383,7 @@ class TestTorchairAscendUnquantizedFusedMoEMethod: else: assert result.shape == x.shape - @pytest.mark.parametrize("others_param", [16, 1, 4]) + @pytest.mark.parametrize("others_param", [16, 4]) def test_apply_with_expert_map(self, moe_method, mock_dist_env, mock_moe_env, others_param): """ @@ -396,9 +396,18 @@ class TestTorchairAscendUnquantizedFusedMoEMethod: is_prefill = False forward_context = MagicMock( fused_moe_state=_get_fused_moe_state(ep_size, is_prefill, True)) - with patch("vllm_ascend.torchair.ops.torchair_fused_moe.get_forward_context", return_value=forward_context), \ - patch("vllm_ascend.torchair.ops.torchair_fused_moe.get_ascend_soc_version", return_value=AscendSocVersion.A3): + if ep_size == 1: + top_k_return = (torch.randn(16, 2), torch.randint(0, 16, + (16, 2)), None) + expert_map = torch.tensor( + [0, 1, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]) + else: + top_k_return = (torch.randn(8, 2), torch.randint(0, 8, + (8, 2)), None) expert_map = torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]) + with patch("vllm_ascend.torchair.ops.torchair_fused_moe.get_forward_context", return_value=forward_context), \ + patch("vllm_ascend.torchair.ops.torchair_fused_moe.get_ascend_soc_version", return_value=AscendSocVersion.A3), \ + patch('torch_npu.npu_moe_gating_top_k', return_value=top_k_return): moe_method.ep_size = ep_size x = torch.randn(8, 2, 2) if ep_size == 1: diff --git a/vllm_ascend/torchair/ops/torchair_fused_moe.py b/vllm_ascend/torchair/ops/torchair_fused_moe.py index d42d023..2825c67 100644 --- a/vllm_ascend/torchair/ops/torchair_fused_moe.py +++ b/vllm_ascend/torchair/ops/torchair_fused_moe.py @@ -857,38 +857,20 @@ class TorchairAscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): shared_experts: Optional[Any] = None, **kwargs, ) -> torch.Tensor: - global_redundant_expert_num = get_ascend_config( - ).init_redundancy_expert - is_deepseek_v3_r1 = global_num_experts - global_redundant_expert_num == 256 - # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern - if is_deepseek_v3_r1: - topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( - router_logits, - k=top_k, # topk currently is 8 - bias=e_score_correction_bias, - k_group=topk_group, # fix: 4 - group_count=num_expert_group, # fix 8 - group_select_mode= - 1, # 0: the maximum in the group; 1: topk2.sum(fix) - renorm=0, # 0: softmax->topk(fix); 1: topk->softmax - norm_type=1, # 0: softmax; 1: sigmoid(fix) - # out_flag=False, # todo new api; should the third output be output - # y2_flag=False, # old api; should the third output be output - routed_scaling_factor=1, - eps=float(1e-20)) - else: - topk_weights, topk_ids = torchair_select_experts( - hidden_states=x, - router_logits=router_logits, - top_k=top_k, - use_grouped_topk=use_grouped_topk, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias, - ) + topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( + router_logits, + k=top_k, # topk currently is 8 + bias=e_score_correction_bias, + k_group=topk_group, # fix: 4 + group_count=num_expert_group, # fix 8 + group_select_mode= + 1, # 0: the maximum in the group; 1: topk2.sum(fix) + renorm=0, # 0: softmax->topk(fix); 1: topk->softmax + norm_type=1, # 0: softmax; 1: sigmoid(fix) + # out_flag=False, # todo new api; should the third output be output + # y2_flag=False, # old api; should the third output be output + routed_scaling_factor=1, + eps=float(1e-20)) topk_weights = topk_weights.to(x.dtype) # this is a naive implementation for experts load balance so as diff --git a/vllm_ascend/torchair/quantization/torchair_w4a8_dynamic.py b/vllm_ascend/torchair/quantization/torchair_w4a8_dynamic.py index c61ddf3..b8cc365 100644 --- a/vllm_ascend/torchair/quantization/torchair_w4a8_dynamic.py +++ b/vllm_ascend/torchair/quantization/torchair_w4a8_dynamic.py @@ -27,7 +27,6 @@ from vllm.forward_context import get_forward_context from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_forward_context import FusedMoEState from vllm_ascend.distributed.parallel_state import get_mc2_group -from vllm_ascend.torchair.ops.torchair_fused_moe import torchair_select_experts from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import ( torchair_fused_experts_with_all2all, torchair_fused_experts_with_mc2) from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor @@ -322,34 +321,20 @@ class TorchairAscendW4A8DynamicFusedMoEMethod: assert router_logits.shape[ 1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)" - if global_num_experts == 256: - topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( - router_logits, - k=top_k, # topk currently is 8 - bias=e_score_correction_bias, - k_group=topk_group, # fix: 4 - group_count=num_expert_group, # fix 8 - group_select_mode= - 1, # 0: the maximum in the group; 1: topk2.sum(fix) - renorm=0, # 0: softmax->topk(fix); 1: topk->softmax - norm_type=1, # 0: softmax; 1: sigmoid(fix) - # out_flag=False, # todo new api; should the third output be output - # y2_flag=False, # old api; should the third output be output - routed_scaling_factor=1, - eps=float(1e-20)) - else: - topk_weights, topk_ids = torchair_select_experts( - hidden_states=x, - router_logits=router_logits, - top_k=top_k, - use_grouped_topk=use_grouped_topk, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias, - ) + topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( + router_logits, + k=top_k, # topk currently is 8 + bias=e_score_correction_bias, + k_group=topk_group, # fix: 4 + group_count=num_expert_group, # fix 8 + group_select_mode= + 1, # 0: the maximum in the group; 1: topk2.sum(fix) + renorm=0, # 0: softmax->topk(fix); 1: topk->softmax + norm_type=1, # 0: softmax; 1: sigmoid(fix) + # out_flag=False, # todo new api; should the third output be output + # y2_flag=False, # old api; should the third output be output + routed_scaling_factor=1, + eps=float(1e-20)) fused_moe_state = get_forward_context().fused_moe_state shared_gate_up, shared_dequant_scale = None, None diff --git a/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py b/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py index f639270..26b0940 100644 --- a/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py +++ b/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py @@ -25,7 +25,6 @@ from vllm.forward_context import get_forward_context from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_forward_context import FusedMoEState from vllm_ascend.distributed.parallel_state import get_mc2_group -from vllm_ascend.torchair.ops.torchair_fused_moe import torchair_select_experts from vllm_ascend.torchair.utils import (npu_stream_switch, npu_wait_tensor, super_kernel) from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendSocVersion, @@ -938,8 +937,6 @@ class TorchairAscendW8A8DynamicFusedMoEMethod: assert router_logits.shape[ 1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)" - is_deepseek_v3_r1 = global_num_experts - global_redundant_expert_num == 256 - fused_moe_state = get_forward_context().fused_moe_state if self.enable_shared_expert_dp and fused_moe_state == FusedMoEState.MC2: fused_moe_state = FusedMoEState.All2All @@ -948,35 +945,20 @@ class TorchairAscendW8A8DynamicFusedMoEMethod: with super_kernel(prefix, "stream-fusion=1", enabled=running_in_super_kernel): - # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern - if is_deepseek_v3_r1: - topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( - router_logits, - k=top_k, # topk currently is 8 - bias=e_score_correction_bias, - k_group=topk_group, # fix: 4 - group_count=num_expert_group, # fix 8 - group_select_mode= - 1, # 0: the maximum in the group; 1: topk2.sum(fix) - renorm=0, # 0: softmax->topk(fix); 1: topk->softmax - norm_type=1, # 0: softmax; 1: sigmoid(fix) - # out_flag=False, # todo new api; should the third output be output - # y2_flag=False, # old api; should the third output be output - routed_scaling_factor=1, - eps=float(1e-20)) - else: - topk_weights, topk_ids = torchair_select_experts( - hidden_states=x, - router_logits=router_logits, - top_k=top_k, - use_grouped_topk=use_grouped_topk, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias, - ) + topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( + router_logits, + k=top_k, # topk currently is 8 + bias=e_score_correction_bias, + k_group=topk_group, # fix: 4 + group_count=num_expert_group, # fix 8 + group_select_mode= + 1, # 0: the maximum in the group; 1: topk2.sum(fix) + renorm=0, # 0: softmax->topk(fix); 1: topk->softmax + norm_type=1, # 0: softmax; 1: sigmoid(fix) + # out_flag=False, # todo new api; should the third output be output + # y2_flag=False, # old api; should the third output be output + routed_scaling_factor=1, + eps=float(1e-20)) if shared_experts is not None and fused_moe_state == FusedMoEState.MC2: with npu_stream_switch("moe_secondary", 0):