# SPDX-License-Identifier: Apache-2.0 import math import re from functools import lru_cache from typing import (Dict, Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union) import numpy as np import scipy.signal import torch import torch.nn as nn import torchvision.transforms as T from PIL import Image from transformers import PretrainedConfig, SiglipVisionConfig from transformers.utils import logging from vllm.config import VllmConfig from vllm.distributed import get_pp_group from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, InputContext) from vllm.inputs.data import TokenInputs, token_inputs from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead) from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors from vllm.sequence import IntermediateTensors, SequenceData from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config from .idefics2_vision_model import Idefics2VisionTransformer from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsV0Only from .phi4mm_audio import AudioEmbedding from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix # <|endoftext10|> (see vocab.json in hf model) _IMAGE_PLACEHOLDER_TOKEN_ID = 200010 # <|endoftext11|> _AUDIO_PLACEHOLDER_TOKEN_ID = 200011 _AUDIO_MAX_SOUNDFILE_SIZE = 241_000 DUMMY_SAMPLING_FREQUENCY = 16_000 # kHz DYNAMIC_HD = 16 AUDIO_TOKEN_PATTERN = r"<\|audio_(\d+)\|>" IMAGE_TOKEN_PATTERN = r"<\|image_(\d+)\|>" SIGLIP_NAME = "siglip-so400m-patch14-448" VISION_ENCODER_TO_PROCESSING_CONFIG = { 'siglip-so400m-patch14-448': { 'dynamic_hd': 16, 'vit_image_size': 448, 'vit_patch_size': 14, 'token_compression_factor': 2, }, } logger = logging.get_logger(__name__) # This is a workaround to prevent text (user input) + audio + image # from being used in the same prompt. # It includes token ids for "/n" and tokens in added_tokens_decoder # from the tokenizer_confg.json file. NON_USER_INPUT_TOKENS = { 198, 200010, 200011, 199999, 200018, 200019, 200020, 200021, 200022, 200023, 200024, 200025, 200026, 200027, 200028 } def get_max_dummy_image(ctx: InputContext): hf_config = ctx.get_hf_config() vision_encoder_name = hf_config.img_processor if vision_encoder_name is None: vision_encoder_name = SIGLIP_NAME prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[vision_encoder_name] dynamic_hd_size = prepro_config['dynamic_hd'] vit_image_size = prepro_config['vit_image_size'] max_side = vit_image_size * dynamic_hd_size dummy_image = dummy_image_for_phi4mm(vit_image_size, max_side) return dummy_image # image token length def get_max_phi4mm_image_tokens(ctx: InputContext): dummy_image = get_max_dummy_image(ctx) hf_config = ctx.get_hf_config() vision_encoder_name = hf_config.img_processor if vision_encoder_name is None: vision_encoder_name = SIGLIP_NAME prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[vision_encoder_name] dynamic_hd_size = prepro_config['dynamic_hd'] vit_image_size = prepro_config['vit_image_size'] vit_patch_size = prepro_config['vit_patch_size'] token_compression_factor = prepro_config['token_compression_factor'] image_num_tokens = _compute_num_image_tokens(dummy_image, dynamic_hd_size, vit_image_size, vit_patch_size, token_compression_factor) return image_num_tokens def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): best_ratio_diff = float('inf') best_ratio = (1, 1) area = width * height for ratio in target_ratios: target_aspect_ratio = ratio[0] / ratio[1] ratio_diff = abs(aspect_ratio - target_aspect_ratio) if ratio_diff < best_ratio_diff: best_ratio_diff = ratio_diff best_ratio = ratio elif ratio_diff == best_ratio_diff: if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: best_ratio = ratio return best_ratio def _find_target_aspect_ratio(image, image_size, max_num, min_num): orig_width, orig_height = image.size w_crop_num = math.ceil(orig_width / float(image_size)) h_crop_num = math.ceil(orig_height / float(image_size)) if w_crop_num * h_crop_num > max_num: aspect_ratio = orig_width / orig_height # calculate the existing image aspect ratio target_ratios = set((i, j) for i in range(1, max_num + 1) for j in range(1, max_num + 1) if i * j <= max_num and i * j >= min_num) target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) # find the closest aspect ratio to the target target_aspect_ratio = find_closest_aspect_ratio( aspect_ratio, target_ratios, orig_width, orig_height, image_size) # calculate the target width and height target_width = image_size * target_aspect_ratio[0] target_height = image_size * target_aspect_ratio[1] logger.debug("target_aspect_ratio: %s", target_aspect_ratio) else: target_width = image_size * w_crop_num target_height = image_size * h_crop_num target_aspect_ratio = (w_crop_num, h_crop_num) return target_aspect_ratio, target_height, target_width def _get_padding_size(image, target_height, target_width): orig_width, orig_height = image.size ratio_width = target_width / orig_width ratio_height = target_height / orig_height if ratio_width < ratio_height: padding_width = 0 padding_height = target_height - int(orig_height * ratio_width) else: padding_width = target_width - int(orig_width * ratio_height) padding_height = 0 return padding_height, padding_width def dynamic_preprocess(image, min_num=1, max_num=12, image_size=384, mask_size=27): target_aspect_ratio, target_height, target_width =\ _find_target_aspect_ratio( image, image_size, max_num, min_num) padding_height, padding_width = _get_padding_size(image, target_height, target_width) # Calculate the ratio orig_width, orig_height = image.size ratio_width = target_width / orig_width ratio_height = target_height / orig_height if ratio_width < ratio_height: new_size = (target_width, int(orig_height * ratio_width)) else: new_size = (int(orig_width * ratio_height), target_height) attention_mask = torch.ones((int(mask_size * target_aspect_ratio[1]), int(mask_size * target_aspect_ratio[0]))) if padding_width >= 14: attention_mask[:, -math.floor(padding_width / 14):] = 0 if padding_height >= 14: attention_mask[-math.floor(padding_height / 14):, :] = 0 assert attention_mask.sum( ) > 0, f'attention mask is empty {attention_mask}' if min(new_size[1], target_height) < 10 or min(new_size[0], target_width) < 10: raise ValueError(f'the aspect ratio is very extreme {new_size}') image = T.functional.resize( image, [new_size[1], new_size[0]], ) resized_img = T.functional.pad(image, [0, 0, padding_width, padding_height], fill=[255, 255, 255]) return resized_img, attention_mask def pad_to_max_num_crops(images, max_crops=5): """ images: B x 3 x H x W, B<=max_crops """ B, _, H, W = images.shape if max_crops > B: pad = torch.zeros(max_crops - B, 3, H, W, dtype=images.dtype, device=images.device) images = torch.cat([images, pad], dim=0) return images def pad_mask_to_max_num_crops(masks, max_crops=5): B, H, W = masks.shape if max_crops > B: pad = torch.ones(max_crops - B, H, W, dtype=masks.dtype, device=masks.device) masks = torch.cat([masks, pad], dim=0) return masks def preprocess(images, dynamic_hd_size, vit_resolution, vit_patch_size): # Basic settings. img_processor = T.Compose([ T.ToTensor(), T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) # Dynamic HD base_resolution = vit_resolution images = [image.convert('RGB') for image in images] # cover 384 and 448 resolution mask_resolution = base_resolution // vit_patch_size elems, image_attention_masks = [], [] for im in images: elem, attention_mask = dynamic_preprocess(im, max_num=dynamic_hd_size, image_size=base_resolution, mask_size=mask_resolution) elems.append(elem) image_attention_masks.append(attention_mask) hd_images = [img_processor(im) for im in elems] global_image = [ torch.nn.functional.interpolate( im.unsqueeze(0).float(), size=(base_resolution, base_resolution), mode='bicubic', ).to(im.dtype) for im in hd_images ] shapes = [[im.size(1), im.size(2)] for im in hd_images] mask_shapes = [[mask.size(0), mask.size(1)] for mask in image_attention_masks] global_attention_mask = [ torch.ones((1, mask_resolution, mask_resolution)) for _ in hd_images ] hd_images_reshape = [ im.reshape(1, 3, h // base_resolution, base_resolution, w // base_resolution, base_resolution).permute( 0, 2, 4, 1, 3, 5).reshape(-1, 3, base_resolution, base_resolution).contiguous() for im, (h, w) in zip(hd_images, shapes) ] attention_masks_reshape = [ mask.reshape(1, h // mask_resolution, mask_resolution, w // mask_resolution, mask_resolution).permute( 0, 1, 3, 2, 4).reshape(-1, mask_resolution, mask_resolution).contiguous() for mask, (h, w) in zip(image_attention_masks, mask_shapes) ] # NOTE token compression is hard coded here, and odd numbers seems to fail downsample_attention_masks = [ mask[:, 0::2, 0::2].reshape(1, h // mask_resolution, w // mask_resolution, mask_resolution // 2 + mask_resolution % 2, mask_resolution // 2 + mask_resolution % 2).permute( 0, 1, 3, 2, 4) for mask, (h, w) in zip(attention_masks_reshape, mask_shapes) ] downsample_attention_masks = [ mask.reshape(mask.size(1) * mask.size(2), mask.size(3) * mask.size(4)) for mask in downsample_attention_masks ] # NOTE hard coded number of tokens num_img_tokens = [ 256 + 1 + int(mask.sum().item()) + int(mask[:, 0].sum().item()) + 16 for mask in downsample_attention_masks ] hd_images_reshape = [ torch.cat([_global_image] + [_im], dim=0) for _global_image, _im in zip(global_image, hd_images_reshape) ] hd_masks_reshape = [ torch.cat([_global_mask] + [_mask], dim=0) for _global_mask, _mask in zip( global_attention_mask, attention_masks_reshape) ] max_crops = max([img.size(0) for img in hd_images_reshape]) image_transformed = [ pad_to_max_num_crops(im, max_crops) for im in hd_images_reshape ] image_transformed = torch.stack(image_transformed, dim=0) mask_transformed = [ pad_mask_to_max_num_crops(mask, max_crops) \ for mask in hd_masks_reshape ] mask_transformed = torch.stack(mask_transformed, dim=0) returned_input_image_embeds = image_transformed returned_image_sizes = torch.tensor(shapes, dtype=torch.long) returned_image_attention_mask = mask_transformed returned_num_img_tokens = num_img_tokens data = { "pixel_values": returned_input_image_embeds, "image_sizes": returned_image_sizes, "image_attention_mask": returned_image_attention_mask, "num_img_tokens": returned_num_img_tokens, } return data def get_navit_vision_model(layer_idx: int = -1, **kwargs): vision_config = { "hidden_size": 1152, "image_size": 448, "intermediate_size": 4304, "model_type": "siglip_vision_model", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 14, } model_config = SiglipVisionConfig(**vision_config, **kwargs) if layer_idx < 0: num_hidden_layers = model_config.num_hidden_layers \ + layer_idx + 1 else: num_hidden_layers = layer_idx + 1 vision_model = Idefics2VisionTransformer( config=model_config, require_post_norm=False, num_hidden_layers_override=num_hidden_layers, ) return vision_model class Phi4MMImageEncoder(nn.Module): """Image embedding.""" def __init__(self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig], prefix: str = "", model_dir: str = "") -> None: super().__init__() # n_embed or hidden_size hidden_size = config.n_embd if hasattr( config, 'n_embd') else config.hidden_size # layer_idx to output the img features if isinstance(config.img_processor, dict): self.layer_idx = config.img_processor.get('layer_idx', -2) self.type_feature = config.img_processor.get( 'type_feature', 'patch') else: self.layer_idx = -2 self.type_feature = 'patch' self.img_processor = get_navit_vision_model(layer_idx=self.layer_idx) pe_weight = self.img_processor.embeddings.position_embedding.weight L, D = pe_weight.size() H = int(math.sqrt(L)) assert H**2 == L, f'position embedding size {L} is not square' if H % 2 != 0: self.img_processor_padding = nn.ReflectionPad2d((0, 1, 0, 1)) H += 1 image_dim_out = D # ((448/14)//2)**2 self.num_img_tokens = (H // 2)**2 self.base_feat_height_target = H self.image_dim_out = image_dim_out self.img_sizes = None self.image_attention_mask = None # global_gn and sub_gn for hd transform, serves as line separator self.use_hd_transform = True self.with_learnable_separator = True self.hd_transform_order = "sub_glb" self.freeze_img_processor = False self.crop_size = 448 # image token compression self.image_token_compression_cls = 'avg_pool_2d' self.image_token_compression = nn.AvgPool2d(kernel_size=2, stride=2) self.base_feat_height_reduction = 1 self.base_feat_height_target = self.base_feat_height_target // 2 # with_hd_transform and with_learnable_separator should have same value assert self.use_hd_transform == self.with_learnable_separator, \ 'use_hd_transform and with_learnable_separator should have same value' assert self.use_hd_transform, \ 'learnable separator is only for hd transform' # 1024 * 4, merge spatial to channel dimension self.glb_GN = nn.Parameter( torch.zeros([ 1, 1, self.image_dim_out * self.base_feat_height_reduction**2 ])) self.sub_GN = nn.Parameter( torch.zeros([ 1, 1, 1, self.image_dim_out * self.base_feat_height_reduction**2 ])) dim_projection = hidden_size depth = 2 layers = [ nn.Linear(image_dim_out * self.base_feat_height_reduction**2, dim_projection) ] for _ in range(1, depth): layers.extend( [nn.GELU(), nn.Linear(dim_projection, dim_projection)]) self.img_projection = nn.Sequential(*layers) self.vocab_size = config.vocab_size self.img_features = None self.use_out_place_operations = False def get_img_features(self, img_embeds: torch.FloatTensor, attention_mask=None) -> torch.FloatTensor: img_feature = self.img_processor(img_embeds, patch_attention_mask=attention_mask) if self.type_feature == "patch": patch_feature = img_feature use_token_compression = self.image_token_compression is not None use_padding = getattr(self, 'img_processor_padding', None) is not None if use_token_compression or use_padding: # reshape to 2D tensor width = int(math.sqrt(patch_feature.size(1))) patch_feature = patch_feature.view(-1, width, width, patch_feature.size(-1)) # convert to NCHW patch_feature = patch_feature.permute(0, 3, 1, 2) if use_padding: patch_feature = self.img_processor_padding(patch_feature) if use_token_compression: patch_feature = self.image_token_compression(patch_feature) # convert to NHWC patch_feature = patch_feature.permute(0, 2, 3, 1) patch_feature = patch_feature.view( -1, patch_feature.size(1) * patch_feature.size(2), patch_feature.size(-1)) return patch_feature raise NotImplementedError def forward(self, pixel_values: torch.FloatTensor, image_sizes: torch.Tensor, image_attention_mask: torch.Tensor) -> torch.FloatTensor: """ process image and return vision embeddings. pixel_values: (num_images, num_crops, c, h, w) image_sizes: [[h1, w1], [h2, w2]] image_attention_mask: num_images x num_crops x 32 x 32 output: (num_images, num_img_tokens, hidden_size) """ # eg # pixel_values: torch.Size([1, 7, 3, 448, 448]) # image_sizes: tensor([[ 896, 1344]], device='cuda:0') # output: torch.Size([1, 1841, 3072]) if isinstance(self.img_projection, nn.Sequential): target_device = self.img_projection[0].bias.device target_dtype = self.img_projection[0].bias.dtype else: # It's a single nn.Linear layer target_device = self.img_projection.bias.device target_dtype = self.img_projection.bias.dtype img_sizes = image_sizes num_images, num_crops, c, h, w = pixel_values.shape bs = num_images pixel_values = pixel_values.flatten(0, 1) img_features = self.get_img_features( pixel_values, image_attention_mask.type(torch.BoolTensor).flatten( 0, 1).to(target_device)) base_feat_height_target = self.base_feat_height_target base_resolution = self.crop_size base_feat_height_reduction = self.base_feat_height_reduction base_feat_height = base_feat_width = int(np.sqrt( img_features.shape[1])) assert base_feat_height == base_feat_height_target \ and base_feat_width == base_feat_height_target, \ f'base_feat_height: {base_feat_height},"\ f" base_feat_width: {base_feat_width}, "\ f"expect {base_feat_height_target} features for hd transform' # bs x max_num_crops x (24x24) x C img_features = img_features.view(bs, -1, base_feat_height * base_feat_width, self.image_dim_out) C = self.image_dim_out H = base_feat_height output_imgs = [] output_len = [] # training is tensor, inference is list if isinstance(img_sizes, torch.Tensor): img_sizes = img_sizes.view(-1, 2) for _bs in range(bs): h, w = img_sizes[_bs] h = h // base_resolution w = w // base_resolution B_ = h * w # 1 x (24x24) x 1024 global_img_feature = img_features[_bs, :1] # 1 x 12 x 12 x 4096 glb_img = global_img_feature.reshape(1, H, H, C).reshape( 1, H // base_feat_height_reduction, base_feat_height_reduction, H // base_feat_height_reduction, base_feat_height_reduction, C).contiguous().permute(0, 1, 3, 2, 4, 5).reshape( 1, H // base_feat_height_reduction, H // base_feat_height_reduction, base_feat_height_reduction * base_feat_height_reduction * C).contiguous() temp_glb_GN = self.sub_GN.repeat(1, H // base_feat_height_reduction, 1, 1) # 1 x 156 x 4096 glb_img = torch.cat([glb_img, temp_glb_GN], dim=2).reshape( 1, -1, base_feat_height_reduction * base_feat_height_reduction * C) # (max_num_crops-1) x (12x12) x C sub_img = img_features[_bs, 1:] # 16x574x1024 # get rid of padding sub_img sub_img = sub_img[:B_] # (num_crops, 12, 2, 12, 2, 1024) -> # (num_crops, 12, 12, 2, 2, 1024) -> (num_crops, 12*12, 4*1024) sub_img = sub_img.reshape(B_, H, H, C).reshape( B_, H // base_feat_height_reduction, base_feat_height_reduction, H // base_feat_height_reduction, base_feat_height_reduction, C).contiguous().permute(0, 1, 3, 2, 4, 5).reshape( B_, -1, base_feat_height_reduction * base_feat_height_reduction * C).contiguous() sub_img = sub_img.reshape( 1, h, w, base_feat_height // base_feat_height_reduction, base_feat_width // base_feat_height_reduction, -1).permute(0, 1, 3, 2, 4, 5).reshape( 1, h * base_feat_height // base_feat_height_reduction, w * base_feat_width // base_feat_height_reduction, base_feat_height_reduction * base_feat_height_reduction * C) if image_attention_mask is not None and len( image_attention_mask) > 0: reshaped_image_attention_mask = image_attention_mask[ _bs, 1:B_ + 1, 0::2, 0::2].reshape( 1, h, w, base_feat_height // base_feat_height_reduction, base_feat_width // base_feat_height_reduction).permute( 0, 1, 3, 2, 4).reshape( 1, h * base_feat_height // base_feat_height_reduction, w * base_feat_width // base_feat_height_reduction) useful_height = int( reshaped_image_attention_mask[0, :, 0].sum().item()) useful_width = int( reshaped_image_attention_mask[0, 0, :].sum().item()) sub_img = sub_img[:, :useful_height, :useful_width] temp_sub_GN = self.sub_GN.repeat(1, useful_height, 1, 1) temp_len = int( image_attention_mask[_bs, :B_ + 1, 0::2, 0::2].sum().item( )) + (useful_height + 1) + base_feat_height // base_feat_height_reduction else: temp_sub_GN = self.sub_GN.repeat( 1, h * base_feat_height // base_feat_height_reduction, 1, 1) temp_len = int((h * w + 1) * self.num_img_tokens + 1 + (h + 1) * base_feat_height // base_feat_height_reduction) sub_img = torch.cat([sub_img, temp_sub_GN], dim=2).reshape( 1, -1, base_feat_height_reduction * base_feat_height_reduction * C) # (1, num_img_tokens, 1024*4) # glb + sub if self.hd_transform_order == 'glb_sub': output_imgs.append( torch.cat([glb_img, self.glb_GN, sub_img], dim=1)) elif self.hd_transform_order == 'sub_glb': output_imgs.append( torch.cat([sub_img, self.glb_GN, glb_img], dim=1)) else: raise NotImplementedError( f'hd_transform_order = {self.hd_transform_order}, "\ "not implemented') #temp_len = int((h*w+1)*144 + 1 + (h+1)*12) assert temp_len == output_imgs[-1].shape[ 1], f'temp_len: {temp_len}, output_imgs[-1].shape[1]: "\ "{output_imgs[-1].shape[1]}' output_len.append(temp_len) img_set_tensor = [] for _output_img in output_imgs: img_feature_proj = self.img_projection( _output_img.to(target_device).to(target_dtype)) img_set_tensor.append(img_feature_proj) return img_set_tensor class Phi4MMAudioFeatureInputs(TypedDict): type: Literal["audio_features"] data: Tuple[NestedTensors] """Shape: `((batch_size, num_audios, 80, M), )""" class Phi4MMAudioEmbeddingInputs(TypedDict): type: Literal["audio_embeds"] data: NestedTensors """Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)""" Phi4MMAudioInputs = Union[Phi4MMAudioFeatureInputs, Phi4MMAudioEmbeddingInputs] def speechlib_mel(sample_rate, n_fft, n_mels, fmin=None, fmax=None): """Create a Mel filter-bank the same as SpeechLib FbankFC. Args: sample_rate (int): Sample rate in Hz. number > 0 [scalar] n_fft (int): FFT size. int > 0 [scalar] n_mel (int): Mel filter size. int > 0 [scalar] fmin (float): lowest frequency (in Hz). If None use 0.0. float >= 0 [scalar] fmax: highest frequency (in Hz). If None use sample_rate / 2. float >= 0 [scalar] Returns out (numpy.ndarray): Mel transform matrix [shape=(n_mels, 1 + n_fft/2)] """ bank_width = int(n_fft // 2 + 1) if fmax is None: fmax = sample_rate / 2 if fmin is None: fmin = 0 assert fmin >= 0, "fmin cannot be negative" assert (fmin < fmax <= sample_rate / 2), "fmax must be between (fmin, samplerate / 2]" def mel(f): return 1127.0 * np.log(1.0 + f / 700.0) def bin2mel(fft_bin): return 1127.0 * np.log(1.0 + fft_bin * sample_rate / (n_fft * 700.0)) def f2bin(f): return int((f * n_fft / sample_rate) + 0.5) # Spec 1: FFT bin range [f2bin(fmin) + 1, f2bin(fmax) - 1] klo = f2bin(fmin) + 1 khi = f2bin(fmax) khi = max(khi, klo) # Spec 2: SpeechLib uses triangles in Mel space mlo = mel(fmin) mhi = mel(fmax) m_centers = np.linspace(mlo, mhi, n_mels + 2) ms = (mhi - mlo) / (n_mels + 1) matrix = np.zeros((n_mels, bank_width), dtype=np.float32) for m in range(0, n_mels): left = m_centers[m] center = m_centers[m + 1] right = m_centers[m + 2] for fft_bin in range(klo, khi): mbin = bin2mel(fft_bin) if left < mbin < right: matrix[m, fft_bin] = 1.0 - abs(center - mbin) / ms return matrix class LogFbankProcessor: def __init__(self): self._eightk_method = "fillzero" self._mel = speechlib_mel(16000, 512, 80, fmin=None, fmax=7690).T self._hamming400 = np.hamming(400) # for 16k audio self._hamming200 = np.hamming(200) # for 8k audio def extract_spectrogram(self, wav, fs): """Extract spectrogram features from waveform. Args: wav (1D array): waveform of the input fs (int): sampling rate of the waveform, 16000 or 8000. If fs=8000, the waveform will be resampled to 16000Hz. Output: log_fbank (2D array): a TxD matrix of log Mel filterbank features. D=80, and T is the number of frames. """ if wav.ndim > 1: wav = np.squeeze(wav) # by default, we extract the mean if stereo if len(wav.shape) == 2: wav = wav.mean(1) # Resample to 16000 or 8000 if needed if fs > 16000: wav = scipy.signal.resample_poly(wav, 1, fs // 16000) fs = 16000 elif 8000 < fs < 16000: wav = scipy.signal.resample_poly(wav, 1, fs // 8000) fs = 8000 elif fs < 8000: raise RuntimeError(f"Unsupported sample rate {fs}") if fs == 8000: if self._eightk_method == "resample": # Input audio is 8 kHz. Convert to 16 kHz before feature # extraction wav = scipy.signal.resample_poly(wav, 2, 1) fs = 16000 # Do nothing here for fillzero method elif fs != 16000: # Input audio is not a supported sample rate. raise RuntimeError( f"Input data using an unsupported sample rate: {fs}") preemphasis = 0.97 if fs == 8000: n_fft = 256 win_length = 200 hop_length = 80 fft_window = self._hamming200 elif fs == 16000: n_fft = 512 win_length = 400 hop_length = 160 fft_window = self._hamming400 # Spec 1: SpeechLib cut remaining sample insufficient for a hop n_batch = (wav.shape[0] - win_length) // hop_length + 1 # Here we don't use stride_tricks since the input array may not satisfy # memory layout requirement and we need writeable output # Here we only use list of views before copy to destination # so it is more efficient than broadcasting y_frames = np.array( [ wav[_stride:_stride + win_length] for _stride in range(0, hop_length * n_batch, hop_length) ], dtype=np.float32, ) # Spec 2: SpeechLib applies preemphasis within each batch y_frames_prev = np.roll(y_frames, 1, axis=1) y_frames_prev[:, 0] = y_frames_prev[:, 1] y_frames = (y_frames - preemphasis * y_frames_prev) * 32768 S = np.fft.rfft(fft_window * y_frames, n=n_fft, axis=1).astype(np.complex64) if fs == 8000: # Need to pad the output to look like 16 kHz data but with zeros in # the 4 to 8 kHz bins. frames, bins = S.shape padarray = np.zeros((frames, bins)) S = np.concatenate((S[:, 0:-1], padarray), axis=1) # Nyquist bin gets set to zero spec = np.abs(S).astype(np.float32) return spec def extract_features(self, wav, fs): """Extract log filterbank features from waveform. Args: wav (1D array): waveform of the input fs (int): sampling rate of the waveform, 16000 or 8000. If fs=8000, the waveform will be resampled to 16000Hz. Output: log_fbank (2D array): a TxD matrix of log Mel filterbank features. D=80, and T is the number of frames. """ spec = self.extract_spectrogram(wav, fs) spec_power = spec**2 fbank_power = np.clip(spec_power.dot(self._mel), 1.0, None) log_fbank = np.log(fbank_power).astype(np.float32) return log_fbank @lru_cache def audio_feature_extractor() -> LogFbankProcessor: # Creates an instance of the audio processor, needed to extract the # the audio features from the sound file # LRU cache ensures that we only make one copy return LogFbankProcessor() def _compute_num_image_tokens(image, dynamic_hd_size, vit_image_size, vit_patch_size, token_compression_factor): """ compute the number of tokens an image is expected to take up considering the image encoder architecture and exclude output features containing only padding pixels for siglip, vit_image_size=448, vit_patch_size=14, so output will be 32x32 feature map NOTE right now, Phi4MM uses hard-coded token_compression_factor=2 """ assert vit_image_size % vit_patch_size == 0, \ "vit_image_size must be divisible by vit_patch_size" assert vit_image_size // vit_patch_size % token_compression_factor == 0, \ "vit_image_size // vit_patch_size must be divisible by "\ "token_compression_factor" target_aspect_ratio, target_height, target_width = ( _find_target_aspect_ratio(image, vit_image_size, dynamic_hd_size, min_num=1)) assert target_aspect_ratio[ 0] * vit_image_size == target_width, \ f"{target_aspect_ratio[0]} * {vit_image_size} != {target_width}" assert target_aspect_ratio[ 1] * vit_image_size == target_height, \ f"{target_aspect_ratio[1]} * {vit_image_size} != {target_height}" assert (target_height % vit_image_size == 0 and target_width % vit_image_size == 0) padding_height, padding_width = _get_padding_size(image, target_height, target_width) assert padding_width == 0 or padding_height == 0, \ "padding_width or padding_height must be 0" target_feat_width = target_width // vit_patch_size target_feat_height = target_height // vit_patch_size if padding_width >= vit_patch_size: assert padding_height == 0, "padding_height not 0" non_pad_feat_width = target_feat_width - math.floor( padding_width / vit_patch_size) non_pad_feat_height = target_feat_height elif padding_height >= vit_patch_size: assert padding_width == 0, "padding_width not 0" non_pad_feat_height = target_feat_height - math.floor( padding_height / vit_patch_size) non_pad_feat_width = target_feat_width else: # small padding shorter than a vit patch non_pad_feat_width = target_feat_width non_pad_feat_height = target_feat_height feat_width = non_pad_feat_width // token_compression_factor feat_height = non_pad_feat_height // token_compression_factor # NOTE it's possible that the non-padding feature is not divisible if non_pad_feat_width % token_compression_factor != 0: feat_width += 1 if non_pad_feat_height % token_compression_factor != 0: feat_height += 1 num_hd_patch_tokens = feat_width * feat_height num_hd_newline_tokens = feat_height vit_feature_size = vit_image_size // vit_patch_size num_global_image_tokens = (vit_feature_size // token_compression_factor)**2 num_sep_tokens = 1 num_global_image_newline_tokens = \ vit_feature_size // token_compression_factor return (num_global_image_tokens + num_sep_tokens + num_hd_patch_tokens + num_hd_newline_tokens + num_global_image_newline_tokens) def compute_logfbank_output_size(wav_length: int, fs: int) -> Tuple[int, int]: """ Compute the output size of the `extract_features` method. Args: wav_length (int): Length of the input waveform in samples. fs (int): Sampling rate of the waveform, either 16000 or 8000. Returns: tuple (int, int): Output size as (T, D), where: T: Number of time frames. D: Number of Mel filterbank bins (80). """ # Resample to 16000 or 8000 if needed if fs > 16000: wav_length //= fs // 16000 fs = 16000 elif 8000 <= fs < 16000: # We'll resample to 16K from 8K wav_length *= 2 fs = 16000 elif fs < 8000: raise RuntimeError(f"Unsupported sample rate {fs}") # Spectrogram parameters for 16 kHz win_length = 400 # Frame length in samples hop_length = 160 # Frame shift in samples mel_bins = 80 # Number of mel filterbank bins # Calculate number of frames (T) T = (wav_length - win_length) // hop_length + 1 if T < 1: raise ValueError("Waveform too short for given parameters.") # Return time frames (T) and mel bins (D) return T, mel_bins def _get_audio_embed_sizes(audios, ctx: InputContext): """ Get the audio embedding sizes for each audio file. Args: audios (List[Tuple[np.ndarray, int]]): List of audio files as tuples of waveform and sample rate. ctx (InputContext): Input context. Returns: List[int]: List of audio embedding sizes. """ audio_embed_sizes = [] for audio in audios: audio_data, sf = audio audio_frames, _ = compute_logfbank_output_size(len(audio_data), sf) audio_embed_size = _compute_audio_embed_size(ctx.get_hf_config(), audio_frames) audio_embed_sizes.append(audio_embed_size) return audio_embed_sizes def _get_audio_id_to_input_ids(audios, ctx: InputContext, prompt_str=""): """ The following will search for `<|audio_{idx}|>` tokens and return a mapping of audio placeholder tokens to audio placeholder token ids based on the size of the audio embeddings. Args: audios (List[Tuple[np.ndarray, int]]): List of audio files as tuples of waveform and sample rate. ctx (InputContext): Input context. prompt_str (str): The prompt string. Returns: Dict[str, List[int]]: Mapping of audio placeholder tokens to audio placeholder token ids. """ if len(audios) == 0: return {} audio_embed_sizes = _get_audio_embed_sizes(audios, ctx) audio_ids = re.findall(AUDIO_TOKEN_PATTERN, prompt_str) audio_ids = [int(audio_id) for audio_id in audio_ids] assert len(audio_ids) == len( audio_embed_sizes ), "Number of audio tokens and audio features do not match" assert tuple(audio_ids) == tuple(range(1, len(audio_ids) + 1)), "Audio ids are not in order!" audio_id_to_input_ids = { f"<|audio_{audio_id}|>": [_AUDIO_PLACEHOLDER_TOKEN_ID] * audio_embed_size for audio_id, audio_embed_size in zip(audio_ids, audio_embed_sizes) } return audio_id_to_input_ids def _count_image_tokens(images, ctx: InputContext): hf_config = ctx.get_hf_config() vision_encoder_name = hf_config.img_processor if vision_encoder_name is None: vision_encoder_name = SIGLIP_NAME prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[vision_encoder_name] dynamic_hd_size = prepro_config['dynamic_hd'] vit_image_size = prepro_config['vit_image_size'] vit_patch_size = prepro_config['vit_patch_size'] token_compression_factor = prepro_config['token_compression_factor'] image_token_counts = [ _compute_num_image_tokens(image, dynamic_hd_size, vit_image_size, vit_patch_size, token_compression_factor) for image in images ] return image_token_counts def _get_image_id_to_input_ids(images, prompt, ctx: InputContext): if len(images) == 0: return {} image_ids = re.findall(IMAGE_TOKEN_PATTERN, prompt) image_ids = [int(image_id) for image_id in image_ids] assert len(image_ids) == len( set(image_ids)), "Duplicate image tokens in prompt" assert len(images) == len( image_ids), "Number of images and image tokens in prompt do not match" # NOTE the following assertion is not strictly necessary assert tuple(image_ids) == tuple(range(1, len(image_ids) + 1)), "Image ids are not in order" image_token_counts = _count_image_tokens(images, ctx) image_id_to_input_ids = { f"<|image_{image_id}|>": [_IMAGE_PLACEHOLDER_TOKEN_ID] * num_tokens for image_id, num_tokens in zip(image_ids, image_token_counts) } return image_id_to_input_ids def input_processor_for_phi4mm(ctx: InputContext, inputs: DecoderOnlyInputs) -> TokenInputs: """ Implements the input processor, which transforms the input prompt ids to include the audio placeholder token. This will become the `input_ids` in `forward` for the model. Args: ctx (InputContext): Input context. inputs (DecoderOnlyInputs): The inputs (e.g. prompt, prompt_token_ids) to process. Returns: TokenInputs: Processed inputs """ multi_modal_data = inputs.get("multi_modal_data") if (multi_modal_data is None or ("audio" not in multi_modal_data and "image" not in multi_modal_data)): # pure text input, so no need to do pre-processing return inputs prompt_str = inputs.get("prompt") prompt_token_ids = inputs.get("prompt_token_ids") # for offline_inference, we will get str input and we parse MM special # tokens from it # (ignore prompt_token_ids) # for OAI server, we will get prompt_token_ids, where MM special tokens # are already parsed if 'audio' in multi_modal_data: audios = multi_modal_data["audio"] if not isinstance(audios, list): audios = [audios] if prompt_str is not None: audio_id_to_input_ids = _get_audio_id_to_input_ids( audios, ctx, prompt_str=prompt_str) audio_embed_sizes = [] elif prompt_token_ids is not None: audio_id_to_input_ids = {} audio_embed_sizes = _get_audio_embed_sizes(audios, ctx) else: audio_id_to_input_ids = {} audio_embed_sizes = [] if 'image' in multi_modal_data: # PIL Image or list of PIL Images images = multi_modal_data["image"] if not isinstance(images, list): images = [images] if prompt_str is not None: image_id_to_input_ids = _get_image_id_to_input_ids( images, prompt_str, ctx) image_token_counts = [] elif prompt_token_ids is not None: image_id_to_input_ids = {} image_token_counts = _count_image_tokens(images, ctx) else: image_id_to_input_ids = {} image_token_counts = [] # Handle the case where the prompt is a string and we need to manually # tokenize it. # In this case, the `audio_id_to_input_ids` dict will be mapping from # an audio placeholder # string (e.g. `<|audio_1|>`) to the audio placeholder tokens for the # given audio length. if prompt_str: pattern = r"(<\|image_\d+\|>|<\|audio_\d+\|>)" prompt_chunk_strings = re.split(pattern, prompt_str) prompt_chunk_strings = [s for s in prompt_chunk_strings if s != ""] # Create the new input_ids with the placeholder image and audio # tokens inserted tokenizer = cached_tokenizer_from_config(ctx.model_config) input_ids = [] has_imag, has_audio, has_user_text_input = False, False, False for prompt_chunk_string in prompt_chunk_strings: if re.match(IMAGE_TOKEN_PATTERN, prompt_chunk_string): input_ids.extend(image_id_to_input_ids[prompt_chunk_string]) has_imag = True elif re.match(AUDIO_TOKEN_PATTERN, prompt_chunk_string): input_ids.extend(audio_id_to_input_ids[prompt_chunk_string]) has_audio = True else: curr_token_ids = tokenizer(prompt_chunk_string).input_ids if not has_user_text_input: for token_id in curr_token_ids: if token_id not in NON_USER_INPUT_TOKENS: has_user_text_input = True break input_ids.extend(curr_token_ids) if has_audio and has_imag and has_user_text_input: raise ValueError( "Phi4MMForCausalLM does not support text + audio + image" + " inputs in the same prompt") # Handle the case where the prompt is already tokenized else: assert prompt_token_ids is not None, \ "If string prompt isn't provided, prompt_token_ids must be" i = 0 input_ids = prompt_token_ids # only needed for later assertion img_cnt, audio_cnt, user_text_input_cnt = 0, 0, 0 image_token_count_iter = iter(image_token_counts) audio_embed_size_iter = iter(audio_embed_sizes) while i < len(input_ids): token_id = input_ids[i] if token_id == _AUDIO_PLACEHOLDER_TOKEN_ID: token_count = next(audio_embed_size_iter) audio_cnt += 1 elif token_id == _IMAGE_PLACEHOLDER_TOKEN_ID: token_count = next(image_token_count_iter) img_cnt += 1 else: user_text_input_cnt += 1 if token_id not in \ NON_USER_INPUT_TOKENS else 0 i += 1 continue tokens = [token_id] * token_count input_ids = input_ids[:i] + tokens + input_ids[i + 1:] i += token_count if audio_cnt > 0 and img_cnt > 0 and user_text_input_cnt > 0: raise ValueError( "Phi4MMForCausalLM does not support text + audio + image" + " inputs in the same prompt") # If the below assertion fails, it might be that input pure-text # messages contain image/audio special tokens literally # (<|endoftext10|>, <|endoftext11|>). assert (img_cnt == len(image_token_counts)), ( f"Number of image tokens in prompt_token_ids ({img_cnt}) " f"does not match number of images ({len(image_token_counts)})") assert (audio_cnt == len(audio_embed_sizes)), ( f"Number of audio tokens in prompt_token_ids ({audio_cnt}) " f"does not match number of audios ({len(audio_embed_sizes)})") # NOTE: Create a defensive copy of the original inputs return token_inputs( prompt_token_ids=input_ids, prompt=prompt_str, multi_modal_data=multi_modal_data, ) def _compute_audio_embed_size(hf_config, audio_frames): """ Compute the audio embedding size based on the audio frames and compression rate. """ compression_rate = hf_config.embd_layer['audio_embd_layer'][ 'compression_rate'] # NOTE: this is a hard-coded value but might be configurable in the future qformer_compression_rate = 1 integer = audio_frames // compression_rate remainder = audio_frames % compression_rate result = integer if remainder == 0 else integer + 1 integer = result // qformer_compression_rate remainder = result % qformer_compression_rate result = integer if remainder == 0 else integer + 1 # qformer compression return result def get_max_phi4mm_audio_tokens(ctx: InputContext) -> int: return 10000 def dummy_audio_for_phi4mm(audio_count: int) -> dict: """ Create dummy audio data for the Phi4MM model, which is used for profiling. Args: audio_count (int): Number of audio samples. Returns: dict: Dummy audio data. """ dummy_audio = np.full((_AUDIO_MAX_SOUNDFILE_SIZE, ), 0.0) return [(dummy_audio, DUMMY_SAMPLING_FREQUENCY)] * audio_count def dummy_image_for_phi4mm(width: int, height: int): image = Image.new('RGB', (width, height), color='black') return image def dummy_data_for_phi4mm(ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]) -> DummyData: """ Create dummy sequence (input_ids) and audio data for the Phi4MM model, which is used for profiling. In this case, the sequence data is a bunch of 0s with a number of audio tokens that correspond to the audio embed size of the _AUDIO_MAX_SOUNDFILE_SIZE. Args: ctx (InputContext): Input context. seq_len (int): Length of the sequence. mm_counts (Mapping[str, int]): Multi-modal counts. Returns: Tuple: Dummy sequence data and dummy audio data. """ audio_count = mm_counts["audio"] audio_frames, _ = compute_logfbank_output_size(_AUDIO_MAX_SOUNDFILE_SIZE, DUMMY_SAMPLING_FREQUENCY) audio_feature_size = _compute_audio_embed_size(ctx.get_hf_config(), audio_frames) image_count = mm_counts["image"] dummy_image = get_max_dummy_image(ctx) max_image_tokens = get_max_phi4mm_image_tokens(ctx) total_image_tokens = image_count * max_image_tokens if seq_len - audio_feature_size * audio_count - total_image_tokens < 0: raise RuntimeError( f"Phi4MM cannot process {audio_count} audios and {image_count}" f"images in a prompt, please increase max_model_len to be at" f" larger than " f"{audio_feature_size * audio_count + total_image_tokens}" " or reduce audio/image limit by --limit-mm-per-prompt.") if audio_feature_size * audio_count > total_image_tokens: seq_data = SequenceData.from_prompt_token_counts( (_AUDIO_PLACEHOLDER_TOKEN_ID, audio_feature_size * audio_count), (0, seq_len - audio_feature_size * audio_count), ) mm_data = { "audio": dummy_audio_for_phi4mm(audio_count), } else: seq_data = SequenceData.from_prompt_token_counts( (_IMAGE_PLACEHOLDER_TOKEN_ID, total_image_tokens), (0, seq_len - total_image_tokens), ) mm_data = { "image": [dummy_image] * image_count, } return DummyData(seq_data, mm_data) def input_mapper_for_phi4mm_audio(ctx: InputContext, data: object) -> MultiModalKwargs: """ This function is used to create the MultiModalKwargs for the Phi4MM (audio) model. Specifically, for audio, we extract the audio features from the sound file and create pairs of audio features and audio embed lengths (the latter of which is used to repeat the audio placeholder token in the input prompt IDs). These pairs are used, downstream, in `_audio_features_to_embeddings` (via `_process_audio_input`). Note that the incoming audio data (each entry in `data`) is a tuple of the audio data and the sampling frequency (e.g. from soundfile.read). Args: ctx (InputContext): Input context. data (object): Audio data. Returns: MultiModalKwargs: Multi-modal inputs. """ if not isinstance(data, list): data = [data] if len(data) == 0: return MultiModalKwargs() audio_features = [] for audio_input in data: if not isinstance(audio_input, tuple): raise NotImplementedError( f"Unsupported data type: {type(audio_input)}") audio, sf = audio_input feature_extractor = audio_feature_extractor() single_audio_features = feature_extractor.extract_features(audio, sf) feat_stride = (1 if not hasattr(feature_extractor, "stride") else feature_extractor.stride) audio_frames = len(single_audio_features) * feat_stride single_audio_embed_size = _compute_audio_embed_size( ctx.get_hf_config(), audio_frames) single_audio_feature_audio_len_pair = ( single_audio_features, [single_audio_embed_size], ) audio_features.append(single_audio_feature_audio_len_pair) return MultiModalKwargs({"audio_features": audio_features}) def input_mapper_for_phi4mm_image(ctx: InputContext, data: object): if not isinstance(data, list): data = [data] # data: list of PIL images if len(data) == 0: return MultiModalKwargs() hf_config = ctx.get_hf_config() vision_encoder_name = hf_config.img_processor if vision_encoder_name is None: vision_encoder_name = SIGLIP_NAME prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[vision_encoder_name] dynamic_hd_size = prepro_config['dynamic_hd'] vit_image_size = prepro_config['vit_image_size'] vit_patch_size = prepro_config['vit_patch_size'] image_input_dict = preprocess(data, dynamic_hd_size, vit_image_size, vit_patch_size) return MultiModalKwargs({ "pixel_values": image_input_dict["pixel_values"], "image_sizes": image_input_dict["image_sizes"], "image_attention_mask": image_input_dict["image_attention_mask"], "num_img_tokens": image_input_dict["num_img_tokens"], }) def cat_with_pad(tensors, dim, padding_value=0): """ cat along dim, while pad to max for all other dims """ ndim = tensors[0].dim() assert all( t.dim() == ndim for t in tensors[1:]), "All tensors must have the same number of dimensions" out_size = [max(t.shape[i] for t in tensors) for i in range(ndim)] out_size[dim] = sum(t.shape[dim] for t in tensors) output = tensors[0].new_full(out_size, padding_value) index = 0 for t in tensors: # Create a slice list where every dimension except dim is full slice slices = [slice(0, t.shape[d]) for d in range(ndim)] # Update only the concat dimension slice slices[dim] = slice(index, index + t.shape[dim]) output[slices] = t index += t.shape[dim] return output @MULTIMODAL_REGISTRY.register_input_mapper("audio", input_mapper_for_phi4mm_audio) @MULTIMODAL_REGISTRY.register_input_mapper("image", input_mapper_for_phi4mm_image) @MULTIMODAL_REGISTRY.register_max_multimodal_tokens( "audio", get_max_phi4mm_audio_tokens) @MULTIMODAL_REGISTRY.register_max_multimodal_tokens( "image", get_max_phi4mm_image_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi4mm) @INPUT_REGISTRY.register_input_processor(input_processor_for_phi4mm) class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal, SupportsV0Only): """ Implements the Phi-4-multimodal-instruct model in vLLM. """ packed_modules_mapping = { "qkv_proj": [ "qkv_proj", ], "gate_up_proj": [ "gate_up_proj", ], } hf_to_vllm_mapper = WeightsMapper( orig_to_new_substr={ "base_layer.": "", }, orig_to_new_prefix={ "model.embed_tokens_extend.audio_embed.audio_projection.vision.": "embed_tokens_extend.audio_projection_for_vision.", "model.embed_tokens_extend.audio_embed.audio_projection.speech.": "embed_tokens_extend.audio_projection.", "model.embed_tokens_extend.audio_embed.": "embed_tokens_extend.", "model.embed_tokens_extend.image_embed.": "vision_encoder.", }, ) def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config multimodal_config = vllm_config.model_config.multimodal_config assert multimodal_config, "multimodal_config is required" quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config self.config = config self.multimodal_config = multimodal_config self.quant_config = quant_config self.lora_config = lora_config # Tensor/Pipeline parallel not supported for now. assert get_pp_group( ).world_size == 1, "pipeline parallel is not supported" self.vision_encoder = Phi4MMImageEncoder( config, quant_config, prefix="model.vision_embed_tokens", model_dir=config._name_or_path) if isinstance(config.embd_layer["audio_embd_layer"], dict): embedding_config = { "embedding_cls": config.embd_layer["audio_embd_layer"]["embedding_cls"], **config.embd_layer["audio_embd_layer"], } else: embedding_config = { "embedding_cls": self.config.embd_layer["embedding_cls"] } self.embed_tokens_extend = AudioEmbedding(config, **embedding_config) self.model = LlamaModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = ParallelLMHead( self.unpadded_vocab_size, config.hidden_size, org_num_embeddings=config.vocab_size, padding_size=( DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility if not lora_config else lora_config.lora_vocab_padding_size), quant_config=quant_config, ) if config.tie_word_embeddings: self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size, logit_scale) self.sampler = Sampler() def _audio_features_to_embeddings( self, input_ids: torch.Tensor, input_features: List[torch.Tensor], audio_input_sizes: torch.Tensor, audio_projection_mode: str, ) -> torch.Tensor: """ Convert audio features to embeddings, which are used as input to the model (via `inputs_embeds`). Args: input_ids (torch.Tensor): Input IDs (the prompt in this case). input_features (list[torch.Tensor]): Input features (the audio embeddings). audio_input_sizes (list[torch.Tensor]): Audio input sizes (the audio embed lengths to use for padding the audio placeholder token in the input prompt IDs). """ # The audio projection can either be a single linear or Sequential, # so handle both cases if isinstance(self.embed_tokens_extend.audio_projection, nn.Sequential): target_dtype = self.embed_tokens_extend.audio_projection[ 0].bias.dtype else: target_dtype = self.embed_tokens_extend.audio_projection.bias.dtype audio_input = [ input.unsqueeze(0).to(target_dtype) for input in input_features ] kwargs = { "wte": self.model.embed_tokens, 'audio_projection_mode': audio_projection_mode } audio_embeddings = self.embed_tokens_extend(input_ids, audio_input, audio_input_sizes, **kwargs) audio_embeddings = audio_embeddings.to(target_dtype) return audio_embeddings def _parse_and_validate_audio_input( self, **kwargs: object) -> Optional[Phi4MMAudioInputs]: """ Parse and validate the audio input to the model. This handles both audio features and audio embeddings, but only the former is used for now. Args: kwargs (object): Keyword arguments. Returns: Optional[Phi4MMAudioInputs]: Parsed and validated audio inputs. """ audio_features = kwargs.pop("audio_features", None) audio_embeds = kwargs.pop("audio_embeds", None) if audio_features is None and audio_embeds is None: return None if audio_features is not None: if not isinstance(audio_features, (torch.Tensor, list)): raise ValueError("Incorrect type of audio features. " f"Got type: {type(audio_features)}") return Phi4MMAudioFeatureInputs(type="audio_features", data=audio_features) if audio_embeds is not None: if not isinstance(audio_embeds, (torch.Tensor, list)): raise ValueError("Incorrect type of audio embeds. " f"Got type: {type(audio_embeds)}") return Phi4MMAudioEmbeddingInputs(type="audio_embeds", data=audio_embeds) raise AssertionError("This line should be unreachable.") def _process_audio_input(self, input_ids: torch.Tensor, audio_input: Phi4MMAudioInputs, audio_projection_mode: str) -> NestedTensors: """ Create the audio embeddings from the audio input, where the audio input is pairs of audio features and audio embed lengths. The audio input is created by `input_mapper_for_phi4mm_audio`. Args: input_ids (torch.Tensor): Input IDs (the prompt in this case, before the audio token replication). audio_input (Phi4MMAudioInputs): Audio input. Returns: NestedTensors: Audio embeddings """ if audio_input["type"] == "audio_embeds": return audio_input["data"] audio_features = audio_input["data"] # (e.g. multiple examples) and the second dim is the multi-audio dim # (e.g. multiple audios in the same example) audio_feature = [i[0] for j in audio_features for i in j] audio_feature_len = [i[1].item() for j in audio_features for i in j] # Add the batch dim via `squeeze` return self._audio_features_to_embeddings( input_ids.unsqueeze(0), audio_feature, audio_feature_len, audio_projection_mode, ).squeeze(0) def _parse_and_validate_image_input(self, **kwargs: object) -> Optional[Dict]: pixel_values: Optional[Dict] = kwargs.get("pixel_values") if pixel_values is None: return None image_sizes = kwargs.get("image_sizes") image_attention_mask = kwargs.get("image_attention_mask") num_img_tokens = kwargs.get("num_img_tokens") assert image_sizes is not None and image_attention_mask is not None\ and num_img_tokens is not None, "Missing image inputs" if isinstance(pixel_values, list): assert pixel_values[0].dim() == 5, "Incorrect image inputs" # list len is batch_size. # each tensor has dimension: num_img_per_example, num_hd_patches, # channels, height, width. # need to pad along num_hd_patches. # mask size num_img_per_prompt, num_hd_patches, feat_h, heat_w. pixel_values = cat_with_pad(pixel_values, dim=0) elif isinstance(pixel_values, torch.Tensor): # dimension: batch_size, num_img_per_example, num_hd_patches, # channels, height, width. # we flatten first 2 dims to make it a single large batch for # SigLIP Encoder. assert pixel_values.dim() == 6, "Incorrect image inputs" pixel_values = pixel_values.flatten(0, 1) else: raise ValueError("Incorrect pixel_values inputs") if isinstance(image_attention_mask, list): image_attention_mask = cat_with_pad(image_attention_mask, dim=0) elif isinstance(image_attention_mask, torch.Tensor): image_attention_mask = image_attention_mask.flatten(0, 1) else: raise ValueError("Incorrect image_attention_mask inputs") if isinstance(image_sizes, list): image_sizes = torch.cat(image_sizes, dim=0) elif isinstance(image_sizes, torch.Tensor): image_sizes = image_sizes.flatten(0, 1) else: raise ValueError("Incorrect image_attention_mask inputs") if isinstance(num_img_tokens, list): num_img_tokens = [ n for num_tensor in num_img_tokens for n in num_tensor.tolist() ] elif isinstance(num_img_tokens, torch.Tensor): num_img_tokens = num_img_tokens.flatten(0, 1).tolist() else: raise ValueError("Incorrect image_attention_mask inputs") return { 'pixel_values': pixel_values, 'image_sizes': image_sizes, 'image_attention_mask': image_attention_mask, 'num_img_tokens': num_img_tokens, } def merge_image_features_to_inputs_embeds( self, input_ids: torch.Tensor, inputs_embeds: torch.Tensor, image_set_tensors: List[torch.Tensor], ): position_tuple = (input_ids == _IMAGE_PLACEHOLDER_TOKEN_ID).nonzero( as_tuple=True) assert all([t.shape[0] == 1 for t in image_set_tensors ]), 'img_set_tensor should have shape (1, N_tokens, C)' # Shape: (merged_N_tokens, C) image_set_tensor = torch.cat(image_set_tensors, dim=1).squeeze(0) image_set_tensor = image_set_tensor.to(inputs_embeds.dtype).to( inputs_embeds.device) merged_embeds = inputs_embeds.index_put( indices=position_tuple, values=image_set_tensor, accumulate=False, ) return merged_embeds def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs: object, ) -> torch.Tensor: if intermediate_tensors is not None: input_ids = None inputs_embeds = None else: # Each entry in this is a pair of audio_features and audio_embed # lengths audio_input = self._parse_and_validate_audio_input(**kwargs) image_inputs = self._parse_and_validate_image_input(**kwargs) has_audio = audio_input is not None has_image = image_inputs is not None if has_audio: audio_projection_mode = 'vision' if has_image else 'speech' inputs_embeds = self._process_audio_input( input_ids, audio_input, audio_projection_mode) if has_image: dtype = self.vision_encoder.img_processor.embeddings.\ patch_embedding.weight.dtype pixel_values = image_inputs['pixel_values'].to(dtype) image_sizes = image_inputs['image_sizes'] image_attention_mask = image_inputs['image_attention_mask'] image_set_tensors = self.vision_encoder( pixel_values, image_sizes, image_attention_mask) if not has_audio: inputs_embeds = self.model.embed_tokens(input_ids) inputs_embeds = self.merge_image_features_to_inputs_embeds( input_ids, inputs_embeds, image_set_tensors) if has_image or has_audio: # multi-modal input, we have set inputs_embeds properly in # previous steps input_ids = None else: # text-only, we keep using original input_ids inputs_embeds = None hidden_states = self.model( input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds, ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits def sample( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> None: weights = ((name, data) for name, data in weights if "lora" not in name) loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) def get_mm_mapping(self) -> MultiModelKeys: """ Get the module prefix in multimodal models """ return MultiModelKeys.from_string_field( language_model="model.", connector=["audio_projection_for_vision", "audio_projection"], tower_model=["vision_encoder", "embed_tokens_extend"], )