From 48094148f84f5dad22c9eaf15c052f6f7116af65 Mon Sep 17 00:00:00 2001 From: hucong <33891520+underfituu@users.noreply.github.com> Date: Sat, 8 Nov 2025 18:45:31 +0800 Subject: [PATCH] [BugFix] Improve the performance of prefixcache features (#4022) ### What this PR does / why we need it? The code bug caused an empty bubble. When the npu_paged_cache_load operator was called, it forcibly transferred seq_len2 to the device, which triggered synchronization and interrupted the CPU operator's launch stream. - vLLM version: v0.11.0 - vLLM main: https://github.com/vllm-project/vllm/commit/83f478bb19489b41e9d208b47b4bb5a95ac171ac --------- Signed-off-by: underfituu --- .github/workflows/vllm_ascend_test_nightly_a3.yaml | 1 + tests/ut/attention/test_mla_v1.py | 6 +++++- tests/ut/torchair/test_torchair_mla.py | 6 +++++- vllm_ascend/attention/mla_v1.py | 14 ++++++++++---- vllm_ascend/torchair/torchair_mla.py | 14 ++++++++++---- 5 files changed, 31 insertions(+), 10 deletions(-) diff --git a/.github/workflows/vllm_ascend_test_nightly_a3.yaml b/.github/workflows/vllm_ascend_test_nightly_a3.yaml index d880a8bf..74741b95 100644 --- a/.github/workflows/vllm_ascend_test_nightly_a3.yaml +++ b/.github/workflows/vllm_ascend_test_nightly_a3.yaml @@ -119,3 +119,4 @@ jobs: config_file_path: ${{ matrix.test_config.config_file_path }} secrets: KUBECONFIG_B64: ${{ secrets.KUBECONFIG_B64 }} + \ No newline at end of file diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index d8ddc6a6..8d15bcaa 100644 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -82,7 +82,8 @@ class TestAscendMLAPrefillMetadata(TestBase): seq_tot=seq_tot, max_seq_lens=max_seq_lens, workspace=workspace, - chunk_seq_lens=chunk_seq_lens) + chunk_seq_lens=chunk_seq_lens, + chunk_seq_lens_npu=chunk_seq_lens) metadata = AscendMLAPrefillMetadata( attn_mask=torch.tensor([[1, 0], [1, 1]], dtype=torch.bool), @@ -103,6 +104,8 @@ class TestAscendMLAPrefillMetadata(TestBase): self.assertEqual(metadata.chunked_context.max_seq_lens, max_seq_lens) self.assertIs(metadata.chunked_context.workspace, workspace) self.assertIs(metadata.chunked_context.chunk_seq_lens, chunk_seq_lens) + self.assertIs(metadata.chunked_context.chunk_seq_lens_npu, + chunk_seq_lens) class TestAscendMLADecodeMetadata(TestBase): @@ -478,6 +481,7 @@ class TestAscendMLAImpl(TestBase): chunk_ctx = MagicMock() chunk_ctx.seq_tot = [8] chunk_ctx.chunk_seq_lens = [torch.tensor([8])] + chunk_ctx.chunk_seq_lens_npu = [torch.tensor([8])] chunk_ctx.starts = [torch.tensor([0])] prefill_meta = MagicMock() diff --git a/tests/ut/torchair/test_torchair_mla.py b/tests/ut/torchair/test_torchair_mla.py index 3dd1d2f7..1f108b3e 100644 --- a/tests/ut/torchair/test_torchair_mla.py +++ b/tests/ut/torchair/test_torchair_mla.py @@ -86,7 +86,8 @@ class TestAscendMLATorchairPrefillMetadata(TestBase): seq_tot=seq_tot, max_seq_lens=max_seq_lens, workspace=workspace, - chunk_seq_lens=chunk_seq_lens) + chunk_seq_lens=chunk_seq_lens, + chunk_seq_lens_npu=chunk_seq_lens) metadata = AscendMLATorchairPrefillMetadata( attn_mask=torch.tensor([[1, 0], [1, 1]], dtype=torch.bool), @@ -107,6 +108,8 @@ class TestAscendMLATorchairPrefillMetadata(TestBase): self.assertEqual(metadata.chunked_context.max_seq_lens, max_seq_lens) self.assertIs(metadata.chunked_context.workspace, workspace) self.assertIs(metadata.chunked_context.chunk_seq_lens, chunk_seq_lens) + self.assertIs(metadata.chunked_context.chunk_seq_lens_npu, + chunk_seq_lens) class TestAscendMLATorchairDecodeMetadata(TestBase): @@ -661,6 +664,7 @@ class TestAscendMLATorchairImpl(TestBase): chunk_ctx = MagicMock() chunk_ctx.seq_tot = [8] chunk_ctx.chunk_seq_lens = [torch.tensor([8])] + chunk_ctx.chunk_seq_lens_npu = [torch.tensor([8])] chunk_ctx.starts = [torch.tensor([0])] prefill_meta = MagicMock() diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index f2717851..6d5c1397 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -110,6 +110,7 @@ class AscendMLAPrefillMetadata: max_seq_lens: list[int] workspace: torch.Tensor chunk_seq_lens: torch.Tensor + chunk_seq_lens_npu: torch.Tensor attn_mask: torch.Tensor query_lens: torch.Tensor @@ -449,6 +450,7 @@ class AscendMLAMetadataBuilder: seq_tot=chunk_seq_lens.sum(dim=1).tolist(), max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), chunk_seq_lens=chunk_seq_lens, + chunk_seq_lens_npu=chunk_seq_lens.npu(), workspace=self.chunked_prefill_workspace, ) prefill_input_positions = input_positions[tokens_start:] @@ -888,7 +890,8 @@ class AscendMLAImpl(MLAAttentionImpl): iters = len(prefill_metadata.chunked_context.seq_tot) - seq_len1 = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32) + current_seq_len = torch.tensor(prefill_metadata.query_lens, + dtype=torch.int32) cache_kv_c = kv_c_and_k_pe_cache[0] cache_k_pe = kv_c_and_k_pe_cache[1] num_heads = cache_k_pe.size(2) @@ -896,8 +899,11 @@ class AscendMLAImpl(MLAAttentionImpl): for i in range(iters): toks = prefill_metadata.chunked_context.seq_tot[i] - seq_len2 = prefill_metadata.chunked_context.chunk_seq_lens[i] - seq_len = torch.stack([seq_len1, seq_len2]) + context_seq_len = prefill_metadata.chunked_context.chunk_seq_lens[ + i] + context_seq_len_npu = prefill_metadata.chunked_context.chunk_seq_lens_npu[ + i] + seq_len = torch.stack([current_seq_len, context_seq_len]) kv_c_normed = torch.empty(toks, num_heads, latent_kv_dim, @@ -913,7 +919,7 @@ class AscendMLAImpl(MLAAttentionImpl): cache_kv_c, cache_k_pe, prefill_metadata.block_table, - seq_len2.to(q_nope.device), + context_seq_len_npu, seq_starts=prefill_metadata.chunked_context.starts[i], key=kv_c_normed, value=k_pe, diff --git a/vllm_ascend/torchair/torchair_mla.py b/vllm_ascend/torchair/torchair_mla.py index ce539b7d..51becad9 100644 --- a/vllm_ascend/torchair/torchair_mla.py +++ b/vllm_ascend/torchair/torchair_mla.py @@ -69,6 +69,7 @@ class AscendMLATorchairPrefillMetadata: max_seq_lens: list[int] workspace: torch.Tensor chunk_seq_lens: torch.Tensor + chunk_seq_lens_npu: torch.Tensor attn_mask: torch.Tensor query_lens: torch.Tensor @@ -447,6 +448,7 @@ class AscendMLATorchairMetadataBuilder: seq_tot=chunk_seq_lens.sum(dim=1).tolist(), max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), chunk_seq_lens=chunk_seq_lens, + chunk_seq_lens_npu=chunk_seq_lens.npu(), workspace=self.chunked_prefill_workspace, ) prefill_input_positions = input_positions[tokens_start:] @@ -760,7 +762,8 @@ class AscendMLATorchairImpl(MLAAttentionImpl): q_pe = query[..., self.qk_nope_head_dim:] q_nope = query[..., :self.qk_nope_head_dim] - seq_len1 = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32) + current_seq_len = torch.tensor(prefill_metadata.query_lens, + dtype=torch.int32) cache_kv_c = kv_c_and_k_pe_cache[0] cache_k_pe = kv_c_and_k_pe_cache[1] num_heads = cache_k_pe.size(2) @@ -768,8 +771,11 @@ class AscendMLATorchairImpl(MLAAttentionImpl): for i in range(iters): toks = prefill_metadata.chunked_context.seq_tot[i] - seq_len2 = prefill_metadata.chunked_context.chunk_seq_lens[i] - seq_len = torch.stack([seq_len1, seq_len2]) + context_seq_len = prefill_metadata.chunked_context.chunk_seq_lens[ + i] + context_seq_len_npu = prefill_metadata.chunked_context.chunk_seq_lens_npu[ + i] + seq_len = torch.stack([current_seq_len, context_seq_len]) kv_c_normed = torch.empty(toks, num_heads, latent_kv_dim, @@ -785,7 +791,7 @@ class AscendMLATorchairImpl(MLAAttentionImpl): cache_kv_c, cache_k_pe, prefill_metadata.block_table, - seq_len2.to(query.device), + context_seq_len_npu, seq_starts=prefill_metadata.chunked_context.starts[i], key=kv_c_normed, value=k_pe,