# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Implementation of Siglip2VisionModel intended to be only used within a vision language model.""" from collections.abc import Iterable import torch from torch import nn from torch.nn import functional as F from transformers import Siglip2VisionConfig from vllm.compilation.decorators import support_torch_compile from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.attention import MMEncoderAttention from vllm.model_executor.layers.linear import ( ColumnParallelLinear, QKVParallelLinear, RowParallelLinear, ) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader from .vision import ( is_vit_use_data_parallel, resolve_visual_encoder_outputs, should_torch_compile_mm_vit, ) class Siglip2VisionEmbeddings(nn.Module): def __init__(self, config: Siglip2VisionConfig): super().__init__() self.config = config self.embed_dim = config.hidden_size self.patch_size = config.patch_size self.patch_embedding = nn.Linear( in_features=config.num_channels * self.patch_size * self.patch_size, out_features=self.embed_dim, ) self.num_patches = config.num_patches self.position_embedding_size = int(self.num_patches**0.5) self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim) def forward( self, pixel_values_packed: torch.FloatTensor, spatial_shapes: torch.LongTensor, ) -> torch.Tensor: """Embed patchified pixel values in packed (unpadded) form. Args: pixel_values_packed: (1, total_tokens, patch_dim) or (total_tokens, patch_dim), packed in tile order. spatial_shapes: (num_tiles, 2) on CPU (height, width) per tile. Returns: (1, total_tokens, embed_dim) packed embeddings. """ assert spatial_shapes.device.type == "cpu", ( "Expected `spatial_shapes` on CPU to avoid device-to-host sync in " "variable-length packing." ) if pixel_values_packed.dim() == 3: assert pixel_values_packed.shape[0] == 1 pixel_values_flat = pixel_values_packed[0] else: pixel_values_flat = pixel_values_packed lengths = (spatial_shapes[:, 0] * spatial_shapes[:, 1]).to(dtype=torch.int64) lengths_list = lengths.tolist() total_tokens = int(sum(lengths_list)) if total_tokens != pixel_values_flat.shape[0]: raise ValueError( "Packed pixel_values token count does not match spatial_shapes: " f"{pixel_values_flat.shape[0]} vs {total_tokens}." ) target_dtype = self.patch_embedding.weight.dtype patch_embeds = self.patch_embedding(pixel_values_flat.to(dtype=target_dtype)) positional_embeddings = self.position_embedding.weight.reshape( self.position_embedding_size, self.position_embedding_size, -1 ) packed_pos_embeds = self.resize_positional_embeddings_packed( positional_embeddings, spatial_shapes, lengths_list=lengths_list, ) embeddings = patch_embeds + packed_pos_embeds return embeddings.unsqueeze(0) @staticmethod def resize_positional_embeddings_packed( positional_embeddings: torch.Tensor, spatial_shapes: torch.LongTensor, lengths_list: list[int], ) -> torch.Tensor: """Resize positional embeddings per image and return a packed tensor. Args: positional_embeddings: (height, width, embed_dim) base grid. spatial_shapes: (batch_size, 2) on CPU, (height, width) per image. lengths_list: flattened token length per image (height * width). Returns: (total_tokens, embed_dim) packed positional embeddings, concatenated in the same order as `lengths_list`. """ assert spatial_shapes.device.type == "cpu" embed_dim = positional_embeddings.shape[-1] source_dtype = positional_embeddings.dtype total_tokens = int(sum(lengths_list)) packed_pos_embeds = torch.empty( (total_tokens, embed_dim), device=positional_embeddings.device, dtype=source_dtype, ) # (height, width, embed_dim) -> (1, embed_dim, height, width) pos_4d = positional_embeddings.permute(2, 0, 1).unsqueeze(0) # Upcast to float32 on CPU because antialias is not supported for # bfloat16/float16 on CPU. if pos_4d.device.type == "cpu": pos_4d = pos_4d.to(torch.float32) offset = 0 for i, length in enumerate(lengths_list): if length <= 0: continue height, width = spatial_shapes[i].tolist() resized = F.interpolate( pos_4d, size=(height, width), mode="bilinear", align_corners=False, antialias=True, ) resized = resized.reshape(embed_dim, height * width).transpose(0, 1) resized = resized.to(source_dtype) packed_pos_embeds[offset : offset + length] = resized offset += length return packed_pos_embeds class Siglip2Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__( self, config: Siglip2VisionConfig, quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() self.config = config self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.embed_dim // self.num_heads if self.head_dim * self.num_heads != self.embed_dim: raise ValueError( f"embed_dim must be divisible by num_heads " f"(got `embed_dim`: {self.embed_dim} and `num_heads`:" f" {self.num_heads})." ) self.scale = self.head_dim**-0.5 self.dropout = config.attention_dropout use_data_parallel = is_vit_use_data_parallel() tp_size = 1 if use_data_parallel else get_tensor_model_parallel_world_size() assert self.num_heads % tp_size == 0 self.num_heads_per_partition = self.num_heads // tp_size self.qkv_proj = QKVParallelLinear( hidden_size=self.embed_dim, head_size=self.head_dim, total_num_heads=self.num_heads, quant_config=quant_config, prefix=f"{prefix}.qkv_proj", disable_tp=use_data_parallel, ) self.out_proj = RowParallelLinear( input_size=self.embed_dim, output_size=self.embed_dim, quant_config=quant_config, prefix=f"{prefix}.out_proj", disable_tp=use_data_parallel, ) self.attn = MMEncoderAttention( num_heads=self.num_heads_per_partition, head_size=self.head_dim, scale=self.scale, prefix=f"{prefix}.attn", ) def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, max_seqlen: int | torch.Tensor, ) -> torch.Tensor: qkv, _ = self.qkv_proj( hidden_states ) # batch_size, q_len, 3 * num_heads_per_partition * head_dim bsz, q_len, _ = qkv.shape query_states, key_states, value_states = qkv.chunk(3, dim=-1) query_states = query_states.view( bsz, q_len, self.num_heads_per_partition, self.head_dim ) key_states = key_states.view( bsz, q_len, self.num_heads_per_partition, self.head_dim ) value_states = value_states.view( bsz, q_len, self.num_heads_per_partition, self.head_dim ) # Use unified MultiHeadAttention implementation out = self.attn( query=query_states, key=key_states, value=value_states, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, ) out = out.reshape(bsz, q_len, -1) attn_output, _ = self.out_proj(out) return attn_output class Siglip2MLP(nn.Module): def __init__( self, config: Siglip2VisionConfig, quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() self.config = config self.activation_fn = get_act_fn(config.hidden_act) use_data_parallel = is_vit_use_data_parallel() self.fc1 = ColumnParallelLinear( config.hidden_size, config.intermediate_size, quant_config=quant_config, prefix=f"{prefix}.fc1", disable_tp=use_data_parallel, ) self.fc2 = RowParallelLinear( config.intermediate_size, config.hidden_size, quant_config=quant_config, prefix=f"{prefix}.fc2", disable_tp=use_data_parallel, ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states, _ = self.fc1(hidden_states) hidden_states = self.activation_fn(hidden_states) hidden_states, _ = self.fc2(hidden_states) return hidden_states @support_torch_compile( dynamic_arg_dims={"hidden_states": [0, 1], "cu_seqlens": 0}, enable_if=should_torch_compile_mm_vit, ) class Siglip2EncoderLayer(nn.Module): def __init__( self, config: Siglip2VisionConfig, quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() self.embed_dim = config.hidden_size self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.self_attn = Siglip2Attention( config, quant_config=quant_config, prefix=f"{prefix}.self_attn", ) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = Siglip2MLP( config, quant_config=quant_config, prefix=f"{prefix}.mlp", ) def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, max_seqlen: int | torch.Tensor, ) -> torch.Tensor: """ Args: hidden_states: Input tensor of shape (batch, seq_len, embed_dim). cu_seqlens: Cumulative sequence lengths tensor. max_seqlen: Maximum sequence length. """ residual = hidden_states hidden_states = self.layer_norm1(hidden_states) hidden_states = self.self_attn( hidden_states=hidden_states, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, ) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.layer_norm2(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states return hidden_states class Siglip2Encoder(nn.Module): """ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a [`Siglip2EncoderLayer`]. Args: config: PretrainedConfig """ def __init__( self, config: Siglip2VisionConfig, quant_config: QuantizationConfig | None = None, num_hidden_layers_override: int | None = None, prefix: str = "", ): super().__init__() self.config = config if num_hidden_layers_override is None: num_hidden_layers = config.num_hidden_layers else: num_hidden_layers = num_hidden_layers_override self.layers = nn.ModuleList( [ Siglip2EncoderLayer( config=config, quant_config=quant_config, prefix=f"{prefix}.layers.{idx}", ) for idx in range(num_hidden_layers) ] ) def forward( self, inputs_embeds: torch.Tensor, cu_seqlens: torch.Tensor, max_seqlen: int | torch.Tensor, return_all_hidden_states: bool = False, ) -> torch.Tensor | list[torch.Tensor]: hidden_states_pool = [inputs_embeds] hidden_states = inputs_embeds for encoder_layer in self.layers: hidden_states = encoder_layer( hidden_states, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, ) if return_all_hidden_states: hidden_states_pool.append(hidden_states) if return_all_hidden_states: return hidden_states_pool return hidden_states class Siglip2VisionTransformer(nn.Module): def __init__( self, config: Siglip2VisionConfig, quant_config: QuantizationConfig | None = None, num_hidden_layers_override: int | None = None, require_post_norm: bool | None = None, prefix: str = "", ): super().__init__() embed_dim = config.hidden_size self.config = config self.embeddings = Siglip2VisionEmbeddings(config) # Keep the import local to avoid circular dependencies during model init. from vllm.compilation.backends import set_model_tag with set_model_tag("Siglip2Encoder", is_encoder=True): self.encoder = Siglip2Encoder( config, quant_config=quant_config, num_hidden_layers_override=num_hidden_layers_override, prefix=f"{prefix}.encoder", ) num_hidden_layers = config.num_hidden_layers if len(self.encoder.layers) > config.num_hidden_layers: raise ValueError( f"The original encoder only has {num_hidden_layers} " f"layers, but you requested {len(self.encoder.layers)} layers." ) if require_post_norm is None: require_post_norm = len(self.encoder.layers) == num_hidden_layers if require_post_norm: self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) else: self.post_layernorm = None def get_input_embeddings(self): return self.embeddings def forward( self, pixel_values_packed: torch.FloatTensor, spatial_shapes: torch.LongTensor, cu_seqlens: torch.Tensor, max_seqlen: torch.Tensor, select_layers: list[int] | None = None, ) -> torch.Tensor: r""" spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`): Tensor containing the spatial dimensions (height, width) of the input images. select_layers (`list[int]` or `None`, defaults to `None`): Layer indices to select hidden states from. Supports negative indices (e.g., -1 for last layer, -2 for second-to-last). If None, returns the last layer output. """ hidden_states = self.embeddings(pixel_values_packed, spatial_shapes) encoder_outputs = self.encoder( inputs_embeds=hidden_states, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, return_all_hidden_states=select_layers is not None, ) encoder_outputs = resolve_visual_encoder_outputs( encoder_outputs, self.post_layernorm, select_layers=select_layers, max_possible_layers=self.config.num_hidden_layers, ) return encoder_outputs class Siglip2Model(torch.nn.Module): def __init__( self, config: Siglip2VisionConfig, quant_config: QuantizationConfig | None = None, num_hidden_layers_override: int | None = None, require_post_norm: bool | None = None, prefix: str = "", ): super().__init__() self.vision_model = Siglip2VisionTransformer( config, quant_config=quant_config, num_hidden_layers_override=num_hidden_layers_override, require_post_norm=require_post_norm, prefix=f"{prefix}.vision_model", ) def forward( self, pixel_values_packed: torch.FloatTensor, spatial_shapes: torch.LongTensor, cu_seqlens: torch.Tensor, max_seqlen: torch.Tensor, select_layers: list[int] | None = None, ) -> torch.Tensor: """Forward pass through the vision model. Args: select_layers: Layer indices to select hidden states from. Supports negative indices (e.g., [-2] for second-to-last). If None, returns the last layer output with post_layernorm. Multiple layers can be selected and will be concatenated. """ return self.vision_model( pixel_values_packed=pixel_values_packed, spatial_shapes=spatial_shapes, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, select_layers=select_layers, ) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), ] params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() layer_count = len(self.vision_model.encoder.layers) for name, loaded_weight in weights: # post_layernorm is optional in Siglip2Model if ( name.startswith("vision_model.post_layernorm") and self.vision_model.post_layernorm is None ): continue # omit layers when num_hidden_layers_override is set if name.startswith("vision_model.encoder.layers"): layer_idx = int(name.split(".")[3]) if layer_idx >= layer_count: continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params