[VLM] Adopt fast image processor by default (#5065)

This commit is contained in:
Mick
2025-04-12 12:46:58 +08:00
committed by GitHub
parent 611720919d
commit 34ef6c8135
12 changed files with 163 additions and 98 deletions

View File

@@ -215,6 +215,7 @@ def get_processor(
tokenizer_mode: str = "auto",
trust_remote_code: bool = False,
tokenizer_revision: Optional[str] = None,
use_fast: Optional[bool] = True,
**kwargs,
):
# pop 'revision' from kwargs if present.
@@ -232,6 +233,9 @@ def get_processor(
if "size" not in kwargs:
kwargs["size"] = {"shortest_edge": 3136, "longest_edge": 1003520}
if config.model_type not in {"llava", "clip"}:
kwargs["use_fast"] = use_fast
processor = AutoProcessor.from_pretrained(
tokenizer_name,
*args,

View File

@@ -4,14 +4,16 @@ import dataclasses
import multiprocessing as mp
import os
from abc import ABC, abstractmethod
from typing import Optional
from typing import List, Optional
import numpy as np
import PIL
from decord import VideoReader, cpu
from PIL import Image
from transformers import BaseImageProcessorFast
from sglang.srt.utils import encode_video, load_audio, load_image, logger
from sglang.srt.managers.schedule_batch import Modality
from sglang.srt.utils import encode_video, load_audio, load_image
@dataclasses.dataclass
@@ -78,6 +80,10 @@ class BaseMultimodalProcessor(ABC):
kwargs["audios"] = audios
processor = self._processor
if hasattr(processor, "image_processor") and isinstance(
processor.image_processor, BaseImageProcessorFast
):
kwargs["device"] = "cuda"
result = processor.__call__(
text=[input_text],
padding=True,
@@ -111,6 +117,84 @@ class BaseMultimodalProcessor(ABC):
return estimated_frames_list
@staticmethod
def _load_single_item(
data, is_video, is_audio, frame_count_limit=None, discard_alpha_channel=True
):
"""Static method that can be pickled for multiprocessing"""
try:
if is_audio:
return load_audio(data)
elif is_video:
path = data[len("video:") :]
return encode_video(path, frame_count_limit)
else:
img, _ = load_image(data)
return img.convert("RGB") if discard_alpha_channel else img
except Exception as e:
raise RuntimeError(f"Error while loading data {data}: {e}")
def submit_data_loading_tasks(
self,
text_parts: List[str],
multimodal_tokens: MultimodalSpecialTokens,
image_data: Optional[list] = None,
audio_data: Optional[list] = None,
discard_alpha_channel: bool = True,
):
"""
load multimodal data parallelly
"""
# TODO(mick): load from server_args, env, or sampling_params
MAX_NUM_FRAMES = 30
estimated_frames_list = self.get_estimated_frames_list(image_data=image_data)
total_frame_count = sum(estimated_frames_list)
# a heuristic value, suggesting the maximum fraction of frames to embed from all visual inputs.
# e.g., 0.1 suggests that 1 frame out of 10 input frames should be used
scaling_factor = min(1.0, MAX_NUM_FRAMES / max(1, total_frame_count))
assert len(image_data) == len(estimated_frames_list)
# Submit all tasks
futures = []
task_info = []
image_index, audio_index = 0, 0
for text_part in text_parts:
if text_part == multimodal_tokens.image_token:
data = image_data[image_index]
is_video = isinstance(data, str) and data.startswith("video:")
estimated_frames = estimated_frames_list[image_index]
frame_count_limit = max(1, int(estimated_frames * scaling_factor))
futures.append(
self.io_executor.submit(
BaseMultimodalProcessor._load_single_item,
data,
is_video,
False,
frame_count_limit,
discard_alpha_channel,
)
)
task_info.append((Modality.IMAGE, data, frame_count_limit))
image_index += 1
elif text_part == multimodal_tokens.audio_token:
data = audio_data[audio_index]
futures.append(
self.io_executor.submit(
BaseMultimodalProcessor._load_single_item,
data,
False,
True,
None,
discard_alpha_channel,
)
)
task_info.append((Modality.AUDIO, data, None))
audio_index += 1
return futures, task_info
def load_mm_data(
self,
prompt: str,
@@ -155,84 +239,37 @@ class BaseMultimodalProcessor(ABC):
# split text into list of normal text and special tokens
text_parts = re.split(pattern, prompt)
# TODO(mick): load from server_args, env, or sampling_params
MAX_NUM_FRAMES = 30
estimated_frames_list = self.get_estimated_frames_list(image_data=image_data)
total_frame_count = sum(estimated_frames_list)
# a heuristic value, suggesting the maximum fraction of frames to embed from all visual inputs.
# e.g., 0.1 suggests that 1 frame out of 10 input frames should be used
scaling_factor = min(1.0, MAX_NUM_FRAMES / max(1, total_frame_count))
assert len(image_data) == len(estimated_frames_list)
image_index, audio_index = 0, 0
hashes, image_sizes, images, audios = [], [], [], []
futures, task_info = self.submit_data_loading_tasks(
text_parts=text_parts,
multimodal_tokens=multimodal_tokens,
image_data=image_data,
audio_data=audio_data,
discard_alpha_channel=discard_alpha_channel,
)
# Process results
image_sizes, images, audios = [], [], []
new_text = ""
for index, text_part in enumerate(text_parts):
try:
if text_part == multimodal_tokens.image_token:
# load as image
if len(images) >= MAX_NUM_FRAMES:
frames_to_process = 0
else:
estimated_frames = estimated_frames_list[image_index]
frames_to_process = max(
1, int(estimated_frames * scaling_factor)
)
task_ptr = 0
if frames_to_process == 0:
frames = []
else:
image_file = image_data[image_index]
if isinstance(image_file, str) and image_file.startswith(
"video:"
):
# video
path = image_file[len("video:") :]
frames = encode_video(
path, frame_count_limit=frames_to_process
)
else:
# image
raw_image, _size = load_image(image_file)
if discard_alpha_channel:
raw_image = raw_image.convert("RGB")
frames = [raw_image]
if len(frames) == 0:
continue
for text_part in text_parts:
if text_part in multimodal_tokens.collect():
task_type, data, frame_limit = task_info[task_ptr]
result = futures[task_ptr].result()
task_ptr += 1
image_sizes += frames[0].size * len(frames)
# Generate a hashable value for the image file
if isinstance(image_file, Image.Image):
# For PIL.Image objects, use the ID as a hashable value
hash_value = hash(id(image_file))
else:
# For other types (strings, etc.), use the regular hash
hash_value = hash(image_file)
hashes += [hash_value] * len(frames)
images += frames
image_index += 1
if frames_to_process != 0:
if task_type == Modality.IMAGE:
frames = [result] if not isinstance(result, list) else result
if frames:
image_sizes += frames[0].size * len(frames)
images += frames
new_text += multimodal_tokens.image_token * len(frames)
assert frames_to_process == len(frames)
elif text_part == multimodal_tokens.audio_token:
# load as audio
audio_file = audio_data[audio_index]
audio = load_audio(audio_file)
hashes += [hash(audio_file)]
audios += [audio]
audio_index += 1
elif task_type == Modality.AUDIO:
# audio
audios.append(result)
new_text += multimodal_tokens.audio_token
else:
# TODO(mick): handle video
# normal text
new_text += text_part
except Exception as e:
logger.error(f"An exception occurred while loading images: {e}")
raise RuntimeError(f"An exception occurred while loading images: {e}")
# TODO: handle video
else:
new_text += text_part
out = BaseMultiModalProcessorOutput(
images=images,

View File

@@ -33,7 +33,9 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
base_out = self.load_mm_data(
prompt=input_ids,
image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens(image_token=processor.image_tag),
multimodal_tokens=MultimodalSpecialTokens(
image_token=processor.image_token
),
max_req_input_len=max_req_input_len,
)

View File

@@ -222,10 +222,10 @@ class MultimodalDataItem:
# memoryview() doesn't support PyTorch's BFloat16 dtype
tensor = tensor.float()
assert isinstance(tensor, torch.Tensor)
if tensor.is_cuda:
tensor_cpu = torch.frombuffer(
tensor.storage().untyped(), dtype=tensor.dtype, count=tensor.numel()
).clone()
# TODO: improve this
tensor_cpu = tensor.cpu()
else:
tensor_cpu = tensor
@@ -321,7 +321,6 @@ class MultimodalInputs:
item.set_pad_value()
optional_args = [
"modalities",
"im_token_id",
"im_start_id",
"im_end_id",

View File

@@ -452,6 +452,7 @@ class Scheduler(
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
revision=server_args.revision,
use_fast=not server_args.disable_fast_image_processor,
)
self.tokenizer = self.processor.tokenizer
else:

View File

@@ -180,6 +180,7 @@ class TokenizerManager:
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
revision=server_args.revision,
use_fast=not server_args.disable_fast_image_processor,
)
# We want to parallelize the image pre-processing so we create an executor for it

View File

@@ -462,6 +462,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
)
self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
@@ -515,15 +516,14 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
otherwise it will be `(seq_len,).
(Use input_metadata.mrope_positions to replace it)
"""
is_mrope_enabled = "mrope_section" in self.config.rope_scaling
if is_mrope_enabled:
if self.is_mrope_enabled:
positions = forward_batch.mrope_positions
if not (
forward_batch.forward_mode.is_decode()
or not forward_batch.contains_image_inputs()
):
if is_mrope_enabled:
if self.is_mrope_enabled:
assert positions.ndim == 2 and positions.size(0) == 3, (
"multimodal section rotary embedding requires "
f"(3, seq_len) positions, but got {positions.size()}"

View File

@@ -467,6 +467,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
prefix=add_prefix("lm_head", prefix),
)
self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
@@ -521,15 +522,14 @@ class Qwen2VLForConditionalGeneration(nn.Module):
otherwise it will be `(seq_len,).
(Use input_metadata.mrope_positions to replace it)
"""
is_mrope_enabled = "mrope_section" in self.config.rope_scaling
if is_mrope_enabled:
if self.is_mrope_enabled:
positions = forward_batch.mrope_positions
if not (
forward_batch.forward_mode.is_decode()
or not forward_batch.contains_image_inputs()
):
if is_mrope_enabled:
if self.is_mrope_enabled:
assert positions.ndim == 2 and positions.size(0) == 3, (
"multimodal section rotary embedding requires "
f"(3, seq_len) positions, but got {positions.size()}"

View File

@@ -196,6 +196,9 @@ class ServerArgs:
disaggregation_mode: str = "null"
disaggregation_bootstrap_port: int = 8998
# multimodal
disable_fast_image_processor: bool = False
def __post_init__(self):
# Expert parallelism
if self.enable_ep_moe:
@@ -979,6 +982,7 @@ class ServerArgs:
)
parser.add_argument(
"--enable-llama4-multimodal",
default=ServerArgs.enable_llama4_multimodal,
action="store_true",
help="Enable the multimodal functionality for Llama-4.",
)
@@ -1170,6 +1174,13 @@ class ServerArgs:
help="Bootstrap server port on the prefill server. Default is 8998.",
)
# Multimodal
parser.add_argument(
"--disable-fast-image-processor",
action="store_true",
help="Adopt base image processor instead of fast image processor.",
)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
args.tp_size = args.tensor_parallel_size