MLA prefill preformance optimization (#5456)
### What this PR does / why we need it?
Since the _npu_ring_mla operator deteriorates in long-sequencescenarios,
the long sequence is split into shorter sequences for input to improve
performance.
- vLLM version: v0.13.0
- vLLM main:
5326c89803
---------
Signed-off-by: pichangping <1337510399@qq.com>
This commit is contained in:
@@ -813,7 +813,7 @@ class TestAscendMLAImpl(TestBase):
|
||||
q_head_idx, q_tail_idx, kv_with_q_head_nomask_idx, kv_with_q_head_mask_idx, kv_with_q_tail_nomask_idx, \
|
||||
kv_with_q_tail_mask_idx, chunk_seqlens, kv_with_q_head_nomask_seqlens, kv_with_q_tail_nomask_seqlens = get_pcp_split_info(
|
||||
rank, pcp_size, nums_tokens_per_rank)
|
||||
|
||||
kv_with_q_head_nomask_idx = [kv_with_q_head_nomask_idx]
|
||||
output_head, lse_head = self.impl._attention_with_mask_and_nomask(
|
||||
q_nope=torch.index_select(q_nope, 0, q_head_idx),
|
||||
q_pe=torch.index_select(q_pe, 0, q_head_idx),
|
||||
@@ -824,15 +824,16 @@ class TestAscendMLAImpl(TestBase):
|
||||
kv_nomask_idx=kv_with_q_head_nomask_idx,
|
||||
attn_mask_seqlens=torch.tensor(
|
||||
[chunk_seqlens, chunk_seqlens], dtype=torch.int32),
|
||||
attn_nomask_seqlens=kv_with_q_head_nomask_seqlens,
|
||||
attn_nomask_seqlens=[kv_with_q_head_nomask_seqlens],
|
||||
mask=mask)
|
||||
self.assertEqual(output_head.shape,
|
||||
(q_head_idx.shape[0], num_heads, v_head_dim))
|
||||
self.assertEqual(lse_head.shape,
|
||||
(num_heads, q_head_idx.shape[0]))
|
||||
self.assertEqual(mock_npu_ring_mla.call_count,
|
||||
1 + (kv_with_q_head_nomask_idx.shape[0] != 0))
|
||||
1 + (len(kv_with_q_head_nomask_idx[0]) != 0))
|
||||
mock_npu_ring_mla.reset_mock()
|
||||
kv_with_q_tail_nomask_idx = [kv_with_q_tail_nomask_idx]
|
||||
output_tail, lse_tail = self.impl._attention_with_mask_and_nomask(
|
||||
q_nope=torch.index_select(q_nope, 0, q_tail_idx),
|
||||
q_pe=torch.index_select(q_pe, 0, q_tail_idx),
|
||||
@@ -843,7 +844,7 @@ class TestAscendMLAImpl(TestBase):
|
||||
kv_nomask_idx=kv_with_q_tail_nomask_idx,
|
||||
attn_mask_seqlens=torch.tensor(
|
||||
[chunk_seqlens, chunk_seqlens], dtype=torch.int32),
|
||||
attn_nomask_seqlens=kv_with_q_tail_nomask_seqlens,
|
||||
attn_nomask_seqlens=[kv_with_q_tail_nomask_seqlens],
|
||||
mask=mask)
|
||||
|
||||
self.assertEqual(output_tail.shape,
|
||||
@@ -851,7 +852,7 @@ class TestAscendMLAImpl(TestBase):
|
||||
self.assertEqual(lse_tail.shape,
|
||||
(num_heads, q_tail_idx.shape[0]))
|
||||
self.assertEqual(mock_npu_ring_mla.call_count,
|
||||
1 + (kv_with_q_tail_nomask_idx.shape[0] != 0))
|
||||
1 + (len(kv_with_q_tail_nomask_idx[0]) != 0))
|
||||
mock_npu_ring_mla.reset_mock()
|
||||
|
||||
@patch_distributed_groups(dcp_size=2, pcp_size=2)
|
||||
|
||||
Reference in New Issue
Block a user