# coding=utf-8 # Copyright 2024 the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch Mllama model.""" import math from array import array from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union) import torch import torch.nn.functional as F import torch.utils.checkpoint import transformers.models.mllama.configuration_mllama as config_mllama from PIL import Image from torch import nn from transformers.modeling_outputs import (BaseModelOutput, CausalLMOutputWithPast) from transformers.models.mllama.image_processing_mllama import ( get_optimal_tiled_canvas) import vllm.distributed.parallel_state as ps from vllm.attention import Attention, AttentionMetadata, AttentionType from vllm.config import CacheConfig, MultiModalConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData from .clip import CLIPMLP from .interfaces import SupportsMultiModal from .llama import LlamaDecoderLayer, LlamaMLP logger = init_logger(__name__) MLLAMA_IMAGE_TOKEN_ID = 128256 MLLAMA_IMAGE_TOKEN = "<|image|>" class MllamaImagePixelInputs(TypedDict): type: Literal["pixel_values"] data: torch.Tensor """Shape: """ """(batch_size, max_num_image, max_num_chunk, num_channel, height, width)""" aspect_ratio_ids: torch.Tensor """Shape: `(batch_size, max_num_image)`""" aspect_ratio_mask: torch.Tensor """Shape: `(batch_size, max_num_image, max_num_tiles)`""" # TODO: support LlamaImageEmbeddingInputs def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs): # move encoder_prompt to prompt if llm_inputs.get("prompt") is None: llm_inputs["prompt"] = llm_inputs["encoder_prompt"] llm_inputs["prompt_token_ids"] = llm_inputs["encoder_prompt_token_ids"] # process multi-modal data assert "decoder_multi_modal_data" not in llm_inputs, \ "multi-modal data should be put in encoder message of mllama" multi_modal_data = llm_inputs.get("encoder_multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data \ or multi_modal_data["image"] is None: # text-only llm_inputs["encoder_prompt"] = "" llm_inputs["encoder_prompt_token_ids"] = [] llm_inputs["encoder_multi_modal_data"] = {} return llm_inputs # get num_tiles if isinstance(multi_modal_data['image'], Image.Image): multi_modal_data['image'] = [multi_modal_data['image']] hf_config = ctx.model_config.hf_config num_tiles = 0 for image in multi_modal_data["image"]: width, height = image.size tile_size = hf_config.vision_config.image_size canvas_height, canvas_width = get_optimal_tiled_canvas( image_height=height, image_width=width, max_image_tiles=hf_config.vision_config.max_num_tiles, tile_size=tile_size, ) num_tiles_height = canvas_height // tile_size num_tiles_width = canvas_width // tile_size num_tiles += num_tiles_height * num_tiles_width # set encoder prompt based on num_tiles assert hf_config.vision_config.image_size % 14 == 0, \ "chunk size should be multiple of 14" token_per_chunk = (hf_config.vision_config.image_size // 14)**2 + 1 num_tokens = num_tiles * token_per_chunk llm_inputs["encoder_prompt"] = MLLAMA_IMAGE_TOKEN * num_tokens llm_inputs["encoder_prompt_token_ids"] = [MLLAMA_IMAGE_TOKEN_ID ] * num_tokens return llm_inputs def get_max_mllama_image_tokens(ctx: InputContext) -> int: hf_config = ctx.model_config.hf_config token_per_chunk = (hf_config.vision_config.image_size // 14)**2 + 1 return hf_config.vision_config.max_num_tiles * token_per_chunk def dummy_decoder_seq_data(seq_len: int, num_images: int): # <|image|> * num_images + 0 * (seq_len - num_images) assert seq_len >= num_images, \ "seq_len should be greater than or equal to num_images" token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [MLLAMA_IMAGE_TOKEN_ID]) * num_images token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * (seq_len - num_images) return SequenceData(token_ids) def dummy_encoder_seq_data(ctx: InputContext, num_images: int): num_tokens = get_max_mllama_image_tokens(ctx) * num_images token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [MLLAMA_IMAGE_TOKEN_ID]) * num_tokens return SequenceData(token_ids) def dummy_image(num_images: int, ): width = height = 1024 image = Image.new("RGB", (width, height), color=0) return {"image": image if num_images == 1 else [image] * num_images} def dummy_decoder_data_for_mllama(ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]): num_images = mm_counts["image"] return dummy_decoder_seq_data(seq_len, num_images), None def dummy_encoder_data_for_mllama(ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]): num_images = mm_counts["image"] return dummy_encoder_seq_data(ctx, num_images), dummy_image(num_images) def _prepare_aspect_ratio_attention_mask( aspect_ratio_mask: torch.Tensor, num_patches: int, target_length: int, dtype: torch.dtype, ) -> torch.Tensor: # Expand aspect ratio mask to target_length batch_size, max_num_tiles = aspect_ratio_mask.shape attention_mask = aspect_ratio_mask.view(batch_size, max_num_tiles, 1, 1).to(dtype) attention_mask = attention_mask.repeat(1, 1, target_length, 1) # Mask padding patches pad_patches = target_length - num_patches attention_mask[:, :, -pad_patches:] = 0 # Invert the mask (0 -> 1, 1 -> 0) attention_mask = 1 - attention_mask # Reshape to 2D and create 4D attention mask # (batch_size, 1, max_num_tiles*target_length, max_num_tiles*target_length) attention_mask = attention_mask.reshape(batch_size, max_num_tiles * target_length, 1) attention_mask = attention_mask @ attention_mask.transpose( -1, -2) * torch.finfo(dtype).min attention_mask = attention_mask.unsqueeze(1) return attention_mask class ColumnParallelConv2dPatch(torch.nn.Module): """Conv2D Patching layer with model parallelism. Column parallel over unfolded input. Arguments: in_channels: Input channels. out_channels: Output channels. kernel_size: Size of convolution kernel. stride (default 1): Stride for convolution. bias (default False): Use bias in Conv2d. Input: (bsz, in_channels, width, height) Output: (bsz, num_tokens, out_channels) """ def __init__( self, in_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, int]], stride: Union[int, Tuple[int, int]], bias: bool = False, ) -> None: super().__init__() if isinstance(kernel_size, int): kernel_size = (kernel_size, kernel_size) self._unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=stride) self._linear = ColumnParallelLinear( in_channels * kernel_size[0] * kernel_size[1], out_channels, bias=bias, ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self._unfold(x) x = x.permute(0, 2, 1).contiguous() x, _ = self._linear(x) return x class MllamaPrecomputedAspectRatioEmbedding(nn.Module): def __init__(self, config: config_mllama.MllamaVisionConfig, is_gated: bool = True): super().__init__() self.max_num_tiles = config.max_num_tiles self.hidden_size = config.hidden_size self.max_aspect_ratio_id = config.max_aspect_ratio_id self.is_gated = is_gated self.embedding = nn.Embedding(self.max_aspect_ratio_id + 1, self.max_num_tiles * self.hidden_size) if is_gated: self.gate = nn.Parameter(torch.zeros(1)) def forward(self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor) -> torch.Tensor: embeddings = self.embedding(aspect_ratio_ids) embeddings = embeddings.reshape(-1, self.max_num_tiles, 1, self.hidden_size) if self.is_gated: embeddings = embeddings * self.gate.tanh() hidden_state = hidden_state + embeddings return hidden_state class MllamaPrecomputedPositionEmbedding(nn.Module): def __init__(self, config: config_mllama.MllamaVisionConfig): super().__init__() self.max_num_tiles = config.max_num_tiles self.max_aspect_ratio_id = config.max_aspect_ratio_id self.num_patches = (config.image_size // config.patch_size)**2 + 1 self.hidden_size = config.hidden_size self.scale = config.hidden_size**-0.5 self.gate = nn.Parameter(torch.zeros(1)) # position embedding position_embedding = torch.randn(self.num_patches, self.hidden_size) self.embedding = nn.Parameter(self.scale * position_embedding) # tile position embedding self.tile_embedding = nn.Embedding( self.max_aspect_ratio_id + 1, self.max_num_tiles * self.num_patches * self.hidden_size) def forward(self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor) -> torch.Tensor: # position embeddings gated_position_embedding = (1 - self.gate.tanh()) * self.embedding hidden_state = hidden_state + gated_position_embedding.view( 1, 1, self.num_patches, self.hidden_size) # precomputed tile position embeddings tile_position_embedding = self.tile_embedding(aspect_ratio_ids) batch_size = hidden_state.shape[0] tile_position_embedding = tile_position_embedding.reshape( batch_size, self.max_num_tiles, self.num_patches, self.hidden_size) gated_tile_position_embedding = self.gate.tanh( ) * tile_position_embedding hidden_state = hidden_state + gated_tile_position_embedding return hidden_state # TODO: support other attention backends for attention in vision model class MllamaVisionSdpaAttention(nn.Module): def __init__(self, config: config_mllama.MllamaVisionConfig): super().__init__() model_parallel_size = get_tensor_model_parallel_world_size() self.embed_dim = config.hidden_size self.num_heads = config.attention_heads self.head_dim = config.hidden_size // config.attention_heads self.num_local_heads = self.num_heads // model_parallel_size self.q_size = self.num_local_heads * self.head_dim self.kv_size = self.num_local_heads * self.head_dim self.scale = 1 / math.sqrt(self.head_dim) self.qkv_proj = QKVParallelLinear( self.embed_dim, self.head_dim, self.num_heads, bias=False, ) self.o_proj = RowParallelLinear( self.num_heads * self.head_dim, self.embed_dim, bias=False, input_is_parallel=True, ) def forward( self, hidden_state: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_state) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q = q.view(q.shape[0], q.shape[1], self.num_local_heads, self.head_dim).transpose(1, 2) k = k.view(k.shape[0], k.shape[1], self.num_local_heads, self.head_dim).transpose(1, 2) v = v.view(v.shape[0], v.shape[1], self.num_local_heads, self.head_dim).transpose(1, 2) # TODO: remove padding in image encoder attn_output = torch.empty_like(q) for i in range(q.shape[0]): attn = q[i] @ (k[i] * self.scale).permute(0,2,1) if attention_mask is not None: attn = attn + attention_mask[i] attn = torch.softmax(attn, dim=-1) output = attn @ v[i] attn_output[i] = output # attn_output = F.scaled_dot_product_attention(q, # k, # v, # attn_mask=attention_mask, # dropout_p=0.0) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(attn_output.shape[0], attn_output.shape[1], -1) output, _ = self.o_proj(attn_output) return output class MllamaVisionEncoderLayer(nn.Module): def __init__(self, config: config_mllama.MllamaVisionConfig, is_gated: bool = False): super().__init__() self.hidden_size = config.hidden_size self.num_attention_heads = config.attention_heads self.is_gated = is_gated self.intermediate_size = config.intermediate_size self.self_attn = MllamaVisionSdpaAttention(config) self.mlp = CLIPMLP(config) self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.norm_eps) self.post_attention_layernorm = nn.LayerNorm(self.hidden_size, eps=config.norm_eps) # there used to be an if else here, no code path if is_gated: self.gate_attn = nn.Parameter(torch.ones(1) * math.pi / 4) self.gate_ffn = nn.Parameter(torch.ones(1) * math.pi / 4) def forward( self, hidden_state: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, ): # Self Attention residual = hidden_state hidden_state = self.input_layernorm(hidden_state) hidden_state = self.self_attn(hidden_state, attention_mask=attention_mask) gate_attn = 1 if not self.is_gated else self.gate_attn.tanh() hidden_state = residual + gate_attn * hidden_state # Feed forward residual = hidden_state hidden_state = self.post_attention_layernorm(hidden_state) hidden_state = self.mlp(hidden_state) gate_ffn = 1 if not self.is_gated else self.gate_ffn.tanh() hidden_state = residual + gate_ffn * hidden_state return hidden_state class MllamaVisionEncoder(nn.Module): def __init__(self, config: config_mllama.MllamaVisionConfig, num_layers=32, is_gated=False, output_hidden_states=None): super().__init__() self.config = config self.layers = nn.ModuleList([ MllamaVisionEncoderLayer(config, is_gated) for _ in range(num_layers) ]) self.output_hidden_states = output_hidden_states or [] def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, ) -> Union[Tuple, BaseModelOutput]: encoder_states = () for i, encoder_layer in enumerate(self.layers): if i in self.output_hidden_states: encoder_states = encoder_states + (hidden_states, ) hidden_states = encoder_layer( hidden_states, attention_mask, ) if len(self.layers) - 1 in self.output_hidden_states: encoder_states = encoder_states + (hidden_states, ) return hidden_states, encoder_states class MllamaVisionModel(nn.Module): def __init__(self, config: config_mllama.MllamaVisionConfig): super().__init__() self.image_size = config.image_size self.patch_size = config.patch_size self.max_num_tiles = config.max_num_tiles self.hidden_size = config.hidden_size self.in_channels = config.num_channels self.intermediate_layers_indices = config.intermediate_layers_indices self.num_patches = (self.image_size // self.patch_size)**2 + 1 self.scale = config.hidden_size**-0.5 self.patch_embedding = ColumnParallelConv2dPatch( in_channels=config.num_channels, out_channels=self.hidden_size, kernel_size=self.patch_size, stride=self.patch_size, bias=False, ) self.class_embedding = nn.Parameter(self.scale * torch.randn(self.hidden_size)) self.gated_positional_embedding = MllamaPrecomputedPositionEmbedding( config) self.pre_tile_positional_embedding = \ MllamaPrecomputedAspectRatioEmbedding(config, is_gated=True) self.post_tile_positional_embedding = \ MllamaPrecomputedAspectRatioEmbedding(config, is_gated=True) # layer norms self.layernorm_pre = nn.LayerNorm(self.hidden_size) self.layernorm_post = nn.LayerNorm(self.hidden_size) # encoders self.transformer = MllamaVisionEncoder( config, config.num_hidden_layers, is_gated=False, output_hidden_states=config.intermediate_layers_indices) self.global_transformer = MllamaVisionEncoder(config, config.num_global_layers, is_gated=True) def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor: batch_size, _, hidden_size = hidden_state.shape class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size) hidden_state = torch.cat([class_embedding, hidden_state], dim=1) return hidden_state def forward(self, pixel_values: torch.Tensor, aspect_ratio_ids: torch.Tensor, aspect_ratio_mask: torch.Tensor) -> torch.Tensor: batch_size, num_concurrent_media, num_tiles, num_channels, \ height, width = pixel_values.shape pixel_values = pixel_values.reshape( batch_size * num_concurrent_media * num_tiles, num_channels, height, width) aspect_ratio_ids = aspect_ratio_ids.reshape( batch_size * num_concurrent_media, -1) # patch embedding patch_embeds = self.patch_embedding( pixel_values.to(self.layernorm_pre.weight.dtype)) hidden_state = patch_embeds hidden_state = ps.get_tp_group().all_gather(hidden_state) # tile embeddings _, num_patches, dim = hidden_state.shape hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, num_tiles, -1, dim) hidden_state = self.pre_tile_positional_embedding( hidden_state, aspect_ratio_ids) # apply cls token hidden_state = hidden_state.reshape( batch_size * num_concurrent_media * num_tiles, num_patches, dim) hidden_state = self.apply_class_embedding(hidden_state) num_patches += 1 # apply position embeddings hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, num_tiles, num_patches, dim) hidden_state = self.gated_positional_embedding(hidden_state, aspect_ratio_ids) # apply encoder hidden_state = self.layernorm_pre(hidden_state) # Compute the number of tokens to pad num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8 # Compute padding tuple for pad function padding = ( 0, 0, 0, num_padding_patches ) # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2) # Pad the tensor hidden_state = F.pad(hidden_state, padding, mode="constant", value=0) slice_index = -num_padding_patches if num_padding_patches > 0 else None attention_mask = aspect_ratio_mask.reshape( batch_size * num_concurrent_media, -1) attention_mask = _prepare_aspect_ratio_attention_mask( aspect_ratio_mask=attention_mask, num_patches=self.num_patches, target_length=hidden_state.shape[2], dtype=self.layernorm_pre.weight.dtype, ) hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1, dim) output = self.transformer( hidden_state, attention_mask=attention_mask, ) hidden_state, intermediate_hidden_states = output[0], output[1] intermediate_hidden_states = torch.stack(intermediate_hidden_states, dim=-1) # apply global encoder hidden_state = self.layernorm_post(hidden_state) hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, num_tiles, num_patches + num_padding_patches, dim) hidden_state = self.post_tile_positional_embedding( hidden_state, aspect_ratio_ids) hidden_state = hidden_state.reshape( batch_size * num_concurrent_media, num_tiles * (num_patches + num_padding_patches), dim) hidden_state = self.global_transformer( hidden_state, attention_mask=attention_mask)[0] hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, num_tiles, num_patches + num_padding_patches, dim) hidden_state = hidden_state[:, :, :slice_index] # adding intermediate layer outputs hidden_state = hidden_state.reshape(batch_size, num_concurrent_media, num_tiles, num_patches, dim) intermediate_hidden_states = intermediate_hidden_states.reshape( batch_size * num_concurrent_media, num_tiles, num_patches + num_padding_patches, -1) intermediate_hidden_states = intermediate_hidden_states[:, :, : slice_index] intermediate_hidden_states = intermediate_hidden_states.reshape( batch_size, num_concurrent_media, num_tiles, num_patches, -1) hidden_state = torch.cat([hidden_state, intermediate_hidden_states], dim=-1) return hidden_state class MllamaTextRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ MllamaTextRMSNorm is equivalent to T5LayerNorm """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" class MllamaTextCrossAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__( self, config: Optional[config_mllama.MllamaTextConfig] = None, layer_idx: Optional[int] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config self.model_parallel_size = get_tensor_model_parallel_world_size() self.num_heads = self.config.num_attention_heads self.num_local_heads = self.num_heads // self.model_parallel_size self.num_key_value_heads = self.config.num_key_value_heads self.num_local_key_value_heads = \ self.num_key_value_heads // self.model_parallel_size self.dropout = config.dropout self.hidden_size = config.hidden_size self.head_dim = config.hidden_size // self.num_heads self.layer_idx = layer_idx self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.q_local_size = self.num_local_heads * self.head_dim self.kv_local_size = self.num_local_key_value_heads * self.head_dim # TODO: change to Q/KV separate linear after #7448 is merged self.qkv_proj = QKVParallelLinear( self.hidden_size, self.head_dim, self.num_heads, self.num_key_value_heads, bias=False, quant_config=quant_config, ) self.o_proj = RowParallelLinear( self.num_heads * self.head_dim, self.hidden_size, bias=False, input_is_parallel=True, quant_config=quant_config, ) # vllm.model_executor.layers.layernorm.RMSNorm has precision issue, # use huggingface's instead self.q_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) self.k_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) self.scaling = self.head_dim**-0.5 self.attn = Attention( self.num_local_heads, self.head_dim, self.scaling, self.num_local_key_value_heads, ) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor], cross_attention_states: Optional[torch.Tensor], kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv_dec, _ = self.qkv_proj(hidden_states) q, _, _ = qkv_dec.split( [self.q_local_size, self.kv_local_size, self.kv_local_size], dim=-1) if cross_attention_states is None: k = None v = None else: qkv_enc, _ = self.qkv_proj(cross_attention_states) _, k, v = qkv_enc.split( [self.q_local_size, self.kv_local_size, self.kv_local_size], dim=-1) k = k.view(-1, self.num_local_key_value_heads, self.head_dim) v = v.view(-1, self.num_local_key_value_heads, self.head_dim) k = self.k_norm(k) q = q.view(-1, self.num_local_heads, self.head_dim) q = self.q_norm(q) output = self.attn(q, k, v, kv_cache, attn_metadata, attn_type=AttentionType.ENCODER_DECODER) out, _ = self.o_proj(output) return out class MllamaCrossAttentionDecoderLayer(torch.nn.Module): """Cross-attention transformer block with tanh-gated attention and feedforward.""" def __init__(self, config: config_mllama.MllamaTextConfig, layer_idx: int, quant_config: Optional[QuantizationConfig]) \ -> None: super().__init__() self.layer_idx = layer_idx self.cross_attn = MllamaTextCrossAttention( config=config, layer_idx=layer_idx, quant_config=quant_config, ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.cross_attn_attn_gate = torch.nn.Parameter(torch.zeros(1)) self.mlp = LlamaMLP( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, ) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.cross_attn_mlp_gate = torch.nn.Parameter(torch.zeros(1)) def forward( self, hidden_states: torch.Tensor, cross_attention_states: torch.Tensor, cross_attention_mask: torch.Tensor, full_text_row_masked_out_mask: torch.Tensor, kv_cache: List[torch.Tensor], attn_metadata: AttentionMetadata, ) -> torch.Tensor: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) hidden_states = self.cross_attn( hidden_states=hidden_states, attention_mask=cross_attention_mask, cross_attention_states=cross_attention_states, kv_cache=kv_cache, attn_metadata=attn_metadata, ) hidden_states = full_text_row_masked_out_mask * hidden_states hidden_states = residual + self.cross_attn_attn_gate.tanh( ) * hidden_states residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = full_text_row_masked_out_mask * hidden_states hidden_states = residual + self.cross_attn_mlp_gate.tanh( ) * hidden_states return hidden_states class MllamaTextModel(nn.Module): config_class = config_mllama.MllamaTextConfig base_model_prefix = "model" def __init__(self, config: config_mllama.MllamaTextConfig, cache_config: Optional[CacheConfig], quant_config: Optional[QuantizationConfig]): super().__init__() self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding(config.vocab_size + 8, config.hidden_size) self.cross_attention_layers = config.cross_attention_layers layers = [] for layer_idx in range(config.num_hidden_layers): if layer_idx in self.cross_attention_layers: layers.append( MllamaCrossAttentionDecoderLayer( config, layer_idx, quant_config=quant_config)) else: # TODO: force LlamaDecoderLayer to config.attention_bias=False layers.append( LlamaDecoderLayer(config, cache_config=cache_config, quant_config=quant_config)) self.layers = nn.ModuleList(layers) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, input_ids: torch.LongTensor, positions: Optional[torch.LongTensor], cross_attention_states: Optional[torch.LongTensor], cross_attention_mask: Optional[torch.LongTensor], full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]], kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, skip_cross_attention: bool, ) -> torch.Tensor: inputs_embeds = self.embed_tokens(input_ids) hidden_states = inputs_embeds for idx, decoder_layer in enumerate(self.layers): if isinstance(decoder_layer, MllamaCrossAttentionDecoderLayer): if not skip_cross_attention: hidden_states = decoder_layer( hidden_states=hidden_states, cross_attention_states=cross_attention_states, cross_attention_mask=cross_attention_mask, full_text_row_masked_out_mask= full_text_row_masked_out_mask, kv_cache=kv_caches[idx], attn_metadata=attn_metadata, ) elif isinstance(decoder_layer, LlamaDecoderLayer): hidden_states, residual = decoder_layer( positions=positions, hidden_states=hidden_states, kv_cache=kv_caches[idx], attn_metadata=attn_metadata, residual=None, ) hidden_states = hidden_states + residual else: raise ValueError( f"Unknown decoder layer type {type(decoder_layer)}") hidden_states = self.norm(hidden_states) return hidden_states class MllamaForCausalLM(nn.Module): config_class = config_mllama.MllamaTextConfig base_model_prefix = "language_model" _no_split_modules = [ "MllamaCrossAttentionDecoderLayer", "MllamaSelfAttentionDecoderLayer" ] def __init__(self, config: config_mllama.MllamaTextConfig, cache_config: Optional[CacheConfig], quant_config: Optional[QuantizationConfig]): super().__init__() self.vocab_size = config.vocab_size self.model = MllamaTextModel(config, cache_config, quant_config) self.lm_head = ParallelLMHead( config.vocab_size, config.hidden_size, org_num_embeddings=config.vocab_size, padding_size=DEFAULT_VOCAB_PADDING_SIZE, quant_config=quant_config, ) def forward( self, input_ids: torch.LongTensor, positions: Optional[torch.LongTensor], cross_attention_states: Optional[torch.LongTensor], cross_attention_mask: Optional[torch.LongTensor], full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]], kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, skip_cross_attention: bool, ) -> torch.Tensor: hidden_states = self.model( input_ids=input_ids, positions=positions, cross_attention_states=cross_attention_states, cross_attention_mask=cross_attention_mask, full_text_row_masked_out_mask=full_text_row_masked_out_mask, kv_caches=kv_caches, attn_metadata=attn_metadata, skip_cross_attention=skip_cross_attention, ) return hidden_states @MULTIMODAL_REGISTRY.register_image_input_mapper() @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_mllama_image_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_decoder_data_for_mllama) @INPUT_REGISTRY.register_dummy_encoder_data(dummy_encoder_data_for_mllama) @INPUT_REGISTRY.register_input_processor(input_processor_for_mllama) class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal): def __init__(self, config: config_mllama.MllamaConfig, multimodal_config: MultiModalConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() self.vocab_size = config.text_config.vocab_size self.hidden_size = config.text_config.hidden_size self.max_num_tiles = config.vision_config.max_num_tiles self.vision_output_dim = config.vision_config.vision_output_dim self.pad_token_id = \ config.pad_token_id if config.pad_token_id is not None else -1 self.image_size = config.vision_config.image_size self.vision_model = MllamaVisionModel(config.vision_config) self.language_model = MllamaForCausalLM( config.text_config, cache_config=cache_config, quant_config=quant_config, ) self.multi_modal_projector = nn.Linear( config.vision_config.vision_output_dim, config.text_config.hidden_size, bias=True, ) self.logits_processor = LogitsProcessor(config.output_hidden_states, config.text_config.vocab_size) self.sampler = Sampler() def compute_logits( self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.language_model.lm_head, hidden_states, sampling_metadata) return logits def sample( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def _parse_and_validate_image_input(self, **kwargs: object): # tensor with the same shape will be batched together by # MultiModalInputs.batch, so pixel_values here can be: # - List[List[torch.Tensor]]: # with shape (num_tiles, 3, image_res, image_res) # - List[torch.Tensor]: # with shape (num_image, num_tiles, 3, image_res, image_res) # - torch.Tensor: # with shape (bs, num_image, num_tiles, 3, image_res, image_res) pixel_values: Optional[Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor]] = kwargs.pop( "pixel_values", None) image_embeds: Optional[Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor]] = kwargs.pop( "image_embeds", None) aspect_ratio_ids: Optional[Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor]] = kwargs.pop( "aspect_ratio_ids", None) aspect_ratio_mask: Optional[Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor]] = kwargs.pop( "aspect_ratio_mask", None) if pixel_values is None and image_embeds is None: return None if pixel_values is not None and image_embeds is not None: raise ValueError( "Both pixel values and image embeds are provided.") if pixel_values is not None: assert aspect_ratio_ids is not None assert aspect_ratio_mask is not None max_num_images = max([len(x[0]) for x in pixel_values]) if max_num_images == 0: raise ValueError("No images provided.") max_num_tiles = max( max([len(x) for x in y[0]]) for y in pixel_values) device = self.multi_modal_projector.weight.device bsz = len(pixel_values) out_num_tiles = [] out_images = torch.zeros( bsz, max_num_images, max_num_tiles, 3, self.image_size, self.image_size, dtype=torch.float32, device=device, ) out_ar_ids = torch.ones(bsz, max_num_images, dtype=torch.int64, device=device) out_ar_mask = torch.zeros(bsz, max_num_images, max_num_tiles, dtype=torch.int64, device=device) for b in range(len(pixel_values)): _num_tiles = [] for i in range(len(pixel_values[b][0])): img = pixel_values[b][0][i] out_images[b, i, :img.shape[0]] = img out_ar_ids[b, i] = aspect_ratio_ids[b][0][i] out_ar_mask[b, i] = aspect_ratio_mask[b][0][i] _num_tiles.append(img.shape[0]) out_num_tiles.append(_num_tiles) return MllamaImagePixelInputs( type="pixel_values", data=out_images, aspect_ratio_ids=out_ar_ids, aspect_ratio_mask=out_ar_mask, ) if image_embeds is not None: raise NotImplementedError raise AssertionError("This line should be unreachable.") def flat_encoder_result(self, cross_attention_states: torch.Tensor, attn_metadata: AttentionMetadata): cross_attention_states_flat = torch.zeros( sum(attn_metadata.encoder_seq_lens), cross_attention_states.shape[-1], device=cross_attention_states.device, dtype=cross_attention_states.dtype) start_pos = 0 for seq_len, vision_token_in_batch in zip( attn_metadata.encoder_seq_lens, cross_attention_states): end_pos = start_pos + seq_len cross_attention_states_flat[ start_pos:end_pos] = vision_token_in_batch[:seq_len] start_pos = end_pos cross_attention_states = cross_attention_states_flat full_text_row_masked_out_mask = torch.ones( (attn_metadata.num_prefill_tokens, 1), dtype=torch.bool) start_pos = 0 for seq_len, encoder_seq_len in zip( attn_metadata.seq_lens_tensor.cpu(), attn_metadata.encoder_seq_lens): if encoder_seq_len == 0: full_text_row_masked_out_mask[start_pos:start_pos + seq_len] = False start_pos += seq_len full_text_row_masked_out_mask = full_text_row_masked_out_mask.to( cross_attention_states.device) return cross_attention_states, full_text_row_masked_out_mask def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, **kwargs: object, ) -> Union[Tuple, CausalLMOutputWithPast]: if attn_metadata.num_prefill_tokens > 0 and \ attn_metadata.num_decode_tokens > 0: raise ValueError("Chunk prefill not supported") image_inputs = self._parse_and_validate_image_input(**kwargs) if image_inputs is None: cross_attention_mask = None full_text_row_masked_out_mask = ( attn_metadata.encoder_seq_lens_tensor != 0).reshape(-1, 1).to( input_ids.device) cross_attention_states = None skip_cross_attention = max(attn_metadata.encoder_seq_lens) == 0 else: # NOTE: llama's reference implementation runs vision model on CPU pixel_values = image_inputs['data'] aspect_ratio_ids = image_inputs['aspect_ratio_ids'] aspect_ratio_mask = image_inputs['aspect_ratio_mask'] cross_attention_states = self.vision_model(pixel_values, aspect_ratio_ids, aspect_ratio_mask) cross_attention_states = self.multi_modal_projector( cross_attention_states) bsz, _, _, _, image_token_dim = tuple(cross_attention_states.shape) cross_attention_states = cross_attention_states.view( bsz, -1, image_token_dim) cross_attention_states, full_text_row_masked_out_mask = \ self.flat_encoder_result(cross_attention_states, attn_metadata) skip_cross_attention = False # TODO: support multi-image by this mask cross_attention_mask = None outputs = self.language_model( input_ids=input_ids, positions=positions, cross_attention_states=cross_attention_states, cross_attention_mask=cross_attention_mask, full_text_row_masked_out_mask=full_text_row_masked_out_mask, kv_caches=kv_caches, attn_metadata=attn_metadata, skip_cross_attention=skip_cross_attention, ) return outputs def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), (".qkv_proj", ".k_proj", "k"), (".qkv_proj", ".v_proj", "v"), (".gate_up_proj", ".gate_proj", 0), (".gate_up_proj", ".up_proj", 1), ] params_dict = dict(self.named_parameters()) updated_params = set() for name, loaded_weight in weights: if 'patch_embedding.weight' in name: name = name.replace('patch_embedding.weight', 'patch_embedding._linear.weight') loaded_weight = loaded_weight.view(loaded_weight.shape[0], -1) for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) param = params_dict[name] updated_params.add(name) weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: param = params_dict.pop(name) weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight)