# SPDX-License-Identifier: Apache-2.0 from abc import ABC, abstractmethod from collections import defaultdict from collections.abc import Sequence from pathlib import Path from typing import (TYPE_CHECKING, Any, Callable, Generic, NamedTuple, Optional, TypeVar, Union) from torch import nn from vllm.inputs import InputContext from vllm.logger import init_logger from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides, resolve_mm_processor_kwargs) if TYPE_CHECKING: from vllm.config import ModelConfig from vllm.sequence import SequenceGroupMetadata from .inputs import (ModalityData, MultiModalDataDict, MultiModalKwargs, PlaceholderRange) logger = init_logger(__name__) MultiModalInputMapper = Callable[[InputContext, ModalityData[object]], MultiModalKwargs] """ Return a dictionary to be passed as keyword arguments to :meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers and processors in HuggingFace Transformers. If the data is not supported, throw :exc:`TypeError`. """ MultiModalTokensCalc = Union[int, Callable[[InputContext], int]] """ Calculate the maximum number of multimodal tokens input to the language model. This does not include tokens that correspond to the input text. """ _T = TypeVar("_T") N = TypeVar("N", bound=type[nn.Module]) class MultiModalPlugin(ABC): """ Base class that defines data processing logic for a specific modality. In particular, we adopt a registry pattern to dispatch data processing according to the model being used (considering that different models may process the same data differently). This registry is in turn used by :class:`~MultiModalRegistry` which acts at a higher level (i.e., the modality of the data). """ def __init__(self) -> None: self._input_mappers = ClassRegistry[nn.Module, MultiModalInputMapper]() self._max_mm_tokens = ClassRegistry[nn.Module, MultiModalTokensCalc]() @abstractmethod def get_data_key(self) -> str: """ Get the data key corresponding to the modality. """ raise NotImplementedError @abstractmethod def _default_input_mapper( self, ctx: InputContext, data: ModalityData[Any], **mm_processor_kwargs, ) -> MultiModalKwargs: """ Return a dictionary to be passed as keyword arguments to :meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers and processors in HuggingFace Transformers. If the data is not supported, throw :exc:`TypeError`. """ raise NotImplementedError def register_input_mapper( self, mapper: Optional[MultiModalInputMapper] = None, ): """ Register an input mapper to a model class. When the model receives input data that matches the modality served by this plugin (see :meth:`get_data_key`), the provided function is invoked to transform the data into a dictionary of model inputs. If `None` is provided, then the default input mapper is used instead. """ def wrapper(model_cls: N) -> N: if self._input_mappers.contains(model_cls, strict=True): logger.warning( "Model class %s already has an input mapper " "registered to %s. It is overwritten by the new one.", model_cls, self, ) self._input_mappers[model_cls] = (mapper or self._default_input_mapper) return model_cls return wrapper def map_input( self, model_config: "ModelConfig", data: ModalityData[Any], mm_processor_kwargs: Optional[dict[str, Any]], ) -> MultiModalKwargs: """ Transform the data into a dictionary of model inputs using the input mapper registered for that model. The model is identified by ``model_config``. Raises: TypeError: If the data type is not supported. """ # Avoid circular import from vllm.model_executor.model_loader import get_model_architecture model_cls, _ = get_model_architecture(model_config) mapper = self._input_mappers.get(model_cls) if mapper is None: raise KeyError(f"No input mapper in {self} is registered for " f"model class {model_cls.__name__}.") if mm_processor_kwargs is None: mm_processor_kwargs = {} # In the case of the default mapper, we have to get resource # processor through its HuggingFace autoclass; since this goes # through **kwargs, we can't inspect it the same way, so we allow # drop mm_processor_kwargs based on signature inspection # if we're using the default mapper. # # This should be safe in general due to the sanitation, since the # transformers resource should filter unused kwargs anyway. uses_default_mapper = mapper == self._default_input_mapper mm_processor_kwargs = resolve_mm_processor_kwargs( model_config.mm_processor_kwargs, mm_processor_kwargs, callable=mapper, allow_var_kwargs=uses_default_mapper, ) return mapper(InputContext(model_config), data, **mm_processor_kwargs) @abstractmethod def _default_max_multimodal_tokens(self, ctx: InputContext) -> int: """ Calculate the maximum number of tokens, corresponding to a single instance of multimodal data, that are passed to the language model. """ raise NotImplementedError def _validate_max_multimodal_tokens(self, max_mm_tokens: int): if max_mm_tokens < 1: raise ValueError("You should set the number of tokens to a " f"positive integer. Found: {max_mm_tokens}") def register_max_multimodal_tokens( self, max_mm_tokens: Optional[MultiModalTokensCalc] = None, ): """ Register the maximum number of tokens, corresponding to a single instance of multimodal data, that are passed to the language model for a model class. If `None` is provided, then the default calculation is used instead. """ def wrapper(model_cls: N) -> N: if self._max_mm_tokens.contains(model_cls, strict=True): logger.warning( "Model class %s already calculates maximum number of " "tokens in %s. It is overwritten by the new one.", model_cls, self, ) if isinstance(max_mm_tokens, int): self._validate_max_multimodal_tokens(max_mm_tokens) self._max_mm_tokens[model_cls] = ( max_mm_tokens or self._default_max_multimodal_tokens) return model_cls return wrapper def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int: """ Get the maximum number of multi-modal tokens for profiling the memory usage of a model. If this registry is not applicable to the model, `0` is returned. The model is identified by ``model_config``. """ # Avoid circular import from vllm.model_executor.model_loader import get_model_architecture from vllm.model_executor.models import supports_multimodal model_cls, _ = get_model_architecture(model_config) if not supports_multimodal(model_cls): return 0 max_mm_tokens = self._max_mm_tokens.get(model_cls) if max_mm_tokens is None: return 0 if callable(max_mm_tokens): mm_processor_kwargs = get_allowed_kwarg_only_overrides( max_mm_tokens, overrides=model_config.mm_processor_kwargs, requires_kw_only=False, allow_var_kwargs=True, ) max_mm_tokens = max_mm_tokens(InputContext(model_config), **mm_processor_kwargs) self._validate_max_multimodal_tokens(max_mm_tokens) return max_mm_tokens class MultiModalPlaceholderMap: """ Relates multi-modal embeddings to their corresponding placeholders. """ class IndexMap(NamedTuple): src: list[int] dest: list[int] src_ranges: list[range] """ The indices of the multi-modal embeddings that will replace the corresponding placeholder embeddings pointed to by ``dest_ranges``. """ src_len: int """ The total number of flattened multi-modal embeddings. """ dest_ranges: list[range] """ The indices of the placeholder embeddings that will be replaced by the multimodal embeddings. """ dest_len: int """ The total number of embeddings in the destination tensor. """ def __init__(self): self.src_ranges = [] self.src_len = 0 self.dest_ranges = [] self.dest_len = 0 @classmethod def from_seq_group( cls, seq_group: "SequenceGroupMetadata", positions: range ) -> tuple[Optional[MultiModalDataDict], dict[str, "MultiModalPlaceholderMap"]]: """ Returns the multi-modal items that intersect with the portion of a prompt (``seq_group``) represented by ``positions``, as well as a ``MultiModalPlaceholderMap`` that relates the multi-modal embedding vectors to their corresponding placeholders. Examples: .. code-block:: Prompt: |AAAA BBBB What's in these images?| Positions: |.................................| images = [A, B] src_ranges = [(0, 4), (4, 8)] dest_ranges = [(0, 4), (5, 9)] Prompt: |AAAA BBBB What's in these images?| Positions: | ..... | images = [A, B] src_ranges = [(2, 4), (4, 6)] dest_ranges = [(0, 2), (3, 5)] Prompt: |AAAA BBBB What's in these images?| Positions: | ......... | images = [B] src_ranges = [(0, 4)] dest_ranges = [(0, 4)] Prompt: |AAAA BBBB What's in these images?| Positions: | .......................| images = [] src_ranges = [] dest_ranges = [] """ seq_mm_data = seq_group.multi_modal_data seq_mm_placeholders = seq_group.multi_modal_placeholders if not seq_mm_data or not seq_mm_placeholders: return seq_mm_data, {} # For merged processor, we directly use mm_kwargs as mm_data if isinstance(seq_mm_data, MultiModalKwargs): placeholder_maps = dict[str, MultiModalPlaceholderMap]() for modality, placeholders in seq_mm_placeholders.items(): placeholder_map = MultiModalPlaceholderMap() if positions: placeholder_map.append_items_from_seq_group( positions, # Dummy, since we don't care about intersecting items [None] * len(placeholders), placeholders, ) placeholder_maps[modality] = placeholder_map return seq_mm_data, placeholder_maps mm_data = {**seq_mm_data} placeholder_maps = defaultdict[str, MultiModalPlaceholderMap]( MultiModalPlaceholderMap) for modality, placeholders in seq_mm_placeholders.items(): mm_items = mm_data.pop(modality) if not isinstance(mm_items, list): mm_items = [mm_items] if positions: intersecting_items = placeholder_maps[modality] \ .append_items_from_seq_group( positions, mm_items, placeholders, ) if intersecting_items: mm_data[modality] = intersecting_items return mm_data, placeholder_maps def append_items_from_seq_group( self, positions: range, multi_modal_items: list[_T], multi_modal_placeholders: Sequence[PlaceholderRange], ) -> list[_T]: """ Adds the multi-modal items that intersect ```positions`` to this placeholder map and returns the intersecting items. """ intersecting_items = [] if len(multi_modal_items) != len(multi_modal_placeholders): raise ValueError( "Multi-modal placeholders and items must have the same length." ) for placeholder_dict, mm_item in zip(multi_modal_placeholders, multi_modal_items): placeholder = range( placeholder_dict["offset"], placeholder_dict["offset"] + placeholder_dict["length"], ) intersection = range( max(positions.start, placeholder.start), min(positions.stop, placeholder.stop), ) if not intersection: # Skip this multi-modal item. continue token_embedding_range = range( intersection.start - positions.start, intersection.stop - positions.start, ) multimodal_embedding_range = range( intersection.start - placeholder.start + self.src_len, intersection.stop - placeholder.start + self.src_len, ) intersecting_items.append(mm_item) self.dest_ranges.append(token_embedding_range) self.src_ranges.append(multimodal_embedding_range) self.src_len += len(placeholder) self.dest_len += len(positions) return intersecting_items def extend(self, other: "MultiModalPlaceholderMap"): """ Adds the placeholders from another ``MultiModalPlaceholderMap`` to this instance based on the source and destination tensors being concatenated. """ self.src_ranges.extend( range(self.src_len + r.start, self.src_len + r.stop) for r in other.src_ranges) self.src_len += other.src_len self.dest_ranges.extend( range(self.dest_len + r.start, self.dest_len + r.stop) for r in other.dest_ranges) self.dest_len += other.dest_len def index_map(self) -> "IndexMap": """ Finalizes the placeholder map into lists of indices that can be used to index the source and destination tensors. """ src_indices = [i for r in self.src_ranges for i in r] dest_indices = [i for r in self.dest_ranges for i in r] if len(src_indices) != len(dest_indices): raise ValueError( f"The number of source ({len(src_indices)}) and destination " f"indices ({len(dest_indices)}) must be the same.") return MultiModalPlaceholderMap.IndexMap(src=src_indices, dest=dest_indices) class MediaIO(ABC, Generic[_T]): @abstractmethod def load_bytes(self, data: bytes) -> _T: raise NotImplementedError @abstractmethod def load_base64(self, media_type: str, data: str) -> _T: """ List of media types: https://www.iana.org/assignments/media-types/media-types.xhtml """ raise NotImplementedError @abstractmethod def load_file(self, filepath: Path) -> _T: raise NotImplementedError