[WIP][Perf]remove unnecessary padding before MLA V1 prefill (#917)
<!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? Currently, the implementation for MLA V1 pads q, k, v to `head_dim` 256 to conform to early MLA kernel. But the new MLA kernel supports `head_dim` that can't be devided by 128. Therefore we can remove those unnecessary paddings to boost the performance ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> Signed-off-by: angazenn <zengyanjia@huawei.com> Co-authored-by: angazenn <zengyanjia@huawei.com>
This commit is contained in:
@@ -373,12 +373,6 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.qk_head_dim = qk_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
# TODO: below padding should be removed after kernel is ready
|
||||
# we found npu_flash_attention can only works on 128 divisible head_dim, we pad it to target size here
|
||||
# and slice the final result to guarantee its functionality.
|
||||
self.padding_head_dim = (
|
||||
(self.qk_nope_head_dim + self.qk_rope_head_dim - 1) // 128 +
|
||||
1) * 128
|
||||
|
||||
# Hack for V1 for now to avoid torch library overhead (since we are
|
||||
# already inside an attention custom op), pull out the forward
|
||||
@@ -520,7 +514,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
||||
attn_output = torch.empty(num_tokens,
|
||||
self.num_heads,
|
||||
self.padding_head_dim,
|
||||
self.v_head_dim,
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
k_nope, value = self.kv_b_proj(kv_c_normed)[0].view(
|
||||
@@ -529,31 +523,17 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
key = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
|
||||
dim=-1)
|
||||
pad_query = torch.nn.functional.pad(query, [
|
||||
0, self.padding_head_dim - self.qk_rope_head_dim -
|
||||
self.qk_nope_head_dim
|
||||
],
|
||||
value=0)
|
||||
pad_key = torch.nn.functional.pad(key, [
|
||||
0, self.padding_head_dim - self.qk_rope_head_dim -
|
||||
self.qk_nope_head_dim
|
||||
],
|
||||
value=0)
|
||||
pad_value = torch.nn.functional.pad(
|
||||
value, [0, self.padding_head_dim - self.v_head_dim], value=0)
|
||||
torch_npu._npu_flash_attention(
|
||||
query=pad_query,
|
||||
key=pad_key,
|
||||
value=pad_value,
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
mask=attn_metadata.attn_mask,
|
||||
seq_len=attn_metadata.prefill.context_lens,
|
||||
scale_value=self.scale,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_heads,
|
||||
out=attn_output)
|
||||
attn_output = attn_output.view(
|
||||
-1, self.num_heads,
|
||||
self.padding_head_dim)[:, :, :self.v_head_dim]
|
||||
attn_output = attn_output.view(-1, self.num_heads, self.v_head_dim)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Unexpected path reached, AscendMLAImpl should only have PrefillNoCache and ChunkedPrefill scenario in forward prefill, please file a bug to vllm-ascend !"
|
||||
|
||||
Reference in New Issue
Block a user