diff --git a/tests/e2e/singlecard/test_aclgraph_accuracy.py b/tests/e2e/singlecard/test_aclgraph_accuracy.py index 32b891b6..71711bf2 100644 --- a/tests/e2e/singlecard/test_aclgraph_accuracy.py +++ b/tests/e2e/singlecard/test_aclgraph_accuracy.py @@ -85,8 +85,8 @@ CASE_DS_FULL_DECODE_ONLY = LLMTestCase( prompts=PROMPTS_LONG, golden_answers=[ "\n\nSelect an assignment template", - "\n\nI'm not sure how to approach this problem. I'm not sure if I should use the law of total probability or if I should use", - "\n\n## Answer\n\n$a + b + c = 0$\n\nSolution\n\nLet $x$ be the common root of the equations", + "\n\nI'm not sure how to approach this problem. I'm thinking that the area of the triangle is $1/2$ times the area", + "\n\n## Answer\n\n$a + b + c = 0$\n\nSolution\n\nLet $x = \\alpha$ be the common root", ], ) @@ -106,8 +106,8 @@ CASE_DS_EX = LLMTestCase( prompts=PROMPTS_LONG, golden_answers=[ "\n\nSelect an assignment template", - "\n\nI'm not sure how to approach this problem. I'm not sure if I should use the law of total probability or if I should use", - "\n\n## Answer\n\n$a + b + c = 0$\n\nSolution\n\nLet $x$ be the common root of the equations", + "\n\nI'm not sure how to approach this problem. I'm thinking that the area of the triangle is $1/2$ times the area", + "\n\n## Answer\n\n$a + b + c = 0$\n\nSolution\n\nLet $x = \\alpha$ be the common root", ], ) diff --git a/tests/ut/attention/test_mla_cp.py b/tests/ut/attention/test_mla_cp.py index 1adf6419..6bcc6b4e 100755 --- a/tests/ut/attention/test_mla_cp.py +++ b/tests/ut/attention/test_mla_cp.py @@ -130,6 +130,10 @@ def get_chunk_metadata(pcp_size, dcp_size, num_prefills, num_decodes, out=padded_local_cu_chunk_seq_lens_cpu[:, 1:], dtype=torch.int32, ) + chunk_actual_seq_lengths_kv_list = [ + torch.cumsum(chunk_seq_lens[i], dim=0).tolist() + for i in range(num_chunks) + ] chunked_context_metadata = CPChunkedContextMetadata( cu_seq_lens=cu_seq_lens_cpu.to(non_blocking=True), starts=local_chunk_starts.to(non_blocking=True), @@ -137,6 +141,7 @@ def get_chunk_metadata(pcp_size, dcp_size, num_prefills, num_decodes, max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), chunk_seq_lens=chunk_seq_lens, chunk_seq_lens_npu=chunk_seq_lens, + chunk_actual_seq_lengths_kv_list=chunk_actual_seq_lengths_kv_list, workspace=None, padded_chunk_seq_lens_npu=padded_local_chunk_seq_lens, padded_local_chunk_seq_lens=padded_local_chunk_seq_lens.tolist(), @@ -500,19 +505,23 @@ class TestAscendMLAImpl(TestBase): self.assertEqual(result.shape[1], self.impl.v_head_dim) @patch("torch_npu.atb.npu_paged_cache_load") - @patch("torch_npu.atb.npu_ring_mla") + @patch("torch_npu.npu_attention_update") + @patch("torch_npu.npu_fused_infer_attention_score") @patch_distributed_groups(dcp_size=2, pcp_size=2) def test_compute_prefill_context_with_dcp_pcp(self, mock_all2all, mock_dcp, - mock_pcp, mock_ring, - mock_load): + mock_pcp, mock_fia, + mock_update, mock_load): - def mock_ring_attn(q_nope, q_rope, k_nope, k_rope, value, mask, seqlen, - head_num, kv_head_num, pre_out, prev_lse, qk_scale, - kernel_type, mask_type, input_layout, calc_type, - output, softmax_lse): - return torch.randn(q_rope.shape[0], value.shape[1], value.shape[2]) + def mock_fia_attn(*args, **kwargs): + q = args[0] + v = args[2] + return (torch.randn(q.shape[0], + v.shape[1], + v.shape[2], + dtype=torch.float16), + torch.randn(v.shape[1], q.shape[0], dtype=torch.float16)) - mock_ring.side_effect = mock_ring_attn + mock_fia.side_effect = mock_fia_attn def mock_kv_b_proj(kv_c_normed): return (torch.randn(kv_c_normed.shape[0], @@ -534,6 +543,13 @@ class TestAscendMLAImpl(TestBase): # mock proj self.impl.kv_b_proj.side_effect = mock_kv_b_proj + + def mock_update_fn(lse_list, out_list, mode): + total_len = out_list[0].shape[0] + D = out_list[0].shape[1] + return (torch.randn(total_len, D, dtype=torch.float32), None) + + mock_update.side_effect = mock_update_fn NUM_BLOCKS, BLOCK_SIZE = 10, 32 # fixed USED_BLOCKS = 3 # pcp_size, dcp_size, nums_tokens_per_rank, nums_all_rank_context, num_prefills, num_decodes, num_seqs, cp_local_block_size, num_computed_tokens, num_computed_tokens_of_pcp_dcp @@ -586,8 +602,8 @@ class TestAscendMLAImpl(TestBase): self.impl.num_heads, self.impl.v_head_dim, dtype=torch.float16) - prefix_lse = torch.randn(sum(nums_tokens_per_rank), - self.impl.num_heads, + prefix_lse = torch.randn(self.impl.num_heads, + sum(nums_tokens_per_rank), dtype=torch.float16) chunk_ctx = get_chunk_metadata( pcp_size, @@ -602,7 +618,7 @@ class TestAscendMLAImpl(TestBase): cp_local_block_size=cp_local_block_size) meta = MagicMock() prefill_meta = MagicMock() - prefill_meta.query_lens = nums_tokens_per_rank + prefill_meta.query_lens = torch.tensor(nums_tokens_per_rank) prefill_meta.block_table = torch.randint( 0, USED_BLOCKS, (1, 64)) # (batch, max_blocks) prefill_meta.chunked_context = chunk_ctx @@ -621,14 +637,14 @@ class TestAscendMLAImpl(TestBase): self.assertEqual(mock_reorg.call_count, iters * (1 if dcp_size * pcp_size > 1 else 0)) self.assertEqual(mock_load.call_count, iters) - self.assertEqual(mock_ring.call_count, iters) + self.assertEqual(mock_fia.call_count, iters) mock_reorg.reset_mock() mock_load.reset_mock() - mock_ring.reset_mock() + mock_fia.reset_mock() + mock_update.reset_mock() mock_dcp.reset_mock() mock_pcp.reset_mock() self.assertEqual(out.shape, prefix_out.shape) - self.assertEqual(lse.shape, prefix_lse.shape) @patch_distributed_groups(dcp_size=2, pcp_size=2) def test_reorg_kvcache_with_dcp_pcp(self, mock_all2all, mock_dcp, diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index 8fe78566..c625969f 100755 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -102,7 +102,8 @@ class TestAscendMLAPrefillMetadata(TestBase): max_seq_lens=max_seq_lens, workspace=workspace, chunk_seq_lens=chunk_seq_lens, - chunk_seq_lens_npu=chunk_seq_lens) + chunk_seq_lens_npu=chunk_seq_lens, + chunk_actual_seq_lengths_kv_list=[[2, 4]]) metadata = AscendMLAPrefillMetadata( attn_mask=torch.tensor([[1, 0], [1, 1]], dtype=torch.bool), @@ -886,8 +887,9 @@ class TestAscendMLAImpl(TestBase): self.assertTrue(torch.equal(prefix_lse, lse)) @patch("torch_npu.atb.npu_paged_cache_load") - @patch("torch_npu.atb.npu_ring_mla") - def test_compute_prefill_context(self, mock_ring, mock_load): + @patch("torch_npu.npu_attention_update") + @patch("torch_npu.npu_fused_infer_attention_score") + def test_compute_prefill_context(self, mock_fia, mock_update, mock_load): S, N, D, VD = 2, self.impl.num_heads, self.impl.qk_head_dim, self.impl.v_head_dim _, AND = self.impl.qk_rope_head_dim, self.impl.qk_nope_head_dim latent_kv_dim = self.impl.kv_lora_rank @@ -898,11 +900,16 @@ class TestAscendMLAImpl(TestBase): kv_cache_0 = torch.randn(num_blocks, block_size, N, latent_kv_dim) kv_cache_1 = torch.randn(num_blocks, block_size, N, D) kv_cache = [kv_cache_0, kv_cache_1] - prefix_out = torch.randn(S, N, 128) - prefix_lse = torch.randn(S, N) + prefix_out = torch.randn(S, N, VD) + prefix_lse = torch.randn(N, S) self.impl.kv_b_proj.return_value = (torch.randn(8, N, VD + AND), ) + # Mock FIA to return output and lse + mock_fia.return_value = (torch.randn(S, N, VD), torch.randn(N, S)) + # Mock attention_update to return merged output + mock_update.return_value = (torch.randn(S * N, VD), None) + chunk_ctx = MagicMock() chunk_ctx.seq_tot = [8] chunk_ctx.chunk_seq_lens = [torch.tensor([8])] @@ -911,7 +918,7 @@ class TestAscendMLAImpl(TestBase): prefill_meta = MagicMock() prefill_meta.chunked_context = chunk_ctx - prefill_meta.query_lens = [8] + prefill_meta.query_lens = torch.tensor([S]) prefill_meta.block_table = torch.randint(0, 100, (S, 4)) meta = MagicMock() @@ -924,10 +931,10 @@ class TestAscendMLAImpl(TestBase): prefix_lse) mock_load.assert_called_once() - mock_ring.assert_called_once() + mock_fia.assert_called_once() + mock_update.assert_called_once() self.assertEqual(out.shape, prefix_out.shape) - self.assertEqual(lse.shape, prefix_lse.shape) @patch('vllm_ascend.ascend_forward_context.get_forward_context') @patch("vllm_ascend.attention.mla_v1.AscendMLAImpl._v_up_proj") diff --git a/vllm_ascend/attention/context_parallel/common_cp.py b/vllm_ascend/attention/context_parallel/common_cp.py index c319698f..5aa5c147 100644 --- a/vllm_ascend/attention/context_parallel/common_cp.py +++ b/vllm_ascend/attention/context_parallel/common_cp.py @@ -53,6 +53,7 @@ class CPChunkedContextMetadata: workspace: torch.Tensor chunk_seq_lens: torch.Tensor chunk_seq_lens_npu: torch.Tensor + chunk_actual_seq_lengths_kv_list: list[list[int]] # for mla DCP & PCP padded_chunk_seq_lens_npu: torch.Tensor = None padded_local_chunk_seq_lens: list[list[int]] | None = None diff --git a/vllm_ascend/attention/context_parallel/mla_cp.py b/vllm_ascend/attention/context_parallel/mla_cp.py index 2ea6c529..c1bf1d36 100644 --- a/vllm_ascend/attention/context_parallel/mla_cp.py +++ b/vllm_ascend/attention/context_parallel/mla_cp.py @@ -30,6 +30,7 @@ from vllm_ascend.attention.mla_v1 import ( # isort: on from vllm_ascend.ascend_forward_context import _EXTRA_CTX +from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.context_parallel.common_cp import ( AscendPCPMetadata, CPChunkedContextMetadata, @@ -189,6 +190,7 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder): max_seq_lens=chunked_context_metadata.max_seq_lens, chunk_seq_lens=self.chunk_seq_lens, chunk_seq_lens_npu=chunked_context_metadata.chunk_seq_lens_npu, + chunk_actual_seq_lengths_kv_list=chunked_context_metadata.chunk_actual_seq_lengths_kv_list, workspace=chunked_context_metadata.workspace, padded_chunk_seq_lens_npu=padded_local_chunk_seq_lens.npu(), padded_local_chunk_seq_lens=padded_local_chunk_seq_lens.tolist(), @@ -276,6 +278,10 @@ class AscendMlaCPImpl(AscendMLAImpl): **kwargs, ) + # npu_ring_mla needs bfloat16 512x512 mask, different from FIA's int8 2048x2048 mask + # TODO: Remove this when mla_cp.py also migrates to FIA + self._ring_mla_mask_builder = AttentionMaskBuilder(torch.device("npu")) + self.pcp_size = get_pcp_group().world_size self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_size > 1 else 0 self.pcp_group = get_pcp_group().device_group if self.pcp_size > 1 else None @@ -484,6 +490,10 @@ class AscendMlaCPImpl(AscendMLAImpl): attn_mask_seqlens = attn_metadata.prefill.pcp_metadata.attn_mask_seqlens head_attn_nomask_seqlens = attn_metadata.prefill.pcp_metadata.head_attn_nomask_seqlens tail_attn_nomask_seqlens = attn_metadata.prefill.pcp_metadata.tail_attn_nomask_seqlens + # Use ring_mla-specific mask (bfloat16, 512x512) + # TODO: Remove this when mla_cp.py migrates to FIA + ring_mla_mask = self._ring_mla_mask_builder.get_mla_mask(self.vllm_config.model_config.dtype) + output_head, lse_head = self._attention_with_mask_and_nomask( q_nope=torch.index_select(q_nope, 0, q_head_idx), q_pe=torch.index_select(q_pe, 0, q_head_idx), @@ -494,7 +504,7 @@ class AscendMlaCPImpl(AscendMLAImpl): kv_nomask_idx=kv_with_q_head_nomask_idx, attn_mask_seqlens=attn_mask_seqlens, attn_nomask_seqlens=head_attn_nomask_seqlens, - mask=attn_metadata.attn_mask, + mask=ring_mla_mask, ) output_tail, lse_tail = self._attention_with_mask_and_nomask( @@ -507,7 +517,7 @@ class AscendMlaCPImpl(AscendMLAImpl): kv_nomask_idx=kv_with_q_tail_nomask_idx, attn_mask_seqlens=attn_mask_seqlens, attn_nomask_seqlens=tail_attn_nomask_seqlens, - mask=attn_metadata.attn_mask, + mask=ring_mla_mask, ) q_full_idx = attn_metadata.prefill.pcp_metadata.q_full_idx diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index c68d4ff0..3cb9615b 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -112,6 +112,7 @@ class ChunkedContextMetadata: workspace: torch.Tensor chunk_seq_lens: torch.Tensor chunk_seq_lens_npu: torch.Tensor + chunk_actual_seq_lengths_kv_list: list[list[int]] @dataclass @@ -131,6 +132,7 @@ class AscendMLAPrefillMetadata: sin: torch.Tensor = None cos: torch.Tensor = None pcp_metadata: AscendPCPMetadata | None = None + actual_seq_lengths_q: list[int] | None = None @dataclass @@ -447,7 +449,7 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]): num_decodes=self.num_decodes, num_decode_tokens=self.num_decode_tokens, num_prefills=self.num_prefills, - attn_mask=self.attn_mask_builder.get_final_mla_mask(self.model_config), + attn_mask=self.attn_mask_builder.get_splitfuse_attn_mask(), attn_state=common_attn_metadata.attn_state, prefill=prefill_metadata, decode=decode_metadata, @@ -486,6 +488,9 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]): self.chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0) self.cu_seq_lens_cpu = torch.zeros(self.num_chunks, self.num_prefills + 1, dtype=torch.int32, pin_memory=True) torch.cumsum(self.chunk_seq_lens, dim=1, out=self.cu_seq_lens_cpu[:, 1:], dtype=torch.int32) + chunk_actual_seq_lengths_kv_list = [ + torch.cumsum(self.chunk_seq_lens[i], dim=0).tolist() for i in range(self.num_chunks) + ] return ChunkedContextMetadata( cu_seq_lens=self.cu_seq_lens_cpu.pin_memory().to(self.device, non_blocking=True), starts=chunk_starts.pin_memory().to(self.device, non_blocking=True), @@ -494,6 +499,7 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]): chunk_seq_lens=self.chunk_seq_lens, chunk_seq_lens_npu=self.chunk_seq_lens.npu(), workspace=self.chunked_prefill_workspace, + chunk_actual_seq_lengths_kv_list=chunk_actual_seq_lengths_kv_list, ) def get_block_table_size(self, common_attn_metadata: AscendCommonAttentionMetadata, build_metadata_step: int): @@ -527,9 +533,11 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]): prefill_input_positions = input_positions[tokens_start:] cos, sin = get_cos_and_sin_mla(prefill_input_positions) + prefill_query_lens = self.query_lens[reqs_start:].to(torch.int32) + actual_seq_lengths_q = torch.cumsum(prefill_query_lens, dim=0).tolist() return AscendMLAPrefillMetadata( - attn_mask=self.attn_mask_builder.get_final_mla_mask(self.model_config), - query_lens=self.query_lens[reqs_start:].to(torch.int32), + attn_mask=self.attn_mask_builder.get_splitfuse_attn_mask(), + query_lens=prefill_query_lens, seq_lens=self.seq_lens, context_lens=self.seq_lens[reqs_start:], input_positions=prefill_input_positions, @@ -540,6 +548,7 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]): chunked_context=chunked_context_metadata, sin=sin, cos=cos, + actual_seq_lengths_q=actual_seq_lengths_q, ) def build_decode_metadata( @@ -887,8 +896,11 @@ class AscendMLAImpl(MLAAttentionImpl): post_process_after_loading_for_shard_weight_series(layer) def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype): - kv_a_proj_wt = self.fused_qkv_a_proj.weight.data[..., self.q_lora_rank :].contiguous() # type: ignore[union-attr] - q_a_proj_wt = self.fused_qkv_a_proj.weight.data[..., : self.q_lora_rank].contiguous() # type: ignore[union-attr] + assert self.fused_qkv_a_proj is not None + assert self.q_a_layernorm is not None + assert self.kv_a_layernorm is not None + kv_a_proj_wt = self.fused_qkv_a_proj.weight.data[..., self.q_lora_rank :].contiguous() + q_a_proj_wt = self.fused_qkv_a_proj.weight.data[..., : self.q_lora_rank].contiguous() kv_a_proj_wt = kv_a_proj_wt.t().contiguous() kv_a_proj_wt = trans_rope_weight(kv_a_proj_wt, self.qk_rope_head_dim) kv_a_proj_wt = kv_a_proj_wt.t().contiguous() @@ -990,17 +1002,18 @@ class AscendMLAImpl(MLAAttentionImpl): return prefix_output, prefix_lse iters = len(prefill_metadata.chunked_context.seq_tot) - - current_seq_len = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32) cache_kv_c = kv_c_and_k_pe_cache[0] cache_k_pe = kv_c_and_k_pe_cache[1] num_heads = cache_k_pe.size(2) latent_kv_dim = kv_c_and_k_pe_cache[0].size(-1) + + actual_seq_lengths_q = prefill_metadata.actual_seq_lengths_q + + chunk_outputs = [] + chunk_lses = [] + for i in range(iters): toks = prefill_metadata.chunked_context.seq_tot[i] - # chunk_seq_lens will be padded when pcp&dcp - context_seq_len = prefill_metadata.chunked_context.chunk_seq_lens[i] - seq_len = torch.stack([current_seq_len, context_seq_len]) context_seq_len_npu = self.get_context_seq_len_npu(i, attn_metadata) kv_c_normed = torch.empty(toks, num_heads, latent_kv_dim, dtype=q_nope.dtype, device=q_nope.device) k_pe = torch.empty(toks, num_heads, rope_dim, dtype=q_nope.dtype, device=q_nope.device) @@ -1026,27 +1039,61 @@ class AscendMLAImpl(MLAAttentionImpl): k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) k_pe = k_pe.expand((*k_nope.shape[:-1], -1)) - mask = attn_metadata.attn_mask - torch_npu.atb.npu_ring_mla( - q_nope=q_nope, - q_rope=q_pe, - k_nope=k_nope, - k_rope=k_pe, - value=v, - mask=mask, - seqlen=seq_len, - head_num=self.num_heads, - kv_head_num=self.num_heads, - pre_out=prefix_output, - prev_lse=prefix_lse, - qk_scale=self.scale, - kernel_type="kernel_type_high_precision", - mask_type="no_mask", - input_layout="type_bsnd", - calc_type="calc_type_default", - output=prefix_output, - softmax_lse=prefix_lse, + actual_seq_lengths_kv = prefill_metadata.chunked_context.chunk_actual_seq_lengths_kv_list[i] + + chunk_out, chunk_lse = torch_npu.npu_fused_infer_attention_score( + q_nope, + k_nope, + v, + query_rope=q_pe, + key_rope=k_pe, + num_heads=self.num_heads, + num_key_value_heads=self.num_heads, + input_layout="TND", + atten_mask=None, + sparse_mode=0, + scale=self.scale, + antiquant_mode=0, + antiquant_scale=None, + softmax_lse_flag=True, + actual_seq_lengths=actual_seq_lengths_q, + actual_seq_lengths_kv=actual_seq_lengths_kv, ) + chunk_outputs.append(chunk_out) + chunk_lses.append(chunk_lse) + + if len(chunk_outputs) > 0: + num_tokens = q_nope.size(0) + D = self.v_head_dim + H = self.num_heads + + # Normalize prefix output/lse to [num_tokens, H, D] and [num_tokens, H, 1] + prefix_output = prefix_output.to(torch.float32) + prefix_lse = prefix_lse.to(torch.float32) + if prefix_lse.dim() == 2: + prefix_lse = prefix_lse.transpose(0, 1).unsqueeze(-1) + + # Concat output and lse: [num_tokens, H, D+1] + all_out_lse = [torch.cat([prefix_output, prefix_lse], dim=-1)] + for chunk_out, chunk_lse in zip(chunk_outputs, chunk_lses): + chunk_out = chunk_out.to(torch.float32) + chunk_lse = chunk_lse.to(torch.float32) + if chunk_lse.dim() == 2: + chunk_lse = chunk_lse.transpose(0, 1).unsqueeze(-1) + all_out_lse.append(torch.cat([chunk_out, chunk_lse], dim=-1)) + + # Stack and split: [N, num_tokens, H, D+1] + all_out_lse = torch.stack(all_out_lse, dim=0) + N = all_out_lse.size(0) + out_flat, lse_flat = torch.split(all_out_lse, [D, 1], dim=-1) + + # Flatten and unbind for npu_attention_update + out_list = out_flat.view(N, num_tokens * H, D).unbind(0) + lse_list = lse_flat.view(N, num_tokens * H).unbind(0) + + output_final, _ = torch_npu.npu_attention_update(lse_list, out_list, 0) + return output_final.view(num_tokens, H, D), None + return prefix_output, prefix_lse def _forward_prefill( @@ -1062,33 +1109,54 @@ class AscendMLAImpl(MLAAttentionImpl): assert attn_metadata.prefill is not None assert len(kv_c_and_k_pe_cache) > 1 num_tokens = q_nope.size(0) + prefill_meta = attn_metadata.prefill + + actual_seq_lengths_q = prefill_meta.actual_seq_lengths_q + actual_seq_lengths_kv = actual_seq_lengths_q.copy() + + # FIA with TND layout only supports bfloat16, convert if needed + original_dtype = q_nope.dtype + need_dtype_convert = original_dtype != torch.bfloat16 + if need_dtype_convert: + q_nope = q_nope.to(torch.bfloat16) + q_pe = q_pe.to(torch.bfloat16) + k_nope = k_nope.to(torch.bfloat16) + k_pe = k_pe.to(torch.bfloat16) + value = value.to(torch.bfloat16) + attn_output = torch.empty(num_tokens, self.num_heads, self.v_head_dim, dtype=q_nope.dtype, device=q_nope.device) attn_lse = torch.empty(self.num_heads, num_tokens, dtype=torch.float32, device=q_nope.device) - torch_npu.atb.npu_ring_mla( - q_nope=q_nope, - q_rope=q_pe, - k_nope=k_nope, - k_rope=k_pe, - value=value, - mask=attn_metadata.attn_mask, - seqlen=attn_metadata.prefill.query_lens, - head_num=self.num_heads, - kv_head_num=self.num_heads, - pre_out=None, - prev_lse=None, - qk_scale=self.scale, - kernel_type="kernel_type_high_precision", - mask_type="mask_type_triu", - input_layout="type_bsnd", - calc_type="calc_type_first_ring", - output=attn_output, - softmax_lse=attn_lse, - ) + + common_kwargs = { + "query_rope": q_pe, + "key_rope": k_pe, + "num_heads": self.num_heads, + "num_key_value_heads": self.num_heads, + "input_layout": "TND", + "atten_mask": prefill_meta.attn_mask, + "sparse_mode": 3, + "scale": self.scale, + "antiquant_mode": 0, + "antiquant_scale": None, + "block_table": None, + "block_size": 0, + "softmax_lse_flag": True, + "actual_seq_lengths": actual_seq_lengths_q, + "actual_seq_lengths_kv": actual_seq_lengths_kv, + } + + attn_output, attn_lse = torch_npu.npu_fused_infer_attention_score(q_nope, k_nope, value, **common_kwargs) + attn_output, attn_lse = self._compute_prefill_context( q_nope, q_pe, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse ) attn_output = attn_output.reshape([num_tokens, self.num_heads * self.v_head_dim]) + + # Convert back to original dtype if needed + if need_dtype_convert: + attn_output = attn_output.to(original_dtype) + return attn_output def exec_kv_decode( @@ -1099,6 +1167,7 @@ class AscendMLAImpl(MLAAttentionImpl): kv_cache: tuple, slots: torch.Tensor, ): + assert self.kv_a_layernorm is not None B = kv_no_split.shape[0] N = self.num_kv_heads S = 1 @@ -1126,6 +1195,7 @@ class AscendMLAImpl(MLAAttentionImpl): kv_cache: tuple, slots: torch.Tensor, ): + assert self.kv_a_layernorm is not None B = kv_no_split.shape[0] N = self.num_kv_heads S = 1