[WIP][Perf]remove unnecessary padding before MLA V1 prefill (#917)

<!--  Thanks for sending a pull request!

BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html

-->
### What this PR does / why we need it?
Currently, the implementation for MLA V1 pads q, k, v to `head_dim` 256
to conform to early MLA kernel. But the new MLA kernel supports
`head_dim` that can't be devided by 128. Therefore we can remove those
unnecessary paddings to boost the performance

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->

Signed-off-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: angazenn <zengyanjia@huawei.com>
This commit is contained in:
Angazenn
2025-05-23 14:14:06 +08:00
committed by GitHub
parent dc6172efd3
commit a970b27e2d

View File

@@ -373,12 +373,6 @@ class AscendMLAImpl(MLAAttentionImpl):
self.qk_rope_head_dim = qk_rope_head_dim
self.qk_head_dim = qk_head_dim
self.v_head_dim = v_head_dim
# TODO: below padding should be removed after kernel is ready
# we found npu_flash_attention can only works on 128 divisible head_dim, we pad it to target size here
# and slice the final result to guarantee its functionality.
self.padding_head_dim = (
(self.qk_nope_head_dim + self.qk_rope_head_dim - 1) // 128 +
1) * 128
# Hack for V1 for now to avoid torch library overhead (since we are
# already inside an attention custom op), pull out the forward
@@ -520,7 +514,7 @@ class AscendMLAImpl(MLAAttentionImpl):
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
attn_output = torch.empty(num_tokens,
self.num_heads,
self.padding_head_dim,
self.v_head_dim,
dtype=query.dtype,
device=query.device)
k_nope, value = self.kv_b_proj(kv_c_normed)[0].view(
@@ -529,31 +523,17 @@ class AscendMLAImpl(MLAAttentionImpl):
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
key = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
dim=-1)
pad_query = torch.nn.functional.pad(query, [
0, self.padding_head_dim - self.qk_rope_head_dim -
self.qk_nope_head_dim
],
value=0)
pad_key = torch.nn.functional.pad(key, [
0, self.padding_head_dim - self.qk_rope_head_dim -
self.qk_nope_head_dim
],
value=0)
pad_value = torch.nn.functional.pad(
value, [0, self.padding_head_dim - self.v_head_dim], value=0)
torch_npu._npu_flash_attention(
query=pad_query,
key=pad_key,
value=pad_value,
query=query,
key=key,
value=value,
mask=attn_metadata.attn_mask,
seq_len=attn_metadata.prefill.context_lens,
scale_value=self.scale,
num_heads=self.num_heads,
num_kv_heads=self.num_heads,
out=attn_output)
attn_output = attn_output.view(
-1, self.num_heads,
self.padding_head_dim)[:, :, :self.v_head_dim]
attn_output = attn_output.view(-1, self.num_heads, self.v_head_dim)
else:
raise RuntimeError(
"Unexpected path reached, AscendMLAImpl should only have PrefillNoCache and ChunkedPrefill scenario in forward prefill, please file a bug to vllm-ascend !"