diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 2196858..f88e01a 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -559,12 +559,13 @@ class AscendMLAImpl(MLAAttentionImpl): self.speculative_config = vllm_config.speculative_config def _v_up_proj(self, x): - # Convert from (B, N, L) to (N, B, L) - x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) - # Multiply (N, B, L) x (N, L, V) -> (N, B, V) - x = torch.bmm(x, self.W_UV) - # Convert from (N, B, V) to (B, N * V) - x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) + x = x.view(-1, self.num_heads, self.kv_lora_rank) + x = torch_npu.npu_transpose_batchmatmul(x, + self.W_UV, + perm_x1=[1, 0, 2], + perm_x2=[0, 1, 2], + perm_y=[1, 0, 2]) + x = x.reshape(-1, self.num_heads * self.v_head_dim) return x # Return `ql_nope`, `q_pe`