[Feature] Remove the transpose step after attention and switch to transpose_batchmatmul (#5390)

1. The `npu_fused_infer_attention_score` kernel supports specifying the
output layout. By selecting the appropriate layout, we can avoid the
transpose operation typically required after the attention.
2. The `transpose_batchmatmul` function allows us to control whether the
output tensor is transposed. If we configure `perm_y`, an additional
transpose after executing `v_up` becomes unnecessary.

- vLLM version: release/v0.13.0
- vLLM main:
254f6b9867

---------

Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
This commit is contained in:
Jade Zheng
2025-12-26 22:03:46 +08:00
committed by GitHub
parent bc5b7a5fb5
commit 8b9ca86827

View File

@@ -790,22 +790,12 @@ class AscendMLAImpl(MLAAttentionImpl):
self.enable_mlapo = envs.VLLM_ASCEND_ENABLE_MLAPO
def _v_up_proj(self, x):
if x.dtype in [torch.float16, torch.bfloat16] \
and hasattr(torch.ops._C_ascend, "batch_matmul_transpose"):
x = x.view(-1, self.num_heads, self.kv_lora_rank)
b, _, _ = x.shape
res = torch.empty((b, self.num_heads, self.v_head_dim),
dtype=x.dtype,
device=x.device)
torch.ops._C_ascend.batch_matmul_transpose(x, self.W_UV, res)
x = res.reshape(-1, self.num_heads * self.v_head_dim)
else:
# Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
x = torch.bmm(x, self.W_UV)
# Convert from (N, B, V) to (B, N * V)
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
# Convert from (N, B, L)/(N, B, 1, L) to (N, B, L)
x = x.view(self.num_heads, -1, self.kv_lora_rank)
# Multiply (N, B, L) x (N, L, V) -> (B, N, V)
x = torch_npu.npu_transpose_batchmatmul(x, self.W_UV, perm_y=(1, 0, 2))
# Convert from (B, N, V) to (B, N * V)
x = x.reshape(-1, self.num_heads * self.v_head_dim)
return x
# Return `ql_nope`, `q_pe`
@@ -1169,16 +1159,18 @@ class AscendMLAImpl(MLAAttentionImpl):
self.kv_lora_rank)
k_pe = k_pe.view(-1, self.num_kv_heads, block_size,
self.qk_rope_head_dim)
input_layout = "BNSD"
if attn_metadata.attn_state in [
AscendAttentionState.SpecDecoding,
AscendAttentionState.ChunkedPrefill,
AscendAttentionState.DecodeOnly,
] and self.speculative_config is not None:
# Use TND layout for pure SpecDecoding and SpecDecoding in ChunkedPrefill
input_layout = "TND"
# [bs * q_seq_len, num_heads_per_rank, dim]
# Input shape: [num_tokens, num_heads, dim]
# Output shape: [num_heads, num_tokens, dim]
# The right part layout indicates the layout of the attention
# output. It is set to NTD to avoid the need for a transpose
# operation after attention.
input_layout = "TND_NTD"
# TODO: If the driver is upgraded later, the contiguous function can be deleted.
q_nope = q_nope.view(num_tokens, self.num_heads, -1).contiguous()
q_pe = q_pe.view(num_tokens, self.num_heads, -1)
@@ -1186,6 +1178,11 @@ class AscendMLAImpl(MLAAttentionImpl):
spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore
actual_seq_lengths = decode_meta.actual_seq_lengths_q
else:
# Input shape: [num_reqs, num_heads, seq_len, dim]
# Output shape: [num_heads, num_reqs, seq_len, dim]
# The output layout is set to NBSD to eliminate the need for a
# transpose operation after attention.
input_layout = "BNSD_NBSD"
q_nope = q_nope.view(num_tokens, self.num_heads, 1,
-1).contiguous()
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
@@ -1227,7 +1224,10 @@ class AscendMLAImpl(MLAAttentionImpl):
q_nope, k_nope, k_nope, **common_kwargs)
update_graph_params_workspaces(num_tokens, workspace)
attn_output = torch.empty_like(q_nope)
attn_output = torch.empty(
(q_nope.shape[1], q_nope.shape[0], *q_nope.shape[2:]),
dtype=q_nope.dtype,
device=q_nope.device)
softmax_lse = torch.empty(num_tokens,
dtype=q_nope.dtype,
device=q_nope.device)