【Bugfix】bugfix_for_bmm_transpose (#4899)

The bmm_transpose operator in version 3.2 is only used in the decoding stage due to shape limitations.

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: ChrisGelhLan <33011886+xlan-huawei@users.noreply.github.com>
This commit is contained in:
ChrisGelhLan
2025-12-11 16:32:28 +08:00
committed by GitHub
parent 78bf211539
commit 5ebb9bd8d2

View File

@@ -490,9 +490,11 @@ class AscendSFAImpl(MLAAttentionImpl):
self._process_weights_for_fused_mlapo(act_dtype)
def _v_up_proj(self, x):
forward_context = get_forward_context()
if x.dtype in [torch.float16, torch.bfloat16] \
and hasattr(torch.ops._C_ascend, "batch_matmul_transpose") \
and not self.enable_sfa_cp:
and not self.enable_sfa_cp \
and not forward_context.with_prefill:
x = x.view(-1, self.num_heads, self.kv_lora_rank)
b, _, _ = x.shape
res = torch.empty((b, self.num_heads, self.v_head_dim),