Support precomputed multimodal features for Qwen-VL and Gemma3 models. (#6136)
Co-authored-by: Yury Sulsky <ysulsky@tesla.com>
This commit is contained in:
163
docs/backend/vlm_query.ipynb
Normal file
163
docs/backend/vlm_query.ipynb
Normal file
@@ -0,0 +1,163 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Querying Qwen-VL"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import nest_asyncio\n",
|
||||
"\n",
|
||||
"nest_asyncio.apply() # Run this first.\n",
|
||||
"\n",
|
||||
"model_path = \"Qwen/Qwen2.5-VL-3B-Instruct\"\n",
|
||||
"chat_template = \"qwen2-vl\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Lets create a prompt.\n",
|
||||
"\n",
|
||||
"from io import BytesIO\n",
|
||||
"import requests\n",
|
||||
"from PIL import Image\n",
|
||||
"\n",
|
||||
"from sglang.srt.openai_api.protocol import ChatCompletionRequest\n",
|
||||
"from sglang.srt.conversation import chat_templates\n",
|
||||
"\n",
|
||||
"image = Image.open(\n",
|
||||
" BytesIO(\n",
|
||||
" requests.get(\n",
|
||||
" \"https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true\"\n",
|
||||
" ).content\n",
|
||||
" )\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"conv = chat_templates[chat_template].copy()\n",
|
||||
"conv.append_message(conv.roles[0], f\"What's shown here: {conv.image_token}?\")\n",
|
||||
"conv.append_message(conv.roles[1], \"\")\n",
|
||||
"conv.image_data = [image]\n",
|
||||
"\n",
|
||||
"print(conv.get_prompt())\n",
|
||||
"image"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "3",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Query via the offline Engine API"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from sglang import Engine\n",
|
||||
"\n",
|
||||
"llm = Engine(\n",
|
||||
" model_path=model_path, chat_template=chat_template, mem_fraction_static=0.8\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"out = llm.generate(prompt=conv.get_prompt(), image_data=[image])\n",
|
||||
"print(out[\"text\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "6",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Query via the offline Engine API, but send precomputed embeddings"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Compute the image embeddings using Huggingface.\n",
|
||||
"\n",
|
||||
"from transformers import AutoProcessor\n",
|
||||
"from transformers import Qwen2_5_VLForConditionalGeneration\n",
|
||||
"\n",
|
||||
"processor = AutoProcessor.from_pretrained(model_path, use_fast=True)\n",
|
||||
"vision = (\n",
|
||||
" Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path).eval().visual.cuda()\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"processed_prompt = processor(\n",
|
||||
" images=[image], text=conv.get_prompt(), return_tensors=\"pt\"\n",
|
||||
")\n",
|
||||
"input_ids = processed_prompt[\"input_ids\"][0].detach().cpu().tolist()\n",
|
||||
"precomputed_features = vision(\n",
|
||||
" processed_prompt[\"pixel_values\"].cuda(), processed_prompt[\"image_grid_thw\"].cuda()\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"mm_item = dict(\n",
|
||||
" modality=\"IMAGE\",\n",
|
||||
" image_grid_thws=processed_prompt[\"image_grid_thw\"],\n",
|
||||
" precomputed_features=precomputed_features,\n",
|
||||
")\n",
|
||||
"out = llm.generate(input_ids=input_ids, image_data=[mm_item])\n",
|
||||
"print(out[\"text\"])"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"jupytext": {
|
||||
"cell_metadata_filter": "-all",
|
||||
"custom_cell_magics": "kql",
|
||||
"encoding": "# -*- coding: utf-8 -*-"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -47,6 +47,7 @@ from sglang.srt.managers.io_struct import (
|
||||
EmbeddingReqInput,
|
||||
GenerateReqInput,
|
||||
GetWeightsByNameReqInput,
|
||||
ImageDataItem,
|
||||
InitWeightsUpdateGroupReqInput,
|
||||
ReleaseMemoryOccupationReqInput,
|
||||
ResumeMemoryOccupationReqInput,
|
||||
@@ -150,9 +151,9 @@ class Engine(EngineBase):
|
||||
# See also python/sglang/srt/utils.py:load_image for more details.
|
||||
image_data: Optional[
|
||||
Union[
|
||||
List[List[Union[Image, str]]],
|
||||
List[Union[Image, str]],
|
||||
Union[Image, str],
|
||||
List[List[ImageDataItem]],
|
||||
List[ImageDataItem],
|
||||
ImageDataItem,
|
||||
]
|
||||
] = None,
|
||||
return_logprob: Optional[Union[List[bool], bool]] = False,
|
||||
@@ -221,9 +222,9 @@ class Engine(EngineBase):
|
||||
# See also python/sglang/srt/utils.py:load_image for more details.
|
||||
image_data: Optional[
|
||||
Union[
|
||||
List[List[Union[Image, str]]],
|
||||
List[Union[Image, str]],
|
||||
Union[Image, str],
|
||||
List[List[ImageDataItem]],
|
||||
List[ImageDataItem],
|
||||
ImageDataItem,
|
||||
]
|
||||
] = None,
|
||||
return_logprob: Optional[Union[List[bool], bool]] = False,
|
||||
|
||||
@@ -40,6 +40,10 @@ class SessionParams:
|
||||
replace: Optional[bool] = None
|
||||
|
||||
|
||||
AudioDataItem = Union[str, Dict]
|
||||
ImageDataItem = Union[Image, str, Dict]
|
||||
|
||||
|
||||
@dataclass
|
||||
class GenerateReqInput:
|
||||
# The input prompt. It can be a single prompt or a batch of prompts.
|
||||
@@ -55,10 +59,10 @@ class GenerateReqInput:
|
||||
# - List of lists of images (multiple images per request)
|
||||
# See also python/sglang/srt/utils.py:load_image for more details.
|
||||
image_data: Optional[
|
||||
Union[List[List[Union[Image, str]]], List[Union[Image, str]], Union[Image, str]]
|
||||
Union[List[List[ImageDataItem]], List[ImageDataItem], ImageDataItem]
|
||||
] = None
|
||||
# The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
|
||||
audio_data: Optional[Union[List[str], str]] = None
|
||||
audio_data: Optional[Union[List[AudioDataItem], AudioDataItem]] = None
|
||||
# The sampling_params. See descriptions below.
|
||||
sampling_params: Optional[Union[List[Dict], Dict]] = None
|
||||
# The request id.
|
||||
|
||||
@@ -368,13 +368,13 @@ def general_mm_embed_routine(
|
||||
input_ids: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
language_model: nn.Module,
|
||||
image_data_embedding_func: Callable[
|
||||
[List[MultimodalDataItem]], torch.Tensor
|
||||
image_data_embedding_func: Optional[
|
||||
Callable[[List[MultimodalDataItem]], torch.Tensor]
|
||||
] = None,
|
||||
audio_data_embedding_func: Callable[
|
||||
[List[MultimodalDataItem]], torch.Tensor
|
||||
audio_data_embedding_func: Optional[
|
||||
Callable[[List[MultimodalDataItem]], torch.Tensor]
|
||||
] = None,
|
||||
placeholder_tokens: dict[Modality, List[int]] = None,
|
||||
placeholder_tokens: Optional[dict[Modality, List[int]]] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@@ -389,7 +389,6 @@ def general_mm_embed_routine(
|
||||
forwarded hidden states
|
||||
|
||||
"""
|
||||
|
||||
assert hasattr(language_model, "get_input_embeddings")
|
||||
embed_tokens = language_model.get_input_embeddings()
|
||||
if (
|
||||
|
||||
@@ -3,16 +3,16 @@ import concurrent.futures
|
||||
import dataclasses
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import BaseImageProcessorFast
|
||||
|
||||
from sglang.srt.managers.schedule_batch import Modality
|
||||
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||
from sglang.srt.utils import encode_video, load_audio, load_image
|
||||
|
||||
|
||||
@@ -22,13 +22,13 @@ class BaseMultiModalProcessorOutput:
|
||||
input_text: str
|
||||
|
||||
# frames loaded from image and video, in given order
|
||||
images: Optional[list[PIL.Image]] = None
|
||||
images: Optional[list[Union[Image.Image, MultimodalDataItem]]] = None
|
||||
|
||||
# audios
|
||||
audios: Optional[list[np.ndarray]] = None
|
||||
audios: Optional[list[Union[np.ndarray, MultimodalDataItem]]] = None
|
||||
|
||||
def normalize(self):
|
||||
for field_name in ["image_sizes", "images", "audios"]:
|
||||
for field_name in ["images", "audios"]:
|
||||
field = getattr(self, field_name, None)
|
||||
if field is not None and isinstance(field, list) and len(field) == 0:
|
||||
setattr(self, field_name, None)
|
||||
@@ -40,12 +40,32 @@ class MultimodalSpecialTokens:
|
||||
video_token: Optional[str] = None
|
||||
audio_token: Optional[str] = None
|
||||
|
||||
def collect(self) -> list[str]:
|
||||
return [
|
||||
token
|
||||
for token in [self.image_token, self.video_token, self.audio_token]
|
||||
if token
|
||||
image_token_regex: Optional[re.Pattern] = None
|
||||
video_token_regex: Optional[re.Pattern] = None
|
||||
audio_token_regex: Optional[re.Pattern] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.image_token_regex is None and self.image_token is not None:
|
||||
self.image_token_regex = re.compile(re.escape(self.image_token))
|
||||
if self.video_token_regex is None and self.video_token is not None:
|
||||
self.video_token_regex = re.compile(re.escape(self.video_token))
|
||||
if self.audio_token_regex is None and self.audio_token is not None:
|
||||
self.audio_token_regex = re.compile(re.escape(self.audio_token))
|
||||
|
||||
def collect(self) -> re.Pattern:
|
||||
tokens = [
|
||||
self.image_token_regex,
|
||||
self.video_token_regex,
|
||||
self.audio_token_regex,
|
||||
]
|
||||
patterns = []
|
||||
flags = 0
|
||||
for t in tokens:
|
||||
if t is not None:
|
||||
patterns.append(t.pattern)
|
||||
flags |= t.flags
|
||||
combined = "(" + "|".join(f"(?:{p})" for p in patterns) + ")"
|
||||
return re.compile(combined, flags)
|
||||
|
||||
|
||||
class BaseMultimodalProcessor(ABC):
|
||||
@@ -136,6 +156,10 @@ class BaseMultimodalProcessor(ABC):
|
||||
data, is_video, is_audio, frame_count_limit=None, discard_alpha_channel=True
|
||||
):
|
||||
"""Static method that can be pickled for multiprocessing"""
|
||||
if isinstance(data, dict):
|
||||
return MultimodalDataItem.from_dict(data)
|
||||
if isinstance(data, MultimodalDataItem):
|
||||
return data
|
||||
try:
|
||||
if is_audio:
|
||||
return load_audio(data)
|
||||
@@ -175,7 +199,10 @@ class BaseMultimodalProcessor(ABC):
|
||||
image_index, audio_index = 0, 0
|
||||
|
||||
for text_part in text_parts:
|
||||
if text_part == multimodal_tokens.image_token:
|
||||
if (
|
||||
multimodal_tokens.image_token_regex
|
||||
and multimodal_tokens.image_token_regex.match(text_part)
|
||||
):
|
||||
data = image_data[image_index]
|
||||
is_video = isinstance(data, str) and data.startswith("video:")
|
||||
estimated_frames = estimated_frames_list[image_index]
|
||||
@@ -192,7 +219,10 @@ class BaseMultimodalProcessor(ABC):
|
||||
)
|
||||
task_info.append((Modality.IMAGE, data, frame_count_limit))
|
||||
image_index += 1
|
||||
elif text_part == multimodal_tokens.audio_token:
|
||||
elif (
|
||||
multimodal_tokens.audio_token_regex
|
||||
and multimodal_tokens.audio_token_regex.match(text_part)
|
||||
):
|
||||
data = audio_data[audio_index]
|
||||
futures.append(
|
||||
self.io_executor.submit(
|
||||
@@ -228,17 +258,22 @@ class BaseMultimodalProcessor(ABC):
|
||||
discard_alpha_channel: if True, discards the alpha channel in the returned images
|
||||
|
||||
"""
|
||||
if not return_text:
|
||||
raise NotImplementedError()
|
||||
|
||||
if image_data is None:
|
||||
image_data = []
|
||||
if isinstance(multimodal_tokens.image_token, int):
|
||||
multimodal_tokens.image_token = (
|
||||
self._processor.tokenizer.convert_ids_to_tokens(
|
||||
multimodal_tokens.image_token
|
||||
multimodal_tokens.image_token = re.compile(
|
||||
re.escape(
|
||||
self._processor.tokenizer.convert_ids_to_tokens(
|
||||
multimodal_tokens.image_token
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
multimodal_tokens.image_token = multimodal_tokens.image_token
|
||||
multimodal_tokens_pattern = multimodal_tokens.collect()
|
||||
|
||||
if isinstance(prompt, list) and return_text:
|
||||
assert len(prompt) and isinstance(prompt[0], int)
|
||||
@@ -247,16 +282,8 @@ class BaseMultimodalProcessor(ABC):
|
||||
prompt = prompt
|
||||
|
||||
assert isinstance(prompt, str)
|
||||
if return_text:
|
||||
import re
|
||||
|
||||
pattern = (
|
||||
"("
|
||||
+ "|".join(re.escape(sep) for sep in multimodal_tokens.collect())
|
||||
+ ")"
|
||||
)
|
||||
# split text into list of normal text and special tokens
|
||||
text_parts = re.split(pattern, prompt)
|
||||
# split text into list of normal text and special tokens
|
||||
text_parts = re.split(multimodal_tokens_pattern, prompt)
|
||||
|
||||
futures, task_info = self.submit_data_loading_tasks(
|
||||
text_parts=text_parts,
|
||||
@@ -266,26 +293,40 @@ class BaseMultimodalProcessor(ABC):
|
||||
discard_alpha_channel=discard_alpha_channel,
|
||||
)
|
||||
# Process results
|
||||
image_sizes, images, audios = [], [], []
|
||||
images, audios = [], []
|
||||
new_text = ""
|
||||
task_ptr = 0
|
||||
|
||||
for text_part in text_parts:
|
||||
if text_part in multimodal_tokens.collect():
|
||||
if multimodal_tokens_pattern.match(text_part):
|
||||
task_type, data, frame_limit = task_info[task_ptr]
|
||||
result = futures[task_ptr].result()
|
||||
task_ptr += 1
|
||||
|
||||
if task_type == Modality.IMAGE:
|
||||
# If data is already processed it will be a
|
||||
# dictionary. In this case we want to keep the
|
||||
# expanded tokens in text_part. Otherwise, we will
|
||||
# call the processor code, so keep only a single image
|
||||
# token.
|
||||
mm_tokens = (
|
||||
text_part
|
||||
if isinstance(data, dict)
|
||||
else multimodal_tokens.image_token
|
||||
)
|
||||
frames = [result] if not isinstance(result, list) else result
|
||||
if frames:
|
||||
image_sizes += frames[0].size * len(frames)
|
||||
images += frames
|
||||
new_text += multimodal_tokens.image_token * len(frames)
|
||||
new_text += mm_tokens * len(frames)
|
||||
elif task_type == Modality.AUDIO:
|
||||
# audio
|
||||
mm_tokens = (
|
||||
text_part
|
||||
if isinstance(data, dict)
|
||||
else multimodal_tokens.audio_token
|
||||
)
|
||||
audios.append(result)
|
||||
new_text += multimodal_tokens.audio_token
|
||||
new_text += mm_tokens
|
||||
# TODO: handle video
|
||||
else:
|
||||
new_text += text_part
|
||||
@@ -297,3 +338,16 @@ class BaseMultimodalProcessor(ABC):
|
||||
)
|
||||
out.normalize()
|
||||
return out
|
||||
|
||||
def mm_inputs_are_preprocessed(self, mm_inputs: Optional[list]):
|
||||
"""Returns true if all images are preprocessed, false if all are not, and error otherwise."""
|
||||
if not mm_inputs:
|
||||
return True
|
||||
ret = any(isinstance(mm_input, MultimodalDataItem) for mm_input in mm_inputs)
|
||||
if ret and not all(
|
||||
isinstance(mm_input, MultimodalDataItem) for mm_input in mm_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
"Unsupported: mixture of multimodal inputs where some but not all are preprocessed."
|
||||
)
|
||||
return ret
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import List, Union
|
||||
import re
|
||||
from typing import Dict, List, Union
|
||||
|
||||
from sglang.srt.managers.multimodal_processor import (
|
||||
BaseMultimodalProcessor as SGLangBaseProcessor,
|
||||
@@ -18,13 +19,18 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
# The single, pre-expanded image token.
|
||||
self.IMAGE_TOKEN = "<start_of_image>"
|
||||
# The regex that matches expanded image tokens.
|
||||
self.IMAGE_TOKEN_REGEX = re.compile(
|
||||
r"<start_of_image>(?:(?:<image_soft_token>)*<end_of_image>)?"
|
||||
)
|
||||
self.IM_START_TOKEN_ID = hf_config.boi_token_index
|
||||
self.IM_END_TOKEN_ID = hf_config.eoi_token_index
|
||||
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
image_data: List[Union[str, bytes]],
|
||||
image_data: List[Union[str, bytes, Dict]],
|
||||
input_text,
|
||||
request_obj,
|
||||
max_req_input_len,
|
||||
@@ -37,22 +43,35 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
|
||||
image_data = [image_data]
|
||||
|
||||
image_token = self.IMAGE_TOKEN
|
||||
image_token_regex = self.IMAGE_TOKEN_REGEX
|
||||
base_output = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
image_data=image_data,
|
||||
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
|
||||
multimodal_tokens=MultimodalSpecialTokens(
|
||||
image_token=image_token, image_token_regex=image_token_regex
|
||||
),
|
||||
max_req_input_len=max_req_input_len,
|
||||
discard_alpha_channel=True,
|
||||
)
|
||||
|
||||
images_are_preprocessed = self.mm_inputs_are_preprocessed(base_output.images)
|
||||
ret = self.process_mm_data(
|
||||
input_text=base_output.input_text, images=base_output.images
|
||||
input_text=base_output.input_text,
|
||||
images=None if images_are_preprocessed else base_output.images,
|
||||
)
|
||||
|
||||
items = []
|
||||
for i, image in enumerate(base_output.images):
|
||||
if images_are_preprocessed:
|
||||
pixel_values = image.pixel_values
|
||||
precomputed_features = image.precomputed_features
|
||||
else:
|
||||
pixel_values = ret["pixel_values"][i]
|
||||
precomputed_features = None
|
||||
|
||||
item = MultimodalDataItem(
|
||||
pixel_values=ret["pixel_values"][i],
|
||||
pixel_values=pixel_values,
|
||||
precomputed_features=precomputed_features,
|
||||
modality=Modality.IMAGE,
|
||||
)
|
||||
items += [item]
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import math
|
||||
from typing import List, Union
|
||||
import re
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
@@ -23,7 +24,12 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
# The single, pre-expanded image token.
|
||||
self.IMAGE_TOKEN = "<|vision_start|><|image_pad|><|vision_end|>"
|
||||
# The regex that matches expanded image tokens.
|
||||
self.IMAGE_TOKEN_REGEX = re.compile(
|
||||
r"<\|vision_start\|>(?:<\|image_pad\|>)+<\|vision_end\|>"
|
||||
)
|
||||
self.IM_START_TOKEN_ID = hf_config.vision_start_token_id
|
||||
self.IM_END_TOKEN_ID = hf_config.vision_end_token_id
|
||||
self.image_token_id = hf_config.image_token_id
|
||||
@@ -38,7 +44,7 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
||||
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
image_data: List[Union[str, bytes]],
|
||||
image_data: List[Union[str, bytes, Dict]],
|
||||
input_text,
|
||||
request_obj,
|
||||
max_req_input_len,
|
||||
@@ -48,11 +54,13 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
||||
if isinstance(image_data, str):
|
||||
image_data = [image_data]
|
||||
|
||||
image_token = self.IMAGE_TOKEN
|
||||
base_output = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
image_data=image_data,
|
||||
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
|
||||
multimodal_tokens=MultimodalSpecialTokens(
|
||||
image_token=self.IMAGE_TOKEN,
|
||||
image_token_regex=self.IMAGE_TOKEN_REGEX,
|
||||
),
|
||||
max_req_input_len=max_req_input_len,
|
||||
)
|
||||
|
||||
@@ -117,26 +125,56 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
||||
async def resize_image_async(image):
|
||||
return resize_image(image)
|
||||
|
||||
if base_output.images:
|
||||
images_are_preprocessed = self.mm_inputs_are_preprocessed(base_output.images)
|
||||
if base_output.images and not images_are_preprocessed:
|
||||
resize_tasks = [resize_image_async(image) for image in base_output.images]
|
||||
base_output.images = await asyncio.gather(*resize_tasks)
|
||||
|
||||
ret = self.process_mm_data(
|
||||
input_text=base_output.input_text,
|
||||
images=base_output.images,
|
||||
images=None if images_are_preprocessed else base_output.images,
|
||||
)
|
||||
|
||||
input_ids = ret["input_ids"].flatten().tolist()
|
||||
image_grid_thw = None
|
||||
video_grid_thw = None # TODO
|
||||
items = []
|
||||
|
||||
input_ids = ret["input_ids"].flatten().tolist()
|
||||
if "pixel_values" in ret:
|
||||
if base_output.images:
|
||||
if images_are_preprocessed:
|
||||
image_grid_thw = torch.concat(
|
||||
[
|
||||
torch.as_tensor(item.image_grid_thws)
|
||||
for item in base_output.images
|
||||
]
|
||||
)
|
||||
all_pixel_values = [
|
||||
item.pixel_values
|
||||
for item in base_output.images
|
||||
if item.pixel_values is not None
|
||||
]
|
||||
all_precomputed_features = [
|
||||
item.precomputed_features
|
||||
for item in base_output.images
|
||||
if item.precomputed_features is not None
|
||||
]
|
||||
pixel_values = (
|
||||
torch.concat(all_pixel_values) if all_pixel_values else None
|
||||
)
|
||||
precomputed_features = (
|
||||
torch.concat(all_precomputed_features)
|
||||
if all_precomputed_features
|
||||
else None
|
||||
)
|
||||
else:
|
||||
image_grid_thw = ret["image_grid_thw"]
|
||||
pixel_values = ret["pixel_values"]
|
||||
precomputed_features = None
|
||||
items += [
|
||||
MultimodalDataItem(
|
||||
pixel_values=ret["pixel_values"],
|
||||
image_grid_thws=torch.concat([ret["image_grid_thw"]]),
|
||||
# TODO
|
||||
video_grid_thws=None,
|
||||
second_per_grid_ts=ret.get("second_per_grid_ts", None),
|
||||
pixel_values=pixel_values,
|
||||
image_grid_thws=image_grid_thw,
|
||||
video_grid_thws=video_grid_thw,
|
||||
precomputed_features=precomputed_features,
|
||||
modality=Modality.IMAGE,
|
||||
)
|
||||
]
|
||||
@@ -151,8 +189,8 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
||||
self.hf_config.vision_config, "tokens_per_second", None
|
||||
),
|
||||
input_ids=torch.tensor(input_ids).unsqueeze(0),
|
||||
image_grid_thw=ret.get("image_grid_thw", None),
|
||||
video_grid_thw=ret.get("video_grid_thw", None),
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
second_per_grid_ts=ret.get("second_per_grid_ts", None),
|
||||
)
|
||||
mrope_positions = mrope_positions.squeeze(1)
|
||||
|
||||
@@ -177,10 +177,10 @@ class MultimodalDataItem:
|
||||
image_offsets: Optional[list] = None
|
||||
|
||||
# the real data, pixel_values or audio_features
|
||||
# data: Union[List[torch.Tensor], List[np.array]]
|
||||
pixel_values: Union[torch.Tensor, np.array] = None
|
||||
image_grid_thws: Union[torch.Tensor, np.array] = None
|
||||
video_grid_thws: Union[torch.Tensor, np.array] = None
|
||||
# data: Union[List[torch.Tensor], List[np.ndarray]]
|
||||
pixel_values: Union[torch.Tensor, np.ndarray] = None
|
||||
image_grid_thws: Union[torch.Tensor, np.ndarray] = None
|
||||
video_grid_thws: Union[torch.Tensor, np.ndarray] = None
|
||||
|
||||
image_emb_mask: Optional[torch.Tensor] = None
|
||||
image_spatial_crop: Optional[torch.Tensor] = None
|
||||
@@ -189,9 +189,11 @@ class MultimodalDataItem:
|
||||
# [num_images, (n, w, h)]
|
||||
tgt_size: Tuple[int, int] = None
|
||||
|
||||
audio_features: Union[torch.Tensor, np.array] = None
|
||||
audio_features: Union[torch.Tensor, np.ndarray] = None
|
||||
audio_feature_lens: Optional[List[torch.Tensor]] = None
|
||||
|
||||
precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None
|
||||
|
||||
@staticmethod
|
||||
def is_empty_list(l):
|
||||
if l is None:
|
||||
@@ -249,7 +251,9 @@ class MultimodalDataItem:
|
||||
return tensor_hash([f])
|
||||
return data_hash(f)
|
||||
|
||||
if self.is_audio():
|
||||
if self.precomputed_features is not None:
|
||||
self.hash = hash_feature(self.precomputed_features)
|
||||
elif self.is_audio():
|
||||
self.hash = hash_feature(self.audio_features)
|
||||
else:
|
||||
self.hash = hash_feature(self.pixel_values)
|
||||
@@ -258,19 +262,24 @@ class MultimodalDataItem:
|
||||
self.pad_value = self.hash % (1 << 30)
|
||||
|
||||
def is_audio(self):
|
||||
return (
|
||||
self.modality == Modality.AUDIO
|
||||
) and not MultimodalDataItem.is_empty_list(self.audio_features)
|
||||
return (self.modality == Modality.AUDIO) and (
|
||||
self.precomputed_features is not None
|
||||
or not MultimodalDataItem.is_empty_list(self.audio_features)
|
||||
)
|
||||
|
||||
def is_image(self):
|
||||
return (
|
||||
self.modality == Modality.IMAGE or self.modality == Modality.MULTI_IMAGES
|
||||
) and not MultimodalDataItem.is_empty_list(self.pixel_values)
|
||||
) and (
|
||||
self.precomputed_features is not None
|
||||
or not MultimodalDataItem.is_empty_list(self.pixel_values)
|
||||
)
|
||||
|
||||
def is_video(self):
|
||||
return (
|
||||
self.modality == Modality.VIDEO
|
||||
) and not MultimodalDataItem.is_empty_list(self.pixel_values)
|
||||
return (self.modality == Modality.VIDEO) and (
|
||||
self.precomputed_features is not None
|
||||
or not MultimodalDataItem.is_empty_list(self.pixel_values)
|
||||
)
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
return self.is_image() or self.is_video() or self.is_audio()
|
||||
@@ -279,6 +288,16 @@ class MultimodalDataItem:
|
||||
...
|
||||
# TODO
|
||||
|
||||
@staticmethod
|
||||
def from_dict(obj: dict):
|
||||
kwargs = dict(obj)
|
||||
modality = kwargs.pop("modality")
|
||||
if isinstance(modality, str):
|
||||
modality = Modality[modality]
|
||||
ret = MultimodalDataItem(modality=modality, **kwargs)
|
||||
ret.validate()
|
||||
return ret
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MultimodalInputs:
|
||||
|
||||
@@ -54,7 +54,7 @@ class SessionReqNode:
|
||||
prefix += " -- " + self.childs[0].req.rid
|
||||
ret = self.childs[0]._str_helper(prefix)
|
||||
for child in self.childs[1:]:
|
||||
prefix = " " * len(origin_prefix) + " \- " + child.req.rid
|
||||
prefix = " " * len(origin_prefix) + r" \- " + child.req.rid
|
||||
ret += child._str_helper(prefix)
|
||||
return ret
|
||||
|
||||
|
||||
@@ -278,6 +278,12 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
|
||||
Returns:
|
||||
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
||||
"""
|
||||
if any(item.precomputed_features is not None for item in items):
|
||||
if not all(item.precomputed_features is not None for item in items):
|
||||
raise NotImplementedError(
|
||||
"MM inputs where only some items are precomputed."
|
||||
)
|
||||
return torch.concat([item.precomputed_features for item in items])
|
||||
pixel_values = torch.stack(
|
||||
flatten_nested_list([item.pixel_values for item in items]), dim=0
|
||||
)
|
||||
|
||||
@@ -497,6 +497,12 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
||||
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
||||
|
||||
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
||||
if any(item.precomputed_features is not None for item in items):
|
||||
if not all(item.precomputed_features is not None for item in items):
|
||||
raise NotImplementedError(
|
||||
"MM inputs where only some items are precomputed."
|
||||
)
|
||||
return torch.concat([item.precomputed_features for item in items])
|
||||
# in qwen-vl, last dim is the same
|
||||
pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type(
|
||||
self.visual.dtype
|
||||
|
||||
@@ -486,6 +486,12 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
||||
|
||||
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
||||
if any(item.precomputed_features is not None for item in items):
|
||||
if not all(item.precomputed_features is not None for item in items):
|
||||
raise NotImplementedError(
|
||||
"MM inputs where only some items are precomputed."
|
||||
)
|
||||
return torch.concat([item.precomputed_features for item in items])
|
||||
# in qwen-vl, last dim is the same
|
||||
pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type(
|
||||
self.visual.dtype
|
||||
|
||||
@@ -54,21 +54,17 @@ class TestSkipTokenizerInit(CustomTestCase):
|
||||
):
|
||||
input_ids = self.get_input_ids(prompt_text)
|
||||
|
||||
request = self.get_request_json(
|
||||
input_ids=input_ids,
|
||||
return_logprob=return_logprob,
|
||||
top_logprobs_num=top_logprobs_num,
|
||||
max_new_tokens=max_new_tokens,
|
||||
stream=False,
|
||||
n=n,
|
||||
)
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"input_ids": input_ids,
|
||||
"sampling_params": {
|
||||
"temperature": 0 if n == 1 else 0.5,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"n": n,
|
||||
"stop_token_ids": [self.tokenizer.eos_token_id],
|
||||
},
|
||||
"stream": False,
|
||||
"return_logprob": return_logprob,
|
||||
"top_logprobs_num": top_logprobs_num,
|
||||
"logprob_start_len": 0,
|
||||
},
|
||||
json=request,
|
||||
)
|
||||
ret = response.json()
|
||||
print(json.dumps(ret, indent=2))
|
||||
@@ -87,9 +83,12 @@ class TestSkipTokenizerInit(CustomTestCase):
|
||||
self.assertEqual(item["meta_info"]["prompt_tokens"], len(input_ids))
|
||||
|
||||
if return_logprob:
|
||||
num_input_logprobs = len(input_ids) - request["logprob_start_len"]
|
||||
if num_input_logprobs > len(input_ids):
|
||||
num_input_logprobs -= len(input_ids)
|
||||
self.assertEqual(
|
||||
len(item["meta_info"]["input_token_logprobs"]),
|
||||
len(input_ids),
|
||||
num_input_logprobs,
|
||||
f'{len(item["meta_info"]["input_token_logprobs"])} mismatch with {len(input_ids)}',
|
||||
)
|
||||
self.assertEqual(
|
||||
@@ -113,19 +112,14 @@ class TestSkipTokenizerInit(CustomTestCase):
|
||||
requests.post(self.base_url + "/flush_cache")
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"input_ids": input_ids,
|
||||
"sampling_params": {
|
||||
"temperature": 0 if n == 1 else 0.5,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"n": n,
|
||||
"stop_token_ids": self.eos_token_id,
|
||||
},
|
||||
"stream": False,
|
||||
"return_logprob": return_logprob,
|
||||
"top_logprobs_num": top_logprobs_num,
|
||||
"logprob_start_len": 0,
|
||||
},
|
||||
json=self.get_request_json(
|
||||
input_ids=input_ids,
|
||||
max_new_tokens=max_new_tokens,
|
||||
return_logprob=return_logprob,
|
||||
top_logprobs_num=top_logprobs_num,
|
||||
stream=False,
|
||||
n=n,
|
||||
),
|
||||
)
|
||||
ret = response.json()
|
||||
print(json.dumps(ret))
|
||||
@@ -137,19 +131,13 @@ class TestSkipTokenizerInit(CustomTestCase):
|
||||
requests.post(self.base_url + "/flush_cache")
|
||||
response_stream = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"input_ids": input_ids,
|
||||
"sampling_params": {
|
||||
"temperature": 0 if n == 1 else 0.5,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"n": n,
|
||||
"stop_token_ids": self.eos_token_id,
|
||||
},
|
||||
"stream": True,
|
||||
"return_logprob": return_logprob,
|
||||
"top_logprobs_num": top_logprobs_num,
|
||||
"logprob_start_len": 0,
|
||||
},
|
||||
json=self.get_request_json(
|
||||
input_ids=input_ids,
|
||||
return_logprob=return_logprob,
|
||||
top_logprobs_num=top_logprobs_num,
|
||||
stream=True,
|
||||
n=n,
|
||||
),
|
||||
)
|
||||
|
||||
response_stream_json = []
|
||||
@@ -188,6 +176,29 @@ class TestSkipTokenizerInit(CustomTestCase):
|
||||
].tolist()
|
||||
return input_ids
|
||||
|
||||
def get_request_json(
|
||||
self,
|
||||
input_ids,
|
||||
max_new_tokens=32,
|
||||
return_logprob=False,
|
||||
top_logprobs_num=0,
|
||||
stream=False,
|
||||
n=1,
|
||||
):
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"sampling_params": {
|
||||
"temperature": 0 if n == 1 else 0.5,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"n": n,
|
||||
"stop_token_ids": self.eos_token_id,
|
||||
},
|
||||
"stream": stream,
|
||||
"return_logprob": return_logprob,
|
||||
"top_logprobs_num": top_logprobs_num,
|
||||
"logprob_start_len": 0,
|
||||
}
|
||||
|
||||
|
||||
class TestSkipTokenizerInitVLM(TestSkipTokenizerInit):
|
||||
@classmethod
|
||||
@@ -218,6 +229,14 @@ class TestSkipTokenizerInitVLM(TestSkipTokenizerInit):
|
||||
|
||||
return inputs.input_ids[0].tolist()
|
||||
|
||||
def get_request_json(self, *args, **kwargs):
|
||||
ret = super().get_request_json(*args, **kwargs)
|
||||
ret["image_data"] = [self.image_url]
|
||||
ret["logprob_start_len"] = (
|
||||
-1
|
||||
) # Do not try to calculate logprobs of image embeddings.
|
||||
return ret
|
||||
|
||||
def test_simple_decode_stream(self):
|
||||
# TODO mick
|
||||
pass
|
||||
|
||||
@@ -3,15 +3,22 @@
|
||||
|
||||
import unittest
|
||||
from io import BytesIO
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
from transformers import AutoModel, AutoProcessor, AutoTokenizer
|
||||
from transformers import (
|
||||
AutoModel,
|
||||
AutoProcessor,
|
||||
AutoTokenizer,
|
||||
Gemma3ForConditionalGeneration,
|
||||
Qwen2_5_VLForConditionalGeneration,
|
||||
)
|
||||
|
||||
from sglang import Engine
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.conversation import generate_chat_conv
|
||||
from sglang.srt.managers.mm_utils import embed_mm_inputs
|
||||
@@ -100,7 +107,7 @@ class VisionLLMLogitsBase(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
np.testing.assert_allclose(hf_np, sg_np)
|
||||
|
||||
def get_processor_output(self):
|
||||
def get_completion_request(self) -> ChatCompletionRequest:
|
||||
json_str = f"""
|
||||
{{
|
||||
"model": "{self.model_path}",
|
||||
@@ -124,10 +131,12 @@ class VisionLLMLogitsBase(unittest.IsolatedAsyncioTestCase):
|
||||
}}
|
||||
"""
|
||||
|
||||
req = ChatCompletionRequest.model_validate_json(json_str)
|
||||
return ChatCompletionRequest.model_validate_json(json_str)
|
||||
|
||||
def get_processor_output(self, req: Optional[ChatCompletionRequest] = None):
|
||||
if req is None:
|
||||
req = self.get_completion_request()
|
||||
conv = generate_chat_conv(req, template_name=self.chat_template)
|
||||
|
||||
text = conv.get_prompt()
|
||||
|
||||
# Process inputs using processor
|
||||
@@ -239,5 +248,129 @@ class TestMiniCPMVLogits(VisionLLMLogitsBase):
|
||||
self.compare_outputs(sglang_output, hf_output)
|
||||
|
||||
|
||||
class TestQwenVLUnderstandsImage(VisionLLMLogitsBase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
cls.model_path = "Qwen/Qwen2.5-VL-3B-Instruct"
|
||||
cls.chat_template = "qwen2-vl"
|
||||
cls.processor = AutoProcessor.from_pretrained(
|
||||
cls.model_path, trust_remote_code=True, use_fast=True
|
||||
)
|
||||
cls.visual = (
|
||||
Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
cls.model_path, torch_dtype=torch.bfloat16
|
||||
)
|
||||
.eval()
|
||||
.visual.to(cls.device)
|
||||
)
|
||||
|
||||
def setUp(self):
|
||||
self.engine = Engine(
|
||||
model_path=self.model_path,
|
||||
chat_template=self.chat_template,
|
||||
device=self.device.type,
|
||||
mem_fraction_static=0.8,
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
self.engine.shutdown()
|
||||
|
||||
async def test_qwen_vl_understands_image(self):
|
||||
req = self.get_completion_request()
|
||||
conv = generate_chat_conv(req, template_name=self.chat_template)
|
||||
text = conv.get_prompt()
|
||||
output = await self.engine.async_generate(
|
||||
prompt=text,
|
||||
image_data=[self.main_image],
|
||||
sampling_params=dict(temperature=0.0),
|
||||
)
|
||||
self.assertIn("taxi", output["text"].lower())
|
||||
|
||||
async def test_qwen_vl_understands_precomputed_features(self):
|
||||
req = self.get_completion_request()
|
||||
processor_output = self.get_processor_output(req=req)
|
||||
with torch.inference_mode():
|
||||
precomputed_features = self.visual(
|
||||
processor_output["pixel_values"], processor_output["image_grid_thw"]
|
||||
)
|
||||
output = await self.engine.async_generate(
|
||||
input_ids=processor_output["input_ids"][0].detach().cpu().tolist(),
|
||||
image_data=[
|
||||
dict(
|
||||
modality="IMAGE",
|
||||
image_grid_thws=processor_output["image_grid_thw"],
|
||||
precomputed_features=precomputed_features,
|
||||
)
|
||||
],
|
||||
sampling_params=dict(temperature=0.0),
|
||||
)
|
||||
self.assertIn("taxi", output["text"].lower())
|
||||
|
||||
|
||||
class TestGemmaUnderstandsImage(VisionLLMLogitsBase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
cls.model_path = "google/gemma-3-4b-it"
|
||||
cls.chat_template = "gemma-it"
|
||||
cls.processor = AutoProcessor.from_pretrained(
|
||||
cls.model_path, trust_remote_code=True, use_fast=True
|
||||
)
|
||||
model = Gemma3ForConditionalGeneration.from_pretrained(
|
||||
cls.model_path, torch_dtype=torch.bfloat16
|
||||
)
|
||||
cls.vision_tower = model.vision_tower.eval().to(cls.device)
|
||||
cls.mm_projector = model.multi_modal_projector.eval().to(cls.device)
|
||||
|
||||
@classmethod
|
||||
def visual(cls, pixel_values):
|
||||
vision_outputs = cls.vision_tower(pixel_values=pixel_values).last_hidden_state
|
||||
image_features = cls.mm_projector(vision_outputs)
|
||||
return image_features
|
||||
|
||||
def setUp(self):
|
||||
self.engine = Engine(
|
||||
model_path=self.model_path,
|
||||
chat_template=self.chat_template,
|
||||
device=self.device.type,
|
||||
mem_fraction_static=0.5,
|
||||
enable_multimodal=True,
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
self.engine.shutdown()
|
||||
|
||||
async def test_gemma_understands_image(self):
|
||||
req = self.get_completion_request()
|
||||
conv = generate_chat_conv(req, template_name=self.chat_template)
|
||||
text = conv.get_prompt()
|
||||
output = await self.engine.async_generate(
|
||||
prompt=text,
|
||||
image_data=[self.main_image],
|
||||
sampling_params=dict(temperature=0.0),
|
||||
)
|
||||
self.assertIn("taxi", output["text"].lower())
|
||||
|
||||
async def test_gemma_understands_precomputed_features(self):
|
||||
req = self.get_completion_request()
|
||||
processor_output = self.get_processor_output(req=req)
|
||||
with torch.inference_mode():
|
||||
precomputed_features = self.visual(processor_output["pixel_values"])
|
||||
output = await self.engine.async_generate(
|
||||
input_ids=processor_output["input_ids"][0].detach().cpu().tolist(),
|
||||
image_data=[
|
||||
dict(
|
||||
modality="IMAGE",
|
||||
precomputed_features=precomputed_features,
|
||||
)
|
||||
],
|
||||
sampling_params=dict(temperature=0.0),
|
||||
)
|
||||
self.assertIn("taxi", output["text"].lower())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user