From 770529a73172499a4d3e6135b1c63f70d63e1f5d Mon Sep 17 00:00:00 2001 From: Mick Date: Fri, 24 Oct 2025 03:15:17 +0800 Subject: [PATCH] model: support deepseek-ocr (#11891) Co-authored-by: yhyang201 <47235274+yhyang201@users.noreply.github.com> Co-authored-by: yhyang201 Co-authored-by: Shi Shuai <126407087+shuaills@users.noreply.github.com> Co-authored-by: Xinyuan Tong --- python/sglang/srt/configs/deepseek_ocr.py | 262 +++ python/sglang/srt/configs/deepseekvl2.py | 289 ++-- python/sglang/srt/configs/model_config.py | 1 + python/sglang/srt/model_loader/utils.py | 1 - python/sglang/srt/models/deepseek_ocr.py | 1516 +++++++++++++++++ python/sglang/srt/models/deepseek_v2.py | 1 - .../multimodal/processors/base_processor.py | 1 + .../srt/multimodal/processors/deepseek_ocr.py | 37 + python/sglang/srt/parser/conversation.py | 22 + .../sglang/srt/utils/hf_transformers_utils.py | 48 +- python/sglang/test/test_utils.py | 1 - test/srt/test_vision_openai_server_a.py | 56 + test/srt/test_vision_openai_server_common.py | 7 +- 13 files changed, 2125 insertions(+), 117 deletions(-) create mode 100644 python/sglang/srt/configs/deepseek_ocr.py create mode 100644 python/sglang/srt/models/deepseek_ocr.py create mode 100644 python/sglang/srt/multimodal/processors/deepseek_ocr.py diff --git a/python/sglang/srt/configs/deepseek_ocr.py b/python/sglang/srt/configs/deepseek_ocr.py new file mode 100644 index 000000000..4a4b2456c --- /dev/null +++ b/python/sglang/srt/configs/deepseek_ocr.py @@ -0,0 +1,262 @@ +from typing import Tuple + +import torchvision.transforms as T +from PIL import Image +from transformers import PretrainedConfig + +BASE_SIZE = 1024 +IMAGE_SIZE = 640 +CROP_MODE = True +MIN_CROPS = 2 +MAX_CROPS = 6 # max:9; If your GPU memory is small, it is recommended to set it to 6. +MAX_CONCURRENCY = 100 # If you have limited GPU memory, lower the concurrency count. +NUM_WORKERS = 64 # image pre-process (resize/padding) workers +PRINT_NUM_VIS_TOKENS = False +SKIP_REPEAT = True +MODEL_PATH = "deepseek-ai/DeepSeek-OCR" # change to your model path + +PROMPT = "\n<|grounding|>Convert the document to markdown." + + +class ImageTransform: + + def __init__( + self, + mean: Tuple[float, float, float] = (0.5, 0.5, 0.5), + std: Tuple[float, float, float] = (0.5, 0.5, 0.5), + normalize: bool = True, + ): + self.mean = mean + self.std = std + self.normalize = normalize + + transform_pipelines = [T.ToTensor()] + + if normalize: + transform_pipelines.append(T.Normalize(mean, std)) + + self.transform = T.Compose(transform_pipelines) + + def __call__(self, pil_img: Image.Image): + x = self.transform(pil_img) + return x + + +class VisionEncoderConfig(PretrainedConfig): + model_type: str = "vision" + + model_name: str = "vit_so400m_patch14_siglip_384.webli" + image_size: int = 384 + patch_size: int = 16 + width: int = 1024 + layers: int = 24 + heads: int = 16 + mlp_ratio: int = 4 + global_pool: str = "map" + ignore_head: bool = True + class_token: bool = False + num_classes: int = 0 + use_checkpoint: bool = False + weight_init: str = "skip" + deterministic: bool = False + num_recomputing_layers: int = 0 + + def __init__( + self, + model_name: str = "vit_so400m_patch14_siglip_384.webli", + image_size: int = 384, + patch_size: int = 16, + width: int = 1024, + layers: int = 24, + heads: int = 16, + mlp_ratio: int = 4, + global_pool: str = "map", + ignore_head: bool = True, + class_token: bool = False, + num_classes: int = 0, + use_checkpoint: bool = False, + **kwargs, + ): + self.model_name = model_name + self.image_size = image_size + self.patch_size = patch_size + self.width = width + self.layers = layers + self.heads = heads + self.mlp_ratio = mlp_ratio + self.global_pool = global_pool + self.ignore_head = ignore_head + self.class_token = class_token + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + + super().__init__(**kwargs) + + +class MlpProjectorConfig(PretrainedConfig): + model_type = "mlp_projector" + projector_type: str = "downsample_mlp_gelu" + input_dim: int = 1152 + n_embed: int = 2048 + depth: int = 2 + mlp_ratio: int = 1 + downsample_ratio: int = 2 + token_pooling: bool = False + + def __init__( + self, + projector_type: str = "downsample_mlp_gelu", + input_dim: int = 1152, + n_embed: int = 2048, + depth: int = 2, + mlp_ratio: int = 1, + downsample_ratio: int = 2, + **kwargs, + ): + self.projector_type = projector_type + self.input_dim = input_dim + self.n_embed = n_embed + self.depth = depth + self.mlp_ratio = mlp_ratio + self.downsample_ratio = downsample_ratio + + super().__init__(**kwargs) + + +class DeepseekV2Config(PretrainedConfig): + model_type = "deepseek_v2" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=102400, + hidden_size=4096, + intermediate_size=11008, + moe_intermediate_size=1407, + num_hidden_layers=30, + num_attention_heads=32, + num_key_value_heads=32, + n_shared_experts=None, + n_routed_experts=None, + ep_size=1, + routed_scaling_factor=1.0, + kv_lora_rank=512, + q_lora_rank=1536, + qk_rope_head_dim=64, + v_head_dim=128, + qk_nope_head_dim=128, + topk_method="gready", + n_group=None, + topk_group=None, + num_experts_per_tok=None, + moe_layer_freq=1, + first_k_dense_replace=0, + norm_topk_prob=False, + scoring_func="softmax", + aux_loss_alpha=0.001, + seq_aux=True, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=100000, + eos_token_id=100001, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + use_mla=True, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.moe_intermediate_size = moe_intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.n_shared_experts = n_shared_experts + self.n_routed_experts = n_routed_experts + self.ep_size = ep_size + self.routed_scaling_factor = routed_scaling_factor + self.kv_lora_rank = kv_lora_rank + self.q_lora_rank = q_lora_rank + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + self.qk_nope_head_dim = qk_nope_head_dim + self.topk_method = topk_method + self.n_group = n_group + self.topk_group = topk_group + self.num_experts_per_tok = num_experts_per_tok + self.moe_layer_freq = moe_layer_freq + self.first_k_dense_replace = first_k_dense_replace + self.norm_topk_prob = norm_topk_prob + self.scoring_func = scoring_func + self.aux_loss_alpha = aux_loss_alpha + self.seq_aux = seq_aux + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = float(rms_norm_eps) + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.use_mla = use_mla + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +class DeepseekVLV2Config(PretrainedConfig): + # model_type = "deepseek_vl_v2" + model_type = "deepseek-ocr" + vision_config: VisionEncoderConfig + projector_config: MlpProjectorConfig + + tile_tag: str = "2D" + global_view_pos: str = "head" + candidate_resolutions: tuple[tuple[int, int]] = ((384, 384),) + + def __init__( + self, + tile_tag: str = "tile_tag", + global_view_pos: str = "head", + candidate_resolutions: tuple[tuple[int, int]] = ((384, 384),), + **kwargs, + ): + super().__init__(**kwargs) + + vision_config = kwargs.get("vision_config", {}) + self.vision_config = VisionEncoderConfig(**vision_config) + + projector_config = kwargs.get("projector_config", {}) + self.projector_config = MlpProjectorConfig(**projector_config) + + language_config = kwargs.get("language_config", {}) + self.text_config = DeepseekV2Config(**language_config) + + self.tile_tag = tile_tag + self.global_view_pos = global_view_pos + self.candidate_resolutions = candidate_resolutions + self.vocab_size = self.text_config.vocab_size + self.hidden_size = self.text_config.hidden_size + + +class DeepseekOCRConfig(DeepseekV2Config): + model_type = "DeepseekOCR" diff --git a/python/sglang/srt/configs/deepseekvl2.py b/python/sglang/srt/configs/deepseekvl2.py index 9621f058b..f18efa314 100644 --- a/python/sglang/srt/configs/deepseekvl2.py +++ b/python/sglang/srt/configs/deepseekvl2.py @@ -11,6 +11,8 @@ from transformers import ( ProcessorMixin, ) +from sglang.srt.configs.deepseek_ocr import BASE_SIZE, IMAGE_SIZE, MAX_CROPS, MIN_CROPS + def select_best_resolution(image_size, candidate_resolutions): # used for cropping @@ -61,6 +63,7 @@ class DictOutput(object): class VLChatProcessorOutput(DictOutput): input_ids: torch.LongTensor target_ids: torch.LongTensor + images_crop: torch.LongTensor pixel_values: ( torch.Tensor ) # rename from "images" to "pixel_values" for compatibility @@ -104,6 +107,68 @@ class ImageTransform(object): return x +def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): + best_ratio_diff = float("inf") + best_ratio = (1, 1) + area = width * height + for ratio in target_ratios: + target_aspect_ratio = ratio[0] / ratio[1] + ratio_diff = abs(aspect_ratio - target_aspect_ratio) + if ratio_diff < best_ratio_diff: + best_ratio_diff = ratio_diff + best_ratio = ratio + elif ratio_diff == best_ratio_diff: + if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: + best_ratio = ratio + return best_ratio + + +def dynamic_preprocess( + image, min_num=MIN_CROPS, max_num=MAX_CROPS, image_size=640, use_thumbnail=False +): + orig_width, orig_height = image.size + aspect_ratio = orig_width / orig_height + + # calculate the existing image aspect ratio + target_ratios = set( + (i, j) + for n in range(min_num, max_num + 1) + for i in range(1, n + 1) + for j in range(1, n + 1) + if i * j <= max_num and i * j >= min_num + ) + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + + # find the closest aspect ratio to the target + target_aspect_ratio = find_closest_aspect_ratio( + aspect_ratio, target_ratios, orig_width, orig_height, image_size + ) + + # calculate the target width and height + target_width = image_size * target_aspect_ratio[0] + target_height = image_size * target_aspect_ratio[1] + blocks = target_aspect_ratio[0] * target_aspect_ratio[1] + + # resize the image + resized_img = image.resize((target_width, target_height)) + processed_images = [] + for i in range(blocks): + box = ( + (i % (target_width // image_size)) * image_size, + (i // (target_width // image_size)) * image_size, + ((i % (target_width // image_size)) + 1) * image_size, + ((i // (target_width // image_size)) + 1) * image_size, + ) + # split the image + split_img = resized_img.crop(box) + processed_images.append(split_img) + assert len(processed_images) == blocks + if use_thumbnail and len(processed_images) != 1: + thumbnail_img = image.resize((image_size, image_size)) + processed_images.append(thumbnail_img) + return processed_images, target_aspect_ratio + + class DeepseekVLV2Processor(ProcessorMixin): tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast") attributes = ["tokenizer"] @@ -133,7 +198,7 @@ class DeepseekVLV2Processor(ProcessorMixin): self.image_std = image_std self.normalize = normalize self.downsample_ratio = downsample_ratio - + self.base_size = BASE_SIZE self.image_transform = ImageTransform( mean=image_mean, std=image_std, normalize=normalize ) @@ -176,7 +241,7 @@ class DeepseekVLV2Processor(ProcessorMixin): **kwargs, ) - def format_messages_v2(self, messages, pil_images, max_req_input_len=-1): + def format_messages_v2(self, messages: str, pil_images, max_req_input_len=-1): """play the role of format_messages_v2 and get_images_info in the last version""" tokenized_data = [] masked_tokenized_data = [] # labels @@ -186,35 +251,34 @@ class DeepseekVLV2Processor(ProcessorMixin): image_index = 0 image_token_cnt = messages.count(self.image_token) - tokenized_str, images, seq_mask, spatial_crop = self.tokenize_with_images( + ( + input_ids, + images, + images_crop, + seq_mask, + spatial_crop, + num_image_tokens, + image_shapes, + ) = self.tokenize_with_images( messages, pil_images[image_index : image_index + image_token_cnt], bos=True, eos=True, cropping=len(pil_images) <= 2, - max_req_input_len=max_req_input_len, ) image_index = image_token_cnt - tokenized_data += tokenized_str - if self.mask_prompt: - masked_tokenized_data += [self.ignore_id] * len(tokenized_str) - else: - masked_tokenized_data += tokenized_str images_list += images images_seq_mask += seq_mask - images_spatial_crop += spatial_crop - - assert len(tokenized_data) == len( - images_seq_mask - ), f"format_messages_v2: tokenized_str's length {len(tokenized_str)} is not equal to imags_seq_mask's length {len(images_seq_mask)}" + images_spatial_crop = spatial_crop return ( - tokenized_data, + input_ids, masked_tokenized_data, images_list, images_seq_mask, images_spatial_crop, + images_crop, ) @property @@ -251,6 +315,7 @@ class DeepseekVLV2Processor(ProcessorMixin): inference_mode: bool = True, system_prompt: str = "", max_req_input_len: int = -1, + cropping: bool = True, **kwargs, ): """ @@ -274,47 +339,22 @@ class DeepseekVLV2Processor(ProcessorMixin): - num_image_tokens (List[int]): the number of image tokens """ - assert ( - prompt is None or conversations is None - ), "prompt and conversations cannot be used at the same time." - + prompt = conversations or prompt ( - tokenized_str, + input_ids, masked_tokenized_str, images_list, images_seq_mask, images_spatial_crop, - ) = self.format_messages_v2(conversations, images, max_req_input_len) + images_crop, + ) = self.format_messages_v2(prompt, images, max_req_input_len) - assert ( - len(tokenized_str) == len(images_seq_mask) == len(masked_tokenized_str) - ), ( - f"tokenized_str's length {len(tokenized_str)}, input_ids' length {len(masked_tokenized_str)}, " - f"imags_seq_mask's length {len(images_seq_mask)}, are not equal" - ) - - input_ids = torch.LongTensor(tokenized_str) target_ids = torch.LongTensor(masked_tokenized_str) - images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool) - - # set input_ids < 0 | input_ids == self.image_token_id as ignore_id - target_ids[(input_ids < 0) | (input_ids == self.image_token_id)] = ( - self.ignore_id - ) - input_ids[input_ids < 0] = self.pad_id - - if inference_mode: - assert input_ids[-1] == self.eos_id - input_ids = input_ids[:-1] - target_ids = target_ids[:-1] - images_seq_mask = images_seq_mask[:-1] if len(images_list) == 0: images = torch.zeros((1, 3, self.image_size, self.image_size)) - images_spatial_crop = torch.zeros((1, 2), dtype=torch.long) else: images = torch.stack(images_list, dim=0) - images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long) images_spatial_crop = torch.stack( [images_spatial_crop], dim=0 @@ -323,6 +363,7 @@ class DeepseekVLV2Processor(ProcessorMixin): prepare = VLChatProcessorOutput( input_ids=input_ids, target_ids=target_ids, + images_crop=images_crop, pixel_values=images, images_seq_mask=images_seq_mask, images_spatial_crop=images_spatial_crop, @@ -340,10 +381,14 @@ class DeepseekVLV2Processor(ProcessorMixin): inference_mode: bool = True, system_prompt: str = "", max_req_input_len: int = -1, + text: list[str] = None, **kwargs, ): + assert text is None or isinstance(text, list) + if text is not None: + text = text[0] prepare = self.process_one( - prompt=prompt, + prompt=prompt or text, conversations=conversations, images=images, apply_sft_format=apply_sft_format, @@ -368,85 +413,83 @@ class DeepseekVLV2Processor(ProcessorMixin): bos: bool = True, eos: bool = True, cropping: bool = True, - max_req_input_len: int = -1, ): """Tokenize text with tags.""" - images_list, images_seq_mask, images_spatial_crop = [], [], [] + + conversation = conversation + assert conversation.count(self.image_token) == len(images) text_splits = conversation.split(self.image_token) + images_list, images_crop_list, images_seq_mask, images_spatial_crop = ( + [], + [], + [], + [], + ) + image_shapes = [] + num_image_tokens = [] tokenized_str = [] for text_sep, image in zip(text_splits, images): """encode text_sep""" tokenized_sep = self.encode(text_sep, bos=False, eos=False) + tokenized_str += tokenized_sep images_seq_mask += [False] * len(tokenized_sep) - """select best resolution for anyres""" - if cropping: - best_width, best_height = select_best_resolution( - image.size, self.candidate_resolutions - ) + image_shapes.append(image.size) + + if image.size[0] <= 640 and image.size[1] <= 640: + crop_ratio = [1, 1] else: - best_width, best_height = self.image_size, self.image_size - # print(image.size, (best_width, best_height)) # check the select_best_resolutions func + if cropping: + images_crop_raw, crop_ratio = dynamic_preprocess( + image, image_size=IMAGE_SIZE + ) + else: + crop_ratio = [1, 1] """process the global view""" + if self.image_size <= 640 and not cropping: + image = image.resize((self.image_size, self.image_size)) + global_view = ImageOps.pad( image, - (self.image_size, self.image_size), + (self.base_size, self.base_size), color=tuple(int(x * 255) for x in self.image_transform.mean), ) images_list.append(self.image_transform(global_view)) - """process the local views""" - local_view = ImageOps.pad( - image, - (best_width, best_height), - color=tuple(int(x * 255) for x in self.image_transform.mean), - ) - for i in range(0, best_height, self.image_size): - for j in range(0, best_width, self.image_size): - images_list.append( - self.image_transform( - local_view.crop( - (j, i, j + self.image_size, i + self.image_size) - ) - ) - ) - - """record height / width crop num""" - num_width_tiles, num_height_tiles = ( - best_width // self.image_size, - best_height // self.image_size, - ) + num_width_tiles, num_height_tiles = crop_ratio images_spatial_crop.append([num_width_tiles, num_height_tiles]) + if num_width_tiles > 1 or num_height_tiles > 1: + for i in range(len(images_crop_raw)): + images_crop_list.append(self.image_transform(images_crop_raw[i])) + """add image tokens""" - h = w = math.ceil( + num_queries = math.ceil( (self.image_size // self.patch_size) / self.downsample_ratio ) - # global views tokens h * (w + 1), 1 is for line separator - tokenized_image = [self.image_token_id] * h * (w + 1) - # add a separator between global and local views - tokenized_image += [self.image_token_id] - # local views tokens, (num_height_tiles * h) * (num_width_tiles * w + 1) - tokenized_image += ( - [self.image_token_id] - * (num_height_tiles * h) - * (num_width_tiles * w + 1) + num_queries_base = math.ceil( + (self.base_size // self.patch_size) / self.downsample_ratio ) + tokenized_image = ( + [self.image_token_id] * num_queries_base + [self.image_token_id] + ) * num_queries_base + tokenized_image += [self.image_token_id] + if num_width_tiles > 1 or num_height_tiles > 1: + tokenized_image += ( + [self.image_token_id] * (num_queries * num_width_tiles) + + [self.image_token_id] + ) * (num_queries * num_height_tiles) tokenized_str += tokenized_image + images_seq_mask += [True] * len(tokenized_image) - # print(width_crop_num, height_crop_num, len(tokenized_image)) # test the correctness of the number of image-related tokens + num_image_tokens.append(len(tokenized_image)) """process the last text split""" tokenized_sep = self.encode(text_splits[-1], bos=False, eos=False) - # deal with video, limit with request len - if max_req_input_len > -1: - if max_req_input_len < len(tokenized_sep) + len(tokenized_str) - 1: - rest = max_req_input_len - len(tokenized_sep) - 1 - 1024 - tokenized_str = tokenized_str[:rest] - images_seq_mask = images_seq_mask[:rest] + tokenized_str += tokenized_sep images_seq_mask += [False] * len(tokenized_sep) @@ -462,7 +505,64 @@ class DeepseekVLV2Processor(ProcessorMixin): images_seq_mask ), f"tokenize_with_images func: tokenized_str's length {len(tokenized_str)} is not equal to imags_seq_mask's length {len(images_seq_mask)}" - return tokenized_str, images_list, images_seq_mask, images_spatial_crop + masked_tokenized_str = [] + for token_index in tokenized_str: + if token_index != self.image_token_id: + masked_tokenized_str.append(token_index) + else: + masked_tokenized_str.append(self.ignore_id) + + assert ( + len(tokenized_str) == len(images_seq_mask) == len(masked_tokenized_str) + ), ( + f"tokenized_str's length {len(tokenized_str)}, input_ids' length {len(masked_tokenized_str)}, " + f"imags_seq_mask's length {len(images_seq_mask)}, are not equal" + ) + input_ids = torch.LongTensor(tokenized_str) + target_ids = torch.LongTensor(masked_tokenized_str) + images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool) + + # set input_ids < 0 | input_ids == self.image_token_id as ignore_id + target_ids[(input_ids < 0) | (input_ids == self.image_token_id)] = ( + self.ignore_id + ) + input_ids[input_ids < 0] = self.pad_id + + inference_mode = True + + if inference_mode: + # Remove the ending eos token + assert input_ids[-1] == self.eos_id + input_ids = input_ids[:-1] + target_ids = target_ids[:-1] + images_seq_mask = images_seq_mask[:-1] + + if len(images_list) == 0: + pixel_values = torch.zeros((1, 3, self.base_size, self.base_size)) + images_spatial_crop = torch.zeros((1, 1), dtype=torch.long) + images_crop = torch.zeros( + (1, 3, self.image_size, self.image_size) + ).unsqueeze(0) + else: + pixel_values = torch.stack(images_list, dim=0) + images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long) + if images_crop_list: + images_crop = torch.stack(images_crop_list, dim=0).unsqueeze(0) + else: + images_crop = torch.zeros( + (1, 3, self.image_size, self.image_size) + ).unsqueeze(0) + + input_ids = input_ids.unsqueeze(0) + return ( + input_ids, + pixel_values, + images_crop, + images_seq_mask, + images_spatial_crop, + num_image_tokens, + image_shapes, + ) class DeepseekVL2VisionEncoderConfig(PretrainedConfig): @@ -547,7 +647,6 @@ class DeepseekVL2MlpProjectorConfig(PretrainedConfig): class DeepseekV2Config(PretrainedConfig): - model_type = "deepseek_v2" keys_to_ignore_at_inference = ["past_key_values"] diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index a674178bc..71b420d50 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -921,6 +921,7 @@ multimodal_model_archs = [ "DotsVLMForCausalLM", "DotsOCRForCausalLM", "Sarashina2VisionForCausalLM", + "DeepseekOCRForCausalLM", ] diff --git a/python/sglang/srt/model_loader/utils.py b/python/sglang/srt/model_loader/utils.py index f6ad79010..3e315e10a 100644 --- a/python/sglang/srt/model_loader/utils.py +++ b/python/sglang/srt/model_loader/utils.py @@ -99,7 +99,6 @@ def get_model_architecture(model_config: ModelConfig) -> Tuple[Type[nn.Module], if not is_native_supported or model_config.model_impl == ModelImpl.TRANSFORMERS: architectures = resolve_transformers_arch(model_config, architectures) - return ModelRegistry.resolve_model_cls(architectures) diff --git a/python/sglang/srt/models/deepseek_ocr.py b/python/sglang/srt/models/deepseek_ocr.py new file mode 100644 index 000000000..fca372a18 --- /dev/null +++ b/python/sglang/srt/models/deepseek_ocr.py @@ -0,0 +1,1516 @@ +# Copyright 2025 The SwissAI Initiative +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# Adapted from +# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1 +"""Inference-only Apertus model compatible with HuggingFace weights.""" +import copy +import logging +import math +from functools import partial +from typing import Iterable, List, Optional, Set, Tuple, Type, TypeAlias, Union + +import torch +import torch.nn.functional as F +from torch import Tensor, nn +from transformers.models.vitdet.modeling_vitdet import get_rel_pos + +from sglang.srt.configs.deepseek_ocr import DeepseekVLV2Config +from sglang.srt.layers.quantization import QuantizationConfig +from sglang.srt.managers.mm_utils import ( + MultiModalityDataPaddingPatternMultimodalTokens, + general_mm_embed_routine, +) +from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.deepseek import DeepseekForCausalLM +from sglang.srt.models.deepseek_v2 import DeepseekV2ForCausalLM, DeepseekV3ForCausalLM +from sglang.srt.models.transformers import maybe_prefix + +NestedTensors: TypeAlias = Union[ + list["NestedTensors"], + list["torch.Tensor"], + "torch.Tensor", + tuple["torch.Tensor", ...], +] + +MultiModalEmbeddings: TypeAlias = list[Tensor] | Tensor | tuple[Tensor, ...] + +logger = logging.getLogger(__name__) + + +def _flatten_embeddings(embeddings: NestedTensors) -> torch.Tensor: + """ + Recursively flattens and concatenates NestedTensors on all but the last + dimension. + """ + + if isinstance(embeddings, torch.Tensor): + # Flatten all but the last dimension. + return embeddings.flatten(0, -2) + + return torch.cat(tuple(_flatten_embeddings(t) for t in embeddings)) + + +def _embedding_count_expression(embeddings: NestedTensors) -> str: + """ + Constructs a debugging representation of the number of embeddings in the + NestedTensors. + """ + + if isinstance(embeddings, torch.Tensor): + return " x ".join([str(dim) for dim in embeddings.shape[:-1]]) + + return " + ".join(_embedding_count_expression(inner) for inner in embeddings) + + +def _merge_multimodal_embeddings( + inputs_embeds: torch.Tensor, + multimodal_embeddings: NestedTensors, + is_multimodal: torch.Tensor, +) -> torch.Tensor: + """ + Merge `multimodal_embeddings` into `inputs_embeds` by overwriting the + positions in `inputs_embeds` corresponding to placeholder tokens in + `input_ids`. + + Note: + This updates `inputs_embeds` in place. + """ + if len(multimodal_embeddings) == 0: + return inputs_embeds + + mm_embeds_flat = _flatten_embeddings(multimodal_embeddings) + input_dtype = inputs_embeds.dtype + + try: + # NOTE: This can avoid D2H sync (#22105), but fails to + # raise an error if is_multimodal.sum() < len(mm_embeds_flat) + inputs_embeds.masked_scatter_( + is_multimodal.unsqueeze(-1), mm_embeds_flat.to(dtype=input_dtype) + ) + except RuntimeError as e: + num_actual_tokens = len(mm_embeds_flat) + num_expected_tokens = is_multimodal.sum().item() + + if num_actual_tokens != num_expected_tokens: + expr = _embedding_count_expression(multimodal_embeddings) + + raise ValueError( + f"Attempted to assign {expr} = {num_actual_tokens} " + f"multimodal tokens to {num_expected_tokens} placeholders" + ) from e + + raise ValueError("Error during masked scatter operation") from e + + return inputs_embeds + + +def isin_list( + elements: torch.Tensor, + test_elements_list: list[int], +) -> torch.Tensor: + test_elements = torch.tensor(test_elements_list, pin_memory=True).to( + device=elements.device, non_blocking=True + ) + + return torch.isin(elements, test_elements) + + +def merge_multimodal_embeddings( + input_ids: torch.Tensor, + inputs_embeds: torch.Tensor, + multimodal_embeddings: NestedTensors, + placeholder_token_id: int | list[int], +) -> torch.Tensor: + """ + Merge `multimodal_embeddings` into `inputs_embeds` by overwriting the + positions in `inputs_embeds` corresponding to placeholder tokens in + `input_ids`. + + `placeholder_token_id` can be a list of token ids (e.g, token ids + of img_start, img_break, and img_end tokens) when needed: This means + the order of these tokens in the `input_ids` MUST MATCH the order of + their embeddings in `multimodal_embeddings` since we need to + slice-merge instead of individually scattering. + + For example, if input_ids is "TTTTTSIIIBIIIBIIIETTT", where + - T is text token + - S is image start token + - I is image embedding token + - B is image break token + - E is image end token. + + Then the image embeddings (that correspond to I's) from vision encoder + must be padded with embeddings of S, B, and E in the same order of + input_ids for a correct embedding merge. + + Note: + This updates `inputs_embeds` in place. + """ + if isinstance(placeholder_token_id, list): + is_multimodal = isin_list(input_ids, placeholder_token_id) + else: + is_multimodal = input_ids == placeholder_token_id + + return _merge_multimodal_embeddings( + inputs_embeds, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + ) + + +class MlpProjector(nn.Module): + + def __init__( + self, + projector_type, + input_dim, + n_embed, + depth=1, + mlp_ratio=1, + downsample_ratio=4, + ): + self.projector_type = projector_type + self.input_dim = input_dim + self.n_embed = n_embed + self.depth = depth + self.token_pooling = False + self.conv_fusion_high_low_features = False + + super().__init__() + + if projector_type == "identity": + modules = nn.Identity() + + elif projector_type == "linear": + modules = nn.Linear(input_dim, n_embed) + + elif projector_type == "mlp_gelu": + mlp_depth = depth + modules = [nn.Linear(input_dim, n_embed)] + for _ in range(1, mlp_depth): + modules.append(nn.GELU()) + modules.append(nn.Linear(n_embed, n_embed)) + modules = nn.Sequential(*modules) + + elif projector_type == "normlayer_downsample_mlp_gelu": + mlp_depth = depth + mlp_ratio = mlp_ratio + modules = [ + nn.LayerNorm(input_dim * downsample_ratio * downsample_ratio), + nn.Linear( + input_dim * downsample_ratio * downsample_ratio, + n_embed * mlp_ratio, + ), + ] + for _ in range(1, mlp_depth - 1): + modules.append(nn.GELU()) + modules.append(nn.Linear(n_embed * mlp_ratio, n_embed * mlp_ratio)) + modules.append(nn.GELU()) + modules.append(nn.Linear(n_embed * mlp_ratio, n_embed)) + modules = nn.Sequential(*modules) + + elif projector_type == "downsample_mlp_gelu": + mlp_depth = depth + mlp_ratio = mlp_ratio + modules = [ + nn.Linear( + input_dim * downsample_ratio * downsample_ratio, + n_embed * mlp_ratio, + ) + ] + for _ in range(1, mlp_depth - 1): + modules.append(nn.GELU()) + modules.append(nn.Linear(n_embed * mlp_ratio, n_embed * mlp_ratio)) + modules.append(nn.GELU()) + modules.append(nn.Linear(n_embed * mlp_ratio, n_embed)) + modules = nn.Sequential(*modules) + + elif projector_type == "low_high_hybrid_split_mlp_gelu": + mlp_depth = depth + self.high_up_proj = nn.Linear(input_dim, n_embed // 2) + self.low_up_proj = nn.Linear(input_dim, n_embed // 2) + + modules = [] + for _ in range(1, mlp_depth): + modules.append(nn.GELU()) + modules.append(nn.Linear(n_embed, n_embed)) + modules = nn.Sequential(*modules) + + elif projector_type == "hybrid_split_feature_mlp_gelu": + mlp_depth = depth + channel_div = 0.5 + self.high_up_proj = nn.Linear(input_dim[0], int(n_embed * channel_div)) + self.low_up_proj = nn.Linear( + input_dim[1], n_embed - int(n_embed * channel_div) + ) + + modules = [] + for _ in range(1, mlp_depth): + modules.append(nn.GELU()) + modules.append(nn.Linear(n_embed, n_embed)) + modules = nn.Sequential(*modules) + + elif projector_type == "low_high_split_mlp_gelu": + mlp_depth = depth + modules = [] + for _ in range(1, mlp_depth): + modules.append(nn.GELU()) + modules.append(nn.Linear(n_embed // 2, n_embed // 2)) + modules = nn.Sequential(*modules) + self.high_layers = nn.Sequential(*modules) + self.low_layers = copy.deepcopy(modules) + + else: + raise ValueError(f"Unknown projector type: {projector_type}") + + self.layers = modules + + def forward(self, x): + if self.token_pooling: + batch_size, wxh, channels = x.shape + w = h = int(wxh**0.5) + x = x.view(batch_size, w, h, channels) + x = x.permute(0, 3, 1, 2) + patches = x.unfold(2, 2, 2).unfold(3, 2, 2) + batch_size, channels, h_patches, w_patches, _, _ = patches.size() + # Concatenate on channel dimension + patches = patches.contiguous().view( + batch_size, channels, h_patches * w_patches, -1 + ) + + # Pass through linear layer + patches = patches.permute(0, 2, 1, 3).contiguous() + patches = patches.view(batch_size, h_patches * w_patches, channels * 4) + + x = self.token_pooling_layer(patches) + + if self.conv_fusion_high_low_features: + x = self.fusion_layer(x[:, 0]) + x[:, 1] + + if self.projector_type == "low_high_hybrid_split_mlp_gelu": + high_x, low_x = x[0], x[1] + high_x = self.high_up_proj(high_x) + low_x = self.low_up_proj(low_x) + x = torch.concat([high_x, low_x], dim=-1) + + if self.projector_type == "hybrid_split_feature_mlp_gelu": + high_x = x[..., : self.input_dim[0]] + low_x = x[..., self.input_dim[0] :] + high_x = self.high_up_proj(high_x) + low_x = self.low_up_proj(low_x) + x = torch.concat([high_x, low_x], dim=-1) + + if self.projector_type == "low_high_split_mlp_gelu": + high_x, low_x = x[0], x[1] + high_x = self.high_layers(high_x) + low_x = self.low_layers(low_x) + x = torch.concat([high_x, low_x], dim=-1) + return x + + if ( + self.projector_type == "downsample_mlp_gelu" + or self.projector_type == "normlayer_downsample_mlp_gelu" + ): + bs, hw, input_dim = x.shape + h = w = int((hw) ** 0.5) + + """compute padding""" + if h % self.downsample_ratio: + pad = self.downsample_ratio - h % self.downsample_ratio + else: + pad = 0 + x = x.reshape(bs, h, w, input_dim) + if pad > 0: + x = F.pad(x, (0, 0, 0, pad, 0, pad), "constant", 0) + + """4 to 1 concat""" + x = x.permute(0, 3, 1, 2) # B, C, H, W + x = F.unfold( + x, + kernel_size=self.downsample_ratio, + stride=self.downsample_ratio, + padding=0, + ) # B, C*4, HW // 4 + x = x.permute(0, 2, 1) + + return self.layers(x) + + +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +class MLPBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + mlp_dim: int, + act: Type[nn.Module] = nn.GELU, + ) -> None: + super().__init__() + self.lin1 = nn.Linear(embedding_dim, mlp_dim) + self.lin2 = nn.Linear(mlp_dim, embedding_dim) + self.act = act() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.lin2(self.act(self.lin1(x))) + + +def add_decomposed_rel_pos( + q: torch.Tensor, + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + q_size: Tuple[int, int], + k_size: Tuple[int, int], +) -> torch.Tensor: + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 + Args: + q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). + rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. + rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. + q_size (Tuple): spatial sequence size of query q with (q_h, q_w). + k_size (Tuple): spatial sequence size of key k with (k_h, k_w). + Returns: + attn (Tensor): attention map with added relative positional embeddings. + """ + q_h, q_w = q_size + k_h, k_w = k_size + Rh = get_rel_pos(q_h, k_h, rel_pos_h) + Rw = get_rel_pos(q_w, k_w, rel_pos_w) + + B, _, dim = q.shape + r_q = q.reshape(B, q_h, q_w, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) + rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) + rel_h = rel_h.unsqueeze(-1) + rel_w = rel_w.unsqueeze(-2) + rel_h = rel_h.reshape(B, q_h * q_w, k_h, 1) + rel_w = rel_w.reshape(B, q_h * q_w, 1, k_w) + + return rel_h, rel_w + + +class Attention(nn.Module): + """Multi-head Attention block with relative position embeddings.""" + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + input_size (tuple(int, int) or None): Input resolution for calculating the relative + positional parameter size. + """ + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + self.use_rel_pos = use_rel_pos + if self.use_rel_pos: + assert ( + input_size is not None + ), "Input size must be provided if using relative positional encoding." + # initialize relative positional embeddings + self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, H, W, _ = x.shape + # qkv with shape (3, B, nHead, H * W, C) + qkv = ( + self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + ) + # q, k, v with shape (B * nHead, H * W, C) + q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) + + rel_h, rel_w = None, None + if self.use_rel_pos: + rel_h, rel_w = add_decomposed_rel_pos( + q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W) + ) + + q = q.view(B, self.num_heads, H * W, -1) + k = k.view(B, self.num_heads, H * W, -1) + v = v.view(B, self.num_heads, H * W, -1) + + if self.use_rel_pos: + rel_h = rel_h.view( + B, self.num_heads, rel_h.size(1), rel_h.size(2), rel_h.size(3) + ) + rel_w = rel_w.view( + B, self.num_heads, rel_w.size(1), rel_w.size(2), rel_w.size(3) + ) + attn_bias = (rel_h + rel_w).view( + B, self.num_heads, rel_h.size(2), rel_h.size(3) * rel_w.size(4) + ) + x = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=attn_bias + ) + # x = _attention_rel_h_rel_w(q, k, v, rel_h, rel_w) + else: + x = torch.nn.functional.scaled_dot_product_attention(q, k, v) + + x = ( + x.view(B, self.num_heads, H, W, -1) + .permute(0, 2, 3, 1, 4) + .reshape(B, H, W, -1) + ) + + x = self.proj(x) + + return x + + +def window_partition( + x: torch.Tensor, window_size: int +) -> Tuple[torch.Tensor, Tuple[int, int]]: + """ + Partition into non-overlapping windows with padding if needed. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = ( + x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + ) + return windows, (Hp, Wp) + + +def window_unpartition( + windows: torch.Tensor, + window_size: int, + pad_hw: Tuple[int, int], + hw: Tuple[int, int], +) -> torch.Tensor: + """ + Window unpartition into original sequences and removing padding. + Args: + windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (Hp, Wp). + hw (Tuple): original height and width (H, W) before padding. + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view( + B, Hp // window_size, Wp // window_size, window_size, window_size, -1 + ) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x + + +class Block(nn.Module): + """Transformer blocks with support of window attention and residual propagation blocks""" + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. If it equals 0, then + use global attention. + input_size (tuple(int, int) or None): Input resolution for calculating the relative + positional parameter size. + """ + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + input_size=input_size if window_size == 0 else (window_size, window_size), + ) + + self.norm2 = norm_layer(dim) + self.mlp = MLPBlock( + embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer + ) + + self.window_size = window_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shortcut = x + x = self.norm1(x) + # Window partition + if self.window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, self.window_size) + + x = self.attn(x) + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, self.window_size, pad_hw, (H, W)) + + x = shortcut + x + x = x + self.mlp(self.norm2(x)) + + return x + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding. + """ + + def __init__( + self, + kernel_size: Tuple[int, int] = (16, 16), + stride: Tuple[int, int] = (16, 16), + padding: Tuple[int, int] = (0, 0), + in_chans: int = 3, + embed_dim: int = 768, + ) -> None: + """ + Args: + kernel_size (Tuple): kernel size of the projection layer. + stride (Tuple): stride of the projection layer. + padding (Tuple): padding size of the projection layer. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + """ + super().__init__() + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + # B C H W -> B H W C + x = x.permute(0, 2, 3, 1) + return x + + +def get_abs_pos_sam(abs_pos, tgt_size): + dtype = abs_pos.dtype + + src_size = abs_pos.size(1) + + if src_size != tgt_size: + old_pos_embed = abs_pos.permute(0, 3, 1, 2) + old_pos_embed = old_pos_embed.to(torch.float32) + new_pos_embed = F.interpolate( + old_pos_embed, + size=(tgt_size, tgt_size), + mode="bicubic", + antialias=True, + align_corners=False, + ).to(dtype) + new_pos_embed = new_pos_embed.permute(0, 2, 3, 1) + return new_pos_embed + else: + return abs_pos + + +# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa +class ImageEncoderViT(nn.Module): + def __init__( + self, + img_size: int = 1024, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + out_chans: int = 256, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_abs_pos: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + global_attn_indexes: Tuple[int, ...] = (), + ) -> None: + """ + Args: + img_size (int): Input image size. + patch_size (int): Patch size. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + depth (int): Depth of ViT. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_abs_pos (bool): If True, use absolute positional embeddings. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. + global_attn_indexes (list): Indexes for blocks using global attention. + """ + super().__init__() + self.img_size = img_size + + self.patch_embed = PatchEmbed( + kernel_size=(patch_size, patch_size), + stride=(patch_size, patch_size), + in_chans=in_chans, + embed_dim=embed_dim, + ) + + self.pos_embed: Optional[nn.Parameter] = None + if use_abs_pos: + # Initialize absolute positional embedding with pretrain image size. + self.pos_embed = nn.Parameter( + torch.zeros( + 1, img_size // patch_size, img_size // patch_size, embed_dim + ) + ) + + self.blocks = nn.ModuleList() + for i in range(depth): + block = Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + act_layer=act_layer, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + window_size=window_size if i not in global_attn_indexes else 0, + input_size=(img_size // patch_size, img_size // patch_size), + ) + self.blocks.append(block) + + self.neck = nn.Sequential( + nn.Conv2d( + embed_dim, + out_chans, + kernel_size=1, + bias=False, + ), + LayerNorm2d(out_chans), + nn.Conv2d( + out_chans, + out_chans, + kernel_size=3, + padding=1, + bias=False, + ), + LayerNorm2d(out_chans), + ) + + self.net_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False) + self.net_3 = nn.Conv2d( + 512, 1024, kernel_size=3, stride=2, padding=1, bias=False + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.patch_embed(x) + if self.pos_embed is not None: + x = x + get_abs_pos_sam(self.pos_embed, x.size(1)) + + for blk in self.blocks: + x = blk(x) + + x = self.neck(x.permute(0, 3, 1, 2)) + x2 = self.net_2(x) + x3 = self.net_3(x2.clone()) + + return x3 + + +def _build_sam( + encoder_embed_dim, + encoder_depth, + encoder_num_heads, + encoder_global_attn_indexes, + checkpoint=None, +): + prompt_embed_dim = 256 + image_size = 1024 + vit_patch_size = 16 + image_encoder = ImageEncoderViT( + depth=encoder_depth, + embed_dim=encoder_embed_dim, + img_size=image_size, + mlp_ratio=4, + norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), + num_heads=encoder_num_heads, + patch_size=vit_patch_size, + qkv_bias=True, + use_rel_pos=True, + global_attn_indexes=encoder_global_attn_indexes, + window_size=14, + out_chans=prompt_embed_dim, + ) + image_encoder.eval() + if checkpoint is not None: + state_dict = torch.load(checkpoint) + image_encoder.load_state_dict( + {k[30:]: v for k, v in state_dict.items() if "vision_tower_high" in k}, + strict=True, + ) + return image_encoder + + +def build_sam_vit_b(checkpoint=None): + return _build_sam( + encoder_embed_dim=768, + encoder_depth=12, + encoder_num_heads=12, + encoder_global_attn_indexes=[2, 5, 8, 11], + checkpoint=checkpoint, + ) + + +def get_abs_pos(abs_pos, tgt_size): + # abs_pos: L, C + # tgt_size: M + # return: M, C + dim = abs_pos.size(-1) + abs_pos_new = abs_pos.squeeze(0) + cls_token, old_pos_embed = abs_pos_new[:1], abs_pos_new[1:] + + src_size = int(math.sqrt(abs_pos_new.shape[0] - 1)) + tgt_size = int(math.sqrt(tgt_size)) + dtype = abs_pos.dtype + + if src_size != tgt_size: + old_pos_embed = ( + old_pos_embed.view(1, src_size, src_size, dim) + .permute(0, 3, 1, 2) + .contiguous() + ) + old_pos_embed = old_pos_embed.to(torch.float32) + new_pos_embed = F.interpolate( + old_pos_embed, + size=(tgt_size, tgt_size), + mode="bicubic", + antialias=True, + align_corners=False, + ).to(dtype) + new_pos_embed = new_pos_embed.permute(0, 2, 3, 1) + new_pos_embed = new_pos_embed.view(tgt_size * tgt_size, dim) + vision_pos_embed = torch.cat([cls_token, new_pos_embed], dim=0) + vision_pos_embed = vision_pos_embed.view(1, tgt_size * tgt_size + 1, dim) + return vision_pos_embed + else: + return abs_pos + + +class CLIPVisionEmbeddings(nn.Module): + def __init__(self, hidden_size=1024, image_size=224, patch_size=14, num_channels=3): + super().__init__() + self.embed_dim = hidden_size + self.image_size = image_size + self.patch_size = patch_size + + self.class_embedding = torch.nn.Parameter(torch.randn(self.embed_dim)) + + self.patch_embedding = torch.nn.Conv2d( + in_channels=num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=False, + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + 1 + self.position_embedding = torch.nn.Embedding(self.num_positions, self.embed_dim) + self.register_buffer( + "position_ids", torch.arange(self.num_positions).expand((1, -1)) + ) + + def forward(self, pixel_values, patch_embeds): + batch_size = pixel_values.shape[0] + + if patch_embeds is not None: + patch_embeds = patch_embeds + else: + patch_embeds = self.patch_embedding(pixel_values) + + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + + class_embeds = self.class_embedding.expand(batch_size, 1, -1) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + + embeddings = embeddings + get_abs_pos( + self.position_embedding(self.position_ids), embeddings.size(1) + ) + return embeddings + + +class NoTPAttention(torch.nn.Module): + def __init__(self, cfg): + super().__init__() + self.num_heads = cfg["num_attention_heads"] + self.n_local_heads = cfg["num_attention_heads"] + self.head_dim = cfg["hidden_size"] // cfg["num_attention_heads"] + self.max_seq_len = cfg["seq_length"] + self.use_flash_attention = cfg["use_flash_attn"] + + self.qkv_proj = torch.nn.Linear( + cfg["hidden_size"], cfg["hidden_size"] * 3, bias=True + ) + self.out_proj = torch.nn.Linear( + cfg["hidden_size"], cfg["hidden_size"], bias=True + ) + + # self.core_attention = CoreAttention(cfg, AttnType.self_attn) + + self.attn_drop = cfg["attention_dropout"] + + def forward( + self, + x: torch.Tensor, + ): + bsz, seqlen, _ = x.shape + xqkv = self.qkv_proj(x) + xqkv = xqkv.view(bsz, seqlen, 3, self.num_heads, self.head_dim) + + if self.use_flash_attention: + + xq, xk, xv = torch.split(xqkv, 1, dim=2) + xq = xq.squeeze(2) + xk = xk.squeeze(2) + xv = xv.squeeze(2) + # xq, xk, xv = xqkv[:, :, 0, ...], xqkv[:, :, 1, ...], xqkv[:, :, 2, ...] + + # (B, num_head, S, head_size) + xq = xq.permute(0, 2, 1, 3) + xk = xk.permute(0, 2, 1, 3) + xv = xv.permute(0, 2, 1, 3) + output = torch.nn.functional.scaled_dot_product_attention( + xq, xk, xv, attn_mask=None + ) + output = output.permute(0, 2, 1, 3).reshape(bsz, seqlen, -1) + else: + xq, xk, xv = torch.split(xqkv, 1, dim=2) + xq = xq.squeeze(2) + xk = xk.squeeze(2) + xv = xv.squeeze(2) + + xq = xq.permute(0, 2, 1, 3) + xk = xk.permute(0, 2, 1, 3) + xv = xv.permute(0, 2, 1, 3) + output = torch.nn.functional.scaled_dot_product_attention( + xq, xk, xv, attn_mask=None + ) + output = output.permute(0, 2, 1, 3).reshape(bsz, seqlen, -1) + output = self.out_proj(output) + return output + + +@torch.jit.script +def quick_gelu(x): + return x * torch.sigmoid(1.702 * x) + + +class NoTPFeedForward(nn.Module): + def __init__( + self, + cfg, + dim: int, + hidden_dim: int, + ): + super().__init__() + + self.fc1 = torch.nn.Linear(dim, hidden_dim, bias=True) + self.fc2 = torch.nn.Linear(hidden_dim, dim, bias=True) + + def forward(self, x): + output = self.fc2(quick_gelu(self.fc1(x))) + return output + + +class LayerNormfp32(torch.nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class NoTPTransformerBlock(nn.Module): + def __init__(self, cfg, layer_id: int, multiple_of=256): + super().__init__() + + self.n_heads = cfg["num_attention_heads"] + self.dim = cfg["hidden_size"] + self.head_dim = cfg["hidden_size"] // cfg["num_attention_heads"] + self.self_attn = NoTPAttention(cfg) + self.mlp = NoTPFeedForward( + cfg, dim=cfg["hidden_size"], hidden_dim=cfg["ffn_hidden_size"] + ) + self.layer_id = layer_id + self.layer_norm1 = torch.nn.LayerNorm( + cfg["hidden_size"], eps=cfg["layernorm_epsilon"] + ) + self.layer_norm2 = torch.nn.LayerNorm( + cfg["hidden_size"], eps=cfg["layernorm_epsilon"] + ) + + def forward(self, x: torch.Tensor): + residual = self.self_attn.forward(self.layer_norm1(x)) + h = x + residual + out = h + self.mlp.forward(self.layer_norm2(h)) + return out + + +class NoTPTransformer(nn.Module): + def __init__(self, cfg): + super().__init__() + + self.cfg = cfg + self.num_layers = cfg["num_layers"] + + self.layers = torch.nn.ModuleList() + for layer_id in range(self.num_layers): + self.layers.append( + NoTPTransformerBlock( + cfg, + layer_id + 1, + ) + ) + + def forward( + self, + hidden_states, + ): + + for layer in self.layers: + hidden_states = layer(hidden_states) + + return hidden_states + + +class VitModel(nn.Module): + def __init__(self, cfg, freeze_embed=False, freeze_pre_norm=False) -> None: + super().__init__() + + self.embeddings = CLIPVisionEmbeddings( + hidden_size=cfg["hidden_size"], + image_size=cfg["image_size"], + patch_size=cfg["patch_size"], + ) + + if freeze_embed: + for _, param in self.embeddings.named_parameters(): + param.requires_grad = False + + self.transformer = NoTPTransformer(cfg=cfg) + + if cfg.get("fp32norm", False): + logger.info("Load fp32 layernorm for ViT.") + self.pre_layrnorm = LayerNormfp32( + cfg["hidden_size"], + eps=cfg.get("pre_layernorm_epsilon", 1e-5), + ) + else: + self.pre_layrnorm = torch.nn.LayerNorm( + cfg["hidden_size"], + eps=cfg.get("pre_layernorm_epsilon", 1e-5), + ) + + if freeze_pre_norm: + for _, param in self.pre_layrnorm.named_parameters(): + param.requires_grad = False + + for p in self.parameters(): + p.micro_dp = True + + @property + def dtype(self): + return next(self.parameters()).dtype + + def set_input_tensor(self, input_tensor): + if not isinstance(input_tensor, list): + input_tensor = [input_tensor] + self.transformer.set_input_tensor(input_tensor[0]) + + def __str__(self) -> str: + return "open_clip" + + def forward(self, x, patch_embeds): + x = self.embeddings(x, patch_embeds) + hidden_states = self.pre_layrnorm(x) + + output = self.transformer(hidden_states) + + return output + + +vit_model_cfg = dict( + num_layers=24, + hidden_size=1024, + num_heads=16, + num_attention_heads=16, + ffn_hidden_size=4096, + seq_length=256, + max_position_embeddings=256, + use_flash_attn=False, + understand_projector_stride=2, + hidden_dropout=0.0, + attention_dropout=0.0, + no_persist_layer_norm=False, + layernorm_epsilon=1e-5, + pre_layernorm_epsilon=1e-5, + image_size=224, + patch_size=14, + recompute_list=[], +) + + +def build_clip_l(): + return VitModel( + cfg=vit_model_cfg, + freeze_embed=False, + freeze_pre_norm=False, + ) + + +class DeepseekOCRForCausalLM(nn.Module): + def __init__( + self, + *, + config: DeepseekVLV2Config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + + self.config = config + + self.vision_config = config.vision_config + self.projector_config = config.projector_config + self.text_config = config.text_config + + n_embed = 1280 + + self.tile_tag = config.tile_tag + self.global_view_pos = config.global_view_pos + + # special token for image token sequence format + embed_std = 1 / torch.sqrt(torch.tensor(n_embed, dtype=torch.float32)) + if self.tile_tag == "2D": + # <|view_separator|>, <|\n|> + self.image_newline = nn.Parameter(torch.randn(n_embed) * embed_std) + self.view_seperator = nn.Parameter(torch.randn(n_embed) * embed_std) + else: + raise ValueError( + f"Only 2D tile_tag is supported currently, got: {self.tile_tag}" + ) + + if self.text_config.topk_method == "noaux_tc": + self.model = DeepseekV3ForCausalLM( + config=config.text_config, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "language"), + ) + elif not self.text_config.use_mla: + self.model = DeepseekForCausalLM( + config=config.text_config, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "language"), + ) + else: + self.model = DeepseekV2ForCausalLM( + config=config.text_config, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "language"), + ) + + self.sam_model = build_sam_vit_b() + self.vision_model = build_clip_l() + n_embed = 1280 + self.projector = MlpProjector( + projector_type="linear", + input_dim=2048, + n_embed=n_embed, + ) + + def _parse_and_validate_image_input(self, **kwargs: object): + + pixel_values = kwargs.pop("pixel_values", None) + images_spatial_crop = kwargs.pop("images_spatial_crop", None) + images_crop = kwargs.pop("images_crop", None) + + if pixel_values is None or torch.sum(pixel_values).item() == 0: + return None + + if pixel_values is not None: + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError( + "Incorrect type of pixel values. " f"Got type: {type(pixel_values)}" + ) + + if not isinstance(images_spatial_crop, (torch.Tensor, list)): + raise ValueError( + "Incorrect type of image sizes. " + f"Got type: {type(images_spatial_crop)}" + ) + + if not isinstance(images_crop, (torch.Tensor, list)): + raise ValueError( + "Incorrect type of image crop. " f"Got type: {type(images_crop)}" + ) + + return [pixel_values, images_crop, images_spatial_crop] + + raise AssertionError("This line should be unreachable.") + + def _pixel_values_to_embedding( + self, + pixel_values: torch.Tensor, + images_crop: torch.Tensor, + images_spatial_crop: torch.Tensor, + ) -> NestedTensors: + + # Pixel_values (global view): [n_image, batch_size, 3, height, width] + # images_spatial_crop: [n_image, batch_size, [num_tiles_w, num_tiles_h]] + # images_crop (local view): [n_image, batch_size, num_pathes, 3, h, w] + # split the pixel and image_crop, all batch_size = 1 + + images_in_this_batch = [] + + with torch.no_grad(): + for jdx in range(images_spatial_crop.size(0)): + patches = images_crop[jdx][0].to(torch.bfloat16) + image_ori = pixel_values[jdx] + crop_shape = images_spatial_crop[jdx][0] + + if torch.sum(patches).item() != 0: + local_features_1 = self.sam_model(patches) + local_features_2 = self.vision_model(patches, local_features_1) + + local_features = torch.cat( + ( + local_features_2[:, 1:], + local_features_1.flatten(2).permute(0, 2, 1), + ), + dim=-1, + ) + local_features = self.projector(local_features) + + global_features_1 = self.sam_model(image_ori) + global_features_2 = self.vision_model(image_ori, global_features_1) + global_features = torch.cat( + ( + global_features_2[:, 1:], + global_features_1.flatten(2).permute(0, 2, 1), + ), + dim=-1, + ) + global_features = self.projector(global_features) + + _, hw, n_dim = global_features.shape + h = w = int(hw**0.5) + + _2, hw2, n_dim2 = local_features.shape + h2 = w2 = int(hw2**0.5) + + width_crop_num, height_crop_num = int(crop_shape[0]), int( + crop_shape[1] + ) + + global_features = global_features.view(h, w, n_dim) + + global_features = torch.cat( + [ + global_features, + self.image_newline[None, None, :].expand(h, 1, n_dim), + ], + dim=1, + ) + + global_features = global_features.view(-1, n_dim) + + local_features = ( + local_features.view( + height_crop_num, width_crop_num, h2, w2, n_dim2 + ) + .permute(0, 2, 1, 3, 4) + .reshape(height_crop_num * h2, width_crop_num * w2, n_dim2) + ) + local_features = torch.cat( + [ + local_features, + self.image_newline[None, None, :].expand( + height_crop_num * h2, 1, n_dim2 + ), + ], + dim=1, + ) + local_features = local_features.view(-1, n_dim2) + + global_local_features = torch.cat( + [local_features, global_features, self.view_seperator[None, :]], + dim=0, + ) + + else: + global_features_1 = self.sam_model(image_ori) + global_features_2 = self.vision_model(image_ori, global_features_1) + global_features = torch.cat( + ( + global_features_2[:, 1:], + global_features_1.flatten(2).permute(0, 2, 1), + ), + dim=-1, + ) + global_features = self.projector(global_features) + + _, hw, n_dim = global_features.shape + h = w = int(hw**0.5) + + global_features = global_features.view(h, w, n_dim) + + global_features = torch.cat( + [ + global_features, + self.image_newline[None, None, :].expand(h, 1, n_dim), + ], + dim=1, + ) + + global_features = global_features.view(-1, n_dim) + + global_local_features = torch.cat( + [global_features, self.view_seperator[None, :]], dim=0 + ) + + images_in_this_batch.append(global_local_features) + + return images_in_this_batch + + def _process_image_input(self, mm_items: List[MultimodalDataItem]) -> torch.Tensor: + pixel_values = torch.stack([item.feature for item in mm_items], dim=0).type( + self.vision_model.dtype + ) + + images_crop = ( + torch.stack([item.images_crop for item in mm_items], dim=0) + .type(torch.long) + .to(device=pixel_values.device) + ) + images_spatial_crop = ( + torch.cat([item.images_spatial_crop for item in mm_items], dim=0) + .type(torch.long) + .to(device=pixel_values.device) + ) + + assert images_crop.dim() == 6 + assert images_spatial_crop.dim() == 3 + + vision_feature_lists = self._pixel_values_to_embedding( + pixel_values=pixel_values, + images_crop=images_crop, + images_spatial_crop=images_spatial_crop, + ) + vision_features = torch.cat(vision_feature_lists, dim=0).type( + self.vision_model.dtype + ) + + return vision_features + + def get_language_model(self) -> torch.nn.Module: + return self.model + + def get_multimodal_embeddings( + self, **kwargs: object + ) -> Optional[MultiModalEmbeddings]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + vision_embeddings = self._process_image_input(image_input) + return vision_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + + inputs_embeds = self.model.get_input_embeddings(input_ids) + + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, self.image_token_id + ) + + return inputs_embeds + + def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs): + pattern = MultiModalityDataPaddingPatternMultimodalTokens() + return pattern.pad_input_tokens(input_ids, mm_inputs) + + def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: + vision_embeddings = self._process_image_input(items) + return vision_embeddings + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + **kwargs: object, + ): + hidden_states = general_mm_embed_routine( + input_ids=input_ids, + forward_batch=forward_batch, + language_model=self.model, + multimodal_model=self, + positions=positions, + ) + + return hidden_states + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if name == "lm_head.weight": + name = "model.lm_head.weight" + elif name.startswith("model."): + if ( + "image_newline" in name + or ".projector" in name + or "vision_model" in name + or "sam_model" in name + or "view_seperator" in name + ): + name = name[len("model.") :] + elif not ( + ".projector" in name + or "vision_model" in name + or "sam_model" in name + or "image_newline" in name + ): + name = name.replace("model.", "model.model.") + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip experts that are not assigned to this worker. + if ( + "mlp.experts." in name or "mlp.shared_experts." in name + ) and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip experts that are not assigned to this worker. + if ( + "mlp.experts." in name or "mlp.shared_experts." in name + ) and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + unloaded_params = params_dict.keys() - loaded_params + if unloaded_params: + raise RuntimeError( + f"Some weights are not initialized from checkpoints: {unloaded_params}" + ) + + +EntryClass = [DeepseekOCRForCausalLM] diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 779cd8853..891dd51e8 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -200,7 +200,6 @@ _is_flashinfer_available = is_flashinfer_available() _is_sm100_supported = is_cuda() and is_sm100_supported() _is_cublas_ge_129 = is_nvidia_cublas_cu12_version_ge_12_9() - logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/multimodal/processors/base_processor.py b/python/sglang/srt/multimodal/processors/base_processor.py index 95ee3a486..7a60d6b96 100644 --- a/python/sglang/srt/multimodal/processors/base_processor.py +++ b/python/sglang/srt/multimodal/processors/base_processor.py @@ -178,6 +178,7 @@ class BaseMultimodalProcessor(ABC): "image_attention_mask": Modality.IMAGE, "image_emb_mask": Modality.IMAGE, "images_spatial_crop": Modality.IMAGE, + "images_crop": Modality.IMAGE, "tgt_size": Modality.IMAGE, "image_grid_hws": Modality.IMAGE, "aspect_ratio_ids": Modality.IMAGE, diff --git a/python/sglang/srt/multimodal/processors/deepseek_ocr.py b/python/sglang/srt/multimodal/processors/deepseek_ocr.py new file mode 100644 index 000000000..8f0d583be --- /dev/null +++ b/python/sglang/srt/multimodal/processors/deepseek_ocr.py @@ -0,0 +1,37 @@ +from typing import List, Union + +from sglang.srt.models.deepseek_ocr import DeepseekOCRForCausalLM +from sglang.srt.multimodal.processors.base_processor import ( + BaseMultimodalProcessor, + MultimodalSpecialTokens, +) + + +class DeepseekOCRProcessor(BaseMultimodalProcessor): + models = [DeepseekOCRForCausalLM] + + def __init__(self, hf_config, server_args, _processor, *args, **kwargs): + _processor.image_size = 640 + super().__init__(hf_config, server_args, _processor, *args, **kwargs) + self.mm_tokens = MultimodalSpecialTokens( + image_token="", image_token_id=self._processor.image_token_id + ).build(_processor) + + async def process_mm_data_async( + self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs + ): + base_output = self.load_mm_data( + prompt=input_text, + multimodal_tokens=self.mm_tokens, + image_data=image_data, + ) + + mm_items, input_ids, _ = self.process_and_combine_mm_data( + base_output, self.mm_tokens + ) + + return { + "input_ids": input_ids.tolist(), + "mm_items": mm_items, + "im_token_id": self.mm_tokens.image_token_id, + } diff --git a/python/sglang/srt/parser/conversation.py b/python/sglang/srt/parser/conversation.py index 2d03e1bfa..f83a238e3 100644 --- a/python/sglang/srt/parser/conversation.py +++ b/python/sglang/srt/parser/conversation.py @@ -838,6 +838,19 @@ register_conv_template( ) ) +register_conv_template( + Conversation( + name="deepseek-ocr", + system_message="", + system_template="", + roles=("", ""), + sep="", + sep_style=SeparatorStyle.NO_COLON_SINGLE, + stop_str=["<|end▁of▁sentence|>"], + image_token="", + ) +) + register_conv_template( Conversation( name="deepseek-vl2", @@ -981,6 +994,7 @@ MODEL_TYPE_TO_TEMPLATE = { "phi4mm": "phi-4-mm", "minicpmv": "minicpmv", "minicpmo": "minicpmo", + "deepseek-ocr": "deepseek-ocr", } @@ -1057,3 +1071,11 @@ def match_phi_4_mm(model_path: str): return "phi-4-mm" model_type = get_model_type(model_path) return MODEL_TYPE_TO_TEMPLATE.get(model_type) + + +@register_conv_template_matching_function +def match_deepseek_ocr(model_path: str): + if "deepseek-ocr" in model_path.lower(): + return "deepseek-ocr" + model_type = get_model_type(model_path) + return MODEL_TYPE_TO_TEMPLATE.get(model_type) diff --git a/python/sglang/srt/utils/hf_transformers_utils.py b/python/sglang/srt/utils/hf_transformers_utils.py index d4e8b8562..cf2e26ae0 100644 --- a/python/sglang/srt/utils/hf_transformers_utils.py +++ b/python/sglang/srt/utils/hf_transformers_utils.py @@ -19,7 +19,7 @@ import os import tempfile import warnings from pathlib import Path -from typing import Any, Dict, Optional, Type, Union +from typing import Any, Dict, List, Optional, Type, Union import torch from huggingface_hub import snapshot_download @@ -51,26 +51,32 @@ from sglang.srt.configs import ( Qwen3NextConfig, Step3VLConfig, ) +from sglang.srt.configs.deepseek_ocr import DeepseekVLV2Config from sglang.srt.configs.internvl import InternVLChatConfig from sglang.srt.connector import create_remote_connector from sglang.srt.utils import is_remote_url, logger, lru_cache_frozenset -_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { - ChatGLMConfig.model_type: ChatGLMConfig, - DbrxConfig.model_type: DbrxConfig, - ExaoneConfig.model_type: ExaoneConfig, - DeepseekVL2Config.model_type: DeepseekVL2Config, - MultiModalityConfig.model_type: MultiModalityConfig, - KimiVLConfig.model_type: KimiVLConfig, - InternVLChatConfig.model_type: InternVLChatConfig, - Step3VLConfig.model_type: Step3VLConfig, - LongcatFlashConfig.model_type: LongcatFlashConfig, - Olmo3Config.model_type: Olmo3Config, - Qwen3NextConfig.model_type: Qwen3NextConfig, - FalconH1Config.model_type: FalconH1Config, - DotsVLMConfig.model_type: DotsVLMConfig, - DotsOCRConfig.model_type: DotsOCRConfig, - NemotronHConfig.model_type: NemotronHConfig, +_CONFIG_REGISTRY: List[Type[PretrainedConfig]] = [ + ChatGLMConfig, + DbrxConfig, + ExaoneConfig, + DeepseekVL2Config, + MultiModalityConfig, + KimiVLConfig, + InternVLChatConfig, + Step3VLConfig, + LongcatFlashConfig, + Olmo3Config, + Qwen3NextConfig, + FalconH1Config, + DotsVLMConfig, + DotsOCRConfig, + NemotronHConfig, + DeepseekVLV2Config, +] + +_CONFIG_REGISTRY = { + config_cls.model_type: config_cls for config_cls in _CONFIG_REGISTRY } for name, cls in _CONFIG_REGISTRY.items(): @@ -191,6 +197,11 @@ def get_config( config = AutoConfig.from_pretrained( model, trust_remote_code=trust_remote_code, revision=revision, **kwargs ) + if "deepseek-ai/DeepSeek-OCR" in model: + config.model_type = "deepseek-ocr" + # Due to an unknown reason, Hugging Face’s AutoConfig mistakenly recognizes the configuration of deepseek-ocr as deepseekvl2. + # This is a temporary workaround and will require further optimization. + except ValueError as e: if not "deepseek_v32" in str(e): raise e @@ -213,7 +224,8 @@ def get_config( "intermediate_size": 4304, "model_type": "siglip_vision_model", "num_attention_heads": 16, - "num_hidden_layers": 26, # Model is originally 27-layer, we only need the first 26 layers for feature extraction. + "num_hidden_layers": 26, + # Model is originally 27-layer, we only need the first 26 layers for feature extraction. "patch_size": 14, } config.vision_config = SiglipVisionConfig(**vision_config) diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 8fd557688..10e1127b3 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -619,7 +619,6 @@ def popen_launch_server( start_time = time.perf_counter() with requests.Session() as session: while time.perf_counter() - start_time < timeout: - return_code = process.poll() if return_code is not None: # Server failed to start (non-zero exit code) or crashed diff --git a/test/srt/test_vision_openai_server_a.py b/test/srt/test_vision_openai_server_a.py index b8eebcce4..6c6989aa0 100644 --- a/test/srt/test_vision_openai_server_a.py +++ b/test/srt/test_vision_openai_server_a.py @@ -150,6 +150,62 @@ class TestQwen2AudioServer(AudioOpenAITestMixin): model = "Qwen/Qwen2-Audio-7B-Instruct" +class TestDeepseekOCRServer(TestOpenAIMLLMServerBase): + model = "deepseek-ai/DeepSeek-OCR" + trust_remote_code = False + + def verify_single_image_response_for_ocr(self, response): + """Verify DeepSeek-OCR grounding output with coordinates""" + assert response.choices[0].message.role == "assistant" + text = response.choices[0].message.content + assert isinstance(text, str) + + # DeepSeek-OCR uses grounding format, outputs coordinates + assert "text" in text.lower(), f"OCR text: {text}, should contain 'text'" + + # Verify coordinate format [[x1, y1, x2, y2]] + import re + + coord_pattern = r"\[\[[\d\s,]+\]\]" + assert re.search( + coord_pattern, text + ), f"OCR text: {text}, should contain coordinate format [[x1, y1, x2, y2]]" + + # Verify basic response fields + assert response.id + assert response.created + assert response.usage.prompt_tokens > 0 + assert response.usage.completion_tokens > 0 + assert response.usage.total_tokens > 0 + + def test_single_image_chat_completion(self): + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + image_url = "https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/images/ocr-text.png" + + response = client.chat.completions.create( + model="default", + messages=[ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": image_url}, + }, + { + "type": "text", + "text": "<|grounding|>Convert the document to markdown.", + }, + ], + }, + ], + temperature=0, + **(self.get_vision_request_kwargs()), + ) + + self.verify_single_image_response_for_ocr(response) + + if __name__ == "__main__": del ( TestOpenAIMLLMServerBase, diff --git a/test/srt/test_vision_openai_server_common.py b/test/srt/test_vision_openai_server_common.py index 392ccf0f8..f737a5699 100644 --- a/test/srt/test_vision_openai_server_common.py +++ b/test/srt/test_vision_openai_server_common.py @@ -32,6 +32,7 @@ class TestOpenAIMLLMServerBase(CustomTestCase): model: str extra_args: list = [] fixed_args: list = ["--trust-remote-code", "--enable-multimodal"] + trust_remote_code: bool = True @classmethod def setUpClass(cls): @@ -42,7 +43,11 @@ class TestOpenAIMLLMServerBase(CustomTestCase): cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, api_key=cls.api_key, - other_args=cls.extra_args + cls.fixed_args, + other_args=( + cls.extra_args + cls.fixed_args + ["--trust-remote-code"] + if cls.trust_remote_code + else [] + ), ) cls.base_url += "/v1"