[Core] Support the features of prefix cache and chunked prefill in v0/v1 (#782)
### What this PR does / why we need it? Support the features of prefix cache and chunked prefill in v0/v1. --------- Signed-off-by: rjg-lyh <1318825571@qq.com>
This commit is contained in:
@@ -693,15 +693,23 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
|
||||
# this may be larger than the sequence length if chunked
|
||||
# prefill is enabled.
|
||||
prefix_cache_len = len(computed_block_nums) * self.block_size
|
||||
|
||||
# The total number of prompt tokens in this sequence.
|
||||
# When chunked prefill is enabled, this is the token number of
|
||||
# computed chunks + current chunk.
|
||||
seq_len = inter_data.seq_lens[seq_idx]
|
||||
|
||||
# When full hit, compute the last block rather than the last token,
|
||||
# due to the requirements of prefix operator.
|
||||
if seq_len <= prefix_cache_len:
|
||||
prefix_cache_len -= self.block_size
|
||||
|
||||
seq_group_metadata.seq_data[inter_data.seq_ids[
|
||||
seq_idx]].update_num_cached_tokens(prefix_cache_len)
|
||||
|
||||
# The number of so far computed prompt tokens in this sequence.
|
||||
context_len = inter_data.context_lens[seq_idx]
|
||||
# The total number of prompt tokens in this sequence.
|
||||
# When chunked prefill is enabled, this is the token number of
|
||||
# computed chunks + current chunk.
|
||||
seq_len = inter_data.seq_lens[seq_idx]
|
||||
|
||||
if prefix_cache_len <= context_len:
|
||||
# We already passed the cache hit region,
|
||||
# so do normal computation.
|
||||
|
||||
@@ -107,6 +107,7 @@ class NPUModelRunner:
|
||||
self.model_config = vllm_config.model_config
|
||||
self.lora_config = vllm_config.lora_config
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.chunked_prefill_enabled = vllm_config.scheduler_config.chunked_prefill_enabled
|
||||
self.device = device
|
||||
self.is_multimodal_model = self.model_config.is_multimodal_model
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
@@ -454,11 +455,15 @@ class NPUModelRunner:
|
||||
if attn_state == AscendAttentionState.ChunkedPrefill:
|
||||
return self.attn_mask_builder.get_splitfuse_attn_mask(
|
||||
seq_lens, query_lens, position, self.dtype, self.device)
|
||||
# Prefill-only situation.
|
||||
elif attn_state == AscendAttentionState.PrefillOnly:
|
||||
# Prefill without cache situation.
|
||||
elif attn_state == AscendAttentionState.PrefillNoCache:
|
||||
max_seq_len = max(seq_lens, default=0)
|
||||
return self.attn_mask_builder.get_attn_mask(
|
||||
max_seq_len, self.dtype, self.device)
|
||||
# Prefill with cache hit.
|
||||
elif attn_state == AscendAttentionState.PrefillCacheHit:
|
||||
return self.attn_mask_builder.get_attn_mask(
|
||||
128, self.dtype, self.device)
|
||||
# Decode-only situation.
|
||||
else:
|
||||
return None
|
||||
@@ -528,13 +533,15 @@ class NPUModelRunner:
|
||||
block_offsets,
|
||||
out=self.slot_mapping_np[:total_num_scheduled_tokens])
|
||||
|
||||
attn_state = AscendAttentionState.ChunkedPrefill
|
||||
if np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens):
|
||||
attn_state = AscendAttentionState.PrefillOnly
|
||||
if self.chunked_prefill_enabled:
|
||||
attn_state = AscendAttentionState.ChunkedPrefill
|
||||
elif np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens):
|
||||
attn_state = AscendAttentionState.PrefillNoCache
|
||||
# We assume it is the decode stage, where prefill occurs but only one token is not hit in cache.
|
||||
elif np.all(num_scheduled_tokens == 1):
|
||||
attn_state = AscendAttentionState.DecodeOnly
|
||||
else:
|
||||
attn_state = AscendAttentionState.ChunkedPrefill
|
||||
attn_state = AscendAttentionState.PrefillCacheHit
|
||||
|
||||
attn_mask = self._make_attention_mask(seq_lens=seq_lens,
|
||||
query_lens=num_scheduled_tokens,
|
||||
|
||||
Reference in New Issue
Block a user