[main] remove torch.cat and replace it by List[0] (#2153)

### What this PR does / why we need it?
torch_npu.npu_grouped_matmul:

https://www.hiascend.com/document/detail/zh/Pytorch/710/apiref/torchnpuCustomsapi/context/torch_npu-npu_grouped_matmul.md

According to the document, when `split_item` is 2 or 3,
`torch_npu.npu_grouped_matmul` will return a list which has one element.
Therefore, the `torch.cat` after `torch_npu.npu_grouped_matmul` is
unnecessary.

### Does this PR introduce _any_ user-facing change?
not involved

### How was this patch tested?
ut and e2e covered: `tests/ut/ops/test_fused_ops.py`,
`tests/e2e/singlecard/ops/test_fused_moe.py`

**performance**:
(qwen3 30B, 2k->20k)

base:
Total Token throughput (tok/s):          667.76 

remove cat:
Total Token throughput (tok/s):          680.82 


- vLLM version: v0.10.0
- vLLM main:
fa00c5d75b

Signed-off-by: huangxialu <huangxialu1@huawei.com>
This commit is contained in:
huangxialu
2025-08-07 17:20:19 +08:00
committed by GitHub
parent b2598c3271
commit dceef080b1
2 changed files with 13 additions and 27 deletions

View File

@@ -112,7 +112,7 @@ def mock_moe_env(mocker: MockerFixture):
torch.randn(16, 2) torch.randn(16, 2)
)), \ )), \
patch("torch_npu.npu_grouped_matmul", return_value=( patch("torch_npu.npu_grouped_matmul", return_value=(
(torch.randn(8, 2), torch.randn(8, 2)) [torch.randn(16, 2)]
)), \ )), \
patch("torch_npu.npu_swiglu", return_value=( patch("torch_npu.npu_swiglu", return_value=(
torch.randn(16, 2) torch.randn(16, 2)

View File

@@ -205,11 +205,9 @@ def fused_experts_with_mc2(
group_list_type=1, group_list_type=1,
group_type=0, group_type=0,
group_list=group_list, group_list=group_list,
) )[0]
# TODO: Remove this in the future. gate_up_out = torch_npu.npu_swiglu(gate_up_out_list)
gate_up_out = torch.cat(gate_up_out_list, dim=0)
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
w2 = w2.transpose(1, 2) w2 = w2.transpose(1, 2)
down_out_list = torch_npu.npu_grouped_matmul( down_out_list = torch_npu.npu_grouped_matmul(
@@ -219,9 +217,7 @@ def fused_experts_with_mc2(
group_list_type=1, group_list_type=1,
group_type=0, group_type=0,
group_list=group_list, group_list=group_list,
) )[0]
down_out_list = torch.cat(down_out_list, dim=0)
# moeCombine # moeCombine
kwargs_mc2 = { kwargs_mc2 = {
@@ -312,9 +308,8 @@ def apply_mlp(
group_list_type=group_list_type, group_list_type=group_list_type,
group_type=0, group_type=0,
group_list=group_list, group_list=group_list,
) )[0]
hidden_states = torch.cat(hidden_states, dim=0)
hidden_states = torch_npu.npu_swiglu(hidden_states) hidden_states = torch_npu.npu_swiglu(hidden_states)
w2 = w2.transpose(1, 2) w2 = w2.transpose(1, 2)
@@ -325,9 +320,8 @@ def apply_mlp(
group_list_type=group_list_type, group_list_type=group_list_type,
group_type=0, group_type=0,
group_list=group_list, group_list=group_list,
) )[0]
hidden_states = torch.cat(hidden_states, dim=0)
return hidden_states return hidden_states
@@ -417,23 +411,19 @@ def fused_experts_with_all2all(
group_list_type=0, group_list_type=0,
group_type=0, group_type=0,
group_list=expert_tokens, group_list=expert_tokens,
) )[0]
# TODO: Remove this in the future. hidden_states = torch_npu.npu_swiglu(gate_up_out_list)
hidden_states = torch.cat(gate_up_out_list, dim=0)
hidden_states = torch_npu.npu_swiglu(hidden_states)
w2 = w2.transpose(1, 2) w2 = w2.transpose(1, 2)
down_out_list = torch_npu.npu_grouped_matmul( hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states], x=[hidden_states],
weight=[w2], weight=[w2],
split_item=2, split_item=2,
group_list_type=0, group_list_type=0,
group_type=0, group_type=0,
group_list=expert_tokens, group_list=expert_tokens,
) )[0]
hidden_states = torch.cat(down_out_list, dim=0)
if expert_map is not None: if expert_map is not None:
resorted_idx = torch.argsort(sorted_idx) resorted_idx = torch.argsort(sorted_idx)
@@ -823,11 +813,9 @@ def fused_experts(
group_list_type=0, group_list_type=0,
group_type=0, group_type=0,
group_list=expert_tokens, group_list=expert_tokens,
) )[0]
# TODO: Remove this in the future. gate_up_out = torch_npu.npu_swiglu(gate_up_out_list)
gate_up_out = torch.cat(gate_up_out_list, dim=0)
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
w2 = w2.transpose(1, 2) w2 = w2.transpose(1, 2)
down_out_list = torch_npu.npu_grouped_matmul( down_out_list = torch_npu.npu_grouped_matmul(
@@ -837,9 +825,7 @@ def fused_experts(
group_list_type=0, group_list_type=0,
group_type=0, group_type=0,
group_list=expert_tokens, group_list=expert_tokens,
) )[0]
down_out_list = torch.cat(down_out_list, dim=0)
if expert_map is not None: if expert_map is not None:
weighted_down_out = down_out_list * sorted_weights.unsqueeze(1) weighted_down_out = down_out_list * sorted_weights.unsqueeze(1)