chore: remove vlm unnecessary import (#7541)
Signed-off-by: Xinyuan Tong <justinning0323@outlook.com> Co-authored-by: yhyang201 <yhyang201@gmail.com> Co-authored-by: Mick <mickjagger19@icloud.com>
This commit is contained in:
@@ -565,6 +565,7 @@ multimodal_model_archs = [
|
||||
"CLIPModel",
|
||||
"DeepseekVL2ForCausalLM",
|
||||
"Gemma3ForConditionalGeneration",
|
||||
"Gemma3nForConditionalGeneration",
|
||||
"Grok1VForCausalLM",
|
||||
"Grok1AForCausalLM",
|
||||
"LlavaLlamaForCausalLM",
|
||||
|
||||
@@ -823,6 +823,7 @@ register_conv_template(
|
||||
sep_style=SeparatorStyle.GEMMA3,
|
||||
stop_str=["<end_of_turn>"],
|
||||
image_token="<start_of_image>",
|
||||
audio_token="<start_of_audio>",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ class MultimodalInputFormat(Enum):
|
||||
RAW_IMAGES = "raw_images"
|
||||
PRECOMPUTED_FEATURES = "precomputed_features"
|
||||
PIXEL_VALUES = "pixel_values"
|
||||
AUDIO = "audio"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@@ -441,10 +442,13 @@ class BaseMultimodalProcessor(ABC):
|
||||
has_image = False
|
||||
has_pixel_values = False
|
||||
has_precomputed_features = False
|
||||
has_audio = False
|
||||
|
||||
for mm_input in mm_inputs:
|
||||
if isinstance(mm_input, Image.Image):
|
||||
has_image = True
|
||||
elif isinstance(mm_input, np.ndarray):
|
||||
has_audio = True
|
||||
elif isinstance(mm_input, dict):
|
||||
if mm_input.get("precomputed_features", None) is not None:
|
||||
has_precomputed_features = True
|
||||
@@ -461,13 +465,13 @@ class BaseMultimodalProcessor(ABC):
|
||||
|
||||
# Validate format consistency
|
||||
format_count = sum(
|
||||
[has_image, has_pixel_values, has_precomputed_features]
|
||||
[has_image, has_pixel_values, has_precomputed_features, has_audio]
|
||||
)
|
||||
if format_count > 1:
|
||||
raise ValueError(
|
||||
"Unsupported: mixture of multimodal input formats. "
|
||||
f"Found formats: image={has_image}, pixel_values={has_pixel_values}, "
|
||||
f"precomputed_features={has_precomputed_features}"
|
||||
f"precomputed_features={has_precomputed_features}, audio={has_audio}"
|
||||
)
|
||||
|
||||
if has_image:
|
||||
@@ -476,6 +480,8 @@ class BaseMultimodalProcessor(ABC):
|
||||
return MultimodalInputFormat.PRECOMPUTED_FEATURES
|
||||
elif has_pixel_values:
|
||||
return MultimodalInputFormat.PIXEL_VALUES
|
||||
elif has_audio:
|
||||
return MultimodalInputFormat.AUDIO
|
||||
else:
|
||||
raise ValueError("No valid multimodal input format found")
|
||||
except Exception as e:
|
||||
@@ -521,20 +527,47 @@ class BaseMultimodalProcessor(ABC):
|
||||
input_ids = tokenize_text(base_output.input_text)
|
||||
return combined_mm_item, input_ids
|
||||
|
||||
def process_audio(
|
||||
base_output: BaseMultiModalProcessorOutput,
|
||||
) -> Tuple[MultimodalDataItem, torch.Tensor]:
|
||||
"""Process inputs with audio."""
|
||||
ret = self.process_mm_data(
|
||||
input_text=base_output.input_text,
|
||||
audio=base_output.audios, # Note: "audio" is for gemma3n only
|
||||
)
|
||||
combined_mm_item = MultimodalDataItem(modality=Modality.AUDIO)
|
||||
for key, value in ret.items():
|
||||
if key != "input_ids" and hasattr(combined_mm_item, key):
|
||||
setattr(combined_mm_item, key, value)
|
||||
input_ids = ret["input_ids"].flatten()
|
||||
return combined_mm_item, input_ids
|
||||
|
||||
def finalize_mm_item(
|
||||
combined_mm_item: MultimodalDataItem, input_ids: torch.Tensor
|
||||
) -> MultimodalDataItem:
|
||||
"""Apply common post-processing to the multimodal item."""
|
||||
combined_mm_item.image_offsets = self.get_mm_items_offset(
|
||||
input_ids=input_ids,
|
||||
mm_token_id=self.IM_TOKEN_ID,
|
||||
)
|
||||
if combined_mm_item.modality in [Modality.IMAGE, Modality.MULTI_IMAGES]:
|
||||
combined_mm_item.image_offsets = self.get_mm_items_offset(
|
||||
input_ids=input_ids,
|
||||
mm_token_id=self.IM_TOKEN_ID,
|
||||
)
|
||||
elif combined_mm_item.modality == Modality.AUDIO:
|
||||
combined_mm_item.audio_offsets = self.get_mm_items_offset(
|
||||
input_ids=input_ids,
|
||||
mm_token_id=self.AUDIO_TOKEN_ID,
|
||||
)
|
||||
elif combined_mm_item.modality == Modality.VIDEO:
|
||||
combined_mm_item.video_offsets = self.get_mm_items_offset(
|
||||
input_ids=input_ids,
|
||||
mm_token_id=self.VIDEO_TOKEN_ID,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown modality: {combined_mm_item.modality}")
|
||||
return combined_mm_item
|
||||
|
||||
# Main logic
|
||||
mm_inputs = base_output.images
|
||||
# Main logic - determine input type and handle text-only case
|
||||
mm_inputs = base_output.images or base_output.audios
|
||||
if not mm_inputs:
|
||||
# Return text-only case
|
||||
input_ids = tokenize_text(base_output.input_text)
|
||||
return None, input_ids
|
||||
|
||||
@@ -548,6 +581,8 @@ class BaseMultimodalProcessor(ABC):
|
||||
combined_mm_item, input_ids = process_precomputed_features(base_output)
|
||||
elif input_format == MultimodalInputFormat.PIXEL_VALUES:
|
||||
combined_mm_item, input_ids = process_pixel_values(base_output)
|
||||
elif input_format == MultimodalInputFormat.AUDIO:
|
||||
combined_mm_item, input_ids = process_audio(base_output)
|
||||
else:
|
||||
raise ValueError(f"Unknown input format: {input_format}")
|
||||
|
||||
|
||||
97
python/sglang/srt/managers/multimodal_processors/gemma3n.py
Normal file
97
python/sglang/srt/managers/multimodal_processors/gemma3n.py
Normal file
@@ -0,0 +1,97 @@
|
||||
# Copyright 2025 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import re
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from sglang.srt.managers.multimodal_processor import (
|
||||
BaseMultimodalProcessor as SGLangBaseProcessor,
|
||||
)
|
||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||
MultimodalSpecialTokens,
|
||||
)
|
||||
from sglang.srt.models.gemma3n_mm import Gemma3nForConditionalGeneration
|
||||
|
||||
|
||||
class Gemma3nSGLangProcessor(SGLangBaseProcessor):
|
||||
"""Multimodal processor for Gemma3n supporting image and audio inputs."""
|
||||
|
||||
models = [Gemma3nForConditionalGeneration]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
|
||||
self.IMAGE_TOKEN = "<image_soft_token>"
|
||||
self.IMAGE_TOKEN_REGEX = re.compile(
|
||||
r"<start_of_image>(?:(?:<image_soft_token>)*<end_of_image>)?"
|
||||
)
|
||||
|
||||
self.AUDIO_TOKEN = "<audio_soft_token>"
|
||||
self.AUDIO_TOKEN_REGEX = re.compile(
|
||||
r"<start_of_audio>(?:(?:<audio_soft_token>)*<end_of_audio>)?"
|
||||
)
|
||||
|
||||
self.IM_TOKEN_ID = hf_config.image_token_id
|
||||
self.IM_START_TOKEN_ID = hf_config.boi_token_id
|
||||
self.IM_END_TOKEN_ID = hf_config.eoi_token_id
|
||||
|
||||
self.AUDIO_TOKEN_ID = hf_config.audio_token_id
|
||||
self.AUDIO_START_TOKEN_ID = hf_config.boa_token_id
|
||||
self.AUDIO_END_TOKEN_ID = hf_config.eoa_token_id
|
||||
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
image_data: Optional[List[Union[str, bytes, Dict]]] = None,
|
||||
audio_data: Optional[List[Union[str, bytes, Dict]]] = None,
|
||||
input_text: str = "",
|
||||
request_obj=None,
|
||||
max_req_input_len: int = 0,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
"""Process multimodal data including images and audio."""
|
||||
|
||||
audio_data = request_obj.audio_data
|
||||
if not image_data and not audio_data:
|
||||
return None
|
||||
|
||||
if isinstance(image_data, str):
|
||||
image_data = [image_data]
|
||||
|
||||
if isinstance(audio_data, str):
|
||||
audio_data = [audio_data]
|
||||
|
||||
base_output = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
image_data=image_data,
|
||||
audio_data=audio_data,
|
||||
max_req_input_len=max_req_input_len,
|
||||
multimodal_tokens=MultimodalSpecialTokens(
|
||||
image_token=self.IMAGE_TOKEN,
|
||||
image_token_regex=self.IMAGE_TOKEN_REGEX,
|
||||
audio_token=self.AUDIO_TOKEN,
|
||||
audio_token_regex=self.AUDIO_TOKEN_REGEX,
|
||||
),
|
||||
)
|
||||
|
||||
combined_mm_item, input_ids = self.process_and_combine_mm_data(base_output)
|
||||
|
||||
return {
|
||||
"input_ids": input_ids.tolist(),
|
||||
"mm_items": [combined_mm_item] if combined_mm_item is not None else [],
|
||||
"im_start_id": self.IM_START_TOKEN_ID,
|
||||
"im_end_id": self.IM_END_TOKEN_ID,
|
||||
"audio_start_id": self.AUDIO_START_TOKEN_ID,
|
||||
"audio_end_id": self.AUDIO_END_TOKEN_ID,
|
||||
}
|
||||
@@ -214,6 +214,10 @@ class MultimodalDataItem:
|
||||
audio_feature_lens: Optional[List[torch.Tensor]] = None
|
||||
audio_offsets: Optional[List[Tuple[int, int]]] = None
|
||||
|
||||
# gemma3n related
|
||||
input_features: Optional[torch.Tensor] = None
|
||||
input_features_mask: Optional[torch.Tensor] = None
|
||||
|
||||
precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None
|
||||
|
||||
@staticmethod
|
||||
@@ -277,7 +281,10 @@ class MultimodalDataItem:
|
||||
if self.precomputed_features is not None:
|
||||
self.hash = hash_feature(self.precomputed_features)
|
||||
elif self.is_audio():
|
||||
self.hash = hash_feature(self.audio_features)
|
||||
if self.audio_features is not None:
|
||||
self.hash = hash_feature(self.audio_features)
|
||||
elif self.input_features is not None:
|
||||
self.hash = hash_feature(self.input_features)
|
||||
else:
|
||||
self.hash = hash_feature(self.pixel_values)
|
||||
|
||||
@@ -288,6 +295,7 @@ class MultimodalDataItem:
|
||||
return (self.modality == Modality.AUDIO) and (
|
||||
self.precomputed_features is not None
|
||||
or not MultimodalDataItem.is_empty_list(self.audio_features)
|
||||
or not MultimodalDataItem.is_empty_list(self.input_features)
|
||||
)
|
||||
|
||||
def is_image(self):
|
||||
|
||||
949
python/sglang/srt/models/gemma3n_audio.py
Normal file
949
python/sglang/srt/models/gemma3n_audio.py
Normal file
@@ -0,0 +1,949 @@
|
||||
import math
|
||||
from typing import Optional, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import Gemma3nAudioConfig, PreTrainedModel
|
||||
|
||||
from sglang.srt.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.models.gemma3n_causal import Gemma3nRMSNorm
|
||||
from sglang.srt.utils import add_prefix, make_layers
|
||||
|
||||
|
||||
class Gemma3nCumulativeGroupNorm(nn.Module):
|
||||
"""Applies Group Normalization cumulatively over the time dimension.
|
||||
|
||||
This layer normalizes the input by calculating the mean and variance
|
||||
cumulatively over the time dimension (dim 1). The statistics are computed
|
||||
over all feature dimensions (specified by `feature_dims` and `num_channels`)
|
||||
for elements marked as valid by the optional `mask`.
|
||||
|
||||
If a `mask` is provided (True for valid, False for invalid/padded),
|
||||
invalid time steps do not contribute to the statistics calculation, and
|
||||
their corresponding output values are zeroed out.
|
||||
|
||||
Scale and bias, if enabled, are applied per-channel (last dimension).
|
||||
This behavior is similar to JAX's `GroupNormalization` with `num_groups=1`
|
||||
and `cumulative=True`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_channels: int, # Number of channels (size of the last dimension)
|
||||
feature_dims: Sequence[
|
||||
int
|
||||
], # Sizes of non-channel feature dimensions, e.g., (H, W) for input [B,T,H,W,C]
|
||||
eps: float = 1e-3,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_channels = num_channels
|
||||
self.feature_dims = tuple(feature_dims)
|
||||
self.eps = eps
|
||||
|
||||
# Scale parameter depends only on the channel dimension
|
||||
self.weight = nn.Parameter(torch.ones(num_channels))
|
||||
|
||||
# Axes for normalization: all dimensions except Batch (0) and Time (1).
|
||||
# For input [B, T, *feature_dims, C], these are dims from 2 onwards.
|
||||
self.reduction_axes = tuple(range(2, 2 + len(self.feature_dims) + 1))
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, mask: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
"""Applies cumulative group norm, optionally using a mask.
|
||||
|
||||
Args:
|
||||
x: Input tensor, shape [B, T, *feature_dims, C].
|
||||
mask: Optional boolean mask, shape [B, T]. True indicates a valid
|
||||
(non-padded) time step. If None, all time steps are considered valid.
|
||||
|
||||
Returns:
|
||||
Normalized tensor with the same shape as x.
|
||||
"""
|
||||
expected_input_suffix = self.feature_dims + (self.num_channels,)
|
||||
if x.shape[2:] != expected_input_suffix:
|
||||
raise ValueError(
|
||||
f"Input tensor shape suffix {x.shape[2:]} does not match expected"
|
||||
f" suffix (feature_dims + num_channels) {expected_input_suffix}"
|
||||
)
|
||||
|
||||
input_dtype = x.dtype
|
||||
# Calculations are performed in float32 for numerical stability.
|
||||
calc_dtype = torch.float32
|
||||
x_calc = x.to(calc_dtype)
|
||||
|
||||
# Prepare a broadcastable mask (`mask_calc`).
|
||||
# If no mask is provided, treat all elements as valid
|
||||
# (mask_calc is all ones).
|
||||
# Otherwise, expand the [B, T] mask to [B, T, 1, ..., 1] for broadcasting.
|
||||
mask_calc = torch.ones_like(x_calc, dtype=calc_dtype)
|
||||
|
||||
# Cumulative Statistics Calculation
|
||||
# 1. Sum of values over reduction axes at each time step.
|
||||
sum_values_at_t = torch.sum(x_calc, dim=self.reduction_axes, keepdim=True)
|
||||
# 2. Cumulative sum of values over time.
|
||||
cum_sum_values = torch.cumsum(sum_values_at_t, dim=1)
|
||||
|
||||
# 3. Count of valid elements in the normalization group at each time step.
|
||||
# (A "group" here consists of all features at a given Batch, Time).
|
||||
elements_in_group_at_t = torch.sum(
|
||||
mask_calc, dim=self.reduction_axes, keepdim=True
|
||||
)
|
||||
# 4. Cumulative count of valid elements over time.
|
||||
cum_count_elements = torch.cumsum(elements_in_group_at_t, dim=1)
|
||||
# Avoid division by zero if all preceding elements were masked.
|
||||
safe_cum_count_elements = torch.clamp(cum_count_elements, min=1.0)
|
||||
|
||||
# 5. Cumulative mean.
|
||||
cum_mean = cum_sum_values / safe_cum_count_elements
|
||||
|
||||
# 6. Sum of squared differences from the cumulative mean.
|
||||
# Only sum for valid elements: (x_calc - cum_mean)^2 * mask_calc.
|
||||
# Using x_calc here for the difference, as cum_mean already accounts for masking.
|
||||
squared_diff_from_mean = (x_calc - cum_mean).pow(2)
|
||||
sum_sq_diff_at_t = torch.sum(
|
||||
squared_diff_from_mean, dim=self.reduction_axes, keepdim=True
|
||||
)
|
||||
|
||||
# 7. Cumulative sum of squared differences over time.
|
||||
cum_sum_sq_diff = torch.cumsum(sum_sq_diff_at_t, dim=1)
|
||||
|
||||
# 8. Cumulative variance.
|
||||
cum_variance = cum_sum_sq_diff / safe_cum_count_elements
|
||||
|
||||
# Normalize the input using the calculated cumulative statistics:
|
||||
# (x - E[x]) / sqrt(Var[x] + eps)
|
||||
normalized_x = (x_calc - cum_mean) * torch.rsqrt(cum_variance + self.eps)
|
||||
|
||||
# Apply affine transformation (scale and bias) if enabled.
|
||||
# Scale and bias are applied per-channel (last dimension).
|
||||
scale = self.weight.to(calc_dtype)
|
||||
# Reshape for broadcasting: [C] -> [1, ..., 1, C]
|
||||
scale_view_shape = [1] * (x.dim() - 1) + [self.num_channels]
|
||||
normalized_x = normalized_x * scale.view(scale_view_shape)
|
||||
|
||||
# Zero out outputs for time steps that were originally masked (where mask_calc is 0).
|
||||
# This ensures padded/invalid positions in the input result in zero output.
|
||||
final_output = normalized_x * mask_calc
|
||||
|
||||
return final_output.to(input_dtype)
|
||||
|
||||
|
||||
class Gemma3nAudioRelativePositionEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: Gemma3nAudioConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.num_heads = self.config.conf_num_attention_heads
|
||||
self.channels = self.config.hidden_size
|
||||
self.head_dim = self.channels // self.num_heads
|
||||
self.max_backward = max(0, self.config.conf_attention_context_left - 1)
|
||||
self.max_forward = self.config.conf_attention_context_right
|
||||
|
||||
self.pos_proj = ColumnParallelLinear(
|
||||
self.channels,
|
||||
self.num_heads * self.head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("pos_proj", prefix),
|
||||
)
|
||||
|
||||
min_timescale = 1.0
|
||||
max_timescale = 1.0e4
|
||||
num_timescales = self.channels // 2
|
||||
log_timescale_increment = math.log(
|
||||
float(max_timescale) / float(min_timescale)
|
||||
) / max(num_timescales - 1, 1)
|
||||
inv_timescales = min_timescale * torch.exp(
|
||||
torch.arange(num_timescales) * -log_timescale_increment
|
||||
)
|
||||
self.register_buffer(
|
||||
"inv_timescales",
|
||||
inv_timescales.float().unsqueeze(0).unsqueeze(0),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
def _get_timing_signal_1d_pos(
|
||||
self, position: torch.Tensor, dtype: torch.dtype
|
||||
) -> torch.Tensor:
|
||||
assert position.ndim == 2
|
||||
position = position.float().unsqueeze(-1)
|
||||
scaled_time = position * self.inv_timescales.to(
|
||||
device=position.device, dtype=torch.float32
|
||||
)
|
||||
timing_signal = torch.cat(
|
||||
[torch.sin(scaled_time), torch.cos(scaled_time)], dim=-1
|
||||
)
|
||||
return timing_signal.type(dtype)
|
||||
|
||||
def _relative_shift(
|
||||
self,
|
||||
term_bd_before_shift: torch.Tensor,
|
||||
batch_size: int,
|
||||
num_heads: int,
|
||||
num_query_blocks: int,
|
||||
query_block_size: int,
|
||||
key_context_size: int,
|
||||
max_span_plus_1: int,
|
||||
) -> torch.Tensor:
|
||||
"""Performs the relative shift."""
|
||||
pad_amount_last_dim = (key_context_size + 1) - max_span_plus_1
|
||||
padding_tuple = (0, pad_amount_last_dim)
|
||||
|
||||
term_bd_padded = F.pad(term_bd_before_shift, padding_tuple)
|
||||
term_bd_reshaped = term_bd_padded.reshape(
|
||||
(
|
||||
batch_size,
|
||||
num_heads,
|
||||
num_query_blocks,
|
||||
query_block_size * (key_context_size + 1),
|
||||
)
|
||||
)
|
||||
term_bd_sliced = term_bd_reshaped[
|
||||
:, :, :, : query_block_size * key_context_size
|
||||
]
|
||||
term_bd_shifted = term_bd_sliced.reshape(
|
||||
(
|
||||
batch_size,
|
||||
num_heads,
|
||||
num_query_blocks,
|
||||
query_block_size,
|
||||
key_context_size,
|
||||
)
|
||||
)
|
||||
return term_bd_shifted
|
||||
|
||||
def forward(self, queries: torch.Tensor, keys: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, num_query_blocks, query_block_size, num_heads, head_dim = (
|
||||
queries.shape
|
||||
)
|
||||
_, _, key_context_size, _, _ = keys.shape
|
||||
|
||||
pos_indices = torch.arange(
|
||||
self.max_backward, -self.max_forward - 1, -1, device=queries.device
|
||||
).unsqueeze(0)
|
||||
max_span_plus_1 = pos_indices.shape[1]
|
||||
|
||||
sin_emb_timing_signal = self._get_timing_signal_1d_pos(
|
||||
pos_indices, dtype=queries.dtype
|
||||
)
|
||||
projected_sin_emb, _ = self.pos_proj(sin_emb_timing_signal)
|
||||
sin_emb = projected_sin_emb.reshape(
|
||||
1, max_span_plus_1, self.num_heads, self.head_dim
|
||||
).squeeze(0)
|
||||
|
||||
queries_p = queries.permute(0, 3, 1, 2, 4)
|
||||
keys_p_t = keys.permute(0, 3, 1, 4, 2)
|
||||
term_ac = torch.matmul(queries_p, keys_p_t)
|
||||
|
||||
q_permuted = queries.permute(0, 3, 1, 2, 4)
|
||||
s_permuted = sin_emb.permute(1, 2, 0)
|
||||
q_reshaped = q_permuted.reshape(
|
||||
batch_size, num_heads, num_query_blocks * query_block_size, head_dim
|
||||
)
|
||||
term_bd_unshifed_matmul = torch.matmul(q_reshaped, s_permuted)
|
||||
term_bd_unshifed = term_bd_unshifed_matmul.reshape(
|
||||
batch_size,
|
||||
num_heads,
|
||||
num_query_blocks,
|
||||
query_block_size,
|
||||
max_span_plus_1,
|
||||
)
|
||||
|
||||
term_bd_shifted = self._relative_shift(
|
||||
term_bd_unshifed,
|
||||
batch_size,
|
||||
num_heads,
|
||||
num_query_blocks,
|
||||
query_block_size,
|
||||
key_context_size,
|
||||
max_span_plus_1,
|
||||
)
|
||||
|
||||
return term_ac + term_bd_shifted
|
||||
|
||||
|
||||
class Gemma3nAudioAttention(nn.Module):
|
||||
"""Local dot product self-attention for audio."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Gemma3nAudioConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.num_heads = self.config.conf_num_attention_heads
|
||||
self.hidden_size = self.config.hidden_size
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
|
||||
self.chunk_size = self.config.conf_attention_chunk_size
|
||||
self.max_future_horizon = self.config.conf_attention_context_right
|
||||
self.max_past_horizon = max(0, self.config.conf_attention_context_left - 1)
|
||||
self.attention_logits_soft_cap = self.config.conf_attention_logit_cap
|
||||
self.context_size = (
|
||||
self.chunk_size + self.max_past_horizon + self.max_future_horizon
|
||||
)
|
||||
|
||||
self.relative_position_embedding = Gemma3nAudioRelativePositionEmbedding(
|
||||
config,
|
||||
quant_config,
|
||||
prefix=add_prefix("relative_position_embedding", prefix),
|
||||
)
|
||||
self.per_dim_scale = nn.Parameter(torch.zeros((self.head_dim,)))
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
self.hidden_size,
|
||||
self.head_dim,
|
||||
self.num_heads,
|
||||
self.num_heads,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("qkv_proj", prefix),
|
||||
)
|
||||
|
||||
q_scale = self.head_dim**-0.5
|
||||
r_softplus_0 = 1.0 / F.softplus(torch.tensor(0.0))
|
||||
self.register_buffer(
|
||||
"q_scale", (q_scale * r_softplus_0).clone().detach(), persistent=False
|
||||
)
|
||||
|
||||
# Create local causal mask
|
||||
lower_causal_mask = torch.tril(
|
||||
torch.ones((self.context_size, self.chunk_size), dtype=torch.bool),
|
||||
diagonal=0,
|
||||
).T
|
||||
upper_causal_mask = torch.tril(
|
||||
torch.ones((self.chunk_size, self.context_size), dtype=torch.bool),
|
||||
diagonal=self.max_past_horizon + self.max_future_horizon,
|
||||
)
|
||||
local_causal_valid_mask = torch.ones(
|
||||
(self.chunk_size, self.context_size), dtype=torch.bool
|
||||
)
|
||||
local_causal_valid_mask = (
|
||||
local_causal_valid_mask * lower_causal_mask * upper_causal_mask
|
||||
)
|
||||
self.register_buffer(
|
||||
"local_causal_valid_mask", local_causal_valid_mask, persistent=False
|
||||
)
|
||||
|
||||
self.register_buffer(
|
||||
"softcap",
|
||||
torch.tensor(self.attention_logits_soft_cap).float(),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
def _pad_dim1(
|
||||
self, x: torch.Tensor, dim10_val: int, dim11_val: int
|
||||
) -> torch.Tensor:
|
||||
padding_tuple = [0] * x.ndim * 2
|
||||
dim_idx_from_end = x.ndim - 2
|
||||
start_idx_for_dim = 2 * dim_idx_from_end
|
||||
padding_tuple[start_idx_for_dim] = dim10_val
|
||||
padding_tuple[start_idx_for_dim + 1] = dim11_val
|
||||
return F.pad(x, tuple(padding_tuple))
|
||||
|
||||
def _convert_to_block(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Turns a sequence to non overlapping blocks."""
|
||||
shape = x.shape
|
||||
b, t = shape[:2]
|
||||
num_blocks = (t + self.chunk_size - 1) // self.chunk_size
|
||||
|
||||
if (padding_len := num_blocks * self.chunk_size - t) > 0:
|
||||
x = self._pad_dim1(x, 0, padding_len)
|
||||
|
||||
permute_dims = (b, num_blocks, self.chunk_size) + shape[2:]
|
||||
x = x.reshape(permute_dims).contiguous()
|
||||
return x
|
||||
|
||||
def _extract_block_context(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Extracts temporal context for every block."""
|
||||
pad_left = self.max_past_horizon
|
||||
pad_right = self.max_future_horizon + self.chunk_size - 1
|
||||
x = self._pad_dim1(x, pad_left, pad_right)
|
||||
|
||||
frame_len = self.context_size
|
||||
frame_step = self.chunk_size
|
||||
|
||||
x_unfolded = x.unfold(dimension=1, size=frame_len, step=frame_step)
|
||||
|
||||
if x.ndim > 2 and x_unfolded.ndim > 3:
|
||||
x_unfolded = torch.movedim(x_unfolded, source=-1, destination=2)
|
||||
|
||||
return x_unfolded.contiguous()
|
||||
|
||||
def forward(self, x: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor:
|
||||
# Project to Q, K, V
|
||||
qkv, _ = self.qkv_proj(x)
|
||||
query_states, key_states, value_states = qkv.chunk(chunks=3, dim=-1)
|
||||
|
||||
# Reshape
|
||||
query_states = query_states.reshape(
|
||||
*x.shape[:-1], self.num_heads, self.head_dim
|
||||
).contiguous()
|
||||
key_states = key_states.reshape(
|
||||
*x.shape[:-1], self.num_heads, self.head_dim
|
||||
).contiguous()
|
||||
value_states = value_states.reshape(
|
||||
*x.shape[:-1], self.num_heads, self.head_dim
|
||||
).contiguous()
|
||||
|
||||
# Apply per-dim scale
|
||||
per_dim_scale_sp = F.softplus(self.per_dim_scale)
|
||||
broadcast_shape = (1, 1, 1, self.head_dim)
|
||||
per_dim_scale_sp_broadcast = per_dim_scale_sp.view(broadcast_shape)
|
||||
query_states = query_states * self.q_scale * per_dim_scale_sp_broadcast
|
||||
|
||||
batch_size, q_time = query_states.shape[:2]
|
||||
|
||||
# Convert to blocks
|
||||
query_blocks = self._convert_to_block(query_states)
|
||||
key_blocks = self._extract_block_context(key_states)
|
||||
value_blocks = self._extract_block_context(value_states)
|
||||
num_query_blocks = query_blocks.shape[1]
|
||||
|
||||
# Create mask for valid positions
|
||||
original_valid_mask = ~mask
|
||||
extracted_valid_mask_blocks = self._extract_block_context(original_valid_mask)
|
||||
|
||||
if (
|
||||
extracted_valid_mask_blocks.ndim == 4
|
||||
and extracted_valid_mask_blocks.shape[0] == batch_size
|
||||
and extracted_valid_mask_blocks.shape[1] == num_query_blocks
|
||||
and extracted_valid_mask_blocks.shape[2]
|
||||
* extracted_valid_mask_blocks.shape[3]
|
||||
== self.context_size
|
||||
):
|
||||
extracted_valid_mask_blocks = extracted_valid_mask_blocks.reshape(
|
||||
batch_size, num_query_blocks, self.context_size
|
||||
)
|
||||
|
||||
condition_from_input_validity = extracted_valid_mask_blocks.unsqueeze(
|
||||
1
|
||||
).unsqueeze(-2)
|
||||
condition_from_causality = (
|
||||
self.local_causal_valid_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0)
|
||||
)
|
||||
|
||||
final_condition_for_where = torch.logical_and(
|
||||
condition_from_input_validity,
|
||||
condition_from_causality.to(condition_from_input_validity.device),
|
||||
)
|
||||
|
||||
# Compute attention scores
|
||||
logits = self.relative_position_embedding(query_blocks, key_blocks)
|
||||
|
||||
# Apply attention logit softcap
|
||||
softcap_val = self.softcap.to(logits.device)
|
||||
logits = logits / softcap_val
|
||||
logits = torch.tanh(logits)
|
||||
logits = logits * softcap_val
|
||||
|
||||
# Apply the combined mask.
|
||||
# final_condition_for_where will broadcast with logits [B,N,U,W,C]
|
||||
logits = torch.where(
|
||||
final_condition_for_where, logits, torch.finfo(logits.dtype).min
|
||||
)
|
||||
|
||||
probabilities = F.softmax(logits, dim=-1, dtype=torch.float32).to(
|
||||
dtype=value_blocks.dtype
|
||||
)
|
||||
|
||||
# context_vectors is adapted from jax.numpy.einsum("BNuwc,BucNH->BuwNH", ...)
|
||||
b_dim, n_dim, u_dim, w_dim, c_dim = probabilities.shape
|
||||
h_dim = value_blocks.shape[-1]
|
||||
prob_bun = probabilities.permute(0, 2, 1, 3, 4).reshape(-1, w_dim, c_dim)
|
||||
v_bun = value_blocks.permute(0, 1, 3, 2, 4).reshape(-1, c_dim, h_dim)
|
||||
result_bmm = torch.bmm(prob_bun, v_bun)
|
||||
context_vectors = result_bmm.reshape(b_dim, u_dim, n_dim, w_dim, h_dim).permute(
|
||||
0, 1, 3, 2, 4
|
||||
)
|
||||
context_vectors = context_vectors.reshape(
|
||||
(
|
||||
batch_size,
|
||||
num_query_blocks * self.chunk_size,
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
)
|
||||
)
|
||||
context_vectors = context_vectors[:, :q_time]
|
||||
|
||||
return context_vectors
|
||||
|
||||
|
||||
class Gemma3nAudioSSCPConvBlock(nn.Module):
|
||||
"""A single convolution block for the SubSampleConvProjection."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Gemma3nAudioConfig,
|
||||
idx: int,
|
||||
input_freq_dim: int,
|
||||
manual_padding: Tuple[int, int, int, int] = (0, 0, 0, 0),
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.manual_padding = manual_padding
|
||||
|
||||
in_channels = 1 if idx == 0 else self.config.sscp_conv_channel_size[idx - 1]
|
||||
out_channels = self.config.sscp_conv_channel_size[idx]
|
||||
kernel_h, kernel_w = self.config.sscp_conv_kernel_size[idx]
|
||||
stride_h, stride_w = self.config.sscp_conv_stride_size[idx]
|
||||
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=(kernel_h, kernel_w),
|
||||
stride=(stride_h, stride_w),
|
||||
padding=(0, 0), # Manual padding is used
|
||||
bias=False,
|
||||
)
|
||||
|
||||
f_in_padded = input_freq_dim + self.manual_padding[0] + self.manual_padding[1]
|
||||
f_out_conv = (f_in_padded - kernel_w) // stride_w + 1
|
||||
|
||||
self.norm = Gemma3nCumulativeGroupNorm(
|
||||
num_channels=out_channels,
|
||||
feature_dims=(f_out_conv,),
|
||||
eps=self.config.sscp_conv_group_norm_eps,
|
||||
)
|
||||
|
||||
self.activation = nn.ReLU()
|
||||
|
||||
def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
|
||||
audio_encodings_padded = F.pad(
|
||||
audio_encodings, self.manual_padding, mode="constant", value=0.0
|
||||
)
|
||||
audio_encodings_conv = self.conv(audio_encodings_padded)
|
||||
x_for_norm = audio_encodings_conv.permute(0, 2, 3, 1).contiguous()
|
||||
x_normed = self.norm(x_for_norm)
|
||||
audio_encodings_normed = x_normed.permute(0, 3, 1, 2).contiguous()
|
||||
return self.activation(audio_encodings_normed)
|
||||
|
||||
|
||||
class Gemma3nAudioSubSampleConvProjection(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: Gemma3nAudioConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
current_f_for_block_input = config.input_feat_size
|
||||
calculated_block_padding = []
|
||||
calculated_f_out_dims = []
|
||||
|
||||
for i in range(2): # Assuming 2 conv layers
|
||||
kernel_h, kernel_w = config.sscp_conv_kernel_size[i]
|
||||
stride_h, stride_w = config.sscp_conv_stride_size[i]
|
||||
|
||||
# Padding for Time (Height for Conv2d) - REVERSE_CAUSAL like
|
||||
pad_t_top = 0
|
||||
pad_t_bottom = kernel_h - 1
|
||||
|
||||
# Frequency Padding (Width for Conv2d)
|
||||
pad_f_left = 1
|
||||
pad_f_right = 1
|
||||
|
||||
manual_padding_tuple = (pad_f_left, pad_f_right, pad_t_top, pad_t_bottom)
|
||||
calculated_block_padding.append(manual_padding_tuple)
|
||||
|
||||
f_in_padded = current_f_for_block_input + pad_f_left + pad_f_right
|
||||
f_out_after_conv = (f_in_padded - kernel_w) // stride_w + 1
|
||||
calculated_f_out_dims.append(f_out_after_conv)
|
||||
current_f_for_block_input = f_out_after_conv
|
||||
|
||||
self.conv_0 = Gemma3nAudioSSCPConvBlock(
|
||||
idx=0,
|
||||
input_freq_dim=config.input_feat_size,
|
||||
config=config,
|
||||
manual_padding=calculated_block_padding[0],
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("conv_0", prefix),
|
||||
)
|
||||
self.conv_1 = Gemma3nAudioSSCPConvBlock(
|
||||
idx=1,
|
||||
input_freq_dim=calculated_f_out_dims[0],
|
||||
config=config,
|
||||
manual_padding=calculated_block_padding[1],
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("conv_1", prefix),
|
||||
)
|
||||
|
||||
final_c_out = config.sscp_conv_channel_size[-1]
|
||||
final_f_out = calculated_f_out_dims[-1]
|
||||
self.input_proj_in_features = final_c_out * final_f_out
|
||||
|
||||
self.input_proj_linear = RowParallelLinear(
|
||||
self.input_proj_in_features,
|
||||
self.config.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("input_proj_linear", prefix),
|
||||
)
|
||||
|
||||
def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
|
||||
audio_encodings_reshaped = audio_encodings.unsqueeze(1)
|
||||
x = self.conv_0(audio_encodings_reshaped)
|
||||
x = self.conv_1(x)
|
||||
b, c_out, t_out, f_out = x.shape
|
||||
x_permuted = x.permute(0, 2, 3, 1).contiguous()
|
||||
output_flattened = x_permuted.view(b, t_out, f_out * c_out)
|
||||
output, _ = self.input_proj_linear(output_flattened)
|
||||
return output
|
||||
|
||||
|
||||
class Gemma3nAudioConformerAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: Gemma3nAudioConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
head_dim = self.config.hidden_size // self.config.conf_num_attention_heads
|
||||
self.post_in_shape = (self.config.conf_num_attention_heads, head_dim)
|
||||
self.post_in_features = self.config.hidden_size
|
||||
|
||||
self.register_buffer(
|
||||
"gradient_clipping",
|
||||
torch.tensor(self.config.gradient_clipping),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
self.pre_attn_norm = Gemma3nRMSNorm(self.config.hidden_size)
|
||||
self.attn = Gemma3nAudioAttention(
|
||||
config, quant_config, prefix=add_prefix("attn", prefix)
|
||||
)
|
||||
self.post = RowParallelLinear(
|
||||
self.post_in_features,
|
||||
self.config.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("post", prefix),
|
||||
)
|
||||
self.post_norm = Gemma3nRMSNorm(self.config.hidden_size)
|
||||
|
||||
def forward(
|
||||
self, audio_encodings: torch.Tensor, audio_mel_mask: torch.BoolTensor
|
||||
) -> torch.Tensor:
|
||||
audio_encodings_input_to_attn = audio_encodings
|
||||
audio_encodings = torch.clamp(
|
||||
audio_encodings, -self.gradient_clipping, self.gradient_clipping
|
||||
)
|
||||
audio_encodings_norm = self.pre_attn_norm(audio_encodings)
|
||||
audio_encodings_attn_out = self.attn(audio_encodings_norm, audio_mel_mask)
|
||||
|
||||
b, t, num_heads, head_dim = audio_encodings_attn_out.shape
|
||||
audio_encodings_reshaped = audio_encodings_attn_out.reshape(
|
||||
b, t, num_heads * head_dim
|
||||
)
|
||||
|
||||
audio_encodings, _ = self.post(audio_encodings_reshaped)
|
||||
audio_encodings = torch.clamp(
|
||||
audio_encodings, -self.gradient_clipping, self.gradient_clipping
|
||||
)
|
||||
return audio_encodings_input_to_attn + self.post_norm(audio_encodings)
|
||||
|
||||
|
||||
class Gemma3nAudioConformerFeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: Gemma3nAudioConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.register_buffer(
|
||||
"gradient_clipping",
|
||||
torch.tensor(self.config.gradient_clipping),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
self.pre_layer_norm = Gemma3nRMSNorm(self.config.hidden_size)
|
||||
self.ffw_layer_1 = ColumnParallelLinear(
|
||||
self.config.hidden_size,
|
||||
self.config.hidden_size * 4,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("ffw_layer_1", prefix),
|
||||
)
|
||||
self.ffw_layer_2 = RowParallelLinear(
|
||||
self.config.hidden_size * 4,
|
||||
self.config.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("ffw_layer_2", prefix),
|
||||
)
|
||||
self.post_layer_norm = Gemma3nRMSNorm(self.config.hidden_size)
|
||||
self.post_layer_scale = torch.tensor(self.config.conf_residual_weight)
|
||||
|
||||
def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
|
||||
residual = audio_encodings
|
||||
audio_encodings = torch.clamp(
|
||||
audio_encodings, -self.gradient_clipping, self.gradient_clipping
|
||||
)
|
||||
audio_encodings = self.pre_layer_norm(audio_encodings)
|
||||
audio_encodings, _ = self.ffw_layer_1(audio_encodings)
|
||||
audio_encodings = F.silu(audio_encodings)
|
||||
audio_encodings, _ = self.ffw_layer_2(audio_encodings)
|
||||
audio_encodings = torch.clamp(
|
||||
audio_encodings, -self.gradient_clipping, self.gradient_clipping
|
||||
)
|
||||
audio_encodings = self.post_layer_norm(audio_encodings)
|
||||
return residual + (audio_encodings * self.post_layer_scale)
|
||||
|
||||
|
||||
class Gemma3nAudioConformerLightConv1d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: Gemma3nAudioConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.pre_layer_norm = Gemma3nRMSNorm(
|
||||
self.config.hidden_size, eps=self.config.rms_norm_eps
|
||||
)
|
||||
self.linear_start = ColumnParallelLinear(
|
||||
self.config.hidden_size,
|
||||
self.config.hidden_size * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("linear_start", prefix),
|
||||
)
|
||||
|
||||
self.depthwise_conv1d = nn.Conv1d(
|
||||
in_channels=self.config.hidden_size,
|
||||
out_channels=self.config.hidden_size,
|
||||
kernel_size=self.config.conf_conv_kernel_size,
|
||||
stride=1,
|
||||
padding=0, # Manual causal padding
|
||||
groups=self.config.hidden_size, # Depthwise
|
||||
bias=False,
|
||||
)
|
||||
self.register_buffer(
|
||||
"gradient_clipping",
|
||||
torch.tensor(self.config.gradient_clipping),
|
||||
persistent=False,
|
||||
)
|
||||
self.conv_norm = Gemma3nRMSNorm(
|
||||
self.config.hidden_size, eps=self.config.rms_norm_eps
|
||||
)
|
||||
self.linear_end = RowParallelLinear(
|
||||
self.config.hidden_size,
|
||||
self.config.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("linear_end", prefix),
|
||||
)
|
||||
|
||||
self.causal_padding = self.config.conf_conv_kernel_size - 1
|
||||
|
||||
def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
|
||||
audio_encodings_residual = audio_encodings # Save for residual connection
|
||||
|
||||
audio_encodings = self.pre_layer_norm(audio_encodings)
|
||||
audio_encodings, _ = self.linear_start(audio_encodings)
|
||||
audio_encodings = F.glu(audio_encodings, dim=-1)
|
||||
|
||||
# Permute for Conv1d: [B, T, D] -> [B, D, T]
|
||||
audio_encodings_permuted = audio_encodings.permute(0, 2, 1)
|
||||
# Apply manual causal padding
|
||||
audio_encodings_permuted_padded = F.pad(
|
||||
audio_encodings_permuted, (self.causal_padding, 0)
|
||||
)
|
||||
audio_encodings = self.depthwise_conv1d(audio_encodings_permuted_padded)
|
||||
# Permute back: [B, D, T_out] -> [B, T_out, D]
|
||||
audio_encodings = audio_encodings.permute(0, 2, 1)
|
||||
audio_encodings = torch.clamp(
|
||||
audio_encodings, -self.gradient_clipping, self.gradient_clipping
|
||||
)
|
||||
audio_encodings = self.conv_norm(audio_encodings)
|
||||
audio_encodings = F.silu(audio_encodings)
|
||||
audio_encodings, _ = self.linear_end(audio_encodings)
|
||||
output = audio_encodings + audio_encodings_residual
|
||||
return output
|
||||
|
||||
|
||||
class Gemma3nAudioConformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: Gemma3nAudioConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.ffw_layer_start = Gemma3nAudioConformerFeedForward(
|
||||
config, quant_config, prefix=add_prefix("ffw_layer_start", prefix)
|
||||
)
|
||||
self.attention = Gemma3nAudioConformerAttention(
|
||||
config, quant_config, prefix=add_prefix("attention", prefix)
|
||||
)
|
||||
self.lconv1d = Gemma3nAudioConformerLightConv1d(
|
||||
config, quant_config, prefix=add_prefix("lconv1d", prefix)
|
||||
)
|
||||
self.ffw_layer_end = Gemma3nAudioConformerFeedForward(
|
||||
config, quant_config, prefix=add_prefix("ffw_layer_end", prefix)
|
||||
)
|
||||
self.register_buffer(
|
||||
"gradient_clipping",
|
||||
torch.tensor(self.config.gradient_clipping),
|
||||
persistent=False,
|
||||
)
|
||||
self.norm = Gemma3nRMSNorm(self.config.hidden_size)
|
||||
|
||||
def forward(
|
||||
self, audio_encodings: torch.Tensor, audio_mel_mask: torch.BoolTensor
|
||||
) -> torch.Tensor:
|
||||
audio_encodings = self.ffw_layer_start(audio_encodings)
|
||||
audio_encodings = self.attention(audio_encodings, audio_mel_mask)
|
||||
validity_mask_for_lconv = ~audio_mel_mask # True for valid
|
||||
audio_encodings_for_lconv_input = (
|
||||
audio_encodings
|
||||
* validity_mask_for_lconv.unsqueeze(-1).to(audio_encodings.dtype)
|
||||
)
|
||||
audio_encodings = self.lconv1d(audio_encodings_for_lconv_input)
|
||||
|
||||
audio_encodings = self.ffw_layer_end(audio_encodings)
|
||||
audio_encodings = torch.clamp(
|
||||
audio_encodings, -self.gradient_clipping, self.gradient_clipping
|
||||
)
|
||||
output = self.norm(audio_encodings)
|
||||
return output
|
||||
|
||||
|
||||
class Gemma3nAudioEncoder(PreTrainedModel):
|
||||
"""A Universal Speech Encoder -- https://arxiv.org/abs/2303.01037"""
|
||||
|
||||
config_class = Gemma3nAudioConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Gemma3nAudioConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
self.subsample_conv_projection = Gemma3nAudioSubSampleConvProjection(
|
||||
config, quant_config, prefix=add_prefix("subsample_conv_projection", prefix)
|
||||
)
|
||||
self.conformer = make_layers(
|
||||
config.conf_num_hidden_layers,
|
||||
lambda idx, prefix: Gemma3nAudioConformerBlock(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
),
|
||||
prefix=add_prefix("conformer", prefix),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, audio_mel: torch.Tensor, audio_mel_mask: torch.BoolTensor
|
||||
) -> Tuple[torch.Tensor, torch.BoolTensor]:
|
||||
"""Encodes a batch of MELs.
|
||||
|
||||
Args:
|
||||
audio_mel: a torch.Tensor of shape [batch, num_frames, mel_bins].
|
||||
audio_mel_mask: a torch.BoolTensor of shape [batch, num_frames].
|
||||
|
||||
Returns:
|
||||
audio_encodings: a torch.Tensor of shape
|
||||
`[batch_size, reduced_time_frames, hidden_size]`
|
||||
audio_mel_mask: a torch.BoolTensor of shape [batch, reduced_time_frames].
|
||||
"""
|
||||
audio_encodings = self.subsample_conv_projection(
|
||||
audio_mel
|
||||
) # audio_encodings: [B, T_sub, D]
|
||||
|
||||
# Subsample the input audio_mel_mask to match the time dimension of audio_encodings (T_sub)
|
||||
t_sub = audio_encodings.shape[1]
|
||||
|
||||
time_stride_product = 1
|
||||
for stride_pair_idx in range(len(self.config.sscp_conv_stride_size)):
|
||||
time_stride_product *= self.config.sscp_conv_stride_size[stride_pair_idx][0]
|
||||
|
||||
# Create indices for gathering from the original mask.
|
||||
# These indices map to original time steps corresponding to the start of each
|
||||
# receptive field in the subsampled output.
|
||||
indices = (
|
||||
torch.arange(t_sub, device=audio_mel_mask.device) * time_stride_product
|
||||
)
|
||||
indices = torch.clamp(indices, max=audio_mel_mask.shape[1] - 1)
|
||||
|
||||
# Expand indices for batch compatibility if B > 1 and indices is 1D.
|
||||
if audio_mel_mask.ndim > 1 and indices.ndim == 1:
|
||||
indices = indices.unsqueeze(0).expand(
|
||||
audio_mel_mask.shape[0], -1
|
||||
) # [B, T_sub]
|
||||
elif (
|
||||
audio_mel_mask.ndim == indices.ndim
|
||||
and audio_mel_mask.shape[0] == 1
|
||||
and indices.shape[0] != 1
|
||||
and t_sub == indices.shape[0]
|
||||
):
|
||||
# Handle case where B=1 but indices became [T_sub] instead of [1, T_sub]
|
||||
indices = indices.unsqueeze(0)
|
||||
|
||||
current_mask = torch.gather(audio_mel_mask, 1, indices) # [B, T_sub]
|
||||
|
||||
# Fallback: Ensure mask length matches feature length after gather.
|
||||
if current_mask.shape[1] != t_sub:
|
||||
if current_mask.shape[1] > t_sub:
|
||||
current_mask = current_mask[:, :t_sub]
|
||||
else: # current_mask.shape[1] < t_sub
|
||||
padding_needed = t_sub - current_mask.shape[1]
|
||||
current_mask = F.pad(
|
||||
current_mask, (0, padding_needed), value=True
|
||||
) # Pad with True (masked)
|
||||
|
||||
for i, block in enumerate(self.conformer):
|
||||
audio_encodings = block(
|
||||
audio_encodings, current_mask
|
||||
) # Pass the processed mask
|
||||
|
||||
if self.config.conf_reduction_factor > 1:
|
||||
audio_encodings = audio_encodings[:, :: self.config.conf_reduction_factor]
|
||||
# Reduce the mask as well
|
||||
current_mask = current_mask[:, :: self.config.conf_reduction_factor]
|
||||
|
||||
# Final masking of audio_encodings based on the final current_mask
|
||||
# Ensure current_mask length matches the finally reduced audio_encodings length
|
||||
if current_mask.shape[1] != audio_encodings.shape[1]:
|
||||
target_len = audio_encodings.shape[1]
|
||||
mask_current_len = current_mask.shape[1]
|
||||
if target_len > mask_current_len:
|
||||
padding_needed = target_len - mask_current_len
|
||||
current_mask = F.pad(current_mask, (0, padding_needed), value=True)
|
||||
elif mask_current_len > target_len: # mask is longer
|
||||
current_mask = current_mask[:, :target_len]
|
||||
|
||||
audio_encodings = audio_encodings.masked_fill(current_mask.unsqueeze(-1), 0.0)
|
||||
return audio_encodings, current_mask
|
||||
1009
python/sglang/srt/models/gemma3n_causal.py
Normal file
1009
python/sglang/srt/models/gemma3n_causal.py
Normal file
File diff suppressed because it is too large
Load Diff
511
python/sglang/srt/models/gemma3n_mm.py
Normal file
511
python/sglang/srt/models/gemma3n_mm.py
Normal file
@@ -0,0 +1,511 @@
|
||||
import logging
|
||||
import re
|
||||
from functools import lru_cache
|
||||
from typing import Dict, Iterable, List, Optional, Set, Tuple, TypedDict, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import (
|
||||
Gemma3nAudioConfig,
|
||||
Gemma3nConfig,
|
||||
Gemma3nTextConfig,
|
||||
Gemma3nVisionConfig,
|
||||
PreTrainedModel,
|
||||
)
|
||||
from transformers.models.auto.modeling_auto import AutoModel
|
||||
|
||||
from sglang.srt.hf_transformers_utils import get_processor
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from sglang.srt.managers.mm_utils import (
|
||||
MultiModalityDataPaddingPatternTokenPairs,
|
||||
general_mm_embed_routine,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import (
|
||||
MultimodalDataItem,
|
||||
MultimodalInputs,
|
||||
flatten_nested_list,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import (
|
||||
default_weight_loader,
|
||||
maybe_remap_kv_scale_name,
|
||||
)
|
||||
from sglang.srt.models.gemma3n_audio import Gemma3nAudioEncoder
|
||||
from sglang.srt.models.gemma3n_causal import Gemma3nRMSNorm, Gemma3nTextModel
|
||||
from sglang.srt.utils import add_prefix
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
cached_get_processor = lru_cache(get_processor)
|
||||
|
||||
|
||||
class Gemma3nImagePixelInputs(TypedDict):
|
||||
pixel_values: torch.Tensor
|
||||
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
|
||||
|
||||
|
||||
class Gemma3nAudioInputs(TypedDict):
|
||||
input_features: torch.Tensor
|
||||
"""Shape: `(batch_size * num_audio, seq_length, num_features)`"""
|
||||
input_features_mask: torch.Tensor
|
||||
"""Shape: `(batch_size * num_audio, seq_length)`"""
|
||||
|
||||
|
||||
class Gemma3nMultimodalEmbedder(nn.Module):
|
||||
"""Embeds token ids or soft tokens for multimodal content into language model space."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
multimodal_config: Union[Gemma3nAudioConfig, Gemma3nVisionConfig],
|
||||
text_config: Gemma3nTextConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.multimodal_hidden_size = multimodal_config.hidden_size
|
||||
self.eps = multimodal_config.rms_norm_eps
|
||||
self.vocab_offset = multimodal_config.vocab_offset
|
||||
self.vocab_size = multimodal_config.vocab_size
|
||||
self.text_hidden_size = text_config.hidden_size
|
||||
|
||||
self.embedding = VocabParallelEmbedding(
|
||||
self.vocab_size,
|
||||
self.multimodal_hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("embedding", prefix),
|
||||
)
|
||||
|
||||
self.hard_embedding_norm = Gemma3nRMSNorm(
|
||||
self.multimodal_hidden_size,
|
||||
eps=self.eps,
|
||||
)
|
||||
|
||||
self.soft_embedding_norm = Gemma3nRMSNorm(
|
||||
self.multimodal_hidden_size,
|
||||
eps=self.eps,
|
||||
)
|
||||
|
||||
self.embedding_projection = RowParallelLinear(
|
||||
self.multimodal_hidden_size,
|
||||
self.text_hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("embedding_projection", prefix),
|
||||
)
|
||||
|
||||
self.embedding_post_projection_norm = Gemma3nRMSNorm(
|
||||
self.text_hidden_size,
|
||||
eps=self.eps,
|
||||
with_scale=False,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Embeds token ids or soft tokens for multimodal content into language model space.
|
||||
|
||||
Args:
|
||||
input_ids: A torch.LongTensor containing the token ids to embed. Values should be in the range
|
||||
`[vocab_offset, vocab_offset + vocab_size)`.
|
||||
inputs_embeds: A torch.Tensor containing the soft tokens to embed.
|
||||
|
||||
Returns:
|
||||
A torch.Tensor of embeddings with shape `[batch_size, seq_len, self.config.text_config.hidden_size]`.
|
||||
"""
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError(
|
||||
"You must specify exactly one of input_ids or inputs_embeds"
|
||||
)
|
||||
|
||||
if inputs_embeds is not None:
|
||||
emb_norm = self.soft_embedding_norm(inputs_embeds)
|
||||
else:
|
||||
# Handle out of vocab ids to prevent CUDA assertion failures
|
||||
out_of_vocab_id = self.vocab_size - 1
|
||||
adjusted_ids = input_ids - self.vocab_offset
|
||||
adjusted_ids = torch.where(adjusted_ids < 0, out_of_vocab_id, adjusted_ids)
|
||||
adjusted_ids = torch.where(
|
||||
adjusted_ids >= self.vocab_size, out_of_vocab_id, adjusted_ids
|
||||
)
|
||||
hard_emb = self.embedding(adjusted_ids)
|
||||
emb_norm = self.hard_embedding_norm(hard_emb)
|
||||
|
||||
emb_norm_proj, _ = self.embedding_projection(emb_norm)
|
||||
return self.embedding_post_projection_norm(emb_norm_proj)
|
||||
|
||||
|
||||
class Gemma3nForConditionalGeneration(PreTrainedModel):
|
||||
config_class = Gemma3nConfig
|
||||
"""Gemma3n multimodal model for conditional generation."""
|
||||
|
||||
# BitandBytes specific attributes
|
||||
default_bitsandbytes_target_modules = [
|
||||
".gate_proj.",
|
||||
".down_proj.",
|
||||
".up_proj.",
|
||||
".q_proj.",
|
||||
".k_proj.",
|
||||
".v_proj.",
|
||||
".o_proj.",
|
||||
".out_proj.",
|
||||
]
|
||||
bitsandbytes_stacked_params_mapping = {
|
||||
"q_proj": ("qkv_proj", 0),
|
||||
"k_proj": ("qkv_proj", 1),
|
||||
"v_proj": ("qkv_proj", 2),
|
||||
"gate_proj": ("gate_up_proj", 0),
|
||||
"up_proj": ("gate_up_proj", 1),
|
||||
"out_proj": ("proj", 0),
|
||||
}
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
"qkv_proj",
|
||||
"o_proj",
|
||||
"gate_up_proj",
|
||||
"down_proj",
|
||||
]
|
||||
# Gemma does not apply LoRA to the embedding layer
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
supports_lora = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Gemma3nConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__(config=config)
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
|
||||
prefix = add_prefix("model", prefix)
|
||||
|
||||
# Vision components
|
||||
# TODO: Use sglang's vision model
|
||||
self.vision_tower = AutoModel.from_config(config=config.vision_config)
|
||||
|
||||
self.embed_vision = Gemma3nMultimodalEmbedder(
|
||||
config.vision_config,
|
||||
config.text_config,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("embed_vision", prefix),
|
||||
)
|
||||
|
||||
# Audio components
|
||||
self.embed_audio = Gemma3nMultimodalEmbedder(
|
||||
config.audio_config,
|
||||
config.text_config,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("embed_audio", prefix),
|
||||
)
|
||||
|
||||
self.audio_tower = Gemma3nAudioEncoder(
|
||||
config.audio_config,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("audio_tower", prefix),
|
||||
)
|
||||
|
||||
self.vocab_size = config.text_config.vocab_size
|
||||
self.vocab_size_per_layer_input = config.text_config.vocab_size_per_layer_input
|
||||
|
||||
# Text model
|
||||
self.language_model = Gemma3nTextModel(
|
||||
config.text_config,
|
||||
quant_config,
|
||||
prefix=add_prefix("language_model", prefix),
|
||||
)
|
||||
|
||||
# Create logits processor for the multimodal model
|
||||
self.logits_processor = LogitsProcessor(config.text_config)
|
||||
|
||||
self.post_init()
|
||||
|
||||
def pad_input_ids(
|
||||
self,
|
||||
input_ids: List[int],
|
||||
mm_inputs: Optional[MultimodalInputs] = None,
|
||||
) -> List[int]:
|
||||
"""Pad input IDs with image and audio tokens."""
|
||||
if mm_inputs is None:
|
||||
return input_ids
|
||||
|
||||
# Collect available media token pairs
|
||||
media_token_pairs = []
|
||||
for attr_name in ["im_start_id", "audio_start_id"]:
|
||||
if hasattr(mm_inputs, attr_name):
|
||||
start_id = getattr(mm_inputs, attr_name)
|
||||
end_id = getattr(mm_inputs, attr_name.replace("start", "end"))
|
||||
media_token_pairs.append((start_id, end_id))
|
||||
|
||||
# Apply padding pattern if we have media tokens
|
||||
if media_token_pairs:
|
||||
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
|
||||
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
||||
|
||||
return input_ids
|
||||
|
||||
def get_input_embeddings(self) -> nn.Embedding:
|
||||
return self.language_model.get_input_embeddings()
|
||||
|
||||
def get_attention_sliding_window_size(self):
|
||||
return self.config.text_config.sliding_window - 1
|
||||
|
||||
def get_image_feature(self, items: List[MultimodalDataItem]):
|
||||
"""
|
||||
Projects the last hidden state from the vision model into language model space.
|
||||
|
||||
Returns:
|
||||
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
||||
"""
|
||||
# Process images one by one to handle flatten_batch=True constraint in vision_tower
|
||||
all_pixel_values = flatten_nested_list([item.pixel_values for item in items])
|
||||
vision_outputs_list = []
|
||||
|
||||
for pixel_values_batch in all_pixel_values:
|
||||
# Normalize input shape to [batch_size, channels, height, width]
|
||||
if pixel_values_batch.dim() == 5:
|
||||
pixel_values_batch = pixel_values_batch.squeeze(0)
|
||||
elif pixel_values_batch.dim() == 3:
|
||||
pixel_values_batch = pixel_values_batch.unsqueeze(0)
|
||||
elif pixel_values_batch.dim() != 4:
|
||||
raise ValueError(
|
||||
f"Unexpected pixel_values shape: {pixel_values_batch.shape}"
|
||||
)
|
||||
|
||||
# Process each image in the batch
|
||||
batch_size = pixel_values_batch.shape[0]
|
||||
for i in range(batch_size):
|
||||
pixel_value = pixel_values_batch[i : i + 1] # Keep batch dimension as 1
|
||||
pixel_value = pixel_value.to(
|
||||
device=self.vision_tower.device, dtype=self.language_model.dtype()
|
||||
)
|
||||
vision_outputs = self.vision_tower(
|
||||
pixel_values=pixel_value, do_pooling=False, return_dict=True
|
||||
).last_hidden_state
|
||||
vision_outputs_list.append(vision_outputs)
|
||||
|
||||
# Concatenate all vision outputs
|
||||
vision_outputs = torch.cat(vision_outputs_list, dim=0)
|
||||
|
||||
# Convert from (batch, channels, height, width) to (batch, height * width, channels)
|
||||
vision_outputs = vision_outputs.reshape(
|
||||
vision_outputs.shape[0],
|
||||
self.config.vision_config.hidden_size,
|
||||
self.config.vision_soft_tokens_per_image,
|
||||
).permute(0, 2, 1)
|
||||
|
||||
# Normalize and embed the soft tokens into language model space
|
||||
vision_outputs *= self.config.vision_config.hidden_size**0.5
|
||||
return self.embed_vision(inputs_embeds=vision_outputs)
|
||||
|
||||
def get_audio_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
||||
"""
|
||||
Projects the last hidden state from the audio encoder into language model space.
|
||||
|
||||
Args:
|
||||
items: List of multimodal data items containing audio data.
|
||||
|
||||
Returns:
|
||||
audio_features (`torch.Tensor`): Audio feature tensor of shape `(num_audios, audio_length, embed_dim)`).
|
||||
"""
|
||||
# Extract audio features and masks from items
|
||||
all_input_features = flatten_nested_list(
|
||||
[item.input_features for item in items]
|
||||
)
|
||||
all_input_features_mask = flatten_nested_list(
|
||||
[~item.input_features_mask for item in items]
|
||||
) # Note(Xinyuan): reverse the mask according to the HF implementation
|
||||
|
||||
# Process audio features one by one
|
||||
audio_features_list = []
|
||||
|
||||
for input_features, input_features_mask in zip(
|
||||
all_input_features, all_input_features_mask
|
||||
):
|
||||
# Ensure proper tensor format
|
||||
if input_features.dim() == 2:
|
||||
input_features = input_features.unsqueeze(0)
|
||||
if input_features_mask.dim() == 1:
|
||||
input_features_mask = input_features_mask.unsqueeze(0)
|
||||
|
||||
# Move to device and dtype
|
||||
input_features = input_features.to(
|
||||
device=next(self.audio_tower.parameters()).device,
|
||||
dtype=self.language_model.dtype(),
|
||||
)
|
||||
input_features_mask = input_features_mask.to(device=input_features.device)
|
||||
|
||||
# Process through audio tower
|
||||
audio_outputs, audio_mask = self.audio_tower(
|
||||
input_features, input_features_mask
|
||||
)
|
||||
|
||||
# Embed the audio outputs
|
||||
audio_embeds = self.embed_audio(inputs_embeds=audio_outputs)
|
||||
audio_features_list.append(audio_embeds)
|
||||
|
||||
# Concatenate all audio features
|
||||
if audio_features_list:
|
||||
audio_features = torch.cat(audio_features_list, dim=0)
|
||||
|
||||
# The Gemma3nProcessor expects all audio will be 30s in length and inserts 188 audio soft tokens into the
|
||||
# text to account for this. However, the audio preprocessing and encoder do not gurarantee they will
|
||||
# produce 188 soft tokens; they will produce at most that many tokens, but they may produce fewer tokens
|
||||
# depending on the length of the longest audio input in the batch. When we encounter this situation, we pad
|
||||
# the audio feature out to 188 soft tokens with the emebedding of the last token in the embed_audio vocab.
|
||||
audio_padding_toks = torch.tensor(
|
||||
[[self.vocab_size - 1]], dtype=torch.long, device=audio_features.device
|
||||
)
|
||||
audio_padding_embs = self.embed_audio(input_ids=audio_padding_toks)
|
||||
audio_features = torch.where(
|
||||
audio_mask.unsqueeze(-1), audio_padding_embs, audio_features
|
||||
)
|
||||
|
||||
audio_batch_size, audio_seq_len, audio_embed_dim = audio_features.shape
|
||||
extra_padding_tokens = (
|
||||
self.config.audio_soft_tokens_per_image - audio_seq_len
|
||||
)
|
||||
extra_padding_features = audio_padding_embs.expand(
|
||||
audio_batch_size, extra_padding_tokens, audio_embed_dim
|
||||
)
|
||||
|
||||
audio_features = torch.cat((audio_features, extra_padding_features), dim=1)
|
||||
return audio_features
|
||||
else:
|
||||
return torch.empty(
|
||||
0,
|
||||
0,
|
||||
self.language_model.config.hidden_size,
|
||||
device=next(self.parameters()).device,
|
||||
dtype=self.language_model.dtype(),
|
||||
)
|
||||
|
||||
def get_per_layer_inputs(
|
||||
self, input_ids: torch.LongTensor
|
||||
) -> Optional[torch.Tensor]:
|
||||
return self.language_model.get_per_layer_inputs(input_ids)
|
||||
|
||||
def project_per_layer_inputs(
|
||||
self,
|
||||
inputs_embeds: torch.Tensor,
|
||||
per_layer_inputs: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
return self.language_model.project_per_layer_inputs(
|
||||
inputs_embeds, per_layer_inputs
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
**kwargs: object,
|
||||
) -> LogitsProcessor:
|
||||
"""Forward pass for multimodal Gemma3n."""
|
||||
if (input_ids is None) ^ (input_embeds is not None):
|
||||
raise ValueError(
|
||||
"You must specify exactly one of input_ids or inputs_embeds"
|
||||
)
|
||||
|
||||
positions += 1
|
||||
|
||||
if input_ids is not None:
|
||||
# Prepare per-layer inputs from inputs_ids
|
||||
per_layer_inputs_mask = torch.logical_and(
|
||||
input_ids >= 0, input_ids < self.vocab_size_per_layer_input
|
||||
)
|
||||
per_layer_inputs_tokens = torch.where(
|
||||
per_layer_inputs_mask, input_ids, torch.zeros_like(input_ids)
|
||||
)
|
||||
per_layer_inputs = self.language_model.get_per_layer_inputs(
|
||||
per_layer_inputs_tokens
|
||||
)
|
||||
|
||||
# Use general_mm_embed_routine for handling multimodal data
|
||||
# This will automatically handle text, image, and audio embeddings
|
||||
hidden_states = general_mm_embed_routine(
|
||||
input_ids=input_ids,
|
||||
forward_batch=forward_batch,
|
||||
language_model=self.language_model,
|
||||
image_data_embedding_func=self.get_image_feature,
|
||||
audio_data_embedding_func=self.get_audio_feature,
|
||||
positions=positions,
|
||||
per_layer_inputs=per_layer_inputs,
|
||||
)
|
||||
|
||||
# Process hidden states through logits processor
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.language_model.embed_tokens, forward_batch
|
||||
)
|
||||
|
||||
def tie_weights(self):
|
||||
return self.language_model.tie_weights()
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".qkv_proj", ".q_proj", "q"),
|
||||
(".qkv_proj", ".k_proj", "k"),
|
||||
(".qkv_proj", ".v_proj", "v"),
|
||||
(".gate_up_proj", ".up_proj", 1),
|
||||
(".gate_up_proj", ".gate_proj", 0),
|
||||
]
|
||||
"""Load weights for the model."""
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
name = re.sub(r"^model\.", "", name)
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
if "vision_model" in name:
|
||||
# adapt to VisionAttention
|
||||
name = name.replace(".self_attn.out_proj", ".self_attn.proj")
|
||||
# Skip loading extra bias for GPTQ models
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
# Remapping the name of FP8 kv-scale
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
EntryClass = Gemma3nForConditionalGeneration
|
||||
Reference in New Issue
Block a user