Urgent model support: support gemma-3-it (#4424)

This commit is contained in:
Mick
2025-03-17 08:37:32 +08:00
committed by GitHub
parent 402db5c58c
commit 9d02bb3e2a
21 changed files with 2565 additions and 85 deletions

View File

@@ -520,6 +520,14 @@ def match_granite_instruct(model_path: str):
return get_chat_template("granite-3-instruct")
@register_chat_template_matching_function
def match_gemma3_instruct(model_path: str):
model_path = model_path.lower()
if "gemma-3" in model_path and "1b" not in model_path:
# gemma-3-1b-it is completion model
return get_chat_template("gemma-it")
if __name__ == "__main__":
messages = [
{"role": "system", "content": None}, # None means default

View File

@@ -1,6 +1,7 @@
from sglang.srt.configs.chatglm import ChatGLMConfig
from sglang.srt.configs.dbrx import DbrxConfig
from sglang.srt.configs.exaone import ExaoneConfig
from sglang.srt.configs.gemma3 import Gemma3Config, Gemma3TextConfig
from sglang.srt.configs.janus_pro import MultiModalityConfig
from sglang.srt.configs.qwen2_5_vl_config import (
Qwen2_5_VLConfig,
@@ -14,4 +15,6 @@ __all__ = [
"Qwen2_5_VLConfig",
"Qwen2_5_VLVisionConfig",
"MultiModalityConfig",
"Gemma3Config",
"Gemma3TextConfig",
]

File diff suppressed because it is too large Load Diff

View File

@@ -391,9 +391,13 @@ def _get_and_verify_dtype(
dtype = dtype.lower()
if dtype == "auto":
if config_dtype == torch.float32:
if config.model_type == "gemma2":
if config.model_type.startswith("gemma"):
if config.model_type == "gemma":
gemma_version = ""
else:
gemma_version = config.model_type[5]
logger.info(
"For Gemma 2, we downcast float32 to bfloat16 instead "
f"For Gemma {gemma_version}, we downcast float32 to bfloat16 instead "
"of float16 by default. Please specify `dtype` if you "
"want to use float16."
)
@@ -453,6 +457,7 @@ multimodal_model_archs = [
"LlavaQwenForCausalLM",
"LlavaMistralForCausalLM",
"LlavaVidForCausalLM",
"Gemma3ForConditionalGeneration",
"Grok1VForCausalLM",
"Grok1AForCausalLM",
"MllamaForConditionalGeneration",

View File

@@ -45,6 +45,7 @@ class SeparatorStyle(IntEnum):
DEEPSEEK_CHAT = auto()
METAMATH = auto()
QWEN2_VL_EMBED = auto()
GEMMA3 = auto()
@dataclasses.dataclass
@@ -285,6 +286,18 @@ class Conversation:
else:
ret += role + ":"
return ret
elif self.sep_style == SeparatorStyle.GEMMA3:
ret = system_prompt
for i, (role, message) in enumerate(self.messages):
if message:
if i == 0:
ret += message + self.sep
else:
ret += role + message + self.sep
else:
ret += role
return ret
else:
raise ValueError(f"Invalid style: {self.sep_style}")
@@ -604,6 +617,20 @@ register_conv_template(
)
)
# Reference: https://huggingface.co/google/gemma-3-4b-it/blob/main/config.json
register_conv_template(
Conversation(
name="gemma-it",
system_message="You are a helpful assistant.",
system_template="<bos><start_of_turn>user{system_message}\n\n",
roles=("<start_of_turn>user\n", "<start_of_turn>model\n"),
sep="<end_of_turn>\n",
sep_style=SeparatorStyle.GEMMA3,
stop_str=["<end_of_turn>"],
image_token="<start_of_image>",
)
)
# Reference: https://huggingface.co/Alibaba-NLP/gme-Qwen2-VL-2B-Instruct#usage
register_conv_template(
Conversation(

View File

@@ -34,6 +34,8 @@ from sglang.srt.configs import (
ChatGLMConfig,
DbrxConfig,
ExaoneConfig,
Gemma3Config,
Gemma3TextConfig,
MultiModalityConfig,
Qwen2_5_VLConfig,
)
@@ -46,6 +48,8 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
ExaoneConfig.model_type: ExaoneConfig,
Qwen2_5_VLConfig.model_type: Qwen2_5_VLConfig,
MultiModalityConfig.model_type: MultiModalityConfig,
Gemma3Config.model_type: Gemma3Config,
Gemma3TextConfig.model_type: Gemma3TextConfig,
}
for name, cls in _CONFIG_REGISTRY.items():

View File

@@ -19,34 +19,10 @@ from sglang.srt.layers.linear import (
RowParallelLinear,
)
from sglang.srt.layers.quantization import QuantizationConfig
from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb, rotate_half
from sglang.srt.utils import add_prefix
# Copied from transformers, modeling_qwen2_vl.py
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb_vision(
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
orig_q_dtype = q.dtype
orig_k_dtype = k.dtype
q, k = q.float(), k.float()
cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
q_embed = q_embed.to(orig_q_dtype)
k_embed = k_embed.to(orig_k_dtype)
return q_embed, k_embed
class VisionAttention(nn.Module):
r"""
Multi-headed attention without any cache, mostly used for ViT.
@@ -168,7 +144,7 @@ class VisionAttention(nn.Module):
cos, sin = position_embeddings
original_shape = q.shape
q, k = q.view(s, head, -1), k.view(s, head, -1)
q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
q, k = apply_rotary_pos_emb(q, k, cos, sin)
q, k = q.reshape(original_shape), k.reshape(original_shape)
if self.use_qkv_parallel:

View File

@@ -119,6 +119,26 @@ class GemmaRMSNorm(CustomOp):
return out
class Gemma3RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.zeros(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float())
# Llama does x.to(float16) * w whilst Gemma3 is (x * w).to(float16)
# See https://github.com/huggingface/transformers/pull/29402
output = output * (1.0 + self.weight.float())
return output.type_as(x)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.eps}"
if not _is_cuda:
logger.info(
"sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."

View File

@@ -1173,6 +1173,37 @@ def get_rope(
return rotary_emb
# Copied from transformers
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
unsqueeze_dim=1,
) -> Tuple[torch.Tensor, torch.Tensor]:
orig_q_dtype = q.dtype
orig_k_dtype = k.dtype
q, k = q.float(), k.float()
# embedding is performed in float
cos = cos.unsqueeze(unsqueeze_dim).float()
sin = sin.unsqueeze(unsqueeze_dim).float()
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
q_embed = q_embed.to(orig_q_dtype)
k_embed = k_embed.to(orig_k_dtype)
return q_embed, k_embed
def get_rope_cpu(
head_size: int,
rotary_dim: int,

View File

@@ -111,7 +111,7 @@ class BaseImageProcessor(ABC):
def load_images(
self,
input_ids: list,
input_ids: list[int],
image_data,
image_token: str,
max_req_input_len: int,
@@ -122,22 +122,21 @@ class BaseImageProcessor(ABC):
Each frame of video/image will be replaced by a single image token
Args:
discard_alpha_channel: if True, discards the alpha channel in the returned images
"""
image_hashes, image_sizes = [], []
all_frames = []
new_text_parts = []
if isinstance(input_ids, list) and return_text:
assert len(input_ids) and isinstance(input_ids[0], int)
input_text = self._processor.tokenizer.decode(input_ids)
else:
input_text = input_ids
if return_text:
text_parts = input_text.split(image_token)
import re
pattern = "(" + "|".join(re.escape(sep) for sep in [image_token]) + ")"
# split text into list of normal text and special tokens
text_parts = re.split(pattern, input_text)
# TODO(mick): load from server_args, env, or sampling_params
MAX_NUM_FRAMES = 30
@@ -145,53 +144,65 @@ class BaseImageProcessor(ABC):
total_frame_count = sum(estimated_frames_list)
# a heuristic value, suggesting the maximum fraction of frames to embed from all visual inputs.
# e.g., 0.1 suggests that 1 frame out of 10 input frames should be used
scaling_factor = min(1.0, MAX_NUM_FRAMES / total_frame_count)
_scaling_factor = min(1.0, MAX_NUM_FRAMES / max(1, total_frame_count))
assert len(image_data) == len(estimated_frames_list)
# Process each input with allocated frames
for image_index, (image, estimated_frames) in enumerate(
zip(image_data, estimated_frames_list)
):
if len(all_frames) >= MAX_NUM_FRAMES:
max_frames_to_process = 0
else:
max_frames_to_process = max(1, int(estimated_frames * scaling_factor))
if max_frames_to_process == 0:
frames = []
else:
try:
if isinstance(image, str) and image.startswith("video:"):
path = image[len("video:") :]
frames = BaseImageProcessor.encode_video(
path, frame_count_limit=max_frames_to_process
)
image_index, audio_index = 0, 0
hashes, image_sizes, images, audios = [], [], [], []
new_text = ""
for index, text_part in enumerate(text_parts):
try:
if text_part == image_token:
# load as image
frames_to_process = estimated_frames_list[image_index]
if frames_to_process == 0:
frames = []
else:
raw_image, _size = load_image(image)
if discard_alpha_channel:
raw_image = raw_image.convert("RGB")
frames = [raw_image]
assert len(frames) != 0
except FileNotFoundError as e:
print(e)
return None
image_file = image_data[image_index]
if isinstance(image_file, str) and image_file.startswith(
"video:"
):
# video
path = image_file[len("video:") :]
frames = self.encode_video(
path, frame_count_limit=frames_to_process
)
else:
# image
raw_image, _size = load_image(image_file)
if discard_alpha_channel:
raw_image = raw_image.convert("RGB")
frames = [raw_image]
if len(frames) == 0:
continue
image_sizes += [frames[0].size] * len(frames)
image_hashes += [hash(image)] * len(frames)
all_frames += frames
image_sizes += frames[0].size * len(frames)
hashes += [hash(image_file)] * len(frames)
images += frames
image_index += 1
if frames_to_process != 0:
new_text += image_token * len(frames)
assert frames_to_process == len(frames)
else:
# TODO(mick): handle video
# normal text
new_text += text_part
if return_text:
new_text_parts.append(text_parts[image_index])
if max_frames_to_process != 0:
new_text_parts.append(image_token * len(frames))
assert max_frames_to_process >= len(frames)
if return_text:
new_text_parts.append(text_parts[-1])
except Exception as e:
import openai
logger.error(f"An exception occurred while loading images: {e}")
raise BadRequestError(
f"An exception occurred while loading images: {e}"
)
continue
input_text = "".join(new_text_parts)
return BaseImageProcessorOutput(
image_hashes, image_sizes, all_frames, input_text
image_hashes=hashes,
image_sizes=image_sizes,
all_frames=images,
input_text=new_text,
)

View File

@@ -0,0 +1,100 @@
import asyncio
from typing import List, Union
from transformers.utils import logging
from sglang.srt.managers.image_processor import (
BaseImageProcessor as SGLangBaseImageProcessor,
)
from sglang.srt.managers.image_processors.base_image_processor import (
get_global_processor,
)
from sglang.srt.models.gemma3_mm import Gemma3ForConditionalGeneration
# Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma3/image_processing_gemma3_fast.py
# will be removed in the future
logger = logging.get_logger(__name__)
class Gemma3SGLangImageProcessor(SGLangBaseImageProcessor):
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
self.IMAGE_TOKEN = "<start_of_image>"
self.IM_START_TOKEN_ID = hf_config.boi_token_index
self.IM_END_TOKEN_ID = hf_config.eoi_token_index
@staticmethod
def _process_images_task(images, input_text, _hf_config):
if isinstance(images, list) and len(images) == 0:
images = None
processor = get_global_processor()
result = processor.__call__(
text=[input_text],
images=images,
padding=True,
return_tensors="pt",
# if RGBA, this needs to be set
# images_kwargs={
# "input_data_format": ChannelDimension.FIRST
# }
)
pixel_values = getattr(result, "pixel_values", None)
return {
"input_ids": result.input_ids,
"pixel_values": pixel_values,
}
async def _process_images(self, images, input_text) -> dict:
if self.executor is not None:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self.executor,
Gemma3SGLangImageProcessor._process_images_task,
images,
input_text,
self.hf_config,
)
else:
return self._process_images_task(images, input_text, self.hf_config)
async def process_images_async(
self,
image_data: List[Union[str, bytes]],
input_ids,
request_obj,
max_req_input_len,
*args,
**kwargs,
):
if not image_data:
return None
if isinstance(image_data, str):
image_data = [image_data]
image_token = self.IMAGE_TOKEN
base_output = self.load_images(
input_ids=input_ids,
image_data=image_data,
image_token=image_token,
max_req_input_len=max_req_input_len,
discard_alpha_channel=True,
)
ret = await self._process_images(
input_text=base_output.input_text, images=base_output.all_frames
)
return {
"input_ids": ret["input_ids"].flatten().tolist(),
"pixel_values": ret["pixel_values"],
"image_hashes": base_output.image_hashes,
"im_start_id": self.IM_START_TOKEN_ID,
"im_end_id": self.IM_END_TOKEN_ID,
}
ImageProcessorMapping = {
Gemma3ForConditionalGeneration: Gemma3SGLangImageProcessor,
}

View File

@@ -60,7 +60,10 @@ class JanusProProcessor(SGLangBaseImageProcessor):
image_data = [image_data]
base_out = self.load_images(
input_ids, image_data, "<image_placeholder>", max_req_input_len
input_ids=input_ids,
image_data=image_data,
image_token="<image_placeholder>",
max_req_input_len=max_req_input_len,
)
images = base_out.all_frames
res = await self._process_images(images=images, input_text=base_out.input_text)

View File

@@ -52,7 +52,10 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
image_data = [image_data]
base_output = self.load_images(
input_ids, image_data, self.IMAGE_TOKEN, max_req_input_len
input_ids=input_ids,
image_data=image_data,
image_token=self.IMAGE_TOKEN,
max_req_input_len=max_req_input_len,
)
if base_output is None:
return None

View File

@@ -72,10 +72,10 @@ class Qwen2_5VLImageProcessor(BaseImageProcessor):
image_token = self.IMAGE_TOKEN
base_output = self.load_images(
input_ids,
image_data,
image_token,
max_req_input_len,
input_ids=input_ids,
image_data=image_data,
image_token=image_token,
max_req_input_len=max_req_input_len,
)
def smart_resize(

View File

@@ -49,7 +49,7 @@ from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, Forw
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_compiler_backend, next_power_of_2
from sglang.srt.utils import get_compiler_backend
if TYPE_CHECKING:
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
@@ -207,6 +207,9 @@ class ImageInputs:
return ret
def merge(self, other):
"""
merge image inputs when requests are being merged
"""
assert self.pixel_values.shape[1:] == other.pixel_values.shape[1:]
self.pixel_values = np.concatenate([self.pixel_values, other.pixel_values])

View File

@@ -33,6 +33,7 @@ from dataclasses import dataclass
from enum import IntEnum, auto
from typing import TYPE_CHECKING, List, Optional, Union
import numpy as np
import torch
import triton
import triton.language as tl
@@ -331,6 +332,32 @@ class ForwardBatch:
return ret
def get_merged_image_inputs(self) -> Optional[ImageInputs]:
"""
Merge all image inputs in the batch into a single ImageInputs object.
Returns:
if none, current batch contains no image input
"""
if not self.image_inputs or all(x is None for x in self.image_inputs):
return None
# Filter out None values
valid_inputs = [x for x in self.image_inputs if x is not None]
# Start with the first valid image input
merged = valid_inputs[0]
# Merge remaining inputs
for img_input in valid_inputs[1:]:
merged.merge(img_input)
if isinstance(merged.pixel_values, np.ndarray):
merged.pixel_values = torch.from_numpy(merged.pixel_values)
return merged
def _compute_mrope_positions(
self, model_runner: ModelRunner, batch: ModelWorkerBatch
):

View File

@@ -0,0 +1,687 @@
# Copyright 2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import copy
from typing import Iterable, Optional, Set, Tuple
import einops
import torch
import torch.nn.functional as F
from torch import nn
from transformers import (
ROPE_INIT_FUNCTIONS,
AutoModel,
PretrainedConfig,
PreTrainedModel,
)
from sglang.srt.configs.gemma3 import Gemma3TextConfig
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import GeluAndMul
from sglang.srt.layers.layernorm import Gemma3RMSNorm
from sglang.srt.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb, get_rope
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
)
from sglang.srt.utils import add_prefix, make_layers
# Adapted from:
# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/gemma3.py
def extract_layer_index(prefix: str) -> int:
"""Extract the layer index from a prefix string."""
parts = prefix.split(".")
for part in parts:
if part.startswith("layers."):
layer_str = part.split(".")[-1]
try:
return int(layer_str)
except ValueError:
continue
return -1
class Gemma3MLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_activation: str,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=add_prefix("gate_up_proj", prefix),
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=add_prefix("down_proj", prefix),
)
if hidden_activation != "gelu_pytorch_tanh":
raise ValueError(
"Gemma3 uses `gelu_pytorch_tanh` as the hidden activation "
"function. Please set `hidden_activation` to "
"`gelu_pytorch_tanh`."
)
self.act_fn = GeluAndMul()
def forward(self, x: torch.Tensor) -> torch.Tensor:
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class Gemma3Attention(nn.Module):
def __init__(
self,
layer_id: int,
config: Gemma3TextConfig,
max_position_embeddings: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.layer_id = layer_id
self.config = config
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = config.num_attention_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = config.num_key_value_heads
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
hidden_size = config.hidden_size
head_dim = getattr(
config, "head_dim", hidden_size // config.num_attention_heads
)
self.head_dim = head_dim
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = config.query_pre_attn_scalar**-0.5
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=config.attention_bias,
quant_config=quant_config,
prefix=add_prefix("qkv_proj", prefix),
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=config.attention_bias,
quant_config=quant_config,
prefix=add_prefix("o_proj", prefix),
)
# Determine if layer uses sliding window based on pattern
self.is_sliding = bool((layer_id + 1) % config.sliding_window_pattern)
# Initialize the rotary embedding.
if self.is_sliding:
# Local attention. Override the values in config.json.
self.rope_theta = config.rope_local_base_freq
self.rope_scaling = {"rope_type": "default"}
# FIXME(mick): idk why vllm does this
# self.sliding_window = config.interleaved_sliding_window
self.sliding_window = config.sliding_window
else:
# Global attention. Use the values in config.json.
self.rope_theta = config.rope_theta
self.rope_scaling = config.rope_scaling
self.sliding_window = None
self.attn = RadixAttention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
logit_cap=getattr(self.config, "attn_logit_softcapping", None),
sliding_window_size=self.sliding_window,
prefix=add_prefix("attn", prefix),
)
# Gemma3 adds normalization for q and k
self.q_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
self.k_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
def naive_attn_with_masks(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
out: torch.Tensor,
**kwargs,
) -> torch.Tensor:
q = q.view(-1, self.num_heads, self.head_dim)
# Expand the key and value to handle GQA.
num_queries_per_kv = self.num_heads // self.num_kv_heads
k = k.view(-1, self.num_kv_heads, self.head_dim)
k = k.repeat_interleave(num_queries_per_kv, dim=-2)
v = v.view(-1, self.num_kv_heads, self.head_dim)
v = v.repeat_interleave(num_queries_per_kv, dim=-2)
if self.is_sliding:
attn_masks = kwargs["local_attn_masks"]
else:
attn_masks = kwargs["global_attn_masks"]
seq_lens = kwargs["seq_lens"]
start_idx = 0
for seq_len, attn_mask in zip(seq_lens, attn_masks):
end_idx = start_idx + seq_len
query = q[start_idx:end_idx].unsqueeze(0)
key = k[start_idx:end_idx].unsqueeze(0)
value = v[start_idx:end_idx].unsqueeze(0)
# Transpose.
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
output = F.scaled_dot_product_attention(
query,
key,
value,
attn_mask,
self.scaling,
)
output = output.transpose(1, 2).flatten(-2, -1)
out[start_idx:end_idx] = output
start_idx = end_idx
return out
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
forward_batch: ForwardBatch,
**kwargs,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
# [s, h * head_dim]
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
# [s, h, head_dim]
q = q.unflatten(-1, (self.num_heads, self.head_dim))
# -> [h, s, head_dim]
q = q.transpose(0, 1).unsqueeze(0)
q = self.q_norm(q)
k = k.unflatten(-1, (self.num_kv_heads, self.head_dim))
# -> [h, s, head_dim]
k = k.transpose(0, 1).unsqueeze(0)
k = self.k_norm(k)
# q, k = self.rotary_emb(positions, q, k)
cos, sin = position_embeddings
q, k = apply_rotary_pos_emb(q, k, cos, sin)
# [b, h, s, head_dim] -> [b, s, h, head_dim]
q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3)
attn_output = self.attn(q, k, v, forward_batch=forward_batch)
output, _ = self.o_proj(attn_output)
return output
class Gemma3DecoderLayer(nn.Module):
def __init__(
self,
layer_id: int,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = Gemma3Attention(
layer_id=layer_id,
config=config,
max_position_embeddings=config.max_position_embeddings,
quant_config=quant_config,
prefix=add_prefix("self_attn", prefix),
)
self.hidden_size = config.hidden_size
self.mlp = Gemma3MLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_activation=config.hidden_activation,
quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
)
self.input_layernorm = Gemma3RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.post_attention_layernorm = Gemma3RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.pre_feedforward_layernorm = Gemma3RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.post_feedforward_layernorm = Gemma3RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.is_sliding = self.self_attn.is_sliding
self.layer_id = layer_id
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
position_embeddings_global: torch.Tensor,
position_embeddings_local: torch.Tensor,
forward_batch: ForwardBatch,
**kwargs,
) -> tuple[
torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]
]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# apply global RoPE to non-sliding layer only
if self.self_attn.is_sliding:
position_embeddings = position_embeddings_local
else:
position_embeddings = position_embeddings_global
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
position_embeddings=position_embeddings,
forward_batch=forward_batch,
**kwargs,
)
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.pre_feedforward_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = self.post_feedforward_layernorm(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
return outputs
class Gemma3RotaryEmbedding(nn.Module):
def __init__(self, config: Gemma3TextConfig, device=None):
super().__init__()
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get(
"rope_type", config.rope_scaling.get("type")
)
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
def _dynamic_frequency_update(self, position_ids, device):
"""
dynamic RoPE layers should recompute `inv_freq` in the following situations:
1 - growing beyond the cached sequence length (allow scaling)
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
"""
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn(
self.config, device, seq_len=seq_len
)
self.register_buffer(
"inv_freq", inv_freq, persistent=False
) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len
if (
seq_len < self.original_max_seq_len
and self.max_seq_len_cached > self.original_max_seq_len
): # reset
# This .to() is needed if the model has been moved to a device after being initialized (because
# the buffer is automatically moved, but not the original copy)
self.original_inv_freq = self.original_inv_freq.to(device)
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len
@torch.no_grad()
def forward(self, x, position_ids):
if "dynamic" in self.rope_type:
self._dynamic_frequency_update(position_ids, device=x.device)
# Core RoPE block
inv_freq_expanded = (
self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
device_type = x.device.type
device_type = (
device_type
if isinstance(device_type, str) and device_type != "mps"
else "cpu"
)
with torch.autocast(device_type=device_type, enabled=False):
freqs = (
inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()
).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
cos = cos * self.attention_scaling
sin = sin * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
class Gemma3TextScaledWordEmbedding(nn.Embedding):
"""
This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
"""
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: int,
embed_scale: Optional[float] = 1.0,
):
super().__init__(num_embeddings, embedding_dim, padding_idx)
self.embed_scale = embed_scale
def forward(self, input_ids: torch.Tensor):
return super().forward(input_ids) * self.embed_scale
class Gemma3TextModel(PreTrainedModel):
def __init__(
self,
config: Gemma3TextConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__(config=config)
self.config = config
self.quant_config = quant_config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
# Gemma3 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5. See https://github.com/huggingface/transformers/pull/29402
self.embed_tokens = Gemma3TextScaledWordEmbedding(
config.vocab_size,
config.hidden_size,
self.padding_idx,
embed_scale=self.config.hidden_size**0.5,
)
self.norm = Gemma3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = Gemma3RotaryEmbedding(config=config)
self.gradient_checkpointing = False
# when we want to create a local RoPE layer. Config defaults should hold values for global RoPE
config = copy.deepcopy(config)
config.rope_theta = config.rope_local_base_freq
config.rope_scaling = {"rope_type": "default"}
self.rotary_emb_local = Gemma3RotaryEmbedding(config=config)
self.layers = make_layers(
config.num_hidden_layers,
lambda idx, prefix: Gemma3DecoderLayer(
layer_id=idx,
config=config,
quant_config=quant_config,
prefix=prefix,
),
prefix=add_prefix("layers", prefix),
)
self.norm = Gemma3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_init()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
**kwargs,
) -> torch.Tensor:
if input_embeds is None:
hidden_states = self.embed_tokens(input_ids)
else:
hidden_states = input_embeds
if len(positions.shape) == 1:
positions = einops.rearrange(positions, "s -> 1 s")
position_embeddings_global = self.rotary_emb(hidden_states, positions)
position_embeddings_local = self.rotary_emb_local(hidden_states, positions)
for layer in self.layers:
layer_outputs = layer(
positions=positions,
position_embeddings_global=position_embeddings_global,
position_embeddings_local=position_embeddings_local,
hidden_states=hidden_states,
forward_batch=forward_batch,
**kwargs,
)
hidden_states = layer_outputs[0]
hidden_states = self.norm(hidden_states)
return hidden_states
class Gemma3ForCausalLM(PreTrainedModel):
config_class = Gemma3TextConfig
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
config_class = Gemma3TextConfig
base_model_prefix = "language_model"
# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"gate_up_proj",
"down_proj",
]
# Gemma does not apply LoRA to the embedding layer.
embedding_modules = {}
embedding_padding_modules = []
supports_lora = True
def __init__(
self,
config: Gemma3TextConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__(config=config)
self.config = config
self.quant_config = quant_config
self.model = Gemma3TextModel(
config, quant_config, prefix=add_prefix("model", prefix)
)
self.logits_processor = LogitsProcessor(config)
if self.config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
)
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def dtype(self) -> torch.dtype:
return self.model.layers[0].mlp.gate_up_proj.weight.dtype
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
**kwargs,
) -> LogitsProcessor:
hidden_states = self.model(
input_ids, positions, forward_batch, input_embeds, **kwargs
)
return self.logits_processor(
input_ids, hidden_states, self.model.embed_tokens, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
for param_name, shard_name, shard_id in stacked_params_mapping:
# if param_name in name:
# print(f"{param_name} is already in {name}")
if shard_name not in name:
continue
name = name.replace(shard_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# lm_head is not used in vllm as it is tied with embed_token.
# To prevent errors, skip loading lm_head.weight.
if "lm_head.weight" in name:
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
# unloaded_params = params_dict.keys() - loaded_params
# if unloaded_params:
# logger.warning(
# "Some weights are not initialized from checkpoints: %s", unloaded_params
# )
return loaded_params
EntryClass = Gemma3ForCausalLM
AutoModel.register(Gemma3TextConfig, Gemma3ForCausalLM, exist_ok=True)

View File

@@ -0,0 +1,462 @@
# Copyright 2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Adapted from:
# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/gemma3_mm.py
import logging
from functools import lru_cache
from typing import Dict, Iterable, List, Optional, Set, Tuple, TypedDict
import torch
from torch import nn
from transformers import AutoModel, PreTrainedModel
from sglang.srt.configs import Gemma3Config
from sglang.srt.hf_transformers_utils import get_processor
from sglang.srt.layers.layernorm import Gemma3RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.multi_modality_padding import (
MultiModalityDataPaddingPatternTokenPairs,
)
from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
)
from sglang.srt.models.gemma3_causal import Gemma3ForCausalLM
from sglang.srt.utils import add_prefix
logger = logging.getLogger(__name__)
cached_get_processor = lru_cache(get_processor)
class Gemma3ImagePixelInputs(TypedDict):
pixel_values: torch.Tensor
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
class Gemma3MultiModalProjector(nn.Module):
"""Projector for Gemma3 multimodal."""
def __init__(self, config: Gemma3Config):
super().__init__()
self.mm_input_projection_weight = nn.Parameter(
torch.zeros(
config.vision_config.hidden_size, config.text_config.hidden_size
)
)
self.mm_soft_emb_norm = Gemma3RMSNorm(
config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps
)
self.patches_per_image = int(
config.vision_config.image_size // config.vision_config.patch_size
)
self.tokens_per_side = int(config.mm_tokens_per_image**0.5)
self.kernel_size = self.patches_per_image // self.tokens_per_side
self.avg_pool = nn.AvgPool2d(
kernel_size=self.kernel_size, stride=self.kernel_size
)
def forward(self, vision_outputs: torch.Tensor) -> torch.Tensor:
batch_size, seq_length, hidden_size = vision_outputs.shape
# Reshape for pooling
reshaped_vision_outputs = vision_outputs.transpose(1, 2)
reshaped_vision_outputs = reshaped_vision_outputs.reshape(
batch_size, hidden_size, self.patches_per_image, self.patches_per_image
)
reshaped_vision_outputs = reshaped_vision_outputs.contiguous()
# Apply pooling
pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs)
pooled_vision_outputs = pooled_vision_outputs.flatten(2)
pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2)
# Apply normalization
normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs)
# Project to text embedding space
projected_vision_outputs = torch.matmul(
normed_vision_outputs, self.mm_input_projection_weight
)
return projected_vision_outputs.type_as(vision_outputs)
class Gemma3ForConditionalGeneration(PreTrainedModel):
config_class = Gemma3Config
"""Gemma3 multimodal model for conditional generation."""
# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"gate_up_proj",
"down_proj",
]
# Gemma does not apply LoRA to the embedding layer.
embedding_modules = {}
embedding_padding_modules = []
supports_lora = True
def __init__(
self,
config: Gemma3Config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__(config=config)
self.config = config
self.quant_config = quant_config
# Vision components
# TODO: replace with vision attention
# self.vision_tower = SiglipVisionModel(
# config.vision_config,
# quant_config,
# prefix=add_prefix("vision_tower", prefix),
# )
self.vision_tower = AutoModel.from_config(config=config.vision_config)
self.multi_modal_projector = Gemma3MultiModalProjector(config)
self.vocab_size = config.text_config.vocab_size
# Text model
self.language_model = Gemma3ForCausalLM(
config.text_config, quant_config, prefix=add_prefix("model", prefix)
)
if self.language_model.logits_processor.logit_scale:
logit_scale = getattr(config, "logit_scale", 1.0)
self.language_model.logits_processor.logit_scale *= logit_scale
self.post_init()
def pad_input_ids(
self, input_ids: List[int], image_inputs: ImageInputs
) -> List[int]:
"""Pad input IDs with image tokens."""
# Get special token IDs
im_start_id: int = image_inputs.im_start_id
im_end_id: int = image_inputs.im_end_id
media_token_pairs = [(im_start_id, im_end_id)]
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
ids = pattern.pad_input_tokens(input_ids, image_inputs)
return ids
def prepare_attn_masks(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
mask_dtype: torch.dtype,
**kwargs,
) -> Dict:
"""Prepare attention masks for multimodal inputs."""
kwargs["has_images"] = True
# Distinguish sequences by position id 0
start_indices = (positions == 0).cpu().nonzero()
num_seqs = len(start_indices)
seq_lens = []
for i in range(num_seqs):
start_idx = start_indices[i].item()
if i < num_seqs - 1:
end_idx = start_indices[i + 1].item()
else:
end_idx = len(input_ids)
seq_lens.append(end_idx - start_idx)
kwargs["seq_lens"] = seq_lens
# Create attention masks
global_attn_masks = []
local_attn_masks = []
sliding_window = self.config.text_config.interleaved_sliding_window
start_idx = 0
for seq_len in seq_lens:
end_idx = start_idx + seq_len
input_token_ids = input_ids[start_idx:end_idx]
start_idx = end_idx
# Create global causal mask
global_attn_mask = torch.empty(
1,
1,
seq_len,
seq_len,
dtype=mask_dtype,
device=input_ids.device,
)
global_attn_mask.fill_(float("-inf"))
global_attn_mask = global_attn_mask.triu(diagonal=1)
# Consider bidirectional attention between image tokens
img_mask = torch.zeros_like(global_attn_mask)
img_pos = input_token_ids == self.config.image_token_index
img_mask[:, :, :, img_pos] += 1
img_mask[:, :, img_pos, :] += 1
global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask)
global_attn_masks.append(global_attn_mask)
# Create local causal mask with sliding window
local_attn_mask = torch.ones_like(global_attn_mask)
local_attn_mask = torch.tril(local_attn_mask, diagonal=-sliding_window)
local_attn_mask = torch.where(
local_attn_mask == 0, global_attn_mask, float("-inf")
)
local_attn_masks.append(local_attn_mask)
kwargs["global_attn_masks"] = global_attn_masks
kwargs["local_attn_masks"] = local_attn_masks
return kwargs
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
def get_image_features(self, pixel_values: torch.Tensor):
"""
Projects the last hidden state from the vision model into language model space.
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
The tensors corresponding to the input images.
Returns:
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
"""
pixel_values = pixel_values.to("cuda")
pixel_values = pixel_values.to(dtype=self.language_model.dtype())
vision_outputs = self.vision_tower(pixel_values=pixel_values).last_hidden_state
image_features = self.multi_modal_projector(vision_outputs)
return image_features
def embed_image_inputs(
self,
input_ids: torch.Tensor,
forward_batch: ForwardBatch,
image_input: ImageInputs,
) -> torch.Tensor:
if input_ids is None:
raise ValueError("Unimplemented")
# boolean-masking image tokens
special_image_mask = torch.isin(
input_ids,
torch.tensor(image_input.pad_values, device=input_ids.device),
).unsqueeze(-1)
num_image_tokens_in_input_ids = special_image_mask.sum()
inputs_embeds = None
if num_image_tokens_in_input_ids == 0:
inputs_embeds = self.get_input_embeddings()(input_ids)
return inputs_embeds
else:
# print(f"image tokens from input_ids: {inputs_embeds[special_image_mask].numel()}")
image_features = self.get_image_features(image_input.pixel_values)
# print(f"image tokens from image embeddings: {image_features.numel()}")
num_image_tokens_in_embedding = (
image_features.shape[0] * image_features.shape[1]
)
if num_image_tokens_in_input_ids != num_image_tokens_in_embedding:
num_image = num_image_tokens_in_input_ids // image_features.shape[1]
image_features = image_features[:num_image, :]
logger.warning(
f"Number of images does not match number of special image tokens in the input text. "
f"Got {num_image_tokens_in_input_ids} image tokens in the text but {num_image_tokens_in_embedding} "
"tokens from image embeddings."
)
# Important: clamp after extracting original image boundaries
input_ids.clamp_(min=0, max=self.vocab_size - 1)
inputs_embeds = self.get_input_embeddings()(input_ids)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
inputs_embeds.device
)
image_features = image_features.to(
inputs_embeds.device, inputs_embeds.dtype
)
inputs_embeds = inputs_embeds.masked_scatter(
special_image_mask, image_features
)
return inputs_embeds
@torch.no_grad()
def forward(
self,
input_ids: torch.LongTensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
**kwargs: object,
) -> LogitsProcessor:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
Example:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
>>> model = Gemma3ForConditionalGeneration.from_pretrained("google/Gemma3-test-224px-hf")
>>> processor = AutoProcessor.from_pretrained("google/Gemma3-test-224px-hf")
>>> prompt = "answer en Where is the cow standing?"
>>> url = "https://huggingface.co/gv-hf/Gemma3-test-224px-hf/resolve/main/cow_beach_1.png"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(**inputs, max_length=30)
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"answer en Where is the cow standing?\nbeach"
```"""
# Important: position_ids in Gemma3 are 1-indexed
# This really does cost me sometime
positions += 1
# Replace image id with PAD if the image token if OOV, to avoid index-errors
if input_ids is not None and self.config.image_token_index >= self.vocab_size:
special_image_mask = input_ids == self.config.image_token_index
llm_input_ids = input_ids.clone()
llm_input_ids[special_image_mask] = 0
else:
llm_input_ids = input_ids
merged_image_input = forward_batch.get_merged_image_inputs()
if (
not forward_batch.forward_mode.is_decode()
and merged_image_input is not None
):
inputs_embeds = self.embed_image_inputs(
input_ids=llm_input_ids,
forward_batch=forward_batch,
image_input=merged_image_input,
)
else:
llm_input_ids.clamp_(min=0, max=self.vocab_size - 1)
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
outputs = self.language_model(
input_ids=None,
positions=positions,
forward_batch=forward_batch,
input_embeds=inputs_embeds,
**kwargs,
)
return outputs
def tie_weights(self):
return self.language_model.tie_weights()
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
"""Load weights for the model."""
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "language_model" in name:
# Gemma3ForCausalLM.load_weights(self, [(name.replace("language_model.", ""), loaded_weight)])
causal_loaded_params = Gemma3ForCausalLM.load_weights(
self, [(name, loaded_weight)]
)
loaded_params.update(causal_loaded_params)
continue
else:
# Skip lm_head.weight as it's tied with embed_tokens
if "lm_head.weight" in name:
continue
# Skip loading extra bias for GPTQ models
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
unloaded_params = params_dict.keys() - loaded_params
if unloaded_params:
pass
# raise RuntimeError(
# f"Some weights are not initialized from checkpoints: {unloaded_params}")
return loaded_params
EntryClass = Gemma3ForConditionalGeneration
AutoModel.register(Gemma3Config, Gemma3ForConditionalGeneration, exist_ok=True)

View File

@@ -41,7 +41,6 @@ from functools import lru_cache
from importlib.metadata import PackageNotFoundError, version
from importlib.util import find_spec
from io import BytesIO
from multiprocessing import Pool
from multiprocessing.reduction import ForkingPickler
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Protocol, Set, Tuple, Union
@@ -454,8 +453,9 @@ def load_image(image_file: Union[str, bytes]):
image = Image.open(BytesIO(image_file))
elif image_file.startswith("http://") or image_file.startswith("https://"):
timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
response = requests.get(image_file, timeout=timeout)
image = Image.open(BytesIO(response.content))
response = requests.get(image_file, stream=True, timeout=timeout).raw
image = Image.open(response)
response.close()
elif image_file.lower().endswith(("png", "jpg", "jpeg", "webp", "gif")):
image = Image.open(image_file)
elif image_file.startswith("data:"):