[Core] Support the features of prefix cache and chunked prefill in v0/v1 (#782)

### What this PR does / why we need it?
Support the features of prefix cache and chunked prefill in v0/v1.

---------

Signed-off-by: rjg-lyh <1318825571@qq.com>
This commit is contained in:
rjg-lyh
2025-05-09 16:39:28 +08:00
committed by GitHub
parent 324f819b92
commit fa99f89e93
6 changed files with 156 additions and 32 deletions

View File

@@ -260,6 +260,8 @@ class AscendMetadata(AttentionMetadata):
# requests only. # requests only.
max_decode_seq_len: int max_decode_seq_len: int
chunked_prefill_enabled: bool
# (batch_size, max_blocks_per_seq). # (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block) # Block addresses per sequence. (Seq id -> list of physical block)
block_tables: Optional[torch.Tensor] block_tables: Optional[torch.Tensor]
@@ -271,6 +273,9 @@ class AscendMetadata(AttentionMetadata):
# the computed tokens + new tokens None if it is a decoding. # the computed tokens + new tokens None if it is a decoding.
seq_lens: Optional[List[int]] = None seq_lens: Optional[List[int]] = None
# The query lengths of the input sequences
query_lens: Optional[List[int]] = None
# Maximum query length in the batch. None for decoding. # Maximum query length in the batch. None for decoding.
max_query_len: Optional[int] = None max_query_len: Optional[int] = None
@@ -290,8 +295,15 @@ class AscendMetadata(AttentionMetadata):
# Number of tokens input to encoder # Number of tokens input to encoder
num_encoder_tokens: Optional[int] = None num_encoder_tokens: Optional[int] = None
# Mask for normal situation
attn_mask: Optional[torch.Tensor] = None attn_mask: Optional[torch.Tensor] = None
# Mask for prefix caching
compress_mask: Optional[torch.Tensor] = None
# Mask for chunked prefill
chunk_mask: Optional[torch.Tensor] = None
# Cross-attention memory-mapping data structures: slot mapping # Cross-attention memory-mapping data structures: slot mapping
# and block tables # and block tables
cross_slot_mapping: Optional[torch.Tensor] = None cross_slot_mapping: Optional[torch.Tensor] = None
@@ -315,6 +327,8 @@ class AscendMetadata(AttentionMetadata):
self.slot_mapping[:self.num_prefill_tokens]) self.slot_mapping[:self.num_prefill_tokens])
seq_lens = (None if self.seq_lens is None else seq_lens = (None if self.seq_lens is None else
self.seq_lens[:self.num_prefills]) self.seq_lens[:self.num_prefills])
query_lens = (None if self.query_lens is None else
self.query_lens[:self.num_prefills])
block_tables = (None if self.block_tables is None else block_tables = (None if self.block_tables is None else
self.block_tables[:self.num_prefills]) self.block_tables[:self.num_prefills])
@@ -329,9 +343,11 @@ class AscendMetadata(AttentionMetadata):
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
seq_lens=seq_lens, seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor, seq_lens_tensor=seq_lens_tensor,
query_lens=query_lens,
max_query_len=self.max_query_len, max_query_len=self.max_query_len,
max_prefill_seq_len=self.max_prefill_seq_len, max_prefill_seq_len=self.max_prefill_seq_len,
max_decode_seq_len=0, max_decode_seq_len=0,
chunked_prefill_enabled=self.chunked_prefill_enabled,
block_tables=block_tables, block_tables=block_tables,
# Begin encoder & cross attn fields below... # Begin encoder & cross attn fields below...
encoder_seq_lens=self.encoder_seq_lens, encoder_seq_lens=self.encoder_seq_lens,
@@ -359,6 +375,8 @@ class AscendMetadata(AttentionMetadata):
self.slot_mapping[self.num_prefill_tokens:]) self.slot_mapping[self.num_prefill_tokens:])
seq_lens = (None if self.seq_lens is None else seq_lens = (None if self.seq_lens is None else
self.seq_lens[self.num_prefills:]) self.seq_lens[self.num_prefills:])
query_lens = (None if self.query_lens is None else
self.query_lens[self.num_prefills:])
block_tables = (None if self.block_tables is None else block_tables = (None if self.block_tables is None else
self.block_tables[self.num_prefills:]) self.block_tables[self.num_prefills:])
seq_lens_tensor = (None if self.seq_lens_tensor is None else seq_lens_tensor = (None if self.seq_lens_tensor is None else
@@ -371,9 +389,11 @@ class AscendMetadata(AttentionMetadata):
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
seq_lens=seq_lens, seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor, seq_lens_tensor=seq_lens_tensor,
query_lens=query_lens,
max_query_len=self.max_query_len, max_query_len=self.max_query_len,
max_prefill_seq_len=0, max_prefill_seq_len=0,
max_decode_seq_len=self.max_decode_seq_len, max_decode_seq_len=self.max_decode_seq_len,
chunked_prefill_enabled=self.chunked_prefill_enabled,
block_tables=block_tables, block_tables=block_tables,
# Begin encoder & cross attn fields below... # Begin encoder & cross attn fields below...
encoder_seq_lens=self.encoder_seq_lens, encoder_seq_lens=self.encoder_seq_lens,
@@ -482,6 +502,8 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
self.block_size = input_builder.block_size self.block_size = input_builder.block_size
self.attn_mask = None self.attn_mask = None
self.compress_mask = None
self.chunk_mask = None
if AscendMetadataBuilder._attn_mask_builder is None: if AscendMetadataBuilder._attn_mask_builder is None:
AscendMetadataBuilder._attn_mask_builder = AttentionMaskBuilder.initialize_from_len( AscendMetadataBuilder._attn_mask_builder = AttentionMaskBuilder.initialize_from_len(
128, self.input_builder.runner.model_config.dtype) 128, self.input_builder.runner.model_config.dtype)
@@ -590,11 +612,13 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
self.input_builder.chunked_prefill_enabled) self.input_builder.chunked_prefill_enabled)
device = self.runner.device device = self.runner.device
dtype = self.runner.model_config.dtype
use_npu_graph = graph_pad_size != -1 use_npu_graph = graph_pad_size != -1
max_query_len = max(query_lens) max_query_len = max(query_lens)
max_prefill_seq_len = max(self.prefill_seq_lens, default=0) max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
max_decode_seq_len = max(self.curr_seq_lens, default=0) max_decode_seq_len = max(self.curr_seq_lens, default=0)
max_seq_len = max(max_prefill_seq_len, max_decode_seq_len)
num_decode_tokens = self.num_decode_tokens num_decode_tokens = self.num_decode_tokens
if self.num_prefills == 0 and use_npu_graph: if self.num_prefills == 0 and use_npu_graph:
@@ -612,12 +636,29 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
) )
if self.num_prefills > 0: if self.num_prefills > 0:
self.attn_mask = AscendMetadataBuilder._attn_mask_builder.get_attn_mask( # type: ignore if block_tables is None or block_tables.numel() == 0:
max_prefill_seq_len, # normal mask
self.input_builder.runner.model_config.dtype, self.attn_mask = AscendMetadataBuilder._attn_mask_builder.get_attn_mask( # type: ignore
self.input_builder.runner.device) max_prefill_seq_len, dtype, device)
elif self.num_decode_tokens == 0 and not self.input_builder.chunked_prefill_enabled:
# compress mask for prefix cache
self.compress_mask = AscendMetadataBuilder._attn_mask_builder.get_attn_mask( # type: ignore
128, dtype, device)
else:
# chunk_mask for chunk prefill
attn_mask = AscendMetadataBuilder._attn_mask_builder.get_attn_mask( # type: ignore
max_seq_len, dtype, device)
if attn_mask.numel() > 1 and attn_mask[0][1] > 0:
attn_mask *= -10000
chunk_mask_list = []
for i, seq_len in enumerate(seq_lens):
context_len = self.context_lens[i]
chunk_mask_list.append(attn_mask[context_len:seq_len])
self.chunk_mask = torch.cat(chunk_mask_list, 0)
else: else:
self.attn_mask = None self.attn_mask = None
self.compress_mask = None
self.chunk_mask = None
assert max_query_len > 0, "query_lens: {}".format(query_lens) assert max_query_len > 0, "query_lens: {}".format(query_lens)
@@ -641,11 +682,15 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
multi_modal_placeholder_index_maps=placeholder_index_maps, multi_modal_placeholder_index_maps=placeholder_index_maps,
enable_kv_scales_calculation=True, enable_kv_scales_calculation=True,
seq_lens_tensor=seq_lens_tensor, seq_lens_tensor=seq_lens_tensor,
query_lens=query_lens,
max_query_len=max_query_len, max_query_len=max_query_len,
max_prefill_seq_len=max_prefill_seq_len, max_prefill_seq_len=max_prefill_seq_len,
max_decode_seq_len=max_decode_seq_len, max_decode_seq_len=max_decode_seq_len,
block_tables=block_tables, block_tables=block_tables,
attn_mask=self.attn_mask, attn_mask=self.attn_mask,
compress_mask=self.compress_mask,
chunk_mask=self.chunk_mask,
chunked_prefill_enabled=self.input_builder.chunked_prefill_enabled,
) )
@@ -681,6 +726,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
assert self.num_heads % self.num_kv_heads == 0 assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.seq_len_cpu_tensor = None self.seq_len_cpu_tensor = None
self.query_len_cpu_tensor = None
self.key_cache = None self.key_cache = None
self.value_cache = None self.value_cache = None
@@ -769,7 +815,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
slot_indices=slots) slot_indices=slots)
if attn_metadata.num_prefills > 0: if attn_metadata.num_prefills > 0:
# Prefix cache disabled and chunk prefill disabled or no prefix cache hit
if (attn_metadata.block_tables is None if (attn_metadata.block_tables is None
or attn_metadata.block_tables.numel() == 0): or attn_metadata.block_tables.numel() == 0):
if attn_type == AttentionType.ENCODER_ONLY: if attn_type == AttentionType.ENCODER_ONLY:
@@ -816,13 +862,60 @@ class AscendAttentionBackendImpl(AttentionImpl):
num_heads=self.num_heads, num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
out=output) out=output)
# Prefix cache only and cache hit
elif attn_metadata.num_decode_tokens == 0 and not attn_metadata.chunked_prefill_enabled:
assert kv_cache is not None
assert attn_metadata.prefill_metadata is not None
self.seq_lens_tensor_cpu = torch.from_numpy(
np.array(
attn_metadata.prefill_metadata.seq_lens).astype(
np.int32))
self.query_lens_tensor_cpu = torch.from_numpy(
np.array(
attn_metadata.prefill_metadata.query_lens).astype(
np.int32))
block_tables = attn_metadata.prefill_metadata.block_tables
assert attn_metadata.compress_mask is not None
compress_mask = attn_metadata.compress_mask
torch_npu._npu_flash_attention_qlens(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
block_table=block_tables,
mask=compress_mask,
seq_len=self.query_lens_tensor_cpu,
context_lens=self.seq_lens_tensor_cpu,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
out=output)
# Splitfuse
else: else:
# TODO: Will support prefix cache and chunked prefill soon. assert kv_cache is not None
raise RuntimeError( self.seq_lens_tensor_cpu = torch.from_numpy(
"Prefix cache and chunked prefill are currently not supported." np.array(attn_metadata.seq_lens).astype(np.int32))
) self.query_lens_tensor_cpu = torch.from_numpy(
elif attn_metadata.decode_metadata: np.array(attn_metadata.query_lens).astype(np.int32))
block_tables = attn_metadata.block_tables
assert attn_metadata.chunk_mask is not None
chunk_mask = attn_metadata.chunk_mask
torch_npu._npu_paged_attention_splitfuse(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
block_table=block_tables,
context_lens=self.seq_lens_tensor_cpu,
mask=chunk_mask,
seq_len=self.query_lens_tensor_cpu,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
out=output)
# Decode only
else:
assert self.key_cache is not None assert self.key_cache is not None
assert self.value_cache is not None
assert attn_metadata.decode_metadata is not None
self.seq_lens_tensor_cpu = torch.from_numpy( self.seq_lens_tensor_cpu = torch.from_numpy(
np.array(attn_metadata.decode_metadata.seq_lens).astype( np.array(attn_metadata.decode_metadata.seq_lens).astype(
np.int32)) np.int32))

View File

@@ -96,9 +96,10 @@ class AscendAttentionBackend(AttentionBackend):
class AscendAttentionState(Enum): class AscendAttentionState(Enum):
PrefillOnly = 0 PrefillNoCache = 0
DecodeOnly = 1 PrefillCacheHit = 1
ChunkedPrefill = 2 DecodeOnly = 2
ChunkedPrefill = 3
@dataclass @dataclass
@@ -264,7 +265,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
# TODO: Add attr (num_prefills, prefill_metadata, decode_metadata) to AscendMetadata # TODO: Add attr (num_prefills, prefill_metadata, decode_metadata) to AscendMetadata
pass pass
# V0-Style scheduler situation. # V0-Style scheduler situation.
elif attn_metadata.attn_state == AscendAttentionState.PrefillOnly: elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
assert attn_metadata is not None assert attn_metadata is not None
assert attn_metadata.attn_mask is not None assert attn_metadata.attn_mask is not None
mask = attn_metadata.attn_mask mask = attn_metadata.attn_mask
@@ -277,8 +278,23 @@ class AscendAttentionBackendImpl(AttentionImpl):
num_heads=self.num_heads, num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
out=output) out=output)
elif attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
assert attn_metadata is not None
assert attn_metadata.attn_mask is not None
compress_mask = attn_metadata.attn_mask
torch_npu._npu_flash_attention_qlens(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
block_table=attn_metadata.block_tables,
mask=compress_mask,
seq_len=attn_metadata.query_lens,
context_lens=attn_metadata.seq_lens,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
out=output)
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
block_tables = attn_metadata.block_tables
torch_npu._npu_paged_attention( torch_npu._npu_paged_attention(
query=query, query=query,
key_cache=self.key_cache, key_cache=self.key_cache,
@@ -286,7 +302,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads, num_heads=self.num_heads,
scale_value=self.scale, scale_value=self.scale,
block_table=block_tables, block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens, context_lens=attn_metadata.seq_lens,
out=output) out=output)
# Normal V1 situation. # Normal V1 situation.

View File

@@ -417,7 +417,7 @@ class AscendMLAImpl(MLAAttentionImpl):
num_tokens = query.size(0) num_tokens = query.size(0)
attn_output = None attn_output = None
# Here is only 2 possibility of input, ChunkedPrefill or PrefillOnly # Here is only 2 possibility of input, ChunkedPrefill or PrefillNoCache
if attn_metadata.attn_state == AscendAttentionState.ChunkedPrefill: if attn_metadata.attn_state == AscendAttentionState.ChunkedPrefill:
attn_output = torch.empty(num_tokens, attn_output = torch.empty(num_tokens,
self.num_heads * self.v_head_dim, self.num_heads * self.v_head_dim,
@@ -440,7 +440,7 @@ class AscendMLAImpl(MLAAttentionImpl):
scale=self.scale, scale=self.scale,
alibi_slopes=None, alibi_slopes=None,
causal=True) causal=True)
elif attn_metadata.attn_state == AscendAttentionState.PrefillOnly: elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
attn_output = torch.empty(num_tokens, attn_output = torch.empty(num_tokens,
self.num_heads, self.num_heads,
self.padding_head_dim, self.padding_head_dim,
@@ -479,7 +479,7 @@ class AscendMLAImpl(MLAAttentionImpl):
self.padding_head_dim)[:, :, :self.v_head_dim] self.padding_head_dim)[:, :, :self.v_head_dim]
else: else:
raise RuntimeError( raise RuntimeError(
"Unexpected path reached, AscendMLAImpl should only have PrefillOnly and ChunkedPrefill scenario in forward prefill, please file a bug to vllm-ascend !" "Unexpected path reached, AscendMLAImpl should only have PrefillNoCache and ChunkedPrefill scenario in forward prefill, please file a bug to vllm-ascend !"
) )
attn_output = attn_output.reshape( attn_output = attn_output.reshape(
[num_tokens, self.num_heads * self.v_head_dim]) [num_tokens, self.num_heads * self.v_head_dim])

View File

@@ -175,11 +175,11 @@ class NPUPlatform(Platform):
if cache_config: if cache_config:
if cache_config.block_size is None: if cache_config.block_size is None:
cache_config.block_size = 128 cache_config.block_size = 128
if envs.VLLM_USE_V1 and cache_config.enable_prefix_caching: if cache_config.enable_prefix_caching and cache_config.block_size != 128:
logger.warning( logger.warning(
"Prefix caching is not supported for V1 now, disable prefix caching" "If prefix caching is enabled, block size must be set to 128."
) )
cache_config.enable_prefix_caching = False cache_config.block_size = 128
if envs.VLLM_USE_V1: if envs.VLLM_USE_V1:
# Activate custom ops for v1. # Activate custom ops for v1.

View File

@@ -693,15 +693,23 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
# this may be larger than the sequence length if chunked # this may be larger than the sequence length if chunked
# prefill is enabled. # prefill is enabled.
prefix_cache_len = len(computed_block_nums) * self.block_size prefix_cache_len = len(computed_block_nums) * self.block_size
# The total number of prompt tokens in this sequence.
# When chunked prefill is enabled, this is the token number of
# computed chunks + current chunk.
seq_len = inter_data.seq_lens[seq_idx]
# When full hit, compute the last block rather than the last token,
# due to the requirements of prefix operator.
if seq_len <= prefix_cache_len:
prefix_cache_len -= self.block_size
seq_group_metadata.seq_data[inter_data.seq_ids[ seq_group_metadata.seq_data[inter_data.seq_ids[
seq_idx]].update_num_cached_tokens(prefix_cache_len) seq_idx]].update_num_cached_tokens(prefix_cache_len)
# The number of so far computed prompt tokens in this sequence. # The number of so far computed prompt tokens in this sequence.
context_len = inter_data.context_lens[seq_idx] context_len = inter_data.context_lens[seq_idx]
# The total number of prompt tokens in this sequence.
# When chunked prefill is enabled, this is the token number of
# computed chunks + current chunk.
seq_len = inter_data.seq_lens[seq_idx]
if prefix_cache_len <= context_len: if prefix_cache_len <= context_len:
# We already passed the cache hit region, # We already passed the cache hit region,
# so do normal computation. # so do normal computation.

View File

@@ -107,6 +107,7 @@ class NPUModelRunner:
self.model_config = vllm_config.model_config self.model_config = vllm_config.model_config
self.lora_config = vllm_config.lora_config self.lora_config = vllm_config.lora_config
self.scheduler_config = vllm_config.scheduler_config self.scheduler_config = vllm_config.scheduler_config
self.chunked_prefill_enabled = vllm_config.scheduler_config.chunked_prefill_enabled
self.device = device self.device = device
self.is_multimodal_model = self.model_config.is_multimodal_model self.is_multimodal_model = self.model_config.is_multimodal_model
self.block_size = vllm_config.cache_config.block_size self.block_size = vllm_config.cache_config.block_size
@@ -454,11 +455,15 @@ class NPUModelRunner:
if attn_state == AscendAttentionState.ChunkedPrefill: if attn_state == AscendAttentionState.ChunkedPrefill:
return self.attn_mask_builder.get_splitfuse_attn_mask( return self.attn_mask_builder.get_splitfuse_attn_mask(
seq_lens, query_lens, position, self.dtype, self.device) seq_lens, query_lens, position, self.dtype, self.device)
# Prefill-only situation. # Prefill without cache situation.
elif attn_state == AscendAttentionState.PrefillOnly: elif attn_state == AscendAttentionState.PrefillNoCache:
max_seq_len = max(seq_lens, default=0) max_seq_len = max(seq_lens, default=0)
return self.attn_mask_builder.get_attn_mask( return self.attn_mask_builder.get_attn_mask(
max_seq_len, self.dtype, self.device) max_seq_len, self.dtype, self.device)
# Prefill with cache hit.
elif attn_state == AscendAttentionState.PrefillCacheHit:
return self.attn_mask_builder.get_attn_mask(
128, self.dtype, self.device)
# Decode-only situation. # Decode-only situation.
else: else:
return None return None
@@ -528,13 +533,15 @@ class NPUModelRunner:
block_offsets, block_offsets,
out=self.slot_mapping_np[:total_num_scheduled_tokens]) out=self.slot_mapping_np[:total_num_scheduled_tokens])
attn_state = AscendAttentionState.ChunkedPrefill if self.chunked_prefill_enabled:
if np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens): attn_state = AscendAttentionState.ChunkedPrefill
attn_state = AscendAttentionState.PrefillOnly elif np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens):
attn_state = AscendAttentionState.PrefillNoCache
# We assume it is the decode stage, where prefill occurs but only one token is not hit in cache.
elif np.all(num_scheduled_tokens == 1): elif np.all(num_scheduled_tokens == 1):
attn_state = AscendAttentionState.DecodeOnly attn_state = AscendAttentionState.DecodeOnly
else: else:
attn_state = AscendAttentionState.ChunkedPrefill attn_state = AscendAttentionState.PrefillCacheHit
attn_mask = self._make_attention_mask(seq_lens=seq_lens, attn_mask = self._make_attention_mask(seq_lens=seq_lens,
query_lens=num_scheduled_tokens, query_lens=num_scheduled_tokens,