model support: Sarashina2VisionForCausalLM (#10632)
This commit is contained in:
@@ -0,0 +1,9 @@
|
||||
{#
|
||||
In sglang, the default chat templates often assume message['content'] is a plain string.
|
||||
That works fine for simple text conversations, but it ignores multimodal inputs (e.g. image_url, tool_call).
|
||||
To align with the original model behavior and support richer content,
|
||||
we iterate over message['content'] as a list of typed items and extract their values directly.
|
||||
This way, both text and non-text inputs are preserved in the prompt.
|
||||
Original template: https://huggingface.co/sbintuitions/sarashina2-vision-8b?chat_template=default
|
||||
#}
|
||||
{{ bos_token + '<|prefix|><|file|><|suffix|>A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human\'s questions.\n\n' }}{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Human: ' }}{%- if message['content'] is string %}{{ message['content'] }}{%- else %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% endif %}{% endfor %}{% endif %}{{ '\n' }}{% elif message['role'] == 'assistant' %}{{ '### Assistant: ' }}{%- if message['content'] is string %}{{ message['content'] }}{%- else %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% endif %}{% endfor %}{% endif %}{{ '\n' }}{% endif %}{% endfor %}{% if messages[-1]['role'] == 'user' %}{{ '### Assistant:' }}{% endif %}
|
||||
@@ -756,6 +756,7 @@ multimodal_model_archs = [
|
||||
"VILAForConditionalGeneration",
|
||||
"Step3VLForConditionalGeneration",
|
||||
"DotsVLMForCausalLM",
|
||||
"Sarashina2VisionForCausalLM",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -374,8 +374,8 @@ def get_processor(
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# fix: for Qwen2-VL model, inject default 'size' if not provided.
|
||||
if config.model_type in {"qwen2_vl"}:
|
||||
# fix: for Qwen2-VL and Sarashina2Vision models, inject default 'size' if not provided.
|
||||
if config.model_type in {"qwen2_vl", "sarashina2_vision"}:
|
||||
if "size" not in kwargs:
|
||||
kwargs["size"] = {"shortest_edge": 3136, "longest_edge": 1003520}
|
||||
|
||||
|
||||
@@ -385,6 +385,10 @@ class LlamaModel(nn.Module):
|
||||
"Self attention has no KV cache scaling " "factor attribute!"
|
||||
)
|
||||
|
||||
def get_input_embeddings(self) -> nn.Embedding:
|
||||
"""Get input embeddings from the model."""
|
||||
return self.embed_tokens
|
||||
|
||||
|
||||
class LlamaForCausalLM(nn.Module):
|
||||
# BitandBytes specific attributes
|
||||
|
||||
269
python/sglang/srt/models/sarashina2_vision.py
Normal file
269
python/sglang/srt/models/sarashina2_vision.py
Normal file
@@ -0,0 +1,269 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Inference-only Sarashina2Vision model compatible with HuggingFace weights."""
|
||||
|
||||
import logging
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from transformers import LlamaConfig
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.pooler import Pooler, PoolingType
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.managers.mm_utils import (
|
||||
MultimodalDataItem,
|
||||
MultimodalInputs,
|
||||
MultiModalityDataPaddingPatternMultimodalTokens,
|
||||
general_mm_embed_routine,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.llama import LlamaForCausalLM
|
||||
from sglang.srt.models.qwen2_vl import Qwen2VisionTransformer
|
||||
from sglang.srt.utils import add_prefix
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Sarashina2VisionForCausalLM(nn.Module):
|
||||
"""
|
||||
Sarashina2Vision model that combines:
|
||||
- Llama text backbone (sbintuitions/sarashina2-7b)
|
||||
- Qwen2VL vision encoder
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
|
||||
# Extract text and vision configurations
|
||||
text_config = getattr(config, "text_config", config)
|
||||
vision_config = getattr(config, "vision_config", None)
|
||||
|
||||
# Create vision transformer first (like original model)
|
||||
if vision_config is not None:
|
||||
self.visual = Qwen2VisionTransformer(
|
||||
vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-5),
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("visual", prefix),
|
||||
)
|
||||
else:
|
||||
self.visual = None
|
||||
|
||||
# Layer norm for vision outputs (matching original model)
|
||||
self.norm = nn.LayerNorm(text_config.hidden_size)
|
||||
|
||||
# Create Llama text model (using 'llm' name to match original)
|
||||
if hasattr(text_config, "model_type") and text_config.model_type == "llama":
|
||||
llama_config = LlamaConfig(**text_config.__dict__)
|
||||
# Set vocab_size from main config if available
|
||||
if hasattr(config, "vocab_size"):
|
||||
llama_config.vocab_size = config.vocab_size
|
||||
self.llm = LlamaForCausalLM(
|
||||
llama_config,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("llm", prefix),
|
||||
)
|
||||
else:
|
||||
# Set vocab_size from main config if available
|
||||
if hasattr(config, "vocab_size"):
|
||||
config.vocab_size = config.vocab_size
|
||||
self.llm = LlamaForCausalLM(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("llm", prefix),
|
||||
)
|
||||
|
||||
# Image token indices from config
|
||||
self.image_token_index = getattr(config, "image_token_index", 14)
|
||||
self.start_image_token_index = getattr(
|
||||
config, "start_image_token_index", 102397
|
||||
)
|
||||
self.end_image_token_index = getattr(config, "end_image_token_index", 102398)
|
||||
|
||||
# Ensure vocabulary size matches
|
||||
if hasattr(config, "vocab_size"):
|
||||
self.llm.config.vocab_size = config.vocab_size
|
||||
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||
|
||||
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
||||
"""Pad input tokens with multimodal data hashes for RadixAttention."""
|
||||
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
|
||||
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
"""Get input embeddings from the language model."""
|
||||
return self.llm.get_input_embeddings()
|
||||
|
||||
def get_image_embeds(
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
image_grid_thw: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Extract image embeddings using the vision transformer."""
|
||||
if self.visual is None:
|
||||
raise ValueError("Visual encoder not initialized")
|
||||
|
||||
# Use the existing Qwen2VisionTransformer forward method
|
||||
hidden_states = self.visual(pixel_values, image_grid_thw)
|
||||
|
||||
# Apply normalization layer
|
||||
return self.norm(hidden_states)
|
||||
|
||||
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
||||
"""Extract image features for SGLang compatibility."""
|
||||
if self.visual is None:
|
||||
raise ValueError("Visual encoder not initialized")
|
||||
|
||||
# Concatenate pixel values and grid_thw from all items
|
||||
pixel_values = torch.cat([item.feature for item in items], dim=0).type(
|
||||
self.visual.dtype
|
||||
)
|
||||
image_grid_thw = torch.cat([item.image_grid_thw for item in items], dim=0)
|
||||
|
||||
assert pixel_values.dim() == 2, pixel_values.dim()
|
||||
assert image_grid_thw.dim() == 2, image_grid_thw.dim()
|
||||
|
||||
# Use the get_image_embeds method
|
||||
return self.get_image_embeds(pixel_values, image_grid_thw)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
get_embedding: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass through the model."""
|
||||
# Handles token-to-feature mapping for expanded tokens
|
||||
hidden_states = general_mm_embed_routine(
|
||||
input_ids=input_ids,
|
||||
forward_batch=forward_batch,
|
||||
language_model=self.llm.model,
|
||||
multimodal_model=self,
|
||||
positions=positions,
|
||||
)
|
||||
|
||||
if get_embedding:
|
||||
return self.pooler(hidden_states, forward_batch)
|
||||
else:
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.llm.lm_head, forward_batch
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
"""Load model weights."""
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params = set()
|
||||
|
||||
# Collect weights that need to be fused
|
||||
qkv_weights = {}
|
||||
gate_up_weights = {}
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
# Handle weight name mappings
|
||||
|
||||
# Map visual attention weights: qkv -> qkv_proj
|
||||
if ".attn.qkv." in name:
|
||||
mapped_name = name.replace(".attn.qkv.", ".attn.qkv_proj.")
|
||||
if mapped_name in params_dict:
|
||||
param = params_dict[mapped_name]
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(mapped_name)
|
||||
continue
|
||||
|
||||
# Handle Llama attention weights - need to fuse q, k, v into qkv
|
||||
if ".self_attn.q_proj.weight" in name:
|
||||
base = name.replace(".q_proj.weight", "")
|
||||
qkv_weights[base] = qkv_weights.get(base, {})
|
||||
qkv_weights[base]["q"] = loaded_weight
|
||||
continue
|
||||
elif ".self_attn.k_proj.weight" in name:
|
||||
base = name.replace(".k_proj.weight", "")
|
||||
qkv_weights[base] = qkv_weights.get(base, {})
|
||||
qkv_weights[base]["k"] = loaded_weight
|
||||
continue
|
||||
elif ".self_attn.v_proj.weight" in name:
|
||||
base = name.replace(".v_proj.weight", "")
|
||||
qkv_weights[base] = qkv_weights.get(base, {})
|
||||
qkv_weights[base]["v"] = loaded_weight
|
||||
continue
|
||||
|
||||
# Handle Llama MLP weights - need to fuse gate and up projections
|
||||
if ".mlp.gate_proj.weight" in name:
|
||||
base = name.replace(".gate_proj.weight", "")
|
||||
gate_up_weights[base] = gate_up_weights.get(base, {})
|
||||
gate_up_weights[base]["gate"] = loaded_weight
|
||||
continue
|
||||
elif ".mlp.up_proj.weight" in name:
|
||||
base = name.replace(".up_proj.weight", "")
|
||||
gate_up_weights[base] = gate_up_weights.get(base, {})
|
||||
gate_up_weights[base]["up"] = loaded_weight
|
||||
continue
|
||||
|
||||
# Direct mapping for other weights
|
||||
if name in params_dict:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
|
||||
# Fuse QKV weights for Llama attention layers
|
||||
for base, weights_dict in qkv_weights.items():
|
||||
if "q" in weights_dict and "k" in weights_dict and "v" in weights_dict:
|
||||
qkv_name = f"{base}.qkv_proj.weight"
|
||||
if qkv_name in params_dict:
|
||||
# Concatenate q, k, v weights
|
||||
q, k, v = weights_dict["q"], weights_dict["k"], weights_dict["v"]
|
||||
qkv = torch.cat([q, k, v], dim=0)
|
||||
param = params_dict[qkv_name]
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
weight_loader(param, qkv)
|
||||
loaded_params.add(qkv_name)
|
||||
|
||||
# Fuse gate and up weights for Llama MLP layers
|
||||
for base, weights_dict in gate_up_weights.items():
|
||||
if "gate" in weights_dict and "up" in weights_dict:
|
||||
gate_up_name = f"{base}.gate_up_proj.weight"
|
||||
if gate_up_name in params_dict:
|
||||
# Concatenate gate and up weights
|
||||
gate, up = weights_dict["gate"], weights_dict["up"]
|
||||
gate_up = torch.cat([gate, up], dim=0)
|
||||
param = params_dict[gate_up_name]
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
weight_loader(param, gate_up)
|
||||
loaded_params.add(gate_up_name)
|
||||
|
||||
|
||||
# Register the model
|
||||
EntryClass = Sarashina2VisionForCausalLM
|
||||
81
python/sglang/srt/multimodal/processors/sarashina2_vision.py
Normal file
81
python/sglang/srt/multimodal/processors/sarashina2_vision.py
Normal file
@@ -0,0 +1,81 @@
|
||||
from typing import List, Union
|
||||
|
||||
from sglang.srt.models.sarashina2_vision import Sarashina2VisionForCausalLM
|
||||
from sglang.srt.multimodal.processors.base_processor import (
|
||||
BaseMultimodalProcessor,
|
||||
MultimodalSpecialTokens,
|
||||
)
|
||||
|
||||
|
||||
class Sarashina2VisionProcessor(BaseMultimodalProcessor):
|
||||
models = [Sarashina2VisionForCausalLM]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
||||
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
||||
|
||||
# Sarashina2Vision specific tokens (default is <|file|>)
|
||||
self.IMAGE_TOKEN = "<|file|>"
|
||||
self.IM_TOKEN_ID = getattr(hf_config, "image_token_index", 14)
|
||||
self.IM_START_ID = getattr(hf_config, "start_image_token_index", 102397)
|
||||
self.IM_END_ID = getattr(hf_config, "end_image_token_index", 102398)
|
||||
|
||||
self.mm_tokens = MultimodalSpecialTokens(
|
||||
image_token=self.IMAGE_TOKEN,
|
||||
image_token_id=self.IM_TOKEN_ID,
|
||||
).build(_processor)
|
||||
|
||||
# Patch the processor's image processor to handle parameter compatibility
|
||||
if hasattr(_processor, "image_processor") and hasattr(
|
||||
_processor.image_processor, "_preprocess"
|
||||
):
|
||||
original_preprocess = _processor.image_processor._preprocess
|
||||
|
||||
def patched_preprocess(*args, **kwargs):
|
||||
# Filter kwargs to only include parameters that the custom _preprocess method accepts
|
||||
# Based on Sarashina2VisionImageProcessor._preprocess signature
|
||||
allowed_params = {
|
||||
"do_resize",
|
||||
"resample",
|
||||
"do_rescale",
|
||||
"rescale_factor",
|
||||
"do_normalize",
|
||||
"image_mean",
|
||||
"image_std",
|
||||
"do_convert_rgb",
|
||||
"data_format",
|
||||
"input_data_format",
|
||||
}
|
||||
filtered_kwargs = {
|
||||
k: v for k, v in kwargs.items() if k in allowed_params
|
||||
}
|
||||
return original_preprocess(*args, **filtered_kwargs)
|
||||
|
||||
_processor.image_processor._preprocess = patched_preprocess
|
||||
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
image_data: List[Union[str, bytes]],
|
||||
input_text,
|
||||
request_obj,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
"""Process image data for Sarashina2Vision model using standard SGLang pattern."""
|
||||
base_output = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
image_data=image_data,
|
||||
multimodal_tokens=self.mm_tokens,
|
||||
)
|
||||
|
||||
mm_items, input_ids, ret = self.process_and_combine_mm_data(
|
||||
base_output=base_output,
|
||||
mm_tokens=self.mm_tokens,
|
||||
)
|
||||
|
||||
return {
|
||||
"mm_items": mm_items,
|
||||
"input_ids": input_ids.tolist(),
|
||||
"im_token_id": self.mm_tokens.image_token_id,
|
||||
"im_start_id": self.IM_START_ID,
|
||||
"im_end_id": self.IM_END_ID,
|
||||
}
|
||||
Reference in New Issue
Block a user