From fa99f89e93d1e70d7685a128ef003333ef17b302 Mon Sep 17 00:00:00 2001 From: rjg-lyh <83491835+rjg-lyh@users.noreply.github.com> Date: Fri, 9 May 2025 16:39:28 +0800 Subject: [PATCH] [Core] Support the features of prefix cache and chunked prefill in v0/v1 (#782) ### What this PR does / why we need it? Support the features of prefix cache and chunked prefill in v0/v1. --------- Signed-off-by: rjg-lyh <1318825571@qq.com> --- vllm_ascend/attention/attention.py | 113 +++++++++++++++++++++++--- vllm_ascend/attention/attention_v1.py | 28 +++++-- vllm_ascend/attention/mla_v1.py | 6 +- vllm_ascend/platform.py | 6 +- vllm_ascend/worker/model_runner.py | 16 +++- vllm_ascend/worker/model_runner_v1.py | 19 +++-- 6 files changed, 156 insertions(+), 32 deletions(-) diff --git a/vllm_ascend/attention/attention.py b/vllm_ascend/attention/attention.py index b8167fe..cb4f745 100644 --- a/vllm_ascend/attention/attention.py +++ b/vllm_ascend/attention/attention.py @@ -260,6 +260,8 @@ class AscendMetadata(AttentionMetadata): # requests only. max_decode_seq_len: int + chunked_prefill_enabled: bool + # (batch_size, max_blocks_per_seq). # Block addresses per sequence. (Seq id -> list of physical block) block_tables: Optional[torch.Tensor] @@ -271,6 +273,9 @@ class AscendMetadata(AttentionMetadata): # the computed tokens + new tokens None if it is a decoding. seq_lens: Optional[List[int]] = None + # The query lengths of the input sequences + query_lens: Optional[List[int]] = None + # Maximum query length in the batch. None for decoding. max_query_len: Optional[int] = None @@ -290,8 +295,15 @@ class AscendMetadata(AttentionMetadata): # Number of tokens input to encoder num_encoder_tokens: Optional[int] = None + # Mask for normal situation attn_mask: Optional[torch.Tensor] = None + # Mask for prefix caching + compress_mask: Optional[torch.Tensor] = None + + # Mask for chunked prefill + chunk_mask: Optional[torch.Tensor] = None + # Cross-attention memory-mapping data structures: slot mapping # and block tables cross_slot_mapping: Optional[torch.Tensor] = None @@ -315,6 +327,8 @@ class AscendMetadata(AttentionMetadata): self.slot_mapping[:self.num_prefill_tokens]) seq_lens = (None if self.seq_lens is None else self.seq_lens[:self.num_prefills]) + query_lens = (None if self.query_lens is None else + self.query_lens[:self.num_prefills]) block_tables = (None if self.block_tables is None else self.block_tables[:self.num_prefills]) @@ -329,9 +343,11 @@ class AscendMetadata(AttentionMetadata): slot_mapping=slot_mapping, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, + query_lens=query_lens, max_query_len=self.max_query_len, max_prefill_seq_len=self.max_prefill_seq_len, max_decode_seq_len=0, + chunked_prefill_enabled=self.chunked_prefill_enabled, block_tables=block_tables, # Begin encoder & cross attn fields below... encoder_seq_lens=self.encoder_seq_lens, @@ -359,6 +375,8 @@ class AscendMetadata(AttentionMetadata): self.slot_mapping[self.num_prefill_tokens:]) seq_lens = (None if self.seq_lens is None else self.seq_lens[self.num_prefills:]) + query_lens = (None if self.query_lens is None else + self.query_lens[self.num_prefills:]) block_tables = (None if self.block_tables is None else self.block_tables[self.num_prefills:]) seq_lens_tensor = (None if self.seq_lens_tensor is None else @@ -371,9 +389,11 @@ class AscendMetadata(AttentionMetadata): slot_mapping=slot_mapping, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, + query_lens=query_lens, max_query_len=self.max_query_len, max_prefill_seq_len=0, max_decode_seq_len=self.max_decode_seq_len, + chunked_prefill_enabled=self.chunked_prefill_enabled, block_tables=block_tables, # Begin encoder & cross attn fields below... encoder_seq_lens=self.encoder_seq_lens, @@ -482,6 +502,8 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]): self.block_size = input_builder.block_size self.attn_mask = None + self.compress_mask = None + self.chunk_mask = None if AscendMetadataBuilder._attn_mask_builder is None: AscendMetadataBuilder._attn_mask_builder = AttentionMaskBuilder.initialize_from_len( 128, self.input_builder.runner.model_config.dtype) @@ -590,11 +612,13 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]): self.input_builder.chunked_prefill_enabled) device = self.runner.device + dtype = self.runner.model_config.dtype use_npu_graph = graph_pad_size != -1 max_query_len = max(query_lens) max_prefill_seq_len = max(self.prefill_seq_lens, default=0) max_decode_seq_len = max(self.curr_seq_lens, default=0) + max_seq_len = max(max_prefill_seq_len, max_decode_seq_len) num_decode_tokens = self.num_decode_tokens if self.num_prefills == 0 and use_npu_graph: @@ -612,12 +636,29 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]): ) if self.num_prefills > 0: - self.attn_mask = AscendMetadataBuilder._attn_mask_builder.get_attn_mask( # type: ignore - max_prefill_seq_len, - self.input_builder.runner.model_config.dtype, - self.input_builder.runner.device) + if block_tables is None or block_tables.numel() == 0: + # normal mask + self.attn_mask = AscendMetadataBuilder._attn_mask_builder.get_attn_mask( # type: ignore + max_prefill_seq_len, dtype, device) + elif self.num_decode_tokens == 0 and not self.input_builder.chunked_prefill_enabled: + # compress mask for prefix cache + self.compress_mask = AscendMetadataBuilder._attn_mask_builder.get_attn_mask( # type: ignore + 128, dtype, device) + else: + # chunk_mask for chunk prefill + attn_mask = AscendMetadataBuilder._attn_mask_builder.get_attn_mask( # type: ignore + max_seq_len, dtype, device) + if attn_mask.numel() > 1 and attn_mask[0][1] > 0: + attn_mask *= -10000 + chunk_mask_list = [] + for i, seq_len in enumerate(seq_lens): + context_len = self.context_lens[i] + chunk_mask_list.append(attn_mask[context_len:seq_len]) + self.chunk_mask = torch.cat(chunk_mask_list, 0) else: self.attn_mask = None + self.compress_mask = None + self.chunk_mask = None assert max_query_len > 0, "query_lens: {}".format(query_lens) @@ -641,11 +682,15 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]): multi_modal_placeholder_index_maps=placeholder_index_maps, enable_kv_scales_calculation=True, seq_lens_tensor=seq_lens_tensor, + query_lens=query_lens, max_query_len=max_query_len, max_prefill_seq_len=max_prefill_seq_len, max_decode_seq_len=max_decode_seq_len, block_tables=block_tables, attn_mask=self.attn_mask, + compress_mask=self.compress_mask, + chunk_mask=self.chunk_mask, + chunked_prefill_enabled=self.input_builder.chunked_prefill_enabled, ) @@ -681,6 +726,7 @@ class AscendAttentionBackendImpl(AttentionImpl): assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.seq_len_cpu_tensor = None + self.query_len_cpu_tensor = None self.key_cache = None self.value_cache = None @@ -769,7 +815,7 @@ class AscendAttentionBackendImpl(AttentionImpl): slot_indices=slots) if attn_metadata.num_prefills > 0: - + # Prefix cache disabled and chunk prefill disabled or no prefix cache hit if (attn_metadata.block_tables is None or attn_metadata.block_tables.numel() == 0): if attn_type == AttentionType.ENCODER_ONLY: @@ -816,13 +862,60 @@ class AscendAttentionBackendImpl(AttentionImpl): num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, out=output) + # Prefix cache only and cache hit + elif attn_metadata.num_decode_tokens == 0 and not attn_metadata.chunked_prefill_enabled: + assert kv_cache is not None + assert attn_metadata.prefill_metadata is not None + self.seq_lens_tensor_cpu = torch.from_numpy( + np.array( + attn_metadata.prefill_metadata.seq_lens).astype( + np.int32)) + self.query_lens_tensor_cpu = torch.from_numpy( + np.array( + attn_metadata.prefill_metadata.query_lens).astype( + np.int32)) + block_tables = attn_metadata.prefill_metadata.block_tables + assert attn_metadata.compress_mask is not None + compress_mask = attn_metadata.compress_mask + torch_npu._npu_flash_attention_qlens( + query=query, + key_cache=self.key_cache, + value_cache=self.value_cache, + block_table=block_tables, + mask=compress_mask, + seq_len=self.query_lens_tensor_cpu, + context_lens=self.seq_lens_tensor_cpu, + num_kv_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale_value=self.scale, + out=output) + # Splitfuse else: - # TODO: Will support prefix cache and chunked prefill soon. - raise RuntimeError( - "Prefix cache and chunked prefill are currently not supported." - ) - elif attn_metadata.decode_metadata: + assert kv_cache is not None + self.seq_lens_tensor_cpu = torch.from_numpy( + np.array(attn_metadata.seq_lens).astype(np.int32)) + self.query_lens_tensor_cpu = torch.from_numpy( + np.array(attn_metadata.query_lens).astype(np.int32)) + block_tables = attn_metadata.block_tables + assert attn_metadata.chunk_mask is not None + chunk_mask = attn_metadata.chunk_mask + torch_npu._npu_paged_attention_splitfuse( + query=query, + key_cache=self.key_cache, + value_cache=self.value_cache, + block_table=block_tables, + context_lens=self.seq_lens_tensor_cpu, + mask=chunk_mask, + seq_len=self.query_lens_tensor_cpu, + num_kv_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale_value=self.scale, + out=output) + # Decode only + else: assert self.key_cache is not None + assert self.value_cache is not None + assert attn_metadata.decode_metadata is not None self.seq_lens_tensor_cpu = torch.from_numpy( np.array(attn_metadata.decode_metadata.seq_lens).astype( np.int32)) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 03d0bcc..862667e 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -96,9 +96,10 @@ class AscendAttentionBackend(AttentionBackend): class AscendAttentionState(Enum): - PrefillOnly = 0 - DecodeOnly = 1 - ChunkedPrefill = 2 + PrefillNoCache = 0 + PrefillCacheHit = 1 + DecodeOnly = 2 + ChunkedPrefill = 3 @dataclass @@ -264,7 +265,7 @@ class AscendAttentionBackendImpl(AttentionImpl): # TODO: Add attr (num_prefills, prefill_metadata, decode_metadata) to AscendMetadata pass # V0-Style scheduler situation. - elif attn_metadata.attn_state == AscendAttentionState.PrefillOnly: + elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: assert attn_metadata is not None assert attn_metadata.attn_mask is not None mask = attn_metadata.attn_mask @@ -277,8 +278,23 @@ class AscendAttentionBackendImpl(AttentionImpl): num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, out=output) + elif attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit: + assert attn_metadata is not None + assert attn_metadata.attn_mask is not None + compress_mask = attn_metadata.attn_mask + torch_npu._npu_flash_attention_qlens( + query=query, + key_cache=self.key_cache, + value_cache=self.value_cache, + block_table=attn_metadata.block_tables, + mask=compress_mask, + seq_len=attn_metadata.query_lens, + context_lens=attn_metadata.seq_lens, + num_kv_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale_value=self.scale, + out=output) elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: - block_tables = attn_metadata.block_tables torch_npu._npu_paged_attention( query=query, key_cache=self.key_cache, @@ -286,7 +302,7 @@ class AscendAttentionBackendImpl(AttentionImpl): num_kv_heads=self.num_kv_heads, num_heads=self.num_heads, scale_value=self.scale, - block_table=block_tables, + block_table=attn_metadata.block_tables, context_lens=attn_metadata.seq_lens, out=output) # Normal V1 situation. diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 64a5431..caf0eae 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -417,7 +417,7 @@ class AscendMLAImpl(MLAAttentionImpl): num_tokens = query.size(0) attn_output = None - # Here is only 2 possibility of input, ChunkedPrefill or PrefillOnly + # Here is only 2 possibility of input, ChunkedPrefill or PrefillNoCache if attn_metadata.attn_state == AscendAttentionState.ChunkedPrefill: attn_output = torch.empty(num_tokens, self.num_heads * self.v_head_dim, @@ -440,7 +440,7 @@ class AscendMLAImpl(MLAAttentionImpl): scale=self.scale, alibi_slopes=None, causal=True) - elif attn_metadata.attn_state == AscendAttentionState.PrefillOnly: + elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: attn_output = torch.empty(num_tokens, self.num_heads, self.padding_head_dim, @@ -479,7 +479,7 @@ class AscendMLAImpl(MLAAttentionImpl): self.padding_head_dim)[:, :, :self.v_head_dim] else: raise RuntimeError( - "Unexpected path reached, AscendMLAImpl should only have PrefillOnly and ChunkedPrefill scenario in forward prefill, please file a bug to vllm-ascend !" + "Unexpected path reached, AscendMLAImpl should only have PrefillNoCache and ChunkedPrefill scenario in forward prefill, please file a bug to vllm-ascend !" ) attn_output = attn_output.reshape( [num_tokens, self.num_heads * self.v_head_dim]) diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 5d2c8ac..828d0e5 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -175,11 +175,11 @@ class NPUPlatform(Platform): if cache_config: if cache_config.block_size is None: cache_config.block_size = 128 - if envs.VLLM_USE_V1 and cache_config.enable_prefix_caching: + if cache_config.enable_prefix_caching and cache_config.block_size != 128: logger.warning( - "Prefix caching is not supported for V1 now, disable prefix caching" + "If prefix caching is enabled, block size must be set to 128." ) - cache_config.enable_prefix_caching = False + cache_config.block_size = 128 if envs.VLLM_USE_V1: # Activate custom ops for v1. diff --git a/vllm_ascend/worker/model_runner.py b/vllm_ascend/worker/model_runner.py index e58f55a..49c221e 100644 --- a/vllm_ascend/worker/model_runner.py +++ b/vllm_ascend/worker/model_runner.py @@ -693,15 +693,23 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]): # this may be larger than the sequence length if chunked # prefill is enabled. prefix_cache_len = len(computed_block_nums) * self.block_size + + # The total number of prompt tokens in this sequence. + # When chunked prefill is enabled, this is the token number of + # computed chunks + current chunk. + seq_len = inter_data.seq_lens[seq_idx] + + # When full hit, compute the last block rather than the last token, + # due to the requirements of prefix operator. + if seq_len <= prefix_cache_len: + prefix_cache_len -= self.block_size + seq_group_metadata.seq_data[inter_data.seq_ids[ seq_idx]].update_num_cached_tokens(prefix_cache_len) # The number of so far computed prompt tokens in this sequence. context_len = inter_data.context_lens[seq_idx] - # The total number of prompt tokens in this sequence. - # When chunked prefill is enabled, this is the token number of - # computed chunks + current chunk. - seq_len = inter_data.seq_lens[seq_idx] + if prefix_cache_len <= context_len: # We already passed the cache hit region, # so do normal computation. diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 76d3ea4..9398da0 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -107,6 +107,7 @@ class NPUModelRunner: self.model_config = vllm_config.model_config self.lora_config = vllm_config.lora_config self.scheduler_config = vllm_config.scheduler_config + self.chunked_prefill_enabled = vllm_config.scheduler_config.chunked_prefill_enabled self.device = device self.is_multimodal_model = self.model_config.is_multimodal_model self.block_size = vllm_config.cache_config.block_size @@ -454,11 +455,15 @@ class NPUModelRunner: if attn_state == AscendAttentionState.ChunkedPrefill: return self.attn_mask_builder.get_splitfuse_attn_mask( seq_lens, query_lens, position, self.dtype, self.device) - # Prefill-only situation. - elif attn_state == AscendAttentionState.PrefillOnly: + # Prefill without cache situation. + elif attn_state == AscendAttentionState.PrefillNoCache: max_seq_len = max(seq_lens, default=0) return self.attn_mask_builder.get_attn_mask( max_seq_len, self.dtype, self.device) + # Prefill with cache hit. + elif attn_state == AscendAttentionState.PrefillCacheHit: + return self.attn_mask_builder.get_attn_mask( + 128, self.dtype, self.device) # Decode-only situation. else: return None @@ -528,13 +533,15 @@ class NPUModelRunner: block_offsets, out=self.slot_mapping_np[:total_num_scheduled_tokens]) - attn_state = AscendAttentionState.ChunkedPrefill - if np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens): - attn_state = AscendAttentionState.PrefillOnly + if self.chunked_prefill_enabled: + attn_state = AscendAttentionState.ChunkedPrefill + elif np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens): + attn_state = AscendAttentionState.PrefillNoCache + # We assume it is the decode stage, where prefill occurs but only one token is not hit in cache. elif np.all(num_scheduled_tokens == 1): attn_state = AscendAttentionState.DecodeOnly else: - attn_state = AscendAttentionState.ChunkedPrefill + attn_state = AscendAttentionState.PrefillCacheHit attn_mask = self._make_attention_mask(seq_lens=seq_lens, query_lens=num_scheduled_tokens,