[Bugfix]: Pass scaling args to mc2 (#1202)

Pass `expert_scale` and `expand_scale` args to the dispatch and combine
functions.

Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
This commit is contained in:
Jade Zheng
2025-06-17 22:16:44 +08:00
committed by GitHub
parent f8029945c3
commit afc8edb046

View File

@@ -130,6 +130,7 @@ def fused_experts_with_mc2(
"shared_expert_rank_num": 0,
"moe_expert_num": moe_expert_num,
"global_bs": global_bs,
"expert_scales": topk_weights.to(torch.float32),
}
rank = torch.distributed.get_rank()
@@ -158,8 +159,8 @@ def fused_experts_with_mc2(
output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2)
# comm_stream.wait_stream(torch.npu.current_stream())
expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[
0:5]
expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts, _, expand_scales = output[
0:7]
if shared_experts is not None:
with npu_stream_switch("moe_secondary", 0):
@@ -187,6 +188,7 @@ def fused_experts_with_mc2(
"shared_expert_rank_num": 0,
"moe_expert_num": moe_expert_num,
"global_bs": 0,
"expand_scales": expand_scales,
}
tp_recv_counts = torch.empty(1,
dtype=torch.int32,