[bugfix][refactor]fix torchair_w8a8 (#2569)

### What this PR does / why we need it?
torchair w8a8 and w4a8 Separate from fused_moe due to the refactor and
change for fused_moe

### Does this PR introduce _any_ user-facing change?
NO

### How was this patch tested?
vLLM version: main
vLLM main:
ab9f2cfd19


- vLLM version: v0.10.1.1
- vLLM main:
69244e67e6

Signed-off-by: hust17yixuan <303660421@qq.com>
This commit is contained in:
Wang Yixuan
2025-08-28 09:10:31 +08:00
committed by GitHub
parent a955e5d404
commit 936c102105
2 changed files with 61 additions and 27 deletions

View File

@@ -27,7 +27,7 @@ from vllm.forward_context import get_forward_context
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import FusedMoEState from vllm_ascend.ascend_forward_context import FusedMoEState
from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.layers.experts_selector import select_experts from vllm_ascend.torchair.ops.torchair_fused_moe import torchair_select_experts
from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import ( from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import (
torchair_fused_experts_with_all2all, torchair_fused_experts_with_mc2) torchair_fused_experts_with_all2all, torchair_fused_experts_with_mc2)
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
@@ -267,19 +267,34 @@ class TorchairAscendW4A8DynamicFusedMoEMethod:
assert router_logits.shape[ assert router_logits.shape[
1] == global_num_experts, "Number of global experts mismatch" 1] == global_num_experts, "Number of global experts mismatch"
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern if global_num_experts == 256:
topk_weights, topk_ids = select_experts( topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
hidden_states=x, router_logits,
router_logits=router_logits, k=top_k, # topk currently is 8
top_k=top_k, bias=e_score_correction_bias,
use_grouped_topk=use_grouped_topk, k_group=topk_group, # fix: 4
renormalize=renormalize, group_count=num_expert_group, # fix 8
topk_group=topk_group, group_select_mode=
num_expert_group=num_expert_group, 1, # 0: the maximum in the group; 1: topk2.sum(fix)
custom_routing_function=custom_routing_function, renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
scoring_func=scoring_func, norm_type=1, # 0: softmax; 1: sigmoid(fix)
e_score_correction_bias=e_score_correction_bias, # out_flag=False, # todo new api; should the third output be output
global_num_experts=global_num_experts) # y2_flag=False, # old api; should the third output be output
routed_scaling_factor=1,
eps=float(1e-20))
else:
topk_weights, topk_ids = torchair_select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)
fused_moe_state = get_forward_context().fused_moe_state fused_moe_state = get_forward_context().fused_moe_state
shared_gate_up, shared_dequant_scale = None, None shared_gate_up, shared_dequant_scale = None, None

View File

@@ -27,7 +27,7 @@ import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import FusedMoEState from vllm_ascend.ascend_forward_context import FusedMoEState
from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.layers.experts_selector import select_experts from vllm_ascend.torchair.ops.torchair_fused_moe import torchair_select_experts
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendSocVersion, from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendSocVersion,
dispose_tensor, get_ascend_soc_version) dispose_tensor, get_ascend_soc_version)
@@ -904,18 +904,37 @@ class TorchairAscendW8A8DynamicFusedMoEMethod:
assert router_logits.shape[ assert router_logits.shape[
1] == global_num_experts, "Number of global experts mismatch" 1] == global_num_experts, "Number of global experts mismatch"
topk_weights, topk_ids = select_experts( is_deepseek_v3_r1 = global_num_experts == 256
hidden_states=x,
router_logits=router_logits, # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
top_k=top_k, if is_deepseek_v3_r1:
use_grouped_topk=use_grouped_topk, topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
renormalize=renormalize, router_logits,
topk_group=topk_group, k=top_k, # topk currently is 8
num_expert_group=num_expert_group, bias=e_score_correction_bias,
custom_routing_function=custom_routing_function, k_group=topk_group, # fix: 4
scoring_func=scoring_func, group_count=num_expert_group, # fix 8
e_score_correction_bias=e_score_correction_bias, group_select_mode=
global_num_experts=global_num_experts) 1, # 0: the maximum in the group; 1: topk2.sum(fix)
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
norm_type=1, # 0: softmax; 1: sigmoid(fix)
# out_flag=False, # todo new api; should the third output be output
# y2_flag=False, # old api; should the third output be output
routed_scaling_factor=1,
eps=float(1e-20))
else:
topk_weights, topk_ids = torchair_select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)
fused_moe_state = get_forward_context().fused_moe_state fused_moe_state = get_forward_context().fused_moe_state
shared_gate_up, shared_dequant_scale = None, None shared_gate_up, shared_dequant_scale = None, None