295 lines
9.7 KiB
Python
295 lines
9.7 KiB
Python
# Adapted from
|
|
# https://github.com/huggingface/transformers/blob/af9b2eaa54c150741f298d6db939af6328e1dc38/src/transformers/models/siglip/modeling_siglip.py
|
|
|
|
from functools import partial
|
|
from typing import Optional, Type, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from transformers import SiglipVisionConfig
|
|
|
|
from sglang.srt.layers.activation import QuickGELU
|
|
from sglang.srt.layers.attention.vision import VisionAttention
|
|
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
|
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
|
from sglang.srt.utils import add_prefix
|
|
|
|
|
|
# Adapted from transformers.models.siglip.modeling_siglip.SiglipVisionTransformer
|
|
class SiglipVisionEmbeddings(nn.Module):
|
|
|
|
def __init__(self, config: SiglipVisionConfig):
|
|
super().__init__()
|
|
self.config = config
|
|
self.embed_dim = config.hidden_size
|
|
self.image_size = config.image_size
|
|
self.patch_size = config.patch_size
|
|
|
|
self.patch_embedding = nn.Conv2d(
|
|
in_channels=config.num_channels,
|
|
out_channels=self.embed_dim,
|
|
kernel_size=self.patch_size,
|
|
stride=self.patch_size,
|
|
padding="valid",
|
|
)
|
|
|
|
self.num_patches = (self.image_size // self.patch_size) ** 2
|
|
self.num_positions = self.num_patches
|
|
self.position_embedding = VocabParallelEmbedding(
|
|
self.num_positions, self.embed_dim
|
|
)
|
|
self.register_buffer(
|
|
"position_ids",
|
|
torch.arange(self.num_positions).expand((1, -1)),
|
|
persistent=False,
|
|
)
|
|
|
|
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
|
target_dtype = self.patch_embedding.weight.dtype
|
|
patch_embeds = self.patch_embedding(
|
|
pixel_values.to(dtype=target_dtype)
|
|
) # shape = [*, width, grid, grid]
|
|
embeddings = patch_embeds.flatten(2).transpose(1, 2)
|
|
# interpolate_pos_encoding is never used in sglang
|
|
embeddings = embeddings + self.position_embedding(self.position_ids)
|
|
|
|
return embeddings
|
|
|
|
|
|
# Copied from sglang.srt.models.clip.CLIPMLP
|
|
class SiglipMLP(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
config,
|
|
act_layer: Type[nn.Module] = QuickGELU,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
):
|
|
super().__init__()
|
|
self.fc1 = ColumnParallelLinear(
|
|
config.hidden_size,
|
|
config.intermediate_size,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("fc1", prefix),
|
|
)
|
|
self.act = act_layer()
|
|
self.fc2 = RowParallelLinear(
|
|
config.intermediate_size,
|
|
config.hidden_size,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("fc2", prefix),
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
x_parallel, _ = self.fc1(x)
|
|
x_parallel = self.act(x_parallel)
|
|
x, _ = self.fc2(x_parallel)
|
|
return x
|
|
|
|
|
|
# Copied from sglang.srt.models.clip.CLIPEncoderLayer
|
|
class SiglipEncoderLayer(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
config: SiglipVisionConfig,
|
|
act_layer: Type[nn.Module] = QuickGELU,
|
|
norm_layer: Type[nn.Module] = None,
|
|
attn_implementation: Optional[str] = "sdpa",
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
if norm_layer is None:
|
|
norm_layer = partial(nn.LayerNorm, eps=config.layer_norm_eps)
|
|
self.layer_norm1 = norm_layer(config.hidden_size)
|
|
self.layer_norm2 = norm_layer(config.hidden_size)
|
|
if attn_implementation == "sdpa":
|
|
qkv_backend = "sdpa"
|
|
softmax_in_single_precision = False
|
|
elif attn_implementation == "flash_attention_2":
|
|
qkv_backend = "triton_attn"
|
|
softmax_in_single_precision = False
|
|
elif attn_implementation == "eager":
|
|
qkv_backend = "sdpa"
|
|
softmax_in_single_precision = True
|
|
self.self_attn = VisionAttention(
|
|
embed_dim=config.hidden_size,
|
|
num_heads=config.num_attention_heads,
|
|
projection_size=config.hidden_size,
|
|
use_qkv_parallel=True,
|
|
qkv_backend=qkv_backend,
|
|
softmax_in_single_precision=softmax_in_single_precision,
|
|
flatten_batch=True,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("self_attn", prefix),
|
|
)
|
|
self.mlp = SiglipMLP(
|
|
config,
|
|
act_layer=act_layer,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("mlp", prefix),
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: torch.Tensor,
|
|
causal_attention_mask: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
|
|
residual = hidden_states
|
|
hidden_states = self.layer_norm1(hidden_states)
|
|
# Siglip text model uses both `causal_attention_mask` and `attention_mask`
|
|
if attention_mask is not None and causal_attention_mask is not None:
|
|
attn_mask = attention_mask + causal_attention_mask
|
|
elif causal_attention_mask is not None:
|
|
attn_mask = causal_attention_mask
|
|
else:
|
|
attn_mask = attention_mask
|
|
hidden_states = self.self_attn(
|
|
hidden_states,
|
|
attention_mask=attn_mask,
|
|
# causal_attention_mask=causal_attention_mask,
|
|
)
|
|
|
|
hidden_states = residual + hidden_states
|
|
residual = hidden_states
|
|
hidden_states = self.layer_norm2(hidden_states)
|
|
hidden_states = self.mlp(hidden_states)
|
|
hidden_states = residual + hidden_states
|
|
return hidden_states
|
|
|
|
|
|
# Copied from sglang.srt.models.clip.CLIPEncoder
|
|
class SiglipEncoder(nn.Module):
|
|
"""
|
|
Transformer encoder consisting of `config.num_hidden_layers` self
|
|
attention layers. Each layer is a [`SiglipEncoderLayer`].
|
|
|
|
Args:
|
|
config: SiglipConfig
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
config: SiglipVisionConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
self.config = config
|
|
|
|
num_hidden_layers = config.num_hidden_layers
|
|
norm_layer = partial(nn.LayerNorm, eps=config.layer_norm_eps)
|
|
self.layers = nn.ModuleList(
|
|
[
|
|
SiglipEncoderLayer(
|
|
config=config,
|
|
norm_layer=norm_layer,
|
|
attn_implementation="sdpa",
|
|
quant_config=quant_config,
|
|
prefix=add_prefix(f"layers.{layer_idx}", prefix),
|
|
)
|
|
for layer_idx in range(num_hidden_layers)
|
|
]
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
inputs_embeds: torch.Tensor,
|
|
attention_mask: torch.Tensor = None,
|
|
causal_attention_mask: torch.Tensor = None,
|
|
return_all_hidden_states: bool = False,
|
|
) -> Union[torch.Tensor, list[torch.Tensor]]:
|
|
hidden_states_pool = [inputs_embeds]
|
|
hidden_states = inputs_embeds
|
|
|
|
for encoder_layer in self.layers:
|
|
hidden_states = encoder_layer(
|
|
hidden_states, attention_mask, causal_attention_mask
|
|
)
|
|
if return_all_hidden_states:
|
|
hidden_states_pool.append(hidden_states)
|
|
if return_all_hidden_states:
|
|
return hidden_states_pool
|
|
return hidden_states
|
|
|
|
|
|
# Adapted from transformers.models.siglip.modeling_siglip.SiglipVisionTransformer
|
|
class SiglipVisionTransformer(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
config: SiglipVisionConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
self.config = config
|
|
embed_dim = config.hidden_size
|
|
|
|
self.embeddings = SiglipVisionEmbeddings(config)
|
|
|
|
self.encoder = SiglipEncoder(
|
|
config=config,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("encoder", prefix),
|
|
)
|
|
|
|
num_hidden_layers = config.num_hidden_layers
|
|
if len(self.encoder.layers) > config.num_hidden_layers:
|
|
raise ValueError(
|
|
f"The original encoder only has {num_hidden_layers} "
|
|
f"layers, but you requested {len(self.encoder.layers)} layers."
|
|
)
|
|
|
|
# VisionAttention in SiglipEncoderLayer is multihead attention
|
|
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
|
|
|
@property
|
|
def device(self) -> torch.device:
|
|
return self.encoder.layers[0].layer_norm1.weight.device
|
|
|
|
def forward(
|
|
self,
|
|
pixel_values: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
hidden_states = self.embeddings(pixel_values.to(self.device))
|
|
|
|
return_all_hidden_states = False
|
|
|
|
last_hidden_state = self.encoder(
|
|
inputs_embeds=hidden_states,
|
|
return_all_hidden_states=return_all_hidden_states,
|
|
)
|
|
|
|
last_hidden_state = self.post_layernorm(last_hidden_state)
|
|
|
|
return last_hidden_state
|
|
|
|
|
|
# Copied from sglang.srt.models.clip.CLIPVisionModel
|
|
class SiglipVisionModel(nn.Module):
|
|
def __init__(
|
|
self,
|
|
config: SiglipVisionConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
):
|
|
super().__init__()
|
|
self.vision_model = SiglipVisionTransformer(
|
|
config, quant_config, prefix=add_prefix("vision_model", prefix)
|
|
)
|
|
|
|
@property
|
|
def device(self) -> torch.device:
|
|
return self.vision_model.device
|
|
|
|
def forward(self, pixel_values: torch.Tensor):
|
|
return self.vision_model(pixel_values)
|