Move multimodal processors into a separate folder (#7581)
This commit is contained in:
@@ -22,7 +22,7 @@ from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
from sglang.srt.mm_utils import has_valid_data
|
||||
from sglang.srt.multimodal.mm_utils import has_valid_data
|
||||
|
||||
# handle serialization of Image for pydantic
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@@ -2,8 +2,6 @@
|
||||
Multi-modality utils
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
|
||||
@@ -5,9 +5,7 @@ import logging
|
||||
import pkgutil
|
||||
from functools import lru_cache
|
||||
|
||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||
BaseMultimodalProcessor,
|
||||
)
|
||||
from sglang.srt.multimodal.processors.base_processor import BaseMultimodalProcessor
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -29,7 +27,7 @@ def get_dummy_processor():
|
||||
|
||||
@lru_cache()
|
||||
def import_processors():
|
||||
package_name = "sglang.srt.managers.multimodal_processors"
|
||||
package_name = "sglang.srt.multimodal.processors"
|
||||
package = importlib.import_module(package_name)
|
||||
for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
|
||||
if not ispkg:
|
||||
|
||||
@@ -1,591 +0,0 @@
|
||||
import concurrent
|
||||
import concurrent.futures
|
||||
import dataclasses
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import BaseImageProcessorFast
|
||||
|
||||
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||
from sglang.srt.utils import encode_video, load_audio, load_image
|
||||
|
||||
|
||||
class MultimodalInputFormat(Enum):
|
||||
"""Enum for different multimodal input formats."""
|
||||
|
||||
RAW_IMAGES = "raw_images"
|
||||
PRECOMPUTED_FEATURES = "precomputed_features"
|
||||
PIXEL_VALUES = "pixel_values"
|
||||
AUDIO = "audio"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BaseMultiModalProcessorOutput:
|
||||
# input_text, with each frame of video/image represented with a image_token
|
||||
input_text: str
|
||||
|
||||
# frames loaded from image and video, in given order
|
||||
images: Optional[list[Union[Image.Image, dict]]] = None
|
||||
|
||||
# audios
|
||||
audios: Optional[list[Union[np.ndarray, dict]]] = None
|
||||
|
||||
def normalize(self):
|
||||
for field_name in ["images", "audios"]:
|
||||
field = getattr(self, field_name, None)
|
||||
if field is not None and isinstance(field, list) and len(field) == 0:
|
||||
setattr(self, field_name, None)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MultimodalSpecialTokens:
|
||||
image_token: Optional[Union[int, str, List[str]]] = None
|
||||
video_token: Optional[Union[int, str, List[str]]] = None
|
||||
audio_token: Optional[Union[int, str, List[str]]] = None
|
||||
|
||||
def convert_to_str(self, token: Union[str, int], processor) -> str:
|
||||
if token is None:
|
||||
return token
|
||||
if isinstance(token, str):
|
||||
return token
|
||||
return processor.tokenizer.convert_ids_to_tokens([token])[0]
|
||||
|
||||
def convert_to_strs(self, processor):
|
||||
self.image_token = self.convert_to_str(self.image_token, processor)
|
||||
self.video_token = self.convert_to_str(self.video_token, processor)
|
||||
self.audio_token = self.convert_to_str(self.audio_token, processor)
|
||||
|
||||
image_token_regex: Optional[re.Pattern] = None
|
||||
video_token_regex: Optional[re.Pattern] = None
|
||||
audio_token_regex: Optional[re.Pattern] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.image_token_regex is None and self.image_token is not None:
|
||||
self.image_token_regex = re.compile(re.escape(self.image_token))
|
||||
if self.video_token_regex is None and self.video_token is not None:
|
||||
self.video_token_regex = re.compile(re.escape(self.video_token))
|
||||
if self.audio_token_regex is None and self.audio_token is not None:
|
||||
self.audio_token_regex = re.compile(re.escape(self.audio_token))
|
||||
|
||||
def collect(self) -> re.Pattern:
|
||||
tokens = [
|
||||
self.image_token_regex,
|
||||
self.video_token_regex,
|
||||
self.audio_token_regex,
|
||||
]
|
||||
patterns = []
|
||||
flags = 0
|
||||
for t in tokens:
|
||||
if t is not None:
|
||||
patterns.append(t.pattern)
|
||||
flags |= t.flags
|
||||
combined = "(" + "|".join(f"(?:{p})" for p in patterns) + ")"
|
||||
return re.compile(combined, flags)
|
||||
|
||||
|
||||
class BaseMultimodalProcessor(ABC):
|
||||
models = []
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
self.hf_config = hf_config
|
||||
self._processor = _processor
|
||||
self.arch = hf_config.architectures[0]
|
||||
self.server_args = server_args
|
||||
# FIXME: not accurate, model and image specific
|
||||
self.NUM_TOKEN_PER_FRAME = 330
|
||||
|
||||
self.io_executor = concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=int(os.environ.get("SGLANG_IO_WORKERS", 4))
|
||||
)
|
||||
self.cpu_executor = concurrent.futures.ProcessPoolExecutor(
|
||||
mp_context=mp.get_context("fork"),
|
||||
max_workers=int(os.environ.get("SGLANG_CPU_WORKERS", os.cpu_count())),
|
||||
)
|
||||
|
||||
def process_mm_data(
|
||||
self, input_text, images=None, videos=None, audios=None, **kwargs
|
||||
):
|
||||
"""
|
||||
process multimodal data with transformers AutoProcessor
|
||||
"""
|
||||
if images is not None:
|
||||
kwargs["images"] = images
|
||||
if videos is not None:
|
||||
kwargs["videos"] = videos
|
||||
if audios is not None:
|
||||
kwargs["audios"] = audios
|
||||
|
||||
processor = self._processor
|
||||
if hasattr(processor, "image_processor") and isinstance(
|
||||
processor.image_processor, BaseImageProcessorFast
|
||||
):
|
||||
kwargs["device"] = "cuda"
|
||||
result = processor.__call__(
|
||||
text=[input_text],
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
**kwargs,
|
||||
)
|
||||
if "pixel_values" in result and isinstance(
|
||||
result["pixel_values"], torch.Tensor
|
||||
):
|
||||
result["pixel_values"] = result["pixel_values"].to("cpu")
|
||||
return result
|
||||
|
||||
@abstractmethod
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
image_data,
|
||||
input_text,
|
||||
request_obj,
|
||||
max_req_input_len,
|
||||
**kwargs,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
pass
|
||||
|
||||
def get_estimated_frames_list(self, image_data):
|
||||
"""
|
||||
estimate the total frame count from all visual input
|
||||
"""
|
||||
# Lazy import because decord is not available on some arm platforms.
|
||||
from decord import VideoReader, cpu
|
||||
|
||||
# Before processing inputs
|
||||
if not image_data or len(image_data) == 0:
|
||||
return []
|
||||
estimated_frames_list = []
|
||||
for image in image_data:
|
||||
if isinstance(image, str) and image.startswith("video:"):
|
||||
path = image[len("video:") :]
|
||||
# Estimate frames for the video
|
||||
vr = VideoReader(path, ctx=cpu(0))
|
||||
num_frames = len(vr)
|
||||
else:
|
||||
# For images, each contributes one frame
|
||||
num_frames = 1
|
||||
estimated_frames_list.append(num_frames)
|
||||
|
||||
return estimated_frames_list
|
||||
|
||||
@staticmethod
|
||||
def _load_single_item(
|
||||
data, is_video, is_audio, frame_count_limit=None, discard_alpha_channel=True
|
||||
):
|
||||
"""Static method that can be pickled for multiprocessing"""
|
||||
if isinstance(data, dict):
|
||||
return data
|
||||
try:
|
||||
if is_audio:
|
||||
return load_audio(data)
|
||||
elif is_video:
|
||||
path = data[len("video:") :]
|
||||
return encode_video(path, frame_count_limit)
|
||||
else:
|
||||
img, _ = load_image(data)
|
||||
return img.convert("RGB") if discard_alpha_channel else img
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error while loading data {data}: {e}")
|
||||
|
||||
def submit_data_loading_tasks(
|
||||
self,
|
||||
text_parts: List[str],
|
||||
multimodal_tokens: MultimodalSpecialTokens,
|
||||
image_data: Optional[list] = None,
|
||||
audio_data: Optional[list] = None,
|
||||
discard_alpha_channel: bool = True,
|
||||
):
|
||||
"""
|
||||
load multimodal data parallelly
|
||||
"""
|
||||
|
||||
# TODO(mick): load from server_args, env, or sampling_params
|
||||
MAX_NUM_FRAMES = 30
|
||||
estimated_frames_list = self.get_estimated_frames_list(image_data=image_data)
|
||||
total_frame_count = sum(estimated_frames_list)
|
||||
# a heuristic value, suggesting the maximum fraction of frames to embed from all visual inputs.
|
||||
# e.g., 0.1 suggests that 1 frame out of 10 input frames should be used
|
||||
scaling_factor = min(1.0, MAX_NUM_FRAMES / max(1, total_frame_count))
|
||||
|
||||
assert len(image_data) == len(estimated_frames_list)
|
||||
# Submit all tasks
|
||||
futures = []
|
||||
task_info = []
|
||||
image_index, audio_index = 0, 0
|
||||
|
||||
for text_part in text_parts:
|
||||
if (
|
||||
multimodal_tokens.image_token_regex
|
||||
and multimodal_tokens.image_token_regex.match(text_part)
|
||||
):
|
||||
data = image_data[image_index]
|
||||
is_video = isinstance(data, str) and data.startswith("video:")
|
||||
estimated_frames = estimated_frames_list[image_index]
|
||||
frame_count_limit = max(1, int(estimated_frames * scaling_factor))
|
||||
futures.append(
|
||||
self.io_executor.submit(
|
||||
BaseMultimodalProcessor._load_single_item,
|
||||
data,
|
||||
is_video,
|
||||
False,
|
||||
frame_count_limit,
|
||||
discard_alpha_channel,
|
||||
)
|
||||
)
|
||||
task_info.append((Modality.IMAGE, data, frame_count_limit))
|
||||
image_index += 1
|
||||
elif (
|
||||
multimodal_tokens.audio_token_regex
|
||||
and multimodal_tokens.audio_token_regex.match(text_part)
|
||||
):
|
||||
data = audio_data[audio_index]
|
||||
futures.append(
|
||||
self.io_executor.submit(
|
||||
BaseMultimodalProcessor._load_single_item,
|
||||
data,
|
||||
False,
|
||||
True,
|
||||
None,
|
||||
discard_alpha_channel,
|
||||
)
|
||||
)
|
||||
task_info.append((Modality.AUDIO, data, None))
|
||||
audio_index += 1
|
||||
|
||||
return futures, task_info
|
||||
|
||||
def load_mm_data(
|
||||
self,
|
||||
prompt: str | List[int],
|
||||
multimodal_tokens: MultimodalSpecialTokens,
|
||||
max_req_input_len: int,
|
||||
image_data: Optional[list] = None,
|
||||
audio_data: Optional[list] = None,
|
||||
return_text: Optional[bool] = True,
|
||||
discard_alpha_channel: bool = True,
|
||||
) -> BaseMultiModalProcessorOutput:
|
||||
"""
|
||||
Each frame of video/image will be replaced by a single image token
|
||||
|
||||
Args:
|
||||
multimodal_tokens (list[str]): list of special token which denoting a single multimodal data
|
||||
e.g. image token or audio token
|
||||
discard_alpha_channel: if True, discards the alpha channel in the returned images
|
||||
|
||||
"""
|
||||
if not return_text:
|
||||
raise NotImplementedError()
|
||||
if image_data is None:
|
||||
image_data = []
|
||||
|
||||
multimodal_tokens.convert_to_strs(self._processor)
|
||||
multimodal_tokens_pattern = multimodal_tokens.collect()
|
||||
|
||||
if isinstance(prompt, list) and return_text:
|
||||
assert len(prompt) and isinstance(prompt[0], int)
|
||||
prompt = self._processor.tokenizer.decode(prompt)
|
||||
else:
|
||||
prompt = prompt
|
||||
|
||||
assert isinstance(prompt, str)
|
||||
# split text into list of normal text and special tokens
|
||||
text_parts = re.split(multimodal_tokens_pattern, prompt)
|
||||
|
||||
futures, task_info = self.submit_data_loading_tasks(
|
||||
text_parts=text_parts,
|
||||
multimodal_tokens=multimodal_tokens,
|
||||
image_data=image_data,
|
||||
audio_data=audio_data,
|
||||
discard_alpha_channel=discard_alpha_channel,
|
||||
)
|
||||
# Process results
|
||||
images, audios = [], []
|
||||
new_text = ""
|
||||
task_ptr = 0
|
||||
|
||||
for text_part in text_parts:
|
||||
if multimodal_tokens_pattern.match(text_part):
|
||||
task_type, data, frame_limit = task_info[task_ptr]
|
||||
result = futures[task_ptr].result()
|
||||
task_ptr += 1
|
||||
|
||||
if task_type == Modality.IMAGE:
|
||||
# If data is already processed it will be a
|
||||
# dictionary. In this case we want to keep the
|
||||
# expanded tokens in text_part. Otherwise, we will
|
||||
# call the processor code, so keep only a single image
|
||||
# token.
|
||||
mm_tokens = (
|
||||
text_part
|
||||
if isinstance(data, dict)
|
||||
else multimodal_tokens.image_token
|
||||
)
|
||||
frames = [result] if not isinstance(result, list) else result
|
||||
if frames:
|
||||
images += frames
|
||||
new_text += mm_tokens * len(frames)
|
||||
elif task_type == Modality.AUDIO:
|
||||
# audio
|
||||
mm_tokens = (
|
||||
text_part
|
||||
if isinstance(data, dict)
|
||||
else multimodal_tokens.audio_token
|
||||
)
|
||||
audios.append(result)
|
||||
new_text += mm_tokens
|
||||
# TODO: handle video
|
||||
else:
|
||||
new_text += text_part
|
||||
|
||||
out = BaseMultiModalProcessorOutput(
|
||||
input_text=new_text,
|
||||
images=images,
|
||||
audios=audios,
|
||||
)
|
||||
out.normalize()
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def get_mm_items_offset(
|
||||
input_ids: torch.Tensor, mm_token_id: int
|
||||
) -> List[Tuple[int, int]]:
|
||||
"""
|
||||
Get a set of range for mm_items from input_ids
|
||||
Example:
|
||||
input_ids = [1, 2, 3, 3, 3, 4, 3, 3]
|
||||
mm_token_id = 3
|
||||
return result = [(2,4),(6,7)]
|
||||
"""
|
||||
mask = input_ids == mm_token_id
|
||||
|
||||
start_positions = (mask & ~torch.roll(mask, 1)).nonzero(as_tuple=True)[0]
|
||||
end_positions = (mask & ~torch.roll(mask, -1)).nonzero(as_tuple=True)[0]
|
||||
|
||||
return list(zip(start_positions.tolist(), end_positions.tolist()))
|
||||
|
||||
@staticmethod
|
||||
def get_mm_items_offset_by_pair(
|
||||
input_ids: torch.Tensor, mm_start_id: int, mm_end_id: int
|
||||
) -> List[Tuple[int, int]]:
|
||||
indices_start = (input_ids == mm_start_id).nonzero(as_tuple=True)[0] + 1
|
||||
indices_end = (input_ids == mm_end_id).nonzero(as_tuple=True)[0] - 1
|
||||
|
||||
return list(zip(indices_start.tolist(), indices_end.tolist()))
|
||||
|
||||
@staticmethod
|
||||
def _extract_processor_features(
|
||||
items: List[dict], attr_name: str
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Helper function to concat extracted attributes from processor output.
|
||||
"""
|
||||
values = [value for item in items if (value := item.get(attr_name)) is not None]
|
||||
return torch.cat(values) if values else None
|
||||
|
||||
# When we assume that all the items have the same attributes
|
||||
def _extract_processor_features_from_all_attributes(
|
||||
self, items: List[dict]
|
||||
) -> dict:
|
||||
values = {}
|
||||
# Verify all items have the same keys
|
||||
first_keys = set(items[0].keys())
|
||||
for item in items[1:]:
|
||||
if set(item.keys()) != first_keys:
|
||||
raise ValueError(
|
||||
f"All items must have the same attributes. "
|
||||
f"First item has {first_keys}, but found {set(item.keys())}"
|
||||
)
|
||||
|
||||
# Process each attribute
|
||||
for k, v in items[0].items():
|
||||
if isinstance(v, list):
|
||||
values[k] = self._extract_processor_features(items, k)
|
||||
else:
|
||||
# Verify all items have the same value for non-list attributes
|
||||
for item in items[1:]:
|
||||
if item[k] != v:
|
||||
raise ValueError(
|
||||
f"All items must have the same value for attribute {k}. "
|
||||
f"First item has {v}, but found {item[k]}"
|
||||
)
|
||||
values[k] = v
|
||||
return values
|
||||
|
||||
def process_and_combine_mm_data(
|
||||
self, base_output: BaseMultiModalProcessorOutput
|
||||
) -> Tuple[Optional[MultimodalDataItem], torch.Tensor]:
|
||||
"""
|
||||
Process multimodal data and return the combined multimodal item and input_ids.
|
||||
Handles all three input formats at the same abstraction level.
|
||||
|
||||
Returns:
|
||||
Tuple of (combined_mm_item, input_ids)
|
||||
"""
|
||||
|
||||
def tokenize_text(input_text: str) -> torch.Tensor:
|
||||
"""Tokenize input text."""
|
||||
return self._processor.tokenizer(
|
||||
input_text,
|
||||
return_tensors="pt",
|
||||
add_special_tokens=True,
|
||||
).input_ids.flatten()
|
||||
|
||||
def categorize_mm_inputs(mm_inputs: List) -> MultimodalInputFormat:
|
||||
"""Categorize multimodal inputs and validate consistency."""
|
||||
try:
|
||||
has_image = False
|
||||
has_pixel_values = False
|
||||
has_precomputed_features = False
|
||||
has_audio = False
|
||||
|
||||
for mm_input in mm_inputs:
|
||||
if isinstance(mm_input, Image.Image):
|
||||
has_image = True
|
||||
elif isinstance(mm_input, np.ndarray):
|
||||
has_audio = True
|
||||
elif isinstance(mm_input, dict):
|
||||
if mm_input.get("precomputed_features", None) is not None:
|
||||
has_precomputed_features = True
|
||||
elif mm_input.get("pixel_values", None) is not None:
|
||||
has_pixel_values = True
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid multimodal input: {mm_input}, expected dict with pixel_values or precomputed_features"
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid multimodal input: {mm_input}, expected Image.Image or dict"
|
||||
)
|
||||
|
||||
# Validate format consistency
|
||||
format_count = sum(
|
||||
[has_image, has_pixel_values, has_precomputed_features, has_audio]
|
||||
)
|
||||
if format_count > 1:
|
||||
raise ValueError(
|
||||
"Unsupported: mixture of multimodal input formats. "
|
||||
f"Found formats: image={has_image}, pixel_values={has_pixel_values}, "
|
||||
f"precomputed_features={has_precomputed_features}, audio={has_audio}"
|
||||
)
|
||||
|
||||
if has_image:
|
||||
return MultimodalInputFormat.RAW_IMAGES
|
||||
elif has_precomputed_features:
|
||||
return MultimodalInputFormat.PRECOMPUTED_FEATURES
|
||||
elif has_pixel_values:
|
||||
return MultimodalInputFormat.PIXEL_VALUES
|
||||
elif has_audio:
|
||||
return MultimodalInputFormat.AUDIO
|
||||
else:
|
||||
raise ValueError("No valid multimodal input format found")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to categorize inputs: {e}")
|
||||
|
||||
def process_raw_images(
|
||||
base_output: BaseMultiModalProcessorOutput,
|
||||
) -> Tuple[MultimodalDataItem, torch.Tensor]:
|
||||
"""Process raw Image.Image objects using transformers processor."""
|
||||
ret = self.process_mm_data(
|
||||
input_text=base_output.input_text,
|
||||
images=base_output.images,
|
||||
)
|
||||
combined_mm_item = MultimodalDataItem(modality=Modality.IMAGE)
|
||||
|
||||
# Copy all fields from processor output except input_ids
|
||||
for key, value in ret.items():
|
||||
if key != "input_ids" and hasattr(combined_mm_item, key):
|
||||
setattr(combined_mm_item, key, value)
|
||||
|
||||
input_ids = ret["input_ids"].flatten()
|
||||
return combined_mm_item, input_ids
|
||||
|
||||
def process_precomputed_features(
|
||||
base_output: BaseMultiModalProcessorOutput,
|
||||
) -> Tuple[MultimodalDataItem, torch.Tensor]:
|
||||
"""Process inputs with precomputed features."""
|
||||
combined_mm_item = MultimodalDataItem(modality=Modality.IMAGE)
|
||||
combined_mm_item.precomputed_features = self._extract_processor_features(
|
||||
base_output.images, "precomputed_features"
|
||||
)
|
||||
input_ids = tokenize_text(base_output.input_text)
|
||||
return combined_mm_item, input_ids
|
||||
|
||||
def process_pixel_values(
|
||||
base_output: BaseMultiModalProcessorOutput,
|
||||
) -> Tuple[MultimodalDataItem, torch.Tensor]:
|
||||
"""Process inputs with pixel values."""
|
||||
values = self._extract_processor_features_from_all_attributes(
|
||||
base_output.images
|
||||
)
|
||||
combined_mm_item = MultimodalDataItem.from_dict(values)
|
||||
input_ids = tokenize_text(base_output.input_text)
|
||||
return combined_mm_item, input_ids
|
||||
|
||||
def process_audio(
|
||||
base_output: BaseMultiModalProcessorOutput,
|
||||
) -> Tuple[MultimodalDataItem, torch.Tensor]:
|
||||
"""Process inputs with audio."""
|
||||
ret = self.process_mm_data(
|
||||
input_text=base_output.input_text,
|
||||
audio=base_output.audios, # Note: "audio" is for gemma3n only
|
||||
)
|
||||
combined_mm_item = MultimodalDataItem(modality=Modality.AUDIO)
|
||||
for key, value in ret.items():
|
||||
if key != "input_ids" and hasattr(combined_mm_item, key):
|
||||
setattr(combined_mm_item, key, value)
|
||||
input_ids = ret["input_ids"].flatten()
|
||||
return combined_mm_item, input_ids
|
||||
|
||||
def finalize_mm_item(
|
||||
combined_mm_item: MultimodalDataItem, input_ids: torch.Tensor
|
||||
) -> MultimodalDataItem:
|
||||
"""Apply common post-processing to the multimodal item."""
|
||||
if combined_mm_item.modality in [Modality.IMAGE, Modality.MULTI_IMAGES]:
|
||||
combined_mm_item.image_offsets = self.get_mm_items_offset(
|
||||
input_ids=input_ids,
|
||||
mm_token_id=self.IM_TOKEN_ID,
|
||||
)
|
||||
elif combined_mm_item.modality == Modality.AUDIO:
|
||||
combined_mm_item.audio_offsets = self.get_mm_items_offset(
|
||||
input_ids=input_ids,
|
||||
mm_token_id=self.AUDIO_TOKEN_ID,
|
||||
)
|
||||
elif combined_mm_item.modality == Modality.VIDEO:
|
||||
combined_mm_item.video_offsets = self.get_mm_items_offset(
|
||||
input_ids=input_ids,
|
||||
mm_token_id=self.VIDEO_TOKEN_ID,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown modality: {combined_mm_item.modality}")
|
||||
return combined_mm_item
|
||||
|
||||
# Main logic - determine input type and handle text-only case
|
||||
mm_inputs = base_output.images or base_output.audios
|
||||
if not mm_inputs:
|
||||
input_ids = tokenize_text(base_output.input_text)
|
||||
return None, input_ids
|
||||
|
||||
# Categorize input formats
|
||||
input_format = categorize_mm_inputs(mm_inputs)
|
||||
|
||||
# Process based on format
|
||||
if input_format == MultimodalInputFormat.RAW_IMAGES:
|
||||
combined_mm_item, input_ids = process_raw_images(base_output)
|
||||
elif input_format == MultimodalInputFormat.PRECOMPUTED_FEATURES:
|
||||
combined_mm_item, input_ids = process_precomputed_features(base_output)
|
||||
elif input_format == MultimodalInputFormat.PIXEL_VALUES:
|
||||
combined_mm_item, input_ids = process_pixel_values(base_output)
|
||||
elif input_format == MultimodalInputFormat.AUDIO:
|
||||
combined_mm_item, input_ids = process_audio(base_output)
|
||||
else:
|
||||
raise ValueError(f"Unknown input format: {input_format}")
|
||||
|
||||
# Finalize with common processing
|
||||
combined_mm_item = finalize_mm_item(combined_mm_item, input_ids)
|
||||
return combined_mm_item, input_ids
|
||||
@@ -1,44 +0,0 @@
|
||||
from typing import List, Union
|
||||
|
||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||
BaseMultimodalProcessor,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||
from sglang.srt.models.clip import CLIPModel
|
||||
from sglang.srt.utils import load_image
|
||||
|
||||
|
||||
class ClipImageProcessor(BaseMultimodalProcessor):
|
||||
models = [CLIPModel]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
|
||||
async def process_mm_data_async(
|
||||
self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs
|
||||
):
|
||||
if not image_data:
|
||||
return None
|
||||
|
||||
if isinstance(input_text, list):
|
||||
assert len(input_text) and isinstance(input_text[0], int)
|
||||
input_text = self._processor.tokenizer.decode(input_text)
|
||||
|
||||
if not isinstance(image_data, list):
|
||||
image_data = [image_data]
|
||||
|
||||
if len(image_data) > 0:
|
||||
images = [load_image(image)[0] for image in image_data]
|
||||
else:
|
||||
images = load_image(image_data[0])[0]
|
||||
|
||||
image_inputs = self.process_mm_data(input_text=input_text, images=images)
|
||||
image_inputs["data_hashes"] = [hash(str(image_data))]
|
||||
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
|
||||
image_inputs["mm_items"] = [
|
||||
MultimodalDataItem(
|
||||
pixel_values=image_inputs["pixel_values"], modality=Modality.IMAGE
|
||||
)
|
||||
]
|
||||
|
||||
return image_inputs
|
||||
@@ -1,90 +0,0 @@
|
||||
# Copyright (c) 2023-2024 DeepSeek.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy of
|
||||
# this software and associated documentation files (the "Software"), to deal in
|
||||
# the Software without restriction, including without limitation the rights to
|
||||
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
||||
# the Software, and to permit persons to whom the Software is furnished to do so,
|
||||
# subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
||||
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
||||
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
||||
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
||||
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||
BaseMultimodalProcessor,
|
||||
MultimodalSpecialTokens,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||
from sglang.srt.models.deepseek_vl2 import DeepseekVL2ForCausalLM
|
||||
|
||||
|
||||
class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
|
||||
models = [DeepseekVL2ForCausalLM]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
self.IMAGE_TOKEN = "<image>"
|
||||
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
image_data: List[Union[str, bytes]],
|
||||
input_text,
|
||||
request_obj,
|
||||
max_req_input_len,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
if not image_data:
|
||||
return None
|
||||
|
||||
if not isinstance(image_data, list):
|
||||
image_data = [image_data]
|
||||
|
||||
image_token = self.IMAGE_TOKEN
|
||||
base_output = self.load_mm_data(
|
||||
input_text,
|
||||
image_data=image_data,
|
||||
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
|
||||
max_req_input_len=max_req_input_len,
|
||||
)
|
||||
res = self.process_mm_data(
|
||||
input_text=base_output.input_text,
|
||||
images=base_output.images,
|
||||
max_req_input_len=max_req_input_len,
|
||||
conversations=base_output.input_text,
|
||||
)
|
||||
images_seq_mask = res["images_seq_mask"]
|
||||
images_spatial_crop = res["images_spatial_crop"]
|
||||
batched_images_spatial_crop = []
|
||||
batched_images_spatial_crop.append(images_spatial_crop)
|
||||
batched_images_spatial_crop = torch.stack(batched_images_spatial_crop, dim=0)
|
||||
|
||||
items = []
|
||||
input_ids = res["input_ids"]
|
||||
image_offsets = self.get_mm_items_offset(
|
||||
input_ids=input_ids, mm_token_id=self._processor.image_token_id
|
||||
)
|
||||
item = MultimodalDataItem(
|
||||
pixel_values=res["images"],
|
||||
image_offsets=image_offsets,
|
||||
modality=Modality.IMAGE,
|
||||
image_emb_mask=images_seq_mask,
|
||||
image_spatial_crop=batched_images_spatial_crop,
|
||||
)
|
||||
items += [item]
|
||||
|
||||
return {
|
||||
"mm_items": items,
|
||||
"input_ids": input_ids.tolist(),
|
||||
"im_token_id": self._processor.image_token_id,
|
||||
}
|
||||
@@ -1,63 +0,0 @@
|
||||
import re
|
||||
from typing import Dict, List, Union
|
||||
|
||||
from sglang.srt.managers.multimodal_processor import (
|
||||
BaseMultimodalProcessor as SGLangBaseProcessor,
|
||||
)
|
||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||
MultimodalSpecialTokens,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||
from sglang.srt.models.gemma3_mm import Gemma3ForConditionalGeneration
|
||||
|
||||
# Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma3/image_processing_gemma3_fast.py
|
||||
# will be removed in the future
|
||||
|
||||
|
||||
class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
|
||||
models = [Gemma3ForConditionalGeneration]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
# The single, pre-expanded image token.
|
||||
self.IMAGE_TOKEN = "<start_of_image>"
|
||||
# The regex that matches expanded image tokens.
|
||||
self.IMAGE_TOKEN_REGEX = re.compile(
|
||||
r"<start_of_image>(?:(?:<image_soft_token>)*<end_of_image>)?"
|
||||
)
|
||||
self.IM_START_TOKEN_ID = hf_config.boi_token_index
|
||||
self.IM_END_TOKEN_ID = hf_config.eoi_token_index
|
||||
self.IM_TOKEN_ID = hf_config.image_token_index
|
||||
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
image_data: List[Union[str, bytes, Dict]],
|
||||
input_text,
|
||||
request_obj,
|
||||
max_req_input_len,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
if not image_data:
|
||||
return None
|
||||
if isinstance(image_data, str):
|
||||
image_data = [image_data]
|
||||
|
||||
base_output = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
image_data=image_data,
|
||||
multimodal_tokens=MultimodalSpecialTokens(
|
||||
image_token=self.IMAGE_TOKEN, image_token_regex=self.IMAGE_TOKEN_REGEX
|
||||
),
|
||||
max_req_input_len=max_req_input_len,
|
||||
discard_alpha_channel=True,
|
||||
)
|
||||
|
||||
combined_mm_item, input_ids = self.process_and_combine_mm_data(base_output)
|
||||
|
||||
return {
|
||||
"input_ids": input_ids.tolist(),
|
||||
"mm_items": [combined_mm_item] if combined_mm_item is not None else [],
|
||||
"im_start_id": self.IM_START_TOKEN_ID,
|
||||
"im_end_id": self.IM_END_TOKEN_ID,
|
||||
}
|
||||
@@ -1,97 +0,0 @@
|
||||
# Copyright 2025 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import re
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from sglang.srt.managers.multimodal_processor import (
|
||||
BaseMultimodalProcessor as SGLangBaseProcessor,
|
||||
)
|
||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||
MultimodalSpecialTokens,
|
||||
)
|
||||
from sglang.srt.models.gemma3n_mm import Gemma3nForConditionalGeneration
|
||||
|
||||
|
||||
class Gemma3nSGLangProcessor(SGLangBaseProcessor):
|
||||
"""Multimodal processor for Gemma3n supporting image and audio inputs."""
|
||||
|
||||
models = [Gemma3nForConditionalGeneration]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
|
||||
self.IMAGE_TOKEN = "<image_soft_token>"
|
||||
self.IMAGE_TOKEN_REGEX = re.compile(
|
||||
r"<start_of_image>(?:(?:<image_soft_token>)*<end_of_image>)?"
|
||||
)
|
||||
|
||||
self.AUDIO_TOKEN = "<audio_soft_token>"
|
||||
self.AUDIO_TOKEN_REGEX = re.compile(
|
||||
r"<start_of_audio>(?:(?:<audio_soft_token>)*<end_of_audio>)?"
|
||||
)
|
||||
|
||||
self.IM_TOKEN_ID = hf_config.image_token_id
|
||||
self.IM_START_TOKEN_ID = hf_config.boi_token_id
|
||||
self.IM_END_TOKEN_ID = hf_config.eoi_token_id
|
||||
|
||||
self.AUDIO_TOKEN_ID = hf_config.audio_token_id
|
||||
self.AUDIO_START_TOKEN_ID = hf_config.boa_token_id
|
||||
self.AUDIO_END_TOKEN_ID = hf_config.eoa_token_id
|
||||
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
image_data: Optional[List[Union[str, bytes, Dict]]] = None,
|
||||
audio_data: Optional[List[Union[str, bytes, Dict]]] = None,
|
||||
input_text: str = "",
|
||||
request_obj=None,
|
||||
max_req_input_len: int = 0,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
"""Process multimodal data including images and audio."""
|
||||
|
||||
audio_data = request_obj.audio_data
|
||||
if not image_data and not audio_data:
|
||||
return None
|
||||
|
||||
if isinstance(image_data, str):
|
||||
image_data = [image_data]
|
||||
|
||||
if isinstance(audio_data, str):
|
||||
audio_data = [audio_data]
|
||||
|
||||
base_output = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
image_data=image_data,
|
||||
audio_data=audio_data,
|
||||
max_req_input_len=max_req_input_len,
|
||||
multimodal_tokens=MultimodalSpecialTokens(
|
||||
image_token=self.IMAGE_TOKEN,
|
||||
image_token_regex=self.IMAGE_TOKEN_REGEX,
|
||||
audio_token=self.AUDIO_TOKEN,
|
||||
audio_token_regex=self.AUDIO_TOKEN_REGEX,
|
||||
),
|
||||
)
|
||||
|
||||
combined_mm_item, input_ids = self.process_and_combine_mm_data(base_output)
|
||||
|
||||
return {
|
||||
"input_ids": input_ids.tolist(),
|
||||
"mm_items": [combined_mm_item] if combined_mm_item is not None else [],
|
||||
"im_start_id": self.IM_START_TOKEN_ID,
|
||||
"im_end_id": self.IM_END_TOKEN_ID,
|
||||
"audio_start_id": self.AUDIO_START_TOKEN_ID,
|
||||
"audio_end_id": self.AUDIO_END_TOKEN_ID,
|
||||
}
|
||||
@@ -1,245 +0,0 @@
|
||||
# Adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_intern_vit.py
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from decord import VideoReader, cpu
|
||||
from PIL import Image
|
||||
|
||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||
BaseMultimodalProcessor,
|
||||
MultimodalSpecialTokens,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||
from sglang.srt.models.internvl import InternVLChatModel
|
||||
|
||||
|
||||
class InternVLImageProcessor(BaseMultimodalProcessor):
|
||||
models = [InternVLChatModel]
|
||||
|
||||
def __init__(self, hf_config, server_args, _image_processor):
|
||||
super().__init__(hf_config, server_args, _image_processor)
|
||||
image_size = hf_config.force_image_size or hf_config.vision_config.image_size
|
||||
patch_size = hf_config.vision_config.patch_size
|
||||
|
||||
self.IMG_CONTEXT_TOKEN = "<IMG_CONTEXT>"
|
||||
self.IMG_START_TOKEN = "<img>"
|
||||
self.IMG_END_TOKEN = "</img>"
|
||||
self.IMG_TOKEN = "<image>"
|
||||
self.num_image_token = int(
|
||||
(image_size // patch_size) ** 2 * (hf_config.downsample_ratio**2)
|
||||
)
|
||||
|
||||
tokenizer = self._processor
|
||||
self.img_start_token_id = tokenizer.convert_tokens_to_ids(self.IMG_START_TOKEN)
|
||||
self.img_end_token_id = tokenizer.convert_tokens_to_ids(self.IMG_END_TOKEN)
|
||||
self.img_context_token_id = tokenizer.convert_tokens_to_ids(
|
||||
self.IMG_CONTEXT_TOKEN
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def build_transform(input_size):
|
||||
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
||||
IMAGENET_STD = (0.229, 0.224, 0.225)
|
||||
|
||||
def resize_image(img, size):
|
||||
return img.resize((size, size), Image.Resampling.BICUBIC)
|
||||
|
||||
def to_tensor(img):
|
||||
# Convert PIL Image to numpy array
|
||||
img_array = np.array(img).astype(np.float32) / 255.0
|
||||
# Convert HWC to CHW format
|
||||
img_array = img_array.transpose(2, 0, 1)
|
||||
return torch.from_numpy(img_array)
|
||||
|
||||
def normalize(tensor, mean, std):
|
||||
mean = torch.tensor(mean).view(-1, 1, 1)
|
||||
std = torch.tensor(std).view(-1, 1, 1)
|
||||
return (tensor - mean) / std
|
||||
|
||||
def transform(img):
|
||||
img = img.convert("RGB") if img.mode != "RGB" else img
|
||||
img = resize_image(img, input_size)
|
||||
tensor = to_tensor(img)
|
||||
tensor = normalize(tensor, IMAGENET_MEAN, IMAGENET_STD)
|
||||
return tensor
|
||||
|
||||
return transform
|
||||
|
||||
@staticmethod
|
||||
def dynamic_preprocess(
|
||||
image, min_num=1, max_num=12, image_size=448, use_thumbnail=False
|
||||
):
|
||||
|
||||
def find_closest_aspect_ratio(
|
||||
aspect_ratio, target_ratios, width, height, image_size
|
||||
):
|
||||
best_ratio_diff = float("inf")
|
||||
best_ratio = (1, 1)
|
||||
area = width * height
|
||||
for ratio in target_ratios:
|
||||
target_aspect_ratio = ratio[0] / ratio[1]
|
||||
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
|
||||
if ratio_diff < best_ratio_diff:
|
||||
best_ratio_diff = ratio_diff
|
||||
best_ratio = ratio
|
||||
elif ratio_diff == best_ratio_diff:
|
||||
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
|
||||
best_ratio = ratio
|
||||
return best_ratio
|
||||
|
||||
orig_width, orig_height = image.size
|
||||
aspect_ratio = orig_width / orig_height
|
||||
|
||||
# calculate the existing image aspect ratio
|
||||
target_ratios = set(
|
||||
(i, j)
|
||||
for n in range(min_num, max_num + 1)
|
||||
for i in range(1, n + 1)
|
||||
for j in range(1, n + 1)
|
||||
if i * j <= max_num and i * j >= min_num
|
||||
)
|
||||
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
||||
|
||||
# find the closest aspect ratio to the target
|
||||
target_aspect_ratio = find_closest_aspect_ratio(
|
||||
aspect_ratio, target_ratios, orig_width, orig_height, image_size
|
||||
)
|
||||
|
||||
# calculate the target width and height
|
||||
target_width = image_size * target_aspect_ratio[0]
|
||||
target_height = image_size * target_aspect_ratio[1]
|
||||
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
||||
|
||||
# resize the image
|
||||
resized_img = image.resize((target_width, target_height))
|
||||
processed_images = []
|
||||
for i in range(blocks):
|
||||
box = (
|
||||
(i % (target_width // image_size)) * image_size,
|
||||
(i // (target_width // image_size)) * image_size,
|
||||
((i % (target_width // image_size)) + 1) * image_size,
|
||||
((i // (target_width // image_size)) + 1) * image_size,
|
||||
)
|
||||
# split the image
|
||||
split_img = resized_img.crop(box)
|
||||
processed_images.append(split_img)
|
||||
assert len(processed_images) == blocks
|
||||
if use_thumbnail and len(processed_images) != 1:
|
||||
thumbnail_img = image.resize((image_size, image_size))
|
||||
processed_images.append(thumbnail_img)
|
||||
return processed_images
|
||||
|
||||
@staticmethod
|
||||
def get_index(bound, fps, max_frame, first_idx=0, num_segments=32):
|
||||
if bound:
|
||||
start, end = bound[0], bound[1]
|
||||
else:
|
||||
start, end = -100000, 100000
|
||||
start_idx = max(first_idx, round(start * fps))
|
||||
end_idx = min(round(end * fps), max_frame)
|
||||
seg_size = float(end_idx - start_idx) / num_segments
|
||||
frame_indices = np.array(
|
||||
[
|
||||
int(start_idx + (seg_size / 2) + np.round(seg_size * idx))
|
||||
for idx in range(num_segments)
|
||||
]
|
||||
)
|
||||
return frame_indices
|
||||
|
||||
@staticmethod
|
||||
def load_video(video_path, bound=None, input_size=448, max_num=1, num_segments=32):
|
||||
vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
|
||||
max_frame = len(vr) - 1
|
||||
fps = float(vr.get_avg_fps())
|
||||
|
||||
pixel_values_list, num_patches_list = [], []
|
||||
transform = InternVLImageProcessor.build_transform(input_size=input_size)
|
||||
frame_indices = InternVLImageProcessor.get_index(
|
||||
bound, fps, max_frame, first_idx=0, num_segments=num_segments
|
||||
)
|
||||
for frame_index in frame_indices:
|
||||
img = Image.fromarray(vr[frame_index].asnumpy()).convert("RGB")
|
||||
img = InternVLImageProcessor.dynamic_preprocess(
|
||||
img, image_size=input_size, use_thumbnail=True, max_num=max_num
|
||||
)
|
||||
pixel_values = [transform(tile) for tile in img]
|
||||
pixel_values = torch.stack(pixel_values)
|
||||
num_patches_list.append(pixel_values.shape[0])
|
||||
pixel_values_list.append(pixel_values)
|
||||
pixel_values = torch.cat(pixel_values_list)
|
||||
return pixel_values, num_patches_list
|
||||
|
||||
async def process_mm_data_async(
|
||||
self, image_data, input_text, request_obj, max_req_input_len, **kwargs
|
||||
):
|
||||
if not image_data:
|
||||
return None
|
||||
|
||||
# Ensure image_data is a list
|
||||
if isinstance(image_data, str):
|
||||
image_data = [image_data]
|
||||
|
||||
base_output = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
image_data=image_data,
|
||||
multimodal_tokens=MultimodalSpecialTokens(image_token=self.IMG_TOKEN),
|
||||
max_req_input_len=max_req_input_len,
|
||||
discard_alpha_channel=True,
|
||||
)
|
||||
|
||||
def process_image_internvl(image, input_size=448, max_num=12):
|
||||
transform = InternVLImageProcessor.build_transform(input_size=input_size)
|
||||
images = InternVLImageProcessor.dynamic_preprocess(
|
||||
image, image_size=input_size, use_thumbnail=True, max_num=max_num
|
||||
)
|
||||
pixel_values = [transform(image) for image in images]
|
||||
pixel_values = torch.stack(pixel_values)
|
||||
return pixel_values
|
||||
|
||||
num_patches_list = []
|
||||
pixel_values = []
|
||||
# Process each input with allocated frames
|
||||
for image_index, (image) in enumerate(base_output.images):
|
||||
try:
|
||||
# TODO: video input
|
||||
raw_image = process_image_internvl(image)
|
||||
pixel_value = [raw_image.to(torch.bfloat16).cuda()]
|
||||
pixel_values += pixel_value
|
||||
num_patches = raw_image.shape[0]
|
||||
num_patches_list += [num_patches]
|
||||
|
||||
except FileNotFoundError as e:
|
||||
print(e)
|
||||
return None
|
||||
|
||||
pixel_values = torch.cat(pixel_values, dim=0)
|
||||
|
||||
for idx, num_patches in enumerate(num_patches_list):
|
||||
image_tokens = (
|
||||
self.IMG_START_TOKEN
|
||||
+ self.IMG_CONTEXT_TOKEN * self.num_image_token * num_patches
|
||||
+ self.IMG_END_TOKEN
|
||||
)
|
||||
input_text = input_text.replace("<image>", image_tokens, 1)
|
||||
|
||||
tokenizer = self._processor
|
||||
input_ids = tokenizer(input_text, return_tensors="pt")["input_ids"].flatten()
|
||||
image_offsets = self.get_mm_items_offset(
|
||||
input_ids=input_ids,
|
||||
mm_token_id=self.img_context_token_id,
|
||||
)
|
||||
items = [
|
||||
MultimodalDataItem(
|
||||
pixel_values=pixel_values,
|
||||
modality=Modality.IMAGE,
|
||||
image_offsets=image_offsets,
|
||||
)
|
||||
]
|
||||
|
||||
return {
|
||||
"input_ids": input_ids.tolist(),
|
||||
"mm_items": items,
|
||||
"im_start_id": self.img_start_token_id,
|
||||
"im_end_id": self.img_end_token_id,
|
||||
"im_token_id": self.img_context_token_id,
|
||||
}
|
||||
@@ -1,66 +0,0 @@
|
||||
from typing import List, Union
|
||||
|
||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||
BaseMultimodalProcessor,
|
||||
MultimodalSpecialTokens,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||
from sglang.srt.models.deepseek_janus_pro import MultiModalityCausalLM
|
||||
|
||||
|
||||
class JanusProImageProcessor(BaseMultimodalProcessor):
|
||||
models = [MultiModalityCausalLM]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
image_data: List[Union[str, bytes]],
|
||||
input_text,
|
||||
request_obj,
|
||||
max_req_input_len,
|
||||
**kwargs,
|
||||
):
|
||||
if not image_data:
|
||||
return None
|
||||
|
||||
if not isinstance(image_data, list):
|
||||
image_data = [image_data]
|
||||
|
||||
processor = self._processor
|
||||
|
||||
base_out = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
image_data=image_data,
|
||||
multimodal_tokens=MultimodalSpecialTokens(
|
||||
image_token=processor.image_token
|
||||
),
|
||||
max_req_input_len=max_req_input_len,
|
||||
)
|
||||
|
||||
images = base_out.images
|
||||
res = self.process_mm_data(
|
||||
input_text=base_out.input_text,
|
||||
prompt=base_out.input_text,
|
||||
images=images,
|
||||
)
|
||||
|
||||
input_ids = res["input_ids"].flatten()
|
||||
image_offsets = self.get_mm_items_offset(
|
||||
input_ids=input_ids, mm_token_id=processor.image_id
|
||||
)
|
||||
return {
|
||||
"mm_items": [
|
||||
MultimodalDataItem(
|
||||
pixel_values=res["pixel_values"],
|
||||
image_emb_mask=res["images_emb_mask"],
|
||||
image_offsets=image_offsets,
|
||||
modality=Modality.IMAGE,
|
||||
)
|
||||
],
|
||||
"input_ids": input_ids.tolist(),
|
||||
"im_start_id": processor.image_start_id,
|
||||
"im_end_id": processor.image_end_id,
|
||||
"im_token_id": processor.image_id,
|
||||
}
|
||||
@@ -1,55 +0,0 @@
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||
BaseMultimodalProcessor as SGLangBaseProcessor,
|
||||
)
|
||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||
MultimodalSpecialTokens,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||
from sglang.srt.models.kimi_vl import KimiVLForConditionalGeneration
|
||||
|
||||
|
||||
# Compatible with KimiVLForConditionalGeneration
|
||||
class KimiVLImageProcessor(SGLangBaseProcessor):
|
||||
models = [KimiVLForConditionalGeneration]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
self.IMAGE_TOKEN = "<|media_pad|>"
|
||||
self.IMAGE_TOKEN_REGEX = re.compile(r"(?:<\|media_pad\|>)+")
|
||||
self.IM_TOKEN_ID = _processor.tokenizer.convert_tokens_to_ids(self.IMAGE_TOKEN)
|
||||
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
image_data: List[Union[str, bytes, Dict]],
|
||||
input_text,
|
||||
request_obj,
|
||||
max_req_input_len,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
if not image_data:
|
||||
return None
|
||||
if isinstance(image_data, str):
|
||||
image_data = [image_data]
|
||||
|
||||
base_output = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
image_data=image_data,
|
||||
multimodal_tokens=MultimodalSpecialTokens(
|
||||
image_token=self.IMAGE_TOKEN, image_token_regex=self.IMAGE_TOKEN_REGEX
|
||||
),
|
||||
max_req_input_len=max_req_input_len,
|
||||
)
|
||||
|
||||
combined_mm_item, input_ids = self.process_and_combine_mm_data(base_output)
|
||||
|
||||
return {
|
||||
"input_ids": input_ids.tolist(),
|
||||
"mm_items": [combined_mm_item] if combined_mm_item is not None else [],
|
||||
"im_token_id": self.IM_TOKEN_ID,
|
||||
}
|
||||
@@ -1,213 +0,0 @@
|
||||
import asyncio
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from transformers.models.auto.processing_auto import (
|
||||
PROCESSOR_MAPPING_NAMES as HF_MAPPING_NAMES,
|
||||
)
|
||||
|
||||
import sglang.srt.managers.multimodal_processor as sgl_mm_processor_utils
|
||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||
BaseMultimodalProcessor,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||
from sglang.srt.mm_utils import expand2square, process_anyres_image
|
||||
from sglang.srt.models.llava import (
|
||||
LlavaForConditionalGeneration,
|
||||
LlavaLlamaForCausalLM,
|
||||
LlavaMistralForCausalLM,
|
||||
LlavaQwenForCausalLM,
|
||||
)
|
||||
from sglang.srt.models.llavavid import LlavaVidForCausalLM
|
||||
from sglang.srt.models.mistral import Mistral3ForConditionalGeneration
|
||||
from sglang.srt.utils import load_image, logger
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
|
||||
class LlavaImageProcessor(BaseMultimodalProcessor):
|
||||
models = [
|
||||
LlavaLlamaForCausalLM,
|
||||
LlavaVidForCausalLM,
|
||||
LlavaQwenForCausalLM,
|
||||
LlavaMistralForCausalLM,
|
||||
]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
|
||||
@staticmethod
|
||||
def _process_single_image_task(
|
||||
image_data: Union[str, bytes],
|
||||
image_aspect_ratio: Optional[str] = None,
|
||||
image_grid_pinpoints: Optional[str] = None,
|
||||
processor=None,
|
||||
):
|
||||
|
||||
image_processor = processor.image_processor
|
||||
|
||||
try:
|
||||
image, image_size = load_image(image_data)
|
||||
if image_size is not None:
|
||||
# It is a video with multiple images
|
||||
image_hash = hash(image_data)
|
||||
pixel_values = image_processor(image)["pixel_values"]
|
||||
for _ in range(len(pixel_values)):
|
||||
pixel_values[_] = pixel_values[_].astype(np.float16)
|
||||
pixel_values = np.stack(pixel_values, axis=0)
|
||||
return pixel_values, image_hash, image_size
|
||||
else:
|
||||
# It is an image
|
||||
image_hash = hash(image_data)
|
||||
if image_aspect_ratio == "pad":
|
||||
image = expand2square(
|
||||
image,
|
||||
tuple(int(x * 255) for x in image_processor.image_mean),
|
||||
)
|
||||
pixel_values = image_processor(image.convert("RGB"))[
|
||||
"pixel_values"
|
||||
][0]
|
||||
elif image_aspect_ratio == "anyres" or (
|
||||
image_aspect_ratio is not None
|
||||
and "anyres_max" in image_aspect_ratio
|
||||
):
|
||||
pixel_values = process_anyres_image(
|
||||
image, image_processor, image_grid_pinpoints
|
||||
)
|
||||
else:
|
||||
pixel_values = image_processor(image)["pixel_values"][0]
|
||||
|
||||
if isinstance(pixel_values, np.ndarray):
|
||||
pixel_values = pixel_values.astype(np.float16)
|
||||
|
||||
return pixel_values, image_hash, image.size
|
||||
except Exception:
|
||||
logger.error("Exception in TokenizerManager:\n" + get_exception_traceback())
|
||||
|
||||
async def _process_single_image(
|
||||
self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str
|
||||
):
|
||||
if self.cpu_executor is not None:
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
self.cpu_executor,
|
||||
LlavaImageProcessor._process_single_image_task,
|
||||
image_data,
|
||||
aspect_ratio,
|
||||
grid_pinpoints,
|
||||
self._processor,
|
||||
)
|
||||
else:
|
||||
return self._process_single_image_task(
|
||||
image_data,
|
||||
aspect_ratio,
|
||||
grid_pinpoints,
|
||||
self._processor.image_processor,
|
||||
)
|
||||
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
image_data: List[Union[str, bytes]],
|
||||
input_text,
|
||||
request_obj,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
if not image_data:
|
||||
return None
|
||||
|
||||
modalities = request_obj.modalities or ["image"]
|
||||
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
|
||||
grid_pinpoints = (
|
||||
self.hf_config.image_grid_pinpoints
|
||||
if hasattr(self.hf_config, "image_grid_pinpoints")
|
||||
and "anyres" in aspect_ratio
|
||||
else None
|
||||
)
|
||||
|
||||
if isinstance(image_data, str):
|
||||
image_data = [image_data]
|
||||
|
||||
if isinstance(image_data, list) and len(image_data) > 0:
|
||||
if "multi-images" in modalities or "video" in modalities:
|
||||
# Multiple images
|
||||
aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
|
||||
pixel_values, data_hashes, image_sizes = [], [], []
|
||||
res = []
|
||||
for img_data in image_data:
|
||||
res.append(
|
||||
self._process_single_image(
|
||||
img_data, aspect_ratio, grid_pinpoints
|
||||
)
|
||||
)
|
||||
|
||||
res = await asyncio.gather(*res)
|
||||
for pixel_v, image_h, image_s in res:
|
||||
pixel_values.append(pixel_v)
|
||||
data_hashes.append(image_h)
|
||||
image_sizes.append(image_s)
|
||||
|
||||
if isinstance(pixel_values[0], np.ndarray):
|
||||
pixel_values = np.stack(pixel_values, axis=0)
|
||||
else:
|
||||
# A single image
|
||||
pixel_values, image_hash, image_size = await self._process_single_image(
|
||||
image_data[0], aspect_ratio, grid_pinpoints
|
||||
)
|
||||
image_sizes = [image_size]
|
||||
else:
|
||||
raise ValueError(f"Invalid image data: {image_data}")
|
||||
modality = Modality.IMAGE
|
||||
if isinstance(request_obj.modalities, list):
|
||||
if request_obj.modalities[0] == "multi-images":
|
||||
modality = Modality.MULTI_IMAGES
|
||||
elif request_obj.modalities[0] == "video":
|
||||
modality = Modality.VIDEO
|
||||
|
||||
return {
|
||||
"mm_items": [
|
||||
MultimodalDataItem(
|
||||
pixel_values=pixel_values,
|
||||
image_sizes=image_sizes,
|
||||
modality=modality,
|
||||
)
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class LlavaMultimodalProcessor(BaseMultimodalProcessor):
|
||||
"""
|
||||
This is a wrapper class used to identify the multimodal processor for Llava architectures' vision model.
|
||||
"""
|
||||
|
||||
models = [LlavaForConditionalGeneration, Mistral3ForConditionalGeneration]
|
||||
|
||||
def _get_sgl_processor_cls(self, model_type: str):
|
||||
if hf_name := HF_MAPPING_NAMES.get(model_type):
|
||||
sgl_mm_processor_set = sgl_mm_processor_utils.PROCESSOR_MAPPING.values()
|
||||
sgl_processor_cls = list(
|
||||
filter(lambda p: p.__name__ == hf_name, sgl_mm_processor_set)
|
||||
)
|
||||
if sgl_processor_cls:
|
||||
return sgl_processor_cls[0]
|
||||
raise ValueError(
|
||||
f"Cannot find corresponding multimodal processor registered in sglang for model type `{model_type}`"
|
||||
)
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
assert hasattr(hf_config, "vision_config")
|
||||
assert hasattr(hf_config, "text_config")
|
||||
self.vision_config = hf_config.vision_config
|
||||
self.text_config = hf_config.text_config
|
||||
self.hf_config = hf_config
|
||||
|
||||
if vision_type := getattr(self.vision_config, "model_type"):
|
||||
self.inner = self._get_sgl_processor_cls(vision_type)(
|
||||
hf_config, server_args, _processor
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Required `vision_config.model_type` is not found in hf_config: `{hf_config}`"
|
||||
)
|
||||
|
||||
async def process_mm_data_async(self, *args, **kwargs):
|
||||
return await self.inner.process_mm_data_async(*args, **kwargs)
|
||||
@@ -1,160 +0,0 @@
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||
BaseMultimodalProcessor,
|
||||
MultimodalSpecialTokens,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||
from sglang.srt.models.minicpmo import MiniCPMO
|
||||
from sglang.srt.models.minicpmv import MiniCPMV
|
||||
|
||||
|
||||
# Compatible with both 'O' and 'V'
|
||||
class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
||||
models = [MiniCPMV, MiniCPMO]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
self.image_token = "(<image>./</image>)"
|
||||
self.audio_token = "(<audio>./</audio>)"
|
||||
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
image_data: List[Union[str, bytes]],
|
||||
input_text,
|
||||
request_obj,
|
||||
max_req_input_len,
|
||||
**kwargs,
|
||||
):
|
||||
audio_data = request_obj.audio_data
|
||||
if not image_data and not audio_data:
|
||||
return None
|
||||
if not isinstance(image_data, list):
|
||||
image_data = [image_data]
|
||||
if not isinstance(audio_data, list):
|
||||
audio_data = [audio_data]
|
||||
|
||||
base_output = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
max_req_input_len=max_req_input_len,
|
||||
audio_data=audio_data,
|
||||
image_data=image_data,
|
||||
multimodal_tokens=MultimodalSpecialTokens(
|
||||
image_token=self.image_token,
|
||||
audio_token=self.audio_token,
|
||||
),
|
||||
)
|
||||
if base_output is None:
|
||||
return None
|
||||
|
||||
res = self.process_mm_data(
|
||||
input_text=base_output.input_text,
|
||||
images=base_output.images,
|
||||
audios=base_output.audios,
|
||||
)
|
||||
|
||||
# Collect special token ids
|
||||
tokenizer = self._processor.tokenizer
|
||||
slice_start_id, slice_end_id, audio_start_id, audio_end_id = (
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
if tokenizer.slice_start_id:
|
||||
slice_start_id = tokenizer.slice_start_id
|
||||
slice_end_id = tokenizer.slice_end_id
|
||||
if hasattr(tokenizer, "audio_start_id"):
|
||||
audio_start_id = tokenizer.audio_start_id
|
||||
audio_end_id = tokenizer.audio_end_id
|
||||
|
||||
im_start_id = tokenizer.im_start_id
|
||||
im_end_id = tokenizer.im_end_id
|
||||
im_token_id = tokenizer.unk_id
|
||||
pixel_values = res["pixel_values"]
|
||||
tgt_sizes = res["tgt_sizes"]
|
||||
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError(
|
||||
"Incorrect type of pixel values. " f"Got type: {type(pixel_values)}"
|
||||
)
|
||||
|
||||
if not isinstance(tgt_sizes, (torch.Tensor, list)):
|
||||
raise ValueError(
|
||||
"Incorrect type of target sizes. " f"Got type: {type(tgt_sizes)}"
|
||||
)
|
||||
|
||||
if len(pixel_values) != len(tgt_sizes):
|
||||
raise ValueError(
|
||||
"Inconsistent batch lengths, found: "
|
||||
f"{len(pixel_values)} vs. {len(tgt_sizes)}"
|
||||
)
|
||||
|
||||
pixel_values_flat: List[torch.Tensor] = []
|
||||
tgt_sizes_flat: List[torch.Tensor] = []
|
||||
for pixel_b, tgt_b in zip(pixel_values, tgt_sizes):
|
||||
# per image
|
||||
if len(pixel_b) != len(tgt_b):
|
||||
raise ValueError(
|
||||
"Inconsistent N lengths, found: " f"{len(pixel_b)} vs {len(tgt_b)}"
|
||||
)
|
||||
for pixel_n, tgt_n in zip(pixel_b, tgt_b):
|
||||
pixel_values_flat += [pixel_n]
|
||||
tgt_sizes_flat += [tgt_n]
|
||||
|
||||
pixel_values = pixel_values_flat
|
||||
|
||||
items = []
|
||||
input_ids = res["input_ids"].flatten()
|
||||
image_offsets = self.get_mm_items_offset_by_pair(
|
||||
input_ids=input_ids, mm_start_id=im_start_id, mm_end_id=im_end_id
|
||||
)
|
||||
slice_offsets = self.get_mm_items_offset_by_pair(
|
||||
input_ids=input_ids, mm_start_id=slice_start_id, mm_end_id=slice_end_id
|
||||
)
|
||||
image_offsets.extend(slice_offsets)
|
||||
image_offsets = sorted(image_offsets)
|
||||
|
||||
if len(pixel_values) != 0:
|
||||
item = MultimodalDataItem(
|
||||
pixel_values=pixel_values,
|
||||
image_offsets=image_offsets,
|
||||
tgt_size=tgt_sizes_flat,
|
||||
modality=Modality.IMAGE,
|
||||
)
|
||||
items += [item]
|
||||
|
||||
if (
|
||||
"audio_features" in res
|
||||
and res["audio_features"] is not None
|
||||
and len(res["audio_features"]) != 0
|
||||
):
|
||||
if audio_start_id is not None and audio_end_id is not None:
|
||||
audio_offsets = self.get_mm_items_offset_by_pair(
|
||||
input_ids=input_ids,
|
||||
mm_start_id=audio_start_id,
|
||||
mm_end_id=audio_end_id,
|
||||
)
|
||||
else:
|
||||
audio_offsets = None
|
||||
item = MultimodalDataItem(
|
||||
audio_features=[res["audio_features"]],
|
||||
audio_feature_lens=res["audio_feature_lens"],
|
||||
audio_offsets=audio_offsets,
|
||||
modality=Modality.AUDIO,
|
||||
)
|
||||
items += [item]
|
||||
|
||||
return {
|
||||
"mm_items": items,
|
||||
"input_ids": input_ids.tolist(),
|
||||
"audio_start_id": audio_start_id,
|
||||
"audio_end_id": audio_end_id,
|
||||
"im_token_id": im_token_id,
|
||||
"im_start_id": im_start_id,
|
||||
"im_end_id": im_end_id,
|
||||
"slice_start_id": slice_start_id,
|
||||
"slice_end_id": slice_end_id,
|
||||
}
|
||||
@@ -1,46 +0,0 @@
|
||||
from typing import List, Union
|
||||
|
||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||
BaseMultimodalProcessor,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||
from sglang.srt.models.mllama import MllamaForConditionalGeneration
|
||||
from sglang.srt.utils import load_image
|
||||
|
||||
|
||||
class MllamaImageProcessor(BaseMultimodalProcessor):
|
||||
models = [MllamaForConditionalGeneration]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
|
||||
async def process_mm_data_async(
|
||||
self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs
|
||||
):
|
||||
if not image_data:
|
||||
return None
|
||||
|
||||
if isinstance(input_text, list):
|
||||
assert len(input_text) and isinstance(input_text[0], int)
|
||||
input_text = self._processor.tokenizer.decode(input_text)
|
||||
|
||||
if not isinstance(image_data, list):
|
||||
image_data = [image_data]
|
||||
|
||||
if len(image_data) > 0:
|
||||
images = [load_image(image)[0] for image in image_data]
|
||||
else:
|
||||
images = load_image(image_data[0])[0]
|
||||
|
||||
image_inputs = self.process_mm_data(input_text=input_text, images=images)
|
||||
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
|
||||
image_inputs["mm_items"] = [
|
||||
MultimodalDataItem(
|
||||
pixel_values=image_inputs["pixel_values"],
|
||||
aspect_ratio_id=image_inputs["aspect_ratio_ids"],
|
||||
aspect_ratio_mask=image_inputs["aspect_ratio_mask"],
|
||||
modality=Modality.IMAGE,
|
||||
)
|
||||
]
|
||||
|
||||
return image_inputs
|
||||
@@ -1,152 +0,0 @@
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
from transformers.image_utils import SizeDict
|
||||
from transformers.models.llama4.image_processing_llama4_fast import (
|
||||
find_supported_resolutions,
|
||||
get_best_fit,
|
||||
)
|
||||
|
||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||
BaseMultimodalProcessor,
|
||||
MultimodalSpecialTokens,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||
from sglang.srt.models.mllama4 import Llama4ForConditionalGeneration
|
||||
|
||||
|
||||
class Mllama4ImageProcessor(BaseMultimodalProcessor):
|
||||
models = [Llama4ForConditionalGeneration]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
self.vision_config = hf_config.vision_config
|
||||
self.text_config = hf_config.text_config
|
||||
self.boi_token_index = hf_config.boi_token_index
|
||||
self.eoi_token_index = hf_config.eoi_token_index
|
||||
self.image_token_index = hf_config.image_token_index
|
||||
self.multimodal_tokens = MultimodalSpecialTokens(
|
||||
image_token=_processor.image_token
|
||||
)
|
||||
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
image_data: List[Union[str, bytes]],
|
||||
input_text,
|
||||
max_req_input_len=None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
if not image_data:
|
||||
return None
|
||||
|
||||
if isinstance(input_text, list):
|
||||
assert len(input_text) and isinstance(input_text[0], int)
|
||||
input_text = self._processor.tokenizer.decode(input_text)
|
||||
|
||||
# Process images and text using the base processor's load_mm_data method
|
||||
processed_data = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
multimodal_tokens=self.multimodal_tokens,
|
||||
max_req_input_len=max_req_input_len or 4096,
|
||||
image_data=image_data,
|
||||
return_text=True,
|
||||
)
|
||||
|
||||
# Process the images using the processor
|
||||
processor = self._processor
|
||||
|
||||
# Process the prompt and images
|
||||
processor_output = self.process_mm_data(
|
||||
input_text=processed_data.input_text,
|
||||
images=processed_data.images,
|
||||
)
|
||||
|
||||
# Handle image resolutions and aspect ratios
|
||||
if "pixel_values" in processor_output:
|
||||
image_processor = processor.image_processor
|
||||
tokenizer = self._processor.tokenizer
|
||||
|
||||
# Calculate tile size and find supported resolutions
|
||||
tile_size = self.vision_config.image_size
|
||||
max_num_tiles = getattr(self.vision_config, "max_patches", 1)
|
||||
|
||||
possible_resolutions = find_supported_resolutions(
|
||||
max_num_chunks=max_num_tiles,
|
||||
patch_size=SizeDict(height=tile_size, width=tile_size),
|
||||
)
|
||||
|
||||
# Find best fit for each image
|
||||
best_fit_sizes = [
|
||||
get_best_fit(
|
||||
(image.size[1], image.size[0]), # (height, width)
|
||||
torch.tensor(possible_resolutions),
|
||||
resize_to_max_canvas=image_processor.resize_to_max_canvas,
|
||||
)
|
||||
for image in processed_data.images
|
||||
]
|
||||
|
||||
# Calculate aspect ratios and patches per image
|
||||
aspect_ratios = [
|
||||
(image_size[0] // tile_size, image_size[1] // tile_size)
|
||||
for image_size in best_fit_sizes
|
||||
]
|
||||
|
||||
patches_per_image = [
|
||||
1 if r_h * r_w == 1 else 1 + r_h * r_w for (r_h, r_w) in aspect_ratios
|
||||
]
|
||||
|
||||
# Add to image_inputs
|
||||
processor_output["aspect_ratios"] = aspect_ratios
|
||||
processor_output["patches_per_image"] = torch.tensor(patches_per_image)
|
||||
|
||||
# Process embed_is_patch
|
||||
vocab = tokenizer.get_vocab()
|
||||
patch_id = vocab.get(processor.img_patch_token, -1)
|
||||
image_end_id = vocab.get(processor.end_of_img_token, -1)
|
||||
|
||||
if patch_id != -1 and image_end_id != -1:
|
||||
input_ids = processor_output["input_ids"].view(-1)
|
||||
|
||||
# Remove BOS token if present
|
||||
if input_ids.size(0) > 0 and input_ids[0] == tokenizer.bos_token_id:
|
||||
input_ids = input_ids[1:]
|
||||
|
||||
# Find image end indices and split input_ids
|
||||
image_end_indices = (input_ids == image_end_id).nonzero().view(-1)
|
||||
|
||||
if image_end_indices.size(0) > 0:
|
||||
# Split at image boundaries
|
||||
split_indices = (image_end_indices + 1)[:-1]
|
||||
split_input_ids = torch.tensor_split(input_ids, split_indices)
|
||||
split_input_ids = [x for x in split_input_ids if x.numel() > 0]
|
||||
|
||||
# Create embed_is_patch for each image
|
||||
embed_is_patch = []
|
||||
for per_image_input_ids in split_input_ids:
|
||||
embed_is_patch.append(per_image_input_ids == patch_id)
|
||||
|
||||
processor_output["embed_is_patch"] = embed_is_patch
|
||||
|
||||
# Convert to the format expected by SGLang
|
||||
processor_output["input_ids"] = processor_output["input_ids"].tolist()[0]
|
||||
|
||||
processor_output["im_start_id"] = self.boi_token_index
|
||||
processor_output["im_end_id"] = self.eoi_token_index
|
||||
processor_output["im_token_id"] = self.image_token_index
|
||||
|
||||
image_offsets = self.get_mm_items_offset(
|
||||
input_ids=torch.tensor(processor_output["input_ids"]),
|
||||
mm_token_id=self.image_token_index,
|
||||
)
|
||||
|
||||
# Add metadata for image processing
|
||||
processor_output["mm_items"] = [
|
||||
MultimodalDataItem(
|
||||
pixel_values=processor_output["pixel_values"],
|
||||
modality=Modality.IMAGE,
|
||||
image_offsets=image_offsets,
|
||||
)
|
||||
]
|
||||
|
||||
return processor_output
|
||||
@@ -1,87 +0,0 @@
|
||||
import logging
|
||||
from typing import List, Union
|
||||
|
||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||
BaseMultimodalProcessor,
|
||||
MultimodalSpecialTokens,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||
from sglang.srt.models.phi4mm import Phi4MMForCausalLM
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_IMAGE_SPECIAL_TOKEN = "<|endoftext10|>"
|
||||
_IMAGE_SPECIAL_TOKEN_ID = 200010
|
||||
|
||||
|
||||
class Phi4MMImageProcessor(BaseMultimodalProcessor):
|
||||
models = [Phi4MMForCausalLM]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
self.multimodal_tokens = MultimodalSpecialTokens(
|
||||
image_token=_IMAGE_SPECIAL_TOKEN,
|
||||
)
|
||||
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
image_data: List[Union[str, bytes]],
|
||||
input_text,
|
||||
request_obj,
|
||||
max_req_input_len,
|
||||
**kwargs,
|
||||
):
|
||||
audio_data = request_obj.audio_data
|
||||
|
||||
if not image_data and not audio_data:
|
||||
return None
|
||||
|
||||
if not isinstance(image_data, list):
|
||||
image_data = [image_data]
|
||||
|
||||
if not isinstance(audio_data, list):
|
||||
audio_data = [audio_data]
|
||||
|
||||
if audio_data:
|
||||
logger.warning(
|
||||
"Currently SGLang does not support audio data for Phi4MM. We are working on it. You can file an issue to help us prioritize."
|
||||
)
|
||||
audio_data = []
|
||||
|
||||
base_output = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
max_req_input_len=max_req_input_len,
|
||||
audio_data=audio_data,
|
||||
image_data=image_data,
|
||||
multimodal_tokens=self.multimodal_tokens,
|
||||
)
|
||||
if base_output is None:
|
||||
return None
|
||||
|
||||
res = self.process_mm_data(
|
||||
input_text=base_output.input_text,
|
||||
images=base_output.images,
|
||||
audios=base_output.audios,
|
||||
)
|
||||
|
||||
input_ids = res["input_ids"].flatten()
|
||||
image_offsets = self.get_mm_items_offset(
|
||||
input_ids=input_ids,
|
||||
mm_token_id=_IMAGE_SPECIAL_TOKEN_ID,
|
||||
)
|
||||
|
||||
items = [
|
||||
MultimodalDataItem(
|
||||
pixel_values=res["input_image_embeds"],
|
||||
image_sizes=res["image_sizes"],
|
||||
image_emb_mask=res["image_attention_mask"],
|
||||
image_offsets=image_offsets,
|
||||
modality=Modality.IMAGE,
|
||||
)
|
||||
]
|
||||
|
||||
return {
|
||||
"mm_items": items,
|
||||
"input_ids": input_ids.tolist(),
|
||||
"im_token_id": _IMAGE_SPECIAL_TOKEN_ID,
|
||||
}
|
||||
@@ -1,127 +0,0 @@
|
||||
import asyncio
|
||||
import math
|
||||
from typing import List, Union
|
||||
|
||||
from transformers.models.pixtral.image_processing_pixtral import (
|
||||
_num_image_tokens as _get_pixtral_hf_num_image_tokens,
|
||||
)
|
||||
|
||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||
BaseMultimodalProcessor,
|
||||
MultimodalSpecialTokens,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||
from sglang.srt.models.pixtral import PixtralVisionModel
|
||||
|
||||
|
||||
class PixtralProcessor(BaseMultimodalProcessor):
|
||||
models = [PixtralVisionModel]
|
||||
|
||||
PAD_TOKEN = "<pad>"
|
||||
IMG_BREAK_TOKEN_ID = 12
|
||||
IMG_END_TOKEN_ID = 13
|
||||
|
||||
def get_patch_grid_size(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
) -> tuple[int, int]:
|
||||
max_width = max_height = self.image_size
|
||||
patch_width = patch_height = self.patch_size
|
||||
|
||||
ratio = max(image_width / max_width, image_height / max_height)
|
||||
|
||||
if ratio > 1:
|
||||
image_width = int(math.floor(image_width / ratio))
|
||||
image_height = int(math.floor(image_height / ratio))
|
||||
|
||||
nrows, ncols = _get_pixtral_hf_num_image_tokens(
|
||||
(image_height, image_width),
|
||||
(patch_height, patch_width),
|
||||
)
|
||||
|
||||
return ncols, nrows
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
self.image_token_id = getattr(
|
||||
hf_config, "image_token_index", PixtralVisionModel.DEFAULT_IMAGE_TOKEN_ID
|
||||
)
|
||||
# Instantiate the patcher logic helper using the class defined above
|
||||
|
||||
self.vision_config = hf_config.vision_config
|
||||
self.image_size = self.vision_config.image_size
|
||||
self.patch_size = self.vision_config.patch_size
|
||||
self.multimodal_tokens = MultimodalSpecialTokens(
|
||||
image_token=_processor.image_token
|
||||
)
|
||||
_processor.tokenizer.add_special_tokens(
|
||||
{
|
||||
"pad_token": getattr(hf_config, "pad_token", self.PAD_TOKEN),
|
||||
}
|
||||
)
|
||||
|
||||
async def _resize(self, image):
|
||||
num_w_tokens, num_h_tokens = self.get_patch_grid_size(
|
||||
image_width=image.size[0],
|
||||
image_height=image.size[1],
|
||||
)
|
||||
new_size = (num_w_tokens * self.patch_size, num_h_tokens * self.patch_size)
|
||||
return image.resize(new_size)
|
||||
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
image_data: List[Union[str, bytes]],
|
||||
input_text,
|
||||
request_obj,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
if not image_data:
|
||||
return None
|
||||
|
||||
if isinstance(image_data, str):
|
||||
image_data = [image_data]
|
||||
|
||||
mm_data = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
multimodal_tokens=self.multimodal_tokens,
|
||||
max_req_input_len=kwargs.get("max_req_input_len", 4096),
|
||||
image_data=image_data,
|
||||
return_text=True,
|
||||
)
|
||||
|
||||
if mm_data.images:
|
||||
resize_tasks = [self._resize(image) for image in mm_data.images]
|
||||
mm_data.images = await asyncio.gather(*resize_tasks)
|
||||
|
||||
processor_output = self.process_mm_data(
|
||||
input_text=mm_data.input_text,
|
||||
images=mm_data.images,
|
||||
)
|
||||
|
||||
if "pixel_values" in processor_output:
|
||||
input_ids = processor_output["input_ids"].view(-1)
|
||||
image_offsets = self.get_mm_items_offset(
|
||||
input_ids=input_ids,
|
||||
mm_token_id=self.image_token_id,
|
||||
)
|
||||
mm_items = [
|
||||
MultimodalDataItem(
|
||||
pixel_values=processor_output["pixel_values"],
|
||||
image_sizes=processor_output["image_sizes"],
|
||||
modality=Modality.IMAGE,
|
||||
image_offsets=image_offsets,
|
||||
)
|
||||
]
|
||||
|
||||
input_ids = input_ids.tolist()
|
||||
processor_output.update(
|
||||
input_ids=input_ids,
|
||||
mm_items=mm_items,
|
||||
# there's no im_start_id for pixtral, only im_token and im_end_token
|
||||
im_end_id=self.IMG_END_TOKEN_ID,
|
||||
im_token_id=self.image_token_id,
|
||||
)
|
||||
return processor_output
|
||||
@@ -1,169 +0,0 @@
|
||||
import asyncio
|
||||
import math
|
||||
import re
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||
BaseMultimodalProcessor as SGLangBaseProcessor,
|
||||
)
|
||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||
MultimodalSpecialTokens,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||
from sglang.srt.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
|
||||
from sglang.srt.models.qwen2_vl import Qwen2VLForConditionalGeneration
|
||||
|
||||
|
||||
# Compatible with Qwen2VL and Qwen2_5VL
|
||||
class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
||||
models = [Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
# The single, pre-expanded image token.
|
||||
self.IMAGE_TOKEN = "<|vision_start|><|image_pad|><|vision_end|>"
|
||||
# The regex that matches expanded image tokens.
|
||||
self.IMAGE_TOKEN_REGEX = re.compile(
|
||||
r"<\|vision_start\|>(?:<\|image_pad\|>)+<\|vision_end\|>"
|
||||
)
|
||||
self.IM_START_TOKEN_ID = hf_config.vision_start_token_id
|
||||
self.IM_END_TOKEN_ID = hf_config.vision_end_token_id
|
||||
self.IM_TOKEN_ID = hf_config.image_token_id
|
||||
self.VIDEO_TOKEN_ID = hf_config.video_token_id
|
||||
self.vision_start_token_id = hf_config.vision_start_token_id
|
||||
self.vision_end_token_id = hf_config.vision_end_token_id
|
||||
self.NUM_TOKEN_PER_FRAME = 770
|
||||
self.IMAGE_FACTOR = 28
|
||||
self.MIN_PIXELS = 4 * 28 * 28
|
||||
self.MAX_PIXELS = 16384 * 28 * 28
|
||||
self.MAX_RATIO = 200
|
||||
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
image_data: List[Union[str, bytes, Dict]],
|
||||
input_text,
|
||||
request_obj,
|
||||
max_req_input_len,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
if isinstance(image_data, str):
|
||||
image_data = [image_data]
|
||||
|
||||
base_output = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
image_data=image_data,
|
||||
multimodal_tokens=MultimodalSpecialTokens(
|
||||
image_token=self.IMAGE_TOKEN,
|
||||
image_token_regex=self.IMAGE_TOKEN_REGEX,
|
||||
),
|
||||
max_req_input_len=max_req_input_len,
|
||||
)
|
||||
|
||||
def smart_resize(
|
||||
height: int,
|
||||
width: int,
|
||||
factor: int = self.IMAGE_FACTOR,
|
||||
min_pixels: int = self.MIN_PIXELS,
|
||||
max_pixels: int = self.MAX_PIXELS,
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Rescales the image so that the following conditions are met:
|
||||
|
||||
1. Both dimensions (height and width) are divisible by 'factor'.
|
||||
|
||||
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
|
||||
|
||||
3. The aspect ratio of the image is maintained as closely as possible.
|
||||
"""
|
||||
if max(height, width) / min(height, width) > self.MAX_RATIO:
|
||||
raise ValueError(
|
||||
f"absolute aspect ratio must be smaller than {self.MAX_RATIO}, got {max(height, width) / min(height, width)}"
|
||||
)
|
||||
h_bar = max(factor, round_by_factor(height, factor))
|
||||
w_bar = max(factor, round_by_factor(width, factor))
|
||||
if h_bar * w_bar > max_pixels:
|
||||
beta = math.sqrt((height * width) / max_pixels)
|
||||
h_bar = floor_by_factor(height / beta, factor)
|
||||
w_bar = floor_by_factor(width / beta, factor)
|
||||
elif h_bar * w_bar < min_pixels:
|
||||
beta = math.sqrt(min_pixels / (height * width))
|
||||
h_bar = ceil_by_factor(height * beta, factor)
|
||||
w_bar = ceil_by_factor(width * beta, factor)
|
||||
return h_bar, w_bar
|
||||
|
||||
def resize_image(image, size_factor: int = self.IMAGE_FACTOR) -> Image.Image:
|
||||
width, height = image.size
|
||||
min_pixels = self.MIN_PIXELS
|
||||
max_pixels = self.MAX_PIXELS
|
||||
resized_height, resized_width = smart_resize(
|
||||
height,
|
||||
width,
|
||||
factor=size_factor,
|
||||
min_pixels=min_pixels,
|
||||
max_pixels=max_pixels,
|
||||
)
|
||||
image = image.resize((resized_width, resized_height))
|
||||
return image
|
||||
|
||||
def round_by_factor(number: int, factor: int) -> int:
|
||||
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
|
||||
return round(number / factor) * factor
|
||||
|
||||
def ceil_by_factor(number: int, factor: int) -> int:
|
||||
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
|
||||
return math.ceil(number / factor) * factor
|
||||
|
||||
def floor_by_factor(number: int, factor: int) -> int:
|
||||
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
|
||||
return math.floor(number / factor) * factor
|
||||
|
||||
async def resize_image_async(image):
|
||||
return resize_image(image)
|
||||
|
||||
# Qwen-specific: resize images if they are raw Image objects
|
||||
if base_output.images and isinstance(base_output.images[0], Image.Image):
|
||||
resize_tasks = [resize_image_async(image) for image in base_output.images]
|
||||
base_output.images = await asyncio.gather(*resize_tasks)
|
||||
|
||||
video_grid_thw = None # TODO
|
||||
|
||||
combined_mm_item, input_ids = self.process_and_combine_mm_data(base_output)
|
||||
|
||||
if combined_mm_item is None:
|
||||
# Note(Xinyuan): This is the case where image loading fails.
|
||||
return None
|
||||
|
||||
video_grid_thw = None # TODO
|
||||
second_per_grid_ts = getattr(combined_mm_item, "second_per_grid_ts", None)
|
||||
|
||||
mrope_positions, mrope_position_delta = MRotaryEmbedding.get_rope_index(
|
||||
spatial_merge_size=self.hf_config.vision_config.spatial_merge_size,
|
||||
image_token_id=self.IM_TOKEN_ID,
|
||||
video_token_id=self.VIDEO_TOKEN_ID,
|
||||
vision_start_token_id=self.vision_start_token_id,
|
||||
model_type=self.hf_config.model_type,
|
||||
tokens_per_second=getattr(
|
||||
self.hf_config.vision_config, "tokens_per_second", None
|
||||
),
|
||||
input_ids=input_ids.unsqueeze(0),
|
||||
image_grid_thw=combined_mm_item.image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
second_per_grid_ts=second_per_grid_ts,
|
||||
)
|
||||
mrope_positions = mrope_positions.squeeze(1)
|
||||
|
||||
return {
|
||||
"input_ids": input_ids.tolist(),
|
||||
"mm_items": [combined_mm_item],
|
||||
"im_start_id": self.IM_START_TOKEN_ID,
|
||||
"im_end_id": self.IM_END_TOKEN_ID,
|
||||
"im_token_id": self.IM_TOKEN_ID,
|
||||
"video_token_id": self.VIDEO_TOKEN_ID,
|
||||
"mrope_positions": mrope_positions,
|
||||
"mrope_position_delta": mrope_position_delta,
|
||||
}
|
||||
@@ -1,85 +0,0 @@
|
||||
from typing import Any, Dict, List, Optional, Type, cast
|
||||
|
||||
import torch.nn as nn
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.processing_utils import ProcessorMixin
|
||||
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
||||
|
||||
from sglang.srt.managers.io_struct import (
|
||||
EmbeddingReqInput,
|
||||
GenerateReqInput,
|
||||
ImageDataItem,
|
||||
)
|
||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||
BaseMultimodalProcessor,
|
||||
MultimodalSpecialTokens,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||
from sglang.srt.models.vila import VILAForConditionalGeneration
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
|
||||
|
||||
class VILAProcessor(ProcessorMixin):
|
||||
"""A stub class for the VILA processor."""
|
||||
|
||||
tokenizer: PreTrainedTokenizerBase
|
||||
|
||||
|
||||
class VILAMultimodalProcessor(BaseMultimodalProcessor):
|
||||
models: List[Type[nn.Module]] = [VILAForConditionalGeneration]
|
||||
|
||||
_processor: VILAProcessor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hf_config: PretrainedConfig,
|
||||
server_args: ServerArgs,
|
||||
_processor: VILAProcessor,
|
||||
) -> None:
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
image_data: Optional[ImageDataItem | List[ImageDataItem]],
|
||||
input_text: str | List[int],
|
||||
request_obj: GenerateReqInput | EmbeddingReqInput,
|
||||
max_req_input_len: int,
|
||||
**kwargs,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
if not image_data:
|
||||
return None
|
||||
|
||||
if not isinstance(image_data, list):
|
||||
image_data = [image_data]
|
||||
|
||||
mm_data = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
multimodal_tokens=MultimodalSpecialTokens(
|
||||
image_token=self._processor.tokenizer.image_token
|
||||
),
|
||||
max_req_input_len=max_req_input_len,
|
||||
image_data=image_data,
|
||||
)
|
||||
|
||||
inputs = self.process_mm_data(
|
||||
input_text=mm_data.input_text,
|
||||
images=mm_data.images,
|
||||
)
|
||||
|
||||
image_offsets = self.get_mm_items_offset(
|
||||
input_ids=inputs.input_ids[0],
|
||||
mm_token_id=cast(int, self._processor.tokenizer.image_token_id),
|
||||
)
|
||||
|
||||
mm_items: List[MultimodalDataItem] = [
|
||||
MultimodalDataItem(
|
||||
modality=Modality.IMAGE,
|
||||
image_offsets=image_offsets,
|
||||
pixel_values=inputs.pixel_values,
|
||||
)
|
||||
]
|
||||
|
||||
return dict(
|
||||
input_ids=inputs.input_ids[0].tolist(),
|
||||
mm_items=mm_items,
|
||||
)
|
||||
Reference in New Issue
Block a user