# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections import defaultdict from dataclasses import dataclass from itertools import accumulate from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type import torch from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionMetadataBuilder) from vllm.attention.backends.utils import CommonAttentionState from vllm.multimodal import MultiModalPlaceholderMap if TYPE_CHECKING: from vllm.worker.model_runner import (ModelInputForGPUBuilder, ModelInputForGPUWithSamplingMetadata) from vllm.utils import async_tensor_h2d # Placeholder attention backend for models like Mamba and pooling models that # lack attention. class PlaceholderAttentionBackend(AttentionBackend): """Placeholder backend for when no attention is needed.""" @staticmethod def get_name() -> str: return "NO_ATTENTION" @staticmethod def get_impl_cls() -> Type["PlaceholderAttentionImpl"]: return PlaceholderAttentionImpl @staticmethod def get_builder_cls() -> Type["PlaceholderAttentionMetadataBuilder"]: return PlaceholderAttentionMetadataBuilder @staticmethod def get_metadata_cls() -> Type["PlaceholderAttentionMetadata"]: return PlaceholderAttentionMetadata @staticmethod def get_state_cls() -> Type["CommonAttentionState"]: return CommonAttentionState @staticmethod def get_kv_cache_shape( num_blocks: int, block_size: int, num_kv_heads: int, head_size: int, ) -> Tuple[int, ...]: return (1, 1, 1, 1, 1) @staticmethod def swap_blocks( src_kv_cache: torch.Tensor, dst_kv_cache: torch.Tensor, src_to_dst: torch.Tensor, ) -> None: return @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], src_to_dists: torch.Tensor, ) -> None: return @dataclass class PlaceholderAttentionMetadata(AttentionMetadata): """Attention metadata for prefill and decode batched together.""" # (batch_size,). The sequence length per sequence. Sequence length means # the computed tokens + new tokens None if it is a decoding. seq_lens: Optional[List[int]] # seq_lens stored as a tensor. seq_lens_tensor: Optional[torch.Tensor] # Maximum sequence length among prefill batch. 0 if there are decoding # requests only. max_prefill_seq_len: int # Maximum sequence length among decode batch. 0 if there are prefill # requests only. max_decode_seq_len: int # (batch_size,) A tensor of context lengths (tokens that are computed # so far). context_lens_tensor: Optional[torch.Tensor] # Whether or not if cuda graph is enabled. # Cuda-graph is currently enabled for decoding only. # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. use_cuda_graph: bool # Maximum query length in the batch. max_query_len: Optional[int] # Max number of query tokens among request in the batch. max_decode_query_len: Optional[int] # (batch_size + 1,). The cumulative subquery lengths of the sequences in # the batch, used to index into subquery. E.g., if the subquery length # is [4, 6], it is [0, 4, 10]. query_start_loc: Optional[torch.Tensor] = None # (batch_size + 1,). The cumulative sequence lengths of the sequences in # the batch, used to index into sequence. E.g., if the sequence length is # [4, 6], it is [0, 4, 10]. seq_start_loc: Optional[torch.Tensor] = None # Placeholder. block_tables: Optional[torch.Tensor] = None _cached_prefill_metadata: Optional["PlaceholderAttentionMetadata"] = None _cached_decode_metadata: Optional["PlaceholderAttentionMetadata"] = None @property def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: if self.num_prefills == 0: return None if self._cached_prefill_metadata is not None: return self._cached_prefill_metadata # Compute some attn_metadata fields which default to None query_start_loc = (None if self.query_start_loc is None else self.query_start_loc[:self.num_prefills + 1]) seq_lens = (None if self.seq_lens is None else self.seq_lens[:self.num_prefills]) seq_lens_tensor = (None if self.seq_lens_tensor is None else self.seq_lens_tensor[:self.num_prefills]) seq_start_loc = (None if self.seq_start_loc is None else self.seq_start_loc[:self.num_prefills + 1]) context_lens_tensor = (None if self.context_lens_tensor is None else self.context_lens_tensor[:self.num_prefills]) # Placeholders slot_mapping = torch.empty(0) block_tables = torch.empty(0) self._cached_prefill_metadata = PlaceholderAttentionMetadata( num_prefills=self.num_prefills, num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=0, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=self. multi_modal_placeholder_index_maps, enable_kv_scales_calculation=self.enable_kv_scales_calculation, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, max_decode_query_len=0, max_query_len=self.max_query_len, max_prefill_seq_len=self.max_prefill_seq_len, max_decode_seq_len=0, query_start_loc=query_start_loc, seq_start_loc=seq_start_loc, context_lens_tensor=context_lens_tensor, block_tables=block_tables, use_cuda_graph=False, ) return self._cached_prefill_metadata @property def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: if self.num_decode_tokens == 0: return None if self._cached_decode_metadata is not None: return self._cached_decode_metadata assert self.seq_lens_tensor is not None # Placeholders slot_mapping = torch.empty(0) block_tables = torch.empty(0) seq_lens_tensor = (None if self.seq_lens_tensor is None else self.seq_lens_tensor[self.num_prefills:]) self._cached_decode_metadata = PlaceholderAttentionMetadata( num_prefills=0, num_prefill_tokens=0, num_decode_tokens=self.num_decode_tokens, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, enable_kv_scales_calculation=True, seq_lens=None, seq_lens_tensor=seq_lens_tensor, max_decode_query_len=self.max_decode_query_len, max_query_len=None, max_prefill_seq_len=0, max_decode_seq_len=self.max_decode_seq_len, query_start_loc=(self.query_start_loc[self.num_prefills:] - self.query_start_loc[self.num_prefills]) if self.query_start_loc is not None else None, seq_start_loc=self.seq_start_loc[self.num_prefills:] if self.seq_start_loc is not None else None, context_lens_tensor=None, block_tables=block_tables, use_cuda_graph=self.use_cuda_graph, ) return self._cached_decode_metadata def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata", sampled_token_ids: Optional[torch.Tensor], block_size: int, num_seqs: int, num_queries: int, turn_prefills_into_decodes: bool = False): """ Update metadata in-place to advance one decode step. """ # When using cudagraph, the num_seqs is padded to the next captured # batch sized, but num_queries tracks the actual number of requests in # the batch. For --enforce-eager mode, num_seqs == num_queries if num_seqs != num_queries: assert num_seqs > num_queries assert self.use_cuda_graph assert not turn_prefills_into_decodes, \ ("Multi-Step + Chunked-Prefill is not supported for attention-free" "models. turn_prefills_into_decodes is a " "Multi-Step + Chunked-Prefill specific parameter.") assert self.seq_lens is not None assert self.max_decode_seq_len == max(self.seq_lens) assert self.num_prefills == 0 assert self.num_prefill_tokens == 0 assert self.num_decode_tokens == num_seqs assert self.seq_lens is not None assert len(self.seq_lens) == num_seqs assert self.seq_lens_tensor is not None assert self.seq_lens_tensor.shape == (num_seqs, ) assert self.max_query_len == 1 assert self.max_prefill_seq_len == 0 assert self.query_start_loc is not None assert self.query_start_loc.shape == (num_queries + 1, ) assert self.seq_start_loc is not None assert self.seq_start_loc.shape == (num_seqs + 1, ) assert self.context_lens_tensor is not None assert self.context_lens_tensor.shape == (num_queries, ) # Update query lengths. Note that we update only queries and not seqs, # since tensors may be padded due to captured cuda graph batch size for i in range(num_queries): self.seq_lens[i] += 1 self.max_decode_seq_len = max(self.seq_lens) # Update sequences, masking off entries greater than num_queries device = self.seq_lens_tensor.device mask = torch.arange(self.seq_lens_tensor.size(0), device=device) < num_queries self.seq_lens_tensor += mask.to(self.seq_lens_tensor.dtype) if sampled_token_ids is not None: model_input.input_tokens.masked_scatter_( mask, sampled_token_ids[:num_queries]) class PlaceholderAttentionMetadataBuilder( AttentionMetadataBuilder[PlaceholderAttentionMetadata]): def __init__(self, input_builder: "ModelInputForGPUBuilder"): self.input_builder = input_builder self.runner = input_builder.runner def prepare(self): self.prefill_seq_lens: List[int] = [] self.context_lens: List[int] = [] self.curr_seq_lens: List[int] = [] self.multimodal_placeholder_maps: Dict[ str, MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) self.num_prefills = 0 self.num_prefill_tokens = 0 self.num_decode_tokens = 0 def _add_seq_group( self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", chunked_prefill_enabled: bool): """Add a sequence group to the metadata. Specifically update/append 1. context length. """ is_prompt = inter_data.is_prompt for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, curr_sliding_window_block) in zip( inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], inter_data.orig_seq_lens, inter_data.seq_lens, inter_data.query_lens, inter_data.context_lens, inter_data.curr_sliding_window_blocks): self.context_lens.append(context_len) if is_prompt: mm_maps = inter_data.multi_modal_placeholder_maps if mm_maps: for modality, placeholders in mm_maps.items(): self.multimodal_placeholder_maps[modality].extend( placeholders) self.num_prefills += 1 self.num_prefill_tokens += token_len self.prefill_seq_lens.append(seq_len) else: self.num_decode_tokens += query_len self.curr_seq_lens.append(curr_seq_len) def build(self, seq_lens: List[int], query_lens: List[int], cuda_graph_pad_size: int, batch_size: int): """Build attention metadata with on-device tensors. Args: seq_lens: The maybe padded sequence lengths of the input sequences. query_lens: The query lengths of the input sequences. cuda_graph_pad_size: The padding size for cuda graph. -1 if cuda graph is not used. batch_size: The maybe padded batch size. """ # Some input builders such as ModelInputForCPUBuilder do not have the # "inter_data_list" attribute. # Let's check inter_data_list exists before we reference it. if hasattr(self.input_builder, "inter_data_list"): for inter_data in self.input_builder.inter_data_list: self._add_seq_group(inter_data, self.input_builder.chunked_prefill_enabled) device = self.runner.device use_captured_graph = cuda_graph_pad_size != -1 max_query_len = max(query_lens) decode_query_lens = query_lens[self.num_prefills:] if len(decode_query_lens) > 0: max_decode_query_len = max(decode_query_lens) else: max_decode_query_len = 1 max_prefill_seq_len = max(self.prefill_seq_lens, default=0) max_decode_seq_len = max(self.curr_seq_lens, default=0) num_decode_tokens = self.num_decode_tokens query_start_loc = list(accumulate(query_lens, initial=0)) seq_start_loc = list(accumulate(seq_lens, initial=0)) if use_captured_graph: num_decode_tokens = batch_size - self.num_prefill_tokens assert max_query_len > 0, ("query_lens: {}".format(query_lens)) assert device is not None context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int, device, self.runner.pin_memory) seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, self.runner.pin_memory) query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32, device, self.runner.pin_memory) seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32, device, self.runner.pin_memory) placeholder_index_maps = { modality: placeholder_map.index_map() for modality, placeholder_map in self.multimodal_placeholder_maps.items() } # Placeholders slot_mapping_tensor = torch.empty(0) block_tables = torch.empty(0) return PlaceholderAttentionMetadata( num_prefills=self.num_prefills, slot_mapping=slot_mapping_tensor, multi_modal_placeholder_index_maps=placeholder_index_maps, enable_kv_scales_calculation=True, num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, max_query_len=max_query_len, max_decode_query_len=max_decode_query_len, max_prefill_seq_len=max_prefill_seq_len, max_decode_seq_len=max_decode_seq_len, query_start_loc=query_start_loc_tensor, seq_start_loc=seq_start_loc_tensor, context_lens_tensor=context_lens_tensor, block_tables=block_tables, use_cuda_graph=use_captured_graph, ) class PlaceholderAttentionImpl(AttentionImpl): def __init__(self, *args, **kwargs) -> None: return def forward(self, *args, **kwargs) -> torch.Tensor: raise NotImplementedError