From c1815a99b78e1146e8c47020c1959f787cf31b10 Mon Sep 17 00:00:00 2001 From: Chang Su Date: Thu, 18 Sep 2025 17:30:38 -0700 Subject: [PATCH] model support: Sarashina2VisionForCausalLM (#10632) --- .../vision_template_sarashina_vl.jinja | 9 + python/sglang/srt/configs/model_config.py | 1 + python/sglang/srt/hf_transformers_utils.py | 4 +- python/sglang/srt/models/llama.py | 4 + python/sglang/srt/models/sarashina2_vision.py | 269 ++++++++++++++++++ .../processors/sarashina2_vision.py | 81 ++++++ 6 files changed, 366 insertions(+), 2 deletions(-) create mode 100644 examples/chat_template/vision_template_sarashina_vl.jinja create mode 100644 python/sglang/srt/models/sarashina2_vision.py create mode 100644 python/sglang/srt/multimodal/processors/sarashina2_vision.py diff --git a/examples/chat_template/vision_template_sarashina_vl.jinja b/examples/chat_template/vision_template_sarashina_vl.jinja new file mode 100644 index 000000000..caff34415 --- /dev/null +++ b/examples/chat_template/vision_template_sarashina_vl.jinja @@ -0,0 +1,9 @@ +{# + In sglang, the default chat templates often assume message['content'] is a plain string. + That works fine for simple text conversations, but it ignores multimodal inputs (e.g. image_url, tool_call). + To align with the original model behavior and support richer content, + we iterate over message['content'] as a list of typed items and extract their values directly. + This way, both text and non-text inputs are preserved in the prompt. + Original template: https://huggingface.co/sbintuitions/sarashina2-vision-8b?chat_template=default +#} +{{ bos_token + '<|prefix|><|file|><|suffix|>A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human\'s questions.\n\n' }}{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Human: ' }}{%- if message['content'] is string %}{{ message['content'] }}{%- else %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% endif %}{% endfor %}{% endif %}{{ '\n' }}{% elif message['role'] == 'assistant' %}{{ '### Assistant: ' }}{%- if message['content'] is string %}{{ message['content'] }}{%- else %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% endif %}{% endfor %}{% endif %}{{ '\n' }}{% endif %}{% endfor %}{% if messages[-1]['role'] == 'user' %}{{ '### Assistant:' }}{% endif %} diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 14da05af1..3e4cd2688 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -756,6 +756,7 @@ multimodal_model_archs = [ "VILAForConditionalGeneration", "Step3VLForConditionalGeneration", "DotsVLMForCausalLM", + "Sarashina2VisionForCausalLM", ] diff --git a/python/sglang/srt/hf_transformers_utils.py b/python/sglang/srt/hf_transformers_utils.py index 6cf22a85e..89c5b63f6 100644 --- a/python/sglang/srt/hf_transformers_utils.py +++ b/python/sglang/srt/hf_transformers_utils.py @@ -374,8 +374,8 @@ def get_processor( **kwargs, ) - # fix: for Qwen2-VL model, inject default 'size' if not provided. - if config.model_type in {"qwen2_vl"}: + # fix: for Qwen2-VL and Sarashina2Vision models, inject default 'size' if not provided. + if config.model_type in {"qwen2_vl", "sarashina2_vision"}: if "size" not in kwargs: kwargs["size"] = {"shortest_edge": 3136, "longest_edge": 1003520} diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index fc0ce930a..420a9d0f4 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -385,6 +385,10 @@ class LlamaModel(nn.Module): "Self attention has no KV cache scaling " "factor attribute!" ) + def get_input_embeddings(self) -> nn.Embedding: + """Get input embeddings from the model.""" + return self.embed_tokens + class LlamaForCausalLM(nn.Module): # BitandBytes specific attributes diff --git a/python/sglang/srt/models/sarashina2_vision.py b/python/sglang/srt/models/sarashina2_vision.py new file mode 100644 index 000000000..eae341349 --- /dev/null +++ b/python/sglang/srt/models/sarashina2_vision.py @@ -0,0 +1,269 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Inference-only Sarashina2Vision model compatible with HuggingFace weights.""" + +import logging +from typing import Iterable, List, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import nn +from transformers import LlamaConfig + +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.pooler import Pooler, PoolingType +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.managers.mm_utils import ( + MultimodalDataItem, + MultimodalInputs, + MultiModalityDataPaddingPatternMultimodalTokens, + general_mm_embed_routine, +) +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.llama import LlamaForCausalLM +from sglang.srt.models.qwen2_vl import Qwen2VisionTransformer +from sglang.srt.utils import add_prefix + +logger = logging.getLogger(__name__) + + +class Sarashina2VisionForCausalLM(nn.Module): + """ + Sarashina2Vision model that combines: + - Llama text backbone (sbintuitions/sarashina2-7b) + - Qwen2VL vision encoder + """ + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + + self.config = config + + # Extract text and vision configurations + text_config = getattr(config, "text_config", config) + vision_config = getattr(config, "vision_config", None) + + # Create vision transformer first (like original model) + if vision_config is not None: + self.visual = Qwen2VisionTransformer( + vision_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-5), + quant_config=quant_config, + prefix=add_prefix("visual", prefix), + ) + else: + self.visual = None + + # Layer norm for vision outputs (matching original model) + self.norm = nn.LayerNorm(text_config.hidden_size) + + # Create Llama text model (using 'llm' name to match original) + if hasattr(text_config, "model_type") and text_config.model_type == "llama": + llama_config = LlamaConfig(**text_config.__dict__) + # Set vocab_size from main config if available + if hasattr(config, "vocab_size"): + llama_config.vocab_size = config.vocab_size + self.llm = LlamaForCausalLM( + llama_config, + quant_config=quant_config, + prefix=add_prefix("llm", prefix), + ) + else: + # Set vocab_size from main config if available + if hasattr(config, "vocab_size"): + config.vocab_size = config.vocab_size + self.llm = LlamaForCausalLM( + config, + quant_config=quant_config, + prefix=add_prefix("llm", prefix), + ) + + # Image token indices from config + self.image_token_index = getattr(config, "image_token_index", 14) + self.start_image_token_index = getattr( + config, "start_image_token_index", 102397 + ) + self.end_image_token_index = getattr(config, "end_image_token_index", 102398) + + # Ensure vocabulary size matches + if hasattr(config, "vocab_size"): + self.llm.config.vocab_size = config.vocab_size + + self.logits_processor = LogitsProcessor(config) + self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + + def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs): + """Pad input tokens with multimodal data hashes for RadixAttention.""" + pattern = MultiModalityDataPaddingPatternMultimodalTokens() + return pattern.pad_input_tokens(input_ids, mm_inputs) + + def get_input_embeddings(self): + """Get input embeddings from the language model.""" + return self.llm.get_input_embeddings() + + def get_image_embeds( + self, + pixel_values: torch.Tensor, + image_grid_thw: torch.Tensor, + ) -> torch.Tensor: + """Extract image embeddings using the vision transformer.""" + if self.visual is None: + raise ValueError("Visual encoder not initialized") + + # Use the existing Qwen2VisionTransformer forward method + hidden_states = self.visual(pixel_values, image_grid_thw) + + # Apply normalization layer + return self.norm(hidden_states) + + def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: + """Extract image features for SGLang compatibility.""" + if self.visual is None: + raise ValueError("Visual encoder not initialized") + + # Concatenate pixel values and grid_thw from all items + pixel_values = torch.cat([item.feature for item in items], dim=0).type( + self.visual.dtype + ) + image_grid_thw = torch.cat([item.image_grid_thw for item in items], dim=0) + + assert pixel_values.dim() == 2, pixel_values.dim() + assert image_grid_thw.dim() == 2, image_grid_thw.dim() + + # Use the get_image_embeds method + return self.get_image_embeds(pixel_values, image_grid_thw) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + get_embedding: bool = False, + ) -> torch.Tensor: + """Forward pass through the model.""" + # Handles token-to-feature mapping for expanded tokens + hidden_states = general_mm_embed_routine( + input_ids=input_ids, + forward_batch=forward_batch, + language_model=self.llm.model, + multimodal_model=self, + positions=positions, + ) + + if get_embedding: + return self.pooler(hidden_states, forward_batch) + else: + return self.logits_processor( + input_ids, hidden_states, self.llm.lm_head, forward_batch + ) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + """Load model weights.""" + params_dict = dict(self.named_parameters()) + loaded_params = set() + + # Collect weights that need to be fused + qkv_weights = {} + gate_up_weights = {} + + for name, loaded_weight in weights: + # Handle weight name mappings + + # Map visual attention weights: qkv -> qkv_proj + if ".attn.qkv." in name: + mapped_name = name.replace(".attn.qkv.", ".attn.qkv_proj.") + if mapped_name in params_dict: + param = params_dict[mapped_name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + loaded_params.add(mapped_name) + continue + + # Handle Llama attention weights - need to fuse q, k, v into qkv + if ".self_attn.q_proj.weight" in name: + base = name.replace(".q_proj.weight", "") + qkv_weights[base] = qkv_weights.get(base, {}) + qkv_weights[base]["q"] = loaded_weight + continue + elif ".self_attn.k_proj.weight" in name: + base = name.replace(".k_proj.weight", "") + qkv_weights[base] = qkv_weights.get(base, {}) + qkv_weights[base]["k"] = loaded_weight + continue + elif ".self_attn.v_proj.weight" in name: + base = name.replace(".v_proj.weight", "") + qkv_weights[base] = qkv_weights.get(base, {}) + qkv_weights[base]["v"] = loaded_weight + continue + + # Handle Llama MLP weights - need to fuse gate and up projections + if ".mlp.gate_proj.weight" in name: + base = name.replace(".gate_proj.weight", "") + gate_up_weights[base] = gate_up_weights.get(base, {}) + gate_up_weights[base]["gate"] = loaded_weight + continue + elif ".mlp.up_proj.weight" in name: + base = name.replace(".up_proj.weight", "") + gate_up_weights[base] = gate_up_weights.get(base, {}) + gate_up_weights[base]["up"] = loaded_weight + continue + + # Direct mapping for other weights + if name in params_dict: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + # Fuse QKV weights for Llama attention layers + for base, weights_dict in qkv_weights.items(): + if "q" in weights_dict and "k" in weights_dict and "v" in weights_dict: + qkv_name = f"{base}.qkv_proj.weight" + if qkv_name in params_dict: + # Concatenate q, k, v weights + q, k, v = weights_dict["q"], weights_dict["k"], weights_dict["v"] + qkv = torch.cat([q, k, v], dim=0) + param = params_dict[qkv_name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, qkv) + loaded_params.add(qkv_name) + + # Fuse gate and up weights for Llama MLP layers + for base, weights_dict in gate_up_weights.items(): + if "gate" in weights_dict and "up" in weights_dict: + gate_up_name = f"{base}.gate_up_proj.weight" + if gate_up_name in params_dict: + # Concatenate gate and up weights + gate, up = weights_dict["gate"], weights_dict["up"] + gate_up = torch.cat([gate, up], dim=0) + param = params_dict[gate_up_name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, gate_up) + loaded_params.add(gate_up_name) + + +# Register the model +EntryClass = Sarashina2VisionForCausalLM diff --git a/python/sglang/srt/multimodal/processors/sarashina2_vision.py b/python/sglang/srt/multimodal/processors/sarashina2_vision.py new file mode 100644 index 000000000..fc7bdf3c9 --- /dev/null +++ b/python/sglang/srt/multimodal/processors/sarashina2_vision.py @@ -0,0 +1,81 @@ +from typing import List, Union + +from sglang.srt.models.sarashina2_vision import Sarashina2VisionForCausalLM +from sglang.srt.multimodal.processors.base_processor import ( + BaseMultimodalProcessor, + MultimodalSpecialTokens, +) + + +class Sarashina2VisionProcessor(BaseMultimodalProcessor): + models = [Sarashina2VisionForCausalLM] + + def __init__(self, hf_config, server_args, _processor, *args, **kwargs): + super().__init__(hf_config, server_args, _processor, *args, **kwargs) + + # Sarashina2Vision specific tokens (default is <|file|>) + self.IMAGE_TOKEN = "<|file|>" + self.IM_TOKEN_ID = getattr(hf_config, "image_token_index", 14) + self.IM_START_ID = getattr(hf_config, "start_image_token_index", 102397) + self.IM_END_ID = getattr(hf_config, "end_image_token_index", 102398) + + self.mm_tokens = MultimodalSpecialTokens( + image_token=self.IMAGE_TOKEN, + image_token_id=self.IM_TOKEN_ID, + ).build(_processor) + + # Patch the processor's image processor to handle parameter compatibility + if hasattr(_processor, "image_processor") and hasattr( + _processor.image_processor, "_preprocess" + ): + original_preprocess = _processor.image_processor._preprocess + + def patched_preprocess(*args, **kwargs): + # Filter kwargs to only include parameters that the custom _preprocess method accepts + # Based on Sarashina2VisionImageProcessor._preprocess signature + allowed_params = { + "do_resize", + "resample", + "do_rescale", + "rescale_factor", + "do_normalize", + "image_mean", + "image_std", + "do_convert_rgb", + "data_format", + "input_data_format", + } + filtered_kwargs = { + k: v for k, v in kwargs.items() if k in allowed_params + } + return original_preprocess(*args, **filtered_kwargs) + + _processor.image_processor._preprocess = patched_preprocess + + async def process_mm_data_async( + self, + image_data: List[Union[str, bytes]], + input_text, + request_obj, + *args, + **kwargs, + ): + """Process image data for Sarashina2Vision model using standard SGLang pattern.""" + base_output = self.load_mm_data( + prompt=input_text, + image_data=image_data, + multimodal_tokens=self.mm_tokens, + ) + + mm_items, input_ids, ret = self.process_and_combine_mm_data( + base_output=base_output, + mm_tokens=self.mm_tokens, + ) + + return { + "mm_items": mm_items, + "input_ids": input_ids.tolist(), + "im_token_id": self.mm_tokens.image_token_id, + "im_start_id": self.IM_START_ID, + "im_end_id": self.IM_END_ID, + }