diff --git a/python/sglang/srt/configs/__init__.py b/python/sglang/srt/configs/__init__.py index 0a57a8b26..7d285b3d3 100644 --- a/python/sglang/srt/configs/__init__.py +++ b/python/sglang/srt/configs/__init__.py @@ -1,6 +1,7 @@ from sglang.srt.configs.chatglm import ChatGLMConfig from sglang.srt.configs.dbrx import DbrxConfig from sglang.srt.configs.deepseekvl2 import DeepseekVL2Config +from sglang.srt.configs.dots_ocr import DotsOCRConfig from sglang.srt.configs.dots_vlm import DotsVLMConfig from sglang.srt.configs.exaone import ExaoneConfig from sglang.srt.configs.janus_pro import MultiModalityConfig @@ -28,4 +29,5 @@ __all__ = [ "Step3VisionEncoderConfig", "Qwen3NextConfig", "DotsVLMConfig", + "DotsOCRConfig", ] diff --git a/python/sglang/srt/configs/dots_ocr.py b/python/sglang/srt/configs/dots_ocr.py new file mode 100644 index 000000000..8b0693b8e --- /dev/null +++ b/python/sglang/srt/configs/dots_ocr.py @@ -0,0 +1,64 @@ +from typing import Optional + +from transformers import AutoProcessor, Qwen2_5_VLProcessor +from transformers.image_processing_utils import BaseImageProcessor +from transformers.models.qwen2 import Qwen2Config + +from sglang.srt.configs.dots_vlm import DotsVisionConfig + + +class DotsOCRConfig(Qwen2Config): + model_type = "dots_ocr" + + def __init__( + self, + image_token_id=151665, + video_token_id=151656, + vision_config: Optional[dict] = None, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.image_token_id = image_token_id + self.video_token_id = video_token_id + self.vision_config = DotsVisionConfig(**(vision_config or {})) + + def save_pretrained(self, save_directory, **kwargs): + self._auto_class = None + super().save_pretrained(save_directory, **kwargs) + + +class DummyVideoProcessor(BaseImageProcessor): + model_input_names = ["pixel_values"] + + def __call__(self, *args, **kwargs): + return None + + +class DotsVLProcessor(Qwen2_5_VLProcessor): + def __init__( + self, + image_processor=None, + tokenizer=None, + video_processor=None, + chat_template=None, + **kwargs + ): + if video_processor is None: + video_processor = DummyVideoProcessor() + super().__init__( + image_processor, tokenizer, video_processor, chat_template=chat_template + ) + self.image_token = ( + "<|imgpad|>" + if not hasattr(tokenizer, "image_token") + else tokenizer.image_token + ) + self.image_token_id = ( + tokenizer.image_token_id + if getattr(tokenizer, "image_token_id", None) is not None + else tokenizer.convert_tokens_to_ids(self.image_token) + ) + + +AutoProcessor.register(DotsOCRConfig, DotsVLProcessor) diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 92d0e130f..9132fb428 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -778,6 +778,7 @@ multimodal_model_archs = [ "VILAForConditionalGeneration", "Step3VLForConditionalGeneration", "DotsVLMForCausalLM", + "DotsOCRForCausalLM", "Sarashina2VisionForCausalLM", ] diff --git a/python/sglang/srt/hf_transformers_utils.py b/python/sglang/srt/hf_transformers_utils.py index 89c5b63f6..b97af68a1 100644 --- a/python/sglang/srt/hf_transformers_utils.py +++ b/python/sglang/srt/hf_transformers_utils.py @@ -38,6 +38,7 @@ from sglang.srt.configs import ( ChatGLMConfig, DbrxConfig, DeepseekVL2Config, + DotsOCRConfig, DotsVLMConfig, ExaoneConfig, KimiVLConfig, @@ -62,6 +63,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { LongcatFlashConfig.model_type: LongcatFlashConfig, Qwen3NextConfig.model_type: Qwen3NextConfig, DotsVLMConfig.model_type: DotsVLMConfig, + DotsOCRConfig.model_type: DotsOCRConfig, } for name, cls in _CONFIG_REGISTRY.items(): diff --git a/python/sglang/srt/models/dots_ocr.py b/python/sglang/srt/models/dots_ocr.py new file mode 100644 index 000000000..b0202367b --- /dev/null +++ b/python/sglang/srt/models/dots_ocr.py @@ -0,0 +1,173 @@ +# coding=utf-8 +# Adapted from Qwen2.5-VL SGLang implementation + +import logging +from typing import Iterable, List, Optional, Tuple + +import torch +import torch.nn as nn +from transformers.activations import ACT2FN + +from sglang.srt.configs import DotsOCRConfig +from sglang.srt.hf_transformers_utils import get_processor +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead +from sglang.srt.managers.mm_utils import ( + MultiModalityDataPaddingPatternMultimodalTokens, + general_mm_embed_routine, +) +from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.dots_vlm_vit import DotsVisionTransformer +from sglang.srt.models.qwen2 import Qwen2ForCausalLM +from sglang.srt.utils import add_prefix + +logger = logging.getLogger(__name__) + + +class DotsOCRForCausalLM(nn.Module): + def __init__( + self, + config: DotsOCRConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + + # Initialize vision transformer + self.visual = DotsVisionTransformer( + config.vision_config, + ) + + # Initialize language model + self.model = Qwen2ForCausalLM(config, quant_config) + + # Initialize LM head + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=add_prefix("lm_head", prefix), + ) + + self.logits_processor = LogitsProcessor(config) + + def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs): + pattern = MultiModalityDataPaddingPatternMultimodalTokens() + return pattern.pad_input_tokens(input_ids, mm_inputs) + + def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: + # Extract pixel values and grid information (following reference pattern) + pixel_values = torch.cat([item.feature for item in items], dim=0).type( + self.visual.dtype + ) + image_grid_thw = torch.concat( + [item.image_grid_thw for item in items], dim=0 + ).to(self.visual.device) + + # Add dimension checks like in reference code + assert pixel_values.dim() == 2, f"{pixel_values.dim()=}" + assert image_grid_thw.dim() == 2, f"{image_grid_thw.dim()=}" + + # Process through vision tower + image_embeds = self.visual(pixel_values, image_grid_thw) + + # Ensure consistent dtype for FlashInfer compatibility + # Force bfloat16 to match model's expected dtype + if hasattr(self.model, "embed_tokens"): + target_dtype = self.model.embed_tokens.weight.dtype + if image_embeds.dtype != target_dtype: + image_embeds = image_embeds.to(target_dtype) + + return image_embeds + + def _pad_vit_attn_dummy_heads(self, name: str, loaded_weight: torch.Tensor): + """pad attn qkv weights for dummy heads""" + num_dummy_heads = self.config.vision_config.num_dummy_heads + if num_dummy_heads == 0: + return loaded_weight + head_dim = self.config.vision_config.head_dim + + if "attn.qkv_proj" in name: + wq, wk, wv = loaded_weight.chunk(3, dim=0) + if name.endswith(".weight"): + dummy_shape = [num_dummy_heads, head_dim, wq.shape[-1]] + elif name.endswith(".bias"): + dummy_shape = [num_dummy_heads, head_dim] + else: + raise RuntimeError(f"Unsupported weight with name={name}") + pad_func = lambda x: torch.cat( + [x.unflatten(0, (-1, head_dim)), x.new_zeros(dummy_shape)], dim=0 + ).flatten(0, 1) + wq, wk, wv = pad_func(wq), pad_func(wk), pad_func(wv) + loaded_weight = torch.cat([wq, wk, wv], dim=0) + if "attn.proj.weight" in name: + padded_weight = loaded_weight.new_zeros( + loaded_weight.shape[0], head_dim * num_dummy_heads + ) + loaded_weight = torch.cat([loaded_weight, padded_weight], dim=-1) + if "attn.q_norm.weight" in name or "attn.k_norm.weight" in name: + padded_weight = loaded_weight.new_zeros(head_dim * num_dummy_heads) + loaded_weight = torch.cat([loaded_weight, padded_weight], dim=0) + return loaded_weight + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + **kwargs: object, + ) -> torch.Tensor: + hidden_states = general_mm_embed_routine( + input_ids=input_ids, + positions=positions, + forward_batch=forward_batch, + multimodal_model=self, + language_model=self.model, + ) + return hidden_states + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + """Load weights for the model, separating vision and language weights""" + weights = list(weights) + + # Separate vision tower weights and language model weights + vision_weights = [] + language_weights = [] + + for name, loaded_weight in weights: + if name.startswith("vision_tower."): + vision_name = name.replace(r"attn.qkv.", r"attn.qkv_proj.") + + vision_weights.append((vision_name, loaded_weight)) + else: + # All other weights go to language model + language_weights.append((name, loaded_weight)) + + # Load vision tower weights + vision_state_dict = dict(vision_weights) + params_dict = dict(self.named_parameters(remove_duplicate=False)) + + for name, loaded_weight in vision_state_dict.items(): + name = name.replace("vision_tower", "visual") + if name not in params_dict: + raise ValueError(f"Weight {name} not found in params_dict") + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + loaded_weight = self._pad_vit_attn_dummy_heads(name, loaded_weight) + weight_loader(param, loaded_weight) + + if language_weights: + self.model.load_weights(language_weights) + + def get_embed_and_head(self): + return self.model.embed_tokens.weight, self.lm_head.weight + + +EntryClass = [DotsOCRForCausalLM] diff --git a/python/sglang/srt/multimodal/processors/dots_vlm.py b/python/sglang/srt/multimodal/processors/dots_vlm.py index a12edccae..3b95beff3 100644 --- a/python/sglang/srt/multimodal/processors/dots_vlm.py +++ b/python/sglang/srt/multimodal/processors/dots_vlm.py @@ -5,6 +5,7 @@ from typing import Dict, List, Union from PIL import Image +from sglang.srt.models.dots_ocr import DotsOCRForCausalLM from sglang.srt.models.dots_vlm import DotsVLMForCausalLM from sglang.srt.multimodal.processors.base_processor import ( BaseMultimodalProcessor, @@ -14,7 +15,7 @@ from sglang.srt.multimodal.processors.qwen_vl import resize_image_async class DotsVLMImageProcessor(BaseMultimodalProcessor): - models = [DotsVLMForCausalLM] + models = [DotsVLMForCausalLM, DotsOCRForCausalLM] def __init__(self, hf_config, server_args, _processor, *args, **kwargs): super().__init__(hf_config, server_args, _processor, *args, **kwargs) @@ -82,11 +83,9 @@ class DotsVLMImageProcessor(BaseMultimodalProcessor): for image in base_output.images ] base_output.images = await asyncio.gather(*resize_tasks) - combined_mm_item, input_ids, _ = self.process_and_combine_mm_data( base_output, self.mm_tokens ) - if combined_mm_item is None: return None