From 5cb552b1d461844e4be280060301509fc6ff6cc6 Mon Sep 17 00:00:00 2001 From: Mick Date: Tue, 1 Apr 2025 00:57:51 +0800 Subject: [PATCH] refactor: multimodal data (#4754) --- benchmark/mmmu/bench_hf.py | 43 ++- benchmark/mmmu/eval_utils.py | 2 + python/sglang/srt/managers/mm_utils.py | 360 ++++++++++-------- .../srt/managers/multimodal_processor.py | 2 - .../multimodal_processors/base_processor.py | 106 ++---- .../managers/multimodal_processors/clip.py | 33 +- .../multimodal_processors/deepseek_vl_v2.py | 75 +--- .../managers/multimodal_processors/gemma3.py | 39 +- .../multimodal_processors/janus_pro.py | 68 +--- .../managers/multimodal_processors/llava.py | 35 +- .../managers/multimodal_processors/minicpm.py | 73 ++-- .../managers/multimodal_processors/mlama.py | 33 +- .../managers/multimodal_processors/qwen_vl.py | 65 +--- python/sglang/srt/managers/schedule_batch.py | 212 ++++++----- python/sglang/srt/managers/scheduler.py | 2 +- python/sglang/srt/managers/utils.py | 7 +- .../srt/model_executor/forward_batch_info.py | 5 - .../sglang/srt/model_executor/model_runner.py | 24 +- python/sglang/srt/models/clip.py | 19 +- .../sglang/srt/models/deepseek_janus_pro.py | 25 +- python/sglang/srt/models/deepseek_v2.py | 3 + python/sglang/srt/models/deepseek_vl2.py | 213 +++++------ python/sglang/srt/models/gemma3_mm.py | 94 +---- python/sglang/srt/models/llama.py | 3 + python/sglang/srt/models/llava.py | 50 ++- python/sglang/srt/models/llavavid.py | 23 +- python/sglang/srt/models/minicpmo.py | 204 +++------- python/sglang/srt/models/minicpmv.py | 44 +-- python/sglang/srt/models/mllama.py | 43 ++- python/sglang/srt/models/qwen2.py | 15 +- python/sglang/srt/models/qwen2_5_vl.py | 52 +-- python/sglang/srt/models/qwen2_vl.py | 41 +- python/sglang/srt/openai_api/adapter.py | 24 +- python/sglang/srt/utils.py | 42 +- test/srt/test_vision_openai_server.py | 12 +- test/srt/test_vlm_accuracy.py | 36 +- 36 files changed, 989 insertions(+), 1138 deletions(-) diff --git a/benchmark/mmmu/bench_hf.py b/benchmark/mmmu/bench_hf.py index c6588c7b9..6735ce7ec 100644 --- a/benchmark/mmmu/bench_hf.py +++ b/benchmark/mmmu/bench_hf.py @@ -72,17 +72,38 @@ def eval_mmmu(args): if suffix: contents += [{"type": "text", "text": suffix}] messages = [{"role": "user", "content": contents}] - model_inputs = processor.apply_chat_template( - messages, - tokenize=True, - return_dict=True, - add_generation_prompt=True, - return_tensors="pt", - ).to(model.device) - input_len = model_inputs["input_ids"].shape[-1] - generation = model.generate(**model_inputs, generation_config=generation_config) - generation = generation[0][input_len:] - response = processor.decode(generation, skip_special_tokens=True) + try: + model_inputs = processor.tokenizer.apply_chat_template( + messages, + tokenize=True, + return_dict=True, + add_generation_prompt=True, + return_tensors="pt", + ).to(model.device) + input_len = model_inputs["input_ids"].shape[-1] + generation = model.generate( + **model_inputs, generation_config=generation_config + ) + generation = generation[0][input_len:] + response = processor.decode(generation, skip_special_tokens=True) + except: + contents = [] + if prefix: + contents += [prefix] + image = PIL.Image.open(sample["image_path"]) + contents += [image] + if suffix: + contents += [suffix] + messages = [{"role": "user", "content": contents}] + response = model.chat( + msgs=messages, + tokenizer=processor.tokenizer, + sampling=False, + max_new_tokens=sampling_params["max_new_tokens"], + use_tts_template=False, + generate_audio=False, + temperature=0.0, + ) print(f"response: {response}") process_result(response, sample, answer_dict, out_samples) diff --git a/benchmark/mmmu/eval_utils.py b/benchmark/mmmu/eval_utils.py index 046bed6b2..2a4c9a939 100644 --- a/benchmark/mmmu/eval_utils.py +++ b/benchmark/mmmu/eval_utils.py @@ -442,6 +442,8 @@ def calculate_ins_level_acc(results: Dict): def process_result(response, sample, answer_dict, out_samples): + if response is None: + return if sample["question_type"] == "multiple-choice": pred_ans = parse_multi_choice_response( response, sample["all_choices"], sample["index2ans"] diff --git a/python/sglang/srt/managers/mm_utils.py b/python/sglang/srt/managers/mm_utils.py index ea51fdeff..6d1a33455 100644 --- a/python/sglang/srt/managers/mm_utils.py +++ b/python/sglang/srt/managers/mm_utils.py @@ -1,5 +1,5 @@ """ - Multimodality utils + Multi-modality utils """ from abc import abstractmethod @@ -9,11 +9,13 @@ import torch from torch import nn from sglang.srt.managers.schedule_batch import ( + MultimodalDataItem, MultimodalInputs, global_server_args_dict, logger, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.utils import print_warning_once from sglang.utils import logger @@ -26,7 +28,7 @@ class MultiModalityDataPaddingPattern: @abstractmethod def pad_input_tokens( - self, input_ids: List[int], image_inputs: MultimodalInputs + self, input_ids: List[int], mm_inputs: MultimodalInputs ) -> List[int]: """ Pad the input ids sequence containing data tokens, and replace them with pad_values @@ -49,13 +51,13 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern) """ This function will replace the data-tokens inbetween with pad_values accordingly """ - pad_values = mm_inputs.pad_values + pad_values = [item.pad_value for item in mm_inputs.mm_items] data_token_pairs = self.data_token_id_pairs - mm_inputs.image_offsets = [] + mm_inputs.data_offsets = [] if data_token_pairs is None: data_token_pairs = [mm_inputs.im_start_id, mm_inputs.im_end_id] if data_token_pairs is None: - logger.warning( + print_warning_once( "No data_token_pairs provided, RadixAttention might be influenced." ) return input_ids @@ -77,10 +79,10 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern) if input_ids[start_idx] in start_token_ids: data_idx += 1 - mm_inputs.image_offsets += [start_idx] + mm_inputs.data_offsets += [start_idx] - if data_idx >= len(mm_inputs.pad_values): - data_idx = len(mm_inputs.pad_values) - 1 + if data_idx >= len(pad_values): + data_idx = len(pad_values) - 1 num_tokens = end_idx - start_idx - 1 pad_value = pad_values[data_idx] @@ -94,68 +96,19 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern) return padded_ids -class MultModalityDataPaddingPatternSingleToken(MultiModalityDataPaddingPattern): - """In this pattern, data is represented with a special token_id ( image_inputs.im_token_id ), - which needs first to be expanded to multiple tokens, then replaced with their padding values - - This strategy should be used when a single data token represents content that should - be expanded to multiple tokens during processing. - """ - - def __init__( - self, num_data_token_calc_func: Callable[[Tuple[int, int, int]], int] - ) -> None: - self.num_data_token_calc_func = num_data_token_calc_func - - def pad_input_tokens( - self, input_ids: List[int], mm_inputs: MultimodalInputs - ) -> List[int]: - """ - This function will follow the procedure of: - 1. the data token will be expanded, of which the final number will be calculated by `num_data_token_calc_func` - 2. the padded data tokens will be replaced with their pad_values - """ - image_grid_thws = mm_inputs.image_grid_thws - pad_values = mm_inputs.pad_values - - image_indices = [ - idx for idx, token in enumerate(input_ids) if token == mm_inputs.im_token_id - ] - - mm_inputs.image_offsets = [] - - input_ids_with_image = [] - for image_cnt, _ in enumerate(image_grid_thws): - # print(f"image_cnt {image_cnt}") - num_image_tokens = self.num_data_token_calc_func(image_grid_thws[image_cnt]) - if image_cnt == 0: - non_image_tokens = input_ids[: image_indices[image_cnt]] - else: - non_image_tokens = input_ids[ - image_indices[image_cnt - 1] + 1 : image_indices[image_cnt] - ] - input_ids_with_image.extend(non_image_tokens) - mm_inputs.image_offsets.append(len(input_ids_with_image)) - pad_ids = pad_values * ( - (num_image_tokens + len(pad_values)) // len(pad_values) - ) - input_ids_with_image.extend(pad_ids[:num_image_tokens]) - input_ids_with_image.extend(input_ids[image_indices[-1] + 1 :]) - - return input_ids_with_image - - class MultiModalityDataPaddingPatternImageTokens(MultiModalityDataPaddingPattern): - """In this pattern, data tokens should be represented as image tokens (e.g. ....)""" + """In this pattern, data tokens should be represented as repetitions of a single token + e.g. ...., or