model: Minicpmo (#3023)

This commit is contained in:
Mick
2025-03-25 11:08:40 +08:00
committed by GitHub
parent 64129fa632
commit 1e86457c90
40 changed files with 2906 additions and 493 deletions

View File

@@ -34,6 +34,7 @@ runtime_common = [
"pydantic",
"python-multipart",
"pyzmq>=25.1.2",
"soundfile==0.13.1",
"torchao>=0.7.0",
"transformers==4.50.0",
"uvicorn",

View File

@@ -15,6 +15,7 @@ class ChatTemplate:
role_prefix_and_suffix: Dict[str, Tuple[str, str]]
stop_str: List[str] = ()
image_token: str = "<image>"
audio_token: str = "<audio>"
style: ChatTemplateStyle = ChatTemplateStyle.PLAIN
def get_prefix_and_suffix(
@@ -253,6 +254,22 @@ register_chat_template(
)
)
# https://huggingface.co/openbmb/MiniCPM-o-2_6
register_chat_template(
ChatTemplate(
name="minicpmo",
default_system_prompt=None,
role_prefix_and_suffix={
"system": ("", " "),
"user": ("user:", " "),
"assistant": ("assistant:", "</s>"),
},
stop_str=("<|im_end|>", "<|endoftext|>"),
image_token="(<image>./</image>)",
audio_token="(<audio>./</audio>)",
)
)
# The difference between "llama-3-instruct-llava" and "llama-3-instruct" is that llava uses a different image_token.
register_chat_template(
ChatTemplate(
@@ -474,12 +491,6 @@ def match_chat_ml(model_path: str):
return get_chat_template("chatml-llava")
@register_chat_template_matching_function
def match_chat_minicpm(model_path: str):
if "minicpm" in model_path:
return get_chat_template("minicpmv")
@register_chat_template_matching_function
def match_chat_yi(model_path: str):
model_path = model_path.lower()
@@ -499,8 +510,10 @@ def match_gemma_it(model_path: str):
@register_chat_template_matching_function
def match_openbmb_minicpm(model_path: str):
model_path = model_path.lower()
if "minicpm" in model_path:
if "minicpm-v" in model_path:
return get_chat_template("minicpmv")
elif "minicpm-o" in model_path:
return get_chat_template("minicpmo")
@register_chat_template_matching_function

View File

@@ -462,18 +462,19 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
multimodal_model_archs = [
"DeepseekVL2ForCausalLM",
"LlavaLlamaForCausalLM",
"LlavaQwenForCausalLM",
"LlavaMistralForCausalLM",
"LlavaVidForCausalLM",
"Gemma3ForConditionalGeneration",
"Grok1VForCausalLM",
"Grok1AForCausalLM",
"LlavaLlamaForCausalLM",
"LlavaMistralForCausalLM",
"LlavaQwenForCausalLM",
"LlavaVidForCausalLM",
"MiniCPMO",
"MiniCPMV",
"MultiModalityCausalLM",
"MllamaForConditionalGeneration",
"Qwen2VLForConditionalGeneration",
"Qwen2_5_VLForConditionalGeneration",
"MiniCPMV",
"MultiModalityCausalLM",
]

View File

@@ -73,11 +73,14 @@ class Conversation:
stop_str: Union[str, List[str]] = None
# The string that represents an image token in the prompt
image_token: str = "<image>"
audio_token: str = "<audio>"
image_data: Optional[List[str]] = None
modalities: Optional[List[str]] = None
stop_token_ids: Optional[int] = None
audio_data: Optional[List[str]] = None
def get_prompt(self) -> str:
"""Get the prompt for generation."""
system_prompt = self.system_template.format(system_message=self.system_message)
@@ -327,6 +330,10 @@ class Conversation:
"""Append a new message."""
self.image_data.append(image)
def append_audio(self, audio: str):
"""Append a new message."""
self.audio_data.append(audio)
def update_last_message(self, message: str):
"""Update the last output.
@@ -373,6 +380,7 @@ class Conversation:
sep2=self.sep2,
stop_str=self.stop_str,
image_token=self.image_token,
audio_token=self.audio_token,
)
def dict(self):
@@ -459,8 +467,10 @@ def generate_chat_conv(
sep2=conv.sep2,
stop_str=conv.stop_str,
image_data=[],
audio_data=[],
modalities=[],
image_token=conv.image_token,
audio_token=conv.audio_token,
)
if isinstance(request.messages, str):
@@ -498,6 +508,7 @@ def generate_chat_conv(
if conv.name != "qwen2-vl"
else conv.image_token
)
audio_token = conv.audio_token
for content in message.content:
if content.type == "text":
if num_image_url > 16:
@@ -507,6 +518,10 @@ def generate_chat_conv(
# NOTE: Only works for llava
real_content += image_token
conv.append_image(content.image_url.url)
elif content.type == "audio_url":
real_content += audio_token
conv.append_audio(content.audio_url.url)
conv.append_message(conv.roles[0], real_content)
elif msg_role == "assistant":
parsed_content = ""
@@ -704,3 +719,18 @@ register_conv_template(
image_token="<image_placeholder>",
)
)
# Reference: https://huggingface.co/openbmb/MiniCPM-o-2_6#usage
register_conv_template(
Conversation(
name="minicpmo",
system_message="You are Qwen, created by Alibaba Cloud. You are a helpful assistant.",
system_template="<|im_start|>system\n{system_message}",
roles=("<|im_start|>user", "<|im_start|>assistant"),
sep="<|im_end|>\n",
sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
stop_str=("<|im_end|>", "<|endoftext|>"),
image_token="(<image>./</image>)",
audio_token="(<audio>./</audio>)",
)
)

View File

@@ -1,61 +0,0 @@
# TODO: also move pad_input_ids into this module
import importlib
import inspect
import logging
import pkgutil
from functools import lru_cache
from typing import Union
from torch import Tensor
from transformers import IMAGE_PROCESSOR_MAPPING
from sglang.srt.managers.image_processors.base_image_processor import (
BaseImageProcessor,
DummyImageProcessor,
)
from sglang.srt.server_args import ServerArgs
logger = logging.getLogger(__name__)
IMAGE_PROCESSOR_MAPPING = {}
def get_image_processor(hf_config, server_args, processor) -> BaseImageProcessor:
for model_cls, processor_cls in IMAGE_PROCESSOR_MAPPING.items():
if model_cls.__name__ in hf_config.architectures:
return processor_cls(hf_config, server_args, processor)
raise ValueError(
f"No image processor found for architecture: {hf_config.architectures}"
)
def get_dummy_image_processor():
return DummyImageProcessor()
@lru_cache()
def import_image_processors():
package_name = "sglang.srt.managers.image_processors"
package = importlib.import_module(package_name)
for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
if not ispkg:
try:
module = importlib.import_module(name)
except Exception as e:
logger.warning(f" Ignore import error when loading {name}: " f"{e}")
continue
all_members = inspect.getmembers(module, inspect.isclass)
classes = [
member
for name, member in all_members
if member.__module__ == module.__name__
]
for cls in classes:
if issubclass(cls, BaseImageProcessor):
for arch in getattr(cls, "models"):
IMAGE_PROCESSOR_MAPPING[arch] = cls
# also register processors
import_image_processors()

View File

@@ -1,129 +0,0 @@
import asyncio
from typing import List, Union
import torch
from sglang.srt.managers.image_processor import BaseImageProcessor
from sglang.srt.managers.image_processors.base_image_processor import (
get_global_processor,
)
from sglang.srt.models.minicpmv import MiniCPMV
class MiniCPMVImageProcessor(BaseImageProcessor):
models = [MiniCPMV]
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
self.IMAGE_TOKEN = "(<image>./</image>)"
@staticmethod
def _process_images_task(images, input_text):
processor = get_global_processor()
result = processor.__call__(text=input_text, images=images, return_tensors="pt")
return {
"input_ids": result.input_ids,
"pixel_values": result.pixel_values,
"tgt_sizes": result.tgt_sizes,
}
async def _process_images(self, images, input_text):
if self.executor is not None:
loop = asyncio.get_event_loop()
image_inputs = await loop.run_in_executor(
self.executor,
MiniCPMVImageProcessor._process_images_task,
images,
input_text,
)
else:
image_inputs = self._processor(
images=images, text=input_text, return_tensors="pt"
)
return image_inputs
async def process_images_async(
self,
image_data: List[Union[str, bytes]],
input_ids,
request_obj,
max_req_input_len,
):
if not image_data:
return None
if not isinstance(image_data, list):
image_data = [image_data]
base_output = self.load_images(
input_ids=input_ids,
image_data=image_data,
image_token=self.IMAGE_TOKEN,
max_req_input_len=max_req_input_len,
)
if base_output is None:
return None
if len(base_output.all_frames) == 0:
return None
res = await self._process_images(
images=base_output.all_frames, input_text=base_output.input_text
)
# Collect special token ids
tokenizer = self._processor.tokenizer
im_start_id = tokenizer.im_start_id
im_token_id = tokenizer.unk_token_id
im_end_id = tokenizer.im_end_id
if tokenizer.slice_start_id:
slice_start_id = tokenizer.slice_start_id
slice_end_id = tokenizer.slice_end_id
pixel_values = res["pixel_values"]
tgt_sizes = res["tgt_sizes"]
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError(
"Incorrect type of pixel values. " f"Got type: {type(pixel_values)}"
)
if not isinstance(tgt_sizes, (torch.Tensor, list)):
raise ValueError(
"Incorrect type of target sizes. " f"Got type: {type(tgt_sizes)}"
)
if len(pixel_values) != len(tgt_sizes):
raise ValueError(
"Inconsistent batch lengths, found: "
f"{len(pixel_values)} vs. {len(tgt_sizes)}"
)
# tgt_sizes = [tgt_size for tgt_size in tgt_sizes if isinstance(tgt_size, torch.Tensor)]
# tgt_sizes = torch.vstack(tgt_sizes).type(torch.int32)
pixel_values_flat: List[torch.Tensor] = []
tgt_sizes_flat: List[torch.Tensor] = []
for pixel_b, tgt_b in zip(pixel_values, tgt_sizes):
# per image
if len(pixel_b) != len(tgt_b):
raise ValueError(
"Inconsistent N lengths, found: " f"{len(pixel_b)} vs {len(tgt_b)}"
)
for pixel_n, tgt_n in zip(pixel_b, tgt_b):
# per patch
pixel_values_flat += [pixel_n]
tgt_sizes_flat += [tgt_n]
pixel_values = pixel_values_flat
tgt_sizes = torch.stack(tgt_sizes_flat)
return {
"input_ids": res["input_ids"].flatten().tolist(),
"pixel_values": pixel_values,
"tgt_sizes": tgt_sizes,
"image_hashes": base_output.image_hashes,
"modalities": request_obj.modalities or ["image"],
"im_start_id": im_start_id,
"im_token_id": im_token_id,
"im_end_id": im_end_id,
"slice_start_id": slice_start_id,
"slice_end_id": slice_end_id,
}

View File

@@ -45,6 +45,8 @@ class GenerateReqInput:
# The image input. It can be a file name, a url, or base64 encoded string.
# See also python/sglang/srt/utils.py:load_image.
image_data: Optional[Union[List[str], str]] = None
# The audio input. Like image data, tt can be a file name, a url, or base64 encoded string.
audio_data: Optional[Union[List[str], str]] = None
# The sampling_params. See descriptions below.
sampling_params: Optional[Union[List[Dict], Dict]] = None
# The request id.
@@ -167,6 +169,13 @@ class GenerateReqInput:
elif isinstance(self.image_data, list):
pass
if self.audio_data is None:
self.audio_data = [None] * num
elif not isinstance(self.audio_data, list):
self.audio_data = [self.audio_data] * num
elif isinstance(self.audio_data, list):
pass
if self.sampling_params is None:
self.sampling_params = [{}] * num
elif not isinstance(self.sampling_params, list):
@@ -231,6 +240,7 @@ class GenerateReqInput:
text=self.text[i] if self.text is not None else None,
input_ids=self.input_ids[i] if self.input_ids is not None else None,
image_data=self.image_data[i],
audio_data=self.audio_data[i],
sampling_params=self.sampling_params[i],
rid=self.rid[i],
return_logprob=self.return_logprob[i],
@@ -259,8 +269,8 @@ class TokenizedGenerateReqInput:
input_text: str
# The input token ids
input_ids: List[int]
# The image inputs
image_inputs: dict
# The multimodal inputs
mm_inputs: dict
# The sampling parameters
sampling_params: SamplingParams
# Whether to return the logprobs

View File

@@ -9,7 +9,7 @@ import torch
from torch import nn
from sglang.srt.managers.schedule_batch import (
ImageInputs,
MultimodalInputs,
global_server_args_dict,
logger,
)
@@ -26,7 +26,7 @@ class MultiModalityDataPaddingPattern:
@abstractmethod
def pad_input_tokens(
self, input_ids: List[int], image_inputs: ImageInputs
self, input_ids: List[int], image_inputs: MultimodalInputs
) -> List[int]:
"""
Pad the input ids sequence containing data tokens, and replace them with pad_values
@@ -44,16 +44,16 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
self.data_token_id_pairs = data_token_pairs
def pad_input_tokens(
self, input_ids: List[int], image_inputs: ImageInputs
self, input_ids: List[int], mm_inputs: MultimodalInputs
) -> List[int]:
"""
This function will replace the data-tokens inbetween with pad_values accordingly
"""
pad_values = image_inputs.pad_values
pad_values = mm_inputs.pad_values
data_token_pairs = self.data_token_id_pairs
image_inputs.image_offsets = []
mm_inputs.image_offsets = []
if data_token_pairs is None:
data_token_pairs = [image_inputs.im_start_id, image_inputs.im_end_id]
data_token_pairs = [mm_inputs.im_start_id, mm_inputs.im_end_id]
if data_token_pairs is None:
logger.warning(
"No data_token_pairs provided, RadixAttention might be influenced."
@@ -61,8 +61,6 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
return input_ids
start_token_ids = [s for s, _e in data_token_pairs]
end_tokens_ids = [e for _s, e in data_token_pairs]
# First start token marks new data
data_start_token = start_token_ids[0]
padded_ids = []
last_idx = 0
@@ -77,9 +75,12 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
for start_idx, end_idx in zip(start_indices, end_indices):
padded_ids.extend(input_ids[last_idx : start_idx + 1])
if input_ids[start_idx] == data_start_token:
if input_ids[start_idx] in start_token_ids:
data_idx += 1
image_inputs.image_offsets += [start_idx]
mm_inputs.image_offsets += [start_idx]
if data_idx >= len(mm_inputs.pad_values):
data_idx = len(mm_inputs.pad_values) - 1
num_tokens = end_idx - start_idx - 1
pad_value = pad_values[data_idx]
@@ -89,7 +90,7 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
padded_ids.extend(input_ids[last_idx:])
assert len(input_ids) == len(padded_ids)
assert len(input_ids) == len(padded_ids), "Length validation fails"
return padded_ids
@@ -107,26 +108,25 @@ class MultModalityDataPaddingPatternSingleToken(MultiModalityDataPaddingPattern)
self.num_data_token_calc_func = num_data_token_calc_func
def pad_input_tokens(
self, input_ids: List[int], image_inputs: ImageInputs
self, input_ids: List[int], mm_inputs: MultimodalInputs
) -> List[int]:
"""
This function will follow the procedure of:
1. the data token will be expanded, of which the final number will be calculated by `num_data_token_calc_func`
2. the padded data tokens will be replaced with their pad_values
"""
image_grid_thws = image_inputs.image_grid_thws
pad_values = image_inputs.pad_values
image_grid_thws = mm_inputs.image_grid_thws
pad_values = mm_inputs.pad_values
image_indices = [
idx
for idx, token in enumerate(input_ids)
if token == image_inputs.im_token_id
idx for idx, token in enumerate(input_ids) if token == mm_inputs.im_token_id
]
image_inputs.image_offsets = []
mm_inputs.image_offsets = []
input_ids_with_image = []
for image_cnt, _ in enumerate(image_grid_thws):
# print(f"image_cnt {image_cnt}")
num_image_tokens = self.num_data_token_calc_func(image_grid_thws[image_cnt])
if image_cnt == 0:
non_image_tokens = input_ids[: image_indices[image_cnt]]
@@ -135,7 +135,7 @@ class MultModalityDataPaddingPatternSingleToken(MultiModalityDataPaddingPattern)
image_indices[image_cnt - 1] + 1 : image_indices[image_cnt]
]
input_ids_with_image.extend(non_image_tokens)
image_inputs.image_offsets.append(len(input_ids_with_image))
mm_inputs.image_offsets.append(len(input_ids_with_image))
pad_ids = pad_values * (
(num_image_tokens + len(pad_values)) // len(pad_values)
)
@@ -170,11 +170,11 @@ class MultiModalityDataPaddingPatternImageTokens(MultiModalityDataPaddingPattern
return input_ids_tensor.tolist()
def embed_image_inputs(
image_input: ImageInputs,
def embed_mm_inputs(
mm_input: MultimodalInputs,
input_ids: torch.Tensor,
input_embedding: nn.Embedding,
image_embedding_func,
mm_data_embedding_func: Callable[[MultimodalInputs], torch.Tensor],
placeholder_token_ids: List[int] = None,
) -> Optional[torch.Tensor]:
"""
@@ -184,10 +184,10 @@ def embed_image_inputs(
Returns:
final embedding: Optional[torch.Tensor]
"""
if image_input is None:
if mm_input is None:
return None
placeholder_token_ids = placeholder_token_ids or image_input.pad_values
placeholder_token_ids = placeholder_token_ids or mm_input.pad_values
# boolean masking the special tokens
special_image_mask = torch.isin(
@@ -196,12 +196,18 @@ def embed_image_inputs(
).unsqueeze(-1)
num_image_tokens_in_input_ids = special_image_mask.sum()
# print(f"{num_image_tokens_in_input_ids}")
# print(f"{input_ids}")
# return
if num_image_tokens_in_input_ids == 0:
# unexpected
inputs_embeds = input_embedding(input_ids)
else:
image_embedding = image_embedding_func(image_input)
# print(f"Getting image feature")
image_embedding = mm_data_embedding_func(mm_input)
# print(f"image_embedding: {image_embedding.shape}")
if image_embedding.dim() == 2:
num_image_tokens_in_embedding = image_embedding.shape[0]
@@ -273,31 +279,95 @@ def embed_image_embedding(
def general_mm_embed_routine(
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
embed_tokens: nn.Embedding,
image_embedding_func: Callable[[ImageInputs], torch.Tensor],
mm_data_embedding_func: Callable[[MultimodalInputs], torch.Tensor],
placeholder_token_ids: List[int] = None,
):
"""
a general wrapper function to get final input embeds from multimodal models
with a language model as causal model
Args:
placeholder_token_ids (List[int]): the ids of mm data placeholder tokens
"""
if (
forward_batch.forward_mode.is_decode()
or not forward_batch.contains_image_inputs()
not forward_batch.forward_mode.is_decode()
and forward_batch.contains_mm_inputs()
):
inputs_embeds = embed_tokens(input_ids)
else:
image = forward_batch.merge_image_inputs()
inputs_embeds = embed_image_inputs(
image_input=image,
image = forward_batch.merge_mm_inputs()
inputs_embeds = embed_mm_inputs(
mm_input=image,
input_ids=input_ids,
input_embedding=embed_tokens,
image_embedding_func=image_embedding_func,
mm_data_embedding_func=mm_data_embedding_func,
placeholder_token_ids=placeholder_token_ids,
)
# once used, image_inputs is useless
# once used, mm_inputs is useless
# just being defensive here
forward_batch.image_inputs = None
forward_batch.mm_inputs = None
else:
inputs_embeds = embed_tokens(input_ids)
return inputs_embeds
def get_multimodal_data_bounds(
input_ids: torch.Tensor, pad_values: List[int], token_pairs: List[Tuple[int, int]]
) -> torch.Tensor:
"""
Returns a tensor indicating the bounds of multimodal data (images, video, audio, etc.)
Returns:
[bounds_count, 2]
"""
# All the images in the batch should share the same special image
# bound token ids.
start_tokens = [s for s, _e in token_pairs]
end_tokens = [e for _s, e in token_pairs]
assert all(isinstance(t, int) for t in start_tokens)
assert all(isinstance(t, int) for t in end_tokens)
# print(input_ids)
start_cond = torch.isin(
input_ids, torch.tensor(start_tokens, device=input_ids.device)
)
end_cond = torch.isin(input_ids, torch.tensor(end_tokens, device=input_ids.device))
(data_start_tokens,) = torch.where(start_cond)
(data_end_tokens,) = torch.where(end_cond)
# the im_start_id sometimes can be cached as prefix, but it is needed for the embedding of the images
if len(data_start_tokens) != len(data_end_tokens):
if (
len(data_start_tokens) + 1 == len(data_end_tokens)
and input_ids[0] in pad_values
and data_end_tokens[0] < data_start_tokens[0]
):
data_start_tokens = torch.cat(
[
torch.tensor([0], device=data_start_tokens.device),
data_start_tokens,
]
)
valid_image_nums = min(len(data_start_tokens), len(data_end_tokens))
if valid_image_nums == 0:
return torch.zeros((0, 2), device=input_ids.device)
# Filter out pairs where start_token >= end_token
valid_pairs = []
for i in range(valid_image_nums):
start_token = data_start_tokens[i]
end_token = data_end_tokens[i]
if start_token < end_token:
valid_pairs.append((start_token + 1, end_token - 1))
if not valid_pairs:
return torch.zeros((0, 2), device=input_ids.device)
# Convert valid pairs to tensor
valid_pairs_tensor = torch.tensor(valid_pairs, device=input_ids.device)
return valid_pairs_tensor

View File

@@ -0,0 +1,68 @@
# TODO: also move pad_input_ids into this module
import importlib
import inspect
import logging
import pkgutil
from functools import lru_cache
from transformers import PROCESSOR_MAPPING
from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor,
)
from sglang.srt.server_args import ServerArgs
logger = logging.getLogger(__name__)
PROCESSOR_MAPPING = {}
class DummyMultimodalProcessor(BaseMultimodalProcessor):
def __init__(self):
pass
async def process_mm_data_async(self, *args, **kwargs):
return None
def get_dummy_processor():
return DummyMultimodalProcessor()
@lru_cache()
def import_processors():
package_name = "sglang.srt.managers.multimodal_processors"
package = importlib.import_module(package_name)
for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
if not ispkg:
try:
module = importlib.import_module(name)
except Exception as e:
logger.warning(f"Ignore import error when loading {name}: " f"{e}")
continue
all_members = inspect.getmembers(module, inspect.isclass)
classes = [
member
for name, member in all_members
if member.__module__ == module.__name__
]
for cls in (
cls for cls in classes if issubclass(cls, BaseMultimodalProcessor)
):
assert hasattr(cls, "models")
for arch in getattr(cls, "models"):
PROCESSOR_MAPPING[arch] = cls
def get_mm_processor(
hf_config, server_args: ServerArgs, processor
) -> BaseMultimodalProcessor:
for model_cls, processor_cls in PROCESSOR_MAPPING.items():
if model_cls.__name__ in hf_config.architectures:
return processor_cls(hf_config, server_args, processor)
raise ValueError(
f"No processor registered for architecture: {hf_config.architectures}.\n"
f"Registered architectures: {[model_cls.__name__ for model_cls in PROCESSOR_MAPPING.keys()]}"
)
self.image_proce

View File

@@ -4,16 +4,16 @@ import dataclasses
import multiprocessing as mp
import os
from abc import ABC, abstractmethod
from typing import Optional, Union
from typing import Optional
import numpy as np
import PIL
import transformers
from decord import VideoReader, cpu
from openai import BadRequestError
from PIL import Image
from sglang.srt.utils import load_image
from sglang.utils import logger
from sglang.srt.utils import load_audio, load_image, logger
global global_processor
@@ -24,21 +24,41 @@ def get_global_processor():
@dataclasses.dataclass
class BaseImageProcessorOutput:
image_hashes: list[int]
image_sizes: list[tuple[int, int]]
all_frames: [PIL.Image]
# input_text, with each frame of video/image represented as an image_token
class BaseMultiModalProcessorOutput:
# input_text, with each frame of video/image represented with a image_token
input_text: str
mm_data_hashes: Optional[list[int]]
# images
image_sizes: Optional[list[int]]
# frames loaded from image and video, in given order
images: Optional[list[PIL.Image]] = None
# audios
audios: Optional[list[np.ndarray]] = None
def normalize(self):
for field_name in ["data_hashes", "image_sizes", "all_frames"]:
for field_name in ["data_hashes", "image_sizes", "images", "audios"]:
field = getattr(self, field_name, None)
if field is not None and isinstance(field, list) and len(field) == 0:
setattr(self, field_name, None)
class BaseImageProcessor(ABC):
@dataclasses.dataclass
class MultimodalSpecialTokens:
image_token: Optional[str] = None
video_token: Optional[str] = None
audio_token: Optional[str] = None
def collect(self) -> list[str]:
return [
token
for token in [self.image_token, self.video_token, self.audio_token]
if token
]
class BaseMultimodalProcessor(ABC):
models = []
def __init__(self, hf_config, server_args, _processor):
@@ -72,7 +92,7 @@ class BaseImageProcessor(ABC):
)
@abstractmethod
async def process_images_async(
async def process_mm_data_async(
self, image_data, input_text, max_req_input_len, **kwargs
):
pass
@@ -120,29 +140,33 @@ class BaseImageProcessor(ABC):
frames = [Image.fromarray(v.astype("uint8")) for v in frames]
return frames
def load_images(
def load_mm_data(
self,
input_ids: list[int],
image_data,
image_token: Union[int, str],
multimodal_tokens: MultimodalSpecialTokens,
max_req_input_len: int,
image_data: Optional[list] = None,
audio_data: Optional[list] = None,
return_text: Optional[bool] = True,
discard_alpha_channel: bool = True,
) -> BaseImageProcessorOutput:
) -> BaseMultiModalProcessorOutput:
"""
Each frame of video/image will be replaced by a single image token
Args:
image_token: The token ID representing the image placeholder.
multimodal_tokens (list[str]): list of special token which denoting a single multimodal data
e.g. image token or audio token
discard_alpha_channel: if True, discards the alpha channel in the returned images
"""
if isinstance(image_token, int):
image_token_str = self._processor.tokenizer.convert_ids_to_tokens(
image_token
if isinstance(multimodal_tokens.image_token, int):
multimodal_tokens.image_token = (
self._processor.tokenizer.convert_ids_to_tokens(
multimodal_tokens.image_token
)
)
else:
image_token_str = image_token
multimodal_tokens.image_token = multimodal_tokens.image_token
if isinstance(input_ids, list) and return_text:
assert len(input_ids) and isinstance(input_ids[0], int)
@@ -152,7 +176,11 @@ class BaseImageProcessor(ABC):
if return_text:
import re
pattern = "(" + "|".join(re.escape(sep) for sep in [image_token]) + ")"
pattern = (
"("
+ "|".join(re.escape(sep) for sep in multimodal_tokens.collect())
+ ")"
)
# split text into list of normal text and special tokens
text_parts = re.split(pattern, input_text)
@@ -162,7 +190,7 @@ class BaseImageProcessor(ABC):
total_frame_count = sum(estimated_frames_list)
# a heuristic value, suggesting the maximum fraction of frames to embed from all visual inputs.
# e.g., 0.1 suggests that 1 frame out of 10 input frames should be used
_scaling_factor = min(1.0, MAX_NUM_FRAMES / max(1, total_frame_count))
scaling_factor = min(1.0, MAX_NUM_FRAMES / max(1, total_frame_count))
assert len(image_data) == len(estimated_frames_list)
@@ -171,9 +199,16 @@ class BaseImageProcessor(ABC):
new_text = ""
for index, text_part in enumerate(text_parts):
try:
if text_part == image_token:
if text_part == multimodal_tokens.image_token:
# load as image
frames_to_process = estimated_frames_list[image_index]
if len(images) >= MAX_NUM_FRAMES:
frames_to_process = 0
else:
estimated_frames = estimated_frames_list[image_index]
frames_to_process = max(
1, int(estimated_frames * scaling_factor)
)
if frames_to_process == 0:
frames = []
else:
@@ -183,7 +218,7 @@ class BaseImageProcessor(ABC):
):
# video
path = image_file[len("video:") :]
frames = self.encode_video(
frames = BaseMultimodalProcessor.encode_video(
path, frame_count_limit=frames_to_process
)
else:
@@ -200,40 +235,41 @@ class BaseImageProcessor(ABC):
images += frames
image_index += 1
if frames_to_process != 0:
new_text += image_token * len(frames)
new_text += multimodal_tokens.image_token * len(frames)
assert frames_to_process == len(frames)
elif text_part == multimodal_tokens.audio_token:
# load as audio
audio_file = audio_data[audio_index]
audio = load_audio(audio_file)
hashes += [hash(audio_file)]
audios += [audio]
audio_index += 1
new_text += multimodal_tokens.audio_token
else:
# TODO(mick): handle video
# normal text
new_text += text_part
except Exception as e:
logger.error(f"An exception occurred while loading images: {e}")
raise BadRequestError(
f"An exception occurred while loading images: {e}"
)
return BaseImageProcessorOutput(
image_hashes=hashes,
out = BaseMultiModalProcessorOutput(
mm_data_hashes=hashes,
image_sizes=image_sizes,
all_frames=images,
images=images,
audios=audios,
input_text=new_text,
)
out.normalize()
return out
class DummyImageProcessor(BaseImageProcessor):
def __init__(self):
pass
async def process_images_async(self, *args, **kwargs):
return None
def init_global_processor(sglang_image_processor: BaseImageProcessor, server_args):
"""Init the global processor for multi-modal models."""
def init_global_processor(sglang_processor: BaseMultimodalProcessor, server_args):
"""
Init the global processor for multimodal models."""
global global_processor
transformers.logging.set_verbosity_error()
global_processor = sglang_image_processor._build_processor(server_args=server_args)
global_processor = sglang_processor._build_processor(server_args=server_args)

View File

@@ -20,14 +20,15 @@ import asyncio
import torch
from sglang.srt.managers.image_processor import BaseImageProcessor
from sglang.srt.managers.image_processors.base_image_processor import (
from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor,
MultimodalSpecialTokens,
get_global_processor,
)
from sglang.srt.models.deepseek_vl2 import DeepseekVL2ForCausalLM
class DeepseekVL2ImageProcessor(BaseImageProcessor):
class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
models = [DeepseekVL2ForCausalLM]
def __init__(self, hf_config, server_args, _processor):
@@ -63,7 +64,23 @@ class DeepseekVL2ImageProcessor(BaseImageProcessor):
return image_inputs
async def process_images_async(
async def _process_images(self, image_data, input_text, max_req_input_len):
if self.executor is not None:
loop = asyncio.get_event_loop()
image_inputs = await loop.run_in_executor(
self.executor,
DeepseekVL2ImageProcessor._process_images_task,
image_data,
input_text,
max_req_input_len,
)
else:
image_inputs = self._process_images_task(
image_data, input_text, max_req_input_len
)
return image_inputs
async def process_mm_data_async(
self, image_data, input_ids, request_obj, max_req_input_len, *args, **kwargs
):
if not image_data:
@@ -75,11 +92,14 @@ class DeepseekVL2ImageProcessor(BaseImageProcessor):
images, image_sizes = [], []
image_token = self.IMAGE_TOKEN
base_output = self.load_images(
input_ids, image_data, image_token, max_req_input_len
base_output = self.load_mm_data(
input_ids,
image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
max_req_input_len=max_req_input_len,
)
res = await self._process_images(
base_output.all_frames, base_output.input_text, max_req_input_len
base_output.images, base_output.input_text, max_req_input_len
)
images_seq_mask = res["images_seq_mask"]
images_spatial_crop = res["images_spatial_crop"]
@@ -91,7 +111,7 @@ class DeepseekVL2ImageProcessor(BaseImageProcessor):
"input_ids": res["input_ids"].tolist(),
"pixel_values": res["images"],
"im_token_id": res["im_token_id"],
"image_hashes": base_output.image_hashes,
"data_hashes": base_output.mm_data_hashes,
"image_sizes": image_sizes,
"images_emb_mask": images_seq_mask,
"image_spatial_crop": batched_images_spatial_crop,

View File

@@ -1,12 +1,12 @@
import asyncio
from typing import List, Union
from transformers.utils import logging
from sglang.srt.managers.image_processor import (
BaseImageProcessor as SGLangBaseImageProcessor,
from sglang.srt.managers.multimodal_processor import (
BaseMultimodalProcessor as SGLangBaseProcessor,
)
from sglang.srt.managers.image_processors.base_image_processor import (
from sglang.srt.managers.multimodal_processors.base_processor import (
MultimodalSpecialTokens,
get_global_processor,
)
from sglang.srt.models.gemma3_mm import Gemma3ForConditionalGeneration
@@ -16,7 +16,7 @@ from sglang.srt.models.gemma3_mm import Gemma3ForConditionalGeneration
logger = logging.get_logger(__name__)
class Gemma3SGLangImageProcessor(SGLangBaseImageProcessor):
class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
models = [Gemma3ForConditionalGeneration]
def __init__(self, hf_config, server_args, _processor):
@@ -47,7 +47,7 @@ class Gemma3SGLangImageProcessor(SGLangBaseImageProcessor):
"pixel_values": pixel_values,
}
async def process_images_async(
async def process_mm_data_async(
self,
image_data: List[Union[str, bytes]],
input_ids,
@@ -62,22 +62,22 @@ class Gemma3SGLangImageProcessor(SGLangBaseImageProcessor):
image_data = [image_data]
image_token = self.IMAGE_TOKEN
base_output = self.load_images(
base_output = self.load_mm_data(
input_ids=input_ids,
image_data=image_data,
image_token=image_token,
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
max_req_input_len=max_req_input_len,
discard_alpha_channel=True,
)
ret = await self._process_single_image(
input_text=base_output.input_text, images=base_output.all_frames
input_text=base_output.input_text, images=base_output.images
)
return {
"input_ids": ret["input_ids"].flatten().tolist(),
"pixel_values": ret["pixel_values"],
"image_hashes": base_output.image_hashes,
"data_hashes": base_output.mm_data_hashes,
"im_start_id": self.IM_START_TOKEN_ID,
"im_end_id": self.IM_END_TOKEN_ID,
}

View File

@@ -1,16 +1,15 @@
import asyncio
from typing import List, Union
from sglang.srt.managers.image_processors.base_image_processor import (
BaseImageProcessor as SGLangBaseImageProcessor,
)
from sglang.srt.managers.image_processors.base_image_processor import (
from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor,
MultimodalSpecialTokens,
get_global_processor,
)
from sglang.srt.models.deepseek_janus_pro import MultiModalityCausalLM
class JanusProProcessor(SGLangBaseImageProcessor):
class JanusProImageProcessor(BaseMultimodalProcessor):
models = [MultiModalityCausalLM]
def __init__(self, hf_config, server_args, _processor):
@@ -36,7 +35,7 @@ class JanusProProcessor(SGLangBaseImageProcessor):
loop = asyncio.get_event_loop()
image_inputs = await loop.run_in_executor(
self.executor,
JanusProProcessor._process_images_task,
JanusProImageProcessor._process_images_task,
images,
input_text,
)
@@ -47,7 +46,7 @@ class JanusProProcessor(SGLangBaseImageProcessor):
return image_inputs
async def process_images_async(
async def process_mm_data_async(
self,
image_data: List[Union[str, bytes]],
input_ids,
@@ -61,20 +60,24 @@ class JanusProProcessor(SGLangBaseImageProcessor):
if not isinstance(image_data, list):
image_data = [image_data]
base_out = self.load_images(
base_out = self.load_mm_data(
input_ids=input_ids,
image_data=image_data,
image_token="<image_placeholder>",
multimodal_tokens=MultimodalSpecialTokens(
image_token="<image_placeholder>"
),
max_req_input_len=max_req_input_len,
)
images = base_out.all_frames
images = base_out.images
res = await self._process_images(images=images, input_text=base_out.input_text)
# print(res)
# print(base_out)
# print("", res["images_emb_mask"].shape)
return {
"input_ids": res["input_ids"].flatten().tolist(),
"pixel_values": res["pixel_values"],
"images_emb_mask": res["images_emb_mask"],
"image_hashes": base_out.image_hashes,
"data_hashes": base_out.mm_data_hashes,
"im_start_id": res["im_start_id"],
"im_end_id": res["im_end_id"],
"im_token_id": res["im_token_id"],

View File

@@ -3,8 +3,8 @@ from typing import List, Optional, Union
import numpy as np
from sglang.srt.managers.image_processor import BaseImageProcessor
from sglang.srt.managers.image_processors.base_image_processor import (
from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor,
get_global_processor,
)
from sglang.srt.mm_utils import expand2square, process_anyres_image
@@ -14,7 +14,7 @@ from sglang.srt.utils import load_image, logger
from sglang.utils import get_exception_traceback
class LlavaImageProcessor(BaseImageProcessor):
class LlavaImageProcessor(BaseMultimodalProcessor):
models = [LlavaVidForCausalLM, LlavaQwenForCausalLM, LlavaMistralForCausalLM]
def __init__(self, hf_config, server_args, _processor):
@@ -86,7 +86,7 @@ class LlavaImageProcessor(BaseImageProcessor):
image_data, aspect_ratio, grid_pinpoints
)
async def process_images_async(
async def process_mm_data_async(
self,
image_data: List[Union[str, bytes]],
input_text,
@@ -113,7 +113,7 @@ class LlavaImageProcessor(BaseImageProcessor):
if "multi-images" in modalities or "video" in modalities:
# Multiple images
aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
pixel_values, image_hashes, image_sizes = [], [], []
pixel_values, data_hashes, image_sizes = [], [], []
res = []
for img_data in image_data:
res.append(
@@ -124,7 +124,7 @@ class LlavaImageProcessor(BaseImageProcessor):
res = await asyncio.gather(*res)
for pixel_v, image_h, image_s in res:
pixel_values.append(pixel_v)
image_hashes.append(image_h)
data_hashes.append(image_h)
image_sizes.append(image_s)
if isinstance(pixel_values[0], np.ndarray):
@@ -134,14 +134,14 @@ class LlavaImageProcessor(BaseImageProcessor):
pixel_values, image_hash, image_size = await self._process_single_image(
image_data[0], aspect_ratio, grid_pinpoints
)
image_hashes = [image_hash]
data_hashes = [image_hash]
image_sizes = [image_size]
else:
raise ValueError(f"Invalid image data: {image_data}")
return {
"pixel_values": pixel_values,
"image_hashes": image_hashes,
"data_hashes": data_hashes,
"image_sizes": image_sizes,
"modalities": request_obj.modalities or ["image"],
}

View File

@@ -0,0 +1,167 @@
import asyncio
from typing import List, Union
import torch
from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor,
MultimodalSpecialTokens,
get_global_processor,
)
from sglang.srt.models.minicpmo import MiniCPMO
from sglang.srt.models.minicpmv import MiniCPMV
# Compatible with both 'O' and 'V'
class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
models = [MiniCPMV, MiniCPMO]
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
self.image_token = "(<image>./</image>)"
self.audio_token = "(<audio>./</audio>)"
@staticmethod
def _process_data_task(input_text, images=None, audios=None):
if isinstance(images, list) and len(images) == 0:
images = None
if isinstance(audios, list) and len(audios) == 0:
audios = None
result = get_global_processor().__call__(
text=input_text,
images=images,
audios=audios,
return_tensors="pt",
chunk_input=True,
)
return {
"input_ids": result.input_ids,
"pixel_values": getattr(result, "pixel_values", None),
"tgt_sizes": getattr(result, "tgt_sizes", None),
"audio_features": getattr(result, "audio_features", None),
"audio_feature_lens": getattr(result, "audio_feature_lens", None),
"audio_bounds": getattr(result, "audio_bounds", None),
}
async def _process_data(self, images, input_text, audios=None):
if self.executor is not None:
loop = asyncio.get_event_loop()
multimodal_data_inputs = await loop.run_in_executor(
self.executor,
MiniCPMMultimodalProcessor._process_data_task,
input_text,
images,
audios,
)
else:
multimodal_data_inputs = self._processor(
images=images, text=input_text, audios=audios, return_tensors="pt"
)
return multimodal_data_inputs
async def process_mm_data_async(
self,
image_data: List[Union[str, bytes]],
input_ids,
request_obj,
max_req_input_len,
):
audio_data = request_obj.audio_data
if not image_data and not audio_data:
return None
if not isinstance(image_data, list):
image_data = [image_data]
if not isinstance(audio_data, list):
audio_data = [audio_data]
base_output = self.load_mm_data(
input_ids=input_ids,
max_req_input_len=max_req_input_len,
audio_data=audio_data,
image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens(
image_token=self.image_token, audio_token=self.audio_token
),
)
if base_output is None:
return None
res = await self._process_data(
images=base_output.images,
input_text=base_output.input_text,
audios=base_output.audios,
)
# Collect special token ids
tokenizer = self._processor.tokenizer
slice_start_id, slice_end_id, audio_start_id, audio_end_id = (
None,
None,
None,
None,
)
if tokenizer.slice_start_id:
slice_start_id = tokenizer.slice_start_id
slice_end_id = tokenizer.slice_end_id
if hasattr(tokenizer, "audio_start_id"):
audio_start_id = tokenizer.audio_start_id
audio_end_id = tokenizer.audio_end_id
im_token_id = tokenizer.unk_token_id
pixel_values = res["pixel_values"]
tgt_sizes = res["tgt_sizes"]
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError(
"Incorrect type of pixel values. " f"Got type: {type(pixel_values)}"
)
if not isinstance(tgt_sizes, (torch.Tensor, list)):
raise ValueError(
"Incorrect type of target sizes. " f"Got type: {type(tgt_sizes)}"
)
if len(pixel_values) != len(tgt_sizes):
raise ValueError(
"Inconsistent batch lengths, found: "
f"{len(pixel_values)} vs. {len(tgt_sizes)}"
)
pixel_values_flat: List[torch.Tensor] = []
tgt_sizes_flat: List[torch.Tensor] = []
for pixel_b, tgt_b in zip(pixel_values, tgt_sizes):
# per image
if len(pixel_b) != len(tgt_b):
raise ValueError(
"Inconsistent N lengths, found: " f"{len(pixel_b)} vs {len(tgt_b)}"
)
for pixel_n, tgt_n in zip(pixel_b, tgt_b):
pixel_values_flat += [pixel_n]
tgt_sizes_flat += [tgt_n]
pixel_values = pixel_values_flat
if len(tgt_sizes_flat) == 0:
tgt_sizes = None
else:
tgt_sizes = torch.stack(tgt_sizes_flat)
if not isinstance(res["audio_features"], list):
res["audio_features"] = [res["audio_features"]]
return {
"input_ids": res["input_ids"].flatten().tolist(),
"pixel_values": pixel_values,
"tgt_sizes": tgt_sizes,
"data_hashes": base_output.mm_data_hashes,
"modalities": request_obj.modalities or ["image"],
"audio_start_id": audio_start_id,
"audio_end_id": audio_end_id,
"audio_features": res["audio_features"],
"audio_bounds": res["audio_bounds"],
"audio_feature_lens": res["audio_feature_lens"],
"im_token_id": im_token_id,
"im_start_id": tokenizer.im_start_id,
"im_end_id": tokenizer.im_end_id,
"slice_start_id": slice_start_id,
"slice_end_id": slice_end_id,
}

View File

@@ -1,15 +1,15 @@
import asyncio
from typing import List, Union
from sglang.srt.managers.image_processor import BaseImageProcessor
from sglang.srt.managers.image_processors.base_image_processor import (
from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor,
get_global_processor,
)
from sglang.srt.models.mllama import MllamaForConditionalGeneration
from sglang.srt.utils import load_image
class MllamaImageProcessor(BaseImageProcessor):
class MllamaImageProcessor(BaseMultimodalProcessor):
models = [MllamaForConditionalGeneration]
def __init__(self, hf_config, server_args, _processor):
@@ -34,7 +34,7 @@ class MllamaImageProcessor(BaseImageProcessor):
return image_inputs
async def process_images_async(
async def process_mm_data_async(
self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs
):
if not image_data:
@@ -53,7 +53,7 @@ class MllamaImageProcessor(BaseImageProcessor):
images = load_image(image_data[0])[0]
image_inputs = await self._process_single_image(images, input_text)
image_inputs["image_hashes"] = [hash(str(image_data))]
image_inputs["data_hashes"] = [hash(str(image_data))]
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
return image_inputs

View File

@@ -1,12 +1,16 @@
import asyncio
import math
import time
from typing import List, Union
import torch
from PIL import Image
from sglang.srt.managers.image_processor import BaseImageProcessor
from sglang.srt.managers.image_processors.base_image_processor import (
from sglang.srt.managers.multimodal_processor import (
BaseMultimodalProcessor as SGLangBaseProcessor,
)
from sglang.srt.managers.multimodal_processors.base_processor import (
MultimodalSpecialTokens,
get_global_processor,
)
from sglang.srt.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
@@ -14,7 +18,7 @@ from sglang.srt.models.qwen2_vl import Qwen2VLForConditionalGeneration
# Compatible with Qwen2VL and Qwen2_5VL
class Qwen2_5VLImageProcessor(BaseImageProcessor):
class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
models = [Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration]
def __init__(self, hf_config, server_args, _processor):
@@ -59,7 +63,7 @@ class Qwen2_5VLImageProcessor(BaseImageProcessor):
else:
return self._process_images_task(images, input_text, self.hf_config)
async def process_images_async(
async def process_mm_data_async(
self,
image_data: List[Union[str, bytes]],
input_ids,
@@ -68,16 +72,17 @@ class Qwen2_5VLImageProcessor(BaseImageProcessor):
*args,
**kwargs,
):
start = time.time()
if not image_data:
return None
if isinstance(image_data, str):
image_data = [image_data]
image_token = self.IMAGE_TOKEN
base_output = self.load_images(
base_output = self.load_mm_data(
input_ids=input_ids,
image_data=image_data,
image_token=image_token,
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
max_req_input_len=max_req_input_len,
)
@@ -139,7 +144,7 @@ class Qwen2_5VLImageProcessor(BaseImageProcessor):
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
return math.floor(number / factor) * factor
images = [resize_image(image) for image in base_output.all_frames]
images = [resize_image(image) for image in base_output.images]
ret = await self._process_single_image(
images=images, input_text=base_output.input_text
@@ -147,11 +152,10 @@ class Qwen2_5VLImageProcessor(BaseImageProcessor):
image_grid_thws = torch.concat([ret["image_grid_thw"]])
video_grid_thws = None
return {
"input_ids": ret["input_ids"].flatten().tolist(),
"pixel_values": ret["pixel_values"],
"image_hashes": base_output.image_hashes,
"data_hashes": base_output.mm_data_hashes,
"modalities": request_obj.modalities or ["image"],
"image_grid_thws": image_grid_thws,
"video_grid_thws": video_grid_thws,

View File

@@ -144,11 +144,11 @@ class FINISH_ABORT(BaseFinishReason):
@dataclasses.dataclass
class ImageInputs:
class MultimodalInputs:
"""The image related inputs."""
pixel_values: Union[torch.Tensor, np.array]
image_hashes: Optional[list] = None
data_hashes: Optional[list] = None
image_sizes: Optional[list] = None
image_offsets: Optional[list] = None
image_pad_len: Optional[list] = None
@@ -182,20 +182,27 @@ class ImageInputs:
im_end_id: Optional[int] = None
slice_start_id: Optional[int] = None
slice_end_id: Optional[int] = None
# [num_images, 2 (w, h)]
tgt_sizes: Optional[list] = None
# audio
audio_start_id: Optional[torch.Tensor] = None
audio_end_id: Optional[torch.Tensor] = None
audio_features: Optional[List[torch.Tensor]] = None
audio_feature_lens: Optional[List[torch.Tensor]] = None
@staticmethod
def from_dict(obj: dict):
ret = ImageInputs(
ret = MultimodalInputs(
pixel_values=obj["pixel_values"],
image_hashes=obj["image_hashes"],
data_hashes=obj["data_hashes"],
)
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
# Please note that if the `input_ids` is later used in the model forward,
# you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
# errors in cuda kernels. See also llava.py for example.
ret.pad_values = [x % (1 << 30) for x in ret.image_hashes]
ret.pad_values = [x % (1 << 30) for x in ret.data_hashes]
optional_args = [
"image_sizes",
@@ -211,6 +218,10 @@ class ImageInputs:
"slice_start_id",
"slice_end_id",
"tgt_sizes",
"audio_start_id",
"audio_end_id",
"audio_features",
"audio_feature_lens",
]
for arg in optional_args:
if arg in obj:
@@ -223,9 +234,19 @@ class ImageInputs:
or isinstance(ret.pixel_values, list)
)
assert ret.audio_features is None or isinstance(ret.audio_features, list)
return ret
def merge(self, other: ImageInputs):
def contains_image_inputs(self) -> bool:
""" """
return self.pixel_values is not None and self.pixel_values != []
def contains_audio_inputs(self) -> bool:
""" """
return self.audio_features is not None and self.audio_features != []
def merge(self, other: MultimodalInputs):
"""
merge image inputs when requests are being merged
"""
@@ -268,10 +289,12 @@ class ImageInputs:
# Please note that if the `input_ids` is later used in the model forward,
# you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
# errors in cuda kernels. See also llava.py for example.
self.image_hashes += other.image_hashes
self.pad_values = [x % (1 << 30) for x in self.image_hashes]
self.data_hashes += other.data_hashes
self.pad_values = [x % (1 << 30) for x in self.data_hashes]
# args needed to be merged
optional_args = [
"audio_features",
"image_sizes",
"image_offsets",
"image_pad_len",
@@ -362,7 +385,7 @@ class Req:
self.decoded_text = ""
# For multimodal inputs
self.image_inputs: Optional[ImageInputs] = None
self.multimodal_inputs: Optional[MultimodalInputs] = None
# Prefix info
# The indices to kv cache for the shared prefix.
@@ -458,10 +481,10 @@ class Req:
return len(self.origin_input_ids) + len(self.output_ids)
def extend_image_inputs(self, image_inputs):
if self.image_inputs is None:
self.image_inputs = image_inputs
if self.multimodal_inputs is None:
self.multimodal_inputs = image_inputs
else:
self.image_inputs.merge(image_inputs)
self.multimodal_inputs.merge(image_inputs)
def finished(self) -> bool:
# Whether request reached finished condition
@@ -802,7 +825,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.encoder_cached = []
for req in self.reqs:
im = req.image_inputs
im = req.multimodal_inputs
if im is None or im.num_image_tokens is None:
# No image input
self.encoder_lens_cpu.append(0)
@@ -1391,7 +1414,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
extend_seq_lens=extend_seq_lens,
extend_prefix_lens=extend_prefix_lens,
extend_logprob_start_lens=extend_logprob_start_lens,
image_inputs=[r.image_inputs for r in self.reqs],
multimodal_inputs=[r.multimodal_inputs for r in self.reqs],
encoder_cached=self.encoder_cached,
encoder_lens=self.encoder_lens,
encoder_lens_cpu=self.encoder_lens_cpu,
@@ -1474,7 +1497,7 @@ class ModelWorkerBatch:
extend_input_logprob_token_ids: Optional[torch.Tensor]
# For multimodal
image_inputs: Optional[List[ImageInputs]]
multimodal_inputs: Optional[List[MultimodalInputs]]
# For encoder-decoder
encoder_cached: Optional[List[bool]]

View File

@@ -88,7 +88,7 @@ from sglang.srt.managers.io_struct import (
)
from sglang.srt.managers.schedule_batch import (
FINISH_ABORT,
ImageInputs,
MultimodalInputs,
Req,
ScheduleBatch,
global_server_args_dict,
@@ -841,8 +841,8 @@ class Scheduler(
return
# Handle multimodal inputs
if recv_req.image_inputs is not None:
image_inputs = ImageInputs.from_dict(recv_req.image_inputs)
if recv_req.mm_inputs is not None:
image_inputs = MultimodalInputs.from_dict(recv_req.mm_inputs)
# Expand a single image token into multiple dummy tokens for receiving image embeddings
req.origin_input_ids = self.pad_input_ids_func(
req.origin_input_ids, image_inputs
@@ -856,7 +856,7 @@ class Scheduler(
)
logger.error(error_msg)
req.origin_input_ids = [0]
req.image_inputs = None
req.multimodal_inputs = None
req.sampling_params.max_new_tokens = 0
req.finished_reason = FINISH_ABORT(
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
@@ -960,7 +960,7 @@ class Scheduler(
# Handle multimodal inputs
if recv_req.image_inputs is not None:
image_inputs = ImageInputs.from_dict(recv_req.image_inputs)
image_inputs = MultimodalInputs.from_dict(recv_req.image_inputs)
# Expand a single image token into multiple dummy tokens for receiving image embeddings
req.origin_input_ids = self.pad_input_ids_func(
req.origin_input_ids, image_inputs
@@ -974,7 +974,7 @@ class Scheduler(
)
logger.error(error_msg)
req.origin_input_ids = [0]
req.image_inputs = None
req.multimodal_inputs = None
req.sampling_params.max_new_tokens = 0
req.finished_reason = FINISH_ABORT(
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"

View File

@@ -138,7 +138,7 @@ class Session:
token_ids_logprob=req.token_ids_logprob,
)
if last_req is not None:
new_req.image_inputs = last_req.image_inputs
new_req.multimodal_inputs = last_req.mm_inputs
new_req.tokenizer = tokenizer
if abort:
new_req.to_abort = True

View File

@@ -16,7 +16,6 @@
import asyncio
import copy
import dataclasses
import json
import logging
import os
import pickle
@@ -52,10 +51,6 @@ from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.disaggregation.conn import KVBootstrapServer
from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.managers.image_processor import (
get_dummy_image_processor,
get_image_processor,
)
from sglang.srt.managers.io_struct import (
AbortReq,
BatchEmbeddingOut,
@@ -93,6 +88,11 @@ from sglang.srt.managers.io_struct import (
UpdateWeightsFromTensorReqInput,
UpdateWeightsFromTensorReqOutput,
)
from sglang.srt.managers.multimodal_processor import (
get_dummy_processor,
get_mm_processor,
import_processors,
)
from sglang.srt.metrics.collector import TokenizerMetricsCollector
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs
@@ -171,6 +171,7 @@ class TokenizerManager:
self.image_token_id = self.model_config.image_token_id
if self.model_config.is_multimodal:
import_processors()
_processor = get_processor(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
@@ -179,9 +180,9 @@ class TokenizerManager:
)
# We want to parallelize the image pre-processing so we create an executor for it
# We create image_processor for any skip_tokenizer_init to make sure we still encode
# We create mm_processor for any skip_tokenizer_init to make sure we still encode
# images even with skip_tokenizer_init=False.
self.image_processor = get_image_processor(
self.mm_processor = get_mm_processor(
self.model_config.hf_config, server_args, _processor
)
@@ -192,7 +193,7 @@ class TokenizerManager:
self.tokenizer = self.processor.tokenizer
os.environ["TOKENIZERS_PARALLELISM"] = "false"
else:
self.image_processor = get_dummy_image_processor()
self.mm_processor = get_dummy_processor()
if server_args.skip_tokenizer_init:
self.tokenizer = self.processor = None
@@ -389,7 +390,7 @@ class TokenizerManager:
)
input_ids = self.tokenizer.encode(input_text)
image_inputs: Dict = await self.image_processor.process_images_async(
image_inputs: Dict = await self.mm_processor.process_mm_data_async(
obj.image_data, input_text or input_ids, obj, self.max_req_input_len
)
if image_inputs and "input_ids" in image_inputs:

View File

@@ -43,7 +43,7 @@ from sglang.srt.utils import get_compiler_backend
if TYPE_CHECKING:
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.managers.schedule_batch import ImageInputs, ModelWorkerBatch
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, MultimodalInputs
from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
@@ -176,7 +176,7 @@ class ForwardBatch:
extend_input_logprob_token_ids_gpu: Optional[torch.Tensor] = None
# For multimodal
image_inputs: Optional[List[ImageInputs]] = None
mm_inputs: Optional[List[MultimodalInputs]] = None
# Encoder-decoder
encoder_cached: Optional[List[bool]] = None
@@ -242,7 +242,7 @@ class ForwardBatch:
req_pool_indices=batch.req_pool_indices,
seq_lens=batch.seq_lens,
out_cache_loc=batch.out_cache_loc,
image_inputs=batch.image_inputs,
mm_inputs=batch.multimodal_inputs,
encoder_cached=batch.encoder_cached,
encoder_lens=batch.encoder_lens,
encoder_lens_cpu=batch.encoder_lens_cpu,
@@ -332,42 +332,53 @@ class ForwardBatch:
return ret
def merge_image_inputs(self) -> Optional[ImageInputs]:
def merge_mm_inputs(self) -> Optional[MultimodalInputs]:
"""
Merge all image inputs in the batch into a single ImageInputs object.
Merge all image inputs in the batch into a single MultiModalInputs object.
Returns:
if none, current batch contains no image input
"""
if not self.image_inputs or all(x is None for x in self.image_inputs):
if not self.mm_inputs or all(x is None for x in self.mm_inputs):
return None
# Filter out None values
valid_inputs = [x for x in self.image_inputs if x is not None]
valid_inputs = [x for x in self.mm_inputs if x is not None]
# Start with the first valid image input
merged = valid_inputs[0]
# Merge remaining inputs
for img_input in valid_inputs[1:]:
merged.merge(img_input)
for mm_input in valid_inputs[1:]:
merged.merge(mm_input)
if isinstance(merged.pixel_values, np.ndarray):
merged.pixel_values = torch.from_numpy(merged.pixel_values)
if isinstance(merged.audio_features, np.ndarray):
merged.audio_features = torch.from_numpy(merged.audio_features)
return merged
def contains_image_inputs(self) -> bool:
""" """
if self.image_inputs is None:
return True
if self.mm_inputs is None:
return False
return any(
image_input.pixel_values is not None and image_input.pixel_values is not []
for image_input in self.image_inputs
if image_input is not None
mm_input is not None and mm_input.contains_image_inputs()
for mm_input in self.mm_inputs
)
def contains_audio_inputs(self) -> bool:
if self.mm_inputs is None:
return False
return any(
mm_input is not None and mm_input.contains_audio_inputs()
for mm_input in self.mm_inputs
)
def contains_mm_inputs(self) -> bool:
return self.contains_audio_inputs() or self.contains_image_inputs()
def _compute_mrope_positions(
self, model_runner: ModelRunner, batch: ModelWorkerBatch
):
@@ -378,8 +389,8 @@ class ForwardBatch:
for i, _ in enumerate(mrope_positions_list):
mrope_position_delta = (
0
if batch.image_inputs[i] is None
else batch.image_inputs[i].mrope_position_delta
if batch.multimodal_inputs[i] is None
else batch.multimodal_inputs[i].mrope_position_delta
)
mrope_positions_list[i] = MRotaryEmbedding.get_next_input_positions(
mrope_position_delta,
@@ -388,13 +399,13 @@ class ForwardBatch:
)
elif self.forward_mode.is_extend():
extend_start_loc_cpu = self.extend_start_loc.cpu().numpy()
for i, image_inputs in enumerate(batch.image_inputs):
for i, multimodal_inputs in enumerate(batch.multimodal_inputs):
extend_start_loc, extend_seq_len, extend_prefix_len = (
extend_start_loc_cpu[i],
batch.extend_seq_lens[i],
batch.extend_prefix_lens[i],
)
if image_inputs is None:
if multimodal_inputs is None:
# text only
mrope_positions = [
[
@@ -411,20 +422,22 @@ class ForwardBatch:
input_tokens=self.input_ids[
extend_start_loc : extend_start_loc + extend_seq_len
],
image_grid_thw=image_inputs.image_grid_thws,
video_grid_thw=image_inputs.video_grid_thws,
image_token_id=image_inputs.im_token_id,
video_token_id=image_inputs.video_token_id,
image_grid_thw=multimodal_inputs.image_grid_thws,
video_grid_thw=multimodal_inputs.video_grid_thws,
image_token_id=multimodal_inputs.im_token_id,
video_token_id=multimodal_inputs.video_token_id,
vision_start_token_id=hf_config.vision_start_token_id,
vision_end_token_id=hf_config.vision_end_token_id,
spatial_merge_size=hf_config.vision_config.spatial_merge_size,
context_len=0,
seq_len=len(self.input_ids),
second_per_grid_ts=image_inputs.second_per_grid_ts,
second_per_grid_ts=multimodal_inputs.second_per_grid_ts,
tokens_per_second=hf_config.vision_config.tokens_per_second,
)
)
batch.image_inputs[i].mrope_position_delta = mrope_position_delta
batch.multimodal_inputs[i].mrope_position_delta = (
mrope_position_delta
)
mrope_positions_list[i] = mrope_positions
self.mrope_positions = torch.cat(

View File

@@ -51,7 +51,7 @@ from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.managers.schedule_batch import MultimodalInputs, global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.llama import LlamaForCausalLM
@@ -1959,7 +1959,7 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
)
self.logits_processor = LogitsProcessor(config)
def get_image_feature(self, image_input: ImageInputs) -> torch.Tensor:
def get_image_feature(self, image_input: MultimodalInputs) -> torch.Tensor:
pixel_values = image_input.pixel_values
bs, n = pixel_values.shape[0:2]
pixel_values = pixel_values.to(
@@ -1988,10 +1988,9 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
inputs_embeds = general_mm_embed_routine(
input_ids=input_ids,
positions=positions,
forward_batch=forward_batch,
embed_tokens=self.get_input_embeddings(),
image_embedding_func=self.get_image_feature,
mm_data_embedding_func=self.get_image_feature,
)
return self.language_model(
@@ -2005,7 +2004,7 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
def prepare_gen_img_embeds(self, image_ids: torch.LongTensor):
return self.gen_aligner(self.gen_embed(image_ids))
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
im_start_id = image_inputs.im_start_id
im_end_id = image_inputs.im_end_id
media_token_pairs = [(im_start_id, im_end_id)]

View File

@@ -11,7 +11,7 @@ from sglang.srt.configs.deepseekvl2 import (
)
from sglang.srt.layers.linear import ReplicatedLinear
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.managers.schedule_batch import MultimodalInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.deepseek_v2 import DeepseekV2ForCausalLM
@@ -222,7 +222,7 @@ class DeepseekVL2ForCausalLM(nn.Module):
):
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
extend_seq_lens_cpu = forward_batch.extend_seq_lens.cpu().numpy()
for idx, image in enumerate(forward_batch.image_inputs):
for idx, image in enumerate(forward_batch.mm_inputs):
if image is None:
continue
start_idx = extend_start_loc_cpu[idx]
@@ -262,10 +262,10 @@ class DeepseekVL2ForCausalLM(nn.Module):
weights_loader = getattr(param, "weight_loader", default_weight_loader)
weights_loader(param, loaded_weight)
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
return input_ids
def get_image_feature(self, image_input: ImageInputs):
def get_image_feature(self, image_input: MultimodalInputs):
pixel_values = image_input.pixel_values.type(
next(self.vision.parameters()).dtype
).to(device=next(self.vision.parameters()).device)

View File

@@ -38,7 +38,7 @@ from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.managers.schedule_batch import MultimodalInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import (
default_weight_loader,
@@ -185,7 +185,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
self.post_init()
def pad_input_ids(
self, input_ids: List[int], image_inputs: ImageInputs
self, input_ids: List[int], image_inputs: MultimodalInputs
) -> List[int]:
"""Pad input IDs with image tokens."""
# Get special token IDs
@@ -268,7 +268,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
def get_input_embeddings(self) -> nn.Embedding:
return self.language_model.get_input_embeddings()
def get_image_feature(self, image_input: ImageInputs):
def get_image_feature(self, image_input: MultimodalInputs):
"""
Projects the last hidden state from the vision model into language model space.
@@ -286,11 +286,11 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
image_features = self.multi_modal_projector(vision_outputs)
return image_features
def embed_image_inputs(
def embed_mm_inputs(
self,
input_ids: torch.Tensor,
forward_batch: ForwardBatch,
image_input: ImageInputs,
image_input: MultimodalInputs,
) -> torch.Tensor:
if input_ids is None:
raise ValueError("Unimplemented")
@@ -401,10 +401,9 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
inputs_embeds = general_mm_embed_routine(
input_ids=llm_input_ids,
positions=positions,
forward_batch=forward_batch,
embed_tokens=self.get_input_embeddings(),
image_embedding_func=self.get_image_feature,
mm_data_embedding_func=self.get_image_feature,
)
outputs = self.language_model(

View File

@@ -17,7 +17,7 @@
"""Inference-only LLaMA model compatible with HuggingFace weights."""
import logging
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
import torch
from torch import nn

View File

@@ -31,7 +31,7 @@ from transformers import (
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.managers.schedule_batch import MultimodalInputs
from sglang.srt.mm_utils import (
get_anyres_image_grid_shape,
unpad_image,
@@ -46,7 +46,7 @@ from sglang.srt.utils import add_prefix
class LlavaBaseForCausalLM(nn.Module):
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
image_sizes, pad_values = image_inputs.image_sizes, image_inputs.pad_values
# hardcode for spatial_unpad + anyres
@@ -134,7 +134,7 @@ class LlavaBaseForCausalLM(nn.Module):
positions: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
image_inputs = forward_batch.image_inputs
image_inputs = forward_batch.mm_inputs
if forward_batch.forward_mode.is_extend():
# Clamp input ids. This is because the input_ids for the image tokens are

View File

@@ -22,7 +22,7 @@ from transformers import CLIPVisionModel, LlavaConfig
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.managers.schedule_batch import MultimodalInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.llama import LlamaForCausalLM
@@ -57,7 +57,7 @@ class LlavaVidForCausalLM(nn.Module):
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
)
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
pad_values = image_inputs.pad_values
new_image_feature_len = self.image_feature_len
@@ -112,7 +112,7 @@ class LlavaVidForCausalLM(nn.Module):
positions: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
image_inputs = forward_batch.image_inputs
image_inputs = forward_batch.mm_inputs
if forward_batch.forward_mode.is_extend():
bs = forward_batch.batch_size

File diff suppressed because it is too large Load Diff

View File

@@ -52,9 +52,9 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs,
embed_image_inputs,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.managers.schedule_batch import MultimodalInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.utils import set_default_torch_dtype
from sglang.srt.model_loader.weight_utils import default_weight_loader
@@ -862,24 +862,12 @@ class MiniCPMVBaseModel(nn.Module):
forward_batch: ForwardBatch,
**kwargs: Any,
) -> torch.Tensor:
if (
forward_batch.forward_mode.is_decode()
or not forward_batch.contains_image_inputs()
):
inputs_embeds: torch.Tensor = self.llm.get_input_embeddings(input_ids)
else:
# Clamp input ids. This is because the input_ids for the image tokens are
# filled with the hash values of the image for the prefix matching in the radix attention.
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
image_inputs = forward_batch.merge_image_inputs()
inputs_embeds = embed_image_inputs(
image_input=image_inputs,
input_ids=input_ids,
input_embedding=self.get_input_embeddings(),
image_embedding_func=self.get_image_features,
placeholder_token_ids=[image_inputs.im_token_id]
+ image_inputs.pad_values,
)
inputs_embeds = general_mm_embed_routine(
input_ids=input_ids,
forward_batch=forward_batch,
embed_tokens=self.get_input_embeddings(),
mm_data_embedding_func=self.get_image_features,
)
hidden_states = self.llm.model(
input_ids=None,
@@ -925,7 +913,7 @@ class MiniCPMVBaseModel(nn.Module):
) -> torch.Tensor:
raise NotImplementedError
def get_image_features(self, image_inputs: ImageInputs) -> torch.Tensor:
def get_image_features(self, image_inputs: MultimodalInputs) -> torch.Tensor:
raise NotImplementedError
@@ -1037,7 +1025,7 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
def get_image_features(
self,
image_inputs: ImageInputs,
image_inputs: MultimodalInputs,
) -> torch.Tensor:
# list of tensors
pixel_values = image_inputs.pixel_values
@@ -1075,7 +1063,7 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
)
return self.resampler(vision_embedding, tgt_sizes)
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
# Get all special token IDs
im_start_id: int = image_inputs.im_start_id
im_end_id: int = image_inputs.im_end_id

View File

@@ -32,7 +32,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.managers.schedule_batch import MultimodalInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.llama import LlamaDecoderLayer, LlamaMLP
@@ -796,7 +796,7 @@ class MllamaForConditionalGeneration(nn.Module):
self.logits_processor = LogitsProcessor(config.text_config)
self.capture_mode = False
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
pixel_values = image_inputs.pixel_values
pad_values = image_inputs.pad_values
@@ -815,7 +815,7 @@ class MllamaForConditionalGeneration(nn.Module):
# pixel_values: shape (bs, num_image, num_tiles, 3, image_res, image_res)
max_num_images = max_num_tiles = bs = 0
for i, im in enumerate(forward_batch.image_inputs):
for i, im in enumerate(forward_batch.mm_inputs):
if not forward_batch.encoder_cached[i] and im is not None:
max_num_images = max(max_num_images, im.pixel_values.shape[1])
max_num_tiles = max(max_num_tiles, im.pixel_values.shape[2])
@@ -842,7 +842,7 @@ class MllamaForConditionalGeneration(nn.Module):
)
i = 0
encoder_lens_need = []
for k, im in enumerate(forward_batch.image_inputs):
for k, im in enumerate(forward_batch.mm_inputs):
if forward_batch.encoder_cached[k] or im is None:
continue

View File

@@ -57,7 +57,7 @@ from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.managers.schedule_batch import MultimodalInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2 import Qwen2Model
@@ -513,7 +513,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
# Get all special token IDs
im_start_id: int = image_inputs.im_start_id
im_end_id: int = image_inputs.im_end_id
@@ -523,7 +523,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
return pattern.pad_input_tokens(input_ids, image_inputs)
def get_image_feature(self, image_input: ImageInputs) -> torch.Tensor:
def get_image_feature(self, image_input: MultimodalInputs) -> torch.Tensor:
pixel_values = image_input.pixel_values.type(self.visual.dtype)
image_embeds = self.visual(pixel_values, grid_thw=image_input.image_grid_thws)
return image_embeds
@@ -572,10 +572,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
inputs_embeds = general_mm_embed_routine(
input_ids=input_ids,
positions=positions,
forward_batch=forward_batch,
embed_tokens=self.get_input_embeddings(),
image_embedding_func=self.get_image_feature,
mm_data_embedding_func=self.get_image_feature,
)
hidden_states = self.model(

View File

@@ -45,7 +45,7 @@ from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.managers.schedule_batch import MultimodalInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2 import Qwen2Model
@@ -472,16 +472,16 @@ class Qwen2VLForConditionalGeneration(nn.Module):
# Use grid_t * grid_w * grid_h to pad tokens for each image
# add replaced padding by unique image hash
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
def pad_input_ids(self, input_ids: List[int], multi_modal_inputs: MultimodalInputs):
# Get all special token IDs
im_start_id: int = image_inputs.im_start_id
im_end_id: int = image_inputs.im_end_id
im_start_id: int = multi_modal_inputs.im_start_id
im_end_id: int = multi_modal_inputs.im_end_id
media_token_pairs = [(im_start_id, im_end_id)]
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
return pattern.pad_input_tokens(input_ids, image_inputs)
return pattern.pad_input_tokens(input_ids, multi_modal_inputs)
def get_image_feature(self, image_input: ImageInputs) -> torch.Tensor:
def get_image_feature(self, image_input: MultimodalInputs) -> torch.Tensor:
pixel_values = image_input.pixel_values.type(self.visual.dtype)
image_embeds = self.visual(pixel_values, grid_thw=image_input.image_grid_thws)
return image_embeds
@@ -530,10 +530,9 @@ class Qwen2VLForConditionalGeneration(nn.Module):
inputs_embeds = general_mm_embed_routine(
input_ids=input_ids,
positions=positions,
forward_batch=forward_batch,
embed_tokens=self.get_input_embeddings(),
image_embedding_func=self.get_image_feature,
mm_data_embedding_func=self.get_image_feature,
)
hidden_states = self.model(

View File

@@ -899,6 +899,7 @@ def v1_chat_generate_request(
input_ids = []
sampling_params_list = []
image_data_list = []
audio_data_list = []
return_logprobs = []
logprob_start_lens = []
top_logprobs_nums = []
@@ -912,6 +913,7 @@ def v1_chat_generate_request(
# - prompt: The full prompt string.
# - stop: Custom stop tokens.
# - image_data: None or a list of image strings (URLs or base64 strings).
# - audio_data: None or a list of audio strings (URLs).
# None skips any image processing in GenerateReqInput.
if not isinstance(request.messages, str):
# Apply chat template and its stop strings.
@@ -956,7 +958,7 @@ def v1_chat_generate_request(
)
except:
# This except branch will be triggered when the chosen model
# has a different tools input format that is not compatiable
# has a different tools input format that is not compatible
# with openAI's apply_chat_template tool_call format, like Mistral.
tools = [t if "function" in t else {"function": t} for t in tools]
prompt_ids = tokenizer_manager.tokenizer.apply_chat_template(
@@ -976,11 +978,13 @@ def v1_chat_generate_request(
prompt_ids += encoded
stop = request.stop
image_data = None
audio_data = None
modalities = []
else:
conv = generate_chat_conv(request, chat_template_name)
prompt = conv.get_prompt()
image_data = conv.image_data
audio_data = conv.audio_data
modalities = conv.modalities
stop = conv.stop_str or []
if request.stop:
@@ -994,6 +998,7 @@ def v1_chat_generate_request(
prompt_ids = request.messages
stop = request.stop
image_data = None
audio_data = None
modalities = []
input_ids.append(prompt_ids)
return_logprobs.append(request.logprobs)
@@ -1034,6 +1039,7 @@ def v1_chat_generate_request(
sampling_params_list.append(sampling_params)
image_data_list.append(image_data)
audio_data_list.append(audio_data)
modalities_list.append(modalities)
if len(all_requests) == 1:
if isinstance(input_ids[0], str):
@@ -1042,6 +1048,7 @@ def v1_chat_generate_request(
prompt_kwargs = {"input_ids": input_ids[0]}
sampling_params_list = sampling_params_list[0]
image_data_list = image_data_list[0]
audio_data_list = audio_data_list[0]
return_logprobs = return_logprobs[0]
logprob_start_lens = logprob_start_lens[0]
top_logprobs_nums = top_logprobs_nums[0]
@@ -1056,6 +1063,7 @@ def v1_chat_generate_request(
adapted_request = GenerateReqInput(
**prompt_kwargs,
image_data=image_data_list,
audio_data=audio_data_list,
sampling_params=sampling_params_list,
return_logprob=return_logprobs,
logprob_start_len=logprob_start_lens,

View File

@@ -227,14 +227,25 @@ class ChatCompletionMessageContentImageURL(BaseModel):
detail: Optional[Literal["auto", "low", "high"]] = "auto"
class ChatCompletionMessageContentAudioURL(BaseModel):
url: str
class ChatCompletionMessageContentImagePart(BaseModel):
type: Literal["image_url"]
image_url: ChatCompletionMessageContentImageURL
modalities: Optional[Literal["image", "multi-images", "video"]] = "image"
class ChatCompletionMessageContentAudioPart(BaseModel):
type: Literal["audio_url"]
audio_url: ChatCompletionMessageContentAudioURL
ChatCompletionMessageContentPart = Union[
ChatCompletionMessageContentTextPart, ChatCompletionMessageContentImagePart
ChatCompletionMessageContentTextPart,
ChatCompletionMessageContentImagePart,
ChatCompletionMessageContentAudioPart,
]

View File

@@ -55,14 +55,13 @@ import triton
import zmq
from fastapi.responses import ORJSONResponse
from packaging import version as pkg_version
from packaging.version import Version, parse
from PIL import Image
from starlette.routing import Mount
from torch import nn
from torch.func import functional_call
from torch.library import Library
from torch.profiler import ProfilerActivity, profile, record_function
from torch.utils._contextlib import _DecoratorContextManager
from torch.utils.cpp_extension import CUDA_HOME
from triton.runtime.cache import (
FileCacheManager,
default_cache_dir,
@@ -507,9 +506,37 @@ def decode_video_base64(video_base64):
) # Return an empty array and size tuple if no frames were found
def load_image(image_file: Union[str, bytes]):
from PIL import Image
def load_audio(audio_file: str, sr: int = 16000, mono: bool = True) -> np.ndarray:
# Use soundfile here, since librosa use it under the hood,
# and librosa will not support audio loading in the future
import soundfile as sf
from scipy.signal import resample
# print(f"loading {audio_file}")
# Load audio data
if isinstance(audio_file, bytes):
audio, original_sr = sf.read(BytesIO(audio_file))
elif audio_file.startswith("data:"):
audio_file = audio_file.split(",")[1]
audio, original_sr = sf.read(BytesIO(base64.b64decode(audio_file)))
elif isinstance(audio_file, str):
audio, original_sr = sf.read(audio_file)
else:
raise ValueError(f"Invalid audio format: {audio_file}")
# Resample audio if the original sample rate is different from the desired sample rate
if original_sr != sr:
num_samples = int(len(audio) * float(sr) / original_sr)
audio = resample(audio, num_samples)
# Convert to mono if requested and audio is stereo
if mono and len(audio.shape) > 1:
audio = np.mean(audio, axis=1)
return audio
def load_image(image_file: Union[str, bytes]) -> tuple[Image, tuple[int, int]]:
image = image_size = None
if isinstance(image_file, bytes):