diff --git a/benchmark/mmmu/bench_sglang.py b/benchmark/mmmu/bench_sglang.py index 3786d2b09..3f91678ac 100644 --- a/benchmark/mmmu/bench_sglang.py +++ b/benchmark/mmmu/bench_sglang.py @@ -89,5 +89,4 @@ if __name__ == "__main__": EvalArgs.add_cli_args(parser) args = add_common_sglang_args_and_parse(parser) args = parser.parse_args() - eval_mmmu(args) diff --git a/benchmark/mmmu/eval_utils.py b/benchmark/mmmu/eval_utils.py index 2a4c9a939..2613be788 100644 --- a/benchmark/mmmu/eval_utils.py +++ b/benchmark/mmmu/eval_utils.py @@ -7,6 +7,7 @@ import os import pprint import random import re +from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Dict, Optional import numpy as np @@ -117,29 +118,38 @@ def prepare_samples(eval_args: EvalArgs): # merge all dataset dataset = concatenate_datasets(sub_dataset_list) - ## prepare images - samples = [] - skip_count = 0 - - # use image file as input to ensure the consistency between sglang and hf + # Prepare images in parallel images_path = os.path.expanduser("~/.cache/mmmu/images") os.makedirs(images_path, exist_ok=True) print(f"Saving images to: {images_path}") - for i, sample in enumerate(tqdm(dataset)): + samples = [] + skip_count = 0 + + def process_sample(i, sample): sample = process_single_sample(sample) sample = construct_prompt(sample, eval_args.config) image = sample["image"] - width, height = image.size if width * height >= eval_args.image_pixels_limit: - skip_count += 1 - continue + return None, True image_path = f"{images_path}/image_{i}.png" if not os.path.exists(image_path): image.save(image_path) sample["image_path"] = image_path - samples.append(sample) + return sample, False + + with ThreadPoolExecutor() as executor: + futures = [ + executor.submit(process_sample, i, sample) + for i, sample in enumerate(dataset) + ] + for future in tqdm(as_completed(futures), total=len(futures)): + sample, skipped = future.result() + if skipped: + skip_count += 1 + elif sample: + samples.append(sample) print( f"skipping {skip_count} samples with large images, {round((float(skip_count) / len(dataset)) * 100, 2)}% of dataset" diff --git a/docs/backend/server_arguments.md b/docs/backend/server_arguments.md index 983c7316c..7ae15f1f2 100644 --- a/docs/backend/server_arguments.md +++ b/docs/backend/server_arguments.md @@ -45,7 +45,7 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct Please consult the documentation below to learn more about the parameters you may provide when launching a server. -## Model and tokenizer +## Model, processor and tokenizer * `model_path`: Path to the model that will be served. * `tokenizer_path`: Defaults to the `model_path`. @@ -62,6 +62,7 @@ Please consult the documentation below to learn more about the parameters you ma * `skip_tokenizer_init`: Set to true to provide the tokens to the engine and get the output tokens directly, typically used in RLHF. Please see this [example for reference](https://github.com/sgl-project/sglang/blob/main/examples/runtime/token_in_token_out/). * `json_model_override_args`: Override model config with the provided JSON. * `delete_ckpt_after_loading`: Delete the model checkpoint after loading the model. +* `disable_fast_image_processor`: Adopt base image processor instead of fast image processor(which is by default). For more detail, see: https://huggingface.co/docs/transformers/main/en/main_classes/image_processor#image-processor ## Serving: HTTP & API diff --git a/python/sglang/srt/hf_transformers_utils.py b/python/sglang/srt/hf_transformers_utils.py index 397404d30..0a189a7bf 100644 --- a/python/sglang/srt/hf_transformers_utils.py +++ b/python/sglang/srt/hf_transformers_utils.py @@ -215,6 +215,7 @@ def get_processor( tokenizer_mode: str = "auto", trust_remote_code: bool = False, tokenizer_revision: Optional[str] = None, + use_fast: Optional[bool] = True, **kwargs, ): # pop 'revision' from kwargs if present. @@ -232,6 +233,9 @@ def get_processor( if "size" not in kwargs: kwargs["size"] = {"shortest_edge": 3136, "longest_edge": 1003520} + if config.model_type not in {"llava", "clip"}: + kwargs["use_fast"] = use_fast + processor = AutoProcessor.from_pretrained( tokenizer_name, *args, diff --git a/python/sglang/srt/managers/multimodal_processors/base_processor.py b/python/sglang/srt/managers/multimodal_processors/base_processor.py index c976f24f7..22ad7e797 100644 --- a/python/sglang/srt/managers/multimodal_processors/base_processor.py +++ b/python/sglang/srt/managers/multimodal_processors/base_processor.py @@ -4,14 +4,16 @@ import dataclasses import multiprocessing as mp import os from abc import ABC, abstractmethod -from typing import Optional +from typing import List, Optional import numpy as np import PIL from decord import VideoReader, cpu from PIL import Image +from transformers import BaseImageProcessorFast -from sglang.srt.utils import encode_video, load_audio, load_image, logger +from sglang.srt.managers.schedule_batch import Modality +from sglang.srt.utils import encode_video, load_audio, load_image @dataclasses.dataclass @@ -78,6 +80,10 @@ class BaseMultimodalProcessor(ABC): kwargs["audios"] = audios processor = self._processor + if hasattr(processor, "image_processor") and isinstance( + processor.image_processor, BaseImageProcessorFast + ): + kwargs["device"] = "cuda" result = processor.__call__( text=[input_text], padding=True, @@ -111,6 +117,84 @@ class BaseMultimodalProcessor(ABC): return estimated_frames_list + @staticmethod + def _load_single_item( + data, is_video, is_audio, frame_count_limit=None, discard_alpha_channel=True + ): + """Static method that can be pickled for multiprocessing""" + try: + if is_audio: + return load_audio(data) + elif is_video: + path = data[len("video:") :] + return encode_video(path, frame_count_limit) + else: + img, _ = load_image(data) + return img.convert("RGB") if discard_alpha_channel else img + except Exception as e: + raise RuntimeError(f"Error while loading data {data}: {e}") + + def submit_data_loading_tasks( + self, + text_parts: List[str], + multimodal_tokens: MultimodalSpecialTokens, + image_data: Optional[list] = None, + audio_data: Optional[list] = None, + discard_alpha_channel: bool = True, + ): + """ + load multimodal data parallelly + """ + + # TODO(mick): load from server_args, env, or sampling_params + MAX_NUM_FRAMES = 30 + estimated_frames_list = self.get_estimated_frames_list(image_data=image_data) + total_frame_count = sum(estimated_frames_list) + # a heuristic value, suggesting the maximum fraction of frames to embed from all visual inputs. + # e.g., 0.1 suggests that 1 frame out of 10 input frames should be used + scaling_factor = min(1.0, MAX_NUM_FRAMES / max(1, total_frame_count)) + + assert len(image_data) == len(estimated_frames_list) + # Submit all tasks + futures = [] + task_info = [] + image_index, audio_index = 0, 0 + + for text_part in text_parts: + if text_part == multimodal_tokens.image_token: + data = image_data[image_index] + is_video = isinstance(data, str) and data.startswith("video:") + estimated_frames = estimated_frames_list[image_index] + frame_count_limit = max(1, int(estimated_frames * scaling_factor)) + futures.append( + self.io_executor.submit( + BaseMultimodalProcessor._load_single_item, + data, + is_video, + False, + frame_count_limit, + discard_alpha_channel, + ) + ) + task_info.append((Modality.IMAGE, data, frame_count_limit)) + image_index += 1 + elif text_part == multimodal_tokens.audio_token: + data = audio_data[audio_index] + futures.append( + self.io_executor.submit( + BaseMultimodalProcessor._load_single_item, + data, + False, + True, + None, + discard_alpha_channel, + ) + ) + task_info.append((Modality.AUDIO, data, None)) + audio_index += 1 + + return futures, task_info + def load_mm_data( self, prompt: str, @@ -155,84 +239,37 @@ class BaseMultimodalProcessor(ABC): # split text into list of normal text and special tokens text_parts = re.split(pattern, prompt) - # TODO(mick): load from server_args, env, or sampling_params - MAX_NUM_FRAMES = 30 - estimated_frames_list = self.get_estimated_frames_list(image_data=image_data) - total_frame_count = sum(estimated_frames_list) - # a heuristic value, suggesting the maximum fraction of frames to embed from all visual inputs. - # e.g., 0.1 suggests that 1 frame out of 10 input frames should be used - scaling_factor = min(1.0, MAX_NUM_FRAMES / max(1, total_frame_count)) - - assert len(image_data) == len(estimated_frames_list) - - image_index, audio_index = 0, 0 - hashes, image_sizes, images, audios = [], [], [], [] + futures, task_info = self.submit_data_loading_tasks( + text_parts=text_parts, + multimodal_tokens=multimodal_tokens, + image_data=image_data, + audio_data=audio_data, + discard_alpha_channel=discard_alpha_channel, + ) + # Process results + image_sizes, images, audios = [], [], [] new_text = "" - for index, text_part in enumerate(text_parts): - try: - if text_part == multimodal_tokens.image_token: - # load as image - if len(images) >= MAX_NUM_FRAMES: - frames_to_process = 0 - else: - estimated_frames = estimated_frames_list[image_index] - frames_to_process = max( - 1, int(estimated_frames * scaling_factor) - ) + task_ptr = 0 - if frames_to_process == 0: - frames = [] - else: - image_file = image_data[image_index] - if isinstance(image_file, str) and image_file.startswith( - "video:" - ): - # video - path = image_file[len("video:") :] - frames = encode_video( - path, frame_count_limit=frames_to_process - ) - else: - # image - raw_image, _size = load_image(image_file) - if discard_alpha_channel: - raw_image = raw_image.convert("RGB") - frames = [raw_image] - if len(frames) == 0: - continue + for text_part in text_parts: + if text_part in multimodal_tokens.collect(): + task_type, data, frame_limit = task_info[task_ptr] + result = futures[task_ptr].result() + task_ptr += 1 - image_sizes += frames[0].size * len(frames) - - # Generate a hashable value for the image file - if isinstance(image_file, Image.Image): - # For PIL.Image objects, use the ID as a hashable value - hash_value = hash(id(image_file)) - else: - # For other types (strings, etc.), use the regular hash - hash_value = hash(image_file) - - hashes += [hash_value] * len(frames) - images += frames - image_index += 1 - if frames_to_process != 0: + if task_type == Modality.IMAGE: + frames = [result] if not isinstance(result, list) else result + if frames: + image_sizes += frames[0].size * len(frames) + images += frames new_text += multimodal_tokens.image_token * len(frames) - assert frames_to_process == len(frames) - elif text_part == multimodal_tokens.audio_token: - # load as audio - audio_file = audio_data[audio_index] - audio = load_audio(audio_file) - hashes += [hash(audio_file)] - audios += [audio] - audio_index += 1 + elif task_type == Modality.AUDIO: + # audio + audios.append(result) new_text += multimodal_tokens.audio_token - else: - # TODO(mick): handle video - # normal text - new_text += text_part - - except Exception as e: - logger.error(f"An exception occurred while loading images: {e}") - raise RuntimeError(f"An exception occurred while loading images: {e}") + # TODO: handle video + else: + new_text += text_part out = BaseMultiModalProcessorOutput( images=images, diff --git a/python/sglang/srt/managers/multimodal_processors/janus_pro.py b/python/sglang/srt/managers/multimodal_processors/janus_pro.py index 58fcc180b..6db62b11e 100644 --- a/python/sglang/srt/managers/multimodal_processors/janus_pro.py +++ b/python/sglang/srt/managers/multimodal_processors/janus_pro.py @@ -33,7 +33,9 @@ class JanusProImageProcessor(BaseMultimodalProcessor): base_out = self.load_mm_data( prompt=input_ids, image_data=image_data, - multimodal_tokens=MultimodalSpecialTokens(image_token=processor.image_tag), + multimodal_tokens=MultimodalSpecialTokens( + image_token=processor.image_token + ), max_req_input_len=max_req_input_len, ) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 86612a135..020dbcf4b 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -222,10 +222,10 @@ class MultimodalDataItem: # memoryview() doesn't support PyTorch's BFloat16 dtype tensor = tensor.float() + assert isinstance(tensor, torch.Tensor) if tensor.is_cuda: - tensor_cpu = torch.frombuffer( - tensor.storage().untyped(), dtype=tensor.dtype, count=tensor.numel() - ).clone() + # TODO: improve this + tensor_cpu = tensor.cpu() else: tensor_cpu = tensor @@ -321,7 +321,6 @@ class MultimodalInputs: item.set_pad_value() optional_args = [ - "modalities", "im_token_id", "im_start_id", "im_end_id", diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 2c69191c3..2bba79770 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -452,6 +452,7 @@ class Scheduler( tokenizer_mode=server_args.tokenizer_mode, trust_remote_code=server_args.trust_remote_code, revision=server_args.revision, + use_fast=not server_args.disable_fast_image_processor, ) self.tokenizer = self.processor.tokenizer else: diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 9f18ae63c..69df67058 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -180,6 +180,7 @@ class TokenizerManager: tokenizer_mode=server_args.tokenizer_mode, trust_remote_code=server_args.trust_remote_code, revision=server_args.revision, + use_fast=not server_args.disable_fast_image_processor, ) # We want to parallelize the image pre-processing so we create an executor for it diff --git a/python/sglang/srt/models/qwen2_5_vl.py b/python/sglang/srt/models/qwen2_5_vl.py index a9a586cf5..a98666c80 100644 --- a/python/sglang/srt/models/qwen2_5_vl.py +++ b/python/sglang/srt/models/qwen2_5_vl.py @@ -462,6 +462,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module): quant_config=quant_config, prefix=add_prefix("lm_head", prefix), ) + self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling self.logits_processor = LogitsProcessor(config) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) @@ -515,15 +516,14 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module): otherwise it will be `(seq_len,). (Use input_metadata.mrope_positions to replace it) """ - is_mrope_enabled = "mrope_section" in self.config.rope_scaling - if is_mrope_enabled: + if self.is_mrope_enabled: positions = forward_batch.mrope_positions if not ( forward_batch.forward_mode.is_decode() or not forward_batch.contains_image_inputs() ): - if is_mrope_enabled: + if self.is_mrope_enabled: assert positions.ndim == 2 and positions.size(0) == 3, ( "multimodal section rotary embedding requires " f"(3, seq_len) positions, but got {positions.size()}" diff --git a/python/sglang/srt/models/qwen2_vl.py b/python/sglang/srt/models/qwen2_vl.py index da878e867..8871ea1f4 100644 --- a/python/sglang/srt/models/qwen2_vl.py +++ b/python/sglang/srt/models/qwen2_vl.py @@ -467,6 +467,7 @@ class Qwen2VLForConditionalGeneration(nn.Module): prefix=add_prefix("lm_head", prefix), ) + self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling self.logits_processor = LogitsProcessor(config) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) @@ -521,15 +522,14 @@ class Qwen2VLForConditionalGeneration(nn.Module): otherwise it will be `(seq_len,). (Use input_metadata.mrope_positions to replace it) """ - is_mrope_enabled = "mrope_section" in self.config.rope_scaling - if is_mrope_enabled: + if self.is_mrope_enabled: positions = forward_batch.mrope_positions if not ( forward_batch.forward_mode.is_decode() or not forward_batch.contains_image_inputs() ): - if is_mrope_enabled: + if self.is_mrope_enabled: assert positions.ndim == 2 and positions.size(0) == 3, ( "multimodal section rotary embedding requires " f"(3, seq_len) positions, but got {positions.size()}" diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 6580f7688..59b744c14 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -196,6 +196,9 @@ class ServerArgs: disaggregation_mode: str = "null" disaggregation_bootstrap_port: int = 8998 + # multimodal + disable_fast_image_processor: bool = False + def __post_init__(self): # Expert parallelism if self.enable_ep_moe: @@ -979,6 +982,7 @@ class ServerArgs: ) parser.add_argument( "--enable-llama4-multimodal", + default=ServerArgs.enable_llama4_multimodal, action="store_true", help="Enable the multimodal functionality for Llama-4.", ) @@ -1170,6 +1174,13 @@ class ServerArgs: help="Bootstrap server port on the prefill server. Default is 8998.", ) + # Multimodal + parser.add_argument( + "--disable-fast-image-processor", + action="store_true", + help="Adopt base image processor instead of fast image processor.", + ) + @classmethod def from_cli_args(cls, args: argparse.Namespace): args.tp_size = args.tensor_parallel_size