[Feature] Remove the transpose step after attention and switch to transpose_batchmatmul (#5390)
1. The `npu_fused_infer_attention_score` kernel supports specifying the
output layout. By selecting the appropriate layout, we can avoid the
transpose operation typically required after the attention.
2. The `transpose_batchmatmul` function allows us to control whether the
output tensor is transposed. If we configure `perm_y`, an additional
transpose after executing `v_up` becomes unnecessary.
- vLLM version: release/v0.13.0
- vLLM main:
254f6b9867
---------
Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
This commit is contained in:
@@ -790,22 +790,12 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
self.enable_mlapo = envs.VLLM_ASCEND_ENABLE_MLAPO
|
self.enable_mlapo = envs.VLLM_ASCEND_ENABLE_MLAPO
|
||||||
|
|
||||||
def _v_up_proj(self, x):
|
def _v_up_proj(self, x):
|
||||||
if x.dtype in [torch.float16, torch.bfloat16] \
|
# Convert from (N, B, L)/(N, B, 1, L) to (N, B, L)
|
||||||
and hasattr(torch.ops._C_ascend, "batch_matmul_transpose"):
|
x = x.view(self.num_heads, -1, self.kv_lora_rank)
|
||||||
x = x.view(-1, self.num_heads, self.kv_lora_rank)
|
# Multiply (N, B, L) x (N, L, V) -> (B, N, V)
|
||||||
b, _, _ = x.shape
|
x = torch_npu.npu_transpose_batchmatmul(x, self.W_UV, perm_y=(1, 0, 2))
|
||||||
res = torch.empty((b, self.num_heads, self.v_head_dim),
|
# Convert from (B, N, V) to (B, N * V)
|
||||||
dtype=x.dtype,
|
x = x.reshape(-1, self.num_heads * self.v_head_dim)
|
||||||
device=x.device)
|
|
||||||
torch.ops._C_ascend.batch_matmul_transpose(x, self.W_UV, res)
|
|
||||||
x = res.reshape(-1, self.num_heads * self.v_head_dim)
|
|
||||||
else:
|
|
||||||
# Convert from (B, N, L) to (N, B, L)
|
|
||||||
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
|
||||||
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
|
|
||||||
x = torch.bmm(x, self.W_UV)
|
|
||||||
# Convert from (N, B, V) to (B, N * V)
|
|
||||||
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
# Return `ql_nope`, `q_pe`
|
# Return `ql_nope`, `q_pe`
|
||||||
@@ -1169,16 +1159,18 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
self.kv_lora_rank)
|
self.kv_lora_rank)
|
||||||
k_pe = k_pe.view(-1, self.num_kv_heads, block_size,
|
k_pe = k_pe.view(-1, self.num_kv_heads, block_size,
|
||||||
self.qk_rope_head_dim)
|
self.qk_rope_head_dim)
|
||||||
input_layout = "BNSD"
|
|
||||||
|
|
||||||
if attn_metadata.attn_state in [
|
if attn_metadata.attn_state in [
|
||||||
AscendAttentionState.SpecDecoding,
|
AscendAttentionState.SpecDecoding,
|
||||||
AscendAttentionState.ChunkedPrefill,
|
AscendAttentionState.ChunkedPrefill,
|
||||||
AscendAttentionState.DecodeOnly,
|
AscendAttentionState.DecodeOnly,
|
||||||
] and self.speculative_config is not None:
|
] and self.speculative_config is not None:
|
||||||
# Use TND layout for pure SpecDecoding and SpecDecoding in ChunkedPrefill
|
# Input shape: [num_tokens, num_heads, dim]
|
||||||
input_layout = "TND"
|
# Output shape: [num_heads, num_tokens, dim]
|
||||||
# [bs * q_seq_len, num_heads_per_rank, dim]
|
# The right part layout indicates the layout of the attention
|
||||||
|
# output. It is set to NTD to avoid the need for a transpose
|
||||||
|
# operation after attention.
|
||||||
|
input_layout = "TND_NTD"
|
||||||
# TODO: If the driver is upgraded later, the contiguous function can be deleted.
|
# TODO: If the driver is upgraded later, the contiguous function can be deleted.
|
||||||
q_nope = q_nope.view(num_tokens, self.num_heads, -1).contiguous()
|
q_nope = q_nope.view(num_tokens, self.num_heads, -1).contiguous()
|
||||||
q_pe = q_pe.view(num_tokens, self.num_heads, -1)
|
q_pe = q_pe.view(num_tokens, self.num_heads, -1)
|
||||||
@@ -1186,6 +1178,11 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore
|
spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore
|
||||||
actual_seq_lengths = decode_meta.actual_seq_lengths_q
|
actual_seq_lengths = decode_meta.actual_seq_lengths_q
|
||||||
else:
|
else:
|
||||||
|
# Input shape: [num_reqs, num_heads, seq_len, dim]
|
||||||
|
# Output shape: [num_heads, num_reqs, seq_len, dim]
|
||||||
|
# The output layout is set to NBSD to eliminate the need for a
|
||||||
|
# transpose operation after attention.
|
||||||
|
input_layout = "BNSD_NBSD"
|
||||||
q_nope = q_nope.view(num_tokens, self.num_heads, 1,
|
q_nope = q_nope.view(num_tokens, self.num_heads, 1,
|
||||||
-1).contiguous()
|
-1).contiguous()
|
||||||
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
|
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
|
||||||
@@ -1227,7 +1224,10 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
q_nope, k_nope, k_nope, **common_kwargs)
|
q_nope, k_nope, k_nope, **common_kwargs)
|
||||||
update_graph_params_workspaces(num_tokens, workspace)
|
update_graph_params_workspaces(num_tokens, workspace)
|
||||||
|
|
||||||
attn_output = torch.empty_like(q_nope)
|
attn_output = torch.empty(
|
||||||
|
(q_nope.shape[1], q_nope.shape[0], *q_nope.shape[2:]),
|
||||||
|
dtype=q_nope.dtype,
|
||||||
|
device=q_nope.device)
|
||||||
softmax_lse = torch.empty(num_tokens,
|
softmax_lse = torch.empty(num_tokens,
|
||||||
dtype=q_nope.dtype,
|
dtype=q_nope.dtype,
|
||||||
device=q_nope.device)
|
device=q_nope.device)
|
||||||
|
|||||||
Reference in New Issue
Block a user