Support precomputed multimodal features for Qwen-VL and Gemma3 models. (#6136)

Co-authored-by: Yury Sulsky <ysulsky@tesla.com>
This commit is contained in:
Yury Sulsky
2025-05-16 12:26:15 -07:00
committed by GitHub
parent c23a7072b6
commit f19a9204cd
14 changed files with 592 additions and 125 deletions

View File

@@ -3,16 +3,16 @@ import concurrent.futures
import dataclasses
import multiprocessing as mp
import os
import re
from abc import ABC, abstractmethod
from typing import List, Optional
from typing import List, Optional, Union
import numpy as np
import PIL
import torch
from PIL import Image
from transformers import BaseImageProcessorFast
from sglang.srt.managers.schedule_batch import Modality
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.utils import encode_video, load_audio, load_image
@@ -22,13 +22,13 @@ class BaseMultiModalProcessorOutput:
input_text: str
# frames loaded from image and video, in given order
images: Optional[list[PIL.Image]] = None
images: Optional[list[Union[Image.Image, MultimodalDataItem]]] = None
# audios
audios: Optional[list[np.ndarray]] = None
audios: Optional[list[Union[np.ndarray, MultimodalDataItem]]] = None
def normalize(self):
for field_name in ["image_sizes", "images", "audios"]:
for field_name in ["images", "audios"]:
field = getattr(self, field_name, None)
if field is not None and isinstance(field, list) and len(field) == 0:
setattr(self, field_name, None)
@@ -40,12 +40,32 @@ class MultimodalSpecialTokens:
video_token: Optional[str] = None
audio_token: Optional[str] = None
def collect(self) -> list[str]:
return [
token
for token in [self.image_token, self.video_token, self.audio_token]
if token
image_token_regex: Optional[re.Pattern] = None
video_token_regex: Optional[re.Pattern] = None
audio_token_regex: Optional[re.Pattern] = None
def __post_init__(self):
if self.image_token_regex is None and self.image_token is not None:
self.image_token_regex = re.compile(re.escape(self.image_token))
if self.video_token_regex is None and self.video_token is not None:
self.video_token_regex = re.compile(re.escape(self.video_token))
if self.audio_token_regex is None and self.audio_token is not None:
self.audio_token_regex = re.compile(re.escape(self.audio_token))
def collect(self) -> re.Pattern:
tokens = [
self.image_token_regex,
self.video_token_regex,
self.audio_token_regex,
]
patterns = []
flags = 0
for t in tokens:
if t is not None:
patterns.append(t.pattern)
flags |= t.flags
combined = "(" + "|".join(f"(?:{p})" for p in patterns) + ")"
return re.compile(combined, flags)
class BaseMultimodalProcessor(ABC):
@@ -136,6 +156,10 @@ class BaseMultimodalProcessor(ABC):
data, is_video, is_audio, frame_count_limit=None, discard_alpha_channel=True
):
"""Static method that can be pickled for multiprocessing"""
if isinstance(data, dict):
return MultimodalDataItem.from_dict(data)
if isinstance(data, MultimodalDataItem):
return data
try:
if is_audio:
return load_audio(data)
@@ -175,7 +199,10 @@ class BaseMultimodalProcessor(ABC):
image_index, audio_index = 0, 0
for text_part in text_parts:
if text_part == multimodal_tokens.image_token:
if (
multimodal_tokens.image_token_regex
and multimodal_tokens.image_token_regex.match(text_part)
):
data = image_data[image_index]
is_video = isinstance(data, str) and data.startswith("video:")
estimated_frames = estimated_frames_list[image_index]
@@ -192,7 +219,10 @@ class BaseMultimodalProcessor(ABC):
)
task_info.append((Modality.IMAGE, data, frame_count_limit))
image_index += 1
elif text_part == multimodal_tokens.audio_token:
elif (
multimodal_tokens.audio_token_regex
and multimodal_tokens.audio_token_regex.match(text_part)
):
data = audio_data[audio_index]
futures.append(
self.io_executor.submit(
@@ -228,17 +258,22 @@ class BaseMultimodalProcessor(ABC):
discard_alpha_channel: if True, discards the alpha channel in the returned images
"""
if not return_text:
raise NotImplementedError()
if image_data is None:
image_data = []
if isinstance(multimodal_tokens.image_token, int):
multimodal_tokens.image_token = (
self._processor.tokenizer.convert_ids_to_tokens(
multimodal_tokens.image_token
multimodal_tokens.image_token = re.compile(
re.escape(
self._processor.tokenizer.convert_ids_to_tokens(
multimodal_tokens.image_token
)
)
)
else:
multimodal_tokens.image_token = multimodal_tokens.image_token
multimodal_tokens_pattern = multimodal_tokens.collect()
if isinstance(prompt, list) and return_text:
assert len(prompt) and isinstance(prompt[0], int)
@@ -247,16 +282,8 @@ class BaseMultimodalProcessor(ABC):
prompt = prompt
assert isinstance(prompt, str)
if return_text:
import re
pattern = (
"("
+ "|".join(re.escape(sep) for sep in multimodal_tokens.collect())
+ ")"
)
# split text into list of normal text and special tokens
text_parts = re.split(pattern, prompt)
# split text into list of normal text and special tokens
text_parts = re.split(multimodal_tokens_pattern, prompt)
futures, task_info = self.submit_data_loading_tasks(
text_parts=text_parts,
@@ -266,26 +293,40 @@ class BaseMultimodalProcessor(ABC):
discard_alpha_channel=discard_alpha_channel,
)
# Process results
image_sizes, images, audios = [], [], []
images, audios = [], []
new_text = ""
task_ptr = 0
for text_part in text_parts:
if text_part in multimodal_tokens.collect():
if multimodal_tokens_pattern.match(text_part):
task_type, data, frame_limit = task_info[task_ptr]
result = futures[task_ptr].result()
task_ptr += 1
if task_type == Modality.IMAGE:
# If data is already processed it will be a
# dictionary. In this case we want to keep the
# expanded tokens in text_part. Otherwise, we will
# call the processor code, so keep only a single image
# token.
mm_tokens = (
text_part
if isinstance(data, dict)
else multimodal_tokens.image_token
)
frames = [result] if not isinstance(result, list) else result
if frames:
image_sizes += frames[0].size * len(frames)
images += frames
new_text += multimodal_tokens.image_token * len(frames)
new_text += mm_tokens * len(frames)
elif task_type == Modality.AUDIO:
# audio
mm_tokens = (
text_part
if isinstance(data, dict)
else multimodal_tokens.audio_token
)
audios.append(result)
new_text += multimodal_tokens.audio_token
new_text += mm_tokens
# TODO: handle video
else:
new_text += text_part
@@ -297,3 +338,16 @@ class BaseMultimodalProcessor(ABC):
)
out.normalize()
return out
def mm_inputs_are_preprocessed(self, mm_inputs: Optional[list]):
"""Returns true if all images are preprocessed, false if all are not, and error otherwise."""
if not mm_inputs:
return True
ret = any(isinstance(mm_input, MultimodalDataItem) for mm_input in mm_inputs)
if ret and not all(
isinstance(mm_input, MultimodalDataItem) for mm_input in mm_inputs
):
raise ValueError(
"Unsupported: mixture of multimodal inputs where some but not all are preprocessed."
)
return ret

View File

@@ -1,4 +1,5 @@
from typing import List, Union
import re
from typing import Dict, List, Union
from sglang.srt.managers.multimodal_processor import (
BaseMultimodalProcessor as SGLangBaseProcessor,
@@ -18,13 +19,18 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
# The single, pre-expanded image token.
self.IMAGE_TOKEN = "<start_of_image>"
# The regex that matches expanded image tokens.
self.IMAGE_TOKEN_REGEX = re.compile(
r"<start_of_image>(?:(?:<image_soft_token>)*<end_of_image>)?"
)
self.IM_START_TOKEN_ID = hf_config.boi_token_index
self.IM_END_TOKEN_ID = hf_config.eoi_token_index
async def process_mm_data_async(
self,
image_data: List[Union[str, bytes]],
image_data: List[Union[str, bytes, Dict]],
input_text,
request_obj,
max_req_input_len,
@@ -37,22 +43,35 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
image_data = [image_data]
image_token = self.IMAGE_TOKEN
image_token_regex = self.IMAGE_TOKEN_REGEX
base_output = self.load_mm_data(
prompt=input_text,
image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
multimodal_tokens=MultimodalSpecialTokens(
image_token=image_token, image_token_regex=image_token_regex
),
max_req_input_len=max_req_input_len,
discard_alpha_channel=True,
)
images_are_preprocessed = self.mm_inputs_are_preprocessed(base_output.images)
ret = self.process_mm_data(
input_text=base_output.input_text, images=base_output.images
input_text=base_output.input_text,
images=None if images_are_preprocessed else base_output.images,
)
items = []
for i, image in enumerate(base_output.images):
if images_are_preprocessed:
pixel_values = image.pixel_values
precomputed_features = image.precomputed_features
else:
pixel_values = ret["pixel_values"][i]
precomputed_features = None
item = MultimodalDataItem(
pixel_values=ret["pixel_values"][i],
pixel_values=pixel_values,
precomputed_features=precomputed_features,
modality=Modality.IMAGE,
)
items += [item]

View File

@@ -1,6 +1,7 @@
import asyncio
import math
from typing import List, Union
import re
from typing import Dict, List, Union
import torch
from PIL import Image
@@ -23,7 +24,12 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
# The single, pre-expanded image token.
self.IMAGE_TOKEN = "<|vision_start|><|image_pad|><|vision_end|>"
# The regex that matches expanded image tokens.
self.IMAGE_TOKEN_REGEX = re.compile(
r"<\|vision_start\|>(?:<\|image_pad\|>)+<\|vision_end\|>"
)
self.IM_START_TOKEN_ID = hf_config.vision_start_token_id
self.IM_END_TOKEN_ID = hf_config.vision_end_token_id
self.image_token_id = hf_config.image_token_id
@@ -38,7 +44,7 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
async def process_mm_data_async(
self,
image_data: List[Union[str, bytes]],
image_data: List[Union[str, bytes, Dict]],
input_text,
request_obj,
max_req_input_len,
@@ -48,11 +54,13 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
if isinstance(image_data, str):
image_data = [image_data]
image_token = self.IMAGE_TOKEN
base_output = self.load_mm_data(
prompt=input_text,
image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
multimodal_tokens=MultimodalSpecialTokens(
image_token=self.IMAGE_TOKEN,
image_token_regex=self.IMAGE_TOKEN_REGEX,
),
max_req_input_len=max_req_input_len,
)
@@ -117,26 +125,56 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
async def resize_image_async(image):
return resize_image(image)
if base_output.images:
images_are_preprocessed = self.mm_inputs_are_preprocessed(base_output.images)
if base_output.images and not images_are_preprocessed:
resize_tasks = [resize_image_async(image) for image in base_output.images]
base_output.images = await asyncio.gather(*resize_tasks)
ret = self.process_mm_data(
input_text=base_output.input_text,
images=base_output.images,
images=None if images_are_preprocessed else base_output.images,
)
input_ids = ret["input_ids"].flatten().tolist()
image_grid_thw = None
video_grid_thw = None # TODO
items = []
input_ids = ret["input_ids"].flatten().tolist()
if "pixel_values" in ret:
if base_output.images:
if images_are_preprocessed:
image_grid_thw = torch.concat(
[
torch.as_tensor(item.image_grid_thws)
for item in base_output.images
]
)
all_pixel_values = [
item.pixel_values
for item in base_output.images
if item.pixel_values is not None
]
all_precomputed_features = [
item.precomputed_features
for item in base_output.images
if item.precomputed_features is not None
]
pixel_values = (
torch.concat(all_pixel_values) if all_pixel_values else None
)
precomputed_features = (
torch.concat(all_precomputed_features)
if all_precomputed_features
else None
)
else:
image_grid_thw = ret["image_grid_thw"]
pixel_values = ret["pixel_values"]
precomputed_features = None
items += [
MultimodalDataItem(
pixel_values=ret["pixel_values"],
image_grid_thws=torch.concat([ret["image_grid_thw"]]),
# TODO
video_grid_thws=None,
second_per_grid_ts=ret.get("second_per_grid_ts", None),
pixel_values=pixel_values,
image_grid_thws=image_grid_thw,
video_grid_thws=video_grid_thw,
precomputed_features=precomputed_features,
modality=Modality.IMAGE,
)
]
@@ -151,8 +189,8 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
self.hf_config.vision_config, "tokens_per_second", None
),
input_ids=torch.tensor(input_ids).unsqueeze(0),
image_grid_thw=ret.get("image_grid_thw", None),
video_grid_thw=ret.get("video_grid_thw", None),
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
second_per_grid_ts=ret.get("second_per_grid_ts", None),
)
mrope_positions = mrope_positions.squeeze(1)