From 4024e1d2a853903cfddfb70f5e10e165db30305c Mon Sep 17 00:00:00 2001 From: Jiajun Li <48857426+guapisolo@users.noreply.github.com> Date: Tue, 20 May 2025 23:53:46 -0700 Subject: [PATCH] Implement Siglip Vision model, and support BNB quantization for gemma3-mm (#5339) --- python/sglang/srt/models/clip.py | 6 +- python/sglang/srt/models/gemma3_mm.py | 78 ++++--- python/sglang/srt/models/siglip.py | 294 ++++++++++++++++++++++++++ test/srt/test_bnb.py | 4 + 4 files changed, 353 insertions(+), 29 deletions(-) create mode 100644 python/sglang/srt/models/siglip.py diff --git a/python/sglang/srt/models/clip.py b/python/sglang/srt/models/clip.py index 999f77ead..f271b45a4 100644 --- a/python/sglang/srt/models/clip.py +++ b/python/sglang/srt/models/clip.py @@ -168,7 +168,7 @@ class CLIPEncoderLayer(nn.Module): softmax_in_single_precision=softmax_in_single_precision, flatten_batch=True, quant_config=quant_config, - prefix=add_prefix("attn", prefix), + prefix=add_prefix("self_attn", prefix), ) self.mlp = CLIPMLP( config, @@ -395,6 +395,10 @@ class CLIPVisionModel(nn.Module): config, quant_config, prefix=add_prefix("vision_model", prefix) ) + @property + def device(self) -> torch.device: + return self.vision_model.device + def forward(self, pixel_values: torch.Tensor): return self.vision_model(pixel_values) diff --git a/python/sglang/srt/models/gemma3_mm.py b/python/sglang/srt/models/gemma3_mm.py index b26a1d603..a112417c5 100644 --- a/python/sglang/srt/models/gemma3_mm.py +++ b/python/sglang/srt/models/gemma3_mm.py @@ -21,7 +21,7 @@ from typing import Dict, Iterable, List, Optional, Set, Tuple, TypedDict import torch from torch import nn -from transformers import AutoModel, Gemma3Config, PreTrainedModel +from transformers import Gemma3Config, PreTrainedModel from sglang.srt.hf_transformers_utils import get_processor from sglang.srt.layers.layernorm import Gemma3RMSNorm @@ -42,6 +42,7 @@ from sglang.srt.model_loader.weight_utils import ( maybe_remap_kv_scale_name, ) from sglang.srt.models.gemma3_causal import Gemma3ForCausalLM +from sglang.srt.models.siglip import SiglipVisionModel from sglang.srt.utils import add_prefix logger = logging.getLogger(__name__) @@ -118,6 +119,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel): ".k_proj.", ".v_proj.", ".o_proj.", + ".out_proj.", ] bitsandbytes_stacked_params_mapping = { # shard_name, weight_name, index @@ -126,6 +128,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel): "v_proj": ("qkv_proj", 2), "gate_proj": ("gate_up_proj", 0), "up_proj": ("gate_up_proj", 1), + "out_proj": ("proj", 0), } packed_modules_mapping = { @@ -161,20 +164,21 @@ class Gemma3ForConditionalGeneration(PreTrainedModel): super().__init__(config=config) self.config = config self.quant_config = quant_config - # Vision components - # TODO: replace with vision attention - # self.vision_tower = SiglipVisionModel( - # config.vision_config, - # quant_config, - # prefix=add_prefix("vision_tower", prefix), - # ) - self.vision_tower = AutoModel.from_config(config=config.vision_config) + + self.vision_tower = SiglipVisionModel( + config=config.vision_config, + quant_config=quant_config, + prefix=add_prefix("vision_tower", prefix), + ) + self.multi_modal_projector = Gemma3MultiModalProjector(config) self.vocab_size = config.text_config.vocab_size # Text model self.language_model = Gemma3ForCausalLM( - config.text_config, quant_config, prefix=add_prefix("model", prefix) + config.text_config, + quant_config, + prefix=add_prefix("language_model", prefix), ) if self.language_model.logits_processor.logit_scale: logit_scale = getattr(config, "logit_scale", 1.0) @@ -290,7 +294,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel): pixel_values = pixel_values.to(device=self.vision_tower.device) pixel_values = pixel_values.to(dtype=self.language_model.dtype()) - vision_outputs = self.vision_tower(pixel_values=pixel_values).last_hidden_state + vision_outputs = self.vision_tower(pixel_values=pixel_values) image_features = self.multi_modal_projector(vision_outputs) return image_features @@ -366,6 +370,14 @@ class Gemma3ForConditionalGeneration(PreTrainedModel): return self.language_model.tie_weights() def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + ("gate_up_proj", "up_proj", 1), + ("gate_up_proj", "gate_proj", 0), + ] """Load weights for the model.""" params_dict = dict(self.named_parameters()) loaded_params: Set[str] = set() @@ -379,21 +391,33 @@ class Gemma3ForConditionalGeneration(PreTrainedModel): loaded_params.update(causal_loaded_params) continue else: - # Skip lm_head.weight as it's tied with embed_tokens - if "lm_head.weight" in name: - continue - - # Skip loading extra bias for GPTQ models - if name.endswith(".bias") and name not in params_dict: - continue - - # Remapping the name of FP8 kv-scale - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + if "vision_model" in name: + # adapt to VisionAttention + name = name.replace(".self_attn.out_proj", ".self_attn.proj") + # Skip loading extra bias for GPTQ models + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) loaded_params.add(name) unloaded_params = params_dict.keys() - loaded_params if unloaded_params: @@ -404,5 +428,3 @@ class Gemma3ForConditionalGeneration(PreTrainedModel): EntryClass = Gemma3ForConditionalGeneration - -AutoModel.register(Gemma3Config, Gemma3ForConditionalGeneration, exist_ok=True) diff --git a/python/sglang/srt/models/siglip.py b/python/sglang/srt/models/siglip.py new file mode 100644 index 000000000..2a76dc286 --- /dev/null +++ b/python/sglang/srt/models/siglip.py @@ -0,0 +1,294 @@ +# Adapted from +# https://github.com/huggingface/transformers/blob/af9b2eaa54c150741f298d6db939af6328e1dc38/src/transformers/models/siglip/modeling_siglip.py + +from functools import partial +from typing import Optional, Type, Union + +import torch +import torch.nn as nn +from transformers import SiglipVisionConfig + +from sglang.srt.layers.activation import QuickGELU +from sglang.srt.layers.attention.vision import VisionAttention +from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding +from sglang.srt.utils import add_prefix + + +# Adapted from transformers.models.siglip.modeling_siglip.SiglipVisionTransformer +class SiglipVisionEmbeddings(nn.Module): + + def __init__(self, config: SiglipVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + self.position_embedding = VocabParallelEmbedding( + self.num_positions, self.embed_dim + ) + self.register_buffer( + "position_ids", + torch.arange(self.num_positions).expand((1, -1)), + persistent=False, + ) + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding( + pixel_values.to(dtype=target_dtype) + ) # shape = [*, width, grid, grid] + embeddings = patch_embeds.flatten(2).transpose(1, 2) + # interpolate_pos_encoding is never used in sglang + embeddings = embeddings + self.position_embedding(self.position_ids) + + return embeddings + + +# Copied from sglang.srt.models.clip.CLIPMLP +class SiglipMLP(nn.Module): + + def __init__( + self, + config, + act_layer: Type[nn.Module] = QuickGELU, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.fc1 = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + quant_config=quant_config, + prefix=add_prefix("fc1", prefix), + ) + self.act = act_layer() + self.fc2 = RowParallelLinear( + config.intermediate_size, + config.hidden_size, + quant_config=quant_config, + prefix=add_prefix("fc2", prefix), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_parallel, _ = self.fc1(x) + x_parallel = self.act(x_parallel) + x, _ = self.fc2(x_parallel) + return x + + +# Copied from sglang.srt.models.clip.CLIPEncoderLayer +class SiglipEncoderLayer(nn.Module): + + def __init__( + self, + config: SiglipVisionConfig, + act_layer: Type[nn.Module] = QuickGELU, + norm_layer: Type[nn.Module] = None, + attn_implementation: Optional[str] = "sdpa", + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + if norm_layer is None: + norm_layer = partial(nn.LayerNorm, eps=config.layer_norm_eps) + self.layer_norm1 = norm_layer(config.hidden_size) + self.layer_norm2 = norm_layer(config.hidden_size) + if attn_implementation == "sdpa": + qkv_backend = "sdpa" + softmax_in_single_precision = False + elif attn_implementation == "flash_attention_2": + qkv_backend = "triton_attn" + softmax_in_single_precision = False + elif attn_implementation == "eager": + qkv_backend = "sdpa" + softmax_in_single_precision = True + self.self_attn = VisionAttention( + embed_dim=config.hidden_size, + num_heads=config.num_attention_heads, + projection_size=config.hidden_size, + use_qkv_parallel=True, + qkv_backend=qkv_backend, + softmax_in_single_precision=softmax_in_single_precision, + flatten_batch=True, + quant_config=quant_config, + prefix=add_prefix("self_attn", prefix), + ) + self.mlp = SiglipMLP( + config, + act_layer=act_layer, + quant_config=quant_config, + prefix=add_prefix("mlp", prefix), + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + causal_attention_mask: torch.Tensor, + ) -> torch.Tensor: + + residual = hidden_states + hidden_states = self.layer_norm1(hidden_states) + # Siglip text model uses both `causal_attention_mask` and `attention_mask` + if attention_mask is not None and causal_attention_mask is not None: + attn_mask = attention_mask + causal_attention_mask + elif causal_attention_mask is not None: + attn_mask = causal_attention_mask + else: + attn_mask = attention_mask + hidden_states = self.self_attn( + hidden_states, + attention_mask=attn_mask, + # causal_attention_mask=causal_attention_mask, + ) + + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +# Copied from sglang.srt.models.clip.CLIPEncoder +class SiglipEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self + attention layers. Each layer is a [`SiglipEncoderLayer`]. + + Args: + config: SiglipConfig + """ + + def __init__( + self, + config: SiglipVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + + self.config = config + + num_hidden_layers = config.num_hidden_layers + norm_layer = partial(nn.LayerNorm, eps=config.layer_norm_eps) + self.layers = nn.ModuleList( + [ + SiglipEncoderLayer( + config=config, + norm_layer=norm_layer, + attn_implementation="sdpa", + quant_config=quant_config, + prefix=add_prefix(f"layers.{layer_idx}", prefix), + ) + for layer_idx in range(num_hidden_layers) + ] + ) + + def forward( + self, + inputs_embeds: torch.Tensor, + attention_mask: torch.Tensor = None, + causal_attention_mask: torch.Tensor = None, + return_all_hidden_states: bool = False, + ) -> Union[torch.Tensor, list[torch.Tensor]]: + hidden_states_pool = [inputs_embeds] + hidden_states = inputs_embeds + + for encoder_layer in self.layers: + hidden_states = encoder_layer( + hidden_states, attention_mask, causal_attention_mask + ) + if return_all_hidden_states: + hidden_states_pool.append(hidden_states) + if return_all_hidden_states: + return hidden_states_pool + return hidden_states + + +# Adapted from transformers.models.siglip.modeling_siglip.SiglipVisionTransformer +class SiglipVisionTransformer(nn.Module): + + def __init__( + self, + config: SiglipVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + + self.config = config + embed_dim = config.hidden_size + + self.embeddings = SiglipVisionEmbeddings(config) + + self.encoder = SiglipEncoder( + config=config, + quant_config=quant_config, + prefix=add_prefix("encoder", prefix), + ) + + num_hidden_layers = config.num_hidden_layers + if len(self.encoder.layers) > config.num_hidden_layers: + raise ValueError( + f"The original encoder only has {num_hidden_layers} " + f"layers, but you requested {len(self.encoder.layers)} layers." + ) + + # VisionAttention in SiglipEncoderLayer is multihead attention + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + @property + def device(self) -> torch.device: + return self.encoder.layers[0].layer_norm1.weight.device + + def forward( + self, + pixel_values: torch.Tensor, + ) -> torch.Tensor: + hidden_states = self.embeddings(pixel_values.to(self.device)) + + return_all_hidden_states = False + + last_hidden_state = self.encoder( + inputs_embeds=hidden_states, + return_all_hidden_states=return_all_hidden_states, + ) + + last_hidden_state = self.post_layernorm(last_hidden_state) + + return last_hidden_state + + +# Copied from sglang.srt.models.clip.CLIPVisionModel +class SiglipVisionModel(nn.Module): + def __init__( + self, + config: SiglipVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.vision_model = SiglipVisionTransformer( + config, quant_config, prefix=add_prefix("vision_model", prefix) + ) + + @property + def device(self) -> torch.device: + return self.vision_model.device + + def forward(self, pixel_values: torch.Tensor): + return self.vision_model(pixel_values) diff --git a/test/srt/test_bnb.py b/test/srt/test_bnb.py index fd4050888..1d9f0201d 100644 --- a/test/srt/test_bnb.py +++ b/test/srt/test_bnb.py @@ -33,11 +33,14 @@ VISION_MODELS = [ "unsloth/Qwen2-VL-7B-Instruct-bnb-4bit", "unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit", "unsloth/Llama-3.2-11B-Vision-bnb-4bit", + "unsloth/gemma-3-4b-it-bnb-4bit", + "unsloth/gemma-3-4b-it-unsloth-bnb-4bit", ] LANGUAGE_MODELS = [ "unsloth/Qwen2.5-7B-Instruct-bnb-4bit", "unsloth/Qwen2-7B-Instruct-bnb-4bit", "unsloth/Llama-3.2-3B-Instruct-bnb-4bit", + "unsloth/gemma-3-1b-it-bnb-4bit", ] # image @@ -256,6 +259,7 @@ class TestVisionModel(CustomTestCase): "0.6", "--load-format", "bitsandbytes", + "--enable-multimodal", ] try: process = popen_launch_server_wrapper(