Urgent model support: support gemma-3-it (#4424)
This commit is contained in:
@@ -520,6 +520,14 @@ def match_granite_instruct(model_path: str):
|
||||
return get_chat_template("granite-3-instruct")
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_gemma3_instruct(model_path: str):
|
||||
model_path = model_path.lower()
|
||||
if "gemma-3" in model_path and "1b" not in model_path:
|
||||
# gemma-3-1b-it is completion model
|
||||
return get_chat_template("gemma-it")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
messages = [
|
||||
{"role": "system", "content": None}, # None means default
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from sglang.srt.configs.chatglm import ChatGLMConfig
|
||||
from sglang.srt.configs.dbrx import DbrxConfig
|
||||
from sglang.srt.configs.exaone import ExaoneConfig
|
||||
from sglang.srt.configs.gemma3 import Gemma3Config, Gemma3TextConfig
|
||||
from sglang.srt.configs.janus_pro import MultiModalityConfig
|
||||
from sglang.srt.configs.qwen2_5_vl_config import (
|
||||
Qwen2_5_VLConfig,
|
||||
@@ -14,4 +15,6 @@ __all__ = [
|
||||
"Qwen2_5_VLConfig",
|
||||
"Qwen2_5_VLVisionConfig",
|
||||
"MultiModalityConfig",
|
||||
"Gemma3Config",
|
||||
"Gemma3TextConfig",
|
||||
]
|
||||
|
||||
1086
python/sglang/srt/configs/gemma3.py
Normal file
1086
python/sglang/srt/configs/gemma3.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -391,9 +391,13 @@ def _get_and_verify_dtype(
|
||||
dtype = dtype.lower()
|
||||
if dtype == "auto":
|
||||
if config_dtype == torch.float32:
|
||||
if config.model_type == "gemma2":
|
||||
if config.model_type.startswith("gemma"):
|
||||
if config.model_type == "gemma":
|
||||
gemma_version = ""
|
||||
else:
|
||||
gemma_version = config.model_type[5]
|
||||
logger.info(
|
||||
"For Gemma 2, we downcast float32 to bfloat16 instead "
|
||||
f"For Gemma {gemma_version}, we downcast float32 to bfloat16 instead "
|
||||
"of float16 by default. Please specify `dtype` if you "
|
||||
"want to use float16."
|
||||
)
|
||||
@@ -453,6 +457,7 @@ multimodal_model_archs = [
|
||||
"LlavaQwenForCausalLM",
|
||||
"LlavaMistralForCausalLM",
|
||||
"LlavaVidForCausalLM",
|
||||
"Gemma3ForConditionalGeneration",
|
||||
"Grok1VForCausalLM",
|
||||
"Grok1AForCausalLM",
|
||||
"MllamaForConditionalGeneration",
|
||||
|
||||
@@ -45,6 +45,7 @@ class SeparatorStyle(IntEnum):
|
||||
DEEPSEEK_CHAT = auto()
|
||||
METAMATH = auto()
|
||||
QWEN2_VL_EMBED = auto()
|
||||
GEMMA3 = auto()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@@ -285,6 +286,18 @@ class Conversation:
|
||||
else:
|
||||
ret += role + ":"
|
||||
return ret
|
||||
elif self.sep_style == SeparatorStyle.GEMMA3:
|
||||
ret = system_prompt
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
if i == 0:
|
||||
ret += message + self.sep
|
||||
else:
|
||||
ret += role + message + self.sep
|
||||
else:
|
||||
ret += role
|
||||
return ret
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid style: {self.sep_style}")
|
||||
|
||||
@@ -604,6 +617,20 @@ register_conv_template(
|
||||
)
|
||||
)
|
||||
|
||||
# Reference: https://huggingface.co/google/gemma-3-4b-it/blob/main/config.json
|
||||
register_conv_template(
|
||||
Conversation(
|
||||
name="gemma-it",
|
||||
system_message="You are a helpful assistant.",
|
||||
system_template="<bos><start_of_turn>user{system_message}\n\n",
|
||||
roles=("<start_of_turn>user\n", "<start_of_turn>model\n"),
|
||||
sep="<end_of_turn>\n",
|
||||
sep_style=SeparatorStyle.GEMMA3,
|
||||
stop_str=["<end_of_turn>"],
|
||||
image_token="<start_of_image>",
|
||||
)
|
||||
)
|
||||
|
||||
# Reference: https://huggingface.co/Alibaba-NLP/gme-Qwen2-VL-2B-Instruct#usage
|
||||
register_conv_template(
|
||||
Conversation(
|
||||
|
||||
@@ -34,6 +34,8 @@ from sglang.srt.configs import (
|
||||
ChatGLMConfig,
|
||||
DbrxConfig,
|
||||
ExaoneConfig,
|
||||
Gemma3Config,
|
||||
Gemma3TextConfig,
|
||||
MultiModalityConfig,
|
||||
Qwen2_5_VLConfig,
|
||||
)
|
||||
@@ -46,6 +48,8 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
||||
ExaoneConfig.model_type: ExaoneConfig,
|
||||
Qwen2_5_VLConfig.model_type: Qwen2_5_VLConfig,
|
||||
MultiModalityConfig.model_type: MultiModalityConfig,
|
||||
Gemma3Config.model_type: Gemma3Config,
|
||||
Gemma3TextConfig.model_type: Gemma3TextConfig,
|
||||
}
|
||||
|
||||
for name, cls in _CONFIG_REGISTRY.items():
|
||||
|
||||
@@ -19,34 +19,10 @@ from sglang.srt.layers.linear import (
|
||||
RowParallelLinear,
|
||||
)
|
||||
from sglang.srt.layers.quantization import QuantizationConfig
|
||||
from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb, rotate_half
|
||||
from sglang.srt.utils import add_prefix
|
||||
|
||||
|
||||
# Copied from transformers, modeling_qwen2_vl.py
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb_vision(
|
||||
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
orig_q_dtype = q.dtype
|
||||
orig_k_dtype = k.dtype
|
||||
q, k = q.float(), k.float()
|
||||
|
||||
cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
|
||||
q_embed = q_embed.to(orig_q_dtype)
|
||||
k_embed = k_embed.to(orig_k_dtype)
|
||||
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
class VisionAttention(nn.Module):
|
||||
r"""
|
||||
Multi-headed attention without any cache, mostly used for ViT.
|
||||
@@ -168,7 +144,7 @@ class VisionAttention(nn.Module):
|
||||
cos, sin = position_embeddings
|
||||
original_shape = q.shape
|
||||
q, k = q.view(s, head, -1), k.view(s, head, -1)
|
||||
q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
|
||||
q, k = apply_rotary_pos_emb(q, k, cos, sin)
|
||||
q, k = q.reshape(original_shape), k.reshape(original_shape)
|
||||
|
||||
if self.use_qkv_parallel:
|
||||
|
||||
@@ -119,6 +119,26 @@ class GemmaRMSNorm(CustomOp):
|
||||
return out
|
||||
|
||||
|
||||
class Gemma3RMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.zeros(dim))
|
||||
|
||||
def _norm(self, x):
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
output = self._norm(x.float())
|
||||
# Llama does x.to(float16) * w whilst Gemma3 is (x * w).to(float16)
|
||||
# See https://github.com/huggingface/transformers/pull/29402
|
||||
output = output * (1.0 + self.weight.float())
|
||||
return output.type_as(x)
|
||||
|
||||
def extra_repr(self):
|
||||
return f"{tuple(self.weight.shape)}, eps={self.eps}"
|
||||
|
||||
|
||||
if not _is_cuda:
|
||||
logger.info(
|
||||
"sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
|
||||
|
||||
@@ -1173,6 +1173,37 @@ def get_rope(
|
||||
return rotary_emb
|
||||
|
||||
|
||||
# Copied from transformers
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
unsqueeze_dim=1,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
orig_q_dtype = q.dtype
|
||||
orig_k_dtype = k.dtype
|
||||
q, k = q.float(), k.float()
|
||||
|
||||
# embedding is performed in float
|
||||
cos = cos.unsqueeze(unsqueeze_dim).float()
|
||||
sin = sin.unsqueeze(unsqueeze_dim).float()
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
|
||||
q_embed = q_embed.to(orig_q_dtype)
|
||||
k_embed = k_embed.to(orig_k_dtype)
|
||||
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
def get_rope_cpu(
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
|
||||
@@ -111,7 +111,7 @@ class BaseImageProcessor(ABC):
|
||||
|
||||
def load_images(
|
||||
self,
|
||||
input_ids: list,
|
||||
input_ids: list[int],
|
||||
image_data,
|
||||
image_token: str,
|
||||
max_req_input_len: int,
|
||||
@@ -122,22 +122,21 @@ class BaseImageProcessor(ABC):
|
||||
Each frame of video/image will be replaced by a single image token
|
||||
|
||||
Args:
|
||||
|
||||
discard_alpha_channel: if True, discards the alpha channel in the returned images
|
||||
|
||||
"""
|
||||
image_hashes, image_sizes = [], []
|
||||
all_frames = []
|
||||
new_text_parts = []
|
||||
|
||||
if isinstance(input_ids, list) and return_text:
|
||||
assert len(input_ids) and isinstance(input_ids[0], int)
|
||||
input_text = self._processor.tokenizer.decode(input_ids)
|
||||
else:
|
||||
input_text = input_ids
|
||||
|
||||
if return_text:
|
||||
text_parts = input_text.split(image_token)
|
||||
import re
|
||||
|
||||
pattern = "(" + "|".join(re.escape(sep) for sep in [image_token]) + ")"
|
||||
# split text into list of normal text and special tokens
|
||||
text_parts = re.split(pattern, input_text)
|
||||
|
||||
# TODO(mick): load from server_args, env, or sampling_params
|
||||
MAX_NUM_FRAMES = 30
|
||||
@@ -145,53 +144,65 @@ class BaseImageProcessor(ABC):
|
||||
total_frame_count = sum(estimated_frames_list)
|
||||
# a heuristic value, suggesting the maximum fraction of frames to embed from all visual inputs.
|
||||
# e.g., 0.1 suggests that 1 frame out of 10 input frames should be used
|
||||
scaling_factor = min(1.0, MAX_NUM_FRAMES / total_frame_count)
|
||||
_scaling_factor = min(1.0, MAX_NUM_FRAMES / max(1, total_frame_count))
|
||||
|
||||
assert len(image_data) == len(estimated_frames_list)
|
||||
|
||||
# Process each input with allocated frames
|
||||
for image_index, (image, estimated_frames) in enumerate(
|
||||
zip(image_data, estimated_frames_list)
|
||||
):
|
||||
if len(all_frames) >= MAX_NUM_FRAMES:
|
||||
max_frames_to_process = 0
|
||||
else:
|
||||
max_frames_to_process = max(1, int(estimated_frames * scaling_factor))
|
||||
|
||||
if max_frames_to_process == 0:
|
||||
frames = []
|
||||
else:
|
||||
try:
|
||||
if isinstance(image, str) and image.startswith("video:"):
|
||||
path = image[len("video:") :]
|
||||
frames = BaseImageProcessor.encode_video(
|
||||
path, frame_count_limit=max_frames_to_process
|
||||
)
|
||||
image_index, audio_index = 0, 0
|
||||
hashes, image_sizes, images, audios = [], [], [], []
|
||||
new_text = ""
|
||||
for index, text_part in enumerate(text_parts):
|
||||
try:
|
||||
if text_part == image_token:
|
||||
# load as image
|
||||
frames_to_process = estimated_frames_list[image_index]
|
||||
if frames_to_process == 0:
|
||||
frames = []
|
||||
else:
|
||||
raw_image, _size = load_image(image)
|
||||
if discard_alpha_channel:
|
||||
raw_image = raw_image.convert("RGB")
|
||||
frames = [raw_image]
|
||||
assert len(frames) != 0
|
||||
except FileNotFoundError as e:
|
||||
print(e)
|
||||
return None
|
||||
image_file = image_data[image_index]
|
||||
if isinstance(image_file, str) and image_file.startswith(
|
||||
"video:"
|
||||
):
|
||||
# video
|
||||
path = image_file[len("video:") :]
|
||||
frames = self.encode_video(
|
||||
path, frame_count_limit=frames_to_process
|
||||
)
|
||||
else:
|
||||
# image
|
||||
raw_image, _size = load_image(image_file)
|
||||
if discard_alpha_channel:
|
||||
raw_image = raw_image.convert("RGB")
|
||||
frames = [raw_image]
|
||||
if len(frames) == 0:
|
||||
continue
|
||||
|
||||
image_sizes += [frames[0].size] * len(frames)
|
||||
image_hashes += [hash(image)] * len(frames)
|
||||
all_frames += frames
|
||||
image_sizes += frames[0].size * len(frames)
|
||||
hashes += [hash(image_file)] * len(frames)
|
||||
images += frames
|
||||
image_index += 1
|
||||
if frames_to_process != 0:
|
||||
new_text += image_token * len(frames)
|
||||
assert frames_to_process == len(frames)
|
||||
else:
|
||||
# TODO(mick): handle video
|
||||
# normal text
|
||||
new_text += text_part
|
||||
|
||||
if return_text:
|
||||
new_text_parts.append(text_parts[image_index])
|
||||
if max_frames_to_process != 0:
|
||||
new_text_parts.append(image_token * len(frames))
|
||||
assert max_frames_to_process >= len(frames)
|
||||
if return_text:
|
||||
new_text_parts.append(text_parts[-1])
|
||||
except Exception as e:
|
||||
import openai
|
||||
|
||||
logger.error(f"An exception occurred while loading images: {e}")
|
||||
raise BadRequestError(
|
||||
f"An exception occurred while loading images: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
input_text = "".join(new_text_parts)
|
||||
return BaseImageProcessorOutput(
|
||||
image_hashes, image_sizes, all_frames, input_text
|
||||
image_hashes=hashes,
|
||||
image_sizes=image_sizes,
|
||||
all_frames=images,
|
||||
input_text=new_text,
|
||||
)
|
||||
|
||||
|
||||
|
||||
100
python/sglang/srt/managers/image_processors/gemma3.py
Normal file
100
python/sglang/srt/managers/image_processors/gemma3.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import asyncio
|
||||
from typing import List, Union
|
||||
|
||||
from transformers.utils import logging
|
||||
|
||||
from sglang.srt.managers.image_processor import (
|
||||
BaseImageProcessor as SGLangBaseImageProcessor,
|
||||
)
|
||||
from sglang.srt.managers.image_processors.base_image_processor import (
|
||||
get_global_processor,
|
||||
)
|
||||
from sglang.srt.models.gemma3_mm import Gemma3ForConditionalGeneration
|
||||
|
||||
# Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma3/image_processing_gemma3_fast.py
|
||||
# will be removed in the future
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Gemma3SGLangImageProcessor(SGLangBaseImageProcessor):
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
self.IMAGE_TOKEN = "<start_of_image>"
|
||||
self.IM_START_TOKEN_ID = hf_config.boi_token_index
|
||||
self.IM_END_TOKEN_ID = hf_config.eoi_token_index
|
||||
|
||||
@staticmethod
|
||||
def _process_images_task(images, input_text, _hf_config):
|
||||
if isinstance(images, list) and len(images) == 0:
|
||||
images = None
|
||||
processor = get_global_processor()
|
||||
result = processor.__call__(
|
||||
text=[input_text],
|
||||
images=images,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
# if RGBA, this needs to be set
|
||||
# images_kwargs={
|
||||
# "input_data_format": ChannelDimension.FIRST
|
||||
# }
|
||||
)
|
||||
|
||||
pixel_values = getattr(result, "pixel_values", None)
|
||||
|
||||
return {
|
||||
"input_ids": result.input_ids,
|
||||
"pixel_values": pixel_values,
|
||||
}
|
||||
|
||||
async def _process_images(self, images, input_text) -> dict:
|
||||
if self.executor is not None:
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
self.executor,
|
||||
Gemma3SGLangImageProcessor._process_images_task,
|
||||
images,
|
||||
input_text,
|
||||
self.hf_config,
|
||||
)
|
||||
else:
|
||||
return self._process_images_task(images, input_text, self.hf_config)
|
||||
|
||||
async def process_images_async(
|
||||
self,
|
||||
image_data: List[Union[str, bytes]],
|
||||
input_ids,
|
||||
request_obj,
|
||||
max_req_input_len,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
if not image_data:
|
||||
return None
|
||||
if isinstance(image_data, str):
|
||||
image_data = [image_data]
|
||||
|
||||
image_token = self.IMAGE_TOKEN
|
||||
base_output = self.load_images(
|
||||
input_ids=input_ids,
|
||||
image_data=image_data,
|
||||
image_token=image_token,
|
||||
max_req_input_len=max_req_input_len,
|
||||
discard_alpha_channel=True,
|
||||
)
|
||||
|
||||
ret = await self._process_images(
|
||||
input_text=base_output.input_text, images=base_output.all_frames
|
||||
)
|
||||
|
||||
return {
|
||||
"input_ids": ret["input_ids"].flatten().tolist(),
|
||||
"pixel_values": ret["pixel_values"],
|
||||
"image_hashes": base_output.image_hashes,
|
||||
"im_start_id": self.IM_START_TOKEN_ID,
|
||||
"im_end_id": self.IM_END_TOKEN_ID,
|
||||
}
|
||||
|
||||
|
||||
ImageProcessorMapping = {
|
||||
Gemma3ForConditionalGeneration: Gemma3SGLangImageProcessor,
|
||||
}
|
||||
@@ -60,7 +60,10 @@ class JanusProProcessor(SGLangBaseImageProcessor):
|
||||
image_data = [image_data]
|
||||
|
||||
base_out = self.load_images(
|
||||
input_ids, image_data, "<image_placeholder>", max_req_input_len
|
||||
input_ids=input_ids,
|
||||
image_data=image_data,
|
||||
image_token="<image_placeholder>",
|
||||
max_req_input_len=max_req_input_len,
|
||||
)
|
||||
images = base_out.all_frames
|
||||
res = await self._process_images(images=images, input_text=base_out.input_text)
|
||||
|
||||
@@ -52,7 +52,10 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
|
||||
image_data = [image_data]
|
||||
|
||||
base_output = self.load_images(
|
||||
input_ids, image_data, self.IMAGE_TOKEN, max_req_input_len
|
||||
input_ids=input_ids,
|
||||
image_data=image_data,
|
||||
image_token=self.IMAGE_TOKEN,
|
||||
max_req_input_len=max_req_input_len,
|
||||
)
|
||||
if base_output is None:
|
||||
return None
|
||||
|
||||
@@ -72,10 +72,10 @@ class Qwen2_5VLImageProcessor(BaseImageProcessor):
|
||||
|
||||
image_token = self.IMAGE_TOKEN
|
||||
base_output = self.load_images(
|
||||
input_ids,
|
||||
image_data,
|
||||
image_token,
|
||||
max_req_input_len,
|
||||
input_ids=input_ids,
|
||||
image_data=image_data,
|
||||
image_token=image_token,
|
||||
max_req_input_len=max_req_input_len,
|
||||
)
|
||||
|
||||
def smart_resize(
|
||||
|
||||
@@ -49,7 +49,7 @@ from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, Forw
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import get_compiler_backend, next_power_of_2
|
||||
from sglang.srt.utils import get_compiler_backend
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
@@ -207,6 +207,9 @@ class ImageInputs:
|
||||
return ret
|
||||
|
||||
def merge(self, other):
|
||||
"""
|
||||
merge image inputs when requests are being merged
|
||||
"""
|
||||
assert self.pixel_values.shape[1:] == other.pixel_values.shape[1:]
|
||||
self.pixel_values = np.concatenate([self.pixel_values, other.pixel_values])
|
||||
|
||||
|
||||
@@ -33,6 +33,7 @@ from dataclasses import dataclass
|
||||
from enum import IntEnum, auto
|
||||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
@@ -331,6 +332,32 @@ class ForwardBatch:
|
||||
|
||||
return ret
|
||||
|
||||
def get_merged_image_inputs(self) -> Optional[ImageInputs]:
|
||||
"""
|
||||
Merge all image inputs in the batch into a single ImageInputs object.
|
||||
|
||||
Returns:
|
||||
if none, current batch contains no image input
|
||||
|
||||
"""
|
||||
if not self.image_inputs or all(x is None for x in self.image_inputs):
|
||||
return None
|
||||
|
||||
# Filter out None values
|
||||
valid_inputs = [x for x in self.image_inputs if x is not None]
|
||||
|
||||
# Start with the first valid image input
|
||||
merged = valid_inputs[0]
|
||||
|
||||
# Merge remaining inputs
|
||||
for img_input in valid_inputs[1:]:
|
||||
merged.merge(img_input)
|
||||
|
||||
if isinstance(merged.pixel_values, np.ndarray):
|
||||
merged.pixel_values = torch.from_numpy(merged.pixel_values)
|
||||
|
||||
return merged
|
||||
|
||||
def _compute_mrope_positions(
|
||||
self, model_runner: ModelRunner, batch: ModelWorkerBatch
|
||||
):
|
||||
|
||||
687
python/sglang/srt/models/gemma3_causal.py
Normal file
687
python/sglang/srt/models/gemma3_causal.py
Normal file
@@ -0,0 +1,687 @@
|
||||
# Copyright 2025 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import copy
|
||||
from typing import Iterable, Optional, Set, Tuple
|
||||
|
||||
import einops
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from transformers import (
|
||||
ROPE_INIT_FUNCTIONS,
|
||||
AutoModel,
|
||||
PretrainedConfig,
|
||||
PreTrainedModel,
|
||||
)
|
||||
|
||||
from sglang.srt.configs.gemma3 import Gemma3TextConfig
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
||||
from sglang.srt.layers.activation import GeluAndMul
|
||||
from sglang.srt.layers.layernorm import Gemma3RMSNorm
|
||||
from sglang.srt.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb, get_rope
|
||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import (
|
||||
default_weight_loader,
|
||||
maybe_remap_kv_scale_name,
|
||||
)
|
||||
from sglang.srt.utils import add_prefix, make_layers
|
||||
|
||||
|
||||
# Adapted from:
|
||||
# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/gemma3.py
|
||||
def extract_layer_index(prefix: str) -> int:
|
||||
"""Extract the layer index from a prefix string."""
|
||||
parts = prefix.split(".")
|
||||
for part in parts:
|
||||
if part.startswith("layers."):
|
||||
layer_str = part.split(".")[-1]
|
||||
try:
|
||||
return int(layer_str)
|
||||
except ValueError:
|
||||
continue
|
||||
return -1
|
||||
|
||||
|
||||
class Gemma3MLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_activation: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size,
|
||||
[intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("gate_up_proj", prefix),
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("down_proj", prefix),
|
||||
)
|
||||
if hidden_activation != "gelu_pytorch_tanh":
|
||||
raise ValueError(
|
||||
"Gemma3 uses `gelu_pytorch_tanh` as the hidden activation "
|
||||
"function. Please set `hidden_activation` to "
|
||||
"`gelu_pytorch_tanh`."
|
||||
)
|
||||
self.act_fn = GeluAndMul()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class Gemma3Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
layer_id: int,
|
||||
config: Gemma3TextConfig,
|
||||
max_position_embeddings: int,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.layer_id = layer_id
|
||||
self.config = config
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
self.total_num_heads = config.num_attention_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
self.total_num_kv_heads = config.num_key_value_heads
|
||||
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
|
||||
if self.total_num_kv_heads >= tp_size:
|
||||
# Number of KV heads is greater than TP size, so we partition
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert self.total_num_kv_heads % tp_size == 0
|
||||
else:
|
||||
# Number of KV heads is less than TP size, so we replicate
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
|
||||
hidden_size = config.hidden_size
|
||||
|
||||
head_dim = getattr(
|
||||
config, "head_dim", hidden_size // config.num_attention_heads
|
||||
)
|
||||
self.head_dim = head_dim
|
||||
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = config.query_pre_attn_scalar**-0.5
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=config.attention_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("qkv_proj", prefix),
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=config.attention_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("o_proj", prefix),
|
||||
)
|
||||
|
||||
# Determine if layer uses sliding window based on pattern
|
||||
self.is_sliding = bool((layer_id + 1) % config.sliding_window_pattern)
|
||||
|
||||
# Initialize the rotary embedding.
|
||||
if self.is_sliding:
|
||||
# Local attention. Override the values in config.json.
|
||||
self.rope_theta = config.rope_local_base_freq
|
||||
self.rope_scaling = {"rope_type": "default"}
|
||||
# FIXME(mick): idk why vllm does this
|
||||
# self.sliding_window = config.interleaved_sliding_window
|
||||
self.sliding_window = config.sliding_window
|
||||
else:
|
||||
# Global attention. Use the values in config.json.
|
||||
self.rope_theta = config.rope_theta
|
||||
self.rope_scaling = config.rope_scaling
|
||||
self.sliding_window = None
|
||||
|
||||
self.attn = RadixAttention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_id,
|
||||
logit_cap=getattr(self.config, "attn_logit_softcapping", None),
|
||||
sliding_window_size=self.sliding_window,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
# Gemma3 adds normalization for q and k
|
||||
self.q_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
|
||||
self.k_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
|
||||
|
||||
def naive_attn_with_masks(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
q = q.view(-1, self.num_heads, self.head_dim)
|
||||
# Expand the key and value to handle GQA.
|
||||
num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
k = k.view(-1, self.num_kv_heads, self.head_dim)
|
||||
k = k.repeat_interleave(num_queries_per_kv, dim=-2)
|
||||
v = v.view(-1, self.num_kv_heads, self.head_dim)
|
||||
v = v.repeat_interleave(num_queries_per_kv, dim=-2)
|
||||
|
||||
if self.is_sliding:
|
||||
attn_masks = kwargs["local_attn_masks"]
|
||||
else:
|
||||
attn_masks = kwargs["global_attn_masks"]
|
||||
|
||||
seq_lens = kwargs["seq_lens"]
|
||||
start_idx = 0
|
||||
for seq_len, attn_mask in zip(seq_lens, attn_masks):
|
||||
end_idx = start_idx + seq_len
|
||||
query = q[start_idx:end_idx].unsqueeze(0)
|
||||
key = k[start_idx:end_idx].unsqueeze(0)
|
||||
value = v[start_idx:end_idx].unsqueeze(0)
|
||||
|
||||
# Transpose.
|
||||
query = query.transpose(1, 2)
|
||||
key = key.transpose(1, 2)
|
||||
value = value.transpose(1, 2)
|
||||
|
||||
output = F.scaled_dot_product_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask,
|
||||
self.scaling,
|
||||
)
|
||||
output = output.transpose(1, 2).flatten(-2, -1)
|
||||
out[start_idx:end_idx] = output
|
||||
start_idx = end_idx
|
||||
return out
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
||||
forward_batch: ForwardBatch,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
# [s, h * head_dim]
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
|
||||
# [s, h, head_dim]
|
||||
q = q.unflatten(-1, (self.num_heads, self.head_dim))
|
||||
# -> [h, s, head_dim]
|
||||
q = q.transpose(0, 1).unsqueeze(0)
|
||||
q = self.q_norm(q)
|
||||
k = k.unflatten(-1, (self.num_kv_heads, self.head_dim))
|
||||
# -> [h, s, head_dim]
|
||||
k = k.transpose(0, 1).unsqueeze(0)
|
||||
k = self.k_norm(k)
|
||||
|
||||
# q, k = self.rotary_emb(positions, q, k)
|
||||
cos, sin = position_embeddings
|
||||
q, k = apply_rotary_pos_emb(q, k, cos, sin)
|
||||
|
||||
# [b, h, s, head_dim] -> [b, s, h, head_dim]
|
||||
q = q.permute(0, 2, 1, 3)
|
||||
k = k.permute(0, 2, 1, 3)
|
||||
|
||||
attn_output = self.attn(q, k, v, forward_batch=forward_batch)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class Gemma3DecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
layer_id: int,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.self_attn = Gemma3Attention(
|
||||
layer_id=layer_id,
|
||||
config=config,
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("self_attn", prefix),
|
||||
)
|
||||
self.hidden_size = config.hidden_size
|
||||
self.mlp = Gemma3MLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_activation=config.hidden_activation,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("mlp", prefix),
|
||||
)
|
||||
self.input_layernorm = Gemma3RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
self.post_attention_layernorm = Gemma3RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
self.pre_feedforward_layernorm = Gemma3RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
self.post_feedforward_layernorm = Gemma3RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
self.is_sliding = self.self_attn.is_sliding
|
||||
self.layer_id = layer_id
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings_global: torch.Tensor,
|
||||
position_embeddings_local: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
**kwargs,
|
||||
) -> tuple[
|
||||
torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]
|
||||
]:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# apply global RoPE to non-sliding layer only
|
||||
if self.self_attn.is_sliding:
|
||||
position_embeddings = position_embeddings_local
|
||||
else:
|
||||
position_embeddings = position_embeddings_global
|
||||
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
position_embeddings=position_embeddings,
|
||||
forward_batch=forward_batch,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.pre_feedforward_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = self.post_feedforward_layernorm(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class Gemma3RotaryEmbedding(nn.Module):
|
||||
def __init__(self, config: Gemma3TextConfig, device=None):
|
||||
super().__init__()
|
||||
# BC: "rope_type" was originally "type"
|
||||
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
||||
self.rope_type = config.rope_scaling.get(
|
||||
"rope_type", config.rope_scaling.get("type")
|
||||
)
|
||||
else:
|
||||
self.rope_type = "default"
|
||||
self.max_seq_len_cached = config.max_position_embeddings
|
||||
self.original_max_seq_len = config.max_position_embeddings
|
||||
|
||||
self.config = config
|
||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.original_inv_freq = self.inv_freq
|
||||
|
||||
def _dynamic_frequency_update(self, position_ids, device):
|
||||
"""
|
||||
dynamic RoPE layers should recompute `inv_freq` in the following situations:
|
||||
1 - growing beyond the cached sequence length (allow scaling)
|
||||
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
|
||||
"""
|
||||
seq_len = torch.max(position_ids) + 1
|
||||
if seq_len > self.max_seq_len_cached: # growth
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(
|
||||
self.config, device, seq_len=seq_len
|
||||
)
|
||||
self.register_buffer(
|
||||
"inv_freq", inv_freq, persistent=False
|
||||
) # TODO joao: may break with compilation
|
||||
self.max_seq_len_cached = seq_len
|
||||
|
||||
if (
|
||||
seq_len < self.original_max_seq_len
|
||||
and self.max_seq_len_cached > self.original_max_seq_len
|
||||
): # reset
|
||||
# This .to() is needed if the model has been moved to a device after being initialized (because
|
||||
# the buffer is automatically moved, but not the original copy)
|
||||
self.original_inv_freq = self.original_inv_freq.to(device)
|
||||
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
||||
self.max_seq_len_cached = self.original_max_seq_len
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x, position_ids):
|
||||
if "dynamic" in self.rope_type:
|
||||
self._dynamic_frequency_update(position_ids, device=x.device)
|
||||
|
||||
# Core RoPE block
|
||||
inv_freq_expanded = (
|
||||
self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||
)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
|
||||
device_type = x.device.type
|
||||
device_type = (
|
||||
device_type
|
||||
if isinstance(device_type, str) and device_type != "mps"
|
||||
else "cpu"
|
||||
)
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (
|
||||
inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()
|
||||
).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
||||
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
|
||||
cos = cos * self.attention_scaling
|
||||
sin = sin * self.attention_scaling
|
||||
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
|
||||
|
||||
class Gemma3TextScaledWordEmbedding(nn.Embedding):
|
||||
"""
|
||||
This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
padding_idx: int,
|
||||
embed_scale: Optional[float] = 1.0,
|
||||
):
|
||||
super().__init__(num_embeddings, embedding_dim, padding_idx)
|
||||
self.embed_scale = embed_scale
|
||||
|
||||
def forward(self, input_ids: torch.Tensor):
|
||||
return super().forward(input_ids) * self.embed_scale
|
||||
|
||||
|
||||
class Gemma3TextModel(PreTrainedModel):
|
||||
def __init__(
|
||||
self,
|
||||
config: Gemma3TextConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__(config=config)
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
# Gemma3 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5. See https://github.com/huggingface/transformers/pull/29402
|
||||
self.embed_tokens = Gemma3TextScaledWordEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
self.padding_idx,
|
||||
embed_scale=self.config.hidden_size**0.5,
|
||||
)
|
||||
|
||||
self.norm = Gemma3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.rotary_emb = Gemma3RotaryEmbedding(config=config)
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# when we want to create a local RoPE layer. Config defaults should hold values for global RoPE
|
||||
config = copy.deepcopy(config)
|
||||
config.rope_theta = config.rope_local_base_freq
|
||||
config.rope_scaling = {"rope_type": "default"}
|
||||
self.rotary_emb_local = Gemma3RotaryEmbedding(config=config)
|
||||
|
||||
self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda idx, prefix: Gemma3DecoderLayer(
|
||||
layer_id=idx,
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
),
|
||||
prefix=add_prefix("layers", prefix),
|
||||
)
|
||||
self.norm = Gemma3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_init()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
if input_embeds is None:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
else:
|
||||
hidden_states = input_embeds
|
||||
|
||||
if len(positions.shape) == 1:
|
||||
positions = einops.rearrange(positions, "s -> 1 s")
|
||||
|
||||
position_embeddings_global = self.rotary_emb(hidden_states, positions)
|
||||
position_embeddings_local = self.rotary_emb_local(hidden_states, positions)
|
||||
for layer in self.layers:
|
||||
layer_outputs = layer(
|
||||
positions=positions,
|
||||
position_embeddings_global=position_embeddings_global,
|
||||
position_embeddings_local=position_embeddings_local,
|
||||
hidden_states=hidden_states,
|
||||
forward_batch=forward_batch,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Gemma3ForCausalLM(PreTrainedModel):
|
||||
config_class = Gemma3TextConfig
|
||||
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||
config_class = Gemma3TextConfig
|
||||
base_model_prefix = "language_model"
|
||||
|
||||
# BitandBytes specific attributes
|
||||
default_bitsandbytes_target_modules = [
|
||||
".gate_proj.",
|
||||
".down_proj.",
|
||||
".up_proj.",
|
||||
".q_proj.",
|
||||
".k_proj.",
|
||||
".v_proj.",
|
||||
".o_proj.",
|
||||
]
|
||||
bitsandbytes_stacked_params_mapping = {
|
||||
# shard_name, weight_name, index
|
||||
"q_proj": ("qkv_proj", 0),
|
||||
"k_proj": ("qkv_proj", 1),
|
||||
"v_proj": ("qkv_proj", 2),
|
||||
"gate_proj": ("gate_up_proj", 0),
|
||||
"up_proj": ("gate_up_proj", 1),
|
||||
}
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
"qkv_proj",
|
||||
"o_proj",
|
||||
"gate_up_proj",
|
||||
"down_proj",
|
||||
]
|
||||
# Gemma does not apply LoRA to the embedding layer.
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
supports_lora = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Gemma3TextConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__(config=config)
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = Gemma3TextModel(
|
||||
config, quant_config, prefix=add_prefix("model", prefix)
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head = self.model.embed_tokens
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
)
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embed_tokens
|
||||
|
||||
def dtype(self) -> torch.dtype:
|
||||
return self.model.layers[0].mlp.gate_up_proj.weight.dtype
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
**kwargs,
|
||||
) -> LogitsProcessor:
|
||||
|
||||
hidden_states = self.model(
|
||||
input_ids, positions, forward_batch, input_embeds, **kwargs
|
||||
)
|
||||
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.model.embed_tokens, forward_batch
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
for param_name, shard_name, shard_id in stacked_params_mapping:
|
||||
# if param_name in name:
|
||||
# print(f"{param_name} is already in {name}")
|
||||
if shard_name not in name:
|
||||
continue
|
||||
name = name.replace(shard_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# lm_head is not used in vllm as it is tied with embed_token.
|
||||
# To prevent errors, skip loading lm_head.weight.
|
||||
if "lm_head.weight" in name:
|
||||
continue
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
# unloaded_params = params_dict.keys() - loaded_params
|
||||
# if unloaded_params:
|
||||
# logger.warning(
|
||||
# "Some weights are not initialized from checkpoints: %s", unloaded_params
|
||||
# )
|
||||
return loaded_params
|
||||
|
||||
|
||||
EntryClass = Gemma3ForCausalLM
|
||||
AutoModel.register(Gemma3TextConfig, Gemma3ForCausalLM, exist_ok=True)
|
||||
462
python/sglang/srt/models/gemma3_mm.py
Normal file
462
python/sglang/srt/models/gemma3_mm.py
Normal file
@@ -0,0 +1,462 @@
|
||||
# Copyright 2025 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
# Adapted from:
|
||||
# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/gemma3_mm.py
|
||||
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from typing import Dict, Iterable, List, Optional, Set, Tuple, TypedDict
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import AutoModel, PreTrainedModel
|
||||
|
||||
from sglang.srt.configs import Gemma3Config
|
||||
from sglang.srt.hf_transformers_utils import get_processor
|
||||
from sglang.srt.layers.layernorm import Gemma3RMSNorm
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.managers.multi_modality_padding import (
|
||||
MultiModalityDataPaddingPatternTokenPairs,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import ImageInputs
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import (
|
||||
default_weight_loader,
|
||||
maybe_remap_kv_scale_name,
|
||||
)
|
||||
from sglang.srt.models.gemma3_causal import Gemma3ForCausalLM
|
||||
from sglang.srt.utils import add_prefix
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
cached_get_processor = lru_cache(get_processor)
|
||||
|
||||
|
||||
class Gemma3ImagePixelInputs(TypedDict):
|
||||
pixel_values: torch.Tensor
|
||||
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
|
||||
|
||||
|
||||
class Gemma3MultiModalProjector(nn.Module):
|
||||
"""Projector for Gemma3 multimodal."""
|
||||
|
||||
def __init__(self, config: Gemma3Config):
|
||||
super().__init__()
|
||||
|
||||
self.mm_input_projection_weight = nn.Parameter(
|
||||
torch.zeros(
|
||||
config.vision_config.hidden_size, config.text_config.hidden_size
|
||||
)
|
||||
)
|
||||
|
||||
self.mm_soft_emb_norm = Gemma3RMSNorm(
|
||||
config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps
|
||||
)
|
||||
|
||||
self.patches_per_image = int(
|
||||
config.vision_config.image_size // config.vision_config.patch_size
|
||||
)
|
||||
self.tokens_per_side = int(config.mm_tokens_per_image**0.5)
|
||||
self.kernel_size = self.patches_per_image // self.tokens_per_side
|
||||
self.avg_pool = nn.AvgPool2d(
|
||||
kernel_size=self.kernel_size, stride=self.kernel_size
|
||||
)
|
||||
|
||||
def forward(self, vision_outputs: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, seq_length, hidden_size = vision_outputs.shape
|
||||
|
||||
# Reshape for pooling
|
||||
reshaped_vision_outputs = vision_outputs.transpose(1, 2)
|
||||
reshaped_vision_outputs = reshaped_vision_outputs.reshape(
|
||||
batch_size, hidden_size, self.patches_per_image, self.patches_per_image
|
||||
)
|
||||
reshaped_vision_outputs = reshaped_vision_outputs.contiguous()
|
||||
|
||||
# Apply pooling
|
||||
pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs)
|
||||
pooled_vision_outputs = pooled_vision_outputs.flatten(2)
|
||||
pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2)
|
||||
|
||||
# Apply normalization
|
||||
normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs)
|
||||
|
||||
# Project to text embedding space
|
||||
projected_vision_outputs = torch.matmul(
|
||||
normed_vision_outputs, self.mm_input_projection_weight
|
||||
)
|
||||
|
||||
return projected_vision_outputs.type_as(vision_outputs)
|
||||
|
||||
|
||||
class Gemma3ForConditionalGeneration(PreTrainedModel):
|
||||
config_class = Gemma3Config
|
||||
"""Gemma3 multimodal model for conditional generation."""
|
||||
|
||||
# BitandBytes specific attributes
|
||||
default_bitsandbytes_target_modules = [
|
||||
".gate_proj.",
|
||||
".down_proj.",
|
||||
".up_proj.",
|
||||
".q_proj.",
|
||||
".k_proj.",
|
||||
".v_proj.",
|
||||
".o_proj.",
|
||||
]
|
||||
bitsandbytes_stacked_params_mapping = {
|
||||
# shard_name, weight_name, index
|
||||
"q_proj": ("qkv_proj", 0),
|
||||
"k_proj": ("qkv_proj", 1),
|
||||
"v_proj": ("qkv_proj", 2),
|
||||
"gate_proj": ("gate_up_proj", 0),
|
||||
"up_proj": ("gate_up_proj", 1),
|
||||
}
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
"qkv_proj",
|
||||
"o_proj",
|
||||
"gate_up_proj",
|
||||
"down_proj",
|
||||
]
|
||||
# Gemma does not apply LoRA to the embedding layer.
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
supports_lora = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Gemma3Config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__(config=config)
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
# Vision components
|
||||
# TODO: replace with vision attention
|
||||
# self.vision_tower = SiglipVisionModel(
|
||||
# config.vision_config,
|
||||
# quant_config,
|
||||
# prefix=add_prefix("vision_tower", prefix),
|
||||
# )
|
||||
self.vision_tower = AutoModel.from_config(config=config.vision_config)
|
||||
self.multi_modal_projector = Gemma3MultiModalProjector(config)
|
||||
self.vocab_size = config.text_config.vocab_size
|
||||
|
||||
# Text model
|
||||
self.language_model = Gemma3ForCausalLM(
|
||||
config.text_config, quant_config, prefix=add_prefix("model", prefix)
|
||||
)
|
||||
if self.language_model.logits_processor.logit_scale:
|
||||
logit_scale = getattr(config, "logit_scale", 1.0)
|
||||
self.language_model.logits_processor.logit_scale *= logit_scale
|
||||
self.post_init()
|
||||
|
||||
def pad_input_ids(
|
||||
self, input_ids: List[int], image_inputs: ImageInputs
|
||||
) -> List[int]:
|
||||
"""Pad input IDs with image tokens."""
|
||||
# Get special token IDs
|
||||
im_start_id: int = image_inputs.im_start_id
|
||||
im_end_id: int = image_inputs.im_end_id
|
||||
|
||||
media_token_pairs = [(im_start_id, im_end_id)]
|
||||
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
|
||||
ids = pattern.pad_input_tokens(input_ids, image_inputs)
|
||||
return ids
|
||||
|
||||
def prepare_attn_masks(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
mask_dtype: torch.dtype,
|
||||
**kwargs,
|
||||
) -> Dict:
|
||||
"""Prepare attention masks for multimodal inputs."""
|
||||
kwargs["has_images"] = True
|
||||
|
||||
# Distinguish sequences by position id 0
|
||||
start_indices = (positions == 0).cpu().nonzero()
|
||||
num_seqs = len(start_indices)
|
||||
seq_lens = []
|
||||
|
||||
for i in range(num_seqs):
|
||||
start_idx = start_indices[i].item()
|
||||
if i < num_seqs - 1:
|
||||
end_idx = start_indices[i + 1].item()
|
||||
else:
|
||||
end_idx = len(input_ids)
|
||||
seq_lens.append(end_idx - start_idx)
|
||||
|
||||
kwargs["seq_lens"] = seq_lens
|
||||
|
||||
# Create attention masks
|
||||
global_attn_masks = []
|
||||
local_attn_masks = []
|
||||
sliding_window = self.config.text_config.interleaved_sliding_window
|
||||
|
||||
start_idx = 0
|
||||
for seq_len in seq_lens:
|
||||
end_idx = start_idx + seq_len
|
||||
input_token_ids = input_ids[start_idx:end_idx]
|
||||
start_idx = end_idx
|
||||
|
||||
# Create global causal mask
|
||||
global_attn_mask = torch.empty(
|
||||
1,
|
||||
1,
|
||||
seq_len,
|
||||
seq_len,
|
||||
dtype=mask_dtype,
|
||||
device=input_ids.device,
|
||||
)
|
||||
global_attn_mask.fill_(float("-inf"))
|
||||
global_attn_mask = global_attn_mask.triu(diagonal=1)
|
||||
|
||||
# Consider bidirectional attention between image tokens
|
||||
img_mask = torch.zeros_like(global_attn_mask)
|
||||
img_pos = input_token_ids == self.config.image_token_index
|
||||
img_mask[:, :, :, img_pos] += 1
|
||||
img_mask[:, :, img_pos, :] += 1
|
||||
global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask)
|
||||
global_attn_masks.append(global_attn_mask)
|
||||
|
||||
# Create local causal mask with sliding window
|
||||
local_attn_mask = torch.ones_like(global_attn_mask)
|
||||
local_attn_mask = torch.tril(local_attn_mask, diagonal=-sliding_window)
|
||||
local_attn_mask = torch.where(
|
||||
local_attn_mask == 0, global_attn_mask, float("-inf")
|
||||
)
|
||||
local_attn_masks.append(local_attn_mask)
|
||||
|
||||
kwargs["global_attn_masks"] = global_attn_masks
|
||||
kwargs["local_attn_masks"] = local_attn_masks
|
||||
return kwargs
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.language_model.get_input_embeddings()
|
||||
|
||||
def get_image_features(self, pixel_values: torch.Tensor):
|
||||
"""
|
||||
Projects the last hidden state from the vision model into language model space.
|
||||
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
|
||||
The tensors corresponding to the input images.
|
||||
Returns:
|
||||
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
||||
"""
|
||||
pixel_values = pixel_values.to("cuda")
|
||||
pixel_values = pixel_values.to(dtype=self.language_model.dtype())
|
||||
|
||||
vision_outputs = self.vision_tower(pixel_values=pixel_values).last_hidden_state
|
||||
image_features = self.multi_modal_projector(vision_outputs)
|
||||
return image_features
|
||||
|
||||
def embed_image_inputs(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
image_input: ImageInputs,
|
||||
) -> torch.Tensor:
|
||||
if input_ids is None:
|
||||
raise ValueError("Unimplemented")
|
||||
# boolean-masking image tokens
|
||||
special_image_mask = torch.isin(
|
||||
input_ids,
|
||||
torch.tensor(image_input.pad_values, device=input_ids.device),
|
||||
).unsqueeze(-1)
|
||||
num_image_tokens_in_input_ids = special_image_mask.sum()
|
||||
|
||||
inputs_embeds = None
|
||||
if num_image_tokens_in_input_ids == 0:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
return inputs_embeds
|
||||
else:
|
||||
# print(f"image tokens from input_ids: {inputs_embeds[special_image_mask].numel()}")
|
||||
image_features = self.get_image_features(image_input.pixel_values)
|
||||
|
||||
# print(f"image tokens from image embeddings: {image_features.numel()}")
|
||||
num_image_tokens_in_embedding = (
|
||||
image_features.shape[0] * image_features.shape[1]
|
||||
)
|
||||
|
||||
if num_image_tokens_in_input_ids != num_image_tokens_in_embedding:
|
||||
num_image = num_image_tokens_in_input_ids // image_features.shape[1]
|
||||
image_features = image_features[:num_image, :]
|
||||
logger.warning(
|
||||
f"Number of images does not match number of special image tokens in the input text. "
|
||||
f"Got {num_image_tokens_in_input_ids} image tokens in the text but {num_image_tokens_in_embedding} "
|
||||
"tokens from image embeddings."
|
||||
)
|
||||
|
||||
# Important: clamp after extracting original image boundaries
|
||||
input_ids.clamp_(min=0, max=self.vocab_size - 1)
|
||||
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
|
||||
inputs_embeds.device
|
||||
)
|
||||
|
||||
image_features = image_features.to(
|
||||
inputs_embeds.device, inputs_embeds.dtype
|
||||
)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(
|
||||
special_image_mask, image_features
|
||||
)
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
**kwargs: object,
|
||||
) -> LogitsProcessor:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
|
||||
|
||||
>>> model = Gemma3ForConditionalGeneration.from_pretrained("google/Gemma3-test-224px-hf")
|
||||
>>> processor = AutoProcessor.from_pretrained("google/Gemma3-test-224px-hf")
|
||||
|
||||
>>> prompt = "answer en Where is the cow standing?"
|
||||
>>> url = "https://huggingface.co/gv-hf/Gemma3-test-224px-hf/resolve/main/cow_beach_1.png"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(**inputs, max_length=30)
|
||||
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"answer en Where is the cow standing?\nbeach"
|
||||
```"""
|
||||
|
||||
# Important: position_ids in Gemma3 are 1-indexed
|
||||
# This really does cost me sometime
|
||||
positions += 1
|
||||
|
||||
# Replace image id with PAD if the image token if OOV, to avoid index-errors
|
||||
if input_ids is not None and self.config.image_token_index >= self.vocab_size:
|
||||
special_image_mask = input_ids == self.config.image_token_index
|
||||
llm_input_ids = input_ids.clone()
|
||||
llm_input_ids[special_image_mask] = 0
|
||||
else:
|
||||
llm_input_ids = input_ids
|
||||
|
||||
merged_image_input = forward_batch.get_merged_image_inputs()
|
||||
|
||||
if (
|
||||
not forward_batch.forward_mode.is_decode()
|
||||
and merged_image_input is not None
|
||||
):
|
||||
inputs_embeds = self.embed_image_inputs(
|
||||
input_ids=llm_input_ids,
|
||||
forward_batch=forward_batch,
|
||||
image_input=merged_image_input,
|
||||
)
|
||||
else:
|
||||
llm_input_ids.clamp_(min=0, max=self.vocab_size - 1)
|
||||
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
|
||||
|
||||
outputs = self.language_model(
|
||||
input_ids=None,
|
||||
positions=positions,
|
||||
forward_batch=forward_batch,
|
||||
input_embeds=inputs_embeds,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
def tie_weights(self):
|
||||
return self.language_model.tie_weights()
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
"""Load weights for the model."""
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
if "language_model" in name:
|
||||
# Gemma3ForCausalLM.load_weights(self, [(name.replace("language_model.", ""), loaded_weight)])
|
||||
causal_loaded_params = Gemma3ForCausalLM.load_weights(
|
||||
self, [(name, loaded_weight)]
|
||||
)
|
||||
loaded_params.update(causal_loaded_params)
|
||||
continue
|
||||
else:
|
||||
# Skip lm_head.weight as it's tied with embed_tokens
|
||||
if "lm_head.weight" in name:
|
||||
continue
|
||||
|
||||
# Skip loading extra bias for GPTQ models
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
# Remapping the name of FP8 kv-scale
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
unloaded_params = params_dict.keys() - loaded_params
|
||||
if unloaded_params:
|
||||
pass
|
||||
# raise RuntimeError(
|
||||
# f"Some weights are not initialized from checkpoints: {unloaded_params}")
|
||||
return loaded_params
|
||||
|
||||
|
||||
EntryClass = Gemma3ForConditionalGeneration
|
||||
|
||||
AutoModel.register(Gemma3Config, Gemma3ForConditionalGeneration, exist_ok=True)
|
||||
@@ -41,7 +41,6 @@ from functools import lru_cache
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
from importlib.util import find_spec
|
||||
from io import BytesIO
|
||||
from multiprocessing import Pool
|
||||
from multiprocessing.reduction import ForkingPickler
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Protocol, Set, Tuple, Union
|
||||
@@ -454,8 +453,9 @@ def load_image(image_file: Union[str, bytes]):
|
||||
image = Image.open(BytesIO(image_file))
|
||||
elif image_file.startswith("http://") or image_file.startswith("https://"):
|
||||
timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
|
||||
response = requests.get(image_file, timeout=timeout)
|
||||
image = Image.open(BytesIO(response.content))
|
||||
response = requests.get(image_file, stream=True, timeout=timeout).raw
|
||||
image = Image.open(response)
|
||||
response.close()
|
||||
elif image_file.lower().endswith(("png", "jpg", "jpeg", "webp", "gif")):
|
||||
image = Image.open(image_file)
|
||||
elif image_file.startswith("data:"):
|
||||
|
||||
Reference in New Issue
Block a user