[Core] Support the features of prefix cache and chunked prefill in v0/v1 (#782)
### What this PR does / why we need it? Support the features of prefix cache and chunked prefill in v0/v1. --------- Signed-off-by: rjg-lyh <1318825571@qq.com>
This commit is contained in:
@@ -260,6 +260,8 @@ class AscendMetadata(AttentionMetadata):
|
||||
# requests only.
|
||||
max_decode_seq_len: int
|
||||
|
||||
chunked_prefill_enabled: bool
|
||||
|
||||
# (batch_size, max_blocks_per_seq).
|
||||
# Block addresses per sequence. (Seq id -> list of physical block)
|
||||
block_tables: Optional[torch.Tensor]
|
||||
@@ -271,6 +273,9 @@ class AscendMetadata(AttentionMetadata):
|
||||
# the computed tokens + new tokens None if it is a decoding.
|
||||
seq_lens: Optional[List[int]] = None
|
||||
|
||||
# The query lengths of the input sequences
|
||||
query_lens: Optional[List[int]] = None
|
||||
|
||||
# Maximum query length in the batch. None for decoding.
|
||||
max_query_len: Optional[int] = None
|
||||
|
||||
@@ -290,8 +295,15 @@ class AscendMetadata(AttentionMetadata):
|
||||
# Number of tokens input to encoder
|
||||
num_encoder_tokens: Optional[int] = None
|
||||
|
||||
# Mask for normal situation
|
||||
attn_mask: Optional[torch.Tensor] = None
|
||||
|
||||
# Mask for prefix caching
|
||||
compress_mask: Optional[torch.Tensor] = None
|
||||
|
||||
# Mask for chunked prefill
|
||||
chunk_mask: Optional[torch.Tensor] = None
|
||||
|
||||
# Cross-attention memory-mapping data structures: slot mapping
|
||||
# and block tables
|
||||
cross_slot_mapping: Optional[torch.Tensor] = None
|
||||
@@ -315,6 +327,8 @@ class AscendMetadata(AttentionMetadata):
|
||||
self.slot_mapping[:self.num_prefill_tokens])
|
||||
seq_lens = (None if self.seq_lens is None else
|
||||
self.seq_lens[:self.num_prefills])
|
||||
query_lens = (None if self.query_lens is None else
|
||||
self.query_lens[:self.num_prefills])
|
||||
block_tables = (None if self.block_tables is None else
|
||||
self.block_tables[:self.num_prefills])
|
||||
|
||||
@@ -329,9 +343,11 @@ class AscendMetadata(AttentionMetadata):
|
||||
slot_mapping=slot_mapping,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
query_lens=query_lens,
|
||||
max_query_len=self.max_query_len,
|
||||
max_prefill_seq_len=self.max_prefill_seq_len,
|
||||
max_decode_seq_len=0,
|
||||
chunked_prefill_enabled=self.chunked_prefill_enabled,
|
||||
block_tables=block_tables,
|
||||
# Begin encoder & cross attn fields below...
|
||||
encoder_seq_lens=self.encoder_seq_lens,
|
||||
@@ -359,6 +375,8 @@ class AscendMetadata(AttentionMetadata):
|
||||
self.slot_mapping[self.num_prefill_tokens:])
|
||||
seq_lens = (None if self.seq_lens is None else
|
||||
self.seq_lens[self.num_prefills:])
|
||||
query_lens = (None if self.query_lens is None else
|
||||
self.query_lens[self.num_prefills:])
|
||||
block_tables = (None if self.block_tables is None else
|
||||
self.block_tables[self.num_prefills:])
|
||||
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
||||
@@ -371,9 +389,11 @@ class AscendMetadata(AttentionMetadata):
|
||||
slot_mapping=slot_mapping,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
query_lens=query_lens,
|
||||
max_query_len=self.max_query_len,
|
||||
max_prefill_seq_len=0,
|
||||
max_decode_seq_len=self.max_decode_seq_len,
|
||||
chunked_prefill_enabled=self.chunked_prefill_enabled,
|
||||
block_tables=block_tables,
|
||||
# Begin encoder & cross attn fields below...
|
||||
encoder_seq_lens=self.encoder_seq_lens,
|
||||
@@ -482,6 +502,8 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
|
||||
self.block_size = input_builder.block_size
|
||||
|
||||
self.attn_mask = None
|
||||
self.compress_mask = None
|
||||
self.chunk_mask = None
|
||||
if AscendMetadataBuilder._attn_mask_builder is None:
|
||||
AscendMetadataBuilder._attn_mask_builder = AttentionMaskBuilder.initialize_from_len(
|
||||
128, self.input_builder.runner.model_config.dtype)
|
||||
@@ -590,11 +612,13 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
|
||||
self.input_builder.chunked_prefill_enabled)
|
||||
|
||||
device = self.runner.device
|
||||
dtype = self.runner.model_config.dtype
|
||||
use_npu_graph = graph_pad_size != -1
|
||||
|
||||
max_query_len = max(query_lens)
|
||||
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
||||
max_decode_seq_len = max(self.curr_seq_lens, default=0)
|
||||
max_seq_len = max(max_prefill_seq_len, max_decode_seq_len)
|
||||
num_decode_tokens = self.num_decode_tokens
|
||||
|
||||
if self.num_prefills == 0 and use_npu_graph:
|
||||
@@ -612,12 +636,29 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
|
||||
)
|
||||
|
||||
if self.num_prefills > 0:
|
||||
self.attn_mask = AscendMetadataBuilder._attn_mask_builder.get_attn_mask( # type: ignore
|
||||
max_prefill_seq_len,
|
||||
self.input_builder.runner.model_config.dtype,
|
||||
self.input_builder.runner.device)
|
||||
if block_tables is None or block_tables.numel() == 0:
|
||||
# normal mask
|
||||
self.attn_mask = AscendMetadataBuilder._attn_mask_builder.get_attn_mask( # type: ignore
|
||||
max_prefill_seq_len, dtype, device)
|
||||
elif self.num_decode_tokens == 0 and not self.input_builder.chunked_prefill_enabled:
|
||||
# compress mask for prefix cache
|
||||
self.compress_mask = AscendMetadataBuilder._attn_mask_builder.get_attn_mask( # type: ignore
|
||||
128, dtype, device)
|
||||
else:
|
||||
# chunk_mask for chunk prefill
|
||||
attn_mask = AscendMetadataBuilder._attn_mask_builder.get_attn_mask( # type: ignore
|
||||
max_seq_len, dtype, device)
|
||||
if attn_mask.numel() > 1 and attn_mask[0][1] > 0:
|
||||
attn_mask *= -10000
|
||||
chunk_mask_list = []
|
||||
for i, seq_len in enumerate(seq_lens):
|
||||
context_len = self.context_lens[i]
|
||||
chunk_mask_list.append(attn_mask[context_len:seq_len])
|
||||
self.chunk_mask = torch.cat(chunk_mask_list, 0)
|
||||
else:
|
||||
self.attn_mask = None
|
||||
self.compress_mask = None
|
||||
self.chunk_mask = None
|
||||
|
||||
assert max_query_len > 0, "query_lens: {}".format(query_lens)
|
||||
|
||||
@@ -641,11 +682,15 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
|
||||
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
||||
enable_kv_scales_calculation=True,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
query_lens=query_lens,
|
||||
max_query_len=max_query_len,
|
||||
max_prefill_seq_len=max_prefill_seq_len,
|
||||
max_decode_seq_len=max_decode_seq_len,
|
||||
block_tables=block_tables,
|
||||
attn_mask=self.attn_mask,
|
||||
compress_mask=self.compress_mask,
|
||||
chunk_mask=self.chunk_mask,
|
||||
chunked_prefill_enabled=self.input_builder.chunked_prefill_enabled,
|
||||
)
|
||||
|
||||
|
||||
@@ -681,6 +726,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
self.seq_len_cpu_tensor = None
|
||||
self.query_len_cpu_tensor = None
|
||||
self.key_cache = None
|
||||
self.value_cache = None
|
||||
|
||||
@@ -769,7 +815,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
slot_indices=slots)
|
||||
|
||||
if attn_metadata.num_prefills > 0:
|
||||
|
||||
# Prefix cache disabled and chunk prefill disabled or no prefix cache hit
|
||||
if (attn_metadata.block_tables is None
|
||||
or attn_metadata.block_tables.numel() == 0):
|
||||
if attn_type == AttentionType.ENCODER_ONLY:
|
||||
@@ -816,13 +862,60 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
out=output)
|
||||
# Prefix cache only and cache hit
|
||||
elif attn_metadata.num_decode_tokens == 0 and not attn_metadata.chunked_prefill_enabled:
|
||||
assert kv_cache is not None
|
||||
assert attn_metadata.prefill_metadata is not None
|
||||
self.seq_lens_tensor_cpu = torch.from_numpy(
|
||||
np.array(
|
||||
attn_metadata.prefill_metadata.seq_lens).astype(
|
||||
np.int32))
|
||||
self.query_lens_tensor_cpu = torch.from_numpy(
|
||||
np.array(
|
||||
attn_metadata.prefill_metadata.query_lens).astype(
|
||||
np.int32))
|
||||
block_tables = attn_metadata.prefill_metadata.block_tables
|
||||
assert attn_metadata.compress_mask is not None
|
||||
compress_mask = attn_metadata.compress_mask
|
||||
torch_npu._npu_flash_attention_qlens(
|
||||
query=query,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
block_table=block_tables,
|
||||
mask=compress_mask,
|
||||
seq_len=self.query_lens_tensor_cpu,
|
||||
context_lens=self.seq_lens_tensor_cpu,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
out=output)
|
||||
# Splitfuse
|
||||
else:
|
||||
# TODO: Will support prefix cache and chunked prefill soon.
|
||||
raise RuntimeError(
|
||||
"Prefix cache and chunked prefill are currently not supported."
|
||||
)
|
||||
elif attn_metadata.decode_metadata:
|
||||
assert kv_cache is not None
|
||||
self.seq_lens_tensor_cpu = torch.from_numpy(
|
||||
np.array(attn_metadata.seq_lens).astype(np.int32))
|
||||
self.query_lens_tensor_cpu = torch.from_numpy(
|
||||
np.array(attn_metadata.query_lens).astype(np.int32))
|
||||
block_tables = attn_metadata.block_tables
|
||||
assert attn_metadata.chunk_mask is not None
|
||||
chunk_mask = attn_metadata.chunk_mask
|
||||
torch_npu._npu_paged_attention_splitfuse(
|
||||
query=query,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
block_table=block_tables,
|
||||
context_lens=self.seq_lens_tensor_cpu,
|
||||
mask=chunk_mask,
|
||||
seq_len=self.query_lens_tensor_cpu,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
out=output)
|
||||
# Decode only
|
||||
else:
|
||||
assert self.key_cache is not None
|
||||
assert self.value_cache is not None
|
||||
assert attn_metadata.decode_metadata is not None
|
||||
self.seq_lens_tensor_cpu = torch.from_numpy(
|
||||
np.array(attn_metadata.decode_metadata.seq_lens).astype(
|
||||
np.int32))
|
||||
|
||||
Reference in New Issue
Block a user