[Refactor] simplify multimodal data processing (#8107)
Signed-off-by: Xinyuan Tong <justinning0323@outlook.com>
This commit is contained in:
@@ -5,7 +5,7 @@ import multiprocessing as mp
|
||||
import os
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -155,17 +155,15 @@ class BaseMultimodalProcessor(ABC):
|
||||
self.ATTR_NAME_TO_MODALITY = {
|
||||
# Image-related attributes
|
||||
"pixel_values": Modality.IMAGE,
|
||||
"pixel_values_videos": Modality.VIDEO,
|
||||
"image_sizes": Modality.IMAGE,
|
||||
"image_grid_thw": Modality.IMAGE,
|
||||
"image_attention_mask": Modality.IMAGE,
|
||||
"image_emb_mask": Modality.IMAGE,
|
||||
"image_spatial_crop": Modality.IMAGE,
|
||||
"images_spatial_crop": Modality.IMAGE,
|
||||
"tgt_size": Modality.IMAGE,
|
||||
"image_grid_hws": Modality.IMAGE,
|
||||
"aspect_ratio_id": Modality.IMAGE,
|
||||
"aspect_ratio_ids": Modality.IMAGE,
|
||||
"aspect_ratio_mask": Modality.IMAGE,
|
||||
"second_per_grid_ts": Modality.IMAGE,
|
||||
# Audio-related attributes
|
||||
"audio_features": Modality.AUDIO,
|
||||
"audio_feature_lens": Modality.AUDIO,
|
||||
@@ -173,9 +171,11 @@ class BaseMultimodalProcessor(ABC):
|
||||
"input_features_mask": Modality.AUDIO,
|
||||
"audio_attention_mask": Modality.AUDIO,
|
||||
# Video-related attributes
|
||||
"pixel_values_videos": Modality.VIDEO,
|
||||
"second_per_grid_ts": Modality.VIDEO,
|
||||
"video_grid_thw": Modality.VIDEO,
|
||||
# Generic attributes that could apply to multiple modalities
|
||||
# "precomputed_features" - handled specially as it can be any modality
|
||||
# "precomputed_embeddings" - handled specially as it can be any modality
|
||||
}
|
||||
|
||||
# name of the feature filed
|
||||
@@ -222,7 +222,6 @@ class BaseMultimodalProcessor(ABC):
|
||||
audio_data,
|
||||
input_text,
|
||||
request_obj,
|
||||
max_req_input_len,
|
||||
**kwargs,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
pass
|
||||
@@ -283,7 +282,7 @@ class BaseMultimodalProcessor(ABC):
|
||||
self,
|
||||
text_parts: List[str],
|
||||
multimodal_tokens: MultimodalSpecialTokens,
|
||||
data_iterators: dict,
|
||||
data_iterators: dict[Modality, Iterator[Any]],
|
||||
discard_alpha_channel: bool = True,
|
||||
image_estimated_frames_iter: Optional[iter] = None,
|
||||
image_scaling_factor: float = 1.0,
|
||||
@@ -354,7 +353,6 @@ class BaseMultimodalProcessor(ABC):
|
||||
self,
|
||||
prompt: str,
|
||||
multimodal_tokens: MultimodalSpecialTokens,
|
||||
max_req_input_len: int,
|
||||
image_data: Optional[list] = None,
|
||||
video_data: Optional[list] = None,
|
||||
audio_data: Optional[list] = None,
|
||||
@@ -489,50 +487,11 @@ class BaseMultimodalProcessor(ABC):
|
||||
|
||||
return list(zip(indices_start.tolist(), indices_end.tolist()))
|
||||
|
||||
@staticmethod
|
||||
def _extract_processor_features(
|
||||
items: List[dict], attr_name: str
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Helper function to concat extracted attributes from processor output.
|
||||
"""
|
||||
values = [value for item in items if (value := item.get(attr_name)) is not None]
|
||||
return torch.cat(values) if values else None
|
||||
|
||||
# When we assume that all the items have the same attributes
|
||||
def _extract_processor_features_from_all_attributes(
|
||||
self, items: List[dict]
|
||||
) -> dict:
|
||||
values = {}
|
||||
# Verify all items have the same keys
|
||||
first_keys = set(items[0].keys())
|
||||
for item in items[1:]:
|
||||
if set(item.keys()) != first_keys:
|
||||
raise ValueError(
|
||||
f"All items must have the same attributes. "
|
||||
f"First item has {first_keys}, but found {set(item.keys())}"
|
||||
)
|
||||
|
||||
# Process each attribute
|
||||
for k, v in items[0].items():
|
||||
if isinstance(v, list):
|
||||
values[k] = self._extract_processor_features(items, k)
|
||||
else:
|
||||
# Verify all items have the same value for non-list attributes
|
||||
for item in items[1:]:
|
||||
if item[k] != v:
|
||||
raise ValueError(
|
||||
f"All items must have the same value for attribute {k}. "
|
||||
f"First item has {v}, but found {item[k]}"
|
||||
)
|
||||
values[k] = v
|
||||
return values
|
||||
|
||||
def collect_mm_items_from_processor_output(
|
||||
self, data_dict: dict
|
||||
) -> List[MultimodalDataItem]:
|
||||
"""Create mm_items directly from processor output."""
|
||||
items = {} # modality -> MultimodalDataItem
|
||||
items: dict[Modality, MultimodalDataItem] = {}
|
||||
|
||||
for attr_name, value in data_dict.items():
|
||||
if attr_name == "input_ids":
|
||||
@@ -541,16 +500,15 @@ class BaseMultimodalProcessor(ABC):
|
||||
# Get modality for this attribute
|
||||
modality = self.ATTR_NAME_TO_MODALITY.get(attr_name)
|
||||
|
||||
if not modality and attr_name == "precomputed_features":
|
||||
if attr_name == "precomputed_embeddings":
|
||||
modality_str = data_dict.get("modality")
|
||||
try:
|
||||
modality = (
|
||||
Modality.from_str(modality_str)
|
||||
if modality_str
|
||||
else Modality.IMAGE
|
||||
)
|
||||
except ValueError:
|
||||
modality = Modality.IMAGE
|
||||
modality = Modality.IMAGE
|
||||
if modality_str:
|
||||
try:
|
||||
modality = Modality.from_str(modality_str)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if modality:
|
||||
# Create item if needed
|
||||
if modality not in items:
|
||||
@@ -559,8 +517,7 @@ class BaseMultimodalProcessor(ABC):
|
||||
if attr_name in self.FEATURE_NAMES:
|
||||
attr_name = "feature"
|
||||
|
||||
# Set attribute
|
||||
setattr(items[modality], attr_name, value)
|
||||
items[modality].set(attr_name, value)
|
||||
|
||||
return list(items.values())
|
||||
|
||||
@@ -586,6 +543,7 @@ class BaseMultimodalProcessor(ABC):
|
||||
self,
|
||||
base_output: BaseMultiModalProcessorOutput,
|
||||
mm_tokens: MultimodalSpecialTokens,
|
||||
**kwargs,
|
||||
) -> Tuple[List[MultimodalDataItem], torch.Tensor, dict]:
|
||||
"""
|
||||
Process multimodal data and return the combined multimodal items and input_ids.
|
||||
@@ -618,7 +576,7 @@ class BaseMultimodalProcessor(ABC):
|
||||
else:
|
||||
raise ValueError(f"Unknown multimodal item type: {type(item)}")
|
||||
# Process items and get input_ids
|
||||
all_collected_items = []
|
||||
all_collected_items: list[MultimodalDataItem] = []
|
||||
input_ids = None
|
||||
|
||||
# Handle dict items (already processed)
|
||||
@@ -634,6 +592,7 @@ class BaseMultimodalProcessor(ABC):
|
||||
images=raw_images,
|
||||
audios=raw_audios,
|
||||
videos=raw_videos,
|
||||
**kwargs,
|
||||
)
|
||||
all_collected_items.extend(collected_items)
|
||||
else:
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from typing import List, Union
|
||||
|
||||
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||
from sglang.srt.models.clip import CLIPModel
|
||||
from sglang.srt.multimodal.processors.base_processor import BaseMultimodalProcessor
|
||||
from sglang.srt.utils import load_image
|
||||
from sglang.srt.multimodal.processors.base_processor import (
|
||||
BaseMultimodalProcessor,
|
||||
MultimodalSpecialTokens,
|
||||
)
|
||||
|
||||
|
||||
class ClipImageProcessor(BaseMultimodalProcessor):
|
||||
@@ -11,23 +12,24 @@ class ClipImageProcessor(BaseMultimodalProcessor):
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
self.mm_tokens = MultimodalSpecialTokens(image_token="<image>").build(
|
||||
_processor
|
||||
)
|
||||
|
||||
async def process_mm_data_async(
|
||||
self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs
|
||||
):
|
||||
if isinstance(input_text, list):
|
||||
assert len(input_text) and isinstance(input_text[0], int)
|
||||
input_text = self._processor.tokenizer.decode(input_text)
|
||||
base_output = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
multimodal_tokens=self.mm_tokens,
|
||||
image_data=image_data,
|
||||
)
|
||||
|
||||
images = [load_image(image)[0] for image in image_data]
|
||||
mm_items, input_ids, _ = self.process_and_combine_mm_data(
|
||||
base_output, self.mm_tokens
|
||||
)
|
||||
|
||||
image_inputs = self.process_mm_data(input_text=input_text, images=images)
|
||||
image_inputs["data_hashes"] = [hash(str(image_data))]
|
||||
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
|
||||
image_inputs["mm_items"] = [
|
||||
MultimodalDataItem(
|
||||
feature=image_inputs["pixel_values"], modality=Modality.IMAGE
|
||||
)
|
||||
]
|
||||
|
||||
return image_inputs
|
||||
return {
|
||||
"input_ids": input_ids.tolist(),
|
||||
"mm_items": mm_items,
|
||||
}
|
||||
|
||||
@@ -33,9 +33,9 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
self.mm_tokens = MultimodalSpecialTokens(image_token="<image>").build(
|
||||
_processor
|
||||
)
|
||||
self.mm_tokens = MultimodalSpecialTokens(
|
||||
image_token="<image>", image_token_id=self._processor.image_token_id
|
||||
).build(_processor)
|
||||
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
@@ -50,36 +50,16 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
|
||||
input_text,
|
||||
image_data=image_data,
|
||||
multimodal_tokens=self.mm_tokens,
|
||||
max_req_input_len=max_req_input_len,
|
||||
)
|
||||
res = self.process_mm_data(
|
||||
input_text=base_output.input_text,
|
||||
images=base_output.images,
|
||||
mm_items, input_ids, _ = self.process_and_combine_mm_data(
|
||||
base_output,
|
||||
self.mm_tokens,
|
||||
max_req_input_len=max_req_input_len,
|
||||
conversations=base_output.input_text,
|
||||
)
|
||||
images_seq_mask = res["images_seq_mask"]
|
||||
images_spatial_crop = res["images_spatial_crop"]
|
||||
batched_images_spatial_crop = []
|
||||
batched_images_spatial_crop.append(images_spatial_crop)
|
||||
batched_images_spatial_crop = torch.stack(batched_images_spatial_crop, dim=0)
|
||||
|
||||
items = []
|
||||
input_ids = res["input_ids"]
|
||||
image_offsets = self.get_mm_items_offset(
|
||||
input_ids=input_ids, mm_token_id=self._processor.image_token_id
|
||||
)
|
||||
item = MultimodalDataItem(
|
||||
feature=res["images"],
|
||||
offsets=image_offsets,
|
||||
modality=Modality.IMAGE,
|
||||
image_emb_mask=images_seq_mask,
|
||||
image_spatial_crop=batched_images_spatial_crop,
|
||||
)
|
||||
items += [item]
|
||||
|
||||
return {
|
||||
"mm_items": items,
|
||||
"mm_items": mm_items,
|
||||
"input_ids": input_ids.tolist(),
|
||||
"im_token_id": self._processor.image_token_id,
|
||||
}
|
||||
|
||||
@@ -33,7 +33,6 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
|
||||
image_data: List[Union[str, bytes, Dict]],
|
||||
input_text,
|
||||
request_obj,
|
||||
max_req_input_len,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -41,7 +40,6 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
|
||||
prompt=input_text,
|
||||
image_data=image_data,
|
||||
multimodal_tokens=self.mm_tokens,
|
||||
max_req_input_len=max_req_input_len,
|
||||
discard_alpha_channel=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -54,7 +54,6 @@ class Gemma3nSGLangProcessor(SGLangBaseProcessor):
|
||||
audio_data: Optional[List[Union[str, bytes, Dict]]] = None,
|
||||
input_text: str = "",
|
||||
request_obj=None,
|
||||
max_req_input_len: int = 0,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -63,7 +62,6 @@ class Gemma3nSGLangProcessor(SGLangBaseProcessor):
|
||||
prompt=input_text,
|
||||
image_data=image_data,
|
||||
audio_data=audio_data,
|
||||
max_req_input_len=max_req_input_len,
|
||||
multimodal_tokens=self.mm_tokens,
|
||||
)
|
||||
|
||||
|
||||
@@ -170,13 +170,12 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
|
||||
return pixel_values, num_patches_list
|
||||
|
||||
async def process_mm_data_async(
|
||||
self, image_data, input_text, request_obj, max_req_input_len, **kwargs
|
||||
self, image_data, input_text, request_obj, **kwargs
|
||||
):
|
||||
base_output = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
image_data=image_data,
|
||||
multimodal_tokens=self.mm_tokens,
|
||||
max_req_input_len=max_req_input_len,
|
||||
discard_alpha_channel=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -11,52 +11,35 @@ from sglang.srt.multimodal.processors.base_processor import (
|
||||
class JanusProImageProcessor(BaseMultimodalProcessor):
|
||||
models = [MultiModalityCausalLM]
|
||||
|
||||
def __init__(self, hf_config, server_args, processor):
|
||||
super().__init__(hf_config, server_args, processor)
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
|
||||
self.mm_tokens = MultimodalSpecialTokens(
|
||||
image_token=processor.image_token
|
||||
).build(processor)
|
||||
image_token=_processor.image_token,
|
||||
image_token_id=_processor.image_id,
|
||||
).build(_processor)
|
||||
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
image_data: List[Union[str, bytes]],
|
||||
input_text,
|
||||
request_obj,
|
||||
max_req_input_len,
|
||||
**kwargs,
|
||||
):
|
||||
processor = self._processor
|
||||
|
||||
base_out = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
image_data=image_data,
|
||||
multimodal_tokens=self.mm_tokens,
|
||||
max_req_input_len=max_req_input_len,
|
||||
)
|
||||
|
||||
images = base_out.images
|
||||
res = self.process_mm_data(
|
||||
input_text=base_out.input_text,
|
||||
prompt=base_out.input_text,
|
||||
images=images,
|
||||
mm_items, input_ids, _ = self.process_and_combine_mm_data(
|
||||
base_out, self.mm_tokens, prompt=base_out.input_text
|
||||
)
|
||||
|
||||
input_ids = res["input_ids"].flatten()
|
||||
image_offsets = self.get_mm_items_offset(
|
||||
input_ids=input_ids, mm_token_id=processor.image_id
|
||||
)
|
||||
return {
|
||||
"mm_items": [
|
||||
MultimodalDataItem(
|
||||
feature=res["pixel_values"],
|
||||
image_emb_mask=res["images_emb_mask"],
|
||||
offsets=image_offsets,
|
||||
modality=Modality.IMAGE,
|
||||
)
|
||||
],
|
||||
"mm_items": mm_items,
|
||||
"input_ids": input_ids.tolist(),
|
||||
"im_start_id": processor.image_start_id,
|
||||
"im_end_id": processor.image_end_id,
|
||||
"im_token_id": processor.image_id,
|
||||
"im_start_id": self._processor.image_start_id,
|
||||
"im_end_id": self._processor.image_end_id,
|
||||
"im_token_id": self.mm_tokens.image_token_id,
|
||||
}
|
||||
|
||||
@@ -26,7 +26,6 @@ class KimiVLImageProcessor(SGLangBaseProcessor):
|
||||
image_data: List[Union[str, bytes, Dict]],
|
||||
input_text,
|
||||
request_obj,
|
||||
max_req_input_len,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -34,7 +33,6 @@ class KimiVLImageProcessor(SGLangBaseProcessor):
|
||||
prompt=input_text,
|
||||
image_data=image_data,
|
||||
multimodal_tokens=self.mm_tokens,
|
||||
max_req_input_len=max_req_input_len,
|
||||
)
|
||||
|
||||
mm_items, input_ids, _ = self.process_and_combine_mm_data(
|
||||
|
||||
@@ -159,7 +159,9 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
|
||||
"mm_items": [
|
||||
MultimodalDataItem(
|
||||
feature=pixel_values,
|
||||
image_sizes=image_sizes,
|
||||
model_specific_data={
|
||||
"image_sizes": image_sizes,
|
||||
},
|
||||
modality=modality,
|
||||
)
|
||||
],
|
||||
|
||||
@@ -17,10 +17,21 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
# Collect special token ids
|
||||
tokenizer = self._processor.tokenizer
|
||||
self.slice_start_id = getattr(tokenizer, "slice_start_id", None)
|
||||
self.slice_end_id = getattr(tokenizer, "slice_end_id", None)
|
||||
self.audio_start_id = getattr(tokenizer, "audio_start_id", None)
|
||||
self.audio_end_id = getattr(tokenizer, "audio_end_id", None)
|
||||
self.im_start_id = getattr(tokenizer, "im_start_id", None)
|
||||
self.im_end_id = getattr(tokenizer, "im_end_id", None)
|
||||
self.im_token_id = getattr(tokenizer, "unk_id", None)
|
||||
|
||||
self.mm_tokens = MultimodalSpecialTokens(
|
||||
image_token="(<image>./</image>)",
|
||||
audio_token="(<audio>./</audio>)",
|
||||
video_token="(<video>./</video>)",
|
||||
image_token_id=self.im_token_id,
|
||||
).build(_processor)
|
||||
|
||||
async def process_mm_data_async(
|
||||
@@ -29,12 +40,10 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
||||
audio_data: List[Union[str, bytes]],
|
||||
input_text,
|
||||
request_obj,
|
||||
max_req_input_len,
|
||||
**kwargs,
|
||||
):
|
||||
base_output = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
max_req_input_len=max_req_input_len,
|
||||
audio_data=audio_data,
|
||||
image_data=image_data,
|
||||
multimodal_tokens=self.mm_tokens,
|
||||
@@ -48,24 +57,6 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
||||
audios=base_output.audios,
|
||||
)
|
||||
|
||||
# Collect special token ids
|
||||
tokenizer = self._processor.tokenizer
|
||||
slice_start_id, slice_end_id, audio_start_id, audio_end_id = (
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
if tokenizer.slice_start_id:
|
||||
slice_start_id = tokenizer.slice_start_id
|
||||
slice_end_id = tokenizer.slice_end_id
|
||||
if hasattr(tokenizer, "audio_start_id"):
|
||||
audio_start_id = tokenizer.audio_start_id
|
||||
audio_end_id = tokenizer.audio_end_id
|
||||
|
||||
im_start_id = tokenizer.im_start_id
|
||||
im_end_id = tokenizer.im_end_id
|
||||
im_token_id = tokenizer.unk_id
|
||||
pixel_values = res["pixel_values"]
|
||||
tgt_sizes = res["tgt_sizes"]
|
||||
|
||||
@@ -102,10 +93,12 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
||||
items = []
|
||||
input_ids = res["input_ids"].flatten()
|
||||
image_offsets = self.get_mm_items_offset_by_pair(
|
||||
input_ids=input_ids, mm_start_id=im_start_id, mm_end_id=im_end_id
|
||||
input_ids=input_ids, mm_start_id=self.im_start_id, mm_end_id=self.im_end_id
|
||||
)
|
||||
slice_offsets = self.get_mm_items_offset_by_pair(
|
||||
input_ids=input_ids, mm_start_id=slice_start_id, mm_end_id=slice_end_id
|
||||
input_ids=input_ids,
|
||||
mm_start_id=self.slice_start_id,
|
||||
mm_end_id=self.slice_end_id,
|
||||
)
|
||||
image_offsets.extend(slice_offsets)
|
||||
image_offsets = sorted(image_offsets)
|
||||
@@ -114,7 +107,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
||||
item = MultimodalDataItem(
|
||||
feature=pixel_values,
|
||||
offsets=image_offsets,
|
||||
tgt_size=tgt_sizes_flat,
|
||||
model_specific_data={"tgt_size": tgt_sizes_flat},
|
||||
modality=Modality.IMAGE,
|
||||
)
|
||||
items += [item]
|
||||
@@ -124,17 +117,17 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
||||
and res["audio_features"] is not None
|
||||
and len(res["audio_features"]) != 0
|
||||
):
|
||||
if audio_start_id is not None and audio_end_id is not None:
|
||||
if self.audio_start_id is not None and self.audio_end_id is not None:
|
||||
audio_offsets = self.get_mm_items_offset_by_pair(
|
||||
input_ids=input_ids,
|
||||
mm_start_id=audio_start_id,
|
||||
mm_end_id=audio_end_id,
|
||||
mm_start_id=self.audio_start_id,
|
||||
mm_end_id=self.audio_end_id,
|
||||
)
|
||||
else:
|
||||
audio_offsets = None
|
||||
item = MultimodalDataItem(
|
||||
feature=[res["audio_features"]],
|
||||
audio_feature_lens=res["audio_feature_lens"],
|
||||
model_specific_data={"audio_feature_lens": res["audio_feature_lens"]},
|
||||
offsets=audio_offsets,
|
||||
modality=Modality.AUDIO,
|
||||
)
|
||||
@@ -142,11 +135,11 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
||||
return {
|
||||
"mm_items": items,
|
||||
"input_ids": input_ids.tolist(),
|
||||
"audio_start_id": audio_start_id,
|
||||
"audio_end_id": audio_end_id,
|
||||
"im_token_id": im_token_id,
|
||||
"im_start_id": im_start_id,
|
||||
"im_end_id": im_end_id,
|
||||
"slice_start_id": slice_start_id,
|
||||
"slice_end_id": slice_end_id,
|
||||
"audio_start_id": self.audio_start_id,
|
||||
"audio_end_id": self.audio_end_id,
|
||||
"im_token_id": self.im_token_id,
|
||||
"im_start_id": self.im_start_id,
|
||||
"im_end_id": self.im_end_id,
|
||||
"slice_start_id": self.slice_start_id,
|
||||
"slice_end_id": self.slice_end_id,
|
||||
}
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from typing import List, Union
|
||||
|
||||
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||
from sglang.srt.models.mllama import MllamaForConditionalGeneration
|
||||
from sglang.srt.multimodal.processors.base_processor import BaseMultimodalProcessor
|
||||
from sglang.srt.utils import load_image
|
||||
from sglang.srt.multimodal.processors.base_processor import (
|
||||
BaseMultimodalProcessor,
|
||||
MultimodalSpecialTokens,
|
||||
)
|
||||
|
||||
|
||||
class MllamaImageProcessor(BaseMultimodalProcessor):
|
||||
@@ -11,24 +12,26 @@ class MllamaImageProcessor(BaseMultimodalProcessor):
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
self.mm_tokens = MultimodalSpecialTokens(
|
||||
image_token=self._processor.image_token,
|
||||
image_token_id=self._processor.image_token_id,
|
||||
).build(_processor)
|
||||
|
||||
async def process_mm_data_async(
|
||||
self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs
|
||||
):
|
||||
if isinstance(input_text, list):
|
||||
assert len(input_text) and isinstance(input_text[0], int)
|
||||
input_text = self._processor.tokenizer.decode(input_text)
|
||||
base_out = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
image_data=image_data,
|
||||
multimodal_tokens=self.mm_tokens,
|
||||
)
|
||||
|
||||
images = [load_image(image)[0] for image in image_data]
|
||||
image_inputs = self.process_mm_data(input_text=input_text, images=images)
|
||||
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
|
||||
image_inputs["mm_items"] = [
|
||||
MultimodalDataItem(
|
||||
feature=image_inputs["pixel_values"],
|
||||
aspect_ratio_id=image_inputs["aspect_ratio_ids"],
|
||||
aspect_ratio_mask=image_inputs["aspect_ratio_mask"],
|
||||
modality=Modality.IMAGE,
|
||||
)
|
||||
]
|
||||
mm_items, input_ids, _ = self.process_and_combine_mm_data(
|
||||
base_out, self.mm_tokens
|
||||
)
|
||||
|
||||
return image_inputs
|
||||
return {
|
||||
"mm_items": mm_items,
|
||||
"input_ids": input_ids.tolist(),
|
||||
"im_token_id": self.mm_tokens.image_token_id,
|
||||
}
|
||||
|
||||
@@ -27,13 +27,13 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
|
||||
self.image_token_index = hf_config.image_token_index
|
||||
self.multimodal_tokens = MultimodalSpecialTokens(
|
||||
image_token=_processor.image_token,
|
||||
image_token_id=self.image_token_index,
|
||||
).build(_processor)
|
||||
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
image_data: List[Union[str, bytes]],
|
||||
input_text,
|
||||
max_req_input_len=None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -45,7 +45,6 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
|
||||
processed_data = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
multimodal_tokens=self.multimodal_tokens,
|
||||
max_req_input_len=max_req_input_len or 4096,
|
||||
image_data=image_data,
|
||||
return_text=True,
|
||||
)
|
||||
|
||||
@@ -31,6 +31,7 @@ class Phi4MMProcessorAdapter(ProcessorMixin):
|
||||
for hf_key, sglang_key in key_mapping.items():
|
||||
if hf_key in result:
|
||||
result[sglang_key] = result[hf_key]
|
||||
del result[hf_key]
|
||||
|
||||
# Filter out None or empty tensors from the result.
|
||||
# This prevents the sglang function base_processor.collect_mm_items_from_processor_output()
|
||||
@@ -58,7 +59,7 @@ class Phi4MMMultimodalProcessor(BaseMultimodalProcessor):
|
||||
self.AUDIO_TOKEN_ID = 200011
|
||||
self.AUDIO_SAMPLE_RATE = 16000
|
||||
|
||||
self.multimodal_tokens = MultimodalSpecialTokens(
|
||||
self.mm_tokens = MultimodalSpecialTokens(
|
||||
image_token=self.IMAGE_TOKEN,
|
||||
image_token_id=self.IM_TOKEN_ID,
|
||||
audio_token=self.AUDIO_TOKEN,
|
||||
@@ -71,15 +72,13 @@ class Phi4MMMultimodalProcessor(BaseMultimodalProcessor):
|
||||
audio_data,
|
||||
input_text,
|
||||
request_obj,
|
||||
max_req_input_len,
|
||||
**kwargs,
|
||||
):
|
||||
base_output = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
max_req_input_len=max_req_input_len,
|
||||
audio_data=audio_data,
|
||||
image_data=image_data,
|
||||
multimodal_tokens=self.multimodal_tokens,
|
||||
multimodal_tokens=self.mm_tokens,
|
||||
audio_sample_rate=self.AUDIO_SAMPLE_RATE,
|
||||
)
|
||||
|
||||
@@ -91,12 +90,12 @@ class Phi4MMMultimodalProcessor(BaseMultimodalProcessor):
|
||||
]
|
||||
|
||||
mm_items, input_ids, _ = self.process_and_combine_mm_data(
|
||||
base_output, self.multimodal_tokens
|
||||
base_output, self.mm_tokens
|
||||
)
|
||||
|
||||
return {
|
||||
"input_ids": input_ids.tolist(),
|
||||
"mm_items": mm_items,
|
||||
"im_token_id": self.IM_TOKEN_ID,
|
||||
"audio_token_id": self.AUDIO_TOKEN_ID,
|
||||
"im_token_id": self.mm_tokens.image_token_id,
|
||||
"audio_token_id": self.mm_tokens.audio_token_id,
|
||||
}
|
||||
|
||||
@@ -6,7 +6,6 @@ from transformers.models.pixtral.image_processing_pixtral import (
|
||||
_num_image_tokens as _get_pixtral_hf_num_image_tokens,
|
||||
)
|
||||
|
||||
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||
from sglang.srt.models.pixtral import PixtralVisionModel
|
||||
from sglang.srt.multimodal.processors.base_processor import (
|
||||
BaseMultimodalProcessor,
|
||||
@@ -45,7 +44,7 @@ class PixtralProcessor(BaseMultimodalProcessor):
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
self.image_token_id = getattr(
|
||||
self.IM_TOKEN_ID = getattr(
|
||||
hf_config, "image_token_index", PixtralVisionModel.DEFAULT_IMAGE_TOKEN_ID
|
||||
)
|
||||
# Instantiate the patcher logic helper using the class defined above
|
||||
@@ -53,8 +52,9 @@ class PixtralProcessor(BaseMultimodalProcessor):
|
||||
self.vision_config = hf_config.vision_config
|
||||
self.image_size = self.vision_config.image_size
|
||||
self.patch_size = self.vision_config.patch_size
|
||||
self.multimodal_tokens = MultimodalSpecialTokens(
|
||||
image_token=_processor.image_token
|
||||
self.mm_tokens = MultimodalSpecialTokens(
|
||||
image_token=_processor.image_token,
|
||||
image_token_id=self.IM_TOKEN_ID,
|
||||
).build(_processor)
|
||||
_processor.tokenizer.add_special_tokens(
|
||||
{
|
||||
@@ -80,42 +80,21 @@ class PixtralProcessor(BaseMultimodalProcessor):
|
||||
):
|
||||
mm_data = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
multimodal_tokens=self.multimodal_tokens,
|
||||
max_req_input_len=kwargs.get("max_req_input_len", 4096),
|
||||
multimodal_tokens=self.mm_tokens,
|
||||
image_data=image_data,
|
||||
return_text=True,
|
||||
)
|
||||
|
||||
if mm_data.images:
|
||||
resize_tasks = [self._resize(image) for image in mm_data.images]
|
||||
mm_data.images = await asyncio.gather(*resize_tasks)
|
||||
|
||||
processor_output = self.process_mm_data(
|
||||
input_text=mm_data.input_text,
|
||||
images=mm_data.images,
|
||||
mm_items, input_ids, _ = self.process_and_combine_mm_data(
|
||||
mm_data, self.mm_tokens
|
||||
)
|
||||
|
||||
if "pixel_values" in processor_output:
|
||||
input_ids = processor_output["input_ids"].view(-1)
|
||||
image_offsets = self.get_mm_items_offset(
|
||||
input_ids=input_ids,
|
||||
mm_token_id=self.image_token_id,
|
||||
)
|
||||
mm_items = [
|
||||
MultimodalDataItem(
|
||||
feature=processor_output["pixel_values"],
|
||||
image_sizes=processor_output["image_sizes"],
|
||||
modality=Modality.IMAGE,
|
||||
offsets=image_offsets,
|
||||
)
|
||||
]
|
||||
|
||||
input_ids = input_ids.tolist()
|
||||
processor_output.update(
|
||||
input_ids=input_ids,
|
||||
mm_items=mm_items,
|
||||
# there's no im_start_id for pixtral, only im_token and im_end_token
|
||||
im_end_id=self.IMG_END_TOKEN_ID,
|
||||
im_token_id=self.image_token_id,
|
||||
)
|
||||
return processor_output
|
||||
return {
|
||||
"mm_items": mm_items,
|
||||
"input_ids": input_ids.tolist(),
|
||||
"im_token_id": self.IM_TOKEN_ID,
|
||||
"im_token": self._processor.image_token,
|
||||
}
|
||||
|
||||
65
python/sglang/srt/multimodal/processors/qwen_audio.py
Normal file
65
python/sglang/srt/multimodal/processors/qwen_audio.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import re
|
||||
|
||||
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||
from sglang.srt.models.qwen2_audio import Qwen2AudioForConditionalGeneration
|
||||
from sglang.srt.multimodal.processors.base_processor import (
|
||||
BaseMultimodalProcessor,
|
||||
MultimodalSpecialTokens,
|
||||
)
|
||||
|
||||
|
||||
class Qwen2AudioMultimodalProcessor(BaseMultimodalProcessor):
|
||||
models = [Qwen2AudioForConditionalGeneration]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
self.AUDIO_TOKEN = "<|audio_bos|><|AUDIO|><|audio_eos|>"
|
||||
self.AUDIO_TOKEN_REGEX = re.compile(
|
||||
r"<\|audio_bos\|>(?:<\|AUDIO\|>)+<\|audio_eos\|>"
|
||||
)
|
||||
# Collect special token ids
|
||||
tokenizer = self._processor.tokenizer
|
||||
self.audio_start_id = tokenizer.convert_tokens_to_ids("<|audio_bos|>")
|
||||
self.audio_token_id = tokenizer.convert_tokens_to_ids("<|AUDIO|>")
|
||||
self.audio_end_id = tokenizer.convert_tokens_to_ids("<|audio_eos|>")
|
||||
|
||||
self.mm_tokens = MultimodalSpecialTokens(
|
||||
audio_token=self.AUDIO_TOKEN,
|
||||
audio_token_regex=self.AUDIO_TOKEN_REGEX,
|
||||
audio_token_id=self.audio_token_id,
|
||||
).build(_processor)
|
||||
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
audio_data,
|
||||
input_text,
|
||||
**kwargs,
|
||||
):
|
||||
base_output = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
audio_data=audio_data,
|
||||
multimodal_tokens=self.mm_tokens,
|
||||
)
|
||||
if base_output is None:
|
||||
return None
|
||||
|
||||
mm_items, input_ids, ret = self.process_and_combine_mm_data(
|
||||
base_output, self.mm_tokens
|
||||
)
|
||||
|
||||
assert (
|
||||
"feature_attention_mask" in ret
|
||||
), "feature_attention_mask not found in processor output"
|
||||
input_lengths = ret["feature_attention_mask"].sum(dim=-1)
|
||||
input_lengths = (input_lengths - 1) // 2 + 1
|
||||
output_lengths = (input_lengths - 2) // 2 + 1
|
||||
|
||||
mm_items[0].model_specific_data["audio_feature_lens"] = output_lengths
|
||||
|
||||
return {
|
||||
"mm_items": mm_items,
|
||||
"input_ids": input_ids.tolist(),
|
||||
"audio_start_id": self.audio_start_id,
|
||||
"audio_token_id": self.audio_token_id,
|
||||
"audio_end_id": self.audio_end_id,
|
||||
}
|
||||
@@ -227,7 +227,6 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
||||
image_data: List[Union[str, bytes]],
|
||||
input_text,
|
||||
request_obj,
|
||||
max_req_input_len,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -237,7 +236,6 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
||||
image_data=image_data,
|
||||
video_data=request_obj.video_data,
|
||||
multimodal_tokens=self.mm_tokens,
|
||||
max_req_input_len=max_req_input_len,
|
||||
)
|
||||
|
||||
# Qwen-specific: resize images if they are raw Image objects
|
||||
|
||||
@@ -47,13 +47,11 @@ class VILAMultimodalProcessor(BaseMultimodalProcessor):
|
||||
image_data: Optional[ImageDataInputItem | List[ImageDataInputItem]],
|
||||
input_text: str | List[int],
|
||||
request_obj: GenerateReqInput | EmbeddingReqInput,
|
||||
max_req_input_len: int,
|
||||
**kwargs,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
base_output = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
multimodal_tokens=self.mm_tokens,
|
||||
max_req_input_len=max_req_input_len,
|
||||
image_data=image_data,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user