Files
sglang/python/sglang/srt/multimodal/processors/base_processor.py
2025-10-16 15:00:03 -07:00

665 lines
24 KiB
Python

import concurrent
import concurrent.futures
import dataclasses
import multiprocessing as mp
import os
import re
from abc import ABC, abstractmethod
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
import numpy as np
import torch
from PIL import Image
from transformers import BaseImageProcessorFast
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.utils import is_npu, load_audio, load_image, load_video, logger
_is_npu = is_npu()
@dataclasses.dataclass
class BaseMultiModalProcessorOutput:
# input_text, with each frame of video/image represented with a image_token
input_text: str
# frames loaded from image, in given order
images: Optional[list[Union[Image.Image, dict]]] = dataclasses.field(
default_factory=list
)
# videos
videos: Optional[list[Union[torch.Tensor, dict]]] = dataclasses.field(
default_factory=list
)
# audios
audios: Optional[list[Union[np.ndarray, dict]]] = dataclasses.field(
default_factory=list
)
def organize_results(self) -> List[Tuple[Modality, Any]]:
"""
:return: a list of results, with their corresponding modalities
"""
return (
[(Modality.IMAGE, data) for data in self.images]
+ [(Modality.VIDEO, data) for data in self.videos]
+ [(Modality.AUDIO, data) for data in self.audios]
)
@dataclasses.dataclass
class MultimodalSpecialTokens:
image_token: Optional[Union[str, List[str]]] = None
video_token: Optional[Union[str, List[str]]] = None
audio_token: Optional[Union[str, List[str]]] = None
image_token_id: Optional[int] = None
video_token_id: Optional[int] = None
audio_token_id: Optional[int] = None
image_token_regex: Optional[re.Pattern] = None
video_token_regex: Optional[re.Pattern] = None
audio_token_regex: Optional[re.Pattern] = None
combined_regex: Optional[re.Pattern] = None
def build(self, processor):
self.convert_to_strs(processor)
self.parse_regex()
self.get_combined_regex()
return self
def convert_to_str(self, token: Union[str, int], processor) -> str:
if token is None:
return token
if isinstance(token, str):
return token
return processor.tokenizer.convert_ids_to_tokens([token])[0]
def convert_to_strs(self, processor):
if not self.image_token:
self.image_token = self.convert_to_str(self.image_token_id, processor)
if not self.video_token:
self.video_token = self.convert_to_str(self.video_token_id, processor)
if not self.audio_token:
self.audio_token = self.convert_to_str(self.audio_token_id, processor)
def get_modality_of_token(self, token: str) -> Optional[Modality]:
"""
:return: the modality associated with the given token, if the token is a special_token or matches with the multimodal token regex
"""
modality = {
self.image_token: Modality.IMAGE,
self.video_token: Modality.VIDEO,
self.audio_token: Modality.AUDIO,
}.get(token)
if modality:
return modality
for regex, modality in [
(self.image_token_regex, Modality.IMAGE),
(self.video_token_regex, Modality.VIDEO),
(self.audio_token_regex, Modality.AUDIO),
]:
if regex and regex.match(token):
return modality
return None
def get_token_id_by_modality(self, modality: Modality) -> Optional[int]:
return {
Modality.IMAGE: self.image_token_id,
Modality.MULTI_IMAGES: self.image_token_id,
Modality.VIDEO: self.video_token_id,
Modality.AUDIO: self.audio_token_id,
}.get(modality)
def parse_regex(self):
if self.image_token_regex is None and self.image_token is not None:
self.image_token_regex = re.compile(re.escape(self.image_token))
if self.video_token_regex is None and self.video_token is not None:
self.video_token_regex = re.compile(re.escape(self.video_token))
if self.audio_token_regex is None and self.audio_token is not None:
self.audio_token_regex = re.compile(re.escape(self.audio_token))
def get_combined_regex(self) -> re.Pattern:
"""
Builds and returns a regex, used to split input str into tokens (with mm special tokens)
"""
if self.combined_regex:
return self.combined_regex
tokens = [
self.image_token_regex,
self.video_token_regex,
self.audio_token_regex,
]
patterns = []
flags = 0
for t in tokens:
if t is not None:
patterns.append(t.pattern)
flags |= t.flags
combined = "(" + "|".join(f"(?:{p})" for p in patterns) + ")"
self.combined_regex = re.compile(combined, flags)
return self.combined_regex
class BaseMultimodalProcessor(ABC):
models = []
def __init__(
self, hf_config, server_args, _processor, transport_mode, *args, **kwargs
):
self.hf_config = hf_config
self._processor = _processor
self.server_args = server_args
self.transport_mode = transport_mode
# FIXME: not accurate, model and image specific
self.NUM_TOKEN_PER_FRAME = 330
self.io_executor = concurrent.futures.ThreadPoolExecutor(
max_workers=int(os.environ.get("SGLANG_IO_WORKERS", 4))
)
self.cpu_executor = concurrent.futures.ProcessPoolExecutor(
mp_context=mp.get_context("fork"),
max_workers=int(os.environ.get("SGLANG_CPU_WORKERS", os.cpu_count())),
)
# Mapping from attribute names to modality types
self.ATTR_NAME_TO_MODALITY = {
# Image-related attributes
"pixel_values": Modality.IMAGE,
"image_sizes": Modality.IMAGE,
"image_grid_thw": Modality.IMAGE,
"image_attention_mask": Modality.IMAGE,
"image_emb_mask": Modality.IMAGE,
"images_spatial_crop": Modality.IMAGE,
"tgt_size": Modality.IMAGE,
"image_grid_hws": Modality.IMAGE,
"aspect_ratio_ids": Modality.IMAGE,
"aspect_ratio_mask": Modality.IMAGE,
"num_patches": Modality.IMAGE,
"patch_pixel_values": Modality.IMAGE,
# Audio-related attributes
"audio_features": Modality.AUDIO,
"audio_feature_lens": Modality.AUDIO,
"input_features": Modality.AUDIO,
"input_features_mask": Modality.AUDIO,
"audio_attention_mask": Modality.AUDIO,
"feature_attention_mask": Modality.AUDIO,
# Video-related attributes
"pixel_values_videos": Modality.VIDEO,
"second_per_grid_ts": Modality.VIDEO,
"video_grid_thw": Modality.VIDEO,
# Generic attributes that could apply to multiple modalities
# "precomputed_embeddings" - handled specially as it can be any modality
}
# name of the feature filed
# TODO: pass from processors
self.FEATURE_NAMES = [
"pixel_values",
"pixel_values_videos",
"audio_features",
"input_features",
]
def process_mm_data(
self, input_text, images=None, videos=None, audios=None, **kwargs
) -> dict:
"""
process multimodal data with transformers AutoProcessor
"""
if images:
kwargs["images"] = images
if videos:
kwargs["videos"] = videos
if audios:
if self._processor.__class__.__name__ in {
"Gemma3nProcessor",
"Qwen2AudioProcessor",
"Qwen3OmniMoeProcessor",
}:
# Note(Xinyuan): for gemma3n, ref: https://github.com/huggingface/transformers/blob/ccf2ca162e33f381e454cdb74bf4b41a51ab976d/src/transformers/models/gemma3n/processing_gemma3n.py#L107
kwargs["audio"] = audios
else:
kwargs["audios"] = audios
processor = self._processor
if (
hasattr(processor, "image_processor")
and isinstance(processor.image_processor, BaseImageProcessorFast)
and not self.server_args.disable_fast_image_processor
):
if not _is_npu:
kwargs["device"] = "cuda"
elif processor.__class__.__name__ not in {
"Qwen2_5_VLProcessor",
"Qwen3VLProcessor",
}:
# Note: for qwen-vl, processor has some reshape issue because of dims restriction on Ascend.
kwargs["device"] = "npu"
result = processor.__call__(
text=[input_text],
padding=True,
return_tensors="pt",
**kwargs,
)
if not self.server_args.keep_mm_feature_on_device:
# move feature tensors to cpu
for feature_name in self.FEATURE_NAMES:
if feature_name in result and isinstance(
result[feature_name], torch.Tensor
):
result[feature_name] = result[feature_name].to("cpu")
return result
@abstractmethod
async def process_mm_data_async(
self,
image_data,
audio_data,
input_text,
request_obj,
**kwargs,
) -> Optional[Dict[str, Any]]:
pass
def get_estimated_frames_list(self, image_data):
"""
estimate the total frame count from all visual input
"""
# Lazy import because decord is not available on some arm platforms.
from decord import VideoReader, cpu
# Before processing inputs
if not image_data or len(image_data) == 0:
return []
estimated_frames_list = []
for image in image_data:
if isinstance(image, str) and image.startswith("video:"):
path = image[len("video:") :]
# Estimate frames for the video
vr = VideoReader(path, ctx=cpu(0))
num_frames = len(vr)
else:
# For images, each contributes one frame
num_frames = 1
estimated_frames_list.append(num_frames)
return estimated_frames_list
@staticmethod
def _load_single_item(
data,
modality: Modality,
frame_count_limit=None,
audio_sample_rate: Optional[int] = None,
discard_alpha_channel=True,
):
"""
Load a single multimodal data.
If data is precomputed, returns directly.
Static method that can be pickled for multiprocessing"""
if isinstance(data, dict):
return data
try:
if modality == Modality.IMAGE:
img, _ = load_image(data)
if discard_alpha_channel and img.mode != "RGB":
img = img.convert("RGB")
return img
elif modality == Modality.VIDEO:
return load_video(data, frame_count_limit)
elif modality == Modality.AUDIO:
return load_audio(data, audio_sample_rate)
except Exception as e:
raise RuntimeError(f"Error while loading data {data}: {e}")
def submit_data_loading_tasks(
self,
text_parts: List[str],
multimodal_tokens: MultimodalSpecialTokens,
data_iterators: dict[Modality, Iterator[Any]],
discard_alpha_channel: bool = True,
image_estimated_frames_iter: Optional[iter] = None,
image_scaling_factor: float = 1.0,
max_image_frames: int = 30,
audio_sample_rate: Optional[int] = None,
) -> Tuple[List, List]:
"""
load multimodal data parallelly using iterators.
"""
futures = []
task_info = []
for text_part in text_parts:
modality = multimodal_tokens.get_modality_of_token(text_part)
if modality is not None:
data_iterator = data_iterators.get(modality)
if data_iterator is None:
raise ValueError(f"No data iterator found for token: {text_part}")
try:
data = next(data_iterator)
except StopIteration:
raise ValueError(
f"Mismatch: More '{text_part}' tokens found than corresponding data items provided."
)
frame_count_limit = None
if modality == Modality.IMAGE and image_estimated_frames_iter:
try:
estimated_frames = next(image_estimated_frames_iter)
# Use the pre-calculated scaling factor and max frames
frame_count_limit = max(
1, int(estimated_frames * image_scaling_factor)
)
# Ensure we don't exceed the absolute max (redundant if scaling_factor handles it)
# frame_count_limit = min(frame_count_limit, max_image_frames)
except StopIteration:
raise ValueError(
"Mismatch between image tokens and estimated frame counts."
)
futures.append(
self.io_executor.submit(
BaseMultimodalProcessor._load_single_item,
data,
modality,
frame_count_limit,
audio_sample_rate,
discard_alpha_channel,
)
)
task_info.append((modality, data, frame_count_limit))
for modality, iterator in data_iterators.items():
try:
next(iterator)
logger.warning(
f"Warning: More {modality.name.lower()} data items provided than corresponding tokens found in the prompt."
)
except StopIteration:
pass
except Exception:
pass
return futures, task_info
def load_mm_data(
self,
prompt: str,
multimodal_tokens: MultimodalSpecialTokens,
image_data: Optional[list] = None,
video_data: Optional[list] = None,
audio_data: Optional[list] = None,
return_text: Optional[bool] = True,
discard_alpha_channel: bool = True,
audio_sample_rate: Optional[int] = None,
) -> BaseMultiModalProcessorOutput:
"""
Each frame of video/image will be replaced by a single image token
Args:
multimodal_tokens (list[str]): list of special token which denoting a single multimodal data
e.g. image token or audio token
discard_alpha_channel: if True, discards the alpha channel in the returned images
"""
multimodal_tokens_pattern = multimodal_tokens.get_combined_regex()
if isinstance(prompt, list) and return_text:
assert len(prompt) and isinstance(prompt[0], int)
prompt = self._processor.tokenizer.decode(prompt)
else:
prompt = prompt
assert isinstance(prompt, str)
# split text into list of normal text and special tokens
text_parts = re.split(multimodal_tokens_pattern, prompt)
# collect all data
data_iterators = {}
if multimodal_tokens.image_token and image_data:
data_iterators[Modality.IMAGE] = iter(image_data)
if multimodal_tokens.video_token and video_data:
data_iterators[Modality.VIDEO] = iter(video_data)
if multimodal_tokens.audio_token and audio_data:
data_iterators[Modality.AUDIO] = iter(audio_data)
# futures: the futures of loaded data
# task_info: modality, raw_data, and other metadata of each data
futures, task_info = self.submit_data_loading_tasks(
text_parts=text_parts,
multimodal_tokens=multimodal_tokens,
data_iterators=data_iterators,
discard_alpha_channel=discard_alpha_channel,
audio_sample_rate=audio_sample_rate,
)
task_info_iter = iter(task_info)
futures_iter = iter(futures)
# Process results
images, videos, audios = [], [], []
new_text_parts = []
for text_part in text_parts:
try:
if multimodal_tokens_pattern.match(text_part):
modality, raw_data, frame_limit = next(task_info_iter)
is_precomputed = isinstance(raw_data, dict)
result = next(futures_iter).result()
if modality == Modality.IMAGE:
# If data is already processed it will be a
# dictionary(precomputed). In this case we want to keep the
# expanded tokens in text_part. Otherwise, we will
# call the processor code, so keep only a single image
# token.
mm_tokens = (
text_part
if is_precomputed
else multimodal_tokens.image_token
)
frames = [result] if not isinstance(result, list) else result
if frames:
# only for minicpmv
images += frames
new_text_parts += mm_tokens * len(frames)
elif modality == Modality.VIDEO:
# load as video
mm_tokens = (
text_part
if is_precomputed
else multimodal_tokens.video_token
)
videos += [result]
new_text_parts += mm_tokens
elif modality == Modality.AUDIO:
# audio
mm_tokens = (
text_part
if is_precomputed
else multimodal_tokens.audio_token
)
audios += [result]
new_text_parts += mm_tokens
else:
# normal text
new_text_parts += [text_part]
except Exception as e:
raise RuntimeError(
f"An exception occurred while loading multimodal data: {e}"
)
return BaseMultiModalProcessorOutput(
images=images,
audios=audios,
videos=videos,
input_text="".join(new_text_parts),
)
@staticmethod
def get_mm_items_offset(
input_ids: torch.Tensor, mm_token_id: int
) -> List[Tuple[int, int]]:
"""
Get a set of range for mm_items from input_ids
Example:
input_ids = [1, 2, 3, 3, 3, 4, 3, 3]
mm_token_id = 3
return result = [(2,4),(6,7)]
"""
mask = input_ids == mm_token_id
start_positions = (mask & ~torch.roll(mask, 1)).nonzero(as_tuple=True)[0]
end_positions = (mask & ~torch.roll(mask, -1)).nonzero(as_tuple=True)[0]
return list(zip(start_positions.tolist(), end_positions.tolist()))
@staticmethod
def get_mm_items_offset_by_pair(
input_ids: torch.Tensor, mm_start_id: int, mm_end_id: int
) -> List[Tuple[int, int]]:
indices_start = (input_ids == mm_start_id).nonzero(as_tuple=True)[0] + 1
indices_end = (input_ids == mm_end_id).nonzero(as_tuple=True)[0] - 1
return list(zip(indices_start.tolist(), indices_end.tolist()))
def collect_mm_items_from_processor_output(
self, data_dict: dict
) -> List[MultimodalDataItem]:
"""Create mm_items directly from processor output."""
items: dict[Modality, MultimodalDataItem] = {}
for attr_name, value in data_dict.items():
if attr_name == "input_ids":
continue
# Get modality for this attribute
modality = self.ATTR_NAME_TO_MODALITY.get(attr_name)
if attr_name == "precomputed_embeddings":
modality_str = data_dict.get("modality")
modality = Modality.IMAGE
if modality_str:
try:
modality = Modality.from_str(modality_str)
except ValueError:
pass
if modality:
# Create item if needed
if modality not in items:
items[modality] = MultimodalDataItem(modality=modality)
if attr_name in self.FEATURE_NAMES:
attr_name = "feature"
items[modality].set(attr_name, value)
return list(items.values())
def _process_and_collect_mm_items(
self, input_text: str, images=None, audios=None, videos=None, **kwargs
) -> Tuple[List[MultimodalDataItem], torch.Tensor, dict]:
"""
Helper method to process multimodal data and create mm_items in one step.
Returns:
Tuple of (created mm_items, input_ids)
"""
ret = self.process_mm_data(
input_text=input_text, images=images, audios=audios, videos=videos, **kwargs
)
input_ids = ret["input_ids"].flatten()
collected_items = self.collect_mm_items_from_processor_output(ret)
return collected_items, input_ids, ret
def process_and_combine_mm_data(
self,
base_output: BaseMultiModalProcessorOutput,
mm_tokens: MultimodalSpecialTokens,
**kwargs,
) -> Tuple[List[MultimodalDataItem], torch.Tensor, dict]:
"""
Process multimodal data and return the combined multimodal items and input_ids.
Supports mixed modalities (images and audio in the same request).
Returns:
Tuple of (list of mm_items, input_ids)
"""
# Collect all items and categorize them
all_items = base_output.organize_results()
# Handle text-only case
if not all_items:
input_ids = self._processor.tokenizer(
base_output.input_text,
return_tensors="pt",
add_special_tokens=True,
).input_ids.flatten()
return [], input_ids, {}
dict_items, raw_images, raw_audios, raw_videos = [], [], [], []
for modality, item in all_items:
if isinstance(item, dict):
dict_items.append(item)
elif modality == Modality.IMAGE:
raw_images.append(item)
elif modality == Modality.AUDIO:
raw_audios.append(item)
elif modality == Modality.VIDEO:
raw_videos.append(item)
else:
raise ValueError(f"Unknown multimodal item type: {type(item)}")
# Process items and get input_ids
all_collected_items: list[MultimodalDataItem] = []
input_ids = None
# Handle raw items (need processing)
if raw_images or raw_audios or raw_videos:
collected_items, input_ids, ret = self._process_and_collect_mm_items(
input_text=base_output.input_text,
images=raw_images,
audios=raw_audios,
videos=raw_videos,
**kwargs,
)
all_collected_items = collected_items
else:
ret = None
# Handle dict items (already processed)
for dict_item in dict_items:
all_collected_items.extend(
self.collect_mm_items_from_processor_output(dict_item)
)
# Fallback tokenization if no raw items were processed
if input_ids is None:
input_ids = self._processor.tokenizer(
base_output.input_text,
return_tensors="pt",
add_special_tokens=True,
).input_ids.flatten()
# Add offsets to all items
for mm_item in all_collected_items:
mm_token_id = mm_tokens.get_token_id_by_modality(mm_item.modality)
if mm_token_id is None:
raise ValueError(f"No token id found for modality: {mm_item.modality}")
mm_item.offsets = self.get_mm_items_offset(
input_ids=input_ids,
mm_token_id=mm_token_id,
)
return all_collected_items, input_ids, ret