From 9fcaf66646d15671ab4d9bcf3530cccfcc6b2675 Mon Sep 17 00:00:00 2001 From: LICO67373 <110013619+LICO1314@users.noreply.github.com> Date: Thu, 18 Dec 2025 16:48:55 +0800 Subject: [PATCH] fix: use batch_matmul_transpose operator in MLA _v_up_proj for better performance (#5142) ### What this PR does / why we need it? This PR fixes a bug in the `AscendMLAImpl._v_up_proj` method where the optimized `batch_matmul_transpose` operator was not being utilized. **Changes:** - Modified `_v_up_proj` method to use `torch.ops._C_ascend.batch_matmul_transpose` operator for FP16/BF16 dtypes when available - Added fallback path using the original `torch.bmm` implementation for other cases - This avoids unnecessary transpose operations and improves performance **Why needed:** - The previous implementation only used `torch.bmm` with multiple transpose operations, which is less efficient - The Ascend backend provides an optimized `batch_matmul_transpose` operator that can handle the computation more efficiently - This fix improves inference performance for MLA (Multi-head Latent Attention) models on Ascend NPU ### Does this PR introduce _any_ user-facing change? No. This is a performance optimization that maintains the same functionality and output. Users will experience faster inference for MLA-based models, but no API or interface changes are introduced. The changes maintain backward compatibility with the fallback path, ensuring correct behavior when the operator is not available or for unsupported dtypes. - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 Signed-off-by: lico67373 <918688502@qq.com> Co-authored-by: hwhaokun Co-authored-by: weijinqian0 <1184188277@qq.com> --- vllm_ascend/attention/mla_v1.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index e250fdbc..0646b41c 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -764,12 +764,22 @@ class AscendMLAImpl(MLAAttentionImpl): self.enable_mlapo = envs.VLLM_ASCEND_ENABLE_MLAPO def _v_up_proj(self, x): - # Convert from (B, N, L) to (N, B, L) - x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) - # # Multiply (N, B, L) x (N, L, V) -> (N, B, V) - x = torch.bmm(x, self.W_UV) - # # Convert from (N, B, V) to (B, N * V) - x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) + if x.dtype in [torch.float16, torch.bfloat16] \ + and hasattr(torch.ops._C_ascend, "batch_matmul_transpose"): + x = x.view(-1, self.num_heads, self.kv_lora_rank) + b, _, _ = x.shape + res = torch.empty((b, self.num_heads, self.v_head_dim), + dtype=x.dtype, + device=x.device) + torch.ops._C_ascend.batch_matmul_transpose(x, self.W_UV, res) + x = res.reshape(-1, self.num_heads * self.v_head_dim) + else: + # Convert from (B, N, L) to (N, B, L) + x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) + # Multiply (N, B, L) x (N, L, V) -> (N, B, V) + x = torch.bmm(x, self.W_UV) + # Convert from (N, B, V) to (B, N * V) + x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) return x # Return `ql_nope`, `q_pe`