[VLM] Adopt fast image processor by default (#5065)
This commit is contained in:
@@ -89,5 +89,4 @@ if __name__ == "__main__":
|
|||||||
EvalArgs.add_cli_args(parser)
|
EvalArgs.add_cli_args(parser)
|
||||||
args = add_common_sglang_args_and_parse(parser)
|
args = add_common_sglang_args_and_parse(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
eval_mmmu(args)
|
eval_mmmu(args)
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import os
|
|||||||
import pprint
|
import pprint
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -117,29 +118,38 @@ def prepare_samples(eval_args: EvalArgs):
|
|||||||
# merge all dataset
|
# merge all dataset
|
||||||
dataset = concatenate_datasets(sub_dataset_list)
|
dataset = concatenate_datasets(sub_dataset_list)
|
||||||
|
|
||||||
## prepare images
|
# Prepare images in parallel
|
||||||
samples = []
|
|
||||||
skip_count = 0
|
|
||||||
|
|
||||||
# use image file as input to ensure the consistency between sglang and hf
|
|
||||||
images_path = os.path.expanduser("~/.cache/mmmu/images")
|
images_path = os.path.expanduser("~/.cache/mmmu/images")
|
||||||
os.makedirs(images_path, exist_ok=True)
|
os.makedirs(images_path, exist_ok=True)
|
||||||
print(f"Saving images to: {images_path}")
|
print(f"Saving images to: {images_path}")
|
||||||
|
|
||||||
for i, sample in enumerate(tqdm(dataset)):
|
samples = []
|
||||||
|
skip_count = 0
|
||||||
|
|
||||||
|
def process_sample(i, sample):
|
||||||
sample = process_single_sample(sample)
|
sample = process_single_sample(sample)
|
||||||
sample = construct_prompt(sample, eval_args.config)
|
sample = construct_prompt(sample, eval_args.config)
|
||||||
image = sample["image"]
|
image = sample["image"]
|
||||||
|
|
||||||
width, height = image.size
|
width, height = image.size
|
||||||
if width * height >= eval_args.image_pixels_limit:
|
if width * height >= eval_args.image_pixels_limit:
|
||||||
skip_count += 1
|
return None, True
|
||||||
continue
|
|
||||||
image_path = f"{images_path}/image_{i}.png"
|
image_path = f"{images_path}/image_{i}.png"
|
||||||
if not os.path.exists(image_path):
|
if not os.path.exists(image_path):
|
||||||
image.save(image_path)
|
image.save(image_path)
|
||||||
sample["image_path"] = image_path
|
sample["image_path"] = image_path
|
||||||
samples.append(sample)
|
return sample, False
|
||||||
|
|
||||||
|
with ThreadPoolExecutor() as executor:
|
||||||
|
futures = [
|
||||||
|
executor.submit(process_sample, i, sample)
|
||||||
|
for i, sample in enumerate(dataset)
|
||||||
|
]
|
||||||
|
for future in tqdm(as_completed(futures), total=len(futures)):
|
||||||
|
sample, skipped = future.result()
|
||||||
|
if skipped:
|
||||||
|
skip_count += 1
|
||||||
|
elif sample:
|
||||||
|
samples.append(sample)
|
||||||
|
|
||||||
print(
|
print(
|
||||||
f"skipping {skip_count} samples with large images, {round((float(skip_count) / len(dataset)) * 100, 2)}% of dataset"
|
f"skipping {skip_count} samples with large images, {round((float(skip_count) / len(dataset)) * 100, 2)}% of dataset"
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct
|
|||||||
Please consult the documentation below to learn more about the parameters you may provide when launching a server.
|
Please consult the documentation below to learn more about the parameters you may provide when launching a server.
|
||||||
|
|
||||||
|
|
||||||
## Model and tokenizer
|
## Model, processor and tokenizer
|
||||||
|
|
||||||
* `model_path`: Path to the model that will be served.
|
* `model_path`: Path to the model that will be served.
|
||||||
* `tokenizer_path`: Defaults to the `model_path`.
|
* `tokenizer_path`: Defaults to the `model_path`.
|
||||||
@@ -62,6 +62,7 @@ Please consult the documentation below to learn more about the parameters you ma
|
|||||||
* `skip_tokenizer_init`: Set to true to provide the tokens to the engine and get the output tokens directly, typically used in RLHF. Please see this [example for reference](https://github.com/sgl-project/sglang/blob/main/examples/runtime/token_in_token_out/).
|
* `skip_tokenizer_init`: Set to true to provide the tokens to the engine and get the output tokens directly, typically used in RLHF. Please see this [example for reference](https://github.com/sgl-project/sglang/blob/main/examples/runtime/token_in_token_out/).
|
||||||
* `json_model_override_args`: Override model config with the provided JSON.
|
* `json_model_override_args`: Override model config with the provided JSON.
|
||||||
* `delete_ckpt_after_loading`: Delete the model checkpoint after loading the model.
|
* `delete_ckpt_after_loading`: Delete the model checkpoint after loading the model.
|
||||||
|
* `disable_fast_image_processor`: Adopt base image processor instead of fast image processor(which is by default). For more detail, see: https://huggingface.co/docs/transformers/main/en/main_classes/image_processor#image-processor
|
||||||
|
|
||||||
|
|
||||||
## Serving: HTTP & API
|
## Serving: HTTP & API
|
||||||
|
|||||||
@@ -215,6 +215,7 @@ def get_processor(
|
|||||||
tokenizer_mode: str = "auto",
|
tokenizer_mode: str = "auto",
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
tokenizer_revision: Optional[str] = None,
|
tokenizer_revision: Optional[str] = None,
|
||||||
|
use_fast: Optional[bool] = True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
# pop 'revision' from kwargs if present.
|
# pop 'revision' from kwargs if present.
|
||||||
@@ -232,6 +233,9 @@ def get_processor(
|
|||||||
if "size" not in kwargs:
|
if "size" not in kwargs:
|
||||||
kwargs["size"] = {"shortest_edge": 3136, "longest_edge": 1003520}
|
kwargs["size"] = {"shortest_edge": 3136, "longest_edge": 1003520}
|
||||||
|
|
||||||
|
if config.model_type not in {"llava", "clip"}:
|
||||||
|
kwargs["use_fast"] = use_fast
|
||||||
|
|
||||||
processor = AutoProcessor.from_pretrained(
|
processor = AutoProcessor.from_pretrained(
|
||||||
tokenizer_name,
|
tokenizer_name,
|
||||||
*args,
|
*args,
|
||||||
|
|||||||
@@ -4,14 +4,16 @@ import dataclasses
|
|||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
import os
|
import os
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import PIL
|
import PIL
|
||||||
from decord import VideoReader, cpu
|
from decord import VideoReader, cpu
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from transformers import BaseImageProcessorFast
|
||||||
|
|
||||||
from sglang.srt.utils import encode_video, load_audio, load_image, logger
|
from sglang.srt.managers.schedule_batch import Modality
|
||||||
|
from sglang.srt.utils import encode_video, load_audio, load_image
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
@@ -78,6 +80,10 @@ class BaseMultimodalProcessor(ABC):
|
|||||||
kwargs["audios"] = audios
|
kwargs["audios"] = audios
|
||||||
|
|
||||||
processor = self._processor
|
processor = self._processor
|
||||||
|
if hasattr(processor, "image_processor") and isinstance(
|
||||||
|
processor.image_processor, BaseImageProcessorFast
|
||||||
|
):
|
||||||
|
kwargs["device"] = "cuda"
|
||||||
result = processor.__call__(
|
result = processor.__call__(
|
||||||
text=[input_text],
|
text=[input_text],
|
||||||
padding=True,
|
padding=True,
|
||||||
@@ -111,6 +117,84 @@ class BaseMultimodalProcessor(ABC):
|
|||||||
|
|
||||||
return estimated_frames_list
|
return estimated_frames_list
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _load_single_item(
|
||||||
|
data, is_video, is_audio, frame_count_limit=None, discard_alpha_channel=True
|
||||||
|
):
|
||||||
|
"""Static method that can be pickled for multiprocessing"""
|
||||||
|
try:
|
||||||
|
if is_audio:
|
||||||
|
return load_audio(data)
|
||||||
|
elif is_video:
|
||||||
|
path = data[len("video:") :]
|
||||||
|
return encode_video(path, frame_count_limit)
|
||||||
|
else:
|
||||||
|
img, _ = load_image(data)
|
||||||
|
return img.convert("RGB") if discard_alpha_channel else img
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Error while loading data {data}: {e}")
|
||||||
|
|
||||||
|
def submit_data_loading_tasks(
|
||||||
|
self,
|
||||||
|
text_parts: List[str],
|
||||||
|
multimodal_tokens: MultimodalSpecialTokens,
|
||||||
|
image_data: Optional[list] = None,
|
||||||
|
audio_data: Optional[list] = None,
|
||||||
|
discard_alpha_channel: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
load multimodal data parallelly
|
||||||
|
"""
|
||||||
|
|
||||||
|
# TODO(mick): load from server_args, env, or sampling_params
|
||||||
|
MAX_NUM_FRAMES = 30
|
||||||
|
estimated_frames_list = self.get_estimated_frames_list(image_data=image_data)
|
||||||
|
total_frame_count = sum(estimated_frames_list)
|
||||||
|
# a heuristic value, suggesting the maximum fraction of frames to embed from all visual inputs.
|
||||||
|
# e.g., 0.1 suggests that 1 frame out of 10 input frames should be used
|
||||||
|
scaling_factor = min(1.0, MAX_NUM_FRAMES / max(1, total_frame_count))
|
||||||
|
|
||||||
|
assert len(image_data) == len(estimated_frames_list)
|
||||||
|
# Submit all tasks
|
||||||
|
futures = []
|
||||||
|
task_info = []
|
||||||
|
image_index, audio_index = 0, 0
|
||||||
|
|
||||||
|
for text_part in text_parts:
|
||||||
|
if text_part == multimodal_tokens.image_token:
|
||||||
|
data = image_data[image_index]
|
||||||
|
is_video = isinstance(data, str) and data.startswith("video:")
|
||||||
|
estimated_frames = estimated_frames_list[image_index]
|
||||||
|
frame_count_limit = max(1, int(estimated_frames * scaling_factor))
|
||||||
|
futures.append(
|
||||||
|
self.io_executor.submit(
|
||||||
|
BaseMultimodalProcessor._load_single_item,
|
||||||
|
data,
|
||||||
|
is_video,
|
||||||
|
False,
|
||||||
|
frame_count_limit,
|
||||||
|
discard_alpha_channel,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
task_info.append((Modality.IMAGE, data, frame_count_limit))
|
||||||
|
image_index += 1
|
||||||
|
elif text_part == multimodal_tokens.audio_token:
|
||||||
|
data = audio_data[audio_index]
|
||||||
|
futures.append(
|
||||||
|
self.io_executor.submit(
|
||||||
|
BaseMultimodalProcessor._load_single_item,
|
||||||
|
data,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
None,
|
||||||
|
discard_alpha_channel,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
task_info.append((Modality.AUDIO, data, None))
|
||||||
|
audio_index += 1
|
||||||
|
|
||||||
|
return futures, task_info
|
||||||
|
|
||||||
def load_mm_data(
|
def load_mm_data(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
@@ -155,84 +239,37 @@ class BaseMultimodalProcessor(ABC):
|
|||||||
# split text into list of normal text and special tokens
|
# split text into list of normal text and special tokens
|
||||||
text_parts = re.split(pattern, prompt)
|
text_parts = re.split(pattern, prompt)
|
||||||
|
|
||||||
# TODO(mick): load from server_args, env, or sampling_params
|
futures, task_info = self.submit_data_loading_tasks(
|
||||||
MAX_NUM_FRAMES = 30
|
text_parts=text_parts,
|
||||||
estimated_frames_list = self.get_estimated_frames_list(image_data=image_data)
|
multimodal_tokens=multimodal_tokens,
|
||||||
total_frame_count = sum(estimated_frames_list)
|
image_data=image_data,
|
||||||
# a heuristic value, suggesting the maximum fraction of frames to embed from all visual inputs.
|
audio_data=audio_data,
|
||||||
# e.g., 0.1 suggests that 1 frame out of 10 input frames should be used
|
discard_alpha_channel=discard_alpha_channel,
|
||||||
scaling_factor = min(1.0, MAX_NUM_FRAMES / max(1, total_frame_count))
|
)
|
||||||
|
# Process results
|
||||||
assert len(image_data) == len(estimated_frames_list)
|
image_sizes, images, audios = [], [], []
|
||||||
|
|
||||||
image_index, audio_index = 0, 0
|
|
||||||
hashes, image_sizes, images, audios = [], [], [], []
|
|
||||||
new_text = ""
|
new_text = ""
|
||||||
for index, text_part in enumerate(text_parts):
|
task_ptr = 0
|
||||||
try:
|
|
||||||
if text_part == multimodal_tokens.image_token:
|
|
||||||
# load as image
|
|
||||||
if len(images) >= MAX_NUM_FRAMES:
|
|
||||||
frames_to_process = 0
|
|
||||||
else:
|
|
||||||
estimated_frames = estimated_frames_list[image_index]
|
|
||||||
frames_to_process = max(
|
|
||||||
1, int(estimated_frames * scaling_factor)
|
|
||||||
)
|
|
||||||
|
|
||||||
if frames_to_process == 0:
|
for text_part in text_parts:
|
||||||
frames = []
|
if text_part in multimodal_tokens.collect():
|
||||||
else:
|
task_type, data, frame_limit = task_info[task_ptr]
|
||||||
image_file = image_data[image_index]
|
result = futures[task_ptr].result()
|
||||||
if isinstance(image_file, str) and image_file.startswith(
|
task_ptr += 1
|
||||||
"video:"
|
|
||||||
):
|
|
||||||
# video
|
|
||||||
path = image_file[len("video:") :]
|
|
||||||
frames = encode_video(
|
|
||||||
path, frame_count_limit=frames_to_process
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# image
|
|
||||||
raw_image, _size = load_image(image_file)
|
|
||||||
if discard_alpha_channel:
|
|
||||||
raw_image = raw_image.convert("RGB")
|
|
||||||
frames = [raw_image]
|
|
||||||
if len(frames) == 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
image_sizes += frames[0].size * len(frames)
|
if task_type == Modality.IMAGE:
|
||||||
|
frames = [result] if not isinstance(result, list) else result
|
||||||
# Generate a hashable value for the image file
|
if frames:
|
||||||
if isinstance(image_file, Image.Image):
|
image_sizes += frames[0].size * len(frames)
|
||||||
# For PIL.Image objects, use the ID as a hashable value
|
images += frames
|
||||||
hash_value = hash(id(image_file))
|
|
||||||
else:
|
|
||||||
# For other types (strings, etc.), use the regular hash
|
|
||||||
hash_value = hash(image_file)
|
|
||||||
|
|
||||||
hashes += [hash_value] * len(frames)
|
|
||||||
images += frames
|
|
||||||
image_index += 1
|
|
||||||
if frames_to_process != 0:
|
|
||||||
new_text += multimodal_tokens.image_token * len(frames)
|
new_text += multimodal_tokens.image_token * len(frames)
|
||||||
assert frames_to_process == len(frames)
|
elif task_type == Modality.AUDIO:
|
||||||
elif text_part == multimodal_tokens.audio_token:
|
# audio
|
||||||
# load as audio
|
audios.append(result)
|
||||||
audio_file = audio_data[audio_index]
|
|
||||||
audio = load_audio(audio_file)
|
|
||||||
hashes += [hash(audio_file)]
|
|
||||||
audios += [audio]
|
|
||||||
audio_index += 1
|
|
||||||
new_text += multimodal_tokens.audio_token
|
new_text += multimodal_tokens.audio_token
|
||||||
else:
|
# TODO: handle video
|
||||||
# TODO(mick): handle video
|
else:
|
||||||
# normal text
|
new_text += text_part
|
||||||
new_text += text_part
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"An exception occurred while loading images: {e}")
|
|
||||||
raise RuntimeError(f"An exception occurred while loading images: {e}")
|
|
||||||
|
|
||||||
out = BaseMultiModalProcessorOutput(
|
out = BaseMultiModalProcessorOutput(
|
||||||
images=images,
|
images=images,
|
||||||
|
|||||||
@@ -33,7 +33,9 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
|
|||||||
base_out = self.load_mm_data(
|
base_out = self.load_mm_data(
|
||||||
prompt=input_ids,
|
prompt=input_ids,
|
||||||
image_data=image_data,
|
image_data=image_data,
|
||||||
multimodal_tokens=MultimodalSpecialTokens(image_token=processor.image_tag),
|
multimodal_tokens=MultimodalSpecialTokens(
|
||||||
|
image_token=processor.image_token
|
||||||
|
),
|
||||||
max_req_input_len=max_req_input_len,
|
max_req_input_len=max_req_input_len,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -222,10 +222,10 @@ class MultimodalDataItem:
|
|||||||
# memoryview() doesn't support PyTorch's BFloat16 dtype
|
# memoryview() doesn't support PyTorch's BFloat16 dtype
|
||||||
tensor = tensor.float()
|
tensor = tensor.float()
|
||||||
|
|
||||||
|
assert isinstance(tensor, torch.Tensor)
|
||||||
if tensor.is_cuda:
|
if tensor.is_cuda:
|
||||||
tensor_cpu = torch.frombuffer(
|
# TODO: improve this
|
||||||
tensor.storage().untyped(), dtype=tensor.dtype, count=tensor.numel()
|
tensor_cpu = tensor.cpu()
|
||||||
).clone()
|
|
||||||
else:
|
else:
|
||||||
tensor_cpu = tensor
|
tensor_cpu = tensor
|
||||||
|
|
||||||
@@ -321,7 +321,6 @@ class MultimodalInputs:
|
|||||||
item.set_pad_value()
|
item.set_pad_value()
|
||||||
|
|
||||||
optional_args = [
|
optional_args = [
|
||||||
"modalities",
|
|
||||||
"im_token_id",
|
"im_token_id",
|
||||||
"im_start_id",
|
"im_start_id",
|
||||||
"im_end_id",
|
"im_end_id",
|
||||||
|
|||||||
@@ -452,6 +452,7 @@ class Scheduler(
|
|||||||
tokenizer_mode=server_args.tokenizer_mode,
|
tokenizer_mode=server_args.tokenizer_mode,
|
||||||
trust_remote_code=server_args.trust_remote_code,
|
trust_remote_code=server_args.trust_remote_code,
|
||||||
revision=server_args.revision,
|
revision=server_args.revision,
|
||||||
|
use_fast=not server_args.disable_fast_image_processor,
|
||||||
)
|
)
|
||||||
self.tokenizer = self.processor.tokenizer
|
self.tokenizer = self.processor.tokenizer
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -180,6 +180,7 @@ class TokenizerManager:
|
|||||||
tokenizer_mode=server_args.tokenizer_mode,
|
tokenizer_mode=server_args.tokenizer_mode,
|
||||||
trust_remote_code=server_args.trust_remote_code,
|
trust_remote_code=server_args.trust_remote_code,
|
||||||
revision=server_args.revision,
|
revision=server_args.revision,
|
||||||
|
use_fast=not server_args.disable_fast_image_processor,
|
||||||
)
|
)
|
||||||
|
|
||||||
# We want to parallelize the image pre-processing so we create an executor for it
|
# We want to parallelize the image pre-processing so we create an executor for it
|
||||||
|
|||||||
@@ -462,6 +462,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=add_prefix("lm_head", prefix),
|
prefix=add_prefix("lm_head", prefix),
|
||||||
)
|
)
|
||||||
|
self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
|
||||||
|
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||||
@@ -515,15 +516,14 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
|||||||
otherwise it will be `(seq_len,).
|
otherwise it will be `(seq_len,).
|
||||||
(Use input_metadata.mrope_positions to replace it)
|
(Use input_metadata.mrope_positions to replace it)
|
||||||
"""
|
"""
|
||||||
is_mrope_enabled = "mrope_section" in self.config.rope_scaling
|
if self.is_mrope_enabled:
|
||||||
if is_mrope_enabled:
|
|
||||||
positions = forward_batch.mrope_positions
|
positions = forward_batch.mrope_positions
|
||||||
|
|
||||||
if not (
|
if not (
|
||||||
forward_batch.forward_mode.is_decode()
|
forward_batch.forward_mode.is_decode()
|
||||||
or not forward_batch.contains_image_inputs()
|
or not forward_batch.contains_image_inputs()
|
||||||
):
|
):
|
||||||
if is_mrope_enabled:
|
if self.is_mrope_enabled:
|
||||||
assert positions.ndim == 2 and positions.size(0) == 3, (
|
assert positions.ndim == 2 and positions.size(0) == 3, (
|
||||||
"multimodal section rotary embedding requires "
|
"multimodal section rotary embedding requires "
|
||||||
f"(3, seq_len) positions, but got {positions.size()}"
|
f"(3, seq_len) positions, but got {positions.size()}"
|
||||||
|
|||||||
@@ -467,6 +467,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||||||
prefix=add_prefix("lm_head", prefix),
|
prefix=add_prefix("lm_head", prefix),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||||
|
|
||||||
@@ -521,15 +522,14 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||||||
otherwise it will be `(seq_len,).
|
otherwise it will be `(seq_len,).
|
||||||
(Use input_metadata.mrope_positions to replace it)
|
(Use input_metadata.mrope_positions to replace it)
|
||||||
"""
|
"""
|
||||||
is_mrope_enabled = "mrope_section" in self.config.rope_scaling
|
if self.is_mrope_enabled:
|
||||||
if is_mrope_enabled:
|
|
||||||
positions = forward_batch.mrope_positions
|
positions = forward_batch.mrope_positions
|
||||||
|
|
||||||
if not (
|
if not (
|
||||||
forward_batch.forward_mode.is_decode()
|
forward_batch.forward_mode.is_decode()
|
||||||
or not forward_batch.contains_image_inputs()
|
or not forward_batch.contains_image_inputs()
|
||||||
):
|
):
|
||||||
if is_mrope_enabled:
|
if self.is_mrope_enabled:
|
||||||
assert positions.ndim == 2 and positions.size(0) == 3, (
|
assert positions.ndim == 2 and positions.size(0) == 3, (
|
||||||
"multimodal section rotary embedding requires "
|
"multimodal section rotary embedding requires "
|
||||||
f"(3, seq_len) positions, but got {positions.size()}"
|
f"(3, seq_len) positions, but got {positions.size()}"
|
||||||
|
|||||||
@@ -196,6 +196,9 @@ class ServerArgs:
|
|||||||
disaggregation_mode: str = "null"
|
disaggregation_mode: str = "null"
|
||||||
disaggregation_bootstrap_port: int = 8998
|
disaggregation_bootstrap_port: int = 8998
|
||||||
|
|
||||||
|
# multimodal
|
||||||
|
disable_fast_image_processor: bool = False
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# Expert parallelism
|
# Expert parallelism
|
||||||
if self.enable_ep_moe:
|
if self.enable_ep_moe:
|
||||||
@@ -979,6 +982,7 @@ class ServerArgs:
|
|||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--enable-llama4-multimodal",
|
"--enable-llama4-multimodal",
|
||||||
|
default=ServerArgs.enable_llama4_multimodal,
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Enable the multimodal functionality for Llama-4.",
|
help="Enable the multimodal functionality for Llama-4.",
|
||||||
)
|
)
|
||||||
@@ -1170,6 +1174,13 @@ class ServerArgs:
|
|||||||
help="Bootstrap server port on the prefill server. Default is 8998.",
|
help="Bootstrap server port on the prefill server. Default is 8998.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Multimodal
|
||||||
|
parser.add_argument(
|
||||||
|
"--disable-fast-image-processor",
|
||||||
|
action="store_true",
|
||||||
|
help="Adopt base image processor instead of fast image processor.",
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_cli_args(cls, args: argparse.Namespace):
|
def from_cli_args(cls, args: argparse.Namespace):
|
||||||
args.tp_size = args.tensor_parallel_size
|
args.tp_size = args.tensor_parallel_size
|
||||||
|
|||||||
Reference in New Issue
Block a user