From ff2ce0b86fe8825f151c23b4d75b92d90983b074 Mon Sep 17 00:00:00 2001 From: Mick Date: Wed, 12 Mar 2025 03:35:35 +0800 Subject: [PATCH] refactor: move image processors to separate files (#4229) --- benchmark/mmmu/bench_hf.py | 39 +- benchmark/mmmu/bench_sglang.py | 82 ++- benchmark/mmmu/data_utils.py | 1 + benchmark/mmmu/eval_utils.py | 25 +- python/sglang/srt/conversation.py | 3 +- python/sglang/srt/hf_transformers_utils.py | 1 + python/sglang/srt/layers/attention/vision.py | 99 ++- python/sglang/srt/managers/image_processor.py | 668 +----------------- .../image_processors/base_image_processor.py | 206 ++++++ .../srt/managers/image_processors/llava.py | 152 ++++ .../srt/managers/image_processors/minicpmv.py | 86 +++ .../srt/managers/image_processors/mlama.py | 60 ++ .../srt/managers/image_processors/qwen_vl.py | 161 +++++ .../srt/managers/multi_modality_padding.py | 134 ++++ python/sglang/srt/managers/schedule_batch.py | 16 +- .../sglang/srt/model_loader/weight_utils.py | 2 +- python/sglang/srt/models/minicpmv.py | 117 +-- python/sglang/srt/models/mllama.py | 2 +- python/sglang/srt/models/qwen2_5_vl.py | 74 +- python/sglang/srt/models/qwen2_vl.py | 82 +-- test/srt/test_vision_llm.py | 8 +- test/srt/test_vision_openai_server.py | 22 +- 22 files changed, 1085 insertions(+), 955 deletions(-) create mode 100644 python/sglang/srt/managers/image_processors/base_image_processor.py create mode 100644 python/sglang/srt/managers/image_processors/llava.py create mode 100644 python/sglang/srt/managers/image_processors/minicpmv.py create mode 100644 python/sglang/srt/managers/image_processors/mlama.py create mode 100644 python/sglang/srt/managers/image_processors/qwen_vl.py create mode 100644 python/sglang/srt/managers/multi_modality_padding.py diff --git a/benchmark/mmmu/bench_hf.py b/benchmark/mmmu/bench_hf.py index d2b7dc9d7..0a237b07b 100644 --- a/benchmark/mmmu/bench_hf.py +++ b/benchmark/mmmu/bench_hf.py @@ -11,11 +11,16 @@ import argparse import random import torch -from bench_sglang import EvalArgs, prepare_samples from data_utils import save_json -from eval_utils import eval_result, get_sampling_params, parse_multi_choice_response +from eval_utils import ( + EvalArgs, + eval_result, + get_sampling_params, + prepare_samples, + process_result, +) from tqdm import tqdm -from transformers import AutoModelForImageTextToText, AutoProcessor +from transformers import AutoModelForImageTextToText, AutoProcessor, GenerationConfig @torch.no_grad() @@ -28,7 +33,6 @@ def eval_mmmu(args): trust_remote_code=True, ) model = model.eval().cuda() - model = torch.compile(model) processor = AutoProcessor.from_pretrained( args.model_path, torch_dtype="auto", device_map="auto" @@ -38,6 +42,10 @@ def eval_mmmu(args): out_samples = dict() sampling_params = get_sampling_params(eval_args) + generation_config = GenerationConfig( + max_new_tokens=sampling_params["max_new_tokens"], + do_sample=False, + ) answer_dict = {} for sample in tqdm(samples): @@ -62,7 +70,6 @@ def eval_mmmu(args): text = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) - inputs = processor( text=[text], images=[image], @@ -70,13 +77,16 @@ def eval_mmmu(args): return_tensors="pt", ).to(model.device) - generated_ids = model.generate(**inputs, **sampling_params) + generated_ids = model.generate( + **inputs, generation_config=generation_config + ) response = processor.decode( generated_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=False, )[len(text) :] + print(f"response: {response}") else: # multiple images actually if sample["question_type"] == "multiple-choice": all_choices = sample["all_choices"] @@ -85,24 +95,11 @@ def eval_mmmu(args): else: response = "INVALID GENERATION FOR MULTIPLE IMAGE INPUTS" - if sample["question_type"] == "multiple-choice": - pred_ans = parse_multi_choice_response( - response, sample["all_choices"], sample["index2ans"] - ) - else: # open question - pred_ans = response - out_samples[sample["id"]] = pred_ans - - torch.cuda.empty_cache() - # set ground truth answer - answer_dict[sample["id"]] = { - "question_type": sample["question_type"], - "ground_truth": sample["answer"], - } + process_result(response, sample, answer_dict, out_samples) args.output_path = f"{args.model_path}_val_hf.json" save_json(args.output_path, out_samples) - eval_result(output_path=args.output_path, answer_dict=answer_dict) + eval_result(model_answer_path=args.output_path, answer_dict=answer_dict) if __name__ == "__main__": diff --git a/benchmark/mmmu/bench_sglang.py b/benchmark/mmmu/bench_sglang.py index 88983165c..ba03dced3 100644 --- a/benchmark/mmmu/bench_sglang.py +++ b/benchmark/mmmu/bench_sglang.py @@ -8,9 +8,9 @@ """ import argparse +import base64 import dataclasses import random -import re from io import BytesIO from data_utils import save_json @@ -18,13 +18,14 @@ from eval_utils import ( EvalArgs, eval_result, get_sampling_params, - parse_multi_choice_response, prepare_samples, + process_result, ) from tqdm import tqdm from sglang import Engine -from sglang.srt.conversation import chat_templates +from sglang.srt.conversation import generate_chat_conv +from sglang.srt.openai_api.protocol import ChatCompletionRequest from sglang.srt.server_args import ServerArgs @@ -35,61 +36,76 @@ def eval_mmmu(args): if server_args.chat_template is None: raise ValueError("Chat template must be provided for this benchmark") - samples = prepare_samples(eval_args) - backend = Engine(**dataclasses.asdict(server_args)) out_samples = dict() sampling_params = get_sampling_params(eval_args) - conv = chat_templates[server_args.chat_template].copy() - image_token = conv.image_token + samples = prepare_samples(eval_args) + answer_dict = {} + for sample in tqdm(samples): prompt = sample["final_input_prompt"] image = sample["image"] - bytes_io = BytesIO() - image.save(bytes_io, format="PNG") - png_bytes = bytes_io.getvalue() - - prompt = re.sub(r"<[^>]*>", image_token, prompt) + buff = BytesIO() + image.save(buff, format="PNG") + base64_str = base64.b64encode(buff.getvalue()).decode("utf-8") + prefix = prompt.split("<")[0] + suffix = prompt.split(">")[1] + request_dict = { + "model": "", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": prefix, + }, + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{base64_str}" + }, + }, + { + "type": "text", + "text": suffix, + }, + ], + } + ], + } + conv = generate_chat_conv( + ChatCompletionRequest(**request_dict), + template_name=server_args.chat_template, + ) + prompt = conv.get_prompt() if image is not None: gen_out = backend.generate( - prompt=prompt, image_data=[png_bytes], sampling_params=sampling_params + prompt=prompt, + image_data=conv.image_data, + sampling_params=sampling_params, )["text"] response = gen_out + else: # multiple images actually if sample["question_type"] == "multiple-choice": all_choices = sample["all_choices"] response = random.choice(all_choices) - else: response = "INVALID GENERATION FOR MULTIPLE IMAGE INPUTS" - if sample["question_type"] == "multiple-choice": - pred_ans = parse_multi_choice_response( - response, sample["all_choices"], sample["index2ans"] - ) - else: # open question - pred_ans = response - out_samples[sample["id"]] = pred_ans - - # set ground truth answer - answer_dict[sample["id"]] = { - "question_type": sample["question_type"], - "ground_truth": ( - sample["correct_choice"] - if "correct_choice" in samples - else sample["answer"] - ), - } - + process_result(response, sample, answer_dict, out_samples) args.output_path = f"{args.model_path}_val_sglang.json" save_json(args.output_path, out_samples) - eval_result(output_path=args.output_path, answer_dict=answer_dict) + eval_result(model_answer_path=args.output_path, answer_dict=answer_dict) + + backend.shutdown() if __name__ == "__main__": diff --git a/benchmark/mmmu/data_utils.py b/benchmark/mmmu/data_utils.py index 40156c970..197e90638 100644 --- a/benchmark/mmmu/data_utils.py +++ b/benchmark/mmmu/data_utils.py @@ -143,6 +143,7 @@ def process_single_sample(data): # DATA SAVING def save_json(filename, ds): + print(f"answers saved to: {filename}") os.makedirs(os.path.dirname(filename), exist_ok=True) with open(filename, "w") as f: json.dump(ds, f, indent=4) diff --git a/benchmark/mmmu/eval_utils.py b/benchmark/mmmu/eval_utils.py index 712042d4a..dc26ccb1e 100644 --- a/benchmark/mmmu/eval_utils.py +++ b/benchmark/mmmu/eval_utils.py @@ -87,6 +87,7 @@ def set_seed(seed_value): def prepare_samples(eval_args: EvalArgs): + print("preparing samples...") # Build prompts set_seed(eval_args.seed) @@ -110,6 +111,7 @@ def prepare_samples(eval_args: EvalArgs): eval_args.dataset_path, subject, split=eval_args.split ) sub_dataset_list.append(sub_dataset) + # break # merge all dataset dataset = concatenate_datasets(sub_dataset_list) @@ -426,9 +428,26 @@ def calculate_ins_level_acc(results: Dict): return acc / ins_num -def eval_result(output_path, answer_dict): +def process_result(response, sample, answer_dict, out_samples): + if sample["question_type"] == "multiple-choice": + pred_ans = parse_multi_choice_response( + response, sample["all_choices"], sample["index2ans"] + ) + else: # open question + pred_ans = response + + out_samples[sample["id"]] = pred_ans + + # set ground truth answer + answer_dict[sample["id"]] = { + "question_type": sample["question_type"], + "ground_truth": sample["answer"], + } + + +def eval_result(model_answer_path, answer_dict): print("Evaluating...") - output_dict = json.load(open(output_path)) + output_dict = json.load(open(model_answer_path)) # answer_dict = json.load(open(answer_path)) # group by category @@ -521,7 +540,7 @@ def eval_result(output_path, answer_dict): "acc": overall_acc, } pprint.pprint(printable_results) - out = output_path + out = model_answer_path with open(out, "w", encoding="utf-8") as outfile: json.dump(printable_results, outfile) print(f"eval out saved to {out}") diff --git a/python/sglang/srt/conversation.py b/python/sglang/srt/conversation.py index a19a9e735..b6db4a2da 100644 --- a/python/sglang/srt/conversation.py +++ b/python/sglang/srt/conversation.py @@ -191,7 +191,7 @@ class Conversation: for i, (role, message) in enumerate(self.messages): if i % 2 == 0: - ret += f"[Round {i//2 + round_add_n}]{self.sep}" + ret += f"[Round {i // 2 + round_add_n}]{self.sep}" if message: ret += f"{role}:{message}{self.sep}" @@ -453,7 +453,6 @@ def generate_chat_conv( conv.system_message = getattr(message.content[0], "text", "") elif msg_role == "user": # Handle the various types of Chat Request content types here. - role = conv.roles[0] if isinstance(message.content, str): conv.append_message(conv.roles[0], message.content) else: diff --git a/python/sglang/srt/hf_transformers_utils.py b/python/sglang/srt/hf_transformers_utils.py index e6f06b399..6c1efe31d 100644 --- a/python/sglang/srt/hf_transformers_utils.py +++ b/python/sglang/srt/hf_transformers_utils.py @@ -66,6 +66,7 @@ def get_config( config = AutoConfig.from_pretrained( model, trust_remote_code=trust_remote_code, revision=revision, **kwargs ) + if config.model_type in _CONFIG_REGISTRY: config_class = _CONFIG_REGISTRY[config.model_type] config = config_class.from_pretrained(model, revision=revision) diff --git a/python/sglang/srt/layers/attention/vision.py b/python/sglang/srt/layers/attention/vision.py index f90673191..3aea3b7ae 100644 --- a/python/sglang/srt/layers/attention/vision.py +++ b/python/sglang/srt/layers/attention/vision.py @@ -1,7 +1,7 @@ from __future__ import annotations from functools import lru_cache -from typing import Optional +from typing import Optional, Tuple import torch import torch.nn as nn @@ -22,47 +22,29 @@ from sglang.srt.layers.quantization import QuantizationConfig from sglang.srt.utils import add_prefix -def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor: - if not interleaved: - x1, x2 = x.chunk(2, dim=-1) - return torch.cat((-x2, x1), dim=-1) - else: - x1, x2 = x[..., ::2], x[..., 1::2] - return rearrange( - torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2 - ) +# Copied from transformers, modeling_qwen2_vl.py +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) -def apply_rotary_emb_torch( - x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False -) -> torch.Tensor: - """ - x: (batch_size, seqlen, nheads, headdim) - cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) - """ - ro_dim = cos.shape[-1] * 2 - assert ro_dim <= x.shape[-1] - cos = repeat( - cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" - ) - sin = repeat( - sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" - ) - return torch.cat( - [ - x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, - x[..., ro_dim:], - ], - dim=-1, - ) +def apply_rotary_pos_emb_vision( + q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + orig_q_dtype = q.dtype + orig_k_dtype = k.dtype + q, k = q.float(), k.float() + cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float() + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) -def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: - t_ = t.float() - cos = freqs.cos() - sin = freqs.sin() - output = apply_rotary_emb_torch(t_, cos, sin).type_as(t) - return output + q_embed = q_embed.to(orig_q_dtype) + k_embed = k_embed.to(orig_k_dtype) + + return q_embed, k_embed class VisionAttention(nn.Module): @@ -75,8 +57,8 @@ class VisionAttention(nn.Module): use_context_forward (bool, default to True): if ``True``, a flash_attn style attention will be applied Otherwise, a full-sequence attention will be applied. - use_full_precision_softmax (bool, default to False): - if ``True``, the softmax will be performed in full-precision + softmax_in_single_precision (bool, default to False): + if ``True``, the softmax will be performed in single-precision Otherwise, it will be performed in half-precision """ @@ -90,7 +72,7 @@ class VisionAttention(nn.Module): quant_config: Optional[QuantizationConfig] = None, dropout: float = 0.0, use_context_forward: bool = True, - use_full_precision_softmax: bool = False, + softmax_in_single_precision: bool = False, flatten_batch: bool = False, prefix: str = "", ): @@ -113,7 +95,7 @@ class VisionAttention(nn.Module): head_size=self.head_size, dropout=dropout, flatten_batch=flatten_batch, - use_full_precision_softmax=use_full_precision_softmax, + softmax_in_single_precision=softmax_in_single_precision, ) self.use_qkv_parallel = use_qkv_parallel @@ -143,7 +125,7 @@ class VisionAttention(nn.Module): self, x: torch.Tensor, cu_seqlens: Optional[torch.Tensor] = None, - rotary_pos_emb: torch.Tensor = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: r""" @@ -151,21 +133,17 @@ class VisionAttention(nn.Module): x: [b, s, embed_dim] cu_seqlens: [b] Returns: - [s, b, num_heads * head] + [s, b, head * head_size] """ bsz, s, _ = x.shape + head = self.num_attention_heads_per_partition if self.use_qkv_parallel: # [b, s, embed_dim] --> [b, s, embed_dim] qkv, _ = self.qkv_proj(x) q, k, v = qkv.chunk(3, dim=-1) - # [b, s, embed_dim] --> [b * s, num_heads, head_size] - q, k, v = [ - x.reshape( - bsz * s, self.num_attention_heads_per_partition, -1 - ).contiguous() - for x in (q, k, v) - ] + # [b, s, embed_dim] --> [b * s, head, head_size] + q, k, v = [x.reshape(bsz * s, head, -1).contiguous() for x in (q, k, v)] else: # [b, s, embed_dim] --> [s, b, embed_dim] x = rearrange(x, "b s ... -> s b ...") @@ -173,7 +151,7 @@ class VisionAttention(nn.Module): qkv, _ = self.qkv_proj(x) # [s, b, head * 3 * head_size] --> [s, b, head, 3 * head_size] new_x_shape = qkv.size()[:-1] + ( - self.num_attention_heads_per_partition, + head, 3 * self.hidden_size_per_attention_head, ) qkv = qkv.view(*new_x_shape) @@ -186,9 +164,12 @@ class VisionAttention(nn.Module): rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v) ] - if rotary_pos_emb is not None: - q = apply_rotary_pos_emb_vision(q, rotary_pos_emb) - k = apply_rotary_pos_emb_vision(k, rotary_pos_emb) + if position_embeddings is not None: + cos, sin = position_embeddings + original_shape = q.shape + q, k = q.view(s, head, -1), k.view(s, head, -1) + q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) + q, k = q.reshape(original_shape), k.reshape(original_shape) if self.use_qkv_parallel: pass @@ -230,12 +211,12 @@ class VisionSdpaAttention(nn.Module): head_size: int, dropout: float = 0.0, flatten_batch: bool = False, - use_full_precision_softmax: bool = False, + softmax_in_single_precision: bool = False, ): super().__init__() self.head_size = head_size self.flatten_batch = flatten_batch - self.use_full_precision_softmax = use_full_precision_softmax + self.softmax_in_single_precision = softmax_in_single_precision self.dropout = dropout @staticmethod @@ -319,14 +300,14 @@ class VisionSdpaAttention(nn.Module): ) if attention_mask is None: - if self.use_full_precision_softmax: + if self.softmax_in_single_precision: raise RuntimeError("Empty attention mask") else: attention_mask = attention_mask.to(device=q.device) q, k, v = [rearrange(x, "(b s) h d -> b h s d", b=bsz) for x in [q, k, v]] - if self.use_full_precision_softmax: + if self.softmax_in_single_precision: scale = self.head_size**-0.5 k_transposed = rearrange(k, "b h s d -> b h d s") attn_weights = torch.matmul(q, k_transposed) * scale diff --git a/python/sglang/srt/managers/image_processor.py b/python/sglang/srt/managers/image_processor.py index 57fc4a6b4..c3ef9b51c 100644 --- a/python/sglang/srt/managers/image_processor.py +++ b/python/sglang/srt/managers/image_processor.py @@ -1,649 +1,55 @@ # TODO: also move pad_input_ids into this module -import asyncio -import concurrent.futures -import dataclasses +import importlib import logging -import multiprocessing as mp -import os -from abc import ABC, abstractmethod -from typing import List, Optional, Union +import pkgutil +from functools import lru_cache -import numpy as np -import PIL -import transformers -from decord import VideoReader, cpu -from PIL import Image +from transformers import IMAGE_PROCESSOR_MAPPING -from sglang.srt.hf_transformers_utils import get_processor -from sglang.srt.mm_utils import expand2square, process_anyres_image +from sglang.srt.managers.image_processors.base_image_processor import ( + BaseImageProcessor, + DummyImageProcessor, +) from sglang.srt.server_args import ServerArgs -from sglang.srt.utils import load_image -from sglang.utils import get_exception_traceback logger = logging.getLogger(__name__) -global global_processor - -def init_global_processor(server_args: ServerArgs): - """Init the global processor for multi modal models.""" - global global_processor - transformers.logging.set_verbosity_error() - global_processor = get_processor( - server_args.tokenizer_path, - tokenizer_mode=server_args.tokenizer_mode, - trust_remote_code=server_args.trust_remote_code, - ) - - -@dataclasses.dataclass -class BaseImageProcessorOutput: - image_hashes: list[int] - image_sizes: list[int] - all_frames: [PIL.Image] - # input_text, with each frame of video/image represented with a image_token - input_text: str - - -class BaseImageProcessor(ABC): - def __init__(self, hf_config, server_args, _processor): - self.hf_config = hf_config - self._processor = _processor - self.server_args = server_args - # FIXME: not accurate, model and image specific - self.NUM_TOKEN_PER_FRAME = 330 - - self.executor = concurrent.futures.ProcessPoolExecutor( - initializer=init_global_processor, - mp_context=mp.get_context("fork"), - initargs=(server_args,), - max_workers=int(os.environ.get("SGLANG_CPU_COUNT", os.cpu_count())), - ) - - @abstractmethod - async def process_images_async( - self, image_data, input_text, max_req_input_len, **kwargs - ): - pass - - def get_estimated_frames_list(self, image_data): - """ - estimate the total frame count from all visual input - """ - # Before processing inputs - estimated_frames_list = [] - for image in image_data: - if isinstance(image, str) and image.startswith("video:"): - path = image[len("video:") :] - # Estimate frames for the video - vr = VideoReader(path, ctx=cpu(0)) - num_frames = len(vr) - else: - # For images, each contributes one frame - num_frames = 1 - estimated_frames_list.append(num_frames) - - return estimated_frames_list - - def encode_video(self, video_path, frame_count_limit=None): - if not os.path.exists(video_path): - logger.error(f"Video {video_path} does not exist") - return [] - - if frame_count_limit == 0: - return [] - - def uniform_sample(l, n): - gap = len(l) / n - idxs = [int(i * gap + gap / 2) for i in range(n)] - return [l[i] for i in idxs] - - vr = VideoReader(video_path, ctx=cpu(0)) - sample_fps = round(vr.get_avg_fps() / 1) # FPS - frame_idx = [i for i in range(0, len(vr), sample_fps)] - if frame_count_limit is not None and len(frame_idx) > frame_count_limit: - frame_idx = uniform_sample(frame_idx, frame_count_limit) - frames = vr.get_batch(frame_idx).asnumpy() - frames = [Image.fromarray(v.astype("uint8")) for v in frames] - return frames - - def load_images( - self, - max_req_input_len: int, - input_ids: list, - image_data, - image_token: str, - ) -> BaseImageProcessorOutput: - """ - Each frame of video/image will be replaced by a single image token - """ - image_hashes, image_sizes = [], [] - all_frames = [] - new_text_parts = [] - - if isinstance(input_ids, list): - assert len(input_ids) and isinstance(input_ids[0], int) - input_text = self._processor.tokenizer.decode(input_ids) - else: - input_text = input_ids - - text_parts = input_text.split(image_token) - - # roughly calculate the max number of frames under the max_req_input_len limit - def calculate_max_num_frames() -> int: - ret = (max_req_input_len - len(input_ids)) // self.NUM_TOKEN_PER_FRAME - return min(ret, 100) - - MAX_NUM_FRAMES = calculate_max_num_frames() - estimated_frames_list = self.get_estimated_frames_list(image_data=image_data) - total_frame_count = sum(estimated_frames_list) - # a heuristic value, suggesting the maximum fraction of frames to embed from all visual inputs. - # e.g., 0.1 suggests that 1 frame out of 10 input frames should be used - scaling_factor = min(1.0, MAX_NUM_FRAMES / total_frame_count) - - # Process each input with allocated frames - for image_index, (image, estimated_frames) in enumerate( - zip(image_data, estimated_frames_list) - ): - if len(all_frames) >= MAX_NUM_FRAMES: - frames_to_process = 0 - else: - frames_to_process = max(1, int(estimated_frames * scaling_factor)) - - if frames_to_process == 0: - frames = [] - else: - try: - if isinstance(image, str) and image.startswith("video:"): - path = image[len("video:") :] - frames = self.encode_video( - path, frame_count_limit=frames_to_process - ) - else: - raw_image, _size = load_image(image) - frames = [raw_image] - if len(frames) == 0: - continue - except FileNotFoundError as e: - print(e) - return None - image_sizes += frames[0].size * len(frames) - image_hashes += [hash(image)] * len(frames) - all_frames += frames - - new_text_parts.append(text_parts[image_index]) - if frames_to_process != 0: - new_text_parts.append(image_token * len(frames)) - assert frames_to_process == len(frames) - - new_text_parts.append(text_parts[-1]) - - input_text = "".join(new_text_parts) - return BaseImageProcessorOutput( - image_hashes, image_sizes, all_frames, input_text - ) - - -class DummyImageProcessor(BaseImageProcessor): - def __init__(self): - pass - - async def process_images_async(self, *args, **kwargs): - return None - - -class LlavaImageProcessor(BaseImageProcessor): - def __init__(self, hf_config, server_args, _processor): - super().__init__(hf_config, server_args, _processor) - - @staticmethod - def _process_single_image_task( - image_data: Union[str, bytes], - image_aspect_ratio: Optional[str] = None, - image_grid_pinpoints: Optional[str] = None, - image_processor=None, - ): - image_processor = image_processor or global_processor.image_processor - - try: - image, image_size = load_image(image_data) - if image_size is not None: - # It is a video with multiple images - image_hash = hash(image_data) - pixel_values = image_processor(image)["pixel_values"] - for _ in range(len(pixel_values)): - pixel_values[_] = pixel_values[_].astype(np.float16) - pixel_values = np.stack(pixel_values, axis=0) - return pixel_values, image_hash, image_size - else: - # It is an image - image_hash = hash(image_data) - if image_aspect_ratio == "pad": - image = expand2square( - image, - tuple(int(x * 255) for x in image_processor.image_mean), - ) - pixel_values = image_processor(image.convert("RGB"))[ - "pixel_values" - ][0] - elif image_aspect_ratio == "anyres" or ( - image_aspect_ratio is not None - and "anyres_max" in image_aspect_ratio - ): - pixel_values = process_anyres_image( - image, image_processor, image_grid_pinpoints - ) - else: - pixel_values = image_processor(image)["pixel_values"][0] - - if isinstance(pixel_values, np.ndarray): - pixel_values = pixel_values.astype(np.float16) - - return pixel_values, image_hash, image.size - except Exception: - logger.error("Exception in TokenizerManager:\n" + get_exception_traceback()) - - async def _process_single_image( - self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str - ): - if self.executor is not None: - loop = asyncio.get_event_loop() - return await loop.run_in_executor( - self.executor, - LlavaImageProcessor._process_single_image_task, - image_data, - aspect_ratio, - grid_pinpoints, - ) - else: - return self._process_single_image_task( - image_data, aspect_ratio, grid_pinpoints - ) - - async def process_images_async( - self, - image_data: List[Union[str, bytes]], - input_text, - request_obj, - *args, - **kwargs, - ): - if not image_data: - return None - - modalities = request_obj.modalities or ["image"] - aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None) - grid_pinpoints = ( - self.hf_config.image_grid_pinpoints - if hasattr(self.hf_config, "image_grid_pinpoints") - and "anyres" in aspect_ratio - else None - ) - - if isinstance(image_data, str): - image_data = [image_data] - - if isinstance(image_data, list) and len(image_data) > 0: - if "multi-images" in modalities or "video" in modalities: - # Multiple images - aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres - pixel_values, image_hashes, image_sizes = [], [], [] - res = [] - for img_data in image_data: - res.append( - self._process_single_image( - img_data, aspect_ratio, grid_pinpoints - ) - ) - res = await asyncio.gather(*res) - for pixel_v, image_h, image_s in res: - pixel_values.append(pixel_v) - image_hashes.append(image_h) - image_sizes.append(image_s) - - if isinstance(pixel_values[0], np.ndarray): - pixel_values = np.stack(pixel_values, axis=0) - else: - # A single image - pixel_values, image_hash, image_size = await self._process_single_image( - image_data[0], aspect_ratio, grid_pinpoints - ) - image_hashes = [image_hash] - image_sizes = [image_size] - else: - raise ValueError(f"Invalid image data: {image_data}") - - return { - "pixel_values": pixel_values, - "image_hashes": image_hashes, - "image_sizes": image_sizes, - "modalities": request_obj.modalities or ["image"], - } - - -class MllamaImageProcessor(BaseImageProcessor): - def __init__(self, hf_config, server_args, _processor): - super().__init__(hf_config, server_args, _processor) - - @staticmethod - def _process_single_image_task(images, input_text): - # input_ids', 'attention_mask', 'pixel_values', 'aspect_ratio_ids', 'aspect_ratio_mask', 'cross_attention_mask' - return global_processor(images, input_text, return_tensors="pt") - - async def _process_single_image(self, images, input_text): - if self.executor is not None: - loop = asyncio.get_event_loop() - image_inputs = await loop.run_in_executor( - self.executor, - MllamaImageProcessor._process_single_image_task, - images, - input_text, - ) - else: - image_inputs = self._processor(images, input_text, return_tensors="pt") - - return image_inputs - - async def process_images_async( - self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs - ): - if not image_data: - return None - - if isinstance(input_text, list): - assert len(input_text) and isinstance(input_text[0], int) - input_text = self._processor.tokenizer.decode(input_text) - - if not isinstance(image_data, list): - image_data = [image_data] - - if len(image_data) > 0: - images = [load_image(image)[0] for image in image_data] - else: - images = load_image(image_data[0])[0] - - image_inputs = await self._process_single_image(images, input_text) - image_inputs["image_hashes"] = [hash(str(image_data))] - image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0] - - return image_inputs - - -class MiniCPMVImageProcessor(BaseImageProcessor): - def __init__(self, hf_config, server_args, _processor): - super().__init__(hf_config, server_args, _processor) - self.IMAGE_TOKEN = "(./)" - - @staticmethod - def _process_images_task(images, input_text): - result = global_processor.__call__( - text=input_text, images=images, return_tensors="pt" - ) - return { - "input_ids": result.input_ids, - "pixel_values": result.pixel_values, - "tgt_sizes": result.tgt_sizes, - } - - async def _process_images(self, images, input_text): - if self.executor is not None: - loop = asyncio.get_event_loop() - image_inputs = await loop.run_in_executor( - self.executor, - MiniCPMVImageProcessor._process_images_task, - images, - input_text, - ) - else: - image_inputs = self._processor( - images=images, text=input_text, return_tensors="pt" - ) - - return image_inputs - - async def process_images_async( - self, - image_data: List[Union[str, bytes]], - input_ids, - request_obj, - max_req_input_len, - ): - if not image_data: - return None - if not isinstance(image_data, list): - image_data = [image_data] - - base_output = self.load_images( - max_req_input_len, input_ids, image_data, self.IMAGE_TOKEN - ) - if base_output is None: - return None - - if len(base_output.all_frames) == 0: - return None - res = await self._process_images( - images=base_output.all_frames, input_text=base_output.input_text - ) - - # Collect special token ids - tokenizer = self._processor.tokenizer - im_start_id = [tokenizer.im_start_id] - im_end_id = [tokenizer.im_end_id] - if tokenizer.slice_start_id: - slice_start_id = [tokenizer.slice_start_id] - slice_end_id = [tokenizer.slice_end_id] - return { - "input_ids": res["input_ids"].flatten().tolist(), - "pixel_values": res["pixel_values"], - "tgt_sizes": res["tgt_sizes"], - "image_hashes": base_output.image_hashes, - "modalities": request_obj.modalities or ["image"], - "im_start_id": im_start_id, - "im_end_id": im_end_id, - "slice_start_id": slice_start_id, - "slice_end_id": slice_end_id, - } - - -class Qwen2VLImageProcessor(BaseImageProcessor): - def __init__(self, hf_config, server_args, _image_processor): - self.hf_config = hf_config - self._image_processor = _image_processor - self.executor = concurrent.futures.ProcessPoolExecutor( - initializer=init_global_processor, - mp_context=mp.get_context("fork"), - initargs=(server_args,), - max_workers=int(os.environ.get("SGLANG_CPU_COUNT", os.cpu_count())), - ) - - @staticmethod - def _process_single_image_task( - image_data: Union[str, bytes], - image_processor=None, - ): - image_processor = image_processor or global_processor.image_processor - - try: - image, image_size = load_image(image_data) - if image_size is not None: - # It is a video with multiple images - image_hash = hash(image_data) - process_result = image_processor(image) - pixel_values, image_grid_thws = ( - process_result["pixel_values"], - process_result["image_grid_thw"][0], - ) - for _ in range(len(pixel_values)): - pixel_values[_] = pixel_values[_].astype(np.float16) - pixel_values = np.stack(pixel_values, axis=0) - image_grid_thws = np.stack(image_grid_thws, axis=0) - return pixel_values, image_hash, image_size, image_grid_thws - else: - # It is an image - image_hash = hash(image_data) - process_result = image_processor(image) - pixel_values, image_grid_thws = ( - process_result["pixel_values"], - process_result["image_grid_thw"][0], - ) - if isinstance(pixel_values, np.ndarray): - pixel_values = pixel_values.astype(np.float16) - - return pixel_values, image_hash, image.size, image_grid_thws - except Exception: - logger.error("Exception in TokenizerManager:\n" + get_exception_traceback()) - - async def _process_single_image(self, image_data: Union[bytes, str]): - if self.executor is not None: - loop = asyncio.get_event_loop() - return await loop.run_in_executor( - self.executor, - Qwen2VLImageProcessor._process_single_image_task, - image_data, - ) - else: - return self._process_single_image_task(image_data) - - async def process_images_async( - self, - image_data: List[Union[str, bytes]], - input_text, - request_obj, - *args, - **kwargs, - ): - if not image_data: - return None - - if isinstance(image_data, list) and len(image_data) > 0: - # Multiple images - if len(image_data) > 1: - pixel_values, image_hashes, image_sizes, image_grid_thws = ( - [], - [], - [], - [], - ) - res = [] - for img_data in image_data: - res.append(self._process_single_image(img_data)) - res = await asyncio.gather(*res) - for pixel_v, image_h, image_s, image_thw in res: - pixel_values.append(pixel_v) - image_hashes.append(image_h) - image_sizes.append(image_s) - image_grid_thws.append(image_thw) - - if isinstance(pixel_values[0], np.ndarray): - pixel_values = np.concatenate(pixel_values, axis=0) - else: - # A single image - pixel_values, image_hash, image_size, image_grid_thw = ( - await self._process_single_image(image_data[0]) - ) - image_hashes = [image_hash] - image_sizes = [image_size] - image_grid_thws = [image_grid_thw] - elif isinstance(image_data, str) or isinstance(image_data, bytes): - # A single image - pixel_values, image_hash, image_size, image_grid_thw = ( - await self._process_single_image(image_data) - ) - image_hashes = [image_hash] - image_sizes = [image_size] - image_grid_thws = [image_grid_thw] - else: - - raise ValueError(f"Invalid image data: {image_data}") - - return { - "pixel_values": pixel_values, - "image_hashes": image_hashes, - "image_sizes": image_sizes, - "modalities": request_obj.modalities or ["image"], - "image_grid_thws": image_grid_thws, - } - - -class Qwen2_5VLImageProcessor(BaseImageProcessor): - def __init__(self, hf_config, server_args, _processor): - super().__init__(hf_config, server_args, _processor) - self.IMAGE_TOKEN = "<|vision_start|><|image_pad|><|vision_end|>" - self.IM_START_TOKEN_ID = hf_config.vision_start_token_id - self.IM_END_TOKEN_ID = hf_config.vision_end_token_id - self.NUM_TOKEN_PER_FRAME = 770 - - @staticmethod - def _process_images_task(images, input_text): - result = global_processor.__call__( - text=input_text, images=images, return_tensors="pt" - ) - return { - "input_ids": result.input_ids, - "pixel_values": result.pixel_values, - "image_grid_thws": result.image_grid_thw, - } - - async def _process_images(self, images, input_text) -> dict: - if self.executor is not None: - loop = asyncio.get_event_loop() - return await loop.run_in_executor( - self.executor, - Qwen2_5VLImageProcessor._process_images_task, - images, - input_text, - ) - else: - return self._process_images_task(images, input_text) - - async def process_images_async( - self, - image_data: List[Union[str, bytes]], - input_ids, - request_obj, - max_req_input_len, - *args, - **kwargs, - ): - if not image_data: - return None - if isinstance(image_data, str): - image_data = [image_data] - - image_token = self.IMAGE_TOKEN - base_output = self.load_images( - max_req_input_len, input_ids, image_data, image_token - ) - - ret = await self._process_images(base_output.all_frames, base_output.input_text) - - return { - "input_ids": ret["input_ids"].flatten().tolist(), - "pixel_values": ret["pixel_values"], - "image_hashes": base_output.image_hashes, - "modalities": request_obj.modalities or ["image"], - "image_grid_thws": ret["image_grid_thws"], - "im_start_id": self.IM_START_TOKEN_ID, - "im_end_id": self.IM_END_TOKEN_ID, - } +IMAGE_PROCESSOR_MAPPING = {} def get_image_processor( hf_config, server_args: ServerArgs, processor ) -> BaseImageProcessor: - if "MllamaForConditionalGeneration" in hf_config.architectures: - return MllamaImageProcessor(hf_config, server_args, processor) - elif "Qwen2VLForConditionalGeneration" in hf_config.architectures: - - return Qwen2VLImageProcessor(hf_config, server_args, processor) - elif "Qwen2_5_VLForConditionalGeneration" in hf_config.architectures: - return Qwen2_5VLImageProcessor(hf_config, server_args, processor) - - elif "MiniCPMV" in hf_config.architectures: - return MiniCPMVImageProcessor(hf_config, server_args, processor) - else: - return LlavaImageProcessor(hf_config, server_args, processor.image_processor) + for model_cls, processor_cls in IMAGE_PROCESSOR_MAPPING.items(): + if model_cls.__name__ in hf_config.architectures: + return processor_cls(hf_config, server_args, processor) + raise ValueError( + f"No image processor found for architecture: {hf_config.architectures}" + ) def get_dummy_image_processor(): return DummyImageProcessor() + + +@lru_cache() +def import_image_processors(): + package_name = "sglang.srt.managers.image_processors" + package = importlib.import_module(package_name) + for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."): + if not ispkg: + try: + module = importlib.import_module(name) + except Exception as e: + logger.warning(f"Ignore import error when loading {name}: " f"{e}") + continue + if hasattr(module, "ImageProcessorMapping"): + entry = module.ImageProcessorMapping + if isinstance(entry, dict): + for processor_name, cls in entry.items(): + IMAGE_PROCESSOR_MAPPING[processor_name] = cls + + +# also register processors +import_image_processors() diff --git a/python/sglang/srt/managers/image_processors/base_image_processor.py b/python/sglang/srt/managers/image_processors/base_image_processor.py new file mode 100644 index 000000000..b5365b8ae --- /dev/null +++ b/python/sglang/srt/managers/image_processors/base_image_processor.py @@ -0,0 +1,206 @@ +import concurrent +import concurrent.futures +import dataclasses +import multiprocessing as mp +import os +from abc import ABC, abstractmethod +from typing import Optional + +import PIL +import transformers +from decord import VideoReader, cpu +from PIL import Image + +from sglang.srt.server_args import ServerArgs +from sglang.srt.utils import load_image + +global global_processor + + +def get_global_processor(): + global global_processor + return global_processor + + +@dataclasses.dataclass +class BaseImageProcessorOutput: + image_hashes: list[int] + image_sizes: list[tuple[int, int]] + all_frames: [PIL.Image] + # input_text, with each frame of video/image represented as an image_token + input_text: str + + +class BaseImageProcessor(ABC): + def __init__(self, hf_config, server_args, _processor): + self.hf_config = hf_config + self._processor = _processor + self.server_args = server_args + # FIXME: not accurate, model and image specific + self.NUM_TOKEN_PER_FRAME = 330 + + self.executor = concurrent.futures.ProcessPoolExecutor( + initializer=init_global_processor, + mp_context=mp.get_context("fork"), + initargs=( + self, + server_args, + ), + max_workers=int(os.environ.get("SGLANG_CPU_COUNT", os.cpu_count())), + ) + + def _build_processor(self, server_args): + """Init the global processor for multi modal models.""" + from sglang.srt.hf_transformers_utils import get_processor + + return get_processor( + server_args.tokenizer_path, + tokenizer_mode=server_args.tokenizer_mode, + trust_remote_code=server_args.trust_remote_code, + ) + + @abstractmethod + async def process_images_async( + self, image_data, input_text, max_req_input_len, **kwargs + ): + pass + + def get_estimated_frames_list(self, image_data): + """ + estimate the total frame count from all visual input + """ + # Before processing inputs + estimated_frames_list = [] + for image in image_data: + if isinstance(image, str) and image.startswith("video:"): + path = image[len("video:") :] + # Estimate frames for the video + vr = VideoReader(path, ctx=cpu(0)) + num_frames = len(vr) + else: + # For images, each contributes one frame + num_frames = 1 + estimated_frames_list.append(num_frames) + + return estimated_frames_list + + @staticmethod + def encode_video(video_path, frame_count_limit=None): + if not os.path.exists(video_path): + logger.error(f"Video {video_path} does not exist") + return [] + + if frame_count_limit == 0: + return [] + + def uniform_sample(l, n): + gap = len(l) / n + idxs = [int(i * gap + gap / 2) for i in range(n)] + return [l[i] for i in idxs] + + vr = VideoReader(video_path, ctx=cpu(0)) + sample_fps = round(vr.get_avg_fps() / 1) # FPS + frame_indices = [i for i in range(0, len(vr), sample_fps)] + if frame_count_limit is not None and len(frame_indices) > frame_count_limit: + frame_indices = uniform_sample(frame_indices, frame_count_limit) + + frames = vr.get_batch(frame_indices).asnumpy() + frames = [Image.fromarray(v.astype("uint8")) for v in frames] + return frames + + def load_images( + self, + input_ids: list, + image_data, + image_token: str, + max_req_input_len: int, + return_text: Optional[bool] = True, + discard_alpha_channel: bool = True, + ) -> BaseImageProcessorOutput: + """ + Each frame of video/image will be replaced by a single image token + """ + image_hashes, image_sizes = [], [] + all_frames = [] + new_text_parts = [] + + if isinstance(input_ids, list) and return_text: + assert len(input_ids) and isinstance(input_ids[0], int) + input_text = self._processor.tokenizer.decode(input_ids) + else: + input_text = input_ids + + if return_text: + text_parts = input_text.split(image_token) + + # roughly calculate the max number of frames under the max_req_input_len limit + MAX_NUM_FRAMES = 30 + estimated_frames_list = self.get_estimated_frames_list(image_data=image_data) + total_frame_count = sum(estimated_frames_list) + # a heuristic value, suggesting the maximum fraction of frames to embed from all visual inputs. + # e.g., 0.1 suggests that 1 frame out of 10 input frames should be used + scaling_factor = min(1.0, MAX_NUM_FRAMES / total_frame_count) + + assert len(image_data) == len(estimated_frames_list) + + # Process each input with allocated frames + for image_index, (image, estimated_frames) in enumerate( + zip(image_data, estimated_frames_list) + ): + if len(all_frames) >= MAX_NUM_FRAMES: + max_frames_to_process = 0 + else: + max_frames_to_process = max(1, int(estimated_frames * scaling_factor)) + + if max_frames_to_process == 0: + frames = [] + else: + try: + if isinstance(image, str) and image.startswith("video:"): + path = image[len("video:") :] + frames = BaseImageProcessor.encode_video( + path, frame_count_limit=max_frames_to_process + ) + else: + raw_image, _size = load_image(image) + if discard_alpha_channel: + raw_image = raw_image.convert("RGB") + frames = [raw_image] + assert len(frames) != 0 + except FileNotFoundError as e: + print(e) + return None + + image_sizes += [frames[0].size] * len(frames) + image_hashes += [hash(image)] * len(frames) + all_frames += frames + + if return_text: + new_text_parts.append(text_parts[image_index]) + if max_frames_to_process != 0: + new_text_parts.append(image_token * len(frames)) + assert max_frames_to_process >= len(frames) + if return_text: + new_text_parts.append(text_parts[-1]) + + input_text = "".join(new_text_parts) + return BaseImageProcessorOutput( + image_hashes, image_sizes, all_frames, input_text + ) + + +class DummyImageProcessor(BaseImageProcessor): + def __init__(self): + pass + + async def process_images_async(self, *args, **kwargs): + return None + + +def init_global_processor( + sglang_image_processor: BaseImageProcessor, server_args: ServerArgs +): + """Init the global processor for multi-modal models.""" + global global_processor + transformers.logging.set_verbosity_error() + global_processor = sglang_image_processor._build_processor(server_args=server_args) diff --git a/python/sglang/srt/managers/image_processors/llava.py b/python/sglang/srt/managers/image_processors/llava.py new file mode 100644 index 000000000..eee08ff40 --- /dev/null +++ b/python/sglang/srt/managers/image_processors/llava.py @@ -0,0 +1,152 @@ +import asyncio +from typing import List, Optional, Union + +import numpy as np + +from sglang.srt.managers.image_processor import BaseImageProcessor +from sglang.srt.managers.image_processors.base_image_processor import ( + get_global_processor, +) +from sglang.srt.mm_utils import expand2square, process_anyres_image +from sglang.srt.models.llava import LlavaMistralForCausalLM, LlavaQwenForCausalLM +from sglang.srt.models.llavavid import LlavaVidForCausalLM +from sglang.srt.utils import load_image, logger +from sglang.utils import get_exception_traceback + + +class LlavaImageProcessor(BaseImageProcessor): + def __init__(self, hf_config, server_args, _processor): + super().__init__(hf_config, server_args, _processor) + + @staticmethod + def _process_single_image_task( + image_data: Union[str, bytes], + image_aspect_ratio: Optional[str] = None, + image_grid_pinpoints: Optional[str] = None, + image_processor=None, + ): + processor = get_global_processor() + + image_processor = image_processor or processor.image_processor + + try: + image, image_size = load_image(image_data) + if image_size is not None: + # It is a video with multiple images + image_hash = hash(image_data) + pixel_values = image_processor(image)["pixel_values"] + for _ in range(len(pixel_values)): + pixel_values[_] = pixel_values[_].astype(np.float16) + pixel_values = np.stack(pixel_values, axis=0) + return pixel_values, image_hash, image_size + else: + # It is an image + image_hash = hash(image_data) + if image_aspect_ratio == "pad": + image = expand2square( + image, + tuple(int(x * 255) for x in image_processor.image_mean), + ) + pixel_values = image_processor(image.convert("RGB"))[ + "pixel_values" + ][0] + elif image_aspect_ratio == "anyres" or ( + image_aspect_ratio is not None + and "anyres_max" in image_aspect_ratio + ): + pixel_values = process_anyres_image( + image, image_processor, image_grid_pinpoints + ) + else: + pixel_values = image_processor(image)["pixel_values"][0] + + if isinstance(pixel_values, np.ndarray): + pixel_values = pixel_values.astype(np.float16) + + return pixel_values, image_hash, image.size + except Exception: + logger.error("Exception in TokenizerManager:\n" + get_exception_traceback()) + + async def _process_single_image( + self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str + ): + if self.executor is not None: + loop = asyncio.get_event_loop() + return await loop.run_in_executor( + self.executor, + LlavaImageProcessor._process_single_image_task, + image_data, + aspect_ratio, + grid_pinpoints, + ) + else: + return self._process_single_image_task( + image_data, aspect_ratio, grid_pinpoints + ) + + async def process_images_async( + self, + image_data: List[Union[str, bytes]], + input_text, + request_obj, + *args, + **kwargs, + ): + if not image_data: + return None + + modalities = request_obj.modalities or ["image"] + aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None) + grid_pinpoints = ( + self.hf_config.image_grid_pinpoints + if hasattr(self.hf_config, "image_grid_pinpoints") + and "anyres" in aspect_ratio + else None + ) + + if isinstance(image_data, str): + image_data = [image_data] + + if isinstance(image_data, list) and len(image_data) > 0: + if "multi-images" in modalities or "video" in modalities: + # Multiple images + aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres + pixel_values, image_hashes, image_sizes = [], [], [] + res = [] + for img_data in image_data: + res.append( + self._process_single_image( + img_data, aspect_ratio, grid_pinpoints + ) + ) + res = await asyncio.gather(*res) + for pixel_v, image_h, image_s in res: + pixel_values.append(pixel_v) + image_hashes.append(image_h) + image_sizes.append(image_s) + + if isinstance(pixel_values[0], np.ndarray): + pixel_values = np.stack(pixel_values, axis=0) + else: + # A single image + pixel_values, image_hash, image_size = await self._process_single_image( + image_data[0], aspect_ratio, grid_pinpoints + ) + image_hashes = [image_hash] + image_sizes = [image_size] + else: + raise ValueError(f"Invalid image data: {image_data}") + + return { + "pixel_values": pixel_values, + "image_hashes": image_hashes, + "image_sizes": image_sizes, + "modalities": request_obj.modalities or ["image"], + } + + +ImageProcessorMapping = { + LlavaVidForCausalLM: LlavaImageProcessor, + LlavaQwenForCausalLM: LlavaImageProcessor, + LlavaMistralForCausalLM: LlavaImageProcessor, +} diff --git a/python/sglang/srt/managers/image_processors/minicpmv.py b/python/sglang/srt/managers/image_processors/minicpmv.py new file mode 100644 index 000000000..1b36f7fe4 --- /dev/null +++ b/python/sglang/srt/managers/image_processors/minicpmv.py @@ -0,0 +1,86 @@ +import asyncio +from typing import List, Union + +from sglang.srt.managers.image_processor import BaseImageProcessor +from sglang.srt.managers.image_processors.base_image_processor import ( + get_global_processor, +) +from sglang.srt.models.minicpmv import MiniCPMV + + +class MiniCPMVImageProcessor(BaseImageProcessor): + def __init__(self, hf_config, server_args, _processor): + super().__init__(hf_config, server_args, _processor) + self.IMAGE_TOKEN = "(./)" + + @staticmethod + def _process_images_task(images, input_text): + processor = get_global_processor() + result = processor.__call__(text=input_text, images=images, return_tensors="pt") + return { + "input_ids": result.input_ids, + "pixel_values": result.pixel_values, + "tgt_sizes": result.tgt_sizes, + } + + async def _process_images(self, images, input_text): + if self.executor is not None: + loop = asyncio.get_event_loop() + image_inputs = await loop.run_in_executor( + self.executor, + MiniCPMVImageProcessor._process_images_task, + images, + input_text, + ) + else: + image_inputs = self._processor( + images=images, text=input_text, return_tensors="pt" + ) + + return image_inputs + + async def process_images_async( + self, + image_data: List[Union[str, bytes]], + input_ids, + request_obj, + max_req_input_len, + ): + if not image_data: + return None + if not isinstance(image_data, list): + image_data = [image_data] + + base_output = self.load_images( + input_ids, image_data, self.IMAGE_TOKEN, max_req_input_len + ) + if base_output is None: + return None + + if len(base_output.all_frames) == 0: + return None + res = await self._process_images( + images=base_output.all_frames, input_text=base_output.input_text + ) + + # Collect special token ids + tokenizer = self._processor.tokenizer + im_start_id = tokenizer.im_start_id + im_end_id = tokenizer.im_end_id + if tokenizer.slice_start_id: + slice_start_id = tokenizer.slice_start_id + slice_end_id = tokenizer.slice_end_id + return { + "input_ids": res["input_ids"].flatten().tolist(), + "pixel_values": res["pixel_values"], + "tgt_sizes": res["tgt_sizes"], + "image_hashes": base_output.image_hashes, + "modalities": request_obj.modalities or ["image"], + "im_start_id": im_start_id, + "im_end_id": im_end_id, + "slice_start_id": slice_start_id, + "slice_end_id": slice_end_id, + } + + +ImageProcessorMapping = {MiniCPMV: MiniCPMVImageProcessor} diff --git a/python/sglang/srt/managers/image_processors/mlama.py b/python/sglang/srt/managers/image_processors/mlama.py new file mode 100644 index 000000000..8043067d8 --- /dev/null +++ b/python/sglang/srt/managers/image_processors/mlama.py @@ -0,0 +1,60 @@ +import asyncio +from typing import List, Union + +from sglang.srt.managers.image_processor import BaseImageProcessor +from sglang.srt.managers.image_processors.base_image_processor import ( + get_global_processor, +) +from sglang.srt.models.mllama import MllamaForConditionalGeneration +from sglang.srt.utils import load_image + + +class MllamaImageProcessor(BaseImageProcessor): + def __init__(self, hf_config, server_args, _processor): + super().__init__(hf_config, server_args, _processor) + + @staticmethod + def _process_single_image_task(images, input_text): + # input_ids', 'attention_mask', 'pixel_values', 'aspect_ratio_ids', 'aspect_ratio_mask', 'cross_attention_mask' + return get_global_processor()(images, input_text, return_tensors="pt") + + async def _process_single_image(self, images, input_text): + if self.executor is not None: + loop = asyncio.get_event_loop() + image_inputs = await loop.run_in_executor( + self.executor, + MllamaImageProcessor._process_single_image_task, + images, + input_text, + ) + else: + image_inputs = self._processor(images, input_text, return_tensors="pt") + + return image_inputs + + async def process_images_async( + self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs + ): + if not image_data: + return None + + if isinstance(input_text, list): + assert len(input_text) and isinstance(input_text[0], int) + input_text = self._processor.tokenizer.decode(input_text) + + if not isinstance(image_data, list): + image_data = [image_data] + + if len(image_data) > 0: + images = [load_image(image)[0] for image in image_data] + else: + images = load_image(image_data[0])[0] + + image_inputs = await self._process_single_image(images, input_text) + image_inputs["image_hashes"] = [hash(str(image_data))] + image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0] + + return image_inputs + + +ImageProcessorMapping = {MllamaForConditionalGeneration: MllamaImageProcessor} diff --git a/python/sglang/srt/managers/image_processors/qwen_vl.py b/python/sglang/srt/managers/image_processors/qwen_vl.py new file mode 100644 index 000000000..eeebf6f81 --- /dev/null +++ b/python/sglang/srt/managers/image_processors/qwen_vl.py @@ -0,0 +1,161 @@ +import asyncio +import math +from typing import List, Union + +from PIL import Image + +from sglang.srt.managers.image_processor import BaseImageProcessor +from sglang.srt.managers.image_processors.base_image_processor import ( + get_global_processor, +) +from sglang.srt.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration +from sglang.srt.models.qwen2_vl import Qwen2VLForConditionalGeneration + + +# Compatible with Qwen2VL and Qwen2_5VL +class Qwen2_5VLImageProcessor(BaseImageProcessor): + def __init__(self, hf_config, server_args, _processor): + super().__init__(hf_config, server_args, _processor) + self.IMAGE_TOKEN = "<|vision_start|><|image_pad|><|vision_end|>" + self.IM_START_TOKEN_ID = hf_config.vision_start_token_id + self.IM_END_TOKEN_ID = hf_config.vision_end_token_id + self.image_token_id = hf_config.image_token_id + self.video_token_id = hf_config.video_token_id + self.NUM_TOKEN_PER_FRAME = 770 + self.IMAGE_FACTOR = 28 + self.MIN_PIXELS = 4 * 28 * 28 + self.MAX_PIXELS = 16384 * 28 * 28 + self.MAX_PIXELS = 16384 * 28 * 28 + self.MAX_RATIO = 200 + + @staticmethod + def _process_images_task(images, input_text, _hf_config): + if isinstance(images, list) and len(images) == 0: + images = None + result = get_global_processor().__call__( + text=[input_text], images=images, padding=True, return_tensors="pt" + ) + + return { + "input_ids": result.input_ids, + "pixel_values": getattr(result, "pixel_values", None), + "image_grid_thw": getattr(result, "image_grid_thw", None), + "second_per_grid_ts": getattr(result, "second_per_grid_ts", None), + "video_grid_thws": getattr(result, "video_grid_thws", None), + } + + async def _process_images(self, images, input_text) -> dict: + if self.executor is not None: + loop = asyncio.get_event_loop() + return await loop.run_in_executor( + self.executor, + Qwen2_5VLImageProcessor._process_images_task, + images, + input_text, + self.hf_config, + ) + else: + return self._process_images_task(images, input_text, self.hf_config) + + async def process_images_async( + self, + image_data: List[Union[str, bytes]], + input_ids, + request_obj, + max_req_input_len, + *args, + **kwargs, + ): + if not image_data: + return None + if isinstance(image_data, str): + image_data = [image_data] + + image_token = self.IMAGE_TOKEN + base_output = self.load_images( + input_ids, + image_data, + image_token, + max_req_input_len, + ) + + def smart_resize( + height: int, + width: int, + factor: int = self.IMAGE_FACTOR, + min_pixels: int = self.MIN_PIXELS, + max_pixels: int = self.MAX_PIXELS, + ) -> tuple[int, int]: + """ + Rescales the image so that the following conditions are met: + + 1. Both dimensions (height and width) are divisible by 'factor'. + + 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. + + 3. The aspect ratio of the image is maintained as closely as possible. + """ + if max(height, width) / min(height, width) > self.MAX_RATIO: + raise ValueError( + f"absolute aspect ratio must be smaller than {self.MAX_RATIO}, got {max(height, width) / min(height, width)}" + ) + h_bar = max(factor, round_by_factor(height, factor)) + w_bar = max(factor, round_by_factor(width, factor)) + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = floor_by_factor(height / beta, factor) + w_bar = floor_by_factor(width / beta, factor) + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = ceil_by_factor(height * beta, factor) + w_bar = ceil_by_factor(width * beta, factor) + return h_bar, w_bar + + def resize_image(image, size_factor: int = self.IMAGE_FACTOR) -> Image.Image: + width, height = image.size + min_pixels = self.MIN_PIXELS + max_pixels = self.MAX_PIXELS + resized_height, resized_width = smart_resize( + height, + width, + factor=size_factor, + min_pixels=min_pixels, + max_pixels=max_pixels, + ) + image = image.resize((resized_width, resized_height)) + return image + + def round_by_factor(number: int, factor: int) -> int: + """Returns the closest integer to 'number' that is divisible by 'factor'.""" + return round(number / factor) * factor + + def ceil_by_factor(number: int, factor: int) -> int: + """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.""" + return math.ceil(number / factor) * factor + + def floor_by_factor(number: int, factor: int) -> int: + """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" + return math.floor(number / factor) * factor + + images = [resize_image(image) for image in base_output.all_frames] + + ret = await self._process_images(images, base_output.input_text) + return { + "input_ids": ret["input_ids"].flatten().tolist(), + "pixel_values": ret["pixel_values"], + "image_hashes": base_output.image_hashes, + "modalities": request_obj.modalities or ["image"], + "image_grid_thws": ret["image_grid_thw"], + "video_grid_thws": ret["video_grid_thws"], + "im_start_id": self.IM_START_TOKEN_ID, + "im_end_id": self.IM_END_TOKEN_ID, + "im_token_id": self.image_token_id, + "video_token_id": self.video_token_id, + "second_per_grid_ts": ret["second_per_grid_ts"], + } + + +ImageProcessorMapping = { + Qwen2VLForConditionalGeneration: Qwen2_5VLImageProcessor, + Qwen2_5_VLForConditionalGeneration: Qwen2_5VLImageProcessor, +} diff --git a/python/sglang/srt/managers/multi_modality_padding.py b/python/sglang/srt/managers/multi_modality_padding.py new file mode 100644 index 000000000..b0b662b7c --- /dev/null +++ b/python/sglang/srt/managers/multi_modality_padding.py @@ -0,0 +1,134 @@ +from abc import abstractmethod +from typing import Callable, List, Optional, Tuple + +from sglang.srt.managers.schedule_batch import ImageInputs +from sglang.utils import logger + + +class MultiModalityDataPaddingPattern: + """ + Data tokens (like image tokens) often need special handling during padding + to maintain model compatibility. This class provides the interface for + implementing different padding strategies for data tokens + """ + + @abstractmethod + def pad_input_tokens( + self, input_ids: List[int], image_inputs: ImageInputs + ) -> List[int]: + """ + Pad the input ids sequence containing data tokens, and replace them with pad_values + """ + pass + + +class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern): + """In this pattern, data tokens should be enclosed by special token pairs (e.g. ..., data_token_pairs) + + This strategy should be applied when data content is marked by start/end token pairs in the input sequence. + """ + + def __init__(self, data_token_pairs: Optional[List[Tuple[int, int]]]) -> None: + self.data_token_id_pairs = data_token_pairs + + def pad_input_tokens( + self, input_ids: List[int], image_inputs: ImageInputs + ) -> List[int]: + """ + This function will replace the data-tokens inbetween with pad_values accordingly + """ + pad_values = image_inputs.pad_values + data_token_pairs = self.data_token_id_pairs + image_inputs.image_offsets = [] + if data_token_pairs is None: + data_token_pairs = [image_inputs.im_start_id, image_inputs.im_end_id] + if data_token_pairs is None: + logger.warning( + "No data_token_pairs provided, RadixAttention might be influenced." + ) + return input_ids + start_token_ids = [s for s, _e in data_token_pairs] + end_tokens_ids = [e for _s, e in data_token_pairs] + # First start token marks new data + data_start_token = start_token_ids[0] + + padded_ids = [] + last_idx = 0 + data_idx = -1 + + start_indices = [i for i, x in enumerate(input_ids) if x in start_token_ids] + end_indices = [i for i, x in enumerate(input_ids) if x in end_tokens_ids] + + if len(start_indices) != len(end_indices): + return input_ids + + for start_idx, end_idx in zip(start_indices, end_indices): + padded_ids.extend(input_ids[last_idx : start_idx + 1]) + + if input_ids[start_idx] == data_start_token: + data_idx += 1 + image_inputs.image_offsets += [start_idx] + + num_tokens = end_idx - start_idx - 1 + pad_value = pad_values[data_idx] + padded_ids.extend([pad_value] * num_tokens) + + last_idx = end_idx + + padded_ids.extend(input_ids[last_idx:]) + + assert len(input_ids) == len(padded_ids) + return padded_ids + + +class MultModalityDataPaddingPatternSingleToken(MultiModalityDataPaddingPattern): + """In this pattern, data is represented with a special token_id ( image_inputs.im_token_id ), + which needs first to be expanded to multiple tokens, then replaced with their padding values + + This strategy should be used when a single data token represents content that should + be expanded to multiple tokens during processing. + """ + + def __init__( + self, num_data_token_calc_func: Callable[[Tuple[int, int, int]], int] + ) -> None: + self.num_data_token_calc_func = num_data_token_calc_func + + def pad_input_tokens( + self, input_ids: List[int], image_inputs: ImageInputs + ) -> List[int]: + """ + This function will follow the procedure of: + 1. the data token will be expanded, of which the final number will be calculated by `num_data_token_calc_func` + 2. the padded data tokens will be replaced with their pad_values + """ + image_grid_thws = image_inputs.image_grid_thws + pad_values = image_inputs.pad_values + + image_indices = [ + idx + for idx, token in enumerate(input_ids) + if token == image_inputs.im_token_id + ] + + image_inputs.image_offsets = [] + + input_ids_with_image = [] + for image_cnt, _ in enumerate(image_grid_thws): + print(f"image_cnt {image_cnt}") + num_image_tokens = self.num_data_token_calc_func(image_grid_thws[image_cnt]) + if image_cnt == 0: + non_image_tokens = input_ids[: image_indices[image_cnt]] + else: + non_image_tokens = input_ids[ + image_indices[image_cnt - 1] + 1 : image_indices[image_cnt] + ] + input_ids_with_image.extend(non_image_tokens) + image_inputs.image_offsets.append(len(input_ids_with_image)) + pad_ids = pad_values * ( + (num_image_tokens + len(pad_values)) // len(pad_values) + ) + input_ids_with_image.extend(pad_ids[:num_image_tokens]) + input_ids_with_image.extend(input_ids[image_indices[-1] + 1 :]) + + return input_ids_with_image diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 5cd2de6b8..219b9d145 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -158,15 +158,19 @@ class ImageInputs: image_grid_thws: List[Tuple[int, int, int]] = None mrope_position_delta: Optional[torch.Tensor] = None - # MiniCPMV related + # The id of the single-image placeholder token + im_token_id: Optional[torch.Tensor] = None # All the images in the batch should share the same special image # bound token ids. - im_start_id: Optional[torch.Tensor] = None - im_end_id: Optional[torch.Tensor] = None - slice_start_id: Optional[torch.Tensor] = None - slice_end_id: Optional[torch.Tensor] = None + im_start_id: Optional[int] = None + im_end_id: Optional[int] = None + slice_start_id: Optional[int] = None + slice_end_id: Optional[int] = None tgt_sizes: Optional[list] = None + # denotes the number of valid image tokens in each image + images_emb_mask: Optional[torch.BoolTensor] = None + @staticmethod def from_dict(obj: dict): ret = ImageInputs( @@ -186,11 +190,13 @@ class ImageInputs: "aspect_ratio_ids", "aspect_ratio_mask", "image_grid_thws", + "im_token_id", "im_start_id", "im_end_id", "slice_start_id", "slice_end_id", "tgt_sizes", + "images_emb_mask", ] for arg in optional_args: if arg in obj: diff --git a/python/sglang/srt/model_loader/weight_utils.py b/python/sglang/srt/model_loader/weight_utils.py index 1106c6cb7..be54f8a5d 100644 --- a/python/sglang/srt/model_loader/weight_utils.py +++ b/python/sglang/srt/model_loader/weight_utils.py @@ -455,7 +455,7 @@ def pt_weights_iterator( disable=not enable_tqdm, bar_format=_BAR_FORMAT, ): - state = torch.load(bin_file, map_location="cpu") + state = torch.load(bin_file, map_location="cpu", weights_only=True) yield from state.items() del state torch.cuda.empty_cache() diff --git a/python/sglang/srt/models/minicpmv.py b/python/sglang/srt/models/minicpmv.py index 7905c808b..0e98a1392 100644 --- a/python/sglang/srt/models/minicpmv.py +++ b/python/sglang/srt/models/minicpmv.py @@ -41,7 +41,6 @@ from torch import nn from torch.nn.init import trunc_normal_ from transformers import PretrainedConfig -from sglang.srt.distributed import divide, get_tensor_model_parallel_world_size from sglang.srt.layers.activation import get_act_fn from sglang.srt.layers.attention.vision import VisionAttention from sglang.srt.layers.linear import ( @@ -51,6 +50,9 @@ from sglang.srt.layers.linear import ( ) from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.managers.multi_modality_padding import ( + MultiModalityDataPaddingPatternTokenPairs, +) from sglang.srt.managers.schedule_batch import ImageInputs from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.utils import set_default_torch_dtype @@ -186,19 +188,16 @@ class Idefics2EncoderLayer(nn.Module): ) -> None: super().__init__() self.embed_dim = config.hidden_size - self.num_heads = config.num_attention_heads - tp_size = get_tensor_model_parallel_world_size() - num_heads_per_partition = divide(self.num_heads, tp_size) self.self_attn = VisionAttention( embed_dim=config.hidden_size, - num_heads=num_heads_per_partition, + num_heads=self.num_heads, projection_size=config.intermediate_size, use_qkv_parallel=True, quant_config=quant_config, dropout=config.attention_dropout, use_context_forward=False, - use_full_precision_softmax=True, + softmax_in_single_precision=True, flatten_batch=False, prefix=add_prefix("self_attn", prefix), ) @@ -708,21 +707,21 @@ class MiniCPMVBaseModel(nn.Module): self, input_ids: torch.Tensor, pad_values: List[int], - im_start_id: torch.Tensor, - im_end_id: torch.Tensor, - slice_start_id: Optional[torch.Tensor] = None, - slice_end_id: Optional[torch.Tensor] = None, + im_start_id: int, + im_end_id: int, + slice_start_id: Optional[int] = None, + slice_end_id: Optional[int] = None, ) -> torch.Tensor: """ Returns a tensor indicating the bounds (start and end token ids) of the images """ # All the images in the batch should share the same special image # bound token ids. - start_cond = input_ids == im_start_id[0] - end_cond = input_ids == im_end_id[0] + start_cond = input_ids == im_start_id + end_cond = input_ids == im_end_id if slice_start_id is not None: - start_cond |= input_ids == slice_start_id[0] - end_cond |= input_ids == slice_end_id[0] + start_cond |= input_ids == slice_start_id + end_cond |= input_ids == slice_end_id (image_start_tokens,) = torch.where(start_cond) image_start_tokens += 1 @@ -733,6 +732,8 @@ class MiniCPMVBaseModel(nn.Module): if ( len(image_start_tokens) + 1 == len(image_end_tokens) and input_ids[0] in pad_values + and len(image_start_tokens) != 0 + and len(image_end_tokens) != 0 and image_end_tokens[0] < image_start_tokens[0] ): image_start_tokens = torch.cat( @@ -897,9 +898,12 @@ class MiniCPMVBaseModel(nn.Module): forward_batch: ForwardBatch, **kwargs: Any, ) -> torch.Tensor: - if forward_batch.image_inputs is not None and forward_batch.image_inputs != [ - None - ]: + if ( + forward_batch.image_inputs is not None + and len(forward_batch.image_inputs) > 0 + and forward_batch.image_inputs[0] is not None + ): + # TODO: bath kwargs.update( { "pixel_values": ( @@ -1135,81 +1139,16 @@ class MiniCPMV2_6(MiniCPMVBaseModel): return self.resampler(vision_embedding, tgt_sizes) def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs): - if not isinstance(image_inputs.im_start_id, list) or not isinstance( - image_inputs.im_end_id, list - ): - return input_ids - - new_input_ids = [] - last_idx = 0 - image_idx = -1 - image_inputs.image_offsets = [] - # Get all special token IDs - im_start_id = ( - image_inputs.im_start_id[0].item() - if isinstance(image_inputs.im_start_id[0], torch.Tensor) - else image_inputs.im_start_id[0] - ) - im_end_id = ( - image_inputs.im_end_id[0].item() - if isinstance(image_inputs.im_end_id[0], torch.Tensor) - else image_inputs.im_end_id[0] - ) - slice_start_id = ( - image_inputs.slice_start_id[0].item() - if isinstance(image_inputs.slice_start_id[0], torch.Tensor) - else image_inputs.slice_start_id[0] - ) - slice_end_id = ( - image_inputs.slice_end_id[0].item() - if isinstance(image_inputs.slice_end_id[0], torch.Tensor) - else image_inputs.slice_end_id[0] - ) + im_start_id: int = image_inputs.im_start_id + im_end_id: int = image_inputs.im_end_id + slice_start_id: int = image_inputs.slice_start_id + slice_end_id: int = image_inputs.slice_end_id - # Find all start and end positions for both types - start_indices = [ - i - for i, x in enumerate(input_ids) - if x == im_start_id or x == slice_start_id - ] - end_indices = [ - i for i, x in enumerate(input_ids) if x == im_end_id or x == slice_end_id - ] + media_token_pairs = [(im_start_id, im_end_id), (slice_start_id, slice_end_id)] + pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs) - if len(start_indices) != len(end_indices): - return input_ids - # Process each region (both image and slice) - for start_idx, end_idx in zip(start_indices, end_indices): - # Add non-image tokens before this region - new_input_ids.extend( - input_ids[last_idx : start_idx + 1] - ) # include start token - - is_image_start = input_ids[start_idx] == im_start_id - - if is_image_start: - image_inputs.image_offsets += [start_idx] - image_idx += 1 - - num_tokens = end_idx - start_idx - 1 # exclude start and end tokens - - # Generate pad_ids - pad_values = [image_inputs.pad_values[image_idx]] - - pad_ids = pad_values * ((num_tokens + len(pad_values)) // len(pad_values)) - pad_ids = pad_ids[:num_tokens] - - # Add pad_ids - new_input_ids.extend(pad_ids) - - # Update last_idx to after end token - last_idx = end_idx - - # Add remaining tokens after last region - new_input_ids.extend(input_ids[last_idx:]) - assert len(input_ids) == len(new_input_ids) - return new_input_ids + return pattern.pad_input_tokens(input_ids, image_inputs) _SUPPORT_VERSION = {(2, 6): MiniCPMV2_6} diff --git a/python/sglang/srt/models/mllama.py b/python/sglang/srt/models/mllama.py index dd52ae6fd..6e33955bb 100644 --- a/python/sglang/srt/models/mllama.py +++ b/python/sglang/srt/models/mllama.py @@ -202,7 +202,7 @@ class MllamaVisionEncoderLayer(nn.Module): quant_config=None, dropout=0.0, use_context_forward=False, - use_full_precision_softmax=False, + softmax_in_single_precision=False, flatten_batch=False, prefix=add_prefix("self_attn", prefix), ) diff --git a/python/sglang/srt/models/qwen2_5_vl.py b/python/sglang/srt/models/qwen2_5_vl.py index 027153be9..2792c0c98 100644 --- a/python/sglang/srt/models/qwen2_5_vl.py +++ b/python/sglang/srt/models/qwen2_5_vl.py @@ -47,6 +47,9 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead +from sglang.srt.managers.multi_modality_padding import ( + MultiModalityDataPaddingPatternTokenPairs, +) from sglang.srt.managers.schedule_batch import ImageInputs from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader @@ -121,12 +124,12 @@ class Qwen2_5_VisionBlock(nn.Module): self.norm2 = Qwen2RMSNorm(dim, eps=1e-6) if attn_implementation == "sdpa": use_context_forward = False - use_full_precision_softmax = False + softmax_in_single_precision = False elif attn_implementation == "flash_attention_2": - use_full_precision_softmax = False + softmax_in_single_precision = False use_context_forward = True elif attn_implementation == "eager": - use_full_precision_softmax = True + softmax_in_single_precision = True use_context_forward = False self.attn = VisionAttention( @@ -135,7 +138,7 @@ class Qwen2_5_VisionBlock(nn.Module): projection_size=dim, use_qkv_parallel=False, use_context_forward=use_context_forward, - use_full_precision_softmax=use_full_precision_softmax, + softmax_in_single_precision=softmax_in_single_precision, flatten_batch=True, quant_config=quant_config, prefix=add_prefix("attn", prefix), @@ -149,12 +152,17 @@ class Qwen2_5_VisionBlock(nn.Module): ) def forward( - self, x: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + position_embeddings: torch.Tensor, ) -> torch.Tensor: hidden_states = self.norm1(x) hidden_states = rearrange(hidden_states, "s b ... -> b s ...") attn = self.attn( - hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb + hidden_states, + cu_seqlens=cu_seqlens, + position_embeddings=position_embeddings, ) attn = rearrange(attn, "b s ... -> s b ...") x = x + attn @@ -443,6 +451,8 @@ class Qwen2_5_VisionTransformer(nn.Module): ) rotary_pos_emb = rotary_pos_emb[window_index, :, :] rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + position_embeddings = (emb.cos(), emb.sin()) # compute cu_seqlens cu_seqlens = torch.repeat_interleave( @@ -457,7 +467,9 @@ class Qwen2_5_VisionTransformer(nn.Module): cu_seqlens_now = cu_seqlens else: cu_seqlens_now = cu_window_seqlens - x = blk(x, cu_seqlens=cu_seqlens_now, rotary_pos_emb=rotary_pos_emb) + x = blk( + x, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings + ) # adapter x = self.merger(x) @@ -522,50 +534,14 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module): return num_image_tokens def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs): - new_input_ids = [] - last_idx = 0 - image_idx = -1 - image_inputs.image_offsets = [] - # Get all special token IDs - im_start_id = image_inputs.im_start_id - im_end_id = image_inputs.im_end_id + im_start_id: int = image_inputs.im_start_id + im_end_id: int = image_inputs.im_end_id - # Find all start and end positions for both types - start_indices = [i for i, x in enumerate(input_ids) if x == im_start_id] - end_indices = [i for i, x in enumerate(input_ids) if x == im_end_id] + media_token_pairs = [(im_start_id, im_end_id)] + pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs) - if len(start_indices) != len(end_indices): - return input_ids - # Process each region (both image and slice) - for start_idx, end_idx in zip(start_indices, end_indices): - # Add non-image tokens before this region - new_input_ids.extend(input_ids[last_idx : start_idx + 1]) - - is_image_start = input_ids[start_idx] == im_start_id - - if is_image_start: - image_inputs.image_offsets += [start_idx] - image_idx += 1 - - num_tokens = end_idx - start_idx - 1 # exclude start and end tokens - - # Generate pad_ids - pad_values = [image_inputs.pad_values[image_idx]] - - pad_ids = pad_values * ((num_tokens + len(pad_values)) // len(pad_values)) - pad_ids = pad_ids[:num_tokens] - - # Add pad_ids - new_input_ids.extend(pad_ids) - - # Update last_idx to after end token - last_idx = end_idx - - # Add remaining tokens after last region - new_input_ids.extend(input_ids[last_idx:]) - assert len(input_ids) == len(new_input_ids) - return new_input_ids + return pattern.pad_input_tokens(input_ids, image_inputs) def _process_image_input(self, image_input: Qwen2VLImageInputs) -> torch.Tensor: pixel_values = image_input["pixel_values"].type(self.visual.dtype) @@ -629,7 +605,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module): extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy() prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu for i, image in enumerate(forward_batch.image_inputs): - if image is None: + if image is None or image.pixel_values is None: continue start_idx = extend_start_loc_cpu[i] prefix_len = prefix_lens_cpu[i] diff --git a/python/sglang/srt/models/qwen2_vl.py b/python/sglang/srt/models/qwen2_vl.py index 63b479113..1bfcf526b 100644 --- a/python/sglang/srt/models/qwen2_vl.py +++ b/python/sglang/srt/models/qwen2_vl.py @@ -42,6 +42,9 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead +from sglang.srt.managers.multi_modality_padding import ( + MultiModalityDataPaddingPatternTokenPairs, +) from sglang.srt.managers.schedule_batch import ImageInputs from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader @@ -137,12 +140,12 @@ class Qwen2VisionBlock(nn.Module): mlp_hidden_dim = int(dim * mlp_ratio) if attn_implementation == "sdpa": use_context_forward = False - use_full_precision_softmax = False + softmax_in_single_precision = False elif attn_implementation == "flash_attention_2": - use_full_precision_softmax = False + softmax_in_single_precision = False use_context_forward = True elif attn_implementation == "eager": - use_full_precision_softmax = True + softmax_in_single_precision = True use_context_forward = False self.attn = VisionAttention( @@ -151,7 +154,7 @@ class Qwen2VisionBlock(nn.Module): projection_size=dim, use_qkv_parallel=False, use_context_forward=use_context_forward, - use_full_precision_softmax=use_full_precision_softmax, + softmax_in_single_precision=softmax_in_single_precision, flatten_batch=True, quant_config=quant_config, prefix=add_prefix("attn", prefix), @@ -165,12 +168,17 @@ class Qwen2VisionBlock(nn.Module): ) def forward( - self, x: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + position_embeddings: torch.Tensor, ) -> torch.Tensor: hidden_states = self.norm1(x) hidden_states = rearrange(hidden_states, "s b ... -> b s ...") attn = self.attn( - hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb + hidden_states, + cu_seqlens=cu_seqlens, + position_embeddings=position_embeddings, ) attn = rearrange(attn, "b s ... -> s b ...") x = x + attn @@ -392,7 +400,8 @@ class Qwen2VisionTransformer(nn.Module): # compute position embedding rotary_pos_emb = self.rot_pos_emb(grid_thw) - + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + position_embeddings = (emb.cos(), emb.sin()) # compute cu_seqlens cu_seqlens = torch.repeat_interleave( grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] @@ -402,7 +411,7 @@ class Qwen2VisionTransformer(nn.Module): # transformers x = x.unsqueeze(1) for blk in self.blocks: - x = blk(x, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb) + x = blk(x, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings) # adapter x = self.merger(x) @@ -425,40 +434,6 @@ class Qwen2VLForConditionalGeneration(nn.Module): ) return num_image_tokens - # Use grid_t * grid_w * grid_h to pad tokens for each image - # add replaced padding by unique image hash - def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs): - image_grid_thws = image_inputs.image_grid_thws - pad_values = image_inputs.pad_values - - image_indices = [ - idx - for idx, token in enumerate(input_ids) - if token == self.config.image_token_id - ] - image_inputs.image_offsets = [] - - input_ids_with_image = [] - for image_cnt, _ in enumerate(image_grid_thws): - num_image_tokens = self.calculate_num_image_tokens( - image_grid_thws[image_cnt] - ) - if image_cnt == 0: - non_image_tokens = input_ids[: image_indices[image_cnt]] - else: - non_image_tokens = input_ids[ - image_indices[image_cnt - 1] + 1 : image_indices[image_cnt] - ] - input_ids_with_image.extend(non_image_tokens) - image_inputs.image_offsets.append(len(input_ids_with_image)) - pad_ids = pad_values * ( - (num_image_tokens + len(pad_values)) // len(pad_values) - ) - input_ids_with_image.extend(pad_ids[:num_image_tokens]) - input_ids_with_image.extend(input_ids[image_indices[-1] + 1 :]) - - return input_ids_with_image - def __init__( self, config: Qwen2VLConfig, @@ -494,6 +469,17 @@ class Qwen2VLForConditionalGeneration(nn.Module): self.logits_processor = LogitsProcessor(config) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + # Use grid_t * grid_w * grid_h to pad tokens for each image + # add replaced padding by unique image hash + def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs): + # Get all special token IDs + im_start_id: int = image_inputs.im_start_id + im_end_id: int = image_inputs.im_end_id + + media_token_pairs = [(im_start_id, im_end_id)] + pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs) + return pattern.pad_input_tokens(input_ids, image_inputs) + def _process_image_input(self, image_input: Qwen2VLImageInputs) -> torch.Tensor: pixel_values = image_input["pixel_values"].type(self.visual.dtype) image_embeds = self.visual(pixel_values, grid_thw=image_input["image_grid_thw"]) @@ -556,12 +542,12 @@ class Qwen2VLForConditionalGeneration(nn.Module): extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy() prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu for i, image in enumerate(forward_batch.image_inputs): - if image is None: + if image is None or image.pixel_values is None: continue start_idx = extend_start_loc_cpu[i] prefix_len = prefix_lens_cpu[i] + pixel_values = image.pixel_values.clone() - pixel_values = torch.tensor(image.pixel_values, device="cuda") image_grid_thws = torch.tensor( np.array(image.image_grid_thws), device="cuda" ) @@ -579,15 +565,13 @@ class Qwen2VLForConditionalGeneration(nn.Module): image_grid_thws[idx] ) - left_idx = start_idx + (image_offset - prefix_len) - right_idx = ( - start_idx + (image_offset - prefix_len) + num_image_tokens - ) - + left_idx = start_idx + (image_offset - prefix_len + 1) + right_idx = left_idx + num_image_tokens inputs_embeds[left_idx:right_idx] = image_embeds[ image_embeds_offset : image_embeds_offset + num_image_tokens ] image_embeds_offset += num_image_tokens + input_ids = None hidden_states = self.model( input_ids=input_ids, diff --git a/test/srt/test_vision_llm.py b/test/srt/test_vision_llm.py index 7cda64fc0..1ff4de2b3 100644 --- a/test/srt/test_vision_llm.py +++ b/test/srt/test_vision_llm.py @@ -193,10 +193,10 @@ class TestMiniCPMVLogits(VisionLLMLogitsBase): **{ "pixel_values": [inputs["pixel_values"]], "tgt_sizes": [inputs["tgt_sizes"]], - "im_start_id": [self.tokenizer.im_start_id], - "im_end_id": [self.tokenizer.im_end_id], - "slice_start_id": [self.tokenizer.slice_start_id], - "slice_end_id": [self.tokenizer.slice_end_id], + "im_start_id": self.tokenizer.im_start_id, + "im_end_id": self.tokenizer.im_end_id, + "slice_start_id": self.tokenizer.slice_start_id, + "slice_end_id": self.tokenizer.slice_end_id, }, ) (sglang_output, _) = model.get_embedding( diff --git a/test/srt/test_vision_openai_server.py b/test/srt/test_vision_openai_server.py index 2d57766d7..a88f0e65c 100644 --- a/test/srt/test_vision_openai_server.py +++ b/test/srt/test_vision_openai_server.py @@ -47,7 +47,7 @@ class TestOpenAIVisionServer(unittest.TestCase): def tearDownClass(cls): kill_process_tree(cls.process.pid) - def test_chat_completion(self): + def test_single_image_chat_completion(self): client = openai.Client(api_key=self.api_key, base_url=self.base_url) response = client.chat.completions.create( @@ -75,7 +75,9 @@ class TestOpenAIVisionServer(unittest.TestCase): assert response.choices[0].message.role == "assistant" text = response.choices[0].message.content assert isinstance(text, str) - assert "man" in text or "cab" in text, text + assert "man" in text or "person" in text, text + assert "cab" in text or "taxi" in text or "SUV" in text, text + assert "iron" in text, text assert response.id assert response.created assert response.usage.prompt_tokens > 0 @@ -169,7 +171,7 @@ class TestOpenAIVisionServer(unittest.TestCase): assert response.choices[0].message.role == "assistant" text = response.choices[0].message.content assert isinstance(text, str) - print(text) + print(f"LLM response: {text}") assert "man" in text or "cab" in text or "SUV" in text or "taxi" in text, text assert "logo" in text or '"S"' in text or "SG" in text, text assert response.id @@ -379,6 +381,8 @@ class TestQWen2VLServer(TestOpenAIVisionServer): other_args=[ "--chat-template", "qwen2-vl", + "--chunked-prefill-size", + "10000", ], ) cls.base_url += "/v1" @@ -408,7 +412,7 @@ class TestQWen2_5_VLServer(TestOpenAIVisionServer): cls.base_url += "/v1" -class TestQWen2VLServerContextLengthIssue(unittest.TestCase): +class TestVLMContextLengthIssue(unittest.TestCase): @classmethod def setUpClass(cls): cls.model = "Qwen/Qwen2-VL-7B-Instruct" @@ -433,7 +437,7 @@ class TestQWen2VLServerContextLengthIssue(unittest.TestCase): def tearDownClass(cls): kill_process_tree(cls.process.pid) - def test_chat_completion(self): + def test_single_image_chat_completion(self): client = openai.Client(api_key=self.api_key, base_url=self.base_url) with self.assertRaises(openai.BadRequestError) as cm: @@ -459,9 +463,11 @@ class TestQWen2VLServerContextLengthIssue(unittest.TestCase): temperature=0, ) - self.assertIn( - "Multimodal prompt is too long after expanding multimodal tokens.", - str(cm.exception), + # context length is checked first, then max_req_input_len, which is calculated from the former + assert ( + "Multimodal prompt is too long after expanding multimodal tokens." + in str(cm.exception) + or "is longer than the model's context length" in str(cm.exception) )