diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 6ddd2484a..63412791d 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -565,6 +565,7 @@ multimodal_model_archs = [ "CLIPModel", "DeepseekVL2ForCausalLM", "Gemma3ForConditionalGeneration", + "Gemma3nForConditionalGeneration", "Grok1VForCausalLM", "Grok1AForCausalLM", "LlavaLlamaForCausalLM", diff --git a/python/sglang/srt/conversation.py b/python/sglang/srt/conversation.py index 661f47700..ad292bbd9 100644 --- a/python/sglang/srt/conversation.py +++ b/python/sglang/srt/conversation.py @@ -823,6 +823,7 @@ register_conv_template( sep_style=SeparatorStyle.GEMMA3, stop_str=[""], image_token="", + audio_token="", ) ) diff --git a/python/sglang/srt/managers/multimodal_processors/base_processor.py b/python/sglang/srt/managers/multimodal_processors/base_processor.py index 618f66a2f..4b8732627 100644 --- a/python/sglang/srt/managers/multimodal_processors/base_processor.py +++ b/python/sglang/srt/managers/multimodal_processors/base_processor.py @@ -23,6 +23,7 @@ class MultimodalInputFormat(Enum): RAW_IMAGES = "raw_images" PRECOMPUTED_FEATURES = "precomputed_features" PIXEL_VALUES = "pixel_values" + AUDIO = "audio" @dataclasses.dataclass @@ -441,10 +442,13 @@ class BaseMultimodalProcessor(ABC): has_image = False has_pixel_values = False has_precomputed_features = False + has_audio = False for mm_input in mm_inputs: if isinstance(mm_input, Image.Image): has_image = True + elif isinstance(mm_input, np.ndarray): + has_audio = True elif isinstance(mm_input, dict): if mm_input.get("precomputed_features", None) is not None: has_precomputed_features = True @@ -461,13 +465,13 @@ class BaseMultimodalProcessor(ABC): # Validate format consistency format_count = sum( - [has_image, has_pixel_values, has_precomputed_features] + [has_image, has_pixel_values, has_precomputed_features, has_audio] ) if format_count > 1: raise ValueError( "Unsupported: mixture of multimodal input formats. " f"Found formats: image={has_image}, pixel_values={has_pixel_values}, " - f"precomputed_features={has_precomputed_features}" + f"precomputed_features={has_precomputed_features}, audio={has_audio}" ) if has_image: @@ -476,6 +480,8 @@ class BaseMultimodalProcessor(ABC): return MultimodalInputFormat.PRECOMPUTED_FEATURES elif has_pixel_values: return MultimodalInputFormat.PIXEL_VALUES + elif has_audio: + return MultimodalInputFormat.AUDIO else: raise ValueError("No valid multimodal input format found") except Exception as e: @@ -521,20 +527,47 @@ class BaseMultimodalProcessor(ABC): input_ids = tokenize_text(base_output.input_text) return combined_mm_item, input_ids + def process_audio( + base_output: BaseMultiModalProcessorOutput, + ) -> Tuple[MultimodalDataItem, torch.Tensor]: + """Process inputs with audio.""" + ret = self.process_mm_data( + input_text=base_output.input_text, + audio=base_output.audios, # Note: "audio" is for gemma3n only + ) + combined_mm_item = MultimodalDataItem(modality=Modality.AUDIO) + for key, value in ret.items(): + if key != "input_ids" and hasattr(combined_mm_item, key): + setattr(combined_mm_item, key, value) + input_ids = ret["input_ids"].flatten() + return combined_mm_item, input_ids + def finalize_mm_item( combined_mm_item: MultimodalDataItem, input_ids: torch.Tensor ) -> MultimodalDataItem: """Apply common post-processing to the multimodal item.""" - combined_mm_item.image_offsets = self.get_mm_items_offset( - input_ids=input_ids, - mm_token_id=self.IM_TOKEN_ID, - ) + if combined_mm_item.modality in [Modality.IMAGE, Modality.MULTI_IMAGES]: + combined_mm_item.image_offsets = self.get_mm_items_offset( + input_ids=input_ids, + mm_token_id=self.IM_TOKEN_ID, + ) + elif combined_mm_item.modality == Modality.AUDIO: + combined_mm_item.audio_offsets = self.get_mm_items_offset( + input_ids=input_ids, + mm_token_id=self.AUDIO_TOKEN_ID, + ) + elif combined_mm_item.modality == Modality.VIDEO: + combined_mm_item.video_offsets = self.get_mm_items_offset( + input_ids=input_ids, + mm_token_id=self.VIDEO_TOKEN_ID, + ) + else: + raise ValueError(f"Unknown modality: {combined_mm_item.modality}") return combined_mm_item - # Main logic - mm_inputs = base_output.images + # Main logic - determine input type and handle text-only case + mm_inputs = base_output.images or base_output.audios if not mm_inputs: - # Return text-only case input_ids = tokenize_text(base_output.input_text) return None, input_ids @@ -548,6 +581,8 @@ class BaseMultimodalProcessor(ABC): combined_mm_item, input_ids = process_precomputed_features(base_output) elif input_format == MultimodalInputFormat.PIXEL_VALUES: combined_mm_item, input_ids = process_pixel_values(base_output) + elif input_format == MultimodalInputFormat.AUDIO: + combined_mm_item, input_ids = process_audio(base_output) else: raise ValueError(f"Unknown input format: {input_format}") diff --git a/python/sglang/srt/managers/multimodal_processors/gemma3n.py b/python/sglang/srt/managers/multimodal_processors/gemma3n.py new file mode 100644 index 000000000..319d4eb40 --- /dev/null +++ b/python/sglang/srt/managers/multimodal_processors/gemma3n.py @@ -0,0 +1,97 @@ +# Copyright 2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import re +from typing import Dict, List, Optional, Union + +from sglang.srt.managers.multimodal_processor import ( + BaseMultimodalProcessor as SGLangBaseProcessor, +) +from sglang.srt.managers.multimodal_processors.base_processor import ( + MultimodalSpecialTokens, +) +from sglang.srt.models.gemma3n_mm import Gemma3nForConditionalGeneration + + +class Gemma3nSGLangProcessor(SGLangBaseProcessor): + """Multimodal processor for Gemma3n supporting image and audio inputs.""" + + models = [Gemma3nForConditionalGeneration] + + def __init__(self, hf_config, server_args, _processor): + super().__init__(hf_config, server_args, _processor) + + self.IMAGE_TOKEN = "" + self.IMAGE_TOKEN_REGEX = re.compile( + r"(?:(?:)*)?" + ) + + self.AUDIO_TOKEN = "" + self.AUDIO_TOKEN_REGEX = re.compile( + r"(?:(?:)*)?" + ) + + self.IM_TOKEN_ID = hf_config.image_token_id + self.IM_START_TOKEN_ID = hf_config.boi_token_id + self.IM_END_TOKEN_ID = hf_config.eoi_token_id + + self.AUDIO_TOKEN_ID = hf_config.audio_token_id + self.AUDIO_START_TOKEN_ID = hf_config.boa_token_id + self.AUDIO_END_TOKEN_ID = hf_config.eoa_token_id + + async def process_mm_data_async( + self, + image_data: Optional[List[Union[str, bytes, Dict]]] = None, + audio_data: Optional[List[Union[str, bytes, Dict]]] = None, + input_text: str = "", + request_obj=None, + max_req_input_len: int = 0, + *args, + **kwargs, + ): + """Process multimodal data including images and audio.""" + + audio_data = request_obj.audio_data + if not image_data and not audio_data: + return None + + if isinstance(image_data, str): + image_data = [image_data] + + if isinstance(audio_data, str): + audio_data = [audio_data] + + base_output = self.load_mm_data( + prompt=input_text, + image_data=image_data, + audio_data=audio_data, + max_req_input_len=max_req_input_len, + multimodal_tokens=MultimodalSpecialTokens( + image_token=self.IMAGE_TOKEN, + image_token_regex=self.IMAGE_TOKEN_REGEX, + audio_token=self.AUDIO_TOKEN, + audio_token_regex=self.AUDIO_TOKEN_REGEX, + ), + ) + + combined_mm_item, input_ids = self.process_and_combine_mm_data(base_output) + + return { + "input_ids": input_ids.tolist(), + "mm_items": [combined_mm_item] if combined_mm_item is not None else [], + "im_start_id": self.IM_START_TOKEN_ID, + "im_end_id": self.IM_END_TOKEN_ID, + "audio_start_id": self.AUDIO_START_TOKEN_ID, + "audio_end_id": self.AUDIO_END_TOKEN_ID, + } diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index e19707340..85e8500dc 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -214,6 +214,10 @@ class MultimodalDataItem: audio_feature_lens: Optional[List[torch.Tensor]] = None audio_offsets: Optional[List[Tuple[int, int]]] = None + # gemma3n related + input_features: Optional[torch.Tensor] = None + input_features_mask: Optional[torch.Tensor] = None + precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None @staticmethod @@ -277,7 +281,10 @@ class MultimodalDataItem: if self.precomputed_features is not None: self.hash = hash_feature(self.precomputed_features) elif self.is_audio(): - self.hash = hash_feature(self.audio_features) + if self.audio_features is not None: + self.hash = hash_feature(self.audio_features) + elif self.input_features is not None: + self.hash = hash_feature(self.input_features) else: self.hash = hash_feature(self.pixel_values) @@ -288,6 +295,7 @@ class MultimodalDataItem: return (self.modality == Modality.AUDIO) and ( self.precomputed_features is not None or not MultimodalDataItem.is_empty_list(self.audio_features) + or not MultimodalDataItem.is_empty_list(self.input_features) ) def is_image(self): diff --git a/python/sglang/srt/models/gemma3n_audio.py b/python/sglang/srt/models/gemma3n_audio.py new file mode 100644 index 000000000..57a743c9f --- /dev/null +++ b/python/sglang/srt/models/gemma3n_audio.py @@ -0,0 +1,949 @@ +import math +from typing import Optional, Sequence, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import Gemma3nAudioConfig, PreTrainedModel + +from sglang.srt.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.models.gemma3n_causal import Gemma3nRMSNorm +from sglang.srt.utils import add_prefix, make_layers + + +class Gemma3nCumulativeGroupNorm(nn.Module): + """Applies Group Normalization cumulatively over the time dimension. + + This layer normalizes the input by calculating the mean and variance + cumulatively over the time dimension (dim 1). The statistics are computed + over all feature dimensions (specified by `feature_dims` and `num_channels`) + for elements marked as valid by the optional `mask`. + + If a `mask` is provided (True for valid, False for invalid/padded), + invalid time steps do not contribute to the statistics calculation, and + their corresponding output values are zeroed out. + + Scale and bias, if enabled, are applied per-channel (last dimension). + This behavior is similar to JAX's `GroupNormalization` with `num_groups=1` + and `cumulative=True`. + """ + + def __init__( + self, + num_channels: int, # Number of channels (size of the last dimension) + feature_dims: Sequence[ + int + ], # Sizes of non-channel feature dimensions, e.g., (H, W) for input [B,T,H,W,C] + eps: float = 1e-3, + ): + super().__init__() + self.num_channels = num_channels + self.feature_dims = tuple(feature_dims) + self.eps = eps + + # Scale parameter depends only on the channel dimension + self.weight = nn.Parameter(torch.ones(num_channels)) + + # Axes for normalization: all dimensions except Batch (0) and Time (1). + # For input [B, T, *feature_dims, C], these are dims from 2 onwards. + self.reduction_axes = tuple(range(2, 2 + len(self.feature_dims) + 1)) + + def forward( + self, x: torch.Tensor, mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Applies cumulative group norm, optionally using a mask. + + Args: + x: Input tensor, shape [B, T, *feature_dims, C]. + mask: Optional boolean mask, shape [B, T]. True indicates a valid + (non-padded) time step. If None, all time steps are considered valid. + + Returns: + Normalized tensor with the same shape as x. + """ + expected_input_suffix = self.feature_dims + (self.num_channels,) + if x.shape[2:] != expected_input_suffix: + raise ValueError( + f"Input tensor shape suffix {x.shape[2:]} does not match expected" + f" suffix (feature_dims + num_channels) {expected_input_suffix}" + ) + + input_dtype = x.dtype + # Calculations are performed in float32 for numerical stability. + calc_dtype = torch.float32 + x_calc = x.to(calc_dtype) + + # Prepare a broadcastable mask (`mask_calc`). + # If no mask is provided, treat all elements as valid + # (mask_calc is all ones). + # Otherwise, expand the [B, T] mask to [B, T, 1, ..., 1] for broadcasting. + mask_calc = torch.ones_like(x_calc, dtype=calc_dtype) + + # Cumulative Statistics Calculation + # 1. Sum of values over reduction axes at each time step. + sum_values_at_t = torch.sum(x_calc, dim=self.reduction_axes, keepdim=True) + # 2. Cumulative sum of values over time. + cum_sum_values = torch.cumsum(sum_values_at_t, dim=1) + + # 3. Count of valid elements in the normalization group at each time step. + # (A "group" here consists of all features at a given Batch, Time). + elements_in_group_at_t = torch.sum( + mask_calc, dim=self.reduction_axes, keepdim=True + ) + # 4. Cumulative count of valid elements over time. + cum_count_elements = torch.cumsum(elements_in_group_at_t, dim=1) + # Avoid division by zero if all preceding elements were masked. + safe_cum_count_elements = torch.clamp(cum_count_elements, min=1.0) + + # 5. Cumulative mean. + cum_mean = cum_sum_values / safe_cum_count_elements + + # 6. Sum of squared differences from the cumulative mean. + # Only sum for valid elements: (x_calc - cum_mean)^2 * mask_calc. + # Using x_calc here for the difference, as cum_mean already accounts for masking. + squared_diff_from_mean = (x_calc - cum_mean).pow(2) + sum_sq_diff_at_t = torch.sum( + squared_diff_from_mean, dim=self.reduction_axes, keepdim=True + ) + + # 7. Cumulative sum of squared differences over time. + cum_sum_sq_diff = torch.cumsum(sum_sq_diff_at_t, dim=1) + + # 8. Cumulative variance. + cum_variance = cum_sum_sq_diff / safe_cum_count_elements + + # Normalize the input using the calculated cumulative statistics: + # (x - E[x]) / sqrt(Var[x] + eps) + normalized_x = (x_calc - cum_mean) * torch.rsqrt(cum_variance + self.eps) + + # Apply affine transformation (scale and bias) if enabled. + # Scale and bias are applied per-channel (last dimension). + scale = self.weight.to(calc_dtype) + # Reshape for broadcasting: [C] -> [1, ..., 1, C] + scale_view_shape = [1] * (x.dim() - 1) + [self.num_channels] + normalized_x = normalized_x * scale.view(scale_view_shape) + + # Zero out outputs for time steps that were originally masked (where mask_calc is 0). + # This ensures padded/invalid positions in the input result in zero output. + final_output = normalized_x * mask_calc + + return final_output.to(input_dtype) + + +class Gemma3nAudioRelativePositionEmbedding(nn.Module): + def __init__( + self, + config: Gemma3nAudioConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + + self.num_heads = self.config.conf_num_attention_heads + self.channels = self.config.hidden_size + self.head_dim = self.channels // self.num_heads + self.max_backward = max(0, self.config.conf_attention_context_left - 1) + self.max_forward = self.config.conf_attention_context_right + + self.pos_proj = ColumnParallelLinear( + self.channels, + self.num_heads * self.head_dim, + bias=False, + quant_config=quant_config, + prefix=add_prefix("pos_proj", prefix), + ) + + min_timescale = 1.0 + max_timescale = 1.0e4 + num_timescales = self.channels // 2 + log_timescale_increment = math.log( + float(max_timescale) / float(min_timescale) + ) / max(num_timescales - 1, 1) + inv_timescales = min_timescale * torch.exp( + torch.arange(num_timescales) * -log_timescale_increment + ) + self.register_buffer( + "inv_timescales", + inv_timescales.float().unsqueeze(0).unsqueeze(0), + persistent=False, + ) + + def _get_timing_signal_1d_pos( + self, position: torch.Tensor, dtype: torch.dtype + ) -> torch.Tensor: + assert position.ndim == 2 + position = position.float().unsqueeze(-1) + scaled_time = position * self.inv_timescales.to( + device=position.device, dtype=torch.float32 + ) + timing_signal = torch.cat( + [torch.sin(scaled_time), torch.cos(scaled_time)], dim=-1 + ) + return timing_signal.type(dtype) + + def _relative_shift( + self, + term_bd_before_shift: torch.Tensor, + batch_size: int, + num_heads: int, + num_query_blocks: int, + query_block_size: int, + key_context_size: int, + max_span_plus_1: int, + ) -> torch.Tensor: + """Performs the relative shift.""" + pad_amount_last_dim = (key_context_size + 1) - max_span_plus_1 + padding_tuple = (0, pad_amount_last_dim) + + term_bd_padded = F.pad(term_bd_before_shift, padding_tuple) + term_bd_reshaped = term_bd_padded.reshape( + ( + batch_size, + num_heads, + num_query_blocks, + query_block_size * (key_context_size + 1), + ) + ) + term_bd_sliced = term_bd_reshaped[ + :, :, :, : query_block_size * key_context_size + ] + term_bd_shifted = term_bd_sliced.reshape( + ( + batch_size, + num_heads, + num_query_blocks, + query_block_size, + key_context_size, + ) + ) + return term_bd_shifted + + def forward(self, queries: torch.Tensor, keys: torch.Tensor) -> torch.Tensor: + batch_size, num_query_blocks, query_block_size, num_heads, head_dim = ( + queries.shape + ) + _, _, key_context_size, _, _ = keys.shape + + pos_indices = torch.arange( + self.max_backward, -self.max_forward - 1, -1, device=queries.device + ).unsqueeze(0) + max_span_plus_1 = pos_indices.shape[1] + + sin_emb_timing_signal = self._get_timing_signal_1d_pos( + pos_indices, dtype=queries.dtype + ) + projected_sin_emb, _ = self.pos_proj(sin_emb_timing_signal) + sin_emb = projected_sin_emb.reshape( + 1, max_span_plus_1, self.num_heads, self.head_dim + ).squeeze(0) + + queries_p = queries.permute(0, 3, 1, 2, 4) + keys_p_t = keys.permute(0, 3, 1, 4, 2) + term_ac = torch.matmul(queries_p, keys_p_t) + + q_permuted = queries.permute(0, 3, 1, 2, 4) + s_permuted = sin_emb.permute(1, 2, 0) + q_reshaped = q_permuted.reshape( + batch_size, num_heads, num_query_blocks * query_block_size, head_dim + ) + term_bd_unshifed_matmul = torch.matmul(q_reshaped, s_permuted) + term_bd_unshifed = term_bd_unshifed_matmul.reshape( + batch_size, + num_heads, + num_query_blocks, + query_block_size, + max_span_plus_1, + ) + + term_bd_shifted = self._relative_shift( + term_bd_unshifed, + batch_size, + num_heads, + num_query_blocks, + query_block_size, + key_context_size, + max_span_plus_1, + ) + + return term_ac + term_bd_shifted + + +class Gemma3nAudioAttention(nn.Module): + """Local dot product self-attention for audio.""" + + def __init__( + self, + config: Gemma3nAudioConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + + self.num_heads = self.config.conf_num_attention_heads + self.hidden_size = self.config.hidden_size + self.head_dim = self.hidden_size // self.num_heads + + self.chunk_size = self.config.conf_attention_chunk_size + self.max_future_horizon = self.config.conf_attention_context_right + self.max_past_horizon = max(0, self.config.conf_attention_context_left - 1) + self.attention_logits_soft_cap = self.config.conf_attention_logit_cap + self.context_size = ( + self.chunk_size + self.max_past_horizon + self.max_future_horizon + ) + + self.relative_position_embedding = Gemma3nAudioRelativePositionEmbedding( + config, + quant_config, + prefix=add_prefix("relative_position_embedding", prefix), + ) + self.per_dim_scale = nn.Parameter(torch.zeros((self.head_dim,))) + + self.qkv_proj = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.num_heads, + self.num_heads, + bias=False, + quant_config=quant_config, + prefix=add_prefix("qkv_proj", prefix), + ) + + q_scale = self.head_dim**-0.5 + r_softplus_0 = 1.0 / F.softplus(torch.tensor(0.0)) + self.register_buffer( + "q_scale", (q_scale * r_softplus_0).clone().detach(), persistent=False + ) + + # Create local causal mask + lower_causal_mask = torch.tril( + torch.ones((self.context_size, self.chunk_size), dtype=torch.bool), + diagonal=0, + ).T + upper_causal_mask = torch.tril( + torch.ones((self.chunk_size, self.context_size), dtype=torch.bool), + diagonal=self.max_past_horizon + self.max_future_horizon, + ) + local_causal_valid_mask = torch.ones( + (self.chunk_size, self.context_size), dtype=torch.bool + ) + local_causal_valid_mask = ( + local_causal_valid_mask * lower_causal_mask * upper_causal_mask + ) + self.register_buffer( + "local_causal_valid_mask", local_causal_valid_mask, persistent=False + ) + + self.register_buffer( + "softcap", + torch.tensor(self.attention_logits_soft_cap).float(), + persistent=False, + ) + + def _pad_dim1( + self, x: torch.Tensor, dim10_val: int, dim11_val: int + ) -> torch.Tensor: + padding_tuple = [0] * x.ndim * 2 + dim_idx_from_end = x.ndim - 2 + start_idx_for_dim = 2 * dim_idx_from_end + padding_tuple[start_idx_for_dim] = dim10_val + padding_tuple[start_idx_for_dim + 1] = dim11_val + return F.pad(x, tuple(padding_tuple)) + + def _convert_to_block(self, x: torch.Tensor) -> torch.Tensor: + """Turns a sequence to non overlapping blocks.""" + shape = x.shape + b, t = shape[:2] + num_blocks = (t + self.chunk_size - 1) // self.chunk_size + + if (padding_len := num_blocks * self.chunk_size - t) > 0: + x = self._pad_dim1(x, 0, padding_len) + + permute_dims = (b, num_blocks, self.chunk_size) + shape[2:] + x = x.reshape(permute_dims).contiguous() + return x + + def _extract_block_context(self, x: torch.Tensor) -> torch.Tensor: + """Extracts temporal context for every block.""" + pad_left = self.max_past_horizon + pad_right = self.max_future_horizon + self.chunk_size - 1 + x = self._pad_dim1(x, pad_left, pad_right) + + frame_len = self.context_size + frame_step = self.chunk_size + + x_unfolded = x.unfold(dimension=1, size=frame_len, step=frame_step) + + if x.ndim > 2 and x_unfolded.ndim > 3: + x_unfolded = torch.movedim(x_unfolded, source=-1, destination=2) + + return x_unfolded.contiguous() + + def forward(self, x: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor: + # Project to Q, K, V + qkv, _ = self.qkv_proj(x) + query_states, key_states, value_states = qkv.chunk(chunks=3, dim=-1) + + # Reshape + query_states = query_states.reshape( + *x.shape[:-1], self.num_heads, self.head_dim + ).contiguous() + key_states = key_states.reshape( + *x.shape[:-1], self.num_heads, self.head_dim + ).contiguous() + value_states = value_states.reshape( + *x.shape[:-1], self.num_heads, self.head_dim + ).contiguous() + + # Apply per-dim scale + per_dim_scale_sp = F.softplus(self.per_dim_scale) + broadcast_shape = (1, 1, 1, self.head_dim) + per_dim_scale_sp_broadcast = per_dim_scale_sp.view(broadcast_shape) + query_states = query_states * self.q_scale * per_dim_scale_sp_broadcast + + batch_size, q_time = query_states.shape[:2] + + # Convert to blocks + query_blocks = self._convert_to_block(query_states) + key_blocks = self._extract_block_context(key_states) + value_blocks = self._extract_block_context(value_states) + num_query_blocks = query_blocks.shape[1] + + # Create mask for valid positions + original_valid_mask = ~mask + extracted_valid_mask_blocks = self._extract_block_context(original_valid_mask) + + if ( + extracted_valid_mask_blocks.ndim == 4 + and extracted_valid_mask_blocks.shape[0] == batch_size + and extracted_valid_mask_blocks.shape[1] == num_query_blocks + and extracted_valid_mask_blocks.shape[2] + * extracted_valid_mask_blocks.shape[3] + == self.context_size + ): + extracted_valid_mask_blocks = extracted_valid_mask_blocks.reshape( + batch_size, num_query_blocks, self.context_size + ) + + condition_from_input_validity = extracted_valid_mask_blocks.unsqueeze( + 1 + ).unsqueeze(-2) + condition_from_causality = ( + self.local_causal_valid_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0) + ) + + final_condition_for_where = torch.logical_and( + condition_from_input_validity, + condition_from_causality.to(condition_from_input_validity.device), + ) + + # Compute attention scores + logits = self.relative_position_embedding(query_blocks, key_blocks) + + # Apply attention logit softcap + softcap_val = self.softcap.to(logits.device) + logits = logits / softcap_val + logits = torch.tanh(logits) + logits = logits * softcap_val + + # Apply the combined mask. + # final_condition_for_where will broadcast with logits [B,N,U,W,C] + logits = torch.where( + final_condition_for_where, logits, torch.finfo(logits.dtype).min + ) + + probabilities = F.softmax(logits, dim=-1, dtype=torch.float32).to( + dtype=value_blocks.dtype + ) + + # context_vectors is adapted from jax.numpy.einsum("BNuwc,BucNH->BuwNH", ...) + b_dim, n_dim, u_dim, w_dim, c_dim = probabilities.shape + h_dim = value_blocks.shape[-1] + prob_bun = probabilities.permute(0, 2, 1, 3, 4).reshape(-1, w_dim, c_dim) + v_bun = value_blocks.permute(0, 1, 3, 2, 4).reshape(-1, c_dim, h_dim) + result_bmm = torch.bmm(prob_bun, v_bun) + context_vectors = result_bmm.reshape(b_dim, u_dim, n_dim, w_dim, h_dim).permute( + 0, 1, 3, 2, 4 + ) + context_vectors = context_vectors.reshape( + ( + batch_size, + num_query_blocks * self.chunk_size, + self.num_heads, + self.head_dim, + ) + ) + context_vectors = context_vectors[:, :q_time] + + return context_vectors + + +class Gemma3nAudioSSCPConvBlock(nn.Module): + """A single convolution block for the SubSampleConvProjection.""" + + def __init__( + self, + config: Gemma3nAudioConfig, + idx: int, + input_freq_dim: int, + manual_padding: Tuple[int, int, int, int] = (0, 0, 0, 0), + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + self.manual_padding = manual_padding + + in_channels = 1 if idx == 0 else self.config.sscp_conv_channel_size[idx - 1] + out_channels = self.config.sscp_conv_channel_size[idx] + kernel_h, kernel_w = self.config.sscp_conv_kernel_size[idx] + stride_h, stride_w = self.config.sscp_conv_stride_size[idx] + + self.conv = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(kernel_h, kernel_w), + stride=(stride_h, stride_w), + padding=(0, 0), # Manual padding is used + bias=False, + ) + + f_in_padded = input_freq_dim + self.manual_padding[0] + self.manual_padding[1] + f_out_conv = (f_in_padded - kernel_w) // stride_w + 1 + + self.norm = Gemma3nCumulativeGroupNorm( + num_channels=out_channels, + feature_dims=(f_out_conv,), + eps=self.config.sscp_conv_group_norm_eps, + ) + + self.activation = nn.ReLU() + + def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor: + audio_encodings_padded = F.pad( + audio_encodings, self.manual_padding, mode="constant", value=0.0 + ) + audio_encodings_conv = self.conv(audio_encodings_padded) + x_for_norm = audio_encodings_conv.permute(0, 2, 3, 1).contiguous() + x_normed = self.norm(x_for_norm) + audio_encodings_normed = x_normed.permute(0, 3, 1, 2).contiguous() + return self.activation(audio_encodings_normed) + + +class Gemma3nAudioSubSampleConvProjection(nn.Module): + def __init__( + self, + config: Gemma3nAudioConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + + current_f_for_block_input = config.input_feat_size + calculated_block_padding = [] + calculated_f_out_dims = [] + + for i in range(2): # Assuming 2 conv layers + kernel_h, kernel_w = config.sscp_conv_kernel_size[i] + stride_h, stride_w = config.sscp_conv_stride_size[i] + + # Padding for Time (Height for Conv2d) - REVERSE_CAUSAL like + pad_t_top = 0 + pad_t_bottom = kernel_h - 1 + + # Frequency Padding (Width for Conv2d) + pad_f_left = 1 + pad_f_right = 1 + + manual_padding_tuple = (pad_f_left, pad_f_right, pad_t_top, pad_t_bottom) + calculated_block_padding.append(manual_padding_tuple) + + f_in_padded = current_f_for_block_input + pad_f_left + pad_f_right + f_out_after_conv = (f_in_padded - kernel_w) // stride_w + 1 + calculated_f_out_dims.append(f_out_after_conv) + current_f_for_block_input = f_out_after_conv + + self.conv_0 = Gemma3nAudioSSCPConvBlock( + idx=0, + input_freq_dim=config.input_feat_size, + config=config, + manual_padding=calculated_block_padding[0], + quant_config=quant_config, + prefix=add_prefix("conv_0", prefix), + ) + self.conv_1 = Gemma3nAudioSSCPConvBlock( + idx=1, + input_freq_dim=calculated_f_out_dims[0], + config=config, + manual_padding=calculated_block_padding[1], + quant_config=quant_config, + prefix=add_prefix("conv_1", prefix), + ) + + final_c_out = config.sscp_conv_channel_size[-1] + final_f_out = calculated_f_out_dims[-1] + self.input_proj_in_features = final_c_out * final_f_out + + self.input_proj_linear = RowParallelLinear( + self.input_proj_in_features, + self.config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=add_prefix("input_proj_linear", prefix), + ) + + def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor: + audio_encodings_reshaped = audio_encodings.unsqueeze(1) + x = self.conv_0(audio_encodings_reshaped) + x = self.conv_1(x) + b, c_out, t_out, f_out = x.shape + x_permuted = x.permute(0, 2, 3, 1).contiguous() + output_flattened = x_permuted.view(b, t_out, f_out * c_out) + output, _ = self.input_proj_linear(output_flattened) + return output + + +class Gemma3nAudioConformerAttention(nn.Module): + def __init__( + self, + config: Gemma3nAudioConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + + head_dim = self.config.hidden_size // self.config.conf_num_attention_heads + self.post_in_shape = (self.config.conf_num_attention_heads, head_dim) + self.post_in_features = self.config.hidden_size + + self.register_buffer( + "gradient_clipping", + torch.tensor(self.config.gradient_clipping), + persistent=False, + ) + + self.pre_attn_norm = Gemma3nRMSNorm(self.config.hidden_size) + self.attn = Gemma3nAudioAttention( + config, quant_config, prefix=add_prefix("attn", prefix) + ) + self.post = RowParallelLinear( + self.post_in_features, + self.config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=add_prefix("post", prefix), + ) + self.post_norm = Gemma3nRMSNorm(self.config.hidden_size) + + def forward( + self, audio_encodings: torch.Tensor, audio_mel_mask: torch.BoolTensor + ) -> torch.Tensor: + audio_encodings_input_to_attn = audio_encodings + audio_encodings = torch.clamp( + audio_encodings, -self.gradient_clipping, self.gradient_clipping + ) + audio_encodings_norm = self.pre_attn_norm(audio_encodings) + audio_encodings_attn_out = self.attn(audio_encodings_norm, audio_mel_mask) + + b, t, num_heads, head_dim = audio_encodings_attn_out.shape + audio_encodings_reshaped = audio_encodings_attn_out.reshape( + b, t, num_heads * head_dim + ) + + audio_encodings, _ = self.post(audio_encodings_reshaped) + audio_encodings = torch.clamp( + audio_encodings, -self.gradient_clipping, self.gradient_clipping + ) + return audio_encodings_input_to_attn + self.post_norm(audio_encodings) + + +class Gemma3nAudioConformerFeedForward(nn.Module): + def __init__( + self, + config: Gemma3nAudioConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + + self.register_buffer( + "gradient_clipping", + torch.tensor(self.config.gradient_clipping), + persistent=False, + ) + + self.pre_layer_norm = Gemma3nRMSNorm(self.config.hidden_size) + self.ffw_layer_1 = ColumnParallelLinear( + self.config.hidden_size, + self.config.hidden_size * 4, + bias=False, + quant_config=quant_config, + prefix=add_prefix("ffw_layer_1", prefix), + ) + self.ffw_layer_2 = RowParallelLinear( + self.config.hidden_size * 4, + self.config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=add_prefix("ffw_layer_2", prefix), + ) + self.post_layer_norm = Gemma3nRMSNorm(self.config.hidden_size) + self.post_layer_scale = torch.tensor(self.config.conf_residual_weight) + + def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor: + residual = audio_encodings + audio_encodings = torch.clamp( + audio_encodings, -self.gradient_clipping, self.gradient_clipping + ) + audio_encodings = self.pre_layer_norm(audio_encodings) + audio_encodings, _ = self.ffw_layer_1(audio_encodings) + audio_encodings = F.silu(audio_encodings) + audio_encodings, _ = self.ffw_layer_2(audio_encodings) + audio_encodings = torch.clamp( + audio_encodings, -self.gradient_clipping, self.gradient_clipping + ) + audio_encodings = self.post_layer_norm(audio_encodings) + return residual + (audio_encodings * self.post_layer_scale) + + +class Gemma3nAudioConformerLightConv1d(nn.Module): + def __init__( + self, + config: Gemma3nAudioConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + + self.pre_layer_norm = Gemma3nRMSNorm( + self.config.hidden_size, eps=self.config.rms_norm_eps + ) + self.linear_start = ColumnParallelLinear( + self.config.hidden_size, + self.config.hidden_size * 2, + bias=False, + quant_config=quant_config, + prefix=add_prefix("linear_start", prefix), + ) + + self.depthwise_conv1d = nn.Conv1d( + in_channels=self.config.hidden_size, + out_channels=self.config.hidden_size, + kernel_size=self.config.conf_conv_kernel_size, + stride=1, + padding=0, # Manual causal padding + groups=self.config.hidden_size, # Depthwise + bias=False, + ) + self.register_buffer( + "gradient_clipping", + torch.tensor(self.config.gradient_clipping), + persistent=False, + ) + self.conv_norm = Gemma3nRMSNorm( + self.config.hidden_size, eps=self.config.rms_norm_eps + ) + self.linear_end = RowParallelLinear( + self.config.hidden_size, + self.config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=add_prefix("linear_end", prefix), + ) + + self.causal_padding = self.config.conf_conv_kernel_size - 1 + + def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor: + audio_encodings_residual = audio_encodings # Save for residual connection + + audio_encodings = self.pre_layer_norm(audio_encodings) + audio_encodings, _ = self.linear_start(audio_encodings) + audio_encodings = F.glu(audio_encodings, dim=-1) + + # Permute for Conv1d: [B, T, D] -> [B, D, T] + audio_encodings_permuted = audio_encodings.permute(0, 2, 1) + # Apply manual causal padding + audio_encodings_permuted_padded = F.pad( + audio_encodings_permuted, (self.causal_padding, 0) + ) + audio_encodings = self.depthwise_conv1d(audio_encodings_permuted_padded) + # Permute back: [B, D, T_out] -> [B, T_out, D] + audio_encodings = audio_encodings.permute(0, 2, 1) + audio_encodings = torch.clamp( + audio_encodings, -self.gradient_clipping, self.gradient_clipping + ) + audio_encodings = self.conv_norm(audio_encodings) + audio_encodings = F.silu(audio_encodings) + audio_encodings, _ = self.linear_end(audio_encodings) + output = audio_encodings + audio_encodings_residual + return output + + +class Gemma3nAudioConformerBlock(nn.Module): + def __init__( + self, + config: Gemma3nAudioConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + + self.ffw_layer_start = Gemma3nAudioConformerFeedForward( + config, quant_config, prefix=add_prefix("ffw_layer_start", prefix) + ) + self.attention = Gemma3nAudioConformerAttention( + config, quant_config, prefix=add_prefix("attention", prefix) + ) + self.lconv1d = Gemma3nAudioConformerLightConv1d( + config, quant_config, prefix=add_prefix("lconv1d", prefix) + ) + self.ffw_layer_end = Gemma3nAudioConformerFeedForward( + config, quant_config, prefix=add_prefix("ffw_layer_end", prefix) + ) + self.register_buffer( + "gradient_clipping", + torch.tensor(self.config.gradient_clipping), + persistent=False, + ) + self.norm = Gemma3nRMSNorm(self.config.hidden_size) + + def forward( + self, audio_encodings: torch.Tensor, audio_mel_mask: torch.BoolTensor + ) -> torch.Tensor: + audio_encodings = self.ffw_layer_start(audio_encodings) + audio_encodings = self.attention(audio_encodings, audio_mel_mask) + validity_mask_for_lconv = ~audio_mel_mask # True for valid + audio_encodings_for_lconv_input = ( + audio_encodings + * validity_mask_for_lconv.unsqueeze(-1).to(audio_encodings.dtype) + ) + audio_encodings = self.lconv1d(audio_encodings_for_lconv_input) + + audio_encodings = self.ffw_layer_end(audio_encodings) + audio_encodings = torch.clamp( + audio_encodings, -self.gradient_clipping, self.gradient_clipping + ) + output = self.norm(audio_encodings) + return output + + +class Gemma3nAudioEncoder(PreTrainedModel): + """A Universal Speech Encoder -- https://arxiv.org/abs/2303.01037""" + + config_class = Gemma3nAudioConfig + + def __init__( + self, + config: Gemma3nAudioConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__(config) + self.config = config + + self.subsample_conv_projection = Gemma3nAudioSubSampleConvProjection( + config, quant_config, prefix=add_prefix("subsample_conv_projection", prefix) + ) + self.conformer = make_layers( + config.conf_num_hidden_layers, + lambda idx, prefix: Gemma3nAudioConformerBlock( + config=config, + quant_config=quant_config, + prefix=prefix, + ), + prefix=add_prefix("conformer", prefix), + ) + + def forward( + self, audio_mel: torch.Tensor, audio_mel_mask: torch.BoolTensor + ) -> Tuple[torch.Tensor, torch.BoolTensor]: + """Encodes a batch of MELs. + + Args: + audio_mel: a torch.Tensor of shape [batch, num_frames, mel_bins]. + audio_mel_mask: a torch.BoolTensor of shape [batch, num_frames]. + + Returns: + audio_encodings: a torch.Tensor of shape + `[batch_size, reduced_time_frames, hidden_size]` + audio_mel_mask: a torch.BoolTensor of shape [batch, reduced_time_frames]. + """ + audio_encodings = self.subsample_conv_projection( + audio_mel + ) # audio_encodings: [B, T_sub, D] + + # Subsample the input audio_mel_mask to match the time dimension of audio_encodings (T_sub) + t_sub = audio_encodings.shape[1] + + time_stride_product = 1 + for stride_pair_idx in range(len(self.config.sscp_conv_stride_size)): + time_stride_product *= self.config.sscp_conv_stride_size[stride_pair_idx][0] + + # Create indices for gathering from the original mask. + # These indices map to original time steps corresponding to the start of each + # receptive field in the subsampled output. + indices = ( + torch.arange(t_sub, device=audio_mel_mask.device) * time_stride_product + ) + indices = torch.clamp(indices, max=audio_mel_mask.shape[1] - 1) + + # Expand indices for batch compatibility if B > 1 and indices is 1D. + if audio_mel_mask.ndim > 1 and indices.ndim == 1: + indices = indices.unsqueeze(0).expand( + audio_mel_mask.shape[0], -1 + ) # [B, T_sub] + elif ( + audio_mel_mask.ndim == indices.ndim + and audio_mel_mask.shape[0] == 1 + and indices.shape[0] != 1 + and t_sub == indices.shape[0] + ): + # Handle case where B=1 but indices became [T_sub] instead of [1, T_sub] + indices = indices.unsqueeze(0) + + current_mask = torch.gather(audio_mel_mask, 1, indices) # [B, T_sub] + + # Fallback: Ensure mask length matches feature length after gather. + if current_mask.shape[1] != t_sub: + if current_mask.shape[1] > t_sub: + current_mask = current_mask[:, :t_sub] + else: # current_mask.shape[1] < t_sub + padding_needed = t_sub - current_mask.shape[1] + current_mask = F.pad( + current_mask, (0, padding_needed), value=True + ) # Pad with True (masked) + + for i, block in enumerate(self.conformer): + audio_encodings = block( + audio_encodings, current_mask + ) # Pass the processed mask + + if self.config.conf_reduction_factor > 1: + audio_encodings = audio_encodings[:, :: self.config.conf_reduction_factor] + # Reduce the mask as well + current_mask = current_mask[:, :: self.config.conf_reduction_factor] + + # Final masking of audio_encodings based on the final current_mask + # Ensure current_mask length matches the finally reduced audio_encodings length + if current_mask.shape[1] != audio_encodings.shape[1]: + target_len = audio_encodings.shape[1] + mask_current_len = current_mask.shape[1] + if target_len > mask_current_len: + padding_needed = target_len - mask_current_len + current_mask = F.pad(current_mask, (0, padding_needed), value=True) + elif mask_current_len > target_len: # mask is longer + current_mask = current_mask[:, :target_len] + + audio_encodings = audio_encodings.masked_fill(current_mask.unsqueeze(-1), 0.0) + return audio_encodings, current_mask diff --git a/python/sglang/srt/models/gemma3n_causal.py b/python/sglang/srt/models/gemma3n_causal.py new file mode 100644 index 000000000..802cb9fc5 --- /dev/null +++ b/python/sglang/srt/models/gemma3n_causal.py @@ -0,0 +1,1009 @@ +from typing import Iterable, Optional, Set, Tuple + +import torch +import torch.nn.functional as F +from torch import nn +from transformers import AutoModel, Gemma3nTextConfig, PretrainedConfig, PreTrainedModel + +from sglang.srt.distributed import get_tensor_model_parallel_world_size +from sglang.srt.layers.activation import GeluAndMul +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope +from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) +from sglang.srt.models.gemma3_causal import Gemma3TextScaledWordEmbedding +from sglang.srt.utils import add_prefix, make_layers + + +# Aligned with HF's implementation, using sliding window inclusive with the last token +# SGLang assumes exclusive +def get_attention_sliding_window_size(config): + return config.sliding_window - 1 + + +class Gemma3nRMSNorm(RMSNorm): + def __init__( + self, + dim: int, + eps: float = 1e-6, + with_scale: bool = True, + ) -> None: + super().__init__(dim, eps=eps) + if not with_scale: + del self.weight + self.register_buffer( + "weight", + torch.ones(dim, dtype=torch.get_default_dtype()), + persistent=False, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + original_shape = x.shape + x_2d = x.contiguous().reshape(-1, original_shape[-1]) + x_2d = super().forward(x_2d) + x = x_2d.reshape(original_shape) + return x + + +class Gemma3nTextScaledWordEmbedding(Gemma3TextScaledWordEmbedding): + pass + + +class Gemma3nMLP(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_activation: str, + activation_sparsity: float = 0.0, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=add_prefix("gate_up_proj", prefix), + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=add_prefix("down_proj", prefix), + ) + if hidden_activation != "gelu_pytorch_tanh": + raise ValueError( + "Gemma3n uses `gelu_pytorch_tanh` as the hidden activation " + "function. Please set `hidden_activation` to " + "`gelu_pytorch_tanh`." + ) + # Use proper GELU with tanh approximation as specified + self.act_fn = GeluAndMul() + self.activation_sparsity = activation_sparsity + self.register_buffer( + "target_sparsity_tensor", + torch.tensor(self.activation_sparsity, dtype=torch.float32), + persistent=False, + ) # moved from _gaussian_topk for cuda graph + + def forward(self, x: torch.Tensor) -> torch.Tensor: + gate_up, _ = self.gate_up_proj(x) + + # Split gate and up projections + gate_proj, up_proj = gate_up.chunk(2, dim=-1) + + # Apply activation sparsity if needed + if self.activation_sparsity > 0.0: + gate_proj = self._gaussian_topk(gate_proj) + + gate_up = torch.cat([gate_proj, up_proj], dim=-1) + + # Apply GELU activation to gate projection and multiply with up projection + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + def _gaussian_topk(self, inputs: torch.Tensor) -> torch.Tensor: + normal_dist = torch.distributions.normal.Normal(0, 1) + std_multiplier = normal_dist.icdf(self.target_sparsity_tensor) + std_multiplier = std_multiplier.type(inputs.dtype) + inputs_mean = torch.mean(inputs, dim=-1, keepdim=True) + inputs_std = torch.std(inputs, dim=-1, keepdim=True, unbiased=False) + cutoff_x = inputs_mean + inputs_std * std_multiplier + return F.relu(inputs - cutoff_x) + + +class Gemma3nLaurelBlock(nn.Module): + """Learned Augmented Residual Layer""" + + def __init__( + self, + config: Gemma3nTextConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + + self.linear_left = ColumnParallelLinear( + config.hidden_size, + config.laurel_rank, + bias=False, + quant_config=quant_config, + prefix=add_prefix("linear_left", prefix), + ) + self.linear_right = RowParallelLinear( + config.laurel_rank, + config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=add_prefix("linear_right", prefix), + ) + self.post_laurel_norm = Gemma3nRMSNorm( + dim=config.hidden_size, + eps=config.rms_norm_eps, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # [num_tokens, hidden_size] + laurel_x, _ = self.linear_left(x) + laurel_x, _ = self.linear_right(laurel_x) + normed_laurel_x = self.post_laurel_norm(laurel_x) + return x + normed_laurel_x + + +class Gemma3nAltUp(nn.Module): + """Alternating Updates (AltUp)""" + + def __init__( + self, + config: Gemma3nTextConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + + self.correct_output_scale = nn.Parameter( + torch.zeros(config.hidden_size, dtype=torch.float32) + ) + self.correction_coefs = ColumnParallelLinear( + config.altup_num_inputs, + config.altup_num_inputs, + bias=False, + quant_config=quant_config, + prefix=add_prefix("correction_coefs", prefix), + ) + self.prediction_coefs = ColumnParallelLinear( + config.altup_num_inputs, + config.altup_num_inputs**2, + bias=False, + quant_config=quant_config, + prefix=add_prefix("prediction_coefs", prefix), + ) + self.modality_router = ColumnParallelLinear( + config.hidden_size, + config.altup_num_inputs, + bias=False, + quant_config=quant_config, + prefix=add_prefix("modality_router", prefix), + ) + + self.router_norm = Gemma3nRMSNorm( + dim=config.hidden_size, + eps=config.rms_norm_eps, + ) + + self.register_buffer( + "router_input_scale", + torch.tensor(config.hidden_size**-1.0), + persistent=False, + ) + + def compute_router_modalities(self, x: torch.Tensor) -> torch.Tensor: + # x : [num_tokens, hidden_size] + router_inputs = self.router_norm(x) * self.router_input_scale.to( + self.router_norm.weight.dtype + ) + # router_inputs : [num_tokens, hidden_size] + routed, _ = self.modality_router(router_inputs) + + # routed : [num_tokens, altup_num_inputs] + return torch.tanh(routed.float()).type_as(routed) + + def predict(self, hidden_states: torch.Tensor) -> torch.Tensor: + """Predicts the output of a layer using a trainable map. + hidden_states: [num_altup_inputs, num_tokens, hidden_size] + """ + modalities = self.compute_router_modalities( + hidden_states[self.config.altup_active_idx] + ) # (n_tokens, altup_num_inputs) + # TODO: CHECK DO WE NEED THIS: self.prediction_coefs.float() # Force computation in float32, in-place operation + + if self.config.altup_coef_clip is not None: + self.prediction_coefs.weight.data.clamp_( + -self.config.altup_coef_clip, self.config.altup_coef_clip + ) + + all_coefs, _ = self.prediction_coefs( + modalities + ) # (n_tokens, altup_num_inputs) -> (n_tokens, altup_num_inputs**2) + + all_coefs = all_coefs.reshape( + *modalities.shape[:-1], + self.config.altup_num_inputs, + self.config.altup_num_inputs, + ).permute(0, 2, 1) + + # permute hidden_states from [num_altup_inputs, num_tokens, hidden_size] to [num_tokens, hidden_size, altup_num_inputs] + predictions = torch.matmul(hidden_states.permute(1, 2, 0), all_coefs) + predictions = predictions.permute(2, 0, 1) # undo the permute + predictions += hidden_states # add the original input + return predictions.contiguous().type_as( + hidden_states + ) # [num_altup_inputs, num_tokens, hidden_size] + + def correct( + self, predictions: torch.Tensor, activated: torch.Tensor + ) -> torch.Tensor: + """Corrects the predictions relative to the activated inputs.""" + # prediction : [num_altup_inputs, num_tokens, hidden_size] + # activated : [num_tokens, hidden_size] + modalities = self.compute_router_modalities( + activated + ) # [num_tokens, altup_num_inputs] + innovation = ( + activated - predictions[self.config.altup_active_idx] + ) # [num_tokens, hidden_size] + innovation = innovation.repeat( + self.config.altup_num_inputs, 1, 1 + ) # (self.config.altup_num_inputs, num_tokens, hidden_size) + + if self.config.altup_coef_clip is not None: + self.correction_coefs.weight.data.clamp_( + -self.config.altup_coef_clip, self.config.altup_coef_clip + ) + + all_coefs, _ = self.correction_coefs( + modalities + ) # [num_tokens, altup_num_inputs] + all_coefs = (all_coefs + 1.0).permute(1, 0).unsqueeze(-1) + # # [num_tokens, altup_num_inputs, 1] + + corrected = torch.mul(innovation, all_coefs) + corrected += predictions + return corrected.contiguous().type_as(activated) + + def scale_corrected_output(self, corrected: torch.Tensor) -> torch.Tensor: + """Scales the provided 3D tensor.""" + return corrected * self.correct_output_scale.to(corrected.dtype) + + def forward( + self, hidden_states: torch.Tensor, activated: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Predicts, correct, and optionally scales the output of a layer using trainable maps. + + hidden_states: [num_altup_inputs, num_tokens, hidden_size] + """ + + predictions = self.predict(hidden_states) + corrected = self.correct(predictions=predictions, activated=activated) + output = corrected[self.config.altup_active_idx] + if self.config.altup_correct_scale: + output = self.scale_corrected_output(output) + return corrected, output + + +class Gemma3nAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + layer_id: int, + config: Gemma3nTextConfig, + max_position_embeddings: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.layer_id = layer_id + self.config = config + tp_size = get_tensor_model_parallel_world_size() + + self.total_num_heads = config.num_attention_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = config.num_key_value_heads + + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + + if self.total_num_kv_heads >= tp_size: + assert self.total_num_kv_heads % tp_size == 0 + else: + assert tp_size % self.total_num_kv_heads == 0 + + hidden_size = config.hidden_size + head_dim = getattr( + config, "head_dim", hidden_size // config.num_attention_heads + ) + self.head_dim = head_dim + + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + # self.scaling = config.query_rescale_scalar / config.query_pre_attn_scalar + self.scaling = 1.0 + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=config.attention_bias, + quant_config=quant_config, + prefix=add_prefix("qkv_proj", prefix), + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=config.attention_bias, + quant_config=quant_config, + prefix=add_prefix("o_proj", prefix), + ) + + # Determine if layer uses sliding window based on pattern + self.is_sliding = config.layer_types[layer_id] == "sliding_attention" + + # Check if this is a KV shared layer + first_kv_shared_layer_idx = ( + config.num_hidden_layers - config.num_kv_shared_layers + ) + self.is_kv_shared_layer = layer_id >= first_kv_shared_layer_idx + + # Compute the layer index from which shared KV cache values will be retrieved + if not self.is_kv_shared_layer: + self.kv_shared_layer_index = None + elif self.is_sliding: + self.kv_shared_layer_index = first_kv_shared_layer_idx - 2 + else: + self.kv_shared_layer_index = first_kv_shared_layer_idx - 1 + + if self.is_sliding: + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=config.max_position_embeddings, + base=config.rope_local_base_freq, + rope_scaling={"rope_type": "default"}, + ) + else: + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=config.max_position_embeddings, + base=config.rope_theta, + rope_scaling=config.rope_scaling, + ) + + self.sliding_window = config.sliding_window if self.is_sliding else None + + self.attn = RadixAttention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + layer_id=( + layer_id if not self.is_kv_shared_layer else self.kv_shared_layer_index + ), + logit_cap=0.0, + sliding_window_size=self.sliding_window, + quant_config=quant_config, + prefix=add_prefix("attn", prefix), + ) + + # Gemma3n adds normalization for q, k, v + self.q_norm = Gemma3nRMSNorm( + dim=config.head_dim, + eps=config.rms_norm_eps, + ) + self.k_norm = Gemma3nRMSNorm( + dim=config.head_dim, + eps=config.rms_norm_eps, + ) + self.v_norm = Gemma3nRMSNorm( + dim=config.head_dim, + eps=config.rms_norm_eps, + with_scale=False, + ) + + def forward( + self, + hidden_states: torch.Tensor, + positions: Tuple[torch.Tensor, torch.Tensor], + forward_batch: ForwardBatch, + **kwargs, + ) -> torch.Tensor: + + qkv, _ = self.qkv_proj(hidden_states) + # TODO: for first 20 layers, we use QKVParallelLinear + # for others, we only calc Q. + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + # Apply normalization to q, k, v + q = q.unflatten(-1, (self.num_heads, self.head_dim)) + q = self.q_norm(q) + + # Check if we should use shared KV cache + if self.is_kv_shared_layer and self.kv_shared_layer_index is not None: + # For KV shared layers, we skip K/V computation and normalization + # The RadixAttention will handle retrieving shared KV from cache + k = None + v = None + else: + k = k.unflatten(-1, (self.num_kv_heads, self.head_dim)) + k = self.k_norm(k) + + v = v.unflatten(-1, (self.num_kv_heads, self.head_dim)) + v = self.v_norm(v) + + # Flatten back for rotary embedding + q = q.flatten(-2, -1) + + # Apply rotary embedding + if k is not None: + k = k.flatten(-2, -1) + q, k = self.rotary_emb(positions, q, k) + # Reshape k back to head format for attention + k = k.unflatten(-1, (self.num_kv_heads, self.head_dim)) + else: + # For shared KV layers, create a dummy key for rotary embedding and discard it + dummy_k = torch.zeros_like( + q[:, : self.kv_size] + ) # Create dummy key with same shape as needed + q, _ = self.rotary_emb(positions, q, dummy_k) + + # Reshape q back to head format for attention + q = q.unflatten(-1, (self.num_heads, self.head_dim)) + + attn_output = self.attn( + q, + k, + v, + forward_batch=forward_batch, + save_kv_cache=not self.is_kv_shared_layer, + ) + + output, _ = self.o_proj(attn_output) + return output + + +class Gemma3nDecoderLayer(nn.Module): + def __init__( + self, + layer_id: int, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.layer_id = layer_id + self.attention_type = config.layer_types[layer_id] + self.config = config + + self.self_attn = Gemma3nAttention( + layer_id=layer_id, + config=config, + max_position_embeddings=config.max_position_embeddings, + quant_config=quant_config, + prefix=add_prefix("self_attn", prefix), + ) + + activation_sparsity = config.activation_sparsity_pattern[layer_id] + self.mlp = Gemma3nMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_activation=config.hidden_activation, + activation_sparsity=activation_sparsity, + quant_config=quant_config, + prefix=add_prefix("mlp", prefix), + ) + + self.input_layernorm = Gemma3nRMSNorm(self.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Gemma3nRMSNorm( + self.hidden_size, eps=config.rms_norm_eps + ) + self.pre_feedforward_layernorm = Gemma3nRMSNorm( + self.hidden_size, eps=config.rms_norm_eps + ) + self.post_feedforward_layernorm = Gemma3nRMSNorm( + self.hidden_size, eps=config.rms_norm_eps + ) + + self.hidden_size_per_layer_input = config.hidden_size_per_layer_input + + self.altup = Gemma3nAltUp( + config, quant_config, prefix=add_prefix("altup", prefix) + ) + self.laurel = Gemma3nLaurelBlock( + config, quant_config, prefix=add_prefix("laurel", prefix) + ) + + self.per_layer_input_gate = ColumnParallelLinear( + self.hidden_size, + self.hidden_size_per_layer_input, + bias=False, + quant_config=quant_config, + prefix=add_prefix("per_layer_input_gate", prefix), + ) + self.per_layer_projection = RowParallelLinear( + self.hidden_size_per_layer_input, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=add_prefix("per_layer_projection", prefix), + ) + self.post_per_layer_input_norm = Gemma3nRMSNorm( + self.hidden_size, eps=config.rms_norm_eps + ) + self.is_sliding = self.self_attn.is_sliding + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + per_layer_input: torch.Tensor, + forward_batch: ForwardBatch, + **kwargs, + ) -> torch.Tensor: + predictions = self.altup.predict( + hidden_states + ) # [num_altup_inputs, num_tokens, hidden_size] + active_prediction = predictions[self.config.altup_active_idx] + + active_prediction_normed = self.input_layernorm(active_prediction) + laurel_output = self.laurel( + active_prediction_normed + ) # laurel_output: [num_tokens, hidden_size] + # active_prediction: [num_tokens, hidden_size] + + attn = self.self_attn( + positions=positions, + hidden_states=active_prediction_normed, + forward_batch=forward_batch, + **kwargs, + ) + attn = self.post_attention_layernorm(attn) # [num_tokens, hidden_size] + + attn_gated = active_prediction + attn # [num_tokens, hidden_size] + attn_laurel = (attn_gated + laurel_output) / torch.sqrt(torch.tensor(2.0)) + + attn_norm = self.pre_feedforward_layernorm( + attn_laurel + ) # [num_tokens, hidden_size] + attn_ffw = self.mlp(attn_norm) # [num_tokens, hidden_size] + attn_ffw_norm = self.post_feedforward_layernorm( + attn_ffw + ) # [num_tokens, hidden_size] + attn_ffw_laurel_gated = attn_laurel + attn_ffw_norm # [num_tokens, hidden_size] + corrected_predictions = self.altup.correct( + predictions, attn_ffw_laurel_gated + ) # prediction : [num_altup_inputs, num_tokens, hidden_size] + # attn_ffw_laurel_gated: [num_tokens, hidden_size] + first_prediction = corrected_predictions[self.config.altup_active_idx] + + if self.config.altup_correct_scale: + first_prediction = self.altup.scale_corrected_output(first_prediction) + + # per_layer_input_gate + first_prediction = first_prediction.to(self.per_layer_input_gate.weight.dtype) + first_prediction, _ = self.per_layer_input_gate(first_prediction) + first_prediction = F.gelu(first_prediction, approximate="tanh") + first_prediction = torch.multiply(first_prediction, per_layer_input) + + # per_layer_projection + first_prediction, _ = self.per_layer_projection(first_prediction) + first_prediction = self.post_per_layer_input_norm(first_prediction) + corrected_predictions[1:] += first_prediction + + return corrected_predictions + + +class Gemma3nTextModel(PreTrainedModel): + def __init__( + self, + config: Gemma3nTextConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__(config=config) + self.config = config + self.quant_config = quant_config + self.vocab_size = config.vocab_size + self.padding_idx = config.pad_token_id + + # Gemma3n downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 + self.embed_tokens = Gemma3nTextScaledWordEmbedding( + config.vocab_size, + config.hidden_size, + self.padding_idx, + embed_scale=self.config.hidden_size**0.5, + ) + + self.norm = Gemma3nRMSNorm( + config.hidden_size, + eps=config.rms_norm_eps, + ) + + self.layers = make_layers( + config.num_hidden_layers, + lambda idx, prefix: Gemma3nDecoderLayer( + layer_id=idx, + config=config, + quant_config=quant_config, + prefix=prefix, + ), + prefix=add_prefix("layers", prefix), + ) + + # Per-layer input embeddings + self.hidden_size = config.hidden_size + self.hidden_size_per_layer_input = config.hidden_size_per_layer_input + + self.embed_tokens_per_layer = Gemma3nTextScaledWordEmbedding( + config.vocab_size_per_layer_input, + config.num_hidden_layers * config.hidden_size_per_layer_input, + self.padding_idx, + embed_scale=self.config.hidden_size_per_layer_input**0.5, + ) + + self.per_layer_model_projection = ColumnParallelLinear( + self.hidden_size, + config.num_hidden_layers * config.hidden_size_per_layer_input, + bias=False, + quant_config=quant_config, + prefix=add_prefix("per_layer_model_projection", prefix), + ) + + self.per_layer_projection_norm = Gemma3nRMSNorm( + dim=config.hidden_size_per_layer_input, + eps=config.rms_norm_eps, + ) + + self.altup_projections = make_layers( + self.config.altup_num_inputs - 1, + lambda idx, prefix: ColumnParallelLinear( + self.hidden_size, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=prefix, + ), + prefix=add_prefix("altup_projections", prefix), + ) + + self.altup_unembed_projections = make_layers( + self.config.altup_num_inputs - 1, + lambda idx, prefix: ColumnParallelLinear( + self.hidden_size, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=prefix, + ), + prefix=add_prefix("altup_unembed_projections", prefix), + ) + + self.register_buffer( + "per_layer_projection_scale", + torch.tensor(self.hidden_size**-0.5), + persistent=False, + ) + self.register_buffer( + "per_layer_input_scale", torch.rsqrt(torch.tensor(2.0)), persistent=False + ) + + self.post_init() + + def get_input_embeddings(self) -> nn.Embedding: + return self.embed_tokens + + def dtype(self) -> torch.dtype: + return next(self.parameters()).dtype + + def get_per_layer_inputs(self, input_ids: torch.LongTensor) -> torch.Tensor: + embeddings = self.embed_tokens_per_layer(input_ids) + return embeddings.reshape( + *input_ids.shape, + self.config.num_hidden_layers, + self.hidden_size_per_layer_input, + ) + + def project_per_layer_inputs( + self, + inputs_embeds: torch.Tensor, + per_layer_inputs: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + per_layer_projection, _ = self.per_layer_model_projection(inputs_embeds) + per_layer_projection *= self.per_layer_projection_scale.type( + inputs_embeds.dtype + ) + per_layer_projection = per_layer_projection.reshape( + *inputs_embeds.shape[:-1], + self.config.num_hidden_layers, + self.hidden_size_per_layer_input, + ) + per_layer_projection = self.per_layer_projection_norm(per_layer_projection) + + if per_layer_inputs is None: + return per_layer_projection + + if per_layer_projection.shape != per_layer_inputs.shape: + # per-layer inputs are sometimes padded with zeros, slice the relevant embeddings + per_layer_inputs = per_layer_inputs[..., : self.config.num_hidden_layers, :] + + return ( + per_layer_projection + per_layer_inputs + ) * self.per_layer_input_scale.type(inputs_embeds.dtype) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + per_layer_inputs: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + if (input_ids is None) ^ (input_embeds is not None): + raise ValueError( + "You must specify exactly one of input_ids or inputs_embeds" + ) + + if input_ids is not None: + input_embeds = self.embed_tokens(input_ids) + per_layer_inputs = self.get_per_layer_inputs(input_ids) + + per_layer_inputs = self.project_per_layer_inputs(input_embeds, per_layer_inputs) + + if positions.dim() == 1: + positions = positions.unsqueeze(0) + + # Expand hidden_states to support per-layer inputs + target_magnitude = torch.mean(input_embeds**2, dim=-1, keepdim=True) ** 0.5 + epsilon_tensor = torch.tensor(torch.finfo(input_embeds.dtype).min) + + # embed positions + hidden_states_0 = input_embeds + temp_hidden_states = [hidden_states_0] + + for i in range(1, self.config.altup_num_inputs): + altup_proj, _ = self.altup_projections[i - 1](hidden_states_0) + current_hidden_state = altup_proj.type(hidden_states_0.dtype) + new_magnitude = ( + torch.mean(current_hidden_state**2, dim=-1, keepdim=True) ** 0.5 + ) + current_hidden_state = current_hidden_state * ( + target_magnitude / torch.maximum(new_magnitude, epsilon_tensor) + ) + temp_hidden_states.append(current_hidden_state) + + hidden_states = torch.stack( + temp_hidden_states, dim=0 + ) # [num_altup_inputs, n_tokens, hidden_size] + + for layer_idx, layer in enumerate(self.layers): + per_layer_input = per_layer_inputs[:, layer_idx, :] + hidden_states = layer( + positions=positions, + per_layer_input=per_layer_input, + hidden_states=hidden_states, + forward_batch=forward_batch, + **kwargs, + ) + + # Per-layer inputs to single output + target_magnitude = ( + torch.mean(hidden_states[0] ** 2, dim=-1, keepdim=True) ** 0.5 + ) + + temp_hidden_states = [hidden_states[0]] + + for i in range(1, self.config.altup_num_inputs): + # altup_unembed_projections adapted from jax.numpy.einsum("btp,pd->btd", ...) + altup_unemb_proj, _ = self.altup_unembed_projections[i - 1]( + hidden_states[i] + ) + current_hidden_state = altup_unemb_proj.type(hidden_states_0.dtype) + new_magnitude = ( + torch.mean(current_hidden_state**2, dim=-1, keepdim=True) ** 0.5 + ) + current_hidden_state = current_hidden_state * ( + target_magnitude / torch.maximum(new_magnitude, epsilon_tensor) + ) + temp_hidden_states.append(current_hidden_state) + + hidden_states = torch.stack(temp_hidden_states) + hidden_states = torch.mean(hidden_states, dim=0) + hidden_states = self.norm(hidden_states) + + return hidden_states + + +class Gemma3nForCausalLM(PreTrainedModel): + config_class = Gemma3nTextConfig + + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + config_class = Gemma3nTextConfig + base_model_prefix = "language_model" + + # BitandBytes specific attributes + default_bitsandbytes_target_modules = [ + ".gate_proj.", + ".down_proj.", + ".up_proj.", + ".q_proj.", + ".k_proj.", + ".v_proj.", + ".o_proj.", + ] + bitsandbytes_stacked_params_mapping = { + ".q_proj": (".qkv_proj", 0), + ".k_proj": (".qkv_proj", 1), + ".v_proj": (".qkv_proj", 2), + ".gate_proj": (".gate_up_proj", 0), + ".up_proj": (".gate_up_proj", 1), + } + + packed_modules_mapping = { + ".qkv_proj": [ + ".q_proj", + ".k_proj", + ".v_proj", + ], + ".gate_up_proj": [ + ".gate_proj", + ".up_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + ".qkv_proj", + ".o_proj", + ".gate_up_proj", + ".down_proj", + ] + # Gemma does not apply LoRA to the embedding layer + embedding_modules = {} + embedding_padding_modules = [] + supports_lora = True + + def __init__( + self, + config: Gemma3nTextConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__(config=config) + self.config = config + self.quant_config = quant_config + self.model = Gemma3nTextModel( + config=config, + quant_config=quant_config, + prefix=add_prefix("model", prefix), + ) + self.logits_processor = LogitsProcessor(config) + + if self.config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=add_prefix("lm_head", prefix), + ) + self.post_init() + + def get_input_embeddings(self) -> nn.Embedding: + return self.model.embed_tokens + + def get_attention_sliding_window_size(self): + return get_attention_sliding_window_size(self.config) + + def dtype(self) -> torch.dtype: + return next(self.parameters()).dtype + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + per_layer_inputs: Optional[torch.Tensor] = None, + **kwargs, + ) -> LogitsProcessor: + hidden_states = self.model( + input_ids, + positions, + forward_batch, + input_embeds, + per_layer_inputs, + **kwargs, + ) + + return self.logits_processor( + input_ids, hidden_states, self.model.embed_tokens, forward_batch + ) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + + for name, loaded_weight in weights: + name = name.replace("model.language_model.", "model.") + for param_name, shard_name, shard_id in stacked_params_mapping: + if shard_name not in name: + continue + name = name.replace(shard_name, param_name) + # Skip loading extra bias for GPTQ models + if name.endswith(".bias") and name not in params_dict: + continue + if name not in params_dict: + # Skip loading weights that are not in the model + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # lm_head is not used in vllm as it is tied with embed_token + if "lm_head.weight" in name: + continue + # Skip loading extra bias for GPTQ models + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + if name not in params_dict: + # Skip loading weights that are not in the model + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +EntryClass = Gemma3nForCausalLM +AutoModel.register(Gemma3nTextConfig, Gemma3nForCausalLM, exist_ok=True) diff --git a/python/sglang/srt/models/gemma3n_mm.py b/python/sglang/srt/models/gemma3n_mm.py new file mode 100644 index 000000000..5100a37ed --- /dev/null +++ b/python/sglang/srt/models/gemma3n_mm.py @@ -0,0 +1,511 @@ +import logging +import re +from functools import lru_cache +from typing import Dict, Iterable, List, Optional, Set, Tuple, TypedDict, Union + +import torch +from torch import nn +from transformers import ( + Gemma3nAudioConfig, + Gemma3nConfig, + Gemma3nTextConfig, + Gemma3nVisionConfig, + PreTrainedModel, +) +from transformers.models.auto.modeling_auto import AutoModel + +from sglang.srt.hf_transformers_utils import get_processor +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding +from sglang.srt.managers.mm_utils import ( + MultiModalityDataPaddingPatternTokenPairs, + general_mm_embed_routine, +) +from sglang.srt.managers.schedule_batch import ( + MultimodalDataItem, + MultimodalInputs, + flatten_nested_list, +) +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) +from sglang.srt.models.gemma3n_audio import Gemma3nAudioEncoder +from sglang.srt.models.gemma3n_causal import Gemma3nRMSNorm, Gemma3nTextModel +from sglang.srt.utils import add_prefix + +logger = logging.getLogger(__name__) + +cached_get_processor = lru_cache(get_processor) + + +class Gemma3nImagePixelInputs(TypedDict): + pixel_values: torch.Tensor + """Shape: `(batch_size * num_images, num_channels, height, width)`""" + + +class Gemma3nAudioInputs(TypedDict): + input_features: torch.Tensor + """Shape: `(batch_size * num_audio, seq_length, num_features)`""" + input_features_mask: torch.Tensor + """Shape: `(batch_size * num_audio, seq_length)`""" + + +class Gemma3nMultimodalEmbedder(nn.Module): + """Embeds token ids or soft tokens for multimodal content into language model space.""" + + def __init__( + self, + multimodal_config: Union[Gemma3nAudioConfig, Gemma3nVisionConfig], + text_config: Gemma3nTextConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + + self.multimodal_hidden_size = multimodal_config.hidden_size + self.eps = multimodal_config.rms_norm_eps + self.vocab_offset = multimodal_config.vocab_offset + self.vocab_size = multimodal_config.vocab_size + self.text_hidden_size = text_config.hidden_size + + self.embedding = VocabParallelEmbedding( + self.vocab_size, + self.multimodal_hidden_size, + quant_config=quant_config, + prefix=add_prefix("embedding", prefix), + ) + + self.hard_embedding_norm = Gemma3nRMSNorm( + self.multimodal_hidden_size, + eps=self.eps, + ) + + self.soft_embedding_norm = Gemma3nRMSNorm( + self.multimodal_hidden_size, + eps=self.eps, + ) + + self.embedding_projection = RowParallelLinear( + self.multimodal_hidden_size, + self.text_hidden_size, + bias=False, + quant_config=quant_config, + prefix=add_prefix("embedding_projection", prefix), + ) + + self.embedding_post_projection_norm = Gemma3nRMSNorm( + self.text_hidden_size, + eps=self.eps, + with_scale=False, + ) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Embeds token ids or soft tokens for multimodal content into language model space. + + Args: + input_ids: A torch.LongTensor containing the token ids to embed. Values should be in the range + `[vocab_offset, vocab_offset + vocab_size)`. + inputs_embeds: A torch.Tensor containing the soft tokens to embed. + + Returns: + A torch.Tensor of embeddings with shape `[batch_size, seq_len, self.config.text_config.hidden_size]`. + """ + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You must specify exactly one of input_ids or inputs_embeds" + ) + + if inputs_embeds is not None: + emb_norm = self.soft_embedding_norm(inputs_embeds) + else: + # Handle out of vocab ids to prevent CUDA assertion failures + out_of_vocab_id = self.vocab_size - 1 + adjusted_ids = input_ids - self.vocab_offset + adjusted_ids = torch.where(adjusted_ids < 0, out_of_vocab_id, adjusted_ids) + adjusted_ids = torch.where( + adjusted_ids >= self.vocab_size, out_of_vocab_id, adjusted_ids + ) + hard_emb = self.embedding(adjusted_ids) + emb_norm = self.hard_embedding_norm(hard_emb) + + emb_norm_proj, _ = self.embedding_projection(emb_norm) + return self.embedding_post_projection_norm(emb_norm_proj) + + +class Gemma3nForConditionalGeneration(PreTrainedModel): + config_class = Gemma3nConfig + """Gemma3n multimodal model for conditional generation.""" + + # BitandBytes specific attributes + default_bitsandbytes_target_modules = [ + ".gate_proj.", + ".down_proj.", + ".up_proj.", + ".q_proj.", + ".k_proj.", + ".v_proj.", + ".o_proj.", + ".out_proj.", + ] + bitsandbytes_stacked_params_mapping = { + "q_proj": ("qkv_proj", 0), + "k_proj": ("qkv_proj", 1), + "v_proj": ("qkv_proj", 2), + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), + "out_proj": ("proj", 0), + } + + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "gate_up_proj", + "down_proj", + ] + # Gemma does not apply LoRA to the embedding layer + embedding_modules = {} + embedding_padding_modules = [] + supports_lora = True + + def __init__( + self, + config: Gemma3nConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__(config=config) + self.config = config + self.quant_config = quant_config + + prefix = add_prefix("model", prefix) + + # Vision components + # TODO: Use sglang's vision model + self.vision_tower = AutoModel.from_config(config=config.vision_config) + + self.embed_vision = Gemma3nMultimodalEmbedder( + config.vision_config, + config.text_config, + quant_config=quant_config, + prefix=add_prefix("embed_vision", prefix), + ) + + # Audio components + self.embed_audio = Gemma3nMultimodalEmbedder( + config.audio_config, + config.text_config, + quant_config=quant_config, + prefix=add_prefix("embed_audio", prefix), + ) + + self.audio_tower = Gemma3nAudioEncoder( + config.audio_config, + quant_config=quant_config, + prefix=add_prefix("audio_tower", prefix), + ) + + self.vocab_size = config.text_config.vocab_size + self.vocab_size_per_layer_input = config.text_config.vocab_size_per_layer_input + + # Text model + self.language_model = Gemma3nTextModel( + config.text_config, + quant_config, + prefix=add_prefix("language_model", prefix), + ) + + # Create logits processor for the multimodal model + self.logits_processor = LogitsProcessor(config.text_config) + + self.post_init() + + def pad_input_ids( + self, + input_ids: List[int], + mm_inputs: Optional[MultimodalInputs] = None, + ) -> List[int]: + """Pad input IDs with image and audio tokens.""" + if mm_inputs is None: + return input_ids + + # Collect available media token pairs + media_token_pairs = [] + for attr_name in ["im_start_id", "audio_start_id"]: + if hasattr(mm_inputs, attr_name): + start_id = getattr(mm_inputs, attr_name) + end_id = getattr(mm_inputs, attr_name.replace("start", "end")) + media_token_pairs.append((start_id, end_id)) + + # Apply padding pattern if we have media tokens + if media_token_pairs: + pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs) + return pattern.pad_input_tokens(input_ids, mm_inputs) + + return input_ids + + def get_input_embeddings(self) -> nn.Embedding: + return self.language_model.get_input_embeddings() + + def get_attention_sliding_window_size(self): + return self.config.text_config.sliding_window - 1 + + def get_image_feature(self, items: List[MultimodalDataItem]): + """ + Projects the last hidden state from the vision model into language model space. + + Returns: + image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). + """ + # Process images one by one to handle flatten_batch=True constraint in vision_tower + all_pixel_values = flatten_nested_list([item.pixel_values for item in items]) + vision_outputs_list = [] + + for pixel_values_batch in all_pixel_values: + # Normalize input shape to [batch_size, channels, height, width] + if pixel_values_batch.dim() == 5: + pixel_values_batch = pixel_values_batch.squeeze(0) + elif pixel_values_batch.dim() == 3: + pixel_values_batch = pixel_values_batch.unsqueeze(0) + elif pixel_values_batch.dim() != 4: + raise ValueError( + f"Unexpected pixel_values shape: {pixel_values_batch.shape}" + ) + + # Process each image in the batch + batch_size = pixel_values_batch.shape[0] + for i in range(batch_size): + pixel_value = pixel_values_batch[i : i + 1] # Keep batch dimension as 1 + pixel_value = pixel_value.to( + device=self.vision_tower.device, dtype=self.language_model.dtype() + ) + vision_outputs = self.vision_tower( + pixel_values=pixel_value, do_pooling=False, return_dict=True + ).last_hidden_state + vision_outputs_list.append(vision_outputs) + + # Concatenate all vision outputs + vision_outputs = torch.cat(vision_outputs_list, dim=0) + + # Convert from (batch, channels, height, width) to (batch, height * width, channels) + vision_outputs = vision_outputs.reshape( + vision_outputs.shape[0], + self.config.vision_config.hidden_size, + self.config.vision_soft_tokens_per_image, + ).permute(0, 2, 1) + + # Normalize and embed the soft tokens into language model space + vision_outputs *= self.config.vision_config.hidden_size**0.5 + return self.embed_vision(inputs_embeds=vision_outputs) + + def get_audio_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: + """ + Projects the last hidden state from the audio encoder into language model space. + + Args: + items: List of multimodal data items containing audio data. + + Returns: + audio_features (`torch.Tensor`): Audio feature tensor of shape `(num_audios, audio_length, embed_dim)`). + """ + # Extract audio features and masks from items + all_input_features = flatten_nested_list( + [item.input_features for item in items] + ) + all_input_features_mask = flatten_nested_list( + [~item.input_features_mask for item in items] + ) # Note(Xinyuan): reverse the mask according to the HF implementation + + # Process audio features one by one + audio_features_list = [] + + for input_features, input_features_mask in zip( + all_input_features, all_input_features_mask + ): + # Ensure proper tensor format + if input_features.dim() == 2: + input_features = input_features.unsqueeze(0) + if input_features_mask.dim() == 1: + input_features_mask = input_features_mask.unsqueeze(0) + + # Move to device and dtype + input_features = input_features.to( + device=next(self.audio_tower.parameters()).device, + dtype=self.language_model.dtype(), + ) + input_features_mask = input_features_mask.to(device=input_features.device) + + # Process through audio tower + audio_outputs, audio_mask = self.audio_tower( + input_features, input_features_mask + ) + + # Embed the audio outputs + audio_embeds = self.embed_audio(inputs_embeds=audio_outputs) + audio_features_list.append(audio_embeds) + + # Concatenate all audio features + if audio_features_list: + audio_features = torch.cat(audio_features_list, dim=0) + + # The Gemma3nProcessor expects all audio will be 30s in length and inserts 188 audio soft tokens into the + # text to account for this. However, the audio preprocessing and encoder do not gurarantee they will + # produce 188 soft tokens; they will produce at most that many tokens, but they may produce fewer tokens + # depending on the length of the longest audio input in the batch. When we encounter this situation, we pad + # the audio feature out to 188 soft tokens with the emebedding of the last token in the embed_audio vocab. + audio_padding_toks = torch.tensor( + [[self.vocab_size - 1]], dtype=torch.long, device=audio_features.device + ) + audio_padding_embs = self.embed_audio(input_ids=audio_padding_toks) + audio_features = torch.where( + audio_mask.unsqueeze(-1), audio_padding_embs, audio_features + ) + + audio_batch_size, audio_seq_len, audio_embed_dim = audio_features.shape + extra_padding_tokens = ( + self.config.audio_soft_tokens_per_image - audio_seq_len + ) + extra_padding_features = audio_padding_embs.expand( + audio_batch_size, extra_padding_tokens, audio_embed_dim + ) + + audio_features = torch.cat((audio_features, extra_padding_features), dim=1) + return audio_features + else: + return torch.empty( + 0, + 0, + self.language_model.config.hidden_size, + device=next(self.parameters()).device, + dtype=self.language_model.dtype(), + ) + + def get_per_layer_inputs( + self, input_ids: torch.LongTensor + ) -> Optional[torch.Tensor]: + return self.language_model.get_per_layer_inputs(input_ids) + + def project_per_layer_inputs( + self, + inputs_embeds: torch.Tensor, + per_layer_inputs: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return self.language_model.project_per_layer_inputs( + inputs_embeds, per_layer_inputs + ) + + @torch.no_grad() + def forward( + self, + input_ids: torch.LongTensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + **kwargs: object, + ) -> LogitsProcessor: + """Forward pass for multimodal Gemma3n.""" + if (input_ids is None) ^ (input_embeds is not None): + raise ValueError( + "You must specify exactly one of input_ids or inputs_embeds" + ) + + positions += 1 + + if input_ids is not None: + # Prepare per-layer inputs from inputs_ids + per_layer_inputs_mask = torch.logical_and( + input_ids >= 0, input_ids < self.vocab_size_per_layer_input + ) + per_layer_inputs_tokens = torch.where( + per_layer_inputs_mask, input_ids, torch.zeros_like(input_ids) + ) + per_layer_inputs = self.language_model.get_per_layer_inputs( + per_layer_inputs_tokens + ) + + # Use general_mm_embed_routine for handling multimodal data + # This will automatically handle text, image, and audio embeddings + hidden_states = general_mm_embed_routine( + input_ids=input_ids, + forward_batch=forward_batch, + language_model=self.language_model, + image_data_embedding_func=self.get_image_feature, + audio_data_embedding_func=self.get_audio_feature, + positions=positions, + per_layer_inputs=per_layer_inputs, + ) + + # Process hidden states through logits processor + return self.logits_processor( + input_ids, hidden_states, self.language_model.embed_tokens, forward_batch + ) + + def tie_weights(self): + return self.language_model.tie_weights() + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".up_proj", 1), + (".gate_up_proj", ".gate_proj", 0), + ] + """Load weights for the model.""" + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + + for name, loaded_weight in weights: + name = re.sub(r"^model\.", "", name) + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + if "vision_model" in name: + # adapt to VisionAttention + name = name.replace(".self_attn.out_proj", ".self_attn.proj") + # Skip loading extra bias for GPTQ models + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +EntryClass = Gemma3nForConditionalGeneration