refactor: multimodal data (#4754)
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
"""
|
||||
Multimodality utils
|
||||
Multi-modality utils
|
||||
"""
|
||||
|
||||
from abc import abstractmethod
|
||||
@@ -9,11 +9,13 @@ import torch
|
||||
from torch import nn
|
||||
|
||||
from sglang.srt.managers.schedule_batch import (
|
||||
MultimodalDataItem,
|
||||
MultimodalInputs,
|
||||
global_server_args_dict,
|
||||
logger,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.utils import print_warning_once
|
||||
from sglang.utils import logger
|
||||
|
||||
|
||||
@@ -26,7 +28,7 @@ class MultiModalityDataPaddingPattern:
|
||||
|
||||
@abstractmethod
|
||||
def pad_input_tokens(
|
||||
self, input_ids: List[int], image_inputs: MultimodalInputs
|
||||
self, input_ids: List[int], mm_inputs: MultimodalInputs
|
||||
) -> List[int]:
|
||||
"""
|
||||
Pad the input ids sequence containing data tokens, and replace them with pad_values
|
||||
@@ -49,13 +51,13 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
|
||||
"""
|
||||
This function will replace the data-tokens inbetween with pad_values accordingly
|
||||
"""
|
||||
pad_values = mm_inputs.pad_values
|
||||
pad_values = [item.pad_value for item in mm_inputs.mm_items]
|
||||
data_token_pairs = self.data_token_id_pairs
|
||||
mm_inputs.image_offsets = []
|
||||
mm_inputs.data_offsets = []
|
||||
if data_token_pairs is None:
|
||||
data_token_pairs = [mm_inputs.im_start_id, mm_inputs.im_end_id]
|
||||
if data_token_pairs is None:
|
||||
logger.warning(
|
||||
print_warning_once(
|
||||
"No data_token_pairs provided, RadixAttention might be influenced."
|
||||
)
|
||||
return input_ids
|
||||
@@ -77,10 +79,10 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
|
||||
|
||||
if input_ids[start_idx] in start_token_ids:
|
||||
data_idx += 1
|
||||
mm_inputs.image_offsets += [start_idx]
|
||||
mm_inputs.data_offsets += [start_idx]
|
||||
|
||||
if data_idx >= len(mm_inputs.pad_values):
|
||||
data_idx = len(mm_inputs.pad_values) - 1
|
||||
if data_idx >= len(pad_values):
|
||||
data_idx = len(pad_values) - 1
|
||||
|
||||
num_tokens = end_idx - start_idx - 1
|
||||
pad_value = pad_values[data_idx]
|
||||
@@ -94,68 +96,19 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
|
||||
return padded_ids
|
||||
|
||||
|
||||
class MultModalityDataPaddingPatternSingleToken(MultiModalityDataPaddingPattern):
|
||||
"""In this pattern, data is represented with a special token_id ( image_inputs.im_token_id ),
|
||||
which needs first to be expanded to multiple tokens, then replaced with their padding values
|
||||
|
||||
This strategy should be used when a single data token represents content that should
|
||||
be expanded to multiple tokens during processing.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, num_data_token_calc_func: Callable[[Tuple[int, int, int]], int]
|
||||
) -> None:
|
||||
self.num_data_token_calc_func = num_data_token_calc_func
|
||||
|
||||
def pad_input_tokens(
|
||||
self, input_ids: List[int], mm_inputs: MultimodalInputs
|
||||
) -> List[int]:
|
||||
"""
|
||||
This function will follow the procedure of:
|
||||
1. the data token will be expanded, of which the final number will be calculated by `num_data_token_calc_func`
|
||||
2. the padded data tokens will be replaced with their pad_values
|
||||
"""
|
||||
image_grid_thws = mm_inputs.image_grid_thws
|
||||
pad_values = mm_inputs.pad_values
|
||||
|
||||
image_indices = [
|
||||
idx for idx, token in enumerate(input_ids) if token == mm_inputs.im_token_id
|
||||
]
|
||||
|
||||
mm_inputs.image_offsets = []
|
||||
|
||||
input_ids_with_image = []
|
||||
for image_cnt, _ in enumerate(image_grid_thws):
|
||||
# print(f"image_cnt {image_cnt}")
|
||||
num_image_tokens = self.num_data_token_calc_func(image_grid_thws[image_cnt])
|
||||
if image_cnt == 0:
|
||||
non_image_tokens = input_ids[: image_indices[image_cnt]]
|
||||
else:
|
||||
non_image_tokens = input_ids[
|
||||
image_indices[image_cnt - 1] + 1 : image_indices[image_cnt]
|
||||
]
|
||||
input_ids_with_image.extend(non_image_tokens)
|
||||
mm_inputs.image_offsets.append(len(input_ids_with_image))
|
||||
pad_ids = pad_values * (
|
||||
(num_image_tokens + len(pad_values)) // len(pad_values)
|
||||
)
|
||||
input_ids_with_image.extend(pad_ids[:num_image_tokens])
|
||||
input_ids_with_image.extend(input_ids[image_indices[-1] + 1 :])
|
||||
|
||||
return input_ids_with_image
|
||||
|
||||
|
||||
class MultiModalityDataPaddingPatternImageTokens(MultiModalityDataPaddingPattern):
|
||||
"""In this pattern, data tokens should be represented as image tokens (e.g. <image><image>....<image>)"""
|
||||
"""In this pattern, data tokens should be represented as repetitions of a single token
|
||||
e.g. <image><image>....<image>, or <audio><audio>...<audio>
|
||||
"""
|
||||
|
||||
def __init__(self, image_token_id: torch.Tensor) -> None:
|
||||
self.image_token_id = image_token_id
|
||||
|
||||
def pad_input_tokens(self, input_ids: List[int], image_inputs) -> List[int]:
|
||||
def pad_input_tokens(self, input_ids: List[int], mm_inputs) -> List[int]:
|
||||
"""
|
||||
This function will replace the data-tokens in between with pad_values accordingly
|
||||
"""
|
||||
pad_values = image_inputs.pad_values
|
||||
pad_values = [item.pad_value for item in mm_inputs.mm_items]
|
||||
assert len(pad_values) != 0
|
||||
|
||||
input_ids_tensor = torch.tensor(input_ids)
|
||||
@@ -170,138 +123,227 @@ class MultiModalityDataPaddingPatternImageTokens(MultiModalityDataPaddingPattern
|
||||
return input_ids_tensor.tolist()
|
||||
|
||||
|
||||
def get_embedding_and_mask(
|
||||
data_embedding_func: Callable[[List[MultimodalDataItem]], torch.Tensor],
|
||||
embedding_items: List[MultimodalDataItem],
|
||||
placeholder_tensor: torch.Tensor,
|
||||
input_ids: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Get the multimodal embedding and its mask from input_ids
|
||||
|
||||
"""
|
||||
# 1. Get the embedding
|
||||
embedding = data_embedding_func(embedding_items)
|
||||
|
||||
# 2. Check the embedding
|
||||
if embedding.dim() == 2:
|
||||
num_mm_tokens_in_embedding = embedding.shape[0]
|
||||
else:
|
||||
num_mm_tokens_in_embedding = embedding.shape[0] * embedding.shape[1]
|
||||
|
||||
# the mask of multimodal tokens from input_ids
|
||||
special_multimodal_mask = torch.isin(
|
||||
input_ids,
|
||||
placeholder_tensor,
|
||||
).unsqueeze(-1)
|
||||
|
||||
num_mm_tokens_in_input_ids = special_multimodal_mask.sum()
|
||||
if num_mm_tokens_in_input_ids != num_mm_tokens_in_embedding:
|
||||
logger.warning(
|
||||
f"Number of tokens in multimodal embedding does not match those in the input text."
|
||||
f"Got {num_mm_tokens_in_input_ids} tokens in the text but {num_mm_tokens_in_embedding} "
|
||||
"tokens from multimodal embeddings."
|
||||
)
|
||||
if num_mm_tokens_in_input_ids < num_mm_tokens_in_embedding:
|
||||
# TODO: chunked prefill will split special tokens from input_ids into several passes, failing the embedding
|
||||
# a fix may be cache the unfinished multimodal embedding for future reuse, determine the tokens to embed with
|
||||
# extend_start_loc and extend_seq_lens
|
||||
chunked_prefill_size = global_server_args_dict["chunked_prefill_size"]
|
||||
if chunked_prefill_size != -1:
|
||||
logger.warning(
|
||||
"You may want to avoid this issue by raising `chunked_prefill_size`, or disabling chunked prefill"
|
||||
)
|
||||
# extract from the end: this is a compromise
|
||||
if embedding.dim() == 2:
|
||||
embedding = embedding[-num_mm_tokens_in_input_ids:, :]
|
||||
else:
|
||||
num_multimodal = num_mm_tokens_in_input_ids // embedding.shape[0]
|
||||
embedding = embedding[-num_multimodal:, :]
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Insufficient multimodal embedding length. This is an internal error"
|
||||
)
|
||||
|
||||
return embedding, special_multimodal_mask
|
||||
|
||||
|
||||
def embed_mm_inputs(
|
||||
mm_input: MultimodalInputs,
|
||||
mm_inputs: MultimodalInputs,
|
||||
input_ids: torch.Tensor,
|
||||
input_embedding: nn.Embedding,
|
||||
mm_data_embedding_func: Callable[[MultimodalInputs], torch.Tensor],
|
||||
image_data_embedding_func: Callable[
|
||||
[List[MultimodalDataItem]], torch.Tensor
|
||||
] = None,
|
||||
audio_data_embedding_func: Callable[
|
||||
[List[MultimodalDataItem]], torch.Tensor
|
||||
] = None,
|
||||
placeholder_token_ids: List[int] = None,
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Calculate the image embeddings if necessary, then scatter the result with
|
||||
the help of a boolean mask denoting the embed locations
|
||||
Calculate the multimodal embeddings if necessary, then scatter the result with the help of a boolean mask denoting the embed locations
|
||||
|
||||
Returns:
|
||||
final embedding: Optional[torch.Tensor]
|
||||
Args:
|
||||
placeholder_token_ids: denoting the token of multimodal data in input_ids.
|
||||
If none, the pad_values of multimodal items are used
|
||||
|
||||
Returns:
|
||||
final embedding: Optional[torch.Tensor]
|
||||
"""
|
||||
if mm_input is None:
|
||||
|
||||
if mm_inputs is None:
|
||||
return None
|
||||
|
||||
placeholder_token_ids = placeholder_token_ids or mm_input.pad_values
|
||||
# 1. Calculate the multimodal data which exists in input_ids, with the help of pad_values
|
||||
# we assume that multimodal data are represented with its pad_values in input_ids
|
||||
placeholder_token_ids = placeholder_token_ids or [
|
||||
item.pad_value for item in mm_inputs.mm_items
|
||||
]
|
||||
|
||||
# boolean masking the special tokens
|
||||
special_image_mask = torch.isin(
|
||||
input_ids,
|
||||
torch.tensor(placeholder_token_ids, device=input_ids.device),
|
||||
).unsqueeze(-1)
|
||||
placeholder_tensor = torch.tensor(placeholder_token_ids, device=input_ids.device)
|
||||
|
||||
num_image_tokens_in_input_ids = special_image_mask.sum()
|
||||
# print(f"{num_image_tokens_in_input_ids}")
|
||||
# print(f"{input_ids}")
|
||||
placeholder_masks = torch.isin(input_ids, placeholder_tensor)
|
||||
|
||||
# return
|
||||
if num_image_tokens_in_input_ids == 0:
|
||||
# unexpected
|
||||
appearing_pad_values = torch.unique(
|
||||
input_ids[placeholder_masks], return_counts=False
|
||||
)
|
||||
|
||||
if appearing_pad_values.numel() == 0:
|
||||
# all been prefixed
|
||||
inputs_embeds = input_embedding(input_ids)
|
||||
else:
|
||||
# print(f"Getting image feature")
|
||||
image_embedding = mm_data_embedding_func(mm_input)
|
||||
appearing_items = [
|
||||
item
|
||||
for item in mm_inputs.mm_items
|
||||
if item.pad_value is not None and item.pad_value in appearing_pad_values
|
||||
]
|
||||
|
||||
# print(f"image_embedding: {image_embedding.shape}")
|
||||
|
||||
if image_embedding.dim() == 2:
|
||||
num_image_tokens_in_embedding = image_embedding.shape[0]
|
||||
else:
|
||||
num_image_tokens_in_embedding = (
|
||||
image_embedding.shape[0] * image_embedding.shape[1]
|
||||
)
|
||||
if num_image_tokens_in_input_ids != num_image_tokens_in_embedding:
|
||||
num_image = num_image_tokens_in_input_ids // image_embedding.shape[1]
|
||||
image_embedding = image_embedding[:num_image, :]
|
||||
logger.warning(
|
||||
f"Number of images does not match number of special image tokens in the input text. "
|
||||
f"Got {num_image_tokens_in_input_ids} image tokens in the text but {num_image_tokens_in_embedding} "
|
||||
"tokens from image embeddings."
|
||||
using_all_items = False
|
||||
if len(appearing_items) == 0:
|
||||
# This happens mostly when arg placeholder_token_ids is passed
|
||||
logger.warning_once(
|
||||
"No multimodal data item's pad value exist in placeholder ids. Using all items"
|
||||
)
|
||||
using_all_items = True
|
||||
appearing_items = mm_inputs.mm_items
|
||||
|
||||
# TODO: chunked prefill will split special tokens from input_ids into several passes, failing the embedding
|
||||
# a fix may be cache the unfinished image embedding for future reuse, determine the tokens to embed with
|
||||
# extend_start_loc and extend_seq_lens
|
||||
if num_image_tokens_in_input_ids > num_image_tokens_in_embedding:
|
||||
chunked_prefill_size = global_server_args_dict["chunked_prefill_size"]
|
||||
if chunked_prefill_size != -1:
|
||||
logger.warning(
|
||||
"You may want to avoid this issue by raising `chunked_prefill_size`, or disabling chunked_prefill"
|
||||
embeddings, masks = [], []
|
||||
|
||||
# 2. Get multimodal embedding separately
|
||||
# TODO: make this more generic
|
||||
# Try get image embedding if any
|
||||
if (
|
||||
any(True for item in appearing_items if item.is_image())
|
||||
and image_data_embedding_func
|
||||
):
|
||||
items = [item for item in appearing_items if item.is_image()]
|
||||
embedding, mask = get_embedding_and_mask(
|
||||
data_embedding_func=image_data_embedding_func,
|
||||
embedding_items=items,
|
||||
placeholder_tensor=(
|
||||
placeholder_tensor
|
||||
if using_all_items
|
||||
else torch.tensor(
|
||||
[item.pad_value for item in items],
|
||||
device=input_ids.device,
|
||||
)
|
||||
),
|
||||
input_ids=input_ids,
|
||||
)
|
||||
embeddings += [embedding]
|
||||
masks += [mask]
|
||||
|
||||
# Try get audio embedding if any
|
||||
if (
|
||||
any(True for item in appearing_items if item.is_audio())
|
||||
and audio_data_embedding_func
|
||||
):
|
||||
items = [item for item in appearing_items if item.is_audio()]
|
||||
embedding, mask = get_embedding_and_mask(
|
||||
data_embedding_func=audio_data_embedding_func,
|
||||
embedding_items=items,
|
||||
placeholder_tensor=(
|
||||
placeholder_tensor
|
||||
if using_all_items
|
||||
else torch.tensor(
|
||||
[item.pad_value for item in items],
|
||||
device=input_ids.device,
|
||||
)
|
||||
),
|
||||
input_ids=input_ids,
|
||||
)
|
||||
embeddings += [embedding]
|
||||
masks += [mask]
|
||||
|
||||
# 3. Get input embeddings
|
||||
vocab_size = input_embedding.num_embeddings
|
||||
# Important: clamp after getting original image regions
|
||||
# Clamp input ids. This is because the input_ids for the image tokens are
|
||||
# filled with the hash values of the image for the prefix matching in the radix attention.
|
||||
# Important: clamp after getting original multimodal regions
|
||||
# Clamp input ids. This is because the input_ids for the multimodal tokens are
|
||||
# filled with the hash values of the multimodal for the prefix matching in the radix attention.
|
||||
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
|
||||
input_ids.clamp_(min=0, max=vocab_size - 1)
|
||||
inputs_embeds = input_embedding(input_ids)
|
||||
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
|
||||
inputs_embeds.device
|
||||
)
|
||||
|
||||
inputs_embeds = inputs_embeds.masked_scatter(
|
||||
special_image_mask,
|
||||
image_embedding.to(inputs_embeds.device, inputs_embeds.dtype),
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
|
||||
def embed_image_embedding(
|
||||
inputs_embeds: torch.Tensor,
|
||||
image_embedding: torch.Tensor,
|
||||
image_bounds: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
scatter image_embedding into inputs_embeds according to image_bounds
|
||||
"""
|
||||
if len(image_bounds) > 0:
|
||||
image_indices = torch.stack(
|
||||
[
|
||||
torch.arange(start, end, dtype=torch.long)
|
||||
for start, end in image_bounds.tolist()
|
||||
]
|
||||
).to(inputs_embeds.device)
|
||||
|
||||
inputs_embeds.scatter_(
|
||||
0,
|
||||
image_indices.view(-1, 1).repeat(1, inputs_embeds.shape[-1]),
|
||||
image_embedding.view(-1, image_embedding.shape[-1]),
|
||||
)
|
||||
# 4. scatter embeddings into input embedding
|
||||
for embedding, mask in zip(embeddings, masks):
|
||||
mask = mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(
|
||||
mask,
|
||||
embedding.to(inputs_embeds.device, inputs_embeds.dtype),
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
|
||||
def general_mm_embed_routine(
|
||||
input_ids: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
embed_tokens: nn.Embedding,
|
||||
mm_data_embedding_func: Callable[[MultimodalInputs], torch.Tensor],
|
||||
language_model: nn.Module,
|
||||
image_data_embedding_func: Callable[
|
||||
[List[MultimodalDataItem]], torch.Tensor
|
||||
] = None,
|
||||
audio_data_embedding_func: Callable[
|
||||
[List[MultimodalDataItem]], torch.Tensor
|
||||
] = None,
|
||||
placeholder_token_ids: List[int] = None,
|
||||
):
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
a general wrapper function to get final input embeds from multimodal models
|
||||
with a language model as causal model
|
||||
A general wrapper function to get final input embeds from multimodal models with a language model as causal model
|
||||
|
||||
Args:
|
||||
placeholder_token_ids (List[int]): the ids of mm data placeholder tokens
|
||||
image_data_embedding_func : the function returning the image embedding
|
||||
audio_data_embedding_func : the function returning the image embedding
|
||||
|
||||
Returns:
|
||||
inputs_embedding
|
||||
forwarded hidden states
|
||||
|
||||
"""
|
||||
|
||||
assert hasattr(language_model, "get_input_embeddings")
|
||||
embed_tokens = language_model.get_input_embeddings()
|
||||
if (
|
||||
not forward_batch.forward_mode.is_decode()
|
||||
and forward_batch.contains_mm_inputs()
|
||||
):
|
||||
image = forward_batch.merge_mm_inputs()
|
||||
mm_input = forward_batch.merge_mm_inputs()
|
||||
inputs_embeds = embed_mm_inputs(
|
||||
mm_input=image,
|
||||
mm_inputs=mm_input,
|
||||
input_ids=input_ids,
|
||||
input_embedding=embed_tokens,
|
||||
mm_data_embedding_func=mm_data_embedding_func,
|
||||
image_data_embedding_func=image_data_embedding_func,
|
||||
audio_data_embedding_func=audio_data_embedding_func,
|
||||
placeholder_token_ids=placeholder_token_ids,
|
||||
)
|
||||
# once used, mm_inputs is useless
|
||||
@@ -310,7 +352,13 @@ def general_mm_embed_routine(
|
||||
else:
|
||||
inputs_embeds = embed_tokens(input_ids)
|
||||
|
||||
return inputs_embeds
|
||||
hidden_states = language_model(
|
||||
input_ids=None,
|
||||
forward_batch=forward_batch,
|
||||
input_embeds=inputs_embeds,
|
||||
**kwargs,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def get_multimodal_data_bounds(
|
||||
@@ -322,15 +370,13 @@ def get_multimodal_data_bounds(
|
||||
Returns:
|
||||
[bounds_count, 2]
|
||||
"""
|
||||
# All the images in the batch should share the same special image
|
||||
# bound token ids.
|
||||
# All the multimodal data in the batch should share the same special bound token ids.
|
||||
start_tokens = [s for s, _e in token_pairs]
|
||||
end_tokens = [e for _s, e in token_pairs]
|
||||
|
||||
assert all(isinstance(t, int) for t in start_tokens)
|
||||
assert all(isinstance(t, int) for t in end_tokens)
|
||||
|
||||
# print(input_ids)
|
||||
start_cond = torch.isin(
|
||||
input_ids, torch.tensor(start_tokens, device=input_ids.device)
|
||||
)
|
||||
@@ -339,7 +385,7 @@ def get_multimodal_data_bounds(
|
||||
(data_start_tokens,) = torch.where(start_cond)
|
||||
(data_end_tokens,) = torch.where(end_cond)
|
||||
|
||||
# the im_start_id sometimes can be cached as prefix, but it is needed for the embedding of the images
|
||||
# the im_start_id sometimes can be cached as prefix, but it is needed for the embedding of the multimodal data
|
||||
if len(data_start_tokens) != len(data_end_tokens):
|
||||
if (
|
||||
len(data_start_tokens) + 1 == len(data_end_tokens)
|
||||
@@ -352,14 +398,14 @@ def get_multimodal_data_bounds(
|
||||
data_start_tokens,
|
||||
]
|
||||
)
|
||||
valid_image_nums = min(len(data_start_tokens), len(data_end_tokens))
|
||||
valid_mm_data_nums = min(len(data_start_tokens), len(data_end_tokens))
|
||||
|
||||
if valid_image_nums == 0:
|
||||
if valid_mm_data_nums == 0:
|
||||
return torch.zeros((0, 2), device=input_ids.device)
|
||||
|
||||
# Filter out pairs where start_token >= end_token
|
||||
valid_pairs = []
|
||||
for i in range(valid_image_nums):
|
||||
for i in range(valid_mm_data_nums):
|
||||
start_token = data_start_tokens[i]
|
||||
end_token = data_end_tokens[i]
|
||||
if start_token < end_token:
|
||||
|
||||
@@ -64,5 +64,3 @@ def get_mm_processor(
|
||||
f"No processor registered for architecture: {hf_config.architectures}.\n"
|
||||
f"Registered architectures: {[model_cls.__name__ for model_cls in PROCESSOR_MAPPING.keys()]}"
|
||||
)
|
||||
|
||||
self.image_proce
|
||||
|
||||
@@ -8,18 +8,10 @@ from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import transformers
|
||||
from decord import VideoReader, cpu
|
||||
from PIL import Image
|
||||
|
||||
from sglang.srt.utils import load_audio, load_image, logger
|
||||
|
||||
global global_processor
|
||||
|
||||
|
||||
def get_global_processor():
|
||||
global global_processor
|
||||
return global_processor
|
||||
from sglang.srt.utils import encode_video, load_audio, load_image, logger
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@@ -27,9 +19,6 @@ class BaseMultiModalProcessorOutput:
|
||||
# input_text, with each frame of video/image represented with a image_token
|
||||
input_text: str
|
||||
|
||||
mm_data_hashes: Optional[list[int]]
|
||||
# images
|
||||
image_sizes: Optional[list[int]]
|
||||
# frames loaded from image and video, in given order
|
||||
images: Optional[list[PIL.Image]] = None
|
||||
|
||||
@@ -37,7 +26,7 @@ class BaseMultiModalProcessorOutput:
|
||||
audios: Optional[list[np.ndarray]] = None
|
||||
|
||||
def normalize(self):
|
||||
for field_name in ["data_hashes", "image_sizes", "images", "audios"]:
|
||||
for field_name in ["image_sizes", "images", "audios"]:
|
||||
field = getattr(self, field_name, None)
|
||||
if field is not None and isinstance(field, list) and len(field) == 0:
|
||||
setattr(self, field_name, None)
|
||||
@@ -67,28 +56,35 @@ class BaseMultimodalProcessor(ABC):
|
||||
# FIXME: not accurate, model and image specific
|
||||
self.NUM_TOKEN_PER_FRAME = 330
|
||||
|
||||
# Initialize global processor first
|
||||
init_global_processor(self, server_args)
|
||||
|
||||
self.executor = concurrent.futures.ProcessPoolExecutor(
|
||||
initializer=init_global_processor,
|
||||
self.io_executor = concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=int(os.environ.get("SGLANG_IO_WORKERS", 4))
|
||||
)
|
||||
self.cpu_executor = concurrent.futures.ProcessPoolExecutor(
|
||||
mp_context=mp.get_context("fork"),
|
||||
initargs=(
|
||||
self,
|
||||
server_args,
|
||||
),
|
||||
max_workers=int(os.environ.get("SGLANG_CPU_COUNT", os.cpu_count())),
|
||||
max_workers=int(os.environ.get("SGLANG_CPU_WORKERS", os.cpu_count())),
|
||||
)
|
||||
|
||||
def _build_processor(self, server_args):
|
||||
"""Init the global processor for multi modal models."""
|
||||
from sglang.srt.hf_transformers_utils import get_processor
|
||||
def process_mm_data(
|
||||
self, input_text, images=None, videos=None, audios=None, **kwargs
|
||||
):
|
||||
"""
|
||||
process multimodal data with transformers AutoProcessor
|
||||
"""
|
||||
if images is not None:
|
||||
kwargs["images"] = images
|
||||
if videos is not None:
|
||||
kwargs["videos"] = videos
|
||||
if audios is not None:
|
||||
kwargs["audios"] = audios
|
||||
|
||||
return get_processor(
|
||||
server_args.tokenizer_path,
|
||||
tokenizer_mode=server_args.tokenizer_mode,
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
processor = self._processor
|
||||
result = processor.__call__(
|
||||
text=[input_text],
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
**kwargs,
|
||||
)
|
||||
return result
|
||||
|
||||
@abstractmethod
|
||||
async def process_mm_data_async(
|
||||
@@ -115,33 +111,9 @@ class BaseMultimodalProcessor(ABC):
|
||||
|
||||
return estimated_frames_list
|
||||
|
||||
@staticmethod
|
||||
def encode_video(video_path, frame_count_limit=None):
|
||||
if not os.path.exists(video_path):
|
||||
logger.error(f"Video {video_path} does not exist")
|
||||
return []
|
||||
|
||||
if frame_count_limit == 0:
|
||||
return []
|
||||
|
||||
def uniform_sample(l, n):
|
||||
gap = len(l) / n
|
||||
idxs = [int(i * gap + gap / 2) for i in range(n)]
|
||||
return [l[i] for i in idxs]
|
||||
|
||||
vr = VideoReader(video_path, ctx=cpu(0))
|
||||
sample_fps = round(vr.get_avg_fps() / 1) # FPS
|
||||
frame_indices = [i for i in range(0, len(vr), sample_fps)]
|
||||
if frame_count_limit is not None and len(frame_indices) > frame_count_limit:
|
||||
frame_indices = uniform_sample(frame_indices, frame_count_limit)
|
||||
|
||||
frames = vr.get_batch(frame_indices).asnumpy()
|
||||
frames = [Image.fromarray(v.astype("uint8")) for v in frames]
|
||||
return frames
|
||||
|
||||
def load_mm_data(
|
||||
self,
|
||||
input_ids: list[int],
|
||||
prompt: str,
|
||||
multimodal_tokens: MultimodalSpecialTokens,
|
||||
max_req_input_len: int,
|
||||
image_data: Optional[list] = None,
|
||||
@@ -167,11 +139,13 @@ class BaseMultimodalProcessor(ABC):
|
||||
else:
|
||||
multimodal_tokens.image_token = multimodal_tokens.image_token
|
||||
|
||||
if isinstance(input_ids, list) and return_text:
|
||||
assert len(input_ids) and isinstance(input_ids[0], int)
|
||||
input_text = self._processor.tokenizer.decode(input_ids)
|
||||
assert isinstance(prompt, str)
|
||||
|
||||
if isinstance(prompt, list) and return_text:
|
||||
assert len(prompt) and isinstance(prompt[0], int)
|
||||
prompt = self._processor.tokenizer.decode(prompt)
|
||||
else:
|
||||
input_text = input_ids
|
||||
prompt = prompt
|
||||
if return_text:
|
||||
import re
|
||||
|
||||
@@ -181,7 +155,7 @@ class BaseMultimodalProcessor(ABC):
|
||||
+ ")"
|
||||
)
|
||||
# split text into list of normal text and special tokens
|
||||
text_parts = re.split(pattern, input_text)
|
||||
text_parts = re.split(pattern, prompt)
|
||||
|
||||
# TODO(mick): load from server_args, env, or sampling_params
|
||||
MAX_NUM_FRAMES = 30
|
||||
@@ -217,7 +191,7 @@ class BaseMultimodalProcessor(ABC):
|
||||
):
|
||||
# video
|
||||
path = image_file[len("video:") :]
|
||||
frames = BaseMultimodalProcessor.encode_video(
|
||||
frames = encode_video(
|
||||
path, frame_count_limit=frames_to_process
|
||||
)
|
||||
else:
|
||||
@@ -254,19 +228,9 @@ class BaseMultimodalProcessor(ABC):
|
||||
raise RuntimeError(f"An exception occurred while loading images: {e}")
|
||||
|
||||
out = BaseMultiModalProcessorOutput(
|
||||
mm_data_hashes=hashes,
|
||||
image_sizes=image_sizes,
|
||||
images=images,
|
||||
audios=audios,
|
||||
input_text=new_text,
|
||||
)
|
||||
out.normalize()
|
||||
return out
|
||||
|
||||
|
||||
def init_global_processor(sglang_processor: BaseMultimodalProcessor, server_args):
|
||||
"""
|
||||
Init the global processor for multimodal models."""
|
||||
global global_processor
|
||||
transformers.logging.set_verbosity_error()
|
||||
global_processor = sglang_processor._build_processor(server_args=server_args)
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
import asyncio
|
||||
from typing import List, Union
|
||||
|
||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||
BaseMultimodalProcessor,
|
||||
get_global_processor,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||
from sglang.srt.models.clip import CLIPModel
|
||||
from sglang.srt.utils import load_image
|
||||
|
||||
@@ -15,29 +14,6 @@ class ClipImageProcessor(BaseMultimodalProcessor):
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
|
||||
@staticmethod
|
||||
def _process_single_image_task(images, input_text):
|
||||
# input_ids', 'attention_mask', 'pixel_values', 'aspect_ratio_ids', 'aspect_ratio_mask', 'cross_attention_mask'
|
||||
return get_global_processor()(
|
||||
images=images, text=input_text, return_tensors="pt"
|
||||
)
|
||||
|
||||
async def _process_single_image(self, images, input_text):
|
||||
if self.executor is not None:
|
||||
loop = asyncio.get_event_loop()
|
||||
image_inputs = await loop.run_in_executor(
|
||||
self.executor,
|
||||
ClipImageProcessor._process_single_image_task,
|
||||
images,
|
||||
input_text,
|
||||
)
|
||||
else:
|
||||
image_inputs = self._processor(
|
||||
images=images, text=[input_text], return_tensors="pt"
|
||||
)
|
||||
|
||||
return image_inputs
|
||||
|
||||
async def process_mm_data_async(
|
||||
self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs
|
||||
):
|
||||
@@ -56,8 +32,13 @@ class ClipImageProcessor(BaseMultimodalProcessor):
|
||||
else:
|
||||
images = load_image(image_data[0])[0]
|
||||
|
||||
image_inputs = await self._process_single_image(images, input_text)
|
||||
image_inputs = self.process_mm_data(input_text=input_text, images=images)
|
||||
image_inputs["data_hashes"] = [hash(str(image_data))]
|
||||
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
|
||||
image_inputs["mm_items"] = [
|
||||
MultimodalDataItem(
|
||||
pixel_values=image_inputs["pixel_values"], modality=Modality.IMAGE
|
||||
)
|
||||
]
|
||||
|
||||
return image_inputs
|
||||
|
||||
@@ -16,15 +16,14 @@
|
||||
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
||||
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
||||
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
import asyncio
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||
BaseMultimodalProcessor,
|
||||
MultimodalSpecialTokens,
|
||||
get_global_processor,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||
from sglang.srt.models.deepseek_vl2 import DeepseekVL2ForCausalLM
|
||||
|
||||
|
||||
@@ -35,51 +34,6 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
self.IMAGE_TOKEN = "<image>"
|
||||
|
||||
@staticmethod
|
||||
def _process_images_task(image, input_text, max_req_input_len):
|
||||
processor = get_global_processor()
|
||||
res = processor.__call__(
|
||||
conversations=input_text, images=image, max_req_input_len=max_req_input_len
|
||||
)
|
||||
|
||||
image_token_id = processor.image_token_id
|
||||
|
||||
res["im_token_id"] = image_token_id
|
||||
return res
|
||||
|
||||
async def _process_images(self, image_data, input_text, max_req_input_len):
|
||||
if self.executor is not None:
|
||||
loop = asyncio.get_event_loop()
|
||||
image_inputs = await loop.run_in_executor(
|
||||
self.executor,
|
||||
DeepseekVL2ImageProcessor._process_images_task,
|
||||
image_data,
|
||||
input_text,
|
||||
max_req_input_len,
|
||||
)
|
||||
else:
|
||||
image_inputs = self._process_images_task(
|
||||
image_data, input_text, max_req_input_len
|
||||
)
|
||||
|
||||
return image_inputs
|
||||
|
||||
async def _process_images(self, image_data, input_text, max_req_input_len):
|
||||
if self.executor is not None:
|
||||
loop = asyncio.get_event_loop()
|
||||
image_inputs = await loop.run_in_executor(
|
||||
self.executor,
|
||||
DeepseekVL2ImageProcessor._process_images_task,
|
||||
image_data,
|
||||
input_text,
|
||||
max_req_input_len,
|
||||
)
|
||||
else:
|
||||
image_inputs = self._process_images_task(
|
||||
image_data, input_text, max_req_input_len
|
||||
)
|
||||
return image_inputs
|
||||
|
||||
async def process_mm_data_async(
|
||||
self, image_data, input_ids, request_obj, max_req_input_len, *args, **kwargs
|
||||
):
|
||||
@@ -89,8 +43,6 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
|
||||
if not isinstance(image_data, list):
|
||||
image_data = [image_data]
|
||||
|
||||
images, image_sizes = [], []
|
||||
|
||||
image_token = self.IMAGE_TOKEN
|
||||
base_output = self.load_mm_data(
|
||||
input_ids,
|
||||
@@ -98,8 +50,11 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
|
||||
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
|
||||
max_req_input_len=max_req_input_len,
|
||||
)
|
||||
res = await self._process_images(
|
||||
base_output.images, base_output.input_text, max_req_input_len
|
||||
res = self.process_mm_data(
|
||||
input_text=base_output.input_text,
|
||||
images=base_output.images,
|
||||
max_req_input_len=max_req_input_len,
|
||||
conversations=base_output.input_text,
|
||||
)
|
||||
images_seq_mask = res["images_seq_mask"]
|
||||
images_spatial_crop = res["images_spatial_crop"]
|
||||
@@ -107,13 +62,17 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
|
||||
batched_images_spatial_crop.append(images_spatial_crop)
|
||||
batched_images_spatial_crop = torch.stack(batched_images_spatial_crop, dim=0)
|
||||
|
||||
items = []
|
||||
item = MultimodalDataItem(
|
||||
pixel_values=res["images"],
|
||||
modality=Modality.IMAGE,
|
||||
image_emb_mask=images_seq_mask,
|
||||
image_spatial_crop=batched_images_spatial_crop,
|
||||
)
|
||||
items += [item]
|
||||
|
||||
return {
|
||||
"mm_items": items,
|
||||
"input_ids": res["input_ids"].tolist(),
|
||||
"pixel_values": res["images"],
|
||||
"im_token_id": res["im_token_id"],
|
||||
"data_hashes": base_output.mm_data_hashes,
|
||||
"image_sizes": image_sizes,
|
||||
"images_emb_mask": images_seq_mask,
|
||||
"image_spatial_crop": batched_images_spatial_crop,
|
||||
"modalities": request_obj.modalities or ["image"],
|
||||
"im_token_id": self._processor.image_token_id,
|
||||
}
|
||||
|
||||
@@ -7,8 +7,8 @@ from sglang.srt.managers.multimodal_processor import (
|
||||
)
|
||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||
MultimodalSpecialTokens,
|
||||
get_global_processor,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||
from sglang.srt.models.gemma3_mm import Gemma3ForConditionalGeneration
|
||||
|
||||
# Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma3/image_processing_gemma3_fast.py
|
||||
@@ -25,28 +25,6 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
|
||||
self.IM_START_TOKEN_ID = hf_config.boi_token_index
|
||||
self.IM_END_TOKEN_ID = hf_config.eoi_token_index
|
||||
|
||||
async def _process_single_image(self, images, input_text) -> dict:
|
||||
if isinstance(images, list) and len(images) == 0:
|
||||
images = None
|
||||
processor = get_global_processor()
|
||||
result = processor.__call__(
|
||||
text=[input_text],
|
||||
images=images,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
# if RGBA, this needs to be set
|
||||
# images_kwargs={
|
||||
# "input_data_format": ChannelDimension.FIRST
|
||||
# }
|
||||
)
|
||||
|
||||
pixel_values = getattr(result, "pixel_values", None)
|
||||
|
||||
return {
|
||||
"input_ids": result.input_ids,
|
||||
"pixel_values": pixel_values,
|
||||
}
|
||||
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
image_data: List[Union[str, bytes]],
|
||||
@@ -63,21 +41,28 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
|
||||
|
||||
image_token = self.IMAGE_TOKEN
|
||||
base_output = self.load_mm_data(
|
||||
input_ids=input_ids,
|
||||
prompt=input_ids,
|
||||
image_data=image_data,
|
||||
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
|
||||
max_req_input_len=max_req_input_len,
|
||||
discard_alpha_channel=True,
|
||||
)
|
||||
|
||||
ret = await self._process_single_image(
|
||||
ret = self.process_mm_data(
|
||||
input_text=base_output.input_text, images=base_output.images
|
||||
)
|
||||
|
||||
items = []
|
||||
for i, image in enumerate(base_output.images):
|
||||
item = MultimodalDataItem(
|
||||
pixel_values=ret["pixel_values"][i],
|
||||
modality=Modality.IMAGE,
|
||||
)
|
||||
items += [item]
|
||||
|
||||
return {
|
||||
"mm_items": items,
|
||||
"input_ids": ret["input_ids"].flatten().tolist(),
|
||||
"pixel_values": ret["pixel_values"],
|
||||
"data_hashes": base_output.mm_data_hashes,
|
||||
"im_start_id": self.IM_START_TOKEN_ID,
|
||||
"im_end_id": self.IM_END_TOKEN_ID,
|
||||
}
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import asyncio
|
||||
from typing import List, Union
|
||||
|
||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||
BaseMultimodalProcessor,
|
||||
MultimodalSpecialTokens,
|
||||
get_global_processor,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||
from sglang.srt.models.deepseek_janus_pro import MultiModalityCausalLM
|
||||
|
||||
|
||||
@@ -15,37 +14,6 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
|
||||
@staticmethod
|
||||
def _process_images_task(images, input_text):
|
||||
processor = get_global_processor()
|
||||
result = processor.__call__(
|
||||
prompt=input_text, images=images, return_tensors="pt"
|
||||
)
|
||||
return {
|
||||
"input_ids": result["input_ids"],
|
||||
"pixel_values": result["pixel_values"],
|
||||
"images_emb_mask": result["images_emb_mask"],
|
||||
"im_start_id": processor.image_start_id,
|
||||
"im_end_id": processor.image_end_id,
|
||||
"im_token_id": processor.image_id,
|
||||
}
|
||||
|
||||
async def _process_images(self, images, input_text):
|
||||
if self.executor is not None:
|
||||
loop = asyncio.get_event_loop()
|
||||
image_inputs = await loop.run_in_executor(
|
||||
self.executor,
|
||||
JanusProImageProcessor._process_images_task,
|
||||
images,
|
||||
input_text,
|
||||
)
|
||||
else:
|
||||
image_inputs = self._processor(
|
||||
images=images, text=input_text, return_tensors="pt"
|
||||
)
|
||||
|
||||
return image_inputs
|
||||
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
image_data: List[Union[str, bytes]],
|
||||
@@ -60,25 +28,31 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
|
||||
if not isinstance(image_data, list):
|
||||
image_data = [image_data]
|
||||
|
||||
processor = self._processor
|
||||
|
||||
base_out = self.load_mm_data(
|
||||
input_ids=input_ids,
|
||||
prompt=input_ids,
|
||||
image_data=image_data,
|
||||
multimodal_tokens=MultimodalSpecialTokens(
|
||||
image_token="<image_placeholder>"
|
||||
),
|
||||
multimodal_tokens=MultimodalSpecialTokens(image_token=processor.image_tag),
|
||||
max_req_input_len=max_req_input_len,
|
||||
)
|
||||
|
||||
images = base_out.images
|
||||
res = await self._process_images(images=images, input_text=base_out.input_text)
|
||||
# print(res)
|
||||
# print(base_out)
|
||||
# print("", res["images_emb_mask"].shape)
|
||||
res = self.process_mm_data(
|
||||
input_text=base_out.input_text,
|
||||
prompt=base_out.input_text,
|
||||
images=images,
|
||||
)
|
||||
return {
|
||||
"mm_items": [
|
||||
MultimodalDataItem(
|
||||
pixel_values=res["pixel_values"],
|
||||
image_emb_mask=res["images_emb_mask"],
|
||||
modality=Modality.IMAGE,
|
||||
)
|
||||
],
|
||||
"input_ids": res["input_ids"].flatten().tolist(),
|
||||
"pixel_values": res["pixel_values"],
|
||||
"images_emb_mask": res["images_emb_mask"],
|
||||
"data_hashes": base_out.mm_data_hashes,
|
||||
"im_start_id": res["im_start_id"],
|
||||
"im_end_id": res["im_end_id"],
|
||||
"im_token_id": res["im_token_id"],
|
||||
"im_start_id": processor.image_start_id,
|
||||
"im_end_id": processor.image_end_id,
|
||||
"im_token_id": processor.image_id,
|
||||
}
|
||||
|
||||
@@ -5,8 +5,8 @@ import numpy as np
|
||||
|
||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||
BaseMultimodalProcessor,
|
||||
get_global_processor,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||
from sglang.srt.mm_utils import expand2square, process_anyres_image
|
||||
from sglang.srt.models.llava import LlavaMistralForCausalLM, LlavaQwenForCausalLM
|
||||
from sglang.srt.models.llavavid import LlavaVidForCausalLM
|
||||
@@ -25,11 +25,10 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
|
||||
image_data: Union[str, bytes],
|
||||
image_aspect_ratio: Optional[str] = None,
|
||||
image_grid_pinpoints: Optional[str] = None,
|
||||
image_processor=None,
|
||||
processor=None,
|
||||
):
|
||||
processor = get_global_processor()
|
||||
|
||||
image_processor = image_processor or processor.image_processor
|
||||
image_processor = processor.image_processor
|
||||
|
||||
try:
|
||||
image, image_size = load_image(image_data)
|
||||
@@ -72,18 +71,22 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
|
||||
async def _process_single_image(
|
||||
self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str
|
||||
):
|
||||
if self.executor is not None:
|
||||
if self.cpu_executor is not None:
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
self.executor,
|
||||
self.cpu_executor,
|
||||
LlavaImageProcessor._process_single_image_task,
|
||||
image_data,
|
||||
aspect_ratio,
|
||||
grid_pinpoints,
|
||||
self._processor,
|
||||
)
|
||||
else:
|
||||
return self._process_single_image_task(
|
||||
image_data, aspect_ratio, grid_pinpoints
|
||||
image_data,
|
||||
aspect_ratio,
|
||||
grid_pinpoints,
|
||||
self._processor.image_processor,
|
||||
)
|
||||
|
||||
async def process_mm_data_async(
|
||||
@@ -134,14 +137,22 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
|
||||
pixel_values, image_hash, image_size = await self._process_single_image(
|
||||
image_data[0], aspect_ratio, grid_pinpoints
|
||||
)
|
||||
data_hashes = [image_hash]
|
||||
image_sizes = [image_size]
|
||||
else:
|
||||
raise ValueError(f"Invalid image data: {image_data}")
|
||||
modality = Modality.IMAGE
|
||||
if isinstance(request_obj.modalities, list):
|
||||
if request_obj.modalities[0] == "multi-images":
|
||||
modality = Modality.MULTI_IMAGES
|
||||
elif request_obj.modalities[0] == "video":
|
||||
modality = Modality.VIDEO
|
||||
|
||||
return {
|
||||
"pixel_values": pixel_values,
|
||||
"data_hashes": data_hashes,
|
||||
"image_sizes": image_sizes,
|
||||
"modalities": request_obj.modalities or ["image"],
|
||||
"mm_items": [
|
||||
MultimodalDataItem(
|
||||
pixel_values=pixel_values,
|
||||
image_sizes=image_sizes,
|
||||
modality=modality,
|
||||
)
|
||||
],
|
||||
}
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
import asyncio
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
from transformers import BaseImageProcessorFast
|
||||
|
||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||
BaseMultimodalProcessor,
|
||||
MultimodalSpecialTokens,
|
||||
get_global_processor,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||
from sglang.srt.models.minicpmo import MiniCPMO
|
||||
from sglang.srt.models.minicpmv import MiniCPMV
|
||||
|
||||
@@ -21,19 +21,23 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
||||
self.image_token = "(<image>./</image>)"
|
||||
self.audio_token = "(<audio>./</audio>)"
|
||||
|
||||
@staticmethod
|
||||
def _process_data_task(input_text, images=None, audios=None):
|
||||
def process_data_task(self, input_text, images=None, audios=None):
|
||||
|
||||
if isinstance(images, list) and len(images) == 0:
|
||||
images = None
|
||||
if isinstance(audios, list) and len(audios) == 0:
|
||||
audios = None
|
||||
result = get_global_processor().__call__(
|
||||
processor = self._processor
|
||||
args = {}
|
||||
if isinstance(processor, BaseImageProcessorFast):
|
||||
args["device"] = "cuda"
|
||||
result = self._processor.__call__(
|
||||
text=input_text,
|
||||
images=images,
|
||||
audios=audios,
|
||||
return_tensors="pt",
|
||||
chunk_input=True,
|
||||
**args,
|
||||
)
|
||||
return {
|
||||
"input_ids": result.input_ids,
|
||||
@@ -44,23 +48,6 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
||||
"audio_bounds": getattr(result, "audio_bounds", None),
|
||||
}
|
||||
|
||||
async def _process_data(self, images, input_text, audios=None):
|
||||
if self.executor is not None:
|
||||
loop = asyncio.get_event_loop()
|
||||
multimodal_data_inputs = await loop.run_in_executor(
|
||||
self.executor,
|
||||
MiniCPMMultimodalProcessor._process_data_task,
|
||||
input_text,
|
||||
images,
|
||||
audios,
|
||||
)
|
||||
else:
|
||||
multimodal_data_inputs = self._processor(
|
||||
images=images, text=input_text, audios=audios, return_tensors="pt"
|
||||
)
|
||||
|
||||
return multimodal_data_inputs
|
||||
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
image_data: List[Union[str, bytes]],
|
||||
@@ -77,7 +64,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
||||
audio_data = [audio_data]
|
||||
|
||||
base_output = self.load_mm_data(
|
||||
input_ids=input_ids,
|
||||
prompt=input_ids,
|
||||
max_req_input_len=max_req_input_len,
|
||||
audio_data=audio_data,
|
||||
image_data=image_data,
|
||||
@@ -88,9 +75,9 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
||||
if base_output is None:
|
||||
return None
|
||||
|
||||
res = await self._process_data(
|
||||
images=base_output.images,
|
||||
res = self.process_mm_data(
|
||||
input_text=base_output.input_text,
|
||||
images=base_output.images,
|
||||
audios=base_output.audios,
|
||||
)
|
||||
|
||||
@@ -142,23 +129,33 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
||||
tgt_sizes_flat += [tgt_n]
|
||||
|
||||
pixel_values = pixel_values_flat
|
||||
if len(tgt_sizes_flat) == 0:
|
||||
tgt_sizes = None
|
||||
else:
|
||||
tgt_sizes = torch.stack(tgt_sizes_flat)
|
||||
if not isinstance(res["audio_features"], list):
|
||||
res["audio_features"] = [res["audio_features"]]
|
||||
|
||||
items = []
|
||||
if len(pixel_values) != 0:
|
||||
item = MultimodalDataItem(
|
||||
pixel_values=pixel_values,
|
||||
tgt_size=tgt_sizes_flat,
|
||||
modality=Modality.IMAGE,
|
||||
)
|
||||
items += [item]
|
||||
|
||||
if (
|
||||
"audio_features" in res
|
||||
and res["audio_features"] is not None
|
||||
and len(res["audio_features"]) != 0
|
||||
):
|
||||
item = MultimodalDataItem(
|
||||
audio_features=[res["audio_features"]],
|
||||
audio_feature_lens=res["audio_feature_lens"],
|
||||
modality=Modality.AUDIO,
|
||||
)
|
||||
items += [item]
|
||||
|
||||
return {
|
||||
"mm_items": items,
|
||||
"input_ids": res["input_ids"].flatten().tolist(),
|
||||
"pixel_values": pixel_values,
|
||||
"tgt_sizes": tgt_sizes,
|
||||
"data_hashes": base_output.mm_data_hashes,
|
||||
"modalities": request_obj.modalities or ["image"],
|
||||
"audio_start_id": audio_start_id,
|
||||
"audio_end_id": audio_end_id,
|
||||
"audio_features": res["audio_features"],
|
||||
"audio_bounds": res["audio_bounds"],
|
||||
"audio_feature_lens": res["audio_feature_lens"],
|
||||
"im_token_id": im_token_id,
|
||||
"im_start_id": tokenizer.im_start_id,
|
||||
"im_end_id": tokenizer.im_end_id,
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
import asyncio
|
||||
from typing import List, Union
|
||||
|
||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||
BaseMultimodalProcessor,
|
||||
get_global_processor,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||
from sglang.srt.models.mllama import MllamaForConditionalGeneration
|
||||
from sglang.srt.utils import load_image
|
||||
|
||||
@@ -15,25 +14,6 @@ class MllamaImageProcessor(BaseMultimodalProcessor):
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
|
||||
@staticmethod
|
||||
def _process_single_image_task(images, input_text):
|
||||
# input_ids', 'attention_mask', 'pixel_values', 'aspect_ratio_ids', 'aspect_ratio_mask', 'cross_attention_mask'
|
||||
return get_global_processor()(images, input_text, return_tensors="pt")
|
||||
|
||||
async def _process_single_image(self, images, input_text):
|
||||
if self.executor is not None:
|
||||
loop = asyncio.get_event_loop()
|
||||
image_inputs = await loop.run_in_executor(
|
||||
self.executor,
|
||||
MllamaImageProcessor._process_single_image_task,
|
||||
images,
|
||||
input_text,
|
||||
)
|
||||
else:
|
||||
image_inputs = self._processor(images, input_text, return_tensors="pt")
|
||||
|
||||
return image_inputs
|
||||
|
||||
async def process_mm_data_async(
|
||||
self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs
|
||||
):
|
||||
@@ -52,8 +32,15 @@ class MllamaImageProcessor(BaseMultimodalProcessor):
|
||||
else:
|
||||
images = load_image(image_data[0])[0]
|
||||
|
||||
image_inputs = await self._process_single_image(images, input_text)
|
||||
image_inputs["data_hashes"] = [hash(str(image_data))]
|
||||
image_inputs = self.process_mm_data(input_text=input_text, images=images)
|
||||
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
|
||||
image_inputs["mm_items"] = [
|
||||
MultimodalDataItem(
|
||||
pixel_values=image_inputs["pixel_values"],
|
||||
aspect_ratio_id=image_inputs["aspect_ratio_ids"],
|
||||
aspect_ratio_mask=image_inputs["aspect_ratio_mask"],
|
||||
modality=Modality.IMAGE,
|
||||
)
|
||||
]
|
||||
|
||||
return image_inputs
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import asyncio
|
||||
import math
|
||||
import time
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
@@ -11,8 +10,8 @@ from sglang.srt.managers.multimodal_processor import (
|
||||
)
|
||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||
MultimodalSpecialTokens,
|
||||
get_global_processor,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||
from sglang.srt.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
|
||||
from sglang.srt.models.qwen2_vl import Qwen2VLForConditionalGeneration
|
||||
|
||||
@@ -34,45 +33,15 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
||||
self.MAX_PIXELS = 16384 * 28 * 28
|
||||
self.MAX_RATIO = 200
|
||||
|
||||
@staticmethod
|
||||
def _process_images_task(images, input_text, _hf_config):
|
||||
if isinstance(images, list) and len(images) == 0:
|
||||
images = None
|
||||
result = get_global_processor().__call__(
|
||||
text=[input_text], images=images, padding=True, return_tensors="pt"
|
||||
)
|
||||
|
||||
return {
|
||||
"input_ids": result.input_ids,
|
||||
"pixel_values": getattr(result, "pixel_values", None),
|
||||
"image_grid_thw": getattr(result, "image_grid_thw", None),
|
||||
"second_per_grid_ts": getattr(result, "second_per_grid_ts", None),
|
||||
"video_grid_thws": getattr(result, "video_grid_thws", None),
|
||||
}
|
||||
|
||||
async def _process_single_image(self, images, input_text) -> dict:
|
||||
if self.executor is not None:
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
self.executor,
|
||||
Qwen2_5VLImageProcessor._process_images_task,
|
||||
images,
|
||||
input_text,
|
||||
self.hf_config,
|
||||
)
|
||||
else:
|
||||
return self._process_images_task(images, input_text, self.hf_config)
|
||||
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
image_data: List[Union[str, bytes]],
|
||||
input_ids,
|
||||
prompt,
|
||||
request_obj,
|
||||
max_req_input_len,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
start = time.time()
|
||||
if not image_data:
|
||||
return None
|
||||
if isinstance(image_data, str):
|
||||
@@ -80,7 +49,7 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
||||
|
||||
image_token = self.IMAGE_TOKEN
|
||||
base_output = self.load_mm_data(
|
||||
input_ids=input_ids,
|
||||
prompt=prompt,
|
||||
image_data=image_data,
|
||||
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
|
||||
max_req_input_len=max_req_input_len,
|
||||
@@ -144,24 +113,32 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
||||
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
|
||||
return math.floor(number / factor) * factor
|
||||
|
||||
images = [resize_image(image) for image in base_output.images]
|
||||
async def resize_image_async(image):
|
||||
return resize_image(image)
|
||||
|
||||
ret = await self._process_single_image(
|
||||
images=images, input_text=base_output.input_text
|
||||
resize_tasks = [resize_image_async(image) for image in base_output.images]
|
||||
resized_images = await asyncio.gather(*resize_tasks)
|
||||
|
||||
ret = self.process_mm_data(
|
||||
input_text=base_output.input_text,
|
||||
images=resized_images,
|
||||
)
|
||||
|
||||
image_grid_thws = torch.concat([ret["image_grid_thw"]])
|
||||
video_grid_thws = None
|
||||
return {
|
||||
"input_ids": ret["input_ids"].flatten().tolist(),
|
||||
"pixel_values": ret["pixel_values"],
|
||||
"data_hashes": base_output.mm_data_hashes,
|
||||
"modalities": request_obj.modalities or ["image"],
|
||||
"image_grid_thws": image_grid_thws,
|
||||
"video_grid_thws": video_grid_thws,
|
||||
"mm_items": [
|
||||
MultimodalDataItem(
|
||||
pixel_values=ret["pixel_values"],
|
||||
image_grid_thws=image_grid_thws,
|
||||
# TODO
|
||||
video_grid_thws=None,
|
||||
second_per_grid_ts=ret.get("second_per_grid_ts", None),
|
||||
modality=Modality.IMAGE,
|
||||
)
|
||||
],
|
||||
"im_start_id": self.IM_START_TOKEN_ID,
|
||||
"im_end_id": self.IM_END_TOKEN_ID,
|
||||
"im_token_id": self.image_token_id,
|
||||
"video_token_id": self.video_token_id,
|
||||
"second_per_grid_ts": ret["second_per_grid_ts"],
|
||||
}
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum, auto
|
||||
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -51,7 +53,7 @@ from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, Forw
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import get_compiler_backend
|
||||
from sglang.srt.utils import flatten_nested_list, get_compiler_backend
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
@@ -143,165 +145,185 @@ class FINISH_ABORT(BaseFinishReason):
|
||||
}
|
||||
|
||||
|
||||
class Modality(Enum):
|
||||
IMAGE = auto()
|
||||
MULTI_IMAGES = auto()
|
||||
VIDEO = auto()
|
||||
AUDIO = auto()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MultimodalInputs:
|
||||
"""The image related inputs."""
|
||||
class MultimodalDataItem:
|
||||
"""
|
||||
A single multimodal data, from a single image/video/audio or other
|
||||
"""
|
||||
|
||||
pixel_values: Union[torch.Tensor, np.array]
|
||||
data_hashes: Optional[list] = None
|
||||
image_sizes: Optional[list] = None
|
||||
image_offsets: Optional[list] = None
|
||||
image_pad_len: Optional[list] = None
|
||||
pad_values: Optional[list] = None
|
||||
modalities: Optional[list] = None
|
||||
num_image_tokens: Optional[int] = None
|
||||
modality: Modality
|
||||
|
||||
# Llava related
|
||||
aspect_ratio_ids: Optional[List[torch.Tensor]] = None
|
||||
hash: int = None
|
||||
pad_value: int = None
|
||||
|
||||
aspect_ratio_id: Optional[List[torch.Tensor]] = None
|
||||
aspect_ratio_mask: Optional[List[torch.Tensor]] = None
|
||||
|
||||
# QWen2-VL related
|
||||
# [num_of_images, t, h, w]
|
||||
image_grid_thws: torch.Tensor = None
|
||||
mrope_position_delta: Optional[torch.Tensor] = None
|
||||
# Qwen2-VL video related
|
||||
video_token_id: Optional[int] = None
|
||||
video_grid_thws: List[Tuple[int, int, int]] = None
|
||||
image_sizes: Tuple[int, int] = None
|
||||
image_offsets: Optional[list] = None
|
||||
|
||||
# the real data, pixel_values or audio_features
|
||||
# data: Union[List[torch.Tensor], List[np.array]]
|
||||
pixel_values: Union[torch.Tensor, np.array] = None
|
||||
image_grid_thws: Union[torch.Tensor, np.array] = None
|
||||
video_grid_thws: Union[torch.Tensor, np.array] = None
|
||||
|
||||
image_emb_mask: Optional[torch.Tensor] = None
|
||||
image_spatial_crop: Optional[torch.Tensor] = None
|
||||
second_per_grid_ts: Optional[List[torch.Tensor]] = None
|
||||
|
||||
# deepseek vl2 related
|
||||
images_emb_mask: Optional[List[torch.Tensor]] = None
|
||||
image_spatial_crop: Optional[List[torch.Tensor]] = None
|
||||
# [num_images, (n, w, h)]
|
||||
tgt_size: Tuple[int, int] = None
|
||||
|
||||
# The id of the single-image placeholder token
|
||||
audio_features: Union[torch.Tensor, np.array] = None
|
||||
audio_feature_lens: Optional[List[torch.Tensor]] = None
|
||||
|
||||
@staticmethod
|
||||
def is_empty_list(l):
|
||||
if l is None:
|
||||
return True
|
||||
return len([item for item in flatten_nested_list(l) if item is not None]) == 0
|
||||
|
||||
def set_pad_value(self):
|
||||
"""
|
||||
Set the pad value after first hashign the data
|
||||
"""
|
||||
|
||||
def hash_feature(f):
|
||||
if isinstance(f, list):
|
||||
return hash(tuple(flatten_nested_list(f)))
|
||||
elif isinstance(f, np.ndarray):
|
||||
arr = np.ascontiguousarray(f)
|
||||
arr_bytes = arr.tobytes()
|
||||
return hash(arr_bytes)
|
||||
return hash(f)
|
||||
|
||||
if self.is_audio():
|
||||
self.hash = hash_feature(self.audio_features)
|
||||
else:
|
||||
self.hash = hash_feature(self.pixel_values)
|
||||
|
||||
assert self.hash is not None
|
||||
self.pad_value = self.hash % (1 << 30)
|
||||
|
||||
def is_audio(self):
|
||||
return (
|
||||
self.modality == Modality.AUDIO
|
||||
) and not MultimodalDataItem.is_empty_list(self.audio_features)
|
||||
|
||||
def is_image(self):
|
||||
return (
|
||||
self.modality == Modality.IMAGE or self.modality == Modality.MULTI_IMAGES
|
||||
) and not MultimodalDataItem.is_empty_list(self.pixel_values)
|
||||
|
||||
def is_video(self):
|
||||
return (
|
||||
self.modality == Modality.VIDEO
|
||||
) and not MultimodalDataItem.is_empty_list(self.pixel_values)
|
||||
|
||||
def validate(self):
|
||||
...
|
||||
# TODO
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MultimodalInputs:
|
||||
"""The multimodal data related inputs."""
|
||||
|
||||
# items of data
|
||||
mm_items: List[MultimodalDataItem]
|
||||
image_pad_len: Optional[list] = None
|
||||
num_image_tokens: Optional[int] = None
|
||||
|
||||
# QWen2-VL related
|
||||
mrope_position_delta: Optional[torch.Tensor] = None
|
||||
|
||||
# image
|
||||
im_token_id: Optional[torch.Tensor] = None
|
||||
|
||||
# All the images in the batch should share the same special image
|
||||
# bound token ids.
|
||||
im_start_id: Optional[int] = None
|
||||
im_end_id: Optional[int] = None
|
||||
slice_start_id: Optional[int] = None
|
||||
slice_end_id: Optional[int] = None
|
||||
# [num_images, 2 (w, h)]
|
||||
tgt_sizes: Optional[list] = None
|
||||
|
||||
# video
|
||||
video_token_id: Optional[int] = None
|
||||
|
||||
# audio
|
||||
audio_start_id: Optional[torch.Tensor] = None
|
||||
audio_end_id: Optional[torch.Tensor] = None
|
||||
audio_features: Optional[List[torch.Tensor]] = None
|
||||
audio_feature_lens: Optional[List[torch.Tensor]] = None
|
||||
|
||||
@staticmethod
|
||||
def from_dict(obj: dict):
|
||||
ret = MultimodalInputs(
|
||||
pixel_values=obj["pixel_values"],
|
||||
data_hashes=obj["data_hashes"],
|
||||
mm_items=obj["mm_items"],
|
||||
)
|
||||
|
||||
assert isinstance(ret.mm_items, list)
|
||||
ret.mm_items = [
|
||||
item
|
||||
for item in ret.mm_items
|
||||
if item.is_audio() or item.is_image() or item.is_video()
|
||||
]
|
||||
|
||||
assert len(ret.mm_items) != 0
|
||||
|
||||
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
|
||||
# Please note that if the `input_ids` is later used in the model forward,
|
||||
# you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
|
||||
# errors in cuda kernels. See also llava.py for example.
|
||||
ret.pad_values = [x % (1 << 30) for x in ret.data_hashes]
|
||||
for item in ret.mm_items:
|
||||
item.set_pad_value()
|
||||
|
||||
optional_args = [
|
||||
"image_sizes",
|
||||
"modalities",
|
||||
"aspect_ratio_ids",
|
||||
"aspect_ratio_mask",
|
||||
"image_grid_thws",
|
||||
"images_emb_mask",
|
||||
"image_spatial_crop",
|
||||
"im_token_id",
|
||||
"im_start_id",
|
||||
"im_end_id",
|
||||
"slice_start_id",
|
||||
"slice_end_id",
|
||||
"tgt_sizes",
|
||||
"audio_start_id",
|
||||
"audio_end_id",
|
||||
"audio_features",
|
||||
"audio_feature_lens",
|
||||
]
|
||||
for arg in optional_args:
|
||||
if arg in obj:
|
||||
setattr(ret, arg, obj[arg])
|
||||
|
||||
# validate
|
||||
assert (
|
||||
isinstance(ret.pixel_values, torch.Tensor)
|
||||
or isinstance(ret.pixel_values, np.ndarray)
|
||||
or isinstance(ret.pixel_values, list)
|
||||
)
|
||||
|
||||
assert ret.audio_features is None or isinstance(ret.audio_features, list)
|
||||
|
||||
return ret
|
||||
|
||||
def contains_image_inputs(self) -> bool:
|
||||
""" """
|
||||
return self.pixel_values is not None and self.pixel_values != []
|
||||
return any(item.is_image() for item in self.mm_items)
|
||||
|
||||
def contains_audio_inputs(self) -> bool:
|
||||
""" """
|
||||
return self.audio_features is not None and self.audio_features != []
|
||||
return any(item.is_audio() for item in self.mm_items)
|
||||
|
||||
def collect_image_inputs(self) -> List[torch.Tensor]:
|
||||
return [item.pixel_values for item in self.mm_items if item.is_image()]
|
||||
|
||||
def merge(self, other: MultimodalInputs):
|
||||
"""
|
||||
merge image inputs when requests are being merged
|
||||
"""
|
||||
if isinstance(self.pixel_values, list):
|
||||
# in some rare cases, pixel values are list of patches with different shapes
|
||||
# e.g. minicpm
|
||||
self.pixel_values += other.pixel_values
|
||||
else:
|
||||
assert (
|
||||
self.pixel_values.shape[1:] == other.pixel_values.shape[1:]
|
||||
), f"{self.pixel_values.shape[1:]} vs {other.pixel_values.shape[1:]}"
|
||||
self.pixel_values = np.concatenate([self.pixel_values, other.pixel_values])
|
||||
|
||||
# args would be stacked along first dim
|
||||
# usually these are already tensors
|
||||
stack_args = [
|
||||
# TODO: merge with image_grid_thws, basically the same thing
|
||||
"tgt_sizes",
|
||||
"image_spatial_crop",
|
||||
]
|
||||
for arg in stack_args:
|
||||
if getattr(self, arg, None) is None:
|
||||
setattr(self, arg, getattr(other, arg, None))
|
||||
elif getattr(other, arg, None) is not None:
|
||||
# self and other both not None
|
||||
setattr(
|
||||
self,
|
||||
arg,
|
||||
torch.cat([getattr(self, arg), getattr(other, arg)], dim=0),
|
||||
)
|
||||
|
||||
if self.image_grid_thws is None:
|
||||
self.image_grid_thws = other.image_grid_thws
|
||||
elif other.image_grid_thws is not None:
|
||||
self.image_grid_thws = torch.concat(
|
||||
[self.image_grid_thws, other.image_grid_thws]
|
||||
)
|
||||
|
||||
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
|
||||
# Please note that if the `input_ids` is later used in the model forward,
|
||||
# you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
|
||||
# errors in cuda kernels. See also llava.py for example.
|
||||
self.data_hashes += other.data_hashes
|
||||
self.pad_values = [x % (1 << 30) for x in self.data_hashes]
|
||||
|
||||
# args needed to be merged
|
||||
optional_args = [
|
||||
"audio_features",
|
||||
"image_sizes",
|
||||
"items",
|
||||
"image_offsets",
|
||||
"image_pad_len",
|
||||
# "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images
|
||||
"aspect_ratio_ids",
|
||||
"aspect_ratio_mask",
|
||||
"images_emb_mask",
|
||||
]
|
||||
for arg in optional_args:
|
||||
self_arg = getattr(self, arg, None)
|
||||
|
||||
@@ -112,7 +112,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache
|
||||
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||
|
||||
@@ -1,11 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from http import HTTPStatus
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from typing import Optional
|
||||
|
||||
from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req
|
||||
|
||||
|
||||
Reference in New Issue
Block a user