[MM][Perf] Enable 2.7x faster for convolution computation with aclnn BatchMatMulV2 (#7017)

### What this PR does / why we need it?
Currently, we are using
e2b31243c0/vllm/model_executor/layers/conv.py (L219-L232)
for convolution computation, which is used in patch embedding for VL
models.

After profiling, we find that this linear method will take about **6.87
ms**, which is much slower than just using `F.conv3d()`. In
`F.conv3d()`, it will call aclnn `BatchMatMulV2` with optimization on
Ascend NPU, which only take about **2.50 ms** and is **2.7x faster**
than linear method.

- vLLM version: v0.16.0
- vLLM main:
15d76f74e2
---------
Signed-off-by: shen-shanshan <467638484@qq.com>
This commit is contained in:
Shanshan Shen
2026-03-06 14:26:37 +08:00
committed by GitHub
parent c49ce18ea5
commit a813eadd2d
2 changed files with 35 additions and 0 deletions

32
vllm_ascend/ops/conv.py Normal file
View File

@@ -0,0 +1,32 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch
from vllm.model_executor.layers.conv import Conv2dLayer, Conv3dLayer
class AscendConv2dLayer(Conv2dLayer):
def forward_oot(self, x: torch.Tensor) -> torch.Tensor:
# Use aclnn BatchMatMulV2 for better performance on Ascend NPU.
return self._forward_conv(x)
class AscendConv3dLayer(Conv3dLayer):
def forward_oot(self, x: torch.Tensor) -> torch.Tensor:
# Use aclnn BatchMatMulV2 for better performance on Ascend NPU.
return self._forward_conv(x)

View File

@@ -597,6 +597,7 @@ def register_ascend_customop(vllm_config: VllmConfig | None = None):
from vllm.model_executor.custom_op import CustomOp
from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul
from vllm_ascend.ops.conv import AscendConv2dLayer, AscendConv3dLayer
from vllm_ascend.ops.fused_moe.fused_moe import AscendFusedMoE, AscendSharedFusedMoE
from vllm_ascend.ops.layernorm import AscendGemmaRMSNorm, AscendRMSNorm, AscendRMSNormGated
from vllm_ascend.ops.linear import (
@@ -645,6 +646,8 @@ def register_ascend_customop(vllm_config: VllmConfig | None = None):
"MMEncoderAttention": AscendMMEncoderAttention,
"ApplyRotaryEmb": AscendApplyRotaryEmb,
"RMSNormGated": AscendRMSNormGated,
"Conv2dLayer": AscendConv2dLayer,
"Conv3dLayer": AscendConv3dLayer,
}
# 310P: override selected ops with 310P implementations (keep minimal changes outside _310p)