From 5ebb9bd8d26395e86f9bd01183a327477a966760 Mon Sep 17 00:00:00 2001 From: ChrisGelhLan <33011886+xlan-huawei@users.noreply.github.com> Date: Thu, 11 Dec 2025 16:32:28 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90Bugfix=E3=80=91bugfix=5Ffor=5Fbmm=5Ftr?= =?UTF-8?q?anspose=20(#4899)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The bmm_transpose operator in version 3.2 is only used in the decoding stage due to shape limitations. - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 --------- Signed-off-by: ChrisGelhLan <33011886+xlan-huawei@users.noreply.github.com> --- vllm_ascend/attention/sfa_v1.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index 8f14aa3d..3a962b87 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -490,9 +490,11 @@ class AscendSFAImpl(MLAAttentionImpl): self._process_weights_for_fused_mlapo(act_dtype) def _v_up_proj(self, x): + forward_context = get_forward_context() if x.dtype in [torch.float16, torch.bfloat16] \ and hasattr(torch.ops._C_ascend, "batch_matmul_transpose") \ - and not self.enable_sfa_cp: + and not self.enable_sfa_cp \ + and not forward_context.with_prefill: x = x.view(-1, self.num_heads, self.kv_lora_rank) b, _, _ = x.shape res = torch.empty((b, self.num_heads, self.v_head_dim),