From 42bcad7e9b68cf92d9b98a9072eff10edc8d6f33 Mon Sep 17 00:00:00 2001 From: chenxi-hh <32731611+chenxi-hh@users.noreply.github.com> Date: Thu, 19 Mar 2026 16:10:30 +0800 Subject: [PATCH] GMM custom operator optimization in small batch scenarios (#7100) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What this PR does / why we need it? GMM custom operator optimization in small batch scenarios ### How was this patch tested? Qwen3-30B input: 4k, output: 1k batch 1: TPOT 7.9 ms -> 7.0 ms Output Token Throughput 125.4651 token/s -> 140.6278 token/s batch 2: TPOT 9.4 ms -> 8.8 ms Output Token Throughput 211.8187 token/s -> 225.2254 token/s batch 16: TPOT 13.6 ms -> 13.5 ms Output Token Throughput 1159.8213 token/s -> 1165.0982 token/s - vLLM version: v0.16.0 - vLLM main: https://github.com/vllm-project/vllm/commit/4034c3d32e30d01639459edd3ab486f56993876d --------- Signed-off-by: chenxi-hh --- csrc/torch_binding.cpp | 2 +- vllm_ascend/device/device_op.py | 40 +++++++++++++------ vllm_ascend/ops/fused_moe/moe_mlp.py | 59 +++++++++++++++++++--------- 3 files changed, 71 insertions(+), 30 deletions(-) diff --git a/csrc/torch_binding.cpp b/csrc/torch_binding.cpp index b22cb7b0..b79065ae 100644 --- a/csrc/torch_binding.cpp +++ b/csrc/torch_binding.cpp @@ -697,7 +697,7 @@ std::vector moe_grouped_matmul( y.emplace_back(y_0); at::TensorList result = at::TensorList(y); - EXEC_NPU_CMD(aclnnMoeGroupedMatmulWeightNz, + EXEC_NPU_CMD(aclnnMoeGroupedMatmul, x_list, weight_list, group_list, transpose_weight, result); return y; diff --git a/vllm_ascend/device/device_op.py b/vllm_ascend/device/device_op.py index 9bb7b1a0..5d95544c 100644 --- a/vllm_ascend/device/device_op.py +++ b/vllm_ascend/device/device_op.py @@ -17,6 +17,7 @@ # import torch import torch_npu +from vllm.forward_context import get_forward_context from vllm_ascend.device.mxfp_compat import ( FLOAT4_E2M1FN_X2_DTYPE, @@ -27,6 +28,8 @@ from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type class BaseDeviceAdaptor: + small_batch_gmm_batch_num = 16 + @classmethod def reshape_and_cache(cls, key, value, key_cache, value_cache, slot_mapping): torch_npu._npu_reshape_and_cache( @@ -46,17 +49,32 @@ class BaseDeviceAdaptor: active_expert_range=None, quant_mode: int = -1, ): - return torch.ops._C_ascend.npu_moe_init_routing_custom( - hidden_states, - topk_ids, - scale=scale, - active_num=active_num, - expert_num=expert_num, - expert_tokens_num_type=expert_tokens_num_type, - expert_tokens_num_flag=expert_tokens_num_flag, - active_expert_range=active_expert_range, - quant_mode=quant_mode, - ) + # In small batch and non-quantization scenarios, npu_moe_init_routing_v2 is more efficient. + # It is expected that further improvements will be made after it is incorporated into CANN on June 30th. + if quant_mode == -1 and get_forward_context().num_tokens <= DeviceOperator.small_batch_gmm_batch_num: + return torch_npu.npu_moe_init_routing_v2( + hidden_states, + topk_ids, + scale=scale, + active_num=active_num, + expert_num=expert_num, + expert_tokens_num_type=2, + expert_tokens_num_flag=expert_tokens_num_flag, + active_expert_range=active_expert_range, + quant_mode=quant_mode, + ) + else: + return torch.ops._C_ascend.npu_moe_init_routing_custom( + hidden_states, + topk_ids, + scale=scale, + active_num=active_num, + expert_num=expert_num, + expert_tokens_num_type=expert_tokens_num_type, + expert_tokens_num_flag=expert_tokens_num_flag, + active_expert_range=active_expert_range, + quant_mode=quant_mode, + ) @staticmethod def npu_dynamic_quant( diff --git a/vllm_ascend/ops/fused_moe/moe_mlp.py b/vllm_ascend/ops/fused_moe/moe_mlp.py index 2033b61b..74b84f80 100644 --- a/vllm_ascend/ops/fused_moe/moe_mlp.py +++ b/vllm_ascend/ops/fused_moe/moe_mlp.py @@ -18,6 +18,7 @@ import torch import torch_npu from torch.nn.functional import pad +from vllm.forward_context import get_forward_context from vllm.triton_utils import HAS_TRITON from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType @@ -339,15 +340,26 @@ def unquant_apply_mlp( w1 = w1.transpose(1, 2) w2 = w2.transpose(1, 2) - gate_up_out = torch_npu.npu_grouped_matmul( - x=[hidden_states], - weight=[w1], - bias=[w1_bias.to(dtype=torch.float32)] if w1_bias is not None else None, - split_item=2, - group_list_type=group_list_type, - group_type=0, - group_list=group_list, - )[0] + # In the small batch scenario, use _C_ascend.moe_grouped_matmul + if group_list.dim() == 2 and get_forward_context().num_tokens <= DeviceOperator.small_batch_gmm_batch_num: + gate_up_out = torch.ops._C_ascend.moe_grouped_matmul( + x=hidden_states, + weight=w1, + split_item=2, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + )[0] + else: + gate_up_out = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[w1], + bias=[w1_bias.to(dtype=torch.float32)] if w1_bias is not None else None, + split_item=2, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + )[0] if activation == "swigluoai": num_experts, _, hidden_size = w1.shape @@ -358,15 +370,26 @@ def unquant_apply_mlp( if topk_scales is not None: gate_up_out *= topk_scales - hidden_states = torch_npu.npu_grouped_matmul( - x=[gate_up_out], - weight=[w2], - bias=[w2_bias.to(dtype=torch.float32)] if w2_bias is not None else None, - split_item=2, - group_list_type=group_list_type, - group_type=0, - group_list=group_list, - )[0] + # In the small batch scenario, use _C_ascend.moe_grouped_matmul + if group_list.dim() == 2 and get_forward_context().num_tokens <= DeviceOperator.small_batch_gmm_batch_num: + hidden_states = torch.ops._C_ascend.moe_grouped_matmul( + x=gate_up_out, + weight=w2, + split_item=2, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + )[0] + else: + hidden_states = torch_npu.npu_grouped_matmul( + x=[gate_up_out], + weight=[w2], + bias=[w2_bias.to(dtype=torch.float32)] if w2_bias is not None else None, + split_item=2, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + )[0] return hidden_states