Refactor mm processors and Enable mixed modality processing (#7629)
Signed-off-by: Xinyuan Tong <justinning0323@outlook.com>
This commit is contained in:
@@ -17,15 +17,6 @@ from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||
from sglang.srt.utils import encode_video, load_audio, load_image
|
||||
|
||||
|
||||
class MultimodalInputFormat(Enum):
|
||||
"""Enum for different multimodal input formats."""
|
||||
|
||||
RAW_IMAGES = "raw_images"
|
||||
PRECOMPUTED_FEATURES = "precomputed_features"
|
||||
PIXEL_VALUES = "pixel_values"
|
||||
AUDIO = "audio"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BaseMultiModalProcessorOutput:
|
||||
# input_text, with each frame of video/image represented with a image_token
|
||||
@@ -110,18 +101,45 @@ class BaseMultimodalProcessor(ABC):
|
||||
max_workers=int(os.environ.get("SGLANG_CPU_WORKERS", os.cpu_count())),
|
||||
)
|
||||
|
||||
# Mapping from attribute names to modality types
|
||||
self.ATTR_NAME_TO_MODALITY = {
|
||||
# Image-related attributes
|
||||
"pixel_values": Modality.IMAGE,
|
||||
"image_sizes": Modality.IMAGE,
|
||||
"image_grid_thw": Modality.IMAGE,
|
||||
"image_emb_mask": Modality.IMAGE,
|
||||
"image_spatial_crop": Modality.IMAGE,
|
||||
"tgt_size": Modality.IMAGE,
|
||||
"image_grid_hws": Modality.IMAGE,
|
||||
"aspect_ratio_id": Modality.IMAGE,
|
||||
"aspect_ratio_mask": Modality.IMAGE,
|
||||
"second_per_grid_ts": Modality.IMAGE,
|
||||
# Audio-related attributes
|
||||
"audio_features": Modality.AUDIO,
|
||||
"audio_feature_lens": Modality.AUDIO,
|
||||
"input_features": Modality.AUDIO,
|
||||
"input_features_mask": Modality.AUDIO,
|
||||
# Video-related attributes
|
||||
"video_grid_thws": Modality.VIDEO,
|
||||
# Generic attributes that could apply to multiple modalities
|
||||
# "precomputed_features" - handled specially as it can be any modality
|
||||
}
|
||||
|
||||
def process_mm_data(
|
||||
self, input_text, images=None, videos=None, audios=None, **kwargs
|
||||
):
|
||||
"""
|
||||
process multimodal data with transformers AutoProcessor
|
||||
"""
|
||||
if images is not None:
|
||||
if images:
|
||||
kwargs["images"] = images
|
||||
if videos is not None:
|
||||
if videos:
|
||||
kwargs["videos"] = videos
|
||||
if audios is not None:
|
||||
if audios:
|
||||
kwargs["audios"] = audios
|
||||
if self.__class__.__name__ == "Gemma3nSGLangProcessor":
|
||||
# Note(Xinyuan): for gemma3n, ref: https://github.com/huggingface/transformers/blob/ccf2ca162e33f381e454cdb74bf4b41a51ab976d/src/transformers/models/gemma3n/processing_gemma3n.py#L107
|
||||
kwargs["audio"] = audios
|
||||
|
||||
processor = self._processor
|
||||
if hasattr(processor, "image_processor") and isinstance(
|
||||
@@ -144,6 +162,7 @@ class BaseMultimodalProcessor(ABC):
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
image_data,
|
||||
audio_data,
|
||||
input_text,
|
||||
request_obj,
|
||||
max_req_input_len,
|
||||
@@ -418,175 +437,137 @@ class BaseMultimodalProcessor(ABC):
|
||||
values[k] = v
|
||||
return values
|
||||
|
||||
def process_and_combine_mm_data(
|
||||
self, base_output: BaseMultiModalProcessorOutput
|
||||
) -> Tuple[Optional[MultimodalDataItem], torch.Tensor]:
|
||||
def collect_mm_items_from_processor_output(
|
||||
self, data_dict: dict
|
||||
) -> List[MultimodalDataItem]:
|
||||
"""Create mm_items directly from processor output."""
|
||||
items = {} # modality -> MultimodalDataItem
|
||||
|
||||
for attr_name, value in data_dict.items():
|
||||
if attr_name == "input_ids":
|
||||
continue
|
||||
|
||||
# Get modality for this attribute
|
||||
modality = self.ATTR_NAME_TO_MODALITY.get(attr_name)
|
||||
|
||||
if not modality and attr_name == "precomputed_features":
|
||||
modality_str = data_dict.get("modality")
|
||||
try:
|
||||
modality = (
|
||||
Modality.from_str(modality_str)
|
||||
if modality_str
|
||||
else Modality.IMAGE
|
||||
)
|
||||
except ValueError:
|
||||
modality = Modality.IMAGE
|
||||
|
||||
if modality:
|
||||
# Create item if needed
|
||||
if modality not in items:
|
||||
items[modality] = MultimodalDataItem(modality=modality)
|
||||
|
||||
# Set attribute
|
||||
if hasattr(items[modality], attr_name):
|
||||
setattr(items[modality], attr_name, value)
|
||||
|
||||
return list(items.values())
|
||||
|
||||
def _process_and_collect_mm_items(
|
||||
self, input_text: str, images=None, audios=None, videos=None, **kwargs
|
||||
) -> Tuple[List[MultimodalDataItem], torch.Tensor]:
|
||||
"""
|
||||
Process multimodal data and return the combined multimodal item and input_ids.
|
||||
Handles all three input formats at the same abstraction level.
|
||||
Helper method to process multimodal data and create mm_items in one step.
|
||||
|
||||
Returns:
|
||||
Tuple of (combined_mm_item, input_ids)
|
||||
Tuple of (created mm_items, input_ids)
|
||||
"""
|
||||
ret = self.process_mm_data(
|
||||
input_text=input_text, images=images, audios=audios, videos=videos, **kwargs
|
||||
)
|
||||
|
||||
def tokenize_text(input_text: str) -> torch.Tensor:
|
||||
"""Tokenize input text."""
|
||||
return self._processor.tokenizer(
|
||||
input_text,
|
||||
input_ids = ret["input_ids"].flatten()
|
||||
collected_items = self.collect_mm_items_from_processor_output(ret)
|
||||
|
||||
return collected_items, input_ids
|
||||
|
||||
def process_and_combine_mm_data(
|
||||
self, base_output: BaseMultiModalProcessorOutput
|
||||
) -> Tuple[List[MultimodalDataItem], torch.Tensor]:
|
||||
"""
|
||||
Process multimodal data and return the combined multimodal items and input_ids.
|
||||
Supports mixed modalities (images and audio in the same request).
|
||||
|
||||
Returns:
|
||||
Tuple of (list of mm_items, input_ids)
|
||||
"""
|
||||
# Collect all items and categorize them
|
||||
all_items = (base_output.images or []) + (base_output.audios or [])
|
||||
|
||||
# Handle text-only case
|
||||
if not all_items:
|
||||
input_ids = self._processor.tokenizer(
|
||||
base_output.input_text,
|
||||
return_tensors="pt",
|
||||
add_special_tokens=True,
|
||||
).input_ids.flatten()
|
||||
return [], input_ids
|
||||
|
||||
dict_items, raw_images, raw_audios = [], [], []
|
||||
for item in all_items:
|
||||
if isinstance(item, dict):
|
||||
dict_items.append(item)
|
||||
elif isinstance(item, Image.Image):
|
||||
raw_images.append(item)
|
||||
elif isinstance(item, np.ndarray):
|
||||
raw_audios.append(item)
|
||||
else:
|
||||
raise ValueError(f"Unknown multimodal item type: {type(item)}")
|
||||
|
||||
# Process items and get input_ids
|
||||
all_collected_items = []
|
||||
input_ids = None
|
||||
|
||||
# Handle dict items (already processed)
|
||||
for dict_item in dict_items:
|
||||
all_collected_items.extend(
|
||||
self.collect_mm_items_from_processor_output(dict_item)
|
||||
)
|
||||
|
||||
# Handle raw items (need processing)
|
||||
if raw_images or raw_audios:
|
||||
collected_items, input_ids = self._process_and_collect_mm_items(
|
||||
input_text=base_output.input_text,
|
||||
images=raw_images,
|
||||
audios=raw_audios,
|
||||
)
|
||||
all_collected_items.extend(collected_items)
|
||||
|
||||
# Fallback tokenization if no raw items were processed
|
||||
if input_ids is None:
|
||||
input_ids = self._processor.tokenizer(
|
||||
base_output.input_text,
|
||||
return_tensors="pt",
|
||||
add_special_tokens=True,
|
||||
).input_ids.flatten()
|
||||
|
||||
def categorize_mm_inputs(mm_inputs: List) -> MultimodalInputFormat:
|
||||
"""Categorize multimodal inputs and validate consistency."""
|
||||
try:
|
||||
has_image = False
|
||||
has_pixel_values = False
|
||||
has_precomputed_features = False
|
||||
has_audio = False
|
||||
|
||||
for mm_input in mm_inputs:
|
||||
if isinstance(mm_input, Image.Image):
|
||||
has_image = True
|
||||
elif isinstance(mm_input, np.ndarray):
|
||||
has_audio = True
|
||||
elif isinstance(mm_input, dict):
|
||||
if mm_input.get("precomputed_features", None) is not None:
|
||||
has_precomputed_features = True
|
||||
elif mm_input.get("pixel_values", None) is not None:
|
||||
has_pixel_values = True
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid multimodal input: {mm_input}, expected dict with pixel_values or precomputed_features"
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid multimodal input: {mm_input}, expected Image.Image or dict"
|
||||
)
|
||||
|
||||
# Validate format consistency
|
||||
format_count = sum(
|
||||
[has_image, has_pixel_values, has_precomputed_features, has_audio]
|
||||
)
|
||||
if format_count > 1:
|
||||
raise ValueError(
|
||||
"Unsupported: mixture of multimodal input formats. "
|
||||
f"Found formats: image={has_image}, pixel_values={has_pixel_values}, "
|
||||
f"precomputed_features={has_precomputed_features}, audio={has_audio}"
|
||||
)
|
||||
|
||||
if has_image:
|
||||
return MultimodalInputFormat.RAW_IMAGES
|
||||
elif has_precomputed_features:
|
||||
return MultimodalInputFormat.PRECOMPUTED_FEATURES
|
||||
elif has_pixel_values:
|
||||
return MultimodalInputFormat.PIXEL_VALUES
|
||||
elif has_audio:
|
||||
return MultimodalInputFormat.AUDIO
|
||||
else:
|
||||
raise ValueError("No valid multimodal input format found")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to categorize inputs: {e}")
|
||||
|
||||
def process_raw_images(
|
||||
base_output: BaseMultiModalProcessorOutput,
|
||||
) -> Tuple[MultimodalDataItem, torch.Tensor]:
|
||||
"""Process raw Image.Image objects using transformers processor."""
|
||||
ret = self.process_mm_data(
|
||||
input_text=base_output.input_text,
|
||||
images=base_output.images,
|
||||
)
|
||||
combined_mm_item = MultimodalDataItem(modality=Modality.IMAGE)
|
||||
|
||||
# Copy all fields from processor output except input_ids
|
||||
for key, value in ret.items():
|
||||
if key != "input_ids" and hasattr(combined_mm_item, key):
|
||||
setattr(combined_mm_item, key, value)
|
||||
|
||||
input_ids = ret["input_ids"].flatten()
|
||||
return combined_mm_item, input_ids
|
||||
|
||||
def process_precomputed_features(
|
||||
base_output: BaseMultiModalProcessorOutput,
|
||||
) -> Tuple[MultimodalDataItem, torch.Tensor]:
|
||||
"""Process inputs with precomputed features."""
|
||||
combined_mm_item = MultimodalDataItem(modality=Modality.IMAGE)
|
||||
combined_mm_item.precomputed_features = self._extract_processor_features(
|
||||
base_output.images, "precomputed_features"
|
||||
)
|
||||
input_ids = tokenize_text(base_output.input_text)
|
||||
return combined_mm_item, input_ids
|
||||
|
||||
def process_pixel_values(
|
||||
base_output: BaseMultiModalProcessorOutput,
|
||||
) -> Tuple[MultimodalDataItem, torch.Tensor]:
|
||||
"""Process inputs with pixel values."""
|
||||
values = self._extract_processor_features_from_all_attributes(
|
||||
base_output.images
|
||||
)
|
||||
combined_mm_item = MultimodalDataItem.from_dict(values)
|
||||
input_ids = tokenize_text(base_output.input_text)
|
||||
return combined_mm_item, input_ids
|
||||
|
||||
def process_audio(
|
||||
base_output: BaseMultiModalProcessorOutput,
|
||||
) -> Tuple[MultimodalDataItem, torch.Tensor]:
|
||||
"""Process inputs with audio."""
|
||||
ret = self.process_mm_data(
|
||||
input_text=base_output.input_text,
|
||||
audio=base_output.audios, # Note: "audio" is for gemma3n only
|
||||
)
|
||||
combined_mm_item = MultimodalDataItem(modality=Modality.AUDIO)
|
||||
for key, value in ret.items():
|
||||
if key != "input_ids" and hasattr(combined_mm_item, key):
|
||||
setattr(combined_mm_item, key, value)
|
||||
input_ids = ret["input_ids"].flatten()
|
||||
return combined_mm_item, input_ids
|
||||
|
||||
def finalize_mm_item(
|
||||
combined_mm_item: MultimodalDataItem, input_ids: torch.Tensor
|
||||
) -> MultimodalDataItem:
|
||||
"""Apply common post-processing to the multimodal item."""
|
||||
if combined_mm_item.modality in [Modality.IMAGE, Modality.MULTI_IMAGES]:
|
||||
combined_mm_item.image_offsets = self.get_mm_items_offset(
|
||||
# Add offsets to all items
|
||||
for mm_item in all_collected_items:
|
||||
if mm_item.modality in [Modality.IMAGE, Modality.MULTI_IMAGES]:
|
||||
mm_item.image_offsets = self.get_mm_items_offset(
|
||||
input_ids=input_ids,
|
||||
mm_token_id=self.IM_TOKEN_ID,
|
||||
)
|
||||
elif combined_mm_item.modality == Modality.AUDIO:
|
||||
combined_mm_item.audio_offsets = self.get_mm_items_offset(
|
||||
elif mm_item.modality == Modality.AUDIO:
|
||||
mm_item.audio_offsets = self.get_mm_items_offset(
|
||||
input_ids=input_ids,
|
||||
mm_token_id=self.AUDIO_TOKEN_ID,
|
||||
)
|
||||
elif combined_mm_item.modality == Modality.VIDEO:
|
||||
combined_mm_item.video_offsets = self.get_mm_items_offset(
|
||||
elif mm_item.modality == Modality.VIDEO:
|
||||
mm_item.video_offsets = self.get_mm_items_offset(
|
||||
input_ids=input_ids,
|
||||
mm_token_id=self.VIDEO_TOKEN_ID,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown modality: {combined_mm_item.modality}")
|
||||
return combined_mm_item
|
||||
raise ValueError(f"Unknown modality: {mm_item.modality}")
|
||||
|
||||
# Main logic - determine input type and handle text-only case
|
||||
mm_inputs = base_output.images or base_output.audios
|
||||
if not mm_inputs:
|
||||
input_ids = tokenize_text(base_output.input_text)
|
||||
return None, input_ids
|
||||
|
||||
# Categorize input formats
|
||||
input_format = categorize_mm_inputs(mm_inputs)
|
||||
|
||||
# Process based on format
|
||||
if input_format == MultimodalInputFormat.RAW_IMAGES:
|
||||
combined_mm_item, input_ids = process_raw_images(base_output)
|
||||
elif input_format == MultimodalInputFormat.PRECOMPUTED_FEATURES:
|
||||
combined_mm_item, input_ids = process_precomputed_features(base_output)
|
||||
elif input_format == MultimodalInputFormat.PIXEL_VALUES:
|
||||
combined_mm_item, input_ids = process_pixel_values(base_output)
|
||||
elif input_format == MultimodalInputFormat.AUDIO:
|
||||
combined_mm_item, input_ids = process_audio(base_output)
|
||||
else:
|
||||
raise ValueError(f"Unknown input format: {input_format}")
|
||||
|
||||
# Finalize with common processing
|
||||
combined_mm_item = finalize_mm_item(combined_mm_item, input_ids)
|
||||
return combined_mm_item, input_ids
|
||||
return all_collected_items, input_ids
|
||||
|
||||
@@ -15,20 +15,11 @@ class ClipImageProcessor(BaseMultimodalProcessor):
|
||||
async def process_mm_data_async(
|
||||
self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs
|
||||
):
|
||||
if not image_data:
|
||||
return None
|
||||
|
||||
if isinstance(input_text, list):
|
||||
assert len(input_text) and isinstance(input_text[0], int)
|
||||
input_text = self._processor.tokenizer.decode(input_text)
|
||||
|
||||
if not isinstance(image_data, list):
|
||||
image_data = [image_data]
|
||||
|
||||
if len(image_data) > 0:
|
||||
images = [load_image(image)[0] for image in image_data]
|
||||
else:
|
||||
images = load_image(image_data[0])[0]
|
||||
images = [load_image(image)[0] for image in image_data]
|
||||
|
||||
image_inputs = self.process_mm_data(input_text=input_text, images=images)
|
||||
image_inputs["data_hashes"] = [hash(str(image_data))]
|
||||
|
||||
@@ -44,17 +44,10 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
if not image_data:
|
||||
return None
|
||||
|
||||
if not isinstance(image_data, list):
|
||||
image_data = [image_data]
|
||||
|
||||
image_token = self.IMAGE_TOKEN
|
||||
base_output = self.load_mm_data(
|
||||
input_text,
|
||||
image_data=image_data,
|
||||
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
|
||||
multimodal_tokens=MultimodalSpecialTokens(image_token=self.IMAGE_TOKEN),
|
||||
max_req_input_len=max_req_input_len,
|
||||
)
|
||||
res = self.process_mm_data(
|
||||
|
||||
@@ -36,11 +36,6 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
if not image_data:
|
||||
return None
|
||||
if isinstance(image_data, str):
|
||||
image_data = [image_data]
|
||||
|
||||
base_output = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
image_data=image_data,
|
||||
@@ -51,11 +46,11 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
|
||||
discard_alpha_channel=True,
|
||||
)
|
||||
|
||||
combined_mm_item, input_ids = self.process_and_combine_mm_data(base_output)
|
||||
mm_items, input_ids = self.process_and_combine_mm_data(base_output)
|
||||
|
||||
return {
|
||||
"input_ids": input_ids.tolist(),
|
||||
"mm_items": [combined_mm_item] if combined_mm_item is not None else [],
|
||||
"mm_items": mm_items,
|
||||
"im_start_id": self.IM_START_TOKEN_ID,
|
||||
"im_end_id": self.IM_END_TOKEN_ID,
|
||||
}
|
||||
|
||||
@@ -59,17 +59,6 @@ class Gemma3nSGLangProcessor(SGLangBaseProcessor):
|
||||
**kwargs,
|
||||
):
|
||||
"""Process multimodal data including images and audio."""
|
||||
|
||||
audio_data = request_obj.audio_data
|
||||
if not image_data and not audio_data:
|
||||
return None
|
||||
|
||||
if isinstance(image_data, str):
|
||||
image_data = [image_data]
|
||||
|
||||
if isinstance(audio_data, str):
|
||||
audio_data = [audio_data]
|
||||
|
||||
base_output = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
image_data=image_data,
|
||||
@@ -83,13 +72,11 @@ class Gemma3nSGLangProcessor(SGLangBaseProcessor):
|
||||
),
|
||||
)
|
||||
|
||||
combined_mm_item, input_ids = self.process_and_combine_mm_data(base_output)
|
||||
mm_items, input_ids = self.process_and_combine_mm_data(base_output)
|
||||
|
||||
return {
|
||||
"input_ids": input_ids.tolist(),
|
||||
"mm_items": [combined_mm_item] if combined_mm_item is not None else [],
|
||||
"im_start_id": self.IM_START_TOKEN_ID,
|
||||
"im_end_id": self.IM_END_TOKEN_ID,
|
||||
"audio_start_id": self.AUDIO_START_TOKEN_ID,
|
||||
"audio_end_id": self.AUDIO_END_TOKEN_ID,
|
||||
"mm_items": mm_items,
|
||||
"im_token_id": self.IM_TOKEN_ID,
|
||||
"audio_token_id": self.AUDIO_TOKEN_ID,
|
||||
}
|
||||
|
||||
@@ -172,13 +172,6 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
|
||||
async def process_mm_data_async(
|
||||
self, image_data, input_text, request_obj, max_req_input_len, **kwargs
|
||||
):
|
||||
if not image_data:
|
||||
return None
|
||||
|
||||
# Ensure image_data is a list
|
||||
if isinstance(image_data, str):
|
||||
image_data = [image_data]
|
||||
|
||||
base_output = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
image_data=image_data,
|
||||
|
||||
@@ -22,12 +22,6 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
|
||||
max_req_input_len,
|
||||
**kwargs,
|
||||
):
|
||||
if not image_data:
|
||||
return None
|
||||
|
||||
if not isinstance(image_data, list):
|
||||
image_data = [image_data]
|
||||
|
||||
processor = self._processor
|
||||
|
||||
base_out = self.load_mm_data(
|
||||
|
||||
@@ -30,11 +30,6 @@ class KimiVLImageProcessor(SGLangBaseProcessor):
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
if not image_data:
|
||||
return None
|
||||
if isinstance(image_data, str):
|
||||
image_data = [image_data]
|
||||
|
||||
base_output = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
image_data=image_data,
|
||||
@@ -44,10 +39,10 @@ class KimiVLImageProcessor(SGLangBaseProcessor):
|
||||
max_req_input_len=max_req_input_len,
|
||||
)
|
||||
|
||||
combined_mm_item, input_ids = self.process_and_combine_mm_data(base_output)
|
||||
mm_items, input_ids = self.process_and_combine_mm_data(base_output)
|
||||
|
||||
return {
|
||||
"input_ids": input_ids.tolist(),
|
||||
"mm_items": [combined_mm_item] if combined_mm_item is not None else [],
|
||||
"mm_items": mm_items,
|
||||
"im_token_id": self.IM_TOKEN_ID,
|
||||
}
|
||||
|
||||
@@ -110,9 +110,6 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
if not image_data:
|
||||
return None
|
||||
|
||||
modalities = request_obj.modalities or ["image"]
|
||||
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
|
||||
grid_pinpoints = (
|
||||
@@ -122,9 +119,6 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
|
||||
else None
|
||||
)
|
||||
|
||||
if isinstance(image_data, str):
|
||||
image_data = [image_data]
|
||||
|
||||
if isinstance(image_data, list) and len(image_data) > 0:
|
||||
if "multi-images" in modalities or "video" in modalities:
|
||||
# Multiple images
|
||||
|
||||
@@ -23,19 +23,12 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
image_data: List[Union[str, bytes]],
|
||||
audio_data: List[Union[str, bytes]],
|
||||
input_text,
|
||||
request_obj,
|
||||
max_req_input_len,
|
||||
**kwargs,
|
||||
):
|
||||
audio_data = request_obj.audio_data
|
||||
if not image_data and not audio_data:
|
||||
return None
|
||||
if not isinstance(image_data, list):
|
||||
image_data = [image_data]
|
||||
if not isinstance(audio_data, list):
|
||||
audio_data = [audio_data]
|
||||
|
||||
base_output = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
max_req_input_len=max_req_input_len,
|
||||
|
||||
@@ -15,21 +15,11 @@ class MllamaImageProcessor(BaseMultimodalProcessor):
|
||||
async def process_mm_data_async(
|
||||
self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs
|
||||
):
|
||||
if not image_data:
|
||||
return None
|
||||
|
||||
if isinstance(input_text, list):
|
||||
assert len(input_text) and isinstance(input_text[0], int)
|
||||
input_text = self._processor.tokenizer.decode(input_text)
|
||||
|
||||
if not isinstance(image_data, list):
|
||||
image_data = [image_data]
|
||||
|
||||
if len(image_data) > 0:
|
||||
images = [load_image(image)[0] for image in image_data]
|
||||
else:
|
||||
images = load_image(image_data[0])[0]
|
||||
|
||||
images = [load_image(image)[0] for image in image_data]
|
||||
image_inputs = self.process_mm_data(input_text=input_text, images=images)
|
||||
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
|
||||
image_inputs["mm_items"] = [
|
||||
|
||||
@@ -37,9 +37,6 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
if not image_data:
|
||||
return None
|
||||
|
||||
if isinstance(input_text, list):
|
||||
assert len(input_text) and isinstance(input_text[0], int)
|
||||
input_text = self._processor.tokenizer.decode(input_text)
|
||||
|
||||
@@ -26,22 +26,12 @@ class Phi4MMImageProcessor(BaseMultimodalProcessor):
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
image_data: List[Union[str, bytes]],
|
||||
audio_data,
|
||||
input_text,
|
||||
request_obj,
|
||||
max_req_input_len,
|
||||
**kwargs,
|
||||
):
|
||||
audio_data = request_obj.audio_data
|
||||
|
||||
if not image_data and not audio_data:
|
||||
return None
|
||||
|
||||
if not isinstance(image_data, list):
|
||||
image_data = [image_data]
|
||||
|
||||
if not isinstance(audio_data, list):
|
||||
audio_data = [audio_data]
|
||||
|
||||
if audio_data:
|
||||
logger.warning(
|
||||
"Currently SGLang does not support audio data for Phi4MM. We are working on it. You can file an issue to help us prioritize."
|
||||
|
||||
@@ -78,12 +78,6 @@ class PixtralProcessor(BaseMultimodalProcessor):
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
if not image_data:
|
||||
return None
|
||||
|
||||
if isinstance(image_data, str):
|
||||
image_data = [image_data]
|
||||
|
||||
mm_data = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
multimodal_tokens=self.multimodal_tokens,
|
||||
|
||||
@@ -49,9 +49,6 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
if isinstance(image_data, str):
|
||||
image_data = [image_data]
|
||||
|
||||
base_output = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
image_data=image_data,
|
||||
@@ -130,12 +127,13 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
||||
|
||||
video_grid_thw = None # TODO
|
||||
|
||||
combined_mm_item, input_ids = self.process_and_combine_mm_data(base_output)
|
||||
mm_items, input_ids = self.process_and_combine_mm_data(base_output)
|
||||
|
||||
if combined_mm_item is None:
|
||||
if not mm_items:
|
||||
# Note(Xinyuan): This is the case where image loading fails.
|
||||
return None
|
||||
|
||||
combined_mm_item = mm_items[0] # only image is supported for now
|
||||
video_grid_thw = None # TODO
|
||||
second_per_grid_ts = getattr(combined_mm_item, "second_per_grid_ts", None)
|
||||
|
||||
@@ -157,7 +155,7 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
||||
|
||||
return {
|
||||
"input_ids": input_ids.tolist(),
|
||||
"mm_items": [combined_mm_item],
|
||||
"mm_items": mm_items,
|
||||
"im_start_id": self.IM_START_TOKEN_ID,
|
||||
"im_end_id": self.IM_END_TOKEN_ID,
|
||||
"im_token_id": self.IM_TOKEN_ID,
|
||||
|
||||
@@ -37,6 +37,8 @@ class VILAMultimodalProcessor(BaseMultimodalProcessor):
|
||||
_processor: VILAProcessor,
|
||||
) -> None:
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
self.IM_TOKEN_ID = hf_config.image_token_id
|
||||
self.VIDEO_TOKEN_ID = hf_config.video_token_id
|
||||
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
@@ -46,13 +48,7 @@ class VILAMultimodalProcessor(BaseMultimodalProcessor):
|
||||
max_req_input_len: int,
|
||||
**kwargs,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
if not image_data:
|
||||
return None
|
||||
|
||||
if not isinstance(image_data, list):
|
||||
image_data = [image_data]
|
||||
|
||||
mm_data = self.load_mm_data(
|
||||
base_output = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
multimodal_tokens=MultimodalSpecialTokens(
|
||||
image_token=self._processor.tokenizer.image_token
|
||||
@@ -61,25 +57,11 @@ class VILAMultimodalProcessor(BaseMultimodalProcessor):
|
||||
image_data=image_data,
|
||||
)
|
||||
|
||||
inputs = self.process_mm_data(
|
||||
input_text=mm_data.input_text,
|
||||
images=mm_data.images,
|
||||
)
|
||||
mm_items, input_ids = self.process_and_combine_mm_data(base_output)
|
||||
|
||||
image_offsets = self.get_mm_items_offset(
|
||||
input_ids=inputs.input_ids[0],
|
||||
mm_token_id=cast(int, self._processor.tokenizer.image_token_id),
|
||||
)
|
||||
|
||||
mm_items: List[MultimodalDataItem] = [
|
||||
MultimodalDataItem(
|
||||
modality=Modality.IMAGE,
|
||||
image_offsets=image_offsets,
|
||||
pixel_values=inputs.pixel_values,
|
||||
)
|
||||
]
|
||||
|
||||
return dict(
|
||||
input_ids=inputs.input_ids[0].tolist(),
|
||||
mm_items=mm_items,
|
||||
)
|
||||
return {
|
||||
"input_ids": input_ids.tolist(),
|
||||
"mm_items": mm_items,
|
||||
"im_token_id": self.IM_TOKEN_ID,
|
||||
"video_token_id": self.VIDEO_TOKEN_ID,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user