[Core] Support the features of prefix cache and chunked prefill in v0/v1 (#782)

### What this PR does / why we need it?
Support the features of prefix cache and chunked prefill in v0/v1.

---------

Signed-off-by: rjg-lyh <1318825571@qq.com>
This commit is contained in:
rjg-lyh
2025-05-09 16:39:28 +08:00
committed by GitHub
parent 324f819b92
commit fa99f89e93
6 changed files with 156 additions and 32 deletions

View File

@@ -96,9 +96,10 @@ class AscendAttentionBackend(AttentionBackend):
class AscendAttentionState(Enum):
PrefillOnly = 0
DecodeOnly = 1
ChunkedPrefill = 2
PrefillNoCache = 0
PrefillCacheHit = 1
DecodeOnly = 2
ChunkedPrefill = 3
@dataclass
@@ -264,7 +265,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
# TODO: Add attr (num_prefills, prefill_metadata, decode_metadata) to AscendMetadata
pass
# V0-Style scheduler situation.
elif attn_metadata.attn_state == AscendAttentionState.PrefillOnly:
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
assert attn_metadata is not None
assert attn_metadata.attn_mask is not None
mask = attn_metadata.attn_mask
@@ -277,8 +278,23 @@ class AscendAttentionBackendImpl(AttentionImpl):
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
out=output)
elif attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
assert attn_metadata is not None
assert attn_metadata.attn_mask is not None
compress_mask = attn_metadata.attn_mask
torch_npu._npu_flash_attention_qlens(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
block_table=attn_metadata.block_tables,
mask=compress_mask,
seq_len=attn_metadata.query_lens,
context_lens=attn_metadata.seq_lens,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
out=output)
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
block_tables = attn_metadata.block_tables
torch_npu._npu_paged_attention(
query=query,
key_cache=self.key_cache,
@@ -286,7 +302,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=block_tables,
block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output)
# Normal V1 situation.