diff --git a/benchmark/mmmu/bench_sglang.py b/benchmark/mmmu/bench_sglang.py index 643a4b6d0..8a24af2e0 100644 --- a/benchmark/mmmu/bench_sglang.py +++ b/benchmark/mmmu/bench_sglang.py @@ -1,13 +1,14 @@ """ - Bench the sglang-hosted vLM with benchmark MMMU +Bench the sglang-hosted vLM with benchmark MMMU - Usage: - python benchmark/mmmu/bench_sglang.py --model-path Qwen/Qwen2-VL-7B-Instruct --chat-template qwen2-vl +Usage: + python benchmark/mmmu/bench_sglang.py --model-path Qwen/Qwen2-VL-7B-Instruct --chat-template qwen2-vl - The eval output will be logged +The eval output will be logged """ import argparse +import time import openai from data_utils import save_json @@ -37,6 +38,7 @@ def eval_mmmu(args): # had to use an openai server, since SglImage doesn't support image data client = openai.Client(api_key="sk", base_url=f"http://127.0.0.1:{args.port}/v1") + start = time.time() for i, sample in enumerate(tqdm(samples)): prompt = sample["final_input_prompt"] prefix = prompt.split("<")[0] @@ -73,6 +75,8 @@ def eval_mmmu(args): response = response.choices[0].message.content process_result(response, sample, answer_dict, out_samples) + print(f"Benchmark time: {time.time() - start}") + args.output_path = f"./val_sglang.json" save_json(args.output_path, out_samples) eval_result(model_answer_path=args.output_path, answer_dict=answer_dict) diff --git a/python/sglang/srt/configs/janus_pro.py b/python/sglang/srt/configs/janus_pro.py index fad9a34a5..ad254edc4 100644 --- a/python/sglang/srt/configs/janus_pro.py +++ b/python/sglang/srt/configs/janus_pro.py @@ -9,8 +9,6 @@ import PIL import torch from PIL.Image import Image from transformers import ( - AutoImageProcessor, - AutoProcessor, BaseImageProcessor, BatchFeature, LlamaConfig, @@ -20,6 +18,7 @@ from transformers import ( ) from transformers.image_utils import to_numpy_array +from sglang.srt.configs.utils import register_image_processor, register_processor from sglang.srt.mm_utils import expand2square @@ -625,5 +624,5 @@ class VLMImageProcessorConfig(PretrainedConfig): super().__init__(**kwargs) -AutoProcessor.register(MultiModalityConfig, VLChatProcessor, exist_ok=True) -AutoImageProcessor.register(VLMImageProcessorConfig, None, VLMImageProcessor, None) +register_processor(MultiModalityConfig, VLChatProcessor) +register_image_processor(MultiModalityConfig, VLMImageProcessor) diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index ad1f8f48e..0b4c48c5a 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -460,6 +460,7 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal multimodal_model_archs = [ + "DeepseekVL2ForCausalLM", "LlavaLlamaForCausalLM", "LlavaQwenForCausalLM", "LlavaMistralForCausalLM", @@ -472,7 +473,6 @@ multimodal_model_archs = [ "Qwen2_5_VLForConditionalGeneration", "MiniCPMV", "MultiModalityCausalLM", - "DeepseekVL2ForCausalLM", ] diff --git a/python/sglang/srt/configs/utils.py b/python/sglang/srt/configs/utils.py new file mode 100644 index 000000000..8403d95da --- /dev/null +++ b/python/sglang/srt/configs/utils.py @@ -0,0 +1,25 @@ +from typing import Type + +from transformers import ( + AutoImageProcessor, + AutoProcessor, + BaseImageProcessor, + PretrainedConfig, + ProcessorMixin, +) + + +def register_image_processor( + config: Type[PretrainedConfig], image_processor: Type[BaseImageProcessor] +): + """ + register customized hf image processor while removing hf impl + """ + AutoImageProcessor.register(config, None, image_processor, None, exist_ok=True) + + +def register_processor(config: Type[PretrainedConfig], processor: Type[ProcessorMixin]): + """ + register customized hf processor while removing hf impl + """ + AutoProcessor.register(config, processor, exist_ok=True) diff --git a/python/sglang/srt/conversation.py b/python/sglang/srt/conversation.py index 7e580da54..9a7cc31b0 100644 --- a/python/sglang/srt/conversation.py +++ b/python/sglang/srt/conversation.py @@ -653,7 +653,7 @@ register_conv_template( Conversation( name="gemma-it", system_message="You are a helpful assistant.", - system_template="user{system_message}\n\n", + system_template="user{system_message}\n\n", roles=("user\n", "model\n"), sep="\n", sep_style=SeparatorStyle.GEMMA3, diff --git a/python/sglang/srt/layers/attention/vision.py b/python/sglang/srt/layers/attention/vision.py index 7b23ae82e..c76188e5e 100644 --- a/python/sglang/srt/layers/attention/vision.py +++ b/python/sglang/srt/layers/attention/vision.py @@ -143,9 +143,14 @@ class VisionAttention(nn.Module): if position_embeddings is not None: cos, sin = position_embeddings original_shape = q.shape - q, k = q.view(s, head, -1), k.view(s, head, -1) + # [total_tokens, head, head_size] + q = q.view(-1, head, self.head_size) + k = k.view(-1, head, self.head_size) + q, k = apply_rotary_pos_emb(q, k, cos, sin) - q, k = q.reshape(original_shape), k.reshape(original_shape) + + q = q.view(original_shape) + k = k.view(original_shape) if self.use_qkv_parallel: pass diff --git a/python/sglang/srt/managers/image_processor.py b/python/sglang/srt/managers/image_processor.py index c3ef9b51c..794d6034d 100644 --- a/python/sglang/srt/managers/image_processor.py +++ b/python/sglang/srt/managers/image_processor.py @@ -1,9 +1,12 @@ # TODO: also move pad_input_ids into this module import importlib +import inspect import logging import pkgutil from functools import lru_cache +from typing import Union +from torch import Tensor from transformers import IMAGE_PROCESSOR_MAPPING from sglang.srt.managers.image_processors.base_image_processor import ( @@ -18,9 +21,7 @@ logger = logging.getLogger(__name__) IMAGE_PROCESSOR_MAPPING = {} -def get_image_processor( - hf_config, server_args: ServerArgs, processor -) -> BaseImageProcessor: +def get_image_processor(hf_config, server_args, processor) -> BaseImageProcessor: for model_cls, processor_cls in IMAGE_PROCESSOR_MAPPING.items(): if model_cls.__name__ in hf_config.architectures: return processor_cls(hf_config, server_args, processor) @@ -42,13 +43,18 @@ def import_image_processors(): try: module = importlib.import_module(name) except Exception as e: - logger.warning(f"Ignore import error when loading {name}: " f"{e}") + logger.warning(f" Ignore import error when loading {name}: " f"{e}") continue - if hasattr(module, "ImageProcessorMapping"): - entry = module.ImageProcessorMapping - if isinstance(entry, dict): - for processor_name, cls in entry.items(): - IMAGE_PROCESSOR_MAPPING[processor_name] = cls + all_members = inspect.getmembers(module, inspect.isclass) + classes = [ + member + for name, member in all_members + if member.__module__ == module.__name__ + ] + for cls in classes: + if issubclass(cls, BaseImageProcessor): + for arch in getattr(cls, "models"): + IMAGE_PROCESSOR_MAPPING[arch] = cls # also register processors diff --git a/python/sglang/srt/managers/image_processors/base_image_processor.py b/python/sglang/srt/managers/image_processors/base_image_processor.py index 86bacc2f4..deac9ed14 100644 --- a/python/sglang/srt/managers/image_processors/base_image_processor.py +++ b/python/sglang/srt/managers/image_processors/base_image_processor.py @@ -4,14 +4,14 @@ import dataclasses import multiprocessing as mp import os from abc import ABC, abstractmethod -from typing import Optional +from typing import Optional, Union import PIL import transformers from decord import VideoReader, cpu +from openai import BadRequestError from PIL import Image -from sglang.srt.server_args import ServerArgs from sglang.srt.utils import load_image from sglang.utils import logger @@ -31,8 +31,16 @@ class BaseImageProcessorOutput: # input_text, with each frame of video/image represented as an image_token input_text: str + def normalize(self): + for field_name in ["data_hashes", "image_sizes", "all_frames"]: + field = getattr(self, field_name, None) + if field is not None and isinstance(field, list) and len(field) == 0: + setattr(self, field_name, None) + class BaseImageProcessor(ABC): + models = [] + def __init__(self, hf_config, server_args, _processor): self.hf_config = hf_config self._processor = _processor @@ -40,6 +48,9 @@ class BaseImageProcessor(ABC): # FIXME: not accurate, model and image specific self.NUM_TOKEN_PER_FRAME = 330 + # Initialize global processor first + init_global_processor(self, server_args) + self.executor = concurrent.futures.ProcessPoolExecutor( initializer=init_global_processor, mp_context=mp.get_context("fork"), @@ -113,7 +124,7 @@ class BaseImageProcessor(ABC): self, input_ids: list[int], image_data, - image_token: str, + image_token: Union[int, str], max_req_input_len: int, return_text: Optional[bool] = True, discard_alpha_channel: bool = True, @@ -122,9 +133,16 @@ class BaseImageProcessor(ABC): Each frame of video/image will be replaced by a single image token Args: + image_token: The token ID representing the image placeholder. discard_alpha_channel: if True, discards the alpha channel in the returned images """ + if isinstance(image_token, int): + image_token_str = self._processor.tokenizer.convert_ids_to_tokens( + image_token + ) + else: + image_token_str = image_token if isinstance(input_ids, list) and return_text: assert len(input_ids) and isinstance(input_ids[0], int) @@ -190,13 +208,11 @@ class BaseImageProcessor(ABC): new_text += text_part except Exception as e: - import openai logger.error(f"An exception occurred while loading images: {e}") raise BadRequestError( f"An exception occurred while loading images: {e}" ) - continue return BaseImageProcessorOutput( image_hashes=hashes, @@ -204,6 +220,8 @@ class BaseImageProcessor(ABC): all_frames=images, input_text=new_text, ) + out.normalize() + return out class DummyImageProcessor(BaseImageProcessor): @@ -214,9 +232,7 @@ class DummyImageProcessor(BaseImageProcessor): return None -def init_global_processor( - sglang_image_processor: BaseImageProcessor, server_args: ServerArgs -): +def init_global_processor(sglang_image_processor: BaseImageProcessor, server_args): """Init the global processor for multi-modal models.""" global global_processor transformers.logging.set_verbosity_error() diff --git a/python/sglang/srt/managers/image_processors/deepseek_vl_v2.py b/python/sglang/srt/managers/image_processors/deepseek_vl_v2.py index f19cf247a..5de4029b7 100644 --- a/python/sglang/srt/managers/image_processors/deepseek_vl_v2.py +++ b/python/sglang/srt/managers/image_processors/deepseek_vl_v2.py @@ -16,13 +16,9 @@ # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - import asyncio -import math -from typing import List, Union import torch -from PIL import Image, ImageOps from sglang.srt.managers.image_processor import BaseImageProcessor from sglang.srt.managers.image_processors.base_image_processor import ( @@ -32,18 +28,24 @@ from sglang.srt.models.deepseek_vl2 import DeepseekVL2ForCausalLM class DeepseekVL2ImageProcessor(BaseImageProcessor): + models = [DeepseekVL2ForCausalLM] + def __init__(self, hf_config, server_args, _processor): - # with contextlib.suppress(ValueError): - # AutoProcessor.register("DeepseekVLV2Processor", DeepseekVLV2Processor) super().__init__(hf_config, server_args, _processor) self.IMAGE_TOKEN = "" @staticmethod def _process_images_task(image, input_text, max_req_input_len): - return get_global_processor().__call__( + processor = get_global_processor() + res = processor.__call__( conversations=input_text, images=image, max_req_input_len=max_req_input_len ) + image_token_id = processor.image_token_id + + res["im_token_id"] = image_token_id + return res + async def _process_images(self, image_data, input_text, max_req_input_len): if self.executor is not None: loop = asyncio.get_event_loop() @@ -70,18 +72,15 @@ class DeepseekVL2ImageProcessor(BaseImageProcessor): if not isinstance(image_data, list): image_data = [image_data] - images, image_hashes, image_sizes = [], [], [] + images, image_sizes = [], [] image_token = self.IMAGE_TOKEN base_output = self.load_images( input_ids, image_data, image_token, max_req_input_len ) - base_output.all_frames = [img.convert("RGB") for img in base_output.all_frames] res = await self._process_images( base_output.all_frames, base_output.input_text, max_req_input_len ) - pixel_values = res["images"] - input_ids = res["input_ids"] images_seq_mask = res["images_seq_mask"] images_spatial_crop = res["images_spatial_crop"] batched_images_spatial_crop = [] @@ -89,16 +88,12 @@ class DeepseekVL2ImageProcessor(BaseImageProcessor): batched_images_spatial_crop = torch.stack(batched_images_spatial_crop, dim=0) return { - "input_ids": input_ids.tolist(), - "pixel_values": pixel_values, - "image_hashes": image_hashes, + "input_ids": res["input_ids"].tolist(), + "pixel_values": res["images"], + "im_token_id": res["im_token_id"], + "image_hashes": base_output.image_hashes, "image_sizes": image_sizes, - "image_seq_mask": images_seq_mask, + "images_emb_mask": images_seq_mask, "image_spatial_crop": batched_images_spatial_crop, "modalities": request_obj.modalities or ["image"], } - - -ImageProcessorMapping = { - DeepseekVL2ForCausalLM: DeepseekVL2ImageProcessor, -} diff --git a/python/sglang/srt/managers/image_processors/gemma3.py b/python/sglang/srt/managers/image_processors/gemma3.py index a54efb8d9..56fb988ca 100644 --- a/python/sglang/srt/managers/image_processors/gemma3.py +++ b/python/sglang/srt/managers/image_processors/gemma3.py @@ -17,14 +17,15 @@ logger = logging.get_logger(__name__) class Gemma3SGLangImageProcessor(SGLangBaseImageProcessor): + models = [Gemma3ForConditionalGeneration] + def __init__(self, hf_config, server_args, _processor): super().__init__(hf_config, server_args, _processor) self.IMAGE_TOKEN = "" self.IM_START_TOKEN_ID = hf_config.boi_token_index self.IM_END_TOKEN_ID = hf_config.eoi_token_index - @staticmethod - def _process_images_task(images, input_text, _hf_config): + async def _process_single_image(self, images, input_text) -> dict: if isinstance(images, list) and len(images) == 0: images = None processor = get_global_processor() @@ -46,19 +47,6 @@ class Gemma3SGLangImageProcessor(SGLangBaseImageProcessor): "pixel_values": pixel_values, } - async def _process_images(self, images, input_text) -> dict: - if self.executor is not None: - loop = asyncio.get_event_loop() - return await loop.run_in_executor( - self.executor, - Gemma3SGLangImageProcessor._process_images_task, - images, - input_text, - self.hf_config, - ) - else: - return self._process_images_task(images, input_text, self.hf_config) - async def process_images_async( self, image_data: List[Union[str, bytes]], @@ -82,7 +70,7 @@ class Gemma3SGLangImageProcessor(SGLangBaseImageProcessor): discard_alpha_channel=True, ) - ret = await self._process_images( + ret = await self._process_single_image( input_text=base_output.input_text, images=base_output.all_frames ) @@ -93,8 +81,3 @@ class Gemma3SGLangImageProcessor(SGLangBaseImageProcessor): "im_start_id": self.IM_START_TOKEN_ID, "im_end_id": self.IM_END_TOKEN_ID, } - - -ImageProcessorMapping = { - Gemma3ForConditionalGeneration: Gemma3SGLangImageProcessor, -} diff --git a/python/sglang/srt/managers/image_processors/janus_pro.py b/python/sglang/srt/managers/image_processors/janus_pro.py index 36db528d3..368729e73 100644 --- a/python/sglang/srt/managers/image_processors/janus_pro.py +++ b/python/sglang/srt/managers/image_processors/janus_pro.py @@ -11,6 +11,8 @@ from sglang.srt.models.deepseek_janus_pro import MultiModalityCausalLM class JanusProProcessor(SGLangBaseImageProcessor): + models = [MultiModalityCausalLM] + def __init__(self, hf_config, server_args, _processor): super().__init__(hf_config, server_args, _processor) @@ -77,6 +79,3 @@ class JanusProProcessor(SGLangBaseImageProcessor): "im_end_id": res["im_end_id"], "im_token_id": res["im_token_id"], } - - -ImageProcessorMapping = {MultiModalityCausalLM: JanusProProcessor} diff --git a/python/sglang/srt/managers/image_processors/llava.py b/python/sglang/srt/managers/image_processors/llava.py index eee08ff40..e153215e4 100644 --- a/python/sglang/srt/managers/image_processors/llava.py +++ b/python/sglang/srt/managers/image_processors/llava.py @@ -15,6 +15,8 @@ from sglang.utils import get_exception_traceback class LlavaImageProcessor(BaseImageProcessor): + models = [LlavaVidForCausalLM, LlavaQwenForCausalLM, LlavaMistralForCausalLM] + def __init__(self, hf_config, server_args, _processor): super().__init__(hf_config, server_args, _processor) @@ -143,10 +145,3 @@ class LlavaImageProcessor(BaseImageProcessor): "image_sizes": image_sizes, "modalities": request_obj.modalities or ["image"], } - - -ImageProcessorMapping = { - LlavaVidForCausalLM: LlavaImageProcessor, - LlavaQwenForCausalLM: LlavaImageProcessor, - LlavaMistralForCausalLM: LlavaImageProcessor, -} diff --git a/python/sglang/srt/managers/image_processors/minicpmv.py b/python/sglang/srt/managers/image_processors/minicpmv.py index 4e10092bf..b47621501 100644 --- a/python/sglang/srt/managers/image_processors/minicpmv.py +++ b/python/sglang/srt/managers/image_processors/minicpmv.py @@ -1,6 +1,8 @@ import asyncio from typing import List, Union +import torch + from sglang.srt.managers.image_processor import BaseImageProcessor from sglang.srt.managers.image_processors.base_image_processor import ( get_global_processor, @@ -9,6 +11,8 @@ from sglang.srt.models.minicpmv import MiniCPMV class MiniCPMVImageProcessor(BaseImageProcessor): + models = [MiniCPMV] + def __init__(self, hf_config, server_args, _processor): super().__init__(hf_config, server_args, _processor) self.IMAGE_TOKEN = "(./)" @@ -69,21 +73,57 @@ class MiniCPMVImageProcessor(BaseImageProcessor): # Collect special token ids tokenizer = self._processor.tokenizer im_start_id = tokenizer.im_start_id + im_token_id = tokenizer.unk_token_id im_end_id = tokenizer.im_end_id if tokenizer.slice_start_id: slice_start_id = tokenizer.slice_start_id slice_end_id = tokenizer.slice_end_id + + pixel_values = res["pixel_values"] + tgt_sizes = res["tgt_sizes"] + + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError( + "Incorrect type of pixel values. " f"Got type: {type(pixel_values)}" + ) + + if not isinstance(tgt_sizes, (torch.Tensor, list)): + raise ValueError( + "Incorrect type of target sizes. " f"Got type: {type(tgt_sizes)}" + ) + + if len(pixel_values) != len(tgt_sizes): + raise ValueError( + "Inconsistent batch lengths, found: " + f"{len(pixel_values)} vs. {len(tgt_sizes)}" + ) + + # tgt_sizes = [tgt_size for tgt_size in tgt_sizes if isinstance(tgt_size, torch.Tensor)] + # tgt_sizes = torch.vstack(tgt_sizes).type(torch.int32) + pixel_values_flat: List[torch.Tensor] = [] + tgt_sizes_flat: List[torch.Tensor] = [] + for pixel_b, tgt_b in zip(pixel_values, tgt_sizes): + # per image + if len(pixel_b) != len(tgt_b): + raise ValueError( + "Inconsistent N lengths, found: " f"{len(pixel_b)} vs {len(tgt_b)}" + ) + for pixel_n, tgt_n in zip(pixel_b, tgt_b): + # per patch + pixel_values_flat += [pixel_n] + tgt_sizes_flat += [tgt_n] + + pixel_values = pixel_values_flat + tgt_sizes = torch.stack(tgt_sizes_flat) return { "input_ids": res["input_ids"].flatten().tolist(), - "pixel_values": res["pixel_values"], - "tgt_sizes": res["tgt_sizes"], + "pixel_values": pixel_values, + "tgt_sizes": tgt_sizes, "image_hashes": base_output.image_hashes, "modalities": request_obj.modalities or ["image"], "im_start_id": im_start_id, + "im_token_id": im_token_id, "im_end_id": im_end_id, "slice_start_id": slice_start_id, "slice_end_id": slice_end_id, } - - -ImageProcessorMapping = {MiniCPMV: MiniCPMVImageProcessor} diff --git a/python/sglang/srt/managers/image_processors/mlama.py b/python/sglang/srt/managers/image_processors/mlama.py index 8043067d8..c5d12e3bf 100644 --- a/python/sglang/srt/managers/image_processors/mlama.py +++ b/python/sglang/srt/managers/image_processors/mlama.py @@ -10,6 +10,8 @@ from sglang.srt.utils import load_image class MllamaImageProcessor(BaseImageProcessor): + models = [MllamaForConditionalGeneration] + def __init__(self, hf_config, server_args, _processor): super().__init__(hf_config, server_args, _processor) @@ -55,6 +57,3 @@ class MllamaImageProcessor(BaseImageProcessor): image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0] return image_inputs - - -ImageProcessorMapping = {MllamaForConditionalGeneration: MllamaImageProcessor} diff --git a/python/sglang/srt/managers/image_processors/qwen_vl.py b/python/sglang/srt/managers/image_processors/qwen_vl.py index 46add1383..d0cec61a7 100644 --- a/python/sglang/srt/managers/image_processors/qwen_vl.py +++ b/python/sglang/srt/managers/image_processors/qwen_vl.py @@ -2,6 +2,7 @@ import asyncio import math from typing import List, Union +import torch from PIL import Image from sglang.srt.managers.image_processor import BaseImageProcessor @@ -14,6 +15,8 @@ from sglang.srt.models.qwen2_vl import Qwen2VLForConditionalGeneration # Compatible with Qwen2VL and Qwen2_5VL class Qwen2_5VLImageProcessor(BaseImageProcessor): + models = [Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration] + def __init__(self, hf_config, server_args, _processor): super().__init__(hf_config, server_args, _processor) self.IMAGE_TOKEN = "<|vision_start|><|image_pad|><|vision_end|>" @@ -43,7 +46,7 @@ class Qwen2_5VLImageProcessor(BaseImageProcessor): "video_grid_thws": getattr(result, "video_grid_thws", None), } - async def _process_images(self, images, input_text) -> dict: + async def _process_single_image(self, images, input_text) -> dict: if self.executor is not None: loop = asyncio.get_event_loop() return await loop.run_in_executor( @@ -138,23 +141,23 @@ class Qwen2_5VLImageProcessor(BaseImageProcessor): images = [resize_image(image) for image in base_output.all_frames] - ret = await self._process_images(images, base_output.input_text) + ret = await self._process_single_image( + images=images, input_text=base_output.input_text + ) + + image_grid_thws = torch.concat([ret["image_grid_thw"]]) + video_grid_thws = None + return { "input_ids": ret["input_ids"].flatten().tolist(), "pixel_values": ret["pixel_values"], "image_hashes": base_output.image_hashes, "modalities": request_obj.modalities or ["image"], - "image_grid_thws": ret["image_grid_thw"], - "video_grid_thws": ret["video_grid_thws"], + "image_grid_thws": image_grid_thws, + "video_grid_thws": video_grid_thws, "im_start_id": self.IM_START_TOKEN_ID, "im_end_id": self.IM_END_TOKEN_ID, "im_token_id": self.image_token_id, "video_token_id": self.video_token_id, "second_per_grid_ts": ret["second_per_grid_ts"], } - - -ImageProcessorMapping = { - Qwen2VLForConditionalGeneration: Qwen2_5VLImageProcessor, - Qwen2_5_VLForConditionalGeneration: Qwen2_5VLImageProcessor, -} diff --git a/python/sglang/srt/managers/mm_utils.py b/python/sglang/srt/managers/mm_utils.py new file mode 100644 index 000000000..2aa9787e6 --- /dev/null +++ b/python/sglang/srt/managers/mm_utils.py @@ -0,0 +1,303 @@ +""" + Multimodality utils +""" + +from abc import abstractmethod +from typing import Callable, List, Optional, Tuple + +import torch +from torch import nn + +from sglang.srt.managers.schedule_batch import ( + ImageInputs, + global_server_args_dict, + logger, +) +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.utils import logger + + +class MultiModalityDataPaddingPattern: + """ + Data tokens (like image tokens) often need special handling during padding + to maintain model compatibility. This class provides the interface for + implementing different padding strategies for data tokens + """ + + @abstractmethod + def pad_input_tokens( + self, input_ids: List[int], image_inputs: ImageInputs + ) -> List[int]: + """ + Pad the input ids sequence containing data tokens, and replace them with pad_values + """ + pass + + +class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern): + """In this pattern, data tokens should be enclosed by special token pairs (e.g. ..., data_token_pairs) + + This strategy should be applied when data content is marked by start/end token pairs in the input sequence. + """ + + def __init__(self, data_token_pairs: Optional[List[Tuple[int, int]]]) -> None: + self.data_token_id_pairs = data_token_pairs + + def pad_input_tokens( + self, input_ids: List[int], image_inputs: ImageInputs + ) -> List[int]: + """ + This function will replace the data-tokens inbetween with pad_values accordingly + """ + pad_values = image_inputs.pad_values + data_token_pairs = self.data_token_id_pairs + image_inputs.image_offsets = [] + if data_token_pairs is None: + data_token_pairs = [image_inputs.im_start_id, image_inputs.im_end_id] + if data_token_pairs is None: + logger.warning( + "No data_token_pairs provided, RadixAttention might be influenced." + ) + return input_ids + start_token_ids = [s for s, _e in data_token_pairs] + end_tokens_ids = [e for _s, e in data_token_pairs] + # First start token marks new data + data_start_token = start_token_ids[0] + + padded_ids = [] + last_idx = 0 + data_idx = -1 + + start_indices = [i for i, x in enumerate(input_ids) if x in start_token_ids] + end_indices = [i for i, x in enumerate(input_ids) if x in end_tokens_ids] + + if len(start_indices) != len(end_indices): + return input_ids + + for start_idx, end_idx in zip(start_indices, end_indices): + padded_ids.extend(input_ids[last_idx : start_idx + 1]) + + if input_ids[start_idx] == data_start_token: + data_idx += 1 + image_inputs.image_offsets += [start_idx] + + num_tokens = end_idx - start_idx - 1 + pad_value = pad_values[data_idx] + padded_ids.extend([pad_value] * num_tokens) + + last_idx = end_idx + + padded_ids.extend(input_ids[last_idx:]) + + assert len(input_ids) == len(padded_ids) + return padded_ids + + +class MultModalityDataPaddingPatternSingleToken(MultiModalityDataPaddingPattern): + """In this pattern, data is represented with a special token_id ( image_inputs.im_token_id ), + which needs first to be expanded to multiple tokens, then replaced with their padding values + + This strategy should be used when a single data token represents content that should + be expanded to multiple tokens during processing. + """ + + def __init__( + self, num_data_token_calc_func: Callable[[Tuple[int, int, int]], int] + ) -> None: + self.num_data_token_calc_func = num_data_token_calc_func + + def pad_input_tokens( + self, input_ids: List[int], image_inputs: ImageInputs + ) -> List[int]: + """ + This function will follow the procedure of: + 1. the data token will be expanded, of which the final number will be calculated by `num_data_token_calc_func` + 2. the padded data tokens will be replaced with their pad_values + """ + image_grid_thws = image_inputs.image_grid_thws + pad_values = image_inputs.pad_values + + image_indices = [ + idx + for idx, token in enumerate(input_ids) + if token == image_inputs.im_token_id + ] + + image_inputs.image_offsets = [] + + input_ids_with_image = [] + for image_cnt, _ in enumerate(image_grid_thws): + num_image_tokens = self.num_data_token_calc_func(image_grid_thws[image_cnt]) + if image_cnt == 0: + non_image_tokens = input_ids[: image_indices[image_cnt]] + else: + non_image_tokens = input_ids[ + image_indices[image_cnt - 1] + 1 : image_indices[image_cnt] + ] + input_ids_with_image.extend(non_image_tokens) + image_inputs.image_offsets.append(len(input_ids_with_image)) + pad_ids = pad_values * ( + (num_image_tokens + len(pad_values)) // len(pad_values) + ) + input_ids_with_image.extend(pad_ids[:num_image_tokens]) + input_ids_with_image.extend(input_ids[image_indices[-1] + 1 :]) + + return input_ids_with_image + + +class MultiModalityDataPaddingPatternImageTokens(MultiModalityDataPaddingPattern): + """In this pattern, data tokens should be represented as image tokens (e.g. ....)""" + + def __init__(self, image_token_id: torch.Tensor) -> None: + self.image_token_id = image_token_id + + def pad_input_tokens(self, input_ids: List[int], image_inputs) -> List[int]: + """ + This function will replace the data-tokens in between with pad_values accordingly + """ + pad_values = image_inputs.pad_values + assert len(pad_values) != 0 + + input_ids_tensor = torch.tensor(input_ids) + mask = torch.isin(input_ids_tensor, self.image_token_id) + + num_image_tokens = mask.sum().item() + repeated_pad_values = torch.tensor(pad_values).repeat( + num_image_tokens // len(pad_values) + 1 + )[:num_image_tokens] + + input_ids_tensor[mask] = repeated_pad_values + return input_ids_tensor.tolist() + + +def embed_image_inputs( + image_input: ImageInputs, + input_ids: torch.Tensor, + input_embedding: nn.Embedding, + image_embedding_func, + placeholder_token_ids: List[int] = None, +) -> Optional[torch.Tensor]: + """ + Calculate the image embeddings if necessary, then scatter the result with + the help of a boolean mask denoting the embed locations + + Returns: + final embedding: Optional[torch.Tensor] + """ + if image_input is None: + return None + + placeholder_token_ids = placeholder_token_ids or image_input.pad_values + + # boolean masking the special tokens + special_image_mask = torch.isin( + input_ids, + torch.tensor(placeholder_token_ids, device=input_ids.device), + ).unsqueeze(-1) + + num_image_tokens_in_input_ids = special_image_mask.sum() + + if num_image_tokens_in_input_ids == 0: + # unexpected + inputs_embeds = input_embedding(input_ids) + else: + image_embedding = image_embedding_func(image_input) + + if image_embedding.dim() == 2: + num_image_tokens_in_embedding = image_embedding.shape[0] + else: + num_image_tokens_in_embedding = ( + image_embedding.shape[0] * image_embedding.shape[1] + ) + if num_image_tokens_in_input_ids != num_image_tokens_in_embedding: + num_image = num_image_tokens_in_input_ids // image_embedding.shape[1] + image_embedding = image_embedding[:num_image, :] + logger.warning( + f"Number of images does not match number of special image tokens in the input text. " + f"Got {num_image_tokens_in_input_ids} image tokens in the text but {num_image_tokens_in_embedding} " + "tokens from image embeddings." + ) + + # TODO: chunked prefill will split special tokens from input_ids into several passes, failing the embedding + # a fix may be cache the unfinished image embedding for future reuse, determine the tokens to embed with + # extend_start_loc and extend_seq_lens + if num_image_tokens_in_input_ids > num_image_tokens_in_embedding: + chunked_prefill_size = global_server_args_dict["chunked_prefill_size"] + if chunked_prefill_size != -1: + logger.warning( + "You may want to avoid this issue by raising `chunked_prefill_size`, or disabling chunked_prefill" + ) + + vocab_size = input_embedding.num_embeddings + # Important: clamp after getting original image regions + # Clamp input ids. This is because the input_ids for the image tokens are + # filled with the hash values of the image for the prefix matching in the radix attention. + # There values are useless because their embeddings will be replaced by vision embeddings anyway. + input_ids.clamp_(min=0, max=vocab_size - 1) + inputs_embeds = input_embedding(input_ids) + + special_image_mask = special_image_mask.expand_as(inputs_embeds).to( + inputs_embeds.device + ) + + inputs_embeds = inputs_embeds.masked_scatter( + special_image_mask, + image_embedding.to(inputs_embeds.device, inputs_embeds.dtype), + ) + return inputs_embeds + + +def embed_image_embedding( + inputs_embeds: torch.Tensor, + image_embedding: torch.Tensor, + image_bounds: torch.Tensor, +) -> torch.Tensor: + """ + scatter image_embedding into inputs_embeds according to image_bounds + """ + if len(image_bounds) > 0: + image_indices = torch.stack( + [ + torch.arange(start, end, dtype=torch.long) + for start, end in image_bounds.tolist() + ] + ).to(inputs_embeds.device) + + inputs_embeds.scatter_( + 0, + image_indices.view(-1, 1).repeat(1, inputs_embeds.shape[-1]), + image_embedding.view(-1, image_embedding.shape[-1]), + ) + return inputs_embeds + + +def general_mm_embed_routine( + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + embed_tokens: nn.Embedding, + image_embedding_func: Callable[[ImageInputs], torch.Tensor], + placeholder_token_ids: List[int] = None, +): + """ + a general wrapper function to get final input embeds from multimodal models + with a language model as causal model + """ + if ( + forward_batch.forward_mode.is_decode() + or not forward_batch.contains_image_inputs() + ): + inputs_embeds = embed_tokens(input_ids) + else: + image = forward_batch.merge_image_inputs() + inputs_embeds = embed_image_inputs( + image_input=image, + input_ids=input_ids, + input_embedding=embed_tokens, + image_embedding_func=image_embedding_func, + placeholder_token_ids=placeholder_token_ids, + ) + # once used, image_inputs is useless + # just being defensive here + forward_batch.image_inputs = None + return inputs_embeds diff --git a/python/sglang/srt/managers/multi_modality_padding.py b/python/sglang/srt/managers/multi_modality_padding.py deleted file mode 100644 index b0b662b7c..000000000 --- a/python/sglang/srt/managers/multi_modality_padding.py +++ /dev/null @@ -1,134 +0,0 @@ -from abc import abstractmethod -from typing import Callable, List, Optional, Tuple - -from sglang.srt.managers.schedule_batch import ImageInputs -from sglang.utils import logger - - -class MultiModalityDataPaddingPattern: - """ - Data tokens (like image tokens) often need special handling during padding - to maintain model compatibility. This class provides the interface for - implementing different padding strategies for data tokens - """ - - @abstractmethod - def pad_input_tokens( - self, input_ids: List[int], image_inputs: ImageInputs - ) -> List[int]: - """ - Pad the input ids sequence containing data tokens, and replace them with pad_values - """ - pass - - -class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern): - """In this pattern, data tokens should be enclosed by special token pairs (e.g. ..., data_token_pairs) - - This strategy should be applied when data content is marked by start/end token pairs in the input sequence. - """ - - def __init__(self, data_token_pairs: Optional[List[Tuple[int, int]]]) -> None: - self.data_token_id_pairs = data_token_pairs - - def pad_input_tokens( - self, input_ids: List[int], image_inputs: ImageInputs - ) -> List[int]: - """ - This function will replace the data-tokens inbetween with pad_values accordingly - """ - pad_values = image_inputs.pad_values - data_token_pairs = self.data_token_id_pairs - image_inputs.image_offsets = [] - if data_token_pairs is None: - data_token_pairs = [image_inputs.im_start_id, image_inputs.im_end_id] - if data_token_pairs is None: - logger.warning( - "No data_token_pairs provided, RadixAttention might be influenced." - ) - return input_ids - start_token_ids = [s for s, _e in data_token_pairs] - end_tokens_ids = [e for _s, e in data_token_pairs] - # First start token marks new data - data_start_token = start_token_ids[0] - - padded_ids = [] - last_idx = 0 - data_idx = -1 - - start_indices = [i for i, x in enumerate(input_ids) if x in start_token_ids] - end_indices = [i for i, x in enumerate(input_ids) if x in end_tokens_ids] - - if len(start_indices) != len(end_indices): - return input_ids - - for start_idx, end_idx in zip(start_indices, end_indices): - padded_ids.extend(input_ids[last_idx : start_idx + 1]) - - if input_ids[start_idx] == data_start_token: - data_idx += 1 - image_inputs.image_offsets += [start_idx] - - num_tokens = end_idx - start_idx - 1 - pad_value = pad_values[data_idx] - padded_ids.extend([pad_value] * num_tokens) - - last_idx = end_idx - - padded_ids.extend(input_ids[last_idx:]) - - assert len(input_ids) == len(padded_ids) - return padded_ids - - -class MultModalityDataPaddingPatternSingleToken(MultiModalityDataPaddingPattern): - """In this pattern, data is represented with a special token_id ( image_inputs.im_token_id ), - which needs first to be expanded to multiple tokens, then replaced with their padding values - - This strategy should be used when a single data token represents content that should - be expanded to multiple tokens during processing. - """ - - def __init__( - self, num_data_token_calc_func: Callable[[Tuple[int, int, int]], int] - ) -> None: - self.num_data_token_calc_func = num_data_token_calc_func - - def pad_input_tokens( - self, input_ids: List[int], image_inputs: ImageInputs - ) -> List[int]: - """ - This function will follow the procedure of: - 1. the data token will be expanded, of which the final number will be calculated by `num_data_token_calc_func` - 2. the padded data tokens will be replaced with their pad_values - """ - image_grid_thws = image_inputs.image_grid_thws - pad_values = image_inputs.pad_values - - image_indices = [ - idx - for idx, token in enumerate(input_ids) - if token == image_inputs.im_token_id - ] - - image_inputs.image_offsets = [] - - input_ids_with_image = [] - for image_cnt, _ in enumerate(image_grid_thws): - print(f"image_cnt {image_cnt}") - num_image_tokens = self.num_data_token_calc_func(image_grid_thws[image_cnt]) - if image_cnt == 0: - non_image_tokens = input_ids[: image_indices[image_cnt]] - else: - non_image_tokens = input_ids[ - image_indices[image_cnt - 1] + 1 : image_indices[image_cnt] - ] - input_ids_with_image.extend(non_image_tokens) - image_inputs.image_offsets.append(len(input_ids_with_image)) - pad_ids = pad_values * ( - (num_image_tokens + len(pad_values)) // len(pad_values) - ) - input_ids_with_image.extend(pad_ids[:num_image_tokens]) - input_ids_with_image.extend(input_ids[image_indices[-1] + 1 :]) - - return input_ids_with_image diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 77472f97b..8721cbc09 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -77,6 +77,7 @@ global_server_args_dict = { "enable_flashmla": ServerArgs.enable_flashmla, "disable_radix_cache": ServerArgs.disable_radix_cache, "flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged, + "chunked_prefill_size": ServerArgs.chunked_prefill_size, } logger = logging.getLogger(__name__) @@ -160,7 +161,8 @@ class ImageInputs: aspect_ratio_mask: Optional[List[torch.Tensor]] = None # QWen2-VL related - image_grid_thws: List[Tuple[int, int, int]] = None + # [num_of_images, t, h, w] + image_grid_thws: torch.Tensor = None mrope_position_delta: Optional[torch.Tensor] = None # Qwen2-VL video related video_token_id: Optional[int] = None @@ -168,7 +170,7 @@ class ImageInputs: second_per_grid_ts: Optional[List[torch.Tensor]] = None # deepseek vl2 related - image_seq_mask: Optional[List[torch.Tensor]] = None + images_emb_mask: Optional[List[torch.Tensor]] = None image_spatial_crop: Optional[List[torch.Tensor]] = None # The id of the single-image placeholder token @@ -182,9 +184,6 @@ class ImageInputs: slice_end_id: Optional[int] = None tgt_sizes: Optional[list] = None - # denotes the number of valid image tokens in each image - images_emb_mask: Optional[torch.BoolTensor] = None - @staticmethod def from_dict(obj: dict): ret = ImageInputs( @@ -204,7 +203,7 @@ class ImageInputs: "aspect_ratio_ids", "aspect_ratio_mask", "image_grid_thws", - "image_seq_mask", + "images_emb_mask", "image_spatial_crop", "im_token_id", "im_start_id", @@ -212,20 +211,58 @@ class ImageInputs: "slice_start_id", "slice_end_id", "tgt_sizes", - "images_emb_mask", ] for arg in optional_args: if arg in obj: setattr(ret, arg, obj[arg]) + # validate + assert ( + isinstance(ret.pixel_values, torch.Tensor) + or isinstance(ret.pixel_values, np.ndarray) + or isinstance(ret.pixel_values, list) + ) + return ret - def merge(self, other): + def merge(self, other: ImageInputs): """ merge image inputs when requests are being merged """ - assert self.pixel_values.shape[1:] == other.pixel_values.shape[1:] - self.pixel_values = np.concatenate([self.pixel_values, other.pixel_values]) + if isinstance(self.pixel_values, list): + # in some rare cases, pixel values are list of patches with different shapes + # e.g. minicpm + self.pixel_values += other.pixel_values + else: + assert ( + self.pixel_values.shape[1:] == other.pixel_values.shape[1:] + ), f"{self.pixel_values.shape[1:]} vs {other.pixel_values.shape[1:]}" + self.pixel_values = np.concatenate([self.pixel_values, other.pixel_values]) + + # args would be stacked along first dim + # usually these are already tensors + stack_args = [ + # TODO: merge with image_grid_thws, basically the same thing + "tgt_sizes", + "image_spatial_crop", + ] + for arg in stack_args: + if getattr(self, arg, None) is None: + setattr(self, arg, getattr(other, arg, None)) + elif getattr(other, arg, None) is not None: + # self and other both not None + setattr( + self, + arg, + torch.cat([getattr(self, arg), getattr(other, arg)], dim=0), + ) + + if self.image_grid_thws is None: + self.image_grid_thws = other.image_grid_thws + elif other.image_grid_thws is not None: + self.image_grid_thws = torch.concat( + [self.image_grid_thws, other.image_grid_thws] + ) # Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache. # Please note that if the `input_ids` is later used in the model forward, @@ -233,7 +270,7 @@ class ImageInputs: # errors in cuda kernels. See also llava.py for example. self.image_hashes += other.image_hashes self.pad_values = [x % (1 << 30) for x in self.image_hashes] - + # args needed to be merged optional_args = [ "image_sizes", "image_offsets", @@ -241,13 +278,13 @@ class ImageInputs: # "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images "aspect_ratio_ids", "aspect_ratio_mask", - "image_grid_thws", - "image_seq_mask", - "image_spatial_crop", + "images_emb_mask", ] for arg in optional_args: - if getattr(self, arg, None) is not None: - setattr(self, arg, getattr(self, arg) + getattr(other, arg)) + self_arg = getattr(self, arg, None) + if self_arg is not None: + setattr(self, arg, self_arg + getattr(other, arg)) + # other args would be kept intact class Req: diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index bbc1cbbbc..04082ab58 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -179,7 +179,7 @@ class TokenizerManager: ) # We want to parallelize the image pre-processing so we create an executor for it - # We creat image_processor for any skip_tokenizer_init to make sure we still encode + # We create image_processor for any skip_tokenizer_init to make sure we still encode # images even with skip_tokenizer_init=False. self.image_processor = get_image_processor( self.model_config.hf_config, server_args, _processor diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 80d1a447b..ade31e773 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -332,7 +332,7 @@ class ForwardBatch: return ret - def get_merged_image_inputs(self) -> Optional[ImageInputs]: + def merge_image_inputs(self) -> Optional[ImageInputs]: """ Merge all image inputs in the batch into a single ImageInputs object. @@ -358,6 +358,16 @@ class ForwardBatch: return merged + def contains_image_inputs(self) -> bool: + """ """ + if self.image_inputs is None: + return True + return any( + image_input.pixel_values is not None and image_input.pixel_values is not [] + for image_input in self.image_inputs + if image_input is not None + ) + def _compute_mrope_positions( self, model_runner: ModelRunner, batch: ModelWorkerBatch ): diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index eaaf2637f..6ae2af0df 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -273,7 +273,7 @@ class ModelRunner: if self.model_config.hf_config.architectures == ["DeepseekVL2ForCausalLM"]: # TODO: deepseek-vl2 does not support radix cache now, set disable_radix_cache=True automatically logger.info( - "Automatically turn off --chunked-prefill-size and disable radix cache for deekseek-vl2." + "Automatically turn off --chunked-prefill-size and disable radix cache for deepseek-vl2." ) server_args.chunked_prefill_size = -1 server_args.disable_radix_cache = True diff --git a/python/sglang/srt/models/deepseek_janus_pro.py b/python/sglang/srt/models/deepseek_janus_pro.py index 75a88b13f..2b657c1d2 100644 --- a/python/sglang/srt/models/deepseek_janus_pro.py +++ b/python/sglang/srt/models/deepseek_janus_pro.py @@ -47,8 +47,9 @@ from sglang.srt.configs.janus_pro import * from sglang.srt.layers.attention.vision import VisionAttention from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization import QuantizationConfig -from sglang.srt.managers.multi_modality_padding import ( +from sglang.srt.managers.mm_utils import ( MultiModalityDataPaddingPatternTokenPairs, + general_mm_embed_routine, ) from sglang.srt.managers.schedule_batch import ImageInputs from sglang.srt.model_executor.forward_batch_info import ForwardBatch @@ -1958,82 +1959,8 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel): ) self.logits_processor = LogitsProcessor(config) - def prepare_images_seq_mask( - self, input_ids: torch.Tensor, image_inputs: ImageInputs - ) -> Optional[torch.LongTensor]: - images_seq_mask = torch.isin( - input_ids, torch.tensor(image_inputs.pad_values, device=input_ids.device) - ) - if images_seq_mask.sum() == 0: - # sometimes image_inputs is not empty, but input_ids contain no image token because of prefix-cache - return None - else: - return images_seq_mask - - @torch.no_grad() - def forward( - self, - input_ids: torch.LongTensor, - positions: torch.Tensor, - forward_batch: ForwardBatch, - ) -> torch.Tensor: - - inputs_embeds = None - if ( - forward_batch.image_inputs is not None - and len(forward_batch.image_inputs) != 0 - and forward_batch.image_inputs[0] is not None - ): - - image_inputs = forward_batch.image_inputs[0] - - images_seq_mask = self.prepare_images_seq_mask( - input_ids=input_ids, image_inputs=image_inputs - ) - - if images_seq_mask is not None: - input_ids.clamp_(min=0, max=self.config.vocab_size - 1) - inputs_embeds = self.prepare_inputs_embeds( - input_ids=input_ids, - pixel_values=image_inputs.pixel_values, - images_seq_mask=images_seq_mask, - images_emb_mask=image_inputs.images_emb_mask, - ) - input_ids = None - - if input_ids is not None: - input_ids.clamp_(min=0, max=self.config.vocab_size - 1) - - return self.language_model( - input_ids=input_ids, - positions=positions, - forward_batch=forward_batch, - input_embeds=inputs_embeds, - get_embedding=False, - ) - - def prepare_inputs_embeds( - self, - input_ids: torch.LongTensor, - pixel_values: torch.FloatTensor, - images_seq_mask: torch.LongTensor, - images_emb_mask: torch.BoolTensor, - **_kwargs, - ): - """ - - Args: - input_ids (torch.LongTensor): [b, T] - pixel_values (torch.FloatTensor): [b, n_images, 3, h, w] - images_seq_mask (torch.BoolTensor): [b, T] - images_emb_mask (torch.BoolTensor): [b, n_images, n_image_tokens] - - assert torch.sum(images_seq_mask) == torch.sum(images_emb_mask) - - Returns: - input_embeds (torch.Tensor): [b, T, D] - """ - + def get_image_feature(self, image_input: ImageInputs) -> torch.Tensor: + pixel_values = image_input.pixel_values bs, n = pixel_values.shape[0:2] pixel_values = pixel_values.to( device=self.vision_model.device, dtype=self.vision_model.dtype @@ -2045,18 +1972,35 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel): # [b x n, T2, D] -> [b, n x T2, D] images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n) - # [b, n, T2] -> [b, n x T2] - images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)") - # [b, T, D] - # ignore the image embeddings - input_ids[input_ids < 0] = 0 - inputs_embeds = self.language_model.model.embed_tokens(input_ids) + return images_embeds - # replace with the image embeddings - inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask] + def get_input_embeddings(self) -> nn.Embedding: + return self.language_model.model.embed_tokens - return inputs_embeds + @torch.no_grad() + def forward( + self, + input_ids: torch.LongTensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + + inputs_embeds = general_mm_embed_routine( + input_ids=input_ids, + positions=positions, + forward_batch=forward_batch, + embed_tokens=self.get_input_embeddings(), + image_embedding_func=self.get_image_feature, + ) + + return self.language_model( + input_ids=None, + positions=positions, + forward_batch=forward_batch, + input_embeds=inputs_embeds, + get_embedding=False, + ) def prepare_gen_img_embeds(self, image_ids: torch.LongTensor): return self.gen_aligner(self.gen_embed(image_ids)) diff --git a/python/sglang/srt/models/deepseek_vl2.py b/python/sglang/srt/models/deepseek_vl2.py index 5fe5cd394..625927f7e 100644 --- a/python/sglang/srt/models/deepseek_vl2.py +++ b/python/sglang/srt/models/deepseek_vl2.py @@ -1,34 +1,16 @@ -import collections -import itertools -import math -import warnings -from enum import Enum -from functools import partial -from typing import Callable, Iterable, List, Optional, Tuple, Type, Union +from typing import Iterable, List, Optional, Tuple import torch import torch.nn.functional as F from einops import rearrange, repeat from torch import nn -from sglang.srt.configs import DeepseekVL2Config from sglang.srt.configs.deepseekvl2 import ( DeepseekVL2Config, DeepseekVL2MlpProjectorConfig, ) -from sglang.srt.layers.attention.vision import VisionAttention -from sglang.srt.layers.layernorm import RMSNorm -from sglang.srt.layers.linear import ( - ColumnParallelLinear, - LinearBase, - ReplicatedLinear, - RowParallelLinear, -) +from sglang.srt.layers.linear import ReplicatedLinear from sglang.srt.layers.quantization.base_config import QuantizationConfig -from sglang.srt.layers.vocab_parallel_embedding import ( - ParallelLMHead, - VocabParallelEmbedding, -) from sglang.srt.managers.schedule_batch import ImageInputs from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader @@ -233,11 +215,11 @@ class DeepseekVL2ForCausalLM(nn.Module): forward_batch: ForwardBatch, **kwargs: object, ): - input_embeds = self.language_model.model.embed_tokens(input_ids) - if forward_batch.forward_mode.is_extend() and forward_batch.image_inputs != [ - None - ]: + if ( + forward_batch.forward_mode.is_extend() + and forward_batch.contains_image_inputs() + ): extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy() extend_seq_lens_cpu = forward_batch.extend_seq_lens.cpu().numpy() for idx, image in enumerate(forward_batch.image_inputs): @@ -245,17 +227,11 @@ class DeepseekVL2ForCausalLM(nn.Module): continue start_idx = extend_start_loc_cpu[idx] end_idx = start_idx + extend_seq_lens_cpu[idx] - pixel_values = image.pixel_values.to( - device="cuda", dtype=torch.bfloat16 - ) - image_seq_mask = image.image_seq_mask.to(device="cuda") - image_spatial_crop = image.image_spatial_crop - input_embeds[start_idx:end_idx] = self.prepare_inputs_embeds( - pixel_values, - image_seq_mask, - image_spatial_crop, - input_embeds[start_idx:end_idx], - ) + images_emb_mask = image.images_emb_mask.to(device="cuda") + image_features = self.get_image_feature(image) + input_embeds[start_idx:end_idx] = input_embeds[ + start_idx:end_idx + ].masked_scatter(images_emb_mask.unsqueeze(-1), image_features) outputs = self.language_model.forward( input_ids=input_ids, @@ -289,20 +265,17 @@ class DeepseekVL2ForCausalLM(nn.Module): def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs): return input_ids - def prepare_inputs_embeds( - self, - pixel_values, - images_seq_mask, - images_spatial_crop, - input_embeds, - ): + def get_image_feature(self, image_input: ImageInputs): + pixel_values = image_input.pixel_values.type( + next(self.vision.parameters()).dtype + ).to(device=next(self.vision.parameters()).device) image_feature = self.vision.forward_features(pixel_values) images_embeds = self.projector(image_feature) _, hw, n_dim = images_embeds.shape h = w = int(hw**0.5) - tile_index = 0 images_in_this_batch = [] + images_spatial_crop = image_input.image_spatial_crop for jdx in range(images_spatial_crop.shape[1]): num_width_tiles, num_height_tiles = images_spatial_crop[0, jdx] if num_width_tiles == 0 or num_height_tiles == 0: @@ -379,13 +352,7 @@ class DeepseekVL2ForCausalLM(nn.Module): images_in_this_batch.append(global_local_features) - if len(images_in_this_batch) > 0: - images_in_this_batch = torch.cat(images_in_this_batch, dim=0) - input_embeds.masked_scatter_( - images_seq_mask.unsqueeze(-1), images_in_this_batch - ) - - return input_embeds + return torch.cat(images_in_this_batch, dim=0) EntryClass = DeepseekVL2ForCausalLM diff --git a/python/sglang/srt/models/gemma3_causal.py b/python/sglang/srt/models/gemma3_causal.py index 489b15798..d892c5152 100644 --- a/python/sglang/srt/models/gemma3_causal.py +++ b/python/sglang/srt/models/gemma3_causal.py @@ -37,11 +37,8 @@ from sglang.srt.layers.linear import ( from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb, get_rope -from sglang.srt.layers.vocab_parallel_embedding import ( - ParallelLMHead, - VocabParallelEmbedding, -) +from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb +from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import ( default_weight_loader, @@ -511,7 +508,7 @@ class Gemma3TextModel(PreTrainedModel): else: hidden_states = input_embeds - if len(positions.shape) == 1: + if positions.dim() == 1: positions = einops.rearrange(positions, "s -> 1 s") position_embeddings_global = self.rotary_emb(hidden_states, positions) @@ -609,11 +606,11 @@ class Gemma3ForCausalLM(PreTrainedModel): ) self.post_init() - def get_input_embeddings(self): + def get_input_embeddings(self) -> nn.Embedding: return self.model.embed_tokens def dtype(self) -> torch.dtype: - return self.model.layers[0].mlp.gate_up_proj.weight.dtype + return next(self.parameters()).dtype @torch.no_grad() def forward( diff --git a/python/sglang/srt/models/gemma3_mm.py b/python/sglang/srt/models/gemma3_mm.py index 401e65731..9be13ba64 100644 --- a/python/sglang/srt/models/gemma3_mm.py +++ b/python/sglang/srt/models/gemma3_mm.py @@ -34,8 +34,9 @@ from sglang.srt.hf_transformers_utils import get_processor from sglang.srt.layers.layernorm import Gemma3RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig -from sglang.srt.managers.multi_modality_padding import ( +from sglang.srt.managers.mm_utils import ( MultiModalityDataPaddingPatternTokenPairs, + general_mm_embed_routine, ) from sglang.srt.managers.schedule_batch import ImageInputs from sglang.srt.model_executor.forward_batch_info import ForwardBatch @@ -264,10 +265,10 @@ class Gemma3ForConditionalGeneration(PreTrainedModel): kwargs["local_attn_masks"] = local_attn_masks return kwargs - def get_input_embeddings(self): + def get_input_embeddings(self) -> nn.Embedding: return self.language_model.get_input_embeddings() - def get_image_features(self, pixel_values: torch.Tensor): + def get_image_feature(self, image_input: ImageInputs): """ Projects the last hidden state from the vision model into language model space. @@ -277,6 +278,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel): Returns: image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). """ + pixel_values = image_input.pixel_values pixel_values = pixel_values.to("cuda") pixel_values = pixel_values.to(dtype=self.language_model.dtype()) @@ -305,7 +307,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel): return inputs_embeds else: # print(f"image tokens from input_ids: {inputs_embeds[special_image_mask].numel()}") - image_features = self.get_image_features(image_input.pixel_values) + image_features = self.get_image_feature(image_input.pixel_values) # print(f"image tokens from image embeddings: {image_features.numel()}") num_image_tokens_in_embedding = ( @@ -397,20 +399,13 @@ class Gemma3ForConditionalGeneration(PreTrainedModel): else: llm_input_ids = input_ids - merged_image_input = forward_batch.get_merged_image_inputs() - - if ( - not forward_batch.forward_mode.is_decode() - and merged_image_input is not None - ): - inputs_embeds = self.embed_image_inputs( - input_ids=llm_input_ids, - forward_batch=forward_batch, - image_input=merged_image_input, - ) - else: - llm_input_ids.clamp_(min=0, max=self.vocab_size - 1) - inputs_embeds = self.get_input_embeddings()(llm_input_ids) + inputs_embeds = general_mm_embed_routine( + input_ids=llm_input_ids, + positions=positions, + forward_batch=forward_batch, + embed_tokens=self.get_input_embeddings(), + image_embedding_func=self.get_image_feature, + ) outputs = self.language_model( input_ids=None, diff --git a/python/sglang/srt/models/minicpmv.py b/python/sglang/srt/models/minicpmv.py index 00ae8fa01..4db3b8c98 100644 --- a/python/sglang/srt/models/minicpmv.py +++ b/python/sglang/srt/models/minicpmv.py @@ -50,8 +50,9 @@ from sglang.srt.layers.linear import ( ) from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig -from sglang.srt.managers.multi_modality_padding import ( +from sglang.srt.managers.mm_utils import ( MultiModalityDataPaddingPatternTokenPairs, + embed_image_inputs, ) from sglang.srt.managers.schedule_batch import ImageInputs from sglang.srt.model_executor.forward_batch_info import ForwardBatch @@ -399,7 +400,7 @@ class Idefics2VisionTransformer(nn.Module): ) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) - def get_input_embeddings(self): + def get_input_embeddings(self) -> nn.Embedding: return self.embeddings def compute_cu_seqlens(self, tgt_sizes: torch.Tensor) -> torch.Tensor: @@ -762,42 +763,6 @@ class MiniCPMVBaseModel(nn.Module): valid_pairs_tensor = torch.tensor(valid_pairs, device=input_ids.device) return valid_pairs_tensor - def get_embedding( - self, - input_ids: torch.Tensor, - image_inputs: Optional[MiniCPMVImageInputs], - ) -> Tuple[torch.Tensor, torch.Tensor]: - vlm_embedding: torch.Tensor = self.llm.get_input_embeddings(input_ids) - - if image_inputs is None: # No image - vision_hidden_states = torch.tensor([], device=input_ids.device) - else: - if image_inputs["type"] == "image_embeds": - vision_hidden_states = ( - image_inputs["data"] - .type(vlm_embedding.dtype) - .to(vlm_embedding.device) - ) - else: - vision_hidden_states = self.get_vision_hidden_states(image_inputs) - # See NOTE in _parse_and_validate_inputs - image_bounds = image_inputs["image_bounds"] - if len(image_bounds) > 0: - image_indices = torch.stack( - [ - torch.arange(start, end, dtype=torch.long) - for start, end in image_bounds.tolist() - ] - ).to(vlm_embedding.device) - - vlm_embedding.scatter_( - 0, - image_indices.view(-1, 1).repeat(1, vlm_embedding.shape[-1]), - vision_hidden_states.view(-1, vision_hidden_states.shape[-1]), - ) - - return vlm_embedding, vision_hidden_states - def _parse_and_validate_inputs( self, input_ids: torch.Tensor, @@ -836,46 +801,6 @@ class MiniCPMVBaseModel(nn.Module): type="image_embeds", ) - if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError( - "Incorrect type of pixel values. " f"Got type: {type(pixel_values)}" - ) - - if not isinstance(tgt_sizes, (torch.Tensor, list)): - raise ValueError( - "Incorrect type of target sizes. " f"Got type: {type(tgt_sizes)}" - ) - - if len(pixel_values) != len(tgt_sizes): - raise ValueError( - "Inconsistent batch lengths, found: " - f"{len(pixel_values)} vs. {len(tgt_sizes)}" - ) - - pixel_values_flat: List[torch.Tensor] = [] - tgt_sizes_flat: List[torch.Tensor] = [] - for pixel_b, tgt_b in zip(pixel_values, tgt_sizes): - if len(pixel_b) != len(tgt_b): - raise ValueError( - "Inconsistent N lengths, found: " f"{len(pixel_b)} vs {len(tgt_b)}" - ) - - for pixel_n, tgt_n in zip(pixel_b, tgt_b): - pixel_values_flat += pixel_n - tgt_sizes_flat += tgt_n - - # NOTE: Input IDs does not contain image tokens during memory profiling, - # so we allow it to be empty - if len(pixel_values_flat) != len(tgt_sizes_flat): - raise ValueError( - "Inconsistent flattened lengths, found: " - f"{len(pixel_values_flat)} vs. " - f"{len(tgt_sizes_flat)}" - ) - - if len(pixel_values_flat) == 0: - return None - image_bounds = self._get_image_bounds( input_ids=input_ids, pad_values=pad_values, @@ -886,11 +811,50 @@ class MiniCPMVBaseModel(nn.Module): ) return MiniCPMVImagePixelInputs( image_bounds=image_bounds.to(device=input_ids.device), - data=pixel_values_flat, - tgt_sizes=torch.stack(tgt_sizes_flat), + data=pixel_values, + tgt_sizes=tgt_sizes, type="pixel_values", ) + def get_embedding( + self, + input_ids: torch.Tensor, + image_inputs: Optional[MiniCPMVImageInputs], + ) -> Tuple[torch.Tensor, torch.Tensor]: + vlm_embedding: torch.Tensor = self.llm.get_input_embeddings(input_ids) + + if image_inputs is None: # No image + vision_hidden_states = torch.tensor([], device=input_ids.device) + else: + if image_inputs["type"] == "image_embeds": + vision_hidden_states = ( + image_inputs["data"] + .type(vlm_embedding.dtype) + .to(vlm_embedding.device) + ) + else: + vision_hidden_states = self.get_vision_hidden_states(image_inputs) + # See NOTE in _parse_and_validate_inputs + image_bounds = image_inputs["image_bounds"] + if len(image_bounds) > 0: + image_indices = torch.stack( + [ + torch.arange(start, end, dtype=torch.long) + for start, end in image_bounds.tolist() + ] + ).to(vlm_embedding.device) + + vlm_embedding.scatter_( + 0, + image_indices.view(-1, 1).repeat(1, vlm_embedding.shape[-1]), + vision_hidden_states.view(-1, vision_hidden_states.shape[-1]), + ) + + return vlm_embedding, vision_hidden_states + + def get_input_embeddings(self) -> nn.Embedding: + return self.llm.get_input_embedding() + def forward( self, input_ids: torch.Tensor, @@ -899,58 +863,29 @@ class MiniCPMVBaseModel(nn.Module): **kwargs: Any, ) -> torch.Tensor: if ( - forward_batch.image_inputs is not None - and len(forward_batch.image_inputs) > 0 - and forward_batch.image_inputs[0] is not None + forward_batch.forward_mode.is_decode() + or not forward_batch.contains_image_inputs() ): - # TODO: bath - kwargs.update( - { - "pixel_values": ( - None - if forward_batch.image_inputs is None - else [ - i.pixel_values - for i in forward_batch.image_inputs - if i is not None - ] - ), - "tgt_sizes": ( - None - if forward_batch.image_inputs is None - else [ - i.tgt_sizes - for i in forward_batch.image_inputs - if i is not None - ] - ), - "im_start_id": forward_batch.image_inputs[0].im_start_id, - "im_end_id": forward_batch.image_inputs[0].im_end_id, - "slice_start_id": forward_batch.image_inputs[0].slice_start_id, - "slice_end_id": forward_batch.image_inputs[0].slice_end_id, - "pad_values": forward_batch.image_inputs[0].pad_values, - } + inputs_embeds: torch.Tensor = self.llm.get_input_embeddings(input_ids) + else: + # Clamp input ids. This is because the input_ids for the image tokens are + # filled with the hash values of the image for the prefix matching in the radix attention. + # There values are useless because their embeddings will be replaced by vision embeddings anyway. + image_inputs = forward_batch.merge_image_inputs() + inputs_embeds = embed_image_inputs( + image_input=image_inputs, + input_ids=input_ids, + input_embedding=self.get_input_embeddings(), + image_embedding_func=self.get_image_features, + placeholder_token_ids=[image_inputs.im_token_id] + + image_inputs.pad_values, ) - image_inputs = self._parse_and_validate_inputs(input_ids, **kwargs) - - # Clamp input ids. This is because the input_ids for the image tokens are - # filled with the hash values of the image for the prefix matching in the radix attention. - # There values are useless because their embeddings will be replaced by vision embeddings anyway. - input_ids.clamp_(min=0, max=self.config.vocab_size - 1) - - vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs) - - # always pass the input via `inputs_embeds` - # to make sure the computation graph is consistent - # for `torch.compile` integration - input_ids = None - hidden_states = self.llm.model( - input_ids=input_ids, + input_ids=None, positions=positions, forward_batch=forward_batch, - input_embeds=vlm_embeddings, + input_embeds=inputs_embeds, ) return self.logits_processor( @@ -990,7 +925,7 @@ class MiniCPMVBaseModel(nn.Module): ) -> torch.Tensor: raise NotImplementedError - def get_vision_hidden_states(self, data: MiniCPMVImageInputs) -> torch.Tensor: + def get_image_features(self, image_inputs: ImageInputs) -> torch.Tensor: raise NotImplementedError @@ -1100,12 +1035,14 @@ class MiniCPMV2_6(MiniCPMVBaseModel): ) return vision_embedding - def get_vision_hidden_states( + def get_image_features( self, - data: MiniCPMVImageInputs, + image_inputs: ImageInputs, ) -> torch.Tensor: - pixel_values = data["data"] - tgt_sizes = data["tgt_sizes"] + # list of tensors + pixel_values = image_inputs.pixel_values + + tgt_sizes = image_inputs.tgt_sizes device = self.vpm.embeddings.position_embedding.weight.device dtype = self.vpm.embeddings.position_embedding.weight.dtype diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index 6e38ec6ca..46ff810b4 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -361,6 +361,9 @@ class Qwen2ForCausalLM(nn.Module): def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) + def get_input_embedding(self) -> nn.Embedding: + return self.model.embed_tokens + @torch.no_grad() def forward( self, diff --git a/python/sglang/srt/models/qwen2_5_vl.py b/python/sglang/srt/models/qwen2_5_vl.py index 95f926356..77d3ea54e 100644 --- a/python/sglang/srt/models/qwen2_5_vl.py +++ b/python/sglang/srt/models/qwen2_5_vl.py @@ -26,7 +26,6 @@ import logging from functools import lru_cache, partial from typing import Iterable, List, Optional, Tuple, Type -import numpy as np import torch import torch.nn as nn import torch.nn.functional as F @@ -54,14 +53,15 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead -from sglang.srt.managers.multi_modality_padding import ( +from sglang.srt.managers.mm_utils import ( MultiModalityDataPaddingPatternTokenPairs, + general_mm_embed_routine, ) from sglang.srt.managers.schedule_batch import ImageInputs from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.qwen2 import Qwen2Model -from sglang.srt.models.qwen2_vl import Qwen2VLImageInputs, Qwen2VLVideoInputs +from sglang.srt.models.qwen2_vl import Qwen2VLVideoInputs from sglang.srt.utils import add_prefix logger = logging.getLogger(__name__) @@ -326,13 +326,12 @@ class Qwen2_5_VisionTransformer(nn.Module): ) def get_window_index(self, grid_thw): - window_index: list = [] cu_window_seqlens: list = [0] window_index_id = 0 vit_merger_window_size = ( self.window_size // self.spatial_merge_size // self.patch_size ) - + window_index: list = [] for grid_t, grid_h, grid_w in grid_thw: llm_grid_h, llm_grid_w = ( grid_h // self.spatial_merge_size, @@ -369,7 +368,6 @@ class Qwen2_5_VisionTransformer(nn.Module): cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() window_index = torch.cat(window_index, dim=0) - return window_index, cu_window_seqlens @property @@ -382,8 +380,10 @@ class Qwen2_5_VisionTransformer(nn.Module): def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: pos_ids = [] - for t, h, w in grid_thw: + for i in range(grid_thw.size(0)): + t, h, w = grid_thw[i].tolist() hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( h // self.spatial_merge_size, self.spatial_merge_size, @@ -402,6 +402,7 @@ class Qwen2_5_VisionTransformer(nn.Module): ) wpos_ids = wpos_ids.permute(0, 2, 1, 3) wpos_ids = wpos_ids.flatten() + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) pos_ids = torch.cat(pos_ids, dim=0) max_grid_size = grid_thw[:, 1:].max() @@ -443,9 +444,12 @@ class Qwen2_5_VisionTransformer(nn.Module): position_embeddings = (emb.cos(), emb.sin()) # compute cu_seqlens - cu_seqlens = torch.repeat_interleave( - grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] - ).cumsum(dim=0, dtype=torch.int32) + cu_seqlens = torch.cat( + [ + torch.tensor([0], device=grid_thw.device), + (grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]).cumsum(dim=0), + ] + ) cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0) # transformers @@ -509,18 +513,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module): self.logits_processor = LogitsProcessor(config) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) - def calculate_num_image_tokens(self, image_grid_thw: Tuple[int, int, int]): - processor = cached_get_processor(self.config._name_or_path) - grid_t, grid_h, grid_w = image_grid_thw - num_image_tokens = ( - grid_t - * grid_h - * grid_w - // processor.image_processor.merge_size - // processor.image_processor.merge_size - ) - return num_image_tokens - def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs): # Get all special token IDs im_start_id: int = image_inputs.im_start_id @@ -531,9 +523,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module): return pattern.pad_input_tokens(input_ids, image_inputs) - def _process_image_input(self, image_input: Qwen2VLImageInputs) -> torch.Tensor: - pixel_values = image_input["pixel_values"].type(self.visual.dtype) - image_embeds = self.visual(pixel_values, grid_thw=image_input["image_grid_thw"]) + def get_image_feature(self, image_input: ImageInputs) -> torch.Tensor: + pixel_values = image_input.pixel_values.type(self.visual.dtype) + image_embeds = self.visual(pixel_values, grid_thw=image_input.image_grid_thws) return image_embeds def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor: @@ -543,6 +535,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module): ) return video_embeds + def get_input_embeddings(self): + return self.model.embed_tokens + def forward( self, input_ids: torch.Tensor, @@ -565,86 +560,26 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module): if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope": positions = forward_batch.mrope_positions - image_inputs = None - if forward_batch.image_inputs is not None: - image_inputs = [ - img for img in forward_batch.image_inputs if img is not None - ] - - if ( + if not ( forward_batch.forward_mode.is_decode() - or image_inputs is None - or len(image_inputs) == 0 + or not forward_batch.contains_image_inputs() ): - inputs_embeds = self.model.embed_tokens(input_ids) - else: if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope": assert positions.ndim == 2 and positions.size(0) == 3, ( "multimodal section rotary embedding requires " f"(3, seq_len) positions, but got {positions.size()}" ) - # Clamp input ids. This is because the input_ids for the image tokens are - # filled with the hash values of the image for the prefix matching in the radix attention. - # There values are useless because their embeddings will be replaced by vision embeddings anyway. - input_ids.clamp_(min=0, max=self.config.vocab_size - 1) - # [B, s, hidden_size] - inputs_embeds = self.model.embed_tokens(input_ids) - extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy() - prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu - for i, image in enumerate(forward_batch.image_inputs): - if image is None or image.pixel_values is None: - continue - start_idx = extend_start_loc_cpu[i] - prefix_len = prefix_lens_cpu[i] - - pixel_values = image.pixel_values.to(device="cuda") - - image_grid_thws = torch.tensor( - np.array(image.image_grid_thws), device="cuda" - ) - image_offsets = image.image_offsets - image_input = Qwen2VLImageInputs( - pixel_values=pixel_values, image_grid_thw=image_grid_thws - ) - image_embeds = self._process_image_input(image_input) - - image_embeds_offset = 0 - for idx, image_offset in enumerate(image_offsets): - if image_offset < prefix_len: - continue - num_image_tokens = self.calculate_num_image_tokens( - image_grid_thws[idx] - ) - - left_idx = start_idx + (image_offset - prefix_len) - right_idx = left_idx + num_image_tokens - - tp_size = get_tensor_model_parallel_world_size() - - hidden_size = image_embeds.shape[-1] - - if hidden_size % tp_size != 0: - padding_size = tp_size - (hidden_size % tp_size) - image_embeds = F.pad(image_embeds, (0, padding_size)) - inputs_embeds = F.pad(inputs_embeds, (0, padding_size)) - - hidden_chunk_size = image_embeds.shape[-1] // tp_size - rank = get_tensor_model_parallel_rank() - start_dim = rank * hidden_chunk_size - end_dim = (rank + 1) * hidden_chunk_size - inputs_embeds[left_idx:right_idx, ..., start_dim:end_dim] = ( - image_embeds[ - image_embeds_offset : image_embeds_offset - + num_image_tokens, - ..., - start_dim:end_dim, - ] - ) - image_embeds_offset += num_image_tokens + inputs_embeds = general_mm_embed_routine( + input_ids=input_ids, + positions=positions, + forward_batch=forward_batch, + embed_tokens=self.get_input_embeddings(), + image_embedding_func=self.get_image_feature, + ) hidden_states = self.model( - input_ids=input_ids, + input_ids=None, positions=positions, forward_batch=forward_batch, input_embeds=inputs_embeds, diff --git a/python/sglang/srt/models/qwen2_vl.py b/python/sglang/srt/models/qwen2_vl.py index 1bfcf526b..c929b006c 100644 --- a/python/sglang/srt/models/qwen2_vl.py +++ b/python/sglang/srt/models/qwen2_vl.py @@ -26,7 +26,6 @@ import logging from functools import lru_cache, partial from typing import Iterable, List, Optional, Tuple, Type, TypedDict -import numpy as np import torch import torch.nn as nn import torch.nn.functional as F @@ -42,8 +41,9 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead -from sglang.srt.managers.multi_modality_padding import ( +from sglang.srt.managers.mm_utils import ( MultiModalityDataPaddingPatternTokenPairs, + general_mm_embed_routine, ) from sglang.srt.managers.schedule_batch import ImageInputs from sglang.srt.model_executor.forward_batch_info import ForwardBatch @@ -351,7 +351,7 @@ class Qwen2VisionTransformer(nn.Module): @property def dtype(self) -> torch.dtype: - return self.blocks[0].mlp.fc2.weight.dtype + return next(self.parameters()).dtype @property def device(self) -> torch.device: @@ -359,7 +359,8 @@ class Qwen2VisionTransformer(nn.Module): def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: pos_ids = [] - for t, h, w in grid_thw: + for i in range(grid_thw.size(0)): + t, h, w = grid_thw[i].tolist() hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) hpos_ids = ( @@ -480,9 +481,9 @@ class Qwen2VLForConditionalGeneration(nn.Module): pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs) return pattern.pad_input_tokens(input_ids, image_inputs) - def _process_image_input(self, image_input: Qwen2VLImageInputs) -> torch.Tensor: - pixel_values = image_input["pixel_values"].type(self.visual.dtype) - image_embeds = self.visual(pixel_values, grid_thw=image_input["image_grid_thw"]) + def get_image_feature(self, image_input: ImageInputs) -> torch.Tensor: + pixel_values = image_input.pixel_values.type(self.visual.dtype) + image_embeds = self.visual(pixel_values, grid_thw=image_input.image_grid_thws) return image_embeds def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor: @@ -492,6 +493,9 @@ class Qwen2VLForConditionalGeneration(nn.Module): ) return video_embeds + def get_input_embeddings(self): + return self.model.embed_tokens + def forward( self, input_ids: torch.Tensor, @@ -514,67 +518,26 @@ class Qwen2VLForConditionalGeneration(nn.Module): if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope": positions = forward_batch.mrope_positions - image_inputs = None - if forward_batch.image_inputs is not None: - image_inputs = [ - img for img in forward_batch.image_inputs if img is not None - ] - - if ( + if not ( forward_batch.forward_mode.is_decode() - or image_inputs is None - or len(image_inputs) == 0 + or not forward_batch.contains_image_inputs() ): - inputs_embeds = self.model.embed_tokens(input_ids) - else: if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope": assert positions.ndim == 2 and positions.size(0) == 3, ( "multimodal section rotary embedding requires " f"(3, seq_len) positions, but got {positions.size()}" ) - # Clamp input ids. This is because the input_ids for the image tokens are - # filled with the hash values of the image for the prefix matching in the radix attention. - # There values are useless because their embeddings will be replaced by vision embeddings anyway. - input_ids.clamp_(min=0, max=self.config.vocab_size - 1) - - inputs_embeds = self.model.embed_tokens(input_ids) - extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy() - prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu - for i, image in enumerate(forward_batch.image_inputs): - if image is None or image.pixel_values is None: - continue - start_idx = extend_start_loc_cpu[i] - prefix_len = prefix_lens_cpu[i] - pixel_values = image.pixel_values.clone() - - image_grid_thws = torch.tensor( - np.array(image.image_grid_thws), device="cuda" - ) - image_offsets = image.image_offsets - image_input = Qwen2VLImageInputs( - pixel_values=pixel_values, image_grid_thw=image_grid_thws - ) - image_embeds = self._process_image_input(image_input) - - image_embeds_offset = 0 - for idx, image_offset in enumerate(image_offsets): - if image_offset < prefix_len: - continue - num_image_tokens = self.calculate_num_image_tokens( - image_grid_thws[idx] - ) - - left_idx = start_idx + (image_offset - prefix_len + 1) - right_idx = left_idx + num_image_tokens - inputs_embeds[left_idx:right_idx] = image_embeds[ - image_embeds_offset : image_embeds_offset + num_image_tokens - ] - image_embeds_offset += num_image_tokens - input_ids = None + inputs_embeds = general_mm_embed_routine( + input_ids=input_ids, + positions=positions, + forward_batch=forward_batch, + embed_tokens=self.get_input_embeddings(), + image_embedding_func=self.get_image_feature, + ) hidden_states = self.model( - input_ids=input_ids, + input_ids=None, positions=positions, forward_batch=forward_batch, input_embeds=inputs_embeds, diff --git a/test/srt/test_vision_openai_server.py b/test/srt/test_vision_openai_server.py index d24e901c0..587aa2d1d 100644 --- a/test/srt/test_vision_openai_server.py +++ b/test/srt/test_vision_openai_server.py @@ -23,6 +23,17 @@ from sglang.test.test_utils import ( popen_launch_server, ) +# image +IMAGE_MAN_IRONING_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/images/man_ironing_on_back_of_suv.png" +IMAGE_SGL_LOGO_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/images/sgl_logo.png" + +# video +VIDEO_JOBS_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/videos/jobs_presenting_ipod.mp4" + +# audio +AUDIO_TRUMP_SPEECH_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/audios/Trump_WEF_2018_10s.mp3" +AUDIO_BIRD_SONG_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/audios/bird_song.mp3" + class TestOpenAIVisionServer(unittest.TestCase): @classmethod @@ -58,9 +69,7 @@ class TestOpenAIVisionServer(unittest.TestCase): "content": [ { "type": "image_url", - "image_url": { - "url": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true" - }, + "image_url": {"url": IMAGE_MAN_IRONING_URL}, }, { "type": "text", @@ -96,9 +105,7 @@ class TestOpenAIVisionServer(unittest.TestCase): "content": [ { "type": "image_url", - "image_url": { - "url": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true" - }, + "image_url": {"url": IMAGE_MAN_IRONING_URL}, }, { "type": "text", @@ -153,9 +160,7 @@ class TestOpenAIVisionServer(unittest.TestCase): }, { "type": "image_url", - "image_url": { - "url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png" - }, + "image_url": {"url": IMAGE_SGL_LOGO_URL}, "modalities": "multi-images", }, { @@ -242,10 +247,12 @@ class TestOpenAIVisionServer(unittest.TestCase): ] return messages - def test_video_chat_completion(self): - url = "https://raw.githubusercontent.com/EvolvingLMMs-Lab/sglang/dev/onevision_local/assets/jobs.mp4" + def get_or_download_file(self, url: str) -> str: cache_dir = os.path.expanduser("~/.cache") - file_path = os.path.join(cache_dir, "jobs.mp4") + if url is None: + raise ValueError() + file_name = url.split("/")[-1] + file_path = os.path.join(cache_dir, file_name) os.makedirs(cache_dir, exist_ok=True) if not os.path.exists(file_path): @@ -254,6 +261,11 @@ class TestOpenAIVisionServer(unittest.TestCase): with open(file_path, "wb") as f: f.write(response.content) + return file_path + + def test_video_chat_completion(self): + url = VIDEO_JOBS_URL + file_path = self.get_or_download_file(url) client = openai.Client(api_key=self.api_key, base_url=self.base_url) @@ -289,6 +301,7 @@ class TestOpenAIVisionServer(unittest.TestCase): "present" in video_response or "examine" in video_response or "display" in video_response + or "hold" in video_response ) assert "black" in video_response or "dark" in video_response self.assertIsNotNone(video_response) @@ -312,9 +325,7 @@ class TestOpenAIVisionServer(unittest.TestCase): "content": [ { "type": "image_url", - "image_url": { - "url": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true" - }, + "image_url": {"url": IMAGE_MAN_IRONING_URL}, }, { "type": "text", @@ -344,18 +355,14 @@ class TestOpenAIVisionServer(unittest.TestCase): content.append( { "type": "image_url", - "image_url": { - "url": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true" - }, + "image_url": {"url": IMAGE_MAN_IRONING_URL}, } ) elif image_id == 1: content.append( { "type": "image_url", - "image_url": { - "url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png" - }, + "image_url": {"url": IMAGE_SGL_LOGO_URL}, } ) else: @@ -465,9 +472,7 @@ class TestVLMContextLengthIssue(unittest.TestCase): "content": [ { "type": "image_url", - "image_url": { - "url": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true" - }, + "image_url": {"url": IMAGE_MAN_IRONING_URL}, }, { "type": "text", diff --git a/test/srt/test_vlm_accuracy.py b/test/srt/test_vlm_accuracy.py index 01ff6c445..87985e3fd 100644 --- a/test/srt/test_vlm_accuracy.py +++ b/test/srt/test_vlm_accuracy.py @@ -13,6 +13,8 @@ from transformers import AutoModel, AutoProcessor, AutoTokenizer from sglang.srt.configs.model_config import ModelConfig from sglang.srt.conversation import generate_chat_conv +from sglang.srt.managers.mm_utils import embed_image_inputs +from sglang.srt.managers.schedule_batch import ImageInputs from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.openai_api.protocol import ChatCompletionRequest from sglang.srt.server_args import ServerArgs @@ -168,10 +170,14 @@ class TestMiniCPMVLogits(VisionLLMLogitsBase): ).eval() cls.model.to(cls.device) - async def test_encode_output(self): + async def test_vlm_embedding_output(self): + """ + Compares the embedding output of vlm + """ inputs = self.get_processor_output() with torch.no_grad(): + # hf model_inputs = { "input_ids": inputs.input_ids, "image_bound": inputs.image_bound, @@ -183,22 +189,20 @@ class TestMiniCPMVLogits(VisionLLMLogitsBase): ) hf_output = hf_output.squeeze(0) - with torch.no_grad(): + # sglang model = self.get_sglang_model() input_ids = inputs["input_ids"].to(self.device).flatten() - image_inputs = model._parse_and_validate_inputs( + sglang_output = embed_image_inputs( + image_input=ImageInputs( + pixel_values=inputs["pixel_values"][0], + tgt_sizes=inputs["tgt_sizes"][0], + ), input_ids=input_ids, - **{ - "pixel_values": [inputs["pixel_values"]], - "tgt_sizes": [inputs["tgt_sizes"]], - "im_start_id": self.tokenizer.im_start_id, - "im_end_id": self.tokenizer.im_end_id, - "slice_start_id": self.tokenizer.slice_start_id, - "slice_end_id": self.tokenizer.slice_end_id, - }, - ) - (sglang_output, _) = model.get_embedding( - input_ids=input_ids, image_inputs=image_inputs + input_embedding=model.get_input_embeddings(), + image_embedding_func=model.get_image_features, + placeholder_token_ids=[ + self.processor.tokenizer.unk_token_id, + ], ) self.compare_outputs(sglang_output, hf_output)