diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 6aff62f..6036e60 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -88,15 +88,14 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor, 0:5] w1 = w1.transpose(1, 2) - expert_token_nums = torch.cumsum(expert_token_nums, - dim=0, - dtype=torch.int64) + group_list = expert_token_nums.to(torch.int64) gate_up_out_list = torch_npu.npu_grouped_matmul( x=[expand_x], weight=[w1], split_item=2, - group_list_type=0, + # 1 means count mode, to avoid cumulative operation of the group list + group_list_type=1, group_type=0, group_list=group_list, ) @@ -110,7 +109,7 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor, x=[gate_up_out], weight=[w2], split_item=2, - group_list_type=0, + group_list_type=1, group_type=0, group_list=group_list, )