diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index c55234b..9851e51 100644 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -82,7 +82,8 @@ class TestAscendMLAPrefillMetadata(TestBase): seq_tot=seq_tot, max_seq_lens=max_seq_lens, workspace=workspace, - chunk_seq_lens=chunk_seq_lens) + chunk_seq_lens=chunk_seq_lens, + chunk_seq_lens_npu=chunk_seq_lens) metadata = AscendMLAPrefillMetadata( attn_mask=torch.tensor([[1, 0], [1, 1]], dtype=torch.bool), @@ -103,6 +104,8 @@ class TestAscendMLAPrefillMetadata(TestBase): self.assertEqual(metadata.chunked_context.max_seq_lens, max_seq_lens) self.assertIs(metadata.chunked_context.workspace, workspace) self.assertIs(metadata.chunked_context.chunk_seq_lens, chunk_seq_lens) + self.assertIs(metadata.chunked_context.chunk_seq_lens_npu, + chunk_seq_lens) class TestAscendMLADecodeMetadata(TestBase): @@ -428,6 +431,7 @@ class TestAscendMLAImpl(TestBase): chunk_ctx = MagicMock() chunk_ctx.seq_tot = [8] chunk_ctx.chunk_seq_lens = [torch.tensor([8])] + chunk_ctx.chunk_seq_lens_npu = [torch.tensor([8])] chunk_ctx.starts = [torch.tensor([0])] prefill_meta = MagicMock() diff --git a/tests/ut/torchair/test_torchair_mla.py b/tests/ut/torchair/test_torchair_mla.py index 3dd1d2f..1f108b3 100644 --- a/tests/ut/torchair/test_torchair_mla.py +++ b/tests/ut/torchair/test_torchair_mla.py @@ -86,7 +86,8 @@ class TestAscendMLATorchairPrefillMetadata(TestBase): seq_tot=seq_tot, max_seq_lens=max_seq_lens, workspace=workspace, - chunk_seq_lens=chunk_seq_lens) + chunk_seq_lens=chunk_seq_lens, + chunk_seq_lens_npu=chunk_seq_lens) metadata = AscendMLATorchairPrefillMetadata( attn_mask=torch.tensor([[1, 0], [1, 1]], dtype=torch.bool), @@ -107,6 +108,8 @@ class TestAscendMLATorchairPrefillMetadata(TestBase): self.assertEqual(metadata.chunked_context.max_seq_lens, max_seq_lens) self.assertIs(metadata.chunked_context.workspace, workspace) self.assertIs(metadata.chunked_context.chunk_seq_lens, chunk_seq_lens) + self.assertIs(metadata.chunked_context.chunk_seq_lens_npu, + chunk_seq_lens) class TestAscendMLATorchairDecodeMetadata(TestBase): @@ -661,6 +664,7 @@ class TestAscendMLATorchairImpl(TestBase): chunk_ctx = MagicMock() chunk_ctx.seq_tot = [8] chunk_ctx.chunk_seq_lens = [torch.tensor([8])] + chunk_ctx.chunk_seq_lens_npu = [torch.tensor([8])] chunk_ctx.starts = [torch.tensor([0])] prefill_meta = MagicMock() diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 177d91b..4044126 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -80,6 +80,7 @@ class AscendMLAPrefillMetadata: max_seq_lens: list[int] workspace: torch.Tensor chunk_seq_lens: torch.Tensor + chunk_seq_lens_npu: torch.Tensor attn_mask: torch.Tensor query_lens: torch.Tensor @@ -371,6 +372,7 @@ class AscendMLAMetadataBuilder: seq_tot=chunk_seq_lens.sum(dim=1).tolist(), max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), chunk_seq_lens=chunk_seq_lens, + chunk_seq_lens_npu=chunk_seq_lens.npu(), workspace=self.chunked_prefill_workspace, ) prefill_input_positions = input_positions[tokens_start:] @@ -766,7 +768,8 @@ class AscendMLAImpl(MLAAttentionImpl): iters = len(prefill_metadata.chunked_context.seq_tot) - seq_len1 = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32) + current_seq_len = torch.tensor(prefill_metadata.query_lens, + dtype=torch.int32) cache_kv_c = kv_c_and_k_pe_cache[0] cache_k_pe = kv_c_and_k_pe_cache[1] num_heads = cache_k_pe.size(2) @@ -774,8 +777,11 @@ class AscendMLAImpl(MLAAttentionImpl): for i in range(iters): toks = prefill_metadata.chunked_context.seq_tot[i] - seq_len2 = prefill_metadata.chunked_context.chunk_seq_lens[i] - seq_len = torch.stack([seq_len1, seq_len2]) + context_seq_len = prefill_metadata.chunked_context.chunk_seq_lens[ + i] + context_seq_len_npu = prefill_metadata.chunked_context.chunk_seq_lens_npu[ + i] + seq_len = torch.stack([current_seq_len, context_seq_len]) kv_c_normed = torch.empty(toks, num_heads, latent_kv_dim, @@ -791,7 +797,7 @@ class AscendMLAImpl(MLAAttentionImpl): cache_kv_c, cache_k_pe, prefill_metadata.block_table, - seq_len2.to(q_nope.device), + context_seq_len_npu, seq_starts=prefill_metadata.chunked_context.starts[i], key=kv_c_normed, value=k_pe, diff --git a/vllm_ascend/torchair/torchair_mla.py b/vllm_ascend/torchair/torchair_mla.py index 32543a8..3ffcdfb 100644 --- a/vllm_ascend/torchair/torchair_mla.py +++ b/vllm_ascend/torchair/torchair_mla.py @@ -72,6 +72,7 @@ class AscendMLATorchairPrefillMetadata: max_seq_lens: list[int] workspace: torch.Tensor chunk_seq_lens: torch.Tensor + chunk_seq_lens_npu: torch.Tensor attn_mask: torch.Tensor query_lens: torch.Tensor @@ -462,6 +463,7 @@ class AscendMLATorchairMetadataBuilder: seq_tot=chunk_seq_lens.sum(dim=1).tolist(), max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), chunk_seq_lens=chunk_seq_lens, + chunk_seq_lens_npu=chunk_seq_lens.npu(), workspace=self.chunked_prefill_workspace, ) prefill_input_positions = input_positions[tokens_start:] @@ -777,7 +779,8 @@ class AscendMLATorchairImpl(MLAAttentionImpl): q_pe = query[..., self.qk_nope_head_dim:] q_nope = query[..., :self.qk_nope_head_dim] - seq_len1 = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32) + current_seq_len = torch.tensor(prefill_metadata.query_lens, + dtype=torch.int32) cache_kv_c = kv_c_and_k_pe_cache[0] cache_k_pe = kv_c_and_k_pe_cache[1] num_heads = cache_k_pe.size(2) @@ -785,8 +788,11 @@ class AscendMLATorchairImpl(MLAAttentionImpl): for i in range(iters): toks = prefill_metadata.chunked_context.seq_tot[i] - seq_len2 = prefill_metadata.chunked_context.chunk_seq_lens[i] - seq_len = torch.stack([seq_len1, seq_len2]) + context_seq_len = prefill_metadata.chunked_context.chunk_seq_lens[ + i] + context_seq_len_npu = prefill_metadata.chunked_context.chunk_seq_lens_npu[ + i] + seq_len = torch.stack([current_seq_len, context_seq_len]) kv_c_normed = torch.empty(toks, num_heads, latent_kv_dim, @@ -802,7 +808,7 @@ class AscendMLATorchairImpl(MLAAttentionImpl): cache_kv_c, cache_k_pe, prefill_metadata.block_table, - seq_len2.to(query.device), + context_seq_len_npu, seq_starts=prefill_metadata.chunked_context.starts[i], key=kv_c_normed, value=k_pe,