# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math from abc import abstractmethod from collections.abc import Iterable, Mapping, Sequence from functools import partial from typing import Annotated, Any, Literal, Optional, TypeVar, Union import numpy as np import torch import torch.nn as nn from einops import rearrange from transformers import PretrainedConfig from transformers.activations import GELUActivation from transformers.feature_extraction_utils import BatchFeature from transformers.modeling_outputs import (BaseModelOutput, BaseModelOutputWithPooling) from transformers.utils import torch_int from vllm.attention.layer import check_upstream_fa_availability from vllm.config import VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors from vllm.multimodal.inputs import (ImageItem, ModalityData, MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, VideoItem) from vllm.multimodal.parse import (DictEmbeddingItems, ImageSize, ModalityDataItems, MultiModalDataItems, MultiModalDataParser) from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.platforms import _Backend from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import uses_mrope from vllm.utils import is_list_of from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) from .siglip import SiglipMLP from .utils import (AutoWeightsLoader, WeightsMapper, init_vllm_registered_model, is_pp_missing_parameter, maybe_prefix, merge_multimodal_embeddings) from .vision import get_vit_attn_backend logger = init_logger(__name__) def smart_resize( height: int, width: int, factor: int, min_pixels: int, max_pixels: int, ): if height < factor: logger.warning( "smart_resize: height=%s < factor=%s, reset height=factor", height, factor, ) width = round((width * factor) / height) height = factor if width < factor: logger.warning( "smart_resize: width=%s < factor=%s, reset width=factor", width, factor, ) height = round((height * factor) / width) width = factor if max(height, width) / min(height, width) > 200: raise ValueError("absolute aspect ratio must be smaller than 200, got " "{max(height, width) / min(height, width)}") h_bar = round(height / factor) * factor w_bar = round(width / factor) * factor if h_bar * w_bar > max_pixels: beta = math.sqrt((height * width) / max_pixels) h_bar = math.floor(height / beta / factor) * factor w_bar = math.floor(width / beta / factor) * factor elif h_bar * w_bar < min_pixels: beta = math.sqrt(min_pixels / (height * width)) h_bar = math.ceil(height * beta / factor) * factor w_bar = math.ceil(width * beta / factor) * factor return h_bar, w_bar class KeyeImagePixelInputs(TensorSchema): """ Dimensions: - b: Batch size - np: Number of patches - c: Number of channels - ps: Patch size - ni: Number of images - g: Grid dimensions (3 for t, h, w) """ type: Literal["pixel_values"] pixel_values: Annotated[ torch.Tensor, TensorShape("b", "np", 3, "ps", "ps", dynamic_dims={"np"})] image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)] class KeyeImageEmbeddingInputs(TensorSchema): """ Dimensions: - nf: Number of image features - hs: Hidden size (must match the hidden size of language model backbone) - ni: Number of images - g: Grid dimensions (3 for t, h, w) """ type: Literal["image_embeds"] image_embeds: Annotated[torch.Tensor, TensorShape("nf", "hs")] image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)] KeyeImageInputs = Union[KeyeImagePixelInputs, KeyeImageEmbeddingInputs] class KeyeVideoPixelInputs(TensorSchema): """ Dimensions: - b: Batch size - np: Number of patches - c: Number of channels - ps: Patch size - ni: Number of images - g: Grid dimensions (3 for t, h, w) """ type: Literal["pixel_values_videos"] pixel_values_videos: Annotated[ torch.Tensor, TensorShape("b", "np", 3, "ps", "ps", dynamic_dims={"np"})] video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)] class KeyeVideoEmbeddingInputs(TensorSchema): """ Dimensions: - nf: Number of video features - hs: Hidden size (must match the hidden size of language model backbone) - nv: Number of videos - g: Grid dimensions (3 for t, h, w) """ type: Literal["video_embeds"] video_embeds: Annotated[torch.Tensor, TensorShape("nf", "hs")] video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)] KeyeVideoInputs = Union[KeyeVideoPixelInputs, KeyeVideoEmbeddingInputs] class KeyeVisionEmbeddings(nn.Module): def __init__(self, config: PretrainedConfig): super().__init__() self.config = config self.embed_dim = config.hidden_size self.image_size = config.image_size self.patch_size = config.patch_size self.patch_embedding = nn.Conv2d( in_channels=config.num_channels, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size, padding="valid", ) self.num_patches = (self.image_size // self.patch_size)**2 self.num_positions = self.num_patches self.cache_position_embedding = dict() self.cache_position_count = dict() self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) self.packing_position_embedding = nn.Embedding(32768, self.embed_dim) self.register_buffer( "position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False, ) def interpolate_pos_encoding( self, embeddings: torch.Tensor, height: int, width: int, is_after_patchify: bool = False, ) -> torch.Tensor: num_positions = self.position_embedding.weight.shape[0] patch_pos_embed = self.position_embedding.weight.unsqueeze(0) dim = embeddings.shape[-1] if is_after_patchify: new_height = height new_width = width else: new_height = height // self.patch_size new_width = width // self.patch_size sqrt_num_positions = torch_int(num_positions**0.5) patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) patch_pos_embed = nn.functional.interpolate( patch_pos_embed, size=(new_height, new_width), mode="bilinear", align_corners=False, ) patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) return patch_pos_embed def fetch_position_embedding_lfu_cache(self, embeddings, h, w, max_cache: int = 20): grid = (h, w) if grid in self.cache_position_embedding: self.cache_position_count[grid] += 1 return self.cache_position_embedding[grid] if len(self.cache_position_embedding) >= max_cache: min_hit_grid = min( self.cache_position_count, key=self.cache_position_count.get, ) self.cache_position_count.pop(min_hit_grid) self.cache_position_embedding.pop(min_hit_grid) position_embedding = self.interpolate_pos_encoding( embeddings, h, w, True) self.cache_position_count[grid] = 1 self.cache_position_embedding[grid] = position_embedding return position_embedding def forward( self, pixel_values: torch.FloatTensor, position_ids: Optional[torch.Tensor] = None, image_grid_thw: Optional[list[Union[ tuple[int, int, int], list[tuple[int, int, int]], ]]] = None, interpolate_pos_encoding=False, ) -> torch.Tensor: if pixel_values.dim() == 4: pixel_values = pixel_values.unsqueeze(0) if pixel_values.dim() == 5: if position_ids is None: raise ValueError( "position_ids cannot be None when pixel_values.dim() is 5." ) ( batch_size, squence_len, channel, height, width, ) = pixel_values.shape target_dtype = self.patch_embedding.weight.dtype pixel_values = rearrange(pixel_values, "b l c h w -> (b l) c h w") patch_embeds = self.patch_embedding( pixel_values.to(dtype=target_dtype)) embeddings = patch_embeds.flatten(-2).squeeze(-1) if interpolate_pos_encoding and image_grid_thw is not None: start = 0 tmp_embeddings = list() for image_grid in image_grid_thw: t, h, w = image_grid end = start + t * h * w image_embeddings = embeddings[start:end, :] position_embedding = (self.interpolate_pos_encoding( image_embeddings, h, w, True).squeeze(0).repeat(t, 1)) image_embeddings = image_embeddings + position_embedding tmp_embeddings.append(image_embeddings) start = end embeddings = torch.concat(tmp_embeddings, dim=0).unsqueeze(0) else: embeddings = embeddings + self.packing_position_embedding( position_ids) return embeddings else: raise ValueError("Unsupported pixel_values dimension:" f" {pixel_values.dim()}. Expected 4 or 5.") def apply_rotary_pos_emb_flashatt( q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: cos = cos.chunk(2, dim=-1)[0].contiguous() sin = sin.chunk(2, dim=-1)[0].contiguous() from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q) k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k) return q_embed, k_embed class KeyeSiglipAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper.""" def __init__( self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): super().__init__() self.config = config hidden_size = config.hidden_size self.hidden_size = config.hidden_size tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = config.num_attention_heads assert self.total_num_heads % tp_size == 0 self.num_heads = self.total_num_heads // tp_size self.total_num_kv_heads = config.num_attention_heads if self.total_num_kv_heads >= tp_size: assert self.total_num_kv_heads % tp_size == 0 else: assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) self.head_dim = config.hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scale = self.head_dim**-0.5 self.qkv_proj = QKVParallelLinear( hidden_size, self.head_dim, self.total_num_heads, self.total_num_kv_heads, bias=True, quant_config=quant_config, prefix=f"{prefix}.qkv_proj", ) self.out_proj = RowParallelLinear( input_size=hidden_size, output_size=hidden_size, quant_config=quant_config, prefix=f"{prefix}.out_proj", ) # Detect attention implementation. self.attn_backend = get_vit_attn_backend( head_size=self.head_dim, dtype=torch.get_default_dtype()) self.use_upstream_fa = False if self.attn_backend != _Backend.FLASH_ATTN and \ check_upstream_fa_availability( torch.get_default_dtype()): self.attn_backend = _Backend.FLASH_ATTN self.use_upstream_fa = True if self.attn_backend not in {_Backend.FLASH_ATTN, _Backend.XFORMERS}: raise RuntimeError( f"Keye-VL does not support {self.attn_backend} backend now.") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, cu_seqlens: Optional[list[torch.Tensor]] = None, rope_emb: Optional[tuple[torch.Tensor, torch.Tensor]] = None, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split( [self.q_size, self.kv_size, self.kv_size], dim=-1, ) max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() batch_size = q.shape[0] if rope_emb is None: q = q.view(*q.shape[:-1], self.num_heads, self.head_dim) k = k.view( *k.shape[:-1], self.num_kv_heads, self.head_dim, ) v = v.view( *v.shape[:-1], self.num_kv_heads, self.head_dim, ) else: if cu_seqlens is None: raise ValueError( "cu_seqlens cannot be None when rope_emb is not None.") cos, sin = rope_emb q = q.view(*q.shape[:-1], self.num_heads, self.head_dim) k = k.view( *k.shape[:-1], self.num_kv_heads, self.head_dim, ) q, k = apply_rotary_pos_emb_flashatt(q, k, cos, sin) v = v.view( *v.shape[:-1], self.num_kv_heads, self.head_dim, ) if self.attn_backend == _Backend.FLASH_ATTN: if self.use_upstream_fa: from flash_attn import flash_attn_varlen_func else: from vllm.vllm_flash_attn import flash_attn_varlen_func q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) output = flash_attn_varlen_func( q, k, v, cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen, causal=False, softmax_scale=self.scale, ) context_layer = rearrange(output, "(b s) ... -> b s ...", b=batch_size) elif self.attn_backend == _Backend.XFORMERS: from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalMask attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens, kv_seqlen=None, device=q.device) context_layer = xops.memory_efficient_attention_forward( q, k, v, attn_bias=attn_bias, p=0, scale=None) context_layer = rearrange(context_layer, "b s h d -> b s (h d)").contiguous() output, _ = self.out_proj(context_layer) return output class SigLIPRotaryEmbedding(nn.Module): def __init__(self, dim: int, theta: float = 10000.0) -> None: super().__init__() self.dim = dim self.theta = theta self.rope_init() def rope_init(self): inv_freq = 1.0 / (self.theta**( torch.arange(0, self.dim, 2, dtype=torch.float) / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) def forward(self, seqlen: int) -> torch.Tensor: seq = torch.arange( seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype, ) freqs = torch.outer(seq, self.inv_freq) return freqs class KeyeSiglipEncoderLayer(nn.Module): def __init__( self, config: Union[PretrainedConfig], quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): super().__init__() self.embed_dim = config.hidden_size self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.self_attn = KeyeSiglipAttention( config, quant_config=quant_config, prefix=f"{prefix}.self_attn", ) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = SiglipMLP( config, quant_config=quant_config, prefix=f"{prefix}.mlp", ) def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, output_attentions: Optional[bool] = False, cu_seqlens: Optional[list[torch.Tensor]] = None, rope_emb: Optional[tuple[torch.Tensor, torch.Tensor]] = None, ) -> tuple[torch.FloatTensor]: residual = hidden_states hidden_states = self.layer_norm1(hidden_states) hidden_states = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions, cu_seqlens=cu_seqlens, rope_emb=rope_emb, ) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.layer_norm2(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states return hidden_states class KeyeSiglipEncoder(nn.Module): def __init__( self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): super().__init__() self.config = config embed_dim = config.hidden_size num_heads = config.num_attention_heads head_dim = embed_dim // num_heads self.layers = nn.ModuleList([ KeyeSiglipEncoderLayer( config, quant_config=quant_config, prefix=f"{prefix}.layers.{layer_idx}", ) for layer_idx in range(config.num_hidden_layers) ]) self.rotary_pos_emb = SigLIPRotaryEmbedding(head_dim // 2) @staticmethod def flatten_list(image_grid_thw): tmp_image_grid_thw = list() for image_grid in image_grid_thw: if isinstance(image_grid, list): tmp_image_grid_thw.extend(image_grid) else: tmp_image_grid_thw.append(image_grid) return tmp_image_grid_thw def forward( self, inputs_embeds, attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, cu_seqlens: Optional[list[torch.Tensor]] = None, image_grid_thw: Optional[list[Union[ tuple[int, int, int], list[tuple[int, int, int]], ]]] = None, height_position_ids: Optional[torch.Tensor] = None, width_position_ids: Optional[torch.Tensor] = None, use_rope: Optional[bool] = False, window_size: Optional[bool] = -1, vision_or_text: str = "vision", ) -> BaseModelOutput: device = inputs_embeds.device hidden_states = inputs_embeds if use_rope is True: flatten_image_grid_thw = self.flatten_list(image_grid_thw) if width_position_ids is None or height_position_ids is None: split_hids = list() split_wids = list() for t, h, w in flatten_image_grid_thw: image_pids = torch.arange(t * h * w, device=device) % (h * w) sample_hids = image_pids // w sample_wids = image_pids % w split_hids.append(sample_hids) split_wids.append(sample_wids) width_position_ids = torch.concat(split_wids, dim=0) height_position_ids = torch.concat(split_hids, dim=0) pids = torch.stack( [height_position_ids, width_position_ids], dim=-1, ) max_grid_size = pids.max() + 1 rope_emb_max_grid = self.rotary_pos_emb(max_grid_size) rope_emb = rope_emb_max_grid[pids].flatten(1) rope_emb = rope_emb.repeat(1, 2) rope_emb = (rope_emb.cos(), rope_emb.sin()) else: rope_emb = None attn_cu_seqlens = cu_seqlens hidden_states = inputs_embeds assert attention_mask is None for encoder_layer in self.layers: hidden_states = encoder_layer( hidden_states, attention_mask, output_attentions=output_attentions, cu_seqlens=attn_cu_seqlens, rope_emb=rope_emb, ) return hidden_states class KeyeSiglipVisionTransformer(nn.Module): def __init__( self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): super().__init__() self.config = config embed_dim = config.hidden_size self.embeddings = KeyeVisionEmbeddings(config) self.encoder = KeyeSiglipEncoder( config, quant_config=quant_config, prefix=f"{prefix}.encoder", ) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) def forward( self, pixel_values, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, interpolate_pos_encoding: Optional[bool] = False, attention_mask: Optional[torch.Tensor] = None, sample_indices: Optional[torch.Tensor] = None, image_indices: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, height_position_ids: Optional[torch.Tensor] = None, width_position_ids: Optional[torch.Tensor] = None, cu_seqlens: Optional[list[torch.Tensor]] = None, padding_mask: Optional[torch.Tensor] = None, vision_return_embed_list: Optional[bool] = False, image_grid_thw: Optional[list[Union[ tuple[int, int, int], list[tuple[int, int, int]], ]]] = None, return_pooler_output: Optional[bool] = True, use_rope: Optional[bool] = False, window_size: Optional[bool] = -1, ) -> BaseModelOutputWithPooling: hidden_states = self.embeddings( pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, position_ids=position_ids, image_grid_thw=image_grid_thw, ) last_hidden_state = self.encoder( inputs_embeds=hidden_states, output_attentions=output_attentions, output_hidden_states=output_hidden_states, attention_mask=attention_mask, cu_seqlens=cu_seqlens, image_grid_thw=image_grid_thw, use_rope=use_rope, height_position_ids=height_position_ids, width_position_ids=width_position_ids, window_size=window_size, vision_or_text="vision", ) last_hidden_state = self.post_layernorm(last_hidden_state) sample_hidden_state = list() if cu_seqlens is None: raise ValueError("cu_seqlens cannot be None for " "SiglipVisionTransformer output processing.") for i in range(cu_seqlens.shape[0] - 1): start = cu_seqlens[i] end = cu_seqlens[i + 1] tensor = last_hidden_state[:, start:end, :].squeeze(0) sample_hidden_state.append(tensor) return sample_hidden_state class KeyeSiglipVisionModel(nn.Module): config_class = PretrainedConfig main_input_name = "pixel_values" def __init__( self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): super().__init__() self.vision_model = KeyeSiglipVisionTransformer( config, quant_config=quant_config, prefix=f"{prefix}.vision_model", ) self.quant_config = quant_config @property def dtype(self) -> torch.dtype: return self.vision_model.embeddings.patch_embedding.weight.dtype @property def device(self) -> torch.device: return self.vision_model.embeddings.patch_embedding.weight.device def get_input_embeddings(self) -> nn.Module: return self.vision_model.embeddings.patch_embedding def forward( self, pixel_values, sample_indices: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, interpolate_pos_encoding: bool = False, position_ids: Optional[torch.Tensor] = None, vision_return_embed_list: Optional[bool] = False, image_grid_thw: Optional[list[Union[ tuple[int, int, int], list[tuple[int, int, int]], ]]] = None, cu_seqlens: Optional[list[torch.Tensor]] = None, return_pooler_output: Optional[bool] = True, use_rope: Optional[bool] = False, window_size: Optional[bool] = -1, ) -> BaseModelOutputWithPooling: return self.vision_model( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, interpolate_pos_encoding=interpolate_pos_encoding, position_ids=position_ids, vision_return_embed_list=vision_return_embed_list, image_grid_thw=image_grid_thw, sample_indices=sample_indices, cu_seqlens=cu_seqlens, return_pooler_output=return_pooler_output, use_rope=use_rope, window_size=window_size, ) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), ] params_dict = dict(self.named_parameters(remove_duplicate=False)) loaded_params: set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue if "head.attention" in name or "head.layernorm" in name: continue if "head.mlp" in name or "head.probe" in name: continue if self.quant_config is not None and ( scale_name := self.quant_config.get_cache_scale(name)): param = params_dict[scale_name] weight_loader = getattr( param, "weight_loader", default_weight_loader, ) loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]) weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue for ( param_name, weight_name, shard_id, ) in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) if name.endswith(".bias") and name not in params_dict: continue if is_pp_missing_parameter(name, self): continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: if name.endswith(".bias") and name not in params_dict: continue name = maybe_remap_kv_scale_name(name, params_dict) if name is None: continue if is_pp_missing_parameter(name, self): continue param = params_dict[name] weight_loader = getattr( param, "weight_loader", default_weight_loader, ) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class Projector(nn.Module): def __init__( self, text_config: PretrainedConfig, vision_config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): super().__init__() self.text_config = text_config self.vision_config = vision_config self.merge_kernel_size = (2, 2) self.hidden_size = (self.vision_config.hidden_size * self.merge_kernel_size[0] * self.merge_kernel_size[1]) self.pre_norm = torch.nn.LayerNorm(self.vision_config.hidden_size, eps=1e-05) self.act = GELUActivation() self.linear_1 = ColumnParallelLinear( self.hidden_size, self.hidden_size, bias=True, quant_config=quant_config, prefix=f"{prefix}.linear_1", ) self.linear_2 = RowParallelLinear( self.hidden_size, self.text_config.hidden_size, bias=True, quant_config=quant_config, prefix=f"{prefix}.linear_2", ) def forward( self, image_features: Union[torch.Tensor, list[torch.Tensor]], image_grid_thw: list[tuple[int, int, int]], ) -> Union[torch.Tensor, list[torch.Tensor]]: m1, m2 = self.merge_kernel_size if isinstance(image_features, (list, tuple)): processed_features = list() for image_feature, image_grid in zip(image_features, image_grid_thw): image_feature = self.pre_norm(image_feature) t, h, w = image_grid image_feature = rearrange( image_feature, "(t h p1 w p2) d -> (t h w) (p1 p2 d)", t=t, h=h // m1, p1=m1, w=w // m2, p2=m2, ) hidden_states, _ = self.linear_1(image_feature) hidden_states = self.act(hidden_states) hidden_states, _ = self.linear_2(hidden_states) processed_features.append(hidden_states) return processed_features dims = image_features.shape[:-1] dim = image_features.shape[-1] image_features = image_features.view(np.prod(dims), dim) hidden_states = self.pre_norm(image_features).view( -1, self.hidden_size) hidden_states = self.linear_1(hidden_states) hidden_states = self.act(hidden_states) hidden_states = self.linear_2(hidden_states) return hidden_states.view(*dims, -1) def _keye_field_config(hf_inputs: Mapping[str, torch.Tensor], ): image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) image_grid_sizes = image_grid_thw.prod(-1) video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3))) video_grid_sizes = video_grid_thw.prod(-1) return dict( pixel_values=MultiModalFieldConfig.flat_from_sizes( "image", image_grid_sizes), image_embeds=MultiModalFieldConfig.flat_from_sizes( "image", image_grid_sizes), image_grid_thw=MultiModalFieldConfig.batched("image"), pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( "video", video_grid_sizes), video_embeds=MultiModalFieldConfig.flat_from_sizes( "video", video_grid_sizes), video_grid_thw=MultiModalFieldConfig.batched("video"), ) class KeyeMultiModalDataParser(MultiModalDataParser): def _parse_image_data( self, data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]], ) -> ModalityDataItems[Any, Any]: if isinstance(data, dict): return DictEmbeddingItems( data, modality="image", required_fields={ "image_embeds", "image_grid_thw", }, fields_factory=_keye_field_config, ) return super()._parse_image_data(data) def _parse_video_data( self, data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]], ) -> ModalityDataItems[Any, Any]: if isinstance(data, dict): return DictEmbeddingItems( data, modality="video", required_fields={ "video_embeds", "video_grid_thw", }, fields_factory=_keye_field_config, ) return super()._parse_video_data(data) class KeyeProcessingInfo(BaseProcessingInfo): def get_max_image_size(self) -> int: return 9999999 #_MAX_IMAGE_SIZE def get_max_frame_per_video(self) -> int: return 16 #_MAX_FRAMES_PER_VIDEO def get_image_processor(self, **kwargs: object): return self.get_hf_processor(**kwargs).image_processor def get_supported_mm_limits(self, ) -> Mapping[str, Optional[int]]: return {"image": None, "video": None} def get_mm_max_tokens_per_item( self, seq_len: int, mm_counts: Mapping[str, int], ) -> Mapping[str, int]: return { "image": self.get_max_image_tokens(), "video": self.get_max_video_tokens(seq_len), } def _get_vision_info( self, *, image_width: int, image_height: int, num_frames: int = 1, do_resize: bool = True, image_processor, ) -> tuple[ImageSize, int]: if image_processor is None: image_processor = self.get_image_processor() hf_config = self.get_hf_config() vision_config = hf_config.vision_config patch_size = vision_config.patch_size merge_size = vision_config.spatial_merge_size temporal_patch_size = 1 if do_resize: resized_height, resized_width = smart_resize( height=image_height, width=image_width, factor=patch_size * merge_size, min_pixels=image_processor.min_pixels, max_pixels=image_processor.max_pixels, ) preprocessed_size = ImageSize(width=resized_width, height=resized_height) else: preprocessed_size = ImageSize(width=image_width, height=image_height) padded_num_frames = num_frames + num_frames % temporal_patch_size grid_t = max(padded_num_frames // temporal_patch_size, 1) grid_h = preprocessed_size.height // patch_size grid_w = preprocessed_size.width // patch_size num_patches = grid_t * grid_h * grid_w num_vision_tokens = num_patches // (merge_size**2) return preprocessed_size, num_vision_tokens def get_num_image_tokens( self, *, image_width: int, image_height: int, image_processor, ) -> int: _, num_image_tokens = self._get_vision_info( image_width=image_width, image_height=image_height, image_processor=image_processor, ) return num_image_tokens def get_num_video_tokens( self, *, image_width: int, image_height: int, num_frames: int, image_processor, ) -> int: _, num_video_tokens = self._get_vision_info( image_width=image_width, image_height=image_height, num_frames=num_frames, image_processor=image_processor, ) return num_video_tokens def get_image_size_with_most_features(self, ) -> ImageSize: max_image_size, _ = self._get_vision_info( image_width=self.get_max_image_size(), image_height=self.get_max_image_size(), image_processor=None, ) return max_image_size def get_max_image_tokens(self) -> int: target_width, target_height = self.get_image_size_with_most_features() return self.get_num_image_tokens( image_width=target_width, image_height=target_height, image_processor=None, ) def _get_max_video_frames(self, max_tokens: int) -> int: target_width, target_height = self.get_image_size_with_most_features() num_frames = 0 while True: next_num_frames = num_frames + 1 next_max_tokens = self.get_num_video_tokens( image_width=target_width, image_height=target_height, num_frames=next_num_frames, image_processor=None, ) if next_max_tokens > max_tokens: break num_frames = next_num_frames return num_frames def get_num_frames_with_most_features(self, seq_len: int) -> int: mm_config = self.ctx.get_mm_config() max_images = mm_config.get_limit_per_prompt("image") max_videos = mm_config.get_limit_per_prompt("video") max_image_tokens = self.get_max_image_tokens() * max_images max_total_frames = self._get_max_video_frames(seq_len - max_image_tokens) max_frames_per_video = min( max_total_frames // max(max_videos, 1), self.get_max_frame_per_video(), ) return max(max_frames_per_video, 1) def get_max_video_tokens(self, seq_len: int) -> int: target_width, target_height = self.get_image_size_with_most_features() return self.get_num_video_tokens( image_width=target_width, image_height=target_height, num_frames=self.get_num_frames_with_most_features(seq_len), image_processor=None, ) _I = TypeVar("_I", bound=KeyeProcessingInfo) class KeyeBaseDummyInputsBuilder(BaseDummyInputsBuilder[_I]): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) hf_processor = self.info.get_hf_processor() image_token: str = hf_processor.image_token video_token: str = hf_processor.video_token return image_token * num_images + video_token * num_videos def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) target_width, target_height = ( self.info.get_image_size_with_most_features()) target_num_frames = self.info.get_num_frames_with_most_features( seq_len) mm_data = { "image": self._get_dummy_images( width=target_width, height=target_height, num_images=num_images, ), "video": self._get_dummy_videos( width=target_width, height=target_height, num_frames=target_num_frames, num_videos=num_videos, ), } return mm_data class KeyeDummyInputsBuilder(KeyeBaseDummyInputsBuilder[KeyeProcessingInfo]): ... class KeyeMultiModalProcessor(BaseMultiModalProcessor[KeyeProcessingInfo]): def _get_data_parser(self) -> MultiModalDataParser: return KeyeMultiModalDataParser() def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, Any], out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) image_processor = self.info.get_image_processor( **hf_processor_mm_kwargs) tokenizer = self.info.get_tokenizer() vocab = tokenizer.get_vocab() placeholder = { "image": vocab[hf_processor.image_token], "video": vocab[hf_processor.video_token], } merge_length = image_processor.merge_size**2 def get_replacement_keye(item_idx: int, modality: str): out_item = out_mm_kwargs[modality][item_idx] grid_thw = out_item[f"{modality}_grid_thw"].data assert isinstance(grid_thw, torch.Tensor) num_tokens = int(grid_thw.prod()) // merge_length return [placeholder[modality]] * num_tokens return [ PromptReplacement( modality=modality, target=[placeholder[modality]], replacement=partial(get_replacement_keye, modality=modality), ) for modality in ("image", "video") ] def _get_mm_fields_config( self, hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: return _keye_field_config(hf_inputs) class BaseKeyeModule(nn.Module): packed_modules_mapping = { "qkv_proj": [ "q_proj", "k_proj", "v_proj", ], "gate_up_proj": [ "gate_proj", "up_proj", ], } hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={ "lm_head.": "language_model.lm_head.", "model.": "language_model.model.", }) @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: if modality.startswith("image"): return "<|vision_start|><|image_pad|><|vision_end|>" if modality.startswith("video"): return "<|vision_start|><|video_pad|><|vision_end|>" raise ValueError("Only image or video modality is supported") def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config: PretrainedConfig = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config multimodal_config = vllm_config.model_config.multimodal_config self.config = config self.multimodal_config = multimodal_config self.visual = KeyeSiglipVisionModel( config.vision_config, quant_config=quant_config, prefix=maybe_prefix(prefix, "visual"), ) self.mlp_AR = self._build_projector( config, config.vision_config, quant_config=quant_config, prefix=maybe_prefix(prefix, "mlp_AR"), ) self.language_model = init_vllm_registered_model( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "language_model"), architectures=["Qwen3ForCausalLM"], ) self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) @abstractmethod def _build_projector(self, text_config: PretrainedConfig, vision_config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "") -> nn.Module: raise ValueError("Need projector") def _process_image_input(self, image_input: Any) -> tuple[torch.Tensor, ...]: siglip_position_ids = list() image_grid_hws = list() sample_indices = list() cu_seqlens = [0] image_grid_thw = image_input["image_grid_thw"] assert image_grid_thw.ndim == 2 for idx, thaw in enumerate(image_grid_thw): thw_tuple = tuple(thaw.detach().cpu().numpy().tolist()) numel = np.prod(thw_tuple) image_grid_hws.append(thw_tuple) image_position_ids = torch.arange(numel) % np.prod(thw_tuple[1:]) siglip_position_ids.append(image_position_ids) sample_indices.append(torch.full((numel, ), idx, dtype=torch.int64)) cu_seqlens.append(cu_seqlens[-1] + numel) if image_input["type"] == "image_embeds": raise ValueError( "Image embeddings are not supported for this processing path.") else: pixel_values = image_input["pixel_values"].type(self.visual.dtype) siglip_position_ids = torch.concat(siglip_position_ids, dim=0).to(pixel_values.device) cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32).to( pixel_values.device) sample_indices = torch.concat(sample_indices, dim=0).to(pixel_values.device) image_embeds = self.visual( pixel_values=pixel_values, image_grid_thw=image_grid_hws, position_ids=siglip_position_ids, vision_return_embed_list=False, interpolate_pos_encoding=True, sample_indices=sample_indices, cu_seqlens=cu_seqlens, use_rope=True, window_size=-1, ) image_embeds = tuple(self.mlp_AR(image_embeds, image_grid_thw)) return image_embeds def _process_video_embeds( self, video_type: Literal["video_embeds", "pixel_values_videos"], video_grid_thw: list[torch.Tensor], pixel_values_videos: Optional[torch.Tensor] = None ) -> Union[torch.Tensor, list[torch.Tensor]]: siglip_position_ids = list() video_grid_hws = list() sample_indices = list() cu_seqlens = [0] assert video_grid_thw.ndim == 2 for idx, sub_thw in enumerate(video_grid_thw): thw_tuple = tuple(sub_thw.detach().cpu().numpy().tolist()) numel = np.prod(thw_tuple) video_grid_hws.append(thw_tuple) video_position_ids = torch.arange(numel) % np.prod(thw_tuple[1:]) siglip_position_ids.append(video_position_ids) sample_indices.append(torch.full((numel, ), idx, dtype=torch.int64)) cu_seqlens.append(cu_seqlens[-1] + numel) if video_type == "video_embeds": raise ValueError( "Video embeddings are not supported for this processing path.") else: pixel_values_videos = pixel_values_videos.type(self.visual.dtype) siglip_position_ids = torch.concat(siglip_position_ids, dim=0).to( pixel_values_videos.device) cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32).to( pixel_values_videos.device) sample_indices = torch.concat(sample_indices, dim=0).to(pixel_values_videos.device) video_embeds = self.visual( pixel_values=pixel_values_videos, image_grid_thw=video_grid_hws, position_ids=siglip_position_ids, vision_return_embed_list=True, interpolate_pos_encoding=True, sample_indices=sample_indices, cu_seqlens=cu_seqlens, use_rope=True, window_size=-1, ) video_embeds = self.mlp_AR(video_embeds, video_grid_thw) return video_embeds def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: modalities = {} for input_key in kwargs: if (input_key in ("pixel_values", "image_embeds") and "images" not in modalities): modalities["images"] = self._parse_and_validate_image_input( **kwargs) if (input_key in ("pixel_values_videos", "video_embeds") and "videos" not in modalities): modalities["videos"] = self._parse_and_validate_video_input( **kwargs) return modalities def get_language_model(self) -> torch.nn.Module: return self.language_model def get_multimodal_embeddings( self, **kwargs: object) -> Optional[MultiModalEmbeddings]: modalities = self._parse_and_validate_multimodal_inputs(**kwargs) if not modalities: return None multimodal_embeddings: tuple[torch.Tensor, ...] = () for modality in modalities: if modality == "images": image_input = modalities["images"] vision_embeddings = self._process_image_input(image_input) multimodal_embeddings += vision_embeddings if modality == "videos": video_input = modalities["videos"] video_embeddings = self._process_video_input(video_input) multimodal_embeddings += video_embeddings return multimodal_embeddings def get_input_embeddings( self, input_ids: torch.Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, ) -> torch.Tensor: inputs_embeds = self.language_model.get_input_embeddings(input_ids) if multimodal_embeddings is not None: inputs_embeds = merge_multimodal_embeddings( input_ids, inputs_embeds, multimodal_embeddings, [ self.config.image_token_id, self.config.video_token_id, ], ) return inputs_embeds def get_input_embeddings_v0( self, input_ids: torch.Tensor, image_input: Optional[Any] = None, video_input: Optional[Any] = None, ) -> torch.Tensor: inputs_embeds = self.get_input_embeddings(input_ids) if image_input is not None: image_embeds = self._process_image_input(image_input) inputs_embeds = merge_multimodal_embeddings( input_ids, inputs_embeds, image_embeds, placeholder_token_id=self.config.image_token_id, ) if video_input is not None: video_embeds = self._process_video_input(video_input) inputs_embeds = merge_multimodal_embeddings( input_ids, inputs_embeds, video_embeds, placeholder_token_id=self.config.video_token_id, ) return inputs_embeds def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ) -> Union[torch.Tensor, IntermediateTensors]: """Run forward pass for Keye-VL. Args: input_ids: Flattened (concatenated) input_ids corresponding to a batch. positions: Flattened (concatenated) position ids corresponding to a batch. **NOTE**: If mrope is enabled (default setting for Qwen2-VL opensource models), the shape will be `(3, seq_len)`, otherwise it will be `(seq_len,)`. intermediate_tensors: Intermediate tensors from prior forward pass. inputs_embeds: Optional tensor of input embeddings. """ if intermediate_tensors is not None: inputs_embeds = None elif inputs_embeds is None: image_input = self._parse_and_validate_image_input(**kwargs) video_input = self._parse_and_validate_video_input(**kwargs) if image_input is None and video_input is None: inputs_embeds = None else: if uses_mrope(self.config): assert positions.ndim == 2 and positions.size(0) == 3, ( "multimodal section rotary embedding requires " f"(3, seq_len) positions, but got {positions.size()}") inputs_embeds = self.get_input_embeddings_v0( input_ids, image_input=image_input, video_input=video_input, ) input_ids = None hidden_states = self.language_model.model( input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, ) -> Optional[torch.Tensor]: return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) def get_mm_mapping(self) -> MultiModelKeys: """Get the module prefix in multimodal models.""" return MultiModelKeys.from_string_field( language_model="language_model", connector="mlp_AR.", tower_model="visual.", ) @MULTIMODAL_REGISTRY.register_processor( KeyeMultiModalProcessor, info=KeyeProcessingInfo, dummy_inputs=KeyeDummyInputsBuilder, ) class KeyeForConditionalGeneration(BaseKeyeModule, SupportsMultiModal, SupportsLoRA, SupportsPP): def _build_projector(self, text_config: PretrainedConfig, vision_config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "") -> nn.Module: return Projector(text_config, vision_config, quant_config, prefix) def _validate_and_reshape_mm_tensor( self, mm_input: NestedTensors, name: str) -> Union[torch.Tensor, list[torch.Tensor]]: if not isinstance(mm_input, (torch.Tensor, list)): raise ValueError(f"Incorrect type of {name}. " f"Got type: {type(mm_input)}") if isinstance(mm_input, torch.Tensor): if mm_input.ndim == 2: return mm_input if mm_input.ndim == 5: return mm_input if mm_input.ndim != 3: raise ValueError(f"{name} should be 2D or batched 3D tensor. " f"Got ndim: {mm_input.ndim} " f"(shape={mm_input.shape})") return mm_input.reshape(-1, mm_input.shape[-1]) elif is_list_of(mm_input, torch.Tensor): if all(p.dim() == 4 for p in mm_input) or all(p.dim() == 2 for p in mm_input): return mm_input return torch.concat(mm_input) def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[KeyeImageInputs]: pixel_values = kwargs.pop("pixel_values", None) image_embeds = kwargs.pop("image_embeds", None) image_grid_thw = kwargs.pop("image_grid_thw", None) if pixel_values is None and image_embeds is None: return None if pixel_values is not None: pixel_values = self._validate_and_reshape_mm_tensor( pixel_values, "image pixel values") image_grid_thw = self._validate_and_reshape_mm_tensor( image_grid_thw, "image grid_thw") return KeyeImagePixelInputs( type="pixel_values", pixel_values=pixel_values, image_grid_thw=image_grid_thw, ) if image_embeds is not None: image_embeds = self._validate_and_reshape_mm_tensor( image_embeds, "image embeds") image_grid_thw = self._validate_and_reshape_mm_tensor( image_grid_thw, "image grid_thw") return KeyeImageEmbeddingInputs( type="image_embeds", image_embeds=image_embeds, image_grid_thw=image_grid_thw, ) def _parse_and_validate_video_input( self, **kwargs: object) -> Optional[KeyeVideoInputs]: pixel_values_videos = kwargs.pop("pixel_values_videos", None) video_embeds = kwargs.pop("video_embeds", None) video_grid_thw = kwargs.pop("video_grid_thw", None) if pixel_values_videos is None and video_embeds is None: return None if pixel_values_videos is not None: pixel_values_videos = self._validate_and_reshape_mm_tensor( pixel_values_videos, "video pixel values", ) video_grid_thw = self._validate_and_reshape_mm_tensor( video_grid_thw, "video grid_thw") return KeyeVideoPixelInputs( type="pixel_values_videos", pixel_values_videos=pixel_values_videos, video_grid_thw=video_grid_thw, ) if video_embeds is not None: video_embeds = self._validate_and_reshape_mm_tensor( video_embeds, "video embeds") video_grid_thw = self._validate_and_reshape_mm_tensor( video_grid_thw, "video grid_thw") return KeyeVideoEmbeddingInputs( type="video_embeds", video_embeds=video_embeds, video_grid_thw=video_grid_thw, ) def _process_video_input( self, video_input: KeyeVideoInputs) -> tuple[torch.Tensor, ...]: video_type = video_input["type"] video_grid_thw = video_input["video_grid_thw"] pixel_values_videos = video_input.get("pixel_values_videos", None) return tuple( self._process_video_embeds(video_type, video_grid_thw, pixel_values_videos))