GMM custom operator optimization in small batch scenarios (#7100)

### What this PR does / why we need it?
GMM custom operator optimization in small batch scenarios

### How was this patch tested?

Qwen3-30B input: 4k, output: 1k

batch 1:
TPOT 7.9 ms -> 7.0 ms
Output Token Throughput 125.4651 token/s -> 140.6278 token/s

batch 2:
TPOT 9.4 ms -> 8.8 ms
Output Token Throughput 211.8187 token/s -> 225.2254 token/s

batch 16:
TPOT 13.6 ms -> 13.5 ms
Output Token Throughput 1159.8213 token/s -> 1165.0982 token/s

- vLLM version: v0.16.0
- vLLM main:
4034c3d32e

---------

Signed-off-by: chenxi-hh <chen464822955@163.com>
This commit is contained in:
chenxi-hh
2026-03-19 16:10:30 +08:00
committed by GitHub
parent 8e0ebb470a
commit 42bcad7e9b
3 changed files with 71 additions and 30 deletions

View File

@@ -17,6 +17,7 @@
#
import torch
import torch_npu
from vllm.forward_context import get_forward_context
from vllm_ascend.device.mxfp_compat import (
FLOAT4_E2M1FN_X2_DTYPE,
@@ -27,6 +28,8 @@ from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
class BaseDeviceAdaptor:
small_batch_gmm_batch_num = 16
@classmethod
def reshape_and_cache(cls, key, value, key_cache, value_cache, slot_mapping):
torch_npu._npu_reshape_and_cache(
@@ -46,17 +49,32 @@ class BaseDeviceAdaptor:
active_expert_range=None,
quant_mode: int = -1,
):
return torch.ops._C_ascend.npu_moe_init_routing_custom(
hidden_states,
topk_ids,
scale=scale,
active_num=active_num,
expert_num=expert_num,
expert_tokens_num_type=expert_tokens_num_type,
expert_tokens_num_flag=expert_tokens_num_flag,
active_expert_range=active_expert_range,
quant_mode=quant_mode,
)
# In small batch and non-quantization scenarios, npu_moe_init_routing_v2 is more efficient.
# It is expected that further improvements will be made after it is incorporated into CANN on June 30th.
if quant_mode == -1 and get_forward_context().num_tokens <= DeviceOperator.small_batch_gmm_batch_num:
return torch_npu.npu_moe_init_routing_v2(
hidden_states,
topk_ids,
scale=scale,
active_num=active_num,
expert_num=expert_num,
expert_tokens_num_type=2,
expert_tokens_num_flag=expert_tokens_num_flag,
active_expert_range=active_expert_range,
quant_mode=quant_mode,
)
else:
return torch.ops._C_ascend.npu_moe_init_routing_custom(
hidden_states,
topk_ids,
scale=scale,
active_num=active_num,
expert_num=expert_num,
expert_tokens_num_type=expert_tokens_num_type,
expert_tokens_num_flag=expert_tokens_num_flag,
active_expert_range=active_expert_range,
quant_mode=quant_mode,
)
@staticmethod
def npu_dynamic_quant(