# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math from collections.abc import Iterable, Mapping, Sequence from dataclasses import dataclass, fields from functools import cached_property from typing import Annotated, Literal import torch import torch.nn as nn import torch.nn.functional as F from mistral_common.protocol.instruct.chunk import ImageChunk, TextChunk from mistral_common.protocol.instruct.messages import UserMessage from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.tokens.tokenizers.multimodal import ImageEncoder from PIL import Image from transformers import BatchFeature, PixtralVisionConfig, TensorType from transformers.image_utils import ImageInput from transformers.models.pixtral.image_processing_pixtral import ( _num_image_tokens as _get_pixtral_hf_num_image_tokens, ) from transformers.models.pixtral.modeling_pixtral import ( PixtralRotaryEmbedding, apply_rotary_pos_emb, position_ids_in_meshgrid, ) from transformers.tokenization_utils_base import TextInput from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_and_mul_fn from vllm.model_executor.layers.conv import Conv2dLayer from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear, ) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems from vllm.multimodal.inputs import ( MultiModalDataDict, MultiModalFieldConfig, MultiModalUUIDDict, NestedTensors, ) from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems from vllm.multimodal.processing import ( BaseMultiModalProcessor, BaseProcessingInfo, MultiModalProcessingInfo, PromptReplacement, PromptUpdate, PromptUpdateDetails, ) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.tokenizers import cached_tokenizer_from_config from vllm.tokenizers.mistral import MistralTokenizer from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .utils import init_vllm_registered_model, maybe_prefix from .vision import ( VisionEncoderInfo, VisionFeatureSelectStrategy, resolve_visual_encoder_outputs, ) try: # Note: vLLM does not install xformers by default. from xformers import ops as xops if current_platform.is_cuda() and current_platform.has_device_capability(100): # Xformers FA is not compatible with B200 USE_XFORMERS_OPS = False else: USE_XFORMERS_OPS = True except ImportError: USE_XFORMERS_OPS = False PATCH_MERGE = "patch_merge" class PixtralImagePixelInputs(TensorSchema): """ Dimensions: - bn: Batch size * number of images - c: Number of channels (3) - h: Height of each image - w: Width of each image The result of stacking `ImageEncoding.tokens` from each prompt. """ type: Literal["pixel_values"] = "pixel_values" images: Annotated[ torch.Tensor | list[torch.Tensor], TensorShape("bn", 3, "h", "w", dynamic_dims={"h", "w"}), ] class PixtralProcessorAdapter: """ Provide a HF-compatible interface for `mistral_common.tokens.tokenizers.multimodal.ImageEncoder`. """ def __init__(self, tokenizer: MistralTokenizer) -> None: super().__init__() self.tokenizer = tokenizer @property def image_processor(self) -> ImageEncoder: image_encoder = self.tokenizer.instruct.mm_encoder assert isinstance(image_encoder, ImageEncoder) return image_encoder @cached_property def image_break_id(self) -> int: return self.image_processor.special_ids.img_break @cached_property def image_token_id(self) -> int: return self.image_processor.special_ids.img @cached_property def image_end_id(self) -> int: return self.image_processor.special_ids.img_end @cached_property def image_size(self) -> int: return self.image_processor.mm_config.max_image_size @cached_property def patch_size(self) -> int: return self.image_processor.mm_config.image_patch_size def __call__( self, text: TextInput | list[TextInput] | None = None, images: ImageInput | list[ImageInput] | None = None, return_tensors: str | TensorType | None = None, **kwargs, ) -> Mapping[str, NestedTensors]: if text is None: text = [] if not isinstance(text, list): text = [text] if images is None: images = [] if not isinstance(images, list): images = [images] if not images: input_ids = self.tokenizer(text).input_ids return {"input_ids": torch.tensor(input_ids)} # Allow dummy text, which is used for profiling as well as token inputs if any(len(t) > 0 for t in text): raise ValueError( "You've passed text inputs instead of token inputs. " "Make sure to process your input via `mistral_common`'s " "tokenizer or pass a chat completion request. " "For more info, see: " "https://github.com/vllm-project/vllm/issues/8411." ) images_processed = list[torch.Tensor]() images_tokens = list[torch.Tensor]() for image in images: image_inputs = self.image_processor(ImageChunk(image=image)) image_processed = torch.tensor(image_inputs.image) image_tokens = torch.tensor(image_inputs.tokens) images_processed.append(image_processed) images_tokens.append(image_tokens) return BatchFeature( { "input_ids": torch.cat(images_tokens)[None].expand(len(text), -1), "images": images_processed, } ) class PixtralProcessingInfo(BaseProcessingInfo): def get_tokenizer(self) -> MistralTokenizer: tokenizer = cached_tokenizer_from_config(self.ctx.model_config) if not isinstance(tokenizer, MistralTokenizer): raise ValueError("This model requires `--tokenizer-mode mistral`") return tokenizer def get_hf_processor(self) -> PixtralProcessorAdapter: return PixtralProcessorAdapter(self.get_tokenizer()) def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"image": None} def get_vision_config( self, processor: PixtralProcessorAdapter | None = None, ): if processor is None: processor = self.get_hf_processor() return PixtralVisionConfig( image_size=processor.image_size, patch_size=processor.patch_size, ) def get_num_image_tokens( self, *, image_width: int, image_height: int, processor: PixtralProcessorAdapter | None = None, ) -> int: if processor is None: processor = self.get_hf_processor() ncols, nrows = processor.image_processor._image_to_num_tokens( Image.new("RGB", (image_width, image_height)) ) return ncols * nrows def get_image_size_with_most_features(self) -> ImageSize: image_processor = self.get_hf_processor().image_processor max_image_size = image_processor.mm_config.max_image_size return ImageSize(width=max_image_size, height=max_image_size) class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: return "" def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) target_width, target_height = self.info.get_image_size_with_most_features() image_overrides = mm_options.get("image") if mm_options else None return { "image": self._get_dummy_images( width=target_width, height=target_height, num_images=num_images, overrides=image_overrides, ) } def get_dummy_processor_inputs( self, seq_len: int, mm_counts: Mapping[str, int], mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> ProcessorInputs: tokenizer = self.info.get_tokenizer() dummy_text = self.get_dummy_text(mm_counts) dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts, mm_options) dummy_images = dummy_mm_data.get("image", []) tokenization_kwargs = {"truncation": False} request = ChatCompletionRequest( messages=[ UserMessage( content=[ TextChunk(text=dummy_text), *(ImageChunk(image=image) for image in dummy_images), ] ), ] ) res = tokenizer.mistral.encode_chat_completion(request) dummy_tokens = res.tokens return ProcessorInputs( prompt=dummy_tokens, mm_data=dummy_mm_data, tokenization_kwargs=tokenization_kwargs, ) class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]): def _get_mm_fields_config( self, hf_inputs: Mapping[str, NestedTensors], hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: return dict(images=MultiModalFieldConfig.batched("image")) def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) image_break_id = processor.image_break_id image_token_id = processor.image_token_id image_end_id = processor.image_end_id def get_replacement(item_idx: int): images = mm_items.get_items("image", ImageProcessorItems) image_size = images.get_image_size(item_idx) ncols, nrows = processor.image_processor._image_to_num_tokens( Image.new("RGB", (image_size.width, image_size.height)) ) tokens = ([image_token_id] * ncols + [image_break_id]) * nrows tokens[-1] = image_end_id return PromptUpdateDetails.select_token_id(tokens, image_token_id) return [ PromptReplacement( modality="image", target="", # Never match the prompt (see below note) replacement=get_replacement, ), ] def _cached_apply_hf_processor( self, prompt: str | list[int], mm_data_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], mm_uuids: MultiModalUUIDDict | None = None, ) -> tuple[list[int], MultiModalProcessingInfo, bool]: prompt_ids, mm_info, _ = super()._cached_apply_hf_processor( prompt=prompt, mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, mm_uuids=mm_uuids, ) # NOTE: The tokens are already inserted by the chat template return prompt_ids, mm_info, True @MULTIMODAL_REGISTRY.register_processor( PixtralMultiModalProcessor, info=PixtralProcessingInfo, dummy_inputs=PixtralDummyInputsBuilder, ) class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): @classmethod def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): return None raise ValueError("Only image modality is supported") def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config multimodal_config = vllm_config.model_config.multimodal_config self.config = config self.multimodal_config = multimodal_config dataclass_fields = {field.name for field in fields(VisionEncoderArgs)} vision_args = { key: value for key, value in self.config.vision_config.to_dict().items() if key in dataclass_fields } self.vision_args = VisionEncoderArgs(**vision_args) # init MistralForCausalLM self.language_model = init_vllm_registered_model( vllm_config=vllm_config, hf_config=config.text_config, prefix=maybe_prefix(prefix, "language_model"), ) if multimodal_config.get_limit_per_prompt("image"): self.vision_encoder = VisionTransformer(self.vision_args) self.pre_mm_projector_norm = ( RMSNorm(self.vision_args.hidden_size, eps=1e-5) if self.vision_args.add_pre_mm_projector_layer_norm else None ) self.patch_merger = ( PatchMerger( vision_encoder_dim=self.vision_args.hidden_size, spatial_merge_size=self.vision_args.spatial_merge_size, use_mlp_bias=False, ) if self.vision_args.mm_projector_id == PATCH_MERGE else None ) self.vision_language_adapter = VisionLanguageAdapter( self.vision_args, dim=config.text_config.hidden_size ) else: self.vision_encoder = None self.pre_mm_projector_norm = None self.patch_merger = None self.vision_language_adapter = None self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors ) def _parse_and_validate_image_input( self, **kwargs: object ) -> PixtralImagePixelInputs | None: images = kwargs.pop("images", None) if images is None: return None return PixtralImagePixelInputs( type="pixel_values", images=images, ) def _process_image_input( self, image_input: PixtralImagePixelInputs, ) -> tuple[torch.Tensor, ...]: assert ( self.vision_encoder is not None and self.vision_language_adapter is not None ) images = image_input["images"] image_features = self.vision_encoder(images) feature_sizes = [image_feature.shape[0] for image_feature in image_features] image_features = torch.cat(image_features) if self.pre_mm_projector_norm is not None: image_features = self.pre_mm_projector_norm(image_features) if self.patch_merger is not None: patch_size = self.vision_args.patch_size spatial_merge_size_square = self.vision_args.spatial_merge_size**2 img_patch_dims = [ (img.shape[1] // patch_size, img.shape[2] // patch_size) for img in images ] feature_sizes = [ feature_size // spatial_merge_size_square for feature_size in feature_sizes ] image_features = self.patch_merger( image_features, image_sizes=img_patch_dims ) image_embeds = self.vision_language_adapter(image_features) image_embeds = torch.split(image_embeds, feature_sizes) return image_embeds def get_language_model(self) -> torch.nn.Module: return self.language_model def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] return self._process_image_input(image_input) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: IntermediateTensors | None = None, inputs_embeds: torch.Tensor | None = None, **kwargs: object, ) -> torch.Tensor | IntermediateTensors: """Run forward pass for pixtral.""" if intermediate_tensors is not None: inputs_embeds = None hidden_states = self.language_model.model( input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, ) -> torch.Tensor | None: return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def is_vision_encoder_weights(weight: tuple[str, torch.Tensor]): return weight[0].startswith("vision_encoder") def is_vision_lang_adapter_weights(weight: tuple[str, torch.Tensor]): return weight[0].startswith("vision_language_adapter") def is_patch_merger(weight: tuple[str, torch.Tensor]): return weight[0].startswith("patch_merger") def is_pre_mm_projector_norm(weight: tuple[str, torch.Tensor]): return weight[0].startswith("pre_mm_projector_norm") # Get references to parameters for direct loading vision_encoder_dict = ( dict(self.vision_encoder.named_parameters()) if self.vision_encoder is not None else {} ) patch_merger_dict = ( dict(self.patch_merger.named_parameters()) if self.patch_merger is not None else {} ) pre_mm_projector_norm_dict = ( dict(self.pre_mm_projector_norm.named_parameters()) if self.pre_mm_projector_norm is not None else {} ) vision_lang_adapter_dict = ( dict(self.vision_language_adapter.named_parameters()) if self.vision_language_adapter is not None else {} ) def llm_weights_generator(): # Single pass over weights for name, w in weights: if is_vision_encoder_weights((name, w)): if self.vision_encoder is None: continue # Load vision encoder weights directly trimmed_name = ".".join(name.split(".")[1:]) param = vision_encoder_dict[trimmed_name] with torch.no_grad(): default_weight_loader(param, w) elif is_patch_merger((name, w)): if self.patch_merger is None: continue # Load vision patch merger weights directly trimmed_name = ".".join(name.split(".")[1:]) param = patch_merger_dict[trimmed_name] with torch.no_grad(): default_weight_loader(param, w) elif is_pre_mm_projector_norm((name, w)): if self.pre_mm_projector_norm is None: continue # Load vision pre_mm_projector_norm weights directly trimmed_name = ".".join(name.split(".")[1:]) param = pre_mm_projector_norm_dict[trimmed_name] with torch.no_grad(): default_weight_loader(param, w) elif is_vision_lang_adapter_weights((name, w)): if self.vision_language_adapter is None: continue # Load vision-language adapter weights directly trimmed_name = ".".join(name.split(".")[1:]) param = vision_lang_adapter_dict[trimmed_name] with torch.no_grad(): default_weight_loader(param, w) else: # LLM weights: yield them to be loaded # by language_model.load_weights yield (name, w) # Now we call the language model load with the generator self.language_model.load_weights(llm_weights_generator()) # Vision encoder @dataclass class VisionEncoderArgs: hidden_size: int num_channels: int image_size: int patch_size: int intermediate_size: int num_hidden_layers: int num_attention_heads: int rope_theta: float # for rope-2D image_token_id: int adapter_bias: bool = True spatial_merge_size: int = 1 add_pre_mm_projector_layer_norm: bool = False mm_projector_id: str = "" def _reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor: """ freqs_cis: complex - (seq_len, head_dim / 2) x: complex - (bsz, seq_len, head_dim / 2) """ ndim = x.ndim assert ndim > 1 assert freqs_cis.shape == (x.shape[1], x.shape[-1]), ( freqs_cis.shape, (x.shape[1], x.shape[-1]), ) shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] return freqs_cis.view(*shape) def precompute_freqs_cis_2d( dim: int, height: int, width: int, theta: float, ) -> torch.Tensor: """ freqs_cis: 2D complex tensor of shape (height, width, dim // 2) to be indexed by (height, width) position tuples """ # (dim / 2) frequency bases freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) h = torch.arange(height, device=freqs.device) w = torch.arange(width, device=freqs.device) freqs_h = torch.outer(h, freqs[::2]).float() freqs_w = torch.outer(w, freqs[1::2]).float() freqs_2d = torch.cat( [ freqs_h[:, None, :].repeat(1, width, 1), freqs_w[None, :, :].repeat(height, 1, 1), ], dim=-1, ) return torch.polar(torch.ones_like(freqs_2d), freqs_2d) def apply_rotary_emb_vit( xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) assert freqs_cis.dtype == torch.complex64 freqs_cis = _reshape_for_broadcast(freqs_cis, xq_) xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) return xq_out.type_as(xq), xk_out.type_as(xk) class FeedForward(nn.Module): def __init__(self, args: VisionEncoderArgs): super().__init__() assert args.intermediate_size is not None self.w1 = nn.Linear(args.hidden_size, args.intermediate_size, bias=False) self.w2 = nn.Linear(args.intermediate_size, args.hidden_size, bias=False) self.w3 = nn.Linear(args.hidden_size, args.intermediate_size, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.w2(F.silu(self.w1(x)) * self.w3(x)) class Attention(nn.Module): def __init__(self, args: VisionEncoderArgs): super().__init__() self.args = args assert not args.hidden_size % args.num_attention_heads self.n_heads = args.num_attention_heads self.head_dim = args.hidden_size // args.num_attention_heads self.wq = nn.Linear(args.hidden_size, args.hidden_size, bias=False) self.wk = nn.Linear(args.hidden_size, args.hidden_size, bias=False) self.wv = nn.Linear(args.hidden_size, args.hidden_size, bias=False) self.wo = nn.Linear(args.hidden_size, args.hidden_size, bias=False) def forward( self, x: torch.Tensor, mask: torch.Tensor, freqs_cis: torch.Tensor, ) -> torch.Tensor: batch, patches, _ = x.shape q, k, v = self.wq(x), self.wk(x), self.wv(x) q = q.reshape(batch, patches, self.n_heads, self.head_dim) k = k.reshape(batch, patches, self.n_heads, self.head_dim) v = v.reshape(batch, patches, self.n_heads, self.head_dim) q, k = apply_rotary_emb_vit(q, k, freqs_cis=freqs_cis) if USE_XFORMERS_OPS: out = xops.memory_efficient_attention(q, k, v, attn_bias=mask) else: q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) out = nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask) out = out.transpose(1, 2) out = out.reshape(batch, patches, self.n_heads * self.head_dim) return self.wo(out) class TransformerBlock(nn.Module): def __init__(self, args: VisionEncoderArgs): super().__init__() self.attention = Attention(args) self.feed_forward = FeedForward(args) self.attention_norm = RMSNorm(args.hidden_size, eps=1e-5) self.ffn_norm = RMSNorm(args.hidden_size, eps=1e-5) def forward( self, x: torch.Tensor, mask: torch.Tensor, freqs_cis: torch.Tensor, ) -> torch.Tensor: r = self.attention.forward( self.attention_norm(x), mask=mask, freqs_cis=freqs_cis ) h = x + r r = self.feed_forward.forward(self.ffn_norm(h)) out = h + r return out class Transformer(nn.Module): def __init__(self, args: VisionEncoderArgs): super().__init__() self.layers = torch.nn.ModuleList() for _ in range(args.num_hidden_layers): self.layers.append(TransformerBlock(args)) def forward( self, x: torch.Tensor, mask: torch.Tensor, freqs_cis: torch.Tensor | None, ) -> torch.Tensor: for layer in self.layers: x = layer(x, mask=mask, freqs_cis=freqs_cis) return x def position_meshgrid( patch_embeds_list: list[torch.Tensor], ) -> torch.Tensor: positions = torch.cat( [ torch.stack( torch.meshgrid( torch.arange(p.shape[-2]), torch.arange(p.shape[-1]), indexing="ij", ), dim=-1, ).reshape(-1, 2) for p in patch_embeds_list ] ) return positions class VisionTransformer(nn.Module): def __init__(self, args: VisionEncoderArgs): super().__init__() self.args = args self.patch_conv = Conv2dLayer( in_channels=args.num_channels, out_channels=args.hidden_size, kernel_size=args.patch_size, stride=args.patch_size, bias=False, ) self.ln_pre = RMSNorm(args.hidden_size, eps=1e-5) self.transformer = Transformer(args) head_dim = self.args.hidden_size // self.args.num_attention_heads assert head_dim % 2 == 0, "ROPE requires even head_dim" self._freqs_cis: torch.Tensor | None = None @property def max_patches_per_side(self) -> int: return self.args.image_size // self.args.patch_size @property def device(self) -> torch.types.Device: return next(self.parameters()).device @property def dtype(self) -> torch.dtype: return next(self.parameters()).dtype @property def freqs_cis(self) -> torch.Tensor: if self._freqs_cis is None: self._freqs_cis = precompute_freqs_cis_2d( dim=self.args.hidden_size // self.args.num_attention_heads, height=self.max_patches_per_side, width=self.max_patches_per_side, theta=self.args.rope_theta, ) if self._freqs_cis.device != self.device: self._freqs_cis = self._freqs_cis.to(device=self.device) return self._freqs_cis def forward( self, images: list[torch.Tensor], ) -> torch.Tensor: """ Args: images: list of N_img images of variable sizes, each of shape (C, H, W) Returns: image_features: tensor of token features for all tokens of all images of shape (N_toks, D) """ # pass images through initial convolution independently patch_embeds_list = [ self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in images ] patch_embeds = [p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list] embed_sizes = [p.shape[1] for p in patch_embeds] # flatten to a single sequence patch_embeds = torch.cat(patch_embeds, dim=1) patch_embeds = self.ln_pre(patch_embeds) # positional embeddings positions = position_meshgrid(patch_embeds_list).to(self.device) freqs_cis = self.freqs_cis[positions[:, 0], positions[:, 1]] # pass through Transformer with a block diagonal mask delimiting images if USE_XFORMERS_OPS: mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens( [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], ) else: from transformers.models.pixtral.modeling_pixtral import ( generate_block_attention_mask, ) mask = generate_block_attention_mask( [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], patch_embeds ) out = self.transformer(patch_embeds, mask=mask, freqs_cis=freqs_cis) # squeeze dim 0 and split into separate tensors for each image return torch.split(out.squeeze(0), embed_sizes) class VisionLanguageAdapter(nn.Module): def __init__(self, args: VisionEncoderArgs, dim: int): super().__init__() assert isinstance(args, VisionEncoderArgs) self.w_in = nn.Linear( args.hidden_size, dim, bias=args.adapter_bias, ) self.gelu = nn.GELU() self.w_out = nn.Linear(dim, dim, bias=args.adapter_bias) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.w_out(self.gelu(self.w_in(x))) class PatchMerger(nn.Module): """ Learned merging of spatial_merge_size ** 2 patches """ def __init__( self, vision_encoder_dim: int, spatial_merge_size: int, use_mlp_bias: bool = False, ) -> None: super().__init__() mlp_input_dim = vision_encoder_dim * (spatial_merge_size**2) self.spatial_merge_size = spatial_merge_size self.mlp_input_dim = mlp_input_dim self.merging_layer = nn.Linear( mlp_input_dim, vision_encoder_dim, bias=use_mlp_bias, ) def forward( self, x: torch.Tensor, image_sizes: list[tuple[int, int]] ) -> torch.Tensor: # image_sizes specified in tokens assert sum([h * w for h, w in image_sizes]) == len(x) # x is (N, vision_encoder_dim) x = self.permute(x, image_sizes) # x is (N / spatial_merge_size ** 2, # vision_encoder_dim * spatial_merge_size ** 2) x = self.merging_layer(x) # x is (N / spatial_merge_size ** 2, vision_encoder_dim) return x def permute( self, x: torch.Tensor, image_sizes: list[tuple[int, int]], ) -> torch.Tensor: """ Args: x: (N, D) where N is flattened and concatenated patch tokens for all images image_sizes: list of tuple of (height, width) in tokens for each image Returns: image_features: reorders patch tokens so each grid of (spatial_merge_size, spatial_merge_size) is contiguous. now (N / spatial_merge_size ** 2, D * spatial_merge_size ** 2) """ sub_grids = get_sub_grids( x=x, image_sizes=image_sizes, spatial_merge_size=self.spatial_merge_size ) # list of [d x sub_grid_size x sub_grid_size x n_patches] permuted_tensor: list[torch.Tensor] = [] for grid in sub_grids: n_patches = grid.shape[-1] permuted_tensor.append( grid.view(-1, n_patches).t() ) # n_patches x d * sub_grid_size * sub_grid_size return torch.cat( permuted_tensor, dim=0 ) # (N / spatial_merge_size ** 2, d * spatial_merge_size ** 2) def get_sub_grids( x: torch.Tensor, image_sizes: list[tuple[int, int]], spatial_merge_size: int, ) -> list[torch.Tensor]: # image_sizes specified in tokens tokens_per_image = [h * w for h, w in image_sizes] d = x.shape[-1] all_img_sub_grids: list[torch.Tensor] = [] sub_grid_size = spatial_merge_size for image_index, image_tokens in enumerate(x.split(tokens_per_image)): # Reshape image_tokens into a 2D grid h, w = image_sizes[image_index] image_grid = image_tokens.view(h, w, d).permute(2, 0, 1)[ None, :, :, : ] # 1 x d x h x w sub_grids = torch.nn.functional.unfold( image_grid, kernel_size=sub_grid_size, stride=sub_grid_size ) sub_grids = sub_grids.view( 1, d, sub_grid_size, sub_grid_size, -1 ) # 1 x d x sub_grid_size x sub_grid_size x n_patches all_img_sub_grids.append(sub_grids[0]) return all_img_sub_grids #### HF Transformers version of Pixtral #### # Based off https://github.com/huggingface/transformers/blob/d7950bff82b18c823193d17d72188c5e46d06c83/src/transformers/models/pixtral/modeling_pixtral.py # This model follows the Llava family, meaning image embeddings are placed # instead of the `[IMG]` token placeholders. # The model uses [`PixtralVisionModel`] for its vision encoder, # and [`MistralForCausalLM`] for its language decoder. class PixtralHFEncoderInfo(VisionEncoderInfo[PixtralVisionConfig]): def get_num_image_tokens( self, *, image_width: int, image_height: int, ) -> int: ncols, nrows = self.get_patch_grid_size( image_width=image_width, image_height=image_height, ) return ncols * nrows def get_image_size(self) -> int: return self.vision_config.image_size def get_patch_size(self) -> int: # spatial_merge_size is needed for Mistral3 spatial_merge_size = getattr(self.hf_config, "spatial_merge_size", 1) return self.vision_config.patch_size * spatial_merge_size def get_patch_grid_length(self) -> int: image_size, patch_size = self.get_image_size(), self.get_patch_size() # Since interpolation is applied, the image size need not be divisible # assert image_size % patch_size == 0 return image_size // patch_size # Adapted from: https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/pixtral/image_processing_pixtral.py#L99 def get_patch_grid_size( self, *, image_width: int, image_height: int, ) -> tuple[int, int]: max_width = max_height = self.get_image_size() patch_width = patch_height = self.get_patch_size() ratio = max(image_width / max_width, image_height / max_height) if ratio > 1: image_width = int(math.floor(image_width / ratio)) image_height = int(math.floor(image_height / ratio)) nrows, ncols = _get_pixtral_hf_num_image_tokens( (image_height, image_width), (patch_height, patch_width), ) # type: ignore return ncols, nrows class PixtralHFMLP(nn.Module): def __init__( self, config: PixtralVisionConfig, quant_config: QuantizationConfig | None = None, *, prefix: str = "", ) -> None: super().__init__() assert config.intermediate_size is not None self.gate_up_proj = MergedColumnParallelLinear( input_size=config.hidden_size, output_sizes=[config.intermediate_size] * 2, bias=False, quant_config=quant_config, prefix=f"{prefix}.gate_up_proj", ) self.down_proj = RowParallelLinear( input_size=config.intermediate_size, output_size=config.hidden_size, bias=False, quant_config=quant_config, prefix=f"{prefix}.down_proj", ) self.act_and_mul = get_act_and_mul_fn(config.hidden_act) def forward(self, x: torch.Tensor) -> torch.Tensor: gate_up, _ = self.gate_up_proj(x) x = self.act_and_mul(gate_up) x, _ = self.down_proj(x) return x class PixtralHFAttention(nn.Module): def __init__( self, config: PixtralVisionConfig, quant_config: QuantizationConfig | None = None, *, prefix: str = "", ) -> None: super().__init__() self.config = config assert not config.hidden_size % config.num_attention_heads self.total_num_heads = config.num_attention_heads tp_size = get_tensor_model_parallel_world_size() self.n_heads = divide(config.num_attention_heads, tp_size) self.head_dim = config.hidden_size // config.num_attention_heads self.qkv_proj = QKVParallelLinear( hidden_size=config.hidden_size, head_size=self.head_dim, total_num_heads=self.total_num_heads, bias=False, quant_config=quant_config, prefix=f"{prefix}.qkv_proj", ) assert self.total_num_heads * self.head_dim == config.hidden_size self.o_proj = RowParallelLinear( input_size=config.hidden_size, output_size=config.hidden_size, bias=False, quant_config=quant_config, prefix=f"{prefix}.o_proj", ) def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, position_embeddings: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor | None]: batch, patches, _ = hidden_states.size() qkv_states, _ = self.qkv_proj(hidden_states) q, k, v = qkv_states.chunk(3, dim=-1) # Transpose q and k to apply HF's Rotary Position Embedding q = q.view(batch, patches, self.n_heads, self.head_dim).transpose(1, 2) k = k.view(batch, patches, self.n_heads, self.head_dim).transpose(1, 2) v = v.view(batch, patches, self.n_heads, self.head_dim) cos, sin = position_embeddings q, k = apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=0) if USE_XFORMERS_OPS: # Transpose q and k back for attention q = q.transpose(1, 2).contiguous() k = k.transpose(1, 2).contiguous() out = xops.memory_efficient_attention(q, k, v, attn_bias=attention_mask) else: v = v.transpose(1, 2) out = nn.functional.scaled_dot_product_attention( q, k, v, attn_mask=attention_mask ) out = out.transpose(1, 2) out = out.reshape(batch, patches, self.n_heads * self.head_dim) attn_output, _ = self.o_proj(out) return attn_output, None class PixtralHFTransformerBlock(nn.Module): def __init__( self, config: PixtralVisionConfig, quant_config: QuantizationConfig | None = None, *, prefix: str = "", ) -> None: super().__init__() self.attention_norm = RMSNorm(config.hidden_size, eps=1e-5) self.attention = PixtralHFAttention( config, quant_config=quant_config, prefix=f"{prefix}.attention" ) self.feed_forward = PixtralHFMLP( config, quant_config=quant_config, prefix=f"{prefix}.feed_forward" ) self.ffn_norm = RMSNorm(config.hidden_size, eps=1e-5) def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, position_embeddings: torch.Tensor, ) -> torch.Tensor: r, _ = self.attention.forward( self.attention_norm(hidden_states), attention_mask=attention_mask, position_embeddings=position_embeddings, ) h = hidden_states + r r = self.feed_forward.forward(self.ffn_norm(h)) out = h + r return out class PixtralHFTransformer(nn.Module): def __init__( self, config: PixtralVisionConfig, quant_config: QuantizationConfig | None = None, *, num_hidden_layers_override: int | None = None, prefix: str = "", ) -> None: super().__init__() if num_hidden_layers_override is None: num_hidden_layers = config.num_hidden_layers else: num_hidden_layers = num_hidden_layers_override self.layers = nn.ModuleList( [ PixtralHFTransformerBlock( config=config, quant_config=quant_config, prefix=f"{prefix}.layers.{layer_idx}", ) for layer_idx in range(num_hidden_layers) ] ) def forward( self, x: torch.Tensor, attention_mask: torch.Tensor, position_embeddings: torch.Tensor, return_all_hidden_states: bool, ) -> torch.Tensor: hidden_states_pool = [x] for layer in self.layers: x = layer(x, attention_mask, position_embeddings) if return_all_hidden_states: hidden_states_pool.append(x) # If we have multiple feature sample layers, we return all hidden # states in order and grab the ones we need by index. if return_all_hidden_states: return hidden_states_pool return x class PixtralHFVisionModel(nn.Module): def __init__( self, config: PixtralVisionConfig, quant_config: QuantizationConfig | None = None, *, num_hidden_layers_override: int | None = None, require_post_norm: bool | None = None, prefix: str = "", ) -> None: super().__init__() self.config = config self.patch_conv = Conv2dLayer( in_channels=config.num_channels, out_channels=config.hidden_size, kernel_size=config.patch_size, stride=config.patch_size, bias=False, ) self.ln_pre = RMSNorm(config.hidden_size, eps=1e-5) self.transformer = PixtralHFTransformer( config, quant_config, num_hidden_layers_override=num_hidden_layers_override, prefix=f"{prefix}.transformer", ) num_hidden_layers = config.num_hidden_layers if len(self.transformer.layers) > config.num_hidden_layers: raise ValueError( f"The original encoder only has {num_hidden_layers} " f"layers, but you requested {len(self.transformer.layers)} " "layers." ) if require_post_norm is True: msg = "PixtralHFVisionModel does not have post-layernorm" raise ValueError(msg) self.dtype = next(self.parameters()).dtype self.device = next(self.parameters()).device self.patch_positional_embedding = PixtralRotaryEmbedding(config, self.device) def forward( self, pixel_values: list[torch.Tensor], *, select_layers: list[int] | None = None, feature_select_strategy: VisionFeatureSelectStrategy | None = None, ) -> tuple[torch.Tensor, ...]: """ Args: pixel_values: Each image to be processed will be a separate tensor in pixel_values. This means it will be a list of tensors because multiple requests batched can have multiple images, each with their own shape potentially select_layers: Layer indices whose features should be concatenated and used as the visual encoder output. If none are provided, the last layer is used. Returns: image_features: tensor of token features for all tokens of all images of shape (N_toks, D) """ # pass images through initial convolution independently patch_embeds_list = [ self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in pixel_values ] patch_embeds = [p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list] embed_sizes = [p.shape[1] for p in patch_embeds] # flatten to a single sequence patch_embeds = torch.cat(patch_embeds, dim=1) patch_embeds = self.ln_pre(patch_embeds) # positional embeddings position_ids = position_ids_in_meshgrid( patch_embeds_list, max_width=self.config.image_size // self.config.patch_size, ).to(self.device) position_embedding = self.patch_positional_embedding(patch_embeds, position_ids) if USE_XFORMERS_OPS: attention_mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens( [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], ) else: from transformers.models.pixtral.modeling_pixtral import ( generate_block_attention_mask, ) attention_mask = generate_block_attention_mask( [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], patch_embeds ) out = self.transformer( patch_embeds, attention_mask, position_embedding, return_all_hidden_states=select_layers is not None, ) out = resolve_visual_encoder_outputs( out, None, select_layers=select_layers, max_possible_layers=self.config.num_hidden_layers, feature_select_strategy=feature_select_strategy, ) # squeeze dim 0 and split into separate tensors for each image return torch.split(out.squeeze(0), embed_sizes) # (TODO) Add prefix argument for filtering out weights to be loaded # ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986 def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), (".qkv_proj", ".k_proj", "k"), (".qkv_proj", ".v_proj", "v"), (".gate_up_proj", ".gate_proj", 0), (".gate_up_proj", ".up_proj", 1), ] params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() layer_count = len(self.transformer.layers) for name, loaded_weight in weights: # omit layers when num_hidden_layers_override is set if name.startswith("transformer.layers"): layer_idx = int(name.split(".")[2]) if layer_idx >= layer_count: continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params