[Core] Support the features of prefix cache and chunked prefill in v0/v1 (#782)
### What this PR does / why we need it? Support the features of prefix cache and chunked prefill in v0/v1. --------- Signed-off-by: rjg-lyh <1318825571@qq.com>
This commit is contained in:
@@ -96,9 +96,10 @@ class AscendAttentionBackend(AttentionBackend):
|
||||
|
||||
|
||||
class AscendAttentionState(Enum):
|
||||
PrefillOnly = 0
|
||||
DecodeOnly = 1
|
||||
ChunkedPrefill = 2
|
||||
PrefillNoCache = 0
|
||||
PrefillCacheHit = 1
|
||||
DecodeOnly = 2
|
||||
ChunkedPrefill = 3
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -264,7 +265,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
# TODO: Add attr (num_prefills, prefill_metadata, decode_metadata) to AscendMetadata
|
||||
pass
|
||||
# V0-Style scheduler situation.
|
||||
elif attn_metadata.attn_state == AscendAttentionState.PrefillOnly:
|
||||
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
||||
assert attn_metadata is not None
|
||||
assert attn_metadata.attn_mask is not None
|
||||
mask = attn_metadata.attn_mask
|
||||
@@ -277,8 +278,23 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
out=output)
|
||||
elif attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
|
||||
assert attn_metadata is not None
|
||||
assert attn_metadata.attn_mask is not None
|
||||
compress_mask = attn_metadata.attn_mask
|
||||
torch_npu._npu_flash_attention_qlens(
|
||||
query=query,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
block_table=attn_metadata.block_tables,
|
||||
mask=compress_mask,
|
||||
seq_len=attn_metadata.query_lens,
|
||||
context_lens=attn_metadata.seq_lens,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
out=output)
|
||||
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
||||
block_tables = attn_metadata.block_tables
|
||||
torch_npu._npu_paged_attention(
|
||||
query=query,
|
||||
key_cache=self.key_cache,
|
||||
@@ -286,7 +302,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
block_table=block_tables,
|
||||
block_table=attn_metadata.block_tables,
|
||||
context_lens=attn_metadata.seq_lens,
|
||||
out=output)
|
||||
# Normal V1 situation.
|
||||
|
||||
Reference in New Issue
Block a user