init
This commit is contained in:
503
vllm/multimodal/utils.py
Normal file
503
vllm/multimodal/utils.py
Normal file
@@ -0,0 +1,503 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import atexit
|
||||
from collections.abc import Iterable
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from itertools import groupby
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union
|
||||
from urllib.parse import ParseResult, urlparse
|
||||
from urllib.request import url2pathname
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import torch
|
||||
from PIL import Image, UnidentifiedImageError
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.connections import HTTPConnection, global_http_connection
|
||||
from vllm.utils.jsontree import json_map_leaves
|
||||
|
||||
from .audio import AudioMediaIO
|
||||
from .base import MediaIO
|
||||
from .image import ImageEmbeddingMediaIO, ImageMediaIO
|
||||
from .video import VideoMediaIO
|
||||
|
||||
_M = TypeVar("_M")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .inputs import (BatchedTensorInputs, MultiModalKwargsItem,
|
||||
MultiModalKwargsItems, MultiModalPlaceholderDict)
|
||||
else:
|
||||
BatchedTensorInputs = Any
|
||||
MultiModalKwargsItem = Any
|
||||
MultiModalKwargsItems = Any
|
||||
MultiModalPlaceholderDict = Any
|
||||
|
||||
global_thread_pool = ThreadPoolExecutor(
|
||||
max_workers=envs.VLLM_MEDIA_LOADING_THREAD_COUNT)
|
||||
atexit.register(global_thread_pool.shutdown)
|
||||
|
||||
|
||||
class MediaConnector:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
media_io_kwargs: Optional[dict[str, dict[str, Any]]] = None,
|
||||
connection: HTTPConnection = global_http_connection,
|
||||
*,
|
||||
allowed_local_media_path: str = "",
|
||||
allowed_media_domains: Optional[list[str]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
media_io_kwargs: Additional args passed to process media
|
||||
inputs, keyed by modalities. For example,
|
||||
to set num_frames for video, set
|
||||
`--media-io-kwargs '{"video":{"num_frames":40}}'`
|
||||
connection: HTTP connection client to download media contents.
|
||||
allowed_local_media_path: A local directory to load media files
|
||||
from.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.media_io_kwargs: dict[str, dict[
|
||||
str, Any]] = media_io_kwargs if media_io_kwargs else {}
|
||||
self.connection = connection
|
||||
|
||||
if allowed_local_media_path:
|
||||
allowed_local_media_path_ = Path(allowed_local_media_path)
|
||||
|
||||
if not allowed_local_media_path_.exists():
|
||||
raise ValueError(
|
||||
"Invalid `--allowed-local-media-path`: The path "
|
||||
f"{allowed_local_media_path_} does not exist.")
|
||||
if not allowed_local_media_path_.is_dir():
|
||||
raise ValueError(
|
||||
"Invalid `--allowed-local-media-path`: The path "
|
||||
f"{allowed_local_media_path_} must be a directory.")
|
||||
else:
|
||||
allowed_local_media_path_ = None
|
||||
|
||||
self.allowed_local_media_path = allowed_local_media_path_
|
||||
if allowed_media_domains is None:
|
||||
allowed_media_domains = []
|
||||
self.allowed_media_domains = allowed_media_domains
|
||||
|
||||
def _load_data_url(
|
||||
self,
|
||||
url_spec: ParseResult,
|
||||
media_io: MediaIO[_M],
|
||||
) -> _M: # type: ignore[type-var]
|
||||
data_spec, data = url_spec.path.split(",", 1)
|
||||
media_type, data_type = data_spec.split(";", 1)
|
||||
|
||||
if data_type != "base64":
|
||||
msg = "Only base64 data URLs are supported for now."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
return media_io.load_base64(media_type, data)
|
||||
|
||||
def _load_file_url(
|
||||
self,
|
||||
url_spec: ParseResult,
|
||||
media_io: MediaIO[_M],
|
||||
) -> _M: # type: ignore[type-var]
|
||||
allowed_local_media_path = self.allowed_local_media_path
|
||||
if allowed_local_media_path is None:
|
||||
raise RuntimeError("Cannot load local files without "
|
||||
"`--allowed-local-media-path`.")
|
||||
|
||||
filepath = Path(url2pathname(url_spec.path))
|
||||
if allowed_local_media_path not in filepath.resolve().parents:
|
||||
raise ValueError(
|
||||
f"The file path {filepath} must be a subpath "
|
||||
f"of `--allowed-local-media-path` {allowed_local_media_path}.")
|
||||
|
||||
return media_io.load_file(filepath)
|
||||
|
||||
def _assert_url_in_allowed_media_domains(self, url_spec) -> None:
|
||||
if self.allowed_media_domains and url_spec.hostname not in \
|
||||
self.allowed_media_domains:
|
||||
raise ValueError(
|
||||
f"The URL must be from one of the allowed domains: "
|
||||
f"{self.allowed_media_domains}. Input URL domain: "
|
||||
f"{url_spec.hostname}")
|
||||
|
||||
def load_from_url(
|
||||
self,
|
||||
url: str,
|
||||
media_io: MediaIO[_M],
|
||||
*,
|
||||
fetch_timeout: Optional[int] = None,
|
||||
) -> _M: # type: ignore[type-var]
|
||||
url_spec = urlparse(url)
|
||||
|
||||
if url_spec.scheme.startswith("http"):
|
||||
self._assert_url_in_allowed_media_domains(url_spec)
|
||||
|
||||
connection = self.connection
|
||||
data = connection.get_bytes(
|
||||
url,
|
||||
timeout=fetch_timeout,
|
||||
allow_redirects=envs.VLLM_MEDIA_URL_ALLOW_REDIRECTS,
|
||||
)
|
||||
|
||||
return media_io.load_bytes(data)
|
||||
|
||||
if url_spec.scheme == "data":
|
||||
return self._load_data_url(url_spec, media_io)
|
||||
|
||||
if url_spec.scheme == "file":
|
||||
return self._load_file_url(url_spec, media_io)
|
||||
|
||||
msg = "The URL must be either a HTTP, data or file URL."
|
||||
raise ValueError(msg)
|
||||
|
||||
async def load_from_url_async(
|
||||
self,
|
||||
url: str,
|
||||
media_io: MediaIO[_M],
|
||||
*,
|
||||
fetch_timeout: Optional[int] = None,
|
||||
) -> _M:
|
||||
url_spec = urlparse(url)
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
if url_spec.scheme.startswith("http"):
|
||||
self._assert_url_in_allowed_media_domains(url_spec)
|
||||
|
||||
connection = self.connection
|
||||
data = await connection.async_get_bytes(
|
||||
url,
|
||||
timeout=fetch_timeout,
|
||||
allow_redirects=envs.VLLM_MEDIA_URL_ALLOW_REDIRECTS,
|
||||
)
|
||||
future = loop.run_in_executor(global_thread_pool,
|
||||
media_io.load_bytes, data)
|
||||
return await future
|
||||
|
||||
if url_spec.scheme == "data":
|
||||
future = loop.run_in_executor(global_thread_pool,
|
||||
self._load_data_url, url_spec,
|
||||
media_io)
|
||||
return await future
|
||||
|
||||
if url_spec.scheme == "file":
|
||||
future = loop.run_in_executor(global_thread_pool,
|
||||
self._load_file_url, url_spec,
|
||||
media_io)
|
||||
return await future
|
||||
msg = "The URL must be either a HTTP, data or file URL."
|
||||
raise ValueError(msg)
|
||||
|
||||
def fetch_audio(
|
||||
self,
|
||||
audio_url: str,
|
||||
) -> tuple[np.ndarray, Union[int, float]]:
|
||||
"""
|
||||
Load audio from a URL.
|
||||
"""
|
||||
audio_io = AudioMediaIO(**self.media_io_kwargs.get("audio", {}))
|
||||
|
||||
return self.load_from_url(
|
||||
audio_url,
|
||||
audio_io,
|
||||
fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
|
||||
)
|
||||
|
||||
async def fetch_audio_async(
|
||||
self,
|
||||
audio_url: str,
|
||||
) -> tuple[np.ndarray, Union[int, float]]:
|
||||
"""
|
||||
Asynchronously fetch audio from a URL.
|
||||
"""
|
||||
audio_io = AudioMediaIO(**self.media_io_kwargs.get("audio", {}))
|
||||
|
||||
return await self.load_from_url_async(
|
||||
audio_url,
|
||||
audio_io,
|
||||
fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
|
||||
)
|
||||
|
||||
def fetch_image(
|
||||
self,
|
||||
image_url: str,
|
||||
*,
|
||||
image_mode: str = "RGB",
|
||||
) -> Image.Image:
|
||||
"""
|
||||
Load a PIL image from an HTTP or base64 data URL.
|
||||
|
||||
By default, the image is converted into RGB format.
|
||||
"""
|
||||
image_io = ImageMediaIO(image_mode=image_mode,
|
||||
**self.media_io_kwargs.get("image", {}))
|
||||
|
||||
try:
|
||||
return self.load_from_url(
|
||||
image_url,
|
||||
image_io,
|
||||
fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
|
||||
)
|
||||
except UnidentifiedImageError as e:
|
||||
# convert to ValueError to be properly caught upstream
|
||||
raise ValueError(str(e)) from e
|
||||
|
||||
async def fetch_image_async(
|
||||
self,
|
||||
image_url: str,
|
||||
*,
|
||||
image_mode: str = "RGB",
|
||||
) -> Image.Image:
|
||||
"""
|
||||
Asynchronously load a PIL image from an HTTP or base64 data URL.
|
||||
|
||||
By default, the image is converted into RGB format.
|
||||
"""
|
||||
image_io = ImageMediaIO(image_mode=image_mode,
|
||||
**self.media_io_kwargs.get("image", {}))
|
||||
|
||||
try:
|
||||
return await self.load_from_url_async(
|
||||
image_url,
|
||||
image_io,
|
||||
fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
|
||||
)
|
||||
except UnidentifiedImageError as e:
|
||||
# convert to ValueError to be properly caught upstream
|
||||
raise ValueError(str(e)) from e
|
||||
|
||||
def fetch_video(
|
||||
self,
|
||||
video_url: str,
|
||||
*,
|
||||
image_mode: str = "RGB",
|
||||
) -> tuple[npt.NDArray, dict[str, Any]]:
|
||||
"""
|
||||
Load video from an HTTP or base64 data URL.
|
||||
"""
|
||||
image_io = ImageMediaIO(image_mode=image_mode,
|
||||
**self.media_io_kwargs.get("image", {}))
|
||||
video_io = VideoMediaIO(image_io,
|
||||
**self.media_io_kwargs.get("video", {}))
|
||||
|
||||
return self.load_from_url(
|
||||
video_url,
|
||||
video_io,
|
||||
fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
|
||||
)
|
||||
|
||||
async def fetch_video_async(
|
||||
self,
|
||||
video_url: str,
|
||||
*,
|
||||
image_mode: str = "RGB",
|
||||
) -> tuple[npt.NDArray, dict[str, Any]]:
|
||||
"""
|
||||
Asynchronously load video from an HTTP or base64 data URL.
|
||||
|
||||
By default, the image is converted into RGB format.
|
||||
"""
|
||||
image_io = ImageMediaIO(image_mode=image_mode,
|
||||
**self.media_io_kwargs.get("image", {}))
|
||||
video_io = VideoMediaIO(image_io,
|
||||
**self.media_io_kwargs.get("video", {}))
|
||||
|
||||
return await self.load_from_url_async(
|
||||
video_url,
|
||||
video_io,
|
||||
fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
|
||||
)
|
||||
|
||||
def fetch_image_embedding(
|
||||
self,
|
||||
data: str,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Load image embedding from a URL.
|
||||
"""
|
||||
image_embedding_io = ImageEmbeddingMediaIO()
|
||||
|
||||
return image_embedding_io.load_base64("", data)
|
||||
|
||||
|
||||
def encode_audio_base64(
|
||||
audio: np.ndarray,
|
||||
sampling_rate: int,
|
||||
) -> str:
|
||||
"""Encode audio as base64."""
|
||||
audio_io = AudioMediaIO()
|
||||
return audio_io.encode_base64((audio, sampling_rate))
|
||||
|
||||
|
||||
def encode_image_base64(
|
||||
image: Image.Image,
|
||||
*,
|
||||
image_mode: str = "RGB",
|
||||
format: str = "JPEG",
|
||||
) -> str:
|
||||
"""
|
||||
Encode a pillow image to base64 format.
|
||||
|
||||
By default, the image is converted into RGB format before being encoded.
|
||||
"""
|
||||
image_io = ImageMediaIO(image_mode=image_mode)
|
||||
return image_io.encode_base64(image, image_format=format)
|
||||
|
||||
|
||||
def encode_video_base64(frames: npt.NDArray) -> str:
|
||||
image_io = ImageMediaIO()
|
||||
video_io = VideoMediaIO(image_io)
|
||||
return video_io.encode_base64(frames)
|
||||
|
||||
|
||||
def argsort_mm_positions(
|
||||
mm_positions: MultiModalPlaceholderDict) -> list[tuple[str, int]]:
|
||||
"""
|
||||
Given a `MultiModalPlaceholderDict`, output a sequence of keys to
|
||||
sort the dictionary by `offset` (starting index in the input sequence)
|
||||
in ascending order.
|
||||
|
||||
Returns:
|
||||
A list of `(modality, idx)`, which can be used to access an item
|
||||
by `mm_positions[modality][idx]`.
|
||||
"""
|
||||
flat_items = ((modality, idx, item)
|
||||
for modality, items in mm_positions.items()
|
||||
for idx, item in enumerate(items))
|
||||
|
||||
sorted_flat_items = sorted(flat_items, key=lambda x: x[2].offset)
|
||||
|
||||
return [(modality, idx) for modality, idx, _ in sorted_flat_items]
|
||||
|
||||
|
||||
# Temporary back-compatibility for plugins that define model runner
|
||||
@deprecated("`group_mm_inputs_by_modality` is superseded by "
|
||||
"`group_mm_kwargs_by_modality` and will be removed in v0.13. "
|
||||
"Please use `group_mm_kwargs_by_modality` instead.")
|
||||
def group_mm_inputs_by_modality(
|
||||
mm_inputs: list[MultiModalKwargsItems]
|
||||
) -> list[list[MultiModalKwargsItems]]:
|
||||
if not mm_inputs:
|
||||
return []
|
||||
|
||||
def modality_group_func(
|
||||
mm_input: MultiModalKwargsItems) -> Union[str, int]:
|
||||
# If the input has multiple modalities, return an id as the unique key
|
||||
# for the mm_input input.
|
||||
if len(mm_input) > 1:
|
||||
return id(mm_input)
|
||||
|
||||
elif len(mm_input) == 1:
|
||||
return next(iter(mm_input.keys()))
|
||||
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
|
||||
return [
|
||||
list(group) for _, group in groupby(mm_inputs, key=modality_group_func)
|
||||
]
|
||||
|
||||
|
||||
def group_mm_kwargs_by_modality(
|
||||
mm_kwargs: list[MultiModalKwargsItem],
|
||||
*,
|
||||
device: torch.types.Device = None,
|
||||
pin_memory: bool = False,
|
||||
merge_by_field_config: bool = False,
|
||||
) -> Iterable[tuple[str, int, BatchedTensorInputs]]:
|
||||
"""Group consecutive `MultiModalKwargsItem`s from `mm_kwargs` with the same
|
||||
modality together into the same `MultiModalKwargs` instance.
|
||||
|
||||
Args:
|
||||
mm_kwargs: List of `MultiModalKwargsItem`.
|
||||
device: The device to place the grouped tensors on.
|
||||
pin_memory: Whether to pin memory for faster host-to-device transfer.
|
||||
|
||||
Yields:
|
||||
A tuple `(modality, num_items, grouped_kwargs)`.
|
||||
"""
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, MultiModalKwargsItems
|
||||
|
||||
for modality, items in groupby(mm_kwargs, key=lambda item: item.modality):
|
||||
items_lst = list(items)
|
||||
|
||||
# TODO: Enable `merge_by_field_config` for all models
|
||||
# to avoid creating an extra batch dimension (except for fields
|
||||
# that are meant to be stacked anyway).
|
||||
# We will also need to update each model to remove `flatten_bn`.
|
||||
if merge_by_field_config:
|
||||
mm_kwargs_group: BatchedTensorInputs = dict(
|
||||
MultiModalKwargsItems.from_seq(items_lst).get_data(
|
||||
pin_memory=pin_memory))
|
||||
|
||||
if device is not None:
|
||||
mm_kwargs_group = json_map_leaves(
|
||||
lambda x: x.to(device=device),
|
||||
mm_kwargs_group,
|
||||
)
|
||||
else:
|
||||
mm_kwargs_group = MultiModalKwargs.as_kwargs(
|
||||
MultiModalKwargs.batch(
|
||||
[
|
||||
MultiModalKwargsItems.from_seq([item]).get_data()
|
||||
for item in items_lst
|
||||
],
|
||||
pin_memory=pin_memory,
|
||||
),
|
||||
device=device,
|
||||
)
|
||||
|
||||
yield modality, len(items_lst), mm_kwargs_group
|
||||
|
||||
|
||||
def fetch_audio(
|
||||
audio_url: str,
|
||||
audio_io_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> tuple[np.ndarray, Union[int, float]]:
|
||||
"""
|
||||
Args:
|
||||
audio_url: URL of the audio file to fetch.
|
||||
audio_io_kwargs: Additional kwargs passed to handle audio IO.
|
||||
"""
|
||||
media_io_kwargs = None if not audio_io_kwargs else {
|
||||
"audio": audio_io_kwargs
|
||||
}
|
||||
media_connector = MediaConnector(media_io_kwargs=media_io_kwargs)
|
||||
return media_connector.fetch_audio(audio_url)
|
||||
|
||||
|
||||
def fetch_image(
|
||||
image_url: str,
|
||||
image_io_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> Image.Image:
|
||||
"""
|
||||
Args:
|
||||
image_url: URL of the image file to fetch.
|
||||
image_io_kwargs: Additional kwargs passed to handle image IO.
|
||||
"""
|
||||
media_io_kwargs = None if not image_io_kwargs else {
|
||||
"image": image_io_kwargs
|
||||
}
|
||||
media_connector = MediaConnector(media_io_kwargs=media_io_kwargs)
|
||||
return media_connector.fetch_image(image_url)
|
||||
|
||||
|
||||
def fetch_video(
|
||||
video_url: str,
|
||||
video_io_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> tuple[npt.NDArray, dict[str, Any]]:
|
||||
"""
|
||||
Args:
|
||||
video_url: URL of the video file to fetch.
|
||||
video_io_kwargs: Additional kwargs passed to handle video IO.
|
||||
"""
|
||||
media_io_kwargs = None if not video_io_kwargs else {
|
||||
"video": video_io_kwargs
|
||||
}
|
||||
media_connector = MediaConnector(media_io_kwargs=media_io_kwargs)
|
||||
return media_connector.fetch_video(video_url)
|
||||
Reference in New Issue
Block a user