From d452d04656eabc0bf719e1d07093973274eebd6e Mon Sep 17 00:00:00 2001 From: linfeng-yuan <1102311262@qq.com> Date: Wed, 25 Mar 2026 17:20:28 +0800 Subject: [PATCH] [A5][bugfix] Fix fused MoE A5 MXFP8 scale normalization, load-balance routing and gating_topk ops (#7573) ### What this PR does / why we need it? This PR fixes A5 MXFP8 MoE scale handling in the fused MoE path. - It normalizes MXFP8 activation scales to the packed 3D layout expected by A5 kernels, including both precomputed dynamic_scale inputs and gmm1 output scales before they are consumed by downstream grouped matmul ops. - It also refines the MXFP8 force load-balancing path in profiling runs. - This PR also enables npu_gating_top_k from torch_npu instead of custom op when running ascend950 chip. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? CI and E2E serving tests on Ascend950DT passed. --------- Signed-off-by: linfeng-yuan <1102311262@qq.com> --- vllm_ascend/device/device_op.py | 83 +++++++++++++++++-- vllm_ascend/ops/fused_moe/experts_selector.py | 3 +- vllm_ascend/ops/fused_moe/moe_mlp.py | 4 +- .../quantization/methods/w8a8_mxfp8.py | 5 +- 4 files changed, 86 insertions(+), 9 deletions(-) diff --git a/vllm_ascend/device/device_op.py b/vllm_ascend/device/device_op.py index 46701d6a..cc7e7aa0 100644 --- a/vllm_ascend/device/device_op.py +++ b/vllm_ascend/device/device_op.py @@ -58,6 +58,40 @@ class BaseDeviceAdaptor: quant_mode=quant_mode, ) + @staticmethod + def maybe_normalize_mxfp_scale_layout(scale: torch.Tensor | None) -> torch.Tensor | None: + return scale + + @staticmethod + def moe_gating_top_k( + x: torch.Tensor, + *, + k: int, + k_group: int, + group_count: int, + group_select_mode: int, + renorm: int, + norm_type: int, + out_flag: bool, + routed_scaling_factor: float = 1.0, + eps: float = 1e-20, + bias_opt: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + topk_weights, topk_ids, out = torch.ops._C_ascend.moe_gating_top_k( + x, + k=k, + k_group=k_group, + group_count=group_count, + group_select_mode=group_select_mode, + renorm=renorm, + norm_type=norm_type, + out_flag=out_flag, + routed_scaling_factor=routed_scaling_factor, + eps=eps, + bias_opt=bias_opt, + ) + return topk_weights, topk_ids.to(torch.int32), out + @staticmethod def npu_dynamic_quant( hidden_states: torch.Tensor, @@ -198,6 +232,46 @@ class A5DeviceAdaptor(BaseDeviceAdaptor): quant_mode=quant_mode, ) + @staticmethod + def maybe_normalize_mxfp_scale_layout(scale: torch.Tensor | None) -> torch.Tensor | None: + if scale is None or scale.ndim != 2: + return scale + if scale.shape[-1] % 2 != 0: + raise ValueError(f"Invalid MXFP8 scale shape: {tuple(scale.shape)}") + return scale.reshape(scale.shape[0], scale.shape[1] // 2, 2) + + @staticmethod + def moe_gating_top_k( + x: torch.Tensor, + *, + k: int, + k_group: int, + group_count: int, + group_select_mode: int, + renorm: int, + norm_type: int, + out_flag: bool, + routed_scaling_factor: float = 1.0, + eps: float = 1e-20, + bias_opt: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + topk_weights, topk_ids, out = torch_npu.npu_moe_gating_top_k( + x, + k=k, + bias=bias_opt, + k_group=k_group, + group_count=group_count, + group_select_mode=group_select_mode, + renorm=0, + norm_type=norm_type, + routed_scaling_factor=routed_scaling_factor, + eps=eps, + ) + if norm_type == 0 and renorm == 1: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + return topk_weights, topk_ids.to(torch.int32), out + @staticmethod def npu_dynamic_quant( hidden_states: torch.Tensor, @@ -215,12 +289,9 @@ class A5DeviceAdaptor(BaseDeviceAdaptor): ) if dynamic_scale is None: - return torch_npu.npu_dynamic_mx_quant(hidden_states, dst_type=act_quant_type) + hidden_states, dynamic_scale = torch_npu.npu_dynamic_mx_quant(hidden_states, dst_type=act_quant_type) - if dynamic_scale.ndim == 2: - dynamic_scale = dynamic_scale.reshape(dynamic_scale.shape[0], dynamic_scale.shape[1] // 2, 2) - - return hidden_states, dynamic_scale + return hidden_states, A5DeviceAdaptor.maybe_normalize_mxfp_scale_layout(dynamic_scale) @staticmethod def npu_grouped_matmul_swiglu_quant( @@ -257,7 +328,7 @@ class A5DeviceAdaptor(BaseDeviceAdaptor): weight_scale_dtype=FLOAT8_E8M0FNU_DTYPE, x_scale_dtype=FLOAT8_E8M0FNU_DTYPE, ) - return out, out_scale, None + return out, A5DeviceAdaptor.maybe_normalize_mxfp_scale_layout(out_scale), None @staticmethod def get_quant_gmm2_kwargs( diff --git a/vllm_ascend/ops/fused_moe/experts_selector.py b/vllm_ascend/ops/fused_moe/experts_selector.py index 3f7a3fd0..300b633f 100644 --- a/vllm_ascend/ops/fused_moe/experts_selector.py +++ b/vllm_ascend/ops/fused_moe/experts_selector.py @@ -18,6 +18,7 @@ from collections.abc import Callable import torch +from vllm_ascend.device.device_op import DeviceOperator from vllm_ascend.utils import get_weight_prefetch_method @@ -216,7 +217,7 @@ def _select_experts_with_fusion_ops( norm_type = 0 if scoring_func == "softmax" else 1 if e_score_correction_bias is not None and e_score_correction_bias.dtype != router_logits.dtype: e_score_correction_bias = e_score_correction_bias.to(router_logits.dtype) - topk_weights, topk_ids, _ = torch.ops._C_ascend.moe_gating_top_k( + topk_weights, topk_ids, _ = DeviceOperator.moe_gating_top_k( router_logits, k=top_k, k_group=topk_group, diff --git a/vllm_ascend/ops/fused_moe/moe_mlp.py b/vllm_ascend/ops/fused_moe/moe_mlp.py index 00b2f41b..d649683f 100644 --- a/vllm_ascend/ops/fused_moe/moe_mlp.py +++ b/vllm_ascend/ops/fused_moe/moe_mlp.py @@ -128,7 +128,9 @@ def quant_apply_mlp( quantized_hidden_states = None else: unquantized_hidden_states = None - pertoken_scale = dynamic_scale + pertoken_scale = ( + DeviceOperator.maybe_normalize_mxfp_scale_layout(dynamic_scale) if use_mxfp_quant else dynamic_scale + ) quantized_hidden_states = hidden_states bias1, bias2 = None, None diff --git a/vllm_ascend/quantization/methods/w8a8_mxfp8.py b/vllm_ascend/quantization/methods/w8a8_mxfp8.py index 79fa9480..cb4a1f42 100644 --- a/vllm_ascend/quantization/methods/w8a8_mxfp8.py +++ b/vllm_ascend/quantization/methods/w8a8_mxfp8.py @@ -196,7 +196,10 @@ class AscendW8A8MXFP8DynamicFusedMoEMethod(AscendMoEScheme): # to avoid accumulating too much tokens on a single rank. # currently it is only activated when doing profile runs. if enable_force_load_balance: - topk_ids = torch.randint_like(topk_ids, 0, global_num_experts) + random_matrix = torch.rand( + topk_ids.size(0), global_num_experts - global_redundant_expert_num, device=topk_ids.device + ) + topk_ids = torch.argsort(random_matrix, dim=1)[:, : topk_ids.size(1)].to(topk_ids.dtype) topk_weights = topk_weights.to(x.dtype)