# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Sequence from typing import cast import torch import torch.nn as nn DEFAULT_MAX_AUDIO_LEN_S = 655 DEFAULT_MERGE_FACTOR = 4 # Default convolution parameters: (padding, kernel_size, stride) # These correspond to the two conv layers in GlmAsrEncoder DEFAULT_CONV_PARAMS = [(1, 3, 1), (1, 3, 2)] def _calculate_conv_output_length( input_length: torch.Tensor, padding: int, kernel_size: int, stride: int ) -> torch.Tensor: """Calculate Conv1d output length using standard formula.""" # in sync with `hf_processor._get_audio_token_length` return (input_length + 2 * padding - (kernel_size - 1) - 1) // stride + 1 def _as_list_chunk_counts( chunk_counts: torch.Tensor | list[int] | list[torch.Tensor], ) -> list[int]: if isinstance(chunk_counts, torch.Tensor): return chunk_counts.tolist() if chunk_counts and isinstance(chunk_counts[0], torch.Tensor): tensor_counts = cast(list[torch.Tensor], chunk_counts) return [int(c.item()) for c in tensor_counts] return [int(c) for c in chunk_counts] def _normalize_chunk_counts( chunk_counts: torch.Tensor | list[int] | list[torch.Tensor] | None, num_chunks: int, ) -> list[int]: if chunk_counts is None: return [1] * num_chunks return _as_list_chunk_counts(chunk_counts) def _get_audio_output_lengths_from_lengths( audio_lengths: torch.Tensor, merge_factor: int, conv_params: list[tuple[int, int, int]], ) -> torch.Tensor: for padding, kernel_size, stride in conv_params: audio_lengths = _calculate_conv_output_length( audio_lengths, padding, kernel_size, stride ) return (audio_lengths - merge_factor) // merge_factor + 1 def _get_audio_output_lengths_from_mask( mask: torch.Tensor, merge_factor: int, conv_params: list[tuple[int, int, int]], ) -> torch.Tensor: audio_lengths = mask.sum(-1) return _get_audio_output_lengths_from_lengths( audio_lengths, merge_factor, conv_params ) def _get_audio_output_lengths_for_tower( audio_tower: nn.Module, audio_lengths: torch.Tensor, merge_factor: int, conv_params: list[tuple[int, int, int]], ) -> torch.Tensor: """ Calculate the output lengths after audio processing. The output length accounts for: 1. Convolution layers (downsampling) 2. Merge factor (further downsampling during projection) Args: audio_tower: The audio encoder module audio_lengths: Input feature lengths [batch_size] merge_factor: Factor for merging adjacent features conv_params: List of (padding, kernel_size, stride) for each conv layer Returns: Output lengths after all processing [batch_size] """ # First, calculate the output length after convolutions if hasattr(audio_tower, "_get_feat_extract_output_lengths"): _, conv_output_lengths = audio_tower._get_feat_extract_output_lengths( audio_lengths ) else: conv_output_lengths = audio_lengths for padding, kernel_size, stride in conv_params: conv_output_lengths = _calculate_conv_output_length( conv_output_lengths, padding, kernel_size, stride ) # Then, apply merge_factor to get final output length # Formula: (conv_output_lengths - merge_factor) // merge_factor + 1 return (conv_output_lengths - merge_factor) // merge_factor + 1 def _flatten_audio_features_by_length( audio_features: torch.Tensor, audio_output_lengths: torch.Tensor, ) -> torch.Tensor: num_chunks, max_audio_tokens, embed_dim = audio_features.shape audio_output_lengths = audio_output_lengths.unsqueeze(1) audio_features_mask = ( torch.arange(max_audio_tokens) .expand(num_chunks, max_audio_tokens) .to(audio_output_lengths.device) < audio_output_lengths ) return audio_features[audio_features_mask].view(-1, embed_dim) def _group_audio_embeddings( chunk_embeddings: Sequence[torch.Tensor], chunk_counts: Sequence[int], ) -> tuple[torch.Tensor, ...]: grouped_embeddings = [] current_idx = 0 for count in chunk_counts: audio_chunks = chunk_embeddings[current_idx : current_idx + count] grouped_embeddings.append(torch.cat(audio_chunks, dim=0)) current_idx += count return tuple(grouped_embeddings)