[v0.18.0][BugFix]Revert the code: Replace npu_ring_mla wit FIA with MLA prefill. (#7961)

This pull request reverts previous changes to switch to FIA and instead
implements npu_ring_mla for MLA prefill operations(#5704 ). The change
streamlines the attention mechanism by removing unnecessary metadata
tracking and updating the underlying NPU operations to use the
ring-based MLA kernel. This adjustment ensures better compatibility and
performance for MLA prefill tasks within the vLLM Ascend backend.

Highlights

- Migration to npu_ring_mla: Replaced the usage of
npu_fused_infer_attention_score (FIA) with npu_ring_mla for MLA prefill
operations across the codebase to improve performance and alignment with
the intended architecture.
- Cleanup of redundant metadata: Removed
chunk_actual_seq_lengths_kv_list and actual_seq_lengths_q from various
metadata structures as they are no longer required for the updated
attention implementation.
- Test suite updates: Updated unit tests in test_mla_cp.py and
test_mla_v1.py to mock npu_ring_mla instead of the deprecated FIA
functions and adjusted test assertions to reflect the new implementation
details.

Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
This commit is contained in:
weijinqian0
2026-04-09 17:00:25 +08:00
committed by GitHub
parent 7c9aa498d6
commit f668ff9ef0
5 changed files with 73 additions and 151 deletions

View File

@@ -102,8 +102,7 @@ class TestAscendMLAPrefillMetadata(TestBase):
max_seq_lens=max_seq_lens,
workspace=workspace,
chunk_seq_lens=chunk_seq_lens,
chunk_seq_lens_npu=chunk_seq_lens,
chunk_actual_seq_lengths_kv_list=[[2, 4]])
chunk_seq_lens_npu=chunk_seq_lens)
metadata = AscendMLAPrefillMetadata(
attn_mask=torch.tensor([[1, 0], [1, 1]], dtype=torch.bool),
@@ -888,9 +887,8 @@ class TestAscendMLAImpl(TestBase):
self.assertTrue(torch.equal(prefix_lse, lse))
@patch("torch_npu.atb.npu_paged_cache_load")
@patch("torch_npu.npu_attention_update")
@patch("torch_npu.npu_fused_infer_attention_score")
def test_compute_prefill_context(self, mock_fia, mock_update, mock_load):
@patch("torch_npu.atb.npu_ring_mla")
def test_compute_prefill_context(self, mock_ring, mock_load):
S, N, D, VD = 2, self.impl.num_heads, self.impl.qk_head_dim, self.impl.v_head_dim
_, AND = self.impl.qk_rope_head_dim, self.impl.qk_nope_head_dim
latent_kv_dim = self.impl.kv_lora_rank
@@ -901,16 +899,11 @@ class TestAscendMLAImpl(TestBase):
kv_cache_0 = torch.randn(num_blocks, block_size, N, latent_kv_dim)
kv_cache_1 = torch.randn(num_blocks, block_size, N, D)
kv_cache = [kv_cache_0, kv_cache_1]
prefix_out = torch.randn(S, N, VD)
prefix_lse = torch.randn(N, S)
prefix_out = torch.randn(S, N, 128)
prefix_lse = torch.randn(S, N)
self.impl.kv_b_proj.return_value = (torch.randn(8, N, VD + AND), )
# Mock FIA to return output and lse
mock_fia.return_value = (torch.randn(S, N, VD), torch.randn(N, S))
# Mock attention_update to return merged output
mock_update.return_value = (torch.randn(S * N, VD), None)
chunk_ctx = MagicMock()
chunk_ctx.seq_tot = [8]
chunk_ctx.chunk_seq_lens = [torch.tensor([8])]
@@ -919,7 +912,7 @@ class TestAscendMLAImpl(TestBase):
prefill_meta = MagicMock()
prefill_meta.chunked_context = chunk_ctx
prefill_meta.query_lens = torch.tensor([S])
prefill_meta.query_lens = [8]
prefill_meta.block_table = torch.randint(0, 100, (S, 4))
meta = MagicMock()
@@ -932,10 +925,10 @@ class TestAscendMLAImpl(TestBase):
prefix_lse)
mock_load.assert_called_once()
mock_fia.assert_called_once()
mock_update.assert_called_once()
mock_ring.assert_called_once()
self.assertEqual(out.shape, prefix_out.shape)
self.assertEqual(lse.shape, prefix_lse.shape)
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch("vllm_ascend.attention.mla_v1.AscendMLAImpl._v_up_proj")