2024-11-22 22:16:53 +08:00
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
2024-01-08 04:37:50 +00:00
""" Inference-only LLaVa model compatible with HuggingFace weights. """
2024-02-11 05:50:13 -08:00
2024-08-24 05:11:16 +08:00
import math
import re
2025-05-13 00:16:10 -07:00
from functools import lru_cache
from typing import Dict , Iterable , List , Optional , Tuple , Type , Union
2024-01-08 04:37:50 +00:00
import numpy as np
import torch
2024-04-22 22:38:09 +08:00
from torch import nn
2024-06-12 21:48:40 -07:00
from transformers import (
CLIPVisionConfig ,
CLIPVisionModel ,
LlavaConfig ,
MistralConfig ,
Qwen2Config ,
2024-08-24 05:11:16 +08:00
SiglipVisionModel ,
2024-06-12 21:48:40 -07:00
)
2025-05-13 00:16:10 -07:00
from transformers . models . auto . modeling_auto import AutoModel , AutoModelForCausalLM
2024-04-22 22:38:09 +08:00
from transformers . models . llava . modeling_llava import LlavaMultiModalProjector
2025-05-13 00:16:10 -07:00
# leave till last and symbol only in case circular import
import sglang . srt . models as sgl_models
2024-09-19 20:53:11 +08:00
from sglang . srt . layers . quantization . base_config import QuantizationConfig
2025-05-13 00:16:10 -07:00
from sglang . srt . managers . mm_utils import general_mm_embed_routine
from sglang . srt . managers . schedule_batch import (
Modality ,
MultimodalDataItem ,
MultimodalInputs ,
)
2024-09-30 02:41:11 -07:00
from sglang . srt . model_executor . forward_batch_info import ForwardBatch
2024-12-02 23:22:13 +08:00
from sglang . srt . model_loader . weight_utils import default_weight_loader
2024-09-02 21:44:45 -07:00
from sglang . srt . models . llama import LlamaForCausalLM
2024-05-27 03:29:51 +08:00
from sglang . srt . models . mistral import MistralForCausalLM
2024-06-12 21:48:40 -07:00
from sglang . srt . models . qwen2 import Qwen2ForCausalLM
2025-06-27 11:58:24 -07:00
from sglang . srt . multimodal . mm_utils import (
get_anyres_image_grid_shape ,
unpad_image ,
unpad_image_shape ,
)
2025-05-13 00:16:10 -07:00
from sglang . srt . utils import add_prefix , flatten_nested_list , logger
2024-01-08 04:37:50 +00:00
2024-08-28 08:38:50 -07:00
class LlavaBaseForCausalLM ( nn . Module ) :
2025-03-25 11:08:40 +08:00
def pad_input_ids ( self , input_ids : List [ int ] , image_inputs : MultimodalInputs ) :
2025-04-01 00:57:51 +08:00
image_sizes = flatten_nested_list (
[ item . image_sizes for item in image_inputs . mm_items ]
)
pad_values = [ item . pad_value for item in image_inputs . mm_items ]
2024-09-28 23:28:55 -07:00
2024-08-24 05:11:16 +08:00
# hardcode for spatial_unpad + anyres
2025-04-01 00:57:51 +08:00
if any (
item . modality == Modality . MULTI_IMAGES or item . modality == Modality . VIDEO
for item in image_inputs . mm_items
2024-11-27 00:03:29 -08:00
) :
image_aspect_ratio = " pad "
else :
image_aspect_ratio = " anyres "
2024-08-24 05:11:16 +08:00
offset_list = [ ]
2024-12-09 09:52:38 -08:00
image_inputs . image_pad_len = [ ]
2024-11-28 12:08:13 -08:00
for image_idx , image_s in enumerate ( image_sizes ) :
2024-08-28 06:33:05 -07:00
if len ( image_sizes ) > 16 :
2024-08-24 05:11:16 +08:00
# 2x2 pooling with stride 2
new_image_feature_len = (
math . ceil ( self . image_size / self . patch_size / 2 ) * * 2
)
else :
2025-04-01 00:57:51 +08:00
new_image_feature_len = self . image_feature_len # multi-image
2024-08-24 05:11:16 +08:00
2024-01-24 01:51:21 -08:00
height = width = self . num_patches_per_side
2024-08-24 05:11:16 +08:00
if " anyres " in image_aspect_ratio :
num_patch_width , num_patch_height = get_anyres_image_grid_shape (
image_s ,
self . image_grid_pinpoints ,
self . vision_tower . config . image_size ,
)
h = num_patch_height * height
w = num_patch_width * width
new_h , new_w = unpad_image_shape ( h , w , image_s )
if " anyres_max " in self . config . image_aspect_ratio :
matched_anyres_max_num_patches = re . match (
r " anyres_max_( \ d+) " , self . config . image_aspect_ratio
)
if matched_anyres_max_num_patches :
max_num_patches = int ( matched_anyres_max_num_patches . group ( 1 ) )
# times = math.sqrt(h * w / (max_num_patches * unit**2))
times = math . sqrt (
new_h * new_w / ( max_num_patches * self . image_feature_len )
2024-01-24 01:51:21 -08:00
)
2024-08-24 05:11:16 +08:00
if times > 1.1 :
new_h = int ( new_h / / times )
new_w = int ( new_w / / times )
new_image_feature_len + = new_h * ( new_w + 1 )
try :
offset = input_ids . index ( self . config . image_token_index )
except ValueError :
offset = 0
# old_len + pad_len - 1, because we need to remove image_token_id
input_ids = (
input_ids [ : offset ]
2025-04-01 00:57:51 +08:00
+ [ pad_values [ image_idx % len ( pad_values ) ] ] * new_image_feature_len
2024-08-24 05:11:16 +08:00
+ input_ids [ offset + 1 : ]
)
offset_list . append ( offset )
2024-12-09 09:52:38 -08:00
image_inputs . image_pad_len . append ( new_image_feature_len )
2024-09-28 23:28:55 -07:00
image_inputs . image_offsets = offset_list
return input_ids
2024-01-08 04:37:50 +00:00
2025-05-13 00:16:10 -07:00
def encode_images (
self , pixel_values : Union [ torch . Tensor , List [ torch . Tensor ] ]
) - > torch . Tensor :
"""
encode images by vision tower and multimodal projector
Args :
pixel_values : torch . Tensor or List [ torch . Tensor ] : each tensor for an input image
Returns :
torch . Tensor : encoded image features from the input image ; if multiple , flattened by seq_len axis
"""
2024-01-24 01:51:21 -08:00
image_outputs = self . vision_tower ( pixel_values , output_hidden_states = True )
# NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated.
selected_image_feature = image_outputs . hidden_states [ self . vision_feature_layer ]
if self . vision_feature_select_strategy in [ " default " , " patch " ] :
selected_image_feature = selected_image_feature [ : , 1 : ]
elif self . vision_feature_select_strategy == " full " :
selected_image_feature = selected_image_feature
else :
raise ValueError (
f " Unexpected select feature strategy: { self . config . vision_feature_select_strategy } "
)
image_features = self . multi_modal_projector ( selected_image_feature )
return image_features
2024-07-15 22:09:09 -07:00
@torch.no_grad ( )
2024-01-08 04:37:50 +00:00
def forward (
self ,
input_ids : torch . LongTensor ,
positions : torch . Tensor ,
2024-09-30 02:41:11 -07:00
forward_batch : ForwardBatch ,
2024-01-08 04:37:50 +00:00
) - > torch . Tensor :
2025-03-25 11:08:40 +08:00
image_inputs = forward_batch . mm_inputs
2024-09-28 23:28:55 -07:00
2024-09-30 02:41:11 -07:00
if forward_batch . forward_mode . is_extend ( ) :
2024-12-09 09:52:38 -08:00
# Clamp input ids. This is because the input_ids for the image tokens are
# filled with the hash values of the image for the prefix matching in the radix attention.
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
input_ids . clamp_ ( min = 0 , max = self . config . vocab_size - 1 )
# Embed text inputs
input_embeds = self . language_model . model . embed_tokens ( input_ids )
2024-09-09 17:07:34 +08:00
# Got List[List[str]] extend it to List[str]
# The length of the List should be equal to batch size
modalities_list = [ ]
2024-09-28 23:28:55 -07:00
max_image_offset = [ ]
for im in image_inputs :
2025-04-01 00:57:51 +08:00
if im :
modalities_list . extend ( [ item . modality for item in im . mm_items ] )
2024-11-29 04:24:20 -08:00
if im and im . image_offsets :
2024-12-09 09:52:38 -08:00
max_image_offset . append (
np . max ( np . array ( im . image_offsets ) + np . array ( im . image_pad_len ) )
)
2024-09-28 23:28:55 -07:00
else :
max_image_offset . append ( - 1 )
2024-01-08 04:37:50 +00:00
2024-09-30 02:41:11 -07:00
start_positions = positions [ forward_batch . extend_start_loc ] . cpu ( ) . numpy ( )
2024-09-28 23:28:55 -07:00
need_vision = start_positions < = np . array ( max_image_offset )
2024-01-08 04:37:50 +00:00
if need_vision . any ( ) :
2024-11-29 04:24:20 -08:00
bs = forward_batch . batch_size
2025-04-01 00:57:51 +08:00
pixel_values = flatten_nested_list (
[
2025-07-17 08:52:38 +08:00
[ item . feature for item in image_inputs [ i ] . mm_items ]
2025-04-01 00:57:51 +08:00
for i in range ( bs )
if need_vision [ i ]
]
)
2024-09-28 23:28:55 -07:00
image_sizes = [
2025-04-01 00:57:51 +08:00
flatten_nested_list (
[ item . image_sizes for item in image_inputs [ i ] . mm_items ]
)
for i in range ( bs )
if need_vision [ i ]
2024-09-28 23:28:55 -07:00
]
2024-01-08 04:37:50 +00:00
2024-01-24 01:51:21 -08:00
########## Encode Image ########
2024-01-25 07:56:25 -08:00
if pixel_values [ 0 ] . ndim == 4 :
2024-01-24 01:51:21 -08:00
# llava-hd: BS, num_patch, C=3, H=336, W=336, num_patch obtained from process_images
2024-01-25 07:56:25 -08:00
np . concatenate ( pixel_values , axis = 0 )
# ndim=4
concat_images = torch . tensor (
np . concatenate ( pixel_values , axis = 0 ) ,
device = self . vision_tower . device ,
)
2024-01-24 01:51:21 -08:00
image_features = self . encode_images ( concat_images )
split_sizes = [ image . shape [ 0 ] for image in pixel_values ]
image_features = torch . split ( image_features , split_sizes , dim = 0 )
# hd image_features: BS, num_patch, 576, 4096
2024-01-08 04:37:50 +00:00
else :
2024-01-24 01:51:21 -08:00
# normal pixel: BS, C=3, H=336, W=336
2024-01-25 07:56:25 -08:00
pixel_values = torch . tensor (
np . array ( pixel_values ) , device = self . vision_tower . device
)
2024-01-24 01:51:21 -08:00
image_features = self . encode_images ( pixel_values )
# image_features: BS, 576, 4096
if self . mm_patch_merge_type . startswith ( " spatial " ) :
new_image_features = [ ]
2024-08-24 05:11:16 +08:00
height = width = self . num_patches_per_side
2024-01-24 01:51:21 -08:00
for image_idx , image_feature in enumerate ( image_features ) :
2025-04-01 00:57:51 +08:00
if modalities_list [ image_idx ] == Modality . IMAGE :
2024-08-24 05:11:16 +08:00
image_aspect_ratio = (
self . config . image_aspect_ratio
) # single image
2024-09-12 16:10:26 +08:00
elif (
2025-04-01 00:57:51 +08:00
modalities_list [ image_idx ] == Modality . MULTI_IMAGES
or modalities_list [ image_idx ] == Modality . VIDEO
2024-09-12 16:10:26 +08:00
) :
2024-08-24 05:11:16 +08:00
image_aspect_ratio = " pad " # multi image
# image_aspect_ratio = (
# "anyres" if len(image_sizes[image_idx]) == 1 else "pad"
# )
if (
image_feature . shape [ 0 ] > 1
and " anyres " in image_aspect_ratio
2025-04-01 00:57:51 +08:00
and modalities_list [ image_idx ] == Modality . IMAGE
2024-08-24 05:11:16 +08:00
) :
2024-01-24 01:51:21 -08:00
base_image_feature = image_feature [ 0 ]
image_feature = image_feature [ 1 : ]
assert height * width == base_image_feature . shape [ 0 ]
2024-08-24 05:11:16 +08:00
if " anyres_max " in image_aspect_ratio :
matched_anyres_max_num_patches = re . match (
r " anyres_max_( \ d+) " , image_aspect_ratio
2024-01-24 01:51:21 -08:00
)
2024-08-24 05:11:16 +08:00
if matched_anyres_max_num_patches :
max_num_patches = int (
matched_anyres_max_num_patches . group ( 1 )
)
if (
image_aspect_ratio == " anyres "
or " anyres_max " in image_aspect_ratio
) :
vision_tower_image_size = self . image_size
try :
num_patch_width , num_patch_height = (
get_anyres_image_grid_shape (
image_sizes [ image_idx ] [ 0 ] ,
self . config . image_grid_pinpoints ,
vision_tower_image_size ,
)
)
except Exception as e :
print ( f " Error: { e } " )
num_patch_width , num_patch_height = 2 , 2
2024-01-24 01:51:21 -08:00
image_feature = image_feature . view (
num_patch_height , num_patch_width , height , width , - 1
)
else :
2024-08-24 05:11:16 +08:00
image_feature = image_feature . view (
2 , 2 , height , width , - 1
)
# (
# num_patch_width,
# num_patch_height,
# ) = get_anyres_image_grid_shape(
# image_sizes[image_idx][0],
# self.image_grid_pinpoints,
# self.vision_tower.config.image_size,
# )
# image_feature = image_feature.view(
# num_patch_height, num_patch_width, height, width, -1
# )
2024-01-24 01:51:21 -08:00
if " unpad " in self . mm_patch_merge_type :
2024-08-24 05:11:16 +08:00
unit = image_feature . shape [ 2 ]
2024-01-24 01:51:21 -08:00
image_feature = image_feature . permute (
4 , 0 , 2 , 1 , 3
) . contiguous ( )
image_feature = image_feature . flatten ( 1 , 2 ) . flatten (
2 , 3
)
image_feature = unpad_image (
2024-08-24 05:11:16 +08:00
image_feature , image_sizes [ image_idx ] [ 0 ]
2024-01-24 01:51:21 -08:00
)
2024-08-24 05:11:16 +08:00
if (
" anyres_max " in image_aspect_ratio
and matched_anyres_max_num_patches
) :
c , h , w = image_feature . shape
times = math . sqrt (
h * w / ( max_num_patches * unit * * 2 )
)
if times > 1.1 :
image_feature = image_feature [ None ]
image_feature = nn . functional . interpolate (
image_feature ,
[ int ( h / / times ) , int ( w / / times ) ] ,
mode = " bilinear " ,
) [ 0 ]
2024-01-24 01:51:21 -08:00
image_feature = torch . cat (
(
image_feature ,
self . language_model . model . image_newline [
: , None , None
] . expand ( * image_feature . shape [ : - 1 ] , 1 ) ,
) ,
dim = - 1 ,
)
image_feature = image_feature . flatten ( 1 , 2 ) . transpose (
0 , 1
)
else :
image_feature = image_feature . permute (
0 , 2 , 1 , 3 , 4
) . contiguous ( )
image_feature = image_feature . flatten ( 0 , 3 )
image_feature = torch . cat (
( base_image_feature , image_feature ) , dim = 0
)
2024-08-24 05:11:16 +08:00
image_feature = image_feature . unsqueeze ( 0 )
2024-01-24 01:51:21 -08:00
else :
2025-04-01 00:57:51 +08:00
if modalities_list [ image_idx ] == Modality . VIDEO : # video
2024-08-24 05:11:16 +08:00
# 2x2 pooling
num_of_frames = image_feature . shape [ 0 ]
image_feature = image_feature . view (
num_of_frames , height , width , - 1
2024-01-24 01:51:21 -08:00
)
2024-08-24 05:11:16 +08:00
image_feature = image_feature . permute (
0 , 3 , 1 , 2
) . contiguous ( ) # N, C, H, W
height , weight = image_feature . shape [ 2 : ]
scaled_shape = [
math . ceil ( height / 2 ) ,
math . ceil ( weight / 2 ) ,
]
image_feature = nn . functional . interpolate (
image_feature , size = scaled_shape , mode = " bilinear "
)
image_feature = (
image_feature . flatten ( 2 )
. transpose ( 1 , 2 )
. contiguous ( )
) # N, C, H*W
2024-09-12 16:10:26 +08:00
if " unpad " in self . mm_patch_merge_type :
image_feature = torch . cat (
(
image_feature ,
# Expand to (bs, 1, hidden_dim) and concat at the end of the image tokens
self . language_model . model . image_newline [
None , None
] . expand (
image_feature . shape [ 0 ] ,
1 ,
image_feature . shape [ - 1 ] ,
) ,
) ,
dim = 1 ,
)
2024-08-24 05:11:16 +08:00
2024-01-24 01:51:21 -08:00
new_image_features . append ( image_feature )
image_features = new_image_features
2024-01-08 04:37:50 +00:00
2024-08-28 06:33:05 -07:00
# Fill in the placeholder for the image
2024-09-30 02:41:11 -07:00
extend_start_loc_cpu = forward_batch . extend_start_loc . cpu ( ) . numpy ( )
2024-12-09 09:52:38 -08:00
extend_seq_lens = forward_batch . extend_seq_lens . cpu ( ) . numpy ( )
2024-11-18 00:02:36 -08:00
prefix_lens_cpu = forward_batch . extend_prefix_lens_cpu
2024-01-08 04:37:50 +00:00
pt = 0
for i in range ( bs ) :
if not need_vision [ i ] :
continue
start_idx = extend_start_loc_cpu [ i ]
2024-12-09 09:52:38 -08:00
seq_len = extend_seq_lens [ i ]
2024-08-28 06:33:05 -07:00
prefix_len = prefix_lens_cpu [ i ]
# Multiple images
2024-12-09 09:52:38 -08:00
for image_idx , image_offset in enumerate (
image_inputs [ i ] . image_offsets
) :
if (
image_offset + image_inputs [ i ] . image_pad_len [ image_idx ]
< = prefix_len
) :
2024-08-28 06:33:05 -07:00
continue
2024-12-09 09:52:38 -08:00
if image_offset > = prefix_len + seq_len :
break
2024-08-28 06:33:05 -07:00
2024-12-09 09:52:38 -08:00
tmp_image_feature = image_features [ pt ] [ image_idx ]
2024-08-28 06:33:05 -07:00
pad_len = tmp_image_feature . shape [ 0 ]
2024-12-09 09:52:38 -08:00
input_offset = image_offset - prefix_len
left_idx = start_idx + input_offset
right_idx = left_idx + pad_len
assert right_idx > start_idx
if input_offset < 0 :
left_idx = start_idx
tmp_image_feature = tmp_image_feature [ - input_offset : ]
if right_idx > start_idx + seq_len :
tmp_image_feature = tmp_image_feature [
: start_idx + seq_len - right_idx
]
right_idx = start_idx + seq_len
2024-08-28 06:33:05 -07:00
try :
input_embeds [ left_idx : right_idx ] = tmp_image_feature
except RuntimeError as e :
print ( f " RuntimeError in image encoding: { e } " )
print ( f " { input_embeds . shape =} , { tmp_image_feature . shape =} " )
print (
f " { start_idx =} , { image_offset =} , { prefix_len =} , { pad_len =} "
)
2024-01-08 04:37:50 +00:00
pt + = 1
return self . language_model (
2024-09-30 02:41:11 -07:00
input_ids , positions , forward_batch , input_embeds = input_embeds
2024-01-08 04:37:50 +00:00
)
2024-09-30 02:41:11 -07:00
elif forward_batch . forward_mode . is_decode ( ) :
return self . language_model ( input_ids , positions , forward_batch )
2024-01-08 04:37:50 +00:00
2024-05-21 09:13:37 -07:00
def load_weights ( self , weights : Iterable [ Tuple [ str , torch . Tensor ] ] ) :
2024-08-28 06:33:05 -07:00
# Load clip vision model by cfg['mm_vision_tower']:
# huggingface_name or path_of_clip_relative_to_llava_model_dir
# We put the initialization here instead of __init__ to allow it being reused by other subclasses.
2024-01-08 04:37:50 +00:00
vision_path = self . config . mm_vision_tower
2024-08-24 05:11:16 +08:00
if " clip " in vision_path :
self . vision_tower = CLIPVisionModel . from_pretrained (
vision_path , torch_dtype = torch . float16
) . cuda ( )
elif " siglip " in vision_path :
self . vision_tower = SiglipVisionModel . from_pretrained (
vision_path , torch_dtype = torch . float16
) . cuda ( )
# Siglip needs all feature tokens
self . config . mm_vision_select_feature = " full "
2024-01-08 04:37:50 +00:00
self . vision_tower . eval ( )
self . vision_feature_layer = self . config . mm_vision_select_layer
self . vision_feature_select_strategy = self . config . mm_vision_select_feature
self . image_size = self . vision_tower . config . image_size
self . patch_size = self . vision_tower . config . patch_size
2024-01-24 01:51:21 -08:00
self . mm_patch_merge_type = getattr ( self . config , " mm_patch_merge_type " , " flat " )
self . image_aspect_ratio = getattr ( self . config , " image_aspect_ratio " , " square " )
self . image_grid_pinpoints = getattr ( self . config , " image_grid_pinpoints " , None )
2024-08-24 05:11:16 +08:00
self . image_feature_len = int ( ( self . image_size / / self . patch_size ) * * 2 )
if (
self . vision_feature_select_strategy == " patch "
or self . vision_feature_select_strategy == " full "
) :
2024-01-08 04:37:50 +00:00
pass
elif self . vision_feature_select_strategy == " cls_patch " :
self . image_feature_len + = 1
else :
raise ValueError ( f " Unexpected select feature: { self . select_feature } " )
# load mm_projector
projector_weights = {
" model.mm_projector.0 " : " multi_modal_projector.linear_1 " ,
" model.mm_projector.2 " : " multi_modal_projector.linear_2 " ,
2025-02-15 16:10:32 +08:00
" model.vision_tower.vision_tower " : " vision_tower " ,
# Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
2024-09-02 21:44:45 -07:00
" model.image_newline " : " language_model.model.image_newline " ,
2024-01-08 04:37:50 +00:00
}
params_dict = dict ( self . named_parameters ( ) )
2024-05-21 09:13:37 -07:00
for name , loaded_weight in weights :
2024-09-02 21:44:45 -07:00
if " projector " in name or " vision_tower " in name or " image_newline " in name :
2024-01-08 04:37:50 +00:00
for weight_name , param_name in projector_weights . items ( ) :
if weight_name in name :
name = name . replace ( weight_name , param_name )
param = params_dict [ name ]
weight_loader = getattr ( param , " weight_loader " , default_weight_loader )
weight_loader ( param , loaded_weight )
2024-09-02 21:44:45 -07:00
else :
self . language_model . load_weights ( [ ( name , loaded_weight ) ] )
2024-01-08 04:37:50 +00:00
2024-01-24 01:51:21 -08:00
@property
def num_patches_per_side ( self ) :
return self . image_size / / self . patch_size
2024-01-08 04:37:50 +00:00
2024-08-28 08:38:50 -07:00
class LlavaLlamaForCausalLM ( LlavaBaseForCausalLM ) :
def __init__ (
self ,
config : LlavaConfig ,
quant_config : Optional [ QuantizationConfig ] = None ,
2025-03-05 17:11:00 +08:00
prefix : str = " " ,
2024-08-28 08:38:50 -07:00
) - > None :
super ( ) . __init__ ( )
self . config = config
self . vision_tower = None
self . config . vision_config . hidden_size = config . mm_hidden_size
self . config . text_config . hidden_size = config . hidden_size
2024-09-02 21:44:45 -07:00
2024-08-28 08:38:50 -07:00
self . multi_modal_projector = LlavaMultiModalProjector ( config )
2025-03-05 17:11:00 +08:00
self . language_model = LlamaForCausalLM (
config ,
quant_config = quant_config ,
prefix = add_prefix ( " language_model " , prefix ) ,
)
2024-08-28 08:38:50 -07:00
if " unpad " in getattr ( config , " mm_patch_merge_type " , " " ) :
self . language_model . model . image_newline = nn . Parameter (
torch . empty ( config . text_config . hidden_size , dtype = torch . float16 )
)
class LlavaQwenForCausalLM ( LlavaBaseForCausalLM ) :
2024-05-27 03:29:51 +08:00
def __init__ (
self ,
config : LlavaConfig ,
quant_config : Optional [ QuantizationConfig ] = None ,
2025-03-05 17:11:00 +08:00
prefix : str = " " ,
2024-05-27 03:29:51 +08:00
) - > None :
2024-08-28 08:38:50 -07:00
super ( ) . __init__ ( )
2024-05-27 03:29:51 +08:00
self . config = config
self . vision_tower = None
2024-09-02 21:44:45 -07:00
2024-05-27 03:29:51 +08:00
if getattr ( self . config , " vision_config " , None ) is None :
self . config . vision_config = CLIPVisionConfig ( self . config . mm_vision_tower )
if getattr ( self . config , " text_config " , None ) is None :
self . config . text_config = Qwen2Config ( self . config . _name_or_path )
self . config . vision_config . hidden_size = config . mm_hidden_size
self . config . text_config . hidden_size = config . hidden_size
if getattr ( self . config , " projector_hidden_act " , None ) is None :
self . config . projector_hidden_act = " gelu "
if getattr ( self . config , " image_token_index " , None ) is None :
self . config . image_token_index = 151646
self . multi_modal_projector = LlavaMultiModalProjector ( config )
2025-03-05 17:11:00 +08:00
self . language_model = Qwen2ForCausalLM (
config ,
quant_config = quant_config ,
prefix = add_prefix ( " language_model " , prefix ) ,
)
2024-05-27 03:29:51 +08:00
if " unpad " in getattr ( config , " mm_patch_merge_type " , " " ) :
self . language_model . model . image_newline = nn . Parameter (
torch . empty ( config . text_config . hidden_size , dtype = torch . float16 )
)
2024-08-28 08:38:50 -07:00
class LlavaMistralForCausalLM ( LlavaBaseForCausalLM ) :
2024-05-27 03:29:51 +08:00
def __init__ (
self ,
config : LlavaConfig ,
quant_config : Optional [ QuantizationConfig ] = None ,
2025-03-05 17:11:00 +08:00
prefix : str = " " ,
2024-05-27 03:29:51 +08:00
) - > None :
2024-08-28 08:38:50 -07:00
super ( ) . __init__ ( )
2024-05-27 03:29:51 +08:00
self . config = config
self . vision_tower = None
2024-09-02 21:44:45 -07:00
2024-05-27 03:29:51 +08:00
if getattr ( self . config , " vision_config " , None ) is None :
self . config . vision_config = CLIPVisionConfig ( self . config . mm_vision_tower )
if getattr ( self . config , " text_config " , None ) is None :
self . config . text_config = MistralConfig ( self . config . _name_or_path )
self . config . vision_config . hidden_size = config . mm_hidden_size
self . config . text_config . hidden_size = config . hidden_size
if getattr ( self . config , " projector_hidden_act " , None ) is None :
self . config . projector_hidden_act = " gelu "
if getattr ( self . config , " image_token_index " , None ) is None :
self . config . image_token_index = 32000
self . multi_modal_projector = LlavaMultiModalProjector ( config )
2025-03-05 17:11:00 +08:00
self . language_model = MistralForCausalLM (
config ,
quant_config = quant_config ,
prefix = add_prefix ( " language_model " , prefix ) ,
)
2024-05-27 03:29:51 +08:00
if " unpad " in getattr ( config , " mm_patch_merge_type " , " " ) :
self . language_model . model . image_newline = nn . Parameter (
torch . empty ( config . text_config . hidden_size , dtype = torch . float16 )
)
2025-05-13 00:16:10 -07:00
class LlavaForConditionalGeneration ( LlavaBaseForCausalLM ) :
"""
An adaptor class to enable support for multiple mmlm such as mistral - community / pixtral - 12 b
It follows the structure of ( vision_tower , multi_modal_projector , language_model )
Once a model config is loaded , text_config and vision_config will be extracted , and
LlavaForConditionalGeneration will load the language_model and vision_tower models
according to config .
"""
MULTIMODAL_PROJECTOR_TYPE = LlavaMultiModalProjector
2025-05-16 18:36:18 -07:00
@property
def dtype ( self ) :
return self . torch_dtype
2025-05-13 00:16:10 -07:00
def pad_input_ids ( self , input_ids : List [ int ] , image_inputs : MultimodalInputs ) :
if hasattr ( self . vision_tower , " pad_input_ids " ) :
return self . vision_tower . pad_input_ids ( input_ids , image_inputs )
else :
return super ( ) . pad_input_ids ( input_ids , image_inputs )
def _get_sgl_model_cls ( self , config , auto_model_type : Type [ AutoModel ] = AutoModel ) :
"""
Get the SGLang model implementation class according to config .
Args :
config : The config object of the model .
auto_model_type : The type of the auto model .
Returns :
The SGLang model implementation class .
"""
config_cls_name = config . __class__ . __name__
arch_name_mapping = self . _config_cls_name_to_arch_name_mapping ( auto_model_type )
if arch := arch_name_mapping . get ( config_cls_name ) :
if isinstance ( arch , tuple ) :
arch = arch [ 0 ]
logger . warning (
f " Multiple { auto_model_type . __name__ } models found for submodule config ` { config_cls_name } `, defaulting to [0]: { arch . __name__ } "
)
try :
return sgl_models . registry . ModelRegistry . resolve_model_cls ( arch ) [ 0 ]
except Exception as e :
raise ValueError (
f " { auto_model_type . __name__ } found a corresponding model ` { arch } ` for config class ` { config_cls_name } `, but failed to load it from SGLang ModelRegistry. \n { e } "
)
else :
raise ValueError (
f " { auto_model_type . __name__ } cannot find a corresponding model for config class ` { config_cls_name } ` "
)
@lru_cache
def _config_cls_name_to_arch_name_mapping (
self , auto_model_type : Type [ AutoModel ]
) - > Dict [ str , str ] :
mapping = { }
2025-07-27 21:27:25 -07:00
for config_cls in auto_model_type . _model_mapping . keys ( ) :
archs = auto_model_type . _model_mapping . get ( config_cls , None )
if archs is not None :
if isinstance ( archs , tuple ) :
mapping [ config_cls . __name__ ] = tuple (
arch . __name__ for arch in archs
)
else :
mapping [ config_cls . __name__ ] = archs . __name__
2025-05-13 00:16:10 -07:00
return mapping
def __init__ (
self ,
config : LlavaConfig ,
quant_config : Optional [ QuantizationConfig ] = None ,
prefix : str = " " ,
) - > None :
super ( ) . __init__ ( )
assert hasattr ( config , " text_config " )
assert hasattr ( config , " vision_config " )
self . config = config
2025-05-16 18:36:18 -07:00
self . text_config = self . config . text_config
self . vision_config = self . config . vision_config
self . torch_dtype = getattr ( self . config , " torch_dtype " )
if not getattr ( self . text_config , " torch_dtype " ) :
self . text_config . torch_dtype = self . torch_dtype
if not getattr ( self . vision_config , " torch_dtype " ) :
self . vision_config . torch_dtype = self . torch_dtype
2025-05-13 00:16:10 -07:00
if not hasattr ( self . config , " vocab_size " ) :
2025-05-16 18:36:18 -07:00
self . config . vocab_size = self . text_config . vocab_size
2025-05-13 00:16:10 -07:00
if not hasattr ( self . config , " image_aspect_ratio " ) :
self . config . image_aspect_ratio = " anyres "
if not hasattr ( self . config , " image_grid_pinpoints " ) :
# from transformers.models.llava_onevision.configuration_llava_onevision import LlavaOnevisionConfig
# self.config.image_grid_pinpoints = LlavaOnevisionConfig().image_grid_pinpoints
self . config . image_grid_pinpoints = [
[ 96 , 96 ] ,
[ 224 , 224 ] ,
[ 384 , 384 ] ,
[ 512 , 512 ] ,
[ 768 , 768 ] ,
[ 1024 , 1024 ] ,
]
if not hasattr ( self . config , " mm_patch_merge_type " ) :
self . config . mm_patch_merge_type = " flat "
if not hasattr ( self . config , " image_token_index " ) :
self . config . image_token_index = 10
if not hasattr ( self . config , " projector_hidden_act " ) :
self . config . projector_hidden_act = " gelu "
2025-05-16 18:36:18 -07:00
self . vision_feature_layer = getattr ( self . config , " vision_feature_layer " , - 1 )
2025-05-13 00:16:10 -07:00
self . vision_feature_select_strategy = getattr (
2025-05-16 18:36:18 -07:00
self . config , " vision_feature_select_strategy " , " full "
2025-05-13 00:16:10 -07:00
)
2025-05-16 18:36:18 -07:00
self . image_size = self . vision_config . image_size
self . patch_size = self . vision_config . patch_size
2025-05-13 00:16:10 -07:00
2025-05-16 18:36:18 -07:00
self . mm_patch_merge_type = self . config . mm_patch_merge_type
self . image_aspect_ratio = self . config . image_aspect_ratio
self . image_grid_pinpoints = self . config . image_grid_pinpoints
2025-05-13 00:16:10 -07:00
self . image_feature_len = int ( ( self . image_size / / self . patch_size ) * * 2 )
self . multi_modal_projector = self . MULTIMODAL_PROJECTOR_TYPE ( config )
language_model_cls = self . _get_sgl_model_cls (
2025-05-16 18:36:18 -07:00
self . text_config , AutoModelForCausalLM
2025-05-13 00:16:10 -07:00
)
2025-05-16 18:36:18 -07:00
vision_model_cls = self . _get_sgl_model_cls ( self . vision_config , AutoModel )
2025-05-13 00:16:10 -07:00
self . language_model = language_model_cls (
2025-05-16 18:36:18 -07:00
self . text_config ,
2025-05-13 00:16:10 -07:00
quant_config = quant_config ,
prefix = add_prefix ( " language_model " , prefix ) ,
)
self . vision_tower = vision_model_cls (
2025-05-16 18:36:18 -07:00
self . vision_config ,
2025-05-13 00:16:10 -07:00
quant_config = quant_config ,
prefix = add_prefix ( " vision_tower " , prefix ) ,
)
2025-05-16 18:36:18 -07:00
if " unpad " in getattr ( self . config , " mm_patch_merge_type " , " " ) :
2025-05-13 00:16:10 -07:00
self . language_model . model . image_newline = nn . Parameter (
2025-05-16 18:36:18 -07:00
torch . empty ( self . text_config . hidden_size , dtype = self . torch_dtype )
2025-05-13 00:16:10 -07:00
)
def get_image_feature ( self , items : List [ MultimodalDataItem ] ) - > torch . Tensor :
""" Extract features from image inputs.
Args :
items : List of MultimodalDataItem objects containing image data
Note that an item can be either " image " or " multi-images "
Returns :
torch . Tensor : features from image inputs , concatenated
"""
features = [ ]
for item in items :
# in each item, we assume pixel_values is always batched
2025-07-17 08:52:38 +08:00
pixel_values , image_sizes = item . feature , item . image_sizes
2025-05-13 00:16:10 -07:00
image_outputs = self . vision_tower (
pixel_values , image_sizes , output_hidden_states = True
)
selected_image_feature = image_outputs . hidden_states [
self . vision_feature_layer
]
if self . vision_feature_select_strategy in [ " default " , " patch " ] :
selected_image_feature = selected_image_feature [ : , 1 : ]
elif self . vision_feature_select_strategy == " full " :
selected_image_feature = selected_image_feature
else :
raise ValueError (
f " Unexpected select feature: { self . vision_feature_select_strategy } "
)
features . append (
self . multi_modal_projector ( selected_image_feature . squeeze ( 0 ) )
)
ret = torch . cat ( features , dim = 0 )
return ret
def forward (
self ,
input_ids : torch . Tensor ,
positions : torch . Tensor ,
forward_batch : ForwardBatch ,
get_embedding : bool = False ,
) :
hidden_states = general_mm_embed_routine (
input_ids = input_ids ,
forward_batch = forward_batch ,
get_embedding = get_embedding ,
language_model = self . language_model ,
2025-07-10 14:48:35 +08:00
data_embedding_funcs = {
Modality . IMAGE : self . get_image_feature ,
} ,
2025-05-13 00:16:10 -07:00
placeholder_tokens = None , # using mm_item.pad_value
positions = positions ,
)
return hidden_states
def load_weights ( self , weights : Iterable [ Tuple [ str , torch . Tensor ] ] ) :
""" Load weights for LlavaForConditionalGeneration.
Unlike the base class implementation , this one doesn ' t need to handle
weight name remapping as the weights are already properly structured with
' language_model ' and ' vision_tower ' prefixes in the safetensors files .
"""
if (
self . vision_feature_select_strategy == " patch "
or self . vision_feature_select_strategy == " full "
) :
pass
elif self . vision_feature_select_strategy == " cls_patch " :
self . image_feature_len + = 1
else :
raise ValueError (
f " Unexpected select feature: { self . vision_feature_select_strategy } "
)
# Create dictionaries for direct parameter loading
params_dict = dict ( self . named_parameters ( ) )
# Load weights directly without remapping
for name , loaded_weight in weights :
for part in ( " language_model " , " vision_tower " ) :
if name . startswith ( part ) :
name = name [ len ( part + " . " ) : ]
getattr ( self , part ) . load_weights ( [ ( name , loaded_weight ) ] )
break
else :
param = params_dict [ name ]
weight_loader = getattr ( param , " weight_loader " , default_weight_loader )
weight_loader ( param , loaded_weight )
EntryClass = [
LlavaLlamaForCausalLM ,
LlavaQwenForCausalLM ,
LlavaMistralForCausalLM ,
LlavaForConditionalGeneration ,
]