From b7e951a6dbcd64a1c011f276c57ab84fb7fa76f0 Mon Sep 17 00:00:00 2001 From: Binyao Jiang Date: Fri, 18 Jul 2025 21:03:53 -0700 Subject: [PATCH] Feat: Support audio in Phi4-mm model (#8048) --- .../multimodal_language_models.md | 2 +- python/sglang/srt/conversation.py | 1 + python/sglang/srt/managers/schedule_batch.py | 4 + python/sglang/srt/models/phi4mm.py | 41 +- python/sglang/srt/models/phi4mm_audio.py | 1260 +++++++++++ python/sglang/srt/models/phi4mm_utils.py | 1917 +++++++++++++++++ .../multimodal/processors/base_processor.py | 14 +- .../srt/multimodal/processors/phi4mm.py | 97 +- python/sglang/srt/utils.py | 7 +- test/srt/test_vision_openai_server_b.py | 22 +- test/srt/test_vision_openai_server_common.py | 22 +- 11 files changed, 3333 insertions(+), 54 deletions(-) create mode 100644 python/sglang/srt/models/phi4mm_audio.py create mode 100644 python/sglang/srt/models/phi4mm_utils.py diff --git a/docs/supported_models/multimodal_language_models.md b/docs/supported_models/multimodal_language_models.md index 665d8de7e..66de3d8a1 100644 --- a/docs/supported_models/multimodal_language_models.md +++ b/docs/supported_models/multimodal_language_models.md @@ -37,5 +37,5 @@ in the GitHub search bar. | **Gemma 3 (Multimodal)** | `google/gemma-3-4b-it` | `gemma-it` | Gemma 3's larger models (4B, 12B, 27B) accept images (each image encoded as 256 tokens) alongside text in a combined 128K-token context. | | **Kimi-VL** (A3B) | `moonshotai/Kimi-VL-A3B-Instruct` | `kimi-vl` | Kimi-VL is a multimodal model that can understand and generate text from images. | | **Mistral-Small-3.1-24B** | `mistralai/Mistral-Small-3.1-24B-Instruct-2503` | `mistral` | Mistral 3.1 is a multimodal model that can generate text from text or images input. It also supports tool calling and structured output. | -| **Phi-4-multimodal-instruct** | `microsoft/Phi-4-multimodal-instruct` | `phi-4-mm` | Phi-4-multimodal-instruct is the multimodal variant of the Phi-4-mini model, enhanced with LoRA for improved multimodal capabilities. Currently, it supports only text and vision modalities in SGLang. | +| **Phi-4-multimodal-instruct** | `microsoft/Phi-4-multimodal-instruct` | `phi-4-mm` | Phi-4-multimodal-instruct is the multimodal variant of the Phi-4-mini model, enhanced with LoRA for improved multimodal capabilities. It supports text, vision and audio modalities in SGLang. | | **MiMo-VL** (7B) | `XiaomiMiMo/MiMo-VL-7B-RL` | `mimo-vl` | Xiaomi's compact yet powerful vision-language model featuring a native resolution ViT encoder for fine-grained visual details, an MLP projector for cross-modal alignment, and the MiMo-7B language model optimized for complex reasoning tasks. | diff --git a/python/sglang/srt/conversation.py b/python/sglang/srt/conversation.py index c085c4423..cb4bdbc44 100644 --- a/python/sglang/srt/conversation.py +++ b/python/sglang/srt/conversation.py @@ -729,6 +729,7 @@ register_conv_template( sep="<|end|>", stop_str="<|end|>", image_token="<|endoftext10|>", + audio_token="<|endoftext11|>", ) ) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 01da558b7..a9ed66f9a 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -239,6 +239,10 @@ class MultimodalDataItem: # For gemma3n input_features_mask: Optional[torch.Tensor] = None + # For phi4-mm + image_attention_mask: Optional[torch.Tensor] = None + audio_attention_mask: Optional[torch.Tensor] = None + @staticmethod def is_empty_list(l): if l is None: diff --git a/python/sglang/srt/models/phi4mm.py b/python/sglang/srt/models/phi4mm.py index 8a74888ac..b7997fc0a 100644 --- a/python/sglang/srt/models/phi4mm.py +++ b/python/sglang/srt/models/phi4mm.py @@ -40,6 +40,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.idefics2 import Idefics2VisionTransformer from sglang.srt.models.llama import LlamaForCausalLM +from sglang.srt.models.phi4mm_audio import AudioEmbedding logger = logging.getLogger(__name__) @@ -420,16 +421,49 @@ class Phi4MMForCausalLM(nn.Module): model_dir=config._name_or_path, ) + if isinstance(config.embd_layer["audio_embd_layer"], dict): + embedding_config = { + "embedding_cls": config.embd_layer["audio_embd_layer"]["embedding_cls"], + **config.embd_layer["audio_embd_layer"], + } + else: + embedding_config = {"embedding_cls": config.embd_layer["embedding_cls"]} + + self.embed_tokens_extend = AudioEmbedding(config, **embedding_config) + def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: dtype = next(self.vision_encoder.parameters()).dtype pixel_values = torch.cat([item.feature for item in items], dim=0).type(dtype) - image_attention_mask = torch.cat([item.image_emb_mask for item in items], dim=0) + image_attention_mask = torch.cat( + [item.image_attention_mask for item in items], dim=0 + ) image_sizes = torch.cat([item.image_sizes for item in items], dim=0) image_embeds = self.vision_encoder( pixel_values, image_sizes, image_attention_mask ) return torch.cat(image_embeds).type(dtype) + def get_audio_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: + # (e.g. multiple examples) and the second dim is the multi-audio dim + # (e.g. multiple audios in the same example) + embed_tokens_extend_param = next(self.embed_tokens_extend.parameters()) + device = embed_tokens_extend_param.device + dtype = embed_tokens_extend_param.dtype + audio_embeds = [ + self.embed_tokens_extend( + # item.feature: (num_audios_in_a_sequence, T, D) + # item.audio_attention_mask: (num_audios_in_a_sequence, T, D) BoolTensor or None + audio_features=item.feature.to(device).type(dtype), + audio_attention_mask=( + item.audio_attention_mask.to(device) + if item.audio_attention_mask is not None + else None + ), + ) + for item in items + ] + return torch.cat(audio_embeds).type(dtype) + def forward( self, input_ids: torch.Tensor, @@ -443,6 +477,7 @@ class Phi4MMForCausalLM(nn.Module): language_model=self.language_model, data_embedding_funcs={ Modality.IMAGE: self.get_image_feature, + Modality.AUDIO: self.get_audio_feature, }, positions=positions, ) @@ -464,6 +499,9 @@ class Phi4MMForCausalLM(nn.Module): (".self_attn.qkv_proj", ".self_attn.v_proj", "v"), ] prefix_mapping = { + "model.embed_tokens_extend.audio_embed.audio_projection.vision.": "embed_tokens_extend.audio_projection_for_vision.", + "model.embed_tokens_extend.audio_embed.audio_projection.speech.": "embed_tokens_extend.audio_projection.", + "model.embed_tokens_extend.audio_embed.": "embed_tokens_extend.", "model.embed_tokens_extend.image_embed.": "vision_encoder.", "model.": "language_model.model.", } @@ -472,7 +510,6 @@ class Phi4MMForCausalLM(nn.Module): "img_processor.encoder.layers.26", "img_processor.head", "img_processor.post_layernorm", - "audio", ] def _should_skip(name: str) -> bool: diff --git a/python/sglang/srt/models/phi4mm_audio.py b/python/sglang/srt/models/phi4mm_audio.py new file mode 100644 index 000000000..fd199836e --- /dev/null +++ b/python/sglang/srt/models/phi4mm_audio.py @@ -0,0 +1,1260 @@ +# Copyright 2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +#!/usr/bin/env python3 +import abc +import math +from typing import Literal, Optional + +import numpy as np +import torch +import torch.nn.functional as F +from torch import Tensor, nn +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + CheckpointWrapper, +) +from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel +from transformers import PretrainedConfig + +from sglang.srt.models.phi4mm_utils import ( + AbsolutePositionalEncoding, + ConvModule, + FeedForward, + MeanVarianceNormLayer, + MultiHeadedAttention, + MultiSequential, + NemoConvSubsampling, + T5RelativeAttentionLogitBias, + adaptive_enc_mask, + get_offset, + unfold_tensor, +) + +_AUDIO_PLACEHOLDER_TOKEN_ID = 200011 # <|endoftext11|> + + +class ConformerEncoderLayer(nn.Module): + """ConformerEncoder Layer module. + for more details see conformer paper: + https://arxiv.org/abs/2005.08100 + This module implement the Conformer block layer. + + Args: + d_model: int + attention dim. + ext_pw_out_channel: int + if > 0, ext_pw_out_channel is a dim channel size + for the last pointwise conv after swish activation. + depthwise_seperable_out_channel: int + if set different to 0, the number of + depthwise_seperable_out_channel will be used as a + channel_out of the second conv1d layer. + otherwise, it equal to 0, the second conv1d layer is skipped. + depthwise_multiplier: int + number of input_dim channels duplication. this value + will be used to compute the hidden channels of the Conv1D. + n_head: int + the number of heads for multihead attention module. + d_ffn: int + output size of the feed_forward blocks. + ext_pw_kernel_size: int + kernel size of the conv pointwise of the conformer. + kernel_size: int + kernel size. + dropout_rate: float + dropout rate. + causal: bool, optional + if set to True, convolution have no access + to future frames. default False. + batch_norm: bool, optional + if set to True, apply batchnorm before activation + in ConvModule layer of the conformer. + default False + activation: str, optional + activation function name, + one of ["relu", "swish", "sigmoid"], + sigmoid activation is only used with "glu_in_fnn=True", + default "relu". + chunk_se: int, optional + 0 for offline SE. + 1 for streaming SE, where mean is computed + by accumulated history until current chunk_se. + 2 for streaming SE, where mean is computed + by only the current chunk. + default 0. + chunk_size: int, optional + chunk_size for cnn. default 18 + conv_activation: str, optional + activation function used in ConvModule part + of the conformer, default "relu". + conv_glu_type: str, optional + activation function used for the glu inside + the ConvModule part of the conformer. + default: "sigmoid". + bias_in_glu: bool, optional + if set to True, use additive bias in the weight module + before GLU. + linear_glu_in_convm: bool, optional + if set to True, use GLULinear module, + otherwise, used GLUPointWiseConv module. + default to False. + attention_inner_dim: int, optional + if equal to -1, attention dim for linears k/q/v is + equal to d_model. otherwise attention_inner_dim is used. + default -1. + attention_glu_type: str, optional + activation function for glu used in the multihead attention, + default "swish". + activation_checkpointing: str, optional + a dictionarry of {"module","interval","offload"}, where + "module": str + accept ["transformer", "attention"] to select + which module should do activation checkpointing. + "interval": int, default 1, + interval of applying activation checkpointing, + interval = 1 means that we apply checkpointing + on every layer (if activation), otherwise, + we apply it every x interval. + "offload": bool, default False, + if set to True, we offload activation to cpu and + reload it during backward, otherwise, + we recalculate activation in backward. + default "". + export: bool, optional + if set to True, it remove the padding from convolutional layers + and allow the onnx conversion for inference. + default False. + use_pt_scaled_dot_product_attention: bool, optional + if set to True, use pytorch's scaled dot product attention + implementation in training. + attn_group_sizes: int, optional + the number of groups to use for attention, default 1 + (Multi-Head Attention), + 1 = typical Multi-Head Attention, + 1 < attn_group_sizes < attention_heads = Grouped-Query Attention + attn_group_sizes = attention_heads = Multi-Query Attention + """ + + def __init__( + self, + d_model=512, + ext_pw_out_channel=0, + depthwise_seperable_out_channel=256, + depthwise_multiplier=1, + n_head=4, + d_ffn=2048, + ext_pw_kernel_size=1, + kernel_size=3, + dropout_rate=0.1, + causal=False, + batch_norm=False, + activation="relu", + chunk_se=0, + chunk_size=18, + conv_activation="relu", + conv_glu_type="sigmoid", + bias_in_glu=True, + linear_glu_in_convm=False, + attention_inner_dim=-1, + attention_glu_type="swish", + activation_checkpointing="", + export=False, + use_pt_scaled_dot_product_attention=False, + attn_group_sizes: int = 1, + ): + super().__init__() + + self.feed_forward_in = FeedForward( + d_model=d_model, + d_inner=d_ffn, + dropout_rate=dropout_rate, + activation=activation, + bias_in_glu=bias_in_glu, + ) + + self.self_attn = MultiHeadedAttention( + n_head, + d_model, + dropout_rate, + attention_inner_dim, + attention_glu_type, + bias_in_glu, + use_pt_scaled_dot_product_attention=use_pt_scaled_dot_product_attention, + group_size=attn_group_sizes, + ) + self.conv = ConvModule( + d_model, + ext_pw_out_channel, + depthwise_seperable_out_channel, + ext_pw_kernel_size, + kernel_size, + depthwise_multiplier, + dropout_rate, + causal, + batch_norm, + chunk_se, + chunk_size, + conv_activation, + conv_glu_type, + bias_in_glu, + linear_glu_in_convm, + export=export, + ) + + self.feed_forward_out = FeedForward( + d_model=d_model, + d_inner=d_ffn, + dropout_rate=dropout_rate, + activation=activation, + bias_in_glu=bias_in_glu, + ) + + self.layer_norm_att = nn.LayerNorm(d_model) + self.layer_norm = nn.LayerNorm(d_model) + + def forward( + self, + x, + pos_k, + pos_v, + mask, + relative_attention_bias: Optional[Tensor] = None, + ): + """ConformerEncoder forward. + + Args: + x: torch.Tensor + input feature of shape (batch, max_time_in, size) + pos_k: torch.Tensor + positional key embedding. + mask: torch.Tensor + mask for x (batch, max_time_in) + relative_attention_bias: Optional[torch.Tensor] + bias added to attention logits w.r.t. relative positions + (1, n_head, time1, time2) + """ + x = x + 0.5 * self.feed_forward_in(x) + norm_x = self.layer_norm_att(x) + + x = x + self.self_attn( + norm_x, + norm_x, + norm_x, + pos_k, + pos_v, + mask, + relative_attention_bias=relative_attention_bias, + ) + x = x + self.conv(x) + x = x + 0.5 * self.feed_forward_out(x) + + out = self.layer_norm(x) + + return out, pos_k, pos_v, mask + + +class TransformerEncoderBase(abc.ABC, nn.Module): + """The Base class for Transformer based encoders + + Please set causal = True in streaming model + Args: + input_size: int + input feature dimension. + chunk_size: int, list(int) + Number of frames for each chunk + This variable can take 2 forms: + int: Used for inference, or single chunk size training + list(int) : Used only for variable chunk size training + Some examples for the 2 cases: + chunk_size = 12 + chunk_size = [6, 8, 12, 24] + left_chunk: int, list(int) + Number of chunks used for masking in streaming mode. + This variable can take 2 forms: + int: Used for inference, or single chunk size training + list(int) : Used only for variable chunk size training. When + chunk_size is a list, left_chunk must be a list with same length. + Some examples for the 2 cases: + left_chunk = 6 + left_chunk = [12, 9, 6, 3] + attention_dim: int, optional + attention dimension. default 256. + attention_heads: int, optional + the number of heads. default 4 + input_layer: str, optional + input layer type before Conformer, + one of ["linear", "conv2d", "custom", "vgg2l", "embed"], + default "conv2d" + cnn_out: int, optional + the number of CNN channels before Conformer. + default -1. + cnn_layer_norm: bool, optional + layer norm between Conformer and the first CNN. + default False. + time_reduction: int, optional + time reduction factor + default 4 + dropout_rate: float, optional + dropout rate. default 0.1 + padding_idx: int, optional + padding index for input_layer=embed + default -1 + relative_attention_bias_args: dict, optional + use more efficient scalar bias-based relative multihead attention + (Q*K^T + B) implemented in cmb.basics.embedding. + [T5/ALiBi]RelativeAttentionLogitBias + usage: relative_attention_bias_args={"type": t5/alibi} + additional method-specific arguments can be provided (see + transformer_base.py) + positional_dropout_rate: float, optional + dropout rate after positional encoding. default 0.0 + nemo_conv_settings: dict, optional + A dictionary of settings for NeMo Subsampling. + default None + conv2d_extra_padding: str, optional + Add extra padding in conv2d subsampling layers. Choices are + (feat, feat_time, none, True). + if True or feat_time, the extra padding is added into non full + supraframe utts in batch. + Default: none + attention_group_size: int, optional + the number of groups to use for attention, default 1 + (Multi-Head Attention), + 1 = typical Multi-Head Attention, + 1 < attention_group_size < attention_heads = Grouped-Query + Attention + attention_group_size = attention_heads = Multi-Query Attention + """ + + def __init__( + self, + input_size, + chunk_size, + left_chunk, + attention_dim=256, + attention_heads=4, + input_layer="nemo_conv", + cnn_out=-1, + cnn_layer_norm=False, + time_reduction=4, + dropout_rate=0.0, + padding_idx=-1, + relative_attention_bias_args=None, + positional_dropout_rate=0.0, + nemo_conv_settings=None, + conv2d_extra_padding: Literal["feat", "feat_time", "none", True] = "none", + attention_group_size=1, + encoder_embedding_config=None, + ): + super().__init__() + self.input_size = input_size + self.input_layer = input_layer + self.chunk_size = chunk_size + self.left_chunk = left_chunk + self.attention_dim = attention_dim + self.num_heads = attention_heads + self.attention_group_size = attention_group_size + self.time_reduction = time_reduction + self.nemo_conv_settings = nemo_conv_settings + self.encoder_embedding_config = encoder_embedding_config + + if self.input_layer == "nemo_conv": + default_nemo_conv_settings = { + "subsampling": "dw_striding", + "subsampling_factor": self.time_reduction, + "feat_in": input_size, + "feat_out": attention_dim, + "conv_channels": 256, + "subsampling_conv_chunking_factor": 1, + "activation": nn.ReLU(), + "is_causal": False, + } + # Override any of the defaults with the incoming, user settings + if nemo_conv_settings: + default_nemo_conv_settings.update(nemo_conv_settings) + for i in ["subsampling_factor", "feat_in", "feat_out"]: + assert ( + i not in nemo_conv_settings + ), "{i} should be specified outside of the NeMo dictionary" + + self.embed = NemoConvSubsampling( + **default_nemo_conv_settings, + ) + else: + raise ValueError("unknown input_layer: " + input_layer) + + self.pos_emb = AbsolutePositionalEncoding( + attention_dim, positional_dropout_rate + ) + + self.relative_attention_bias_type = ( + relative_attention_bias_args.get("type") + if relative_attention_bias_args + else None + ) + if self.relative_attention_bias_type == "t5": + assert ( + self.num_heads % self.attention_group_size == 0 + ), "attention_group_size must divide n_head" + self.relative_attention_bias_layer = T5RelativeAttentionLogitBias( + self.num_heads // self.attention_group_size, + max_distance=relative_attention_bias_args.get( + "t5_bias_max_distance", 1000 + ), + symmetric=relative_attention_bias_args.get("t5_bias_symmetric", False), + ) + else: + raise NotImplementedError + + self.encoder_embedding = MeanVarianceNormLayer( + self.encoder_embedding_config["input_size"] + ) + + def compute_lens_change(self, feature_lens): + """feature_lens: int + return updated feature lens. + + This used to return a different lambda function for each case that + computed the right thing. That does not work within Torchscript. + If you really need this to be faster, create nn.Module()-s for all + the cases and return one of them. Torchscript does support that. + """ + if self.input_layer == "nemo_conv": + # Handle the special causal case + subsampling_causal_cond = self.nemo_conv_settings.get( + "subsampling", "dw_striding" + ) in [ + "dw_striding", + "striding", + "striding_conv1d", + ] + is_causal = self.nemo_conv_settings.get("is_causal", False) + if is_causal and subsampling_causal_cond: + lens_change = ( + torch.ceil(feature_lens / self.time_reduction).long() + if isinstance(feature_lens, Tensor) + else math.ceil(feature_lens / self.time_reduction) + ) + feature_lens_remainder = feature_lens % self.time_reduction + if isinstance(feature_lens, Tensor): + lens_change[feature_lens_remainder != 1] += 1 + elif feature_lens_remainder != 1: + lens_change += 1 + return lens_change + ceil_func = math.ceil if isinstance(feature_lens, int) else torch.ceil + return ceil_func(feature_lens / self.time_reduction) + + @abc.abstractmethod + def forward(self): + """Abstract forward method implementation.""" + + def _chunk_size_selection(self, chunk_size=None, left_chunk=None): + """If chunk size is a list, we will randomly select a chunk size.""" + + if chunk_size is None: + chunk_size = self.chunk_size + if left_chunk is None: + left_chunk = self.left_chunk + if isinstance(chunk_size, list): + # Variable chunk size during training + chunk_size_index = int( + torch.randint(low=0, high=len(chunk_size), size=(1,)) + ) + chunk_size_train_eff = chunk_size[chunk_size_index] + if not isinstance(left_chunk, list): + raise ValueError( + "Since chunk_size is a list, left_chunk must be a list" + ) + if len(left_chunk) != len(chunk_size): + raise ValueError( + "The length of left_chunk must be the same as length of " + "chunk_size." + ) + left_chunk_train_eff = left_chunk[chunk_size_index] + else: + chunk_size_train_eff = chunk_size + left_chunk_train_eff = left_chunk + + return chunk_size_train_eff, left_chunk_train_eff + + def _get_embed_class(self, embed): + # pylint: disable=protected-access + is_embed_using_act_chkpt = isinstance(embed, CheckpointWrapper) + is_embed_fsdp_wrapped = isinstance(embed, FullyShardedDataParallel) + embed_class = embed + if is_embed_using_act_chkpt: + embed_class = embed._checkpoint_wrapped_module + if is_embed_fsdp_wrapped: + embed_class = embed.module + return embed_class + + def _forward_embeddings_core(self, input_tensor, masks): + embed_class = self._get_embed_class(self.embed) + assert isinstance(embed_class, NemoConvSubsampling) + input_tensor, masks = self.embed(input_tensor, masks) + return input_tensor, masks + + def _position_embedding(self, input_tensor): + pos_k = None + pos_v = None + if self.relative_attention_bias_layer is None: + input_tensor = self.pos_emb( + input_tensor + ) # default to add abs sinusoid embedding + return pos_k, pos_v + + def _streaming_mask(self, seq_len, batch_size, chunk_size, left_chunk): + chunk_size_train_eff, left_chunk_train_eff = self._chunk_size_selection( + chunk_size, left_chunk + ) + + # Create mask matrix for streaming + # S stores start index. if chunksize is 18, s is [0,18,36,....] + chunk_start_idx = np.arange(0, seq_len, chunk_size_train_eff) + + enc_streaming_mask = ( + adaptive_enc_mask( + seq_len, chunk_start_idx, left_window=left_chunk_train_eff + ) + .unsqueeze(0) + .expand([batch_size, -1, -1]) + ) + return enc_streaming_mask + + def forward_embeddings(self, xs_pad, masks, chunk_size_nc=None, left_chunk_nc=None): + """Forwarding the inputs through the top embedding layers + + Args: + xs_pad: torch.Tensor + input tensor + masks: torch.Tensor + input mask + chunk_size_nc: (optional, default is None) chunk size for + non-causal layers + left_chunk_nc: (optional, default is None) # of left chunks for + non-causal layers + """ + # pylint: disable=R0915 + # get new lens. + seq_len = int(self.compute_lens_change(xs_pad.shape[1])) + if seq_len <= 0: + raise ValueError( + f"""The sequence length after time reduction is invalid: + {seq_len}. Your input feature is too short. Consider + filtering out the very short sentence from data + loader""", + ) + + batch_size = xs_pad.shape[0] + + enc_streaming_mask = self._streaming_mask( + seq_len, batch_size, self.chunk_size, self.left_chunk + ) + + if xs_pad.is_cuda: + enc_streaming_mask = enc_streaming_mask.cuda() + xs_pad = xs_pad.cuda() + + input_tensor = xs_pad + input_tensor, masks = self._forward_embeddings_core(input_tensor, masks) + + streaming_mask = enc_streaming_mask + if streaming_mask is not None and masks is not None: + hs_mask = masks & streaming_mask + elif masks is not None: + hs_mask = masks + else: + hs_mask = streaming_mask + + if chunk_size_nc is not None: + enc_streaming_mask_nc = self._streaming_mask( + seq_len, batch_size, chunk_size_nc, left_chunk_nc + ) + if xs_pad.is_cuda: + enc_streaming_mask_nc = enc_streaming_mask_nc.cuda() + if masks is not None: + hs_mask_nc = masks & enc_streaming_mask_nc + else: + hs_mask_nc = enc_streaming_mask_nc + else: + hs_mask_nc = None + + pos_k, pos_v = self._position_embedding(input_tensor) + + if chunk_size_nc is None: + return input_tensor, pos_k, pos_v, hs_mask, masks + return input_tensor, pos_k, pos_v, hs_mask, masks, hs_mask_nc + + def get_offset(self): + """Returns offset used when retaining inputs for decoding. + + This is essentially, how many additional frames have to be added to + the front-end CNN input to ensure it can produce a single output. + So if the "padding" parameter is 0, typically offset will be > 0. + """ + return get_offset(self.input_layer, self.time_reduction) + + +class ConformerEncoder(TransformerEncoderBase): + """ConformerEncoder module. + see original paper for more details: + https://arxiv.org/abs/2005.08100 + + Please set causal = True in streaming model + Args: + input_size: int + input feature dimension. + chunk_size: int, list(int) + Number of frames for each chunk + This variable can take 2 forms: + int: Used for inference, or single chunk size training + list(int) : Used only for variable chunk size training + Some examples for the 2 cases: + chunk_size = 12 + chunk_size = [6, 8, 12, 24] + left_chunk: int, list(int) + Number of chunks used for masking in streaming mode. + This variable can take 2 forms: + int: Used for inference, or single chunk size training + list(int) : Used only for variable chunk size training. When + chunk_size is a list, left_chunk must be a list with same length. + Some examples for the 2 cases: + left_chunk = 6 + left_chunk = [12, 9, 6, 3] + left_chunk: int + number of chunks used for masking in streaming mode. + num_lang: int + This parameter is used to store the number of languages in the + lang_dict, only used for multiseed/multilingual models. + default None. + attention_dim: int, optional + attention dimension. default 256. + attention_heads: int, optional + the number of heads. default 4 + linear_units: + the number of units of position-wise feed forward. + default 2048 + num_block: + number of Transformer layer. default 6 + dropout_rate: float, optional + dropout rate. default 0.1 + input_layer: str, optional + input layer type before Conformer, + one of ["linear", "conv2d", "custom", "vgg2l", "embed"], + default "conv2d" + causal: bool, optional + if set to True, convolution have no access + to future frames. default False. + batch_norm: bool, optional + if set to True, apply batchnorm before activation + in ConvModule layer of the conformer. + default False + cnn_out: int, optional + the number of CNN channels before Conformer. + default -1. + cnn_layer_norm: bool, optional + layer norm between Conformer and the first CNN. + default False. + ext_pw_out_channel: int, optional + the number of channel for CNN + before depthwise_seperable_CNN. + If 0 then use linear. default 0. + ext_pw_kernel_size: int, optional + kernel size of N before depthwise_seperable_CNN. + only work for ext_pw_out_channel > 0. + default 1 + depthwise_seperable_out_channel: int, optional + the number of channel for + depthwise_seperable_CNN. + default 256. + depthwise_multiplier: int, optional + the number of multiplier for + depthwise_seperable_CNN. + default 1. + chunk_se: int, optional + 0 for offline SE. + 1 for streaming SE, where mean is computed + by accumulated history until current chunk_se. + 2 for streaming SE, where mean is computed + by only the current chunk. + default 0. + kernel_size: int, optional + the number of kernels for depthwise_seperable_CNN. + default 3. + activation: str, optional + FeedForward block activation. + one of ["relu", "swish", "sigmoid"] + default "relu". + conv_activation: str, optional + activation function used in ConvModule part + of the conformer, default "relu". + conv_glu_type: str, optional + activation used use glu in depthwise_seperable_CNN, + default "sigmoid" + bias_in_glu: bool, optional + if set to True, use additive bias in the weight module + before GLU. default True + linear_glu_in_convm: bool, optional + if set to True, use GLULinear module, + otherwise, used GLUPointWiseConv module. + default to False. + attention_glu_type: str + only work for glu_in_attention !=0 + default "swish". + export: bool, optional + if set to True, it remove the padding from convolutional layers + and allow the onnx conversion for inference. + default False. + activation_checkpointing: str, optional + a dictionarry of {"module","interval","offload"}, where + "module": str + accept ["transformer", "attention"] to select + which module should do activation checkpointing. + "interval": int, default 1, + interval of applying activation checkpointing, + interval = 1 means that we apply checkpointing + on every layer (if activation), otherwise, + we apply it every x interval. + "offload": bool, default False, + if set to True, we offload activation to cpu and + reload it during backward, otherwise, + we recalculate activation in backward. + default "". + extra_layer_output_idx: int + the layer index to be exposed. + relative_attention_bias_args: dict, optional + use more efficient scalar bias-based relative multihead attention + (Q*K^T + B) implemented in cmb.basics.embedding. + [T5/ALiBi]RelativeAttentionLogitBias + usage: relative_attention_bias_args={"type": t5/alibi} + additional method-specific arguments can be provided (see + transformer_base.py) + time_reduction: int optional + time reduction factor + default 4 + use_pt_scaled_dot_product_attention: whether to use pytorch scaled + dot product attention in training. + Default: False + nemo_conv_settings: dict, optional + A dictionary of settings for NeMo Subsampling. + default: None + usage: nemo_conv_settings= + { + "subsampling": + dw_striding/striding/dw_striding_conv1d/striding_conv1d, + "conv_channels": int, + "subsampling_conv_chunking_factor": int, + "is_causal": True/False + } + conv2d_extra_padding: str, optional + Add extra padding in conv2d subsampling layers. Choices are + (feat, feat_time, none, True) + Default: none + replication_pad_for_subsample_embedding: For batched-streaming + decoding, use "replication" padding for the cache at start of + utterance. + Default: False + attention_group_size: int, optional + the number of groups to use for attention, default 1 + (Multi-Head Attention), + 1 = typical Multi-Head Attention, + 1 < attention_group_size < attention_heads = Grouped-Query + Attention + attention_group_size = attention_heads = Multi-Query Attention + """ + + extra_multi_layer_output_idxs: list[int] + + def __init__( # pylint: disable-all + self, + input_size, + chunk_size, + left_chunk, + num_lang=None, + attention_dim=256, + attention_heads=4, + linear_units=2048, + num_blocks=6, + dropout_rate=0.1, + input_layer="nemo_conv", + causal=True, + batch_norm=False, + cnn_out=-1, + cnn_layer_norm=False, + ext_pw_out_channel=0, + ext_pw_kernel_size=1, + depthwise_seperable_out_channel=256, + depthwise_multiplier=1, + chunk_se=0, + kernel_size=3, + activation="relu", + conv_activation="relu", + conv_glu_type="sigmoid", + bias_in_glu=True, + linear_glu_in_convm=False, + attention_glu_type="swish", + export=False, + extra_layer_output_idx=-1, + extra_multi_layer_output_idxs=[], # noqa + activation_checkpointing="", + relative_attention_bias_args=None, + time_reduction=4, + use_pt_scaled_dot_product_attention=False, + nemo_conv_settings=None, + conv2d_extra_padding: Literal["feat", "feat_time", "none", True] = "none", + replication_pad_for_subsample_embedding=False, + attention_group_size=1, + encoder_embedding_config=None, + ): + super().__init__( + input_size, + chunk_size, + left_chunk, + attention_dim, + attention_heads, + input_layer, + cnn_out, + cnn_layer_norm, + time_reduction, + dropout_rate=dropout_rate, + relative_attention_bias_args=relative_attention_bias_args, + positional_dropout_rate=0.0, + nemo_conv_settings=nemo_conv_settings, + conv2d_extra_padding=conv2d_extra_padding, + attention_group_size=attention_group_size, + encoder_embedding_config=encoder_embedding_config, + ) + self.num_blocks = num_blocks + self.num_lang = num_lang + self.kernel_size = kernel_size + self.replication_pad_for_subsample_embedding: bool = ( + replication_pad_for_subsample_embedding + ) + assert ( + self.num_heads % attention_group_size == 0 + ), "attention_group_size must divide n_head" + self.num_heads_k = self.num_heads // attention_group_size + + self.encoders = MultiSequential( + *[ + ConformerEncoderLayer( + d_model=attention_dim, + ext_pw_out_channel=ext_pw_out_channel, + depthwise_seperable_out_channel=depthwise_seperable_out_channel, + depthwise_multiplier=depthwise_multiplier, + n_head=attention_heads, + d_ffn=linear_units, + ext_pw_kernel_size=ext_pw_kernel_size, + kernel_size=kernel_size, + dropout_rate=dropout_rate, + causal=causal, + batch_norm=batch_norm, + activation=activation, + chunk_se=chunk_se, + chunk_size=chunk_size, + conv_activation=conv_activation, + conv_glu_type=conv_glu_type, + bias_in_glu=bias_in_glu, + linear_glu_in_convm=linear_glu_in_convm, + attention_glu_type=attention_glu_type, + activation_checkpointing=activation_checkpointing, + export=export, + use_pt_scaled_dot_product_attention=use_pt_scaled_dot_product_attention, + attn_group_sizes=attention_group_size, + ) + for _ in range(num_blocks) + ] + ) + self.extra_layer_output_idx = extra_layer_output_idx + self.extra_multi_layer_output_idxs = extra_multi_layer_output_idxs + # Make a zeros scalar we can use in get_initial_state to determine + # the device and the needed dtype: + self.register_buffer("dev_type", torch.zeros(()), persistent=False) + + def init_relative_attention_bias(self, input_tensor): + if self.relative_attention_bias_layer: + return self.relative_attention_bias_layer(input_tensor) + + def calculate_hs_mask(self, xs_pad, device, mask): + max_audio_length = xs_pad.shape[1] + batch_size = xs_pad.shape[0] + enc_streaming_mask = self._streaming_mask( + max_audio_length, batch_size, self.chunk_size, self.left_chunk + ) + enc_streaming_mask = enc_streaming_mask.to(device) + if mask is None: + return enc_streaming_mask + + feature_lens = mask.sum(1) + padding_length = feature_lens + pad_mask = torch.arange(0, max_audio_length, device=device).expand( + padding_length.size(0), -1 + ) < padding_length.unsqueeze(1) + pad_mask = pad_mask.unsqueeze(1) + pad_mask = pad_mask & enc_streaming_mask + return pad_mask + + @torch.jit.ignore + def forward(self, xs_pad, masks): + """Conformer Forward function + + Args: + xs_pad: torch.Tensor + input tensor + masks: torch.Tensor + post-embedding input lengths + """ + xs_pad = self.encoder_embedding(xs_pad) + input_tensor, pos_k, pos_v, hs_mask, masks = self.forward_embeddings( + xs_pad, masks + ) + + unfolded = False + ori_bz, seq_len, D = input_tensor.shape + max_seq_len = 500 # maximum position for absolute positional encoding + if seq_len > max_seq_len: + # audio sequence is longer than max_seq_len, unfold it into chunks + # of max_seq_len + unfolded = True + # the unfold op will drop residual frames, pad it to the multiple + # of max_seq_len + if seq_len % max_seq_len > 0: + chunk_pad_size = max_seq_len - (seq_len % max_seq_len) + else: + chunk_pad_size = 0 + if chunk_pad_size > 0: + input_tensor_pad = F.pad( + input_tensor, (0, 0, 0, chunk_pad_size), "constant", 0 + ) + input_tensor = input_tensor_pad.to(input_tensor.device) + input_tensor = unfold_tensor(input_tensor, max_seq_len) + if masks is not None: + # revise hs_mask here because the previous calculated hs_mask + # did not consider extra pad + subsampled_pad_mask = masks.squeeze( + 1 + ) # [bz, subsampled_unmask_seq_len] + extra_padded_subsamlped_pad_mask = F.pad( + subsampled_pad_mask, (0, chunk_pad_size), "constant", False + ) # extra padding to the pad mask + extra_padded_subsamlped_pad_mask = ( + extra_padded_subsamlped_pad_mask.unsqueeze(-1).float() + ) + masks_unfold = unfold_tensor( + extra_padded_subsamlped_pad_mask, max_seq_len + ) # unfold the pad mask like we did to the input tensor + masks_unfold = masks_unfold.squeeze( + -1 + ).bool() # unfold op does not support bool tensor + else: + masks_unfold = None + hs_mask = self.calculate_hs_mask( + input_tensor, input_tensor.device, masks_unfold + ) # calculate hs_mask based on the unfolded pad mask + + # layer_emb = None + + relative_attention_bias = self.init_relative_attention_bias(input_tensor) + + _simplified_path = ( + self.extra_layer_output_idx == -1 and relative_attention_bias is None + ) + + if _simplified_path: + input_tensor, *_ = self.encoders(input_tensor, pos_k, pos_v, hs_mask) + else: + for i, layer in enumerate(self.encoders): + input_tensor, _, _, _ = layer( + input_tensor, + pos_k, + pos_v, + hs_mask, + relative_attention_bias=relative_attention_bias, + ) + + # if i == self.extra_layer_output_idx: + # layer_emb = input_tensor + + if unfolded: + embed_dim = input_tensor.shape[-1] + input_tensor = input_tensor.reshape(ori_bz, -1, embed_dim) + # if we ever padded before unfolding, we need to remove the padding + if chunk_pad_size > 0: + input_tensor = input_tensor[:, :-chunk_pad_size, :] + + return input_tensor, masks # , layer_emb + + +class WindowQformer(nn.Module): + """Window-level Qformer""" + + def __init__( + self, + window_size: int = 8, + num_queries: int = 1, + num_blocks: int = 2, + attention_dim: int = 512, + attention_heads: int = 8, + linear_units: int = 2048, + dropout_rate: float = 0.0, + normalize_before: bool = True, + ): + super().__init__() + + self.decoders = nn.ModuleList( + [ + nn.TransformerDecoderLayer( + d_model=attention_dim, + nhead=attention_heads, + dim_feedforward=linear_units, + dropout=dropout_rate, + activation="relu", + batch_first=True, + norm_first=normalize_before, # TODO need to verify + ) + for _ in range(num_blocks) + ] + ) + + self.queries = nn.Parameter(torch.zeros(1, num_queries, attention_dim)) + self.after_norm = ( + nn.LayerNorm(attention_dim, eps=1e-12) if normalize_before else None + ) + self.window_size = window_size + + def forward(self, audio_embed, mask, embed_len=None): + """forward decoder""" + # audio_embed: N x T x D => N x D x T + + audio_embed = audio_embed.transpose(1, 2) + # audio_embed: N x D x 1 x T => N x DK x T' + padding = audio_embed.shape[-1] % self.window_size + if padding > 0: + audio_embed = F.pad( + audio_embed, (0, self.window_size - padding), "constant", 0 + ) + + embed_chunk = F.unfold( + audio_embed[..., None, :], + kernel_size=(1, self.window_size), + stride=(1, self.window_size), + ) + bsz, _, slen = embed_chunk.shape + # N x D x K x T' + embed_chunk = embed_chunk.view(bsz, -1, self.window_size, slen) + # N x T' x K x D + embed_chunk = embed_chunk.transpose(1, 3).contiguous() + # NT' x K x D + embed_chunk = embed_chunk.view(bsz * slen, self.window_size, -1) + # NT' x 1 x D + q = self.queries.expand(bsz * slen, -1, -1) + for layer in self.decoders: + q = layer(tgt=q, memory=embed_chunk, tgt_mask=None, memory_mask=mask) + + if self.after_norm is not None: + q = self.after_norm(q) + + if embed_len is not None: + embed_len = embed_len // self.window_size + # N x T' x D + out = q.view(bsz, slen, -1) + + return out, embed_len + + +class AudioEmbedding(nn.Module): + """Image embedding.""" + + def __init__(self, config: PretrainedConfig, **kwargs) -> None: + super().__init__() + self.config = config + # n_embed or hidden_size for text LM + hidden_size = config.n_embd if hasattr(config, "n_embd") else config.hidden_size + + # self.wte = nn.Embedding(config.vocab_size, hidden_size) + + audio_dim_out = ( + None # Set this variable according to the actual audio processor + ) + self.layer_idx = -2 + + if ( + isinstance(config.audio_processor, dict) + and config.audio_processor.get("name", None) == "cascades" + ): + encoder_config = config.audio_processor.get("config", None) + assert encoder_config is not None + self.encoder = ConformerEncoder(**encoder_config) + + audio_dim_out = encoder_config["attention_dim"] + n_mels = encoder_config["input_size"] + else: + raise NotImplementedError("") + + assert audio_dim_out is not None, "Remember to set values for audio_dim_out" + self.audio_dim_out = audio_dim_out + self.audio_dim_in = n_mels + + self.freeze_audio_processor = kwargs.get("freeze_audio_processor", False) + + self.downsample_rate = kwargs.get("downsample_rate", 1) + + if kwargs.get("use_qformer", False): + qformer_config = kwargs.get("qformer_config", {}) + qformer_config["attention_dim"] = audio_dim_out + self.qformer = WindowQformer(**qformer_config) + else: + self.qformer = None + + if kwargs.get("use_conv_downsample", False): + assert ( + self.qformer is None + ), "don't support use qformer and conv downsample together" + nemo_conv_settings = kwargs.get("nemo_conv_settings", {}) + default_nemo_conv_settings = { + "subsampling": "dw_striding", + "subsampling_factor": self.downsample_rate, + "feat_in": audio_dim_out, + "feat_out": audio_dim_out, + "conv_channels": 256, + "subsampling_conv_chunking_factor": 1, + "activation": nn.ReLU(), + "is_causal": False, + } + # Override any of the defaults with the incoming, user settings + if nemo_conv_settings: + default_nemo_conv_settings.update(nemo_conv_settings) + for i in ["subsampling_factor", "feat_in", "feat_out"]: + assert ( + i not in nemo_conv_settings + ), "{i} should be specified outside of the NeMo dictionary" + + self.conv_ds = NemoConvSubsampling( + **default_nemo_conv_settings, + ) + else: + self.conv_ds = None + + projection_cls = kwargs.get("projection_cls", "linear") + if projection_cls == "linear": + self.audio_projection = nn.Linear(audio_dim_out, hidden_size) + elif projection_cls == "mlp": + # follow llava-v1.5's implementation + # (do not use image_projection and image_proj_norm) + dim_projection = hidden_size + depth = 2 + self.linear_downsample_rate = ( + 1 if (self.qformer or self.conv_ds) else self.downsample_rate + ) + layers = [ + nn.Linear(audio_dim_out * self.linear_downsample_rate, dim_projection) + ] + for _ in range(1, depth): + layers.extend([nn.GELU(), nn.Linear(dim_projection, dim_projection)]) + self.audio_projection = nn.Sequential(*layers) + # NOTE vision-speech tasks use a separate projection layer + layers = [ + nn.Linear(audio_dim_out * self.linear_downsample_rate, dim_projection) + ] + for _ in range(1, depth): + layers.extend([nn.GELU(), nn.Linear(dim_projection, dim_projection)]) + self.audio_projection_for_vision = nn.Sequential(*layers) + else: + raise NotImplementedError( + f"projection_cls = {projection_cls}, not implemented" + ) + + # TODO: audio sequence compression - Qformer + self.vocab_size = config.vocab_size + self.input_embeds = None + self.audio_embed_sizes = None + + def set_audio_embeds(self, input_embeds: torch.FloatTensor) -> None: + self.input_embeds = input_embeds + + def set_audio_embed_sizes(self, audio_embed_sizes: torch.LongTensor) -> None: + self.audio_embed_sizes = audio_embed_sizes + + def get_audio_features( + self, + input_embeds: torch.FloatTensor, + audio_attention_mask: torch.Tensor = None, + audio_projection_mode: str = "speech", + ) -> torch.FloatTensor: + """ + arguments: + input_embeds: audio features (B, T, D) B: num audios in a sequence + """ + if self.freeze_audio_processor: + with torch.no_grad(): + audio_features, masks = self.encoder(input_embeds, audio_attention_mask) + else: + audio_features, masks = self.encoder(input_embeds, audio_attention_mask) + + if self.qformer is not None: + audio_features, _ = self.qformer(audio_features, mask=None) + + if self.conv_ds is not None: + if masks is not None: + masks = masks.squeeze(1) + + audio_features, masks = self.conv_ds(audio_features, mask=masks) + + if self.linear_downsample_rate != 1: + bs, seq_len, feat_dim = audio_features.size() + padding = seq_len % self.linear_downsample_rate + if padding > 0: + audio_features = F.pad( + audio_features, + (0, 0, 0, self.linear_downsample_rate - padding), + "constant", + 0, + ) + + seq_len = audio_features.size(1) + audio_features = audio_features.view( + bs, + seq_len // self.linear_downsample_rate, + feat_dim * self.linear_downsample_rate, + ) + + if audio_projection_mode == "speech": + audio_set_tensor = self.audio_projection(audio_features) + elif audio_projection_mode == "vision": + audio_set_tensor = self.audio_projection_for_vision(audio_features) + else: + raise ValueError( + f"audio_projection_mode = {audio_projection_mode} not " "implemented" + ) + + return audio_set_tensor + + def forward( + self, + audio_features: torch.FloatTensor, + audio_attention_mask: torch.Tensor = None, + audio_projection_mode: str = "speech", + ) -> torch.FloatTensor: + """ + arguments: + audio_features: audio features (num_audio_tokens, T, D) + + returns: + audio_embeds: audio embeddings (num_audio_tokens, hidden_dim) + """ + audio_embeds = self.get_audio_features( + audio_features, + audio_attention_mask=audio_attention_mask, + audio_projection_mode=audio_projection_mode, + ) + return audio_embeds diff --git a/python/sglang/srt/models/phi4mm_utils.py b/python/sglang/srt/models/phi4mm_utils.py new file mode 100644 index 000000000..e6bf35ebf --- /dev/null +++ b/python/sglang/srt/models/phi4mm_utils.py @@ -0,0 +1,1917 @@ +# Copyright 2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +#!/usr/bin/env python3 +import math +from typing import Optional, Union + +import torch +import torch.nn.functional as F +from torch import Tensor, nn + + +class BlockBase(nn.Module): + """Block abstract module""" + + def __init__(self, input_size, output_size): + super().__init__() + self.input_size = input_size + self.output_size = output_size + + +def get_activation(name="relu"): + """Select an activation function by name + + Args: + name: str + activation function name, + one of ["relu", "gelu", "swish", "sigmoid"], + default "relu". + """ + name = name.lower() + if name == "relu": + return nn.ReLU(inplace=True) + if name == "gelu": + return nn.GELU() + if name == "swish": + return Swish() + if name == "sigmoid": + return torch.nn.Sigmoid() + return nn.Identity() + + +def adaptive_enc_mask(x_len, chunk_start_idx, left_window=0, right_window=0): + """ + The function is very important for Transformer Transducer Streaming mode + Args: + xs_len (int): sequence length + chunk_start_idx (list): first idx of each chunk, such as [0,18,36,48]. + It also supports adaptive chunk size [0,10,15,45] + left_window (int): how many left chunks can be seen + right_window (int): how many right chunks can be seen. It is used for + chunk overlap model. + Returns: + mask (torch.Tensor): a mask tensor for streaming model + Torch 1.0.1 + tensor([[1., 1., 0., 0.], + [0., 1., 1., 0.], + [0., 0., 1., 1.]]) + Torch 1.4.1 + tensor([[True., True., False., False.], + [False., True., True., False.], + [False., False., True., True.]]) + """ + chunk_start_idx = torch.Tensor( + chunk_start_idx + ).long() # first idx of each chunk, such as [0,18,36,48]. + start_pad = torch.nn.functional.pad( + chunk_start_idx, (1, 0) + ) # append 0 to the beginning, so it becomes [0, 0, 18, 36, 48] + end_pad = torch.nn.functional.pad( + chunk_start_idx, (0, 1), value=x_len + ) # append x_len to the end, so it becomes [0,18,36,48, x_len] + seq_range = torch.arange(0, x_len).unsqueeze(-1) # seq_range size: [x_len, 1] + idx = ((seq_range < end_pad) & (seq_range >= start_pad)).nonzero()[ + :, 1 + ] # idx size: [x_len] + # boundary = end_pad[idx] # boundary size: [x_len] + seq_range_expand = ( + torch.arange(0, x_len).unsqueeze(0).expand(x_len, -1) + ) # seq_range_expand size [x_len, x_len] + idx_left = idx - left_window + idx_left[idx_left < 0] = 0 + boundary_left = start_pad[idx_left] + mask_left = seq_range_expand >= boundary_left.unsqueeze(-1) + idx_right = idx + right_window + idx_right[idx_right > len(chunk_start_idx)] = len(chunk_start_idx) + boundary_right = end_pad[idx_right] + mask_right = seq_range_expand < boundary_right.unsqueeze(-1) + return mask_left & mask_right + + +class Swish(nn.Module): + """Implement Swish activation module. + From https://arxiv.org/pdf/2005.03191.pdf + + """ + + def __init__(self) -> None: + super().__init__() + self.act_fn = nn.Sigmoid() + + def forward(self, x: Tensor) -> Tensor: + """Apply Swish function + + Args: + x: torch.Tensor + Input. + """ + return x * self.act_fn(x) + + +class GLU(nn.Module): + """Implement Gated Linear Unit (GLU) module""" + + def __init__(self, dim: int = -1, act_name: str = "sigmoid") -> None: + super().__init__() + self.dim = dim + self.act_name = act_name.lower() + + if self.act_name == "relu": + self.act_fn = nn.ReLU(inplace=True) + elif self.act_name == "gelu": + self.act_fn = nn.GELU() + elif self.act_name == "swish": + self.act_fn = Swish() + elif self.act_name == "sigmoid": + self.act_fn = nn.Sigmoid() + else: + self.act_fn = nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + """GLU forward + Apply Swish function on the first half of input matrices + with sigmoid of the second half. + + Args: + x: torch.Tensor + Input. + + """ + half_x, gate = x.chunk(2, dim=self.dim) + return half_x * self.act_fn(gate) + + +# TODO: Abdel, this can be improved using GLU module +class GLUPointWiseConv(nn.Module): + """GLUPointWiseConv module + used for conformer architecture, + for more details see: + https://arxiv.org/pdf/2005.08100v1.pdf + + Args: + input_dim: int + input channel size. + output_dim: int + output channel size. + kernel_size: int + kernel size + glu_type: str, optional + activation function one of + ["sigmoid", "relu", "gelu"] + default "sigmoid". + bias_in_glu: bool, optional + use addtive bias in glu + causal: bool, optional + if set to True, padding is set to the half of + kernel size, ie, convolution can't see future frames. + default False. + + """ + + def __init__( + self, + input_dim, + output_dim, + kernel_size, + glu_type="sigmoid", + bias_in_glu=True, + causal=False, + ): + super().__init__() + + self.glu_type = glu_type + self.output_dim = output_dim + self.bias_in_glu = bias_in_glu + if causal: + self.ext_pw_conv_1d = nn.Conv1d( + input_dim, + output_dim * 2, + kernel_size, + 1, + padding=(kernel_size - 1), + ) + else: + self.ext_pw_conv_1d = nn.Conv1d( + input_dim, + output_dim * 2, + kernel_size, + 1, + padding=(kernel_size - 1) // 2, + ) + + if glu_type == "sigmoid": + self.glu_act = nn.Sigmoid() + elif glu_type == "relu": + self.glu_act = nn.ReLU() + elif glu_type == "gelu": + self.glu_act = nn.GELU() + elif glu_type == "swish": + self.glu_act = Swish() + else: + raise ValueError(f"Unsupported activation type {self.glu_act}") + + if bias_in_glu: + self.b1 = nn.Parameter(torch.zeros(1, output_dim, 1)) + self.b2 = nn.Parameter(torch.zeros(1, output_dim, 1)) + + def forward(self, x): + """ + Args: + x: torch.Tensor + input tensor + """ + # to be consistent with GLULinear, we assume the input always has the + # #channel (#dim) in the last dimension of the tensor, so need to + # switch the dimension first for 1D-Conv case + x = x.permute([0, 2, 1]) + x = self.ext_pw_conv_1d(x) + if self.glu_type == "bilinear": + if self.bias_in_glu: + x = (x[:, 0 : self.output_dim, :] + self.b1) * ( + x[:, self.output_dim : self.output_dim * 2, :] + self.b2 + ) + else: + x = (x[:, 0 : self.output_dim, :]) * ( + x[:, self.output_dim : self.output_dim * 2, :] + ) + else: + if self.bias_in_glu: + x = (x[:, 0 : self.output_dim, :] + self.b1) * self.glu_act( + x[:, self.output_dim : self.output_dim * 2, :] + self.b2 + ) + else: + x = (x[:, 0 : self.output_dim, :]) * self.glu_act( + x[:, self.output_dim : self.output_dim * 2, :] + ) + + x = x.permute([0, 2, 1]) + return x + + +class DepthWiseSeperableConv1d(nn.Module): + """DepthWiseSeperableConv1d module used in Convnet module + for the conformer, for more details see: + https://arxiv.org/pdf/2005.08100v1.pdf + + Args: + input_dim: int + input channel size. + depthwise_seperable_out_channel: int + if set different to 0, the number of + depthwise_seperable_out_channel will be used as a channel_out + of the second conv1d layer. + otherwise, it equal to 0, the second conv1d layer is skipped. + kernel_size: int + kernel_size + depthwise_multiplier: int + number of input_dim channels duplication. this value + will be used to compute the hidden channels of the Conv1D. + padding: int, optional + padding for the conv1d, + default: 0. + + """ + + def __init__( + self, + input_dim, + depthwise_seperable_out_channel, + kernel_size, + depthwise_multiplier, + padding=0, + ): + super().__init__() + + self.dw_conv = nn.Conv1d( + input_dim, + input_dim * depthwise_multiplier, + kernel_size, + 1, + padding=padding, + groups=input_dim, + ) + + if depthwise_seperable_out_channel != 0: + self.pw_conv = nn.Conv1d( + input_dim * depthwise_multiplier, + depthwise_seperable_out_channel, + 1, + 1, + 0, + ) + else: + self.pw_conv = nn.Identity() + self.depthwise_seperable_out_channel = depthwise_seperable_out_channel + + def forward(self, x): + """ + + Args: + x: torch.Tensor + input tensor + """ + x = self.dw_conv(x) + if self.depthwise_seperable_out_channel != 0: + x = self.pw_conv(x) + return x + + +class ConvModule(nn.Module): + """ConvModule Module for the conformer block. + for more details see: + https://arxiv.org/pdf/2005.08100v1.pdf + + Args: + input_dim: int + input channel size. + ext_pw_out_channel: int + if > 0, ext_pw_out_channel is a dim channel size + for the last pointwise conv after swish activation. + depthwise_seperable_out_channel: int + if set different to 0, the number of + depthwise_seperable_out_channel + will be used as a channel_out of the second conv1d layer. + otherwise, it equal to 0, the second conv1d layer is skipped. + ext_pw_kernel_size: int + kernel size of the conv pointwise of the conformer. + kernel_size: int + kernel size. + depthwise_multiplier: int + number of input_dim channels duplication. this value + will be used to compute the hidden channels of the Conv1D. + dropout_rate: float + dropout rate. + causal: bool, optional + if set to True, convolution have no access + to future frames. default False. + batch_norm: bool, optional + if set to True, apply batchnorm before activation. + default False + chunk_se: int, optional + 0 for offline SE. + 1 for streaming SE, where mean is computed + by accumulated history until current chunk_se. + 2 for streaming SE, where mean is computed + by only the current chunk. + chunk_size: int, optional + chunk size for cnn. default 18 + activation: str, optional + activation function used in ConvModule, + default: "relu". + glu_type: str, optional + activation function used for the glu, + default: "sigmoid". + bias_in_glu: bool, optional + if set to True, use additive bias in the weight module + before GLU. + linear_glu_in_convm: bool, optional + if set to True, use GLULinear module, + otherwise, used GLUPointWiseConv module. + default to False. + export: bool, optional, + if set to True, padding is equal to 0. This is for inference, + or onnx export. Typically this is set by the export program or + the decoder program, and it isn't present in your config file. + default False + """ + + def __init__( + self, + input_dim, + ext_pw_out_channel, + depthwise_seperable_out_channel, + ext_pw_kernel_size, + kernel_size, + depthwise_multiplier, + dropout_rate, + causal=False, + batch_norm=False, + chunk_se=0, + chunk_size=18, + activation="relu", + glu_type="sigmoid", + bias_in_glu=True, + linear_glu_in_convm=False, + export=False, + ): + super().__init__() + self.layer_norm = nn.LayerNorm(input_dim) + self.input_dim = input_dim + self.ext_pw_out_channel = ext_pw_out_channel + self.ext_pw_kernel_size = ext_pw_kernel_size + self.depthwise_seperable_out_channel = depthwise_seperable_out_channel + self.glu_type = glu_type + self.bias_in_glu = bias_in_glu + self.linear_glu_in_convm = linear_glu_in_convm + self.causal = causal + + self._add_ext_pw_layer() + + self.batch_norm = batch_norm + self.kernel_size = kernel_size + + if batch_norm: + self.bn_layer = nn.BatchNorm1d(input_dim) + + self.act = get_activation(activation) + self.dropout = nn.Dropout(dropout_rate) + self.export = export + + if causal: + padding = 0 if export else kernel_size - 1 + else: + padding = (kernel_size - 1) // 2 + + self.dw_sep_conv_1d = DepthWiseSeperableConv1d( + input_dim, + depthwise_seperable_out_channel, + kernel_size, + depthwise_multiplier, + padding=padding, + ) + + if depthwise_seperable_out_channel != 0: + if input_dim != depthwise_seperable_out_channel: + self.ln2 = nn.Linear(depthwise_seperable_out_channel, input_dim) + else: + if depthwise_multiplier != 1: + self.ln2 = nn.Linear(input_dim * depthwise_multiplier, input_dim) + + def _add_ext_pw_layer(self): + """ + This function is an extension of __init__ function + and dedicated to the convolution module creation + of the conformer. + """ + self.ln1 = self.glu = self.bn_layer = self.ext_pw_conv_1d = ( + nn.Identity() + ) # jit hacks. + self.squeeze_excitation = nn.Identity() # jit. + self.apply_ln1 = self.fix_len1 = False # jit. + + if self.ext_pw_out_channel != 0: + if self.causal: + self.ext_pw_conv_1d = nn.Conv1d( + self.input_dim, + self.ext_pw_out_channel, + self.ext_pw_kernel_size, + 1, + padding=(self.ext_pw_kernel_size - 1), + ) + if self.ext_pw_kernel_size > 1: + self.fix_len1 = True + else: + self.fix_len1 = False + else: + self.ext_pw_conv_1d = nn.Conv1d( + self.input_dim, + self.ext_pw_out_channel, + self.ext_pw_kernel_size, + 1, + padding=(self.ext_pw_kernel_size - 1) // 2, + ) + self.fix_len1 = False + + if self.linear_glu_in_convm: + self.glu = GLULinear( + self.input_dim, + self.ext_pw_out_channel, + self.glu_type, + self.bias_in_glu, + ) + else: + self.glu = GLUPointWiseConv( + self.input_dim, + self.ext_pw_out_channel, + self.ext_pw_kernel_size, + self.glu_type, + self.bias_in_glu, + self.causal, + ) + + if self.input_dim != self.ext_pw_out_channel: + self.apply_ln1 = True + self.ln1 = nn.Linear(self.ext_pw_out_channel, self.input_dim) + else: + self.apply_ln1 = False + else: + self.pw_conv_simplify_w = torch.nn.Parameter(torch.ones(3)) + self.pw_conv_simplify_b = torch.nn.Parameter(torch.zeros(3)) + + def forward(self, x): + """ConvModule Forward. + + Args: + x: torch.Tensor + input tensor. + """ + x = self.layer_norm(x) + + if self.ext_pw_out_channel != 0: + x = self.glu(x) + if self.causal and self.ext_pw_kernel_size > 1: + x = x[:, : -(self.ext_pw_kernel_size - 1), :] + if self.apply_ln1: + x = self.ln1(x) + else: + x_0 = x * self.pw_conv_simplify_w[0] + self.pw_conv_simplify_b[0] + x_1 = x * self.pw_conv_simplify_w[1] + self.pw_conv_simplify_b[1] + x = x_0 + x_1 + + x = x.permute([0, 2, 1]) + + x = self.dw_sep_conv_1d(x) + if self.causal and self.kernel_size > 1: + x = x[:, :, : -(self.kernel_size - 1)] + if hasattr(self, "ln2"): + x = x.permute([0, 2, 1]) + x = self.ln2(x) + x = x.permute([0, 2, 1]) + if self.batch_norm: + x = self.bn_layer(x) + x = self.act(x) + + if self.ext_pw_out_channel != 0: + x = self.ext_pw_conv_1d(x) + if self.fix_len1: + x = x[:, :, : -(self.ext_pw_kernel_size - 1)] + + if self.apply_ln1: + x = x.permute([0, 2, 1]) + x = self.ln1(x) + x = x.permute([0, 2, 1]) + + x = x.permute([0, 2, 1]) + else: + x = x.unsqueeze(1).permute([0, 1, 3, 2]) + x = x * self.pw_conv_simplify_w[2] + self.pw_conv_simplify_b[2] + x = x.squeeze(1) + + x = self.dropout(x) + return x + + +class GLULinear(nn.Module): + """Linear + GLU module + + Args: + input_dim: int + input size + output_dim: int + output size. + glu_type: + activation function name used in glu module. + default "sigmoid" (swish function). + bias_in_glu: bool, optional + If True, the addtive bias is added. Default False. + """ + + def __init__( + self, + input_dim, + output_dim, + glu_type="sigmoid", + bias_in_glu=True, + ): + super().__init__() + self.linear = nn.Linear(input_dim, output_dim * 2, bias_in_glu) + self.glu_act = GLU(-1, glu_type) + + def forward(self, x): + """GLULinear forward + + Args: + x: torch.Tensor + inpute tensor. + """ + x = self.linear(x) + return self.glu_act(x) + + +class FeedForward(nn.Module): + """FeedForward Module. + For more details see Conformer paper: + https://arxiv.org/pdf/2005.08100.pdf + + Args: + d_model: int + input size. + d_inner: int + output size. + dropout_rate: float, + dropout rate. + activation: str, + activation function name, + one of ["relu", "swish", "sigmoid"], + sigmoid activation is only used with "glu_in_fnn=True", + default "sigmoid". + bias_in_glu: bool, optional + """ + + def __init__( + self, + d_model, + d_inner, + dropout_rate, + activation="sigmoid", + bias_in_glu=True, + ): + super().__init__() + self.d_model = d_model + self.d_inner = d_inner + + self.layer_norm = nn.LayerNorm(d_model) + module = GLULinear(d_model, d_inner, activation, bias_in_glu) + self.net = nn.Sequential( + module, + nn.Dropout(dropout_rate), + nn.Linear(d_inner, d_model), + nn.Dropout(dropout_rate), + ) + + def forward(self, x): + """FeedForward forward function. + + Args: + x: torch.Tensor + input tensor. + """ + out = self.net(self.layer_norm(x)) + + return out + + +#### positional encoding starts here +def _pre_hook( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, +): + """Perform pre-hook in load_state_dict for backward compatibility. + + Note: + We saved self.pe until v.0.5.2 but we have omitted it later. + Therefore, we remove the item "pe" from `state_dict` for backward + compatibility. + + """ + k = prefix + "pe" + if k in state_dict: + state_dict.pop(k) + + +class T5RelativeAttentionLogitBias(nn.Module): + """ + This module implements the relative position bias described in Section + 2.1 of the T5 paper: https://arxiv.org/pdf/1910.10683.pdf + + The Huggingface implementation is used as a reference + https://github.com/huggingface/transformers/blob/v4.30.0/src/ + transformers/models/t5/modeling_t5.py#L435 + + Modifies attention as Q*K^T + B, where B is a learned scalar bias based + on relative position of the query and key. It is HxNxN, where H is the + number of heads, N is the sequence length. + + I've made these modifications to the original T5 bias: + - Skipping of the bucketing step. Original T5 bias converted rel + position distances into logarithmically increasing buckets. This is + supposed to help with length generalization. + - I just directly use rel position index as bias values, as we don't + need length generalization (40s max is good enough for ASR encoder), + and it keeps ONNX export simple. + - I've also extended it so that biases can be asymmetric, the default + implementation treats L->R and R->L the same. Asymmetric was found to + yield better results in my experiments. + + Args: + num_heads: int + Number of attention heads + num_buckets: int + Number of buckets to use for relative attention bias. This is the + size of the learnable bias parameter. Bucketing is not yet + supported, so this defaults to -1 which means no bucketing is + used (max_distance determines size of bias param). + max_distance: int + Maximum distance to use for relative attention bias. With + num_buckets=-1, this directly controls the max size of the bias + parameter. When num_buckets > 0 is supported, this will control + the maximum distance for logarithmic bucketing after which all + positions are in the same bucket. + symmetric: bool + Whether to use symmetric or asymmetric biases. symmetric=False uses + 2x number of bias params to distinguish L->R from R->L. This was + found to be better for the encoder. + """ + + def __init__(self, num_heads, num_buckets=-1, max_distance=1000, symmetric=False): + super().__init__() + self.num_heads = num_heads + self.num_buckets = num_buckets + self.max_distance = max_distance + self.symmetric = symmetric + self._skip_bucketing = self.num_buckets < 0 + if self._skip_bucketing: + self.num_buckets = max_distance + else: + raise NotImplementedError( + "T5 attention bias with bucketed positions is not yet tested" + ) + if not self.symmetric: + self.num_buckets *= 2 + self.bias_values = nn.Embedding(self.num_buckets, self.num_heads) + + def forward(self, x): + # instantiate bias compatible with shape of x + maxpos = x.size(1) + context_position = torch.arange(maxpos, device=x.device, dtype=torch.long)[ + :, None + ] + memory_position = torch.arange(maxpos, device=x.device, dtype=torch.long)[ + None, : + ] + relative_position = memory_position - context_position + # clipping to a maximum distance using ops that play well with ONNX + # export + relative_position = relative_position.masked_fill( + relative_position < -self.max_distance, -self.max_distance + ) + relative_position = relative_position.masked_fill( + relative_position > self.max_distance - 1, self.max_distance - 1 + ) + + # mapping from relative position to index in the bias parameter + if self._skip_bucketing: + bias_idx = relative_position + else: + bias_idx = self._bucket_relative_position(relative_position) + if self.symmetric: + bias_idx = bias_idx.abs() + else: + bias_idx += self.num_buckets // 2 + + t5_rel_att_bias = self.bias_values(bias_idx) # [L, L, H] + t5_rel_att_bias = t5_rel_att_bias.permute(2, 0, 1).unsqueeze(0) # [1, H, L, L] + + return t5_rel_att_bias + + def _bucket_relative_position(self, relative_position): + # this is a placeholder (isn't tested, likely buggy) using HuggingFace + # implem as a reference this also needs to be extended to support + # asymmetric +/- ve positions + relative_buckets = 0 + if not self.causal: + self.num_buckets //= 2 + relative_buckets += (relative_position > 0).to( + torch.long + ) * self.num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min( + relative_position, torch.zeros_like(relative_position) + ) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = self.num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in + # positions up to max_distance + relative_position_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(self.max_distance / max_exact) + * (self.num_buckets - max_exact) + ).to(torch.long) + relative_position_if_large = torch.min( + relative_position_if_large, + torch.full_like(relative_position_if_large, self.num_buckets - 1), + ) + + relative_buckets += torch.where( + is_small, relative_position, relative_position_if_large + ) + return relative_buckets + + +class AbsolutePositionalEncoding(nn.Module): + """Absolute Positional encoding module. + This module implement Absolute sinusoidal positional encoding + from: https://arxiv.org/pdf/1706.03762.pdf + + Args: + d_model: int + Input embedding size. + dropout_rate: float + dropout rate + max_len: int, optional + Maximum input length sequence, Default 5000 + + """ + + def __init__(self, d_model, dropout_rate, max_len=5000): + """Construct an PositionalEncoding object.""" + super().__init__() + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + self._register_load_state_dict_pre_hook(_pre_hook) + + def extend_pe(self, x): + """Reset the positional encodings. + + Args: + x: torch.Tensor + """ + if self.pe is not None and self.pe.size(1) >= x.size(1): + if self.pe.dtype != x.dtype or self.pe.device != x.device: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + pe = torch.zeros(x.size(1), self.d_model) + position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.d_model, 2, dtype=torch.float32) + * -(math.log(10000.0) / self.d_model) + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + self.pe = pe.to(device=x.device, dtype=x.dtype) + + def forward(self, x: torch.Tensor): + """Add positional encoding. + + Args: + x: torch.Tensor + Input tensor. shape is (batch, time, ...) + + Returns: + torch.Tensor: Encoded tensor. Its shape is (batch, time, ...) + + """ + self.extend_pe(x) + x = x * self.xscale + self.pe[:, : x.size(1)] + return self.dropout(x) + + +#### forward embedding layers starts here +class MeanVarianceNormLayer(nn.Module): + """Mean/variance normalization layer. + + Will subtract mean and multiply input by inverted standard deviation. + Typically used as a very first layer in a model. + + Args: + input_size: int + layer input size. + """ + + def __init__(self, input_size): + super().__init__() + self.input_size = input_size + self.global_mean = nn.Parameter(torch.zeros(input_size)) + self.global_invstd = nn.Parameter(torch.ones(input_size)) + + def forward(self, input_: Tensor) -> Tensor: + """MeanVarianceNormLayer Forward + + Args: + input_: torch.Tensor + input tensor. + """ + return (input_ - self.global_mean) * self.global_invstd + + +class CausalConv1D(nn.Conv1d): + """ + A causal version of nn.Conv1d where each step would have limited access to + locations on its right or left + All arguments are the same as nn.Conv1d except padding. + + If padding is set None, then paddings are set automatically to make it a + causal convolution where each location would not see any steps on its right. + + If padding is set as a list (size of 2), then padding[0] would be used as + left padding and padding[1] as right padding. + It would make it possible to control the number of steps to be accessible + on the right and left. + This mode is not supported when stride > 1. padding[0]+padding[1] should + be equal to (kernel_size - 1). + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: Union[str, int] = 0, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", + device=None, + dtype=None, + ) -> None: + self.cache_drop_size = None + if padding is None: + self._left_padding = kernel_size - 1 + self._right_padding = stride - 1 + else: + if stride != 1 and padding != kernel_size - 1: + raise ValueError("No striding allowed for non-symmetric convolutions!") + if isinstance(padding, int): + self._left_padding = padding + self._right_padding = padding + elif ( + isinstance(padding, list) + and len(padding) == 2 + and padding[0] + padding[1] == kernel_size - 1 + ): + self._left_padding = padding[0] + self._right_padding = padding[1] + else: + raise ValueError(f"Invalid padding param: {padding}!") + + self._max_cache_len = self._left_padding + + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=0, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=padding_mode, + device=device, + dtype=dtype, + ) + + def update_cache(self, x, cache=None): + if cache is None: + new_x = F.pad(x, pad=(self._left_padding, self._right_padding)) + next_cache = cache + else: + new_x = F.pad(x, pad=(0, self._right_padding)) + new_x = torch.cat([cache, new_x], dim=-1) + if self.cache_drop_size > 0: + next_cache = new_x[:, :, : -self.cache_drop_size] + else: + next_cache = new_x + next_cache = next_cache[:, :, -cache.size(-1) :] + return new_x, next_cache + + def forward(self, x, cache=None): + x, cache = self.update_cache(x, cache=cache) + x = super().forward(x) + if cache is None: + return x + else: + return x, cache + + +class CausalConv2D(nn.Conv2d): + """ + A causal version of nn.Conv2d where each location in the 2D matrix would + have no access to locations on its right or down + All arguments are the same as nn.Conv2d except padding which should be + set as None + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: Union[str, int] = 0, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", + device=None, + dtype=None, + ) -> None: + if padding is not None: + raise ValueError("Argument padding should be set to None for CausalConv2D.") + self._left_padding = kernel_size - 1 + self._right_padding = stride - 1 + + padding = 0 + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + padding_mode, + device, + dtype, + ) + + def forward( + self, + x, + ): + x = F.pad( + x, + pad=(self._left_padding, self._right_padding, 0, 0), + ) + x = super().forward(x) + return x + + +class NemoConvSubsampling(torch.nn.Module): + """Convlutional subsampling module, taken from NeMo ASR + (https://github.com/NVIDIA/NeMo/blob/b367413645d5c72db3c2c96e46e95a + 34501479cf/nemo/collections/asr/parts/submodules/subsampling.py) + + Striding Subsampling: "Speech-Transformer: A No-Recurrence + Sequence-to-Sequence Model for Speech Recognition" by Linhao Dong + et al. (https://ieeexplore.ieee.org/document/8462506) + + + Compared with the EncoderConv2D (`input_layer: custom`), this is a + much simplified approach, and uses no LayerNorm and far fewer Conv2Ds. + Moreover, depthwise convolutions are used to reduce FLOPs, but the first + layer is kept as a regular convolution so as not to degrade accuracy. + + `Striding` and `dw_striding` are the same except that the latter uses + depthwise convolutions after the first layer, whereas the former does not. + + Args: + subsampling_factor (int): Time reduction factor + feat_in (int): size of the input features + feat_out (int): size of the output features + subsampling (str): The subsampling technique, choose from + {"striding", "dw-striding", "striding_conv1d", + "dw_striding_conv1d"} + conv_channels (int): Number of channels for the convolution layers, + default is 256. + subsampling_conv_chunking_factor (int): Input chunking factor which + can be -1 (no chunking) 1 (auto) or a power of 2. Default is 1 + activation (Module): activation function, default is nn.ReLU() + is_causal (bool): whether to use causal Conv1/2D, where each step will + have limited access to locations on its right or left + """ + + def __init__( + self, + feat_in, + feat_out, + subsampling_factor=4, + subsampling="dw_striding", + conv_channels=256, + subsampling_conv_chunking_factor=1, + activation=nn.ReLU(), # noqa: B008 + is_causal=False, + ): + super().__init__() + self._subsampling = subsampling + self._conv_channels = conv_channels + self._feat_in = feat_in + self._feat_out = feat_out + + if subsampling_factor % 2 != 0: + raise ValueError("Sampling factor should be a multiply of 2!") + self._sampling_num = int(math.log(subsampling_factor, 2)) + self.subsampling_factor = subsampling_factor + self.is_causal = is_causal + self.subsampling_causal_cond = subsampling in ( + "dw_striding", + "striding", + "striding_conv1d", + ) + + if ( + subsampling_conv_chunking_factor != -1 + and subsampling_conv_chunking_factor != 1 + and subsampling_conv_chunking_factor % 2 != 0 + ): + raise ValueError( + "subsampling_conv_chunking_factor should be -1, 1, or a " "power of 2" + ) + self.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor + + in_channels = 1 + layers = [] + + if subsampling == "dw_striding": + self._stride = 2 + self._kernel_size = 3 + self._ceil_mode = False + + if self.is_causal: + self._left_padding = self._kernel_size - 1 + self._right_padding = self._stride - 1 + self._max_cache_len = subsampling_factor + 1 + else: + self._left_padding = (self._kernel_size - 1) // 2 + self._right_padding = (self._kernel_size - 1) // 2 + self._max_cache_len = 0 + + # Layer 1 + if self.is_causal: + layers.append( + CausalConv2D( + in_channels=in_channels, + out_channels=conv_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=None, + ) + ) + else: + layers.append( + torch.nn.Conv2d( + in_channels=in_channels, + out_channels=conv_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=self._left_padding, + ) + ) + in_channels = conv_channels + layers.append(activation) + + for i in range(self._sampling_num - 1): + if self.is_causal: + layers.append( + CausalConv2D( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=None, + groups=in_channels, + ) + ) + else: + layers.append( + torch.nn.Conv2d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=self._left_padding, + groups=in_channels, + ) + ) + + layers.append( + torch.nn.Conv2d( + in_channels=in_channels, + out_channels=conv_channels, + kernel_size=1, + stride=1, + padding=0, + groups=1, + ) + ) + layers.append(activation) + in_channels = conv_channels + + elif subsampling == "striding": + self._stride = 2 + self._kernel_size = 3 + self._ceil_mode = False + + if self.is_causal: + self._left_padding = self._kernel_size - 1 + self._right_padding = self._stride - 1 + self._max_cache_len = subsampling_factor + 1 + else: + self._left_padding = (self._kernel_size - 1) // 2 + self._right_padding = (self._kernel_size - 1) // 2 + self._max_cache_len = 0 + + for i in range(self._sampling_num): + if self.is_causal: + layers.append( + CausalConv2D( + in_channels=in_channels, + out_channels=conv_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=None, + ) + ) + else: + layers.append( + torch.nn.Conv2d( + in_channels=in_channels, + out_channels=conv_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=self._left_padding, + ) + ) + layers.append(activation) + in_channels = conv_channels + + elif subsampling == "striding_conv1d": + in_channels = feat_in + + self._stride = 2 + self._kernel_size = 5 + self._ceil_mode = False + + if self.is_causal: + self._left_padding = self._kernel_size - 1 + self._right_padding = self._stride - 1 + self._max_cache_len = subsampling_factor + 1 + else: + self._left_padding = (self._kernel_size - 1) // 2 + self._right_padding = (self._kernel_size - 1) // 2 + self._max_cache_len = 0 + + for i in range(self._sampling_num): + if self.is_causal: + layers.append( + CausalConv1D( + in_channels=in_channels, + out_channels=( + feat_out + if self._sampling_num == i + 1 + else conv_channels + ), + kernel_size=self._kernel_size, + stride=self._stride, + padding=None, + ) + ) + else: + layers.append( + torch.nn.Conv1d( + in_channels=in_channels, + out_channels=( + feat_out + if self._sampling_num == i + 1 + else conv_channels + ), + kernel_size=self._kernel_size, + stride=self._stride, + padding=self._left_padding, + ) + ) + layers.append(activation) + in_channels = conv_channels + + elif subsampling == "dw_striding_conv1d": + in_channels = feat_in + + self._stride = 2 + self._kernel_size = 5 + self._ceil_mode = False + + self._left_padding = (self._kernel_size - 1) // 2 + self._right_padding = (self._kernel_size - 1) // 2 + + # Layer 1 + layers.extend( + [ + torch.nn.Conv1d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=self._left_padding, + groups=in_channels, + ), + torch.nn.Conv1d( + in_channels=in_channels, + out_channels=( + feat_out if self._sampling_num == 1 else conv_channels + ), + kernel_size=1, + stride=1, + padding=0, + groups=1, + ), + ] + ) + in_channels = conv_channels + layers.append(activation) + + for i in range(self._sampling_num - 1): + layers.extend( + [ + torch.nn.Conv1d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=self._left_padding, + groups=in_channels, + ), + torch.nn.Conv1d( + in_channels=in_channels, + out_channels=( + feat_out + if self._sampling_num == i + 2 + else conv_channels + ), + kernel_size=1, + stride=1, + padding=0, + groups=1, + ), + ] + ) + layers.append(activation) + in_channels = conv_channels + + else: + raise ValueError(f"Not valid sub-sampling: {subsampling}!") + + if subsampling in ["dw_striding", "striding"]: + in_length = torch.tensor(feat_in, dtype=torch.float) + out_length = calc_length( + lengths=in_length, + all_paddings=self._left_padding + self._right_padding, + kernel_size=self._kernel_size, + stride=self._stride, + ceil_mode=self._ceil_mode, + repeat_num=self._sampling_num, + ) + self.out = torch.nn.Linear(conv_channels * int(out_length), feat_out) + self.conv2d_subsampling = True + elif subsampling in ["striding_conv1d", "dw_striding_conv1d"]: + self.out = None + self.conv2d_subsampling = False + else: + raise ValueError(f"Not valid sub-sampling: {subsampling}!") + + self.conv = torch.nn.Sequential(*layers) + + def get_sampling_frames(self): + return [1, self.subsampling_factor] + + def get_streaming_cache_size(self): + return [0, self.subsampling_factor + 1] + + def forward(self, x, mask): + """ + Forward method for NeMo subsampling. + + Args: + x[Batch, Time, Filters]: torch.Tensor + input tensor + x_mask: torch.Tensor + input mask + + Returns: + x: torch.Tensor + Resulting tensor from subsampling (B, T // + time_reduction_factor, feat_out) + pad_mask: torch.Tensor + tensor of padded hidden state sequences (B, 1, T // + time_reduction_factor) + """ + x = x.unsqueeze(1) if self.conv2d_subsampling else x.transpose(1, 2) + + # split inputs if chunking_factor is set + if self.subsampling_conv_chunking_factor != -1 and self.conv2d_subsampling: + if self.subsampling_conv_chunking_factor == 1: + # if subsampling_conv_chunking_factor is 1, we split only + # if needed. + # avoiding a bug / feature limiting indexing of tensors + # to 2**31. + # see https://github.com/pytorch/pytorch/issues/80020 + x_ceil = 2**31 / self._conv_channels * self._stride * self._stride + need_to_split = torch.numel(x) > x_ceil + else: + # if subsampling_conv_chunking_factor > 1 we always split + need_to_split = True + + if need_to_split: + x, success = self.conv_split_by_batch(x) + if not success: # if unable to split by batch, try by channel + if self._subsampling == "dw_striding": + x = self.conv_split_by_channel(x) + else: + x = self.conv(x) # try anyway + else: + x = self.conv(x) + else: + x = self.conv(x) + + # Flatten Channel and Frequency Axes + if self.conv2d_subsampling: + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).reshape(b, t, -1)) + # Transpose to Channel Last mode + else: + x = x.transpose(1, 2) + + if mask is None: + return x, None + + max_audio_length = x.shape[1] + feature_lens = mask.sum(1) + padding_length = torch.ceil(feature_lens / self.subsampling_factor) + if self.is_causal and self.subsampling_causal_cond: + feature_lens_remainder = feature_lens % self.subsampling_factor + padding_length[feature_lens_remainder != 1] += 1 + pad_mask = torch.arange(0, max_audio_length, device=x.device).expand( + padding_length.size(0), -1 + ) < padding_length.unsqueeze(1) + return x, pad_mask.unsqueeze(1) + + def reset_parameters(self): + # initialize weights + if self._subsampling == "dw_striding": + with torch.no_grad(): + # init conv + scale = 1.0 / self._kernel_size + dw_max = (self._kernel_size**2) ** -0.5 + pw_max = self._conv_channels**-0.5 + + torch.nn.init.uniform_(self.conv[0].weight, -scale, scale) + torch.nn.init.uniform_(self.conv[0].bias, -scale, scale) + + for idx in range(2, len(self.conv), 3): + torch.nn.init.uniform_(self.conv[idx].weight, -dw_max, dw_max) + torch.nn.init.uniform_(self.conv[idx].bias, -dw_max, dw_max) + torch.nn.init.uniform_(self.conv[idx + 1].weight, -pw_max, pw_max) + torch.nn.init.uniform_(self.conv[idx + 1].bias, -pw_max, pw_max) + + # init fc (80 * 64 = 5120 from https://github.com/kssteven418/ + # Squeezeformer/blob/13c97d6cf92f2844d2cb3142b4c5bfa9ad1a8951/ + # src/models/conformer_encoder.py#L487 + fc_scale = (self._feat_out * self._feat_in / self._sampling_num) ** -0.5 + torch.nn.init.uniform_(self.out.weight, -fc_scale, fc_scale) + torch.nn.init.uniform_(self.out.bias, -fc_scale, fc_scale) + + def conv_split_by_batch(self, x): + """Tries to split input by batch, run conv and concat results""" + b, _, _, _ = x.size() + if b == 1: # can't split if batch size is 1 + return x, False + + if self.subsampling_conv_chunking_factor > 1: + cf = self.subsampling_conv_chunking_factor + else: + # avoiding a bug / feature limiting indexing of tensors to 2**31 + # see https://github.com/pytorch/pytorch/issues/80020 + x_ceil = 2**31 / self._conv_channels * self._stride * self._stride + p = math.ceil(math.log(torch.numel(x) / x_ceil, 2)) + cf = 2**p + + new_batch_size = b // cf + if new_batch_size == 0: # input is too big + return x, False + + return ( + torch.cat( + [self.conv(chunk) for chunk in torch.split(x, new_batch_size, 0)] + ), + True, + ) + + def conv_split_by_channel(self, x): + """For dw convs, tries to split input by time, run conv and concat + results""" + x = self.conv[0](x) # full conv2D + x = self.conv[1](x) # activation + + for i in range(self._sampling_num - 1): + _, c, t, _ = x.size() + + if self.subsampling_conv_chunking_factor > 1: + cf = self.subsampling_conv_chunking_factor + else: + # avoiding a bug / feature limiting indexing of tensors + # to 2**31 + # see https://github.com/pytorch/pytorch/issues/80020 + p = math.ceil(math.log(torch.numel(x) / 2**31, 2)) + cf = 2**p + + new_c = int(c // cf) + if new_c == 0: + new_c = 1 + + new_t = int(t // cf) + if new_t == 0: + new_t = 1 + + x = self.channel_chunked_conv( + self.conv[i * 3 + 2], new_c, x + ) # conv2D, depthwise + + # splitting pointwise convs by time + x = torch.cat( + [self.conv[i * 3 + 3](chunk) for chunk in torch.split(x, new_t, 2)], + 2, + ) # conv2D, pointwise + x = self.conv[i * 3 + 4](x) # activation + return x + + def channel_chunked_conv(self, conv, chunk_size, x): + """Performs channel chunked convolution""" + + ind = 0 + out_chunks = [] + for chunk in torch.split(x, chunk_size, 1): + step = chunk.size()[1] + + if self.is_causal: + chunk = nn.functional.pad( + chunk, + pad=( + self._kernel_size - 1, + self._stride - 1, + self._kernel_size - 1, + self._stride - 1, + ), + ) + ch_out = nn.functional.conv2d( + chunk, + conv.weight[ind : ind + step, :, :, :], + bias=conv.bias[ind : ind + step], + stride=self._stride, + padding=0, + groups=step, + ) + else: + ch_out = nn.functional.conv2d( + chunk, + conv.weight[ind : ind + step, :, :, :], + bias=conv.bias[ind : ind + step], + stride=self._stride, + padding=self._left_padding, + groups=step, + ) + out_chunks.append(ch_out) + ind += step + + return torch.cat(out_chunks, 1) + + def change_subsampling_conv_chunking_factor( + self, subsampling_conv_chunking_factor: int + ): + if ( + subsampling_conv_chunking_factor != -1 + and subsampling_conv_chunking_factor != 1 + and subsampling_conv_chunking_factor % 2 != 0 + ): + raise ValueError( + "subsampling_conv_chunking_factor should be -1, 1, or a " "power of 2" + ) + self.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor + + +def calc_length(lengths, all_paddings, kernel_size, stride, ceil_mode, repeat_num=1): + """Calculates the output length of a Tensor passed through a convolution or + max pooling layer""" + add_pad: float = all_paddings - kernel_size + one: float = 1.0 + for i in range(repeat_num): + lengths = torch.div(lengths.to(dtype=torch.float) + add_pad, stride) + one + lengths = torch.ceil(lengths) if ceil_mode else torch.floor(lengths) + return lengths.to(dtype=torch.int) + + +#### multihead attention starts here +class AttModule(nn.Module): + """Attention abstraction module""" + + def __init__(self): + super().__init__() + self.export_mode = False + + def set_export(self, mode=True): + """set the export mode""" + self.export_mode = mode + + def forward( + self, + x: Tensor, + memory: Optional[Tensor] = None, + pos_emb: Optional[Tensor] = None, + att_mask: Optional[Tensor] = None, + ) -> tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: + """AttModule forward + + Args: + x: torch.Tensor + input tensor. + memory: torch.Tensor, optional + memory tensor. + pos_emb: torch.Tensor, optional + positional encoder embedding. + att_mask: torch.Tensor, optional + attention mask tensor. + """ + return x, memory, pos_emb, att_mask + + +class AttBlock(BlockBase, AttModule): + """Attention Block module to support both Attention and Block module.""" + + def memory_dims(self, max_len=False): + """memory dimensions""" + return (1, self.input_size) + + +def masked_softmax( + scores, + mask: Optional[Tensor], +): + if mask is not None: + mask = mask.unsqueeze(1).eq(0) # (batch, 1, time1, time2) + scores = scores.masked_fill(mask, -torch.inf) + attn = torch.softmax(scores, dim=-1).masked_fill( + mask, 0.0 + ) # (batch, head, time1, time2) + else: + attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) + return attn + + +class MultiHeadedAttention(nn.Module): + """Multi-Head Attention layer with optional relative position embedding + and GLU. + + Args: + n_head: int + the number of heads. + n_feat: int + input size features. + dropout_rate: float + dropout rate. + use_LN: bool + apply layer norm or not + dropout_at_output: bool + whether to apply dropout at output + attention_inner_dim: int, optional + the attention dimension used in the class, + it can be different from the input dimension n_feat. + default: -1 (equal to n_feat). + use_pt_scaled_dot_product_attention: bool, optional + if set True, use pytorch scaled dot product attention in training. + NOTE: this will NOT be used in ONNX decoding due to a lack of + support. In that case, we use the original attention + implementation, which shows no regression. + default: False. + n_value: int, optional + if set to values other than -1, use a different dimension for + value. With the default value (i.e. -1), it is backward compatible. + group_size: int, optional. must divide `n_head` + if group_size > 1: GQA + if group_size = 1: MHA + if group_size = n_head: MQA + """ + + inv_sqrt_d_k: torch.jit.Final[float] + h: torch.jit.Final[int] + h_k: torch.jit.Final[int] + g: torch.jit.Final[int] + + def __init__( + self, + n_head, + n_feat, + dropout_rate, + attention_inner_dim=-1, + glu_type="swish", + bias_in_glu=True, + use_pt_scaled_dot_product_attention=False, + n_value=-1, + group_size: int = 1, + ): + super().__init__() + if n_value == -1: + n_value = n_feat + if attention_inner_dim == -1: + attention_inner_dim = n_feat + assert attention_inner_dim % n_head == 0 + + # We assume d_v always equals d_k + self.d_k = attention_inner_dim // n_head + self.inv_sqrt_d_k = 1.0 / math.sqrt(self.d_k) + self.h = n_head + assert n_head % group_size == 0, "group_size must divide n_head" + self.g = group_size + self.h_k = n_head // group_size + + self.linear_q = nn.Linear(n_feat, attention_inner_dim) + self.linear_k = nn.Linear(n_feat, attention_inner_dim // group_size) + self.linear_v = nn.Linear(n_value, attention_inner_dim // group_size) + self.linear_out = nn.Linear(attention_inner_dim // group_size, n_value) + + self.attn = torch.jit.Attribute(None, Optional[Tensor]) + self.dropout = nn.Dropout(p=dropout_rate) + self.dropout_rate = dropout_rate + self.use_pt_scaled_dot_product_attention = use_pt_scaled_dot_product_attention + + if use_pt_scaled_dot_product_attention and group_size > 1: + raise ValueError("Cannot use PT Scaled Attention with GQA") + + # Torchscript eager quantization. Note that these functions below are + # NOOPs and have very little impact on performance unless quantization + # is enabled. + self.quant_q = torch.ao.quantization.QuantStub() + self.quant_x = torch.ao.quantization.QuantStub() + self.dequant = torch.ao.quantization.DeQuantStub() + self.ffunc = torch.ao.nn.quantized.FloatFunctional() + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + pos_k: Tensor, + pos_v: Tensor, + mask: Optional[Tensor], + relative_attention_bias: Optional[Tensor] = None, + ): + """Compute 'Scaled Dot Product Attention'. + + Args: + query: torch.Tensor + query tensor (batch, time1, size) + key: torch.Tensor + key tensor (batch, time2, size) + value: torch.Tensor + value tensor (batch, time1, size) + pos_k: torch.Tensor + key tensor used for relative positional embedding. + pos_v: torch.Tensor + value tensor used for relative positional embedding. + mask: torch.Tensor + mask tensor (batch, time1, time2) + relative_attention_bias: torch.Tensor + bias added to attention logits w.r.t. relative positions + (1, n_head, time1, time2) + """ + n_batch = query.size(0) + + q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) # (b, t, d) + k = self.linear_k(key).view(n_batch, -1, self.h_k, self.d_k) # (b, t, d) + v = self.linear_v(value).view(n_batch, -1, self.h_k, self.d_k) + q = ( + q.transpose(1, 2) + if self.use_pt_scaled_dot_product_attention and not torch.jit.is_scripting() + else q.transpose(1, 2) * self.inv_sqrt_d_k + ) + k = k.transpose(1, 2) # (batch, head_k, time2, d_k) + v = v.transpose(1, 2) # (batch, head_k, time2, d_k) + + if self.use_pt_scaled_dot_product_attention and not torch.jit.is_scripting(): + attn_mask = None + if mask is not None: + mask = mask.unsqueeze(1) + if relative_attention_bias is not None: + attn_mask = mask + relative_attention_bias + else: + attn_mask = mask + if mask.dtype != q.dtype: + attn_mask = attn_mask.to(q.dtype) + + with torch.nn.attention.sdpa_kernel( + [ + torch.nn.attention.SDPBackend.FLASH_ATTENTION, + torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION, + torch.nn.attention.SDPBackend.MATH, + torch.nn.attention.SDPBackend.CUDNN_ATTENTION, + ] + ): + x = torch.nn.functional.scaled_dot_product_attention( + q, + k, + v, + attn_mask=attn_mask, + dropout_p=self.dropout_rate, + ) + else: + if self.h != self.h_k: + q = q.reshape(n_batch, self.g, self.h_k, -1, self.d_k) + A = torch.einsum("b g h t d, b h s d -> b h t s", q, k) + else: + A = torch.matmul(q, k.transpose(-2, -1)) + if pos_k is not None: + if self.h != self.h_k: + B = torch.einsum("b g h t d, t s d -> b h t s", q, pos_k) + else: + reshape_q = ( + q.contiguous() + .view(n_batch * self.h, -1, self.d_k) + .transpose(0, 1) + ) # (t1,nh,dk) + B = torch.matmul( + reshape_q, pos_k.transpose(-2, -1) + ) # pos_k: (t1,dk,t2) + B = B.transpose(0, 1).view( + n_batch, self.h, pos_k.size(0), pos_k.size(1) + ) + scores = A + B + else: + scores = A + + if relative_attention_bias is not None: + scores = scores + relative_attention_bias + + attn = masked_softmax(scores, mask) # (batch, head, time1, time2) + + self.attn = attn + + p_attn = self.dropout(attn) + x = torch.matmul(p_attn.to(v.dtype), v) # (batch, head, time1, d_k) + if pos_v is not None: + reshape_attn = ( + p_attn.contiguous() + .view(n_batch * self.h, pos_v.size(0), pos_v.size(1)) + .transpose(0, 1) + ) # (t1, bh, t2) + + attn_v = ( + torch.matmul(reshape_attn, pos_v) + .transpose(0, 1) + .contiguous() + .view(n_batch, self.h, pos_v.size(0), self.d_k) + ) + x = x + attn_v + x = ( + x.transpose(1, 2).contiguous().view(n_batch, -1, self.h_k * self.d_k) + ) # (batch, time1, d_model) + + return self.linear_out(x) # (batch, time1, d_model) + + +class MultiSequential(torch.nn.Sequential): + """Multi-input multi-output torch.nn.Sequential""" + + @torch.jit.ignore + def forward(self, *args): + """Forward method implementation.""" + for m in self: + args = m(*args) + return args + + +def get_offset(input_layer: str, time_reduction: int): + """Get an offset. We will use the offset for determining #frames of a + subsampled feature. + + Args: + input_layer (str): Type of an input layer + time_reduction (int): time reduction factor for downsampling a feature + Returns: + int: offset + """ + if input_layer in ("conv2d", "nemo_conv") and time_reduction == 4: + return 3 + if input_layer in ("conv2d",) and time_reduction == 6: + return 1 + if input_layer in ("conv2d", "nemo_conv") and time_reduction == 8: + return 7 + return 0 + + +def unfold_tensor(xs_pad, max_seq_len): + """ + For a given tensor with shape of (N, T, D), if sequence length T is + longer than max_seq_len, this function unfold it to a + (NT', max_seq_len, D) where T' is T // max_seq_len. + Args: + xs_pad: N, T, D + """ + _, _, D = xs_pad.shape + xs_pad = xs_pad.transpose(-1, -2) # convert to N, D, T + # N x D x 1 x T => N x (D x max_seq_len) x T' + xs_pad = F.unfold( + xs_pad[..., None, :], + kernel_size=(1, max_seq_len), + stride=(1, max_seq_len), + ) + new_bsz, _, slen = xs_pad.shape + # N x D x max_seq_len x T' + xs_pad = xs_pad.view(new_bsz, -1, max_seq_len, slen) + # N x T' x max_seq_len x D + xs_pad = xs_pad.permute(0, 3, 2, 1).contiguous() + # NT' x max_seq_len x D + xs_pad = xs_pad.view(-1, max_seq_len, D) + return xs_pad diff --git a/python/sglang/srt/multimodal/processors/base_processor.py b/python/sglang/srt/multimodal/processors/base_processor.py index 5c44c4d49..6c6495c5f 100644 --- a/python/sglang/srt/multimodal/processors/base_processor.py +++ b/python/sglang/srt/multimodal/processors/base_processor.py @@ -158,6 +158,7 @@ class BaseMultimodalProcessor(ABC): "pixel_values_videos": Modality.VIDEO, "image_sizes": Modality.IMAGE, "image_grid_thw": Modality.IMAGE, + "image_attention_mask": Modality.IMAGE, "image_emb_mask": Modality.IMAGE, "image_spatial_crop": Modality.IMAGE, "tgt_size": Modality.IMAGE, @@ -170,6 +171,7 @@ class BaseMultimodalProcessor(ABC): "audio_feature_lens": Modality.AUDIO, "input_features": Modality.AUDIO, "input_features_mask": Modality.AUDIO, + "audio_attention_mask": Modality.AUDIO, # Video-related attributes "video_grid_thw": Modality.VIDEO, # Generic attributes that could apply to multiple modalities @@ -251,7 +253,11 @@ class BaseMultimodalProcessor(ABC): @staticmethod def _load_single_item( - data, modality: Modality, frame_count_limit=None, discard_alpha_channel=True + data, + modality: Modality, + frame_count_limit=None, + audio_sample_rate: Optional[int] = None, + discard_alpha_channel=True, ): """ Load a single multimodal data. @@ -268,7 +274,7 @@ class BaseMultimodalProcessor(ABC): elif modality == Modality.VIDEO: return load_video(data, frame_count_limit) elif modality == Modality.AUDIO: - return load_audio(data) + return load_audio(data, audio_sample_rate) except Exception as e: raise RuntimeError(f"Error while loading data {data}: {e}") @@ -282,6 +288,7 @@ class BaseMultimodalProcessor(ABC): image_estimated_frames_iter: Optional[iter] = None, image_scaling_factor: float = 1.0, max_image_frames: int = 30, + audio_sample_rate: Optional[int] = None, ) -> Tuple[List, List]: """ load multimodal data parallelly using iterators. @@ -324,6 +331,7 @@ class BaseMultimodalProcessor(ABC): data, modality, frame_count_limit, + audio_sample_rate, discard_alpha_channel, ) ) @@ -352,6 +360,7 @@ class BaseMultimodalProcessor(ABC): audio_data: Optional[list] = None, return_text: Optional[bool] = True, discard_alpha_channel: bool = True, + audio_sample_rate: Optional[int] = None, ) -> BaseMultiModalProcessorOutput: """ Each frame of video/image will be replaced by a single image token @@ -390,6 +399,7 @@ class BaseMultimodalProcessor(ABC): multimodal_tokens=multimodal_tokens, data_iterators=data_iterators, discard_alpha_channel=discard_alpha_channel, + audio_sample_rate=audio_sample_rate, ) task_info_iter = iter(task_info) futures_iter = iter(futures) diff --git a/python/sglang/srt/multimodal/processors/phi4mm.py b/python/sglang/srt/multimodal/processors/phi4mm.py index aea06506d..8772403db 100644 --- a/python/sglang/srt/multimodal/processors/phi4mm.py +++ b/python/sglang/srt/multimodal/processors/phi4mm.py @@ -1,6 +1,8 @@ import logging from typing import List, Union +from transformers.processing_utils import ProcessorMixin + from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem from sglang.srt.models.phi4mm import Phi4MMForCausalLM from sglang.srt.multimodal.processors.base_processor import ( @@ -10,18 +12,58 @@ from sglang.srt.multimodal.processors.base_processor import ( logger = logging.getLogger(__name__) -_IMAGE_SPECIAL_TOKEN = "<|endoftext10|>" -_IMAGE_SPECIAL_TOKEN_ID = 200010 + +# It is an adapter of hf phi4 mm processor to make it work for sglang +# Ref: https://huggingface.co/microsoft/Phi-4-multimodal-instruct/blob/main/processing_phi4mm.py#L693 +class Phi4MMProcessorAdapter(ProcessorMixin): + def __init__(self, _processor) -> None: + self._processor = _processor + + def __call__(self, **kwargs): + result = self._processor(**kwargs) + + # Map HuggingFace output keys to sglang standard keys + key_mapping = { + "input_image_embeds": "pixel_values", + "input_audio_embeds": "audio_features", + "audio_embed_sizes": "audio_feature_lens", + } + for hf_key, sglang_key in key_mapping.items(): + if hf_key in result: + result[sglang_key] = result[hf_key] + + # Filter out None or empty tensors from the result. + # This prevents the sglang function base_processor.collect_mm_items_from_processor_output() + # from misclassifying audio content as image content, and vice versa. + filtered_result = { + k: v + for k, v in result.items() + if v is not None and (not hasattr(v, "numel") or v.numel() > 0) + } + return filtered_result -class Phi4MMImageProcessor(BaseMultimodalProcessor): +class Phi4MMMultimodalProcessor(BaseMultimodalProcessor): models = [Phi4MMForCausalLM] def __init__(self, hf_config, server_args, _processor): - super().__init__(hf_config, server_args, _processor) + self.processor = Phi4MMProcessorAdapter(_processor) + super().__init__(hf_config, server_args, self.processor) + + # the following CONSTANTS come from hugging-face microsoft/Phi-4-multimodal-instruct's processing_phi4mm.py file + # ref: https://huggingface.co/microsoft/Phi-4-multimodal-instruct/blob/main/processing_phi4mm.py + self.IMAGE_TOKEN = "<|endoftext10|>" + self.AUDIO_TOKEN = "<|endoftext11|>" + self.IM_TOKEN_ID = 200010 + self.AUDIO_TOKEN_ID = 200011 + self.AUDIO_SAMPLE_RATE = 16000 + self.multimodal_tokens = MultimodalSpecialTokens( - image_token=_IMAGE_SPECIAL_TOKEN, - ).build(_processor) + image_token=self.IMAGE_TOKEN, + image_token_id=self.IM_TOKEN_ID, + audio_token=self.AUDIO_TOKEN, + audio_token_id=self.AUDIO_TOKEN_ID, + ).build(self.processor) async def process_mm_data_async( self, @@ -32,46 +74,29 @@ class Phi4MMImageProcessor(BaseMultimodalProcessor): max_req_input_len, **kwargs, ): - if audio_data: - logger.warning( - "Currently SGLang does not support audio data for Phi4MM. We are working on it. You can file an issue to help us prioritize." - ) - audio_data = [] - base_output = self.load_mm_data( prompt=input_text, max_req_input_len=max_req_input_len, audio_data=audio_data, image_data=image_data, multimodal_tokens=self.multimodal_tokens, - ) - if base_output is None: - return None - - res = self.process_mm_data( - input_text=base_output.input_text, - images=base_output.images, - audios=base_output.audios, + audio_sample_rate=self.AUDIO_SAMPLE_RATE, ) - input_ids = res["input_ids"].flatten() - image_offsets = self.get_mm_items_offset( - input_ids=input_ids, - mm_token_id=_IMAGE_SPECIAL_TOKEN_ID, - ) + if base_output.audios is not None: + # hugging-face microsoft/Phi-4-multimodal-instruct's processing_phi4mm.py file requires the audio input to be tuple of (audio, sample_rate) + # ref: https://huggingface.co/microsoft/Phi-4-multimodal-instruct/blob/main/processing_phi4mm.py + base_output.audios = [ + (audio, self.AUDIO_SAMPLE_RATE) for audio in base_output.audios + ] - items = [ - MultimodalDataItem( - feature=res["input_image_embeds"], - image_sizes=res["image_sizes"], - image_emb_mask=res["image_attention_mask"], - offsets=image_offsets, - modality=Modality.IMAGE, - ) - ] + mm_items, input_ids, _ = self.process_and_combine_mm_data( + base_output, self.multimodal_tokens + ) return { - "mm_items": items, "input_ids": input_ids.tolist(), - "im_token_id": _IMAGE_SPECIAL_TOKEN_ID, + "mm_items": mm_items, + "im_token_id": self.IM_TOKEN_ID, + "audio_token_id": self.AUDIO_TOKEN_ID, } diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index ce159a4da..dc6e72d75 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -691,12 +691,17 @@ def decode_video_base64(video_base64): ) # Return an empty array and size tuple if no frames were found -def load_audio(audio_file: str, sr: int = 16000, mono: bool = True) -> np.ndarray: +def load_audio( + audio_file: str, sr: Optional[int] = None, mono: bool = True +) -> np.ndarray: # Use soundfile here, since librosa use it under the hood, # and librosa will not support audio loading in the future import soundfile as sf from scipy.signal import resample + if sr is None: + sr = 16000 + # Load audio data if isinstance(audio_file, bytes): audio, original_sr = sf.read(BytesIO(audio_file)) diff --git a/test/srt/test_vision_openai_server_b.py b/test/srt/test_vision_openai_server_b.py index f6152ea76..534989461 100644 --- a/test/srt/test_vision_openai_server_b.py +++ b/test/srt/test_vision_openai_server_b.py @@ -200,16 +200,17 @@ class TestPhi4MMServer(TestOpenAIVisionServer): "0.70", "--disable-radix-cache", "--max-loras-per-batch", - "1", + "2", "--revision", revision, "--lora-paths", f"vision={constants.HF_HUB_CACHE}/models--microsoft--Phi-4-multimodal-instruct/snapshots/{revision}/vision-lora", + f"speech={constants.HF_HUB_CACHE}/models--microsoft--Phi-4-multimodal-instruct/snapshots/{revision}/speech-lora", ], ) cls.base_url += "/v1" - def get_request_kwargs(self): + def get_vision_request_kwargs(self): return { "extra_body": { "lora_path": "vision", @@ -218,8 +219,21 @@ class TestPhi4MMServer(TestOpenAIVisionServer): } } - def test_video_chat_completion(self): - pass + def get_audio_request_kwargs(self): + return { + "extra_body": { + "lora_path": "speech", + "top_k": 1, + "top_p": 1.0, + } + } + + def test_audio_chat_completion(self): + self._test_audio_speech_completion() + # TODO: currently phi4-mm cannot pass this test. + # We are investigating this issue. + # Response: La ciudad está situada en la costa este de la isla, en la desembocadura del río St. Lawrence. + # self._test_audio_ambient_completion() class TestVILAServer(TestOpenAIVisionServer): diff --git a/test/srt/test_vision_openai_server_common.py b/test/srt/test_vision_openai_server_common.py index 5d958fd5a..341db654e 100644 --- a/test/srt/test_vision_openai_server_common.py +++ b/test/srt/test_vision_openai_server_common.py @@ -47,6 +47,12 @@ class TestOpenAIVisionServer(CustomTestCase): def tearDownClass(cls): kill_process_tree(cls.process.pid) + def get_audio_request_kwargs(self): + return self.get_request_kwargs() + + def get_vision_request_kwargs(self): + return self.get_request_kwargs() + def get_request_kwargs(self): return {} @@ -71,7 +77,7 @@ class TestOpenAIVisionServer(CustomTestCase): }, ], temperature=0, - **(self.get_request_kwargs()), + **(self.get_vision_request_kwargs()), ) assert response.choices[0].message.role == "assistant" @@ -134,7 +140,7 @@ class TestOpenAIVisionServer(CustomTestCase): }, ], temperature=0, - **(self.get_request_kwargs()), + **(self.get_vision_request_kwargs()), ) assert response.choices[0].message.role == "assistant" @@ -177,7 +183,7 @@ class TestOpenAIVisionServer(CustomTestCase): }, ], temperature=0, - **(self.get_request_kwargs()), + **(self.get_vision_request_kwargs()), ) assert response.choices[0].message.role == "assistant" @@ -333,7 +339,7 @@ class TestOpenAIVisionServer(CustomTestCase): temperature=0, max_tokens=1024, stream=False, - **(self.get_request_kwargs()), + **(self.get_vision_request_kwargs()), ) video_response = response.choices[0].message.content @@ -376,7 +382,7 @@ class TestOpenAIVisionServer(CustomTestCase): + r"""\}""" ) - extra_kwargs = self.get_request_kwargs() + extra_kwargs = self.get_vision_request_kwargs() extra_kwargs.setdefault("extra_body", {})["regex"] = regex response = client.chat.completions.create( @@ -443,7 +449,7 @@ class TestOpenAIVisionServer(CustomTestCase): {"role": "user", "content": content}, ], temperature=0, - **(self.get_request_kwargs()), + **(self.get_vision_request_kwargs()), ) assert response.choices[0].message.role == "assistant" @@ -486,7 +492,7 @@ class TestOpenAIVisionServer(CustomTestCase): temperature=0, max_tokens=128, stream=False, - **(self.get_request_kwargs()), + **(self.get_audio_request_kwargs()), ) audio_response = response.choices[0].message.content @@ -500,7 +506,7 @@ class TestOpenAIVisionServer(CustomTestCase): self.assertIsNotNone(audio_response) self.assertGreater(len(audio_response), 0) - return audio_response + return audio_response.lower() def _test_audio_speech_completion(self): # a fragment of Trump's speech