# Adapted from # https://github.com/huggingface/transformers/blob/19e6e80e10118f855137b90740936c0b11ac397f/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py # Copyright 2024 The Qwen team. # Copyright 2023 The vLLM team. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX # and OPT implementations in this library. It has been modified from its # original forms to accommodate minor architectural differences compared # to GPT-NeoX and OPT used by the Meta AI team that trained the model. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen2-VL model compatible with HuggingFace weights.""" from functools import partial from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping, Optional, Tuple, Type, TypedDict, Union) import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange, repeat from PIL import Image # Conditional import for transformers.image_utils compatibility try: from transformers.image_utils import (get_image_size, infer_channel_dimension_format, to_numpy_array) except ImportError: import numpy as np def get_image_size(image): if hasattr(image, 'size'): return image.size[::-1] # PIL returns (width, height) return image.shape[:2] def infer_channel_dimension_format(image): return "channels_last" def to_numpy_array(image): if isinstance(image, np.ndarray): return image return np.array(image) from transformers.models.qwen2_vl.configuration_qwen2_vl import ( Qwen2VLConfig, Qwen2VLVisionConfig) # Conditional import for transformers compatibility try: from transformers.models.qwen2_vl.image_processing_qwen2_vl import ( make_batched_images, make_batched_videos, smart_resize) except ImportError: # Fallback implementations for older transformers versions def smart_resize(height, width, factor, min_pixels, max_pixels): """Fallback implementation of smart_resize""" # Simple resize logic if height * width > max_pixels: scale = (max_pixels / (height * width)) ** 0.5 height = int(height * scale) width = int(width * scale) elif height * width < min_pixels: scale = (min_pixels / (height * width)) ** 0.5 height = int(height * scale) width = int(width * scale) # Round to factor height = (height // factor) * factor width = (width // factor) * factor return height, width def make_batched_images(images): """Fallback implementation of make_batched_images""" return images def make_batched_videos(videos): """Fallback implementation of make_batched_videos""" return videos from vllm.attention import AttentionMetadata from vllm.attention.selector import _Backend from vllm.config import VllmConfig from vllm.distributed import get_pp_group, parallel_state from vllm.distributed import utils as dist_utils from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, InputContext, token_inputs) from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.activation import QuickGELU from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.quantization import (GPTQConfig, GPTQMarlinConfig, QuantizationConfig) from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.qwen2 import Qwen2Model from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.image import cached_get_image_processor from vllm.multimodal.inputs import (MultiModalData, MultiModalDataDict, MultiModalKwargs) from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import IntermediateTensors, PoolerOutput, SequenceData from vllm.transformers_utils.config import uses_mrope from vllm.transformers_utils.processor import cached_get_processor from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP from .utils import (PPMissingLayer, get_vit_attn_backend, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, maybe_prefix) logger = init_logger(__name__) # === Vision Inputs === # class Qwen2VLImagePixelInputs(TypedDict): type: Literal["pixel_values"] pixel_values: torch.Tensor """Shape: `(num_patches, num_channels * patch_size * patch_size)` """ image_grid_thw: torch.Tensor """Shape: `(num_images, 3)` This should be in `(grid_t, grid_h, grid_w)` format. """ class Qwen2VLImageEmbeddingInputs(TypedDict): type: Literal["image_embeds"] image_embeds: torch.Tensor """Supported types: - List[`torch.Tensor`]: A list of tensors holding all images' features. Each tensor holds an image's features. - `torch.Tensor`: A tensor holding all images' features (concatenation of all images' feature tensors). Tensor shape: `(num_image_features, hidden_size)` - `num_image_features` varies based on the number and resolution of the images. - `hidden_size` must match the hidden size of language model backbone. """ image_grid_thw: torch.Tensor """Shape: `(num_images, 3)` This should be in `(grid_t, grid_h, grid_w)` format. """ Qwen2VLImageInputs = Union[Qwen2VLImagePixelInputs, Qwen2VLImageEmbeddingInputs] class Qwen2VLVideoPixelInputs(TypedDict): type: Literal["pixel_values_videos"] pixel_values_videos: torch.Tensor """Shape: `(num_patches, num_channels * temporal_patch_size * patch_size * patch_size)` """ video_grid_thw: torch.Tensor """Shape: `(num_videos, 3)` This should be in `(grid_t, grid_h, grid_w)` format. """ class Qwen2VLVideoEmbeddingInputs(TypedDict): type: Literal["video_embeds"] video_embeds: torch.Tensor """Supported types: - List[`torch.Tensor`]: A list of tensors holding all videos' features. Each tensor holds an video's features. - `torch.Tensor`: A tensor holding all videos' features (concatenation of all videos' feature tensors). Tensor shape: `(num_image_features, hidden_size)` - `num_image_features` varies based on the number and resolution of the videos. - `hidden_size` must match the hidden size of language model backbone. """ video_grid_thw: torch.Tensor """Shape: `(num_videos, 3)` This should be in `(grid_t, grid_h, grid_w)` format. """ Qwen2VLVideoInputs = Union[Qwen2VLVideoPixelInputs, Qwen2VLVideoEmbeddingInputs] # === Vision Encoder === # class Qwen2VisionMLP(nn.Module): def __init__( self, in_features: int, hidden_features: int = None, act_layer: Type[nn.Module] = QuickGELU, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): super().__init__() self.fc1 = ColumnParallelLinear(in_features, hidden_features, quant_config=quant_config, prefix=f"{prefix}.fc1") self.act = act_layer() self.fc2 = RowParallelLinear(hidden_features, in_features, quant_config=quant_config, prefix=f"{prefix}.fc2") def forward(self, x: torch.Tensor) -> torch.Tensor: x_parallel, _ = self.fc1(x) x_parallel = self.act(x_parallel) x, _ = self.fc2(x_parallel) return x def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor: if not interleaved: x1, x2 = x.chunk(2, dim=-1) return torch.cat((-x2, x1), dim=-1) else: x1, x2 = x[..., ::2], x[..., 1::2] return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2) def apply_rotary_emb_torch(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False) -> torch.Tensor: """ x: (batch_size, seqlen, nheads, headdim) cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) """ ro_dim = cos.shape[-1] * 2 assert ro_dim <= x.shape[-1] cos = repeat( cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") sin = repeat( sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") return torch.cat( [ x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:] ], dim=-1, ) def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: t_ = t.float() cos = freqs.cos() sin = freqs.sin() output = apply_rotary_emb_torch(t_, cos, sin).type_as(t) return output class Qwen2VisionAttention(nn.Module): def __init__( self, embed_dim: Optional[int] = None, num_heads: Optional[int] = None, projection_size: Optional[int] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: super().__init__() # Per attention head and per partition values. world_size = parallel_state.get_tensor_model_parallel_world_size() self.hidden_size_per_attention_head = dist_utils.divide( projection_size, num_heads) self.num_attention_heads_per_partition = dist_utils.divide( num_heads, world_size) self.qkv = ColumnParallelLinear(input_size=embed_dim, output_size=3 * projection_size, quant_config=quant_config, prefix=f"{prefix}.qkv") self.proj = RowParallelLinear(input_size=projection_size, output_size=embed_dim, quant_config=quant_config, prefix=f"{prefix}.proj") # Detect attention implementation. self.attn_backend: _Backend = get_vit_attn_backend() if self.attn_backend not in { _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS }: raise RuntimeError( f"Qwen2-VL does not support {self.attn_backend} backend now.") def forward( self, x: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None, ) -> torch.Tensor: # [s, b, c] --> [s, b, head * 3 * head_dim] x, _ = self.qkv(x) # [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim] new_x_shape = x.size()[:-1] + ( self.num_attention_heads_per_partition, 3 * self.hidden_size_per_attention_head, ) x = x.view(*new_x_shape) # [s, b, head, 3 * head_dim] --> 3 [s, b, head, head_dim] q, k, v = dist_utils.split_tensor_along_last_dim(x, 3) batch_size = q.shape[1] q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)) if rotary_pos_emb is not None: q = apply_rotary_pos_emb_vision(q, rotary_pos_emb) k = apply_rotary_pos_emb_vision(k, rotary_pos_emb) if self.attn_backend == _Backend.FLASH_ATTN: # from vllm_flash_attn.flash_attn_interface import ( # flash_attn_varlen_func) from flash_attn import flash_attn_varlen_func q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() output = flash_attn_varlen_func(q, k, v, cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen, dropout_p=0, causal=False) context_layer = rearrange(output, "(b s) ... -> b s ...", b=batch_size) elif self.attn_backend == _Backend.TORCH_SDPA: seq_length = q.size(1) q, k, v = (rearrange(x, "b s h d -> b h s d") for x in [q, k, v]) attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool) for i in range(1, len(cu_seqlens)): attention_mask[..., cu_seqlens[i - 1]:cu_seqlens[i], cu_seqlens[i - 1]:cu_seqlens[i]] = True output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0) context_layer = rearrange(output, "b h s d -> b s h d ") elif self.attn_backend == _Backend.XFORMERS: from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalMask seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens, kv_seqlen=None) context_layer = xops.memory_efficient_attention_forward( q, k, v, attn_bias=attn_bias, p=0, scale=None) context_layer = rearrange(context_layer, "b s h d -> s b (h d)").contiguous() output, _ = self.proj(context_layer) return output class Qwen2VisionBlock(nn.Module): def __init__( self, dim: int, num_heads: int, mlp_ratio: float, act_layer: Type[nn.Module] = QuickGELU, norm_layer: Type[nn.Module] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: super().__init__() if norm_layer is None: norm_layer = partial(nn.LayerNorm, eps=1e-6) self.norm1 = norm_layer(dim) self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.attn = Qwen2VisionAttention(embed_dim=dim, num_heads=num_heads, projection_size=dim, quant_config=quant_config, prefix=f"{prefix}.attn") self.mlp = Qwen2VisionMLP(dim, mlp_hidden_dim, act_layer=act_layer, quant_config=quant_config, prefix=f"{prefix}.mlp") def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor) -> torch.Tensor: x = x + self.attn(self.norm1(x), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb) x = x + self.mlp(self.norm2(x)) return x class Qwen2VisionPatchEmbed(nn.Module): def __init__( self, patch_size: int = 14, temporal_patch_size: int = 2, in_chans: int = 3, embed_dim: int = 1152, ) -> None: super().__init__() self.patch_size = patch_size self.temporal_patch_size = temporal_patch_size self.embed_dim = embed_dim kernel_size = [temporal_patch_size, patch_size, patch_size] self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: L, C = x.shape x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size) x = self.proj(x).view(L, self.embed_dim) return x class Qwen2VisionPatchMerger(nn.Module): def __init__( self, d_model: int, context_dim: int, norm_layer: Type[nn.Module] = None, spatial_merge_size: int = 2, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: super().__init__() self.hidden_size = context_dim * (spatial_merge_size**2) if norm_layer is None: norm_layer = partial(nn.LayerNorm, eps=1e-6) self.ln_q = norm_layer(context_dim) self.mlp = nn.ModuleList([ ColumnParallelLinear(self.hidden_size, self.hidden_size, bias=True, quant_config=quant_config, prefix=f"{prefix}.mlp.0"), nn.GELU(), RowParallelLinear(self.hidden_size, d_model, bias=True, quant_config=quant_config, prefix=f"{prefix}.mlp.2"), ]) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.ln_q(x) x = x.view(-1, self.hidden_size) mlp_fc1, mlp_act, mlp_fc2 = self.mlp x_parallel, _ = mlp_fc1(x) x_parallel = mlp_act(x_parallel) out, _ = mlp_fc2(x_parallel) return out class Qwen2VisionRotaryEmbedding(nn.Module): def __init__(self, dim: int, theta: float = 10000.0) -> None: super().__init__() self.dim = dim self.theta = theta inv_freq = 1.0 / (theta **(torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) self._seq_len_cached = 0 self._freqs_cached = None def update_freqs_cache(self, seqlen: int) -> None: if seqlen > self._seq_len_cached: seqlen *= 2 self._seq_len_cached = seqlen self.inv_freq = 1.0 / (self.theta**(torch.arange( 0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device) / self.dim)) seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) freqs = torch.outer(seq, self.inv_freq) self._freqs_cached = freqs def forward(self, seqlen: int) -> torch.Tensor: self.update_freqs_cache(seqlen) return self._freqs_cached[:seqlen] class Qwen2VisionTransformer(nn.Module): def __init__( self, vision_config: Qwen2VLVisionConfig, norm_eps: float = 1e-6, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: super().__init__() patch_size: int = vision_config.patch_size temporal_patch_size: int = vision_config.temporal_patch_size spatial_merge_size: int = vision_config.spatial_merge_size in_chans: int = vision_config.in_chans hidden_size: int = vision_config.hidden_size embed_dim: int = vision_config.embed_dim depth: int = vision_config.depth num_heads: int = vision_config.num_heads mlp_ratio: float = vision_config.mlp_ratio self.spatial_merge_size = spatial_merge_size self.patch_embed = Qwen2VisionPatchEmbed( patch_size=patch_size, temporal_patch_size=temporal_patch_size, in_chans=in_chans, embed_dim=embed_dim, ) norm_layer = partial(nn.LayerNorm, eps=norm_eps) head_dim = embed_dim // num_heads self.rotary_pos_emb = Qwen2VisionRotaryEmbedding(head_dim // 2) self.blocks = nn.ModuleList([ Qwen2VisionBlock(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, norm_layer=norm_layer, quant_config=quant_config, prefix=f"{prefix}.blocks.{layer_idx}") for layer_idx in range(depth) ]) self.merger = Qwen2VisionPatchMerger( d_model=hidden_size, context_dim=embed_dim, norm_layer=norm_layer, quant_config=quant_config, prefix=f"{prefix}.merger", ) @property def dtype(self) -> torch.dtype: return self.patch_embed.proj.weight.dtype @property def device(self) -> torch.device: return self.patch_embed.proj.weight.device def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: pos_ids = [] for t, h, w in grid_thw: hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) hpos_ids = hpos_ids.reshape( h // self.spatial_merge_size, self.spatial_merge_size, w // self.spatial_merge_size, self.spatial_merge_size, ).permute(0, 2, 1, 3).flatten() wpos_ids = wpos_ids.reshape( h // self.spatial_merge_size, self.spatial_merge_size, w // self.spatial_merge_size, self.spatial_merge_size, ).permute(0, 2, 1, 3).flatten() pos_ids.append( torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) pos_ids = torch.cat(pos_ids, dim=0) max_grid_size = grid_thw[:, 1:].max() rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) return rotary_pos_emb def forward( self, x: torch.Tensor, grid_thw: torch.Tensor, ) -> torch.Tensor: # patchify x = x.to(device=self.device, dtype=self.dtype) x = self.patch_embed(x) # compute position embedding rotary_pos_emb = self.rot_pos_emb(grid_thw) # compute cu_seqlens cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( dim=0, dtype=torch.int32) cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0) # transformers x = x.unsqueeze(1) for blk in self.blocks: x = blk(x, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb) # adapter x = self.merger(x) return x # === Vision input helpers === # def get_mm_processor_kwargs( min_pixels: Optional[int] = None, max_pixels: Optional[int] = None) -> Dict[str, int]: mm_processor_kwargs = {} if min_pixels: mm_processor_kwargs["min_pixels"] = min_pixels if max_pixels: mm_processor_kwargs["max_pixels"] = max_pixels return mm_processor_kwargs def mm_input_mapper_for_qwen2_vl( ctx: InputContext, data: MultiModalData[object], data_type_key: str, *, min_pixels: Optional[int] = None, max_pixels: Optional[int] = None, ) -> MultiModalKwargs: """Input mapper for Qwen2-VL.""" if data_type_key == "image" and isinstance(data, dict): return MultiModalKwargs({ "image_embeds": data.get("image_embeds"), "image_grid_thw": data.get("image_grid_thw"), }) if data_type_key == "video" and isinstance(data, dict): return MultiModalKwargs({ "video_embeds": data.get("video_embeds"), "video_grid_thw": data.get("video_grid_thw"), }) model_config = ctx.model_config # Handle mm processor kwargs; we pass these at creation time # because preprocess() in transformers doesn't expose them mm_processor_kwargs = get_mm_processor_kwargs(min_pixels=min_pixels, max_pixels=max_pixels) image_processor = cached_get_image_processor( model_config.model, trust_remote_code=model_config.trust_remote_code, **mm_processor_kwargs, ) if image_processor is None: raise RuntimeError("No HuggingFace processor is available " "to process the image object") images = None videos = None if data_type_key == "image": images = data else: assert data_type_key == "video" videos = data try: batch_data = image_processor \ .preprocess(images=images, videos=videos, return_tensors="pt") \ .data except Exception: logger.error("Failed to process image (%s)", data) raise return MultiModalKwargs(batch_data) image_input_mapper_for_qwen2_vl = partial(mm_input_mapper_for_qwen2_vl, data_type_key="image") video_input_mapper_for_qwen2_vl = partial(mm_input_mapper_for_qwen2_vl, data_type_key="video") def _get_vision_info( image_processor, height: int, width: int, min_pixels: int, max_pixels: int, do_resize: bool = True, data_type_key: str = "image", mm_count: int = 1, ): """Get information (resized height / width and number of vision tokens) of input image / video frame.""" if do_resize: resized_height, resized_width = smart_resize( height=height, width=width, factor=image_processor.patch_size * image_processor.merge_size, min_pixels=min_pixels, max_pixels=max_pixels, ) else: resized_height, resized_width = height, width if data_type_key == "image": grid_t = mm_count else: assert data_type_key == "video" grid_t = max(mm_count // image_processor.temporal_patch_size, 1) grid_h = resized_height // image_processor.patch_size grid_w = resized_width // image_processor.patch_size vision_tokens = grid_t * grid_h * grid_w llm_num_vision_tokens = (vision_tokens // image_processor.merge_size // image_processor.merge_size) return resized_height, resized_width, llm_num_vision_tokens def _get_max_image_info( image_processor, data_type_key: str = "image", mm_count: int = 1, min_pixels: Optional[int] = None, max_pixels: Optional[int] = None, ): # Limit min / max pixels unless they're explicitly provided if min_pixels is None: min_pixels = max(image_processor.min_pixels, 28 * 28) if max_pixels is None: max_pixels = min(image_processor.max_pixels, 1280 * 28 * 28) return _get_vision_info( image_processor, height=9999999, width=9999999, min_pixels=min_pixels, max_pixels=max_pixels, data_type_key=data_type_key, mm_count=mm_count, ) def get_max_qwen2_vl_mm_tokens(ctx: InputContext, data_type_key: str, *, min_pixels=None, max_pixels=None) -> int: mm_processor_kwargs = get_mm_processor_kwargs(min_pixels=min_pixels, max_pixels=max_pixels) image_processor = cached_get_image_processor(ctx.model_config.model, **mm_processor_kwargs) max_resized_height, max_resized_width, max_llm_image_tokens = \ _get_max_image_info(image_processor, data_type_key=data_type_key, mm_count=1, min_pixels=min_pixels, max_pixels=max_pixels) return max_llm_image_tokens get_max_qwen2_vl_image_tokens = partial(get_max_qwen2_vl_mm_tokens, data_type_key="image") get_max_qwen2_vl_video_tokens = partial(get_max_qwen2_vl_mm_tokens, data_type_key="video") def dummy_data_for_qwen2_vl( ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int], *, min_pixels: Optional[int] = None, max_pixels: Optional[int] = None ) -> Tuple[SequenceData, Optional[MultiModalDataDict]]: mm_processor_kwargs = get_mm_processor_kwargs(min_pixels=min_pixels, max_pixels=max_pixels) image_processor = cached_get_image_processor(ctx.model_config.model, **mm_processor_kwargs) num_images = mm_counts["image"] max_resized_height, max_resized_width, max_llm_image_tokens = \ _get_max_image_info(image_processor, data_type_key="image", mm_count=num_images, min_pixels=min_pixels, max_pixels=max_pixels) if seq_len - max_llm_image_tokens - 2 < 0: raise RuntimeError( f"Qwen2-VL cannot process {num_images} images in a prompt, " "please increase max_model_len or reduce image limit by " "--limit-mm-per-prompt.") # Check video counts. num_videos = mm_counts["video"] max_resized_height, max_resized_width, max_llm_video_tokens = \ _get_max_image_info(image_processor, data_type_key="video", mm_count=num_videos, min_pixels=min_pixels, max_pixels=max_pixels) if seq_len - max_llm_video_tokens - 2 < 0: raise RuntimeError( f"Qwen2-VL cannot process {num_videos} videos in a prompt, " "please increase max_model_len or reduce video limit by " "--limit-mm-per-prompt.") hf_config = ctx.get_hf_config(Qwen2VLConfig) dummy_seqdata = SequenceData.from_prompt_token_counts( (hf_config.vision_start_token_id, 1), (hf_config.image_token_id, max_llm_image_tokens), (hf_config.vision_end_token_id, 1), (0, seq_len - max_llm_image_tokens - 2), ) dummy_image = Image.new("RGB", (max_resized_width, max_resized_height), color=0) return DummyData(dummy_seqdata, { "image": dummy_image if num_images == 1 else [dummy_image] * num_images }) def _get_llm_num_vision_tokens( mm_inputs: list, data_type_key: str, image_processor, min_pixels: int, max_pixels: int, ): """Get number of vision tokens of multimodal inputs. This method is derived from `transformers.models.qwen2_vl. image_processing_qwen2_vl.Qwen2VLImageProcessor._preprocess`. """ image = to_numpy_array(mm_inputs[0]) input_data_format = infer_channel_dimension_format(image) height, width = get_image_size(image, channel_dim=input_data_format) _, _, llm_num_vision_tokens = _get_vision_info( image_processor, height=height, width=width, min_pixels=min_pixels, max_pixels=max_pixels, do_resize=image_processor.do_resize, data_type_key=data_type_key, mm_count=len(mm_inputs), ) return llm_num_vision_tokens def _expand_pad_tokens(inputs: list, token_id: int, make_batched_fn: Callable, data_type_key: str, image_processor: Any, prompt_token_ids: List[int], min_pixels: Optional[int], max_pixels: Optional[int]) -> List[int]: """ Expand pad tokens for multi-modal inputs (e.g., images or videos). Args: inputs (list): The multi-modal inputs (e.g., images or videos). token_id (int): The token ID used to represent the multi-modal input. make_batched_fn (Callable): A function to batch the inputs. data_type_key (str): The type of the multi-modal input. image_processor (Any): The image processor used to process the inputs. prompt_token_ids (List[int]): The list of token IDs in the prompt. min_pixels (int): min pixels to used for img processing max_pixels (int): max pixels to be used for img processing Returns: List[int]: The list of token IDs for the multi-modal inputs. """ indices = [ idx for idx, token in enumerate(prompt_token_ids) if token == token_id ] inputs = make_batched_fn(inputs) assert len(indices) == len(inputs) prompt_token_ids_with_data = [] for cnt, data in enumerate(inputs): num_tokens = _get_llm_num_vision_tokens( [data] if data_type_key == "image" else data, data_type_key=data_type_key, image_processor=image_processor, min_pixels=min_pixels, max_pixels=max_pixels, ) if cnt == 0: end_idx = indices[cnt] non_data_tokens = prompt_token_ids[:end_idx] else: non_data_tokens = prompt_token_ids[indices[cnt - 1] + 1:indices[cnt]] prompt_token_ids_with_data.extend(non_data_tokens) prompt_token_ids_with_data.extend(token_id for _ in range(num_tokens)) prompt_token_ids_with_data.extend(prompt_token_ids[indices[-1] + 1:]) return prompt_token_ids_with_data def input_processor_for_qwen2_vl( ctx: InputContext, inputs: DecoderOnlyInputs, *, min_pixels: Optional[int] = None, max_pixels: Optional[int] = None, ) -> DecoderOnlyInputs: multi_modal_data = inputs.get("multi_modal_data") if multi_modal_data is None: return inputs image_inputs = multi_modal_data.get("image", None) video_inputs = multi_modal_data.get("video", None) processor = cached_get_processor(ctx.model_config.model) image_processor = processor.image_processor # Apply processor kwarg overrides for image processor options min_pixels = min_pixels if min_pixels else image_processor.min_pixels max_pixels = max_pixels if max_pixels else image_processor.max_pixels model_config = ctx.model_config hf_config = ctx.get_hf_config(Qwen2VLConfig) # To avoid redundant processing of vision objects (resize, rescale, etc.), # we extract code of calculating number of vision tokens from # `transformers.models.qwen2_vl.processing_qwen2_vl.Qwen2VLProcessor`. # # The following code is equivalent to: # prompt = inputs["prompt"] # inputs = processor(text=[prompt], # images=image_inputs, # videos=video_inputs, # padding=True, # return_tensors="pt") # prompt_token_ids = inputs["input_ids"][0].tolist() tokenizer = cached_get_tokenizer( model_config.tokenizer, trust_remote_code=model_config.trust_remote_code) prompt_token_ids = inputs["prompt_token_ids"] # Expand image pad tokens. if image_inputs is not None: if isinstance(image_inputs, dict): prompt_token_ids_with_image = [] image_indices = [ idx for idx, token in enumerate(prompt_token_ids) if token == hf_config.image_token_id ] # ensure all image tokens have grid_thw assert \ len(image_indices) == image_inputs["image_grid_thw"].size(0), \ "image token num does not match image_grid_thw.shape" image_counter = 0 pad_token_counter = 0 for idx, token in enumerate(prompt_token_ids): if idx in image_indices: grid_thw = image_inputs["image_grid_thw"][image_counter] grid_t, grid_h, grid_w = grid_thw num_pad_tokens = (grid_t * grid_h * grid_w // image_processor.merge_size // image_processor.merge_size) prompt_token_ids_with_image.extend([token] * num_pad_tokens) image_counter += 1 pad_token_counter += num_pad_tokens else: prompt_token_ids_with_image.append(token) # ensure all embeddings are used assert \ pad_token_counter == image_inputs["image_embeds"].size(0), \ "image_embeds.shape does not match image_grid_thw" prompt_token_ids = prompt_token_ids_with_image else: prompt_token_ids = _expand_pad_tokens(image_inputs, hf_config.image_token_id, make_batched_images, "image", image_processor, prompt_token_ids, min_pixels=min_pixels, max_pixels=max_pixels) if video_inputs is not None: if isinstance(video_inputs, dict): prompt_token_ids_with_video = [] video_indices = [ idx for idx, token in enumerate(prompt_token_ids) if token == hf_config.video_token_id ] # ensure all video tokens have grid_thw assert \ len(video_indices) == video_inputs["video_grid_thw"].size(0), \ "video token num does not match video_grid_thw.shape" video_counter = 0 pad_token_counter = 0 for idx, token in enumerate(prompt_token_ids): if idx in video_indices: grid_thw = video_inputs["video_grid_thw"][video_counter] grid_t, grid_h, grid_w = grid_thw num_pad_tokens = (grid_t * grid_h * grid_w // image_processor.merge_size // image_processor.merge_size) prompt_token_ids_with_video.extend([token] * num_pad_tokens) video_counter += 1 pad_token_counter += num_pad_tokens else: prompt_token_ids_with_video.append(token) # ensure all embeddings are used assert \ pad_token_counter == video_inputs["video_embeds"].size(0), \ "video_embeds.shape does not match video_grid_thw" prompt_token_ids = prompt_token_ids_with_video else: prompt_token_ids = _expand_pad_tokens(video_inputs, hf_config.video_token_id, make_batched_videos, "video", image_processor, prompt_token_ids, min_pixels=min_pixels, max_pixels=max_pixels) prompt = inputs.get("prompt") if prompt is None: prompt = tokenizer.decode(prompt_token_ids) return token_inputs( prompt_token_ids=prompt_token_ids, prompt=prompt, multi_modal_data=multi_modal_data, ) @MULTIMODAL_REGISTRY.register_image_input_mapper( image_input_mapper_for_qwen2_vl) @MULTIMODAL_REGISTRY.register_input_mapper("video", video_input_mapper_for_qwen2_vl) @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_qwen2_vl_image_tokens) @MULTIMODAL_REGISTRY.register_max_multimodal_tokens( "video", get_max_qwen2_vl_video_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen2_vl) @INPUT_REGISTRY.register_input_processor(input_processor_for_qwen2_vl) class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP): packed_modules_mapping = { "qkv_proj": [ "q_proj", "k_proj", "v_proj", ], "gate_up_proj": [ "gate_proj", "up_proj", ], } # LoRA specific attributes # TODO Support LoRA for the visual encoder in the future. supported_lora_modules = [ "qkv_proj", "o_proj", "gate_up_proj", "down_proj", ] embedding_modules = {} embedding_padding_modules = [] def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config pooler_config = vllm_config.model_config.pooler_config multimodal_config = vllm_config.model_config.multimodal_config assert not cache_config.enable_prefix_caching, \ "Qwen2-VL currently does not support prefix caching" self.config = config self.multimodal_config = multimodal_config self.visual = Qwen2VisionTransformer( config.vision_config, norm_eps=getattr(config, "rms_norm_eps", 1e-6), quant_config=self._maybe_ignore_quant_config(quant_config), prefix=maybe_prefix(prefix, "visual"), ) self.model = Qwen2Model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) if get_pp_group().is_last_rank: if config.tie_word_embeddings: self.lm_head = self.model.embed_tokens else: self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config, prefix=maybe_prefix( prefix, "lm_head")) else: self.lm_head = PPMissingLayer() self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = get_sampler() self._pooler = Pooler.from_config_with_defaults( pooler_config, pooling_type=PoolingType.LAST, normalize=True, softmax=False) self.make_empty_intermediate_tensors = ( make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): # GPTQ configs do not have a list of ignored modules, however AutoGPTQ # seems to avoid vision encoder sections for some models. # See: https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4 if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)): return None return quant_config def _validate_and_reshape_mm_tensor(self, mm_input: Union[torch.Tensor, List[torch.Tensor]], name: str) -> torch.Tensor: if not isinstance(mm_input, (torch.Tensor, list)): raise ValueError(f"Incorrect type of {name}. " f"Got type: {type(mm_input)}") if isinstance(mm_input, torch.Tensor): if mm_input.ndim == 2: return mm_input if mm_input.ndim != 3: raise ValueError(f"{name} should be 2D or batched 3D tensor. " f"Got ndim: {mm_input.ndim}") return torch.concat(list(mm_input)) else: return torch.concat(mm_input) def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[Qwen2VLImageInputs]: pixel_values = kwargs.pop("pixel_values", None) image_embeds = kwargs.pop("image_embeds", None) image_grid_thw = kwargs.pop("image_grid_thw", None) if pixel_values is None and image_embeds is None: return None if pixel_values is not None: pixel_values = self._validate_and_reshape_mm_tensor( pixel_values, "image pixel values") image_grid_thw = self._validate_and_reshape_mm_tensor( image_grid_thw, "image grid_thw") if not isinstance(pixel_values, (torch.Tensor, list)): raise ValueError("Incorrect type of image pixel values. " f"Got type: {type(pixel_values)}") return Qwen2VLImagePixelInputs(type="pixel_values", pixel_values=pixel_values, image_grid_thw=image_grid_thw) if image_embeds is not None: image_embeds = self._validate_and_reshape_mm_tensor( image_embeds, "image embeds") image_grid_thw = self._validate_and_reshape_mm_tensor( image_grid_thw, "image grid_thw") if not isinstance(image_embeds, torch.Tensor): raise ValueError("Incorrect type of image embeddings. " f"Got type: {type(image_embeds)}") return Qwen2VLImageEmbeddingInputs(type="image_embeds", image_embeds=image_embeds, image_grid_thw=image_grid_thw) def _parse_and_validate_video_input( self, **kwargs: object) -> Optional[Qwen2VLVideoInputs]: pixel_values_videos = kwargs.pop("pixel_values_videos", None) video_embeds = kwargs.pop("video_embeds", None) video_grid_thw = kwargs.pop("video_grid_thw", None) if pixel_values_videos is None and video_embeds is None: return None if pixel_values_videos is not None: pixel_values_videos = self._validate_and_reshape_mm_tensor( pixel_values_videos, "video pixel values") video_grid_thw = self._validate_and_reshape_mm_tensor( video_grid_thw, "video grid_thw") return Qwen2VLVideoPixelInputs( type="pixel_values_videos", pixel_values_videos=pixel_values_videos, video_grid_thw=video_grid_thw, ) if video_embeds is not None: video_embeds = self._validate_and_reshape_mm_tensor( video_embeds, "video embeds") video_grid_thw = self._validate_and_reshape_mm_tensor( video_grid_thw, "video grid_thw") if not isinstance(video_embeds, torch.Tensor): raise ValueError("Incorrect type of video embeddings. " f"Got type: {type(video_embeds)}") return Qwen2VLVideoEmbeddingInputs(type="video_embeds", video_embeds=video_embeds, video_grid_thw=video_grid_thw) def _process_image_input(self, image_input: Qwen2VLImageInputs) -> torch.Tensor: if image_input["type"] == "image_embeds": return image_input["image_embeds"].type(self.visual.dtype) pixel_values = image_input["pixel_values"].type(self.visual.dtype) image_embeds = self.visual(pixel_values, grid_thw=image_input["image_grid_thw"]) return image_embeds def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor: if video_input["type"] == "video_embeds": return video_input["video_embeds"].type(self.visual.dtype) pixel_values_videos = video_input["pixel_values_videos"].type( self.visual.dtype) video_embeds = self.visual(pixel_values_videos, grid_thw=video_input["video_grid_thw"]) return video_embeds def _merge_multimodal_embeddings( self, input_ids: torch.Tensor, inputs_embeds: torch.Tensor, multimodal_embeddings: torch.Tensor, placeholder_token_id: int, ) -> torch.Tensor: mask = (input_ids == placeholder_token_id) inputs_embeds[mask, :] = multimodal_embeddings return inputs_embeds def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs: object, ) -> Union[torch.Tensor, IntermediateTensors]: """Run forward pass for Qwen2-VL. Args: input_ids: Flattened (concatenated) input_ids corresponding to a batch. positions: Flattened (concatenated) position ids corresponding to a batch. **NOTE**: If mrope is enabled (default setting for Qwen2-VL opensource models), the shape will be `(3, seq_len)`, otherwise it will be `(seq_len,). pixel_values: Pixel values to be fed to a model. `None` if no images are passed. image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM. `None` if no images are passed. pixel_values_videos: Pixel values of videos to be fed to a model. `None` if no videos are passed. video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM. `None` if no videos are passed. """ if intermediate_tensors is not None: input_ids = None inputs_embeds = None else: image_input = self._parse_and_validate_image_input(**kwargs) video_input = self._parse_and_validate_video_input(**kwargs) if image_input is None and video_input is None: inputs_embeds = None else: if uses_mrope(self.config): assert positions.ndim == 2 and positions.size(0) == 3, ( "multimodal section rotary embedding requires " f"(3, seq_len) positions, but got {positions.size()}") inputs_embeds = self.model.embed_tokens(input_ids) if image_input is not None: image_embeds = self._process_image_input(image_input) inputs_embeds = self._merge_multimodal_embeddings( input_ids, inputs_embeds, image_embeds, placeholder_token_id=self.config.image_token_id, ) if video_input is not None: video_embeds = self._process_video_input(video_input) inputs_embeds = self._merge_multimodal_embeddings( input_ids, inputs_embeds, video_embeds, placeholder_token_id=self.config.video_token_id, ) input_ids = None hidden_states = self.model( input_ids=input_ids, positions=positions, kv_caches=kv_caches, attn_metadata=attn_metadata, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits def sample( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def pooler( self, hidden_states: torch.Tensor, pooling_metadata: PoolingMetadata, ) -> Optional[PoolerOutput]: return self._pooler(hidden_states, pooling_metadata) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), ("gate_up_proj", "up_proj", 1), ("gate_up_proj", "gate_proj", 0), ] params_dict = dict(self.named_parameters(remove_duplicate=False)) for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue if self.config.tie_word_embeddings and "lm_head.weight" in name: continue for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue if is_pp_missing_parameter(name, self): continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: if "visual" in name and name.endswith("qkv.weight"): visual_num_heads = self.config.vision_config.num_heads visual_embed_dim = self.config.vision_config.embed_dim head_size = visual_embed_dim // visual_num_heads loaded_weight = loaded_weight.view(3, visual_num_heads, head_size, visual_embed_dim) loaded_weight = loaded_weight.transpose(0, 1) loaded_weight = loaded_weight.reshape(-1, visual_embed_dim) elif "visual" in name and name.endswith("qkv.bias"): visual_num_heads = self.config.vision_config.num_heads visual_embed_dim = self.config.vision_config.embed_dim head_size = visual_embed_dim // visual_num_heads loaded_weight = loaded_weight.view(3, visual_num_heads, head_size) loaded_weight = loaded_weight.transpose(0, 1) loaded_weight = loaded_weight.reshape(-1) try: # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue if is_pp_missing_parameter(name, self): continue param = params_dict[name] except KeyError: raise ValueError(f"Unexpected weight: {name}") from None weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight)