model(vlm): pixtral (#5084)
This commit is contained in:
@@ -194,6 +194,21 @@ register_chat_template(
|
||||
)
|
||||
)
|
||||
|
||||
# Reference: https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/blob/main/chat_template.json
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="mistral",
|
||||
default_system_prompt=None,
|
||||
role_prefix_and_suffix={
|
||||
"system": ("[SYSTEM_PROMPT] ", " [/SYSTEM_PROMPT]"),
|
||||
"user": ("[INST] ", " [/INST]"),
|
||||
"assistant": ("", " </s><s>"),
|
||||
},
|
||||
stop_str=("</s>",),
|
||||
image_token="[IMG]",
|
||||
)
|
||||
)
|
||||
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="llama-3-instruct",
|
||||
@@ -509,13 +524,19 @@ def match_vicuna(model_path: str):
|
||||
@register_chat_template_matching_function
|
||||
def match_llama2_chat(model_path: str):
|
||||
if re.search(
|
||||
r"llama-2.*chat|(mistral|mixtral).*instruct|codellama.*instruct",
|
||||
r"llama-2.*chat|codellama.*instruct",
|
||||
model_path,
|
||||
re.IGNORECASE,
|
||||
):
|
||||
return "llama-2-chat"
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_mistral(model_path: str):
|
||||
if re.search(r"pixtral|(mistral|mixtral).*instruct", model_path, re.IGNORECASE):
|
||||
return "mistral"
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_llama3_instruct(model_path: str):
|
||||
if re.search(r"llama-3.*instruct", model_path, re.IGNORECASE):
|
||||
|
||||
@@ -545,6 +545,7 @@ multimodal_model_archs = [
|
||||
"Llama4ForConditionalGeneration",
|
||||
"LlavaMistralForCausalLM",
|
||||
"LlavaQwenForCausalLM",
|
||||
"LlavaForConditionalGeneration",
|
||||
"LlavaVidForCausalLM",
|
||||
"MiniCPMO",
|
||||
"MiniCPMV",
|
||||
|
||||
@@ -634,6 +634,20 @@ register_conv_template(
|
||||
)
|
||||
)
|
||||
|
||||
# reference: https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/blob/main/chat_template.json
|
||||
register_conv_template(
|
||||
Conversation(
|
||||
name="mistral",
|
||||
system_template="[SYSTEM_PROMPT]\n{system_message}\n[/SYSTEM_PROMPT]\n\n",
|
||||
roles=("[INST]", "[/INST]"),
|
||||
sep_style=SeparatorStyle.LLAMA2,
|
||||
sep=" ",
|
||||
sep2=" </s><s>",
|
||||
stop_str=["[INST]", "[/INST]", "[SYSTEM_PROMPT]", "[/SYSTEM_PROMPT]"],
|
||||
image_token="[IMG]",
|
||||
)
|
||||
)
|
||||
|
||||
# reference: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct/blob/main/chat_template.json
|
||||
register_conv_template(
|
||||
Conversation(
|
||||
@@ -880,13 +894,19 @@ def match_vicuna(model_path: str):
|
||||
@register_conv_template_matching_function
|
||||
def match_llama2_chat(model_path: str):
|
||||
if re.search(
|
||||
r"llama-2.*chat|(mistral|mixtral).*instruct|codellama.*instruct",
|
||||
r"llama-2.*chat|codellama.*instruct",
|
||||
model_path,
|
||||
re.IGNORECASE,
|
||||
):
|
||||
return "llama-2"
|
||||
|
||||
|
||||
@register_conv_template_matching_function
|
||||
def match_mistral(model_path: str):
|
||||
if re.search(r"pixtral|(mistral|mixtral).*instruct", model_path, re.IGNORECASE):
|
||||
return "mistral"
|
||||
|
||||
|
||||
@register_conv_template_matching_function
|
||||
def match_deepseek_vl(model_path: str):
|
||||
if re.search(r"deepseek.*vl2", model_path, re.IGNORECASE):
|
||||
|
||||
@@ -1,14 +1,20 @@
|
||||
import asyncio
|
||||
import importlib
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from transformers.models.auto.processing_auto import (
|
||||
PROCESSOR_MAPPING_NAMES as HF_MAPPING_NAMES,
|
||||
)
|
||||
|
||||
import sglang.srt.managers.multimodal_processor as sgl_mm_processor_utils
|
||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||
BaseMultimodalProcessor,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||
from sglang.srt.mm_utils import expand2square, process_anyres_image
|
||||
from sglang.srt.models.llava import (
|
||||
LlavaForConditionalGeneration,
|
||||
LlavaLlamaForCausalLM,
|
||||
LlavaMistralForCausalLM,
|
||||
LlavaQwenForCausalLM,
|
||||
@@ -133,6 +139,7 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
|
||||
img_data, aspect_ratio, grid_pinpoints
|
||||
)
|
||||
)
|
||||
|
||||
res = await asyncio.gather(*res)
|
||||
for pixel_v, image_h, image_s in res:
|
||||
pixel_values.append(pixel_v)
|
||||
@@ -165,3 +172,42 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
|
||||
)
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class LlavaMultimodalProcessor(BaseMultimodalProcessor):
|
||||
"""
|
||||
This is a wrapper class used to identify the multimodal processor for Llava architecture models.
|
||||
"""
|
||||
|
||||
models = [LlavaForConditionalGeneration]
|
||||
|
||||
def _get_sgl_processor_cls(self, model_type: str):
|
||||
if hf_name := HF_MAPPING_NAMES.get(model_type):
|
||||
sgl_mm_processor_set = sgl_mm_processor_utils.PROCESSOR_MAPPING.values()
|
||||
sgl_processor_cls = list(
|
||||
filter(lambda p: p.__name__ == hf_name, sgl_mm_processor_set)
|
||||
)
|
||||
if sgl_processor_cls:
|
||||
return sgl_processor_cls[0]
|
||||
raise ValueError(
|
||||
f"Cannot find corresponding multimodal processor registered in sglang for model type `{model_type}`"
|
||||
)
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
assert hasattr(hf_config, "vision_config")
|
||||
assert hasattr(hf_config, "text_config")
|
||||
self.vision_config = hf_config.vision_config
|
||||
self.text_config = hf_config.text_config
|
||||
self.hf_config = hf_config
|
||||
|
||||
if vision_type := getattr(self.vision_config, "model_type"):
|
||||
self.inner = self._get_sgl_processor_cls(vision_type)(
|
||||
hf_config, server_args, _processor
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Required `vision_config.model_type` is not found in hf_config: `{hf_config}`"
|
||||
)
|
||||
|
||||
async def process_mm_data_async(self, *args, **kwargs):
|
||||
return await self.inner.process_mm_data_async(*args, **kwargs)
|
||||
|
||||
127
python/sglang/srt/managers/multimodal_processors/pixtral.py
Normal file
127
python/sglang/srt/managers/multimodal_processors/pixtral.py
Normal file
@@ -0,0 +1,127 @@
|
||||
import asyncio
|
||||
import math
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from transformers import PretrainedConfig
|
||||
from transformers.models.pixtral.image_processing_pixtral import (
|
||||
_num_image_tokens as _get_pixtral_hf_num_image_tokens,
|
||||
)
|
||||
|
||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||
BaseMultimodalProcessor,
|
||||
MultimodalSpecialTokens,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import (
|
||||
Modality,
|
||||
MultimodalDataItem,
|
||||
MultimodalInputs,
|
||||
)
|
||||
from sglang.srt.models.pixtral import PixtralVisionModel
|
||||
|
||||
|
||||
class PixtralProcessor(BaseMultimodalProcessor):
|
||||
models = [PixtralVisionModel]
|
||||
|
||||
PAD_TOKEN = "<pad>"
|
||||
IMG_BREAK_TOKEN_ID = 12
|
||||
IMG_END_TOKEN_ID = 13
|
||||
|
||||
def get_patch_grid_size(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
) -> tuple[int, int]:
|
||||
max_width = max_height = self.image_size
|
||||
patch_width = patch_height = self.patch_size
|
||||
|
||||
ratio = max(image_width / max_width, image_height / max_height)
|
||||
|
||||
if ratio > 1:
|
||||
image_width = int(math.floor(image_width / ratio))
|
||||
image_height = int(math.floor(image_height / ratio))
|
||||
|
||||
nrows, ncols = _get_pixtral_hf_num_image_tokens(
|
||||
(image_height, image_width),
|
||||
(patch_height, patch_width),
|
||||
)
|
||||
|
||||
return ncols, nrows
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
self.image_token_id = getattr(
|
||||
hf_config, "image_token_index", PixtralVisionModel.DEFAULT_IMAGE_TOKEN_ID
|
||||
)
|
||||
# Instantiate the patcher logic helper using the class defined above
|
||||
|
||||
self.vision_config = hf_config.vision_config
|
||||
self.image_size = self.vision_config.image_size
|
||||
self.patch_size = self.vision_config.patch_size
|
||||
self.multimodal_tokens = MultimodalSpecialTokens(
|
||||
image_token=_processor.image_token
|
||||
)
|
||||
_processor.tokenizer.add_special_tokens(
|
||||
{
|
||||
"pad_token": getattr(hf_config, "pad_token", self.PAD_TOKEN),
|
||||
}
|
||||
)
|
||||
|
||||
async def _resize(self, image):
|
||||
num_w_tokens, num_h_tokens = self.get_patch_grid_size(
|
||||
image_width=image.size[0],
|
||||
image_height=image.size[1],
|
||||
)
|
||||
new_size = (num_w_tokens * self.patch_size, num_h_tokens * self.patch_size)
|
||||
return image.resize(new_size)
|
||||
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
image_data: List[Union[str, bytes]],
|
||||
input_text,
|
||||
request_obj,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
if not image_data:
|
||||
return None
|
||||
|
||||
if isinstance(image_data, str):
|
||||
image_data = [image_data]
|
||||
|
||||
mm_data = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
multimodal_tokens=self.multimodal_tokens,
|
||||
max_req_input_len=kwargs.get("max_req_input_len", 4096),
|
||||
image_data=image_data,
|
||||
return_text=True,
|
||||
)
|
||||
|
||||
if mm_data.images:
|
||||
resize_tasks = [self._resize(image) for image in mm_data.images]
|
||||
mm_data.images = await asyncio.gather(*resize_tasks)
|
||||
|
||||
processor_output = self.process_mm_data(
|
||||
input_text=mm_data.input_text,
|
||||
images=mm_data.images,
|
||||
)
|
||||
|
||||
if "pixel_values" in processor_output:
|
||||
mm_items = [
|
||||
MultimodalDataItem(
|
||||
pixel_values=processor_output["pixel_values"],
|
||||
image_sizes=processor_output["image_sizes"],
|
||||
modality=Modality.IMAGE,
|
||||
)
|
||||
]
|
||||
|
||||
input_ids = processor_output["input_ids"].view(-1).tolist()
|
||||
processor_output.update(
|
||||
input_ids=input_ids,
|
||||
mm_items=mm_items,
|
||||
# there's no im_start_id for pixtral, only im_token and im_end_token
|
||||
im_end_id=self.IMG_END_TOKEN_ID,
|
||||
im_token_id=self.image_token_id,
|
||||
)
|
||||
return processor_output
|
||||
@@ -15,7 +15,8 @@
|
||||
|
||||
import math
|
||||
import re
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
from functools import lru_cache
|
||||
from typing import Dict, Iterable, List, Optional, Tuple, Type, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -28,10 +29,18 @@ from transformers import (
|
||||
Qwen2Config,
|
||||
SiglipVisionModel,
|
||||
)
|
||||
from transformers.models.auto.modeling_auto import AutoModel, AutoModelForCausalLM
|
||||
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
||||
|
||||
# leave till last and symbol only in case circular import
|
||||
import sglang.srt.models as sgl_models
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.managers.schedule_batch import Modality, MultimodalInputs
|
||||
from sglang.srt.managers.mm_utils import general_mm_embed_routine
|
||||
from sglang.srt.managers.schedule_batch import (
|
||||
Modality,
|
||||
MultimodalDataItem,
|
||||
MultimodalInputs,
|
||||
)
|
||||
from sglang.srt.mm_utils import (
|
||||
get_anyres_image_grid_shape,
|
||||
unpad_image,
|
||||
@@ -42,7 +51,7 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.llama import LlamaForCausalLM
|
||||
from sglang.srt.models.mistral import MistralForCausalLM
|
||||
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
||||
from sglang.srt.utils import add_prefix, flatten_nested_list
|
||||
from sglang.srt.utils import add_prefix, flatten_nested_list, logger
|
||||
|
||||
|
||||
class LlavaBaseForCausalLM(nn.Module):
|
||||
@@ -114,7 +123,16 @@ class LlavaBaseForCausalLM(nn.Module):
|
||||
image_inputs.image_offsets = offset_list
|
||||
return input_ids
|
||||
|
||||
def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||||
def encode_images(
|
||||
self, pixel_values: Union[torch.Tensor, List[torch.Tensor]]
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
encode images by vision tower and multimodal projector
|
||||
Args:
|
||||
pixel_values: torch.Tensor or List[torch.Tensor]: each tensor for an input image
|
||||
Returns:
|
||||
torch.Tensor: encoded image features from the input image; if multiple, flattened by seq_len axis
|
||||
"""
|
||||
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
|
||||
# NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated.
|
||||
|
||||
@@ -583,4 +601,229 @@ class LlavaMistralForCausalLM(LlavaBaseForCausalLM):
|
||||
)
|
||||
|
||||
|
||||
EntryClass = [LlavaLlamaForCausalLM, LlavaQwenForCausalLM, LlavaMistralForCausalLM]
|
||||
class LlavaForConditionalGeneration(LlavaBaseForCausalLM):
|
||||
"""
|
||||
An adaptor class to enable support for multiple mmlm such as mistral-community/pixtral-12b
|
||||
It follows the structure of (vision_tower, multi_modal_projector, language_model)
|
||||
|
||||
Once a model config is loaded, text_config and vision_config will be extracted, and
|
||||
LlavaForConditionalGeneration will load the language_model and vision_tower models
|
||||
according to config.
|
||||
"""
|
||||
|
||||
MULTIMODAL_PROJECTOR_TYPE = LlavaMultiModalProjector
|
||||
|
||||
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
|
||||
if hasattr(self.vision_tower, "pad_input_ids"):
|
||||
return self.vision_tower.pad_input_ids(input_ids, image_inputs)
|
||||
else:
|
||||
return super().pad_input_ids(input_ids, image_inputs)
|
||||
|
||||
def _get_sgl_model_cls(self, config, auto_model_type: Type[AutoModel] = AutoModel):
|
||||
"""
|
||||
Get the SGLang model implementation class according to config.
|
||||
|
||||
Args:
|
||||
config: The config object of the model.
|
||||
auto_model_type: The type of the auto model.
|
||||
|
||||
Returns:
|
||||
The SGLang model implementation class.
|
||||
"""
|
||||
config_cls_name = config.__class__.__name__
|
||||
arch_name_mapping = self._config_cls_name_to_arch_name_mapping(auto_model_type)
|
||||
if arch := arch_name_mapping.get(config_cls_name):
|
||||
if isinstance(arch, tuple):
|
||||
arch = arch[0]
|
||||
logger.warning(
|
||||
f"Multiple {auto_model_type.__name__} models found for submodule config `{config_cls_name}`, defaulting to [0]: {arch.__name__}"
|
||||
)
|
||||
try:
|
||||
return sgl_models.registry.ModelRegistry.resolve_model_cls(arch)[0]
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"{auto_model_type.__name__} found a corresponding model `{arch}` for config class `{config_cls_name}`, but failed to load it from SGLang ModelRegistry. \n{e}"
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{auto_model_type.__name__} cannot find a corresponding model for config class `{config_cls_name}`"
|
||||
)
|
||||
|
||||
@lru_cache
|
||||
def _config_cls_name_to_arch_name_mapping(
|
||||
self, auto_model_type: Type[AutoModel]
|
||||
) -> Dict[str, str]:
|
||||
mapping = {}
|
||||
for config_cls, archs in auto_model_type._model_mapping.items():
|
||||
if isinstance(archs, tuple):
|
||||
mapping[config_cls.__name__] = tuple(arch.__name__ for arch in archs)
|
||||
else:
|
||||
mapping[config_cls.__name__] = archs.__name__
|
||||
return mapping
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: LlavaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
assert hasattr(config, "text_config")
|
||||
assert hasattr(config, "vision_config")
|
||||
self.config = config
|
||||
self.text_config = config.text_config
|
||||
self.vision_config = config.vision_config
|
||||
|
||||
if not hasattr(self.config, "vocab_size"):
|
||||
self.config.vocab_size = self.config.text_config.vocab_size
|
||||
if not hasattr(self.config, "image_aspect_ratio"):
|
||||
self.config.image_aspect_ratio = "anyres"
|
||||
if not hasattr(self.config, "image_grid_pinpoints"):
|
||||
# from transformers.models.llava_onevision.configuration_llava_onevision import LlavaOnevisionConfig
|
||||
# self.config.image_grid_pinpoints = LlavaOnevisionConfig().image_grid_pinpoints
|
||||
self.config.image_grid_pinpoints = [
|
||||
[96, 96],
|
||||
[224, 224],
|
||||
[384, 384],
|
||||
[512, 512],
|
||||
[768, 768],
|
||||
[1024, 1024],
|
||||
]
|
||||
if not hasattr(self.config, "mm_patch_merge_type"):
|
||||
self.config.mm_patch_merge_type = "flat"
|
||||
if not hasattr(self.config, "image_token_index"):
|
||||
self.config.image_token_index = 10
|
||||
if not hasattr(self.config, "projector_hidden_act"):
|
||||
self.config.projector_hidden_act = "gelu"
|
||||
|
||||
self.vision_feature_layer = getattr(config, "vision_feature_layer", -1)
|
||||
self.vision_feature_select_strategy = getattr(
|
||||
config, "vision_feature_select_strategy", "full"
|
||||
)
|
||||
self.image_size = self.config.vision_config.image_size
|
||||
self.patch_size = self.config.vision_config.patch_size
|
||||
|
||||
self.mm_patch_merge_type = config.mm_patch_merge_type
|
||||
self.image_aspect_ratio = config.image_aspect_ratio
|
||||
self.image_grid_pinpoints = config.image_grid_pinpoints
|
||||
|
||||
self.image_feature_len = int((self.image_size // self.patch_size) ** 2)
|
||||
|
||||
self.multi_modal_projector = self.MULTIMODAL_PROJECTOR_TYPE(config)
|
||||
|
||||
language_model_cls = self._get_sgl_model_cls(
|
||||
config.text_config, AutoModelForCausalLM
|
||||
)
|
||||
vision_model_cls = self._get_sgl_model_cls(config.vision_config, AutoModel)
|
||||
self.language_model = language_model_cls(
|
||||
config.text_config,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("language_model", prefix),
|
||||
)
|
||||
self.vision_tower = vision_model_cls(
|
||||
config.vision_config,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("vision_tower", prefix),
|
||||
)
|
||||
|
||||
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
|
||||
self.language_model.model.image_newline = nn.Parameter(
|
||||
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
|
||||
)
|
||||
|
||||
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
||||
"""Extract features from image inputs.
|
||||
|
||||
Args:
|
||||
items: List of MultimodalDataItem objects containing image data
|
||||
Note that an item can be either "image" or "multi-images"
|
||||
|
||||
Returns:
|
||||
torch.Tensor: features from image inputs, concatenated
|
||||
"""
|
||||
features = []
|
||||
for item in items:
|
||||
# in each item, we assume pixel_values is always batched
|
||||
pixel_values, image_sizes = item.pixel_values, item.image_sizes
|
||||
image_outputs = self.vision_tower(
|
||||
pixel_values, image_sizes, output_hidden_states=True
|
||||
)
|
||||
selected_image_feature = image_outputs.hidden_states[
|
||||
self.vision_feature_layer
|
||||
]
|
||||
|
||||
if self.vision_feature_select_strategy in ["default", "patch"]:
|
||||
selected_image_feature = selected_image_feature[:, 1:]
|
||||
elif self.vision_feature_select_strategy == "full":
|
||||
selected_image_feature = selected_image_feature
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected select feature: {self.vision_feature_select_strategy}"
|
||||
)
|
||||
features.append(
|
||||
self.multi_modal_projector(selected_image_feature.squeeze(0))
|
||||
)
|
||||
ret = torch.cat(features, dim=0)
|
||||
return ret
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
get_embedding: bool = False,
|
||||
):
|
||||
hidden_states = general_mm_embed_routine(
|
||||
input_ids=input_ids,
|
||||
forward_batch=forward_batch,
|
||||
get_embedding=get_embedding,
|
||||
language_model=self.language_model,
|
||||
image_data_embedding_func=self.get_image_feature,
|
||||
placeholder_tokens=None, # using mm_item.pad_value
|
||||
positions=positions,
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
"""Load weights for LlavaForConditionalGeneration.
|
||||
|
||||
Unlike the base class implementation, this one doesn't need to handle
|
||||
weight name remapping as the weights are already properly structured with
|
||||
'language_model' and 'vision_tower' prefixes in the safetensors files.
|
||||
"""
|
||||
if (
|
||||
self.vision_feature_select_strategy == "patch"
|
||||
or self.vision_feature_select_strategy == "full"
|
||||
):
|
||||
pass
|
||||
elif self.vision_feature_select_strategy == "cls_patch":
|
||||
self.image_feature_len += 1
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected select feature: {self.vision_feature_select_strategy}"
|
||||
)
|
||||
|
||||
# Create dictionaries for direct parameter loading
|
||||
params_dict = dict(self.named_parameters())
|
||||
|
||||
# Load weights directly without remapping
|
||||
for name, loaded_weight in weights:
|
||||
for part in ("language_model", "vision_tower"):
|
||||
if name.startswith(part):
|
||||
name = name[len(part + ".") :]
|
||||
getattr(self, part).load_weights([(name, loaded_weight)])
|
||||
break
|
||||
else:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
|
||||
EntryClass = [
|
||||
LlavaLlamaForCausalLM,
|
||||
LlavaQwenForCausalLM,
|
||||
LlavaMistralForCausalLM,
|
||||
LlavaForConditionalGeneration,
|
||||
]
|
||||
|
||||
467
python/sglang/srt/models/pixtral.py
Normal file
467
python/sglang/srt/models/pixtral.py
Normal file
@@ -0,0 +1,467 @@
|
||||
# Copyright 2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""
|
||||
Using mistral-community/pixtral-12b as reference.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import PixtralVisionConfig, PretrainedConfig
|
||||
from transformers.models.pixtral.modeling_pixtral import PixtralRotaryEmbedding
|
||||
from transformers.models.pixtral.modeling_pixtral import (
|
||||
generate_block_attention_mask as _get_pixtral_attention_mask,
|
||||
)
|
||||
from transformers.models.pixtral.modeling_pixtral import position_ids_in_meshgrid
|
||||
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.attention.vision import VisionAttention
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.linear import MergedColumnParallelLinear, RowParallelLinear
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.managers.mm_utils import MultiModalityDataPaddingPatternMultimodalTokens
|
||||
from sglang.srt.managers.schedule_batch import MultimodalInputs
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
|
||||
class PixtralHFMLP(nn.Module):
|
||||
"""MLP for PixtralHFVisionModel using SGLang components."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
*,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
assert config.intermediate_size is not None
|
||||
|
||||
# Use MergedColumnParallelLinear for gate_up_proj to handle combined weights
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
input_size=config.hidden_size,
|
||||
output_sizes=[config.intermediate_size, config.intermediate_size],
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
)
|
||||
|
||||
self.down_proj = RowParallelLinear(
|
||||
input_size=config.intermediate_size,
|
||||
output_size=config.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
)
|
||||
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
gate_up_output, _ = self.gate_up_proj(x)
|
||||
|
||||
# Apply SiLU activation and multiply
|
||||
gate_up = self.act_fn(gate_up_output)
|
||||
|
||||
# Project back to hidden size
|
||||
out, _ = self.down_proj(gate_up)
|
||||
return out
|
||||
|
||||
|
||||
class PixtralHFTransformerBlock(nn.Module):
|
||||
"""Transformer block for PixtralHFVisionModel using SGLang components."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
layer_id: int,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
*,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.layer_id = layer_id
|
||||
self.attention_norm = RMSNorm(config.hidden_size, eps=1e-5)
|
||||
|
||||
# Use SGLang's VisionAttention instead of vLLM's PixtralHFAttention
|
||||
self.attention = VisionAttention(
|
||||
embed_dim=config.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
projection_size=config.hidden_size,
|
||||
use_qkv_parallel=True,
|
||||
quant_config=quant_config,
|
||||
dropout=0.0,
|
||||
use_context_forward=False,
|
||||
softmax_in_single_precision=False,
|
||||
flatten_batch=False,
|
||||
prefix=f"{prefix}.attention",
|
||||
)
|
||||
|
||||
self.feed_forward = PixtralHFMLP(
|
||||
config, quant_config=quant_config, prefix=f"{prefix}.feed_forward"
|
||||
)
|
||||
|
||||
self.ffn_norm = RMSNorm(config.hidden_size, eps=1e-5)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
||||
) -> torch.Tensor:
|
||||
# Ensure hidden_states has the batch dimension [batch, seq_len, hidden_dim]
|
||||
batch_size, seq_len, hidden_dim = hidden_states.shape
|
||||
|
||||
# Apply attention norm - normalize along the last dimension
|
||||
attn_normalized = self.attention_norm(hidden_states.view(-1, hidden_dim)).view(
|
||||
batch_size, seq_len, hidden_dim
|
||||
)
|
||||
|
||||
# Pass through attention layer
|
||||
attention_output = self.attention(
|
||||
attn_normalized,
|
||||
attention_mask=attention_mask,
|
||||
cu_seqlens=None,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
|
||||
# Apply first residual connection
|
||||
hidden_states = hidden_states + attention_output
|
||||
|
||||
# Apply feed-forward norm - normalize along the last dimension
|
||||
ffn_normalized = self.ffn_norm(hidden_states.view(-1, hidden_dim)).view(
|
||||
batch_size, seq_len, hidden_dim
|
||||
)
|
||||
|
||||
# Pass through feed-forward layer
|
||||
# First reshape to 2D for the feed-forward network, then reshape back
|
||||
ffn_output = self.feed_forward(ffn_normalized)
|
||||
|
||||
# Apply second residual connection
|
||||
output = hidden_states + ffn_output
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class PixtralHFTransformer(nn.Module):
|
||||
"""Transformer for PixtralHFVisionModel using SGLang components."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PixtralVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
*,
|
||||
num_hidden_layers_override: Optional[int] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
num_hidden_layers = config.num_hidden_layers
|
||||
if num_hidden_layers_override is not None:
|
||||
num_hidden_layers = num_hidden_layers_override
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
PixtralHFTransformerBlock(
|
||||
config=config,
|
||||
layer_id=layer_idx,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.layers.{layer_idx}",
|
||||
)
|
||||
for layer_idx in range(num_hidden_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
||||
return_all_hidden_states: bool = False,
|
||||
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
"""Forward pass through transformer layers.
|
||||
|
||||
Args:
|
||||
x: Input tensor
|
||||
attention_mask: Optional attention mask
|
||||
position_embeddings: Optional position embeddings for rotary attention
|
||||
return_all_hidden_states: Whether to return all hidden states
|
||||
|
||||
Returns:
|
||||
Either the final hidden state, or a list of all hidden states if
|
||||
return_all_hidden_states is True
|
||||
"""
|
||||
# For HF model compatibility, always start with the input
|
||||
hidden_states = x
|
||||
all_hidden_states = [hidden_states] if return_all_hidden_states else None
|
||||
|
||||
for i, layer in enumerate(self.layers):
|
||||
hidden_states = layer(hidden_states, attention_mask, position_embeddings)
|
||||
if return_all_hidden_states:
|
||||
all_hidden_states.append(hidden_states)
|
||||
|
||||
if return_all_hidden_states:
|
||||
return all_hidden_states
|
||||
return hidden_states
|
||||
|
||||
|
||||
def resolve_visual_encoder_outputs(
|
||||
outputs: Union[torch.Tensor, List[torch.Tensor]],
|
||||
feature_sample_layers: Optional[List[int]],
|
||||
post_norm: Optional[nn.Module],
|
||||
num_hidden_layers: int,
|
||||
) -> torch.Tensor:
|
||||
"""Resolve outputs from visual encoder based on feature_sample_layers."""
|
||||
if feature_sample_layers is None:
|
||||
# Just use the last layer's output
|
||||
if isinstance(outputs, list):
|
||||
outputs = outputs[-1]
|
||||
if post_norm is not None:
|
||||
outputs = post_norm(outputs)
|
||||
return outputs
|
||||
|
||||
# Handle the case where we want to use specific layers
|
||||
if not isinstance(outputs, list):
|
||||
raise ValueError(
|
||||
"Expected outputs to be a list when feature_sample_layers is provided"
|
||||
)
|
||||
|
||||
# Validate layer indices
|
||||
for layer_idx in feature_sample_layers:
|
||||
if layer_idx < 0 or layer_idx > num_hidden_layers:
|
||||
raise ValueError(
|
||||
f"Feature sample layer index {layer_idx} is out of range "
|
||||
f"[0, {num_hidden_layers}]"
|
||||
)
|
||||
|
||||
# Collect outputs from specified layers
|
||||
selected_outputs = [outputs[layer_idx] for layer_idx in feature_sample_layers]
|
||||
|
||||
# Combine the outputs
|
||||
combined_outputs = torch.cat(selected_outputs, dim=-1)
|
||||
|
||||
if post_norm is not None:
|
||||
combined_outputs = post_norm(combined_outputs)
|
||||
|
||||
return combined_outputs
|
||||
|
||||
|
||||
class PixtralHFVisionModel(nn.Module):
|
||||
"""Hugging Face Pixtral Vision Model implemented using SGLang components."""
|
||||
|
||||
DEFAULT_IMAGE_TOKEN_ID = 10
|
||||
|
||||
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
|
||||
return self.input_padder.pad_input_tokens(input_ids, image_inputs)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PixtralVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
*,
|
||||
image_token_id: int = DEFAULT_IMAGE_TOKEN_ID,
|
||||
num_hidden_layers_override: Optional[int] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
|
||||
self.image_size = config.image_size
|
||||
self.patch_size = config.patch_size
|
||||
|
||||
self.patch_conv = nn.Conv2d(
|
||||
in_channels=config.num_channels,
|
||||
out_channels=config.hidden_size,
|
||||
kernel_size=config.patch_size,
|
||||
stride=config.patch_size,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
self.ln_pre = RMSNorm(config.hidden_size, eps=1e-5)
|
||||
|
||||
self.transformer = PixtralHFTransformer(
|
||||
config,
|
||||
quant_config,
|
||||
num_hidden_layers_override=num_hidden_layers_override,
|
||||
prefix=f"{prefix}.transformer",
|
||||
)
|
||||
|
||||
# Check that num_hidden_layers is valid
|
||||
num_hidden_layers = config.num_hidden_layers
|
||||
if len(self.transformer.layers) > config.num_hidden_layers:
|
||||
raise ValueError(
|
||||
f"The original encoder only has {num_hidden_layers} "
|
||||
f"layers, but you requested {len(self.transformer.layers)} "
|
||||
"layers."
|
||||
)
|
||||
|
||||
# Initialize patch position embedding
|
||||
self.image_token_id = image_token_id
|
||||
self.patch_positional_embedding = PixtralRotaryEmbedding(config)
|
||||
self.input_padder = MultiModalityDataPaddingPatternMultimodalTokens(
|
||||
[self.image_token_id]
|
||||
)
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(self.parameters()).dtype
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
image_sizes: list[tuple[int, int]],
|
||||
output_hidden_states: bool = False,
|
||||
feature_sample_layers: Optional[list[int]] = None,
|
||||
) -> Union[torch.Tensor, tuple]:
|
||||
"""
|
||||
Args:
|
||||
pixel_values: [batch_size, C, H, W], padded if multiple images
|
||||
image_sizes: list of (H, W) for each image in the batch
|
||||
output_hidden_states: Whether to return all hidden states.
|
||||
feature_sample_layers: Layer indices whose features should be
|
||||
concatenated and used as the visual encoder output. If none
|
||||
are provided, the last layer is used.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- hidden_states: Final model outputs (or selected layers if feature_sample_layers given)
|
||||
- hidden_states tuple (optional): All hidden states if output_hidden_states=True
|
||||
"""
|
||||
# batch patch images
|
||||
embeds_orig = self.patch_conv(
|
||||
pixel_values.to(device=self.device, dtype=self.dtype)
|
||||
)
|
||||
# crop the embeddings
|
||||
embeds_2d = [
|
||||
embed[..., : h // self.patch_size, : w // self.patch_size]
|
||||
for embed, (h, w) in zip(embeds_orig, image_sizes)
|
||||
]
|
||||
|
||||
# flatten to sequence
|
||||
embeds_1d = torch.cat([p.flatten(1).T for p in embeds_2d], dim=0)
|
||||
embeds_featurized = self.ln_pre(embeds_1d).unsqueeze(0)
|
||||
|
||||
# positional embeddings
|
||||
position_ids = position_ids_in_meshgrid(
|
||||
embeds_2d,
|
||||
max_width=self.image_size // self.patch_size,
|
||||
).to(self.device)
|
||||
|
||||
# The original PixtralRotaryEmbedding expects 2D input but returns a tuple of tensors (cos, sin)
|
||||
# These tensors are used by apply_rotary_pos_emb in the transformer blocks
|
||||
position_embedding = self.patch_positional_embedding(
|
||||
embeds_featurized, position_ids
|
||||
)
|
||||
attention_mask = _get_pixtral_attention_mask(
|
||||
[p.shape[-2] * p.shape[-1] for p in embeds_2d], embeds_featurized
|
||||
)
|
||||
|
||||
return_all_hidden_states = (
|
||||
output_hidden_states or feature_sample_layers is not None
|
||||
)
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
embeds_featurized, # add batch dimension
|
||||
attention_mask,
|
||||
position_embedding,
|
||||
return_all_hidden_states=return_all_hidden_states,
|
||||
)
|
||||
|
||||
# Store all hidden states if requested
|
||||
all_hidden_states = None
|
||||
if isinstance(transformer_outputs, list):
|
||||
all_hidden_states = transformer_outputs
|
||||
# Use the last layer by default if feature_sample_layers is not specified
|
||||
if feature_sample_layers is None:
|
||||
out = transformer_outputs[-1]
|
||||
else:
|
||||
# Resolve outputs based on feature sample layers
|
||||
out = resolve_visual_encoder_outputs(
|
||||
transformer_outputs,
|
||||
feature_sample_layers,
|
||||
None,
|
||||
self.config.num_hidden_layers,
|
||||
)
|
||||
else:
|
||||
out = transformer_outputs
|
||||
|
||||
# Format return to be compatible with HuggingFace vision models
|
||||
if output_hidden_states:
|
||||
return type(
|
||||
"VisualOutput",
|
||||
(),
|
||||
{
|
||||
"last_hidden_state": out,
|
||||
"hidden_states": all_hidden_states,
|
||||
},
|
||||
)
|
||||
else:
|
||||
return out
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]:
|
||||
"""Load weights from a HuggingFace checkpoint with proper parameter mapping."""
|
||||
params_dict = dict(self.named_parameters())
|
||||
|
||||
# for (param, weight, shard_id): load weight into param as param's shard_id part
|
||||
stacked_params_mapping = [
|
||||
(".attention.qkv_proj", ".attention.q_proj", "q"),
|
||||
(".attention.qkv_proj", ".attention.k_proj", "k"),
|
||||
(".attention.qkv_proj", ".attention.v_proj", "v"),
|
||||
(".feed_forward.gate_up_proj", ".feed_forward.gate_proj", 0),
|
||||
(".feed_forward.gate_up_proj", ".feed_forward.up_proj", 1),
|
||||
]
|
||||
|
||||
# Process each weight
|
||||
for name, loaded_weight in weights:
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name in name:
|
||||
# Replace the weight name part with the combined parameter name
|
||||
transformed_name = name.replace(weight_name, param_name)
|
||||
if transformed_name in params_dict:
|
||||
param = params_dict[transformed_name]
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
if ".attention.o_proj" in name:
|
||||
alt_name = name.replace(".attention.o_proj", ".attention.proj")
|
||||
if alt_name in params_dict:
|
||||
name = alt_name
|
||||
if name in params_dict:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
|
||||
class PixtralVisionModel(PixtralHFVisionModel):
|
||||
pass
|
||||
|
||||
|
||||
# Register the model classes for external access
|
||||
EntryClass = [PixtralVisionModel]
|
||||
@@ -19,7 +19,9 @@ from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import transformers
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModel,
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForVision2Seq,
|
||||
@@ -211,7 +213,12 @@ class HFRunner:
|
||||
|
||||
# Load the model and tokenizer
|
||||
if self.model_type == "generation":
|
||||
self.base_model = AutoModelForCausalLM.from_pretrained(
|
||||
config = AutoConfig.from_pretrained(model_path)
|
||||
if model_archs := getattr(config, "architectures"):
|
||||
model_cls = getattr(transformers, model_archs[0])
|
||||
else:
|
||||
model_cls = AutoModelForCausalLM
|
||||
self.base_model = model_cls.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch_dtype,
|
||||
trust_remote_code=self.trust_remote_code,
|
||||
|
||||
Reference in New Issue
Block a user