[0.11.0][BugFix] Improve the performance of prefixcache features (#4021)

### What this PR does / why we need it?
cherry-pick from https://github.com/vllm-project/vllm-ascend/pull/4022

The code bug caused an empty bubble. When the npu_paged_cache_load
operator was called, it forcibly transferred seq_len2 to the device,
which triggered synchronization and interrupted the CPU operator's
launch stream.


---------

Signed-off-by: underfituu <hzhucong@163.com>
This commit is contained in:
hucong
2025-11-10 11:51:34 +08:00
committed by GitHub
parent c2d58c0655
commit 7ea17fbee3
4 changed files with 30 additions and 10 deletions

View File

@@ -82,7 +82,8 @@ class TestAscendMLAPrefillMetadata(TestBase):
seq_tot=seq_tot,
max_seq_lens=max_seq_lens,
workspace=workspace,
chunk_seq_lens=chunk_seq_lens)
chunk_seq_lens=chunk_seq_lens,
chunk_seq_lens_npu=chunk_seq_lens)
metadata = AscendMLAPrefillMetadata(
attn_mask=torch.tensor([[1, 0], [1, 1]], dtype=torch.bool),
@@ -103,6 +104,8 @@ class TestAscendMLAPrefillMetadata(TestBase):
self.assertEqual(metadata.chunked_context.max_seq_lens, max_seq_lens)
self.assertIs(metadata.chunked_context.workspace, workspace)
self.assertIs(metadata.chunked_context.chunk_seq_lens, chunk_seq_lens)
self.assertIs(metadata.chunked_context.chunk_seq_lens_npu,
chunk_seq_lens)
class TestAscendMLADecodeMetadata(TestBase):
@@ -428,6 +431,7 @@ class TestAscendMLAImpl(TestBase):
chunk_ctx = MagicMock()
chunk_ctx.seq_tot = [8]
chunk_ctx.chunk_seq_lens = [torch.tensor([8])]
chunk_ctx.chunk_seq_lens_npu = [torch.tensor([8])]
chunk_ctx.starts = [torch.tensor([0])]
prefill_meta = MagicMock()

View File

@@ -86,7 +86,8 @@ class TestAscendMLATorchairPrefillMetadata(TestBase):
seq_tot=seq_tot,
max_seq_lens=max_seq_lens,
workspace=workspace,
chunk_seq_lens=chunk_seq_lens)
chunk_seq_lens=chunk_seq_lens,
chunk_seq_lens_npu=chunk_seq_lens)
metadata = AscendMLATorchairPrefillMetadata(
attn_mask=torch.tensor([[1, 0], [1, 1]], dtype=torch.bool),
@@ -107,6 +108,8 @@ class TestAscendMLATorchairPrefillMetadata(TestBase):
self.assertEqual(metadata.chunked_context.max_seq_lens, max_seq_lens)
self.assertIs(metadata.chunked_context.workspace, workspace)
self.assertIs(metadata.chunked_context.chunk_seq_lens, chunk_seq_lens)
self.assertIs(metadata.chunked_context.chunk_seq_lens_npu,
chunk_seq_lens)
class TestAscendMLATorchairDecodeMetadata(TestBase):
@@ -661,6 +664,7 @@ class TestAscendMLATorchairImpl(TestBase):
chunk_ctx = MagicMock()
chunk_ctx.seq_tot = [8]
chunk_ctx.chunk_seq_lens = [torch.tensor([8])]
chunk_ctx.chunk_seq_lens_npu = [torch.tensor([8])]
chunk_ctx.starts = [torch.tensor([0])]
prefill_meta = MagicMock()

View File

@@ -80,6 +80,7 @@ class AscendMLAPrefillMetadata:
max_seq_lens: list[int]
workspace: torch.Tensor
chunk_seq_lens: torch.Tensor
chunk_seq_lens_npu: torch.Tensor
attn_mask: torch.Tensor
query_lens: torch.Tensor
@@ -371,6 +372,7 @@ class AscendMLAMetadataBuilder:
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
chunk_seq_lens=chunk_seq_lens,
chunk_seq_lens_npu=chunk_seq_lens.npu(),
workspace=self.chunked_prefill_workspace,
)
prefill_input_positions = input_positions[tokens_start:]
@@ -766,7 +768,8 @@ class AscendMLAImpl(MLAAttentionImpl):
iters = len(prefill_metadata.chunked_context.seq_tot)
seq_len1 = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32)
current_seq_len = torch.tensor(prefill_metadata.query_lens,
dtype=torch.int32)
cache_kv_c = kv_c_and_k_pe_cache[0]
cache_k_pe = kv_c_and_k_pe_cache[1]
num_heads = cache_k_pe.size(2)
@@ -774,8 +777,11 @@ class AscendMLAImpl(MLAAttentionImpl):
for i in range(iters):
toks = prefill_metadata.chunked_context.seq_tot[i]
seq_len2 = prefill_metadata.chunked_context.chunk_seq_lens[i]
seq_len = torch.stack([seq_len1, seq_len2])
context_seq_len = prefill_metadata.chunked_context.chunk_seq_lens[
i]
context_seq_len_npu = prefill_metadata.chunked_context.chunk_seq_lens_npu[
i]
seq_len = torch.stack([current_seq_len, context_seq_len])
kv_c_normed = torch.empty(toks,
num_heads,
latent_kv_dim,
@@ -791,7 +797,7 @@ class AscendMLAImpl(MLAAttentionImpl):
cache_kv_c,
cache_k_pe,
prefill_metadata.block_table,
seq_len2.to(q_nope.device),
context_seq_len_npu,
seq_starts=prefill_metadata.chunked_context.starts[i],
key=kv_c_normed,
value=k_pe,

View File

@@ -72,6 +72,7 @@ class AscendMLATorchairPrefillMetadata:
max_seq_lens: list[int]
workspace: torch.Tensor
chunk_seq_lens: torch.Tensor
chunk_seq_lens_npu: torch.Tensor
attn_mask: torch.Tensor
query_lens: torch.Tensor
@@ -462,6 +463,7 @@ class AscendMLATorchairMetadataBuilder:
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
chunk_seq_lens=chunk_seq_lens,
chunk_seq_lens_npu=chunk_seq_lens.npu(),
workspace=self.chunked_prefill_workspace,
)
prefill_input_positions = input_positions[tokens_start:]
@@ -777,7 +779,8 @@ class AscendMLATorchairImpl(MLAAttentionImpl):
q_pe = query[..., self.qk_nope_head_dim:]
q_nope = query[..., :self.qk_nope_head_dim]
seq_len1 = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32)
current_seq_len = torch.tensor(prefill_metadata.query_lens,
dtype=torch.int32)
cache_kv_c = kv_c_and_k_pe_cache[0]
cache_k_pe = kv_c_and_k_pe_cache[1]
num_heads = cache_k_pe.size(2)
@@ -785,8 +788,11 @@ class AscendMLATorchairImpl(MLAAttentionImpl):
for i in range(iters):
toks = prefill_metadata.chunked_context.seq_tot[i]
seq_len2 = prefill_metadata.chunked_context.chunk_seq_lens[i]
seq_len = torch.stack([seq_len1, seq_len2])
context_seq_len = prefill_metadata.chunked_context.chunk_seq_lens[
i]
context_seq_len_npu = prefill_metadata.chunked_context.chunk_seq_lens_npu[
i]
seq_len = torch.stack([current_seq_len, context_seq_len])
kv_c_normed = torch.empty(toks,
num_heads,
latent_kv_dim,
@@ -802,7 +808,7 @@ class AscendMLATorchairImpl(MLAAttentionImpl):
cache_kv_c,
cache_k_pe,
prefill_metadata.block_table,
seq_len2.to(query.device),
context_seq_len_npu,
seq_starts=prefill_metadata.chunked_context.starts[i],
key=kv_c_normed,
value=k_pe,