MLA prefill preformance optimization (#5275)

### What this PR does / why we need it?
Since the _npu_ring_mla operator deteriorates in long-sequencescenarios,
the long sequence is split into shorter sequences for input to improve
performance.
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?

- vLLM version: release/v0.13.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: pichangping <1337510399@qq.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
pichangping
2025-12-27 09:19:45 +08:00
committed by GitHub
parent 1486e0d06c
commit 711f1861e4
4 changed files with 361 additions and 50 deletions

View File

@@ -778,11 +778,18 @@ class AscendMlaCPImpl(AscendMLAImpl):
return output
def _attention_with_mask_and_nomask(
self, q_nope: torch.Tensor, q_pe: torch.Tensor,
k_nope: torch.Tensor, k_pe: torch.Tensor, value: torch.Tensor,
kv_mask_idx: torch.Tensor, kv_nomask_idx: torch.Tensor,
attn_mask_seqlens: torch.Tensor, attn_nomask_seqlens: torch.Tensor,
mask: torch.Tensor):
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
k_nope: torch.Tensor,
k_pe: torch.Tensor,
value: torch.Tensor,
kv_mask_idx: torch.Tensor,
kv_nomask_idx: list[torch.Tensor],
attn_mask_seqlens: torch.Tensor,
attn_nomask_seqlens: list[torch.Tensor],
mask: torch.Tensor,
):
attn_output = torch.empty(q_nope.shape[0],
self.num_heads,
self.v_head_dim,
@@ -816,30 +823,32 @@ class AscendMlaCPImpl(AscendMLAImpl):
softmax_lse=attn_lse)
# nomask
if kv_nomask_idx.shape[0] == 0:
if not kv_nomask_idx or len(kv_nomask_idx[0]) == 0:
return attn_output, attn_lse
k_nope_nomask = torch.index_select(k_nope, 0, kv_nomask_idx)
value_nomask = torch.index_select(value, 0, kv_nomask_idx)
k_pe_nomask = torch.index_select(k_pe, 0, kv_nomask_idx)
torch_npu.atb.npu_ring_mla(q_nope=q_nope,
q_rope=q_pe,
k_nope=k_nope_nomask,
k_rope=k_pe_nomask,
value=value_nomask,
mask=mask,
seqlen=attn_nomask_seqlens,
head_num=self.num_heads,
kv_head_num=self.num_heads,
pre_out=attn_output,
prev_lse=attn_lse,
qk_scale=self.scale,
kernel_type="kernel_type_high_precision",
mask_type="no_mask",
input_layout="type_bsnd",
calc_type="calc_type_default",
output=attn_output,
softmax_lse=attn_lse)
for kv_nomask_idx_split, attn_nomask_seqlens_split in zip(
kv_nomask_idx, attn_nomask_seqlens):
k_nope_nomask = torch.index_select(k_nope, 0, kv_nomask_idx_split)
value_nomask = torch.index_select(value, 0, kv_nomask_idx_split)
k_pe_nomask = torch.index_select(k_pe, 0, kv_nomask_idx_split)
torch_npu.atb.npu_ring_mla(
q_nope=q_nope,
q_rope=q_pe,
k_nope=k_nope_nomask,
k_rope=k_pe_nomask,
value=value_nomask,
mask=mask,
seqlen=attn_nomask_seqlens_split,
head_num=self.num_heads,
kv_head_num=self.num_heads,
pre_out=attn_output,
prev_lse=attn_lse,
qk_scale=self.scale,
kernel_type="kernel_type_high_precision",
mask_type="no_mask",
input_layout="type_bsnd",
calc_type="calc_type_default",
output=attn_output,
softmax_lse=attn_lse)
return attn_output, attn_lse
def _forward_decode_pcp_dcp(