model(vlm): mistral 3.1 (#5099)

Co-authored-by: KivenChen <sleigh-queue-0y@icloud.com>
This commit is contained in:
Kiv Chen
2025-05-16 18:36:18 -07:00
committed by GitHub
parent 69748d088d
commit 64825b8395
6 changed files with 152 additions and 21 deletions

View File

@@ -549,6 +549,7 @@ multimodal_model_archs = [
"LlavaVidForCausalLM",
"MiniCPMO",
"MiniCPMV",
"Mistral3ForConditionalGeneration",
"MultiModalityCausalLM",
"MllamaForConditionalGeneration",
"Qwen2VLForConditionalGeneration",

View File

@@ -20,6 +20,7 @@ from sglang.srt.models.llava import (
LlavaQwenForCausalLM,
)
from sglang.srt.models.llavavid import LlavaVidForCausalLM
from sglang.srt.models.mistral import Mistral3ForConditionalGeneration
from sglang.srt.utils import load_image, logger
from sglang.utils import get_exception_traceback
@@ -176,10 +177,10 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
class LlavaMultimodalProcessor(BaseMultimodalProcessor):
"""
This is a wrapper class used to identify the multimodal processor for Llava architecture models.
This is a wrapper class used to identify the multimodal processor for Llava architectures' vision model.
"""
models = [LlavaForConditionalGeneration]
models = [LlavaForConditionalGeneration, Mistral3ForConditionalGeneration]
def _get_sgl_processor_cls(self, model_type: str):
if hf_name := HF_MAPPING_NAMES.get(model_type):

View File

@@ -135,7 +135,6 @@ class LlavaBaseForCausalLM(nn.Module):
"""
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
# NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated.
selected_image_feature = image_outputs.hidden_states[self.vision_feature_layer]
if self.vision_feature_select_strategy in ["default", "patch"]:
selected_image_feature = selected_image_feature[:, 1:]
@@ -146,7 +145,6 @@ class LlavaBaseForCausalLM(nn.Module):
f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
)
image_features = self.multi_modal_projector(selected_image_feature)
return image_features
@torch.no_grad()
@@ -613,6 +611,10 @@ class LlavaForConditionalGeneration(LlavaBaseForCausalLM):
MULTIMODAL_PROJECTOR_TYPE = LlavaMultiModalProjector
@property
def dtype(self):
return self.torch_dtype
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
if hasattr(self.vision_tower, "pad_input_ids"):
return self.vision_tower.pad_input_ids(input_ids, image_inputs)
@@ -672,11 +674,17 @@ class LlavaForConditionalGeneration(LlavaBaseForCausalLM):
assert hasattr(config, "text_config")
assert hasattr(config, "vision_config")
self.config = config
self.text_config = config.text_config
self.vision_config = config.vision_config
self.text_config = self.config.text_config
self.vision_config = self.config.vision_config
self.torch_dtype = getattr(self.config, "torch_dtype")
if not getattr(self.text_config, "torch_dtype"):
self.text_config.torch_dtype = self.torch_dtype
if not getattr(self.vision_config, "torch_dtype"):
self.vision_config.torch_dtype = self.torch_dtype
if not hasattr(self.config, "vocab_size"):
self.config.vocab_size = self.config.text_config.vocab_size
self.config.vocab_size = self.text_config.vocab_size
if not hasattr(self.config, "image_aspect_ratio"):
self.config.image_aspect_ratio = "anyres"
if not hasattr(self.config, "image_grid_pinpoints"):
@@ -697,39 +705,39 @@ class LlavaForConditionalGeneration(LlavaBaseForCausalLM):
if not hasattr(self.config, "projector_hidden_act"):
self.config.projector_hidden_act = "gelu"
self.vision_feature_layer = getattr(config, "vision_feature_layer", -1)
self.vision_feature_layer = getattr(self.config, "vision_feature_layer", -1)
self.vision_feature_select_strategy = getattr(
config, "vision_feature_select_strategy", "full"
self.config, "vision_feature_select_strategy", "full"
)
self.image_size = self.config.vision_config.image_size
self.patch_size = self.config.vision_config.patch_size
self.image_size = self.vision_config.image_size
self.patch_size = self.vision_config.patch_size
self.mm_patch_merge_type = config.mm_patch_merge_type
self.image_aspect_ratio = config.image_aspect_ratio
self.image_grid_pinpoints = config.image_grid_pinpoints
self.mm_patch_merge_type = self.config.mm_patch_merge_type
self.image_aspect_ratio = self.config.image_aspect_ratio
self.image_grid_pinpoints = self.config.image_grid_pinpoints
self.image_feature_len = int((self.image_size // self.patch_size) ** 2)
self.multi_modal_projector = self.MULTIMODAL_PROJECTOR_TYPE(config)
language_model_cls = self._get_sgl_model_cls(
config.text_config, AutoModelForCausalLM
self.text_config, AutoModelForCausalLM
)
vision_model_cls = self._get_sgl_model_cls(config.vision_config, AutoModel)
vision_model_cls = self._get_sgl_model_cls(self.vision_config, AutoModel)
self.language_model = language_model_cls(
config.text_config,
self.text_config,
quant_config=quant_config,
prefix=add_prefix("language_model", prefix),
)
self.vision_tower = vision_model_cls(
config.vision_config,
self.vision_config,
quant_config=quant_config,
prefix=add_prefix("vision_tower", prefix),
)
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
if "unpad" in getattr(self.config, "mm_patch_merge_type", ""):
self.language_model.model.image_newline = nn.Parameter(
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
torch.empty(self.text_config.hidden_size, dtype=self.torch_dtype)
)
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:

View File

@@ -13,6 +13,12 @@
# ==============================================================================
"""Inference-only Mistral model."""
from typing import List, Union
import torch
from transformers.models.mistral3.modeling_mistral3 import Mistral3MultiModalProjector
from sglang.srt.managers.schedule_batch import MultimodalDataItem
from sglang.srt.models.llama import LlamaForCausalLM
@@ -20,4 +26,68 @@ class MistralForCausalLM(LlamaForCausalLM):
pass
EntryClass = MistralForCausalLM
class Mistral3ForConditionalGeneration:
MULTIMODAL_PROJECTOR_TYPE = Mistral3MultiModalProjector
def __init__(self, **kwargs):
# lazy load inner class
# to bypass circular import
from sglang.srt.models.llava import LlavaForConditionalGeneration
# override config: mistral's projector adds patchmerger that doesn't require padding
kwargs["config"].vision_config.pad_image_border = False
self.inner = LlavaForConditionalGeneration(**kwargs)
self.inner.multi_modal_projector = self.MULTIMODAL_PROJECTOR_TYPE(
kwargs["config"]
)
self.inner.get_image_feature = self.get_image_feature
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
"""Extract features from image inputs.
Args:
items: List of MultimodalDataItem objects containing image data
Note that an item can be either "image" or "multi-images"
Returns:
torch.Tensor: features from image inputs, concatenated
"""
features = []
for item in items:
# in each item, we assume pixel_values is always batched
pixel_values, image_sizes = item.pixel_values, item.image_sizes
image_outputs = self.vision_tower(
pixel_values, image_sizes, output_hidden_states=True
)
selected_image_feature = image_outputs.hidden_states[
self.vision_feature_layer
]
if self.vision_feature_select_strategy in ["default", "patch"]:
selected_image_feature = selected_image_feature[:, 1:]
elif self.vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
else:
raise ValueError(
f"Unexpected select feature: {self.vision_feature_select_strategy}"
)
features.append(
self.multi_modal_projector(
selected_image_feature.squeeze(0), image_sizes
)
)
ret = torch.cat(features, dim=0)
return ret
def __getattr__(self, name):
return getattr(self.inner, name)
def __hasattr__(self, name):
return hasattr(self.inner, name)
def __call__(self, *args, **kwargs):
return self.inner(*args, **kwargs)
EntryClass = [MistralForCausalLM, Mistral3ForConditionalGeneration]