From 8b9ca868274d169c47424df9b74376d83f441d89 Mon Sep 17 00:00:00 2001 From: Jade Zheng Date: Fri, 26 Dec 2025 22:03:46 +0800 Subject: [PATCH] [Feature] Remove the transpose step after attention and switch to transpose_batchmatmul (#5390) 1. The `npu_fused_infer_attention_score` kernel supports specifying the output layout. By selecting the appropriate layout, we can avoid the transpose operation typically required after the attention. 2. The `transpose_batchmatmul` function allows us to control whether the output tensor is transposed. If we configure `perm_y`, an additional transpose after executing `v_up` becomes unnecessary. - vLLM version: release/v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/254f6b986720c92ddf97fbb1a6a6465da8e87e29 --------- Signed-off-by: Jade Zheng --- vllm_ascend/attention/mla_v1.py | 42 ++++++++++++++++----------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index b77682db..e529f0bf 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -790,22 +790,12 @@ class AscendMLAImpl(MLAAttentionImpl): self.enable_mlapo = envs.VLLM_ASCEND_ENABLE_MLAPO def _v_up_proj(self, x): - if x.dtype in [torch.float16, torch.bfloat16] \ - and hasattr(torch.ops._C_ascend, "batch_matmul_transpose"): - x = x.view(-1, self.num_heads, self.kv_lora_rank) - b, _, _ = x.shape - res = torch.empty((b, self.num_heads, self.v_head_dim), - dtype=x.dtype, - device=x.device) - torch.ops._C_ascend.batch_matmul_transpose(x, self.W_UV, res) - x = res.reshape(-1, self.num_heads * self.v_head_dim) - else: - # Convert from (B, N, L) to (N, B, L) - x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) - # Multiply (N, B, L) x (N, L, V) -> (N, B, V) - x = torch.bmm(x, self.W_UV) - # Convert from (N, B, V) to (B, N * V) - x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) + # Convert from (N, B, L)/(N, B, 1, L) to (N, B, L) + x = x.view(self.num_heads, -1, self.kv_lora_rank) + # Multiply (N, B, L) x (N, L, V) -> (B, N, V) + x = torch_npu.npu_transpose_batchmatmul(x, self.W_UV, perm_y=(1, 0, 2)) + # Convert from (B, N, V) to (B, N * V) + x = x.reshape(-1, self.num_heads * self.v_head_dim) return x # Return `ql_nope`, `q_pe` @@ -1169,16 +1159,18 @@ class AscendMLAImpl(MLAAttentionImpl): self.kv_lora_rank) k_pe = k_pe.view(-1, self.num_kv_heads, block_size, self.qk_rope_head_dim) - input_layout = "BNSD" if attn_metadata.attn_state in [ AscendAttentionState.SpecDecoding, AscendAttentionState.ChunkedPrefill, AscendAttentionState.DecodeOnly, ] and self.speculative_config is not None: - # Use TND layout for pure SpecDecoding and SpecDecoding in ChunkedPrefill - input_layout = "TND" - # [bs * q_seq_len, num_heads_per_rank, dim] + # Input shape: [num_tokens, num_heads, dim] + # Output shape: [num_heads, num_tokens, dim] + # The right part layout indicates the layout of the attention + # output. It is set to NTD to avoid the need for a transpose + # operation after attention. + input_layout = "TND_NTD" # TODO: If the driver is upgraded later, the contiguous function can be deleted. q_nope = q_nope.view(num_tokens, self.num_heads, -1).contiguous() q_pe = q_pe.view(num_tokens, self.num_heads, -1) @@ -1186,6 +1178,11 @@ class AscendMLAImpl(MLAAttentionImpl): spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore actual_seq_lengths = decode_meta.actual_seq_lengths_q else: + # Input shape: [num_reqs, num_heads, seq_len, dim] + # Output shape: [num_heads, num_reqs, seq_len, dim] + # The output layout is set to NBSD to eliminate the need for a + # transpose operation after attention. + input_layout = "BNSD_NBSD" q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1).contiguous() q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1) @@ -1227,7 +1224,10 @@ class AscendMLAImpl(MLAAttentionImpl): q_nope, k_nope, k_nope, **common_kwargs) update_graph_params_workspaces(num_tokens, workspace) - attn_output = torch.empty_like(q_nope) + attn_output = torch.empty( + (q_nope.shape[1], q_nope.shape[0], *q_nope.shape[2:]), + dtype=q_nope.dtype, + device=q_nope.device) softmax_lse = torch.empty(num_tokens, dtype=q_nope.dtype, device=q_nope.device)