refactor: bug fixes and refactor for vlm (#4661)

This commit is contained in:
Mick
2025-03-23 13:48:49 +08:00
committed by GitHub
parent ca75741e86
commit 11577cedb7
31 changed files with 770 additions and 735 deletions

View File

@@ -1,13 +1,14 @@
""" """
Bench the sglang-hosted vLM with benchmark MMMU Bench the sglang-hosted vLM with benchmark MMMU
Usage: Usage:
python benchmark/mmmu/bench_sglang.py --model-path Qwen/Qwen2-VL-7B-Instruct --chat-template qwen2-vl python benchmark/mmmu/bench_sglang.py --model-path Qwen/Qwen2-VL-7B-Instruct --chat-template qwen2-vl
The eval output will be logged The eval output will be logged
""" """
import argparse import argparse
import time
import openai import openai
from data_utils import save_json from data_utils import save_json
@@ -37,6 +38,7 @@ def eval_mmmu(args):
# had to use an openai server, since SglImage doesn't support image data # had to use an openai server, since SglImage doesn't support image data
client = openai.Client(api_key="sk", base_url=f"http://127.0.0.1:{args.port}/v1") client = openai.Client(api_key="sk", base_url=f"http://127.0.0.1:{args.port}/v1")
start = time.time()
for i, sample in enumerate(tqdm(samples)): for i, sample in enumerate(tqdm(samples)):
prompt = sample["final_input_prompt"] prompt = sample["final_input_prompt"]
prefix = prompt.split("<")[0] prefix = prompt.split("<")[0]
@@ -73,6 +75,8 @@ def eval_mmmu(args):
response = response.choices[0].message.content response = response.choices[0].message.content
process_result(response, sample, answer_dict, out_samples) process_result(response, sample, answer_dict, out_samples)
print(f"Benchmark time: {time.time() - start}")
args.output_path = f"./val_sglang.json" args.output_path = f"./val_sglang.json"
save_json(args.output_path, out_samples) save_json(args.output_path, out_samples)
eval_result(model_answer_path=args.output_path, answer_dict=answer_dict) eval_result(model_answer_path=args.output_path, answer_dict=answer_dict)

View File

@@ -9,8 +9,6 @@ import PIL
import torch import torch
from PIL.Image import Image from PIL.Image import Image
from transformers import ( from transformers import (
AutoImageProcessor,
AutoProcessor,
BaseImageProcessor, BaseImageProcessor,
BatchFeature, BatchFeature,
LlamaConfig, LlamaConfig,
@@ -20,6 +18,7 @@ from transformers import (
) )
from transformers.image_utils import to_numpy_array from transformers.image_utils import to_numpy_array
from sglang.srt.configs.utils import register_image_processor, register_processor
from sglang.srt.mm_utils import expand2square from sglang.srt.mm_utils import expand2square
@@ -625,5 +624,5 @@ class VLMImageProcessorConfig(PretrainedConfig):
super().__init__(**kwargs) super().__init__(**kwargs)
AutoProcessor.register(MultiModalityConfig, VLChatProcessor, exist_ok=True) register_processor(MultiModalityConfig, VLChatProcessor)
AutoImageProcessor.register(VLMImageProcessorConfig, None, VLMImageProcessor, None) register_image_processor(MultiModalityConfig, VLMImageProcessor)

View File

@@ -460,6 +460,7 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
multimodal_model_archs = [ multimodal_model_archs = [
"DeepseekVL2ForCausalLM",
"LlavaLlamaForCausalLM", "LlavaLlamaForCausalLM",
"LlavaQwenForCausalLM", "LlavaQwenForCausalLM",
"LlavaMistralForCausalLM", "LlavaMistralForCausalLM",
@@ -472,7 +473,6 @@ multimodal_model_archs = [
"Qwen2_5_VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration",
"MiniCPMV", "MiniCPMV",
"MultiModalityCausalLM", "MultiModalityCausalLM",
"DeepseekVL2ForCausalLM",
] ]

View File

@@ -0,0 +1,25 @@
from typing import Type
from transformers import (
AutoImageProcessor,
AutoProcessor,
BaseImageProcessor,
PretrainedConfig,
ProcessorMixin,
)
def register_image_processor(
config: Type[PretrainedConfig], image_processor: Type[BaseImageProcessor]
):
"""
register customized hf image processor while removing hf impl
"""
AutoImageProcessor.register(config, None, image_processor, None, exist_ok=True)
def register_processor(config: Type[PretrainedConfig], processor: Type[ProcessorMixin]):
"""
register customized hf processor while removing hf impl
"""
AutoProcessor.register(config, processor, exist_ok=True)

View File

@@ -653,7 +653,7 @@ register_conv_template(
Conversation( Conversation(
name="gemma-it", name="gemma-it",
system_message="You are a helpful assistant.", system_message="You are a helpful assistant.",
system_template="<bos><start_of_turn>user{system_message}\n\n", system_template="<start_of_turn>user{system_message}\n\n",
roles=("<start_of_turn>user\n", "<start_of_turn>model\n"), roles=("<start_of_turn>user\n", "<start_of_turn>model\n"),
sep="<end_of_turn>\n", sep="<end_of_turn>\n",
sep_style=SeparatorStyle.GEMMA3, sep_style=SeparatorStyle.GEMMA3,

View File

@@ -143,9 +143,14 @@ class VisionAttention(nn.Module):
if position_embeddings is not None: if position_embeddings is not None:
cos, sin = position_embeddings cos, sin = position_embeddings
original_shape = q.shape original_shape = q.shape
q, k = q.view(s, head, -1), k.view(s, head, -1) # [total_tokens, head, head_size]
q = q.view(-1, head, self.head_size)
k = k.view(-1, head, self.head_size)
q, k = apply_rotary_pos_emb(q, k, cos, sin) q, k = apply_rotary_pos_emb(q, k, cos, sin)
q, k = q.reshape(original_shape), k.reshape(original_shape)
q = q.view(original_shape)
k = k.view(original_shape)
if self.use_qkv_parallel: if self.use_qkv_parallel:
pass pass

View File

@@ -1,9 +1,12 @@
# TODO: also move pad_input_ids into this module # TODO: also move pad_input_ids into this module
import importlib import importlib
import inspect
import logging import logging
import pkgutil import pkgutil
from functools import lru_cache from functools import lru_cache
from typing import Union
from torch import Tensor
from transformers import IMAGE_PROCESSOR_MAPPING from transformers import IMAGE_PROCESSOR_MAPPING
from sglang.srt.managers.image_processors.base_image_processor import ( from sglang.srt.managers.image_processors.base_image_processor import (
@@ -18,9 +21,7 @@ logger = logging.getLogger(__name__)
IMAGE_PROCESSOR_MAPPING = {} IMAGE_PROCESSOR_MAPPING = {}
def get_image_processor( def get_image_processor(hf_config, server_args, processor) -> BaseImageProcessor:
hf_config, server_args: ServerArgs, processor
) -> BaseImageProcessor:
for model_cls, processor_cls in IMAGE_PROCESSOR_MAPPING.items(): for model_cls, processor_cls in IMAGE_PROCESSOR_MAPPING.items():
if model_cls.__name__ in hf_config.architectures: if model_cls.__name__ in hf_config.architectures:
return processor_cls(hf_config, server_args, processor) return processor_cls(hf_config, server_args, processor)
@@ -42,13 +43,18 @@ def import_image_processors():
try: try:
module = importlib.import_module(name) module = importlib.import_module(name)
except Exception as e: except Exception as e:
logger.warning(f"Ignore import error when loading {name}: " f"{e}") logger.warning(f" Ignore import error when loading {name}: " f"{e}")
continue continue
if hasattr(module, "ImageProcessorMapping"): all_members = inspect.getmembers(module, inspect.isclass)
entry = module.ImageProcessorMapping classes = [
if isinstance(entry, dict): member
for processor_name, cls in entry.items(): for name, member in all_members
IMAGE_PROCESSOR_MAPPING[processor_name] = cls if member.__module__ == module.__name__
]
for cls in classes:
if issubclass(cls, BaseImageProcessor):
for arch in getattr(cls, "models"):
IMAGE_PROCESSOR_MAPPING[arch] = cls
# also register processors # also register processors

View File

@@ -4,14 +4,14 @@ import dataclasses
import multiprocessing as mp import multiprocessing as mp
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Optional from typing import Optional, Union
import PIL import PIL
import transformers import transformers
from decord import VideoReader, cpu from decord import VideoReader, cpu
from openai import BadRequestError
from PIL import Image from PIL import Image
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import load_image from sglang.srt.utils import load_image
from sglang.utils import logger from sglang.utils import logger
@@ -31,8 +31,16 @@ class BaseImageProcessorOutput:
# input_text, with each frame of video/image represented as an image_token # input_text, with each frame of video/image represented as an image_token
input_text: str input_text: str
def normalize(self):
for field_name in ["data_hashes", "image_sizes", "all_frames"]:
field = getattr(self, field_name, None)
if field is not None and isinstance(field, list) and len(field) == 0:
setattr(self, field_name, None)
class BaseImageProcessor(ABC): class BaseImageProcessor(ABC):
models = []
def __init__(self, hf_config, server_args, _processor): def __init__(self, hf_config, server_args, _processor):
self.hf_config = hf_config self.hf_config = hf_config
self._processor = _processor self._processor = _processor
@@ -40,6 +48,9 @@ class BaseImageProcessor(ABC):
# FIXME: not accurate, model and image specific # FIXME: not accurate, model and image specific
self.NUM_TOKEN_PER_FRAME = 330 self.NUM_TOKEN_PER_FRAME = 330
# Initialize global processor first
init_global_processor(self, server_args)
self.executor = concurrent.futures.ProcessPoolExecutor( self.executor = concurrent.futures.ProcessPoolExecutor(
initializer=init_global_processor, initializer=init_global_processor,
mp_context=mp.get_context("fork"), mp_context=mp.get_context("fork"),
@@ -113,7 +124,7 @@ class BaseImageProcessor(ABC):
self, self,
input_ids: list[int], input_ids: list[int],
image_data, image_data,
image_token: str, image_token: Union[int, str],
max_req_input_len: int, max_req_input_len: int,
return_text: Optional[bool] = True, return_text: Optional[bool] = True,
discard_alpha_channel: bool = True, discard_alpha_channel: bool = True,
@@ -122,9 +133,16 @@ class BaseImageProcessor(ABC):
Each frame of video/image will be replaced by a single image token Each frame of video/image will be replaced by a single image token
Args: Args:
image_token: The token ID representing the image placeholder.
discard_alpha_channel: if True, discards the alpha channel in the returned images discard_alpha_channel: if True, discards the alpha channel in the returned images
""" """
if isinstance(image_token, int):
image_token_str = self._processor.tokenizer.convert_ids_to_tokens(
image_token
)
else:
image_token_str = image_token
if isinstance(input_ids, list) and return_text: if isinstance(input_ids, list) and return_text:
assert len(input_ids) and isinstance(input_ids[0], int) assert len(input_ids) and isinstance(input_ids[0], int)
@@ -190,13 +208,11 @@ class BaseImageProcessor(ABC):
new_text += text_part new_text += text_part
except Exception as e: except Exception as e:
import openai
logger.error(f"An exception occurred while loading images: {e}") logger.error(f"An exception occurred while loading images: {e}")
raise BadRequestError( raise BadRequestError(
f"An exception occurred while loading images: {e}" f"An exception occurred while loading images: {e}"
) )
continue
return BaseImageProcessorOutput( return BaseImageProcessorOutput(
image_hashes=hashes, image_hashes=hashes,
@@ -204,6 +220,8 @@ class BaseImageProcessor(ABC):
all_frames=images, all_frames=images,
input_text=new_text, input_text=new_text,
) )
out.normalize()
return out
class DummyImageProcessor(BaseImageProcessor): class DummyImageProcessor(BaseImageProcessor):
@@ -214,9 +232,7 @@ class DummyImageProcessor(BaseImageProcessor):
return None return None
def init_global_processor( def init_global_processor(sglang_image_processor: BaseImageProcessor, server_args):
sglang_image_processor: BaseImageProcessor, server_args: ServerArgs
):
"""Init the global processor for multi-modal models.""" """Init the global processor for multi-modal models."""
global global_processor global global_processor
transformers.logging.set_verbosity_error() transformers.logging.set_verbosity_error()

View File

@@ -16,13 +16,9 @@
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import asyncio import asyncio
import math
from typing import List, Union
import torch import torch
from PIL import Image, ImageOps
from sglang.srt.managers.image_processor import BaseImageProcessor from sglang.srt.managers.image_processor import BaseImageProcessor
from sglang.srt.managers.image_processors.base_image_processor import ( from sglang.srt.managers.image_processors.base_image_processor import (
@@ -32,18 +28,24 @@ from sglang.srt.models.deepseek_vl2 import DeepseekVL2ForCausalLM
class DeepseekVL2ImageProcessor(BaseImageProcessor): class DeepseekVL2ImageProcessor(BaseImageProcessor):
models = [DeepseekVL2ForCausalLM]
def __init__(self, hf_config, server_args, _processor): def __init__(self, hf_config, server_args, _processor):
# with contextlib.suppress(ValueError):
# AutoProcessor.register("DeepseekVLV2Processor", DeepseekVLV2Processor)
super().__init__(hf_config, server_args, _processor) super().__init__(hf_config, server_args, _processor)
self.IMAGE_TOKEN = "<image>" self.IMAGE_TOKEN = "<image>"
@staticmethod @staticmethod
def _process_images_task(image, input_text, max_req_input_len): def _process_images_task(image, input_text, max_req_input_len):
return get_global_processor().__call__( processor = get_global_processor()
res = processor.__call__(
conversations=input_text, images=image, max_req_input_len=max_req_input_len conversations=input_text, images=image, max_req_input_len=max_req_input_len
) )
image_token_id = processor.image_token_id
res["im_token_id"] = image_token_id
return res
async def _process_images(self, image_data, input_text, max_req_input_len): async def _process_images(self, image_data, input_text, max_req_input_len):
if self.executor is not None: if self.executor is not None:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
@@ -70,18 +72,15 @@ class DeepseekVL2ImageProcessor(BaseImageProcessor):
if not isinstance(image_data, list): if not isinstance(image_data, list):
image_data = [image_data] image_data = [image_data]
images, image_hashes, image_sizes = [], [], [] images, image_sizes = [], []
image_token = self.IMAGE_TOKEN image_token = self.IMAGE_TOKEN
base_output = self.load_images( base_output = self.load_images(
input_ids, image_data, image_token, max_req_input_len input_ids, image_data, image_token, max_req_input_len
) )
base_output.all_frames = [img.convert("RGB") for img in base_output.all_frames]
res = await self._process_images( res = await self._process_images(
base_output.all_frames, base_output.input_text, max_req_input_len base_output.all_frames, base_output.input_text, max_req_input_len
) )
pixel_values = res["images"]
input_ids = res["input_ids"]
images_seq_mask = res["images_seq_mask"] images_seq_mask = res["images_seq_mask"]
images_spatial_crop = res["images_spatial_crop"] images_spatial_crop = res["images_spatial_crop"]
batched_images_spatial_crop = [] batched_images_spatial_crop = []
@@ -89,16 +88,12 @@ class DeepseekVL2ImageProcessor(BaseImageProcessor):
batched_images_spatial_crop = torch.stack(batched_images_spatial_crop, dim=0) batched_images_spatial_crop = torch.stack(batched_images_spatial_crop, dim=0)
return { return {
"input_ids": input_ids.tolist(), "input_ids": res["input_ids"].tolist(),
"pixel_values": pixel_values, "pixel_values": res["images"],
"image_hashes": image_hashes, "im_token_id": res["im_token_id"],
"image_hashes": base_output.image_hashes,
"image_sizes": image_sizes, "image_sizes": image_sizes,
"image_seq_mask": images_seq_mask, "images_emb_mask": images_seq_mask,
"image_spatial_crop": batched_images_spatial_crop, "image_spatial_crop": batched_images_spatial_crop,
"modalities": request_obj.modalities or ["image"], "modalities": request_obj.modalities or ["image"],
} }
ImageProcessorMapping = {
DeepseekVL2ForCausalLM: DeepseekVL2ImageProcessor,
}

View File

@@ -17,14 +17,15 @@ logger = logging.get_logger(__name__)
class Gemma3SGLangImageProcessor(SGLangBaseImageProcessor): class Gemma3SGLangImageProcessor(SGLangBaseImageProcessor):
models = [Gemma3ForConditionalGeneration]
def __init__(self, hf_config, server_args, _processor): def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor) super().__init__(hf_config, server_args, _processor)
self.IMAGE_TOKEN = "<start_of_image>" self.IMAGE_TOKEN = "<start_of_image>"
self.IM_START_TOKEN_ID = hf_config.boi_token_index self.IM_START_TOKEN_ID = hf_config.boi_token_index
self.IM_END_TOKEN_ID = hf_config.eoi_token_index self.IM_END_TOKEN_ID = hf_config.eoi_token_index
@staticmethod async def _process_single_image(self, images, input_text) -> dict:
def _process_images_task(images, input_text, _hf_config):
if isinstance(images, list) and len(images) == 0: if isinstance(images, list) and len(images) == 0:
images = None images = None
processor = get_global_processor() processor = get_global_processor()
@@ -46,19 +47,6 @@ class Gemma3SGLangImageProcessor(SGLangBaseImageProcessor):
"pixel_values": pixel_values, "pixel_values": pixel_values,
} }
async def _process_images(self, images, input_text) -> dict:
if self.executor is not None:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self.executor,
Gemma3SGLangImageProcessor._process_images_task,
images,
input_text,
self.hf_config,
)
else:
return self._process_images_task(images, input_text, self.hf_config)
async def process_images_async( async def process_images_async(
self, self,
image_data: List[Union[str, bytes]], image_data: List[Union[str, bytes]],
@@ -82,7 +70,7 @@ class Gemma3SGLangImageProcessor(SGLangBaseImageProcessor):
discard_alpha_channel=True, discard_alpha_channel=True,
) )
ret = await self._process_images( ret = await self._process_single_image(
input_text=base_output.input_text, images=base_output.all_frames input_text=base_output.input_text, images=base_output.all_frames
) )
@@ -93,8 +81,3 @@ class Gemma3SGLangImageProcessor(SGLangBaseImageProcessor):
"im_start_id": self.IM_START_TOKEN_ID, "im_start_id": self.IM_START_TOKEN_ID,
"im_end_id": self.IM_END_TOKEN_ID, "im_end_id": self.IM_END_TOKEN_ID,
} }
ImageProcessorMapping = {
Gemma3ForConditionalGeneration: Gemma3SGLangImageProcessor,
}

View File

@@ -11,6 +11,8 @@ from sglang.srt.models.deepseek_janus_pro import MultiModalityCausalLM
class JanusProProcessor(SGLangBaseImageProcessor): class JanusProProcessor(SGLangBaseImageProcessor):
models = [MultiModalityCausalLM]
def __init__(self, hf_config, server_args, _processor): def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor) super().__init__(hf_config, server_args, _processor)
@@ -77,6 +79,3 @@ class JanusProProcessor(SGLangBaseImageProcessor):
"im_end_id": res["im_end_id"], "im_end_id": res["im_end_id"],
"im_token_id": res["im_token_id"], "im_token_id": res["im_token_id"],
} }
ImageProcessorMapping = {MultiModalityCausalLM: JanusProProcessor}

View File

@@ -15,6 +15,8 @@ from sglang.utils import get_exception_traceback
class LlavaImageProcessor(BaseImageProcessor): class LlavaImageProcessor(BaseImageProcessor):
models = [LlavaVidForCausalLM, LlavaQwenForCausalLM, LlavaMistralForCausalLM]
def __init__(self, hf_config, server_args, _processor): def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor) super().__init__(hf_config, server_args, _processor)
@@ -143,10 +145,3 @@ class LlavaImageProcessor(BaseImageProcessor):
"image_sizes": image_sizes, "image_sizes": image_sizes,
"modalities": request_obj.modalities or ["image"], "modalities": request_obj.modalities or ["image"],
} }
ImageProcessorMapping = {
LlavaVidForCausalLM: LlavaImageProcessor,
LlavaQwenForCausalLM: LlavaImageProcessor,
LlavaMistralForCausalLM: LlavaImageProcessor,
}

View File

@@ -1,6 +1,8 @@
import asyncio import asyncio
from typing import List, Union from typing import List, Union
import torch
from sglang.srt.managers.image_processor import BaseImageProcessor from sglang.srt.managers.image_processor import BaseImageProcessor
from sglang.srt.managers.image_processors.base_image_processor import ( from sglang.srt.managers.image_processors.base_image_processor import (
get_global_processor, get_global_processor,
@@ -9,6 +11,8 @@ from sglang.srt.models.minicpmv import MiniCPMV
class MiniCPMVImageProcessor(BaseImageProcessor): class MiniCPMVImageProcessor(BaseImageProcessor):
models = [MiniCPMV]
def __init__(self, hf_config, server_args, _processor): def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor) super().__init__(hf_config, server_args, _processor)
self.IMAGE_TOKEN = "(<image>./</image>)" self.IMAGE_TOKEN = "(<image>./</image>)"
@@ -69,21 +73,57 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
# Collect special token ids # Collect special token ids
tokenizer = self._processor.tokenizer tokenizer = self._processor.tokenizer
im_start_id = tokenizer.im_start_id im_start_id = tokenizer.im_start_id
im_token_id = tokenizer.unk_token_id
im_end_id = tokenizer.im_end_id im_end_id = tokenizer.im_end_id
if tokenizer.slice_start_id: if tokenizer.slice_start_id:
slice_start_id = tokenizer.slice_start_id slice_start_id = tokenizer.slice_start_id
slice_end_id = tokenizer.slice_end_id slice_end_id = tokenizer.slice_end_id
pixel_values = res["pixel_values"]
tgt_sizes = res["tgt_sizes"]
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError(
"Incorrect type of pixel values. " f"Got type: {type(pixel_values)}"
)
if not isinstance(tgt_sizes, (torch.Tensor, list)):
raise ValueError(
"Incorrect type of target sizes. " f"Got type: {type(tgt_sizes)}"
)
if len(pixel_values) != len(tgt_sizes):
raise ValueError(
"Inconsistent batch lengths, found: "
f"{len(pixel_values)} vs. {len(tgt_sizes)}"
)
# tgt_sizes = [tgt_size for tgt_size in tgt_sizes if isinstance(tgt_size, torch.Tensor)]
# tgt_sizes = torch.vstack(tgt_sizes).type(torch.int32)
pixel_values_flat: List[torch.Tensor] = []
tgt_sizes_flat: List[torch.Tensor] = []
for pixel_b, tgt_b in zip(pixel_values, tgt_sizes):
# per image
if len(pixel_b) != len(tgt_b):
raise ValueError(
"Inconsistent N lengths, found: " f"{len(pixel_b)} vs {len(tgt_b)}"
)
for pixel_n, tgt_n in zip(pixel_b, tgt_b):
# per patch
pixel_values_flat += [pixel_n]
tgt_sizes_flat += [tgt_n]
pixel_values = pixel_values_flat
tgt_sizes = torch.stack(tgt_sizes_flat)
return { return {
"input_ids": res["input_ids"].flatten().tolist(), "input_ids": res["input_ids"].flatten().tolist(),
"pixel_values": res["pixel_values"], "pixel_values": pixel_values,
"tgt_sizes": res["tgt_sizes"], "tgt_sizes": tgt_sizes,
"image_hashes": base_output.image_hashes, "image_hashes": base_output.image_hashes,
"modalities": request_obj.modalities or ["image"], "modalities": request_obj.modalities or ["image"],
"im_start_id": im_start_id, "im_start_id": im_start_id,
"im_token_id": im_token_id,
"im_end_id": im_end_id, "im_end_id": im_end_id,
"slice_start_id": slice_start_id, "slice_start_id": slice_start_id,
"slice_end_id": slice_end_id, "slice_end_id": slice_end_id,
} }
ImageProcessorMapping = {MiniCPMV: MiniCPMVImageProcessor}

View File

@@ -10,6 +10,8 @@ from sglang.srt.utils import load_image
class MllamaImageProcessor(BaseImageProcessor): class MllamaImageProcessor(BaseImageProcessor):
models = [MllamaForConditionalGeneration]
def __init__(self, hf_config, server_args, _processor): def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor) super().__init__(hf_config, server_args, _processor)
@@ -55,6 +57,3 @@ class MllamaImageProcessor(BaseImageProcessor):
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0] image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
return image_inputs return image_inputs
ImageProcessorMapping = {MllamaForConditionalGeneration: MllamaImageProcessor}

View File

@@ -2,6 +2,7 @@ import asyncio
import math import math
from typing import List, Union from typing import List, Union
import torch
from PIL import Image from PIL import Image
from sglang.srt.managers.image_processor import BaseImageProcessor from sglang.srt.managers.image_processor import BaseImageProcessor
@@ -14,6 +15,8 @@ from sglang.srt.models.qwen2_vl import Qwen2VLForConditionalGeneration
# Compatible with Qwen2VL and Qwen2_5VL # Compatible with Qwen2VL and Qwen2_5VL
class Qwen2_5VLImageProcessor(BaseImageProcessor): class Qwen2_5VLImageProcessor(BaseImageProcessor):
models = [Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration]
def __init__(self, hf_config, server_args, _processor): def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor) super().__init__(hf_config, server_args, _processor)
self.IMAGE_TOKEN = "<|vision_start|><|image_pad|><|vision_end|>" self.IMAGE_TOKEN = "<|vision_start|><|image_pad|><|vision_end|>"
@@ -43,7 +46,7 @@ class Qwen2_5VLImageProcessor(BaseImageProcessor):
"video_grid_thws": getattr(result, "video_grid_thws", None), "video_grid_thws": getattr(result, "video_grid_thws", None),
} }
async def _process_images(self, images, input_text) -> dict: async def _process_single_image(self, images, input_text) -> dict:
if self.executor is not None: if self.executor is not None:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
return await loop.run_in_executor( return await loop.run_in_executor(
@@ -138,23 +141,23 @@ class Qwen2_5VLImageProcessor(BaseImageProcessor):
images = [resize_image(image) for image in base_output.all_frames] images = [resize_image(image) for image in base_output.all_frames]
ret = await self._process_images(images, base_output.input_text) ret = await self._process_single_image(
images=images, input_text=base_output.input_text
)
image_grid_thws = torch.concat([ret["image_grid_thw"]])
video_grid_thws = None
return { return {
"input_ids": ret["input_ids"].flatten().tolist(), "input_ids": ret["input_ids"].flatten().tolist(),
"pixel_values": ret["pixel_values"], "pixel_values": ret["pixel_values"],
"image_hashes": base_output.image_hashes, "image_hashes": base_output.image_hashes,
"modalities": request_obj.modalities or ["image"], "modalities": request_obj.modalities or ["image"],
"image_grid_thws": ret["image_grid_thw"], "image_grid_thws": image_grid_thws,
"video_grid_thws": ret["video_grid_thws"], "video_grid_thws": video_grid_thws,
"im_start_id": self.IM_START_TOKEN_ID, "im_start_id": self.IM_START_TOKEN_ID,
"im_end_id": self.IM_END_TOKEN_ID, "im_end_id": self.IM_END_TOKEN_ID,
"im_token_id": self.image_token_id, "im_token_id": self.image_token_id,
"video_token_id": self.video_token_id, "video_token_id": self.video_token_id,
"second_per_grid_ts": ret["second_per_grid_ts"], "second_per_grid_ts": ret["second_per_grid_ts"],
} }
ImageProcessorMapping = {
Qwen2VLForConditionalGeneration: Qwen2_5VLImageProcessor,
Qwen2_5_VLForConditionalGeneration: Qwen2_5VLImageProcessor,
}

View File

@@ -0,0 +1,303 @@
"""
Multimodality utils
"""
from abc import abstractmethod
from typing import Callable, List, Optional, Tuple
import torch
from torch import nn
from sglang.srt.managers.schedule_batch import (
ImageInputs,
global_server_args_dict,
logger,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.utils import logger
class MultiModalityDataPaddingPattern:
"""
Data tokens (like image tokens) often need special handling during padding
to maintain model compatibility. This class provides the interface for
implementing different padding strategies for data tokens
"""
@abstractmethod
def pad_input_tokens(
self, input_ids: List[int], image_inputs: ImageInputs
) -> List[int]:
"""
Pad the input ids sequence containing data tokens, and replace them with pad_values
"""
pass
class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern):
"""In this pattern, data tokens should be enclosed by special token pairs (e.g. <image>...</image>, data_token_pairs)
This strategy should be applied when data content is marked by start/end token pairs in the input sequence.
"""
def __init__(self, data_token_pairs: Optional[List[Tuple[int, int]]]) -> None:
self.data_token_id_pairs = data_token_pairs
def pad_input_tokens(
self, input_ids: List[int], image_inputs: ImageInputs
) -> List[int]:
"""
This function will replace the data-tokens inbetween with pad_values accordingly
"""
pad_values = image_inputs.pad_values
data_token_pairs = self.data_token_id_pairs
image_inputs.image_offsets = []
if data_token_pairs is None:
data_token_pairs = [image_inputs.im_start_id, image_inputs.im_end_id]
if data_token_pairs is None:
logger.warning(
"No data_token_pairs provided, RadixAttention might be influenced."
)
return input_ids
start_token_ids = [s for s, _e in data_token_pairs]
end_tokens_ids = [e for _s, e in data_token_pairs]
# First start token marks new data
data_start_token = start_token_ids[0]
padded_ids = []
last_idx = 0
data_idx = -1
start_indices = [i for i, x in enumerate(input_ids) if x in start_token_ids]
end_indices = [i for i, x in enumerate(input_ids) if x in end_tokens_ids]
if len(start_indices) != len(end_indices):
return input_ids
for start_idx, end_idx in zip(start_indices, end_indices):
padded_ids.extend(input_ids[last_idx : start_idx + 1])
if input_ids[start_idx] == data_start_token:
data_idx += 1
image_inputs.image_offsets += [start_idx]
num_tokens = end_idx - start_idx - 1
pad_value = pad_values[data_idx]
padded_ids.extend([pad_value] * num_tokens)
last_idx = end_idx
padded_ids.extend(input_ids[last_idx:])
assert len(input_ids) == len(padded_ids)
return padded_ids
class MultModalityDataPaddingPatternSingleToken(MultiModalityDataPaddingPattern):
"""In this pattern, data is represented with a special token_id ( image_inputs.im_token_id ),
which needs first to be expanded to multiple tokens, then replaced with their padding values
This strategy should be used when a single data token represents content that should
be expanded to multiple tokens during processing.
"""
def __init__(
self, num_data_token_calc_func: Callable[[Tuple[int, int, int]], int]
) -> None:
self.num_data_token_calc_func = num_data_token_calc_func
def pad_input_tokens(
self, input_ids: List[int], image_inputs: ImageInputs
) -> List[int]:
"""
This function will follow the procedure of:
1. the data token will be expanded, of which the final number will be calculated by `num_data_token_calc_func`
2. the padded data tokens will be replaced with their pad_values
"""
image_grid_thws = image_inputs.image_grid_thws
pad_values = image_inputs.pad_values
image_indices = [
idx
for idx, token in enumerate(input_ids)
if token == image_inputs.im_token_id
]
image_inputs.image_offsets = []
input_ids_with_image = []
for image_cnt, _ in enumerate(image_grid_thws):
num_image_tokens = self.num_data_token_calc_func(image_grid_thws[image_cnt])
if image_cnt == 0:
non_image_tokens = input_ids[: image_indices[image_cnt]]
else:
non_image_tokens = input_ids[
image_indices[image_cnt - 1] + 1 : image_indices[image_cnt]
]
input_ids_with_image.extend(non_image_tokens)
image_inputs.image_offsets.append(len(input_ids_with_image))
pad_ids = pad_values * (
(num_image_tokens + len(pad_values)) // len(pad_values)
)
input_ids_with_image.extend(pad_ids[:num_image_tokens])
input_ids_with_image.extend(input_ids[image_indices[-1] + 1 :])
return input_ids_with_image
class MultiModalityDataPaddingPatternImageTokens(MultiModalityDataPaddingPattern):
"""In this pattern, data tokens should be represented as image tokens (e.g. <image><image>....<image>)"""
def __init__(self, image_token_id: torch.Tensor) -> None:
self.image_token_id = image_token_id
def pad_input_tokens(self, input_ids: List[int], image_inputs) -> List[int]:
"""
This function will replace the data-tokens in between with pad_values accordingly
"""
pad_values = image_inputs.pad_values
assert len(pad_values) != 0
input_ids_tensor = torch.tensor(input_ids)
mask = torch.isin(input_ids_tensor, self.image_token_id)
num_image_tokens = mask.sum().item()
repeated_pad_values = torch.tensor(pad_values).repeat(
num_image_tokens // len(pad_values) + 1
)[:num_image_tokens]
input_ids_tensor[mask] = repeated_pad_values
return input_ids_tensor.tolist()
def embed_image_inputs(
image_input: ImageInputs,
input_ids: torch.Tensor,
input_embedding: nn.Embedding,
image_embedding_func,
placeholder_token_ids: List[int] = None,
) -> Optional[torch.Tensor]:
"""
Calculate the image embeddings if necessary, then scatter the result with
the help of a boolean mask denoting the embed locations
Returns:
final embedding: Optional[torch.Tensor]
"""
if image_input is None:
return None
placeholder_token_ids = placeholder_token_ids or image_input.pad_values
# boolean masking the special tokens
special_image_mask = torch.isin(
input_ids,
torch.tensor(placeholder_token_ids, device=input_ids.device),
).unsqueeze(-1)
num_image_tokens_in_input_ids = special_image_mask.sum()
if num_image_tokens_in_input_ids == 0:
# unexpected
inputs_embeds = input_embedding(input_ids)
else:
image_embedding = image_embedding_func(image_input)
if image_embedding.dim() == 2:
num_image_tokens_in_embedding = image_embedding.shape[0]
else:
num_image_tokens_in_embedding = (
image_embedding.shape[0] * image_embedding.shape[1]
)
if num_image_tokens_in_input_ids != num_image_tokens_in_embedding:
num_image = num_image_tokens_in_input_ids // image_embedding.shape[1]
image_embedding = image_embedding[:num_image, :]
logger.warning(
f"Number of images does not match number of special image tokens in the input text. "
f"Got {num_image_tokens_in_input_ids} image tokens in the text but {num_image_tokens_in_embedding} "
"tokens from image embeddings."
)
# TODO: chunked prefill will split special tokens from input_ids into several passes, failing the embedding
# a fix may be cache the unfinished image embedding for future reuse, determine the tokens to embed with
# extend_start_loc and extend_seq_lens
if num_image_tokens_in_input_ids > num_image_tokens_in_embedding:
chunked_prefill_size = global_server_args_dict["chunked_prefill_size"]
if chunked_prefill_size != -1:
logger.warning(
"You may want to avoid this issue by raising `chunked_prefill_size`, or disabling chunked_prefill"
)
vocab_size = input_embedding.num_embeddings
# Important: clamp after getting original image regions
# Clamp input ids. This is because the input_ids for the image tokens are
# filled with the hash values of the image for the prefix matching in the radix attention.
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
input_ids.clamp_(min=0, max=vocab_size - 1)
inputs_embeds = input_embedding(input_ids)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
inputs_embeds.device
)
inputs_embeds = inputs_embeds.masked_scatter(
special_image_mask,
image_embedding.to(inputs_embeds.device, inputs_embeds.dtype),
)
return inputs_embeds
def embed_image_embedding(
inputs_embeds: torch.Tensor,
image_embedding: torch.Tensor,
image_bounds: torch.Tensor,
) -> torch.Tensor:
"""
scatter image_embedding into inputs_embeds according to image_bounds
"""
if len(image_bounds) > 0:
image_indices = torch.stack(
[
torch.arange(start, end, dtype=torch.long)
for start, end in image_bounds.tolist()
]
).to(inputs_embeds.device)
inputs_embeds.scatter_(
0,
image_indices.view(-1, 1).repeat(1, inputs_embeds.shape[-1]),
image_embedding.view(-1, image_embedding.shape[-1]),
)
return inputs_embeds
def general_mm_embed_routine(
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
embed_tokens: nn.Embedding,
image_embedding_func: Callable[[ImageInputs], torch.Tensor],
placeholder_token_ids: List[int] = None,
):
"""
a general wrapper function to get final input embeds from multimodal models
with a language model as causal model
"""
if (
forward_batch.forward_mode.is_decode()
or not forward_batch.contains_image_inputs()
):
inputs_embeds = embed_tokens(input_ids)
else:
image = forward_batch.merge_image_inputs()
inputs_embeds = embed_image_inputs(
image_input=image,
input_ids=input_ids,
input_embedding=embed_tokens,
image_embedding_func=image_embedding_func,
placeholder_token_ids=placeholder_token_ids,
)
# once used, image_inputs is useless
# just being defensive here
forward_batch.image_inputs = None
return inputs_embeds

View File

@@ -1,134 +0,0 @@
from abc import abstractmethod
from typing import Callable, List, Optional, Tuple
from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.utils import logger
class MultiModalityDataPaddingPattern:
"""
Data tokens (like image tokens) often need special handling during padding
to maintain model compatibility. This class provides the interface for
implementing different padding strategies for data tokens
"""
@abstractmethod
def pad_input_tokens(
self, input_ids: List[int], image_inputs: ImageInputs
) -> List[int]:
"""
Pad the input ids sequence containing data tokens, and replace them with pad_values
"""
pass
class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern):
"""In this pattern, data tokens should be enclosed by special token pairs (e.g. <image>...</image>, data_token_pairs)
This strategy should be applied when data content is marked by start/end token pairs in the input sequence.
"""
def __init__(self, data_token_pairs: Optional[List[Tuple[int, int]]]) -> None:
self.data_token_id_pairs = data_token_pairs
def pad_input_tokens(
self, input_ids: List[int], image_inputs: ImageInputs
) -> List[int]:
"""
This function will replace the data-tokens inbetween with pad_values accordingly
"""
pad_values = image_inputs.pad_values
data_token_pairs = self.data_token_id_pairs
image_inputs.image_offsets = []
if data_token_pairs is None:
data_token_pairs = [image_inputs.im_start_id, image_inputs.im_end_id]
if data_token_pairs is None:
logger.warning(
"No data_token_pairs provided, RadixAttention might be influenced."
)
return input_ids
start_token_ids = [s for s, _e in data_token_pairs]
end_tokens_ids = [e for _s, e in data_token_pairs]
# First start token marks new data
data_start_token = start_token_ids[0]
padded_ids = []
last_idx = 0
data_idx = -1
start_indices = [i for i, x in enumerate(input_ids) if x in start_token_ids]
end_indices = [i for i, x in enumerate(input_ids) if x in end_tokens_ids]
if len(start_indices) != len(end_indices):
return input_ids
for start_idx, end_idx in zip(start_indices, end_indices):
padded_ids.extend(input_ids[last_idx : start_idx + 1])
if input_ids[start_idx] == data_start_token:
data_idx += 1
image_inputs.image_offsets += [start_idx]
num_tokens = end_idx - start_idx - 1
pad_value = pad_values[data_idx]
padded_ids.extend([pad_value] * num_tokens)
last_idx = end_idx
padded_ids.extend(input_ids[last_idx:])
assert len(input_ids) == len(padded_ids)
return padded_ids
class MultModalityDataPaddingPatternSingleToken(MultiModalityDataPaddingPattern):
"""In this pattern, data is represented with a special token_id ( image_inputs.im_token_id ),
which needs first to be expanded to multiple tokens, then replaced with their padding values
This strategy should be used when a single data token represents content that should
be expanded to multiple tokens during processing.
"""
def __init__(
self, num_data_token_calc_func: Callable[[Tuple[int, int, int]], int]
) -> None:
self.num_data_token_calc_func = num_data_token_calc_func
def pad_input_tokens(
self, input_ids: List[int], image_inputs: ImageInputs
) -> List[int]:
"""
This function will follow the procedure of:
1. the data token will be expanded, of which the final number will be calculated by `num_data_token_calc_func`
2. the padded data tokens will be replaced with their pad_values
"""
image_grid_thws = image_inputs.image_grid_thws
pad_values = image_inputs.pad_values
image_indices = [
idx
for idx, token in enumerate(input_ids)
if token == image_inputs.im_token_id
]
image_inputs.image_offsets = []
input_ids_with_image = []
for image_cnt, _ in enumerate(image_grid_thws):
print(f"image_cnt {image_cnt}")
num_image_tokens = self.num_data_token_calc_func(image_grid_thws[image_cnt])
if image_cnt == 0:
non_image_tokens = input_ids[: image_indices[image_cnt]]
else:
non_image_tokens = input_ids[
image_indices[image_cnt - 1] + 1 : image_indices[image_cnt]
]
input_ids_with_image.extend(non_image_tokens)
image_inputs.image_offsets.append(len(input_ids_with_image))
pad_ids = pad_values * (
(num_image_tokens + len(pad_values)) // len(pad_values)
)
input_ids_with_image.extend(pad_ids[:num_image_tokens])
input_ids_with_image.extend(input_ids[image_indices[-1] + 1 :])
return input_ids_with_image

View File

@@ -77,6 +77,7 @@ global_server_args_dict = {
"enable_flashmla": ServerArgs.enable_flashmla, "enable_flashmla": ServerArgs.enable_flashmla,
"disable_radix_cache": ServerArgs.disable_radix_cache, "disable_radix_cache": ServerArgs.disable_radix_cache,
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged, "flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
"chunked_prefill_size": ServerArgs.chunked_prefill_size,
} }
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -160,7 +161,8 @@ class ImageInputs:
aspect_ratio_mask: Optional[List[torch.Tensor]] = None aspect_ratio_mask: Optional[List[torch.Tensor]] = None
# QWen2-VL related # QWen2-VL related
image_grid_thws: List[Tuple[int, int, int]] = None # [num_of_images, t, h, w]
image_grid_thws: torch.Tensor = None
mrope_position_delta: Optional[torch.Tensor] = None mrope_position_delta: Optional[torch.Tensor] = None
# Qwen2-VL video related # Qwen2-VL video related
video_token_id: Optional[int] = None video_token_id: Optional[int] = None
@@ -168,7 +170,7 @@ class ImageInputs:
second_per_grid_ts: Optional[List[torch.Tensor]] = None second_per_grid_ts: Optional[List[torch.Tensor]] = None
# deepseek vl2 related # deepseek vl2 related
image_seq_mask: Optional[List[torch.Tensor]] = None images_emb_mask: Optional[List[torch.Tensor]] = None
image_spatial_crop: Optional[List[torch.Tensor]] = None image_spatial_crop: Optional[List[torch.Tensor]] = None
# The id of the single-image placeholder token # The id of the single-image placeholder token
@@ -182,9 +184,6 @@ class ImageInputs:
slice_end_id: Optional[int] = None slice_end_id: Optional[int] = None
tgt_sizes: Optional[list] = None tgt_sizes: Optional[list] = None
# denotes the number of valid image tokens in each image
images_emb_mask: Optional[torch.BoolTensor] = None
@staticmethod @staticmethod
def from_dict(obj: dict): def from_dict(obj: dict):
ret = ImageInputs( ret = ImageInputs(
@@ -204,7 +203,7 @@ class ImageInputs:
"aspect_ratio_ids", "aspect_ratio_ids",
"aspect_ratio_mask", "aspect_ratio_mask",
"image_grid_thws", "image_grid_thws",
"image_seq_mask", "images_emb_mask",
"image_spatial_crop", "image_spatial_crop",
"im_token_id", "im_token_id",
"im_start_id", "im_start_id",
@@ -212,20 +211,58 @@ class ImageInputs:
"slice_start_id", "slice_start_id",
"slice_end_id", "slice_end_id",
"tgt_sizes", "tgt_sizes",
"images_emb_mask",
] ]
for arg in optional_args: for arg in optional_args:
if arg in obj: if arg in obj:
setattr(ret, arg, obj[arg]) setattr(ret, arg, obj[arg])
# validate
assert (
isinstance(ret.pixel_values, torch.Tensor)
or isinstance(ret.pixel_values, np.ndarray)
or isinstance(ret.pixel_values, list)
)
return ret return ret
def merge(self, other): def merge(self, other: ImageInputs):
""" """
merge image inputs when requests are being merged merge image inputs when requests are being merged
""" """
assert self.pixel_values.shape[1:] == other.pixel_values.shape[1:] if isinstance(self.pixel_values, list):
self.pixel_values = np.concatenate([self.pixel_values, other.pixel_values]) # in some rare cases, pixel values are list of patches with different shapes
# e.g. minicpm
self.pixel_values += other.pixel_values
else:
assert (
self.pixel_values.shape[1:] == other.pixel_values.shape[1:]
), f"{self.pixel_values.shape[1:]} vs {other.pixel_values.shape[1:]}"
self.pixel_values = np.concatenate([self.pixel_values, other.pixel_values])
# args would be stacked along first dim
# usually these are already tensors
stack_args = [
# TODO: merge with image_grid_thws, basically the same thing
"tgt_sizes",
"image_spatial_crop",
]
for arg in stack_args:
if getattr(self, arg, None) is None:
setattr(self, arg, getattr(other, arg, None))
elif getattr(other, arg, None) is not None:
# self and other both not None
setattr(
self,
arg,
torch.cat([getattr(self, arg), getattr(other, arg)], dim=0),
)
if self.image_grid_thws is None:
self.image_grid_thws = other.image_grid_thws
elif other.image_grid_thws is not None:
self.image_grid_thws = torch.concat(
[self.image_grid_thws, other.image_grid_thws]
)
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache. # Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
# Please note that if the `input_ids` is later used in the model forward, # Please note that if the `input_ids` is later used in the model forward,
@@ -233,7 +270,7 @@ class ImageInputs:
# errors in cuda kernels. See also llava.py for example. # errors in cuda kernels. See also llava.py for example.
self.image_hashes += other.image_hashes self.image_hashes += other.image_hashes
self.pad_values = [x % (1 << 30) for x in self.image_hashes] self.pad_values = [x % (1 << 30) for x in self.image_hashes]
# args needed to be merged
optional_args = [ optional_args = [
"image_sizes", "image_sizes",
"image_offsets", "image_offsets",
@@ -241,13 +278,13 @@ class ImageInputs:
# "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images # "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images
"aspect_ratio_ids", "aspect_ratio_ids",
"aspect_ratio_mask", "aspect_ratio_mask",
"image_grid_thws", "images_emb_mask",
"image_seq_mask",
"image_spatial_crop",
] ]
for arg in optional_args: for arg in optional_args:
if getattr(self, arg, None) is not None: self_arg = getattr(self, arg, None)
setattr(self, arg, getattr(self, arg) + getattr(other, arg)) if self_arg is not None:
setattr(self, arg, self_arg + getattr(other, arg))
# other args would be kept intact
class Req: class Req:

View File

@@ -179,7 +179,7 @@ class TokenizerManager:
) )
# We want to parallelize the image pre-processing so we create an executor for it # We want to parallelize the image pre-processing so we create an executor for it
# We creat image_processor for any skip_tokenizer_init to make sure we still encode # We create image_processor for any skip_tokenizer_init to make sure we still encode
# images even with skip_tokenizer_init=False. # images even with skip_tokenizer_init=False.
self.image_processor = get_image_processor( self.image_processor = get_image_processor(
self.model_config.hf_config, server_args, _processor self.model_config.hf_config, server_args, _processor

View File

@@ -332,7 +332,7 @@ class ForwardBatch:
return ret return ret
def get_merged_image_inputs(self) -> Optional[ImageInputs]: def merge_image_inputs(self) -> Optional[ImageInputs]:
""" """
Merge all image inputs in the batch into a single ImageInputs object. Merge all image inputs in the batch into a single ImageInputs object.
@@ -358,6 +358,16 @@ class ForwardBatch:
return merged return merged
def contains_image_inputs(self) -> bool:
""" """
if self.image_inputs is None:
return True
return any(
image_input.pixel_values is not None and image_input.pixel_values is not []
for image_input in self.image_inputs
if image_input is not None
)
def _compute_mrope_positions( def _compute_mrope_positions(
self, model_runner: ModelRunner, batch: ModelWorkerBatch self, model_runner: ModelRunner, batch: ModelWorkerBatch
): ):

View File

@@ -273,7 +273,7 @@ class ModelRunner:
if self.model_config.hf_config.architectures == ["DeepseekVL2ForCausalLM"]: if self.model_config.hf_config.architectures == ["DeepseekVL2ForCausalLM"]:
# TODO: deepseek-vl2 does not support radix cache now, set disable_radix_cache=True automatically # TODO: deepseek-vl2 does not support radix cache now, set disable_radix_cache=True automatically
logger.info( logger.info(
"Automatically turn off --chunked-prefill-size and disable radix cache for deekseek-vl2." "Automatically turn off --chunked-prefill-size and disable radix cache for deepseek-vl2."
) )
server_args.chunked_prefill_size = -1 server_args.chunked_prefill_size = -1
server_args.disable_radix_cache = True server_args.disable_radix_cache = True

View File

@@ -47,8 +47,9 @@ from sglang.srt.configs.janus_pro import *
from sglang.srt.layers.attention.vision import VisionAttention from sglang.srt.layers.attention.vision import VisionAttention
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization import QuantizationConfig from sglang.srt.layers.quantization import QuantizationConfig
from sglang.srt.managers.multi_modality_padding import ( from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs, MultiModalityDataPaddingPatternTokenPairs,
general_mm_embed_routine,
) )
from sglang.srt.managers.schedule_batch import ImageInputs from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
@@ -1958,82 +1959,8 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
def prepare_images_seq_mask( def get_image_feature(self, image_input: ImageInputs) -> torch.Tensor:
self, input_ids: torch.Tensor, image_inputs: ImageInputs pixel_values = image_input.pixel_values
) -> Optional[torch.LongTensor]:
images_seq_mask = torch.isin(
input_ids, torch.tensor(image_inputs.pad_values, device=input_ids.device)
)
if images_seq_mask.sum() == 0:
# sometimes image_inputs is not empty, but input_ids contain no image token because of prefix-cache
return None
else:
return images_seq_mask
@torch.no_grad()
def forward(
self,
input_ids: torch.LongTensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
inputs_embeds = None
if (
forward_batch.image_inputs is not None
and len(forward_batch.image_inputs) != 0
and forward_batch.image_inputs[0] is not None
):
image_inputs = forward_batch.image_inputs[0]
images_seq_mask = self.prepare_images_seq_mask(
input_ids=input_ids, image_inputs=image_inputs
)
if images_seq_mask is not None:
input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
inputs_embeds = self.prepare_inputs_embeds(
input_ids=input_ids,
pixel_values=image_inputs.pixel_values,
images_seq_mask=images_seq_mask,
images_emb_mask=image_inputs.images_emb_mask,
)
input_ids = None
if input_ids is not None:
input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
return self.language_model(
input_ids=input_ids,
positions=positions,
forward_batch=forward_batch,
input_embeds=inputs_embeds,
get_embedding=False,
)
def prepare_inputs_embeds(
self,
input_ids: torch.LongTensor,
pixel_values: torch.FloatTensor,
images_seq_mask: torch.LongTensor,
images_emb_mask: torch.BoolTensor,
**_kwargs,
):
"""
Args:
input_ids (torch.LongTensor): [b, T]
pixel_values (torch.FloatTensor): [b, n_images, 3, h, w]
images_seq_mask (torch.BoolTensor): [b, T]
images_emb_mask (torch.BoolTensor): [b, n_images, n_image_tokens]
assert torch.sum(images_seq_mask) == torch.sum(images_emb_mask)
Returns:
input_embeds (torch.Tensor): [b, T, D]
"""
bs, n = pixel_values.shape[0:2] bs, n = pixel_values.shape[0:2]
pixel_values = pixel_values.to( pixel_values = pixel_values.to(
device=self.vision_model.device, dtype=self.vision_model.dtype device=self.vision_model.device, dtype=self.vision_model.dtype
@@ -2045,18 +1972,35 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
# [b x n, T2, D] -> [b, n x T2, D] # [b x n, T2, D] -> [b, n x T2, D]
images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n) images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n)
# [b, n, T2] -> [b, n x T2]
images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)")
# [b, T, D] return images_embeds
# ignore the image embeddings
input_ids[input_ids < 0] = 0
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
# replace with the image embeddings def get_input_embeddings(self) -> nn.Embedding:
inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask] return self.language_model.model.embed_tokens
return inputs_embeds @torch.no_grad()
def forward(
self,
input_ids: torch.LongTensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
inputs_embeds = general_mm_embed_routine(
input_ids=input_ids,
positions=positions,
forward_batch=forward_batch,
embed_tokens=self.get_input_embeddings(),
image_embedding_func=self.get_image_feature,
)
return self.language_model(
input_ids=None,
positions=positions,
forward_batch=forward_batch,
input_embeds=inputs_embeds,
get_embedding=False,
)
def prepare_gen_img_embeds(self, image_ids: torch.LongTensor): def prepare_gen_img_embeds(self, image_ids: torch.LongTensor):
return self.gen_aligner(self.gen_embed(image_ids)) return self.gen_aligner(self.gen_embed(image_ids))

View File

@@ -1,34 +1,16 @@
import collections from typing import Iterable, List, Optional, Tuple
import itertools
import math
import warnings
from enum import Enum
from functools import partial
from typing import Callable, Iterable, List, Optional, Tuple, Type, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange, repeat from einops import rearrange, repeat
from torch import nn from torch import nn
from sglang.srt.configs import DeepseekVL2Config
from sglang.srt.configs.deepseekvl2 import ( from sglang.srt.configs.deepseekvl2 import (
DeepseekVL2Config, DeepseekVL2Config,
DeepseekVL2MlpProjectorConfig, DeepseekVL2MlpProjectorConfig,
) )
from sglang.srt.layers.attention.vision import VisionAttention from sglang.srt.layers.linear import ReplicatedLinear
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
ColumnParallelLinear,
LinearBase,
ReplicatedLinear,
RowParallelLinear,
)
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import ImageInputs from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
@@ -233,11 +215,11 @@ class DeepseekVL2ForCausalLM(nn.Module):
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
**kwargs: object, **kwargs: object,
): ):
input_embeds = self.language_model.model.embed_tokens(input_ids) input_embeds = self.language_model.model.embed_tokens(input_ids)
if forward_batch.forward_mode.is_extend() and forward_batch.image_inputs != [ if (
None forward_batch.forward_mode.is_extend()
]: and forward_batch.contains_image_inputs()
):
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy() extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
extend_seq_lens_cpu = forward_batch.extend_seq_lens.cpu().numpy() extend_seq_lens_cpu = forward_batch.extend_seq_lens.cpu().numpy()
for idx, image in enumerate(forward_batch.image_inputs): for idx, image in enumerate(forward_batch.image_inputs):
@@ -245,17 +227,11 @@ class DeepseekVL2ForCausalLM(nn.Module):
continue continue
start_idx = extend_start_loc_cpu[idx] start_idx = extend_start_loc_cpu[idx]
end_idx = start_idx + extend_seq_lens_cpu[idx] end_idx = start_idx + extend_seq_lens_cpu[idx]
pixel_values = image.pixel_values.to( images_emb_mask = image.images_emb_mask.to(device="cuda")
device="cuda", dtype=torch.bfloat16 image_features = self.get_image_feature(image)
) input_embeds[start_idx:end_idx] = input_embeds[
image_seq_mask = image.image_seq_mask.to(device="cuda") start_idx:end_idx
image_spatial_crop = image.image_spatial_crop ].masked_scatter(images_emb_mask.unsqueeze(-1), image_features)
input_embeds[start_idx:end_idx] = self.prepare_inputs_embeds(
pixel_values,
image_seq_mask,
image_spatial_crop,
input_embeds[start_idx:end_idx],
)
outputs = self.language_model.forward( outputs = self.language_model.forward(
input_ids=input_ids, input_ids=input_ids,
@@ -289,20 +265,17 @@ class DeepseekVL2ForCausalLM(nn.Module):
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs): def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
return input_ids return input_ids
def prepare_inputs_embeds( def get_image_feature(self, image_input: ImageInputs):
self, pixel_values = image_input.pixel_values.type(
pixel_values, next(self.vision.parameters()).dtype
images_seq_mask, ).to(device=next(self.vision.parameters()).device)
images_spatial_crop,
input_embeds,
):
image_feature = self.vision.forward_features(pixel_values) image_feature = self.vision.forward_features(pixel_values)
images_embeds = self.projector(image_feature) images_embeds = self.projector(image_feature)
_, hw, n_dim = images_embeds.shape _, hw, n_dim = images_embeds.shape
h = w = int(hw**0.5) h = w = int(hw**0.5)
tile_index = 0 tile_index = 0
images_in_this_batch = [] images_in_this_batch = []
images_spatial_crop = image_input.image_spatial_crop
for jdx in range(images_spatial_crop.shape[1]): for jdx in range(images_spatial_crop.shape[1]):
num_width_tiles, num_height_tiles = images_spatial_crop[0, jdx] num_width_tiles, num_height_tiles = images_spatial_crop[0, jdx]
if num_width_tiles == 0 or num_height_tiles == 0: if num_width_tiles == 0 or num_height_tiles == 0:
@@ -379,13 +352,7 @@ class DeepseekVL2ForCausalLM(nn.Module):
images_in_this_batch.append(global_local_features) images_in_this_batch.append(global_local_features)
if len(images_in_this_batch) > 0: return torch.cat(images_in_this_batch, dim=0)
images_in_this_batch = torch.cat(images_in_this_batch, dim=0)
input_embeds.masked_scatter_(
images_seq_mask.unsqueeze(-1), images_in_this_batch
)
return input_embeds
EntryClass = DeepseekVL2ForCausalLM EntryClass = DeepseekVL2ForCausalLM

View File

@@ -37,11 +37,8 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb, get_rope from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import ( from sglang.srt.model_loader.weight_utils import (
default_weight_loader, default_weight_loader,
@@ -511,7 +508,7 @@ class Gemma3TextModel(PreTrainedModel):
else: else:
hidden_states = input_embeds hidden_states = input_embeds
if len(positions.shape) == 1: if positions.dim() == 1:
positions = einops.rearrange(positions, "s -> 1 s") positions = einops.rearrange(positions, "s -> 1 s")
position_embeddings_global = self.rotary_emb(hidden_states, positions) position_embeddings_global = self.rotary_emb(hidden_states, positions)
@@ -609,11 +606,11 @@ class Gemma3ForCausalLM(PreTrainedModel):
) )
self.post_init() self.post_init()
def get_input_embeddings(self): def get_input_embeddings(self) -> nn.Embedding:
return self.model.embed_tokens return self.model.embed_tokens
def dtype(self) -> torch.dtype: def dtype(self) -> torch.dtype:
return self.model.layers[0].mlp.gate_up_proj.weight.dtype return next(self.parameters()).dtype
@torch.no_grad() @torch.no_grad()
def forward( def forward(

View File

@@ -34,8 +34,9 @@ from sglang.srt.hf_transformers_utils import get_processor
from sglang.srt.layers.layernorm import Gemma3RMSNorm from sglang.srt.layers.layernorm import Gemma3RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.multi_modality_padding import ( from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs, MultiModalityDataPaddingPatternTokenPairs,
general_mm_embed_routine,
) )
from sglang.srt.managers.schedule_batch import ImageInputs from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
@@ -264,10 +265,10 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
kwargs["local_attn_masks"] = local_attn_masks kwargs["local_attn_masks"] = local_attn_masks
return kwargs return kwargs
def get_input_embeddings(self): def get_input_embeddings(self) -> nn.Embedding:
return self.language_model.get_input_embeddings() return self.language_model.get_input_embeddings()
def get_image_features(self, pixel_values: torch.Tensor): def get_image_feature(self, image_input: ImageInputs):
""" """
Projects the last hidden state from the vision model into language model space. Projects the last hidden state from the vision model into language model space.
@@ -277,6 +278,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
Returns: Returns:
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
""" """
pixel_values = image_input.pixel_values
pixel_values = pixel_values.to("cuda") pixel_values = pixel_values.to("cuda")
pixel_values = pixel_values.to(dtype=self.language_model.dtype()) pixel_values = pixel_values.to(dtype=self.language_model.dtype())
@@ -305,7 +307,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
return inputs_embeds return inputs_embeds
else: else:
# print(f"image tokens from input_ids: {inputs_embeds[special_image_mask].numel()}") # print(f"image tokens from input_ids: {inputs_embeds[special_image_mask].numel()}")
image_features = self.get_image_features(image_input.pixel_values) image_features = self.get_image_feature(image_input.pixel_values)
# print(f"image tokens from image embeddings: {image_features.numel()}") # print(f"image tokens from image embeddings: {image_features.numel()}")
num_image_tokens_in_embedding = ( num_image_tokens_in_embedding = (
@@ -397,20 +399,13 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
else: else:
llm_input_ids = input_ids llm_input_ids = input_ids
merged_image_input = forward_batch.get_merged_image_inputs() inputs_embeds = general_mm_embed_routine(
input_ids=llm_input_ids,
if ( positions=positions,
not forward_batch.forward_mode.is_decode() forward_batch=forward_batch,
and merged_image_input is not None embed_tokens=self.get_input_embeddings(),
): image_embedding_func=self.get_image_feature,
inputs_embeds = self.embed_image_inputs( )
input_ids=llm_input_ids,
forward_batch=forward_batch,
image_input=merged_image_input,
)
else:
llm_input_ids.clamp_(min=0, max=self.vocab_size - 1)
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
outputs = self.language_model( outputs = self.language_model(
input_ids=None, input_ids=None,

View File

@@ -50,8 +50,9 @@ from sglang.srt.layers.linear import (
) )
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.multi_modality_padding import ( from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs, MultiModalityDataPaddingPatternTokenPairs,
embed_image_inputs,
) )
from sglang.srt.managers.schedule_batch import ImageInputs from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
@@ -399,7 +400,7 @@ class Idefics2VisionTransformer(nn.Module):
) )
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
def get_input_embeddings(self): def get_input_embeddings(self) -> nn.Embedding:
return self.embeddings return self.embeddings
def compute_cu_seqlens(self, tgt_sizes: torch.Tensor) -> torch.Tensor: def compute_cu_seqlens(self, tgt_sizes: torch.Tensor) -> torch.Tensor:
@@ -762,42 +763,6 @@ class MiniCPMVBaseModel(nn.Module):
valid_pairs_tensor = torch.tensor(valid_pairs, device=input_ids.device) valid_pairs_tensor = torch.tensor(valid_pairs, device=input_ids.device)
return valid_pairs_tensor return valid_pairs_tensor
def get_embedding(
self,
input_ids: torch.Tensor,
image_inputs: Optional[MiniCPMVImageInputs],
) -> Tuple[torch.Tensor, torch.Tensor]:
vlm_embedding: torch.Tensor = self.llm.get_input_embeddings(input_ids)
if image_inputs is None: # No image
vision_hidden_states = torch.tensor([], device=input_ids.device)
else:
if image_inputs["type"] == "image_embeds":
vision_hidden_states = (
image_inputs["data"]
.type(vlm_embedding.dtype)
.to(vlm_embedding.device)
)
else:
vision_hidden_states = self.get_vision_hidden_states(image_inputs)
# See NOTE in _parse_and_validate_inputs
image_bounds = image_inputs["image_bounds"]
if len(image_bounds) > 0:
image_indices = torch.stack(
[
torch.arange(start, end, dtype=torch.long)
for start, end in image_bounds.tolist()
]
).to(vlm_embedding.device)
vlm_embedding.scatter_(
0,
image_indices.view(-1, 1).repeat(1, vlm_embedding.shape[-1]),
vision_hidden_states.view(-1, vision_hidden_states.shape[-1]),
)
return vlm_embedding, vision_hidden_states
def _parse_and_validate_inputs( def _parse_and_validate_inputs(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
@@ -836,46 +801,6 @@ class MiniCPMVBaseModel(nn.Module):
type="image_embeds", type="image_embeds",
) )
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError(
"Incorrect type of pixel values. " f"Got type: {type(pixel_values)}"
)
if not isinstance(tgt_sizes, (torch.Tensor, list)):
raise ValueError(
"Incorrect type of target sizes. " f"Got type: {type(tgt_sizes)}"
)
if len(pixel_values) != len(tgt_sizes):
raise ValueError(
"Inconsistent batch lengths, found: "
f"{len(pixel_values)} vs. {len(tgt_sizes)}"
)
pixel_values_flat: List[torch.Tensor] = []
tgt_sizes_flat: List[torch.Tensor] = []
for pixel_b, tgt_b in zip(pixel_values, tgt_sizes):
if len(pixel_b) != len(tgt_b):
raise ValueError(
"Inconsistent N lengths, found: " f"{len(pixel_b)} vs {len(tgt_b)}"
)
for pixel_n, tgt_n in zip(pixel_b, tgt_b):
pixel_values_flat += pixel_n
tgt_sizes_flat += tgt_n
# NOTE: Input IDs does not contain image tokens during memory profiling,
# so we allow it to be empty
if len(pixel_values_flat) != len(tgt_sizes_flat):
raise ValueError(
"Inconsistent flattened lengths, found: "
f"{len(pixel_values_flat)} vs. "
f"{len(tgt_sizes_flat)}"
)
if len(pixel_values_flat) == 0:
return None
image_bounds = self._get_image_bounds( image_bounds = self._get_image_bounds(
input_ids=input_ids, input_ids=input_ids,
pad_values=pad_values, pad_values=pad_values,
@@ -886,11 +811,50 @@ class MiniCPMVBaseModel(nn.Module):
) )
return MiniCPMVImagePixelInputs( return MiniCPMVImagePixelInputs(
image_bounds=image_bounds.to(device=input_ids.device), image_bounds=image_bounds.to(device=input_ids.device),
data=pixel_values_flat, data=pixel_values,
tgt_sizes=torch.stack(tgt_sizes_flat), tgt_sizes=tgt_sizes,
type="pixel_values", type="pixel_values",
) )
def get_embedding(
self,
input_ids: torch.Tensor,
image_inputs: Optional[MiniCPMVImageInputs],
) -> Tuple[torch.Tensor, torch.Tensor]:
vlm_embedding: torch.Tensor = self.llm.get_input_embeddings(input_ids)
if image_inputs is None: # No image
vision_hidden_states = torch.tensor([], device=input_ids.device)
else:
if image_inputs["type"] == "image_embeds":
vision_hidden_states = (
image_inputs["data"]
.type(vlm_embedding.dtype)
.to(vlm_embedding.device)
)
else:
vision_hidden_states = self.get_vision_hidden_states(image_inputs)
# See NOTE in _parse_and_validate_inputs
image_bounds = image_inputs["image_bounds"]
if len(image_bounds) > 0:
image_indices = torch.stack(
[
torch.arange(start, end, dtype=torch.long)
for start, end in image_bounds.tolist()
]
).to(vlm_embedding.device)
vlm_embedding.scatter_(
0,
image_indices.view(-1, 1).repeat(1, vlm_embedding.shape[-1]),
vision_hidden_states.view(-1, vision_hidden_states.shape[-1]),
)
return vlm_embedding, vision_hidden_states
def get_input_embeddings(self) -> nn.Embedding:
return self.llm.get_input_embedding()
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
@@ -899,58 +863,29 @@ class MiniCPMVBaseModel(nn.Module):
**kwargs: Any, **kwargs: Any,
) -> torch.Tensor: ) -> torch.Tensor:
if ( if (
forward_batch.image_inputs is not None forward_batch.forward_mode.is_decode()
and len(forward_batch.image_inputs) > 0 or not forward_batch.contains_image_inputs()
and forward_batch.image_inputs[0] is not None
): ):
# TODO: bath inputs_embeds: torch.Tensor = self.llm.get_input_embeddings(input_ids)
kwargs.update( else:
{ # Clamp input ids. This is because the input_ids for the image tokens are
"pixel_values": ( # filled with the hash values of the image for the prefix matching in the radix attention.
None # There values are useless because their embeddings will be replaced by vision embeddings anyway.
if forward_batch.image_inputs is None image_inputs = forward_batch.merge_image_inputs()
else [ inputs_embeds = embed_image_inputs(
i.pixel_values image_input=image_inputs,
for i in forward_batch.image_inputs input_ids=input_ids,
if i is not None input_embedding=self.get_input_embeddings(),
] image_embedding_func=self.get_image_features,
), placeholder_token_ids=[image_inputs.im_token_id]
"tgt_sizes": ( + image_inputs.pad_values,
None
if forward_batch.image_inputs is None
else [
i.tgt_sizes
for i in forward_batch.image_inputs
if i is not None
]
),
"im_start_id": forward_batch.image_inputs[0].im_start_id,
"im_end_id": forward_batch.image_inputs[0].im_end_id,
"slice_start_id": forward_batch.image_inputs[0].slice_start_id,
"slice_end_id": forward_batch.image_inputs[0].slice_end_id,
"pad_values": forward_batch.image_inputs[0].pad_values,
}
) )
image_inputs = self._parse_and_validate_inputs(input_ids, **kwargs)
# Clamp input ids. This is because the input_ids for the image tokens are
# filled with the hash values of the image for the prefix matching in the radix attention.
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs)
# always pass the input via `inputs_embeds`
# to make sure the computation graph is consistent
# for `torch.compile` integration
input_ids = None
hidden_states = self.llm.model( hidden_states = self.llm.model(
input_ids=input_ids, input_ids=None,
positions=positions, positions=positions,
forward_batch=forward_batch, forward_batch=forward_batch,
input_embeds=vlm_embeddings, input_embeds=inputs_embeds,
) )
return self.logits_processor( return self.logits_processor(
@@ -990,7 +925,7 @@ class MiniCPMVBaseModel(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
def get_vision_hidden_states(self, data: MiniCPMVImageInputs) -> torch.Tensor: def get_image_features(self, image_inputs: ImageInputs) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
@@ -1100,12 +1035,14 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
) )
return vision_embedding return vision_embedding
def get_vision_hidden_states( def get_image_features(
self, self,
data: MiniCPMVImageInputs, image_inputs: ImageInputs,
) -> torch.Tensor: ) -> torch.Tensor:
pixel_values = data["data"] # list of tensors
tgt_sizes = data["tgt_sizes"] pixel_values = image_inputs.pixel_values
tgt_sizes = image_inputs.tgt_sizes
device = self.vpm.embeddings.position_embedding.weight.device device = self.vpm.embeddings.position_embedding.weight.device
dtype = self.vpm.embeddings.position_embedding.weight.dtype dtype = self.vpm.embeddings.position_embedding.weight.dtype

View File

@@ -361,6 +361,9 @@ class Qwen2ForCausalLM(nn.Module):
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.get_input_embeddings(input_ids)
def get_input_embedding(self) -> nn.Embedding:
return self.model.embed_tokens
@torch.no_grad() @torch.no_grad()
def forward( def forward(
self, self,

View File

@@ -26,7 +26,6 @@ import logging
from functools import lru_cache, partial from functools import lru_cache, partial
from typing import Iterable, List, Optional, Tuple, Type from typing import Iterable, List, Optional, Tuple, Type
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
@@ -54,14 +53,15 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.managers.multi_modality_padding import ( from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs, MultiModalityDataPaddingPatternTokenPairs,
general_mm_embed_routine,
) )
from sglang.srt.managers.schedule_batch import ImageInputs from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2 import Qwen2Model from sglang.srt.models.qwen2 import Qwen2Model
from sglang.srt.models.qwen2_vl import Qwen2VLImageInputs, Qwen2VLVideoInputs from sglang.srt.models.qwen2_vl import Qwen2VLVideoInputs
from sglang.srt.utils import add_prefix from sglang.srt.utils import add_prefix
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -326,13 +326,12 @@ class Qwen2_5_VisionTransformer(nn.Module):
) )
def get_window_index(self, grid_thw): def get_window_index(self, grid_thw):
window_index: list = []
cu_window_seqlens: list = [0] cu_window_seqlens: list = [0]
window_index_id = 0 window_index_id = 0
vit_merger_window_size = ( vit_merger_window_size = (
self.window_size // self.spatial_merge_size // self.patch_size self.window_size // self.spatial_merge_size // self.patch_size
) )
window_index: list = []
for grid_t, grid_h, grid_w in grid_thw: for grid_t, grid_h, grid_w in grid_thw:
llm_grid_h, llm_grid_w = ( llm_grid_h, llm_grid_w = (
grid_h // self.spatial_merge_size, grid_h // self.spatial_merge_size,
@@ -369,7 +368,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
window_index = torch.cat(window_index, dim=0) window_index = torch.cat(window_index, dim=0)
return window_index, cu_window_seqlens return window_index, cu_window_seqlens
@property @property
@@ -382,8 +380,10 @@ class Qwen2_5_VisionTransformer(nn.Module):
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
pos_ids = [] pos_ids = []
for t, h, w in grid_thw: for i in range(grid_thw.size(0)):
t, h, w = grid_thw[i].tolist()
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
hpos_ids = hpos_ids.reshape( hpos_ids = hpos_ids.reshape(
h // self.spatial_merge_size, h // self.spatial_merge_size,
self.spatial_merge_size, self.spatial_merge_size,
@@ -402,6 +402,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
) )
wpos_ids = wpos_ids.permute(0, 2, 1, 3) wpos_ids = wpos_ids.permute(0, 2, 1, 3)
wpos_ids = wpos_ids.flatten() wpos_ids = wpos_ids.flatten()
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
pos_ids = torch.cat(pos_ids, dim=0) pos_ids = torch.cat(pos_ids, dim=0)
max_grid_size = grid_thw[:, 1:].max() max_grid_size = grid_thw[:, 1:].max()
@@ -443,9 +444,12 @@ class Qwen2_5_VisionTransformer(nn.Module):
position_embeddings = (emb.cos(), emb.sin()) position_embeddings = (emb.cos(), emb.sin())
# compute cu_seqlens # compute cu_seqlens
cu_seqlens = torch.repeat_interleave( cu_seqlens = torch.cat(
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] [
).cumsum(dim=0, dtype=torch.int32) torch.tensor([0], device=grid_thw.device),
(grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]).cumsum(dim=0),
]
)
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0) cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
# transformers # transformers
@@ -509,18 +513,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
def calculate_num_image_tokens(self, image_grid_thw: Tuple[int, int, int]):
processor = cached_get_processor(self.config._name_or_path)
grid_t, grid_h, grid_w = image_grid_thw
num_image_tokens = (
grid_t
* grid_h
* grid_w
// processor.image_processor.merge_size
// processor.image_processor.merge_size
)
return num_image_tokens
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs): def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
# Get all special token IDs # Get all special token IDs
im_start_id: int = image_inputs.im_start_id im_start_id: int = image_inputs.im_start_id
@@ -531,9 +523,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
return pattern.pad_input_tokens(input_ids, image_inputs) return pattern.pad_input_tokens(input_ids, image_inputs)
def _process_image_input(self, image_input: Qwen2VLImageInputs) -> torch.Tensor: def get_image_feature(self, image_input: ImageInputs) -> torch.Tensor:
pixel_values = image_input["pixel_values"].type(self.visual.dtype) pixel_values = image_input.pixel_values.type(self.visual.dtype)
image_embeds = self.visual(pixel_values, grid_thw=image_input["image_grid_thw"]) image_embeds = self.visual(pixel_values, grid_thw=image_input.image_grid_thws)
return image_embeds return image_embeds
def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor: def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor:
@@ -543,6 +535,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
) )
return video_embeds return video_embeds
def get_input_embeddings(self):
return self.model.embed_tokens
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
@@ -565,86 +560,26 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope": if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
positions = forward_batch.mrope_positions positions = forward_batch.mrope_positions
image_inputs = None if not (
if forward_batch.image_inputs is not None:
image_inputs = [
img for img in forward_batch.image_inputs if img is not None
]
if (
forward_batch.forward_mode.is_decode() forward_batch.forward_mode.is_decode()
or image_inputs is None or not forward_batch.contains_image_inputs()
or len(image_inputs) == 0
): ):
inputs_embeds = self.model.embed_tokens(input_ids)
else:
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope": if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
assert positions.ndim == 2 and positions.size(0) == 3, ( assert positions.ndim == 2 and positions.size(0) == 3, (
"multimodal section rotary embedding requires " "multimodal section rotary embedding requires "
f"(3, seq_len) positions, but got {positions.size()}" f"(3, seq_len) positions, but got {positions.size()}"
) )
# Clamp input ids. This is because the input_ids for the image tokens are inputs_embeds = general_mm_embed_routine(
# filled with the hash values of the image for the prefix matching in the radix attention. input_ids=input_ids,
# There values are useless because their embeddings will be replaced by vision embeddings anyway. positions=positions,
input_ids.clamp_(min=0, max=self.config.vocab_size - 1) forward_batch=forward_batch,
# [B, s, hidden_size] embed_tokens=self.get_input_embeddings(),
inputs_embeds = self.model.embed_tokens(input_ids) image_embedding_func=self.get_image_feature,
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy() )
prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
for i, image in enumerate(forward_batch.image_inputs):
if image is None or image.pixel_values is None:
continue
start_idx = extend_start_loc_cpu[i]
prefix_len = prefix_lens_cpu[i]
pixel_values = image.pixel_values.to(device="cuda")
image_grid_thws = torch.tensor(
np.array(image.image_grid_thws), device="cuda"
)
image_offsets = image.image_offsets
image_input = Qwen2VLImageInputs(
pixel_values=pixel_values, image_grid_thw=image_grid_thws
)
image_embeds = self._process_image_input(image_input)
image_embeds_offset = 0
for idx, image_offset in enumerate(image_offsets):
if image_offset < prefix_len:
continue
num_image_tokens = self.calculate_num_image_tokens(
image_grid_thws[idx]
)
left_idx = start_idx + (image_offset - prefix_len)
right_idx = left_idx + num_image_tokens
tp_size = get_tensor_model_parallel_world_size()
hidden_size = image_embeds.shape[-1]
if hidden_size % tp_size != 0:
padding_size = tp_size - (hidden_size % tp_size)
image_embeds = F.pad(image_embeds, (0, padding_size))
inputs_embeds = F.pad(inputs_embeds, (0, padding_size))
hidden_chunk_size = image_embeds.shape[-1] // tp_size
rank = get_tensor_model_parallel_rank()
start_dim = rank * hidden_chunk_size
end_dim = (rank + 1) * hidden_chunk_size
inputs_embeds[left_idx:right_idx, ..., start_dim:end_dim] = (
image_embeds[
image_embeds_offset : image_embeds_offset
+ num_image_tokens,
...,
start_dim:end_dim,
]
)
image_embeds_offset += num_image_tokens
hidden_states = self.model( hidden_states = self.model(
input_ids=input_ids, input_ids=None,
positions=positions, positions=positions,
forward_batch=forward_batch, forward_batch=forward_batch,
input_embeds=inputs_embeds, input_embeds=inputs_embeds,

View File

@@ -26,7 +26,6 @@ import logging
from functools import lru_cache, partial from functools import lru_cache, partial
from typing import Iterable, List, Optional, Tuple, Type, TypedDict from typing import Iterable, List, Optional, Tuple, Type, TypedDict
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
@@ -42,8 +41,9 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.managers.multi_modality_padding import ( from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs, MultiModalityDataPaddingPatternTokenPairs,
general_mm_embed_routine,
) )
from sglang.srt.managers.schedule_batch import ImageInputs from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
@@ -351,7 +351,7 @@ class Qwen2VisionTransformer(nn.Module):
@property @property
def dtype(self) -> torch.dtype: def dtype(self) -> torch.dtype:
return self.blocks[0].mlp.fc2.weight.dtype return next(self.parameters()).dtype
@property @property
def device(self) -> torch.device: def device(self) -> torch.device:
@@ -359,7 +359,8 @@ class Qwen2VisionTransformer(nn.Module):
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
pos_ids = [] pos_ids = []
for t, h, w in grid_thw: for i in range(grid_thw.size(0)):
t, h, w = grid_thw[i].tolist()
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
hpos_ids = ( hpos_ids = (
@@ -480,9 +481,9 @@ class Qwen2VLForConditionalGeneration(nn.Module):
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs) pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
return pattern.pad_input_tokens(input_ids, image_inputs) return pattern.pad_input_tokens(input_ids, image_inputs)
def _process_image_input(self, image_input: Qwen2VLImageInputs) -> torch.Tensor: def get_image_feature(self, image_input: ImageInputs) -> torch.Tensor:
pixel_values = image_input["pixel_values"].type(self.visual.dtype) pixel_values = image_input.pixel_values.type(self.visual.dtype)
image_embeds = self.visual(pixel_values, grid_thw=image_input["image_grid_thw"]) image_embeds = self.visual(pixel_values, grid_thw=image_input.image_grid_thws)
return image_embeds return image_embeds
def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor: def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor:
@@ -492,6 +493,9 @@ class Qwen2VLForConditionalGeneration(nn.Module):
) )
return video_embeds return video_embeds
def get_input_embeddings(self):
return self.model.embed_tokens
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
@@ -514,67 +518,26 @@ class Qwen2VLForConditionalGeneration(nn.Module):
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope": if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
positions = forward_batch.mrope_positions positions = forward_batch.mrope_positions
image_inputs = None if not (
if forward_batch.image_inputs is not None:
image_inputs = [
img for img in forward_batch.image_inputs if img is not None
]
if (
forward_batch.forward_mode.is_decode() forward_batch.forward_mode.is_decode()
or image_inputs is None or not forward_batch.contains_image_inputs()
or len(image_inputs) == 0
): ):
inputs_embeds = self.model.embed_tokens(input_ids)
else:
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope": if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
assert positions.ndim == 2 and positions.size(0) == 3, ( assert positions.ndim == 2 and positions.size(0) == 3, (
"multimodal section rotary embedding requires " "multimodal section rotary embedding requires "
f"(3, seq_len) positions, but got {positions.size()}" f"(3, seq_len) positions, but got {positions.size()}"
) )
# Clamp input ids. This is because the input_ids for the image tokens are inputs_embeds = general_mm_embed_routine(
# filled with the hash values of the image for the prefix matching in the radix attention. input_ids=input_ids,
# There values are useless because their embeddings will be replaced by vision embeddings anyway. positions=positions,
input_ids.clamp_(min=0, max=self.config.vocab_size - 1) forward_batch=forward_batch,
embed_tokens=self.get_input_embeddings(),
inputs_embeds = self.model.embed_tokens(input_ids) image_embedding_func=self.get_image_feature,
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy() )
prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
for i, image in enumerate(forward_batch.image_inputs):
if image is None or image.pixel_values is None:
continue
start_idx = extend_start_loc_cpu[i]
prefix_len = prefix_lens_cpu[i]
pixel_values = image.pixel_values.clone()
image_grid_thws = torch.tensor(
np.array(image.image_grid_thws), device="cuda"
)
image_offsets = image.image_offsets
image_input = Qwen2VLImageInputs(
pixel_values=pixel_values, image_grid_thw=image_grid_thws
)
image_embeds = self._process_image_input(image_input)
image_embeds_offset = 0
for idx, image_offset in enumerate(image_offsets):
if image_offset < prefix_len:
continue
num_image_tokens = self.calculate_num_image_tokens(
image_grid_thws[idx]
)
left_idx = start_idx + (image_offset - prefix_len + 1)
right_idx = left_idx + num_image_tokens
inputs_embeds[left_idx:right_idx] = image_embeds[
image_embeds_offset : image_embeds_offset + num_image_tokens
]
image_embeds_offset += num_image_tokens
input_ids = None
hidden_states = self.model( hidden_states = self.model(
input_ids=input_ids, input_ids=None,
positions=positions, positions=positions,
forward_batch=forward_batch, forward_batch=forward_batch,
input_embeds=inputs_embeds, input_embeds=inputs_embeds,

View File

@@ -23,6 +23,17 @@ from sglang.test.test_utils import (
popen_launch_server, popen_launch_server,
) )
# image
IMAGE_MAN_IRONING_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/images/man_ironing_on_back_of_suv.png"
IMAGE_SGL_LOGO_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/images/sgl_logo.png"
# video
VIDEO_JOBS_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/videos/jobs_presenting_ipod.mp4"
# audio
AUDIO_TRUMP_SPEECH_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/audios/Trump_WEF_2018_10s.mp3"
AUDIO_BIRD_SONG_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/audios/bird_song.mp3"
class TestOpenAIVisionServer(unittest.TestCase): class TestOpenAIVisionServer(unittest.TestCase):
@classmethod @classmethod
@@ -58,9 +69,7 @@ class TestOpenAIVisionServer(unittest.TestCase):
"content": [ "content": [
{ {
"type": "image_url", "type": "image_url",
"image_url": { "image_url": {"url": IMAGE_MAN_IRONING_URL},
"url": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
},
}, },
{ {
"type": "text", "type": "text",
@@ -96,9 +105,7 @@ class TestOpenAIVisionServer(unittest.TestCase):
"content": [ "content": [
{ {
"type": "image_url", "type": "image_url",
"image_url": { "image_url": {"url": IMAGE_MAN_IRONING_URL},
"url": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
},
}, },
{ {
"type": "text", "type": "text",
@@ -153,9 +160,7 @@ class TestOpenAIVisionServer(unittest.TestCase):
}, },
{ {
"type": "image_url", "type": "image_url",
"image_url": { "image_url": {"url": IMAGE_SGL_LOGO_URL},
"url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png"
},
"modalities": "multi-images", "modalities": "multi-images",
}, },
{ {
@@ -242,10 +247,12 @@ class TestOpenAIVisionServer(unittest.TestCase):
] ]
return messages return messages
def test_video_chat_completion(self): def get_or_download_file(self, url: str) -> str:
url = "https://raw.githubusercontent.com/EvolvingLMMs-Lab/sglang/dev/onevision_local/assets/jobs.mp4"
cache_dir = os.path.expanduser("~/.cache") cache_dir = os.path.expanduser("~/.cache")
file_path = os.path.join(cache_dir, "jobs.mp4") if url is None:
raise ValueError()
file_name = url.split("/")[-1]
file_path = os.path.join(cache_dir, file_name)
os.makedirs(cache_dir, exist_ok=True) os.makedirs(cache_dir, exist_ok=True)
if not os.path.exists(file_path): if not os.path.exists(file_path):
@@ -254,6 +261,11 @@ class TestOpenAIVisionServer(unittest.TestCase):
with open(file_path, "wb") as f: with open(file_path, "wb") as f:
f.write(response.content) f.write(response.content)
return file_path
def test_video_chat_completion(self):
url = VIDEO_JOBS_URL
file_path = self.get_or_download_file(url)
client = openai.Client(api_key=self.api_key, base_url=self.base_url) client = openai.Client(api_key=self.api_key, base_url=self.base_url)
@@ -289,6 +301,7 @@ class TestOpenAIVisionServer(unittest.TestCase):
"present" in video_response "present" in video_response
or "examine" in video_response or "examine" in video_response
or "display" in video_response or "display" in video_response
or "hold" in video_response
) )
assert "black" in video_response or "dark" in video_response assert "black" in video_response or "dark" in video_response
self.assertIsNotNone(video_response) self.assertIsNotNone(video_response)
@@ -312,9 +325,7 @@ class TestOpenAIVisionServer(unittest.TestCase):
"content": [ "content": [
{ {
"type": "image_url", "type": "image_url",
"image_url": { "image_url": {"url": IMAGE_MAN_IRONING_URL},
"url": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
},
}, },
{ {
"type": "text", "type": "text",
@@ -344,18 +355,14 @@ class TestOpenAIVisionServer(unittest.TestCase):
content.append( content.append(
{ {
"type": "image_url", "type": "image_url",
"image_url": { "image_url": {"url": IMAGE_MAN_IRONING_URL},
"url": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
},
} }
) )
elif image_id == 1: elif image_id == 1:
content.append( content.append(
{ {
"type": "image_url", "type": "image_url",
"image_url": { "image_url": {"url": IMAGE_SGL_LOGO_URL},
"url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png"
},
} }
) )
else: else:
@@ -465,9 +472,7 @@ class TestVLMContextLengthIssue(unittest.TestCase):
"content": [ "content": [
{ {
"type": "image_url", "type": "image_url",
"image_url": { "image_url": {"url": IMAGE_MAN_IRONING_URL},
"url": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
},
}, },
{ {
"type": "text", "type": "text",

View File

@@ -13,6 +13,8 @@ from transformers import AutoModel, AutoProcessor, AutoTokenizer
from sglang.srt.configs.model_config import ModelConfig from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.conversation import generate_chat_conv from sglang.srt.conversation import generate_chat_conv
from sglang.srt.managers.mm_utils import embed_image_inputs
from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.openai_api.protocol import ChatCompletionRequest from sglang.srt.openai_api.protocol import ChatCompletionRequest
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
@@ -168,10 +170,14 @@ class TestMiniCPMVLogits(VisionLLMLogitsBase):
).eval() ).eval()
cls.model.to(cls.device) cls.model.to(cls.device)
async def test_encode_output(self): async def test_vlm_embedding_output(self):
"""
Compares the embedding output of vlm
"""
inputs = self.get_processor_output() inputs = self.get_processor_output()
with torch.no_grad(): with torch.no_grad():
# hf
model_inputs = { model_inputs = {
"input_ids": inputs.input_ids, "input_ids": inputs.input_ids,
"image_bound": inputs.image_bound, "image_bound": inputs.image_bound,
@@ -183,22 +189,20 @@ class TestMiniCPMVLogits(VisionLLMLogitsBase):
) )
hf_output = hf_output.squeeze(0) hf_output = hf_output.squeeze(0)
with torch.no_grad(): # sglang
model = self.get_sglang_model() model = self.get_sglang_model()
input_ids = inputs["input_ids"].to(self.device).flatten() input_ids = inputs["input_ids"].to(self.device).flatten()
image_inputs = model._parse_and_validate_inputs( sglang_output = embed_image_inputs(
image_input=ImageInputs(
pixel_values=inputs["pixel_values"][0],
tgt_sizes=inputs["tgt_sizes"][0],
),
input_ids=input_ids, input_ids=input_ids,
**{ input_embedding=model.get_input_embeddings(),
"pixel_values": [inputs["pixel_values"]], image_embedding_func=model.get_image_features,
"tgt_sizes": [inputs["tgt_sizes"]], placeholder_token_ids=[
"im_start_id": self.tokenizer.im_start_id, self.processor.tokenizer.unk_token_id,
"im_end_id": self.tokenizer.im_end_id, ],
"slice_start_id": self.tokenizer.slice_start_id,
"slice_end_id": self.tokenizer.slice_end_id,
},
)
(sglang_output, _) = model.get_embedding(
input_ids=input_ids, image_inputs=image_inputs
) )
self.compare_outputs(sglang_output, hf_output) self.compare_outputs(sglang_output, hf_output)