diff --git a/benchmark/mmmu/eval_utils.py b/benchmark/mmmu/eval_utils.py index dc26ccb1e..6daf5db6f 100644 --- a/benchmark/mmmu/eval_utils.py +++ b/benchmark/mmmu/eval_utils.py @@ -26,6 +26,7 @@ class EvalArgs: backend: str = "engine" seed: int = 42 split: str = "validation" + # Default setting to make the benchmark available on A100 for most 7B models image_pixels_limit: int = 4300000 result_filename: str = "" prompt_format_file: str = "prompt_format.yaml" @@ -38,6 +39,7 @@ class EvalArgs: parser.add_argument( "--result-filename", type=str, default=EvalArgs.result_filename ) + parser.add_argument( "--image-pixels-limit", type=int, default=EvalArgs.image_pixels_limit ) diff --git a/docs/references/supported_models.md b/docs/references/supported_models.md index 396746a0d..ef668e817 100644 --- a/docs/references/supported_models.md +++ b/docs/references/supported_models.md @@ -31,6 +31,7 @@ - Phi-3 / Phi-4 - Phi-3-Small - IBM Granite 3 +- Janus-Pro-1B / Janus-Pro-7B ## Embedding Models diff --git a/python/sglang/lang/chat_template.py b/python/sglang/lang/chat_template.py index 49107a0f4..e17168b90 100644 --- a/python/sglang/lang/chat_template.py +++ b/python/sglang/lang/chat_template.py @@ -230,6 +230,29 @@ register_chat_template( ) ) +register_chat_template( + ChatTemplate( + name="janus-pro", + default_system_prompt=None, + role_prefix_and_suffix={ + "system": ( + "", + "", + ), + "User": ( + "<|User|>", + "", + ), + "assistant": ( + "<|Assistant|>", + "<|end▁of▁sentence|>", + ), + }, + stop_str=("<|end▁of▁sentence|>",), + image_token="\n", + ) +) + # The difference between "llama-3-instruct-llava" and "llama-3-instruct" is that llava uses a different image_token. register_chat_template( ChatTemplate( @@ -384,6 +407,12 @@ def match_deepseek(model_path: str): return get_chat_template("deepseek-v3") +@register_chat_template_matching_function +def match_deepseek_janus_pro(model_path: str): + if "janus" in model_path.lower(): + return get_chat_template("janus-pro") + + @register_chat_template_matching_function def match_dbrx(model_path: str): if "dbrx" in model_path.lower() and "instruct" in model_path.lower(): diff --git a/python/sglang/srt/configs/__init__.py b/python/sglang/srt/configs/__init__.py index c33d5e0a2..f7ae0108e 100644 --- a/python/sglang/srt/configs/__init__.py +++ b/python/sglang/srt/configs/__init__.py @@ -1,6 +1,7 @@ from sglang.srt.configs.chatglm import ChatGLMConfig from sglang.srt.configs.dbrx import DbrxConfig from sglang.srt.configs.exaone import ExaoneConfig +from sglang.srt.configs.janus_pro import MultiModalityConfig from sglang.srt.configs.qwen2_5_vl_config import ( Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig, @@ -12,4 +13,5 @@ __all__ = [ "DbrxConfig", "Qwen2_5_VLConfig", "Qwen2_5_VLVisionConfig", + "MultiModalityConfig", ] diff --git a/python/sglang/srt/configs/janus_pro.py b/python/sglang/srt/configs/janus_pro.py new file mode 100644 index 000000000..fad9a34a5 --- /dev/null +++ b/python/sglang/srt/configs/janus_pro.py @@ -0,0 +1,629 @@ +# Adapted from: +# https://github.com/deepseek-ai/Janus/tree/main/janus/models + +from dataclasses import dataclass +from typing import Dict, List, Tuple, Union + +import numpy as np +import PIL +import torch +from PIL.Image import Image +from transformers import ( + AutoImageProcessor, + AutoProcessor, + BaseImageProcessor, + BatchFeature, + LlamaConfig, + LlamaTokenizerFast, + PretrainedConfig, + ProcessorMixin, +) +from transformers.image_utils import to_numpy_array + +from sglang.srt.mm_utils import expand2square + + +class DictToObject(dict): + def __init__(self, dictionary): + super(self).__init__(dictionary) + + for key, value in dictionary.items(): + if isinstance(value, dict): + value = DictToObject(value) + setattr(self, key, value) + + +class VisionConfig(PretrainedConfig): + model_type = "vision" + cls: str = "" + params = {} + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + self.cls = kwargs.get("cls", "") + if not isinstance(self.cls, str): + self.cls = self.cls.__name__ + + self.params = kwargs.get("params", {}) + + +class GenAlignerConfig(PretrainedConfig): + model_type = "gen_aligner" + cls: str = "" + params = {} + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + self.cls = kwargs.get("cls", "") + if not isinstance(self.cls, str): + self.cls = self.cls.__name__ + + self.params = kwargs.get("params", {}) + + +class GenHeadConfig(PretrainedConfig): + model_type = "gen_head" + cls: str = "" + params = {} + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + self.cls = kwargs.get("cls", "") + if not isinstance(self.cls, str): + self.cls = self.cls.__name__ + + self.params = kwargs.get("params", {}) + + +class AlignerConfig(PretrainedConfig): + model_type = "aligner" + cls: str = "" + params = {} + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + self.cls = kwargs.get("cls", "") + if not isinstance(self.cls, str): + self.cls = self.cls.__name__ + + self.params = kwargs.get("params", {}) + + +class GenVisionConfig(PretrainedConfig): + model_type = "gen_vision" + cls: str = "" + params = {} + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + self.cls = kwargs.get("cls", "") + if not isinstance(self.cls, str): + self.cls = self.cls.__name__ + + self.params = kwargs.get("params", {}) + + +@dataclass +class SigLIPVisionCfg: + width: int = 1152 + layers: Union[Tuple[int, int, int, int], int] = 27 + heads: int = 16 + patch_size: int = 14 + image_size: Union[Tuple[int, int], int] = 336 + global_pool: str = "map" + mlp_ratio: float = 3.7362 + class_token: bool = False + num_classes: int = 0 + use_checkpoint: bool = False + + +class MultiModalityConfig(PretrainedConfig): + model_type = "multi_modality" + vision_config: VisionConfig + aligner_config: AlignerConfig + + gen_vision_config: GenVisionConfig + gen_aligner_config: GenAlignerConfig + gen_head_config: GenHeadConfig + + language_config: LlamaConfig + + def __init__(self, **kwargs): + super().__init__(**kwargs) + vision_config = kwargs.get("vision_config", {}) + self.vision_config = VisionConfig(**vision_config) + + aligner_config = kwargs.get("aligner_config", {}) + self.aligner_config = AlignerConfig(**aligner_config) + + gen_vision_config = kwargs.get("gen_vision_config", {}) + self.gen_vision_config = GenVisionConfig(**gen_vision_config) + + gen_aligner_config = kwargs.get("gen_aligner_config", {}) + self.gen_aligner_config = GenAlignerConfig(**gen_aligner_config) + + gen_head_config = kwargs.get("gen_head_config", {}) + self.gen_head_config = GenHeadConfig(**gen_head_config) + + language_config = kwargs.get("language_config", {}) + if isinstance(language_config, LlamaConfig): + self.language_config = language_config + else: + self.language_config = LlamaConfig(**language_config) + + +class VLMImageProcessor(BaseImageProcessor): + model_input_names = ["pixel_values"] + + def __init__( + self, + image_size: int, + min_size: int = 14, + image_mean: Union[Tuple[float, float, float], List[float]] = ( + 0.48145466, + 0.4578275, + 0.40821073, + ), + image_std: Union[Tuple[float, float, float], List[float]] = ( + 0.26862954, + 0.26130258, + 0.27577711, + ), + rescale_factor: float = 1.0 / 255.0, + do_normalize: bool = True, + **kwargs, + ): + super().__init__(**kwargs) + + self.image_size = image_size + self.rescale_factor = rescale_factor + self.image_mean = image_mean + self.image_std = image_std + self.min_size = min_size + self.do_normalize = do_normalize + + if image_mean is None: + self.background_color = (127, 127, 127) + else: + self.background_color = tuple([int(x * 255) for x in image_mean]) + + def resize(self, pil_img: Image) -> np.ndarray: + """ + + Args: + pil_img (PIL.Image): [H, W, 3] in PIL.Image in RGB + + Returns: + x (np.ndarray): [3, self.image_size, self.image_size] + """ + + width, height = pil_img.size + max_size = max(width, height) + + size = [ + max(int(height / max_size * self.image_size), self.min_size), + max(int(width / max_size * self.image_size), self.min_size), + ] + + if width <= 0 or height <= 0 or size[0] <= 0 or size[1] <= 0: + # print(f"orig size = {pil_img.size}, new size = {size}") + raise ValueError("Invalid size!") + + def resize( + pil_img, size, interpolation=PIL.Image.Resampling.BICUBIC, antialias=True + ): + if isinstance(size, int): + w, h = pil_img.size + if (w <= h and w == size) or (h <= w and h == size): + return pil_img + if w < h: + ow = size + oh = int(size * h / w) + else: + oh = size + ow = int(size * w / h) + size = (ow, oh) + else: + size = (size[1], size[0]) + + return pil_img.resize( + size, resample=interpolation, reducing_gap=None if antialias else 3.0 + ) + + pil_img = resize( + pil_img, size, interpolation=PIL.Image.Resampling.BICUBIC, antialias=True + ) + + pil_img = expand2square(pil_img, self.background_color) + x = to_numpy_array(pil_img) + + # [H, W, 3] -> [3, H, W] + x = np.transpose(x, (2, 0, 1)) + + return x + + def preprocess(self, images, return_tensors: str = "pt", **kwargs) -> BatchFeature: + # resize and pad to [self.image_size, self.image_size] + # then convert from [H, W, 3] to [3, H, W] + if not isinstance(images, list): + images = [images] + images: List[np.ndarray] = [self.resize(image) for image in images] + images = [image[:3, ...] for image in images] + + # rescale from [0, 255] -> [0, 1] + images = [ + self.rescale( + image=image, + scale=self.rescale_factor, + input_data_format="channels_first", + ) + for image in images + ] + + # normalize + if self.do_normalize: + images = [ + self.normalize( + image=image, + mean=self.image_mean, + std=self.image_std, + input_data_format="channels_first", + ) + for image in images + ] + data = {"pixel_values": images} + return BatchFeature(data=data, tensor_type=return_tensors) + + @property + def default_shape(self): + return [3, self.image_size, self.image_size] + + +class DictOutput(object): + def keys(self): + return self.__dict__.keys() + + def __getitem__(self, item): + return self.__dict__[item] + + def __setitem__(self, key, value): + self.__dict__[key] = value + + +@dataclass +class VLChatProcessorOutput(DictOutput): + sft_format: str + input_ids: torch.Tensor + pixel_values: torch.Tensor + num_image_tokens: torch.IntTensor + + def __len__(self): + return len(self.input_ids) + + +@dataclass +class BatchedVLChatProcessorOutput(DictOutput): + sft_format: List[str] + input_ids: torch.Tensor + pixel_values: torch.Tensor + attention_mask: torch.Tensor + images_seq_mask: torch.BoolTensor + images_emb_mask: torch.BoolTensor + + +# FIXME: had to place Official Processor here, since image_processor module would not be imported in all threads, +# hence AutoProcessor registration would not be affective in some cases +class VLChatProcessor(ProcessorMixin): + image_processor_class = "AutoImageProcessor" + tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast") + + attributes = ["image_processor", "tokenizer"] + + def __init__( + self, + image_processor: VLMImageProcessor, + tokenizer: LlamaTokenizerFast, + image_tag: str = "", + image_start_tag: str = "", + image_end_tag: str = "", + pad_tag: str = "<|▁pad▁|>", + num_image_tokens: int = 576, + add_special_token: bool = False, + sft_format: str = "deepseek", + mask_prompt: bool = True, + ignore_id: int = -100, + **kwargs, + ): + self.image_processor = image_processor + self.tokenizer = tokenizer + + image_id = self.tokenizer.vocab.get(image_tag) + if image_id is None: + special_tokens = [image_tag] + special_tokens_dict = {"additional_special_tokens": special_tokens} + self.tokenizer.add_special_tokens(special_tokens_dict) + # print(f"Add image tag = {image_tag} to the tokenizer") + + self.image_tag = image_tag + self.image_start_tag = image_start_tag + self.image_end_tag = image_end_tag + self.pad_tag = pad_tag + + self.num_image_tokens = num_image_tokens + self.add_special_token = add_special_token + self.sft_format = sft_format + self.ignore_id = ignore_id + + super().__init__( + image_processor, + tokenizer, + **kwargs, + ) + + @property + def image_token(self): + return self.image_tag + + @property + def image_id(self) -> int: + image_id = self.tokenizer.vocab.get(self.image_tag) + return image_id + + @property + def image_start_id(self): + image_start_id = self.tokenizer.vocab.get(self.image_start_tag) + return image_start_id + + @property + def image_end_id(self): + image_end_id = self.tokenizer.vocab.get(self.image_end_tag) + return image_end_id + + @property + def image_start_token(self): + return self.image_start_tag + + @property + def image_end_token(self): + return self.image_end_tag + + @property + def pad_id(self): + pad_id = self.tokenizer.vocab.get(self.pad_tag) + return pad_id + + def add_image_token( + self, + image_indices: List[int], + input_ids: torch.LongTensor, + ): + """ + + Args: + image_indices (List[int]): [index_0, index_1, ..., index_j] + input_ids (torch.LongTensor): [N] + + Returns: + input_ids (torch.LongTensor): [N + image tokens] + num_image_tokens (torch.IntTensor): [n_images] + """ + + input_slices = [] + + start = 0 + for index in image_indices: + if self.add_special_token: + end = index + 1 + else: + end = index + + # original text tokens + input_slices.append(input_ids[start:end]) + + # add boi, image tokens, eoi and set the mask as False + input_slices.append(self.image_start_id * torch.ones((1), dtype=torch.long)) + input_slices.append( + self.image_id * torch.ones((self.num_image_tokens,), dtype=torch.long) + ) + input_slices.append(self.image_end_id * torch.ones((1), dtype=torch.long)) + start = index + 1 + + # the left part + input_slices.append(input_ids[start:]) + + # concat all slices + input_ids = torch.cat(input_slices, dim=0) + num_image_tokens = torch.IntTensor([self.num_image_tokens] * len(image_indices)) + + return input_ids, num_image_tokens + + def process_one( + self, + prompt: str = None, + images: List[Image] = None, + **kwargs, + ): + """ + + Args: + prompt (str): the formatted prompt; + images (List[ImageType]): the list of images; + **kwargs: + + Returns: + outputs (BaseProcessorOutput): the output of the processor, + - input_ids (torch.LongTensor): [N + image tokens] + - target_ids (torch.LongTensor): [N + image tokens] + - images (torch.FloatTensor): [n_images, 3, H, W] + - image_id (int): the id of the image token + - num_image_tokens (List[int]): the number of image tokens + """ + + sft_format = prompt + # tokenize + input_ids = self.tokenizer.encode(sft_format) + input_ids = torch.LongTensor(input_ids) + + # add image tokens to the input_ids + image_token_mask: torch.Tensor = (input_ids == self.image_id).to(torch.bool) + image_indices = image_token_mask.nonzero() + input_ids, num_image_tokens = self.add_image_token( + image_indices=image_indices, + input_ids=input_ids, + ) + + # load images + images_outputs = self.image_processor(images, return_tensors="pt") + + prepare = VLChatProcessorOutput( + sft_format=sft_format, + input_ids=input_ids, + pixel_values=images_outputs.pixel_values, + num_image_tokens=num_image_tokens, + ) + + return prepare + + def __call__( + self, + *, + prompt: str = None, + conversations: List[Dict[str, str]] = None, + images: List[Image] = None, + force_batchify: bool = True, + **kwargs, + ): + """ + + Args: + prompt (str): the formatted prompt; + conversations (List[Dict]): conversations with a list of messages; + images (List[ImageType]): the list of images; + force_batchify (bool): force batchify the inputs; + **kwargs: + + Returns: + outputs (BaseProcessorOutput): the output of the processor, + - input_ids (torch.LongTensor): [N + image tokens] + - images (torch.FloatTensor): [n_images, 3, H, W] + - image_id (int): the id of the image token + - num_image_tokens (List[int]): the number of image tokens + """ + + prepare = self.process_one( + prompt=prompt, conversations=conversations, images=images + ) + + if force_batchify: + prepare = self.batchify([prepare]) + + return prepare + + def batchify( + self, prepare_list: List[VLChatProcessorOutput] + ) -> BatchedVLChatProcessorOutput: + """ + Preprocesses the inputs for multimodal inference. + + Args: + prepare_list (List[VLChatProcessorOutput]): A list of VLChatProcessorOutput. + + Returns: + BatchedVLChatProcessorOutput: A dictionary of the inputs to use for multimodal inference. + """ + + batch_size = len(prepare_list) + sft_format = [] + n_images = [] + seq_lens = [] + for prepare in prepare_list: + n_images.append(len(prepare.num_image_tokens)) + seq_lens.append(len(prepare)) + + input_token_max_len = max(seq_lens) + max_n_images = max(1, max(n_images)) + + batched_input_ids = torch.full( + (batch_size, input_token_max_len), self.pad_id + ).long() # FIXME + batched_attention_mask = torch.zeros((batch_size, input_token_max_len)).long() + batched_pixel_values = torch.zeros( + (batch_size, max_n_images, *self.image_processor.default_shape) + ).float() + batched_images_seq_mask = torch.zeros((batch_size, input_token_max_len)).bool() + batched_images_emb_mask = torch.zeros( + (batch_size, max_n_images, self.num_image_tokens) + ).bool() + + for i, prepare in enumerate(prepare_list): + input_ids = prepare.input_ids + seq_len = len(prepare) + n_image = len(prepare.num_image_tokens) + # left-padding + batched_attention_mask[i, -seq_len:] = 1 + batched_input_ids[i, -seq_len:] = torch.LongTensor(input_ids) + batched_images_seq_mask[i, -seq_len:] = input_ids == self.image_id + + if n_image > 0: + batched_pixel_values[i, :n_image] = prepare.pixel_values + for j, n_image_tokens in enumerate(prepare.num_image_tokens): + batched_images_emb_mask[i, j, :n_image_tokens] = True + + sft_format.append(prepare.sft_format) + + batched_prepares = BatchedVLChatProcessorOutput( + input_ids=batched_input_ids, + attention_mask=batched_attention_mask, + pixel_values=batched_pixel_values, + images_seq_mask=batched_images_seq_mask, + images_emb_mask=batched_images_emb_mask, + sft_format=sft_format, + ) + + return batched_prepares + + +class VLMImageProcessorConfig(PretrainedConfig): + model_type = "deepseek_vlm" + image_size: int + min_size: int + image_mean: Union[Tuple[float, float, float], List[float]] + image_std: Union[Tuple[float, float, float], List[float]] + rescale_factor: float + do_normalize: bool + + def __init__( + self, + image_size: int, + min_size: int = 14, + image_mean: Union[Tuple[float, float, float], List[float]] = ( + 0.48145466, + 0.4578275, + 0.40821073, + ), + image_std: Union[Tuple[float, float, float], List[float]] = ( + 0.26862954, + 0.26130258, + 0.27577711, + ), + rescale_factor: float = 1.0 / 255.0, + do_normalize: bool = True, + **kwargs, + ): + self.image_size = image_size + self.min_size = min_size + self.image_mean = image_mean + self.image_std = image_std + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + + super().__init__(**kwargs) + + +AutoProcessor.register(MultiModalityConfig, VLChatProcessor, exist_ok=True) +AutoImageProcessor.register(VLMImageProcessorConfig, None, VLMImageProcessor, None) diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 028a4519a..13516c5c6 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -408,7 +408,7 @@ def _get_and_verify_dtype( def is_generation_model(model_architectures: List[str], is_embedding: bool = False): # We have two ways to determine whether a model is a generative model. - # 1. Check the model architectue + # 1. Check the model architecture # 2. check the `is_embedding` server args if ( @@ -424,18 +424,25 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal return not is_embedding +multimodal_model_archs = [ + "LlavaLlamaForCausalLM", + "LlavaQwenForCausalLM", + "LlavaMistralForCausalLM", + "LlavaVidForCausalLM", + "Grok1VForCausalLM", + "Grok1AForCausalLM", + "MllamaForConditionalGeneration", + "Qwen2VLForConditionalGeneration", + "Qwen2_5_VLForConditionalGeneration", + "MiniCPMV", + "MultiModalityCausalLM", +] + + def is_multimodal_model(model_architectures: List[str]): - if ( - "LlavaLlamaForCausalLM" in model_architectures - or "LlavaQwenForCausalLM" in model_architectures - or "LlavaMistralForCausalLM" in model_architectures - or "LlavaVidForCausalLM" in model_architectures - or "Grok1VForCausalLM" in model_architectures - or "Grok1AForCausalLM" in model_architectures - or "MllamaForConditionalGeneration" in model_architectures - or "Qwen2VLForConditionalGeneration" in model_architectures - or "Qwen2_5_VLForConditionalGeneration" in model_architectures - or "MiniCPMV" in model_architectures + if any( + multi_model_arch in model_architectures + for multi_model_arch in multimodal_model_archs ): return True else: diff --git a/python/sglang/srt/conversation.py b/python/sglang/srt/conversation.py index b6db4a2da..5c8f37643 100644 --- a/python/sglang/srt/conversation.py +++ b/python/sglang/srt/conversation.py @@ -631,3 +631,18 @@ register_conv_template( image_token="(./)", ) ) + +# Reference: https://github.com/deepseek-ai/Janus?tab=readme-ov-file#janus-pro +register_conv_template( + Conversation( + name="janus-pro", + system_message="You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language", + system_template="{system_message}.", + roles=("User", "Assistant"), + sep="\n\n", + sep2="<|end▁of▁sentence|>", + sep_style=SeparatorStyle.ADD_COLON_TWO, + stop_str=["<|User|>", "<|end▁of▁sentence|>"], + image_token="", + ) +) diff --git a/python/sglang/srt/hf_transformers_utils.py b/python/sglang/srt/hf_transformers_utils.py index 6c1efe31d..e5aa5a62e 100644 --- a/python/sglang/srt/hf_transformers_utils.py +++ b/python/sglang/srt/hf_transformers_utils.py @@ -30,13 +30,20 @@ from transformers import ( ) from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES -from sglang.srt.configs import ChatGLMConfig, DbrxConfig, ExaoneConfig, Qwen2_5_VLConfig +from sglang.srt.configs import ( + ChatGLMConfig, + DbrxConfig, + ExaoneConfig, + MultiModalityConfig, + Qwen2_5_VLConfig, +) _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { ChatGLMConfig.model_type: ChatGLMConfig, DbrxConfig.model_type: DbrxConfig, ExaoneConfig.model_type: ExaoneConfig, Qwen2_5_VLConfig.model_type: Qwen2_5_VLConfig, + MultiModalityConfig.model_type: MultiModalityConfig, } for name, cls in _CONFIG_REGISTRY.items(): @@ -67,6 +74,13 @@ def get_config( model, trust_remote_code=trust_remote_code, revision=revision, **kwargs ) + # FIXME: Pour contents of janus-pro's langauge_config to first-level + if isinstance(model, str) and model.lower().startswith("deepseek-ai/janus-pro"): + assert hasattr(config, "language_config") + for key, val in config.language_config.__dict__.items(): + setattr(config, key, val) + setattr(config, "architectures", ["MultiModalityCausalLM"]) + if config.model_type in _CONFIG_REGISTRY: config_class = _CONFIG_REGISTRY[config.model_type] config = config_class.from_pretrained(model, revision=revision) diff --git a/python/sglang/srt/layers/attention/vision.py b/python/sglang/srt/layers/attention/vision.py index 3aea3b7ae..f27d8c781 100644 --- a/python/sglang/srt/layers/attention/vision.py +++ b/python/sglang/srt/layers/attention/vision.py @@ -6,7 +6,7 @@ from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F -from einops import rearrange, repeat +from einops import rearrange from sglang.srt.distributed import parallel_state from sglang.srt.distributed import utils as dist_utils diff --git a/python/sglang/srt/managers/image_processors/base_image_processor.py b/python/sglang/srt/managers/image_processors/base_image_processor.py index b5365b8ae..ec1799b69 100644 --- a/python/sglang/srt/managers/image_processors/base_image_processor.py +++ b/python/sglang/srt/managers/image_processors/base_image_processor.py @@ -13,6 +13,7 @@ from PIL import Image from sglang.srt.server_args import ServerArgs from sglang.srt.utils import load_image +from sglang.utils import logger global global_processor @@ -22,6 +23,13 @@ def get_global_processor(): return global_processor +def init_global_processor(sglang_image_processor, server_args: ServerArgs): + """Init the global processor for multi-modal models.""" + global global_processor + transformers.logging.set_verbosity_error() + global_processor = sglang_image_processor._build_processor(server_args=server_args) + + @dataclasses.dataclass class BaseImageProcessorOutput: image_hashes: list[int] @@ -119,6 +127,11 @@ class BaseImageProcessor(ABC): ) -> BaseImageProcessorOutput: """ Each frame of video/image will be replaced by a single image token + + Args: + + discard_alpha_channel: if True, discards the alpha channel in the returned images + """ image_hashes, image_sizes = [], [] all_frames = [] @@ -133,7 +146,7 @@ class BaseImageProcessor(ABC): if return_text: text_parts = input_text.split(image_token) - # roughly calculate the max number of frames under the max_req_input_len limit + # TODO(mick): load from server_args, env, or sampling_params MAX_NUM_FRAMES = 30 estimated_frames_list = self.get_estimated_frames_list(image_data=image_data) total_frame_count = sum(estimated_frames_list) diff --git a/python/sglang/srt/managers/image_processors/janus_pro.py b/python/sglang/srt/managers/image_processors/janus_pro.py new file mode 100644 index 000000000..a3d25c989 --- /dev/null +++ b/python/sglang/srt/managers/image_processors/janus_pro.py @@ -0,0 +1,79 @@ +import asyncio +from typing import List, Union + +from sglang.srt.managers.image_processors.base_image_processor import ( + BaseImageProcessor as SGLangBaseImageProcessor, +) +from sglang.srt.managers.image_processors.base_image_processor import ( + get_global_processor, +) +from sglang.srt.models.deepseek_janus_pro import MultiModalityCausalLM + + +class JanusProProcessor(SGLangBaseImageProcessor): + def __init__(self, hf_config, server_args, _processor): + super().__init__(hf_config, server_args, _processor) + + @staticmethod + def _process_images_task(images, input_text): + processor = get_global_processor() + result = processor.__call__( + prompt=input_text, images=images, return_tensors="pt" + ) + return { + "input_ids": result["input_ids"], + "pixel_values": result["pixel_values"], + "images_emb_mask": result["images_emb_mask"], + "im_start_id": processor.image_start_id, + "im_end_id": processor.image_end_id, + "im_token_id": processor.image_id, + } + + async def _process_images(self, images, input_text): + if self.executor is not None: + loop = asyncio.get_event_loop() + image_inputs = await loop.run_in_executor( + self.executor, + JanusProProcessor._process_images_task, + images, + input_text, + ) + else: + image_inputs = self._processor( + images=images, text=input_text, return_tensors="pt" + ) + + return image_inputs + + async def process_images_async( + self, + image_data: List[Union[str, bytes]], + input_ids, + request_obj, + max_req_input_len, + **kwargs, + ): + if not image_data: + return None + + if not isinstance(image_data, list): + image_data = [image_data] + + base_out = self.load_images( + input_ids, image_data, "", max_req_input_len + ) + images = base_out.all_frames + res = await self._process_images(images=images, input_text=base_out.input_text) + + return { + "input_ids": res["input_ids"].flatten().tolist(), + "pixel_values": res["pixel_values"], + "images_emb_mask": res["images_emb_mask"], + "image_hashes": base_out.image_hashes, + "im_start_id": res["im_start_id"], + "im_end_id": res["im_end_id"], + "im_token_id": res["im_token_id"], + } + + +ImageProcessorMapping = {MultiModalityCausalLM: JanusProProcessor} diff --git a/python/sglang/srt/models/deepseek_janus_pro.py b/python/sglang/srt/models/deepseek_janus_pro.py new file mode 100644 index 000000000..7b0dac16c --- /dev/null +++ b/python/sglang/srt/models/deepseek_janus_pro.py @@ -0,0 +1,2127 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# Copied and Adapted from: +# https://github.com/deepseek-ai/Janus + + +import collections +import math +import os +from dataclasses import field +from enum import Enum +from functools import partial +from itertools import repeat +from typing import ( + Callable, + Final, + Iterable, + Literal, + Optional, + Sequence, + Set, + Tuple, + Type, + Union, +) + +import torch +import torch.nn.functional as F +from einops import rearrange +from torch import Tensor, _assert, nn +from torch.nn.init import trunc_normal_ +from transformers import AutoModel, PreTrainedModel + +from sglang.srt.configs.janus_pro import * +from sglang.srt.layers.attention.vision import VisionAttention +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.quantization import QuantizationConfig +from sglang.srt.managers.multi_modality_padding import ( + MultiModalityDataPaddingPatternTokenPairs, +) +from sglang.srt.managers.schedule_batch import ImageInputs +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.llama import LlamaForCausalLM +from sglang.utils import logger + +################################################################################# +# VQ Model Configs # +################################################################################# + + +# Copied from: +# https://github.com/deepseek-ai/Janus/tree/main/janus/models/vq_model.py +@dataclass +class ModelArgs: + codebook_size: int = 16384 + codebook_embed_dim: int = 8 + codebook_l2_norm: bool = True + codebook_show_usage: bool = True + commit_loss_beta: float = 0.25 + entropy_loss_ratio: float = 0.0 + + encoder_ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4]) + decoder_ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4]) + z_channels: int = 256 + dropout_p: float = 0.0 + + +def named_apply( + fn: Callable, + module: nn.Module, + name="", + depth_first: bool = True, + include_root: bool = False, +) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = ".".join((name, child_name)) if name else child_name + named_apply( + fn=fn, + module=child_module, + name=child_name, + depth_first=depth_first, + include_root=True, + ) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +def VQ_16(**kwargs): + return VQModel( + ModelArgs( + encoder_ch_mult=[1, 1, 2, 2, 4], decoder_ch_mult=[1, 1, 2, 2, 4], **kwargs + ) + ) + + +VQ_models = {"VQ-16": VQ_16} + +import collections.abc + + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return tuple(x) + return tuple(repeat(x, n)) + + return parse + + +def _trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + logger.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) + + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + if tensor.dtype in [torch.float16, torch.bfloat16]: + # The `erfinv_` op is not (yet?) defined in float16+cpu, bfloat16+gpu + og_dtype = tensor.dtype + tensor = tensor.to(torch.float32) + tensor.erfinv_() + tensor = tensor.to(og_dtype) + else: + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + if tensor.dtype == torch.float16: + # The `clamp_` op is not (yet?) defined in float16+cpu + tensor = tensor.to(torch.float32) + tensor.clamp_(min=a, max=b) + else: + tensor.clamp_(min=a, max=b) + + +def trunc_normal_tf_( + tensor: torch.Tensor, + mean: float = 0.0, + std: float = 1.0, + a: float = -2.0, + b: float = 2.0, +): + """Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \\leq \text{mean} \\leq b`. + NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the + bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 + and the result is subsquently scaled and shifted by the mean and std args. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + """ + with torch.no_grad(): + _trunc_normal_(tensor, 0, 1.0, a, b) + tensor.mul_(std).add_(mean) + + +to_2tuple = _ntuple(2) + + +class Format(str, Enum): + NCHW = "NCHW" + NHWC = "NHWC" + NCL = "NCL" + NLC = "NLC" + + +def nchw_to(x: torch.Tensor, fmt: Format): + if fmt == Format.NHWC: + x = x.permute(0, 2, 3, 1) + elif fmt == Format.NLC: + x = x.flatten(2).transpose(1, 2) + elif fmt == Format.NCL: + x = x.flatten(2) + return x + + +def resample_patch_embed( + patch_embed, + new_size: List[int], + interpolation: str = "bicubic", + antialias: bool = True, + verbose: bool = False, +): + """Resample the weights of the patch embedding kernel to target resolution. + We resample the patch embedding kernel by approximately inverting the effect + of patch resizing. + + Code based on: + https://github.com/google-research/big_vision/blob/b00544b81f8694488d5f36295aeb7972f3755ffe/big_vision/models/proj/flexi/vit.py + + With this resizing, we can for example load a B/8 filter into a B/16 model + and, on 2x larger input image, the result will match. + + Args: + patch_embed: original parameter to be resized. + new_size (tuple(int, int): target shape (height, width)-only. + interpolation (str): interpolation for resize + antialias (bool): use anti-aliasing filter in resize + verbose (bool): log operation + Returns: + Resized patch embedding kernel. + """ + import numpy as np + + try: + from torch import vmap + except ImportError: + from functorch import vmap + + assert len(patch_embed.shape) == 4, "Four dimensions expected" + assert len(new_size) == 2, "New shape should only be hw" + old_size = patch_embed.shape[-2:] + if tuple(old_size) == tuple(new_size): + return patch_embed + + if verbose: + logger.info( + f"Resize patch embedding {patch_embed.shape} to {new_size}, w/ {interpolation} interpolation." + ) + + def resize(x_np, _new_size): + x_tf = torch.Tensor(x_np)[None, None, ...] + x_upsampled = F.interpolate( + x_tf, size=_new_size, mode=interpolation, antialias=antialias + )[0, 0, ...].numpy() + return x_upsampled + + def get_resize_mat(_old_size, _new_size): + mat = [] + for i in range(np.prod(_old_size)): + basis_vec = np.zeros(_old_size) + basis_vec[np.unravel_index(i, _old_size)] = 1.0 + mat.append(resize(basis_vec, _new_size).reshape(-1)) + return np.stack(mat).T + + resize_mat = get_resize_mat(old_size, new_size) + resize_mat_pinv = torch.tensor( + np.linalg.pinv(resize_mat.T), device=patch_embed.device + ) + + def resample_kernel(kernel): + resampled_kernel = resize_mat_pinv @ kernel.reshape(-1) + return resampled_kernel.reshape(new_size) + + v_resample_kernel = vmap(vmap(resample_kernel, 0, 0), 1, 1) + orig_dtype = patch_embed.dtype + patch_embed = patch_embed.float() + patch_embed = v_resample_kernel(patch_embed) + patch_embed = patch_embed.to(orig_dtype) + return patch_embed + + +# Copied from: +# https://github.com/deepseek-ai/Janus/tree/main/janus/models/siglip_vit.py +class PatchEmbed(nn.Module): + """2D Image to Patch Embedding""" + + output_fmt: Format + dynamic_img_pad: torch.jit.Final[bool] + + def __init__( + self, + img_size: Optional[int] = 224, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten: bool = True, + output_fmt: Optional[str] = None, + bias: bool = True, + strict_img_size: bool = True, + dynamic_img_pad: bool = False, + ): + super().__init__() + self.patch_size = tuple(to_2tuple(patch_size)) + self.img_size, self.grid_size, self.num_patches = self._init_img_size(img_size) + + if output_fmt is not None: + self.flatten = False + self.output_fmt = Format(output_fmt) + else: + # flatten spatial dim and transpose to channels last, kept for bwd compat + self.flatten = flatten + self.output_fmt = Format.NCHW + self.strict_img_size = strict_img_size + self.dynamic_img_pad = dynamic_img_pad + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias + ) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def _init_img_size(self, img_size: Union[int, Tuple[int, int]]): + assert self.patch_size + if img_size is None: + return None, None, None + img_size = to_2tuple(img_size) + grid_size = tuple([s // p for s, p in zip(img_size, self.patch_size)]) + num_patches = grid_size[0] * grid_size[1] + return img_size, grid_size, num_patches + + def set_input_size( + self, + img_size: Optional[Union[int, Tuple[int, int]]] = None, + patch_size: Optional[Union[int, Tuple[int, int]]] = None, + ): + new_patch_size = None + if patch_size is not None: + new_patch_size = to_2tuple(patch_size) + if new_patch_size is not None and new_patch_size != self.patch_size: + with torch.no_grad(): + new_proj = nn.Conv2d( + self.proj.in_channels, + self.proj.out_channels, + kernel_size=new_patch_size, + stride=new_patch_size, + bias=self.proj.bias is not None, + ) + new_proj.weight.copy_( + resample_patch_embed(self.proj.weight, new_patch_size, verbose=True) + ) + if self.proj.bias is not None: + new_proj.bias.copy_(self.proj.bias) + self.proj = new_proj + self.patch_size = new_patch_size + img_size = img_size or self.img_size + if img_size != self.img_size or new_patch_size is not None: + self.img_size, self.grid_size, self.num_patches = self._init_img_size( + img_size + ) + + def feat_ratio(self, as_scalar=True) -> Union[Tuple[int, int], int]: + if as_scalar: + return max(self.patch_size) + else: + return self.patch_size + + def dynamic_feat_size(self, img_size: Tuple[int, int]) -> Tuple[int, int]: + """Get grid (feature) size for given image size taking account of dynamic padding. + NOTE: must be torchscript compatible so using fixed tuple indexing + """ + if self.dynamic_img_pad: + return math.ceil(img_size[0] / self.patch_size[0]), math.ceil( + img_size[1] / self.patch_size[1] + ) + else: + return img_size[0] // self.patch_size[0], img_size[1] // self.patch_size[1] + + def forward(self, x): + B, C, H, W = x.shape + if self.img_size is not None: + if self.strict_img_size: + _assert( + H == self.img_size[0], + f"Input height ({H}) doesn't match model ({self.img_size[0]}).", + ) + _assert( + W == self.img_size[1], + f"Input width ({W}) doesn't match model ({self.img_size[1]}).", + ) + elif not self.dynamic_img_pad: + _assert( + H % self.patch_size[0] == 0, + f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]}).", + ) + _assert( + W % self.patch_size[1] == 0, + f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]}).", + ) + if self.dynamic_img_pad: + pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0] + pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1] + x = F.pad(x, (0, pad_w, 0, pad_h)) + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # NCHW -> NLC + elif self.output_fmt != Format.NCHW: + x = nchw_to(x, self.output_fmt) + x = self.norm(x) + return x + + +class Mlp(nn.Module): + """MLP as used in Vision Transformer, MLP-Mixer and related networks + + NOTE: When use_conv=True, expects 2D NCHW tensors, otherwise N*C expected. + """ + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + norm_layer=None, + bias=True, + drop=0.0, + use_conv=False, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + bias = to_2tuple(bias) + drop_probs = to_2tuple(drop) + linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear + + self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.norm = ( + norm_layer(hidden_features) if norm_layer is not None else nn.Identity() + ) + self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.norm(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +def drop_path( + x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True +): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + + """ + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * ( + x.ndim - 1 + ) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) + + def extra_repr(self): + return f"drop_prob={round(self.drop_prob, 3):0.3f}" + + +class VisionTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + qk_norm: bool = False, + proj_drop: float = 0.0, + attn_drop: float = 0.0, + init_values: Optional[float] = None, + drop_path: float = 0.0, + act_layer: nn.Module = nn.GELU, + norm_layer: nn.Module = nn.LayerNorm, + mlp_layer: nn.Module = Mlp, + ) -> None: + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = VisionAttention( + embed_dim=dim, + num_heads=num_heads, + projection_size=dim, + use_qkv_parallel=True, + use_context_forward=False, + softmax_in_single_precision=False, + dropout=attn_drop, + ) + + self.ls1 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + self.mlp = mlp_layer( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=proj_drop, + ) + self.ls2 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x)))) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) + return x + + +LayerType = Union[str, Callable, Type[torch.nn.Module]] + + +class PatchDropout(nn.Module): + """ + https://arxiv.org/abs/2212.00794 and https://arxiv.org/pdf/2208.07220 + """ + + return_indices: torch.jit.Final[bool] + + def __init__( + self, + prob: float = 0.5, + num_prefix_tokens: int = 1, + ordered: bool = False, + return_indices: bool = False, + ): + super().__init__() + assert 0 <= prob < 1.0 + self.prob = prob + self.num_prefix_tokens = ( + num_prefix_tokens # exclude CLS token (or other prefix tokens) + ) + self.ordered = ordered + self.return_indices = return_indices + + def forward( + self, x + ) -> Union[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]: + if not self.training or self.prob == 0.0: + if self.return_indices: + return x, None + return x + + if self.num_prefix_tokens: + prefix_tokens, x = ( + x[:, : self.num_prefix_tokens], + x[:, self.num_prefix_tokens :], + ) + else: + prefix_tokens = None + + B = x.shape[0] + L = x.shape[1] + num_keep = max(1, int(L * (1.0 - self.prob))) + keep_indices = torch.argsort(torch.randn(B, L, device=x.device), dim=-1)[ + :, :num_keep + ] + if self.ordered: + # NOTE does not need to maintain patch order in typical transformer use, + # but possibly useful for debug / visualization + keep_indices = keep_indices.sort(dim=-1)[0] + x = x.gather(1, keep_indices.unsqueeze(-1).expand((-1, -1) + x.shape[2:])) + + if prefix_tokens is not None: + x = torch.cat((prefix_tokens, x), dim=1) + + if self.return_indices: + return x, keep_indices + return x + + +def resample_abs_pos_embed( + posemb: torch.Tensor, + new_size: List[int], + old_size: Optional[List[int]] = None, + num_prefix_tokens: int = 1, + interpolation: str = "bicubic", + antialias: bool = True, + verbose: bool = False, +): + # sort out sizes, assume square if old size not provided + num_pos_tokens = posemb.shape[1] + num_new_tokens = new_size[0] * new_size[1] + num_prefix_tokens + if num_new_tokens == num_pos_tokens and new_size[0] == new_size[1]: + return posemb + + if old_size is None: + hw = int(math.sqrt(num_pos_tokens - num_prefix_tokens)) + old_size = hw, hw + + if num_prefix_tokens: + posemb_prefix, posemb = ( + posemb[:, :num_prefix_tokens], + posemb[:, num_prefix_tokens:], + ) + else: + posemb_prefix, posemb = None, posemb + + # do the interpolation + embed_dim = posemb.shape[-1] + orig_dtype = posemb.dtype + posemb = posemb.float() # interpolate needs float32 + posemb = posemb.reshape(1, old_size[0], old_size[1], -1).permute(0, 3, 1, 2) + posemb = F.interpolate( + posemb, size=new_size, mode=interpolation, antialias=antialias + ) + posemb = posemb.permute(0, 2, 3, 1).reshape(1, -1, embed_dim) + posemb = posemb.to(orig_dtype) + + # add back extra (class, etc) prefix tokens + if posemb_prefix is not None: + posemb = torch.cat([posemb_prefix, posemb], dim=1) + + if not torch.jit.is_scripting() and verbose: + logger.info(f"Resized position embedding: {old_size} to {new_size}.") + + return posemb + + +def init_weights(self): + if self.pos_embed is not None: + trunc_normal_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5) + trunc_normal_(self.latent, std=self.latent_dim**-0.5) + + +def init_weights_vit_timm(module: nn.Module, name: str = "") -> None: + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif hasattr(module, "init_weights"): + module.init_weights() + + +class VisionTransformer(nn.Module): + """Vision Transformer + + A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` + - https://arxiv.org/abs/2010.11929 + """ + + dynamic_img_size: Final[bool] + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + num_classes: int = 1000, + global_pool: Literal["", "avg", "token", "map"] = "token", + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + qk_norm: bool = False, + init_values: Optional[float] = None, + class_token: bool = True, + no_embed_class: bool = False, + reg_tokens: int = 0, + pre_norm: bool = False, + fc_norm: Optional[bool] = None, + dynamic_img_size: bool = False, + dynamic_img_pad: bool = False, + drop_rate: float = 0.0, + pos_drop_rate: float = 0.0, + patch_drop_rate: float = 0.0, + proj_drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + drop_path_rate: float = 0.0, + weight_init: Literal["skip", "jax", "jax_nlhb", "moco", ""] = "", + embed_layer: Callable = PatchEmbed, + _norm_layer: Optional[LayerType] = None, + _act_layer: Optional[LayerType] = None, + block_fn: Type[nn.Module] = VisionTransformerBlock, + mlp_layer: Type[nn.Module] = Mlp, + ignore_head: bool = False, + ) -> None: + """ + Args: + img_size: Input image size. + patch_size: Patch size. + in_chans: Number of image input channels. + num_classes: Mumber of classes for classification head. + global_pool: Type of global pooling for final sequence (default: 'token'). + embed_dim: Transformer embedding dimension. + depth: Depth of transformer. + num_heads: Number of attention heads. + mlp_ratio: Ratio of mlp hidden dim to embedding dim. + qkv_bias: Enable bias for qkv projections if True. + init_values: Layer-scale init values (layer-scale enabled if not None). + class_token: Use class token. + no_embed_class: Don't include position embeddings for class (or reg) tokens. + reg_tokens: Number of register tokens. + fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'. + drop_rate: Head dropout rate. + pos_drop_rate: Position embedding dropout rate. + attn_drop_rate: Attention dropout rate. + drop_path_rate: Stochastic depth rate. + weight_init: Weight initialization scheme. + embed_layer: Patch embedding layer. + _norm_layer: Normalization layer. + _act_layer: MLP activation layer. + block_fn: Transformer block layer. + """ + super().__init__() + assert global_pool in ("", "avg", "token", "map") + assert class_token or global_pool != "token" + use_fc_norm = global_pool == "avg" if fc_norm is None else fc_norm + # norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6) + # act_layer = get_act_layer(act_layer) or nn.GELU + norm_layer = partial(nn.LayerNorm, eps=1e-6) + act_layer = nn.GELU + + self.num_classes = num_classes + self.global_pool = global_pool + self.num_features = self.embed_dim = ( + embed_dim # num_features for consistency with other models + ) + self.num_prefix_tokens = 1 if class_token else 0 + self.num_prefix_tokens += reg_tokens + self.num_reg_tokens = reg_tokens + self.has_class_token = class_token + self.no_embed_class = ( + no_embed_class # don't embed prefix positions (includes reg) + ) + self.dynamic_img_size = dynamic_img_size + self.grad_checkpointing = False + self.ignore_head = ignore_head + + embed_args = {} + if dynamic_img_size: + # flatten deferred until after pos embed + embed_args.update(dict(strict_img_size=False, output_fmt="NHWC")) + self.patch_embed = embed_layer( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP) + dynamic_img_pad=dynamic_img_pad, + **embed_args, + ) + num_patches = self.patch_embed.num_patches + + self.cls_token = ( + nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None + ) + self.reg_token = ( + nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None + ) + embed_len = ( + num_patches if no_embed_class else num_patches + self.num_prefix_tokens + ) + self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * 0.02) + self.pos_drop = nn.Dropout(p=pos_drop_rate) + if patch_drop_rate > 0: + self.patch_drop = PatchDropout( + patch_drop_rate, + num_prefix_tokens=self.num_prefix_tokens, + ) + else: + self.patch_drop = nn.Identity() + self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity() + + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, depth) + ] # stochastic depth decay rule + self.blocks = nn.Sequential( + *[ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + init_values=init_values, + proj_drop=proj_drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + mlp_layer=mlp_layer, + ) + for i in range(depth) + ] + ) + self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity() + + # Classifier Head + if global_pool == "map": + AttentionPoolLatent.init_weights = init_weights + self.attn_pool = AttentionPoolLatent( + self.embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + norm_layer=norm_layer, + ) + else: + self.attn_pool = None + self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity() + self.head_drop = nn.Dropout(drop_rate) + self.head = ( + nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + ) + + if weight_init != "skip": + self.init_weights(weight_init) + + def init_weights(self, mode: Literal["jax", "jax_nlhb", "moco", ""] = "") -> None: + assert mode in ("jax", "jax_nlhb", "moco", "") + # head_bias = -math.log(self.num_classes) if "nlhb" in mode else 0.0 + trunc_normal_(self.pos_embed, std=0.02) + if self.cls_token is not None: + nn.init.normal_(self.cls_token, std=1e-6) + named_apply(init_weights_vit_timm, self) + + @torch.jit.ignore + def no_weight_decay(self) -> Set: + return {"pos_embed", "cls_token", "dist_token"} + + @torch.jit.ignore + def group_matcher(self, coarse: bool = False) -> Dict: + return dict( + stem=r"^cls_token|pos_embed|patch_embed", # stem and embed + blocks=[(r"^blocks\.(\d+)", None), (r"^norm", (99999,))], + ) + + @torch.jit.ignore + def get_classifier(self) -> nn.Module: + return self.head + + def reset_classifier(self, num_classes: int, global_pool=None) -> None: + self.num_classes = num_classes + if global_pool is not None: + assert global_pool in ("", "avg", "token", "map") + if global_pool == "map" and self.attn_pool is None: + assert ( + False + ), "Cannot currently add attention pooling in reset_classifier()." + elif global_pool != "map " and self.attn_pool is not None: + self.attn_pool = None # remove attention pooling + self.global_pool = global_pool + self.head = ( + nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + ) + + def _pos_embed(self, x: torch.Tensor) -> torch.Tensor: + if self.dynamic_img_size: + B, H, W, C = x.shape + pos_embed = resample_abs_pos_embed( + self.pos_embed, + [H, W], + num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens, + ) + x = x.view(B, -1, C) + else: + pos_embed = self.pos_embed + + to_cat = [] + if self.cls_token is not None: + to_cat.append(self.cls_token.expand(x.shape[0], -1, -1)) + if self.reg_token is not None: + to_cat.append(self.reg_token.expand(x.shape[0], -1, -1)) + + if self.no_embed_class: + # deit-3, updated JAX (big vision) + # position embedding does not overlap with class token, add then concat + x = x + pos_embed + if to_cat: + x = torch.cat(to_cat + [x], dim=1) + else: + # original timm, JAX, and deit vit impl + # pos_embed has entry for class token, concat then add + if to_cat: + x = torch.cat(to_cat + [x], dim=1) + x = x + pos_embed + + return self.pos_drop(x) + + def _intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, + ) -> List[torch.Tensor]: + outputs, num_blocks = [], len(self.blocks) + take_indices = set( + range(num_blocks - n, num_blocks) if isinstance(n, int) else n + ) + + # forward pass + x = self.patch_embed(x) + x = self._pos_embed(x) + x = self.patch_drop(x) + x = self.norm_pre(x) + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in take_indices: + outputs.append(x) + + return outputs + + def forward_features(self, x: torch.Tensor) -> torch.Tensor: + x = self.patch_embed(x) + x = self._pos_embed(x) + x = self.patch_drop(x) + x = self.norm_pre(x) + x = self.blocks(x) + x = self.norm(x) + return x + + def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: + if self.attn_pool is not None: + x = self.attn_pool(x) + elif self.global_pool == "avg": + x = x[:, self.num_prefix_tokens :].mean(dim=1) + elif self.global_pool: + x = x[:, 0] # class token + x = self.fc_norm(x) + x = self.head_drop(x) + return x if pre_logits else self.head(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.forward_features(x) + if not self.ignore_head: + x = self.forward_head(x) + return x + + +def model_name_to_cls(cls_name): + if "MlpProjector" in cls_name: + cls = MlpProjector + + elif "CLIPVisionTower" in cls_name: + cls = CLIPVisionTower + + elif "VQ" in cls_name: + + cls = VQ_models[cls_name] + elif "vision_head" in cls_name: + cls = vision_head + else: + raise ValueError(f"class_name {cls_name} is invalid.") + + return cls + + +class vision_head(torch.nn.Module): + def __init__(self, params): + super().__init__() + self.output_mlp_projector = torch.nn.Linear( + params["n_embed"], params["image_token_embed"] + ) + self.vision_activation = torch.nn.GELU() + self.vision_head = torch.nn.Linear( + params["image_token_embed"], params["image_token_size"] + ) + + def forward(self, x): + x = self.output_mlp_projector(x) + x = self.vision_activation(x) + x = self.vision_head(x) + return x + + +SigLIP_MODEL_CONFIG = { + "siglip_so400m_patch14_384": { + "image_size": 336, + "patch_size": 14, + "width": 1152, + "layers": 27, + "heads": 16, + "mlp_ratio": 3.7362, + "global_pool": "map", + "use_checkpoint": False, + }, + "siglip_so400m_patch14_224": { + "image_size": 224, + "patch_size": 14, + "width": 1152, + "layers": 27, + "heads": 16, + "mlp_ratio": 3.7362, + "global_pool": "map", + "use_checkpoint": False, + }, + "siglip_large_patch16_384": { + "image_size": 384, + "patch_size": 16, + "width": 1024, + "layers": 24, + "heads": 16, + "mlp_ratio": 4, + "global_pool": "map", + "use_checkpoint": False, + }, +} + + +def create_siglip_vit( + model_name: str = "siglip_so400m_patch14_384", + image_size: int = 384, + select_layer: int = -1, + ckpt_path: str = "", + **kwargs, +): + assert ( + model_name in SigLIP_MODEL_CONFIG.keys() + ), f"model name should be in {SigLIP_MODEL_CONFIG.keys()}" + + vision_cfg = SigLIPVisionCfg(**SigLIP_MODEL_CONFIG[model_name]) + + if select_layer <= 0: + layers = min(vision_cfg.layers, vision_cfg.layers + select_layer + 1) + else: + layers = min(vision_cfg.layers, select_layer) + + model = VisionTransformer( + img_size=image_size, + patch_size=vision_cfg.patch_size, + embed_dim=vision_cfg.width, + depth=layers, + num_heads=vision_cfg.heads, + mlp_ratio=vision_cfg.mlp_ratio, + class_token=vision_cfg.class_token, + global_pool=vision_cfg.global_pool, + ignore_head=kwargs.get("ignore_head", True), + weight_init=kwargs.get("weight_init", "skip"), + num_classes=0, + ) + + if ckpt_path: + state_dict = torch.load(ckpt_path, map_location="cpu") + + incompatible_keys = model.load_state_dict(state_dict, strict=False) + print( + f"SigLIP-ViT restores from {ckpt_path},\n" + f"\tincompatible_keys:', {incompatible_keys}." + ) + + return model + + +class Normalize(torch.nn.Module): + """Normalize a tensor image with mean and standard deviation. + This transform does not support PIL Image. + Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n`` + channels, this transform will normalize each channel of the input + ``torch.*Tensor`` i.e., + ``output[channel] = (input[channel] - mean[channel]) / std[channel]`` + + .. note:: + This transform acts out of place, i.e., it does not mutate the input tensor. + + Args: + mean (sequence): Sequence of means for each channel. + std (sequence): Sequence of standard deviations for each channel. + inplace(bool,optional): Bool to make this operation in-place. + + """ + + def __init__(self, mean, std, inplace=False): + super().__init__() + # _log_api_usage_once(self) + self.mean = mean + self.std = std + self.inplace = inplace + + def forward(self, tensor: Tensor) -> Tensor: + """ + Args: + tensor (Tensor): Tensor image to be normalized. + + Returns: + Tensor: Normalized Tensor image. + """ + return F.normalize(tensor, self.mean, self.std, self.inplace) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(mean={self.mean}, std={self.std})" + + +class CLIPVisionTower(nn.Module): + def __init__( + self, + model_name: str = "siglip_large_patch16_384", + image_size: Union[Tuple[int, int], int] = 336, + select_feature: str = "patch", + select_layer: int = -2, + select_layers: list = None, + ckpt_path: str = "", + pixel_mean: Optional[List[float]] = None, + pixel_std: Optional[List[float]] = None, + **kwargs, + ): + super().__init__() + + self.model_name = model_name + self.select_feature = select_feature + self.select_layer = select_layer + self.select_layers = select_layers + + vision_tower_params = { + "model_name": model_name, + "image_size": image_size, + "ckpt_path": ckpt_path, + "select_layer": select_layer, + } + vision_tower_params.update(kwargs) + self.vision_tower, self.forward_kwargs = self.build_vision_tower( + vision_tower_params + ) + + if pixel_mean is not None and pixel_std is not None: + image_norm = Normalize(mean=pixel_mean, std=pixel_std) + else: + image_norm = None + + self.image_norm = image_norm + + @property + def device(self) -> torch.device: + return next(self.vision_tower.parameters()).device + + @property + def dtype(self): + return next(self.vision_tower.parameters()).dtype + + def build_vision_tower(self, vision_tower_params): + if self.model_name.startswith("siglip"): + self.select_feature = "same" + vision_tower = create_siglip_vit(**vision_tower_params) + forward_kwargs = dict() + + elif self.model_name.startswith("sam"): + # vision_tower = create_sam_vit(**vision_tower_params) + forward_kwargs = dict() + + else: # huggingface + from transformers import CLIPVisionModel + + vision_tower = CLIPVisionModel.from_pretrained(**vision_tower_params) + forward_kwargs = dict(output_hidden_states=True) + + return vision_tower, forward_kwargs + + def feature_select(self, image_forward_outs): + if isinstance(image_forward_outs, torch.Tensor): + # the output has been the self.select_layer"s features + image_features = image_forward_outs + else: + image_features = image_forward_outs.hidden_states[self.select_layer] + + if self.select_feature == "patch": + # if the output has cls_token + image_features = image_features[:, 1:] + elif self.select_feature == "cls_patch": + image_features = image_features + elif self.select_feature == "same": + image_features = image_features + + else: + raise ValueError(f"Unexpected select feature: {self.select_feature}") + return image_features + + def forward(self, images): + """ + + Args: + images (torch.Tensor): [b, 3, H, W] + + Returns: + image_features (torch.Tensor): [b, n_patch, d] + """ + + if self.image_norm is not None: + images = self.image_norm(images) + + image_forward_outs = self.vision_tower(images, **self.forward_kwargs) + image_features = self.feature_select(image_forward_outs) + return image_features + + +class MlpProjector(nn.Module): + def __init__(self, cfg): + super().__init__() + + self.cfg = cfg + + if cfg["projector_type"] == "identity": + modules = nn.Identity() + + elif cfg["projector_type"] == "linear": + modules = nn.Linear(cfg["input_dim"], cfg["n_embed"]) + + elif cfg["projector_type"] == "mlp_gelu": + mlp_depth = cfg.get("depth", 1) + modules = [nn.Linear(cfg["input_dim"], cfg["n_embed"])] + for _ in range(1, mlp_depth): + modules.append(nn.GELU()) + modules.append(nn.Linear(cfg["n_embed"], cfg["n_embed"])) + modules = nn.Sequential(*modules) + + elif cfg["projector_type"] == "low_high_hybrid_split_mlp_gelu": + mlp_depth = cfg.get("depth", 1) + self.high_up_proj = nn.Linear(cfg["input_dim"], cfg["n_embed"] // 2) + self.low_up_proj = nn.Linear(cfg["input_dim"], cfg["n_embed"] // 2) + + modules = [] + for _ in range(1, mlp_depth): + modules.append(nn.GELU()) + modules.append(nn.Linear(cfg["n_embed"], cfg["n_embed"])) + modules = nn.Sequential(*modules) + + else: + raise ValueError(f"Unknown projector type: {cfg['projector_type']}") + + self.layers = modules + + def forward( + self, x_or_tuple: Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor] + ): + """ + + Args: + x_or_tuple (Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: if it is a tuple of torch.Tensor, + then it comes from the hybrid vision encoder, and x = high_res_x, low_res_x); + otherwise it is the feature from the single vision encoder. + + Returns: + x (torch.Tensor): [b, s, c] + """ + + if isinstance(x_or_tuple, tuple): + # self.cfg.projector_type == "low_high_hybrid_split_mlp_gelu": + high_x, low_x = x_or_tuple + high_x = self.high_up_proj(high_x) + low_x = self.low_up_proj(low_x) + x = torch.concat([high_x, low_x], dim=-1) + else: + x = x_or_tuple + + return self.layers(x) + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: float = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +# use torch.scaled_dot_product_attention where possible +_HAS_FUSED_ATTN = hasattr(torch.nn.functional, "scaled_dot_product_attention") +if "TIMM_FUSED_ATTN" in os.environ: + _USE_FUSED_ATTN = int(os.environ["TIMM_FUSED_ATTN"]) +else: + _USE_FUSED_ATTN = ( + 1 # 0 == off, 1 == on (for tested use), 2 == on (for experimental use) + ) + +# Set to True if exporting a model with Same padding via ONNX +_EXPORTABLE = False + + +def use_fused_attn(experimental: bool = False) -> bool: + # NOTE: ONNX export cannot handle F.scaled_dot_product_attention as of pytorch 2.0 + if not _HAS_FUSED_ATTN or _EXPORTABLE: + return False + if experimental: + return _USE_FUSED_ATTN > 1 + return _USE_FUSED_ATTN > 0 + + +class AttentionPoolLatent(nn.Module): + """Attention pooling w/ latent query""" + + fused_attn: torch.jit.Final[bool] + + def __init__( + self, + in_features: int, + out_features: int = None, + embed_dim: int = None, + num_heads: int = 8, + feat_size: Optional[int] = None, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + qk_norm: bool = False, + latent_len: int = 1, + latent_dim: int = None, + pos_embed: str = "", + pool_type: str = "token", + norm_layer: Optional[nn.Module] = None, + drop: float = 0.0, + ): + super().__init__() + embed_dim = embed_dim or in_features + out_features = out_features or in_features + assert embed_dim % num_heads == 0 + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + self.feat_size = feat_size + self.scale = self.head_dim**-0.5 + self.pool = pool_type + self.fused_attn = use_fused_attn() + + if pos_embed == "abs": + assert feat_size is not None + self.pos_embed = nn.Parameter(torch.zeros(feat_size, in_features)) + else: + self.pos_embed = None + + self.latent_dim = latent_dim or embed_dim + self.latent_len = latent_len + self.latent = nn.Parameter(torch.zeros(1, self.latent_len, embed_dim)) + + self.q = nn.Linear(embed_dim, embed_dim, bias=qkv_bias) + self.kv = nn.Linear(embed_dim, embed_dim * 2, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.proj = nn.Linear(embed_dim, embed_dim) + self.proj_drop = nn.Dropout(drop) + + self.norm = ( + norm_layer(out_features) if norm_layer is not None else nn.Identity() + ) + self.mlp = Mlp(embed_dim, int(embed_dim * mlp_ratio)) + + self.init_weights() + + def init_weights(self): + if self.pos_embed is not None: + trunc_normal_tf_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5) + trunc_normal_tf_(self.latent, std=self.latent_dim**-0.5) + + def forward(self, x): + B, N, C = x.shape + + if self.pos_embed is not None: + # FIXME interpolate + x = x + self.pos_embed.unsqueeze(0).to(x.dtype) + + q_latent = self.latent.expand(B, -1, -1) + q = ( + self.q(q_latent) + .reshape(B, self.latent_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + + kv = ( + self.kv(x) + .reshape(B, N, 2, self.num_heads, self.head_dim) + .permute(2, 0, 3, 1, 4) + ) + k, v = kv.unbind(0) + + q, k = self.q_norm(q), self.k_norm(k) + + if self.fused_attn: + x = F.scaled_dot_product_attention(q, k, v) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + x = attn @ v + x = x.transpose(1, 2).reshape(B, self.latent_len, C) + x = self.proj(x) + x = self.proj_drop(x) + + x = x + self.mlp(self.norm(x)) + + # optional pool if latent seq_len > 1 and pooled output is desired + if self.pool == "token": + x = x[:, 0] + elif self.pool == "avg": + x = x.mean(1) + + +class Encoder(nn.Module): + def __init__( + self, + in_channels=3, + ch=128, + ch_mult=(1, 1, 2, 2, 4), + num_res_blocks=2, + norm_type="group", + dropout=0.0, + resamp_with_conv=True, + z_channels=256, + ): + super().__init__() + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.conv_in = nn.Conv2d(in_channels, ch, kernel_size=3, stride=1, padding=1) + + # downsampling + in_ch_mult = (1,) + tuple(ch_mult) + self.conv_blocks = nn.ModuleList() + for i_level in range(self.num_resolutions): + conv_block = nn.Module() + # res & attn + res_block = nn.ModuleList() + attn_block = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks): + res_block.append( + ResnetBlock( + block_in, block_out, dropout=dropout, norm_type=norm_type + ) + ) + block_in = block_out + if i_level == self.num_resolutions - 1: + attn_block.append(AttnBlock(block_in, norm_type)) + conv_block.res = res_block + conv_block.attn = attn_block + # downsample + if i_level != self.num_resolutions - 1: + conv_block.downsample = Downsample(block_in, resamp_with_conv) + self.conv_blocks.append(conv_block) + + # middle + self.mid = nn.ModuleList() + self.mid.append( + ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type) + ) + self.mid.append(AttnBlock(block_in, norm_type=norm_type)) + self.mid.append( + ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type) + ) + + # end + self.norm_out = Normalize(block_in, norm_type) + self.conv_out = nn.Conv2d( + block_in, z_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + h = self.conv_in(x) + # downsampling + for i_level, block in enumerate(self.conv_blocks): + for i_block in range(self.num_res_blocks): + h = block.res[i_block](h) + if len(block.attn) > 0: + h = block.attn[i_block](h) + if i_level != self.num_resolutions - 1: + h = block.downsample(h) + + # middle + for mid_block in self.mid: + h = mid_block(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + z_channels=256, + ch=128, + ch_mult=(1, 1, 2, 2, 4), + num_res_blocks=2, + norm_type="group", + dropout=0.0, + resamp_with_conv=True, + out_channels=3, + ): + super().__init__() + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + + block_in = ch * ch_mult[self.num_resolutions - 1] + # z to block_in + self.conv_in = nn.Conv2d( + z_channels, block_in, kernel_size=3, stride=1, padding=1 + ) + + # middle + self.mid = nn.ModuleList() + self.mid.append( + ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type) + ) + self.mid.append(AttnBlock(block_in, norm_type=norm_type)) + self.mid.append( + ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type) + ) + + # upsampling + self.conv_blocks = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + conv_block = nn.Module() + # res & attn + res_block = nn.ModuleList() + attn_block = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks + 1): + res_block.append( + ResnetBlock( + block_in, block_out, dropout=dropout, norm_type=norm_type + ) + ) + block_in = block_out + if i_level == self.num_resolutions - 1: + attn_block.append(AttnBlock(block_in, norm_type)) + conv_block.res = res_block + conv_block.attn = attn_block + # downsample + if i_level != 0: + conv_block.upsample = Upsample(block_in, resamp_with_conv) + self.conv_blocks.append(conv_block) + + # end + self.norm_out = Normalize(block_in, norm_type) + self.conv_out = nn.Conv2d( + block_in, out_channels, kernel_size=3, stride=1, padding=1 + ) + + @property + def last_layer(self): + return self.conv_out.weight + + def forward(self, z): + # z to block_in + h = self.conv_in(z) + + # middle + for mid_block in self.mid: + h = mid_block(h) + + # upsampling + for i_level, block in enumerate(self.conv_blocks): + for i_block in range(self.num_res_blocks + 1): + h = block.res[i_block](h) + if len(block.attn) > 0: + h = block.attn[i_block](h) + if i_level != self.num_resolutions - 1: + h = block.upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class VectorQuantizer(nn.Module): + def __init__(self, n_e, e_dim, beta, entropy_loss_ratio, l2_norm, show_usage): + super().__init__() + self.n_e = n_e + self.e_dim = e_dim + self.beta = beta + self.entropy_loss_ratio = entropy_loss_ratio + self.l2_norm = l2_norm + self.show_usage = show_usage + + self.embedding = nn.Embedding(self.n_e, self.e_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + if self.l2_norm: + self.embedding.weight.data = F.normalize( + self.embedding.weight.data, p=2, dim=-1 + ) + if self.show_usage: + # self.register_buffer("codebook_used", nn.Parameter(torch.zeros(65536))) + self.codebook_used = nn.Parameter(torch.zeros(65536)) + + def forward(self, z): + # reshape z -> (batch, height, width, channel) and flatten + z = torch.einsum("b c h w -> b h w c", z).contiguous() + z_flattened = z.view(-1, self.e_dim) + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + + if self.l2_norm: + z = F.normalize(z, p=2, dim=-1) + z_flattened = F.normalize(z_flattened, p=2, dim=-1) + embedding = F.normalize(self.embedding.weight, p=2, dim=-1) + else: + embedding = self.embedding.weight + + d = ( + torch.sum(z_flattened**2, dim=1, keepdim=True) + + torch.sum(embedding**2, dim=1) + - 2 + * torch.einsum( + "bd,dn->bn", z_flattened, torch.einsum("n d -> d n", embedding) + ) + ) + + min_encoding_indices = torch.argmin(d, dim=1) + z_q = embedding[min_encoding_indices].view(z.shape) + perplexity = None + min_encodings = None + vq_loss = None + commit_loss = None + entropy_loss = None + + # compute loss for embedding + if self.training: + vq_loss = torch.mean((z_q - z.detach()) ** 2) + commit_loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + entropy_loss = self.entropy_loss_ratio * compute_entropy_loss(-d) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # reshape back to match original input shape + z_q = torch.einsum("b h w c -> b c h w", z_q) + + return ( + z_q, + (vq_loss, commit_loss, entropy_loss), + (perplexity, min_encodings, min_encoding_indices), + ) + + def get_codebook_entry(self, indices, shape=None, channel_first=True): + # shape = (batch, channel, height, width) if channel_first else (batch, height, width, channel) + if self.l2_norm: + embedding = F.normalize(self.embedding.weight, p=2, dim=-1) + else: + embedding = self.embedding.weight + z_q = embedding[indices] # (b*h*w, c) + + if shape is not None: + if channel_first: + z_q = z_q.reshape(shape[0], shape[2], shape[3], shape[1]) + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + else: + z_q = z_q.view(shape) + return z_q + + +class ResnetBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout=0.0, + norm_type="group", + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels, norm_type) + self.conv1 = nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + self.norm2 = Normalize(out_channels, norm_type) + self.dropout = nn.Dropout(dropout) + self.conv2 = nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + else: + self.nin_shortcut = nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + return x + h + + +class AttnBlock(nn.Module): + def __init__(self, in_channels, norm_type="group"): + super().__init__() + self.norm = Normalize(in_channels, norm_type) + self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w) + q = q.permute(0, 2, 1) # b,hw,c + k = k.reshape(b, c, h * w) # b,c,hw + w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) + w_ = F.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w) + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w) + + h_ = self.proj_out(h_) + + return x + h_ + + +def nonlinearity(x): + # swish + return x * torch.sigmoid(x) + + +def Normalize(in_channels, norm_type="group"): + assert norm_type in ["group", "batch"] + if norm_type == "group": + return nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True + ) + elif norm_type == "batch": + return nn.SyncBatchNorm(in_channels) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + if x.dtype != torch.float32: + x = F.interpolate(x.to(torch.float), scale_factor=2.0, mode="nearest").to( + torch.bfloat16 + ) + else: + x = F.interpolate(x, scale_factor=2.0, mode="nearest") + + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=2, padding=0 + ) + + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = F.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = F.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +def compute_entropy_loss(affinity, loss_type="softmax", temperature=0.01): + flat_affinity = affinity.reshape(-1, affinity.shape[-1]) + flat_affinity /= temperature + probs = F.softmax(flat_affinity, dim=-1) + log_probs = F.log_softmax(flat_affinity + 1e-5, dim=-1) + if loss_type == "softmax": + target_probs = probs + else: + raise ValueError("Entropy loss {} not supported".format(loss_type)) + avg_probs = torch.mean(target_probs, dim=0) + avg_entropy = -torch.sum(avg_probs * torch.log(avg_probs + 1e-5)) + sample_entropy = -torch.mean(torch.sum(target_probs * log_probs, dim=-1)) + loss = sample_entropy - avg_entropy + return loss + + +class VQModel(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + self.config = config + self.encoder = Encoder( + ch_mult=config.encoder_ch_mult, + z_channels=config.z_channels, + dropout=config.dropout_p, + ) + self.decoder = Decoder( + ch_mult=config.decoder_ch_mult, + z_channels=config.z_channels, + dropout=config.dropout_p, + ) + + self.quantize = VectorQuantizer( + config.codebook_size, + config.codebook_embed_dim, + config.commit_loss_beta, + config.entropy_loss_ratio, + config.codebook_l2_norm, + config.codebook_show_usage, + ) + self.quant_conv = nn.Conv2d(config.z_channels, config.codebook_embed_dim, 1) + self.post_quant_conv = nn.Conv2d( + config.codebook_embed_dim, config.z_channels, 1 + ) + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + quant, emb_loss, info = self.quantize(h) + return quant, emb_loss, info + + def decode(self, quant): + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec + + def decode_code(self, code_b, shape=None, channel_first=True): + quant_b = self.quantize.get_codebook_entry(code_b, shape, channel_first) + dec = self.decode(quant_b) + return dec + + def forward(self, input): + quant, diff, _ = self.encode(input) + dec = self.decode(quant) + return dec, diff + + +class MultiModalityPreTrainedModel(PreTrainedModel): + config_class = MultiModalityConfig + base_model_prefix = "multi_modality" + _no_split_modules = [] + _skip_keys_device_placement = "past_key_values" + + +# Copied and adapted from: +# https://github.com/deepseek-ai/Janus/tree/main/janus/models/modeling_vlm.py +class MultiModalityCausalLM(MultiModalityPreTrainedModel): + + def __init__( + self, + config: MultiModalityConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__(config) + + vision_config = config.vision_config + vision_cls = model_name_to_cls(vision_config.cls) + self.vision_model = vision_cls(**vision_config.params) + + aligner_config = config.aligner_config + aligner_cls = model_name_to_cls(aligner_config.cls) + self.aligner = aligner_cls(aligner_config.params) + + gen_vision_config = config.gen_vision_config + gen_vision_cls = model_name_to_cls(gen_vision_config.cls) + self.gen_vision_model = gen_vision_cls() + + gen_aligner_config = config.gen_aligner_config + gen_aligner_cls = model_name_to_cls(gen_aligner_config.cls) + self.gen_aligner = gen_aligner_cls(gen_aligner_config.params) + + gen_head_config = config.gen_head_config + gen_head_cls = model_name_to_cls(gen_head_config.cls) + self.gen_head = gen_head_cls(gen_head_config.params) + + self.gen_embed = torch.nn.Embedding( + gen_vision_config.params["image_token_size"], + gen_vision_config.params["n_embed"], + ) + + language_config = config.language_config + self.language_model = LlamaForCausalLM( + language_config, quant_config=quant_config + ) + self.logits_processor = LogitsProcessor(config) + + def prepare_images_seq_mask( + self, input_ids: torch.Tensor, image_inputs: ImageInputs + ) -> Optional[torch.LongTensor]: + images_seq_mask = torch.isin( + input_ids, torch.tensor(image_inputs.pad_values, device=input_ids.device) + ) + if images_seq_mask.sum() == 0: + # sometimes image_inputs is not empty, but input_ids contain no image token because of prefix-cache + return None + else: + return images_seq_mask + + @torch.no_grad() + def forward( + self, + input_ids: torch.LongTensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + + inputs_embeds = None + if ( + forward_batch.image_inputs is not None + and len(forward_batch.image_inputs) != 0 + and forward_batch.image_inputs[0] is not None + ): + + image_inputs = forward_batch.image_inputs[0] + + images_seq_mask = self.prepare_images_seq_mask( + input_ids=input_ids, image_inputs=image_inputs + ) + + if images_seq_mask is not None: + input_ids.clamp_(min=0, max=self.config.vocab_size - 1) + inputs_embeds = self.prepare_inputs_embeds( + input_ids=input_ids, + pixel_values=image_inputs.pixel_values, + images_seq_mask=images_seq_mask, + images_emb_mask=image_inputs.images_emb_mask, + ) + input_ids = None + + if input_ids is not None: + input_ids.clamp_(min=0, max=self.config.vocab_size - 1) + + return self.language_model( + input_ids=input_ids, + positions=positions, + forward_batch=forward_batch, + input_embeds=inputs_embeds, + get_embedding=False, + ) + + def prepare_inputs_embeds( + self, + input_ids: torch.LongTensor, + pixel_values: torch.FloatTensor, + images_seq_mask: torch.LongTensor, + images_emb_mask: torch.BoolTensor, + **_kwargs, + ): + """ + + Args: + input_ids (torch.LongTensor): [b, T] + pixel_values (torch.FloatTensor): [b, n_images, 3, h, w] + images_seq_mask (torch.BoolTensor): [b, T] + images_emb_mask (torch.BoolTensor): [b, n_images, n_image_tokens] + + assert torch.sum(images_seq_mask) == torch.sum(images_emb_mask) + + Returns: + input_embeds (torch.Tensor): [b, T, D] + """ + + bs, n = pixel_values.shape[0:2] + pixel_values = pixel_values.to( + device=self.vision_model.device, dtype=self.vision_model.dtype + ) + images = rearrange(pixel_values, "b n c h w -> (b n) c h w") + + # [b x n, T2, D] + images_embeds = self.aligner(self.vision_model(images)) + + # [b x n, T2, D] -> [b, n x T2, D] + images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n) + # [b, n, T2] -> [b, n x T2] + images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)") + + # [b, T, D] + # ignore the image embeddings + input_ids[input_ids < 0] = 0 + inputs_embeds = self.language_model.model.embed_tokens(input_ids) + + # replace with the image embeddings + inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask] + + return inputs_embeds + + def prepare_gen_img_embeds(self, image_ids: torch.LongTensor): + return self.gen_aligner(self.gen_embed(image_ids)) + + def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs): + im_start_id = image_inputs.im_start_id + im_end_id = image_inputs.im_end_id + media_token_pairs = [(im_start_id, im_end_id)] + + helper = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs) + + return helper.pad_input_tokens(input_ids, image_inputs) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq~" in name or "projector" in name: + continue + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + if name.startswith("model.vision_tower") and name not in params_dict: + continue + + # skip generation sub model + if "gen" in name: + continue + + # adapt to VisionAttention + name = name.replace(r"self_attn.out_proj", r"self_attn.proj") + if "vision_model.vision_tower" in name: + name = name.replace("attn.qkv", "attn.qkv_proj") + + for param_name, weight_name, shard_id in stacked_params_mapping: + # replace the name and load with customized loader + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + # # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", None) + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +AutoModel.register(config_class=MultiModalityConfig, model_class=MultiModalityCausalLM) +EntryClass = [MultiModalityCausalLM] diff --git a/test/srt/test_vision_openai_server.py b/test/srt/test_vision_openai_server.py index a88f0e65c..c0e360468 100644 --- a/test/srt/test_vision_openai_server.py +++ b/test/srt/test_vision_openai_server.py @@ -512,5 +512,29 @@ class TestMinicpmvServer(TestOpenAIVisionServer): cls.base_url += "/v1" +class TestJanusProServer(TestOpenAIVisionServer): + @classmethod + def setUpClass(cls): + cls.model = "deepseek-ai/Janus-Pro-7B" + cls.base_url = DEFAULT_URL_FOR_TEST + cls.api_key = "sk-123456" + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--trust-remote-code", + "--chat-template", + "janus-pro", + "--mem-fraction-static", + "0.4", + ], + ) + cls.base_url += "/v1" + + def test_video_chat_completion(self): + pass + + if __name__ == "__main__": unittest.main()