[Refactor] Replace npu_ring_mla with FIA in MLA prefill (#5704)
### What this PR does / why we need it? **Refactor: Replace npu_ring_mla with FIA in MLA prefill** This PR refactors the MLA (Multi-Layer Attention) prefill implementation by replacing `npu_ring_mla` with `npu_fused_infer_attention_score` (FIA) operator, unifying the attention backend with the standard attention implementation. **Key changes:** 1. **Core prefill refactoring (`mla_v1.py`)** - Replace `npu_ring_mla` with `npu_fused_infer_attention_score` in `_forward_prefill` and `_compute_prefill_context` - Use TND layout with `softmax_lse_flag=True` for prefill attention - Use `npu_attention_update` to merge multiple chunk outputs with LSE (Log-Sum-Exp) - Change `attn_mask` from `get_final_mla_mask()` to `get_splitfuse_attn_mask()` for FIA compatibility 2. **Data type handling** - Add automatic float16 → bfloat16 conversion (FIA with TND layout only supports bfloat16) - Convert output back to original dtype after FIA computation 3. **Metadata optimization** - Pre-calculate `actual_seq_lengths_q` in `AscendMLAPrefillMetadata` - Pre-calculate `chunk_actual_seq_lengths_kv_list` in `ChunkedContextMetadata` - Move `torch.cumsum` operations from forward pass to metadata building phase 4. **CP compatibility (`mla_cp.py`)** - Add `_ring_mla_mask_builder` to get `npu_ring_mla`-compatible masks for Context Parallel scenarios - Add `chunk_actual_seq_lengths_kv_list` field to `CPChunkedContextMetadata` **Why we need it:** - **Backend unification**: Aligns MLA prefill with standard attention implementation (`attention_v1.py`) - **Better chunked context support**: FIA + `npu_attention_update` provides native LSE-based output merging - **Future compatibility**: Prepares for eventual `npu_ring_mla` removal across the codebase ### Does this PR introduce _any_ user-facing change? **No.** This is a pure refactoring with no functional changes - same behavior, unified backend. --- - Related issue: #5463 (item 7) - vLLM version: v0.14.1 Signed-off-by: lico67373 <918688502@qq.com>
This commit is contained in:
@@ -102,7 +102,8 @@ class TestAscendMLAPrefillMetadata(TestBase):
|
||||
max_seq_lens=max_seq_lens,
|
||||
workspace=workspace,
|
||||
chunk_seq_lens=chunk_seq_lens,
|
||||
chunk_seq_lens_npu=chunk_seq_lens)
|
||||
chunk_seq_lens_npu=chunk_seq_lens,
|
||||
chunk_actual_seq_lengths_kv_list=[[2, 4]])
|
||||
|
||||
metadata = AscendMLAPrefillMetadata(
|
||||
attn_mask=torch.tensor([[1, 0], [1, 1]], dtype=torch.bool),
|
||||
@@ -886,8 +887,9 @@ class TestAscendMLAImpl(TestBase):
|
||||
self.assertTrue(torch.equal(prefix_lse, lse))
|
||||
|
||||
@patch("torch_npu.atb.npu_paged_cache_load")
|
||||
@patch("torch_npu.atb.npu_ring_mla")
|
||||
def test_compute_prefill_context(self, mock_ring, mock_load):
|
||||
@patch("torch_npu.npu_attention_update")
|
||||
@patch("torch_npu.npu_fused_infer_attention_score")
|
||||
def test_compute_prefill_context(self, mock_fia, mock_update, mock_load):
|
||||
S, N, D, VD = 2, self.impl.num_heads, self.impl.qk_head_dim, self.impl.v_head_dim
|
||||
_, AND = self.impl.qk_rope_head_dim, self.impl.qk_nope_head_dim
|
||||
latent_kv_dim = self.impl.kv_lora_rank
|
||||
@@ -898,11 +900,16 @@ class TestAscendMLAImpl(TestBase):
|
||||
kv_cache_0 = torch.randn(num_blocks, block_size, N, latent_kv_dim)
|
||||
kv_cache_1 = torch.randn(num_blocks, block_size, N, D)
|
||||
kv_cache = [kv_cache_0, kv_cache_1]
|
||||
prefix_out = torch.randn(S, N, 128)
|
||||
prefix_lse = torch.randn(S, N)
|
||||
prefix_out = torch.randn(S, N, VD)
|
||||
prefix_lse = torch.randn(N, S)
|
||||
|
||||
self.impl.kv_b_proj.return_value = (torch.randn(8, N, VD + AND), )
|
||||
|
||||
# Mock FIA to return output and lse
|
||||
mock_fia.return_value = (torch.randn(S, N, VD), torch.randn(N, S))
|
||||
# Mock attention_update to return merged output
|
||||
mock_update.return_value = (torch.randn(S * N, VD), None)
|
||||
|
||||
chunk_ctx = MagicMock()
|
||||
chunk_ctx.seq_tot = [8]
|
||||
chunk_ctx.chunk_seq_lens = [torch.tensor([8])]
|
||||
@@ -911,7 +918,7 @@ class TestAscendMLAImpl(TestBase):
|
||||
|
||||
prefill_meta = MagicMock()
|
||||
prefill_meta.chunked_context = chunk_ctx
|
||||
prefill_meta.query_lens = [8]
|
||||
prefill_meta.query_lens = torch.tensor([S])
|
||||
prefill_meta.block_table = torch.randint(0, 100, (S, 4))
|
||||
|
||||
meta = MagicMock()
|
||||
@@ -924,10 +931,10 @@ class TestAscendMLAImpl(TestBase):
|
||||
prefix_lse)
|
||||
|
||||
mock_load.assert_called_once()
|
||||
mock_ring.assert_called_once()
|
||||
mock_fia.assert_called_once()
|
||||
mock_update.assert_called_once()
|
||||
|
||||
self.assertEqual(out.shape, prefix_out.shape)
|
||||
self.assertEqual(lse.shape, prefix_lse.shape)
|
||||
|
||||
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||
@patch("vllm_ascend.attention.mla_v1.AscendMLAImpl._v_up_proj")
|
||||
|
||||
Reference in New Issue
Block a user