GMM custom operator optimization in small batch scenarios (#7100)

### What this PR does / why we need it?
GMM custom operator optimization in small batch scenarios

### How was this patch tested?

Qwen3-30B input: 4k, output: 1k

batch 1:
TPOT 7.9 ms -> 7.0 ms
Output Token Throughput 125.4651 token/s -> 140.6278 token/s

batch 2:
TPOT 9.4 ms -> 8.8 ms
Output Token Throughput 211.8187 token/s -> 225.2254 token/s

batch 16:
TPOT 13.6 ms -> 13.5 ms
Output Token Throughput 1159.8213 token/s -> 1165.0982 token/s

- vLLM version: v0.16.0
- vLLM main:
4034c3d32e

---------

Signed-off-by: chenxi-hh <chen464822955@163.com>
This commit is contained in:
chenxi-hh
2026-03-19 16:10:30 +08:00
committed by GitHub
parent 8e0ebb470a
commit 42bcad7e9b
3 changed files with 71 additions and 30 deletions

View File

@@ -18,6 +18,7 @@
import torch
import torch_npu
from torch.nn.functional import pad
from vllm.forward_context import get_forward_context
from vllm.triton_utils import HAS_TRITON
from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType
@@ -339,15 +340,26 @@ def unquant_apply_mlp(
w1 = w1.transpose(1, 2)
w2 = w2.transpose(1, 2)
gate_up_out = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w1],
bias=[w1_bias.to(dtype=torch.float32)] if w1_bias is not None else None,
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
)[0]
# In the small batch scenario, use _C_ascend.moe_grouped_matmul
if group_list.dim() == 2 and get_forward_context().num_tokens <= DeviceOperator.small_batch_gmm_batch_num:
gate_up_out = torch.ops._C_ascend.moe_grouped_matmul(
x=hidden_states,
weight=w1,
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
)[0]
else:
gate_up_out = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w1],
bias=[w1_bias.to(dtype=torch.float32)] if w1_bias is not None else None,
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
)[0]
if activation == "swigluoai":
num_experts, _, hidden_size = w1.shape
@@ -358,15 +370,26 @@ def unquant_apply_mlp(
if topk_scales is not None:
gate_up_out *= topk_scales
hidden_states = torch_npu.npu_grouped_matmul(
x=[gate_up_out],
weight=[w2],
bias=[w2_bias.to(dtype=torch.float32)] if w2_bias is not None else None,
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
)[0]
# In the small batch scenario, use _C_ascend.moe_grouped_matmul
if group_list.dim() == 2 and get_forward_context().num_tokens <= DeviceOperator.small_batch_gmm_batch_num:
hidden_states = torch.ops._C_ascend.moe_grouped_matmul(
x=gate_up_out,
weight=w2,
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
)[0]
else:
hidden_states = torch_npu.npu_grouped_matmul(
x=[gate_up_out],
weight=[w2],
bias=[w2_bias.to(dtype=torch.float32)] if w2_bias is not None else None,
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
)[0]
return hidden_states