Refactor mm processors and Enable mixed modality processing (#7629)

Signed-off-by: Xinyuan Tong <justinning0323@outlook.com>
This commit is contained in:
Xinyuan Tong
2025-06-30 23:14:48 -07:00
committed by GitHub
parent 886d344964
commit 3a911b854d
28 changed files with 235 additions and 428 deletions

View File

@@ -125,74 +125,38 @@ class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPa
e.g. <image><image>....<image>, or <audio><audio>...<audio> e.g. <image><image>....<image>, or <audio><audio>...<audio>
""" """
def __init__(self, token_ids: List[int]) -> None:
self.token_ids = token_ids
def pad_input_tokens( def pad_input_tokens(
self, input_ids: List[int], mm_inputs: MultimodalInputs self, input_ids: List[int], mm_inputs: MultimodalInputs
) -> List[int]: ) -> List[int]:
""" """
Finds contiguous regions of tokens matching `self.token_ids` in `input_ids` Replaces multimodal tokens in input_ids with corresponding pad_values from mm_items.
and replaces each region with the corresponding `pad_value` from `mm_inputs.mm_items`. Each modality (image, audio, video) is handled separately based on its token_id.
""" """
pad_values = [item.pad_value for item in mm_inputs.mm_items] if not input_ids or not mm_inputs.mm_items:
if not pad_values:
# No multimodal items, return original input_ids
return input_ids return input_ids
if not input_ids:
return []
input_ids_tensor = torch.tensor(input_ids) input_ids_tensor = torch.tensor(input_ids)
device = input_ids_tensor.device
token_ids_tensor = torch.tensor(self.token_ids, device=device)
mask = torch.isin(input_ids_tensor, token_ids_tensor)
if not mask.any(): # Create mapping of token_ids to pad_values for each modality
# No tokens match token_ids, return original input_ids token_to_pad_mapping = {}
return input_ids
# Find contiguous regions for item in mm_inputs.mm_items:
padded_mask = torch.cat( if item.is_image() and mm_inputs.im_token_id is not None:
( token_to_pad_mapping[mm_inputs.im_token_id] = item.pad_value
torch.tensor([False], device=device), elif item.is_audio() and mm_inputs.audio_token_id is not None:
mask, token_to_pad_mapping[mm_inputs.audio_token_id] = item.pad_value
torch.tensor([False], device=device), elif item.is_video() and mm_inputs.video_token_id is not None:
) token_to_pad_mapping[mm_inputs.video_token_id] = item.pad_value
)
# Find indices where the mask value changes
diff_indices = torch.where(padded_mask[1:] != padded_mask[:-1])[0]
# Start indices are where False changes to True
starts = diff_indices[::2]
# End indices are where True changes to False (exclusive index)
ends = diff_indices[1::2]
# Check if the number of regions matches the number of pad values
if len(starts) != len(pad_values):
# Maybe log a warning here?
num_regions = len(starts)
num_pad_values = len(pad_values)
if num_regions > 0 and num_pad_values > 0:
pad_values = (pad_values * (num_regions // num_pad_values + 1))[
:num_regions
]
else: # If no regions or no pad_values, this loop won't run anyway.
pad_values = [] # Ensure pad_values is empty if starts is empty
# Create a copy to modify
output_ids_tensor = input_ids_tensor.clone()
# Replace tokens in each region with the corresponding pad value
# Ensure we don't iterate if pad_values became empty due to mismatch and num_regions=0
for i in range(min(len(starts), len(pad_values))):
start_idx = starts[i]
end_idx = ends[i]
pad_value = pad_values[i]
if pad_value is not None: # Ensure pad_value is not None before assignment
output_ids_tensor[start_idx:end_idx] = pad_value
else: else:
logger.warning(f"Skipping region {i} due to None pad_value.") raise ValueError(f"No multimodal token id provided for {item.modality}")
return output_ids_tensor.tolist()
# Apply replacements for all tokens at once
for token_id, pad_value in token_to_pad_mapping.items():
input_ids_tensor[input_ids_tensor == token_id] = pad_value
ret_input_ids = input_ids_tensor.tolist()
return ret_input_ids
embedding_cache = None embedding_cache = None

View File

@@ -174,6 +174,15 @@ class Modality(Enum):
VIDEO = auto() VIDEO = auto()
AUDIO = auto() AUDIO = auto()
@staticmethod
def from_str(modality_str: str):
try:
return Modality[modality_str.upper()]
except KeyError:
raise ValueError(
f"Invalid modality string: {modality_str}. Valid modalities are: {[m.name for m in Modality]}"
)
@dataclasses.dataclass @dataclasses.dataclass
class MultimodalDataItem: class MultimodalDataItem:

View File

@@ -482,20 +482,25 @@ class TokenizerManager:
token_type_ids = encoded.get("token_type_ids", [None])[0] token_type_ids = encoded.get("token_type_ids", [None])[0]
if self.mm_processor and obj.contains_mm_input(): if self.mm_processor and obj.contains_mm_input():
image_inputs: Dict = await self.mm_processor.process_mm_data_async( if not isinstance(obj.image_data, list):
obj.image_data = [obj.image_data]
if not isinstance(obj.audio_data, list):
obj.audio_data = [obj.audio_data]
mm_inputs: Dict = await self.mm_processor.process_mm_data_async(
image_data=obj.image_data, image_data=obj.image_data,
audio_data=obj.audio_data,
input_text=input_text or input_ids, input_text=input_text or input_ids,
request_obj=obj, request_obj=obj,
max_req_input_len=self.max_req_input_len, max_req_input_len=self.max_req_input_len,
) )
if image_inputs and "input_ids" in image_inputs: if mm_inputs and "input_ids" in mm_inputs:
input_ids = image_inputs["input_ids"] input_ids = mm_inputs["input_ids"]
else: else:
image_inputs: Optional[Dict] = None mm_inputs = None
self._validate_one_request(obj, input_ids) self._validate_one_request(obj, input_ids)
return self._create_tokenized_object( return self._create_tokenized_object(
obj, input_text, input_ids, input_embeds, image_inputs, token_type_ids obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids
) )
def _validate_one_request( def _validate_one_request(
@@ -559,7 +564,7 @@ class TokenizerManager:
input_text: str, input_text: str,
input_ids: List[int], input_ids: List[int],
input_embeds: Optional[Union[List[float], None]] = None, input_embeds: Optional[Union[List[float], None]] = None,
image_inputs: Optional[Dict] = None, mm_inputs: Optional[Dict] = None,
token_type_ids: Optional[List[int]] = None, token_type_ids: Optional[List[int]] = None,
) -> Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]: ) -> Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]:
"""Create a tokenized request object from common parameters.""" """Create a tokenized request object from common parameters."""
@@ -584,7 +589,7 @@ class TokenizerManager:
obj.rid, obj.rid,
input_text, input_text,
input_ids, input_ids,
image_inputs, mm_inputs,
sampling_params, sampling_params,
obj.return_logprob, obj.return_logprob,
obj.logprob_start_len, obj.logprob_start_len,
@@ -606,7 +611,7 @@ class TokenizerManager:
obj.rid, obj.rid,
input_text, input_text,
input_ids, input_ids,
image_inputs, mm_inputs,
token_type_ids, token_type_ids,
sampling_params, sampling_params,
) )
@@ -644,9 +649,9 @@ class TokenizerManager:
) -> None: ) -> None:
"""Validate constraints for batch tokenization processing.""" """Validate constraints for batch tokenization processing."""
for i in range(batch_size): for i in range(batch_size):
if self.is_generation and obj[i].image_data: if self.is_generation and obj[i].contains_mm_input():
raise ValueError( raise ValueError(
"For image input processing do not set `enable_tokenizer_batch_encode`." "For multimodal input processing do not set `enable_tokenizer_batch_encode`."
) )
if obj[i].input_ids is not None: if obj[i].input_ids is not None:
raise ValueError( raise ValueError(

View File

@@ -253,11 +253,9 @@ class DeepseekVL2ForCausalLM(nn.Module):
weights_loader = getattr(param, "weight_loader", default_weight_loader) weights_loader = getattr(param, "weight_loader", default_weight_loader)
weights_loader(param, loaded_weight) weights_loader(param, loaded_weight)
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs): def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
helper = MultiModalityDataPaddingPatternMultimodalTokens( pattern = MultiModalityDataPaddingPatternMultimodalTokens()
[image_inputs.im_token_id] return pattern.pad_input_tokens(input_ids, mm_inputs)
)
return helper.pad_input_tokens(input_ids, image_inputs)
def get_image_feature(self, items: List[MultimodalDataItem]): def get_image_feature(self, items: List[MultimodalDataItem]):

View File

@@ -21,7 +21,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.managers.mm_utils import ( from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs, MultiModalityDataPaddingPatternMultimodalTokens,
general_mm_embed_routine, general_mm_embed_routine,
) )
from sglang.srt.managers.schedule_batch import ( from sglang.srt.managers.schedule_batch import (
@@ -244,26 +244,11 @@ class Gemma3nForConditionalGeneration(PreTrainedModel):
def pad_input_ids( def pad_input_ids(
self, self,
input_ids: List[int], input_ids: List[int],
mm_inputs: Optional[MultimodalInputs] = None, mm_inputs: MultimodalInputs,
) -> List[int]: ) -> List[int]:
"""Pad input IDs with image and audio tokens.""" """Pad input IDs with image and audio tokens."""
if mm_inputs is None: pattern = MultiModalityDataPaddingPatternMultimodalTokens()
return input_ids return pattern.pad_input_tokens(input_ids, mm_inputs)
# Collect available media token pairs
media_token_pairs = []
for attr_name in ["im_start_id", "audio_start_id"]:
if hasattr(mm_inputs, attr_name):
start_id = getattr(mm_inputs, attr_name)
end_id = getattr(mm_inputs, attr_name.replace("start", "end"))
media_token_pairs.append((start_id, end_id))
# Apply padding pattern if we have media tokens
if media_token_pairs:
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
return pattern.pad_input_tokens(input_ids, mm_inputs)
return input_ids
def get_input_embeddings(self) -> nn.Embedding: def get_input_embeddings(self) -> nn.Embedding:
return self.language_model.get_input_embeddings() return self.language_model.get_input_embeddings()
@@ -431,7 +416,6 @@ class Gemma3nForConditionalGeneration(PreTrainedModel):
) )
positions += 1 positions += 1
if input_ids is not None: if input_ids is not None:
# Prepare per-layer inputs from inputs_ids # Prepare per-layer inputs from inputs_ids
per_layer_inputs_mask = torch.logical_and( per_layer_inputs_mask = torch.logical_and(

View File

@@ -154,8 +154,7 @@ class KimiVLForConditionalGeneration(nn.Module):
return res return res
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs): def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
# Get all special token IDs pattern = MultiModalityDataPaddingPatternMultimodalTokens()
pattern = MultiModalityDataPaddingPatternMultimodalTokens(mm_inputs.im_token_id)
return pattern.pad_input_tokens(input_ids, mm_inputs) return pattern.pad_input_tokens(input_ids, mm_inputs)
def forward( def forward(

View File

@@ -50,10 +50,7 @@ class Llama4ForConditionalGeneration(nn.Module):
self.logits_processor = LogitsProcessor(config.text_config) self.logits_processor = LogitsProcessor(config.text_config)
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs): def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
# Get all special token IDs pattern = MultiModalityDataPaddingPatternMultimodalTokens()
im_token_id: int = mm_inputs.im_token_id
pattern = MultiModalityDataPaddingPatternMultimodalTokens([im_token_id])
return pattern.pad_input_tokens(input_ids, mm_inputs) return pattern.pad_input_tokens(input_ids, mm_inputs)
def get_image_feature( def get_image_feature(

View File

@@ -446,9 +446,7 @@ class Phi4MMForCausalLM(nn.Module):
return hidden_states return hidden_states
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs): def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
# Get all special token IDs pattern = MultiModalityDataPaddingPatternMultimodalTokens()
im_token_id: int = mm_inputs.im_token_id
pattern = MultiModalityDataPaddingPatternMultimodalTokens([im_token_id])
return pattern.pad_input_tokens(input_ids, mm_inputs) return pattern.pad_input_tokens(input_ids, mm_inputs)
def should_apply_lora(self, module_name: str) -> bool: def should_apply_lora(self, module_name: str) -> bool:

View File

@@ -268,15 +268,14 @@ class PixtralHFVisionModel(nn.Module):
DEFAULT_IMAGE_TOKEN_ID = 10 DEFAULT_IMAGE_TOKEN_ID = 10
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs): def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
return self.input_padder.pad_input_tokens(input_ids, image_inputs) return self.input_padder.pad_input_tokens(input_ids, mm_inputs)
def __init__( def __init__(
self, self,
config: PixtralVisionConfig, config: PixtralVisionConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
*, *,
image_token_id: int = DEFAULT_IMAGE_TOKEN_ID,
num_hidden_layers_override: Optional[int] = None, num_hidden_layers_override: Optional[int] = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
@@ -314,11 +313,8 @@ class PixtralHFVisionModel(nn.Module):
) )
# Initialize patch position embedding # Initialize patch position embedding
self.image_token_id = image_token_id
self.patch_positional_embedding = PixtralRotaryEmbedding(config) self.patch_positional_embedding = PixtralRotaryEmbedding(config)
self.input_padder = MultiModalityDataPaddingPatternMultimodalTokens( self.input_padder = MultiModalityDataPaddingPatternMultimodalTokens()
[self.image_token_id]
)
@property @property
def dtype(self): def dtype(self):

View File

@@ -493,9 +493,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs): def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
# Get all special token IDs pattern = MultiModalityDataPaddingPatternMultimodalTokens()
im_token_id: int = mm_inputs.im_token_id
pattern = MultiModalityDataPaddingPatternMultimodalTokens([im_token_id])
return pattern.pad_input_tokens(input_ids, mm_inputs) return pattern.pad_input_tokens(input_ids, mm_inputs)
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:

View File

@@ -479,10 +479,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs): def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
# Get all special token IDs pattern = MultiModalityDataPaddingPatternMultimodalTokens()
im_token_id: int = mm_inputs.im_token_id
pattern = MultiModalityDataPaddingPatternMultimodalTokens([im_token_id])
return pattern.pad_input_tokens(input_ids, mm_inputs) return pattern.pad_input_tokens(input_ids, mm_inputs)
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:

View File

@@ -270,15 +270,10 @@ class VILAForConditionalGeneration(nn.Module):
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
def pad_input_ids( def pad_input_ids(
self, self, input_ids: List[int], mm_inputs: MultimodalInputs
input_ids: List[int],
image_inputs: MultimodalInputs,
) -> List[int]: ) -> List[int]:
pattern = MultiModalityDataPaddingPatternMultimodalTokens( pattern = MultiModalityDataPaddingPatternMultimodalTokens()
token_ids=[self.config.image_token_id], return pattern.pad_input_tokens(input_ids, mm_inputs)
)
return pattern.pad_input_tokens(input_ids, image_inputs)
##### BEGIN COPY modeling_vila.py ##### ##### BEGIN COPY modeling_vila.py #####

View File

@@ -17,15 +17,6 @@ from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.utils import encode_video, load_audio, load_image from sglang.srt.utils import encode_video, load_audio, load_image
class MultimodalInputFormat(Enum):
"""Enum for different multimodal input formats."""
RAW_IMAGES = "raw_images"
PRECOMPUTED_FEATURES = "precomputed_features"
PIXEL_VALUES = "pixel_values"
AUDIO = "audio"
@dataclasses.dataclass @dataclasses.dataclass
class BaseMultiModalProcessorOutput: class BaseMultiModalProcessorOutput:
# input_text, with each frame of video/image represented with a image_token # input_text, with each frame of video/image represented with a image_token
@@ -110,18 +101,45 @@ class BaseMultimodalProcessor(ABC):
max_workers=int(os.environ.get("SGLANG_CPU_WORKERS", os.cpu_count())), max_workers=int(os.environ.get("SGLANG_CPU_WORKERS", os.cpu_count())),
) )
# Mapping from attribute names to modality types
self.ATTR_NAME_TO_MODALITY = {
# Image-related attributes
"pixel_values": Modality.IMAGE,
"image_sizes": Modality.IMAGE,
"image_grid_thw": Modality.IMAGE,
"image_emb_mask": Modality.IMAGE,
"image_spatial_crop": Modality.IMAGE,
"tgt_size": Modality.IMAGE,
"image_grid_hws": Modality.IMAGE,
"aspect_ratio_id": Modality.IMAGE,
"aspect_ratio_mask": Modality.IMAGE,
"second_per_grid_ts": Modality.IMAGE,
# Audio-related attributes
"audio_features": Modality.AUDIO,
"audio_feature_lens": Modality.AUDIO,
"input_features": Modality.AUDIO,
"input_features_mask": Modality.AUDIO,
# Video-related attributes
"video_grid_thws": Modality.VIDEO,
# Generic attributes that could apply to multiple modalities
# "precomputed_features" - handled specially as it can be any modality
}
def process_mm_data( def process_mm_data(
self, input_text, images=None, videos=None, audios=None, **kwargs self, input_text, images=None, videos=None, audios=None, **kwargs
): ):
""" """
process multimodal data with transformers AutoProcessor process multimodal data with transformers AutoProcessor
""" """
if images is not None: if images:
kwargs["images"] = images kwargs["images"] = images
if videos is not None: if videos:
kwargs["videos"] = videos kwargs["videos"] = videos
if audios is not None: if audios:
kwargs["audios"] = audios kwargs["audios"] = audios
if self.__class__.__name__ == "Gemma3nSGLangProcessor":
# Note(Xinyuan): for gemma3n, ref: https://github.com/huggingface/transformers/blob/ccf2ca162e33f381e454cdb74bf4b41a51ab976d/src/transformers/models/gemma3n/processing_gemma3n.py#L107
kwargs["audio"] = audios
processor = self._processor processor = self._processor
if hasattr(processor, "image_processor") and isinstance( if hasattr(processor, "image_processor") and isinstance(
@@ -144,6 +162,7 @@ class BaseMultimodalProcessor(ABC):
async def process_mm_data_async( async def process_mm_data_async(
self, self,
image_data, image_data,
audio_data,
input_text, input_text,
request_obj, request_obj,
max_req_input_len, max_req_input_len,
@@ -418,175 +437,137 @@ class BaseMultimodalProcessor(ABC):
values[k] = v values[k] = v
return values return values
def process_and_combine_mm_data( def collect_mm_items_from_processor_output(
self, base_output: BaseMultiModalProcessorOutput self, data_dict: dict
) -> Tuple[Optional[MultimodalDataItem], torch.Tensor]: ) -> List[MultimodalDataItem]:
"""Create mm_items directly from processor output."""
items = {} # modality -> MultimodalDataItem
for attr_name, value in data_dict.items():
if attr_name == "input_ids":
continue
# Get modality for this attribute
modality = self.ATTR_NAME_TO_MODALITY.get(attr_name)
if not modality and attr_name == "precomputed_features":
modality_str = data_dict.get("modality")
try:
modality = (
Modality.from_str(modality_str)
if modality_str
else Modality.IMAGE
)
except ValueError:
modality = Modality.IMAGE
if modality:
# Create item if needed
if modality not in items:
items[modality] = MultimodalDataItem(modality=modality)
# Set attribute
if hasattr(items[modality], attr_name):
setattr(items[modality], attr_name, value)
return list(items.values())
def _process_and_collect_mm_items(
self, input_text: str, images=None, audios=None, videos=None, **kwargs
) -> Tuple[List[MultimodalDataItem], torch.Tensor]:
""" """
Process multimodal data and return the combined multimodal item and input_ids. Helper method to process multimodal data and create mm_items in one step.
Handles all three input formats at the same abstraction level.
Returns: Returns:
Tuple of (combined_mm_item, input_ids) Tuple of (created mm_items, input_ids)
""" """
ret = self.process_mm_data(
input_text=input_text, images=images, audios=audios, videos=videos, **kwargs
)
def tokenize_text(input_text: str) -> torch.Tensor: input_ids = ret["input_ids"].flatten()
"""Tokenize input text.""" collected_items = self.collect_mm_items_from_processor_output(ret)
return self._processor.tokenizer(
input_text, return collected_items, input_ids
def process_and_combine_mm_data(
self, base_output: BaseMultiModalProcessorOutput
) -> Tuple[List[MultimodalDataItem], torch.Tensor]:
"""
Process multimodal data and return the combined multimodal items and input_ids.
Supports mixed modalities (images and audio in the same request).
Returns:
Tuple of (list of mm_items, input_ids)
"""
# Collect all items and categorize them
all_items = (base_output.images or []) + (base_output.audios or [])
# Handle text-only case
if not all_items:
input_ids = self._processor.tokenizer(
base_output.input_text,
return_tensors="pt",
add_special_tokens=True,
).input_ids.flatten()
return [], input_ids
dict_items, raw_images, raw_audios = [], [], []
for item in all_items:
if isinstance(item, dict):
dict_items.append(item)
elif isinstance(item, Image.Image):
raw_images.append(item)
elif isinstance(item, np.ndarray):
raw_audios.append(item)
else:
raise ValueError(f"Unknown multimodal item type: {type(item)}")
# Process items and get input_ids
all_collected_items = []
input_ids = None
# Handle dict items (already processed)
for dict_item in dict_items:
all_collected_items.extend(
self.collect_mm_items_from_processor_output(dict_item)
)
# Handle raw items (need processing)
if raw_images or raw_audios:
collected_items, input_ids = self._process_and_collect_mm_items(
input_text=base_output.input_text,
images=raw_images,
audios=raw_audios,
)
all_collected_items.extend(collected_items)
# Fallback tokenization if no raw items were processed
if input_ids is None:
input_ids = self._processor.tokenizer(
base_output.input_text,
return_tensors="pt", return_tensors="pt",
add_special_tokens=True, add_special_tokens=True,
).input_ids.flatten() ).input_ids.flatten()
def categorize_mm_inputs(mm_inputs: List) -> MultimodalInputFormat: # Add offsets to all items
"""Categorize multimodal inputs and validate consistency.""" for mm_item in all_collected_items:
try: if mm_item.modality in [Modality.IMAGE, Modality.MULTI_IMAGES]:
has_image = False mm_item.image_offsets = self.get_mm_items_offset(
has_pixel_values = False
has_precomputed_features = False
has_audio = False
for mm_input in mm_inputs:
if isinstance(mm_input, Image.Image):
has_image = True
elif isinstance(mm_input, np.ndarray):
has_audio = True
elif isinstance(mm_input, dict):
if mm_input.get("precomputed_features", None) is not None:
has_precomputed_features = True
elif mm_input.get("pixel_values", None) is not None:
has_pixel_values = True
else:
raise ValueError(
f"Invalid multimodal input: {mm_input}, expected dict with pixel_values or precomputed_features"
)
else:
raise ValueError(
f"Invalid multimodal input: {mm_input}, expected Image.Image or dict"
)
# Validate format consistency
format_count = sum(
[has_image, has_pixel_values, has_precomputed_features, has_audio]
)
if format_count > 1:
raise ValueError(
"Unsupported: mixture of multimodal input formats. "
f"Found formats: image={has_image}, pixel_values={has_pixel_values}, "
f"precomputed_features={has_precomputed_features}, audio={has_audio}"
)
if has_image:
return MultimodalInputFormat.RAW_IMAGES
elif has_precomputed_features:
return MultimodalInputFormat.PRECOMPUTED_FEATURES
elif has_pixel_values:
return MultimodalInputFormat.PIXEL_VALUES
elif has_audio:
return MultimodalInputFormat.AUDIO
else:
raise ValueError("No valid multimodal input format found")
except Exception as e:
raise ValueError(f"Failed to categorize inputs: {e}")
def process_raw_images(
base_output: BaseMultiModalProcessorOutput,
) -> Tuple[MultimodalDataItem, torch.Tensor]:
"""Process raw Image.Image objects using transformers processor."""
ret = self.process_mm_data(
input_text=base_output.input_text,
images=base_output.images,
)
combined_mm_item = MultimodalDataItem(modality=Modality.IMAGE)
# Copy all fields from processor output except input_ids
for key, value in ret.items():
if key != "input_ids" and hasattr(combined_mm_item, key):
setattr(combined_mm_item, key, value)
input_ids = ret["input_ids"].flatten()
return combined_mm_item, input_ids
def process_precomputed_features(
base_output: BaseMultiModalProcessorOutput,
) -> Tuple[MultimodalDataItem, torch.Tensor]:
"""Process inputs with precomputed features."""
combined_mm_item = MultimodalDataItem(modality=Modality.IMAGE)
combined_mm_item.precomputed_features = self._extract_processor_features(
base_output.images, "precomputed_features"
)
input_ids = tokenize_text(base_output.input_text)
return combined_mm_item, input_ids
def process_pixel_values(
base_output: BaseMultiModalProcessorOutput,
) -> Tuple[MultimodalDataItem, torch.Tensor]:
"""Process inputs with pixel values."""
values = self._extract_processor_features_from_all_attributes(
base_output.images
)
combined_mm_item = MultimodalDataItem.from_dict(values)
input_ids = tokenize_text(base_output.input_text)
return combined_mm_item, input_ids
def process_audio(
base_output: BaseMultiModalProcessorOutput,
) -> Tuple[MultimodalDataItem, torch.Tensor]:
"""Process inputs with audio."""
ret = self.process_mm_data(
input_text=base_output.input_text,
audio=base_output.audios, # Note: "audio" is for gemma3n only
)
combined_mm_item = MultimodalDataItem(modality=Modality.AUDIO)
for key, value in ret.items():
if key != "input_ids" and hasattr(combined_mm_item, key):
setattr(combined_mm_item, key, value)
input_ids = ret["input_ids"].flatten()
return combined_mm_item, input_ids
def finalize_mm_item(
combined_mm_item: MultimodalDataItem, input_ids: torch.Tensor
) -> MultimodalDataItem:
"""Apply common post-processing to the multimodal item."""
if combined_mm_item.modality in [Modality.IMAGE, Modality.MULTI_IMAGES]:
combined_mm_item.image_offsets = self.get_mm_items_offset(
input_ids=input_ids, input_ids=input_ids,
mm_token_id=self.IM_TOKEN_ID, mm_token_id=self.IM_TOKEN_ID,
) )
elif combined_mm_item.modality == Modality.AUDIO: elif mm_item.modality == Modality.AUDIO:
combined_mm_item.audio_offsets = self.get_mm_items_offset( mm_item.audio_offsets = self.get_mm_items_offset(
input_ids=input_ids, input_ids=input_ids,
mm_token_id=self.AUDIO_TOKEN_ID, mm_token_id=self.AUDIO_TOKEN_ID,
) )
elif combined_mm_item.modality == Modality.VIDEO: elif mm_item.modality == Modality.VIDEO:
combined_mm_item.video_offsets = self.get_mm_items_offset( mm_item.video_offsets = self.get_mm_items_offset(
input_ids=input_ids, input_ids=input_ids,
mm_token_id=self.VIDEO_TOKEN_ID, mm_token_id=self.VIDEO_TOKEN_ID,
) )
else: else:
raise ValueError(f"Unknown modality: {combined_mm_item.modality}") raise ValueError(f"Unknown modality: {mm_item.modality}")
return combined_mm_item
# Main logic - determine input type and handle text-only case return all_collected_items, input_ids
mm_inputs = base_output.images or base_output.audios
if not mm_inputs:
input_ids = tokenize_text(base_output.input_text)
return None, input_ids
# Categorize input formats
input_format = categorize_mm_inputs(mm_inputs)
# Process based on format
if input_format == MultimodalInputFormat.RAW_IMAGES:
combined_mm_item, input_ids = process_raw_images(base_output)
elif input_format == MultimodalInputFormat.PRECOMPUTED_FEATURES:
combined_mm_item, input_ids = process_precomputed_features(base_output)
elif input_format == MultimodalInputFormat.PIXEL_VALUES:
combined_mm_item, input_ids = process_pixel_values(base_output)
elif input_format == MultimodalInputFormat.AUDIO:
combined_mm_item, input_ids = process_audio(base_output)
else:
raise ValueError(f"Unknown input format: {input_format}")
# Finalize with common processing
combined_mm_item = finalize_mm_item(combined_mm_item, input_ids)
return combined_mm_item, input_ids

View File

@@ -15,20 +15,11 @@ class ClipImageProcessor(BaseMultimodalProcessor):
async def process_mm_data_async( async def process_mm_data_async(
self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs
): ):
if not image_data:
return None
if isinstance(input_text, list): if isinstance(input_text, list):
assert len(input_text) and isinstance(input_text[0], int) assert len(input_text) and isinstance(input_text[0], int)
input_text = self._processor.tokenizer.decode(input_text) input_text = self._processor.tokenizer.decode(input_text)
if not isinstance(image_data, list): images = [load_image(image)[0] for image in image_data]
image_data = [image_data]
if len(image_data) > 0:
images = [load_image(image)[0] for image in image_data]
else:
images = load_image(image_data[0])[0]
image_inputs = self.process_mm_data(input_text=input_text, images=images) image_inputs = self.process_mm_data(input_text=input_text, images=images)
image_inputs["data_hashes"] = [hash(str(image_data))] image_inputs["data_hashes"] = [hash(str(image_data))]

View File

@@ -44,17 +44,10 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
*args, *args,
**kwargs **kwargs
): ):
if not image_data:
return None
if not isinstance(image_data, list):
image_data = [image_data]
image_token = self.IMAGE_TOKEN
base_output = self.load_mm_data( base_output = self.load_mm_data(
input_text, input_text,
image_data=image_data, image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token), multimodal_tokens=MultimodalSpecialTokens(image_token=self.IMAGE_TOKEN),
max_req_input_len=max_req_input_len, max_req_input_len=max_req_input_len,
) )
res = self.process_mm_data( res = self.process_mm_data(

View File

@@ -36,11 +36,6 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
*args, *args,
**kwargs, **kwargs,
): ):
if not image_data:
return None
if isinstance(image_data, str):
image_data = [image_data]
base_output = self.load_mm_data( base_output = self.load_mm_data(
prompt=input_text, prompt=input_text,
image_data=image_data, image_data=image_data,
@@ -51,11 +46,11 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
discard_alpha_channel=True, discard_alpha_channel=True,
) )
combined_mm_item, input_ids = self.process_and_combine_mm_data(base_output) mm_items, input_ids = self.process_and_combine_mm_data(base_output)
return { return {
"input_ids": input_ids.tolist(), "input_ids": input_ids.tolist(),
"mm_items": [combined_mm_item] if combined_mm_item is not None else [], "mm_items": mm_items,
"im_start_id": self.IM_START_TOKEN_ID, "im_start_id": self.IM_START_TOKEN_ID,
"im_end_id": self.IM_END_TOKEN_ID, "im_end_id": self.IM_END_TOKEN_ID,
} }

View File

@@ -59,17 +59,6 @@ class Gemma3nSGLangProcessor(SGLangBaseProcessor):
**kwargs, **kwargs,
): ):
"""Process multimodal data including images and audio.""" """Process multimodal data including images and audio."""
audio_data = request_obj.audio_data
if not image_data and not audio_data:
return None
if isinstance(image_data, str):
image_data = [image_data]
if isinstance(audio_data, str):
audio_data = [audio_data]
base_output = self.load_mm_data( base_output = self.load_mm_data(
prompt=input_text, prompt=input_text,
image_data=image_data, image_data=image_data,
@@ -83,13 +72,11 @@ class Gemma3nSGLangProcessor(SGLangBaseProcessor):
), ),
) )
combined_mm_item, input_ids = self.process_and_combine_mm_data(base_output) mm_items, input_ids = self.process_and_combine_mm_data(base_output)
return { return {
"input_ids": input_ids.tolist(), "input_ids": input_ids.tolist(),
"mm_items": [combined_mm_item] if combined_mm_item is not None else [], "mm_items": mm_items,
"im_start_id": self.IM_START_TOKEN_ID, "im_token_id": self.IM_TOKEN_ID,
"im_end_id": self.IM_END_TOKEN_ID, "audio_token_id": self.AUDIO_TOKEN_ID,
"audio_start_id": self.AUDIO_START_TOKEN_ID,
"audio_end_id": self.AUDIO_END_TOKEN_ID,
} }

View File

@@ -172,13 +172,6 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
async def process_mm_data_async( async def process_mm_data_async(
self, image_data, input_text, request_obj, max_req_input_len, **kwargs self, image_data, input_text, request_obj, max_req_input_len, **kwargs
): ):
if not image_data:
return None
# Ensure image_data is a list
if isinstance(image_data, str):
image_data = [image_data]
base_output = self.load_mm_data( base_output = self.load_mm_data(
prompt=input_text, prompt=input_text,
image_data=image_data, image_data=image_data,

View File

@@ -22,12 +22,6 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
max_req_input_len, max_req_input_len,
**kwargs, **kwargs,
): ):
if not image_data:
return None
if not isinstance(image_data, list):
image_data = [image_data]
processor = self._processor processor = self._processor
base_out = self.load_mm_data( base_out = self.load_mm_data(

View File

@@ -30,11 +30,6 @@ class KimiVLImageProcessor(SGLangBaseProcessor):
*args, *args,
**kwargs, **kwargs,
): ):
if not image_data:
return None
if isinstance(image_data, str):
image_data = [image_data]
base_output = self.load_mm_data( base_output = self.load_mm_data(
prompt=input_text, prompt=input_text,
image_data=image_data, image_data=image_data,
@@ -44,10 +39,10 @@ class KimiVLImageProcessor(SGLangBaseProcessor):
max_req_input_len=max_req_input_len, max_req_input_len=max_req_input_len,
) )
combined_mm_item, input_ids = self.process_and_combine_mm_data(base_output) mm_items, input_ids = self.process_and_combine_mm_data(base_output)
return { return {
"input_ids": input_ids.tolist(), "input_ids": input_ids.tolist(),
"mm_items": [combined_mm_item] if combined_mm_item is not None else [], "mm_items": mm_items,
"im_token_id": self.IM_TOKEN_ID, "im_token_id": self.IM_TOKEN_ID,
} }

View File

@@ -110,9 +110,6 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
*args, *args,
**kwargs, **kwargs,
): ):
if not image_data:
return None
modalities = request_obj.modalities or ["image"] modalities = request_obj.modalities or ["image"]
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None) aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
grid_pinpoints = ( grid_pinpoints = (
@@ -122,9 +119,6 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
else None else None
) )
if isinstance(image_data, str):
image_data = [image_data]
if isinstance(image_data, list) and len(image_data) > 0: if isinstance(image_data, list) and len(image_data) > 0:
if "multi-images" in modalities or "video" in modalities: if "multi-images" in modalities or "video" in modalities:
# Multiple images # Multiple images

View File

@@ -23,19 +23,12 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
async def process_mm_data_async( async def process_mm_data_async(
self, self,
image_data: List[Union[str, bytes]], image_data: List[Union[str, bytes]],
audio_data: List[Union[str, bytes]],
input_text, input_text,
request_obj, request_obj,
max_req_input_len, max_req_input_len,
**kwargs, **kwargs,
): ):
audio_data = request_obj.audio_data
if not image_data and not audio_data:
return None
if not isinstance(image_data, list):
image_data = [image_data]
if not isinstance(audio_data, list):
audio_data = [audio_data]
base_output = self.load_mm_data( base_output = self.load_mm_data(
prompt=input_text, prompt=input_text,
max_req_input_len=max_req_input_len, max_req_input_len=max_req_input_len,

View File

@@ -15,21 +15,11 @@ class MllamaImageProcessor(BaseMultimodalProcessor):
async def process_mm_data_async( async def process_mm_data_async(
self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs
): ):
if not image_data:
return None
if isinstance(input_text, list): if isinstance(input_text, list):
assert len(input_text) and isinstance(input_text[0], int) assert len(input_text) and isinstance(input_text[0], int)
input_text = self._processor.tokenizer.decode(input_text) input_text = self._processor.tokenizer.decode(input_text)
if not isinstance(image_data, list): images = [load_image(image)[0] for image in image_data]
image_data = [image_data]
if len(image_data) > 0:
images = [load_image(image)[0] for image in image_data]
else:
images = load_image(image_data[0])[0]
image_inputs = self.process_mm_data(input_text=input_text, images=images) image_inputs = self.process_mm_data(input_text=input_text, images=images)
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0] image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
image_inputs["mm_items"] = [ image_inputs["mm_items"] = [

View File

@@ -37,9 +37,6 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
*args, *args,
**kwargs, **kwargs,
): ):
if not image_data:
return None
if isinstance(input_text, list): if isinstance(input_text, list):
assert len(input_text) and isinstance(input_text[0], int) assert len(input_text) and isinstance(input_text[0], int)
input_text = self._processor.tokenizer.decode(input_text) input_text = self._processor.tokenizer.decode(input_text)

View File

@@ -26,22 +26,12 @@ class Phi4MMImageProcessor(BaseMultimodalProcessor):
async def process_mm_data_async( async def process_mm_data_async(
self, self,
image_data: List[Union[str, bytes]], image_data: List[Union[str, bytes]],
audio_data,
input_text, input_text,
request_obj, request_obj,
max_req_input_len, max_req_input_len,
**kwargs, **kwargs,
): ):
audio_data = request_obj.audio_data
if not image_data and not audio_data:
return None
if not isinstance(image_data, list):
image_data = [image_data]
if not isinstance(audio_data, list):
audio_data = [audio_data]
if audio_data: if audio_data:
logger.warning( logger.warning(
"Currently SGLang does not support audio data for Phi4MM. We are working on it. You can file an issue to help us prioritize." "Currently SGLang does not support audio data for Phi4MM. We are working on it. You can file an issue to help us prioritize."

View File

@@ -78,12 +78,6 @@ class PixtralProcessor(BaseMultimodalProcessor):
*args, *args,
**kwargs, **kwargs,
): ):
if not image_data:
return None
if isinstance(image_data, str):
image_data = [image_data]
mm_data = self.load_mm_data( mm_data = self.load_mm_data(
prompt=input_text, prompt=input_text,
multimodal_tokens=self.multimodal_tokens, multimodal_tokens=self.multimodal_tokens,

View File

@@ -49,9 +49,6 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
*args, *args,
**kwargs, **kwargs,
): ):
if isinstance(image_data, str):
image_data = [image_data]
base_output = self.load_mm_data( base_output = self.load_mm_data(
prompt=input_text, prompt=input_text,
image_data=image_data, image_data=image_data,
@@ -130,12 +127,13 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
video_grid_thw = None # TODO video_grid_thw = None # TODO
combined_mm_item, input_ids = self.process_and_combine_mm_data(base_output) mm_items, input_ids = self.process_and_combine_mm_data(base_output)
if combined_mm_item is None: if not mm_items:
# Note(Xinyuan): This is the case where image loading fails. # Note(Xinyuan): This is the case where image loading fails.
return None return None
combined_mm_item = mm_items[0] # only image is supported for now
video_grid_thw = None # TODO video_grid_thw = None # TODO
second_per_grid_ts = getattr(combined_mm_item, "second_per_grid_ts", None) second_per_grid_ts = getattr(combined_mm_item, "second_per_grid_ts", None)
@@ -157,7 +155,7 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
return { return {
"input_ids": input_ids.tolist(), "input_ids": input_ids.tolist(),
"mm_items": [combined_mm_item], "mm_items": mm_items,
"im_start_id": self.IM_START_TOKEN_ID, "im_start_id": self.IM_START_TOKEN_ID,
"im_end_id": self.IM_END_TOKEN_ID, "im_end_id": self.IM_END_TOKEN_ID,
"im_token_id": self.IM_TOKEN_ID, "im_token_id": self.IM_TOKEN_ID,

View File

@@ -37,6 +37,8 @@ class VILAMultimodalProcessor(BaseMultimodalProcessor):
_processor: VILAProcessor, _processor: VILAProcessor,
) -> None: ) -> None:
super().__init__(hf_config, server_args, _processor) super().__init__(hf_config, server_args, _processor)
self.IM_TOKEN_ID = hf_config.image_token_id
self.VIDEO_TOKEN_ID = hf_config.video_token_id
async def process_mm_data_async( async def process_mm_data_async(
self, self,
@@ -46,13 +48,7 @@ class VILAMultimodalProcessor(BaseMultimodalProcessor):
max_req_input_len: int, max_req_input_len: int,
**kwargs, **kwargs,
) -> Optional[Dict[str, Any]]: ) -> Optional[Dict[str, Any]]:
if not image_data: base_output = self.load_mm_data(
return None
if not isinstance(image_data, list):
image_data = [image_data]
mm_data = self.load_mm_data(
prompt=input_text, prompt=input_text,
multimodal_tokens=MultimodalSpecialTokens( multimodal_tokens=MultimodalSpecialTokens(
image_token=self._processor.tokenizer.image_token image_token=self._processor.tokenizer.image_token
@@ -61,25 +57,11 @@ class VILAMultimodalProcessor(BaseMultimodalProcessor):
image_data=image_data, image_data=image_data,
) )
inputs = self.process_mm_data( mm_items, input_ids = self.process_and_combine_mm_data(base_output)
input_text=mm_data.input_text,
images=mm_data.images,
)
image_offsets = self.get_mm_items_offset( return {
input_ids=inputs.input_ids[0], "input_ids": input_ids.tolist(),
mm_token_id=cast(int, self._processor.tokenizer.image_token_id), "mm_items": mm_items,
) "im_token_id": self.IM_TOKEN_ID,
"video_token_id": self.VIDEO_TOKEN_ID,
mm_items: List[MultimodalDataItem] = [ }
MultimodalDataItem(
modality=Modality.IMAGE,
image_offsets=image_offsets,
pixel_values=inputs.pixel_values,
)
]
return dict(
input_ids=inputs.input_ids[0].tolist(),
mm_items=mm_items,
)