# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # Copyright 2024 The Qwen team. # Copyright 2023 The vLLM team. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX # and OPT implementations in this library. It has been modified from its # original forms to accommodate minor architectural differences compared # to GPT-NeoX and OPT used by the Meta AI team that trained the model. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen2.5-Omni model (thinker part).""" from collections.abc import Iterable, Mapping, Sequence from copy import copy from functools import partial from typing import Annotated, Any, Callable, Literal, Optional, Union import torch import torch.nn as nn from transformers.feature_extraction_utils import BatchFeature from transformers.models.qwen2_5_omni.configuration_qwen2_5_omni import ( Qwen2_5OmniConfig, Qwen2_5OmniThinkerConfig) from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import ( Qwen2_5OmniAudioEncoder) from transformers.models.qwen2_5_omni.processing_qwen2_5_omni import ( Qwen2_5OmniProcessor) from transformers.models.whisper import WhisperFeatureExtractor from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.qwen2_5_vl import ( Qwen2_5_VisionTransformer, Qwen2_5_VLImageEmbeddingInputs, Qwen2_5_VLImageInputs, Qwen2_5_VLImagePixelInputs, Qwen2_5_VLProcessingInfo, Qwen2_5_VLVideoEmbeddingInputs, Qwen2_5_VLVideoInputs, Qwen2_5_VLVideoPixelInputs) from vllm.model_executor.models.qwen2_audio import ( Qwen2AudioProcessingInfo, _get_feat_extract_output_lengths) from vllm.model_executor.models.qwen2_vl import Qwen2VLMultiModalDataParser from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (ImageItem, ModalityData, MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, NestedTensors) from vllm.multimodal.parse import (AudioProcessorItems, DictEmbeddingItems, ModalityDataItems, MultiModalDataItems, MultiModalDataParser) from vllm.multimodal.processing import (BaseMultiModalProcessor, MultiModalPromptUpdates, PlaceholderFeaturesInfo, PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.transformers_utils.tokenizer import decode_tokens, encode_tokens from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) from .utils import (AutoWeightsLoader, WeightsMapper, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) try: import flash_attn except (ImportError, ModuleNotFoundError): flash_attn = None logger = init_logger(__name__) class Qwen2_5OmniAudioFeatureInputs(TensorSchema): """ Dimensions: - na: Number of audios - nmb: Number of mel bins - msl: Maximum sequence length - tsl: Total sequence length """ type: Literal["audio_features"] input_features: Annotated[ Union[torch.Tensor, list[torch.Tensor]], TensorShape("nmb", "tsl"), ] feature_attention_mask: Annotated[ torch.Tensor, TensorShape("na", "msl"), ] def create_qwen2_5_omni_thinker_field_factory( spatial_merge_size: int ) -> Callable[[Mapping[str, torch.Tensor]], Mapping[str, MultiModalFieldConfig]]: def _qwen2_5_omni_thinker_field_config(hf_inputs: Mapping[str, torch.Tensor]): audio_feature_lengths = hf_inputs.get("audio_feature_lengths", torch.empty((0, ))) image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) image_pixel_grid_sizes = image_grid_thw.prod(-1) image_embed_grid_sizes = (image_pixel_grid_sizes // spatial_merge_size // spatial_merge_size) video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3))) video_grid_sizes = video_grid_thw.prod(-1) video_embed_grid_sizes = (video_grid_sizes // spatial_merge_size // spatial_merge_size) num_videos = len(video_grid_sizes) return dict( input_audio_features=MultiModalFieldConfig.flat_from_sizes( "audio", audio_feature_lengths, dim=1), feature_attention_mask=MultiModalFieldConfig.batched("audio"), audio_feature_lengths=MultiModalFieldConfig.batched("audio"), pixel_values=MultiModalFieldConfig.flat_from_sizes( "image", image_pixel_grid_sizes), image_embeds=MultiModalFieldConfig.flat_from_sizes( "image", image_embed_grid_sizes), image_grid_thw=MultiModalFieldConfig.batched("image"), pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( "video", video_grid_sizes), video_embeds=MultiModalFieldConfig.flat_from_sizes( "video", video_embed_grid_sizes), video_grid_thw=MultiModalFieldConfig.batched("video"), second_per_grid_ts=MultiModalFieldConfig.batched("video"), use_audio_in_video=MultiModalFieldConfig.shared( "video", num_videos), ) return _qwen2_5_omni_thinker_field_config class Qwen2_5OmniThinkerMultiModalDataParser(Qwen2VLMultiModalDataParser): def __init__(self, spatial_merge_size: int, *args, **kwargs): self._spatial_merge_size = spatial_merge_size super().__init__(self._spatial_merge_size, *args, **kwargs) def _parse_audio_data( self, data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]], ) -> ModalityDataItems[Any, Any]: if isinstance(data, dict): return DictEmbeddingItems( data, modality="audio", required_fields={ "input_audio_features", "audio_feature_lengths" }, fields_factory=create_qwen2_5_omni_thinker_field_factory( self._spatial_merge_size), ) return super()._parse_audio_data(data) class Qwen2_5OmniThinkerProcessingInfo(Qwen2AudioProcessingInfo, Qwen2_5_VLProcessingInfo): def get_hf_config(self): return self.ctx.get_hf_config(Qwen2_5OmniConfig).thinker_config def get_hf_processor(self, **kwargs: object) -> Qwen2_5OmniProcessor: return self.ctx.get_hf_processor( Qwen2_5OmniProcessor, use_fast=kwargs.pop("use_fast", True), **kwargs, ) def get_feature_extractor(self, **kwargs: object): hf_processor = self.get_hf_processor(**kwargs) feature_extractor = hf_processor.feature_extractor # type: ignore assert isinstance(feature_extractor, WhisperFeatureExtractor) return feature_extractor def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"audio": None, "image": None, "video": None} class Qwen2_5OmniThinkerDummyInputsBuilder( BaseDummyInputsBuilder[Qwen2_5OmniThinkerProcessingInfo]): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_audios = mm_counts.get("audio", 0) num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) hf_processor = self.info.get_hf_processor() audio_token: str = hf_processor.audio_token image_token: str = hf_processor.image_token video_token: str = hf_processor.video_token return (audio_token * num_audios + image_token * num_images + video_token * num_videos) def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], ) -> MultiModalDataDict: num_audios = mm_counts.get("audio", 0) num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) feature_extractor = self.info.get_feature_extractor() target_audio_length = min( feature_extractor.chunk_length, 30, ) * feature_extractor.sampling_rate target_width, target_height = \ self.info.get_image_size_with_most_features() target_num_frames = \ self.info.get_num_frames_with_most_features(seq_len, mm_counts) mm_data = { "audio": self._get_dummy_audios(length=target_audio_length, num_audios=num_audios), "image": self._get_dummy_images(width=target_width, height=target_height, num_images=num_images), "video": self._get_dummy_videos(width=target_width, height=target_height, num_frames=target_num_frames, num_videos=num_videos), } return mm_data class Qwen2_5OmniThinkerMultiModalProcessor( BaseMultiModalProcessor[Qwen2_5OmniThinkerProcessingInfo]): def _get_data_parser(self) -> MultiModalDataParser: feature_extractor = self.info.get_feature_extractor() return Qwen2_5OmniThinkerMultiModalDataParser( spatial_merge_size=self.info.get_hf_config( ).vision_config.spatial_merge_size, target_sr=feature_extractor.sampling_rate) def _call_hf_processor( self, prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], tok_kwargs: Mapping[str, object], ) -> BatchFeature: mm_data = dict(mm_data) audios = mm_data.pop("audios", []) # NOTE: WhisperFeatureExtractor cannot handle empty list of audios if audios: # NOTE: Qwen2.5-Omni processor accept "audio" mm_data["audio"] = audios mm_kwargs = dict(**mm_kwargs, ) hf_inputs = super()._call_hf_processor( prompt=prompt, mm_data=mm_data, mm_kwargs=mm_kwargs, tok_kwargs=tok_kwargs, ) input_features = hf_inputs.pop('input_features', None) feature_attention_mask = hf_inputs.get('feature_attention_mask', None) if ('input_audio_features' not in hf_inputs and input_features is not None): if feature_attention_mask is not None: input_features = input_features.permute( 0, 2, 1)[feature_attention_mask.bool()].permute(1, 0) hf_inputs['input_audio_features'] = input_features if ('audio_feature_lengths' not in hf_inputs and feature_attention_mask is not None): hf_inputs['audio_feature_lengths'] = feature_attention_mask.sum(-1) video_second_per_grid = hf_inputs.get("video_second_per_grid", None) if video_second_per_grid is not None: hf_inputs["second_per_grid_ts"] = video_second_per_grid use_audio_in_video = mm_kwargs.get("use_audio_in_video", False) hf_inputs["use_audio_in_video"] = torch.tensor(use_audio_in_video) return hf_inputs def _get_mm_fields_config( self, hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: return create_qwen2_5_omni_thinker_field_factory( self.info.get_hf_config().vision_config.spatial_merge_size)( hf_inputs) def _maybe_apply_prompt_updates( self, mm_items: MultiModalDataItems, prompt_ids: list[int], mm_kwargs: MultiModalKwargsItems, mm_prompt_updates: MultiModalPromptUpdates, is_update_applied: bool, ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]: """ Qwen2.5-Omni reimplements this function to handle `use_audio_in_video`. """ mm_item_counts = mm_items.get_all_counts() self._validate_mm_kwargs(mm_kwargs, mm_item_counts) use_audio_in_video = (all( item["use_audio_in_video"].data for item in mm_kwargs["video"]) if "video" in mm_kwargs else False) if is_update_applied: mm_placeholders = self._find_mm_placeholders( prompt_ids, mm_prompt_updates, ) self._validate_mm_placeholders( mm_placeholders, mm_item_counts, use_audio_in_video=use_audio_in_video) tokenizer = self.info.get_tokenizer() prompt = decode_tokens(tokenizer, prompt_ids) else: ( prompt_ids, prompt, mm_placeholders, ) = self._apply_prompt_updates( prompt_ids, mm_prompt_updates, ) self._validate_mm_placeholders( mm_placeholders, mm_item_counts, use_audio_in_video=use_audio_in_video) tokenizer = self.info.get_tokenizer() prompt = decode_tokens(tokenizer, prompt_ids) return prompt_ids, prompt, mm_placeholders def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, Any], out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) tokenizer = self.info.get_tokenizer() image_processor = self.info.get_image_processor( **hf_processor_mm_kwargs) vocab = tokenizer.get_vocab() audio_token = processor.audio_token image_token = processor.image_token video_token = processor.video_token audio_token_id = vocab[audio_token] image_token_id = vocab[image_token] video_token_id = vocab[video_token] out_mm_data = out_mm_kwargs.get_data() audio_feature_lengths = out_mm_data.get("audio_feature_lengths") feature_attention_mask = out_mm_data.get("feature_attention_mask") if audio_feature_lengths is None and feature_attention_mask is None: audio_output_lengths = [] elif audio_feature_lengths is not None: _, audio_output_lens = _get_feat_extract_output_lengths( audio_feature_lengths) audio_output_lengths = audio_output_lens.tolist() elif feature_attention_mask is not None: assert isinstance(feature_attention_mask, torch.Tensor) _, audio_output_lens = _get_feat_extract_output_lengths( feature_attention_mask.sum(-1)) audio_output_lengths = audio_output_lens.tolist() # number of audios read from video. audio_in_video_item_idx = 0 def get_replacement_qwen2_audio(item_idx: int): item_idx += audio_in_video_item_idx num_features = audio_output_lengths[item_idx] if num_features == 0: audios = mm_items.get_items("audio", AudioProcessorItems) audio = audios.get(item_idx) raise ValueError( f"The audio {audio} (len={len(audio)}) is too short " "to be represented inside the model") return [audio_token_id] * num_features def get_replacement_qwen2_vision(item_idx: int, modality: str): grid_thw = out_mm_data[f"{modality}_grid_thw"][item_idx] assert isinstance(grid_thw, torch.Tensor) merge_length = image_processor.merge_size**2 token_id = image_token_id if modality == "image" else video_token_id return [token_id] * (int(grid_thw.prod()) // merge_length) use_audio_in_video = hf_processor_mm_kwargs.get( "use_audio_in_video", False) thinker_config = self.info.get_hf_config() def get_replacement_qwen2_use_audio_in_video(item_idx: int): nonlocal audio_in_video_item_idx audio_num_features = audio_output_lengths[audio_in_video_item_idx + item_idx] video_grid_thw = out_mm_data["video_grid_thw"][item_idx] audio_in_video_item_idx += 1 second_per_grid_ts = hf_processor_mm_kwargs.get( "second_per_grid_ts", None) if second_per_grid_ts: video_second_per_grid_t = second_per_grid_ts[item_idx] else: video_second_per_grid_t = 1.0 return MRotaryEmbedding.omni_get_updates_use_audio_in_video( thinker_config=thinker_config, audio_len=audio_num_features, video_grid_thw=video_grid_thw, video_second_per_grid_t=video_second_per_grid_t, ) video_replacement_fn = ( get_replacement_qwen2_use_audio_in_video if use_audio_in_video else partial(get_replacement_qwen2_vision, modality="video")) return [ PromptReplacement( modality="audio", target=audio_token, replacement=get_replacement_qwen2_audio, ), PromptReplacement( modality="image", target=image_token, replacement=partial(get_replacement_qwen2_vision, modality="image"), ), PromptReplacement( modality="video", target=video_token, replacement=video_replacement_fn, ), ] def _apply_hf_processor_main( self, prompt: Union[str, list[int]], mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], *, enable_hf_prompt_update: bool, ) -> tuple[list[int], BatchFeature, bool]: """ Qwen2.5-Omni reimplements this function to handle text only. """ if isinstance(prompt, str): if enable_hf_prompt_update: return self._apply_hf_processor_text_mm( prompt_text=prompt, mm_items=mm_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, ) tokenizer = self.info.get_tokenizer() prompt_ids = encode_tokens(tokenizer, prompt) else: prompt_ids = self._apply_hf_processor_tokens_only(prompt) mm_processed_data = self._apply_hf_processor_mm_only( mm_items=mm_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, ) return prompt_ids, mm_processed_data, False def _apply_hf_processor_mm_only( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], ) -> BatchFeature: """ Qwen2.5-Omni reimplements this function to handle `use_audio_in_video`. """ mm_counts = mm_items.get_all_counts() use_audio_in_video = hf_processor_mm_kwargs.get( "use_audio_in_video", False) if use_audio_in_video and "video" in mm_counts: assert "audio" in mm_counts mm_counts["audio"] -= mm_counts["video"] _, mm_processed_data, _ = self._apply_hf_processor_text_mm( prompt_text=self.dummy_inputs.get_dummy_text(mm_counts), mm_items=mm_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, ) return mm_processed_data def _validate_mm_placeholders( self, mm_placeholders: Mapping[str, list[PlaceholderFeaturesInfo]], mm_item_counts: Mapping[str, int], use_audio_in_video: bool = False, ) -> None: if use_audio_in_video: mm_item_counts = copy(mm_item_counts) if "video" in mm_item_counts: assert "audio" in mm_item_counts mm_item_counts["audio"] -= mm_item_counts["video"] super()._validate_mm_placeholders(mm_placeholders, mm_item_counts) class Qwen2_5OmniConditionalGenerationMixin: def _validate_and_reshape_mm_tensor(self, mm_input: object, name: str, dim: int = 0) -> torch.Tensor: if not isinstance(mm_input, (torch.Tensor, list)): raise ValueError(f"Incorrect type of {name}. " f"Got type: {type(mm_input)}") if isinstance(mm_input, torch.Tensor): if dim == 0: return mm_input.reshape(-1, *mm_input.shape[2:]) return torch.concat(list(mm_input), dim=dim) else: return torch.concat(mm_input, dim=dim) def _parse_and_validate_audio_input( self, **kwargs: object) -> Optional[Qwen2_5OmniAudioFeatureInputs]: input_audio_features = kwargs.pop('input_audio_features', None) audio_feature_lengths = kwargs.pop('audio_feature_lengths', None) feature_attention_mask = kwargs.pop('feature_attention_mask', None) if input_audio_features is None: return None input_audio_features = self._validate_and_reshape_mm_tensor( input_audio_features, 'input_audio_features', dim=1) if feature_attention_mask is not None: feature_attention_mask = self._validate_and_reshape_mm_tensor( feature_attention_mask, 'feature_attention_mask') if not isinstance(input_audio_features, (torch.Tensor, list)): raise ValueError("Incorrect type of audio input features. " f"Got type: {type(input_audio_features)}") return Qwen2_5OmniAudioFeatureInputs( type="audio_features", input_features=input_audio_features, audio_feature_lengths=audio_feature_lengths, feature_attention_mask=feature_attention_mask) def _parse_and_validate_image_input( self, **kwargs: dict[str, Any], ) -> Optional[Qwen2_5_VLImageInputs]: pixel_values = kwargs.pop("pixel_values", None) image_embeds = kwargs.pop("image_embeds", None) image_grid_thw = kwargs.pop("image_grid_thw", None) if pixel_values is None and image_embeds is None: return None if pixel_values is not None: pixel_values = self._validate_and_reshape_mm_tensor( pixel_values, "image pixel values") image_grid_thw = self._validate_and_reshape_mm_tensor( image_grid_thw, "image grid_thw") if not isinstance(pixel_values, (torch.Tensor, list)): raise ValueError("Incorrect type of image pixel values. " f"Got type: {type(pixel_values)}") return Qwen2_5_VLImagePixelInputs(type="pixel_values", pixel_values=pixel_values, image_grid_thw=image_grid_thw) if image_embeds is not None: image_embeds = self._validate_and_reshape_mm_tensor( image_embeds, "image embeds") image_grid_thw = self._validate_and_reshape_mm_tensor( image_grid_thw, "image grid_thw") if not isinstance(image_embeds, torch.Tensor): raise ValueError("Incorrect type of image embeddings. " f"Got type: {type(image_embeds)}") return Qwen2_5_VLImageEmbeddingInputs( type="image_embeds", image_embeds=image_embeds, image_grid_thw=image_grid_thw) def _parse_and_validate_video_input( self, **kwargs: dict[str, Any], ) -> Optional[Qwen2_5_VLVideoInputs]: pixel_values_videos = kwargs.pop("pixel_values_videos", None) video_embeds = kwargs.pop("video_embeds", None) video_grid_thw = kwargs.pop("video_grid_thw", None) if pixel_values_videos is None and video_embeds is None: return None if pixel_values_videos is not None: pixel_values_videos = self._validate_and_reshape_mm_tensor( pixel_values_videos, "video pixel values") video_grid_thw = self._validate_and_reshape_mm_tensor( video_grid_thw, "video grid_thw") return Qwen2_5_VLVideoPixelInputs( type="pixel_values_videos", pixel_values_videos=pixel_values_videos, video_grid_thw=video_grid_thw, ) if video_embeds is not None: video_embeds = self._validate_and_reshape_mm_tensor( video_embeds, "video embeds") video_grid_thw = self._validate_and_reshape_mm_tensor( video_grid_thw, "video grid_thw") if not isinstance(video_embeds, torch.Tensor): raise ValueError("Incorrect type of video embeddings. " f"Got type: {type(video_embeds)}") return Qwen2_5_VLVideoEmbeddingInputs( type="video_embeds", video_embeds=video_embeds, video_grid_thw=video_grid_thw) def _process_audio_input( self, audio_input: Qwen2_5OmniAudioFeatureInputs, audio_hashes: list[str] = None, cached_audio_features: torch.Tensor = None, ) -> torch.Tensor: input_features = audio_input["input_features"] audio_feature_lengths = audio_input["audio_feature_lengths"] if input_features.ndim == 3: assert input_features.shape[0] == 1 input_features = input_features.squeeze(0) if audio_feature_lengths.ndim == 2: assert audio_feature_lengths.shape[ 0] == 1 or audio_feature_lengths.shape[1] == 1 if audio_feature_lengths.shape[0] == 1: audio_feature_lengths = audio_feature_lengths.squeeze(0) else: audio_feature_lengths = audio_feature_lengths.squeeze(1) audio_feat_lengths, audio_output_lengths = ( self.audio_tower._get_feat_extract_output_lengths( audio_feature_lengths)) audio_outputs = self.audio_tower( input_features.to(self.audio_tower.dtype), feature_lens=audio_feature_lengths, aftercnn_lens=audio_feat_lengths, ) return audio_outputs.last_hidden_state.split( audio_output_lengths.tolist()) def _process_image_input( self, image_input: Qwen2_5_VLImageInputs) -> tuple[torch.Tensor, ...]: if image_input["type"] == "image_embeds": return image_input["image_embeds"].type(self.visual.dtype) grid_thw = image_input["image_grid_thw"] assert grid_thw.ndim == 2 pixel_values = image_input["pixel_values"].type(self.visual.dtype) image_embeds = self.visual(pixel_values, grid_thw=grid_thw) # Split concatenated embeddings for each image item. merge_size = self.visual.spatial_merge_size sizes = grid_thw.prod(-1) // merge_size // merge_size return image_embeds.split(sizes.tolist()) def _process_video_input( self, video_input: Qwen2_5_VLVideoInputs, video_hashes: list[str] = None, cached_video_embeds: torch.Tensor = None) -> torch.Tensor: if video_input["type"] == "video_embeds": return video_input["video_embeds"].type(self.visual.dtype) grid_thw = video_input["video_grid_thw"] assert grid_thw.ndim == 2 pixel_values_videos = video_input["pixel_values_videos"].type( self.visual.dtype) video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw) # Split concatenated embeddings for each video item. merge_size = self.visual.spatial_merge_size sizes = grid_thw.prod(-1) // merge_size // merge_size return video_embeds.split(sizes.tolist()) @MULTIMODAL_REGISTRY.register_processor( Qwen2_5OmniThinkerMultiModalProcessor, info=Qwen2_5OmniThinkerProcessingInfo, dummy_inputs=Qwen2_5OmniThinkerDummyInputsBuilder, ) class Qwen2_5OmniThinkerForConditionalGeneration( nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, Qwen2_5OmniConditionalGenerationMixin): hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ "thinker.lm_head.": "language_model.lm_head.", "thinker.model.": "language_model.model.", "thinker.": "", }) packed_modules_mapping = { "qkv_proj": [ "q_proj", "k_proj", "v_proj", ], "attn.qkv": [ "attn.q", "attn.k", "attn.v", ], "gate_up_proj": [ "gate_proj", "up_proj", ], } @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: if modality.startswith("image"): return "<|vision_start|><|IMAGE|><|vision_end|>" if modality.startswith("video"): return "<|vision_start|><|VIDEO|><|vision_end|>" if modality.startswith("audio"): return f"Audio {i}: <|audio_bos|><|AUDIO|><|audio_eos|>" raise ValueError("Only image, video or audio modality is supported") def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() thinker_config: Qwen2_5OmniThinkerConfig = ( vllm_config.model_config.hf_config.thinker_config) quant_config = vllm_config.quant_config multimodal_config = vllm_config.model_config.multimodal_config self.config = thinker_config self.multimodal_config = multimodal_config # force "use_flash_attention_2=True" to audio tower to align # the results. if flash_attn is not None: audio_config = thinker_config.audio_config audio_config._attn_implementation_autoset = True audio_config._attn_implementation = "flash_attention_2" else: logger.warning( "flash_attn is not available, the model may not yield the " "exactly same result as the transformers implementation " "in the audio tower part.") if multimodal_config.get_limit_per_prompt("audio"): self.audio_tower = Qwen2_5OmniAudioEncoder( thinker_config.audio_config) else: self.audio_tower = None if multimodal_config.get_limit_per_prompt( "image") or multimodal_config.get_limit_per_prompt("video"): self.visual = Qwen2_5_VisionTransformer( vision_config=thinker_config.vision_config, norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6), quant_config=quant_config, prefix=maybe_prefix(prefix, "visual"), ) else: self.visual = None self.quant_config = quant_config self.language_model = init_vllm_registered_model( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "language_model"), hf_config=thinker_config.text_config, architectures=["Qwen2ForCausalLM"], ) self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: mm_input_by_modality = {} # Preserve the order of modalities if there are multiple of them # from the order of kwargs. for input_key in kwargs: if input_key in ("pixel_values", "image_embeds" ) and "image" not in mm_input_by_modality: mm_input_by_modality[ "image"] = self._parse_and_validate_image_input(**kwargs) if input_key in ("pixel_values_videos", "video_embeds" ) and "video" not in mm_input_by_modality: mm_input_by_modality[ "video"] = self._parse_and_validate_video_input(**kwargs) if input_key in ("input_audio_features" ) and "audio" not in mm_input_by_modality: mm_input_by_modality[ "audio"] = self._parse_and_validate_audio_input(**kwargs) return mm_input_by_modality def get_language_model(self) -> torch.nn.Module: return self.language_model def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: mm_input_by_modality = self._parse_and_validate_multimodal_inputs( **kwargs) if not mm_input_by_modality: return [] # The result multimodal_embeddings is tuple of tensors, with each # tensor corresponding to a multimodal data item (image or video). multimodal_embeddings: tuple[torch.Tensor, ...] = () # NOTE: It is important to iterate over the keys in this dictionary # to preserve the order of the modalities. for modality in mm_input_by_modality: multimodal_input = mm_input_by_modality[modality] if modality == "image": vision_embeddings = self._process_image_input(multimodal_input) multimodal_embeddings += vision_embeddings if modality == "video": video_embeddings = self._process_video_input(multimodal_input) multimodal_embeddings += video_embeddings if modality == "audio": audio_embeddings = self._process_audio_input(multimodal_input) multimodal_embeddings += audio_embeddings return multimodal_embeddings def get_input_embeddings( self, input_ids: torch.Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, ) -> torch.Tensor: inputs_embeds = self.language_model.get_input_embeddings(input_ids) if multimodal_embeddings is not None \ and len(multimodal_embeddings) != 0: # TODO (ywang96): support overlapping modality embeddings so that # `use_audio_in_video` will work on V1. inputs_embeds = merge_multimodal_embeddings( input_ids, inputs_embeds, multimodal_embeddings, [ self.config.image_token_index, self.config.video_token_index, self.config.audio_token_index ]) return inputs_embeds def get_multimodal_embeddings_v0( self, **kwargs: object) -> Optional[NestedTensors]: audio_input = self._parse_and_validate_audio_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs) video_input = self._parse_and_validate_video_input(**kwargs) if audio_input is None and image_input is None and video_input is None: return None multimodal_embeddings: list[tuple[NestedTensors, str]] = [] if audio_input is not None: audio_embeds = self._process_audio_input(audio_input) multimodal_embeddings.append((audio_embeds, "audio")) if image_input is not None: image_embeds = self._process_image_input(image_input) multimodal_embeddings.append((image_embeds, "image")) if video_input is not None: video_embeds = self._process_video_input(video_input) multimodal_embeddings.append((video_embeds, "video")) return multimodal_embeddings def get_input_embeddings_v0( self, input_ids: torch.Tensor, multimodal_embeddings: Optional[NestedTensors] = None, ) -> torch.Tensor: inputs_embeds = self.language_model.get_input_embeddings(input_ids) if multimodal_embeddings is None or len(multimodal_embeddings) == 0: return inputs_embeds for embeddings, modality in multimodal_embeddings: if modality == "audio": placeholder_token_id = self.config.audio_token_index if modality == "image": placeholder_token_id = self.config.image_token_index if modality == "video": placeholder_token_id = self.config.video_token_index inputs_embeds = merge_multimodal_embeddings( input_ids, inputs_embeds, embeddings, placeholder_token_id) return inputs_embeds def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ) -> Union[torch.Tensor, IntermediateTensors]: if intermediate_tensors is not None: inputs_embeds = None # NOTE: In v1, inputs_embeds is always generated at model runner, this # condition is for v0 compatibility. elif inputs_embeds is None: multimodal_embeddings = self.get_multimodal_embeddings_v0(**kwargs) inputs_embeds = self.get_input_embeddings_v0( input_ids, multimodal_embeddings) input_ids = None hidden_states = self.language_model.model(input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, ) -> Optional[torch.Tensor]: return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: skip_prefixes = ["talker.", "token2wav."] if self.audio_tower is None: skip_prefixes.extend(["audio_tower."]) if self.visual is None: skip_prefixes.extend(["visual."]) loader = AutoWeightsLoader( self, skip_prefixes=skip_prefixes, ) loaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) return loaded_weights def get_mm_mapping(self) -> MultiModelKeys: """ Get the module prefix in multimodal models """ return MultiModelKeys.from_string_field( language_model="language_model", connector="merger.", tower_model=["visual.", "audio_tower."])