[main] remove torch.cat and replace it by List[0] (#2153)

### What this PR does / why we need it?
torch_npu.npu_grouped_matmul:

https://www.hiascend.com/document/detail/zh/Pytorch/710/apiref/torchnpuCustomsapi/context/torch_npu-npu_grouped_matmul.md

According to the document, when `split_item` is 2 or 3,
`torch_npu.npu_grouped_matmul` will return a list which has one element.
Therefore, the `torch.cat` after `torch_npu.npu_grouped_matmul` is
unnecessary.

### Does this PR introduce _any_ user-facing change?
not involved

### How was this patch tested?
ut and e2e covered: `tests/ut/ops/test_fused_ops.py`,
`tests/e2e/singlecard/ops/test_fused_moe.py`

**performance**:
(qwen3 30B, 2k->20k)

base:
Total Token throughput (tok/s):          667.76 

remove cat:
Total Token throughput (tok/s):          680.82 


- vLLM version: v0.10.0
- vLLM main:
fa00c5d75b

Signed-off-by: huangxialu <huangxialu1@huawei.com>
This commit is contained in:
huangxialu
2025-08-07 17:20:19 +08:00
committed by GitHub
parent b2598c3271
commit dceef080b1
2 changed files with 13 additions and 27 deletions

View File

@@ -112,7 +112,7 @@ def mock_moe_env(mocker: MockerFixture):
torch.randn(16, 2)
)), \
patch("torch_npu.npu_grouped_matmul", return_value=(
(torch.randn(8, 2), torch.randn(8, 2))
[torch.randn(16, 2)]
)), \
patch("torch_npu.npu_swiglu", return_value=(
torch.randn(16, 2)

View File

@@ -205,11 +205,9 @@ def fused_experts_with_mc2(
group_list_type=1,
group_type=0,
group_list=group_list,
)
)[0]
# TODO: Remove this in the future.
gate_up_out = torch.cat(gate_up_out_list, dim=0)
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
gate_up_out = torch_npu.npu_swiglu(gate_up_out_list)
w2 = w2.transpose(1, 2)
down_out_list = torch_npu.npu_grouped_matmul(
@@ -219,9 +217,7 @@ def fused_experts_with_mc2(
group_list_type=1,
group_type=0,
group_list=group_list,
)
down_out_list = torch.cat(down_out_list, dim=0)
)[0]
# moeCombine
kwargs_mc2 = {
@@ -312,9 +308,8 @@ def apply_mlp(
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
)
)[0]
hidden_states = torch.cat(hidden_states, dim=0)
hidden_states = torch_npu.npu_swiglu(hidden_states)
w2 = w2.transpose(1, 2)
@@ -325,9 +320,8 @@ def apply_mlp(
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
)
)[0]
hidden_states = torch.cat(hidden_states, dim=0)
return hidden_states
@@ -417,23 +411,19 @@ def fused_experts_with_all2all(
group_list_type=0,
group_type=0,
group_list=expert_tokens,
)
)[0]
# TODO: Remove this in the future.
hidden_states = torch.cat(gate_up_out_list, dim=0)
hidden_states = torch_npu.npu_swiglu(hidden_states)
hidden_states = torch_npu.npu_swiglu(gate_up_out_list)
w2 = w2.transpose(1, 2)
down_out_list = torch_npu.npu_grouped_matmul(
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w2],
split_item=2,
group_list_type=0,
group_type=0,
group_list=expert_tokens,
)
hidden_states = torch.cat(down_out_list, dim=0)
)[0]
if expert_map is not None:
resorted_idx = torch.argsort(sorted_idx)
@@ -823,11 +813,9 @@ def fused_experts(
group_list_type=0,
group_type=0,
group_list=expert_tokens,
)
)[0]
# TODO: Remove this in the future.
gate_up_out = torch.cat(gate_up_out_list, dim=0)
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
gate_up_out = torch_npu.npu_swiglu(gate_up_out_list)
w2 = w2.transpose(1, 2)
down_out_list = torch_npu.npu_grouped_matmul(
@@ -837,9 +825,7 @@ def fused_experts(
group_list_type=0,
group_type=0,
group_list=expert_tokens,
)
down_out_list = torch.cat(down_out_list, dim=0)
)[0]
if expert_map is not None:
weighted_down_out = down_out_list * sorted_weights.unsqueeze(1)