[Refactor] Replace npu_ring_mla with FIA in MLA prefill (#5704)

### What this PR does / why we need it?

**Refactor: Replace npu_ring_mla with FIA in MLA prefill**

This PR refactors the MLA (Multi-Layer Attention) prefill implementation
by replacing `npu_ring_mla` with `npu_fused_infer_attention_score` (FIA)
operator, unifying the attention backend with the standard attention
implementation.

**Key changes:**

1. **Core prefill refactoring (`mla_v1.py`)**
- Replace `npu_ring_mla` with `npu_fused_infer_attention_score` in
`_forward_prefill` and `_compute_prefill_context`
   - Use TND layout with `softmax_lse_flag=True` for prefill attention
- Use `npu_attention_update` to merge multiple chunk outputs with LSE
(Log-Sum-Exp)
- Change `attn_mask` from `get_final_mla_mask()` to
`get_splitfuse_attn_mask()` for FIA compatibility

2. **Data type handling**
- Add automatic float16 → bfloat16 conversion (FIA with TND layout only
supports bfloat16)
   - Convert output back to original dtype after FIA computation

3. **Metadata optimization**
   - Pre-calculate `actual_seq_lengths_q` in `AscendMLAPrefillMetadata`
- Pre-calculate `chunk_actual_seq_lengths_kv_list` in
`ChunkedContextMetadata`
- Move `torch.cumsum` operations from forward pass to metadata building
phase

4. **CP compatibility (`mla_cp.py`)**
- Add `_ring_mla_mask_builder` to get `npu_ring_mla`-compatible masks
for Context Parallel scenarios
- Add `chunk_actual_seq_lengths_kv_list` field to
`CPChunkedContextMetadata`

**Why we need it:**
- **Backend unification**: Aligns MLA prefill with standard attention
implementation (`attention_v1.py`)
- **Better chunked context support**: FIA + `npu_attention_update`
provides native LSE-based output merging
- **Future compatibility**: Prepares for eventual `npu_ring_mla` removal
across the codebase

### Does this PR introduce _any_ user-facing change?

**No.** This is a pure refactoring with no functional changes - same
behavior, unified backend.

---
- Related issue: #5463 (item 7)
- vLLM version: v0.14.1

Signed-off-by: lico67373 <918688502@qq.com>
This commit is contained in:
LICO67373
2026-03-16 10:33:09 +08:00
committed by GitHub
parent e20f0b1a0d
commit 71c21f76f5
6 changed files with 183 additions and 79 deletions

View File

@@ -130,6 +130,10 @@ def get_chunk_metadata(pcp_size, dcp_size, num_prefills, num_decodes,
out=padded_local_cu_chunk_seq_lens_cpu[:, 1:],
dtype=torch.int32,
)
chunk_actual_seq_lengths_kv_list = [
torch.cumsum(chunk_seq_lens[i], dim=0).tolist()
for i in range(num_chunks)
]
chunked_context_metadata = CPChunkedContextMetadata(
cu_seq_lens=cu_seq_lens_cpu.to(non_blocking=True),
starts=local_chunk_starts.to(non_blocking=True),
@@ -137,6 +141,7 @@ def get_chunk_metadata(pcp_size, dcp_size, num_prefills, num_decodes,
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
chunk_seq_lens=chunk_seq_lens,
chunk_seq_lens_npu=chunk_seq_lens,
chunk_actual_seq_lengths_kv_list=chunk_actual_seq_lengths_kv_list,
workspace=None,
padded_chunk_seq_lens_npu=padded_local_chunk_seq_lens,
padded_local_chunk_seq_lens=padded_local_chunk_seq_lens.tolist(),
@@ -500,19 +505,23 @@ class TestAscendMLAImpl(TestBase):
self.assertEqual(result.shape[1], self.impl.v_head_dim)
@patch("torch_npu.atb.npu_paged_cache_load")
@patch("torch_npu.atb.npu_ring_mla")
@patch("torch_npu.npu_attention_update")
@patch("torch_npu.npu_fused_infer_attention_score")
@patch_distributed_groups(dcp_size=2, pcp_size=2)
def test_compute_prefill_context_with_dcp_pcp(self, mock_all2all, mock_dcp,
mock_pcp, mock_ring,
mock_load):
mock_pcp, mock_fia,
mock_update, mock_load):
def mock_ring_attn(q_nope, q_rope, k_nope, k_rope, value, mask, seqlen,
head_num, kv_head_num, pre_out, prev_lse, qk_scale,
kernel_type, mask_type, input_layout, calc_type,
output, softmax_lse):
return torch.randn(q_rope.shape[0], value.shape[1], value.shape[2])
def mock_fia_attn(*args, **kwargs):
q = args[0]
v = args[2]
return (torch.randn(q.shape[0],
v.shape[1],
v.shape[2],
dtype=torch.float16),
torch.randn(v.shape[1], q.shape[0], dtype=torch.float16))
mock_ring.side_effect = mock_ring_attn
mock_fia.side_effect = mock_fia_attn
def mock_kv_b_proj(kv_c_normed):
return (torch.randn(kv_c_normed.shape[0],
@@ -534,6 +543,13 @@ class TestAscendMLAImpl(TestBase):
# mock proj
self.impl.kv_b_proj.side_effect = mock_kv_b_proj
def mock_update_fn(lse_list, out_list, mode):
total_len = out_list[0].shape[0]
D = out_list[0].shape[1]
return (torch.randn(total_len, D, dtype=torch.float32), None)
mock_update.side_effect = mock_update_fn
NUM_BLOCKS, BLOCK_SIZE = 10, 32 # fixed
USED_BLOCKS = 3
# pcp_size, dcp_size, nums_tokens_per_rank, nums_all_rank_context, num_prefills, num_decodes, num_seqs, cp_local_block_size, num_computed_tokens, num_computed_tokens_of_pcp_dcp
@@ -586,8 +602,8 @@ class TestAscendMLAImpl(TestBase):
self.impl.num_heads,
self.impl.v_head_dim,
dtype=torch.float16)
prefix_lse = torch.randn(sum(nums_tokens_per_rank),
self.impl.num_heads,
prefix_lse = torch.randn(self.impl.num_heads,
sum(nums_tokens_per_rank),
dtype=torch.float16)
chunk_ctx = get_chunk_metadata(
pcp_size,
@@ -602,7 +618,7 @@ class TestAscendMLAImpl(TestBase):
cp_local_block_size=cp_local_block_size)
meta = MagicMock()
prefill_meta = MagicMock()
prefill_meta.query_lens = nums_tokens_per_rank
prefill_meta.query_lens = torch.tensor(nums_tokens_per_rank)
prefill_meta.block_table = torch.randint(
0, USED_BLOCKS, (1, 64)) # (batch, max_blocks)
prefill_meta.chunked_context = chunk_ctx
@@ -621,14 +637,14 @@ class TestAscendMLAImpl(TestBase):
self.assertEqual(mock_reorg.call_count,
iters * (1 if dcp_size * pcp_size > 1 else 0))
self.assertEqual(mock_load.call_count, iters)
self.assertEqual(mock_ring.call_count, iters)
self.assertEqual(mock_fia.call_count, iters)
mock_reorg.reset_mock()
mock_load.reset_mock()
mock_ring.reset_mock()
mock_fia.reset_mock()
mock_update.reset_mock()
mock_dcp.reset_mock()
mock_pcp.reset_mock()
self.assertEqual(out.shape, prefix_out.shape)
self.assertEqual(lse.shape, prefix_lse.shape)
@patch_distributed_groups(dcp_size=2, pcp_size=2)
def test_reorg_kvcache_with_dcp_pcp(self, mock_all2all, mock_dcp,

View File

@@ -102,7 +102,8 @@ class TestAscendMLAPrefillMetadata(TestBase):
max_seq_lens=max_seq_lens,
workspace=workspace,
chunk_seq_lens=chunk_seq_lens,
chunk_seq_lens_npu=chunk_seq_lens)
chunk_seq_lens_npu=chunk_seq_lens,
chunk_actual_seq_lengths_kv_list=[[2, 4]])
metadata = AscendMLAPrefillMetadata(
attn_mask=torch.tensor([[1, 0], [1, 1]], dtype=torch.bool),
@@ -886,8 +887,9 @@ class TestAscendMLAImpl(TestBase):
self.assertTrue(torch.equal(prefix_lse, lse))
@patch("torch_npu.atb.npu_paged_cache_load")
@patch("torch_npu.atb.npu_ring_mla")
def test_compute_prefill_context(self, mock_ring, mock_load):
@patch("torch_npu.npu_attention_update")
@patch("torch_npu.npu_fused_infer_attention_score")
def test_compute_prefill_context(self, mock_fia, mock_update, mock_load):
S, N, D, VD = 2, self.impl.num_heads, self.impl.qk_head_dim, self.impl.v_head_dim
_, AND = self.impl.qk_rope_head_dim, self.impl.qk_nope_head_dim
latent_kv_dim = self.impl.kv_lora_rank
@@ -898,11 +900,16 @@ class TestAscendMLAImpl(TestBase):
kv_cache_0 = torch.randn(num_blocks, block_size, N, latent_kv_dim)
kv_cache_1 = torch.randn(num_blocks, block_size, N, D)
kv_cache = [kv_cache_0, kv_cache_1]
prefix_out = torch.randn(S, N, 128)
prefix_lse = torch.randn(S, N)
prefix_out = torch.randn(S, N, VD)
prefix_lse = torch.randn(N, S)
self.impl.kv_b_proj.return_value = (torch.randn(8, N, VD + AND), )
# Mock FIA to return output and lse
mock_fia.return_value = (torch.randn(S, N, VD), torch.randn(N, S))
# Mock attention_update to return merged output
mock_update.return_value = (torch.randn(S * N, VD), None)
chunk_ctx = MagicMock()
chunk_ctx.seq_tot = [8]
chunk_ctx.chunk_seq_lens = [torch.tensor([8])]
@@ -911,7 +918,7 @@ class TestAscendMLAImpl(TestBase):
prefill_meta = MagicMock()
prefill_meta.chunked_context = chunk_ctx
prefill_meta.query_lens = [8]
prefill_meta.query_lens = torch.tensor([S])
prefill_meta.block_table = torch.randint(0, 100, (S, 4))
meta = MagicMock()
@@ -924,10 +931,10 @@ class TestAscendMLAImpl(TestBase):
prefix_lse)
mock_load.assert_called_once()
mock_ring.assert_called_once()
mock_fia.assert_called_once()
mock_update.assert_called_once()
self.assertEqual(out.shape, prefix_out.shape)
self.assertEqual(lse.shape, prefix_lse.shape)
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch("vllm_ascend.attention.mla_v1.AscendMLAImpl._v_up_proj")