From 936c102105b72a4e36dd284f900f32338a232696 Mon Sep 17 00:00:00 2001 From: Wang Yixuan <88923622+hust17yixuan@users.noreply.github.com> Date: Thu, 28 Aug 2025 09:10:31 +0800 Subject: [PATCH] [bugfix][refactor]fix torchair_w8a8 (#2569) ### What this PR does / why we need it? torchair w8a8 and w4a8 Separate from fused_moe due to the refactor and change for fused_moe ### Does this PR introduce _any_ user-facing change? NO ### How was this patch tested? vLLM version: main vLLM main: https://github.com/vllm-project/vllm/commit/ab9f2cfd1942f7ddfee658ce86ea96b4789862af - vLLM version: v0.10.1.1 - vLLM main: https://github.com/vllm-project/vllm/commit/69244e67e6822f1c15816f887659e1ccc18c2632 Signed-off-by: hust17yixuan <303660421@qq.com> --- .../quantization/torchair_w4a8_dynamic.py | 43 ++++++++++++------ .../quantization/torchair_w8a8_dynamic.py | 45 +++++++++++++------ 2 files changed, 61 insertions(+), 27 deletions(-) diff --git a/vllm_ascend/torchair/quantization/torchair_w4a8_dynamic.py b/vllm_ascend/torchair/quantization/torchair_w4a8_dynamic.py index 0354b47..f38e2d8 100644 --- a/vllm_ascend/torchair/quantization/torchair_w4a8_dynamic.py +++ b/vllm_ascend/torchair/quantization/torchair_w4a8_dynamic.py @@ -27,7 +27,7 @@ from vllm.forward_context import get_forward_context from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_forward_context import FusedMoEState from vllm_ascend.distributed.parallel_state import get_mc2_group -from vllm_ascend.ops.layers.experts_selector import select_experts +from vllm_ascend.torchair.ops.torchair_fused_moe import torchair_select_experts from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import ( torchair_fused_experts_with_all2all, torchair_fused_experts_with_mc2) from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor @@ -267,19 +267,34 @@ class TorchairAscendW4A8DynamicFusedMoEMethod: assert router_logits.shape[ 1] == global_num_experts, "Number of global experts mismatch" - # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern - topk_weights, topk_ids = select_experts( - hidden_states=x, - router_logits=router_logits, - top_k=top_k, - use_grouped_topk=use_grouped_topk, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias, - global_num_experts=global_num_experts) + if global_num_experts == 256: + topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( + router_logits, + k=top_k, # topk currently is 8 + bias=e_score_correction_bias, + k_group=topk_group, # fix: 4 + group_count=num_expert_group, # fix 8 + group_select_mode= + 1, # 0: the maximum in the group; 1: topk2.sum(fix) + renorm=0, # 0: softmax->topk(fix); 1: topk->softmax + norm_type=1, # 0: softmax; 1: sigmoid(fix) + # out_flag=False, # todo new api; should the third output be output + # y2_flag=False, # old api; should the third output be output + routed_scaling_factor=1, + eps=float(1e-20)) + else: + topk_weights, topk_ids = torchair_select_experts( + hidden_states=x, + router_logits=router_logits, + top_k=top_k, + use_grouped_topk=use_grouped_topk, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + ) fused_moe_state = get_forward_context().fused_moe_state shared_gate_up, shared_dequant_scale = None, None diff --git a/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py b/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py index 9de9cc7..5c3fa95 100644 --- a/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py +++ b/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py @@ -27,7 +27,7 @@ import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_forward_context import FusedMoEState from vllm_ascend.distributed.parallel_state import get_mc2_group -from vllm_ascend.ops.layers.experts_selector import select_experts +from vllm_ascend.torchair.ops.torchair_fused_moe import torchair_select_experts from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendSocVersion, dispose_tensor, get_ascend_soc_version) @@ -904,18 +904,37 @@ class TorchairAscendW8A8DynamicFusedMoEMethod: assert router_logits.shape[ 1] == global_num_experts, "Number of global experts mismatch" - topk_weights, topk_ids = select_experts( - hidden_states=x, - router_logits=router_logits, - top_k=top_k, - use_grouped_topk=use_grouped_topk, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias, - global_num_experts=global_num_experts) + is_deepseek_v3_r1 = global_num_experts == 256 + + # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern + if is_deepseek_v3_r1: + topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( + router_logits, + k=top_k, # topk currently is 8 + bias=e_score_correction_bias, + k_group=topk_group, # fix: 4 + group_count=num_expert_group, # fix 8 + group_select_mode= + 1, # 0: the maximum in the group; 1: topk2.sum(fix) + renorm=0, # 0: softmax->topk(fix); 1: topk->softmax + norm_type=1, # 0: softmax; 1: sigmoid(fix) + # out_flag=False, # todo new api; should the third output be output + # y2_flag=False, # old api; should the third output be output + routed_scaling_factor=1, + eps=float(1e-20)) + else: + topk_weights, topk_ids = torchair_select_experts( + hidden_states=x, + router_logits=router_logits, + top_k=top_k, + use_grouped_topk=use_grouped_topk, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + ) fused_moe_state = get_forward_context().fused_moe_state shared_gate_up, shared_dequant_scale = None, None