From 0b12c2acf7d9fd192beebebf662298067d9a5435 Mon Sep 17 00:00:00 2001 From: hahazhky Date: Fri, 6 Jun 2025 19:17:27 +0800 Subject: [PATCH] [Kernel] Remove cumsum in groupedmatmul (#987) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What this PR does / why we need it remove cumsum operator in MOE to improve performance ### How was this patch tested? it should be tested on a case with mc2 operator and graph mode enabled Signed-off-by: zhky Co-authored-by: 洪炜杰 --- vllm_ascend/ops/fused_moe.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 6aff62f..6036e60 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -88,15 +88,14 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor, 0:5] w1 = w1.transpose(1, 2) - expert_token_nums = torch.cumsum(expert_token_nums, - dim=0, - dtype=torch.int64) + group_list = expert_token_nums.to(torch.int64) gate_up_out_list = torch_npu.npu_grouped_matmul( x=[expand_x], weight=[w1], split_item=2, - group_list_type=0, + # 1 means count mode, to avoid cumulative operation of the group list + group_list_type=1, group_type=0, group_list=group_list, ) @@ -110,7 +109,7 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor, x=[gate_up_out], weight=[w2], split_item=2, - group_list_type=0, + group_list_type=1, group_type=0, group_list=group_list, )