Support Dots.ocr model (#11071)

This commit is contained in:
qrskannbara
2025-10-01 03:18:39 +08:00
committed by GitHub
parent a6cc86df9d
commit fb367acfcb
6 changed files with 244 additions and 3 deletions

View File

@@ -1,6 +1,7 @@
from sglang.srt.configs.chatglm import ChatGLMConfig
from sglang.srt.configs.dbrx import DbrxConfig
from sglang.srt.configs.deepseekvl2 import DeepseekVL2Config
from sglang.srt.configs.dots_ocr import DotsOCRConfig
from sglang.srt.configs.dots_vlm import DotsVLMConfig
from sglang.srt.configs.exaone import ExaoneConfig
from sglang.srt.configs.janus_pro import MultiModalityConfig
@@ -28,4 +29,5 @@ __all__ = [
"Step3VisionEncoderConfig",
"Qwen3NextConfig",
"DotsVLMConfig",
"DotsOCRConfig",
]

View File

@@ -0,0 +1,64 @@
from typing import Optional
from transformers import AutoProcessor, Qwen2_5_VLProcessor
from transformers.image_processing_utils import BaseImageProcessor
from transformers.models.qwen2 import Qwen2Config
from sglang.srt.configs.dots_vlm import DotsVisionConfig
class DotsOCRConfig(Qwen2Config):
model_type = "dots_ocr"
def __init__(
self,
image_token_id=151665,
video_token_id=151656,
vision_config: Optional[dict] = None,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.image_token_id = image_token_id
self.video_token_id = video_token_id
self.vision_config = DotsVisionConfig(**(vision_config or {}))
def save_pretrained(self, save_directory, **kwargs):
self._auto_class = None
super().save_pretrained(save_directory, **kwargs)
class DummyVideoProcessor(BaseImageProcessor):
model_input_names = ["pixel_values"]
def __call__(self, *args, **kwargs):
return None
class DotsVLProcessor(Qwen2_5_VLProcessor):
def __init__(
self,
image_processor=None,
tokenizer=None,
video_processor=None,
chat_template=None,
**kwargs
):
if video_processor is None:
video_processor = DummyVideoProcessor()
super().__init__(
image_processor, tokenizer, video_processor, chat_template=chat_template
)
self.image_token = (
"<|imgpad|>"
if not hasattr(tokenizer, "image_token")
else tokenizer.image_token
)
self.image_token_id = (
tokenizer.image_token_id
if getattr(tokenizer, "image_token_id", None) is not None
else tokenizer.convert_tokens_to_ids(self.image_token)
)
AutoProcessor.register(DotsOCRConfig, DotsVLProcessor)

View File

@@ -778,6 +778,7 @@ multimodal_model_archs = [
"VILAForConditionalGeneration",
"Step3VLForConditionalGeneration",
"DotsVLMForCausalLM",
"DotsOCRForCausalLM",
"Sarashina2VisionForCausalLM",
]