MLA prefill preformance optimization (#5456)

### What this PR does / why we need it?
Since the _npu_ring_mla operator deteriorates in long-sequencescenarios,
the long sequence is split into shorter sequences for input to improve
performance.

- vLLM version: v0.13.0
- vLLM main:
5326c89803

---------

Signed-off-by: pichangping <1337510399@qq.com>
This commit is contained in:
pichangping
2026-01-05 11:41:59 +08:00
committed by GitHub
parent c23cf30709
commit 50e7934415
4 changed files with 351 additions and 46 deletions

View File

@@ -813,7 +813,7 @@ class TestAscendMLAImpl(TestBase):
q_head_idx, q_tail_idx, kv_with_q_head_nomask_idx, kv_with_q_head_mask_idx, kv_with_q_tail_nomask_idx, \
kv_with_q_tail_mask_idx, chunk_seqlens, kv_with_q_head_nomask_seqlens, kv_with_q_tail_nomask_seqlens = get_pcp_split_info(
rank, pcp_size, nums_tokens_per_rank)
kv_with_q_head_nomask_idx = [kv_with_q_head_nomask_idx]
output_head, lse_head = self.impl._attention_with_mask_and_nomask(
q_nope=torch.index_select(q_nope, 0, q_head_idx),
q_pe=torch.index_select(q_pe, 0, q_head_idx),
@@ -824,15 +824,16 @@ class TestAscendMLAImpl(TestBase):
kv_nomask_idx=kv_with_q_head_nomask_idx,
attn_mask_seqlens=torch.tensor(
[chunk_seqlens, chunk_seqlens], dtype=torch.int32),
attn_nomask_seqlens=kv_with_q_head_nomask_seqlens,
attn_nomask_seqlens=[kv_with_q_head_nomask_seqlens],
mask=mask)
self.assertEqual(output_head.shape,
(q_head_idx.shape[0], num_heads, v_head_dim))
self.assertEqual(lse_head.shape,
(num_heads, q_head_idx.shape[0]))
self.assertEqual(mock_npu_ring_mla.call_count,
1 + (kv_with_q_head_nomask_idx.shape[0] != 0))
1 + (len(kv_with_q_head_nomask_idx[0]) != 0))
mock_npu_ring_mla.reset_mock()
kv_with_q_tail_nomask_idx = [kv_with_q_tail_nomask_idx]
output_tail, lse_tail = self.impl._attention_with_mask_and_nomask(
q_nope=torch.index_select(q_nope, 0, q_tail_idx),
q_pe=torch.index_select(q_pe, 0, q_tail_idx),
@@ -843,7 +844,7 @@ class TestAscendMLAImpl(TestBase):
kv_nomask_idx=kv_with_q_tail_nomask_idx,
attn_mask_seqlens=torch.tensor(
[chunk_seqlens, chunk_seqlens], dtype=torch.int32),
attn_nomask_seqlens=kv_with_q_tail_nomask_seqlens,
attn_nomask_seqlens=[kv_with_q_tail_nomask_seqlens],
mask=mask)
self.assertEqual(output_tail.shape,
@@ -851,7 +852,7 @@ class TestAscendMLAImpl(TestBase):
self.assertEqual(lse_tail.shape,
(num_heads, q_tail_idx.shape[0]))
self.assertEqual(mock_npu_ring_mla.call_count,
1 + (kv_with_q_tail_nomask_idx.shape[0] != 0))
1 + (len(kv_with_q_tail_nomask_idx[0]) != 0))
mock_npu_ring_mla.reset_mock()
@patch_distributed_groups(dcp_size=2, pcp_size=2)