[Feature] Support Deepseek-VL2 (#2798)
Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: Chayenne <zhaochen20@outlook.com> Co-authored-by: Yi Zhang <1109276519@qq.com>
This commit is contained in:
@@ -32,6 +32,7 @@
|
|||||||
- Phi-3-Small
|
- Phi-3-Small
|
||||||
- IBM Granite 3
|
- IBM Granite 3
|
||||||
- Janus-Pro-1B / Janus-Pro-7B
|
- Janus-Pro-1B / Janus-Pro-7B
|
||||||
|
- Deepseek-VL2 / Deepseek-VL2-small
|
||||||
- Gemma 3 (it)
|
- Gemma 3 (it)
|
||||||
|
|
||||||
## Embedding Models
|
## Embedding Models
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from sglang.srt.configs.chatglm import ChatGLMConfig
|
from sglang.srt.configs.chatglm import ChatGLMConfig
|
||||||
from sglang.srt.configs.dbrx import DbrxConfig
|
from sglang.srt.configs.dbrx import DbrxConfig
|
||||||
|
from sglang.srt.configs.deepseekvl2 import DeepseekVL2Config
|
||||||
from sglang.srt.configs.exaone import ExaoneConfig
|
from sglang.srt.configs.exaone import ExaoneConfig
|
||||||
from sglang.srt.configs.gemma3 import Gemma3Config, Gemma3TextConfig
|
from sglang.srt.configs.gemma3 import Gemma3Config, Gemma3TextConfig
|
||||||
from sglang.srt.configs.janus_pro import MultiModalityConfig
|
from sglang.srt.configs.janus_pro import MultiModalityConfig
|
||||||
@@ -12,6 +13,7 @@ __all__ = [
|
|||||||
"ExaoneConfig",
|
"ExaoneConfig",
|
||||||
"ChatGLMConfig",
|
"ChatGLMConfig",
|
||||||
"DbrxConfig",
|
"DbrxConfig",
|
||||||
|
"DeepseekVL2Config",
|
||||||
"Qwen2_5_VLConfig",
|
"Qwen2_5_VLConfig",
|
||||||
"Qwen2_5_VLVisionConfig",
|
"Qwen2_5_VLVisionConfig",
|
||||||
"MultiModalityConfig",
|
"MultiModalityConfig",
|
||||||
|
|||||||
667
python/sglang/srt/configs/deepseekvl2.py
Normal file
667
python/sglang/srt/configs/deepseekvl2.py
Normal file
@@ -0,0 +1,667 @@
|
|||||||
|
import math
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torchvision.transforms as T
|
||||||
|
from PIL import Image, ImageOps
|
||||||
|
from transformers import (
|
||||||
|
AutoProcessor,
|
||||||
|
LlamaTokenizerFast,
|
||||||
|
PretrainedConfig,
|
||||||
|
ProcessorMixin,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def select_best_resolution(image_size, candidate_resolutions):
|
||||||
|
# used for cropping
|
||||||
|
original_width, original_height = image_size
|
||||||
|
best_fit = None
|
||||||
|
max_effective_resolution = 0
|
||||||
|
min_wasted_resolution = float("inf")
|
||||||
|
|
||||||
|
for width, height in candidate_resolutions:
|
||||||
|
scale = min(width / original_width, height / original_height)
|
||||||
|
downscaled_width, downscaled_height = int(original_width * scale), int(
|
||||||
|
original_height * scale
|
||||||
|
)
|
||||||
|
effective_resolution = min(
|
||||||
|
downscaled_width * downscaled_height, original_width * original_height
|
||||||
|
)
|
||||||
|
wasted_resolution = (width * height) - effective_resolution
|
||||||
|
|
||||||
|
if effective_resolution > max_effective_resolution or (
|
||||||
|
effective_resolution == max_effective_resolution
|
||||||
|
and wasted_resolution < min_wasted_resolution
|
||||||
|
):
|
||||||
|
max_effective_resolution = effective_resolution
|
||||||
|
min_wasted_resolution = wasted_resolution
|
||||||
|
best_fit = (width, height)
|
||||||
|
|
||||||
|
return best_fit
|
||||||
|
|
||||||
|
|
||||||
|
class DictOutput(object):
|
||||||
|
def keys(self):
|
||||||
|
return self.__dict__.keys()
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
return self.__dict__[item]
|
||||||
|
|
||||||
|
def __setitem__(self, key, value):
|
||||||
|
self.__dict__[key] = value
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VLChatProcessorOutput(DictOutput):
|
||||||
|
input_ids: torch.LongTensor
|
||||||
|
target_ids: torch.LongTensor
|
||||||
|
images: torch.Tensor
|
||||||
|
images_seq_mask: torch.BoolTensor
|
||||||
|
images_spatial_crop: torch.LongTensor
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.input_ids)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageTransform(object):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
mean: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5),
|
||||||
|
std: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5),
|
||||||
|
normalize: bool = True,
|
||||||
|
):
|
||||||
|
self.mean = mean
|
||||||
|
self.std = std
|
||||||
|
self.normalize = normalize
|
||||||
|
|
||||||
|
transform_pipelines = [T.ToTensor()]
|
||||||
|
|
||||||
|
if normalize:
|
||||||
|
transform_pipelines.append(T.Normalize(mean, std))
|
||||||
|
|
||||||
|
self.transform = T.Compose(transform_pipelines)
|
||||||
|
|
||||||
|
def __call__(self, pil_img: Image.Image):
|
||||||
|
x = self.transform(pil_img)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekVLV2Processor(ProcessorMixin):
|
||||||
|
tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
|
||||||
|
attributes = ["tokenizer"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tokenizer: LlamaTokenizerFast,
|
||||||
|
candidate_resolutions: Tuple[Tuple[int, int]],
|
||||||
|
patch_size: int,
|
||||||
|
downsample_ratio: int,
|
||||||
|
image_mean: Tuple[float, float, float] = (0.5, 0.5, 0.5),
|
||||||
|
image_std: Tuple[float, float, float] = (0.5, 0.5, 0.5),
|
||||||
|
normalize: bool = True,
|
||||||
|
image_token: str = "<image>",
|
||||||
|
pad_token: str = "<|▁pad▁|>",
|
||||||
|
add_special_token: bool = False,
|
||||||
|
sft_format: str = "deepseek",
|
||||||
|
mask_prompt: bool = True,
|
||||||
|
ignore_id: int = -100,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
|
||||||
|
self.candidate_resolutions = candidate_resolutions
|
||||||
|
self.image_size = candidate_resolutions[0][0]
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.image_mean = image_mean
|
||||||
|
self.image_std = image_std
|
||||||
|
self.normalize = normalize
|
||||||
|
self.downsample_ratio = downsample_ratio
|
||||||
|
|
||||||
|
self.image_transform = ImageTransform(
|
||||||
|
mean=image_mean, std=image_std, normalize=normalize
|
||||||
|
)
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
# must set this,padding side with make a difference in batch inference
|
||||||
|
self.tokenizer.padding_side = "left"
|
||||||
|
|
||||||
|
# add the pad_token as special token to use 'tokenizer.pad_token' and 'tokenizer.pad_token_id'
|
||||||
|
if tokenizer.pad_token is None:
|
||||||
|
self.tokenizer.add_special_tokens({"pad_token": pad_token})
|
||||||
|
|
||||||
|
# add image token
|
||||||
|
image_token_id = self.tokenizer.vocab.get(image_token)
|
||||||
|
if image_token_id is None:
|
||||||
|
special_tokens = [image_token]
|
||||||
|
special_tokens_dict = {"additional_special_tokens": special_tokens}
|
||||||
|
self.tokenizer.add_special_tokens(special_tokens_dict)
|
||||||
|
self.image_token_id = self.tokenizer.vocab.get(image_token)
|
||||||
|
|
||||||
|
# add five special tokens for grounding-related tasks
|
||||||
|
# <|ref|>, <|/ref|>, <|det|>, <|/det|>, <|grounding|>
|
||||||
|
special_tokens = ["<|ref|>", "<|/ref|>", "<|det|>", "<|/det|>", "<|grounding|>"]
|
||||||
|
special_tokens_dict = {"additional_special_tokens": special_tokens}
|
||||||
|
self.tokenizer.add_special_tokens(special_tokens_dict)
|
||||||
|
|
||||||
|
# add special tokens for SFT data
|
||||||
|
special_tokens = ["<|User|>", "<|Assistant|>"]
|
||||||
|
special_tokens_dict = {"additional_special_tokens": special_tokens}
|
||||||
|
self.tokenizer.add_special_tokens(special_tokens_dict)
|
||||||
|
|
||||||
|
self.image_token = image_token
|
||||||
|
self.pad_token = pad_token
|
||||||
|
self.add_special_token = add_special_token
|
||||||
|
self.sft_format = sft_format
|
||||||
|
self.mask_prompt = mask_prompt
|
||||||
|
self.ignore_id = ignore_id
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
tokenizer,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def format_messages_v2(self, messages, pil_images, max_req_input_len=-1):
|
||||||
|
"""play the role of format_messages_v2 and get_images_info in the last version"""
|
||||||
|
tokenized_data = []
|
||||||
|
masked_tokenized_data = [] # labels
|
||||||
|
images_list = []
|
||||||
|
images_seq_mask = []
|
||||||
|
images_spatial_crop = []
|
||||||
|
|
||||||
|
image_index = 0
|
||||||
|
image_token_cnt = messages.count(self.image_token)
|
||||||
|
tokenized_str, images, seq_mask, spatial_crop = self.tokenize_with_images(
|
||||||
|
messages,
|
||||||
|
pil_images[image_index : image_index + image_token_cnt],
|
||||||
|
bos=False,
|
||||||
|
eos=True,
|
||||||
|
cropping=len(pil_images) <= 2,
|
||||||
|
max_req_input_len=max_req_input_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
image_index = image_token_cnt
|
||||||
|
tokenized_data += tokenized_str
|
||||||
|
if self.mask_prompt:
|
||||||
|
masked_tokenized_data += [self.ignore_id] * len(tokenized_str)
|
||||||
|
else:
|
||||||
|
masked_tokenized_data += tokenized_str
|
||||||
|
images_list += images
|
||||||
|
images_seq_mask += seq_mask
|
||||||
|
images_spatial_crop += spatial_crop
|
||||||
|
|
||||||
|
assert len(tokenized_data) == len(
|
||||||
|
images_seq_mask
|
||||||
|
), f"format_messages_v2: tokenized_str's length {len(tokenized_str)} is not equal to imags_seq_mask's length {len(images_seq_mask)}"
|
||||||
|
|
||||||
|
return (
|
||||||
|
tokenized_data,
|
||||||
|
masked_tokenized_data,
|
||||||
|
images_list,
|
||||||
|
images_seq_mask,
|
||||||
|
images_spatial_crop,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def bos_id(self):
|
||||||
|
return self.tokenizer.bos_token_id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def eos_id(self):
|
||||||
|
return self.tokenizer.eos_token_id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pad_id(self):
|
||||||
|
return self.tokenizer.pad_token_id
|
||||||
|
|
||||||
|
def encode(self, text: str, bos: bool = True, eos: bool = False):
|
||||||
|
t = self.tokenizer.encode(text, add_special_tokens=False)
|
||||||
|
|
||||||
|
if bos:
|
||||||
|
t = [self.bos_id] + t
|
||||||
|
if eos:
|
||||||
|
t = t + [self.eos_id]
|
||||||
|
|
||||||
|
return t
|
||||||
|
|
||||||
|
def decode(self, t: List[int], **kwargs) -> str:
|
||||||
|
return self.tokenizer.decode(t, **kwargs)
|
||||||
|
|
||||||
|
def process_one(
|
||||||
|
self,
|
||||||
|
prompt: str = None,
|
||||||
|
conversations: List[Dict[str, str]] = None,
|
||||||
|
images: List[Image.Image] = None,
|
||||||
|
apply_sft_format: bool = False,
|
||||||
|
inference_mode: bool = True,
|
||||||
|
system_prompt: str = "",
|
||||||
|
max_req_input_len: int = -1,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (str): the formatted prompt;
|
||||||
|
conversations (List[Dict]): conversations with a list of messages;
|
||||||
|
images (List[ImageType]): the list of images;
|
||||||
|
apply_sft_format (bool): if prompt is not None, then apply the SFT format to prompt;
|
||||||
|
if conversations is not None, then it will always apply the SFT format to conversations;
|
||||||
|
inference_mode (bool): if True, then remove the last eos token;
|
||||||
|
system_prompt (str): the system prompt;
|
||||||
|
**kwargs:
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
outputs (BaseProcessorOutput): the output of the processor,
|
||||||
|
- input_ids (torch.LongTensor): [N + image tokens]
|
||||||
|
- target_ids (torch.LongTensor): [N + image tokens]
|
||||||
|
- images (torch.FloatTensor): [n_images, 3, H, W]
|
||||||
|
- image_id (int): the id of the image token
|
||||||
|
- num_image_tokens (List[int]): the number of image tokens
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert (
|
||||||
|
prompt is None or conversations is None
|
||||||
|
), "prompt and conversations cannot be used at the same time."
|
||||||
|
|
||||||
|
(
|
||||||
|
tokenized_str,
|
||||||
|
masked_tokenized_str,
|
||||||
|
images_list,
|
||||||
|
images_seq_mask,
|
||||||
|
images_spatial_crop,
|
||||||
|
) = self.format_messages_v2(conversations, images, max_req_input_len)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
len(tokenized_str) == len(images_seq_mask) == len(masked_tokenized_str)
|
||||||
|
), (
|
||||||
|
f"tokenized_str's length {len(tokenized_str)}, input_ids' length {len(masked_tokenized_str)}, "
|
||||||
|
f"imags_seq_mask's length {len(images_seq_mask)}, are not equal"
|
||||||
|
)
|
||||||
|
|
||||||
|
input_ids = torch.LongTensor(tokenized_str)
|
||||||
|
target_ids = torch.LongTensor(masked_tokenized_str)
|
||||||
|
images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool)
|
||||||
|
|
||||||
|
# set input_ids < 0 | input_ids == self.image_token_id as ignore_id
|
||||||
|
target_ids[(input_ids < 0) | (input_ids == self.image_token_id)] = (
|
||||||
|
self.ignore_id
|
||||||
|
)
|
||||||
|
input_ids[input_ids < 0] = self.pad_id
|
||||||
|
|
||||||
|
if inference_mode:
|
||||||
|
assert input_ids[-1] == self.eos_id
|
||||||
|
input_ids = input_ids[:-1]
|
||||||
|
target_ids = target_ids[:-1]
|
||||||
|
images_seq_mask = images_seq_mask[:-1]
|
||||||
|
|
||||||
|
if len(images_list) == 0:
|
||||||
|
images = torch.zeros((1, 3, self.image_size, self.image_size))
|
||||||
|
images_spatial_crop = torch.zeros((1, 2), dtype=torch.long)
|
||||||
|
else:
|
||||||
|
images = torch.stack(images_list, dim=0)
|
||||||
|
images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long)
|
||||||
|
|
||||||
|
prepare = VLChatProcessorOutput(
|
||||||
|
input_ids=input_ids,
|
||||||
|
target_ids=target_ids,
|
||||||
|
images=images,
|
||||||
|
images_seq_mask=images_seq_mask,
|
||||||
|
images_spatial_crop=images_spatial_crop,
|
||||||
|
)
|
||||||
|
|
||||||
|
return prepare
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
prompt: str = None,
|
||||||
|
conversations: List[Dict[str, str]] = None,
|
||||||
|
images: List[Image.Image] = None,
|
||||||
|
apply_sft_format: bool = False,
|
||||||
|
inference_mode: bool = True,
|
||||||
|
system_prompt: str = "",
|
||||||
|
max_req_input_len: int = -1,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
prepare = self.process_one(
|
||||||
|
prompt=prompt,
|
||||||
|
conversations=conversations,
|
||||||
|
images=images,
|
||||||
|
apply_sft_format=apply_sft_format,
|
||||||
|
inference_mode=inference_mode,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
max_req_input_len=max_req_input_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
return prepare
|
||||||
|
|
||||||
|
def find_all_indices(self, messages, target_value):
|
||||||
|
indices = []
|
||||||
|
for index, item in enumerate(messages):
|
||||||
|
if item == target_value:
|
||||||
|
indices.append(index)
|
||||||
|
return indices
|
||||||
|
|
||||||
|
def tokenize_with_images(
|
||||||
|
self,
|
||||||
|
conversation: str,
|
||||||
|
images: List[Image.Image],
|
||||||
|
bos: bool = True,
|
||||||
|
eos: bool = True,
|
||||||
|
cropping: bool = True,
|
||||||
|
max_req_input_len: int = -1,
|
||||||
|
):
|
||||||
|
"""Tokenize text with <image> tags."""
|
||||||
|
images_list, images_seq_mask, images_spatial_crop = [], [], []
|
||||||
|
text_splits = conversation.split(self.image_token)
|
||||||
|
tokenized_str = []
|
||||||
|
for text_sep, image in zip(text_splits, images):
|
||||||
|
"""encode text_sep"""
|
||||||
|
tokenized_sep = self.encode(text_sep, bos=False, eos=False)
|
||||||
|
tokenized_str += tokenized_sep
|
||||||
|
images_seq_mask += [False] * len(tokenized_sep)
|
||||||
|
|
||||||
|
"""select best resolution for anyres"""
|
||||||
|
if cropping:
|
||||||
|
best_width, best_height = select_best_resolution(
|
||||||
|
image.size, self.candidate_resolutions
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
best_width, best_height = self.image_size, self.image_size
|
||||||
|
# print(image.size, (best_width, best_height)) # check the select_best_resolutions func
|
||||||
|
|
||||||
|
"""process the global view"""
|
||||||
|
global_view = ImageOps.pad(
|
||||||
|
image,
|
||||||
|
(self.image_size, self.image_size),
|
||||||
|
color=tuple(int(x * 255) for x in self.image_transform.mean),
|
||||||
|
)
|
||||||
|
images_list.append(self.image_transform(global_view))
|
||||||
|
|
||||||
|
"""process the local views"""
|
||||||
|
local_view = ImageOps.pad(
|
||||||
|
image,
|
||||||
|
(best_width, best_height),
|
||||||
|
color=tuple(int(x * 255) for x in self.image_transform.mean),
|
||||||
|
)
|
||||||
|
for i in range(0, best_height, self.image_size):
|
||||||
|
for j in range(0, best_width, self.image_size):
|
||||||
|
images_list.append(
|
||||||
|
self.image_transform(
|
||||||
|
local_view.crop(
|
||||||
|
(j, i, j + self.image_size, i + self.image_size)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
"""record height / width crop num"""
|
||||||
|
num_width_tiles, num_height_tiles = (
|
||||||
|
best_width // self.image_size,
|
||||||
|
best_height // self.image_size,
|
||||||
|
)
|
||||||
|
images_spatial_crop.append([num_width_tiles, num_height_tiles])
|
||||||
|
|
||||||
|
"""add image tokens"""
|
||||||
|
h = w = math.ceil(
|
||||||
|
(self.image_size // self.patch_size) / self.downsample_ratio
|
||||||
|
)
|
||||||
|
# global views tokens h * (w + 1), 1 is for line seperator
|
||||||
|
tokenized_image = [self.image_token_id] * h * (w + 1)
|
||||||
|
# add a seperator between global and local views
|
||||||
|
tokenized_image += [self.image_token_id]
|
||||||
|
# local views tokens, (num_height_tiles * h) * (num_width_tiles * w + 1)
|
||||||
|
tokenized_image += (
|
||||||
|
[self.image_token_id]
|
||||||
|
* (num_height_tiles * h)
|
||||||
|
* (num_width_tiles * w + 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenized_str += tokenized_image
|
||||||
|
images_seq_mask += [True] * len(tokenized_image)
|
||||||
|
# print(width_crop_num, height_crop_num, len(tokenized_image)) # test the correctness of the number of image-related tokens
|
||||||
|
|
||||||
|
"""process the last text split"""
|
||||||
|
tokenized_sep = self.encode(text_splits[-1], bos=False, eos=False)
|
||||||
|
# deal with video, limit with request len
|
||||||
|
if max_req_input_len > -1:
|
||||||
|
if max_req_input_len < len(tokenized_sep) + len(tokenized_str) - 1:
|
||||||
|
rest = max_req_input_len - len(tokenized_sep) - 1 - 1024
|
||||||
|
tokenized_str = tokenized_str[:rest]
|
||||||
|
images_seq_mask = images_seq_mask[:rest]
|
||||||
|
tokenized_str += tokenized_sep
|
||||||
|
images_seq_mask += [False] * len(tokenized_sep)
|
||||||
|
|
||||||
|
"""add the bos and eos tokens"""
|
||||||
|
if bos:
|
||||||
|
tokenized_str = [self.bos_id] + tokenized_str
|
||||||
|
images_seq_mask = [False] + images_seq_mask
|
||||||
|
if eos:
|
||||||
|
tokenized_str = tokenized_str + [self.eos_id]
|
||||||
|
images_seq_mask = images_seq_mask + [False]
|
||||||
|
|
||||||
|
assert len(tokenized_str) == len(
|
||||||
|
images_seq_mask
|
||||||
|
), f"tokenize_with_images func: tokenized_str's length {len(tokenized_str)} is not equal to imags_seq_mask's length {len(images_seq_mask)}"
|
||||||
|
|
||||||
|
return tokenized_str, images_list, images_seq_mask, images_spatial_crop
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekVL2VisionEncoderConfig(PretrainedConfig):
|
||||||
|
model_type: str = "vision"
|
||||||
|
|
||||||
|
model_name: str = "siglip_large_patch16_384"
|
||||||
|
image_size: int = 384
|
||||||
|
patch_size: int = 16
|
||||||
|
width: int = 1024
|
||||||
|
layers: int = 24
|
||||||
|
heads: int = 16
|
||||||
|
mlp_ratio: int = 4
|
||||||
|
global_pool: str = "map"
|
||||||
|
ignore_head: bool = True
|
||||||
|
class_token: bool = False
|
||||||
|
num_classes: int = 0
|
||||||
|
use_checkpoint: bool = False
|
||||||
|
weight_init: str = "skip"
|
||||||
|
deterministic: bool = False
|
||||||
|
num_recomputing_layers: int = 0
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str = "siglip_large_patch16_384",
|
||||||
|
image_size: int = 384,
|
||||||
|
patch_size: int = 16,
|
||||||
|
width: int = 1024,
|
||||||
|
layers: int = 24,
|
||||||
|
heads: int = 16,
|
||||||
|
mlp_ratio: int = 4,
|
||||||
|
global_pool: str = "map",
|
||||||
|
ignore_head: bool = True,
|
||||||
|
class_token: bool = False,
|
||||||
|
num_classes: int = 0,
|
||||||
|
use_checkpoint: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.model_name = model_name
|
||||||
|
self.image_size = image_size
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.width = width
|
||||||
|
self.layers = layers
|
||||||
|
self.heads = heads
|
||||||
|
self.mlp_ratio = mlp_ratio
|
||||||
|
self.global_pool = global_pool
|
||||||
|
self.ignore_head = ignore_head
|
||||||
|
self.class_token = class_token
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.use_checkpoint = use_checkpoint
|
||||||
|
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekVL2MlpProjectorConfig(PretrainedConfig):
|
||||||
|
model_type = "mlp_projector"
|
||||||
|
projector_type: str = "downsample_mlp_gelu"
|
||||||
|
input_dim: int = 1152
|
||||||
|
n_embed: int = 2048
|
||||||
|
depth: int = 2
|
||||||
|
mlp_ratio: int = 1
|
||||||
|
downsample_ratio: int = 2
|
||||||
|
token_pooling: bool = False
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
projector_type: str = "downsample_mlp_gelu",
|
||||||
|
input_dim: int = 1152,
|
||||||
|
n_embed: int = 2048,
|
||||||
|
depth: int = 2,
|
||||||
|
mlp_ratio: int = 1,
|
||||||
|
downsample_ratio: int = 2,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.projector_type = projector_type
|
||||||
|
self.input_dim = input_dim
|
||||||
|
self.n_embed = n_embed
|
||||||
|
self.depth = depth
|
||||||
|
self.mlp_ratio = mlp_ratio
|
||||||
|
self.downsample_ratio = downsample_ratio
|
||||||
|
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekV2Config(PretrainedConfig):
|
||||||
|
|
||||||
|
model_type = "deepseek_v2"
|
||||||
|
keys_to_ignore_at_inference = ["past_key_values"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=102400,
|
||||||
|
hidden_size=4096,
|
||||||
|
intermediate_size=11008,
|
||||||
|
moe_intermediate_size=1407,
|
||||||
|
num_hidden_layers=30,
|
||||||
|
num_attention_heads=32,
|
||||||
|
num_key_value_heads=32,
|
||||||
|
n_shared_experts=None,
|
||||||
|
n_routed_experts=None,
|
||||||
|
ep_size=1,
|
||||||
|
routed_scaling_factor=1.0,
|
||||||
|
kv_lora_rank=512,
|
||||||
|
q_lora_rank=1536,
|
||||||
|
qk_rope_head_dim=64,
|
||||||
|
v_head_dim=128,
|
||||||
|
qk_nope_head_dim=128,
|
||||||
|
topk_method="gready",
|
||||||
|
n_group=None,
|
||||||
|
topk_group=None,
|
||||||
|
num_experts_per_tok=None,
|
||||||
|
moe_layer_freq=1,
|
||||||
|
first_k_dense_replace=0,
|
||||||
|
norm_topk_prob=False,
|
||||||
|
scoring_func="softmax",
|
||||||
|
aux_loss_alpha=0.001,
|
||||||
|
seq_aux=True,
|
||||||
|
hidden_act="silu",
|
||||||
|
max_position_embeddings=2048,
|
||||||
|
initializer_range=0.02,
|
||||||
|
rms_norm_eps=1e-6,
|
||||||
|
use_cache=True,
|
||||||
|
pad_token_id=None,
|
||||||
|
bos_token_id=100000,
|
||||||
|
eos_token_id=100001,
|
||||||
|
pretraining_tp=1,
|
||||||
|
tie_word_embeddings=False,
|
||||||
|
rope_theta=10000.0,
|
||||||
|
rope_scaling=None,
|
||||||
|
attention_bias=False,
|
||||||
|
attention_dropout=0.0,
|
||||||
|
use_mla=True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.moe_intermediate_size = moe_intermediate_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.n_shared_experts = n_shared_experts
|
||||||
|
self.n_routed_experts = n_routed_experts
|
||||||
|
self.ep_size = ep_size
|
||||||
|
self.routed_scaling_factor = routed_scaling_factor
|
||||||
|
self.kv_lora_rank = kv_lora_rank
|
||||||
|
self.q_lora_rank = q_lora_rank
|
||||||
|
self.qk_rope_head_dim = qk_rope_head_dim
|
||||||
|
self.v_head_dim = v_head_dim
|
||||||
|
self.qk_nope_head_dim = qk_nope_head_dim
|
||||||
|
self.topk_method = topk_method
|
||||||
|
self.n_group = n_group
|
||||||
|
self.topk_group = topk_group
|
||||||
|
self.num_experts_per_tok = num_experts_per_tok
|
||||||
|
self.moe_layer_freq = moe_layer_freq
|
||||||
|
self.first_k_dense_replace = first_k_dense_replace
|
||||||
|
self.norm_topk_prob = norm_topk_prob
|
||||||
|
self.scoring_func = scoring_func
|
||||||
|
self.aux_loss_alpha = aux_loss_alpha
|
||||||
|
self.seq_aux = seq_aux
|
||||||
|
# for backward compatibility
|
||||||
|
if num_key_value_heads is None:
|
||||||
|
num_key_value_heads = num_attention_heads
|
||||||
|
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.rms_norm_eps = float(rms_norm_eps)
|
||||||
|
self.pretraining_tp = pretraining_tp
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.rope_scaling = rope_scaling
|
||||||
|
self.attention_bias = attention_bias
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
self.use_mla = use_mla
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
bos_token_id=bos_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekVL2Config(PretrainedConfig):
|
||||||
|
model_type = "deepseek_vl_v2"
|
||||||
|
vision_config: DeepseekVL2VisionEncoderConfig
|
||||||
|
projector_config: DeepseekVL2MlpProjectorConfig
|
||||||
|
language_config: DeepseekV2Config
|
||||||
|
|
||||||
|
tile_tag: str = "2D"
|
||||||
|
global_view_pos: str = "head"
|
||||||
|
candidate_resolutions: Tuple[Tuple[int, int]] = ((384, 384),)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tile_tag: str = "tile_tag",
|
||||||
|
global_view_pos: str = "head",
|
||||||
|
candidate_resolutions: Tuple[Tuple[int, int]] = ((384, 384),),
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
vision_config = kwargs.get("vision_config", {})
|
||||||
|
self.vision_config = DeepseekVL2VisionEncoderConfig(**vision_config)
|
||||||
|
|
||||||
|
projector_config = kwargs.get("projector_config", {})
|
||||||
|
self.projector_config = DeepseekVL2MlpProjectorConfig(**projector_config)
|
||||||
|
|
||||||
|
language_config = kwargs.get("language_config", {})
|
||||||
|
if isinstance(language_config, DeepseekV2Config):
|
||||||
|
self.language_config = language_config
|
||||||
|
else:
|
||||||
|
self.language_config = DeepseekV2Config(**language_config)
|
||||||
|
|
||||||
|
self.tile_tag = tile_tag
|
||||||
|
self.global_view_pos = global_view_pos
|
||||||
|
self.candidate_resolutions = candidate_resolutions
|
||||||
|
self.architectures = ["DeepseekVL2ForCausalLM"]
|
||||||
|
|
||||||
|
|
||||||
|
AutoProcessor.register(DeepseekVL2Config, DeepseekVLV2Processor)
|
||||||
@@ -135,6 +135,11 @@ class ModelConfig:
|
|||||||
self.attention_arch = AttentionArch.MLA
|
self.attention_arch = AttentionArch.MLA
|
||||||
self.kv_lora_rank = self.hf_config.kv_lora_rank
|
self.kv_lora_rank = self.hf_config.kv_lora_rank
|
||||||
self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
|
self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
|
||||||
|
elif "DeepseekVL2ForCausalLM" in self.hf_config.architectures:
|
||||||
|
self.head_dim = 256
|
||||||
|
self.attention_arch = AttentionArch.MLA
|
||||||
|
self.kv_lora_rank = self.hf_text_config.kv_lora_rank
|
||||||
|
self.qk_rope_head_dim = self.hf_text_config.qk_rope_head_dim
|
||||||
else:
|
else:
|
||||||
self.attention_arch = AttentionArch.MHA
|
self.attention_arch = AttentionArch.MHA
|
||||||
|
|
||||||
@@ -362,6 +367,8 @@ def get_hf_text_config(config: PretrainedConfig):
|
|||||||
# if transformers config doesn't align with this assumption.
|
# if transformers config doesn't align with this assumption.
|
||||||
assert hasattr(config.text_config, "num_attention_heads")
|
assert hasattr(config.text_config, "num_attention_heads")
|
||||||
return config.text_config
|
return config.text_config
|
||||||
|
if hasattr(config, "language_config"):
|
||||||
|
return config.language_config
|
||||||
else:
|
else:
|
||||||
return config
|
return config
|
||||||
|
|
||||||
@@ -465,6 +472,7 @@ multimodal_model_archs = [
|
|||||||
"Qwen2_5_VLForConditionalGeneration",
|
"Qwen2_5_VLForConditionalGeneration",
|
||||||
"MiniCPMV",
|
"MiniCPMV",
|
||||||
"MultiModalityCausalLM",
|
"MultiModalityCausalLM",
|
||||||
|
"DeepseekVL2ForCausalLM",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -44,6 +44,7 @@ class SeparatorStyle(IntEnum):
|
|||||||
CHATGLM3 = auto()
|
CHATGLM3 = auto()
|
||||||
DEEPSEEK_CHAT = auto()
|
DEEPSEEK_CHAT = auto()
|
||||||
METAMATH = auto()
|
METAMATH = auto()
|
||||||
|
DeepSeekVL2 = auto()
|
||||||
QWEN2_VL_EMBED = auto()
|
QWEN2_VL_EMBED = auto()
|
||||||
GEMMA3 = auto()
|
GEMMA3 = auto()
|
||||||
|
|
||||||
@@ -75,6 +76,7 @@ class Conversation:
|
|||||||
|
|
||||||
image_data: Optional[List[str]] = None
|
image_data: Optional[List[str]] = None
|
||||||
modalities: Optional[List[str]] = None
|
modalities: Optional[List[str]] = None
|
||||||
|
stop_token_ids: Optional[int] = None
|
||||||
|
|
||||||
def get_prompt(self) -> str:
|
def get_prompt(self) -> str:
|
||||||
"""Get the prompt for generation."""
|
"""Get the prompt for generation."""
|
||||||
@@ -286,6 +288,18 @@ class Conversation:
|
|||||||
else:
|
else:
|
||||||
ret += role + ":"
|
ret += role + ":"
|
||||||
return ret
|
return ret
|
||||||
|
elif self.sep_style == SeparatorStyle.DeepSeekVL2:
|
||||||
|
seps = [self.sep, self.sep2]
|
||||||
|
if system_prompt == "" or system_prompt is None:
|
||||||
|
ret = ""
|
||||||
|
else:
|
||||||
|
ret = system_prompt + seps[0]
|
||||||
|
for i, (role, message) in enumerate(self.messages):
|
||||||
|
if message:
|
||||||
|
ret += role + ": " + message + seps[i % 2]
|
||||||
|
else:
|
||||||
|
ret += role + ":"
|
||||||
|
return ret
|
||||||
elif self.sep_style == SeparatorStyle.GEMMA3:
|
elif self.sep_style == SeparatorStyle.GEMMA3:
|
||||||
ret = system_prompt
|
ret = system_prompt
|
||||||
for i, (role, message) in enumerate(self.messages):
|
for i, (role, message) in enumerate(self.messages):
|
||||||
@@ -617,6 +631,23 @@ register_conv_template(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
register_conv_template(
|
||||||
|
Conversation(
|
||||||
|
name="deepseek-vl2",
|
||||||
|
system_template="{system_message}",
|
||||||
|
# system_message="You are a helpful assistant. Please answer truthfully and write out your "
|
||||||
|
# "thinking step by step to be sure you get the right answer.",
|
||||||
|
system_message="",
|
||||||
|
roles=("<|User|>", "<|Assistant|>"),
|
||||||
|
messages=(),
|
||||||
|
offset=0,
|
||||||
|
sep_style=SeparatorStyle.DeepSeekVL2,
|
||||||
|
sep="\n\n",
|
||||||
|
sep2="<|end▁of▁sentence|>",
|
||||||
|
stop_str=["User:", "<|end▁of▁sentence|>"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Reference: https://huggingface.co/google/gemma-3-4b-it/blob/main/config.json
|
# Reference: https://huggingface.co/google/gemma-3-4b-it/blob/main/config.json
|
||||||
register_conv_template(
|
register_conv_template(
|
||||||
Conversation(
|
Conversation(
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_N
|
|||||||
from sglang.srt.configs import (
|
from sglang.srt.configs import (
|
||||||
ChatGLMConfig,
|
ChatGLMConfig,
|
||||||
DbrxConfig,
|
DbrxConfig,
|
||||||
|
DeepseekVL2Config,
|
||||||
ExaoneConfig,
|
ExaoneConfig,
|
||||||
Gemma3Config,
|
Gemma3Config,
|
||||||
Gemma3TextConfig,
|
Gemma3TextConfig,
|
||||||
@@ -47,6 +48,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
|||||||
DbrxConfig.model_type: DbrxConfig,
|
DbrxConfig.model_type: DbrxConfig,
|
||||||
ExaoneConfig.model_type: ExaoneConfig,
|
ExaoneConfig.model_type: ExaoneConfig,
|
||||||
Qwen2_5_VLConfig.model_type: Qwen2_5_VLConfig,
|
Qwen2_5_VLConfig.model_type: Qwen2_5_VLConfig,
|
||||||
|
DeepseekVL2Config.model_type: DeepseekVL2Config,
|
||||||
MultiModalityConfig.model_type: MultiModalityConfig,
|
MultiModalityConfig.model_type: MultiModalityConfig,
|
||||||
Gemma3Config.model_type: Gemma3Config,
|
Gemma3Config.model_type: Gemma3Config,
|
||||||
Gemma3TextConfig.model_type: Gemma3TextConfig,
|
Gemma3TextConfig.model_type: Gemma3TextConfig,
|
||||||
|
|||||||
104
python/sglang/srt/managers/image_processors/deepseek_vl_v2.py
Normal file
104
python/sglang/srt/managers/image_processors/deepseek_vl_v2.py
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
# Copyright (c) 2023-2024 DeepSeek.
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining a copy of
|
||||||
|
# this software and associated documentation files (the "Software"), to deal in
|
||||||
|
# the Software without restriction, including without limitation the rights to
|
||||||
|
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
||||||
|
# the Software, and to permit persons to whom the Software is furnished to do so,
|
||||||
|
# subject to the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be included in all
|
||||||
|
# copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
||||||
|
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
||||||
|
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
||||||
|
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
||||||
|
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import math
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from PIL import Image, ImageOps
|
||||||
|
|
||||||
|
from sglang.srt.managers.image_processor import BaseImageProcessor
|
||||||
|
from sglang.srt.managers.image_processors.base_image_processor import (
|
||||||
|
get_global_processor,
|
||||||
|
)
|
||||||
|
from sglang.srt.models.deepseek_vl2 import DeepseekVL2ForCausalLM
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekVL2ImageProcessor(BaseImageProcessor):
|
||||||
|
def __init__(self, hf_config, server_args, _processor):
|
||||||
|
# with contextlib.suppress(ValueError):
|
||||||
|
# AutoProcessor.register("DeepseekVLV2Processor", DeepseekVLV2Processor)
|
||||||
|
super().__init__(hf_config, server_args, _processor)
|
||||||
|
self.IMAGE_TOKEN = "<image>"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _process_images_task(image, input_text, max_req_input_len):
|
||||||
|
return get_global_processor().__call__(
|
||||||
|
conversations=input_text, images=image, max_req_input_len=max_req_input_len
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _process_images(self, image_data, input_text, max_req_input_len):
|
||||||
|
if self.executor is not None:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
image_inputs = await loop.run_in_executor(
|
||||||
|
self.executor,
|
||||||
|
DeepseekVL2ImageProcessor._process_images_task,
|
||||||
|
image_data,
|
||||||
|
input_text,
|
||||||
|
max_req_input_len,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
image_inputs = self._process_images_task(
|
||||||
|
image_data, input_text, max_req_input_len
|
||||||
|
)
|
||||||
|
|
||||||
|
return image_inputs
|
||||||
|
|
||||||
|
async def process_images_async(
|
||||||
|
self, image_data, input_ids, request_obj, max_req_input_len, *args, **kwargs
|
||||||
|
):
|
||||||
|
if not image_data:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not isinstance(image_data, list):
|
||||||
|
image_data = [image_data]
|
||||||
|
|
||||||
|
images, image_hashes, image_sizes = [], [], []
|
||||||
|
|
||||||
|
image_token = self.IMAGE_TOKEN
|
||||||
|
base_output = self.load_images(
|
||||||
|
input_ids, image_data, image_token, max_req_input_len
|
||||||
|
)
|
||||||
|
base_output.all_frames = [img.convert("RGB") for img in base_output.all_frames]
|
||||||
|
res = await self._process_images(
|
||||||
|
base_output.all_frames, base_output.input_text, max_req_input_len
|
||||||
|
)
|
||||||
|
pixel_values = res["images"]
|
||||||
|
input_ids = res["input_ids"]
|
||||||
|
images_seq_mask = res["images_seq_mask"]
|
||||||
|
images_spatial_crop = res["images_spatial_crop"]
|
||||||
|
batched_images_spatial_crop = []
|
||||||
|
batched_images_spatial_crop.append(images_spatial_crop)
|
||||||
|
batched_images_spatial_crop = torch.stack(batched_images_spatial_crop, dim=0)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"input_ids": input_ids.tolist(),
|
||||||
|
"pixel_values": pixel_values,
|
||||||
|
"image_hashes": image_hashes,
|
||||||
|
"image_sizes": image_sizes,
|
||||||
|
"image_seq_mask": images_seq_mask,
|
||||||
|
"image_spatial_crop": batched_images_spatial_crop,
|
||||||
|
"modalities": request_obj.modalities or ["image"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
ImageProcessorMapping = {
|
||||||
|
DeepseekVL2ForCausalLM: DeepseekVL2ImageProcessor,
|
||||||
|
}
|
||||||
@@ -160,8 +160,13 @@ class ImageInputs:
|
|||||||
image_grid_thws: List[Tuple[int, int, int]] = None
|
image_grid_thws: List[Tuple[int, int, int]] = None
|
||||||
mrope_position_delta: Optional[torch.Tensor] = None
|
mrope_position_delta: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
# deepseek vl2 related
|
||||||
|
image_seq_mask: Optional[List[torch.Tensor]] = None
|
||||||
|
image_spatial_crop: Optional[List[torch.Tensor]] = None
|
||||||
|
|
||||||
# The id of the single-image placeholder token
|
# The id of the single-image placeholder token
|
||||||
im_token_id: Optional[torch.Tensor] = None
|
im_token_id: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
# All the images in the batch should share the same special image
|
# All the images in the batch should share the same special image
|
||||||
# bound token ids.
|
# bound token ids.
|
||||||
im_start_id: Optional[int] = None
|
im_start_id: Optional[int] = None
|
||||||
@@ -192,6 +197,8 @@ class ImageInputs:
|
|||||||
"aspect_ratio_ids",
|
"aspect_ratio_ids",
|
||||||
"aspect_ratio_mask",
|
"aspect_ratio_mask",
|
||||||
"image_grid_thws",
|
"image_grid_thws",
|
||||||
|
"image_seq_mask",
|
||||||
|
"image_spatial_crop",
|
||||||
"im_token_id",
|
"im_token_id",
|
||||||
"im_start_id",
|
"im_start_id",
|
||||||
"im_end_id",
|
"im_end_id",
|
||||||
@@ -228,6 +235,8 @@ class ImageInputs:
|
|||||||
"aspect_ratio_ids",
|
"aspect_ratio_ids",
|
||||||
"aspect_ratio_mask",
|
"aspect_ratio_mask",
|
||||||
"image_grid_thws",
|
"image_grid_thws",
|
||||||
|
"image_seq_mask",
|
||||||
|
"image_spatial_crop",
|
||||||
]
|
]
|
||||||
for arg in optional_args:
|
for arg in optional_args:
|
||||||
if getattr(self, arg, None) is not None:
|
if getattr(self, arg, None) is not None:
|
||||||
|
|||||||
@@ -266,6 +266,14 @@ class ModelRunner:
|
|||||||
server_args.chunked_prefill_size = -1
|
server_args.chunked_prefill_size = -1
|
||||||
server_args.disable_radix_cache = True
|
server_args.disable_radix_cache = True
|
||||||
|
|
||||||
|
if self.model_config.hf_config.architectures == ["DeepseekVL2ForCausalLM"]:
|
||||||
|
# TODO: deepseek-vl2 does not support radix cache now, set disable_radix_cache=True automatically
|
||||||
|
logger.info(
|
||||||
|
"Automatically turn off --chunked-prefill-size and disable radix cache for deekseek-vl2."
|
||||||
|
)
|
||||||
|
server_args.chunked_prefill_size = -1
|
||||||
|
server_args.disable_radix_cache = True
|
||||||
|
|
||||||
def init_torch_distributed(self):
|
def init_torch_distributed(self):
|
||||||
logger.info("Init torch distributed begin.")
|
logger.info("Init torch distributed begin.")
|
||||||
|
|
||||||
|
|||||||
@@ -1021,6 +1021,7 @@ class DeepseekV2Model(nn.Module):
|
|||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
|
input_embeds: torch.Tensor = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
# Gather
|
# Gather
|
||||||
@@ -1035,7 +1036,11 @@ class DeepseekV2Model(nn.Module):
|
|||||||
)
|
)
|
||||||
dp_gather(input_ids, local_input_ids, forward_batch, "embedding")
|
dp_gather(input_ids, local_input_ids, forward_batch, "embedding")
|
||||||
|
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
if input_embeds is None:
|
||||||
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
else:
|
||||||
|
hidden_states = input_embeds
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
for i in range(len(self.layers)):
|
for i in range(len(self.layers)):
|
||||||
layer = self.layers[i]
|
layer = self.layers[i]
|
||||||
@@ -1076,8 +1081,10 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
|
input_embeds: torch.Tensor = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, forward_batch)
|
|
||||||
|
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||||
|
|
||||||
if self.dp_size != 1:
|
if self.dp_size != 1:
|
||||||
# important: forward batch.gathered_buffer is used both after scatter and after gather.
|
# important: forward batch.gathered_buffer is used both after scatter and after gather.
|
||||||
|
|||||||
391
python/sglang/srt/models/deepseek_vl2.py
Normal file
391
python/sglang/srt/models/deepseek_vl2.py
Normal file
@@ -0,0 +1,391 @@
|
|||||||
|
import collections
|
||||||
|
import itertools
|
||||||
|
import math
|
||||||
|
import warnings
|
||||||
|
from enum import Enum
|
||||||
|
from functools import partial
|
||||||
|
from typing import Callable, Iterable, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from sglang.srt.configs import DeepseekVL2Config
|
||||||
|
from sglang.srt.configs.deepseekvl2 import (
|
||||||
|
DeepseekVL2Config,
|
||||||
|
DeepseekVL2MlpProjectorConfig,
|
||||||
|
)
|
||||||
|
from sglang.srt.layers.attention.vision import VisionAttention
|
||||||
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
|
from sglang.srt.layers.linear import (
|
||||||
|
ColumnParallelLinear,
|
||||||
|
LinearBase,
|
||||||
|
ReplicatedLinear,
|
||||||
|
RowParallelLinear,
|
||||||
|
)
|
||||||
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||||
|
ParallelLMHead,
|
||||||
|
VocabParallelEmbedding,
|
||||||
|
)
|
||||||
|
from sglang.srt.managers.schedule_batch import ImageInputs
|
||||||
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
|
from sglang.srt.models.deepseek_v2 import DeepseekV2ForCausalLM
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekVL2MlpProjector(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: DeepseekVL2MlpProjectorConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
if config.projector_type == "identity":
|
||||||
|
modules = nn.Identity()
|
||||||
|
|
||||||
|
elif config.projector_type == "linear":
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
ReplicatedLinear(
|
||||||
|
config.input_dim,
|
||||||
|
config.n_embed,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
elif config.projector_type == "mlp_gelu":
|
||||||
|
mlp_depth = config.depth
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
ReplicatedLinear(
|
||||||
|
config.input_dim,
|
||||||
|
config.n_embed,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
for _ in range(1, mlp_depth):
|
||||||
|
self.layers.append(nn.GELU())
|
||||||
|
self.layers.append(
|
||||||
|
ReplicatedLinear(
|
||||||
|
config.n_embed,
|
||||||
|
config.n_embed,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
elif config.projector_type == "downsample_mlp_gelu":
|
||||||
|
mlp_depth = config.depth
|
||||||
|
mlp_ratio = config.mlp_ratio
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
ReplicatedLinear(
|
||||||
|
config.input_dim
|
||||||
|
* config.downsample_ratio
|
||||||
|
* config.downsample_ratio,
|
||||||
|
config.n_embed * mlp_ratio,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
for _ in range(1, mlp_depth - 1):
|
||||||
|
self.layers.append(nn.GELU())
|
||||||
|
self.layers.append(
|
||||||
|
ReplicatedLinear(
|
||||||
|
config.n_embed * mlp_ratio,
|
||||||
|
config.n_embed * mlp_ratio,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.layers.append(nn.GELU())
|
||||||
|
self.layers.append(
|
||||||
|
ReplicatedLinear(
|
||||||
|
config.n_embed * mlp_ratio,
|
||||||
|
config.n_embed,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown projector type: {config.projector_type}")
|
||||||
|
|
||||||
|
if config.token_pooling:
|
||||||
|
self.token_pooling_layer = ReplicatedLinear(
|
||||||
|
config.input_dim * 4, config.input_dim, quant_config=quant_config
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.config.token_pooling:
|
||||||
|
batch_size, wxh, channels = x.shape
|
||||||
|
w = h = int(wxh**0.5)
|
||||||
|
x = x.view(batch_size, w, h, channels)
|
||||||
|
x = x.permute(0, 3, 1, 2)
|
||||||
|
|
||||||
|
patches = x.unfold(2, 2, 2).unfold(3, 2, 2)
|
||||||
|
batch_size, channels, h_patches, w_patches, _, _ = patches.size()
|
||||||
|
patches = patches.contiguous().view(
|
||||||
|
batch_size, channels, h_patches * w_patches, -1
|
||||||
|
)
|
||||||
|
patches = patches.permute(0, 2, 1, 3).contiguous()
|
||||||
|
patches = patches.view(batch_size, h_patches * w_patches, channels * 4)
|
||||||
|
|
||||||
|
x = self.token_pooling_layer(patches)[0]
|
||||||
|
|
||||||
|
elif self.config.projector_type == "downsample_mlp_gelu":
|
||||||
|
bs, hw, input_dim = x.shape
|
||||||
|
h = w = int((hw) ** 0.5)
|
||||||
|
|
||||||
|
"""compute padding"""
|
||||||
|
if h % self.config.downsample_ratio:
|
||||||
|
pad = self.config.downsample_ratio - h % self.config.downsample_ratio
|
||||||
|
else:
|
||||||
|
pad = 0
|
||||||
|
x = x.reshape(bs, h, w, input_dim)
|
||||||
|
if pad > 0:
|
||||||
|
x = F.pad(x, (0, 0, 0, pad, 0, pad), "constant", 0)
|
||||||
|
|
||||||
|
"""4 to 1 concat"""
|
||||||
|
x = x.permute(0, 3, 1, 2) # B, C, H, W
|
||||||
|
x = F.unfold(
|
||||||
|
x,
|
||||||
|
kernel_size=self.config.downsample_ratio,
|
||||||
|
stride=self.config.downsample_ratio,
|
||||||
|
padding=0,
|
||||||
|
) # B, C*4, HW // 4
|
||||||
|
x = x.permute(0, 2, 1)
|
||||||
|
|
||||||
|
for layer in self.layers:
|
||||||
|
x = layer(x)
|
||||||
|
if isinstance(x, tuple):
|
||||||
|
x = x[0]
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# todo
|
||||||
|
class DeepseekVL2ForCausalLM(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: DeepseekVL2Config,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# ----------- vision encoder ------------
|
||||||
|
vision_config = config.vision_config
|
||||||
|
self.vision = self._init_vision_module(vision_config, quant_config)
|
||||||
|
|
||||||
|
# ----------- vl projector ------------
|
||||||
|
projector_config = config.projector_config
|
||||||
|
self.projector = DeepseekVL2MlpProjector(projector_config, quant_config)
|
||||||
|
|
||||||
|
self.tile_tag = config.tile_tag
|
||||||
|
self.global_view_pos = config.global_view_pos
|
||||||
|
|
||||||
|
embed_std = 1 / torch.sqrt(
|
||||||
|
torch.tensor(projector_config.n_embed, dtype=torch.float32)
|
||||||
|
)
|
||||||
|
if self.tile_tag == "2D":
|
||||||
|
self.image_newline = nn.Parameter(
|
||||||
|
torch.randn(projector_config.n_embed) * embed_std
|
||||||
|
)
|
||||||
|
self.view_seperator = nn.Parameter(
|
||||||
|
torch.randn(projector_config.n_embed) * embed_std
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"tile tag should be 2D, but got {self.tile_tag}")
|
||||||
|
|
||||||
|
# ----------- language model ------------
|
||||||
|
language_config = config.language_config
|
||||||
|
self.language_model = DeepseekV2ForCausalLM(language_config)
|
||||||
|
|
||||||
|
def _init_vision_module(
|
||||||
|
self, vision_config, quant_config: Optional[QuantizationConfig]
|
||||||
|
) -> nn.Module:
|
||||||
|
# TODO: refactor vision model through timm wrapper from transformers
|
||||||
|
try:
|
||||||
|
import timm
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("Please install timm") from ImportError
|
||||||
|
|
||||||
|
model = timm.create_model(
|
||||||
|
"vit_so400m_patch14_siglip_384.webli",
|
||||||
|
pretrained=False,
|
||||||
|
num_classes=0,
|
||||||
|
dynamic_img_size=True,
|
||||||
|
dynamic_img_pad=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
model = model.to(dtype=torch.get_default_dtype())
|
||||||
|
return model
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
**kwargs: object,
|
||||||
|
):
|
||||||
|
|
||||||
|
input_embeds = self.language_model.model.embed_tokens(input_ids)
|
||||||
|
if forward_batch.forward_mode.is_extend() and forward_batch.image_inputs != [
|
||||||
|
None
|
||||||
|
]:
|
||||||
|
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
|
||||||
|
extend_seq_lens_cpu = forward_batch.extend_seq_lens.cpu().numpy()
|
||||||
|
for idx, image in enumerate(forward_batch.image_inputs):
|
||||||
|
if image is None:
|
||||||
|
continue
|
||||||
|
start_idx = extend_start_loc_cpu[idx]
|
||||||
|
end_idx = start_idx + extend_seq_lens_cpu[idx]
|
||||||
|
pixel_values = image.pixel_values.to(
|
||||||
|
device="cuda", dtype=torch.bfloat16
|
||||||
|
)
|
||||||
|
image_seq_mask = image.image_seq_mask.to(device="cuda")
|
||||||
|
image_spatial_crop = image.image_spatial_crop
|
||||||
|
input_embeds[start_idx:end_idx] = self.prepare_inputs_embeds(
|
||||||
|
pixel_values,
|
||||||
|
image_seq_mask,
|
||||||
|
image_spatial_crop,
|
||||||
|
input_embeds[start_idx:end_idx],
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = self.language_model.forward(
|
||||||
|
input_ids=input_ids,
|
||||||
|
positions=positions,
|
||||||
|
forward_batch=forward_batch,
|
||||||
|
input_embeds=input_embeds,
|
||||||
|
)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
("qkv_proj", "q_proj", "q"),
|
||||||
|
("qkv_proj", "k_proj", "k"),
|
||||||
|
("qkv_proj", "v_proj", "v"),
|
||||||
|
("gate_up_proj", "up_proj", 1),
|
||||||
|
("gate_up_proj", "gate_proj", 0),
|
||||||
|
]
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
weights = list(weights)
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
if "language" in name:
|
||||||
|
name = name.replace("language.", "")
|
||||||
|
self.language_model.load_weights([(name, loaded_weight)])
|
||||||
|
else:
|
||||||
|
param = params_dict[name]
|
||||||
|
weights_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
|
weights_loader(param, loaded_weight)
|
||||||
|
|
||||||
|
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
|
||||||
|
return input_ids
|
||||||
|
|
||||||
|
def prepare_inputs_embeds(
|
||||||
|
self,
|
||||||
|
pixel_values,
|
||||||
|
images_seq_mask,
|
||||||
|
images_spatial_crop,
|
||||||
|
input_embeds,
|
||||||
|
):
|
||||||
|
image_feature = self.vision.forward_features(pixel_values)
|
||||||
|
images_embeds = self.projector(image_feature)
|
||||||
|
_, hw, n_dim = images_embeds.shape
|
||||||
|
h = w = int(hw**0.5)
|
||||||
|
|
||||||
|
tile_index = 0
|
||||||
|
images_in_this_batch = []
|
||||||
|
for jdx in range(images_spatial_crop.shape[1]):
|
||||||
|
num_width_tiles, num_height_tiles = images_spatial_crop[0, jdx]
|
||||||
|
if num_width_tiles == 0 or num_height_tiles == 0:
|
||||||
|
break
|
||||||
|
num_tiles_in_image = num_width_tiles * num_height_tiles
|
||||||
|
|
||||||
|
# [hw, D]
|
||||||
|
global_features = images_embeds[tile_index]
|
||||||
|
|
||||||
|
# [num_height_tiles * num_width_tiles, hw, D]
|
||||||
|
local_features = images_embeds[
|
||||||
|
tile_index + 1 : tile_index + 1 + num_tiles_in_image
|
||||||
|
]
|
||||||
|
tile_index += num_tiles_in_image + 1
|
||||||
|
|
||||||
|
# format global and local features
|
||||||
|
# ----------------- global view add newline -----------------
|
||||||
|
# [hw, D] -> [h, w, D]
|
||||||
|
global_features = global_features.view(h, w, n_dim)
|
||||||
|
|
||||||
|
# [D] -> [h, 1, D]
|
||||||
|
new_lines_in_global = repeat(self.image_newline, "d -> h 1 d", h=h)
|
||||||
|
|
||||||
|
# cat([h, w, D], [h, 1, D], dim=1) -> [h, w + 1, D]
|
||||||
|
global_features = torch.cat([global_features, new_lines_in_global], dim=1)
|
||||||
|
|
||||||
|
# [h, w + 1, D] -> [h * (w + 1), D]
|
||||||
|
global_features = global_features.view(-1, n_dim)
|
||||||
|
|
||||||
|
# ----------------- local view add newline -----------------
|
||||||
|
# [num_height_tiles * num_width_tiles, h * w, D] ->
|
||||||
|
# [num_height_tiles * h, num_width_tiles * w, D]
|
||||||
|
local_features = rearrange(
|
||||||
|
local_features,
|
||||||
|
"(th tw) (h w) d -> (th h) (tw w) d",
|
||||||
|
th=num_height_tiles,
|
||||||
|
tw=num_width_tiles,
|
||||||
|
h=h,
|
||||||
|
w=w,
|
||||||
|
)
|
||||||
|
|
||||||
|
# [D] -> [num_height_tiles * h, 1, D]
|
||||||
|
new_lines_in_local = repeat(
|
||||||
|
self.image_newline,
|
||||||
|
"d -> (th h) 1 d",
|
||||||
|
th=num_height_tiles,
|
||||||
|
h=h,
|
||||||
|
)
|
||||||
|
|
||||||
|
# [num_height_tiles * h, num_width_tiles * w + 1, D]
|
||||||
|
local_features = torch.cat([local_features, new_lines_in_local], dim=1)
|
||||||
|
|
||||||
|
# [num_height_tiles * h, num_width_tiles * w + 1, D]
|
||||||
|
# --> [(num_height_tiles * h) * (num_width_tiles * w + 1), D]
|
||||||
|
local_features = local_features.view(-1, n_dim)
|
||||||
|
|
||||||
|
# merge global and local tiles
|
||||||
|
if self.global_view_pos == "head":
|
||||||
|
global_local_features = torch.cat(
|
||||||
|
[
|
||||||
|
global_features,
|
||||||
|
self.view_seperator[None, :],
|
||||||
|
local_features,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
global_local_features = torch.cat(
|
||||||
|
[
|
||||||
|
local_features,
|
||||||
|
self.view_seperator[None, :],
|
||||||
|
global_features,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
images_in_this_batch.append(global_local_features)
|
||||||
|
|
||||||
|
if len(images_in_this_batch) > 0:
|
||||||
|
images_in_this_batch = torch.cat(images_in_this_batch, dim=0)
|
||||||
|
input_embeds.masked_scatter_(
|
||||||
|
images_seq_mask.unsqueeze(-1), images_in_this_batch
|
||||||
|
)
|
||||||
|
|
||||||
|
return input_embeds
|
||||||
|
|
||||||
|
|
||||||
|
EntryClass = DeepseekVL2ForCausalLM
|
||||||
@@ -24,3 +24,6 @@ pip install transformers==4.48.3 sentence_transformers accelerate==1.4.0 peft pa
|
|||||||
|
|
||||||
# For compling xgrammar kernels
|
# For compling xgrammar kernels
|
||||||
pip install cuda-python nvidia-cuda-nvrtc-cu12
|
pip install cuda-python nvidia-cuda-nvrtc-cu12
|
||||||
|
|
||||||
|
# For DeepSeek-VL2
|
||||||
|
pip install timm
|
||||||
|
|||||||
@@ -513,6 +513,30 @@ class TestMinicpmvServer(TestOpenAIVisionServer):
|
|||||||
cls.base_url += "/v1"
|
cls.base_url += "/v1"
|
||||||
|
|
||||||
|
|
||||||
|
class TestDeepseekVL2Server(TestOpenAIVisionServer):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = "deepseek-ai/deepseek-vl2-small"
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.api_key = "sk-123456"
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=[
|
||||||
|
"--trust-remote-code",
|
||||||
|
"--chat-template",
|
||||||
|
"deepseek-vl2",
|
||||||
|
"--context-length",
|
||||||
|
"4096",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
cls.base_url += "/v1"
|
||||||
|
|
||||||
|
def test_video_chat_completion(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class TestJanusProServer(TestOpenAIVisionServer):
|
class TestJanusProServer(TestOpenAIVisionServer):
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
|
|||||||
Reference in New Issue
Block a user