MLA prefill preformance optimization (#5275)

### What this PR does / why we need it?
Since the _npu_ring_mla operator deteriorates in long-sequencescenarios,
the long sequence is split into shorter sequences for input to improve
performance.
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?

- vLLM version: release/v0.13.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: pichangping <1337510399@qq.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
pichangping
2025-12-27 09:19:45 +08:00
committed by GitHub
parent 1486e0d06c
commit 711f1861e4
4 changed files with 361 additions and 50 deletions

View File

@@ -865,7 +865,7 @@ class TestAscendMLAImpl(TestBase):
q_head_idx, q_tail_idx, kv_with_q_head_nomask_idx, kv_with_q_head_mask_idx, kv_with_q_tail_nomask_idx, \
kv_with_q_tail_mask_idx, chunk_seqlens, kv_with_q_head_nomask_seqlens, kv_with_q_tail_nomask_seqlens = get_pcp_split_info(
rank, pcp_size, nums_tokens_per_rank)
kv_with_q_head_nomask_idx = [kv_with_q_head_nomask_idx]
output_head, lse_head = self.impl._attention_with_mask_and_nomask(
q_nope=torch.index_select(q_nope, 0, q_head_idx),
q_pe=torch.index_select(q_pe, 0, q_head_idx),
@@ -876,15 +876,16 @@ class TestAscendMLAImpl(TestBase):
kv_nomask_idx=kv_with_q_head_nomask_idx,
attn_mask_seqlens=torch.tensor(
[chunk_seqlens, chunk_seqlens], dtype=torch.int32),
attn_nomask_seqlens=kv_with_q_head_nomask_seqlens,
attn_nomask_seqlens=[kv_with_q_head_nomask_seqlens],
mask=mask)
self.assertEqual(output_head.shape,
(q_head_idx.shape[0], num_heads, v_head_dim))
self.assertEqual(lse_head.shape,
(num_heads, q_head_idx.shape[0]))
self.assertEqual(mock_npu_ring_mla.call_count,
1 + (kv_with_q_head_nomask_idx.shape[0] != 0))
1 + (len(kv_with_q_head_nomask_idx[0]) != 0))
mock_npu_ring_mla.reset_mock()
kv_with_q_tail_nomask_idx = [kv_with_q_tail_nomask_idx]
output_tail, lse_tail = self.impl._attention_with_mask_and_nomask(
q_nope=torch.index_select(q_nope, 0, q_tail_idx),
q_pe=torch.index_select(q_pe, 0, q_tail_idx),
@@ -895,7 +896,7 @@ class TestAscendMLAImpl(TestBase):
kv_nomask_idx=kv_with_q_tail_nomask_idx,
attn_mask_seqlens=torch.tensor(
[chunk_seqlens, chunk_seqlens], dtype=torch.int32),
attn_nomask_seqlens=kv_with_q_tail_nomask_seqlens,
attn_nomask_seqlens=[kv_with_q_tail_nomask_seqlens],
mask=mask)
self.assertEqual(output_tail.shape,
@@ -903,7 +904,7 @@ class TestAscendMLAImpl(TestBase):
self.assertEqual(lse_tail.shape,
(num_heads, q_tail_idx.shape[0]))
self.assertEqual(mock_npu_ring_mla.call_count,
1 + (kv_with_q_tail_nomask_idx.shape[0] != 0))
1 + (len(kv_with_q_tail_nomask_idx[0]) != 0))
mock_npu_ring_mla.reset_mock()
@patch("torch.distributed.all_to_all_single")