# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math from collections.abc import Iterable, Mapping, Sequence from typing import Annotated, Final, Literal, Protocol, TypeAlias import torch import torch.nn as nn from transformers import BatchFeature, LlavaOnevisionConfig, LlavaOnevisionProcessor from transformers.models.llava_onevision.modeling_llava_onevision import ( get_anyres_image_grid_shape, unpad_image, ) from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.model_executor.layers.activation import get_act_fn from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, ) from vllm.multimodal.parse import ( ImageSize, MultiModalDataItems, VideoEmbeddingItems, VideoProcessorItems, ) from vllm.multimodal.processing import PromptReplacement, PromptUpdate from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape from .clip import CLIPVisionModel from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .llava import LlavaDummyInputsBuilder, init_vision_tower_for_llava from .llava_next import ( BaseLlavaNextMultiModalProcessor, LlavaNextLikeConfig, LlavaNextProcessingInfo, ) from .siglip import SiglipVisionModel from .utils import ( AutoWeightsLoader, WeightsMapper, init_vllm_registered_model, maybe_prefix, ) # For profile run _MAX_FRAMES_PER_VIDEO = 16 class LlavaOnevisionVideoPixelInputs(TensorSchema): """ Dimensions: - bn: Batch size * number of videos - f: Number of frames - c: Number of channels (3) - h: Height - w: Width Note that `f` may be different for each batch, and 'num_frames' may be different for each video, in which case the data is passed as a list instead of a batched tensor. """ type: Literal["pixel_values_videos"] = "pixel_values_videos" pixel_values_videos: Annotated[ torch.Tensor | list[torch.Tensor], TensorShape("bn", "f", 3, "h", "w", dynamic_dims={"f"}), ] class LlavaOnevisionImagePixelInputs(TensorSchema): """ Dimensions: - bn: Batch size * number of images - np: Number of patches (1 + num_patches) - c: Number of channels (3) - h: Height - w: Width Note that `num_patches` may be different per batch and image, in which case the data is passed as a list instead of a batched tensor. """ type: Literal["pixel_values"] = "pixel_values" pixel_values: Annotated[ torch.Tensor | list[torch.Tensor], TensorShape("bn", "np", 3, "h", "w", dynamic_dims={"np"}), ] image_sizes: Annotated[torch.Tensor | None, TensorShape("bn", 2)] class LlavaOnevisionImageEmbeddingInputs(TensorSchema): """ Dimensions: - bn: Batch size * number of images - ifs: Image feature size - hs: Hidden size (must match language model backbone) """ type: Literal["image_embeds"] = "image_embeds" data: Annotated[ torch.Tensor, TensorShape("bn", "ifs", "hs"), ] LlavaOnevisionImageInputs: TypeAlias = ( LlavaOnevisionImagePixelInputs | LlavaOnevisionImageEmbeddingInputs ) LlavaOnevisionMultiInputs: TypeAlias = ( LlavaOnevisionImageInputs | LlavaOnevisionVideoPixelInputs ) class LlavaOnevisionLikeConfig(LlavaNextLikeConfig, Protocol): video_token_index: Final[int] class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo): def get_hf_config(self) -> LlavaOnevisionLikeConfig: return self.ctx.get_hf_config(LlavaOnevisionConfig) def get_hf_processor(self, **kwargs: object): return self.ctx.get_hf_processor(LlavaOnevisionProcessor, **kwargs) def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"image": None, "video": None} # Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L86 # with additional logic afterwards taken from LlavaOnevisionProcessor def _get_num_unpadded_features( self, *, original_height: int, original_width: int, npatches: int, num_patch_height: int, num_patch_width: int, ) -> tuple[int, int]: current_height = npatches * num_patch_height current_width = npatches * num_patch_width aspect_ratio = original_width / original_height current_aspect_ratio = current_width / current_height if aspect_ratio > current_aspect_ratio: new_height = int( round(original_height * (current_width / original_width), 7) ) padding = (current_height - new_height) // 2 current_height = current_height - (2 * padding) else: new_width = int( round(original_width * (current_height / original_height), 7) ) padding = (current_width - new_width) // 2 current_width = current_width - (2 * padding) unpadded_features = current_height * current_width newline_features = current_height ratio = math.sqrt(current_height * current_width / (9 * npatches**2)) if ratio > 1.1: height_factor = int(current_height // ratio) width_factor = int(current_width // ratio) unpadded_features = height_factor * width_factor newline_features = height_factor return (unpadded_features, newline_features) def get_image_size_with_most_features(self) -> ImageSize: # NOTE: This hardcoded value is found via processor tests return ImageSize(width=1153, height=944) def _get_num_frame_tokens( self, *, image_width: int, image_height: int, ) -> int: hf_config = self.get_hf_config() spatial_pool_stride = getattr(hf_config, "spatial_pool_stride", 2) vision_encoder_info = self.get_vision_encoder_info() patch_grid_length = vision_encoder_info.get_patch_grid_length() pooled_grid_length = math.ceil(patch_grid_length / spatial_pool_stride) return pooled_grid_length * pooled_grid_length def get_num_video_tokens( self, *, image_width: int, image_height: int, num_frames: int, ) -> int: num_frame_tokens = self._get_num_frame_tokens( image_width=image_width, image_height=image_height, ) return num_frame_tokens * num_frames + 1 # Newline token def _get_max_video_frames(self, max_tokens: int) -> int: target_width, target_height = self.get_image_size_with_most_features() num_frames = 0 while True: next_num_frames = num_frames + 1 next_max_tokens = self.get_num_video_tokens( image_width=target_width, image_height=target_height, num_frames=next_num_frames, ) if next_max_tokens > max_tokens: break num_frames = next_num_frames return num_frames def get_num_frames_with_most_features( self, seq_len: int, mm_counts: Mapping[str, int], ) -> int: max_videos = mm_counts.get("video", 0) max_total_frames = self._get_max_video_frames(seq_len) max_frames_per_video = min( max_total_frames // max(max_videos, 1), _MAX_FRAMES_PER_VIDEO ) return max(max_frames_per_video, 1) def get_max_video_tokens( self, seq_len: int, mm_counts: Mapping[str, int], ) -> int: target_width, target_height = self.get_image_size_with_most_features() return self.get_num_video_tokens( image_width=target_width, image_height=target_height, num_frames=self.get_num_frames_with_most_features(seq_len, mm_counts), ) class LlavaOnevisionDummyInputsBuilder( LlavaDummyInputsBuilder[LlavaOnevisionProcessingInfo] ): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) processor = self.info.get_hf_processor() image_token = processor.image_token video_token = processor.video_token return image_token * num_images + video_token * num_videos def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) target_width, target_height = self.info.get_image_size_with_most_features() target_num_frames = self.info.get_num_frames_with_most_features( seq_len, mm_counts ) image_overrides = mm_options.get("image") if mm_options else None video_overrides = mm_options.get("video") if mm_options else None return { "image": self._get_dummy_images( width=target_width, height=target_height, num_images=num_images, overrides=image_overrides, ), "video": self._get_dummy_videos( width=target_width, height=target_height, num_frames=target_num_frames, num_videos=num_videos, overrides=video_overrides, ), } class LlavaOnevisionMultiModalProcessor( BaseLlavaNextMultiModalProcessor[LlavaOnevisionProcessingInfo] ): def _get_mm_fields_config( self, hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: return dict( pixel_values=MultiModalFieldConfig.batched("image"), image_sizes=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"), pixel_values_videos=MultiModalFieldConfig.batched("video"), ) def _call_hf_processor( self, prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], tok_kwargs: Mapping[str, object], ) -> BatchFeature: mm_data = dict(mm_data) videos = mm_data.pop("videos", []) assert isinstance(videos, list) if not videos: return super()._call_hf_processor( prompt=prompt, mm_data=mm_data, mm_kwargs=mm_kwargs, tok_kwargs=tok_kwargs, ) # LLaVA-OneVision processor doesn't support multiple videos # with different sizes when converting back to tensors # So, we process each component separately # NOTE: No prompt replacement is applied in this case processor = self.info.get_hf_processor() image_token = processor.image_token video_token = processor.video_token text_outputs = super()._call_hf_processor( prompt=prompt, mm_data={}, mm_kwargs=mm_kwargs, tok_kwargs=tok_kwargs, ) images = mm_data.pop("images", []) assert isinstance(images, list) if images: processor_outputs = super()._call_hf_processor( prompt=image_token * len(images), mm_data={"images": images}, mm_kwargs=mm_kwargs, tok_kwargs=tok_kwargs, ) image_outputs = { k: v for k, v in processor_outputs.items() if k in ("pixel_values", "image_sizes") } else: image_outputs = {} pixel_values_videos = [] for video in videos: item_outputs = super()._call_hf_processor( prompt=video_token, mm_data={"videos": video}, mm_kwargs=mm_kwargs, tok_kwargs=tok_kwargs, ) pixel_values_videos.append(item_outputs["pixel_values_videos"][0]) video_outputs = {"pixel_values_videos": pixel_values_videos} combined_outputs = dict( text_outputs, **image_outputs, **video_outputs, ) return BatchFeature(combined_outputs) def _hf_processor_applies_updates( self, prompt_text: str, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], ) -> bool: base_result = super()._hf_processor_applies_updates( prompt_text=prompt_text, mm_items=mm_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, ) return base_result and mm_items.get_count("video", strict=False) == 0 def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: image_repls = super()._get_prompt_updates( mm_items=mm_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, out_mm_kwargs=out_mm_kwargs, ) hf_config = self.info.get_hf_config() video_token_id = hf_config.video_token_index def get_video_replacement(item_idx: int): videos = mm_items.get_items( "video", (VideoEmbeddingItems, VideoProcessorItems) ) if isinstance(videos, VideoEmbeddingItems): num_video_tokens = videos.get_feature_size(item_idx) else: image_size = videos.get_frame_size(item_idx) num_video_tokens = self.info.get_num_video_tokens( image_width=image_size.width, image_height=image_size.height, num_frames=videos.get_num_frames(item_idx), ) return [video_token_id] * num_video_tokens return [ *image_repls, PromptReplacement( modality="video", target=[video_token_id], replacement=get_video_replacement, ), ] class LlavaOnevisionMultiModalProjector(nn.Module): def __init__(self, config: LlavaOnevisionConfig): super().__init__() self.linear_1 = nn.Linear( config.vision_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias, ) self.act = get_act_fn(config.projector_hidden_act) self.linear_2 = nn.Linear( config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias, ) def forward(self, image_features: torch.Tensor) -> torch.Tensor: hidden_states = self.linear_1(image_features) hidden_states = self.act(hidden_states) hidden_states = self.linear_2(hidden_states) return hidden_states @MULTIMODAL_REGISTRY.register_processor( LlavaOnevisionMultiModalProcessor, info=LlavaOnevisionProcessingInfo, dummy_inputs=LlavaOnevisionDummyInputsBuilder, ) class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ # mapping for new names in checkpoint saved after transformers v4.52 "model.language_model.": "language_model.model.", "model.vision_tower.": "vision_tower.", "model.multi_modal_projector.": "multi_modal_projector.", "model.image_newline": "image_newline", "lm_head.": "language_model.lm_head.", } ) @classmethod def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): return "" if modality.startswith("video"): return "