model: Support Janus-pro (#3203)
This commit is contained in:
@@ -26,6 +26,7 @@ class EvalArgs:
|
|||||||
backend: str = "engine"
|
backend: str = "engine"
|
||||||
seed: int = 42
|
seed: int = 42
|
||||||
split: str = "validation"
|
split: str = "validation"
|
||||||
|
# Default setting to make the benchmark available on A100 for most 7B models
|
||||||
image_pixels_limit: int = 4300000
|
image_pixels_limit: int = 4300000
|
||||||
result_filename: str = ""
|
result_filename: str = ""
|
||||||
prompt_format_file: str = "prompt_format.yaml"
|
prompt_format_file: str = "prompt_format.yaml"
|
||||||
@@ -38,6 +39,7 @@ class EvalArgs:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--result-filename", type=str, default=EvalArgs.result_filename
|
"--result-filename", type=str, default=EvalArgs.result_filename
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--image-pixels-limit", type=int, default=EvalArgs.image_pixels_limit
|
"--image-pixels-limit", type=int, default=EvalArgs.image_pixels_limit
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -31,6 +31,7 @@
|
|||||||
- Phi-3 / Phi-4
|
- Phi-3 / Phi-4
|
||||||
- Phi-3-Small
|
- Phi-3-Small
|
||||||
- IBM Granite 3
|
- IBM Granite 3
|
||||||
|
- Janus-Pro-1B / Janus-Pro-7B
|
||||||
|
|
||||||
## Embedding Models
|
## Embedding Models
|
||||||
|
|
||||||
|
|||||||
@@ -230,6 +230,29 @@ register_chat_template(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
register_chat_template(
|
||||||
|
ChatTemplate(
|
||||||
|
name="janus-pro",
|
||||||
|
default_system_prompt=None,
|
||||||
|
role_prefix_and_suffix={
|
||||||
|
"system": (
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
),
|
||||||
|
"User": (
|
||||||
|
"<|User|>",
|
||||||
|
"",
|
||||||
|
),
|
||||||
|
"assistant": (
|
||||||
|
"<|Assistant|>",
|
||||||
|
"<|end▁of▁sentence|>",
|
||||||
|
),
|
||||||
|
},
|
||||||
|
stop_str=("<|end▁of▁sentence|>",),
|
||||||
|
image_token="<image_placeholder>\n",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# The difference between "llama-3-instruct-llava" and "llama-3-instruct" is that llava uses a different image_token.
|
# The difference between "llama-3-instruct-llava" and "llama-3-instruct" is that llava uses a different image_token.
|
||||||
register_chat_template(
|
register_chat_template(
|
||||||
ChatTemplate(
|
ChatTemplate(
|
||||||
@@ -384,6 +407,12 @@ def match_deepseek(model_path: str):
|
|||||||
return get_chat_template("deepseek-v3")
|
return get_chat_template("deepseek-v3")
|
||||||
|
|
||||||
|
|
||||||
|
@register_chat_template_matching_function
|
||||||
|
def match_deepseek_janus_pro(model_path: str):
|
||||||
|
if "janus" in model_path.lower():
|
||||||
|
return get_chat_template("janus-pro")
|
||||||
|
|
||||||
|
|
||||||
@register_chat_template_matching_function
|
@register_chat_template_matching_function
|
||||||
def match_dbrx(model_path: str):
|
def match_dbrx(model_path: str):
|
||||||
if "dbrx" in model_path.lower() and "instruct" in model_path.lower():
|
if "dbrx" in model_path.lower() and "instruct" in model_path.lower():
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from sglang.srt.configs.chatglm import ChatGLMConfig
|
from sglang.srt.configs.chatglm import ChatGLMConfig
|
||||||
from sglang.srt.configs.dbrx import DbrxConfig
|
from sglang.srt.configs.dbrx import DbrxConfig
|
||||||
from sglang.srt.configs.exaone import ExaoneConfig
|
from sglang.srt.configs.exaone import ExaoneConfig
|
||||||
|
from sglang.srt.configs.janus_pro import MultiModalityConfig
|
||||||
from sglang.srt.configs.qwen2_5_vl_config import (
|
from sglang.srt.configs.qwen2_5_vl_config import (
|
||||||
Qwen2_5_VLConfig,
|
Qwen2_5_VLConfig,
|
||||||
Qwen2_5_VLVisionConfig,
|
Qwen2_5_VLVisionConfig,
|
||||||
@@ -12,4 +13,5 @@ __all__ = [
|
|||||||
"DbrxConfig",
|
"DbrxConfig",
|
||||||
"Qwen2_5_VLConfig",
|
"Qwen2_5_VLConfig",
|
||||||
"Qwen2_5_VLVisionConfig",
|
"Qwen2_5_VLVisionConfig",
|
||||||
|
"MultiModalityConfig",
|
||||||
]
|
]
|
||||||
|
|||||||
629
python/sglang/srt/configs/janus_pro.py
Normal file
629
python/sglang/srt/configs/janus_pro.py
Normal file
@@ -0,0 +1,629 @@
|
|||||||
|
# Adapted from:
|
||||||
|
# https://github.com/deepseek-ai/Janus/tree/main/janus/models
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, List, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import PIL
|
||||||
|
import torch
|
||||||
|
from PIL.Image import Image
|
||||||
|
from transformers import (
|
||||||
|
AutoImageProcessor,
|
||||||
|
AutoProcessor,
|
||||||
|
BaseImageProcessor,
|
||||||
|
BatchFeature,
|
||||||
|
LlamaConfig,
|
||||||
|
LlamaTokenizerFast,
|
||||||
|
PretrainedConfig,
|
||||||
|
ProcessorMixin,
|
||||||
|
)
|
||||||
|
from transformers.image_utils import to_numpy_array
|
||||||
|
|
||||||
|
from sglang.srt.mm_utils import expand2square
|
||||||
|
|
||||||
|
|
||||||
|
class DictToObject(dict):
|
||||||
|
def __init__(self, dictionary):
|
||||||
|
super(self).__init__(dictionary)
|
||||||
|
|
||||||
|
for key, value in dictionary.items():
|
||||||
|
if isinstance(value, dict):
|
||||||
|
value = DictToObject(value)
|
||||||
|
setattr(self, key, value)
|
||||||
|
|
||||||
|
|
||||||
|
class VisionConfig(PretrainedConfig):
|
||||||
|
model_type = "vision"
|
||||||
|
cls: str = ""
|
||||||
|
params = {}
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
self.cls = kwargs.get("cls", "")
|
||||||
|
if not isinstance(self.cls, str):
|
||||||
|
self.cls = self.cls.__name__
|
||||||
|
|
||||||
|
self.params = kwargs.get("params", {})
|
||||||
|
|
||||||
|
|
||||||
|
class GenAlignerConfig(PretrainedConfig):
|
||||||
|
model_type = "gen_aligner"
|
||||||
|
cls: str = ""
|
||||||
|
params = {}
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
self.cls = kwargs.get("cls", "")
|
||||||
|
if not isinstance(self.cls, str):
|
||||||
|
self.cls = self.cls.__name__
|
||||||
|
|
||||||
|
self.params = kwargs.get("params", {})
|
||||||
|
|
||||||
|
|
||||||
|
class GenHeadConfig(PretrainedConfig):
|
||||||
|
model_type = "gen_head"
|
||||||
|
cls: str = ""
|
||||||
|
params = {}
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
self.cls = kwargs.get("cls", "")
|
||||||
|
if not isinstance(self.cls, str):
|
||||||
|
self.cls = self.cls.__name__
|
||||||
|
|
||||||
|
self.params = kwargs.get("params", {})
|
||||||
|
|
||||||
|
|
||||||
|
class AlignerConfig(PretrainedConfig):
|
||||||
|
model_type = "aligner"
|
||||||
|
cls: str = ""
|
||||||
|
params = {}
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
self.cls = kwargs.get("cls", "")
|
||||||
|
if not isinstance(self.cls, str):
|
||||||
|
self.cls = self.cls.__name__
|
||||||
|
|
||||||
|
self.params = kwargs.get("params", {})
|
||||||
|
|
||||||
|
|
||||||
|
class GenVisionConfig(PretrainedConfig):
|
||||||
|
model_type = "gen_vision"
|
||||||
|
cls: str = ""
|
||||||
|
params = {}
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
self.cls = kwargs.get("cls", "")
|
||||||
|
if not isinstance(self.cls, str):
|
||||||
|
self.cls = self.cls.__name__
|
||||||
|
|
||||||
|
self.params = kwargs.get("params", {})
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SigLIPVisionCfg:
|
||||||
|
width: int = 1152
|
||||||
|
layers: Union[Tuple[int, int, int, int], int] = 27
|
||||||
|
heads: int = 16
|
||||||
|
patch_size: int = 14
|
||||||
|
image_size: Union[Tuple[int, int], int] = 336
|
||||||
|
global_pool: str = "map"
|
||||||
|
mlp_ratio: float = 3.7362
|
||||||
|
class_token: bool = False
|
||||||
|
num_classes: int = 0
|
||||||
|
use_checkpoint: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class MultiModalityConfig(PretrainedConfig):
|
||||||
|
model_type = "multi_modality"
|
||||||
|
vision_config: VisionConfig
|
||||||
|
aligner_config: AlignerConfig
|
||||||
|
|
||||||
|
gen_vision_config: GenVisionConfig
|
||||||
|
gen_aligner_config: GenAlignerConfig
|
||||||
|
gen_head_config: GenHeadConfig
|
||||||
|
|
||||||
|
language_config: LlamaConfig
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
vision_config = kwargs.get("vision_config", {})
|
||||||
|
self.vision_config = VisionConfig(**vision_config)
|
||||||
|
|
||||||
|
aligner_config = kwargs.get("aligner_config", {})
|
||||||
|
self.aligner_config = AlignerConfig(**aligner_config)
|
||||||
|
|
||||||
|
gen_vision_config = kwargs.get("gen_vision_config", {})
|
||||||
|
self.gen_vision_config = GenVisionConfig(**gen_vision_config)
|
||||||
|
|
||||||
|
gen_aligner_config = kwargs.get("gen_aligner_config", {})
|
||||||
|
self.gen_aligner_config = GenAlignerConfig(**gen_aligner_config)
|
||||||
|
|
||||||
|
gen_head_config = kwargs.get("gen_head_config", {})
|
||||||
|
self.gen_head_config = GenHeadConfig(**gen_head_config)
|
||||||
|
|
||||||
|
language_config = kwargs.get("language_config", {})
|
||||||
|
if isinstance(language_config, LlamaConfig):
|
||||||
|
self.language_config = language_config
|
||||||
|
else:
|
||||||
|
self.language_config = LlamaConfig(**language_config)
|
||||||
|
|
||||||
|
|
||||||
|
class VLMImageProcessor(BaseImageProcessor):
|
||||||
|
model_input_names = ["pixel_values"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
image_size: int,
|
||||||
|
min_size: int = 14,
|
||||||
|
image_mean: Union[Tuple[float, float, float], List[float]] = (
|
||||||
|
0.48145466,
|
||||||
|
0.4578275,
|
||||||
|
0.40821073,
|
||||||
|
),
|
||||||
|
image_std: Union[Tuple[float, float, float], List[float]] = (
|
||||||
|
0.26862954,
|
||||||
|
0.26130258,
|
||||||
|
0.27577711,
|
||||||
|
),
|
||||||
|
rescale_factor: float = 1.0 / 255.0,
|
||||||
|
do_normalize: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
self.image_size = image_size
|
||||||
|
self.rescale_factor = rescale_factor
|
||||||
|
self.image_mean = image_mean
|
||||||
|
self.image_std = image_std
|
||||||
|
self.min_size = min_size
|
||||||
|
self.do_normalize = do_normalize
|
||||||
|
|
||||||
|
if image_mean is None:
|
||||||
|
self.background_color = (127, 127, 127)
|
||||||
|
else:
|
||||||
|
self.background_color = tuple([int(x * 255) for x in image_mean])
|
||||||
|
|
||||||
|
def resize(self, pil_img: Image) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pil_img (PIL.Image): [H, W, 3] in PIL.Image in RGB
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
x (np.ndarray): [3, self.image_size, self.image_size]
|
||||||
|
"""
|
||||||
|
|
||||||
|
width, height = pil_img.size
|
||||||
|
max_size = max(width, height)
|
||||||
|
|
||||||
|
size = [
|
||||||
|
max(int(height / max_size * self.image_size), self.min_size),
|
||||||
|
max(int(width / max_size * self.image_size), self.min_size),
|
||||||
|
]
|
||||||
|
|
||||||
|
if width <= 0 or height <= 0 or size[0] <= 0 or size[1] <= 0:
|
||||||
|
# print(f"orig size = {pil_img.size}, new size = {size}")
|
||||||
|
raise ValueError("Invalid size!")
|
||||||
|
|
||||||
|
def resize(
|
||||||
|
pil_img, size, interpolation=PIL.Image.Resampling.BICUBIC, antialias=True
|
||||||
|
):
|
||||||
|
if isinstance(size, int):
|
||||||
|
w, h = pil_img.size
|
||||||
|
if (w <= h and w == size) or (h <= w and h == size):
|
||||||
|
return pil_img
|
||||||
|
if w < h:
|
||||||
|
ow = size
|
||||||
|
oh = int(size * h / w)
|
||||||
|
else:
|
||||||
|
oh = size
|
||||||
|
ow = int(size * w / h)
|
||||||
|
size = (ow, oh)
|
||||||
|
else:
|
||||||
|
size = (size[1], size[0])
|
||||||
|
|
||||||
|
return pil_img.resize(
|
||||||
|
size, resample=interpolation, reducing_gap=None if antialias else 3.0
|
||||||
|
)
|
||||||
|
|
||||||
|
pil_img = resize(
|
||||||
|
pil_img, size, interpolation=PIL.Image.Resampling.BICUBIC, antialias=True
|
||||||
|
)
|
||||||
|
|
||||||
|
pil_img = expand2square(pil_img, self.background_color)
|
||||||
|
x = to_numpy_array(pil_img)
|
||||||
|
|
||||||
|
# [H, W, 3] -> [3, H, W]
|
||||||
|
x = np.transpose(x, (2, 0, 1))
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def preprocess(self, images, return_tensors: str = "pt", **kwargs) -> BatchFeature:
|
||||||
|
# resize and pad to [self.image_size, self.image_size]
|
||||||
|
# then convert from [H, W, 3] to [3, H, W]
|
||||||
|
if not isinstance(images, list):
|
||||||
|
images = [images]
|
||||||
|
images: List[np.ndarray] = [self.resize(image) for image in images]
|
||||||
|
images = [image[:3, ...] for image in images]
|
||||||
|
|
||||||
|
# rescale from [0, 255] -> [0, 1]
|
||||||
|
images = [
|
||||||
|
self.rescale(
|
||||||
|
image=image,
|
||||||
|
scale=self.rescale_factor,
|
||||||
|
input_data_format="channels_first",
|
||||||
|
)
|
||||||
|
for image in images
|
||||||
|
]
|
||||||
|
|
||||||
|
# normalize
|
||||||
|
if self.do_normalize:
|
||||||
|
images = [
|
||||||
|
self.normalize(
|
||||||
|
image=image,
|
||||||
|
mean=self.image_mean,
|
||||||
|
std=self.image_std,
|
||||||
|
input_data_format="channels_first",
|
||||||
|
)
|
||||||
|
for image in images
|
||||||
|
]
|
||||||
|
data = {"pixel_values": images}
|
||||||
|
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def default_shape(self):
|
||||||
|
return [3, self.image_size, self.image_size]
|
||||||
|
|
||||||
|
|
||||||
|
class DictOutput(object):
|
||||||
|
def keys(self):
|
||||||
|
return self.__dict__.keys()
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
return self.__dict__[item]
|
||||||
|
|
||||||
|
def __setitem__(self, key, value):
|
||||||
|
self.__dict__[key] = value
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VLChatProcessorOutput(DictOutput):
|
||||||
|
sft_format: str
|
||||||
|
input_ids: torch.Tensor
|
||||||
|
pixel_values: torch.Tensor
|
||||||
|
num_image_tokens: torch.IntTensor
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.input_ids)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BatchedVLChatProcessorOutput(DictOutput):
|
||||||
|
sft_format: List[str]
|
||||||
|
input_ids: torch.Tensor
|
||||||
|
pixel_values: torch.Tensor
|
||||||
|
attention_mask: torch.Tensor
|
||||||
|
images_seq_mask: torch.BoolTensor
|
||||||
|
images_emb_mask: torch.BoolTensor
|
||||||
|
|
||||||
|
|
||||||
|
# FIXME: had to place Official Processor here, since image_processor module would not be imported in all threads,
|
||||||
|
# hence AutoProcessor registration would not be affective in some cases
|
||||||
|
class VLChatProcessor(ProcessorMixin):
|
||||||
|
image_processor_class = "AutoImageProcessor"
|
||||||
|
tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
|
||||||
|
|
||||||
|
attributes = ["image_processor", "tokenizer"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
image_processor: VLMImageProcessor,
|
||||||
|
tokenizer: LlamaTokenizerFast,
|
||||||
|
image_tag: str = "<image_placeholder>",
|
||||||
|
image_start_tag: str = "<begin_of_image>",
|
||||||
|
image_end_tag: str = "<end_of_image>",
|
||||||
|
pad_tag: str = "<|▁pad▁|>",
|
||||||
|
num_image_tokens: int = 576,
|
||||||
|
add_special_token: bool = False,
|
||||||
|
sft_format: str = "deepseek",
|
||||||
|
mask_prompt: bool = True,
|
||||||
|
ignore_id: int = -100,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.image_processor = image_processor
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
|
image_id = self.tokenizer.vocab.get(image_tag)
|
||||||
|
if image_id is None:
|
||||||
|
special_tokens = [image_tag]
|
||||||
|
special_tokens_dict = {"additional_special_tokens": special_tokens}
|
||||||
|
self.tokenizer.add_special_tokens(special_tokens_dict)
|
||||||
|
# print(f"Add image tag = {image_tag} to the tokenizer")
|
||||||
|
|
||||||
|
self.image_tag = image_tag
|
||||||
|
self.image_start_tag = image_start_tag
|
||||||
|
self.image_end_tag = image_end_tag
|
||||||
|
self.pad_tag = pad_tag
|
||||||
|
|
||||||
|
self.num_image_tokens = num_image_tokens
|
||||||
|
self.add_special_token = add_special_token
|
||||||
|
self.sft_format = sft_format
|
||||||
|
self.ignore_id = ignore_id
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
image_processor,
|
||||||
|
tokenizer,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def image_token(self):
|
||||||
|
return self.image_tag
|
||||||
|
|
||||||
|
@property
|
||||||
|
def image_id(self) -> int:
|
||||||
|
image_id = self.tokenizer.vocab.get(self.image_tag)
|
||||||
|
return image_id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def image_start_id(self):
|
||||||
|
image_start_id = self.tokenizer.vocab.get(self.image_start_tag)
|
||||||
|
return image_start_id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def image_end_id(self):
|
||||||
|
image_end_id = self.tokenizer.vocab.get(self.image_end_tag)
|
||||||
|
return image_end_id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def image_start_token(self):
|
||||||
|
return self.image_start_tag
|
||||||
|
|
||||||
|
@property
|
||||||
|
def image_end_token(self):
|
||||||
|
return self.image_end_tag
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pad_id(self):
|
||||||
|
pad_id = self.tokenizer.vocab.get(self.pad_tag)
|
||||||
|
return pad_id
|
||||||
|
|
||||||
|
def add_image_token(
|
||||||
|
self,
|
||||||
|
image_indices: List[int],
|
||||||
|
input_ids: torch.LongTensor,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_indices (List[int]): [index_0, index_1, ..., index_j]
|
||||||
|
input_ids (torch.LongTensor): [N]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
input_ids (torch.LongTensor): [N + image tokens]
|
||||||
|
num_image_tokens (torch.IntTensor): [n_images]
|
||||||
|
"""
|
||||||
|
|
||||||
|
input_slices = []
|
||||||
|
|
||||||
|
start = 0
|
||||||
|
for index in image_indices:
|
||||||
|
if self.add_special_token:
|
||||||
|
end = index + 1
|
||||||
|
else:
|
||||||
|
end = index
|
||||||
|
|
||||||
|
# original text tokens
|
||||||
|
input_slices.append(input_ids[start:end])
|
||||||
|
|
||||||
|
# add boi, image tokens, eoi and set the mask as False
|
||||||
|
input_slices.append(self.image_start_id * torch.ones((1), dtype=torch.long))
|
||||||
|
input_slices.append(
|
||||||
|
self.image_id * torch.ones((self.num_image_tokens,), dtype=torch.long)
|
||||||
|
)
|
||||||
|
input_slices.append(self.image_end_id * torch.ones((1), dtype=torch.long))
|
||||||
|
start = index + 1
|
||||||
|
|
||||||
|
# the left part
|
||||||
|
input_slices.append(input_ids[start:])
|
||||||
|
|
||||||
|
# concat all slices
|
||||||
|
input_ids = torch.cat(input_slices, dim=0)
|
||||||
|
num_image_tokens = torch.IntTensor([self.num_image_tokens] * len(image_indices))
|
||||||
|
|
||||||
|
return input_ids, num_image_tokens
|
||||||
|
|
||||||
|
def process_one(
|
||||||
|
self,
|
||||||
|
prompt: str = None,
|
||||||
|
images: List[Image] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (str): the formatted prompt;
|
||||||
|
images (List[ImageType]): the list of images;
|
||||||
|
**kwargs:
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
outputs (BaseProcessorOutput): the output of the processor,
|
||||||
|
- input_ids (torch.LongTensor): [N + image tokens]
|
||||||
|
- target_ids (torch.LongTensor): [N + image tokens]
|
||||||
|
- images (torch.FloatTensor): [n_images, 3, H, W]
|
||||||
|
- image_id (int): the id of the image token
|
||||||
|
- num_image_tokens (List[int]): the number of image tokens
|
||||||
|
"""
|
||||||
|
|
||||||
|
sft_format = prompt
|
||||||
|
# tokenize
|
||||||
|
input_ids = self.tokenizer.encode(sft_format)
|
||||||
|
input_ids = torch.LongTensor(input_ids)
|
||||||
|
|
||||||
|
# add image tokens to the input_ids
|
||||||
|
image_token_mask: torch.Tensor = (input_ids == self.image_id).to(torch.bool)
|
||||||
|
image_indices = image_token_mask.nonzero()
|
||||||
|
input_ids, num_image_tokens = self.add_image_token(
|
||||||
|
image_indices=image_indices,
|
||||||
|
input_ids=input_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
# load images
|
||||||
|
images_outputs = self.image_processor(images, return_tensors="pt")
|
||||||
|
|
||||||
|
prepare = VLChatProcessorOutput(
|
||||||
|
sft_format=sft_format,
|
||||||
|
input_ids=input_ids,
|
||||||
|
pixel_values=images_outputs.pixel_values,
|
||||||
|
num_image_tokens=num_image_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
return prepare
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
prompt: str = None,
|
||||||
|
conversations: List[Dict[str, str]] = None,
|
||||||
|
images: List[Image] = None,
|
||||||
|
force_batchify: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (str): the formatted prompt;
|
||||||
|
conversations (List[Dict]): conversations with a list of messages;
|
||||||
|
images (List[ImageType]): the list of images;
|
||||||
|
force_batchify (bool): force batchify the inputs;
|
||||||
|
**kwargs:
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
outputs (BaseProcessorOutput): the output of the processor,
|
||||||
|
- input_ids (torch.LongTensor): [N + image tokens]
|
||||||
|
- images (torch.FloatTensor): [n_images, 3, H, W]
|
||||||
|
- image_id (int): the id of the image token
|
||||||
|
- num_image_tokens (List[int]): the number of image tokens
|
||||||
|
"""
|
||||||
|
|
||||||
|
prepare = self.process_one(
|
||||||
|
prompt=prompt, conversations=conversations, images=images
|
||||||
|
)
|
||||||
|
|
||||||
|
if force_batchify:
|
||||||
|
prepare = self.batchify([prepare])
|
||||||
|
|
||||||
|
return prepare
|
||||||
|
|
||||||
|
def batchify(
|
||||||
|
self, prepare_list: List[VLChatProcessorOutput]
|
||||||
|
) -> BatchedVLChatProcessorOutput:
|
||||||
|
"""
|
||||||
|
Preprocesses the inputs for multimodal inference.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prepare_list (List[VLChatProcessorOutput]): A list of VLChatProcessorOutput.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BatchedVLChatProcessorOutput: A dictionary of the inputs to use for multimodal inference.
|
||||||
|
"""
|
||||||
|
|
||||||
|
batch_size = len(prepare_list)
|
||||||
|
sft_format = []
|
||||||
|
n_images = []
|
||||||
|
seq_lens = []
|
||||||
|
for prepare in prepare_list:
|
||||||
|
n_images.append(len(prepare.num_image_tokens))
|
||||||
|
seq_lens.append(len(prepare))
|
||||||
|
|
||||||
|
input_token_max_len = max(seq_lens)
|
||||||
|
max_n_images = max(1, max(n_images))
|
||||||
|
|
||||||
|
batched_input_ids = torch.full(
|
||||||
|
(batch_size, input_token_max_len), self.pad_id
|
||||||
|
).long() # FIXME
|
||||||
|
batched_attention_mask = torch.zeros((batch_size, input_token_max_len)).long()
|
||||||
|
batched_pixel_values = torch.zeros(
|
||||||
|
(batch_size, max_n_images, *self.image_processor.default_shape)
|
||||||
|
).float()
|
||||||
|
batched_images_seq_mask = torch.zeros((batch_size, input_token_max_len)).bool()
|
||||||
|
batched_images_emb_mask = torch.zeros(
|
||||||
|
(batch_size, max_n_images, self.num_image_tokens)
|
||||||
|
).bool()
|
||||||
|
|
||||||
|
for i, prepare in enumerate(prepare_list):
|
||||||
|
input_ids = prepare.input_ids
|
||||||
|
seq_len = len(prepare)
|
||||||
|
n_image = len(prepare.num_image_tokens)
|
||||||
|
# left-padding
|
||||||
|
batched_attention_mask[i, -seq_len:] = 1
|
||||||
|
batched_input_ids[i, -seq_len:] = torch.LongTensor(input_ids)
|
||||||
|
batched_images_seq_mask[i, -seq_len:] = input_ids == self.image_id
|
||||||
|
|
||||||
|
if n_image > 0:
|
||||||
|
batched_pixel_values[i, :n_image] = prepare.pixel_values
|
||||||
|
for j, n_image_tokens in enumerate(prepare.num_image_tokens):
|
||||||
|
batched_images_emb_mask[i, j, :n_image_tokens] = True
|
||||||
|
|
||||||
|
sft_format.append(prepare.sft_format)
|
||||||
|
|
||||||
|
batched_prepares = BatchedVLChatProcessorOutput(
|
||||||
|
input_ids=batched_input_ids,
|
||||||
|
attention_mask=batched_attention_mask,
|
||||||
|
pixel_values=batched_pixel_values,
|
||||||
|
images_seq_mask=batched_images_seq_mask,
|
||||||
|
images_emb_mask=batched_images_emb_mask,
|
||||||
|
sft_format=sft_format,
|
||||||
|
)
|
||||||
|
|
||||||
|
return batched_prepares
|
||||||
|
|
||||||
|
|
||||||
|
class VLMImageProcessorConfig(PretrainedConfig):
|
||||||
|
model_type = "deepseek_vlm"
|
||||||
|
image_size: int
|
||||||
|
min_size: int
|
||||||
|
image_mean: Union[Tuple[float, float, float], List[float]]
|
||||||
|
image_std: Union[Tuple[float, float, float], List[float]]
|
||||||
|
rescale_factor: float
|
||||||
|
do_normalize: bool
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
image_size: int,
|
||||||
|
min_size: int = 14,
|
||||||
|
image_mean: Union[Tuple[float, float, float], List[float]] = (
|
||||||
|
0.48145466,
|
||||||
|
0.4578275,
|
||||||
|
0.40821073,
|
||||||
|
),
|
||||||
|
image_std: Union[Tuple[float, float, float], List[float]] = (
|
||||||
|
0.26862954,
|
||||||
|
0.26130258,
|
||||||
|
0.27577711,
|
||||||
|
),
|
||||||
|
rescale_factor: float = 1.0 / 255.0,
|
||||||
|
do_normalize: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.image_size = image_size
|
||||||
|
self.min_size = min_size
|
||||||
|
self.image_mean = image_mean
|
||||||
|
self.image_std = image_std
|
||||||
|
self.rescale_factor = rescale_factor
|
||||||
|
self.do_normalize = do_normalize
|
||||||
|
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
AutoProcessor.register(MultiModalityConfig, VLChatProcessor, exist_ok=True)
|
||||||
|
AutoImageProcessor.register(VLMImageProcessorConfig, None, VLMImageProcessor, None)
|
||||||
@@ -408,7 +408,7 @@ def _get_and_verify_dtype(
|
|||||||
|
|
||||||
def is_generation_model(model_architectures: List[str], is_embedding: bool = False):
|
def is_generation_model(model_architectures: List[str], is_embedding: bool = False):
|
||||||
# We have two ways to determine whether a model is a generative model.
|
# We have two ways to determine whether a model is a generative model.
|
||||||
# 1. Check the model architectue
|
# 1. Check the model architecture
|
||||||
# 2. check the `is_embedding` server args
|
# 2. check the `is_embedding` server args
|
||||||
|
|
||||||
if (
|
if (
|
||||||
@@ -424,18 +424,25 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
|
|||||||
return not is_embedding
|
return not is_embedding
|
||||||
|
|
||||||
|
|
||||||
|
multimodal_model_archs = [
|
||||||
|
"LlavaLlamaForCausalLM",
|
||||||
|
"LlavaQwenForCausalLM",
|
||||||
|
"LlavaMistralForCausalLM",
|
||||||
|
"LlavaVidForCausalLM",
|
||||||
|
"Grok1VForCausalLM",
|
||||||
|
"Grok1AForCausalLM",
|
||||||
|
"MllamaForConditionalGeneration",
|
||||||
|
"Qwen2VLForConditionalGeneration",
|
||||||
|
"Qwen2_5_VLForConditionalGeneration",
|
||||||
|
"MiniCPMV",
|
||||||
|
"MultiModalityCausalLM",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def is_multimodal_model(model_architectures: List[str]):
|
def is_multimodal_model(model_architectures: List[str]):
|
||||||
if (
|
if any(
|
||||||
"LlavaLlamaForCausalLM" in model_architectures
|
multi_model_arch in model_architectures
|
||||||
or "LlavaQwenForCausalLM" in model_architectures
|
for multi_model_arch in multimodal_model_archs
|
||||||
or "LlavaMistralForCausalLM" in model_architectures
|
|
||||||
or "LlavaVidForCausalLM" in model_architectures
|
|
||||||
or "Grok1VForCausalLM" in model_architectures
|
|
||||||
or "Grok1AForCausalLM" in model_architectures
|
|
||||||
or "MllamaForConditionalGeneration" in model_architectures
|
|
||||||
or "Qwen2VLForConditionalGeneration" in model_architectures
|
|
||||||
or "Qwen2_5_VLForConditionalGeneration" in model_architectures
|
|
||||||
or "MiniCPMV" in model_architectures
|
|
||||||
):
|
):
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -631,3 +631,18 @@ register_conv_template(
|
|||||||
image_token="(<image>./</image>)",
|
image_token="(<image>./</image>)",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Reference: https://github.com/deepseek-ai/Janus?tab=readme-ov-file#janus-pro
|
||||||
|
register_conv_template(
|
||||||
|
Conversation(
|
||||||
|
name="janus-pro",
|
||||||
|
system_message="You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language",
|
||||||
|
system_template="{system_message}.",
|
||||||
|
roles=("User", "Assistant"),
|
||||||
|
sep="\n\n",
|
||||||
|
sep2="<|end▁of▁sentence|>",
|
||||||
|
sep_style=SeparatorStyle.ADD_COLON_TWO,
|
||||||
|
stop_str=["<|User|>", "<|end▁of▁sentence|>"],
|
||||||
|
image_token="<image_placeholder>",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|||||||
@@ -30,13 +30,20 @@ from transformers import (
|
|||||||
)
|
)
|
||||||
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
||||||
|
|
||||||
from sglang.srt.configs import ChatGLMConfig, DbrxConfig, ExaoneConfig, Qwen2_5_VLConfig
|
from sglang.srt.configs import (
|
||||||
|
ChatGLMConfig,
|
||||||
|
DbrxConfig,
|
||||||
|
ExaoneConfig,
|
||||||
|
MultiModalityConfig,
|
||||||
|
Qwen2_5_VLConfig,
|
||||||
|
)
|
||||||
|
|
||||||
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
||||||
ChatGLMConfig.model_type: ChatGLMConfig,
|
ChatGLMConfig.model_type: ChatGLMConfig,
|
||||||
DbrxConfig.model_type: DbrxConfig,
|
DbrxConfig.model_type: DbrxConfig,
|
||||||
ExaoneConfig.model_type: ExaoneConfig,
|
ExaoneConfig.model_type: ExaoneConfig,
|
||||||
Qwen2_5_VLConfig.model_type: Qwen2_5_VLConfig,
|
Qwen2_5_VLConfig.model_type: Qwen2_5_VLConfig,
|
||||||
|
MultiModalityConfig.model_type: MultiModalityConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
for name, cls in _CONFIG_REGISTRY.items():
|
for name, cls in _CONFIG_REGISTRY.items():
|
||||||
@@ -67,6 +74,13 @@ def get_config(
|
|||||||
model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
|
model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# FIXME: Pour contents of janus-pro's langauge_config to first-level
|
||||||
|
if isinstance(model, str) and model.lower().startswith("deepseek-ai/janus-pro"):
|
||||||
|
assert hasattr(config, "language_config")
|
||||||
|
for key, val in config.language_config.__dict__.items():
|
||||||
|
setattr(config, key, val)
|
||||||
|
setattr(config, "architectures", ["MultiModalityCausalLM"])
|
||||||
|
|
||||||
if config.model_type in _CONFIG_REGISTRY:
|
if config.model_type in _CONFIG_REGISTRY:
|
||||||
config_class = _CONFIG_REGISTRY[config.model_type]
|
config_class = _CONFIG_REGISTRY[config.model_type]
|
||||||
config = config_class.from_pretrained(model, revision=revision)
|
config = config_class.from_pretrained(model, revision=revision)
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from typing import Optional, Tuple
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange
|
||||||
|
|
||||||
from sglang.srt.distributed import parallel_state
|
from sglang.srt.distributed import parallel_state
|
||||||
from sglang.srt.distributed import utils as dist_utils
|
from sglang.srt.distributed import utils as dist_utils
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from PIL import Image
|
|||||||
|
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import load_image
|
from sglang.srt.utils import load_image
|
||||||
|
from sglang.utils import logger
|
||||||
|
|
||||||
global global_processor
|
global global_processor
|
||||||
|
|
||||||
@@ -22,6 +23,13 @@ def get_global_processor():
|
|||||||
return global_processor
|
return global_processor
|
||||||
|
|
||||||
|
|
||||||
|
def init_global_processor(sglang_image_processor, server_args: ServerArgs):
|
||||||
|
"""Init the global processor for multi-modal models."""
|
||||||
|
global global_processor
|
||||||
|
transformers.logging.set_verbosity_error()
|
||||||
|
global_processor = sglang_image_processor._build_processor(server_args=server_args)
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class BaseImageProcessorOutput:
|
class BaseImageProcessorOutput:
|
||||||
image_hashes: list[int]
|
image_hashes: list[int]
|
||||||
@@ -119,6 +127,11 @@ class BaseImageProcessor(ABC):
|
|||||||
) -> BaseImageProcessorOutput:
|
) -> BaseImageProcessorOutput:
|
||||||
"""
|
"""
|
||||||
Each frame of video/image will be replaced by a single image token
|
Each frame of video/image will be replaced by a single image token
|
||||||
|
|
||||||
|
Args:
|
||||||
|
|
||||||
|
discard_alpha_channel: if True, discards the alpha channel in the returned images
|
||||||
|
|
||||||
"""
|
"""
|
||||||
image_hashes, image_sizes = [], []
|
image_hashes, image_sizes = [], []
|
||||||
all_frames = []
|
all_frames = []
|
||||||
@@ -133,7 +146,7 @@ class BaseImageProcessor(ABC):
|
|||||||
if return_text:
|
if return_text:
|
||||||
text_parts = input_text.split(image_token)
|
text_parts = input_text.split(image_token)
|
||||||
|
|
||||||
# roughly calculate the max number of frames under the max_req_input_len limit
|
# TODO(mick): load from server_args, env, or sampling_params
|
||||||
MAX_NUM_FRAMES = 30
|
MAX_NUM_FRAMES = 30
|
||||||
estimated_frames_list = self.get_estimated_frames_list(image_data=image_data)
|
estimated_frames_list = self.get_estimated_frames_list(image_data=image_data)
|
||||||
total_frame_count = sum(estimated_frames_list)
|
total_frame_count = sum(estimated_frames_list)
|
||||||
|
|||||||
79
python/sglang/srt/managers/image_processors/janus_pro.py
Normal file
79
python/sglang/srt/managers/image_processors/janus_pro.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
import asyncio
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
from sglang.srt.managers.image_processors.base_image_processor import (
|
||||||
|
BaseImageProcessor as SGLangBaseImageProcessor,
|
||||||
|
)
|
||||||
|
from sglang.srt.managers.image_processors.base_image_processor import (
|
||||||
|
get_global_processor,
|
||||||
|
)
|
||||||
|
from sglang.srt.models.deepseek_janus_pro import MultiModalityCausalLM
|
||||||
|
|
||||||
|
|
||||||
|
class JanusProProcessor(SGLangBaseImageProcessor):
|
||||||
|
def __init__(self, hf_config, server_args, _processor):
|
||||||
|
super().__init__(hf_config, server_args, _processor)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _process_images_task(images, input_text):
|
||||||
|
processor = get_global_processor()
|
||||||
|
result = processor.__call__(
|
||||||
|
prompt=input_text, images=images, return_tensors="pt"
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"input_ids": result["input_ids"],
|
||||||
|
"pixel_values": result["pixel_values"],
|
||||||
|
"images_emb_mask": result["images_emb_mask"],
|
||||||
|
"im_start_id": processor.image_start_id,
|
||||||
|
"im_end_id": processor.image_end_id,
|
||||||
|
"im_token_id": processor.image_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _process_images(self, images, input_text):
|
||||||
|
if self.executor is not None:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
image_inputs = await loop.run_in_executor(
|
||||||
|
self.executor,
|
||||||
|
JanusProProcessor._process_images_task,
|
||||||
|
images,
|
||||||
|
input_text,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
image_inputs = self._processor(
|
||||||
|
images=images, text=input_text, return_tensors="pt"
|
||||||
|
)
|
||||||
|
|
||||||
|
return image_inputs
|
||||||
|
|
||||||
|
async def process_images_async(
|
||||||
|
self,
|
||||||
|
image_data: List[Union[str, bytes]],
|
||||||
|
input_ids,
|
||||||
|
request_obj,
|
||||||
|
max_req_input_len,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if not image_data:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not isinstance(image_data, list):
|
||||||
|
image_data = [image_data]
|
||||||
|
|
||||||
|
base_out = self.load_images(
|
||||||
|
input_ids, image_data, "<image_placeholder>", max_req_input_len
|
||||||
|
)
|
||||||
|
images = base_out.all_frames
|
||||||
|
res = await self._process_images(images=images, input_text=base_out.input_text)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"input_ids": res["input_ids"].flatten().tolist(),
|
||||||
|
"pixel_values": res["pixel_values"],
|
||||||
|
"images_emb_mask": res["images_emb_mask"],
|
||||||
|
"image_hashes": base_out.image_hashes,
|
||||||
|
"im_start_id": res["im_start_id"],
|
||||||
|
"im_end_id": res["im_end_id"],
|
||||||
|
"im_token_id": res["im_token_id"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
ImageProcessorMapping = {MultiModalityCausalLM: JanusProProcessor}
|
||||||
2127
python/sglang/srt/models/deepseek_janus_pro.py
Normal file
2127
python/sglang/srt/models/deepseek_janus_pro.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -512,5 +512,29 @@ class TestMinicpmvServer(TestOpenAIVisionServer):
|
|||||||
cls.base_url += "/v1"
|
cls.base_url += "/v1"
|
||||||
|
|
||||||
|
|
||||||
|
class TestJanusProServer(TestOpenAIVisionServer):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = "deepseek-ai/Janus-Pro-7B"
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.api_key = "sk-123456"
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=[
|
||||||
|
"--trust-remote-code",
|
||||||
|
"--chat-template",
|
||||||
|
"janus-pro",
|
||||||
|
"--mem-fraction-static",
|
||||||
|
"0.4",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
cls.base_url += "/v1"
|
||||||
|
|
||||||
|
def test_video_chat_completion(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user