Files
sglang/python/sglang/srt/models/gemma3n_mm.py

535 lines
20 KiB
Python

import logging
import re
from functools import lru_cache
from typing import Iterable, List, Optional, Set, Tuple, TypedDict, Union
import torch
from torch import nn
from transformers import (
Gemma3nAudioConfig,
Gemma3nConfig,
Gemma3nTextConfig,
Gemma3nVisionConfig,
PreTrainedModel,
)
from transformers.models.auto.modeling_auto import AutoModel
from sglang.srt.hf_transformers_utils import get_processor
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternMultimodalTokens,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import (
Modality,
MultimodalDataItem,
MultimodalInputs,
flatten_nested_list,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
)
from sglang.srt.models.gemma3n_audio import Gemma3nAudioEncoder
from sglang.srt.models.gemma3n_causal import Gemma3nRMSNorm, Gemma3nTextModel
from sglang.srt.utils import add_prefix
logger = logging.getLogger(__name__)
cached_get_processor = lru_cache(get_processor)
class Gemma3nImagePixelInputs(TypedDict):
pixel_values: torch.Tensor
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
class Gemma3nAudioInputs(TypedDict):
input_features: torch.Tensor
"""Shape: `(batch_size * num_audio, seq_length, num_features)`"""
input_features_mask: torch.Tensor
"""Shape: `(batch_size * num_audio, seq_length)`"""
class Gemma3nMultimodalEmbedder(nn.Module):
"""Embeds token ids or soft tokens for multimodal content into language model space."""
def __init__(
self,
multimodal_config: Union[Gemma3nAudioConfig, Gemma3nVisionConfig],
text_config: Gemma3nTextConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.multimodal_hidden_size = multimodal_config.hidden_size
self.eps = multimodal_config.rms_norm_eps
self.vocab_offset = multimodal_config.vocab_offset
self.vocab_size = multimodal_config.vocab_size
self.text_hidden_size = text_config.hidden_size
self.embedding = VocabParallelEmbedding(
self.vocab_size,
self.multimodal_hidden_size,
quant_config=quant_config,
prefix=add_prefix("embedding", prefix),
)
self.hard_embedding_norm = Gemma3nRMSNorm(
self.multimodal_hidden_size,
eps=self.eps,
)
self.soft_embedding_norm = Gemma3nRMSNorm(
self.multimodal_hidden_size,
eps=self.eps,
)
self.embedding_projection = RowParallelLinear(
self.multimodal_hidden_size,
self.text_hidden_size,
bias=False,
quant_config=quant_config,
prefix=add_prefix("embedding_projection", prefix),
)
self.embedding_post_projection_norm = Gemma3nRMSNorm(
self.text_hidden_size,
eps=self.eps,
with_scale=False,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Embeds token ids or soft tokens for multimodal content into language model space.
Args:
input_ids: A torch.LongTensor containing the token ids to embed. Values should be in the range
`[vocab_offset, vocab_offset + vocab_size)`.
inputs_embeds: A torch.Tensor containing the soft tokens to embed.
Returns:
A torch.Tensor of embeddings with shape `[batch_size, seq_len, self.config.text_config.hidden_size]`.
"""
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You must specify exactly one of input_ids or inputs_embeds"
)
if inputs_embeds is not None:
emb_norm = self.soft_embedding_norm(inputs_embeds)
else:
# Handle out of vocab ids to prevent CUDA assertion failures
out_of_vocab_id = self.vocab_size - 1
adjusted_ids = input_ids - self.vocab_offset
adjusted_ids = torch.where(adjusted_ids < 0, out_of_vocab_id, adjusted_ids)
adjusted_ids = torch.where(
adjusted_ids >= self.vocab_size, out_of_vocab_id, adjusted_ids
)
hard_emb = self.embedding(adjusted_ids)
emb_norm = self.hard_embedding_norm(hard_emb)
emb_norm_proj, _ = self.embedding_projection(emb_norm)
return self.embedding_post_projection_norm(emb_norm_proj)
class Gemma3nForConditionalGeneration(PreTrainedModel):
config_class = Gemma3nConfig
"""Gemma3n multimodal model for conditional generation."""
# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
".out_proj.",
]
bitsandbytes_stacked_params_mapping = {
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
"out_proj": ("proj", 0),
}
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"gate_up_proj",
"down_proj",
]
# Gemma does not apply LoRA to the embedding layer
embedding_modules = {}
embedding_padding_modules = []
supports_lora = True
def __init__(
self,
config: Gemma3nConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__(config=config)
self.config = config
self.quant_config = quant_config
prefix = add_prefix("model", prefix)
# Vision components
# TODO: Use sglang's vision model
self.vision_tower = AutoModel.from_config(config=config.vision_config)
self.embed_vision = Gemma3nMultimodalEmbedder(
config.vision_config,
config.text_config,
quant_config=quant_config,
prefix=add_prefix("embed_vision", prefix),
)
# Audio components
self.embed_audio = Gemma3nMultimodalEmbedder(
config.audio_config,
config.text_config,
quant_config=quant_config,
prefix=add_prefix("embed_audio", prefix),
)
self.audio_tower = Gemma3nAudioEncoder(
config.audio_config,
quant_config=quant_config,
prefix=add_prefix("audio_tower", prefix),
)
self.vocab_size = config.text_config.vocab_size
self.vocab_size_per_layer_input = config.text_config.vocab_size_per_layer_input
# Text model
self.language_model = Gemma3nTextModel(
config.text_config,
quant_config,
prefix=add_prefix("language_model", prefix),
)
# Create logits processor for the multimodal model
self.logits_processor = LogitsProcessor(config.text_config)
self.post_init()
def pad_input_ids(
self,
input_ids: List[int],
mm_inputs: MultimodalInputs,
) -> List[int]:
"""Pad input IDs with image and audio tokens."""
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
return pattern.pad_input_tokens(input_ids, mm_inputs)
def get_input_embeddings(self) -> nn.Embedding:
return self.language_model.get_input_embeddings()
def get_attention_sliding_window_size(self):
return self.config.text_config.sliding_window - 1
def get_image_feature(self, items: List[MultimodalDataItem]):
"""
Projects the last hidden state from the vision model into language model space.
Returns:
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
"""
# Process images one by one to handle flatten_batch=True constraint in vision_tower
all_pixel_values = flatten_nested_list([item.feature for item in items])
vision_outputs_list = []
for pixel_values_batch in all_pixel_values:
# Normalize input shape to [batch_size, channels, height, width]
if pixel_values_batch.dim() == 5:
pixel_values_batch = pixel_values_batch.squeeze(0)
elif pixel_values_batch.dim() == 3:
pixel_values_batch = pixel_values_batch.unsqueeze(0)
elif pixel_values_batch.dim() != 4:
raise ValueError(
f"Unexpected pixel_values shape: {pixel_values_batch.shape}"
)
# Process each image in the batch
batch_size = pixel_values_batch.shape[0]
for i in range(batch_size):
pixel_value = pixel_values_batch[i : i + 1] # Keep batch dimension as 1
pixel_value = pixel_value.to(
device=self.vision_tower.device, dtype=self.language_model.dtype()
)
vision_outputs = self.vision_tower(
pixel_values=pixel_value, do_pooling=False, return_dict=True
).last_hidden_state
vision_outputs_list.append(vision_outputs)
# Concatenate all vision outputs
vision_outputs = torch.cat(vision_outputs_list, dim=0)
# Convert from (batch, channels, height, width) to (batch, height * width, channels)
vision_outputs = vision_outputs.reshape(
vision_outputs.shape[0],
self.config.vision_config.hidden_size,
self.config.vision_soft_tokens_per_image,
).permute(0, 2, 1)
# Normalize and embed the soft tokens into language model space
vision_outputs *= self.config.vision_config.hidden_size**0.5
return self.embed_vision(inputs_embeds=vision_outputs)
def get_audio_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
"""
Projects the last hidden state from the audio encoder into language model space.
Args:
items: List of multimodal data items containing audio data.
Returns:
audio_features (`torch.Tensor`): Audio feature tensor of shape `(num_audios, audio_length, embed_dim)`).
"""
# Extract audio features and masks from items
all_input_features = flatten_nested_list([item.feature for item in items])
all_input_features_mask = flatten_nested_list(
[~item.input_features_mask for item in items]
) # Note(Xinyuan): reverse the mask according to the HF implementation
# Process audio features one by one
audio_features_list = []
for input_features, input_features_mask in zip(
all_input_features, all_input_features_mask
):
# Ensure proper tensor format
if input_features.dim() == 2:
input_features = input_features.unsqueeze(0)
if input_features_mask.dim() == 1:
input_features_mask = input_features_mask.unsqueeze(0)
# Move to device and dtype
input_features = input_features.to(
device=next(self.audio_tower.parameters()).device,
dtype=self.language_model.dtype(),
)
input_features_mask = input_features_mask.to(device=input_features.device)
# Process through audio tower
audio_outputs, audio_mask = self.audio_tower(
input_features, input_features_mask
)
# Embed the audio outputs
audio_embeds = self.embed_audio(inputs_embeds=audio_outputs)
audio_features_list.append(audio_embeds)
# Concatenate all audio features
if audio_features_list:
audio_features = torch.cat(audio_features_list, dim=0)
# The Gemma3nProcessor expects all audio will be 30s in length and inserts 188 audio soft tokens into the
# text to account for this. However, the audio preprocessing and encoder do not gurarantee they will
# produce 188 soft tokens; they will produce at most that many tokens, but they may produce fewer tokens
# depending on the length of the longest audio input in the batch. When we encounter this situation, we pad
# the audio feature out to 188 soft tokens with the emebedding of the last token in the embed_audio vocab.
audio_padding_toks = torch.tensor(
[[self.vocab_size - 1]], dtype=torch.long, device=audio_features.device
)
audio_padding_embs = self.embed_audio(input_ids=audio_padding_toks)
audio_features = torch.where(
audio_mask.unsqueeze(-1), audio_padding_embs, audio_features
)
audio_batch_size, audio_seq_len, audio_embed_dim = audio_features.shape
extra_padding_tokens = (
self.config.audio_soft_tokens_per_image - audio_seq_len
)
extra_padding_features = audio_padding_embs.expand(
audio_batch_size, extra_padding_tokens, audio_embed_dim
)
audio_features = torch.cat((audio_features, extra_padding_features), dim=1)
return audio_features
else:
return torch.empty(
0,
0,
self.language_model.config.hidden_size,
device=next(self.parameters()).device,
dtype=self.language_model.dtype(),
)
def get_per_layer_inputs(
self, input_ids: torch.LongTensor
) -> Optional[torch.Tensor]:
return self.language_model.get_per_layer_inputs(input_ids)
def project_per_layer_inputs(
self,
inputs_embeds: torch.Tensor,
per_layer_inputs: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self.language_model.project_per_layer_inputs(
inputs_embeds, per_layer_inputs
)
@torch.no_grad()
def forward(
self,
input_ids: torch.LongTensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
**kwargs: object,
) -> LogitsProcessor:
"""Forward pass for multimodal Gemma3n."""
if (input_ids is None) ^ (input_embeds is not None):
raise ValueError(
"You must specify exactly one of input_ids or inputs_embeds"
)
positions += 1
if input_ids is not None:
# Prepare per-layer inputs from inputs_ids
per_layer_inputs_mask = torch.logical_and(
input_ids >= 0, input_ids < self.vocab_size_per_layer_input
)
per_layer_inputs_tokens = torch.where(
per_layer_inputs_mask, input_ids, torch.zeros_like(input_ids)
)
per_layer_inputs = self.language_model.get_per_layer_inputs(
per_layer_inputs_tokens
)
# Use general_mm_embed_routine for handling multimodal data
# This will automatically handle text, image, and audio embeddings
hidden_states = general_mm_embed_routine(
input_ids=input_ids,
forward_batch=forward_batch,
language_model=self.language_model,
data_embedding_funcs={
Modality.IMAGE: self.get_image_feature,
Modality.AUDIO: self.get_audio_feature,
},
positions=positions,
per_layer_inputs=per_layer_inputs,
)
# Process hidden states through logits processor
return self.logits_processor(
input_ids, hidden_states, self.language_model.embed_tokens, forward_batch
)
def tie_weights(self):
return self.language_model.tie_weights()
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".up_proj", 1),
(".gate_up_proj", ".gate_proj", 0),
]
"""Load weights for the model."""
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
name = re.sub(r"^model\.", "", name)
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
if "vision_model" in name:
# adapt to VisionAttention
name = name.replace(".self_attn.out_proj", ".self_attn.proj")
# Skip loading extra bias for GPTQ models
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
lora_pattern = re.compile(
r"^language_model\.layers\.(\d+)\.(?:self_attn|mlp)\.(?:qkv_proj|o_proj|down_proj|gate_up_proj)"
)
def should_apply_lora(self, module_name: str) -> bool:
return bool(self.lora_pattern.match(module_name))
def get_hidden_dim(self, module_name):
# return input_dim, output_dim
if module_name == "qkv_proj":
return (
self.config.hidden_size,
self.config.head_dim
* (
self.config.num_attention_heads
+ self.config.num_key_value_heads * 2
),
)
elif module_name == "o_proj":
return (
self.config.head_dim * self.config.num_attention_heads,
self.config.hidden_size,
)
elif module_name == "gate_up_proj":
assert len(set(self.config.intermediate_size)) == 1, (
"Currently SGLang requires uniform intermediate size for all layers. "
"Please file an issue if you need support for non-uniform intermediate sizes."
)
return self.config.hidden_size, self.config.intermediate_size[0] * 2
elif module_name == "down_proj":
assert len(set(self.config.intermediate_size)) == 1, (
"Currently SGLang requires uniform intermediate size for all layers. "
"Please file an issue if you need support for non-uniform intermediate sizes."
)
return self.config.intermediate_size[0], self.config.hidden_size
else:
raise NotImplementedError()
EntryClass = Gemma3nForConditionalGeneration