[A5][bugfix] Fix fused MoE A5 MXFP8 scale normalization, load-balance routing and gating_topk ops (#7573)
### What this PR does / why we need it? This PR fixes A5 MXFP8 MoE scale handling in the fused MoE path. - It normalizes MXFP8 activation scales to the packed 3D layout expected by A5 kernels, including both precomputed dynamic_scale inputs and gmm1 output scales before they are consumed by downstream grouped matmul ops. - It also refines the MXFP8 force load-balancing path in profiling runs. - This PR also enables npu_gating_top_k from torch_npu instead of custom op when running ascend950 chip. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? CI and E2E serving tests on Ascend950DT passed. --------- Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
@@ -58,6 +58,40 @@ class BaseDeviceAdaptor:
|
|||||||
quant_mode=quant_mode,
|
quant_mode=quant_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def maybe_normalize_mxfp_scale_layout(scale: torch.Tensor | None) -> torch.Tensor | None:
|
||||||
|
return scale
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def moe_gating_top_k(
|
||||||
|
x: torch.Tensor,
|
||||||
|
*,
|
||||||
|
k: int,
|
||||||
|
k_group: int,
|
||||||
|
group_count: int,
|
||||||
|
group_select_mode: int,
|
||||||
|
renorm: int,
|
||||||
|
norm_type: int,
|
||||||
|
out_flag: bool,
|
||||||
|
routed_scaling_factor: float = 1.0,
|
||||||
|
eps: float = 1e-20,
|
||||||
|
bias_opt: torch.Tensor | None = None,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
topk_weights, topk_ids, out = torch.ops._C_ascend.moe_gating_top_k(
|
||||||
|
x,
|
||||||
|
k=k,
|
||||||
|
k_group=k_group,
|
||||||
|
group_count=group_count,
|
||||||
|
group_select_mode=group_select_mode,
|
||||||
|
renorm=renorm,
|
||||||
|
norm_type=norm_type,
|
||||||
|
out_flag=out_flag,
|
||||||
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
|
eps=eps,
|
||||||
|
bias_opt=bias_opt,
|
||||||
|
)
|
||||||
|
return topk_weights, topk_ids.to(torch.int32), out
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def npu_dynamic_quant(
|
def npu_dynamic_quant(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@@ -198,6 +232,46 @@ class A5DeviceAdaptor(BaseDeviceAdaptor):
|
|||||||
quant_mode=quant_mode,
|
quant_mode=quant_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def maybe_normalize_mxfp_scale_layout(scale: torch.Tensor | None) -> torch.Tensor | None:
|
||||||
|
if scale is None or scale.ndim != 2:
|
||||||
|
return scale
|
||||||
|
if scale.shape[-1] % 2 != 0:
|
||||||
|
raise ValueError(f"Invalid MXFP8 scale shape: {tuple(scale.shape)}")
|
||||||
|
return scale.reshape(scale.shape[0], scale.shape[1] // 2, 2)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def moe_gating_top_k(
|
||||||
|
x: torch.Tensor,
|
||||||
|
*,
|
||||||
|
k: int,
|
||||||
|
k_group: int,
|
||||||
|
group_count: int,
|
||||||
|
group_select_mode: int,
|
||||||
|
renorm: int,
|
||||||
|
norm_type: int,
|
||||||
|
out_flag: bool,
|
||||||
|
routed_scaling_factor: float = 1.0,
|
||||||
|
eps: float = 1e-20,
|
||||||
|
bias_opt: torch.Tensor | None = None,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
topk_weights, topk_ids, out = torch_npu.npu_moe_gating_top_k(
|
||||||
|
x,
|
||||||
|
k=k,
|
||||||
|
bias=bias_opt,
|
||||||
|
k_group=k_group,
|
||||||
|
group_count=group_count,
|
||||||
|
group_select_mode=group_select_mode,
|
||||||
|
renorm=0,
|
||||||
|
norm_type=norm_type,
|
||||||
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
|
eps=eps,
|
||||||
|
)
|
||||||
|
if norm_type == 0 and renorm == 1:
|
||||||
|
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
return topk_weights, topk_ids.to(torch.int32), out
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def npu_dynamic_quant(
|
def npu_dynamic_quant(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@@ -215,12 +289,9 @@ class A5DeviceAdaptor(BaseDeviceAdaptor):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if dynamic_scale is None:
|
if dynamic_scale is None:
|
||||||
return torch_npu.npu_dynamic_mx_quant(hidden_states, dst_type=act_quant_type)
|
hidden_states, dynamic_scale = torch_npu.npu_dynamic_mx_quant(hidden_states, dst_type=act_quant_type)
|
||||||
|
|
||||||
if dynamic_scale.ndim == 2:
|
return hidden_states, A5DeviceAdaptor.maybe_normalize_mxfp_scale_layout(dynamic_scale)
|
||||||
dynamic_scale = dynamic_scale.reshape(dynamic_scale.shape[0], dynamic_scale.shape[1] // 2, 2)
|
|
||||||
|
|
||||||
return hidden_states, dynamic_scale
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def npu_grouped_matmul_swiglu_quant(
|
def npu_grouped_matmul_swiglu_quant(
|
||||||
@@ -257,7 +328,7 @@ class A5DeviceAdaptor(BaseDeviceAdaptor):
|
|||||||
weight_scale_dtype=FLOAT8_E8M0FNU_DTYPE,
|
weight_scale_dtype=FLOAT8_E8M0FNU_DTYPE,
|
||||||
x_scale_dtype=FLOAT8_E8M0FNU_DTYPE,
|
x_scale_dtype=FLOAT8_E8M0FNU_DTYPE,
|
||||||
)
|
)
|
||||||
return out, out_scale, None
|
return out, A5DeviceAdaptor.maybe_normalize_mxfp_scale_layout(out_scale), None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_quant_gmm2_kwargs(
|
def get_quant_gmm2_kwargs(
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from collections.abc import Callable
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm_ascend.device.device_op import DeviceOperator
|
||||||
from vllm_ascend.utils import get_weight_prefetch_method
|
from vllm_ascend.utils import get_weight_prefetch_method
|
||||||
|
|
||||||
|
|
||||||
@@ -216,7 +217,7 @@ def _select_experts_with_fusion_ops(
|
|||||||
norm_type = 0 if scoring_func == "softmax" else 1
|
norm_type = 0 if scoring_func == "softmax" else 1
|
||||||
if e_score_correction_bias is not None and e_score_correction_bias.dtype != router_logits.dtype:
|
if e_score_correction_bias is not None and e_score_correction_bias.dtype != router_logits.dtype:
|
||||||
e_score_correction_bias = e_score_correction_bias.to(router_logits.dtype)
|
e_score_correction_bias = e_score_correction_bias.to(router_logits.dtype)
|
||||||
topk_weights, topk_ids, _ = torch.ops._C_ascend.moe_gating_top_k(
|
topk_weights, topk_ids, _ = DeviceOperator.moe_gating_top_k(
|
||||||
router_logits,
|
router_logits,
|
||||||
k=top_k,
|
k=top_k,
|
||||||
k_group=topk_group,
|
k_group=topk_group,
|
||||||
|
|||||||
@@ -128,7 +128,9 @@ def quant_apply_mlp(
|
|||||||
quantized_hidden_states = None
|
quantized_hidden_states = None
|
||||||
else:
|
else:
|
||||||
unquantized_hidden_states = None
|
unquantized_hidden_states = None
|
||||||
pertoken_scale = dynamic_scale
|
pertoken_scale = (
|
||||||
|
DeviceOperator.maybe_normalize_mxfp_scale_layout(dynamic_scale) if use_mxfp_quant else dynamic_scale
|
||||||
|
)
|
||||||
quantized_hidden_states = hidden_states
|
quantized_hidden_states = hidden_states
|
||||||
|
|
||||||
bias1, bias2 = None, None
|
bias1, bias2 = None, None
|
||||||
|
|||||||
@@ -196,7 +196,10 @@ class AscendW8A8MXFP8DynamicFusedMoEMethod(AscendMoEScheme):
|
|||||||
# to avoid accumulating too much tokens on a single rank.
|
# to avoid accumulating too much tokens on a single rank.
|
||||||
# currently it is only activated when doing profile runs.
|
# currently it is only activated when doing profile runs.
|
||||||
if enable_force_load_balance:
|
if enable_force_load_balance:
|
||||||
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
|
random_matrix = torch.rand(
|
||||||
|
topk_ids.size(0), global_num_experts - global_redundant_expert_num, device=topk_ids.device
|
||||||
|
)
|
||||||
|
topk_ids = torch.argsort(random_matrix, dim=1)[:, : topk_ids.size(1)].to(topk_ids.dtype)
|
||||||
|
|
||||||
topk_weights = topk_weights.to(x.dtype)
|
topk_weights = topk_weights.to(x.dtype)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user