[A5][bugfix] Fix fused MoE A5 MXFP8 scale normalization, load-balance routing and gating_topk ops (#7573)
### What this PR does / why we need it? This PR fixes A5 MXFP8 MoE scale handling in the fused MoE path. - It normalizes MXFP8 activation scales to the packed 3D layout expected by A5 kernels, including both precomputed dynamic_scale inputs and gmm1 output scales before they are consumed by downstream grouped matmul ops. - It also refines the MXFP8 force load-balancing path in profiling runs. - This PR also enables npu_gating_top_k from torch_npu instead of custom op when running ascend950 chip. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? CI and E2E serving tests on Ascend950DT passed. --------- Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
@@ -58,6 +58,40 @@ class BaseDeviceAdaptor:
|
||||
quant_mode=quant_mode,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def maybe_normalize_mxfp_scale_layout(scale: torch.Tensor | None) -> torch.Tensor | None:
|
||||
return scale
|
||||
|
||||
@staticmethod
|
||||
def moe_gating_top_k(
|
||||
x: torch.Tensor,
|
||||
*,
|
||||
k: int,
|
||||
k_group: int,
|
||||
group_count: int,
|
||||
group_select_mode: int,
|
||||
renorm: int,
|
||||
norm_type: int,
|
||||
out_flag: bool,
|
||||
routed_scaling_factor: float = 1.0,
|
||||
eps: float = 1e-20,
|
||||
bias_opt: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
topk_weights, topk_ids, out = torch.ops._C_ascend.moe_gating_top_k(
|
||||
x,
|
||||
k=k,
|
||||
k_group=k_group,
|
||||
group_count=group_count,
|
||||
group_select_mode=group_select_mode,
|
||||
renorm=renorm,
|
||||
norm_type=norm_type,
|
||||
out_flag=out_flag,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
eps=eps,
|
||||
bias_opt=bias_opt,
|
||||
)
|
||||
return topk_weights, topk_ids.to(torch.int32), out
|
||||
|
||||
@staticmethod
|
||||
def npu_dynamic_quant(
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -198,6 +232,46 @@ class A5DeviceAdaptor(BaseDeviceAdaptor):
|
||||
quant_mode=quant_mode,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def maybe_normalize_mxfp_scale_layout(scale: torch.Tensor | None) -> torch.Tensor | None:
|
||||
if scale is None or scale.ndim != 2:
|
||||
return scale
|
||||
if scale.shape[-1] % 2 != 0:
|
||||
raise ValueError(f"Invalid MXFP8 scale shape: {tuple(scale.shape)}")
|
||||
return scale.reshape(scale.shape[0], scale.shape[1] // 2, 2)
|
||||
|
||||
@staticmethod
|
||||
def moe_gating_top_k(
|
||||
x: torch.Tensor,
|
||||
*,
|
||||
k: int,
|
||||
k_group: int,
|
||||
group_count: int,
|
||||
group_select_mode: int,
|
||||
renorm: int,
|
||||
norm_type: int,
|
||||
out_flag: bool,
|
||||
routed_scaling_factor: float = 1.0,
|
||||
eps: float = 1e-20,
|
||||
bias_opt: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
topk_weights, topk_ids, out = torch_npu.npu_moe_gating_top_k(
|
||||
x,
|
||||
k=k,
|
||||
bias=bias_opt,
|
||||
k_group=k_group,
|
||||
group_count=group_count,
|
||||
group_select_mode=group_select_mode,
|
||||
renorm=0,
|
||||
norm_type=norm_type,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
eps=eps,
|
||||
)
|
||||
if norm_type == 0 and renorm == 1:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
return topk_weights, topk_ids.to(torch.int32), out
|
||||
|
||||
@staticmethod
|
||||
def npu_dynamic_quant(
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -215,12 +289,9 @@ class A5DeviceAdaptor(BaseDeviceAdaptor):
|
||||
)
|
||||
|
||||
if dynamic_scale is None:
|
||||
return torch_npu.npu_dynamic_mx_quant(hidden_states, dst_type=act_quant_type)
|
||||
hidden_states, dynamic_scale = torch_npu.npu_dynamic_mx_quant(hidden_states, dst_type=act_quant_type)
|
||||
|
||||
if dynamic_scale.ndim == 2:
|
||||
dynamic_scale = dynamic_scale.reshape(dynamic_scale.shape[0], dynamic_scale.shape[1] // 2, 2)
|
||||
|
||||
return hidden_states, dynamic_scale
|
||||
return hidden_states, A5DeviceAdaptor.maybe_normalize_mxfp_scale_layout(dynamic_scale)
|
||||
|
||||
@staticmethod
|
||||
def npu_grouped_matmul_swiglu_quant(
|
||||
@@ -257,7 +328,7 @@ class A5DeviceAdaptor(BaseDeviceAdaptor):
|
||||
weight_scale_dtype=FLOAT8_E8M0FNU_DTYPE,
|
||||
x_scale_dtype=FLOAT8_E8M0FNU_DTYPE,
|
||||
)
|
||||
return out, out_scale, None
|
||||
return out, A5DeviceAdaptor.maybe_normalize_mxfp_scale_layout(out_scale), None
|
||||
|
||||
@staticmethod
|
||||
def get_quant_gmm2_kwargs(
|
||||
|
||||
Reference in New Issue
Block a user