[Refactor] simplify multimodal data processing (#8107)

Signed-off-by: Xinyuan Tong <justinning0323@outlook.com>
This commit is contained in:
Xinyuan Tong
2025-07-20 21:43:09 -07:00
committed by GitHub
parent c9e8613c97
commit 8430bfe3e9
30 changed files with 297 additions and 421 deletions

View File

@@ -5,7 +5,7 @@ import multiprocessing as mp
import os
import re
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
import numpy as np
import torch
@@ -155,17 +155,15 @@ class BaseMultimodalProcessor(ABC):
self.ATTR_NAME_TO_MODALITY = {
# Image-related attributes
"pixel_values": Modality.IMAGE,
"pixel_values_videos": Modality.VIDEO,
"image_sizes": Modality.IMAGE,
"image_grid_thw": Modality.IMAGE,
"image_attention_mask": Modality.IMAGE,
"image_emb_mask": Modality.IMAGE,
"image_spatial_crop": Modality.IMAGE,
"images_spatial_crop": Modality.IMAGE,
"tgt_size": Modality.IMAGE,
"image_grid_hws": Modality.IMAGE,
"aspect_ratio_id": Modality.IMAGE,
"aspect_ratio_ids": Modality.IMAGE,
"aspect_ratio_mask": Modality.IMAGE,
"second_per_grid_ts": Modality.IMAGE,
# Audio-related attributes
"audio_features": Modality.AUDIO,
"audio_feature_lens": Modality.AUDIO,
@@ -173,9 +171,11 @@ class BaseMultimodalProcessor(ABC):
"input_features_mask": Modality.AUDIO,
"audio_attention_mask": Modality.AUDIO,
# Video-related attributes
"pixel_values_videos": Modality.VIDEO,
"second_per_grid_ts": Modality.VIDEO,
"video_grid_thw": Modality.VIDEO,
# Generic attributes that could apply to multiple modalities
# "precomputed_features" - handled specially as it can be any modality
# "precomputed_embeddings" - handled specially as it can be any modality
}
# name of the feature filed
@@ -222,7 +222,6 @@ class BaseMultimodalProcessor(ABC):
audio_data,
input_text,
request_obj,
max_req_input_len,
**kwargs,
) -> Optional[Dict[str, Any]]:
pass
@@ -283,7 +282,7 @@ class BaseMultimodalProcessor(ABC):
self,
text_parts: List[str],
multimodal_tokens: MultimodalSpecialTokens,
data_iterators: dict,
data_iterators: dict[Modality, Iterator[Any]],
discard_alpha_channel: bool = True,
image_estimated_frames_iter: Optional[iter] = None,
image_scaling_factor: float = 1.0,
@@ -354,7 +353,6 @@ class BaseMultimodalProcessor(ABC):
self,
prompt: str,
multimodal_tokens: MultimodalSpecialTokens,
max_req_input_len: int,
image_data: Optional[list] = None,
video_data: Optional[list] = None,
audio_data: Optional[list] = None,
@@ -489,50 +487,11 @@ class BaseMultimodalProcessor(ABC):
return list(zip(indices_start.tolist(), indices_end.tolist()))
@staticmethod
def _extract_processor_features(
items: List[dict], attr_name: str
) -> Optional[torch.Tensor]:
"""
Helper function to concat extracted attributes from processor output.
"""
values = [value for item in items if (value := item.get(attr_name)) is not None]
return torch.cat(values) if values else None
# When we assume that all the items have the same attributes
def _extract_processor_features_from_all_attributes(
self, items: List[dict]
) -> dict:
values = {}
# Verify all items have the same keys
first_keys = set(items[0].keys())
for item in items[1:]:
if set(item.keys()) != first_keys:
raise ValueError(
f"All items must have the same attributes. "
f"First item has {first_keys}, but found {set(item.keys())}"
)
# Process each attribute
for k, v in items[0].items():
if isinstance(v, list):
values[k] = self._extract_processor_features(items, k)
else:
# Verify all items have the same value for non-list attributes
for item in items[1:]:
if item[k] != v:
raise ValueError(
f"All items must have the same value for attribute {k}. "
f"First item has {v}, but found {item[k]}"
)
values[k] = v
return values
def collect_mm_items_from_processor_output(
self, data_dict: dict
) -> List[MultimodalDataItem]:
"""Create mm_items directly from processor output."""
items = {} # modality -> MultimodalDataItem
items: dict[Modality, MultimodalDataItem] = {}
for attr_name, value in data_dict.items():
if attr_name == "input_ids":
@@ -541,16 +500,15 @@ class BaseMultimodalProcessor(ABC):
# Get modality for this attribute
modality = self.ATTR_NAME_TO_MODALITY.get(attr_name)
if not modality and attr_name == "precomputed_features":
if attr_name == "precomputed_embeddings":
modality_str = data_dict.get("modality")
try:
modality = (
Modality.from_str(modality_str)
if modality_str
else Modality.IMAGE
)
except ValueError:
modality = Modality.IMAGE
modality = Modality.IMAGE
if modality_str:
try:
modality = Modality.from_str(modality_str)
except ValueError:
pass
if modality:
# Create item if needed
if modality not in items:
@@ -559,8 +517,7 @@ class BaseMultimodalProcessor(ABC):
if attr_name in self.FEATURE_NAMES:
attr_name = "feature"
# Set attribute
setattr(items[modality], attr_name, value)
items[modality].set(attr_name, value)
return list(items.values())
@@ -586,6 +543,7 @@ class BaseMultimodalProcessor(ABC):
self,
base_output: BaseMultiModalProcessorOutput,
mm_tokens: MultimodalSpecialTokens,
**kwargs,
) -> Tuple[List[MultimodalDataItem], torch.Tensor, dict]:
"""
Process multimodal data and return the combined multimodal items and input_ids.
@@ -618,7 +576,7 @@ class BaseMultimodalProcessor(ABC):
else:
raise ValueError(f"Unknown multimodal item type: {type(item)}")
# Process items and get input_ids
all_collected_items = []
all_collected_items: list[MultimodalDataItem] = []
input_ids = None
# Handle dict items (already processed)
@@ -634,6 +592,7 @@ class BaseMultimodalProcessor(ABC):
images=raw_images,
audios=raw_audios,
videos=raw_videos,
**kwargs,
)
all_collected_items.extend(collected_items)
else:

View File

@@ -1,9 +1,10 @@
from typing import List, Union
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.models.clip import CLIPModel
from sglang.srt.multimodal.processors.base_processor import BaseMultimodalProcessor
from sglang.srt.utils import load_image
from sglang.srt.multimodal.processors.base_processor import (
BaseMultimodalProcessor,
MultimodalSpecialTokens,
)
class ClipImageProcessor(BaseMultimodalProcessor):
@@ -11,23 +12,24 @@ class ClipImageProcessor(BaseMultimodalProcessor):
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
self.mm_tokens = MultimodalSpecialTokens(image_token="<image>").build(
_processor
)
async def process_mm_data_async(
self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs
):
if isinstance(input_text, list):
assert len(input_text) and isinstance(input_text[0], int)
input_text = self._processor.tokenizer.decode(input_text)
base_output = self.load_mm_data(
prompt=input_text,
multimodal_tokens=self.mm_tokens,
image_data=image_data,
)
images = [load_image(image)[0] for image in image_data]
mm_items, input_ids, _ = self.process_and_combine_mm_data(
base_output, self.mm_tokens
)
image_inputs = self.process_mm_data(input_text=input_text, images=images)
image_inputs["data_hashes"] = [hash(str(image_data))]
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
image_inputs["mm_items"] = [
MultimodalDataItem(
feature=image_inputs["pixel_values"], modality=Modality.IMAGE
)
]
return image_inputs
return {
"input_ids": input_ids.tolist(),
"mm_items": mm_items,
}

View File

@@ -33,9 +33,9 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
self.mm_tokens = MultimodalSpecialTokens(image_token="<image>").build(
_processor
)
self.mm_tokens = MultimodalSpecialTokens(
image_token="<image>", image_token_id=self._processor.image_token_id
).build(_processor)
async def process_mm_data_async(
self,
@@ -50,36 +50,16 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
input_text,
image_data=image_data,
multimodal_tokens=self.mm_tokens,
max_req_input_len=max_req_input_len,
)
res = self.process_mm_data(
input_text=base_output.input_text,
images=base_output.images,
mm_items, input_ids, _ = self.process_and_combine_mm_data(
base_output,
self.mm_tokens,
max_req_input_len=max_req_input_len,
conversations=base_output.input_text,
)
images_seq_mask = res["images_seq_mask"]
images_spatial_crop = res["images_spatial_crop"]
batched_images_spatial_crop = []
batched_images_spatial_crop.append(images_spatial_crop)
batched_images_spatial_crop = torch.stack(batched_images_spatial_crop, dim=0)
items = []
input_ids = res["input_ids"]
image_offsets = self.get_mm_items_offset(
input_ids=input_ids, mm_token_id=self._processor.image_token_id
)
item = MultimodalDataItem(
feature=res["images"],
offsets=image_offsets,
modality=Modality.IMAGE,
image_emb_mask=images_seq_mask,
image_spatial_crop=batched_images_spatial_crop,
)
items += [item]
return {
"mm_items": items,
"mm_items": mm_items,
"input_ids": input_ids.tolist(),
"im_token_id": self._processor.image_token_id,
}

View File

@@ -33,7 +33,6 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
image_data: List[Union[str, bytes, Dict]],
input_text,
request_obj,
max_req_input_len,
*args,
**kwargs,
):
@@ -41,7 +40,6 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
prompt=input_text,
image_data=image_data,
multimodal_tokens=self.mm_tokens,
max_req_input_len=max_req_input_len,
discard_alpha_channel=True,
)

View File

@@ -54,7 +54,6 @@ class Gemma3nSGLangProcessor(SGLangBaseProcessor):
audio_data: Optional[List[Union[str, bytes, Dict]]] = None,
input_text: str = "",
request_obj=None,
max_req_input_len: int = 0,
*args,
**kwargs,
):
@@ -63,7 +62,6 @@ class Gemma3nSGLangProcessor(SGLangBaseProcessor):
prompt=input_text,
image_data=image_data,
audio_data=audio_data,
max_req_input_len=max_req_input_len,
multimodal_tokens=self.mm_tokens,
)

View File

@@ -170,13 +170,12 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
return pixel_values, num_patches_list
async def process_mm_data_async(
self, image_data, input_text, request_obj, max_req_input_len, **kwargs
self, image_data, input_text, request_obj, **kwargs
):
base_output = self.load_mm_data(
prompt=input_text,
image_data=image_data,
multimodal_tokens=self.mm_tokens,
max_req_input_len=max_req_input_len,
discard_alpha_channel=True,
)

View File

@@ -11,52 +11,35 @@ from sglang.srt.multimodal.processors.base_processor import (
class JanusProImageProcessor(BaseMultimodalProcessor):
models = [MultiModalityCausalLM]
def __init__(self, hf_config, server_args, processor):
super().__init__(hf_config, server_args, processor)
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
self.mm_tokens = MultimodalSpecialTokens(
image_token=processor.image_token
).build(processor)
image_token=_processor.image_token,
image_token_id=_processor.image_id,
).build(_processor)
async def process_mm_data_async(
self,
image_data: List[Union[str, bytes]],
input_text,
request_obj,
max_req_input_len,
**kwargs,
):
processor = self._processor
base_out = self.load_mm_data(
prompt=input_text,
image_data=image_data,
multimodal_tokens=self.mm_tokens,
max_req_input_len=max_req_input_len,
)
images = base_out.images
res = self.process_mm_data(
input_text=base_out.input_text,
prompt=base_out.input_text,
images=images,
mm_items, input_ids, _ = self.process_and_combine_mm_data(
base_out, self.mm_tokens, prompt=base_out.input_text
)
input_ids = res["input_ids"].flatten()
image_offsets = self.get_mm_items_offset(
input_ids=input_ids, mm_token_id=processor.image_id
)
return {
"mm_items": [
MultimodalDataItem(
feature=res["pixel_values"],
image_emb_mask=res["images_emb_mask"],
offsets=image_offsets,
modality=Modality.IMAGE,
)
],
"mm_items": mm_items,
"input_ids": input_ids.tolist(),
"im_start_id": processor.image_start_id,
"im_end_id": processor.image_end_id,
"im_token_id": processor.image_id,
"im_start_id": self._processor.image_start_id,
"im_end_id": self._processor.image_end_id,
"im_token_id": self.mm_tokens.image_token_id,
}

View File

@@ -26,7 +26,6 @@ class KimiVLImageProcessor(SGLangBaseProcessor):
image_data: List[Union[str, bytes, Dict]],
input_text,
request_obj,
max_req_input_len,
*args,
**kwargs,
):
@@ -34,7 +33,6 @@ class KimiVLImageProcessor(SGLangBaseProcessor):
prompt=input_text,
image_data=image_data,
multimodal_tokens=self.mm_tokens,
max_req_input_len=max_req_input_len,
)
mm_items, input_ids, _ = self.process_and_combine_mm_data(

View File

@@ -159,7 +159,9 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
"mm_items": [
MultimodalDataItem(
feature=pixel_values,
image_sizes=image_sizes,
model_specific_data={
"image_sizes": image_sizes,
},
modality=modality,
)
],

View File

@@ -17,10 +17,21 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
# Collect special token ids
tokenizer = self._processor.tokenizer
self.slice_start_id = getattr(tokenizer, "slice_start_id", None)
self.slice_end_id = getattr(tokenizer, "slice_end_id", None)
self.audio_start_id = getattr(tokenizer, "audio_start_id", None)
self.audio_end_id = getattr(tokenizer, "audio_end_id", None)
self.im_start_id = getattr(tokenizer, "im_start_id", None)
self.im_end_id = getattr(tokenizer, "im_end_id", None)
self.im_token_id = getattr(tokenizer, "unk_id", None)
self.mm_tokens = MultimodalSpecialTokens(
image_token="(<image>./</image>)",
audio_token="(<audio>./</audio>)",
video_token="(<video>./</video>)",
image_token_id=self.im_token_id,
).build(_processor)
async def process_mm_data_async(
@@ -29,12 +40,10 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
audio_data: List[Union[str, bytes]],
input_text,
request_obj,
max_req_input_len,
**kwargs,
):
base_output = self.load_mm_data(
prompt=input_text,
max_req_input_len=max_req_input_len,
audio_data=audio_data,
image_data=image_data,
multimodal_tokens=self.mm_tokens,
@@ -48,24 +57,6 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
audios=base_output.audios,
)
# Collect special token ids
tokenizer = self._processor.tokenizer
slice_start_id, slice_end_id, audio_start_id, audio_end_id = (
None,
None,
None,
None,
)
if tokenizer.slice_start_id:
slice_start_id = tokenizer.slice_start_id
slice_end_id = tokenizer.slice_end_id
if hasattr(tokenizer, "audio_start_id"):
audio_start_id = tokenizer.audio_start_id
audio_end_id = tokenizer.audio_end_id
im_start_id = tokenizer.im_start_id
im_end_id = tokenizer.im_end_id
im_token_id = tokenizer.unk_id
pixel_values = res["pixel_values"]
tgt_sizes = res["tgt_sizes"]
@@ -102,10 +93,12 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
items = []
input_ids = res["input_ids"].flatten()
image_offsets = self.get_mm_items_offset_by_pair(
input_ids=input_ids, mm_start_id=im_start_id, mm_end_id=im_end_id
input_ids=input_ids, mm_start_id=self.im_start_id, mm_end_id=self.im_end_id
)
slice_offsets = self.get_mm_items_offset_by_pair(
input_ids=input_ids, mm_start_id=slice_start_id, mm_end_id=slice_end_id
input_ids=input_ids,
mm_start_id=self.slice_start_id,
mm_end_id=self.slice_end_id,
)
image_offsets.extend(slice_offsets)
image_offsets = sorted(image_offsets)
@@ -114,7 +107,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
item = MultimodalDataItem(
feature=pixel_values,
offsets=image_offsets,
tgt_size=tgt_sizes_flat,
model_specific_data={"tgt_size": tgt_sizes_flat},
modality=Modality.IMAGE,
)
items += [item]
@@ -124,17 +117,17 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
and res["audio_features"] is not None
and len(res["audio_features"]) != 0
):
if audio_start_id is not None and audio_end_id is not None:
if self.audio_start_id is not None and self.audio_end_id is not None:
audio_offsets = self.get_mm_items_offset_by_pair(
input_ids=input_ids,
mm_start_id=audio_start_id,
mm_end_id=audio_end_id,
mm_start_id=self.audio_start_id,
mm_end_id=self.audio_end_id,
)
else:
audio_offsets = None
item = MultimodalDataItem(
feature=[res["audio_features"]],
audio_feature_lens=res["audio_feature_lens"],
model_specific_data={"audio_feature_lens": res["audio_feature_lens"]},
offsets=audio_offsets,
modality=Modality.AUDIO,
)
@@ -142,11 +135,11 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
return {
"mm_items": items,
"input_ids": input_ids.tolist(),
"audio_start_id": audio_start_id,
"audio_end_id": audio_end_id,
"im_token_id": im_token_id,
"im_start_id": im_start_id,
"im_end_id": im_end_id,
"slice_start_id": slice_start_id,
"slice_end_id": slice_end_id,
"audio_start_id": self.audio_start_id,
"audio_end_id": self.audio_end_id,
"im_token_id": self.im_token_id,
"im_start_id": self.im_start_id,
"im_end_id": self.im_end_id,
"slice_start_id": self.slice_start_id,
"slice_end_id": self.slice_end_id,
}

View File

@@ -1,9 +1,10 @@
from typing import List, Union
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.models.mllama import MllamaForConditionalGeneration
from sglang.srt.multimodal.processors.base_processor import BaseMultimodalProcessor
from sglang.srt.utils import load_image
from sglang.srt.multimodal.processors.base_processor import (
BaseMultimodalProcessor,
MultimodalSpecialTokens,
)
class MllamaImageProcessor(BaseMultimodalProcessor):
@@ -11,24 +12,26 @@ class MllamaImageProcessor(BaseMultimodalProcessor):
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
self.mm_tokens = MultimodalSpecialTokens(
image_token=self._processor.image_token,
image_token_id=self._processor.image_token_id,
).build(_processor)
async def process_mm_data_async(
self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs
):
if isinstance(input_text, list):
assert len(input_text) and isinstance(input_text[0], int)
input_text = self._processor.tokenizer.decode(input_text)
base_out = self.load_mm_data(
prompt=input_text,
image_data=image_data,
multimodal_tokens=self.mm_tokens,
)
images = [load_image(image)[0] for image in image_data]
image_inputs = self.process_mm_data(input_text=input_text, images=images)
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
image_inputs["mm_items"] = [
MultimodalDataItem(
feature=image_inputs["pixel_values"],
aspect_ratio_id=image_inputs["aspect_ratio_ids"],
aspect_ratio_mask=image_inputs["aspect_ratio_mask"],
modality=Modality.IMAGE,
)
]
mm_items, input_ids, _ = self.process_and_combine_mm_data(
base_out, self.mm_tokens
)
return image_inputs
return {
"mm_items": mm_items,
"input_ids": input_ids.tolist(),
"im_token_id": self.mm_tokens.image_token_id,
}

View File

@@ -27,13 +27,13 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
self.image_token_index = hf_config.image_token_index
self.multimodal_tokens = MultimodalSpecialTokens(
image_token=_processor.image_token,
image_token_id=self.image_token_index,
).build(_processor)
async def process_mm_data_async(
self,
image_data: List[Union[str, bytes]],
input_text,
max_req_input_len=None,
*args,
**kwargs,
):
@@ -45,7 +45,6 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
processed_data = self.load_mm_data(
prompt=input_text,
multimodal_tokens=self.multimodal_tokens,
max_req_input_len=max_req_input_len or 4096,
image_data=image_data,
return_text=True,
)

View File

@@ -31,6 +31,7 @@ class Phi4MMProcessorAdapter(ProcessorMixin):
for hf_key, sglang_key in key_mapping.items():
if hf_key in result:
result[sglang_key] = result[hf_key]
del result[hf_key]
# Filter out None or empty tensors from the result.
# This prevents the sglang function base_processor.collect_mm_items_from_processor_output()
@@ -58,7 +59,7 @@ class Phi4MMMultimodalProcessor(BaseMultimodalProcessor):
self.AUDIO_TOKEN_ID = 200011
self.AUDIO_SAMPLE_RATE = 16000
self.multimodal_tokens = MultimodalSpecialTokens(
self.mm_tokens = MultimodalSpecialTokens(
image_token=self.IMAGE_TOKEN,
image_token_id=self.IM_TOKEN_ID,
audio_token=self.AUDIO_TOKEN,
@@ -71,15 +72,13 @@ class Phi4MMMultimodalProcessor(BaseMultimodalProcessor):
audio_data,
input_text,
request_obj,
max_req_input_len,
**kwargs,
):
base_output = self.load_mm_data(
prompt=input_text,
max_req_input_len=max_req_input_len,
audio_data=audio_data,
image_data=image_data,
multimodal_tokens=self.multimodal_tokens,
multimodal_tokens=self.mm_tokens,
audio_sample_rate=self.AUDIO_SAMPLE_RATE,
)
@@ -91,12 +90,12 @@ class Phi4MMMultimodalProcessor(BaseMultimodalProcessor):
]
mm_items, input_ids, _ = self.process_and_combine_mm_data(
base_output, self.multimodal_tokens
base_output, self.mm_tokens
)
return {
"input_ids": input_ids.tolist(),
"mm_items": mm_items,
"im_token_id": self.IM_TOKEN_ID,
"audio_token_id": self.AUDIO_TOKEN_ID,
"im_token_id": self.mm_tokens.image_token_id,
"audio_token_id": self.mm_tokens.audio_token_id,
}

View File

@@ -6,7 +6,6 @@ from transformers.models.pixtral.image_processing_pixtral import (
_num_image_tokens as _get_pixtral_hf_num_image_tokens,
)
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.models.pixtral import PixtralVisionModel
from sglang.srt.multimodal.processors.base_processor import (
BaseMultimodalProcessor,
@@ -45,7 +44,7 @@ class PixtralProcessor(BaseMultimodalProcessor):
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
self.image_token_id = getattr(
self.IM_TOKEN_ID = getattr(
hf_config, "image_token_index", PixtralVisionModel.DEFAULT_IMAGE_TOKEN_ID
)
# Instantiate the patcher logic helper using the class defined above
@@ -53,8 +52,9 @@ class PixtralProcessor(BaseMultimodalProcessor):
self.vision_config = hf_config.vision_config
self.image_size = self.vision_config.image_size
self.patch_size = self.vision_config.patch_size
self.multimodal_tokens = MultimodalSpecialTokens(
image_token=_processor.image_token
self.mm_tokens = MultimodalSpecialTokens(
image_token=_processor.image_token,
image_token_id=self.IM_TOKEN_ID,
).build(_processor)
_processor.tokenizer.add_special_tokens(
{
@@ -80,42 +80,21 @@ class PixtralProcessor(BaseMultimodalProcessor):
):
mm_data = self.load_mm_data(
prompt=input_text,
multimodal_tokens=self.multimodal_tokens,
max_req_input_len=kwargs.get("max_req_input_len", 4096),
multimodal_tokens=self.mm_tokens,
image_data=image_data,
return_text=True,
)
if mm_data.images:
resize_tasks = [self._resize(image) for image in mm_data.images]
mm_data.images = await asyncio.gather(*resize_tasks)
processor_output = self.process_mm_data(
input_text=mm_data.input_text,
images=mm_data.images,
mm_items, input_ids, _ = self.process_and_combine_mm_data(
mm_data, self.mm_tokens
)
if "pixel_values" in processor_output:
input_ids = processor_output["input_ids"].view(-1)
image_offsets = self.get_mm_items_offset(
input_ids=input_ids,
mm_token_id=self.image_token_id,
)
mm_items = [
MultimodalDataItem(
feature=processor_output["pixel_values"],
image_sizes=processor_output["image_sizes"],
modality=Modality.IMAGE,
offsets=image_offsets,
)
]
input_ids = input_ids.tolist()
processor_output.update(
input_ids=input_ids,
mm_items=mm_items,
# there's no im_start_id for pixtral, only im_token and im_end_token
im_end_id=self.IMG_END_TOKEN_ID,
im_token_id=self.image_token_id,
)
return processor_output
return {
"mm_items": mm_items,
"input_ids": input_ids.tolist(),
"im_token_id": self.IM_TOKEN_ID,
"im_token": self._processor.image_token,
}

View File

@@ -0,0 +1,65 @@
import re
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.models.qwen2_audio import Qwen2AudioForConditionalGeneration
from sglang.srt.multimodal.processors.base_processor import (
BaseMultimodalProcessor,
MultimodalSpecialTokens,
)
class Qwen2AudioMultimodalProcessor(BaseMultimodalProcessor):
models = [Qwen2AudioForConditionalGeneration]
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
self.AUDIO_TOKEN = "<|audio_bos|><|AUDIO|><|audio_eos|>"
self.AUDIO_TOKEN_REGEX = re.compile(
r"<\|audio_bos\|>(?:<\|AUDIO\|>)+<\|audio_eos\|>"
)
# Collect special token ids
tokenizer = self._processor.tokenizer
self.audio_start_id = tokenizer.convert_tokens_to_ids("<|audio_bos|>")
self.audio_token_id = tokenizer.convert_tokens_to_ids("<|AUDIO|>")
self.audio_end_id = tokenizer.convert_tokens_to_ids("<|audio_eos|>")
self.mm_tokens = MultimodalSpecialTokens(
audio_token=self.AUDIO_TOKEN,
audio_token_regex=self.AUDIO_TOKEN_REGEX,
audio_token_id=self.audio_token_id,
).build(_processor)
async def process_mm_data_async(
self,
audio_data,
input_text,
**kwargs,
):
base_output = self.load_mm_data(
prompt=input_text,
audio_data=audio_data,
multimodal_tokens=self.mm_tokens,
)
if base_output is None:
return None
mm_items, input_ids, ret = self.process_and_combine_mm_data(
base_output, self.mm_tokens
)
assert (
"feature_attention_mask" in ret
), "feature_attention_mask not found in processor output"
input_lengths = ret["feature_attention_mask"].sum(dim=-1)
input_lengths = (input_lengths - 1) // 2 + 1
output_lengths = (input_lengths - 2) // 2 + 1
mm_items[0].model_specific_data["audio_feature_lens"] = output_lengths
return {
"mm_items": mm_items,
"input_ids": input_ids.tolist(),
"audio_start_id": self.audio_start_id,
"audio_token_id": self.audio_token_id,
"audio_end_id": self.audio_end_id,
}

View File

@@ -227,7 +227,6 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
image_data: List[Union[str, bytes]],
input_text,
request_obj,
max_req_input_len,
*args,
**kwargs,
):
@@ -237,7 +236,6 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
image_data=image_data,
video_data=request_obj.video_data,
multimodal_tokens=self.mm_tokens,
max_req_input_len=max_req_input_len,
)
# Qwen-specific: resize images if they are raw Image objects

View File

@@ -47,13 +47,11 @@ class VILAMultimodalProcessor(BaseMultimodalProcessor):
image_data: Optional[ImageDataInputItem | List[ImageDataInputItem]],
input_text: str | List[int],
request_obj: GenerateReqInput | EmbeddingReqInput,
max_req_input_len: int,
**kwargs,
) -> Optional[Dict[str, Any]]:
base_output = self.load_mm_data(
prompt=input_text,
multimodal_tokens=self.mm_tokens,
max_req_input_len=max_req_input_len,
image_data=image_data,
)