model: support dots.vlm1 model (#8778)
Co-authored-by: weishi <bushou@xiaohongshu.com> Co-authored-by: Ezra-Yu <1105212286@qq.com> Co-authored-by: Jianfei Wang <905787410@qq.com> Co-authored-by: qianwu <wangjianfei@xiaohongshu.com>
This commit is contained in:
committed by
GitHub
parent
6d40308905
commit
1b1701f1f7
@@ -44,7 +44,6 @@ runtime_common = [
|
||||
"pynvml",
|
||||
"python-multipart",
|
||||
"pyzmq>=25.1.2",
|
||||
"sentencepiece",
|
||||
"soundfile==0.13.1",
|
||||
"scipy",
|
||||
"timm==1.0.16",
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from sglang.srt.configs.chatglm import ChatGLMConfig
|
||||
from sglang.srt.configs.dbrx import DbrxConfig
|
||||
from sglang.srt.configs.deepseekvl2 import DeepseekVL2Config
|
||||
from sglang.srt.configs.dots_vlm import DotsVLMConfig
|
||||
from sglang.srt.configs.exaone import ExaoneConfig
|
||||
from sglang.srt.configs.janus_pro import MultiModalityConfig
|
||||
from sglang.srt.configs.kimi_vl import KimiVLConfig
|
||||
@@ -26,4 +27,5 @@ __all__ = [
|
||||
"Step3TextConfig",
|
||||
"Step3VisionEncoderConfig",
|
||||
"Qwen3NextConfig",
|
||||
"DotsVLMConfig",
|
||||
]
|
||||
|
||||
139
python/sglang/srt/configs/dots_vlm.py
Normal file
139
python/sglang/srt/configs/dots_vlm.py
Normal file
@@ -0,0 +1,139 @@
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
from transformers import AutoProcessor, LlamaTokenizerFast, PretrainedConfig
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
from transformers.image_utils import ImageInput
|
||||
from transformers.processing_utils import ProcessingKwargs, Unpack
|
||||
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
||||
|
||||
try:
|
||||
from transformers import Qwen2_5_VLProcessor
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Qwen2_5_VLProcessor can not be found. Please upgrade your transformers version."
|
||||
)
|
||||
|
||||
from sglang.srt.configs.deepseekvl2 import DeepseekV2Config
|
||||
|
||||
|
||||
class DotsVisionConfig(PretrainedConfig):
|
||||
model_type: str = "dots_vit"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int = 1536, # vision encoder embed size
|
||||
hidden_size: int = 1536, # after merger hidden size
|
||||
intermediate_size: int = 4224,
|
||||
num_hidden_layers: int = 42,
|
||||
num_attention_heads: int = 12,
|
||||
num_channels: int = 3,
|
||||
patch_size: int = 14,
|
||||
spatial_merge_size: int = 2,
|
||||
temporal_patch_size: int = 1,
|
||||
rms_norm_eps: float = 1e-5,
|
||||
use_bias: bool = False,
|
||||
attn_implementation="flash_attention_2", # "eager","sdpa","flash_attention_2"
|
||||
initializer_range=0.02,
|
||||
init_merger_std=0.02,
|
||||
is_causal=False, # ve causal forward
|
||||
post_norm=True,
|
||||
gradient_checkpointing=False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.embed_dim = embed_dim
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_channels = num_channels
|
||||
self.patch_size = patch_size
|
||||
self.spatial_merge_size = spatial_merge_size
|
||||
self.temporal_patch_size = temporal_patch_size
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_bias = use_bias
|
||||
self.attn_implementation = attn_implementation
|
||||
self.initializer_range = initializer_range
|
||||
self.init_merger_std = init_merger_std
|
||||
self.is_causal = is_causal
|
||||
self.post_norm = post_norm
|
||||
self.gradient_checkpointing = gradient_checkpointing
|
||||
|
||||
|
||||
class DotsVLMConfig(PretrainedConfig):
|
||||
model_type = "dots_vlm"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
vision_config = kwargs.get("vision_config", {})
|
||||
self.im_span_id = kwargs.get("image_token_id", 128815)
|
||||
self.video_span_id = kwargs.get("video_token_id", 128836)
|
||||
self.vision_config = DotsVisionConfig(**vision_config)
|
||||
self.language_config = DeepseekV2Config(**kwargs)
|
||||
self.architectures = ["DotsVLMForCausalLM"]
|
||||
|
||||
|
||||
class DotsVLMProcessorKwargs(ProcessingKwargs, total=False):
|
||||
_defaults = {
|
||||
"text_kwargs": {
|
||||
"padding": False,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class DotsVLMProcessor(Qwen2_5_VLProcessor):
|
||||
r"""
|
||||
Constructs a DotsVLM processor which derives from Qwen2_5_VLProcessor, but overrides the image and video token ids.
|
||||
Besides, its tokenizer is a LlamaTokenizerFast instead of Qwen2TokenizerFast.
|
||||
[`DotsVLMProcessor`] offers all the functionalities of [`DotsVisionConfig`] and [`LlamaTokenizerFast`]. See the
|
||||
[`~DotsVLMProcessor.__call__`] and [`~DotsVLMProcessor.decode`] for more information.
|
||||
Args:
|
||||
image_processor ([`Qwen2VLImageProcessor`], *optional*):
|
||||
The image processor is a required input.
|
||||
tokenizer ([`LlamaTokenizerFast`], *optional*):
|
||||
The tokenizer is a required input.
|
||||
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
|
||||
in a chat into a tokenizable string.
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
|
||||
valid_kwargs = ["chat_template"]
|
||||
|
||||
tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
|
||||
|
||||
def __init__(
|
||||
self, image_processor=None, tokenizer=None, chat_template=None, **kwargs
|
||||
):
|
||||
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
||||
self.image_token = (
|
||||
"<|imgpad|>"
|
||||
if not hasattr(tokenizer, "image_token")
|
||||
else tokenizer.image_token
|
||||
)
|
||||
self.video_token = (
|
||||
"<|video_pad|>"
|
||||
if not hasattr(tokenizer, "video_token")
|
||||
else tokenizer.video_token
|
||||
)
|
||||
self.img_token = (
|
||||
"<|img|>" if not hasattr(tokenizer, "img_token") else tokenizer.img_token
|
||||
)
|
||||
self.endofimg_token = (
|
||||
"<|endofimg|>"
|
||||
if not hasattr(tokenizer, "endofimg_token")
|
||||
else tokenizer.endofimg_token
|
||||
)
|
||||
self.image_token_id = (
|
||||
tokenizer.image_token_id
|
||||
if getattr(tokenizer, "image_token_id", None)
|
||||
else tokenizer.encode(self.image_token)[0]
|
||||
)
|
||||
self.video_token_id = (
|
||||
tokenizer.video_token_id
|
||||
if getattr(tokenizer, "video_token_id", None)
|
||||
else tokenizer.encode(self.video_token)[0]
|
||||
)
|
||||
|
||||
|
||||
AutoProcessor.register(DotsVLMConfig, DotsVLMProcessor)
|
||||
@@ -216,6 +216,7 @@ class ModelConfig:
|
||||
or "DeepseekV3ForCausalLMNextN" in self.hf_config.architectures
|
||||
or "LongcatFlashForCausalLM" in self.hf_config.architectures
|
||||
or "LongcatFlashForCausalLMNextN" in self.hf_config.architectures
|
||||
or "DotsVLMForCausalLM" in self.hf_config.architectures
|
||||
):
|
||||
self.head_dim = 256
|
||||
self.attention_arch = AttentionArch.MLA
|
||||
@@ -734,6 +735,7 @@ multimodal_model_archs = [
|
||||
"Phi4MMForCausalLM",
|
||||
"VILAForConditionalGeneration",
|
||||
"Step3VLForConditionalGeneration",
|
||||
"DotsVLMForCausalLM",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -38,6 +38,7 @@ from sglang.srt.configs import (
|
||||
ChatGLMConfig,
|
||||
DbrxConfig,
|
||||
DeepseekVL2Config,
|
||||
DotsVLMConfig,
|
||||
ExaoneConfig,
|
||||
KimiVLConfig,
|
||||
LongcatFlashConfig,
|
||||
@@ -60,6 +61,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
||||
Step3VLConfig.model_type: Step3VLConfig,
|
||||
LongcatFlashConfig.model_type: LongcatFlashConfig,
|
||||
Qwen3NextConfig.model_type: Qwen3NextConfig,
|
||||
DotsVLMConfig.model_type: DotsVLMConfig,
|
||||
}
|
||||
|
||||
for name, cls in _CONFIG_REGISTRY.items():
|
||||
|
||||
174
python/sglang/srt/models/dots_vlm.py
Normal file
174
python/sglang/srt/models/dots_vlm.py
Normal file
@@ -0,0 +1,174 @@
|
||||
# Copyright 2025 The RedNote HiLab team.
|
||||
# Copyright 2025 The SGLang team.
|
||||
#
|
||||
# This code is based on the DeepseekVL2ForCausalLM and DotsVisionTransformer
|
||||
# implementation in this library.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only Dots-VL model compatible with HuggingFace weights."""
|
||||
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from sglang.srt.configs.dots_vlm import DotsVLMConfig
|
||||
from sglang.srt.distributed import parallel_state
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.managers.mm_utils import (
|
||||
MultiModalityDataPaddingPatternMultimodalTokens,
|
||||
general_mm_embed_routine,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.deepseek_v2 import DeepseekV2ForCausalLM
|
||||
|
||||
from .dots_vlm_vit import DotsVisionTransformer
|
||||
|
||||
|
||||
class DotsVLMForCausalLM(nn.Module):
|
||||
"""DotsVLM model for sglang inference"""
|
||||
|
||||
def __init__(
|
||||
self, config: DotsVLMConfig, quant_config: Optional[QuantizationConfig] = None
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.image_token_id = config.im_span_id
|
||||
self.video_token_id = config.video_span_id
|
||||
|
||||
self.language_model = DeepseekV2ForCausalLM(
|
||||
config.language_config, quant_config
|
||||
)
|
||||
|
||||
# Initialize vision tower (matching transformers naming for weight compatibility)
|
||||
self.vision_tower = DotsVisionTransformer(config.vision_config)
|
||||
|
||||
def _pad_vit_attn_dummy_heads(self, name: str, loaded_weight: torch.Tensor):
|
||||
"""pad attn qkv weights for dummy heads"""
|
||||
num_dummy_heads = self.config.vision_config.num_dummy_heads
|
||||
if num_dummy_heads == 0:
|
||||
return loaded_weight
|
||||
head_dim = self.config.vision_config.head_dim
|
||||
|
||||
if "attn.qkv_proj" in name:
|
||||
wq, wk, wv = loaded_weight.chunk(3, dim=0)
|
||||
if name.endswith(".weight"):
|
||||
dummy_shape = [num_dummy_heads, head_dim, wq.shape[-1]]
|
||||
elif name.endswith(".bias"):
|
||||
dummy_shape = [num_dummy_heads, head_dim]
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported weight with name={name}")
|
||||
pad_func = lambda x: torch.cat(
|
||||
[x.unflatten(0, (-1, head_dim)), x.new_zeros(dummy_shape)], dim=0
|
||||
).flatten(0, 1)
|
||||
wq, wk, wv = pad_func(wq), pad_func(wk), pad_func(wv)
|
||||
loaded_weight = torch.cat([wq, wk, wv], dim=0)
|
||||
if "attn.proj.weight" in name:
|
||||
padded_weight = loaded_weight.new_zeros(
|
||||
loaded_weight.shape[0], head_dim * num_dummy_heads
|
||||
)
|
||||
loaded_weight = torch.cat([loaded_weight, padded_weight], dim=-1)
|
||||
if "attn.q_norm.weight" in name or "attn.k_norm.weight" in name:
|
||||
padded_weight = loaded_weight.new_zeros(head_dim * num_dummy_heads)
|
||||
loaded_weight = torch.cat([loaded_weight, padded_weight], dim=0)
|
||||
return loaded_weight
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
"""Load weights for the model, separating vision and language weights"""
|
||||
weights = list(weights)
|
||||
|
||||
# Separate vision tower weights and language model weights
|
||||
vision_weights = []
|
||||
language_weights = []
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
if name.startswith("vision_tower."):
|
||||
vision_name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
|
||||
vision_weights.append((vision_name, loaded_weight))
|
||||
else:
|
||||
# All other weights go to language model
|
||||
language_weights.append((name, loaded_weight))
|
||||
|
||||
# Load vision tower weights
|
||||
vision_state_dict = dict(vision_weights)
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
for name, loaded_weight in vision_state_dict.items():
|
||||
if name not in params_dict:
|
||||
raise ValueError(f"Weight {name} not found in params_dict")
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
loaded_weight = self._pad_vit_attn_dummy_heads(name, loaded_weight)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
# Load language model weights
|
||||
if language_weights:
|
||||
self.language_model.load_weights(language_weights)
|
||||
|
||||
@classmethod
|
||||
def get_model_config_for_expert_location(cls, config):
|
||||
return DeepseekV2ForCausalLM.get_model_config_for_expert_location(config)
|
||||
|
||||
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
||||
"""Pad input_ids with multimodal tokens"""
|
||||
# Get image token ID for padding pattern
|
||||
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
|
||||
padded_input_ids = pattern.pad_input_tokens(input_ids, mm_inputs)
|
||||
return padded_input_ids
|
||||
|
||||
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
||||
# Extract pixel values and grid information (following reference pattern)
|
||||
pixel_values = torch.cat([item.feature for item in items], dim=0).type(
|
||||
self.vision_tower.dtype
|
||||
)
|
||||
image_grid_thw = torch.concat(
|
||||
[item.image_grid_thw for item in items], dim=0
|
||||
).to(self.vision_tower.device)
|
||||
|
||||
# Add dimension checks like in reference code
|
||||
assert pixel_values.dim() == 2, f"{pixel_values.dim()=}"
|
||||
assert image_grid_thw.dim() == 2, f"{image_grid_thw.dim()=}"
|
||||
|
||||
# Process through vision tower
|
||||
image_embeds = self.vision_tower(pixel_values, image_grid_thw)
|
||||
|
||||
# Ensure consistent dtype for FlashInfer compatibility
|
||||
# Force bfloat16 to match model's expected dtype
|
||||
if image_embeds.dtype != torch.bfloat16 and hasattr(
|
||||
self.language_model.model, "embed_tokens"
|
||||
):
|
||||
target_dtype = self.language_model.model.embed_tokens.weight.dtype
|
||||
image_embeds = image_embeds.to(target_dtype)
|
||||
|
||||
return image_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
**kwargs: object,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = general_mm_embed_routine(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
forward_batch=forward_batch,
|
||||
multimodal_model=self,
|
||||
language_model=self.language_model,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
|
||||
EntryClass = [DotsVLMForCausalLM]
|
||||
337
python/sglang/srt/models/dots_vlm_vit.py
Normal file
337
python/sglang/srt/models/dots_vlm_vit.py
Normal file
@@ -0,0 +1,337 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch.nn import LayerNorm
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
|
||||
from sglang.srt.configs.dots_vlm import DotsVisionConfig
|
||||
from sglang.srt.distributed import parallel_state
|
||||
from sglang.srt.layers.attention.vision import VisionAttention
|
||||
from sglang.srt.layers.quantization import QuantizationConfig
|
||||
from sglang.srt.utils import add_prefix
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VisionRotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim: int, theta: float = 10000.0) -> None:
|
||||
super().__init__()
|
||||
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
def forward(self, seqlen: int) -> torch.Tensor:
|
||||
seq = torch.arange(
|
||||
seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
|
||||
)
|
||||
freqs = torch.outer(seq, self.inv_freq)
|
||||
return freqs
|
||||
|
||||
|
||||
class PatchMerger(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
context_dim: int,
|
||||
spatial_merge_size: int = 2,
|
||||
pre_norm="layernorm",
|
||||
init_merger_std=None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = context_dim * (spatial_merge_size**2)
|
||||
self.pre_norm = pre_norm
|
||||
if self.pre_norm == "layernorm":
|
||||
self.ln_q = LayerNorm(context_dim, eps=1e-6)
|
||||
elif self.pre_norm == "rmsnorm":
|
||||
self.ln_q = RMSNorm(context_dim, eps=1e-6)
|
||||
else:
|
||||
logger.warning(f"no norm in patch merger: {self.pre_norm}")
|
||||
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(self.hidden_size, self.hidden_size),
|
||||
nn.GELU(),
|
||||
nn.Linear(self.hidden_size, dim),
|
||||
)
|
||||
|
||||
if init_merger_std is not None:
|
||||
nn.init.normal_(self.mlp[0].weight, mean=0.0, std=init_merger_std)
|
||||
nn.init.zeros_(self.mlp[0].bias)
|
||||
nn.init.normal_(self.mlp[2].weight, mean=0.0, std=init_merger_std)
|
||||
nn.init.zeros_(self.mlp[2].bias)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.pre_norm:
|
||||
x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))
|
||||
else:
|
||||
x = self.mlp(x.view(-1, self.hidden_size))
|
||||
return x
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
output = self._norm(x.float()).type_as(x)
|
||||
return output * self.weight
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return f"{tuple(self.weight.shape)}, eps={self.eps}"
|
||||
|
||||
def _norm(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
|
||||
class DotsSwiGLUFFN(nn.Module):
|
||||
def __init__(self, config, quant_config: Optional[QuantizationConfig] = None):
|
||||
super().__init__()
|
||||
hidden_features = config.intermediate_size
|
||||
in_features = config.embed_dim
|
||||
bias = config.use_bias
|
||||
|
||||
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
||||
self.fc2 = nn.Linear(hidden_features, in_features, bias=bias)
|
||||
self.fc3 = nn.Linear(in_features, hidden_features, bias=bias)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = F.silu(self.fc1(x)) * self.fc3(x)
|
||||
x = self.fc2(x)
|
||||
return x
|
||||
|
||||
|
||||
class DotsPatchEmbed(nn.Module):
|
||||
def __init__(self, config, quant_config: Optional[QuantizationConfig] = None):
|
||||
super().__init__()
|
||||
self.num_channels = config.num_channels
|
||||
self.patch_size = config.patch_size
|
||||
self.temporal_patch_size = config.temporal_patch_size
|
||||
self.embed_dim = config.embed_dim
|
||||
self.config = config
|
||||
self.proj = nn.Conv2d(
|
||||
config.num_channels,
|
||||
config.embed_dim,
|
||||
kernel_size=(config.patch_size, config.patch_size),
|
||||
stride=(config.patch_size, config.patch_size),
|
||||
)
|
||||
self.norm = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(self, x: torch.Tensor, grid_thw=None) -> torch.Tensor:
|
||||
x = x.view(
|
||||
-1,
|
||||
self.num_channels,
|
||||
self.temporal_patch_size,
|
||||
self.patch_size,
|
||||
self.patch_size,
|
||||
)[:, :, 0]
|
||||
x = self.proj(x).view(-1, self.embed_dim)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
|
||||
class DotsViTPreprocessor(nn.Module):
|
||||
def __init__(self, config, quant_config: Optional[QuantizationConfig] = None):
|
||||
super().__init__()
|
||||
self.patch_h = config.patch_size
|
||||
self.patch_w = config.patch_size
|
||||
self.embed_dim = config.embed_dim
|
||||
self.config = config
|
||||
self.patchifier = DotsPatchEmbed(config, quant_config)
|
||||
|
||||
def forward(self, x: torch.Tensor, grid_thw=None) -> torch.Tensor:
|
||||
tokens = self.patchifier(x, grid_thw)
|
||||
return tokens
|
||||
|
||||
|
||||
class DotsVisionBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: DotsVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
attn_implementation: str = "flash_attention_2",
|
||||
):
|
||||
super().__init__()
|
||||
if attn_implementation == "flash_attention_2":
|
||||
qkv_backend = "fa3"
|
||||
softmax_in_single_precision = False
|
||||
else:
|
||||
raise RuntimeError("Unimplemented")
|
||||
self.attn = VisionAttention(
|
||||
embed_dim=config.embed_dim,
|
||||
num_heads=config.num_attention_heads,
|
||||
projection_size=config.embed_dim,
|
||||
use_qkv_parallel=True,
|
||||
qkv_backend=qkv_backend,
|
||||
softmax_in_single_precision=softmax_in_single_precision,
|
||||
flatten_batch=True,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
num_dummy_heads=config.num_dummy_heads,
|
||||
qkv_bias=config.use_bias,
|
||||
proj_bias=config.use_bias,
|
||||
)
|
||||
self.norm1 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
|
||||
self.mlp = DotsSwiGLUFFN(config, quant_config)
|
||||
self.norm2 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor:
|
||||
hidden_states = hidden_states + self.attn(
|
||||
self.norm1(hidden_states),
|
||||
cu_seqlens=cu_seqlens,
|
||||
position_embeddings=rotary_pos_emb,
|
||||
)
|
||||
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DotsVisionTransformer(PreTrainedModel):
|
||||
def __init__(
|
||||
self,
|
||||
config: DotsVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self._update_vision_config()
|
||||
self.spatial_merge_size = config.spatial_merge_size
|
||||
|
||||
self.patch_embed = DotsViTPreprocessor(config, quant_config)
|
||||
self._init_weights(self.patch_embed.patchifier.proj)
|
||||
|
||||
head_dim = config.embed_dim // config.num_attention_heads
|
||||
|
||||
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
|
||||
|
||||
_num_hidden_layers = config.num_hidden_layers
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
DotsVisionBlock(
|
||||
config, quant_config, f"blocks.{i}", config.attn_implementation
|
||||
)
|
||||
for i in range(_num_hidden_layers)
|
||||
]
|
||||
)
|
||||
|
||||
if self.config.post_norm:
|
||||
self.post_trunk_norm = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
|
||||
|
||||
self.merger = PatchMerger(
|
||||
dim=config.hidden_size,
|
||||
context_dim=config.embed_dim,
|
||||
spatial_merge_size=config.spatial_merge_size,
|
||||
init_merger_std=self.config.init_merger_std,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def _update_vision_config(self):
|
||||
"""update vision config to support tp"""
|
||||
world_size = parallel_state.get_tensor_model_parallel_world_size()
|
||||
num_heads = self.config.num_attention_heads
|
||||
head_dim = self.config.embed_dim // num_heads
|
||||
num_dummy_heads = 0
|
||||
|
||||
if num_heads % world_size != 0:
|
||||
num_dummy_heads = (
|
||||
(num_heads + world_size) // world_size
|
||||
) * world_size - num_heads
|
||||
|
||||
setattr(self.config, "head_dim", head_dim)
|
||||
setattr(self.config, "num_dummy_heads", num_dummy_heads)
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
return self.blocks[0].mlp.fc2.weight.dtype
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return self.blocks[0].mlp.fc2.weight.device
|
||||
|
||||
def get_pos_ids_by_grid(self, grid_thw):
|
||||
pos_ids = []
|
||||
for t, h, w in grid_thw:
|
||||
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
||||
hpos_ids = hpos_ids.reshape(
|
||||
h // self.spatial_merge_size,
|
||||
self.spatial_merge_size,
|
||||
w // self.spatial_merge_size,
|
||||
self.spatial_merge_size,
|
||||
)
|
||||
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
|
||||
hpos_ids = hpos_ids.flatten()
|
||||
|
||||
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
|
||||
wpos_ids = wpos_ids.reshape(
|
||||
h // self.spatial_merge_size,
|
||||
self.spatial_merge_size,
|
||||
w // self.spatial_merge_size,
|
||||
self.spatial_merge_size,
|
||||
)
|
||||
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
|
||||
wpos_ids = wpos_ids.flatten()
|
||||
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
||||
|
||||
return pos_ids
|
||||
|
||||
def rot_pos_emb(self, grid_thw):
|
||||
pos_ids = self.get_pos_ids_by_grid(grid_thw)
|
||||
pos_ids = torch.cat(pos_ids, dim=0)
|
||||
max_grid_size = grid_thw[:, 1:].max()
|
||||
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
|
||||
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
||||
return rotary_pos_emb
|
||||
|
||||
def calc_cos_sin(self, rotary_pos_emb):
|
||||
cos = rotary_pos_emb.cos()
|
||||
sin = rotary_pos_emb.sin()
|
||||
cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
|
||||
sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
|
||||
rotary_pos_emb = (cos, sin)
|
||||
return rotary_pos_emb
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, bf16=True
|
||||
) -> torch.Tensor:
|
||||
if bf16:
|
||||
hidden_states = hidden_states.bfloat16()
|
||||
hidden_states = self.patch_embed(hidden_states, grid_thw)
|
||||
|
||||
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
||||
rotary_pos_emb = self.calc_cos_sin(rotary_pos_emb)
|
||||
|
||||
cu_seqlens = torch.repeat_interleave(
|
||||
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
||||
).cumsum(
|
||||
dim=0,
|
||||
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
|
||||
)
|
||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
||||
|
||||
for blk in self.blocks:
|
||||
hidden_states = blk(
|
||||
hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
|
||||
)
|
||||
|
||||
if self.config.post_norm:
|
||||
hidden_states = self.post_trunk_norm(hidden_states)
|
||||
|
||||
hidden_states = self.merger(hidden_states)
|
||||
return hidden_states
|
||||
99
python/sglang/srt/multimodal/processors/dots_vlm.py
Normal file
99
python/sglang/srt/multimodal/processors/dots_vlm.py
Normal file
@@ -0,0 +1,99 @@
|
||||
import asyncio
|
||||
import math
|
||||
import re
|
||||
from typing import Dict, List, Union
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from sglang.srt.models.dots_vlm import DotsVLMForCausalLM
|
||||
from sglang.srt.multimodal.processors.base_processor import (
|
||||
BaseMultimodalProcessor,
|
||||
MultimodalSpecialTokens,
|
||||
)
|
||||
from sglang.srt.multimodal.processors.qwen_vl import resize_image_async
|
||||
|
||||
|
||||
class DotsVLMImageProcessor(BaseMultimodalProcessor):
|
||||
models = [DotsVLMForCausalLM]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
||||
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
||||
# The single, pre-expanded image token.
|
||||
self.IMAGE_TOKEN = "<|img|><|imgpad|><|endofimg|>"
|
||||
# The regex that matches expanded image tokens.
|
||||
self.IMAGE_TOKEN_REGEX = re.compile(r"<\|img\|>(?:<\|imgpad\|>)+<\|endofimg\|>")
|
||||
|
||||
assert len(_processor.tokenizer.encode("<|img|>")) == 1
|
||||
self.im_start_id = _processor.tokenizer.encode("<|img|>")[0]
|
||||
self.im_end_id = _processor.tokenizer.encode("<|endofimg|>")[0]
|
||||
self.image_token_id = _processor.tokenizer.encode("<|imgpad|>")[0]
|
||||
self.IM_TOKEN_ID = self.image_token_id
|
||||
self.IM_START_ID = self.im_start_id
|
||||
self.IM_END_ID = self.im_end_id
|
||||
|
||||
vision_config = hf_config.vision_config
|
||||
patch_size = vision_config.patch_size
|
||||
merge_size = vision_config.spatial_merge_size
|
||||
|
||||
self.IMAGE_FACTOR = patch_size * merge_size
|
||||
self.MIN_PIXELS = _processor.image_processor.min_pixels
|
||||
self.MAX_PIXELS = _processor.image_processor.max_pixels
|
||||
self.MAX_RATIO = 200
|
||||
self.mm_tokens = MultimodalSpecialTokens(
|
||||
image_token=self.IMAGE_TOKEN,
|
||||
image_token_id=self.image_token_id,
|
||||
image_token_regex=self.IMAGE_TOKEN_REGEX,
|
||||
).build(_processor)
|
||||
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
image_data: List[Union[str, bytes, Dict]],
|
||||
input_text,
|
||||
request_obj,
|
||||
max_req_input_len,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
if isinstance(image_data, str):
|
||||
image_data = [image_data]
|
||||
|
||||
if (
|
||||
isinstance(image_data, list)
|
||||
and image_data
|
||||
and isinstance(image_data[0], list)
|
||||
):
|
||||
image_data = sum(image_data, [])
|
||||
|
||||
base_output = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
image_data=image_data,
|
||||
multimodal_tokens=self.mm_tokens,
|
||||
)
|
||||
|
||||
# Qwen-specific: resize images if they are raw Image objects
|
||||
if base_output.images and isinstance(base_output.images[0], Image.Image):
|
||||
resize_tasks = [
|
||||
resize_image_async(
|
||||
image,
|
||||
min_pixels=self.MIN_PIXELS,
|
||||
max_pixels=self.MAX_PIXELS,
|
||||
size_factor=self.IMAGE_FACTOR,
|
||||
)
|
||||
for image in base_output.images
|
||||
]
|
||||
base_output.images = await asyncio.gather(*resize_tasks)
|
||||
|
||||
combined_mm_item, input_ids, _ = self.process_and_combine_mm_data(
|
||||
base_output, self.mm_tokens
|
||||
)
|
||||
|
||||
if combined_mm_item is None:
|
||||
return None
|
||||
|
||||
return {
|
||||
"input_ids": input_ids.tolist(),
|
||||
"mm_items": combined_mm_item,
|
||||
"im_start_id": self.im_start_id,
|
||||
"im_end_id": self.im_end_id,
|
||||
"im_token_id": self.image_token_id,
|
||||
}
|
||||
@@ -67,10 +67,15 @@ def smart_resize(
|
||||
return h_bar, w_bar
|
||||
|
||||
|
||||
def resize_image(image, size_factor: int = IMAGE_FACTOR) -> Image.Image:
|
||||
def resize_image(
|
||||
image,
|
||||
min_pixels: int = MIN_PIXELS,
|
||||
max_pixels: int = MAX_PIXELS,
|
||||
size_factor: int = IMAGE_FACTOR,
|
||||
) -> Image.Image:
|
||||
width, height = image.size
|
||||
min_pixels = MIN_PIXELS
|
||||
max_pixels = MAX_PIXELS
|
||||
min_pixels = min_pixels
|
||||
max_pixels = max_pixels
|
||||
resized_height, resized_width = smart_resize(
|
||||
height,
|
||||
width,
|
||||
@@ -97,8 +102,13 @@ def floor_by_factor(number: int, factor: int) -> int:
|
||||
return math.floor(number / factor) * factor
|
||||
|
||||
|
||||
async def resize_image_async(image):
|
||||
return resize_image(image)
|
||||
async def resize_image_async(
|
||||
image,
|
||||
min_pixels: int = MIN_PIXELS,
|
||||
max_pixels: int = MAX_PIXELS,
|
||||
size_factor: int = IMAGE_FACTOR,
|
||||
):
|
||||
return resize_image(image, min_pixels, max_pixels, size_factor)
|
||||
|
||||
|
||||
def smart_nframes(
|
||||
|
||||
Reference in New Issue
Block a user