# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # Adapted from # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py # Copyright 2025 The vLLM team. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX # and OPT implementations in this library. It has been modified from its # original forms to accommodate minor architectural differences compared # to GPT-NeoX and OPT used by the Meta AI team that trained the model. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Inference-only IBM Granite speech model.""" import math from collections.abc import Iterable, Mapping from typing import Annotated, Optional, Union import torch import torch.nn.functional as F from torch import nn from transformers import BatchFeature, PretrainedConfig from vllm.config import CacheConfig, VllmConfig from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems) from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems, MultiModalDataParser) from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape from .blip2 import Blip2QFormerModel from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) from .utils import (AutoWeightsLoader, embed_multimodal, init_vllm_registered_model, maybe_prefix) ### Audio Input class GraniteSpeechAudioInputs(TensorSchema): """ Audio input features for Granite Speech model. Dimensions: - b: Batch size - fi: Number of input features from the Mel spectrogram. - fo: Number of output features, i.e. the embedding size. - 160: Fixed feature dimension for Mel spectrogram features """ input_features: Annotated[torch.Tensor, TensorShape("b", "fi", 160)] """Audio input features.""" input_features_mask: Annotated[torch.Tensor, TensorShape("b", "fo")] """Mask for variable length audio features.""" audio_embed_sizes: Annotated[list[int], TensorShape("b")] """List of audio embedding sizes for each item in batch.""" class GraniteSpeechMultiModalProcessingInfo(BaseProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"audio": 1} # There is no limit to the maximum number of audio tokens that can be # encoded as features; we pick ~5000 as a number that is probably higher # than we would expect to encounter. The sequence of length # get_max_audio_len() produces get_max_audio_tokens(). def get_max_audio_tokens(self): return 5001 def get_max_audio_len(self): return 8000000 ### Input Processing & Multimodal utils class GraniteSpeechMultiModalProcessor( BaseMultiModalProcessor[GraniteSpeechMultiModalProcessingInfo]): def _get_data_parser(self) -> MultiModalDataParser: feature_extractor = self.info.get_hf_processor().audio_processor sampling_rate = feature_extractor.melspec_kwargs["sample_rate"] return MultiModalDataParser(target_sr=sampling_rate) def _get_mm_fields_config( self, hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: return dict( input_features=MultiModalFieldConfig.batched("audio"), audio_embed_sizes=MultiModalFieldConfig.batched("audio"), ) def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargsItems, ) -> list[PromptUpdate]: processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) tokenizer = self.info.get_tokenizer() feature_extractor = processor.audio_processor vocab = tokenizer.get_vocab() # Use getattr with default to be compatible with transformers<4.48 audio_token = getattr(processor, "audio_token", "<|audio|>") audio_token_id = vocab[audio_token] def get_replacement(item_idx: int): audios = mm_items.get_items("audio", AudioProcessorItems) audio = audios.get(item_idx) audio_length = audio.shape[-1] num_projector_features = feature_extractor._get_num_audio_features( [audio_length])[0] return [audio_token_id] * num_projector_features return [ PromptReplacement( modality="audio", target=[audio_token_id], replacement=get_replacement, ) ] def _call_hf_processor( self, prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], tok_kwargs: Mapping[str, object], ) -> BatchFeature: mm_data = dict(mm_data) audios = mm_data.pop("audios", []) if audios: # GraniteSpeechFeatureExtractor accepts "audio" mm_data["audio"] = audios processed_outputs = super()._call_hf_processor( prompt=prompt, mm_data=mm_data, mm_kwargs=mm_kwargs, tok_kwargs=tok_kwargs, ) if "audio" in mm_data: # Calculate the number of audio tokens per entry in the batch; # This is used to split the batch back out after padding. audio_token_index = self.info.get_hf_config().audio_token_index processed_outputs["audio_embed_sizes"] = [ torch.sum(indices == audio_token_index).item() for indices in processed_outputs["input_ids"] ] return processed_outputs class GraniteSpeechDummyInputsBuilder( BaseDummyInputsBuilder[GraniteSpeechMultiModalProcessingInfo]): def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], ) -> MultiModalDataDict: num_audios = mm_counts.get("audio", 0) return { "audio": self._get_dummy_audios( length=self.info.get_max_audio_len(), num_audios=num_audios, ) } def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_audios = mm_counts.get("audio", 0) hf_processor = self.info.get_hf_processor() audio_token = getattr(hf_processor, "audio_token", "<|audio|>") return audio_token * num_audios ### QFormer Projector class GraniteSpeechEncoderProjector(nn.Module): def __init__( self, config: PretrainedConfig, cache_config: CacheConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): super().__init__() self.hidden_size = config.projector_config.hidden_size self.downsample_rate = config.downsample_rate self.window_size = config.window_size self.num_queries = config.window_size // config.downsample_rate self.query = nn.Parameter( torch.zeros(1, self.num_queries, config.projector_config.hidden_size)) # NOTE - this is implemented generically in transformers, # but for now we create the QFormer model directly since # all existing models use this for the projector. self.qformer = Blip2QFormerModel( config.projector_config, quant_config=quant_config, cache_config=cache_config, prefix=f"{prefix}.qformer", ) self.linear = nn.Linear(config.projector_config.hidden_size, config.text_config.hidden_size) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, seq_len, dim = hidden_states.size() nblocks = math.ceil(seq_len / self.window_size) pad = nblocks * self.window_size - seq_len hidden_states = nn.functional.pad(hidden_states, (0, 0, 0, pad), "constant", 0) hidden_states = hidden_states.view(batch_size * nblocks, self.window_size, dim) last_hidden_state = self.qformer( query_embeds=self.query.data, encoder_hidden_states=hidden_states, ) query_proj = self.linear( last_hidden_state.view( batch_size, nblocks * self.window_size // self.downsample_rate, -1, )) return query_proj # Encoder - conformer is adapted from: https://github.com/lucidrains/conformer.git # NOTE - it would be nice to see if we can align this with other models using # conformer in vLLM, e.g., phi4mm audio. class GraniteSpeechConformerFeedForward(nn.Module): """Feedforward module for conformer encoder blocks.""" def __init__(self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = ""): super().__init__() self.pre_norm = nn.LayerNorm(config.hidden_dim) self.up_proj = ColumnParallelLinear( input_size=config.hidden_dim, output_size=config.hidden_dim * config.feedforward_mult, quant_config=quant_config, prefix=f"{prefix}.up_proj", ) self.silu = nn.SiLU() self.down_proj = RowParallelLinear( input_size=config.hidden_dim * config.feedforward_mult, output_size=config.hidden_dim, quant_config=quant_config, prefix=f"{prefix}.down_proj", ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.pre_norm(hidden_states) hidden_states, _ = self.up_proj(hidden_states) hidden_states = self.silu(hidden_states) hidden_states, _ = self.down_proj(hidden_states) return hidden_states class GraniteSpeechConformerAttention(nn.Module): """Attention for conformer blocks using Shaw's relative positional embeddings. See the following [paper](https://arxiv.org/pdf/1803.02155) for more details. """ def __init__(self, config: PretrainedConfig, prefix: str = ""): super().__init__() inner_dim = config.dim_head * config.num_heads self.max_pos_emb = config.max_pos_emb self.context_size = config.context_size self.num_heads = config.num_heads self.dim_head = config.dim_head self.scale = self.dim_head**-0.5 self.pre_norm = nn.LayerNorm(config.hidden_dim) self.to_q = nn.Linear(config.hidden_dim, inner_dim, bias=False) self.to_kv = nn.Linear(config.hidden_dim, inner_dim * 2, bias=False) self.to_out = nn.Linear(inner_dim, config.hidden_dim) self.rel_pos_emb = nn.Embedding(2 * self.max_pos_emb + 1, self.dim_head) if self.context_size <= 0 or self.context_size > self.max_pos_emb: raise ValueError( "Context size is either less than 0 or exceeds the max_pos_emb" ) def forward(self, hidden_states: torch.Tensor, attention_dists: torch.Tensor) -> torch.Tensor: hidden_states = self.pre_norm(hidden_states) bsz, num_features, _ = hidden_states.shape num_blocks = math.ceil(num_features / self.context_size) remainder = num_features % self.context_size if remainder > 0: # right padding to reach block size hidden_states = torch.nn.functional.pad( hidden_states, (0, 0, 0, self.context_size - remainder)) # NOTE: would be nice to try to use qkvparallellinear # here for this block attention implementation if possible query_states = self.to_q(hidden_states) key_states, value_states = self.to_kv(hidden_states).chunk(2, dim=-1) query_states = query_states.reshape(bsz, num_blocks, self.context_size, self.num_heads, -1).transpose(2, 3) key_states = key_states.reshape(bsz, num_blocks, self.context_size, self.num_heads, -1).transpose(2, 3) value_states = value_states.reshape(bsz, num_blocks, self.context_size, self.num_heads, -1).transpose(2, 3) # shaw's relative positional embedding dist = attention_dists.to(hidden_states.device) rel_pos_emb = self.rel_pos_emb(dist) rel_pos_emb_expanded = rel_pos_emb.view([1, 1, 1] + list(rel_pos_emb.shape)) pos_attn = torch.sum(query_states.unsqueeze(-2) * rel_pos_emb_expanded, dim=-1) * self.scale if remainder > 0: # masked attention in the extended block mask = torch.ones(self.context_size, self.context_size, dtype=bool, device=hidden_states.device) mask[:remainder, :remainder] = 0 mask_value = -torch.finfo(pos_attn.dtype).max pos_attn[:, -1, :].masked_fill_(mask, mask_value) with torch.nn.attention.sdpa_kernel( torch.nn.attention.SDPBackend.MATH): out = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=pos_attn, scale=self.scale) out = out.transpose(2, 3).reshape(bsz, hidden_states.shape[1], -1) return self.to_out(out[:, :num_features, :]) class GraniteSpeechConformerDepthWiseConv1d(nn.Module): """Wrapper for padded 1D pointwise convolution.""" def __init__(self, chan_in: int, chan_out: int, kernel_size: int, prefix: str = ""): super().__init__() # Padding for the 1D conv is symmetric or close (i.e., offset by one). pad = kernel_size // 2 pad_offset = (kernel_size + 1) % 2 self.padding = (pad, pad - pad_offset) self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, groups=chan_in, bias=False) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = F.pad(hidden_states, self.padding) return self.conv(hidden_states) class GraniteSpeechConformerConvModule(nn.Module): """Conformer conv module consisting of several 1D/depthwise 1D convolutional layers. """ def __init__(self, config: PretrainedConfig, prefix: str = ""): super().__init__() inner_dim = config.hidden_dim * config.conv_expansion_factor self.norm = nn.LayerNorm(config.hidden_dim) self.up_conv = nn.Conv1d(config.hidden_dim, inner_dim * 2, 1) self.glu = nn.GLU(dim=1) self.depth_conv = GraniteSpeechConformerDepthWiseConv1d( inner_dim, inner_dim, kernel_size=config.conv_kernel_size, prefix=f"{prefix}.depth_conv", ) self.silu = nn.SiLU() self.batch_norm = nn.BatchNorm1d(inner_dim) self.down_conv = nn.Conv1d(inner_dim, config.hidden_dim, 1) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.norm(hidden_states) hidden_states = self.up_conv(hidden_states.permute(0, 2, 1)) hidden_states = self.glu(hidden_states) hidden_states = self.depth_conv(hidden_states) hidden_states = self.silu(self.batch_norm(hidden_states)) hidden_states = self.down_conv(hidden_states).permute(0, 2, 1) return hidden_states class GraniteSpeechConformerBlock(nn.Module): """Conformer block, consisting largely of linear layers, attention, and convolutional layers.""" def __init__(self, config: PretrainedConfig, prefix: str = ""): super().__init__() self.ff1 = GraniteSpeechConformerFeedForward(config, prefix=f"{prefix}.ff1") self.attn = GraniteSpeechConformerAttention(config, prefix=f"{prefix}.attn") self.conv = GraniteSpeechConformerConvModule(config, prefix=f"{prefix}.conv") self.ff2 = GraniteSpeechConformerFeedForward(config, prefix=f"{prefix}.ff2") self.post_norm = nn.LayerNorm(config.hidden_dim) def forward(self, hidden_states: torch.Tensor, attention_dists: torch.Tensor) -> torch.Tensor: hidden_states = 0.5 * self.ff1(hidden_states) + hidden_states hidden_states = self.attn( hidden_states, attention_dists=attention_dists) + hidden_states hidden_states = self.conv(hidden_states) + hidden_states hidden_states = 0.5 * self.ff2(hidden_states) + hidden_states hidden_states = self.post_norm(hidden_states) return hidden_states class GraniteSpeechCTCEncoder(nn.Module): """CTC Encoder comprising conformer blocks and additional linear layers.""" def __init__(self, config: PretrainedConfig, prefix: str, quant_config: Optional[QuantizationConfig] = None): super().__init__() self.config = config # Precompute clamped relative positional encoding distances seq = torch.arange(config.context_size) relpos_dist = seq.view(-1, 1) - seq.view(1, -1) self.attention_dists = torch.clamp( relpos_dist, -config.context_size, config.context_size) + config.max_pos_emb self.input_linear = nn.Linear(config.input_dim, config.hidden_dim, bias=True) self.layers = nn.ModuleList([ GraniteSpeechConformerBlock( config, prefix=f"{prefix}.layers.{idx}", ) for idx in range(config.num_layers) ]) self.out = ColumnParallelLinear( input_size=config.hidden_dim, output_size=config.output_dim, bias=True, quant_config=quant_config, prefix=f"{prefix}.out", ) self.out_mid = RowParallelLinear( input_size=config.output_dim, output_size=config.hidden_dim, bias=True, quant_config=quant_config, prefix=f"{prefix}.out_mid", ) self.softmax = nn.Softmax(dim=-1) self.num_layers = config.num_layers def forward(self, hidden_states: torch.Tensor): hidden_states = self.input_linear(hidden_states) for idx, layer in enumerate(self.layers, start=1): hidden_states = layer(hidden_states, attention_dists=self.attention_dists) if idx == self.num_layers // 2: hidden_states_mid = hidden_states.clone() hidden_states_mid, _ = self.out(hidden_states_mid) hidden_states_mid = self.softmax(hidden_states_mid) hidden_states_mid, _ = self.out_mid(hidden_states_mid) hidden_states += hidden_states_mid return hidden_states @MULTIMODAL_REGISTRY.register_processor( GraniteSpeechMultiModalProcessor, info=GraniteSpeechMultiModalProcessingInfo, dummy_inputs=GraniteSpeechDummyInputsBuilder) class GraniteSpeechForConditionalGeneration( nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, ): packed_modules_mapping = { "qkv_proj": [ "q_proj", "k_proj", "v_proj", ], "gate_up_proj": [ "gate_proj", "up_proj", ], } @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: if modality.startswith("audio"): return "<|audio|>" raise ValueError("Only audio modality is supported") def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config cache_config = vllm_config.cache_config self.config = config self.quant_config = quant_config self.cache_config = cache_config # The language model is typically a Granite LLM self.language_model = init_vllm_registered_model( vllm_config=vllm_config, hf_config=config.text_config, prefix=maybe_prefix(prefix, "language_model"), ) # Conformer encoder self.encoder = GraniteSpeechCTCEncoder( config=config.encoder_config, quant_config=quant_config, prefix=f"{prefix}.encoder", ) # Blip2 QFormer self.projector = GraniteSpeechEncoderProjector( config=config, quant_config=quant_config, cache_config=cache_config, prefix=f"{prefix}.projector", ) self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) def _parse_and_validate_audio_input( self, **kwargs: object, ) -> Optional[GraniteSpeechAudioInputs]: input_features = kwargs.pop("input_features", None) input_features_mask = kwargs.pop("input_features_mask", None) audio_embed_sizes = kwargs.pop("audio_embed_sizes", None) if input_features is None: return None # If we have a batch of variable feature length audio clips, we need # to mask the features; usually we would get an input_features_mask # from the processor, but we handle rebuilding it here since # vLLM generally processes everything independently + batches. if input_features_mask is None: input_features_mask = self._build_input_features_mask( audio_embed_sizes) if not isinstance(input_features, (torch.Tensor, list)): raise ValueError("Incorrect type of audio input features. " f"Got type: {type(input_features)}") if input_features_mask is not None and not isinstance( input_features_mask, torch.Tensor): raise ValueError("Incorrect type of audio input features mask. " f"Got type: {type(input_features_mask)}") if isinstance(input_features, torch.Tensor): # Granite speech currently only allows one audio token per instance # and features are already unsqueezed in the processor, so one # instance will have shape [1, {num_features}, 160]. As such, # input features will usually be of shape # [bsz, 1, num_features, 160], which we squeeze to be 3D here. if len(input_features.shape) == 4: input_features = input_features.squeeze(1) if len(input_features.shape) != 3: raise ValueError( "Squeezed input features should be 3D but are of shape " f"{input_features.shape}") input_features = input_features.to( self.encoder.input_linear.weight.dtype) else: # Otherwise we have a list of tensors, which are almost certainly # differing in their respective numbers of audio features; # stack them into a 3D tensor of size [bsz, most_num_features, 160]. input_features = self._pad_and_stack_input_features( input_features, ).to(self.encoder.input_linear.weight.dtype) return GraniteSpeechAudioInputs( input_features=input_features, input_features_mask=input_features_mask, audio_embed_sizes=audio_embed_sizes.flatten().tolist(), ) def _build_input_features_mask( self, audio_embed_sizes: torch.Tensor, ) -> torch.Tensor: """Calculate the input features mask, which will generally be used to mask the padded features for all entries in the batch except for those with the most audio features. Args: audio_embed_sizes: torch.Tensor Tensor of num features in each seq in the batch. Returns: torch.Tensor: Mask of shape (bsz, num_features) to be applied to the audio features prior to splitting the audio embeddings. """ most_audio_features = torch.max(audio_embed_sizes).item() mask_indices = torch.arange( most_audio_features, device=audio_embed_sizes.device, ).view(1, -1) input_features_mask = mask_indices < audio_embed_sizes.view(-1, 1) return input_features_mask def _pad_and_stack_input_features( self, input_features: list[torch.Tensor], ) -> torch.Tensor: """Given a list of input features of varying length, pad them to the same length and stack them into a torch.Tensor. NOTE: Usually, padding is done in the input processor/feature extractor and zero padded prior to the computation of the Mel features; the resulting values are only constant within a batch and generally nonzero (i.e., slightly negative nums); we should validate that this is okay since we don't use a feature attention mask, but the more important thing is that we apply the input_features_mask with variable len batches. Args: input_features: list[torch.Tensor] Input features to be coerced into a tensor. Returns: torch.Tensor: Tensor of shape [bsz, num_features, 160], where num_features is the max number of features of any entry in the batch. """ # Input features are of shape [bsz, num_features, 160] feat_lens = [feats.shape[1] for feats in input_features] padding = [max(feat_lens) - length for length in feat_lens] # TODO (Alex) - Validate that it's okay to zero pad like this; # in transformers we zero pad prior to calculating the speech features, # so the value is not zero and is dependent on the batched features. padded = [ torch.nn.functional.pad(feats, (0, 0, 0, pad, 0, 0)) for feats, pad in zip(input_features, padding) ] stacked_features = torch.cat(padded, dim=0).to(input_features[0]) return stacked_features def _process_audio_input( self, audio_input: GraniteSpeechAudioInputs, ) -> tuple[torch.Tensor]: """Compute the audio features to be merged into the LLM embeddings. Args: audio_input: GraniteSpeechAudioInputs Audio inputs object containing Mel features, an input features mask, and the (flattened) number of audio tokens per instance. Returns: tuple[torch.Tensor]: List of length bsz. """ # TODO (Alex) - support embedding inputs encoder_embeds = self.encoder(audio_input["input_features"]) # [bsz, , 4096] projected_embeds = self.projector(encoder_embeds) # Apply mask on variable length audio features masked_embeds = projected_embeds[audio_input["input_features_mask"]] # Split variable length features into a tuple return torch.split(masked_embeds, audio_input["audio_embed_sizes"]) def get_multimodal_embeddings( self, **kwargs: object, ) -> MultiModalEmbeddings: """Compute the audio embeddings if audio inputs are present.""" audio_input = self._parse_and_validate_audio_input(**kwargs) if audio_input is None: return [] return None audio_features = self._process_audio_input(audio_input) return audio_features def get_input_embeddings( self, input_ids: torch.Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, ) -> torch.Tensor: """Compute the merged LLM / audio embeddings.""" if multimodal_embeddings is None \ or len(multimodal_embeddings) == 0: return self.language_model.get_input_embeddings(input_ids) inputs_embeds = embed_multimodal( input_ids, self.config.audio_token_index, self.language_model.model.get_input_embeddings, multimodal_embeddings, ) return inputs_embeds def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ) -> Union[torch.Tensor, IntermediateTensors]: if intermediate_tensors is not None: inputs_embeds = None # NOTE: In v1, inputs_embeds is always generated at model runner, this # condition is for v0 compatibility. elif inputs_embeds is None: audio_embeds = self.get_multimodal_embeddings(**kwargs) inputs_embeds = self.get_input_embeddings(input_ids, audio_embeds) input_ids = None model_output = self.language_model(input_ids, positions, intermediate_tensors, inputs_embeds) return model_output def compute_logits( self, hidden_states: torch.Tensor, ) -> Optional[torch.Tensor]: return self.language_model.compute_logits(hidden_states) def load_weights( self, weights: Iterable[tuple[str, torch.Tensor]], ) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) def get_mm_mapping(self) -> MultiModelKeys: """Get the module prefix in multimodal models.""" return MultiModelKeys.from_string_field( language_model="language_model", connector="projector", tower_model="encoder", )