[Core] Support the features of prefix cache and chunked prefill in v0/v1 (#782)

### What this PR does / why we need it?
Support the features of prefix cache and chunked prefill in v0/v1.

---------

Signed-off-by: rjg-lyh <1318825571@qq.com>
This commit is contained in:
rjg-lyh
2025-05-09 16:39:28 +08:00
committed by GitHub
parent 324f819b92
commit fa99f89e93
6 changed files with 156 additions and 32 deletions

View File

@@ -107,6 +107,7 @@ class NPUModelRunner:
self.model_config = vllm_config.model_config
self.lora_config = vllm_config.lora_config
self.scheduler_config = vllm_config.scheduler_config
self.chunked_prefill_enabled = vllm_config.scheduler_config.chunked_prefill_enabled
self.device = device
self.is_multimodal_model = self.model_config.is_multimodal_model
self.block_size = vllm_config.cache_config.block_size
@@ -454,11 +455,15 @@ class NPUModelRunner:
if attn_state == AscendAttentionState.ChunkedPrefill:
return self.attn_mask_builder.get_splitfuse_attn_mask(
seq_lens, query_lens, position, self.dtype, self.device)
# Prefill-only situation.
elif attn_state == AscendAttentionState.PrefillOnly:
# Prefill without cache situation.
elif attn_state == AscendAttentionState.PrefillNoCache:
max_seq_len = max(seq_lens, default=0)
return self.attn_mask_builder.get_attn_mask(
max_seq_len, self.dtype, self.device)
# Prefill with cache hit.
elif attn_state == AscendAttentionState.PrefillCacheHit:
return self.attn_mask_builder.get_attn_mask(
128, self.dtype, self.device)
# Decode-only situation.
else:
return None
@@ -528,13 +533,15 @@ class NPUModelRunner:
block_offsets,
out=self.slot_mapping_np[:total_num_scheduled_tokens])
attn_state = AscendAttentionState.ChunkedPrefill
if np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens):
attn_state = AscendAttentionState.PrefillOnly
if self.chunked_prefill_enabled:
attn_state = AscendAttentionState.ChunkedPrefill
elif np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens):
attn_state = AscendAttentionState.PrefillNoCache
# We assume it is the decode stage, where prefill occurs but only one token is not hit in cache.
elif np.all(num_scheduled_tokens == 1):
attn_state = AscendAttentionState.DecodeOnly
else:
attn_state = AscendAttentionState.ChunkedPrefill
attn_state = AscendAttentionState.PrefillCacheHit
attn_mask = self._make_attention_mask(seq_lens=seq_lens,
query_lens=num_scheduled_tokens,