From a970b27e2ddc4de2e49aebf7dca447bb143a1b5d Mon Sep 17 00:00:00 2001 From: Angazenn <92204292+Angazenn@users.noreply.github.com> Date: Fri, 23 May 2025 14:14:06 +0800 Subject: [PATCH] [WIP][Perf]remove unnecessary padding before MLA V1 prefill (#917) ### What this PR does / why we need it? Currently, the implementation for MLA V1 pads q, k, v to `head_dim` 256 to conform to early MLA kernel. But the new MLA kernel supports `head_dim` that can't be devided by 128. Therefore we can remove those unnecessary paddings to boost the performance ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Signed-off-by: angazenn Co-authored-by: angazenn --- vllm_ascend/attention/mla_v1.py | 30 +++++------------------------- 1 file changed, 5 insertions(+), 25 deletions(-) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index fabf95e..d987eab 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -373,12 +373,6 @@ class AscendMLAImpl(MLAAttentionImpl): self.qk_rope_head_dim = qk_rope_head_dim self.qk_head_dim = qk_head_dim self.v_head_dim = v_head_dim - # TODO: below padding should be removed after kernel is ready - # we found npu_flash_attention can only works on 128 divisible head_dim, we pad it to target size here - # and slice the final result to guarantee its functionality. - self.padding_head_dim = ( - (self.qk_nope_head_dim + self.qk_rope_head_dim - 1) // 128 + - 1) * 128 # Hack for V1 for now to avoid torch library overhead (since we are # already inside an attention custom op), pull out the forward @@ -520,7 +514,7 @@ class AscendMLAImpl(MLAAttentionImpl): elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: attn_output = torch.empty(num_tokens, self.num_heads, - self.padding_head_dim, + self.v_head_dim, dtype=query.dtype, device=query.device) k_nope, value = self.kv_b_proj(kv_c_normed)[0].view( @@ -529,31 +523,17 @@ class AscendMLAImpl(MLAAttentionImpl): [self.qk_nope_head_dim, self.v_head_dim], dim=-1) key = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) - pad_query = torch.nn.functional.pad(query, [ - 0, self.padding_head_dim - self.qk_rope_head_dim - - self.qk_nope_head_dim - ], - value=0) - pad_key = torch.nn.functional.pad(key, [ - 0, self.padding_head_dim - self.qk_rope_head_dim - - self.qk_nope_head_dim - ], - value=0) - pad_value = torch.nn.functional.pad( - value, [0, self.padding_head_dim - self.v_head_dim], value=0) torch_npu._npu_flash_attention( - query=pad_query, - key=pad_key, - value=pad_value, + query=query, + key=key, + value=value, mask=attn_metadata.attn_mask, seq_len=attn_metadata.prefill.context_lens, scale_value=self.scale, num_heads=self.num_heads, num_kv_heads=self.num_heads, out=attn_output) - attn_output = attn_output.view( - -1, self.num_heads, - self.padding_head_dim)[:, :, :self.v_head_dim] + attn_output = attn_output.view(-1, self.num_heads, self.v_head_dim) else: raise RuntimeError( "Unexpected path reached, AscendMLAImpl should only have PrefillNoCache and ChunkedPrefill scenario in forward prefill, please file a bug to vllm-ascend !"