init
This commit is contained in:
33
vllm/multimodal/__init__.py
Normal file
33
vllm/multimodal/__init__.py
Normal file
@@ -0,0 +1,33 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from .hasher import MultiModalHasher
|
||||
from .inputs import (BatchedTensorInputs, ModalityData, MultiModalDataBuiltins,
|
||||
MultiModalDataDict, MultiModalKwargs,
|
||||
MultiModalKwargsItems, MultiModalPlaceholderDict,
|
||||
MultiModalUUIDDict, NestedTensors)
|
||||
from .registry import MultiModalRegistry
|
||||
|
||||
MULTIMODAL_REGISTRY = MultiModalRegistry()
|
||||
"""
|
||||
The global [`MultiModalRegistry`][vllm.multimodal.registry.MultiModalRegistry]
|
||||
is used by model runners to dispatch data processing according to the target
|
||||
model.
|
||||
|
||||
Info:
|
||||
[mm_processing](../../../design/mm_processing.md)
|
||||
"""
|
||||
|
||||
__all__ = [
|
||||
"BatchedTensorInputs",
|
||||
"ModalityData",
|
||||
"MultiModalDataBuiltins",
|
||||
"MultiModalDataDict",
|
||||
"MultiModalHasher",
|
||||
"MultiModalKwargs",
|
||||
"MultiModalKwargsItems",
|
||||
"MultiModalPlaceholderDict",
|
||||
"MultiModalUUIDDict",
|
||||
"NestedTensors",
|
||||
"MULTIMODAL_REGISTRY",
|
||||
"MultiModalRegistry",
|
||||
]
|
||||
BIN
vllm/multimodal/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
vllm/multimodal/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm/multimodal/__pycache__/audio.cpython-312.pyc
Normal file
BIN
vllm/multimodal/__pycache__/audio.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm/multimodal/__pycache__/base.cpython-312.pyc
Normal file
BIN
vllm/multimodal/__pycache__/base.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm/multimodal/__pycache__/cache.cpython-312.pyc
Normal file
BIN
vllm/multimodal/__pycache__/cache.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm/multimodal/__pycache__/hasher.cpython-312.pyc
Normal file
BIN
vllm/multimodal/__pycache__/hasher.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm/multimodal/__pycache__/image.cpython-312.pyc
Normal file
BIN
vllm/multimodal/__pycache__/image.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm/multimodal/__pycache__/inputs.cpython-312.pyc
Normal file
BIN
vllm/multimodal/__pycache__/inputs.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm/multimodal/__pycache__/parse.cpython-312.pyc
Normal file
BIN
vllm/multimodal/__pycache__/parse.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm/multimodal/__pycache__/processing.cpython-312.pyc
Normal file
BIN
vllm/multimodal/__pycache__/processing.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm/multimodal/__pycache__/profiling.cpython-312.pyc
Normal file
BIN
vllm/multimodal/__pycache__/profiling.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm/multimodal/__pycache__/registry.cpython-312.pyc
Normal file
BIN
vllm/multimodal/__pycache__/registry.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm/multimodal/__pycache__/utils.cpython-312.pyc
Normal file
BIN
vllm/multimodal/__pycache__/utils.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm/multimodal/__pycache__/video.cpython-312.pyc
Normal file
BIN
vllm/multimodal/__pycache__/video.cpython-312.pyc
Normal file
Binary file not shown.
116
vllm/multimodal/audio.py
Normal file
116
vllm/multimodal/audio.py
Normal file
@@ -0,0 +1,116 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
from vllm.utils import PlaceholderModule
|
||||
|
||||
from .base import MediaIO
|
||||
|
||||
try:
|
||||
import librosa
|
||||
except ImportError:
|
||||
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
|
||||
|
||||
try:
|
||||
import soundfile
|
||||
except ImportError:
|
||||
soundfile = PlaceholderModule("soundfile") # type: ignore[assignment]
|
||||
|
||||
|
||||
def resample_audio_librosa(
|
||||
audio: npt.NDArray[np.floating],
|
||||
*,
|
||||
orig_sr: float,
|
||||
target_sr: float,
|
||||
) -> npt.NDArray[np.floating]:
|
||||
return librosa.resample(audio, orig_sr=orig_sr, target_sr=target_sr)
|
||||
|
||||
|
||||
def resample_audio_scipy(
|
||||
audio: npt.NDArray[np.floating],
|
||||
*,
|
||||
orig_sr: float,
|
||||
target_sr: float,
|
||||
):
|
||||
# lazy import scipy.signal, otherwise it will crash doc build.
|
||||
import scipy.signal
|
||||
|
||||
if orig_sr > target_sr:
|
||||
return scipy.signal.resample_poly(audio, 1, orig_sr // target_sr)
|
||||
elif orig_sr < target_sr:
|
||||
return scipy.signal.resample_poly(audio, target_sr // orig_sr, 1)
|
||||
return audio
|
||||
|
||||
|
||||
class AudioResampler:
|
||||
"""Resample audio data to a target sample rate."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target_sr: Optional[float] = None,
|
||||
method: Literal["librosa", "scipy"] = "librosa",
|
||||
):
|
||||
self.target_sr = target_sr
|
||||
self.method = method
|
||||
|
||||
def resample(
|
||||
self,
|
||||
audio: npt.NDArray[np.floating],
|
||||
*,
|
||||
orig_sr: float,
|
||||
) -> npt.NDArray[np.floating]:
|
||||
if self.target_sr is None:
|
||||
raise RuntimeError("Audio resampling is not supported when "
|
||||
"`target_sr` is not provided")
|
||||
if self.method == "librosa":
|
||||
return resample_audio_librosa(audio,
|
||||
orig_sr=orig_sr,
|
||||
target_sr=self.target_sr)
|
||||
elif self.method == "scipy":
|
||||
return resample_audio_scipy(audio,
|
||||
orig_sr=orig_sr,
|
||||
target_sr=self.target_sr)
|
||||
else:
|
||||
raise ValueError(f"Invalid resampling method: {self.method}. "
|
||||
"Supported methods are 'librosa' and 'scipy'.")
|
||||
|
||||
|
||||
class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]):
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__()
|
||||
|
||||
# `kwargs` contains custom arguments from
|
||||
# --media-io-kwargs for this modality.
|
||||
# They can be passed to the underlying
|
||||
# media loaders (e.g. custom implementations)
|
||||
# for flexible control.
|
||||
self.kwargs = kwargs
|
||||
|
||||
def load_bytes(self, data: bytes) -> tuple[npt.NDArray, float]:
|
||||
return librosa.load(BytesIO(data), sr=None)
|
||||
|
||||
def load_base64(
|
||||
self,
|
||||
media_type: str,
|
||||
data: str,
|
||||
) -> tuple[npt.NDArray, float]:
|
||||
return self.load_bytes(base64.b64decode(data))
|
||||
|
||||
def load_file(self, filepath: Path) -> tuple[npt.NDArray, float]:
|
||||
return librosa.load(filepath, sr=None)
|
||||
|
||||
def encode_base64(self, media: tuple[npt.NDArray, int]) -> str:
|
||||
audio, sr = media
|
||||
|
||||
with BytesIO() as buffer:
|
||||
soundfile.write(buffer, audio, sr, format="WAV")
|
||||
data = buffer.getvalue()
|
||||
|
||||
return base64.b64encode(data).decode('utf-8')
|
||||
27
vllm/multimodal/base.py
Normal file
27
vllm/multimodal/base.py
Normal file
@@ -0,0 +1,27 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
class MediaIO(ABC, Generic[_T]):
|
||||
|
||||
@abstractmethod
|
||||
def load_bytes(self, data: bytes) -> _T:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def load_base64(self, media_type: str, data: str) -> _T:
|
||||
"""
|
||||
List of media types:
|
||||
https://www.iana.org/assignments/media-types/media-types.xhtml
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def load_file(self, filepath: Path) -> _T:
|
||||
raise NotImplementedError
|
||||
697
vllm/multimodal/cache.py
Normal file
697
vllm/multimodal/cache.py
Normal file
@@ -0,0 +1,697 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import operator
|
||||
import sys
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping, Sequence
|
||||
from multiprocessing.synchronize import Lock as LockType
|
||||
from typing import TYPE_CHECKING, Generic, Optional, TypeVar, Union, cast
|
||||
|
||||
import torch
|
||||
from typing_extensions import TypeAlias, override
|
||||
|
||||
from vllm.distributed.device_communicators.shm_object_storage import (
|
||||
MsgpackSerde, SingleWriterShmObjectStorage, SingleWriterShmRingBuffer)
|
||||
from vllm.envs import VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import GiB_bytes, LRUCache, MiB_bytes
|
||||
from vllm.utils.jsontree import (json_count_leaves, json_map_leaves,
|
||||
json_reduce_leaves)
|
||||
|
||||
from .inputs import (MultiModalBatchedField, MultiModalFeatureSpec,
|
||||
MultiModalFieldElem, MultiModalKwargs,
|
||||
MultiModalKwargsItem, MultiModalKwargsItems,
|
||||
NestedTensors)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
|
||||
from .processing import ResolvedPromptUpdate
|
||||
from .registry import MultiModalRegistry
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MultiModalProcessorCacheItem:
|
||||
"""
|
||||
The data to store inside `MultiModalProcessorOnlyCache`.
|
||||
|
||||
Args:
|
||||
item: The processed tensor data corresponding to a multi-modal item.
|
||||
prompt_updates: The prompt updates corresponding to `item`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
item: MultiModalKwargsItem,
|
||||
prompt_updates: Sequence["ResolvedPromptUpdate"],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.item = item
|
||||
self.prompt_updates = prompt_updates
|
||||
|
||||
|
||||
class MultiModalProcessorCacheItemMetadata:
|
||||
"""
|
||||
The metadata to store inside `MultiModalProcessorSenderCache`.
|
||||
|
||||
Args:
|
||||
item: The processed tensor data corresponding to a multi-modal item.
|
||||
Since P1 already stores the tensor data, we only store its size
|
||||
metadata in P0 to reduce memory usage. The size metadata is still
|
||||
needed to keep the same cache eviction policy as P0.
|
||||
prompt_updates: The prompt updates corresponding to `item`.
|
||||
This needs to stay on P0 because for some models, they are
|
||||
dependent on the processed tensor data (cached on P1).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
item: MultiModalKwargsItem,
|
||||
prompt_updates: Sequence["ResolvedPromptUpdate"],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.item_size = MultiModalCache.get_item_size(item)
|
||||
self.prompt_updates = prompt_updates
|
||||
|
||||
|
||||
MultiModalCacheValue = Union[
|
||||
MultiModalProcessorCacheItem,
|
||||
MultiModalProcessorCacheItemMetadata,
|
||||
MultiModalKwargsItems,
|
||||
MultiModalKwargsItem,
|
||||
MultiModalKwargs,
|
||||
Mapping[str, NestedTensors],
|
||||
]
|
||||
|
||||
_V = TypeVar("_V", bound=MultiModalCacheValue)
|
||||
|
||||
|
||||
class MultiModalCache:
|
||||
|
||||
@classmethod
|
||||
def get_leaf_size(cls, leaf: object) -> int:
|
||||
if isinstance(leaf, MultiModalProcessorCacheItem):
|
||||
return cls.get_leaf_size(leaf.item)
|
||||
if isinstance(leaf, MultiModalProcessorCacheItemMetadata):
|
||||
return leaf.item_size
|
||||
|
||||
# These are not subclasses of dict
|
||||
if isinstance(leaf, (MultiModalKwargs, MultiModalKwargsItems,
|
||||
MultiModalKwargsItem, MultiModalFieldElem)):
|
||||
return cls.get_item_size(leaf.data) # type: ignore
|
||||
|
||||
# sys.getsizeof doesn't work for tensors
|
||||
if isinstance(leaf, torch.Tensor):
|
||||
return leaf.nbytes
|
||||
|
||||
return sys.getsizeof(leaf)
|
||||
|
||||
@classmethod
|
||||
def get_item_size(
|
||||
cls,
|
||||
value: MultiModalCacheValue,
|
||||
*,
|
||||
debug: bool = False,
|
||||
) -> int:
|
||||
size = json_reduce_leaves(operator.add,
|
||||
json_map_leaves(cls.get_leaf_size, value))
|
||||
|
||||
if debug:
|
||||
leaf_count = json_count_leaves(value)
|
||||
logger.debug(
|
||||
"Calculated size of %s to be %.2f GiB (%d leaves)",
|
||||
type(value),
|
||||
size / GiB_bytes,
|
||||
leaf_count,
|
||||
)
|
||||
|
||||
return size
|
||||
|
||||
@classmethod
|
||||
def get_item_complexity(cls, value: MultiModalCacheValue) -> int:
|
||||
"""
|
||||
Get the number of leaf elements in a multi-modal cache value.
|
||||
|
||||
This provides a measure of structural complexity that can be useful
|
||||
for debugging cache performance and understanding data patterns.
|
||||
|
||||
Args:
|
||||
value: The multi-modal cache value to analyze.
|
||||
|
||||
Returns:
|
||||
The number of leaf elements in the nested structure.
|
||||
"""
|
||||
return json_count_leaves(value)
|
||||
|
||||
@classmethod
|
||||
def get_lru_cache(
|
||||
cls,
|
||||
capacity_gb: float,
|
||||
value_type: type[_V],
|
||||
*,
|
||||
debug: bool = False,
|
||||
) -> LRUCache[str, _V]:
|
||||
return LRUCache(
|
||||
GiB_bytes * capacity_gb,
|
||||
getsizeof=lambda x: cls.get_item_size(x, debug=debug),
|
||||
)
|
||||
|
||||
|
||||
_I = TypeVar("_I", contravariant=True)
|
||||
_O = TypeVar("_O", covariant=True)
|
||||
|
||||
|
||||
class BaseMultiModalCache(ABC, Generic[_I, _O]):
|
||||
"""
|
||||
Abstract base class to read/write multi-modal items from cache.
|
||||
|
||||
The idea of multi-modal caching is based on having a client and server
|
||||
where the client executes in the frontend process (=P0) and
|
||||
the server in the core process (=P1). The data flow is as follows:
|
||||
|
||||
```
|
||||
is_cached() x N get_and_update()
|
||||
P0: From API -----------------> -----------------> To P1
|
||||
|
||||
get_and_update()
|
||||
P1: From P0 -----------------> To model
|
||||
```
|
||||
|
||||
`is_cached()` can be called any number of times in P0. However,
|
||||
`get_and_update()` must be called in P0 and P1 one after another
|
||||
so that their cache eviction order remains the same.
|
||||
|
||||
This ensures that the keys in P0 and P1 caches are mirrored,
|
||||
allowing us to determine whether a key is cached in P1 by looking
|
||||
up the P0 cache, without having to communicate with P1.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_and_update_item(
|
||||
self,
|
||||
mm_item: _I,
|
||||
mm_hash: str,
|
||||
) -> _O:
|
||||
"""
|
||||
Possibly update a multi-modal item based on whether it is
|
||||
in the underlying cache.
|
||||
|
||||
This update is done out-of-place and updates the cache eviction order.
|
||||
|
||||
Args:
|
||||
mm_item: The multi-modal item to update.
|
||||
mm_hash: The hash of `mm_item`.
|
||||
|
||||
Returns:
|
||||
The update multi-modal item.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_and_update(
|
||||
self,
|
||||
mm_items: Sequence[_I],
|
||||
mm_hashes: list[str],
|
||||
) -> list[_O]:
|
||||
"""
|
||||
Possibly update a sequence of multi-modal items based on whether they
|
||||
are in the underlying cache.
|
||||
|
||||
This update is done out-of-place and updates the cache eviction order.
|
||||
|
||||
Args:
|
||||
mm_items: The multi-modal items to update.
|
||||
mm_hashes: The hash of each item in `mm_items`.
|
||||
|
||||
Returns:
|
||||
A new list of updated multi-modal items.
|
||||
"""
|
||||
assert len(mm_items) == len(mm_hashes)
|
||||
|
||||
return [
|
||||
self.get_and_update_item(mm_item, mm_hash)
|
||||
for mm_item, mm_hash in zip(mm_items, mm_hashes)
|
||||
]
|
||||
|
||||
@abstractmethod
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear the underlying cache."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
MultiModalProcessorCacheInItem: TypeAlias = \
|
||||
Optional[tuple[MultiModalKwargsItem, Sequence["ResolvedPromptUpdate"]]]
|
||||
|
||||
|
||||
MultiModalProcessorCacheOutItem: TypeAlias = \
|
||||
tuple[Optional[MultiModalKwargsItem], Sequence["ResolvedPromptUpdate"]]
|
||||
|
||||
|
||||
class BaseMultiModalProcessorCache(
|
||||
BaseMultiModalCache[MultiModalProcessorCacheInItem,
|
||||
MultiModalProcessorCacheOutItem]):
|
||||
"""The required interface for caches on P0."""
|
||||
|
||||
@abstractmethod
|
||||
def is_cached_item(self, mm_hash: str) -> bool:
|
||||
"""
|
||||
Check whether a multi-modal item is
|
||||
in the underlying cache.
|
||||
|
||||
This **DOES NOT** update the cache eviction order.
|
||||
|
||||
Args:
|
||||
mm_hash: The hash of the item to check.
|
||||
|
||||
Returns:
|
||||
`True` if the item is cached, otherwise `False`.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def is_cached(self, mm_hashes: list[str]) -> list[bool]:
|
||||
"""
|
||||
Check whether a sequence of multi-modal items are
|
||||
in the underlying cache.
|
||||
|
||||
This **DOES NOT** update the cache eviction order.
|
||||
|
||||
Args:
|
||||
mm_hashes: The hash of each item to check.
|
||||
|
||||
Returns:
|
||||
For each item, `True` if the item is cached, otherwise `False`.
|
||||
"""
|
||||
return [self.is_cached_item(mm_hash) for mm_hash in mm_hashes]
|
||||
|
||||
|
||||
class MultiModalProcessorOnlyCache(BaseMultiModalProcessorCache):
|
||||
"""
|
||||
The cache which is used on P0 when IPC caching is disabled.
|
||||
|
||||
How to update each item:
|
||||
|
||||
- If the item is in the cache, replace the input with the cached item.
|
||||
- If the item is not in the cache, store that item (which includes
|
||||
tensor data and metadata) into the cache, and return the input.
|
||||
"""
|
||||
|
||||
def __init__(self, model_config: "ModelConfig") -> None:
|
||||
super().__init__()
|
||||
|
||||
mm_config = model_config.get_multimodal_config()
|
||||
|
||||
self._cache = MultiModalCache.get_lru_cache(
|
||||
mm_config.mm_processor_cache_gb,
|
||||
MultiModalProcessorCacheItem,
|
||||
)
|
||||
|
||||
@override
|
||||
def is_cached_item(self, mm_hash: str) -> bool:
|
||||
return mm_hash in self._cache
|
||||
|
||||
@override
|
||||
def get_and_update_item(
|
||||
self,
|
||||
mm_item: MultiModalProcessorCacheInItem,
|
||||
mm_hash: str,
|
||||
) -> MultiModalProcessorCacheOutItem:
|
||||
if (cached_item := self._cache.get(mm_hash)) is not None:
|
||||
return cached_item.item, cached_item.prompt_updates
|
||||
|
||||
assert mm_item is not None, f"Expected a cached item for {mm_hash=}"
|
||||
|
||||
self._cache[mm_hash] = MultiModalProcessorCacheItem(*mm_item)
|
||||
|
||||
return mm_item
|
||||
|
||||
@override
|
||||
def clear_cache(self) -> None:
|
||||
self._cache.clear()
|
||||
|
||||
|
||||
class MultiModalProcessorSenderCache(BaseMultiModalProcessorCache):
|
||||
"""
|
||||
The cache which is used on P0 when IPC caching is enabled.
|
||||
|
||||
How to update each item:
|
||||
|
||||
- If the item is already in the cache, clear the input to avoid
|
||||
unnecessary IPC.
|
||||
|
||||
- If the item is not in the cache, store the metadata of that item so
|
||||
that the eviction policy remains the same as the cache on P1,
|
||||
and return the input.
|
||||
By only storing the metadata, we avoid keeping the data itself in
|
||||
memory inside P0.
|
||||
"""
|
||||
|
||||
def __init__(self, model_config: "ModelConfig") -> None:
|
||||
super().__init__()
|
||||
|
||||
mm_config = model_config.get_multimodal_config()
|
||||
|
||||
self._cache = MultiModalCache.get_lru_cache(
|
||||
mm_config.mm_processor_cache_gb,
|
||||
MultiModalProcessorCacheItemMetadata,
|
||||
)
|
||||
|
||||
@override
|
||||
def is_cached_item(self, mm_hash: str) -> bool:
|
||||
return mm_hash in self._cache
|
||||
|
||||
@override
|
||||
def get_and_update_item(
|
||||
self,
|
||||
mm_item: MultiModalProcessorCacheInItem,
|
||||
mm_hash: str,
|
||||
) -> MultiModalProcessorCacheOutItem:
|
||||
if (cached_item := self._cache.get(mm_hash)) is not None:
|
||||
return None, cached_item.prompt_updates
|
||||
|
||||
assert mm_item is not None, f"Expected a cached item for {mm_hash=}"
|
||||
|
||||
self._cache[mm_hash] = MultiModalProcessorCacheItemMetadata(*mm_item)
|
||||
|
||||
return mm_item
|
||||
|
||||
@override
|
||||
def clear_cache(self) -> None:
|
||||
self._cache.clear()
|
||||
|
||||
|
||||
class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache):
|
||||
"""
|
||||
The cache which is used on P0 when IPC caching is enabled.
|
||||
|
||||
How to update each item:
|
||||
|
||||
- If the item is already in the cache, clear the input to avoid
|
||||
unnecessary IPC.
|
||||
|
||||
- If the item is not in the cache, store the data in shared memory.
|
||||
"""
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig") -> None:
|
||||
super().__init__()
|
||||
|
||||
self.world_size = vllm_config.parallel_config.world_size
|
||||
mm_config = vllm_config.model_config.get_multimodal_config()
|
||||
|
||||
ring_buffer = SingleWriterShmRingBuffer(
|
||||
data_buffer_size=int(mm_config.mm_processor_cache_gb * GiB_bytes),
|
||||
name=VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME,
|
||||
create=True, # sender is the writer
|
||||
)
|
||||
self._shm_cache = SingleWriterShmObjectStorage(
|
||||
max_object_size=mm_config.mm_shm_cache_max_object_size_mb *
|
||||
MiB_bytes,
|
||||
n_readers=self.world_size,
|
||||
ring_buffer=ring_buffer,
|
||||
serde_class=MsgpackSerde,
|
||||
)
|
||||
# cache (prompt_updates, modality) for P0 only
|
||||
self._p0_cache: dict[str, tuple[Sequence[ResolvedPromptUpdate],
|
||||
str]] = {}
|
||||
|
||||
@override
|
||||
def is_cached_item(self, mm_hash: str) -> bool:
|
||||
return self._shm_cache.is_cached(mm_hash)
|
||||
|
||||
@override
|
||||
def get_and_update_item(
|
||||
self,
|
||||
mm_item: MultiModalProcessorCacheInItem,
|
||||
mm_hash: str,
|
||||
) -> MultiModalProcessorCacheOutItem:
|
||||
|
||||
if self._shm_cache.is_cached(mm_hash):
|
||||
address, monotonic_id = self._shm_cache.get_cached(mm_hash)
|
||||
prompt_updates, modality = self._p0_cache[mm_hash]
|
||||
return self.address_as_item(address, monotonic_id,
|
||||
modality), prompt_updates
|
||||
|
||||
assert mm_item is not None, f"Expected a cached item for {mm_hash=}"
|
||||
|
||||
try:
|
||||
address, monotonic_id = self._shm_cache.put(mm_hash, mm_item[0])
|
||||
# Try to remove dangling items if p0 cache is too large.
|
||||
if len(self._p0_cache) >= 2 * len(self._shm_cache.key_index):
|
||||
self.remove_dangling_items()
|
||||
self._p0_cache[mm_hash] = mm_item[1], mm_item[0].modality
|
||||
address_item = self.address_as_item(address, monotonic_id,
|
||||
mm_item[0].modality)
|
||||
return address_item, mm_item[1]
|
||||
except (ValueError, MemoryError) as e:
|
||||
# put may fail if the object is too large or
|
||||
# the cache is full.
|
||||
# In this case we log the error and keep the original mm_input.
|
||||
logger.debug("Failed to cache mm_input with hash %s: %s", mm_hash,
|
||||
e)
|
||||
return mm_item
|
||||
|
||||
@override
|
||||
def clear_cache(self) -> None:
|
||||
self._shm_cache.clear()
|
||||
self._p0_cache.clear()
|
||||
|
||||
def remove_dangling_items(self) -> None:
|
||||
"""Remove items that are no longer in the shared memory cache."""
|
||||
cached_hashes = self._shm_cache.key_index.keys()
|
||||
dangling_hashes = set(self._p0_cache.keys()) - cached_hashes
|
||||
for mm_hash in dangling_hashes:
|
||||
del self._p0_cache[mm_hash]
|
||||
|
||||
def address_as_item(self, address: int, monotonic_id: int,
|
||||
modality: str) -> MultiModalKwargsItem:
|
||||
addr_elem = MultiModalFieldElem(
|
||||
modality=modality,
|
||||
key="address",
|
||||
data=address,
|
||||
field=MultiModalBatchedField(),
|
||||
)
|
||||
id_elem = MultiModalFieldElem(
|
||||
modality=modality,
|
||||
key="monotonic_id",
|
||||
data=monotonic_id,
|
||||
field=MultiModalBatchedField(),
|
||||
)
|
||||
mm_item = MultiModalKwargsItem.from_elems([addr_elem, id_elem])
|
||||
return mm_item
|
||||
|
||||
|
||||
def _enable_processor_cache(
|
||||
model_config: "ModelConfig",
|
||||
mm_registry: "MultiModalRegistry",
|
||||
) -> bool:
|
||||
if not mm_registry.supports_multimodal_inputs(model_config):
|
||||
return False
|
||||
|
||||
mm_config = model_config.get_multimodal_config()
|
||||
return mm_config.mm_processor_cache_gb > 0
|
||||
|
||||
|
||||
def _enable_ipc_cache(vllm_config: "VllmConfig") -> bool:
|
||||
parallel_config = vllm_config.parallel_config
|
||||
supports_ipc_cache = ((parallel_config._api_process_count == 1
|
||||
and parallel_config.data_parallel_size == 1)
|
||||
or parallel_config.data_parallel_external_lb)
|
||||
|
||||
return supports_ipc_cache
|
||||
|
||||
|
||||
def _enable_mm_input_shm_cache(vllm_config: "VllmConfig") -> bool:
|
||||
"""Whether the shared memory based cache should be enabled."""
|
||||
|
||||
if not _enable_ipc_cache(vllm_config):
|
||||
return False
|
||||
|
||||
mm_config = vllm_config.model_config.get_multimodal_config()
|
||||
|
||||
return mm_config.mm_processor_cache_type == "shm"
|
||||
|
||||
|
||||
def processor_cache_from_config(
|
||||
vllm_config: "VllmConfig",
|
||||
mm_registry: "MultiModalRegistry",
|
||||
) -> Optional[BaseMultiModalProcessorCache]:
|
||||
"""Return a `BaseMultiModalProcessorCache`, if enabled."""
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
if not _enable_processor_cache(model_config, mm_registry):
|
||||
return None
|
||||
|
||||
if not _enable_ipc_cache(vllm_config):
|
||||
return MultiModalProcessorOnlyCache(model_config)
|
||||
|
||||
if not _enable_mm_input_shm_cache(vllm_config):
|
||||
return MultiModalProcessorSenderCache(model_config)
|
||||
return ShmObjectStoreSenderCache(vllm_config)
|
||||
|
||||
|
||||
def processor_only_cache_from_config(
|
||||
model_config: "ModelConfig",
|
||||
mm_registry: "MultiModalRegistry",
|
||||
):
|
||||
"""Return a `MultiModalProcessorOnlyCache`, if enabled."""
|
||||
if not _enable_processor_cache(model_config, mm_registry):
|
||||
return None
|
||||
|
||||
return MultiModalProcessorOnlyCache(model_config)
|
||||
|
||||
|
||||
class BaseMultiModalReceiverCache(
|
||||
BaseMultiModalCache[Optional[MultiModalKwargsItem],
|
||||
MultiModalKwargsItem]):
|
||||
"""The required interface for caches on P1."""
|
||||
|
||||
def get_and_update_features(
|
||||
self,
|
||||
mm_features: list["MultiModalFeatureSpec"],
|
||||
) -> list["MultiModalFeatureSpec"]:
|
||||
"""Update multimodal features with cached encoder outputs."""
|
||||
for feature in mm_features:
|
||||
feature.data = self.get_and_update_item(feature.data,
|
||||
feature.identifier)
|
||||
return mm_features
|
||||
|
||||
|
||||
class MultiModalReceiverCache(BaseMultiModalReceiverCache):
|
||||
"""
|
||||
The cache which is used on P1 when IPC caching is enabled.
|
||||
|
||||
How to update each item:
|
||||
|
||||
- If the item is in the cache, replace the input with the cached item.
|
||||
- If the item is not in the cache, store that item (which includes tensor
|
||||
data) into the cache, and return the input.
|
||||
"""
|
||||
|
||||
def __init__(self, model_config: "ModelConfig") -> None:
|
||||
super().__init__()
|
||||
|
||||
mm_config = model_config.get_multimodal_config()
|
||||
|
||||
self._cache = MultiModalCache.get_lru_cache(
|
||||
mm_config.mm_processor_cache_gb,
|
||||
MultiModalKwargsItem,
|
||||
)
|
||||
|
||||
@override
|
||||
def get_and_update_item(
|
||||
self,
|
||||
mm_item: Optional[MultiModalKwargsItem],
|
||||
mm_hash: str,
|
||||
) -> MultiModalKwargsItem:
|
||||
if (cached_item := self._cache.get(mm_hash)) is not None:
|
||||
return cached_item
|
||||
|
||||
assert mm_item is not None, f"Expected a cached item for {mm_hash=}"
|
||||
|
||||
self._cache[mm_hash] = mm_item
|
||||
return mm_item
|
||||
|
||||
@override
|
||||
def clear_cache(self) -> None:
|
||||
self._cache.clear()
|
||||
|
||||
|
||||
class ShmObjectStoreReceiverCache(BaseMultiModalReceiverCache):
|
||||
"""
|
||||
The cache which is used on P1 Worker Process when IPC caching is enabled.
|
||||
|
||||
How to update each item:
|
||||
|
||||
- If the item has an address, replace the input with the cached item.
|
||||
- If not, return the input.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
shared_worker_lock: LockType,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.world_size = vllm_config.parallel_config.world_size
|
||||
mm_config = vllm_config.model_config.get_multimodal_config()
|
||||
|
||||
ring_buffer = SingleWriterShmRingBuffer(
|
||||
data_buffer_size=int(mm_config.mm_processor_cache_gb * GiB_bytes),
|
||||
name=VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME,
|
||||
create=False, # Server is a reader
|
||||
)
|
||||
self._shm_cache = SingleWriterShmObjectStorage(
|
||||
max_object_size=mm_config.mm_shm_cache_max_object_size_mb *
|
||||
MiB_bytes,
|
||||
n_readers=self.world_size,
|
||||
ring_buffer=ring_buffer,
|
||||
serde_class=MsgpackSerde,
|
||||
reader_lock=shared_worker_lock,
|
||||
)
|
||||
|
||||
@override
|
||||
def get_and_update_item(
|
||||
self,
|
||||
mm_item: Optional[MultiModalKwargsItem],
|
||||
mm_hash: str,
|
||||
) -> MultiModalKwargsItem:
|
||||
assert mm_item is not None, f"Expected an address item for {mm_hash=}"
|
||||
if "address" in mm_item:
|
||||
address = cast(int, mm_item["address"].data)
|
||||
monotonic_id = cast(int, mm_item["monotonic_id"].data)
|
||||
return self._shm_cache.get(address, monotonic_id)
|
||||
|
||||
return mm_item
|
||||
|
||||
@override
|
||||
def clear_cache(self) -> None:
|
||||
self._shm_cache.clear()
|
||||
|
||||
|
||||
def engine_receiver_cache_from_config(
|
||||
vllm_config: "VllmConfig",
|
||||
mm_registry: "MultiModalRegistry",
|
||||
) -> Optional[BaseMultiModalReceiverCache]:
|
||||
"""
|
||||
This is used in the engine process.
|
||||
Return a `BaseMultiModalReceiverCache` only when IPC caching is enabled and
|
||||
mm_processor_cache_type=="lru".
|
||||
"""
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
if not _enable_processor_cache(model_config, mm_registry):
|
||||
return None
|
||||
|
||||
if not _enable_ipc_cache(vllm_config):
|
||||
return None
|
||||
|
||||
if not _enable_mm_input_shm_cache(vllm_config):
|
||||
return MultiModalReceiverCache(model_config)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def worker_receiver_cache_from_config(
|
||||
vllm_config: "VllmConfig",
|
||||
mm_registry: "MultiModalRegistry",
|
||||
shared_worker_lock: LockType,
|
||||
) -> Optional[BaseMultiModalReceiverCache]:
|
||||
"""
|
||||
This is used in the worker process.
|
||||
Return a `BaseMultiModalReceiverCache` only when IPC caching is enabled and
|
||||
mm_processor_cache_type=="shm".
|
||||
"""
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
if not _enable_processor_cache(model_config, mm_registry):
|
||||
return None
|
||||
|
||||
if not _enable_ipc_cache(vllm_config):
|
||||
return None
|
||||
|
||||
if not _enable_mm_input_shm_cache(vllm_config):
|
||||
return None
|
||||
|
||||
return ShmObjectStoreReceiverCache(vllm_config, shared_worker_lock)
|
||||
273
vllm/multimodal/evs.py
Normal file
273
vllm/multimodal/evs.py
Normal file
@@ -0,0 +1,273 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
import typing
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def compute_retained_tokens_count(video_size_thw: torch.LongTensor,
|
||||
spatial_merge_size: int, q: float) -> int:
|
||||
"""
|
||||
Compute the number of retained tokens for a given video.
|
||||
Method ensures that we retain all the tokens from the first frame
|
||||
regardless of the pruning rate.
|
||||
|
||||
Args:
|
||||
video_size_thw: The size of the video in the format of (T, H, W).
|
||||
spatial_merge_size: The size of the spatial merge.
|
||||
q: The pruning rate.
|
||||
|
||||
Returns:
|
||||
The number of retained tokens.
|
||||
"""
|
||||
T, H, W = map(int, video_size_thw)
|
||||
min_num_tokens = (H // spatial_merge_size) * (W // spatial_merge_size)
|
||||
evs_num_tokens = int(T * min_num_tokens * (1 - q))
|
||||
return max(min_num_tokens, evs_num_tokens)
|
||||
|
||||
|
||||
def compute_retention_mask(
|
||||
video_embeds: torch.Tensor,
|
||||
video_size_thw: torch.LongTensor,
|
||||
spatial_merge_size: int,
|
||||
q: float,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Computes the retention mask for input video embeddings.
|
||||
|
||||
Args:
|
||||
video_embeds (`torch.Tensor`): The input video embeddings
|
||||
of shape `(T * H * W // spatial_merge_size ^ 2, hidden_size)`
|
||||
video_size_thw (`torch.LongTensor` of shape `(3)`):
|
||||
The temporal, height and width of video.
|
||||
spatial_merge_size: Size reduction for rows & cols dimensions.
|
||||
q: (`float`): Pruning rate factor [0,1)
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: The retention mask for the video embeddings of
|
||||
`(T * H * W // spatial_merge_size ^ 2)` shape.
|
||||
"""
|
||||
T, H, W = video_size_thw
|
||||
|
||||
# Use reshape instead of einops to avoid graph breaks
|
||||
video_embeds = video_embeds.reshape(
|
||||
T,
|
||||
H // spatial_merge_size,
|
||||
W // spatial_merge_size,
|
||||
video_embeds.size(-1),
|
||||
)
|
||||
|
||||
# Core EVS
|
||||
similarity = torch.nn.functional.cosine_similarity(video_embeds[1:, ...],
|
||||
video_embeds[:-1, ...],
|
||||
dim=-1)
|
||||
dissimilarity = 1 - similarity
|
||||
|
||||
# Always ensure we include all tokens from the first frame
|
||||
dissimilarity = torch.cat(
|
||||
[255 * torch.ones_like(video_embeds[:1, :, :, 0]), dissimilarity],
|
||||
dim=0)
|
||||
|
||||
dissimilarity_flat = dissimilarity.view(-1)
|
||||
order = torch.argsort(dissimilarity_flat,
|
||||
dim=-1,
|
||||
descending=True,
|
||||
stable=True)
|
||||
retain_num_tokens = compute_retained_tokens_count(video_size_thw,
|
||||
spatial_merge_size, q)
|
||||
topk_indices = order[:retain_num_tokens]
|
||||
|
||||
retention_mask = torch.zeros_like(dissimilarity_flat, dtype=torch.bool)
|
||||
retention_mask[topk_indices] = True
|
||||
retention_mask = retention_mask.reshape(dissimilarity.size())
|
||||
|
||||
mask = retention_mask.view(-1) # "T H W -> (T H W)"
|
||||
return mask
|
||||
|
||||
|
||||
def compute_mrope_for_media(
|
||||
video_size_thw: torch.LongTensor,
|
||||
spatial_merge_size: int,
|
||||
tokens_per_second: float = 1.0,
|
||||
video_second_per_grid: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Computes the mrope for video embeddings based on the grid dimensions.
|
||||
Computed mrope positions match original qwen 2.5 implementation,
|
||||
but positions are built for media being the first element in sequence.
|
||||
|
||||
Args:
|
||||
video_size_thw: Media size (num frames, rows, cols)
|
||||
spatial_merge_size: Size reduction for rows & cols dimensions.
|
||||
tokens_per_second: Number of tokens per second.
|
||||
video_second_per_grid: Number of seconds per video.
|
||||
|
||||
Returns:
|
||||
Tensor of shape `(T * H * W, 4)` where last dimension
|
||||
represents mrope positions [0:3), while the last channel
|
||||
contains value of llm_grid_w repeated for all positions.
|
||||
"""
|
||||
llm_grid_t = video_size_thw[0]
|
||||
llm_grid_h = video_size_thw[1] // spatial_merge_size
|
||||
llm_grid_w = video_size_thw[2] // spatial_merge_size
|
||||
|
||||
t_index = ((torch.arange(llm_grid_t).view(-1, 1).expand(
|
||||
-1, llm_grid_h * llm_grid_w).mul(
|
||||
tokens_per_second * video_second_per_grid)).long().flatten())
|
||||
h_index = (torch.arange(llm_grid_h).view(1, -1,
|
||||
1).expand(llm_grid_t, -1,
|
||||
llm_grid_w).flatten())
|
||||
w_index = (torch.arange(llm_grid_w).view(1, 1, -1).expand(
|
||||
llm_grid_t, llm_grid_h, -1).flatten())
|
||||
llm_grid_w = (torch.tensor([llm_grid_w
|
||||
]).view(1, 1,
|
||||
1).expand(llm_grid_t, llm_grid_h,
|
||||
llm_grid_w).flatten())
|
||||
|
||||
positions = torch.stack([t_index, h_index, w_index, llm_grid_w], dim=1)
|
||||
return positions
|
||||
|
||||
|
||||
def recompute_mrope_positions(
|
||||
input_ids: torch.LongTensor,
|
||||
multimodal_positions: list[torch.Tensor],
|
||||
mrope_positions: torch.LongTensor,
|
||||
num_computed_tokens: int,
|
||||
vision_start_token_id: int,
|
||||
image_token_id: int,
|
||||
video_token_id: int,
|
||||
) -> tuple[torch.LongTensor, int]:
|
||||
"""
|
||||
Update part of input mrope positions.
|
||||
Original mrope_positions are computed incorrectly, so once we prune media
|
||||
tokens we should reflect this in the mrope positions for the LLM.
|
||||
|
||||
This method supports chunked prefill approach where
|
||||
multimodal_embeddings are passed to LLM in chunks, so input
|
||||
multimodal_embeddings may contain zero, some or even some part of all
|
||||
multimodal_embeddings for a given prompt.
|
||||
|
||||
Each multimodal_positions has 4 extra channels
|
||||
(First 3 channels corresponds to original 3 mrope positions, last channel
|
||||
is the maximum width of the media repeated). Provided multimodal_positions
|
||||
do not reflect location of media position in sequence - they are computed
|
||||
like the media is in the 0-th position in the sequence.
|
||||
|
||||
Method works as follows: it recomputes mrope_positions starting from the
|
||||
`num_computed_tokens` for `total_len_of_multimodal_embeddings` and then
|
||||
shifts all text tokens that goes after total_len_of_multimodal_embeddings.
|
||||
|
||||
It also handles case when multimodal_embeddings is partial
|
||||
(e.g. one media is split into two prefill stages)
|
||||
|
||||
Args:
|
||||
input_ids: (N,) All input tokens of the prompt (entire sequence).
|
||||
multimodal_positions: List of mrope positsions for each media.
|
||||
mrope_positions: Existing mrope positions (4, N) for entire sequence.
|
||||
num_computed_tokens: A number of computed tokens so far.
|
||||
vision_start_token_id: Token indicating start of vision media.
|
||||
image_token_id: Image token id
|
||||
video_token_id: Video token id
|
||||
|
||||
Returns:
|
||||
Tuple of (mrope_positions, mrope_position_delta).
|
||||
"""
|
||||
|
||||
# Tensors
|
||||
positions: torch.LongTensor = typing.cast(
|
||||
torch.LongTensor, mrope_positions.clone()) # (3, N)
|
||||
N = input_ids.numel()
|
||||
|
||||
image_mask = input_ids.eq(image_token_id)
|
||||
video_mask = input_ids.eq(video_token_id)
|
||||
media_mask = image_mask | video_mask
|
||||
text_mask = ~media_mask
|
||||
|
||||
# Early exit: no media in this chunk
|
||||
if len(multimodal_positions) == 0:
|
||||
delta = (int((positions.max().item() + 1) -
|
||||
N) if positions.numel() else -N)
|
||||
return positions, delta
|
||||
|
||||
total_mm_tokens = torch.count_nonzero(media_mask)
|
||||
seen_mm_tokens = torch.count_nonzero(media_mask[:num_computed_tokens])
|
||||
|
||||
# Early exit: we've updated positions for all media tokens
|
||||
# (and consequently - for all remaining text tokens)
|
||||
if seen_mm_tokens == total_mm_tokens:
|
||||
delta = (int((positions.max().item() + 1) -
|
||||
N) if positions.numel() else -N)
|
||||
return positions, delta
|
||||
|
||||
vision_start_indices = (input_ids == vision_start_token_id).nonzero(
|
||||
as_tuple=True)[0]
|
||||
|
||||
for mm_pos in multimodal_positions:
|
||||
# Each mm_pos can be a complete embedding for single media
|
||||
# or it can be a part of a single media (due to chunked prefill)
|
||||
|
||||
# Cases to cover
|
||||
# - Current prefill chunk has no vision start indexes at all
|
||||
# - Vision start token appeared in previous prefill round
|
||||
# - Regular case
|
||||
seen_vision_start_indices = vision_start_indices[vision_start_indices <
|
||||
num_computed_tokens]
|
||||
|
||||
if len(seen_vision_start_indices):
|
||||
# If we have encountered some vision start indexes,
|
||||
# then we should check the condition:
|
||||
# | --- prefill 1 ------| ---- prefill 2 ----- |
|
||||
# | TTTTTTTTTSVVVVVVVVVV|VVVVVVTTTTTTTTTTTTTTTT|
|
||||
last_vision_start_token = seen_vision_start_indices[-1]
|
||||
seem_mm_tokens_before_last_vision_start = torch.count_nonzero(
|
||||
media_mask[:last_vision_start_token])
|
||||
in_the_middle_of_media = (
|
||||
seen_mm_tokens > seem_mm_tokens_before_last_vision_start)
|
||||
|
||||
if in_the_middle_of_media:
|
||||
mm_embeddings_seen = (seen_mm_tokens -
|
||||
seem_mm_tokens_before_last_vision_start)
|
||||
global_mm_start = last_vision_start_token
|
||||
else:
|
||||
# We have completed previous mm_embedding part and
|
||||
# ready to start a new one
|
||||
next_vision_start_token = vision_start_indices[
|
||||
vision_start_indices >= num_computed_tokens][0]
|
||||
mm_embeddings_seen = 0
|
||||
global_mm_start = next_vision_start_token
|
||||
|
||||
else:
|
||||
# If there were no vision start indexes so far,
|
||||
# let's find first vision start index
|
||||
next_vision_start_token = vision_start_indices[
|
||||
vision_start_indices >= num_computed_tokens][0]
|
||||
|
||||
mm_embeddings_seen = 0
|
||||
global_mm_start = next_vision_start_token
|
||||
|
||||
# Offset right after vision_start_token
|
||||
base = positions[-1, global_mm_start] + 1
|
||||
local_start = global_mm_start + 1 + mm_embeddings_seen
|
||||
local_end = local_start + mm_pos.shape[1]
|
||||
positions[:, local_start:local_end] = mm_pos[0:3] + base
|
||||
|
||||
# mm_pos[3, 0] is the max width of the media
|
||||
offset = mm_pos[3, 0] + base
|
||||
|
||||
text_pos_sum = torch.cumsum(text_mask[local_end:].long(), dim=0)
|
||||
|
||||
positions[:, local_end:N] = text_pos_sum + offset - 1
|
||||
|
||||
# Include distance to the next vision start token
|
||||
num_computed_tokens += mm_pos.shape[1]
|
||||
|
||||
mrope_positions_delta = (positions.max() + 1 - N).item()
|
||||
return positions, mrope_positions_delta
|
||||
102
vllm/multimodal/hasher.py
Normal file
102
vllm/multimodal/hasher.py
Normal file
@@ -0,0 +1,102 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pickle
|
||||
import uuid
|
||||
from collections.abc import Iterable
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from blake3 import blake3
|
||||
from PIL import Image
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MultiModalHasher:
|
||||
|
||||
@classmethod
|
||||
def serialize_item(cls, obj: object) -> Iterable[Union[bytes, memoryview]]:
|
||||
# Simple cases
|
||||
if isinstance(obj, (bytes, memoryview)):
|
||||
return (obj, )
|
||||
if isinstance(obj, str):
|
||||
return (obj.encode("utf-8"), )
|
||||
if isinstance(obj, (int, float)):
|
||||
return (np.array(obj).tobytes(), )
|
||||
|
||||
if isinstance(obj, Image.Image):
|
||||
exif = obj.getexif()
|
||||
if Image.ExifTags.Base.ImageID in exif and isinstance(
|
||||
exif[Image.ExifTags.Base.ImageID], uuid.UUID):
|
||||
# If the image has exif ImageID tag, use that
|
||||
return (exif[Image.ExifTags.Base.ImageID].bytes, )
|
||||
data = {"mode": obj.mode, "data": np.asarray(obj)}
|
||||
if obj.palette is not None:
|
||||
data["palette"] = obj.palette.palette
|
||||
if obj.palette.rawmode is not None:
|
||||
data["palette_rawmode"] = obj.palette.rawmode
|
||||
return cls.iter_item_to_bytes("image", data)
|
||||
if isinstance(obj, torch.Tensor):
|
||||
tensor_obj: torch.Tensor = obj.cpu()
|
||||
tensor_dtype = tensor_obj.dtype
|
||||
tensor_shape = tensor_obj.shape
|
||||
|
||||
# NumPy does not support bfloat16.
|
||||
# Workaround: View the tensor as a contiguous 1D array of bytes
|
||||
if tensor_dtype == torch.bfloat16:
|
||||
tensor_obj = tensor_obj.contiguous()
|
||||
tensor_obj = tensor_obj.view(
|
||||
(tensor_obj.numel(), )).view(torch.uint8)
|
||||
|
||||
return cls.iter_item_to_bytes(
|
||||
"tensor", {
|
||||
"original_dtype": str(tensor_dtype),
|
||||
"original_shape": tuple(tensor_shape),
|
||||
"data": tensor_obj.numpy(),
|
||||
})
|
||||
return cls.iter_item_to_bytes("tensor", tensor_obj.numpy())
|
||||
if isinstance(obj, np.ndarray):
|
||||
# If the array is non-contiguous, we need to copy it first
|
||||
arr_data = obj.view(
|
||||
np.uint8).data if obj.flags.c_contiguous else obj.tobytes()
|
||||
return cls.iter_item_to_bytes("ndarray", {
|
||||
"dtype": obj.dtype.str,
|
||||
"shape": obj.shape,
|
||||
"data": arr_data,
|
||||
})
|
||||
logger.warning(
|
||||
"No serialization method found for %s. "
|
||||
"Falling back to pickle.", type(obj))
|
||||
|
||||
return (pickle.dumps(obj), )
|
||||
|
||||
@classmethod
|
||||
def iter_item_to_bytes(
|
||||
cls,
|
||||
key: str,
|
||||
obj: object,
|
||||
) -> Iterable[Union[bytes, memoryview]]:
|
||||
# Recursive cases
|
||||
if isinstance(obj, (list, tuple)):
|
||||
for i, elem in enumerate(obj):
|
||||
yield from cls.iter_item_to_bytes(f"{key}.{i}", elem)
|
||||
elif isinstance(obj, dict):
|
||||
for k, v in obj.items():
|
||||
yield from cls.iter_item_to_bytes(f"{key}.{k}", v)
|
||||
else:
|
||||
yield key.encode("utf-8")
|
||||
yield from cls.serialize_item(obj)
|
||||
|
||||
@classmethod
|
||||
def hash_kwargs(cls, **kwargs: object) -> str:
|
||||
hasher = blake3()
|
||||
|
||||
for k, v in kwargs.items():
|
||||
for bytes_ in cls.iter_item_to_bytes(k, v):
|
||||
hasher.update(bytes_)
|
||||
|
||||
return hasher.hexdigest()
|
||||
130
vllm/multimodal/image.py
Normal file
130
vllm/multimodal/image.py
Normal file
@@ -0,0 +1,130 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import pybase64
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from .base import MediaIO
|
||||
|
||||
|
||||
def rescale_image_size(image: Image.Image,
|
||||
size_factor: float,
|
||||
transpose: int = -1) -> Image.Image:
|
||||
"""Rescale the dimensions of an image by a constant factor."""
|
||||
new_width = int(image.width * size_factor)
|
||||
new_height = int(image.height * size_factor)
|
||||
image = image.resize((new_width, new_height))
|
||||
if transpose >= 0:
|
||||
image = image.transpose(Image.Transpose(transpose))
|
||||
return image
|
||||
|
||||
|
||||
def rgba_to_rgb(
|
||||
image: Image.Image,
|
||||
background_color: Union[tuple[int, int, int], list[int]] = (255, 255, 255)
|
||||
) -> Image.Image:
|
||||
"""Convert an RGBA image to RGB with filled background color."""
|
||||
assert image.mode == "RGBA"
|
||||
converted = Image.new("RGB", image.size, background_color)
|
||||
converted.paste(image, mask=image.split()[3]) # 3 is the alpha channel
|
||||
return converted
|
||||
|
||||
|
||||
def convert_image_mode(image: Image.Image, to_mode: str):
|
||||
if image.mode == to_mode:
|
||||
return image
|
||||
elif image.mode == "RGBA" and to_mode == "RGB":
|
||||
return rgba_to_rgb(image)
|
||||
else:
|
||||
return image.convert(to_mode)
|
||||
|
||||
|
||||
class ImageMediaIO(MediaIO[Image.Image]):
|
||||
|
||||
def __init__(self, image_mode: str = "RGB", **kwargs) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.image_mode = image_mode
|
||||
# `kwargs` contains custom arguments from
|
||||
# --media-io-kwargs for this modality.
|
||||
# They can be passed to the underlying
|
||||
# media loaders (e.g. custom implementations)
|
||||
# for flexible control.
|
||||
self.kwargs = kwargs
|
||||
|
||||
# Extract RGBA background color from kwargs if provided
|
||||
# Default to white background for backward compatibility
|
||||
rgba_bg = kwargs.get('rgba_background_color', (255, 255, 255))
|
||||
# Convert list to tuple for consistency
|
||||
if isinstance(rgba_bg, list):
|
||||
rgba_bg = tuple(rgba_bg)
|
||||
|
||||
# Validate rgba_background_color format
|
||||
if not (isinstance(rgba_bg, tuple) and len(rgba_bg) == 3
|
||||
and all(isinstance(c, int) and 0 <= c <= 255
|
||||
for c in rgba_bg)):
|
||||
raise ValueError(
|
||||
"rgba_background_color must be a list or tuple of 3 integers "
|
||||
"in the range [0, 255].")
|
||||
self.rgba_background_color = rgba_bg
|
||||
|
||||
def _convert_image_mode(self, image: Image.Image) -> Image.Image:
|
||||
"""Convert image mode with custom background color."""
|
||||
if image.mode == self.image_mode:
|
||||
return image
|
||||
elif image.mode == "RGBA" and self.image_mode == "RGB":
|
||||
return rgba_to_rgb(image, self.rgba_background_color)
|
||||
else:
|
||||
return convert_image_mode(image, self.image_mode)
|
||||
|
||||
def load_bytes(self, data: bytes) -> Image.Image:
|
||||
image = Image.open(BytesIO(data))
|
||||
image.load()
|
||||
return self._convert_image_mode(image)
|
||||
|
||||
def load_base64(self, media_type: str, data: str) -> Image.Image:
|
||||
return self.load_bytes(pybase64.b64decode(data, validate=True))
|
||||
|
||||
def load_file(self, filepath: Path) -> Image.Image:
|
||||
image = Image.open(filepath)
|
||||
image.load()
|
||||
return self._convert_image_mode(image)
|
||||
|
||||
def encode_base64(
|
||||
self,
|
||||
media: Image.Image,
|
||||
*,
|
||||
image_format: str = "JPEG",
|
||||
) -> str:
|
||||
image = media
|
||||
|
||||
with BytesIO() as buffer:
|
||||
image = self._convert_image_mode(image)
|
||||
image.save(buffer, image_format)
|
||||
data = buffer.getvalue()
|
||||
|
||||
return pybase64.b64encode(data).decode('utf-8')
|
||||
|
||||
|
||||
class ImageEmbeddingMediaIO(MediaIO[torch.Tensor]):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def load_bytes(self, data: bytes) -> torch.Tensor:
|
||||
buffer = BytesIO(data)
|
||||
return torch.load(buffer, weights_only=True)
|
||||
|
||||
def load_base64(self, media_type: str, data: str) -> torch.Tensor:
|
||||
return self.load_bytes(pybase64.b64decode(data, validate=True))
|
||||
|
||||
def load_file(self, filepath: Path) -> torch.Tensor:
|
||||
return torch.load(filepath, weights_only=True)
|
||||
|
||||
def encode_base64(self, media: torch.Tensor) -> str:
|
||||
return pybase64.b64encode(media.numpy()).decode('utf-8')
|
||||
987
vllm/multimodal/inputs.py
Normal file
987
vllm/multimodal/inputs.py
Normal file
@@ -0,0 +1,987 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import UserDict, defaultdict
|
||||
from collections.abc import Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from itertools import accumulate
|
||||
from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, Union,
|
||||
cast, final)
|
||||
|
||||
import numpy as np
|
||||
from typing_extensions import NotRequired, TypeAlias, TypeVar, deprecated
|
||||
|
||||
from vllm.utils import LazyLoader, full_groupby, is_list_of
|
||||
from vllm.utils.jsontree import json_map_leaves
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
import torch.types
|
||||
from PIL.Image import Image
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
|
||||
from .processing import MultiModalHashes
|
||||
|
||||
else:
|
||||
torch = LazyLoader("torch", globals(), "torch")
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
HfImageItem: TypeAlias = Union["Image", np.ndarray, "torch.Tensor"]
|
||||
"""
|
||||
A `transformers.image_utils.ImageInput` representing a single image
|
||||
item, which can be passed to a HuggingFace `ImageProcessor`.
|
||||
"""
|
||||
|
||||
HfVideoItem: TypeAlias = Union[list["Image"], np.ndarray, "torch.Tensor",
|
||||
list[np.ndarray], list["torch.Tensor"]]
|
||||
"""
|
||||
A `transformers.image_utils.VideoInput` representing a single video
|
||||
item, which can be passed to a HuggingFace `VideoProcessor`.
|
||||
"""
|
||||
|
||||
HfAudioItem: TypeAlias = Union[list[float], np.ndarray, "torch.Tensor"]
|
||||
"""
|
||||
Represents a single audio
|
||||
item, which can be passed to a HuggingFace `AudioProcessor`.
|
||||
"""
|
||||
|
||||
ImageItem: TypeAlias = Union[HfImageItem, "torch.Tensor"]
|
||||
"""
|
||||
A `transformers.image_utils.ImageInput` representing a single image
|
||||
item, which can be passed to a HuggingFace `ImageProcessor`.
|
||||
|
||||
Alternatively, a 3-D tensor or batch of 2-D tensors,
|
||||
which are treated as image embeddings;
|
||||
these are directly passed to the model without HF processing.
|
||||
"""
|
||||
|
||||
VideoItem: TypeAlias = Union[HfVideoItem, "torch.Tensor",
|
||||
tuple[HfVideoItem, dict[str, Any]]]
|
||||
"""
|
||||
A `transformers.video_utils.VideoInput` representing a single video item.
|
||||
This can be passed to a HuggingFace `VideoProcessor`
|
||||
with `transformers.video_utils.VideoMetadata`.
|
||||
|
||||
Alternatively, a 3-D tensor or batch of 2-D tensors,
|
||||
which are treated as video embeddings;
|
||||
these are directly passed to the model without HF processing.
|
||||
"""
|
||||
|
||||
AudioItem: TypeAlias = Union[HfAudioItem, tuple[np.ndarray, float],
|
||||
"torch.Tensor"]
|
||||
"""
|
||||
Represents a single audio
|
||||
item, which can be passed to a HuggingFace `AudioProcessor`.
|
||||
|
||||
Alternatively, a tuple `(audio, sampling_rate)`, where the sampling rate
|
||||
is different from that expected by the model;
|
||||
these are resampled to the model's sampling rate before being processed by HF.
|
||||
|
||||
Alternatively, a 3-D tensor or batch of 2-D tensors,
|
||||
which are treated as audio embeddings;
|
||||
these are directly passed to the model without HF processing.
|
||||
"""
|
||||
|
||||
ModalityData: TypeAlias = Union[_T, list[Optional[_T]], None]
|
||||
"""
|
||||
Either a single data item, or a list of data items. Can only be None if UUID
|
||||
is provided.
|
||||
|
||||
The number of data items allowed per modality is restricted by
|
||||
`--limit-mm-per-prompt`.
|
||||
"""
|
||||
|
||||
|
||||
@final
|
||||
class MultiModalDataBuiltins(TypedDict, total=False):
|
||||
"""Type annotations for modality types predefined by vLLM."""
|
||||
|
||||
image: ModalityData[ImageItem]
|
||||
"""The input image(s)."""
|
||||
|
||||
video: ModalityData[VideoItem]
|
||||
"""The input video(s)."""
|
||||
|
||||
audio: ModalityData[AudioItem]
|
||||
"""The input audio(s)."""
|
||||
|
||||
|
||||
MultiModalDataDict: TypeAlias = Mapping[str, ModalityData[Any]]
|
||||
"""
|
||||
A dictionary containing an entry for each modality type to input.
|
||||
|
||||
The built-in modalities are defined by
|
||||
[`MultiModalDataBuiltins`][vllm.multimodal.inputs.MultiModalDataBuiltins].
|
||||
"""
|
||||
|
||||
MultiModalUUIDDict: TypeAlias = Mapping[str, Union[list[Optional[str]], str]]
|
||||
"""
|
||||
A dictionary containing user-provided UUIDs for items in each modality.
|
||||
If a UUID for an item is not provided, its entry will be `None` and
|
||||
MultiModalHasher will compute a hash for the item.
|
||||
|
||||
The UUID will be used to identify the item for all caching purposes
|
||||
(input processing caching, embedding caching, prefix caching, etc).
|
||||
"""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PlaceholderRange:
|
||||
"""
|
||||
Placeholder location information for multi-modal data.
|
||||
|
||||
Example:
|
||||
|
||||
Prompt: `AAAA BBBB What is in these images?`
|
||||
|
||||
Images A and B will have:
|
||||
|
||||
```
|
||||
A: PlaceholderRange(offset=0, length=4)
|
||||
B: PlaceholderRange(offset=5, length=4)
|
||||
```
|
||||
"""
|
||||
|
||||
offset: int
|
||||
"""The start index of the placeholder in the prompt."""
|
||||
|
||||
length: int
|
||||
"""The length of the placeholder."""
|
||||
|
||||
is_embed: Optional["torch.Tensor"] = None
|
||||
"""
|
||||
A boolean mask of shape `(length,)` indicating which positions
|
||||
between `offset` and `offset + length` to assign embeddings to.
|
||||
"""
|
||||
|
||||
def get_num_embeds(self) -> int:
|
||||
if self.is_embed is None:
|
||||
return self.length
|
||||
|
||||
return int(self.is_embed.sum().item())
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, self.__class__):
|
||||
return False
|
||||
if not (self.offset, self.length) == (other.offset, other.length):
|
||||
return False
|
||||
|
||||
if self.is_embed is None:
|
||||
return other.is_embed is None
|
||||
if other.is_embed is None:
|
||||
return self.is_embed is None
|
||||
|
||||
return nested_tensors_equal(self.is_embed, other.is_embed)
|
||||
|
||||
|
||||
NestedTensors: TypeAlias = Union[list["NestedTensors"], list["torch.Tensor"],
|
||||
"torch.Tensor", tuple["torch.Tensor", ...]]
|
||||
"""
|
||||
Uses a list instead of a tensor if the dimensions of each element do not match.
|
||||
"""
|
||||
|
||||
|
||||
def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:
|
||||
"""Equality check between
|
||||
[`NestedTensors`][vllm.multimodal.inputs.NestedTensors] objects."""
|
||||
if isinstance(a, torch.Tensor):
|
||||
return isinstance(b, torch.Tensor) and torch.equal(a, b)
|
||||
elif isinstance(b, torch.Tensor):
|
||||
return isinstance(a, torch.Tensor) and torch.equal(b, a)
|
||||
|
||||
if isinstance(a, list):
|
||||
return (isinstance(b, list)
|
||||
and all(nested_tensors_equal(a_, b_) for a_, b_ in zip(a, b)))
|
||||
if isinstance(b, list):
|
||||
return (isinstance(a, list)
|
||||
and all(nested_tensors_equal(b_, a_) for b_, a_ in zip(b, a)))
|
||||
|
||||
# Both a and b are scalars
|
||||
return a == b
|
||||
|
||||
|
||||
BatchedTensorInputs: TypeAlias = dict[str, NestedTensors]
|
||||
"""
|
||||
A dictionary containing nested tensors which have been batched via
|
||||
[`MultiModalKwargs.batch`][vllm.multimodal.inputs.MultiModalKwargs.batch].
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiModalFeatureSpec:
|
||||
"""
|
||||
Represents a single multimodal input with its processed data and metadata.
|
||||
|
||||
Used by the V1 engine to track multimodal data through processing and
|
||||
caching. A request containing multiple multimodal items will have one
|
||||
MultiModalFeatureSpec per item.
|
||||
"""
|
||||
|
||||
data: Optional["MultiModalKwargsItem"]
|
||||
"""Multimodal data for this feature"""
|
||||
|
||||
modality: str
|
||||
"""Based on the input, e.g., "image", "audio", "video"."""
|
||||
|
||||
identifier: str
|
||||
"""mm_hash or uuid for caching encoder outputs."""
|
||||
|
||||
mm_position: PlaceholderRange
|
||||
"""e.g., PlaceholderRange(offset=2, length=336)"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiModalFieldElem:
|
||||
"""
|
||||
Represents a keyword argument corresponding to a multi-modal item
|
||||
in [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs].
|
||||
"""
|
||||
|
||||
modality: str
|
||||
"""
|
||||
The modality of the corresponding multi-modal item.
|
||||
Each multi-modal item can consist of multiple keyword arguments.
|
||||
"""
|
||||
|
||||
key: str
|
||||
"""
|
||||
The key of this field in
|
||||
[`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs],
|
||||
i.e. the name of the keyword argument to be passed to the model.
|
||||
"""
|
||||
|
||||
data: NestedTensors
|
||||
"""
|
||||
The tensor data of this field in
|
||||
[`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs],
|
||||
i.e. the value of the keyword argument to be passed to the model.
|
||||
|
||||
It may be set to `None` if it is determined that the item is cached
|
||||
in `EngineCore`.
|
||||
"""
|
||||
|
||||
field: "BaseMultiModalField"
|
||||
"""
|
||||
Defines how to combine the tensor data of this field with others
|
||||
in order to batch multi-modal items together for model inference.
|
||||
"""
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, self.__class__):
|
||||
return False
|
||||
|
||||
if self.data is None:
|
||||
data_equal = other.data is None
|
||||
elif other.data is None:
|
||||
data_equal = self.data is None
|
||||
else:
|
||||
data_equal = nested_tensors_equal(self.data, other.data)
|
||||
|
||||
return ((self.modality, self.key) == (other.modality, other.key)
|
||||
and data_equal
|
||||
and type(self.field) == type(other.field)) # noqa: E721
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BaseMultiModalField(ABC):
|
||||
"""
|
||||
Defines how to interpret tensor data belonging to a keyword argument in
|
||||
[`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs] for multiple
|
||||
multi-modal items, and vice versa.
|
||||
"""
|
||||
|
||||
def _field_factory(self, *, modality: str, key: str):
|
||||
f = partial(
|
||||
MultiModalFieldElem,
|
||||
modality=modality,
|
||||
key=key,
|
||||
field=self,
|
||||
)
|
||||
|
||||
# Allow passing data as positional argument
|
||||
def factory(data: NestedTensors) -> MultiModalFieldElem:
|
||||
return f(data=data)
|
||||
|
||||
return factory
|
||||
|
||||
@abstractmethod
|
||||
def build_elems(
|
||||
self,
|
||||
modality: str,
|
||||
key: str,
|
||||
data: NestedTensors,
|
||||
) -> Sequence[MultiModalFieldElem]:
|
||||
"""
|
||||
Construct
|
||||
[`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem]
|
||||
instances to represent the provided data.
|
||||
|
||||
This is the inverse of
|
||||
[`reduce_data`][vllm.multimodal.inputs.BaseMultiModalField.reduce_data].
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _reduce_data(
|
||||
self,
|
||||
batch: list[NestedTensors],
|
||||
*,
|
||||
pin_memory: bool,
|
||||
) -> NestedTensors:
|
||||
raise NotImplementedError
|
||||
|
||||
def reduce_data(
|
||||
self,
|
||||
elems: list[MultiModalFieldElem],
|
||||
*,
|
||||
pin_memory: bool = False,
|
||||
) -> NestedTensors:
|
||||
"""
|
||||
Merge the data from multiple instances of
|
||||
[`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem].
|
||||
|
||||
This is the inverse of
|
||||
[`build_elems`][vllm.multimodal.inputs.BaseMultiModalField.build_elems].
|
||||
"""
|
||||
field_types = [type(item.field) for item in elems]
|
||||
if len(set(field_types)) > 1:
|
||||
raise ValueError(f"Cannot merge different {field_types=}")
|
||||
|
||||
batch = [elem.data for elem in elems]
|
||||
return self._reduce_data(batch, pin_memory=pin_memory)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MultiModalBatchedField(BaseMultiModalField):
|
||||
"""
|
||||
Info:
|
||||
[`MultiModalFieldConfig.batched`][vllm.multimodal.inputs.MultiModalFieldConfig.batched]
|
||||
"""
|
||||
|
||||
def build_elems(
|
||||
self,
|
||||
modality: str,
|
||||
key: str,
|
||||
data: NestedTensors,
|
||||
) -> Sequence[MultiModalFieldElem]:
|
||||
field_factory = self._field_factory(modality=modality, key=key)
|
||||
return [field_factory(item) for item in data]
|
||||
|
||||
def _reduce_data(
|
||||
self,
|
||||
batch: list[NestedTensors],
|
||||
*,
|
||||
pin_memory: bool,
|
||||
) -> NestedTensors:
|
||||
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
|
||||
batch = cast(list[torch.Tensor], batch)
|
||||
if len(batch) == 1:
|
||||
# An optimization when `batch` contains only one tensor:
|
||||
# - produce exactly same result as `torch.stack(batch)`
|
||||
# - will achieve zero-copy if the tensor is contiguous
|
||||
return batch[0].unsqueeze(0).contiguous()
|
||||
first_shape = batch[0].shape
|
||||
if all(elem.shape == first_shape for elem in batch):
|
||||
out = torch.empty((len(batch), *batch[0].shape),
|
||||
dtype=batch[0].dtype,
|
||||
device=batch[0].device,
|
||||
pin_memory=pin_memory)
|
||||
return torch.stack(batch, out=out)
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MultiModalFlatField(BaseMultiModalField):
|
||||
"""
|
||||
Info:
|
||||
[`MultiModalFieldConfig.flat`][vllm.multimodal.inputs.MultiModalFieldConfig.flat]
|
||||
[`MultiModalFieldConfig.flat_from_sizes`][vllm.multimodal.inputs.MultiModalFieldConfig.flat_from_sizes]
|
||||
"""
|
||||
slices: Union[Sequence[slice], Sequence[Sequence[slice]]]
|
||||
dim: int = 0
|
||||
|
||||
def build_elems(
|
||||
self,
|
||||
modality: str,
|
||||
key: str,
|
||||
data: NestedTensors,
|
||||
) -> Sequence[MultiModalFieldElem]:
|
||||
field_factory = self._field_factory(modality=modality, key=key)
|
||||
if not is_list_of(self.slices, slice, check="all"):
|
||||
assert isinstance(data, torch.Tensor), \
|
||||
"torch.Tensor is required for multiple slices"
|
||||
return [field_factory(data[cast(slice, s)]) for s in self.slices]
|
||||
|
||||
def _reduce_data(
|
||||
self,
|
||||
batch: list[NestedTensors],
|
||||
*,
|
||||
pin_memory: bool,
|
||||
) -> NestedTensors:
|
||||
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
|
||||
batch = cast(list[torch.Tensor], batch)
|
||||
if len(batch) == 1:
|
||||
# An optimization when `batch` contains only one tensor:
|
||||
# - produce exactly same result as `torch.concat(batch)`
|
||||
# - will achieve zero-copy if the tensor is contiguous
|
||||
return batch[0].contiguous()
|
||||
|
||||
dim = self.dim + (self.dim < 0) * len(batch[0].shape)
|
||||
|
||||
def _shape_before_after(tensor: torch.Tensor):
|
||||
return tensor.shape[:dim], tensor.shape[dim + 1:]
|
||||
|
||||
first_shape = _shape_before_after(batch[0])
|
||||
|
||||
if all(_shape_before_after(elem) == first_shape for elem in batch):
|
||||
shape_before, shape_after = first_shape
|
||||
shape_concat = sum(item.shape[dim] for item in batch)
|
||||
out = torch.empty((*shape_before, shape_concat, *shape_after),
|
||||
dtype=batch[0].dtype,
|
||||
device=batch[0].device,
|
||||
pin_memory=pin_memory)
|
||||
return torch.concat(batch, dim=self.dim, out=out)
|
||||
|
||||
assert self.dim == 0, "dim == 0 is required for nested list"
|
||||
return [e for elem in batch for e in elem]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MultiModalSharedField(BaseMultiModalField):
|
||||
"""
|
||||
Info:
|
||||
[`MultiModalFieldConfig.shared`][vllm.multimodal.inputs.MultiModalFieldConfig.shared]
|
||||
"""
|
||||
batch_size: int
|
||||
|
||||
def build_elems(
|
||||
self,
|
||||
modality: str,
|
||||
key: str,
|
||||
data: NestedTensors,
|
||||
) -> Sequence[MultiModalFieldElem]:
|
||||
field_factory = self._field_factory(modality=modality, key=key)
|
||||
return [field_factory(data)] * self.batch_size
|
||||
|
||||
def _reduce_data(
|
||||
self,
|
||||
batch: list[NestedTensors],
|
||||
*,
|
||||
pin_memory: bool,
|
||||
) -> NestedTensors:
|
||||
return batch[0]
|
||||
|
||||
|
||||
class MultiModalFieldConfig:
|
||||
|
||||
@staticmethod
|
||||
def batched(modality: str):
|
||||
"""
|
||||
Defines a field where an element in the batch is obtained by
|
||||
indexing into the first dimension of the underlying data.
|
||||
|
||||
Args:
|
||||
modality: The modality of the multi-modal item that uses this
|
||||
keyword argument.
|
||||
|
||||
Example:
|
||||
|
||||
```
|
||||
Input:
|
||||
Data: [[AAAA]
|
||||
[BBBB]
|
||||
[CCCC]]
|
||||
|
||||
Output:
|
||||
Element 1: [AAAA]
|
||||
Element 2: [BBBB]
|
||||
Element 3: [CCCC]
|
||||
```
|
||||
"""
|
||||
return MultiModalFieldConfig(
|
||||
field=MultiModalBatchedField(),
|
||||
modality=modality,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def flat(modality: str,
|
||||
slices: Union[Sequence[slice], Sequence[Sequence[slice]]],
|
||||
dim: int = 0):
|
||||
"""
|
||||
Defines a field where an element in the batch is obtained by
|
||||
slicing along the first dimension of the underlying data.
|
||||
|
||||
Args:
|
||||
modality: The modality of the multi-modal item that uses this
|
||||
keyword argument.
|
||||
slices: For each multi-modal item, a slice (dim=0) or a tuple of
|
||||
slices (dim>0) that is used to extract the data corresponding
|
||||
to it.
|
||||
dim: The dimension to extract data, default to 0.
|
||||
|
||||
Example:
|
||||
|
||||
```
|
||||
Given:
|
||||
slices: [slice(0, 3), slice(3, 7), slice(7, 9)]
|
||||
|
||||
Input:
|
||||
Data: [AAABBBBCC]
|
||||
|
||||
Output:
|
||||
Element 1: [AAA]
|
||||
Element 2: [BBBB]
|
||||
Element 3: [CC]
|
||||
```
|
||||
|
||||
```
|
||||
Given:
|
||||
slices: [
|
||||
(slice(None), slice(0, 3)),
|
||||
(slice(None), slice(3, 7)),
|
||||
(slice(None), slice(7, 9))]
|
||||
dim: 1
|
||||
|
||||
Input:
|
||||
Data: [[A],[A],[A],[B],[B],[B],[B],[C],[C]]
|
||||
|
||||
Output:
|
||||
Element 1: [[A],[A],[A]]
|
||||
Element 2: [[B],[B],[B],[B]]
|
||||
Element 3: [[C],[C]]
|
||||
```
|
||||
"""
|
||||
return MultiModalFieldConfig(
|
||||
field=MultiModalFlatField(slices=slices, dim=dim),
|
||||
modality=modality,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def flat_from_sizes(modality: str,
|
||||
size_per_item: "torch.Tensor",
|
||||
dim: int = 0):
|
||||
"""
|
||||
Defines a field where an element in the batch is obtained by
|
||||
slicing along the first dimension of the underlying data.
|
||||
|
||||
Args:
|
||||
modality: The modality of the multi-modal item that uses this
|
||||
keyword argument.
|
||||
size_per_item: For each multi-modal item, the size of the slice
|
||||
that is used to extract the data corresponding to it.
|
||||
dim: The dimension to slice, default to 0.
|
||||
|
||||
Example:
|
||||
|
||||
```
|
||||
Given:
|
||||
size_per_item: [3, 4, 2]
|
||||
|
||||
Input:
|
||||
Data: [AAABBBBCC]
|
||||
|
||||
Output:
|
||||
Element 1: [AAA]
|
||||
Element 2: [BBBB]
|
||||
Element 3: [CC]
|
||||
```
|
||||
|
||||
```
|
||||
Given:
|
||||
size_per_item: [3, 4, 2]
|
||||
dim: 1
|
||||
|
||||
Input:
|
||||
Data: [[A],[A],[A],[B],[B],[B],[B],[C],[C]]
|
||||
|
||||
Output:
|
||||
Element 1: [[A],[A],[A]]
|
||||
Element 2: [[B],[B],[B],[B]]
|
||||
Element 3: [[C],[C]]
|
||||
```
|
||||
|
||||
Info:
|
||||
[`MultiModalFieldConfig.flat`][vllm.multimodal.inputs.MultiModalFieldConfig.flat]
|
||||
"""
|
||||
|
||||
if size_per_item.ndim != 1:
|
||||
raise ValueError("size_per_item should be a 1-D tensor, "
|
||||
f"but found shape: {size_per_item.shape}")
|
||||
|
||||
slice_idxs = [0, *accumulate(size_per_item)]
|
||||
slices = [(slice(None, None, None), ) * dim +
|
||||
(slice(slice_idxs[i], slice_idxs[i + 1]), )
|
||||
for i in range(len(size_per_item))]
|
||||
|
||||
return MultiModalFieldConfig.flat(modality, slices, dim=dim)
|
||||
|
||||
@staticmethod
|
||||
def shared(modality: str, batch_size: int):
|
||||
"""
|
||||
Defines a field where an element in the batch is obtained by
|
||||
taking the entirety of the underlying data.
|
||||
|
||||
This means that the data is the same for each element in the batch.
|
||||
|
||||
Args:
|
||||
modality: The modality of the multi-modal item that uses this
|
||||
keyword argument.
|
||||
batch_size: The number of multi-modal items which share this data.
|
||||
|
||||
Example:
|
||||
|
||||
```
|
||||
Given:
|
||||
batch_size: 4
|
||||
|
||||
Input:
|
||||
Data: [XYZ]
|
||||
|
||||
Output:
|
||||
Element 1: [XYZ]
|
||||
Element 2: [XYZ]
|
||||
Element 3: [XYZ]
|
||||
Element 4: [XYZ]
|
||||
```
|
||||
"""
|
||||
return MultiModalFieldConfig(
|
||||
field=MultiModalSharedField(batch_size),
|
||||
modality=modality,
|
||||
)
|
||||
|
||||
def __init__(self, field: BaseMultiModalField, modality: str) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.field = field
|
||||
self.modality = modality
|
||||
|
||||
def build_elems(
|
||||
self,
|
||||
key: str,
|
||||
batch: NestedTensors,
|
||||
) -> Sequence[MultiModalFieldElem]:
|
||||
return self.field.build_elems(self.modality, key, batch)
|
||||
|
||||
|
||||
class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
|
||||
"""
|
||||
A collection of
|
||||
[`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem]
|
||||
corresponding to a data item in
|
||||
[`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def dummy(modality: str):
|
||||
"""Convenience class for testing."""
|
||||
mm_elem = MultiModalFieldElem(
|
||||
modality=modality,
|
||||
key="dummy",
|
||||
data=torch.empty(1),
|
||||
field=MultiModalSharedField(1),
|
||||
)
|
||||
return MultiModalKwargsItem.from_elems([mm_elem])
|
||||
|
||||
@staticmethod
|
||||
def from_elems(elems: Sequence[MultiModalFieldElem]):
|
||||
return MultiModalKwargsItem({elem.key: elem for elem in elems})
|
||||
|
||||
def __init__(self, data: Mapping[str, MultiModalFieldElem] = {}) -> None:
|
||||
super().__init__(data)
|
||||
|
||||
modalities = {elem.modality for elem in self.values()}
|
||||
assert len(modalities) == 1, f"Found different modalities={modalities}"
|
||||
self._modality = next(iter(modalities))
|
||||
|
||||
@property
|
||||
def modality(self) -> str:
|
||||
return self._modality
|
||||
|
||||
def get_data(self) -> dict[str, NestedTensors]:
|
||||
return {key: elem.data for key, elem in self.items()}
|
||||
|
||||
|
||||
_I = TypeVar(
|
||||
"_I",
|
||||
MultiModalKwargsItem,
|
||||
Optional[MultiModalKwargsItem],
|
||||
default=MultiModalKwargsItem,
|
||||
)
|
||||
|
||||
|
||||
class MultiModalKwargsItems(UserDict[str, Sequence[_I]]):
|
||||
"""
|
||||
A dictionary of
|
||||
[`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem]s
|
||||
by modality.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def from_hf_inputs(
|
||||
hf_inputs: "BatchFeature",
|
||||
config_by_key: Mapping[str, MultiModalFieldConfig],
|
||||
):
|
||||
# NOTE: This skips fields in `hf_inputs` that are not in `config_by_key`
|
||||
# We assume that those fields are not used in vLLM
|
||||
elems_by_key = dict[str, Sequence[MultiModalFieldElem]]()
|
||||
keys_by_modality = defaultdict[str, set[str]](set)
|
||||
for key, config in config_by_key.items():
|
||||
batch = hf_inputs.get(key)
|
||||
if batch is not None:
|
||||
elems = config.build_elems(key, batch)
|
||||
if len(elems) > 0:
|
||||
elems_by_key[key] = elems
|
||||
keys_by_modality[config.modality].add(key)
|
||||
|
||||
items = list[MultiModalKwargsItem]()
|
||||
for modality, keys in keys_by_modality.items():
|
||||
elems_in_modality = {k: elems_by_key[k] for k in keys}
|
||||
batch_sizes = {k: len(v) for k, v in elems_in_modality.items()}
|
||||
|
||||
if len(set(batch_sizes.values())) > 1:
|
||||
raise ValueError(
|
||||
f"Cannot merge different batch sizes for {modality=}! "
|
||||
f"Found: {batch_sizes=}")
|
||||
|
||||
batch_size = next(iter(batch_sizes.values()))
|
||||
for item_idx in range(batch_size):
|
||||
elems = [v[item_idx] for v in elems_in_modality.values()]
|
||||
items.append(MultiModalKwargsItem.from_elems(elems))
|
||||
|
||||
return MultiModalKwargsItems.from_seq(items)
|
||||
|
||||
@staticmethod
|
||||
def from_seq(items: Sequence[MultiModalKwargsItem]):
|
||||
items_by_modality = full_groupby(items, key=lambda x: x.modality)
|
||||
return MultiModalKwargsItems(items_by_modality)
|
||||
|
||||
def __getitem__(self, modality: str) -> Sequence[_I]:
|
||||
if modality not in self:
|
||||
raise KeyError(f"Modality {modality!r} not found. "
|
||||
f"Available modalities: {set(self.keys())}")
|
||||
|
||||
return super().__getitem__(modality) # type: ignore[return-value]
|
||||
|
||||
def require_data(self) -> "MultiModalKwargsItems[MultiModalKwargsItem]":
|
||||
for modality, items in self.items():
|
||||
for i, item in enumerate(items):
|
||||
if item is None:
|
||||
raise RuntimeError(
|
||||
f"Found empty mm_items[{modality}][{i}]")
|
||||
|
||||
return self # type: ignore[return-value]
|
||||
|
||||
def get_data(self, *, pin_memory: bool = False) -> "MultiModalKwargs":
|
||||
elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
|
||||
for modality, items in self.items():
|
||||
for i, item in enumerate(items):
|
||||
if item is None:
|
||||
raise RuntimeError("Cannot build data from empty "
|
||||
f"mm_items[{modality}][{i}]")
|
||||
|
||||
for key, elem in item.items():
|
||||
elems_by_key[key].append(elem)
|
||||
|
||||
return MultiModalKwargs({
|
||||
key:
|
||||
elems[0].field.reduce_data(elems, pin_memory=pin_memory)
|
||||
for key, elems in elems_by_key.items()
|
||||
})
|
||||
|
||||
|
||||
MultiModalKwargsOptionalItems: TypeAlias = Union[
|
||||
MultiModalKwargsItems[MultiModalKwargsItem],
|
||||
MultiModalKwargsItems[Optional[MultiModalKwargsItem]],
|
||||
]
|
||||
|
||||
|
||||
class MultiModalKwargs(UserDict[str, NestedTensors]):
|
||||
"""
|
||||
A dictionary that represents the keyword arguments to
|
||||
[`torch.nn.Module.forward`][].
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@deprecated("`MultiModalKwargs.from_hf_inputs` is deprecated and "
|
||||
"will be removed in v0.13. "
|
||||
"Please use `MultiModalKwargsItems.from_hf_inputs` and "
|
||||
"access the tensor data using `.get_data()`.")
|
||||
def from_hf_inputs(
|
||||
hf_inputs: "BatchFeature",
|
||||
config_by_key: Mapping[str, MultiModalFieldConfig],
|
||||
):
|
||||
return MultiModalKwargsItems.from_hf_inputs(hf_inputs, config_by_key) \
|
||||
.get_data()
|
||||
|
||||
@staticmethod
|
||||
@deprecated("`MultiModalKwargs.from_items` is deprecated and "
|
||||
"will be removed in v0.13. "
|
||||
"Please use `MultiModalKwargsItems.from_seq` and "
|
||||
"access the tensor data using `.get_data()`.")
|
||||
def from_items(
|
||||
items: Sequence[MultiModalKwargsItem],
|
||||
*,
|
||||
pin_memory: bool = False,
|
||||
):
|
||||
return MultiModalKwargsItems.from_seq(items) \
|
||||
.get_data(pin_memory=pin_memory)
|
||||
|
||||
@staticmethod
|
||||
def _try_stack(nested_tensors: NestedTensors,
|
||||
pin_memory: bool = False) -> NestedTensors:
|
||||
"""
|
||||
Stack the inner dimensions that have the same shape in
|
||||
a nested list of tensors.
|
||||
|
||||
Thus, a dimension represented by a list means that the inner
|
||||
dimensions are different for each element along that dimension.
|
||||
"""
|
||||
if isinstance(nested_tensors, torch.Tensor):
|
||||
return nested_tensors
|
||||
|
||||
# TODO: Remove these once all models have been migrated
|
||||
if isinstance(nested_tensors, np.ndarray):
|
||||
return torch.from_numpy(nested_tensors)
|
||||
if isinstance(nested_tensors, (int, float)):
|
||||
return torch.tensor(nested_tensors)
|
||||
|
||||
stacked = [
|
||||
MultiModalKwargs._try_stack(t, pin_memory) for t in nested_tensors
|
||||
]
|
||||
if not is_list_of(stacked, torch.Tensor, check="all"):
|
||||
# Only tensors (not lists) can be stacked.
|
||||
return stacked
|
||||
|
||||
tensors_ = cast(list[torch.Tensor], stacked)
|
||||
if len(tensors_) == 1:
|
||||
# An optimization when `tensors_` contains only one tensor:
|
||||
# - produce exactly same result as `torch.stack(tensors_)`
|
||||
# - will achieve zero-copy if the tensor is contiguous
|
||||
return tensors_[0].unsqueeze(0).contiguous()
|
||||
|
||||
if any(t.shape != tensors_[0].shape for t in tensors_):
|
||||
# The tensors have incompatible shapes and can't be stacked.
|
||||
return tensors_
|
||||
|
||||
outputs = torch.empty(len(tensors_),
|
||||
*tensors_[0].shape,
|
||||
dtype=tensors_[0].dtype,
|
||||
device=tensors_[0].device,
|
||||
pin_memory=pin_memory)
|
||||
return torch.stack(tensors_, out=outputs)
|
||||
|
||||
@staticmethod
|
||||
def batch(inputs_list: list["MultiModalKwargs"],
|
||||
pin_memory: bool = False) -> BatchedTensorInputs:
|
||||
"""
|
||||
Batch multiple inputs together into a dictionary.
|
||||
|
||||
The resulting dictionary has the same keys as the inputs.
|
||||
If the corresponding value from each input is a tensor and they all
|
||||
share the same shape, the output value is a single batched tensor;
|
||||
otherwise, the output value is a list containing the original value
|
||||
from each input.
|
||||
"""
|
||||
if len(inputs_list) == 0:
|
||||
return {}
|
||||
|
||||
# We need to consider the case where each item in the batch
|
||||
# contains different modalities (i.e. different keys).
|
||||
item_lists = defaultdict[str, list[NestedTensors]](list)
|
||||
|
||||
for inputs in inputs_list:
|
||||
for k, v in inputs.items():
|
||||
item_lists[k].append(v)
|
||||
|
||||
return {
|
||||
k: MultiModalKwargs._try_stack(item_list, pin_memory)
|
||||
for k, item_list in item_lists.items()
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def as_kwargs(
|
||||
batched_inputs: BatchedTensorInputs,
|
||||
*,
|
||||
device: torch.types.Device,
|
||||
) -> BatchedTensorInputs:
|
||||
return json_map_leaves(
|
||||
lambda x: x.to(device=device, non_blocking=True),
|
||||
batched_inputs,
|
||||
)
|
||||
|
||||
def __getitem__(self, key: str):
|
||||
if key not in self:
|
||||
raise KeyError(f"Keyword argument {key!r} not found. "
|
||||
f"Available keys: {set(self.keys())}")
|
||||
|
||||
return super().__getitem__(key)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, self.__class__):
|
||||
return False
|
||||
|
||||
for k in self:
|
||||
if k not in other:
|
||||
return False
|
||||
if not nested_tensors_equal(self[k], other[k]):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
MultiModalPlaceholderDict: TypeAlias = Mapping[str, Sequence[PlaceholderRange]]
|
||||
"""
|
||||
A dictionary containing placeholder ranges for each modality.
|
||||
"""
|
||||
|
||||
|
||||
class MultiModalInputs(TypedDict):
|
||||
"""
|
||||
Represents the outputs of
|
||||
[`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor],
|
||||
ready to be passed to vLLM internals.
|
||||
"""
|
||||
|
||||
type: Literal["multimodal"]
|
||||
"""The type of inputs."""
|
||||
|
||||
prompt: str
|
||||
"""The processed prompt text."""
|
||||
|
||||
prompt_token_ids: list[int]
|
||||
"""The processed token IDs which includes placeholder tokens."""
|
||||
|
||||
mm_kwargs: MultiModalKwargsOptionalItems
|
||||
"""Keyword arguments to be directly passed to the model after batching."""
|
||||
|
||||
mm_hashes: "MultiModalHashes"
|
||||
"""The hashes of the multi-modal data."""
|
||||
|
||||
mm_placeholders: "MultiModalPlaceholderDict"
|
||||
"""
|
||||
For each modality, information about the placeholder tokens in
|
||||
`prompt_token_ids`.
|
||||
"""
|
||||
|
||||
cache_salt: NotRequired[str]
|
||||
"""
|
||||
Optional cache salt to be used for prefix caching.
|
||||
"""
|
||||
|
||||
|
||||
class MultiModalEncDecInputs(MultiModalInputs):
|
||||
"""
|
||||
Represents the outputs of
|
||||
[`EncDecMultiModalProcessor`][vllm.multimodal.processing.EncDecMultiModalProcessor]
|
||||
ready to be passed to vLLM internals.
|
||||
"""
|
||||
|
||||
encoder_prompt: str
|
||||
"""The processed encoder prompt text."""
|
||||
|
||||
encoder_prompt_token_ids: list[int]
|
||||
"""The processed token IDs of the encoder prompt."""
|
||||
511
vllm/multimodal/parse.py
Normal file
511
vllm/multimodal/parse.py
Normal file
@@ -0,0 +1,511 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import UserDict
|
||||
from collections.abc import Callable, Iterator, Mapping, Sequence
|
||||
from typing import (TYPE_CHECKING, Any, Generic, Literal, NamedTuple, Optional,
|
||||
TypeVar, Union)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from typing_extensions import TypeAlias, TypeGuard, assert_never
|
||||
|
||||
from vllm.utils import LazyLoader, is_list_of
|
||||
|
||||
from .audio import AudioResampler
|
||||
from .inputs import (AudioItem, HfAudioItem, HfImageItem, HfVideoItem,
|
||||
ImageItem, ModalityData, MultiModalDataDict,
|
||||
MultiModalFieldConfig, MultiModalKwargsItems, VideoItem)
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_I = TypeVar("_I")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import PIL.Image as PILImage
|
||||
else:
|
||||
PILImage = LazyLoader("PILImage", globals(), "PIL.Image")
|
||||
|
||||
|
||||
class ModalityDataItems(ABC, Generic[_T, _I]):
|
||||
"""
|
||||
Represents data items for a modality in
|
||||
[`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
|
||||
"""
|
||||
|
||||
def __init__(self, data: _T, modality: str) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.data: _T = data
|
||||
self.modality = modality
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"{type(self).__name__}(modality={self.modality!r}, "
|
||||
f"len={len(self)})")
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.get_count()
|
||||
|
||||
def __getitem__(self, index: int) -> _I:
|
||||
return self.get(index)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# Auto-generated
|
||||
def __iter__(self) -> Iterator[_I]:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_count(self) -> int:
|
||||
"""Get the number of data items."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get(self, index: int) -> _I:
|
||||
"""Get a data item by its index."""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_all(self) -> list[_I]:
|
||||
"""Get all data items."""
|
||||
return [self.get(idx) for idx in range(self.get_count())]
|
||||
|
||||
@abstractmethod
|
||||
def get_processor_data(self) -> Mapping[str, object]:
|
||||
"""Get the data to pass to the HF processor."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_passthrough_data(self) -> Mapping[str, object]:
|
||||
"""Get the data to pass directly to the model."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ProcessorBatchItems(ModalityDataItems[Sequence[_T], _T]):
|
||||
"""Base class for data items that are arranged in a list."""
|
||||
|
||||
def get_count(self) -> int:
|
||||
return len(self.data)
|
||||
|
||||
def get(self, index: int) -> _T:
|
||||
return self.data[index]
|
||||
|
||||
def get_processor_data(self) -> Mapping[str, object]:
|
||||
return {f"{self.modality}s": self.data}
|
||||
|
||||
def get_passthrough_data(self) -> Mapping[str, object]:
|
||||
return {}
|
||||
|
||||
|
||||
class EmbeddingItems(ModalityDataItems[Union[torch.Tensor, list[torch.Tensor]],
|
||||
torch.Tensor]):
|
||||
"""
|
||||
Base class for data items that are expressed as a batched embedding tensor,
|
||||
or a list of embedding tensors (one per item).
|
||||
"""
|
||||
|
||||
def get_count(self) -> int:
|
||||
return len(self.data)
|
||||
|
||||
def get(self, index: int) -> torch.Tensor:
|
||||
return self.data[index]
|
||||
|
||||
def get_processor_data(self) -> Mapping[str, object]:
|
||||
return {}
|
||||
|
||||
def get_passthrough_data(self) -> Mapping[str, object]:
|
||||
return {f"{self.modality}_embeds": self.data}
|
||||
|
||||
def get_feature_size(self, item_idx: int) -> int:
|
||||
return len(self.get(item_idx))
|
||||
|
||||
|
||||
class DictEmbeddingItems(ModalityDataItems[Mapping[str, torch.Tensor],
|
||||
Mapping[str, torch.Tensor]]):
|
||||
"""
|
||||
Base class for data items that are expressed as a dictionary of tensors.
|
||||
|
||||
Usually, the dictionary keys correspond to the outputs of HF processor.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data: Mapping[str, torch.Tensor],
|
||||
modality: str,
|
||||
required_fields: set[str],
|
||||
fields_factory: Callable[
|
||||
[Mapping[str, torch.Tensor]],
|
||||
Mapping[str, MultiModalFieldConfig],
|
||||
],
|
||||
) -> None:
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
|
||||
super().__init__(data, modality)
|
||||
|
||||
missing_required_data_keys = required_fields - data.keys()
|
||||
if missing_required_data_keys:
|
||||
data_keys = set(data.keys())
|
||||
msg = (f"The data should contain the fields: {required_fields}, "
|
||||
f"but only found the following keys: {data_keys}")
|
||||
raise ValueError(msg)
|
||||
|
||||
fields_config = fields_factory(data)
|
||||
missing_required_fields = required_fields - fields_config.keys()
|
||||
if missing_required_fields:
|
||||
fields = set(fields_config.keys())
|
||||
msg = f"{required_fields=} should be a subset of {fields=}"
|
||||
raise ValueError(msg)
|
||||
|
||||
self.fields_config = fields_config
|
||||
self.required_fields = required_fields
|
||||
|
||||
self._kwargs = MultiModalKwargsItems.from_hf_inputs(
|
||||
BatchFeature(dict(data)),
|
||||
fields_config,
|
||||
)
|
||||
|
||||
def get_count(self) -> int:
|
||||
return len(self._kwargs[self.modality])
|
||||
|
||||
def get(self, index: int) -> Mapping[str, torch.Tensor]:
|
||||
return self._kwargs[self.modality][index].get_data()
|
||||
|
||||
def get_processor_data(self) -> Mapping[str, object]:
|
||||
return {}
|
||||
|
||||
def get_passthrough_data(self) -> Mapping[str, object]:
|
||||
return self.data
|
||||
|
||||
|
||||
class AudioProcessorItems(ProcessorBatchItems[HfAudioItem]):
|
||||
|
||||
def __init__(self, data: Optional[Sequence[HfAudioItem]]) -> None:
|
||||
if data is None:
|
||||
data = [None]
|
||||
super().__init__(data, "audio")
|
||||
|
||||
def get_audio_length(self, item_idx: int) -> int:
|
||||
audio = self.get(item_idx)
|
||||
return len(audio)
|
||||
|
||||
|
||||
class AudioEmbeddingItems(EmbeddingItems):
|
||||
|
||||
def __init__(self, data: Union[torch.Tensor, list[torch.Tensor]]) -> None:
|
||||
super().__init__(data, "audio")
|
||||
|
||||
|
||||
class ImageSize(NamedTuple):
|
||||
width: int
|
||||
height: int
|
||||
|
||||
|
||||
class ImageProcessorItems(ProcessorBatchItems[HfImageItem]):
|
||||
|
||||
def __init__(self, data: Optional[Sequence[HfImageItem]]) -> None:
|
||||
if data is None:
|
||||
data = [None]
|
||||
super().__init__(data, "image")
|
||||
|
||||
def get_image_size(self, item_idx: int) -> ImageSize:
|
||||
image = self.get(item_idx)
|
||||
|
||||
if isinstance(image, PILImage.Image):
|
||||
return ImageSize(*image.size)
|
||||
if isinstance(image, (np.ndarray, torch.Tensor)):
|
||||
_, h, w = image.shape
|
||||
return ImageSize(w, h)
|
||||
|
||||
assert_never(image)
|
||||
|
||||
|
||||
class ImageEmbeddingItems(EmbeddingItems):
|
||||
|
||||
def __init__(self, data: Union[torch.Tensor, list[torch.Tensor]]) -> None:
|
||||
super().__init__(data, "image")
|
||||
|
||||
|
||||
class VideoProcessorItems(ProcessorBatchItems[HfVideoItem]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data: Optional[Sequence[HfVideoItem]],
|
||||
metadata: Optional[Union[dict[str, Any],
|
||||
list[Optional[dict[str, Any]]]]] = None,
|
||||
) -> None:
|
||||
if data is None:
|
||||
data = [None]
|
||||
super().__init__(data, "video")
|
||||
self.metadata = metadata
|
||||
|
||||
def get_num_frames(self, item_idx: int) -> int:
|
||||
return len(self.get(item_idx))
|
||||
|
||||
def get_frame_size(self, item_idx: int) -> ImageSize:
|
||||
image = self.get(item_idx)[0] # Assume that the video isn't empty
|
||||
|
||||
if isinstance(image, PILImage.Image):
|
||||
return ImageSize(*image.size)
|
||||
if isinstance(image, (np.ndarray, torch.Tensor)):
|
||||
_, h, w = image.shape
|
||||
return ImageSize(w, h)
|
||||
|
||||
assert_never(image)
|
||||
|
||||
|
||||
class VideoEmbeddingItems(EmbeddingItems):
|
||||
|
||||
def __init__(self, data: Union[torch.Tensor, list[torch.Tensor]]) -> None:
|
||||
super().__init__(data, "video")
|
||||
|
||||
|
||||
_D = TypeVar("_D", bound=ModalityDataItems[Any, Any])
|
||||
|
||||
|
||||
class MultiModalDataItems(UserDict[str, ModalityDataItems[Any, Any]]):
|
||||
"""
|
||||
As [`MultiModalDataDict`][vllm.multimodal.inputs.MultiModalDataDict], but
|
||||
normalized such that each entry corresponds to a list.
|
||||
"""
|
||||
|
||||
def get_count(self, modality: str, *, strict: bool = True) -> int:
|
||||
"""
|
||||
Get the number of data items belonging to a modality.
|
||||
|
||||
If `strict=False`, return `0` instead of raising [`KeyError`][]
|
||||
even if the modality is not found.
|
||||
"""
|
||||
if modality not in self:
|
||||
if strict:
|
||||
available_modalities = set(self.keys())
|
||||
raise KeyError(f"Modality {modality!r} not found. "
|
||||
f"Available modalities: {available_modalities}")
|
||||
|
||||
return 0
|
||||
|
||||
return self[modality].get_count()
|
||||
|
||||
def get_all_counts(self) -> Mapping[str, int]:
|
||||
"""Get the number of items belonging to each modality."""
|
||||
return {m: items.get_count() for m, items in self.items()}
|
||||
|
||||
def get_items(
|
||||
self,
|
||||
modality: str,
|
||||
typ: Union[type[_D], tuple[type[_D], ...]],
|
||||
) -> _D:
|
||||
"""
|
||||
Get the data items belonging to a modality,
|
||||
requiring that they belong to a certain type.
|
||||
"""
|
||||
if modality not in self:
|
||||
available_modalities = set(self.keys())
|
||||
raise KeyError(f"Modality {modality!r} not found. "
|
||||
f"Available modalities: {available_modalities}")
|
||||
|
||||
items = self[modality]
|
||||
if not isinstance(items, typ):
|
||||
raise TypeError(f"Invalid type of data items for {modality=}. "
|
||||
f"Expected type: {typ}, but "
|
||||
f"found type: {type(items)}")
|
||||
|
||||
return items # type: ignore[return-value]
|
||||
|
||||
|
||||
ModalityDataParser: TypeAlias = Callable[[ModalityData[Any]],
|
||||
Optional[ModalityDataItems[Any, Any]]]
|
||||
|
||||
|
||||
class MultiModalDataParser:
|
||||
"""
|
||||
Parses [`MultiModalDataDict`][vllm.multimodal.inputs.MultiModalDataDict]
|
||||
into [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
|
||||
|
||||
Args:
|
||||
target_sr (float, optional): Enables automatic resampling of audio
|
||||
items to the model's expected sampling rate.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
target_sr: Optional[float] = None,
|
||||
audio_resample_method: Literal["librosa", "scipy"] = "librosa",
|
||||
video_needs_metadata: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.audio_resampler = AudioResampler(
|
||||
target_sr=target_sr,
|
||||
method=audio_resample_method,
|
||||
)
|
||||
self.video_needs_metadata = video_needs_metadata
|
||||
|
||||
def _is_embeddings(
|
||||
self, data: object
|
||||
) -> TypeGuard[Union[torch.Tensor, list[torch.Tensor]]]:
|
||||
if isinstance(data, torch.Tensor):
|
||||
return data.ndim == 3
|
||||
if is_list_of(data, torch.Tensor):
|
||||
return data[0].ndim == 2
|
||||
|
||||
return False
|
||||
|
||||
def _is_empty(self, data: object) -> TypeGuard[None]:
|
||||
if isinstance(data, list):
|
||||
return len(data) == 0
|
||||
if isinstance(data, (np.ndarray, torch.Tensor)):
|
||||
return data.size == 0
|
||||
|
||||
return False
|
||||
|
||||
def _get_audio_with_sr(
|
||||
self,
|
||||
audio: AudioItem,
|
||||
) -> tuple[np.ndarray, Optional[float]]:
|
||||
if isinstance(audio, tuple):
|
||||
return audio
|
||||
if isinstance(audio, list):
|
||||
return np.array(audio), None
|
||||
if isinstance(audio, np.ndarray):
|
||||
return audio, None
|
||||
if isinstance(audio, torch.Tensor):
|
||||
return audio.numpy(), None
|
||||
|
||||
assert_never(audio)
|
||||
|
||||
def _get_video_with_metadata(
|
||||
self,
|
||||
video: VideoItem,
|
||||
) -> tuple[np.ndarray, Optional[dict[str, Any]]]:
|
||||
if isinstance(video, tuple):
|
||||
return video
|
||||
if isinstance(video, list):
|
||||
return np.array(video), None
|
||||
if isinstance(video, np.ndarray):
|
||||
return video, None
|
||||
if isinstance(video, torch.Tensor):
|
||||
return video.numpy(), None
|
||||
|
||||
assert_never(video)
|
||||
|
||||
def _parse_audio_data(
|
||||
self,
|
||||
data: ModalityData[AudioItem],
|
||||
) -> Optional[ModalityDataItems[Any, Any]]:
|
||||
if data is None:
|
||||
return AudioProcessorItems(None)
|
||||
|
||||
# also check single audio item with sampling rate
|
||||
if self._is_empty(data) or (isinstance(data, tuple)
|
||||
and self._is_empty(data[0])):
|
||||
return None
|
||||
|
||||
if self._is_embeddings(data):
|
||||
return AudioEmbeddingItems(data)
|
||||
|
||||
if (is_list_of(data, float)
|
||||
or isinstance(data,
|
||||
(np.ndarray, torch.Tensor)) and data.ndim == 1
|
||||
or isinstance(data, tuple)):
|
||||
data_items = [data]
|
||||
elif isinstance(data, (np.ndarray, torch.Tensor)):
|
||||
data_items = [elem for elem in data]
|
||||
else:
|
||||
data_items = data
|
||||
|
||||
new_audios = list[np.ndarray]()
|
||||
for data_item in data_items:
|
||||
audio, orig_sr = self._get_audio_with_sr(data_item)
|
||||
if orig_sr is None:
|
||||
new_audio = audio
|
||||
else:
|
||||
new_audio = self.audio_resampler.resample(audio,
|
||||
orig_sr=orig_sr)
|
||||
|
||||
new_audios.append(new_audio)
|
||||
|
||||
return AudioProcessorItems(new_audios)
|
||||
|
||||
def _parse_image_data(
|
||||
self,
|
||||
data: ModalityData[ImageItem],
|
||||
) -> Optional[ModalityDataItems[Any, Any]]:
|
||||
if data is None:
|
||||
return ImageProcessorItems(None)
|
||||
|
||||
if self._is_empty(data):
|
||||
return None
|
||||
|
||||
if self._is_embeddings(data):
|
||||
return ImageEmbeddingItems(data)
|
||||
|
||||
if (isinstance(data, PILImage.Image)
|
||||
or isinstance(data,
|
||||
(np.ndarray, torch.Tensor)) and data.ndim == 3):
|
||||
data_items = [data]
|
||||
elif isinstance(data, (np.ndarray, torch.Tensor)):
|
||||
data_items = [elem for elem in data]
|
||||
else:
|
||||
data_items = data
|
||||
|
||||
return ImageProcessorItems(data_items)
|
||||
|
||||
def _parse_video_data(
|
||||
self,
|
||||
data: ModalityData[VideoItem],
|
||||
) -> Optional[ModalityDataItems[Any, Any]]:
|
||||
if data is None:
|
||||
return VideoProcessorItems(None)
|
||||
|
||||
if self._is_empty(data):
|
||||
return None
|
||||
|
||||
if self._is_embeddings(data):
|
||||
return VideoEmbeddingItems(data)
|
||||
|
||||
if (is_list_of(data, PILImage.Image)
|
||||
or isinstance(data,
|
||||
(np.ndarray, torch.Tensor)) and data.ndim == 4):
|
||||
data_items = [data]
|
||||
elif isinstance(data, (np.ndarray, torch.Tensor)):
|
||||
data_items = [elem for elem in data]
|
||||
elif isinstance(data, tuple) and len(data) == 2:
|
||||
data_items = [data]
|
||||
else:
|
||||
data_items = data
|
||||
|
||||
new_videos = list[tuple[np.ndarray, Optional[dict[str, Any]]]]()
|
||||
metadata_lst: list[Optional[dict[str, Any]]] = []
|
||||
for data_item in data_items:
|
||||
video, metadata = self._get_video_with_metadata(data_item)
|
||||
if self.video_needs_metadata:
|
||||
new_videos.append((video, metadata))
|
||||
metadata_lst.append(metadata)
|
||||
else:
|
||||
new_videos.append(video)
|
||||
|
||||
if not self.video_needs_metadata:
|
||||
metadata = None
|
||||
|
||||
return VideoProcessorItems(new_videos, metadata=metadata_lst)
|
||||
|
||||
def _get_subparsers(self) -> Mapping[str, ModalityDataParser]:
|
||||
return {
|
||||
"audio": self._parse_audio_data,
|
||||
"image": self._parse_image_data,
|
||||
"video": self._parse_video_data,
|
||||
}
|
||||
|
||||
def parse_mm_data(self,
|
||||
mm_data: MultiModalDataDict) -> MultiModalDataItems:
|
||||
subparsers = self._get_subparsers()
|
||||
|
||||
mm_items = MultiModalDataItems()
|
||||
for k, v in mm_data.items():
|
||||
if k not in subparsers:
|
||||
raise ValueError(f"Unsupported modality: {k}")
|
||||
|
||||
# ignore empty embedding data
|
||||
if (parsed_data := subparsers[k](v)) is not None:
|
||||
mm_items[k] = parsed_data
|
||||
|
||||
return mm_items
|
||||
2148
vllm/multimodal/processing.py
Normal file
2148
vllm/multimodal/processing.py
Normal file
File diff suppressed because it is too large
Load Diff
284
vllm/multimodal/profiling.py
Normal file
284
vllm/multimodal/profiling.py
Normal file
@@ -0,0 +1,284 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Generic, NamedTuple, Optional, TypeVar, Union, cast
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
from PIL import Image
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
|
||||
MultiModalInputs, MultiModalKwargsItems,
|
||||
MultiModalPlaceholderDict)
|
||||
from .processing import (BaseMultiModalProcessor, BaseProcessingInfo,
|
||||
EncDecMultiModalProcessor)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcessorInputs:
|
||||
"""
|
||||
Represents the keyword arguments to
|
||||
[`vllm.multimodal.processing.BaseMultiModalProcessor.apply`][].
|
||||
"""
|
||||
prompt: Union[str, list[int]]
|
||||
mm_data: MultiModalDataDict
|
||||
hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict)
|
||||
tokenization_kwargs: Mapping[str, object] = field(default_factory=dict)
|
||||
|
||||
|
||||
class DummyEncoderData(NamedTuple):
|
||||
"""Dummy data used for profiling."""
|
||||
|
||||
prompt_token_ids: list[int]
|
||||
|
||||
|
||||
class DummyDecoderData(NamedTuple):
|
||||
"""Dummy data used for profiling."""
|
||||
|
||||
prompt_token_ids: list[int]
|
||||
multi_modal_data: MultiModalKwargsItems
|
||||
multi_modal_placeholders: MultiModalPlaceholderDict
|
||||
|
||||
|
||||
_I = TypeVar("_I", bound=BaseProcessingInfo)
|
||||
|
||||
|
||||
class BaseDummyInputsBuilder(ABC, Generic[_I]):
|
||||
"""
|
||||
Abstract base class that constructs the dummy data to profile
|
||||
multi-modal models.
|
||||
"""
|
||||
|
||||
def __init__(self, info: _I) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.info = info
|
||||
|
||||
@abstractmethod
|
||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||
"""
|
||||
Build the text input corresponding to `mm_counts`.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_dummy_mm_data(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> MultiModalDataDict:
|
||||
"""
|
||||
Build the multimodal input which, after processing, results in
|
||||
the maximum possible number of placeholder tokens.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
"""
|
||||
Build the input which, after processing, results in
|
||||
the maximum possible number of placeholder tokens.
|
||||
"""
|
||||
dummy_text = self.get_dummy_text(mm_counts)
|
||||
dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts)
|
||||
tokenization_kwargs = {"truncation": False}
|
||||
|
||||
return ProcessorInputs(prompt=dummy_text,
|
||||
mm_data=dummy_mm_data,
|
||||
tokenization_kwargs=tokenization_kwargs)
|
||||
|
||||
def _get_dummy_audios(
|
||||
self,
|
||||
*,
|
||||
length: int,
|
||||
num_audios: int,
|
||||
) -> list[npt.NDArray]:
|
||||
if num_audios == 0:
|
||||
return []
|
||||
audio = np.zeros((length, ))
|
||||
return [audio] * num_audios
|
||||
|
||||
def _get_dummy_images(
|
||||
self,
|
||||
*,
|
||||
width: int,
|
||||
height: int,
|
||||
num_images: int,
|
||||
) -> list[Image.Image]:
|
||||
if num_images == 0:
|
||||
return []
|
||||
image = Image.new("RGB", (width, height), color=255)
|
||||
return [image] * num_images
|
||||
|
||||
def _get_dummy_videos(
|
||||
self,
|
||||
*,
|
||||
width: int,
|
||||
height: int,
|
||||
num_frames: int,
|
||||
num_videos: int,
|
||||
) -> list[npt.NDArray]:
|
||||
if num_videos == 0:
|
||||
return []
|
||||
video = np.full((num_frames, width, height, 3), 255)
|
||||
return [video] * num_videos
|
||||
|
||||
|
||||
class MultiModalProfiler(Generic[_I]):
|
||||
"""
|
||||
Contains code for running memory profiling for multi-modal models.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
processor: BaseMultiModalProcessor[_I],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.processor = processor
|
||||
|
||||
@property
|
||||
def processing_info(self) -> BaseProcessingInfo:
|
||||
return self.processor.info
|
||||
|
||||
@property
|
||||
def dummy_inputs(self) -> BaseDummyInputsBuilder[_I]:
|
||||
return self.processor.dummy_inputs
|
||||
|
||||
def get_mm_limits(self) -> Mapping[str, int]:
|
||||
return self.processor.allowed_mm_limits
|
||||
|
||||
def _get_dummy_mm_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Optional[Mapping[str, int]] = None,
|
||||
) -> MultiModalInputs:
|
||||
if mm_counts is None:
|
||||
mm_counts = self.get_mm_limits()
|
||||
|
||||
factory = self.dummy_inputs
|
||||
processor_inputs = factory.get_dummy_processor_inputs(
|
||||
seq_len, mm_counts)
|
||||
|
||||
return self.processor.apply(
|
||||
prompt=processor_inputs.prompt,
|
||||
mm_data=processor_inputs.mm_data,
|
||||
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
|
||||
tokenization_kwargs=processor_inputs.tokenization_kwargs,
|
||||
)
|
||||
|
||||
def _get_mm_num_tokens(
|
||||
self,
|
||||
mm_inputs: MultiModalInputs,
|
||||
mm_embeddings_only: bool = True,
|
||||
) -> Mapping[str, int]:
|
||||
placeholders_by_modality = mm_inputs["mm_placeholders"]
|
||||
|
||||
return {
|
||||
modality:
|
||||
sum(item.get_num_embeds() if mm_embeddings_only else item.length
|
||||
for item in placeholders)
|
||||
for modality, placeholders in placeholders_by_modality.items()
|
||||
}
|
||||
|
||||
def get_encoder_dummy_data(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Optional[Mapping[str, int]] = None,
|
||||
) -> DummyEncoderData:
|
||||
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
|
||||
mm_inputs = cast(MultiModalEncDecInputs, mm_inputs)
|
||||
|
||||
# For encoder-decoder models, use encoder prompt token ids instead of
|
||||
# decoder prompt to construct dummy seq_data for encoder profiling.
|
||||
encoder_prompt_token_ids = mm_inputs["encoder_prompt_token_ids"]
|
||||
|
||||
total_len = len(encoder_prompt_token_ids)
|
||||
|
||||
processor = cast(EncDecMultiModalProcessor, self.processor)
|
||||
if processor.pad_dummy_encoder_prompt:
|
||||
num_tokens_to_pad = max(total_len, seq_len) - total_len
|
||||
encoder_prompt_token_ids.extend([0] * num_tokens_to_pad)
|
||||
# NOTE: Whisper allows total_len > seq_len.
|
||||
elif total_len > seq_len and not envs.VLLM_USE_V1:
|
||||
# `max_num_batched_tokens` is defined by `SchedulerConfig`
|
||||
logger.warning_once(
|
||||
"The encoder sequence length used for profiling (max_num_batched_tokens / max_num_seqs = %d) " # noqa: E501
|
||||
"is too short to hold the multi-modal embeddings in the worst case (%d tokens in total, out of which %s are reserved for multi-modal embeddings). " # noqa: E501
|
||||
"This may cause certain multi-modal inputs to fail during inference, even when the input text is short. " # noqa: E501
|
||||
"To avoid this, you should increase `max_model_len`, reduce `max_num_seqs`, and/or reduce `mm_counts`.", # noqa: E501
|
||||
seq_len,
|
||||
total_len,
|
||||
str(self._get_mm_num_tokens(mm_inputs)),
|
||||
)
|
||||
|
||||
return DummyEncoderData(encoder_prompt_token_ids)
|
||||
|
||||
def get_decoder_dummy_data(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Optional[Mapping[str, int]] = None,
|
||||
) -> DummyDecoderData:
|
||||
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
|
||||
|
||||
prompt_token_ids = mm_inputs["prompt_token_ids"]
|
||||
total_len = len(prompt_token_ids)
|
||||
|
||||
if total_len < seq_len:
|
||||
prompt_token_ids.extend([0] * (seq_len - total_len))
|
||||
|
||||
return DummyDecoderData(
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
multi_modal_data=mm_inputs["mm_kwargs"].require_data(),
|
||||
multi_modal_placeholders=mm_inputs["mm_placeholders"],
|
||||
)
|
||||
|
||||
def _get_mm_max_tokens(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Optional[Mapping[str, int]] = None,
|
||||
mm_embeddings_only: bool = True,
|
||||
) -> Mapping[str, int]:
|
||||
if mm_counts is None:
|
||||
mm_counts = self.get_mm_limits()
|
||||
|
||||
max_tokens_per_item = self.processing_info.get_mm_max_tokens_per_item(
|
||||
seq_len=seq_len,
|
||||
mm_counts=mm_counts,
|
||||
)
|
||||
if max_tokens_per_item is not None:
|
||||
return max_tokens_per_item
|
||||
|
||||
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
|
||||
return self._get_mm_num_tokens(mm_inputs,
|
||||
mm_embeddings_only=mm_embeddings_only)
|
||||
|
||||
def get_mm_max_contiguous_tokens(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Optional[Mapping[str, int]] = None,
|
||||
):
|
||||
"""
|
||||
Returns the maximum length of the multimodal (image placeholders+text)
|
||||
tokens, including any break/text tokens in-between image embeddings.
|
||||
|
||||
`<im_start> [IMG] [IMG] [IMG] <row_break> [IMG] [IMG] [IMG] <im_end>`
|
||||
Returns 9, even when the number of image embeddings is 6.
|
||||
|
||||
This is important to take into account when profiling and
|
||||
initializing the encoder cache size.
|
||||
"""
|
||||
|
||||
return self._get_mm_max_tokens(seq_len,
|
||||
mm_counts,
|
||||
mm_embeddings_only=False)
|
||||
345
vllm/multimodal/registry.py
Normal file
345
vllm/multimodal/registry.py
Normal file
@@ -0,0 +1,345 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Generic, Optional, Protocol, TypeVar
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
|
||||
cached_tokenizer_from_config)
|
||||
from vllm.utils import ClassRegistry
|
||||
|
||||
from .cache import BaseMultiModalProcessorCache
|
||||
from .processing import (BaseMultiModalProcessor, BaseProcessingInfo,
|
||||
InputProcessingContext)
|
||||
from .profiling import (BaseDummyInputsBuilder, DummyDecoderData,
|
||||
DummyEncoderData, MultiModalProfiler)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
N = TypeVar("N", bound=type[nn.Module])
|
||||
_I = TypeVar("_I", bound=BaseProcessingInfo)
|
||||
_I_co = TypeVar("_I_co", bound=BaseProcessingInfo, covariant=True)
|
||||
|
||||
|
||||
class ProcessingInfoFactory(Protocol[_I_co]):
|
||||
"""
|
||||
Constructs a
|
||||
[`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor]
|
||||
instance from the context.
|
||||
"""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
ctx: InputProcessingContext,
|
||||
) -> _I_co:
|
||||
...
|
||||
|
||||
|
||||
class DummyInputsBuilderFactory(Protocol[_I]): # type: ignore[misc]
|
||||
"""
|
||||
Constructs a
|
||||
[`BaseDummyInputsBuilder`][vllm.multimodal.profiling.BaseDummyInputsBuilder]
|
||||
instance from the context.
|
||||
"""
|
||||
|
||||
def __call__(self, info: _I) -> BaseDummyInputsBuilder[_I]:
|
||||
...
|
||||
|
||||
|
||||
class MultiModalProcessorFactory(Protocol[_I]):
|
||||
"""
|
||||
Constructs a
|
||||
[`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor]
|
||||
instance from the context.
|
||||
"""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
info: _I,
|
||||
dummy_inputs: BaseDummyInputsBuilder[_I],
|
||||
*,
|
||||
cache: Optional[BaseMultiModalProcessorCache] = None,
|
||||
) -> BaseMultiModalProcessor[_I]:
|
||||
...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _ProcessorFactories(Generic[_I]):
|
||||
info: ProcessingInfoFactory[_I]
|
||||
processor: MultiModalProcessorFactory[_I]
|
||||
dummy_inputs: DummyInputsBuilderFactory[_I]
|
||||
|
||||
def build_processor(
|
||||
self,
|
||||
ctx: InputProcessingContext,
|
||||
*,
|
||||
cache: Optional[BaseMultiModalProcessorCache] = None,
|
||||
):
|
||||
info = self.info(ctx)
|
||||
dummy_inputs_builder = self.dummy_inputs(info)
|
||||
return self.processor(info, dummy_inputs_builder, cache=cache)
|
||||
|
||||
|
||||
class MultiModalRegistry:
|
||||
"""
|
||||
A registry that dispatches data processing according to the model.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._processor_factories = ClassRegistry[nn.Module,
|
||||
_ProcessorFactories]()
|
||||
|
||||
def supports_multimodal_inputs(self, model_config: "ModelConfig") -> bool:
|
||||
"""
|
||||
Checks if the model supports multimodal inputs.
|
||||
Returns True if the model is multimodal with any non-zero supported
|
||||
modalities, otherwise returns False, effectively running in
|
||||
text-only mode.
|
||||
"""
|
||||
if not model_config.is_multimodal_model:
|
||||
return False
|
||||
|
||||
info = self._create_processing_info(model_config, tokenizer=None)
|
||||
supported_modalities = info.get_supported_mm_limits()
|
||||
|
||||
mm_config = model_config.get_multimodal_config()
|
||||
|
||||
# Check if all supported modalities have limit == 0
|
||||
if all(
|
||||
mm_config.get_limit_per_prompt(modality) == 0
|
||||
for modality in supported_modalities):
|
||||
logger.info_once(
|
||||
"All limits of multimodal modalities supported by the model "
|
||||
"are set to 0, running in text-only mode.")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get_max_tokens_per_item_by_modality(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
*,
|
||||
cache: Optional[BaseMultiModalProcessorCache] = None,
|
||||
) -> Mapping[str, int]:
|
||||
"""
|
||||
Get the maximum number of tokens per data item from each modality based
|
||||
on underlying model configuration.
|
||||
"""
|
||||
if not model_config.is_multimodal_model:
|
||||
return {}
|
||||
|
||||
processor = self.create_processor(model_config, cache=cache)
|
||||
profiler = MultiModalProfiler(processor)
|
||||
|
||||
seq_len = model_config.max_model_len
|
||||
mm_limits = self.get_mm_limits_per_prompt(model_config, cache=cache)
|
||||
|
||||
return profiler.get_mm_max_contiguous_tokens(
|
||||
seq_len,
|
||||
{
|
||||
modality: 1
|
||||
for modality, limit in mm_limits.items() if limit > 0
|
||||
},
|
||||
)
|
||||
|
||||
def get_max_tokens_per_item_by_nonzero_modality(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
*,
|
||||
cache: Optional[BaseMultiModalProcessorCache] = None,
|
||||
) -> Mapping[str, int]:
|
||||
"""
|
||||
Get the maximum number of tokens per data item from each modality based
|
||||
on underlying model configuration, excluding modalities that user
|
||||
explicitly disabled via `limit_mm_per_prompt`.
|
||||
|
||||
Note:
|
||||
This is currently directly used only in V1 for profiling the memory
|
||||
usage of a model.
|
||||
"""
|
||||
mm_limits = self.get_mm_limits_per_prompt(model_config, cache=cache)
|
||||
max_tokens_per_item = self.get_max_tokens_per_item_by_modality(
|
||||
model_config,
|
||||
cache=cache,
|
||||
)
|
||||
|
||||
return {
|
||||
key: max_tokens_per_mm_item
|
||||
for key, max_tokens_per_mm_item in max_tokens_per_item.items()
|
||||
if mm_limits[key] > 0
|
||||
}
|
||||
|
||||
def get_mm_limits_per_prompt(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
*,
|
||||
cache: Optional[BaseMultiModalProcessorCache] = None,
|
||||
) -> Mapping[str, int]:
|
||||
"""
|
||||
Get the maximum number of multi-modal input instances for each modality
|
||||
that are allowed per prompt for a model class.
|
||||
"""
|
||||
if not model_config.is_multimodal_model:
|
||||
return {}
|
||||
|
||||
processor = self.create_processor(model_config, cache=cache)
|
||||
profiler = MultiModalProfiler(processor)
|
||||
return profiler.get_mm_limits()
|
||||
|
||||
def register_processor(
|
||||
self,
|
||||
processor: MultiModalProcessorFactory[_I],
|
||||
*,
|
||||
info: ProcessingInfoFactory[_I],
|
||||
dummy_inputs: DummyInputsBuilderFactory[_I],
|
||||
):
|
||||
"""
|
||||
Register a multi-modal processor to a model class. The processor
|
||||
is constructed lazily, hence a factory method should be passed.
|
||||
|
||||
When the model receives multi-modal data, the provided function is
|
||||
invoked to transform the data into a dictionary of model inputs.
|
||||
"""
|
||||
|
||||
def wrapper(model_cls: N) -> N:
|
||||
if self._processor_factories.contains(model_cls, strict=True):
|
||||
logger.warning(
|
||||
"Model class %s already has a multi-modal processor "
|
||||
"registered to %s. It is overwritten by the new one.",
|
||||
model_cls, self)
|
||||
|
||||
self._processor_factories[model_cls] = _ProcessorFactories(
|
||||
info=info,
|
||||
dummy_inputs=dummy_inputs,
|
||||
processor=processor,
|
||||
)
|
||||
|
||||
return model_cls
|
||||
|
||||
return wrapper
|
||||
|
||||
def _get_model_cls(self, model_config: "ModelConfig"):
|
||||
# Avoid circular import
|
||||
from vllm.model_executor.model_loader import get_model_architecture
|
||||
|
||||
model_cls, _ = get_model_architecture(model_config)
|
||||
return model_cls
|
||||
|
||||
def _create_processing_ctx(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
tokenizer: Optional[AnyTokenizer] = None,
|
||||
) -> InputProcessingContext:
|
||||
if tokenizer is None and not model_config.skip_tokenizer_init:
|
||||
tokenizer = cached_tokenizer_from_config(model_config)
|
||||
return InputProcessingContext(model_config, tokenizer)
|
||||
|
||||
def _create_processing_info(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
*,
|
||||
tokenizer: Optional[AnyTokenizer] = None,
|
||||
) -> BaseProcessingInfo:
|
||||
model_cls = self._get_model_cls(model_config)
|
||||
factories = self._processor_factories[model_cls]
|
||||
ctx = self._create_processing_ctx(model_config, tokenizer)
|
||||
return factories.info(ctx)
|
||||
|
||||
def create_processor(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
*,
|
||||
tokenizer: Optional[AnyTokenizer] = None,
|
||||
cache: Optional[BaseMultiModalProcessorCache] = None,
|
||||
) -> BaseMultiModalProcessor[BaseProcessingInfo]:
|
||||
"""
|
||||
Create a multi-modal processor for a specific model and tokenizer.
|
||||
"""
|
||||
if not model_config.is_multimodal_model:
|
||||
raise ValueError(f"{model_config.model} is not a multimodal model")
|
||||
|
||||
model_cls = self._get_model_cls(model_config)
|
||||
factories = self._processor_factories[model_cls]
|
||||
|
||||
ctx = self._create_processing_ctx(model_config, tokenizer)
|
||||
|
||||
return factories.build_processor(ctx, cache=cache)
|
||||
|
||||
def get_decoder_dummy_data(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
seq_len: int,
|
||||
mm_counts: Optional[Mapping[str, int]] = None,
|
||||
*,
|
||||
cache: Optional[BaseMultiModalProcessorCache] = None,
|
||||
) -> DummyDecoderData:
|
||||
"""
|
||||
Create dummy data for profiling the memory usage of a model.
|
||||
|
||||
The model is identified by ``model_config``.
|
||||
"""
|
||||
processor = self.create_processor(model_config, cache=cache)
|
||||
profiler = MultiModalProfiler(processor)
|
||||
dummy_data = profiler.get_decoder_dummy_data(seq_len, mm_counts)
|
||||
|
||||
# Having more tokens is over-conservative but otherwise fine
|
||||
token_ids = dummy_data.prompt_token_ids
|
||||
if len(token_ids) < seq_len:
|
||||
raise AssertionError(
|
||||
f"Expected at least {seq_len} dummy tokens for profiling, "
|
||||
f"but found {len(token_ids)} tokens instead.")
|
||||
|
||||
return dummy_data
|
||||
|
||||
def get_encoder_dummy_data(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
seq_len: int,
|
||||
mm_counts: Optional[Mapping[str, int]] = None,
|
||||
*,
|
||||
cache: Optional[BaseMultiModalProcessorCache] = None,
|
||||
) -> DummyEncoderData:
|
||||
"""
|
||||
Create dummy data for profiling the memory usage of a model.
|
||||
|
||||
The model is identified by ``model_config``.
|
||||
"""
|
||||
processor = self.create_processor(model_config, cache=cache)
|
||||
profiler = MultiModalProfiler(processor)
|
||||
dummy_data = profiler.get_encoder_dummy_data(seq_len, mm_counts)
|
||||
|
||||
# Having more tokens is over-conservative but otherwise fine
|
||||
token_ids = dummy_data.prompt_token_ids
|
||||
if len(token_ids) < seq_len:
|
||||
logger.warning_once(
|
||||
"Expected at least %d dummy encoder tokens for profiling, but found %d tokens instead.", # noqa: E501
|
||||
seq_len,
|
||||
len(token_ids),
|
||||
)
|
||||
|
||||
return dummy_data
|
||||
|
||||
def get_encdec_max_encoder_len(self, model_config: "ModelConfig") -> int:
|
||||
"""
|
||||
Get the maximum length of the encoder input for encoder-decoder models.
|
||||
"""
|
||||
if not model_config.is_encoder_decoder:
|
||||
return 0
|
||||
max_tokens = self.\
|
||||
get_max_tokens_per_item_by_nonzero_modality(model_config)
|
||||
if not max_tokens:
|
||||
# TODO - this function assumes encoder-decoder models are
|
||||
# multimodal. This will need to change when adding support for more
|
||||
# than whisper.
|
||||
return 0
|
||||
assert len(max_tokens) == 1, "Encoder-decoder models are expected \
|
||||
to implement the multimodal interface with at most one modality."
|
||||
|
||||
first_modality = next(iter(max_tokens))
|
||||
return max_tokens[first_modality]
|
||||
503
vllm/multimodal/utils.py
Normal file
503
vllm/multimodal/utils.py
Normal file
@@ -0,0 +1,503 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import atexit
|
||||
from collections.abc import Iterable
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from itertools import groupby
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union
|
||||
from urllib.parse import ParseResult, urlparse
|
||||
from urllib.request import url2pathname
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import torch
|
||||
from PIL import Image, UnidentifiedImageError
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.connections import HTTPConnection, global_http_connection
|
||||
from vllm.utils.jsontree import json_map_leaves
|
||||
|
||||
from .audio import AudioMediaIO
|
||||
from .base import MediaIO
|
||||
from .image import ImageEmbeddingMediaIO, ImageMediaIO
|
||||
from .video import VideoMediaIO
|
||||
|
||||
_M = TypeVar("_M")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .inputs import (BatchedTensorInputs, MultiModalKwargsItem,
|
||||
MultiModalKwargsItems, MultiModalPlaceholderDict)
|
||||
else:
|
||||
BatchedTensorInputs = Any
|
||||
MultiModalKwargsItem = Any
|
||||
MultiModalKwargsItems = Any
|
||||
MultiModalPlaceholderDict = Any
|
||||
|
||||
global_thread_pool = ThreadPoolExecutor(
|
||||
max_workers=envs.VLLM_MEDIA_LOADING_THREAD_COUNT)
|
||||
atexit.register(global_thread_pool.shutdown)
|
||||
|
||||
|
||||
class MediaConnector:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
media_io_kwargs: Optional[dict[str, dict[str, Any]]] = None,
|
||||
connection: HTTPConnection = global_http_connection,
|
||||
*,
|
||||
allowed_local_media_path: str = "",
|
||||
allowed_media_domains: Optional[list[str]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
media_io_kwargs: Additional args passed to process media
|
||||
inputs, keyed by modalities. For example,
|
||||
to set num_frames for video, set
|
||||
`--media-io-kwargs '{"video":{"num_frames":40}}'`
|
||||
connection: HTTP connection client to download media contents.
|
||||
allowed_local_media_path: A local directory to load media files
|
||||
from.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.media_io_kwargs: dict[str, dict[
|
||||
str, Any]] = media_io_kwargs if media_io_kwargs else {}
|
||||
self.connection = connection
|
||||
|
||||
if allowed_local_media_path:
|
||||
allowed_local_media_path_ = Path(allowed_local_media_path)
|
||||
|
||||
if not allowed_local_media_path_.exists():
|
||||
raise ValueError(
|
||||
"Invalid `--allowed-local-media-path`: The path "
|
||||
f"{allowed_local_media_path_} does not exist.")
|
||||
if not allowed_local_media_path_.is_dir():
|
||||
raise ValueError(
|
||||
"Invalid `--allowed-local-media-path`: The path "
|
||||
f"{allowed_local_media_path_} must be a directory.")
|
||||
else:
|
||||
allowed_local_media_path_ = None
|
||||
|
||||
self.allowed_local_media_path = allowed_local_media_path_
|
||||
if allowed_media_domains is None:
|
||||
allowed_media_domains = []
|
||||
self.allowed_media_domains = allowed_media_domains
|
||||
|
||||
def _load_data_url(
|
||||
self,
|
||||
url_spec: ParseResult,
|
||||
media_io: MediaIO[_M],
|
||||
) -> _M: # type: ignore[type-var]
|
||||
data_spec, data = url_spec.path.split(",", 1)
|
||||
media_type, data_type = data_spec.split(";", 1)
|
||||
|
||||
if data_type != "base64":
|
||||
msg = "Only base64 data URLs are supported for now."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
return media_io.load_base64(media_type, data)
|
||||
|
||||
def _load_file_url(
|
||||
self,
|
||||
url_spec: ParseResult,
|
||||
media_io: MediaIO[_M],
|
||||
) -> _M: # type: ignore[type-var]
|
||||
allowed_local_media_path = self.allowed_local_media_path
|
||||
if allowed_local_media_path is None:
|
||||
raise RuntimeError("Cannot load local files without "
|
||||
"`--allowed-local-media-path`.")
|
||||
|
||||
filepath = Path(url2pathname(url_spec.path))
|
||||
if allowed_local_media_path not in filepath.resolve().parents:
|
||||
raise ValueError(
|
||||
f"The file path {filepath} must be a subpath "
|
||||
f"of `--allowed-local-media-path` {allowed_local_media_path}.")
|
||||
|
||||
return media_io.load_file(filepath)
|
||||
|
||||
def _assert_url_in_allowed_media_domains(self, url_spec) -> None:
|
||||
if self.allowed_media_domains and url_spec.hostname not in \
|
||||
self.allowed_media_domains:
|
||||
raise ValueError(
|
||||
f"The URL must be from one of the allowed domains: "
|
||||
f"{self.allowed_media_domains}. Input URL domain: "
|
||||
f"{url_spec.hostname}")
|
||||
|
||||
def load_from_url(
|
||||
self,
|
||||
url: str,
|
||||
media_io: MediaIO[_M],
|
||||
*,
|
||||
fetch_timeout: Optional[int] = None,
|
||||
) -> _M: # type: ignore[type-var]
|
||||
url_spec = urlparse(url)
|
||||
|
||||
if url_spec.scheme.startswith("http"):
|
||||
self._assert_url_in_allowed_media_domains(url_spec)
|
||||
|
||||
connection = self.connection
|
||||
data = connection.get_bytes(
|
||||
url,
|
||||
timeout=fetch_timeout,
|
||||
allow_redirects=envs.VLLM_MEDIA_URL_ALLOW_REDIRECTS,
|
||||
)
|
||||
|
||||
return media_io.load_bytes(data)
|
||||
|
||||
if url_spec.scheme == "data":
|
||||
return self._load_data_url(url_spec, media_io)
|
||||
|
||||
if url_spec.scheme == "file":
|
||||
return self._load_file_url(url_spec, media_io)
|
||||
|
||||
msg = "The URL must be either a HTTP, data or file URL."
|
||||
raise ValueError(msg)
|
||||
|
||||
async def load_from_url_async(
|
||||
self,
|
||||
url: str,
|
||||
media_io: MediaIO[_M],
|
||||
*,
|
||||
fetch_timeout: Optional[int] = None,
|
||||
) -> _M:
|
||||
url_spec = urlparse(url)
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
if url_spec.scheme.startswith("http"):
|
||||
self._assert_url_in_allowed_media_domains(url_spec)
|
||||
|
||||
connection = self.connection
|
||||
data = await connection.async_get_bytes(
|
||||
url,
|
||||
timeout=fetch_timeout,
|
||||
allow_redirects=envs.VLLM_MEDIA_URL_ALLOW_REDIRECTS,
|
||||
)
|
||||
future = loop.run_in_executor(global_thread_pool,
|
||||
media_io.load_bytes, data)
|
||||
return await future
|
||||
|
||||
if url_spec.scheme == "data":
|
||||
future = loop.run_in_executor(global_thread_pool,
|
||||
self._load_data_url, url_spec,
|
||||
media_io)
|
||||
return await future
|
||||
|
||||
if url_spec.scheme == "file":
|
||||
future = loop.run_in_executor(global_thread_pool,
|
||||
self._load_file_url, url_spec,
|
||||
media_io)
|
||||
return await future
|
||||
msg = "The URL must be either a HTTP, data or file URL."
|
||||
raise ValueError(msg)
|
||||
|
||||
def fetch_audio(
|
||||
self,
|
||||
audio_url: str,
|
||||
) -> tuple[np.ndarray, Union[int, float]]:
|
||||
"""
|
||||
Load audio from a URL.
|
||||
"""
|
||||
audio_io = AudioMediaIO(**self.media_io_kwargs.get("audio", {}))
|
||||
|
||||
return self.load_from_url(
|
||||
audio_url,
|
||||
audio_io,
|
||||
fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
|
||||
)
|
||||
|
||||
async def fetch_audio_async(
|
||||
self,
|
||||
audio_url: str,
|
||||
) -> tuple[np.ndarray, Union[int, float]]:
|
||||
"""
|
||||
Asynchronously fetch audio from a URL.
|
||||
"""
|
||||
audio_io = AudioMediaIO(**self.media_io_kwargs.get("audio", {}))
|
||||
|
||||
return await self.load_from_url_async(
|
||||
audio_url,
|
||||
audio_io,
|
||||
fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
|
||||
)
|
||||
|
||||
def fetch_image(
|
||||
self,
|
||||
image_url: str,
|
||||
*,
|
||||
image_mode: str = "RGB",
|
||||
) -> Image.Image:
|
||||
"""
|
||||
Load a PIL image from an HTTP or base64 data URL.
|
||||
|
||||
By default, the image is converted into RGB format.
|
||||
"""
|
||||
image_io = ImageMediaIO(image_mode=image_mode,
|
||||
**self.media_io_kwargs.get("image", {}))
|
||||
|
||||
try:
|
||||
return self.load_from_url(
|
||||
image_url,
|
||||
image_io,
|
||||
fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
|
||||
)
|
||||
except UnidentifiedImageError as e:
|
||||
# convert to ValueError to be properly caught upstream
|
||||
raise ValueError(str(e)) from e
|
||||
|
||||
async def fetch_image_async(
|
||||
self,
|
||||
image_url: str,
|
||||
*,
|
||||
image_mode: str = "RGB",
|
||||
) -> Image.Image:
|
||||
"""
|
||||
Asynchronously load a PIL image from an HTTP or base64 data URL.
|
||||
|
||||
By default, the image is converted into RGB format.
|
||||
"""
|
||||
image_io = ImageMediaIO(image_mode=image_mode,
|
||||
**self.media_io_kwargs.get("image", {}))
|
||||
|
||||
try:
|
||||
return await self.load_from_url_async(
|
||||
image_url,
|
||||
image_io,
|
||||
fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
|
||||
)
|
||||
except UnidentifiedImageError as e:
|
||||
# convert to ValueError to be properly caught upstream
|
||||
raise ValueError(str(e)) from e
|
||||
|
||||
def fetch_video(
|
||||
self,
|
||||
video_url: str,
|
||||
*,
|
||||
image_mode: str = "RGB",
|
||||
) -> tuple[npt.NDArray, dict[str, Any]]:
|
||||
"""
|
||||
Load video from an HTTP or base64 data URL.
|
||||
"""
|
||||
image_io = ImageMediaIO(image_mode=image_mode,
|
||||
**self.media_io_kwargs.get("image", {}))
|
||||
video_io = VideoMediaIO(image_io,
|
||||
**self.media_io_kwargs.get("video", {}))
|
||||
|
||||
return self.load_from_url(
|
||||
video_url,
|
||||
video_io,
|
||||
fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
|
||||
)
|
||||
|
||||
async def fetch_video_async(
|
||||
self,
|
||||
video_url: str,
|
||||
*,
|
||||
image_mode: str = "RGB",
|
||||
) -> tuple[npt.NDArray, dict[str, Any]]:
|
||||
"""
|
||||
Asynchronously load video from an HTTP or base64 data URL.
|
||||
|
||||
By default, the image is converted into RGB format.
|
||||
"""
|
||||
image_io = ImageMediaIO(image_mode=image_mode,
|
||||
**self.media_io_kwargs.get("image", {}))
|
||||
video_io = VideoMediaIO(image_io,
|
||||
**self.media_io_kwargs.get("video", {}))
|
||||
|
||||
return await self.load_from_url_async(
|
||||
video_url,
|
||||
video_io,
|
||||
fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
|
||||
)
|
||||
|
||||
def fetch_image_embedding(
|
||||
self,
|
||||
data: str,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Load image embedding from a URL.
|
||||
"""
|
||||
image_embedding_io = ImageEmbeddingMediaIO()
|
||||
|
||||
return image_embedding_io.load_base64("", data)
|
||||
|
||||
|
||||
def encode_audio_base64(
|
||||
audio: np.ndarray,
|
||||
sampling_rate: int,
|
||||
) -> str:
|
||||
"""Encode audio as base64."""
|
||||
audio_io = AudioMediaIO()
|
||||
return audio_io.encode_base64((audio, sampling_rate))
|
||||
|
||||
|
||||
def encode_image_base64(
|
||||
image: Image.Image,
|
||||
*,
|
||||
image_mode: str = "RGB",
|
||||
format: str = "JPEG",
|
||||
) -> str:
|
||||
"""
|
||||
Encode a pillow image to base64 format.
|
||||
|
||||
By default, the image is converted into RGB format before being encoded.
|
||||
"""
|
||||
image_io = ImageMediaIO(image_mode=image_mode)
|
||||
return image_io.encode_base64(image, image_format=format)
|
||||
|
||||
|
||||
def encode_video_base64(frames: npt.NDArray) -> str:
|
||||
image_io = ImageMediaIO()
|
||||
video_io = VideoMediaIO(image_io)
|
||||
return video_io.encode_base64(frames)
|
||||
|
||||
|
||||
def argsort_mm_positions(
|
||||
mm_positions: MultiModalPlaceholderDict) -> list[tuple[str, int]]:
|
||||
"""
|
||||
Given a `MultiModalPlaceholderDict`, output a sequence of keys to
|
||||
sort the dictionary by `offset` (starting index in the input sequence)
|
||||
in ascending order.
|
||||
|
||||
Returns:
|
||||
A list of `(modality, idx)`, which can be used to access an item
|
||||
by `mm_positions[modality][idx]`.
|
||||
"""
|
||||
flat_items = ((modality, idx, item)
|
||||
for modality, items in mm_positions.items()
|
||||
for idx, item in enumerate(items))
|
||||
|
||||
sorted_flat_items = sorted(flat_items, key=lambda x: x[2].offset)
|
||||
|
||||
return [(modality, idx) for modality, idx, _ in sorted_flat_items]
|
||||
|
||||
|
||||
# Temporary back-compatibility for plugins that define model runner
|
||||
@deprecated("`group_mm_inputs_by_modality` is superseded by "
|
||||
"`group_mm_kwargs_by_modality` and will be removed in v0.13. "
|
||||
"Please use `group_mm_kwargs_by_modality` instead.")
|
||||
def group_mm_inputs_by_modality(
|
||||
mm_inputs: list[MultiModalKwargsItems]
|
||||
) -> list[list[MultiModalKwargsItems]]:
|
||||
if not mm_inputs:
|
||||
return []
|
||||
|
||||
def modality_group_func(
|
||||
mm_input: MultiModalKwargsItems) -> Union[str, int]:
|
||||
# If the input has multiple modalities, return an id as the unique key
|
||||
# for the mm_input input.
|
||||
if len(mm_input) > 1:
|
||||
return id(mm_input)
|
||||
|
||||
elif len(mm_input) == 1:
|
||||
return next(iter(mm_input.keys()))
|
||||
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
|
||||
return [
|
||||
list(group) for _, group in groupby(mm_inputs, key=modality_group_func)
|
||||
]
|
||||
|
||||
|
||||
def group_mm_kwargs_by_modality(
|
||||
mm_kwargs: list[MultiModalKwargsItem],
|
||||
*,
|
||||
device: torch.types.Device = None,
|
||||
pin_memory: bool = False,
|
||||
merge_by_field_config: bool = False,
|
||||
) -> Iterable[tuple[str, int, BatchedTensorInputs]]:
|
||||
"""Group consecutive `MultiModalKwargsItem`s from `mm_kwargs` with the same
|
||||
modality together into the same `MultiModalKwargs` instance.
|
||||
|
||||
Args:
|
||||
mm_kwargs: List of `MultiModalKwargsItem`.
|
||||
device: The device to place the grouped tensors on.
|
||||
pin_memory: Whether to pin memory for faster host-to-device transfer.
|
||||
|
||||
Yields:
|
||||
A tuple `(modality, num_items, grouped_kwargs)`.
|
||||
"""
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, MultiModalKwargsItems
|
||||
|
||||
for modality, items in groupby(mm_kwargs, key=lambda item: item.modality):
|
||||
items_lst = list(items)
|
||||
|
||||
# TODO: Enable `merge_by_field_config` for all models
|
||||
# to avoid creating an extra batch dimension (except for fields
|
||||
# that are meant to be stacked anyway).
|
||||
# We will also need to update each model to remove `flatten_bn`.
|
||||
if merge_by_field_config:
|
||||
mm_kwargs_group: BatchedTensorInputs = dict(
|
||||
MultiModalKwargsItems.from_seq(items_lst).get_data(
|
||||
pin_memory=pin_memory))
|
||||
|
||||
if device is not None:
|
||||
mm_kwargs_group = json_map_leaves(
|
||||
lambda x: x.to(device=device),
|
||||
mm_kwargs_group,
|
||||
)
|
||||
else:
|
||||
mm_kwargs_group = MultiModalKwargs.as_kwargs(
|
||||
MultiModalKwargs.batch(
|
||||
[
|
||||
MultiModalKwargsItems.from_seq([item]).get_data()
|
||||
for item in items_lst
|
||||
],
|
||||
pin_memory=pin_memory,
|
||||
),
|
||||
device=device,
|
||||
)
|
||||
|
||||
yield modality, len(items_lst), mm_kwargs_group
|
||||
|
||||
|
||||
def fetch_audio(
|
||||
audio_url: str,
|
||||
audio_io_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> tuple[np.ndarray, Union[int, float]]:
|
||||
"""
|
||||
Args:
|
||||
audio_url: URL of the audio file to fetch.
|
||||
audio_io_kwargs: Additional kwargs passed to handle audio IO.
|
||||
"""
|
||||
media_io_kwargs = None if not audio_io_kwargs else {
|
||||
"audio": audio_io_kwargs
|
||||
}
|
||||
media_connector = MediaConnector(media_io_kwargs=media_io_kwargs)
|
||||
return media_connector.fetch_audio(audio_url)
|
||||
|
||||
|
||||
def fetch_image(
|
||||
image_url: str,
|
||||
image_io_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> Image.Image:
|
||||
"""
|
||||
Args:
|
||||
image_url: URL of the image file to fetch.
|
||||
image_io_kwargs: Additional kwargs passed to handle image IO.
|
||||
"""
|
||||
media_io_kwargs = None if not image_io_kwargs else {
|
||||
"image": image_io_kwargs
|
||||
}
|
||||
media_connector = MediaConnector(media_io_kwargs=media_io_kwargs)
|
||||
return media_connector.fetch_image(image_url)
|
||||
|
||||
|
||||
def fetch_video(
|
||||
video_url: str,
|
||||
video_io_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> tuple[npt.NDArray, dict[str, Any]]:
|
||||
"""
|
||||
Args:
|
||||
video_url: URL of the video file to fetch.
|
||||
video_io_kwargs: Additional kwargs passed to handle video IO.
|
||||
"""
|
||||
media_io_kwargs = None if not video_io_kwargs else {
|
||||
"video": video_io_kwargs
|
||||
}
|
||||
media_connector = MediaConnector(media_io_kwargs=media_io_kwargs)
|
||||
return media_connector.fetch_video(video_url)
|
||||
319
vllm/multimodal/video.py
Normal file
319
vllm/multimodal/video.py
Normal file
@@ -0,0 +1,319 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import base64
|
||||
import math
|
||||
from abc import abstractmethod
|
||||
from functools import partial
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Any, Union
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
from PIL import Image
|
||||
|
||||
from vllm import envs
|
||||
|
||||
from .base import MediaIO
|
||||
from .image import ImageMediaIO
|
||||
|
||||
|
||||
def resize_video(frames: npt.NDArray, size: tuple[int, int]) -> npt.NDArray:
|
||||
num_frames, _, _, channels = frames.shape
|
||||
new_height, new_width = size
|
||||
resized_frames = np.empty((num_frames, new_height, new_width, channels),
|
||||
dtype=frames.dtype)
|
||||
# lazy import cv2 to avoid bothering users who only use text models
|
||||
import cv2
|
||||
|
||||
for i, frame in enumerate(frames):
|
||||
resized_frame = cv2.resize(frame, (new_width, new_height))
|
||||
resized_frames[i] = resized_frame
|
||||
return resized_frames
|
||||
|
||||
|
||||
def rescale_video_size(frames: npt.NDArray, size_factor: float) -> npt.NDArray:
|
||||
_, height, width, _ = frames.shape
|
||||
new_height = int(height * size_factor)
|
||||
new_width = int(width * size_factor)
|
||||
|
||||
return resize_video(frames, (new_height, new_width))
|
||||
|
||||
|
||||
def sample_frames_from_video(frames: npt.NDArray,
|
||||
num_frames: int) -> npt.NDArray:
|
||||
total_frames = frames.shape[0]
|
||||
if num_frames == -1:
|
||||
return frames
|
||||
|
||||
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
|
||||
sampled_frames = frames[frame_indices, ...]
|
||||
return sampled_frames
|
||||
|
||||
|
||||
class VideoLoader:
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def load_bytes(cls,
|
||||
data: bytes,
|
||||
num_frames: int = -1,
|
||||
**kwargs) -> tuple[npt.NDArray, dict[str, Any]]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class VideoLoaderRegistry:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.name2class: dict[str, type] = {}
|
||||
|
||||
def register(self, name: str):
|
||||
|
||||
def wrap(cls_to_register):
|
||||
self.name2class[name] = cls_to_register
|
||||
return cls_to_register
|
||||
|
||||
return wrap
|
||||
|
||||
@staticmethod
|
||||
def load(cls_name: str) -> VideoLoader:
|
||||
cls = VIDEO_LOADER_REGISTRY.name2class.get(cls_name)
|
||||
assert cls is not None, f"VideoLoader class {cls_name} not found"
|
||||
return cls()
|
||||
|
||||
|
||||
VIDEO_LOADER_REGISTRY = VideoLoaderRegistry()
|
||||
|
||||
|
||||
@VIDEO_LOADER_REGISTRY.register("opencv")
|
||||
class OpenCVVideoBackend(VideoLoader):
|
||||
|
||||
def get_cv2_video_api(self):
|
||||
import cv2.videoio_registry as vr
|
||||
|
||||
api_pref = None
|
||||
for backend in vr.getStreamBufferedBackends():
|
||||
if not vr.hasBackend(backend):
|
||||
continue
|
||||
if not vr.isBackendBuiltIn(backend):
|
||||
_, abi, api = vr.getStreamBufferedBackendPluginVersion(backend)
|
||||
if abi < 1 or (abi == 1 and api < 2):
|
||||
continue
|
||||
api_pref = backend
|
||||
break
|
||||
return api_pref
|
||||
|
||||
@classmethod
|
||||
def load_bytes(
|
||||
cls,
|
||||
data: bytes,
|
||||
num_frames: int = -1,
|
||||
**kwargs,
|
||||
) -> tuple[npt.NDArray, dict[str, Any]]:
|
||||
import cv2
|
||||
|
||||
backend = cls().get_cv2_video_api()
|
||||
cap = cv2.VideoCapture(BytesIO(data), backend, [])
|
||||
if not cap.isOpened():
|
||||
raise ValueError("Could not open video stream")
|
||||
|
||||
total_frames_num = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
original_fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
duration = total_frames_num / original_fps if original_fps > 0 else 0
|
||||
|
||||
# resample video to target num_frames
|
||||
full_read = num_frames == -1 or total_frames_num < num_frames
|
||||
if full_read:
|
||||
num_frames = total_frames_num
|
||||
frame_idx = list(range(0, num_frames))
|
||||
else:
|
||||
uniform_sampled_frames = np.linspace(0,
|
||||
total_frames_num - 1,
|
||||
num_frames,
|
||||
dtype=int)
|
||||
frame_idx = uniform_sampled_frames.tolist()
|
||||
|
||||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
frames = np.empty((len(frame_idx), height, width, 3), dtype=np.uint8)
|
||||
|
||||
i = 0
|
||||
for idx in range(total_frames_num):
|
||||
ok = cap.grab()
|
||||
if not ok:
|
||||
break
|
||||
if idx in frame_idx:
|
||||
ret, frame = cap.retrieve()
|
||||
if ret:
|
||||
frames[i] = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
i += 1
|
||||
|
||||
assert i == num_frames, (f"Expected reading {num_frames} frames, "
|
||||
f"but only loaded {i} frames from video.")
|
||||
|
||||
# Use transformers transformers.video_utils.VideoMetadata format
|
||||
# NOTE(Isotr0py): For models like Qwen3-VL/GLM4.5V, this metadata
|
||||
# can cause incorrect timestamp calculation without num_frames=-1.
|
||||
metadata = {
|
||||
"total_num_frames": num_frames,
|
||||
"fps": num_frames / duration,
|
||||
"duration": duration,
|
||||
"video_backend": "opencv",
|
||||
"frames_indices": list(range(num_frames)),
|
||||
# extra field used to control hf processor's video
|
||||
# sampling behavior
|
||||
"do_sample_frames": num_frames == total_frames_num,
|
||||
}
|
||||
|
||||
return frames, metadata
|
||||
|
||||
|
||||
@VIDEO_LOADER_REGISTRY.register("opencv_dynamic")
|
||||
class OpenCVDynamicVideoBackend(OpenCVVideoBackend):
|
||||
|
||||
@classmethod
|
||||
def load_bytes(
|
||||
cls,
|
||||
data: bytes,
|
||||
num_frames: int = -1,
|
||||
fps: int = 2,
|
||||
max_duration: int = 300,
|
||||
**kwargs,
|
||||
) -> tuple[npt.NDArray, dict[str, Any]]:
|
||||
import cv2
|
||||
|
||||
backend = cls().get_cv2_video_api()
|
||||
cap = cv2.VideoCapture(BytesIO(data), backend, [])
|
||||
if not cap.isOpened():
|
||||
raise ValueError("Could not open video stream")
|
||||
|
||||
total_frames_num = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
original_fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
duration = total_frames_num / original_fps if original_fps > 0 else 0
|
||||
|
||||
# resample video to target num_frames
|
||||
max_frame_idx = total_frames_num - 1
|
||||
duration = duration or round(max_frame_idx / original_fps) + 1
|
||||
|
||||
# Refer to:
|
||||
# https://github.com/huggingface/transformers/blob/v4.55.4/src/transformers/models/glm4v/video_processing_glm4v.py#L103-L140
|
||||
frame_indices: Union[range, list[int]]
|
||||
if duration <= max_duration:
|
||||
n = int(math.floor(duration * fps))
|
||||
frame_indices = sorted({
|
||||
min(max_frame_idx, int(math.ceil(i * original_fps / fps)))
|
||||
for i in range(n)
|
||||
})
|
||||
else:
|
||||
num_samples = int(max_duration * fps)
|
||||
if num_samples >= total_frames_num:
|
||||
frame_indices = range(total_frames_num)
|
||||
else:
|
||||
target_seconds = np.linspace(0,
|
||||
duration,
|
||||
num_samples,
|
||||
endpoint=True)
|
||||
frame_indices = sorted({
|
||||
min(max_frame_idx, int(math.ceil(t * original_fps)))
|
||||
for t in target_seconds
|
||||
})
|
||||
|
||||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
frames = np.empty((len(frame_indices), height, width, 3),
|
||||
dtype=np.uint8)
|
||||
|
||||
i = 0
|
||||
for idx in range(total_frames_num):
|
||||
ok = cap.grab()
|
||||
if not ok:
|
||||
break
|
||||
if idx in frame_indices:
|
||||
ret, frame = cap.retrieve()
|
||||
if ret:
|
||||
frames[i] = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
i += 1
|
||||
|
||||
assert i == len(frame_indices), (
|
||||
f"Expected reading {len(frame_indices)} frames, "
|
||||
f"but only loaded {i} frames from video.")
|
||||
|
||||
# Use transformers transformers.video_utils.VideoMetadata format
|
||||
metadata = {
|
||||
"total_num_frames": total_frames_num,
|
||||
"fps": original_fps,
|
||||
"duration": duration,
|
||||
"video_backend": "opencv_dynamic",
|
||||
"frames_indices": list(frame_indices),
|
||||
"do_sample_frames": False,
|
||||
}
|
||||
|
||||
return frames, metadata
|
||||
|
||||
|
||||
class VideoMediaIO(MediaIO[npt.NDArray]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_io: ImageMediaIO,
|
||||
num_frames: int = 32,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.image_io = image_io
|
||||
self.num_frames = num_frames
|
||||
# `kwargs` contains custom arguments from
|
||||
# --media-io-kwargs for this modality.
|
||||
# They can be passed to the underlying
|
||||
# media loaders (e.g. custom implementations)
|
||||
# for flexible control.
|
||||
self.kwargs = kwargs
|
||||
video_loader_backend = envs.VLLM_VIDEO_LOADER_BACKEND
|
||||
self.video_loader = VIDEO_LOADER_REGISTRY.load(video_loader_backend)
|
||||
|
||||
def load_bytes(self, data: bytes) -> tuple[npt.NDArray, dict[str, Any]]:
|
||||
return self.video_loader.load_bytes(data,
|
||||
num_frames=self.num_frames,
|
||||
**self.kwargs)
|
||||
|
||||
def load_base64(self, media_type: str,
|
||||
data: str) -> tuple[npt.NDArray, dict[str, Any]]:
|
||||
if media_type.lower() == "video/jpeg":
|
||||
load_frame = partial(
|
||||
self.image_io.load_base64,
|
||||
"image/jpeg",
|
||||
)
|
||||
|
||||
return np.stack([
|
||||
np.asarray(load_frame(frame_data))
|
||||
for frame_data in data.split(",")
|
||||
]), {}
|
||||
|
||||
return self.load_bytes(base64.b64decode(data))
|
||||
|
||||
def load_file(self, filepath: Path) -> tuple[npt.NDArray, dict[str, Any]]:
|
||||
with filepath.open("rb") as f:
|
||||
data = f.read()
|
||||
|
||||
return self.load_bytes(data)
|
||||
|
||||
def encode_base64(
|
||||
self,
|
||||
media: npt.NDArray,
|
||||
*,
|
||||
video_format: str = "JPEG",
|
||||
) -> str:
|
||||
video = media
|
||||
|
||||
if video_format == "JPEG":
|
||||
encode_frame = partial(
|
||||
self.image_io.encode_base64,
|
||||
image_format=video_format,
|
||||
)
|
||||
|
||||
return ",".join(
|
||||
encode_frame(Image.fromarray(frame)) for frame in video)
|
||||
|
||||
msg = "Only JPEG format is supported for now."
|
||||
raise NotImplementedError(msg)
|
||||
Reference in New Issue
Block a user