Add minimal vLLM 0.16.1 build repo for BI-V150
This commit is contained in:
38
vllm/multimodal/__init__.py
Normal file
38
vllm/multimodal/__init__.py
Normal file
@@ -0,0 +1,38 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from .hasher import MultiModalHasher
|
||||
from .inputs import (
|
||||
BatchedTensorInputs,
|
||||
ModalityData,
|
||||
MultiModalDataBuiltins,
|
||||
MultiModalDataDict,
|
||||
MultiModalKwargsItems,
|
||||
MultiModalPlaceholderDict,
|
||||
MultiModalUUIDDict,
|
||||
NestedTensors,
|
||||
)
|
||||
from .registry import MultiModalRegistry
|
||||
|
||||
MULTIMODAL_REGISTRY = MultiModalRegistry()
|
||||
"""
|
||||
The global [`MultiModalRegistry`][vllm.multimodal.registry.MultiModalRegistry]
|
||||
is used by model runners to dispatch data processing according to the target
|
||||
model.
|
||||
|
||||
Info:
|
||||
[mm_processing](../../../design/mm_processing.md)
|
||||
"""
|
||||
|
||||
__all__ = [
|
||||
"BatchedTensorInputs",
|
||||
"ModalityData",
|
||||
"MultiModalDataBuiltins",
|
||||
"MultiModalDataDict",
|
||||
"MultiModalHasher",
|
||||
"MultiModalKwargsItems",
|
||||
"MultiModalPlaceholderDict",
|
||||
"MultiModalUUIDDict",
|
||||
"NestedTensors",
|
||||
"MULTIMODAL_REGISTRY",
|
||||
"MultiModalRegistry",
|
||||
]
|
||||
336
vllm/multimodal/audio.py
Normal file
336
vllm/multimodal/audio.py
Normal file
@@ -0,0 +1,336 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Literal
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import torch
|
||||
|
||||
from vllm.utils.import_utils import PlaceholderModule
|
||||
|
||||
try:
|
||||
import librosa
|
||||
except ImportError:
|
||||
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
|
||||
|
||||
|
||||
try:
|
||||
import scipy.signal as scipy_signal
|
||||
except ImportError:
|
||||
scipy_signal = PlaceholderModule("scipy").placeholder_attr("signal") # type: ignore[assignment]
|
||||
|
||||
# ============================================================
|
||||
|
||||
|
||||
class ChannelReduction(str, Enum):
|
||||
"""Method to reduce multi-channel audio to target channels."""
|
||||
|
||||
MEAN = "mean" # Average across channels (default, preserves energy balance)
|
||||
FIRST = "first" # Take first channel only
|
||||
MAX = "max" # Take max value across channels
|
||||
SUM = "sum" # Sum across channels
|
||||
|
||||
|
||||
@dataclass
|
||||
class AudioSpec:
|
||||
"""Specification for target audio format.
|
||||
|
||||
This dataclass defines the expected audio format for a model's feature
|
||||
extractor. It is used to normalize audio data before processing.
|
||||
|
||||
Attributes:
|
||||
target_channels: Number of output channels. None means passthrough
|
||||
(no normalization). 1 = mono, 2 = stereo, etc.
|
||||
channel_reduction: Method to reduce channels when input has more
|
||||
channels than target. Only used when reducing channels.
|
||||
"""
|
||||
|
||||
target_channels: int | None = 1
|
||||
channel_reduction: ChannelReduction = ChannelReduction.MEAN
|
||||
|
||||
@property
|
||||
def needs_normalization(self) -> bool:
|
||||
"""Whether audio normalization is needed."""
|
||||
return self.target_channels is not None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
if self.target_channels is None:
|
||||
return "AudioSpec(passthrough)"
|
||||
return (
|
||||
f"AudioSpec(channels={self.target_channels}, "
|
||||
f"reduction={self.channel_reduction.value})"
|
||||
)
|
||||
|
||||
|
||||
# Pre-defined specs for common use cases
|
||||
MONO_AUDIO_SPEC = AudioSpec(target_channels=1, channel_reduction=ChannelReduction.MEAN)
|
||||
PASSTHROUGH_AUDIO_SPEC = AudioSpec(target_channels=None)
|
||||
|
||||
|
||||
def normalize_audio(
|
||||
audio: npt.NDArray[np.floating] | torch.Tensor,
|
||||
spec: AudioSpec,
|
||||
) -> npt.NDArray[np.floating] | torch.Tensor:
|
||||
"""Normalize audio to the specified format.
|
||||
|
||||
This function handles channel reduction for multi-channel audio,
|
||||
supporting both numpy arrays and torch tensors.
|
||||
|
||||
Args:
|
||||
audio: Input audio data. Can be:
|
||||
- 1D array/tensor: (time,) - already mono
|
||||
- 2D array/tensor: (channels, time) - standard format from torchaudio
|
||||
- 2D array/tensor: (time, channels) - format from soundfile
|
||||
(will be auto-detected and transposed if time > channels)
|
||||
spec: AudioSpec defining the target format.
|
||||
|
||||
Returns:
|
||||
Normalized audio in the same type as input (numpy or torch).
|
||||
For mono output (target_channels=1), returns 1D array/tensor.
|
||||
|
||||
Raises:
|
||||
ValueError: If audio has unsupported dimensions or channel expansion
|
||||
is requested (e.g., mono to stereo).
|
||||
"""
|
||||
if not spec.needs_normalization:
|
||||
return audio
|
||||
|
||||
# Handle 1D audio (already mono)
|
||||
if audio.ndim == 1:
|
||||
if spec.target_channels == 1:
|
||||
return audio
|
||||
raise ValueError(f"Cannot expand mono audio to {spec.target_channels} channels")
|
||||
|
||||
# Handle 2D audio
|
||||
if audio.ndim != 2:
|
||||
raise ValueError(f"Unsupported audio shape: {audio.shape}. Expected 1D or 2D.")
|
||||
|
||||
# Auto-detect format: if shape[0] > shape[1], assume (time, channels)
|
||||
# This handles soundfile format where time dimension is typically much larger
|
||||
if audio.shape[0] > audio.shape[1]:
|
||||
# Transpose from (time, channels) to (channels, time)
|
||||
audio = audio.T if isinstance(audio, np.ndarray) else audio.T
|
||||
|
||||
num_channels = audio.shape[0]
|
||||
|
||||
# No reduction needed if already at target
|
||||
if num_channels == spec.target_channels:
|
||||
return audio
|
||||
|
||||
# Cannot expand channels
|
||||
if num_channels < spec.target_channels:
|
||||
raise ValueError(
|
||||
f"Cannot expand {num_channels} channels to {spec.target_channels}"
|
||||
)
|
||||
|
||||
# Reduce channels
|
||||
is_numpy = isinstance(audio, np.ndarray)
|
||||
|
||||
if spec.target_channels == 1:
|
||||
# Reduce to mono
|
||||
if spec.channel_reduction == ChannelReduction.MEAN:
|
||||
result = np.mean(audio, axis=0) if is_numpy else audio.mean(dim=0)
|
||||
elif spec.channel_reduction == ChannelReduction.FIRST:
|
||||
result = audio[0]
|
||||
elif spec.channel_reduction == ChannelReduction.MAX:
|
||||
result = np.max(audio, axis=0) if is_numpy else audio.max(dim=0).values
|
||||
elif spec.channel_reduction == ChannelReduction.SUM:
|
||||
result = np.sum(audio, axis=0) if is_numpy else audio.sum(dim=0)
|
||||
else:
|
||||
raise ValueError(f"Unknown reduction method: {spec.channel_reduction}")
|
||||
return result
|
||||
else:
|
||||
# Reduce to N channels (take first N and apply reduction if needed)
|
||||
# For now, just take first N channels
|
||||
return audio[: spec.target_channels]
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Audio Resampling
|
||||
# ============================================================
|
||||
|
||||
|
||||
def resample_audio_librosa(
|
||||
audio: npt.NDArray[np.floating],
|
||||
*,
|
||||
orig_sr: float,
|
||||
target_sr: float,
|
||||
) -> npt.NDArray[np.floating]:
|
||||
return librosa.resample(audio, orig_sr=orig_sr, target_sr=target_sr)
|
||||
|
||||
|
||||
def resample_audio_scipy(
|
||||
audio: npt.NDArray[np.floating],
|
||||
*,
|
||||
orig_sr: float,
|
||||
target_sr: float,
|
||||
):
|
||||
if orig_sr > target_sr:
|
||||
return scipy_signal.resample_poly(audio, 1, orig_sr // target_sr)
|
||||
elif orig_sr < target_sr:
|
||||
return scipy_signal.resample_poly(audio, target_sr // orig_sr, 1)
|
||||
return audio
|
||||
|
||||
|
||||
class AudioResampler:
|
||||
"""Resample audio data to a target sample rate."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target_sr: float | None = None,
|
||||
method: Literal["librosa", "scipy"] = "librosa",
|
||||
):
|
||||
self.target_sr = target_sr
|
||||
self.method = method
|
||||
|
||||
def resample(
|
||||
self,
|
||||
audio: npt.NDArray[np.floating],
|
||||
*,
|
||||
orig_sr: float,
|
||||
) -> npt.NDArray[np.floating]:
|
||||
if self.target_sr is None:
|
||||
raise RuntimeError(
|
||||
"Audio resampling is not supported when `target_sr` is not provided"
|
||||
)
|
||||
if math.isclose(
|
||||
float(orig_sr),
|
||||
float(self.target_sr),
|
||||
rel_tol=0.0,
|
||||
abs_tol=1e-6,
|
||||
):
|
||||
return audio
|
||||
if self.method == "librosa":
|
||||
return resample_audio_librosa(
|
||||
audio, orig_sr=orig_sr, target_sr=self.target_sr
|
||||
)
|
||||
elif self.method == "scipy":
|
||||
return resample_audio_scipy(
|
||||
audio, orig_sr=orig_sr, target_sr=self.target_sr
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid resampling method: {self.method}. "
|
||||
"Supported methods are 'librosa' and 'scipy'."
|
||||
)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Audio Chunking / Splitting
|
||||
# ============================================================
|
||||
|
||||
|
||||
def split_audio(
|
||||
audio_data: np.ndarray,
|
||||
sample_rate: int,
|
||||
max_clip_duration_s: float,
|
||||
overlap_duration_s: float,
|
||||
min_energy_window_size: int,
|
||||
) -> list[np.ndarray]:
|
||||
"""Split audio into chunks with intelligent split points.
|
||||
|
||||
Splits long audio into smaller chunks at low-energy regions to minimize
|
||||
cutting through speech. Uses overlapping windows to find quiet moments
|
||||
for splitting.
|
||||
|
||||
Args:
|
||||
audio_data: Audio array to split. Can be 1D (mono) or multi-dimensional.
|
||||
Splits along the last dimension (time axis).
|
||||
sample_rate: Sample rate of the audio in Hz.
|
||||
max_clip_duration_s: Maximum duration of each chunk in seconds.
|
||||
overlap_duration_s: Overlap duration in seconds between consecutive chunks.
|
||||
Used to search for optimal split points.
|
||||
min_energy_window_size: Window size in samples for finding low-energy regions.
|
||||
|
||||
Returns:
|
||||
List of audio chunks. Each chunk is a numpy array with the same shape
|
||||
as the input except for the last (time) dimension.
|
||||
|
||||
Example:
|
||||
>>> audio = np.random.randn(1040000) # 65 seconds at 16kHz
|
||||
>>> chunks = split_audio(
|
||||
... audio_data=audio,
|
||||
... sample_rate=16000,
|
||||
... max_clip_duration_s=30.0,
|
||||
... overlap_duration_s=1.0,
|
||||
... min_energy_window_size=1600,
|
||||
... )
|
||||
>>> len(chunks)
|
||||
3
|
||||
"""
|
||||
chunk_size = int(sample_rate * max_clip_duration_s)
|
||||
overlap_size = int(sample_rate * overlap_duration_s)
|
||||
chunks = []
|
||||
i = 0
|
||||
|
||||
while i < audio_data.shape[-1]:
|
||||
if i + chunk_size >= audio_data.shape[-1]:
|
||||
# Handle last chunk - take everything remaining
|
||||
chunks.append(audio_data[..., i:])
|
||||
break
|
||||
|
||||
# Find the best split point in the overlap region
|
||||
search_start = i + chunk_size - overlap_size
|
||||
search_end = min(i + chunk_size, audio_data.shape[-1])
|
||||
split_point = find_split_point(
|
||||
audio_data, search_start, search_end, min_energy_window_size
|
||||
)
|
||||
|
||||
# Extract chunk up to the split point
|
||||
chunks.append(audio_data[..., i:split_point])
|
||||
i = split_point
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
def find_split_point(
|
||||
wav: np.ndarray,
|
||||
start_idx: int,
|
||||
end_idx: int,
|
||||
min_energy_window: int,
|
||||
) -> int:
|
||||
"""Find the best point to split audio by looking for silence or low amplitude.
|
||||
|
||||
Searches for the quietest region within a specified range by calculating
|
||||
RMS energy in sliding windows.
|
||||
|
||||
Args:
|
||||
wav: Audio array. Can be 1D or multi-dimensional.
|
||||
start_idx: Start index of search region (inclusive).
|
||||
end_idx: End index of search region (exclusive).
|
||||
min_energy_window: Window size in samples for energy calculation.
|
||||
|
||||
Returns:
|
||||
Index of the quietest point within the search region. This is the
|
||||
recommended split point to minimize audio artifacts.
|
||||
|
||||
Example:
|
||||
>>> audio = np.random.randn(32000)
|
||||
>>> # Insert quiet region
|
||||
>>> audio[16000:17600] = 0.01
|
||||
>>> split_idx = find_split_point(
|
||||
... wav=audio,
|
||||
... start_idx=0,
|
||||
... end_idx=32000,
|
||||
... min_energy_window=1600,
|
||||
... )
|
||||
>>> 16000 <= split_idx <= 17600
|
||||
True
|
||||
"""
|
||||
segment = wav[start_idx:end_idx]
|
||||
|
||||
# Calculate RMS energy in small windows
|
||||
min_energy = math.inf
|
||||
quietest_idx = 0
|
||||
|
||||
for i in range(0, len(segment) - min_energy_window, min_energy_window):
|
||||
window = segment[i : i + min_energy_window]
|
||||
energy = (window**2).mean() ** 0.5
|
||||
if energy < min_energy:
|
||||
quietest_idx = i + start_idx
|
||||
min_energy = energy
|
||||
|
||||
return quietest_idx
|
||||
725
vllm/multimodal/cache.py
Normal file
725
vllm/multimodal/cache.py
Normal file
@@ -0,0 +1,725 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import operator
|
||||
import sys
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping, Sequence
|
||||
from multiprocessing.synchronize import Lock as LockType
|
||||
from typing import TYPE_CHECKING, Generic, TypeAlias, TypeVar, cast
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed.device_communicators.shm_object_storage import (
|
||||
MsgpackSerde,
|
||||
SingleWriterShmObjectStorage,
|
||||
SingleWriterShmRingBuffer,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.cache import CacheInfo, LRUCache
|
||||
from vllm.utils.jsontree import json_count_leaves, json_map_leaves, json_reduce_leaves
|
||||
from vllm.utils.mem_constants import GiB_bytes, MiB_bytes
|
||||
from vllm.utils.mem_utils import format_gib
|
||||
|
||||
from .inputs import (
|
||||
MultiModalBatchedField,
|
||||
MultiModalFeatureSpec,
|
||||
MultiModalFieldElem,
|
||||
MultiModalKwargsItem,
|
||||
MultiModalKwargsItems,
|
||||
NestedTensors,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
|
||||
from .processing.processor import ResolvedPromptUpdate
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MultiModalProcessorCacheItem:
|
||||
"""
|
||||
The data to store inside `MultiModalProcessorOnlyCache`.
|
||||
|
||||
Args:
|
||||
item: The processed tensor data corresponding to a multi-modal item.
|
||||
prompt_updates: The prompt updates corresponding to `item`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
item: MultiModalKwargsItem,
|
||||
prompt_updates: Sequence["ResolvedPromptUpdate"],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.item = item
|
||||
self.prompt_updates = prompt_updates
|
||||
|
||||
|
||||
class MultiModalProcessorCacheItemMetadata:
|
||||
"""
|
||||
The metadata to store inside `MultiModalProcessorSenderCache`.
|
||||
|
||||
Args:
|
||||
item: The processed tensor data corresponding to a multi-modal item.
|
||||
Since P1 already stores the tensor data, we only store its size
|
||||
metadata in P0 to reduce memory usage. The size metadata is still
|
||||
needed to keep the same cache eviction policy as P0.
|
||||
prompt_updates: The prompt updates corresponding to `item`.
|
||||
This needs to stay on P0 because for some models, they are
|
||||
dependent on the processed tensor data (cached on P1).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
item: MultiModalKwargsItem,
|
||||
prompt_updates: Sequence["ResolvedPromptUpdate"],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.item_size = MultiModalCache.get_item_size(item)
|
||||
self.prompt_updates = prompt_updates
|
||||
|
||||
|
||||
MultiModalCacheValue: TypeAlias = (
|
||||
MultiModalProcessorCacheItem
|
||||
| MultiModalProcessorCacheItemMetadata
|
||||
| MultiModalKwargsItems
|
||||
| MultiModalKwargsItem
|
||||
| Mapping[str, NestedTensors]
|
||||
)
|
||||
|
||||
_V = TypeVar("_V", bound=MultiModalCacheValue)
|
||||
|
||||
|
||||
class MultiModalCache:
|
||||
@classmethod
|
||||
def get_leaf_size(cls, leaf: object) -> int:
|
||||
if isinstance(leaf, MultiModalProcessorCacheItem):
|
||||
return cls.get_leaf_size(leaf.item)
|
||||
if isinstance(leaf, MultiModalProcessorCacheItemMetadata):
|
||||
return leaf.item_size
|
||||
|
||||
# These are not subclasses of dict
|
||||
if isinstance(
|
||||
leaf,
|
||||
(MultiModalKwargsItems, MultiModalKwargsItem, MultiModalFieldElem),
|
||||
):
|
||||
return cls.get_item_size(leaf.data) # type: ignore
|
||||
|
||||
# sys.getsizeof doesn't work for tensors
|
||||
if isinstance(leaf, torch.Tensor):
|
||||
return leaf.nbytes
|
||||
|
||||
return sys.getsizeof(leaf)
|
||||
|
||||
@classmethod
|
||||
def get_item_size(
|
||||
cls,
|
||||
value: MultiModalCacheValue,
|
||||
*,
|
||||
debug: bool = False,
|
||||
) -> int:
|
||||
size = json_reduce_leaves(
|
||||
operator.add, json_map_leaves(cls.get_leaf_size, value)
|
||||
)
|
||||
|
||||
if debug:
|
||||
leaf_count = json_count_leaves(value)
|
||||
logger.debug(
|
||||
"Calculated size of %s to be %s GiB (%d leaves)",
|
||||
type(value),
|
||||
format_gib(size),
|
||||
leaf_count,
|
||||
)
|
||||
|
||||
return size
|
||||
|
||||
@classmethod
|
||||
def get_item_complexity(cls, value: MultiModalCacheValue) -> int:
|
||||
"""
|
||||
Get the number of leaf elements in a multi-modal cache value.
|
||||
|
||||
This provides a measure of structural complexity that can be useful
|
||||
for debugging cache performance and understanding data patterns.
|
||||
|
||||
Args:
|
||||
value: The multi-modal cache value to analyze.
|
||||
|
||||
Returns:
|
||||
The number of leaf elements in the nested structure.
|
||||
"""
|
||||
return json_count_leaves(value)
|
||||
|
||||
@classmethod
|
||||
def get_lru_cache(
|
||||
cls,
|
||||
capacity_gb: float,
|
||||
value_type: type[_V],
|
||||
*,
|
||||
debug: bool = False,
|
||||
) -> LRUCache[str, _V]:
|
||||
return LRUCache(
|
||||
GiB_bytes * capacity_gb,
|
||||
getsizeof=lambda x: cls.get_item_size(x, debug=debug),
|
||||
)
|
||||
|
||||
|
||||
_I = TypeVar("_I", contravariant=True)
|
||||
_O = TypeVar("_O", covariant=True)
|
||||
|
||||
|
||||
class BaseMultiModalCache(ABC, Generic[_I, _O]):
|
||||
"""
|
||||
Abstract base class to read/write multi-modal items from cache.
|
||||
|
||||
The idea of multi-modal caching is based on having a client and server
|
||||
where the client executes in the frontend process (=P0) and
|
||||
the server in the core process (=P1). The data flow is as follows:
|
||||
|
||||
```
|
||||
is_cached() x N get_and_update()
|
||||
P0: From API -----------------> -----------------> To P1
|
||||
|
||||
get_and_update()
|
||||
P1: From P0 -----------------> To model
|
||||
```
|
||||
|
||||
`is_cached()` can be called any number of times in P0. However,
|
||||
`get_and_update()` must be called in P0 and P1 one after another
|
||||
so that their cache eviction order remains the same.
|
||||
|
||||
This ensures that the keys in P0 and P1 caches are mirrored,
|
||||
allowing us to determine whether a key is cached in P1 by looking
|
||||
up the P0 cache, without having to communicate with P1.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_and_update_item(
|
||||
self,
|
||||
mm_item: _I,
|
||||
mm_hash: str,
|
||||
) -> _O:
|
||||
"""
|
||||
Possibly update a multi-modal item based on whether it is
|
||||
in the underlying cache.
|
||||
|
||||
This update is done out-of-place and updates the cache eviction order.
|
||||
|
||||
Args:
|
||||
mm_item: The multi-modal item to update.
|
||||
mm_hash: The hash of `mm_item`.
|
||||
|
||||
Returns:
|
||||
The update multi-modal item.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_and_update(
|
||||
self,
|
||||
mm_items: Sequence[_I],
|
||||
mm_hashes: list[str],
|
||||
) -> list[_O]:
|
||||
"""
|
||||
Possibly update a sequence of multi-modal items based on whether they
|
||||
are in the underlying cache.
|
||||
|
||||
This update is done out-of-place and updates the cache eviction order.
|
||||
|
||||
Args:
|
||||
mm_items: The multi-modal items to update.
|
||||
mm_hashes: The hash of each item in `mm_items`.
|
||||
|
||||
Returns:
|
||||
A new list of updated multi-modal items.
|
||||
"""
|
||||
assert len(mm_items) == len(mm_hashes)
|
||||
|
||||
return [
|
||||
self.get_and_update_item(mm_item, mm_hash)
|
||||
for mm_item, mm_hash in zip(mm_items, mm_hashes)
|
||||
]
|
||||
|
||||
@abstractmethod
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear the underlying cache."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
MultiModalProcessorCacheInItem: TypeAlias = (
|
||||
tuple[MultiModalKwargsItem, Sequence["ResolvedPromptUpdate"]] | None
|
||||
)
|
||||
|
||||
|
||||
MultiModalProcessorCacheOutItem: TypeAlias = tuple[
|
||||
MultiModalKwargsItem | None, Sequence["ResolvedPromptUpdate"]
|
||||
]
|
||||
|
||||
|
||||
class BaseMultiModalProcessorCache(
|
||||
BaseMultiModalCache[MultiModalProcessorCacheInItem, MultiModalProcessorCacheOutItem]
|
||||
):
|
||||
"""The required interface for caches on P0."""
|
||||
|
||||
@abstractmethod
|
||||
def is_cached_item(self, mm_hash: str) -> bool:
|
||||
"""
|
||||
Check whether a multi-modal item is
|
||||
in the underlying cache.
|
||||
|
||||
This **DOES NOT** update the cache eviction order.
|
||||
|
||||
Args:
|
||||
mm_hash: The hash of the item to check.
|
||||
|
||||
Returns:
|
||||
`True` if the item is cached, otherwise `False`.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def is_cached(self, mm_hashes: list[str]) -> list[bool]:
|
||||
"""
|
||||
Check whether a sequence of multi-modal items are
|
||||
in the underlying cache.
|
||||
|
||||
This **DOES NOT** update the cache eviction order.
|
||||
|
||||
Args:
|
||||
mm_hashes: The hash of each item to check.
|
||||
|
||||
Returns:
|
||||
For each item, `True` if the item is cached, otherwise `False`.
|
||||
"""
|
||||
return [self.is_cached_item(mm_hash) for mm_hash in mm_hashes]
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the underlying cache, if needed."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def touch_sender_cache_item(self, mm_hash: str) -> None:
|
||||
"""
|
||||
Update the cache eviction order for a multi-modal item.
|
||||
|
||||
This is used to touch the item in the cache without changing
|
||||
its value.
|
||||
|
||||
Args:
|
||||
mm_hash: The hash of the multi-modal item.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def make_stats(self, *, delta: bool = False) -> CacheInfo:
|
||||
"""
|
||||
Get (and reset) the multi-modal cache stats.
|
||||
|
||||
Returns:
|
||||
The current multi-modal caching stats.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MultiModalProcessorOnlyCache(BaseMultiModalProcessorCache):
|
||||
"""
|
||||
The cache which is used on P0 when IPC caching is disabled.
|
||||
|
||||
How to update each item:
|
||||
|
||||
- If the item is in the cache, replace the input with the cached item.
|
||||
- If the item is not in the cache, store that item (which includes
|
||||
tensor data and metadata) into the cache, and return the input.
|
||||
"""
|
||||
|
||||
def __init__(self, model_config: "ModelConfig") -> None:
|
||||
super().__init__()
|
||||
|
||||
mm_config = model_config.get_multimodal_config()
|
||||
|
||||
self._cache = MultiModalCache.get_lru_cache(
|
||||
mm_config.mm_processor_cache_gb,
|
||||
MultiModalProcessorCacheItem,
|
||||
)
|
||||
|
||||
@override
|
||||
def is_cached_item(self, mm_hash: str) -> bool:
|
||||
return mm_hash in self._cache
|
||||
|
||||
@override
|
||||
def get_and_update_item(
|
||||
self,
|
||||
mm_item: MultiModalProcessorCacheInItem,
|
||||
mm_hash: str,
|
||||
) -> MultiModalProcessorCacheOutItem:
|
||||
if (cached_item := self._cache.get(mm_hash)) is not None:
|
||||
return cached_item.item, cached_item.prompt_updates
|
||||
|
||||
assert mm_item is not None, f"Expected a cached item for {mm_hash=}"
|
||||
|
||||
self._cache[mm_hash] = MultiModalProcessorCacheItem(*mm_item)
|
||||
|
||||
return mm_item
|
||||
|
||||
@override
|
||||
def touch_sender_cache_item(self, mm_hash: str) -> None:
|
||||
self._cache.touch(mm_hash)
|
||||
|
||||
@override
|
||||
def clear_cache(self) -> None:
|
||||
self._cache.clear()
|
||||
|
||||
@override
|
||||
def make_stats(self, *, delta: bool = False) -> CacheInfo:
|
||||
return self._cache.stat(delta=delta)
|
||||
|
||||
|
||||
class MultiModalProcessorSenderCache(BaseMultiModalProcessorCache):
|
||||
"""
|
||||
The cache which is used on P0 when IPC caching is enabled.
|
||||
|
||||
How to update each item:
|
||||
|
||||
- If the item is already in the cache, clear the input to avoid
|
||||
unnecessary IPC.
|
||||
|
||||
- If the item is not in the cache, store the metadata of that item so
|
||||
that the eviction policy remains the same as the cache on P1,
|
||||
and return the input.
|
||||
By only storing the metadata, we avoid keeping the data itself in
|
||||
memory inside P0.
|
||||
"""
|
||||
|
||||
def __init__(self, model_config: "ModelConfig") -> None:
|
||||
super().__init__()
|
||||
|
||||
mm_config = model_config.get_multimodal_config()
|
||||
|
||||
self._cache = MultiModalCache.get_lru_cache(
|
||||
mm_config.mm_processor_cache_gb,
|
||||
MultiModalProcessorCacheItemMetadata,
|
||||
)
|
||||
|
||||
@override
|
||||
def is_cached_item(self, mm_hash: str) -> bool:
|
||||
return mm_hash in self._cache
|
||||
|
||||
@override
|
||||
def get_and_update_item(
|
||||
self,
|
||||
mm_item: MultiModalProcessorCacheInItem,
|
||||
mm_hash: str,
|
||||
) -> MultiModalProcessorCacheOutItem:
|
||||
if (cached_item := self._cache.get(mm_hash)) is not None:
|
||||
return None, cached_item.prompt_updates
|
||||
|
||||
assert mm_item is not None, f"Expected a cached item for {mm_hash=}"
|
||||
|
||||
self._cache[mm_hash] = MultiModalProcessorCacheItemMetadata(*mm_item)
|
||||
|
||||
return mm_item
|
||||
|
||||
@override
|
||||
def touch_sender_cache_item(self, mm_hash: str) -> None:
|
||||
self._cache.touch(mm_hash)
|
||||
|
||||
@override
|
||||
def clear_cache(self) -> None:
|
||||
self._cache.clear()
|
||||
|
||||
@override
|
||||
def make_stats(self, *, delta: bool = False) -> CacheInfo:
|
||||
return self._cache.stat(delta=delta)
|
||||
|
||||
|
||||
class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache):
|
||||
"""
|
||||
The cache which is used on P0 when IPC caching is enabled.
|
||||
|
||||
How to update each item:
|
||||
|
||||
- If the item is already in the cache, clear the input to avoid
|
||||
unnecessary IPC.
|
||||
|
||||
- If the item is not in the cache, store the data in shared memory.
|
||||
"""
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig") -> None:
|
||||
super().__init__()
|
||||
|
||||
self.world_size = vllm_config.parallel_config.world_size
|
||||
mm_config = vllm_config.model_config.get_multimodal_config()
|
||||
|
||||
ring_buffer = SingleWriterShmRingBuffer(
|
||||
data_buffer_size=int(mm_config.mm_processor_cache_gb * GiB_bytes),
|
||||
name=envs.VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME,
|
||||
create=True, # sender is the writer
|
||||
)
|
||||
self._shm_cache = SingleWriterShmObjectStorage(
|
||||
max_object_size=mm_config.mm_shm_cache_max_object_size_mb * MiB_bytes,
|
||||
n_readers=self.world_size,
|
||||
ring_buffer=ring_buffer,
|
||||
serde_class=MsgpackSerde,
|
||||
)
|
||||
# cache prompt_updates for P0 only
|
||||
self._p0_cache: dict[str, Sequence[ResolvedPromptUpdate]] = {}
|
||||
|
||||
self._hits = 0
|
||||
self._total = 0
|
||||
self._last_info = CacheInfo(hits=0, total=0)
|
||||
|
||||
def _stat(self, *, delta: bool = False) -> CacheInfo:
|
||||
info = CacheInfo(hits=self._hits, total=self._total)
|
||||
|
||||
if delta:
|
||||
info_delta = info - self._last_info
|
||||
self._last_info = info
|
||||
info = info_delta
|
||||
|
||||
return info
|
||||
|
||||
@override
|
||||
def is_cached_item(self, mm_hash: str) -> bool:
|
||||
return self._shm_cache.is_cached(mm_hash)
|
||||
|
||||
@override
|
||||
def get_and_update_item(
|
||||
self,
|
||||
mm_item: MultiModalProcessorCacheInItem,
|
||||
mm_hash: str,
|
||||
) -> MultiModalProcessorCacheOutItem:
|
||||
if self._shm_cache.is_cached(mm_hash):
|
||||
self._hits += 1
|
||||
self._total += 1
|
||||
|
||||
address, monotonic_id = self._shm_cache.get_cached(mm_hash)
|
||||
prompt_updates = self._p0_cache[mm_hash]
|
||||
return self.address_as_item(address, monotonic_id), prompt_updates
|
||||
|
||||
assert mm_item is not None, f"Expected a cached item for {mm_hash=}"
|
||||
item, prompt_updates = mm_item
|
||||
|
||||
self._total += 1
|
||||
|
||||
try:
|
||||
address, monotonic_id = self._shm_cache.put(mm_hash, item)
|
||||
# Try to remove dangling items if p0 cache is too large.
|
||||
if len(self._p0_cache) >= 2 * len(self._shm_cache.key_index):
|
||||
self.remove_dangling_items()
|
||||
|
||||
self._p0_cache[mm_hash] = prompt_updates
|
||||
return self.address_as_item(address, monotonic_id), prompt_updates
|
||||
except (ValueError, MemoryError) as e:
|
||||
# put may fail if the object is too large or
|
||||
# the cache is full.
|
||||
# In this case we log the error and keep the original mm_input.
|
||||
logger.debug("Failed to cache mm_input with hash %s: %s", mm_hash, e)
|
||||
return mm_item
|
||||
|
||||
@override
|
||||
def touch_sender_cache_item(self, mm_hash: str) -> None:
|
||||
"""Touch the item in shared memory cache to prevent eviction.
|
||||
Increments writer_flag on sender side."""
|
||||
self._shm_cache.touch(mm_hash)
|
||||
|
||||
@override
|
||||
def clear_cache(self) -> None:
|
||||
self._shm_cache.clear()
|
||||
self._p0_cache.clear()
|
||||
|
||||
self._hits = 0
|
||||
self._total = 0
|
||||
self._last_info = CacheInfo(hits=0, total=0)
|
||||
|
||||
@override
|
||||
def make_stats(self, *, delta: bool = False) -> CacheInfo:
|
||||
return self._stat(delta=delta)
|
||||
|
||||
@override
|
||||
def close(self) -> None:
|
||||
self._shm_cache.close()
|
||||
|
||||
def remove_dangling_items(self) -> None:
|
||||
"""Remove items that are no longer in the shared memory cache."""
|
||||
cached_hashes = self._shm_cache.key_index.keys()
|
||||
dangling_hashes = set(self._p0_cache.keys()) - cached_hashes
|
||||
for mm_hash in dangling_hashes:
|
||||
del self._p0_cache[mm_hash]
|
||||
|
||||
def address_as_item(
|
||||
self,
|
||||
address: int,
|
||||
monotonic_id: int,
|
||||
) -> MultiModalKwargsItem:
|
||||
addr_elem = MultiModalFieldElem(
|
||||
data=address,
|
||||
field=MultiModalBatchedField(),
|
||||
)
|
||||
id_elem = MultiModalFieldElem(
|
||||
data=monotonic_id,
|
||||
field=MultiModalBatchedField(),
|
||||
)
|
||||
|
||||
return MultiModalKwargsItem({"address": addr_elem, "monotonic_id": id_elem})
|
||||
|
||||
|
||||
class BaseMultiModalReceiverCache(
|
||||
BaseMultiModalCache[MultiModalKwargsItem | None, MultiModalKwargsItem]
|
||||
):
|
||||
"""The required interface for caches on P1."""
|
||||
|
||||
def get_and_update_features(
|
||||
self,
|
||||
mm_features: list["MultiModalFeatureSpec"],
|
||||
) -> list["MultiModalFeatureSpec"]:
|
||||
"""
|
||||
Update multimodal features with cached encoder outputs.
|
||||
Touch all identifier at first before update to avoid
|
||||
item in updated list evict during update.
|
||||
|
||||
Uses mm_hash for cache key to share across LoRAs (falls back to
|
||||
identifier for backward compatibility).
|
||||
"""
|
||||
for feature in mm_features:
|
||||
cache_key = feature.mm_hash or feature.identifier
|
||||
self.touch_receiver_cache_item(cache_key, feature.data)
|
||||
|
||||
for feature in mm_features:
|
||||
cache_key = feature.mm_hash or feature.identifier
|
||||
feature.data = self.get_and_update_item(feature.data, cache_key)
|
||||
return mm_features
|
||||
|
||||
@abstractmethod
|
||||
def touch_receiver_cache_item(
|
||||
self,
|
||||
mm_hash: str,
|
||||
mm_item: MultiModalKwargsItem | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Update the cache eviction order for a multi-modal item.
|
||||
|
||||
This is used to touch the item in the cache without changing
|
||||
its value.
|
||||
|
||||
Args:
|
||||
mm_hash: The hash of the multi-modal item.
|
||||
mm_item: The multi-modal item itself. This is optional and
|
||||
may not be needed by some cache implementations.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MultiModalReceiverCache(BaseMultiModalReceiverCache):
|
||||
"""
|
||||
The cache which is used on P1 when IPC caching is enabled.
|
||||
|
||||
How to update each item:
|
||||
|
||||
- If the item is in the cache, replace the input with the cached item.
|
||||
- If the item is not in the cache, store that item (which includes tensor
|
||||
data) into the cache, and return the input.
|
||||
"""
|
||||
|
||||
def __init__(self, model_config: "ModelConfig") -> None:
|
||||
super().__init__()
|
||||
|
||||
mm_config = model_config.get_multimodal_config()
|
||||
|
||||
self._cache = MultiModalCache.get_lru_cache(
|
||||
mm_config.mm_processor_cache_gb,
|
||||
MultiModalKwargsItem,
|
||||
)
|
||||
|
||||
@override
|
||||
def get_and_update_item(
|
||||
self,
|
||||
mm_item: MultiModalKwargsItem | None,
|
||||
mm_hash: str,
|
||||
) -> MultiModalKwargsItem:
|
||||
if (cached_item := self._cache.get(mm_hash)) is not None:
|
||||
return cached_item
|
||||
|
||||
assert mm_item is not None, f"Expected a cached item for {mm_hash=}"
|
||||
|
||||
self._cache[mm_hash] = mm_item
|
||||
return mm_item
|
||||
|
||||
@override
|
||||
def touch_receiver_cache_item(
|
||||
self,
|
||||
mm_hash: str,
|
||||
mm_item: MultiModalKwargsItem | None = None,
|
||||
) -> None:
|
||||
self._cache.touch(mm_hash)
|
||||
|
||||
@override
|
||||
def clear_cache(self) -> None:
|
||||
self._cache.clear()
|
||||
|
||||
|
||||
class ShmObjectStoreReceiverCache(BaseMultiModalReceiverCache):
|
||||
"""
|
||||
The cache which is used on P1 Worker Process when IPC caching is enabled.
|
||||
|
||||
How to update each item:
|
||||
|
||||
- If the item has an address, replace the input with the cached item.
|
||||
- If not, return the input.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
shared_worker_lock: LockType,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.world_size = vllm_config.parallel_config.world_size
|
||||
mm_config = vllm_config.model_config.get_multimodal_config()
|
||||
|
||||
ring_buffer = SingleWriterShmRingBuffer(
|
||||
data_buffer_size=int(mm_config.mm_processor_cache_gb * GiB_bytes),
|
||||
name=envs.VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME,
|
||||
create=False, # Server is a reader
|
||||
)
|
||||
self._shm_cache = SingleWriterShmObjectStorage(
|
||||
max_object_size=mm_config.mm_shm_cache_max_object_size_mb * MiB_bytes,
|
||||
n_readers=self.world_size,
|
||||
ring_buffer=ring_buffer,
|
||||
serde_class=MsgpackSerde,
|
||||
reader_lock=shared_worker_lock,
|
||||
)
|
||||
|
||||
@override
|
||||
def get_and_update_item(
|
||||
self,
|
||||
mm_item: MultiModalKwargsItem | None,
|
||||
mm_hash: str,
|
||||
) -> MultiModalKwargsItem:
|
||||
assert mm_item is not None, f"Expected an address item for {mm_hash=}"
|
||||
if "address" in mm_item:
|
||||
address = cast(int, mm_item["address"].data)
|
||||
monotonic_id = cast(int, mm_item["monotonic_id"].data)
|
||||
return self._shm_cache.get(address, monotonic_id)
|
||||
|
||||
return mm_item
|
||||
|
||||
@override
|
||||
def touch_receiver_cache_item(
|
||||
self,
|
||||
mm_hash: str,
|
||||
mm_item: MultiModalKwargsItem | None = None,
|
||||
) -> None:
|
||||
"""Touch the item in shared memory cache to prevent eviction.
|
||||
Increments reader_count on receiver side."""
|
||||
assert mm_item is not None
|
||||
if "address" in mm_item:
|
||||
address = cast(int, mm_item["address"].data)
|
||||
monotonic_id = cast(int, mm_item["monotonic_id"].data)
|
||||
self._shm_cache.touch(mm_hash, address=address, monotonic_id=monotonic_id)
|
||||
|
||||
@override
|
||||
def clear_cache(self) -> None:
|
||||
self._shm_cache.clear()
|
||||
193
vllm/multimodal/encoder_budget.py
Normal file
193
vllm/multimodal/encoder_budget.py
Normal file
@@ -0,0 +1,193 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Mapping
|
||||
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal.processing import BaseMultiModalProcessor
|
||||
from vllm.multimodal.registry import MultiModalRegistry
|
||||
from vllm.utils.torch_utils import set_default_torch_num_threads
|
||||
from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def get_mm_max_toks_per_item(
|
||||
model_config: ModelConfig,
|
||||
mm_registry: MultiModalRegistry,
|
||||
processor: BaseMultiModalProcessor,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> Mapping[str, int]:
|
||||
"""
|
||||
Get the maximum number of tokens per data item from each modality based
|
||||
on underlying model configuration.
|
||||
"""
|
||||
max_tokens_per_item = processor.info.get_mm_max_tokens_per_item(
|
||||
seq_len=model_config.max_model_len,
|
||||
mm_counts=mm_counts,
|
||||
)
|
||||
if max_tokens_per_item is not None:
|
||||
return max_tokens_per_item
|
||||
|
||||
mm_inputs = mm_registry.get_dummy_mm_inputs(
|
||||
model_config,
|
||||
mm_counts=mm_counts,
|
||||
processor=processor,
|
||||
)
|
||||
|
||||
return {
|
||||
modality: sum(item.get_num_embeds() for item in placeholders)
|
||||
for modality, placeholders in mm_inputs["mm_placeholders"].items()
|
||||
}
|
||||
|
||||
|
||||
class MultiModalBudget:
|
||||
"""Helper class to calculate budget information for multi-modal models."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
mm_registry: MultiModalRegistry,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.model_config = model_config = vllm_config.model_config
|
||||
self.scheduler_config = scheduler_config = vllm_config.scheduler_config
|
||||
|
||||
self.max_model_len = model_config.max_model_len
|
||||
self.max_num_reqs = scheduler_config.max_num_seqs
|
||||
|
||||
with set_default_torch_num_threads(): # Avoid hang during startup
|
||||
cache = mm_registry.processor_only_cache_from_config(vllm_config)
|
||||
processor = mm_registry.create_processor(model_config, cache=cache)
|
||||
|
||||
self.cache = cache
|
||||
self.processor = processor
|
||||
mm_config = model_config.get_multimodal_config()
|
||||
enable_mm_embeds = mm_config is not None and mm_config.enable_mm_embeds
|
||||
|
||||
supported_mm_limits = processor.info.supported_mm_limits
|
||||
self.mm_limits = mm_limits = processor.info.allowed_mm_limits
|
||||
|
||||
# Modalities that pass through the MM encoder tower
|
||||
tower_modalities = {
|
||||
modality
|
||||
for modality in supported_mm_limits
|
||||
if mm_limits.get(modality, 0) > 0
|
||||
}
|
||||
# Modalities that bypass the tower (pre-computed embeddings only)
|
||||
embed_only_modalities = {
|
||||
modality
|
||||
for modality in supported_mm_limits
|
||||
if enable_mm_embeds and mm_limits.get(modality, 0) == 0
|
||||
}
|
||||
|
||||
active_modalities = tower_modalities | embed_only_modalities
|
||||
|
||||
all_mm_max_toks_per_item = get_mm_max_toks_per_item(
|
||||
model_config,
|
||||
mm_registry,
|
||||
processor,
|
||||
mm_counts=dict.fromkeys(active_modalities, 1),
|
||||
)
|
||||
|
||||
if embed_only_modalities:
|
||||
logger.info_once(
|
||||
"enable_mm_embeds is True; modalities handled as embedding-only: %s",
|
||||
tuple(embed_only_modalities),
|
||||
)
|
||||
|
||||
# Some models (e.g., Qwen3Omni with use_audio_in_video=True) share
|
||||
# placeholders between modalities, so not all active modalities will
|
||||
# have their own entry in the returned dict. We filter to only include
|
||||
# modalities that have independent placeholder tokens.
|
||||
active_mm_max_toks_per_item = {
|
||||
modality: all_mm_max_toks_per_item[modality]
|
||||
for modality in active_modalities
|
||||
if modality in all_mm_max_toks_per_item
|
||||
}
|
||||
tower_mm_max_toks_per_item = {
|
||||
modality: active_mm_max_toks_per_item[modality]
|
||||
for modality in tower_modalities
|
||||
if modality in active_mm_max_toks_per_item
|
||||
}
|
||||
|
||||
# Encoder budget is computed from all active modalities (including
|
||||
# embedding-only ones that need encoder cache space).
|
||||
encoder_compute_budget, encoder_cache_size = compute_mm_encoder_budget(
|
||||
scheduler_config,
|
||||
active_mm_max_toks_per_item,
|
||||
)
|
||||
|
||||
self.encoder_compute_budget = encoder_compute_budget
|
||||
self.encoder_cache_size = encoder_cache_size
|
||||
|
||||
mm_max_items_per_prompt = dict[str, int]()
|
||||
mm_max_items_per_batch = dict[str, int]()
|
||||
|
||||
# Per-prompt/per-batch limits are only relevant for tower modalities
|
||||
# (embedding-only modalities don't go through the encoder tower).
|
||||
for modality, max_toks_per_item in tower_mm_max_toks_per_item.items():
|
||||
(
|
||||
mm_max_items_per_prompt[modality],
|
||||
mm_max_items_per_batch[modality],
|
||||
) = self._get_max_items(modality, max_toks_per_item)
|
||||
|
||||
self.mm_max_toks_per_item = tower_mm_max_toks_per_item
|
||||
self.mm_max_items_per_prompt: Mapping[str, int] = mm_max_items_per_prompt
|
||||
self.mm_max_items_per_batch: Mapping[str, int] = mm_max_items_per_batch
|
||||
|
||||
def _get_max_items(
|
||||
self,
|
||||
modality: str,
|
||||
max_tokens_per_item: int,
|
||||
) -> tuple[int, int]:
|
||||
if max_tokens_per_item == 0:
|
||||
return 0, 0
|
||||
|
||||
# Check how many items of this modality can be supported by
|
||||
# the encoder budget.
|
||||
if (encoder_budget := self.get_encoder_budget()) == 0:
|
||||
return 0, 0
|
||||
|
||||
max_encoder_items_per_batch = encoder_budget // max_tokens_per_item
|
||||
|
||||
# Check how many items of this modality can be supported by
|
||||
# the decoder budget.
|
||||
mm_limit = self.mm_limits[modality]
|
||||
|
||||
max_items_per_prompt = max(
|
||||
1,
|
||||
min(mm_limit, self.max_model_len // max_tokens_per_item),
|
||||
)
|
||||
|
||||
scheduler_config = self.scheduler_config
|
||||
max_num_reqs = self.max_num_reqs
|
||||
|
||||
if not scheduler_config.enable_chunked_prefill:
|
||||
max_num_reqs = min(
|
||||
max_num_reqs,
|
||||
scheduler_config.max_num_batched_tokens // max_tokens_per_item,
|
||||
)
|
||||
|
||||
max_decoder_items_per_batch = max_num_reqs * max_items_per_prompt
|
||||
|
||||
max_items_per_batch = max(
|
||||
1,
|
||||
min(max_encoder_items_per_batch, max_decoder_items_per_batch),
|
||||
)
|
||||
|
||||
return max_items_per_prompt, max_items_per_batch
|
||||
|
||||
def get_modality_with_max_tokens(self) -> str:
|
||||
mm_max_toks_per_item = self.mm_max_toks_per_item
|
||||
modality, _ = max(mm_max_toks_per_item.items(), key=lambda x: (x[1], x[0]))
|
||||
|
||||
return modality
|
||||
|
||||
def get_encoder_budget(self) -> int:
|
||||
return min(self.encoder_compute_budget, self.encoder_cache_size)
|
||||
|
||||
def reset_cache(self) -> None:
|
||||
if self.cache is not None:
|
||||
self.cache.clear_cache()
|
||||
294
vllm/multimodal/evs.py
Normal file
294
vllm/multimodal/evs.py
Normal file
@@ -0,0 +1,294 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
import typing
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def compute_retained_tokens_count(
|
||||
tokens_per_frame: int, num_frames: int, q: float
|
||||
) -> int:
|
||||
"""
|
||||
Compute the number of retained tokens for a given video.
|
||||
Method ensures that we retain all the tokens from the first frame
|
||||
regardless of the pruning rate.
|
||||
|
||||
Args:
|
||||
tokens_per_frame: The number of tokens per frame.
|
||||
num_frames: The total number of frames.
|
||||
q: The pruning rate.
|
||||
|
||||
Returns:
|
||||
The number of retained tokens.
|
||||
"""
|
||||
total_tokens = tokens_per_frame * num_frames
|
||||
evs_num_tokens = int(total_tokens * (1 - q))
|
||||
min_num_tokens = tokens_per_frame
|
||||
return max(min_num_tokens, evs_num_tokens)
|
||||
|
||||
|
||||
def compute_retention_mask(
|
||||
video_embeds: torch.Tensor,
|
||||
video_size_thw: torch.LongTensor | tuple[int, int, int],
|
||||
spatial_merge_size: int,
|
||||
q: float,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Computes the retention mask for input video embeddings.
|
||||
|
||||
Args:
|
||||
video_embeds (`torch.Tensor`): The input video embeddings
|
||||
of shape `(T * H * W // spatial_merge_size ^ 2, hidden_size)`
|
||||
video_size_thw (`torch.LongTensor` of shape `(3)`):
|
||||
The temporal, height and width of video.
|
||||
spatial_merge_size: Size reduction for rows & cols dimensions.
|
||||
q: (`float`): Pruning rate factor [0,1)
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: The retention mask for the video embeddings of
|
||||
`(T * H * W // spatial_merge_size ^ 2)` shape.
|
||||
"""
|
||||
T, H, W = map(int, video_size_thw)
|
||||
|
||||
# Use reshape instead of einops to avoid graph breaks
|
||||
video_embeds = video_embeds.reshape(
|
||||
T,
|
||||
H // spatial_merge_size,
|
||||
W // spatial_merge_size,
|
||||
video_embeds.size(-1),
|
||||
)
|
||||
tokens_per_frame = (H // spatial_merge_size) * (W // spatial_merge_size)
|
||||
# Core EVS
|
||||
similarity = torch.nn.functional.cosine_similarity(
|
||||
video_embeds[1:, ...], video_embeds[:-1, ...], dim=-1
|
||||
)
|
||||
dissimilarity = 1 - similarity
|
||||
|
||||
# Always ensure we include all tokens from the first frame
|
||||
dissimilarity = torch.cat(
|
||||
[255 * torch.ones_like(video_embeds[:1, :, :, 0]), dissimilarity], dim=0
|
||||
)
|
||||
|
||||
dissimilarity_flat = dissimilarity.view(-1)
|
||||
order = torch.argsort(dissimilarity_flat, dim=-1, descending=True, stable=True)
|
||||
retain_num_tokens = compute_retained_tokens_count(
|
||||
tokens_per_frame=tokens_per_frame, num_frames=T, q=q
|
||||
)
|
||||
topk_indices = order[:retain_num_tokens]
|
||||
|
||||
retention_mask = torch.zeros_like(dissimilarity_flat, dtype=torch.bool)
|
||||
retention_mask[topk_indices] = True
|
||||
retention_mask = retention_mask.reshape(dissimilarity.size())
|
||||
|
||||
mask = retention_mask.view(-1) # "T H W -> (T H W)"
|
||||
return mask
|
||||
|
||||
|
||||
def compute_mrope_for_media(
|
||||
video_size_thw: torch.LongTensor,
|
||||
spatial_merge_size: int,
|
||||
tokens_per_second: float = 1.0,
|
||||
video_second_per_grid: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Computes the mrope for video embeddings based on the grid dimensions.
|
||||
Computed mrope positions match original qwen 2.5 implementation,
|
||||
but positions are built for media being the first element in sequence.
|
||||
|
||||
Args:
|
||||
video_size_thw: Media size (num frames, rows, cols)
|
||||
spatial_merge_size: Size reduction for rows & cols dimensions.
|
||||
tokens_per_second: Number of tokens per second.
|
||||
video_second_per_grid: Number of seconds per video.
|
||||
|
||||
Returns:
|
||||
Tensor of shape `(T * H * W, 4)` where last dimension
|
||||
represents mrope positions [0:3), while the last channel
|
||||
contains value of llm_grid_w repeated for all positions.
|
||||
"""
|
||||
llm_grid_t = video_size_thw[0]
|
||||
llm_grid_h = video_size_thw[1] // spatial_merge_size
|
||||
llm_grid_w = video_size_thw[2] // spatial_merge_size
|
||||
|
||||
t_index = (
|
||||
(
|
||||
torch.arange(llm_grid_t)
|
||||
.view(-1, 1)
|
||||
.expand(-1, llm_grid_h * llm_grid_w)
|
||||
.mul(tokens_per_second * video_second_per_grid)
|
||||
)
|
||||
.long()
|
||||
.flatten()
|
||||
)
|
||||
h_index = (
|
||||
torch.arange(llm_grid_h)
|
||||
.view(1, -1, 1)
|
||||
.expand(llm_grid_t, -1, llm_grid_w)
|
||||
.flatten()
|
||||
)
|
||||
w_index = (
|
||||
torch.arange(llm_grid_w)
|
||||
.view(1, 1, -1)
|
||||
.expand(llm_grid_t, llm_grid_h, -1)
|
||||
.flatten()
|
||||
)
|
||||
llm_grid_w = (
|
||||
torch.tensor([llm_grid_w])
|
||||
.view(1, 1, 1)
|
||||
.expand(llm_grid_t, llm_grid_h, llm_grid_w)
|
||||
.flatten()
|
||||
)
|
||||
|
||||
positions = torch.stack([t_index, h_index, w_index, llm_grid_w], dim=1)
|
||||
return positions
|
||||
|
||||
|
||||
def recompute_mrope_positions(
|
||||
input_ids: torch.LongTensor,
|
||||
multimodal_positions: list[torch.Tensor],
|
||||
mrope_positions: torch.LongTensor,
|
||||
num_computed_tokens: int,
|
||||
vision_start_token_id: int,
|
||||
image_token_id: int,
|
||||
video_token_id: int,
|
||||
) -> tuple[torch.LongTensor, int]:
|
||||
"""
|
||||
Update part of input mrope positions.
|
||||
Original mrope_positions are computed incorrectly, so once we prune media
|
||||
tokens we should reflect this in the mrope positions for the LLM.
|
||||
|
||||
This method supports chunked prefill approach where
|
||||
multimodal_embeddings are passed to LLM in chunks, so input
|
||||
multimodal_embeddings may contain zero, some or even some part of all
|
||||
multimodal_embeddings for a given prompt.
|
||||
|
||||
Each multimodal_positions has 4 extra channels
|
||||
(First 3 channels corresponds to original 3 mrope positions, last channel
|
||||
is the maximum width of the media repeated). Provided multimodal_positions
|
||||
do not reflect location of media position in sequence - they are computed
|
||||
like the media is in the 0-th position in the sequence.
|
||||
|
||||
Method works as follows: it recomputes mrope_positions starting from the
|
||||
`num_computed_tokens` for `total_len_of_multimodal_embeddings` and then
|
||||
shifts all text tokens that goes after total_len_of_multimodal_embeddings.
|
||||
|
||||
It also handles case when multimodal_embeddings is partial
|
||||
(e.g. one media is split into two prefill stages)
|
||||
|
||||
Args:
|
||||
input_ids: (N,) All input tokens of the prompt (entire sequence).
|
||||
multimodal_positions: List of mrope positions for each media.
|
||||
mrope_positions: Existing mrope positions (4, N) for entire sequence.
|
||||
num_computed_tokens: A number of computed tokens so far.
|
||||
vision_start_token_id: Token indicating start of vision media.
|
||||
image_token_id: Image token id
|
||||
video_token_id: Video token id
|
||||
|
||||
Returns:
|
||||
Tuple of (mrope_positions, mrope_position_delta).
|
||||
"""
|
||||
|
||||
# Tensors
|
||||
positions: torch.LongTensor = typing.cast(
|
||||
torch.LongTensor, mrope_positions.clone()
|
||||
) # (3, N)
|
||||
N = input_ids.numel()
|
||||
|
||||
image_mask = input_ids.eq(image_token_id)
|
||||
video_mask = input_ids.eq(video_token_id)
|
||||
media_mask = image_mask | video_mask
|
||||
text_mask = ~media_mask
|
||||
|
||||
# Early exit: no media in this chunk
|
||||
if len(multimodal_positions) == 0:
|
||||
delta = int((positions.max().item() + 1) - N) if positions.numel() else -N
|
||||
return positions, delta
|
||||
|
||||
total_mm_tokens = torch.count_nonzero(media_mask)
|
||||
seen_mm_tokens = torch.count_nonzero(media_mask[:num_computed_tokens])
|
||||
|
||||
# Early exit: we've updated positions for all media tokens
|
||||
# (and consequently - for all remaining text tokens)
|
||||
if seen_mm_tokens == total_mm_tokens:
|
||||
delta = int((positions.max().item() + 1) - N) if positions.numel() else -N
|
||||
return positions, delta
|
||||
|
||||
vision_start_indices = (input_ids == vision_start_token_id).nonzero(as_tuple=True)[
|
||||
0
|
||||
]
|
||||
|
||||
for mm_pos in multimodal_positions:
|
||||
# Each mm_pos can be a complete embedding for single media
|
||||
# or it can be a part of a single media (due to chunked prefill)
|
||||
|
||||
# Cases to cover
|
||||
# - Current prefill chunk has no vision start indexes at all
|
||||
# - Vision start token appeared in previous prefill round
|
||||
# - Regular case
|
||||
seen_vision_start_indices = vision_start_indices[
|
||||
vision_start_indices < num_computed_tokens
|
||||
]
|
||||
|
||||
if len(seen_vision_start_indices):
|
||||
# If we have encountered some vision start indexes,
|
||||
# then we should check the condition:
|
||||
# | --- prefill 1 ------| ---- prefill 2 ----- |
|
||||
# | TTTTTTTTTSVVVVVVVVVV|VVVVVVTTTTTTTTTTTTTTTT|
|
||||
last_vision_start_token = seen_vision_start_indices[-1]
|
||||
seem_mm_tokens_before_last_vision_start = torch.count_nonzero(
|
||||
media_mask[:last_vision_start_token]
|
||||
)
|
||||
in_the_middle_of_media = (
|
||||
seen_mm_tokens > seem_mm_tokens_before_last_vision_start
|
||||
)
|
||||
|
||||
if in_the_middle_of_media:
|
||||
mm_embeddings_seen = (
|
||||
seen_mm_tokens - seem_mm_tokens_before_last_vision_start
|
||||
)
|
||||
global_mm_start = last_vision_start_token
|
||||
else:
|
||||
# We have completed previous mm_embedding part and
|
||||
# ready to start a new one
|
||||
next_vision_start_token = vision_start_indices[
|
||||
vision_start_indices >= num_computed_tokens
|
||||
][0]
|
||||
mm_embeddings_seen = 0
|
||||
global_mm_start = next_vision_start_token
|
||||
|
||||
else:
|
||||
# If there were no vision start indexes so far,
|
||||
# let's find first vision start index
|
||||
next_vision_start_token = vision_start_indices[
|
||||
vision_start_indices >= num_computed_tokens
|
||||
][0]
|
||||
|
||||
mm_embeddings_seen = 0
|
||||
global_mm_start = next_vision_start_token
|
||||
|
||||
# Offset right after vision_start_token
|
||||
base = positions[-1, global_mm_start] + 1
|
||||
local_start = global_mm_start + 1 + mm_embeddings_seen
|
||||
local_end = local_start + mm_pos.shape[1]
|
||||
positions[:, local_start:local_end] = mm_pos[0:3] + base
|
||||
|
||||
# mm_pos[3, 0] is the max width of the media
|
||||
offset = mm_pos[3, 0] + base
|
||||
|
||||
text_pos_sum = torch.cumsum(text_mask[local_end:].long(), dim=0)
|
||||
|
||||
positions[:, local_end:N] = text_pos_sum + offset - 1
|
||||
|
||||
# Include distance to the next vision start token
|
||||
num_computed_tokens += mm_pos.shape[1]
|
||||
|
||||
mrope_positions_delta = (positions.max() + 1 - N).item()
|
||||
return positions, mrope_positions_delta
|
||||
162
vllm/multimodal/hasher.py
Normal file
162
vllm/multimodal/hasher.py
Normal file
@@ -0,0 +1,162 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import functools
|
||||
import hashlib
|
||||
import pickle
|
||||
import uuid
|
||||
from collections.abc import Callable, Iterable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .media import MediaWithBytes
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=3)
|
||||
def _get_hasher_factory(algorithm: str) -> Callable[[], "hashlib._Hash"]:
|
||||
"""
|
||||
Get the hasher factory based on the configured algorithm.
|
||||
|
||||
Args:
|
||||
algorithm: Hash algorithm name (blake3, sha256, or sha512)
|
||||
|
||||
Returns a callable that creates a new hasher instance.
|
||||
Supports blake3 (default), sha256, and sha512 for FIPS compliance.
|
||||
|
||||
See: https://github.com/vllm-project/vllm/issues/18334
|
||||
"""
|
||||
algorithm = algorithm.lower()
|
||||
|
||||
if algorithm == "blake3":
|
||||
from blake3 import blake3
|
||||
|
||||
return blake3
|
||||
elif algorithm == "sha256":
|
||||
return hashlib.sha256
|
||||
elif algorithm == "sha512":
|
||||
return hashlib.sha512
|
||||
else:
|
||||
# This should never happen due to env_with_choices validation
|
||||
raise ValueError(f"Unsupported hash algorithm: {algorithm}")
|
||||
|
||||
|
||||
class MultiModalHasher:
|
||||
@classmethod
|
||||
def serialize_item(cls, obj: object) -> Iterable[bytes | memoryview]:
|
||||
# Simple cases
|
||||
if isinstance(obj, (bytes, memoryview)):
|
||||
return (obj,)
|
||||
if isinstance(obj, str):
|
||||
return (obj.encode("utf-8"),)
|
||||
if isinstance(obj, (int, float)):
|
||||
return (np.array(obj).tobytes(),)
|
||||
|
||||
if isinstance(obj, Image.Image):
|
||||
exif = obj.getexif()
|
||||
if Image.ExifTags.Base.ImageID in exif and isinstance(
|
||||
exif[Image.ExifTags.Base.ImageID], uuid.UUID
|
||||
):
|
||||
return (exif[Image.ExifTags.Base.ImageID].bytes,)
|
||||
|
||||
data = {"mode": obj.mode, "data": np.asarray(obj)}
|
||||
palette = obj.palette
|
||||
if palette is not None:
|
||||
data["palette"] = palette.palette
|
||||
if palette.rawmode is not None:
|
||||
data["palette_rawmode"] = palette.rawmode
|
||||
|
||||
return cls.iter_item_to_bytes("image", data)
|
||||
|
||||
if isinstance(obj, MediaWithBytes) and isinstance(obj.media, Image.Image):
|
||||
exif = obj.media.getexif()
|
||||
if Image.ExifTags.Base.ImageID in exif and isinstance(
|
||||
exif[Image.ExifTags.Base.ImageID], uuid.UUID
|
||||
):
|
||||
return (exif[Image.ExifTags.Base.ImageID].bytes,)
|
||||
|
||||
return cls.iter_item_to_bytes("image", obj.original_bytes)
|
||||
|
||||
if isinstance(obj, torch.Tensor):
|
||||
tensor_obj: torch.Tensor = obj.cpu()
|
||||
tensor_dtype = tensor_obj.dtype
|
||||
tensor_shape = tensor_obj.shape
|
||||
|
||||
# NumPy does not support bfloat16.
|
||||
# Workaround: View the tensor as a contiguous 1D array of bytes
|
||||
if tensor_dtype == torch.bfloat16:
|
||||
tensor_obj = tensor_obj.contiguous()
|
||||
tensor_obj = tensor_obj.view((tensor_obj.numel(),)).view(torch.uint8)
|
||||
|
||||
return cls.iter_item_to_bytes(
|
||||
"tensor",
|
||||
{
|
||||
"original_dtype": str(tensor_dtype),
|
||||
"original_shape": tuple(tensor_shape),
|
||||
"data": tensor_obj.numpy(),
|
||||
},
|
||||
)
|
||||
|
||||
return cls.iter_item_to_bytes("tensor", tensor_obj.numpy())
|
||||
|
||||
if isinstance(obj, np.ndarray):
|
||||
if obj.ndim == 0:
|
||||
arr_data = obj.item()
|
||||
elif obj.flags.c_contiguous:
|
||||
# Not valid for 0-D arrays
|
||||
arr_data = obj.view(np.uint8).data
|
||||
else:
|
||||
# If the array is non-contiguous, we need to copy it first
|
||||
arr_data = obj.tobytes()
|
||||
|
||||
return cls.iter_item_to_bytes(
|
||||
"ndarray",
|
||||
{
|
||||
"dtype": obj.dtype.str,
|
||||
"shape": obj.shape,
|
||||
"data": arr_data,
|
||||
},
|
||||
)
|
||||
|
||||
logger.warning(
|
||||
"No serialization method found for %s. Falling back to pickle.", type(obj)
|
||||
)
|
||||
|
||||
return (pickle.dumps(obj),)
|
||||
|
||||
@classmethod
|
||||
def iter_item_to_bytes(
|
||||
cls,
|
||||
key: str,
|
||||
obj: object,
|
||||
) -> Iterable[bytes | memoryview]:
|
||||
if obj is None:
|
||||
yield key.encode("utf-8")
|
||||
return
|
||||
# Recursive cases
|
||||
if isinstance(obj, (list, tuple)):
|
||||
for i, elem in enumerate(obj):
|
||||
yield from cls.iter_item_to_bytes(f"{key}.{i}", elem)
|
||||
elif isinstance(obj, dict):
|
||||
for k, v in obj.items():
|
||||
yield from cls.iter_item_to_bytes(f"{key}.{k}", v)
|
||||
else:
|
||||
yield key.encode("utf-8")
|
||||
yield from cls.serialize_item(obj)
|
||||
|
||||
@classmethod
|
||||
def hash_kwargs(cls, **kwargs: object) -> str:
|
||||
hasher_factory = _get_hasher_factory(envs.VLLM_MM_HASHER_ALGORITHM)
|
||||
hasher = hasher_factory()
|
||||
|
||||
for k, v in sorted(kwargs.items(), key=lambda kv: kv[0]):
|
||||
for bytes_ in cls.iter_item_to_bytes(k, v):
|
||||
hasher.update(bytes_)
|
||||
|
||||
return hasher.hexdigest()
|
||||
36
vllm/multimodal/image.py
Normal file
36
vllm/multimodal/image.py
Normal file
@@ -0,0 +1,36 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def rescale_image_size(
|
||||
image: Image.Image, size_factor: float, transpose: int = -1
|
||||
) -> Image.Image:
|
||||
"""Rescale the dimensions of an image by a constant factor."""
|
||||
new_width = int(image.width * size_factor)
|
||||
new_height = int(image.height * size_factor)
|
||||
image = image.resize((new_width, new_height))
|
||||
if transpose >= 0:
|
||||
image = image.transpose(Image.Transpose(transpose))
|
||||
return image
|
||||
|
||||
|
||||
def rgba_to_rgb(
|
||||
image: Image.Image,
|
||||
background_color: tuple[int, int, int] | list[int] = (255, 255, 255),
|
||||
) -> Image.Image:
|
||||
"""Convert an RGBA image to RGB with filled background color."""
|
||||
assert image.mode == "RGBA"
|
||||
converted = Image.new("RGB", image.size, background_color)
|
||||
converted.paste(image, mask=image.split()[3]) # 3 is the alpha channel
|
||||
return converted
|
||||
|
||||
|
||||
def convert_image_mode(image: Image.Image, to_mode: str):
|
||||
if image.mode == to_mode:
|
||||
return image
|
||||
elif image.mode == "RGBA" and to_mode == "RGB":
|
||||
return rgba_to_rgb(image)
|
||||
else:
|
||||
return image.convert(to_mode)
|
||||
1159
vllm/multimodal/inputs.py
Normal file
1159
vllm/multimodal/inputs.py
Normal file
File diff suppressed because it is too large
Load Diff
20
vllm/multimodal/media/__init__.py
Normal file
20
vllm/multimodal/media/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from .audio import AudioEmbeddingMediaIO, AudioMediaIO
|
||||
from .base import MediaIO, MediaWithBytes
|
||||
from .connector import MEDIA_CONNECTOR_REGISTRY, MediaConnector
|
||||
from .image import ImageEmbeddingMediaIO, ImageMediaIO
|
||||
from .video import VIDEO_LOADER_REGISTRY, VideoMediaIO
|
||||
|
||||
__all__ = [
|
||||
"MediaIO",
|
||||
"MediaWithBytes",
|
||||
"AudioEmbeddingMediaIO",
|
||||
"AudioMediaIO",
|
||||
"ImageEmbeddingMediaIO",
|
||||
"ImageMediaIO",
|
||||
"VIDEO_LOADER_REGISTRY",
|
||||
"VideoMediaIO",
|
||||
"MEDIA_CONNECTOR_REGISTRY",
|
||||
"MediaConnector",
|
||||
]
|
||||
89
vllm/multimodal/media/audio.py
Normal file
89
vllm/multimodal/media/audio.py
Normal file
@@ -0,0 +1,89 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
import numpy.typing as npt
|
||||
import pybase64
|
||||
import torch
|
||||
|
||||
from vllm.utils.import_utils import PlaceholderModule
|
||||
from vllm.utils.serial_utils import tensor2base64
|
||||
|
||||
from .base import MediaIO
|
||||
|
||||
try:
|
||||
import librosa
|
||||
except ImportError:
|
||||
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
|
||||
|
||||
try:
|
||||
import soundfile
|
||||
except ImportError:
|
||||
soundfile = PlaceholderModule("soundfile") # type: ignore[assignment]
|
||||
|
||||
|
||||
class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]):
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__()
|
||||
|
||||
# `kwargs` contains custom arguments from
|
||||
# --media-io-kwargs for this modality.
|
||||
# They can be passed to the underlying
|
||||
# media loaders (e.g. custom implementations)
|
||||
# for flexible control.
|
||||
self.kwargs = kwargs
|
||||
|
||||
def load_bytes(self, data: bytes) -> tuple[npt.NDArray, float]:
|
||||
return librosa.load(BytesIO(data), sr=None)
|
||||
|
||||
def load_base64(
|
||||
self,
|
||||
media_type: str,
|
||||
data: str,
|
||||
) -> tuple[npt.NDArray, float]:
|
||||
return self.load_bytes(base64.b64decode(data))
|
||||
|
||||
def load_file(self, filepath: Path) -> tuple[npt.NDArray, float]:
|
||||
return librosa.load(filepath, sr=None)
|
||||
|
||||
def encode_base64(
|
||||
self,
|
||||
media: tuple[npt.NDArray, int],
|
||||
*,
|
||||
audio_format: str = "WAV",
|
||||
) -> str:
|
||||
audio, sr = media
|
||||
|
||||
with BytesIO() as buffer:
|
||||
soundfile.write(buffer, audio, sr, format=audio_format)
|
||||
data = buffer.getvalue()
|
||||
|
||||
return base64.b64encode(data).decode("utf-8")
|
||||
|
||||
|
||||
class AudioEmbeddingMediaIO(MediaIO[torch.Tensor]):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def load_bytes(self, data: bytes) -> torch.Tensor:
|
||||
buffer = BytesIO(data)
|
||||
# Enable sparse tensor integrity checks to prevent out-of-bounds
|
||||
# writes from maliciously crafted tensors
|
||||
with torch.sparse.check_sparse_tensor_invariants():
|
||||
tensor = torch.load(buffer, weights_only=True)
|
||||
return tensor.to_dense()
|
||||
|
||||
def load_base64(self, media_type: str, data: str) -> torch.Tensor:
|
||||
return self.load_bytes(pybase64.b64decode(data, validate=True))
|
||||
|
||||
def load_file(self, filepath: Path) -> torch.Tensor:
|
||||
# Enable sparse tensor integrity checks to prevent out-of-bounds
|
||||
# writes from maliciously crafted tensors
|
||||
with torch.sparse.check_sparse_tensor_invariants():
|
||||
tensor = torch.load(filepath, weights_only=True)
|
||||
return tensor.to_dense()
|
||||
|
||||
def encode_base64(self, media: torch.Tensor) -> str:
|
||||
return tensor2base64(media)
|
||||
61
vllm/multimodal/media/base.py
Normal file
61
vllm/multimodal/media/base.py
Normal file
@@ -0,0 +1,61 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
import numpy as np
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
@dataclass
|
||||
class MediaWithBytes(Generic[_T]):
|
||||
"""
|
||||
Wrapper that couples a media object with its original encoded bytes.
|
||||
|
||||
This ensures the raw bytes and media object remain synchronized,
|
||||
preventing cache corruption from in-place modifications.
|
||||
|
||||
The wrapper delegates attribute access to the underlying media object,
|
||||
making it behave transparently like the wrapped type (e.g., PIL.Image).
|
||||
|
||||
NOTE: Currently, this wrapper is used only for the image modality.
|
||||
"""
|
||||
|
||||
media: _T
|
||||
original_bytes: bytes = field(repr=False)
|
||||
|
||||
def __array__(self, *args, **kwargs) -> np.ndarray:
|
||||
"""Allow np.array(obj) to return np.array(obj.media)."""
|
||||
return np.array(self.media, *args, **kwargs)
|
||||
|
||||
def __getstate__(self):
|
||||
return self.__dict__.copy()
|
||||
|
||||
def __setstate__(self, state: dict[str, Any]):
|
||||
self.__dict__.update(state)
|
||||
|
||||
def __getattr__(self, name: str):
|
||||
"""Delegate attribute access to the underlying media object."""
|
||||
return getattr(self.media, name)
|
||||
|
||||
|
||||
class MediaIO(ABC, Generic[_T]):
|
||||
@abstractmethod
|
||||
def load_bytes(self, data: bytes) -> _T:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def load_base64(self, media_type: str, data: str) -> _T:
|
||||
"""
|
||||
List of media types:
|
||||
https://www.iana.org/assignments/media-types/media-types.xhtml
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def load_file(self, filepath: Path) -> _T:
|
||||
raise NotImplementedError
|
||||
343
vllm/multimodal/media/connector.py
Normal file
343
vllm/multimodal/media/connector.py
Normal file
@@ -0,0 +1,343 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import atexit
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
from typing import Any, TypeVar
|
||||
from urllib.request import url2pathname
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import torch
|
||||
from PIL import Image, UnidentifiedImageError
|
||||
from urllib3.util import Url, parse_url
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.connections import HTTPConnection, global_http_connection
|
||||
from vllm.utils.registry import ExtensionManager
|
||||
|
||||
from .audio import AudioEmbeddingMediaIO, AudioMediaIO
|
||||
from .base import MediaIO
|
||||
from .image import ImageEmbeddingMediaIO, ImageMediaIO
|
||||
from .video import VideoMediaIO
|
||||
|
||||
_M = TypeVar("_M")
|
||||
|
||||
global_thread_pool = ThreadPoolExecutor(
|
||||
max_workers=envs.VLLM_MEDIA_LOADING_THREAD_COUNT
|
||||
)
|
||||
atexit.register(global_thread_pool.shutdown)
|
||||
|
||||
MEDIA_CONNECTOR_REGISTRY = ExtensionManager()
|
||||
|
||||
|
||||
@MEDIA_CONNECTOR_REGISTRY.register("http")
|
||||
class MediaConnector:
|
||||
def __init__(
|
||||
self,
|
||||
media_io_kwargs: dict[str, dict[str, Any]] | None = None,
|
||||
connection: HTTPConnection = global_http_connection,
|
||||
*,
|
||||
allowed_local_media_path: str = "",
|
||||
allowed_media_domains: list[str] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
media_io_kwargs: Additional args passed to process media
|
||||
inputs, keyed by modalities. For example,
|
||||
to set num_frames for video, set
|
||||
`--media-io-kwargs '{"video":{"num_frames":40}}'`
|
||||
connection: HTTP connection client to download media contents.
|
||||
allowed_local_media_path: A local directory to load media files from.
|
||||
allowed_media_domains: If set, only media URLs that belong to this
|
||||
domain can be used for multi-modal inputs.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.media_io_kwargs: dict[str, dict[str, Any]] = (
|
||||
media_io_kwargs if media_io_kwargs else {}
|
||||
)
|
||||
self.connection = connection
|
||||
|
||||
if allowed_local_media_path:
|
||||
allowed_local_media_path_ = Path(allowed_local_media_path)
|
||||
|
||||
if not allowed_local_media_path_.exists():
|
||||
raise ValueError(
|
||||
"Invalid `--allowed-local-media-path`: The path "
|
||||
f"{allowed_local_media_path_} does not exist."
|
||||
)
|
||||
if not allowed_local_media_path_.is_dir():
|
||||
raise ValueError(
|
||||
"Invalid `--allowed-local-media-path`: The path "
|
||||
f"{allowed_local_media_path_} must be a directory."
|
||||
)
|
||||
else:
|
||||
allowed_local_media_path_ = None
|
||||
|
||||
self.allowed_local_media_path = allowed_local_media_path_
|
||||
if allowed_media_domains is None:
|
||||
allowed_media_domains = []
|
||||
self.allowed_media_domains = allowed_media_domains
|
||||
|
||||
def _load_data_url(
|
||||
self,
|
||||
url_spec: Url,
|
||||
media_io: MediaIO[_M],
|
||||
) -> _M: # type: ignore[type-var]
|
||||
url_spec_path = url_spec.path or ""
|
||||
data_spec, data = url_spec_path.split(",", 1)
|
||||
media_type, data_type = data_spec.split(";", 1)
|
||||
# media_type starts with a leading "/" (e.g., "/video/jpeg")
|
||||
media_type = media_type.lstrip("/")
|
||||
|
||||
if data_type != "base64":
|
||||
msg = "Only base64 data URLs are supported for now."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
return media_io.load_base64(media_type, data)
|
||||
|
||||
def _load_file_url(
|
||||
self,
|
||||
url_spec: Url,
|
||||
media_io: MediaIO[_M],
|
||||
) -> _M: # type: ignore[type-var]
|
||||
allowed_local_media_path = self.allowed_local_media_path
|
||||
if allowed_local_media_path is None:
|
||||
raise RuntimeError(
|
||||
"Cannot load local files without `--allowed-local-media-path`."
|
||||
)
|
||||
|
||||
url_spec_path = url_spec.path or ""
|
||||
url_spec_netloc = url_spec.netloc or ""
|
||||
filepath = Path(url2pathname(url_spec_netloc + url_spec_path))
|
||||
if allowed_local_media_path not in filepath.resolve().parents:
|
||||
raise ValueError(
|
||||
f"The file path {filepath} must be a subpath "
|
||||
f"of `--allowed-local-media-path {allowed_local_media_path}`."
|
||||
)
|
||||
|
||||
return media_io.load_file(filepath)
|
||||
|
||||
def _assert_url_in_allowed_media_domains(self, url_spec: Url) -> None:
|
||||
if (
|
||||
self.allowed_media_domains
|
||||
and url_spec.hostname not in self.allowed_media_domains
|
||||
):
|
||||
raise ValueError(
|
||||
f"The URL must be from one of the allowed domains: "
|
||||
f"{self.allowed_media_domains}. Input URL domain: "
|
||||
f"{url_spec.hostname}"
|
||||
)
|
||||
|
||||
def load_from_url(
|
||||
self,
|
||||
url: str,
|
||||
media_io: MediaIO[_M],
|
||||
*,
|
||||
fetch_timeout: int | None = None,
|
||||
) -> _M: # type: ignore[type-var]
|
||||
url_spec = parse_url(url)
|
||||
|
||||
if url_spec.scheme and url_spec.scheme.startswith("http"):
|
||||
self._assert_url_in_allowed_media_domains(url_spec)
|
||||
|
||||
connection = self.connection
|
||||
data = connection.get_bytes(
|
||||
url_spec.url,
|
||||
timeout=fetch_timeout,
|
||||
allow_redirects=envs.VLLM_MEDIA_URL_ALLOW_REDIRECTS,
|
||||
)
|
||||
|
||||
return media_io.load_bytes(data)
|
||||
|
||||
if url_spec.scheme == "data":
|
||||
return self._load_data_url(url_spec, media_io)
|
||||
|
||||
if url_spec.scheme == "file":
|
||||
return self._load_file_url(url_spec, media_io)
|
||||
|
||||
msg = "The URL must be either a HTTP, data or file URL."
|
||||
raise ValueError(msg)
|
||||
|
||||
async def load_from_url_async(
|
||||
self,
|
||||
url: str,
|
||||
media_io: MediaIO[_M],
|
||||
*,
|
||||
fetch_timeout: int | None = None,
|
||||
) -> _M:
|
||||
url_spec = parse_url(url)
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
if url_spec.scheme and url_spec.scheme.startswith("http"):
|
||||
self._assert_url_in_allowed_media_domains(url_spec)
|
||||
|
||||
connection = self.connection
|
||||
data = await connection.async_get_bytes(
|
||||
url_spec.url,
|
||||
timeout=fetch_timeout,
|
||||
allow_redirects=envs.VLLM_MEDIA_URL_ALLOW_REDIRECTS,
|
||||
)
|
||||
future = loop.run_in_executor(global_thread_pool, media_io.load_bytes, data)
|
||||
return await future
|
||||
|
||||
if url_spec.scheme == "data":
|
||||
future = loop.run_in_executor(
|
||||
global_thread_pool, self._load_data_url, url_spec, media_io
|
||||
)
|
||||
return await future
|
||||
|
||||
if url_spec.scheme == "file":
|
||||
future = loop.run_in_executor(
|
||||
global_thread_pool, self._load_file_url, url_spec, media_io
|
||||
)
|
||||
return await future
|
||||
msg = "The URL must be either a HTTP, data or file URL."
|
||||
raise ValueError(msg)
|
||||
|
||||
def fetch_audio(
|
||||
self,
|
||||
audio_url: str,
|
||||
) -> tuple[np.ndarray, int | float]:
|
||||
"""
|
||||
Load audio from a URL.
|
||||
"""
|
||||
audio_io = AudioMediaIO(**self.media_io_kwargs.get("audio", {}))
|
||||
|
||||
return self.load_from_url(
|
||||
audio_url,
|
||||
audio_io,
|
||||
fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
|
||||
)
|
||||
|
||||
async def fetch_audio_async(
|
||||
self,
|
||||
audio_url: str,
|
||||
) -> tuple[np.ndarray, int | float]:
|
||||
"""
|
||||
Asynchronously fetch audio from a URL.
|
||||
"""
|
||||
audio_io = AudioMediaIO(**self.media_io_kwargs.get("audio", {}))
|
||||
|
||||
return await self.load_from_url_async(
|
||||
audio_url,
|
||||
audio_io,
|
||||
fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
|
||||
)
|
||||
|
||||
def fetch_image(
|
||||
self,
|
||||
image_url: str,
|
||||
*,
|
||||
image_mode: str = "RGB",
|
||||
) -> Image.Image:
|
||||
"""
|
||||
Load a PIL image from an HTTP or base64 data URL.
|
||||
|
||||
By default, the image is converted into RGB format.
|
||||
"""
|
||||
image_io = ImageMediaIO(
|
||||
image_mode=image_mode, **self.media_io_kwargs.get("image", {})
|
||||
)
|
||||
|
||||
try:
|
||||
return self.load_from_url(
|
||||
image_url,
|
||||
image_io,
|
||||
fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
|
||||
)
|
||||
except UnidentifiedImageError as e:
|
||||
# convert to ValueError to be properly caught upstream
|
||||
raise ValueError(str(e)) from e
|
||||
|
||||
async def fetch_image_async(
|
||||
self,
|
||||
image_url: str,
|
||||
*,
|
||||
image_mode: str = "RGB",
|
||||
) -> Image.Image:
|
||||
"""
|
||||
Asynchronously load a PIL image from an HTTP or base64 data URL.
|
||||
|
||||
By default, the image is converted into RGB format.
|
||||
"""
|
||||
image_io = ImageMediaIO(
|
||||
image_mode=image_mode, **self.media_io_kwargs.get("image", {})
|
||||
)
|
||||
|
||||
try:
|
||||
return await self.load_from_url_async(
|
||||
image_url,
|
||||
image_io,
|
||||
fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
|
||||
)
|
||||
except UnidentifiedImageError as e:
|
||||
# convert to ValueError to be properly caught upstream
|
||||
raise ValueError(str(e)) from e
|
||||
|
||||
def fetch_video(
|
||||
self,
|
||||
video_url: str,
|
||||
*,
|
||||
image_mode: str = "RGB",
|
||||
) -> tuple[npt.NDArray, dict[str, Any]]:
|
||||
"""
|
||||
Load video from an HTTP or base64 data URL.
|
||||
"""
|
||||
image_io = ImageMediaIO(
|
||||
image_mode=image_mode, **self.media_io_kwargs.get("image", {})
|
||||
)
|
||||
video_io = VideoMediaIO(image_io, **self.media_io_kwargs.get("video", {}))
|
||||
|
||||
return self.load_from_url(
|
||||
video_url,
|
||||
video_io,
|
||||
fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
|
||||
)
|
||||
|
||||
async def fetch_video_async(
|
||||
self,
|
||||
video_url: str,
|
||||
*,
|
||||
image_mode: str = "RGB",
|
||||
) -> tuple[npt.NDArray, dict[str, Any]]:
|
||||
"""
|
||||
Asynchronously load video from an HTTP or base64 data URL.
|
||||
|
||||
By default, the image is converted into RGB format.
|
||||
"""
|
||||
image_io = ImageMediaIO(
|
||||
image_mode=image_mode, **self.media_io_kwargs.get("image", {})
|
||||
)
|
||||
video_io = VideoMediaIO(image_io, **self.media_io_kwargs.get("video", {}))
|
||||
|
||||
return await self.load_from_url_async(
|
||||
video_url,
|
||||
video_io,
|
||||
fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
|
||||
)
|
||||
|
||||
def fetch_image_embedding(
|
||||
self,
|
||||
data: str,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Load image embedding from a URL.
|
||||
"""
|
||||
image_embedding_io = ImageEmbeddingMediaIO()
|
||||
|
||||
return image_embedding_io.load_base64("", data)
|
||||
|
||||
def fetch_audio_embedding(
|
||||
self,
|
||||
data: str,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Load audio embedding from a URL.
|
||||
"""
|
||||
audio_embedding_io = AudioEmbeddingMediaIO()
|
||||
|
||||
return audio_embedding_io.load_base64("", data)
|
||||
113
vllm/multimodal/media/image.py
Normal file
113
vllm/multimodal/media/image.py
Normal file
@@ -0,0 +1,113 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
import pybase64
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from vllm.utils.serial_utils import tensor2base64
|
||||
|
||||
from ..image import convert_image_mode, rgba_to_rgb
|
||||
from .base import MediaIO, MediaWithBytes
|
||||
|
||||
|
||||
class ImageMediaIO(MediaIO[Image.Image]):
|
||||
def __init__(self, image_mode: str = "RGB", **kwargs) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.image_mode = image_mode
|
||||
# `kwargs` contains custom arguments from
|
||||
# --media-io-kwargs for this modality.
|
||||
# They can be passed to the underlying
|
||||
# media loaders (e.g. custom implementations)
|
||||
# for flexible control.
|
||||
self.kwargs = kwargs
|
||||
|
||||
# Extract RGBA background color from kwargs if provided
|
||||
# Default to white background for backward compatibility
|
||||
rgba_bg = kwargs.get("rgba_background_color", (255, 255, 255))
|
||||
# Convert list to tuple for consistency
|
||||
if isinstance(rgba_bg, list):
|
||||
rgba_bg = tuple(rgba_bg)
|
||||
|
||||
# Validate rgba_background_color format
|
||||
if not (
|
||||
isinstance(rgba_bg, tuple)
|
||||
and len(rgba_bg) == 3
|
||||
and all(isinstance(c, int) and 0 <= c <= 255 for c in rgba_bg)
|
||||
):
|
||||
raise ValueError(
|
||||
"rgba_background_color must be a list or tuple of 3 integers "
|
||||
"in the range [0, 255]."
|
||||
)
|
||||
self.rgba_background_color = rgba_bg
|
||||
|
||||
def _convert_image_mode(
|
||||
self, image: Image.Image | MediaWithBytes[Image.Image]
|
||||
) -> Image.Image:
|
||||
"""Convert image mode with custom background color."""
|
||||
if isinstance(image, MediaWithBytes):
|
||||
image = image.media
|
||||
if image.mode == self.image_mode:
|
||||
return image
|
||||
elif image.mode == "RGBA" and self.image_mode == "RGB":
|
||||
return rgba_to_rgb(image, self.rgba_background_color)
|
||||
else:
|
||||
return convert_image_mode(image, self.image_mode)
|
||||
|
||||
def load_bytes(self, data: bytes) -> MediaWithBytes[Image.Image]:
|
||||
image = Image.open(BytesIO(data))
|
||||
return MediaWithBytes(self._convert_image_mode(image), data)
|
||||
|
||||
def load_base64(self, media_type: str, data: str) -> MediaWithBytes[Image.Image]:
|
||||
return self.load_bytes(pybase64.b64decode(data, validate=True))
|
||||
|
||||
def load_file(self, filepath: Path) -> MediaWithBytes[Image.Image]:
|
||||
with open(filepath, "rb") as f:
|
||||
data = f.read()
|
||||
image = Image.open(BytesIO(data))
|
||||
return MediaWithBytes(self._convert_image_mode(image), data)
|
||||
|
||||
def encode_base64(
|
||||
self,
|
||||
media: Image.Image,
|
||||
*,
|
||||
image_format: str = "PNG",
|
||||
) -> str:
|
||||
image = media
|
||||
|
||||
with BytesIO() as buffer:
|
||||
image = self._convert_image_mode(image)
|
||||
image.save(buffer, image_format)
|
||||
data = buffer.getvalue()
|
||||
|
||||
return pybase64.b64encode(data).decode("utf-8")
|
||||
|
||||
|
||||
class ImageEmbeddingMediaIO(MediaIO[torch.Tensor]):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def load_bytes(self, data: bytes) -> torch.Tensor:
|
||||
buffer = BytesIO(data)
|
||||
# Enable sparse tensor integrity checks to prevent out-of-bounds
|
||||
# writes from maliciously crafted tensors
|
||||
with torch.sparse.check_sparse_tensor_invariants():
|
||||
tensor = torch.load(buffer, weights_only=True)
|
||||
return tensor.to_dense()
|
||||
|
||||
def load_base64(self, media_type: str, data: str) -> torch.Tensor:
|
||||
return self.load_bytes(pybase64.b64decode(data, validate=True))
|
||||
|
||||
def load_file(self, filepath: Path) -> torch.Tensor:
|
||||
# Enable sparse tensor integrity checks to prevent out-of-bounds
|
||||
# writes from maliciously crafted tensors
|
||||
with torch.sparse.check_sparse_tensor_invariants():
|
||||
tensor = torch.load(filepath, weights_only=True)
|
||||
return tensor.to_dense()
|
||||
|
||||
def encode_base64(self, media: torch.Tensor) -> str:
|
||||
return tensor2base64(media)
|
||||
89
vllm/multimodal/media/video.py
Normal file
89
vllm/multimodal/media/video.py
Normal file
@@ -0,0 +1,89 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import base64
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
from PIL import Image
|
||||
|
||||
from vllm import envs
|
||||
|
||||
from ..video import VIDEO_LOADER_REGISTRY
|
||||
from .base import MediaIO
|
||||
from .image import ImageMediaIO
|
||||
|
||||
|
||||
class VideoMediaIO(MediaIO[tuple[npt.NDArray, dict[str, Any]]]):
|
||||
def __init__(
|
||||
self,
|
||||
image_io: ImageMediaIO,
|
||||
num_frames: int = 32,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.image_io = image_io
|
||||
self.num_frames = num_frames
|
||||
# `kwargs` contains custom arguments from
|
||||
# --media-io-kwargs for this modality.
|
||||
# They can be passed to the underlying
|
||||
# media loaders (e.g. custom implementations)
|
||||
# for flexible control.
|
||||
|
||||
# Allow per-request override of video backend via kwargs.
|
||||
# This enables users to specify a different backend than the
|
||||
# global VLLM_VIDEO_LOADER_BACKEND env var, e.g.:
|
||||
# --media-io-kwargs '{"video": {"video_backend": "torchcodec"}}'
|
||||
video_loader_backend = (
|
||||
kwargs.pop("video_backend", None) or envs.VLLM_VIDEO_LOADER_BACKEND
|
||||
)
|
||||
self.kwargs = kwargs
|
||||
self.video_loader = VIDEO_LOADER_REGISTRY.load(video_loader_backend)
|
||||
|
||||
def load_bytes(self, data: bytes) -> tuple[npt.NDArray, dict[str, Any]]:
|
||||
return self.video_loader.load_bytes(
|
||||
data, num_frames=self.num_frames, **self.kwargs
|
||||
)
|
||||
|
||||
def load_base64(
|
||||
self, media_type: str, data: str
|
||||
) -> tuple[npt.NDArray, dict[str, Any]]:
|
||||
if media_type.lower() == "video/jpeg":
|
||||
load_frame = partial(
|
||||
self.image_io.load_base64,
|
||||
"image/jpeg",
|
||||
)
|
||||
|
||||
return np.stack(
|
||||
[np.asarray(load_frame(frame_data)) for frame_data in data.split(",")]
|
||||
), {}
|
||||
|
||||
return self.load_bytes(base64.b64decode(data))
|
||||
|
||||
def load_file(self, filepath: Path) -> tuple[npt.NDArray, dict[str, Any]]:
|
||||
with filepath.open("rb") as f:
|
||||
data = f.read()
|
||||
|
||||
return self.load_bytes(data)
|
||||
|
||||
def encode_base64(
|
||||
self,
|
||||
media: npt.NDArray,
|
||||
*,
|
||||
video_format: str = "JPEG",
|
||||
) -> str:
|
||||
video = media
|
||||
|
||||
if video_format == "JPEG":
|
||||
encode_frame = partial(
|
||||
self.image_io.encode_base64,
|
||||
image_format=video_format,
|
||||
)
|
||||
|
||||
return ",".join(encode_frame(Image.fromarray(frame)) for frame in video)
|
||||
|
||||
msg = "Only JPEG format is supported for now."
|
||||
raise NotImplementedError(msg)
|
||||
710
vllm/multimodal/parse.py
Normal file
710
vllm/multimodal/parse.py
Normal file
@@ -0,0 +1,710 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import UserDict
|
||||
from collections.abc import Callable, Iterator, Mapping, Sequence, Set
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Generic,
|
||||
Literal,
|
||||
NamedTuple,
|
||||
TypeAlias,
|
||||
TypeGuard,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.utils.collection_utils import is_list_of
|
||||
from vllm.utils.import_utils import LazyLoader
|
||||
|
||||
from .audio import AudioResampler, AudioSpec, normalize_audio
|
||||
from .inputs import (
|
||||
AudioItem,
|
||||
HfAudioItem,
|
||||
HfImageItem,
|
||||
HfVideoItem,
|
||||
ImageItem,
|
||||
ModalityData,
|
||||
MultiModalDataDict,
|
||||
MultiModalFieldConfig,
|
||||
MultiModalKwargsItems,
|
||||
MultiModalUUIDDict,
|
||||
VideoItem,
|
||||
)
|
||||
from .media import MediaWithBytes
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_I = TypeVar("_I")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import PIL.Image as PILImage
|
||||
else:
|
||||
PILImage = LazyLoader("PILImage", globals(), "PIL.Image")
|
||||
|
||||
|
||||
class ModalityDataItems(ABC, Generic[_T, _I]):
|
||||
"""
|
||||
Represents data items for a modality in
|
||||
[`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
|
||||
"""
|
||||
|
||||
def __init__(self, data: _T, modality: str) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.data: _T = data
|
||||
self.modality = modality
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{type(self).__name__}(modality={self.modality!r}, len={len(self)})"
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.get_count()
|
||||
|
||||
def __getitem__(self, index: int) -> _I:
|
||||
return self.get(index)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# Auto-generated
|
||||
def __iter__(self) -> Iterator[_I]: ...
|
||||
|
||||
@abstractmethod
|
||||
def get_count(self) -> int:
|
||||
"""Get the number of data items."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get(self, index: int) -> _I:
|
||||
"""Get a data item by its index."""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_all(self) -> list[_I]:
|
||||
"""Get all data items."""
|
||||
return [self.get(idx) for idx in range(self.get_count())]
|
||||
|
||||
def get_item_for_hash(self, index: int) -> object:
|
||||
return self.get(index)
|
||||
|
||||
def get_all_items_for_hash(self) -> list[object]:
|
||||
return [self.get_item_for_hash(idx) for idx in range(self.get_count())]
|
||||
|
||||
@abstractmethod
|
||||
def get_processor_data(self) -> Mapping[str, object]:
|
||||
"""Get the data to pass to the HF processor."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_passthrough_data(self) -> Mapping[str, object]:
|
||||
"""Get the data to pass directly to the model."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ProcessorBatchItems(ModalityDataItems[Sequence[_T], _T]):
|
||||
"""Base class for data items that are arranged in a list."""
|
||||
|
||||
def _unwrap(self, item: _T | MediaWithBytes[_T]) -> _T:
|
||||
"""Extract media from wrapper if present."""
|
||||
return item.media if isinstance(item, MediaWithBytes) else item
|
||||
|
||||
def get_count(self) -> int:
|
||||
return len(self.data)
|
||||
|
||||
def get(self, index: int) -> _T:
|
||||
return self._unwrap(self.data[index])
|
||||
|
||||
def get_item_for_hash(self, index: int) -> _T | MediaWithBytes[_T]:
|
||||
# Return raw item for hashing (preserves original_bytes if present)
|
||||
return self.data[index]
|
||||
|
||||
def get_processor_data(self) -> Mapping[str, object]:
|
||||
return {f"{self.modality}s": self.get_all()}
|
||||
|
||||
def get_passthrough_data(self) -> Mapping[str, object]:
|
||||
return {}
|
||||
|
||||
|
||||
def validate_embedding_ndim(
|
||||
tensor: torch.Tensor,
|
||||
modality: str,
|
||||
index: int | None = None,
|
||||
) -> None:
|
||||
"""Validate tensor ndim for multimodal embeddings.
|
||||
|
||||
Single embeddings should be 2D (seq_len, hidden_size).
|
||||
Batched embeddings should be 3D (batch, seq_len, hidden_size).
|
||||
|
||||
Args:
|
||||
tensor: The tensor to validate.
|
||||
modality: The modality name for error messages (e.g., "image", "audio").
|
||||
index: Optional index for list items, included in error messages.
|
||||
"""
|
||||
if tensor.ndim < 2 or tensor.ndim > 3:
|
||||
idx_str = f" [{index}]" if index is not None else ""
|
||||
raise ValueError(
|
||||
f"{modality.capitalize()} embedding{idx_str} must be 2D "
|
||||
f"(seq_len, hidden_size) or 3D (batch, seq_len, hidden_size), "
|
||||
f"got {tensor.ndim}D tensor with shape {tuple(tensor.shape)}"
|
||||
)
|
||||
|
||||
|
||||
class EmbeddingItems(
|
||||
ModalityDataItems[torch.Tensor | list[torch.Tensor], torch.Tensor]
|
||||
):
|
||||
"""
|
||||
Base class for data items that are expressed as a batched embedding tensor,
|
||||
or a list of embedding tensors (one per item).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data: torch.Tensor | list[torch.Tensor],
|
||||
modality: str,
|
||||
expected_hidden_size: int | None = None,
|
||||
) -> None:
|
||||
super().__init__(data, modality)
|
||||
|
||||
# Validate ndim first (before hidden_size which depends on correct ndim)
|
||||
self._validate_ndim()
|
||||
|
||||
# Validate hidden dimension if expected size is provided
|
||||
if expected_hidden_size is not None:
|
||||
self._validate_hidden_size(expected_hidden_size)
|
||||
|
||||
def _validate_ndim(self) -> None:
|
||||
"""Validate that embedding tensors have correct ndim (2D or 3D)."""
|
||||
if isinstance(self.data, torch.Tensor):
|
||||
validate_embedding_ndim(self.data, self.modality)
|
||||
else:
|
||||
# List of tensors: each should be 2D (seq_len, hidden_size)
|
||||
for idx, tensor in enumerate(self.data):
|
||||
if tensor.ndim != 2:
|
||||
raise ValueError(
|
||||
f"{self.modality.capitalize()} embedding [{idx}] must be "
|
||||
f"2D (seq_len, hidden_size), got {tensor.ndim}D tensor "
|
||||
f"with shape {tuple(tensor.shape)}"
|
||||
)
|
||||
|
||||
def _validate_hidden_size(self, expected_hidden_size: int) -> None:
|
||||
"""Validate that embedding hidden dimension matches expected size.
|
||||
|
||||
This validates hidden dimensions to prevent vulnerabilities: Embeddings
|
||||
with correct ndim but wrong hidden dimension could bypass initial
|
||||
checks and cause crashes during model inference when dimensions don't match.
|
||||
"""
|
||||
if isinstance(self.data, torch.Tensor):
|
||||
# Batched tensor: shape is (batch, seq_len, hidden_size)
|
||||
actual_hidden_size = self.data.shape[-1]
|
||||
if actual_hidden_size != expected_hidden_size:
|
||||
raise ValueError(
|
||||
f"{self.modality.capitalize()} embedding hidden dimension "
|
||||
f"mismatch: got {actual_hidden_size}, but model expects "
|
||||
f"{expected_hidden_size}. Embedding shape: {tuple(self.data.shape)}"
|
||||
)
|
||||
else:
|
||||
# List of tensors: each has shape (seq_len, hidden_size)
|
||||
for idx, tensor in enumerate(self.data):
|
||||
actual_hidden_size = tensor.shape[-1]
|
||||
if actual_hidden_size != expected_hidden_size:
|
||||
raise ValueError(
|
||||
f"{self.modality.capitalize()} embedding [{idx}] hidden "
|
||||
f"dimension mismatch: got {actual_hidden_size}, but model "
|
||||
f"expects {expected_hidden_size}. "
|
||||
f"Embedding shape: {tuple(tensor.shape)}"
|
||||
)
|
||||
|
||||
def _unwrap(
|
||||
self, item: torch.Tensor | MediaWithBytes[torch.Tensor]
|
||||
) -> torch.Tensor:
|
||||
"""Extract media from wrapper if present."""
|
||||
return item.media if isinstance(item, MediaWithBytes) else item
|
||||
|
||||
def get_count(self) -> int:
|
||||
return len(self.data)
|
||||
|
||||
def get(self, index: int) -> torch.Tensor:
|
||||
return self._unwrap(self.data[index])
|
||||
|
||||
def get_processor_data(self) -> Mapping[str, object]:
|
||||
return {}
|
||||
|
||||
def get_passthrough_data(self) -> Mapping[str, object]:
|
||||
return {f"{self.modality}_embeds": self.data}
|
||||
|
||||
def get_feature_size(self, item_idx: int) -> int:
|
||||
return len(self.get(item_idx))
|
||||
|
||||
|
||||
class DictEmbeddingItems(
|
||||
ModalityDataItems[Mapping[str, torch.Tensor], Mapping[str, torch.Tensor]]
|
||||
):
|
||||
"""
|
||||
Base class for data items that are expressed as a dictionary of tensors.
|
||||
|
||||
Usually, the dictionary keys correspond to the outputs of HF processor.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data: Mapping[str, torch.Tensor],
|
||||
modality: str,
|
||||
required_fields: set[str],
|
||||
fields_factory: Callable[
|
||||
[Mapping[str, torch.Tensor]],
|
||||
Mapping[str, MultiModalFieldConfig],
|
||||
],
|
||||
) -> None:
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
|
||||
super().__init__(data, modality)
|
||||
|
||||
missing_required_data_keys = required_fields - data.keys()
|
||||
if missing_required_data_keys:
|
||||
data_keys = set(data.keys())
|
||||
msg = (
|
||||
f"The data should contain the fields: {required_fields}, "
|
||||
f"but only found the following keys: {data_keys}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
fields_config = fields_factory(data)
|
||||
missing_required_fields = required_fields - fields_config.keys()
|
||||
if missing_required_fields:
|
||||
fields = set(fields_config.keys())
|
||||
msg = f"{required_fields=} should be a subset of {fields=}"
|
||||
raise ValueError(msg)
|
||||
|
||||
self.fields_config = fields_config
|
||||
self.required_fields = required_fields
|
||||
|
||||
self._kwargs = MultiModalKwargsItems.from_hf_inputs(
|
||||
BatchFeature(dict(data)),
|
||||
fields_config,
|
||||
)
|
||||
|
||||
def get_count(self) -> int:
|
||||
return len(self._kwargs[self.modality])
|
||||
|
||||
def get(self, index: int) -> Mapping[str, torch.Tensor]:
|
||||
return self._kwargs[self.modality][index].get_data()
|
||||
|
||||
def get_processor_data(self) -> Mapping[str, object]:
|
||||
return {}
|
||||
|
||||
def get_passthrough_data(self) -> Mapping[str, object]:
|
||||
return self.data
|
||||
|
||||
|
||||
class AudioProcessorItems(ProcessorBatchItems[HfAudioItem | None]):
|
||||
def __init__(self, data: Sequence[HfAudioItem | None]) -> None:
|
||||
super().__init__(data, "audio")
|
||||
|
||||
def get_audio_length(self, item_idx: int) -> int:
|
||||
audio = self.get(item_idx)
|
||||
if audio is None:
|
||||
raise ValueError(f"Cannot get length of cached audio at {item_idx}")
|
||||
|
||||
return len(audio)
|
||||
|
||||
|
||||
class AudioEmbeddingItems(EmbeddingItems):
|
||||
def __init__(
|
||||
self,
|
||||
data: torch.Tensor | list[torch.Tensor],
|
||||
expected_hidden_size: int | None = None,
|
||||
) -> None:
|
||||
super().__init__(data, "audio", expected_hidden_size)
|
||||
|
||||
|
||||
class ImageSize(NamedTuple):
|
||||
width: int
|
||||
height: int
|
||||
|
||||
|
||||
class ImageProcessorItems(ProcessorBatchItems[HfImageItem | None]):
|
||||
def __init__(self, data: Sequence[HfImageItem | None]) -> None:
|
||||
super().__init__(data, "image")
|
||||
|
||||
def get_image_size(self, item_idx: int) -> ImageSize:
|
||||
image = self.get(item_idx)
|
||||
if image is None:
|
||||
raise ValueError(f"Cannot get size of cached image at {item_idx}")
|
||||
|
||||
if isinstance(image, PILImage.Image):
|
||||
return ImageSize(*image.size)
|
||||
if isinstance(image, (np.ndarray, torch.Tensor)):
|
||||
_, h, w = image.shape
|
||||
return ImageSize(w, h)
|
||||
|
||||
assert_never(image)
|
||||
|
||||
|
||||
class ImageEmbeddingItems(EmbeddingItems):
|
||||
def __init__(
|
||||
self,
|
||||
data: torch.Tensor | list[torch.Tensor],
|
||||
expected_hidden_size: int | None = None,
|
||||
) -> None:
|
||||
super().__init__(data, "image", expected_hidden_size)
|
||||
|
||||
|
||||
class VideoProcessorItems(ProcessorBatchItems[HfVideoItem | None]):
|
||||
def __init__(
|
||||
self,
|
||||
data: Sequence[HfVideoItem | None],
|
||||
metadata: dict[str, Any] | list[dict[str, Any] | None] | None = None,
|
||||
) -> None:
|
||||
super().__init__(data, "video")
|
||||
|
||||
self.metadata = metadata
|
||||
|
||||
def get_num_frames(self, item_idx: int) -> int:
|
||||
video = self.get(item_idx)
|
||||
if video is None:
|
||||
raise ValueError(f"Cannot get length of cached video at {item_idx}")
|
||||
|
||||
return len(video)
|
||||
|
||||
def get_frame_size(self, item_idx: int) -> ImageSize:
|
||||
video = self.get(item_idx)
|
||||
if video is None:
|
||||
raise ValueError(f"Cannot get size of cached video at {item_idx}")
|
||||
if len(video) == 0:
|
||||
raise ValueError(f"Cannot get size of empty video at {item_idx}")
|
||||
|
||||
image = video[0]
|
||||
|
||||
if isinstance(image, PILImage.Image):
|
||||
return ImageSize(*image.size)
|
||||
if isinstance(image, (np.ndarray, torch.Tensor)):
|
||||
_, h, w = image.shape
|
||||
return ImageSize(w, h)
|
||||
|
||||
assert_never(image)
|
||||
|
||||
|
||||
class VideoEmbeddingItems(EmbeddingItems):
|
||||
def __init__(
|
||||
self,
|
||||
data: torch.Tensor | list[torch.Tensor],
|
||||
expected_hidden_size: int | None = None,
|
||||
) -> None:
|
||||
super().__init__(data, "video", expected_hidden_size)
|
||||
|
||||
|
||||
class VisionChunkProcessorItems(ProcessorBatchItems[Any]):
|
||||
"""Processor items for vision chunks (unified image and video chunks)."""
|
||||
|
||||
def __init__(self, data: Sequence[Any]) -> None:
|
||||
super().__init__(data, "vision_chunk")
|
||||
|
||||
|
||||
_D = TypeVar("_D", bound=ModalityDataItems[Any, Any])
|
||||
|
||||
|
||||
class MultiModalDataItems(UserDict[str, ModalityDataItems[Any, Any]]):
|
||||
"""
|
||||
As [`MultiModalDataDict`][vllm.multimodal.inputs.MultiModalDataDict], but
|
||||
normalized such that each entry corresponds to a list.
|
||||
"""
|
||||
|
||||
def select(self, modalities: Set[str]):
|
||||
"""
|
||||
Construct a new `MultiModalDataItems` instance containing only the
|
||||
selected modalities.
|
||||
"""
|
||||
return MultiModalDataItems(
|
||||
{modality: self[modality] for modality in modalities}
|
||||
)
|
||||
|
||||
def get_count(self, modality: str, *, strict: bool = True) -> int:
|
||||
"""
|
||||
Get the number of data items belonging to a modality.
|
||||
|
||||
If `strict=False`, return `0` instead of raising [`KeyError`][]
|
||||
even if the modality is not found.
|
||||
"""
|
||||
if modality not in self:
|
||||
if strict:
|
||||
available_modalities = set(self.keys())
|
||||
raise KeyError(
|
||||
f"Modality {modality!r} not found. "
|
||||
f"Available modalities: {available_modalities}"
|
||||
)
|
||||
|
||||
return 0
|
||||
|
||||
return self[modality].get_count()
|
||||
|
||||
def get_all_counts(self) -> Mapping[str, int]:
|
||||
"""Get the number of items belonging to each modality."""
|
||||
return {m: items.get_count() for m, items in self.items()}
|
||||
|
||||
def get_items(
|
||||
self,
|
||||
modality: str,
|
||||
typ: type[_D] | tuple[type[_D], ...],
|
||||
) -> _D:
|
||||
"""
|
||||
Get the data items belonging to a modality,
|
||||
requiring that they belong to a certain type.
|
||||
"""
|
||||
if modality not in self:
|
||||
available_modalities = set(self.keys())
|
||||
raise KeyError(
|
||||
f"Modality {modality!r} not found. "
|
||||
f"Available modalities: {available_modalities}"
|
||||
)
|
||||
|
||||
items = self[modality]
|
||||
if not isinstance(items, typ):
|
||||
raise TypeError(
|
||||
f"Invalid type of data items for {modality=}. "
|
||||
f"Expected type: {typ}, but "
|
||||
f"found type: {type(items)}"
|
||||
)
|
||||
|
||||
return items # type: ignore[return-value]
|
||||
|
||||
|
||||
ModalityDataParser: TypeAlias = Callable[
|
||||
[ModalityData[Any]], ModalityDataItems[Any, Any] | None
|
||||
]
|
||||
|
||||
|
||||
class MultiModalDataParser:
|
||||
"""
|
||||
Parses [`MultiModalDataDict`][vllm.multimodal.inputs.MultiModalDataDict]
|
||||
into [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
|
||||
|
||||
Args:
|
||||
target_sr (float, optional): Enables automatic resampling of audio
|
||||
items to the model's expected sampling rate.
|
||||
target_channels (int, optional): Target number of audio channels.
|
||||
If provided, normalizes audio to this many channels (e.g., 1 for mono).
|
||||
If None, audio channels are passed through unchanged.
|
||||
expected_hidden_size (int, optional): Expected hidden dimension for
|
||||
embedding inputs. If provided, validates that user-supplied
|
||||
embeddings have the correct hidden size to prevent crashes
|
||||
during model inference.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
target_sr: float | None = None,
|
||||
target_channels: int | None = None,
|
||||
audio_resample_method: Literal["librosa", "scipy"] = "librosa",
|
||||
video_needs_metadata: bool = False,
|
||||
expected_hidden_size: int | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.audio_resampler = AudioResampler(
|
||||
target_sr=target_sr,
|
||||
method=audio_resample_method,
|
||||
)
|
||||
self.target_channels = target_channels
|
||||
self.video_needs_metadata = video_needs_metadata
|
||||
self.expected_hidden_size = expected_hidden_size
|
||||
|
||||
@classmethod
|
||||
def is_embeddings(
|
||||
cls, data: object
|
||||
) -> TypeGuard[torch.Tensor | list[torch.Tensor]]:
|
||||
if isinstance(data, torch.Tensor):
|
||||
return data.ndim == 3
|
||||
if is_list_of(data, torch.Tensor) and len(data) > 0:
|
||||
return data[0].ndim == 2 # type: ignore[index]
|
||||
|
||||
return False
|
||||
|
||||
def _get_audio_with_sr(
|
||||
self,
|
||||
audio: AudioItem,
|
||||
) -> tuple[np.ndarray, float | None]:
|
||||
if isinstance(audio, tuple):
|
||||
return audio
|
||||
if isinstance(audio, list):
|
||||
return np.array(audio), None
|
||||
if isinstance(audio, np.ndarray):
|
||||
return audio, None
|
||||
if isinstance(audio, torch.Tensor):
|
||||
return audio.numpy(), None
|
||||
|
||||
assert_never(audio)
|
||||
|
||||
def _get_video_with_metadata(
|
||||
self,
|
||||
video: VideoItem,
|
||||
) -> tuple[np.ndarray, dict[str, Any] | None]:
|
||||
if isinstance(video, tuple):
|
||||
return video
|
||||
if isinstance(video, list):
|
||||
return np.array(video), None
|
||||
if isinstance(video, np.ndarray):
|
||||
return video, None
|
||||
if isinstance(video, torch.Tensor):
|
||||
return video.numpy(), None
|
||||
|
||||
assert_never(video)
|
||||
|
||||
def _parse_audio_data(
|
||||
self,
|
||||
data: ModalityData[AudioItem],
|
||||
) -> ModalityDataItems[Any, Any] | None:
|
||||
if data is None:
|
||||
return None
|
||||
|
||||
if self.is_embeddings(data):
|
||||
return AudioEmbeddingItems(data, self.expected_hidden_size)
|
||||
|
||||
data_items: list[AudioItem]
|
||||
if (
|
||||
(is_list_of(data, float) and len(data) > 0)
|
||||
or (isinstance(data, (np.ndarray, torch.Tensor)) and data.ndim == 1)
|
||||
or isinstance(data, tuple)
|
||||
):
|
||||
data_items = [data]
|
||||
elif isinstance(data, (np.ndarray, torch.Tensor)):
|
||||
data_items = [elem for elem in data]
|
||||
else:
|
||||
data_items = data # type: ignore[assignment]
|
||||
|
||||
new_audios = list[np.ndarray]()
|
||||
for data_item in data_items:
|
||||
audio, orig_sr = self._get_audio_with_sr(data_item)
|
||||
if orig_sr is None:
|
||||
new_audio = audio
|
||||
else:
|
||||
new_audio = self.audio_resampler.resample(audio, orig_sr=orig_sr)
|
||||
|
||||
# Apply channel normalization if target_channels is set
|
||||
if self.target_channels is not None:
|
||||
spec = AudioSpec(target_channels=self.target_channels)
|
||||
new_audio = normalize_audio(new_audio, spec)
|
||||
|
||||
new_audios.append(new_audio)
|
||||
|
||||
return AudioProcessorItems(new_audios)
|
||||
|
||||
def _parse_image_data(
|
||||
self,
|
||||
data: ModalityData[ImageItem],
|
||||
) -> ModalityDataItems[Any, Any] | None:
|
||||
if data is None:
|
||||
return None
|
||||
|
||||
if self.is_embeddings(data):
|
||||
return ImageEmbeddingItems(data, self.expected_hidden_size)
|
||||
|
||||
if isinstance(data, (PILImage.Image, MediaWithBytes)) or (
|
||||
isinstance(data, (np.ndarray, torch.Tensor)) and data.ndim == 3
|
||||
):
|
||||
data_items = [data]
|
||||
elif isinstance(data, (np.ndarray, torch.Tensor)):
|
||||
data_items = [elem for elem in data]
|
||||
else:
|
||||
data_items = data
|
||||
|
||||
return ImageProcessorItems(data_items)
|
||||
|
||||
def _parse_video_data(
|
||||
self,
|
||||
data: ModalityData[VideoItem],
|
||||
) -> ModalityDataItems[Any, Any] | None:
|
||||
if data is None:
|
||||
return None
|
||||
|
||||
if self.is_embeddings(data):
|
||||
return VideoEmbeddingItems(data, self.expected_hidden_size)
|
||||
|
||||
data_items: list[VideoItem]
|
||||
if (is_list_of(data, PILImage.Image) and len(data) > 0) or (
|
||||
isinstance(data, (np.ndarray, torch.Tensor)) and data.ndim == 4
|
||||
):
|
||||
data_items = [data]
|
||||
elif isinstance(data, (np.ndarray, torch.Tensor)):
|
||||
data_items = [elem for elem in data]
|
||||
elif isinstance(data, tuple) and len(data) == 2:
|
||||
data_items = [data]
|
||||
else:
|
||||
data_items = data # type: ignore[assignment]
|
||||
|
||||
new_videos = list[tuple[np.ndarray, dict[str, Any] | None]]()
|
||||
metadata_lst: list[dict[str, Any] | None] = []
|
||||
for data_item in data_items:
|
||||
video, metadata = self._get_video_with_metadata(data_item)
|
||||
if self.video_needs_metadata:
|
||||
if metadata is None:
|
||||
raise ValueError(
|
||||
"Video metadata is required but not found in mm input. "
|
||||
"Please check your video input in `multi_modal_data`"
|
||||
)
|
||||
new_videos.append((video, metadata))
|
||||
metadata_lst.append(metadata)
|
||||
else:
|
||||
new_videos.append(video)
|
||||
|
||||
if not self.video_needs_metadata:
|
||||
metadata = None
|
||||
|
||||
return VideoProcessorItems(new_videos, metadata=metadata_lst)
|
||||
|
||||
def _parse_vision_chunk_data(
|
||||
self,
|
||||
data: ModalityData[Any],
|
||||
) -> ModalityDataItems[Any, Any] | None:
|
||||
"""Parse vision chunk data (unified image and video chunks)."""
|
||||
if data is None:
|
||||
return None
|
||||
|
||||
if self.is_embeddings(data):
|
||||
raise ValueError("Do not support embedding data for vision_chunk right now")
|
||||
|
||||
if isinstance(data, dict):
|
||||
data = [data]
|
||||
|
||||
return VisionChunkProcessorItems(data)
|
||||
|
||||
def _get_subparsers(self) -> Mapping[str, ModalityDataParser]:
|
||||
return {
|
||||
"audio": self._parse_audio_data,
|
||||
"image": self._parse_image_data,
|
||||
"video": self._parse_video_data,
|
||||
"vision_chunk": self._parse_vision_chunk_data,
|
||||
}
|
||||
|
||||
def parse_mm_data(self, mm_data: MultiModalDataDict) -> MultiModalDataItems:
|
||||
subparsers = self._get_subparsers()
|
||||
|
||||
mm_items = MultiModalDataItems()
|
||||
for k, v in mm_data.items():
|
||||
if k not in subparsers:
|
||||
raise ValueError(f"Unsupported modality: {k}")
|
||||
|
||||
# ignore empty embedding data
|
||||
if (parsed_data := subparsers[k](v)) is not None:
|
||||
mm_items[k] = parsed_data
|
||||
|
||||
return mm_items
|
||||
|
||||
|
||||
MultiModalUUIDItems: TypeAlias = dict[str, Sequence[str | None]]
|
||||
"""
|
||||
As [`MultiModalUUIDDict`][vllm.multimodal.inputs.MultiModalUUIDDict], but
|
||||
normalized such that each entry corresponds to a list.
|
||||
"""
|
||||
|
||||
|
||||
def parse_mm_uuids(mm_uuids: MultiModalUUIDDict | None) -> MultiModalUUIDItems:
|
||||
if mm_uuids is None:
|
||||
return {}
|
||||
|
||||
return {
|
||||
modality: [uuids] if isinstance(uuids, str) else uuids
|
||||
for modality, uuids in mm_uuids.items()
|
||||
}
|
||||
29
vllm/multimodal/processing/__init__.py
Normal file
29
vllm/multimodal/processing/__init__.py
Normal file
@@ -0,0 +1,29 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from .context import BaseProcessingInfo, InputProcessingContext, TimingContext
|
||||
from .dummy_inputs import BaseDummyInputsBuilder
|
||||
from .inputs import ProcessorInputs
|
||||
from .processor import (
|
||||
BaseMultiModalProcessor,
|
||||
EncDecMultiModalProcessor,
|
||||
PromptIndexTargets,
|
||||
PromptInsertion,
|
||||
PromptReplacement,
|
||||
PromptUpdate,
|
||||
PromptUpdateDetails,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BaseProcessingInfo",
|
||||
"InputProcessingContext",
|
||||
"TimingContext",
|
||||
"BaseDummyInputsBuilder",
|
||||
"ProcessorInputs",
|
||||
"BaseMultiModalProcessor",
|
||||
"EncDecMultiModalProcessor",
|
||||
"PromptUpdate",
|
||||
"PromptIndexTargets",
|
||||
"PromptUpdateDetails",
|
||||
"PromptInsertion",
|
||||
"PromptReplacement",
|
||||
]
|
||||
507
vllm/multimodal/processing/context.py
Normal file
507
vllm/multimodal/processing/context.py
Normal file
@@ -0,0 +1,507 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Mapping
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Any, overload
|
||||
|
||||
import torch
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal.inputs import MultiModalDataDict
|
||||
from vllm.multimodal.parse import (
|
||||
DictEmbeddingItems,
|
||||
EmbeddingItems,
|
||||
MultiModalDataItems,
|
||||
MultiModalDataParser,
|
||||
)
|
||||
from vllm.renderers import TokenizeParams
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.transformers_utils.processor import cached_processor_from_config
|
||||
from vllm.utils.func_utils import get_allowed_kwarg_only_overrides
|
||||
from vllm.utils.jsontree import JSONTree, json_map_leaves
|
||||
from vllm.utils.mistral import is_mistral_tokenizer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
from transformers.processing_utils import ProcessorMixin
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
else:
|
||||
PretrainedConfig = object
|
||||
BatchFeature = object
|
||||
ProcessorMixin = object
|
||||
|
||||
ModelConfig = object
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimingContext:
|
||||
"""Helper class to record execution times during multi-modal processing."""
|
||||
|
||||
enabled: bool = True
|
||||
"""If disabled, `TimingContext.record` becomes a no-op."""
|
||||
|
||||
stage_secs: dict[str, float] = field(default_factory=dict)
|
||||
"""The execution time (in seconds) for each processing stage."""
|
||||
|
||||
@property
|
||||
def total_secs(self) -> float:
|
||||
return sum(self.stage_secs.values())
|
||||
|
||||
@contextmanager
|
||||
def record(self, stage: str):
|
||||
"""Record the execution time for a processing stage."""
|
||||
if not self.enabled:
|
||||
yield
|
||||
return
|
||||
|
||||
start_time = time.perf_counter()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
elapsed = time.perf_counter() - start_time
|
||||
self.stage_secs.setdefault(stage, 0.0)
|
||||
self.stage_secs[stage] += elapsed
|
||||
|
||||
def get_stats_dict(self):
|
||||
stats_dict = {
|
||||
f"{stage}_secs": time_s for stage, time_s in self.stage_secs.items()
|
||||
}
|
||||
stats_dict["preprocessor_total_secs"] = self.total_secs
|
||||
|
||||
return stats_dict
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_C = TypeVar("_C", bound=PretrainedConfig, default=PretrainedConfig)
|
||||
_P = TypeVar("_P", bound=ProcessorMixin, default=ProcessorMixin)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class InputProcessingContext:
|
||||
"""
|
||||
Contains information about the model which may be used to
|
||||
modify the inputs.
|
||||
"""
|
||||
|
||||
model_config: ModelConfig
|
||||
"""The configuration of the model."""
|
||||
|
||||
tokenizer: TokenizerLike | None
|
||||
"""The tokenizer used to tokenize the inputs."""
|
||||
|
||||
def get_tokenizer(self) -> TokenizerLike:
|
||||
if self.tokenizer is None:
|
||||
raise ValueError(
|
||||
"You cannot pass text prompts when `skip_tokenizer_init=True`"
|
||||
)
|
||||
|
||||
return self.tokenizer
|
||||
|
||||
@overload
|
||||
def get_hf_config(self, /) -> PretrainedConfig: ...
|
||||
|
||||
@overload
|
||||
def get_hf_config(
|
||||
self,
|
||||
typ: type[_C] | tuple[type[_C], ...],
|
||||
/,
|
||||
) -> _C: ...
|
||||
|
||||
def get_hf_config(
|
||||
self,
|
||||
typ: type[Any] | tuple[type[Any], ...] | None = None,
|
||||
/,
|
||||
) -> Any:
|
||||
"""
|
||||
Get the HuggingFace configuration
|
||||
(`transformers.PretrainedConfig`) of the model,
|
||||
additionally checking its type.
|
||||
|
||||
Raises:
|
||||
TypeError: If the configuration is not of the specified type.
|
||||
"""
|
||||
if typ is None:
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
typ = PretrainedConfig
|
||||
|
||||
hf_config = self.model_config.hf_config
|
||||
if not isinstance(hf_config, typ):
|
||||
raise TypeError(
|
||||
"Invalid type of HuggingFace config. "
|
||||
f"Expected type: {typ}, but "
|
||||
f"found type: {type(hf_config)}"
|
||||
)
|
||||
|
||||
return hf_config
|
||||
|
||||
def get_hf_image_processor_config(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get the HuggingFace image processor configuration of the model.
|
||||
"""
|
||||
return self.model_config.hf_image_processor_config
|
||||
|
||||
def get_mm_config(self):
|
||||
"""
|
||||
Get the multimodal config of the model.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the model is not a multimodal model.
|
||||
"""
|
||||
mm_config = self.model_config.multimodal_config
|
||||
if mm_config is None:
|
||||
raise RuntimeError("Not a multimodal model")
|
||||
|
||||
return mm_config
|
||||
|
||||
@overload
|
||||
def get_hf_processor(self, /, **kwargs: object) -> ProcessorMixin: ...
|
||||
|
||||
@overload
|
||||
def get_hf_processor(
|
||||
self,
|
||||
typ: type[_P] | tuple[type[_P], ...],
|
||||
/,
|
||||
**kwargs: object,
|
||||
) -> _P: ...
|
||||
|
||||
def get_hf_processor(
|
||||
self,
|
||||
typ: type[Any] | tuple[type[Any], ...] | None = None,
|
||||
/,
|
||||
**kwargs: object,
|
||||
) -> Any:
|
||||
"""
|
||||
Get the HuggingFace processor
|
||||
(`transformers.ProcessorMixin`) of the model,
|
||||
additionally checking its type.
|
||||
|
||||
Raises:
|
||||
TypeError: If the processor is not of the specified type.
|
||||
"""
|
||||
if typ is None:
|
||||
from transformers.processing_utils import ProcessorMixin
|
||||
|
||||
typ = ProcessorMixin
|
||||
|
||||
tokenizer = self.tokenizer
|
||||
if is_mistral_tokenizer(tokenizer):
|
||||
tokenizer = tokenizer.transformers_tokenizer
|
||||
|
||||
merged_kwargs = self.get_merged_mm_kwargs(kwargs)
|
||||
merged_kwargs.pop("tokenizer", None)
|
||||
|
||||
return cached_processor_from_config(
|
||||
self.model_config,
|
||||
processor_cls=typ,
|
||||
tokenizer=tokenizer,
|
||||
**merged_kwargs,
|
||||
)
|
||||
|
||||
def init_processor(
|
||||
self,
|
||||
typ: type[_T],
|
||||
/,
|
||||
**kwargs: object,
|
||||
) -> _T:
|
||||
"""
|
||||
Initialize a HuggingFace-like processor class, merging the
|
||||
keyword arguments with those in the model's configuration.
|
||||
"""
|
||||
merged_kwargs = self.get_merged_mm_kwargs(kwargs)
|
||||
|
||||
return typ(**merged_kwargs)
|
||||
|
||||
def _postprocess_output(
|
||||
self,
|
||||
output: JSONTree,
|
||||
) -> JSONTree:
|
||||
def _postprocess_one(x: object):
|
||||
if isinstance(x, torch.Tensor): # noqa: SIM102
|
||||
# This mimics the behavior of transformers.BatchFeature
|
||||
if x.is_floating_point():
|
||||
x = x.to(dtype=self.model_config.dtype)
|
||||
|
||||
return x
|
||||
|
||||
return json_map_leaves(_postprocess_one, output)
|
||||
|
||||
def get_merged_mm_kwargs(self, kwargs: Mapping[str, object]):
|
||||
mm_config = self.model_config.get_multimodal_config()
|
||||
return mm_config.merge_mm_processor_kwargs(kwargs)
|
||||
|
||||
def call_hf_processor(
|
||||
self,
|
||||
hf_processor: ProcessorMixin,
|
||||
data: Mapping[str, object],
|
||||
kwargs: Mapping[str, object] = {},
|
||||
*,
|
||||
num_tries: int = 1,
|
||||
max_tries: int = 5,
|
||||
) -> BatchFeature | JSONTree:
|
||||
"""
|
||||
Call `hf_processor` on the prompt `data`
|
||||
(text, image, audio...) with configurable options `kwargs`.
|
||||
"""
|
||||
assert callable(hf_processor)
|
||||
|
||||
merged_kwargs = self.get_merged_mm_kwargs(kwargs)
|
||||
|
||||
allowed_kwargs = get_allowed_kwarg_only_overrides(
|
||||
hf_processor,
|
||||
merged_kwargs,
|
||||
requires_kw_only=False,
|
||||
allow_var_kwargs=True,
|
||||
)
|
||||
|
||||
try:
|
||||
output = hf_processor(**data, **allowed_kwargs, return_tensors="pt")
|
||||
except Exception as exc:
|
||||
# See https://github.com/huggingface/tokenizers/issues/537
|
||||
if (
|
||||
isinstance(exc, RuntimeError)
|
||||
and exc
|
||||
and exc.args[0] == "Already borrowed"
|
||||
and num_tries < max_tries
|
||||
):
|
||||
logger.warning(
|
||||
"Failed to acquire tokenizer in current thread. "
|
||||
"Retrying (%d/%d)...",
|
||||
num_tries,
|
||||
max_tries,
|
||||
)
|
||||
time.sleep(0.5)
|
||||
return self.call_hf_processor(
|
||||
hf_processor,
|
||||
data,
|
||||
kwargs,
|
||||
num_tries=num_tries + 1,
|
||||
max_tries=max_tries,
|
||||
)
|
||||
|
||||
msg = (
|
||||
f"Failed to apply {type(hf_processor).__name__} "
|
||||
f"on data={data} with kwargs={allowed_kwargs}"
|
||||
)
|
||||
|
||||
raise ValueError(msg) from exc
|
||||
|
||||
# this emulates output.to(dtype=self.model_config.dtype)
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
|
||||
if isinstance(output, BatchFeature):
|
||||
output_ = self._postprocess_output(output.data)
|
||||
return BatchFeature(output_)
|
||||
|
||||
logger.warning_once(
|
||||
"%s did not return `BatchFeature`. "
|
||||
"Make sure to match the behaviour of `ProcessorMixin` when "
|
||||
"implementing custom processors.",
|
||||
type(hf_processor).__name__,
|
||||
)
|
||||
|
||||
return self._postprocess_output(output)
|
||||
|
||||
|
||||
class BaseProcessingInfo:
|
||||
"""Base class to provide the information necessary for data processing."""
|
||||
|
||||
def __init__(self, ctx: InputProcessingContext) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.ctx = ctx
|
||||
|
||||
@property
|
||||
def model_id(self) -> str:
|
||||
return self.ctx.model_config.model
|
||||
|
||||
def get_tokenizer(self) -> TokenizerLike:
|
||||
return self.ctx.get_tokenizer()
|
||||
|
||||
def get_hf_config(self) -> PretrainedConfig:
|
||||
return self.ctx.get_hf_config()
|
||||
|
||||
def get_hf_processor(self, **kwargs: object) -> ProcessorMixin:
|
||||
"""
|
||||
Subclasses can override this method to handle
|
||||
specific kwargs from model config or user inputs.
|
||||
"""
|
||||
return self.ctx.get_hf_processor(**kwargs)
|
||||
|
||||
def get_default_tok_params(self) -> TokenizeParams:
|
||||
"""Construct the default parameters for tokenization."""
|
||||
model_config = self.ctx.model_config
|
||||
encoder_config = model_config.encoder_config or {}
|
||||
|
||||
return TokenizeParams(
|
||||
max_total_tokens=model_config.max_model_len,
|
||||
do_lower_case=encoder_config.get("do_lower_case", False),
|
||||
add_special_tokens=True,
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def default_tok_params(self) -> TokenizeParams:
|
||||
return self.get_default_tok_params()
|
||||
|
||||
def _get_expected_hidden_size(self) -> int | None:
|
||||
"""
|
||||
Get expected hidden size for embedding validation if `mm_embeds` are enabled.
|
||||
|
||||
This validates hidden dimensions to prevent a vulnerability where embeddings
|
||||
with correct `ndim` but wrong `shape` could cause crashes at inference time.
|
||||
"""
|
||||
model_config = self.ctx.model_config
|
||||
mm_config = model_config.get_multimodal_config()
|
||||
|
||||
if mm_config.enable_mm_embeds:
|
||||
return model_config.get_inputs_embeds_size()
|
||||
|
||||
return None
|
||||
|
||||
def get_data_parser(self) -> MultiModalDataParser:
|
||||
"""
|
||||
Constructs a parser to preprocess multi-modal data items
|
||||
before passing them to
|
||||
[`_get_hf_mm_data`][vllm.multimodal.processing.BaseMultiModalProcessor._get_hf_mm_data].
|
||||
|
||||
You can support additional modalities by creating a subclass
|
||||
of [`MultiModalDataParser`][vllm.multimodal.parse.MultiModalDataParser]
|
||||
that has additional subparsers.
|
||||
"""
|
||||
return MultiModalDataParser(
|
||||
expected_hidden_size=self._get_expected_hidden_size(),
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def data_parser(self) -> MultiModalDataParser:
|
||||
return self.get_data_parser()
|
||||
|
||||
@property
|
||||
def skip_prompt_length_check(self) -> bool:
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
||||
"""
|
||||
Return the maximum supported number of items for each modality.
|
||||
|
||||
A value of `None` means unlimited number of items.
|
||||
|
||||
Omitting a modality from the returned dictionary means that
|
||||
it is not supported at all.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@cached_property
|
||||
def supported_mm_limits(self) -> Mapping[str, int | None]:
|
||||
"""The maximum supported number of items for each modality."""
|
||||
return self.get_supported_mm_limits()
|
||||
|
||||
@cached_property
|
||||
def allowed_mm_limits(self) -> Mapping[str, int]:
|
||||
"""The maximum allowed number of items for each modality."""
|
||||
mm_config = self.ctx.get_mm_config()
|
||||
|
||||
allowed_limits = dict[str, int]()
|
||||
for modality, supported_limit in self.supported_mm_limits.items():
|
||||
user_limit = mm_config.get_limit_per_prompt(modality)
|
||||
|
||||
allowed_limits[modality] = (
|
||||
user_limit
|
||||
if supported_limit is None
|
||||
else min(user_limit, supported_limit)
|
||||
)
|
||||
|
||||
return allowed_limits
|
||||
|
||||
def validate_num_items(self, modality: str, num_items: int) -> None:
|
||||
"""
|
||||
Raise `ValueError` if the number of input items for the given modality
|
||||
is invalid.
|
||||
"""
|
||||
supported_limit = self.supported_mm_limits.get(modality, 0)
|
||||
allowed_limit = self.allowed_mm_limits.get(modality, 0)
|
||||
|
||||
if supported_limit is None:
|
||||
supported_limit = allowed_limit
|
||||
|
||||
limit = min(supported_limit, allowed_limit)
|
||||
|
||||
if num_items > limit:
|
||||
msg = f"At most {limit} {modality}(s) may be provided in one prompt."
|
||||
|
||||
if num_items <= supported_limit:
|
||||
msg += " Set `--limit-mm-per-prompt` to increase this limit."
|
||||
|
||||
raise ValueError(msg)
|
||||
|
||||
def parse_mm_data(
|
||||
self,
|
||||
mm_data: MultiModalDataDict,
|
||||
*,
|
||||
validate: bool = True,
|
||||
) -> MultiModalDataItems:
|
||||
"""
|
||||
Normalize
|
||||
[`MultiModalDataDict`][vllm.multimodal.inputs.MultiModalDataDict]
|
||||
to [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems]
|
||||
before passing them to
|
||||
[`_get_hf_mm_data`][vllm.multimodal.processing.BaseMultiModalProcessor._get_hf_mm_data].
|
||||
"""
|
||||
mm_items = self.data_parser.parse_mm_data(mm_data)
|
||||
|
||||
if validate:
|
||||
mm_config = self.ctx.get_mm_config()
|
||||
|
||||
for modality, items in mm_items.items():
|
||||
if isinstance(items, (EmbeddingItems, DictEmbeddingItems)):
|
||||
if not mm_config.enable_mm_embeds:
|
||||
raise ValueError(
|
||||
f"You must set `--enable-mm-embeds` to input "
|
||||
f"`{modality}_embeds`"
|
||||
)
|
||||
if mm_config.get_limit_per_prompt(modality) == 0:
|
||||
logger.debug(
|
||||
"Skipping count validation for modality "
|
||||
"'%s' (embeddings with limit=0)",
|
||||
modality,
|
||||
)
|
||||
continue
|
||||
self.validate_num_items(modality, len(items))
|
||||
|
||||
return mm_items
|
||||
|
||||
def get_mm_max_tokens_per_item(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> Mapping[str, int] | None:
|
||||
"""
|
||||
Return the maximum number of tokens per item of for each modality.
|
||||
|
||||
When `None` (the default) is returned, vLLM will generate dummy inputs
|
||||
(images/videos) at maximum possible sizes and process them to determine
|
||||
the maximum token count per modality.
|
||||
|
||||
This approach works but can be very slow for certain models (e.g.,
|
||||
Qwen2.5-VL), leading to very long startup time. For better performance,
|
||||
each model can override this method to return pre-computed maximum token
|
||||
counts, avoiding the need for dummy input generation and processing.
|
||||
|
||||
Note:
|
||||
The maximum number of tokens per item of each modality returned
|
||||
from this function should respect the model's maximum sequence
|
||||
length and the maximum number of items of each modality allowed,
|
||||
and agree with dummy inputs (images/videos) at maximum possible
|
||||
sizes.
|
||||
"""
|
||||
return None
|
||||
187
vllm/multimodal/processing/dummy_inputs.py
Normal file
187
vllm/multimodal/processing/dummy_inputs.py
Normal file
@@ -0,0 +1,187 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
from PIL import Image
|
||||
|
||||
from vllm.config.multimodal import (
|
||||
AudioDummyOptions,
|
||||
BaseDummyOptions,
|
||||
ImageDummyOptions,
|
||||
VideoDummyOptions,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from ..inputs import MultiModalDataDict
|
||||
from .context import BaseProcessingInfo
|
||||
from .inputs import ProcessorInputs
|
||||
|
||||
_I = TypeVar("_I", bound=BaseProcessingInfo)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class BaseDummyInputsBuilder(ABC, Generic[_I]):
|
||||
"""
|
||||
Abstract base class that constructs the dummy data to profile
|
||||
multi-modal models.
|
||||
"""
|
||||
|
||||
def __init__(self, info: _I) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.info = info
|
||||
|
||||
@abstractmethod
|
||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||
"""
|
||||
Build the text input corresponding to `mm_counts`.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_dummy_mm_data(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Mapping[str, BaseDummyOptions],
|
||||
) -> MultiModalDataDict:
|
||||
"""
|
||||
Build the multimodal input which, after processing, results in
|
||||
the maximum possible number of placeholder tokens.
|
||||
|
||||
Args:
|
||||
seq_len: Sequence length
|
||||
mm_counts: Count of items per modality
|
||||
mm_options: Configurable options per modality (optional).
|
||||
If None, use model defaults for backward compatibility.
|
||||
If provided, models can use these to customize dummy
|
||||
data generation.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Mapping[str, BaseDummyOptions],
|
||||
) -> ProcessorInputs:
|
||||
"""
|
||||
Build the input which, after processing, results in
|
||||
the maximum possible number of placeholder tokens.
|
||||
|
||||
Args:
|
||||
seq_len: Sequence length
|
||||
mm_counts: Count of items per modality
|
||||
mm_options: Configurable options per modality (optional)
|
||||
"""
|
||||
dummy_text = self.get_dummy_text(mm_counts)
|
||||
dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts, mm_options)
|
||||
dummy_mm_items = self.info.parse_mm_data(dummy_mm_data, validate=False)
|
||||
|
||||
tokenization_kwargs = {"truncation": False}
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt=dummy_text,
|
||||
mm_data_items=dummy_mm_items,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
|
||||
def _get_dummy_audios(
|
||||
self,
|
||||
*,
|
||||
length: int,
|
||||
num_audios: int,
|
||||
overrides: AudioDummyOptions | None = None,
|
||||
) -> list[npt.NDArray]:
|
||||
if num_audios == 0:
|
||||
return []
|
||||
if overrides and overrides.length:
|
||||
if overrides.length > length:
|
||||
logger.warning(
|
||||
"audio.length override (%d) exceeds model's "
|
||||
"maximum length (%d), will be ignored",
|
||||
overrides.length,
|
||||
length,
|
||||
)
|
||||
length = min(length, overrides.length)
|
||||
audio = np.zeros((length,))
|
||||
return [audio] * num_audios
|
||||
|
||||
def _get_dummy_images(
|
||||
self,
|
||||
*,
|
||||
width: int,
|
||||
height: int,
|
||||
num_images: int,
|
||||
overrides: ImageDummyOptions | None = None,
|
||||
) -> list[Image.Image]:
|
||||
if num_images == 0:
|
||||
return []
|
||||
if overrides:
|
||||
if overrides.width:
|
||||
if overrides.width > width:
|
||||
logger.warning(
|
||||
"image.width override (%d) exceeds model's "
|
||||
"maximum width (%d), will be ignored",
|
||||
overrides.width,
|
||||
width,
|
||||
)
|
||||
width = min(width, overrides.width)
|
||||
if overrides.height:
|
||||
if overrides.height > height:
|
||||
logger.warning(
|
||||
"image.height override (%d) exceeds model's "
|
||||
"maximum height (%d), will be ignored",
|
||||
overrides.height,
|
||||
height,
|
||||
)
|
||||
height = min(height, overrides.height)
|
||||
image = Image.new("RGB", (width, height), color=255)
|
||||
return [image] * num_images
|
||||
|
||||
def _get_dummy_videos(
|
||||
self,
|
||||
*,
|
||||
width: int,
|
||||
height: int,
|
||||
num_frames: int,
|
||||
num_videos: int,
|
||||
overrides: VideoDummyOptions | None = None,
|
||||
) -> list[npt.NDArray]:
|
||||
if num_videos == 0:
|
||||
return []
|
||||
if overrides:
|
||||
if overrides.num_frames:
|
||||
if overrides.num_frames > num_frames:
|
||||
logger.warning(
|
||||
"video.num_frames override (%d) exceeds model's "
|
||||
"maximum number of frames (%d), will be ignored",
|
||||
overrides.num_frames,
|
||||
num_frames,
|
||||
)
|
||||
num_frames = min(num_frames, overrides.num_frames)
|
||||
if overrides.width:
|
||||
if overrides.width > width:
|
||||
logger.warning(
|
||||
"video.width override (%d) exceeds model's "
|
||||
"maximum width (%d), will be ignored",
|
||||
overrides.width,
|
||||
width,
|
||||
)
|
||||
width = min(width, overrides.width)
|
||||
if overrides.height:
|
||||
if overrides.height > height:
|
||||
logger.warning(
|
||||
"video.height override (%d) exceeds model's "
|
||||
"maximum height (%d), will be ignored",
|
||||
overrides.height,
|
||||
height,
|
||||
)
|
||||
height = min(height, overrides.height)
|
||||
video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8)
|
||||
return [video] * num_videos
|
||||
70
vllm/multimodal/processing/inputs.py
Normal file
70
vllm/multimodal/processing/inputs.py
Normal file
@@ -0,0 +1,70 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from ..hasher import MultiModalHasher
|
||||
from ..inputs import MultiModalHashes
|
||||
from ..parse import MultiModalDataItems, MultiModalUUIDItems
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcessorInputs:
|
||||
"""
|
||||
Represents the keyword arguments to
|
||||
[`vllm.multimodal.processing.BaseMultiModalProcessor.apply`][].
|
||||
"""
|
||||
|
||||
prompt: str | list[int]
|
||||
mm_data_items: MultiModalDataItems
|
||||
mm_uuid_items: MultiModalUUIDItems | None = None
|
||||
hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict)
|
||||
tokenization_kwargs: Mapping[str, object] = field(default_factory=dict)
|
||||
|
||||
def get_mm_hashes(self, model_id: str) -> MultiModalHashes:
|
||||
mm_data_items = self.mm_data_items
|
||||
mm_uuid_items = self.mm_uuid_items or {}
|
||||
hf_processor_mm_kwargs = self.hf_processor_mm_kwargs
|
||||
|
||||
mm_hashes: MultiModalHashes = {}
|
||||
hasher = MultiModalHasher
|
||||
|
||||
for modality, data_items in mm_data_items.items():
|
||||
if modality in mm_uuid_items:
|
||||
uuid_items = mm_uuid_items[modality]
|
||||
|
||||
# For None entries, compute a hash; otherwise, use provided ID.
|
||||
hashes: list[str] = []
|
||||
for i, item in enumerate(data_items.get_all_items_for_hash()):
|
||||
uuid_item = uuid_items[i]
|
||||
|
||||
# NOTE: Even if a uuid_item is provided, we still compute a hash
|
||||
# if `hf_processor_mm_kwargs` is provided.
|
||||
# This is because the processed multimodal inputs can be different
|
||||
# depending on the processor kwargs.
|
||||
if uuid_item is None or hf_processor_mm_kwargs:
|
||||
# NOTE: use provided hash string to hash with kwargs
|
||||
# if available for better performance.
|
||||
item = uuid_item if uuid_item is not None else item
|
||||
hashes.append(
|
||||
hasher.hash_kwargs(
|
||||
model_id=model_id,
|
||||
**{modality: item},
|
||||
**hf_processor_mm_kwargs,
|
||||
)
|
||||
)
|
||||
else:
|
||||
hashes.append(uuid_item)
|
||||
|
||||
mm_hashes[modality] = hashes
|
||||
else:
|
||||
mm_hashes[modality] = [
|
||||
hasher.hash_kwargs(
|
||||
model_id=model_id,
|
||||
**{modality: item},
|
||||
**hf_processor_mm_kwargs,
|
||||
)
|
||||
for item in data_items
|
||||
]
|
||||
|
||||
return mm_hashes
|
||||
1791
vllm/multimodal/processing/processor.py
Normal file
1791
vllm/multimodal/processing/processor.py
Normal file
File diff suppressed because it is too large
Load Diff
362
vllm/multimodal/registry.py
Normal file
362
vllm/multimodal/registry.py
Normal file
@@ -0,0 +1,362 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from multiprocessing.synchronize import Lock as LockType
|
||||
from typing import TYPE_CHECKING, Generic, Literal, Protocol, TypeVar, cast
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
|
||||
|
||||
from .cache import (
|
||||
BaseMultiModalProcessorCache,
|
||||
BaseMultiModalReceiverCache,
|
||||
MultiModalProcessorOnlyCache,
|
||||
MultiModalProcessorSenderCache,
|
||||
MultiModalReceiverCache,
|
||||
ShmObjectStoreReceiverCache,
|
||||
ShmObjectStoreSenderCache,
|
||||
)
|
||||
from .inputs import MultiModalInputs
|
||||
from .processing import (
|
||||
BaseDummyInputsBuilder,
|
||||
BaseMultiModalProcessor,
|
||||
BaseProcessingInfo,
|
||||
InputProcessingContext,
|
||||
TimingContext,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig, ObservabilityConfig, VllmConfig
|
||||
from vllm.model_executor.models.interfaces import SupportsMultiModal
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
N = TypeVar("N", bound=type["SupportsMultiModal"])
|
||||
_I = TypeVar("_I", bound=BaseProcessingInfo)
|
||||
_I_co = TypeVar("_I_co", bound=BaseProcessingInfo, covariant=True)
|
||||
|
||||
|
||||
class ProcessingInfoFactory(Protocol[_I_co]):
|
||||
"""
|
||||
Constructs a
|
||||
[`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor]
|
||||
instance from the context.
|
||||
"""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
ctx: InputProcessingContext,
|
||||
) -> _I_co: ...
|
||||
|
||||
|
||||
class DummyInputsBuilderFactory(Protocol[_I]): # type: ignore[misc]
|
||||
"""
|
||||
Constructs a
|
||||
[`BaseDummyInputsBuilder`][vllm.multimodal.processing.BaseDummyInputsBuilder]
|
||||
instance from the context.
|
||||
"""
|
||||
|
||||
def __call__(self, info: _I) -> BaseDummyInputsBuilder[_I]: ...
|
||||
|
||||
|
||||
class MultiModalProcessorFactory(Protocol[_I]): # type: ignore[misc]
|
||||
"""
|
||||
Constructs a
|
||||
[`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor]
|
||||
instance from the context.
|
||||
"""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
info: _I,
|
||||
dummy_inputs: BaseDummyInputsBuilder[_I],
|
||||
*,
|
||||
cache: BaseMultiModalProcessorCache | None = None,
|
||||
) -> BaseMultiModalProcessor[_I]: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _ProcessorFactories(Generic[_I]):
|
||||
info: ProcessingInfoFactory[_I]
|
||||
processor: MultiModalProcessorFactory[_I]
|
||||
dummy_inputs: DummyInputsBuilderFactory[_I]
|
||||
|
||||
def build_processor(
|
||||
self,
|
||||
ctx: InputProcessingContext,
|
||||
*,
|
||||
cache: BaseMultiModalProcessorCache | None = None,
|
||||
):
|
||||
info = self.info(ctx)
|
||||
dummy_inputs_builder = self.dummy_inputs(info)
|
||||
return self.processor(info, dummy_inputs_builder, cache=cache)
|
||||
|
||||
|
||||
class MultiModalRegistry:
|
||||
"""
|
||||
A registry that dispatches data processing according to the model.
|
||||
"""
|
||||
|
||||
def supports_multimodal_inputs(self, model_config: "ModelConfig") -> bool:
|
||||
"""
|
||||
Checks if the model supports multimodal inputs.
|
||||
Returns True if the model is multimodal with any non-zero supported
|
||||
modalities, otherwise returns False, effectively running in
|
||||
text-only mode.
|
||||
"""
|
||||
if not model_config.is_multimodal_model:
|
||||
return False
|
||||
|
||||
mm_config = model_config.get_multimodal_config()
|
||||
info = self._create_processing_info(model_config, tokenizer=None)
|
||||
|
||||
# Check if all supported modalities have limit == 0
|
||||
if all(
|
||||
mm_config.get_limit_per_prompt(modality) == 0
|
||||
for modality in info.supported_mm_limits
|
||||
):
|
||||
# If enable_mm_embeds is True, we still need MM infrastructure
|
||||
# to process pre-computed embeddings even though encoder won't run
|
||||
if mm_config.enable_mm_embeds:
|
||||
return True
|
||||
|
||||
logger.info_once(
|
||||
"All limits of multimodal modalities supported by the model "
|
||||
"are set to 0, running in text-only mode."
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def register_processor(
|
||||
self,
|
||||
processor: MultiModalProcessorFactory[_I],
|
||||
*,
|
||||
info: ProcessingInfoFactory[_I],
|
||||
dummy_inputs: DummyInputsBuilderFactory[_I],
|
||||
):
|
||||
"""
|
||||
Register a multi-modal processor to a model class. The processor
|
||||
is constructed lazily, hence a factory method should be passed.
|
||||
|
||||
When the model receives multi-modal data, the provided function is
|
||||
invoked to transform the data into a dictionary of model inputs.
|
||||
"""
|
||||
|
||||
def wrapper(model_cls: N) -> N:
|
||||
if "_processor_factory" in model_cls.__dict__:
|
||||
logger.warning(
|
||||
"Model class %s already has a multi-modal processor "
|
||||
"registered to %s. It is overwritten by the new one.",
|
||||
model_cls,
|
||||
self,
|
||||
)
|
||||
|
||||
model_cls._processor_factory = _ProcessorFactories(
|
||||
info=info,
|
||||
dummy_inputs=dummy_inputs,
|
||||
processor=processor,
|
||||
)
|
||||
|
||||
return model_cls
|
||||
|
||||
return wrapper
|
||||
|
||||
def _get_model_cls(self, model_config: "ModelConfig") -> "SupportsMultiModal":
|
||||
# Avoid circular import
|
||||
from vllm.model_executor.model_loader import get_model_architecture
|
||||
|
||||
model_cls, _ = get_model_architecture(model_config)
|
||||
assert hasattr(model_cls, "_processor_factory")
|
||||
return cast("SupportsMultiModal", model_cls)
|
||||
|
||||
def _create_processing_ctx(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
tokenizer: TokenizerLike | None = None,
|
||||
) -> InputProcessingContext:
|
||||
if tokenizer is None:
|
||||
tokenizer = cached_tokenizer_from_config(model_config)
|
||||
|
||||
return InputProcessingContext(model_config, tokenizer)
|
||||
|
||||
def _create_processing_info(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
tokenizer: TokenizerLike | None = None,
|
||||
) -> BaseProcessingInfo:
|
||||
model_cls = self._get_model_cls(model_config)
|
||||
factories = model_cls._processor_factory
|
||||
ctx = self._create_processing_ctx(model_config, tokenizer)
|
||||
return factories.info(ctx)
|
||||
|
||||
def create_processor(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
*,
|
||||
tokenizer: TokenizerLike | None = None,
|
||||
cache: BaseMultiModalProcessorCache | None = None,
|
||||
) -> BaseMultiModalProcessor[BaseProcessingInfo]:
|
||||
"""
|
||||
Create a multi-modal processor for a specific model and tokenizer.
|
||||
"""
|
||||
if not model_config.is_multimodal_model:
|
||||
raise ValueError(f"{model_config.model} is not a multimodal model")
|
||||
|
||||
model_cls = self._get_model_cls(model_config)
|
||||
factories = model_cls._processor_factory
|
||||
|
||||
ctx = self._create_processing_ctx(model_config, tokenizer)
|
||||
|
||||
return factories.build_processor(ctx, cache=cache)
|
||||
|
||||
def get_dummy_mm_inputs(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
mm_counts: Mapping[str, int],
|
||||
*,
|
||||
cache: BaseMultiModalProcessorCache | None = None,
|
||||
processor: BaseMultiModalProcessor | None = None,
|
||||
) -> MultiModalInputs:
|
||||
"""
|
||||
Create dummy data for profiling the memory usage of a model.
|
||||
|
||||
The model is identified by `model_config`.
|
||||
"""
|
||||
seq_len = model_config.max_model_len
|
||||
|
||||
if processor is None:
|
||||
processor = self.create_processor(model_config, cache=cache)
|
||||
|
||||
mm_config = model_config.get_multimodal_config()
|
||||
processor_inputs = processor.dummy_inputs.get_dummy_processor_inputs(
|
||||
seq_len=seq_len,
|
||||
mm_counts=mm_counts,
|
||||
mm_options=mm_config.limit_per_prompt,
|
||||
)
|
||||
mm_inputs = processor.apply(
|
||||
processor_inputs,
|
||||
timing_ctx=TimingContext(enabled=False),
|
||||
)
|
||||
|
||||
prompt_token_ids = mm_inputs["prompt_token_ids"]
|
||||
total_len = len(prompt_token_ids)
|
||||
if total_len < seq_len:
|
||||
prompt_token_ids.extend([0] * (seq_len - total_len))
|
||||
|
||||
return mm_inputs
|
||||
|
||||
def _get_cache_type(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
) -> Literal[None, "processor_only", "lru", "shm"]:
|
||||
model_config = vllm_config.model_config
|
||||
if not self.supports_multimodal_inputs(model_config):
|
||||
return None
|
||||
|
||||
# Check if the cache is disabled.
|
||||
mm_config = model_config.get_multimodal_config()
|
||||
if mm_config.mm_processor_cache_gb <= 0:
|
||||
return None
|
||||
|
||||
# Check if IPC caching is supported.
|
||||
parallel_config = vllm_config.parallel_config
|
||||
is_ipc_supported = parallel_config._api_process_count == 1 and (
|
||||
parallel_config.data_parallel_size == 1
|
||||
or parallel_config.data_parallel_external_lb
|
||||
)
|
||||
|
||||
if not is_ipc_supported:
|
||||
return "processor_only"
|
||||
|
||||
mm_config = model_config.get_multimodal_config()
|
||||
return mm_config.mm_processor_cache_type
|
||||
|
||||
def processor_cache_from_config(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
) -> BaseMultiModalProcessorCache | None:
|
||||
"""Return a `BaseMultiModalProcessorCache`, if enabled."""
|
||||
cache_type = self._get_cache_type(vllm_config)
|
||||
if cache_type is None:
|
||||
return None
|
||||
elif cache_type == "processor_only":
|
||||
return MultiModalProcessorOnlyCache(vllm_config.model_config)
|
||||
elif cache_type == "lru":
|
||||
return MultiModalProcessorSenderCache(vllm_config.model_config)
|
||||
elif cache_type == "shm":
|
||||
return ShmObjectStoreSenderCache(vllm_config)
|
||||
else:
|
||||
raise ValueError(f"Unknown cache type: {cache_type!r}")
|
||||
|
||||
def processor_only_cache_from_config(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
) -> MultiModalProcessorOnlyCache | None:
|
||||
"""Return a `MultiModalProcessorOnlyCache`, if enabled."""
|
||||
cache_type = self._get_cache_type(vllm_config)
|
||||
if cache_type is None:
|
||||
return None
|
||||
|
||||
return MultiModalProcessorOnlyCache(vllm_config.model_config)
|
||||
|
||||
def engine_receiver_cache_from_config(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
) -> BaseMultiModalReceiverCache | None:
|
||||
"""Return a `BaseMultiModalReceiverCache` for the engine process."""
|
||||
cache_type = self._get_cache_type(vllm_config)
|
||||
if cache_type in (None, "processor_only", "shm"):
|
||||
return None
|
||||
elif cache_type == "lru":
|
||||
return MultiModalReceiverCache(vllm_config.model_config)
|
||||
else:
|
||||
raise ValueError(f"Unknown cache type: {cache_type!r}")
|
||||
|
||||
def worker_receiver_cache_from_config(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
shared_worker_lock: LockType,
|
||||
) -> BaseMultiModalReceiverCache | None:
|
||||
"""Return a `BaseMultiModalReceiverCache` for the worker process."""
|
||||
cache_type = self._get_cache_type(vllm_config)
|
||||
if cache_type in (None, "processor_only", "lru"):
|
||||
return None
|
||||
elif cache_type == "shm":
|
||||
return ShmObjectStoreReceiverCache(vllm_config, shared_worker_lock)
|
||||
else:
|
||||
raise ValueError(f"Unknown cache type: {cache_type!r}")
|
||||
|
||||
|
||||
class MultiModalTimingRegistry:
|
||||
def __init__(self, observability_config: "ObservabilityConfig | None") -> None:
|
||||
super().__init__()
|
||||
|
||||
if observability_config and observability_config.enable_mm_processor_stats:
|
||||
self._lock = threading.Lock()
|
||||
self._ctx_by_request_id = defaultdict[str, TimingContext](TimingContext)
|
||||
self._enabled = True
|
||||
else:
|
||||
self._enabled = False
|
||||
|
||||
def get(self, request_id: str) -> TimingContext:
|
||||
if not self._enabled:
|
||||
return TimingContext(enabled=False)
|
||||
|
||||
with self._lock:
|
||||
return self._ctx_by_request_id[request_id]
|
||||
|
||||
def stat(self) -> dict[str, dict[str, float]]:
|
||||
if not self._enabled:
|
||||
return {}
|
||||
|
||||
with self._lock:
|
||||
stats = {
|
||||
req_id: ctx.get_stats_dict()
|
||||
for req_id, ctx in self._ctx_by_request_id.items()
|
||||
}
|
||||
self._ctx_by_request_id.clear()
|
||||
return stats
|
||||
327
vllm/multimodal/utils.py
Normal file
327
vllm/multimodal/utils.py
Normal file
@@ -0,0 +1,327 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import mimetypes
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from collections.abc import Generator, Sequence
|
||||
from itertools import groupby
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
from PIL import Image
|
||||
|
||||
from vllm.utils.import_utils import LazyLoader
|
||||
|
||||
from .hasher import MultiModalHasher
|
||||
from .inputs import (
|
||||
BatchedTensorInputs,
|
||||
MultiModalFieldElem,
|
||||
MultiModalKwargsItem,
|
||||
MultiModalPlaceholderDict,
|
||||
MultiModalSharedField,
|
||||
)
|
||||
from .media import AudioMediaIO, ImageMediaIO, MediaConnector, VideoMediaIO
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch.types
|
||||
else:
|
||||
torch = LazyLoader("torch", globals(), "torch")
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
if name == "MEDIA_CONNECTOR_REGISTRY":
|
||||
from .media import MEDIA_CONNECTOR_REGISTRY
|
||||
|
||||
warnings.warn(
|
||||
"`vllm.multimodal.utils.MEDIA_CONNECTOR_REGISTRY` "
|
||||
"has been moved to `vllm.multimodal.media.MEDIA_CONNECTOR_REGISTRY`. "
|
||||
"The old name will be removed in v0.17.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
return MEDIA_CONNECTOR_REGISTRY
|
||||
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
|
||||
def encode_audio_base64(
|
||||
audio: np.ndarray,
|
||||
sampling_rate: int,
|
||||
*,
|
||||
format: str = "WAV",
|
||||
) -> str:
|
||||
"""Encode audio as base64."""
|
||||
audio_io = AudioMediaIO()
|
||||
return audio_io.encode_base64((audio, sampling_rate), audio_format=format)
|
||||
|
||||
|
||||
def encode_audio_url(
|
||||
audio: np.ndarray,
|
||||
sampling_rate: int,
|
||||
*,
|
||||
format: str = "WAV",
|
||||
) -> str:
|
||||
"""Encode audio as a data URL."""
|
||||
audio_b64 = encode_audio_base64(audio, sampling_rate, format=format)
|
||||
mimetype = mimetypes.types_map.get("." + format.lower(), "audio")
|
||||
return f"data:{mimetype};base64,{audio_b64}"
|
||||
|
||||
|
||||
def encode_image_base64(
|
||||
image: Image.Image,
|
||||
*,
|
||||
image_mode: str = "RGB",
|
||||
format: str = "PNG",
|
||||
) -> str:
|
||||
"""
|
||||
Encode a pillow image to base64 format.
|
||||
|
||||
By default, the image is converted into RGB format before being encoded.
|
||||
"""
|
||||
image_io = ImageMediaIO(image_mode=image_mode)
|
||||
return image_io.encode_base64(image, image_format=format)
|
||||
|
||||
|
||||
def encode_image_url(
|
||||
image: Image.Image,
|
||||
*,
|
||||
image_mode: str = "RGB",
|
||||
format: str = "PNG",
|
||||
) -> str:
|
||||
"""
|
||||
Encode a pillow image as a data URL.
|
||||
|
||||
By default, the image is converted into RGB format before being encoded.
|
||||
"""
|
||||
image_b64 = encode_image_base64(image, image_mode=image_mode, format=format)
|
||||
mimetype = mimetypes.types_map.get("." + format.lower(), "image")
|
||||
return f"data:{mimetype};base64,{image_b64}"
|
||||
|
||||
|
||||
def encode_video_base64(
|
||||
frames: npt.NDArray,
|
||||
*,
|
||||
format: str = "JPEG",
|
||||
) -> str:
|
||||
image_io = ImageMediaIO()
|
||||
video_io = VideoMediaIO(image_io)
|
||||
return video_io.encode_base64(frames, video_format=format)
|
||||
|
||||
|
||||
def encode_video_url(
|
||||
frames: npt.NDArray,
|
||||
*,
|
||||
format: str = "JPEG",
|
||||
) -> str:
|
||||
video_b64 = encode_video_base64(frames, format=format)
|
||||
|
||||
if format.lower() == "jpeg":
|
||||
mimetype = "video/jpeg"
|
||||
else:
|
||||
mimetype = mimetypes.types_map.get("." + format.lower(), "video")
|
||||
|
||||
return f"data:{mimetype};base64,{video_b64}"
|
||||
|
||||
|
||||
def argsort_mm_positions(
|
||||
mm_positions: MultiModalPlaceholderDict,
|
||||
) -> list[tuple[str, int]]:
|
||||
"""
|
||||
Given a `MultiModalPlaceholderDict`, output a sequence of keys to
|
||||
sort the dictionary by `offset` (starting index in the input sequence)
|
||||
in ascending order.
|
||||
|
||||
Returns:
|
||||
A list of `(modality, idx)`, which can be used to access an item
|
||||
by `mm_positions[modality][idx]`.
|
||||
"""
|
||||
flat_items = (
|
||||
(modality, idx, item)
|
||||
for modality, items in mm_positions.items()
|
||||
for idx, item in enumerate(items)
|
||||
)
|
||||
|
||||
sorted_flat_items = sorted(flat_items, key=lambda x: x[2].offset)
|
||||
|
||||
return [(modality, idx) for modality, idx, _ in sorted_flat_items]
|
||||
|
||||
|
||||
def _get_group_hash(elem: MultiModalFieldElem):
|
||||
if not isinstance(elem.field, MultiModalSharedField):
|
||||
return None
|
||||
|
||||
return MultiModalHasher.hash_kwargs(data=elem.data)
|
||||
|
||||
|
||||
def _batch_mm_items(
|
||||
items: Sequence[MultiModalKwargsItem],
|
||||
*,
|
||||
device: torch.types.Device = None,
|
||||
pin_memory: bool = False,
|
||||
):
|
||||
elems = defaultdict[str, list[MultiModalFieldElem]](list)
|
||||
for item in items:
|
||||
for key, elem in item.items():
|
||||
elems[key].append(elem)
|
||||
|
||||
return {
|
||||
key: elems[0].field.reduce_data(
|
||||
elems,
|
||||
device=device,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
for key, elems in elems.items()
|
||||
}
|
||||
|
||||
|
||||
def group_and_batch_mm_items(
|
||||
items: Sequence[MultiModalKwargsItem],
|
||||
*,
|
||||
device: torch.types.Device = None,
|
||||
pin_memory: bool = False,
|
||||
) -> Generator[tuple[int, BatchedTensorInputs]]:
|
||||
"""
|
||||
Group consecutive items (possibly from different requests) into batches.
|
||||
|
||||
Items must be split across groups if any of the following occurs,
|
||||
as the batch would otherwise be invalid:
|
||||
- They have different fields (e.g. mixed image and embedding inputs).
|
||||
- They have different values in `MultiModalSharedField`.
|
||||
|
||||
Args:
|
||||
items: List of `MultiModalKwargsItem`.
|
||||
device: The device to place the grouped tensors on.
|
||||
pin_memory: Whether to pin memory for faster host-to-device transfer.
|
||||
|
||||
Yields:
|
||||
A tuple `(num_items, grouped_kwargs)`, where:
|
||||
- `kwargs` is a dictionary of keyword arguments to pass to the model;
|
||||
- `num_items` is the corresponding number of items.
|
||||
"""
|
||||
group_ids = [
|
||||
tuple(
|
||||
(key, _get_group_hash(elem))
|
||||
for key, elem in sorted(item.items(), key=lambda kv: kv[0])
|
||||
)
|
||||
for item in items
|
||||
]
|
||||
group_sizes = [sum(1 for _ in group) for _, group in groupby(group_ids)]
|
||||
|
||||
start_idx = 0
|
||||
for group_size in group_sizes:
|
||||
group_data = _batch_mm_items(
|
||||
items[start_idx : start_idx + group_size],
|
||||
device=device,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
|
||||
yield group_size, group_data
|
||||
|
||||
start_idx += group_size
|
||||
|
||||
assert start_idx == len(items)
|
||||
|
||||
|
||||
def group_mm_kwargs_by_modality(
|
||||
mm_kwargs: list[tuple[str, MultiModalKwargsItem]],
|
||||
*,
|
||||
device: torch.types.Device = None,
|
||||
pin_memory: bool = False,
|
||||
) -> Generator[tuple[str, int, BatchedTensorInputs], None, None]:
|
||||
"""
|
||||
Group consecutive items (possibly from different requests) into batches.
|
||||
|
||||
Items must be split across groups if any of the following occurs,
|
||||
as the batch would otherwise be invalid:
|
||||
- They have different fields (e.g. mixed image and embedding inputs).
|
||||
- They have different values in `MultiModalSharedField`.
|
||||
|
||||
To simplify the implementation of `embed_multimodal`, we add another
|
||||
restriction that the items in a batch must belong to the same modality.
|
||||
|
||||
Args:
|
||||
mm_kwargs: List of `(modality, item)`.
|
||||
device: The device to place the grouped tensors on.
|
||||
pin_memory: Whether to pin memory for faster host-to-device transfer.
|
||||
|
||||
Yields:
|
||||
A tuple `(modality, num_items, grouped_kwargs)`, where:
|
||||
- `modality` is the modality of the batch;
|
||||
- `kwargs` is a dictionary of keyword arguments to pass to the model;
|
||||
- `num_items` is the corresponding number of items.
|
||||
"""
|
||||
for modality, group in groupby(mm_kwargs, key=lambda x: x[0]):
|
||||
items_lst = [item for _, item in group]
|
||||
|
||||
for num_items, mm_kwargs_batch in group_and_batch_mm_items(
|
||||
items_lst,
|
||||
device=device,
|
||||
pin_memory=pin_memory,
|
||||
):
|
||||
yield modality, num_items, mm_kwargs_batch
|
||||
|
||||
|
||||
def fetch_audio(
|
||||
audio_url: str,
|
||||
audio_io_kwargs: dict[str, Any] | None = None,
|
||||
) -> tuple[np.ndarray, int | float]:
|
||||
"""
|
||||
Args:
|
||||
audio_url: URL of the audio file to fetch.
|
||||
audio_io_kwargs: Additional kwargs passed to handle audio IO.
|
||||
|
||||
Warning:
|
||||
This method has direct access to local files and is only intended
|
||||
to be called by user code. Never call this from the online server!
|
||||
"""
|
||||
media_io_kwargs = None if not audio_io_kwargs else {"audio": audio_io_kwargs}
|
||||
media_connector = MediaConnector(
|
||||
media_io_kwargs=media_io_kwargs,
|
||||
allowed_local_media_path="/",
|
||||
)
|
||||
return media_connector.fetch_audio(audio_url)
|
||||
|
||||
|
||||
def fetch_image(
|
||||
image_url: str,
|
||||
image_io_kwargs: dict[str, Any] | None = None,
|
||||
) -> Image.Image:
|
||||
"""
|
||||
Args:
|
||||
image_url: URL of the image file to fetch.
|
||||
image_io_kwargs: Additional kwargs passed to handle image IO.
|
||||
|
||||
Warning:
|
||||
This method has direct access to local files and is only intended
|
||||
to be called by user code. Never call this from the online server!
|
||||
"""
|
||||
media_io_kwargs = None if not image_io_kwargs else {"image": image_io_kwargs}
|
||||
media_connector = MediaConnector(
|
||||
media_io_kwargs=media_io_kwargs,
|
||||
allowed_local_media_path="/",
|
||||
)
|
||||
return media_connector.fetch_image(image_url)
|
||||
|
||||
|
||||
def fetch_video(
|
||||
video_url: str,
|
||||
video_io_kwargs: dict[str, Any] | None = None,
|
||||
) -> tuple[npt.NDArray, dict[str, Any]]:
|
||||
"""
|
||||
Args:
|
||||
video_url: URL of the video file to fetch.
|
||||
video_io_kwargs: Additional kwargs passed to handle video IO.
|
||||
|
||||
Warning:
|
||||
This method has direct access to local files and is only intended
|
||||
to be called by user code. Never call this from the online server!
|
||||
"""
|
||||
media_io_kwargs = None if not video_io_kwargs else {"video": video_io_kwargs}
|
||||
media_connector = MediaConnector(
|
||||
media_io_kwargs=media_io_kwargs,
|
||||
allowed_local_media_path="/",
|
||||
)
|
||||
return media_connector.fetch_video(video_url)
|
||||
836
vllm/multimodal/video.py
Normal file
836
vllm/multimodal/video.py
Normal file
@@ -0,0 +1,836 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import math
|
||||
from abc import abstractmethod
|
||||
from io import BytesIO
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import cv2
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.registry import ExtensionManager
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def resize_video(frames: npt.NDArray, size: tuple[int, int]) -> npt.NDArray:
|
||||
num_frames, _, _, channels = frames.shape
|
||||
new_height, new_width = size
|
||||
resized_frames = np.empty(
|
||||
(num_frames, new_height, new_width, channels), dtype=frames.dtype
|
||||
)
|
||||
# lazy import cv2 to avoid bothering users who only use text models
|
||||
import cv2
|
||||
|
||||
for i, frame in enumerate(frames):
|
||||
resized_frame = cv2.resize(frame, (new_width, new_height))
|
||||
resized_frames[i] = resized_frame
|
||||
return resized_frames
|
||||
|
||||
|
||||
def rescale_video_size(frames: npt.NDArray, size_factor: float) -> npt.NDArray:
|
||||
_, height, width, _ = frames.shape
|
||||
new_height = int(height * size_factor)
|
||||
new_width = int(width * size_factor)
|
||||
|
||||
return resize_video(frames, (new_height, new_width))
|
||||
|
||||
|
||||
def sample_frames_from_video(frames: npt.NDArray, num_frames: int) -> npt.NDArray:
|
||||
total_frames = frames.shape[0]
|
||||
if num_frames == -1:
|
||||
return frames
|
||||
|
||||
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
|
||||
sampled_frames = frames[frame_indices, ...]
|
||||
return sampled_frames
|
||||
|
||||
|
||||
class VideoLoader:
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def load_bytes(
|
||||
cls, data: bytes, num_frames: int = -1, **kwargs
|
||||
) -> tuple[npt.NDArray, dict[str, Any]]:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def _can_use_for_recovery(
|
||||
idx: int,
|
||||
failed_frames: list[int],
|
||||
next_target_map: dict[int, int],
|
||||
total_frames: int,
|
||||
) -> bool:
|
||||
"""Check if current frame can recover the oldest failed frame."""
|
||||
if not failed_frames:
|
||||
return False
|
||||
oldest_failed = failed_frames[0]
|
||||
limit = next_target_map.get(oldest_failed, total_frames)
|
||||
return idx < limit
|
||||
|
||||
@staticmethod
|
||||
def _read_frames_with_recovery(
|
||||
cap: "cv2.VideoCapture",
|
||||
frame_indices: list[int],
|
||||
total_frames: int,
|
||||
) -> tuple[npt.NDArray, list[int], dict[int, int]]:
|
||||
"""
|
||||
Read frames with dynamic window forward-scan recovery.
|
||||
|
||||
When a target frame fails to load, the next successfully grabbed
|
||||
frame (before the next target frame) will be used to recover it.
|
||||
|
||||
Args:
|
||||
cap: OpenCV VideoCapture object
|
||||
frame_indices: Sorted list of target frame indices to load
|
||||
total_frames: Total number of frames in the video
|
||||
|
||||
Returns:
|
||||
Tuple of (frames_array, valid_frame_indices, recovered_map)
|
||||
- frames_array: Array of loaded frames
|
||||
- valid_frame_indices: List of frame indices that were loaded
|
||||
- recovered_map: Dict mapping recovered_idx -> source_idx
|
||||
"""
|
||||
import cv2
|
||||
|
||||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
|
||||
assert width > 0 and height > 0, (
|
||||
f"Invalid video frame size: width={width}, height={height}"
|
||||
)
|
||||
|
||||
frame_idx_set = set(frame_indices)
|
||||
max_frame_idx = frame_indices[-1] if frame_indices else 0
|
||||
|
||||
# Build map: target_idx -> next_target_idx (for recovery window)
|
||||
next_target_map: dict[int, int] = {}
|
||||
for k in range(len(frame_indices) - 1):
|
||||
next_target_map[frame_indices[k]] = frame_indices[k + 1]
|
||||
next_target_map[frame_indices[-1]] = total_frames
|
||||
|
||||
frames_list: list[npt.NDArray] = []
|
||||
valid_frame_indices: list[int] = []
|
||||
failed_frames_idx: list[int] = []
|
||||
recovered_map: dict[int, int] = {}
|
||||
|
||||
i = 0
|
||||
for idx in range(max_frame_idx + 1):
|
||||
is_target_frame = idx in frame_idx_set
|
||||
|
||||
# Attempt to grab the current frame
|
||||
ok = cap.grab()
|
||||
|
||||
if not ok:
|
||||
if is_target_frame:
|
||||
logger.warning(
|
||||
"Failed to grab frame %d during video loading.",
|
||||
idx,
|
||||
)
|
||||
failed_frames_idx.append(idx)
|
||||
continue
|
||||
|
||||
# Check if we should retrieve: target frame OR can recover a failed one
|
||||
can_recover = VideoLoader._can_use_for_recovery(
|
||||
idx, failed_frames_idx, next_target_map, total_frames
|
||||
)
|
||||
|
||||
if is_target_frame or can_recover:
|
||||
ret, frame = cap.retrieve()
|
||||
|
||||
if ret and frame is not None and frame.size > 0:
|
||||
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
frames_list.append(rgb_frame)
|
||||
valid_frame_indices.append(idx)
|
||||
i += 1
|
||||
|
||||
if can_recover:
|
||||
recovered_idx = failed_frames_idx.pop(0)
|
||||
recovered_map[recovered_idx] = idx
|
||||
logger.info(
|
||||
"Recovered frame %d using frame %d (delay: %d)",
|
||||
recovered_idx,
|
||||
idx,
|
||||
idx - recovered_idx,
|
||||
)
|
||||
elif is_target_frame:
|
||||
logger.warning(
|
||||
"Failed to retrieve frame %d during video loading.",
|
||||
idx,
|
||||
)
|
||||
failed_frames_idx.append(idx)
|
||||
|
||||
# Log any remaining failed frames
|
||||
for failed_idx in failed_frames_idx:
|
||||
logger.warning(
|
||||
"Frame %d could not be recovered (end of video).",
|
||||
failed_idx,
|
||||
)
|
||||
|
||||
# Stack frames
|
||||
if frames_list:
|
||||
frames = np.stack(frames_list)
|
||||
else:
|
||||
frames = np.empty((0, height, width, 3), dtype=np.uint8)
|
||||
|
||||
return frames, valid_frame_indices, recovered_map
|
||||
|
||||
@staticmethod
|
||||
def _read_frames(
|
||||
cap,
|
||||
frame_indices: set[int],
|
||||
num_expected_frames: int,
|
||||
max_frame_idx: int,
|
||||
) -> tuple[npt.NDArray, int, list[int]]:
|
||||
import cv2
|
||||
|
||||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
frames = np.empty((num_expected_frames, height, width, 3), dtype=np.uint8)
|
||||
|
||||
i = 0
|
||||
valid_frame_indices = []
|
||||
for idx in range(max_frame_idx + 1):
|
||||
ok = cap.grab()
|
||||
if not ok:
|
||||
# Frame is broken/unreadable, log warning
|
||||
if idx in frame_indices:
|
||||
logger.warning(
|
||||
"Failed to grab frame %d during video loading. "
|
||||
"This frame will be skipped.",
|
||||
idx,
|
||||
)
|
||||
continue
|
||||
if idx in frame_indices:
|
||||
ret, frame = cap.retrieve()
|
||||
if ret:
|
||||
frames[i] = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
valid_frame_indices.append(idx)
|
||||
i += 1
|
||||
else:
|
||||
# retrieve() failed even though grab() succeeded
|
||||
logger.warning(
|
||||
"Failed to retrieve frame %d during video loading. "
|
||||
"This frame will be skipped.",
|
||||
idx,
|
||||
)
|
||||
|
||||
valid_num_frames = len(valid_frame_indices)
|
||||
if valid_num_frames < num_expected_frames:
|
||||
logger.warning(
|
||||
"Video loading completed with %d broken/unreadable frames. "
|
||||
"Expected %d frames but only loaded %d frames.",
|
||||
num_expected_frames - valid_num_frames,
|
||||
num_expected_frames,
|
||||
valid_num_frames,
|
||||
)
|
||||
|
||||
return frames[:valid_num_frames], valid_num_frames, valid_frame_indices
|
||||
|
||||
|
||||
VIDEO_LOADER_REGISTRY = ExtensionManager()
|
||||
|
||||
|
||||
@VIDEO_LOADER_REGISTRY.register("opencv")
|
||||
class OpenCVVideoBackend(VideoLoader):
|
||||
def get_cv2_video_api(self):
|
||||
import cv2.videoio_registry as vr
|
||||
|
||||
api_pref = None
|
||||
for backend in vr.getStreamBufferedBackends():
|
||||
if not vr.hasBackend(backend):
|
||||
continue
|
||||
if not vr.isBackendBuiltIn(backend):
|
||||
_, abi, api = vr.getStreamBufferedBackendPluginVersion(backend)
|
||||
if abi < 1 or (abi == 1 and api < 2):
|
||||
continue
|
||||
api_pref = backend
|
||||
break
|
||||
return api_pref
|
||||
|
||||
@classmethod
|
||||
def load_bytes(
|
||||
cls,
|
||||
data: bytes,
|
||||
num_frames: int = -1,
|
||||
fps: int = -1,
|
||||
max_duration: int = 300,
|
||||
frame_recovery: bool = False,
|
||||
**kwargs,
|
||||
) -> tuple[npt.NDArray, dict[str, Any]]:
|
||||
"""
|
||||
Load video frames from bytes.
|
||||
|
||||
Args:
|
||||
data: Raw video bytes
|
||||
num_frames: Target number of frames to sample (-1 for all)
|
||||
fps: Target FPS for sampling (-1 for original)
|
||||
max_duration: Maximum duration (unused in base backend)
|
||||
frame_recovery: Enable forward-scan recovery for failed frames
|
||||
|
||||
Returns:
|
||||
Tuple of (frames_array, metadata_dict)
|
||||
"""
|
||||
import cv2
|
||||
|
||||
backend = cls().get_cv2_video_api()
|
||||
cap = cv2.VideoCapture(BytesIO(data), backend, [])
|
||||
if not cap.isOpened():
|
||||
raise ValueError("Could not open video stream")
|
||||
|
||||
total_frames_num = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
original_fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
duration = total_frames_num / original_fps if original_fps > 0 else 0
|
||||
|
||||
# resample video to target num_frames and fps
|
||||
# - the minimum of the two will be used
|
||||
num_frames_to_sample = total_frames_num
|
||||
if num_frames > 0:
|
||||
num_frames_to_sample = min(num_frames, total_frames_num)
|
||||
if fps > 0:
|
||||
num_frames_to_sample = min(num_frames_to_sample, math.floor(duration * fps))
|
||||
num_frames_to_sample = max(1, num_frames_to_sample) # at least one sample
|
||||
|
||||
if num_frames_to_sample == total_frames_num:
|
||||
frame_idx = list(range(0, num_frames_to_sample))
|
||||
else:
|
||||
uniform_sampled_frames = np.linspace(
|
||||
0, total_frames_num - 1, num_frames_to_sample, dtype=int
|
||||
)
|
||||
frame_idx = uniform_sampled_frames.tolist()
|
||||
|
||||
if frame_recovery:
|
||||
frames, valid_frame_indices, recovered_map = cls._read_frames_with_recovery(
|
||||
cap, frame_idx, total_frames_num
|
||||
)
|
||||
valid_num_frames = len(valid_frame_indices)
|
||||
|
||||
if recovered_map:
|
||||
logger.info(
|
||||
"Frame recovery: %d frames recovered using forward scan.",
|
||||
len(recovered_map),
|
||||
)
|
||||
else:
|
||||
frame_idx_set = set(frame_idx)
|
||||
frames, valid_num_frames, valid_frame_indices = cls._read_frames(
|
||||
cap, frame_idx_set, num_frames_to_sample, max(frame_idx)
|
||||
)
|
||||
|
||||
# Use transformers transformers.video_utils.VideoMetadata format
|
||||
# NOTE(Isotr0py): For models like Qwen3-VL/GLM4.5V, this metadata
|
||||
# can cause incorrect timestamp calculation without num_frames=-1.
|
||||
metadata = {
|
||||
"total_num_frames": total_frames_num,
|
||||
"fps": original_fps,
|
||||
"duration": duration,
|
||||
"video_backend": "opencv",
|
||||
"frames_indices": valid_frame_indices,
|
||||
# extra field used to control hf processor's video
|
||||
# sampling behavior
|
||||
"do_sample_frames": valid_num_frames == total_frames_num,
|
||||
}
|
||||
|
||||
return frames, metadata
|
||||
|
||||
|
||||
@VIDEO_LOADER_REGISTRY.register("opencv_dynamic")
|
||||
class OpenCVDynamicVideoBackend(OpenCVVideoBackend):
|
||||
@classmethod
|
||||
def load_bytes(
|
||||
cls,
|
||||
data: bytes,
|
||||
num_frames: int = -1,
|
||||
fps: int = 2,
|
||||
max_duration: int = 300,
|
||||
frame_recovery: bool = False,
|
||||
**kwargs,
|
||||
) -> tuple[npt.NDArray, dict[str, Any]]:
|
||||
"""
|
||||
Load video frames with dynamic sampling based on duration.
|
||||
|
||||
Args:
|
||||
data: Raw video bytes
|
||||
num_frames: Not used in dynamic backend
|
||||
fps: Target FPS for sampling (default: 2)
|
||||
max_duration: Maximum video duration to process (default: 300s)
|
||||
frame_recovery: Enable forward-scan recovery for failed frames
|
||||
|
||||
Returns:
|
||||
Tuple of (frames_array, metadata_dict)
|
||||
"""
|
||||
import cv2
|
||||
|
||||
backend = cls().get_cv2_video_api()
|
||||
cap = cv2.VideoCapture(BytesIO(data), backend, [])
|
||||
if not cap.isOpened():
|
||||
raise ValueError("Could not open video stream")
|
||||
|
||||
total_frames_num = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
original_fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
duration = total_frames_num / original_fps if original_fps > 0 else 0
|
||||
|
||||
# resample video to target num_frames
|
||||
max_frame_idx = total_frames_num - 1
|
||||
duration = duration or round(max_frame_idx / original_fps) + 1
|
||||
|
||||
# Refer to:
|
||||
# https://github.com/huggingface/transformers/blob/v4.55.4/src/transformers/models/glm4v/video_processing_glm4v.py#L103-L140
|
||||
frame_indices_list: list[int]
|
||||
if duration <= max_duration:
|
||||
n = int(math.floor(duration * fps))
|
||||
frame_indices_list = sorted(
|
||||
{
|
||||
min(max_frame_idx, int(math.ceil(i * original_fps / fps)))
|
||||
for i in range(n)
|
||||
}
|
||||
)
|
||||
else:
|
||||
num_samples = int(max_duration * fps)
|
||||
if num_samples >= total_frames_num:
|
||||
frame_indices_list = list(range(total_frames_num))
|
||||
else:
|
||||
target_seconds = np.linspace(0, duration, num_samples, endpoint=True)
|
||||
frame_indices_list = sorted(
|
||||
{
|
||||
min(max_frame_idx, int(math.ceil(t * original_fps)))
|
||||
for t in target_seconds
|
||||
}
|
||||
)
|
||||
|
||||
if frame_recovery:
|
||||
frames, valid_frame_indices, recovered_map = cls._read_frames_with_recovery(
|
||||
cap, frame_indices_list, total_frames_num
|
||||
)
|
||||
valid_num_frames = len(valid_frame_indices)
|
||||
|
||||
if recovered_map:
|
||||
logger.info(
|
||||
"Frame recovery: %d frames recovered using forward scan.",
|
||||
len(recovered_map),
|
||||
)
|
||||
else:
|
||||
frame_indices_set = set(frame_indices_list)
|
||||
frames, valid_num_frames, valid_frame_indices = cls._read_frames(
|
||||
cap, frame_indices_set, len(frame_indices_list), total_frames_num - 1
|
||||
)
|
||||
|
||||
# Use transformers transformers.video_utils.VideoMetadata format
|
||||
metadata = {
|
||||
"total_num_frames": total_frames_num,
|
||||
"fps": original_fps,
|
||||
"duration": duration,
|
||||
"video_backend": "opencv_dynamic",
|
||||
"frames_indices": valid_frame_indices,
|
||||
"do_sample_frames": False,
|
||||
}
|
||||
|
||||
return frames, metadata
|
||||
|
||||
|
||||
@VIDEO_LOADER_REGISTRY.register("molmo2")
|
||||
class Molmo2VideoBackend(VideoLoader):
|
||||
def get_cv2_video_api(self):
|
||||
import cv2.videoio_registry as vr
|
||||
|
||||
api_pref = None
|
||||
for backend in vr.getStreamBufferedBackends():
|
||||
if not vr.hasBackend(backend):
|
||||
continue
|
||||
if not vr.isBackendBuiltIn(backend):
|
||||
_, abi, api = vr.getStreamBufferedBackendPluginVersion(backend)
|
||||
if abi < 1 or (abi == 1 and api < 2):
|
||||
continue
|
||||
api_pref = backend
|
||||
break
|
||||
return api_pref
|
||||
|
||||
@classmethod
|
||||
def get_candidate_target_fps(
|
||||
cls,
|
||||
video_fps: float,
|
||||
sampling_fps: float,
|
||||
max_fps: float = 8.0,
|
||||
) -> list[float]:
|
||||
"""
|
||||
Return the subset of `video_fps` factors that remain multiples
|
||||
of `sampling_fps`.
|
||||
|
||||
Examples:
|
||||
>>> get_candidate_target_fps(video_fps=6, sampling_fps=2)
|
||||
[2, 6]
|
||||
>>> get_candidate_target_fps(video_fps=5, sampling_fps=1)
|
||||
[1, 5]
|
||||
>>> get_candidate_target_fps(video_fps=2, sampling_fps=2)
|
||||
[2]
|
||||
>>> get_candidate_target_fps(video_fps=5, sampling_fps=2)
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
ValueError: sampling_fps=2 must divide video_fps=5 to produce
|
||||
consistent frame steps.
|
||||
"""
|
||||
video_fps = int(video_fps)
|
||||
sampling_fps = int(sampling_fps)
|
||||
max_fps = int(max_fps)
|
||||
|
||||
if sampling_fps is None:
|
||||
raise ValueError("sampling_fps must be provided")
|
||||
if video_fps <= 0 or sampling_fps <= 0:
|
||||
raise ValueError(
|
||||
"video_fps and sampling_fps must be positive "
|
||||
f"(got {video_fps}, {sampling_fps})"
|
||||
)
|
||||
if video_fps % sampling_fps != 0:
|
||||
raise ValueError(
|
||||
f"sampling_fps={sampling_fps} must divide video_fps={video_fps}."
|
||||
)
|
||||
|
||||
candidates = []
|
||||
for candidate in range(sampling_fps, video_fps + 1, sampling_fps):
|
||||
if candidate > max_fps:
|
||||
break
|
||||
if video_fps % candidate == 0:
|
||||
candidates.append(float(candidate))
|
||||
|
||||
return candidates
|
||||
|
||||
@classmethod
|
||||
def get_target_fps(
|
||||
cls,
|
||||
video_fps: float,
|
||||
max_frames: int,
|
||||
total_frames: int,
|
||||
frame_sample_mode: str,
|
||||
candidate_target_fps: list[float],
|
||||
) -> float | None:
|
||||
"""
|
||||
Get the target fps that best spans the videoand has the most frames sampled
|
||||
"""
|
||||
num_frames_sampled = 0
|
||||
selected_target_fps = None
|
||||
for target_fps in candidate_target_fps:
|
||||
step_size = max(int(video_fps / target_fps), 1)
|
||||
num_frames_sampled_at_fps = int(total_frames / step_size)
|
||||
if num_frames_sampled == 0:
|
||||
if (
|
||||
"uniform" in frame_sample_mode
|
||||
and num_frames_sampled_at_fps > max_frames
|
||||
):
|
||||
break
|
||||
selected_target_fps = target_fps
|
||||
num_frames_sampled = num_frames_sampled_at_fps
|
||||
|
||||
else:
|
||||
# the candidate sampling fps increases so frame count can't decrease
|
||||
assert num_frames_sampled <= num_frames_sampled_at_fps
|
||||
if num_frames_sampled_at_fps > max_frames:
|
||||
# choose the sampling fps that spans the video
|
||||
continue
|
||||
|
||||
elif num_frames_sampled_at_fps > num_frames_sampled:
|
||||
# both are less than max_frames; choose the one with higher
|
||||
# density of frames sampled
|
||||
selected_target_fps = target_fps
|
||||
num_frames_sampled = num_frames_sampled_at_fps
|
||||
return selected_target_fps
|
||||
|
||||
@classmethod
|
||||
def get_frame_times_and_chosen_fps(
|
||||
cls,
|
||||
selected_target_fps: float | None,
|
||||
total_frames: int,
|
||||
max_frames: int,
|
||||
video_fps: float,
|
||||
) -> tuple[float | None, npt.NDArray]:
|
||||
if selected_target_fps is None:
|
||||
frame_indices = np.linspace(
|
||||
0, total_frames, max_frames, endpoint=False, dtype=int
|
||||
)
|
||||
else:
|
||||
step_size = max(int(video_fps / selected_target_fps), 1)
|
||||
frame_indices = np.arange(0, total_frames, step_size)
|
||||
if len(frame_indices) > max_frames:
|
||||
frame_indices = frame_indices[:max_frames]
|
||||
return selected_target_fps, frame_indices
|
||||
|
||||
@classmethod
|
||||
def sample_times(
|
||||
cls,
|
||||
duration: float,
|
||||
max_frames: int,
|
||||
frame_sample_mode: str,
|
||||
max_fps: int | None,
|
||||
candidate_target_fps: list[float] | None = None,
|
||||
**kwargs,
|
||||
) -> npt.NDArray:
|
||||
if frame_sample_mode == "fps":
|
||||
assert candidate_target_fps is not None
|
||||
# Try larger and larger FPSs until we hit one that can't span the video
|
||||
sampling_fps = candidate_target_fps[0]
|
||||
for candidate_fps in candidate_target_fps[1:]:
|
||||
if max_frames / candidate_fps < duration:
|
||||
break
|
||||
sampling_fps = candidate_fps
|
||||
times = np.arange(0, max_frames) / sampling_fps
|
||||
times = times[times < duration]
|
||||
return times
|
||||
elif frame_sample_mode == "uniform_last_frame":
|
||||
if max_fps is not None:
|
||||
max_duration = (
|
||||
max_frames - 1
|
||||
) / max_fps # -1 to include the last frame
|
||||
if max_duration < duration:
|
||||
times = np.linspace(
|
||||
0, duration, num=max_frames, endpoint=True, dtype=np.float64
|
||||
)
|
||||
else:
|
||||
times = np.arange(0.0, stop=duration, step=1 / max_fps)
|
||||
times = np.concatenate([times, [duration]], axis=0)
|
||||
assert len(times) <= max_frames
|
||||
else:
|
||||
times = np.linspace(
|
||||
0, duration, num=max_frames, endpoint=True, dtype=np.float64
|
||||
)
|
||||
return times
|
||||
else:
|
||||
raise NotImplementedError(frame_sample_mode)
|
||||
|
||||
@classmethod
|
||||
def _sample_frames(
|
||||
cls,
|
||||
total_num_frames: int,
|
||||
video_fps: float,
|
||||
duration: float,
|
||||
frame_sample_mode: str,
|
||||
num_frames: int,
|
||||
max_fps: int,
|
||||
sampling_fps: int,
|
||||
) -> npt.NDArray:
|
||||
if frame_sample_mode == "uniform_last_frame" and max_fps is not None:
|
||||
if total_num_frames <= 2:
|
||||
indices = np.arange(total_num_frames).astype(int)
|
||||
elif duration > (num_frames - 1) / max_fps: # -1 to include the last frame
|
||||
# uniform fallback
|
||||
indices = np.linspace(
|
||||
0,
|
||||
total_num_frames - 1,
|
||||
num=min(num_frames, total_num_frames),
|
||||
endpoint=True,
|
||||
).astype(int)
|
||||
else:
|
||||
float_indices = np.arange(
|
||||
0.0,
|
||||
stop=total_num_frames - 1,
|
||||
step=float(video_fps / max_fps),
|
||||
)
|
||||
if np.round(float_indices[-1]) != total_num_frames - 1:
|
||||
float_indices = np.concatenate(
|
||||
[float_indices, [total_num_frames - 1]], axis=0
|
||||
)
|
||||
indices = np.round(float_indices).astype(int)
|
||||
assert indices[-1] < total_num_frames
|
||||
assert len(float_indices) <= num_frames
|
||||
elif frame_sample_mode == "uniform_last_frame":
|
||||
indices = np.linspace(
|
||||
0,
|
||||
total_num_frames - 1,
|
||||
num=min(num_frames, total_num_frames),
|
||||
endpoint=True,
|
||||
).astype(int)
|
||||
elif frame_sample_mode == "fps":
|
||||
candidate_target_fps = cls.get_candidate_target_fps(video_fps, sampling_fps)
|
||||
selected_target_fps = cls.get_target_fps(
|
||||
video_fps,
|
||||
num_frames,
|
||||
total_num_frames,
|
||||
frame_sample_mode,
|
||||
candidate_target_fps,
|
||||
)
|
||||
_, indices = cls.get_frame_times_and_chosen_fps(
|
||||
selected_target_fps,
|
||||
total_num_frames,
|
||||
num_frames,
|
||||
video_fps,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(frame_sample_mode)
|
||||
|
||||
return indices
|
||||
|
||||
@classmethod
|
||||
def load_bytes_opencv(
|
||||
cls,
|
||||
data: bytes,
|
||||
frame_sample_mode: str | None = None,
|
||||
num_frames: int = -1,
|
||||
max_fps: int = 2,
|
||||
sampling_fps: int = 2,
|
||||
**kwargs,
|
||||
) -> tuple[npt.NDArray, dict[str, Any]]:
|
||||
import cv2
|
||||
|
||||
backend = cls().get_cv2_video_api()
|
||||
cap = cv2.VideoCapture(BytesIO(data), backend, [])
|
||||
if not cap.isOpened():
|
||||
raise ValueError("Could not open video stream")
|
||||
|
||||
total_frames_num = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
original_fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
duration = total_frames_num / original_fps if original_fps > 0 else 0
|
||||
|
||||
if frame_sample_mode is None:
|
||||
# Use transformers transformers.video_utils.VideoMetadata format
|
||||
frame_idx = list(range(0, total_frames_num))
|
||||
frame_idx_set = set(frame_idx)
|
||||
frames, valid_num_frames, valid_frame_indices = cls._read_frames(
|
||||
cap, frame_idx_set, total_frames_num, max(frame_idx)
|
||||
)
|
||||
do_sample_frames = valid_num_frames == total_frames_num
|
||||
metadata = {
|
||||
"total_num_frames": total_frames_num,
|
||||
"fps": original_fps,
|
||||
"duration": duration,
|
||||
"video_backend": "opencv",
|
||||
"do_sample_frames": do_sample_frames,
|
||||
}
|
||||
if not do_sample_frames:
|
||||
metadata["frames_indices"] = valid_frame_indices
|
||||
return frames, metadata
|
||||
|
||||
frame_idx = cls._sample_frames(
|
||||
total_frames_num,
|
||||
original_fps,
|
||||
duration,
|
||||
frame_sample_mode,
|
||||
num_frames,
|
||||
max_fps,
|
||||
sampling_fps,
|
||||
).tolist()
|
||||
|
||||
frames, valid_num_frames, valid_frame_indices = cls._read_frames(
|
||||
cap,
|
||||
set(frame_idx),
|
||||
len(frame_idx),
|
||||
total_frames_num - 1,
|
||||
)
|
||||
|
||||
metadata = {
|
||||
"total_num_frames": total_frames_num,
|
||||
"fps": original_fps,
|
||||
"duration": duration,
|
||||
"video_backend": "opencv",
|
||||
"frames_indices": valid_frame_indices,
|
||||
"do_sample_frames": False,
|
||||
}
|
||||
|
||||
return frames, metadata
|
||||
|
||||
@classmethod
|
||||
def load_bytes(
|
||||
cls,
|
||||
data: bytes,
|
||||
num_frames: int = -1,
|
||||
**kwargs,
|
||||
) -> tuple[npt.NDArray, dict[str, Any]]:
|
||||
frame_sample_mode = cast(str | None, kwargs.pop("frame_sample_mode", None))
|
||||
max_fps = cast(int, kwargs.pop("max_fps", 2))
|
||||
sampling_fps = cast(int, kwargs.pop("sampling_fps", 2))
|
||||
out = cls.load_bytes_opencv(
|
||||
data,
|
||||
frame_sample_mode,
|
||||
num_frames,
|
||||
max_fps,
|
||||
sampling_fps,
|
||||
**kwargs,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
@VIDEO_LOADER_REGISTRY.register("openpangu")
|
||||
class OpenCVDynamicOpenPanguVideoBackend(OpenCVVideoBackend):
|
||||
@classmethod
|
||||
def load_bytes(
|
||||
cls,
|
||||
data: bytes,
|
||||
num_frames: int = 32,
|
||||
fps: int = 1,
|
||||
max_duration: int = 300,
|
||||
frame_recovery: bool = False,
|
||||
**kwargs,
|
||||
) -> tuple[npt.NDArray, dict[str, Any]]:
|
||||
"""
|
||||
Load video frames with dynamic sampling based on duration.
|
||||
Assume that total_num_frames = 10 and fps = 1.
|
||||
The timestamp of frame 0 is 0.0.
|
||||
The timestamp of frame 1 is 1.0.…
|
||||
The timestamp of frame 9 (the last frame) should be 9.0, that is,
|
||||
(total_frames_num – 1) / original_fps.
|
||||
|
||||
Args:
|
||||
data: Raw video bytes
|
||||
num_frames: Not used in dynamic backend
|
||||
fps: Target FPS for sampling (default: 1)
|
||||
|
||||
Returns:
|
||||
Tuple of (frames_array, metadata_dict)
|
||||
"""
|
||||
import cv2
|
||||
|
||||
backend = cls().get_cv2_video_api()
|
||||
cap = cv2.VideoCapture(BytesIO(data), backend, [])
|
||||
if not cap.isOpened():
|
||||
raise ValueError("Could not open video stream")
|
||||
|
||||
total_frames_num = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
original_fps = float(cap.get(cv2.CAP_PROP_FPS))
|
||||
# The timestamp of the rightmost frame, cannot be used to calculate frame 0.
|
||||
if total_frames_num >= 1 and original_fps > 0:
|
||||
total_duration = (total_frames_num - 1) / original_fps
|
||||
else:
|
||||
total_duration = 0
|
||||
|
||||
# `fps` is the FPS parameter passed in for sampling,
|
||||
# -1 indicates that sampling can be performed directly without FPS limitation.
|
||||
if fps > 0:
|
||||
# Num_frames is the maximum number of frames to sample.
|
||||
# If fewer frames are sampled at this sample_fps, the update duration will be longer. # noqa: E501
|
||||
if num_frames >= int(total_duration * fps) + 1:
|
||||
num_frames = int(total_duration * fps) + 1
|
||||
# Under the new maximum frame rate, the video duration of the rightmost frame, # noqa: E501
|
||||
# cannot be calculated for frame 0.
|
||||
total_duration = min(total_duration, (num_frames - 1) / fps)
|
||||
elif fps != -1:
|
||||
raise ValueError(
|
||||
f"requires dataset fps is -1 or greater than 0 but got {fps}"
|
||||
)
|
||||
|
||||
sample_frame_timestamps = np.linspace(
|
||||
0, total_duration, num_frames, dtype=float
|
||||
)
|
||||
frames_indices = [
|
||||
min(total_frames_num - 1, round(t * original_fps))
|
||||
for t in sample_frame_timestamps
|
||||
]
|
||||
|
||||
frames, valid_frame_indices, recovered_map = cls._read_frames_with_recovery(
|
||||
cap, frames_indices, total_frames_num
|
||||
)
|
||||
|
||||
if recovered_map:
|
||||
logger.info(
|
||||
"Frame recovery: %d frames recovered using forward scan.",
|
||||
len(recovered_map),
|
||||
)
|
||||
|
||||
metadata = {
|
||||
"total_num_frames": total_frames_num,
|
||||
"fps": original_fps,
|
||||
"duration": total_duration,
|
||||
"video_backend": "opencv_dynamic_openpangu",
|
||||
"frames_indices": valid_frame_indices,
|
||||
"do_sample_frames": False,
|
||||
}
|
||||
return frames, metadata
|
||||
Reference in New Issue
Block a user