diff --git a/docs/references/supported_models.md b/docs/references/supported_models.md index ea7c34dd1..09dc551d2 100644 --- a/docs/references/supported_models.md +++ b/docs/references/supported_models.md @@ -32,6 +32,7 @@ - Phi-3-Small - IBM Granite 3 - Janus-Pro-1B / Janus-Pro-7B +- Gemma 3 (it) ## Embedding Models diff --git a/python/sglang/lang/chat_template.py b/python/sglang/lang/chat_template.py index e17168b90..8677b99b3 100644 --- a/python/sglang/lang/chat_template.py +++ b/python/sglang/lang/chat_template.py @@ -520,6 +520,14 @@ def match_granite_instruct(model_path: str): return get_chat_template("granite-3-instruct") +@register_chat_template_matching_function +def match_gemma3_instruct(model_path: str): + model_path = model_path.lower() + if "gemma-3" in model_path and "1b" not in model_path: + # gemma-3-1b-it is completion model + return get_chat_template("gemma-it") + + if __name__ == "__main__": messages = [ {"role": "system", "content": None}, # None means default diff --git a/python/sglang/srt/configs/__init__.py b/python/sglang/srt/configs/__init__.py index f7ae0108e..765a3f7e2 100644 --- a/python/sglang/srt/configs/__init__.py +++ b/python/sglang/srt/configs/__init__.py @@ -1,6 +1,7 @@ from sglang.srt.configs.chatglm import ChatGLMConfig from sglang.srt.configs.dbrx import DbrxConfig from sglang.srt.configs.exaone import ExaoneConfig +from sglang.srt.configs.gemma3 import Gemma3Config, Gemma3TextConfig from sglang.srt.configs.janus_pro import MultiModalityConfig from sglang.srt.configs.qwen2_5_vl_config import ( Qwen2_5_VLConfig, @@ -14,4 +15,6 @@ __all__ = [ "Qwen2_5_VLConfig", "Qwen2_5_VLVisionConfig", "MultiModalityConfig", + "Gemma3Config", + "Gemma3TextConfig", ] diff --git a/python/sglang/srt/configs/gemma3.py b/python/sglang/srt/configs/gemma3.py new file mode 100644 index 000000000..b70089f5c --- /dev/null +++ b/python/sglang/srt/configs/gemma3.py @@ -0,0 +1,1086 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/gemma3/modular_gemma3.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_gemma3.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import itertools +import logging +import math +import re +from typing import Dict, Iterable, List, Optional, Union + +import numpy as np +import PIL +import transformers +from torch import TensorType +from transformers import ( + AutoImageProcessor, + AutoProcessor, + BatchFeature, + PretrainedConfig, + SiglipVisionConfig, +) +from transformers.image_processing_utils import BaseImageProcessor, get_size_dict +from transformers.image_transforms import ( + convert_to_rgb, + resize, + to_channel_dimension_format, +) +from transformers.image_utils import ( + ChannelDimension, + ImageInput, + PILImageResampling, + get_image_size, + infer_channel_dimension_format, + is_pil_image, + is_scaled_image, + is_valid_image, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from transformers.modeling_rope_utils import rope_config_validation +from transformers.processing_utils import ( + ImagesKwargs, + ProcessingKwargs, + ProcessorMixin, + Unpack, +) +from transformers.tokenization_utils_base import PreTokenizedInput, TextInput +from transformers.utils import ( + IMAGENET_STANDARD_MEAN, + IMAGENET_STANDARD_STD, + filter_out_non_signature_kwargs, + to_py_obj, +) + +logger = logging.getLogger(__name__) + + +def is_valid_list_of_images(images: List): + return images and all(is_valid_image(image) for image in images) + + +# copied from transformer +def make_nested_list_of_images( + images: Union[List[ImageInput], ImageInput], +) -> ImageInput: + """ + Ensure that the output is a nested list of images. + Args: + images (`Union[List[ImageInput], ImageInput]`): + The input image. + Returns: + list: A list of list of images or a list of 4d array of images. + """ + # If it's a list of batches, it's already in the right format + if ( + isinstance(images, (list, tuple)) + and all(isinstance(images_i, (list, tuple)) for images_i in images) + and all(is_valid_list_of_images(images_i) for images_i in images) + ): + return images + + # If it's a list of images, it's a single batch, so convert it to a list of lists + if isinstance(images, (list, tuple)) and is_valid_list_of_images(images): + if is_pil_image(images[0]) or images[0].ndim == 3: + return [images] + if images[0].ndim == 4: + return [list(image) for image in images] + + # If it's a single image, convert it to a list of lists + if is_valid_image(images): + if is_pil_image(images) or images.ndim == 3: + return [[images]] + if images.ndim == 4: + return [list(images)] + + raise ValueError( + "Invalid input type. Must be a single image, a list of images, or a list of batches of images." + ) + + +def rescale( + image: np.ndarray, + scale: float, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, +) -> np.ndarray: + """ + Rescale an image by a scale factor. image = image * scale. + + Args: + image (`np.ndarray`): + Image to rescale. + scale (`float`): + The scaling factor to rescale pixel values by. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the output image. If unset, the channel dimension format of the input + image is used. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + + Returns: + `np.ndarray`: The rescaled image. + """ + return transformers.image_transforms.rescale( + image, + scale=scale, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + +def normalize( + image: np.ndarray, + mean: Union[float, Iterable[float]], + std: Union[float, Iterable[float]], + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, +) -> np.ndarray: + """ + Normalize an image. image = (image - image_mean) / image_std. + + Args: + image (`np.ndarray`): + Image to normalize. + mean (`float` or `Iterable[float]`): + Image mean to use for normalization. + std (`float` or `Iterable[float]`): + Image standard deviation to use for normalization. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the output image. If unset, the channel dimension format of the input + image is used. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + + Returns: + `np.ndarray`: The normalized image. + """ + return transformers.image_transforms.normalize( + image, + mean=mean, + std=std, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + +class Gemma3ImagesKwargs(ImagesKwargs): + do_pan_and_scan: Optional[bool] + pan_and_scan_min_crop_size: Optional[int] + pan_and_scan_max_num_crops: Optional[int] + pan_and_scan_min_ratio_to_activate: Optional[float] + do_convert_rgb: Optional[bool] + + +class Gemma3ProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: Gemma3ImagesKwargs + _defaults = { + "text_kwargs": { + "padding": False, + }, + "images_kwargs": { + "do_pan_and_scan": False, + "pan_and_scan_min_crop_size": 256, + "pan_and_scan_max_num_crops": 4, + "pan_and_scan_min_ratio_to_activate": 1.2, + }, + } + + +class Gemma3Processor(ProcessorMixin): + attributes = ["image_processor", "tokenizer"] + valid_kwargs = ["chat_template", "image_seq_length"] + image_processor_class = "AutoImageProcessor" + tokenizer_class = "AutoTokenizer" + + def __init__( + self, + image_processor, + tokenizer, + chat_template=None, + image_seq_length: int = 256, + **kwargs, + ): + + self.image_seq_length = image_seq_length + self.image_token_id = tokenizer.image_token_id + self.boi_token = tokenizer.boi_token + image_tokens_expanded = "".join([tokenizer.image_token] * image_seq_length) + self.full_image_sequence = ( + f"\n\n{tokenizer.boi_token}{image_tokens_expanded}{tokenizer.eoi_token}\n\n" + ) + + super().__init__( + image_processor=image_processor, + tokenizer=tokenizer, + chat_template=chat_template, + **kwargs, + ) + + # TODO: if transformers is updated, the chat_template needs to be adjusted + self.tokenizer.add_bos_token = False + + def __call__( + self, + images: ImageInput = None, + text: Union[ + TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput] + ] = None, + videos=None, + audio=None, + **kwargs: Unpack[Gemma3ProcessorKwargs], + ) -> BatchFeature: + if text is None and images is None: + raise ValueError("Provide at least one of `text` or `images`.") + # print(f"processing, text:{text}") + output_kwargs = self._merge_kwargs( + Gemma3ProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + + if isinstance(text, str): + text = [text] + elif not isinstance(text, list) and not isinstance(text[0], str): + raise ValueError( + "Invalid input text. Please provide a string, or a list of strings" + ) + + image_inputs = {} + if images is not None: + batched_images = make_nested_list_of_images(images) + image_inputs = self.image_processor( + batched_images, **output_kwargs["images_kwargs"] + ) + + # Create empty text to be replaced with placeholders + if not text: + text = [ + " ".join([self.boi_token] * len(images)) + for images in batched_images + ] + + if len(batched_images) != len(text): + raise ValueError( + f"Received inconsistently sized batches of images ({len(batched_images)}) and text ({len(text)})." + ) + + # Replace image tokens by the full expanded sequence + batch_num_crops = to_py_obj(image_inputs.pop("num_crops")) + text_with_crops = text + + for batch_idx, (prompt, images, num_crops) in enumerate( + zip(text, batched_images, batch_num_crops) + ): + + image_indexes = [m.start() for m in re.finditer(self.boi_token, prompt)] + + if len(images) != len(image_indexes): + raise ValueError( + f"Prompt contained {len(image_indexes)} image tokens but received {len(images)} images." + ) + + # Insert additional image tokens for Pan-and-Scan crops + for num, idx in reversed(list(zip(num_crops, image_indexes))): + if num: + formatted_image_text = ( + f"Here is the original image {self.boi_token} and here are some crops to help you see better " + + " ".join([self.boi_token] * num) + ) + prompt = ( + prompt[:idx] + + formatted_image_text + + prompt[idx + len(self.boi_token) :] + ) + text_with_crops[batch_idx] = prompt + + # Expand placeholder image tokens to the full image token sequence + text = [ + prompt.replace(self.boi_token, self.full_image_sequence) + for prompt in text + ] + + return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) + text_inputs = self.tokenizer( + text=text, **output_kwargs["text_kwargs"], return_tensors="np" + ) + + # print(f"processing, text_inputs:{text_inputs}") + + # Add token type ids manually, as tokenizer can't do arbitrary position token types + array_ids = np.array(text_inputs["input_ids"]) + mm_token_type_ids = np.zeros_like(text_inputs["input_ids"]) + mm_token_type_ids[array_ids == self.image_token_id] = 1 + text_inputs = { + k: v.tolist() for k, v in text_inputs.items() + } # in case user requested list inputs + text_inputs["token_type_ids"] = mm_token_type_ids.tolist() + return BatchFeature( + data={**text_inputs, **image_inputs}, tensor_type=return_tensors + ) + + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Gemma + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Gemma + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + ["token_type_ids"] + image_processor_input_names = self.image_processor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) + + +class Gemma3ImageProcessor(BaseImageProcessor): + r""" + Constructs a SigLIP image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by + `do_resize` in the `preprocess` method. + size (`Dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`): + Size of the image after resizing. Can be overridden by `size` in the `preprocess` method. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`): + Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in + the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess` + method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image by the specified mean and standard deviation. Can be overridden by + `do_normalize` in the `preprocess` method. + image_mean (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + Can be overridden by the `image_std` parameter in the `preprocess` method. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + do_pan_and_scan (`bool`, *optional*): + Whether to apply `pan_and_scan` to images. + pan_and_scan_min_crop_size (`int`, *optional*): + Minimum size of each crop in pan and scan. + pan_and_scan_max_num_crops (`int`, *optional*): + Maximum number of crops per image in pan and scan. + pan_and_scan_min_ratio_to_activate (`float`, *optional*): + Minimum aspect ratio to activate pan and scan. + """ + + model_input_names = ["pixel_values", "num_crops"] + + def __init__( + self, + do_resize: bool = True, + size: Dict[str, int] = None, + resample: PILImageResampling = PILImageResampling.BILINEAR, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: bool = None, + do_pan_and_scan: bool = None, + pan_and_scan_min_crop_size: int = None, + pan_and_scan_max_num_crops: int = None, + pan_and_scan_min_ratio_to_activate: float = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"height": 224, "width": 224} + size = get_size_dict(size, default_to_square=True) + image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN + image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD + + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean + self.image_std = image_std + self.do_convert_rgb = do_convert_rgb + self.do_pan_and_scan = do_pan_and_scan + self.pan_and_scan_min_crop_size = pan_and_scan_min_crop_size + self.pan_and_scan_max_num_crops = pan_and_scan_max_num_crops + self.pan_and_scan_min_ratio_to_activate = pan_and_scan_min_ratio_to_activate + + def pan_and_scan( + self, + image: np.ndarray, + pan_and_scan_min_crop_size: int, + pan_and_scan_max_num_crops: int, + pan_and_scan_min_ratio_to_activate: float, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + """ + Pan and Scan and image, by cropping into smaller images when the aspect ratio exceeds + minumum allowed ratio. + + Args: + image (`np.ndarray`): + Image to resize. + pan_and_scan_min_crop_size (`int`, *optional*): + Minimum size of each crop in pan and scan. + pan_and_scan_max_num_crops (`int`, *optional*): + Maximum number of crops per image in pan and scan. + pan_and_scan_min_ratio_to_activate (`float`, *optional*): + Minimum aspect ratio to activate pan and scan. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + height, width = get_image_size(image) + + # Square or landscape image. + if width >= height: + # Only apply PaS if the image is sufficiently exaggerated + if width / height < pan_and_scan_min_ratio_to_activate: + return [] + + # Select ideal number of crops close to the image aspect ratio and such that crop_size > min_crop_size. + num_crops_w = int( + math.floor(width / height + 0.5) + ) # Half round up rounding. + num_crops_w = min( + int(math.floor(width / pan_and_scan_min_crop_size)), num_crops_w + ) + + # Make sure the number of crops is in range [2, pan_and_scan_max_num_crops]. + num_crops_w = max(2, num_crops_w) + num_crops_w = min(pan_and_scan_max_num_crops, num_crops_w) + num_crops_h = 1 + + # Portrait image. + else: + # Only apply PaS if the image is sufficiently exaggerated + if height / width < pan_and_scan_min_ratio_to_activate: + return [] + + # Select ideal number of crops close to the image aspect ratio and such that crop_size > min_crop_size. + num_crops_h = int(math.floor(height / width + 0.5)) + num_crops_h = min( + int(math.floor(height / pan_and_scan_min_crop_size)), num_crops_h + ) + + # Make sure the number of crops is in range [2, pan_and_scan_max_num_crops]. + num_crops_h = max(2, num_crops_h) + num_crops_h = min(pan_and_scan_max_num_crops, num_crops_h) + num_crops_w = 1 + + crop_size_w = int(math.ceil(width / num_crops_w)) + crop_size_h = int(math.ceil(height / num_crops_h)) + + # Don't apply PaS if crop size is too small. + if min(crop_size_w, crop_size_h) < pan_and_scan_min_crop_size: + return [] + + crop_positions_w = [crop_size_w * i for i in range(num_crops_w)] + crop_positions_h = [crop_size_h * i for i in range(num_crops_h)] + + if input_data_format == ChannelDimension.LAST: + image_crops = [ + image[pos_h : pos_h + crop_size_h, pos_w : pos_w + crop_size_w] + for pos_h, pos_w in itertools.product( + crop_positions_h, crop_positions_w + ) + ] + else: + image_crops = [ + image[:, pos_h : pos_h + crop_size_h, pos_w : pos_w + crop_size_w] + for pos_h, pos_w in itertools.product( + crop_positions_h, crop_positions_w + ) + ] + + return image_crops + + def _process_images_for_pan_and_scan( + self, + images: List[np.ndarray], + do_pan_and_scan: bool, + pan_and_scan_min_crop_size: int, + pan_and_scan_max_num_crops: int, + pan_and_scan_min_ratio_to_activate: float, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + pas_images_list = [] + num_crops = [] + for image in images: + pas_images = self.pan_and_scan( + image=image, + pan_and_scan_min_crop_size=pan_and_scan_min_crop_size, + pan_and_scan_max_num_crops=pan_and_scan_max_num_crops, + pan_and_scan_min_ratio_to_activate=pan_and_scan_min_ratio_to_activate, + data_format=data_format, + input_data_format=input_data_format, + ) + pas_images_list.extend([image] + pas_images) + num_crops.append(len(pas_images)) + return pas_images_list, num_crops + + @filter_out_non_signature_kwargs() + def preprocess( + self, + images: ImageInput, + do_resize: bool = None, + size: Dict[str, int] = None, + resample: PILImageResampling = None, + do_rescale: bool = None, + rescale_factor: float = None, + do_normalize: bool = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + do_convert_rgb: bool = None, + do_pan_and_scan: bool = None, + pan_and_scan_min_crop_size: int = None, + pan_and_scan_max_num_crops: int = None, + pan_and_scan_min_ratio_to_activate: float = None, + ) -> PIL.Image.Image: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after resizing. + resample (`int`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only + has an effect if `do_resize` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to + `True`. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + do_pan_and_scan (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to apply `pan_and_scan` to images. + pan_and_scan_min_crop_size (`int`, *optional*, defaults to `self.pan_and_scan_min_crop_size`): + Minimum size of each crop in pan and scan. + pan_and_scan_max_num_crops (`int`, *optional*, defaults to `self.pan_and_scan_max_num_crops`): + Maximum number of crops per image in pan and scan. + pan_and_scan_min_ratio_to_activate (`float`, *optional*, defaults to `self.pan_and_scan_min_ratio_to_activate`): + Minimum aspect ratio to activate pan and scan. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + size = size if size is not None else self.size + size = get_size_dict(size, param_name="size", default_to_square=False) + resample = resample if resample is not None else self.resample + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = ( + rescale_factor if rescale_factor is not None else self.rescale_factor + ) + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + do_convert_rgb = ( + do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + ) + do_pan_and_scan = ( + do_pan_and_scan if do_pan_and_scan is not None else self.do_pan_and_scan + ) + pan_and_scan_min_crop_size = ( + pan_and_scan_min_crop_size + if pan_and_scan_min_crop_size is not None + else self.pan_and_scan_min_crop_size + ) + pan_and_scan_max_num_crops = ( + pan_and_scan_max_num_crops + if pan_and_scan_max_num_crops is not None + else self.pan_and_scan_max_num_crops + ) + pan_and_scan_min_ratio_to_activate = ( + pan_and_scan_min_ratio_to_activate + if pan_and_scan_min_ratio_to_activate is not None + else self.pan_and_scan_min_ratio_to_activate + ) + + images_list = make_nested_list_of_images(images) + + if not valid_images(images_list[0]): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_resize=do_resize, + size=size, + resample=resample, + ) + if do_convert_rgb: + images_list = [ + [convert_to_rgb(image) for image in images] for images in images_list + ] + + # All transformations expect numpy arrays. + images_list = [ + [to_numpy_array(image) for image in images] for images in images_list + ] + + if do_rescale and is_scaled_image(images_list[0][0]): + logger.warning( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images_list[0][0]) + + if do_pan_and_scan: + images_list_and_num_crops = [ + self._process_images_for_pan_and_scan( + images=images, + do_pan_and_scan=do_pan_and_scan, + pan_and_scan_min_crop_size=pan_and_scan_min_crop_size, + pan_and_scan_max_num_crops=pan_and_scan_max_num_crops, + pan_and_scan_min_ratio_to_activate=pan_and_scan_min_ratio_to_activate, + data_format=data_format, + input_data_format=input_data_format, + ) + for images in images_list + ] + images_list = [images for images, _ in images_list_and_num_crops] + num_crops = [num_crops for _, num_crops in images_list_and_num_crops] + else: + num_crops = [[0] for images in images_list] + + processed_images = [] + for images in images_list: + for image in images: + if do_resize: + height, width = size["height"], size["width"] + image = resize( + image=image, + size=(height, width), + resample=resample, + input_data_format=input_data_format, + ) + + if do_rescale: + image = rescale( + image=image, + scale=rescale_factor, + input_data_format=input_data_format, + ) + + if do_normalize: + image = normalize( + image=image, + mean=image_mean, + std=image_std, + input_data_format=input_data_format, + ) + + image = to_channel_dimension_format( + image, data_format, input_channel_dim=input_data_format + ) + processed_images.append(image) + + data = {"pixel_values": processed_images, "num_crops": num_crops} + return BatchFeature(data=data, tensor_type=return_tensors) + + +class Gemma3TextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Gemma3TextModel`]. It is used to instantiate an Gemma3Text + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the Gemma3Text-7B. + e.g. [google/gemma3_text-7b](https://huggingface.co/google/gemma3_text-7b) + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 262208): + Vocabulary size of the Gemma3Text model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Gemma3TextModel`] + hidden_size (`int`, *optional*, defaults to 2304): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 9216): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 26): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*, defaults to 4): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + head_dim (`int`, *optional*, defaults to 256): + The attention head dimension. + hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"` + if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function. + max_position_embeddings (`int`, *optional*, defaults to 131072): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*, defaults to 0): + Padding token id. + eos_token_id (`int`, *optional*, defaults to 1): + End of stream token id. + bos_token_id (`int`, *optional*, defaults to 2): + Beginning of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `True`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 1000000.0): + The base period of the RoPE embeddings. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + query_pre_attn_scalar (`float`, *optional*, defaults to 256): + Scaling factor used on the attention scores + sliding_window (`int`, *optional*, defaults to 4096): in Gemma3Text, every other layer uses sliding window attention. This is the + size of the sliding window. + final_logit_softcapping (`float`, *optional*): + Scaling factor when applying tanh softcapping on the logits. + attn_logit_softcapping (`float`, *optional*): + Scaling factor when applying tanh softcapping on the attention scores. + cache_implementation (`str`, *optional*, defaults to `"hybrid"`): the cache type to be used with `generate`. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings used in gloabl attention. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + rope_local_base_freq (float, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings for local attention. + sliding_window_pattern (`int`, *optional*, defaults to 6): + Pattern for the sliding window attention. + + ```python + >>> from transformers import Gemma3TextModel, Gemma3TextConfig + >>> # Initializing a Gemma3Text gemma3_text-7b style configuration + >>> configuration = Gemma3TextConfig() + >>> # Initializing a model from the gemma3_text-7b style configuration + >>> model = Gemma3TextModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + rope_local_base_freq (float, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings for local attention. + sliding_window_pattern (`int`, *optional*, defaults to 6): + Pattern for the sliding window attention. + """ + + model_type = "gemma3_text" + keys_to_ignore_at_inference = ["past_key_values"] + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=262_208, + hidden_size=2304, + intermediate_size=9216, + num_hidden_layers=26, + num_attention_heads=8, + num_key_value_heads=4, + head_dim=256, + hidden_activation="gelu_pytorch_tanh", + max_position_embeddings=131_072, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=0, + eos_token_id=1, + bos_token_id=2, + tie_word_embeddings=True, + rope_theta=1_000_000.0, + attention_bias=False, + attention_dropout=0.0, + query_pre_attn_scalar=256, + sliding_window=4096, + final_logit_softcapping=None, + attn_logit_softcapping=None, + cache_implementation="hybrid", + rope_scaling=None, + rope_local_base_freq=10_000.0, + sliding_window_pattern=6, + **kwargs, + ): + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.head_dim = head_dim + self.num_key_value_heads = num_key_value_heads + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.hidden_activation = hidden_activation + self.query_pre_attn_scalar = query_pre_attn_scalar + self.sliding_window = sliding_window + self.final_logit_softcapping = final_logit_softcapping + self.attn_logit_softcapping = attn_logit_softcapping + self.cache_implementation = cache_implementation + + self.rope_local_base_freq = rope_local_base_freq + # For configuring HybridCache to work with 5:1 attention pattern + self.sliding_window_pattern = sliding_window_pattern + self.rope_scaling = rope_scaling + rope_config_validation(self) + + +class Gemma3Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Gemma3ForConditionalGeneration`]. It is used to instantiate an + Gemma3ForConditionalGeneration according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the PaliGemma-2B. + + e.g. [google/gemma-3-4b](https://huggingface.co/google/gemma-3-4b) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + text_config (`Union[Gemma3TextConfig, dict]`, *optional*): + The config object of the text backbone. + vision_config (`Union[AutoConfig, dict]`, *optional*): + Custom vision config or dict. + mm_tokens_per_image (`int`, *optional*, defaults to 256): + The number of tokens per image embedding. + boi_token_index (`int`, *optional*, defaults to 255999): + The begin-of-image token index to wrap the image prompt. + eoi_token_index (`int`, *optional*, defaults to 256000): + The end-of-image token index to wrap the image prompt. + image_token_index (`int`, *optional*, defaults to 262144): + The image token index to encode the image prompt. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + + + Example: + + ```python + >>> from transformers import Gemma3ForConditionalGeneration, Gemma3Config, SiglipVisionConfig, Gemma3TextConfig + + >>> # Initializing a Siglip-like vision config + >>> vision_config = SiglipVisionConfig() + + >>> # Initializing a Gemma3 Text config + >>> text_config = Gemma3TextConfig() + + >>> # Initializing a Gemma3 gemma-3-4b style configuration + >>> configuration = Gemma3Config(vision_config, text_config) + + >>> # Initializing a model from the gemma-3-4b style configuration + >>> model = Gemma3TextConfig(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "gemma3" + sub_configs = { + "text_config": Gemma3TextConfig, + "vision_config": SiglipVisionConfig, + } + + def __init__( + self, + text_config: Optional[Gemma3TextConfig] = None, + vision_config: Optional[SiglipVisionConfig] = None, + mm_tokens_per_image: int = 256, + boi_token_index: int = 255_999, + eoi_token_index: int = 256_000, + image_token_index: int = 262_144, + initializer_range: float = 0.02, + **kwargs, + ): + if text_config is None: + text_config = Gemma3TextConfig() + # logger.info( + # "text_config is None, using default Gemma3TextConfig config." + # ) + elif isinstance(text_config, dict): + text_config = Gemma3TextConfig(**text_config) + + if isinstance(vision_config, dict): + vision_config = SiglipVisionConfig(**vision_config) + elif isinstance(vision_config, SiglipVisionConfig): + pass + else: + # logger.info( + # "vision_config is None or incompatible with Gemma3VisionConfig initialization. Gemma3 will be limited " + # "to text tasks." + # ) + # logger.info(f"vision_config: {vision_config}") + vision_config = SiglipVisionConfig() + + self.text_config = text_config + self.vision_config = vision_config + self.mm_tokens_per_image = mm_tokens_per_image + self.boi_token_index = boi_token_index + self.eoi_token_index = eoi_token_index + self.image_token_index = image_token_index + self.initializer_range = initializer_range + + super().__init__(**kwargs) + + +AutoProcessor.register( + config_class=Gemma3Config, processor_class=Gemma3Processor, exist_ok=True +) + +AutoImageProcessor.register( + config_class=Gemma3Config, + image_processor_class=None, + slow_image_processor_class=Gemma3ImageProcessor, + fast_image_processor_class=None, + exist_ok=True, +) diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index cb31edd1b..22174c922 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -391,9 +391,13 @@ def _get_and_verify_dtype( dtype = dtype.lower() if dtype == "auto": if config_dtype == torch.float32: - if config.model_type == "gemma2": + if config.model_type.startswith("gemma"): + if config.model_type == "gemma": + gemma_version = "" + else: + gemma_version = config.model_type[5] logger.info( - "For Gemma 2, we downcast float32 to bfloat16 instead " + f"For Gemma {gemma_version}, we downcast float32 to bfloat16 instead " "of float16 by default. Please specify `dtype` if you " "want to use float16." ) @@ -453,6 +457,7 @@ multimodal_model_archs = [ "LlavaQwenForCausalLM", "LlavaMistralForCausalLM", "LlavaVidForCausalLM", + "Gemma3ForConditionalGeneration", "Grok1VForCausalLM", "Grok1AForCausalLM", "MllamaForConditionalGeneration", diff --git a/python/sglang/srt/conversation.py b/python/sglang/srt/conversation.py index 5c8f37643..6255126be 100644 --- a/python/sglang/srt/conversation.py +++ b/python/sglang/srt/conversation.py @@ -45,6 +45,7 @@ class SeparatorStyle(IntEnum): DEEPSEEK_CHAT = auto() METAMATH = auto() QWEN2_VL_EMBED = auto() + GEMMA3 = auto() @dataclasses.dataclass @@ -285,6 +286,18 @@ class Conversation: else: ret += role + ":" return ret + elif self.sep_style == SeparatorStyle.GEMMA3: + ret = system_prompt + for i, (role, message) in enumerate(self.messages): + if message: + if i == 0: + ret += message + self.sep + else: + ret += role + message + self.sep + else: + ret += role + return ret + else: raise ValueError(f"Invalid style: {self.sep_style}") @@ -604,6 +617,20 @@ register_conv_template( ) ) +# Reference: https://huggingface.co/google/gemma-3-4b-it/blob/main/config.json +register_conv_template( + Conversation( + name="gemma-it", + system_message="You are a helpful assistant.", + system_template="user{system_message}\n\n", + roles=("user\n", "model\n"), + sep="\n", + sep_style=SeparatorStyle.GEMMA3, + stop_str=[""], + image_token="", + ) +) + # Reference: https://huggingface.co/Alibaba-NLP/gme-Qwen2-VL-2B-Instruct#usage register_conv_template( Conversation( diff --git a/python/sglang/srt/hf_transformers_utils.py b/python/sglang/srt/hf_transformers_utils.py index 5cd52bf4a..987cc98dc 100644 --- a/python/sglang/srt/hf_transformers_utils.py +++ b/python/sglang/srt/hf_transformers_utils.py @@ -34,6 +34,8 @@ from sglang.srt.configs import ( ChatGLMConfig, DbrxConfig, ExaoneConfig, + Gemma3Config, + Gemma3TextConfig, MultiModalityConfig, Qwen2_5_VLConfig, ) @@ -46,6 +48,8 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { ExaoneConfig.model_type: ExaoneConfig, Qwen2_5_VLConfig.model_type: Qwen2_5_VLConfig, MultiModalityConfig.model_type: MultiModalityConfig, + Gemma3Config.model_type: Gemma3Config, + Gemma3TextConfig.model_type: Gemma3TextConfig, } for name, cls in _CONFIG_REGISTRY.items(): diff --git a/python/sglang/srt/layers/attention/vision.py b/python/sglang/srt/layers/attention/vision.py index f27d8c781..7b23ae82e 100644 --- a/python/sglang/srt/layers/attention/vision.py +++ b/python/sglang/srt/layers/attention/vision.py @@ -19,34 +19,10 @@ from sglang.srt.layers.linear import ( RowParallelLinear, ) from sglang.srt.layers.quantization import QuantizationConfig +from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb, rotate_half from sglang.srt.utils import add_prefix -# Copied from transformers, modeling_qwen2_vl.py -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb_vision( - q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor -) -> Tuple[torch.Tensor, torch.Tensor]: - orig_q_dtype = q.dtype - orig_k_dtype = k.dtype - q, k = q.float(), k.float() - - cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float() - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - - q_embed = q_embed.to(orig_q_dtype) - k_embed = k_embed.to(orig_k_dtype) - - return q_embed, k_embed - - class VisionAttention(nn.Module): r""" Multi-headed attention without any cache, mostly used for ViT. @@ -168,7 +144,7 @@ class VisionAttention(nn.Module): cos, sin = position_embeddings original_shape = q.shape q, k = q.view(s, head, -1), k.view(s, head, -1) - q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) + q, k = apply_rotary_pos_emb(q, k, cos, sin) q, k = q.reshape(original_shape), k.reshape(original_shape) if self.use_qkv_parallel: diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index 7f6ac5f69..47dccc9f9 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -119,6 +119,26 @@ class GemmaRMSNorm(CustomOp): return out +class Gemma3RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.zeros(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()) + # Llama does x.to(float16) * w whilst Gemma3 is (x * w).to(float16) + # See https://github.com/huggingface/transformers/pull/29402 + output = output * (1.0 + self.weight.float()) + return output.type_as(x) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.eps}" + + if not _is_cuda: logger.info( "sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries." diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index f5608cf5a..3cab22306 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -1173,6 +1173,37 @@ def get_rope( return rotary_emb +# Copied from transformers +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + unsqueeze_dim=1, +) -> Tuple[torch.Tensor, torch.Tensor]: + orig_q_dtype = q.dtype + orig_k_dtype = k.dtype + q, k = q.float(), k.float() + + # embedding is performed in float + cos = cos.unsqueeze(unsqueeze_dim).float() + sin = sin.unsqueeze(unsqueeze_dim).float() + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + + q_embed = q_embed.to(orig_q_dtype) + k_embed = k_embed.to(orig_k_dtype) + + return q_embed, k_embed + + def get_rope_cpu( head_size: int, rotary_dim: int, diff --git a/python/sglang/srt/managers/image_processors/base_image_processor.py b/python/sglang/srt/managers/image_processors/base_image_processor.py index c4349d16c..86bacc2f4 100644 --- a/python/sglang/srt/managers/image_processors/base_image_processor.py +++ b/python/sglang/srt/managers/image_processors/base_image_processor.py @@ -111,7 +111,7 @@ class BaseImageProcessor(ABC): def load_images( self, - input_ids: list, + input_ids: list[int], image_data, image_token: str, max_req_input_len: int, @@ -122,22 +122,21 @@ class BaseImageProcessor(ABC): Each frame of video/image will be replaced by a single image token Args: - discard_alpha_channel: if True, discards the alpha channel in the returned images """ - image_hashes, image_sizes = [], [] - all_frames = [] - new_text_parts = [] if isinstance(input_ids, list) and return_text: assert len(input_ids) and isinstance(input_ids[0], int) input_text = self._processor.tokenizer.decode(input_ids) else: input_text = input_ids - if return_text: - text_parts = input_text.split(image_token) + import re + + pattern = "(" + "|".join(re.escape(sep) for sep in [image_token]) + ")" + # split text into list of normal text and special tokens + text_parts = re.split(pattern, input_text) # TODO(mick): load from server_args, env, or sampling_params MAX_NUM_FRAMES = 30 @@ -145,53 +144,65 @@ class BaseImageProcessor(ABC): total_frame_count = sum(estimated_frames_list) # a heuristic value, suggesting the maximum fraction of frames to embed from all visual inputs. # e.g., 0.1 suggests that 1 frame out of 10 input frames should be used - scaling_factor = min(1.0, MAX_NUM_FRAMES / total_frame_count) + _scaling_factor = min(1.0, MAX_NUM_FRAMES / max(1, total_frame_count)) assert len(image_data) == len(estimated_frames_list) - # Process each input with allocated frames - for image_index, (image, estimated_frames) in enumerate( - zip(image_data, estimated_frames_list) - ): - if len(all_frames) >= MAX_NUM_FRAMES: - max_frames_to_process = 0 - else: - max_frames_to_process = max(1, int(estimated_frames * scaling_factor)) - - if max_frames_to_process == 0: - frames = [] - else: - try: - if isinstance(image, str) and image.startswith("video:"): - path = image[len("video:") :] - frames = BaseImageProcessor.encode_video( - path, frame_count_limit=max_frames_to_process - ) + image_index, audio_index = 0, 0 + hashes, image_sizes, images, audios = [], [], [], [] + new_text = "" + for index, text_part in enumerate(text_parts): + try: + if text_part == image_token: + # load as image + frames_to_process = estimated_frames_list[image_index] + if frames_to_process == 0: + frames = [] else: - raw_image, _size = load_image(image) - if discard_alpha_channel: - raw_image = raw_image.convert("RGB") - frames = [raw_image] - assert len(frames) != 0 - except FileNotFoundError as e: - print(e) - return None + image_file = image_data[image_index] + if isinstance(image_file, str) and image_file.startswith( + "video:" + ): + # video + path = image_file[len("video:") :] + frames = self.encode_video( + path, frame_count_limit=frames_to_process + ) + else: + # image + raw_image, _size = load_image(image_file) + if discard_alpha_channel: + raw_image = raw_image.convert("RGB") + frames = [raw_image] + if len(frames) == 0: + continue - image_sizes += [frames[0].size] * len(frames) - image_hashes += [hash(image)] * len(frames) - all_frames += frames + image_sizes += frames[0].size * len(frames) + hashes += [hash(image_file)] * len(frames) + images += frames + image_index += 1 + if frames_to_process != 0: + new_text += image_token * len(frames) + assert frames_to_process == len(frames) + else: + # TODO(mick): handle video + # normal text + new_text += text_part - if return_text: - new_text_parts.append(text_parts[image_index]) - if max_frames_to_process != 0: - new_text_parts.append(image_token * len(frames)) - assert max_frames_to_process >= len(frames) - if return_text: - new_text_parts.append(text_parts[-1]) + except Exception as e: + import openai + + logger.error(f"An exception occurred while loading images: {e}") + raise BadRequestError( + f"An exception occurred while loading images: {e}" + ) + continue - input_text = "".join(new_text_parts) return BaseImageProcessorOutput( - image_hashes, image_sizes, all_frames, input_text + image_hashes=hashes, + image_sizes=image_sizes, + all_frames=images, + input_text=new_text, ) diff --git a/python/sglang/srt/managers/image_processors/gemma3.py b/python/sglang/srt/managers/image_processors/gemma3.py new file mode 100644 index 000000000..a54efb8d9 --- /dev/null +++ b/python/sglang/srt/managers/image_processors/gemma3.py @@ -0,0 +1,100 @@ +import asyncio +from typing import List, Union + +from transformers.utils import logging + +from sglang.srt.managers.image_processor import ( + BaseImageProcessor as SGLangBaseImageProcessor, +) +from sglang.srt.managers.image_processors.base_image_processor import ( + get_global_processor, +) +from sglang.srt.models.gemma3_mm import Gemma3ForConditionalGeneration + +# Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma3/image_processing_gemma3_fast.py +# will be removed in the future +logger = logging.get_logger(__name__) + + +class Gemma3SGLangImageProcessor(SGLangBaseImageProcessor): + def __init__(self, hf_config, server_args, _processor): + super().__init__(hf_config, server_args, _processor) + self.IMAGE_TOKEN = "" + self.IM_START_TOKEN_ID = hf_config.boi_token_index + self.IM_END_TOKEN_ID = hf_config.eoi_token_index + + @staticmethod + def _process_images_task(images, input_text, _hf_config): + if isinstance(images, list) and len(images) == 0: + images = None + processor = get_global_processor() + result = processor.__call__( + text=[input_text], + images=images, + padding=True, + return_tensors="pt", + # if RGBA, this needs to be set + # images_kwargs={ + # "input_data_format": ChannelDimension.FIRST + # } + ) + + pixel_values = getattr(result, "pixel_values", None) + + return { + "input_ids": result.input_ids, + "pixel_values": pixel_values, + } + + async def _process_images(self, images, input_text) -> dict: + if self.executor is not None: + loop = asyncio.get_event_loop() + return await loop.run_in_executor( + self.executor, + Gemma3SGLangImageProcessor._process_images_task, + images, + input_text, + self.hf_config, + ) + else: + return self._process_images_task(images, input_text, self.hf_config) + + async def process_images_async( + self, + image_data: List[Union[str, bytes]], + input_ids, + request_obj, + max_req_input_len, + *args, + **kwargs, + ): + if not image_data: + return None + if isinstance(image_data, str): + image_data = [image_data] + + image_token = self.IMAGE_TOKEN + base_output = self.load_images( + input_ids=input_ids, + image_data=image_data, + image_token=image_token, + max_req_input_len=max_req_input_len, + discard_alpha_channel=True, + ) + + ret = await self._process_images( + input_text=base_output.input_text, images=base_output.all_frames + ) + + return { + "input_ids": ret["input_ids"].flatten().tolist(), + "pixel_values": ret["pixel_values"], + "image_hashes": base_output.image_hashes, + "im_start_id": self.IM_START_TOKEN_ID, + "im_end_id": self.IM_END_TOKEN_ID, + } + + +ImageProcessorMapping = { + Gemma3ForConditionalGeneration: Gemma3SGLangImageProcessor, +} diff --git a/python/sglang/srt/managers/image_processors/janus_pro.py b/python/sglang/srt/managers/image_processors/janus_pro.py index a3d25c989..36db528d3 100644 --- a/python/sglang/srt/managers/image_processors/janus_pro.py +++ b/python/sglang/srt/managers/image_processors/janus_pro.py @@ -60,7 +60,10 @@ class JanusProProcessor(SGLangBaseImageProcessor): image_data = [image_data] base_out = self.load_images( - input_ids, image_data, "", max_req_input_len + input_ids=input_ids, + image_data=image_data, + image_token="", + max_req_input_len=max_req_input_len, ) images = base_out.all_frames res = await self._process_images(images=images, input_text=base_out.input_text) diff --git a/python/sglang/srt/managers/image_processors/minicpmv.py b/python/sglang/srt/managers/image_processors/minicpmv.py index 1b36f7fe4..4e10092bf 100644 --- a/python/sglang/srt/managers/image_processors/minicpmv.py +++ b/python/sglang/srt/managers/image_processors/minicpmv.py @@ -52,7 +52,10 @@ class MiniCPMVImageProcessor(BaseImageProcessor): image_data = [image_data] base_output = self.load_images( - input_ids, image_data, self.IMAGE_TOKEN, max_req_input_len + input_ids=input_ids, + image_data=image_data, + image_token=self.IMAGE_TOKEN, + max_req_input_len=max_req_input_len, ) if base_output is None: return None diff --git a/python/sglang/srt/managers/image_processors/qwen_vl.py b/python/sglang/srt/managers/image_processors/qwen_vl.py index a0594cd1d..46add1383 100644 --- a/python/sglang/srt/managers/image_processors/qwen_vl.py +++ b/python/sglang/srt/managers/image_processors/qwen_vl.py @@ -72,10 +72,10 @@ class Qwen2_5VLImageProcessor(BaseImageProcessor): image_token = self.IMAGE_TOKEN base_output = self.load_images( - input_ids, - image_data, - image_token, - max_req_input_len, + input_ids=input_ids, + image_data=image_data, + image_token=image_token, + max_req_input_len=max_req_input_len, ) def smart_resize( diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 011ff3ed0..86be904e8 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -49,7 +49,7 @@ from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, Forw from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server_args import ServerArgs -from sglang.srt.utils import get_compiler_backend, next_power_of_2 +from sglang.srt.utils import get_compiler_backend if TYPE_CHECKING: from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput @@ -207,6 +207,9 @@ class ImageInputs: return ret def merge(self, other): + """ + merge image inputs when requests are being merged + """ assert self.pixel_values.shape[1:] == other.pixel_values.shape[1:] self.pixel_values = np.concatenate([self.pixel_values, other.pixel_values]) diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 35667465a..230e9e0f7 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -33,6 +33,7 @@ from dataclasses import dataclass from enum import IntEnum, auto from typing import TYPE_CHECKING, List, Optional, Union +import numpy as np import torch import triton import triton.language as tl @@ -331,6 +332,32 @@ class ForwardBatch: return ret + def get_merged_image_inputs(self) -> Optional[ImageInputs]: + """ + Merge all image inputs in the batch into a single ImageInputs object. + + Returns: + if none, current batch contains no image input + + """ + if not self.image_inputs or all(x is None for x in self.image_inputs): + return None + + # Filter out None values + valid_inputs = [x for x in self.image_inputs if x is not None] + + # Start with the first valid image input + merged = valid_inputs[0] + + # Merge remaining inputs + for img_input in valid_inputs[1:]: + merged.merge(img_input) + + if isinstance(merged.pixel_values, np.ndarray): + merged.pixel_values = torch.from_numpy(merged.pixel_values) + + return merged + def _compute_mrope_positions( self, model_runner: ModelRunner, batch: ModelWorkerBatch ): diff --git a/python/sglang/srt/models/gemma3_causal.py b/python/sglang/srt/models/gemma3_causal.py new file mode 100644 index 000000000..b2fa84eb3 --- /dev/null +++ b/python/sglang/srt/models/gemma3_causal.py @@ -0,0 +1,687 @@ +# Copyright 2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import copy +from typing import Iterable, Optional, Set, Tuple + +import einops +import torch +import torch.nn.functional as F +from torch import nn +from transformers import ( + ROPE_INIT_FUNCTIONS, + AutoModel, + PretrainedConfig, + PreTrainedModel, +) + +from sglang.srt.configs.gemma3 import Gemma3TextConfig +from sglang.srt.distributed import get_tensor_model_parallel_world_size +from sglang.srt.layers.activation import GeluAndMul +from sglang.srt.layers.layernorm import Gemma3RMSNorm +from sglang.srt.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb, get_rope +from sglang.srt.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) +from sglang.srt.utils import add_prefix, make_layers + + +# Adapted from: +# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/gemma3.py +def extract_layer_index(prefix: str) -> int: + """Extract the layer index from a prefix string.""" + parts = prefix.split(".") + for part in parts: + if part.startswith("layers."): + layer_str = part.split(".")[-1] + try: + return int(layer_str) + except ValueError: + continue + return -1 + + +class Gemma3MLP(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_activation: str, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=add_prefix("gate_up_proj", prefix), + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=add_prefix("down_proj", prefix), + ) + if hidden_activation != "gelu_pytorch_tanh": + raise ValueError( + "Gemma3 uses `gelu_pytorch_tanh` as the hidden activation " + "function. Please set `hidden_activation` to " + "`gelu_pytorch_tanh`." + ) + self.act_fn = GeluAndMul() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class Gemma3Attention(nn.Module): + def __init__( + self, + layer_id: int, + config: Gemma3TextConfig, + max_position_embeddings: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.layer_id = layer_id + self.config = config + tp_size = get_tensor_model_parallel_world_size() + + self.total_num_heads = config.num_attention_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = config.num_key_value_heads + + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + + hidden_size = config.hidden_size + + head_dim = getattr( + config, "head_dim", hidden_size // config.num_attention_heads + ) + self.head_dim = head_dim + + self.q_size = self.num_heads * self.head_dim + + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = config.query_pre_attn_scalar**-0.5 + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=config.attention_bias, + quant_config=quant_config, + prefix=add_prefix("qkv_proj", prefix), + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=config.attention_bias, + quant_config=quant_config, + prefix=add_prefix("o_proj", prefix), + ) + + # Determine if layer uses sliding window based on pattern + self.is_sliding = bool((layer_id + 1) % config.sliding_window_pattern) + + # Initialize the rotary embedding. + if self.is_sliding: + # Local attention. Override the values in config.json. + self.rope_theta = config.rope_local_base_freq + self.rope_scaling = {"rope_type": "default"} + # FIXME(mick): idk why vllm does this + # self.sliding_window = config.interleaved_sliding_window + self.sliding_window = config.sliding_window + else: + # Global attention. Use the values in config.json. + self.rope_theta = config.rope_theta + self.rope_scaling = config.rope_scaling + self.sliding_window = None + + self.attn = RadixAttention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + layer_id=layer_id, + logit_cap=getattr(self.config, "attn_logit_softcapping", None), + sliding_window_size=self.sliding_window, + prefix=add_prefix("attn", prefix), + ) + + # Gemma3 adds normalization for q and k + self.q_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps) + self.k_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps) + + def naive_attn_with_masks( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + q = q.view(-1, self.num_heads, self.head_dim) + # Expand the key and value to handle GQA. + num_queries_per_kv = self.num_heads // self.num_kv_heads + k = k.view(-1, self.num_kv_heads, self.head_dim) + k = k.repeat_interleave(num_queries_per_kv, dim=-2) + v = v.view(-1, self.num_kv_heads, self.head_dim) + v = v.repeat_interleave(num_queries_per_kv, dim=-2) + + if self.is_sliding: + attn_masks = kwargs["local_attn_masks"] + else: + attn_masks = kwargs["global_attn_masks"] + + seq_lens = kwargs["seq_lens"] + start_idx = 0 + for seq_len, attn_mask in zip(seq_lens, attn_masks): + end_idx = start_idx + seq_len + query = q[start_idx:end_idx].unsqueeze(0) + key = k[start_idx:end_idx].unsqueeze(0) + value = v[start_idx:end_idx].unsqueeze(0) + + # Transpose. + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + output = F.scaled_dot_product_attention( + query, + key, + value, + attn_mask, + self.scaling, + ) + output = output.transpose(1, 2).flatten(-2, -1) + out[start_idx:end_idx] = output + start_idx = end_idx + return out + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + forward_batch: ForwardBatch, + **kwargs, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + # [s, h * head_dim] + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + # [s, h, head_dim] + q = q.unflatten(-1, (self.num_heads, self.head_dim)) + # -> [h, s, head_dim] + q = q.transpose(0, 1).unsqueeze(0) + q = self.q_norm(q) + k = k.unflatten(-1, (self.num_kv_heads, self.head_dim)) + # -> [h, s, head_dim] + k = k.transpose(0, 1).unsqueeze(0) + k = self.k_norm(k) + + # q, k = self.rotary_emb(positions, q, k) + cos, sin = position_embeddings + q, k = apply_rotary_pos_emb(q, k, cos, sin) + + # [b, h, s, head_dim] -> [b, s, h, head_dim] + q = q.permute(0, 2, 1, 3) + k = k.permute(0, 2, 1, 3) + + attn_output = self.attn(q, k, v, forward_batch=forward_batch) + output, _ = self.o_proj(attn_output) + return output + + +class Gemma3DecoderLayer(nn.Module): + def __init__( + self, + layer_id: int, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = Gemma3Attention( + layer_id=layer_id, + config=config, + max_position_embeddings=config.max_position_embeddings, + quant_config=quant_config, + prefix=add_prefix("self_attn", prefix), + ) + self.hidden_size = config.hidden_size + self.mlp = Gemma3MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_activation=config.hidden_activation, + quant_config=quant_config, + prefix=add_prefix("mlp", prefix), + ) + self.input_layernorm = Gemma3RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = Gemma3RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.pre_feedforward_layernorm = Gemma3RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_feedforward_layernorm = Gemma3RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.is_sliding = self.self_attn.is_sliding + self.layer_id = layer_id + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + position_embeddings_global: torch.Tensor, + position_embeddings_local: torch.Tensor, + forward_batch: ForwardBatch, + **kwargs, + ) -> tuple[ + torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + # apply global RoPE to non-sliding layer only + if self.self_attn.is_sliding: + position_embeddings = position_embeddings_local + else: + position_embeddings = position_embeddings_global + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + position_embeddings=position_embeddings, + forward_batch=forward_batch, + **kwargs, + ) + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.pre_feedforward_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + return outputs + + +class Gemma3RotaryEmbedding(nn.Module): + def __init__(self, config: Gemma3TextConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get( + "rope_type", config.rope_scaling.get("type") + ) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len + ) + self.register_buffer( + "inv_freq", inv_freq, persistent=False + ) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if ( + seq_len < self.original_max_seq_len + and self.max_seq_len_cached > self.original_max_seq_len + ): # reset + # This .to() is needed if the model has been moved to a device after being initialized (because + # the buffer is automatically moved, but not the original copy) + self.original_inv_freq = self.original_inv_freq.to(device) + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = ( + self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + ) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = ( + device_type + if isinstance(device_type, str) and device_type != "mps" + else "cpu" + ) + with torch.autocast(device_type=device_type, enabled=False): + freqs = ( + inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float() + ).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class Gemma3TextScaledWordEmbedding(nn.Embedding): + """ + This module overrides nn.Embeddings' forward by multiplying with embeddings scale. + """ + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int, + embed_scale: Optional[float] = 1.0, + ): + super().__init__(num_embeddings, embedding_dim, padding_idx) + self.embed_scale = embed_scale + + def forward(self, input_ids: torch.Tensor): + return super().forward(input_ids) * self.embed_scale + + +class Gemma3TextModel(PreTrainedModel): + def __init__( + self, + config: Gemma3TextConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__(config=config) + self.config = config + self.quant_config = quant_config + + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + # Gemma3 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5. See https://github.com/huggingface/transformers/pull/29402 + self.embed_tokens = Gemma3TextScaledWordEmbedding( + config.vocab_size, + config.hidden_size, + self.padding_idx, + embed_scale=self.config.hidden_size**0.5, + ) + + self.norm = Gemma3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Gemma3RotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # when we want to create a local RoPE layer. Config defaults should hold values for global RoPE + config = copy.deepcopy(config) + config.rope_theta = config.rope_local_base_freq + config.rope_scaling = {"rope_type": "default"} + self.rotary_emb_local = Gemma3RotaryEmbedding(config=config) + + self.layers = make_layers( + config.num_hidden_layers, + lambda idx, prefix: Gemma3DecoderLayer( + layer_id=idx, + config=config, + quant_config=quant_config, + prefix=prefix, + ), + prefix=add_prefix("layers", prefix), + ) + self.norm = Gemma3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_init() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + if input_embeds is None: + hidden_states = self.embed_tokens(input_ids) + else: + hidden_states = input_embeds + + if len(positions.shape) == 1: + positions = einops.rearrange(positions, "s -> 1 s") + + position_embeddings_global = self.rotary_emb(hidden_states, positions) + position_embeddings_local = self.rotary_emb_local(hidden_states, positions) + for layer in self.layers: + layer_outputs = layer( + positions=positions, + position_embeddings_global=position_embeddings_global, + position_embeddings_local=position_embeddings_local, + hidden_states=hidden_states, + forward_batch=forward_batch, + **kwargs, + ) + hidden_states = layer_outputs[0] + + hidden_states = self.norm(hidden_states) + + return hidden_states + + +class Gemma3ForCausalLM(PreTrainedModel): + config_class = Gemma3TextConfig + + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + config_class = Gemma3TextConfig + base_model_prefix = "language_model" + + # BitandBytes specific attributes + default_bitsandbytes_target_modules = [ + ".gate_proj.", + ".down_proj.", + ".up_proj.", + ".q_proj.", + ".k_proj.", + ".v_proj.", + ".o_proj.", + ] + bitsandbytes_stacked_params_mapping = { + # shard_name, weight_name, index + "q_proj": ("qkv_proj", 0), + "k_proj": ("qkv_proj", 1), + "v_proj": ("qkv_proj", 2), + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), + } + + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "gate_up_proj", + "down_proj", + ] + # Gemma does not apply LoRA to the embedding layer. + embedding_modules = {} + embedding_padding_modules = [] + supports_lora = True + + def __init__( + self, + config: Gemma3TextConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__(config=config) + self.config = config + self.quant_config = quant_config + self.model = Gemma3TextModel( + config, quant_config, prefix=add_prefix("model", prefix) + ) + self.logits_processor = LogitsProcessor(config) + + if self.config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=add_prefix("lm_head", prefix), + ) + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def dtype(self) -> torch.dtype: + return self.model.layers[0].mlp.gate_up_proj.weight.dtype + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + **kwargs, + ) -> LogitsProcessor: + + hidden_states = self.model( + input_ids, positions, forward_batch, input_embeds, **kwargs + ) + + return self.logits_processor( + input_ids, hidden_states, self.model.embed_tokens, forward_batch + ) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + for param_name, shard_name, shard_id in stacked_params_mapping: + # if param_name in name: + # print(f"{param_name} is already in {name}") + if shard_name not in name: + continue + name = name.replace(shard_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # lm_head is not used in vllm as it is tied with embed_token. + # To prevent errors, skip loading lm_head.weight. + if "lm_head.weight" in name: + continue + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + # unloaded_params = params_dict.keys() - loaded_params + # if unloaded_params: + # logger.warning( + # "Some weights are not initialized from checkpoints: %s", unloaded_params + # ) + return loaded_params + + +EntryClass = Gemma3ForCausalLM +AutoModel.register(Gemma3TextConfig, Gemma3ForCausalLM, exist_ok=True) diff --git a/python/sglang/srt/models/gemma3_mm.py b/python/sglang/srt/models/gemma3_mm.py new file mode 100644 index 000000000..561b7e834 --- /dev/null +++ b/python/sglang/srt/models/gemma3_mm.py @@ -0,0 +1,462 @@ +# Copyright 2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# Adapted from: +# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/gemma3_mm.py + +import logging +from functools import lru_cache +from typing import Dict, Iterable, List, Optional, Set, Tuple, TypedDict + +import torch +from torch import nn +from transformers import AutoModel, PreTrainedModel + +from sglang.srt.configs import Gemma3Config +from sglang.srt.hf_transformers_utils import get_processor +from sglang.srt.layers.layernorm import Gemma3RMSNorm +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.managers.multi_modality_padding import ( + MultiModalityDataPaddingPatternTokenPairs, +) +from sglang.srt.managers.schedule_batch import ImageInputs +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) +from sglang.srt.models.gemma3_causal import Gemma3ForCausalLM +from sglang.srt.utils import add_prefix + +logger = logging.getLogger(__name__) + +cached_get_processor = lru_cache(get_processor) + + +class Gemma3ImagePixelInputs(TypedDict): + pixel_values: torch.Tensor + """Shape: `(batch_size * num_images, num_channels, height, width)`""" + + +class Gemma3MultiModalProjector(nn.Module): + """Projector for Gemma3 multimodal.""" + + def __init__(self, config: Gemma3Config): + super().__init__() + + self.mm_input_projection_weight = nn.Parameter( + torch.zeros( + config.vision_config.hidden_size, config.text_config.hidden_size + ) + ) + + self.mm_soft_emb_norm = Gemma3RMSNorm( + config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps + ) + + self.patches_per_image = int( + config.vision_config.image_size // config.vision_config.patch_size + ) + self.tokens_per_side = int(config.mm_tokens_per_image**0.5) + self.kernel_size = self.patches_per_image // self.tokens_per_side + self.avg_pool = nn.AvgPool2d( + kernel_size=self.kernel_size, stride=self.kernel_size + ) + + def forward(self, vision_outputs: torch.Tensor) -> torch.Tensor: + batch_size, seq_length, hidden_size = vision_outputs.shape + + # Reshape for pooling + reshaped_vision_outputs = vision_outputs.transpose(1, 2) + reshaped_vision_outputs = reshaped_vision_outputs.reshape( + batch_size, hidden_size, self.patches_per_image, self.patches_per_image + ) + reshaped_vision_outputs = reshaped_vision_outputs.contiguous() + + # Apply pooling + pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs) + pooled_vision_outputs = pooled_vision_outputs.flatten(2) + pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2) + + # Apply normalization + normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs) + + # Project to text embedding space + projected_vision_outputs = torch.matmul( + normed_vision_outputs, self.mm_input_projection_weight + ) + + return projected_vision_outputs.type_as(vision_outputs) + + +class Gemma3ForConditionalGeneration(PreTrainedModel): + config_class = Gemma3Config + """Gemma3 multimodal model for conditional generation.""" + + # BitandBytes specific attributes + default_bitsandbytes_target_modules = [ + ".gate_proj.", + ".down_proj.", + ".up_proj.", + ".q_proj.", + ".k_proj.", + ".v_proj.", + ".o_proj.", + ] + bitsandbytes_stacked_params_mapping = { + # shard_name, weight_name, index + "q_proj": ("qkv_proj", 0), + "k_proj": ("qkv_proj", 1), + "v_proj": ("qkv_proj", 2), + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), + } + + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "gate_up_proj", + "down_proj", + ] + # Gemma does not apply LoRA to the embedding layer. + embedding_modules = {} + embedding_padding_modules = [] + supports_lora = True + + def __init__( + self, + config: Gemma3Config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__(config=config) + self.config = config + self.quant_config = quant_config + # Vision components + # TODO: replace with vision attention + # self.vision_tower = SiglipVisionModel( + # config.vision_config, + # quant_config, + # prefix=add_prefix("vision_tower", prefix), + # ) + self.vision_tower = AutoModel.from_config(config=config.vision_config) + self.multi_modal_projector = Gemma3MultiModalProjector(config) + self.vocab_size = config.text_config.vocab_size + + # Text model + self.language_model = Gemma3ForCausalLM( + config.text_config, quant_config, prefix=add_prefix("model", prefix) + ) + if self.language_model.logits_processor.logit_scale: + logit_scale = getattr(config, "logit_scale", 1.0) + self.language_model.logits_processor.logit_scale *= logit_scale + self.post_init() + + def pad_input_ids( + self, input_ids: List[int], image_inputs: ImageInputs + ) -> List[int]: + """Pad input IDs with image tokens.""" + # Get special token IDs + im_start_id: int = image_inputs.im_start_id + im_end_id: int = image_inputs.im_end_id + + media_token_pairs = [(im_start_id, im_end_id)] + pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs) + ids = pattern.pad_input_tokens(input_ids, image_inputs) + return ids + + def prepare_attn_masks( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + mask_dtype: torch.dtype, + **kwargs, + ) -> Dict: + """Prepare attention masks for multimodal inputs.""" + kwargs["has_images"] = True + + # Distinguish sequences by position id 0 + start_indices = (positions == 0).cpu().nonzero() + num_seqs = len(start_indices) + seq_lens = [] + + for i in range(num_seqs): + start_idx = start_indices[i].item() + if i < num_seqs - 1: + end_idx = start_indices[i + 1].item() + else: + end_idx = len(input_ids) + seq_lens.append(end_idx - start_idx) + + kwargs["seq_lens"] = seq_lens + + # Create attention masks + global_attn_masks = [] + local_attn_masks = [] + sliding_window = self.config.text_config.interleaved_sliding_window + + start_idx = 0 + for seq_len in seq_lens: + end_idx = start_idx + seq_len + input_token_ids = input_ids[start_idx:end_idx] + start_idx = end_idx + + # Create global causal mask + global_attn_mask = torch.empty( + 1, + 1, + seq_len, + seq_len, + dtype=mask_dtype, + device=input_ids.device, + ) + global_attn_mask.fill_(float("-inf")) + global_attn_mask = global_attn_mask.triu(diagonal=1) + + # Consider bidirectional attention between image tokens + img_mask = torch.zeros_like(global_attn_mask) + img_pos = input_token_ids == self.config.image_token_index + img_mask[:, :, :, img_pos] += 1 + img_mask[:, :, img_pos, :] += 1 + global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask) + global_attn_masks.append(global_attn_mask) + + # Create local causal mask with sliding window + local_attn_mask = torch.ones_like(global_attn_mask) + local_attn_mask = torch.tril(local_attn_mask, diagonal=-sliding_window) + local_attn_mask = torch.where( + local_attn_mask == 0, global_attn_mask, float("-inf") + ) + local_attn_masks.append(local_attn_mask) + + kwargs["global_attn_masks"] = global_attn_masks + kwargs["local_attn_masks"] = local_attn_masks + return kwargs + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def get_image_features(self, pixel_values: torch.Tensor): + """ + Projects the last hidden state from the vision model into language model space. + + Args: + pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`) + The tensors corresponding to the input images. + Returns: + image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). + """ + pixel_values = pixel_values.to("cuda") + pixel_values = pixel_values.to(dtype=self.language_model.dtype()) + + vision_outputs = self.vision_tower(pixel_values=pixel_values).last_hidden_state + image_features = self.multi_modal_projector(vision_outputs) + return image_features + + def embed_image_inputs( + self, + input_ids: torch.Tensor, + forward_batch: ForwardBatch, + image_input: ImageInputs, + ) -> torch.Tensor: + if input_ids is None: + raise ValueError("Unimplemented") + # boolean-masking image tokens + special_image_mask = torch.isin( + input_ids, + torch.tensor(image_input.pad_values, device=input_ids.device), + ).unsqueeze(-1) + num_image_tokens_in_input_ids = special_image_mask.sum() + + inputs_embeds = None + if num_image_tokens_in_input_ids == 0: + inputs_embeds = self.get_input_embeddings()(input_ids) + return inputs_embeds + else: + # print(f"image tokens from input_ids: {inputs_embeds[special_image_mask].numel()}") + image_features = self.get_image_features(image_input.pixel_values) + + # print(f"image tokens from image embeddings: {image_features.numel()}") + num_image_tokens_in_embedding = ( + image_features.shape[0] * image_features.shape[1] + ) + + if num_image_tokens_in_input_ids != num_image_tokens_in_embedding: + num_image = num_image_tokens_in_input_ids // image_features.shape[1] + image_features = image_features[:num_image, :] + logger.warning( + f"Number of images does not match number of special image tokens in the input text. " + f"Got {num_image_tokens_in_input_ids} image tokens in the text but {num_image_tokens_in_embedding} " + "tokens from image embeddings." + ) + + # Important: clamp after extracting original image boundaries + input_ids.clamp_(min=0, max=self.vocab_size - 1) + + inputs_embeds = self.get_input_embeddings()(input_ids) + + special_image_mask = special_image_mask.expand_as(inputs_embeds).to( + inputs_embeds.device + ) + + image_features = image_features.to( + inputs_embeds.device, inputs_embeds.dtype + ) + inputs_embeds = inputs_embeds.masked_scatter( + special_image_mask, image_features + ) + + return inputs_embeds + + @torch.no_grad() + def forward( + self, + input_ids: torch.LongTensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + **kwargs: object, + ) -> LogitsProcessor: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration + + >>> model = Gemma3ForConditionalGeneration.from_pretrained("google/Gemma3-test-224px-hf") + >>> processor = AutoProcessor.from_pretrained("google/Gemma3-test-224px-hf") + + >>> prompt = "answer en Where is the cow standing?" + >>> url = "https://huggingface.co/gv-hf/Gemma3-test-224px-hf/resolve/main/cow_beach_1.png" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, text=prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(**inputs, max_length=30) + >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "answer en Where is the cow standing?\nbeach" + ```""" + + # Important: position_ids in Gemma3 are 1-indexed + # This really does cost me sometime + positions += 1 + + # Replace image id with PAD if the image token if OOV, to avoid index-errors + if input_ids is not None and self.config.image_token_index >= self.vocab_size: + special_image_mask = input_ids == self.config.image_token_index + llm_input_ids = input_ids.clone() + llm_input_ids[special_image_mask] = 0 + else: + llm_input_ids = input_ids + + merged_image_input = forward_batch.get_merged_image_inputs() + + if ( + not forward_batch.forward_mode.is_decode() + and merged_image_input is not None + ): + inputs_embeds = self.embed_image_inputs( + input_ids=llm_input_ids, + forward_batch=forward_batch, + image_input=merged_image_input, + ) + else: + llm_input_ids.clamp_(min=0, max=self.vocab_size - 1) + inputs_embeds = self.get_input_embeddings()(llm_input_ids) + + outputs = self.language_model( + input_ids=None, + positions=positions, + forward_batch=forward_batch, + input_embeds=inputs_embeds, + **kwargs, + ) + + return outputs + + def tie_weights(self): + return self.language_model.tie_weights() + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + """Load weights for the model.""" + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + + for name, loaded_weight in weights: + if "language_model" in name: + # Gemma3ForCausalLM.load_weights(self, [(name.replace("language_model.", ""), loaded_weight)]) + causal_loaded_params = Gemma3ForCausalLM.load_weights( + self, [(name, loaded_weight)] + ) + loaded_params.update(causal_loaded_params) + continue + else: + # Skip lm_head.weight as it's tied with embed_tokens + if "lm_head.weight" in name: + continue + + # Skip loading extra bias for GPTQ models + if name.endswith(".bias") and name not in params_dict: + continue + + # Remapping the name of FP8 kv-scale + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + unloaded_params = params_dict.keys() - loaded_params + if unloaded_params: + pass + # raise RuntimeError( + # f"Some weights are not initialized from checkpoints: {unloaded_params}") + return loaded_params + + +EntryClass = Gemma3ForConditionalGeneration + +AutoModel.register(Gemma3Config, Gemma3ForConditionalGeneration, exist_ok=True) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index e1f009b1e..85bf35967 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -41,7 +41,6 @@ from functools import lru_cache from importlib.metadata import PackageNotFoundError, version from importlib.util import find_spec from io import BytesIO -from multiprocessing import Pool from multiprocessing.reduction import ForkingPickler from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Protocol, Set, Tuple, Union @@ -454,8 +453,9 @@ def load_image(image_file: Union[str, bytes]): image = Image.open(BytesIO(image_file)) elif image_file.startswith("http://") or image_file.startswith("https://"): timeout = int(os.getenv("REQUEST_TIMEOUT", "3")) - response = requests.get(image_file, timeout=timeout) - image = Image.open(BytesIO(response.content)) + response = requests.get(image_file, stream=True, timeout=timeout).raw + image = Image.open(response) + response.close() elif image_file.lower().endswith(("png", "jpg", "jpeg", "webp", "gif")): image = Image.open(image_file) elif image_file.startswith("data:"): diff --git a/test/srt/test_vision_openai_server.py b/test/srt/test_vision_openai_server.py index 3becdf319..935c2057b 100644 --- a/test/srt/test_vision_openai_server.py +++ b/test/srt/test_vision_openai_server.py @@ -75,7 +75,8 @@ class TestOpenAIVisionServer(unittest.TestCase): assert response.choices[0].message.role == "assistant" text = response.choices[0].message.content assert isinstance(text, str) - assert "man" in text or "person" in text, text + # `driver` is for gemma-3-it + assert "man" in text or "person" or "driver" in text, text assert "cab" in text or "taxi" in text or "SUV" in text, text assert "iron" in text, text assert response.id @@ -540,5 +541,27 @@ class TestJanusProServer(TestOpenAIVisionServer): pass +class TestGemma3itServer(TestOpenAIVisionServer): + @classmethod + def setUpClass(cls): + cls.model = "google/gemma-3-4b-it" + cls.base_url = DEFAULT_URL_FOR_TEST + cls.api_key = "sk-123456" + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--trust-remote-code", + "--chat-template", + "gemma-it", + ], + ) + cls.base_url += "/v1" + + def test_video_chat_completion(self): + pass + + if __name__ == "__main__": unittest.main()