[A5][bugfix] Fix fused MoE A5 MXFP8 scale normalization, load-balance routing and gating_topk ops (#7573)

### What this PR does / why we need it?
This PR fixes A5 MXFP8 MoE scale handling in the fused MoE path.

- It normalizes MXFP8 activation scales to the packed 3D layout expected
by A5 kernels, including both precomputed dynamic_scale inputs and gmm1
output scales before they are consumed by downstream grouped matmul ops.
- It also refines the MXFP8 force load-balancing path in profiling runs.
- This PR also enables npu_gating_top_k from torch_npu instead of custom
op when running ascend950 chip.
### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
CI and E2E serving tests on Ascend950DT passed.

---------

Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
linfeng-yuan
2026-03-25 17:20:28 +08:00
committed by GitHub
parent e0e585a109
commit d452d04656
4 changed files with 86 additions and 9 deletions

View File

@@ -196,7 +196,10 @@ class AscendW8A8MXFP8DynamicFusedMoEMethod(AscendMoEScheme):
# to avoid accumulating too much tokens on a single rank.
# currently it is only activated when doing profile runs.
if enable_force_load_balance:
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
random_matrix = torch.rand(
topk_ids.size(0), global_num_experts - global_redundant_expert_num, device=topk_ids.device
)
topk_ids = torch.argsort(random_matrix, dim=1)[:, : topk_ids.size(1)].to(topk_ids.dtype)
topk_weights = topk_weights.to(x.dtype)