diff --git a/tests/ut/attention/test_mla_cp.py b/tests/ut/attention/test_mla_cp.py index 6bcc6b4e..1adf6419 100755 --- a/tests/ut/attention/test_mla_cp.py +++ b/tests/ut/attention/test_mla_cp.py @@ -130,10 +130,6 @@ def get_chunk_metadata(pcp_size, dcp_size, num_prefills, num_decodes, out=padded_local_cu_chunk_seq_lens_cpu[:, 1:], dtype=torch.int32, ) - chunk_actual_seq_lengths_kv_list = [ - torch.cumsum(chunk_seq_lens[i], dim=0).tolist() - for i in range(num_chunks) - ] chunked_context_metadata = CPChunkedContextMetadata( cu_seq_lens=cu_seq_lens_cpu.to(non_blocking=True), starts=local_chunk_starts.to(non_blocking=True), @@ -141,7 +137,6 @@ def get_chunk_metadata(pcp_size, dcp_size, num_prefills, num_decodes, max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), chunk_seq_lens=chunk_seq_lens, chunk_seq_lens_npu=chunk_seq_lens, - chunk_actual_seq_lengths_kv_list=chunk_actual_seq_lengths_kv_list, workspace=None, padded_chunk_seq_lens_npu=padded_local_chunk_seq_lens, padded_local_chunk_seq_lens=padded_local_chunk_seq_lens.tolist(), @@ -505,23 +500,19 @@ class TestAscendMLAImpl(TestBase): self.assertEqual(result.shape[1], self.impl.v_head_dim) @patch("torch_npu.atb.npu_paged_cache_load") - @patch("torch_npu.npu_attention_update") - @patch("torch_npu.npu_fused_infer_attention_score") + @patch("torch_npu.atb.npu_ring_mla") @patch_distributed_groups(dcp_size=2, pcp_size=2) def test_compute_prefill_context_with_dcp_pcp(self, mock_all2all, mock_dcp, - mock_pcp, mock_fia, - mock_update, mock_load): + mock_pcp, mock_ring, + mock_load): - def mock_fia_attn(*args, **kwargs): - q = args[0] - v = args[2] - return (torch.randn(q.shape[0], - v.shape[1], - v.shape[2], - dtype=torch.float16), - torch.randn(v.shape[1], q.shape[0], dtype=torch.float16)) + def mock_ring_attn(q_nope, q_rope, k_nope, k_rope, value, mask, seqlen, + head_num, kv_head_num, pre_out, prev_lse, qk_scale, + kernel_type, mask_type, input_layout, calc_type, + output, softmax_lse): + return torch.randn(q_rope.shape[0], value.shape[1], value.shape[2]) - mock_fia.side_effect = mock_fia_attn + mock_ring.side_effect = mock_ring_attn def mock_kv_b_proj(kv_c_normed): return (torch.randn(kv_c_normed.shape[0], @@ -543,13 +534,6 @@ class TestAscendMLAImpl(TestBase): # mock proj self.impl.kv_b_proj.side_effect = mock_kv_b_proj - - def mock_update_fn(lse_list, out_list, mode): - total_len = out_list[0].shape[0] - D = out_list[0].shape[1] - return (torch.randn(total_len, D, dtype=torch.float32), None) - - mock_update.side_effect = mock_update_fn NUM_BLOCKS, BLOCK_SIZE = 10, 32 # fixed USED_BLOCKS = 3 # pcp_size, dcp_size, nums_tokens_per_rank, nums_all_rank_context, num_prefills, num_decodes, num_seqs, cp_local_block_size, num_computed_tokens, num_computed_tokens_of_pcp_dcp @@ -602,8 +586,8 @@ class TestAscendMLAImpl(TestBase): self.impl.num_heads, self.impl.v_head_dim, dtype=torch.float16) - prefix_lse = torch.randn(self.impl.num_heads, - sum(nums_tokens_per_rank), + prefix_lse = torch.randn(sum(nums_tokens_per_rank), + self.impl.num_heads, dtype=torch.float16) chunk_ctx = get_chunk_metadata( pcp_size, @@ -618,7 +602,7 @@ class TestAscendMLAImpl(TestBase): cp_local_block_size=cp_local_block_size) meta = MagicMock() prefill_meta = MagicMock() - prefill_meta.query_lens = torch.tensor(nums_tokens_per_rank) + prefill_meta.query_lens = nums_tokens_per_rank prefill_meta.block_table = torch.randint( 0, USED_BLOCKS, (1, 64)) # (batch, max_blocks) prefill_meta.chunked_context = chunk_ctx @@ -637,14 +621,14 @@ class TestAscendMLAImpl(TestBase): self.assertEqual(mock_reorg.call_count, iters * (1 if dcp_size * pcp_size > 1 else 0)) self.assertEqual(mock_load.call_count, iters) - self.assertEqual(mock_fia.call_count, iters) + self.assertEqual(mock_ring.call_count, iters) mock_reorg.reset_mock() mock_load.reset_mock() - mock_fia.reset_mock() - mock_update.reset_mock() + mock_ring.reset_mock() mock_dcp.reset_mock() mock_pcp.reset_mock() self.assertEqual(out.shape, prefix_out.shape) + self.assertEqual(lse.shape, prefix_lse.shape) @patch_distributed_groups(dcp_size=2, pcp_size=2) def test_reorg_kvcache_with_dcp_pcp(self, mock_all2all, mock_dcp, diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index 7f457898..d2fa707b 100755 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -102,8 +102,7 @@ class TestAscendMLAPrefillMetadata(TestBase): max_seq_lens=max_seq_lens, workspace=workspace, chunk_seq_lens=chunk_seq_lens, - chunk_seq_lens_npu=chunk_seq_lens, - chunk_actual_seq_lengths_kv_list=[[2, 4]]) + chunk_seq_lens_npu=chunk_seq_lens) metadata = AscendMLAPrefillMetadata( attn_mask=torch.tensor([[1, 0], [1, 1]], dtype=torch.bool), @@ -888,9 +887,8 @@ class TestAscendMLAImpl(TestBase): self.assertTrue(torch.equal(prefix_lse, lse)) @patch("torch_npu.atb.npu_paged_cache_load") - @patch("torch_npu.npu_attention_update") - @patch("torch_npu.npu_fused_infer_attention_score") - def test_compute_prefill_context(self, mock_fia, mock_update, mock_load): + @patch("torch_npu.atb.npu_ring_mla") + def test_compute_prefill_context(self, mock_ring, mock_load): S, N, D, VD = 2, self.impl.num_heads, self.impl.qk_head_dim, self.impl.v_head_dim _, AND = self.impl.qk_rope_head_dim, self.impl.qk_nope_head_dim latent_kv_dim = self.impl.kv_lora_rank @@ -901,16 +899,11 @@ class TestAscendMLAImpl(TestBase): kv_cache_0 = torch.randn(num_blocks, block_size, N, latent_kv_dim) kv_cache_1 = torch.randn(num_blocks, block_size, N, D) kv_cache = [kv_cache_0, kv_cache_1] - prefix_out = torch.randn(S, N, VD) - prefix_lse = torch.randn(N, S) + prefix_out = torch.randn(S, N, 128) + prefix_lse = torch.randn(S, N) self.impl.kv_b_proj.return_value = (torch.randn(8, N, VD + AND), ) - # Mock FIA to return output and lse - mock_fia.return_value = (torch.randn(S, N, VD), torch.randn(N, S)) - # Mock attention_update to return merged output - mock_update.return_value = (torch.randn(S * N, VD), None) - chunk_ctx = MagicMock() chunk_ctx.seq_tot = [8] chunk_ctx.chunk_seq_lens = [torch.tensor([8])] @@ -919,7 +912,7 @@ class TestAscendMLAImpl(TestBase): prefill_meta = MagicMock() prefill_meta.chunked_context = chunk_ctx - prefill_meta.query_lens = torch.tensor([S]) + prefill_meta.query_lens = [8] prefill_meta.block_table = torch.randint(0, 100, (S, 4)) meta = MagicMock() @@ -932,10 +925,10 @@ class TestAscendMLAImpl(TestBase): prefix_lse) mock_load.assert_called_once() - mock_fia.assert_called_once() - mock_update.assert_called_once() + mock_ring.assert_called_once() self.assertEqual(out.shape, prefix_out.shape) + self.assertEqual(lse.shape, prefix_lse.shape) @patch('vllm_ascend.ascend_forward_context.get_forward_context') @patch("vllm_ascend.attention.mla_v1.AscendMLAImpl._v_up_proj") diff --git a/vllm_ascend/attention/context_parallel/common_cp.py b/vllm_ascend/attention/context_parallel/common_cp.py index 5aa5c147..c319698f 100644 --- a/vllm_ascend/attention/context_parallel/common_cp.py +++ b/vllm_ascend/attention/context_parallel/common_cp.py @@ -53,7 +53,6 @@ class CPChunkedContextMetadata: workspace: torch.Tensor chunk_seq_lens: torch.Tensor chunk_seq_lens_npu: torch.Tensor - chunk_actual_seq_lengths_kv_list: list[list[int]] # for mla DCP & PCP padded_chunk_seq_lens_npu: torch.Tensor = None padded_local_chunk_seq_lens: list[list[int]] | None = None diff --git a/vllm_ascend/attention/context_parallel/mla_cp.py b/vllm_ascend/attention/context_parallel/mla_cp.py index c1550a0b..aa9a0e0f 100644 --- a/vllm_ascend/attention/context_parallel/mla_cp.py +++ b/vllm_ascend/attention/context_parallel/mla_cp.py @@ -30,7 +30,6 @@ from vllm_ascend.attention.mla_v1 import ( # isort: on from vllm_ascend.ascend_forward_context import _EXTRA_CTX -from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.context_parallel.common_cp import ( AscendPCPMetadata, CPChunkedContextMetadata, @@ -190,7 +189,6 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder): max_seq_lens=chunked_context_metadata.max_seq_lens, chunk_seq_lens=self.chunk_seq_lens, chunk_seq_lens_npu=chunked_context_metadata.chunk_seq_lens_npu, - chunk_actual_seq_lengths_kv_list=chunked_context_metadata.chunk_actual_seq_lengths_kv_list, workspace=chunked_context_metadata.workspace, padded_chunk_seq_lens_npu=padded_local_chunk_seq_lens.npu(), padded_local_chunk_seq_lens=padded_local_chunk_seq_lens.tolist(), @@ -278,10 +276,6 @@ class AscendMlaCPImpl(AscendMLAImpl): **kwargs, ) - # npu_ring_mla needs bfloat16 512x512 mask, different from FIA's int8 2048x2048 mask - # TODO: Remove this when mla_cp.py also migrates to FIA - self._ring_mla_mask_builder = AttentionMaskBuilder(torch.device("npu")) - self.pcp_size = get_pcp_group().world_size self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_size > 1 else 0 self.pcp_group = get_pcp_group().device_group if self.pcp_size > 1 else None @@ -490,10 +484,6 @@ class AscendMlaCPImpl(AscendMLAImpl): attn_mask_seqlens = attn_metadata.prefill.pcp_metadata.attn_mask_seqlens head_attn_nomask_seqlens = attn_metadata.prefill.pcp_metadata.head_attn_nomask_seqlens tail_attn_nomask_seqlens = attn_metadata.prefill.pcp_metadata.tail_attn_nomask_seqlens - # Use ring_mla-specific mask (bfloat16, 512x512) - # TODO: Remove this when mla_cp.py migrates to FIA - ring_mla_mask = self._ring_mla_mask_builder.get_mla_mask(self.vllm_config.model_config.dtype) - output_head, lse_head = self._attention_with_mask_and_nomask( q_nope=torch.index_select(q_nope, 0, q_head_idx), q_pe=torch.index_select(q_pe, 0, q_head_idx), @@ -504,7 +494,7 @@ class AscendMlaCPImpl(AscendMLAImpl): kv_nomask_idx=kv_with_q_head_nomask_idx, attn_mask_seqlens=attn_mask_seqlens, attn_nomask_seqlens=head_attn_nomask_seqlens, - mask=ring_mla_mask, + mask=attn_metadata.attn_mask, ) output_tail, lse_tail = self._attention_with_mask_and_nomask( @@ -517,7 +507,7 @@ class AscendMlaCPImpl(AscendMLAImpl): kv_nomask_idx=kv_with_q_tail_nomask_idx, attn_mask_seqlens=attn_mask_seqlens, attn_nomask_seqlens=tail_attn_nomask_seqlens, - mask=ring_mla_mask, + mask=attn_metadata.attn_mask, ) q_full_idx = attn_metadata.prefill.pcp_metadata.q_full_idx diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index a38ba6b4..0e62df4b 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -114,7 +114,6 @@ class ChunkedContextMetadata: workspace: torch.Tensor chunk_seq_lens: torch.Tensor chunk_seq_lens_npu: torch.Tensor - chunk_actual_seq_lengths_kv_list: list[list[int]] @dataclass @@ -134,7 +133,6 @@ class AscendMLAPrefillMetadata: sin: torch.Tensor = None cos: torch.Tensor = None pcp_metadata: AscendPCPMetadata | None = None - actual_seq_lengths_q: list[int] | None = None @dataclass @@ -452,7 +450,7 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]): num_decodes=self.num_decodes, num_decode_tokens=self.num_decode_tokens, num_prefills=self.num_prefills, - attn_mask=self.attn_mask_builder.get_splitfuse_attn_mask(), + attn_mask=self.attn_mask_builder.get_final_mla_mask(self.model_config), attn_state=common_attn_metadata.attn_state, prefill=prefill_metadata, decode=decode_metadata, @@ -492,9 +490,6 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]): self.chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0) self.cu_seq_lens_cpu = torch.zeros(self.num_chunks, self.num_prefills + 1, dtype=torch.int32, pin_memory=True) torch.cumsum(self.chunk_seq_lens, dim=1, out=self.cu_seq_lens_cpu[:, 1:], dtype=torch.int32) - chunk_actual_seq_lengths_kv_list = [ - torch.cumsum(self.chunk_seq_lens[i], dim=0).tolist() for i in range(self.num_chunks) - ] return ChunkedContextMetadata( cu_seq_lens=self.cu_seq_lens_cpu.pin_memory().to(self.device, non_blocking=True), starts=chunk_starts.pin_memory().to(self.device, non_blocking=True), @@ -503,7 +498,6 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]): chunk_seq_lens=self.chunk_seq_lens, chunk_seq_lens_npu=self.chunk_seq_lens.npu(), workspace=self.chunked_prefill_workspace, - chunk_actual_seq_lengths_kv_list=chunk_actual_seq_lengths_kv_list, ) def get_block_table_size(self, common_attn_metadata: AscendCommonAttentionMetadata, build_metadata_step: int): @@ -538,9 +532,8 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]): prefill_input_positions = input_positions[tokens_start:] cos, sin = get_cos_and_sin_mla(prefill_input_positions) prefill_query_lens = self.query_lens[reqs_start:].to(torch.int32) - actual_seq_lengths_q = torch.cumsum(prefill_query_lens, dim=0).tolist() return AscendMLAPrefillMetadata( - attn_mask=self.attn_mask_builder.get_splitfuse_attn_mask(), + attn_mask=self.attn_mask_builder.get_final_mla_mask(self.model_config), query_lens=prefill_query_lens, seq_lens=self.seq_lens, context_lens=self.seq_lens[reqs_start:], @@ -552,7 +545,6 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]): chunked_context=chunked_context_metadata, sin=sin, cos=cos, - actual_seq_lengths_q=actual_seq_lengths_q, ) def build_decode_metadata( @@ -1056,29 +1048,18 @@ class AscendMLAImpl(MLAAttentionImpl): return prefix_output, prefix_lse iters = len(prefill_metadata.chunked_context.seq_tot) + + current_seq_len = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32) cache_kv_c = kv_c_and_k_pe_cache[0] cache_k_pe = kv_c_and_k_pe_cache[1] num_heads = cache_k_pe.size(2) latent_kv_dim = kv_c_and_k_pe_cache[0].size(-1) - actual_seq_lengths_q = prefill_metadata.actual_seq_lengths_q - - if iters == 0: - return prefix_output, prefix_lse - - num_tokens = q_nope.size(0) - D = self.v_head_dim - H = self.num_heads - - if prefix_lse.dim() == 2: - prefix_lse = prefix_lse.transpose(0, 1).unsqueeze(-1) - prefix_output = prefix_output.to(torch.float32) - prefix_lse = prefix_lse.to(torch.float32) - out_list = [prefix_output.reshape(num_tokens * H, D)] - lse_list = [prefix_lse.reshape(num_tokens * H)] - for i in range(iters): toks = prefill_metadata.chunked_context.seq_tot[i] + # chunk_seq_lens will be padded when pcp&dcp + context_seq_len = prefill_metadata.chunked_context.chunk_seq_lens[i] + seq_len = torch.stack([current_seq_len, context_seq_len]) context_seq_len_npu = self.get_context_seq_len_npu(i, attn_metadata) kv_c_normed = torch.empty(toks, num_heads, latent_kv_dim, dtype=q_nope.dtype, device=q_nope.device) k_pe = torch.empty(toks, num_heads, rope_dim, dtype=q_nope.dtype, device=q_nope.device) @@ -1104,35 +1085,29 @@ class AscendMLAImpl(MLAAttentionImpl): k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) k_pe = k_pe.expand((*k_nope.shape[:-1], -1)) - actual_seq_lengths_kv = prefill_metadata.chunked_context.chunk_actual_seq_lengths_kv_list[i] - - chunk_out, chunk_lse = torch_npu.npu_fused_infer_attention_score( - q_nope, - k_nope, - v, - query_rope=q_pe, - key_rope=k_pe, - num_heads=self.num_heads, - num_key_value_heads=self.num_heads, - input_layout="TND", - atten_mask=None, - sparse_mode=0, - scale=self.scale, - antiquant_mode=0, - antiquant_scale=None, - softmax_lse_flag=True, - actual_seq_lengths=actual_seq_lengths_q, - actual_seq_lengths_kv=actual_seq_lengths_kv, + mask = attn_metadata.attn_mask + torch_npu.atb.npu_ring_mla( + q_nope=q_nope, + q_rope=q_pe, + k_nope=k_nope, + k_rope=k_pe, + value=v, + mask=mask, + seqlen=seq_len, + head_num=self.num_heads, + kv_head_num=self.num_heads, + pre_out=prefix_output, + prev_lse=prefix_lse, + qk_scale=self.scale, + kernel_type="kernel_type_high_precision", + mask_type="no_mask", + input_layout="type_bsnd", + calc_type="calc_type_default", + output=prefix_output, + softmax_lse=prefix_lse, ) - if chunk_lse.dim() == 2: - chunk_lse = chunk_lse.transpose(0, 1).unsqueeze(-1) - chunk_out = chunk_out.to(torch.float32) - chunk_lse = chunk_lse.to(torch.float32) - out_list.append(chunk_out.reshape(num_tokens * H, D)) - lse_list.append(chunk_lse.reshape(num_tokens * H)) - output_final, _ = torch_npu.npu_attention_update(tuple(lse_list), tuple(out_list), 0) - return output_final.view(num_tokens, H, D), None + return prefix_output, prefix_lse def _forward_prefill( self, @@ -1147,54 +1122,35 @@ class AscendMLAImpl(MLAAttentionImpl): assert attn_metadata.prefill is not None assert len(kv_c_and_k_pe_cache) > 1 num_tokens = q_nope.size(0) - prefill_meta = attn_metadata.prefill - - actual_seq_lengths_q = prefill_meta.actual_seq_lengths_q - actual_seq_lengths_kv = actual_seq_lengths_q.copy() - - # FIA with TND layout only supports bfloat16, convert if needed - original_dtype = q_nope.dtype - need_dtype_convert = original_dtype != torch.bfloat16 - if need_dtype_convert: - q_nope = q_nope.to(torch.bfloat16) - q_pe = q_pe.to(torch.bfloat16) - k_nope = k_nope.to(torch.bfloat16) - k_pe = k_pe.to(torch.bfloat16) - value = value.to(torch.bfloat16) attn_output = torch.empty(num_tokens, self.num_heads, self.v_head_dim, dtype=q_nope.dtype, device=q_nope.device) attn_lse = torch.empty(self.num_heads, num_tokens, dtype=torch.float32, device=q_nope.device) - common_kwargs = { - "query_rope": q_pe, - "key_rope": k_pe, - "num_heads": self.num_heads, - "num_key_value_heads": self.num_heads, - "input_layout": "TND", - "atten_mask": prefill_meta.attn_mask, - "sparse_mode": 3, - "scale": self.scale, - "antiquant_mode": 0, - "antiquant_scale": None, - "block_table": None, - "block_size": 0, - "softmax_lse_flag": True, - "actual_seq_lengths": actual_seq_lengths_q, - "actual_seq_lengths_kv": actual_seq_lengths_kv, - } - - attn_output, attn_lse = torch_npu.npu_fused_infer_attention_score(q_nope, k_nope, value, **common_kwargs) - + torch_npu.atb.npu_ring_mla( + q_nope=q_nope, + q_rope=q_pe, + k_nope=k_nope, + k_rope=k_pe, + value=value, + mask=attn_metadata.attn_mask, + seqlen=attn_metadata.prefill.query_lens, + head_num=self.num_heads, + kv_head_num=self.num_heads, + pre_out=None, + prev_lse=None, + qk_scale=self.scale, + kernel_type="kernel_type_high_precision", + mask_type="mask_type_triu", + input_layout="type_bsnd", + calc_type="calc_type_first_ring", + output=attn_output, + softmax_lse=attn_lse, + ) attn_output, attn_lse = self._compute_prefill_context( q_nope, q_pe, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse ) attn_output = attn_output.reshape([num_tokens, self.num_heads * self.v_head_dim]) - - # Convert back to original dtype if needed - if need_dtype_convert: - attn_output = attn_output.to(original_dtype) - return attn_output def exec_kv_decode(