Refactor mm processors and Enable mixed modality processing (#7629)

Signed-off-by: Xinyuan Tong <justinning0323@outlook.com>
This commit is contained in:
Xinyuan Tong
2025-06-30 23:14:48 -07:00
committed by GitHub
parent 886d344964
commit 3a911b854d
28 changed files with 235 additions and 428 deletions

View File

@@ -17,15 +17,6 @@ from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.utils import encode_video, load_audio, load_image
class MultimodalInputFormat(Enum):
"""Enum for different multimodal input formats."""
RAW_IMAGES = "raw_images"
PRECOMPUTED_FEATURES = "precomputed_features"
PIXEL_VALUES = "pixel_values"
AUDIO = "audio"
@dataclasses.dataclass
class BaseMultiModalProcessorOutput:
# input_text, with each frame of video/image represented with a image_token
@@ -110,18 +101,45 @@ class BaseMultimodalProcessor(ABC):
max_workers=int(os.environ.get("SGLANG_CPU_WORKERS", os.cpu_count())),
)
# Mapping from attribute names to modality types
self.ATTR_NAME_TO_MODALITY = {
# Image-related attributes
"pixel_values": Modality.IMAGE,
"image_sizes": Modality.IMAGE,
"image_grid_thw": Modality.IMAGE,
"image_emb_mask": Modality.IMAGE,
"image_spatial_crop": Modality.IMAGE,
"tgt_size": Modality.IMAGE,
"image_grid_hws": Modality.IMAGE,
"aspect_ratio_id": Modality.IMAGE,
"aspect_ratio_mask": Modality.IMAGE,
"second_per_grid_ts": Modality.IMAGE,
# Audio-related attributes
"audio_features": Modality.AUDIO,
"audio_feature_lens": Modality.AUDIO,
"input_features": Modality.AUDIO,
"input_features_mask": Modality.AUDIO,
# Video-related attributes
"video_grid_thws": Modality.VIDEO,
# Generic attributes that could apply to multiple modalities
# "precomputed_features" - handled specially as it can be any modality
}
def process_mm_data(
self, input_text, images=None, videos=None, audios=None, **kwargs
):
"""
process multimodal data with transformers AutoProcessor
"""
if images is not None:
if images:
kwargs["images"] = images
if videos is not None:
if videos:
kwargs["videos"] = videos
if audios is not None:
if audios:
kwargs["audios"] = audios
if self.__class__.__name__ == "Gemma3nSGLangProcessor":
# Note(Xinyuan): for gemma3n, ref: https://github.com/huggingface/transformers/blob/ccf2ca162e33f381e454cdb74bf4b41a51ab976d/src/transformers/models/gemma3n/processing_gemma3n.py#L107
kwargs["audio"] = audios
processor = self._processor
if hasattr(processor, "image_processor") and isinstance(
@@ -144,6 +162,7 @@ class BaseMultimodalProcessor(ABC):
async def process_mm_data_async(
self,
image_data,
audio_data,
input_text,
request_obj,
max_req_input_len,
@@ -418,175 +437,137 @@ class BaseMultimodalProcessor(ABC):
values[k] = v
return values
def process_and_combine_mm_data(
self, base_output: BaseMultiModalProcessorOutput
) -> Tuple[Optional[MultimodalDataItem], torch.Tensor]:
def collect_mm_items_from_processor_output(
self, data_dict: dict
) -> List[MultimodalDataItem]:
"""Create mm_items directly from processor output."""
items = {} # modality -> MultimodalDataItem
for attr_name, value in data_dict.items():
if attr_name == "input_ids":
continue
# Get modality for this attribute
modality = self.ATTR_NAME_TO_MODALITY.get(attr_name)
if not modality and attr_name == "precomputed_features":
modality_str = data_dict.get("modality")
try:
modality = (
Modality.from_str(modality_str)
if modality_str
else Modality.IMAGE
)
except ValueError:
modality = Modality.IMAGE
if modality:
# Create item if needed
if modality not in items:
items[modality] = MultimodalDataItem(modality=modality)
# Set attribute
if hasattr(items[modality], attr_name):
setattr(items[modality], attr_name, value)
return list(items.values())
def _process_and_collect_mm_items(
self, input_text: str, images=None, audios=None, videos=None, **kwargs
) -> Tuple[List[MultimodalDataItem], torch.Tensor]:
"""
Process multimodal data and return the combined multimodal item and input_ids.
Handles all three input formats at the same abstraction level.
Helper method to process multimodal data and create mm_items in one step.
Returns:
Tuple of (combined_mm_item, input_ids)
Tuple of (created mm_items, input_ids)
"""
ret = self.process_mm_data(
input_text=input_text, images=images, audios=audios, videos=videos, **kwargs
)
def tokenize_text(input_text: str) -> torch.Tensor:
"""Tokenize input text."""
return self._processor.tokenizer(
input_text,
input_ids = ret["input_ids"].flatten()
collected_items = self.collect_mm_items_from_processor_output(ret)
return collected_items, input_ids
def process_and_combine_mm_data(
self, base_output: BaseMultiModalProcessorOutput
) -> Tuple[List[MultimodalDataItem], torch.Tensor]:
"""
Process multimodal data and return the combined multimodal items and input_ids.
Supports mixed modalities (images and audio in the same request).
Returns:
Tuple of (list of mm_items, input_ids)
"""
# Collect all items and categorize them
all_items = (base_output.images or []) + (base_output.audios or [])
# Handle text-only case
if not all_items:
input_ids = self._processor.tokenizer(
base_output.input_text,
return_tensors="pt",
add_special_tokens=True,
).input_ids.flatten()
return [], input_ids
dict_items, raw_images, raw_audios = [], [], []
for item in all_items:
if isinstance(item, dict):
dict_items.append(item)
elif isinstance(item, Image.Image):
raw_images.append(item)
elif isinstance(item, np.ndarray):
raw_audios.append(item)
else:
raise ValueError(f"Unknown multimodal item type: {type(item)}")
# Process items and get input_ids
all_collected_items = []
input_ids = None
# Handle dict items (already processed)
for dict_item in dict_items:
all_collected_items.extend(
self.collect_mm_items_from_processor_output(dict_item)
)
# Handle raw items (need processing)
if raw_images or raw_audios:
collected_items, input_ids = self._process_and_collect_mm_items(
input_text=base_output.input_text,
images=raw_images,
audios=raw_audios,
)
all_collected_items.extend(collected_items)
# Fallback tokenization if no raw items were processed
if input_ids is None:
input_ids = self._processor.tokenizer(
base_output.input_text,
return_tensors="pt",
add_special_tokens=True,
).input_ids.flatten()
def categorize_mm_inputs(mm_inputs: List) -> MultimodalInputFormat:
"""Categorize multimodal inputs and validate consistency."""
try:
has_image = False
has_pixel_values = False
has_precomputed_features = False
has_audio = False
for mm_input in mm_inputs:
if isinstance(mm_input, Image.Image):
has_image = True
elif isinstance(mm_input, np.ndarray):
has_audio = True
elif isinstance(mm_input, dict):
if mm_input.get("precomputed_features", None) is not None:
has_precomputed_features = True
elif mm_input.get("pixel_values", None) is not None:
has_pixel_values = True
else:
raise ValueError(
f"Invalid multimodal input: {mm_input}, expected dict with pixel_values or precomputed_features"
)
else:
raise ValueError(
f"Invalid multimodal input: {mm_input}, expected Image.Image or dict"
)
# Validate format consistency
format_count = sum(
[has_image, has_pixel_values, has_precomputed_features, has_audio]
)
if format_count > 1:
raise ValueError(
"Unsupported: mixture of multimodal input formats. "
f"Found formats: image={has_image}, pixel_values={has_pixel_values}, "
f"precomputed_features={has_precomputed_features}, audio={has_audio}"
)
if has_image:
return MultimodalInputFormat.RAW_IMAGES
elif has_precomputed_features:
return MultimodalInputFormat.PRECOMPUTED_FEATURES
elif has_pixel_values:
return MultimodalInputFormat.PIXEL_VALUES
elif has_audio:
return MultimodalInputFormat.AUDIO
else:
raise ValueError("No valid multimodal input format found")
except Exception as e:
raise ValueError(f"Failed to categorize inputs: {e}")
def process_raw_images(
base_output: BaseMultiModalProcessorOutput,
) -> Tuple[MultimodalDataItem, torch.Tensor]:
"""Process raw Image.Image objects using transformers processor."""
ret = self.process_mm_data(
input_text=base_output.input_text,
images=base_output.images,
)
combined_mm_item = MultimodalDataItem(modality=Modality.IMAGE)
# Copy all fields from processor output except input_ids
for key, value in ret.items():
if key != "input_ids" and hasattr(combined_mm_item, key):
setattr(combined_mm_item, key, value)
input_ids = ret["input_ids"].flatten()
return combined_mm_item, input_ids
def process_precomputed_features(
base_output: BaseMultiModalProcessorOutput,
) -> Tuple[MultimodalDataItem, torch.Tensor]:
"""Process inputs with precomputed features."""
combined_mm_item = MultimodalDataItem(modality=Modality.IMAGE)
combined_mm_item.precomputed_features = self._extract_processor_features(
base_output.images, "precomputed_features"
)
input_ids = tokenize_text(base_output.input_text)
return combined_mm_item, input_ids
def process_pixel_values(
base_output: BaseMultiModalProcessorOutput,
) -> Tuple[MultimodalDataItem, torch.Tensor]:
"""Process inputs with pixel values."""
values = self._extract_processor_features_from_all_attributes(
base_output.images
)
combined_mm_item = MultimodalDataItem.from_dict(values)
input_ids = tokenize_text(base_output.input_text)
return combined_mm_item, input_ids
def process_audio(
base_output: BaseMultiModalProcessorOutput,
) -> Tuple[MultimodalDataItem, torch.Tensor]:
"""Process inputs with audio."""
ret = self.process_mm_data(
input_text=base_output.input_text,
audio=base_output.audios, # Note: "audio" is for gemma3n only
)
combined_mm_item = MultimodalDataItem(modality=Modality.AUDIO)
for key, value in ret.items():
if key != "input_ids" and hasattr(combined_mm_item, key):
setattr(combined_mm_item, key, value)
input_ids = ret["input_ids"].flatten()
return combined_mm_item, input_ids
def finalize_mm_item(
combined_mm_item: MultimodalDataItem, input_ids: torch.Tensor
) -> MultimodalDataItem:
"""Apply common post-processing to the multimodal item."""
if combined_mm_item.modality in [Modality.IMAGE, Modality.MULTI_IMAGES]:
combined_mm_item.image_offsets = self.get_mm_items_offset(
# Add offsets to all items
for mm_item in all_collected_items:
if mm_item.modality in [Modality.IMAGE, Modality.MULTI_IMAGES]:
mm_item.image_offsets = self.get_mm_items_offset(
input_ids=input_ids,
mm_token_id=self.IM_TOKEN_ID,
)
elif combined_mm_item.modality == Modality.AUDIO:
combined_mm_item.audio_offsets = self.get_mm_items_offset(
elif mm_item.modality == Modality.AUDIO:
mm_item.audio_offsets = self.get_mm_items_offset(
input_ids=input_ids,
mm_token_id=self.AUDIO_TOKEN_ID,
)
elif combined_mm_item.modality == Modality.VIDEO:
combined_mm_item.video_offsets = self.get_mm_items_offset(
elif mm_item.modality == Modality.VIDEO:
mm_item.video_offsets = self.get_mm_items_offset(
input_ids=input_ids,
mm_token_id=self.VIDEO_TOKEN_ID,
)
else:
raise ValueError(f"Unknown modality: {combined_mm_item.modality}")
return combined_mm_item
raise ValueError(f"Unknown modality: {mm_item.modality}")
# Main logic - determine input type and handle text-only case
mm_inputs = base_output.images or base_output.audios
if not mm_inputs:
input_ids = tokenize_text(base_output.input_text)
return None, input_ids
# Categorize input formats
input_format = categorize_mm_inputs(mm_inputs)
# Process based on format
if input_format == MultimodalInputFormat.RAW_IMAGES:
combined_mm_item, input_ids = process_raw_images(base_output)
elif input_format == MultimodalInputFormat.PRECOMPUTED_FEATURES:
combined_mm_item, input_ids = process_precomputed_features(base_output)
elif input_format == MultimodalInputFormat.PIXEL_VALUES:
combined_mm_item, input_ids = process_pixel_values(base_output)
elif input_format == MultimodalInputFormat.AUDIO:
combined_mm_item, input_ids = process_audio(base_output)
else:
raise ValueError(f"Unknown input format: {input_format}")
# Finalize with common processing
combined_mm_item = finalize_mm_item(combined_mm_item, input_ids)
return combined_mm_item, input_ids
return all_collected_items, input_ids

View File

@@ -15,20 +15,11 @@ class ClipImageProcessor(BaseMultimodalProcessor):
async def process_mm_data_async(
self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs
):
if not image_data:
return None
if isinstance(input_text, list):
assert len(input_text) and isinstance(input_text[0], int)
input_text = self._processor.tokenizer.decode(input_text)
if not isinstance(image_data, list):
image_data = [image_data]
if len(image_data) > 0:
images = [load_image(image)[0] for image in image_data]
else:
images = load_image(image_data[0])[0]
images = [load_image(image)[0] for image in image_data]
image_inputs = self.process_mm_data(input_text=input_text, images=images)
image_inputs["data_hashes"] = [hash(str(image_data))]

View File

@@ -44,17 +44,10 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
*args,
**kwargs
):
if not image_data:
return None
if not isinstance(image_data, list):
image_data = [image_data]
image_token = self.IMAGE_TOKEN
base_output = self.load_mm_data(
input_text,
image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
multimodal_tokens=MultimodalSpecialTokens(image_token=self.IMAGE_TOKEN),
max_req_input_len=max_req_input_len,
)
res = self.process_mm_data(

View File

@@ -36,11 +36,6 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
*args,
**kwargs,
):
if not image_data:
return None
if isinstance(image_data, str):
image_data = [image_data]
base_output = self.load_mm_data(
prompt=input_text,
image_data=image_data,
@@ -51,11 +46,11 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
discard_alpha_channel=True,
)
combined_mm_item, input_ids = self.process_and_combine_mm_data(base_output)
mm_items, input_ids = self.process_and_combine_mm_data(base_output)
return {
"input_ids": input_ids.tolist(),
"mm_items": [combined_mm_item] if combined_mm_item is not None else [],
"mm_items": mm_items,
"im_start_id": self.IM_START_TOKEN_ID,
"im_end_id": self.IM_END_TOKEN_ID,
}

View File

@@ -59,17 +59,6 @@ class Gemma3nSGLangProcessor(SGLangBaseProcessor):
**kwargs,
):
"""Process multimodal data including images and audio."""
audio_data = request_obj.audio_data
if not image_data and not audio_data:
return None
if isinstance(image_data, str):
image_data = [image_data]
if isinstance(audio_data, str):
audio_data = [audio_data]
base_output = self.load_mm_data(
prompt=input_text,
image_data=image_data,
@@ -83,13 +72,11 @@ class Gemma3nSGLangProcessor(SGLangBaseProcessor):
),
)
combined_mm_item, input_ids = self.process_and_combine_mm_data(base_output)
mm_items, input_ids = self.process_and_combine_mm_data(base_output)
return {
"input_ids": input_ids.tolist(),
"mm_items": [combined_mm_item] if combined_mm_item is not None else [],
"im_start_id": self.IM_START_TOKEN_ID,
"im_end_id": self.IM_END_TOKEN_ID,
"audio_start_id": self.AUDIO_START_TOKEN_ID,
"audio_end_id": self.AUDIO_END_TOKEN_ID,
"mm_items": mm_items,
"im_token_id": self.IM_TOKEN_ID,
"audio_token_id": self.AUDIO_TOKEN_ID,
}

View File

@@ -172,13 +172,6 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
async def process_mm_data_async(
self, image_data, input_text, request_obj, max_req_input_len, **kwargs
):
if not image_data:
return None
# Ensure image_data is a list
if isinstance(image_data, str):
image_data = [image_data]
base_output = self.load_mm_data(
prompt=input_text,
image_data=image_data,

View File

@@ -22,12 +22,6 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
max_req_input_len,
**kwargs,
):
if not image_data:
return None
if not isinstance(image_data, list):
image_data = [image_data]
processor = self._processor
base_out = self.load_mm_data(

View File

@@ -30,11 +30,6 @@ class KimiVLImageProcessor(SGLangBaseProcessor):
*args,
**kwargs,
):
if not image_data:
return None
if isinstance(image_data, str):
image_data = [image_data]
base_output = self.load_mm_data(
prompt=input_text,
image_data=image_data,
@@ -44,10 +39,10 @@ class KimiVLImageProcessor(SGLangBaseProcessor):
max_req_input_len=max_req_input_len,
)
combined_mm_item, input_ids = self.process_and_combine_mm_data(base_output)
mm_items, input_ids = self.process_and_combine_mm_data(base_output)
return {
"input_ids": input_ids.tolist(),
"mm_items": [combined_mm_item] if combined_mm_item is not None else [],
"mm_items": mm_items,
"im_token_id": self.IM_TOKEN_ID,
}

View File

@@ -110,9 +110,6 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
*args,
**kwargs,
):
if not image_data:
return None
modalities = request_obj.modalities or ["image"]
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
grid_pinpoints = (
@@ -122,9 +119,6 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
else None
)
if isinstance(image_data, str):
image_data = [image_data]
if isinstance(image_data, list) and len(image_data) > 0:
if "multi-images" in modalities or "video" in modalities:
# Multiple images

View File

@@ -23,19 +23,12 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
async def process_mm_data_async(
self,
image_data: List[Union[str, bytes]],
audio_data: List[Union[str, bytes]],
input_text,
request_obj,
max_req_input_len,
**kwargs,
):
audio_data = request_obj.audio_data
if not image_data and not audio_data:
return None
if not isinstance(image_data, list):
image_data = [image_data]
if not isinstance(audio_data, list):
audio_data = [audio_data]
base_output = self.load_mm_data(
prompt=input_text,
max_req_input_len=max_req_input_len,

View File

@@ -15,21 +15,11 @@ class MllamaImageProcessor(BaseMultimodalProcessor):
async def process_mm_data_async(
self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs
):
if not image_data:
return None
if isinstance(input_text, list):
assert len(input_text) and isinstance(input_text[0], int)
input_text = self._processor.tokenizer.decode(input_text)
if not isinstance(image_data, list):
image_data = [image_data]
if len(image_data) > 0:
images = [load_image(image)[0] for image in image_data]
else:
images = load_image(image_data[0])[0]
images = [load_image(image)[0] for image in image_data]
image_inputs = self.process_mm_data(input_text=input_text, images=images)
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
image_inputs["mm_items"] = [

View File

@@ -37,9 +37,6 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
*args,
**kwargs,
):
if not image_data:
return None
if isinstance(input_text, list):
assert len(input_text) and isinstance(input_text[0], int)
input_text = self._processor.tokenizer.decode(input_text)

View File

@@ -26,22 +26,12 @@ class Phi4MMImageProcessor(BaseMultimodalProcessor):
async def process_mm_data_async(
self,
image_data: List[Union[str, bytes]],
audio_data,
input_text,
request_obj,
max_req_input_len,
**kwargs,
):
audio_data = request_obj.audio_data
if not image_data and not audio_data:
return None
if not isinstance(image_data, list):
image_data = [image_data]
if not isinstance(audio_data, list):
audio_data = [audio_data]
if audio_data:
logger.warning(
"Currently SGLang does not support audio data for Phi4MM. We are working on it. You can file an issue to help us prioritize."

View File

@@ -78,12 +78,6 @@ class PixtralProcessor(BaseMultimodalProcessor):
*args,
**kwargs,
):
if not image_data:
return None
if isinstance(image_data, str):
image_data = [image_data]
mm_data = self.load_mm_data(
prompt=input_text,
multimodal_tokens=self.multimodal_tokens,

View File

@@ -49,9 +49,6 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
*args,
**kwargs,
):
if isinstance(image_data, str):
image_data = [image_data]
base_output = self.load_mm_data(
prompt=input_text,
image_data=image_data,
@@ -130,12 +127,13 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
video_grid_thw = None # TODO
combined_mm_item, input_ids = self.process_and_combine_mm_data(base_output)
mm_items, input_ids = self.process_and_combine_mm_data(base_output)
if combined_mm_item is None:
if not mm_items:
# Note(Xinyuan): This is the case where image loading fails.
return None
combined_mm_item = mm_items[0] # only image is supported for now
video_grid_thw = None # TODO
second_per_grid_ts = getattr(combined_mm_item, "second_per_grid_ts", None)
@@ -157,7 +155,7 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
return {
"input_ids": input_ids.tolist(),
"mm_items": [combined_mm_item],
"mm_items": mm_items,
"im_start_id": self.IM_START_TOKEN_ID,
"im_end_id": self.IM_END_TOKEN_ID,
"im_token_id": self.IM_TOKEN_ID,

View File

@@ -37,6 +37,8 @@ class VILAMultimodalProcessor(BaseMultimodalProcessor):
_processor: VILAProcessor,
) -> None:
super().__init__(hf_config, server_args, _processor)
self.IM_TOKEN_ID = hf_config.image_token_id
self.VIDEO_TOKEN_ID = hf_config.video_token_id
async def process_mm_data_async(
self,
@@ -46,13 +48,7 @@ class VILAMultimodalProcessor(BaseMultimodalProcessor):
max_req_input_len: int,
**kwargs,
) -> Optional[Dict[str, Any]]:
if not image_data:
return None
if not isinstance(image_data, list):
image_data = [image_data]
mm_data = self.load_mm_data(
base_output = self.load_mm_data(
prompt=input_text,
multimodal_tokens=MultimodalSpecialTokens(
image_token=self._processor.tokenizer.image_token
@@ -61,25 +57,11 @@ class VILAMultimodalProcessor(BaseMultimodalProcessor):
image_data=image_data,
)
inputs = self.process_mm_data(
input_text=mm_data.input_text,
images=mm_data.images,
)
mm_items, input_ids = self.process_and_combine_mm_data(base_output)
image_offsets = self.get_mm_items_offset(
input_ids=inputs.input_ids[0],
mm_token_id=cast(int, self._processor.tokenizer.image_token_id),
)
mm_items: List[MultimodalDataItem] = [
MultimodalDataItem(
modality=Modality.IMAGE,
image_offsets=image_offsets,
pixel_values=inputs.pixel_values,
)
]
return dict(
input_ids=inputs.input_ids[0].tolist(),
mm_items=mm_items,
)
return {
"input_ids": input_ids.tolist(),
"mm_items": mm_items,
"im_token_id": self.IM_TOKEN_ID,
"video_token_id": self.VIDEO_TOKEN_ID,
}