[main] remove torch.cat and replace it by List[0] (#2153)
### What this PR does / why we need it?
torch_npu.npu_grouped_matmul:
https://www.hiascend.com/document/detail/zh/Pytorch/710/apiref/torchnpuCustomsapi/context/torch_npu-npu_grouped_matmul.md
According to the document, when `split_item` is 2 or 3,
`torch_npu.npu_grouped_matmul` will return a list which has one element.
Therefore, the `torch.cat` after `torch_npu.npu_grouped_matmul` is
unnecessary.
### Does this PR introduce _any_ user-facing change?
not involved
### How was this patch tested?
ut and e2e covered: `tests/ut/ops/test_fused_ops.py`,
`tests/e2e/singlecard/ops/test_fused_moe.py`
**performance**:
(qwen3 30B, 2k->20k)
base:
Total Token throughput (tok/s): 667.76
remove cat:
Total Token throughput (tok/s): 680.82
- vLLM version: v0.10.0
- vLLM main:
fa00c5d75b
Signed-off-by: huangxialu <huangxialu1@huawei.com>
This commit is contained in:
@@ -112,7 +112,7 @@ def mock_moe_env(mocker: MockerFixture):
|
||||
torch.randn(16, 2)
|
||||
)), \
|
||||
patch("torch_npu.npu_grouped_matmul", return_value=(
|
||||
(torch.randn(8, 2), torch.randn(8, 2))
|
||||
[torch.randn(16, 2)]
|
||||
)), \
|
||||
patch("torch_npu.npu_swiglu", return_value=(
|
||||
torch.randn(16, 2)
|
||||
|
||||
@@ -205,11 +205,9 @@ def fused_experts_with_mc2(
|
||||
group_list_type=1,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
)
|
||||
)[0]
|
||||
|
||||
# TODO: Remove this in the future.
|
||||
gate_up_out = torch.cat(gate_up_out_list, dim=0)
|
||||
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
|
||||
gate_up_out = torch_npu.npu_swiglu(gate_up_out_list)
|
||||
|
||||
w2 = w2.transpose(1, 2)
|
||||
down_out_list = torch_npu.npu_grouped_matmul(
|
||||
@@ -219,9 +217,7 @@ def fused_experts_with_mc2(
|
||||
group_list_type=1,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
)
|
||||
|
||||
down_out_list = torch.cat(down_out_list, dim=0)
|
||||
)[0]
|
||||
|
||||
# moeCombine
|
||||
kwargs_mc2 = {
|
||||
@@ -312,9 +308,8 @@ def apply_mlp(
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
)
|
||||
)[0]
|
||||
|
||||
hidden_states = torch.cat(hidden_states, dim=0)
|
||||
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
||||
|
||||
w2 = w2.transpose(1, 2)
|
||||
@@ -325,9 +320,8 @@ def apply_mlp(
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
)
|
||||
)[0]
|
||||
|
||||
hidden_states = torch.cat(hidden_states, dim=0)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -417,23 +411,19 @@ def fused_experts_with_all2all(
|
||||
group_list_type=0,
|
||||
group_type=0,
|
||||
group_list=expert_tokens,
|
||||
)
|
||||
)[0]
|
||||
|
||||
# TODO: Remove this in the future.
|
||||
hidden_states = torch.cat(gate_up_out_list, dim=0)
|
||||
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
||||
hidden_states = torch_npu.npu_swiglu(gate_up_out_list)
|
||||
|
||||
w2 = w2.transpose(1, 2)
|
||||
down_out_list = torch_npu.npu_grouped_matmul(
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w2],
|
||||
split_item=2,
|
||||
group_list_type=0,
|
||||
group_type=0,
|
||||
group_list=expert_tokens,
|
||||
)
|
||||
|
||||
hidden_states = torch.cat(down_out_list, dim=0)
|
||||
)[0]
|
||||
|
||||
if expert_map is not None:
|
||||
resorted_idx = torch.argsort(sorted_idx)
|
||||
@@ -823,11 +813,9 @@ def fused_experts(
|
||||
group_list_type=0,
|
||||
group_type=0,
|
||||
group_list=expert_tokens,
|
||||
)
|
||||
)[0]
|
||||
|
||||
# TODO: Remove this in the future.
|
||||
gate_up_out = torch.cat(gate_up_out_list, dim=0)
|
||||
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
|
||||
gate_up_out = torch_npu.npu_swiglu(gate_up_out_list)
|
||||
|
||||
w2 = w2.transpose(1, 2)
|
||||
down_out_list = torch_npu.npu_grouped_matmul(
|
||||
@@ -837,9 +825,7 @@ def fused_experts(
|
||||
group_list_type=0,
|
||||
group_type=0,
|
||||
group_list=expert_tokens,
|
||||
)
|
||||
|
||||
down_out_list = torch.cat(down_out_list, dim=0)
|
||||
)[0]
|
||||
|
||||
if expert_map is not None:
|
||||
weighted_down_out = down_out_list * sorted_weights.unsqueeze(1)
|
||||
|
||||
Reference in New Issue
Block a user