[A5][bugfix] Fix fused MoE A5 MXFP8 scale normalization, load-balance routing and gating_topk ops (#7573)

### What this PR does / why we need it?
This PR fixes A5 MXFP8 MoE scale handling in the fused MoE path.

- It normalizes MXFP8 activation scales to the packed 3D layout expected
by A5 kernels, including both precomputed dynamic_scale inputs and gmm1
output scales before they are consumed by downstream grouped matmul ops.
- It also refines the MXFP8 force load-balancing path in profiling runs.
- This PR also enables npu_gating_top_k from torch_npu instead of custom
op when running ascend950 chip.
### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
CI and E2E serving tests on Ascend950DT passed.

---------

Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
linfeng-yuan
2026-03-25 17:20:28 +08:00
committed by GitHub
parent e0e585a109
commit d452d04656
4 changed files with 86 additions and 9 deletions

View File

@@ -18,6 +18,7 @@ from collections.abc import Callable
import torch
from vllm_ascend.device.device_op import DeviceOperator
from vllm_ascend.utils import get_weight_prefetch_method
@@ -216,7 +217,7 @@ def _select_experts_with_fusion_ops(
norm_type = 0 if scoring_func == "softmax" else 1
if e_score_correction_bias is not None and e_score_correction_bias.dtype != router_logits.dtype:
e_score_correction_bias = e_score_correction_bias.to(router_logits.dtype)
topk_weights, topk_ids, _ = torch.ops._C_ascend.moe_gating_top_k(
topk_weights, topk_ids, _ = DeviceOperator.moe_gating_top_k(
router_logits,
k=top_k,
k_group=topk_group,

View File

@@ -128,7 +128,9 @@ def quant_apply_mlp(
quantized_hidden_states = None
else:
unquantized_hidden_states = None
pertoken_scale = dynamic_scale
pertoken_scale = (
DeviceOperator.maybe_normalize_mxfp_scale_layout(dynamic_scale) if use_mxfp_quant else dynamic_scale
)
quantized_hidden_states = hidden_states
bias1, bias2 = None, None