[Bugfix]: Pass scaling args to mc2 (#1202)
Pass `expert_scale` and `expand_scale` args to the dispatch and combine functions. Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
This commit is contained in:
@@ -130,6 +130,7 @@ def fused_experts_with_mc2(
|
|||||||
"shared_expert_rank_num": 0,
|
"shared_expert_rank_num": 0,
|
||||||
"moe_expert_num": moe_expert_num,
|
"moe_expert_num": moe_expert_num,
|
||||||
"global_bs": global_bs,
|
"global_bs": global_bs,
|
||||||
|
"expert_scales": topk_weights.to(torch.float32),
|
||||||
}
|
}
|
||||||
|
|
||||||
rank = torch.distributed.get_rank()
|
rank = torch.distributed.get_rank()
|
||||||
@@ -158,8 +159,8 @@ def fused_experts_with_mc2(
|
|||||||
|
|
||||||
output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2)
|
output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2)
|
||||||
# comm_stream.wait_stream(torch.npu.current_stream())
|
# comm_stream.wait_stream(torch.npu.current_stream())
|
||||||
expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[
|
expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts, _, expand_scales = output[
|
||||||
0:5]
|
0:7]
|
||||||
|
|
||||||
if shared_experts is not None:
|
if shared_experts is not None:
|
||||||
with npu_stream_switch("moe_secondary", 0):
|
with npu_stream_switch("moe_secondary", 0):
|
||||||
@@ -187,6 +188,7 @@ def fused_experts_with_mc2(
|
|||||||
"shared_expert_rank_num": 0,
|
"shared_expert_rank_num": 0,
|
||||||
"moe_expert_num": moe_expert_num,
|
"moe_expert_num": moe_expert_num,
|
||||||
"global_bs": 0,
|
"global_bs": 0,
|
||||||
|
"expand_scales": expand_scales,
|
||||||
}
|
}
|
||||||
tp_recv_counts = torch.empty(1,
|
tp_recv_counts = torch.empty(1,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
|
|||||||
Reference in New Issue
Block a user