From f19a9204cdbd6b8360f85f404463aa45af9ee00b Mon Sep 17 00:00:00 2001 From: Yury Sulsky Date: Fri, 16 May 2025 12:26:15 -0700 Subject: [PATCH] Support precomputed multimodal features for Qwen-VL and Gemma3 models. (#6136) Co-authored-by: Yury Sulsky --- docs/backend/vlm_query.ipynb | 163 ++++++++++++++++++ python/sglang/srt/entrypoints/engine.py | 13 +- python/sglang/srt/managers/io_struct.py | 8 +- python/sglang/srt/managers/mm_utils.py | 11 +- .../multimodal_processors/base_processor.py | 116 +++++++++---- .../managers/multimodal_processors/gemma3.py | 29 +++- .../managers/multimodal_processors/qwen_vl.py | 70 ++++++-- python/sglang/srt/managers/schedule_batch.py | 45 +++-- .../sglang/srt/managers/session_controller.py | 2 +- python/sglang/srt/models/gemma3_mm.py | 6 + python/sglang/srt/models/qwen2_5_vl.py | 6 + python/sglang/srt/models/qwen2_vl.py | 6 + test/srt/test_skip_tokenizer_init.py | 99 ++++++----- test/srt/test_vlm_accuracy.py | 143 ++++++++++++++- 14 files changed, 592 insertions(+), 125 deletions(-) create mode 100644 docs/backend/vlm_query.ipynb diff --git a/docs/backend/vlm_query.ipynb b/docs/backend/vlm_query.ipynb new file mode 100644 index 000000000..7aba8dfb8 --- /dev/null +++ b/docs/backend/vlm_query.ipynb @@ -0,0 +1,163 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "# Querying Qwen-VL" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "import nest_asyncio\n", + "\n", + "nest_asyncio.apply() # Run this first.\n", + "\n", + "model_path = \"Qwen/Qwen2.5-VL-3B-Instruct\"\n", + "chat_template = \"qwen2-vl\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": {}, + "outputs": [], + "source": [ + "# Lets create a prompt.\n", + "\n", + "from io import BytesIO\n", + "import requests\n", + "from PIL import Image\n", + "\n", + "from sglang.srt.openai_api.protocol import ChatCompletionRequest\n", + "from sglang.srt.conversation import chat_templates\n", + "\n", + "image = Image.open(\n", + " BytesIO(\n", + " requests.get(\n", + " \"https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true\"\n", + " ).content\n", + " )\n", + ")\n", + "\n", + "conv = chat_templates[chat_template].copy()\n", + "conv.append_message(conv.roles[0], f\"What's shown here: {conv.image_token}?\")\n", + "conv.append_message(conv.roles[1], \"\")\n", + "conv.image_data = [image]\n", + "\n", + "print(conv.get_prompt())\n", + "image" + ] + }, + { + "cell_type": "markdown", + "id": "3", + "metadata": {}, + "source": [ + "## Query via the offline Engine API" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "from sglang import Engine\n", + "\n", + "llm = Engine(\n", + " model_path=model_path, chat_template=chat_template, mem_fraction_static=0.8\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "out = llm.generate(prompt=conv.get_prompt(), image_data=[image])\n", + "print(out[\"text\"])" + ] + }, + { + "cell_type": "markdown", + "id": "6", + "metadata": {}, + "source": [ + "## Query via the offline Engine API, but send precomputed embeddings" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "# Compute the image embeddings using Huggingface.\n", + "\n", + "from transformers import AutoProcessor\n", + "from transformers import Qwen2_5_VLForConditionalGeneration\n", + "\n", + "processor = AutoProcessor.from_pretrained(model_path, use_fast=True)\n", + "vision = (\n", + " Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path).eval().visual.cuda()\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "processed_prompt = processor(\n", + " images=[image], text=conv.get_prompt(), return_tensors=\"pt\"\n", + ")\n", + "input_ids = processed_prompt[\"input_ids\"][0].detach().cpu().tolist()\n", + "precomputed_features = vision(\n", + " processed_prompt[\"pixel_values\"].cuda(), processed_prompt[\"image_grid_thw\"].cuda()\n", + ")\n", + "\n", + "mm_item = dict(\n", + " modality=\"IMAGE\",\n", + " image_grid_thws=processed_prompt[\"image_grid_thw\"],\n", + " precomputed_features=precomputed_features,\n", + ")\n", + "out = llm.generate(input_ids=input_ids, image_data=[mm_item])\n", + "print(out[\"text\"])" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "custom_cell_magics": "kql", + "encoding": "# -*- coding: utf-8 -*-" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index f7b1c23fe..c80133b19 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -47,6 +47,7 @@ from sglang.srt.managers.io_struct import ( EmbeddingReqInput, GenerateReqInput, GetWeightsByNameReqInput, + ImageDataItem, InitWeightsUpdateGroupReqInput, ReleaseMemoryOccupationReqInput, ResumeMemoryOccupationReqInput, @@ -150,9 +151,9 @@ class Engine(EngineBase): # See also python/sglang/srt/utils.py:load_image for more details. image_data: Optional[ Union[ - List[List[Union[Image, str]]], - List[Union[Image, str]], - Union[Image, str], + List[List[ImageDataItem]], + List[ImageDataItem], + ImageDataItem, ] ] = None, return_logprob: Optional[Union[List[bool], bool]] = False, @@ -221,9 +222,9 @@ class Engine(EngineBase): # See also python/sglang/srt/utils.py:load_image for more details. image_data: Optional[ Union[ - List[List[Union[Image, str]]], - List[Union[Image, str]], - Union[Image, str], + List[List[ImageDataItem]], + List[ImageDataItem], + ImageDataItem, ] ] = None, return_logprob: Optional[Union[List[bool], bool]] = False, diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 5390668cf..dfb3b6eb2 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -40,6 +40,10 @@ class SessionParams: replace: Optional[bool] = None +AudioDataItem = Union[str, Dict] +ImageDataItem = Union[Image, str, Dict] + + @dataclass class GenerateReqInput: # The input prompt. It can be a single prompt or a batch of prompts. @@ -55,10 +59,10 @@ class GenerateReqInput: # - List of lists of images (multiple images per request) # See also python/sglang/srt/utils.py:load_image for more details. image_data: Optional[ - Union[List[List[Union[Image, str]]], List[Union[Image, str]], Union[Image, str]] + Union[List[List[ImageDataItem]], List[ImageDataItem], ImageDataItem] ] = None # The audio input. Like image data, it can be a file name, a url, or base64 encoded string. - audio_data: Optional[Union[List[str], str]] = None + audio_data: Optional[Union[List[AudioDataItem], AudioDataItem]] = None # The sampling_params. See descriptions below. sampling_params: Optional[Union[List[Dict], Dict]] = None # The request id. diff --git a/python/sglang/srt/managers/mm_utils.py b/python/sglang/srt/managers/mm_utils.py index b5ef4cbce..2c8cad5ac 100644 --- a/python/sglang/srt/managers/mm_utils.py +++ b/python/sglang/srt/managers/mm_utils.py @@ -368,13 +368,13 @@ def general_mm_embed_routine( input_ids: torch.Tensor, forward_batch: ForwardBatch, language_model: nn.Module, - image_data_embedding_func: Callable[ - [List[MultimodalDataItem]], torch.Tensor + image_data_embedding_func: Optional[ + Callable[[List[MultimodalDataItem]], torch.Tensor] ] = None, - audio_data_embedding_func: Callable[ - [List[MultimodalDataItem]], torch.Tensor + audio_data_embedding_func: Optional[ + Callable[[List[MultimodalDataItem]], torch.Tensor] ] = None, - placeholder_tokens: dict[Modality, List[int]] = None, + placeholder_tokens: Optional[dict[Modality, List[int]]] = None, **kwargs, ) -> torch.Tensor: """ @@ -389,7 +389,6 @@ def general_mm_embed_routine( forwarded hidden states """ - assert hasattr(language_model, "get_input_embeddings") embed_tokens = language_model.get_input_embeddings() if ( diff --git a/python/sglang/srt/managers/multimodal_processors/base_processor.py b/python/sglang/srt/managers/multimodal_processors/base_processor.py index aafa63c3b..a6070cc0f 100644 --- a/python/sglang/srt/managers/multimodal_processors/base_processor.py +++ b/python/sglang/srt/managers/multimodal_processors/base_processor.py @@ -3,16 +3,16 @@ import concurrent.futures import dataclasses import multiprocessing as mp import os +import re from abc import ABC, abstractmethod -from typing import List, Optional +from typing import List, Optional, Union import numpy as np -import PIL import torch from PIL import Image from transformers import BaseImageProcessorFast -from sglang.srt.managers.schedule_batch import Modality +from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem from sglang.srt.utils import encode_video, load_audio, load_image @@ -22,13 +22,13 @@ class BaseMultiModalProcessorOutput: input_text: str # frames loaded from image and video, in given order - images: Optional[list[PIL.Image]] = None + images: Optional[list[Union[Image.Image, MultimodalDataItem]]] = None # audios - audios: Optional[list[np.ndarray]] = None + audios: Optional[list[Union[np.ndarray, MultimodalDataItem]]] = None def normalize(self): - for field_name in ["image_sizes", "images", "audios"]: + for field_name in ["images", "audios"]: field = getattr(self, field_name, None) if field is not None and isinstance(field, list) and len(field) == 0: setattr(self, field_name, None) @@ -40,12 +40,32 @@ class MultimodalSpecialTokens: video_token: Optional[str] = None audio_token: Optional[str] = None - def collect(self) -> list[str]: - return [ - token - for token in [self.image_token, self.video_token, self.audio_token] - if token + image_token_regex: Optional[re.Pattern] = None + video_token_regex: Optional[re.Pattern] = None + audio_token_regex: Optional[re.Pattern] = None + + def __post_init__(self): + if self.image_token_regex is None and self.image_token is not None: + self.image_token_regex = re.compile(re.escape(self.image_token)) + if self.video_token_regex is None and self.video_token is not None: + self.video_token_regex = re.compile(re.escape(self.video_token)) + if self.audio_token_regex is None and self.audio_token is not None: + self.audio_token_regex = re.compile(re.escape(self.audio_token)) + + def collect(self) -> re.Pattern: + tokens = [ + self.image_token_regex, + self.video_token_regex, + self.audio_token_regex, ] + patterns = [] + flags = 0 + for t in tokens: + if t is not None: + patterns.append(t.pattern) + flags |= t.flags + combined = "(" + "|".join(f"(?:{p})" for p in patterns) + ")" + return re.compile(combined, flags) class BaseMultimodalProcessor(ABC): @@ -136,6 +156,10 @@ class BaseMultimodalProcessor(ABC): data, is_video, is_audio, frame_count_limit=None, discard_alpha_channel=True ): """Static method that can be pickled for multiprocessing""" + if isinstance(data, dict): + return MultimodalDataItem.from_dict(data) + if isinstance(data, MultimodalDataItem): + return data try: if is_audio: return load_audio(data) @@ -175,7 +199,10 @@ class BaseMultimodalProcessor(ABC): image_index, audio_index = 0, 0 for text_part in text_parts: - if text_part == multimodal_tokens.image_token: + if ( + multimodal_tokens.image_token_regex + and multimodal_tokens.image_token_regex.match(text_part) + ): data = image_data[image_index] is_video = isinstance(data, str) and data.startswith("video:") estimated_frames = estimated_frames_list[image_index] @@ -192,7 +219,10 @@ class BaseMultimodalProcessor(ABC): ) task_info.append((Modality.IMAGE, data, frame_count_limit)) image_index += 1 - elif text_part == multimodal_tokens.audio_token: + elif ( + multimodal_tokens.audio_token_regex + and multimodal_tokens.audio_token_regex.match(text_part) + ): data = audio_data[audio_index] futures.append( self.io_executor.submit( @@ -228,17 +258,22 @@ class BaseMultimodalProcessor(ABC): discard_alpha_channel: if True, discards the alpha channel in the returned images """ + if not return_text: + raise NotImplementedError() if image_data is None: image_data = [] if isinstance(multimodal_tokens.image_token, int): - multimodal_tokens.image_token = ( - self._processor.tokenizer.convert_ids_to_tokens( - multimodal_tokens.image_token + multimodal_tokens.image_token = re.compile( + re.escape( + self._processor.tokenizer.convert_ids_to_tokens( + multimodal_tokens.image_token + ) ) ) else: multimodal_tokens.image_token = multimodal_tokens.image_token + multimodal_tokens_pattern = multimodal_tokens.collect() if isinstance(prompt, list) and return_text: assert len(prompt) and isinstance(prompt[0], int) @@ -247,16 +282,8 @@ class BaseMultimodalProcessor(ABC): prompt = prompt assert isinstance(prompt, str) - if return_text: - import re - - pattern = ( - "(" - + "|".join(re.escape(sep) for sep in multimodal_tokens.collect()) - + ")" - ) - # split text into list of normal text and special tokens - text_parts = re.split(pattern, prompt) + # split text into list of normal text and special tokens + text_parts = re.split(multimodal_tokens_pattern, prompt) futures, task_info = self.submit_data_loading_tasks( text_parts=text_parts, @@ -266,26 +293,40 @@ class BaseMultimodalProcessor(ABC): discard_alpha_channel=discard_alpha_channel, ) # Process results - image_sizes, images, audios = [], [], [] + images, audios = [], [] new_text = "" task_ptr = 0 for text_part in text_parts: - if text_part in multimodal_tokens.collect(): + if multimodal_tokens_pattern.match(text_part): task_type, data, frame_limit = task_info[task_ptr] result = futures[task_ptr].result() task_ptr += 1 if task_type == Modality.IMAGE: + # If data is already processed it will be a + # dictionary. In this case we want to keep the + # expanded tokens in text_part. Otherwise, we will + # call the processor code, so keep only a single image + # token. + mm_tokens = ( + text_part + if isinstance(data, dict) + else multimodal_tokens.image_token + ) frames = [result] if not isinstance(result, list) else result if frames: - image_sizes += frames[0].size * len(frames) images += frames - new_text += multimodal_tokens.image_token * len(frames) + new_text += mm_tokens * len(frames) elif task_type == Modality.AUDIO: # audio + mm_tokens = ( + text_part + if isinstance(data, dict) + else multimodal_tokens.audio_token + ) audios.append(result) - new_text += multimodal_tokens.audio_token + new_text += mm_tokens # TODO: handle video else: new_text += text_part @@ -297,3 +338,16 @@ class BaseMultimodalProcessor(ABC): ) out.normalize() return out + + def mm_inputs_are_preprocessed(self, mm_inputs: Optional[list]): + """Returns true if all images are preprocessed, false if all are not, and error otherwise.""" + if not mm_inputs: + return True + ret = any(isinstance(mm_input, MultimodalDataItem) for mm_input in mm_inputs) + if ret and not all( + isinstance(mm_input, MultimodalDataItem) for mm_input in mm_inputs + ): + raise ValueError( + "Unsupported: mixture of multimodal inputs where some but not all are preprocessed." + ) + return ret diff --git a/python/sglang/srt/managers/multimodal_processors/gemma3.py b/python/sglang/srt/managers/multimodal_processors/gemma3.py index 3b45d27d1..481a31718 100644 --- a/python/sglang/srt/managers/multimodal_processors/gemma3.py +++ b/python/sglang/srt/managers/multimodal_processors/gemma3.py @@ -1,4 +1,5 @@ -from typing import List, Union +import re +from typing import Dict, List, Union from sglang.srt.managers.multimodal_processor import ( BaseMultimodalProcessor as SGLangBaseProcessor, @@ -18,13 +19,18 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor): def __init__(self, hf_config, server_args, _processor): super().__init__(hf_config, server_args, _processor) + # The single, pre-expanded image token. self.IMAGE_TOKEN = "" + # The regex that matches expanded image tokens. + self.IMAGE_TOKEN_REGEX = re.compile( + r"(?:(?:)*)?" + ) self.IM_START_TOKEN_ID = hf_config.boi_token_index self.IM_END_TOKEN_ID = hf_config.eoi_token_index async def process_mm_data_async( self, - image_data: List[Union[str, bytes]], + image_data: List[Union[str, bytes, Dict]], input_text, request_obj, max_req_input_len, @@ -37,22 +43,35 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor): image_data = [image_data] image_token = self.IMAGE_TOKEN + image_token_regex = self.IMAGE_TOKEN_REGEX base_output = self.load_mm_data( prompt=input_text, image_data=image_data, - multimodal_tokens=MultimodalSpecialTokens(image_token=image_token), + multimodal_tokens=MultimodalSpecialTokens( + image_token=image_token, image_token_regex=image_token_regex + ), max_req_input_len=max_req_input_len, discard_alpha_channel=True, ) + images_are_preprocessed = self.mm_inputs_are_preprocessed(base_output.images) ret = self.process_mm_data( - input_text=base_output.input_text, images=base_output.images + input_text=base_output.input_text, + images=None if images_are_preprocessed else base_output.images, ) items = [] for i, image in enumerate(base_output.images): + if images_are_preprocessed: + pixel_values = image.pixel_values + precomputed_features = image.precomputed_features + else: + pixel_values = ret["pixel_values"][i] + precomputed_features = None + item = MultimodalDataItem( - pixel_values=ret["pixel_values"][i], + pixel_values=pixel_values, + precomputed_features=precomputed_features, modality=Modality.IMAGE, ) items += [item] diff --git a/python/sglang/srt/managers/multimodal_processors/qwen_vl.py b/python/sglang/srt/managers/multimodal_processors/qwen_vl.py index 962a2acd7..ef7ed44b3 100644 --- a/python/sglang/srt/managers/multimodal_processors/qwen_vl.py +++ b/python/sglang/srt/managers/multimodal_processors/qwen_vl.py @@ -1,6 +1,7 @@ import asyncio import math -from typing import List, Union +import re +from typing import Dict, List, Union import torch from PIL import Image @@ -23,7 +24,12 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor): def __init__(self, hf_config, server_args, _processor): super().__init__(hf_config, server_args, _processor) + # The single, pre-expanded image token. self.IMAGE_TOKEN = "<|vision_start|><|image_pad|><|vision_end|>" + # The regex that matches expanded image tokens. + self.IMAGE_TOKEN_REGEX = re.compile( + r"<\|vision_start\|>(?:<\|image_pad\|>)+<\|vision_end\|>" + ) self.IM_START_TOKEN_ID = hf_config.vision_start_token_id self.IM_END_TOKEN_ID = hf_config.vision_end_token_id self.image_token_id = hf_config.image_token_id @@ -38,7 +44,7 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor): async def process_mm_data_async( self, - image_data: List[Union[str, bytes]], + image_data: List[Union[str, bytes, Dict]], input_text, request_obj, max_req_input_len, @@ -48,11 +54,13 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor): if isinstance(image_data, str): image_data = [image_data] - image_token = self.IMAGE_TOKEN base_output = self.load_mm_data( prompt=input_text, image_data=image_data, - multimodal_tokens=MultimodalSpecialTokens(image_token=image_token), + multimodal_tokens=MultimodalSpecialTokens( + image_token=self.IMAGE_TOKEN, + image_token_regex=self.IMAGE_TOKEN_REGEX, + ), max_req_input_len=max_req_input_len, ) @@ -117,26 +125,56 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor): async def resize_image_async(image): return resize_image(image) - if base_output.images: + images_are_preprocessed = self.mm_inputs_are_preprocessed(base_output.images) + if base_output.images and not images_are_preprocessed: resize_tasks = [resize_image_async(image) for image in base_output.images] base_output.images = await asyncio.gather(*resize_tasks) ret = self.process_mm_data( input_text=base_output.input_text, - images=base_output.images, + images=None if images_are_preprocessed else base_output.images, ) - + input_ids = ret["input_ids"].flatten().tolist() + image_grid_thw = None + video_grid_thw = None # TODO items = [] - input_ids = ret["input_ids"].flatten().tolist() - if "pixel_values" in ret: + if base_output.images: + if images_are_preprocessed: + image_grid_thw = torch.concat( + [ + torch.as_tensor(item.image_grid_thws) + for item in base_output.images + ] + ) + all_pixel_values = [ + item.pixel_values + for item in base_output.images + if item.pixel_values is not None + ] + all_precomputed_features = [ + item.precomputed_features + for item in base_output.images + if item.precomputed_features is not None + ] + pixel_values = ( + torch.concat(all_pixel_values) if all_pixel_values else None + ) + precomputed_features = ( + torch.concat(all_precomputed_features) + if all_precomputed_features + else None + ) + else: + image_grid_thw = ret["image_grid_thw"] + pixel_values = ret["pixel_values"] + precomputed_features = None items += [ MultimodalDataItem( - pixel_values=ret["pixel_values"], - image_grid_thws=torch.concat([ret["image_grid_thw"]]), - # TODO - video_grid_thws=None, - second_per_grid_ts=ret.get("second_per_grid_ts", None), + pixel_values=pixel_values, + image_grid_thws=image_grid_thw, + video_grid_thws=video_grid_thw, + precomputed_features=precomputed_features, modality=Modality.IMAGE, ) ] @@ -151,8 +189,8 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor): self.hf_config.vision_config, "tokens_per_second", None ), input_ids=torch.tensor(input_ids).unsqueeze(0), - image_grid_thw=ret.get("image_grid_thw", None), - video_grid_thw=ret.get("video_grid_thw", None), + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, second_per_grid_ts=ret.get("second_per_grid_ts", None), ) mrope_positions = mrope_positions.squeeze(1) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index e9bf68b32..c0464541d 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -177,10 +177,10 @@ class MultimodalDataItem: image_offsets: Optional[list] = None # the real data, pixel_values or audio_features - # data: Union[List[torch.Tensor], List[np.array]] - pixel_values: Union[torch.Tensor, np.array] = None - image_grid_thws: Union[torch.Tensor, np.array] = None - video_grid_thws: Union[torch.Tensor, np.array] = None + # data: Union[List[torch.Tensor], List[np.ndarray]] + pixel_values: Union[torch.Tensor, np.ndarray] = None + image_grid_thws: Union[torch.Tensor, np.ndarray] = None + video_grid_thws: Union[torch.Tensor, np.ndarray] = None image_emb_mask: Optional[torch.Tensor] = None image_spatial_crop: Optional[torch.Tensor] = None @@ -189,9 +189,11 @@ class MultimodalDataItem: # [num_images, (n, w, h)] tgt_size: Tuple[int, int] = None - audio_features: Union[torch.Tensor, np.array] = None + audio_features: Union[torch.Tensor, np.ndarray] = None audio_feature_lens: Optional[List[torch.Tensor]] = None + precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None + @staticmethod def is_empty_list(l): if l is None: @@ -249,7 +251,9 @@ class MultimodalDataItem: return tensor_hash([f]) return data_hash(f) - if self.is_audio(): + if self.precomputed_features is not None: + self.hash = hash_feature(self.precomputed_features) + elif self.is_audio(): self.hash = hash_feature(self.audio_features) else: self.hash = hash_feature(self.pixel_values) @@ -258,19 +262,24 @@ class MultimodalDataItem: self.pad_value = self.hash % (1 << 30) def is_audio(self): - return ( - self.modality == Modality.AUDIO - ) and not MultimodalDataItem.is_empty_list(self.audio_features) + return (self.modality == Modality.AUDIO) and ( + self.precomputed_features is not None + or not MultimodalDataItem.is_empty_list(self.audio_features) + ) def is_image(self): return ( self.modality == Modality.IMAGE or self.modality == Modality.MULTI_IMAGES - ) and not MultimodalDataItem.is_empty_list(self.pixel_values) + ) and ( + self.precomputed_features is not None + or not MultimodalDataItem.is_empty_list(self.pixel_values) + ) def is_video(self): - return ( - self.modality == Modality.VIDEO - ) and not MultimodalDataItem.is_empty_list(self.pixel_values) + return (self.modality == Modality.VIDEO) and ( + self.precomputed_features is not None + or not MultimodalDataItem.is_empty_list(self.pixel_values) + ) def is_valid(self) -> bool: return self.is_image() or self.is_video() or self.is_audio() @@ -279,6 +288,16 @@ class MultimodalDataItem: ... # TODO + @staticmethod + def from_dict(obj: dict): + kwargs = dict(obj) + modality = kwargs.pop("modality") + if isinstance(modality, str): + modality = Modality[modality] + ret = MultimodalDataItem(modality=modality, **kwargs) + ret.validate() + return ret + @dataclasses.dataclass class MultimodalInputs: diff --git a/python/sglang/srt/managers/session_controller.py b/python/sglang/srt/managers/session_controller.py index 0a132adfd..65babbf99 100644 --- a/python/sglang/srt/managers/session_controller.py +++ b/python/sglang/srt/managers/session_controller.py @@ -54,7 +54,7 @@ class SessionReqNode: prefix += " -- " + self.childs[0].req.rid ret = self.childs[0]._str_helper(prefix) for child in self.childs[1:]: - prefix = " " * len(origin_prefix) + " \- " + child.req.rid + prefix = " " * len(origin_prefix) + r" \- " + child.req.rid ret += child._str_helper(prefix) return ret diff --git a/python/sglang/srt/models/gemma3_mm.py b/python/sglang/srt/models/gemma3_mm.py index 824c58916..b26a1d603 100644 --- a/python/sglang/srt/models/gemma3_mm.py +++ b/python/sglang/srt/models/gemma3_mm.py @@ -278,6 +278,12 @@ class Gemma3ForConditionalGeneration(PreTrainedModel): Returns: image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). """ + if any(item.precomputed_features is not None for item in items): + if not all(item.precomputed_features is not None for item in items): + raise NotImplementedError( + "MM inputs where only some items are precomputed." + ) + return torch.concat([item.precomputed_features for item in items]) pixel_values = torch.stack( flatten_nested_list([item.pixel_values for item in items]), dim=0 ) diff --git a/python/sglang/srt/models/qwen2_5_vl.py b/python/sglang/srt/models/qwen2_5_vl.py index 1d52c92cd..7cc24f182 100644 --- a/python/sglang/srt/models/qwen2_5_vl.py +++ b/python/sglang/srt/models/qwen2_5_vl.py @@ -497,6 +497,12 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module): return pattern.pad_input_tokens(input_ids, mm_inputs) def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: + if any(item.precomputed_features is not None for item in items): + if not all(item.precomputed_features is not None for item in items): + raise NotImplementedError( + "MM inputs where only some items are precomputed." + ) + return torch.concat([item.precomputed_features for item in items]) # in qwen-vl, last dim is the same pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type( self.visual.dtype diff --git a/python/sglang/srt/models/qwen2_vl.py b/python/sglang/srt/models/qwen2_vl.py index f653401d8..b4421290e 100644 --- a/python/sglang/srt/models/qwen2_vl.py +++ b/python/sglang/srt/models/qwen2_vl.py @@ -486,6 +486,12 @@ class Qwen2VLForConditionalGeneration(nn.Module): return pattern.pad_input_tokens(input_ids, mm_inputs) def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: + if any(item.precomputed_features is not None for item in items): + if not all(item.precomputed_features is not None for item in items): + raise NotImplementedError( + "MM inputs where only some items are precomputed." + ) + return torch.concat([item.precomputed_features for item in items]) # in qwen-vl, last dim is the same pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type( self.visual.dtype diff --git a/test/srt/test_skip_tokenizer_init.py b/test/srt/test_skip_tokenizer_init.py index f9eee27b6..02b1b40c6 100644 --- a/test/srt/test_skip_tokenizer_init.py +++ b/test/srt/test_skip_tokenizer_init.py @@ -54,21 +54,17 @@ class TestSkipTokenizerInit(CustomTestCase): ): input_ids = self.get_input_ids(prompt_text) + request = self.get_request_json( + input_ids=input_ids, + return_logprob=return_logprob, + top_logprobs_num=top_logprobs_num, + max_new_tokens=max_new_tokens, + stream=False, + n=n, + ) response = requests.post( self.base_url + "/generate", - json={ - "input_ids": input_ids, - "sampling_params": { - "temperature": 0 if n == 1 else 0.5, - "max_new_tokens": max_new_tokens, - "n": n, - "stop_token_ids": [self.tokenizer.eos_token_id], - }, - "stream": False, - "return_logprob": return_logprob, - "top_logprobs_num": top_logprobs_num, - "logprob_start_len": 0, - }, + json=request, ) ret = response.json() print(json.dumps(ret, indent=2)) @@ -87,9 +83,12 @@ class TestSkipTokenizerInit(CustomTestCase): self.assertEqual(item["meta_info"]["prompt_tokens"], len(input_ids)) if return_logprob: + num_input_logprobs = len(input_ids) - request["logprob_start_len"] + if num_input_logprobs > len(input_ids): + num_input_logprobs -= len(input_ids) self.assertEqual( len(item["meta_info"]["input_token_logprobs"]), - len(input_ids), + num_input_logprobs, f'{len(item["meta_info"]["input_token_logprobs"])} mismatch with {len(input_ids)}', ) self.assertEqual( @@ -113,19 +112,14 @@ class TestSkipTokenizerInit(CustomTestCase): requests.post(self.base_url + "/flush_cache") response = requests.post( self.base_url + "/generate", - json={ - "input_ids": input_ids, - "sampling_params": { - "temperature": 0 if n == 1 else 0.5, - "max_new_tokens": max_new_tokens, - "n": n, - "stop_token_ids": self.eos_token_id, - }, - "stream": False, - "return_logprob": return_logprob, - "top_logprobs_num": top_logprobs_num, - "logprob_start_len": 0, - }, + json=self.get_request_json( + input_ids=input_ids, + max_new_tokens=max_new_tokens, + return_logprob=return_logprob, + top_logprobs_num=top_logprobs_num, + stream=False, + n=n, + ), ) ret = response.json() print(json.dumps(ret)) @@ -137,19 +131,13 @@ class TestSkipTokenizerInit(CustomTestCase): requests.post(self.base_url + "/flush_cache") response_stream = requests.post( self.base_url + "/generate", - json={ - "input_ids": input_ids, - "sampling_params": { - "temperature": 0 if n == 1 else 0.5, - "max_new_tokens": max_new_tokens, - "n": n, - "stop_token_ids": self.eos_token_id, - }, - "stream": True, - "return_logprob": return_logprob, - "top_logprobs_num": top_logprobs_num, - "logprob_start_len": 0, - }, + json=self.get_request_json( + input_ids=input_ids, + return_logprob=return_logprob, + top_logprobs_num=top_logprobs_num, + stream=True, + n=n, + ), ) response_stream_json = [] @@ -188,6 +176,29 @@ class TestSkipTokenizerInit(CustomTestCase): ].tolist() return input_ids + def get_request_json( + self, + input_ids, + max_new_tokens=32, + return_logprob=False, + top_logprobs_num=0, + stream=False, + n=1, + ): + return { + "input_ids": input_ids, + "sampling_params": { + "temperature": 0 if n == 1 else 0.5, + "max_new_tokens": max_new_tokens, + "n": n, + "stop_token_ids": self.eos_token_id, + }, + "stream": stream, + "return_logprob": return_logprob, + "top_logprobs_num": top_logprobs_num, + "logprob_start_len": 0, + } + class TestSkipTokenizerInitVLM(TestSkipTokenizerInit): @classmethod @@ -218,6 +229,14 @@ class TestSkipTokenizerInitVLM(TestSkipTokenizerInit): return inputs.input_ids[0].tolist() + def get_request_json(self, *args, **kwargs): + ret = super().get_request_json(*args, **kwargs) + ret["image_data"] = [self.image_url] + ret["logprob_start_len"] = ( + -1 + ) # Do not try to calculate logprobs of image embeddings. + return ret + def test_simple_decode_stream(self): # TODO mick pass diff --git a/test/srt/test_vlm_accuracy.py b/test/srt/test_vlm_accuracy.py index 58b0efa6d..49805230a 100644 --- a/test/srt/test_vlm_accuracy.py +++ b/test/srt/test_vlm_accuracy.py @@ -3,15 +3,22 @@ import unittest from io import BytesIO -from typing import List +from typing import List, Optional import numpy as np import requests import torch import torch.nn.functional as F from PIL import Image -from transformers import AutoModel, AutoProcessor, AutoTokenizer +from transformers import ( + AutoModel, + AutoProcessor, + AutoTokenizer, + Gemma3ForConditionalGeneration, + Qwen2_5_VLForConditionalGeneration, +) +from sglang import Engine from sglang.srt.configs.model_config import ModelConfig from sglang.srt.conversation import generate_chat_conv from sglang.srt.managers.mm_utils import embed_mm_inputs @@ -100,7 +107,7 @@ class VisionLLMLogitsBase(unittest.IsolatedAsyncioTestCase): np.testing.assert_allclose(hf_np, sg_np) - def get_processor_output(self): + def get_completion_request(self) -> ChatCompletionRequest: json_str = f""" {{ "model": "{self.model_path}", @@ -124,10 +131,12 @@ class VisionLLMLogitsBase(unittest.IsolatedAsyncioTestCase): }} """ - req = ChatCompletionRequest.model_validate_json(json_str) + return ChatCompletionRequest.model_validate_json(json_str) + def get_processor_output(self, req: Optional[ChatCompletionRequest] = None): + if req is None: + req = self.get_completion_request() conv = generate_chat_conv(req, template_name=self.chat_template) - text = conv.get_prompt() # Process inputs using processor @@ -239,5 +248,129 @@ class TestMiniCPMVLogits(VisionLLMLogitsBase): self.compare_outputs(sglang_output, hf_output) +class TestQwenVLUnderstandsImage(VisionLLMLogitsBase): + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.model_path = "Qwen/Qwen2.5-VL-3B-Instruct" + cls.chat_template = "qwen2-vl" + cls.processor = AutoProcessor.from_pretrained( + cls.model_path, trust_remote_code=True, use_fast=True + ) + cls.visual = ( + Qwen2_5_VLForConditionalGeneration.from_pretrained( + cls.model_path, torch_dtype=torch.bfloat16 + ) + .eval() + .visual.to(cls.device) + ) + + def setUp(self): + self.engine = Engine( + model_path=self.model_path, + chat_template=self.chat_template, + device=self.device.type, + mem_fraction_static=0.8, + ) + + def tearDown(self): + self.engine.shutdown() + + async def test_qwen_vl_understands_image(self): + req = self.get_completion_request() + conv = generate_chat_conv(req, template_name=self.chat_template) + text = conv.get_prompt() + output = await self.engine.async_generate( + prompt=text, + image_data=[self.main_image], + sampling_params=dict(temperature=0.0), + ) + self.assertIn("taxi", output["text"].lower()) + + async def test_qwen_vl_understands_precomputed_features(self): + req = self.get_completion_request() + processor_output = self.get_processor_output(req=req) + with torch.inference_mode(): + precomputed_features = self.visual( + processor_output["pixel_values"], processor_output["image_grid_thw"] + ) + output = await self.engine.async_generate( + input_ids=processor_output["input_ids"][0].detach().cpu().tolist(), + image_data=[ + dict( + modality="IMAGE", + image_grid_thws=processor_output["image_grid_thw"], + precomputed_features=precomputed_features, + ) + ], + sampling_params=dict(temperature=0.0), + ) + self.assertIn("taxi", output["text"].lower()) + + +class TestGemmaUnderstandsImage(VisionLLMLogitsBase): + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.model_path = "google/gemma-3-4b-it" + cls.chat_template = "gemma-it" + cls.processor = AutoProcessor.from_pretrained( + cls.model_path, trust_remote_code=True, use_fast=True + ) + model = Gemma3ForConditionalGeneration.from_pretrained( + cls.model_path, torch_dtype=torch.bfloat16 + ) + cls.vision_tower = model.vision_tower.eval().to(cls.device) + cls.mm_projector = model.multi_modal_projector.eval().to(cls.device) + + @classmethod + def visual(cls, pixel_values): + vision_outputs = cls.vision_tower(pixel_values=pixel_values).last_hidden_state + image_features = cls.mm_projector(vision_outputs) + return image_features + + def setUp(self): + self.engine = Engine( + model_path=self.model_path, + chat_template=self.chat_template, + device=self.device.type, + mem_fraction_static=0.5, + enable_multimodal=True, + ) + + def tearDown(self): + self.engine.shutdown() + + async def test_gemma_understands_image(self): + req = self.get_completion_request() + conv = generate_chat_conv(req, template_name=self.chat_template) + text = conv.get_prompt() + output = await self.engine.async_generate( + prompt=text, + image_data=[self.main_image], + sampling_params=dict(temperature=0.0), + ) + self.assertIn("taxi", output["text"].lower()) + + async def test_gemma_understands_precomputed_features(self): + req = self.get_completion_request() + processor_output = self.get_processor_output(req=req) + with torch.inference_mode(): + precomputed_features = self.visual(processor_output["pixel_values"]) + output = await self.engine.async_generate( + input_ids=processor_output["input_ids"][0].detach().cpu().tolist(), + image_data=[ + dict( + modality="IMAGE", + precomputed_features=precomputed_features, + ) + ], + sampling_params=dict(temperature=0.0), + ) + self.assertIn("taxi", output["text"].lower()) + + if __name__ == "__main__": unittest.main()