[VLM] Adopt fast image processor by default (#5065)

This commit is contained in:
Mick
2025-04-12 12:46:58 +08:00
committed by GitHub
parent 611720919d
commit 34ef6c8135
12 changed files with 163 additions and 98 deletions

View File

@@ -89,5 +89,4 @@ if __name__ == "__main__":
EvalArgs.add_cli_args(parser) EvalArgs.add_cli_args(parser)
args = add_common_sglang_args_and_parse(parser) args = add_common_sglang_args_and_parse(parser)
args = parser.parse_args() args = parser.parse_args()
eval_mmmu(args) eval_mmmu(args)

View File

@@ -7,6 +7,7 @@ import os
import pprint import pprint
import random import random
import re import re
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Dict, Optional from typing import Dict, Optional
import numpy as np import numpy as np
@@ -117,29 +118,38 @@ def prepare_samples(eval_args: EvalArgs):
# merge all dataset # merge all dataset
dataset = concatenate_datasets(sub_dataset_list) dataset = concatenate_datasets(sub_dataset_list)
## prepare images # Prepare images in parallel
samples = []
skip_count = 0
# use image file as input to ensure the consistency between sglang and hf
images_path = os.path.expanduser("~/.cache/mmmu/images") images_path = os.path.expanduser("~/.cache/mmmu/images")
os.makedirs(images_path, exist_ok=True) os.makedirs(images_path, exist_ok=True)
print(f"Saving images to: {images_path}") print(f"Saving images to: {images_path}")
for i, sample in enumerate(tqdm(dataset)): samples = []
skip_count = 0
def process_sample(i, sample):
sample = process_single_sample(sample) sample = process_single_sample(sample)
sample = construct_prompt(sample, eval_args.config) sample = construct_prompt(sample, eval_args.config)
image = sample["image"] image = sample["image"]
width, height = image.size width, height = image.size
if width * height >= eval_args.image_pixels_limit: if width * height >= eval_args.image_pixels_limit:
skip_count += 1 return None, True
continue
image_path = f"{images_path}/image_{i}.png" image_path = f"{images_path}/image_{i}.png"
if not os.path.exists(image_path): if not os.path.exists(image_path):
image.save(image_path) image.save(image_path)
sample["image_path"] = image_path sample["image_path"] = image_path
samples.append(sample) return sample, False
with ThreadPoolExecutor() as executor:
futures = [
executor.submit(process_sample, i, sample)
for i, sample in enumerate(dataset)
]
for future in tqdm(as_completed(futures), total=len(futures)):
sample, skipped = future.result()
if skipped:
skip_count += 1
elif sample:
samples.append(sample)
print( print(
f"skipping {skip_count} samples with large images, {round((float(skip_count) / len(dataset)) * 100, 2)}% of dataset" f"skipping {skip_count} samples with large images, {round((float(skip_count) / len(dataset)) * 100, 2)}% of dataset"

View File

@@ -45,7 +45,7 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct
Please consult the documentation below to learn more about the parameters you may provide when launching a server. Please consult the documentation below to learn more about the parameters you may provide when launching a server.
## Model and tokenizer ## Model, processor and tokenizer
* `model_path`: Path to the model that will be served. * `model_path`: Path to the model that will be served.
* `tokenizer_path`: Defaults to the `model_path`. * `tokenizer_path`: Defaults to the `model_path`.
@@ -62,6 +62,7 @@ Please consult the documentation below to learn more about the parameters you ma
* `skip_tokenizer_init`: Set to true to provide the tokens to the engine and get the output tokens directly, typically used in RLHF. Please see this [example for reference](https://github.com/sgl-project/sglang/blob/main/examples/runtime/token_in_token_out/). * `skip_tokenizer_init`: Set to true to provide the tokens to the engine and get the output tokens directly, typically used in RLHF. Please see this [example for reference](https://github.com/sgl-project/sglang/blob/main/examples/runtime/token_in_token_out/).
* `json_model_override_args`: Override model config with the provided JSON. * `json_model_override_args`: Override model config with the provided JSON.
* `delete_ckpt_after_loading`: Delete the model checkpoint after loading the model. * `delete_ckpt_after_loading`: Delete the model checkpoint after loading the model.
* `disable_fast_image_processor`: Adopt base image processor instead of fast image processor(which is by default). For more detail, see: https://huggingface.co/docs/transformers/main/en/main_classes/image_processor#image-processor
## Serving: HTTP & API ## Serving: HTTP & API

View File

@@ -215,6 +215,7 @@ def get_processor(
tokenizer_mode: str = "auto", tokenizer_mode: str = "auto",
trust_remote_code: bool = False, trust_remote_code: bool = False,
tokenizer_revision: Optional[str] = None, tokenizer_revision: Optional[str] = None,
use_fast: Optional[bool] = True,
**kwargs, **kwargs,
): ):
# pop 'revision' from kwargs if present. # pop 'revision' from kwargs if present.
@@ -232,6 +233,9 @@ def get_processor(
if "size" not in kwargs: if "size" not in kwargs:
kwargs["size"] = {"shortest_edge": 3136, "longest_edge": 1003520} kwargs["size"] = {"shortest_edge": 3136, "longest_edge": 1003520}
if config.model_type not in {"llava", "clip"}:
kwargs["use_fast"] = use_fast
processor = AutoProcessor.from_pretrained( processor = AutoProcessor.from_pretrained(
tokenizer_name, tokenizer_name,
*args, *args,

View File

@@ -4,14 +4,16 @@ import dataclasses
import multiprocessing as mp import multiprocessing as mp
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Optional from typing import List, Optional
import numpy as np import numpy as np
import PIL import PIL
from decord import VideoReader, cpu from decord import VideoReader, cpu
from PIL import Image from PIL import Image
from transformers import BaseImageProcessorFast
from sglang.srt.utils import encode_video, load_audio, load_image, logger from sglang.srt.managers.schedule_batch import Modality
from sglang.srt.utils import encode_video, load_audio, load_image
@dataclasses.dataclass @dataclasses.dataclass
@@ -78,6 +80,10 @@ class BaseMultimodalProcessor(ABC):
kwargs["audios"] = audios kwargs["audios"] = audios
processor = self._processor processor = self._processor
if hasattr(processor, "image_processor") and isinstance(
processor.image_processor, BaseImageProcessorFast
):
kwargs["device"] = "cuda"
result = processor.__call__( result = processor.__call__(
text=[input_text], text=[input_text],
padding=True, padding=True,
@@ -111,6 +117,84 @@ class BaseMultimodalProcessor(ABC):
return estimated_frames_list return estimated_frames_list
@staticmethod
def _load_single_item(
data, is_video, is_audio, frame_count_limit=None, discard_alpha_channel=True
):
"""Static method that can be pickled for multiprocessing"""
try:
if is_audio:
return load_audio(data)
elif is_video:
path = data[len("video:") :]
return encode_video(path, frame_count_limit)
else:
img, _ = load_image(data)
return img.convert("RGB") if discard_alpha_channel else img
except Exception as e:
raise RuntimeError(f"Error while loading data {data}: {e}")
def submit_data_loading_tasks(
self,
text_parts: List[str],
multimodal_tokens: MultimodalSpecialTokens,
image_data: Optional[list] = None,
audio_data: Optional[list] = None,
discard_alpha_channel: bool = True,
):
"""
load multimodal data parallelly
"""
# TODO(mick): load from server_args, env, or sampling_params
MAX_NUM_FRAMES = 30
estimated_frames_list = self.get_estimated_frames_list(image_data=image_data)
total_frame_count = sum(estimated_frames_list)
# a heuristic value, suggesting the maximum fraction of frames to embed from all visual inputs.
# e.g., 0.1 suggests that 1 frame out of 10 input frames should be used
scaling_factor = min(1.0, MAX_NUM_FRAMES / max(1, total_frame_count))
assert len(image_data) == len(estimated_frames_list)
# Submit all tasks
futures = []
task_info = []
image_index, audio_index = 0, 0
for text_part in text_parts:
if text_part == multimodal_tokens.image_token:
data = image_data[image_index]
is_video = isinstance(data, str) and data.startswith("video:")
estimated_frames = estimated_frames_list[image_index]
frame_count_limit = max(1, int(estimated_frames * scaling_factor))
futures.append(
self.io_executor.submit(
BaseMultimodalProcessor._load_single_item,
data,
is_video,
False,
frame_count_limit,
discard_alpha_channel,
)
)
task_info.append((Modality.IMAGE, data, frame_count_limit))
image_index += 1
elif text_part == multimodal_tokens.audio_token:
data = audio_data[audio_index]
futures.append(
self.io_executor.submit(
BaseMultimodalProcessor._load_single_item,
data,
False,
True,
None,
discard_alpha_channel,
)
)
task_info.append((Modality.AUDIO, data, None))
audio_index += 1
return futures, task_info
def load_mm_data( def load_mm_data(
self, self,
prompt: str, prompt: str,
@@ -155,84 +239,37 @@ class BaseMultimodalProcessor(ABC):
# split text into list of normal text and special tokens # split text into list of normal text and special tokens
text_parts = re.split(pattern, prompt) text_parts = re.split(pattern, prompt)
# TODO(mick): load from server_args, env, or sampling_params futures, task_info = self.submit_data_loading_tasks(
MAX_NUM_FRAMES = 30 text_parts=text_parts,
estimated_frames_list = self.get_estimated_frames_list(image_data=image_data) multimodal_tokens=multimodal_tokens,
total_frame_count = sum(estimated_frames_list) image_data=image_data,
# a heuristic value, suggesting the maximum fraction of frames to embed from all visual inputs. audio_data=audio_data,
# e.g., 0.1 suggests that 1 frame out of 10 input frames should be used discard_alpha_channel=discard_alpha_channel,
scaling_factor = min(1.0, MAX_NUM_FRAMES / max(1, total_frame_count)) )
# Process results
assert len(image_data) == len(estimated_frames_list) image_sizes, images, audios = [], [], []
image_index, audio_index = 0, 0
hashes, image_sizes, images, audios = [], [], [], []
new_text = "" new_text = ""
for index, text_part in enumerate(text_parts): task_ptr = 0
try:
if text_part == multimodal_tokens.image_token:
# load as image
if len(images) >= MAX_NUM_FRAMES:
frames_to_process = 0
else:
estimated_frames = estimated_frames_list[image_index]
frames_to_process = max(
1, int(estimated_frames * scaling_factor)
)
if frames_to_process == 0: for text_part in text_parts:
frames = [] if text_part in multimodal_tokens.collect():
else: task_type, data, frame_limit = task_info[task_ptr]
image_file = image_data[image_index] result = futures[task_ptr].result()
if isinstance(image_file, str) and image_file.startswith( task_ptr += 1
"video:"
):
# video
path = image_file[len("video:") :]
frames = encode_video(
path, frame_count_limit=frames_to_process
)
else:
# image
raw_image, _size = load_image(image_file)
if discard_alpha_channel:
raw_image = raw_image.convert("RGB")
frames = [raw_image]
if len(frames) == 0:
continue
image_sizes += frames[0].size * len(frames) if task_type == Modality.IMAGE:
frames = [result] if not isinstance(result, list) else result
# Generate a hashable value for the image file if frames:
if isinstance(image_file, Image.Image): image_sizes += frames[0].size * len(frames)
# For PIL.Image objects, use the ID as a hashable value images += frames
hash_value = hash(id(image_file))
else:
# For other types (strings, etc.), use the regular hash
hash_value = hash(image_file)
hashes += [hash_value] * len(frames)
images += frames
image_index += 1
if frames_to_process != 0:
new_text += multimodal_tokens.image_token * len(frames) new_text += multimodal_tokens.image_token * len(frames)
assert frames_to_process == len(frames) elif task_type == Modality.AUDIO:
elif text_part == multimodal_tokens.audio_token: # audio
# load as audio audios.append(result)
audio_file = audio_data[audio_index]
audio = load_audio(audio_file)
hashes += [hash(audio_file)]
audios += [audio]
audio_index += 1
new_text += multimodal_tokens.audio_token new_text += multimodal_tokens.audio_token
else: # TODO: handle video
# TODO(mick): handle video else:
# normal text new_text += text_part
new_text += text_part
except Exception as e:
logger.error(f"An exception occurred while loading images: {e}")
raise RuntimeError(f"An exception occurred while loading images: {e}")
out = BaseMultiModalProcessorOutput( out = BaseMultiModalProcessorOutput(
images=images, images=images,

View File

@@ -33,7 +33,9 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
base_out = self.load_mm_data( base_out = self.load_mm_data(
prompt=input_ids, prompt=input_ids,
image_data=image_data, image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens(image_token=processor.image_tag), multimodal_tokens=MultimodalSpecialTokens(
image_token=processor.image_token
),
max_req_input_len=max_req_input_len, max_req_input_len=max_req_input_len,
) )

View File

@@ -222,10 +222,10 @@ class MultimodalDataItem:
# memoryview() doesn't support PyTorch's BFloat16 dtype # memoryview() doesn't support PyTorch's BFloat16 dtype
tensor = tensor.float() tensor = tensor.float()
assert isinstance(tensor, torch.Tensor)
if tensor.is_cuda: if tensor.is_cuda:
tensor_cpu = torch.frombuffer( # TODO: improve this
tensor.storage().untyped(), dtype=tensor.dtype, count=tensor.numel() tensor_cpu = tensor.cpu()
).clone()
else: else:
tensor_cpu = tensor tensor_cpu = tensor
@@ -321,7 +321,6 @@ class MultimodalInputs:
item.set_pad_value() item.set_pad_value()
optional_args = [ optional_args = [
"modalities",
"im_token_id", "im_token_id",
"im_start_id", "im_start_id",
"im_end_id", "im_end_id",

View File

@@ -452,6 +452,7 @@ class Scheduler(
tokenizer_mode=server_args.tokenizer_mode, tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code, trust_remote_code=server_args.trust_remote_code,
revision=server_args.revision, revision=server_args.revision,
use_fast=not server_args.disable_fast_image_processor,
) )
self.tokenizer = self.processor.tokenizer self.tokenizer = self.processor.tokenizer
else: else:

View File

@@ -180,6 +180,7 @@ class TokenizerManager:
tokenizer_mode=server_args.tokenizer_mode, tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code, trust_remote_code=server_args.trust_remote_code,
revision=server_args.revision, revision=server_args.revision,
use_fast=not server_args.disable_fast_image_processor,
) )
# We want to parallelize the image pre-processing so we create an executor for it # We want to parallelize the image pre-processing so we create an executor for it

View File

@@ -462,6 +462,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("lm_head", prefix), prefix=add_prefix("lm_head", prefix),
) )
self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
@@ -515,15 +516,14 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
otherwise it will be `(seq_len,). otherwise it will be `(seq_len,).
(Use input_metadata.mrope_positions to replace it) (Use input_metadata.mrope_positions to replace it)
""" """
is_mrope_enabled = "mrope_section" in self.config.rope_scaling if self.is_mrope_enabled:
if is_mrope_enabled:
positions = forward_batch.mrope_positions positions = forward_batch.mrope_positions
if not ( if not (
forward_batch.forward_mode.is_decode() forward_batch.forward_mode.is_decode()
or not forward_batch.contains_image_inputs() or not forward_batch.contains_image_inputs()
): ):
if is_mrope_enabled: if self.is_mrope_enabled:
assert positions.ndim == 2 and positions.size(0) == 3, ( assert positions.ndim == 2 and positions.size(0) == 3, (
"multimodal section rotary embedding requires " "multimodal section rotary embedding requires "
f"(3, seq_len) positions, but got {positions.size()}" f"(3, seq_len) positions, but got {positions.size()}"

View File

@@ -467,6 +467,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
prefix=add_prefix("lm_head", prefix), prefix=add_prefix("lm_head", prefix),
) )
self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
@@ -521,15 +522,14 @@ class Qwen2VLForConditionalGeneration(nn.Module):
otherwise it will be `(seq_len,). otherwise it will be `(seq_len,).
(Use input_metadata.mrope_positions to replace it) (Use input_metadata.mrope_positions to replace it)
""" """
is_mrope_enabled = "mrope_section" in self.config.rope_scaling if self.is_mrope_enabled:
if is_mrope_enabled:
positions = forward_batch.mrope_positions positions = forward_batch.mrope_positions
if not ( if not (
forward_batch.forward_mode.is_decode() forward_batch.forward_mode.is_decode()
or not forward_batch.contains_image_inputs() or not forward_batch.contains_image_inputs()
): ):
if is_mrope_enabled: if self.is_mrope_enabled:
assert positions.ndim == 2 and positions.size(0) == 3, ( assert positions.ndim == 2 and positions.size(0) == 3, (
"multimodal section rotary embedding requires " "multimodal section rotary embedding requires "
f"(3, seq_len) positions, but got {positions.size()}" f"(3, seq_len) positions, but got {positions.size()}"

View File

@@ -196,6 +196,9 @@ class ServerArgs:
disaggregation_mode: str = "null" disaggregation_mode: str = "null"
disaggregation_bootstrap_port: int = 8998 disaggregation_bootstrap_port: int = 8998
# multimodal
disable_fast_image_processor: bool = False
def __post_init__(self): def __post_init__(self):
# Expert parallelism # Expert parallelism
if self.enable_ep_moe: if self.enable_ep_moe:
@@ -979,6 +982,7 @@ class ServerArgs:
) )
parser.add_argument( parser.add_argument(
"--enable-llama4-multimodal", "--enable-llama4-multimodal",
default=ServerArgs.enable_llama4_multimodal,
action="store_true", action="store_true",
help="Enable the multimodal functionality for Llama-4.", help="Enable the multimodal functionality for Llama-4.",
) )
@@ -1170,6 +1174,13 @@ class ServerArgs:
help="Bootstrap server port on the prefill server. Default is 8998.", help="Bootstrap server port on the prefill server. Default is 8998.",
) )
# Multimodal
parser.add_argument(
"--disable-fast-image-processor",
action="store_true",
help="Adopt base image processor instead of fast image processor.",
)
@classmethod @classmethod
def from_cli_args(cls, args: argparse.Namespace): def from_cli_args(cls, args: argparse.Namespace):
args.tp_size = args.tensor_parallel_size args.tp_size = args.tensor_parallel_size