[Core] Support the features of prefix cache and chunked prefill in v0/v1 (#782)
### What this PR does / why we need it? Support the features of prefix cache and chunked prefill in v0/v1. --------- Signed-off-by: rjg-lyh <1318825571@qq.com>
This commit is contained in:
@@ -260,6 +260,8 @@ class AscendMetadata(AttentionMetadata):
|
||||
# requests only.
|
||||
max_decode_seq_len: int
|
||||
|
||||
chunked_prefill_enabled: bool
|
||||
|
||||
# (batch_size, max_blocks_per_seq).
|
||||
# Block addresses per sequence. (Seq id -> list of physical block)
|
||||
block_tables: Optional[torch.Tensor]
|
||||
@@ -271,6 +273,9 @@ class AscendMetadata(AttentionMetadata):
|
||||
# the computed tokens + new tokens None if it is a decoding.
|
||||
seq_lens: Optional[List[int]] = None
|
||||
|
||||
# The query lengths of the input sequences
|
||||
query_lens: Optional[List[int]] = None
|
||||
|
||||
# Maximum query length in the batch. None for decoding.
|
||||
max_query_len: Optional[int] = None
|
||||
|
||||
@@ -290,8 +295,15 @@ class AscendMetadata(AttentionMetadata):
|
||||
# Number of tokens input to encoder
|
||||
num_encoder_tokens: Optional[int] = None
|
||||
|
||||
# Mask for normal situation
|
||||
attn_mask: Optional[torch.Tensor] = None
|
||||
|
||||
# Mask for prefix caching
|
||||
compress_mask: Optional[torch.Tensor] = None
|
||||
|
||||
# Mask for chunked prefill
|
||||
chunk_mask: Optional[torch.Tensor] = None
|
||||
|
||||
# Cross-attention memory-mapping data structures: slot mapping
|
||||
# and block tables
|
||||
cross_slot_mapping: Optional[torch.Tensor] = None
|
||||
@@ -315,6 +327,8 @@ class AscendMetadata(AttentionMetadata):
|
||||
self.slot_mapping[:self.num_prefill_tokens])
|
||||
seq_lens = (None if self.seq_lens is None else
|
||||
self.seq_lens[:self.num_prefills])
|
||||
query_lens = (None if self.query_lens is None else
|
||||
self.query_lens[:self.num_prefills])
|
||||
block_tables = (None if self.block_tables is None else
|
||||
self.block_tables[:self.num_prefills])
|
||||
|
||||
@@ -329,9 +343,11 @@ class AscendMetadata(AttentionMetadata):
|
||||
slot_mapping=slot_mapping,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
query_lens=query_lens,
|
||||
max_query_len=self.max_query_len,
|
||||
max_prefill_seq_len=self.max_prefill_seq_len,
|
||||
max_decode_seq_len=0,
|
||||
chunked_prefill_enabled=self.chunked_prefill_enabled,
|
||||
block_tables=block_tables,
|
||||
# Begin encoder & cross attn fields below...
|
||||
encoder_seq_lens=self.encoder_seq_lens,
|
||||
@@ -359,6 +375,8 @@ class AscendMetadata(AttentionMetadata):
|
||||
self.slot_mapping[self.num_prefill_tokens:])
|
||||
seq_lens = (None if self.seq_lens is None else
|
||||
self.seq_lens[self.num_prefills:])
|
||||
query_lens = (None if self.query_lens is None else
|
||||
self.query_lens[self.num_prefills:])
|
||||
block_tables = (None if self.block_tables is None else
|
||||
self.block_tables[self.num_prefills:])
|
||||
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
||||
@@ -371,9 +389,11 @@ class AscendMetadata(AttentionMetadata):
|
||||
slot_mapping=slot_mapping,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
query_lens=query_lens,
|
||||
max_query_len=self.max_query_len,
|
||||
max_prefill_seq_len=0,
|
||||
max_decode_seq_len=self.max_decode_seq_len,
|
||||
chunked_prefill_enabled=self.chunked_prefill_enabled,
|
||||
block_tables=block_tables,
|
||||
# Begin encoder & cross attn fields below...
|
||||
encoder_seq_lens=self.encoder_seq_lens,
|
||||
@@ -482,6 +502,8 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
|
||||
self.block_size = input_builder.block_size
|
||||
|
||||
self.attn_mask = None
|
||||
self.compress_mask = None
|
||||
self.chunk_mask = None
|
||||
if AscendMetadataBuilder._attn_mask_builder is None:
|
||||
AscendMetadataBuilder._attn_mask_builder = AttentionMaskBuilder.initialize_from_len(
|
||||
128, self.input_builder.runner.model_config.dtype)
|
||||
@@ -590,11 +612,13 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
|
||||
self.input_builder.chunked_prefill_enabled)
|
||||
|
||||
device = self.runner.device
|
||||
dtype = self.runner.model_config.dtype
|
||||
use_npu_graph = graph_pad_size != -1
|
||||
|
||||
max_query_len = max(query_lens)
|
||||
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
||||
max_decode_seq_len = max(self.curr_seq_lens, default=0)
|
||||
max_seq_len = max(max_prefill_seq_len, max_decode_seq_len)
|
||||
num_decode_tokens = self.num_decode_tokens
|
||||
|
||||
if self.num_prefills == 0 and use_npu_graph:
|
||||
@@ -612,12 +636,29 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
|
||||
)
|
||||
|
||||
if self.num_prefills > 0:
|
||||
self.attn_mask = AscendMetadataBuilder._attn_mask_builder.get_attn_mask( # type: ignore
|
||||
max_prefill_seq_len,
|
||||
self.input_builder.runner.model_config.dtype,
|
||||
self.input_builder.runner.device)
|
||||
if block_tables is None or block_tables.numel() == 0:
|
||||
# normal mask
|
||||
self.attn_mask = AscendMetadataBuilder._attn_mask_builder.get_attn_mask( # type: ignore
|
||||
max_prefill_seq_len, dtype, device)
|
||||
elif self.num_decode_tokens == 0 and not self.input_builder.chunked_prefill_enabled:
|
||||
# compress mask for prefix cache
|
||||
self.compress_mask = AscendMetadataBuilder._attn_mask_builder.get_attn_mask( # type: ignore
|
||||
128, dtype, device)
|
||||
else:
|
||||
# chunk_mask for chunk prefill
|
||||
attn_mask = AscendMetadataBuilder._attn_mask_builder.get_attn_mask( # type: ignore
|
||||
max_seq_len, dtype, device)
|
||||
if attn_mask.numel() > 1 and attn_mask[0][1] > 0:
|
||||
attn_mask *= -10000
|
||||
chunk_mask_list = []
|
||||
for i, seq_len in enumerate(seq_lens):
|
||||
context_len = self.context_lens[i]
|
||||
chunk_mask_list.append(attn_mask[context_len:seq_len])
|
||||
self.chunk_mask = torch.cat(chunk_mask_list, 0)
|
||||
else:
|
||||
self.attn_mask = None
|
||||
self.compress_mask = None
|
||||
self.chunk_mask = None
|
||||
|
||||
assert max_query_len > 0, "query_lens: {}".format(query_lens)
|
||||
|
||||
@@ -641,11 +682,15 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
|
||||
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
||||
enable_kv_scales_calculation=True,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
query_lens=query_lens,
|
||||
max_query_len=max_query_len,
|
||||
max_prefill_seq_len=max_prefill_seq_len,
|
||||
max_decode_seq_len=max_decode_seq_len,
|
||||
block_tables=block_tables,
|
||||
attn_mask=self.attn_mask,
|
||||
compress_mask=self.compress_mask,
|
||||
chunk_mask=self.chunk_mask,
|
||||
chunked_prefill_enabled=self.input_builder.chunked_prefill_enabled,
|
||||
)
|
||||
|
||||
|
||||
@@ -681,6 +726,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
self.seq_len_cpu_tensor = None
|
||||
self.query_len_cpu_tensor = None
|
||||
self.key_cache = None
|
||||
self.value_cache = None
|
||||
|
||||
@@ -769,7 +815,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
slot_indices=slots)
|
||||
|
||||
if attn_metadata.num_prefills > 0:
|
||||
|
||||
# Prefix cache disabled and chunk prefill disabled or no prefix cache hit
|
||||
if (attn_metadata.block_tables is None
|
||||
or attn_metadata.block_tables.numel() == 0):
|
||||
if attn_type == AttentionType.ENCODER_ONLY:
|
||||
@@ -816,13 +862,60 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
out=output)
|
||||
# Prefix cache only and cache hit
|
||||
elif attn_metadata.num_decode_tokens == 0 and not attn_metadata.chunked_prefill_enabled:
|
||||
assert kv_cache is not None
|
||||
assert attn_metadata.prefill_metadata is not None
|
||||
self.seq_lens_tensor_cpu = torch.from_numpy(
|
||||
np.array(
|
||||
attn_metadata.prefill_metadata.seq_lens).astype(
|
||||
np.int32))
|
||||
self.query_lens_tensor_cpu = torch.from_numpy(
|
||||
np.array(
|
||||
attn_metadata.prefill_metadata.query_lens).astype(
|
||||
np.int32))
|
||||
block_tables = attn_metadata.prefill_metadata.block_tables
|
||||
assert attn_metadata.compress_mask is not None
|
||||
compress_mask = attn_metadata.compress_mask
|
||||
torch_npu._npu_flash_attention_qlens(
|
||||
query=query,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
block_table=block_tables,
|
||||
mask=compress_mask,
|
||||
seq_len=self.query_lens_tensor_cpu,
|
||||
context_lens=self.seq_lens_tensor_cpu,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
out=output)
|
||||
# Splitfuse
|
||||
else:
|
||||
# TODO: Will support prefix cache and chunked prefill soon.
|
||||
raise RuntimeError(
|
||||
"Prefix cache and chunked prefill are currently not supported."
|
||||
)
|
||||
elif attn_metadata.decode_metadata:
|
||||
assert kv_cache is not None
|
||||
self.seq_lens_tensor_cpu = torch.from_numpy(
|
||||
np.array(attn_metadata.seq_lens).astype(np.int32))
|
||||
self.query_lens_tensor_cpu = torch.from_numpy(
|
||||
np.array(attn_metadata.query_lens).astype(np.int32))
|
||||
block_tables = attn_metadata.block_tables
|
||||
assert attn_metadata.chunk_mask is not None
|
||||
chunk_mask = attn_metadata.chunk_mask
|
||||
torch_npu._npu_paged_attention_splitfuse(
|
||||
query=query,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
block_table=block_tables,
|
||||
context_lens=self.seq_lens_tensor_cpu,
|
||||
mask=chunk_mask,
|
||||
seq_len=self.query_lens_tensor_cpu,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
out=output)
|
||||
# Decode only
|
||||
else:
|
||||
assert self.key_cache is not None
|
||||
assert self.value_cache is not None
|
||||
assert attn_metadata.decode_metadata is not None
|
||||
self.seq_lens_tensor_cpu = torch.from_numpy(
|
||||
np.array(attn_metadata.decode_metadata.seq_lens).astype(
|
||||
np.int32))
|
||||
|
||||
@@ -96,9 +96,10 @@ class AscendAttentionBackend(AttentionBackend):
|
||||
|
||||
|
||||
class AscendAttentionState(Enum):
|
||||
PrefillOnly = 0
|
||||
DecodeOnly = 1
|
||||
ChunkedPrefill = 2
|
||||
PrefillNoCache = 0
|
||||
PrefillCacheHit = 1
|
||||
DecodeOnly = 2
|
||||
ChunkedPrefill = 3
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -264,7 +265,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
# TODO: Add attr (num_prefills, prefill_metadata, decode_metadata) to AscendMetadata
|
||||
pass
|
||||
# V0-Style scheduler situation.
|
||||
elif attn_metadata.attn_state == AscendAttentionState.PrefillOnly:
|
||||
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
||||
assert attn_metadata is not None
|
||||
assert attn_metadata.attn_mask is not None
|
||||
mask = attn_metadata.attn_mask
|
||||
@@ -277,8 +278,23 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
out=output)
|
||||
elif attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
|
||||
assert attn_metadata is not None
|
||||
assert attn_metadata.attn_mask is not None
|
||||
compress_mask = attn_metadata.attn_mask
|
||||
torch_npu._npu_flash_attention_qlens(
|
||||
query=query,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
block_table=attn_metadata.block_tables,
|
||||
mask=compress_mask,
|
||||
seq_len=attn_metadata.query_lens,
|
||||
context_lens=attn_metadata.seq_lens,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
out=output)
|
||||
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
||||
block_tables = attn_metadata.block_tables
|
||||
torch_npu._npu_paged_attention(
|
||||
query=query,
|
||||
key_cache=self.key_cache,
|
||||
@@ -286,7 +302,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
block_table=block_tables,
|
||||
block_table=attn_metadata.block_tables,
|
||||
context_lens=attn_metadata.seq_lens,
|
||||
out=output)
|
||||
# Normal V1 situation.
|
||||
|
||||
@@ -417,7 +417,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
|
||||
num_tokens = query.size(0)
|
||||
attn_output = None
|
||||
# Here is only 2 possibility of input, ChunkedPrefill or PrefillOnly
|
||||
# Here is only 2 possibility of input, ChunkedPrefill or PrefillNoCache
|
||||
if attn_metadata.attn_state == AscendAttentionState.ChunkedPrefill:
|
||||
attn_output = torch.empty(num_tokens,
|
||||
self.num_heads * self.v_head_dim,
|
||||
@@ -440,7 +440,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
scale=self.scale,
|
||||
alibi_slopes=None,
|
||||
causal=True)
|
||||
elif attn_metadata.attn_state == AscendAttentionState.PrefillOnly:
|
||||
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
||||
attn_output = torch.empty(num_tokens,
|
||||
self.num_heads,
|
||||
self.padding_head_dim,
|
||||
@@ -479,7 +479,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
self.padding_head_dim)[:, :, :self.v_head_dim]
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Unexpected path reached, AscendMLAImpl should only have PrefillOnly and ChunkedPrefill scenario in forward prefill, please file a bug to vllm-ascend !"
|
||||
"Unexpected path reached, AscendMLAImpl should only have PrefillNoCache and ChunkedPrefill scenario in forward prefill, please file a bug to vllm-ascend !"
|
||||
)
|
||||
attn_output = attn_output.reshape(
|
||||
[num_tokens, self.num_heads * self.v_head_dim])
|
||||
|
||||
@@ -175,11 +175,11 @@ class NPUPlatform(Platform):
|
||||
if cache_config:
|
||||
if cache_config.block_size is None:
|
||||
cache_config.block_size = 128
|
||||
if envs.VLLM_USE_V1 and cache_config.enable_prefix_caching:
|
||||
if cache_config.enable_prefix_caching and cache_config.block_size != 128:
|
||||
logger.warning(
|
||||
"Prefix caching is not supported for V1 now, disable prefix caching"
|
||||
"If prefix caching is enabled, block size must be set to 128."
|
||||
)
|
||||
cache_config.enable_prefix_caching = False
|
||||
cache_config.block_size = 128
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
# Activate custom ops for v1.
|
||||
|
||||
@@ -693,15 +693,23 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
|
||||
# this may be larger than the sequence length if chunked
|
||||
# prefill is enabled.
|
||||
prefix_cache_len = len(computed_block_nums) * self.block_size
|
||||
|
||||
# The total number of prompt tokens in this sequence.
|
||||
# When chunked prefill is enabled, this is the token number of
|
||||
# computed chunks + current chunk.
|
||||
seq_len = inter_data.seq_lens[seq_idx]
|
||||
|
||||
# When full hit, compute the last block rather than the last token,
|
||||
# due to the requirements of prefix operator.
|
||||
if seq_len <= prefix_cache_len:
|
||||
prefix_cache_len -= self.block_size
|
||||
|
||||
seq_group_metadata.seq_data[inter_data.seq_ids[
|
||||
seq_idx]].update_num_cached_tokens(prefix_cache_len)
|
||||
|
||||
# The number of so far computed prompt tokens in this sequence.
|
||||
context_len = inter_data.context_lens[seq_idx]
|
||||
# The total number of prompt tokens in this sequence.
|
||||
# When chunked prefill is enabled, this is the token number of
|
||||
# computed chunks + current chunk.
|
||||
seq_len = inter_data.seq_lens[seq_idx]
|
||||
|
||||
if prefix_cache_len <= context_len:
|
||||
# We already passed the cache hit region,
|
||||
# so do normal computation.
|
||||
|
||||
@@ -107,6 +107,7 @@ class NPUModelRunner:
|
||||
self.model_config = vllm_config.model_config
|
||||
self.lora_config = vllm_config.lora_config
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.chunked_prefill_enabled = vllm_config.scheduler_config.chunked_prefill_enabled
|
||||
self.device = device
|
||||
self.is_multimodal_model = self.model_config.is_multimodal_model
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
@@ -454,11 +455,15 @@ class NPUModelRunner:
|
||||
if attn_state == AscendAttentionState.ChunkedPrefill:
|
||||
return self.attn_mask_builder.get_splitfuse_attn_mask(
|
||||
seq_lens, query_lens, position, self.dtype, self.device)
|
||||
# Prefill-only situation.
|
||||
elif attn_state == AscendAttentionState.PrefillOnly:
|
||||
# Prefill without cache situation.
|
||||
elif attn_state == AscendAttentionState.PrefillNoCache:
|
||||
max_seq_len = max(seq_lens, default=0)
|
||||
return self.attn_mask_builder.get_attn_mask(
|
||||
max_seq_len, self.dtype, self.device)
|
||||
# Prefill with cache hit.
|
||||
elif attn_state == AscendAttentionState.PrefillCacheHit:
|
||||
return self.attn_mask_builder.get_attn_mask(
|
||||
128, self.dtype, self.device)
|
||||
# Decode-only situation.
|
||||
else:
|
||||
return None
|
||||
@@ -528,13 +533,15 @@ class NPUModelRunner:
|
||||
block_offsets,
|
||||
out=self.slot_mapping_np[:total_num_scheduled_tokens])
|
||||
|
||||
attn_state = AscendAttentionState.ChunkedPrefill
|
||||
if np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens):
|
||||
attn_state = AscendAttentionState.PrefillOnly
|
||||
if self.chunked_prefill_enabled:
|
||||
attn_state = AscendAttentionState.ChunkedPrefill
|
||||
elif np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens):
|
||||
attn_state = AscendAttentionState.PrefillNoCache
|
||||
# We assume it is the decode stage, where prefill occurs but only one token is not hit in cache.
|
||||
elif np.all(num_scheduled_tokens == 1):
|
||||
attn_state = AscendAttentionState.DecodeOnly
|
||||
else:
|
||||
attn_state = AscendAttentionState.ChunkedPrefill
|
||||
attn_state = AscendAttentionState.PrefillCacheHit
|
||||
|
||||
attn_mask = self._make_attention_mask(seq_lens=seq_lens,
|
||||
query_lens=num_scheduled_tokens,
|
||||
|
||||
Reference in New Issue
Block a user