Revert "MLA prefill preformance optimization (#5275)" (#5410)

We'll release 0.13.0 soon. The main branch is freeze. Let's revert the
newest change and redo it once 0.13.0 is released
- vLLM version: release/v0.13.0
- vLLM main:
81786c8774
This commit is contained in:
wangxiyuan
2025-12-27 09:48:56 +08:00
committed by GitHub
parent 711f1861e4
commit d1f0df7b4b
4 changed files with 50 additions and 361 deletions

View File

@@ -865,7 +865,7 @@ class TestAscendMLAImpl(TestBase):
q_head_idx, q_tail_idx, kv_with_q_head_nomask_idx, kv_with_q_head_mask_idx, kv_with_q_tail_nomask_idx, \
kv_with_q_tail_mask_idx, chunk_seqlens, kv_with_q_head_nomask_seqlens, kv_with_q_tail_nomask_seqlens = get_pcp_split_info(
rank, pcp_size, nums_tokens_per_rank)
kv_with_q_head_nomask_idx = [kv_with_q_head_nomask_idx]
output_head, lse_head = self.impl._attention_with_mask_and_nomask(
q_nope=torch.index_select(q_nope, 0, q_head_idx),
q_pe=torch.index_select(q_pe, 0, q_head_idx),
@@ -876,16 +876,15 @@ class TestAscendMLAImpl(TestBase):
kv_nomask_idx=kv_with_q_head_nomask_idx,
attn_mask_seqlens=torch.tensor(
[chunk_seqlens, chunk_seqlens], dtype=torch.int32),
attn_nomask_seqlens=[kv_with_q_head_nomask_seqlens],
attn_nomask_seqlens=kv_with_q_head_nomask_seqlens,
mask=mask)
self.assertEqual(output_head.shape,
(q_head_idx.shape[0], num_heads, v_head_dim))
self.assertEqual(lse_head.shape,
(num_heads, q_head_idx.shape[0]))
self.assertEqual(mock_npu_ring_mla.call_count,
1 + (len(kv_with_q_head_nomask_idx[0]) != 0))
1 + (kv_with_q_head_nomask_idx.shape[0] != 0))
mock_npu_ring_mla.reset_mock()
kv_with_q_tail_nomask_idx = [kv_with_q_tail_nomask_idx]
output_tail, lse_tail = self.impl._attention_with_mask_and_nomask(
q_nope=torch.index_select(q_nope, 0, q_tail_idx),
q_pe=torch.index_select(q_pe, 0, q_tail_idx),
@@ -896,7 +895,7 @@ class TestAscendMLAImpl(TestBase):
kv_nomask_idx=kv_with_q_tail_nomask_idx,
attn_mask_seqlens=torch.tensor(
[chunk_seqlens, chunk_seqlens], dtype=torch.int32),
attn_nomask_seqlens=[kv_with_q_tail_nomask_seqlens],
attn_nomask_seqlens=kv_with_q_tail_nomask_seqlens,
mask=mask)
self.assertEqual(output_tail.shape,
@@ -904,7 +903,7 @@ class TestAscendMLAImpl(TestBase):
self.assertEqual(lse_tail.shape,
(num_heads, q_tail_idx.shape[0]))
self.assertEqual(mock_npu_ring_mla.call_count,
1 + (len(kv_with_q_tail_nomask_idx[0]) != 0))
1 + (kv_with_q_tail_nomask_idx.shape[0] != 0))
mock_npu_ring_mla.reset_mock()
@patch("torch.distributed.all_to_all_single")