[Feature] Remove the transpose step after attention and switch to transpose_batchmatmul (#5390)
1. The `npu_fused_infer_attention_score` kernel supports specifying the
output layout. By selecting the appropriate layout, we can avoid the
transpose operation typically required after the attention.
2. The `transpose_batchmatmul` function allows us to control whether the
output tensor is transposed. If we configure `perm_y`, an additional
transpose after executing `v_up` becomes unnecessary.
- vLLM version: release/v0.13.0
- vLLM main:
254f6b9867
---------
Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
This commit is contained in:
@@ -790,22 +790,12 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
self.enable_mlapo = envs.VLLM_ASCEND_ENABLE_MLAPO
|
||||
|
||||
def _v_up_proj(self, x):
|
||||
if x.dtype in [torch.float16, torch.bfloat16] \
|
||||
and hasattr(torch.ops._C_ascend, "batch_matmul_transpose"):
|
||||
x = x.view(-1, self.num_heads, self.kv_lora_rank)
|
||||
b, _, _ = x.shape
|
||||
res = torch.empty((b, self.num_heads, self.v_head_dim),
|
||||
dtype=x.dtype,
|
||||
device=x.device)
|
||||
torch.ops._C_ascend.batch_matmul_transpose(x, self.W_UV, res)
|
||||
x = res.reshape(-1, self.num_heads * self.v_head_dim)
|
||||
else:
|
||||
# Convert from (B, N, L) to (N, B, L)
|
||||
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
||||
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
|
||||
x = torch.bmm(x, self.W_UV)
|
||||
# Convert from (N, B, V) to (B, N * V)
|
||||
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
|
||||
# Convert from (N, B, L)/(N, B, 1, L) to (N, B, L)
|
||||
x = x.view(self.num_heads, -1, self.kv_lora_rank)
|
||||
# Multiply (N, B, L) x (N, L, V) -> (B, N, V)
|
||||
x = torch_npu.npu_transpose_batchmatmul(x, self.W_UV, perm_y=(1, 0, 2))
|
||||
# Convert from (B, N, V) to (B, N * V)
|
||||
x = x.reshape(-1, self.num_heads * self.v_head_dim)
|
||||
return x
|
||||
|
||||
# Return `ql_nope`, `q_pe`
|
||||
@@ -1169,16 +1159,18 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
self.kv_lora_rank)
|
||||
k_pe = k_pe.view(-1, self.num_kv_heads, block_size,
|
||||
self.qk_rope_head_dim)
|
||||
input_layout = "BNSD"
|
||||
|
||||
if attn_metadata.attn_state in [
|
||||
AscendAttentionState.SpecDecoding,
|
||||
AscendAttentionState.ChunkedPrefill,
|
||||
AscendAttentionState.DecodeOnly,
|
||||
] and self.speculative_config is not None:
|
||||
# Use TND layout for pure SpecDecoding and SpecDecoding in ChunkedPrefill
|
||||
input_layout = "TND"
|
||||
# [bs * q_seq_len, num_heads_per_rank, dim]
|
||||
# Input shape: [num_tokens, num_heads, dim]
|
||||
# Output shape: [num_heads, num_tokens, dim]
|
||||
# The right part layout indicates the layout of the attention
|
||||
# output. It is set to NTD to avoid the need for a transpose
|
||||
# operation after attention.
|
||||
input_layout = "TND_NTD"
|
||||
# TODO: If the driver is upgraded later, the contiguous function can be deleted.
|
||||
q_nope = q_nope.view(num_tokens, self.num_heads, -1).contiguous()
|
||||
q_pe = q_pe.view(num_tokens, self.num_heads, -1)
|
||||
@@ -1186,6 +1178,11 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore
|
||||
actual_seq_lengths = decode_meta.actual_seq_lengths_q
|
||||
else:
|
||||
# Input shape: [num_reqs, num_heads, seq_len, dim]
|
||||
# Output shape: [num_heads, num_reqs, seq_len, dim]
|
||||
# The output layout is set to NBSD to eliminate the need for a
|
||||
# transpose operation after attention.
|
||||
input_layout = "BNSD_NBSD"
|
||||
q_nope = q_nope.view(num_tokens, self.num_heads, 1,
|
||||
-1).contiguous()
|
||||
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
|
||||
@@ -1227,7 +1224,10 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
q_nope, k_nope, k_nope, **common_kwargs)
|
||||
update_graph_params_workspaces(num_tokens, workspace)
|
||||
|
||||
attn_output = torch.empty_like(q_nope)
|
||||
attn_output = torch.empty(
|
||||
(q_nope.shape[1], q_nope.shape[0], *q_nope.shape[2:]),
|
||||
dtype=q_nope.dtype,
|
||||
device=q_nope.device)
|
||||
softmax_lse = torch.empty(num_tokens,
|
||||
dtype=q_nope.dtype,
|
||||
device=q_nope.device)
|
||||
|
||||
Reference in New Issue
Block a user