2026-04-18 10:56:22 +08:00
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from collections . abc import Iterable , Mapping , Sequence
from dataclasses import dataclass , fields
from functools import cached_property
from typing import Annotated , Literal
import torch
import torch . nn as nn
import torch . nn . functional as F
from mistral_common . protocol . instruct . chunk import ImageChunk , TextChunk
from mistral_common . protocol . instruct . messages import UserMessage
from mistral_common . protocol . instruct . request import ChatCompletionRequest
from mistral_common . tokens . tokenizers . multimodal import ImageEncoder
from PIL import Image
from transformers import BatchFeature , PixtralVisionConfig , TensorType
from transformers . image_utils import ImageInput
from transformers . models . pixtral . image_processing_pixtral import (
_num_image_tokens as _get_pixtral_hf_num_image_tokens ,
)
from transformers . models . pixtral . modeling_pixtral import (
PixtralRotaryEmbedding ,
apply_rotary_pos_emb ,
position_ids_in_meshgrid ,
)
from transformers . tokenization_utils_base import TextInput
from vllm . config import VllmConfig
from vllm . config . multimodal import BaseDummyOptions
from vllm . distributed import divide , get_tensor_model_parallel_world_size
from vllm . model_executor . layers . activation import get_act_and_mul_fn
from vllm . model_executor . layers . conv import Conv2dLayer
from vllm . model_executor . layers . layernorm import RMSNorm
from vllm . model_executor . layers . linear import (
MergedColumnParallelLinear ,
QKVParallelLinear ,
RowParallelLinear ,
)
from vllm . model_executor . layers . quantization import QuantizationConfig
from vllm . model_executor . model_loader . weight_utils import default_weight_loader
from vllm . multimodal import MULTIMODAL_REGISTRY , MultiModalKwargsItems
from vllm . multimodal . inputs import (
MultiModalDataDict ,
MultiModalFieldConfig ,
NestedTensors ,
)
from vllm . multimodal . parse import (
ImageProcessorItems ,
ImageSize ,
MultiModalDataItems ,
)
from vllm . multimodal . processing import BaseDummyInputsBuilder
from vllm . multimodal . processing . processor import (
BaseMultiModalProcessor ,
BaseProcessingInfo ,
MultiModalProcessingInfo ,
ProcessorInputs ,
PromptReplacement ,
PromptUpdate ,
PromptUpdateDetails ,
TimingContext ,
)
from vllm . platforms import current_platform
from vllm . sequence import IntermediateTensors
from vllm . tokenizers import cached_tokenizer_from_config
from vllm . tokenizers . mistral import MistralTokenizer
from vllm . utils . tensor_schema import TensorSchema , TensorShape
from . interfaces import (
MultiModalEmbeddings ,
SupportsLoRA ,
SupportsMultiModal ,
SupportsPP ,
)
from . module_mapping import MultiModelKeys
from . utils import StageMissingLayer , init_vllm_registered_model , maybe_prefix
from . vision import (
VisionEncoderInfo ,
VisionFeatureSelectStrategy ,
is_vit_use_data_parallel ,
resolve_visual_encoder_outputs ,
)
2026-04-29 19:38:22 +08:00
import ixformer . inference . functions as ixf
2026-04-18 10:56:22 +08:00
try :
# Note: vLLM does not install xformers by default.
from xformers import ops as xops
2026-04-29 19:38:22 +08:00
if current_platform . is_cuda ( ) :
2026-04-18 10:56:22 +08:00
# Xformers FA is not compatible with B200
USE_XFORMERS_OPS = True
2026-04-29 19:38:22 +08:00
else :
USE_XFORMERS_OPS = False
2026-04-18 10:56:22 +08:00
except ImportError :
USE_XFORMERS_OPS = False
PATCH_MERGE = " patch_merge "
def _is_layer_none_or_staged ( layer : nn . Module ) - > bool :
return layer is None or isinstance ( layer , StageMissingLayer )
class PixtralImagePixelInputs ( TensorSchema ) :
"""
Dimensions :
- bn : Batch size * number of images
- c : Number of channels ( 3 )
- h : Height of each image
- w : Width of each image
The result of stacking ` ImageEncoding . tokens ` from each prompt .
"""
type : Literal [ " pixel_values " ] = " pixel_values "
images : Annotated [
torch . Tensor | list [ torch . Tensor ] ,
TensorShape ( " bn " , 3 , " h " , " w " , dynamic_dims = { " h " , " w " } ) ,
]
class PixtralProcessorAdapter :
"""
Provide a HF - compatible interface for
` mistral_common . tokens . tokenizers . multimodal . ImageEncoder ` .
"""
def __init__ ( self , tokenizer : MistralTokenizer ) - > None :
super ( ) . __init__ ( )
self . tokenizer = tokenizer
@property
def image_processor ( self ) - > ImageEncoder :
image_encoder = self . tokenizer . instruct . mm_encoder
assert isinstance ( image_encoder , ImageEncoder )
return image_encoder
@cached_property
def image_break_id ( self ) - > int :
return self . image_processor . special_ids . img_break
@cached_property
def image_token_id ( self ) - > int :
return self . image_processor . special_ids . img
@cached_property
def image_end_id ( self ) - > int :
return self . image_processor . special_ids . img_end
@cached_property
def image_size ( self ) - > int :
return self . image_processor . mm_config . max_image_size
@cached_property
def patch_size ( self ) - > int :
return self . image_processor . mm_config . image_patch_size
def __call__ (
self ,
text : TextInput | list [ TextInput ] | None = None ,
images : ImageInput | list [ ImageInput ] | None = None ,
return_tensors : str | TensorType | None = None ,
* * kwargs ,
) - > Mapping [ str , NestedTensors ] :
if text is None :
text = [ ]
if not isinstance ( text , list ) :
text = [ text ]
if images is None :
images = [ ]
if not isinstance ( images , list ) :
images = [ images ]
if not images :
input_ids = self . tokenizer ( text ) . input_ids
return { " input_ids " : torch . tensor ( input_ids ) }
# Allow dummy text, which is used for profiling as well as token inputs
if any ( len ( t ) > 0 for t in text ) :
raise ValueError (
" You ' ve passed text inputs instead of token inputs. "
" Make sure to process your input via `mistral_common` ' s "
" tokenizer or pass a chat completion request. "
" For more info, see: "
" https://github.com/vllm-project/vllm/issues/8411. "
)
images_processed = list [ torch . Tensor ] ( )
images_tokens = list [ torch . Tensor ] ( )
for image in images :
image_inputs = self . image_processor ( ImageChunk ( image = image ) )
image_processed = torch . tensor ( image_inputs . image )
image_tokens = torch . tensor ( image_inputs . tokens )
images_processed . append ( image_processed )
images_tokens . append ( image_tokens )
return BatchFeature (
{
" input_ids " : torch . cat ( images_tokens ) [ None ] . expand ( len ( text ) , - 1 ) ,
" images " : images_processed ,
}
)
class PixtralProcessingInfo ( BaseProcessingInfo ) :
def get_tokenizer ( self ) - > MistralTokenizer :
tokenizer = cached_tokenizer_from_config ( self . ctx . model_config )
if not isinstance ( tokenizer , MistralTokenizer ) :
raise ValueError ( " This model requires `--tokenizer-mode mistral` " )
return tokenizer
def get_hf_processor ( self ) - > PixtralProcessorAdapter :
return PixtralProcessorAdapter ( self . get_tokenizer ( ) )
def get_supported_mm_limits ( self ) - > Mapping [ str , int | None ] :
return { " image " : None }
def get_num_image_tokens (
self ,
* ,
image_width : int ,
image_height : int ,
processor : PixtralProcessorAdapter ,
) - > int :
ncols , nrows = processor . image_processor . _image_to_num_tokens (
Image . new ( " RGB " , ( image_width , image_height ) )
)
return ncols * nrows
def get_image_size_with_most_features ( self ) - > ImageSize :
image_processor = self . get_hf_processor ( ) . image_processor
max_image_size = image_processor . mm_config . max_image_size
return ImageSize ( width = max_image_size , height = max_image_size )
class PixtralDummyInputsBuilder ( BaseDummyInputsBuilder [ PixtralProcessingInfo ] ) :
def get_dummy_text ( self , mm_counts : Mapping [ str , int ] ) - > str :
return " "
def get_dummy_mm_data (
self ,
seq_len : int ,
mm_counts : Mapping [ str , int ] ,
mm_options : Mapping [ str , BaseDummyOptions ] ,
) - > MultiModalDataDict :
num_images = mm_counts . get ( " image " , 0 )
target_width , target_height = self . info . get_image_size_with_most_features ( )
image_overrides = mm_options . get ( " image " )
return {
" image " : self . _get_dummy_images (
width = target_width ,
height = target_height ,
num_images = num_images ,
overrides = image_overrides ,
)
}
def get_dummy_processor_inputs (
self ,
seq_len : int ,
mm_counts : Mapping [ str , int ] ,
mm_options : Mapping [ str , BaseDummyOptions ] ,
) - > ProcessorInputs :
tokenizer = self . info . get_tokenizer ( )
dummy_text = self . get_dummy_text ( mm_counts )
dummy_mm_data = self . get_dummy_mm_data ( seq_len , mm_counts , mm_options )
dummy_images = dummy_mm_data . get ( " image " , [ ] )
request = ChatCompletionRequest (
messages = [
UserMessage (
content = [
TextChunk ( text = dummy_text ) ,
* ( ImageChunk ( image = image ) for image in dummy_images ) ,
]
) ,
]
)
res = tokenizer . mistral . encode_chat_completion ( request )
dummy_tokens = res . tokens
dummy_mm_items = self . info . parse_mm_data ( dummy_mm_data )
return ProcessorInputs ( prompt = dummy_tokens , mm_data_items = dummy_mm_items )
class PixtralMultiModalProcessor ( BaseMultiModalProcessor [ PixtralProcessingInfo ] ) :
def _get_mm_fields_config (
self ,
hf_inputs : Mapping [ str , NestedTensors ] ,
hf_processor_mm_kwargs : Mapping [ str , object ] ,
) - > Mapping [ str , MultiModalFieldConfig ] :
return dict ( images = MultiModalFieldConfig . batched ( " image " ) )
def _get_prompt_updates (
self ,
mm_items : MultiModalDataItems ,
hf_processor_mm_kwargs : Mapping [ str , object ] ,
out_mm_kwargs : MultiModalKwargsItems ,
) - > Sequence [ PromptUpdate ] :
processor = self . info . get_hf_processor ( * * hf_processor_mm_kwargs )
image_break_id = processor . image_break_id
image_token_id = processor . image_token_id
image_end_id = processor . image_end_id
def get_replacement ( item_idx : int ) :
images = mm_items . get_items ( " image " , ImageProcessorItems )
image_size = images . get_image_size ( item_idx )
ncols , nrows = processor . image_processor . _image_to_num_tokens (
Image . new ( " RGB " , ( image_size . width , image_size . height ) )
)
tokens = ( [ image_token_id ] * ncols + [ image_break_id ] ) * nrows
tokens [ - 1 ] = image_end_id
return PromptUpdateDetails . select_token_id ( tokens , image_token_id )
return [
PromptReplacement (
modality = " image " ,
target = " " , # Never match the prompt (see below note)
replacement = get_replacement ,
) ,
]
def _cached_apply_hf_processor (
self ,
inputs : ProcessorInputs ,
timing_ctx : TimingContext ,
) - > tuple [ list [ int ] , MultiModalProcessingInfo , bool ] :
prompt_ids , mm_info , _ = super ( ) . _cached_apply_hf_processor ( inputs , timing_ctx )
# NOTE: The tokens are already inserted by the chat template
return prompt_ids , mm_info , True
@MULTIMODAL_REGISTRY.register_processor (
PixtralMultiModalProcessor ,
info = PixtralProcessingInfo ,
dummy_inputs = PixtralDummyInputsBuilder ,
)
class PixtralForConditionalGeneration (
nn . Module , SupportsLoRA , SupportsMultiModal , SupportsPP
) :
@classmethod
def get_placeholder_str ( cls , modality : str , i : int ) - > str | None :
if modality . startswith ( " image " ) :
return None
raise ValueError ( " Only image modality is supported " )
def __init__ ( self , * , vllm_config : VllmConfig , prefix : str = " " ) :
super ( ) . __init__ ( )
config = vllm_config . model_config . hf_config
multimodal_config = vllm_config . model_config . multimodal_config
self . config = config
self . multimodal_config = multimodal_config
dataclass_fields = { field . name for field in fields ( VisionEncoderArgs ) }
vision_args = {
key : value
for key , value in self . config . vision_config . to_dict ( ) . items ( )
if key in dataclass_fields
}
self . vision_args = VisionEncoderArgs ( * * vision_args )
# init MistralForCausalLM
with self . _mark_language_model ( vllm_config ) :
self . language_model = init_vllm_registered_model (
vllm_config = vllm_config ,
hf_config = config . text_config ,
prefix = maybe_prefix ( prefix , " language_model " ) ,
)
with self . _mark_tower_model ( vllm_config , " image " ) :
self . vision_encoder = VisionTransformer ( self . vision_args )
self . pre_mm_projector_norm = (
RMSNorm ( self . vision_args . hidden_size , eps = 1e-5 )
if self . vision_args . add_pre_mm_projector_layer_norm
else None
)
self . patch_merger = (
PatchMerger (
vision_encoder_dim = self . vision_args . hidden_size ,
spatial_merge_size = self . vision_args . spatial_merge_size ,
use_mlp_bias = False ,
)
if self . vision_args . mm_projector_id == PATCH_MERGE
else None
)
self . vision_language_adapter = VisionLanguageAdapter (
self . vision_args , dim = config . text_config . hidden_size
)
self . make_empty_intermediate_tensors = (
self . language_model . make_empty_intermediate_tensors
)
def _parse_and_validate_image_input (
self , * * kwargs : object
) - > PixtralImagePixelInputs | None :
images = kwargs . pop ( " images " , None )
if images is None :
return None
return PixtralImagePixelInputs (
type = " pixel_values " ,
images = images ,
)
def _process_image_input (
self ,
image_input : PixtralImagePixelInputs ,
) - > tuple [ torch . Tensor , . . . ] :
images = image_input [ " images " ]
image_features = self . vision_encoder ( images )
feature_sizes = [ image_feature . shape [ 0 ] for image_feature in image_features ]
image_features = torch . cat ( image_features )
if self . pre_mm_projector_norm is not None :
image_features = self . pre_mm_projector_norm ( image_features )
if self . patch_merger is not None :
patch_size = self . vision_args . patch_size
spatial_merge_size_square = self . vision_args . spatial_merge_size * * 2
img_patch_dims = [
( img . shape [ 1 ] / / patch_size , img . shape [ 2 ] / / patch_size )
for img in images
]
feature_sizes = [
feature_size / / spatial_merge_size_square
for feature_size in feature_sizes
]
image_features = self . patch_merger (
image_features , image_sizes = img_patch_dims
)
image_embeds = self . vision_language_adapter ( image_features )
image_embeds = torch . split ( image_embeds , feature_sizes )
return image_embeds
def embed_multimodal ( self , * * kwargs : object ) - > MultiModalEmbeddings :
image_input = self . _parse_and_validate_image_input ( * * kwargs )
if image_input is None :
return [ ]
return self . _process_image_input ( image_input )
def forward (
self ,
input_ids : torch . Tensor | None ,
positions : torch . Tensor ,
intermediate_tensors : IntermediateTensors | None = None ,
inputs_embeds : torch . Tensor | None = None ,
* * kwargs : object ,
) - > torch . Tensor | IntermediateTensors :
""" Run forward pass for pixtral. """
if intermediate_tensors is not None :
inputs_embeds = None
hidden_states = self . language_model . model (
input_ids , positions , intermediate_tensors , inputs_embeds = inputs_embeds
)
return hidden_states
def compute_logits (
self ,
hidden_states : torch . Tensor ,
) - > torch . Tensor | None :
return self . language_model . compute_logits ( hidden_states )
def load_weights ( self , weights : Iterable [ tuple [ str , torch . Tensor ] ] ) :
def is_vision_encoder_weights ( weight : tuple [ str , torch . Tensor ] ) :
return weight [ 0 ] . startswith ( ( " vision_encoder " , " vision_tower " ) )
def is_vision_lang_adapter_weights ( weight : tuple [ str , torch . Tensor ] ) :
return weight [ 0 ] . startswith (
( " vision_language_adapter " , " multi_modal_projector " )
)
def is_patch_merger ( weight : tuple [ str , torch . Tensor ] ) :
return weight [ 0 ] . startswith ( " patch_merger " )
def is_pre_mm_projector_norm ( weight : tuple [ str , torch . Tensor ] ) :
return weight [ 0 ] . startswith ( " pre_mm_projector_norm " )
# Get references to parameters for direct loading
vision_encoder_dict = (
dict ( self . vision_encoder . named_parameters ( ) )
if self . vision_encoder is not None
else { }
)
patch_merger_dict = (
dict ( self . patch_merger . named_parameters ( ) )
if self . patch_merger is not None
else { }
)
pre_mm_projector_norm_dict = (
dict ( self . pre_mm_projector_norm . named_parameters ( ) )
if self . pre_mm_projector_norm is not None
else { }
)
vision_lang_adapter_dict = (
dict ( self . vision_language_adapter . named_parameters ( ) )
if self . vision_language_adapter is not None
else { }
)
def llm_weights_generator ( ) :
# Single pass over weights
for name , w in weights :
if is_vision_encoder_weights ( ( name , w ) ) :
if _is_layer_none_or_staged ( self . vision_encoder ) :
continue
# Load vision encoder weights directly
trimmed_name = " . " . join ( name . split ( " . " ) [ 1 : ] )
param = vision_encoder_dict . get ( trimmed_name )
if param is not None :
with torch . no_grad ( ) :
default_weight_loader ( param , w )
elif is_patch_merger ( ( name , w ) ) :
if _is_layer_none_or_staged ( self . patch_merger ) :
continue
# Load vision patch merger weights directly
trimmed_name = " . " . join ( name . split ( " . " ) [ 1 : ] )
param = patch_merger_dict [ trimmed_name ]
with torch . no_grad ( ) :
default_weight_loader ( param , w )
elif is_pre_mm_projector_norm ( ( name , w ) ) :
if _is_layer_none_or_staged ( self . pre_mm_projector_norm ) :
continue
# Load vision pre_mm_projector_norm weights directly
trimmed_name = " . " . join ( name . split ( " . " ) [ 1 : ] )
param = pre_mm_projector_norm_dict [ trimmed_name ]
with torch . no_grad ( ) :
default_weight_loader ( param , w )
elif is_vision_lang_adapter_weights ( ( name , w ) ) :
if _is_layer_none_or_staged ( self . vision_language_adapter ) :
continue
# Load vision-language adapter weights directly
trimmed_name = " . " . join ( name . split ( " . " ) [ 1 : ] )
param = vision_lang_adapter_dict . get ( trimmed_name )
if param is not None :
with torch . no_grad ( ) :
default_weight_loader ( param , w )
else :
# LLM weights: yield them to be loaded
# by language_model.load_weights
# Strip "language_model." prefix if present (HF sharded format)
name = name . removeprefix ( " language_model. " )
yield ( name , w )
# Now we call the language model load with the generator
self . language_model . load_weights ( llm_weights_generator ( ) )
def get_mm_mapping ( self ) - > MultiModelKeys :
return MultiModelKeys . from_string_field (
language_model = " language_model " ,
connector = " vision_language_adapter " ,
tower_model = " vision_encoder " ,
)
def get_num_mm_encoder_tokens ( self , num_image_tokens : int ) - > int :
if getattr ( self , " patch_merger " , None ) is None :
return num_image_tokens
merge_size = self . vision_args . spatial_merge_size
return num_image_tokens * ( merge_size * * 2 )
def get_num_mm_connector_tokens ( self , num_vision_tokens : int ) - > int :
if getattr ( self , " patch_merger " , None ) is None :
return num_vision_tokens
merge_size = self . vision_args . spatial_merge_size
return num_vision_tokens / / ( merge_size * * 2 )
# Vision encoder
@dataclass
class VisionEncoderArgs :
hidden_size : int
num_channels : int
image_size : int
patch_size : int
intermediate_size : int
num_hidden_layers : int
num_attention_heads : int
rope_theta : float # for rope-2D
image_token_id : int
adapter_bias : bool = True
spatial_merge_size : int = 1
add_pre_mm_projector_layer_norm : bool = False
mm_projector_id : str = " "
def _reshape_for_broadcast ( freqs_cis : torch . Tensor , x : torch . Tensor ) - > torch . Tensor :
"""
freqs_cis : complex - ( seq_len , head_dim / 2 )
x : complex - ( bsz , seq_len , head_dim / 2 )
"""
ndim = x . ndim
assert ndim > 1
assert freqs_cis . shape == ( x . shape [ 1 ] , x . shape [ - 1 ] ) , (
freqs_cis . shape ,
( x . shape [ 1 ] , x . shape [ - 1 ] ) ,
)
shape = [ d if i == 1 or i == ndim - 1 else 1 for i , d in enumerate ( x . shape ) ]
return freqs_cis . view ( * shape )
def precompute_freqs_cis_2d (
dim : int ,
height : int ,
width : int ,
theta : float ,
) - > torch . Tensor :
"""
freqs_cis : 2 D complex tensor of shape ( height , width , dim / / 2 )
to be indexed by ( height , width ) position tuples
"""
# (dim / 2) frequency bases
freqs = 1.0 / ( theta * * ( torch . arange ( 0 , dim , 2 ) . float ( ) / dim ) )
h = torch . arange ( height , device = freqs . device )
w = torch . arange ( width , device = freqs . device )
freqs_h = torch . outer ( h , freqs [ : : 2 ] ) . float ( )
freqs_w = torch . outer ( w , freqs [ 1 : : 2 ] ) . float ( )
freqs_2d = torch . cat (
[
freqs_h [ : , None , : ] . repeat ( 1 , width , 1 ) ,
freqs_w [ None , : , : ] . repeat ( height , 1 , 1 ) ,
] ,
dim = - 1 ,
)
return torch . polar ( torch . ones_like ( freqs_2d ) , freqs_2d )
def apply_rotary_emb_vit (
xq : torch . Tensor ,
xk : torch . Tensor ,
freqs_cis : torch . Tensor ,
) - > tuple [ torch . Tensor , torch . Tensor ] :
xq_ = torch . view_as_complex ( xq . float ( ) . reshape ( * xq . shape [ : - 1 ] , - 1 , 2 ) )
xk_ = torch . view_as_complex ( xk . float ( ) . reshape ( * xk . shape [ : - 1 ] , - 1 , 2 ) )
assert freqs_cis . dtype == torch . complex64
freqs_cis = _reshape_for_broadcast ( freqs_cis , xq_ )
xq_out = torch . view_as_real ( xq_ * freqs_cis ) . flatten ( 3 )
xk_out = torch . view_as_real ( xk_ * freqs_cis ) . flatten ( 3 )
return xq_out . type_as ( xq ) , xk_out . type_as ( xk )
class FeedForward ( nn . Module ) :
def __init__ ( self , args : VisionEncoderArgs ) :
super ( ) . __init__ ( )
assert args . intermediate_size is not None
self . w1 = nn . Linear ( args . hidden_size , args . intermediate_size , bias = False )
self . w2 = nn . Linear ( args . intermediate_size , args . hidden_size , bias = False )
self . w3 = nn . Linear ( args . hidden_size , args . intermediate_size , bias = False )
def forward ( self , x : torch . Tensor ) - > torch . Tensor :
return self . w2 ( F . silu ( self . w1 ( x ) ) * self . w3 ( x ) )
class Attention ( nn . Module ) :
def __init__ ( self , args : VisionEncoderArgs ) :
super ( ) . __init__ ( )
self . args = args
assert not args . hidden_size % args . num_attention_heads
self . n_heads = args . num_attention_heads
self . head_dim = args . hidden_size / / args . num_attention_heads
self . wq = nn . Linear ( args . hidden_size , args . hidden_size , bias = False )
self . wk = nn . Linear ( args . hidden_size , args . hidden_size , bias = False )
self . wv = nn . Linear ( args . hidden_size , args . hidden_size , bias = False )
self . wo = nn . Linear ( args . hidden_size , args . hidden_size , bias = False )
def forward (
self ,
x : torch . Tensor ,
mask : torch . Tensor ,
freqs_cis : torch . Tensor ,
) - > torch . Tensor :
batch , patches , _ = x . shape
q , k , v = self . wq ( x ) , self . wk ( x ) , self . wv ( x )
q = q . reshape ( batch , patches , self . n_heads , self . head_dim )
k = k . reshape ( batch , patches , self . n_heads , self . head_dim )
if USE_XFORMERS_OPS :
2026-04-29 19:38:22 +08:00
v = v . reshape ( batch * patches , self . n_heads , self . head_dim )
q , k = apply_rotary_emb_vit ( q , k , freqs_cis = freqs_cis )
q = q . view ( batch * patches , self . n_heads , self . head_dim )
k = k . view ( batch * patches , self . n_heads , self . head_dim )
out = ixf . ixinfer_flash_attn_unpad ( q , k , v , mask . q_seqinfo . seqstart . to ( q . device ) , mask . k_seqinfo . seqstart . to ( q . device ) , mask . q_seqinfo . max_seqlen , mask . k_seqinfo . max_seqlen )
# out = memory_efficient_attention(q, k, v, attn_bias=mask)
2026-04-18 10:56:22 +08:00
else :
2026-04-29 19:38:22 +08:00
assert False , " xformers failed ! "
2026-04-18 10:56:22 +08:00
out = out . reshape ( batch , patches , self . n_heads * self . head_dim )
return self . wo ( out )
2026-04-29 19:38:22 +08:00
2026-04-18 10:56:22 +08:00
class TransformerBlock ( nn . Module ) :
def __init__ ( self , args : VisionEncoderArgs ) :
super ( ) . __init__ ( )
self . attention = Attention ( args )
self . feed_forward = FeedForward ( args )
self . attention_norm = RMSNorm ( args . hidden_size , eps = 1e-5 )
self . ffn_norm = RMSNorm ( args . hidden_size , eps = 1e-5 )
def forward (
self ,
x : torch . Tensor ,
mask : torch . Tensor ,
freqs_cis : torch . Tensor ,
) - > torch . Tensor :
r = self . attention . forward (
self . attention_norm ( x ) , mask = mask , freqs_cis = freqs_cis
)
h = x + r
r = self . feed_forward . forward ( self . ffn_norm ( h ) )
out = h + r
return out
class Transformer ( nn . Module ) :
def __init__ ( self , args : VisionEncoderArgs ) :
super ( ) . __init__ ( )
self . layers = torch . nn . ModuleList ( )
for _ in range ( args . num_hidden_layers ) :
self . layers . append ( TransformerBlock ( args ) )
def forward (
self ,
x : torch . Tensor ,
mask : torch . Tensor ,
freqs_cis : torch . Tensor | None ,
) - > torch . Tensor :
for layer in self . layers :
x = layer ( x , mask = mask , freqs_cis = freqs_cis )
return x
def position_meshgrid (
patch_embeds_list : list [ torch . Tensor ] ,
) - > torch . Tensor :
positions = torch . cat (
[
torch . stack (
torch . meshgrid (
torch . arange ( p . shape [ - 2 ] ) ,
torch . arange ( p . shape [ - 1 ] ) ,
indexing = " ij " ,
) ,
dim = - 1 ,
) . reshape ( - 1 , 2 )
for p in patch_embeds_list
]
)
return positions
class VisionTransformer ( nn . Module ) :
def __init__ ( self , args : VisionEncoderArgs ) :
super ( ) . __init__ ( )
self . args = args
self . patch_conv = Conv2dLayer (
in_channels = args . num_channels ,
out_channels = args . hidden_size ,
kernel_size = args . patch_size ,
stride = args . patch_size ,
bias = False ,
)
self . ln_pre = RMSNorm ( args . hidden_size , eps = 1e-5 )
self . transformer = Transformer ( args )
head_dim = self . args . hidden_size / / self . args . num_attention_heads
assert head_dim % 2 == 0 , " ROPE requires even head_dim "
self . _freqs_cis : torch . Tensor | None = None
@property
def max_patches_per_side ( self ) - > int :
return self . args . image_size / / self . args . patch_size
@property
def device ( self ) - > torch . types . Device :
return next ( self . parameters ( ) ) . device
@property
def dtype ( self ) - > torch . dtype :
return next ( self . parameters ( ) ) . dtype
@property
def freqs_cis ( self ) - > torch . Tensor :
if self . _freqs_cis is None :
self . _freqs_cis = precompute_freqs_cis_2d (
dim = self . args . hidden_size / / self . args . num_attention_heads ,
height = self . max_patches_per_side ,
width = self . max_patches_per_side ,
theta = self . args . rope_theta ,
)
if self . _freqs_cis . device != self . device :
self . _freqs_cis = self . _freqs_cis . to ( device = self . device )
return self . _freqs_cis
def forward (
self ,
images : list [ torch . Tensor ] ,
) - > torch . Tensor :
"""
Args :
images : list of N_img images of variable sizes ,
each of shape ( C , H , W )
Returns :
image_features : tensor of token features for
all tokens of all images of shape ( N_toks , D )
"""
# pass images through initial convolution independently
patch_embeds_list = [
self . patch_conv ( img . unsqueeze ( 0 ) . to ( self . dtype ) ) for img in images
]
patch_embeds = [ p . flatten ( 2 ) . permute ( 0 , 2 , 1 ) for p in patch_embeds_list ]
embed_sizes = [ p . shape [ 1 ] for p in patch_embeds ]
# flatten to a single sequence
patch_embeds = torch . cat ( patch_embeds , dim = 1 )
patch_embeds = self . ln_pre ( patch_embeds )
# positional embeddings
positions = position_meshgrid ( patch_embeds_list ) . to ( self . device )
freqs_cis = self . freqs_cis [ positions [ : , 0 ] , positions [ : , 1 ] ]
# pass through Transformer with a block diagonal mask delimiting images
if USE_XFORMERS_OPS :
mask = xops . fmha . attn_bias . BlockDiagonalMask . from_seqlens (
[ p . shape [ - 2 ] * p . shape [ - 1 ] for p in patch_embeds_list ] ,
)
else :
from transformers . models . pixtral . modeling_pixtral import (
generate_block_attention_mask ,
)
mask = generate_block_attention_mask (
[ p . shape [ - 2 ] * p . shape [ - 1 ] for p in patch_embeds_list ] , patch_embeds
)
out = self . transformer ( patch_embeds , mask = mask , freqs_cis = freqs_cis )
# squeeze dim 0 and split into separate tensors for each image
return torch . split ( out . squeeze ( 0 ) , embed_sizes )
class VisionLanguageAdapter ( nn . Module ) :
def __init__ ( self , args : VisionEncoderArgs , dim : int ) :
super ( ) . __init__ ( )
assert isinstance ( args , VisionEncoderArgs )
self . w_in = nn . Linear (
args . hidden_size ,
dim ,
bias = args . adapter_bias ,
)
self . gelu = nn . GELU ( )
self . w_out = nn . Linear ( dim , dim , bias = args . adapter_bias )
def forward ( self , x : torch . Tensor ) - > torch . Tensor :
return self . w_out ( self . gelu ( self . w_in ( x ) ) )
class PatchMerger ( nn . Module ) :
"""
Learned merging of spatial_merge_size * * 2 patches
"""
def __init__ (
self ,
vision_encoder_dim : int ,
spatial_merge_size : int ,
use_mlp_bias : bool = False ,
) - > None :
super ( ) . __init__ ( )
mlp_input_dim = vision_encoder_dim * ( spatial_merge_size * * 2 )
self . spatial_merge_size = spatial_merge_size
self . mlp_input_dim = mlp_input_dim
self . merging_layer = nn . Linear (
mlp_input_dim ,
vision_encoder_dim ,
bias = use_mlp_bias ,
)
def forward (
self , x : torch . Tensor , image_sizes : list [ tuple [ int , int ] ]
) - > torch . Tensor :
# image_sizes specified in tokens
assert sum ( [ h * w for h , w in image_sizes ] ) == len ( x )
# x is (N, vision_encoder_dim)
x = self . permute ( x , image_sizes )
# x is (N / spatial_merge_size ** 2,
# vision_encoder_dim * spatial_merge_size ** 2)
x = self . merging_layer ( x )
# x is (N / spatial_merge_size ** 2, vision_encoder_dim)
return x
def permute (
self ,
x : torch . Tensor ,
image_sizes : list [ tuple [ int , int ] ] ,
) - > torch . Tensor :
"""
Args :
x : ( N , D ) where N is flattened and concatenated patch tokens
for all images
image_sizes : list of tuple of ( height , width ) in tokens for
each image
Returns :
image_features : reorders patch tokens so each grid of
( spatial_merge_size , spatial_merge_size ) is contiguous .
now ( N / spatial_merge_size * * 2 , D * spatial_merge_size * * 2 )
"""
sub_grids = get_sub_grids (
x = x , image_sizes = image_sizes , spatial_merge_size = self . spatial_merge_size
) # list of [d x sub_grid_size x sub_grid_size x n_patches]
permuted_tensor : list [ torch . Tensor ] = [ ]
for grid in sub_grids :
n_patches = grid . shape [ - 1 ]
permuted_tensor . append (
grid . view ( - 1 , n_patches ) . t ( )
) # n_patches x d * sub_grid_size * sub_grid_size
return torch . cat (
permuted_tensor , dim = 0
) # (N / spatial_merge_size ** 2, d * spatial_merge_size ** 2)
def get_sub_grids (
x : torch . Tensor ,
image_sizes : list [ tuple [ int , int ] ] ,
spatial_merge_size : int ,
) - > list [ torch . Tensor ] :
# image_sizes specified in tokens
tokens_per_image = [ h * w for h , w in image_sizes ]
d = x . shape [ - 1 ]
all_img_sub_grids : list [ torch . Tensor ] = [ ]
sub_grid_size = spatial_merge_size
for image_index , image_tokens in enumerate ( x . split ( tokens_per_image ) ) :
# Reshape image_tokens into a 2D grid
h , w = image_sizes [ image_index ]
image_grid = image_tokens . view ( h , w , d ) . permute ( 2 , 0 , 1 ) [
None , : , : , :
] # 1 x d x h x w
sub_grids = torch . nn . functional . unfold (
image_grid , kernel_size = sub_grid_size , stride = sub_grid_size
)
sub_grids = sub_grids . view (
1 , d , sub_grid_size , sub_grid_size , - 1
) # 1 x d x sub_grid_size x sub_grid_size x n_patches
all_img_sub_grids . append ( sub_grids [ 0 ] )
return all_img_sub_grids
#### HF Transformers version of Pixtral ####
# Based off https://github.com/huggingface/transformers/blob/d7950bff82b18c823193d17d72188c5e46d06c83/src/transformers/models/pixtral/modeling_pixtral.py
# This model follows the Llava family, meaning image embeddings are placed
# instead of the `[IMG]` token placeholders.
# The model uses [`PixtralVisionModel`] for its vision encoder,
# and [`MistralForCausalLM`] for its language decoder.
class PixtralHFEncoderInfo ( VisionEncoderInfo [ PixtralVisionConfig ] ) :
def get_num_image_tokens (
self ,
* ,
image_width : int ,
image_height : int ,
) - > int :
ncols , nrows = self . get_patch_grid_size (
image_width = image_width ,
image_height = image_height ,
)
return ncols * nrows
def get_image_size ( self ) - > int :
return self . vision_config . image_size
def get_patch_size ( self ) - > int :
# spatial_merge_size is needed for Mistral3
spatial_merge_size = getattr ( self . hf_config , " spatial_merge_size " , 1 )
return self . vision_config . patch_size * spatial_merge_size
def get_patch_grid_length ( self ) - > int :
image_size , patch_size = self . get_image_size ( ) , self . get_patch_size ( )
# Since interpolation is applied, the image size need not be divisible
# assert image_size % patch_size == 0
return image_size / / patch_size
# Adapted from: https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/pixtral/image_processing_pixtral.py#L99
def get_patch_grid_size (
self ,
* ,
image_width : int ,
image_height : int ,
) - > tuple [ int , int ] :
max_width = max_height = self . get_image_size ( )
patch_width = patch_height = self . get_patch_size ( )
ratio = max ( image_width / max_width , image_height / max_height )
if ratio > 1 :
image_width = int ( math . floor ( image_width / ratio ) )
image_height = int ( math . floor ( image_height / ratio ) )
nrows , ncols = _get_pixtral_hf_num_image_tokens (
( image_height , image_width ) ,
( patch_height , patch_width ) ,
) # type: ignore
return ncols , nrows
class PixtralHFMLP ( nn . Module ) :
def __init__ (
self ,
config : PixtralVisionConfig ,
quant_config : QuantizationConfig | None = None ,
* ,
prefix : str = " " ,
) - > None :
super ( ) . __init__ ( )
use_data_parallel = is_vit_use_data_parallel ( )
assert config . intermediate_size is not None
self . gate_up_proj = MergedColumnParallelLinear (
input_size = config . hidden_size ,
output_sizes = [ config . intermediate_size ] * 2 ,
bias = False ,
quant_config = quant_config ,
prefix = f " { prefix } .gate_up_proj " ,
disable_tp = use_data_parallel ,
)
self . down_proj = RowParallelLinear (
input_size = config . intermediate_size ,
output_size = config . hidden_size ,
bias = False ,
quant_config = quant_config ,
prefix = f " { prefix } .down_proj " ,
disable_tp = use_data_parallel ,
)
self . act_and_mul = get_act_and_mul_fn ( config . hidden_act )
def forward ( self , x : torch . Tensor ) - > torch . Tensor :
gate_up , _ = self . gate_up_proj ( x )
x = self . act_and_mul ( gate_up )
x , _ = self . down_proj ( x )
return x
class PixtralHFAttention ( nn . Module ) :
def __init__ (
self ,
config : PixtralVisionConfig ,
quant_config : QuantizationConfig | None = None ,
* ,
prefix : str = " " ,
) - > None :
super ( ) . __init__ ( )
self . config = config
assert not config . hidden_size % config . num_attention_heads
self . total_num_heads = config . num_attention_heads
self . head_dim = config . hidden_size / / config . num_attention_heads
assert self . total_num_heads * self . head_dim == config . hidden_size
use_data_parallel = is_vit_use_data_parallel ( )
self . qkv_proj = QKVParallelLinear (
hidden_size = config . hidden_size ,
head_size = self . head_dim ,
total_num_heads = self . total_num_heads ,
bias = False ,
quant_config = quant_config ,
prefix = f " { prefix } .qkv_proj " ,
disable_tp = use_data_parallel ,
)
self . o_proj = RowParallelLinear (
input_size = config . hidden_size ,
output_size = config . hidden_size ,
bias = False ,
quant_config = quant_config ,
prefix = f " { prefix } .o_proj " ,
disable_tp = use_data_parallel ,
)
self . tp_size = (
1 if use_data_parallel else get_tensor_model_parallel_world_size ( )
)
self . n_heads = divide ( config . num_attention_heads , self . tp_size )
def forward (
self ,
hidden_states : torch . Tensor ,
attention_mask : torch . Tensor ,
position_embeddings : torch . Tensor ,
) - > tuple [ torch . Tensor , torch . Tensor | None ] :
batch , patches , _ = hidden_states . size ( )
qkv_states , _ = self . qkv_proj ( hidden_states )
q , k , v = qkv_states . chunk ( 3 , dim = - 1 )
# Transpose q and k to apply HF's Rotary Position Embedding
q = q . view ( batch , patches , self . n_heads , self . head_dim ) . transpose ( 1 , 2 )
k = k . view ( batch , patches , self . n_heads , self . head_dim ) . transpose ( 1 , 2 )
v = v . view ( batch , patches , self . n_heads , self . head_dim )
cos , sin = position_embeddings
q , k = apply_rotary_pos_emb ( q , k , cos , sin , unsqueeze_dim = 0 )
if USE_XFORMERS_OPS :
# Transpose q and k back for attention
q = q . transpose ( 1 , 2 ) . contiguous ( )
k = k . transpose ( 1 , 2 ) . contiguous ( )
out = xops . memory_efficient_attention ( q , k , v , attn_bias = attention_mask )
else :
v = v . transpose ( 1 , 2 )
out = nn . functional . scaled_dot_product_attention (
q , k , v , attn_mask = attention_mask
)
out = out . transpose ( 1 , 2 )
out = out . reshape ( batch , patches , self . n_heads * self . head_dim )
attn_output , _ = self . o_proj ( out )
return attn_output , None
class PixtralHFTransformerBlock ( nn . Module ) :
def __init__ (
self ,
config : PixtralVisionConfig ,
quant_config : QuantizationConfig | None = None ,
* ,
prefix : str = " " ,
) - > None :
super ( ) . __init__ ( )
self . attention_norm = RMSNorm ( config . hidden_size , eps = 1e-5 )
self . attention = PixtralHFAttention (
config ,
quant_config = quant_config ,
prefix = f " { prefix } .attention " ,
)
self . feed_forward = PixtralHFMLP (
config ,
quant_config = quant_config ,
prefix = f " { prefix } .feed_forward " ,
)
self . ffn_norm = RMSNorm ( config . hidden_size , eps = 1e-5 )
def forward (
self ,
hidden_states : torch . Tensor ,
attention_mask : torch . Tensor ,
position_embeddings : torch . Tensor ,
) - > torch . Tensor :
r , _ = self . attention . forward (
self . attention_norm ( hidden_states ) ,
attention_mask = attention_mask ,
position_embeddings = position_embeddings ,
)
h = hidden_states + r
r = self . feed_forward . forward ( self . ffn_norm ( h ) )
out = h + r
return out
class PixtralHFTransformer ( nn . Module ) :
def __init__ (
self ,
config : PixtralVisionConfig ,
quant_config : QuantizationConfig | None = None ,
* ,
num_hidden_layers_override : int | None = None ,
prefix : str = " " ,
) - > None :
super ( ) . __init__ ( )
if num_hidden_layers_override is None :
num_hidden_layers = config . num_hidden_layers
else :
num_hidden_layers = num_hidden_layers_override
self . layers = nn . ModuleList (
[
PixtralHFTransformerBlock (
config = config ,
quant_config = quant_config ,
prefix = f " { prefix } .layers. { layer_idx } " ,
)
for layer_idx in range ( num_hidden_layers )
]
)
def forward (
self ,
x : torch . Tensor ,
attention_mask : torch . Tensor ,
position_embeddings : torch . Tensor ,
return_all_hidden_states : bool ,
) - > torch . Tensor :
hidden_states_pool = [ x ]
for layer in self . layers :
x = layer ( x , attention_mask , position_embeddings )
if return_all_hidden_states :
hidden_states_pool . append ( x )
# If we have multiple feature sample layers, we return all hidden
# states in order and grab the ones we need by index.
if return_all_hidden_states :
return hidden_states_pool
return x
class PixtralHFVisionModel ( nn . Module ) :
def __init__ (
self ,
config : PixtralVisionConfig ,
quant_config : QuantizationConfig | None = None ,
* ,
num_hidden_layers_override : int | None = None ,
require_post_norm : bool | None = None ,
prefix : str = " " ,
) - > None :
super ( ) . __init__ ( )
self . config = config
self . patch_conv = Conv2dLayer (
in_channels = config . num_channels ,
out_channels = config . hidden_size ,
kernel_size = config . patch_size ,
stride = config . patch_size ,
bias = False ,
)
self . ln_pre = RMSNorm ( config . hidden_size , eps = 1e-5 )
self . transformer = PixtralHFTransformer (
config ,
quant_config = quant_config ,
num_hidden_layers_override = num_hidden_layers_override ,
prefix = f " { prefix } .transformer " ,
)
num_hidden_layers = config . num_hidden_layers
if len ( self . transformer . layers ) > config . num_hidden_layers :
raise ValueError (
f " The original encoder only has { num_hidden_layers } "
f " layers, but you requested { len ( self . transformer . layers ) } "
" layers. "
)
if require_post_norm is True :
msg = " PixtralHFVisionModel does not have post-layernorm "
raise ValueError ( msg )
self . dtype = next ( self . parameters ( ) ) . dtype
self . device = next ( self . parameters ( ) ) . device
self . patch_positional_embedding = PixtralRotaryEmbedding ( config , self . device )
def forward (
self ,
pixel_values : list [ torch . Tensor ] ,
* ,
select_layers : list [ int ] | None = None ,
feature_select_strategy : VisionFeatureSelectStrategy | None = None ,
) - > tuple [ torch . Tensor , . . . ] :
"""
Args :
pixel_values : Each image to be processed will be a separate tensor
in pixel_values . This means it will be a list of tensors
because multiple requests batched can have multiple images ,
each with their own shape potentially
select_layers : Layer indices whose features should be
concatenated and used as the visual encoder output . If none
are provided , the last layer is used .
Returns :
image_features : tensor of token features for
all tokens of all images of shape ( N_toks , D )
"""
# pass images through initial convolution independently
patch_embeds_list = [
self . patch_conv ( img . unsqueeze ( 0 ) . to ( self . dtype ) ) for img in pixel_values
]
patch_embeds = [ p . flatten ( 2 ) . permute ( 0 , 2 , 1 ) for p in patch_embeds_list ]
embed_sizes = [ p . shape [ 1 ] for p in patch_embeds ]
# flatten to a single sequence
patch_embeds = torch . cat ( patch_embeds , dim = 1 )
patch_embeds = self . ln_pre ( patch_embeds )
# positional embeddings
position_ids = position_ids_in_meshgrid (
patch_embeds_list ,
max_width = self . config . image_size / / self . config . patch_size ,
) . to ( self . device )
position_embedding = self . patch_positional_embedding ( patch_embeds , position_ids )
if USE_XFORMERS_OPS :
attention_mask = xops . fmha . attn_bias . BlockDiagonalMask . from_seqlens (
[ p . shape [ - 2 ] * p . shape [ - 1 ] for p in patch_embeds_list ] ,
)
else :
from transformers . models . pixtral . modeling_pixtral import (
generate_block_attention_mask ,
)
attention_mask = generate_block_attention_mask (
[ p . shape [ - 2 ] * p . shape [ - 1 ] for p in patch_embeds_list ] , patch_embeds
)
out = self . transformer (
patch_embeds ,
attention_mask ,
position_embedding ,
return_all_hidden_states = select_layers is not None ,
)
out = resolve_visual_encoder_outputs (
out ,
None ,
select_layers = select_layers ,
max_possible_layers = self . config . num_hidden_layers ,
feature_select_strategy = feature_select_strategy ,
)
# squeeze dim 0 and split into separate tensors for each image
return torch . split ( out . squeeze ( 0 ) , embed_sizes )
# (TODO) Add prefix argument for filtering out weights to be loaded
# ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
def load_weights ( self , weights : Iterable [ tuple [ str , torch . Tensor ] ] ) - > set [ str ] :
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
( " .qkv_proj " , " .q_proj " , " q " ) ,
( " .qkv_proj " , " .k_proj " , " k " ) ,
( " .qkv_proj " , " .v_proj " , " v " ) ,
( " .gate_up_proj " , " .gate_proj " , 0 ) ,
( " .gate_up_proj " , " .up_proj " , 1 ) ,
]
params_dict = dict ( self . named_parameters ( ) )
loaded_params : set [ str ] = set ( )
layer_count = len ( self . transformer . layers )
for name , loaded_weight in weights :
# omit layers when num_hidden_layers_override is set
if name . startswith ( " transformer.layers " ) :
layer_idx = int ( name . split ( " . " ) [ 2 ] )
if layer_idx > = layer_count :
continue
for param_name , weight_name , shard_id in stacked_params_mapping :
if weight_name not in name :
continue
name = name . replace ( weight_name , param_name )
param = params_dict [ name ]
weight_loader = param . weight_loader
weight_loader ( param , loaded_weight , shard_id )
break
else :
param = params_dict [ name ]
weight_loader = getattr ( param , " weight_loader " , default_weight_loader )
weight_loader ( param , loaded_weight )
loaded_params . add ( name )
return loaded_params