model: Minicpmo (#3023)
This commit is contained in:
@@ -34,6 +34,7 @@ runtime_common = [
|
||||
"pydantic",
|
||||
"python-multipart",
|
||||
"pyzmq>=25.1.2",
|
||||
"soundfile==0.13.1",
|
||||
"torchao>=0.7.0",
|
||||
"transformers==4.50.0",
|
||||
"uvicorn",
|
||||
|
||||
@@ -15,6 +15,7 @@ class ChatTemplate:
|
||||
role_prefix_and_suffix: Dict[str, Tuple[str, str]]
|
||||
stop_str: List[str] = ()
|
||||
image_token: str = "<image>"
|
||||
audio_token: str = "<audio>"
|
||||
style: ChatTemplateStyle = ChatTemplateStyle.PLAIN
|
||||
|
||||
def get_prefix_and_suffix(
|
||||
@@ -253,6 +254,22 @@ register_chat_template(
|
||||
)
|
||||
)
|
||||
|
||||
# https://huggingface.co/openbmb/MiniCPM-o-2_6
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="minicpmo",
|
||||
default_system_prompt=None,
|
||||
role_prefix_and_suffix={
|
||||
"system": ("", " "),
|
||||
"user": ("user:", " "),
|
||||
"assistant": ("assistant:", "</s>"),
|
||||
},
|
||||
stop_str=("<|im_end|>", "<|endoftext|>"),
|
||||
image_token="(<image>./</image>)",
|
||||
audio_token="(<audio>./</audio>)",
|
||||
)
|
||||
)
|
||||
|
||||
# The difference between "llama-3-instruct-llava" and "llama-3-instruct" is that llava uses a different image_token.
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
@@ -474,12 +491,6 @@ def match_chat_ml(model_path: str):
|
||||
return get_chat_template("chatml-llava")
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_chat_minicpm(model_path: str):
|
||||
if "minicpm" in model_path:
|
||||
return get_chat_template("minicpmv")
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_chat_yi(model_path: str):
|
||||
model_path = model_path.lower()
|
||||
@@ -499,8 +510,10 @@ def match_gemma_it(model_path: str):
|
||||
@register_chat_template_matching_function
|
||||
def match_openbmb_minicpm(model_path: str):
|
||||
model_path = model_path.lower()
|
||||
if "minicpm" in model_path:
|
||||
if "minicpm-v" in model_path:
|
||||
return get_chat_template("minicpmv")
|
||||
elif "minicpm-o" in model_path:
|
||||
return get_chat_template("minicpmo")
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
|
||||
@@ -462,18 +462,19 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
|
||||
|
||||
multimodal_model_archs = [
|
||||
"DeepseekVL2ForCausalLM",
|
||||
"LlavaLlamaForCausalLM",
|
||||
"LlavaQwenForCausalLM",
|
||||
"LlavaMistralForCausalLM",
|
||||
"LlavaVidForCausalLM",
|
||||
"Gemma3ForConditionalGeneration",
|
||||
"Grok1VForCausalLM",
|
||||
"Grok1AForCausalLM",
|
||||
"LlavaLlamaForCausalLM",
|
||||
"LlavaMistralForCausalLM",
|
||||
"LlavaQwenForCausalLM",
|
||||
"LlavaVidForCausalLM",
|
||||
"MiniCPMO",
|
||||
"MiniCPMV",
|
||||
"MultiModalityCausalLM",
|
||||
"MllamaForConditionalGeneration",
|
||||
"Qwen2VLForConditionalGeneration",
|
||||
"Qwen2_5_VLForConditionalGeneration",
|
||||
"MiniCPMV",
|
||||
"MultiModalityCausalLM",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -73,11 +73,14 @@ class Conversation:
|
||||
stop_str: Union[str, List[str]] = None
|
||||
# The string that represents an image token in the prompt
|
||||
image_token: str = "<image>"
|
||||
audio_token: str = "<audio>"
|
||||
|
||||
image_data: Optional[List[str]] = None
|
||||
modalities: Optional[List[str]] = None
|
||||
stop_token_ids: Optional[int] = None
|
||||
|
||||
audio_data: Optional[List[str]] = None
|
||||
|
||||
def get_prompt(self) -> str:
|
||||
"""Get the prompt for generation."""
|
||||
system_prompt = self.system_template.format(system_message=self.system_message)
|
||||
@@ -327,6 +330,10 @@ class Conversation:
|
||||
"""Append a new message."""
|
||||
self.image_data.append(image)
|
||||
|
||||
def append_audio(self, audio: str):
|
||||
"""Append a new message."""
|
||||
self.audio_data.append(audio)
|
||||
|
||||
def update_last_message(self, message: str):
|
||||
"""Update the last output.
|
||||
|
||||
@@ -373,6 +380,7 @@ class Conversation:
|
||||
sep2=self.sep2,
|
||||
stop_str=self.stop_str,
|
||||
image_token=self.image_token,
|
||||
audio_token=self.audio_token,
|
||||
)
|
||||
|
||||
def dict(self):
|
||||
@@ -459,8 +467,10 @@ def generate_chat_conv(
|
||||
sep2=conv.sep2,
|
||||
stop_str=conv.stop_str,
|
||||
image_data=[],
|
||||
audio_data=[],
|
||||
modalities=[],
|
||||
image_token=conv.image_token,
|
||||
audio_token=conv.audio_token,
|
||||
)
|
||||
|
||||
if isinstance(request.messages, str):
|
||||
@@ -498,6 +508,7 @@ def generate_chat_conv(
|
||||
if conv.name != "qwen2-vl"
|
||||
else conv.image_token
|
||||
)
|
||||
audio_token = conv.audio_token
|
||||
for content in message.content:
|
||||
if content.type == "text":
|
||||
if num_image_url > 16:
|
||||
@@ -507,6 +518,10 @@ def generate_chat_conv(
|
||||
# NOTE: Only works for llava
|
||||
real_content += image_token
|
||||
conv.append_image(content.image_url.url)
|
||||
elif content.type == "audio_url":
|
||||
real_content += audio_token
|
||||
conv.append_audio(content.audio_url.url)
|
||||
|
||||
conv.append_message(conv.roles[0], real_content)
|
||||
elif msg_role == "assistant":
|
||||
parsed_content = ""
|
||||
@@ -704,3 +719,18 @@ register_conv_template(
|
||||
image_token="<image_placeholder>",
|
||||
)
|
||||
)
|
||||
|
||||
# Reference: https://huggingface.co/openbmb/MiniCPM-o-2_6#usage
|
||||
register_conv_template(
|
||||
Conversation(
|
||||
name="minicpmo",
|
||||
system_message="You are Qwen, created by Alibaba Cloud. You are a helpful assistant.",
|
||||
system_template="<|im_start|>system\n{system_message}",
|
||||
roles=("<|im_start|>user", "<|im_start|>assistant"),
|
||||
sep="<|im_end|>\n",
|
||||
sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
|
||||
stop_str=("<|im_end|>", "<|endoftext|>"),
|
||||
image_token="(<image>./</image>)",
|
||||
audio_token="(<audio>./</audio>)",
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,61 +0,0 @@
|
||||
# TODO: also move pad_input_ids into this module
|
||||
import importlib
|
||||
import inspect
|
||||
import logging
|
||||
import pkgutil
|
||||
from functools import lru_cache
|
||||
from typing import Union
|
||||
|
||||
from torch import Tensor
|
||||
from transformers import IMAGE_PROCESSOR_MAPPING
|
||||
|
||||
from sglang.srt.managers.image_processors.base_image_processor import (
|
||||
BaseImageProcessor,
|
||||
DummyImageProcessor,
|
||||
)
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
IMAGE_PROCESSOR_MAPPING = {}
|
||||
|
||||
|
||||
def get_image_processor(hf_config, server_args, processor) -> BaseImageProcessor:
|
||||
for model_cls, processor_cls in IMAGE_PROCESSOR_MAPPING.items():
|
||||
if model_cls.__name__ in hf_config.architectures:
|
||||
return processor_cls(hf_config, server_args, processor)
|
||||
raise ValueError(
|
||||
f"No image processor found for architecture: {hf_config.architectures}"
|
||||
)
|
||||
|
||||
|
||||
def get_dummy_image_processor():
|
||||
return DummyImageProcessor()
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def import_image_processors():
|
||||
package_name = "sglang.srt.managers.image_processors"
|
||||
package = importlib.import_module(package_name)
|
||||
for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
|
||||
if not ispkg:
|
||||
try:
|
||||
module = importlib.import_module(name)
|
||||
except Exception as e:
|
||||
logger.warning(f" Ignore import error when loading {name}: " f"{e}")
|
||||
continue
|
||||
all_members = inspect.getmembers(module, inspect.isclass)
|
||||
classes = [
|
||||
member
|
||||
for name, member in all_members
|
||||
if member.__module__ == module.__name__
|
||||
]
|
||||
for cls in classes:
|
||||
if issubclass(cls, BaseImageProcessor):
|
||||
for arch in getattr(cls, "models"):
|
||||
IMAGE_PROCESSOR_MAPPING[arch] = cls
|
||||
|
||||
|
||||
# also register processors
|
||||
import_image_processors()
|
||||
@@ -1,129 +0,0 @@
|
||||
import asyncio
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.managers.image_processor import BaseImageProcessor
|
||||
from sglang.srt.managers.image_processors.base_image_processor import (
|
||||
get_global_processor,
|
||||
)
|
||||
from sglang.srt.models.minicpmv import MiniCPMV
|
||||
|
||||
|
||||
class MiniCPMVImageProcessor(BaseImageProcessor):
|
||||
models = [MiniCPMV]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
self.IMAGE_TOKEN = "(<image>./</image>)"
|
||||
|
||||
@staticmethod
|
||||
def _process_images_task(images, input_text):
|
||||
processor = get_global_processor()
|
||||
result = processor.__call__(text=input_text, images=images, return_tensors="pt")
|
||||
return {
|
||||
"input_ids": result.input_ids,
|
||||
"pixel_values": result.pixel_values,
|
||||
"tgt_sizes": result.tgt_sizes,
|
||||
}
|
||||
|
||||
async def _process_images(self, images, input_text):
|
||||
if self.executor is not None:
|
||||
loop = asyncio.get_event_loop()
|
||||
image_inputs = await loop.run_in_executor(
|
||||
self.executor,
|
||||
MiniCPMVImageProcessor._process_images_task,
|
||||
images,
|
||||
input_text,
|
||||
)
|
||||
else:
|
||||
image_inputs = self._processor(
|
||||
images=images, text=input_text, return_tensors="pt"
|
||||
)
|
||||
|
||||
return image_inputs
|
||||
|
||||
async def process_images_async(
|
||||
self,
|
||||
image_data: List[Union[str, bytes]],
|
||||
input_ids,
|
||||
request_obj,
|
||||
max_req_input_len,
|
||||
):
|
||||
if not image_data:
|
||||
return None
|
||||
if not isinstance(image_data, list):
|
||||
image_data = [image_data]
|
||||
|
||||
base_output = self.load_images(
|
||||
input_ids=input_ids,
|
||||
image_data=image_data,
|
||||
image_token=self.IMAGE_TOKEN,
|
||||
max_req_input_len=max_req_input_len,
|
||||
)
|
||||
if base_output is None:
|
||||
return None
|
||||
|
||||
if len(base_output.all_frames) == 0:
|
||||
return None
|
||||
res = await self._process_images(
|
||||
images=base_output.all_frames, input_text=base_output.input_text
|
||||
)
|
||||
|
||||
# Collect special token ids
|
||||
tokenizer = self._processor.tokenizer
|
||||
im_start_id = tokenizer.im_start_id
|
||||
im_token_id = tokenizer.unk_token_id
|
||||
im_end_id = tokenizer.im_end_id
|
||||
if tokenizer.slice_start_id:
|
||||
slice_start_id = tokenizer.slice_start_id
|
||||
slice_end_id = tokenizer.slice_end_id
|
||||
|
||||
pixel_values = res["pixel_values"]
|
||||
tgt_sizes = res["tgt_sizes"]
|
||||
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError(
|
||||
"Incorrect type of pixel values. " f"Got type: {type(pixel_values)}"
|
||||
)
|
||||
|
||||
if not isinstance(tgt_sizes, (torch.Tensor, list)):
|
||||
raise ValueError(
|
||||
"Incorrect type of target sizes. " f"Got type: {type(tgt_sizes)}"
|
||||
)
|
||||
|
||||
if len(pixel_values) != len(tgt_sizes):
|
||||
raise ValueError(
|
||||
"Inconsistent batch lengths, found: "
|
||||
f"{len(pixel_values)} vs. {len(tgt_sizes)}"
|
||||
)
|
||||
|
||||
# tgt_sizes = [tgt_size for tgt_size in tgt_sizes if isinstance(tgt_size, torch.Tensor)]
|
||||
# tgt_sizes = torch.vstack(tgt_sizes).type(torch.int32)
|
||||
pixel_values_flat: List[torch.Tensor] = []
|
||||
tgt_sizes_flat: List[torch.Tensor] = []
|
||||
for pixel_b, tgt_b in zip(pixel_values, tgt_sizes):
|
||||
# per image
|
||||
if len(pixel_b) != len(tgt_b):
|
||||
raise ValueError(
|
||||
"Inconsistent N lengths, found: " f"{len(pixel_b)} vs {len(tgt_b)}"
|
||||
)
|
||||
for pixel_n, tgt_n in zip(pixel_b, tgt_b):
|
||||
# per patch
|
||||
pixel_values_flat += [pixel_n]
|
||||
tgt_sizes_flat += [tgt_n]
|
||||
|
||||
pixel_values = pixel_values_flat
|
||||
tgt_sizes = torch.stack(tgt_sizes_flat)
|
||||
return {
|
||||
"input_ids": res["input_ids"].flatten().tolist(),
|
||||
"pixel_values": pixel_values,
|
||||
"tgt_sizes": tgt_sizes,
|
||||
"image_hashes": base_output.image_hashes,
|
||||
"modalities": request_obj.modalities or ["image"],
|
||||
"im_start_id": im_start_id,
|
||||
"im_token_id": im_token_id,
|
||||
"im_end_id": im_end_id,
|
||||
"slice_start_id": slice_start_id,
|
||||
"slice_end_id": slice_end_id,
|
||||
}
|
||||
@@ -45,6 +45,8 @@ class GenerateReqInput:
|
||||
# The image input. It can be a file name, a url, or base64 encoded string.
|
||||
# See also python/sglang/srt/utils.py:load_image.
|
||||
image_data: Optional[Union[List[str], str]] = None
|
||||
# The audio input. Like image data, tt can be a file name, a url, or base64 encoded string.
|
||||
audio_data: Optional[Union[List[str], str]] = None
|
||||
# The sampling_params. See descriptions below.
|
||||
sampling_params: Optional[Union[List[Dict], Dict]] = None
|
||||
# The request id.
|
||||
@@ -167,6 +169,13 @@ class GenerateReqInput:
|
||||
elif isinstance(self.image_data, list):
|
||||
pass
|
||||
|
||||
if self.audio_data is None:
|
||||
self.audio_data = [None] * num
|
||||
elif not isinstance(self.audio_data, list):
|
||||
self.audio_data = [self.audio_data] * num
|
||||
elif isinstance(self.audio_data, list):
|
||||
pass
|
||||
|
||||
if self.sampling_params is None:
|
||||
self.sampling_params = [{}] * num
|
||||
elif not isinstance(self.sampling_params, list):
|
||||
@@ -231,6 +240,7 @@ class GenerateReqInput:
|
||||
text=self.text[i] if self.text is not None else None,
|
||||
input_ids=self.input_ids[i] if self.input_ids is not None else None,
|
||||
image_data=self.image_data[i],
|
||||
audio_data=self.audio_data[i],
|
||||
sampling_params=self.sampling_params[i],
|
||||
rid=self.rid[i],
|
||||
return_logprob=self.return_logprob[i],
|
||||
@@ -259,8 +269,8 @@ class TokenizedGenerateReqInput:
|
||||
input_text: str
|
||||
# The input token ids
|
||||
input_ids: List[int]
|
||||
# The image inputs
|
||||
image_inputs: dict
|
||||
# The multimodal inputs
|
||||
mm_inputs: dict
|
||||
# The sampling parameters
|
||||
sampling_params: SamplingParams
|
||||
# Whether to return the logprobs
|
||||
|
||||
@@ -9,7 +9,7 @@ import torch
|
||||
from torch import nn
|
||||
|
||||
from sglang.srt.managers.schedule_batch import (
|
||||
ImageInputs,
|
||||
MultimodalInputs,
|
||||
global_server_args_dict,
|
||||
logger,
|
||||
)
|
||||
@@ -26,7 +26,7 @@ class MultiModalityDataPaddingPattern:
|
||||
|
||||
@abstractmethod
|
||||
def pad_input_tokens(
|
||||
self, input_ids: List[int], image_inputs: ImageInputs
|
||||
self, input_ids: List[int], image_inputs: MultimodalInputs
|
||||
) -> List[int]:
|
||||
"""
|
||||
Pad the input ids sequence containing data tokens, and replace them with pad_values
|
||||
@@ -44,16 +44,16 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
|
||||
self.data_token_id_pairs = data_token_pairs
|
||||
|
||||
def pad_input_tokens(
|
||||
self, input_ids: List[int], image_inputs: ImageInputs
|
||||
self, input_ids: List[int], mm_inputs: MultimodalInputs
|
||||
) -> List[int]:
|
||||
"""
|
||||
This function will replace the data-tokens inbetween with pad_values accordingly
|
||||
"""
|
||||
pad_values = image_inputs.pad_values
|
||||
pad_values = mm_inputs.pad_values
|
||||
data_token_pairs = self.data_token_id_pairs
|
||||
image_inputs.image_offsets = []
|
||||
mm_inputs.image_offsets = []
|
||||
if data_token_pairs is None:
|
||||
data_token_pairs = [image_inputs.im_start_id, image_inputs.im_end_id]
|
||||
data_token_pairs = [mm_inputs.im_start_id, mm_inputs.im_end_id]
|
||||
if data_token_pairs is None:
|
||||
logger.warning(
|
||||
"No data_token_pairs provided, RadixAttention might be influenced."
|
||||
@@ -61,8 +61,6 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
|
||||
return input_ids
|
||||
start_token_ids = [s for s, _e in data_token_pairs]
|
||||
end_tokens_ids = [e for _s, e in data_token_pairs]
|
||||
# First start token marks new data
|
||||
data_start_token = start_token_ids[0]
|
||||
|
||||
padded_ids = []
|
||||
last_idx = 0
|
||||
@@ -77,9 +75,12 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
|
||||
for start_idx, end_idx in zip(start_indices, end_indices):
|
||||
padded_ids.extend(input_ids[last_idx : start_idx + 1])
|
||||
|
||||
if input_ids[start_idx] == data_start_token:
|
||||
if input_ids[start_idx] in start_token_ids:
|
||||
data_idx += 1
|
||||
image_inputs.image_offsets += [start_idx]
|
||||
mm_inputs.image_offsets += [start_idx]
|
||||
|
||||
if data_idx >= len(mm_inputs.pad_values):
|
||||
data_idx = len(mm_inputs.pad_values) - 1
|
||||
|
||||
num_tokens = end_idx - start_idx - 1
|
||||
pad_value = pad_values[data_idx]
|
||||
@@ -89,7 +90,7 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
|
||||
|
||||
padded_ids.extend(input_ids[last_idx:])
|
||||
|
||||
assert len(input_ids) == len(padded_ids)
|
||||
assert len(input_ids) == len(padded_ids), "Length validation fails"
|
||||
return padded_ids
|
||||
|
||||
|
||||
@@ -107,26 +108,25 @@ class MultModalityDataPaddingPatternSingleToken(MultiModalityDataPaddingPattern)
|
||||
self.num_data_token_calc_func = num_data_token_calc_func
|
||||
|
||||
def pad_input_tokens(
|
||||
self, input_ids: List[int], image_inputs: ImageInputs
|
||||
self, input_ids: List[int], mm_inputs: MultimodalInputs
|
||||
) -> List[int]:
|
||||
"""
|
||||
This function will follow the procedure of:
|
||||
1. the data token will be expanded, of which the final number will be calculated by `num_data_token_calc_func`
|
||||
2. the padded data tokens will be replaced with their pad_values
|
||||
"""
|
||||
image_grid_thws = image_inputs.image_grid_thws
|
||||
pad_values = image_inputs.pad_values
|
||||
image_grid_thws = mm_inputs.image_grid_thws
|
||||
pad_values = mm_inputs.pad_values
|
||||
|
||||
image_indices = [
|
||||
idx
|
||||
for idx, token in enumerate(input_ids)
|
||||
if token == image_inputs.im_token_id
|
||||
idx for idx, token in enumerate(input_ids) if token == mm_inputs.im_token_id
|
||||
]
|
||||
|
||||
image_inputs.image_offsets = []
|
||||
mm_inputs.image_offsets = []
|
||||
|
||||
input_ids_with_image = []
|
||||
for image_cnt, _ in enumerate(image_grid_thws):
|
||||
# print(f"image_cnt {image_cnt}")
|
||||
num_image_tokens = self.num_data_token_calc_func(image_grid_thws[image_cnt])
|
||||
if image_cnt == 0:
|
||||
non_image_tokens = input_ids[: image_indices[image_cnt]]
|
||||
@@ -135,7 +135,7 @@ class MultModalityDataPaddingPatternSingleToken(MultiModalityDataPaddingPattern)
|
||||
image_indices[image_cnt - 1] + 1 : image_indices[image_cnt]
|
||||
]
|
||||
input_ids_with_image.extend(non_image_tokens)
|
||||
image_inputs.image_offsets.append(len(input_ids_with_image))
|
||||
mm_inputs.image_offsets.append(len(input_ids_with_image))
|
||||
pad_ids = pad_values * (
|
||||
(num_image_tokens + len(pad_values)) // len(pad_values)
|
||||
)
|
||||
@@ -170,11 +170,11 @@ class MultiModalityDataPaddingPatternImageTokens(MultiModalityDataPaddingPattern
|
||||
return input_ids_tensor.tolist()
|
||||
|
||||
|
||||
def embed_image_inputs(
|
||||
image_input: ImageInputs,
|
||||
def embed_mm_inputs(
|
||||
mm_input: MultimodalInputs,
|
||||
input_ids: torch.Tensor,
|
||||
input_embedding: nn.Embedding,
|
||||
image_embedding_func,
|
||||
mm_data_embedding_func: Callable[[MultimodalInputs], torch.Tensor],
|
||||
placeholder_token_ids: List[int] = None,
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
@@ -184,10 +184,10 @@ def embed_image_inputs(
|
||||
Returns:
|
||||
final embedding: Optional[torch.Tensor]
|
||||
"""
|
||||
if image_input is None:
|
||||
if mm_input is None:
|
||||
return None
|
||||
|
||||
placeholder_token_ids = placeholder_token_ids or image_input.pad_values
|
||||
placeholder_token_ids = placeholder_token_ids or mm_input.pad_values
|
||||
|
||||
# boolean masking the special tokens
|
||||
special_image_mask = torch.isin(
|
||||
@@ -196,12 +196,18 @@ def embed_image_inputs(
|
||||
).unsqueeze(-1)
|
||||
|
||||
num_image_tokens_in_input_ids = special_image_mask.sum()
|
||||
# print(f"{num_image_tokens_in_input_ids}")
|
||||
# print(f"{input_ids}")
|
||||
|
||||
# return
|
||||
if num_image_tokens_in_input_ids == 0:
|
||||
# unexpected
|
||||
inputs_embeds = input_embedding(input_ids)
|
||||
else:
|
||||
image_embedding = image_embedding_func(image_input)
|
||||
# print(f"Getting image feature")
|
||||
image_embedding = mm_data_embedding_func(mm_input)
|
||||
|
||||
# print(f"image_embedding: {image_embedding.shape}")
|
||||
|
||||
if image_embedding.dim() == 2:
|
||||
num_image_tokens_in_embedding = image_embedding.shape[0]
|
||||
@@ -273,31 +279,95 @@ def embed_image_embedding(
|
||||
|
||||
def general_mm_embed_routine(
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
embed_tokens: nn.Embedding,
|
||||
image_embedding_func: Callable[[ImageInputs], torch.Tensor],
|
||||
mm_data_embedding_func: Callable[[MultimodalInputs], torch.Tensor],
|
||||
placeholder_token_ids: List[int] = None,
|
||||
):
|
||||
"""
|
||||
a general wrapper function to get final input embeds from multimodal models
|
||||
with a language model as causal model
|
||||
|
||||
Args:
|
||||
placeholder_token_ids (List[int]): the ids of mm data placeholder tokens
|
||||
|
||||
"""
|
||||
if (
|
||||
forward_batch.forward_mode.is_decode()
|
||||
or not forward_batch.contains_image_inputs()
|
||||
not forward_batch.forward_mode.is_decode()
|
||||
and forward_batch.contains_mm_inputs()
|
||||
):
|
||||
inputs_embeds = embed_tokens(input_ids)
|
||||
else:
|
||||
image = forward_batch.merge_image_inputs()
|
||||
inputs_embeds = embed_image_inputs(
|
||||
image_input=image,
|
||||
image = forward_batch.merge_mm_inputs()
|
||||
inputs_embeds = embed_mm_inputs(
|
||||
mm_input=image,
|
||||
input_ids=input_ids,
|
||||
input_embedding=embed_tokens,
|
||||
image_embedding_func=image_embedding_func,
|
||||
mm_data_embedding_func=mm_data_embedding_func,
|
||||
placeholder_token_ids=placeholder_token_ids,
|
||||
)
|
||||
# once used, image_inputs is useless
|
||||
# once used, mm_inputs is useless
|
||||
# just being defensive here
|
||||
forward_batch.image_inputs = None
|
||||
forward_batch.mm_inputs = None
|
||||
else:
|
||||
inputs_embeds = embed_tokens(input_ids)
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
|
||||
def get_multimodal_data_bounds(
|
||||
input_ids: torch.Tensor, pad_values: List[int], token_pairs: List[Tuple[int, int]]
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Returns a tensor indicating the bounds of multimodal data (images, video, audio, etc.)
|
||||
|
||||
Returns:
|
||||
[bounds_count, 2]
|
||||
"""
|
||||
# All the images in the batch should share the same special image
|
||||
# bound token ids.
|
||||
start_tokens = [s for s, _e in token_pairs]
|
||||
end_tokens = [e for _s, e in token_pairs]
|
||||
|
||||
assert all(isinstance(t, int) for t in start_tokens)
|
||||
assert all(isinstance(t, int) for t in end_tokens)
|
||||
|
||||
# print(input_ids)
|
||||
start_cond = torch.isin(
|
||||
input_ids, torch.tensor(start_tokens, device=input_ids.device)
|
||||
)
|
||||
end_cond = torch.isin(input_ids, torch.tensor(end_tokens, device=input_ids.device))
|
||||
|
||||
(data_start_tokens,) = torch.where(start_cond)
|
||||
(data_end_tokens,) = torch.where(end_cond)
|
||||
|
||||
# the im_start_id sometimes can be cached as prefix, but it is needed for the embedding of the images
|
||||
if len(data_start_tokens) != len(data_end_tokens):
|
||||
if (
|
||||
len(data_start_tokens) + 1 == len(data_end_tokens)
|
||||
and input_ids[0] in pad_values
|
||||
and data_end_tokens[0] < data_start_tokens[0]
|
||||
):
|
||||
data_start_tokens = torch.cat(
|
||||
[
|
||||
torch.tensor([0], device=data_start_tokens.device),
|
||||
data_start_tokens,
|
||||
]
|
||||
)
|
||||
valid_image_nums = min(len(data_start_tokens), len(data_end_tokens))
|
||||
|
||||
if valid_image_nums == 0:
|
||||
return torch.zeros((0, 2), device=input_ids.device)
|
||||
|
||||
# Filter out pairs where start_token >= end_token
|
||||
valid_pairs = []
|
||||
for i in range(valid_image_nums):
|
||||
start_token = data_start_tokens[i]
|
||||
end_token = data_end_tokens[i]
|
||||
if start_token < end_token:
|
||||
valid_pairs.append((start_token + 1, end_token - 1))
|
||||
|
||||
if not valid_pairs:
|
||||
return torch.zeros((0, 2), device=input_ids.device)
|
||||
|
||||
# Convert valid pairs to tensor
|
||||
valid_pairs_tensor = torch.tensor(valid_pairs, device=input_ids.device)
|
||||
return valid_pairs_tensor
|
||||
|
||||
68
python/sglang/srt/managers/multimodal_processor.py
Normal file
68
python/sglang/srt/managers/multimodal_processor.py
Normal file
@@ -0,0 +1,68 @@
|
||||
# TODO: also move pad_input_ids into this module
|
||||
import importlib
|
||||
import inspect
|
||||
import logging
|
||||
import pkgutil
|
||||
from functools import lru_cache
|
||||
|
||||
from transformers import PROCESSOR_MAPPING
|
||||
|
||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||
BaseMultimodalProcessor,
|
||||
)
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PROCESSOR_MAPPING = {}
|
||||
|
||||
|
||||
class DummyMultimodalProcessor(BaseMultimodalProcessor):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
async def process_mm_data_async(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
|
||||
def get_dummy_processor():
|
||||
return DummyMultimodalProcessor()
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def import_processors():
|
||||
package_name = "sglang.srt.managers.multimodal_processors"
|
||||
package = importlib.import_module(package_name)
|
||||
for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
|
||||
if not ispkg:
|
||||
try:
|
||||
module = importlib.import_module(name)
|
||||
except Exception as e:
|
||||
logger.warning(f"Ignore import error when loading {name}: " f"{e}")
|
||||
continue
|
||||
all_members = inspect.getmembers(module, inspect.isclass)
|
||||
classes = [
|
||||
member
|
||||
for name, member in all_members
|
||||
if member.__module__ == module.__name__
|
||||
]
|
||||
for cls in (
|
||||
cls for cls in classes if issubclass(cls, BaseMultimodalProcessor)
|
||||
):
|
||||
assert hasattr(cls, "models")
|
||||
for arch in getattr(cls, "models"):
|
||||
PROCESSOR_MAPPING[arch] = cls
|
||||
|
||||
|
||||
def get_mm_processor(
|
||||
hf_config, server_args: ServerArgs, processor
|
||||
) -> BaseMultimodalProcessor:
|
||||
for model_cls, processor_cls in PROCESSOR_MAPPING.items():
|
||||
if model_cls.__name__ in hf_config.architectures:
|
||||
return processor_cls(hf_config, server_args, processor)
|
||||
raise ValueError(
|
||||
f"No processor registered for architecture: {hf_config.architectures}.\n"
|
||||
f"Registered architectures: {[model_cls.__name__ for model_cls in PROCESSOR_MAPPING.keys()]}"
|
||||
)
|
||||
|
||||
self.image_proce
|
||||
@@ -4,16 +4,16 @@ import dataclasses
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import transformers
|
||||
from decord import VideoReader, cpu
|
||||
from openai import BadRequestError
|
||||
from PIL import Image
|
||||
|
||||
from sglang.srt.utils import load_image
|
||||
from sglang.utils import logger
|
||||
from sglang.srt.utils import load_audio, load_image, logger
|
||||
|
||||
global global_processor
|
||||
|
||||
@@ -24,21 +24,41 @@ def get_global_processor():
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BaseImageProcessorOutput:
|
||||
image_hashes: list[int]
|
||||
image_sizes: list[tuple[int, int]]
|
||||
all_frames: [PIL.Image]
|
||||
# input_text, with each frame of video/image represented as an image_token
|
||||
class BaseMultiModalProcessorOutput:
|
||||
# input_text, with each frame of video/image represented with a image_token
|
||||
input_text: str
|
||||
|
||||
mm_data_hashes: Optional[list[int]]
|
||||
# images
|
||||
image_sizes: Optional[list[int]]
|
||||
# frames loaded from image and video, in given order
|
||||
images: Optional[list[PIL.Image]] = None
|
||||
|
||||
# audios
|
||||
audios: Optional[list[np.ndarray]] = None
|
||||
|
||||
def normalize(self):
|
||||
for field_name in ["data_hashes", "image_sizes", "all_frames"]:
|
||||
for field_name in ["data_hashes", "image_sizes", "images", "audios"]:
|
||||
field = getattr(self, field_name, None)
|
||||
if field is not None and isinstance(field, list) and len(field) == 0:
|
||||
setattr(self, field_name, None)
|
||||
|
||||
|
||||
class BaseImageProcessor(ABC):
|
||||
@dataclasses.dataclass
|
||||
class MultimodalSpecialTokens:
|
||||
image_token: Optional[str] = None
|
||||
video_token: Optional[str] = None
|
||||
audio_token: Optional[str] = None
|
||||
|
||||
def collect(self) -> list[str]:
|
||||
return [
|
||||
token
|
||||
for token in [self.image_token, self.video_token, self.audio_token]
|
||||
if token
|
||||
]
|
||||
|
||||
|
||||
class BaseMultimodalProcessor(ABC):
|
||||
models = []
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
@@ -72,7 +92,7 @@ class BaseImageProcessor(ABC):
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
async def process_images_async(
|
||||
async def process_mm_data_async(
|
||||
self, image_data, input_text, max_req_input_len, **kwargs
|
||||
):
|
||||
pass
|
||||
@@ -120,29 +140,33 @@ class BaseImageProcessor(ABC):
|
||||
frames = [Image.fromarray(v.astype("uint8")) for v in frames]
|
||||
return frames
|
||||
|
||||
def load_images(
|
||||
def load_mm_data(
|
||||
self,
|
||||
input_ids: list[int],
|
||||
image_data,
|
||||
image_token: Union[int, str],
|
||||
multimodal_tokens: MultimodalSpecialTokens,
|
||||
max_req_input_len: int,
|
||||
image_data: Optional[list] = None,
|
||||
audio_data: Optional[list] = None,
|
||||
return_text: Optional[bool] = True,
|
||||
discard_alpha_channel: bool = True,
|
||||
) -> BaseImageProcessorOutput:
|
||||
) -> BaseMultiModalProcessorOutput:
|
||||
"""
|
||||
Each frame of video/image will be replaced by a single image token
|
||||
|
||||
Args:
|
||||
image_token: The token ID representing the image placeholder.
|
||||
multimodal_tokens (list[str]): list of special token which denoting a single multimodal data
|
||||
e.g. image token or audio token
|
||||
discard_alpha_channel: if True, discards the alpha channel in the returned images
|
||||
|
||||
"""
|
||||
if isinstance(image_token, int):
|
||||
image_token_str = self._processor.tokenizer.convert_ids_to_tokens(
|
||||
image_token
|
||||
if isinstance(multimodal_tokens.image_token, int):
|
||||
multimodal_tokens.image_token = (
|
||||
self._processor.tokenizer.convert_ids_to_tokens(
|
||||
multimodal_tokens.image_token
|
||||
)
|
||||
)
|
||||
else:
|
||||
image_token_str = image_token
|
||||
multimodal_tokens.image_token = multimodal_tokens.image_token
|
||||
|
||||
if isinstance(input_ids, list) and return_text:
|
||||
assert len(input_ids) and isinstance(input_ids[0], int)
|
||||
@@ -152,7 +176,11 @@ class BaseImageProcessor(ABC):
|
||||
if return_text:
|
||||
import re
|
||||
|
||||
pattern = "(" + "|".join(re.escape(sep) for sep in [image_token]) + ")"
|
||||
pattern = (
|
||||
"("
|
||||
+ "|".join(re.escape(sep) for sep in multimodal_tokens.collect())
|
||||
+ ")"
|
||||
)
|
||||
# split text into list of normal text and special tokens
|
||||
text_parts = re.split(pattern, input_text)
|
||||
|
||||
@@ -162,7 +190,7 @@ class BaseImageProcessor(ABC):
|
||||
total_frame_count = sum(estimated_frames_list)
|
||||
# a heuristic value, suggesting the maximum fraction of frames to embed from all visual inputs.
|
||||
# e.g., 0.1 suggests that 1 frame out of 10 input frames should be used
|
||||
_scaling_factor = min(1.0, MAX_NUM_FRAMES / max(1, total_frame_count))
|
||||
scaling_factor = min(1.0, MAX_NUM_FRAMES / max(1, total_frame_count))
|
||||
|
||||
assert len(image_data) == len(estimated_frames_list)
|
||||
|
||||
@@ -171,9 +199,16 @@ class BaseImageProcessor(ABC):
|
||||
new_text = ""
|
||||
for index, text_part in enumerate(text_parts):
|
||||
try:
|
||||
if text_part == image_token:
|
||||
if text_part == multimodal_tokens.image_token:
|
||||
# load as image
|
||||
frames_to_process = estimated_frames_list[image_index]
|
||||
if len(images) >= MAX_NUM_FRAMES:
|
||||
frames_to_process = 0
|
||||
else:
|
||||
estimated_frames = estimated_frames_list[image_index]
|
||||
frames_to_process = max(
|
||||
1, int(estimated_frames * scaling_factor)
|
||||
)
|
||||
|
||||
if frames_to_process == 0:
|
||||
frames = []
|
||||
else:
|
||||
@@ -183,7 +218,7 @@ class BaseImageProcessor(ABC):
|
||||
):
|
||||
# video
|
||||
path = image_file[len("video:") :]
|
||||
frames = self.encode_video(
|
||||
frames = BaseMultimodalProcessor.encode_video(
|
||||
path, frame_count_limit=frames_to_process
|
||||
)
|
||||
else:
|
||||
@@ -200,40 +235,41 @@ class BaseImageProcessor(ABC):
|
||||
images += frames
|
||||
image_index += 1
|
||||
if frames_to_process != 0:
|
||||
new_text += image_token * len(frames)
|
||||
new_text += multimodal_tokens.image_token * len(frames)
|
||||
assert frames_to_process == len(frames)
|
||||
elif text_part == multimodal_tokens.audio_token:
|
||||
# load as audio
|
||||
audio_file = audio_data[audio_index]
|
||||
audio = load_audio(audio_file)
|
||||
hashes += [hash(audio_file)]
|
||||
audios += [audio]
|
||||
audio_index += 1
|
||||
new_text += multimodal_tokens.audio_token
|
||||
else:
|
||||
# TODO(mick): handle video
|
||||
# normal text
|
||||
new_text += text_part
|
||||
|
||||
except Exception as e:
|
||||
|
||||
logger.error(f"An exception occurred while loading images: {e}")
|
||||
raise BadRequestError(
|
||||
f"An exception occurred while loading images: {e}"
|
||||
)
|
||||
|
||||
return BaseImageProcessorOutput(
|
||||
image_hashes=hashes,
|
||||
out = BaseMultiModalProcessorOutput(
|
||||
mm_data_hashes=hashes,
|
||||
image_sizes=image_sizes,
|
||||
all_frames=images,
|
||||
images=images,
|
||||
audios=audios,
|
||||
input_text=new_text,
|
||||
)
|
||||
out.normalize()
|
||||
return out
|
||||
|
||||
|
||||
class DummyImageProcessor(BaseImageProcessor):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
async def process_images_async(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
|
||||
def init_global_processor(sglang_image_processor: BaseImageProcessor, server_args):
|
||||
"""Init the global processor for multi-modal models."""
|
||||
def init_global_processor(sglang_processor: BaseMultimodalProcessor, server_args):
|
||||
"""
|
||||
Init the global processor for multimodal models."""
|
||||
global global_processor
|
||||
transformers.logging.set_verbosity_error()
|
||||
global_processor = sglang_image_processor._build_processor(server_args=server_args)
|
||||
global_processor = sglang_processor._build_processor(server_args=server_args)
|
||||
@@ -20,14 +20,15 @@ import asyncio
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.managers.image_processor import BaseImageProcessor
|
||||
from sglang.srt.managers.image_processors.base_image_processor import (
|
||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||
BaseMultimodalProcessor,
|
||||
MultimodalSpecialTokens,
|
||||
get_global_processor,
|
||||
)
|
||||
from sglang.srt.models.deepseek_vl2 import DeepseekVL2ForCausalLM
|
||||
|
||||
|
||||
class DeepseekVL2ImageProcessor(BaseImageProcessor):
|
||||
class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
|
||||
models = [DeepseekVL2ForCausalLM]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
@@ -63,7 +64,23 @@ class DeepseekVL2ImageProcessor(BaseImageProcessor):
|
||||
|
||||
return image_inputs
|
||||
|
||||
async def process_images_async(
|
||||
async def _process_images(self, image_data, input_text, max_req_input_len):
|
||||
if self.executor is not None:
|
||||
loop = asyncio.get_event_loop()
|
||||
image_inputs = await loop.run_in_executor(
|
||||
self.executor,
|
||||
DeepseekVL2ImageProcessor._process_images_task,
|
||||
image_data,
|
||||
input_text,
|
||||
max_req_input_len,
|
||||
)
|
||||
else:
|
||||
image_inputs = self._process_images_task(
|
||||
image_data, input_text, max_req_input_len
|
||||
)
|
||||
return image_inputs
|
||||
|
||||
async def process_mm_data_async(
|
||||
self, image_data, input_ids, request_obj, max_req_input_len, *args, **kwargs
|
||||
):
|
||||
if not image_data:
|
||||
@@ -75,11 +92,14 @@ class DeepseekVL2ImageProcessor(BaseImageProcessor):
|
||||
images, image_sizes = [], []
|
||||
|
||||
image_token = self.IMAGE_TOKEN
|
||||
base_output = self.load_images(
|
||||
input_ids, image_data, image_token, max_req_input_len
|
||||
base_output = self.load_mm_data(
|
||||
input_ids,
|
||||
image_data=image_data,
|
||||
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
|
||||
max_req_input_len=max_req_input_len,
|
||||
)
|
||||
res = await self._process_images(
|
||||
base_output.all_frames, base_output.input_text, max_req_input_len
|
||||
base_output.images, base_output.input_text, max_req_input_len
|
||||
)
|
||||
images_seq_mask = res["images_seq_mask"]
|
||||
images_spatial_crop = res["images_spatial_crop"]
|
||||
@@ -91,7 +111,7 @@ class DeepseekVL2ImageProcessor(BaseImageProcessor):
|
||||
"input_ids": res["input_ids"].tolist(),
|
||||
"pixel_values": res["images"],
|
||||
"im_token_id": res["im_token_id"],
|
||||
"image_hashes": base_output.image_hashes,
|
||||
"data_hashes": base_output.mm_data_hashes,
|
||||
"image_sizes": image_sizes,
|
||||
"images_emb_mask": images_seq_mask,
|
||||
"image_spatial_crop": batched_images_spatial_crop,
|
||||
@@ -1,12 +1,12 @@
|
||||
import asyncio
|
||||
from typing import List, Union
|
||||
|
||||
from transformers.utils import logging
|
||||
|
||||
from sglang.srt.managers.image_processor import (
|
||||
BaseImageProcessor as SGLangBaseImageProcessor,
|
||||
from sglang.srt.managers.multimodal_processor import (
|
||||
BaseMultimodalProcessor as SGLangBaseProcessor,
|
||||
)
|
||||
from sglang.srt.managers.image_processors.base_image_processor import (
|
||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||
MultimodalSpecialTokens,
|
||||
get_global_processor,
|
||||
)
|
||||
from sglang.srt.models.gemma3_mm import Gemma3ForConditionalGeneration
|
||||
@@ -16,7 +16,7 @@ from sglang.srt.models.gemma3_mm import Gemma3ForConditionalGeneration
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Gemma3SGLangImageProcessor(SGLangBaseImageProcessor):
|
||||
class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
|
||||
models = [Gemma3ForConditionalGeneration]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
@@ -47,7 +47,7 @@ class Gemma3SGLangImageProcessor(SGLangBaseImageProcessor):
|
||||
"pixel_values": pixel_values,
|
||||
}
|
||||
|
||||
async def process_images_async(
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
image_data: List[Union[str, bytes]],
|
||||
input_ids,
|
||||
@@ -62,22 +62,22 @@ class Gemma3SGLangImageProcessor(SGLangBaseImageProcessor):
|
||||
image_data = [image_data]
|
||||
|
||||
image_token = self.IMAGE_TOKEN
|
||||
base_output = self.load_images(
|
||||
base_output = self.load_mm_data(
|
||||
input_ids=input_ids,
|
||||
image_data=image_data,
|
||||
image_token=image_token,
|
||||
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
|
||||
max_req_input_len=max_req_input_len,
|
||||
discard_alpha_channel=True,
|
||||
)
|
||||
|
||||
ret = await self._process_single_image(
|
||||
input_text=base_output.input_text, images=base_output.all_frames
|
||||
input_text=base_output.input_text, images=base_output.images
|
||||
)
|
||||
|
||||
return {
|
||||
"input_ids": ret["input_ids"].flatten().tolist(),
|
||||
"pixel_values": ret["pixel_values"],
|
||||
"image_hashes": base_output.image_hashes,
|
||||
"data_hashes": base_output.mm_data_hashes,
|
||||
"im_start_id": self.IM_START_TOKEN_ID,
|
||||
"im_end_id": self.IM_END_TOKEN_ID,
|
||||
}
|
||||
@@ -1,16 +1,15 @@
|
||||
import asyncio
|
||||
from typing import List, Union
|
||||
|
||||
from sglang.srt.managers.image_processors.base_image_processor import (
|
||||
BaseImageProcessor as SGLangBaseImageProcessor,
|
||||
)
|
||||
from sglang.srt.managers.image_processors.base_image_processor import (
|
||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||
BaseMultimodalProcessor,
|
||||
MultimodalSpecialTokens,
|
||||
get_global_processor,
|
||||
)
|
||||
from sglang.srt.models.deepseek_janus_pro import MultiModalityCausalLM
|
||||
|
||||
|
||||
class JanusProProcessor(SGLangBaseImageProcessor):
|
||||
class JanusProImageProcessor(BaseMultimodalProcessor):
|
||||
models = [MultiModalityCausalLM]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
@@ -36,7 +35,7 @@ class JanusProProcessor(SGLangBaseImageProcessor):
|
||||
loop = asyncio.get_event_loop()
|
||||
image_inputs = await loop.run_in_executor(
|
||||
self.executor,
|
||||
JanusProProcessor._process_images_task,
|
||||
JanusProImageProcessor._process_images_task,
|
||||
images,
|
||||
input_text,
|
||||
)
|
||||
@@ -47,7 +46,7 @@ class JanusProProcessor(SGLangBaseImageProcessor):
|
||||
|
||||
return image_inputs
|
||||
|
||||
async def process_images_async(
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
image_data: List[Union[str, bytes]],
|
||||
input_ids,
|
||||
@@ -61,20 +60,24 @@ class JanusProProcessor(SGLangBaseImageProcessor):
|
||||
if not isinstance(image_data, list):
|
||||
image_data = [image_data]
|
||||
|
||||
base_out = self.load_images(
|
||||
base_out = self.load_mm_data(
|
||||
input_ids=input_ids,
|
||||
image_data=image_data,
|
||||
image_token="<image_placeholder>",
|
||||
multimodal_tokens=MultimodalSpecialTokens(
|
||||
image_token="<image_placeholder>"
|
||||
),
|
||||
max_req_input_len=max_req_input_len,
|
||||
)
|
||||
images = base_out.all_frames
|
||||
images = base_out.images
|
||||
res = await self._process_images(images=images, input_text=base_out.input_text)
|
||||
|
||||
# print(res)
|
||||
# print(base_out)
|
||||
# print("", res["images_emb_mask"].shape)
|
||||
return {
|
||||
"input_ids": res["input_ids"].flatten().tolist(),
|
||||
"pixel_values": res["pixel_values"],
|
||||
"images_emb_mask": res["images_emb_mask"],
|
||||
"image_hashes": base_out.image_hashes,
|
||||
"data_hashes": base_out.mm_data_hashes,
|
||||
"im_start_id": res["im_start_id"],
|
||||
"im_end_id": res["im_end_id"],
|
||||
"im_token_id": res["im_token_id"],
|
||||
@@ -3,8 +3,8 @@ from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from sglang.srt.managers.image_processor import BaseImageProcessor
|
||||
from sglang.srt.managers.image_processors.base_image_processor import (
|
||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||
BaseMultimodalProcessor,
|
||||
get_global_processor,
|
||||
)
|
||||
from sglang.srt.mm_utils import expand2square, process_anyres_image
|
||||
@@ -14,7 +14,7 @@ from sglang.srt.utils import load_image, logger
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
|
||||
class LlavaImageProcessor(BaseImageProcessor):
|
||||
class LlavaImageProcessor(BaseMultimodalProcessor):
|
||||
models = [LlavaVidForCausalLM, LlavaQwenForCausalLM, LlavaMistralForCausalLM]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
@@ -86,7 +86,7 @@ class LlavaImageProcessor(BaseImageProcessor):
|
||||
image_data, aspect_ratio, grid_pinpoints
|
||||
)
|
||||
|
||||
async def process_images_async(
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
image_data: List[Union[str, bytes]],
|
||||
input_text,
|
||||
@@ -113,7 +113,7 @@ class LlavaImageProcessor(BaseImageProcessor):
|
||||
if "multi-images" in modalities or "video" in modalities:
|
||||
# Multiple images
|
||||
aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
|
||||
pixel_values, image_hashes, image_sizes = [], [], []
|
||||
pixel_values, data_hashes, image_sizes = [], [], []
|
||||
res = []
|
||||
for img_data in image_data:
|
||||
res.append(
|
||||
@@ -124,7 +124,7 @@ class LlavaImageProcessor(BaseImageProcessor):
|
||||
res = await asyncio.gather(*res)
|
||||
for pixel_v, image_h, image_s in res:
|
||||
pixel_values.append(pixel_v)
|
||||
image_hashes.append(image_h)
|
||||
data_hashes.append(image_h)
|
||||
image_sizes.append(image_s)
|
||||
|
||||
if isinstance(pixel_values[0], np.ndarray):
|
||||
@@ -134,14 +134,14 @@ class LlavaImageProcessor(BaseImageProcessor):
|
||||
pixel_values, image_hash, image_size = await self._process_single_image(
|
||||
image_data[0], aspect_ratio, grid_pinpoints
|
||||
)
|
||||
image_hashes = [image_hash]
|
||||
data_hashes = [image_hash]
|
||||
image_sizes = [image_size]
|
||||
else:
|
||||
raise ValueError(f"Invalid image data: {image_data}")
|
||||
|
||||
return {
|
||||
"pixel_values": pixel_values,
|
||||
"image_hashes": image_hashes,
|
||||
"data_hashes": data_hashes,
|
||||
"image_sizes": image_sizes,
|
||||
"modalities": request_obj.modalities or ["image"],
|
||||
}
|
||||
167
python/sglang/srt/managers/multimodal_processors/minicpm.py
Normal file
167
python/sglang/srt/managers/multimodal_processors/minicpm.py
Normal file
@@ -0,0 +1,167 @@
|
||||
import asyncio
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||
BaseMultimodalProcessor,
|
||||
MultimodalSpecialTokens,
|
||||
get_global_processor,
|
||||
)
|
||||
from sglang.srt.models.minicpmo import MiniCPMO
|
||||
from sglang.srt.models.minicpmv import MiniCPMV
|
||||
|
||||
|
||||
# Compatible with both 'O' and 'V'
|
||||
class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
||||
models = [MiniCPMV, MiniCPMO]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
self.image_token = "(<image>./</image>)"
|
||||
self.audio_token = "(<audio>./</audio>)"
|
||||
|
||||
@staticmethod
|
||||
def _process_data_task(input_text, images=None, audios=None):
|
||||
|
||||
if isinstance(images, list) and len(images) == 0:
|
||||
images = None
|
||||
if isinstance(audios, list) and len(audios) == 0:
|
||||
audios = None
|
||||
result = get_global_processor().__call__(
|
||||
text=input_text,
|
||||
images=images,
|
||||
audios=audios,
|
||||
return_tensors="pt",
|
||||
chunk_input=True,
|
||||
)
|
||||
return {
|
||||
"input_ids": result.input_ids,
|
||||
"pixel_values": getattr(result, "pixel_values", None),
|
||||
"tgt_sizes": getattr(result, "tgt_sizes", None),
|
||||
"audio_features": getattr(result, "audio_features", None),
|
||||
"audio_feature_lens": getattr(result, "audio_feature_lens", None),
|
||||
"audio_bounds": getattr(result, "audio_bounds", None),
|
||||
}
|
||||
|
||||
async def _process_data(self, images, input_text, audios=None):
|
||||
if self.executor is not None:
|
||||
loop = asyncio.get_event_loop()
|
||||
multimodal_data_inputs = await loop.run_in_executor(
|
||||
self.executor,
|
||||
MiniCPMMultimodalProcessor._process_data_task,
|
||||
input_text,
|
||||
images,
|
||||
audios,
|
||||
)
|
||||
else:
|
||||
multimodal_data_inputs = self._processor(
|
||||
images=images, text=input_text, audios=audios, return_tensors="pt"
|
||||
)
|
||||
|
||||
return multimodal_data_inputs
|
||||
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
image_data: List[Union[str, bytes]],
|
||||
input_ids,
|
||||
request_obj,
|
||||
max_req_input_len,
|
||||
):
|
||||
audio_data = request_obj.audio_data
|
||||
if not image_data and not audio_data:
|
||||
return None
|
||||
if not isinstance(image_data, list):
|
||||
image_data = [image_data]
|
||||
if not isinstance(audio_data, list):
|
||||
audio_data = [audio_data]
|
||||
|
||||
base_output = self.load_mm_data(
|
||||
input_ids=input_ids,
|
||||
max_req_input_len=max_req_input_len,
|
||||
audio_data=audio_data,
|
||||
image_data=image_data,
|
||||
multimodal_tokens=MultimodalSpecialTokens(
|
||||
image_token=self.image_token, audio_token=self.audio_token
|
||||
),
|
||||
)
|
||||
if base_output is None:
|
||||
return None
|
||||
|
||||
res = await self._process_data(
|
||||
images=base_output.images,
|
||||
input_text=base_output.input_text,
|
||||
audios=base_output.audios,
|
||||
)
|
||||
|
||||
# Collect special token ids
|
||||
tokenizer = self._processor.tokenizer
|
||||
slice_start_id, slice_end_id, audio_start_id, audio_end_id = (
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
if tokenizer.slice_start_id:
|
||||
slice_start_id = tokenizer.slice_start_id
|
||||
slice_end_id = tokenizer.slice_end_id
|
||||
if hasattr(tokenizer, "audio_start_id"):
|
||||
audio_start_id = tokenizer.audio_start_id
|
||||
audio_end_id = tokenizer.audio_end_id
|
||||
|
||||
im_token_id = tokenizer.unk_token_id
|
||||
pixel_values = res["pixel_values"]
|
||||
tgt_sizes = res["tgt_sizes"]
|
||||
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError(
|
||||
"Incorrect type of pixel values. " f"Got type: {type(pixel_values)}"
|
||||
)
|
||||
|
||||
if not isinstance(tgt_sizes, (torch.Tensor, list)):
|
||||
raise ValueError(
|
||||
"Incorrect type of target sizes. " f"Got type: {type(tgt_sizes)}"
|
||||
)
|
||||
|
||||
if len(pixel_values) != len(tgt_sizes):
|
||||
raise ValueError(
|
||||
"Inconsistent batch lengths, found: "
|
||||
f"{len(pixel_values)} vs. {len(tgt_sizes)}"
|
||||
)
|
||||
|
||||
pixel_values_flat: List[torch.Tensor] = []
|
||||
tgt_sizes_flat: List[torch.Tensor] = []
|
||||
for pixel_b, tgt_b in zip(pixel_values, tgt_sizes):
|
||||
# per image
|
||||
if len(pixel_b) != len(tgt_b):
|
||||
raise ValueError(
|
||||
"Inconsistent N lengths, found: " f"{len(pixel_b)} vs {len(tgt_b)}"
|
||||
)
|
||||
for pixel_n, tgt_n in zip(pixel_b, tgt_b):
|
||||
pixel_values_flat += [pixel_n]
|
||||
tgt_sizes_flat += [tgt_n]
|
||||
|
||||
pixel_values = pixel_values_flat
|
||||
if len(tgt_sizes_flat) == 0:
|
||||
tgt_sizes = None
|
||||
else:
|
||||
tgt_sizes = torch.stack(tgt_sizes_flat)
|
||||
if not isinstance(res["audio_features"], list):
|
||||
res["audio_features"] = [res["audio_features"]]
|
||||
return {
|
||||
"input_ids": res["input_ids"].flatten().tolist(),
|
||||
"pixel_values": pixel_values,
|
||||
"tgt_sizes": tgt_sizes,
|
||||
"data_hashes": base_output.mm_data_hashes,
|
||||
"modalities": request_obj.modalities or ["image"],
|
||||
"audio_start_id": audio_start_id,
|
||||
"audio_end_id": audio_end_id,
|
||||
"audio_features": res["audio_features"],
|
||||
"audio_bounds": res["audio_bounds"],
|
||||
"audio_feature_lens": res["audio_feature_lens"],
|
||||
"im_token_id": im_token_id,
|
||||
"im_start_id": tokenizer.im_start_id,
|
||||
"im_end_id": tokenizer.im_end_id,
|
||||
"slice_start_id": slice_start_id,
|
||||
"slice_end_id": slice_end_id,
|
||||
}
|
||||
@@ -1,15 +1,15 @@
|
||||
import asyncio
|
||||
from typing import List, Union
|
||||
|
||||
from sglang.srt.managers.image_processor import BaseImageProcessor
|
||||
from sglang.srt.managers.image_processors.base_image_processor import (
|
||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||
BaseMultimodalProcessor,
|
||||
get_global_processor,
|
||||
)
|
||||
from sglang.srt.models.mllama import MllamaForConditionalGeneration
|
||||
from sglang.srt.utils import load_image
|
||||
|
||||
|
||||
class MllamaImageProcessor(BaseImageProcessor):
|
||||
class MllamaImageProcessor(BaseMultimodalProcessor):
|
||||
models = [MllamaForConditionalGeneration]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
@@ -34,7 +34,7 @@ class MllamaImageProcessor(BaseImageProcessor):
|
||||
|
||||
return image_inputs
|
||||
|
||||
async def process_images_async(
|
||||
async def process_mm_data_async(
|
||||
self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs
|
||||
):
|
||||
if not image_data:
|
||||
@@ -53,7 +53,7 @@ class MllamaImageProcessor(BaseImageProcessor):
|
||||
images = load_image(image_data[0])[0]
|
||||
|
||||
image_inputs = await self._process_single_image(images, input_text)
|
||||
image_inputs["image_hashes"] = [hash(str(image_data))]
|
||||
image_inputs["data_hashes"] = [hash(str(image_data))]
|
||||
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
|
||||
|
||||
return image_inputs
|
||||
@@ -1,12 +1,16 @@
|
||||
import asyncio
|
||||
import math
|
||||
import time
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from sglang.srt.managers.image_processor import BaseImageProcessor
|
||||
from sglang.srt.managers.image_processors.base_image_processor import (
|
||||
from sglang.srt.managers.multimodal_processor import (
|
||||
BaseMultimodalProcessor as SGLangBaseProcessor,
|
||||
)
|
||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||
MultimodalSpecialTokens,
|
||||
get_global_processor,
|
||||
)
|
||||
from sglang.srt.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
|
||||
@@ -14,7 +18,7 @@ from sglang.srt.models.qwen2_vl import Qwen2VLForConditionalGeneration
|
||||
|
||||
|
||||
# Compatible with Qwen2VL and Qwen2_5VL
|
||||
class Qwen2_5VLImageProcessor(BaseImageProcessor):
|
||||
class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
||||
models = [Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
@@ -59,7 +63,7 @@ class Qwen2_5VLImageProcessor(BaseImageProcessor):
|
||||
else:
|
||||
return self._process_images_task(images, input_text, self.hf_config)
|
||||
|
||||
async def process_images_async(
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
image_data: List[Union[str, bytes]],
|
||||
input_ids,
|
||||
@@ -68,16 +72,17 @@ class Qwen2_5VLImageProcessor(BaseImageProcessor):
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
start = time.time()
|
||||
if not image_data:
|
||||
return None
|
||||
if isinstance(image_data, str):
|
||||
image_data = [image_data]
|
||||
|
||||
image_token = self.IMAGE_TOKEN
|
||||
base_output = self.load_images(
|
||||
base_output = self.load_mm_data(
|
||||
input_ids=input_ids,
|
||||
image_data=image_data,
|
||||
image_token=image_token,
|
||||
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
|
||||
max_req_input_len=max_req_input_len,
|
||||
)
|
||||
|
||||
@@ -139,7 +144,7 @@ class Qwen2_5VLImageProcessor(BaseImageProcessor):
|
||||
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
|
||||
return math.floor(number / factor) * factor
|
||||
|
||||
images = [resize_image(image) for image in base_output.all_frames]
|
||||
images = [resize_image(image) for image in base_output.images]
|
||||
|
||||
ret = await self._process_single_image(
|
||||
images=images, input_text=base_output.input_text
|
||||
@@ -147,11 +152,10 @@ class Qwen2_5VLImageProcessor(BaseImageProcessor):
|
||||
|
||||
image_grid_thws = torch.concat([ret["image_grid_thw"]])
|
||||
video_grid_thws = None
|
||||
|
||||
return {
|
||||
"input_ids": ret["input_ids"].flatten().tolist(),
|
||||
"pixel_values": ret["pixel_values"],
|
||||
"image_hashes": base_output.image_hashes,
|
||||
"data_hashes": base_output.mm_data_hashes,
|
||||
"modalities": request_obj.modalities or ["image"],
|
||||
"image_grid_thws": image_grid_thws,
|
||||
"video_grid_thws": video_grid_thws,
|
||||
@@ -144,11 +144,11 @@ class FINISH_ABORT(BaseFinishReason):
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ImageInputs:
|
||||
class MultimodalInputs:
|
||||
"""The image related inputs."""
|
||||
|
||||
pixel_values: Union[torch.Tensor, np.array]
|
||||
image_hashes: Optional[list] = None
|
||||
data_hashes: Optional[list] = None
|
||||
image_sizes: Optional[list] = None
|
||||
image_offsets: Optional[list] = None
|
||||
image_pad_len: Optional[list] = None
|
||||
@@ -182,20 +182,27 @@ class ImageInputs:
|
||||
im_end_id: Optional[int] = None
|
||||
slice_start_id: Optional[int] = None
|
||||
slice_end_id: Optional[int] = None
|
||||
# [num_images, 2 (w, h)]
|
||||
tgt_sizes: Optional[list] = None
|
||||
|
||||
# audio
|
||||
audio_start_id: Optional[torch.Tensor] = None
|
||||
audio_end_id: Optional[torch.Tensor] = None
|
||||
audio_features: Optional[List[torch.Tensor]] = None
|
||||
audio_feature_lens: Optional[List[torch.Tensor]] = None
|
||||
|
||||
@staticmethod
|
||||
def from_dict(obj: dict):
|
||||
ret = ImageInputs(
|
||||
ret = MultimodalInputs(
|
||||
pixel_values=obj["pixel_values"],
|
||||
image_hashes=obj["image_hashes"],
|
||||
data_hashes=obj["data_hashes"],
|
||||
)
|
||||
|
||||
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
|
||||
# Please note that if the `input_ids` is later used in the model forward,
|
||||
# you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
|
||||
# errors in cuda kernels. See also llava.py for example.
|
||||
ret.pad_values = [x % (1 << 30) for x in ret.image_hashes]
|
||||
ret.pad_values = [x % (1 << 30) for x in ret.data_hashes]
|
||||
|
||||
optional_args = [
|
||||
"image_sizes",
|
||||
@@ -211,6 +218,10 @@ class ImageInputs:
|
||||
"slice_start_id",
|
||||
"slice_end_id",
|
||||
"tgt_sizes",
|
||||
"audio_start_id",
|
||||
"audio_end_id",
|
||||
"audio_features",
|
||||
"audio_feature_lens",
|
||||
]
|
||||
for arg in optional_args:
|
||||
if arg in obj:
|
||||
@@ -223,9 +234,19 @@ class ImageInputs:
|
||||
or isinstance(ret.pixel_values, list)
|
||||
)
|
||||
|
||||
assert ret.audio_features is None or isinstance(ret.audio_features, list)
|
||||
|
||||
return ret
|
||||
|
||||
def merge(self, other: ImageInputs):
|
||||
def contains_image_inputs(self) -> bool:
|
||||
""" """
|
||||
return self.pixel_values is not None and self.pixel_values != []
|
||||
|
||||
def contains_audio_inputs(self) -> bool:
|
||||
""" """
|
||||
return self.audio_features is not None and self.audio_features != []
|
||||
|
||||
def merge(self, other: MultimodalInputs):
|
||||
"""
|
||||
merge image inputs when requests are being merged
|
||||
"""
|
||||
@@ -268,10 +289,12 @@ class ImageInputs:
|
||||
# Please note that if the `input_ids` is later used in the model forward,
|
||||
# you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
|
||||
# errors in cuda kernels. See also llava.py for example.
|
||||
self.image_hashes += other.image_hashes
|
||||
self.pad_values = [x % (1 << 30) for x in self.image_hashes]
|
||||
self.data_hashes += other.data_hashes
|
||||
self.pad_values = [x % (1 << 30) for x in self.data_hashes]
|
||||
|
||||
# args needed to be merged
|
||||
optional_args = [
|
||||
"audio_features",
|
||||
"image_sizes",
|
||||
"image_offsets",
|
||||
"image_pad_len",
|
||||
@@ -362,7 +385,7 @@ class Req:
|
||||
self.decoded_text = ""
|
||||
|
||||
# For multimodal inputs
|
||||
self.image_inputs: Optional[ImageInputs] = None
|
||||
self.multimodal_inputs: Optional[MultimodalInputs] = None
|
||||
|
||||
# Prefix info
|
||||
# The indices to kv cache for the shared prefix.
|
||||
@@ -458,10 +481,10 @@ class Req:
|
||||
return len(self.origin_input_ids) + len(self.output_ids)
|
||||
|
||||
def extend_image_inputs(self, image_inputs):
|
||||
if self.image_inputs is None:
|
||||
self.image_inputs = image_inputs
|
||||
if self.multimodal_inputs is None:
|
||||
self.multimodal_inputs = image_inputs
|
||||
else:
|
||||
self.image_inputs.merge(image_inputs)
|
||||
self.multimodal_inputs.merge(image_inputs)
|
||||
|
||||
def finished(self) -> bool:
|
||||
# Whether request reached finished condition
|
||||
@@ -802,7 +825,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
self.encoder_cached = []
|
||||
|
||||
for req in self.reqs:
|
||||
im = req.image_inputs
|
||||
im = req.multimodal_inputs
|
||||
if im is None or im.num_image_tokens is None:
|
||||
# No image input
|
||||
self.encoder_lens_cpu.append(0)
|
||||
@@ -1391,7 +1414,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
extend_seq_lens=extend_seq_lens,
|
||||
extend_prefix_lens=extend_prefix_lens,
|
||||
extend_logprob_start_lens=extend_logprob_start_lens,
|
||||
image_inputs=[r.image_inputs for r in self.reqs],
|
||||
multimodal_inputs=[r.multimodal_inputs for r in self.reqs],
|
||||
encoder_cached=self.encoder_cached,
|
||||
encoder_lens=self.encoder_lens,
|
||||
encoder_lens_cpu=self.encoder_lens_cpu,
|
||||
@@ -1474,7 +1497,7 @@ class ModelWorkerBatch:
|
||||
extend_input_logprob_token_ids: Optional[torch.Tensor]
|
||||
|
||||
# For multimodal
|
||||
image_inputs: Optional[List[ImageInputs]]
|
||||
multimodal_inputs: Optional[List[MultimodalInputs]]
|
||||
|
||||
# For encoder-decoder
|
||||
encoder_cached: Optional[List[bool]]
|
||||
|
||||
@@ -88,7 +88,7 @@ from sglang.srt.managers.io_struct import (
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import (
|
||||
FINISH_ABORT,
|
||||
ImageInputs,
|
||||
MultimodalInputs,
|
||||
Req,
|
||||
ScheduleBatch,
|
||||
global_server_args_dict,
|
||||
@@ -841,8 +841,8 @@ class Scheduler(
|
||||
return
|
||||
|
||||
# Handle multimodal inputs
|
||||
if recv_req.image_inputs is not None:
|
||||
image_inputs = ImageInputs.from_dict(recv_req.image_inputs)
|
||||
if recv_req.mm_inputs is not None:
|
||||
image_inputs = MultimodalInputs.from_dict(recv_req.mm_inputs)
|
||||
# Expand a single image token into multiple dummy tokens for receiving image embeddings
|
||||
req.origin_input_ids = self.pad_input_ids_func(
|
||||
req.origin_input_ids, image_inputs
|
||||
@@ -856,7 +856,7 @@ class Scheduler(
|
||||
)
|
||||
logger.error(error_msg)
|
||||
req.origin_input_ids = [0]
|
||||
req.image_inputs = None
|
||||
req.multimodal_inputs = None
|
||||
req.sampling_params.max_new_tokens = 0
|
||||
req.finished_reason = FINISH_ABORT(
|
||||
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
|
||||
@@ -960,7 +960,7 @@ class Scheduler(
|
||||
|
||||
# Handle multimodal inputs
|
||||
if recv_req.image_inputs is not None:
|
||||
image_inputs = ImageInputs.from_dict(recv_req.image_inputs)
|
||||
image_inputs = MultimodalInputs.from_dict(recv_req.image_inputs)
|
||||
# Expand a single image token into multiple dummy tokens for receiving image embeddings
|
||||
req.origin_input_ids = self.pad_input_ids_func(
|
||||
req.origin_input_ids, image_inputs
|
||||
@@ -974,7 +974,7 @@ class Scheduler(
|
||||
)
|
||||
logger.error(error_msg)
|
||||
req.origin_input_ids = [0]
|
||||
req.image_inputs = None
|
||||
req.multimodal_inputs = None
|
||||
req.sampling_params.max_new_tokens = 0
|
||||
req.finished_reason = FINISH_ABORT(
|
||||
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
|
||||
|
||||
@@ -138,7 +138,7 @@ class Session:
|
||||
token_ids_logprob=req.token_ids_logprob,
|
||||
)
|
||||
if last_req is not None:
|
||||
new_req.image_inputs = last_req.image_inputs
|
||||
new_req.multimodal_inputs = last_req.mm_inputs
|
||||
new_req.tokenizer = tokenizer
|
||||
if abort:
|
||||
new_req.to_abort = True
|
||||
|
||||
@@ -16,7 +16,6 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import dataclasses
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
@@ -52,10 +51,6 @@ from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.disaggregation.conn import KVBootstrapServer
|
||||
from sglang.srt.disaggregation.utils import DisaggregationMode
|
||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||
from sglang.srt.managers.image_processor import (
|
||||
get_dummy_image_processor,
|
||||
get_image_processor,
|
||||
)
|
||||
from sglang.srt.managers.io_struct import (
|
||||
AbortReq,
|
||||
BatchEmbeddingOut,
|
||||
@@ -93,6 +88,11 @@ from sglang.srt.managers.io_struct import (
|
||||
UpdateWeightsFromTensorReqInput,
|
||||
UpdateWeightsFromTensorReqOutput,
|
||||
)
|
||||
from sglang.srt.managers.multimodal_processor import (
|
||||
get_dummy_processor,
|
||||
get_mm_processor,
|
||||
import_processors,
|
||||
)
|
||||
from sglang.srt.metrics.collector import TokenizerMetricsCollector
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
@@ -171,6 +171,7 @@ class TokenizerManager:
|
||||
self.image_token_id = self.model_config.image_token_id
|
||||
|
||||
if self.model_config.is_multimodal:
|
||||
import_processors()
|
||||
_processor = get_processor(
|
||||
server_args.tokenizer_path,
|
||||
tokenizer_mode=server_args.tokenizer_mode,
|
||||
@@ -179,9 +180,9 @@ class TokenizerManager:
|
||||
)
|
||||
|
||||
# We want to parallelize the image pre-processing so we create an executor for it
|
||||
# We create image_processor for any skip_tokenizer_init to make sure we still encode
|
||||
# We create mm_processor for any skip_tokenizer_init to make sure we still encode
|
||||
# images even with skip_tokenizer_init=False.
|
||||
self.image_processor = get_image_processor(
|
||||
self.mm_processor = get_mm_processor(
|
||||
self.model_config.hf_config, server_args, _processor
|
||||
)
|
||||
|
||||
@@ -192,7 +193,7 @@ class TokenizerManager:
|
||||
self.tokenizer = self.processor.tokenizer
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
else:
|
||||
self.image_processor = get_dummy_image_processor()
|
||||
self.mm_processor = get_dummy_processor()
|
||||
|
||||
if server_args.skip_tokenizer_init:
|
||||
self.tokenizer = self.processor = None
|
||||
@@ -389,7 +390,7 @@ class TokenizerManager:
|
||||
)
|
||||
input_ids = self.tokenizer.encode(input_text)
|
||||
|
||||
image_inputs: Dict = await self.image_processor.process_images_async(
|
||||
image_inputs: Dict = await self.mm_processor.process_mm_data_async(
|
||||
obj.image_data, input_text or input_ids, obj, self.max_req_input_len
|
||||
)
|
||||
if image_inputs and "input_ids" in image_inputs:
|
||||
|
||||
@@ -43,7 +43,7 @@ from sglang.srt.utils import get_compiler_backend
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.managers.schedule_batch import ImageInputs, ModelWorkerBatch
|
||||
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, MultimodalInputs
|
||||
from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
@@ -176,7 +176,7 @@ class ForwardBatch:
|
||||
extend_input_logprob_token_ids_gpu: Optional[torch.Tensor] = None
|
||||
|
||||
# For multimodal
|
||||
image_inputs: Optional[List[ImageInputs]] = None
|
||||
mm_inputs: Optional[List[MultimodalInputs]] = None
|
||||
|
||||
# Encoder-decoder
|
||||
encoder_cached: Optional[List[bool]] = None
|
||||
@@ -242,7 +242,7 @@ class ForwardBatch:
|
||||
req_pool_indices=batch.req_pool_indices,
|
||||
seq_lens=batch.seq_lens,
|
||||
out_cache_loc=batch.out_cache_loc,
|
||||
image_inputs=batch.image_inputs,
|
||||
mm_inputs=batch.multimodal_inputs,
|
||||
encoder_cached=batch.encoder_cached,
|
||||
encoder_lens=batch.encoder_lens,
|
||||
encoder_lens_cpu=batch.encoder_lens_cpu,
|
||||
@@ -332,42 +332,53 @@ class ForwardBatch:
|
||||
|
||||
return ret
|
||||
|
||||
def merge_image_inputs(self) -> Optional[ImageInputs]:
|
||||
def merge_mm_inputs(self) -> Optional[MultimodalInputs]:
|
||||
"""
|
||||
Merge all image inputs in the batch into a single ImageInputs object.
|
||||
Merge all image inputs in the batch into a single MultiModalInputs object.
|
||||
|
||||
Returns:
|
||||
if none, current batch contains no image input
|
||||
|
||||
"""
|
||||
if not self.image_inputs or all(x is None for x in self.image_inputs):
|
||||
if not self.mm_inputs or all(x is None for x in self.mm_inputs):
|
||||
return None
|
||||
|
||||
# Filter out None values
|
||||
valid_inputs = [x for x in self.image_inputs if x is not None]
|
||||
valid_inputs = [x for x in self.mm_inputs if x is not None]
|
||||
|
||||
# Start with the first valid image input
|
||||
merged = valid_inputs[0]
|
||||
|
||||
# Merge remaining inputs
|
||||
for img_input in valid_inputs[1:]:
|
||||
merged.merge(img_input)
|
||||
for mm_input in valid_inputs[1:]:
|
||||
merged.merge(mm_input)
|
||||
|
||||
if isinstance(merged.pixel_values, np.ndarray):
|
||||
merged.pixel_values = torch.from_numpy(merged.pixel_values)
|
||||
if isinstance(merged.audio_features, np.ndarray):
|
||||
merged.audio_features = torch.from_numpy(merged.audio_features)
|
||||
|
||||
return merged
|
||||
|
||||
def contains_image_inputs(self) -> bool:
|
||||
""" """
|
||||
if self.image_inputs is None:
|
||||
return True
|
||||
if self.mm_inputs is None:
|
||||
return False
|
||||
return any(
|
||||
image_input.pixel_values is not None and image_input.pixel_values is not []
|
||||
for image_input in self.image_inputs
|
||||
if image_input is not None
|
||||
mm_input is not None and mm_input.contains_image_inputs()
|
||||
for mm_input in self.mm_inputs
|
||||
)
|
||||
|
||||
def contains_audio_inputs(self) -> bool:
|
||||
if self.mm_inputs is None:
|
||||
return False
|
||||
return any(
|
||||
mm_input is not None and mm_input.contains_audio_inputs()
|
||||
for mm_input in self.mm_inputs
|
||||
)
|
||||
|
||||
def contains_mm_inputs(self) -> bool:
|
||||
return self.contains_audio_inputs() or self.contains_image_inputs()
|
||||
|
||||
def _compute_mrope_positions(
|
||||
self, model_runner: ModelRunner, batch: ModelWorkerBatch
|
||||
):
|
||||
@@ -378,8 +389,8 @@ class ForwardBatch:
|
||||
for i, _ in enumerate(mrope_positions_list):
|
||||
mrope_position_delta = (
|
||||
0
|
||||
if batch.image_inputs[i] is None
|
||||
else batch.image_inputs[i].mrope_position_delta
|
||||
if batch.multimodal_inputs[i] is None
|
||||
else batch.multimodal_inputs[i].mrope_position_delta
|
||||
)
|
||||
mrope_positions_list[i] = MRotaryEmbedding.get_next_input_positions(
|
||||
mrope_position_delta,
|
||||
@@ -388,13 +399,13 @@ class ForwardBatch:
|
||||
)
|
||||
elif self.forward_mode.is_extend():
|
||||
extend_start_loc_cpu = self.extend_start_loc.cpu().numpy()
|
||||
for i, image_inputs in enumerate(batch.image_inputs):
|
||||
for i, multimodal_inputs in enumerate(batch.multimodal_inputs):
|
||||
extend_start_loc, extend_seq_len, extend_prefix_len = (
|
||||
extend_start_loc_cpu[i],
|
||||
batch.extend_seq_lens[i],
|
||||
batch.extend_prefix_lens[i],
|
||||
)
|
||||
if image_inputs is None:
|
||||
if multimodal_inputs is None:
|
||||
# text only
|
||||
mrope_positions = [
|
||||
[
|
||||
@@ -411,20 +422,22 @@ class ForwardBatch:
|
||||
input_tokens=self.input_ids[
|
||||
extend_start_loc : extend_start_loc + extend_seq_len
|
||||
],
|
||||
image_grid_thw=image_inputs.image_grid_thws,
|
||||
video_grid_thw=image_inputs.video_grid_thws,
|
||||
image_token_id=image_inputs.im_token_id,
|
||||
video_token_id=image_inputs.video_token_id,
|
||||
image_grid_thw=multimodal_inputs.image_grid_thws,
|
||||
video_grid_thw=multimodal_inputs.video_grid_thws,
|
||||
image_token_id=multimodal_inputs.im_token_id,
|
||||
video_token_id=multimodal_inputs.video_token_id,
|
||||
vision_start_token_id=hf_config.vision_start_token_id,
|
||||
vision_end_token_id=hf_config.vision_end_token_id,
|
||||
spatial_merge_size=hf_config.vision_config.spatial_merge_size,
|
||||
context_len=0,
|
||||
seq_len=len(self.input_ids),
|
||||
second_per_grid_ts=image_inputs.second_per_grid_ts,
|
||||
second_per_grid_ts=multimodal_inputs.second_per_grid_ts,
|
||||
tokens_per_second=hf_config.vision_config.tokens_per_second,
|
||||
)
|
||||
)
|
||||
batch.image_inputs[i].mrope_position_delta = mrope_position_delta
|
||||
batch.multimodal_inputs[i].mrope_position_delta = (
|
||||
mrope_position_delta
|
||||
)
|
||||
mrope_positions_list[i] = mrope_positions
|
||||
|
||||
self.mrope_positions = torch.cat(
|
||||
|
||||
@@ -51,7 +51,7 @@ from sglang.srt.managers.mm_utils import (
|
||||
MultiModalityDataPaddingPatternTokenPairs,
|
||||
general_mm_embed_routine,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import ImageInputs
|
||||
from sglang.srt.managers.schedule_batch import MultimodalInputs, global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.llama import LlamaForCausalLM
|
||||
@@ -1959,7 +1959,7 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
def get_image_feature(self, image_input: ImageInputs) -> torch.Tensor:
|
||||
def get_image_feature(self, image_input: MultimodalInputs) -> torch.Tensor:
|
||||
pixel_values = image_input.pixel_values
|
||||
bs, n = pixel_values.shape[0:2]
|
||||
pixel_values = pixel_values.to(
|
||||
@@ -1988,10 +1988,9 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
|
||||
|
||||
inputs_embeds = general_mm_embed_routine(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
forward_batch=forward_batch,
|
||||
embed_tokens=self.get_input_embeddings(),
|
||||
image_embedding_func=self.get_image_feature,
|
||||
mm_data_embedding_func=self.get_image_feature,
|
||||
)
|
||||
|
||||
return self.language_model(
|
||||
@@ -2005,7 +2004,7 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
|
||||
def prepare_gen_img_embeds(self, image_ids: torch.LongTensor):
|
||||
return self.gen_aligner(self.gen_embed(image_ids))
|
||||
|
||||
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
|
||||
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
|
||||
im_start_id = image_inputs.im_start_id
|
||||
im_end_id = image_inputs.im_end_id
|
||||
media_token_pairs = [(im_start_id, im_end_id)]
|
||||
|
||||
@@ -11,7 +11,7 @@ from sglang.srt.configs.deepseekvl2 import (
|
||||
)
|
||||
from sglang.srt.layers.linear import ReplicatedLinear
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.managers.schedule_batch import ImageInputs
|
||||
from sglang.srt.managers.schedule_batch import MultimodalInputs
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.deepseek_v2 import DeepseekV2ForCausalLM
|
||||
@@ -222,7 +222,7 @@ class DeepseekVL2ForCausalLM(nn.Module):
|
||||
):
|
||||
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
|
||||
extend_seq_lens_cpu = forward_batch.extend_seq_lens.cpu().numpy()
|
||||
for idx, image in enumerate(forward_batch.image_inputs):
|
||||
for idx, image in enumerate(forward_batch.mm_inputs):
|
||||
if image is None:
|
||||
continue
|
||||
start_idx = extend_start_loc_cpu[idx]
|
||||
@@ -262,10 +262,10 @@ class DeepseekVL2ForCausalLM(nn.Module):
|
||||
weights_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weights_loader(param, loaded_weight)
|
||||
|
||||
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
|
||||
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
|
||||
return input_ids
|
||||
|
||||
def get_image_feature(self, image_input: ImageInputs):
|
||||
def get_image_feature(self, image_input: MultimodalInputs):
|
||||
pixel_values = image_input.pixel_values.type(
|
||||
next(self.vision.parameters()).dtype
|
||||
).to(device=next(self.vision.parameters()).device)
|
||||
|
||||
@@ -38,7 +38,7 @@ from sglang.srt.managers.mm_utils import (
|
||||
MultiModalityDataPaddingPatternTokenPairs,
|
||||
general_mm_embed_routine,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import ImageInputs
|
||||
from sglang.srt.managers.schedule_batch import MultimodalInputs
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import (
|
||||
default_weight_loader,
|
||||
@@ -185,7 +185,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
|
||||
self.post_init()
|
||||
|
||||
def pad_input_ids(
|
||||
self, input_ids: List[int], image_inputs: ImageInputs
|
||||
self, input_ids: List[int], image_inputs: MultimodalInputs
|
||||
) -> List[int]:
|
||||
"""Pad input IDs with image tokens."""
|
||||
# Get special token IDs
|
||||
@@ -268,7 +268,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
|
||||
def get_input_embeddings(self) -> nn.Embedding:
|
||||
return self.language_model.get_input_embeddings()
|
||||
|
||||
def get_image_feature(self, image_input: ImageInputs):
|
||||
def get_image_feature(self, image_input: MultimodalInputs):
|
||||
"""
|
||||
Projects the last hidden state from the vision model into language model space.
|
||||
|
||||
@@ -286,11 +286,11 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
|
||||
image_features = self.multi_modal_projector(vision_outputs)
|
||||
return image_features
|
||||
|
||||
def embed_image_inputs(
|
||||
def embed_mm_inputs(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
image_input: ImageInputs,
|
||||
image_input: MultimodalInputs,
|
||||
) -> torch.Tensor:
|
||||
if input_ids is None:
|
||||
raise ValueError("Unimplemented")
|
||||
@@ -401,10 +401,9 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
|
||||
|
||||
inputs_embeds = general_mm_embed_routine(
|
||||
input_ids=llm_input_ids,
|
||||
positions=positions,
|
||||
forward_batch=forward_batch,
|
||||
embed_tokens=self.get_input_embeddings(),
|
||||
image_embedding_func=self.get_image_feature,
|
||||
mm_data_embedding_func=self.get_image_feature,
|
||||
)
|
||||
|
||||
outputs = self.language_model(
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
@@ -31,7 +31,7 @@ from transformers import (
|
||||
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
||||
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.managers.schedule_batch import ImageInputs
|
||||
from sglang.srt.managers.schedule_batch import MultimodalInputs
|
||||
from sglang.srt.mm_utils import (
|
||||
get_anyres_image_grid_shape,
|
||||
unpad_image,
|
||||
@@ -46,7 +46,7 @@ from sglang.srt.utils import add_prefix
|
||||
|
||||
|
||||
class LlavaBaseForCausalLM(nn.Module):
|
||||
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
|
||||
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
|
||||
image_sizes, pad_values = image_inputs.image_sizes, image_inputs.pad_values
|
||||
|
||||
# hardcode for spatial_unpad + anyres
|
||||
@@ -134,7 +134,7 @@ class LlavaBaseForCausalLM(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
image_inputs = forward_batch.image_inputs
|
||||
image_inputs = forward_batch.mm_inputs
|
||||
|
||||
if forward_batch.forward_mode.is_extend():
|
||||
# Clamp input ids. This is because the input_ids for the image tokens are
|
||||
|
||||
@@ -22,7 +22,7 @@ from transformers import CLIPVisionModel, LlavaConfig
|
||||
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
||||
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.managers.schedule_batch import ImageInputs
|
||||
from sglang.srt.managers.schedule_batch import MultimodalInputs
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.llama import LlamaForCausalLM
|
||||
@@ -57,7 +57,7 @@ class LlavaVidForCausalLM(nn.Module):
|
||||
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
|
||||
)
|
||||
|
||||
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
|
||||
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
|
||||
pad_values = image_inputs.pad_values
|
||||
new_image_feature_len = self.image_feature_len
|
||||
|
||||
@@ -112,7 +112,7 @@ class LlavaVidForCausalLM(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
image_inputs = forward_batch.image_inputs
|
||||
image_inputs = forward_batch.mm_inputs
|
||||
if forward_batch.forward_mode.is_extend():
|
||||
bs = forward_batch.batch_size
|
||||
|
||||
|
||||
1995
python/sglang/srt/models/minicpmo.py
Normal file
1995
python/sglang/srt/models/minicpmo.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -52,9 +52,9 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.managers.mm_utils import (
|
||||
MultiModalityDataPaddingPatternTokenPairs,
|
||||
embed_image_inputs,
|
||||
general_mm_embed_routine,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import ImageInputs
|
||||
from sglang.srt.managers.schedule_batch import MultimodalInputs
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.utils import set_default_torch_dtype
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
@@ -862,24 +862,12 @@ class MiniCPMVBaseModel(nn.Module):
|
||||
forward_batch: ForwardBatch,
|
||||
**kwargs: Any,
|
||||
) -> torch.Tensor:
|
||||
if (
|
||||
forward_batch.forward_mode.is_decode()
|
||||
or not forward_batch.contains_image_inputs()
|
||||
):
|
||||
inputs_embeds: torch.Tensor = self.llm.get_input_embeddings(input_ids)
|
||||
else:
|
||||
# Clamp input ids. This is because the input_ids for the image tokens are
|
||||
# filled with the hash values of the image for the prefix matching in the radix attention.
|
||||
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
|
||||
image_inputs = forward_batch.merge_image_inputs()
|
||||
inputs_embeds = embed_image_inputs(
|
||||
image_input=image_inputs,
|
||||
input_ids=input_ids,
|
||||
input_embedding=self.get_input_embeddings(),
|
||||
image_embedding_func=self.get_image_features,
|
||||
placeholder_token_ids=[image_inputs.im_token_id]
|
||||
+ image_inputs.pad_values,
|
||||
)
|
||||
inputs_embeds = general_mm_embed_routine(
|
||||
input_ids=input_ids,
|
||||
forward_batch=forward_batch,
|
||||
embed_tokens=self.get_input_embeddings(),
|
||||
mm_data_embedding_func=self.get_image_features,
|
||||
)
|
||||
|
||||
hidden_states = self.llm.model(
|
||||
input_ids=None,
|
||||
@@ -925,7 +913,7 @@ class MiniCPMVBaseModel(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_image_features(self, image_inputs: ImageInputs) -> torch.Tensor:
|
||||
def get_image_features(self, image_inputs: MultimodalInputs) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -1037,7 +1025,7 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
|
||||
|
||||
def get_image_features(
|
||||
self,
|
||||
image_inputs: ImageInputs,
|
||||
image_inputs: MultimodalInputs,
|
||||
) -> torch.Tensor:
|
||||
# list of tensors
|
||||
pixel_values = image_inputs.pixel_values
|
||||
@@ -1075,7 +1063,7 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
|
||||
)
|
||||
return self.resampler(vision_embedding, tgt_sizes)
|
||||
|
||||
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
|
||||
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
|
||||
# Get all special token IDs
|
||||
im_start_id: int = image_inputs.im_start_id
|
||||
im_end_id: int = image_inputs.im_end_id
|
||||
|
||||
@@ -32,7 +32,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import ImageInputs
|
||||
from sglang.srt.managers.schedule_batch import MultimodalInputs
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.llama import LlamaDecoderLayer, LlamaMLP
|
||||
@@ -796,7 +796,7 @@ class MllamaForConditionalGeneration(nn.Module):
|
||||
self.logits_processor = LogitsProcessor(config.text_config)
|
||||
self.capture_mode = False
|
||||
|
||||
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
|
||||
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
|
||||
pixel_values = image_inputs.pixel_values
|
||||
pad_values = image_inputs.pad_values
|
||||
|
||||
@@ -815,7 +815,7 @@ class MllamaForConditionalGeneration(nn.Module):
|
||||
|
||||
# pixel_values: shape (bs, num_image, num_tiles, 3, image_res, image_res)
|
||||
max_num_images = max_num_tiles = bs = 0
|
||||
for i, im in enumerate(forward_batch.image_inputs):
|
||||
for i, im in enumerate(forward_batch.mm_inputs):
|
||||
if not forward_batch.encoder_cached[i] and im is not None:
|
||||
max_num_images = max(max_num_images, im.pixel_values.shape[1])
|
||||
max_num_tiles = max(max_num_tiles, im.pixel_values.shape[2])
|
||||
@@ -842,7 +842,7 @@ class MllamaForConditionalGeneration(nn.Module):
|
||||
)
|
||||
i = 0
|
||||
encoder_lens_need = []
|
||||
for k, im in enumerate(forward_batch.image_inputs):
|
||||
for k, im in enumerate(forward_batch.mm_inputs):
|
||||
if forward_batch.encoder_cached[k] or im is None:
|
||||
continue
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ from sglang.srt.managers.mm_utils import (
|
||||
MultiModalityDataPaddingPatternTokenPairs,
|
||||
general_mm_embed_routine,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import ImageInputs
|
||||
from sglang.srt.managers.schedule_batch import MultimodalInputs
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.qwen2 import Qwen2Model
|
||||
@@ -513,7 +513,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||
|
||||
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
|
||||
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
|
||||
# Get all special token IDs
|
||||
im_start_id: int = image_inputs.im_start_id
|
||||
im_end_id: int = image_inputs.im_end_id
|
||||
@@ -523,7 +523,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
||||
|
||||
return pattern.pad_input_tokens(input_ids, image_inputs)
|
||||
|
||||
def get_image_feature(self, image_input: ImageInputs) -> torch.Tensor:
|
||||
def get_image_feature(self, image_input: MultimodalInputs) -> torch.Tensor:
|
||||
pixel_values = image_input.pixel_values.type(self.visual.dtype)
|
||||
image_embeds = self.visual(pixel_values, grid_thw=image_input.image_grid_thws)
|
||||
return image_embeds
|
||||
@@ -572,10 +572,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
||||
|
||||
inputs_embeds = general_mm_embed_routine(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
forward_batch=forward_batch,
|
||||
embed_tokens=self.get_input_embeddings(),
|
||||
image_embedding_func=self.get_image_feature,
|
||||
mm_data_embedding_func=self.get_image_feature,
|
||||
)
|
||||
|
||||
hidden_states = self.model(
|
||||
|
||||
@@ -45,7 +45,7 @@ from sglang.srt.managers.mm_utils import (
|
||||
MultiModalityDataPaddingPatternTokenPairs,
|
||||
general_mm_embed_routine,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import ImageInputs
|
||||
from sglang.srt.managers.schedule_batch import MultimodalInputs
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.qwen2 import Qwen2Model
|
||||
@@ -472,16 +472,16 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
|
||||
# Use grid_t * grid_w * grid_h to pad tokens for each image
|
||||
# add replaced padding by unique image hash
|
||||
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
|
||||
def pad_input_ids(self, input_ids: List[int], multi_modal_inputs: MultimodalInputs):
|
||||
# Get all special token IDs
|
||||
im_start_id: int = image_inputs.im_start_id
|
||||
im_end_id: int = image_inputs.im_end_id
|
||||
im_start_id: int = multi_modal_inputs.im_start_id
|
||||
im_end_id: int = multi_modal_inputs.im_end_id
|
||||
|
||||
media_token_pairs = [(im_start_id, im_end_id)]
|
||||
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
|
||||
return pattern.pad_input_tokens(input_ids, image_inputs)
|
||||
return pattern.pad_input_tokens(input_ids, multi_modal_inputs)
|
||||
|
||||
def get_image_feature(self, image_input: ImageInputs) -> torch.Tensor:
|
||||
def get_image_feature(self, image_input: MultimodalInputs) -> torch.Tensor:
|
||||
pixel_values = image_input.pixel_values.type(self.visual.dtype)
|
||||
image_embeds = self.visual(pixel_values, grid_thw=image_input.image_grid_thws)
|
||||
return image_embeds
|
||||
@@ -530,10 +530,9 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
|
||||
inputs_embeds = general_mm_embed_routine(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
forward_batch=forward_batch,
|
||||
embed_tokens=self.get_input_embeddings(),
|
||||
image_embedding_func=self.get_image_feature,
|
||||
mm_data_embedding_func=self.get_image_feature,
|
||||
)
|
||||
|
||||
hidden_states = self.model(
|
||||
|
||||
@@ -899,6 +899,7 @@ def v1_chat_generate_request(
|
||||
input_ids = []
|
||||
sampling_params_list = []
|
||||
image_data_list = []
|
||||
audio_data_list = []
|
||||
return_logprobs = []
|
||||
logprob_start_lens = []
|
||||
top_logprobs_nums = []
|
||||
@@ -912,6 +913,7 @@ def v1_chat_generate_request(
|
||||
# - prompt: The full prompt string.
|
||||
# - stop: Custom stop tokens.
|
||||
# - image_data: None or a list of image strings (URLs or base64 strings).
|
||||
# - audio_data: None or a list of audio strings (URLs).
|
||||
# None skips any image processing in GenerateReqInput.
|
||||
if not isinstance(request.messages, str):
|
||||
# Apply chat template and its stop strings.
|
||||
@@ -956,7 +958,7 @@ def v1_chat_generate_request(
|
||||
)
|
||||
except:
|
||||
# This except branch will be triggered when the chosen model
|
||||
# has a different tools input format that is not compatiable
|
||||
# has a different tools input format that is not compatible
|
||||
# with openAI's apply_chat_template tool_call format, like Mistral.
|
||||
tools = [t if "function" in t else {"function": t} for t in tools]
|
||||
prompt_ids = tokenizer_manager.tokenizer.apply_chat_template(
|
||||
@@ -976,11 +978,13 @@ def v1_chat_generate_request(
|
||||
prompt_ids += encoded
|
||||
stop = request.stop
|
||||
image_data = None
|
||||
audio_data = None
|
||||
modalities = []
|
||||
else:
|
||||
conv = generate_chat_conv(request, chat_template_name)
|
||||
prompt = conv.get_prompt()
|
||||
image_data = conv.image_data
|
||||
audio_data = conv.audio_data
|
||||
modalities = conv.modalities
|
||||
stop = conv.stop_str or []
|
||||
if request.stop:
|
||||
@@ -994,6 +998,7 @@ def v1_chat_generate_request(
|
||||
prompt_ids = request.messages
|
||||
stop = request.stop
|
||||
image_data = None
|
||||
audio_data = None
|
||||
modalities = []
|
||||
input_ids.append(prompt_ids)
|
||||
return_logprobs.append(request.logprobs)
|
||||
@@ -1034,6 +1039,7 @@ def v1_chat_generate_request(
|
||||
sampling_params_list.append(sampling_params)
|
||||
|
||||
image_data_list.append(image_data)
|
||||
audio_data_list.append(audio_data)
|
||||
modalities_list.append(modalities)
|
||||
if len(all_requests) == 1:
|
||||
if isinstance(input_ids[0], str):
|
||||
@@ -1042,6 +1048,7 @@ def v1_chat_generate_request(
|
||||
prompt_kwargs = {"input_ids": input_ids[0]}
|
||||
sampling_params_list = sampling_params_list[0]
|
||||
image_data_list = image_data_list[0]
|
||||
audio_data_list = audio_data_list[0]
|
||||
return_logprobs = return_logprobs[0]
|
||||
logprob_start_lens = logprob_start_lens[0]
|
||||
top_logprobs_nums = top_logprobs_nums[0]
|
||||
@@ -1056,6 +1063,7 @@ def v1_chat_generate_request(
|
||||
adapted_request = GenerateReqInput(
|
||||
**prompt_kwargs,
|
||||
image_data=image_data_list,
|
||||
audio_data=audio_data_list,
|
||||
sampling_params=sampling_params_list,
|
||||
return_logprob=return_logprobs,
|
||||
logprob_start_len=logprob_start_lens,
|
||||
|
||||
@@ -227,14 +227,25 @@ class ChatCompletionMessageContentImageURL(BaseModel):
|
||||
detail: Optional[Literal["auto", "low", "high"]] = "auto"
|
||||
|
||||
|
||||
class ChatCompletionMessageContentAudioURL(BaseModel):
|
||||
url: str
|
||||
|
||||
|
||||
class ChatCompletionMessageContentImagePart(BaseModel):
|
||||
type: Literal["image_url"]
|
||||
image_url: ChatCompletionMessageContentImageURL
|
||||
modalities: Optional[Literal["image", "multi-images", "video"]] = "image"
|
||||
|
||||
|
||||
class ChatCompletionMessageContentAudioPart(BaseModel):
|
||||
type: Literal["audio_url"]
|
||||
audio_url: ChatCompletionMessageContentAudioURL
|
||||
|
||||
|
||||
ChatCompletionMessageContentPart = Union[
|
||||
ChatCompletionMessageContentTextPart, ChatCompletionMessageContentImagePart
|
||||
ChatCompletionMessageContentTextPart,
|
||||
ChatCompletionMessageContentImagePart,
|
||||
ChatCompletionMessageContentAudioPart,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -55,14 +55,13 @@ import triton
|
||||
import zmq
|
||||
from fastapi.responses import ORJSONResponse
|
||||
from packaging import version as pkg_version
|
||||
from packaging.version import Version, parse
|
||||
from PIL import Image
|
||||
from starlette.routing import Mount
|
||||
from torch import nn
|
||||
from torch.func import functional_call
|
||||
from torch.library import Library
|
||||
from torch.profiler import ProfilerActivity, profile, record_function
|
||||
from torch.utils._contextlib import _DecoratorContextManager
|
||||
from torch.utils.cpp_extension import CUDA_HOME
|
||||
from triton.runtime.cache import (
|
||||
FileCacheManager,
|
||||
default_cache_dir,
|
||||
@@ -507,9 +506,37 @@ def decode_video_base64(video_base64):
|
||||
) # Return an empty array and size tuple if no frames were found
|
||||
|
||||
|
||||
def load_image(image_file: Union[str, bytes]):
|
||||
from PIL import Image
|
||||
def load_audio(audio_file: str, sr: int = 16000, mono: bool = True) -> np.ndarray:
|
||||
# Use soundfile here, since librosa use it under the hood,
|
||||
# and librosa will not support audio loading in the future
|
||||
import soundfile as sf
|
||||
from scipy.signal import resample
|
||||
|
||||
# print(f"loading {audio_file}")
|
||||
# Load audio data
|
||||
if isinstance(audio_file, bytes):
|
||||
audio, original_sr = sf.read(BytesIO(audio_file))
|
||||
elif audio_file.startswith("data:"):
|
||||
audio_file = audio_file.split(",")[1]
|
||||
audio, original_sr = sf.read(BytesIO(base64.b64decode(audio_file)))
|
||||
elif isinstance(audio_file, str):
|
||||
audio, original_sr = sf.read(audio_file)
|
||||
else:
|
||||
raise ValueError(f"Invalid audio format: {audio_file}")
|
||||
|
||||
# Resample audio if the original sample rate is different from the desired sample rate
|
||||
if original_sr != sr:
|
||||
num_samples = int(len(audio) * float(sr) / original_sr)
|
||||
audio = resample(audio, num_samples)
|
||||
|
||||
# Convert to mono if requested and audio is stereo
|
||||
if mono and len(audio.shape) > 1:
|
||||
audio = np.mean(audio, axis=1)
|
||||
|
||||
return audio
|
||||
|
||||
|
||||
def load_image(image_file: Union[str, bytes]) -> tuple[Image, tuple[int, int]]:
|
||||
image = image_size = None
|
||||
|
||||
if isinstance(image_file, bytes):
|
||||
|
||||
Reference in New Issue
Block a user