diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index e250fdbc..0646b41c 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -764,12 +764,22 @@ class AscendMLAImpl(MLAAttentionImpl): self.enable_mlapo = envs.VLLM_ASCEND_ENABLE_MLAPO def _v_up_proj(self, x): - # Convert from (B, N, L) to (N, B, L) - x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) - # # Multiply (N, B, L) x (N, L, V) -> (N, B, V) - x = torch.bmm(x, self.W_UV) - # # Convert from (N, B, V) to (B, N * V) - x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) + if x.dtype in [torch.float16, torch.bfloat16] \ + and hasattr(torch.ops._C_ascend, "batch_matmul_transpose"): + x = x.view(-1, self.num_heads, self.kv_lora_rank) + b, _, _ = x.shape + res = torch.empty((b, self.num_heads, self.v_head_dim), + dtype=x.dtype, + device=x.device) + torch.ops._C_ascend.batch_matmul_transpose(x, self.W_UV, res) + x = res.reshape(-1, self.num_heads * self.v_head_dim) + else: + # Convert from (B, N, L) to (N, B, L) + x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) + # Multiply (N, B, L) x (N, L, V) -> (N, B, V) + x = torch.bmm(x, self.W_UV) + # Convert from (N, B, V) to (B, N * V) + x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) return x # Return `ql_nope`, `q_pe`