Files
sglang/python/sglang/srt/multimodal/processors/base_processor.py

593 lines
23 KiB
Python

import concurrent
import concurrent.futures
import dataclasses
import multiprocessing as mp
import os
import re
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from PIL import Image
from transformers import BaseImageProcessorFast
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.utils import encode_video, load_audio, load_image
class MultimodalInputFormat(Enum):
"""Enum for different multimodal input formats."""
RAW_IMAGES = "raw_images"
PRECOMPUTED_FEATURES = "precomputed_features"
PIXEL_VALUES = "pixel_values"
AUDIO = "audio"
@dataclasses.dataclass
class BaseMultiModalProcessorOutput:
# input_text, with each frame of video/image represented with a image_token
input_text: str
# frames loaded from image and video, in given order
images: Optional[list[Union[Image.Image, dict]]] = None
# audios
audios: Optional[list[Union[np.ndarray, dict]]] = None
def normalize(self):
for field_name in ["images", "audios"]:
field = getattr(self, field_name, None)
if field is not None and isinstance(field, list) and len(field) == 0:
setattr(self, field_name, None)
@dataclasses.dataclass
class MultimodalSpecialTokens:
image_token: Optional[Union[int, str, List[str]]] = None
video_token: Optional[Union[int, str, List[str]]] = None
audio_token: Optional[Union[int, str, List[str]]] = None
def convert_to_str(self, token: Union[str, int], processor) -> str:
if token is None:
return token
if isinstance(token, str):
return token
return processor.tokenizer.convert_ids_to_tokens([token])[0]
def convert_to_strs(self, processor):
self.image_token = self.convert_to_str(self.image_token, processor)
self.video_token = self.convert_to_str(self.video_token, processor)
self.audio_token = self.convert_to_str(self.audio_token, processor)
image_token_regex: Optional[re.Pattern] = None
video_token_regex: Optional[re.Pattern] = None
audio_token_regex: Optional[re.Pattern] = None
def __post_init__(self):
if self.image_token_regex is None and self.image_token is not None:
self.image_token_regex = re.compile(re.escape(self.image_token))
if self.video_token_regex is None and self.video_token is not None:
self.video_token_regex = re.compile(re.escape(self.video_token))
if self.audio_token_regex is None and self.audio_token is not None:
self.audio_token_regex = re.compile(re.escape(self.audio_token))
def collect(self) -> re.Pattern:
tokens = [
self.image_token_regex,
self.video_token_regex,
self.audio_token_regex,
]
patterns = []
flags = 0
for t in tokens:
if t is not None:
patterns.append(t.pattern)
flags |= t.flags
combined = "(" + "|".join(f"(?:{p})" for p in patterns) + ")"
return re.compile(combined, flags)
class BaseMultimodalProcessor(ABC):
models = []
def __init__(self, hf_config, server_args, _processor):
self.hf_config = hf_config
self._processor = _processor
self.arch = hf_config.architectures[0]
self.server_args = server_args
# FIXME: not accurate, model and image specific
self.NUM_TOKEN_PER_FRAME = 330
self.io_executor = concurrent.futures.ThreadPoolExecutor(
max_workers=int(os.environ.get("SGLANG_IO_WORKERS", 4))
)
self.cpu_executor = concurrent.futures.ProcessPoolExecutor(
mp_context=mp.get_context("fork"),
max_workers=int(os.environ.get("SGLANG_CPU_WORKERS", os.cpu_count())),
)
def process_mm_data(
self, input_text, images=None, videos=None, audios=None, **kwargs
):
"""
process multimodal data with transformers AutoProcessor
"""
if images is not None:
kwargs["images"] = images
if videos is not None:
kwargs["videos"] = videos
if audios is not None:
kwargs["audios"] = audios
processor = self._processor
if hasattr(processor, "image_processor") and isinstance(
processor.image_processor, BaseImageProcessorFast
):
kwargs["device"] = "cuda"
result = processor.__call__(
text=[input_text],
padding=True,
return_tensors="pt",
**kwargs,
)
if "pixel_values" in result and isinstance(
result["pixel_values"], torch.Tensor
):
result["pixel_values"] = result["pixel_values"].to("cpu")
return result
@abstractmethod
async def process_mm_data_async(
self,
image_data,
input_text,
request_obj,
max_req_input_len,
**kwargs,
) -> Optional[Dict[str, Any]]:
pass
def get_estimated_frames_list(self, image_data):
"""
estimate the total frame count from all visual input
"""
# Lazy import because decord is not available on some arm platforms.
from decord import VideoReader, cpu
# Before processing inputs
if not image_data or len(image_data) == 0:
return []
estimated_frames_list = []
for image in image_data:
if isinstance(image, str) and image.startswith("video:"):
path = image[len("video:") :]
# Estimate frames for the video
vr = VideoReader(path, ctx=cpu(0))
num_frames = len(vr)
else:
# For images, each contributes one frame
num_frames = 1
estimated_frames_list.append(num_frames)
return estimated_frames_list
@staticmethod
def _load_single_item(
data, is_video, is_audio, frame_count_limit=None, discard_alpha_channel=True
):
"""Static method that can be pickled for multiprocessing"""
if isinstance(data, dict):
return data
try:
if is_audio:
return load_audio(data)
elif is_video:
path = data[len("video:") :]
return encode_video(path, frame_count_limit)
else:
img, _ = load_image(data)
return img.convert("RGB") if discard_alpha_channel else img
except Exception as e:
raise RuntimeError(f"Error while loading data {data}: {e}")
def submit_data_loading_tasks(
self,
text_parts: List[str],
multimodal_tokens: MultimodalSpecialTokens,
image_data: Optional[list] = None,
audio_data: Optional[list] = None,
discard_alpha_channel: bool = True,
):
"""
load multimodal data parallelly
"""
# TODO(mick): load from server_args, env, or sampling_params
MAX_NUM_FRAMES = 30
estimated_frames_list = self.get_estimated_frames_list(image_data=image_data)
total_frame_count = sum(estimated_frames_list)
# a heuristic value, suggesting the maximum fraction of frames to embed from all visual inputs.
# e.g., 0.1 suggests that 1 frame out of 10 input frames should be used
scaling_factor = min(1.0, MAX_NUM_FRAMES / max(1, total_frame_count))
assert len(image_data) == len(estimated_frames_list)
# Submit all tasks
futures = []
task_info = []
image_index, audio_index = 0, 0
for text_part in text_parts:
if (
multimodal_tokens.image_token_regex
and multimodal_tokens.image_token_regex.match(text_part)
):
data = image_data[image_index]
is_video = isinstance(data, str) and data.startswith("video:")
estimated_frames = estimated_frames_list[image_index]
frame_count_limit = max(1, int(estimated_frames * scaling_factor))
futures.append(
self.io_executor.submit(
BaseMultimodalProcessor._load_single_item,
data,
is_video,
False,
frame_count_limit,
discard_alpha_channel,
)
)
task_info.append((Modality.IMAGE, data, frame_count_limit))
image_index += 1
elif (
multimodal_tokens.audio_token_regex
and multimodal_tokens.audio_token_regex.match(text_part)
):
data = audio_data[audio_index]
futures.append(
self.io_executor.submit(
BaseMultimodalProcessor._load_single_item,
data,
False,
True,
None,
discard_alpha_channel,
)
)
task_info.append((Modality.AUDIO, data, None))
audio_index += 1
return futures, task_info
def load_mm_data(
self,
prompt: str | List[int],
multimodal_tokens: MultimodalSpecialTokens,
max_req_input_len: int,
image_data: Optional[list] = None,
audio_data: Optional[list] = None,
return_text: Optional[bool] = True,
discard_alpha_channel: bool = True,
) -> BaseMultiModalProcessorOutput:
"""
Each frame of video/image will be replaced by a single image token
Args:
multimodal_tokens (list[str]): list of special token which denoting a single multimodal data
e.g. image token or audio token
discard_alpha_channel: if True, discards the alpha channel in the returned images
"""
if not return_text:
raise NotImplementedError()
if image_data is None:
image_data = []
multimodal_tokens.convert_to_strs(self._processor)
multimodal_tokens_pattern = multimodal_tokens.collect()
if isinstance(prompt, list) and return_text:
assert len(prompt) and isinstance(prompt[0], int)
prompt = self._processor.tokenizer.decode(prompt)
else:
prompt = prompt
assert isinstance(prompt, str)
# split text into list of normal text and special tokens
text_parts = re.split(multimodal_tokens_pattern, prompt)
futures, task_info = self.submit_data_loading_tasks(
text_parts=text_parts,
multimodal_tokens=multimodal_tokens,
image_data=image_data,
audio_data=audio_data,
discard_alpha_channel=discard_alpha_channel,
)
# Process results
images, audios = [], []
new_text = ""
task_ptr = 0
for text_part in text_parts:
if multimodal_tokens_pattern.match(text_part):
task_type, data, frame_limit = task_info[task_ptr]
result = futures[task_ptr].result()
task_ptr += 1
if task_type == Modality.IMAGE:
# If data is already processed it will be a
# dictionary. In this case we want to keep the
# expanded tokens in text_part. Otherwise, we will
# call the processor code, so keep only a single image
# token.
mm_tokens = (
text_part
if isinstance(data, dict)
else multimodal_tokens.image_token
)
frames = [result] if not isinstance(result, list) else result
if frames:
images += frames
new_text += mm_tokens * len(frames)
elif task_type == Modality.AUDIO:
# audio
mm_tokens = (
text_part
if isinstance(data, dict)
else multimodal_tokens.audio_token
)
audios.append(result)
new_text += mm_tokens
# TODO: handle video
else:
new_text += text_part
out = BaseMultiModalProcessorOutput(
input_text=new_text,
images=images,
audios=audios,
)
out.normalize()
return out
@staticmethod
def get_mm_items_offset(
input_ids: torch.Tensor, mm_token_id: int
) -> List[Tuple[int, int]]:
"""
Get a set of range for mm_items from input_ids
Example:
input_ids = [1, 2, 3, 3, 3, 4, 3, 3]
mm_token_id = 3
return result = [(2,4),(6,7)]
"""
mask = input_ids == mm_token_id
start_positions = (mask & ~torch.roll(mask, 1)).nonzero(as_tuple=True)[0]
end_positions = (mask & ~torch.roll(mask, -1)).nonzero(as_tuple=True)[0]
return list(zip(start_positions.tolist(), end_positions.tolist()))
@staticmethod
def get_mm_items_offset_by_pair(
input_ids: torch.Tensor, mm_start_id: int, mm_end_id: int
) -> List[Tuple[int, int]]:
indices_start = (input_ids == mm_start_id).nonzero(as_tuple=True)[0] + 1
indices_end = (input_ids == mm_end_id).nonzero(as_tuple=True)[0] - 1
return list(zip(indices_start.tolist(), indices_end.tolist()))
@staticmethod
def _extract_processor_features(
items: List[dict], attr_name: str
) -> Optional[torch.Tensor]:
"""
Helper function to concat extracted attributes from processor output.
"""
values = [value for item in items if (value := item.get(attr_name)) is not None]
return torch.cat(values) if values else None
# When we assume that all the items have the same attributes
def _extract_processor_features_from_all_attributes(
self, items: List[dict]
) -> dict:
values = {}
# Verify all items have the same keys
first_keys = set(items[0].keys())
for item in items[1:]:
if set(item.keys()) != first_keys:
raise ValueError(
f"All items must have the same attributes. "
f"First item has {first_keys}, but found {set(item.keys())}"
)
# Process each attribute
for k, v in items[0].items():
if isinstance(v, list):
values[k] = self._extract_processor_features(items, k)
else:
# Verify all items have the same value for non-list attributes
for item in items[1:]:
if item[k] != v:
raise ValueError(
f"All items must have the same value for attribute {k}. "
f"First item has {v}, but found {item[k]}"
)
values[k] = v
return values
def process_and_combine_mm_data(
self, base_output: BaseMultiModalProcessorOutput
) -> Tuple[Optional[MultimodalDataItem], torch.Tensor]:
"""
Process multimodal data and return the combined multimodal item and input_ids.
Handles all three input formats at the same abstraction level.
Returns:
Tuple of (combined_mm_item, input_ids)
"""
def tokenize_text(input_text: str) -> torch.Tensor:
"""Tokenize input text."""
return self._processor.tokenizer(
input_text,
return_tensors="pt",
add_special_tokens=True,
).input_ids.flatten()
def categorize_mm_inputs(mm_inputs: List) -> MultimodalInputFormat:
"""Categorize multimodal inputs and validate consistency."""
try:
has_image = False
has_pixel_values = False
has_precomputed_features = False
has_audio = False
for mm_input in mm_inputs:
if isinstance(mm_input, Image.Image):
has_image = True
elif isinstance(mm_input, np.ndarray):
has_audio = True
elif isinstance(mm_input, dict):
if mm_input.get("precomputed_features", None) is not None:
has_precomputed_features = True
elif mm_input.get("pixel_values", None) is not None:
has_pixel_values = True
else:
raise ValueError(
f"Invalid multimodal input: {mm_input}, expected dict with pixel_values or precomputed_features"
)
else:
raise ValueError(
f"Invalid multimodal input: {mm_input}, expected Image.Image or dict"
)
# Validate format consistency
format_count = sum(
[has_image, has_pixel_values, has_precomputed_features, has_audio]
)
if format_count > 1:
raise ValueError(
"Unsupported: mixture of multimodal input formats. "
f"Found formats: image={has_image}, pixel_values={has_pixel_values}, "
f"precomputed_features={has_precomputed_features}, audio={has_audio}"
)
if has_image:
return MultimodalInputFormat.RAW_IMAGES
elif has_precomputed_features:
return MultimodalInputFormat.PRECOMPUTED_FEATURES
elif has_pixel_values:
return MultimodalInputFormat.PIXEL_VALUES
elif has_audio:
return MultimodalInputFormat.AUDIO
else:
raise ValueError("No valid multimodal input format found")
except Exception as e:
raise ValueError(f"Failed to categorize inputs: {e}")
def process_raw_images(
base_output: BaseMultiModalProcessorOutput,
) -> Tuple[MultimodalDataItem, torch.Tensor]:
"""Process raw Image.Image objects using transformers processor."""
ret = self.process_mm_data(
input_text=base_output.input_text,
images=base_output.images,
)
combined_mm_item = MultimodalDataItem(modality=Modality.IMAGE)
# Copy all fields from processor output except input_ids
for key, value in ret.items():
if key != "input_ids" and hasattr(combined_mm_item, key):
setattr(combined_mm_item, key, value)
input_ids = ret["input_ids"].flatten()
return combined_mm_item, input_ids
def process_precomputed_features(
base_output: BaseMultiModalProcessorOutput,
) -> Tuple[MultimodalDataItem, torch.Tensor]:
"""Process inputs with precomputed features."""
combined_mm_item = MultimodalDataItem(modality=Modality.IMAGE)
combined_mm_item.precomputed_features = self._extract_processor_features(
base_output.images, "precomputed_features"
)
input_ids = tokenize_text(base_output.input_text)
return combined_mm_item, input_ids
def process_pixel_values(
base_output: BaseMultiModalProcessorOutput,
) -> Tuple[MultimodalDataItem, torch.Tensor]:
"""Process inputs with pixel values."""
values = self._extract_processor_features_from_all_attributes(
base_output.images
)
combined_mm_item = MultimodalDataItem.from_dict(values)
input_ids = tokenize_text(base_output.input_text)
return combined_mm_item, input_ids
def process_audio(
base_output: BaseMultiModalProcessorOutput,
) -> Tuple[MultimodalDataItem, torch.Tensor]:
"""Process inputs with audio."""
ret = self.process_mm_data(
input_text=base_output.input_text,
audio=base_output.audios, # Note: "audio" is for gemma3n only
)
combined_mm_item = MultimodalDataItem(modality=Modality.AUDIO)
for key, value in ret.items():
if key != "input_ids" and hasattr(combined_mm_item, key):
setattr(combined_mm_item, key, value)
input_ids = ret["input_ids"].flatten()
return combined_mm_item, input_ids
def finalize_mm_item(
combined_mm_item: MultimodalDataItem, input_ids: torch.Tensor
) -> MultimodalDataItem:
"""Apply common post-processing to the multimodal item."""
if combined_mm_item.modality in [Modality.IMAGE, Modality.MULTI_IMAGES]:
combined_mm_item.image_offsets = self.get_mm_items_offset(
input_ids=input_ids,
mm_token_id=self.IM_TOKEN_ID,
)
elif combined_mm_item.modality == Modality.AUDIO:
combined_mm_item.audio_offsets = self.get_mm_items_offset(
input_ids=input_ids,
mm_token_id=self.AUDIO_TOKEN_ID,
)
elif combined_mm_item.modality == Modality.VIDEO:
combined_mm_item.video_offsets = self.get_mm_items_offset(
input_ids=input_ids,
mm_token_id=self.VIDEO_TOKEN_ID,
)
else:
raise ValueError(f"Unknown modality: {combined_mm_item.modality}")
return combined_mm_item
# Main logic - determine input type and handle text-only case
mm_inputs = base_output.images or base_output.audios
if not mm_inputs:
input_ids = tokenize_text(base_output.input_text)
return None, input_ids
# Categorize input formats
input_format = categorize_mm_inputs(mm_inputs)
# Process based on format
if input_format == MultimodalInputFormat.RAW_IMAGES:
combined_mm_item, input_ids = process_raw_images(base_output)
elif input_format == MultimodalInputFormat.PRECOMPUTED_FEATURES:
combined_mm_item, input_ids = process_precomputed_features(base_output)
elif input_format == MultimodalInputFormat.PIXEL_VALUES:
combined_mm_item, input_ids = process_pixel_values(base_output)
elif input_format == MultimodalInputFormat.AUDIO:
combined_mm_item, input_ids = process_audio(base_output)
else:
raise ValueError(f"Unknown input format: {input_format}")
# Finalize with common processing
combined_mm_item = finalize_mm_item(combined_mm_item, input_ids)
return combined_mm_item, input_ids