[BugFix] Improve the performance of prefixcache features (#4022)

### What this PR does / why we need it?
The code bug caused an empty bubble. When the npu_paged_cache_load
operator was called, it forcibly transferred seq_len2 to the device,
which triggered synchronization and interrupted the CPU operator's
launch stream.

- vLLM version: v0.11.0
- vLLM main:
83f478bb19

---------

Signed-off-by: underfituu <hzhucong@163.com>
This commit is contained in:
hucong
2025-11-08 18:45:31 +08:00
committed by GitHub
parent 1d81a289d0
commit 48094148f8
5 changed files with 31 additions and 10 deletions

View File

@@ -119,3 +119,4 @@ jobs:
config_file_path: ${{ matrix.test_config.config_file_path }} config_file_path: ${{ matrix.test_config.config_file_path }}
secrets: secrets:
KUBECONFIG_B64: ${{ secrets.KUBECONFIG_B64 }} KUBECONFIG_B64: ${{ secrets.KUBECONFIG_B64 }}

View File

@@ -82,7 +82,8 @@ class TestAscendMLAPrefillMetadata(TestBase):
seq_tot=seq_tot, seq_tot=seq_tot,
max_seq_lens=max_seq_lens, max_seq_lens=max_seq_lens,
workspace=workspace, workspace=workspace,
chunk_seq_lens=chunk_seq_lens) chunk_seq_lens=chunk_seq_lens,
chunk_seq_lens_npu=chunk_seq_lens)
metadata = AscendMLAPrefillMetadata( metadata = AscendMLAPrefillMetadata(
attn_mask=torch.tensor([[1, 0], [1, 1]], dtype=torch.bool), attn_mask=torch.tensor([[1, 0], [1, 1]], dtype=torch.bool),
@@ -103,6 +104,8 @@ class TestAscendMLAPrefillMetadata(TestBase):
self.assertEqual(metadata.chunked_context.max_seq_lens, max_seq_lens) self.assertEqual(metadata.chunked_context.max_seq_lens, max_seq_lens)
self.assertIs(metadata.chunked_context.workspace, workspace) self.assertIs(metadata.chunked_context.workspace, workspace)
self.assertIs(metadata.chunked_context.chunk_seq_lens, chunk_seq_lens) self.assertIs(metadata.chunked_context.chunk_seq_lens, chunk_seq_lens)
self.assertIs(metadata.chunked_context.chunk_seq_lens_npu,
chunk_seq_lens)
class TestAscendMLADecodeMetadata(TestBase): class TestAscendMLADecodeMetadata(TestBase):
@@ -478,6 +481,7 @@ class TestAscendMLAImpl(TestBase):
chunk_ctx = MagicMock() chunk_ctx = MagicMock()
chunk_ctx.seq_tot = [8] chunk_ctx.seq_tot = [8]
chunk_ctx.chunk_seq_lens = [torch.tensor([8])] chunk_ctx.chunk_seq_lens = [torch.tensor([8])]
chunk_ctx.chunk_seq_lens_npu = [torch.tensor([8])]
chunk_ctx.starts = [torch.tensor([0])] chunk_ctx.starts = [torch.tensor([0])]
prefill_meta = MagicMock() prefill_meta = MagicMock()

View File

@@ -86,7 +86,8 @@ class TestAscendMLATorchairPrefillMetadata(TestBase):
seq_tot=seq_tot, seq_tot=seq_tot,
max_seq_lens=max_seq_lens, max_seq_lens=max_seq_lens,
workspace=workspace, workspace=workspace,
chunk_seq_lens=chunk_seq_lens) chunk_seq_lens=chunk_seq_lens,
chunk_seq_lens_npu=chunk_seq_lens)
metadata = AscendMLATorchairPrefillMetadata( metadata = AscendMLATorchairPrefillMetadata(
attn_mask=torch.tensor([[1, 0], [1, 1]], dtype=torch.bool), attn_mask=torch.tensor([[1, 0], [1, 1]], dtype=torch.bool),
@@ -107,6 +108,8 @@ class TestAscendMLATorchairPrefillMetadata(TestBase):
self.assertEqual(metadata.chunked_context.max_seq_lens, max_seq_lens) self.assertEqual(metadata.chunked_context.max_seq_lens, max_seq_lens)
self.assertIs(metadata.chunked_context.workspace, workspace) self.assertIs(metadata.chunked_context.workspace, workspace)
self.assertIs(metadata.chunked_context.chunk_seq_lens, chunk_seq_lens) self.assertIs(metadata.chunked_context.chunk_seq_lens, chunk_seq_lens)
self.assertIs(metadata.chunked_context.chunk_seq_lens_npu,
chunk_seq_lens)
class TestAscendMLATorchairDecodeMetadata(TestBase): class TestAscendMLATorchairDecodeMetadata(TestBase):
@@ -661,6 +664,7 @@ class TestAscendMLATorchairImpl(TestBase):
chunk_ctx = MagicMock() chunk_ctx = MagicMock()
chunk_ctx.seq_tot = [8] chunk_ctx.seq_tot = [8]
chunk_ctx.chunk_seq_lens = [torch.tensor([8])] chunk_ctx.chunk_seq_lens = [torch.tensor([8])]
chunk_ctx.chunk_seq_lens_npu = [torch.tensor([8])]
chunk_ctx.starts = [torch.tensor([0])] chunk_ctx.starts = [torch.tensor([0])]
prefill_meta = MagicMock() prefill_meta = MagicMock()

View File

@@ -110,6 +110,7 @@ class AscendMLAPrefillMetadata:
max_seq_lens: list[int] max_seq_lens: list[int]
workspace: torch.Tensor workspace: torch.Tensor
chunk_seq_lens: torch.Tensor chunk_seq_lens: torch.Tensor
chunk_seq_lens_npu: torch.Tensor
attn_mask: torch.Tensor attn_mask: torch.Tensor
query_lens: torch.Tensor query_lens: torch.Tensor
@@ -449,6 +450,7 @@ class AscendMLAMetadataBuilder:
seq_tot=chunk_seq_lens.sum(dim=1).tolist(), seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
chunk_seq_lens=chunk_seq_lens, chunk_seq_lens=chunk_seq_lens,
chunk_seq_lens_npu=chunk_seq_lens.npu(),
workspace=self.chunked_prefill_workspace, workspace=self.chunked_prefill_workspace,
) )
prefill_input_positions = input_positions[tokens_start:] prefill_input_positions = input_positions[tokens_start:]
@@ -888,7 +890,8 @@ class AscendMLAImpl(MLAAttentionImpl):
iters = len(prefill_metadata.chunked_context.seq_tot) iters = len(prefill_metadata.chunked_context.seq_tot)
seq_len1 = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32) current_seq_len = torch.tensor(prefill_metadata.query_lens,
dtype=torch.int32)
cache_kv_c = kv_c_and_k_pe_cache[0] cache_kv_c = kv_c_and_k_pe_cache[0]
cache_k_pe = kv_c_and_k_pe_cache[1] cache_k_pe = kv_c_and_k_pe_cache[1]
num_heads = cache_k_pe.size(2) num_heads = cache_k_pe.size(2)
@@ -896,8 +899,11 @@ class AscendMLAImpl(MLAAttentionImpl):
for i in range(iters): for i in range(iters):
toks = prefill_metadata.chunked_context.seq_tot[i] toks = prefill_metadata.chunked_context.seq_tot[i]
seq_len2 = prefill_metadata.chunked_context.chunk_seq_lens[i] context_seq_len = prefill_metadata.chunked_context.chunk_seq_lens[
seq_len = torch.stack([seq_len1, seq_len2]) i]
context_seq_len_npu = prefill_metadata.chunked_context.chunk_seq_lens_npu[
i]
seq_len = torch.stack([current_seq_len, context_seq_len])
kv_c_normed = torch.empty(toks, kv_c_normed = torch.empty(toks,
num_heads, num_heads,
latent_kv_dim, latent_kv_dim,
@@ -913,7 +919,7 @@ class AscendMLAImpl(MLAAttentionImpl):
cache_kv_c, cache_kv_c,
cache_k_pe, cache_k_pe,
prefill_metadata.block_table, prefill_metadata.block_table,
seq_len2.to(q_nope.device), context_seq_len_npu,
seq_starts=prefill_metadata.chunked_context.starts[i], seq_starts=prefill_metadata.chunked_context.starts[i],
key=kv_c_normed, key=kv_c_normed,
value=k_pe, value=k_pe,

View File

@@ -69,6 +69,7 @@ class AscendMLATorchairPrefillMetadata:
max_seq_lens: list[int] max_seq_lens: list[int]
workspace: torch.Tensor workspace: torch.Tensor
chunk_seq_lens: torch.Tensor chunk_seq_lens: torch.Tensor
chunk_seq_lens_npu: torch.Tensor
attn_mask: torch.Tensor attn_mask: torch.Tensor
query_lens: torch.Tensor query_lens: torch.Tensor
@@ -447,6 +448,7 @@ class AscendMLATorchairMetadataBuilder:
seq_tot=chunk_seq_lens.sum(dim=1).tolist(), seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
chunk_seq_lens=chunk_seq_lens, chunk_seq_lens=chunk_seq_lens,
chunk_seq_lens_npu=chunk_seq_lens.npu(),
workspace=self.chunked_prefill_workspace, workspace=self.chunked_prefill_workspace,
) )
prefill_input_positions = input_positions[tokens_start:] prefill_input_positions = input_positions[tokens_start:]
@@ -760,7 +762,8 @@ class AscendMLATorchairImpl(MLAAttentionImpl):
q_pe = query[..., self.qk_nope_head_dim:] q_pe = query[..., self.qk_nope_head_dim:]
q_nope = query[..., :self.qk_nope_head_dim] q_nope = query[..., :self.qk_nope_head_dim]
seq_len1 = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32) current_seq_len = torch.tensor(prefill_metadata.query_lens,
dtype=torch.int32)
cache_kv_c = kv_c_and_k_pe_cache[0] cache_kv_c = kv_c_and_k_pe_cache[0]
cache_k_pe = kv_c_and_k_pe_cache[1] cache_k_pe = kv_c_and_k_pe_cache[1]
num_heads = cache_k_pe.size(2) num_heads = cache_k_pe.size(2)
@@ -768,8 +771,11 @@ class AscendMLATorchairImpl(MLAAttentionImpl):
for i in range(iters): for i in range(iters):
toks = prefill_metadata.chunked_context.seq_tot[i] toks = prefill_metadata.chunked_context.seq_tot[i]
seq_len2 = prefill_metadata.chunked_context.chunk_seq_lens[i] context_seq_len = prefill_metadata.chunked_context.chunk_seq_lens[
seq_len = torch.stack([seq_len1, seq_len2]) i]
context_seq_len_npu = prefill_metadata.chunked_context.chunk_seq_lens_npu[
i]
seq_len = torch.stack([current_seq_len, context_seq_len])
kv_c_normed = torch.empty(toks, kv_c_normed = torch.empty(toks,
num_heads, num_heads,
latent_kv_dim, latent_kv_dim,
@@ -785,7 +791,7 @@ class AscendMLATorchairImpl(MLAAttentionImpl):
cache_kv_c, cache_kv_c,
cache_k_pe, cache_k_pe,
prefill_metadata.block_table, prefill_metadata.block_table,
seq_len2.to(query.device), context_seq_len_npu,
seq_starts=prefill_metadata.chunked_context.starts[i], seq_starts=prefill_metadata.chunked_context.starts[i],
key=kv_c_normed, key=kv_c_normed,
value=k_pe, value=k_pe,