[A5][bugfix] Fix fused MoE A5 MXFP8 scale normalization, load-balance routing and gating_topk ops (#7573)
### What this PR does / why we need it? This PR fixes A5 MXFP8 MoE scale handling in the fused MoE path. - It normalizes MXFP8 activation scales to the packed 3D layout expected by A5 kernels, including both precomputed dynamic_scale inputs and gmm1 output scales before they are consumed by downstream grouped matmul ops. - It also refines the MXFP8 force load-balancing path in profiling runs. - This PR also enables npu_gating_top_k from torch_npu instead of custom op when running ascend950 chip. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? CI and E2E serving tests on Ascend950DT passed. --------- Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
@@ -196,7 +196,10 @@ class AscendW8A8MXFP8DynamicFusedMoEMethod(AscendMoEScheme):
|
||||
# to avoid accumulating too much tokens on a single rank.
|
||||
# currently it is only activated when doing profile runs.
|
||||
if enable_force_load_balance:
|
||||
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
|
||||
random_matrix = torch.rand(
|
||||
topk_ids.size(0), global_num_experts - global_redundant_expert_num, device=topk_ids.device
|
||||
)
|
||||
topk_ids = torch.argsort(random_matrix, dim=1)[:, : topk_ids.size(1)].to(topk_ids.dtype)
|
||||
|
||||
topk_weights = topk_weights.to(x.dtype)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user