Iluvatar-mrv100 SDK 4.3.0

This commit is contained in:
2025-09-15 14:58:11 +08:00
parent 9efe891f99
commit 8af8290b1d
1052 changed files with 294967 additions and 1 deletions

View File

@@ -0,0 +1,33 @@
# SPDX-License-Identifier: Apache-2.0
from .base import MultiModalPlaceholderMap, MultiModalPlugin
from .hasher import MultiModalHashDict, MultiModalHasher
from .inputs import (BatchedTensorInputs, ModalityData, MultiModalDataBuiltins,
MultiModalDataDict, MultiModalKwargs,
MultiModalPlaceholderDict, NestedTensors)
from .registry import MultiModalRegistry
MULTIMODAL_REGISTRY = MultiModalRegistry()
"""
The global :class:`~MultiModalRegistry` is used by model runners to
dispatch data processing according to the target model.
See also:
:ref:`mm-processing`
"""
__all__ = [
"BatchedTensorInputs",
"ModalityData",
"MultiModalDataBuiltins",
"MultiModalDataDict",
"MultiModalHashDict",
"MultiModalHasher",
"MultiModalKwargs",
"MultiModalPlaceholderDict",
"MultiModalPlaceholderMap",
"MultiModalPlugin",
"NestedTensors",
"MULTIMODAL_REGISTRY",
"MultiModalRegistry",
]

77
vllm/multimodal/audio.py Normal file
View File

@@ -0,0 +1,77 @@
# SPDX-License-Identifier: Apache-2.0
import base64
from io import BytesIO
from pathlib import Path
import numpy as np
import numpy.typing as npt
from vllm.inputs.registry import InputContext
from vllm.utils import PlaceholderModule
from .base import MediaIO, MultiModalPlugin
from .inputs import AudioItem, ModalityData, MultiModalKwargs
try:
import librosa
except ImportError:
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
try:
import soundfile
except ImportError:
soundfile = PlaceholderModule("soundfile") # type: ignore[assignment]
class AudioPlugin(MultiModalPlugin):
"""Plugin for audio data."""
def get_data_key(self) -> str:
return "audio"
def _default_input_mapper(
self,
ctx: InputContext,
data: ModalityData[AudioItem],
**mm_processor_kwargs,
) -> MultiModalKwargs:
raise NotImplementedError("There is no default audio input mapper")
def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
raise NotImplementedError(
"There is no default maximum multimodal tokens")
def resample_audio(
audio: npt.NDArray[np.floating],
*,
orig_sr: float,
target_sr: float,
) -> npt.NDArray[np.floating]:
return librosa.resample(audio, orig_sr=orig_sr, target_sr=target_sr)
class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]):
def load_bytes(self, data: bytes) -> tuple[npt.NDArray, float]:
return librosa.load(BytesIO(data), sr=None)
def load_base64(
self,
media_type: str,
data: str,
) -> tuple[npt.NDArray, float]:
return self.load_bytes(base64.b64decode(data))
def load_file(self, filepath: Path) -> tuple[npt.NDArray, float]:
return librosa.load(filepath, sr=None)
def encode_base64(self, media: tuple[npt.NDArray, float]) -> str:
audio, sr = media
with BytesIO() as buffer:
soundfile.write(buffer, audio, sr, format="WAV")
data = buffer.getvalue()
return base64.b64encode(data).decode('utf-8')

468
vllm/multimodal/base.py Normal file
View File

@@ -0,0 +1,468 @@
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
from collections import defaultdict
from collections.abc import Sequence
from pathlib import Path
from typing import (TYPE_CHECKING, Any, Callable, Generic, NamedTuple,
Optional, TypeVar, Union)
from torch import nn
from vllm.inputs import InputContext
from vllm.logger import init_logger
from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides,
resolve_mm_processor_kwargs)
if TYPE_CHECKING:
from vllm.config import ModelConfig
from vllm.sequence import SequenceGroupMetadata
from .inputs import (ModalityData, MultiModalDataDict, MultiModalKwargs,
PlaceholderRange)
logger = init_logger(__name__)
MultiModalInputMapper = Callable[[InputContext, ModalityData[object]],
MultiModalKwargs]
"""
Return a dictionary to be passed as keyword arguments to
:meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers
and processors in HuggingFace Transformers.
If the data is not supported, throw :exc:`TypeError`.
"""
MultiModalTokensCalc = Union[int, Callable[[InputContext], int]]
"""
Calculate the maximum number of multimodal tokens input to the language
model. This does not include tokens that correspond to the input text.
"""
_T = TypeVar("_T")
N = TypeVar("N", bound=type[nn.Module])
class MultiModalPlugin(ABC):
"""
Base class that defines data processing logic for a specific modality.
In particular, we adopt a registry pattern to dispatch data processing
according to the model being used (considering that different models may
process the same data differently). This registry is in turn used by
:class:`~MultiModalRegistry` which acts at a higher level
(i.e., the modality of the data).
"""
def __init__(self) -> None:
self._input_mappers = ClassRegistry[nn.Module, MultiModalInputMapper]()
self._max_mm_tokens = ClassRegistry[nn.Module, MultiModalTokensCalc]()
@abstractmethod
def get_data_key(self) -> str:
"""
Get the data key corresponding to the modality.
"""
raise NotImplementedError
@abstractmethod
def _default_input_mapper(
self,
ctx: InputContext,
data: ModalityData[Any],
**mm_processor_kwargs,
) -> MultiModalKwargs:
"""
Return a dictionary to be passed as keyword arguments to
:meth:`~torch.nn.Module.forward`. This is similar in concept to
tokenizers and processors in HuggingFace Transformers.
If the data is not supported, throw :exc:`TypeError`.
"""
raise NotImplementedError
def register_input_mapper(
self,
mapper: Optional[MultiModalInputMapper] = None,
):
"""
Register an input mapper to a model class.
When the model receives input data that matches the modality served by
this plugin (see :meth:`get_data_key`), the provided function is
invoked to transform the data into a dictionary of model inputs.
If `None` is provided, then the default input mapper is used instead.
"""
def wrapper(model_cls: N) -> N:
if self._input_mappers.contains(model_cls, strict=True):
logger.warning(
"Model class %s already has an input mapper "
"registered to %s. It is overwritten by the new one.",
model_cls,
self,
)
self._input_mappers[model_cls] = (mapper
or self._default_input_mapper)
return model_cls
return wrapper
def map_input(
self,
model_config: "ModelConfig",
data: ModalityData[Any],
mm_processor_kwargs: Optional[dict[str, Any]],
) -> MultiModalKwargs:
"""
Transform the data into a dictionary of model inputs using the
input mapper registered for that model.
The model is identified by ``model_config``.
Raises:
TypeError: If the data type is not supported.
"""
# Avoid circular import
from vllm.model_executor.model_loader import get_model_architecture
model_cls, _ = get_model_architecture(model_config)
mapper = self._input_mappers.get(model_cls)
if mapper is None:
raise KeyError(f"No input mapper in {self} is registered for "
f"model class {model_cls.__name__}.")
if mm_processor_kwargs is None:
mm_processor_kwargs = {}
# In the case of the default mapper, we have to get resource
# processor through its HuggingFace autoclass; since this goes
# through **kwargs, we can't inspect it the same way, so we allow
# drop mm_processor_kwargs based on signature inspection
# if we're using the default mapper.
#
# This should be safe in general due to the sanitation, since the
# transformers resource should filter unused kwargs anyway.
uses_default_mapper = mapper == self._default_input_mapper
mm_processor_kwargs = resolve_mm_processor_kwargs(
model_config.mm_processor_kwargs,
mm_processor_kwargs,
callable=mapper,
allow_var_kwargs=uses_default_mapper,
)
return mapper(InputContext(model_config), data, **mm_processor_kwargs)
@abstractmethod
def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
"""
Calculate the maximum number of tokens, corresponding to a single
instance of multimodal data, that are passed to the language model.
"""
raise NotImplementedError
def _validate_max_multimodal_tokens(self, max_mm_tokens: int):
if max_mm_tokens < 1:
raise ValueError("You should set the number of tokens to a "
f"positive integer. Found: {max_mm_tokens}")
def register_max_multimodal_tokens(
self,
max_mm_tokens: Optional[MultiModalTokensCalc] = None,
):
"""
Register the maximum number of tokens, corresponding to a single
instance of multimodal data, that are passed to the language model
for a model class.
If `None` is provided, then the default calculation is used instead.
"""
def wrapper(model_cls: N) -> N:
if self._max_mm_tokens.contains(model_cls, strict=True):
logger.warning(
"Model class %s already calculates maximum number of "
"tokens in %s. It is overwritten by the new one.",
model_cls,
self,
)
if isinstance(max_mm_tokens, int):
self._validate_max_multimodal_tokens(max_mm_tokens)
self._max_mm_tokens[model_cls] = (
max_mm_tokens or self._default_max_multimodal_tokens)
return model_cls
return wrapper
def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int:
"""
Get the maximum number of multi-modal tokens
for profiling the memory usage of a model.
If this registry is not applicable to the model, `0` is returned.
The model is identified by ``model_config``.
"""
# Avoid circular import
from vllm.model_executor.model_loader import get_model_architecture
from vllm.model_executor.models import supports_multimodal
model_cls, _ = get_model_architecture(model_config)
if not supports_multimodal(model_cls):
return 0
max_mm_tokens = self._max_mm_tokens.get(model_cls)
if max_mm_tokens is None:
return 0
if callable(max_mm_tokens):
mm_processor_kwargs = get_allowed_kwarg_only_overrides(
max_mm_tokens,
overrides=model_config.mm_processor_kwargs,
requires_kw_only=False,
allow_var_kwargs=True,
)
max_mm_tokens = max_mm_tokens(InputContext(model_config),
**mm_processor_kwargs)
self._validate_max_multimodal_tokens(max_mm_tokens)
return max_mm_tokens
class MultiModalPlaceholderMap:
"""
Relates multi-modal embeddings to their corresponding placeholders.
"""
class IndexMap(NamedTuple):
src: list[int]
dest: list[int]
src_ranges: list[range]
"""
The indices of the multi-modal embeddings that will replace the
corresponding placeholder embeddings pointed to by ``dest_ranges``.
"""
src_len: int
"""
The total number of flattened multi-modal embeddings.
"""
dest_ranges: list[range]
"""
The indices of the placeholder embeddings that will be replaced by the
multimodal embeddings.
"""
dest_len: int
"""
The total number of embeddings in the destination tensor.
"""
def __init__(self):
self.src_ranges = []
self.src_len = 0
self.dest_ranges = []
self.dest_len = 0
@classmethod
def from_seq_group(
cls, seq_group: "SequenceGroupMetadata", positions: range
) -> tuple[Optional[MultiModalDataDict], dict[str,
"MultiModalPlaceholderMap"]]:
"""
Returns the multi-modal items that intersect with the portion of a
prompt (``seq_group``) represented by ``positions``, as well as a
``MultiModalPlaceholderMap`` that relates the multi-modal embedding
vectors to their corresponding placeholders.
Examples:
.. code-block::
Prompt: |AAAA BBBB What's in these images?|
Positions: |.................................|
images = [A, B]
src_ranges = [(0, 4), (4, 8)]
dest_ranges = [(0, 4), (5, 9)]
Prompt: |AAAA BBBB What's in these images?|
Positions: | ..... |
images = [A, B]
src_ranges = [(2, 4), (4, 6)]
dest_ranges = [(0, 2), (3, 5)]
Prompt: |AAAA BBBB What's in these images?|
Positions: | ......... |
images = [B]
src_ranges = [(0, 4)]
dest_ranges = [(0, 4)]
Prompt: |AAAA BBBB What's in these images?|
Positions: | .......................|
images = []
src_ranges = []
dest_ranges = []
"""
seq_mm_data = seq_group.multi_modal_data
seq_mm_placeholders = seq_group.multi_modal_placeholders
if not seq_mm_data or not seq_mm_placeholders:
return seq_mm_data, {}
# For merged processor, we directly use mm_kwargs as mm_data
if isinstance(seq_mm_data, MultiModalKwargs):
placeholder_maps = dict[str, MultiModalPlaceholderMap]()
for modality, placeholders in seq_mm_placeholders.items():
placeholder_map = MultiModalPlaceholderMap()
if positions:
placeholder_map.append_items_from_seq_group(
positions,
# Dummy, since we don't care about intersecting items
[None] * len(placeholders),
placeholders,
)
placeholder_maps[modality] = placeholder_map
return seq_mm_data, placeholder_maps
mm_data = {**seq_mm_data}
placeholder_maps = defaultdict[str, MultiModalPlaceholderMap](
MultiModalPlaceholderMap)
for modality, placeholders in seq_mm_placeholders.items():
mm_items = mm_data.pop(modality)
if not isinstance(mm_items, list):
mm_items = [mm_items]
if positions:
intersecting_items = placeholder_maps[modality] \
.append_items_from_seq_group(
positions,
mm_items,
placeholders,
)
if intersecting_items:
mm_data[modality] = intersecting_items
return mm_data, placeholder_maps
def append_items_from_seq_group(
self,
positions: range,
multi_modal_items: list[_T],
multi_modal_placeholders: Sequence[PlaceholderRange],
) -> list[_T]:
"""
Adds the multi-modal items that intersect ```positions`` to this
placeholder map and returns the intersecting items.
"""
intersecting_items = []
if len(multi_modal_items) != len(multi_modal_placeholders):
raise ValueError(
"Multi-modal placeholders and items must have the same length."
)
for placeholder_dict, mm_item in zip(multi_modal_placeholders,
multi_modal_items):
placeholder = range(
placeholder_dict["offset"],
placeholder_dict["offset"] + placeholder_dict["length"],
)
intersection = range(
max(positions.start, placeholder.start),
min(positions.stop, placeholder.stop),
)
if not intersection:
# Skip this multi-modal item.
continue
token_embedding_range = range(
intersection.start - positions.start,
intersection.stop - positions.start,
)
multimodal_embedding_range = range(
intersection.start - placeholder.start + self.src_len,
intersection.stop - placeholder.start + self.src_len,
)
intersecting_items.append(mm_item)
self.dest_ranges.append(token_embedding_range)
self.src_ranges.append(multimodal_embedding_range)
self.src_len += len(placeholder)
self.dest_len += len(positions)
return intersecting_items
def extend(self, other: "MultiModalPlaceholderMap"):
"""
Adds the placeholders from another ``MultiModalPlaceholderMap`` to this
instance based on the source and destination tensors being
concatenated.
"""
self.src_ranges.extend(
range(self.src_len + r.start, self.src_len + r.stop)
for r in other.src_ranges)
self.src_len += other.src_len
self.dest_ranges.extend(
range(self.dest_len + r.start, self.dest_len + r.stop)
for r in other.dest_ranges)
self.dest_len += other.dest_len
def index_map(self) -> "IndexMap":
"""
Finalizes the placeholder map into lists of indices that can be used to
index the source and destination tensors.
"""
src_indices = [i for r in self.src_ranges for i in r]
dest_indices = [i for r in self.dest_ranges for i in r]
if len(src_indices) != len(dest_indices):
raise ValueError(
f"The number of source ({len(src_indices)}) and destination "
f"indices ({len(dest_indices)}) must be the same.")
return MultiModalPlaceholderMap.IndexMap(src=src_indices,
dest=dest_indices)
class MediaIO(ABC, Generic[_T]):
@abstractmethod
def load_bytes(self, data: bytes) -> _T:
raise NotImplementedError
@abstractmethod
def load_base64(self, media_type: str, data: str) -> _T:
"""
List of media types:
https://www.iana.org/assignments/media-types/media-types.xhtml
"""
raise NotImplementedError
@abstractmethod
def load_file(self, filepath: Path) -> _T:
raise NotImplementedError

103
vllm/multimodal/hasher.py Normal file
View File

@@ -0,0 +1,103 @@
# SPDX-License-Identifier: Apache-2.0
import pickle
from collections.abc import Iterable, Mapping
from typing import TYPE_CHECKING, Optional
import numpy as np
import torch
from blake3 import blake3
from PIL import Image
from vllm.logger import init_logger
if TYPE_CHECKING:
from vllm.inputs import TokensPrompt
logger = init_logger(__name__)
MultiModalHashDict = Mapping[str, list[str]]
"""
A dictionary containing hashes for items in each modality.
"""
class MultiModalHasher:
@classmethod
def serialize_item(cls, obj: object) -> bytes:
# Simple cases
if isinstance(obj, str):
return obj.encode("utf-8")
if isinstance(obj, bytes):
return obj
if isinstance(obj, Image.Image):
return obj.tobytes()
# Convertible to NumPy arrays
if isinstance(obj, torch.Tensor):
obj = obj.numpy()
if isinstance(obj, (int, float)):
obj = np.array(obj)
if isinstance(obj, np.ndarray):
return obj.tobytes()
logger.warning(
"No serialization method found for %s. "
"Falling back to pickle.", type(obj))
return pickle.dumps(obj)
@classmethod
def item_to_bytes(
cls,
key: str,
obj: object,
) -> Iterable[tuple[bytes, bytes]]:
# Recursive cases
if isinstance(obj, (list, tuple)):
for i, elem in enumerate(obj):
yield from cls.item_to_bytes(f"{key}.{i}", elem)
elif isinstance(obj, dict):
for k, v in obj.items():
yield from cls.item_to_bytes(f"{key}.{k}", v)
else:
key_bytes = cls.serialize_item(key)
value_bytes = cls.serialize_item(obj)
yield key_bytes, value_bytes
@classmethod
def hash_kwargs(cls, **kwargs: object) -> str:
hasher = blake3()
for k, v in kwargs.items():
for k_bytes, v_bytes in cls.item_to_bytes(k, v):
hasher.update(k_bytes)
hasher.update(v_bytes)
return hasher.hexdigest()
@classmethod
def hash_prompt_mm_data(
cls, prompt: "TokensPrompt") -> Optional["MultiModalHashDict"]:
"""Hash multimodal data in the user input prompt if they exist."""
if "multi_modal_data" not in prompt:
return None
mm_data = prompt["multi_modal_data"]
if not mm_data:
# mm_data can be None or an empty dict.
return None
mm_items = {
modality: items if isinstance(items, list) else [items]
for modality, items in mm_data.items()
}
mm_hashes = {
modality: [cls.hash_kwargs(**{modality: item}) for item in items]
for modality, items in mm_items.items()
}
return mm_hashes

155
vllm/multimodal/image.py Normal file
View File

@@ -0,0 +1,155 @@
# SPDX-License-Identifier: Apache-2.0
import base64
from io import BytesIO
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional
import torch
from PIL import Image
from vllm.inputs.registry import InputContext
from vllm.logger import init_logger
from vllm.transformers_utils.processor import cached_get_image_processor
from vllm.utils import is_list_of
from .base import MediaIO, MultiModalPlugin
from .inputs import ImageItem, ModalityData, MultiModalKwargs
if TYPE_CHECKING:
from vllm.config import ModelConfig
logger = init_logger(__name__)
class ImagePlugin(MultiModalPlugin):
"""Plugin for image data."""
def get_data_key(self) -> str:
return "image"
def _get_hf_image_processor(
self,
model_config: "ModelConfig",
mm_processor_kwargs: Optional[dict[str, Any]] = None,
):
if mm_processor_kwargs is None:
mm_processor_kwargs = {}
return cached_get_image_processor(
model_config.model,
trust_remote_code=model_config.trust_remote_code,
**mm_processor_kwargs)
def _default_input_mapper(
self,
ctx: InputContext,
data: ModalityData[ImageItem],
**mm_processor_kwargs,
) -> MultiModalKwargs:
model_config = ctx.model_config
# PIL image
if isinstance(data, Image.Image) or is_list_of(data, Image.Image):
image_processor = self._get_hf_image_processor(
model_config,
mm_processor_kwargs,
)
if image_processor is None:
raise RuntimeError("No HuggingFace processor is available "
"to process the image object")
try:
# NOTE: It may make sense to forward the mm_processor_kwargs
# here too. For now, to keep it simple, we only allow it be
# used for the initialization call though, just in case the
# signatures of the preprocessor initializer don't match
# preprocess()
batch_data = image_processor \
.preprocess(data, return_tensors="pt") \
.data
except Exception:
logger.error(
"Failed to process image (%s) with the default mapper. "
"This is most likely an edge-case with this model's image "
"processor in transformers (type: %s), and not vLLM.",
data,
type(image_processor).__name__)
raise
return MultiModalKwargs(batch_data)
# Image embedding
elif isinstance(data, torch.Tensor) or is_list_of(data, torch.Tensor):
return MultiModalKwargs({"image_embeds": data})
raise TypeError(f"Invalid image type: {type(data)}")
def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
return 3000
def rescale_image_size(image: Image.Image,
size_factor: float,
transpose: int = -1) -> Image.Image:
"""Rescale the dimensions of an image by a constant factor."""
new_width = int(image.width * size_factor)
new_height = int(image.height * size_factor)
image = image.resize((new_width, new_height))
if transpose >= 0:
image = image.transpose(Image.Transpose(transpose))
return image
class ImageMediaIO(MediaIO[Image.Image]):
def __init__(self, *, image_mode: str = "RGB") -> None:
super().__init__()
self.image_mode = image_mode
def load_bytes(self, data: bytes) -> Image.Image:
image = Image.open(BytesIO(data))
image.load()
return image.convert(self.image_mode)
def load_base64(self, media_type: str, data: str) -> Image.Image:
return self.load_bytes(base64.b64decode(data))
def load_file(self, filepath: Path) -> Image.Image:
image = Image.open(filepath)
image.load()
return image.convert(self.image_mode)
def encode_base64(
self,
media: Image.Image,
*,
image_format: str = "JPEG",
) -> str:
image = media
with BytesIO() as buffer:
image = image.convert(self.image_mode)
image.save(buffer, image_format)
data = buffer.getvalue()
return base64.b64encode(data).decode('utf-8')
class ImageEmbeddingMediaIO(MediaIO[torch.Tensor]):
def __init__(self) -> None:
super().__init__()
def load_bytes(self, data: bytes) -> torch.Tensor:
buffer = BytesIO(data)
return torch.load(buffer, weights_only=True)
def load_base64(self, media_type: str, data: str) -> torch.Tensor:
return self.load_bytes(base64.b64decode(data))
def load_file(self, filepath: Path) -> torch.Tensor:
return torch.load(filepath, weights_only=True)
def encode_base64(self, media: torch.Tensor) -> str:
return base64.b64encode(media.numpy()).decode('utf-8')

769
vllm/multimodal/inputs.py Normal file
View File

@@ -0,0 +1,769 @@
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
from collections import UserDict, defaultdict
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from functools import partial
from itertools import accumulate
from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, TypeVar,
Union, cast, final)
import numpy as np
import torch
import torch.types
from PIL.Image import Image
from transformers import BatchFeature
from typing_extensions import NotRequired, TypeAlias
from vllm.jsontree import JSONTree, json_map_leaves
from vllm.utils import full_groupby, is_list_of
if TYPE_CHECKING:
from .hasher import MultiModalHashDict
_T = TypeVar("_T")
HfImageItem: TypeAlias = Union[Image, np.ndarray, torch.Tensor]
"""
A :class:`transformers.image_utils.ImageInput` representing a single image
item, which can be passed to a HuggingFace :code:`ImageProcessor`.
"""
HfVideoItem: TypeAlias = Union[list[Image], np.ndarray, torch.Tensor,
list[np.ndarray], list[torch.Tensor]]
"""
A :class:`transformers.image_utils.VideoInput` representing a single video
item, which can be passed to a HuggingFace :code:`VideoProcessor`.
"""
HfAudioItem: TypeAlias = Union[list[float], np.ndarray, torch.Tensor]
"""
Represents a single audio
item, which can be passed to a HuggingFace :code:`AudioProcessor`.
"""
ImageItem: TypeAlias = Union[HfImageItem, torch.Tensor]
"""
A :class:`transformers.image_utils.ImageInput` representing a single image
item, which can be passed to a HuggingFace :code:`ImageProcessor`.
Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as image embeddings;
these are directly passed to the model without HF processing.
"""
VideoItem: TypeAlias = Union[HfVideoItem, torch.Tensor]
"""
A :class:`transformers.image_utils.VideoInput` representing a single video
item, which can be passed to a HuggingFace :code:`VideoProcessor`.
Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as video embeddings;
these are directly passed to the model without HF processing.
"""
AudioItem: TypeAlias = Union[HfAudioItem, tuple[np.ndarray, float],
torch.Tensor]
"""
Represents a single audio
item, which can be passed to a HuggingFace :code:`AudioProcessor`.
Alternatively, a tuple `(audio, sampling_rate)`, where the sampling rate
is different from that expected by the model;
these are resampled to the model's sampling rate before being processed by HF.
Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as audio embeddings;
these are directly passed to the model without HF processing.
"""
ModalityData: TypeAlias = Union[_T, list[_T]]
"""
Either a single data item, or a list of data items.
The number of data items allowed per modality is restricted by
:code:`--limit-mm-per-prompt`.
"""
@final
class MultiModalDataBuiltins(TypedDict, total=False):
"""Type annotations for modality types predefined by vLLM."""
image: ModalityData[ImageItem]
"""The input image(s)."""
video: ModalityData[VideoItem]
"""The input video(s)."""
audio: ModalityData[AudioItem]
"""The input audio(s)."""
MultiModalDataDict: TypeAlias = Mapping[str, ModalityData[Any]]
"""
A dictionary containing an entry for each modality type to input.
The built-in modalities are defined by :class:`MultiModalDataBuiltins`.
"""
class PlaceholderRange(TypedDict):
"""
Placeholder location information for multi-modal data.
Example:
Prompt: :code:`AAAA BBBB What is in these images?`
Images A and B will have:
.. code-block::
A: { "offset": 0, "length": 4 }
B: { "offset": 5, "length": 4 }
"""
offset: int
"""The start index of the placeholder in the prompt."""
length: int
"""The length of the placeholder."""
NestedTensors = Union[list["NestedTensors"], list[torch.Tensor], torch.Tensor,
tuple[torch.Tensor, ...]]
"""
Uses a list instead of a tensor if the dimensions of each element do not match.
"""
def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:
"""Equality check between :data:`NestedTensors` objects."""
if isinstance(a, torch.Tensor):
return isinstance(b, torch.Tensor) and torch.equal(a, b)
elif isinstance(b, torch.Tensor):
return isinstance(a, torch.Tensor) and torch.equal(b, a)
if isinstance(a, list):
return (isinstance(b, list)
and all(nested_tensors_equal(a_, b_) for a_, b_ in zip(a, b)))
if isinstance(b, list):
return (isinstance(a, list)
and all(nested_tensors_equal(b_, a_) for b_, a_ in zip(b, a)))
# Both a and b are scalars
return a == b
BatchedTensorInputs: TypeAlias = Mapping[str, NestedTensors]
"""
A dictionary containing nested tensors which have been batched via
:meth:`MultiModalKwargs.batch`.
"""
@dataclass(frozen=True)
class MultiModalFieldElem:
"""
Represents a keyword argument corresponding to a multi-modal item
in :class:`MultiModalKwargs`.
"""
modality: str
"""
The modality of the corresponding multi-modal item.
Each multi-modal item can consist of multiple keyword arguments.
"""
key: str
"""
The key of this field in :class:`MultiModalKwargs`,
i.e. the name of the keyword argument to be passed to the model.
"""
data: NestedTensors
"""
The tensor data of this field in :class:`MultiModalKwargs`,
i.e. the value of the keyword argument to be passed to the model.
"""
field: "BaseMultiModalField"
"""
Defines how to combine the tensor data of this field with others
in order to batch multi-modal items together for model inference.
"""
def __eq__(self, other: object) -> bool:
if not isinstance(other, self.__class__):
return False
return ((self.modality, self.key) == (other.modality, other.key)
and nested_tensors_equal(self.data, other.data)
and type(self.field) == type(other.field)) # noqa: E721
@dataclass(frozen=True)
class BaseMultiModalField(ABC):
"""
Defines how to interpret tensor data belonging to a keyword argument in
:class:`MultiModalKwargs` for multiple multi-modal items, and vice versa.
"""
def _field_factory(self, *, modality: str, key: str):
f = partial(
MultiModalFieldElem,
modality=modality,
key=key,
field=self,
)
# Allow passing data as positional argument
def factory(data: NestedTensors) -> MultiModalFieldElem:
return f(data=data)
return factory
@abstractmethod
def build_elems(
self,
modality: str,
key: str,
data: NestedTensors,
) -> Sequence[MultiModalFieldElem]:
"""
Construct :class:`MultiModalFieldElem` instances to represent
the provided data.
This is the inverse of :meth:`reduce_data`.
"""
raise NotImplementedError
@abstractmethod
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
raise NotImplementedError
def reduce_data(self, elems: list[MultiModalFieldElem]) -> NestedTensors:
"""
Merge the data from multiple instances of :class:`MultiModalFieldElem`.
This is the inverse of :meth:`build_elems`.
"""
field_types = [type(item.field) for item in elems]
if len(set(field_types)) > 1:
raise ValueError(f"Cannot merge different {field_types=}")
return self._reduce_data([item.data for item in elems])
@dataclass(frozen=True)
class MultiModalBatchedField(BaseMultiModalField):
"""
See also:
:func:`MultiModalFieldConfig.batched`
"""
def build_elems(
self,
modality: str,
key: str,
data: NestedTensors,
) -> Sequence[MultiModalFieldElem]:
field_factory = self._field_factory(modality=modality, key=key)
return [field_factory(item) for item in data]
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
if len(batch) == 1:
# An optimization when `batch` contains only one tensor:
# - produce exactly same result as `torch.stack(batch)`
# - will achieve zero-copy if the tensor is contiguous
return batch[0].unsqueeze(0).contiguous()
first_shape = batch[0].shape
if all(elem.shape == first_shape for elem in batch):
return torch.stack(batch)
return batch
@dataclass(frozen=True)
class MultiModalFlatField(BaseMultiModalField):
"""
See also:
:func:`MultiModalFieldConfig.flat`
:func:`MultiModalFieldConfig.flat_from_sizes`
"""
slices: Sequence[slice]
def build_elems(
self,
modality: str,
key: str,
data: NestedTensors,
) -> Sequence[MultiModalFieldElem]:
field_factory = self._field_factory(modality=modality, key=key)
return [field_factory(data[s]) for s in self.slices]
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
if len(batch) == 1:
# An optimization when `batch` contains only one tensor:
# - produce exactly same result as `torch.concat(batch)`
# - will achieve zero-copy if the tensor is contiguous
return batch[0].contiguous()
first_shape = batch[0].shape
if all(elem.shape[1:] == first_shape[1:] for elem in batch):
return torch.concat(batch)
return [e for elem in batch for e in elem]
@dataclass(frozen=True)
class MultiModalSharedField(BaseMultiModalField):
"""
See also:
:func:`MultiModalFieldConfig.shared`
"""
batch_size: int
def build_elems(
self,
modality: str,
key: str,
data: NestedTensors,
) -> Sequence[MultiModalFieldElem]:
field_factory = self._field_factory(modality=modality, key=key)
return [field_factory(data)] * self.batch_size
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
return batch[0]
class MultiModalFieldConfig:
@staticmethod
def batched(modality: str):
"""
Defines a field where an element in the batch is obtained by
indexing into the first dimension of the underlying data.
Args:
modality: The modality of the multi-modal item that uses this
keyword argument.
Example:
.. code-block::
Input:
Data: [[AAAA]
[BBBB]
[CCCC]]
Output:
Element 1: [AAAA]
Element 2: [BBBB]
Element 3: [CCCC]
"""
return MultiModalFieldConfig(
field=MultiModalBatchedField(),
modality=modality,
)
@staticmethod
def flat(modality: str, slices: Sequence[slice]):
"""
Defines a field where an element in the batch is obtained by
slicing along the first dimension of the underlying data.
Args:
modality: The modality of the multi-modal item that uses this
keyword argument.
slices: For each multi-modal item, a slice that is used to extract
the data corresponding to it.
Example:
.. code-block::
Given:
slices: [slice(0, 3), slice(3, 7), slice(7, 9)]
Input:
Data: [AAABBBBCC]
Output:
Element 1: [AAA]
Element 2: [BBBB]
Element 3: [CC]
"""
return MultiModalFieldConfig(
field=MultiModalFlatField(slices=slices),
modality=modality,
)
@staticmethod
def flat_from_sizes(modality: str, size_per_item: torch.Tensor):
"""
Defines a field where an element in the batch is obtained by
slicing along the first dimension of the underlying data.
Args:
modality: The modality of the multi-modal item that uses this
keyword argument.
slices: For each multi-modal item, the size of the slice that
is used to extract the data corresponding to it.
Example:
.. code-block::
Given:
size_per_item: [3, 4, 2]
Input:
Data: [AAABBBBCC]
Output:
Element 1: [AAA]
Element 2: [BBBB]
Element 3: [CC]
See also:
:func:`MultiModalFieldConfig.flat`
"""
if size_per_item.ndim != 1:
raise ValueError("size_per_item should be a 1-D tensor, "
f"but found shape: {size_per_item.shape}")
slice_idxs = [0, *accumulate(size_per_item)]
slices = [
slice(slice_idxs[i], slice_idxs[i + 1])
for i in range(len(size_per_item))
]
return MultiModalFieldConfig.flat(modality, slices)
@staticmethod
def shared(modality: str, batch_size: int):
"""
Defines a field where an element in the batch is obtained by
taking the entirety of the underlying data.
This means that the data is the same for each element in the batch.
Args:
modality: The modality of the multi-modal item that uses this
keyword argument.
batch_size: The number of multi-modal items which share this data.
Example:
.. code-block::
Given:
batch_size: 4
Input:
Data: [XYZ]
Output:
Element 1: [XYZ]
Element 2: [XYZ]
Element 3: [XYZ]
Element 4: [XYZ]
"""
return MultiModalFieldConfig(
field=MultiModalSharedField(batch_size),
modality=modality,
)
def __init__(self, field: BaseMultiModalField, modality: str) -> None:
super().__init__()
self.field = field
self.modality = modality
def build_elems(
self,
key: str,
batch: NestedTensors,
) -> Sequence[MultiModalFieldElem]:
return self.field.build_elems(self.modality, key, batch)
class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
"""
A collection of :class:`MultiModalFieldElem`
corresponding to a data item in :class:`MultiModalDataItems`.
"""
@staticmethod
def from_elems(elems: Sequence[MultiModalFieldElem]):
return MultiModalKwargsItem({elem.key: elem for elem in elems})
@property
def modality(self) -> str:
modalities = {elem.modality for elem in self.data.values()}
assert len(modalities) == 1, f"Found different modalities={modalities}"
return next(iter(modalities))
# NOTE: UserDict is for V0 compatibility.
# V1 should access individual items via `get_item`.
class MultiModalKwargs(UserDict[str, NestedTensors]):
"""
A dictionary that represents the keyword arguments to
:meth:`~torch.nn.Module.forward`.
The metadata :code:`items` enables us to obtain the keyword arguments
corresponding to each data item in :class:`MultiModalDataItems`, via
:meth:`get_item` and :meth:`get_items`.
"""
@staticmethod
def from_hf_inputs(
hf_inputs: BatchFeature,
config_by_key: Mapping[str, MultiModalFieldConfig],
):
# NOTE: This skips fields in `hf_inputs` that are not in `config_by_key`
# We assume that those fields are not used in vLLM
elems_by_key = dict[str, Sequence[MultiModalFieldElem]]()
keys_by_modality = defaultdict[str, set[str]](set)
for key, config in config_by_key.items():
batch = hf_inputs.get(key)
if batch is not None:
elems = config.build_elems(key, batch)
if len(elems) > 0:
elems_by_key[key] = elems
keys_by_modality[config.modality].add(key)
items = list[MultiModalKwargsItem]()
for modality, keys in keys_by_modality.items():
elems_in_modality = {k: elems_by_key[k] for k in keys}
batch_sizes = {k: len(v) for k, v in elems_in_modality.items()}
if len(set(batch_sizes.values())) > 1:
raise ValueError(
f"Cannot merge different batch sizes for {modality=}! "
f"Found: {batch_sizes=}")
batch_size = next(iter(batch_sizes.values()))
for item_idx in range(batch_size):
elems = [v[item_idx] for v in elems_in_modality.values()]
items.append(MultiModalKwargsItem.from_elems(elems))
return MultiModalKwargs.from_items(items)
@staticmethod
def from_items(items: Sequence[MultiModalKwargsItem]):
"""Construct a new :class:`MultiModalKwargs` from multiple items."""
elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
for item in items:
for key, elem in item.items():
elems_by_key[key].append(elem)
data = {
key: elems[0].field.reduce_data(elems)
for key, elems in elems_by_key.items() if len(elems) > 0
}
return MultiModalKwargs(data, items=items)
def __init__(
self,
data: Mapping[str, NestedTensors],
*,
items: Optional[Sequence[MultiModalKwargsItem]] = None,
) -> None:
super().__init__(data)
items_by_modality = full_groupby(items or [], key=lambda x: x.modality)
self._items_by_modality = dict(items_by_modality)
@property
def modalities(self):
return self._items_by_modality.keys()
@staticmethod
def _try_stack(nested_tensors: NestedTensors) -> NestedTensors:
"""
Stack the inner dimensions that have the same shape in
a nested list of tensors.
Thus, a dimension represented by a list means that the inner
dimensions are different for each element along that dimension.
"""
if isinstance(nested_tensors, torch.Tensor):
return nested_tensors
# TODO: Remove these once all models have been migrated
if isinstance(nested_tensors, np.ndarray):
return torch.from_numpy(nested_tensors)
if isinstance(nested_tensors, (int, float)):
return torch.tensor(nested_tensors)
stacked = [MultiModalKwargs._try_stack(t) for t in nested_tensors]
if not is_list_of(stacked, torch.Tensor, check="all"):
# Only tensors (not lists) can be stacked.
return stacked
tensors_ = cast(list[torch.Tensor], stacked)
if len(tensors_) == 1:
# An optimization when `tensors_` contains only one tensor:
# - produce exactly same result as `torch.stack(tensors_)`
# - will achieve zero-copy if the tensor is contiguous
return tensors_[0].unsqueeze(0).contiguous()
if any(t.shape != tensors_[0].shape for t in tensors_):
# The tensors have incompatible shapes and can't be stacked.
return tensors_
return torch.stack(tensors_)
@staticmethod
def batch(inputs_list: list["MultiModalKwargs"]) -> BatchedTensorInputs:
"""
Batch multiple inputs together into a dictionary.
The resulting dictionary has the same keys as the inputs.
If the corresponding value from each input is a tensor and they all
share the same shape, the output value is a single batched tensor;
otherwise, the output value is a list containing the original value
from each input.
"""
if len(inputs_list) == 0:
return {}
# We need to consider the case where each item in the batch
# contains different modalities (i.e. different keys).
item_lists = defaultdict[str, list[NestedTensors]](list)
for inputs in inputs_list:
for k, v in inputs.items():
item_lists[k].append(v)
return {
k: MultiModalKwargs._try_stack(item_list)
for k, item_list in item_lists.items()
}
@staticmethod
def as_kwargs(
batched_inputs: BatchedTensorInputs,
*,
device: torch.types.Device,
) -> BatchedTensorInputs:
json_inputs = cast(JSONTree[torch.Tensor], batched_inputs)
json_mapped = json_map_leaves(
lambda x: x.to(device, non_blocking=True),
json_inputs,
)
return cast(BatchedTensorInputs, json_mapped)
def __delitem__(self, key: str) -> None:
super().__delitem__(key)
for items in self._items_by_modality.values():
for item in items:
item.pop(key, None)
def __eq__(self, other: object) -> bool:
if not isinstance(other, self.__class__):
return False
if self._items_by_modality != other._items_by_modality:
return False
ks = self.keys()
return (ks == other.keys()
and all(nested_tensors_equal(self[k], other[k]) for k in ks))
def _validate_modality(self, method_name: str, modality: str) -> None:
if not self._items_by_modality:
raise RuntimeError(
f"`{method_name}` is not supported when "
"MultiModalKwargs is not initialized with `items`")
if modality not in self._items_by_modality:
available_modalities = set(self._items_by_modality.keys())
raise KeyError(f"Modality {modality!r} not found. "
f"Available modalities: {available_modalities}")
def get_item_count(self, modality: str) -> int:
"""Get the number of items belonging to a modality."""
self._validate_modality("get_item_count", modality)
return len(self._items_by_modality[modality])
def get_item(self, modality: str, item_index: int) -> MultiModalKwargsItem:
"""
Get the keyword arguments corresponding to an item identified by
its modality and index.
"""
self._validate_modality("get_item", modality)
return self._items_by_modality[modality][item_index]
def get_items(self, modality: str) -> Sequence[MultiModalKwargsItem]:
"""
Get the keyword arguments corresponding to each item belonging to
a modality.
"""
self._validate_modality("get_items", modality)
return self._items_by_modality[modality]
MultiModalPlaceholderDict = Mapping[str, Sequence[PlaceholderRange]]
"""
A dictionary containing placeholder ranges for each modality.
"""
class MultiModalInputs(TypedDict):
"""
Represents the outputs of
:class:`vllm.multimodal.processing.BaseMultiModalProcessor`,
ready to be passed to vLLM internals.
"""
type: Literal["multimodal"]
"""The type of inputs."""
prompt: str
"""The processed prompt text."""
prompt_token_ids: list[int]
"""The processed token IDs which includes placeholder tokens."""
token_type_ids: NotRequired[list[int]]
"""The token type IDs of the prompt."""
mm_kwargs: MultiModalKwargs
"""Keyword arguments to be directly passed to the model after batching."""
mm_hashes: Optional["MultiModalHashDict"]
"""The hashes of the multi-modal data."""
mm_placeholders: MultiModalPlaceholderDict
"""
For each modality, information about the placeholder tokens in
:code:`prompt_token_ids`.
"""
class MultiModalEncDecInputs(MultiModalInputs):
"""
Represents the outputs of :class:`vllm.multimodal.EncDecMultiModalProcessor`
ready to be passed to vLLM internals.
"""
encoder_prompt: str
"""The processed encoder prompt text."""
encoder_prompt_token_ids: list[int]
"""The processed token IDs of the encoder prompt."""
encoder_token_type_ids: NotRequired[list[int]]
"""The token type IDs of the encoder prompt."""

453
vllm/multimodal/parse.py Normal file
View File

@@ -0,0 +1,453 @@
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
from collections import UserDict
from collections.abc import Callable, Iterator, Mapping, Sequence
from typing import (TYPE_CHECKING, Any, Generic, NamedTuple, Optional, TypeVar,
Union)
import numpy as np
import torch
from PIL.Image import Image
from transformers import BatchFeature
from typing_extensions import TypeAlias, TypeGuard, assert_never
from vllm.utils import is_list_of
from .audio import resample_audio
from .inputs import (AudioItem, HfAudioItem, HfImageItem, HfVideoItem,
ImageItem, ModalityData, MultiModalDataDict,
MultiModalFieldConfig, MultiModalKwargs, VideoItem)
_T = TypeVar("_T")
_I = TypeVar("_I")
class ModalityDataItems(ABC, Generic[_T, _I]):
"""
Represents data items for a modality in :class:`MultiModalDataItems`.
"""
def __init__(self, data: _T, modality: str) -> None:
super().__init__()
self.data = data
self.modality = modality
def __repr__(self) -> str:
return (f"{type(self).__name__}(modality={self.modality!r}, "
f"len={len(self)})")
def __len__(self) -> int:
return self.get_count()
def __getitem__(self, index: int) -> _I:
return self.get(index)
if TYPE_CHECKING:
# Auto-generated
def __iter__(self) -> Iterator[_I]:
...
@abstractmethod
def get_count(self) -> int:
"""Get the number of data items."""
raise NotImplementedError
@abstractmethod
def get(self, index: int) -> _I:
"""Get a data item by its index."""
raise NotImplementedError
def get_all(self) -> list[_I]:
"""Get all data items."""
return [self.get(idx) for idx in range(self.get_count())]
@abstractmethod
def get_processor_data(self) -> Mapping[str, object]:
"""Get the data to pass to the HF processor."""
raise NotImplementedError
@abstractmethod
def get_passthrough_data(self) -> Mapping[str, object]:
"""Get the data to pass directly to the model."""
raise NotImplementedError
class ProcessorBatchItems(ModalityDataItems[Sequence[_T], _T]):
"""Base class for data items that are arranged in a list."""
def get_count(self) -> int:
return len(self.data)
def get(self, index: int) -> _T:
return self.data[index]
def get_processor_data(self) -> Mapping[str, object]:
return {f"{self.modality}s": self.data}
def get_passthrough_data(self) -> Mapping[str, object]:
return {}
class EmbeddingItems(ModalityDataItems[Union[torch.Tensor, list[torch.Tensor]],
torch.Tensor]):
"""
Base class for data items that are expressed as a batched embedding tensor,
or a list of embedding tensors (one per item).
"""
def get_count(self) -> int:
return len(self.data)
def get(self, index: int) -> torch.Tensor:
return self.data[index]
def get_processor_data(self) -> Mapping[str, object]:
return {}
def get_passthrough_data(self) -> Mapping[str, object]:
return {f"{self.modality}_embeds": self.data}
def get_feature_size(self, item_idx: int) -> int:
return len(self.get(item_idx))
class DictEmbeddingItems(ModalityDataItems[Mapping[str, torch.Tensor],
Mapping[str, torch.Tensor]]):
"""
Base class for data items that are expressed as a dictionary of tensors.
Usually, the dictionary keys correspond to the outputs of HF processor.
"""
def __init__(
self,
data: Mapping[str, torch.Tensor],
modality: str,
required_fields: set[str],
fields_factory: Callable[
[Mapping[str, torch.Tensor]],
Mapping[str, MultiModalFieldConfig],
],
) -> None:
super().__init__(data, modality)
missing_required_data_keys = required_fields - data.keys()
if missing_required_data_keys:
data_keys = set(data.keys())
msg = (f"The data should contain the fields: {required_fields}, "
f"but only found the following keys: {data_keys}")
raise ValueError(msg)
fields_config = fields_factory(data)
missing_required_fields = required_fields - fields_config.keys()
if missing_required_fields:
fields = set(fields_config.keys())
msg = f"{required_fields=} should be a subset of {fields=}"
raise ValueError(msg)
self.fields_config = fields_config
self.required_fields = required_fields
self._kwargs = MultiModalKwargs.from_hf_inputs(
BatchFeature(dict(data)),
fields_config,
)
def get_count(self) -> int:
return self._kwargs.get_item_count(self.modality)
def get(self, index: int) -> Mapping[str, torch.Tensor]:
return {
k: v.data
for k, v in self._kwargs.get_item(self.modality, index).items()
}
def get_processor_data(self) -> Mapping[str, object]:
return {}
def get_passthrough_data(self) -> Mapping[str, object]:
return self.data
class AudioProcessorItems(ProcessorBatchItems[HfAudioItem]):
def __init__(self, data: Sequence[HfAudioItem]) -> None:
super().__init__(data, "audio")
def get_audio_length(self, item_idx: int) -> int:
audio = self.get(item_idx)
return len(audio)
class AudioEmbeddingItems(EmbeddingItems):
def __init__(self, data: Union[torch.Tensor, list[torch.Tensor]]) -> None:
super().__init__(data, "audio")
class ImageSize(NamedTuple):
width: int
height: int
class ImageProcessorItems(ProcessorBatchItems[HfImageItem]):
def __init__(self, data: Sequence[HfImageItem]) -> None:
super().__init__(data, "image")
def get_image_size(self, item_idx: int) -> ImageSize:
image = self.get(item_idx)
if isinstance(image, Image):
return ImageSize(*image.size)
if isinstance(image, (np.ndarray, torch.Tensor)):
_, h, w = image.shape
return ImageSize(w, h)
assert_never(image)
class ImageEmbeddingItems(EmbeddingItems):
def __init__(self, data: Union[torch.Tensor, list[torch.Tensor]]) -> None:
super().__init__(data, "image")
class VideoProcessorItems(ProcessorBatchItems[HfVideoItem]):
def __init__(self, data: Sequence[HfVideoItem]) -> None:
super().__init__(data, "video")
def get_num_frames(self, item_idx: int) -> int:
return len(self.get(item_idx))
def get_frame_size(self, item_idx: int) -> ImageSize:
image = self.get(item_idx)[0] # Assume that the video isn't empty
if isinstance(image, Image):
return ImageSize(*image.size)
if isinstance(image, (np.ndarray, torch.Tensor)):
_, h, w = image.shape
return ImageSize(w, h)
assert_never(image)
class VideoEmbeddingItems(EmbeddingItems):
def __init__(self, data: Union[torch.Tensor, list[torch.Tensor]]) -> None:
super().__init__(data, "video")
_D = TypeVar("_D", bound=ModalityDataItems[Any, Any])
class MultiModalDataItems(UserDict[str, ModalityDataItems[Any, Any]]):
"""
As :data:`~vllm.multimodal.inputs.MultiModalDataDict`, but normalized
such that each entry corresponds to a list.
"""
def get_count(self, modality: str, *, strict: bool = True) -> int:
"""
Get the number of data items belonging to a modality.
If `strict=False`, return `0` instead of raising :exc:`KeyError`
even if the modality is not found.
"""
if modality not in self:
if strict:
available_modalities = set(self.keys())
raise KeyError(f"Modality {modality!r} not found. "
f"Available modalities: {available_modalities}")
return 0
return self[modality].get_count()
def get_all_counts(self) -> Mapping[str, int]:
"""Get the number of items belonging to each modality."""
return {m: items.get_count() for m, items in self.items()}
def get_items(
self,
modality: str,
typ: Union[type[_D], tuple[type[_D], ...]],
) -> _D:
"""
Get the data items belonging to a modality,
requiring that they belong to a certain type.
"""
if modality not in self:
available_modalities = set(self.keys())
raise KeyError(f"Modality {modality!r} not found. "
f"Available modalities: {available_modalities}")
items = self[modality]
if not isinstance(items, typ):
raise TypeError(f"Invalid type of data items for {modality=}. "
f"Expected type: {typ}, but "
f"found type: {type(items)}")
return items # type: ignore[return-value]
ModalityDataParser: TypeAlias = Callable[[ModalityData[Any]],
Optional[ModalityDataItems[Any, Any]]]
class MultiModalDataParser:
"""
Parses :data:`~vllm.multimodal.inputs.MultiModalDataDict` into
:class:`MultiModalDataItems`.
Args:
target_sr (float, optional): Enables automatic resampling of audio
items to the model's expected sampling rate.
"""
def __init__(self, *, target_sr: Optional[float] = None) -> None:
super().__init__()
self.target_sr = target_sr
def _is_embeddings(
self, data: object
) -> TypeGuard[Union[torch.Tensor, list[torch.Tensor]]]:
if isinstance(data, torch.Tensor):
return data.ndim == 3
if is_list_of(data, torch.Tensor):
return data[0].ndim == 2
return False
def _is_empty(self, data: object) -> TypeGuard[None]:
if isinstance(data, list):
return len(data) == 0
if isinstance(data, (np.ndarray, torch.Tensor)):
return data.size == 0
return False
def _get_audio_with_sr(
self,
audio: AudioItem,
) -> tuple[np.ndarray, Optional[float]]:
if isinstance(audio, tuple):
return audio
if isinstance(audio, list):
return np.array(audio), None
if isinstance(audio, np.ndarray):
return audio, None
if isinstance(audio, torch.Tensor):
return audio.numpy(), None
assert_never(audio)
def _parse_audio_data(
self,
data: ModalityData[AudioItem],
) -> Optional[ModalityDataItems[Any, Any]]:
# also check single audio item with sampling rate
if self._is_empty(data) or (isinstance(data, tuple)
and self._is_empty(data[0])):
return None
if self._is_embeddings(data):
return AudioEmbeddingItems(data)
if (is_list_of(data, float)
or isinstance(data,
(np.ndarray, torch.Tensor)) and data.ndim == 1
or isinstance(data, tuple)):
data_items = [data]
elif isinstance(data, (np.ndarray, torch.Tensor)):
data_items = [elem for elem in data]
else:
data_items = data
new_audios = list[np.ndarray]()
for data_item in data_items:
audio, orig_sr = self._get_audio_with_sr(data_item)
if orig_sr is None:
new_audio = audio
else:
target_sr = self.target_sr
if target_sr is None:
raise RuntimeError(
"Audio resampling is not supported when "
"`target_sr` is not provided")
new_audio = resample_audio(audio,
orig_sr=orig_sr,
target_sr=target_sr)
new_audios.append(new_audio)
return AudioProcessorItems(new_audios)
def _parse_image_data(
self,
data: ModalityData[ImageItem],
) -> Optional[ModalityDataItems[Any, Any]]:
if self._is_empty(data):
return None
if self._is_embeddings(data):
return ImageEmbeddingItems(data)
if (isinstance(data, Image)
or isinstance(data,
(np.ndarray, torch.Tensor)) and data.ndim == 3):
data_items = [data]
elif isinstance(data, (np.ndarray, torch.Tensor)):
data_items = [elem for elem in data]
else:
data_items = data
return ImageProcessorItems(data_items)
def _parse_video_data(
self,
data: ModalityData[VideoItem],
) -> Optional[ModalityDataItems[Any, Any]]:
if self._is_empty(data):
return None
if self._is_embeddings(data):
return VideoEmbeddingItems(data)
if (is_list_of(data, Image)
or isinstance(data,
(np.ndarray, torch.Tensor)) and data.ndim == 4):
data_items = [data]
elif isinstance(data, (np.ndarray, torch.Tensor)):
data_items = [elem for elem in data]
else:
data_items = data
return VideoProcessorItems(data_items)
def _get_subparsers(self) -> Mapping[str, ModalityDataParser]:
return {
"audio": self._parse_audio_data,
"image": self._parse_image_data,
"video": self._parse_video_data,
}
def parse_mm_data(self,
mm_data: MultiModalDataDict) -> MultiModalDataItems:
subparsers = self._get_subparsers()
mm_items = MultiModalDataItems()
for k, v in mm_data.items():
if k not in subparsers:
raise ValueError(f"Unsupported modality: {k}")
# ignore empty embedding data
if (parsed_data := subparsers[k](v)) is not None:
mm_items[k] = parsed_data
return mm_items

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,251 @@
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
from collections.abc import Mapping
from dataclasses import dataclass, field
from typing import Generic, NamedTuple, Optional, TypeVar, cast
import numpy as np
import numpy.typing as npt
from PIL import Image
import vllm.envs as envs
from vllm.logger import init_logger
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
MultiModalInputs, MultiModalKwargs,
MultiModalPlaceholderDict)
from .processing import BaseMultiModalProcessor, BaseProcessingInfo
logger = init_logger(__name__)
@dataclass
class ProcessorInputs:
"""
Represents the keyword arguments to
:meth:`vllm.multimodal.processing.BaseMultiModalProcessor.apply`.
"""
prompt_text: str
mm_data: MultiModalDataDict
hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict)
class DummyEncoderData(NamedTuple):
"""Dummy data used for profiling."""
prompt_token_ids: list[int]
class DummyDecoderData(NamedTuple):
"""Dummy data used for profiling."""
prompt_token_ids: list[int]
multi_modal_data: MultiModalKwargs
multi_modal_placeholders: MultiModalPlaceholderDict
_I = TypeVar("_I", bound=BaseProcessingInfo)
class BaseDummyInputsBuilder(ABC, Generic[_I]):
"""
Abstract base class that constructs the dummy data to profile
multi-modal models.
"""
def __init__(self, info: _I) -> None:
super().__init__()
self.info = info
@abstractmethod
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
"""
Build the input which, after processing, results in
:code:`self.info.get_mm_max_tokens_per_item()` placeholder tokens.
"""
raise NotImplementedError
def _get_dummy_audios(
self,
*,
length: int,
num_audios: int,
) -> list[npt.NDArray]:
audio = np.zeros((length, ))
return [audio] * num_audios
def _get_dummy_images(
self,
*,
width: int,
height: int,
num_images: int,
) -> list[Image.Image]:
image = Image.new("RGB", (width, height), color=255)
return [image] * num_images
def _get_dummy_videos(
self,
*,
width: int,
height: int,
num_frames: int,
num_videos: int,
) -> list[npt.NDArray]:
video = np.full((num_frames, width, height, 3), 255)
return [video] * num_videos
class MultiModalProfiler(Generic[_I]):
"""
Contains code for running memory profiling for multi-modal models.
"""
def __init__(
self,
processor: BaseMultiModalProcessor[_I],
) -> None:
super().__init__()
self.processor = processor
@property
def processing_info(self) -> BaseProcessingInfo:
return self.processor.info
@property
def dummy_inputs(self) -> BaseDummyInputsBuilder[_I]:
return self.processor.dummy_inputs
def get_mm_limits(self) -> Mapping[str, int]:
mm_config = self.processing_info.ctx.get_mm_config()
supported_mm_limits = self.processing_info.get_supported_mm_limits()
mm_limits = {
modality: mm_config.get_limit_per_prompt(modality)
for modality in supported_mm_limits
}
for modality, supported_limit in supported_mm_limits.items():
limit = mm_limits[modality]
if supported_limit is not None and supported_limit < limit:
raise ValueError(
f"You set {modality}={limit} (or defaulted to 1) in "
f"`--limit-mm-per-prompt`, but this model only supports "
f"at most {supported_limit} {modality} items.")
return mm_limits
def _get_dummy_mm_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> MultiModalInputs:
factory = self.dummy_inputs
processor_inputs = factory.get_dummy_processor_inputs(
seq_len, mm_counts)
return self.processor.apply(
prompt=processor_inputs.prompt_text,
mm_data=processor_inputs.mm_data,
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
)
def get_and_validate_mm_inputs(
self,
seq_len: int,
mm_counts: Optional[Mapping[str, int]] = None,
) -> tuple[MultiModalInputs, Mapping[str, int]]:
if mm_counts is None:
mm_counts = self.get_mm_limits()
info = self.processing_info
mm_max_tokens_per_item = info.get_mm_max_tokens_per_item(
seq_len, mm_counts)
if mm_counts.keys() - mm_max_tokens_per_item.keys():
raise AssertionError(
"The keys returned by `get_supported_mm_limits` "
f"({set(mm_counts.keys())}) should be a subset of those "
"returned by `get_mm_max_tokens_per_item` "
f"({set(mm_max_tokens_per_item.keys())})")
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
placeholders_by_modality = mm_inputs["mm_placeholders"]
total_placeholders_by_modality = {
modality: sum(item["length"] for item in placeholders)
for modality, placeholders in placeholders_by_modality.items()
}
expected_placeholders_by_modality = {
modality: mm_max_tokens_per_item[modality] * mm_counts[modality]
for modality in placeholders_by_modality
}
if total_placeholders_by_modality != expected_placeholders_by_modality:
raise AssertionError(
f"The processed dummy data has a total of "
f"{total_placeholders_by_modality} placeholder tokens, which "
f"is not the expected {expected_placeholders_by_modality} "
"tokens.")
return mm_inputs, total_placeholders_by_modality
def get_encoder_dummy_data(
self,
seq_len: int,
mm_counts: Optional[Mapping[str, int]] = None,
) -> DummyEncoderData:
mm_inputs, _ = self.get_and_validate_mm_inputs(seq_len, mm_counts)
mm_inputs = cast(MultiModalEncDecInputs, mm_inputs)
# For encoder-decoder models, use encoder prompt token ids instead of
# decoder prompt to construct dummy seq_data for encoder profiling.
encoder_prompt_token_ids = mm_inputs["encoder_prompt_token_ids"]
total_len = len(encoder_prompt_token_ids)
num_tokens_to_pad = max(total_len, seq_len) - total_len
encoder_prompt_token_ids.extend([0] * num_tokens_to_pad)
return DummyEncoderData(encoder_prompt_token_ids)
def get_decoder_dummy_data(
self,
seq_len: int,
mm_counts: Optional[Mapping[str, int]] = None,
) -> DummyDecoderData:
(
mm_inputs,
total_placeholders_by_modality,
) = self.get_and_validate_mm_inputs(seq_len, mm_counts)
prompt_token_ids = mm_inputs["prompt_token_ids"]
total_len = len(prompt_token_ids)
# V0 does not support chunked prefill.
if total_len > seq_len and not envs.VLLM_USE_V1:
# `max_num_batched_tokens` is defined by `SchedulerConfig`
logger.warning(
"The sequence length used for profiling ("
"max_num_batched_tokens / max_num_seqs = %d) is too short "
"to hold the multi-modal embeddings in the worst case "
"(%d tokens in total, out of which %s are reserved for "
"multi-modal embeddings). This may cause certain "
"multi-modal inputs to fail during inference, even when "
"the input text is short. To avoid this, you should "
"increase `max_model_len`, reduce `max_num_seqs`, "
"and/or reduce `mm_counts`.", seq_len, total_len,
total_placeholders_by_modality)
if total_len < seq_len:
prompt_token_ids.extend([0] * (seq_len - total_len))
return DummyDecoderData(
prompt_token_ids=prompt_token_ids,
multi_modal_data=mm_inputs["mm_kwargs"],
multi_modal_placeholders=mm_inputs["mm_placeholders"],
)

503
vllm/multimodal/registry.py Normal file
View File

@@ -0,0 +1,503 @@
# SPDX-License-Identifier: Apache-2.0
import functools
from collections import UserDict
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Generic, Optional, Protocol, TypeVar
import torch.nn as nn
from vllm.envs import VLLM_MM_INPUT_CACHE_GIB
from vllm.inputs import InputProcessingContext
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
cached_tokenizer_from_config)
from vllm.utils import ClassRegistry
from .audio import AudioPlugin
from .base import MultiModalInputMapper, MultiModalPlugin, MultiModalTokensCalc
from .image import ImagePlugin
from .inputs import MultiModalDataDict, MultiModalKwargs, NestedTensors
from .processing import (BaseMultiModalProcessor, BaseProcessingInfo,
ProcessingCache)
from .profiling import (BaseDummyInputsBuilder, DummyDecoderData,
DummyEncoderData, MultiModalProfiler)
from .video import VideoPlugin
if TYPE_CHECKING:
from vllm.config import ModelConfig
logger = init_logger(__name__)
N = TypeVar("N", bound=type[nn.Module])
_I = TypeVar("_I", bound=BaseProcessingInfo)
_I_co = TypeVar("_I_co", bound=BaseProcessingInfo, covariant=True)
class ProcessingInfoFactory(Protocol[_I_co]):
"""Constructs a :class:`MultiModalProcessor` instance from the context."""
def __call__(
self,
ctx: InputProcessingContext,
) -> _I_co:
...
class DummyInputsBuilderFactory(Protocol[_I]):
"""
Constructs a :class:`BaseDummyInputsBuilder` instance from the context.
"""
def __call__(self, info: _I) -> BaseDummyInputsBuilder[_I]:
...
class MultiModalProcessorFactory(Protocol[_I]):
"""Constructs a :class:`MultiModalProcessor` instance from the context."""
def __call__(
self,
info: _I,
dummy_inputs: BaseDummyInputsBuilder[_I],
*,
cache: Optional[ProcessingCache] = None,
) -> BaseMultiModalProcessor[_I]:
...
@dataclass(frozen=True)
class _ProcessorFactories(Generic[_I]):
info: ProcessingInfoFactory[_I]
processor: MultiModalProcessorFactory[_I]
dummy_inputs: DummyInputsBuilderFactory[_I]
def build_processor(
self,
ctx: InputProcessingContext,
*,
cache: Optional[ProcessingCache] = None,
):
info = self.info(ctx)
dummy_inputs_builder = self.dummy_inputs(info)
return self.processor(info, dummy_inputs_builder, cache=cache)
class _MultiModalLimits(UserDict["ModelConfig", dict[str, int]]):
"""
Wraps `_limits_by_model` for a more informative error message
when attempting to access a model that does not exist.
"""
def __getitem__(self, key: "ModelConfig") -> dict[str, int]:
try:
return super().__getitem__(key)
except KeyError as exc:
msg = (f"Cannot find `mm_limits` for model={key.model}. Did you "
"forget to call `init_mm_limits_per_prompt`?")
raise KeyError(msg) from exc
class MultiModalRegistry:
"""
A registry that dispatches data processing according to the model.
"""
DEFAULT_PLUGINS = (ImagePlugin(), AudioPlugin(), VideoPlugin())
def __init__(
self,
*,
plugins: Sequence[MultiModalPlugin] = DEFAULT_PLUGINS) -> None:
self._plugins = {p.get_data_key(): p for p in plugins}
self._processor_factories = ClassRegistry[nn.Module,
_ProcessorFactories]()
# This is used for non-multimodal models
self._disabled_limits_per_plugin = {k: 0 for k in self._plugins}
self._limits_by_model = _MultiModalLimits()
self._processing_cache = ProcessingCache(VLLM_MM_INPUT_CACHE_GIB)
def register_plugin(self, plugin: MultiModalPlugin) -> None:
"""
Register a multi-modal plugin so it can be recognized by vLLM.
"""
data_type_key = plugin.get_data_key()
if data_type_key in self._plugins:
logger.warning(
"A plugin is already registered for data type %s, "
"and will be overwritten by the new plugin %s.", data_type_key,
plugin)
self._plugins[data_type_key] = plugin
def _get_plugin(self, data_type_key: str):
plugin = self._plugins.get(data_type_key)
if plugin is not None:
return plugin
msg = f"Unknown multi-modal data type: {data_type_key}"
raise NotImplementedError(msg)
def register_input_mapper(
self,
data_type_key: str,
mapper: Optional[MultiModalInputMapper] = None,
):
"""
Register an input mapper for a specific modality to a model class.
See :meth:`MultiModalPlugin.register_input_mapper` for more details.
"""
return self._get_plugin(data_type_key).register_input_mapper(mapper)
def register_image_input_mapper(
self,
mapper: Optional[MultiModalInputMapper] = None,
):
"""
Register an input mapper for image data to a model class.
See :meth:`MultiModalPlugin.register_input_mapper` for more details.
"""
return self.register_input_mapper("image", mapper)
def map_input(
self,
model_config: "ModelConfig",
data: MultiModalDataDict,
mm_processor_kwargs: Optional[dict[str, Any]] = None,
) -> MultiModalKwargs:
"""
Apply an input mapper to the data passed to the model.
The data belonging to each modality is passed to the corresponding
plugin which in turn converts the data into into keyword arguments
via the input mapper registered for that model.
See :meth:`MultiModalPlugin.map_input` for more details.
Note:
This should be called after :meth:`init_mm_limits_per_prompt`.
"""
merged_dict = dict[str, NestedTensors]()
for data_key, data_value in data.items():
plugin = self._get_plugin(data_key)
num_items = len(data_value) if isinstance(data_value, list) else 1
max_items = self._limits_by_model[model_config][data_key]
if num_items > max_items:
raise ValueError(
f"You set {data_key}={max_items} (or defaulted to 1) in "
f"`--limit-mm-per-prompt`, but found {num_items} items "
"in the same prompt.")
input_dict = plugin.map_input(model_config, data_value,
mm_processor_kwargs)
for input_key, input_tensor in input_dict.items():
if input_key in merged_dict:
raise ValueError(f"The input mappers (keys={set(data)}) "
f"resulted in a conflicting keyword "
f"argument to `forward()`: {input_key}")
merged_dict[input_key] = input_tensor
return MultiModalKwargs(merged_dict)
def create_input_mapper(self, model_config: "ModelConfig"):
"""
Create an input mapper (see :meth:`map_input`) for a specific model.
"""
# NOTE - we currently make the assumption that if a model has multiple
# supported modalities, they take the same kwargs. For the default,
# this could be an issue in the future if it falls back to two HF
# resources and we can't inspect the signature easily since it's
# getting initialized through the autoclass.
#
# If this is a problem in the future, we should revisit it, but since
# it potentially introduces a lot of complexity for a currently
# uncommon case, we do not for simplicity of both use & implementation
return functools.partial(self.map_input, model_config)
def register_max_multimodal_tokens(
self,
data_type_key: str,
max_mm_tokens: Optional[MultiModalTokensCalc] = None,
):
"""
Register the maximum number of tokens, corresponding to a single
instance of multimodal data belonging to a specific modality, that are
passed to the language model for a model class.
"""
return self._get_plugin(data_type_key) \
.register_max_multimodal_tokens(max_mm_tokens)
def register_max_image_tokens(
self,
max_mm_tokens: Optional[MultiModalTokensCalc] = None,
):
"""
Register the maximum number of image tokens, corresponding to a single
image, that are passed to the language model for a model class.
"""
return self.register_max_multimodal_tokens("image", max_mm_tokens)
def get_max_tokens_per_item_by_modality(
self,
model_config: "ModelConfig",
) -> Mapping[str, int]:
"""
Get the maximum number of tokens per data item from each modality based
on underlying model configuration.
"""
if self.has_processor(model_config):
processor = self.create_processor(model_config, disable_cache=True)
seq_len = model_config.max_model_len
mm_limits = self.get_mm_limits_per_prompt(model_config)
return processor.info.get_mm_max_tokens_per_item(
seq_len, mm_limits)
return {
key: plugin.get_max_multimodal_tokens(model_config)
for key, plugin in self._plugins.items()
}
def get_max_tokens_per_item_by_nonzero_modality(
self,
model_config: "ModelConfig",
) -> Mapping[str, int]:
"""
Get the maximum number of tokens per data item from each modality based
on underlying model configuration, excluding modalities that user
explicitly disabled via `limit_mm_per_prompt`.
Note:
This is currently directly used only in V1 for profiling the memory
usage of a model.
"""
mm_limits = self.get_mm_limits_per_prompt(model_config)
return {
key: max_tokens_per_mm_item
for key, max_tokens_per_mm_item in
self.get_max_tokens_per_item_by_modality(model_config).items()
if mm_limits[key] > 0
}
def get_max_tokens_by_modality(
self,
model_config: "ModelConfig",
) -> Mapping[str, int]:
"""
Get the maximum number of tokens from each modality
for profiling the memory usage of a model.
See :meth:`MultiModalPlugin.get_max_multimodal_tokens` for more details.
Note:
This should be called after :meth:`init_mm_limits_per_prompt`.
"""
mm_limits = self.get_mm_limits_per_prompt(model_config)
return {
key: mm_limits[key] * max_tokens_per_mm_item
for key, max_tokens_per_mm_item in
self.get_max_tokens_per_item_by_modality(model_config).items()
}
def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int:
"""
Get the maximum number of multi-modal tokens
for profiling the memory usage of a model.
See :meth:`MultiModalPlugin.get_max_multimodal_tokens` for more details.
Note:
This should be called after :meth:`init_mm_limits_per_prompt`.
"""
return sum(self.get_max_tokens_by_modality(model_config).values())
def init_mm_limits_per_prompt(
self,
model_config: "ModelConfig",
) -> None:
"""
Initialize the maximum number of multi-modal input instances for each
modality that are allowed per prompt for a model class.
"""
if model_config in self._limits_by_model:
logger.warning(
"`mm_limits` has already been set for model=%s, and will "
"be overwritten by the new values.", model_config.model)
multimodal_config = model_config.multimodal_config
if multimodal_config is None:
limits_per_plugin = self._disabled_limits_per_plugin
else:
config_limits_per_plugin = multimodal_config.limit_per_prompt
extra_keys = config_limits_per_plugin.keys() - self._plugins.keys()
if extra_keys:
logger.warning(
"Detected extra keys in `--limit-mm-per-prompt` which "
"are not registered as multi-modal plugins: %s. "
"They will be ignored.", extra_keys)
# NOTE: Currently the default is set to 1 for each plugin
# TODO: Automatically determine the limits based on budget
# once more models support multi-image inputs
limits_per_plugin = {
key: multimodal_config.get_limit_per_prompt(key)
for key in self._plugins
}
self._limits_by_model[model_config] = limits_per_plugin
def get_mm_limits_per_prompt(
self,
model_config: "ModelConfig",
) -> Mapping[str, int]:
"""
Get the maximum number of multi-modal input instances for each modality
that are allowed per prompt for a model class.
Note:
This should be called after :meth:`init_mm_limits_per_prompt`.
"""
if self.has_processor(model_config):
processor = self.create_processor(model_config, disable_cache=True)
profiler = MultiModalProfiler(processor)
return profiler.get_mm_limits()
return self._limits_by_model[model_config]
def register_processor(
self,
processor: MultiModalProcessorFactory[_I],
*,
info: ProcessingInfoFactory[_I],
dummy_inputs: DummyInputsBuilderFactory[_I],
):
"""
Register a multi-modal processor to a model class. The processor
is constructed lazily, hence a factory method should be passed.
When the model receives multi-modal data, the provided function is
invoked to transform the data into a dictionary of model inputs.
See also:
:ref:`mm-processing`
"""
def wrapper(model_cls: N) -> N:
if self._processor_factories.contains(model_cls, strict=True):
logger.warning(
"Model class %s already has a multi-modal processor "
"registered to %s. It is overwritten by the new one.",
model_cls, self)
self._processor_factories[model_cls] = _ProcessorFactories(
info=info,
dummy_inputs=dummy_inputs,
processor=processor,
)
return model_cls
return wrapper
def _get_model_cls(self, model_config: "ModelConfig"):
# Avoid circular import
from vllm.model_executor.model_loader import get_model_architecture
model_cls, _ = get_model_architecture(model_config)
return model_cls
def has_processor(self, model_config: "ModelConfig") -> bool:
"""
Test whether a multi-modal processor is defined for a specific model.
See also:
:ref:`mm-processing`
"""
return self._get_model_cls(model_config) in self._processor_factories
def create_processor(
self,
model_config: "ModelConfig",
*,
tokenizer: Optional[AnyTokenizer] = None,
disable_cache: Optional[bool] = None,
) -> BaseMultiModalProcessor[BaseProcessingInfo]:
"""
Create a multi-modal processor for a specific model and tokenizer.
See also:
:ref:`mm-processing`
"""
if tokenizer is None:
tokenizer = cached_tokenizer_from_config(model_config)
if disable_cache is None:
disable_cache = model_config.disable_mm_preprocessor_cache
model_cls = self._get_model_cls(model_config)
factories = self._processor_factories[model_cls]
ctx = InputProcessingContext(model_config, tokenizer)
cache = None if disable_cache else self._processing_cache
return factories.build_processor(ctx, cache=cache)
def get_decoder_dummy_data(
self,
model_config: "ModelConfig",
seq_len: int,
mm_counts: Optional[Mapping[str, int]] = None,
) -> DummyDecoderData:
"""
Create dummy data for profiling the memory usage of a model.
The model is identified by ``model_config``.
"""
processor = self.create_processor(model_config, disable_cache=True)
profiler = MultiModalProfiler(processor)
dummy_data = profiler.get_decoder_dummy_data(seq_len, mm_counts)
# Having more tokens is over-conservative but otherwise fine
token_ids = dummy_data.prompt_token_ids
if len(token_ids) < seq_len:
raise AssertionError(
f"Expected at least {seq_len} dummy tokens for profiling, "
f"but found {len(token_ids)} tokens instead.")
return dummy_data
def get_encoder_dummy_data(
self,
model_config: "ModelConfig",
seq_len: int,
mm_counts: Optional[Mapping[str, int]] = None,
) -> DummyEncoderData:
"""
Create dummy data for profiling the memory usage of a model.
The model is identified by ``model_config``.
"""
processor = self.create_processor(model_config, disable_cache=True)
profiler = MultiModalProfiler(processor)
dummy_data = profiler.get_encoder_dummy_data(seq_len, mm_counts)
# Having more tokens is over-conservative but otherwise fine
token_ids = dummy_data.prompt_token_ids
if len(token_ids) < seq_len:
logger.warning_once(
f"Expected at least {seq_len} dummy encoder tokens for "
f"profiling, but found {len(token_ids)} tokens instead.")
return dummy_data

386
vllm/multimodal/utils.py Normal file
View File

@@ -0,0 +1,386 @@
# SPDX-License-Identifier: Apache-2.0
from itertools import groupby
from pathlib import Path
from typing import TYPE_CHECKING, Optional, TypeVar, Union
from urllib.parse import ParseResult, urlparse
import numpy as np
import numpy.typing as npt
import torch
from PIL import Image
import vllm.envs as envs
from vllm.connections import HTTPConnection, global_http_connection
from .audio import AudioMediaIO
from .base import MediaIO
from .image import ImageEmbeddingMediaIO, ImageMediaIO
from .inputs import PlaceholderRange
from .video import VideoMediaIO
_M = TypeVar("_M")
if TYPE_CHECKING:
from .hasher import MultiModalHashDict
from .inputs import MultiModalKwargs, MultiModalPlaceholderDict
class MediaConnector:
def __init__(
self,
connection: HTTPConnection = global_http_connection,
*,
allowed_local_media_path: str = "",
) -> None:
super().__init__()
self.connection = connection
if allowed_local_media_path:
allowed_local_media_path_ = Path(allowed_local_media_path)
if not allowed_local_media_path_.exists():
raise ValueError(
"Invalid `--allowed-local-media-path`: The path "
f"{allowed_local_media_path_} does not exist.")
if not allowed_local_media_path_.is_dir():
raise ValueError(
"Invalid `--allowed-local-media-path`: The path "
f"{allowed_local_media_path_} must be a directory.")
else:
allowed_local_media_path_ = None
self.allowed_local_media_path = allowed_local_media_path_
def _load_data_url(
self,
url_spec: ParseResult,
media_io: MediaIO[_M],
) -> _M:
data_spec, data = url_spec.path.split(",", 1)
media_type, data_type = data_spec.split(";", 1)
if data_type != "base64":
msg = "Only base64 data URLs are supported for now."
raise NotImplementedError(msg)
return media_io.load_base64(media_type, data)
def _load_file_url(
self,
url_spec: ParseResult,
media_io: MediaIO[_M],
) -> _M:
allowed_local_media_path = self.allowed_local_media_path
if allowed_local_media_path is None:
raise RuntimeError("Cannot load local files without "
"`--allowed-local-media-path`.")
filepath = Path(url_spec.path)
if allowed_local_media_path not in filepath.resolve().parents:
raise ValueError(
f"The file path {filepath} must be a subpath "
f"of `--allowed-local-media-path` {allowed_local_media_path}.")
return media_io.load_file(filepath)
def load_from_url(
self,
url: str,
media_io: MediaIO[_M],
*,
fetch_timeout: Optional[int] = None,
) -> _M:
url_spec = urlparse(url)
if url_spec.scheme.startswith("http"):
connection = self.connection
data = connection.get_bytes(url, timeout=fetch_timeout)
return media_io.load_bytes(data)
if url_spec.scheme == "data":
return self._load_data_url(url_spec, media_io)
if url_spec.scheme == "file":
return self._load_file_url(url_spec, media_io)
msg = "The URL must be either a HTTP, data or file URL."
raise ValueError(msg)
async def load_from_url_async(
self,
url: str,
media_io: MediaIO[_M],
*,
fetch_timeout: Optional[int] = None,
) -> _M:
url_spec = urlparse(url)
if url_spec.scheme.startswith("http"):
connection = self.connection
data = await connection.async_get_bytes(url, timeout=fetch_timeout)
return media_io.load_bytes(data)
if url_spec.scheme == "data":
return self._load_data_url(url_spec, media_io)
if url_spec.scheme == "file":
return self._load_file_url(url_spec, media_io)
msg = "The URL must be either a HTTP, data or file URL."
raise ValueError(msg)
def fetch_audio(
self,
audio_url: str,
) -> tuple[np.ndarray, Union[int, float]]:
"""
Load audio from a URL.
"""
audio_io = AudioMediaIO()
return self.load_from_url(
audio_url,
audio_io,
fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
)
async def fetch_audio_async(
self,
audio_url: str,
) -> tuple[np.ndarray, Union[int, float]]:
"""
Asynchronously fetch audio from a URL.
"""
audio_io = AudioMediaIO()
return await self.load_from_url_async(
audio_url,
audio_io,
fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
)
def fetch_image(
self,
image_url: str,
*,
image_mode: str = "RGB",
) -> Image.Image:
"""
Load a PIL image from a HTTP or base64 data URL.
By default, the image is converted into RGB format.
"""
image_io = ImageMediaIO(image_mode=image_mode)
return self.load_from_url(
image_url,
image_io,
fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
)
async def fetch_image_async(
self,
image_url: str,
*,
image_mode: str = "RGB",
) -> Image.Image:
"""
Asynchronously load a PIL image from a HTTP or base64 data URL.
By default, the image is converted into RGB format.
"""
image_io = ImageMediaIO(image_mode=image_mode)
return await self.load_from_url_async(
image_url,
image_io,
fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
)
def fetch_video(
self,
video_url: str,
*,
image_mode: str = "RGB",
num_frames: int = 32,
) -> npt.NDArray:
"""
Load video from a HTTP or base64 data URL.
"""
image_io = ImageMediaIO(image_mode=image_mode)
video_io = VideoMediaIO(image_io, num_frames=num_frames)
return self.load_from_url(
video_url,
video_io,
fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
)
async def fetch_video_async(
self,
video_url: str,
*,
image_mode: str = "RGB",
num_frames: int = 32,
) -> npt.NDArray:
"""
Asynchronously load video from a HTTP or base64 data URL.
By default, the image is converted into RGB format.
"""
image_io = ImageMediaIO(image_mode=image_mode)
video_io = VideoMediaIO(image_io, num_frames=num_frames)
return await self.load_from_url_async(
video_url,
video_io,
fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
)
def fetch_image_embedding(
self,
data: str,
) -> torch.Tensor:
"""
Load image embedding from a URL.
"""
image_embedding_io = ImageEmbeddingMediaIO()
return image_embedding_io.load_base64("", data)
global_media_connector = MediaConnector()
"""The global :class:`MediaConnector` instance used by vLLM."""
fetch_audio = global_media_connector.fetch_audio
fetch_image = global_media_connector.fetch_image
fetch_video = global_media_connector.fetch_video
def encode_audio_base64(
audio: np.ndarray,
sampling_rate: int,
) -> str:
"""Encode audio as base64."""
audio_io = AudioMediaIO()
return audio_io.encode_base64((audio, sampling_rate))
def encode_image_base64(
image: Image.Image,
*,
image_mode: str = "RGB",
format: str = "JPEG",
) -> str:
"""
Encode a pillow image to base64 format.
By default, the image is converted into RGB format before being encoded.
"""
image_io = ImageMediaIO(image_mode=image_mode)
return image_io.encode_base64(image, image_format=format)
def encode_video_base64(frames: npt.NDArray) -> str:
image_io = ImageMediaIO()
video_io = VideoMediaIO(image_io)
return video_io.encode_base64(frames)
def merge_and_sort_multimodal_metadata(
mm_positions: "MultiModalPlaceholderDict",
mm_hashes: Optional["MultiModalHashDict"],
) -> tuple[list[str], list[PlaceholderRange], Optional[list[str]]]:
"""Given a MultiModalPlaceholderDict, merge all PlaceholderRange
objects from all available modalities into a single list of
PlaceholderRange, sorted by their offset (starting index in the input
sequence) in the ascending order.
Optionally if a MultiModalHashDict is given, same operation will be
applied to the object and the sorted list of hashes will be returned.
Returns:
list[str]: List of item modalities in order of their positions in
the input sequence.
list[PlaceholderRange]: Sorted list of all PlaceholdeRanges from
mm_positions.
Optional[list[str]]: Sorted list of all hashes from mm_hashes if
given, None otherwise.
"""
modalities = list(mm_positions.keys())
assert len(modalities) > 0, "No modalities found in the mm_positions."
# For single modality, placeholder ranges and hashes are already sorted
# so we can return the list directly.
if len(modalities) == 1:
modality = modalities[0]
placeholder_list = list(mm_positions[modality])
return [modality] * len(
placeholder_list
), placeholder_list, None if not mm_hashes else mm_hashes[modality]
# Create a list of (modality, placeholder, hash) tuples for all placeholders
all_items = []
for modality in modalities:
placeholder_list = list(mm_positions[modality])
hash_list: list[Optional[str]] = list(
mm_hashes[modality]) if mm_hashes and modality in mm_hashes else [
None
] * len(placeholder_list)
for placeholder, hash_value in zip(placeholder_list, hash_list):
all_items.append((modality, placeholder, hash_value))
# Sort all items by offset
all_items.sort(key=lambda x: x[1]['offset'])
# Split into separate lists
sorted_modalities = [item[0] for item in all_items]
merged_placeholders = [item[1] for item in all_items]
merged_hashes = [str(item[2])
for item in all_items] if mm_hashes is not None else None
return sorted_modalities, merged_placeholders, merged_hashes
def group_mm_inputs_by_modality(
mm_inputs: list["MultiModalKwargs"]) -> list[list["MultiModalKwargs"]]:
"""Group consecutive MultiModalKwargs from mm_inputs with the same modality
together into the same list for batching purpose. For MultiModalKwargs with
multiple modalities, put them into their own list.
Args:
mm_inputs: List of MultiModalKwargs.
Returns:
list[list[MultiModalKwargs]]: List of list of MultiModalKwargs, each
inner list contains consecutive MultiModalKwargs with same modality.
"""
if not mm_inputs:
return []
def modality_group_func(mm_input: "MultiModalKwargs") -> Union[str, int]:
# If the input has multiple modalities, return a id as the unique key
# for the mm_input input.
if len(mm_input.modalities) > 1:
return id(mm_input)
elif len(mm_input.modalities) == 1:
return list(mm_input.modalities)[0]
# FIXME(Isotr0py): Modality of mm_input from legacy pipeline is empty,
# this is used to make InternVL with legacy pipeline still work with v1.
else:
return ""
return [
list(group) for _, group in groupby(mm_inputs, key=modality_group_func)
]

233
vllm/multimodal/video.py Normal file
View File

@@ -0,0 +1,233 @@
# SPDX-License-Identifier: Apache-2.0
import base64
from functools import partial
from io import BytesIO
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional
import numpy as np
import numpy.typing as npt
from PIL import Image
from vllm.inputs.registry import InputContext
from vllm.logger import init_logger
from vllm.transformers_utils.processor import cached_get_video_processor
from vllm.utils import is_list_of
from .base import MediaIO, ModalityData
from .image import ImageMediaIO, ImagePlugin
from .inputs import MultiModalKwargs, VideoItem
if TYPE_CHECKING:
from vllm.config import ModelConfig
logger = init_logger(__name__)
class VideoPlugin(ImagePlugin):
"""Plugin for video data."""
def get_data_key(self) -> str:
return "video"
def _get_hf_video_processor(
self,
model_config: "ModelConfig",
mm_processor_kwargs: Optional[dict[str, Any]] = None,
):
if mm_processor_kwargs is None:
mm_processor_kwargs = {}
return cached_get_video_processor(
model_config.model,
trust_remote_code=model_config.trust_remote_code,
**mm_processor_kwargs)
def _default_input_mapper(
self,
ctx: InputContext,
data: ModalityData[VideoItem],
**mm_processor_kwargs,
) -> MultiModalKwargs:
model_config = ctx.model_config
if isinstance(data, list) and len(data) == 1:
data = data[0] # type: ignore
if isinstance(data, np.ndarray) or is_list_of(data, np.ndarray):
video_processor = self._get_hf_video_processor(
model_config,
mm_processor_kwargs,
)
if video_processor is None:
raise RuntimeError("No HuggingFace processor is available "
"to process the video object")
try:
# NOTE: Similar to image; it may be a good idea to filter and
# pass mm_processor_kwargs here too, but for now we don't to
# avoid extra complexity if the initializer and preprocess
# signatures of the processor don't align
batch_data = video_processor(data, return_tensors="pt").data
except Exception:
logger.error("Failed to process video (%s)", data)
raise
return MultiModalKwargs(batch_data)
raise TypeError(f"Invalid video type: {type(data)}")
def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
return 4096
def resize_video(frames: npt.NDArray, size: tuple[int, int]) -> npt.NDArray:
num_frames, _, _, channels = frames.shape
new_height, new_width = size
resized_frames = np.empty((num_frames, new_height, new_width, channels),
dtype=frames.dtype)
# lazy import cv2 to avoid bothering users who only use text models
import cv2
for i, frame in enumerate(frames):
resized_frame = cv2.resize(frame, (new_width, new_height))
resized_frames[i] = resized_frame
return resized_frames
def rescale_video_size(frames: npt.NDArray, size_factor: float) -> npt.NDArray:
_, height, width, _ = frames.shape
new_height = int(height * size_factor)
new_width = int(width * size_factor)
return resize_video(frames, (new_height, new_width))
def sample_frames_from_video(frames: npt.NDArray,
num_frames: int) -> npt.NDArray:
total_frames = frames.shape[0]
if num_frames == -1:
return frames
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
sampled_frames = frames[frame_indices, ...]
return sampled_frames
class VideoLoader:
@classmethod
def load_bytes(self, data: bytes, num_frames: int = -1) -> npt.NDArray:
raise NotImplementedError
class OpenCVVideoBackend(VideoLoader):
def get_cv2_video_api(self):
import cv2.videoio_registry as vr
api_pref = None
for backend in vr.getStreamBufferedBackends():
if not vr.hasBackend(backend):
continue
if not vr.isBackendBuiltIn(backend):
_, abi, api = vr.getStreamBufferedBackendPluginVersion(backend)
if (abi < 1 or (abi == 1 and api < 2)):
continue
api_pref = backend
break
return api_pref
@classmethod
def load_bytes(cls, data: bytes, num_frames: int = -1) -> npt.NDArray:
import cv2
backend = cls().get_cv2_video_api()
cap = cv2.VideoCapture(BytesIO(data), backend, [])
if not cap.isOpened():
raise ValueError("Could not open video stream")
total_frames_num = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
full_read = num_frames == -1 or total_frames_num < num_frames
if full_read:
frame_idx = list(range(0, total_frames_num))
else:
uniform_sampled_frames = np.linspace(0,
total_frames_num - 1,
num_frames,
dtype=int)
frame_idx = uniform_sampled_frames.tolist()
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
frames = np.empty((len(frame_idx), height, width, 3), dtype=np.uint8)
i = 0
for idx in range(total_frames_num):
ok = cap.grab() # next img
if not ok:
break
if idx in frame_idx: # only decompress needed
ret, frame = cap.retrieve()
if ret:
frames[i] = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
i += 1
# we expect all frames loaded
assert i == num_frames
return frames
class VideoMediaIO(MediaIO[npt.NDArray]):
def __init__(
self,
image_io: ImageMediaIO,
*,
num_frames: int = 32,
) -> None:
super().__init__()
self.image_io = image_io
self.num_frames = num_frames
self.video_loader = OpenCVVideoBackend
def load_bytes(self, data: bytes) -> npt.NDArray:
return self.video_loader.load_bytes(data, self.num_frames)
def load_base64(self, media_type: str, data: str) -> npt.NDArray:
if media_type.lower() == "video/jpeg":
load_frame = partial(
self.image_io.load_base64,
"image/jpeg",
)
return np.stack([
np.array(load_frame(frame_data))
for frame_data in data.split(",")
])
return self.load_bytes(base64.b64decode(data))
def load_file(self, filepath: Path) -> npt.NDArray:
with filepath.open("rb") as f:
data = f.read()
return self.load_bytes(data)
def encode_base64(
self,
media: npt.NDArray,
*,
video_format: str = "JPEG",
) -> str:
video = media
if video_format == "JPEG":
encode_frame = partial(
self.image_io.encode_base64,
image_format=video_format,
)
return ",".join(
encode_frame(Image.fromarray(frame)) for frame in video)
msg = "Only JPEG format is supported for now."
raise NotImplementedError(msg)