[v0.18.0][BugFix]Revert the code: Replace npu_ring_mla wit FIA with MLA prefill. (#7961)
This pull request reverts previous changes to switch to FIA and instead implements npu_ring_mla for MLA prefill operations(#5704 ). The change streamlines the attention mechanism by removing unnecessary metadata tracking and updating the underlying NPU operations to use the ring-based MLA kernel. This adjustment ensures better compatibility and performance for MLA prefill tasks within the vLLM Ascend backend. Highlights - Migration to npu_ring_mla: Replaced the usage of npu_fused_infer_attention_score (FIA) with npu_ring_mla for MLA prefill operations across the codebase to improve performance and alignment with the intended architecture. - Cleanup of redundant metadata: Removed chunk_actual_seq_lengths_kv_list and actual_seq_lengths_q from various metadata structures as they are no longer required for the updated attention implementation. - Test suite updates: Updated unit tests in test_mla_cp.py and test_mla_v1.py to mock npu_ring_mla instead of the deprecated FIA functions and adjusted test assertions to reflect the new implementation details. Signed-off-by: weijinqian_v1 <weijinqian@huawei.com> Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
This commit is contained in:
@@ -130,10 +130,6 @@ def get_chunk_metadata(pcp_size, dcp_size, num_prefills, num_decodes,
|
||||
out=padded_local_cu_chunk_seq_lens_cpu[:, 1:],
|
||||
dtype=torch.int32,
|
||||
)
|
||||
chunk_actual_seq_lengths_kv_list = [
|
||||
torch.cumsum(chunk_seq_lens[i], dim=0).tolist()
|
||||
for i in range(num_chunks)
|
||||
]
|
||||
chunked_context_metadata = CPChunkedContextMetadata(
|
||||
cu_seq_lens=cu_seq_lens_cpu.to(non_blocking=True),
|
||||
starts=local_chunk_starts.to(non_blocking=True),
|
||||
@@ -141,7 +137,6 @@ def get_chunk_metadata(pcp_size, dcp_size, num_prefills, num_decodes,
|
||||
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
||||
chunk_seq_lens=chunk_seq_lens,
|
||||
chunk_seq_lens_npu=chunk_seq_lens,
|
||||
chunk_actual_seq_lengths_kv_list=chunk_actual_seq_lengths_kv_list,
|
||||
workspace=None,
|
||||
padded_chunk_seq_lens_npu=padded_local_chunk_seq_lens,
|
||||
padded_local_chunk_seq_lens=padded_local_chunk_seq_lens.tolist(),
|
||||
@@ -505,23 +500,19 @@ class TestAscendMLAImpl(TestBase):
|
||||
self.assertEqual(result.shape[1], self.impl.v_head_dim)
|
||||
|
||||
@patch("torch_npu.atb.npu_paged_cache_load")
|
||||
@patch("torch_npu.npu_attention_update")
|
||||
@patch("torch_npu.npu_fused_infer_attention_score")
|
||||
@patch("torch_npu.atb.npu_ring_mla")
|
||||
@patch_distributed_groups(dcp_size=2, pcp_size=2)
|
||||
def test_compute_prefill_context_with_dcp_pcp(self, mock_all2all, mock_dcp,
|
||||
mock_pcp, mock_fia,
|
||||
mock_update, mock_load):
|
||||
mock_pcp, mock_ring,
|
||||
mock_load):
|
||||
|
||||
def mock_fia_attn(*args, **kwargs):
|
||||
q = args[0]
|
||||
v = args[2]
|
||||
return (torch.randn(q.shape[0],
|
||||
v.shape[1],
|
||||
v.shape[2],
|
||||
dtype=torch.float16),
|
||||
torch.randn(v.shape[1], q.shape[0], dtype=torch.float16))
|
||||
def mock_ring_attn(q_nope, q_rope, k_nope, k_rope, value, mask, seqlen,
|
||||
head_num, kv_head_num, pre_out, prev_lse, qk_scale,
|
||||
kernel_type, mask_type, input_layout, calc_type,
|
||||
output, softmax_lse):
|
||||
return torch.randn(q_rope.shape[0], value.shape[1], value.shape[2])
|
||||
|
||||
mock_fia.side_effect = mock_fia_attn
|
||||
mock_ring.side_effect = mock_ring_attn
|
||||
|
||||
def mock_kv_b_proj(kv_c_normed):
|
||||
return (torch.randn(kv_c_normed.shape[0],
|
||||
@@ -543,13 +534,6 @@ class TestAscendMLAImpl(TestBase):
|
||||
|
||||
# mock proj
|
||||
self.impl.kv_b_proj.side_effect = mock_kv_b_proj
|
||||
|
||||
def mock_update_fn(lse_list, out_list, mode):
|
||||
total_len = out_list[0].shape[0]
|
||||
D = out_list[0].shape[1]
|
||||
return (torch.randn(total_len, D, dtype=torch.float32), None)
|
||||
|
||||
mock_update.side_effect = mock_update_fn
|
||||
NUM_BLOCKS, BLOCK_SIZE = 10, 32 # fixed
|
||||
USED_BLOCKS = 3
|
||||
# pcp_size, dcp_size, nums_tokens_per_rank, nums_all_rank_context, num_prefills, num_decodes, num_seqs, cp_local_block_size, num_computed_tokens, num_computed_tokens_of_pcp_dcp
|
||||
@@ -602,8 +586,8 @@ class TestAscendMLAImpl(TestBase):
|
||||
self.impl.num_heads,
|
||||
self.impl.v_head_dim,
|
||||
dtype=torch.float16)
|
||||
prefix_lse = torch.randn(self.impl.num_heads,
|
||||
sum(nums_tokens_per_rank),
|
||||
prefix_lse = torch.randn(sum(nums_tokens_per_rank),
|
||||
self.impl.num_heads,
|
||||
dtype=torch.float16)
|
||||
chunk_ctx = get_chunk_metadata(
|
||||
pcp_size,
|
||||
@@ -618,7 +602,7 @@ class TestAscendMLAImpl(TestBase):
|
||||
cp_local_block_size=cp_local_block_size)
|
||||
meta = MagicMock()
|
||||
prefill_meta = MagicMock()
|
||||
prefill_meta.query_lens = torch.tensor(nums_tokens_per_rank)
|
||||
prefill_meta.query_lens = nums_tokens_per_rank
|
||||
prefill_meta.block_table = torch.randint(
|
||||
0, USED_BLOCKS, (1, 64)) # (batch, max_blocks)
|
||||
prefill_meta.chunked_context = chunk_ctx
|
||||
@@ -637,14 +621,14 @@ class TestAscendMLAImpl(TestBase):
|
||||
self.assertEqual(mock_reorg.call_count,
|
||||
iters * (1 if dcp_size * pcp_size > 1 else 0))
|
||||
self.assertEqual(mock_load.call_count, iters)
|
||||
self.assertEqual(mock_fia.call_count, iters)
|
||||
self.assertEqual(mock_ring.call_count, iters)
|
||||
mock_reorg.reset_mock()
|
||||
mock_load.reset_mock()
|
||||
mock_fia.reset_mock()
|
||||
mock_update.reset_mock()
|
||||
mock_ring.reset_mock()
|
||||
mock_dcp.reset_mock()
|
||||
mock_pcp.reset_mock()
|
||||
self.assertEqual(out.shape, prefix_out.shape)
|
||||
self.assertEqual(lse.shape, prefix_lse.shape)
|
||||
|
||||
@patch_distributed_groups(dcp_size=2, pcp_size=2)
|
||||
def test_reorg_kvcache_with_dcp_pcp(self, mock_all2all, mock_dcp,
|
||||
|
||||
@@ -102,8 +102,7 @@ class TestAscendMLAPrefillMetadata(TestBase):
|
||||
max_seq_lens=max_seq_lens,
|
||||
workspace=workspace,
|
||||
chunk_seq_lens=chunk_seq_lens,
|
||||
chunk_seq_lens_npu=chunk_seq_lens,
|
||||
chunk_actual_seq_lengths_kv_list=[[2, 4]])
|
||||
chunk_seq_lens_npu=chunk_seq_lens)
|
||||
|
||||
metadata = AscendMLAPrefillMetadata(
|
||||
attn_mask=torch.tensor([[1, 0], [1, 1]], dtype=torch.bool),
|
||||
@@ -888,9 +887,8 @@ class TestAscendMLAImpl(TestBase):
|
||||
self.assertTrue(torch.equal(prefix_lse, lse))
|
||||
|
||||
@patch("torch_npu.atb.npu_paged_cache_load")
|
||||
@patch("torch_npu.npu_attention_update")
|
||||
@patch("torch_npu.npu_fused_infer_attention_score")
|
||||
def test_compute_prefill_context(self, mock_fia, mock_update, mock_load):
|
||||
@patch("torch_npu.atb.npu_ring_mla")
|
||||
def test_compute_prefill_context(self, mock_ring, mock_load):
|
||||
S, N, D, VD = 2, self.impl.num_heads, self.impl.qk_head_dim, self.impl.v_head_dim
|
||||
_, AND = self.impl.qk_rope_head_dim, self.impl.qk_nope_head_dim
|
||||
latent_kv_dim = self.impl.kv_lora_rank
|
||||
@@ -901,16 +899,11 @@ class TestAscendMLAImpl(TestBase):
|
||||
kv_cache_0 = torch.randn(num_blocks, block_size, N, latent_kv_dim)
|
||||
kv_cache_1 = torch.randn(num_blocks, block_size, N, D)
|
||||
kv_cache = [kv_cache_0, kv_cache_1]
|
||||
prefix_out = torch.randn(S, N, VD)
|
||||
prefix_lse = torch.randn(N, S)
|
||||
prefix_out = torch.randn(S, N, 128)
|
||||
prefix_lse = torch.randn(S, N)
|
||||
|
||||
self.impl.kv_b_proj.return_value = (torch.randn(8, N, VD + AND), )
|
||||
|
||||
# Mock FIA to return output and lse
|
||||
mock_fia.return_value = (torch.randn(S, N, VD), torch.randn(N, S))
|
||||
# Mock attention_update to return merged output
|
||||
mock_update.return_value = (torch.randn(S * N, VD), None)
|
||||
|
||||
chunk_ctx = MagicMock()
|
||||
chunk_ctx.seq_tot = [8]
|
||||
chunk_ctx.chunk_seq_lens = [torch.tensor([8])]
|
||||
@@ -919,7 +912,7 @@ class TestAscendMLAImpl(TestBase):
|
||||
|
||||
prefill_meta = MagicMock()
|
||||
prefill_meta.chunked_context = chunk_ctx
|
||||
prefill_meta.query_lens = torch.tensor([S])
|
||||
prefill_meta.query_lens = [8]
|
||||
prefill_meta.block_table = torch.randint(0, 100, (S, 4))
|
||||
|
||||
meta = MagicMock()
|
||||
@@ -932,10 +925,10 @@ class TestAscendMLAImpl(TestBase):
|
||||
prefix_lse)
|
||||
|
||||
mock_load.assert_called_once()
|
||||
mock_fia.assert_called_once()
|
||||
mock_update.assert_called_once()
|
||||
mock_ring.assert_called_once()
|
||||
|
||||
self.assertEqual(out.shape, prefix_out.shape)
|
||||
self.assertEqual(lse.shape, prefix_lse.shape)
|
||||
|
||||
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||
@patch("vllm_ascend.attention.mla_v1.AscendMLAImpl._v_up_proj")
|
||||
|
||||
Reference in New Issue
Block a user