From dceef080b140305841a98d71cd8cfdeccbd390af Mon Sep 17 00:00:00 2001 From: huangxialu Date: Thu, 7 Aug 2025 17:20:19 +0800 Subject: [PATCH] [main] remove torch.cat and replace it by List[0] (#2153) ### What this PR does / why we need it? torch_npu.npu_grouped_matmul: https://www.hiascend.com/document/detail/zh/Pytorch/710/apiref/torchnpuCustomsapi/context/torch_npu-npu_grouped_matmul.md According to the document, when `split_item` is 2 or 3, `torch_npu.npu_grouped_matmul` will return a list which has one element. Therefore, the `torch.cat` after `torch_npu.npu_grouped_matmul` is unnecessary. ### Does this PR introduce _any_ user-facing change? not involved ### How was this patch tested? ut and e2e covered: `tests/ut/ops/test_fused_ops.py`, `tests/e2e/singlecard/ops/test_fused_moe.py` **performance**: (qwen3 30B, 2k->20k) base: Total Token throughput (tok/s): 667.76 remove cat: Total Token throughput (tok/s): 680.82 - vLLM version: v0.10.0 - vLLM main: https://github.com/vllm-project/vllm/commit/fa00c5d75bc63c87f5822f839db1342f19e4acc8 Signed-off-by: huangxialu --- tests/ut/ops/test_fused_ops.py | 2 +- vllm_ascend/ops/fused_moe.py | 38 +++++++++++----------------------- 2 files changed, 13 insertions(+), 27 deletions(-) diff --git a/tests/ut/ops/test_fused_ops.py b/tests/ut/ops/test_fused_ops.py index 6c89f6f..46192f6 100644 --- a/tests/ut/ops/test_fused_ops.py +++ b/tests/ut/ops/test_fused_ops.py @@ -112,7 +112,7 @@ def mock_moe_env(mocker: MockerFixture): torch.randn(16, 2) )), \ patch("torch_npu.npu_grouped_matmul", return_value=( - (torch.randn(8, 2), torch.randn(8, 2)) + [torch.randn(16, 2)] )), \ patch("torch_npu.npu_swiglu", return_value=( torch.randn(16, 2) diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index f35fb10..f3c14bf 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -205,11 +205,9 @@ def fused_experts_with_mc2( group_list_type=1, group_type=0, group_list=group_list, - ) + )[0] - # TODO: Remove this in the future. - gate_up_out = torch.cat(gate_up_out_list, dim=0) - gate_up_out = torch_npu.npu_swiglu(gate_up_out) + gate_up_out = torch_npu.npu_swiglu(gate_up_out_list) w2 = w2.transpose(1, 2) down_out_list = torch_npu.npu_grouped_matmul( @@ -219,9 +217,7 @@ def fused_experts_with_mc2( group_list_type=1, group_type=0, group_list=group_list, - ) - - down_out_list = torch.cat(down_out_list, dim=0) + )[0] # moeCombine kwargs_mc2 = { @@ -312,9 +308,8 @@ def apply_mlp( group_list_type=group_list_type, group_type=0, group_list=group_list, - ) + )[0] - hidden_states = torch.cat(hidden_states, dim=0) hidden_states = torch_npu.npu_swiglu(hidden_states) w2 = w2.transpose(1, 2) @@ -325,9 +320,8 @@ def apply_mlp( group_list_type=group_list_type, group_type=0, group_list=group_list, - ) + )[0] - hidden_states = torch.cat(hidden_states, dim=0) return hidden_states @@ -417,23 +411,19 @@ def fused_experts_with_all2all( group_list_type=0, group_type=0, group_list=expert_tokens, - ) + )[0] - # TODO: Remove this in the future. - hidden_states = torch.cat(gate_up_out_list, dim=0) - hidden_states = torch_npu.npu_swiglu(hidden_states) + hidden_states = torch_npu.npu_swiglu(gate_up_out_list) w2 = w2.transpose(1, 2) - down_out_list = torch_npu.npu_grouped_matmul( + hidden_states = torch_npu.npu_grouped_matmul( x=[hidden_states], weight=[w2], split_item=2, group_list_type=0, group_type=0, group_list=expert_tokens, - ) - - hidden_states = torch.cat(down_out_list, dim=0) + )[0] if expert_map is not None: resorted_idx = torch.argsort(sorted_idx) @@ -823,11 +813,9 @@ def fused_experts( group_list_type=0, group_type=0, group_list=expert_tokens, - ) + )[0] - # TODO: Remove this in the future. - gate_up_out = torch.cat(gate_up_out_list, dim=0) - gate_up_out = torch_npu.npu_swiglu(gate_up_out) + gate_up_out = torch_npu.npu_swiglu(gate_up_out_list) w2 = w2.transpose(1, 2) down_out_list = torch_npu.npu_grouped_matmul( @@ -837,9 +825,7 @@ def fused_experts( group_list_type=0, group_type=0, group_list=expert_tokens, - ) - - down_out_list = torch.cat(down_out_list, dim=0) + )[0] if expert_map is not None: weighted_down_out = down_out_list * sorted_weights.unsqueeze(1)