[Core] Support the features of prefix cache and chunked prefill in v0/v1 (#782)
### What this PR does / why we need it? Support the features of prefix cache and chunked prefill in v0/v1. --------- Signed-off-by: rjg-lyh <1318825571@qq.com>
This commit is contained in:
@@ -107,6 +107,7 @@ class NPUModelRunner:
|
||||
self.model_config = vllm_config.model_config
|
||||
self.lora_config = vllm_config.lora_config
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.chunked_prefill_enabled = vllm_config.scheduler_config.chunked_prefill_enabled
|
||||
self.device = device
|
||||
self.is_multimodal_model = self.model_config.is_multimodal_model
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
@@ -454,11 +455,15 @@ class NPUModelRunner:
|
||||
if attn_state == AscendAttentionState.ChunkedPrefill:
|
||||
return self.attn_mask_builder.get_splitfuse_attn_mask(
|
||||
seq_lens, query_lens, position, self.dtype, self.device)
|
||||
# Prefill-only situation.
|
||||
elif attn_state == AscendAttentionState.PrefillOnly:
|
||||
# Prefill without cache situation.
|
||||
elif attn_state == AscendAttentionState.PrefillNoCache:
|
||||
max_seq_len = max(seq_lens, default=0)
|
||||
return self.attn_mask_builder.get_attn_mask(
|
||||
max_seq_len, self.dtype, self.device)
|
||||
# Prefill with cache hit.
|
||||
elif attn_state == AscendAttentionState.PrefillCacheHit:
|
||||
return self.attn_mask_builder.get_attn_mask(
|
||||
128, self.dtype, self.device)
|
||||
# Decode-only situation.
|
||||
else:
|
||||
return None
|
||||
@@ -528,13 +533,15 @@ class NPUModelRunner:
|
||||
block_offsets,
|
||||
out=self.slot_mapping_np[:total_num_scheduled_tokens])
|
||||
|
||||
attn_state = AscendAttentionState.ChunkedPrefill
|
||||
if np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens):
|
||||
attn_state = AscendAttentionState.PrefillOnly
|
||||
if self.chunked_prefill_enabled:
|
||||
attn_state = AscendAttentionState.ChunkedPrefill
|
||||
elif np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens):
|
||||
attn_state = AscendAttentionState.PrefillNoCache
|
||||
# We assume it is the decode stage, where prefill occurs but only one token is not hit in cache.
|
||||
elif np.all(num_scheduled_tokens == 1):
|
||||
attn_state = AscendAttentionState.DecodeOnly
|
||||
else:
|
||||
attn_state = AscendAttentionState.ChunkedPrefill
|
||||
attn_state = AscendAttentionState.PrefillCacheHit
|
||||
|
||||
attn_mask = self._make_attention_mask(seq_lens=seq_lens,
|
||||
query_lens=num_scheduled_tokens,
|
||||
|
||||
Reference in New Issue
Block a user