Support precomputed multimodal features for Qwen-VL and Gemma3 models. (#6136)
Co-authored-by: Yury Sulsky <ysulsky@tesla.com>
This commit is contained in:
@@ -3,16 +3,16 @@ import concurrent.futures
|
||||
import dataclasses
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import BaseImageProcessorFast
|
||||
|
||||
from sglang.srt.managers.schedule_batch import Modality
|
||||
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||
from sglang.srt.utils import encode_video, load_audio, load_image
|
||||
|
||||
|
||||
@@ -22,13 +22,13 @@ class BaseMultiModalProcessorOutput:
|
||||
input_text: str
|
||||
|
||||
# frames loaded from image and video, in given order
|
||||
images: Optional[list[PIL.Image]] = None
|
||||
images: Optional[list[Union[Image.Image, MultimodalDataItem]]] = None
|
||||
|
||||
# audios
|
||||
audios: Optional[list[np.ndarray]] = None
|
||||
audios: Optional[list[Union[np.ndarray, MultimodalDataItem]]] = None
|
||||
|
||||
def normalize(self):
|
||||
for field_name in ["image_sizes", "images", "audios"]:
|
||||
for field_name in ["images", "audios"]:
|
||||
field = getattr(self, field_name, None)
|
||||
if field is not None and isinstance(field, list) and len(field) == 0:
|
||||
setattr(self, field_name, None)
|
||||
@@ -40,12 +40,32 @@ class MultimodalSpecialTokens:
|
||||
video_token: Optional[str] = None
|
||||
audio_token: Optional[str] = None
|
||||
|
||||
def collect(self) -> list[str]:
|
||||
return [
|
||||
token
|
||||
for token in [self.image_token, self.video_token, self.audio_token]
|
||||
if token
|
||||
image_token_regex: Optional[re.Pattern] = None
|
||||
video_token_regex: Optional[re.Pattern] = None
|
||||
audio_token_regex: Optional[re.Pattern] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.image_token_regex is None and self.image_token is not None:
|
||||
self.image_token_regex = re.compile(re.escape(self.image_token))
|
||||
if self.video_token_regex is None and self.video_token is not None:
|
||||
self.video_token_regex = re.compile(re.escape(self.video_token))
|
||||
if self.audio_token_regex is None and self.audio_token is not None:
|
||||
self.audio_token_regex = re.compile(re.escape(self.audio_token))
|
||||
|
||||
def collect(self) -> re.Pattern:
|
||||
tokens = [
|
||||
self.image_token_regex,
|
||||
self.video_token_regex,
|
||||
self.audio_token_regex,
|
||||
]
|
||||
patterns = []
|
||||
flags = 0
|
||||
for t in tokens:
|
||||
if t is not None:
|
||||
patterns.append(t.pattern)
|
||||
flags |= t.flags
|
||||
combined = "(" + "|".join(f"(?:{p})" for p in patterns) + ")"
|
||||
return re.compile(combined, flags)
|
||||
|
||||
|
||||
class BaseMultimodalProcessor(ABC):
|
||||
@@ -136,6 +156,10 @@ class BaseMultimodalProcessor(ABC):
|
||||
data, is_video, is_audio, frame_count_limit=None, discard_alpha_channel=True
|
||||
):
|
||||
"""Static method that can be pickled for multiprocessing"""
|
||||
if isinstance(data, dict):
|
||||
return MultimodalDataItem.from_dict(data)
|
||||
if isinstance(data, MultimodalDataItem):
|
||||
return data
|
||||
try:
|
||||
if is_audio:
|
||||
return load_audio(data)
|
||||
@@ -175,7 +199,10 @@ class BaseMultimodalProcessor(ABC):
|
||||
image_index, audio_index = 0, 0
|
||||
|
||||
for text_part in text_parts:
|
||||
if text_part == multimodal_tokens.image_token:
|
||||
if (
|
||||
multimodal_tokens.image_token_regex
|
||||
and multimodal_tokens.image_token_regex.match(text_part)
|
||||
):
|
||||
data = image_data[image_index]
|
||||
is_video = isinstance(data, str) and data.startswith("video:")
|
||||
estimated_frames = estimated_frames_list[image_index]
|
||||
@@ -192,7 +219,10 @@ class BaseMultimodalProcessor(ABC):
|
||||
)
|
||||
task_info.append((Modality.IMAGE, data, frame_count_limit))
|
||||
image_index += 1
|
||||
elif text_part == multimodal_tokens.audio_token:
|
||||
elif (
|
||||
multimodal_tokens.audio_token_regex
|
||||
and multimodal_tokens.audio_token_regex.match(text_part)
|
||||
):
|
||||
data = audio_data[audio_index]
|
||||
futures.append(
|
||||
self.io_executor.submit(
|
||||
@@ -228,17 +258,22 @@ class BaseMultimodalProcessor(ABC):
|
||||
discard_alpha_channel: if True, discards the alpha channel in the returned images
|
||||
|
||||
"""
|
||||
if not return_text:
|
||||
raise NotImplementedError()
|
||||
|
||||
if image_data is None:
|
||||
image_data = []
|
||||
if isinstance(multimodal_tokens.image_token, int):
|
||||
multimodal_tokens.image_token = (
|
||||
self._processor.tokenizer.convert_ids_to_tokens(
|
||||
multimodal_tokens.image_token
|
||||
multimodal_tokens.image_token = re.compile(
|
||||
re.escape(
|
||||
self._processor.tokenizer.convert_ids_to_tokens(
|
||||
multimodal_tokens.image_token
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
multimodal_tokens.image_token = multimodal_tokens.image_token
|
||||
multimodal_tokens_pattern = multimodal_tokens.collect()
|
||||
|
||||
if isinstance(prompt, list) and return_text:
|
||||
assert len(prompt) and isinstance(prompt[0], int)
|
||||
@@ -247,16 +282,8 @@ class BaseMultimodalProcessor(ABC):
|
||||
prompt = prompt
|
||||
|
||||
assert isinstance(prompt, str)
|
||||
if return_text:
|
||||
import re
|
||||
|
||||
pattern = (
|
||||
"("
|
||||
+ "|".join(re.escape(sep) for sep in multimodal_tokens.collect())
|
||||
+ ")"
|
||||
)
|
||||
# split text into list of normal text and special tokens
|
||||
text_parts = re.split(pattern, prompt)
|
||||
# split text into list of normal text and special tokens
|
||||
text_parts = re.split(multimodal_tokens_pattern, prompt)
|
||||
|
||||
futures, task_info = self.submit_data_loading_tasks(
|
||||
text_parts=text_parts,
|
||||
@@ -266,26 +293,40 @@ class BaseMultimodalProcessor(ABC):
|
||||
discard_alpha_channel=discard_alpha_channel,
|
||||
)
|
||||
# Process results
|
||||
image_sizes, images, audios = [], [], []
|
||||
images, audios = [], []
|
||||
new_text = ""
|
||||
task_ptr = 0
|
||||
|
||||
for text_part in text_parts:
|
||||
if text_part in multimodal_tokens.collect():
|
||||
if multimodal_tokens_pattern.match(text_part):
|
||||
task_type, data, frame_limit = task_info[task_ptr]
|
||||
result = futures[task_ptr].result()
|
||||
task_ptr += 1
|
||||
|
||||
if task_type == Modality.IMAGE:
|
||||
# If data is already processed it will be a
|
||||
# dictionary. In this case we want to keep the
|
||||
# expanded tokens in text_part. Otherwise, we will
|
||||
# call the processor code, so keep only a single image
|
||||
# token.
|
||||
mm_tokens = (
|
||||
text_part
|
||||
if isinstance(data, dict)
|
||||
else multimodal_tokens.image_token
|
||||
)
|
||||
frames = [result] if not isinstance(result, list) else result
|
||||
if frames:
|
||||
image_sizes += frames[0].size * len(frames)
|
||||
images += frames
|
||||
new_text += multimodal_tokens.image_token * len(frames)
|
||||
new_text += mm_tokens * len(frames)
|
||||
elif task_type == Modality.AUDIO:
|
||||
# audio
|
||||
mm_tokens = (
|
||||
text_part
|
||||
if isinstance(data, dict)
|
||||
else multimodal_tokens.audio_token
|
||||
)
|
||||
audios.append(result)
|
||||
new_text += multimodal_tokens.audio_token
|
||||
new_text += mm_tokens
|
||||
# TODO: handle video
|
||||
else:
|
||||
new_text += text_part
|
||||
@@ -297,3 +338,16 @@ class BaseMultimodalProcessor(ABC):
|
||||
)
|
||||
out.normalize()
|
||||
return out
|
||||
|
||||
def mm_inputs_are_preprocessed(self, mm_inputs: Optional[list]):
|
||||
"""Returns true if all images are preprocessed, false if all are not, and error otherwise."""
|
||||
if not mm_inputs:
|
||||
return True
|
||||
ret = any(isinstance(mm_input, MultimodalDataItem) for mm_input in mm_inputs)
|
||||
if ret and not all(
|
||||
isinstance(mm_input, MultimodalDataItem) for mm_input in mm_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
"Unsupported: mixture of multimodal inputs where some but not all are preprocessed."
|
||||
)
|
||||
return ret
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import List, Union
|
||||
import re
|
||||
from typing import Dict, List, Union
|
||||
|
||||
from sglang.srt.managers.multimodal_processor import (
|
||||
BaseMultimodalProcessor as SGLangBaseProcessor,
|
||||
@@ -18,13 +19,18 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
# The single, pre-expanded image token.
|
||||
self.IMAGE_TOKEN = "<start_of_image>"
|
||||
# The regex that matches expanded image tokens.
|
||||
self.IMAGE_TOKEN_REGEX = re.compile(
|
||||
r"<start_of_image>(?:(?:<image_soft_token>)*<end_of_image>)?"
|
||||
)
|
||||
self.IM_START_TOKEN_ID = hf_config.boi_token_index
|
||||
self.IM_END_TOKEN_ID = hf_config.eoi_token_index
|
||||
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
image_data: List[Union[str, bytes]],
|
||||
image_data: List[Union[str, bytes, Dict]],
|
||||
input_text,
|
||||
request_obj,
|
||||
max_req_input_len,
|
||||
@@ -37,22 +43,35 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
|
||||
image_data = [image_data]
|
||||
|
||||
image_token = self.IMAGE_TOKEN
|
||||
image_token_regex = self.IMAGE_TOKEN_REGEX
|
||||
base_output = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
image_data=image_data,
|
||||
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
|
||||
multimodal_tokens=MultimodalSpecialTokens(
|
||||
image_token=image_token, image_token_regex=image_token_regex
|
||||
),
|
||||
max_req_input_len=max_req_input_len,
|
||||
discard_alpha_channel=True,
|
||||
)
|
||||
|
||||
images_are_preprocessed = self.mm_inputs_are_preprocessed(base_output.images)
|
||||
ret = self.process_mm_data(
|
||||
input_text=base_output.input_text, images=base_output.images
|
||||
input_text=base_output.input_text,
|
||||
images=None if images_are_preprocessed else base_output.images,
|
||||
)
|
||||
|
||||
items = []
|
||||
for i, image in enumerate(base_output.images):
|
||||
if images_are_preprocessed:
|
||||
pixel_values = image.pixel_values
|
||||
precomputed_features = image.precomputed_features
|
||||
else:
|
||||
pixel_values = ret["pixel_values"][i]
|
||||
precomputed_features = None
|
||||
|
||||
item = MultimodalDataItem(
|
||||
pixel_values=ret["pixel_values"][i],
|
||||
pixel_values=pixel_values,
|
||||
precomputed_features=precomputed_features,
|
||||
modality=Modality.IMAGE,
|
||||
)
|
||||
items += [item]
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import math
|
||||
from typing import List, Union
|
||||
import re
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
@@ -23,7 +24,12 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
# The single, pre-expanded image token.
|
||||
self.IMAGE_TOKEN = "<|vision_start|><|image_pad|><|vision_end|>"
|
||||
# The regex that matches expanded image tokens.
|
||||
self.IMAGE_TOKEN_REGEX = re.compile(
|
||||
r"<\|vision_start\|>(?:<\|image_pad\|>)+<\|vision_end\|>"
|
||||
)
|
||||
self.IM_START_TOKEN_ID = hf_config.vision_start_token_id
|
||||
self.IM_END_TOKEN_ID = hf_config.vision_end_token_id
|
||||
self.image_token_id = hf_config.image_token_id
|
||||
@@ -38,7 +44,7 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
||||
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
image_data: List[Union[str, bytes]],
|
||||
image_data: List[Union[str, bytes, Dict]],
|
||||
input_text,
|
||||
request_obj,
|
||||
max_req_input_len,
|
||||
@@ -48,11 +54,13 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
||||
if isinstance(image_data, str):
|
||||
image_data = [image_data]
|
||||
|
||||
image_token = self.IMAGE_TOKEN
|
||||
base_output = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
image_data=image_data,
|
||||
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
|
||||
multimodal_tokens=MultimodalSpecialTokens(
|
||||
image_token=self.IMAGE_TOKEN,
|
||||
image_token_regex=self.IMAGE_TOKEN_REGEX,
|
||||
),
|
||||
max_req_input_len=max_req_input_len,
|
||||
)
|
||||
|
||||
@@ -117,26 +125,56 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
||||
async def resize_image_async(image):
|
||||
return resize_image(image)
|
||||
|
||||
if base_output.images:
|
||||
images_are_preprocessed = self.mm_inputs_are_preprocessed(base_output.images)
|
||||
if base_output.images and not images_are_preprocessed:
|
||||
resize_tasks = [resize_image_async(image) for image in base_output.images]
|
||||
base_output.images = await asyncio.gather(*resize_tasks)
|
||||
|
||||
ret = self.process_mm_data(
|
||||
input_text=base_output.input_text,
|
||||
images=base_output.images,
|
||||
images=None if images_are_preprocessed else base_output.images,
|
||||
)
|
||||
|
||||
input_ids = ret["input_ids"].flatten().tolist()
|
||||
image_grid_thw = None
|
||||
video_grid_thw = None # TODO
|
||||
items = []
|
||||
|
||||
input_ids = ret["input_ids"].flatten().tolist()
|
||||
if "pixel_values" in ret:
|
||||
if base_output.images:
|
||||
if images_are_preprocessed:
|
||||
image_grid_thw = torch.concat(
|
||||
[
|
||||
torch.as_tensor(item.image_grid_thws)
|
||||
for item in base_output.images
|
||||
]
|
||||
)
|
||||
all_pixel_values = [
|
||||
item.pixel_values
|
||||
for item in base_output.images
|
||||
if item.pixel_values is not None
|
||||
]
|
||||
all_precomputed_features = [
|
||||
item.precomputed_features
|
||||
for item in base_output.images
|
||||
if item.precomputed_features is not None
|
||||
]
|
||||
pixel_values = (
|
||||
torch.concat(all_pixel_values) if all_pixel_values else None
|
||||
)
|
||||
precomputed_features = (
|
||||
torch.concat(all_precomputed_features)
|
||||
if all_precomputed_features
|
||||
else None
|
||||
)
|
||||
else:
|
||||
image_grid_thw = ret["image_grid_thw"]
|
||||
pixel_values = ret["pixel_values"]
|
||||
precomputed_features = None
|
||||
items += [
|
||||
MultimodalDataItem(
|
||||
pixel_values=ret["pixel_values"],
|
||||
image_grid_thws=torch.concat([ret["image_grid_thw"]]),
|
||||
# TODO
|
||||
video_grid_thws=None,
|
||||
second_per_grid_ts=ret.get("second_per_grid_ts", None),
|
||||
pixel_values=pixel_values,
|
||||
image_grid_thws=image_grid_thw,
|
||||
video_grid_thws=video_grid_thw,
|
||||
precomputed_features=precomputed_features,
|
||||
modality=Modality.IMAGE,
|
||||
)
|
||||
]
|
||||
@@ -151,8 +189,8 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
||||
self.hf_config.vision_config, "tokens_per_second", None
|
||||
),
|
||||
input_ids=torch.tensor(input_ids).unsqueeze(0),
|
||||
image_grid_thw=ret.get("image_grid_thw", None),
|
||||
video_grid_thw=ret.get("video_grid_thw", None),
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
second_per_grid_ts=ret.get("second_per_grid_ts", None),
|
||||
)
|
||||
mrope_positions = mrope_positions.squeeze(1)
|
||||
|
||||
Reference in New Issue
Block a user