From 1b1701f1f7ea59ffca374cfbd8cd53ed5fd39df8 Mon Sep 17 00:00:00 2001 From: "chenge@xiaohongshu.com" Date: Fri, 12 Sep 2025 17:38:38 +0800 Subject: [PATCH] model: support dots.vlm1 model (#8778) Co-authored-by: weishi Co-authored-by: Ezra-Yu <1105212286@qq.com> Co-authored-by: Jianfei Wang <905787410@qq.com> Co-authored-by: qianwu --- benchmark/mmmu/bench_sglang.py | 10 +- benchmark/mmmu/eval_utils.py | 31 +- python/pyproject.toml | 1 - python/sglang/srt/configs/__init__.py | 2 + python/sglang/srt/configs/dots_vlm.py | 139 ++++++++ python/sglang/srt/configs/model_config.py | 2 + python/sglang/srt/hf_transformers_utils.py | 2 + python/sglang/srt/models/dots_vlm.py | 174 +++++++++ python/sglang/srt/models/dots_vlm_vit.py | 337 ++++++++++++++++++ .../srt/multimodal/processors/dots_vlm.py | 99 +++++ .../srt/multimodal/processors/qwen_vl.py | 20 +- 11 files changed, 806 insertions(+), 11 deletions(-) create mode 100644 python/sglang/srt/configs/dots_vlm.py create mode 100644 python/sglang/srt/models/dots_vlm.py create mode 100644 python/sglang/srt/models/dots_vlm_vit.py create mode 100644 python/sglang/srt/multimodal/processors/dots_vlm.py diff --git a/benchmark/mmmu/bench_sglang.py b/benchmark/mmmu/bench_sglang.py index d8834ea5f..9a0bf4529 100644 --- a/benchmark/mmmu/bench_sglang.py +++ b/benchmark/mmmu/bench_sglang.py @@ -124,7 +124,9 @@ async def eval_mmmu(args) -> None: answer_dict = {} out_samples = {} client = openai.AsyncOpenAI( - api_key="sk", base_url=f"http://127.0.0.1:{args.port}/v1" + api_key="sk", + base_url=f"http://127.0.0.1:{args.port}/v1", + timeout=20 * 60 * 60, ) start = time.perf_counter() base_url = f"http://127.0.0.1:{args.port}" @@ -146,13 +148,14 @@ async def eval_mmmu(args) -> None: _, response = await process_sample( client, sample, sampling_params, lora_path ) + sample["original_response"] = response answer = ( re.search(args.response_answer_regex, response) if response is not None else None ) process_result( - answer.group(1) if answer else response, + answer.group(1).strip() if answer else response, sample, answer_dict, out_samples, @@ -168,13 +171,14 @@ async def eval_mmmu(args) -> None: for coro in tqdm(asyncio.as_completed(tasks), total=len(tasks)): sample, response = await coro + sample["original_response"] = response answer = ( re.search(args.response_answer_regex, response) if response is not None else None ) process_result( - answer.group(1) if answer else response, + answer.group(1).strip() if answer else response, sample, answer_dict, out_samples, diff --git a/benchmark/mmmu/eval_utils.py b/benchmark/mmmu/eval_utils.py index ca0e87c6a..17cf850f6 100644 --- a/benchmark/mmmu/eval_utils.py +++ b/benchmark/mmmu/eval_utils.py @@ -18,6 +18,7 @@ from data_utils import ( construct_prompt, load_yaml, process_single_sample, + save_json, ) from datasets import concatenate_datasets, load_dataset from tqdm import tqdm @@ -28,7 +29,7 @@ class EvalArgs: seed: int = 42 split: str = "validation" image_pixels_limit: int = -1 - result_filename: str = "" + result_filename: str = f"./val_sglang.json" prompt_format_file: str = "prompt_format.yaml" dataset_path: str = "MMMU/MMMU" extra_request_body: Optional[str] = None @@ -445,6 +446,18 @@ def eval_multi_choice(gold_i, pred_i): Evaluate a multiple choice instance. """ correct = False + # for case like Answer: A, Answer is A, answer is A, answer: A + for _exp in ["Answer:", "Answer is ", "answer is ", "answer: "]: + if _exp in pred_i: + pred_i = pred_i.split(_exp)[1].strip() + break + # for case like (A), (B), (C), (D) ...... + if "(" in pred_i and ")" in pred_i: + try: + pred_i = re.search(r"\(([A-Z])\)", pred_i).group(1) + except: + print(f"Error to extract answer from: {pred_i}") + pass # only they are exactly the same, we consider it as correct if isinstance(gold_i, list): for answer in gold_i: @@ -535,7 +548,12 @@ def process_result(response, sample, answer_dict, out_samples): else: # open question pred_ans = response - out_samples[sample["id"]] = pred_ans + out_samples[sample["id"]] = { + "pred_ans": pred_ans, + "original_response": sample["original_response"], + "ground_truth": sample["answer"], + "question_type": sample["question_type"], + } # set ground truth answer answer_dict[sample["id"]] = { @@ -554,6 +572,12 @@ def eval_result(model_answer_path, answer_dict, eval_output_path=None): # group by category output_dict_w_cat = {} for data_id, parsed_pred in output_dict.items(): + if isinstance(parsed_pred, str): + parsed_pred = parsed_pred + elif isinstance(parsed_pred, dict): + parsed_pred = parsed_pred["pred_ans"] + else: + raise ValueError(f"Unknown type of parsed_pred: {type(parsed_pred)}") category = "_".join(data_id.split("_")[1:-1]) if category not in output_dict_w_cat: output_dict_w_cat.update({category: {}}) @@ -600,9 +624,12 @@ def eval_result(model_answer_path, answer_dict, eval_output_path=None): judge_dict, metric_dict = evaluate(exampels_to_eval) metric_dict.update({"num_example": len(exampels_to_eval)}) + for key, value in judge_dict.items(): + output_dict[key]["judge"] = value evaluation_result[category] = metric_dict + save_json(model_answer_path, output_dict) printable_results = {} # pdb.set_trace() # add domain Subject diff --git a/python/pyproject.toml b/python/pyproject.toml index f2e69b3c0..7b0bda1f5 100755 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -44,7 +44,6 @@ runtime_common = [ "pynvml", "python-multipart", "pyzmq>=25.1.2", - "sentencepiece", "soundfile==0.13.1", "scipy", "timm==1.0.16", diff --git a/python/sglang/srt/configs/__init__.py b/python/sglang/srt/configs/__init__.py index ef880c911..0a57a8b26 100644 --- a/python/sglang/srt/configs/__init__.py +++ b/python/sglang/srt/configs/__init__.py @@ -1,6 +1,7 @@ from sglang.srt.configs.chatglm import ChatGLMConfig from sglang.srt.configs.dbrx import DbrxConfig from sglang.srt.configs.deepseekvl2 import DeepseekVL2Config +from sglang.srt.configs.dots_vlm import DotsVLMConfig from sglang.srt.configs.exaone import ExaoneConfig from sglang.srt.configs.janus_pro import MultiModalityConfig from sglang.srt.configs.kimi_vl import KimiVLConfig @@ -26,4 +27,5 @@ __all__ = [ "Step3TextConfig", "Step3VisionEncoderConfig", "Qwen3NextConfig", + "DotsVLMConfig", ] diff --git a/python/sglang/srt/configs/dots_vlm.py b/python/sglang/srt/configs/dots_vlm.py new file mode 100644 index 000000000..155d6ee47 --- /dev/null +++ b/python/sglang/srt/configs/dots_vlm.py @@ -0,0 +1,139 @@ +from typing import Any, List, Optional, Union + +from transformers import AutoProcessor, LlamaTokenizerFast, PretrainedConfig +from transformers.feature_extraction_utils import BatchFeature +from transformers.image_utils import ImageInput +from transformers.processing_utils import ProcessingKwargs, Unpack +from transformers.tokenization_utils_base import PreTokenizedInput, TextInput + +try: + from transformers import Qwen2_5_VLProcessor +except ImportError: + raise ImportError( + "Qwen2_5_VLProcessor can not be found. Please upgrade your transformers version." + ) + +from sglang.srt.configs.deepseekvl2 import DeepseekV2Config + + +class DotsVisionConfig(PretrainedConfig): + model_type: str = "dots_vit" + + def __init__( + self, + embed_dim: int = 1536, # vision encoder embed size + hidden_size: int = 1536, # after merger hidden size + intermediate_size: int = 4224, + num_hidden_layers: int = 42, + num_attention_heads: int = 12, + num_channels: int = 3, + patch_size: int = 14, + spatial_merge_size: int = 2, + temporal_patch_size: int = 1, + rms_norm_eps: float = 1e-5, + use_bias: bool = False, + attn_implementation="flash_attention_2", # "eager","sdpa","flash_attention_2" + initializer_range=0.02, + init_merger_std=0.02, + is_causal=False, # ve causal forward + post_norm=True, + gradient_checkpointing=False, + **kwargs, + ): + super().__init__(**kwargs) + self.embed_dim = embed_dim + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.spatial_merge_size = spatial_merge_size + self.temporal_patch_size = temporal_patch_size + self.rms_norm_eps = rms_norm_eps + self.use_bias = use_bias + self.attn_implementation = attn_implementation + self.initializer_range = initializer_range + self.init_merger_std = init_merger_std + self.is_causal = is_causal + self.post_norm = post_norm + self.gradient_checkpointing = gradient_checkpointing + + +class DotsVLMConfig(PretrainedConfig): + model_type = "dots_vlm" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + vision_config = kwargs.get("vision_config", {}) + self.im_span_id = kwargs.get("image_token_id", 128815) + self.video_span_id = kwargs.get("video_token_id", 128836) + self.vision_config = DotsVisionConfig(**vision_config) + self.language_config = DeepseekV2Config(**kwargs) + self.architectures = ["DotsVLMForCausalLM"] + + +class DotsVLMProcessorKwargs(ProcessingKwargs, total=False): + _defaults = { + "text_kwargs": { + "padding": False, + }, + } + + +class DotsVLMProcessor(Qwen2_5_VLProcessor): + r""" + Constructs a DotsVLM processor which derives from Qwen2_5_VLProcessor, but overrides the image and video token ids. + Besides, its tokenizer is a LlamaTokenizerFast instead of Qwen2TokenizerFast. + [`DotsVLMProcessor`] offers all the functionalities of [`DotsVisionConfig`] and [`LlamaTokenizerFast`]. See the + [`~DotsVLMProcessor.__call__`] and [`~DotsVLMProcessor.decode`] for more information. + Args: + image_processor ([`Qwen2VLImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`LlamaTokenizerFast`], *optional*): + The tokenizer is a required input. + chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages + in a chat into a tokenizable string. + """ + + attributes = ["image_processor", "tokenizer"] + + valid_kwargs = ["chat_template"] + + tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast") + + def __init__( + self, image_processor=None, tokenizer=None, chat_template=None, **kwargs + ): + super().__init__(image_processor, tokenizer, chat_template=chat_template) + self.image_token = ( + "<|imgpad|>" + if not hasattr(tokenizer, "image_token") + else tokenizer.image_token + ) + self.video_token = ( + "<|video_pad|>" + if not hasattr(tokenizer, "video_token") + else tokenizer.video_token + ) + self.img_token = ( + "<|img|>" if not hasattr(tokenizer, "img_token") else tokenizer.img_token + ) + self.endofimg_token = ( + "<|endofimg|>" + if not hasattr(tokenizer, "endofimg_token") + else tokenizer.endofimg_token + ) + self.image_token_id = ( + tokenizer.image_token_id + if getattr(tokenizer, "image_token_id", None) + else tokenizer.encode(self.image_token)[0] + ) + self.video_token_id = ( + tokenizer.video_token_id + if getattr(tokenizer, "video_token_id", None) + else tokenizer.encode(self.video_token)[0] + ) + + +AutoProcessor.register(DotsVLMConfig, DotsVLMProcessor) diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index d6f34b4d9..2ba4bbc7a 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -216,6 +216,7 @@ class ModelConfig: or "DeepseekV3ForCausalLMNextN" in self.hf_config.architectures or "LongcatFlashForCausalLM" in self.hf_config.architectures or "LongcatFlashForCausalLMNextN" in self.hf_config.architectures + or "DotsVLMForCausalLM" in self.hf_config.architectures ): self.head_dim = 256 self.attention_arch = AttentionArch.MLA @@ -734,6 +735,7 @@ multimodal_model_archs = [ "Phi4MMForCausalLM", "VILAForConditionalGeneration", "Step3VLForConditionalGeneration", + "DotsVLMForCausalLM", ] diff --git a/python/sglang/srt/hf_transformers_utils.py b/python/sglang/srt/hf_transformers_utils.py index d7dcf8904..6cf22a85e 100644 --- a/python/sglang/srt/hf_transformers_utils.py +++ b/python/sglang/srt/hf_transformers_utils.py @@ -38,6 +38,7 @@ from sglang.srt.configs import ( ChatGLMConfig, DbrxConfig, DeepseekVL2Config, + DotsVLMConfig, ExaoneConfig, KimiVLConfig, LongcatFlashConfig, @@ -60,6 +61,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { Step3VLConfig.model_type: Step3VLConfig, LongcatFlashConfig.model_type: LongcatFlashConfig, Qwen3NextConfig.model_type: Qwen3NextConfig, + DotsVLMConfig.model_type: DotsVLMConfig, } for name, cls in _CONFIG_REGISTRY.items(): diff --git a/python/sglang/srt/models/dots_vlm.py b/python/sglang/srt/models/dots_vlm.py new file mode 100644 index 000000000..95475058f --- /dev/null +++ b/python/sglang/srt/models/dots_vlm.py @@ -0,0 +1,174 @@ +# Copyright 2025 The RedNote HiLab team. +# Copyright 2025 The SGLang team. +# +# This code is based on the DeepseekVL2ForCausalLM and DotsVisionTransformer +# implementation in this library. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Dots-VL model compatible with HuggingFace weights.""" + +from typing import Iterable, List, Optional, Tuple + +import torch +from torch import nn + +from sglang.srt.configs.dots_vlm import DotsVLMConfig +from sglang.srt.distributed import parallel_state +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.managers.mm_utils import ( + MultiModalityDataPaddingPatternMultimodalTokens, + general_mm_embed_routine, +) +from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.deepseek_v2 import DeepseekV2ForCausalLM + +from .dots_vlm_vit import DotsVisionTransformer + + +class DotsVLMForCausalLM(nn.Module): + """DotsVLM model for sglang inference""" + + def __init__( + self, config: DotsVLMConfig, quant_config: Optional[QuantizationConfig] = None + ) -> None: + super().__init__() + + self.config = config + self.image_token_id = config.im_span_id + self.video_token_id = config.video_span_id + + self.language_model = DeepseekV2ForCausalLM( + config.language_config, quant_config + ) + + # Initialize vision tower (matching transformers naming for weight compatibility) + self.vision_tower = DotsVisionTransformer(config.vision_config) + + def _pad_vit_attn_dummy_heads(self, name: str, loaded_weight: torch.Tensor): + """pad attn qkv weights for dummy heads""" + num_dummy_heads = self.config.vision_config.num_dummy_heads + if num_dummy_heads == 0: + return loaded_weight + head_dim = self.config.vision_config.head_dim + + if "attn.qkv_proj" in name: + wq, wk, wv = loaded_weight.chunk(3, dim=0) + if name.endswith(".weight"): + dummy_shape = [num_dummy_heads, head_dim, wq.shape[-1]] + elif name.endswith(".bias"): + dummy_shape = [num_dummy_heads, head_dim] + else: + raise RuntimeError(f"Unsupported weight with name={name}") + pad_func = lambda x: torch.cat( + [x.unflatten(0, (-1, head_dim)), x.new_zeros(dummy_shape)], dim=0 + ).flatten(0, 1) + wq, wk, wv = pad_func(wq), pad_func(wk), pad_func(wv) + loaded_weight = torch.cat([wq, wk, wv], dim=0) + if "attn.proj.weight" in name: + padded_weight = loaded_weight.new_zeros( + loaded_weight.shape[0], head_dim * num_dummy_heads + ) + loaded_weight = torch.cat([loaded_weight, padded_weight], dim=-1) + if "attn.q_norm.weight" in name or "attn.k_norm.weight" in name: + padded_weight = loaded_weight.new_zeros(head_dim * num_dummy_heads) + loaded_weight = torch.cat([loaded_weight, padded_weight], dim=0) + return loaded_weight + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + """Load weights for the model, separating vision and language weights""" + weights = list(weights) + + # Separate vision tower weights and language model weights + vision_weights = [] + language_weights = [] + + for name, loaded_weight in weights: + if name.startswith("vision_tower."): + vision_name = name.replace(r"attn.qkv.", r"attn.qkv_proj.") + vision_weights.append((vision_name, loaded_weight)) + else: + # All other weights go to language model + language_weights.append((name, loaded_weight)) + + # Load vision tower weights + vision_state_dict = dict(vision_weights) + params_dict = dict(self.named_parameters(remove_duplicate=False)) + for name, loaded_weight in vision_state_dict.items(): + if name not in params_dict: + raise ValueError(f"Weight {name} not found in params_dict") + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + loaded_weight = self._pad_vit_attn_dummy_heads(name, loaded_weight) + weight_loader(param, loaded_weight) + + # Load language model weights + if language_weights: + self.language_model.load_weights(language_weights) + + @classmethod + def get_model_config_for_expert_location(cls, config): + return DeepseekV2ForCausalLM.get_model_config_for_expert_location(config) + + def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs): + """Pad input_ids with multimodal tokens""" + # Get image token ID for padding pattern + pattern = MultiModalityDataPaddingPatternMultimodalTokens() + padded_input_ids = pattern.pad_input_tokens(input_ids, mm_inputs) + return padded_input_ids + + def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: + # Extract pixel values and grid information (following reference pattern) + pixel_values = torch.cat([item.feature for item in items], dim=0).type( + self.vision_tower.dtype + ) + image_grid_thw = torch.concat( + [item.image_grid_thw for item in items], dim=0 + ).to(self.vision_tower.device) + + # Add dimension checks like in reference code + assert pixel_values.dim() == 2, f"{pixel_values.dim()=}" + assert image_grid_thw.dim() == 2, f"{image_grid_thw.dim()=}" + + # Process through vision tower + image_embeds = self.vision_tower(pixel_values, image_grid_thw) + + # Ensure consistent dtype for FlashInfer compatibility + # Force bfloat16 to match model's expected dtype + if image_embeds.dtype != torch.bfloat16 and hasattr( + self.language_model.model, "embed_tokens" + ): + target_dtype = self.language_model.model.embed_tokens.weight.dtype + image_embeds = image_embeds.to(target_dtype) + + return image_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + **kwargs: object, + ) -> torch.Tensor: + hidden_states = general_mm_embed_routine( + input_ids=input_ids, + positions=positions, + forward_batch=forward_batch, + multimodal_model=self, + language_model=self.language_model, + ) + return hidden_states + + +EntryClass = [DotsVLMForCausalLM] diff --git a/python/sglang/srt/models/dots_vlm_vit.py b/python/sglang/srt/models/dots_vlm_vit.py new file mode 100644 index 000000000..e36e01ee3 --- /dev/null +++ b/python/sglang/srt/models/dots_vlm_vit.py @@ -0,0 +1,337 @@ +import logging +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint +from torch.nn import LayerNorm +from transformers.modeling_utils import PreTrainedModel + +from sglang.srt.configs.dots_vlm import DotsVisionConfig +from sglang.srt.distributed import parallel_state +from sglang.srt.layers.attention.vision import VisionAttention +from sglang.srt.layers.quantization import QuantizationConfig +from sglang.srt.utils import add_prefix + +logger = logging.getLogger(__name__) + + +class VisionRotaryEmbedding(nn.Module): + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, seqlen: int) -> torch.Tensor: + seq = torch.arange( + seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype + ) + freqs = torch.outer(seq, self.inv_freq) + return freqs + + +class PatchMerger(nn.Module): + def __init__( + self, + dim: int, + context_dim: int, + spatial_merge_size: int = 2, + pre_norm="layernorm", + init_merger_std=None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = context_dim * (spatial_merge_size**2) + self.pre_norm = pre_norm + if self.pre_norm == "layernorm": + self.ln_q = LayerNorm(context_dim, eps=1e-6) + elif self.pre_norm == "rmsnorm": + self.ln_q = RMSNorm(context_dim, eps=1e-6) + else: + logger.warning(f"no norm in patch merger: {self.pre_norm}") + + self.mlp = nn.Sequential( + nn.Linear(self.hidden_size, self.hidden_size), + nn.GELU(), + nn.Linear(self.hidden_size, dim), + ) + + if init_merger_std is not None: + nn.init.normal_(self.mlp[0].weight, mean=0.0, std=init_merger_std) + nn.init.zeros_(self.mlp[0].bias) + nn.init.normal_(self.mlp[2].weight, mean=0.0, std=init_merger_std) + nn.init.zeros_(self.mlp[2].bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.pre_norm: + x = self.mlp(self.ln_q(x).view(-1, self.hidden_size)) + else: + x = self.mlp(x.view(-1, self.hidden_size)) + return x + + +class RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(dim)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + output = self._norm(x.float()).type_as(x) + return output * self.weight + + def extra_repr(self) -> str: + return f"{tuple(self.weight.shape)}, eps={self.eps}" + + def _norm(self, x: torch.Tensor) -> torch.Tensor: + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + +class DotsSwiGLUFFN(nn.Module): + def __init__(self, config, quant_config: Optional[QuantizationConfig] = None): + super().__init__() + hidden_features = config.intermediate_size + in_features = config.embed_dim + bias = config.use_bias + + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.fc2 = nn.Linear(hidden_features, in_features, bias=bias) + self.fc3 = nn.Linear(in_features, hidden_features, bias=bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = F.silu(self.fc1(x)) * self.fc3(x) + x = self.fc2(x) + return x + + +class DotsPatchEmbed(nn.Module): + def __init__(self, config, quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.num_channels = config.num_channels + self.patch_size = config.patch_size + self.temporal_patch_size = config.temporal_patch_size + self.embed_dim = config.embed_dim + self.config = config + self.proj = nn.Conv2d( + config.num_channels, + config.embed_dim, + kernel_size=(config.patch_size, config.patch_size), + stride=(config.patch_size, config.patch_size), + ) + self.norm = RMSNorm(config.embed_dim, eps=config.rms_norm_eps) + + def forward(self, x: torch.Tensor, grid_thw=None) -> torch.Tensor: + x = x.view( + -1, + self.num_channels, + self.temporal_patch_size, + self.patch_size, + self.patch_size, + )[:, :, 0] + x = self.proj(x).view(-1, self.embed_dim) + x = self.norm(x) + return x + + +class DotsViTPreprocessor(nn.Module): + def __init__(self, config, quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.patch_h = config.patch_size + self.patch_w = config.patch_size + self.embed_dim = config.embed_dim + self.config = config + self.patchifier = DotsPatchEmbed(config, quant_config) + + def forward(self, x: torch.Tensor, grid_thw=None) -> torch.Tensor: + tokens = self.patchifier(x, grid_thw) + return tokens + + +class DotsVisionBlock(nn.Module): + def __init__( + self, + config: DotsVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + attn_implementation: str = "flash_attention_2", + ): + super().__init__() + if attn_implementation == "flash_attention_2": + qkv_backend = "fa3" + softmax_in_single_precision = False + else: + raise RuntimeError("Unimplemented") + self.attn = VisionAttention( + embed_dim=config.embed_dim, + num_heads=config.num_attention_heads, + projection_size=config.embed_dim, + use_qkv_parallel=True, + qkv_backend=qkv_backend, + softmax_in_single_precision=softmax_in_single_precision, + flatten_batch=True, + quant_config=quant_config, + prefix=add_prefix("attn", prefix), + num_dummy_heads=config.num_dummy_heads, + qkv_bias=config.use_bias, + proj_bias=config.use_bias, + ) + self.norm1 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps) + self.mlp = DotsSwiGLUFFN(config, quant_config) + self.norm2 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps) + + def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor: + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states), + cu_seqlens=cu_seqlens, + position_embeddings=rotary_pos_emb, + ) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states + + +class DotsVisionTransformer(PreTrainedModel): + def __init__( + self, + config: DotsVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__(config) + self.config = config + self._update_vision_config() + self.spatial_merge_size = config.spatial_merge_size + + self.patch_embed = DotsViTPreprocessor(config, quant_config) + self._init_weights(self.patch_embed.patchifier.proj) + + head_dim = config.embed_dim // config.num_attention_heads + + self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2) + + _num_hidden_layers = config.num_hidden_layers + self.blocks = nn.ModuleList( + [ + DotsVisionBlock( + config, quant_config, f"blocks.{i}", config.attn_implementation + ) + for i in range(_num_hidden_layers) + ] + ) + + if self.config.post_norm: + self.post_trunk_norm = RMSNorm(config.embed_dim, eps=config.rms_norm_eps) + + self.merger = PatchMerger( + dim=config.hidden_size, + context_dim=config.embed_dim, + spatial_merge_size=config.spatial_merge_size, + init_merger_std=self.config.init_merger_std, + quant_config=quant_config, + ) + + self.gradient_checkpointing = False + + def _update_vision_config(self): + """update vision config to support tp""" + world_size = parallel_state.get_tensor_model_parallel_world_size() + num_heads = self.config.num_attention_heads + head_dim = self.config.embed_dim // num_heads + num_dummy_heads = 0 + + if num_heads % world_size != 0: + num_dummy_heads = ( + (num_heads + world_size) // world_size + ) * world_size - num_heads + + setattr(self.config, "head_dim", head_dim) + setattr(self.config, "num_dummy_heads", num_dummy_heads) + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + @property + def dtype(self) -> torch.dtype: + return self.blocks[0].mlp.fc2.weight.dtype + + @property + def device(self) -> torch.device: + return self.blocks[0].mlp.fc2.weight.device + + def get_pos_ids_by_grid(self, grid_thw): + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + + return pos_ids + + def rot_pos_emb(self, grid_thw): + pos_ids = self.get_pos_ids_by_grid(grid_thw) + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def calc_cos_sin(self, rotary_pos_emb): + cos = rotary_pos_emb.cos() + sin = rotary_pos_emb.sin() + cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() + sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() + rotary_pos_emb = (cos, sin) + return rotary_pos_emb + + def forward( + self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, bf16=True + ) -> torch.Tensor: + if bf16: + hidden_states = hidden_states.bfloat16() + hidden_states = self.patch_embed(hidden_states, grid_thw) + + rotary_pos_emb = self.rot_pos_emb(grid_thw) + rotary_pos_emb = self.calc_cos_sin(rotary_pos_emb) + + cu_seqlens = torch.repeat_interleave( + grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] + ).cumsum( + dim=0, + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + for blk in self.blocks: + hidden_states = blk( + hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb + ) + + if self.config.post_norm: + hidden_states = self.post_trunk_norm(hidden_states) + + hidden_states = self.merger(hidden_states) + return hidden_states diff --git a/python/sglang/srt/multimodal/processors/dots_vlm.py b/python/sglang/srt/multimodal/processors/dots_vlm.py new file mode 100644 index 000000000..a12edccae --- /dev/null +++ b/python/sglang/srt/multimodal/processors/dots_vlm.py @@ -0,0 +1,99 @@ +import asyncio +import math +import re +from typing import Dict, List, Union + +from PIL import Image + +from sglang.srt.models.dots_vlm import DotsVLMForCausalLM +from sglang.srt.multimodal.processors.base_processor import ( + BaseMultimodalProcessor, + MultimodalSpecialTokens, +) +from sglang.srt.multimodal.processors.qwen_vl import resize_image_async + + +class DotsVLMImageProcessor(BaseMultimodalProcessor): + models = [DotsVLMForCausalLM] + + def __init__(self, hf_config, server_args, _processor, *args, **kwargs): + super().__init__(hf_config, server_args, _processor, *args, **kwargs) + # The single, pre-expanded image token. + self.IMAGE_TOKEN = "<|img|><|imgpad|><|endofimg|>" + # The regex that matches expanded image tokens. + self.IMAGE_TOKEN_REGEX = re.compile(r"<\|img\|>(?:<\|imgpad\|>)+<\|endofimg\|>") + + assert len(_processor.tokenizer.encode("<|img|>")) == 1 + self.im_start_id = _processor.tokenizer.encode("<|img|>")[0] + self.im_end_id = _processor.tokenizer.encode("<|endofimg|>")[0] + self.image_token_id = _processor.tokenizer.encode("<|imgpad|>")[0] + self.IM_TOKEN_ID = self.image_token_id + self.IM_START_ID = self.im_start_id + self.IM_END_ID = self.im_end_id + + vision_config = hf_config.vision_config + patch_size = vision_config.patch_size + merge_size = vision_config.spatial_merge_size + + self.IMAGE_FACTOR = patch_size * merge_size + self.MIN_PIXELS = _processor.image_processor.min_pixels + self.MAX_PIXELS = _processor.image_processor.max_pixels + self.MAX_RATIO = 200 + self.mm_tokens = MultimodalSpecialTokens( + image_token=self.IMAGE_TOKEN, + image_token_id=self.image_token_id, + image_token_regex=self.IMAGE_TOKEN_REGEX, + ).build(_processor) + + async def process_mm_data_async( + self, + image_data: List[Union[str, bytes, Dict]], + input_text, + request_obj, + max_req_input_len, + *args, + **kwargs, + ): + if isinstance(image_data, str): + image_data = [image_data] + + if ( + isinstance(image_data, list) + and image_data + and isinstance(image_data[0], list) + ): + image_data = sum(image_data, []) + + base_output = self.load_mm_data( + prompt=input_text, + image_data=image_data, + multimodal_tokens=self.mm_tokens, + ) + + # Qwen-specific: resize images if they are raw Image objects + if base_output.images and isinstance(base_output.images[0], Image.Image): + resize_tasks = [ + resize_image_async( + image, + min_pixels=self.MIN_PIXELS, + max_pixels=self.MAX_PIXELS, + size_factor=self.IMAGE_FACTOR, + ) + for image in base_output.images + ] + base_output.images = await asyncio.gather(*resize_tasks) + + combined_mm_item, input_ids, _ = self.process_and_combine_mm_data( + base_output, self.mm_tokens + ) + + if combined_mm_item is None: + return None + + return { + "input_ids": input_ids.tolist(), + "mm_items": combined_mm_item, + "im_start_id": self.im_start_id, + "im_end_id": self.im_end_id, + "im_token_id": self.image_token_id, + } diff --git a/python/sglang/srt/multimodal/processors/qwen_vl.py b/python/sglang/srt/multimodal/processors/qwen_vl.py index f67f72b95..facddfea5 100644 --- a/python/sglang/srt/multimodal/processors/qwen_vl.py +++ b/python/sglang/srt/multimodal/processors/qwen_vl.py @@ -67,10 +67,15 @@ def smart_resize( return h_bar, w_bar -def resize_image(image, size_factor: int = IMAGE_FACTOR) -> Image.Image: +def resize_image( + image, + min_pixels: int = MIN_PIXELS, + max_pixels: int = MAX_PIXELS, + size_factor: int = IMAGE_FACTOR, +) -> Image.Image: width, height = image.size - min_pixels = MIN_PIXELS - max_pixels = MAX_PIXELS + min_pixels = min_pixels + max_pixels = max_pixels resized_height, resized_width = smart_resize( height, width, @@ -97,8 +102,13 @@ def floor_by_factor(number: int, factor: int) -> int: return math.floor(number / factor) * factor -async def resize_image_async(image): - return resize_image(image) +async def resize_image_async( + image, + min_pixels: int = MIN_PIXELS, + max_pixels: int = MAX_PIXELS, + size_factor: int = IMAGE_FACTOR, +): + return resize_image(image, min_pixels, max_pixels, size_factor) def smart_nframes(