[VLM] Adopt fast image processor by default (#5065)
This commit is contained in:
@@ -215,6 +215,7 @@ def get_processor(
|
||||
tokenizer_mode: str = "auto",
|
||||
trust_remote_code: bool = False,
|
||||
tokenizer_revision: Optional[str] = None,
|
||||
use_fast: Optional[bool] = True,
|
||||
**kwargs,
|
||||
):
|
||||
# pop 'revision' from kwargs if present.
|
||||
@@ -232,6 +233,9 @@ def get_processor(
|
||||
if "size" not in kwargs:
|
||||
kwargs["size"] = {"shortest_edge": 3136, "longest_edge": 1003520}
|
||||
|
||||
if config.model_type not in {"llava", "clip"}:
|
||||
kwargs["use_fast"] = use_fast
|
||||
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
tokenizer_name,
|
||||
*args,
|
||||
|
||||
@@ -4,14 +4,16 @@ import dataclasses
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
from decord import VideoReader, cpu
|
||||
from PIL import Image
|
||||
from transformers import BaseImageProcessorFast
|
||||
|
||||
from sglang.srt.utils import encode_video, load_audio, load_image, logger
|
||||
from sglang.srt.managers.schedule_batch import Modality
|
||||
from sglang.srt.utils import encode_video, load_audio, load_image
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@@ -78,6 +80,10 @@ class BaseMultimodalProcessor(ABC):
|
||||
kwargs["audios"] = audios
|
||||
|
||||
processor = self._processor
|
||||
if hasattr(processor, "image_processor") and isinstance(
|
||||
processor.image_processor, BaseImageProcessorFast
|
||||
):
|
||||
kwargs["device"] = "cuda"
|
||||
result = processor.__call__(
|
||||
text=[input_text],
|
||||
padding=True,
|
||||
@@ -111,6 +117,84 @@ class BaseMultimodalProcessor(ABC):
|
||||
|
||||
return estimated_frames_list
|
||||
|
||||
@staticmethod
|
||||
def _load_single_item(
|
||||
data, is_video, is_audio, frame_count_limit=None, discard_alpha_channel=True
|
||||
):
|
||||
"""Static method that can be pickled for multiprocessing"""
|
||||
try:
|
||||
if is_audio:
|
||||
return load_audio(data)
|
||||
elif is_video:
|
||||
path = data[len("video:") :]
|
||||
return encode_video(path, frame_count_limit)
|
||||
else:
|
||||
img, _ = load_image(data)
|
||||
return img.convert("RGB") if discard_alpha_channel else img
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error while loading data {data}: {e}")
|
||||
|
||||
def submit_data_loading_tasks(
|
||||
self,
|
||||
text_parts: List[str],
|
||||
multimodal_tokens: MultimodalSpecialTokens,
|
||||
image_data: Optional[list] = None,
|
||||
audio_data: Optional[list] = None,
|
||||
discard_alpha_channel: bool = True,
|
||||
):
|
||||
"""
|
||||
load multimodal data parallelly
|
||||
"""
|
||||
|
||||
# TODO(mick): load from server_args, env, or sampling_params
|
||||
MAX_NUM_FRAMES = 30
|
||||
estimated_frames_list = self.get_estimated_frames_list(image_data=image_data)
|
||||
total_frame_count = sum(estimated_frames_list)
|
||||
# a heuristic value, suggesting the maximum fraction of frames to embed from all visual inputs.
|
||||
# e.g., 0.1 suggests that 1 frame out of 10 input frames should be used
|
||||
scaling_factor = min(1.0, MAX_NUM_FRAMES / max(1, total_frame_count))
|
||||
|
||||
assert len(image_data) == len(estimated_frames_list)
|
||||
# Submit all tasks
|
||||
futures = []
|
||||
task_info = []
|
||||
image_index, audio_index = 0, 0
|
||||
|
||||
for text_part in text_parts:
|
||||
if text_part == multimodal_tokens.image_token:
|
||||
data = image_data[image_index]
|
||||
is_video = isinstance(data, str) and data.startswith("video:")
|
||||
estimated_frames = estimated_frames_list[image_index]
|
||||
frame_count_limit = max(1, int(estimated_frames * scaling_factor))
|
||||
futures.append(
|
||||
self.io_executor.submit(
|
||||
BaseMultimodalProcessor._load_single_item,
|
||||
data,
|
||||
is_video,
|
||||
False,
|
||||
frame_count_limit,
|
||||
discard_alpha_channel,
|
||||
)
|
||||
)
|
||||
task_info.append((Modality.IMAGE, data, frame_count_limit))
|
||||
image_index += 1
|
||||
elif text_part == multimodal_tokens.audio_token:
|
||||
data = audio_data[audio_index]
|
||||
futures.append(
|
||||
self.io_executor.submit(
|
||||
BaseMultimodalProcessor._load_single_item,
|
||||
data,
|
||||
False,
|
||||
True,
|
||||
None,
|
||||
discard_alpha_channel,
|
||||
)
|
||||
)
|
||||
task_info.append((Modality.AUDIO, data, None))
|
||||
audio_index += 1
|
||||
|
||||
return futures, task_info
|
||||
|
||||
def load_mm_data(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -155,84 +239,37 @@ class BaseMultimodalProcessor(ABC):
|
||||
# split text into list of normal text and special tokens
|
||||
text_parts = re.split(pattern, prompt)
|
||||
|
||||
# TODO(mick): load from server_args, env, or sampling_params
|
||||
MAX_NUM_FRAMES = 30
|
||||
estimated_frames_list = self.get_estimated_frames_list(image_data=image_data)
|
||||
total_frame_count = sum(estimated_frames_list)
|
||||
# a heuristic value, suggesting the maximum fraction of frames to embed from all visual inputs.
|
||||
# e.g., 0.1 suggests that 1 frame out of 10 input frames should be used
|
||||
scaling_factor = min(1.0, MAX_NUM_FRAMES / max(1, total_frame_count))
|
||||
|
||||
assert len(image_data) == len(estimated_frames_list)
|
||||
|
||||
image_index, audio_index = 0, 0
|
||||
hashes, image_sizes, images, audios = [], [], [], []
|
||||
futures, task_info = self.submit_data_loading_tasks(
|
||||
text_parts=text_parts,
|
||||
multimodal_tokens=multimodal_tokens,
|
||||
image_data=image_data,
|
||||
audio_data=audio_data,
|
||||
discard_alpha_channel=discard_alpha_channel,
|
||||
)
|
||||
# Process results
|
||||
image_sizes, images, audios = [], [], []
|
||||
new_text = ""
|
||||
for index, text_part in enumerate(text_parts):
|
||||
try:
|
||||
if text_part == multimodal_tokens.image_token:
|
||||
# load as image
|
||||
if len(images) >= MAX_NUM_FRAMES:
|
||||
frames_to_process = 0
|
||||
else:
|
||||
estimated_frames = estimated_frames_list[image_index]
|
||||
frames_to_process = max(
|
||||
1, int(estimated_frames * scaling_factor)
|
||||
)
|
||||
task_ptr = 0
|
||||
|
||||
if frames_to_process == 0:
|
||||
frames = []
|
||||
else:
|
||||
image_file = image_data[image_index]
|
||||
if isinstance(image_file, str) and image_file.startswith(
|
||||
"video:"
|
||||
):
|
||||
# video
|
||||
path = image_file[len("video:") :]
|
||||
frames = encode_video(
|
||||
path, frame_count_limit=frames_to_process
|
||||
)
|
||||
else:
|
||||
# image
|
||||
raw_image, _size = load_image(image_file)
|
||||
if discard_alpha_channel:
|
||||
raw_image = raw_image.convert("RGB")
|
||||
frames = [raw_image]
|
||||
if len(frames) == 0:
|
||||
continue
|
||||
for text_part in text_parts:
|
||||
if text_part in multimodal_tokens.collect():
|
||||
task_type, data, frame_limit = task_info[task_ptr]
|
||||
result = futures[task_ptr].result()
|
||||
task_ptr += 1
|
||||
|
||||
image_sizes += frames[0].size * len(frames)
|
||||
|
||||
# Generate a hashable value for the image file
|
||||
if isinstance(image_file, Image.Image):
|
||||
# For PIL.Image objects, use the ID as a hashable value
|
||||
hash_value = hash(id(image_file))
|
||||
else:
|
||||
# For other types (strings, etc.), use the regular hash
|
||||
hash_value = hash(image_file)
|
||||
|
||||
hashes += [hash_value] * len(frames)
|
||||
images += frames
|
||||
image_index += 1
|
||||
if frames_to_process != 0:
|
||||
if task_type == Modality.IMAGE:
|
||||
frames = [result] if not isinstance(result, list) else result
|
||||
if frames:
|
||||
image_sizes += frames[0].size * len(frames)
|
||||
images += frames
|
||||
new_text += multimodal_tokens.image_token * len(frames)
|
||||
assert frames_to_process == len(frames)
|
||||
elif text_part == multimodal_tokens.audio_token:
|
||||
# load as audio
|
||||
audio_file = audio_data[audio_index]
|
||||
audio = load_audio(audio_file)
|
||||
hashes += [hash(audio_file)]
|
||||
audios += [audio]
|
||||
audio_index += 1
|
||||
elif task_type == Modality.AUDIO:
|
||||
# audio
|
||||
audios.append(result)
|
||||
new_text += multimodal_tokens.audio_token
|
||||
else:
|
||||
# TODO(mick): handle video
|
||||
# normal text
|
||||
new_text += text_part
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"An exception occurred while loading images: {e}")
|
||||
raise RuntimeError(f"An exception occurred while loading images: {e}")
|
||||
# TODO: handle video
|
||||
else:
|
||||
new_text += text_part
|
||||
|
||||
out = BaseMultiModalProcessorOutput(
|
||||
images=images,
|
||||
|
||||
@@ -33,7 +33,9 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
|
||||
base_out = self.load_mm_data(
|
||||
prompt=input_ids,
|
||||
image_data=image_data,
|
||||
multimodal_tokens=MultimodalSpecialTokens(image_token=processor.image_tag),
|
||||
multimodal_tokens=MultimodalSpecialTokens(
|
||||
image_token=processor.image_token
|
||||
),
|
||||
max_req_input_len=max_req_input_len,
|
||||
)
|
||||
|
||||
|
||||
@@ -222,10 +222,10 @@ class MultimodalDataItem:
|
||||
# memoryview() doesn't support PyTorch's BFloat16 dtype
|
||||
tensor = tensor.float()
|
||||
|
||||
assert isinstance(tensor, torch.Tensor)
|
||||
if tensor.is_cuda:
|
||||
tensor_cpu = torch.frombuffer(
|
||||
tensor.storage().untyped(), dtype=tensor.dtype, count=tensor.numel()
|
||||
).clone()
|
||||
# TODO: improve this
|
||||
tensor_cpu = tensor.cpu()
|
||||
else:
|
||||
tensor_cpu = tensor
|
||||
|
||||
@@ -321,7 +321,6 @@ class MultimodalInputs:
|
||||
item.set_pad_value()
|
||||
|
||||
optional_args = [
|
||||
"modalities",
|
||||
"im_token_id",
|
||||
"im_start_id",
|
||||
"im_end_id",
|
||||
|
||||
@@ -452,6 +452,7 @@ class Scheduler(
|
||||
tokenizer_mode=server_args.tokenizer_mode,
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
revision=server_args.revision,
|
||||
use_fast=not server_args.disable_fast_image_processor,
|
||||
)
|
||||
self.tokenizer = self.processor.tokenizer
|
||||
else:
|
||||
|
||||
@@ -180,6 +180,7 @@ class TokenizerManager:
|
||||
tokenizer_mode=server_args.tokenizer_mode,
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
revision=server_args.revision,
|
||||
use_fast=not server_args.disable_fast_image_processor,
|
||||
)
|
||||
|
||||
# We want to parallelize the image pre-processing so we create an executor for it
|
||||
|
||||
@@ -462,6 +462,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
)
|
||||
self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
|
||||
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||
@@ -515,15 +516,14 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
||||
otherwise it will be `(seq_len,).
|
||||
(Use input_metadata.mrope_positions to replace it)
|
||||
"""
|
||||
is_mrope_enabled = "mrope_section" in self.config.rope_scaling
|
||||
if is_mrope_enabled:
|
||||
if self.is_mrope_enabled:
|
||||
positions = forward_batch.mrope_positions
|
||||
|
||||
if not (
|
||||
forward_batch.forward_mode.is_decode()
|
||||
or not forward_batch.contains_image_inputs()
|
||||
):
|
||||
if is_mrope_enabled:
|
||||
if self.is_mrope_enabled:
|
||||
assert positions.ndim == 2 and positions.size(0) == 3, (
|
||||
"multimodal section rotary embedding requires "
|
||||
f"(3, seq_len) positions, but got {positions.size()}"
|
||||
|
||||
@@ -467,6 +467,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
)
|
||||
|
||||
self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||
|
||||
@@ -521,15 +522,14 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
otherwise it will be `(seq_len,).
|
||||
(Use input_metadata.mrope_positions to replace it)
|
||||
"""
|
||||
is_mrope_enabled = "mrope_section" in self.config.rope_scaling
|
||||
if is_mrope_enabled:
|
||||
if self.is_mrope_enabled:
|
||||
positions = forward_batch.mrope_positions
|
||||
|
||||
if not (
|
||||
forward_batch.forward_mode.is_decode()
|
||||
or not forward_batch.contains_image_inputs()
|
||||
):
|
||||
if is_mrope_enabled:
|
||||
if self.is_mrope_enabled:
|
||||
assert positions.ndim == 2 and positions.size(0) == 3, (
|
||||
"multimodal section rotary embedding requires "
|
||||
f"(3, seq_len) positions, but got {positions.size()}"
|
||||
|
||||
@@ -196,6 +196,9 @@ class ServerArgs:
|
||||
disaggregation_mode: str = "null"
|
||||
disaggregation_bootstrap_port: int = 8998
|
||||
|
||||
# multimodal
|
||||
disable_fast_image_processor: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
# Expert parallelism
|
||||
if self.enable_ep_moe:
|
||||
@@ -979,6 +982,7 @@ class ServerArgs:
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-llama4-multimodal",
|
||||
default=ServerArgs.enable_llama4_multimodal,
|
||||
action="store_true",
|
||||
help="Enable the multimodal functionality for Llama-4.",
|
||||
)
|
||||
@@ -1170,6 +1174,13 @@ class ServerArgs:
|
||||
help="Bootstrap server port on the prefill server. Default is 8998.",
|
||||
)
|
||||
|
||||
# Multimodal
|
||||
parser.add_argument(
|
||||
"--disable-fast-image-processor",
|
||||
action="store_true",
|
||||
help="Adopt base image processor instead of fast image processor.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
args.tp_size = args.tensor_parallel_size
|
||||
|
||||
Reference in New Issue
Block a user