[gpt-oss] Add gpt-oss bf16 support
This commit is contained in:
33
vllm/multimodal/__init__.py
Normal file
33
vllm/multimodal/__init__.py
Normal file
@@ -0,0 +1,33 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from .base import MultiModalPlaceholderMap
|
||||
from .hasher import MultiModalHashDict, MultiModalHasher
|
||||
from .inputs import (BatchedTensorInputs, ModalityData, MultiModalDataBuiltins,
|
||||
MultiModalDataDict, MultiModalKwargs,
|
||||
MultiModalPlaceholderDict, NestedTensors)
|
||||
from .registry import MultiModalRegistry
|
||||
|
||||
MULTIMODAL_REGISTRY = MultiModalRegistry()
|
||||
"""
|
||||
The global [`MultiModalRegistry`][vllm.multimodal.registry.MultiModalRegistry]
|
||||
is used by model runners to dispatch data processing according to the target
|
||||
model.
|
||||
|
||||
Info:
|
||||
[mm_processing](../../../design/mm_processing.html)
|
||||
"""
|
||||
|
||||
__all__ = [
|
||||
"BatchedTensorInputs",
|
||||
"ModalityData",
|
||||
"MultiModalDataBuiltins",
|
||||
"MultiModalDataDict",
|
||||
"MultiModalHashDict",
|
||||
"MultiModalHasher",
|
||||
"MultiModalKwargs",
|
||||
"MultiModalPlaceholderDict",
|
||||
"MultiModalPlaceholderMap",
|
||||
"NestedTensors",
|
||||
"MULTIMODAL_REGISTRY",
|
||||
"MultiModalRegistry",
|
||||
]
|
||||
106
vllm/multimodal/audio.py
Normal file
106
vllm/multimodal/audio.py
Normal file
@@ -0,0 +1,106 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
from vllm.utils import PlaceholderModule
|
||||
|
||||
from .base import MediaIO
|
||||
|
||||
try:
|
||||
import librosa
|
||||
except ImportError:
|
||||
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
|
||||
|
||||
try:
|
||||
import soundfile
|
||||
except ImportError:
|
||||
soundfile = PlaceholderModule("soundfile") # type: ignore[assignment]
|
||||
|
||||
|
||||
def resample_audio_librosa(
|
||||
audio: npt.NDArray[np.floating],
|
||||
*,
|
||||
orig_sr: float,
|
||||
target_sr: float,
|
||||
) -> npt.NDArray[np.floating]:
|
||||
return librosa.resample(audio, orig_sr=orig_sr, target_sr=target_sr)
|
||||
|
||||
|
||||
def resample_audio_scipy(
|
||||
audio: npt.NDArray[np.floating],
|
||||
*,
|
||||
orig_sr: float,
|
||||
target_sr: float,
|
||||
):
|
||||
# lazy import scipy.signal, otherwise it will crash doc build.
|
||||
import scipy.signal
|
||||
|
||||
if orig_sr > target_sr:
|
||||
return scipy.signal.resample_poly(audio, 1, orig_sr // target_sr)
|
||||
elif orig_sr < target_sr:
|
||||
return scipy.signal.resample_poly(audio, target_sr // orig_sr, 1)
|
||||
return audio
|
||||
|
||||
|
||||
class AudioResampler:
|
||||
"""Resample audio data to a target sample rate."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target_sr: Optional[float] = None,
|
||||
method: Literal["librosa", "scipy"] = "librosa",
|
||||
):
|
||||
self.target_sr = target_sr
|
||||
self.method = method
|
||||
|
||||
def resample(
|
||||
self,
|
||||
audio: npt.NDArray[np.floating],
|
||||
*,
|
||||
orig_sr: float,
|
||||
) -> npt.NDArray[np.floating]:
|
||||
if self.target_sr is None:
|
||||
raise RuntimeError("Audio resampling is not supported when "
|
||||
"`target_sr` is not provided")
|
||||
if self.method == "librosa":
|
||||
return resample_audio_librosa(audio,
|
||||
orig_sr=orig_sr,
|
||||
target_sr=self.target_sr)
|
||||
elif self.method == "scipy":
|
||||
return resample_audio_scipy(audio,
|
||||
orig_sr=orig_sr,
|
||||
target_sr=self.target_sr)
|
||||
else:
|
||||
raise ValueError(f"Invalid resampling method: {self.method}. "
|
||||
"Supported methods are 'librosa' and 'scipy'.")
|
||||
|
||||
|
||||
class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]):
|
||||
|
||||
def load_bytes(self, data: bytes) -> tuple[npt.NDArray, float]:
|
||||
return librosa.load(BytesIO(data), sr=None)
|
||||
|
||||
def load_base64(
|
||||
self,
|
||||
media_type: str,
|
||||
data: str,
|
||||
) -> tuple[npt.NDArray, float]:
|
||||
return self.load_bytes(base64.b64decode(data))
|
||||
|
||||
def load_file(self, filepath: Path) -> tuple[npt.NDArray, float]:
|
||||
return librosa.load(filepath, sr=None)
|
||||
|
||||
def encode_base64(self, media: tuple[npt.NDArray, float]) -> str:
|
||||
audio, sr = media
|
||||
|
||||
with BytesIO() as buffer:
|
||||
soundfile.write(buffer, audio, sr, format="WAV")
|
||||
data = buffer.getvalue()
|
||||
|
||||
return base64.b64encode(data).decode('utf-8')
|
||||
219
vllm/multimodal/base.py
Normal file
219
vllm/multimodal/base.py
Normal file
@@ -0,0 +1,219 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Generic, NamedTuple, TypeVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.sequence import SequenceGroupMetadata
|
||||
|
||||
from .inputs import MultiModalKwargs, PlaceholderRange
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
class MultiModalPlaceholderMap:
|
||||
"""
|
||||
Relates multi-modal embeddings to their corresponding placeholders.
|
||||
|
||||
Note: This is only used in V0.
|
||||
"""
|
||||
|
||||
class IndexMap(NamedTuple):
|
||||
src: list[int]
|
||||
dest: list[int]
|
||||
|
||||
src_ranges: list[range]
|
||||
"""
|
||||
The indices of the multi-modal embeddings that will replace the
|
||||
corresponding placeholder embeddings pointed to by ``dest_ranges``.
|
||||
"""
|
||||
|
||||
src_len: int
|
||||
"""
|
||||
The total number of flattened multi-modal embeddings.
|
||||
"""
|
||||
|
||||
dest_ranges: list[range]
|
||||
"""
|
||||
The indices of the placeholder embeddings that will be replaced by the
|
||||
multimodal embeddings.
|
||||
"""
|
||||
|
||||
dest_len: int
|
||||
"""
|
||||
The total number of embeddings in the destination tensor.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.src_ranges = []
|
||||
self.src_len = 0
|
||||
self.dest_ranges = []
|
||||
self.dest_len = 0
|
||||
|
||||
@classmethod
|
||||
def from_seq_group(
|
||||
cls, seq_group: "SequenceGroupMetadata", positions: range
|
||||
) -> tuple[MultiModalKwargs, dict[str, "MultiModalPlaceholderMap"]]:
|
||||
"""
|
||||
Returns the multi-modal items that intersect with the portion of a
|
||||
prompt (``seq_group``) represented by ``positions``, as well as a
|
||||
``MultiModalPlaceholderMap`` that relates the multi-modal embedding
|
||||
vectors to their corresponding placeholders.
|
||||
|
||||
Examples:
|
||||
|
||||
```
|
||||
Prompt: |AAAA BBBB What's in these images?|
|
||||
Positions: |.................................|
|
||||
|
||||
images = [A, B]
|
||||
src_ranges = [(0, 4), (4, 8)]
|
||||
dest_ranges = [(0, 4), (5, 9)]
|
||||
|
||||
Prompt: |AAAA BBBB What's in these images?|
|
||||
Positions: | ..... |
|
||||
|
||||
images = [A, B]
|
||||
src_ranges = [(2, 4), (4, 6)]
|
||||
dest_ranges = [(0, 2), (3, 5)]
|
||||
|
||||
Prompt: |AAAA BBBB What's in these images?|
|
||||
Positions: | ......... |
|
||||
|
||||
images = [B]
|
||||
src_ranges = [(0, 4)]
|
||||
dest_ranges = [(0, 4)]
|
||||
|
||||
Prompt: |AAAA BBBB What's in these images?|
|
||||
Positions: | .......................|
|
||||
|
||||
images = []
|
||||
src_ranges = []
|
||||
dest_ranges = []
|
||||
```
|
||||
"""
|
||||
seq_mm_data = seq_group.multi_modal_data
|
||||
seq_mm_placeholders = seq_group.multi_modal_placeholders
|
||||
|
||||
if not seq_mm_data or not seq_mm_placeholders:
|
||||
return MultiModalKwargs({}), {}
|
||||
|
||||
placeholder_maps = dict[str, MultiModalPlaceholderMap]()
|
||||
|
||||
for modality, placeholders in seq_mm_placeholders.items():
|
||||
placeholder_map = MultiModalPlaceholderMap()
|
||||
|
||||
if positions:
|
||||
placeholder_map.append_items_from_seq_group(
|
||||
positions,
|
||||
# Dummy, since we don't care about intersecting items
|
||||
[None] * len(placeholders),
|
||||
placeholders,
|
||||
)
|
||||
|
||||
placeholder_maps[modality] = placeholder_map
|
||||
|
||||
return seq_mm_data, placeholder_maps
|
||||
|
||||
def append_items_from_seq_group(
|
||||
self,
|
||||
positions: range,
|
||||
multi_modal_items: list[_T],
|
||||
multi_modal_placeholders: Sequence[PlaceholderRange],
|
||||
) -> list[_T]:
|
||||
"""
|
||||
Adds the multi-modal items that intersect ```positions`` to this
|
||||
placeholder map and returns the intersecting items.
|
||||
"""
|
||||
intersecting_items = []
|
||||
|
||||
if len(multi_modal_items) != len(multi_modal_placeholders):
|
||||
raise ValueError(
|
||||
"Multi-modal placeholders and items must have the same length."
|
||||
)
|
||||
for placeholder_dict, mm_item in zip(multi_modal_placeholders,
|
||||
multi_modal_items):
|
||||
placeholder = range(
|
||||
placeholder_dict.offset,
|
||||
placeholder_dict.offset + placeholder_dict.length,
|
||||
)
|
||||
intersection = range(
|
||||
max(positions.start, placeholder.start),
|
||||
min(positions.stop, placeholder.stop),
|
||||
)
|
||||
|
||||
if not intersection:
|
||||
# Skip this multi-modal item.
|
||||
continue
|
||||
|
||||
token_embedding_range = range(
|
||||
intersection.start - positions.start,
|
||||
intersection.stop - positions.start,
|
||||
)
|
||||
|
||||
multimodal_embedding_range = range(
|
||||
intersection.start - placeholder.start + self.src_len,
|
||||
intersection.stop - placeholder.start + self.src_len,
|
||||
)
|
||||
|
||||
intersecting_items.append(mm_item)
|
||||
self.dest_ranges.append(token_embedding_range)
|
||||
self.src_ranges.append(multimodal_embedding_range)
|
||||
self.src_len += len(placeholder)
|
||||
|
||||
self.dest_len += len(positions)
|
||||
return intersecting_items
|
||||
|
||||
def extend(self, other: "MultiModalPlaceholderMap"):
|
||||
"""
|
||||
Adds the placeholders from another ``MultiModalPlaceholderMap`` to this
|
||||
instance based on the source and destination tensors being
|
||||
concatenated.
|
||||
"""
|
||||
|
||||
self.src_ranges.extend(
|
||||
range(self.src_len + r.start, self.src_len + r.stop)
|
||||
for r in other.src_ranges)
|
||||
self.src_len += other.src_len
|
||||
self.dest_ranges.extend(
|
||||
range(self.dest_len + r.start, self.dest_len + r.stop)
|
||||
for r in other.dest_ranges)
|
||||
self.dest_len += other.dest_len
|
||||
|
||||
def index_map(self) -> "IndexMap":
|
||||
"""
|
||||
Finalizes the placeholder map into lists of indices that can be used to
|
||||
index the source and destination tensors.
|
||||
"""
|
||||
|
||||
src_indices = [i for r in self.src_ranges for i in r]
|
||||
dest_indices = [i for r in self.dest_ranges for i in r]
|
||||
|
||||
if len(src_indices) != len(dest_indices):
|
||||
raise ValueError(
|
||||
f"The number of source ({len(src_indices)}) and destination "
|
||||
f"indices ({len(dest_indices)}) must be the same.")
|
||||
|
||||
return self.IndexMap(src=src_indices, dest=dest_indices)
|
||||
|
||||
|
||||
class MediaIO(ABC, Generic[_T]):
|
||||
|
||||
@abstractmethod
|
||||
def load_bytes(self, data: bytes) -> _T:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def load_base64(self, media_type: str, data: str) -> _T:
|
||||
"""
|
||||
List of media types:
|
||||
https://www.iana.org/assignments/media-types/media-types.xhtml
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def load_file(self, filepath: Path) -> _T:
|
||||
raise NotImplementedError
|
||||
118
vllm/multimodal/hasher.py
Normal file
118
vllm/multimodal/hasher.py
Normal file
@@ -0,0 +1,118 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pickle
|
||||
from collections.abc import Iterable, Mapping
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from blake3 import blake3
|
||||
from PIL import Image
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal.image import convert_image_mode
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.inputs import TokensPrompt
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
MultiModalHashDict = Mapping[str, list[str]]
|
||||
"""
|
||||
A dictionary containing hashes for items in each modality.
|
||||
"""
|
||||
|
||||
|
||||
class MultiModalHasher:
|
||||
|
||||
@classmethod
|
||||
def serialize_item(cls, obj: object) -> bytes:
|
||||
# Simple cases
|
||||
if isinstance(obj, str):
|
||||
return obj.encode("utf-8")
|
||||
if isinstance(obj, bytes):
|
||||
return obj
|
||||
if isinstance(obj, (int, float)):
|
||||
return np.array(obj).tobytes()
|
||||
|
||||
if isinstance(obj, Image.Image):
|
||||
return cls.item_to_bytes(
|
||||
"image", np.asarray(convert_image_mode(obj, "RGBA")))
|
||||
if isinstance(obj, torch.Tensor):
|
||||
return cls.item_to_bytes("tensor", obj.numpy())
|
||||
if isinstance(obj, np.ndarray):
|
||||
return cls.item_to_bytes(
|
||||
"ndarray", {
|
||||
"dtype": obj.dtype.str,
|
||||
"shape": obj.shape,
|
||||
"data": obj.tobytes(),
|
||||
})
|
||||
|
||||
logger.warning(
|
||||
"No serialization method found for %s. "
|
||||
"Falling back to pickle.", type(obj))
|
||||
|
||||
return pickle.dumps(obj)
|
||||
|
||||
@classmethod
|
||||
def item_to_bytes(
|
||||
cls,
|
||||
key: str,
|
||||
obj: object,
|
||||
) -> bytes:
|
||||
return b''.join(kb + vb for kb, vb in cls.iter_item_to_bytes(key, obj))
|
||||
|
||||
@classmethod
|
||||
def iter_item_to_bytes(
|
||||
cls,
|
||||
key: str,
|
||||
obj: object,
|
||||
) -> Iterable[tuple[bytes, bytes]]:
|
||||
# Recursive cases
|
||||
if isinstance(obj, (list, tuple)):
|
||||
for i, elem in enumerate(obj):
|
||||
yield from cls.iter_item_to_bytes(f"{key}.{i}", elem)
|
||||
elif isinstance(obj, dict):
|
||||
for k, v in obj.items():
|
||||
yield from cls.iter_item_to_bytes(f"{key}.{k}", v)
|
||||
else:
|
||||
key_bytes = cls.serialize_item(key)
|
||||
value_bytes = cls.serialize_item(obj)
|
||||
yield key_bytes, value_bytes
|
||||
|
||||
@classmethod
|
||||
def hash_kwargs(cls, **kwargs: object) -> str:
|
||||
hasher = blake3()
|
||||
|
||||
for k, v in kwargs.items():
|
||||
for k_bytes, v_bytes in cls.iter_item_to_bytes(k, v):
|
||||
hasher.update(k_bytes)
|
||||
hasher.update(v_bytes)
|
||||
|
||||
return hasher.hexdigest()
|
||||
|
||||
@classmethod
|
||||
def hash_prompt_mm_data(
|
||||
cls, prompt: "TokensPrompt") -> Optional["MultiModalHashDict"]:
|
||||
"""Hash multimodal data in the user input prompt if they exist."""
|
||||
|
||||
if "multi_modal_data" not in prompt:
|
||||
return None
|
||||
|
||||
mm_data = prompt["multi_modal_data"]
|
||||
if not mm_data:
|
||||
# mm_data can be None or an empty dict.
|
||||
return None
|
||||
|
||||
mm_items = {
|
||||
modality: items if isinstance(items, list) else [items]
|
||||
for modality, items in mm_data.items()
|
||||
}
|
||||
|
||||
mm_hashes = {
|
||||
modality: [cls.hash_kwargs(**{modality: item}) for item in items]
|
||||
for modality, items in mm_items.items()
|
||||
}
|
||||
|
||||
return mm_hashes
|
||||
97
vllm/multimodal/image.py
Normal file
97
vllm/multimodal/image.py
Normal file
@@ -0,0 +1,97 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from .base import MediaIO
|
||||
|
||||
|
||||
def rescale_image_size(image: Image.Image,
|
||||
size_factor: float,
|
||||
transpose: int = -1) -> Image.Image:
|
||||
"""Rescale the dimensions of an image by a constant factor."""
|
||||
new_width = int(image.width * size_factor)
|
||||
new_height = int(image.height * size_factor)
|
||||
image = image.resize((new_width, new_height))
|
||||
if transpose >= 0:
|
||||
image = image.transpose(Image.Transpose(transpose))
|
||||
return image
|
||||
|
||||
|
||||
# TODO: Support customizable background color to fill in.
|
||||
def rgba_to_rgb(
|
||||
image: Image.Image, background_color=(255, 255, 255)) -> Image.Image:
|
||||
"""Convert an RGBA image to RGB with filled background color."""
|
||||
assert image.mode == "RGBA"
|
||||
converted = Image.new("RGB", image.size, background_color)
|
||||
converted.paste(image, mask=image.split()[3]) # 3 is the alpha channel
|
||||
return converted
|
||||
|
||||
|
||||
def convert_image_mode(image: Image.Image, to_mode: str):
|
||||
if image.mode == to_mode:
|
||||
return image
|
||||
elif image.mode == "RGBA" and to_mode == "RGB":
|
||||
return rgba_to_rgb(image)
|
||||
else:
|
||||
return image.convert(to_mode)
|
||||
|
||||
|
||||
class ImageMediaIO(MediaIO[Image.Image]):
|
||||
|
||||
def __init__(self, *, image_mode: str = "RGB") -> None:
|
||||
super().__init__()
|
||||
|
||||
self.image_mode = image_mode
|
||||
|
||||
def load_bytes(self, data: bytes) -> Image.Image:
|
||||
image = Image.open(BytesIO(data))
|
||||
image.load()
|
||||
return convert_image_mode(image, self.image_mode)
|
||||
|
||||
def load_base64(self, media_type: str, data: str) -> Image.Image:
|
||||
return self.load_bytes(base64.b64decode(data))
|
||||
|
||||
def load_file(self, filepath: Path) -> Image.Image:
|
||||
image = Image.open(filepath)
|
||||
image.load()
|
||||
return convert_image_mode(image, self.image_mode)
|
||||
|
||||
def encode_base64(
|
||||
self,
|
||||
media: Image.Image,
|
||||
*,
|
||||
image_format: str = "JPEG",
|
||||
) -> str:
|
||||
image = media
|
||||
|
||||
with BytesIO() as buffer:
|
||||
image = convert_image_mode(image, self.image_mode)
|
||||
image.save(buffer, image_format)
|
||||
data = buffer.getvalue()
|
||||
|
||||
return base64.b64encode(data).decode('utf-8')
|
||||
|
||||
|
||||
class ImageEmbeddingMediaIO(MediaIO[torch.Tensor]):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def load_bytes(self, data: bytes) -> torch.Tensor:
|
||||
buffer = BytesIO(data)
|
||||
return torch.load(buffer, weights_only=True)
|
||||
|
||||
def load_base64(self, media_type: str, data: str) -> torch.Tensor:
|
||||
return self.load_bytes(base64.b64decode(data))
|
||||
|
||||
def load_file(self, filepath: Path) -> torch.Tensor:
|
||||
return torch.load(filepath, weights_only=True)
|
||||
|
||||
def encode_base64(self, media: torch.Tensor) -> str:
|
||||
return base64.b64encode(media.numpy()).decode('utf-8')
|
||||
876
vllm/multimodal/inputs.py
Normal file
876
vllm/multimodal/inputs.py
Normal file
@@ -0,0 +1,876 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import UserDict, defaultdict
|
||||
from collections.abc import Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from itertools import accumulate
|
||||
from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, TypeVar,
|
||||
Union, cast, final)
|
||||
|
||||
import numpy as np
|
||||
from typing_extensions import NotRequired, TypeAlias
|
||||
|
||||
from vllm.jsontree import JSONTree, json_map_leaves
|
||||
from vllm.utils import LazyLoader, full_groupby, is_list_of
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
import torch.types
|
||||
from PIL.Image import Image
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
|
||||
from .hasher import MultiModalHashDict
|
||||
else:
|
||||
torch = LazyLoader("torch", globals(), "torch")
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
HfImageItem: TypeAlias = Union["Image", np.ndarray, "torch.Tensor"]
|
||||
"""
|
||||
A `transformers.image_utils.ImageInput` representing a single image
|
||||
item, which can be passed to a HuggingFace `ImageProcessor`.
|
||||
"""
|
||||
|
||||
HfVideoItem: TypeAlias = Union[list["Image"], np.ndarray, "torch.Tensor",
|
||||
list[np.ndarray], list["torch.Tensor"]]
|
||||
"""
|
||||
A `transformers.image_utils.VideoInput` representing a single video
|
||||
item, which can be passed to a HuggingFace `VideoProcessor`.
|
||||
"""
|
||||
|
||||
HfAudioItem: TypeAlias = Union[list[float], np.ndarray, "torch.Tensor"]
|
||||
"""
|
||||
Represents a single audio
|
||||
item, which can be passed to a HuggingFace `AudioProcessor`.
|
||||
"""
|
||||
|
||||
ImageItem: TypeAlias = Union[HfImageItem, "torch.Tensor"]
|
||||
"""
|
||||
A `transformers.image_utils.ImageInput` representing a single image
|
||||
item, which can be passed to a HuggingFace `ImageProcessor`.
|
||||
|
||||
Alternatively, a 3-D tensor or batch of 2-D tensors,
|
||||
which are treated as image embeddings;
|
||||
these are directly passed to the model without HF processing.
|
||||
"""
|
||||
|
||||
VideoItem: TypeAlias = Union[HfVideoItem, "torch.Tensor"]
|
||||
"""
|
||||
A `transformers.image_utils.VideoInput` representing a single video
|
||||
item, which can be passed to a HuggingFace `VideoProcessor`.
|
||||
|
||||
Alternatively, a 3-D tensor or batch of 2-D tensors,
|
||||
which are treated as video embeddings;
|
||||
these are directly passed to the model without HF processing.
|
||||
"""
|
||||
|
||||
AudioItem: TypeAlias = Union[HfAudioItem, tuple[np.ndarray, float],
|
||||
"torch.Tensor"]
|
||||
"""
|
||||
Represents a single audio
|
||||
item, which can be passed to a HuggingFace `AudioProcessor`.
|
||||
|
||||
Alternatively, a tuple `(audio, sampling_rate)`, where the sampling rate
|
||||
is different from that expected by the model;
|
||||
these are resampled to the model's sampling rate before being processed by HF.
|
||||
|
||||
Alternatively, a 3-D tensor or batch of 2-D tensors,
|
||||
which are treated as audio embeddings;
|
||||
these are directly passed to the model without HF processing.
|
||||
"""
|
||||
|
||||
ModalityData: TypeAlias = Union[_T, list[_T]]
|
||||
"""
|
||||
Either a single data item, or a list of data items.
|
||||
|
||||
The number of data items allowed per modality is restricted by
|
||||
`--limit-mm-per-prompt`.
|
||||
"""
|
||||
|
||||
|
||||
@final
|
||||
class MultiModalDataBuiltins(TypedDict, total=False):
|
||||
"""Type annotations for modality types predefined by vLLM."""
|
||||
|
||||
image: ModalityData[ImageItem]
|
||||
"""The input image(s)."""
|
||||
|
||||
video: ModalityData[VideoItem]
|
||||
"""The input video(s)."""
|
||||
|
||||
audio: ModalityData[AudioItem]
|
||||
"""The input audio(s)."""
|
||||
|
||||
|
||||
MultiModalDataDict: TypeAlias = Mapping[str, ModalityData[Any]]
|
||||
"""
|
||||
A dictionary containing an entry for each modality type to input.
|
||||
|
||||
The built-in modalities are defined by
|
||||
[`MultiModalDataBuiltins`][vllm.multimodal.inputs.MultiModalDataBuiltins].
|
||||
"""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PlaceholderRange:
|
||||
"""
|
||||
Placeholder location information for multi-modal data.
|
||||
|
||||
Example:
|
||||
|
||||
Prompt: `AAAA BBBB What is in these images?`
|
||||
|
||||
Images A and B will have:
|
||||
|
||||
```
|
||||
A: PlaceholderRange(offset=0, length=4)
|
||||
B: PlaceholderRange(offset=5, length=4)
|
||||
```
|
||||
"""
|
||||
|
||||
offset: int
|
||||
"""The start index of the placeholder in the prompt."""
|
||||
|
||||
length: int
|
||||
"""The length of the placeholder."""
|
||||
|
||||
is_embed: Optional["torch.Tensor"] = None
|
||||
"""
|
||||
A boolean mask of shape `(length,)` indicating which positions
|
||||
between `offset` and `offset + length` to assign embeddings to.
|
||||
"""
|
||||
|
||||
def get_num_embeds(self) -> int:
|
||||
if self.is_embed is None:
|
||||
return self.length
|
||||
|
||||
return int(self.is_embed.sum().item())
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, self.__class__):
|
||||
return False
|
||||
if not (self.offset, self.length) == (other.offset, other.length):
|
||||
return False
|
||||
|
||||
if self.is_embed is None:
|
||||
return other.is_embed is None
|
||||
if other.is_embed is None:
|
||||
return self.is_embed is None
|
||||
|
||||
return nested_tensors_equal(self.is_embed, other.is_embed)
|
||||
|
||||
|
||||
NestedTensors: TypeAlias = Union[list["NestedTensors"], list["torch.Tensor"],
|
||||
"torch.Tensor", tuple["torch.Tensor", ...]]
|
||||
"""
|
||||
Uses a list instead of a tensor if the dimensions of each element do not match.
|
||||
"""
|
||||
|
||||
|
||||
def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:
|
||||
"""Equality check between
|
||||
[`NestedTensors`][vllm.multimodal.inputs.NestedTensors] objects."""
|
||||
if isinstance(a, torch.Tensor):
|
||||
return isinstance(b, torch.Tensor) and torch.equal(a, b)
|
||||
elif isinstance(b, torch.Tensor):
|
||||
return isinstance(a, torch.Tensor) and torch.equal(b, a)
|
||||
|
||||
if isinstance(a, list):
|
||||
return (isinstance(b, list)
|
||||
and all(nested_tensors_equal(a_, b_) for a_, b_ in zip(a, b)))
|
||||
if isinstance(b, list):
|
||||
return (isinstance(a, list)
|
||||
and all(nested_tensors_equal(b_, a_) for b_, a_ in zip(b, a)))
|
||||
|
||||
# Both a and b are scalars
|
||||
return a == b
|
||||
|
||||
|
||||
BatchedTensorInputs: TypeAlias = Mapping[str, NestedTensors]
|
||||
"""
|
||||
A dictionary containing nested tensors which have been batched via
|
||||
[`MultiModalKwargs.batch`][vllm.multimodal.inputs.MultiModalKwargs.batch].
|
||||
"""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MultiModalFieldElem:
|
||||
"""
|
||||
Represents a keyword argument corresponding to a multi-modal item
|
||||
in [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs].
|
||||
"""
|
||||
|
||||
modality: str
|
||||
"""
|
||||
The modality of the corresponding multi-modal item.
|
||||
Each multi-modal item can consist of multiple keyword arguments.
|
||||
"""
|
||||
|
||||
key: str
|
||||
"""
|
||||
The key of this field in
|
||||
[`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs],
|
||||
i.e. the name of the keyword argument to be passed to the model.
|
||||
"""
|
||||
|
||||
data: NestedTensors
|
||||
"""
|
||||
The tensor data of this field in
|
||||
[`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs],
|
||||
i.e. the value of the keyword argument to be passed to the model.
|
||||
"""
|
||||
|
||||
field: "BaseMultiModalField"
|
||||
"""
|
||||
Defines how to combine the tensor data of this field with others
|
||||
in order to batch multi-modal items together for model inference.
|
||||
"""
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, self.__class__):
|
||||
return False
|
||||
|
||||
return ((self.modality, self.key) == (other.modality, other.key)
|
||||
and nested_tensors_equal(self.data, other.data)
|
||||
and type(self.field) == type(other.field)) # noqa: E721
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BaseMultiModalField(ABC):
|
||||
"""
|
||||
Defines how to interpret tensor data belonging to a keyword argument in
|
||||
[`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs] for multiple
|
||||
multi-modal items, and vice versa.
|
||||
"""
|
||||
|
||||
def _field_factory(self, *, modality: str, key: str):
|
||||
f = partial(
|
||||
MultiModalFieldElem,
|
||||
modality=modality,
|
||||
key=key,
|
||||
field=self,
|
||||
)
|
||||
|
||||
# Allow passing data as positional argument
|
||||
def factory(data: NestedTensors) -> MultiModalFieldElem:
|
||||
return f(data=data)
|
||||
|
||||
return factory
|
||||
|
||||
@abstractmethod
|
||||
def build_elems(
|
||||
self,
|
||||
modality: str,
|
||||
key: str,
|
||||
data: NestedTensors,
|
||||
) -> Sequence[MultiModalFieldElem]:
|
||||
"""
|
||||
Construct
|
||||
[`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem]
|
||||
instances to represent the provided data.
|
||||
|
||||
This is the inverse of
|
||||
[`reduce_data`][vllm.multimodal.inputs.BaseMultiModalField.reduce_data].
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
|
||||
raise NotImplementedError
|
||||
|
||||
def reduce_data(self, elems: list[MultiModalFieldElem]) -> NestedTensors:
|
||||
"""
|
||||
Merge the data from multiple instances of
|
||||
[`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem].
|
||||
|
||||
This is the inverse of
|
||||
[`build_elems`][vllm.multimodal.inputs.BaseMultiModalField.build_elems].
|
||||
"""
|
||||
field_types = [type(item.field) for item in elems]
|
||||
if len(set(field_types)) > 1:
|
||||
raise ValueError(f"Cannot merge different {field_types=}")
|
||||
|
||||
return self._reduce_data([item.data for item in elems])
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MultiModalBatchedField(BaseMultiModalField):
|
||||
"""
|
||||
Info:
|
||||
[`MultiModalFieldConfig.batched`][vllm.multimodal.inputs.MultiModalFieldConfig.batched]
|
||||
"""
|
||||
|
||||
def build_elems(
|
||||
self,
|
||||
modality: str,
|
||||
key: str,
|
||||
data: NestedTensors,
|
||||
) -> Sequence[MultiModalFieldElem]:
|
||||
field_factory = self._field_factory(modality=modality, key=key)
|
||||
return [field_factory(item) for item in data]
|
||||
|
||||
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
|
||||
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
|
||||
if len(batch) == 1:
|
||||
# An optimization when `batch` contains only one tensor:
|
||||
# - produce exactly same result as `torch.stack(batch)`
|
||||
# - will achieve zero-copy if the tensor is contiguous
|
||||
return batch[0].unsqueeze(0).contiguous()
|
||||
first_shape = batch[0].shape
|
||||
if all(elem.shape == first_shape for elem in batch):
|
||||
return torch.stack(batch)
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MultiModalFlatField(BaseMultiModalField):
|
||||
"""
|
||||
Info:
|
||||
[`MultiModalFieldConfig.flat`][vllm.multimodal.inputs.MultiModalFieldConfig.flat]
|
||||
[`MultiModalFieldConfig.flat_from_sizes`][vllm.multimodal.inputs.MultiModalFieldConfig.flat_from_sizes]
|
||||
"""
|
||||
slices: Union[Sequence[slice], Sequence[Sequence[slice]]]
|
||||
dim: int = 0
|
||||
|
||||
def build_elems(
|
||||
self,
|
||||
modality: str,
|
||||
key: str,
|
||||
data: NestedTensors,
|
||||
) -> Sequence[MultiModalFieldElem]:
|
||||
field_factory = self._field_factory(modality=modality, key=key)
|
||||
if not is_list_of(self.slices, slice, check="all"):
|
||||
assert isinstance(data, torch.Tensor), \
|
||||
"torch.Tensor is required for multiple slices"
|
||||
return [field_factory(data[cast(slice, s)]) for s in self.slices]
|
||||
|
||||
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
|
||||
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
|
||||
if len(batch) == 1:
|
||||
# An optimization when `batch` contains only one tensor:
|
||||
# - produce exactly same result as `torch.concat(batch)`
|
||||
# - will achieve zero-copy if the tensor is contiguous
|
||||
return batch[0].contiguous()
|
||||
|
||||
def _expect_same_shape(tensor: torch.Tensor):
|
||||
return tensor.shape[:self.dim] + tensor.shape[self.dim + 1:]
|
||||
|
||||
first_shape = _expect_same_shape(batch[0])
|
||||
|
||||
if all(_expect_same_shape(elem) == first_shape for elem in batch):
|
||||
return torch.concat(batch, dim=self.dim)
|
||||
|
||||
assert self.dim == 0, "dim == 0 is required for nested list"
|
||||
return [e for elem in batch for e in elem]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MultiModalSharedField(BaseMultiModalField):
|
||||
"""
|
||||
Info:
|
||||
[`MultiModalFieldConfig.shared`][vllm.multimodal.inputs.MultiModalFieldConfig.shared]
|
||||
"""
|
||||
batch_size: int
|
||||
|
||||
def build_elems(
|
||||
self,
|
||||
modality: str,
|
||||
key: str,
|
||||
data: NestedTensors,
|
||||
) -> Sequence[MultiModalFieldElem]:
|
||||
field_factory = self._field_factory(modality=modality, key=key)
|
||||
return [field_factory(data)] * self.batch_size
|
||||
|
||||
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
|
||||
return batch[0]
|
||||
|
||||
|
||||
class MultiModalFieldConfig:
|
||||
|
||||
@staticmethod
|
||||
def batched(modality: str):
|
||||
"""
|
||||
Defines a field where an element in the batch is obtained by
|
||||
indexing into the first dimension of the underlying data.
|
||||
|
||||
Args:
|
||||
modality: The modality of the multi-modal item that uses this
|
||||
keyword argument.
|
||||
|
||||
Example:
|
||||
|
||||
```
|
||||
Input:
|
||||
Data: [[AAAA]
|
||||
[BBBB]
|
||||
[CCCC]]
|
||||
|
||||
Output:
|
||||
Element 1: [AAAA]
|
||||
Element 2: [BBBB]
|
||||
Element 3: [CCCC]
|
||||
```
|
||||
"""
|
||||
return MultiModalFieldConfig(
|
||||
field=MultiModalBatchedField(),
|
||||
modality=modality,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def flat(modality: str,
|
||||
slices: Union[Sequence[slice], Sequence[Sequence[slice]]],
|
||||
dim: int = 0):
|
||||
"""
|
||||
Defines a field where an element in the batch is obtained by
|
||||
slicing along the first dimension of the underlying data.
|
||||
|
||||
Args:
|
||||
modality: The modality of the multi-modal item that uses this
|
||||
keyword argument.
|
||||
slices: For each multi-modal item, a slice (dim=0) or a tuple of
|
||||
slices (dim>0) that is used to extract the data corresponding
|
||||
to it.
|
||||
dim: The dimension to extract data, default to 0.
|
||||
|
||||
Example:
|
||||
|
||||
```
|
||||
Given:
|
||||
slices: [slice(0, 3), slice(3, 7), slice(7, 9)]
|
||||
|
||||
Input:
|
||||
Data: [AAABBBBCC]
|
||||
|
||||
Output:
|
||||
Element 1: [AAA]
|
||||
Element 2: [BBBB]
|
||||
Element 3: [CC]
|
||||
```
|
||||
|
||||
```
|
||||
Given:
|
||||
slices: [
|
||||
(slice(None), slice(0, 3)),
|
||||
(slice(None), slice(3, 7)),
|
||||
(slice(None), slice(7, 9))]
|
||||
dim: 1
|
||||
|
||||
Input:
|
||||
Data: [[A],[A],[A],[B],[B],[B],[B],[C],[C]]
|
||||
|
||||
Output:
|
||||
Element 1: [[A],[A],[A]]
|
||||
Element 2: [[B],[B],[B],[B]]
|
||||
Element 3: [[C],[C]]
|
||||
```
|
||||
"""
|
||||
return MultiModalFieldConfig(
|
||||
field=MultiModalFlatField(slices=slices, dim=dim),
|
||||
modality=modality,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def flat_from_sizes(modality: str,
|
||||
size_per_item: "torch.Tensor",
|
||||
dim: int = 0):
|
||||
"""
|
||||
Defines a field where an element in the batch is obtained by
|
||||
slicing along the first dimension of the underlying data.
|
||||
|
||||
Args:
|
||||
modality: The modality of the multi-modal item that uses this
|
||||
keyword argument.
|
||||
slices: For each multi-modal item, the size of the slice that
|
||||
is used to extract the data corresponding to it.
|
||||
dim: The dimension to slice, default to 0.
|
||||
|
||||
Example:
|
||||
|
||||
```
|
||||
Given:
|
||||
size_per_item: [3, 4, 2]
|
||||
|
||||
Input:
|
||||
Data: [AAABBBBCC]
|
||||
|
||||
Output:
|
||||
Element 1: [AAA]
|
||||
Element 2: [BBBB]
|
||||
Element 3: [CC]
|
||||
```
|
||||
|
||||
```
|
||||
Given:
|
||||
slices: [3, 4, 2]
|
||||
dim: 1
|
||||
|
||||
Input:
|
||||
Data: [[A],[A],[A],[B],[B],[B],[B],[C],[C]]
|
||||
|
||||
Output:
|
||||
Element 1: [[A],[A],[A]]
|
||||
Element 2: [[B],[B],[B],[B]]
|
||||
Element 3: [[C],[C]]
|
||||
```
|
||||
|
||||
Info:
|
||||
[`MultiModalFieldConfig.flat`][vllm.multimodal.inputs.MultiModalFieldConfig.flat]
|
||||
"""
|
||||
|
||||
if size_per_item.ndim != 1:
|
||||
raise ValueError("size_per_item should be a 1-D tensor, "
|
||||
f"but found shape: {size_per_item.shape}")
|
||||
|
||||
slice_idxs = [0, *accumulate(size_per_item)]
|
||||
slices = [(slice(None, None, None), ) * dim +
|
||||
(slice(slice_idxs[i], slice_idxs[i + 1]), )
|
||||
for i in range(len(size_per_item))]
|
||||
|
||||
return MultiModalFieldConfig.flat(modality, slices, dim=dim)
|
||||
|
||||
@staticmethod
|
||||
def shared(modality: str, batch_size: int):
|
||||
"""
|
||||
Defines a field where an element in the batch is obtained by
|
||||
taking the entirety of the underlying data.
|
||||
|
||||
This means that the data is the same for each element in the batch.
|
||||
|
||||
Args:
|
||||
modality: The modality of the multi-modal item that uses this
|
||||
keyword argument.
|
||||
batch_size: The number of multi-modal items which share this data.
|
||||
|
||||
Example:
|
||||
|
||||
```
|
||||
Given:
|
||||
batch_size: 4
|
||||
|
||||
Input:
|
||||
Data: [XYZ]
|
||||
|
||||
Output:
|
||||
Element 1: [XYZ]
|
||||
Element 2: [XYZ]
|
||||
Element 3: [XYZ]
|
||||
Element 4: [XYZ]
|
||||
```
|
||||
"""
|
||||
return MultiModalFieldConfig(
|
||||
field=MultiModalSharedField(batch_size),
|
||||
modality=modality,
|
||||
)
|
||||
|
||||
def __init__(self, field: BaseMultiModalField, modality: str) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.field = field
|
||||
self.modality = modality
|
||||
|
||||
def build_elems(
|
||||
self,
|
||||
key: str,
|
||||
batch: NestedTensors,
|
||||
) -> Sequence[MultiModalFieldElem]:
|
||||
return self.field.build_elems(self.modality, key, batch)
|
||||
|
||||
|
||||
class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
|
||||
"""
|
||||
A collection of
|
||||
[`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem]
|
||||
corresponding to a data item in
|
||||
[`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def from_elems(elems: Sequence[MultiModalFieldElem]):
|
||||
return MultiModalKwargsItem({elem.key: elem for elem in elems})
|
||||
|
||||
@property
|
||||
def modality(self) -> str:
|
||||
modalities = {elem.modality for elem in self.data.values()}
|
||||
assert len(modalities) == 1, f"Found different modalities={modalities}"
|
||||
return next(iter(modalities))
|
||||
|
||||
|
||||
# NOTE: UserDict is for V0 compatibility.
|
||||
# V1 should access individual items via `get_item`.
|
||||
class MultiModalKwargs(UserDict[str, NestedTensors]):
|
||||
"""
|
||||
A dictionary that represents the keyword arguments to
|
||||
[`torch.nn.Module.forward`][].
|
||||
|
||||
The metadata `items` enables us to obtain the keyword arguments
|
||||
corresponding to each data item in
|
||||
[`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems], via
|
||||
[`get_item`][vllm.multimodal.inputs.MultiModalKwargs.get_item] and
|
||||
[`get_items`][vllm.multimodal.inputs.MultiModalKwargs.get_items].
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def from_hf_inputs(
|
||||
hf_inputs: "BatchFeature",
|
||||
config_by_key: Mapping[str, MultiModalFieldConfig],
|
||||
):
|
||||
# NOTE: This skips fields in `hf_inputs` that are not in `config_by_key`
|
||||
# We assume that those fields are not used in vLLM
|
||||
elems_by_key = dict[str, Sequence[MultiModalFieldElem]]()
|
||||
keys_by_modality = defaultdict[str, set[str]](set)
|
||||
for key, config in config_by_key.items():
|
||||
batch = hf_inputs.get(key)
|
||||
if batch is not None:
|
||||
elems = config.build_elems(key, batch)
|
||||
if len(elems) > 0:
|
||||
elems_by_key[key] = elems
|
||||
keys_by_modality[config.modality].add(key)
|
||||
|
||||
items = list[MultiModalKwargsItem]()
|
||||
for modality, keys in keys_by_modality.items():
|
||||
elems_in_modality = {k: elems_by_key[k] for k in keys}
|
||||
batch_sizes = {k: len(v) for k, v in elems_in_modality.items()}
|
||||
|
||||
if len(set(batch_sizes.values())) > 1:
|
||||
raise ValueError(
|
||||
f"Cannot merge different batch sizes for {modality=}! "
|
||||
f"Found: {batch_sizes=}")
|
||||
|
||||
batch_size = next(iter(batch_sizes.values()))
|
||||
for item_idx in range(batch_size):
|
||||
elems = [v[item_idx] for v in elems_in_modality.values()]
|
||||
items.append(MultiModalKwargsItem.from_elems(elems))
|
||||
|
||||
return MultiModalKwargs.from_items(items)
|
||||
|
||||
@staticmethod
|
||||
def from_items(items: Sequence[MultiModalKwargsItem]):
|
||||
"""Construct a new
|
||||
[`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs]
|
||||
from multiple items."""
|
||||
elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
|
||||
for item in items:
|
||||
for key, elem in item.items():
|
||||
elems_by_key[key].append(elem)
|
||||
|
||||
data = {
|
||||
key: elems[0].field.reduce_data(elems)
|
||||
for key, elems in elems_by_key.items() if len(elems) > 0
|
||||
}
|
||||
|
||||
return MultiModalKwargs(data, items=items)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data: Mapping[str, NestedTensors],
|
||||
*,
|
||||
items: Optional[Sequence[MultiModalKwargsItem]] = None,
|
||||
) -> None:
|
||||
super().__init__(data)
|
||||
|
||||
items_by_modality = full_groupby(items or [], key=lambda x: x.modality)
|
||||
self._items_by_modality = dict(items_by_modality)
|
||||
|
||||
@property
|
||||
def modalities(self):
|
||||
return self._items_by_modality.keys()
|
||||
|
||||
@staticmethod
|
||||
def _try_stack(nested_tensors: NestedTensors,
|
||||
pin_memory: bool = False) -> NestedTensors:
|
||||
"""
|
||||
Stack the inner dimensions that have the same shape in
|
||||
a nested list of tensors.
|
||||
|
||||
Thus, a dimension represented by a list means that the inner
|
||||
dimensions are different for each element along that dimension.
|
||||
"""
|
||||
if isinstance(nested_tensors, torch.Tensor):
|
||||
return nested_tensors
|
||||
|
||||
# TODO: Remove these once all models have been migrated
|
||||
if isinstance(nested_tensors, np.ndarray):
|
||||
return torch.from_numpy(nested_tensors)
|
||||
if isinstance(nested_tensors, (int, float)):
|
||||
return torch.tensor(nested_tensors)
|
||||
|
||||
stacked = [
|
||||
MultiModalKwargs._try_stack(t, pin_memory) for t in nested_tensors
|
||||
]
|
||||
if not is_list_of(stacked, torch.Tensor, check="all"):
|
||||
# Only tensors (not lists) can be stacked.
|
||||
return stacked
|
||||
|
||||
tensors_ = cast(list[torch.Tensor], stacked)
|
||||
if len(tensors_) == 1:
|
||||
# An optimization when `tensors_` contains only one tensor:
|
||||
# - produce exactly same result as `torch.stack(tensors_)`
|
||||
# - will achieve zero-copy if the tensor is contiguous
|
||||
return tensors_[0].unsqueeze(0).contiguous()
|
||||
|
||||
if any(t.shape != tensors_[0].shape for t in tensors_):
|
||||
# The tensors have incompatible shapes and can't be stacked.
|
||||
return tensors_
|
||||
|
||||
outputs = torch.empty(len(tensors_),
|
||||
*tensors_[0].shape,
|
||||
dtype=tensors_[0].dtype,
|
||||
device=tensors_[0].device,
|
||||
pin_memory=pin_memory)
|
||||
return torch.stack(tensors_, out=outputs)
|
||||
|
||||
@staticmethod
|
||||
def batch(inputs_list: list["MultiModalKwargs"],
|
||||
pin_memory: bool = False) -> BatchedTensorInputs:
|
||||
"""
|
||||
Batch multiple inputs together into a dictionary.
|
||||
|
||||
The resulting dictionary has the same keys as the inputs.
|
||||
If the corresponding value from each input is a tensor and they all
|
||||
share the same shape, the output value is a single batched tensor;
|
||||
otherwise, the output value is a list containing the original value
|
||||
from each input.
|
||||
"""
|
||||
if len(inputs_list) == 0:
|
||||
return {}
|
||||
|
||||
# We need to consider the case where each item in the batch
|
||||
# contains different modalities (i.e. different keys).
|
||||
item_lists = defaultdict[str, list[NestedTensors]](list)
|
||||
|
||||
for inputs in inputs_list:
|
||||
for k, v in inputs.items():
|
||||
item_lists[k].append(v)
|
||||
|
||||
return {
|
||||
k: MultiModalKwargs._try_stack(item_list, pin_memory)
|
||||
for k, item_list in item_lists.items()
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def as_kwargs(
|
||||
batched_inputs: BatchedTensorInputs,
|
||||
*,
|
||||
device: torch.types.Device,
|
||||
) -> BatchedTensorInputs:
|
||||
json_inputs = cast(JSONTree[torch.Tensor], batched_inputs)
|
||||
|
||||
json_mapped = json_map_leaves(
|
||||
lambda x: x.to(device=device, non_blocking=True),
|
||||
json_inputs,
|
||||
)
|
||||
|
||||
return cast(BatchedTensorInputs, json_mapped)
|
||||
|
||||
def __delitem__(self, key: str) -> None:
|
||||
super().__delitem__(key)
|
||||
|
||||
for items in self._items_by_modality.values():
|
||||
for item in items:
|
||||
item.pop(key, None)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, self.__class__):
|
||||
return False
|
||||
if self._items_by_modality != other._items_by_modality:
|
||||
return False
|
||||
|
||||
ks = self.keys()
|
||||
return (ks == other.keys()
|
||||
and all(nested_tensors_equal(self[k], other[k]) for k in ks))
|
||||
|
||||
def _validate_modality(self, method_name: str, modality: str) -> None:
|
||||
if not self._items_by_modality:
|
||||
raise RuntimeError(
|
||||
f"`{method_name}` is not supported when "
|
||||
"MultiModalKwargs is not initialized with `items`")
|
||||
|
||||
if modality not in self._items_by_modality:
|
||||
available_modalities = set(self._items_by_modality.keys())
|
||||
raise KeyError(f"Modality {modality!r} not found. "
|
||||
f"Available modalities: {available_modalities}")
|
||||
|
||||
def get_item_count(self, modality: str) -> int:
|
||||
"""Get the number of items belonging to a modality."""
|
||||
self._validate_modality("get_item_count", modality)
|
||||
return len(self._items_by_modality[modality])
|
||||
|
||||
def get_item(self, modality: str, item_index: int) -> MultiModalKwargsItem:
|
||||
"""
|
||||
Get the keyword arguments corresponding to an item identified by
|
||||
its modality and index.
|
||||
"""
|
||||
self._validate_modality("get_item", modality)
|
||||
return self._items_by_modality[modality][item_index]
|
||||
|
||||
def get_items(self, modality: str) -> Sequence[MultiModalKwargsItem]:
|
||||
"""
|
||||
Get the keyword arguments corresponding to each item belonging to
|
||||
a modality.
|
||||
"""
|
||||
self._validate_modality("get_items", modality)
|
||||
return self._items_by_modality[modality]
|
||||
|
||||
|
||||
MultiModalPlaceholderDict: TypeAlias = Mapping[str, Sequence[PlaceholderRange]]
|
||||
"""
|
||||
A dictionary containing placeholder ranges for each modality.
|
||||
"""
|
||||
|
||||
|
||||
class MultiModalInputs(TypedDict):
|
||||
"""
|
||||
Represents the outputs of
|
||||
[`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor],
|
||||
ready to be passed to vLLM internals.
|
||||
"""
|
||||
|
||||
type: Literal["multimodal"]
|
||||
"""The type of inputs."""
|
||||
|
||||
prompt: str
|
||||
"""The processed prompt text."""
|
||||
|
||||
prompt_token_ids: list[int]
|
||||
"""The processed token IDs which includes placeholder tokens."""
|
||||
|
||||
token_type_ids: NotRequired[list[int]]
|
||||
"""The token type IDs of the prompt."""
|
||||
|
||||
mm_kwargs: MultiModalKwargs
|
||||
"""Keyword arguments to be directly passed to the model after batching."""
|
||||
|
||||
mm_hashes: Optional["MultiModalHashDict"]
|
||||
"""The hashes of the multi-modal data."""
|
||||
|
||||
mm_placeholders: "MultiModalPlaceholderDict"
|
||||
"""
|
||||
For each modality, information about the placeholder tokens in
|
||||
`prompt_token_ids`.
|
||||
"""
|
||||
|
||||
cache_salt: NotRequired[str]
|
||||
"""
|
||||
Optional cache salt to be used for prefix caching.
|
||||
"""
|
||||
|
||||
|
||||
class MultiModalEncDecInputs(MultiModalInputs):
|
||||
"""
|
||||
Represents the outputs of
|
||||
[`EncDecMultiModalProcessor`][vllm.multimodal.processing.EncDecMultiModalProcessor]
|
||||
ready to be passed to vLLM internals.
|
||||
"""
|
||||
|
||||
encoder_prompt: str
|
||||
"""The processed encoder prompt text."""
|
||||
|
||||
encoder_prompt_token_ids: list[int]
|
||||
"""The processed token IDs of the encoder prompt."""
|
||||
|
||||
encoder_token_type_ids: NotRequired[list[int]]
|
||||
"""The token type IDs of the encoder prompt."""
|
||||
461
vllm/multimodal/parse.py
Normal file
461
vllm/multimodal/parse.py
Normal file
@@ -0,0 +1,461 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import UserDict
|
||||
from collections.abc import Callable, Iterator, Mapping, Sequence
|
||||
from typing import (TYPE_CHECKING, Any, Generic, Literal, NamedTuple, Optional,
|
||||
TypeVar, Union)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from typing_extensions import TypeAlias, TypeGuard, assert_never
|
||||
|
||||
from vllm.utils import LazyLoader, is_list_of
|
||||
|
||||
from .audio import AudioResampler
|
||||
from .inputs import (AudioItem, HfAudioItem, HfImageItem, HfVideoItem,
|
||||
ImageItem, ModalityData, MultiModalDataDict,
|
||||
MultiModalFieldConfig, MultiModalKwargs, VideoItem)
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_I = TypeVar("_I")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import PIL.Image as PILImage
|
||||
else:
|
||||
PILImage = LazyLoader("PILImage", globals(), "PIL.Image")
|
||||
|
||||
|
||||
class ModalityDataItems(ABC, Generic[_T, _I]):
|
||||
"""
|
||||
Represents data items for a modality in
|
||||
[`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
|
||||
"""
|
||||
|
||||
def __init__(self, data: _T, modality: str) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.data = data
|
||||
self.modality = modality
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"{type(self).__name__}(modality={self.modality!r}, "
|
||||
f"len={len(self)})")
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.get_count()
|
||||
|
||||
def __getitem__(self, index: int) -> _I:
|
||||
return self.get(index)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# Auto-generated
|
||||
def __iter__(self) -> Iterator[_I]:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_count(self) -> int:
|
||||
"""Get the number of data items."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get(self, index: int) -> _I:
|
||||
"""Get a data item by its index."""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_all(self) -> list[_I]:
|
||||
"""Get all data items."""
|
||||
return [self.get(idx) for idx in range(self.get_count())]
|
||||
|
||||
@abstractmethod
|
||||
def get_processor_data(self) -> Mapping[str, object]:
|
||||
"""Get the data to pass to the HF processor."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_passthrough_data(self) -> Mapping[str, object]:
|
||||
"""Get the data to pass directly to the model."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ProcessorBatchItems(ModalityDataItems[Sequence[_T], _T]):
|
||||
"""Base class for data items that are arranged in a list."""
|
||||
|
||||
def get_count(self) -> int:
|
||||
return len(self.data)
|
||||
|
||||
def get(self, index: int) -> _T:
|
||||
return self.data[index]
|
||||
|
||||
def get_processor_data(self) -> Mapping[str, object]:
|
||||
return {f"{self.modality}s": self.data}
|
||||
|
||||
def get_passthrough_data(self) -> Mapping[str, object]:
|
||||
return {}
|
||||
|
||||
|
||||
class EmbeddingItems(ModalityDataItems[Union[torch.Tensor, list[torch.Tensor]],
|
||||
torch.Tensor]):
|
||||
"""
|
||||
Base class for data items that are expressed as a batched embedding tensor,
|
||||
or a list of embedding tensors (one per item).
|
||||
"""
|
||||
|
||||
def get_count(self) -> int:
|
||||
return len(self.data)
|
||||
|
||||
def get(self, index: int) -> torch.Tensor:
|
||||
return self.data[index]
|
||||
|
||||
def get_processor_data(self) -> Mapping[str, object]:
|
||||
return {}
|
||||
|
||||
def get_passthrough_data(self) -> Mapping[str, object]:
|
||||
return {f"{self.modality}_embeds": self.data}
|
||||
|
||||
def get_feature_size(self, item_idx: int) -> int:
|
||||
return len(self.get(item_idx))
|
||||
|
||||
|
||||
class DictEmbeddingItems(ModalityDataItems[Mapping[str, torch.Tensor],
|
||||
Mapping[str, torch.Tensor]]):
|
||||
"""
|
||||
Base class for data items that are expressed as a dictionary of tensors.
|
||||
|
||||
Usually, the dictionary keys correspond to the outputs of HF processor.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data: Mapping[str, torch.Tensor],
|
||||
modality: str,
|
||||
required_fields: set[str],
|
||||
fields_factory: Callable[
|
||||
[Mapping[str, torch.Tensor]],
|
||||
Mapping[str, MultiModalFieldConfig],
|
||||
],
|
||||
) -> None:
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
|
||||
super().__init__(data, modality)
|
||||
|
||||
missing_required_data_keys = required_fields - data.keys()
|
||||
if missing_required_data_keys:
|
||||
data_keys = set(data.keys())
|
||||
msg = (f"The data should contain the fields: {required_fields}, "
|
||||
f"but only found the following keys: {data_keys}")
|
||||
raise ValueError(msg)
|
||||
|
||||
fields_config = fields_factory(data)
|
||||
missing_required_fields = required_fields - fields_config.keys()
|
||||
if missing_required_fields:
|
||||
fields = set(fields_config.keys())
|
||||
msg = f"{required_fields=} should be a subset of {fields=}"
|
||||
raise ValueError(msg)
|
||||
|
||||
self.fields_config = fields_config
|
||||
self.required_fields = required_fields
|
||||
|
||||
self._kwargs = MultiModalKwargs.from_hf_inputs(
|
||||
BatchFeature(dict(data)),
|
||||
fields_config,
|
||||
)
|
||||
|
||||
def get_count(self) -> int:
|
||||
return self._kwargs.get_item_count(self.modality)
|
||||
|
||||
def get(self, index: int) -> Mapping[str, torch.Tensor]:
|
||||
return {
|
||||
k: v.data
|
||||
for k, v in self._kwargs.get_item(self.modality, index).items()
|
||||
}
|
||||
|
||||
def get_processor_data(self) -> Mapping[str, object]:
|
||||
return {}
|
||||
|
||||
def get_passthrough_data(self) -> Mapping[str, object]:
|
||||
return self.data
|
||||
|
||||
|
||||
class AudioProcessorItems(ProcessorBatchItems[HfAudioItem]):
|
||||
|
||||
def __init__(self, data: Sequence[HfAudioItem]) -> None:
|
||||
super().__init__(data, "audio")
|
||||
|
||||
def get_audio_length(self, item_idx: int) -> int:
|
||||
audio = self.get(item_idx)
|
||||
return len(audio)
|
||||
|
||||
|
||||
class AudioEmbeddingItems(EmbeddingItems):
|
||||
|
||||
def __init__(self, data: Union[torch.Tensor, list[torch.Tensor]]) -> None:
|
||||
super().__init__(data, "audio")
|
||||
|
||||
|
||||
class ImageSize(NamedTuple):
|
||||
width: int
|
||||
height: int
|
||||
|
||||
|
||||
class ImageProcessorItems(ProcessorBatchItems[HfImageItem]):
|
||||
|
||||
def __init__(self, data: Sequence[HfImageItem]) -> None:
|
||||
super().__init__(data, "image")
|
||||
|
||||
def get_image_size(self, item_idx: int) -> ImageSize:
|
||||
image = self.get(item_idx)
|
||||
|
||||
if isinstance(image, PILImage.Image):
|
||||
return ImageSize(*image.size)
|
||||
if isinstance(image, (np.ndarray, torch.Tensor)):
|
||||
_, h, w = image.shape
|
||||
return ImageSize(w, h)
|
||||
|
||||
assert_never(image)
|
||||
|
||||
|
||||
class ImageEmbeddingItems(EmbeddingItems):
|
||||
|
||||
def __init__(self, data: Union[torch.Tensor, list[torch.Tensor]]) -> None:
|
||||
super().__init__(data, "image")
|
||||
|
||||
|
||||
class VideoProcessorItems(ProcessorBatchItems[HfVideoItem]):
|
||||
|
||||
def __init__(self, data: Sequence[HfVideoItem]) -> None:
|
||||
super().__init__(data, "video")
|
||||
|
||||
def get_num_frames(self, item_idx: int) -> int:
|
||||
return len(self.get(item_idx))
|
||||
|
||||
def get_frame_size(self, item_idx: int) -> ImageSize:
|
||||
image = self.get(item_idx)[0] # Assume that the video isn't empty
|
||||
|
||||
if isinstance(image, PILImage.Image):
|
||||
return ImageSize(*image.size)
|
||||
if isinstance(image, (np.ndarray, torch.Tensor)):
|
||||
_, h, w = image.shape
|
||||
return ImageSize(w, h)
|
||||
|
||||
assert_never(image)
|
||||
|
||||
|
||||
class VideoEmbeddingItems(EmbeddingItems):
|
||||
|
||||
def __init__(self, data: Union[torch.Tensor, list[torch.Tensor]]) -> None:
|
||||
super().__init__(data, "video")
|
||||
|
||||
|
||||
_D = TypeVar("_D", bound=ModalityDataItems[Any, Any])
|
||||
|
||||
|
||||
class MultiModalDataItems(UserDict[str, ModalityDataItems[Any, Any]]):
|
||||
"""
|
||||
As [`MultiModalDataDict`][vllm.multimodal.inputs.MultiModalDataDict], but
|
||||
normalized such that each entry corresponds to a list.
|
||||
"""
|
||||
|
||||
def get_count(self, modality: str, *, strict: bool = True) -> int:
|
||||
"""
|
||||
Get the number of data items belonging to a modality.
|
||||
|
||||
If `strict=False`, return `0` instead of raising [`KeyError`][]
|
||||
even if the modality is not found.
|
||||
"""
|
||||
if modality not in self:
|
||||
if strict:
|
||||
available_modalities = set(self.keys())
|
||||
raise KeyError(f"Modality {modality!r} not found. "
|
||||
f"Available modalities: {available_modalities}")
|
||||
|
||||
return 0
|
||||
|
||||
return self[modality].get_count()
|
||||
|
||||
def get_all_counts(self) -> Mapping[str, int]:
|
||||
"""Get the number of items belonging to each modality."""
|
||||
return {m: items.get_count() for m, items in self.items()}
|
||||
|
||||
def get_items(
|
||||
self,
|
||||
modality: str,
|
||||
typ: Union[type[_D], tuple[type[_D], ...]],
|
||||
) -> _D:
|
||||
"""
|
||||
Get the data items belonging to a modality,
|
||||
requiring that they belong to a certain type.
|
||||
"""
|
||||
if modality not in self:
|
||||
available_modalities = set(self.keys())
|
||||
raise KeyError(f"Modality {modality!r} not found. "
|
||||
f"Available modalities: {available_modalities}")
|
||||
|
||||
items = self[modality]
|
||||
if not isinstance(items, typ):
|
||||
raise TypeError(f"Invalid type of data items for {modality=}. "
|
||||
f"Expected type: {typ}, but "
|
||||
f"found type: {type(items)}")
|
||||
|
||||
return items # type: ignore[return-value]
|
||||
|
||||
|
||||
ModalityDataParser: TypeAlias = Callable[[ModalityData[Any]],
|
||||
Optional[ModalityDataItems[Any, Any]]]
|
||||
|
||||
|
||||
class MultiModalDataParser:
|
||||
"""
|
||||
Parses [`MultiModalDataDict`][vllm.multimodal.inputs.MultiModalDataDict]
|
||||
into [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
|
||||
|
||||
Args:
|
||||
target_sr (float, optional): Enables automatic resampling of audio
|
||||
items to the model's expected sampling rate.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
target_sr: Optional[float] = None,
|
||||
audio_resample_method: Literal["librosa", "scipy"] = "librosa",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.audio_resampler = AudioResampler(
|
||||
target_sr=target_sr,
|
||||
method=audio_resample_method,
|
||||
)
|
||||
|
||||
def _is_embeddings(
|
||||
self, data: object
|
||||
) -> TypeGuard[Union[torch.Tensor, list[torch.Tensor]]]:
|
||||
if isinstance(data, torch.Tensor):
|
||||
return data.ndim == 3
|
||||
if is_list_of(data, torch.Tensor):
|
||||
return data[0].ndim == 2
|
||||
|
||||
return False
|
||||
|
||||
def _is_empty(self, data: object) -> TypeGuard[None]:
|
||||
if isinstance(data, list):
|
||||
return len(data) == 0
|
||||
if isinstance(data, (np.ndarray, torch.Tensor)):
|
||||
return data.size == 0
|
||||
|
||||
return False
|
||||
|
||||
def _get_audio_with_sr(
|
||||
self,
|
||||
audio: AudioItem,
|
||||
) -> tuple[np.ndarray, Optional[float]]:
|
||||
if isinstance(audio, tuple):
|
||||
return audio
|
||||
if isinstance(audio, list):
|
||||
return np.array(audio), None
|
||||
if isinstance(audio, np.ndarray):
|
||||
return audio, None
|
||||
if isinstance(audio, torch.Tensor):
|
||||
return audio.numpy(), None
|
||||
|
||||
assert_never(audio)
|
||||
|
||||
def _parse_audio_data(
|
||||
self,
|
||||
data: ModalityData[AudioItem],
|
||||
) -> Optional[ModalityDataItems[Any, Any]]:
|
||||
# also check single audio item with sampling rate
|
||||
if self._is_empty(data) or (isinstance(data, tuple)
|
||||
and self._is_empty(data[0])):
|
||||
return None
|
||||
|
||||
if self._is_embeddings(data):
|
||||
return AudioEmbeddingItems(data)
|
||||
|
||||
if (is_list_of(data, float)
|
||||
or isinstance(data,
|
||||
(np.ndarray, torch.Tensor)) and data.ndim == 1
|
||||
or isinstance(data, tuple)):
|
||||
data_items = [data]
|
||||
elif isinstance(data, (np.ndarray, torch.Tensor)):
|
||||
data_items = [elem for elem in data]
|
||||
else:
|
||||
data_items = data
|
||||
|
||||
new_audios = list[np.ndarray]()
|
||||
for data_item in data_items:
|
||||
audio, orig_sr = self._get_audio_with_sr(data_item)
|
||||
if orig_sr is None:
|
||||
new_audio = audio
|
||||
else:
|
||||
new_audio = self.audio_resampler.resample(audio,
|
||||
orig_sr=orig_sr)
|
||||
|
||||
new_audios.append(new_audio)
|
||||
|
||||
return AudioProcessorItems(new_audios)
|
||||
|
||||
def _parse_image_data(
|
||||
self,
|
||||
data: ModalityData[ImageItem],
|
||||
) -> Optional[ModalityDataItems[Any, Any]]:
|
||||
if self._is_empty(data):
|
||||
return None
|
||||
|
||||
if self._is_embeddings(data):
|
||||
return ImageEmbeddingItems(data)
|
||||
|
||||
if (isinstance(data, PILImage.Image)
|
||||
or isinstance(data,
|
||||
(np.ndarray, torch.Tensor)) and data.ndim == 3):
|
||||
data_items = [data]
|
||||
elif isinstance(data, (np.ndarray, torch.Tensor)):
|
||||
data_items = [elem for elem in data]
|
||||
else:
|
||||
data_items = data
|
||||
|
||||
return ImageProcessorItems(data_items)
|
||||
|
||||
def _parse_video_data(
|
||||
self,
|
||||
data: ModalityData[VideoItem],
|
||||
) -> Optional[ModalityDataItems[Any, Any]]:
|
||||
if self._is_empty(data):
|
||||
return None
|
||||
|
||||
if self._is_embeddings(data):
|
||||
return VideoEmbeddingItems(data)
|
||||
|
||||
if (is_list_of(data, PILImage.Image)
|
||||
or isinstance(data,
|
||||
(np.ndarray, torch.Tensor)) and data.ndim == 4):
|
||||
data_items = [data]
|
||||
elif isinstance(data, (np.ndarray, torch.Tensor)):
|
||||
data_items = [elem for elem in data]
|
||||
else:
|
||||
data_items = data
|
||||
|
||||
return VideoProcessorItems(data_items)
|
||||
|
||||
def _get_subparsers(self) -> Mapping[str, ModalityDataParser]:
|
||||
return {
|
||||
"audio": self._parse_audio_data,
|
||||
"image": self._parse_image_data,
|
||||
"video": self._parse_video_data,
|
||||
}
|
||||
|
||||
def parse_mm_data(self,
|
||||
mm_data: MultiModalDataDict) -> MultiModalDataItems:
|
||||
subparsers = self._get_subparsers()
|
||||
|
||||
mm_items = MultiModalDataItems()
|
||||
for k, v in mm_data.items():
|
||||
if k not in subparsers:
|
||||
raise ValueError(f"Unsupported modality: {k}")
|
||||
|
||||
# ignore empty embedding data
|
||||
if (parsed_data := subparsers[k](v)) is not None:
|
||||
mm_items[k] = parsed_data
|
||||
|
||||
return mm_items
|
||||
1895
vllm/multimodal/processing.py
Normal file
1895
vllm/multimodal/processing.py
Normal file
File diff suppressed because it is too large
Load Diff
258
vllm/multimodal/profiling.py
Normal file
258
vllm/multimodal/profiling.py
Normal file
@@ -0,0 +1,258 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Generic, NamedTuple, Optional, TypeVar, Union, cast
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
from PIL import Image
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
|
||||
MultiModalInputs, MultiModalKwargs,
|
||||
MultiModalPlaceholderDict)
|
||||
from .processing import (BaseMultiModalProcessor, BaseProcessingInfo,
|
||||
EncDecMultiModalProcessor)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcessorInputs:
|
||||
"""
|
||||
Represents the keyword arguments to
|
||||
[`vllm.multimodal.processing.BaseMultiModalProcessor.apply`][].
|
||||
"""
|
||||
prompt: Union[str, list[int]]
|
||||
mm_data: MultiModalDataDict
|
||||
hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict)
|
||||
|
||||
|
||||
class DummyEncoderData(NamedTuple):
|
||||
"""Dummy data used for profiling."""
|
||||
|
||||
prompt_token_ids: list[int]
|
||||
|
||||
|
||||
class DummyDecoderData(NamedTuple):
|
||||
"""Dummy data used for profiling."""
|
||||
|
||||
prompt_token_ids: list[int]
|
||||
multi_modal_data: MultiModalKwargs
|
||||
multi_modal_placeholders: MultiModalPlaceholderDict
|
||||
|
||||
|
||||
_I = TypeVar("_I", bound=BaseProcessingInfo)
|
||||
|
||||
|
||||
class BaseDummyInputsBuilder(ABC, Generic[_I]):
|
||||
"""
|
||||
Abstract base class that constructs the dummy data to profile
|
||||
multi-modal models.
|
||||
"""
|
||||
|
||||
def __init__(self, info: _I) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.info = info
|
||||
|
||||
@abstractmethod
|
||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||
"""
|
||||
Build the text input corresponding to `mm_counts`.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_dummy_mm_data(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> MultiModalDataDict:
|
||||
"""
|
||||
Build the multimodal input which, after processing, results in
|
||||
the maximum possible number of placeholder tokens.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
"""
|
||||
Build the input which, after processing, results in
|
||||
the maximum possible number of placeholder tokens.
|
||||
"""
|
||||
dummy_text = self.get_dummy_text(mm_counts)
|
||||
dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts)
|
||||
|
||||
return ProcessorInputs(prompt=dummy_text, mm_data=dummy_mm_data)
|
||||
|
||||
def _get_dummy_audios(
|
||||
self,
|
||||
*,
|
||||
length: int,
|
||||
num_audios: int,
|
||||
) -> list[npt.NDArray]:
|
||||
if num_audios == 0:
|
||||
return []
|
||||
audio = np.zeros((length, ))
|
||||
return [audio] * num_audios
|
||||
|
||||
def _get_dummy_images(
|
||||
self,
|
||||
*,
|
||||
width: int,
|
||||
height: int,
|
||||
num_images: int,
|
||||
) -> list[Image.Image]:
|
||||
if num_images == 0:
|
||||
return []
|
||||
image = Image.new("RGB", (width, height), color=255)
|
||||
return [image] * num_images
|
||||
|
||||
def _get_dummy_videos(
|
||||
self,
|
||||
*,
|
||||
width: int,
|
||||
height: int,
|
||||
num_frames: int,
|
||||
num_videos: int,
|
||||
) -> list[npt.NDArray]:
|
||||
if num_videos == 0:
|
||||
return []
|
||||
video = np.full((num_frames, width, height, 3), 255)
|
||||
return [video] * num_videos
|
||||
|
||||
|
||||
class MultiModalProfiler(Generic[_I]):
|
||||
"""
|
||||
Contains code for running memory profiling for multi-modal models.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
processor: BaseMultiModalProcessor[_I],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.processor = processor
|
||||
|
||||
@property
|
||||
def processing_info(self) -> BaseProcessingInfo:
|
||||
return self.processor.info
|
||||
|
||||
@property
|
||||
def dummy_inputs(self) -> BaseDummyInputsBuilder[_I]:
|
||||
return self.processor.dummy_inputs
|
||||
|
||||
def get_mm_limits(self) -> Mapping[str, int]:
|
||||
return self.processing_info.get_allowed_mm_limits()
|
||||
|
||||
def _get_dummy_mm_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Optional[Mapping[str, int]] = None,
|
||||
) -> MultiModalInputs:
|
||||
if mm_counts is None:
|
||||
mm_counts = self.get_mm_limits()
|
||||
|
||||
factory = self.dummy_inputs
|
||||
processor_inputs = factory.get_dummy_processor_inputs(
|
||||
seq_len, mm_counts)
|
||||
|
||||
return self.processor.apply(
|
||||
prompt=processor_inputs.prompt,
|
||||
mm_data=processor_inputs.mm_data,
|
||||
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
|
||||
)
|
||||
|
||||
def _get_mm_num_tokens(
|
||||
self,
|
||||
mm_inputs: MultiModalInputs,
|
||||
) -> Mapping[str, int]:
|
||||
placeholders_by_modality = mm_inputs["mm_placeholders"]
|
||||
|
||||
return {
|
||||
modality: sum(item.get_num_embeds() for item in placeholders)
|
||||
for modality, placeholders in placeholders_by_modality.items()
|
||||
}
|
||||
|
||||
def get_encoder_dummy_data(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Optional[Mapping[str, int]] = None,
|
||||
) -> DummyEncoderData:
|
||||
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
|
||||
mm_inputs = cast(MultiModalEncDecInputs, mm_inputs)
|
||||
|
||||
# For encoder-decoder models, use encoder prompt token ids instead of
|
||||
# decoder prompt to construct dummy seq_data for encoder profiling.
|
||||
encoder_prompt_token_ids = mm_inputs["encoder_prompt_token_ids"]
|
||||
|
||||
total_len = len(encoder_prompt_token_ids)
|
||||
|
||||
processor = cast(EncDecMultiModalProcessor, self.processor)
|
||||
if processor.pad_dummy_encoder_prompt:
|
||||
num_tokens_to_pad = max(total_len, seq_len) - total_len
|
||||
encoder_prompt_token_ids.extend([0] * num_tokens_to_pad)
|
||||
# NOTE: Whisper allows total_len > seq_len.
|
||||
elif total_len > seq_len and not envs.VLLM_USE_V1:
|
||||
# `max_num_batched_tokens` is defined by `SchedulerConfig`
|
||||
logger.warning_once(
|
||||
"The encoder sequence length used for profiling (max_num_batched_tokens / max_num_seqs = %d) " # noqa: E501
|
||||
"is too short to hold the multi-modal embeddings in the worst case (%d tokens in total, out of which %s are reserved for multi-modal embeddings). " # noqa: E501
|
||||
"This may cause certain multi-modal inputs to fail during inference, even when the input text is short. " # noqa: E501
|
||||
"To avoid this, you should increase `max_model_len`, reduce `max_num_seqs`, and/or reduce `mm_counts`.", # noqa: E501
|
||||
seq_len,
|
||||
total_len,
|
||||
str(self._get_mm_num_tokens(mm_inputs)),
|
||||
)
|
||||
|
||||
return DummyEncoderData(encoder_prompt_token_ids)
|
||||
|
||||
def get_decoder_dummy_data(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Optional[Mapping[str, int]] = None,
|
||||
) -> DummyDecoderData:
|
||||
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
|
||||
|
||||
prompt_token_ids = mm_inputs["prompt_token_ids"]
|
||||
total_len = len(prompt_token_ids)
|
||||
|
||||
# V0 does not support chunked prefill.
|
||||
if total_len > seq_len and not envs.VLLM_USE_V1:
|
||||
# `max_num_batched_tokens` is defined by `SchedulerConfig`
|
||||
logger.warning_once(
|
||||
"The sequence length used for profiling (max_num_batched_tokens / max_num_seqs = %d) " # noqa: E501
|
||||
"is too short to hold the multi-modal embeddings in the worst case (%d tokens in total, out of which %s are reserved for multi-modal embeddings). " # noqa: E501
|
||||
"This may cause certain multi-modal inputs to fail during inference, even when the input text is short. " # noqa: E501
|
||||
"To avoid this, you should increase `max_model_len`, reduce `max_num_seqs`, and/or reduce `mm_counts`.", # noqa: E501
|
||||
seq_len,
|
||||
total_len,
|
||||
str(self._get_mm_num_tokens(mm_inputs)),
|
||||
)
|
||||
|
||||
if total_len < seq_len:
|
||||
prompt_token_ids.extend([0] * (seq_len - total_len))
|
||||
|
||||
return DummyDecoderData(
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
multi_modal_data=mm_inputs["mm_kwargs"],
|
||||
multi_modal_placeholders=mm_inputs["mm_placeholders"],
|
||||
)
|
||||
|
||||
def get_mm_max_tokens(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Optional[Mapping[str, int]] = None,
|
||||
) -> Mapping[str, int]:
|
||||
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
|
||||
|
||||
return self._get_mm_num_tokens(mm_inputs)
|
||||
331
vllm/multimodal/registry.py
Normal file
331
vllm/multimodal/registry.py
Normal file
@@ -0,0 +1,331 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Generic, Optional, Protocol, TypeVar
|
||||
|
||||
import torch.nn as nn
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from vllm.envs import VLLM_MM_INPUT_CACHE_GIB
|
||||
from vllm.inputs import InputProcessingContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
|
||||
cached_tokenizer_from_config)
|
||||
from vllm.utils import ClassRegistry
|
||||
|
||||
from .processing import (BaseMultiModalProcessor, BaseProcessingInfo,
|
||||
ProcessingCache)
|
||||
from .profiling import (BaseDummyInputsBuilder, DummyDecoderData,
|
||||
DummyEncoderData, MultiModalProfiler)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
N = TypeVar("N", bound=type[nn.Module])
|
||||
_I = TypeVar("_I", bound=BaseProcessingInfo)
|
||||
_I_co = TypeVar("_I_co", bound=BaseProcessingInfo, covariant=True)
|
||||
|
||||
|
||||
class ProcessingInfoFactory(Protocol[_I_co]):
|
||||
"""
|
||||
Constructs a
|
||||
[`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor]
|
||||
instance from the context.
|
||||
"""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
ctx: InputProcessingContext,
|
||||
) -> _I_co:
|
||||
...
|
||||
|
||||
|
||||
class DummyInputsBuilderFactory(Protocol[_I]):
|
||||
"""
|
||||
Constructs a
|
||||
[`BaseDummyInputsBuilder`][vllm.multimodal.profiling.BaseDummyInputsBuilder]
|
||||
instance from the context.
|
||||
"""
|
||||
|
||||
def __call__(self, info: _I) -> BaseDummyInputsBuilder[_I]:
|
||||
...
|
||||
|
||||
|
||||
class MultiModalProcessorFactory(Protocol[_I]):
|
||||
"""
|
||||
Constructs a
|
||||
[`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor]
|
||||
instance from the context.
|
||||
"""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
info: _I,
|
||||
dummy_inputs: BaseDummyInputsBuilder[_I],
|
||||
*,
|
||||
cache: Optional[ProcessingCache] = None,
|
||||
) -> BaseMultiModalProcessor[_I]:
|
||||
...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _ProcessorFactories(Generic[_I]):
|
||||
info: ProcessingInfoFactory[_I]
|
||||
processor: MultiModalProcessorFactory[_I]
|
||||
dummy_inputs: DummyInputsBuilderFactory[_I]
|
||||
|
||||
def build_processor(
|
||||
self,
|
||||
ctx: InputProcessingContext,
|
||||
*,
|
||||
cache: Optional[ProcessingCache] = None,
|
||||
):
|
||||
info = self.info(ctx)
|
||||
dummy_inputs_builder = self.dummy_inputs(info)
|
||||
return self.processor(info, dummy_inputs_builder, cache=cache)
|
||||
|
||||
|
||||
class MultiModalRegistry:
|
||||
"""
|
||||
A registry that dispatches data processing according to the model.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._processor_factories = ClassRegistry[nn.Module,
|
||||
_ProcessorFactories]()
|
||||
|
||||
self._processing_cache = ProcessingCache(VLLM_MM_INPUT_CACHE_GIB)
|
||||
|
||||
def reset_processor_cache(self) -> bool:
|
||||
"""Reset the multi-modal processing cache."""
|
||||
self._processing_cache.reset()
|
||||
|
||||
return True # Success
|
||||
|
||||
@deprecated("Legacy input processor/mapper pipeline has been removed. "
|
||||
"Please update your model runner to use "
|
||||
"`seq_group_metadata.multi_modal_data` directly without "
|
||||
"further processing.")
|
||||
def create_input_mapper(self, model_config: "ModelConfig"):
|
||||
return lambda data, mm_processor_kwargs: data
|
||||
|
||||
def get_max_tokens_per_item_by_modality(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
) -> Mapping[str, int]:
|
||||
"""
|
||||
Get the maximum number of tokens per data item from each modality based
|
||||
on underlying model configuration.
|
||||
"""
|
||||
if not model_config.is_multimodal_model:
|
||||
return {}
|
||||
|
||||
processor = self.create_processor(model_config, disable_cache=False)
|
||||
profiler = MultiModalProfiler(processor)
|
||||
|
||||
seq_len = model_config.max_model_len
|
||||
mm_limits = self.get_mm_limits_per_prompt(model_config)
|
||||
|
||||
return profiler.get_mm_max_tokens(
|
||||
seq_len,
|
||||
{
|
||||
modality: 1
|
||||
for modality, limit in mm_limits.items() if limit > 0
|
||||
},
|
||||
)
|
||||
|
||||
def get_max_tokens_per_item_by_nonzero_modality(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
) -> Mapping[str, int]:
|
||||
"""
|
||||
Get the maximum number of tokens per data item from each modality based
|
||||
on underlying model configuration, excluding modalities that user
|
||||
explicitly disabled via `limit_mm_per_prompt`.
|
||||
|
||||
Note:
|
||||
This is currently directly used only in V1 for profiling the memory
|
||||
usage of a model.
|
||||
"""
|
||||
mm_limits = self.get_mm_limits_per_prompt(model_config)
|
||||
|
||||
return {
|
||||
key: max_tokens_per_mm_item
|
||||
for key, max_tokens_per_mm_item in
|
||||
self.get_max_tokens_per_item_by_modality(model_config).items()
|
||||
if mm_limits[key] > 0
|
||||
}
|
||||
|
||||
def get_max_tokens_by_modality(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
) -> Mapping[str, int]:
|
||||
"""
|
||||
Get the maximum number of tokens from each modality
|
||||
for profiling the memory usage of a model.
|
||||
"""
|
||||
mm_limits = self.get_mm_limits_per_prompt(model_config)
|
||||
|
||||
return {
|
||||
key: mm_limits[key] * max_tokens_per_mm_item
|
||||
for key, max_tokens_per_mm_item in
|
||||
self.get_max_tokens_per_item_by_modality(model_config).items()
|
||||
}
|
||||
|
||||
def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int:
|
||||
"""
|
||||
Get the maximum number of multi-modal tokens
|
||||
for profiling the memory usage of a model.
|
||||
"""
|
||||
return sum(self.get_max_tokens_by_modality(model_config).values())
|
||||
|
||||
@deprecated("Legacy input processor/mapper pipeline has been removed. "
|
||||
"Please update your model runner to use "
|
||||
"`seq_group_metadata.multi_modal_data` directly without "
|
||||
"further processing.")
|
||||
def init_mm_limits_per_prompt(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def get_mm_limits_per_prompt(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
) -> Mapping[str, int]:
|
||||
"""
|
||||
Get the maximum number of multi-modal input instances for each modality
|
||||
that are allowed per prompt for a model class.
|
||||
"""
|
||||
if not model_config.is_multimodal_model:
|
||||
return {}
|
||||
|
||||
processor = self.create_processor(model_config, disable_cache=False)
|
||||
profiler = MultiModalProfiler(processor)
|
||||
return profiler.get_mm_limits()
|
||||
|
||||
def register_processor(
|
||||
self,
|
||||
processor: MultiModalProcessorFactory[_I],
|
||||
*,
|
||||
info: ProcessingInfoFactory[_I],
|
||||
dummy_inputs: DummyInputsBuilderFactory[_I],
|
||||
):
|
||||
"""
|
||||
Register a multi-modal processor to a model class. The processor
|
||||
is constructed lazily, hence a factory method should be passed.
|
||||
|
||||
When the model receives multi-modal data, the provided function is
|
||||
invoked to transform the data into a dictionary of model inputs.
|
||||
"""
|
||||
|
||||
def wrapper(model_cls: N) -> N:
|
||||
if self._processor_factories.contains(model_cls, strict=True):
|
||||
logger.warning(
|
||||
"Model class %s already has a multi-modal processor "
|
||||
"registered to %s. It is overwritten by the new one.",
|
||||
model_cls, self)
|
||||
|
||||
self._processor_factories[model_cls] = _ProcessorFactories(
|
||||
info=info,
|
||||
dummy_inputs=dummy_inputs,
|
||||
processor=processor,
|
||||
)
|
||||
|
||||
return model_cls
|
||||
|
||||
return wrapper
|
||||
|
||||
def _get_model_cls(self, model_config: "ModelConfig"):
|
||||
# Avoid circular import
|
||||
from vllm.model_executor.model_loader import get_model_architecture
|
||||
|
||||
model_cls, _ = get_model_architecture(model_config)
|
||||
return model_cls
|
||||
|
||||
@deprecated("Legacy input processor/mapper pipeline has been removed. "
|
||||
"Please update your model runner to use "
|
||||
"`seq_group_metadata.multi_modal_data` directly without "
|
||||
"further processing.")
|
||||
def has_processor(self, model_config: "ModelConfig") -> bool:
|
||||
return True
|
||||
|
||||
def create_processor(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
*,
|
||||
tokenizer: Optional[AnyTokenizer] = None,
|
||||
disable_cache: Optional[bool] = None,
|
||||
) -> BaseMultiModalProcessor[BaseProcessingInfo]:
|
||||
"""
|
||||
Create a multi-modal processor for a specific model and tokenizer.
|
||||
"""
|
||||
if not model_config.is_multimodal_model:
|
||||
raise ValueError(f"{model_config.model} is not a multimodal model")
|
||||
|
||||
if tokenizer is None:
|
||||
tokenizer = cached_tokenizer_from_config(model_config)
|
||||
if disable_cache is None:
|
||||
mm_config = model_config.get_multimodal_config()
|
||||
disable_cache = mm_config.disable_mm_preprocessor_cache
|
||||
|
||||
model_cls = self._get_model_cls(model_config)
|
||||
factories = self._processor_factories[model_cls]
|
||||
|
||||
ctx = InputProcessingContext(model_config, tokenizer)
|
||||
cache = None if disable_cache else self._processing_cache
|
||||
|
||||
return factories.build_processor(ctx, cache=cache)
|
||||
|
||||
def get_decoder_dummy_data(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
seq_len: int,
|
||||
mm_counts: Optional[Mapping[str, int]] = None,
|
||||
) -> DummyDecoderData:
|
||||
"""
|
||||
Create dummy data for profiling the memory usage of a model.
|
||||
|
||||
The model is identified by ``model_config``.
|
||||
"""
|
||||
processor = self.create_processor(model_config, disable_cache=False)
|
||||
profiler = MultiModalProfiler(processor)
|
||||
dummy_data = profiler.get_decoder_dummy_data(seq_len, mm_counts)
|
||||
|
||||
# Having more tokens is over-conservative but otherwise fine
|
||||
token_ids = dummy_data.prompt_token_ids
|
||||
if len(token_ids) < seq_len:
|
||||
raise AssertionError(
|
||||
f"Expected at least {seq_len} dummy tokens for profiling, "
|
||||
f"but found {len(token_ids)} tokens instead.")
|
||||
|
||||
return dummy_data
|
||||
|
||||
def get_encoder_dummy_data(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
seq_len: int,
|
||||
mm_counts: Optional[Mapping[str, int]] = None,
|
||||
) -> DummyEncoderData:
|
||||
"""
|
||||
Create dummy data for profiling the memory usage of a model.
|
||||
|
||||
The model is identified by ``model_config``.
|
||||
"""
|
||||
processor = self.create_processor(model_config, disable_cache=False)
|
||||
profiler = MultiModalProfiler(processor)
|
||||
dummy_data = profiler.get_encoder_dummy_data(seq_len, mm_counts)
|
||||
|
||||
# Having more tokens is over-conservative but otherwise fine
|
||||
token_ids = dummy_data.prompt_token_ids
|
||||
if len(token_ids) < seq_len:
|
||||
logger.warning_once(
|
||||
"Expected at least %d dummy encoder tokens for profiling, but found %d tokens instead.", # noqa: E501
|
||||
seq_len,
|
||||
len(token_ids),
|
||||
)
|
||||
|
||||
return dummy_data
|
||||
436
vllm/multimodal/utils.py
Normal file
436
vllm/multimodal/utils.py
Normal file
@@ -0,0 +1,436 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from itertools import groupby
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union
|
||||
from urllib.parse import ParseResult, urlparse
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import torch
|
||||
from PIL import Image, UnidentifiedImageError
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.connections import HTTPConnection, global_http_connection
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather)
|
||||
|
||||
from .audio import AudioMediaIO
|
||||
from .base import MediaIO
|
||||
from .image import ImageEmbeddingMediaIO, ImageMediaIO
|
||||
from .inputs import PlaceholderRange
|
||||
from .video import VideoMediaIO
|
||||
|
||||
_M = TypeVar("_M")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .hasher import MultiModalHashDict
|
||||
from .inputs import MultiModalKwargs, MultiModalPlaceholderDict
|
||||
else:
|
||||
MultiModalHashDict = Any
|
||||
MultiModalKwargs = Any
|
||||
MultiModalPlaceholderDict = Any
|
||||
|
||||
|
||||
class MediaConnector:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection: HTTPConnection = global_http_connection,
|
||||
*,
|
||||
allowed_local_media_path: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.connection = connection
|
||||
|
||||
if allowed_local_media_path:
|
||||
allowed_local_media_path_ = Path(allowed_local_media_path)
|
||||
|
||||
if not allowed_local_media_path_.exists():
|
||||
raise ValueError(
|
||||
"Invalid `--allowed-local-media-path`: The path "
|
||||
f"{allowed_local_media_path_} does not exist.")
|
||||
if not allowed_local_media_path_.is_dir():
|
||||
raise ValueError(
|
||||
"Invalid `--allowed-local-media-path`: The path "
|
||||
f"{allowed_local_media_path_} must be a directory.")
|
||||
else:
|
||||
allowed_local_media_path_ = None
|
||||
|
||||
self.allowed_local_media_path = allowed_local_media_path_
|
||||
|
||||
def _load_data_url(
|
||||
self,
|
||||
url_spec: ParseResult,
|
||||
media_io: MediaIO[_M],
|
||||
) -> _M:
|
||||
data_spec, data = url_spec.path.split(",", 1)
|
||||
media_type, data_type = data_spec.split(";", 1)
|
||||
|
||||
if data_type != "base64":
|
||||
msg = "Only base64 data URLs are supported for now."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
return media_io.load_base64(media_type, data)
|
||||
|
||||
def _load_file_url(
|
||||
self,
|
||||
url_spec: ParseResult,
|
||||
media_io: MediaIO[_M],
|
||||
) -> _M:
|
||||
allowed_local_media_path = self.allowed_local_media_path
|
||||
if allowed_local_media_path is None:
|
||||
raise RuntimeError("Cannot load local files without "
|
||||
"`--allowed-local-media-path`.")
|
||||
|
||||
filepath = Path(url_spec.path)
|
||||
if allowed_local_media_path not in filepath.resolve().parents:
|
||||
raise ValueError(
|
||||
f"The file path {filepath} must be a subpath "
|
||||
f"of `--allowed-local-media-path` {allowed_local_media_path}.")
|
||||
|
||||
return media_io.load_file(filepath)
|
||||
|
||||
def load_from_url(
|
||||
self,
|
||||
url: str,
|
||||
media_io: MediaIO[_M],
|
||||
*,
|
||||
fetch_timeout: Optional[int] = None,
|
||||
) -> _M:
|
||||
url_spec = urlparse(url)
|
||||
|
||||
if url_spec.scheme.startswith("http"):
|
||||
connection = self.connection
|
||||
data = connection.get_bytes(url, timeout=fetch_timeout)
|
||||
|
||||
return media_io.load_bytes(data)
|
||||
|
||||
if url_spec.scheme == "data":
|
||||
return self._load_data_url(url_spec, media_io)
|
||||
|
||||
if url_spec.scheme == "file":
|
||||
return self._load_file_url(url_spec, media_io)
|
||||
|
||||
msg = "The URL must be either a HTTP, data or file URL."
|
||||
raise ValueError(msg)
|
||||
|
||||
async def load_from_url_async(
|
||||
self,
|
||||
url: str,
|
||||
media_io: MediaIO[_M],
|
||||
*,
|
||||
fetch_timeout: Optional[int] = None,
|
||||
) -> _M:
|
||||
url_spec = urlparse(url)
|
||||
|
||||
if url_spec.scheme.startswith("http"):
|
||||
connection = self.connection
|
||||
data = await connection.async_get_bytes(url, timeout=fetch_timeout)
|
||||
|
||||
return media_io.load_bytes(data)
|
||||
|
||||
if url_spec.scheme == "data":
|
||||
return self._load_data_url(url_spec, media_io)
|
||||
|
||||
if url_spec.scheme == "file":
|
||||
return self._load_file_url(url_spec, media_io)
|
||||
|
||||
msg = "The URL must be either a HTTP, data or file URL."
|
||||
raise ValueError(msg)
|
||||
|
||||
def fetch_audio(
|
||||
self,
|
||||
audio_url: str,
|
||||
) -> tuple[np.ndarray, Union[int, float]]:
|
||||
"""
|
||||
Load audio from a URL.
|
||||
"""
|
||||
audio_io = AudioMediaIO()
|
||||
|
||||
return self.load_from_url(
|
||||
audio_url,
|
||||
audio_io,
|
||||
fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
|
||||
)
|
||||
|
||||
async def fetch_audio_async(
|
||||
self,
|
||||
audio_url: str,
|
||||
) -> tuple[np.ndarray, Union[int, float]]:
|
||||
"""
|
||||
Asynchronously fetch audio from a URL.
|
||||
"""
|
||||
audio_io = AudioMediaIO()
|
||||
|
||||
return await self.load_from_url_async(
|
||||
audio_url,
|
||||
audio_io,
|
||||
fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
|
||||
)
|
||||
|
||||
def fetch_image(
|
||||
self,
|
||||
image_url: str,
|
||||
*,
|
||||
image_mode: str = "RGB",
|
||||
) -> Image.Image:
|
||||
"""
|
||||
Load a PIL image from a HTTP or base64 data URL.
|
||||
|
||||
By default, the image is converted into RGB format.
|
||||
"""
|
||||
image_io = ImageMediaIO(image_mode=image_mode)
|
||||
|
||||
try:
|
||||
return self.load_from_url(
|
||||
image_url,
|
||||
image_io,
|
||||
fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
|
||||
)
|
||||
except UnidentifiedImageError as e:
|
||||
# convert to ValueError to be properly caught upstream
|
||||
raise ValueError(str(e)) from e
|
||||
|
||||
async def fetch_image_async(
|
||||
self,
|
||||
image_url: str,
|
||||
*,
|
||||
image_mode: str = "RGB",
|
||||
) -> Image.Image:
|
||||
"""
|
||||
Asynchronously load a PIL image from a HTTP or base64 data URL.
|
||||
|
||||
By default, the image is converted into RGB format.
|
||||
"""
|
||||
image_io = ImageMediaIO(image_mode=image_mode)
|
||||
|
||||
try:
|
||||
return await self.load_from_url_async(
|
||||
image_url,
|
||||
image_io,
|
||||
fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
|
||||
)
|
||||
except UnidentifiedImageError as e:
|
||||
# convert to ValueError to be properly caught upstream
|
||||
raise ValueError(str(e)) from e
|
||||
|
||||
def fetch_video(
|
||||
self,
|
||||
video_url: str,
|
||||
*,
|
||||
image_mode: str = "RGB",
|
||||
num_frames: int = 32,
|
||||
) -> npt.NDArray:
|
||||
"""
|
||||
Load video from a HTTP or base64 data URL.
|
||||
"""
|
||||
image_io = ImageMediaIO(image_mode=image_mode)
|
||||
video_io = VideoMediaIO(image_io, num_frames=num_frames)
|
||||
|
||||
return self.load_from_url(
|
||||
video_url,
|
||||
video_io,
|
||||
fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
|
||||
)
|
||||
|
||||
async def fetch_video_async(
|
||||
self,
|
||||
video_url: str,
|
||||
*,
|
||||
image_mode: str = "RGB",
|
||||
num_frames: int = 32,
|
||||
) -> npt.NDArray:
|
||||
"""
|
||||
Asynchronously load video from a HTTP or base64 data URL.
|
||||
|
||||
By default, the image is converted into RGB format.
|
||||
"""
|
||||
image_io = ImageMediaIO(image_mode=image_mode)
|
||||
video_io = VideoMediaIO(image_io, num_frames=num_frames)
|
||||
|
||||
return await self.load_from_url_async(
|
||||
video_url,
|
||||
video_io,
|
||||
fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
|
||||
)
|
||||
|
||||
def fetch_image_embedding(
|
||||
self,
|
||||
data: str,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Load image embedding from a URL.
|
||||
"""
|
||||
image_embedding_io = ImageEmbeddingMediaIO()
|
||||
|
||||
return image_embedding_io.load_base64("", data)
|
||||
|
||||
|
||||
global_media_connector = MediaConnector()
|
||||
"""The global [`MediaConnector`][vllm.multimodal.utils.MediaConnector]
|
||||
instance used by vLLM."""
|
||||
|
||||
fetch_audio = global_media_connector.fetch_audio
|
||||
fetch_image = global_media_connector.fetch_image
|
||||
fetch_video = global_media_connector.fetch_video
|
||||
|
||||
|
||||
def encode_audio_base64(
|
||||
audio: np.ndarray,
|
||||
sampling_rate: float,
|
||||
) -> str:
|
||||
"""Encode audio as base64."""
|
||||
audio_io = AudioMediaIO()
|
||||
return audio_io.encode_base64((audio, sampling_rate))
|
||||
|
||||
|
||||
def encode_image_base64(
|
||||
image: Image.Image,
|
||||
*,
|
||||
image_mode: str = "RGB",
|
||||
format: str = "JPEG",
|
||||
) -> str:
|
||||
"""
|
||||
Encode a pillow image to base64 format.
|
||||
|
||||
By default, the image is converted into RGB format before being encoded.
|
||||
"""
|
||||
image_io = ImageMediaIO(image_mode=image_mode)
|
||||
return image_io.encode_base64(image, image_format=format)
|
||||
|
||||
|
||||
def encode_video_base64(frames: npt.NDArray) -> str:
|
||||
image_io = ImageMediaIO()
|
||||
video_io = VideoMediaIO(image_io)
|
||||
return video_io.encode_base64(frames)
|
||||
|
||||
|
||||
def merge_and_sort_multimodal_metadata(
|
||||
mm_positions: MultiModalPlaceholderDict,
|
||||
mm_hashes: Optional[MultiModalHashDict],
|
||||
) -> tuple[list[str], list[PlaceholderRange], Optional[list[str]]]:
|
||||
"""Given a MultiModalPlaceholderDict, merge all PlaceholderRange
|
||||
objects from all available modalities into a single list of
|
||||
PlaceholderRange, sorted by their offset (starting index in the input
|
||||
sequence) in the ascending order.
|
||||
|
||||
Optionally if a `MultiModalHashDict` is given, same operation will be
|
||||
applied to the object and the sorted list of hashes will be returned.
|
||||
|
||||
Returns:
|
||||
list[str]: List of item modalities in order of their positions in the
|
||||
input sequence.
|
||||
list[PlaceholderRange]: Sorted list of all PlaceholdeRanges from
|
||||
mm_positions.
|
||||
Optional[list[str]]: Sorted list of all hashes from mm_hashes if given,
|
||||
None otherwise.
|
||||
"""
|
||||
|
||||
modalities = list(mm_positions.keys())
|
||||
|
||||
assert len(modalities) > 0, "No modalities found in the mm_positions."
|
||||
|
||||
# For single modality, placeholder ranges and hashes are already sorted
|
||||
# so we can return the list directly.
|
||||
if len(modalities) == 1:
|
||||
modality = modalities[0]
|
||||
placeholder_list = list(mm_positions[modality])
|
||||
|
||||
return [modality] * len(
|
||||
placeholder_list
|
||||
), placeholder_list, None if not mm_hashes else mm_hashes[modality]
|
||||
|
||||
# Create a list of (modality, placeholder, hash) tuples for all placeholders
|
||||
all_items = []
|
||||
for modality in modalities:
|
||||
placeholder_list = list(mm_positions[modality])
|
||||
hash_list: list[Optional[str]] = list(
|
||||
mm_hashes[modality]) if mm_hashes and modality in mm_hashes else [
|
||||
None
|
||||
] * len(placeholder_list)
|
||||
|
||||
for placeholder, hash_value in zip(placeholder_list, hash_list):
|
||||
all_items.append((modality, placeholder, hash_value))
|
||||
|
||||
# Sort all items by offset
|
||||
all_items.sort(key=lambda x: x[1].offset)
|
||||
|
||||
# Split into separate lists
|
||||
sorted_modalities = [item[0] for item in all_items]
|
||||
merged_placeholders = [item[1] for item in all_items]
|
||||
merged_hashes = [str(item[2])
|
||||
for item in all_items] if mm_hashes is not None else None
|
||||
|
||||
return sorted_modalities, merged_placeholders, merged_hashes
|
||||
|
||||
|
||||
def group_mm_inputs_by_modality(
|
||||
mm_inputs: list[MultiModalKwargs]) -> list[list[MultiModalKwargs]]:
|
||||
"""Group consecutive MultiModalKwargs from mm_inputs with the same modality
|
||||
together into the same list for batching purpose. For MultiModalKwargs with
|
||||
multiple modalities, put them into their own list.
|
||||
|
||||
Args:
|
||||
mm_inputs: List of MultiModalKwargs.
|
||||
|
||||
Returns:
|
||||
list[list[vllm.multimodal.MultiModalKwargs]]: List of list of
|
||||
`MultiModalKwargs`, each inner list contains consecutive
|
||||
`MultiModalKwargs` with same modality.
|
||||
"""
|
||||
if not mm_inputs:
|
||||
return []
|
||||
|
||||
def modality_group_func(mm_input: MultiModalKwargs) -> Union[str, int]:
|
||||
# If the input has multiple modalities, return a id as the unique key
|
||||
# for the mm_input input.
|
||||
if len(mm_input.modalities) > 1:
|
||||
return id(mm_input)
|
||||
|
||||
elif len(mm_input.modalities) == 1:
|
||||
return list(mm_input.modalities)[0]
|
||||
|
||||
# FIXME(Isotr0py): Modality of mm_input from legacy pipeline is empty,
|
||||
# this is used to make InternVL with legacy pipeline still work with v1.
|
||||
else:
|
||||
return ""
|
||||
|
||||
return [
|
||||
list(group) for _, group in groupby(mm_inputs, key=modality_group_func)
|
||||
]
|
||||
|
||||
|
||||
def run_dp_sharded_vision_model(image_input: torch.Tensor,
|
||||
vision_model: torch.nn.Module) -> torch.Tensor:
|
||||
"""Run a vision model with data parallelism (DP) sharding. The function
|
||||
will shard the input image tensor on the first dimension and run the vision
|
||||
model
|
||||
|
||||
Args:
|
||||
image_input (torch.Tensor): Image input tensor.
|
||||
vision_model (torch.nn.Module): Vision model.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output image embeddings
|
||||
"""
|
||||
|
||||
num_chunks = image_input.shape[0]
|
||||
mp_world_size = get_tensor_model_parallel_world_size()
|
||||
num_chunks_per_rank = (num_chunks + mp_world_size - 1) // mp_world_size
|
||||
num_padded_chunks = num_chunks_per_rank * mp_world_size - num_chunks
|
||||
pad = (0, ) * (2 * (image_input.dim() - 1)) + (0, num_padded_chunks)
|
||||
image_input_padded = torch.nn.functional.pad(image_input, pad)
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
image_input_per_rank = image_input_padded[rank *
|
||||
num_chunks_per_rank:(rank + 1) *
|
||||
num_chunks_per_rank, ...]
|
||||
|
||||
vision_embeddings = vision_model(image_input_per_rank)
|
||||
vision_embeddings = tensor_model_parallel_all_gather(vision_embeddings,
|
||||
dim=0)
|
||||
vision_embeddings = vision_embeddings[:num_chunks, ...]
|
||||
return vision_embeddings
|
||||
198
vllm/multimodal/video.py
Normal file
198
vllm/multimodal/video.py
Normal file
@@ -0,0 +1,198 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import base64
|
||||
from abc import abstractmethod
|
||||
from functools import partial
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
from PIL import Image
|
||||
|
||||
from vllm import envs
|
||||
|
||||
from .base import MediaIO
|
||||
from .image import ImageMediaIO
|
||||
|
||||
|
||||
def resize_video(frames: npt.NDArray, size: tuple[int, int]) -> npt.NDArray:
|
||||
num_frames, _, _, channels = frames.shape
|
||||
new_height, new_width = size
|
||||
resized_frames = np.empty((num_frames, new_height, new_width, channels),
|
||||
dtype=frames.dtype)
|
||||
# lazy import cv2 to avoid bothering users who only use text models
|
||||
import cv2
|
||||
for i, frame in enumerate(frames):
|
||||
resized_frame = cv2.resize(frame, (new_width, new_height))
|
||||
resized_frames[i] = resized_frame
|
||||
return resized_frames
|
||||
|
||||
|
||||
def rescale_video_size(frames: npt.NDArray, size_factor: float) -> npt.NDArray:
|
||||
_, height, width, _ = frames.shape
|
||||
new_height = int(height * size_factor)
|
||||
new_width = int(width * size_factor)
|
||||
|
||||
return resize_video(frames, (new_height, new_width))
|
||||
|
||||
|
||||
def sample_frames_from_video(frames: npt.NDArray,
|
||||
num_frames: int) -> npt.NDArray:
|
||||
total_frames = frames.shape[0]
|
||||
if num_frames == -1:
|
||||
return frames
|
||||
|
||||
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
|
||||
sampled_frames = frames[frame_indices, ...]
|
||||
return sampled_frames
|
||||
|
||||
|
||||
class VideoLoader:
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def load_bytes(cls, data: bytes, num_frames: int = -1) -> npt.NDArray:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class VideoLoaderRegistry:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.name2class: dict[str, type] = {}
|
||||
|
||||
def register(self, name: str):
|
||||
|
||||
def wrap(cls_to_register):
|
||||
self.name2class[name] = cls_to_register
|
||||
return cls_to_register
|
||||
|
||||
return wrap
|
||||
|
||||
@staticmethod
|
||||
def load(cls_name: str) -> VideoLoader:
|
||||
cls = VIDEO_LOADER_REGISTRY.name2class.get(cls_name)
|
||||
assert cls is not None, f"VideoLoader class {cls_name} not found"
|
||||
return cls()
|
||||
|
||||
|
||||
VIDEO_LOADER_REGISTRY = VideoLoaderRegistry()
|
||||
|
||||
|
||||
@VIDEO_LOADER_REGISTRY.register("opencv")
|
||||
class OpenCVVideoBackend(VideoLoader):
|
||||
|
||||
def get_cv2_video_api(self):
|
||||
import cv2.videoio_registry as vr
|
||||
|
||||
api_pref = None
|
||||
for backend in vr.getStreamBufferedBackends():
|
||||
if not vr.hasBackend(backend):
|
||||
continue
|
||||
if not vr.isBackendBuiltIn(backend):
|
||||
_, abi, api = vr.getStreamBufferedBackendPluginVersion(backend)
|
||||
if (abi < 1 or (abi == 1 and api < 2)):
|
||||
continue
|
||||
api_pref = backend
|
||||
break
|
||||
return api_pref
|
||||
|
||||
@classmethod
|
||||
def load_bytes(cls, data: bytes, num_frames: int = -1) -> npt.NDArray:
|
||||
import cv2
|
||||
|
||||
backend = cls().get_cv2_video_api()
|
||||
cap = cv2.VideoCapture(BytesIO(data), backend, [])
|
||||
if not cap.isOpened():
|
||||
raise ValueError("Could not open video stream")
|
||||
|
||||
total_frames_num = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
full_read = num_frames == -1 or total_frames_num < num_frames
|
||||
if full_read:
|
||||
num_frames = total_frames_num
|
||||
frame_idx = list(range(0, num_frames))
|
||||
else:
|
||||
uniform_sampled_frames = np.linspace(0,
|
||||
total_frames_num - 1,
|
||||
num_frames,
|
||||
dtype=int)
|
||||
frame_idx = uniform_sampled_frames.tolist()
|
||||
|
||||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
frames = np.empty((len(frame_idx), height, width, 3), dtype=np.uint8)
|
||||
|
||||
i = 0
|
||||
for idx in range(total_frames_num):
|
||||
ok = cap.grab() # next img
|
||||
if not ok:
|
||||
break
|
||||
if idx in frame_idx: # only decompress needed
|
||||
ret, frame = cap.retrieve()
|
||||
if ret:
|
||||
frames[i] = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
i += 1
|
||||
# we expect all frames loaded
|
||||
assert i == num_frames, (f"Expected reading {num_frames} frames, "
|
||||
f"but only loaded {i} frames from video.")
|
||||
return frames
|
||||
|
||||
|
||||
class VideoMediaIO(MediaIO[npt.NDArray]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_io: ImageMediaIO,
|
||||
*,
|
||||
num_frames: int = 32,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.image_io = image_io
|
||||
self.num_frames = num_frames
|
||||
video_loader_backend = envs.VLLM_VIDEO_LOADER_BACKEND
|
||||
self.video_loader = VIDEO_LOADER_REGISTRY.load(video_loader_backend)
|
||||
|
||||
def load_bytes(self, data: bytes) -> npt.NDArray:
|
||||
return self.video_loader.load_bytes(data, self.num_frames)
|
||||
|
||||
def load_base64(self, media_type: str, data: str) -> npt.NDArray:
|
||||
if media_type.lower() == "video/jpeg":
|
||||
load_frame = partial(
|
||||
self.image_io.load_base64,
|
||||
"image/jpeg",
|
||||
)
|
||||
|
||||
return np.stack([
|
||||
np.asarray(load_frame(frame_data))
|
||||
for frame_data in data.split(",")
|
||||
])
|
||||
|
||||
return self.load_bytes(base64.b64decode(data))
|
||||
|
||||
def load_file(self, filepath: Path) -> npt.NDArray:
|
||||
with filepath.open("rb") as f:
|
||||
data = f.read()
|
||||
|
||||
return self.load_bytes(data)
|
||||
|
||||
def encode_base64(
|
||||
self,
|
||||
media: npt.NDArray,
|
||||
*,
|
||||
video_format: str = "JPEG",
|
||||
) -> str:
|
||||
video = media
|
||||
|
||||
if video_format == "JPEG":
|
||||
encode_frame = partial(
|
||||
self.image_io.encode_base64,
|
||||
image_format=video_format,
|
||||
)
|
||||
|
||||
return ",".join(
|
||||
encode_frame(Image.fromarray(frame)) for frame in video)
|
||||
|
||||
msg = "Only JPEG format is supported for now."
|
||||
raise NotImplementedError(msg)
|
||||
Reference in New Issue
Block a user