model(vlm): pixtral (#5084)
This commit is contained in:
@@ -15,7 +15,8 @@
|
||||
|
||||
import math
|
||||
import re
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
from functools import lru_cache
|
||||
from typing import Dict, Iterable, List, Optional, Tuple, Type, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -28,10 +29,18 @@ from transformers import (
|
||||
Qwen2Config,
|
||||
SiglipVisionModel,
|
||||
)
|
||||
from transformers.models.auto.modeling_auto import AutoModel, AutoModelForCausalLM
|
||||
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
||||
|
||||
# leave till last and symbol only in case circular import
|
||||
import sglang.srt.models as sgl_models
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.managers.schedule_batch import Modality, MultimodalInputs
|
||||
from sglang.srt.managers.mm_utils import general_mm_embed_routine
|
||||
from sglang.srt.managers.schedule_batch import (
|
||||
Modality,
|
||||
MultimodalDataItem,
|
||||
MultimodalInputs,
|
||||
)
|
||||
from sglang.srt.mm_utils import (
|
||||
get_anyres_image_grid_shape,
|
||||
unpad_image,
|
||||
@@ -42,7 +51,7 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.llama import LlamaForCausalLM
|
||||
from sglang.srt.models.mistral import MistralForCausalLM
|
||||
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
||||
from sglang.srt.utils import add_prefix, flatten_nested_list
|
||||
from sglang.srt.utils import add_prefix, flatten_nested_list, logger
|
||||
|
||||
|
||||
class LlavaBaseForCausalLM(nn.Module):
|
||||
@@ -114,7 +123,16 @@ class LlavaBaseForCausalLM(nn.Module):
|
||||
image_inputs.image_offsets = offset_list
|
||||
return input_ids
|
||||
|
||||
def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||||
def encode_images(
|
||||
self, pixel_values: Union[torch.Tensor, List[torch.Tensor]]
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
encode images by vision tower and multimodal projector
|
||||
Args:
|
||||
pixel_values: torch.Tensor or List[torch.Tensor]: each tensor for an input image
|
||||
Returns:
|
||||
torch.Tensor: encoded image features from the input image; if multiple, flattened by seq_len axis
|
||||
"""
|
||||
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
|
||||
# NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated.
|
||||
|
||||
@@ -583,4 +601,229 @@ class LlavaMistralForCausalLM(LlavaBaseForCausalLM):
|
||||
)
|
||||
|
||||
|
||||
EntryClass = [LlavaLlamaForCausalLM, LlavaQwenForCausalLM, LlavaMistralForCausalLM]
|
||||
class LlavaForConditionalGeneration(LlavaBaseForCausalLM):
|
||||
"""
|
||||
An adaptor class to enable support for multiple mmlm such as mistral-community/pixtral-12b
|
||||
It follows the structure of (vision_tower, multi_modal_projector, language_model)
|
||||
|
||||
Once a model config is loaded, text_config and vision_config will be extracted, and
|
||||
LlavaForConditionalGeneration will load the language_model and vision_tower models
|
||||
according to config.
|
||||
"""
|
||||
|
||||
MULTIMODAL_PROJECTOR_TYPE = LlavaMultiModalProjector
|
||||
|
||||
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
|
||||
if hasattr(self.vision_tower, "pad_input_ids"):
|
||||
return self.vision_tower.pad_input_ids(input_ids, image_inputs)
|
||||
else:
|
||||
return super().pad_input_ids(input_ids, image_inputs)
|
||||
|
||||
def _get_sgl_model_cls(self, config, auto_model_type: Type[AutoModel] = AutoModel):
|
||||
"""
|
||||
Get the SGLang model implementation class according to config.
|
||||
|
||||
Args:
|
||||
config: The config object of the model.
|
||||
auto_model_type: The type of the auto model.
|
||||
|
||||
Returns:
|
||||
The SGLang model implementation class.
|
||||
"""
|
||||
config_cls_name = config.__class__.__name__
|
||||
arch_name_mapping = self._config_cls_name_to_arch_name_mapping(auto_model_type)
|
||||
if arch := arch_name_mapping.get(config_cls_name):
|
||||
if isinstance(arch, tuple):
|
||||
arch = arch[0]
|
||||
logger.warning(
|
||||
f"Multiple {auto_model_type.__name__} models found for submodule config `{config_cls_name}`, defaulting to [0]: {arch.__name__}"
|
||||
)
|
||||
try:
|
||||
return sgl_models.registry.ModelRegistry.resolve_model_cls(arch)[0]
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"{auto_model_type.__name__} found a corresponding model `{arch}` for config class `{config_cls_name}`, but failed to load it from SGLang ModelRegistry. \n{e}"
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{auto_model_type.__name__} cannot find a corresponding model for config class `{config_cls_name}`"
|
||||
)
|
||||
|
||||
@lru_cache
|
||||
def _config_cls_name_to_arch_name_mapping(
|
||||
self, auto_model_type: Type[AutoModel]
|
||||
) -> Dict[str, str]:
|
||||
mapping = {}
|
||||
for config_cls, archs in auto_model_type._model_mapping.items():
|
||||
if isinstance(archs, tuple):
|
||||
mapping[config_cls.__name__] = tuple(arch.__name__ for arch in archs)
|
||||
else:
|
||||
mapping[config_cls.__name__] = archs.__name__
|
||||
return mapping
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: LlavaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
assert hasattr(config, "text_config")
|
||||
assert hasattr(config, "vision_config")
|
||||
self.config = config
|
||||
self.text_config = config.text_config
|
||||
self.vision_config = config.vision_config
|
||||
|
||||
if not hasattr(self.config, "vocab_size"):
|
||||
self.config.vocab_size = self.config.text_config.vocab_size
|
||||
if not hasattr(self.config, "image_aspect_ratio"):
|
||||
self.config.image_aspect_ratio = "anyres"
|
||||
if not hasattr(self.config, "image_grid_pinpoints"):
|
||||
# from transformers.models.llava_onevision.configuration_llava_onevision import LlavaOnevisionConfig
|
||||
# self.config.image_grid_pinpoints = LlavaOnevisionConfig().image_grid_pinpoints
|
||||
self.config.image_grid_pinpoints = [
|
||||
[96, 96],
|
||||
[224, 224],
|
||||
[384, 384],
|
||||
[512, 512],
|
||||
[768, 768],
|
||||
[1024, 1024],
|
||||
]
|
||||
if not hasattr(self.config, "mm_patch_merge_type"):
|
||||
self.config.mm_patch_merge_type = "flat"
|
||||
if not hasattr(self.config, "image_token_index"):
|
||||
self.config.image_token_index = 10
|
||||
if not hasattr(self.config, "projector_hidden_act"):
|
||||
self.config.projector_hidden_act = "gelu"
|
||||
|
||||
self.vision_feature_layer = getattr(config, "vision_feature_layer", -1)
|
||||
self.vision_feature_select_strategy = getattr(
|
||||
config, "vision_feature_select_strategy", "full"
|
||||
)
|
||||
self.image_size = self.config.vision_config.image_size
|
||||
self.patch_size = self.config.vision_config.patch_size
|
||||
|
||||
self.mm_patch_merge_type = config.mm_patch_merge_type
|
||||
self.image_aspect_ratio = config.image_aspect_ratio
|
||||
self.image_grid_pinpoints = config.image_grid_pinpoints
|
||||
|
||||
self.image_feature_len = int((self.image_size // self.patch_size) ** 2)
|
||||
|
||||
self.multi_modal_projector = self.MULTIMODAL_PROJECTOR_TYPE(config)
|
||||
|
||||
language_model_cls = self._get_sgl_model_cls(
|
||||
config.text_config, AutoModelForCausalLM
|
||||
)
|
||||
vision_model_cls = self._get_sgl_model_cls(config.vision_config, AutoModel)
|
||||
self.language_model = language_model_cls(
|
||||
config.text_config,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("language_model", prefix),
|
||||
)
|
||||
self.vision_tower = vision_model_cls(
|
||||
config.vision_config,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("vision_tower", prefix),
|
||||
)
|
||||
|
||||
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
|
||||
self.language_model.model.image_newline = nn.Parameter(
|
||||
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
|
||||
)
|
||||
|
||||
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
||||
"""Extract features from image inputs.
|
||||
|
||||
Args:
|
||||
items: List of MultimodalDataItem objects containing image data
|
||||
Note that an item can be either "image" or "multi-images"
|
||||
|
||||
Returns:
|
||||
torch.Tensor: features from image inputs, concatenated
|
||||
"""
|
||||
features = []
|
||||
for item in items:
|
||||
# in each item, we assume pixel_values is always batched
|
||||
pixel_values, image_sizes = item.pixel_values, item.image_sizes
|
||||
image_outputs = self.vision_tower(
|
||||
pixel_values, image_sizes, output_hidden_states=True
|
||||
)
|
||||
selected_image_feature = image_outputs.hidden_states[
|
||||
self.vision_feature_layer
|
||||
]
|
||||
|
||||
if self.vision_feature_select_strategy in ["default", "patch"]:
|
||||
selected_image_feature = selected_image_feature[:, 1:]
|
||||
elif self.vision_feature_select_strategy == "full":
|
||||
selected_image_feature = selected_image_feature
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected select feature: {self.vision_feature_select_strategy}"
|
||||
)
|
||||
features.append(
|
||||
self.multi_modal_projector(selected_image_feature.squeeze(0))
|
||||
)
|
||||
ret = torch.cat(features, dim=0)
|
||||
return ret
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
get_embedding: bool = False,
|
||||
):
|
||||
hidden_states = general_mm_embed_routine(
|
||||
input_ids=input_ids,
|
||||
forward_batch=forward_batch,
|
||||
get_embedding=get_embedding,
|
||||
language_model=self.language_model,
|
||||
image_data_embedding_func=self.get_image_feature,
|
||||
placeholder_tokens=None, # using mm_item.pad_value
|
||||
positions=positions,
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
"""Load weights for LlavaForConditionalGeneration.
|
||||
|
||||
Unlike the base class implementation, this one doesn't need to handle
|
||||
weight name remapping as the weights are already properly structured with
|
||||
'language_model' and 'vision_tower' prefixes in the safetensors files.
|
||||
"""
|
||||
if (
|
||||
self.vision_feature_select_strategy == "patch"
|
||||
or self.vision_feature_select_strategy == "full"
|
||||
):
|
||||
pass
|
||||
elif self.vision_feature_select_strategy == "cls_patch":
|
||||
self.image_feature_len += 1
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected select feature: {self.vision_feature_select_strategy}"
|
||||
)
|
||||
|
||||
# Create dictionaries for direct parameter loading
|
||||
params_dict = dict(self.named_parameters())
|
||||
|
||||
# Load weights directly without remapping
|
||||
for name, loaded_weight in weights:
|
||||
for part in ("language_model", "vision_tower"):
|
||||
if name.startswith(part):
|
||||
name = name[len(part + ".") :]
|
||||
getattr(self, part).load_weights([(name, loaded_weight)])
|
||||
break
|
||||
else:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
|
||||
EntryClass = [
|
||||
LlavaLlamaForCausalLM,
|
||||
LlavaQwenForCausalLM,
|
||||
LlavaMistralForCausalLM,
|
||||
LlavaForConditionalGeneration,
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user