[Core] Support the features of prefix cache and chunked prefill in v0/v1 (#782)

### What this PR does / why we need it?
Support the features of prefix cache and chunked prefill in v0/v1.

---------

Signed-off-by: rjg-lyh <1318825571@qq.com>
This commit is contained in:
rjg-lyh
2025-05-09 16:39:28 +08:00
committed by GitHub
parent 324f819b92
commit fa99f89e93
6 changed files with 156 additions and 32 deletions

View File

@@ -260,6 +260,8 @@ class AscendMetadata(AttentionMetadata):
# requests only.
max_decode_seq_len: int
chunked_prefill_enabled: bool
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
block_tables: Optional[torch.Tensor]
@@ -271,6 +273,9 @@ class AscendMetadata(AttentionMetadata):
# the computed tokens + new tokens None if it is a decoding.
seq_lens: Optional[List[int]] = None
# The query lengths of the input sequences
query_lens: Optional[List[int]] = None
# Maximum query length in the batch. None for decoding.
max_query_len: Optional[int] = None
@@ -290,8 +295,15 @@ class AscendMetadata(AttentionMetadata):
# Number of tokens input to encoder
num_encoder_tokens: Optional[int] = None
# Mask for normal situation
attn_mask: Optional[torch.Tensor] = None
# Mask for prefix caching
compress_mask: Optional[torch.Tensor] = None
# Mask for chunked prefill
chunk_mask: Optional[torch.Tensor] = None
# Cross-attention memory-mapping data structures: slot mapping
# and block tables
cross_slot_mapping: Optional[torch.Tensor] = None
@@ -315,6 +327,8 @@ class AscendMetadata(AttentionMetadata):
self.slot_mapping[:self.num_prefill_tokens])
seq_lens = (None if self.seq_lens is None else
self.seq_lens[:self.num_prefills])
query_lens = (None if self.query_lens is None else
self.query_lens[:self.num_prefills])
block_tables = (None if self.block_tables is None else
self.block_tables[:self.num_prefills])
@@ -329,9 +343,11 @@ class AscendMetadata(AttentionMetadata):
slot_mapping=slot_mapping,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
query_lens=query_lens,
max_query_len=self.max_query_len,
max_prefill_seq_len=self.max_prefill_seq_len,
max_decode_seq_len=0,
chunked_prefill_enabled=self.chunked_prefill_enabled,
block_tables=block_tables,
# Begin encoder & cross attn fields below...
encoder_seq_lens=self.encoder_seq_lens,
@@ -359,6 +375,8 @@ class AscendMetadata(AttentionMetadata):
self.slot_mapping[self.num_prefill_tokens:])
seq_lens = (None if self.seq_lens is None else
self.seq_lens[self.num_prefills:])
query_lens = (None if self.query_lens is None else
self.query_lens[self.num_prefills:])
block_tables = (None if self.block_tables is None else
self.block_tables[self.num_prefills:])
seq_lens_tensor = (None if self.seq_lens_tensor is None else
@@ -371,9 +389,11 @@ class AscendMetadata(AttentionMetadata):
slot_mapping=slot_mapping,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
query_lens=query_lens,
max_query_len=self.max_query_len,
max_prefill_seq_len=0,
max_decode_seq_len=self.max_decode_seq_len,
chunked_prefill_enabled=self.chunked_prefill_enabled,
block_tables=block_tables,
# Begin encoder & cross attn fields below...
encoder_seq_lens=self.encoder_seq_lens,
@@ -482,6 +502,8 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
self.block_size = input_builder.block_size
self.attn_mask = None
self.compress_mask = None
self.chunk_mask = None
if AscendMetadataBuilder._attn_mask_builder is None:
AscendMetadataBuilder._attn_mask_builder = AttentionMaskBuilder.initialize_from_len(
128, self.input_builder.runner.model_config.dtype)
@@ -590,11 +612,13 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
self.input_builder.chunked_prefill_enabled)
device = self.runner.device
dtype = self.runner.model_config.dtype
use_npu_graph = graph_pad_size != -1
max_query_len = max(query_lens)
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
max_decode_seq_len = max(self.curr_seq_lens, default=0)
max_seq_len = max(max_prefill_seq_len, max_decode_seq_len)
num_decode_tokens = self.num_decode_tokens
if self.num_prefills == 0 and use_npu_graph:
@@ -612,12 +636,29 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
)
if self.num_prefills > 0:
self.attn_mask = AscendMetadataBuilder._attn_mask_builder.get_attn_mask( # type: ignore
max_prefill_seq_len,
self.input_builder.runner.model_config.dtype,
self.input_builder.runner.device)
if block_tables is None or block_tables.numel() == 0:
# normal mask
self.attn_mask = AscendMetadataBuilder._attn_mask_builder.get_attn_mask( # type: ignore
max_prefill_seq_len, dtype, device)
elif self.num_decode_tokens == 0 and not self.input_builder.chunked_prefill_enabled:
# compress mask for prefix cache
self.compress_mask = AscendMetadataBuilder._attn_mask_builder.get_attn_mask( # type: ignore
128, dtype, device)
else:
# chunk_mask for chunk prefill
attn_mask = AscendMetadataBuilder._attn_mask_builder.get_attn_mask( # type: ignore
max_seq_len, dtype, device)
if attn_mask.numel() > 1 and attn_mask[0][1] > 0:
attn_mask *= -10000
chunk_mask_list = []
for i, seq_len in enumerate(seq_lens):
context_len = self.context_lens[i]
chunk_mask_list.append(attn_mask[context_len:seq_len])
self.chunk_mask = torch.cat(chunk_mask_list, 0)
else:
self.attn_mask = None
self.compress_mask = None
self.chunk_mask = None
assert max_query_len > 0, "query_lens: {}".format(query_lens)
@@ -641,11 +682,15 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
multi_modal_placeholder_index_maps=placeholder_index_maps,
enable_kv_scales_calculation=True,
seq_lens_tensor=seq_lens_tensor,
query_lens=query_lens,
max_query_len=max_query_len,
max_prefill_seq_len=max_prefill_seq_len,
max_decode_seq_len=max_decode_seq_len,
block_tables=block_tables,
attn_mask=self.attn_mask,
compress_mask=self.compress_mask,
chunk_mask=self.chunk_mask,
chunked_prefill_enabled=self.input_builder.chunked_prefill_enabled,
)
@@ -681,6 +726,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.seq_len_cpu_tensor = None
self.query_len_cpu_tensor = None
self.key_cache = None
self.value_cache = None
@@ -769,7 +815,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
slot_indices=slots)
if attn_metadata.num_prefills > 0:
# Prefix cache disabled and chunk prefill disabled or no prefix cache hit
if (attn_metadata.block_tables is None
or attn_metadata.block_tables.numel() == 0):
if attn_type == AttentionType.ENCODER_ONLY:
@@ -816,13 +862,60 @@ class AscendAttentionBackendImpl(AttentionImpl):
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
out=output)
# Prefix cache only and cache hit
elif attn_metadata.num_decode_tokens == 0 and not attn_metadata.chunked_prefill_enabled:
assert kv_cache is not None
assert attn_metadata.prefill_metadata is not None
self.seq_lens_tensor_cpu = torch.from_numpy(
np.array(
attn_metadata.prefill_metadata.seq_lens).astype(
np.int32))
self.query_lens_tensor_cpu = torch.from_numpy(
np.array(
attn_metadata.prefill_metadata.query_lens).astype(
np.int32))
block_tables = attn_metadata.prefill_metadata.block_tables
assert attn_metadata.compress_mask is not None
compress_mask = attn_metadata.compress_mask
torch_npu._npu_flash_attention_qlens(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
block_table=block_tables,
mask=compress_mask,
seq_len=self.query_lens_tensor_cpu,
context_lens=self.seq_lens_tensor_cpu,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
out=output)
# Splitfuse
else:
# TODO: Will support prefix cache and chunked prefill soon.
raise RuntimeError(
"Prefix cache and chunked prefill are currently not supported."
)
elif attn_metadata.decode_metadata:
assert kv_cache is not None
self.seq_lens_tensor_cpu = torch.from_numpy(
np.array(attn_metadata.seq_lens).astype(np.int32))
self.query_lens_tensor_cpu = torch.from_numpy(
np.array(attn_metadata.query_lens).astype(np.int32))
block_tables = attn_metadata.block_tables
assert attn_metadata.chunk_mask is not None
chunk_mask = attn_metadata.chunk_mask
torch_npu._npu_paged_attention_splitfuse(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
block_table=block_tables,
context_lens=self.seq_lens_tensor_cpu,
mask=chunk_mask,
seq_len=self.query_lens_tensor_cpu,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
out=output)
# Decode only
else:
assert self.key_cache is not None
assert self.value_cache is not None
assert attn_metadata.decode_metadata is not None
self.seq_lens_tensor_cpu = torch.from_numpy(
np.array(attn_metadata.decode_metadata.seq_lens).astype(
np.int32))

View File

@@ -96,9 +96,10 @@ class AscendAttentionBackend(AttentionBackend):
class AscendAttentionState(Enum):
PrefillOnly = 0
DecodeOnly = 1
ChunkedPrefill = 2
PrefillNoCache = 0
PrefillCacheHit = 1
DecodeOnly = 2
ChunkedPrefill = 3
@dataclass
@@ -264,7 +265,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
# TODO: Add attr (num_prefills, prefill_metadata, decode_metadata) to AscendMetadata
pass
# V0-Style scheduler situation.
elif attn_metadata.attn_state == AscendAttentionState.PrefillOnly:
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
assert attn_metadata is not None
assert attn_metadata.attn_mask is not None
mask = attn_metadata.attn_mask
@@ -277,8 +278,23 @@ class AscendAttentionBackendImpl(AttentionImpl):
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
out=output)
elif attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
assert attn_metadata is not None
assert attn_metadata.attn_mask is not None
compress_mask = attn_metadata.attn_mask
torch_npu._npu_flash_attention_qlens(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
block_table=attn_metadata.block_tables,
mask=compress_mask,
seq_len=attn_metadata.query_lens,
context_lens=attn_metadata.seq_lens,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
out=output)
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
block_tables = attn_metadata.block_tables
torch_npu._npu_paged_attention(
query=query,
key_cache=self.key_cache,
@@ -286,7 +302,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=block_tables,
block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output)
# Normal V1 situation.

View File

@@ -417,7 +417,7 @@ class AscendMLAImpl(MLAAttentionImpl):
num_tokens = query.size(0)
attn_output = None
# Here is only 2 possibility of input, ChunkedPrefill or PrefillOnly
# Here is only 2 possibility of input, ChunkedPrefill or PrefillNoCache
if attn_metadata.attn_state == AscendAttentionState.ChunkedPrefill:
attn_output = torch.empty(num_tokens,
self.num_heads * self.v_head_dim,
@@ -440,7 +440,7 @@ class AscendMLAImpl(MLAAttentionImpl):
scale=self.scale,
alibi_slopes=None,
causal=True)
elif attn_metadata.attn_state == AscendAttentionState.PrefillOnly:
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
attn_output = torch.empty(num_tokens,
self.num_heads,
self.padding_head_dim,
@@ -479,7 +479,7 @@ class AscendMLAImpl(MLAAttentionImpl):
self.padding_head_dim)[:, :, :self.v_head_dim]
else:
raise RuntimeError(
"Unexpected path reached, AscendMLAImpl should only have PrefillOnly and ChunkedPrefill scenario in forward prefill, please file a bug to vllm-ascend !"
"Unexpected path reached, AscendMLAImpl should only have PrefillNoCache and ChunkedPrefill scenario in forward prefill, please file a bug to vllm-ascend !"
)
attn_output = attn_output.reshape(
[num_tokens, self.num_heads * self.v_head_dim])

View File

@@ -175,11 +175,11 @@ class NPUPlatform(Platform):
if cache_config:
if cache_config.block_size is None:
cache_config.block_size = 128
if envs.VLLM_USE_V1 and cache_config.enable_prefix_caching:
if cache_config.enable_prefix_caching and cache_config.block_size != 128:
logger.warning(
"Prefix caching is not supported for V1 now, disable prefix caching"
"If prefix caching is enabled, block size must be set to 128."
)
cache_config.enable_prefix_caching = False
cache_config.block_size = 128
if envs.VLLM_USE_V1:
# Activate custom ops for v1.

View File

@@ -693,15 +693,23 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
# this may be larger than the sequence length if chunked
# prefill is enabled.
prefix_cache_len = len(computed_block_nums) * self.block_size
# The total number of prompt tokens in this sequence.
# When chunked prefill is enabled, this is the token number of
# computed chunks + current chunk.
seq_len = inter_data.seq_lens[seq_idx]
# When full hit, compute the last block rather than the last token,
# due to the requirements of prefix operator.
if seq_len <= prefix_cache_len:
prefix_cache_len -= self.block_size
seq_group_metadata.seq_data[inter_data.seq_ids[
seq_idx]].update_num_cached_tokens(prefix_cache_len)
# The number of so far computed prompt tokens in this sequence.
context_len = inter_data.context_lens[seq_idx]
# The total number of prompt tokens in this sequence.
# When chunked prefill is enabled, this is the token number of
# computed chunks + current chunk.
seq_len = inter_data.seq_lens[seq_idx]
if prefix_cache_len <= context_len:
# We already passed the cache hit region,
# so do normal computation.

View File

@@ -107,6 +107,7 @@ class NPUModelRunner:
self.model_config = vllm_config.model_config
self.lora_config = vllm_config.lora_config
self.scheduler_config = vllm_config.scheduler_config
self.chunked_prefill_enabled = vllm_config.scheduler_config.chunked_prefill_enabled
self.device = device
self.is_multimodal_model = self.model_config.is_multimodal_model
self.block_size = vllm_config.cache_config.block_size
@@ -454,11 +455,15 @@ class NPUModelRunner:
if attn_state == AscendAttentionState.ChunkedPrefill:
return self.attn_mask_builder.get_splitfuse_attn_mask(
seq_lens, query_lens, position, self.dtype, self.device)
# Prefill-only situation.
elif attn_state == AscendAttentionState.PrefillOnly:
# Prefill without cache situation.
elif attn_state == AscendAttentionState.PrefillNoCache:
max_seq_len = max(seq_lens, default=0)
return self.attn_mask_builder.get_attn_mask(
max_seq_len, self.dtype, self.device)
# Prefill with cache hit.
elif attn_state == AscendAttentionState.PrefillCacheHit:
return self.attn_mask_builder.get_attn_mask(
128, self.dtype, self.device)
# Decode-only situation.
else:
return None
@@ -528,13 +533,15 @@ class NPUModelRunner:
block_offsets,
out=self.slot_mapping_np[:total_num_scheduled_tokens])
attn_state = AscendAttentionState.ChunkedPrefill
if np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens):
attn_state = AscendAttentionState.PrefillOnly
if self.chunked_prefill_enabled:
attn_state = AscendAttentionState.ChunkedPrefill
elif np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens):
attn_state = AscendAttentionState.PrefillNoCache
# We assume it is the decode stage, where prefill occurs but only one token is not hit in cache.
elif np.all(num_scheduled_tokens == 1):
attn_state = AscendAttentionState.DecodeOnly
else:
attn_state = AscendAttentionState.ChunkedPrefill
attn_state = AscendAttentionState.PrefillCacheHit
attn_mask = self._make_attention_mask(seq_lens=seq_lens,
query_lens=num_scheduled_tokens,