[Kernel] Remove cumsum in groupedmatmul (#987)
### What this PR does / why we need it remove cumsum operator in MOE to improve performance ### How was this patch tested? it should be tested on a case with mc2 operator and graph mode enabled Signed-off-by: zhky <hahazhky@163.com> Co-authored-by: 洪炜杰 <hongweijie1@huawei.com>
This commit is contained in:
@@ -88,15 +88,14 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor,
|
|||||||
0:5]
|
0:5]
|
||||||
|
|
||||||
w1 = w1.transpose(1, 2)
|
w1 = w1.transpose(1, 2)
|
||||||
expert_token_nums = torch.cumsum(expert_token_nums,
|
|
||||||
dim=0,
|
|
||||||
dtype=torch.int64)
|
|
||||||
group_list = expert_token_nums.to(torch.int64)
|
group_list = expert_token_nums.to(torch.int64)
|
||||||
gate_up_out_list = torch_npu.npu_grouped_matmul(
|
gate_up_out_list = torch_npu.npu_grouped_matmul(
|
||||||
x=[expand_x],
|
x=[expand_x],
|
||||||
weight=[w1],
|
weight=[w1],
|
||||||
split_item=2,
|
split_item=2,
|
||||||
group_list_type=0,
|
# 1 means count mode, to avoid cumulative operation of the group list
|
||||||
|
group_list_type=1,
|
||||||
group_type=0,
|
group_type=0,
|
||||||
group_list=group_list,
|
group_list=group_list,
|
||||||
)
|
)
|
||||||
@@ -110,7 +109,7 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor,
|
|||||||
x=[gate_up_out],
|
x=[gate_up_out],
|
||||||
weight=[w2],
|
weight=[w2],
|
||||||
split_item=2,
|
split_item=2,
|
||||||
group_list_type=0,
|
group_list_type=1,
|
||||||
group_type=0,
|
group_type=0,
|
||||||
group_list=group_list,
|
group_list=group_list,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user