[Bugfix]: Pass scaling args to mc2 (#1202)
Pass `expert_scale` and `expand_scale` args to the dispatch and combine functions. Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
This commit is contained in:
@@ -130,6 +130,7 @@ def fused_experts_with_mc2(
|
||||
"shared_expert_rank_num": 0,
|
||||
"moe_expert_num": moe_expert_num,
|
||||
"global_bs": global_bs,
|
||||
"expert_scales": topk_weights.to(torch.float32),
|
||||
}
|
||||
|
||||
rank = torch.distributed.get_rank()
|
||||
@@ -158,8 +159,8 @@ def fused_experts_with_mc2(
|
||||
|
||||
output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2)
|
||||
# comm_stream.wait_stream(torch.npu.current_stream())
|
||||
expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[
|
||||
0:5]
|
||||
expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts, _, expand_scales = output[
|
||||
0:7]
|
||||
|
||||
if shared_experts is not None:
|
||||
with npu_stream_switch("moe_secondary", 0):
|
||||
@@ -187,6 +188,7 @@ def fused_experts_with_mc2(
|
||||
"shared_expert_rank_num": 0,
|
||||
"moe_expert_num": moe_expert_num,
|
||||
"global_bs": 0,
|
||||
"expand_scales": expand_scales,
|
||||
}
|
||||
tp_recv_counts = torch.empty(1,
|
||||
dtype=torch.int32,
|
||||
|
||||
Reference in New Issue
Block a user