diff --git a/.github/workflows/vllm_ascend_test_nightly_a3.yaml b/.github/workflows/vllm_ascend_test_nightly_a3.yaml index d880a8bf..74741b95 100644 --- a/.github/workflows/vllm_ascend_test_nightly_a3.yaml +++ b/.github/workflows/vllm_ascend_test_nightly_a3.yaml @@ -119,3 +119,4 @@ jobs: config_file_path: ${{ matrix.test_config.config_file_path }} secrets: KUBECONFIG_B64: ${{ secrets.KUBECONFIG_B64 }} + \ No newline at end of file diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index d8ddc6a6..8d15bcaa 100644 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -82,7 +82,8 @@ class TestAscendMLAPrefillMetadata(TestBase): seq_tot=seq_tot, max_seq_lens=max_seq_lens, workspace=workspace, - chunk_seq_lens=chunk_seq_lens) + chunk_seq_lens=chunk_seq_lens, + chunk_seq_lens_npu=chunk_seq_lens) metadata = AscendMLAPrefillMetadata( attn_mask=torch.tensor([[1, 0], [1, 1]], dtype=torch.bool), @@ -103,6 +104,8 @@ class TestAscendMLAPrefillMetadata(TestBase): self.assertEqual(metadata.chunked_context.max_seq_lens, max_seq_lens) self.assertIs(metadata.chunked_context.workspace, workspace) self.assertIs(metadata.chunked_context.chunk_seq_lens, chunk_seq_lens) + self.assertIs(metadata.chunked_context.chunk_seq_lens_npu, + chunk_seq_lens) class TestAscendMLADecodeMetadata(TestBase): @@ -478,6 +481,7 @@ class TestAscendMLAImpl(TestBase): chunk_ctx = MagicMock() chunk_ctx.seq_tot = [8] chunk_ctx.chunk_seq_lens = [torch.tensor([8])] + chunk_ctx.chunk_seq_lens_npu = [torch.tensor([8])] chunk_ctx.starts = [torch.tensor([0])] prefill_meta = MagicMock() diff --git a/tests/ut/torchair/test_torchair_mla.py b/tests/ut/torchair/test_torchair_mla.py index 3dd1d2f7..1f108b3e 100644 --- a/tests/ut/torchair/test_torchair_mla.py +++ b/tests/ut/torchair/test_torchair_mla.py @@ -86,7 +86,8 @@ class TestAscendMLATorchairPrefillMetadata(TestBase): seq_tot=seq_tot, max_seq_lens=max_seq_lens, workspace=workspace, - chunk_seq_lens=chunk_seq_lens) + chunk_seq_lens=chunk_seq_lens, + chunk_seq_lens_npu=chunk_seq_lens) metadata = AscendMLATorchairPrefillMetadata( attn_mask=torch.tensor([[1, 0], [1, 1]], dtype=torch.bool), @@ -107,6 +108,8 @@ class TestAscendMLATorchairPrefillMetadata(TestBase): self.assertEqual(metadata.chunked_context.max_seq_lens, max_seq_lens) self.assertIs(metadata.chunked_context.workspace, workspace) self.assertIs(metadata.chunked_context.chunk_seq_lens, chunk_seq_lens) + self.assertIs(metadata.chunked_context.chunk_seq_lens_npu, + chunk_seq_lens) class TestAscendMLATorchairDecodeMetadata(TestBase): @@ -661,6 +664,7 @@ class TestAscendMLATorchairImpl(TestBase): chunk_ctx = MagicMock() chunk_ctx.seq_tot = [8] chunk_ctx.chunk_seq_lens = [torch.tensor([8])] + chunk_ctx.chunk_seq_lens_npu = [torch.tensor([8])] chunk_ctx.starts = [torch.tensor([0])] prefill_meta = MagicMock() diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index f2717851..6d5c1397 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -110,6 +110,7 @@ class AscendMLAPrefillMetadata: max_seq_lens: list[int] workspace: torch.Tensor chunk_seq_lens: torch.Tensor + chunk_seq_lens_npu: torch.Tensor attn_mask: torch.Tensor query_lens: torch.Tensor @@ -449,6 +450,7 @@ class AscendMLAMetadataBuilder: seq_tot=chunk_seq_lens.sum(dim=1).tolist(), max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), chunk_seq_lens=chunk_seq_lens, + chunk_seq_lens_npu=chunk_seq_lens.npu(), workspace=self.chunked_prefill_workspace, ) prefill_input_positions = input_positions[tokens_start:] @@ -888,7 +890,8 @@ class AscendMLAImpl(MLAAttentionImpl): iters = len(prefill_metadata.chunked_context.seq_tot) - seq_len1 = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32) + current_seq_len = torch.tensor(prefill_metadata.query_lens, + dtype=torch.int32) cache_kv_c = kv_c_and_k_pe_cache[0] cache_k_pe = kv_c_and_k_pe_cache[1] num_heads = cache_k_pe.size(2) @@ -896,8 +899,11 @@ class AscendMLAImpl(MLAAttentionImpl): for i in range(iters): toks = prefill_metadata.chunked_context.seq_tot[i] - seq_len2 = prefill_metadata.chunked_context.chunk_seq_lens[i] - seq_len = torch.stack([seq_len1, seq_len2]) + context_seq_len = prefill_metadata.chunked_context.chunk_seq_lens[ + i] + context_seq_len_npu = prefill_metadata.chunked_context.chunk_seq_lens_npu[ + i] + seq_len = torch.stack([current_seq_len, context_seq_len]) kv_c_normed = torch.empty(toks, num_heads, latent_kv_dim, @@ -913,7 +919,7 @@ class AscendMLAImpl(MLAAttentionImpl): cache_kv_c, cache_k_pe, prefill_metadata.block_table, - seq_len2.to(q_nope.device), + context_seq_len_npu, seq_starts=prefill_metadata.chunked_context.starts[i], key=kv_c_normed, value=k_pe, diff --git a/vllm_ascend/torchair/torchair_mla.py b/vllm_ascend/torchair/torchair_mla.py index ce539b7d..51becad9 100644 --- a/vllm_ascend/torchair/torchair_mla.py +++ b/vllm_ascend/torchair/torchair_mla.py @@ -69,6 +69,7 @@ class AscendMLATorchairPrefillMetadata: max_seq_lens: list[int] workspace: torch.Tensor chunk_seq_lens: torch.Tensor + chunk_seq_lens_npu: torch.Tensor attn_mask: torch.Tensor query_lens: torch.Tensor @@ -447,6 +448,7 @@ class AscendMLATorchairMetadataBuilder: seq_tot=chunk_seq_lens.sum(dim=1).tolist(), max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), chunk_seq_lens=chunk_seq_lens, + chunk_seq_lens_npu=chunk_seq_lens.npu(), workspace=self.chunked_prefill_workspace, ) prefill_input_positions = input_positions[tokens_start:] @@ -760,7 +762,8 @@ class AscendMLATorchairImpl(MLAAttentionImpl): q_pe = query[..., self.qk_nope_head_dim:] q_nope = query[..., :self.qk_nope_head_dim] - seq_len1 = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32) + current_seq_len = torch.tensor(prefill_metadata.query_lens, + dtype=torch.int32) cache_kv_c = kv_c_and_k_pe_cache[0] cache_k_pe = kv_c_and_k_pe_cache[1] num_heads = cache_k_pe.size(2) @@ -768,8 +771,11 @@ class AscendMLATorchairImpl(MLAAttentionImpl): for i in range(iters): toks = prefill_metadata.chunked_context.seq_tot[i] - seq_len2 = prefill_metadata.chunked_context.chunk_seq_lens[i] - seq_len = torch.stack([seq_len1, seq_len2]) + context_seq_len = prefill_metadata.chunked_context.chunk_seq_lens[ + i] + context_seq_len_npu = prefill_metadata.chunked_context.chunk_seq_lens_npu[ + i] + seq_len = torch.stack([current_seq_len, context_seq_len]) kv_c_normed = torch.empty(toks, num_heads, latent_kv_dim, @@ -785,7 +791,7 @@ class AscendMLATorchairImpl(MLAAttentionImpl): cache_kv_c, cache_k_pe, prefill_metadata.block_table, - seq_len2.to(query.device), + context_seq_len_npu, seq_starts=prefill_metadata.chunked_context.starts[i], key=kv_c_normed, value=k_pe,