Support Dots.ocr model (#11071)
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
from sglang.srt.configs.chatglm import ChatGLMConfig
|
||||
from sglang.srt.configs.dbrx import DbrxConfig
|
||||
from sglang.srt.configs.deepseekvl2 import DeepseekVL2Config
|
||||
from sglang.srt.configs.dots_ocr import DotsOCRConfig
|
||||
from sglang.srt.configs.dots_vlm import DotsVLMConfig
|
||||
from sglang.srt.configs.exaone import ExaoneConfig
|
||||
from sglang.srt.configs.janus_pro import MultiModalityConfig
|
||||
@@ -28,4 +29,5 @@ __all__ = [
|
||||
"Step3VisionEncoderConfig",
|
||||
"Qwen3NextConfig",
|
||||
"DotsVLMConfig",
|
||||
"DotsOCRConfig",
|
||||
]
|
||||
|
||||
64
python/sglang/srt/configs/dots_ocr.py
Normal file
64
python/sglang/srt/configs/dots_ocr.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from typing import Optional
|
||||
|
||||
from transformers import AutoProcessor, Qwen2_5_VLProcessor
|
||||
from transformers.image_processing_utils import BaseImageProcessor
|
||||
from transformers.models.qwen2 import Qwen2Config
|
||||
|
||||
from sglang.srt.configs.dots_vlm import DotsVisionConfig
|
||||
|
||||
|
||||
class DotsOCRConfig(Qwen2Config):
|
||||
model_type = "dots_ocr"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_token_id=151665,
|
||||
video_token_id=151656,
|
||||
vision_config: Optional[dict] = None,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.image_token_id = image_token_id
|
||||
self.video_token_id = video_token_id
|
||||
self.vision_config = DotsVisionConfig(**(vision_config or {}))
|
||||
|
||||
def save_pretrained(self, save_directory, **kwargs):
|
||||
self._auto_class = None
|
||||
super().save_pretrained(save_directory, **kwargs)
|
||||
|
||||
|
||||
class DummyVideoProcessor(BaseImageProcessor):
|
||||
model_input_names = ["pixel_values"]
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
|
||||
class DotsVLProcessor(Qwen2_5_VLProcessor):
|
||||
def __init__(
|
||||
self,
|
||||
image_processor=None,
|
||||
tokenizer=None,
|
||||
video_processor=None,
|
||||
chat_template=None,
|
||||
**kwargs
|
||||
):
|
||||
if video_processor is None:
|
||||
video_processor = DummyVideoProcessor()
|
||||
super().__init__(
|
||||
image_processor, tokenizer, video_processor, chat_template=chat_template
|
||||
)
|
||||
self.image_token = (
|
||||
"<|imgpad|>"
|
||||
if not hasattr(tokenizer, "image_token")
|
||||
else tokenizer.image_token
|
||||
)
|
||||
self.image_token_id = (
|
||||
tokenizer.image_token_id
|
||||
if getattr(tokenizer, "image_token_id", None) is not None
|
||||
else tokenizer.convert_tokens_to_ids(self.image_token)
|
||||
)
|
||||
|
||||
|
||||
AutoProcessor.register(DotsOCRConfig, DotsVLProcessor)
|
||||
@@ -778,6 +778,7 @@ multimodal_model_archs = [
|
||||
"VILAForConditionalGeneration",
|
||||
"Step3VLForConditionalGeneration",
|
||||
"DotsVLMForCausalLM",
|
||||
"DotsOCRForCausalLM",
|
||||
"Sarashina2VisionForCausalLM",
|
||||
]
|
||||
|
||||
|
||||
@@ -38,6 +38,7 @@ from sglang.srt.configs import (
|
||||
ChatGLMConfig,
|
||||
DbrxConfig,
|
||||
DeepseekVL2Config,
|
||||
DotsOCRConfig,
|
||||
DotsVLMConfig,
|
||||
ExaoneConfig,
|
||||
KimiVLConfig,
|
||||
@@ -62,6 +63,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
||||
LongcatFlashConfig.model_type: LongcatFlashConfig,
|
||||
Qwen3NextConfig.model_type: Qwen3NextConfig,
|
||||
DotsVLMConfig.model_type: DotsVLMConfig,
|
||||
DotsOCRConfig.model_type: DotsOCRConfig,
|
||||
}
|
||||
|
||||
for name, cls in _CONFIG_REGISTRY.items():
|
||||
|
||||
173
python/sglang/srt/models/dots_ocr.py
Normal file
173
python/sglang/srt/models/dots_ocr.py
Normal file
@@ -0,0 +1,173 @@
|
||||
# coding=utf-8
|
||||
# Adapted from Qwen2.5-VL SGLang implementation
|
||||
|
||||
import logging
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers.activations import ACT2FN
|
||||
|
||||
from sglang.srt.configs import DotsOCRConfig
|
||||
from sglang.srt.hf_transformers_utils import get_processor
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from sglang.srt.managers.mm_utils import (
|
||||
MultiModalityDataPaddingPatternMultimodalTokens,
|
||||
general_mm_embed_routine,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.dots_vlm_vit import DotsVisionTransformer
|
||||
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
||||
from sglang.srt.utils import add_prefix
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DotsOCRForCausalLM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: DotsOCRConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
# Initialize vision transformer
|
||||
self.visual = DotsVisionTransformer(
|
||||
config.vision_config,
|
||||
)
|
||||
|
||||
# Initialize language model
|
||||
self.model = Qwen2ForCausalLM(config, quant_config)
|
||||
|
||||
# Initialize LM head
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head = self.model.embed_tokens
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
)
|
||||
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
||||
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
|
||||
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
||||
|
||||
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
||||
# Extract pixel values and grid information (following reference pattern)
|
||||
pixel_values = torch.cat([item.feature for item in items], dim=0).type(
|
||||
self.visual.dtype
|
||||
)
|
||||
image_grid_thw = torch.concat(
|
||||
[item.image_grid_thw for item in items], dim=0
|
||||
).to(self.visual.device)
|
||||
|
||||
# Add dimension checks like in reference code
|
||||
assert pixel_values.dim() == 2, f"{pixel_values.dim()=}"
|
||||
assert image_grid_thw.dim() == 2, f"{image_grid_thw.dim()=}"
|
||||
|
||||
# Process through vision tower
|
||||
image_embeds = self.visual(pixel_values, image_grid_thw)
|
||||
|
||||
# Ensure consistent dtype for FlashInfer compatibility
|
||||
# Force bfloat16 to match model's expected dtype
|
||||
if hasattr(self.model, "embed_tokens"):
|
||||
target_dtype = self.model.embed_tokens.weight.dtype
|
||||
if image_embeds.dtype != target_dtype:
|
||||
image_embeds = image_embeds.to(target_dtype)
|
||||
|
||||
return image_embeds
|
||||
|
||||
def _pad_vit_attn_dummy_heads(self, name: str, loaded_weight: torch.Tensor):
|
||||
"""pad attn qkv weights for dummy heads"""
|
||||
num_dummy_heads = self.config.vision_config.num_dummy_heads
|
||||
if num_dummy_heads == 0:
|
||||
return loaded_weight
|
||||
head_dim = self.config.vision_config.head_dim
|
||||
|
||||
if "attn.qkv_proj" in name:
|
||||
wq, wk, wv = loaded_weight.chunk(3, dim=0)
|
||||
if name.endswith(".weight"):
|
||||
dummy_shape = [num_dummy_heads, head_dim, wq.shape[-1]]
|
||||
elif name.endswith(".bias"):
|
||||
dummy_shape = [num_dummy_heads, head_dim]
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported weight with name={name}")
|
||||
pad_func = lambda x: torch.cat(
|
||||
[x.unflatten(0, (-1, head_dim)), x.new_zeros(dummy_shape)], dim=0
|
||||
).flatten(0, 1)
|
||||
wq, wk, wv = pad_func(wq), pad_func(wk), pad_func(wv)
|
||||
loaded_weight = torch.cat([wq, wk, wv], dim=0)
|
||||
if "attn.proj.weight" in name:
|
||||
padded_weight = loaded_weight.new_zeros(
|
||||
loaded_weight.shape[0], head_dim * num_dummy_heads
|
||||
)
|
||||
loaded_weight = torch.cat([loaded_weight, padded_weight], dim=-1)
|
||||
if "attn.q_norm.weight" in name or "attn.k_norm.weight" in name:
|
||||
padded_weight = loaded_weight.new_zeros(head_dim * num_dummy_heads)
|
||||
loaded_weight = torch.cat([loaded_weight, padded_weight], dim=0)
|
||||
return loaded_weight
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
**kwargs: object,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = general_mm_embed_routine(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
forward_batch=forward_batch,
|
||||
multimodal_model=self,
|
||||
language_model=self.model,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
"""Load weights for the model, separating vision and language weights"""
|
||||
weights = list(weights)
|
||||
|
||||
# Separate vision tower weights and language model weights
|
||||
vision_weights = []
|
||||
language_weights = []
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
if name.startswith("vision_tower."):
|
||||
vision_name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
|
||||
|
||||
vision_weights.append((vision_name, loaded_weight))
|
||||
else:
|
||||
# All other weights go to language model
|
||||
language_weights.append((name, loaded_weight))
|
||||
|
||||
# Load vision tower weights
|
||||
vision_state_dict = dict(vision_weights)
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
|
||||
for name, loaded_weight in vision_state_dict.items():
|
||||
name = name.replace("vision_tower", "visual")
|
||||
if name not in params_dict:
|
||||
raise ValueError(f"Weight {name} not found in params_dict")
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
loaded_weight = self._pad_vit_attn_dummy_heads(name, loaded_weight)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
if language_weights:
|
||||
self.model.load_weights(language_weights)
|
||||
|
||||
def get_embed_and_head(self):
|
||||
return self.model.embed_tokens.weight, self.lm_head.weight
|
||||
|
||||
|
||||
EntryClass = [DotsOCRForCausalLM]
|
||||
@@ -5,6 +5,7 @@ from typing import Dict, List, Union
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from sglang.srt.models.dots_ocr import DotsOCRForCausalLM
|
||||
from sglang.srt.models.dots_vlm import DotsVLMForCausalLM
|
||||
from sglang.srt.multimodal.processors.base_processor import (
|
||||
BaseMultimodalProcessor,
|
||||
@@ -14,7 +15,7 @@ from sglang.srt.multimodal.processors.qwen_vl import resize_image_async
|
||||
|
||||
|
||||
class DotsVLMImageProcessor(BaseMultimodalProcessor):
|
||||
models = [DotsVLMForCausalLM]
|
||||
models = [DotsVLMForCausalLM, DotsOCRForCausalLM]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
||||
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
||||
@@ -82,11 +83,9 @@ class DotsVLMImageProcessor(BaseMultimodalProcessor):
|
||||
for image in base_output.images
|
||||
]
|
||||
base_output.images = await asyncio.gather(*resize_tasks)
|
||||
|
||||
combined_mm_item, input_ids, _ = self.process_and_combine_mm_data(
|
||||
base_output, self.mm_tokens
|
||||
)
|
||||
|
||||
if combined_mm_item is None:
|
||||
return None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user