GMM custom operator optimization in small batch scenarios (#7100)
### What this PR does / why we need it?
GMM custom operator optimization in small batch scenarios
### How was this patch tested?
Qwen3-30B input: 4k, output: 1k
batch 1:
TPOT 7.9 ms -> 7.0 ms
Output Token Throughput 125.4651 token/s -> 140.6278 token/s
batch 2:
TPOT 9.4 ms -> 8.8 ms
Output Token Throughput 211.8187 token/s -> 225.2254 token/s
batch 16:
TPOT 13.6 ms -> 13.5 ms
Output Token Throughput 1159.8213 token/s -> 1165.0982 token/s
- vLLM version: v0.16.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: chenxi-hh <chen464822955@163.com>
This commit is contained in:
@@ -697,7 +697,7 @@ std::vector<at::Tensor> moe_grouped_matmul(
|
|||||||
y.emplace_back(y_0);
|
y.emplace_back(y_0);
|
||||||
at::TensorList result = at::TensorList(y);
|
at::TensorList result = at::TensorList(y);
|
||||||
|
|
||||||
EXEC_NPU_CMD(aclnnMoeGroupedMatmulWeightNz,
|
EXEC_NPU_CMD(aclnnMoeGroupedMatmul,
|
||||||
x_list, weight_list, group_list, transpose_weight, result);
|
x_list, weight_list, group_list, transpose_weight, result);
|
||||||
|
|
||||||
return y;
|
return y;
|
||||||
|
|||||||
@@ -17,6 +17,7 @@
|
|||||||
#
|
#
|
||||||
import torch
|
import torch
|
||||||
import torch_npu
|
import torch_npu
|
||||||
|
from vllm.forward_context import get_forward_context
|
||||||
|
|
||||||
from vllm_ascend.device.mxfp_compat import (
|
from vllm_ascend.device.mxfp_compat import (
|
||||||
FLOAT4_E2M1FN_X2_DTYPE,
|
FLOAT4_E2M1FN_X2_DTYPE,
|
||||||
@@ -27,6 +28,8 @@ from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
|
|||||||
|
|
||||||
|
|
||||||
class BaseDeviceAdaptor:
|
class BaseDeviceAdaptor:
|
||||||
|
small_batch_gmm_batch_num = 16
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def reshape_and_cache(cls, key, value, key_cache, value_cache, slot_mapping):
|
def reshape_and_cache(cls, key, value, key_cache, value_cache, slot_mapping):
|
||||||
torch_npu._npu_reshape_and_cache(
|
torch_npu._npu_reshape_and_cache(
|
||||||
@@ -46,6 +49,21 @@ class BaseDeviceAdaptor:
|
|||||||
active_expert_range=None,
|
active_expert_range=None,
|
||||||
quant_mode: int = -1,
|
quant_mode: int = -1,
|
||||||
):
|
):
|
||||||
|
# In small batch and non-quantization scenarios, npu_moe_init_routing_v2 is more efficient.
|
||||||
|
# It is expected that further improvements will be made after it is incorporated into CANN on June 30th.
|
||||||
|
if quant_mode == -1 and get_forward_context().num_tokens <= DeviceOperator.small_batch_gmm_batch_num:
|
||||||
|
return torch_npu.npu_moe_init_routing_v2(
|
||||||
|
hidden_states,
|
||||||
|
topk_ids,
|
||||||
|
scale=scale,
|
||||||
|
active_num=active_num,
|
||||||
|
expert_num=expert_num,
|
||||||
|
expert_tokens_num_type=2,
|
||||||
|
expert_tokens_num_flag=expert_tokens_num_flag,
|
||||||
|
active_expert_range=active_expert_range,
|
||||||
|
quant_mode=quant_mode,
|
||||||
|
)
|
||||||
|
else:
|
||||||
return torch.ops._C_ascend.npu_moe_init_routing_custom(
|
return torch.ops._C_ascend.npu_moe_init_routing_custom(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
|
|||||||
@@ -18,6 +18,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch_npu
|
import torch_npu
|
||||||
from torch.nn.functional import pad
|
from torch.nn.functional import pad
|
||||||
|
from vllm.forward_context import get_forward_context
|
||||||
from vllm.triton_utils import HAS_TRITON
|
from vllm.triton_utils import HAS_TRITON
|
||||||
|
|
||||||
from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType
|
from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType
|
||||||
@@ -339,6 +340,17 @@ def unquant_apply_mlp(
|
|||||||
w1 = w1.transpose(1, 2)
|
w1 = w1.transpose(1, 2)
|
||||||
w2 = w2.transpose(1, 2)
|
w2 = w2.transpose(1, 2)
|
||||||
|
|
||||||
|
# In the small batch scenario, use _C_ascend.moe_grouped_matmul
|
||||||
|
if group_list.dim() == 2 and get_forward_context().num_tokens <= DeviceOperator.small_batch_gmm_batch_num:
|
||||||
|
gate_up_out = torch.ops._C_ascend.moe_grouped_matmul(
|
||||||
|
x=hidden_states,
|
||||||
|
weight=w1,
|
||||||
|
split_item=2,
|
||||||
|
group_list_type=group_list_type,
|
||||||
|
group_type=0,
|
||||||
|
group_list=group_list,
|
||||||
|
)[0]
|
||||||
|
else:
|
||||||
gate_up_out = torch_npu.npu_grouped_matmul(
|
gate_up_out = torch_npu.npu_grouped_matmul(
|
||||||
x=[hidden_states],
|
x=[hidden_states],
|
||||||
weight=[w1],
|
weight=[w1],
|
||||||
@@ -358,6 +370,17 @@ def unquant_apply_mlp(
|
|||||||
if topk_scales is not None:
|
if topk_scales is not None:
|
||||||
gate_up_out *= topk_scales
|
gate_up_out *= topk_scales
|
||||||
|
|
||||||
|
# In the small batch scenario, use _C_ascend.moe_grouped_matmul
|
||||||
|
if group_list.dim() == 2 and get_forward_context().num_tokens <= DeviceOperator.small_batch_gmm_batch_num:
|
||||||
|
hidden_states = torch.ops._C_ascend.moe_grouped_matmul(
|
||||||
|
x=gate_up_out,
|
||||||
|
weight=w2,
|
||||||
|
split_item=2,
|
||||||
|
group_list_type=group_list_type,
|
||||||
|
group_type=0,
|
||||||
|
group_list=group_list,
|
||||||
|
)[0]
|
||||||
|
else:
|
||||||
hidden_states = torch_npu.npu_grouped_matmul(
|
hidden_states = torch_npu.npu_grouped_matmul(
|
||||||
x=[gate_up_out],
|
x=[gate_up_out],
|
||||||
weight=[w2],
|
weight=[w2],
|
||||||
|
|||||||
Reference in New Issue
Block a user