Support precomputed multimodal features for Qwen-VL and Gemma3 models. (#6136)
Co-authored-by: Yury Sulsky <ysulsky@tesla.com>
This commit is contained in:
163
docs/backend/vlm_query.ipynb
Normal file
163
docs/backend/vlm_query.ipynb
Normal file
@@ -0,0 +1,163 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "0",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Querying Qwen-VL"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "1",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import nest_asyncio\n",
|
||||||
|
"\n",
|
||||||
|
"nest_asyncio.apply() # Run this first.\n",
|
||||||
|
"\n",
|
||||||
|
"model_path = \"Qwen/Qwen2.5-VL-3B-Instruct\"\n",
|
||||||
|
"chat_template = \"qwen2-vl\""
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "2",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Lets create a prompt.\n",
|
||||||
|
"\n",
|
||||||
|
"from io import BytesIO\n",
|
||||||
|
"import requests\n",
|
||||||
|
"from PIL import Image\n",
|
||||||
|
"\n",
|
||||||
|
"from sglang.srt.openai_api.protocol import ChatCompletionRequest\n",
|
||||||
|
"from sglang.srt.conversation import chat_templates\n",
|
||||||
|
"\n",
|
||||||
|
"image = Image.open(\n",
|
||||||
|
" BytesIO(\n",
|
||||||
|
" requests.get(\n",
|
||||||
|
" \"https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true\"\n",
|
||||||
|
" ).content\n",
|
||||||
|
" )\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"conv = chat_templates[chat_template].copy()\n",
|
||||||
|
"conv.append_message(conv.roles[0], f\"What's shown here: {conv.image_token}?\")\n",
|
||||||
|
"conv.append_message(conv.roles[1], \"\")\n",
|
||||||
|
"conv.image_data = [image]\n",
|
||||||
|
"\n",
|
||||||
|
"print(conv.get_prompt())\n",
|
||||||
|
"image"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "3",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Query via the offline Engine API"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "4",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from sglang import Engine\n",
|
||||||
|
"\n",
|
||||||
|
"llm = Engine(\n",
|
||||||
|
" model_path=model_path, chat_template=chat_template, mem_fraction_static=0.8\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "5",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"out = llm.generate(prompt=conv.get_prompt(), image_data=[image])\n",
|
||||||
|
"print(out[\"text\"])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "6",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Query via the offline Engine API, but send precomputed embeddings"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "7",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Compute the image embeddings using Huggingface.\n",
|
||||||
|
"\n",
|
||||||
|
"from transformers import AutoProcessor\n",
|
||||||
|
"from transformers import Qwen2_5_VLForConditionalGeneration\n",
|
||||||
|
"\n",
|
||||||
|
"processor = AutoProcessor.from_pretrained(model_path, use_fast=True)\n",
|
||||||
|
"vision = (\n",
|
||||||
|
" Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path).eval().visual.cuda()\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "8",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"processed_prompt = processor(\n",
|
||||||
|
" images=[image], text=conv.get_prompt(), return_tensors=\"pt\"\n",
|
||||||
|
")\n",
|
||||||
|
"input_ids = processed_prompt[\"input_ids\"][0].detach().cpu().tolist()\n",
|
||||||
|
"precomputed_features = vision(\n",
|
||||||
|
" processed_prompt[\"pixel_values\"].cuda(), processed_prompt[\"image_grid_thw\"].cuda()\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"mm_item = dict(\n",
|
||||||
|
" modality=\"IMAGE\",\n",
|
||||||
|
" image_grid_thws=processed_prompt[\"image_grid_thw\"],\n",
|
||||||
|
" precomputed_features=precomputed_features,\n",
|
||||||
|
")\n",
|
||||||
|
"out = llm.generate(input_ids=input_ids, image_data=[mm_item])\n",
|
||||||
|
"print(out[\"text\"])"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"jupytext": {
|
||||||
|
"cell_metadata_filter": "-all",
|
||||||
|
"custom_cell_magics": "kql",
|
||||||
|
"encoding": "# -*- coding: utf-8 -*-"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
||||||
@@ -47,6 +47,7 @@ from sglang.srt.managers.io_struct import (
|
|||||||
EmbeddingReqInput,
|
EmbeddingReqInput,
|
||||||
GenerateReqInput,
|
GenerateReqInput,
|
||||||
GetWeightsByNameReqInput,
|
GetWeightsByNameReqInput,
|
||||||
|
ImageDataItem,
|
||||||
InitWeightsUpdateGroupReqInput,
|
InitWeightsUpdateGroupReqInput,
|
||||||
ReleaseMemoryOccupationReqInput,
|
ReleaseMemoryOccupationReqInput,
|
||||||
ResumeMemoryOccupationReqInput,
|
ResumeMemoryOccupationReqInput,
|
||||||
@@ -150,9 +151,9 @@ class Engine(EngineBase):
|
|||||||
# See also python/sglang/srt/utils.py:load_image for more details.
|
# See also python/sglang/srt/utils.py:load_image for more details.
|
||||||
image_data: Optional[
|
image_data: Optional[
|
||||||
Union[
|
Union[
|
||||||
List[List[Union[Image, str]]],
|
List[List[ImageDataItem]],
|
||||||
List[Union[Image, str]],
|
List[ImageDataItem],
|
||||||
Union[Image, str],
|
ImageDataItem,
|
||||||
]
|
]
|
||||||
] = None,
|
] = None,
|
||||||
return_logprob: Optional[Union[List[bool], bool]] = False,
|
return_logprob: Optional[Union[List[bool], bool]] = False,
|
||||||
@@ -221,9 +222,9 @@ class Engine(EngineBase):
|
|||||||
# See also python/sglang/srt/utils.py:load_image for more details.
|
# See also python/sglang/srt/utils.py:load_image for more details.
|
||||||
image_data: Optional[
|
image_data: Optional[
|
||||||
Union[
|
Union[
|
||||||
List[List[Union[Image, str]]],
|
List[List[ImageDataItem]],
|
||||||
List[Union[Image, str]],
|
List[ImageDataItem],
|
||||||
Union[Image, str],
|
ImageDataItem,
|
||||||
]
|
]
|
||||||
] = None,
|
] = None,
|
||||||
return_logprob: Optional[Union[List[bool], bool]] = False,
|
return_logprob: Optional[Union[List[bool], bool]] = False,
|
||||||
|
|||||||
@@ -40,6 +40,10 @@ class SessionParams:
|
|||||||
replace: Optional[bool] = None
|
replace: Optional[bool] = None
|
||||||
|
|
||||||
|
|
||||||
|
AudioDataItem = Union[str, Dict]
|
||||||
|
ImageDataItem = Union[Image, str, Dict]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GenerateReqInput:
|
class GenerateReqInput:
|
||||||
# The input prompt. It can be a single prompt or a batch of prompts.
|
# The input prompt. It can be a single prompt or a batch of prompts.
|
||||||
@@ -55,10 +59,10 @@ class GenerateReqInput:
|
|||||||
# - List of lists of images (multiple images per request)
|
# - List of lists of images (multiple images per request)
|
||||||
# See also python/sglang/srt/utils.py:load_image for more details.
|
# See also python/sglang/srt/utils.py:load_image for more details.
|
||||||
image_data: Optional[
|
image_data: Optional[
|
||||||
Union[List[List[Union[Image, str]]], List[Union[Image, str]], Union[Image, str]]
|
Union[List[List[ImageDataItem]], List[ImageDataItem], ImageDataItem]
|
||||||
] = None
|
] = None
|
||||||
# The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
|
# The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
|
||||||
audio_data: Optional[Union[List[str], str]] = None
|
audio_data: Optional[Union[List[AudioDataItem], AudioDataItem]] = None
|
||||||
# The sampling_params. See descriptions below.
|
# The sampling_params. See descriptions below.
|
||||||
sampling_params: Optional[Union[List[Dict], Dict]] = None
|
sampling_params: Optional[Union[List[Dict], Dict]] = None
|
||||||
# The request id.
|
# The request id.
|
||||||
|
|||||||
@@ -368,13 +368,13 @@ def general_mm_embed_routine(
|
|||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
language_model: nn.Module,
|
language_model: nn.Module,
|
||||||
image_data_embedding_func: Callable[
|
image_data_embedding_func: Optional[
|
||||||
[List[MultimodalDataItem]], torch.Tensor
|
Callable[[List[MultimodalDataItem]], torch.Tensor]
|
||||||
] = None,
|
] = None,
|
||||||
audio_data_embedding_func: Callable[
|
audio_data_embedding_func: Optional[
|
||||||
[List[MultimodalDataItem]], torch.Tensor
|
Callable[[List[MultimodalDataItem]], torch.Tensor]
|
||||||
] = None,
|
] = None,
|
||||||
placeholder_tokens: dict[Modality, List[int]] = None,
|
placeholder_tokens: Optional[dict[Modality, List[int]]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
@@ -389,7 +389,6 @@ def general_mm_embed_routine(
|
|||||||
forwarded hidden states
|
forwarded hidden states
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert hasattr(language_model, "get_input_embeddings")
|
assert hasattr(language_model, "get_input_embeddings")
|
||||||
embed_tokens = language_model.get_input_embeddings()
|
embed_tokens = language_model.get_input_embeddings()
|
||||||
if (
|
if (
|
||||||
|
|||||||
@@ -3,16 +3,16 @@ import concurrent.futures
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List, Optional
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import PIL
|
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers import BaseImageProcessorFast
|
from transformers import BaseImageProcessorFast
|
||||||
|
|
||||||
from sglang.srt.managers.schedule_batch import Modality
|
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||||
from sglang.srt.utils import encode_video, load_audio, load_image
|
from sglang.srt.utils import encode_video, load_audio, load_image
|
||||||
|
|
||||||
|
|
||||||
@@ -22,13 +22,13 @@ class BaseMultiModalProcessorOutput:
|
|||||||
input_text: str
|
input_text: str
|
||||||
|
|
||||||
# frames loaded from image and video, in given order
|
# frames loaded from image and video, in given order
|
||||||
images: Optional[list[PIL.Image]] = None
|
images: Optional[list[Union[Image.Image, MultimodalDataItem]]] = None
|
||||||
|
|
||||||
# audios
|
# audios
|
||||||
audios: Optional[list[np.ndarray]] = None
|
audios: Optional[list[Union[np.ndarray, MultimodalDataItem]]] = None
|
||||||
|
|
||||||
def normalize(self):
|
def normalize(self):
|
||||||
for field_name in ["image_sizes", "images", "audios"]:
|
for field_name in ["images", "audios"]:
|
||||||
field = getattr(self, field_name, None)
|
field = getattr(self, field_name, None)
|
||||||
if field is not None and isinstance(field, list) and len(field) == 0:
|
if field is not None and isinstance(field, list) and len(field) == 0:
|
||||||
setattr(self, field_name, None)
|
setattr(self, field_name, None)
|
||||||
@@ -40,12 +40,32 @@ class MultimodalSpecialTokens:
|
|||||||
video_token: Optional[str] = None
|
video_token: Optional[str] = None
|
||||||
audio_token: Optional[str] = None
|
audio_token: Optional[str] = None
|
||||||
|
|
||||||
def collect(self) -> list[str]:
|
image_token_regex: Optional[re.Pattern] = None
|
||||||
return [
|
video_token_regex: Optional[re.Pattern] = None
|
||||||
token
|
audio_token_regex: Optional[re.Pattern] = None
|
||||||
for token in [self.image_token, self.video_token, self.audio_token]
|
|
||||||
if token
|
def __post_init__(self):
|
||||||
|
if self.image_token_regex is None and self.image_token is not None:
|
||||||
|
self.image_token_regex = re.compile(re.escape(self.image_token))
|
||||||
|
if self.video_token_regex is None and self.video_token is not None:
|
||||||
|
self.video_token_regex = re.compile(re.escape(self.video_token))
|
||||||
|
if self.audio_token_regex is None and self.audio_token is not None:
|
||||||
|
self.audio_token_regex = re.compile(re.escape(self.audio_token))
|
||||||
|
|
||||||
|
def collect(self) -> re.Pattern:
|
||||||
|
tokens = [
|
||||||
|
self.image_token_regex,
|
||||||
|
self.video_token_regex,
|
||||||
|
self.audio_token_regex,
|
||||||
]
|
]
|
||||||
|
patterns = []
|
||||||
|
flags = 0
|
||||||
|
for t in tokens:
|
||||||
|
if t is not None:
|
||||||
|
patterns.append(t.pattern)
|
||||||
|
flags |= t.flags
|
||||||
|
combined = "(" + "|".join(f"(?:{p})" for p in patterns) + ")"
|
||||||
|
return re.compile(combined, flags)
|
||||||
|
|
||||||
|
|
||||||
class BaseMultimodalProcessor(ABC):
|
class BaseMultimodalProcessor(ABC):
|
||||||
@@ -136,6 +156,10 @@ class BaseMultimodalProcessor(ABC):
|
|||||||
data, is_video, is_audio, frame_count_limit=None, discard_alpha_channel=True
|
data, is_video, is_audio, frame_count_limit=None, discard_alpha_channel=True
|
||||||
):
|
):
|
||||||
"""Static method that can be pickled for multiprocessing"""
|
"""Static method that can be pickled for multiprocessing"""
|
||||||
|
if isinstance(data, dict):
|
||||||
|
return MultimodalDataItem.from_dict(data)
|
||||||
|
if isinstance(data, MultimodalDataItem):
|
||||||
|
return data
|
||||||
try:
|
try:
|
||||||
if is_audio:
|
if is_audio:
|
||||||
return load_audio(data)
|
return load_audio(data)
|
||||||
@@ -175,7 +199,10 @@ class BaseMultimodalProcessor(ABC):
|
|||||||
image_index, audio_index = 0, 0
|
image_index, audio_index = 0, 0
|
||||||
|
|
||||||
for text_part in text_parts:
|
for text_part in text_parts:
|
||||||
if text_part == multimodal_tokens.image_token:
|
if (
|
||||||
|
multimodal_tokens.image_token_regex
|
||||||
|
and multimodal_tokens.image_token_regex.match(text_part)
|
||||||
|
):
|
||||||
data = image_data[image_index]
|
data = image_data[image_index]
|
||||||
is_video = isinstance(data, str) and data.startswith("video:")
|
is_video = isinstance(data, str) and data.startswith("video:")
|
||||||
estimated_frames = estimated_frames_list[image_index]
|
estimated_frames = estimated_frames_list[image_index]
|
||||||
@@ -192,7 +219,10 @@ class BaseMultimodalProcessor(ABC):
|
|||||||
)
|
)
|
||||||
task_info.append((Modality.IMAGE, data, frame_count_limit))
|
task_info.append((Modality.IMAGE, data, frame_count_limit))
|
||||||
image_index += 1
|
image_index += 1
|
||||||
elif text_part == multimodal_tokens.audio_token:
|
elif (
|
||||||
|
multimodal_tokens.audio_token_regex
|
||||||
|
and multimodal_tokens.audio_token_regex.match(text_part)
|
||||||
|
):
|
||||||
data = audio_data[audio_index]
|
data = audio_data[audio_index]
|
||||||
futures.append(
|
futures.append(
|
||||||
self.io_executor.submit(
|
self.io_executor.submit(
|
||||||
@@ -228,17 +258,22 @@ class BaseMultimodalProcessor(ABC):
|
|||||||
discard_alpha_channel: if True, discards the alpha channel in the returned images
|
discard_alpha_channel: if True, discards the alpha channel in the returned images
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
if not return_text:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
if image_data is None:
|
if image_data is None:
|
||||||
image_data = []
|
image_data = []
|
||||||
if isinstance(multimodal_tokens.image_token, int):
|
if isinstance(multimodal_tokens.image_token, int):
|
||||||
multimodal_tokens.image_token = (
|
multimodal_tokens.image_token = re.compile(
|
||||||
|
re.escape(
|
||||||
self._processor.tokenizer.convert_ids_to_tokens(
|
self._processor.tokenizer.convert_ids_to_tokens(
|
||||||
multimodal_tokens.image_token
|
multimodal_tokens.image_token
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
multimodal_tokens.image_token = multimodal_tokens.image_token
|
multimodal_tokens.image_token = multimodal_tokens.image_token
|
||||||
|
multimodal_tokens_pattern = multimodal_tokens.collect()
|
||||||
|
|
||||||
if isinstance(prompt, list) and return_text:
|
if isinstance(prompt, list) and return_text:
|
||||||
assert len(prompt) and isinstance(prompt[0], int)
|
assert len(prompt) and isinstance(prompt[0], int)
|
||||||
@@ -247,16 +282,8 @@ class BaseMultimodalProcessor(ABC):
|
|||||||
prompt = prompt
|
prompt = prompt
|
||||||
|
|
||||||
assert isinstance(prompt, str)
|
assert isinstance(prompt, str)
|
||||||
if return_text:
|
|
||||||
import re
|
|
||||||
|
|
||||||
pattern = (
|
|
||||||
"("
|
|
||||||
+ "|".join(re.escape(sep) for sep in multimodal_tokens.collect())
|
|
||||||
+ ")"
|
|
||||||
)
|
|
||||||
# split text into list of normal text and special tokens
|
# split text into list of normal text and special tokens
|
||||||
text_parts = re.split(pattern, prompt)
|
text_parts = re.split(multimodal_tokens_pattern, prompt)
|
||||||
|
|
||||||
futures, task_info = self.submit_data_loading_tasks(
|
futures, task_info = self.submit_data_loading_tasks(
|
||||||
text_parts=text_parts,
|
text_parts=text_parts,
|
||||||
@@ -266,26 +293,40 @@ class BaseMultimodalProcessor(ABC):
|
|||||||
discard_alpha_channel=discard_alpha_channel,
|
discard_alpha_channel=discard_alpha_channel,
|
||||||
)
|
)
|
||||||
# Process results
|
# Process results
|
||||||
image_sizes, images, audios = [], [], []
|
images, audios = [], []
|
||||||
new_text = ""
|
new_text = ""
|
||||||
task_ptr = 0
|
task_ptr = 0
|
||||||
|
|
||||||
for text_part in text_parts:
|
for text_part in text_parts:
|
||||||
if text_part in multimodal_tokens.collect():
|
if multimodal_tokens_pattern.match(text_part):
|
||||||
task_type, data, frame_limit = task_info[task_ptr]
|
task_type, data, frame_limit = task_info[task_ptr]
|
||||||
result = futures[task_ptr].result()
|
result = futures[task_ptr].result()
|
||||||
task_ptr += 1
|
task_ptr += 1
|
||||||
|
|
||||||
if task_type == Modality.IMAGE:
|
if task_type == Modality.IMAGE:
|
||||||
|
# If data is already processed it will be a
|
||||||
|
# dictionary. In this case we want to keep the
|
||||||
|
# expanded tokens in text_part. Otherwise, we will
|
||||||
|
# call the processor code, so keep only a single image
|
||||||
|
# token.
|
||||||
|
mm_tokens = (
|
||||||
|
text_part
|
||||||
|
if isinstance(data, dict)
|
||||||
|
else multimodal_tokens.image_token
|
||||||
|
)
|
||||||
frames = [result] if not isinstance(result, list) else result
|
frames = [result] if not isinstance(result, list) else result
|
||||||
if frames:
|
if frames:
|
||||||
image_sizes += frames[0].size * len(frames)
|
|
||||||
images += frames
|
images += frames
|
||||||
new_text += multimodal_tokens.image_token * len(frames)
|
new_text += mm_tokens * len(frames)
|
||||||
elif task_type == Modality.AUDIO:
|
elif task_type == Modality.AUDIO:
|
||||||
# audio
|
# audio
|
||||||
|
mm_tokens = (
|
||||||
|
text_part
|
||||||
|
if isinstance(data, dict)
|
||||||
|
else multimodal_tokens.audio_token
|
||||||
|
)
|
||||||
audios.append(result)
|
audios.append(result)
|
||||||
new_text += multimodal_tokens.audio_token
|
new_text += mm_tokens
|
||||||
# TODO: handle video
|
# TODO: handle video
|
||||||
else:
|
else:
|
||||||
new_text += text_part
|
new_text += text_part
|
||||||
@@ -297,3 +338,16 @@ class BaseMultimodalProcessor(ABC):
|
|||||||
)
|
)
|
||||||
out.normalize()
|
out.normalize()
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def mm_inputs_are_preprocessed(self, mm_inputs: Optional[list]):
|
||||||
|
"""Returns true if all images are preprocessed, false if all are not, and error otherwise."""
|
||||||
|
if not mm_inputs:
|
||||||
|
return True
|
||||||
|
ret = any(isinstance(mm_input, MultimodalDataItem) for mm_input in mm_inputs)
|
||||||
|
if ret and not all(
|
||||||
|
isinstance(mm_input, MultimodalDataItem) for mm_input in mm_inputs
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"Unsupported: mixture of multimodal inputs where some but not all are preprocessed."
|
||||||
|
)
|
||||||
|
return ret
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
from typing import List, Union
|
import re
|
||||||
|
from typing import Dict, List, Union
|
||||||
|
|
||||||
from sglang.srt.managers.multimodal_processor import (
|
from sglang.srt.managers.multimodal_processor import (
|
||||||
BaseMultimodalProcessor as SGLangBaseProcessor,
|
BaseMultimodalProcessor as SGLangBaseProcessor,
|
||||||
@@ -18,13 +19,18 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
|
|||||||
|
|
||||||
def __init__(self, hf_config, server_args, _processor):
|
def __init__(self, hf_config, server_args, _processor):
|
||||||
super().__init__(hf_config, server_args, _processor)
|
super().__init__(hf_config, server_args, _processor)
|
||||||
|
# The single, pre-expanded image token.
|
||||||
self.IMAGE_TOKEN = "<start_of_image>"
|
self.IMAGE_TOKEN = "<start_of_image>"
|
||||||
|
# The regex that matches expanded image tokens.
|
||||||
|
self.IMAGE_TOKEN_REGEX = re.compile(
|
||||||
|
r"<start_of_image>(?:(?:<image_soft_token>)*<end_of_image>)?"
|
||||||
|
)
|
||||||
self.IM_START_TOKEN_ID = hf_config.boi_token_index
|
self.IM_START_TOKEN_ID = hf_config.boi_token_index
|
||||||
self.IM_END_TOKEN_ID = hf_config.eoi_token_index
|
self.IM_END_TOKEN_ID = hf_config.eoi_token_index
|
||||||
|
|
||||||
async def process_mm_data_async(
|
async def process_mm_data_async(
|
||||||
self,
|
self,
|
||||||
image_data: List[Union[str, bytes]],
|
image_data: List[Union[str, bytes, Dict]],
|
||||||
input_text,
|
input_text,
|
||||||
request_obj,
|
request_obj,
|
||||||
max_req_input_len,
|
max_req_input_len,
|
||||||
@@ -37,22 +43,35 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
|
|||||||
image_data = [image_data]
|
image_data = [image_data]
|
||||||
|
|
||||||
image_token = self.IMAGE_TOKEN
|
image_token = self.IMAGE_TOKEN
|
||||||
|
image_token_regex = self.IMAGE_TOKEN_REGEX
|
||||||
base_output = self.load_mm_data(
|
base_output = self.load_mm_data(
|
||||||
prompt=input_text,
|
prompt=input_text,
|
||||||
image_data=image_data,
|
image_data=image_data,
|
||||||
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
|
multimodal_tokens=MultimodalSpecialTokens(
|
||||||
|
image_token=image_token, image_token_regex=image_token_regex
|
||||||
|
),
|
||||||
max_req_input_len=max_req_input_len,
|
max_req_input_len=max_req_input_len,
|
||||||
discard_alpha_channel=True,
|
discard_alpha_channel=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
images_are_preprocessed = self.mm_inputs_are_preprocessed(base_output.images)
|
||||||
ret = self.process_mm_data(
|
ret = self.process_mm_data(
|
||||||
input_text=base_output.input_text, images=base_output.images
|
input_text=base_output.input_text,
|
||||||
|
images=None if images_are_preprocessed else base_output.images,
|
||||||
)
|
)
|
||||||
|
|
||||||
items = []
|
items = []
|
||||||
for i, image in enumerate(base_output.images):
|
for i, image in enumerate(base_output.images):
|
||||||
|
if images_are_preprocessed:
|
||||||
|
pixel_values = image.pixel_values
|
||||||
|
precomputed_features = image.precomputed_features
|
||||||
|
else:
|
||||||
|
pixel_values = ret["pixel_values"][i]
|
||||||
|
precomputed_features = None
|
||||||
|
|
||||||
item = MultimodalDataItem(
|
item = MultimodalDataItem(
|
||||||
pixel_values=ret["pixel_values"][i],
|
pixel_values=pixel_values,
|
||||||
|
precomputed_features=precomputed_features,
|
||||||
modality=Modality.IMAGE,
|
modality=Modality.IMAGE,
|
||||||
)
|
)
|
||||||
items += [item]
|
items += [item]
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import math
|
import math
|
||||||
from typing import List, Union
|
import re
|
||||||
|
from typing import Dict, List, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@@ -23,7 +24,12 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
|||||||
|
|
||||||
def __init__(self, hf_config, server_args, _processor):
|
def __init__(self, hf_config, server_args, _processor):
|
||||||
super().__init__(hf_config, server_args, _processor)
|
super().__init__(hf_config, server_args, _processor)
|
||||||
|
# The single, pre-expanded image token.
|
||||||
self.IMAGE_TOKEN = "<|vision_start|><|image_pad|><|vision_end|>"
|
self.IMAGE_TOKEN = "<|vision_start|><|image_pad|><|vision_end|>"
|
||||||
|
# The regex that matches expanded image tokens.
|
||||||
|
self.IMAGE_TOKEN_REGEX = re.compile(
|
||||||
|
r"<\|vision_start\|>(?:<\|image_pad\|>)+<\|vision_end\|>"
|
||||||
|
)
|
||||||
self.IM_START_TOKEN_ID = hf_config.vision_start_token_id
|
self.IM_START_TOKEN_ID = hf_config.vision_start_token_id
|
||||||
self.IM_END_TOKEN_ID = hf_config.vision_end_token_id
|
self.IM_END_TOKEN_ID = hf_config.vision_end_token_id
|
||||||
self.image_token_id = hf_config.image_token_id
|
self.image_token_id = hf_config.image_token_id
|
||||||
@@ -38,7 +44,7 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
|||||||
|
|
||||||
async def process_mm_data_async(
|
async def process_mm_data_async(
|
||||||
self,
|
self,
|
||||||
image_data: List[Union[str, bytes]],
|
image_data: List[Union[str, bytes, Dict]],
|
||||||
input_text,
|
input_text,
|
||||||
request_obj,
|
request_obj,
|
||||||
max_req_input_len,
|
max_req_input_len,
|
||||||
@@ -48,11 +54,13 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
|||||||
if isinstance(image_data, str):
|
if isinstance(image_data, str):
|
||||||
image_data = [image_data]
|
image_data = [image_data]
|
||||||
|
|
||||||
image_token = self.IMAGE_TOKEN
|
|
||||||
base_output = self.load_mm_data(
|
base_output = self.load_mm_data(
|
||||||
prompt=input_text,
|
prompt=input_text,
|
||||||
image_data=image_data,
|
image_data=image_data,
|
||||||
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
|
multimodal_tokens=MultimodalSpecialTokens(
|
||||||
|
image_token=self.IMAGE_TOKEN,
|
||||||
|
image_token_regex=self.IMAGE_TOKEN_REGEX,
|
||||||
|
),
|
||||||
max_req_input_len=max_req_input_len,
|
max_req_input_len=max_req_input_len,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -117,26 +125,56 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
|||||||
async def resize_image_async(image):
|
async def resize_image_async(image):
|
||||||
return resize_image(image)
|
return resize_image(image)
|
||||||
|
|
||||||
if base_output.images:
|
images_are_preprocessed = self.mm_inputs_are_preprocessed(base_output.images)
|
||||||
|
if base_output.images and not images_are_preprocessed:
|
||||||
resize_tasks = [resize_image_async(image) for image in base_output.images]
|
resize_tasks = [resize_image_async(image) for image in base_output.images]
|
||||||
base_output.images = await asyncio.gather(*resize_tasks)
|
base_output.images = await asyncio.gather(*resize_tasks)
|
||||||
|
|
||||||
ret = self.process_mm_data(
|
ret = self.process_mm_data(
|
||||||
input_text=base_output.input_text,
|
input_text=base_output.input_text,
|
||||||
images=base_output.images,
|
images=None if images_are_preprocessed else base_output.images,
|
||||||
)
|
)
|
||||||
|
input_ids = ret["input_ids"].flatten().tolist()
|
||||||
|
image_grid_thw = None
|
||||||
|
video_grid_thw = None # TODO
|
||||||
items = []
|
items = []
|
||||||
|
|
||||||
input_ids = ret["input_ids"].flatten().tolist()
|
if base_output.images:
|
||||||
if "pixel_values" in ret:
|
if images_are_preprocessed:
|
||||||
|
image_grid_thw = torch.concat(
|
||||||
|
[
|
||||||
|
torch.as_tensor(item.image_grid_thws)
|
||||||
|
for item in base_output.images
|
||||||
|
]
|
||||||
|
)
|
||||||
|
all_pixel_values = [
|
||||||
|
item.pixel_values
|
||||||
|
for item in base_output.images
|
||||||
|
if item.pixel_values is not None
|
||||||
|
]
|
||||||
|
all_precomputed_features = [
|
||||||
|
item.precomputed_features
|
||||||
|
for item in base_output.images
|
||||||
|
if item.precomputed_features is not None
|
||||||
|
]
|
||||||
|
pixel_values = (
|
||||||
|
torch.concat(all_pixel_values) if all_pixel_values else None
|
||||||
|
)
|
||||||
|
precomputed_features = (
|
||||||
|
torch.concat(all_precomputed_features)
|
||||||
|
if all_precomputed_features
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
image_grid_thw = ret["image_grid_thw"]
|
||||||
|
pixel_values = ret["pixel_values"]
|
||||||
|
precomputed_features = None
|
||||||
items += [
|
items += [
|
||||||
MultimodalDataItem(
|
MultimodalDataItem(
|
||||||
pixel_values=ret["pixel_values"],
|
pixel_values=pixel_values,
|
||||||
image_grid_thws=torch.concat([ret["image_grid_thw"]]),
|
image_grid_thws=image_grid_thw,
|
||||||
# TODO
|
video_grid_thws=video_grid_thw,
|
||||||
video_grid_thws=None,
|
precomputed_features=precomputed_features,
|
||||||
second_per_grid_ts=ret.get("second_per_grid_ts", None),
|
|
||||||
modality=Modality.IMAGE,
|
modality=Modality.IMAGE,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
@@ -151,8 +189,8 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
|||||||
self.hf_config.vision_config, "tokens_per_second", None
|
self.hf_config.vision_config, "tokens_per_second", None
|
||||||
),
|
),
|
||||||
input_ids=torch.tensor(input_ids).unsqueeze(0),
|
input_ids=torch.tensor(input_ids).unsqueeze(0),
|
||||||
image_grid_thw=ret.get("image_grid_thw", None),
|
image_grid_thw=image_grid_thw,
|
||||||
video_grid_thw=ret.get("video_grid_thw", None),
|
video_grid_thw=video_grid_thw,
|
||||||
second_per_grid_ts=ret.get("second_per_grid_ts", None),
|
second_per_grid_ts=ret.get("second_per_grid_ts", None),
|
||||||
)
|
)
|
||||||
mrope_positions = mrope_positions.squeeze(1)
|
mrope_positions = mrope_positions.squeeze(1)
|
||||||
|
|||||||
@@ -177,10 +177,10 @@ class MultimodalDataItem:
|
|||||||
image_offsets: Optional[list] = None
|
image_offsets: Optional[list] = None
|
||||||
|
|
||||||
# the real data, pixel_values or audio_features
|
# the real data, pixel_values or audio_features
|
||||||
# data: Union[List[torch.Tensor], List[np.array]]
|
# data: Union[List[torch.Tensor], List[np.ndarray]]
|
||||||
pixel_values: Union[torch.Tensor, np.array] = None
|
pixel_values: Union[torch.Tensor, np.ndarray] = None
|
||||||
image_grid_thws: Union[torch.Tensor, np.array] = None
|
image_grid_thws: Union[torch.Tensor, np.ndarray] = None
|
||||||
video_grid_thws: Union[torch.Tensor, np.array] = None
|
video_grid_thws: Union[torch.Tensor, np.ndarray] = None
|
||||||
|
|
||||||
image_emb_mask: Optional[torch.Tensor] = None
|
image_emb_mask: Optional[torch.Tensor] = None
|
||||||
image_spatial_crop: Optional[torch.Tensor] = None
|
image_spatial_crop: Optional[torch.Tensor] = None
|
||||||
@@ -189,9 +189,11 @@ class MultimodalDataItem:
|
|||||||
# [num_images, (n, w, h)]
|
# [num_images, (n, w, h)]
|
||||||
tgt_size: Tuple[int, int] = None
|
tgt_size: Tuple[int, int] = None
|
||||||
|
|
||||||
audio_features: Union[torch.Tensor, np.array] = None
|
audio_features: Union[torch.Tensor, np.ndarray] = None
|
||||||
audio_feature_lens: Optional[List[torch.Tensor]] = None
|
audio_feature_lens: Optional[List[torch.Tensor]] = None
|
||||||
|
|
||||||
|
precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def is_empty_list(l):
|
def is_empty_list(l):
|
||||||
if l is None:
|
if l is None:
|
||||||
@@ -249,7 +251,9 @@ class MultimodalDataItem:
|
|||||||
return tensor_hash([f])
|
return tensor_hash([f])
|
||||||
return data_hash(f)
|
return data_hash(f)
|
||||||
|
|
||||||
if self.is_audio():
|
if self.precomputed_features is not None:
|
||||||
|
self.hash = hash_feature(self.precomputed_features)
|
||||||
|
elif self.is_audio():
|
||||||
self.hash = hash_feature(self.audio_features)
|
self.hash = hash_feature(self.audio_features)
|
||||||
else:
|
else:
|
||||||
self.hash = hash_feature(self.pixel_values)
|
self.hash = hash_feature(self.pixel_values)
|
||||||
@@ -258,19 +262,24 @@ class MultimodalDataItem:
|
|||||||
self.pad_value = self.hash % (1 << 30)
|
self.pad_value = self.hash % (1 << 30)
|
||||||
|
|
||||||
def is_audio(self):
|
def is_audio(self):
|
||||||
return (
|
return (self.modality == Modality.AUDIO) and (
|
||||||
self.modality == Modality.AUDIO
|
self.precomputed_features is not None
|
||||||
) and not MultimodalDataItem.is_empty_list(self.audio_features)
|
or not MultimodalDataItem.is_empty_list(self.audio_features)
|
||||||
|
)
|
||||||
|
|
||||||
def is_image(self):
|
def is_image(self):
|
||||||
return (
|
return (
|
||||||
self.modality == Modality.IMAGE or self.modality == Modality.MULTI_IMAGES
|
self.modality == Modality.IMAGE or self.modality == Modality.MULTI_IMAGES
|
||||||
) and not MultimodalDataItem.is_empty_list(self.pixel_values)
|
) and (
|
||||||
|
self.precomputed_features is not None
|
||||||
|
or not MultimodalDataItem.is_empty_list(self.pixel_values)
|
||||||
|
)
|
||||||
|
|
||||||
def is_video(self):
|
def is_video(self):
|
||||||
return (
|
return (self.modality == Modality.VIDEO) and (
|
||||||
self.modality == Modality.VIDEO
|
self.precomputed_features is not None
|
||||||
) and not MultimodalDataItem.is_empty_list(self.pixel_values)
|
or not MultimodalDataItem.is_empty_list(self.pixel_values)
|
||||||
|
)
|
||||||
|
|
||||||
def is_valid(self) -> bool:
|
def is_valid(self) -> bool:
|
||||||
return self.is_image() or self.is_video() or self.is_audio()
|
return self.is_image() or self.is_video() or self.is_audio()
|
||||||
@@ -279,6 +288,16 @@ class MultimodalDataItem:
|
|||||||
...
|
...
|
||||||
# TODO
|
# TODO
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_dict(obj: dict):
|
||||||
|
kwargs = dict(obj)
|
||||||
|
modality = kwargs.pop("modality")
|
||||||
|
if isinstance(modality, str):
|
||||||
|
modality = Modality[modality]
|
||||||
|
ret = MultimodalDataItem(modality=modality, **kwargs)
|
||||||
|
ret.validate()
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class MultimodalInputs:
|
class MultimodalInputs:
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ class SessionReqNode:
|
|||||||
prefix += " -- " + self.childs[0].req.rid
|
prefix += " -- " + self.childs[0].req.rid
|
||||||
ret = self.childs[0]._str_helper(prefix)
|
ret = self.childs[0]._str_helper(prefix)
|
||||||
for child in self.childs[1:]:
|
for child in self.childs[1:]:
|
||||||
prefix = " " * len(origin_prefix) + " \- " + child.req.rid
|
prefix = " " * len(origin_prefix) + r" \- " + child.req.rid
|
||||||
ret += child._str_helper(prefix)
|
ret += child._str_helper(prefix)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|||||||
@@ -278,6 +278,12 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
|
|||||||
Returns:
|
Returns:
|
||||||
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
||||||
"""
|
"""
|
||||||
|
if any(item.precomputed_features is not None for item in items):
|
||||||
|
if not all(item.precomputed_features is not None for item in items):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"MM inputs where only some items are precomputed."
|
||||||
|
)
|
||||||
|
return torch.concat([item.precomputed_features for item in items])
|
||||||
pixel_values = torch.stack(
|
pixel_values = torch.stack(
|
||||||
flatten_nested_list([item.pixel_values for item in items]), dim=0
|
flatten_nested_list([item.pixel_values for item in items]), dim=0
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -497,6 +497,12 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
|||||||
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
||||||
|
|
||||||
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
||||||
|
if any(item.precomputed_features is not None for item in items):
|
||||||
|
if not all(item.precomputed_features is not None for item in items):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"MM inputs where only some items are precomputed."
|
||||||
|
)
|
||||||
|
return torch.concat([item.precomputed_features for item in items])
|
||||||
# in qwen-vl, last dim is the same
|
# in qwen-vl, last dim is the same
|
||||||
pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type(
|
pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type(
|
||||||
self.visual.dtype
|
self.visual.dtype
|
||||||
|
|||||||
@@ -486,6 +486,12 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||||||
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
||||||
|
|
||||||
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
||||||
|
if any(item.precomputed_features is not None for item in items):
|
||||||
|
if not all(item.precomputed_features is not None for item in items):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"MM inputs where only some items are precomputed."
|
||||||
|
)
|
||||||
|
return torch.concat([item.precomputed_features for item in items])
|
||||||
# in qwen-vl, last dim is the same
|
# in qwen-vl, last dim is the same
|
||||||
pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type(
|
pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type(
|
||||||
self.visual.dtype
|
self.visual.dtype
|
||||||
|
|||||||
@@ -54,21 +54,17 @@ class TestSkipTokenizerInit(CustomTestCase):
|
|||||||
):
|
):
|
||||||
input_ids = self.get_input_ids(prompt_text)
|
input_ids = self.get_input_ids(prompt_text)
|
||||||
|
|
||||||
|
request = self.get_request_json(
|
||||||
|
input_ids=input_ids,
|
||||||
|
return_logprob=return_logprob,
|
||||||
|
top_logprobs_num=top_logprobs_num,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
stream=False,
|
||||||
|
n=n,
|
||||||
|
)
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
self.base_url + "/generate",
|
self.base_url + "/generate",
|
||||||
json={
|
json=request,
|
||||||
"input_ids": input_ids,
|
|
||||||
"sampling_params": {
|
|
||||||
"temperature": 0 if n == 1 else 0.5,
|
|
||||||
"max_new_tokens": max_new_tokens,
|
|
||||||
"n": n,
|
|
||||||
"stop_token_ids": [self.tokenizer.eos_token_id],
|
|
||||||
},
|
|
||||||
"stream": False,
|
|
||||||
"return_logprob": return_logprob,
|
|
||||||
"top_logprobs_num": top_logprobs_num,
|
|
||||||
"logprob_start_len": 0,
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
ret = response.json()
|
ret = response.json()
|
||||||
print(json.dumps(ret, indent=2))
|
print(json.dumps(ret, indent=2))
|
||||||
@@ -87,9 +83,12 @@ class TestSkipTokenizerInit(CustomTestCase):
|
|||||||
self.assertEqual(item["meta_info"]["prompt_tokens"], len(input_ids))
|
self.assertEqual(item["meta_info"]["prompt_tokens"], len(input_ids))
|
||||||
|
|
||||||
if return_logprob:
|
if return_logprob:
|
||||||
|
num_input_logprobs = len(input_ids) - request["logprob_start_len"]
|
||||||
|
if num_input_logprobs > len(input_ids):
|
||||||
|
num_input_logprobs -= len(input_ids)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
len(item["meta_info"]["input_token_logprobs"]),
|
len(item["meta_info"]["input_token_logprobs"]),
|
||||||
len(input_ids),
|
num_input_logprobs,
|
||||||
f'{len(item["meta_info"]["input_token_logprobs"])} mismatch with {len(input_ids)}',
|
f'{len(item["meta_info"]["input_token_logprobs"])} mismatch with {len(input_ids)}',
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
@@ -113,19 +112,14 @@ class TestSkipTokenizerInit(CustomTestCase):
|
|||||||
requests.post(self.base_url + "/flush_cache")
|
requests.post(self.base_url + "/flush_cache")
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
self.base_url + "/generate",
|
self.base_url + "/generate",
|
||||||
json={
|
json=self.get_request_json(
|
||||||
"input_ids": input_ids,
|
input_ids=input_ids,
|
||||||
"sampling_params": {
|
max_new_tokens=max_new_tokens,
|
||||||
"temperature": 0 if n == 1 else 0.5,
|
return_logprob=return_logprob,
|
||||||
"max_new_tokens": max_new_tokens,
|
top_logprobs_num=top_logprobs_num,
|
||||||
"n": n,
|
stream=False,
|
||||||
"stop_token_ids": self.eos_token_id,
|
n=n,
|
||||||
},
|
),
|
||||||
"stream": False,
|
|
||||||
"return_logprob": return_logprob,
|
|
||||||
"top_logprobs_num": top_logprobs_num,
|
|
||||||
"logprob_start_len": 0,
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
ret = response.json()
|
ret = response.json()
|
||||||
print(json.dumps(ret))
|
print(json.dumps(ret))
|
||||||
@@ -137,19 +131,13 @@ class TestSkipTokenizerInit(CustomTestCase):
|
|||||||
requests.post(self.base_url + "/flush_cache")
|
requests.post(self.base_url + "/flush_cache")
|
||||||
response_stream = requests.post(
|
response_stream = requests.post(
|
||||||
self.base_url + "/generate",
|
self.base_url + "/generate",
|
||||||
json={
|
json=self.get_request_json(
|
||||||
"input_ids": input_ids,
|
input_ids=input_ids,
|
||||||
"sampling_params": {
|
return_logprob=return_logprob,
|
||||||
"temperature": 0 if n == 1 else 0.5,
|
top_logprobs_num=top_logprobs_num,
|
||||||
"max_new_tokens": max_new_tokens,
|
stream=True,
|
||||||
"n": n,
|
n=n,
|
||||||
"stop_token_ids": self.eos_token_id,
|
),
|
||||||
},
|
|
||||||
"stream": True,
|
|
||||||
"return_logprob": return_logprob,
|
|
||||||
"top_logprobs_num": top_logprobs_num,
|
|
||||||
"logprob_start_len": 0,
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
response_stream_json = []
|
response_stream_json = []
|
||||||
@@ -188,6 +176,29 @@ class TestSkipTokenizerInit(CustomTestCase):
|
|||||||
].tolist()
|
].tolist()
|
||||||
return input_ids
|
return input_ids
|
||||||
|
|
||||||
|
def get_request_json(
|
||||||
|
self,
|
||||||
|
input_ids,
|
||||||
|
max_new_tokens=32,
|
||||||
|
return_logprob=False,
|
||||||
|
top_logprobs_num=0,
|
||||||
|
stream=False,
|
||||||
|
n=1,
|
||||||
|
):
|
||||||
|
return {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0 if n == 1 else 0.5,
|
||||||
|
"max_new_tokens": max_new_tokens,
|
||||||
|
"n": n,
|
||||||
|
"stop_token_ids": self.eos_token_id,
|
||||||
|
},
|
||||||
|
"stream": stream,
|
||||||
|
"return_logprob": return_logprob,
|
||||||
|
"top_logprobs_num": top_logprobs_num,
|
||||||
|
"logprob_start_len": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class TestSkipTokenizerInitVLM(TestSkipTokenizerInit):
|
class TestSkipTokenizerInitVLM(TestSkipTokenizerInit):
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -218,6 +229,14 @@ class TestSkipTokenizerInitVLM(TestSkipTokenizerInit):
|
|||||||
|
|
||||||
return inputs.input_ids[0].tolist()
|
return inputs.input_ids[0].tolist()
|
||||||
|
|
||||||
|
def get_request_json(self, *args, **kwargs):
|
||||||
|
ret = super().get_request_json(*args, **kwargs)
|
||||||
|
ret["image_data"] = [self.image_url]
|
||||||
|
ret["logprob_start_len"] = (
|
||||||
|
-1
|
||||||
|
) # Do not try to calculate logprobs of image embeddings.
|
||||||
|
return ret
|
||||||
|
|
||||||
def test_simple_decode_stream(self):
|
def test_simple_decode_stream(self):
|
||||||
# TODO mick
|
# TODO mick
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -3,15 +3,22 @@
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import requests
|
import requests
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers import AutoModel, AutoProcessor, AutoTokenizer
|
from transformers import (
|
||||||
|
AutoModel,
|
||||||
|
AutoProcessor,
|
||||||
|
AutoTokenizer,
|
||||||
|
Gemma3ForConditionalGeneration,
|
||||||
|
Qwen2_5_VLForConditionalGeneration,
|
||||||
|
)
|
||||||
|
|
||||||
|
from sglang import Engine
|
||||||
from sglang.srt.configs.model_config import ModelConfig
|
from sglang.srt.configs.model_config import ModelConfig
|
||||||
from sglang.srt.conversation import generate_chat_conv
|
from sglang.srt.conversation import generate_chat_conv
|
||||||
from sglang.srt.managers.mm_utils import embed_mm_inputs
|
from sglang.srt.managers.mm_utils import embed_mm_inputs
|
||||||
@@ -100,7 +107,7 @@ class VisionLLMLogitsBase(unittest.IsolatedAsyncioTestCase):
|
|||||||
|
|
||||||
np.testing.assert_allclose(hf_np, sg_np)
|
np.testing.assert_allclose(hf_np, sg_np)
|
||||||
|
|
||||||
def get_processor_output(self):
|
def get_completion_request(self) -> ChatCompletionRequest:
|
||||||
json_str = f"""
|
json_str = f"""
|
||||||
{{
|
{{
|
||||||
"model": "{self.model_path}",
|
"model": "{self.model_path}",
|
||||||
@@ -124,10 +131,12 @@ class VisionLLMLogitsBase(unittest.IsolatedAsyncioTestCase):
|
|||||||
}}
|
}}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
req = ChatCompletionRequest.model_validate_json(json_str)
|
return ChatCompletionRequest.model_validate_json(json_str)
|
||||||
|
|
||||||
|
def get_processor_output(self, req: Optional[ChatCompletionRequest] = None):
|
||||||
|
if req is None:
|
||||||
|
req = self.get_completion_request()
|
||||||
conv = generate_chat_conv(req, template_name=self.chat_template)
|
conv = generate_chat_conv(req, template_name=self.chat_template)
|
||||||
|
|
||||||
text = conv.get_prompt()
|
text = conv.get_prompt()
|
||||||
|
|
||||||
# Process inputs using processor
|
# Process inputs using processor
|
||||||
@@ -239,5 +248,129 @@ class TestMiniCPMVLogits(VisionLLMLogitsBase):
|
|||||||
self.compare_outputs(sglang_output, hf_output)
|
self.compare_outputs(sglang_output, hf_output)
|
||||||
|
|
||||||
|
|
||||||
|
class TestQwenVLUnderstandsImage(VisionLLMLogitsBase):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
super().setUpClass()
|
||||||
|
cls.model_path = "Qwen/Qwen2.5-VL-3B-Instruct"
|
||||||
|
cls.chat_template = "qwen2-vl"
|
||||||
|
cls.processor = AutoProcessor.from_pretrained(
|
||||||
|
cls.model_path, trust_remote_code=True, use_fast=True
|
||||||
|
)
|
||||||
|
cls.visual = (
|
||||||
|
Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||||
|
cls.model_path, torch_dtype=torch.bfloat16
|
||||||
|
)
|
||||||
|
.eval()
|
||||||
|
.visual.to(cls.device)
|
||||||
|
)
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.engine = Engine(
|
||||||
|
model_path=self.model_path,
|
||||||
|
chat_template=self.chat_template,
|
||||||
|
device=self.device.type,
|
||||||
|
mem_fraction_static=0.8,
|
||||||
|
)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
self.engine.shutdown()
|
||||||
|
|
||||||
|
async def test_qwen_vl_understands_image(self):
|
||||||
|
req = self.get_completion_request()
|
||||||
|
conv = generate_chat_conv(req, template_name=self.chat_template)
|
||||||
|
text = conv.get_prompt()
|
||||||
|
output = await self.engine.async_generate(
|
||||||
|
prompt=text,
|
||||||
|
image_data=[self.main_image],
|
||||||
|
sampling_params=dict(temperature=0.0),
|
||||||
|
)
|
||||||
|
self.assertIn("taxi", output["text"].lower())
|
||||||
|
|
||||||
|
async def test_qwen_vl_understands_precomputed_features(self):
|
||||||
|
req = self.get_completion_request()
|
||||||
|
processor_output = self.get_processor_output(req=req)
|
||||||
|
with torch.inference_mode():
|
||||||
|
precomputed_features = self.visual(
|
||||||
|
processor_output["pixel_values"], processor_output["image_grid_thw"]
|
||||||
|
)
|
||||||
|
output = await self.engine.async_generate(
|
||||||
|
input_ids=processor_output["input_ids"][0].detach().cpu().tolist(),
|
||||||
|
image_data=[
|
||||||
|
dict(
|
||||||
|
modality="IMAGE",
|
||||||
|
image_grid_thws=processor_output["image_grid_thw"],
|
||||||
|
precomputed_features=precomputed_features,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
sampling_params=dict(temperature=0.0),
|
||||||
|
)
|
||||||
|
self.assertIn("taxi", output["text"].lower())
|
||||||
|
|
||||||
|
|
||||||
|
class TestGemmaUnderstandsImage(VisionLLMLogitsBase):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
super().setUpClass()
|
||||||
|
cls.model_path = "google/gemma-3-4b-it"
|
||||||
|
cls.chat_template = "gemma-it"
|
||||||
|
cls.processor = AutoProcessor.from_pretrained(
|
||||||
|
cls.model_path, trust_remote_code=True, use_fast=True
|
||||||
|
)
|
||||||
|
model = Gemma3ForConditionalGeneration.from_pretrained(
|
||||||
|
cls.model_path, torch_dtype=torch.bfloat16
|
||||||
|
)
|
||||||
|
cls.vision_tower = model.vision_tower.eval().to(cls.device)
|
||||||
|
cls.mm_projector = model.multi_modal_projector.eval().to(cls.device)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def visual(cls, pixel_values):
|
||||||
|
vision_outputs = cls.vision_tower(pixel_values=pixel_values).last_hidden_state
|
||||||
|
image_features = cls.mm_projector(vision_outputs)
|
||||||
|
return image_features
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.engine = Engine(
|
||||||
|
model_path=self.model_path,
|
||||||
|
chat_template=self.chat_template,
|
||||||
|
device=self.device.type,
|
||||||
|
mem_fraction_static=0.5,
|
||||||
|
enable_multimodal=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
self.engine.shutdown()
|
||||||
|
|
||||||
|
async def test_gemma_understands_image(self):
|
||||||
|
req = self.get_completion_request()
|
||||||
|
conv = generate_chat_conv(req, template_name=self.chat_template)
|
||||||
|
text = conv.get_prompt()
|
||||||
|
output = await self.engine.async_generate(
|
||||||
|
prompt=text,
|
||||||
|
image_data=[self.main_image],
|
||||||
|
sampling_params=dict(temperature=0.0),
|
||||||
|
)
|
||||||
|
self.assertIn("taxi", output["text"].lower())
|
||||||
|
|
||||||
|
async def test_gemma_understands_precomputed_features(self):
|
||||||
|
req = self.get_completion_request()
|
||||||
|
processor_output = self.get_processor_output(req=req)
|
||||||
|
with torch.inference_mode():
|
||||||
|
precomputed_features = self.visual(processor_output["pixel_values"])
|
||||||
|
output = await self.engine.async_generate(
|
||||||
|
input_ids=processor_output["input_ids"][0].detach().cpu().tolist(),
|
||||||
|
image_data=[
|
||||||
|
dict(
|
||||||
|
modality="IMAGE",
|
||||||
|
precomputed_features=precomputed_features,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
sampling_params=dict(temperature=0.0),
|
||||||
|
)
|
||||||
|
self.assertIn("taxi", output["text"].lower())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user