[BugFix] Improve the performance of prefixcache features (#4022)
### What this PR does / why we need it?
The code bug caused an empty bubble. When the npu_paged_cache_load
operator was called, it forcibly transferred seq_len2 to the device,
which triggered synchronization and interrupted the CPU operator's
launch stream.
- vLLM version: v0.11.0
- vLLM main:
83f478bb19
---------
Signed-off-by: underfituu <hzhucong@163.com>
This commit is contained in:
@@ -119,3 +119,4 @@ jobs:
|
|||||||
config_file_path: ${{ matrix.test_config.config_file_path }}
|
config_file_path: ${{ matrix.test_config.config_file_path }}
|
||||||
secrets:
|
secrets:
|
||||||
KUBECONFIG_B64: ${{ secrets.KUBECONFIG_B64 }}
|
KUBECONFIG_B64: ${{ secrets.KUBECONFIG_B64 }}
|
||||||
|
|
||||||
@@ -82,7 +82,8 @@ class TestAscendMLAPrefillMetadata(TestBase):
|
|||||||
seq_tot=seq_tot,
|
seq_tot=seq_tot,
|
||||||
max_seq_lens=max_seq_lens,
|
max_seq_lens=max_seq_lens,
|
||||||
workspace=workspace,
|
workspace=workspace,
|
||||||
chunk_seq_lens=chunk_seq_lens)
|
chunk_seq_lens=chunk_seq_lens,
|
||||||
|
chunk_seq_lens_npu=chunk_seq_lens)
|
||||||
|
|
||||||
metadata = AscendMLAPrefillMetadata(
|
metadata = AscendMLAPrefillMetadata(
|
||||||
attn_mask=torch.tensor([[1, 0], [1, 1]], dtype=torch.bool),
|
attn_mask=torch.tensor([[1, 0], [1, 1]], dtype=torch.bool),
|
||||||
@@ -103,6 +104,8 @@ class TestAscendMLAPrefillMetadata(TestBase):
|
|||||||
self.assertEqual(metadata.chunked_context.max_seq_lens, max_seq_lens)
|
self.assertEqual(metadata.chunked_context.max_seq_lens, max_seq_lens)
|
||||||
self.assertIs(metadata.chunked_context.workspace, workspace)
|
self.assertIs(metadata.chunked_context.workspace, workspace)
|
||||||
self.assertIs(metadata.chunked_context.chunk_seq_lens, chunk_seq_lens)
|
self.assertIs(metadata.chunked_context.chunk_seq_lens, chunk_seq_lens)
|
||||||
|
self.assertIs(metadata.chunked_context.chunk_seq_lens_npu,
|
||||||
|
chunk_seq_lens)
|
||||||
|
|
||||||
|
|
||||||
class TestAscendMLADecodeMetadata(TestBase):
|
class TestAscendMLADecodeMetadata(TestBase):
|
||||||
@@ -478,6 +481,7 @@ class TestAscendMLAImpl(TestBase):
|
|||||||
chunk_ctx = MagicMock()
|
chunk_ctx = MagicMock()
|
||||||
chunk_ctx.seq_tot = [8]
|
chunk_ctx.seq_tot = [8]
|
||||||
chunk_ctx.chunk_seq_lens = [torch.tensor([8])]
|
chunk_ctx.chunk_seq_lens = [torch.tensor([8])]
|
||||||
|
chunk_ctx.chunk_seq_lens_npu = [torch.tensor([8])]
|
||||||
chunk_ctx.starts = [torch.tensor([0])]
|
chunk_ctx.starts = [torch.tensor([0])]
|
||||||
|
|
||||||
prefill_meta = MagicMock()
|
prefill_meta = MagicMock()
|
||||||
|
|||||||
@@ -86,7 +86,8 @@ class TestAscendMLATorchairPrefillMetadata(TestBase):
|
|||||||
seq_tot=seq_tot,
|
seq_tot=seq_tot,
|
||||||
max_seq_lens=max_seq_lens,
|
max_seq_lens=max_seq_lens,
|
||||||
workspace=workspace,
|
workspace=workspace,
|
||||||
chunk_seq_lens=chunk_seq_lens)
|
chunk_seq_lens=chunk_seq_lens,
|
||||||
|
chunk_seq_lens_npu=chunk_seq_lens)
|
||||||
|
|
||||||
metadata = AscendMLATorchairPrefillMetadata(
|
metadata = AscendMLATorchairPrefillMetadata(
|
||||||
attn_mask=torch.tensor([[1, 0], [1, 1]], dtype=torch.bool),
|
attn_mask=torch.tensor([[1, 0], [1, 1]], dtype=torch.bool),
|
||||||
@@ -107,6 +108,8 @@ class TestAscendMLATorchairPrefillMetadata(TestBase):
|
|||||||
self.assertEqual(metadata.chunked_context.max_seq_lens, max_seq_lens)
|
self.assertEqual(metadata.chunked_context.max_seq_lens, max_seq_lens)
|
||||||
self.assertIs(metadata.chunked_context.workspace, workspace)
|
self.assertIs(metadata.chunked_context.workspace, workspace)
|
||||||
self.assertIs(metadata.chunked_context.chunk_seq_lens, chunk_seq_lens)
|
self.assertIs(metadata.chunked_context.chunk_seq_lens, chunk_seq_lens)
|
||||||
|
self.assertIs(metadata.chunked_context.chunk_seq_lens_npu,
|
||||||
|
chunk_seq_lens)
|
||||||
|
|
||||||
|
|
||||||
class TestAscendMLATorchairDecodeMetadata(TestBase):
|
class TestAscendMLATorchairDecodeMetadata(TestBase):
|
||||||
@@ -661,6 +664,7 @@ class TestAscendMLATorchairImpl(TestBase):
|
|||||||
chunk_ctx = MagicMock()
|
chunk_ctx = MagicMock()
|
||||||
chunk_ctx.seq_tot = [8]
|
chunk_ctx.seq_tot = [8]
|
||||||
chunk_ctx.chunk_seq_lens = [torch.tensor([8])]
|
chunk_ctx.chunk_seq_lens = [torch.tensor([8])]
|
||||||
|
chunk_ctx.chunk_seq_lens_npu = [torch.tensor([8])]
|
||||||
chunk_ctx.starts = [torch.tensor([0])]
|
chunk_ctx.starts = [torch.tensor([0])]
|
||||||
|
|
||||||
prefill_meta = MagicMock()
|
prefill_meta = MagicMock()
|
||||||
|
|||||||
@@ -110,6 +110,7 @@ class AscendMLAPrefillMetadata:
|
|||||||
max_seq_lens: list[int]
|
max_seq_lens: list[int]
|
||||||
workspace: torch.Tensor
|
workspace: torch.Tensor
|
||||||
chunk_seq_lens: torch.Tensor
|
chunk_seq_lens: torch.Tensor
|
||||||
|
chunk_seq_lens_npu: torch.Tensor
|
||||||
|
|
||||||
attn_mask: torch.Tensor
|
attn_mask: torch.Tensor
|
||||||
query_lens: torch.Tensor
|
query_lens: torch.Tensor
|
||||||
@@ -449,6 +450,7 @@ class AscendMLAMetadataBuilder:
|
|||||||
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
|
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
|
||||||
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
||||||
chunk_seq_lens=chunk_seq_lens,
|
chunk_seq_lens=chunk_seq_lens,
|
||||||
|
chunk_seq_lens_npu=chunk_seq_lens.npu(),
|
||||||
workspace=self.chunked_prefill_workspace,
|
workspace=self.chunked_prefill_workspace,
|
||||||
)
|
)
|
||||||
prefill_input_positions = input_positions[tokens_start:]
|
prefill_input_positions = input_positions[tokens_start:]
|
||||||
@@ -888,7 +890,8 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
|
|
||||||
iters = len(prefill_metadata.chunked_context.seq_tot)
|
iters = len(prefill_metadata.chunked_context.seq_tot)
|
||||||
|
|
||||||
seq_len1 = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32)
|
current_seq_len = torch.tensor(prefill_metadata.query_lens,
|
||||||
|
dtype=torch.int32)
|
||||||
cache_kv_c = kv_c_and_k_pe_cache[0]
|
cache_kv_c = kv_c_and_k_pe_cache[0]
|
||||||
cache_k_pe = kv_c_and_k_pe_cache[1]
|
cache_k_pe = kv_c_and_k_pe_cache[1]
|
||||||
num_heads = cache_k_pe.size(2)
|
num_heads = cache_k_pe.size(2)
|
||||||
@@ -896,8 +899,11 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
for i in range(iters):
|
for i in range(iters):
|
||||||
toks = prefill_metadata.chunked_context.seq_tot[i]
|
toks = prefill_metadata.chunked_context.seq_tot[i]
|
||||||
|
|
||||||
seq_len2 = prefill_metadata.chunked_context.chunk_seq_lens[i]
|
context_seq_len = prefill_metadata.chunked_context.chunk_seq_lens[
|
||||||
seq_len = torch.stack([seq_len1, seq_len2])
|
i]
|
||||||
|
context_seq_len_npu = prefill_metadata.chunked_context.chunk_seq_lens_npu[
|
||||||
|
i]
|
||||||
|
seq_len = torch.stack([current_seq_len, context_seq_len])
|
||||||
kv_c_normed = torch.empty(toks,
|
kv_c_normed = torch.empty(toks,
|
||||||
num_heads,
|
num_heads,
|
||||||
latent_kv_dim,
|
latent_kv_dim,
|
||||||
@@ -913,7 +919,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
cache_kv_c,
|
cache_kv_c,
|
||||||
cache_k_pe,
|
cache_k_pe,
|
||||||
prefill_metadata.block_table,
|
prefill_metadata.block_table,
|
||||||
seq_len2.to(q_nope.device),
|
context_seq_len_npu,
|
||||||
seq_starts=prefill_metadata.chunked_context.starts[i],
|
seq_starts=prefill_metadata.chunked_context.starts[i],
|
||||||
key=kv_c_normed,
|
key=kv_c_normed,
|
||||||
value=k_pe,
|
value=k_pe,
|
||||||
|
|||||||
@@ -69,6 +69,7 @@ class AscendMLATorchairPrefillMetadata:
|
|||||||
max_seq_lens: list[int]
|
max_seq_lens: list[int]
|
||||||
workspace: torch.Tensor
|
workspace: torch.Tensor
|
||||||
chunk_seq_lens: torch.Tensor
|
chunk_seq_lens: torch.Tensor
|
||||||
|
chunk_seq_lens_npu: torch.Tensor
|
||||||
|
|
||||||
attn_mask: torch.Tensor
|
attn_mask: torch.Tensor
|
||||||
query_lens: torch.Tensor
|
query_lens: torch.Tensor
|
||||||
@@ -447,6 +448,7 @@ class AscendMLATorchairMetadataBuilder:
|
|||||||
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
|
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
|
||||||
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
||||||
chunk_seq_lens=chunk_seq_lens,
|
chunk_seq_lens=chunk_seq_lens,
|
||||||
|
chunk_seq_lens_npu=chunk_seq_lens.npu(),
|
||||||
workspace=self.chunked_prefill_workspace,
|
workspace=self.chunked_prefill_workspace,
|
||||||
)
|
)
|
||||||
prefill_input_positions = input_positions[tokens_start:]
|
prefill_input_positions = input_positions[tokens_start:]
|
||||||
@@ -760,7 +762,8 @@ class AscendMLATorchairImpl(MLAAttentionImpl):
|
|||||||
q_pe = query[..., self.qk_nope_head_dim:]
|
q_pe = query[..., self.qk_nope_head_dim:]
|
||||||
q_nope = query[..., :self.qk_nope_head_dim]
|
q_nope = query[..., :self.qk_nope_head_dim]
|
||||||
|
|
||||||
seq_len1 = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32)
|
current_seq_len = torch.tensor(prefill_metadata.query_lens,
|
||||||
|
dtype=torch.int32)
|
||||||
cache_kv_c = kv_c_and_k_pe_cache[0]
|
cache_kv_c = kv_c_and_k_pe_cache[0]
|
||||||
cache_k_pe = kv_c_and_k_pe_cache[1]
|
cache_k_pe = kv_c_and_k_pe_cache[1]
|
||||||
num_heads = cache_k_pe.size(2)
|
num_heads = cache_k_pe.size(2)
|
||||||
@@ -768,8 +771,11 @@ class AscendMLATorchairImpl(MLAAttentionImpl):
|
|||||||
for i in range(iters):
|
for i in range(iters):
|
||||||
toks = prefill_metadata.chunked_context.seq_tot[i]
|
toks = prefill_metadata.chunked_context.seq_tot[i]
|
||||||
|
|
||||||
seq_len2 = prefill_metadata.chunked_context.chunk_seq_lens[i]
|
context_seq_len = prefill_metadata.chunked_context.chunk_seq_lens[
|
||||||
seq_len = torch.stack([seq_len1, seq_len2])
|
i]
|
||||||
|
context_seq_len_npu = prefill_metadata.chunked_context.chunk_seq_lens_npu[
|
||||||
|
i]
|
||||||
|
seq_len = torch.stack([current_seq_len, context_seq_len])
|
||||||
kv_c_normed = torch.empty(toks,
|
kv_c_normed = torch.empty(toks,
|
||||||
num_heads,
|
num_heads,
|
||||||
latent_kv_dim,
|
latent_kv_dim,
|
||||||
@@ -785,7 +791,7 @@ class AscendMLATorchairImpl(MLAAttentionImpl):
|
|||||||
cache_kv_c,
|
cache_kv_c,
|
||||||
cache_k_pe,
|
cache_k_pe,
|
||||||
prefill_metadata.block_table,
|
prefill_metadata.block_table,
|
||||||
seq_len2.to(query.device),
|
context_seq_len_npu,
|
||||||
seq_starts=prefill_metadata.chunked_context.starts[i],
|
seq_starts=prefill_metadata.chunked_context.starts[i],
|
||||||
key=kv_c_normed,
|
key=kv_c_normed,
|
||||||
value=k_pe,
|
value=k_pe,
|
||||||
|
|||||||
Reference in New Issue
Block a user