refactor: multimodal data (#4754)

This commit is contained in:
Mick
2025-04-01 00:57:51 +08:00
committed by GitHub
parent c7457191a0
commit 5cb552b1d4
36 changed files with 989 additions and 1138 deletions

View File

@@ -1,5 +1,5 @@
"""
Multimodality utils
Multi-modality utils
"""
from abc import abstractmethod
@@ -9,11 +9,13 @@ import torch
from torch import nn
from sglang.srt.managers.schedule_batch import (
MultimodalDataItem,
MultimodalInputs,
global_server_args_dict,
logger,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import print_warning_once
from sglang.utils import logger
@@ -26,7 +28,7 @@ class MultiModalityDataPaddingPattern:
@abstractmethod
def pad_input_tokens(
self, input_ids: List[int], image_inputs: MultimodalInputs
self, input_ids: List[int], mm_inputs: MultimodalInputs
) -> List[int]:
"""
Pad the input ids sequence containing data tokens, and replace them with pad_values
@@ -49,13 +51,13 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
"""
This function will replace the data-tokens inbetween with pad_values accordingly
"""
pad_values = mm_inputs.pad_values
pad_values = [item.pad_value for item in mm_inputs.mm_items]
data_token_pairs = self.data_token_id_pairs
mm_inputs.image_offsets = []
mm_inputs.data_offsets = []
if data_token_pairs is None:
data_token_pairs = [mm_inputs.im_start_id, mm_inputs.im_end_id]
if data_token_pairs is None:
logger.warning(
print_warning_once(
"No data_token_pairs provided, RadixAttention might be influenced."
)
return input_ids
@@ -77,10 +79,10 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
if input_ids[start_idx] in start_token_ids:
data_idx += 1
mm_inputs.image_offsets += [start_idx]
mm_inputs.data_offsets += [start_idx]
if data_idx >= len(mm_inputs.pad_values):
data_idx = len(mm_inputs.pad_values) - 1
if data_idx >= len(pad_values):
data_idx = len(pad_values) - 1
num_tokens = end_idx - start_idx - 1
pad_value = pad_values[data_idx]
@@ -94,68 +96,19 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
return padded_ids
class MultModalityDataPaddingPatternSingleToken(MultiModalityDataPaddingPattern):
"""In this pattern, data is represented with a special token_id ( image_inputs.im_token_id ),
which needs first to be expanded to multiple tokens, then replaced with their padding values
This strategy should be used when a single data token represents content that should
be expanded to multiple tokens during processing.
"""
def __init__(
self, num_data_token_calc_func: Callable[[Tuple[int, int, int]], int]
) -> None:
self.num_data_token_calc_func = num_data_token_calc_func
def pad_input_tokens(
self, input_ids: List[int], mm_inputs: MultimodalInputs
) -> List[int]:
"""
This function will follow the procedure of:
1. the data token will be expanded, of which the final number will be calculated by `num_data_token_calc_func`
2. the padded data tokens will be replaced with their pad_values
"""
image_grid_thws = mm_inputs.image_grid_thws
pad_values = mm_inputs.pad_values
image_indices = [
idx for idx, token in enumerate(input_ids) if token == mm_inputs.im_token_id
]
mm_inputs.image_offsets = []
input_ids_with_image = []
for image_cnt, _ in enumerate(image_grid_thws):
# print(f"image_cnt {image_cnt}")
num_image_tokens = self.num_data_token_calc_func(image_grid_thws[image_cnt])
if image_cnt == 0:
non_image_tokens = input_ids[: image_indices[image_cnt]]
else:
non_image_tokens = input_ids[
image_indices[image_cnt - 1] + 1 : image_indices[image_cnt]
]
input_ids_with_image.extend(non_image_tokens)
mm_inputs.image_offsets.append(len(input_ids_with_image))
pad_ids = pad_values * (
(num_image_tokens + len(pad_values)) // len(pad_values)
)
input_ids_with_image.extend(pad_ids[:num_image_tokens])
input_ids_with_image.extend(input_ids[image_indices[-1] + 1 :])
return input_ids_with_image
class MultiModalityDataPaddingPatternImageTokens(MultiModalityDataPaddingPattern):
"""In this pattern, data tokens should be represented as image tokens (e.g. <image><image>....<image>)"""
"""In this pattern, data tokens should be represented as repetitions of a single token
e.g. <image><image>....<image>, or <audio><audio>...<audio>
"""
def __init__(self, image_token_id: torch.Tensor) -> None:
self.image_token_id = image_token_id
def pad_input_tokens(self, input_ids: List[int], image_inputs) -> List[int]:
def pad_input_tokens(self, input_ids: List[int], mm_inputs) -> List[int]:
"""
This function will replace the data-tokens in between with pad_values accordingly
"""
pad_values = image_inputs.pad_values
pad_values = [item.pad_value for item in mm_inputs.mm_items]
assert len(pad_values) != 0
input_ids_tensor = torch.tensor(input_ids)
@@ -170,138 +123,227 @@ class MultiModalityDataPaddingPatternImageTokens(MultiModalityDataPaddingPattern
return input_ids_tensor.tolist()
def get_embedding_and_mask(
data_embedding_func: Callable[[List[MultimodalDataItem]], torch.Tensor],
embedding_items: List[MultimodalDataItem],
placeholder_tensor: torch.Tensor,
input_ids: torch.Tensor,
):
"""
Get the multimodal embedding and its mask from input_ids
"""
# 1. Get the embedding
embedding = data_embedding_func(embedding_items)
# 2. Check the embedding
if embedding.dim() == 2:
num_mm_tokens_in_embedding = embedding.shape[0]
else:
num_mm_tokens_in_embedding = embedding.shape[0] * embedding.shape[1]
# the mask of multimodal tokens from input_ids
special_multimodal_mask = torch.isin(
input_ids,
placeholder_tensor,
).unsqueeze(-1)
num_mm_tokens_in_input_ids = special_multimodal_mask.sum()
if num_mm_tokens_in_input_ids != num_mm_tokens_in_embedding:
logger.warning(
f"Number of tokens in multimodal embedding does not match those in the input text."
f"Got {num_mm_tokens_in_input_ids} tokens in the text but {num_mm_tokens_in_embedding} "
"tokens from multimodal embeddings."
)
if num_mm_tokens_in_input_ids < num_mm_tokens_in_embedding:
# TODO: chunked prefill will split special tokens from input_ids into several passes, failing the embedding
# a fix may be cache the unfinished multimodal embedding for future reuse, determine the tokens to embed with
# extend_start_loc and extend_seq_lens
chunked_prefill_size = global_server_args_dict["chunked_prefill_size"]
if chunked_prefill_size != -1:
logger.warning(
"You may want to avoid this issue by raising `chunked_prefill_size`, or disabling chunked prefill"
)
# extract from the end: this is a compromise
if embedding.dim() == 2:
embedding = embedding[-num_mm_tokens_in_input_ids:, :]
else:
num_multimodal = num_mm_tokens_in_input_ids // embedding.shape[0]
embedding = embedding[-num_multimodal:, :]
else:
raise RuntimeError(
"Insufficient multimodal embedding length. This is an internal error"
)
return embedding, special_multimodal_mask
def embed_mm_inputs(
mm_input: MultimodalInputs,
mm_inputs: MultimodalInputs,
input_ids: torch.Tensor,
input_embedding: nn.Embedding,
mm_data_embedding_func: Callable[[MultimodalInputs], torch.Tensor],
image_data_embedding_func: Callable[
[List[MultimodalDataItem]], torch.Tensor
] = None,
audio_data_embedding_func: Callable[
[List[MultimodalDataItem]], torch.Tensor
] = None,
placeholder_token_ids: List[int] = None,
) -> Optional[torch.Tensor]:
"""
Calculate the image embeddings if necessary, then scatter the result with
the help of a boolean mask denoting the embed locations
Calculate the multimodal embeddings if necessary, then scatter the result with the help of a boolean mask denoting the embed locations
Returns:
final embedding: Optional[torch.Tensor]
Args:
placeholder_token_ids: denoting the token of multimodal data in input_ids.
If none, the pad_values of multimodal items are used
Returns:
final embedding: Optional[torch.Tensor]
"""
if mm_input is None:
if mm_inputs is None:
return None
placeholder_token_ids = placeholder_token_ids or mm_input.pad_values
# 1. Calculate the multimodal data which exists in input_ids, with the help of pad_values
# we assume that multimodal data are represented with its pad_values in input_ids
placeholder_token_ids = placeholder_token_ids or [
item.pad_value for item in mm_inputs.mm_items
]
# boolean masking the special tokens
special_image_mask = torch.isin(
input_ids,
torch.tensor(placeholder_token_ids, device=input_ids.device),
).unsqueeze(-1)
placeholder_tensor = torch.tensor(placeholder_token_ids, device=input_ids.device)
num_image_tokens_in_input_ids = special_image_mask.sum()
# print(f"{num_image_tokens_in_input_ids}")
# print(f"{input_ids}")
placeholder_masks = torch.isin(input_ids, placeholder_tensor)
# return
if num_image_tokens_in_input_ids == 0:
# unexpected
appearing_pad_values = torch.unique(
input_ids[placeholder_masks], return_counts=False
)
if appearing_pad_values.numel() == 0:
# all been prefixed
inputs_embeds = input_embedding(input_ids)
else:
# print(f"Getting image feature")
image_embedding = mm_data_embedding_func(mm_input)
appearing_items = [
item
for item in mm_inputs.mm_items
if item.pad_value is not None and item.pad_value in appearing_pad_values
]
# print(f"image_embedding: {image_embedding.shape}")
if image_embedding.dim() == 2:
num_image_tokens_in_embedding = image_embedding.shape[0]
else:
num_image_tokens_in_embedding = (
image_embedding.shape[0] * image_embedding.shape[1]
)
if num_image_tokens_in_input_ids != num_image_tokens_in_embedding:
num_image = num_image_tokens_in_input_ids // image_embedding.shape[1]
image_embedding = image_embedding[:num_image, :]
logger.warning(
f"Number of images does not match number of special image tokens in the input text. "
f"Got {num_image_tokens_in_input_ids} image tokens in the text but {num_image_tokens_in_embedding} "
"tokens from image embeddings."
using_all_items = False
if len(appearing_items) == 0:
# This happens mostly when arg placeholder_token_ids is passed
logger.warning_once(
"No multimodal data item's pad value exist in placeholder ids. Using all items"
)
using_all_items = True
appearing_items = mm_inputs.mm_items
# TODO: chunked prefill will split special tokens from input_ids into several passes, failing the embedding
# a fix may be cache the unfinished image embedding for future reuse, determine the tokens to embed with
# extend_start_loc and extend_seq_lens
if num_image_tokens_in_input_ids > num_image_tokens_in_embedding:
chunked_prefill_size = global_server_args_dict["chunked_prefill_size"]
if chunked_prefill_size != -1:
logger.warning(
"You may want to avoid this issue by raising `chunked_prefill_size`, or disabling chunked_prefill"
embeddings, masks = [], []
# 2. Get multimodal embedding separately
# TODO: make this more generic
# Try get image embedding if any
if (
any(True for item in appearing_items if item.is_image())
and image_data_embedding_func
):
items = [item for item in appearing_items if item.is_image()]
embedding, mask = get_embedding_and_mask(
data_embedding_func=image_data_embedding_func,
embedding_items=items,
placeholder_tensor=(
placeholder_tensor
if using_all_items
else torch.tensor(
[item.pad_value for item in items],
device=input_ids.device,
)
),
input_ids=input_ids,
)
embeddings += [embedding]
masks += [mask]
# Try get audio embedding if any
if (
any(True for item in appearing_items if item.is_audio())
and audio_data_embedding_func
):
items = [item for item in appearing_items if item.is_audio()]
embedding, mask = get_embedding_and_mask(
data_embedding_func=audio_data_embedding_func,
embedding_items=items,
placeholder_tensor=(
placeholder_tensor
if using_all_items
else torch.tensor(
[item.pad_value for item in items],
device=input_ids.device,
)
),
input_ids=input_ids,
)
embeddings += [embedding]
masks += [mask]
# 3. Get input embeddings
vocab_size = input_embedding.num_embeddings
# Important: clamp after getting original image regions
# Clamp input ids. This is because the input_ids for the image tokens are
# filled with the hash values of the image for the prefix matching in the radix attention.
# Important: clamp after getting original multimodal regions
# Clamp input ids. This is because the input_ids for the multimodal tokens are
# filled with the hash values of the multimodal for the prefix matching in the radix attention.
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
input_ids.clamp_(min=0, max=vocab_size - 1)
inputs_embeds = input_embedding(input_ids)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
inputs_embeds.device
)
inputs_embeds = inputs_embeds.masked_scatter(
special_image_mask,
image_embedding.to(inputs_embeds.device, inputs_embeds.dtype),
)
return inputs_embeds
def embed_image_embedding(
inputs_embeds: torch.Tensor,
image_embedding: torch.Tensor,
image_bounds: torch.Tensor,
) -> torch.Tensor:
"""
scatter image_embedding into inputs_embeds according to image_bounds
"""
if len(image_bounds) > 0:
image_indices = torch.stack(
[
torch.arange(start, end, dtype=torch.long)
for start, end in image_bounds.tolist()
]
).to(inputs_embeds.device)
inputs_embeds.scatter_(
0,
image_indices.view(-1, 1).repeat(1, inputs_embeds.shape[-1]),
image_embedding.view(-1, image_embedding.shape[-1]),
)
# 4. scatter embeddings into input embedding
for embedding, mask in zip(embeddings, masks):
mask = mask.expand_as(inputs_embeds).to(inputs_embeds.device)
inputs_embeds = inputs_embeds.masked_scatter(
mask,
embedding.to(inputs_embeds.device, inputs_embeds.dtype),
)
return inputs_embeds
def general_mm_embed_routine(
input_ids: torch.Tensor,
forward_batch: ForwardBatch,
embed_tokens: nn.Embedding,
mm_data_embedding_func: Callable[[MultimodalInputs], torch.Tensor],
language_model: nn.Module,
image_data_embedding_func: Callable[
[List[MultimodalDataItem]], torch.Tensor
] = None,
audio_data_embedding_func: Callable[
[List[MultimodalDataItem]], torch.Tensor
] = None,
placeholder_token_ids: List[int] = None,
):
**kwargs,
) -> torch.Tensor:
"""
a general wrapper function to get final input embeds from multimodal models
with a language model as causal model
A general wrapper function to get final input embeds from multimodal models with a language model as causal model
Args:
placeholder_token_ids (List[int]): the ids of mm data placeholder tokens
image_data_embedding_func : the function returning the image embedding
audio_data_embedding_func : the function returning the image embedding
Returns:
inputs_embedding
forwarded hidden states
"""
assert hasattr(language_model, "get_input_embeddings")
embed_tokens = language_model.get_input_embeddings()
if (
not forward_batch.forward_mode.is_decode()
and forward_batch.contains_mm_inputs()
):
image = forward_batch.merge_mm_inputs()
mm_input = forward_batch.merge_mm_inputs()
inputs_embeds = embed_mm_inputs(
mm_input=image,
mm_inputs=mm_input,
input_ids=input_ids,
input_embedding=embed_tokens,
mm_data_embedding_func=mm_data_embedding_func,
image_data_embedding_func=image_data_embedding_func,
audio_data_embedding_func=audio_data_embedding_func,
placeholder_token_ids=placeholder_token_ids,
)
# once used, mm_inputs is useless
@@ -310,7 +352,13 @@ def general_mm_embed_routine(
else:
inputs_embeds = embed_tokens(input_ids)
return inputs_embeds
hidden_states = language_model(
input_ids=None,
forward_batch=forward_batch,
input_embeds=inputs_embeds,
**kwargs,
)
return hidden_states
def get_multimodal_data_bounds(
@@ -322,15 +370,13 @@ def get_multimodal_data_bounds(
Returns:
[bounds_count, 2]
"""
# All the images in the batch should share the same special image
# bound token ids.
# All the multimodal data in the batch should share the same special bound token ids.
start_tokens = [s for s, _e in token_pairs]
end_tokens = [e for _s, e in token_pairs]
assert all(isinstance(t, int) for t in start_tokens)
assert all(isinstance(t, int) for t in end_tokens)
# print(input_ids)
start_cond = torch.isin(
input_ids, torch.tensor(start_tokens, device=input_ids.device)
)
@@ -339,7 +385,7 @@ def get_multimodal_data_bounds(
(data_start_tokens,) = torch.where(start_cond)
(data_end_tokens,) = torch.where(end_cond)
# the im_start_id sometimes can be cached as prefix, but it is needed for the embedding of the images
# the im_start_id sometimes can be cached as prefix, but it is needed for the embedding of the multimodal data
if len(data_start_tokens) != len(data_end_tokens):
if (
len(data_start_tokens) + 1 == len(data_end_tokens)
@@ -352,14 +398,14 @@ def get_multimodal_data_bounds(
data_start_tokens,
]
)
valid_image_nums = min(len(data_start_tokens), len(data_end_tokens))
valid_mm_data_nums = min(len(data_start_tokens), len(data_end_tokens))
if valid_image_nums == 0:
if valid_mm_data_nums == 0:
return torch.zeros((0, 2), device=input_ids.device)
# Filter out pairs where start_token >= end_token
valid_pairs = []
for i in range(valid_image_nums):
for i in range(valid_mm_data_nums):
start_token = data_start_tokens[i]
end_token = data_end_tokens[i]
if start_token < end_token:

View File

@@ -64,5 +64,3 @@ def get_mm_processor(
f"No processor registered for architecture: {hf_config.architectures}.\n"
f"Registered architectures: {[model_cls.__name__ for model_cls in PROCESSOR_MAPPING.keys()]}"
)
self.image_proce

View File

@@ -8,18 +8,10 @@ from typing import Optional
import numpy as np
import PIL
import transformers
from decord import VideoReader, cpu
from PIL import Image
from sglang.srt.utils import load_audio, load_image, logger
global global_processor
def get_global_processor():
global global_processor
return global_processor
from sglang.srt.utils import encode_video, load_audio, load_image, logger
@dataclasses.dataclass
@@ -27,9 +19,6 @@ class BaseMultiModalProcessorOutput:
# input_text, with each frame of video/image represented with a image_token
input_text: str
mm_data_hashes: Optional[list[int]]
# images
image_sizes: Optional[list[int]]
# frames loaded from image and video, in given order
images: Optional[list[PIL.Image]] = None
@@ -37,7 +26,7 @@ class BaseMultiModalProcessorOutput:
audios: Optional[list[np.ndarray]] = None
def normalize(self):
for field_name in ["data_hashes", "image_sizes", "images", "audios"]:
for field_name in ["image_sizes", "images", "audios"]:
field = getattr(self, field_name, None)
if field is not None and isinstance(field, list) and len(field) == 0:
setattr(self, field_name, None)
@@ -67,28 +56,35 @@ class BaseMultimodalProcessor(ABC):
# FIXME: not accurate, model and image specific
self.NUM_TOKEN_PER_FRAME = 330
# Initialize global processor first
init_global_processor(self, server_args)
self.executor = concurrent.futures.ProcessPoolExecutor(
initializer=init_global_processor,
self.io_executor = concurrent.futures.ThreadPoolExecutor(
max_workers=int(os.environ.get("SGLANG_IO_WORKERS", 4))
)
self.cpu_executor = concurrent.futures.ProcessPoolExecutor(
mp_context=mp.get_context("fork"),
initargs=(
self,
server_args,
),
max_workers=int(os.environ.get("SGLANG_CPU_COUNT", os.cpu_count())),
max_workers=int(os.environ.get("SGLANG_CPU_WORKERS", os.cpu_count())),
)
def _build_processor(self, server_args):
"""Init the global processor for multi modal models."""
from sglang.srt.hf_transformers_utils import get_processor
def process_mm_data(
self, input_text, images=None, videos=None, audios=None, **kwargs
):
"""
process multimodal data with transformers AutoProcessor
"""
if images is not None:
kwargs["images"] = images
if videos is not None:
kwargs["videos"] = videos
if audios is not None:
kwargs["audios"] = audios
return get_processor(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
processor = self._processor
result = processor.__call__(
text=[input_text],
padding=True,
return_tensors="pt",
**kwargs,
)
return result
@abstractmethod
async def process_mm_data_async(
@@ -115,33 +111,9 @@ class BaseMultimodalProcessor(ABC):
return estimated_frames_list
@staticmethod
def encode_video(video_path, frame_count_limit=None):
if not os.path.exists(video_path):
logger.error(f"Video {video_path} does not exist")
return []
if frame_count_limit == 0:
return []
def uniform_sample(l, n):
gap = len(l) / n
idxs = [int(i * gap + gap / 2) for i in range(n)]
return [l[i] for i in idxs]
vr = VideoReader(video_path, ctx=cpu(0))
sample_fps = round(vr.get_avg_fps() / 1) # FPS
frame_indices = [i for i in range(0, len(vr), sample_fps)]
if frame_count_limit is not None and len(frame_indices) > frame_count_limit:
frame_indices = uniform_sample(frame_indices, frame_count_limit)
frames = vr.get_batch(frame_indices).asnumpy()
frames = [Image.fromarray(v.astype("uint8")) for v in frames]
return frames
def load_mm_data(
self,
input_ids: list[int],
prompt: str,
multimodal_tokens: MultimodalSpecialTokens,
max_req_input_len: int,
image_data: Optional[list] = None,
@@ -167,11 +139,13 @@ class BaseMultimodalProcessor(ABC):
else:
multimodal_tokens.image_token = multimodal_tokens.image_token
if isinstance(input_ids, list) and return_text:
assert len(input_ids) and isinstance(input_ids[0], int)
input_text = self._processor.tokenizer.decode(input_ids)
assert isinstance(prompt, str)
if isinstance(prompt, list) and return_text:
assert len(prompt) and isinstance(prompt[0], int)
prompt = self._processor.tokenizer.decode(prompt)
else:
input_text = input_ids
prompt = prompt
if return_text:
import re
@@ -181,7 +155,7 @@ class BaseMultimodalProcessor(ABC):
+ ")"
)
# split text into list of normal text and special tokens
text_parts = re.split(pattern, input_text)
text_parts = re.split(pattern, prompt)
# TODO(mick): load from server_args, env, or sampling_params
MAX_NUM_FRAMES = 30
@@ -217,7 +191,7 @@ class BaseMultimodalProcessor(ABC):
):
# video
path = image_file[len("video:") :]
frames = BaseMultimodalProcessor.encode_video(
frames = encode_video(
path, frame_count_limit=frames_to_process
)
else:
@@ -254,19 +228,9 @@ class BaseMultimodalProcessor(ABC):
raise RuntimeError(f"An exception occurred while loading images: {e}")
out = BaseMultiModalProcessorOutput(
mm_data_hashes=hashes,
image_sizes=image_sizes,
images=images,
audios=audios,
input_text=new_text,
)
out.normalize()
return out
def init_global_processor(sglang_processor: BaseMultimodalProcessor, server_args):
"""
Init the global processor for multimodal models."""
global global_processor
transformers.logging.set_verbosity_error()
global_processor = sglang_processor._build_processor(server_args=server_args)

View File

@@ -1,10 +1,9 @@
import asyncio
from typing import List, Union
from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor,
get_global_processor,
)
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.models.clip import CLIPModel
from sglang.srt.utils import load_image
@@ -15,29 +14,6 @@ class ClipImageProcessor(BaseMultimodalProcessor):
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
@staticmethod
def _process_single_image_task(images, input_text):
# input_ids', 'attention_mask', 'pixel_values', 'aspect_ratio_ids', 'aspect_ratio_mask', 'cross_attention_mask'
return get_global_processor()(
images=images, text=input_text, return_tensors="pt"
)
async def _process_single_image(self, images, input_text):
if self.executor is not None:
loop = asyncio.get_event_loop()
image_inputs = await loop.run_in_executor(
self.executor,
ClipImageProcessor._process_single_image_task,
images,
input_text,
)
else:
image_inputs = self._processor(
images=images, text=[input_text], return_tensors="pt"
)
return image_inputs
async def process_mm_data_async(
self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs
):
@@ -56,8 +32,13 @@ class ClipImageProcessor(BaseMultimodalProcessor):
else:
images = load_image(image_data[0])[0]
image_inputs = await self._process_single_image(images, input_text)
image_inputs = self.process_mm_data(input_text=input_text, images=images)
image_inputs["data_hashes"] = [hash(str(image_data))]
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
image_inputs["mm_items"] = [
MultimodalDataItem(
pixel_values=image_inputs["pixel_values"], modality=Modality.IMAGE
)
]
return image_inputs

View File

@@ -16,15 +16,14 @@
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import asyncio
import torch
from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor,
MultimodalSpecialTokens,
get_global_processor,
)
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.models.deepseek_vl2 import DeepseekVL2ForCausalLM
@@ -35,51 +34,6 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
super().__init__(hf_config, server_args, _processor)
self.IMAGE_TOKEN = "<image>"
@staticmethod
def _process_images_task(image, input_text, max_req_input_len):
processor = get_global_processor()
res = processor.__call__(
conversations=input_text, images=image, max_req_input_len=max_req_input_len
)
image_token_id = processor.image_token_id
res["im_token_id"] = image_token_id
return res
async def _process_images(self, image_data, input_text, max_req_input_len):
if self.executor is not None:
loop = asyncio.get_event_loop()
image_inputs = await loop.run_in_executor(
self.executor,
DeepseekVL2ImageProcessor._process_images_task,
image_data,
input_text,
max_req_input_len,
)
else:
image_inputs = self._process_images_task(
image_data, input_text, max_req_input_len
)
return image_inputs
async def _process_images(self, image_data, input_text, max_req_input_len):
if self.executor is not None:
loop = asyncio.get_event_loop()
image_inputs = await loop.run_in_executor(
self.executor,
DeepseekVL2ImageProcessor._process_images_task,
image_data,
input_text,
max_req_input_len,
)
else:
image_inputs = self._process_images_task(
image_data, input_text, max_req_input_len
)
return image_inputs
async def process_mm_data_async(
self, image_data, input_ids, request_obj, max_req_input_len, *args, **kwargs
):
@@ -89,8 +43,6 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
if not isinstance(image_data, list):
image_data = [image_data]
images, image_sizes = [], []
image_token = self.IMAGE_TOKEN
base_output = self.load_mm_data(
input_ids,
@@ -98,8 +50,11 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
max_req_input_len=max_req_input_len,
)
res = await self._process_images(
base_output.images, base_output.input_text, max_req_input_len
res = self.process_mm_data(
input_text=base_output.input_text,
images=base_output.images,
max_req_input_len=max_req_input_len,
conversations=base_output.input_text,
)
images_seq_mask = res["images_seq_mask"]
images_spatial_crop = res["images_spatial_crop"]
@@ -107,13 +62,17 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
batched_images_spatial_crop.append(images_spatial_crop)
batched_images_spatial_crop = torch.stack(batched_images_spatial_crop, dim=0)
items = []
item = MultimodalDataItem(
pixel_values=res["images"],
modality=Modality.IMAGE,
image_emb_mask=images_seq_mask,
image_spatial_crop=batched_images_spatial_crop,
)
items += [item]
return {
"mm_items": items,
"input_ids": res["input_ids"].tolist(),
"pixel_values": res["images"],
"im_token_id": res["im_token_id"],
"data_hashes": base_output.mm_data_hashes,
"image_sizes": image_sizes,
"images_emb_mask": images_seq_mask,
"image_spatial_crop": batched_images_spatial_crop,
"modalities": request_obj.modalities or ["image"],
"im_token_id": self._processor.image_token_id,
}

View File

@@ -7,8 +7,8 @@ from sglang.srt.managers.multimodal_processor import (
)
from sglang.srt.managers.multimodal_processors.base_processor import (
MultimodalSpecialTokens,
get_global_processor,
)
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.models.gemma3_mm import Gemma3ForConditionalGeneration
# Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma3/image_processing_gemma3_fast.py
@@ -25,28 +25,6 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
self.IM_START_TOKEN_ID = hf_config.boi_token_index
self.IM_END_TOKEN_ID = hf_config.eoi_token_index
async def _process_single_image(self, images, input_text) -> dict:
if isinstance(images, list) and len(images) == 0:
images = None
processor = get_global_processor()
result = processor.__call__(
text=[input_text],
images=images,
padding=True,
return_tensors="pt",
# if RGBA, this needs to be set
# images_kwargs={
# "input_data_format": ChannelDimension.FIRST
# }
)
pixel_values = getattr(result, "pixel_values", None)
return {
"input_ids": result.input_ids,
"pixel_values": pixel_values,
}
async def process_mm_data_async(
self,
image_data: List[Union[str, bytes]],
@@ -63,21 +41,28 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
image_token = self.IMAGE_TOKEN
base_output = self.load_mm_data(
input_ids=input_ids,
prompt=input_ids,
image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
max_req_input_len=max_req_input_len,
discard_alpha_channel=True,
)
ret = await self._process_single_image(
ret = self.process_mm_data(
input_text=base_output.input_text, images=base_output.images
)
items = []
for i, image in enumerate(base_output.images):
item = MultimodalDataItem(
pixel_values=ret["pixel_values"][i],
modality=Modality.IMAGE,
)
items += [item]
return {
"mm_items": items,
"input_ids": ret["input_ids"].flatten().tolist(),
"pixel_values": ret["pixel_values"],
"data_hashes": base_output.mm_data_hashes,
"im_start_id": self.IM_START_TOKEN_ID,
"im_end_id": self.IM_END_TOKEN_ID,
}

View File

@@ -1,11 +1,10 @@
import asyncio
from typing import List, Union
from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor,
MultimodalSpecialTokens,
get_global_processor,
)
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.models.deepseek_janus_pro import MultiModalityCausalLM
@@ -15,37 +14,6 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
@staticmethod
def _process_images_task(images, input_text):
processor = get_global_processor()
result = processor.__call__(
prompt=input_text, images=images, return_tensors="pt"
)
return {
"input_ids": result["input_ids"],
"pixel_values": result["pixel_values"],
"images_emb_mask": result["images_emb_mask"],
"im_start_id": processor.image_start_id,
"im_end_id": processor.image_end_id,
"im_token_id": processor.image_id,
}
async def _process_images(self, images, input_text):
if self.executor is not None:
loop = asyncio.get_event_loop()
image_inputs = await loop.run_in_executor(
self.executor,
JanusProImageProcessor._process_images_task,
images,
input_text,
)
else:
image_inputs = self._processor(
images=images, text=input_text, return_tensors="pt"
)
return image_inputs
async def process_mm_data_async(
self,
image_data: List[Union[str, bytes]],
@@ -60,25 +28,31 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
if not isinstance(image_data, list):
image_data = [image_data]
processor = self._processor
base_out = self.load_mm_data(
input_ids=input_ids,
prompt=input_ids,
image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens(
image_token="<image_placeholder>"
),
multimodal_tokens=MultimodalSpecialTokens(image_token=processor.image_tag),
max_req_input_len=max_req_input_len,
)
images = base_out.images
res = await self._process_images(images=images, input_text=base_out.input_text)
# print(res)
# print(base_out)
# print("", res["images_emb_mask"].shape)
res = self.process_mm_data(
input_text=base_out.input_text,
prompt=base_out.input_text,
images=images,
)
return {
"mm_items": [
MultimodalDataItem(
pixel_values=res["pixel_values"],
image_emb_mask=res["images_emb_mask"],
modality=Modality.IMAGE,
)
],
"input_ids": res["input_ids"].flatten().tolist(),
"pixel_values": res["pixel_values"],
"images_emb_mask": res["images_emb_mask"],
"data_hashes": base_out.mm_data_hashes,
"im_start_id": res["im_start_id"],
"im_end_id": res["im_end_id"],
"im_token_id": res["im_token_id"],
"im_start_id": processor.image_start_id,
"im_end_id": processor.image_end_id,
"im_token_id": processor.image_id,
}

View File

@@ -5,8 +5,8 @@ import numpy as np
from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor,
get_global_processor,
)
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.mm_utils import expand2square, process_anyres_image
from sglang.srt.models.llava import LlavaMistralForCausalLM, LlavaQwenForCausalLM
from sglang.srt.models.llavavid import LlavaVidForCausalLM
@@ -25,11 +25,10 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
image_data: Union[str, bytes],
image_aspect_ratio: Optional[str] = None,
image_grid_pinpoints: Optional[str] = None,
image_processor=None,
processor=None,
):
processor = get_global_processor()
image_processor = image_processor or processor.image_processor
image_processor = processor.image_processor
try:
image, image_size = load_image(image_data)
@@ -72,18 +71,22 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
async def _process_single_image(
self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str
):
if self.executor is not None:
if self.cpu_executor is not None:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self.executor,
self.cpu_executor,
LlavaImageProcessor._process_single_image_task,
image_data,
aspect_ratio,
grid_pinpoints,
self._processor,
)
else:
return self._process_single_image_task(
image_data, aspect_ratio, grid_pinpoints
image_data,
aspect_ratio,
grid_pinpoints,
self._processor.image_processor,
)
async def process_mm_data_async(
@@ -134,14 +137,22 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
pixel_values, image_hash, image_size = await self._process_single_image(
image_data[0], aspect_ratio, grid_pinpoints
)
data_hashes = [image_hash]
image_sizes = [image_size]
else:
raise ValueError(f"Invalid image data: {image_data}")
modality = Modality.IMAGE
if isinstance(request_obj.modalities, list):
if request_obj.modalities[0] == "multi-images":
modality = Modality.MULTI_IMAGES
elif request_obj.modalities[0] == "video":
modality = Modality.VIDEO
return {
"pixel_values": pixel_values,
"data_hashes": data_hashes,
"image_sizes": image_sizes,
"modalities": request_obj.modalities or ["image"],
"mm_items": [
MultimodalDataItem(
pixel_values=pixel_values,
image_sizes=image_sizes,
modality=modality,
)
],
}

View File

@@ -1,13 +1,13 @@
import asyncio
from typing import List, Union
import torch
from transformers import BaseImageProcessorFast
from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor,
MultimodalSpecialTokens,
get_global_processor,
)
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.models.minicpmo import MiniCPMO
from sglang.srt.models.minicpmv import MiniCPMV
@@ -21,19 +21,23 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
self.image_token = "(<image>./</image>)"
self.audio_token = "(<audio>./</audio>)"
@staticmethod
def _process_data_task(input_text, images=None, audios=None):
def process_data_task(self, input_text, images=None, audios=None):
if isinstance(images, list) and len(images) == 0:
images = None
if isinstance(audios, list) and len(audios) == 0:
audios = None
result = get_global_processor().__call__(
processor = self._processor
args = {}
if isinstance(processor, BaseImageProcessorFast):
args["device"] = "cuda"
result = self._processor.__call__(
text=input_text,
images=images,
audios=audios,
return_tensors="pt",
chunk_input=True,
**args,
)
return {
"input_ids": result.input_ids,
@@ -44,23 +48,6 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
"audio_bounds": getattr(result, "audio_bounds", None),
}
async def _process_data(self, images, input_text, audios=None):
if self.executor is not None:
loop = asyncio.get_event_loop()
multimodal_data_inputs = await loop.run_in_executor(
self.executor,
MiniCPMMultimodalProcessor._process_data_task,
input_text,
images,
audios,
)
else:
multimodal_data_inputs = self._processor(
images=images, text=input_text, audios=audios, return_tensors="pt"
)
return multimodal_data_inputs
async def process_mm_data_async(
self,
image_data: List[Union[str, bytes]],
@@ -77,7 +64,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
audio_data = [audio_data]
base_output = self.load_mm_data(
input_ids=input_ids,
prompt=input_ids,
max_req_input_len=max_req_input_len,
audio_data=audio_data,
image_data=image_data,
@@ -88,9 +75,9 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
if base_output is None:
return None
res = await self._process_data(
images=base_output.images,
res = self.process_mm_data(
input_text=base_output.input_text,
images=base_output.images,
audios=base_output.audios,
)
@@ -142,23 +129,33 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
tgt_sizes_flat += [tgt_n]
pixel_values = pixel_values_flat
if len(tgt_sizes_flat) == 0:
tgt_sizes = None
else:
tgt_sizes = torch.stack(tgt_sizes_flat)
if not isinstance(res["audio_features"], list):
res["audio_features"] = [res["audio_features"]]
items = []
if len(pixel_values) != 0:
item = MultimodalDataItem(
pixel_values=pixel_values,
tgt_size=tgt_sizes_flat,
modality=Modality.IMAGE,
)
items += [item]
if (
"audio_features" in res
and res["audio_features"] is not None
and len(res["audio_features"]) != 0
):
item = MultimodalDataItem(
audio_features=[res["audio_features"]],
audio_feature_lens=res["audio_feature_lens"],
modality=Modality.AUDIO,
)
items += [item]
return {
"mm_items": items,
"input_ids": res["input_ids"].flatten().tolist(),
"pixel_values": pixel_values,
"tgt_sizes": tgt_sizes,
"data_hashes": base_output.mm_data_hashes,
"modalities": request_obj.modalities or ["image"],
"audio_start_id": audio_start_id,
"audio_end_id": audio_end_id,
"audio_features": res["audio_features"],
"audio_bounds": res["audio_bounds"],
"audio_feature_lens": res["audio_feature_lens"],
"im_token_id": im_token_id,
"im_start_id": tokenizer.im_start_id,
"im_end_id": tokenizer.im_end_id,

View File

@@ -1,10 +1,9 @@
import asyncio
from typing import List, Union
from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor,
get_global_processor,
)
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.models.mllama import MllamaForConditionalGeneration
from sglang.srt.utils import load_image
@@ -15,25 +14,6 @@ class MllamaImageProcessor(BaseMultimodalProcessor):
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
@staticmethod
def _process_single_image_task(images, input_text):
# input_ids', 'attention_mask', 'pixel_values', 'aspect_ratio_ids', 'aspect_ratio_mask', 'cross_attention_mask'
return get_global_processor()(images, input_text, return_tensors="pt")
async def _process_single_image(self, images, input_text):
if self.executor is not None:
loop = asyncio.get_event_loop()
image_inputs = await loop.run_in_executor(
self.executor,
MllamaImageProcessor._process_single_image_task,
images,
input_text,
)
else:
image_inputs = self._processor(images, input_text, return_tensors="pt")
return image_inputs
async def process_mm_data_async(
self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs
):
@@ -52,8 +32,15 @@ class MllamaImageProcessor(BaseMultimodalProcessor):
else:
images = load_image(image_data[0])[0]
image_inputs = await self._process_single_image(images, input_text)
image_inputs["data_hashes"] = [hash(str(image_data))]
image_inputs = self.process_mm_data(input_text=input_text, images=images)
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
image_inputs["mm_items"] = [
MultimodalDataItem(
pixel_values=image_inputs["pixel_values"],
aspect_ratio_id=image_inputs["aspect_ratio_ids"],
aspect_ratio_mask=image_inputs["aspect_ratio_mask"],
modality=Modality.IMAGE,
)
]
return image_inputs

View File

@@ -1,6 +1,5 @@
import asyncio
import math
import time
from typing import List, Union
import torch
@@ -11,8 +10,8 @@ from sglang.srt.managers.multimodal_processor import (
)
from sglang.srt.managers.multimodal_processors.base_processor import (
MultimodalSpecialTokens,
get_global_processor,
)
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
from sglang.srt.models.qwen2_vl import Qwen2VLForConditionalGeneration
@@ -34,45 +33,15 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
self.MAX_PIXELS = 16384 * 28 * 28
self.MAX_RATIO = 200
@staticmethod
def _process_images_task(images, input_text, _hf_config):
if isinstance(images, list) and len(images) == 0:
images = None
result = get_global_processor().__call__(
text=[input_text], images=images, padding=True, return_tensors="pt"
)
return {
"input_ids": result.input_ids,
"pixel_values": getattr(result, "pixel_values", None),
"image_grid_thw": getattr(result, "image_grid_thw", None),
"second_per_grid_ts": getattr(result, "second_per_grid_ts", None),
"video_grid_thws": getattr(result, "video_grid_thws", None),
}
async def _process_single_image(self, images, input_text) -> dict:
if self.executor is not None:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self.executor,
Qwen2_5VLImageProcessor._process_images_task,
images,
input_text,
self.hf_config,
)
else:
return self._process_images_task(images, input_text, self.hf_config)
async def process_mm_data_async(
self,
image_data: List[Union[str, bytes]],
input_ids,
prompt,
request_obj,
max_req_input_len,
*args,
**kwargs,
):
start = time.time()
if not image_data:
return None
if isinstance(image_data, str):
@@ -80,7 +49,7 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
image_token = self.IMAGE_TOKEN
base_output = self.load_mm_data(
input_ids=input_ids,
prompt=prompt,
image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
max_req_input_len=max_req_input_len,
@@ -144,24 +113,32 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
return math.floor(number / factor) * factor
images = [resize_image(image) for image in base_output.images]
async def resize_image_async(image):
return resize_image(image)
ret = await self._process_single_image(
images=images, input_text=base_output.input_text
resize_tasks = [resize_image_async(image) for image in base_output.images]
resized_images = await asyncio.gather(*resize_tasks)
ret = self.process_mm_data(
input_text=base_output.input_text,
images=resized_images,
)
image_grid_thws = torch.concat([ret["image_grid_thw"]])
video_grid_thws = None
return {
"input_ids": ret["input_ids"].flatten().tolist(),
"pixel_values": ret["pixel_values"],
"data_hashes": base_output.mm_data_hashes,
"modalities": request_obj.modalities or ["image"],
"image_grid_thws": image_grid_thws,
"video_grid_thws": video_grid_thws,
"mm_items": [
MultimodalDataItem(
pixel_values=ret["pixel_values"],
image_grid_thws=image_grid_thws,
# TODO
video_grid_thws=None,
second_per_grid_ts=ret.get("second_per_grid_ts", None),
modality=Modality.IMAGE,
)
],
"im_start_id": self.IM_START_TOKEN_ID,
"im_end_id": self.IM_END_TOKEN_ID,
"im_token_id": self.image_token_id,
"video_token_id": self.video_token_id,
"second_per_grid_ts": ret["second_per_grid_ts"],
}

View File

@@ -1,5 +1,7 @@
from __future__ import annotations
from enum import Enum, auto
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -51,7 +53,7 @@ from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, Forw
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_compiler_backend
from sglang.srt.utils import flatten_nested_list, get_compiler_backend
if TYPE_CHECKING:
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
@@ -143,165 +145,185 @@ class FINISH_ABORT(BaseFinishReason):
}
class Modality(Enum):
IMAGE = auto()
MULTI_IMAGES = auto()
VIDEO = auto()
AUDIO = auto()
@dataclasses.dataclass
class MultimodalInputs:
"""The image related inputs."""
class MultimodalDataItem:
"""
A single multimodal data, from a single image/video/audio or other
"""
pixel_values: Union[torch.Tensor, np.array]
data_hashes: Optional[list] = None
image_sizes: Optional[list] = None
image_offsets: Optional[list] = None
image_pad_len: Optional[list] = None
pad_values: Optional[list] = None
modalities: Optional[list] = None
num_image_tokens: Optional[int] = None
modality: Modality
# Llava related
aspect_ratio_ids: Optional[List[torch.Tensor]] = None
hash: int = None
pad_value: int = None
aspect_ratio_id: Optional[List[torch.Tensor]] = None
aspect_ratio_mask: Optional[List[torch.Tensor]] = None
# QWen2-VL related
# [num_of_images, t, h, w]
image_grid_thws: torch.Tensor = None
mrope_position_delta: Optional[torch.Tensor] = None
# Qwen2-VL video related
video_token_id: Optional[int] = None
video_grid_thws: List[Tuple[int, int, int]] = None
image_sizes: Tuple[int, int] = None
image_offsets: Optional[list] = None
# the real data, pixel_values or audio_features
# data: Union[List[torch.Tensor], List[np.array]]
pixel_values: Union[torch.Tensor, np.array] = None
image_grid_thws: Union[torch.Tensor, np.array] = None
video_grid_thws: Union[torch.Tensor, np.array] = None
image_emb_mask: Optional[torch.Tensor] = None
image_spatial_crop: Optional[torch.Tensor] = None
second_per_grid_ts: Optional[List[torch.Tensor]] = None
# deepseek vl2 related
images_emb_mask: Optional[List[torch.Tensor]] = None
image_spatial_crop: Optional[List[torch.Tensor]] = None
# [num_images, (n, w, h)]
tgt_size: Tuple[int, int] = None
# The id of the single-image placeholder token
audio_features: Union[torch.Tensor, np.array] = None
audio_feature_lens: Optional[List[torch.Tensor]] = None
@staticmethod
def is_empty_list(l):
if l is None:
return True
return len([item for item in flatten_nested_list(l) if item is not None]) == 0
def set_pad_value(self):
"""
Set the pad value after first hashign the data
"""
def hash_feature(f):
if isinstance(f, list):
return hash(tuple(flatten_nested_list(f)))
elif isinstance(f, np.ndarray):
arr = np.ascontiguousarray(f)
arr_bytes = arr.tobytes()
return hash(arr_bytes)
return hash(f)
if self.is_audio():
self.hash = hash_feature(self.audio_features)
else:
self.hash = hash_feature(self.pixel_values)
assert self.hash is not None
self.pad_value = self.hash % (1 << 30)
def is_audio(self):
return (
self.modality == Modality.AUDIO
) and not MultimodalDataItem.is_empty_list(self.audio_features)
def is_image(self):
return (
self.modality == Modality.IMAGE or self.modality == Modality.MULTI_IMAGES
) and not MultimodalDataItem.is_empty_list(self.pixel_values)
def is_video(self):
return (
self.modality == Modality.VIDEO
) and not MultimodalDataItem.is_empty_list(self.pixel_values)
def validate(self):
...
# TODO
@dataclasses.dataclass
class MultimodalInputs:
"""The multimodal data related inputs."""
# items of data
mm_items: List[MultimodalDataItem]
image_pad_len: Optional[list] = None
num_image_tokens: Optional[int] = None
# QWen2-VL related
mrope_position_delta: Optional[torch.Tensor] = None
# image
im_token_id: Optional[torch.Tensor] = None
# All the images in the batch should share the same special image
# bound token ids.
im_start_id: Optional[int] = None
im_end_id: Optional[int] = None
slice_start_id: Optional[int] = None
slice_end_id: Optional[int] = None
# [num_images, 2 (w, h)]
tgt_sizes: Optional[list] = None
# video
video_token_id: Optional[int] = None
# audio
audio_start_id: Optional[torch.Tensor] = None
audio_end_id: Optional[torch.Tensor] = None
audio_features: Optional[List[torch.Tensor]] = None
audio_feature_lens: Optional[List[torch.Tensor]] = None
@staticmethod
def from_dict(obj: dict):
ret = MultimodalInputs(
pixel_values=obj["pixel_values"],
data_hashes=obj["data_hashes"],
mm_items=obj["mm_items"],
)
assert isinstance(ret.mm_items, list)
ret.mm_items = [
item
for item in ret.mm_items
if item.is_audio() or item.is_image() or item.is_video()
]
assert len(ret.mm_items) != 0
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
# Please note that if the `input_ids` is later used in the model forward,
# you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
# errors in cuda kernels. See also llava.py for example.
ret.pad_values = [x % (1 << 30) for x in ret.data_hashes]
for item in ret.mm_items:
item.set_pad_value()
optional_args = [
"image_sizes",
"modalities",
"aspect_ratio_ids",
"aspect_ratio_mask",
"image_grid_thws",
"images_emb_mask",
"image_spatial_crop",
"im_token_id",
"im_start_id",
"im_end_id",
"slice_start_id",
"slice_end_id",
"tgt_sizes",
"audio_start_id",
"audio_end_id",
"audio_features",
"audio_feature_lens",
]
for arg in optional_args:
if arg in obj:
setattr(ret, arg, obj[arg])
# validate
assert (
isinstance(ret.pixel_values, torch.Tensor)
or isinstance(ret.pixel_values, np.ndarray)
or isinstance(ret.pixel_values, list)
)
assert ret.audio_features is None or isinstance(ret.audio_features, list)
return ret
def contains_image_inputs(self) -> bool:
""" """
return self.pixel_values is not None and self.pixel_values != []
return any(item.is_image() for item in self.mm_items)
def contains_audio_inputs(self) -> bool:
""" """
return self.audio_features is not None and self.audio_features != []
return any(item.is_audio() for item in self.mm_items)
def collect_image_inputs(self) -> List[torch.Tensor]:
return [item.pixel_values for item in self.mm_items if item.is_image()]
def merge(self, other: MultimodalInputs):
"""
merge image inputs when requests are being merged
"""
if isinstance(self.pixel_values, list):
# in some rare cases, pixel values are list of patches with different shapes
# e.g. minicpm
self.pixel_values += other.pixel_values
else:
assert (
self.pixel_values.shape[1:] == other.pixel_values.shape[1:]
), f"{self.pixel_values.shape[1:]} vs {other.pixel_values.shape[1:]}"
self.pixel_values = np.concatenate([self.pixel_values, other.pixel_values])
# args would be stacked along first dim
# usually these are already tensors
stack_args = [
# TODO: merge with image_grid_thws, basically the same thing
"tgt_sizes",
"image_spatial_crop",
]
for arg in stack_args:
if getattr(self, arg, None) is None:
setattr(self, arg, getattr(other, arg, None))
elif getattr(other, arg, None) is not None:
# self and other both not None
setattr(
self,
arg,
torch.cat([getattr(self, arg), getattr(other, arg)], dim=0),
)
if self.image_grid_thws is None:
self.image_grid_thws = other.image_grid_thws
elif other.image_grid_thws is not None:
self.image_grid_thws = torch.concat(
[self.image_grid_thws, other.image_grid_thws]
)
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
# Please note that if the `input_ids` is later used in the model forward,
# you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
# errors in cuda kernels. See also llava.py for example.
self.data_hashes += other.data_hashes
self.pad_values = [x % (1 << 30) for x in self.data_hashes]
# args needed to be merged
optional_args = [
"audio_features",
"image_sizes",
"items",
"image_offsets",
"image_pad_len",
# "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images
"aspect_ratio_ids",
"aspect_ratio_mask",
"images_emb_mask",
]
for arg in optional_args:
self_arg = getattr(self, arg, None)

View File

@@ -112,7 +112,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter

View File

@@ -1,11 +1,6 @@
import json
import logging
import time
from collections import defaultdict
from http import HTTPStatus
from typing import Dict, List, Optional, Tuple
import torch
from typing import Optional
from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req