Files
2025-09-15 14:58:11 +08:00

234 lines
7.3 KiB
Python

# SPDX-License-Identifier: Apache-2.0
import base64
from functools import partial
from io import BytesIO
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional
import numpy as np
import numpy.typing as npt
from PIL import Image
from vllm.inputs.registry import InputContext
from vllm.logger import init_logger
from vllm.transformers_utils.processor import cached_get_video_processor
from vllm.utils import is_list_of
from .base import MediaIO, ModalityData
from .image import ImageMediaIO, ImagePlugin
from .inputs import MultiModalKwargs, VideoItem
if TYPE_CHECKING:
from vllm.config import ModelConfig
logger = init_logger(__name__)
class VideoPlugin(ImagePlugin):
"""Plugin for video data."""
def get_data_key(self) -> str:
return "video"
def _get_hf_video_processor(
self,
model_config: "ModelConfig",
mm_processor_kwargs: Optional[dict[str, Any]] = None,
):
if mm_processor_kwargs is None:
mm_processor_kwargs = {}
return cached_get_video_processor(
model_config.model,
trust_remote_code=model_config.trust_remote_code,
**mm_processor_kwargs)
def _default_input_mapper(
self,
ctx: InputContext,
data: ModalityData[VideoItem],
**mm_processor_kwargs,
) -> MultiModalKwargs:
model_config = ctx.model_config
if isinstance(data, list) and len(data) == 1:
data = data[0] # type: ignore
if isinstance(data, np.ndarray) or is_list_of(data, np.ndarray):
video_processor = self._get_hf_video_processor(
model_config,
mm_processor_kwargs,
)
if video_processor is None:
raise RuntimeError("No HuggingFace processor is available "
"to process the video object")
try:
# NOTE: Similar to image; it may be a good idea to filter and
# pass mm_processor_kwargs here too, but for now we don't to
# avoid extra complexity if the initializer and preprocess
# signatures of the processor don't align
batch_data = video_processor(data, return_tensors="pt").data
except Exception:
logger.error("Failed to process video (%s)", data)
raise
return MultiModalKwargs(batch_data)
raise TypeError(f"Invalid video type: {type(data)}")
def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
return 4096
def resize_video(frames: npt.NDArray, size: tuple[int, int]) -> npt.NDArray:
num_frames, _, _, channels = frames.shape
new_height, new_width = size
resized_frames = np.empty((num_frames, new_height, new_width, channels),
dtype=frames.dtype)
# lazy import cv2 to avoid bothering users who only use text models
import cv2
for i, frame in enumerate(frames):
resized_frame = cv2.resize(frame, (new_width, new_height))
resized_frames[i] = resized_frame
return resized_frames
def rescale_video_size(frames: npt.NDArray, size_factor: float) -> npt.NDArray:
_, height, width, _ = frames.shape
new_height = int(height * size_factor)
new_width = int(width * size_factor)
return resize_video(frames, (new_height, new_width))
def sample_frames_from_video(frames: npt.NDArray,
num_frames: int) -> npt.NDArray:
total_frames = frames.shape[0]
if num_frames == -1:
return frames
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
sampled_frames = frames[frame_indices, ...]
return sampled_frames
class VideoLoader:
@classmethod
def load_bytes(self, data: bytes, num_frames: int = -1) -> npt.NDArray:
raise NotImplementedError
class OpenCVVideoBackend(VideoLoader):
def get_cv2_video_api(self):
import cv2.videoio_registry as vr
api_pref = None
for backend in vr.getStreamBufferedBackends():
if not vr.hasBackend(backend):
continue
if not vr.isBackendBuiltIn(backend):
_, abi, api = vr.getStreamBufferedBackendPluginVersion(backend)
if (abi < 1 or (abi == 1 and api < 2)):
continue
api_pref = backend
break
return api_pref
@classmethod
def load_bytes(cls, data: bytes, num_frames: int = -1) -> npt.NDArray:
import cv2
backend = cls().get_cv2_video_api()
cap = cv2.VideoCapture(BytesIO(data), backend, [])
if not cap.isOpened():
raise ValueError("Could not open video stream")
total_frames_num = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
full_read = num_frames == -1 or total_frames_num < num_frames
if full_read:
frame_idx = list(range(0, total_frames_num))
else:
uniform_sampled_frames = np.linspace(0,
total_frames_num - 1,
num_frames,
dtype=int)
frame_idx = uniform_sampled_frames.tolist()
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
frames = np.empty((len(frame_idx), height, width, 3), dtype=np.uint8)
i = 0
for idx in range(total_frames_num):
ok = cap.grab() # next img
if not ok:
break
if idx in frame_idx: # only decompress needed
ret, frame = cap.retrieve()
if ret:
frames[i] = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
i += 1
# we expect all frames loaded
assert i == num_frames
return frames
class VideoMediaIO(MediaIO[npt.NDArray]):
def __init__(
self,
image_io: ImageMediaIO,
*,
num_frames: int = 32,
) -> None:
super().__init__()
self.image_io = image_io
self.num_frames = num_frames
self.video_loader = OpenCVVideoBackend
def load_bytes(self, data: bytes) -> npt.NDArray:
return self.video_loader.load_bytes(data, self.num_frames)
def load_base64(self, media_type: str, data: str) -> npt.NDArray:
if media_type.lower() == "video/jpeg":
load_frame = partial(
self.image_io.load_base64,
"image/jpeg",
)
return np.stack([
np.array(load_frame(frame_data))
for frame_data in data.split(",")
])
return self.load_bytes(base64.b64decode(data))
def load_file(self, filepath: Path) -> npt.NDArray:
with filepath.open("rb") as f:
data = f.read()
return self.load_bytes(data)
def encode_base64(
self,
media: npt.NDArray,
*,
video_format: str = "JPEG",
) -> str:
video = media
if video_format == "JPEG":
encode_frame = partial(
self.image_io.encode_base64,
image_format=video_format,
)
return ",".join(
encode_frame(Image.fromarray(frame)) for frame in video)
msg = "Only JPEG format is supported for now."
raise NotImplementedError(msg)