# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math from collections.abc import Iterable, Mapping, Sequence from functools import partial from typing import Any, Literal, Optional, TypedDict, Union import numpy as np import torch import torch.nn as nn from einops import rearrange from transformers import PretrainedConfig from transformers.activations import GELUActivation from transformers.feature_extraction_utils import BatchFeature from transformers.modeling_outputs import (BaseModelOutput, BaseModelOutputWithPooling) from transformers.utils import torch_int from vllm.config import VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.gptq_marlin import ( GPTQMarlinConfig) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (ImageItem, ModalityData, MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargs, VideoItem) from vllm.multimodal.parse import (DictEmbeddingItems, ImageSize, ModalityDataItems, MultiModalDataItems, MultiModalDataParser) from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.platforms import _Backend from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import uses_mrope from vllm.transformers_utils.processor import ( cached_image_processor_from_config) from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) from .siglip import SiglipMLP from .utils import (AutoWeightsLoader, WeightsMapper, init_vllm_registered_model, is_pp_missing_parameter, maybe_prefix, merge_multimodal_embeddings) from .vision import get_vit_attn_backend from vllm.platforms import current_platform logger = init_logger(__name__) _MAX_FRAMES_PER_VIDEO = 16 _MAX_IMAGE_SIZE = 9999999 def smart_resize( height: int, width: int, factor: int = 28, min_pixels: int = 28 * 28 * 130, max_pixels: int = 28 * 28 * 1280, ): if height < factor: logger.warning( "smart_resize: height=%s < factor=%s, reset height=factor", height, factor, ) width = round((width * factor) / height) height = factor if width < factor: logger.warning( "smart_resize: width=%s < factor=%s, reset width=factor", width, factor, ) height = round((height * factor) / width) width = factor if max(height, width) / min(height, width) > 200: raise ValueError("absolute aspect ratio must be smaller than 200, got " "{max(height, width) / min(height, width)}") h_bar = round(height / factor) * factor w_bar = round(width / factor) * factor if h_bar * w_bar > max_pixels: beta = math.sqrt((height * width) / max_pixels) h_bar = math.floor(height / beta / factor) * factor w_bar = math.floor(width / beta / factor) * factor elif h_bar * w_bar < min_pixels: beta = math.sqrt(min_pixels / (height * width)) h_bar = math.ceil(height * beta / factor) * factor w_bar = math.ceil(width * beta / factor) * factor return h_bar, w_bar class KeyeImagePixelInputs(TypedDict): type: Literal["pixel_values"] pixel_values: torch.Tensor """Shape: `(num_patches, num_channels * patch_size * patch_size)` """ image_grid_thw: torch.Tensor """Shape: `(num_images, 3)` This should be in `(grid_t, grid_h, grid_w)` format. """ class KeyeImageEmbeddingInputs(TypedDict): type: Literal["image_embeds"] image_embeds: torch.Tensor """Supported types: - list[`torch.Tensor`]: A list of tensors holding all images' features. Each tensor holds an image's features. - `torch.Tensor`: A tensor holding all images' features (concatenation of all images' feature tensors). Tensor shape: `(num_image_features, hidden_size)` - `num_image_features` varies based on the number and resolution of the images. - `hidden_size` must match the hidden size of language model backbone. """ image_grid_thw: torch.Tensor """Shape: `(num_images, 3)` This should be in `(grid_t, grid_h, grid_w)` format. """ KeyeImageInputs = Union[KeyeImagePixelInputs, KeyeImageEmbeddingInputs] class KeyeVideoPixelInputs(TypedDict): type: Literal["pixel_values_videos"] pixel_values_videos: torch.Tensor """Shape: `(num_patches, num_channels * temporal_patch_size * patch_size * patch_size)` """ video_grid_thw: torch.Tensor """Shape: `(num_videos, 3)` This should be in `(grid_t, grid_h, grid_w)` format. """ class KeyeVideoEmbeddingInputs(TypedDict): type: Literal["video_embeds"] video_embeds: torch.Tensor """Supported types: - list[`torch.Tensor`]: A list of tensors holding all videos' features. Each tensor holds an video's features. - `torch.Tensor`: A tensor holding all videos' features (concatenation of all videos' feature tensors). Tensor shape: `(num_image_features, hidden_size)` - `num_image_features` varies based on the number and resolution of the videos. - `hidden_size` must match the hidden size of language model backbone. """ video_grid_thw: torch.Tensor """Shape: `(num_videos, 3)` This should be in `(grid_t, grid_h, grid_w)` format. """ KeyeVideoInputs = Union[KeyeVideoPixelInputs, KeyeVideoEmbeddingInputs] class KeyeVisionEmbeddings(nn.Module): def __init__(self, config: PretrainedConfig): super().__init__() self.config = config self.embed_dim = config.hidden_size self.image_size = config.image_size self.patch_size = config.patch_size self.patch_embedding = nn.Conv2d( in_channels=config.num_channels, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size, padding="valid", ) self.num_patches = (self.image_size // self.patch_size)**2 self.num_positions = self.num_patches self.cache_position_embedding = dict() self.cache_position_count = dict() self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) self.packing_position_embedding = nn.Embedding(32768, self.embed_dim) self.register_buffer( "position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False, ) def interpolate_pos_encoding( self, embeddings: torch.Tensor, height: int, width: int, is_after_patchify: bool = False, ) -> torch.Tensor: num_positions = self.position_embedding.weight.shape[0] patch_pos_embed = self.position_embedding.weight.unsqueeze(0) dim = embeddings.shape[-1] if is_after_patchify: new_height = height new_width = width else: new_height = height // self.patch_size new_width = width // self.patch_size sqrt_num_positions = torch_int(num_positions**0.5) patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) patch_pos_embed = nn.functional.interpolate( patch_pos_embed, size=(new_height, new_width), mode="bilinear", align_corners=False, ) patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) return patch_pos_embed def fetch_position_embedding_lfu_cache(self, embeddings, h, w, max_cache: int = 20): grid = (h, w) if grid in self.cache_position_embedding: self.cache_position_count[grid] += 1 return self.cache_position_embedding[grid] if len(self.cache_position_embedding) >= max_cache: min_hit_grid = min( self.cache_position_count, key=self.cache_position_count.get, ) self.cache_position_count.pop(min_hit_grid) self.cache_position_embedding.pop(min_hit_grid) position_embedding = self.interpolate_pos_encoding( embeddings, h, w, True) self.cache_position_count[grid] = 1 self.cache_position_embedding[grid] = position_embedding return position_embedding def forward( self, pixel_values: torch.FloatTensor, position_ids: Optional[torch.Tensor] = None, image_grid_thw: Optional[list[Union[ tuple[int, int, int], list[tuple[int, int, int]], ]]] = None, interpolate_pos_encoding=False, ) -> torch.Tensor: if pixel_values.dim() == 4: pixel_values = pixel_values.unsqueeze(0) if pixel_values.dim() == 5: if position_ids is None: raise ValueError( "position_ids cannot be None when pixel_values.dim() is 5." ) ( batch_size, squence_len, channel, height, width, ) = pixel_values.shape target_dtype = self.patch_embedding.weight.dtype pixel_values = rearrange(pixel_values, "b l c h w -> (b l) c h w") patch_embeds = self.patch_embedding( pixel_values.to(dtype=target_dtype)) embeddings = patch_embeds.flatten(-2).squeeze(-1) if interpolate_pos_encoding and image_grid_thw is not None: start = 0 tmp_embeddings = list() for image_grid in image_grid_thw: t, h, w = image_grid end = start + t * h * w image_embeddings = embeddings[start:end, :] position_embedding = (self.interpolate_pos_encoding( image_embeddings, h, w, True).squeeze(0).repeat(t, 1)) image_embeddings = image_embeddings + position_embedding tmp_embeddings.append(image_embeddings) start = end embeddings = torch.concat(tmp_embeddings, dim=0).unsqueeze(0) else: embeddings = embeddings + self.packing_position_embedding( position_ids) return embeddings else: raise ValueError("Unsupported pixel_values dimension:" f" {pixel_values.dim()}. Expected 4 or 5.") def apply_rotary_pos_emb_flashatt( q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: cos = cos.chunk(2, dim=-1)[0].contiguous() sin = sin.chunk(2, dim=-1)[0].contiguous() if not current_platform.is_rocm(): from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb else: from flash_attn.layers.rotary import apply_rotary_emb q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q) k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k) return q_embed, k_embed class KeyeSiglipAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper.""" def __init__( self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): super().__init__() self.config = config hidden_size = config.hidden_size self.hidden_size = config.hidden_size tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = config.num_attention_heads assert self.total_num_heads % tp_size == 0 self.num_heads = self.total_num_heads // tp_size self.total_num_kv_heads = config.num_attention_heads if self.total_num_kv_heads >= tp_size: assert self.total_num_kv_heads % tp_size == 0 else: assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) self.head_dim = config.hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scale = self.head_dim**-0.5 self.qkv_proj = QKVParallelLinear( hidden_size, self.head_dim, self.total_num_heads, self.total_num_kv_heads, bias=True, quant_config=quant_config, prefix=f"{prefix}.qkv_proj", ) self.out_proj = RowParallelLinear( input_size=hidden_size, output_size=hidden_size, quant_config=quant_config, prefix=f"{prefix}.out_proj", ) # Detect attention implementation. self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) if self.attn_backend not in {_Backend.FLASH_ATTN, _Backend.XFORMERS}: raise RuntimeError( f"Keye-VL does not support {self.attn_backend} backend now.") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, cu_seqlens: Optional[list[torch.Tensor]] = None, rope_emb: Optional[tuple[torch.Tensor, torch.Tensor]] = None, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split( [self.q_size, self.kv_size, self.kv_size], dim=-1, ) max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() batch_size = q.shape[0] if rope_emb is None: q = q.view(*q.shape[:-1], self.num_heads, self.head_dim) k = k.view( *k.shape[:-1], self.num_kv_heads, self.head_dim, ) v = v.view( *v.shape[:-1], self.num_kv_heads, self.head_dim, ) else: if cu_seqlens is None: raise ValueError( "cu_seqlens cannot be None when rope_emb is not None.") cos, sin = rope_emb q = q.view(*q.shape[:-1], self.num_heads, self.head_dim) k = k.view( *k.shape[:-1], self.num_kv_heads, self.head_dim, ) q, k = apply_rotary_pos_emb_flashatt(q, k, cos, sin) v = v.view( *v.shape[:-1], self.num_kv_heads, self.head_dim, ) if self.attn_backend == _Backend.FLASH_ATTN: from flash_attn import flash_attn_varlen_func q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) output = flash_attn_varlen_func( q, k, v, cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen, causal=False, softmax_scale=self.scale, ) context_layer = rearrange(output, "(b s) ... -> b s ...", b=batch_size) elif self.attn_backend == _Backend.XFORMERS: from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalMask attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens, kv_seqlen=None, device=q.device) context_layer = xops.memory_efficient_attention_forward( q, k, v, attn_bias=attn_bias, p=0, scale=None) context_layer = rearrange(context_layer, "b s h d -> b s (h d)").contiguous() output, _ = self.out_proj(context_layer) return output class SigLIPRotaryEmbedding(nn.Module): def __init__(self, dim: int, theta: float = 10000.0) -> None: super().__init__() self.dim = dim self.theta = theta self.rope_init() def rope_init(self): inv_freq = 1.0 / (self.theta**( torch.arange(0, self.dim, 2, dtype=torch.float) / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) def forward(self, seqlen: int) -> torch.Tensor: seq = torch.arange( seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype, ) freqs = torch.outer(seq, self.inv_freq) return freqs class KeyeSiglipEncoderLayer(nn.Module): def __init__( self, config: Union[PretrainedConfig], quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): super().__init__() self.embed_dim = config.hidden_size self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.self_attn = KeyeSiglipAttention( config, quant_config=quant_config, prefix=f"{prefix}.self_attn", ) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = SiglipMLP( config, quant_config=quant_config, prefix=f"{prefix}.mlp", ) def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, output_attentions: Optional[bool] = False, cu_seqlens: Optional[list[torch.Tensor]] = None, rope_emb: Optional[tuple[torch.Tensor, torch.Tensor]] = None, ) -> tuple[torch.FloatTensor]: residual = hidden_states hidden_states = self.layer_norm1(hidden_states) hidden_states = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions, cu_seqlens=cu_seqlens, rope_emb=rope_emb, ) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.layer_norm2(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states return hidden_states class KeyeSiglipEncoder(nn.Module): def __init__( self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): super().__init__() self.config = config embed_dim = config.hidden_size num_heads = config.num_attention_heads head_dim = embed_dim // num_heads self.layers = nn.ModuleList([ KeyeSiglipEncoderLayer( config, quant_config=quant_config, prefix=f"{prefix}.layers.{layer_idx}", ) for layer_idx in range(config.num_hidden_layers) ]) self.rotary_pos_emb = SigLIPRotaryEmbedding(head_dim // 2) @staticmethod def flatten_list(image_grid_thw): tmp_image_grid_thw = list() for image_grid in image_grid_thw: if isinstance(image_grid, list): tmp_image_grid_thw.extend(image_grid) else: tmp_image_grid_thw.append(image_grid) return tmp_image_grid_thw def forward( self, inputs_embeds, attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, cu_seqlens: Optional[list[torch.Tensor]] = None, image_grid_thw: Optional[list[Union[ tuple[int, int, int], list[tuple[int, int, int]], ]]] = None, height_position_ids: Optional[torch.Tensor] = None, width_position_ids: Optional[torch.Tensor] = None, use_rope: Optional[bool] = False, window_size: Optional[bool] = -1, vision_or_text: str = "vision", ) -> BaseModelOutput: device = inputs_embeds.device hidden_states = inputs_embeds if use_rope is True: flatten_image_grid_thw = self.flatten_list(image_grid_thw) if width_position_ids is None or height_position_ids is None: split_hids = list() split_wids = list() for t, h, w in flatten_image_grid_thw: image_pids = torch.arange(t * h * w, device=device) % (h * w) sample_hids = image_pids // w sample_wids = image_pids % w split_hids.append(sample_hids) split_wids.append(sample_wids) width_position_ids = torch.concat(split_wids, dim=0) height_position_ids = torch.concat(split_hids, dim=0) pids = torch.stack( [height_position_ids, width_position_ids], dim=-1, ) max_grid_size = pids.max() + 1 rope_emb_max_grid = self.rotary_pos_emb(max_grid_size) rope_emb = rope_emb_max_grid[pids].flatten(1) rope_emb = rope_emb.repeat(1, 2) rope_emb = (rope_emb.cos(), rope_emb.sin()) else: rope_emb = None attn_cu_seqlens = cu_seqlens hidden_states = inputs_embeds assert attention_mask is None for encoder_layer in self.layers: hidden_states = encoder_layer( hidden_states, attention_mask, output_attentions=output_attentions, cu_seqlens=attn_cu_seqlens, rope_emb=rope_emb, ) return hidden_states class KeyeSiglipVisionTransformer(nn.Module): def __init__( self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): super().__init__() self.config = config embed_dim = config.hidden_size self.embeddings = KeyeVisionEmbeddings(config) self.encoder = KeyeSiglipEncoder( config, quant_config=quant_config, prefix=f"{prefix}.encoder", ) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) def forward( self, pixel_values, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, interpolate_pos_encoding: Optional[bool] = False, attention_mask: Optional[torch.Tensor] = None, sample_indices: Optional[torch.Tensor] = None, image_indices: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, height_position_ids: Optional[torch.Tensor] = None, width_position_ids: Optional[torch.Tensor] = None, cu_seqlens: Optional[list[torch.Tensor]] = None, padding_mask: Optional[torch.Tensor] = None, vision_return_embed_list: Optional[bool] = False, image_grid_thw: Optional[list[Union[ tuple[int, int, int], list[tuple[int, int, int]], ]]] = None, return_pooler_output: Optional[bool] = True, use_rope: Optional[bool] = False, window_size: Optional[bool] = -1, ) -> BaseModelOutputWithPooling: hidden_states = self.embeddings( pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, position_ids=position_ids, image_grid_thw=image_grid_thw, ) last_hidden_state = self.encoder( inputs_embeds=hidden_states, output_attentions=output_attentions, output_hidden_states=output_hidden_states, attention_mask=attention_mask, cu_seqlens=cu_seqlens, image_grid_thw=image_grid_thw, use_rope=use_rope, height_position_ids=height_position_ids, width_position_ids=width_position_ids, window_size=window_size, vision_or_text="vision", ) last_hidden_state = self.post_layernorm(last_hidden_state) sample_hidden_state = list() if cu_seqlens is None: raise ValueError("cu_seqlens cannot be None for " "SiglipVisionTransformer output processing.") for i in range(cu_seqlens.shape[0] - 1): start = cu_seqlens[i] end = cu_seqlens[i + 1] tensor = last_hidden_state[:, start:end, :].squeeze(0) sample_hidden_state.append(tensor) return sample_hidden_state class KeyeSiglipVisionModel(nn.Module): config_class = PretrainedConfig main_input_name = "pixel_values" def __init__( self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): super().__init__() self.vision_model = KeyeSiglipVisionTransformer( config, quant_config=quant_config, prefix=f"{prefix}.vision_model", ) self.quant_config = quant_config @property def dtype(self) -> torch.dtype: return self.vision_model.embeddings.patch_embedding.weight.dtype @property def device(self) -> torch.device: return self.vision_model.embeddings.patch_embedding.weight.device def get_input_embeddings(self) -> nn.Module: return self.vision_model.embeddings.patch_embedding def forward( self, pixel_values, sample_indices: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, interpolate_pos_encoding: bool = False, position_ids: Optional[torch.Tensor] = None, vision_return_embed_list: Optional[bool] = False, image_grid_thw: Optional[list[Union[ tuple[int, int, int], list[tuple[int, int, int]], ]]] = None, cu_seqlens: Optional[list[torch.Tensor]] = None, return_pooler_output: Optional[bool] = True, use_rope: Optional[bool] = False, window_size: Optional[bool] = -1, ) -> BaseModelOutputWithPooling: return self.vision_model( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, interpolate_pos_encoding=interpolate_pos_encoding, position_ids=position_ids, vision_return_embed_list=vision_return_embed_list, image_grid_thw=image_grid_thw, sample_indices=sample_indices, cu_seqlens=cu_seqlens, return_pooler_output=return_pooler_output, use_rope=use_rope, window_size=window_size, ) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), ] params_dict = dict(self.named_parameters(remove_duplicate=False)) loaded_params: set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue if "head.attention" in name or "head.layernorm" in name: continue if "head.mlp" in name or "head.probe" in name: continue if self.quant_config is not None and ( scale_name := self.quant_config.get_cache_scale(name)): param = params_dict[scale_name] weight_loader = getattr( param, "weight_loader", default_weight_loader, ) loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]) weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue for ( param_name, weight_name, shard_id, ) in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) if name.endswith(".bias") and name not in params_dict: continue if is_pp_missing_parameter(name, self): continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: if name.endswith(".bias") and name not in params_dict: continue name = maybe_remap_kv_scale_name(name, params_dict) if name is None: continue if is_pp_missing_parameter(name, self): continue param = params_dict[name] weight_loader = getattr( param, "weight_loader", default_weight_loader, ) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class Projector(nn.Module): def __init__( self, text_config: PretrainedConfig, vision_config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): super().__init__() self.text_config = text_config self.vision_config = vision_config self.merge_kernel_size = (2, 2) self.hidden_size = (self.vision_config.hidden_size * self.merge_kernel_size[0] * self.merge_kernel_size[1]) self.pre_norm = torch.nn.LayerNorm(self.vision_config.hidden_size, eps=1e-05) self.act = GELUActivation() self.linear_1 = ColumnParallelLinear( self.hidden_size, self.hidden_size, bias=True, quant_config=quant_config, prefix=f"{prefix}.linear_1", ) self.linear_2 = RowParallelLinear( self.hidden_size, self.text_config.hidden_size, bias=True, quant_config=quant_config, prefix=f"{prefix}.linear_2", ) def forward( self, image_features: torch.Tensor, image_grid_thw: list[tuple[int, int, int]], ) -> torch.Tensor: m1, m2 = self.merge_kernel_size if isinstance(image_features, (list, tuple)): processed_features = list() for image_feature, image_grid in zip(image_features, image_grid_thw): image_feature = self.pre_norm(image_feature) t, h, w = image_grid image_feature = rearrange( image_feature, "(t h p1 w p2) d -> (t h w) (p1 p2 d)", t=t, h=h // m1, p1=m1, w=w // m2, p2=m2, ) hidden_states, _ = self.linear_1(image_feature) hidden_states = self.act(hidden_states) hidden_states, _ = self.linear_2(hidden_states) processed_features.append(hidden_states) return processed_features dims = image_features.shape[:-1] dim = image_features.shape[-1] image_features = image_features.view(np.prod(dims), dim) hidden_states = self.pre_norm(image_features).view( -1, self.hidden_size) hidden_states = self.linear_1(hidden_states) hidden_states = self.act(hidden_states) hidden_states = self.linear_2(hidden_states) return hidden_states.view(*dims, -1) def _keye_field_config(hf_inputs: Mapping[str, torch.Tensor], ): image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) image_grid_sizes = image_grid_thw.prod(-1) video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3))) video_grid_sizes = video_grid_thw.prod(-1) return dict( pixel_values=MultiModalFieldConfig.flat_from_sizes( "image", image_grid_sizes), image_embeds=MultiModalFieldConfig.flat_from_sizes( "image", image_grid_sizes), image_grid_thw=MultiModalFieldConfig.batched("image"), pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( "video", video_grid_sizes), video_embeds=MultiModalFieldConfig.flat_from_sizes( "video", video_grid_sizes), video_grid_thw=MultiModalFieldConfig.batched("video"), ) class KeyeMultiModalDataParser(MultiModalDataParser): def _parse_image_data( self, data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]], ) -> ModalityDataItems[Any, Any]: if isinstance(data, dict): return DictEmbeddingItems( data, modality="image", required_fields={ "image_embeds", "image_grid_thw", }, fields_factory=_keye_field_config, ) return super()._parse_image_data(data) def _parse_video_data( self, data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]], ) -> ModalityDataItems[Any, Any]: if isinstance(data, dict): return DictEmbeddingItems( data, modality="video", required_fields={ "video_embeds", "video_grid_thw", }, fields_factory=_keye_field_config, ) return super()._parse_video_data(data) class KeyeProcessingInfo(BaseProcessingInfo): def get_hf_config(self): return self.ctx.get_hf_config(PretrainedConfig) def get_hf_processor( self, *, min_pixels: Optional[int] = None, max_pixels: Optional[int] = None, size: Optional[dict[str, int]] = None, **kwargs: object, ): return self.ctx.get_hf_processor( image_processor=self.get_image_processor( min_pixels=min_pixels, max_pixels=max_pixels, size=size, ), **kwargs, ) def _get_image_processor_kwargs( self, *, min_pixels: Optional[int] = None, max_pixels: Optional[int] = None, size: Optional[dict[str, int]] = None, **kwargs: object, ): if self.ctx.model_config.mm_processor_kwargs: kwargs.update(self.ctx.model_config.mm_processor_kwargs) if min_pixels is not None: kwargs["min_pixels"] = min_pixels if size is None: size = {"shortest_edge": min_pixels} else: size["shortest_edge"] = min_pixels if max_pixels is not None: kwargs["max_pixels"] = max_pixels if size is None: size = {"longest_edge": max_pixels} else: size["longest_edge"] = max_pixels if size is not None: kwargs["size"] = size return kwargs def get_image_processor( self, *, min_pixels: Optional[int] = None, max_pixels: Optional[int] = None, size: Optional[dict[str, int]] = None, **kwargs: object, ): return cached_image_processor_from_config( self.ctx.model_config, **self._get_image_processor_kwargs( min_pixels=min_pixels, max_pixels=max_pixels, size=size, **kwargs, ), ) def get_supported_mm_limits(self, ) -> Mapping[str, Optional[int]]: return {"image": None, "video": None} def get_mm_max_tokens_per_item( self, seq_len: int, mm_counts: Mapping[str, int], ) -> Mapping[str, int]: return { "image": self.get_max_image_tokens(), "video": self.get_max_video_tokens(seq_len), } def _get_vision_info( self, *, image_width: int, image_height: int, num_frames: int = 1, do_resize: bool = True, image_processor, ) -> tuple[ImageSize, int]: if image_processor is None: image_processor = self.get_image_processor() hf_config = self.get_hf_config() vision_config = hf_config.vision_config patch_size = vision_config.patch_size merge_size = vision_config.spatial_merge_size temporal_patch_size = 1 if do_resize: resized_height, resized_width = smart_resize( height=image_height, width=image_width, factor=patch_size * merge_size, min_pixels=image_processor.min_pixels, max_pixels=image_processor.max_pixels, ) preprocessed_size = ImageSize(width=resized_width, height=resized_height) else: preprocessed_size = ImageSize(width=image_width, height=image_height) padded_num_frames = num_frames + num_frames % temporal_patch_size grid_t = max(padded_num_frames // temporal_patch_size, 1) grid_h = preprocessed_size.height // patch_size grid_w = preprocessed_size.width // patch_size num_patches = grid_t * grid_h * grid_w num_vision_tokens = num_patches // (merge_size**2) return preprocessed_size, num_vision_tokens def get_num_image_tokens( self, *, image_width: int, image_height: int, image_processor, ) -> int: _, num_image_tokens = self._get_vision_info( image_width=image_width, image_height=image_height, image_processor=image_processor, ) return num_image_tokens def get_num_video_tokens( self, *, image_width: int, image_height: int, num_frames: int, image_processor, ) -> int: _, num_video_tokens = self._get_vision_info( image_width=image_width, image_height=image_height, num_frames=num_frames, image_processor=image_processor, ) return num_video_tokens def get_image_size_with_most_features(self, ) -> ImageSize: max_image_size, _ = self._get_vision_info( image_width=_MAX_IMAGE_SIZE, image_height=_MAX_IMAGE_SIZE, image_processor=None, ) return max_image_size def get_max_image_tokens(self) -> int: target_width, target_height = self.get_image_size_with_most_features() return self.get_num_image_tokens( image_width=target_width, image_height=target_height, image_processor=None, ) def _get_max_video_frames(self, max_tokens: int) -> int: target_width, target_height = self.get_image_size_with_most_features() num_frames = 0 while True: next_num_frames = num_frames + 1 next_max_tokens = self.get_num_video_tokens( image_width=target_width, image_height=target_height, num_frames=next_num_frames, image_processor=None, ) if next_max_tokens > max_tokens: break num_frames = next_num_frames return num_frames def get_num_frames_with_most_features(self, seq_len: int) -> int: mm_config = self.ctx.get_mm_config() max_images = mm_config.get_limit_per_prompt("image") max_videos = mm_config.get_limit_per_prompt("video") max_image_tokens = self.get_max_image_tokens() * max_images max_total_frames = self._get_max_video_frames(seq_len - max_image_tokens) max_frames_per_video = min( max_total_frames // max(max_videos, 1), _MAX_FRAMES_PER_VIDEO, ) return max(max_frames_per_video, 1) def get_max_video_tokens(self, seq_len: int) -> int: target_width, target_height = self.get_image_size_with_most_features() return self.get_num_video_tokens( image_width=target_width, image_height=target_height, num_frames=self.get_num_frames_with_most_features(seq_len), image_processor=None, ) class KeyeDummyInputsBuilder(BaseDummyInputsBuilder[KeyeProcessingInfo]): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) hf_processor = self.info.get_hf_processor() image_token: str = hf_processor.image_token video_token: str = hf_processor.video_token return image_token * num_images + video_token * num_videos def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) target_width, target_height = ( self.info.get_image_size_with_most_features()) target_num_frames = self.info.get_num_frames_with_most_features( seq_len) mm_data = { "image": self._get_dummy_images( width=target_width, height=target_height, num_images=num_images, ), "video": self._get_dummy_videos( width=target_width, height=target_height, num_frames=target_num_frames, num_videos=num_videos, ), } return mm_data class KeyeMultiModalProcessor(BaseMultiModalProcessor[KeyeProcessingInfo]): def _get_data_parser(self) -> MultiModalDataParser: return KeyeMultiModalDataParser() def _call_hf_processor( self, prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], tok_kwargs: Mapping[str, object], ) -> BatchFeature: mm_kwargs = self.info._get_image_processor_kwargs(**mm_kwargs) return self.info.ctx.call_hf_processor( self.info.get_hf_processor(**mm_kwargs), dict(text=prompt, **mm_data), dict(**mm_kwargs, **tok_kwargs), ) def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, Any], out_mm_kwargs: MultiModalKwargs, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) image_processor = self.info.get_image_processor( **hf_processor_mm_kwargs) tokenizer = self.info.get_tokenizer() vocab = tokenizer.get_vocab() placeholder = { "image": vocab[hf_processor.image_token], "video": vocab[hf_processor.video_token], } merge_length = image_processor.merge_size**2 def get_replacement_keye(item_idx: int, modality: str): grid_thw = out_mm_kwargs[f"{modality}_grid_thw"][item_idx] assert isinstance(grid_thw, torch.Tensor) num_tokens = int(grid_thw.prod()) // merge_length return [placeholder[modality]] * num_tokens return [ PromptReplacement( modality=modality, target=[placeholder[modality]], replacement=partial(get_replacement_keye, modality=modality), ) for modality in ("image", "video") ] def _get_mm_fields_config( self, hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: return _keye_field_config(hf_inputs) @MULTIMODAL_REGISTRY.register_processor( KeyeMultiModalProcessor, info=KeyeProcessingInfo, dummy_inputs=KeyeDummyInputsBuilder, ) class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP): packed_modules_mapping = { "qkv_proj": [ "q_proj", "k_proj", "v_proj", ], "gate_up_proj": [ "gate_proj", "up_proj", ], } hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={ "lm_head.": "language_model.lm_head.", "model.": "language_model.model.", }) @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: if modality.startswith("image"): return "<|vision_start|><|image_pad|><|vision_end|>" if modality.startswith("video"): return "<|vision_start|><|video_pad|><|vision_end|>" raise ValueError("Only image or video modality is supported") def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config: PretrainedConfig = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config multimodal_config = vllm_config.model_config.multimodal_config self.config = config self.multimodal_config = multimodal_config self.visual = KeyeSiglipVisionModel( config.vision_config, quant_config=self._maybe_ignore_quant_config(quant_config), prefix=maybe_prefix(prefix, "visual"), ) self.mlp_AR = Projector( config, config.vision_config, quant_config=self._maybe_ignore_quant_config(quant_config), prefix=maybe_prefix(prefix, "mlp_AR"), ) self.language_model = init_vllm_registered_model( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "language_model"), architectures=["Qwen3ForCausalLM"], ) self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)): return None return quant_config def _validate_and_reshape_mm_tensor(self, mm_input: object, name: str) -> torch.Tensor: if not isinstance(mm_input, (torch.Tensor, list)): raise ValueError(f"Incorrect type of {name}. " f"Got type: {type(mm_input)}") if isinstance(mm_input, torch.Tensor): if mm_input.ndim == 2: return mm_input if mm_input.ndim == 5: return mm_input if mm_input.ndim != 3: raise ValueError(f"{name} should be 2D or batched 3D tensor. " f"Got ndim: {mm_input.ndim} " f"(shape={mm_input.shape})") return torch.concat(list(mm_input)) else: return torch.concat(mm_input) def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[KeyeImageInputs]: pixel_values = kwargs.pop("pixel_values", None) image_embeds = kwargs.pop("image_embeds", None) image_grid_thw = kwargs.pop("image_grid_thw", None) if pixel_values is None and image_embeds is None: return None if pixel_values is not None: pixel_values = self._validate_and_reshape_mm_tensor( pixel_values, "image pixel values") image_grid_thw = self._validate_and_reshape_mm_tensor( image_grid_thw, "image grid_thw") if not isinstance(pixel_values, (torch.Tensor, list)): raise ValueError("Incorrect type of image pixel values. " f"Got type: {type(pixel_values)}") return KeyeImagePixelInputs( type="pixel_values", pixel_values=pixel_values, image_grid_thw=image_grid_thw, ) if image_embeds is not None: image_embeds = self._validate_and_reshape_mm_tensor( image_embeds, "image embeds") image_grid_thw = self._validate_and_reshape_mm_tensor( image_grid_thw, "image grid_thw") if not isinstance(image_embeds, torch.Tensor): raise ValueError("Incorrect type of image embeddings. " f"Got type: {type(image_embeds)}") return KeyeImageEmbeddingInputs( type="image_embeds", image_embeds=image_embeds, image_grid_thw=image_grid_thw, ) def _parse_and_validate_video_input( self, **kwargs: object) -> Optional[KeyeVideoInputs]: pixel_values_videos = kwargs.pop("pixel_values_videos", None) video_embeds = kwargs.pop("video_embeds", None) video_grid_thw = kwargs.pop("video_grid_thw", None) if pixel_values_videos is None and video_embeds is None: return None if pixel_values_videos is not None: pixel_values_videos = self._validate_and_reshape_mm_tensor( pixel_values_videos, "video pixel values", ) video_grid_thw = self._validate_and_reshape_mm_tensor( video_grid_thw, "video grid_thw") return KeyeVideoPixelInputs( type="pixel_values_videos", pixel_values_videos=pixel_values_videos, video_grid_thw=video_grid_thw, ) if video_embeds is not None: video_embeds = self._validate_and_reshape_mm_tensor( video_embeds, "video embeds") video_grid_thw = self._validate_and_reshape_mm_tensor( video_grid_thw, "video grid_thw") if not isinstance(video_embeds, torch.Tensor): raise ValueError("Incorrect type of video embeddings. " f"Got type: {type(video_embeds)}") return KeyeVideoEmbeddingInputs( type="video_embeds", video_embeds=video_embeds, video_grid_thw=video_grid_thw, ) def _process_image_input( self, image_input: KeyeImageInputs) -> tuple[torch.Tensor, ...]: siglip_position_ids = list() image_grid_hws = list() sample_indices = list() cu_seqlens = [0] image_grid_thw = image_input["image_grid_thw"] assert image_grid_thw.ndim == 2 for idx, thaw in enumerate(image_grid_thw): thw_tuple = tuple(thaw.detach().cpu().numpy().tolist()) numel = np.prod(thw_tuple) image_grid_hws.append(thw_tuple) image_position_ids = torch.arange(numel) % np.prod(thw_tuple[1:]) siglip_position_ids.append(image_position_ids) sample_indices.append(torch.full((numel, ), idx, dtype=torch.int64)) cu_seqlens.append(cu_seqlens[-1] + numel) if image_input["type"] == "image_embeds": raise ValueError( "Image embeddings are not supported for this processing path.") else: pixel_values = image_input["pixel_values"].type(self.visual.dtype) siglip_position_ids = torch.concat(siglip_position_ids, dim=0).to(pixel_values.device) cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32).to( pixel_values.device) sample_indices = torch.concat(sample_indices, dim=0).to(pixel_values.device) image_embeds = self.visual( pixel_values=pixel_values, image_grid_thw=image_grid_hws, position_ids=siglip_position_ids, vision_return_embed_list=False, interpolate_pos_encoding=True, sample_indices=sample_indices, cu_seqlens=cu_seqlens, use_rope=True, window_size=-1, ) image_embeds = tuple(self.mlp_AR(image_embeds, image_grid_thw)) return image_embeds def _process_video_input( self, video_input: KeyeVideoInputs) -> tuple[torch.Tensor, ...]: siglip_position_ids = list() video_grid_hws = list() sample_indices = list() cu_seqlens = [0] video_grid_thw = video_input["video_grid_thw"] assert video_grid_thw.ndim == 2 for idx, thaw in enumerate(video_grid_thw): thw_tuple = tuple(thaw.detach().cpu().numpy().tolist()) numel = np.prod(thw_tuple) video_grid_hws.append(thw_tuple) video_position_ids = torch.arange(numel) % np.prod(thw_tuple[1:]) siglip_position_ids.append(video_position_ids) sample_indices.append(torch.full((numel, ), idx, dtype=torch.int64)) cu_seqlens.append(cu_seqlens[-1] + numel) if video_input["type"] == "video_embeds": raise ValueError( "Video embeddings are not supported for this processing path.") else: pixel_values_videos = video_input["pixel_values_videos"].type( self.visual.dtype) siglip_position_ids = torch.concat(siglip_position_ids, dim=0).to( pixel_values_videos.device) cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32).to( pixel_values_videos.device) sample_indices = torch.concat(sample_indices, dim=0).to(pixel_values_videos.device) video_embeds = self.visual( pixel_values=pixel_values_videos, image_grid_thw=video_grid_hws, position_ids=siglip_position_ids, vision_return_embed_list=True, interpolate_pos_encoding=True, sample_indices=sample_indices, cu_seqlens=cu_seqlens, use_rope=True, window_size=-1, ) video_embeds = tuple(self.mlp_AR(video_embeds, video_grid_thw)) return video_embeds def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: modalities = {} for input_key in kwargs: if (input_key in ("pixel_values", "image_embeds") and "images" not in modalities): modalities["images"] = self._parse_and_validate_image_input( **kwargs) if (input_key in ("pixel_values_videos", "video_embeds") and "videos" not in modalities): modalities["videos"] = self._parse_and_validate_video_input( **kwargs) return modalities def get_multimodal_embeddings( self, **kwargs: object) -> Optional[MultiModalEmbeddings]: modalities = self._parse_and_validate_multimodal_inputs(**kwargs) if not modalities: return None multimodal_embeddings: tuple[torch.Tensor, ...] = () for modality in modalities: if modality == "images": image_input = modalities["images"] vision_embeddings = self._process_image_input(image_input) multimodal_embeddings += vision_embeddings if modality == "videos": video_input = modalities["videos"] video_embeddings = self._process_video_input(video_input) multimodal_embeddings += video_embeddings return multimodal_embeddings def get_input_embeddings( self, input_ids: torch.Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, ) -> torch.Tensor: inputs_embeds = self.language_model.get_input_embeddings(input_ids) if multimodal_embeddings is not None: inputs_embeds = merge_multimodal_embeddings( input_ids, inputs_embeds, multimodal_embeddings, [ self.config.image_token_id, self.config.video_token_id, ], ) return inputs_embeds def get_input_embeddings_v0( self, input_ids: torch.Tensor, image_input: Optional[KeyeImagePixelInputs] = None, video_input: Optional[KeyeVideoPixelInputs] = None, ) -> torch.Tensor: inputs_embeds = self.get_input_embeddings(input_ids) if image_input is not None: image_embeds = self._process_image_input(image_input) inputs_embeds = merge_multimodal_embeddings( input_ids, inputs_embeds, image_embeds, placeholder_token_id=self.config.image_token_id, ) if video_input is not None: video_embeds = self._process_video_input(video_input) inputs_embeds = merge_multimodal_embeddings( input_ids, inputs_embeds, video_embeds, placeholder_token_id=self.config.video_token_id, ) return inputs_embeds def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ) -> Union[torch.Tensor, IntermediateTensors]: """Run forward pass for Qwen2-VL. Args: input_ids: Flattened (concatenated) input_ids corresponding to a batch. positions: Flattened (concatenated) position ids corresponding to a batch. **NOTE**: If mrope is enabled (default setting for Qwen2-VL opensource models), the shape will be `(3, seq_len)`, otherwise it will be `(seq_len,). pixel_values: Pixel values to be fed to a model. `None` if no images are passed. image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM. `None` if no images are passed. pixel_values_videos: Pixel values of videos to be fed to a model. `None` if no videos are passed. video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM. `None` if no videos are passed. """ if intermediate_tensors is not None: inputs_embeds = None elif inputs_embeds is None: image_input = self._parse_and_validate_image_input(**kwargs) video_input = self._parse_and_validate_video_input(**kwargs) if image_input is None and video_input is None: inputs_embeds = None else: if uses_mrope(self.config): assert positions.ndim == 2 and positions.size(0) == 3, ( "multimodal section rotary embedding requires " f"(3, seq_len) positions, but got {positions.size()}") inputs_embeds = self.get_input_embeddings_v0( input_ids, image_input=image_input, video_input=video_input, ) input_ids = None hidden_states = self.language_model.model( input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: return self.language_model.compute_logits(hidden_states, sampling_metadata) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) def get_mm_mapping(self) -> MultiModelKeys: """Get the module prefix in multimodal models.""" return MultiModelKeys.from_string_field( language_model="language_model", connector="visual.", tower_model="mlp_AR.", )