# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections import defaultdict from dataclasses import dataclass, field from typing import TYPE_CHECKING import torch from vllm.attention.backends.abstract import AttentionBackend from vllm.config import ModelConfig, SchedulerConfig, VllmConfig from vllm.model_executor.models.interfaces import MultiModalEmbeddings from vllm.model_executor.models.utils import extract_layer_index from vllm.multimodal.cache import processor_only_cache_from_config from vllm.multimodal.registry import MultiModalRegistry from vllm.platforms import current_platform from vllm.v1.attention.backends.utils import AttentionMetadataBuilder from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec if TYPE_CHECKING: from vllm.attention.layer import Attention class MultiModalBudget: """Helper class to calculate budget information for multi-modal models.""" def __init__( self, model_config: ModelConfig, scheduler_config: SchedulerConfig, mm_registry: MultiModalRegistry, ) -> None: super().__init__() self.model_config = model_config self.scheduler_config = scheduler_config self.mm_registry = mm_registry self.cache = cache = processor_only_cache_from_config(model_config, mm_registry) self.max_model_len = model_config.max_model_len self.max_num_reqs = scheduler_config.max_num_seqs self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config, cache=cache) max_tokens_by_modality = mm_registry.get_max_tokens_per_item_by_modality( model_config, cache=cache, profiler_limits=self.mm_limits, ) encoder_compute_budget, encoder_cache_size = compute_mm_encoder_budget( scheduler_config, max_tokens_by_modality, ) self.encoder_compute_budget = encoder_compute_budget self.encoder_cache_size = encoder_cache_size max_items_per_prompt_by_modality = dict[str, int]() max_items_per_batch_by_modality = dict[str, int]() for modality, max_tokens in max_tokens_by_modality.items(): ( max_items_per_prompt, max_items_per_batch, ) = self.get_max_items(modality, max_tokens) max_items_per_prompt_by_modality[modality] = max_items_per_prompt max_items_per_batch_by_modality[modality] = max_items_per_batch self.max_tokens_by_modality = max_tokens_by_modality self.max_items_per_prompt_by_modality = max_items_per_prompt_by_modality self.max_items_per_batch_by_modality = max_items_per_batch_by_modality def get_modality_with_max_tokens(self) -> str: max_tokens_by_modality = self.max_tokens_by_modality modality, _ = max(max_tokens_by_modality.items(), key=lambda x: x[1]) return modality def get_encoder_budget(self) -> int: return min(self.encoder_compute_budget, self.encoder_cache_size) def get_max_items( self, modality: str, max_tokens_per_item: int, ) -> tuple[int, int]: if max_tokens_per_item == 0: return 0, 0 # Check how many items of this modality can be supported by # the encoder budget. encoder_budget = self.get_encoder_budget() # TODO: handle encoder-decoder models once we support them. if encoder_budget == 0: return 0, 0 max_encoder_items_per_batch = encoder_budget // max_tokens_per_item # Check how many items of this modality can be supported by # the decoder budget. mm_limit = self.mm_limits[modality] max_items_per_prompt = max( 1, min(mm_limit, self.max_model_len // max_tokens_per_item), ) scheduler_config = self.scheduler_config max_num_reqs = self.max_num_reqs if not scheduler_config.enable_chunked_prefill: max_num_reqs = min( max_num_reqs, scheduler_config.max_num_batched_tokens // max_tokens_per_item, ) max_decoder_items_per_batch = max_num_reqs * max_items_per_prompt max_items_per_batch = max( 1, min(max_encoder_items_per_batch, max_decoder_items_per_batch), ) return max_items_per_prompt, max_items_per_batch def reset_cache(self) -> None: if self.cache is not None: self.cache.clear_cache() @dataclass class AttentionGroup: backend: type[AttentionBackend] layer_names: list[str] kv_cache_spec: KVCacheSpec kv_cache_group_id: int # When ubatching is enabled we will have a metadata builder for each ubatch # so that if they use internal persistant buffers for cudagraphs, and they # won't have to worry about conflicting with the other ubatches. metadata_builders: list[AttentionMetadataBuilder] = field( default_factory=lambda: [] ) def create_metadata_builders( self, vllm_config, device, kernel_block_size: int | None, num_metadata_builders: int = 1, ): kv_cache_spec_builder = ( self.kv_cache_spec.copy_with_new_block_size(kernel_block_size) if kernel_block_size is not None else self.kv_cache_spec ) self.metadata_builders = [ self.backend.get_builder_cls()( kv_cache_spec_builder, self.layer_names, vllm_config, device, ) for _ in range(num_metadata_builders) ] def get_metadata_builder(self, ubatch_id: int = 0) -> AttentionMetadataBuilder: assert len(self.metadata_builders) > ubatch_id return self.metadata_builders[ubatch_id] def sanity_check_mm_encoder_outputs( mm_embeddings: MultiModalEmbeddings, expected_num_items: int, ) -> None: """ Perform sanity checks for the result of [`vllm.model_executor.models.SupportsMultiModal.embed_multimodal`][]. """ assert isinstance(mm_embeddings, (list, tuple, torch.Tensor)), ( "Expected multimodal embeddings to be a list/tuple of 2D tensors, " f"or a single 3D tensor, but got {type(mm_embeddings)} " "instead. This is most likely due to incorrect implementation " "of the model's `embed_multimodal` method." ) assert len(mm_embeddings) == expected_num_items, ( "Expected number of multimodal embeddings to match number of " f"input items: {expected_num_items}, but got {len(mm_embeddings)=} " "instead. This is most likely due to incorrect implementation " "of the model's `embed_multimodal` method." ) assert all(e.ndim == 2 for e in mm_embeddings), ( "Expected multimodal embeddings to be a sequence of 2D tensors, " f"but got tensors with shapes {[e.shape for e in mm_embeddings]} " "instead. This is most likely due to incorrect implementation " "of the model's `embed_multimodal` method." ) def scatter_mm_placeholders( embeds: torch.Tensor, is_embed: torch.Tensor | None, ) -> torch.Tensor: """ Scatter the multimodal embeddings into a contiguous tensor that represents the placeholder tokens. [`vllm.multimodal.processing.PromptUpdateDetails.is_embed`][]. Args: embeds: The multimodal embeddings. Shape: `(num_embeds, embed_dim)` is_embed: A boolean mask indicating which positions in the placeholder tokens need to be filled with multimodal embeddings. Shape: `(num_placeholders, num_embeds)` """ if is_embed is None: return embeds placeholders = embeds.new_full( (is_embed.shape[0], embeds.shape[-1]), fill_value=torch.nan, ) placeholders[is_embed] = embeds return placeholders def gather_mm_placeholders( placeholders: torch.Tensor, is_embed: torch.Tensor | None, ) -> torch.Tensor: """ Reconstructs the embeddings from the placeholder tokens. This is the operation of [`scatter_mm_placeholders`] [vllm.v1.worker.utils.scatter_mm_placeholders]. """ if is_embed is None: return placeholders return placeholders[is_embed] def add_kv_sharing_layers_to_kv_cache_groups( shared_kv_cache_layers: dict[str, str], kv_cache_groups: list[KVCacheGroupSpec], runner_only_attn_layers: set[str] | None = None, ) -> None: """ Sets up KV cache sharing by reusing the allocated KV caches in `kv_caches` for layers that do not allocate its own KV cache, based on the mapping in `shared_kv_cache_layers`. Adds these layers to the corresponding KV cache group, which is needed to ensure that attention metadata is assigned later. Args: shared_kv_cache_layers: Layer pairings for cross-layer KV sharing. If an Attention layer `layer_name` is in the keys of this dict, it means this layer will perform attention using the keys and values from the KV cache of `shared_kv_cache_layers[layer_name]`. kv_cache_groups: The KV cache groups of the model. """ layer_to_kv_cache_group: dict[str, KVCacheGroupSpec] = {} for kv_cache_group in kv_cache_groups: for layer_name in kv_cache_group.layer_names: layer_to_kv_cache_group[layer_name] = kv_cache_group for layer_name, target_layer_name in shared_kv_cache_layers.items(): tgt_kv_cache_group = layer_to_kv_cache_group[target_layer_name] tgt_kv_cache_group.layer_names.append(layer_name) if runner_only_attn_layers is not None: runner_only_attn_layers.add(layer_name) def bind_kv_cache( kv_caches: dict[str, torch.Tensor], forward_context: dict[str, "Attention"], runner_kv_caches: list[torch.Tensor], num_attn_module: int | None = 1, ) -> None: """ Bind the allocated KV cache to both ModelRunner and forward context so that the KV cache can be used in the forward pass. This function: 1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with kv_caches. 2) Associates each attention layer in the `forward_context` with its corresponding KV cache in kv_caches. Args: kv_caches: The allocated kv_caches with layer names as keys. forward_context: The global forward context containing all Attention layers with layer names as keys. runner_kv_caches: The kv_cache declared by ModelRunner. """ # Bind kv_caches to ModelRunner assert len(runner_kv_caches) == 0 # Convert kv_caches dict to a list of tensors in the order of layer_index. index2name = defaultdict(list) for layer_name in kv_caches: index2name[extract_layer_index(layer_name, num_attn_module)].append(layer_name) for layer_index in sorted(index2name.keys()): layer_names = index2name[layer_index] if len(layer_names) > 1: # One typical case is encoder-decoder model, e.g., bart. # The cross attention and self attention in the same decoder layer # has different layer_name but the same layer_index. # TODO - analyze where runner_kv_caches is used and the right # way to ensure it properly reflects multiple attention layers # in the same decoder block. if current_platform.is_cuda() or current_platform.is_xpu(): # We know that the GPU runner is not impacted by this # case. Some test code depends on runner_kv_caches, but # not in a way that's impacted by ignoring this. pass else: raise NotImplementedError layer_name = layer_names[0] runner_kv_caches.append(kv_caches[layer_name]) # Bind kv_caches to forward context for layer_name, kv_cache in kv_caches.items(): # NOTE: Use list because of v0 PP virtual engine. forward_context[layer_name].kv_cache = [kv_cache] def bind_kv_cache_scale( kv_caches_scale: dict[str, torch.Tensor], forward_context: dict[str, "Attention"], runner_kv_caches_scale: list[torch.Tensor], num_attn_module: int | None = 1, ) -> None: """ Bind the allocated KV cache to both ModelRunner and forward context so that the KV cache can be used in the forward pass. This function: 1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with kv_caches. 2) Associates each attention layer in the `forward_context` with its corresponding KV cache in kv_caches. Args: kv_caches: The allocated kv_caches with layer names as keys. forward_context: The global forward context containing all Attention layers with layer names as keys. runner_kv_caches: The kv_cache declared by ModelRunner. """ # Bind kv_caches to ModelRunner assert len(runner_kv_caches_scale) == 0 # Convert kv_caches dict to a list of tensors in the order of layer_index. index2name = defaultdict(list) for layer_name in kv_caches_scale: index2name[extract_layer_index(layer_name, num_attn_module)].append(layer_name) for layer_index in sorted(index2name.keys()): layer_names = index2name[layer_index] if len(layer_names) > 1: # One typical case is encoder-decoder model, e.g., bart. # The cross attention and self attention in the same decoder layer # has different layer_name but the same layer_index. if current_platform.is_cuda() or current_platform.is_xpu(): pass else: raise NotImplementedError layer_name = layer_names[0] runner_kv_caches_scale.append(kv_caches_scale[layer_name]) # Bind kv_caches to forward context for layer_name, kv_cache_scale in kv_caches_scale.items(): # NOTE: Use list because of v0 PP virtual engine. forward_context[layer_name].kv_cache_scale = [kv_cache_scale] def is_residual_scattered_for_sp( vllm_config: VllmConfig, num_input_tokens: int ) -> bool: """Check if the residual tensor is scattered for sequence parallelism. The residual tensor is scattered across tensor parallel ranks when sequence parallelism and tensor parallelism is enabled. This follows the same logic as SequenceParallelismPass.is_applicable(): - In full-graph compilation mode (no splitting ops or using inductor graph partition), SP is always applied - Otherwise, SP is only applied for specific shapes in compile_sizes """ if not vllm_config.compilation_config.pass_config.enable_sequence_parallelism: return False tp = vllm_config.parallel_config.tensor_parallel_size if tp == 1: return False # When sequence parallelism is enabled, we always pad num_input_tokens # to be a multiple of tensor_parallel_size (tp) earlier. assert num_input_tokens % tp == 0 if ( not vllm_config.compilation_config.splitting_ops or vllm_config.compilation_config.use_inductor_graph_partition ): return True return num_input_tokens in vllm_config.compilation_config.compile_sizes