refactor: multimodal data (#4754)

This commit is contained in:
Mick
2025-04-01 00:57:51 +08:00
committed by GitHub
parent c7457191a0
commit 5cb552b1d4
36 changed files with 989 additions and 1138 deletions

View File

@@ -72,17 +72,38 @@ def eval_mmmu(args):
if suffix: if suffix:
contents += [{"type": "text", "text": suffix}] contents += [{"type": "text", "text": suffix}]
messages = [{"role": "user", "content": contents}] messages = [{"role": "user", "content": contents}]
model_inputs = processor.apply_chat_template( try:
messages, model_inputs = processor.tokenizer.apply_chat_template(
tokenize=True, messages,
return_dict=True, tokenize=True,
add_generation_prompt=True, return_dict=True,
return_tensors="pt", add_generation_prompt=True,
).to(model.device) return_tensors="pt",
input_len = model_inputs["input_ids"].shape[-1] ).to(model.device)
generation = model.generate(**model_inputs, generation_config=generation_config) input_len = model_inputs["input_ids"].shape[-1]
generation = generation[0][input_len:] generation = model.generate(
response = processor.decode(generation, skip_special_tokens=True) **model_inputs, generation_config=generation_config
)
generation = generation[0][input_len:]
response = processor.decode(generation, skip_special_tokens=True)
except:
contents = []
if prefix:
contents += [prefix]
image = PIL.Image.open(sample["image_path"])
contents += [image]
if suffix:
contents += [suffix]
messages = [{"role": "user", "content": contents}]
response = model.chat(
msgs=messages,
tokenizer=processor.tokenizer,
sampling=False,
max_new_tokens=sampling_params["max_new_tokens"],
use_tts_template=False,
generate_audio=False,
temperature=0.0,
)
print(f"response: {response}") print(f"response: {response}")
process_result(response, sample, answer_dict, out_samples) process_result(response, sample, answer_dict, out_samples)

View File

@@ -442,6 +442,8 @@ def calculate_ins_level_acc(results: Dict):
def process_result(response, sample, answer_dict, out_samples): def process_result(response, sample, answer_dict, out_samples):
if response is None:
return
if sample["question_type"] == "multiple-choice": if sample["question_type"] == "multiple-choice":
pred_ans = parse_multi_choice_response( pred_ans = parse_multi_choice_response(
response, sample["all_choices"], sample["index2ans"] response, sample["all_choices"], sample["index2ans"]

View File

@@ -1,5 +1,5 @@
""" """
Multimodality utils Multi-modality utils
""" """
from abc import abstractmethod from abc import abstractmethod
@@ -9,11 +9,13 @@ import torch
from torch import nn from torch import nn
from sglang.srt.managers.schedule_batch import ( from sglang.srt.managers.schedule_batch import (
MultimodalDataItem,
MultimodalInputs, MultimodalInputs,
global_server_args_dict, global_server_args_dict,
logger, logger,
) )
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import print_warning_once
from sglang.utils import logger from sglang.utils import logger
@@ -26,7 +28,7 @@ class MultiModalityDataPaddingPattern:
@abstractmethod @abstractmethod
def pad_input_tokens( def pad_input_tokens(
self, input_ids: List[int], image_inputs: MultimodalInputs self, input_ids: List[int], mm_inputs: MultimodalInputs
) -> List[int]: ) -> List[int]:
""" """
Pad the input ids sequence containing data tokens, and replace them with pad_values Pad the input ids sequence containing data tokens, and replace them with pad_values
@@ -49,13 +51,13 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
""" """
This function will replace the data-tokens inbetween with pad_values accordingly This function will replace the data-tokens inbetween with pad_values accordingly
""" """
pad_values = mm_inputs.pad_values pad_values = [item.pad_value for item in mm_inputs.mm_items]
data_token_pairs = self.data_token_id_pairs data_token_pairs = self.data_token_id_pairs
mm_inputs.image_offsets = [] mm_inputs.data_offsets = []
if data_token_pairs is None: if data_token_pairs is None:
data_token_pairs = [mm_inputs.im_start_id, mm_inputs.im_end_id] data_token_pairs = [mm_inputs.im_start_id, mm_inputs.im_end_id]
if data_token_pairs is None: if data_token_pairs is None:
logger.warning( print_warning_once(
"No data_token_pairs provided, RadixAttention might be influenced." "No data_token_pairs provided, RadixAttention might be influenced."
) )
return input_ids return input_ids
@@ -77,10 +79,10 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
if input_ids[start_idx] in start_token_ids: if input_ids[start_idx] in start_token_ids:
data_idx += 1 data_idx += 1
mm_inputs.image_offsets += [start_idx] mm_inputs.data_offsets += [start_idx]
if data_idx >= len(mm_inputs.pad_values): if data_idx >= len(pad_values):
data_idx = len(mm_inputs.pad_values) - 1 data_idx = len(pad_values) - 1
num_tokens = end_idx - start_idx - 1 num_tokens = end_idx - start_idx - 1
pad_value = pad_values[data_idx] pad_value = pad_values[data_idx]
@@ -94,68 +96,19 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
return padded_ids return padded_ids
class MultModalityDataPaddingPatternSingleToken(MultiModalityDataPaddingPattern):
"""In this pattern, data is represented with a special token_id ( image_inputs.im_token_id ),
which needs first to be expanded to multiple tokens, then replaced with their padding values
This strategy should be used when a single data token represents content that should
be expanded to multiple tokens during processing.
"""
def __init__(
self, num_data_token_calc_func: Callable[[Tuple[int, int, int]], int]
) -> None:
self.num_data_token_calc_func = num_data_token_calc_func
def pad_input_tokens(
self, input_ids: List[int], mm_inputs: MultimodalInputs
) -> List[int]:
"""
This function will follow the procedure of:
1. the data token will be expanded, of which the final number will be calculated by `num_data_token_calc_func`
2. the padded data tokens will be replaced with their pad_values
"""
image_grid_thws = mm_inputs.image_grid_thws
pad_values = mm_inputs.pad_values
image_indices = [
idx for idx, token in enumerate(input_ids) if token == mm_inputs.im_token_id
]
mm_inputs.image_offsets = []
input_ids_with_image = []
for image_cnt, _ in enumerate(image_grid_thws):
# print(f"image_cnt {image_cnt}")
num_image_tokens = self.num_data_token_calc_func(image_grid_thws[image_cnt])
if image_cnt == 0:
non_image_tokens = input_ids[: image_indices[image_cnt]]
else:
non_image_tokens = input_ids[
image_indices[image_cnt - 1] + 1 : image_indices[image_cnt]
]
input_ids_with_image.extend(non_image_tokens)
mm_inputs.image_offsets.append(len(input_ids_with_image))
pad_ids = pad_values * (
(num_image_tokens + len(pad_values)) // len(pad_values)
)
input_ids_with_image.extend(pad_ids[:num_image_tokens])
input_ids_with_image.extend(input_ids[image_indices[-1] + 1 :])
return input_ids_with_image
class MultiModalityDataPaddingPatternImageTokens(MultiModalityDataPaddingPattern): class MultiModalityDataPaddingPatternImageTokens(MultiModalityDataPaddingPattern):
"""In this pattern, data tokens should be represented as image tokens (e.g. <image><image>....<image>)""" """In this pattern, data tokens should be represented as repetitions of a single token
e.g. <image><image>....<image>, or <audio><audio>...<audio>
"""
def __init__(self, image_token_id: torch.Tensor) -> None: def __init__(self, image_token_id: torch.Tensor) -> None:
self.image_token_id = image_token_id self.image_token_id = image_token_id
def pad_input_tokens(self, input_ids: List[int], image_inputs) -> List[int]: def pad_input_tokens(self, input_ids: List[int], mm_inputs) -> List[int]:
""" """
This function will replace the data-tokens in between with pad_values accordingly This function will replace the data-tokens in between with pad_values accordingly
""" """
pad_values = image_inputs.pad_values pad_values = [item.pad_value for item in mm_inputs.mm_items]
assert len(pad_values) != 0 assert len(pad_values) != 0
input_ids_tensor = torch.tensor(input_ids) input_ids_tensor = torch.tensor(input_ids)
@@ -170,138 +123,227 @@ class MultiModalityDataPaddingPatternImageTokens(MultiModalityDataPaddingPattern
return input_ids_tensor.tolist() return input_ids_tensor.tolist()
def get_embedding_and_mask(
data_embedding_func: Callable[[List[MultimodalDataItem]], torch.Tensor],
embedding_items: List[MultimodalDataItem],
placeholder_tensor: torch.Tensor,
input_ids: torch.Tensor,
):
"""
Get the multimodal embedding and its mask from input_ids
"""
# 1. Get the embedding
embedding = data_embedding_func(embedding_items)
# 2. Check the embedding
if embedding.dim() == 2:
num_mm_tokens_in_embedding = embedding.shape[0]
else:
num_mm_tokens_in_embedding = embedding.shape[0] * embedding.shape[1]
# the mask of multimodal tokens from input_ids
special_multimodal_mask = torch.isin(
input_ids,
placeholder_tensor,
).unsqueeze(-1)
num_mm_tokens_in_input_ids = special_multimodal_mask.sum()
if num_mm_tokens_in_input_ids != num_mm_tokens_in_embedding:
logger.warning(
f"Number of tokens in multimodal embedding does not match those in the input text."
f"Got {num_mm_tokens_in_input_ids} tokens in the text but {num_mm_tokens_in_embedding} "
"tokens from multimodal embeddings."
)
if num_mm_tokens_in_input_ids < num_mm_tokens_in_embedding:
# TODO: chunked prefill will split special tokens from input_ids into several passes, failing the embedding
# a fix may be cache the unfinished multimodal embedding for future reuse, determine the tokens to embed with
# extend_start_loc and extend_seq_lens
chunked_prefill_size = global_server_args_dict["chunked_prefill_size"]
if chunked_prefill_size != -1:
logger.warning(
"You may want to avoid this issue by raising `chunked_prefill_size`, or disabling chunked prefill"
)
# extract from the end: this is a compromise
if embedding.dim() == 2:
embedding = embedding[-num_mm_tokens_in_input_ids:, :]
else:
num_multimodal = num_mm_tokens_in_input_ids // embedding.shape[0]
embedding = embedding[-num_multimodal:, :]
else:
raise RuntimeError(
"Insufficient multimodal embedding length. This is an internal error"
)
return embedding, special_multimodal_mask
def embed_mm_inputs( def embed_mm_inputs(
mm_input: MultimodalInputs, mm_inputs: MultimodalInputs,
input_ids: torch.Tensor, input_ids: torch.Tensor,
input_embedding: nn.Embedding, input_embedding: nn.Embedding,
mm_data_embedding_func: Callable[[MultimodalInputs], torch.Tensor], image_data_embedding_func: Callable[
[List[MultimodalDataItem]], torch.Tensor
] = None,
audio_data_embedding_func: Callable[
[List[MultimodalDataItem]], torch.Tensor
] = None,
placeholder_token_ids: List[int] = None, placeholder_token_ids: List[int] = None,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
""" """
Calculate the image embeddings if necessary, then scatter the result with Calculate the multimodal embeddings if necessary, then scatter the result with the help of a boolean mask denoting the embed locations
the help of a boolean mask denoting the embed locations
Returns: Args:
final embedding: Optional[torch.Tensor] placeholder_token_ids: denoting the token of multimodal data in input_ids.
If none, the pad_values of multimodal items are used
Returns:
final embedding: Optional[torch.Tensor]
""" """
if mm_input is None:
if mm_inputs is None:
return None return None
placeholder_token_ids = placeholder_token_ids or mm_input.pad_values # 1. Calculate the multimodal data which exists in input_ids, with the help of pad_values
# we assume that multimodal data are represented with its pad_values in input_ids
placeholder_token_ids = placeholder_token_ids or [
item.pad_value for item in mm_inputs.mm_items
]
# boolean masking the special tokens placeholder_tensor = torch.tensor(placeholder_token_ids, device=input_ids.device)
special_image_mask = torch.isin(
input_ids,
torch.tensor(placeholder_token_ids, device=input_ids.device),
).unsqueeze(-1)
num_image_tokens_in_input_ids = special_image_mask.sum() placeholder_masks = torch.isin(input_ids, placeholder_tensor)
# print(f"{num_image_tokens_in_input_ids}")
# print(f"{input_ids}")
# return appearing_pad_values = torch.unique(
if num_image_tokens_in_input_ids == 0: input_ids[placeholder_masks], return_counts=False
# unexpected )
if appearing_pad_values.numel() == 0:
# all been prefixed
inputs_embeds = input_embedding(input_ids) inputs_embeds = input_embedding(input_ids)
else: else:
# print(f"Getting image feature") appearing_items = [
image_embedding = mm_data_embedding_func(mm_input) item
for item in mm_inputs.mm_items
if item.pad_value is not None and item.pad_value in appearing_pad_values
]
# print(f"image_embedding: {image_embedding.shape}") using_all_items = False
if len(appearing_items) == 0:
if image_embedding.dim() == 2: # This happens mostly when arg placeholder_token_ids is passed
num_image_tokens_in_embedding = image_embedding.shape[0] logger.warning_once(
else: "No multimodal data item's pad value exist in placeholder ids. Using all items"
num_image_tokens_in_embedding = (
image_embedding.shape[0] * image_embedding.shape[1]
)
if num_image_tokens_in_input_ids != num_image_tokens_in_embedding:
num_image = num_image_tokens_in_input_ids // image_embedding.shape[1]
image_embedding = image_embedding[:num_image, :]
logger.warning(
f"Number of images does not match number of special image tokens in the input text. "
f"Got {num_image_tokens_in_input_ids} image tokens in the text but {num_image_tokens_in_embedding} "
"tokens from image embeddings."
) )
using_all_items = True
appearing_items = mm_inputs.mm_items
# TODO: chunked prefill will split special tokens from input_ids into several passes, failing the embedding embeddings, masks = [], []
# a fix may be cache the unfinished image embedding for future reuse, determine the tokens to embed with
# extend_start_loc and extend_seq_lens # 2. Get multimodal embedding separately
if num_image_tokens_in_input_ids > num_image_tokens_in_embedding: # TODO: make this more generic
chunked_prefill_size = global_server_args_dict["chunked_prefill_size"] # Try get image embedding if any
if chunked_prefill_size != -1: if (
logger.warning( any(True for item in appearing_items if item.is_image())
"You may want to avoid this issue by raising `chunked_prefill_size`, or disabling chunked_prefill" and image_data_embedding_func
):
items = [item for item in appearing_items if item.is_image()]
embedding, mask = get_embedding_and_mask(
data_embedding_func=image_data_embedding_func,
embedding_items=items,
placeholder_tensor=(
placeholder_tensor
if using_all_items
else torch.tensor(
[item.pad_value for item in items],
device=input_ids.device,
) )
),
input_ids=input_ids,
)
embeddings += [embedding]
masks += [mask]
# Try get audio embedding if any
if (
any(True for item in appearing_items if item.is_audio())
and audio_data_embedding_func
):
items = [item for item in appearing_items if item.is_audio()]
embedding, mask = get_embedding_and_mask(
data_embedding_func=audio_data_embedding_func,
embedding_items=items,
placeholder_tensor=(
placeholder_tensor
if using_all_items
else torch.tensor(
[item.pad_value for item in items],
device=input_ids.device,
)
),
input_ids=input_ids,
)
embeddings += [embedding]
masks += [mask]
# 3. Get input embeddings
vocab_size = input_embedding.num_embeddings vocab_size = input_embedding.num_embeddings
# Important: clamp after getting original image regions # Important: clamp after getting original multimodal regions
# Clamp input ids. This is because the input_ids for the image tokens are # Clamp input ids. This is because the input_ids for the multimodal tokens are
# filled with the hash values of the image for the prefix matching in the radix attention. # filled with the hash values of the multimodal for the prefix matching in the radix attention.
# There values are useless because their embeddings will be replaced by vision embeddings anyway. # There values are useless because their embeddings will be replaced by vision embeddings anyway.
input_ids.clamp_(min=0, max=vocab_size - 1) input_ids.clamp_(min=0, max=vocab_size - 1)
inputs_embeds = input_embedding(input_ids) inputs_embeds = input_embedding(input_ids)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to( # 4. scatter embeddings into input embedding
inputs_embeds.device for embedding, mask in zip(embeddings, masks):
) mask = mask.expand_as(inputs_embeds).to(inputs_embeds.device)
inputs_embeds = inputs_embeds.masked_scatter(
inputs_embeds = inputs_embeds.masked_scatter( mask,
special_image_mask, embedding.to(inputs_embeds.device, inputs_embeds.dtype),
image_embedding.to(inputs_embeds.device, inputs_embeds.dtype), )
)
return inputs_embeds
def embed_image_embedding(
inputs_embeds: torch.Tensor,
image_embedding: torch.Tensor,
image_bounds: torch.Tensor,
) -> torch.Tensor:
"""
scatter image_embedding into inputs_embeds according to image_bounds
"""
if len(image_bounds) > 0:
image_indices = torch.stack(
[
torch.arange(start, end, dtype=torch.long)
for start, end in image_bounds.tolist()
]
).to(inputs_embeds.device)
inputs_embeds.scatter_(
0,
image_indices.view(-1, 1).repeat(1, inputs_embeds.shape[-1]),
image_embedding.view(-1, image_embedding.shape[-1]),
)
return inputs_embeds return inputs_embeds
def general_mm_embed_routine( def general_mm_embed_routine(
input_ids: torch.Tensor, input_ids: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
embed_tokens: nn.Embedding, language_model: nn.Module,
mm_data_embedding_func: Callable[[MultimodalInputs], torch.Tensor], image_data_embedding_func: Callable[
[List[MultimodalDataItem]], torch.Tensor
] = None,
audio_data_embedding_func: Callable[
[List[MultimodalDataItem]], torch.Tensor
] = None,
placeholder_token_ids: List[int] = None, placeholder_token_ids: List[int] = None,
): **kwargs,
) -> torch.Tensor:
""" """
a general wrapper function to get final input embeds from multimodal models A general wrapper function to get final input embeds from multimodal models with a language model as causal model
with a language model as causal model
Args: Args:
placeholder_token_ids (List[int]): the ids of mm data placeholder tokens placeholder_token_ids (List[int]): the ids of mm data placeholder tokens
image_data_embedding_func : the function returning the image embedding
audio_data_embedding_func : the function returning the image embedding
Returns:
inputs_embedding
forwarded hidden states
""" """
assert hasattr(language_model, "get_input_embeddings")
embed_tokens = language_model.get_input_embeddings()
if ( if (
not forward_batch.forward_mode.is_decode() not forward_batch.forward_mode.is_decode()
and forward_batch.contains_mm_inputs() and forward_batch.contains_mm_inputs()
): ):
image = forward_batch.merge_mm_inputs() mm_input = forward_batch.merge_mm_inputs()
inputs_embeds = embed_mm_inputs( inputs_embeds = embed_mm_inputs(
mm_input=image, mm_inputs=mm_input,
input_ids=input_ids, input_ids=input_ids,
input_embedding=embed_tokens, input_embedding=embed_tokens,
mm_data_embedding_func=mm_data_embedding_func, image_data_embedding_func=image_data_embedding_func,
audio_data_embedding_func=audio_data_embedding_func,
placeholder_token_ids=placeholder_token_ids, placeholder_token_ids=placeholder_token_ids,
) )
# once used, mm_inputs is useless # once used, mm_inputs is useless
@@ -310,7 +352,13 @@ def general_mm_embed_routine(
else: else:
inputs_embeds = embed_tokens(input_ids) inputs_embeds = embed_tokens(input_ids)
return inputs_embeds hidden_states = language_model(
input_ids=None,
forward_batch=forward_batch,
input_embeds=inputs_embeds,
**kwargs,
)
return hidden_states
def get_multimodal_data_bounds( def get_multimodal_data_bounds(
@@ -322,15 +370,13 @@ def get_multimodal_data_bounds(
Returns: Returns:
[bounds_count, 2] [bounds_count, 2]
""" """
# All the images in the batch should share the same special image # All the multimodal data in the batch should share the same special bound token ids.
# bound token ids.
start_tokens = [s for s, _e in token_pairs] start_tokens = [s for s, _e in token_pairs]
end_tokens = [e for _s, e in token_pairs] end_tokens = [e for _s, e in token_pairs]
assert all(isinstance(t, int) for t in start_tokens) assert all(isinstance(t, int) for t in start_tokens)
assert all(isinstance(t, int) for t in end_tokens) assert all(isinstance(t, int) for t in end_tokens)
# print(input_ids)
start_cond = torch.isin( start_cond = torch.isin(
input_ids, torch.tensor(start_tokens, device=input_ids.device) input_ids, torch.tensor(start_tokens, device=input_ids.device)
) )
@@ -339,7 +385,7 @@ def get_multimodal_data_bounds(
(data_start_tokens,) = torch.where(start_cond) (data_start_tokens,) = torch.where(start_cond)
(data_end_tokens,) = torch.where(end_cond) (data_end_tokens,) = torch.where(end_cond)
# the im_start_id sometimes can be cached as prefix, but it is needed for the embedding of the images # the im_start_id sometimes can be cached as prefix, but it is needed for the embedding of the multimodal data
if len(data_start_tokens) != len(data_end_tokens): if len(data_start_tokens) != len(data_end_tokens):
if ( if (
len(data_start_tokens) + 1 == len(data_end_tokens) len(data_start_tokens) + 1 == len(data_end_tokens)
@@ -352,14 +398,14 @@ def get_multimodal_data_bounds(
data_start_tokens, data_start_tokens,
] ]
) )
valid_image_nums = min(len(data_start_tokens), len(data_end_tokens)) valid_mm_data_nums = min(len(data_start_tokens), len(data_end_tokens))
if valid_image_nums == 0: if valid_mm_data_nums == 0:
return torch.zeros((0, 2), device=input_ids.device) return torch.zeros((0, 2), device=input_ids.device)
# Filter out pairs where start_token >= end_token # Filter out pairs where start_token >= end_token
valid_pairs = [] valid_pairs = []
for i in range(valid_image_nums): for i in range(valid_mm_data_nums):
start_token = data_start_tokens[i] start_token = data_start_tokens[i]
end_token = data_end_tokens[i] end_token = data_end_tokens[i]
if start_token < end_token: if start_token < end_token:

View File

@@ -64,5 +64,3 @@ def get_mm_processor(
f"No processor registered for architecture: {hf_config.architectures}.\n" f"No processor registered for architecture: {hf_config.architectures}.\n"
f"Registered architectures: {[model_cls.__name__ for model_cls in PROCESSOR_MAPPING.keys()]}" f"Registered architectures: {[model_cls.__name__ for model_cls in PROCESSOR_MAPPING.keys()]}"
) )
self.image_proce

View File

@@ -8,18 +8,10 @@ from typing import Optional
import numpy as np import numpy as np
import PIL import PIL
import transformers
from decord import VideoReader, cpu from decord import VideoReader, cpu
from PIL import Image from PIL import Image
from sglang.srt.utils import load_audio, load_image, logger from sglang.srt.utils import encode_video, load_audio, load_image, logger
global global_processor
def get_global_processor():
global global_processor
return global_processor
@dataclasses.dataclass @dataclasses.dataclass
@@ -27,9 +19,6 @@ class BaseMultiModalProcessorOutput:
# input_text, with each frame of video/image represented with a image_token # input_text, with each frame of video/image represented with a image_token
input_text: str input_text: str
mm_data_hashes: Optional[list[int]]
# images
image_sizes: Optional[list[int]]
# frames loaded from image and video, in given order # frames loaded from image and video, in given order
images: Optional[list[PIL.Image]] = None images: Optional[list[PIL.Image]] = None
@@ -37,7 +26,7 @@ class BaseMultiModalProcessorOutput:
audios: Optional[list[np.ndarray]] = None audios: Optional[list[np.ndarray]] = None
def normalize(self): def normalize(self):
for field_name in ["data_hashes", "image_sizes", "images", "audios"]: for field_name in ["image_sizes", "images", "audios"]:
field = getattr(self, field_name, None) field = getattr(self, field_name, None)
if field is not None and isinstance(field, list) and len(field) == 0: if field is not None and isinstance(field, list) and len(field) == 0:
setattr(self, field_name, None) setattr(self, field_name, None)
@@ -67,28 +56,35 @@ class BaseMultimodalProcessor(ABC):
# FIXME: not accurate, model and image specific # FIXME: not accurate, model and image specific
self.NUM_TOKEN_PER_FRAME = 330 self.NUM_TOKEN_PER_FRAME = 330
# Initialize global processor first self.io_executor = concurrent.futures.ThreadPoolExecutor(
init_global_processor(self, server_args) max_workers=int(os.environ.get("SGLANG_IO_WORKERS", 4))
)
self.executor = concurrent.futures.ProcessPoolExecutor( self.cpu_executor = concurrent.futures.ProcessPoolExecutor(
initializer=init_global_processor,
mp_context=mp.get_context("fork"), mp_context=mp.get_context("fork"),
initargs=( max_workers=int(os.environ.get("SGLANG_CPU_WORKERS", os.cpu_count())),
self,
server_args,
),
max_workers=int(os.environ.get("SGLANG_CPU_COUNT", os.cpu_count())),
) )
def _build_processor(self, server_args): def process_mm_data(
"""Init the global processor for multi modal models.""" self, input_text, images=None, videos=None, audios=None, **kwargs
from sglang.srt.hf_transformers_utils import get_processor ):
"""
process multimodal data with transformers AutoProcessor
"""
if images is not None:
kwargs["images"] = images
if videos is not None:
kwargs["videos"] = videos
if audios is not None:
kwargs["audios"] = audios
return get_processor( processor = self._processor
server_args.tokenizer_path, result = processor.__call__(
tokenizer_mode=server_args.tokenizer_mode, text=[input_text],
trust_remote_code=server_args.trust_remote_code, padding=True,
return_tensors="pt",
**kwargs,
) )
return result
@abstractmethod @abstractmethod
async def process_mm_data_async( async def process_mm_data_async(
@@ -115,33 +111,9 @@ class BaseMultimodalProcessor(ABC):
return estimated_frames_list return estimated_frames_list
@staticmethod
def encode_video(video_path, frame_count_limit=None):
if not os.path.exists(video_path):
logger.error(f"Video {video_path} does not exist")
return []
if frame_count_limit == 0:
return []
def uniform_sample(l, n):
gap = len(l) / n
idxs = [int(i * gap + gap / 2) for i in range(n)]
return [l[i] for i in idxs]
vr = VideoReader(video_path, ctx=cpu(0))
sample_fps = round(vr.get_avg_fps() / 1) # FPS
frame_indices = [i for i in range(0, len(vr), sample_fps)]
if frame_count_limit is not None and len(frame_indices) > frame_count_limit:
frame_indices = uniform_sample(frame_indices, frame_count_limit)
frames = vr.get_batch(frame_indices).asnumpy()
frames = [Image.fromarray(v.astype("uint8")) for v in frames]
return frames
def load_mm_data( def load_mm_data(
self, self,
input_ids: list[int], prompt: str,
multimodal_tokens: MultimodalSpecialTokens, multimodal_tokens: MultimodalSpecialTokens,
max_req_input_len: int, max_req_input_len: int,
image_data: Optional[list] = None, image_data: Optional[list] = None,
@@ -167,11 +139,13 @@ class BaseMultimodalProcessor(ABC):
else: else:
multimodal_tokens.image_token = multimodal_tokens.image_token multimodal_tokens.image_token = multimodal_tokens.image_token
if isinstance(input_ids, list) and return_text: assert isinstance(prompt, str)
assert len(input_ids) and isinstance(input_ids[0], int)
input_text = self._processor.tokenizer.decode(input_ids) if isinstance(prompt, list) and return_text:
assert len(prompt) and isinstance(prompt[0], int)
prompt = self._processor.tokenizer.decode(prompt)
else: else:
input_text = input_ids prompt = prompt
if return_text: if return_text:
import re import re
@@ -181,7 +155,7 @@ class BaseMultimodalProcessor(ABC):
+ ")" + ")"
) )
# split text into list of normal text and special tokens # split text into list of normal text and special tokens
text_parts = re.split(pattern, input_text) text_parts = re.split(pattern, prompt)
# TODO(mick): load from server_args, env, or sampling_params # TODO(mick): load from server_args, env, or sampling_params
MAX_NUM_FRAMES = 30 MAX_NUM_FRAMES = 30
@@ -217,7 +191,7 @@ class BaseMultimodalProcessor(ABC):
): ):
# video # video
path = image_file[len("video:") :] path = image_file[len("video:") :]
frames = BaseMultimodalProcessor.encode_video( frames = encode_video(
path, frame_count_limit=frames_to_process path, frame_count_limit=frames_to_process
) )
else: else:
@@ -254,19 +228,9 @@ class BaseMultimodalProcessor(ABC):
raise RuntimeError(f"An exception occurred while loading images: {e}") raise RuntimeError(f"An exception occurred while loading images: {e}")
out = BaseMultiModalProcessorOutput( out = BaseMultiModalProcessorOutput(
mm_data_hashes=hashes,
image_sizes=image_sizes,
images=images, images=images,
audios=audios, audios=audios,
input_text=new_text, input_text=new_text,
) )
out.normalize() out.normalize()
return out return out
def init_global_processor(sglang_processor: BaseMultimodalProcessor, server_args):
"""
Init the global processor for multimodal models."""
global global_processor
transformers.logging.set_verbosity_error()
global_processor = sglang_processor._build_processor(server_args=server_args)

View File

@@ -1,10 +1,9 @@
import asyncio
from typing import List, Union from typing import List, Union
from sglang.srt.managers.multimodal_processors.base_processor import ( from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor, BaseMultimodalProcessor,
get_global_processor,
) )
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.models.clip import CLIPModel from sglang.srt.models.clip import CLIPModel
from sglang.srt.utils import load_image from sglang.srt.utils import load_image
@@ -15,29 +14,6 @@ class ClipImageProcessor(BaseMultimodalProcessor):
def __init__(self, hf_config, server_args, _processor): def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor) super().__init__(hf_config, server_args, _processor)
@staticmethod
def _process_single_image_task(images, input_text):
# input_ids', 'attention_mask', 'pixel_values', 'aspect_ratio_ids', 'aspect_ratio_mask', 'cross_attention_mask'
return get_global_processor()(
images=images, text=input_text, return_tensors="pt"
)
async def _process_single_image(self, images, input_text):
if self.executor is not None:
loop = asyncio.get_event_loop()
image_inputs = await loop.run_in_executor(
self.executor,
ClipImageProcessor._process_single_image_task,
images,
input_text,
)
else:
image_inputs = self._processor(
images=images, text=[input_text], return_tensors="pt"
)
return image_inputs
async def process_mm_data_async( async def process_mm_data_async(
self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs
): ):
@@ -56,8 +32,13 @@ class ClipImageProcessor(BaseMultimodalProcessor):
else: else:
images = load_image(image_data[0])[0] images = load_image(image_data[0])[0]
image_inputs = await self._process_single_image(images, input_text) image_inputs = self.process_mm_data(input_text=input_text, images=images)
image_inputs["data_hashes"] = [hash(str(image_data))] image_inputs["data_hashes"] = [hash(str(image_data))]
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0] image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
image_inputs["mm_items"] = [
MultimodalDataItem(
pixel_values=image_inputs["pixel_values"], modality=Modality.IMAGE
)
]
return image_inputs return image_inputs

View File

@@ -16,15 +16,14 @@
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import asyncio
import torch import torch
from sglang.srt.managers.multimodal_processors.base_processor import ( from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor, BaseMultimodalProcessor,
MultimodalSpecialTokens, MultimodalSpecialTokens,
get_global_processor,
) )
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.models.deepseek_vl2 import DeepseekVL2ForCausalLM from sglang.srt.models.deepseek_vl2 import DeepseekVL2ForCausalLM
@@ -35,51 +34,6 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
super().__init__(hf_config, server_args, _processor) super().__init__(hf_config, server_args, _processor)
self.IMAGE_TOKEN = "<image>" self.IMAGE_TOKEN = "<image>"
@staticmethod
def _process_images_task(image, input_text, max_req_input_len):
processor = get_global_processor()
res = processor.__call__(
conversations=input_text, images=image, max_req_input_len=max_req_input_len
)
image_token_id = processor.image_token_id
res["im_token_id"] = image_token_id
return res
async def _process_images(self, image_data, input_text, max_req_input_len):
if self.executor is not None:
loop = asyncio.get_event_loop()
image_inputs = await loop.run_in_executor(
self.executor,
DeepseekVL2ImageProcessor._process_images_task,
image_data,
input_text,
max_req_input_len,
)
else:
image_inputs = self._process_images_task(
image_data, input_text, max_req_input_len
)
return image_inputs
async def _process_images(self, image_data, input_text, max_req_input_len):
if self.executor is not None:
loop = asyncio.get_event_loop()
image_inputs = await loop.run_in_executor(
self.executor,
DeepseekVL2ImageProcessor._process_images_task,
image_data,
input_text,
max_req_input_len,
)
else:
image_inputs = self._process_images_task(
image_data, input_text, max_req_input_len
)
return image_inputs
async def process_mm_data_async( async def process_mm_data_async(
self, image_data, input_ids, request_obj, max_req_input_len, *args, **kwargs self, image_data, input_ids, request_obj, max_req_input_len, *args, **kwargs
): ):
@@ -89,8 +43,6 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
if not isinstance(image_data, list): if not isinstance(image_data, list):
image_data = [image_data] image_data = [image_data]
images, image_sizes = [], []
image_token = self.IMAGE_TOKEN image_token = self.IMAGE_TOKEN
base_output = self.load_mm_data( base_output = self.load_mm_data(
input_ids, input_ids,
@@ -98,8 +50,11 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token), multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
max_req_input_len=max_req_input_len, max_req_input_len=max_req_input_len,
) )
res = await self._process_images( res = self.process_mm_data(
base_output.images, base_output.input_text, max_req_input_len input_text=base_output.input_text,
images=base_output.images,
max_req_input_len=max_req_input_len,
conversations=base_output.input_text,
) )
images_seq_mask = res["images_seq_mask"] images_seq_mask = res["images_seq_mask"]
images_spatial_crop = res["images_spatial_crop"] images_spatial_crop = res["images_spatial_crop"]
@@ -107,13 +62,17 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
batched_images_spatial_crop.append(images_spatial_crop) batched_images_spatial_crop.append(images_spatial_crop)
batched_images_spatial_crop = torch.stack(batched_images_spatial_crop, dim=0) batched_images_spatial_crop = torch.stack(batched_images_spatial_crop, dim=0)
items = []
item = MultimodalDataItem(
pixel_values=res["images"],
modality=Modality.IMAGE,
image_emb_mask=images_seq_mask,
image_spatial_crop=batched_images_spatial_crop,
)
items += [item]
return { return {
"mm_items": items,
"input_ids": res["input_ids"].tolist(), "input_ids": res["input_ids"].tolist(),
"pixel_values": res["images"], "im_token_id": self._processor.image_token_id,
"im_token_id": res["im_token_id"],
"data_hashes": base_output.mm_data_hashes,
"image_sizes": image_sizes,
"images_emb_mask": images_seq_mask,
"image_spatial_crop": batched_images_spatial_crop,
"modalities": request_obj.modalities or ["image"],
} }

View File

@@ -7,8 +7,8 @@ from sglang.srt.managers.multimodal_processor import (
) )
from sglang.srt.managers.multimodal_processors.base_processor import ( from sglang.srt.managers.multimodal_processors.base_processor import (
MultimodalSpecialTokens, MultimodalSpecialTokens,
get_global_processor,
) )
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.models.gemma3_mm import Gemma3ForConditionalGeneration from sglang.srt.models.gemma3_mm import Gemma3ForConditionalGeneration
# Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma3/image_processing_gemma3_fast.py # Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma3/image_processing_gemma3_fast.py
@@ -25,28 +25,6 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
self.IM_START_TOKEN_ID = hf_config.boi_token_index self.IM_START_TOKEN_ID = hf_config.boi_token_index
self.IM_END_TOKEN_ID = hf_config.eoi_token_index self.IM_END_TOKEN_ID = hf_config.eoi_token_index
async def _process_single_image(self, images, input_text) -> dict:
if isinstance(images, list) and len(images) == 0:
images = None
processor = get_global_processor()
result = processor.__call__(
text=[input_text],
images=images,
padding=True,
return_tensors="pt",
# if RGBA, this needs to be set
# images_kwargs={
# "input_data_format": ChannelDimension.FIRST
# }
)
pixel_values = getattr(result, "pixel_values", None)
return {
"input_ids": result.input_ids,
"pixel_values": pixel_values,
}
async def process_mm_data_async( async def process_mm_data_async(
self, self,
image_data: List[Union[str, bytes]], image_data: List[Union[str, bytes]],
@@ -63,21 +41,28 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
image_token = self.IMAGE_TOKEN image_token = self.IMAGE_TOKEN
base_output = self.load_mm_data( base_output = self.load_mm_data(
input_ids=input_ids, prompt=input_ids,
image_data=image_data, image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token), multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
max_req_input_len=max_req_input_len, max_req_input_len=max_req_input_len,
discard_alpha_channel=True, discard_alpha_channel=True,
) )
ret = await self._process_single_image( ret = self.process_mm_data(
input_text=base_output.input_text, images=base_output.images input_text=base_output.input_text, images=base_output.images
) )
items = []
for i, image in enumerate(base_output.images):
item = MultimodalDataItem(
pixel_values=ret["pixel_values"][i],
modality=Modality.IMAGE,
)
items += [item]
return { return {
"mm_items": items,
"input_ids": ret["input_ids"].flatten().tolist(), "input_ids": ret["input_ids"].flatten().tolist(),
"pixel_values": ret["pixel_values"],
"data_hashes": base_output.mm_data_hashes,
"im_start_id": self.IM_START_TOKEN_ID, "im_start_id": self.IM_START_TOKEN_ID,
"im_end_id": self.IM_END_TOKEN_ID, "im_end_id": self.IM_END_TOKEN_ID,
} }

View File

@@ -1,11 +1,10 @@
import asyncio
from typing import List, Union from typing import List, Union
from sglang.srt.managers.multimodal_processors.base_processor import ( from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor, BaseMultimodalProcessor,
MultimodalSpecialTokens, MultimodalSpecialTokens,
get_global_processor,
) )
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.models.deepseek_janus_pro import MultiModalityCausalLM from sglang.srt.models.deepseek_janus_pro import MultiModalityCausalLM
@@ -15,37 +14,6 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
def __init__(self, hf_config, server_args, _processor): def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor) super().__init__(hf_config, server_args, _processor)
@staticmethod
def _process_images_task(images, input_text):
processor = get_global_processor()
result = processor.__call__(
prompt=input_text, images=images, return_tensors="pt"
)
return {
"input_ids": result["input_ids"],
"pixel_values": result["pixel_values"],
"images_emb_mask": result["images_emb_mask"],
"im_start_id": processor.image_start_id,
"im_end_id": processor.image_end_id,
"im_token_id": processor.image_id,
}
async def _process_images(self, images, input_text):
if self.executor is not None:
loop = asyncio.get_event_loop()
image_inputs = await loop.run_in_executor(
self.executor,
JanusProImageProcessor._process_images_task,
images,
input_text,
)
else:
image_inputs = self._processor(
images=images, text=input_text, return_tensors="pt"
)
return image_inputs
async def process_mm_data_async( async def process_mm_data_async(
self, self,
image_data: List[Union[str, bytes]], image_data: List[Union[str, bytes]],
@@ -60,25 +28,31 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
if not isinstance(image_data, list): if not isinstance(image_data, list):
image_data = [image_data] image_data = [image_data]
processor = self._processor
base_out = self.load_mm_data( base_out = self.load_mm_data(
input_ids=input_ids, prompt=input_ids,
image_data=image_data, image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens( multimodal_tokens=MultimodalSpecialTokens(image_token=processor.image_tag),
image_token="<image_placeholder>"
),
max_req_input_len=max_req_input_len, max_req_input_len=max_req_input_len,
) )
images = base_out.images images = base_out.images
res = await self._process_images(images=images, input_text=base_out.input_text) res = self.process_mm_data(
# print(res) input_text=base_out.input_text,
# print(base_out) prompt=base_out.input_text,
# print("", res["images_emb_mask"].shape) images=images,
)
return { return {
"mm_items": [
MultimodalDataItem(
pixel_values=res["pixel_values"],
image_emb_mask=res["images_emb_mask"],
modality=Modality.IMAGE,
)
],
"input_ids": res["input_ids"].flatten().tolist(), "input_ids": res["input_ids"].flatten().tolist(),
"pixel_values": res["pixel_values"], "im_start_id": processor.image_start_id,
"images_emb_mask": res["images_emb_mask"], "im_end_id": processor.image_end_id,
"data_hashes": base_out.mm_data_hashes, "im_token_id": processor.image_id,
"im_start_id": res["im_start_id"],
"im_end_id": res["im_end_id"],
"im_token_id": res["im_token_id"],
} }

View File

@@ -5,8 +5,8 @@ import numpy as np
from sglang.srt.managers.multimodal_processors.base_processor import ( from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor, BaseMultimodalProcessor,
get_global_processor,
) )
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.mm_utils import expand2square, process_anyres_image from sglang.srt.mm_utils import expand2square, process_anyres_image
from sglang.srt.models.llava import LlavaMistralForCausalLM, LlavaQwenForCausalLM from sglang.srt.models.llava import LlavaMistralForCausalLM, LlavaQwenForCausalLM
from sglang.srt.models.llavavid import LlavaVidForCausalLM from sglang.srt.models.llavavid import LlavaVidForCausalLM
@@ -25,11 +25,10 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
image_data: Union[str, bytes], image_data: Union[str, bytes],
image_aspect_ratio: Optional[str] = None, image_aspect_ratio: Optional[str] = None,
image_grid_pinpoints: Optional[str] = None, image_grid_pinpoints: Optional[str] = None,
image_processor=None, processor=None,
): ):
processor = get_global_processor()
image_processor = image_processor or processor.image_processor image_processor = processor.image_processor
try: try:
image, image_size = load_image(image_data) image, image_size = load_image(image_data)
@@ -72,18 +71,22 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
async def _process_single_image( async def _process_single_image(
self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str
): ):
if self.executor is not None: if self.cpu_executor is not None:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
return await loop.run_in_executor( return await loop.run_in_executor(
self.executor, self.cpu_executor,
LlavaImageProcessor._process_single_image_task, LlavaImageProcessor._process_single_image_task,
image_data, image_data,
aspect_ratio, aspect_ratio,
grid_pinpoints, grid_pinpoints,
self._processor,
) )
else: else:
return self._process_single_image_task( return self._process_single_image_task(
image_data, aspect_ratio, grid_pinpoints image_data,
aspect_ratio,
grid_pinpoints,
self._processor.image_processor,
) )
async def process_mm_data_async( async def process_mm_data_async(
@@ -134,14 +137,22 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
pixel_values, image_hash, image_size = await self._process_single_image( pixel_values, image_hash, image_size = await self._process_single_image(
image_data[0], aspect_ratio, grid_pinpoints image_data[0], aspect_ratio, grid_pinpoints
) )
data_hashes = [image_hash]
image_sizes = [image_size] image_sizes = [image_size]
else: else:
raise ValueError(f"Invalid image data: {image_data}") raise ValueError(f"Invalid image data: {image_data}")
modality = Modality.IMAGE
if isinstance(request_obj.modalities, list):
if request_obj.modalities[0] == "multi-images":
modality = Modality.MULTI_IMAGES
elif request_obj.modalities[0] == "video":
modality = Modality.VIDEO
return { return {
"pixel_values": pixel_values, "mm_items": [
"data_hashes": data_hashes, MultimodalDataItem(
"image_sizes": image_sizes, pixel_values=pixel_values,
"modalities": request_obj.modalities or ["image"], image_sizes=image_sizes,
modality=modality,
)
],
} }

View File

@@ -1,13 +1,13 @@
import asyncio
from typing import List, Union from typing import List, Union
import torch import torch
from transformers import BaseImageProcessorFast
from sglang.srt.managers.multimodal_processors.base_processor import ( from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor, BaseMultimodalProcessor,
MultimodalSpecialTokens, MultimodalSpecialTokens,
get_global_processor,
) )
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.models.minicpmo import MiniCPMO from sglang.srt.models.minicpmo import MiniCPMO
from sglang.srt.models.minicpmv import MiniCPMV from sglang.srt.models.minicpmv import MiniCPMV
@@ -21,19 +21,23 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
self.image_token = "(<image>./</image>)" self.image_token = "(<image>./</image>)"
self.audio_token = "(<audio>./</audio>)" self.audio_token = "(<audio>./</audio>)"
@staticmethod def process_data_task(self, input_text, images=None, audios=None):
def _process_data_task(input_text, images=None, audios=None):
if isinstance(images, list) and len(images) == 0: if isinstance(images, list) and len(images) == 0:
images = None images = None
if isinstance(audios, list) and len(audios) == 0: if isinstance(audios, list) and len(audios) == 0:
audios = None audios = None
result = get_global_processor().__call__( processor = self._processor
args = {}
if isinstance(processor, BaseImageProcessorFast):
args["device"] = "cuda"
result = self._processor.__call__(
text=input_text, text=input_text,
images=images, images=images,
audios=audios, audios=audios,
return_tensors="pt", return_tensors="pt",
chunk_input=True, chunk_input=True,
**args,
) )
return { return {
"input_ids": result.input_ids, "input_ids": result.input_ids,
@@ -44,23 +48,6 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
"audio_bounds": getattr(result, "audio_bounds", None), "audio_bounds": getattr(result, "audio_bounds", None),
} }
async def _process_data(self, images, input_text, audios=None):
if self.executor is not None:
loop = asyncio.get_event_loop()
multimodal_data_inputs = await loop.run_in_executor(
self.executor,
MiniCPMMultimodalProcessor._process_data_task,
input_text,
images,
audios,
)
else:
multimodal_data_inputs = self._processor(
images=images, text=input_text, audios=audios, return_tensors="pt"
)
return multimodal_data_inputs
async def process_mm_data_async( async def process_mm_data_async(
self, self,
image_data: List[Union[str, bytes]], image_data: List[Union[str, bytes]],
@@ -77,7 +64,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
audio_data = [audio_data] audio_data = [audio_data]
base_output = self.load_mm_data( base_output = self.load_mm_data(
input_ids=input_ids, prompt=input_ids,
max_req_input_len=max_req_input_len, max_req_input_len=max_req_input_len,
audio_data=audio_data, audio_data=audio_data,
image_data=image_data, image_data=image_data,
@@ -88,9 +75,9 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
if base_output is None: if base_output is None:
return None return None
res = await self._process_data( res = self.process_mm_data(
images=base_output.images,
input_text=base_output.input_text, input_text=base_output.input_text,
images=base_output.images,
audios=base_output.audios, audios=base_output.audios,
) )
@@ -142,23 +129,33 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
tgt_sizes_flat += [tgt_n] tgt_sizes_flat += [tgt_n]
pixel_values = pixel_values_flat pixel_values = pixel_values_flat
if len(tgt_sizes_flat) == 0:
tgt_sizes = None items = []
else: if len(pixel_values) != 0:
tgt_sizes = torch.stack(tgt_sizes_flat) item = MultimodalDataItem(
if not isinstance(res["audio_features"], list): pixel_values=pixel_values,
res["audio_features"] = [res["audio_features"]] tgt_size=tgt_sizes_flat,
modality=Modality.IMAGE,
)
items += [item]
if (
"audio_features" in res
and res["audio_features"] is not None
and len(res["audio_features"]) != 0
):
item = MultimodalDataItem(
audio_features=[res["audio_features"]],
audio_feature_lens=res["audio_feature_lens"],
modality=Modality.AUDIO,
)
items += [item]
return { return {
"mm_items": items,
"input_ids": res["input_ids"].flatten().tolist(), "input_ids": res["input_ids"].flatten().tolist(),
"pixel_values": pixel_values,
"tgt_sizes": tgt_sizes,
"data_hashes": base_output.mm_data_hashes,
"modalities": request_obj.modalities or ["image"],
"audio_start_id": audio_start_id, "audio_start_id": audio_start_id,
"audio_end_id": audio_end_id, "audio_end_id": audio_end_id,
"audio_features": res["audio_features"],
"audio_bounds": res["audio_bounds"],
"audio_feature_lens": res["audio_feature_lens"],
"im_token_id": im_token_id, "im_token_id": im_token_id,
"im_start_id": tokenizer.im_start_id, "im_start_id": tokenizer.im_start_id,
"im_end_id": tokenizer.im_end_id, "im_end_id": tokenizer.im_end_id,

View File

@@ -1,10 +1,9 @@
import asyncio
from typing import List, Union from typing import List, Union
from sglang.srt.managers.multimodal_processors.base_processor import ( from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor, BaseMultimodalProcessor,
get_global_processor,
) )
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.models.mllama import MllamaForConditionalGeneration from sglang.srt.models.mllama import MllamaForConditionalGeneration
from sglang.srt.utils import load_image from sglang.srt.utils import load_image
@@ -15,25 +14,6 @@ class MllamaImageProcessor(BaseMultimodalProcessor):
def __init__(self, hf_config, server_args, _processor): def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor) super().__init__(hf_config, server_args, _processor)
@staticmethod
def _process_single_image_task(images, input_text):
# input_ids', 'attention_mask', 'pixel_values', 'aspect_ratio_ids', 'aspect_ratio_mask', 'cross_attention_mask'
return get_global_processor()(images, input_text, return_tensors="pt")
async def _process_single_image(self, images, input_text):
if self.executor is not None:
loop = asyncio.get_event_loop()
image_inputs = await loop.run_in_executor(
self.executor,
MllamaImageProcessor._process_single_image_task,
images,
input_text,
)
else:
image_inputs = self._processor(images, input_text, return_tensors="pt")
return image_inputs
async def process_mm_data_async( async def process_mm_data_async(
self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs
): ):
@@ -52,8 +32,15 @@ class MllamaImageProcessor(BaseMultimodalProcessor):
else: else:
images = load_image(image_data[0])[0] images = load_image(image_data[0])[0]
image_inputs = await self._process_single_image(images, input_text) image_inputs = self.process_mm_data(input_text=input_text, images=images)
image_inputs["data_hashes"] = [hash(str(image_data))]
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0] image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
image_inputs["mm_items"] = [
MultimodalDataItem(
pixel_values=image_inputs["pixel_values"],
aspect_ratio_id=image_inputs["aspect_ratio_ids"],
aspect_ratio_mask=image_inputs["aspect_ratio_mask"],
modality=Modality.IMAGE,
)
]
return image_inputs return image_inputs

View File

@@ -1,6 +1,5 @@
import asyncio import asyncio
import math import math
import time
from typing import List, Union from typing import List, Union
import torch import torch
@@ -11,8 +10,8 @@ from sglang.srt.managers.multimodal_processor import (
) )
from sglang.srt.managers.multimodal_processors.base_processor import ( from sglang.srt.managers.multimodal_processors.base_processor import (
MultimodalSpecialTokens, MultimodalSpecialTokens,
get_global_processor,
) )
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration from sglang.srt.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
from sglang.srt.models.qwen2_vl import Qwen2VLForConditionalGeneration from sglang.srt.models.qwen2_vl import Qwen2VLForConditionalGeneration
@@ -34,45 +33,15 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
self.MAX_PIXELS = 16384 * 28 * 28 self.MAX_PIXELS = 16384 * 28 * 28
self.MAX_RATIO = 200 self.MAX_RATIO = 200
@staticmethod
def _process_images_task(images, input_text, _hf_config):
if isinstance(images, list) and len(images) == 0:
images = None
result = get_global_processor().__call__(
text=[input_text], images=images, padding=True, return_tensors="pt"
)
return {
"input_ids": result.input_ids,
"pixel_values": getattr(result, "pixel_values", None),
"image_grid_thw": getattr(result, "image_grid_thw", None),
"second_per_grid_ts": getattr(result, "second_per_grid_ts", None),
"video_grid_thws": getattr(result, "video_grid_thws", None),
}
async def _process_single_image(self, images, input_text) -> dict:
if self.executor is not None:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self.executor,
Qwen2_5VLImageProcessor._process_images_task,
images,
input_text,
self.hf_config,
)
else:
return self._process_images_task(images, input_text, self.hf_config)
async def process_mm_data_async( async def process_mm_data_async(
self, self,
image_data: List[Union[str, bytes]], image_data: List[Union[str, bytes]],
input_ids, prompt,
request_obj, request_obj,
max_req_input_len, max_req_input_len,
*args, *args,
**kwargs, **kwargs,
): ):
start = time.time()
if not image_data: if not image_data:
return None return None
if isinstance(image_data, str): if isinstance(image_data, str):
@@ -80,7 +49,7 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
image_token = self.IMAGE_TOKEN image_token = self.IMAGE_TOKEN
base_output = self.load_mm_data( base_output = self.load_mm_data(
input_ids=input_ids, prompt=prompt,
image_data=image_data, image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token), multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
max_req_input_len=max_req_input_len, max_req_input_len=max_req_input_len,
@@ -144,24 +113,32 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
return math.floor(number / factor) * factor return math.floor(number / factor) * factor
images = [resize_image(image) for image in base_output.images] async def resize_image_async(image):
return resize_image(image)
ret = await self._process_single_image( resize_tasks = [resize_image_async(image) for image in base_output.images]
images=images, input_text=base_output.input_text resized_images = await asyncio.gather(*resize_tasks)
ret = self.process_mm_data(
input_text=base_output.input_text,
images=resized_images,
) )
image_grid_thws = torch.concat([ret["image_grid_thw"]]) image_grid_thws = torch.concat([ret["image_grid_thw"]])
video_grid_thws = None
return { return {
"input_ids": ret["input_ids"].flatten().tolist(), "input_ids": ret["input_ids"].flatten().tolist(),
"pixel_values": ret["pixel_values"], "mm_items": [
"data_hashes": base_output.mm_data_hashes, MultimodalDataItem(
"modalities": request_obj.modalities or ["image"], pixel_values=ret["pixel_values"],
"image_grid_thws": image_grid_thws, image_grid_thws=image_grid_thws,
"video_grid_thws": video_grid_thws, # TODO
video_grid_thws=None,
second_per_grid_ts=ret.get("second_per_grid_ts", None),
modality=Modality.IMAGE,
)
],
"im_start_id": self.IM_START_TOKEN_ID, "im_start_id": self.IM_START_TOKEN_ID,
"im_end_id": self.IM_END_TOKEN_ID, "im_end_id": self.IM_END_TOKEN_ID,
"im_token_id": self.image_token_id, "im_token_id": self.image_token_id,
"video_token_id": self.video_token_id, "video_token_id": self.video_token_id,
"second_per_grid_ts": ret["second_per_grid_ts"],
} }

View File

@@ -1,5 +1,7 @@
from __future__ import annotations from __future__ import annotations
from enum import Enum, auto
# Copyright 2023-2024 SGLang Team # Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -51,7 +53,7 @@ from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, Forw
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_compiler_backend from sglang.srt.utils import flatten_nested_list, get_compiler_backend
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
@@ -143,165 +145,185 @@ class FINISH_ABORT(BaseFinishReason):
} }
class Modality(Enum):
IMAGE = auto()
MULTI_IMAGES = auto()
VIDEO = auto()
AUDIO = auto()
@dataclasses.dataclass @dataclasses.dataclass
class MultimodalInputs: class MultimodalDataItem:
"""The image related inputs.""" """
A single multimodal data, from a single image/video/audio or other
"""
pixel_values: Union[torch.Tensor, np.array] modality: Modality
data_hashes: Optional[list] = None
image_sizes: Optional[list] = None
image_offsets: Optional[list] = None
image_pad_len: Optional[list] = None
pad_values: Optional[list] = None
modalities: Optional[list] = None
num_image_tokens: Optional[int] = None
# Llava related hash: int = None
aspect_ratio_ids: Optional[List[torch.Tensor]] = None pad_value: int = None
aspect_ratio_id: Optional[List[torch.Tensor]] = None
aspect_ratio_mask: Optional[List[torch.Tensor]] = None aspect_ratio_mask: Optional[List[torch.Tensor]] = None
# QWen2-VL related image_sizes: Tuple[int, int] = None
# [num_of_images, t, h, w] image_offsets: Optional[list] = None
image_grid_thws: torch.Tensor = None
mrope_position_delta: Optional[torch.Tensor] = None # the real data, pixel_values or audio_features
# Qwen2-VL video related # data: Union[List[torch.Tensor], List[np.array]]
video_token_id: Optional[int] = None pixel_values: Union[torch.Tensor, np.array] = None
video_grid_thws: List[Tuple[int, int, int]] = None image_grid_thws: Union[torch.Tensor, np.array] = None
video_grid_thws: Union[torch.Tensor, np.array] = None
image_emb_mask: Optional[torch.Tensor] = None
image_spatial_crop: Optional[torch.Tensor] = None
second_per_grid_ts: Optional[List[torch.Tensor]] = None second_per_grid_ts: Optional[List[torch.Tensor]] = None
# deepseek vl2 related # [num_images, (n, w, h)]
images_emb_mask: Optional[List[torch.Tensor]] = None tgt_size: Tuple[int, int] = None
image_spatial_crop: Optional[List[torch.Tensor]] = None
# The id of the single-image placeholder token audio_features: Union[torch.Tensor, np.array] = None
audio_feature_lens: Optional[List[torch.Tensor]] = None
@staticmethod
def is_empty_list(l):
if l is None:
return True
return len([item for item in flatten_nested_list(l) if item is not None]) == 0
def set_pad_value(self):
"""
Set the pad value after first hashign the data
"""
def hash_feature(f):
if isinstance(f, list):
return hash(tuple(flatten_nested_list(f)))
elif isinstance(f, np.ndarray):
arr = np.ascontiguousarray(f)
arr_bytes = arr.tobytes()
return hash(arr_bytes)
return hash(f)
if self.is_audio():
self.hash = hash_feature(self.audio_features)
else:
self.hash = hash_feature(self.pixel_values)
assert self.hash is not None
self.pad_value = self.hash % (1 << 30)
def is_audio(self):
return (
self.modality == Modality.AUDIO
) and not MultimodalDataItem.is_empty_list(self.audio_features)
def is_image(self):
return (
self.modality == Modality.IMAGE or self.modality == Modality.MULTI_IMAGES
) and not MultimodalDataItem.is_empty_list(self.pixel_values)
def is_video(self):
return (
self.modality == Modality.VIDEO
) and not MultimodalDataItem.is_empty_list(self.pixel_values)
def validate(self):
...
# TODO
@dataclasses.dataclass
class MultimodalInputs:
"""The multimodal data related inputs."""
# items of data
mm_items: List[MultimodalDataItem]
image_pad_len: Optional[list] = None
num_image_tokens: Optional[int] = None
# QWen2-VL related
mrope_position_delta: Optional[torch.Tensor] = None
# image
im_token_id: Optional[torch.Tensor] = None im_token_id: Optional[torch.Tensor] = None
# All the images in the batch should share the same special image
# bound token ids.
im_start_id: Optional[int] = None im_start_id: Optional[int] = None
im_end_id: Optional[int] = None im_end_id: Optional[int] = None
slice_start_id: Optional[int] = None slice_start_id: Optional[int] = None
slice_end_id: Optional[int] = None slice_end_id: Optional[int] = None
# [num_images, 2 (w, h)]
tgt_sizes: Optional[list] = None # video
video_token_id: Optional[int] = None
# audio # audio
audio_start_id: Optional[torch.Tensor] = None audio_start_id: Optional[torch.Tensor] = None
audio_end_id: Optional[torch.Tensor] = None audio_end_id: Optional[torch.Tensor] = None
audio_features: Optional[List[torch.Tensor]] = None
audio_feature_lens: Optional[List[torch.Tensor]] = None
@staticmethod @staticmethod
def from_dict(obj: dict): def from_dict(obj: dict):
ret = MultimodalInputs( ret = MultimodalInputs(
pixel_values=obj["pixel_values"], mm_items=obj["mm_items"],
data_hashes=obj["data_hashes"],
) )
assert isinstance(ret.mm_items, list)
ret.mm_items = [
item
for item in ret.mm_items
if item.is_audio() or item.is_image() or item.is_video()
]
assert len(ret.mm_items) != 0
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache. # Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
# Please note that if the `input_ids` is later used in the model forward, # Please note that if the `input_ids` is later used in the model forward,
# you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound # you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
# errors in cuda kernels. See also llava.py for example. # errors in cuda kernels. See also llava.py for example.
ret.pad_values = [x % (1 << 30) for x in ret.data_hashes] for item in ret.mm_items:
item.set_pad_value()
optional_args = [ optional_args = [
"image_sizes",
"modalities", "modalities",
"aspect_ratio_ids",
"aspect_ratio_mask",
"image_grid_thws",
"images_emb_mask",
"image_spatial_crop",
"im_token_id", "im_token_id",
"im_start_id", "im_start_id",
"im_end_id", "im_end_id",
"slice_start_id", "slice_start_id",
"slice_end_id", "slice_end_id",
"tgt_sizes",
"audio_start_id", "audio_start_id",
"audio_end_id", "audio_end_id",
"audio_features",
"audio_feature_lens",
] ]
for arg in optional_args: for arg in optional_args:
if arg in obj: if arg in obj:
setattr(ret, arg, obj[arg]) setattr(ret, arg, obj[arg])
# validate
assert (
isinstance(ret.pixel_values, torch.Tensor)
or isinstance(ret.pixel_values, np.ndarray)
or isinstance(ret.pixel_values, list)
)
assert ret.audio_features is None or isinstance(ret.audio_features, list)
return ret return ret
def contains_image_inputs(self) -> bool: def contains_image_inputs(self) -> bool:
""" """ """ """
return self.pixel_values is not None and self.pixel_values != [] return any(item.is_image() for item in self.mm_items)
def contains_audio_inputs(self) -> bool: def contains_audio_inputs(self) -> bool:
""" """ """ """
return self.audio_features is not None and self.audio_features != [] return any(item.is_audio() for item in self.mm_items)
def collect_image_inputs(self) -> List[torch.Tensor]:
return [item.pixel_values for item in self.mm_items if item.is_image()]
def merge(self, other: MultimodalInputs): def merge(self, other: MultimodalInputs):
""" """
merge image inputs when requests are being merged merge image inputs when requests are being merged
""" """
if isinstance(self.pixel_values, list):
# in some rare cases, pixel values are list of patches with different shapes
# e.g. minicpm
self.pixel_values += other.pixel_values
else:
assert (
self.pixel_values.shape[1:] == other.pixel_values.shape[1:]
), f"{self.pixel_values.shape[1:]} vs {other.pixel_values.shape[1:]}"
self.pixel_values = np.concatenate([self.pixel_values, other.pixel_values])
# args would be stacked along first dim
# usually these are already tensors
stack_args = [
# TODO: merge with image_grid_thws, basically the same thing
"tgt_sizes",
"image_spatial_crop",
]
for arg in stack_args:
if getattr(self, arg, None) is None:
setattr(self, arg, getattr(other, arg, None))
elif getattr(other, arg, None) is not None:
# self and other both not None
setattr(
self,
arg,
torch.cat([getattr(self, arg), getattr(other, arg)], dim=0),
)
if self.image_grid_thws is None:
self.image_grid_thws = other.image_grid_thws
elif other.image_grid_thws is not None:
self.image_grid_thws = torch.concat(
[self.image_grid_thws, other.image_grid_thws]
)
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache. # Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
# Please note that if the `input_ids` is later used in the model forward, # Please note that if the `input_ids` is later used in the model forward,
# you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound # you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
# errors in cuda kernels. See also llava.py for example. # errors in cuda kernels. See also llava.py for example.
self.data_hashes += other.data_hashes
self.pad_values = [x % (1 << 30) for x in self.data_hashes]
# args needed to be merged # args needed to be merged
optional_args = [ optional_args = [
"audio_features", "items",
"image_sizes",
"image_offsets", "image_offsets",
"image_pad_len", "image_pad_len",
# "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images # "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images
"aspect_ratio_ids",
"aspect_ratio_mask",
"images_emb_mask",
] ]
for arg in optional_args: for arg in optional_args:
self_arg = getattr(self, arg, None) self_arg = getattr(self, arg, None)

View File

@@ -112,7 +112,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter

View File

@@ -1,11 +1,6 @@
import json
import logging import logging
import time
from collections import defaultdict
from http import HTTPStatus from http import HTTPStatus
from typing import Dict, List, Optional, Tuple from typing import Optional
import torch
from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req

View File

@@ -355,11 +355,6 @@ class ForwardBatch:
for mm_input in valid_inputs[1:]: for mm_input in valid_inputs[1:]:
merged.merge(mm_input) merged.merge(mm_input)
if isinstance(merged.pixel_values, np.ndarray):
merged.pixel_values = torch.from_numpy(merged.pixel_values)
if isinstance(merged.audio_features, np.ndarray):
merged.audio_features = torch.from_numpy(merged.audio_features)
return merged return merged
def contains_image_inputs(self) -> bool: def contains_image_inputs(self) -> bool:

View File

@@ -251,17 +251,16 @@ class ModelRunner:
self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type) self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type)
if self.is_multimodal: if self.is_multimodal:
self.mem_fraction_static *= 0.95 self.mem_fraction_static *= 0.90
logger.info( logger.info(
f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} " f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
f"because this is a multimodal model." f"because this is a multimodal model."
) )
if self.model_config.hf_config.architectures == [ logger.info(
"MllamaForConditionalGeneration" "Automatically turn off --chunked-prefill-size for multimodal model."
]: )
logger.info("Automatically turn off --chunked-prefill-size for mllama.") server_args.chunked_prefill_size = -1
server_args.chunked_prefill_size = -1
if self.model_config.hf_config.architectures == [ if self.model_config.hf_config.architectures == [
"Qwen2VLForConditionalGeneration" "Qwen2VLForConditionalGeneration"
@@ -269,18 +268,7 @@ class ModelRunner:
"Qwen2_5_VLForConditionalGeneration" "Qwen2_5_VLForConditionalGeneration"
]: ]:
# TODO: qwen2-vl series does not support radix cache now, set disable_radix_cache=True automatically # TODO: qwen2-vl series does not support radix cache now, set disable_radix_cache=True automatically
logger.info( logger.info("Automatically disable radix cache for qwen-vl series.")
"Automatically turn off --chunked-prefill-size and disable radix cache for qwen-vl series."
)
server_args.chunked_prefill_size = -1
server_args.disable_radix_cache = True
if self.model_config.hf_config.architectures == ["DeepseekVL2ForCausalLM"]:
# TODO: deepseek-vl2 does not support radix cache now, set disable_radix_cache=True automatically
logger.info(
"Automatically turn off --chunked-prefill-size and disable radix cache for deepseek-vl2."
)
server_args.chunked_prefill_size = -1
server_args.disable_radix_cache = True server_args.disable_radix_cache = True
if server_args.enable_deepep_moe: if server_args.enable_deepep_moe:

View File

@@ -17,7 +17,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.schedule_batch import MultimodalInputs from sglang.srt.managers.schedule_batch import MultimodalInputs
from sglang.srt.model_executor.model_runner import ForwardBatch from sglang.srt.model_executor.model_runner import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix from sglang.srt.utils import add_prefix, flatten_nested_list
class CLIPVisionEmbeddings(nn.Module): class CLIPVisionEmbeddings(nn.Module):
@@ -368,7 +368,6 @@ class CLIPVisionTransformer(nn.Module):
self, self,
pixel_values: torch.Tensor, pixel_values: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embeddings(pixel_values.to(self.device)) hidden_states = self.embeddings(pixel_values.to(self.device))
hidden_states = self.pre_layrnorm(hidden_states) hidden_states = self.pre_layrnorm(hidden_states)
@@ -456,12 +455,18 @@ class CLIPModel(nn.Module):
get_embedding: bool = True, get_embedding: bool = True,
): ):
assert get_embedding, "CLIPEmbeddingModel is only used for embedding" assert get_embedding, "CLIPEmbeddingModel is only used for embedding"
image_inputs = None mm_inputs = []
if forward_batch.mm_inputs is not None: if forward_batch.mm_inputs is not None:
image_inputs = forward_batch.mm_inputs mm_inputs = forward_batch.mm_inputs
pixel_values_list = [
if image_inputs is not None and image_inputs[0] is not None: item.pixel_values
vision_outputs = self.vision_model(image_inputs[0].pixel_values) for item in flatten_nested_list(
[mm_input.mm_items for mm_input in mm_inputs if mm_input is not None]
)
]
if len(pixel_values_list) != 0:
pixel_values = torch.concat(pixel_values_list)
vision_outputs = self.vision_model(pixel_values)
pooled_output = vision_outputs[:, 0, :] pooled_output = vision_outputs[:, 0, :]
image_embeds = self.visual_projection(pooled_output) image_embeds = self.visual_projection(pooled_output)
image_embeds = nn.functional.normalize(image_embeds, p=2, dim=1) image_embeds = nn.functional.normalize(image_embeds, p=2, dim=1)

View File

@@ -51,7 +51,7 @@ from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs, MultiModalityDataPaddingPatternTokenPairs,
general_mm_embed_routine, general_mm_embed_routine,
) )
from sglang.srt.managers.schedule_batch import MultimodalInputs, global_server_args_dict from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.llama import LlamaForCausalLM from sglang.srt.models.llama import LlamaForCausalLM
@@ -1959,8 +1959,8 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
def get_image_feature(self, image_input: MultimodalInputs) -> torch.Tensor: def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
pixel_values = image_input.pixel_values pixel_values = torch.concat([item.pixel_values for item in items], dim=0)
bs, n = pixel_values.shape[0:2] bs, n = pixel_values.shape[0:2]
pixel_values = pixel_values.to( pixel_values = pixel_values.to(
device=self.vision_model.device, dtype=self.vision_model.dtype device=self.vision_model.device, dtype=self.vision_model.dtype
@@ -1976,7 +1976,7 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
return images_embeds return images_embeds
def get_input_embeddings(self) -> nn.Embedding: def get_input_embeddings(self) -> nn.Embedding:
return self.language_model.model.embed_tokens return self.language_model.get_input_embeddings()
@torch.no_grad() @torch.no_grad()
def forward( def forward(
@@ -1984,22 +1984,17 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
positions: torch.Tensor, positions: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
get_embedding: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = general_mm_embed_routine(
inputs_embeds = general_mm_embed_routine(
input_ids=input_ids, input_ids=input_ids,
forward_batch=forward_batch, forward_batch=forward_batch,
embed_tokens=self.get_input_embeddings(), image_data_embedding_func=self.get_image_feature,
mm_data_embedding_func=self.get_image_feature, language_model=self.language_model,
positions=positions,
) )
return self.language_model( return hidden_states
input_ids=None,
positions=positions,
forward_batch=forward_batch,
input_embeds=inputs_embeds,
get_embedding=False,
)
def prepare_gen_img_embeds(self, image_ids: torch.LongTensor): def prepare_gen_img_embeds(self, image_ids: torch.LongTensor):
return self.gen_aligner(self.gen_embed(image_ids)) return self.gen_aligner(self.gen_embed(image_ids))

View File

@@ -1308,6 +1308,9 @@ class DeepseekV2ForCausalLM(nn.Module):
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.dp_size = get_attention_dp_size() self.dp_size = get_attention_dp_size()
def get_input_embeddings(self) -> nn.Embedding:
return self.model.embed_tokens
@torch.no_grad() @torch.no_grad()
def forward( def forward(
self, self,

View File

@@ -11,7 +11,11 @@ from sglang.srt.configs.deepseekvl2 import (
) )
from sglang.srt.layers.linear import ReplicatedLinear from sglang.srt.layers.linear import ReplicatedLinear
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.schedule_batch import MultimodalInputs from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternImageTokens,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.deepseek_v2 import DeepseekV2ForCausalLM from sglang.srt.models.deepseek_v2 import DeepseekV2ForCausalLM
@@ -150,7 +154,6 @@ class DeepseekVL2MlpProjector(nn.Module):
return x return x
# todo
class DeepseekVL2ForCausalLM(nn.Module): class DeepseekVL2ForCausalLM(nn.Module):
def __init__( def __init__(
@@ -215,32 +218,15 @@ class DeepseekVL2ForCausalLM(nn.Module):
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
**kwargs: object, **kwargs: object,
): ):
input_embeds = self.language_model.model.embed_tokens(input_ids) hs = general_mm_embed_routine(
if (
forward_batch.forward_mode.is_extend()
and forward_batch.contains_image_inputs()
):
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
extend_seq_lens_cpu = forward_batch.extend_seq_lens.cpu().numpy()
for idx, image in enumerate(forward_batch.mm_inputs):
if image is None:
continue
start_idx = extend_start_loc_cpu[idx]
end_idx = start_idx + extend_seq_lens_cpu[idx]
images_emb_mask = image.images_emb_mask.to(device="cuda")
image_features = self.get_image_feature(image)
input_embeds[start_idx:end_idx] = input_embeds[
start_idx:end_idx
].masked_scatter(images_emb_mask.unsqueeze(-1), image_features)
outputs = self.language_model.forward(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
forward_batch=forward_batch, forward_batch=forward_batch,
input_embeds=input_embeds, image_data_embedding_func=self.get_image_feature,
language_model=self.language_model,
) )
return outputs return hs
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
@@ -263,94 +249,109 @@ class DeepseekVL2ForCausalLM(nn.Module):
weights_loader(param, loaded_weight) weights_loader(param, loaded_weight)
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs): def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
return input_ids helper = MultiModalityDataPaddingPatternImageTokens(
image_token_id=image_inputs.im_token_id
)
return helper.pad_input_tokens(input_ids, image_inputs)
def get_image_feature(self, image_input: MultimodalInputs): def get_image_feature(self, items: List[MultimodalDataItem]):
pixel_values = image_input.pixel_values.type(
next(self.vision.parameters()).dtype images_spatial_crop = torch.cat(
).to(device=next(self.vision.parameters()).device) [item.image_spatial_crop for item in items], dim=0
image_feature = self.vision.forward_features(pixel_values) )
images_embeds = self.projector(image_feature)
_, hw, n_dim = images_embeds.shape assert images_spatial_crop.dim() == 3
h = w = int(hw**0.5)
tile_index = 0 # TODO: can it be batched ?
images_in_this_batch = [] images_in_this_batch = []
images_spatial_crop = image_input.image_spatial_crop for item in items:
for jdx in range(images_spatial_crop.shape[1]): assert item.pixel_values.dim() == 4
num_width_tiles, num_height_tiles = images_spatial_crop[0, jdx] image_feature = self.vision.forward_features(
if num_width_tiles == 0 or num_height_tiles == 0: item.pixel_values.type(next(self.vision.parameters()).dtype).to(
break device=next(self.vision.parameters()).device
num_tiles_in_image = num_width_tiles * num_height_tiles
# [hw, D]
global_features = images_embeds[tile_index]
# [num_height_tiles * num_width_tiles, hw, D]
local_features = images_embeds[
tile_index + 1 : tile_index + 1 + num_tiles_in_image
]
tile_index += num_tiles_in_image + 1
# format global and local features
# ----------------- global view add newline -----------------
# [hw, D] -> [h, w, D]
global_features = global_features.view(h, w, n_dim)
# [D] -> [h, 1, D]
new_lines_in_global = repeat(self.image_newline, "d -> h 1 d", h=h)
# cat([h, w, D], [h, 1, D], dim=1) -> [h, w + 1, D]
global_features = torch.cat([global_features, new_lines_in_global], dim=1)
# [h, w + 1, D] -> [h * (w + 1), D]
global_features = global_features.view(-1, n_dim)
# ----------------- local view add newline -----------------
# [num_height_tiles * num_width_tiles, h * w, D] ->
# [num_height_tiles * h, num_width_tiles * w, D]
local_features = rearrange(
local_features,
"(th tw) (h w) d -> (th h) (tw w) d",
th=num_height_tiles,
tw=num_width_tiles,
h=h,
w=w,
)
# [D] -> [num_height_tiles * h, 1, D]
new_lines_in_local = repeat(
self.image_newline,
"d -> (th h) 1 d",
th=num_height_tiles,
h=h,
)
# [num_height_tiles * h, num_width_tiles * w + 1, D]
local_features = torch.cat([local_features, new_lines_in_local], dim=1)
# [num_height_tiles * h, num_width_tiles * w + 1, D]
# --> [(num_height_tiles * h) * (num_width_tiles * w + 1), D]
local_features = local_features.view(-1, n_dim)
# merge global and local tiles
if self.global_view_pos == "head":
global_local_features = torch.cat(
[
global_features,
self.view_seperator[None, :],
local_features,
]
) )
else: )
global_local_features = torch.cat( images_embeds = self.projector(image_feature)
[ _, hw, n_dim = images_embeds.shape
local_features, h = w = int(hw**0.5)
self.view_seperator[None, :], tile_index = 0
global_features, for jdx in range(item.image_spatial_crop.shape[1]):
] num_width_tiles, num_height_tiles = item.image_spatial_crop[0, jdx]
if num_width_tiles == 0 or num_height_tiles == 0:
break
num_tiles_in_image = num_width_tiles * num_height_tiles
# [hw, D]
global_features = images_embeds[tile_index]
# [num_height_tiles * num_width_tiles, hw, D]
local_features = images_embeds[
tile_index + 1 : tile_index + 1 + num_tiles_in_image
]
tile_index += num_tiles_in_image + 1
# format global and local features
# ----------------- global view add newline -----------------
# [hw, D] -> [h, w, D]
global_features = global_features.view(h, w, n_dim)
# [D] -> [h, 1, D]
new_lines_in_global = repeat(self.image_newline, "d -> h 1 d", h=h)
# cat([h, w, D], [h, 1, D], dim=1) -> [h, w + 1, D]
global_features = torch.cat(
[global_features, new_lines_in_global], dim=1
) )
images_in_this_batch.append(global_local_features) # [h, w + 1, D] -> [h * (w + 1), D]
global_features = global_features.view(-1, n_dim)
# ----------------- local view add newline -----------------
# [num_height_tiles * num_width_tiles, h * w, D] ->
# [num_height_tiles * h, num_width_tiles * w, D]
local_features = rearrange(
local_features,
"(th tw) (h w) d -> (th h) (tw w) d",
th=num_height_tiles,
tw=num_width_tiles,
h=h,
w=w,
)
# [D] -> [num_height_tiles * h, 1, D]
new_lines_in_local = repeat(
self.image_newline,
"d -> (th h) 1 d",
th=num_height_tiles,
h=h,
)
# [num_height_tiles * h, num_width_tiles * w + 1, D]
local_features = torch.cat([local_features, new_lines_in_local], dim=1)
# [num_height_tiles * h, num_width_tiles * w + 1, D]
# --> [(num_height_tiles * h) * (num_width_tiles * w + 1), D]
local_features = local_features.view(-1, n_dim)
# merge global and local tiles
if self.global_view_pos == "head":
global_local_features = torch.cat(
[
global_features,
self.view_seperator[None, :],
local_features,
]
)
else:
global_local_features = torch.cat(
[
local_features,
self.view_seperator[None, :],
global_features,
]
)
images_in_this_batch.append(global_local_features)
return torch.cat(images_in_this_batch, dim=0) return torch.cat(images_in_this_batch, dim=0)

View File

@@ -21,14 +21,7 @@ from typing import Dict, Iterable, List, Optional, Set, Tuple, TypedDict
import torch import torch
from torch import nn from torch import nn
from transformers import ( from transformers import AutoModel, Gemma3Config, PreTrainedModel
AutoModel,
BatchFeature,
Gemma3Config,
Gemma3Processor,
PreTrainedModel,
)
from transformers.models.gemma3.processing_gemma3 import Gemma3ProcessorKwargs
from sglang.srt.hf_transformers_utils import get_processor from sglang.srt.hf_transformers_utils import get_processor
from sglang.srt.layers.layernorm import Gemma3RMSNorm from sglang.srt.layers.layernorm import Gemma3RMSNorm
@@ -38,7 +31,11 @@ from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs, MultiModalityDataPaddingPatternTokenPairs,
general_mm_embed_routine, general_mm_embed_routine,
) )
from sglang.srt.managers.schedule_batch import MultimodalInputs from sglang.srt.managers.schedule_batch import (
MultimodalDataItem,
MultimodalInputs,
flatten_nested_list,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import ( from sglang.srt.model_loader.weight_utils import (
default_weight_loader, default_weight_loader,
@@ -274,17 +271,16 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
""" """
return self.language_model.get_attention_sliding_window_size() return self.language_model.get_attention_sliding_window_size()
def get_image_feature(self, image_input: MultimodalInputs): def get_image_feature(self, items: List[MultimodalDataItem]):
""" """
Projects the last hidden state from the vision model into language model space. Projects the last hidden state from the vision model into language model space.
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
The tensors corresponding to the input images.
Returns: Returns:
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
""" """
pixel_values = image_input.pixel_values pixel_values = torch.stack(
flatten_nested_list([item.pixel_values for item in items]), dim=0
)
pixel_values = pixel_values.to("cuda") pixel_values = pixel_values.to("cuda")
pixel_values = pixel_values.to(dtype=self.language_model.dtype()) pixel_values = pixel_values.to(dtype=self.language_model.dtype())
@@ -292,61 +288,6 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
image_features = self.multi_modal_projector(vision_outputs) image_features = self.multi_modal_projector(vision_outputs)
return image_features return image_features
def embed_mm_inputs(
self,
input_ids: torch.Tensor,
forward_batch: ForwardBatch,
image_input: MultimodalInputs,
) -> torch.Tensor:
if input_ids is None:
raise ValueError("Unimplemented")
# boolean-masking image tokens
special_image_mask = torch.isin(
input_ids,
torch.tensor(image_input.pad_values, device=input_ids.device),
).unsqueeze(-1)
num_image_tokens_in_input_ids = special_image_mask.sum()
inputs_embeds = None
if num_image_tokens_in_input_ids == 0:
inputs_embeds = self.get_input_embeddings()(input_ids)
return inputs_embeds
else:
# print(f"image tokens from input_ids: {inputs_embeds[special_image_mask].numel()}")
image_features = self.get_image_feature(image_input.pixel_values)
# print(f"image tokens from image embeddings: {image_features.numel()}")
num_image_tokens_in_embedding = (
image_features.shape[0] * image_features.shape[1]
)
if num_image_tokens_in_input_ids != num_image_tokens_in_embedding:
num_image = num_image_tokens_in_input_ids // image_features.shape[1]
image_features = image_features[:num_image, :]
logger.warning(
f"Number of images does not match number of special image tokens in the input text. "
f"Got {num_image_tokens_in_input_ids} image tokens in the text but {num_image_tokens_in_embedding} "
"tokens from image embeddings."
)
# Important: clamp after extracting original image boundaries
input_ids.clamp_(min=0, max=self.vocab_size - 1)
inputs_embeds = self.get_input_embeddings()(input_ids)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
inputs_embeds.device
)
image_features = image_features.to(
inputs_embeds.device, inputs_embeds.dtype
)
inputs_embeds = inputs_embeds.masked_scatter(
special_image_mask, image_features
)
return inputs_embeds
@torch.no_grad() @torch.no_grad()
def forward( def forward(
self, self,
@@ -405,22 +346,15 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
else: else:
llm_input_ids = input_ids llm_input_ids = input_ids
inputs_embeds = general_mm_embed_routine( hs = general_mm_embed_routine(
input_ids=llm_input_ids, input_ids=llm_input_ids,
forward_batch=forward_batch, forward_batch=forward_batch,
embed_tokens=self.get_input_embeddings(), language_model=self.language_model,
mm_data_embedding_func=self.get_image_feature, image_data_embedding_func=self.get_image_feature,
)
outputs = self.language_model(
input_ids=None,
positions=positions, positions=positions,
forward_batch=forward_batch,
input_embeds=inputs_embeds,
**kwargs,
) )
return outputs return hs
def tie_weights(self): def tie_weights(self):
return self.language_model.tie_weights() return self.language_model.tie_weights()

View File

@@ -428,6 +428,9 @@ class LlamaForCausalLM(nn.Module):
else: else:
return self.pooler(hidden_states, forward_batch) return self.pooler(hidden_states, forward_batch)
def get_input_embeddings(self) -> nn.Embedding:
return self.model.embed_tokens
def get_hidden_dim(self, module_name): def get_hidden_dim(self, module_name):
# return input_dim, output_dim # return input_dim, output_dim
if module_name in ["q_proj", "o_proj", "qkv_proj"]: if module_name in ["q_proj", "o_proj", "qkv_proj"]:

View File

@@ -31,7 +31,7 @@ from transformers import (
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.schedule_batch import MultimodalInputs from sglang.srt.managers.schedule_batch import Modality, MultimodalInputs
from sglang.srt.mm_utils import ( from sglang.srt.mm_utils import (
get_anyres_image_grid_shape, get_anyres_image_grid_shape,
unpad_image, unpad_image,
@@ -42,17 +42,21 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.llama import LlamaForCausalLM from sglang.srt.models.llama import LlamaForCausalLM
from sglang.srt.models.mistral import MistralForCausalLM from sglang.srt.models.mistral import MistralForCausalLM
from sglang.srt.models.qwen2 import Qwen2ForCausalLM from sglang.srt.models.qwen2 import Qwen2ForCausalLM
from sglang.srt.utils import add_prefix from sglang.srt.utils import add_prefix, flatten_nested_list
class LlavaBaseForCausalLM(nn.Module): class LlavaBaseForCausalLM(nn.Module):
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs): def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
image_sizes, pad_values = image_inputs.image_sizes, image_inputs.pad_values image_sizes = flatten_nested_list(
[item.image_sizes for item in image_inputs.mm_items]
)
pad_values = [item.pad_value for item in image_inputs.mm_items]
# hardcode for spatial_unpad + anyres # hardcode for spatial_unpad + anyres
if image_inputs.modalities is not None and ( if any(
"multi-images" in image_inputs.modalities item.modality == Modality.MULTI_IMAGES or item.modality == Modality.VIDEO
or "video" in image_inputs.modalities for item in image_inputs.mm_items
): ):
image_aspect_ratio = "pad" image_aspect_ratio = "pad"
else: else:
@@ -66,7 +70,7 @@ class LlavaBaseForCausalLM(nn.Module):
math.ceil(self.image_size / self.patch_size / 2) ** 2 math.ceil(self.image_size / self.patch_size / 2) ** 2
) )
else: else:
new_image_feature_len = self.image_feature_len # multiimage new_image_feature_len = self.image_feature_len # multi-image
height = width = self.num_patches_per_side height = width = self.num_patches_per_side
if "anyres" in image_aspect_ratio: if "anyres" in image_aspect_ratio:
@@ -101,7 +105,7 @@ class LlavaBaseForCausalLM(nn.Module):
# old_len + pad_len - 1, because we need to remove image_token_id # old_len + pad_len - 1, because we need to remove image_token_id
input_ids = ( input_ids = (
input_ids[:offset] input_ids[:offset]
+ [pad_values[image_idx]] * new_image_feature_len + [pad_values[image_idx % len(pad_values)]] * new_image_feature_len
+ input_ids[offset + 1 :] + input_ids[offset + 1 :]
) )
offset_list.append(offset) offset_list.append(offset)
@@ -150,8 +154,8 @@ class LlavaBaseForCausalLM(nn.Module):
modalities_list = [] modalities_list = []
max_image_offset = [] max_image_offset = []
for im in image_inputs: for im in image_inputs:
if im and im.modalities is not None: if im:
modalities_list.extend(im.modalities) modalities_list.extend([item.modality for item in im.mm_items])
if im and im.image_offsets: if im and im.image_offsets:
max_image_offset.append( max_image_offset.append(
np.max(np.array(im.image_offsets) + np.array(im.image_pad_len)) np.max(np.array(im.image_offsets) + np.array(im.image_pad_len))
@@ -164,11 +168,19 @@ class LlavaBaseForCausalLM(nn.Module):
if need_vision.any(): if need_vision.any():
bs = forward_batch.batch_size bs = forward_batch.batch_size
pixel_values = [ pixel_values = flatten_nested_list(
image_inputs[i].pixel_values for i in range(bs) if need_vision[i] [
] [item.pixel_values for item in image_inputs[i].mm_items]
for i in range(bs)
if need_vision[i]
]
)
image_sizes = [ image_sizes = [
image_inputs[i].image_sizes for i in range(bs) if need_vision[i] flatten_nested_list(
[item.image_sizes for item in image_inputs[i].mm_items]
)
for i in range(bs)
if need_vision[i]
] ]
########## Encode Image ######## ########## Encode Image ########
@@ -197,13 +209,13 @@ class LlavaBaseForCausalLM(nn.Module):
new_image_features = [] new_image_features = []
height = width = self.num_patches_per_side height = width = self.num_patches_per_side
for image_idx, image_feature in enumerate(image_features): for image_idx, image_feature in enumerate(image_features):
if modalities_list[image_idx] == "image": if modalities_list[image_idx] == Modality.IMAGE:
image_aspect_ratio = ( image_aspect_ratio = (
self.config.image_aspect_ratio self.config.image_aspect_ratio
) # single image ) # single image
elif ( elif (
modalities_list[image_idx] == "multi-images" modalities_list[image_idx] == Modality.MULTI_IMAGES
or modalities_list[image_idx] == "video" or modalities_list[image_idx] == Modality.VIDEO
): ):
image_aspect_ratio = "pad" # multi image image_aspect_ratio = "pad" # multi image
# image_aspect_ratio = ( # image_aspect_ratio = (
@@ -212,7 +224,7 @@ class LlavaBaseForCausalLM(nn.Module):
if ( if (
image_feature.shape[0] > 1 image_feature.shape[0] > 1
and "anyres" in image_aspect_ratio and "anyres" in image_aspect_ratio
and modalities_list[image_idx] == "image" and modalities_list[image_idx] == Modality.IMAGE
): ):
base_image_feature = image_feature[0] base_image_feature = image_feature[0]
image_feature = image_feature[1:] image_feature = image_feature[1:]
@@ -312,7 +324,7 @@ class LlavaBaseForCausalLM(nn.Module):
) )
image_feature = image_feature.unsqueeze(0) image_feature = image_feature.unsqueeze(0)
else: else:
if modalities_list[image_idx] == "video": # video if modalities_list[image_idx] == Modality.VIDEO: # video
# 2x2 pooling # 2x2 pooling
num_of_frames = image_feature.shape[0] num_of_frames = image_feature.shape[0]
image_feature = image_feature.view( image_feature = image_feature.view(

View File

@@ -22,7 +22,7 @@ from transformers import CLIPVisionModel, LlavaConfig
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.schedule_batch import MultimodalInputs from sglang.srt.managers.schedule_batch import MultimodalInputs, flatten_nested_list
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.llama import LlamaForCausalLM from sglang.srt.models.llama import LlamaForCausalLM
@@ -58,7 +58,7 @@ class LlavaVidForCausalLM(nn.Module):
) )
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs): def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
pad_values = image_inputs.pad_values pad_values = [item.pad_value for item in image_inputs.mm_items]
new_image_feature_len = self.image_feature_len new_image_feature_len = self.image_feature_len
pad_ids = pad_values * ( pad_ids = pad_values * (
@@ -133,11 +133,19 @@ class LlavaVidForCausalLM(nn.Module):
need_vision = start_positions <= np.array(max_image_offset) need_vision = start_positions <= np.array(max_image_offset)
if need_vision.any(): if need_vision.any():
pixel_values = [ pixel_values = flatten_nested_list(
image_inputs[i].pixel_values for i in range(bs) if need_vision[i] [
] [item.pixel_values for item in image_inputs[i].mm_items]
for i in range(bs)
if need_vision[i]
]
)
image_offsets = [ image_offsets = [
image_inputs[i].image_offsets for i in range(bs) if need_vision[i] flatten_nested_list(
[item.image_offsets for item in image_inputs[i].mm_items]
)
for i in range(bs)
if need_vision[i]
] ]
########## Encode Image ######## ########## Encode Image ########
@@ -246,7 +254,8 @@ class LlavaVidForCausalLM(nn.Module):
"model.mm_projector.2": "multi_modal_projector.linear_2", "model.mm_projector.2": "multi_modal_projector.linear_2",
"model.vision_resampler.mm_projector.0": "multi_modal_projector.linear_1", "model.vision_resampler.mm_projector.0": "multi_modal_projector.linear_1",
"model.vision_resampler.mm_projector.2": "multi_modal_projector.linear_2", "model.vision_resampler.mm_projector.2": "multi_modal_projector.linear_2",
"model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned). "model.vision_tower.vision_tower": "vision_tower",
# Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
"model.image_newline": "language_model.model.image_newline", "model.image_newline": "language_model.model.image_newline",
} }
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())

View File

@@ -40,16 +40,19 @@ from transformers.models.whisper.modeling_whisper import (
from sglang.srt.layers.quantization import QuantizationConfig from sglang.srt.layers.quantization import QuantizationConfig
from sglang.srt.managers.mm_utils import ( from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs, MultiModalityDataPaddingPatternTokenPairs,
embed_mm_inputs, general_mm_embed_routine,
get_multimodal_data_bounds,
) )
from sglang.srt.managers.schedule_batch import MultimodalInputs from sglang.srt.managers.schedule_batch import (
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode MultimodalDataItem,
MultimodalInputs,
flatten_nested_list,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.utils import set_default_torch_dtype from sglang.srt.model_loader.utils import set_default_torch_dtype
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.minicpmv import ( from sglang.srt.models.minicpmv import (
Idefics2VisionTransformer, Idefics2VisionTransformer,
MiniCPMVBaseModel, MiniCPMBaseModel,
Resampler2_5, Resampler2_5,
) )
from sglang.srt.models.qwen2 import Qwen2ForCausalLM from sglang.srt.models.qwen2 import Qwen2ForCausalLM
@@ -1409,7 +1412,7 @@ class MultiModalProjector(nn.Module):
return hidden_states return hidden_states
class MiniCPMO(MiniCPMVBaseModel): class MiniCPMO(MiniCPMBaseModel):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
@@ -1537,7 +1540,7 @@ class MiniCPMO(MiniCPMVBaseModel):
return input_lengths_after_cnn, input_lengths_after_pooling return input_lengths_after_cnn, input_lengths_after_pooling
def get_audio_embedding_streaming(self, multimodal_input: MultimodalInputs): def get_audio_embedding_streaming(self, items: List[MultimodalDataItem]):
r""" r"""
Extract audio embeddings in a streaming manner using cached key-value pairs. Extract audio embeddings in a streaming manner using cached key-value pairs.
@@ -1545,26 +1548,15 @@ class MiniCPMO(MiniCPMVBaseModel):
for faster inference on subsequent audio frames. It only supports batch_size=1 and is intended for faster inference on subsequent audio frames. It only supports batch_size=1 and is intended
for streaming scenarios. for streaming scenarios.
Args:
multimodal_input (dict):
- **"audio_features"** (`torch.FloatTensor`): Input mel-spectrograms of shape `(batch_size, 80, frames)`.
- **"audio_feature_lens"** (List[List[int]]): Lengths of each audio segment for each item in the batch.
Returns: Returns:
List[List[torch.Tensor]]: audio embeddings List[List[torch.Tensor]]: audio embeddings
""" """
# print("audio embedding") wavforms = flatten_nested_list(
[item.audio_features for item in items if item.audio_features]
wavforms = (
[]
if multimodal_input.audio_features is None
else multimodal_input.audio_features
) )
# list, [[x1, x2], [y1], [z1]] # list, [[x1, x2], [y1], [z1]]
audio_feature_lens_raw = ( audio_feature_lens_raw = flatten_nested_list(
[] [item.audio_feature_lens for item in items if item.audio_feature_lens]
if multimodal_input.audio_feature_lens is None
else multimodal_input.audio_feature_lens
) )
# exist audio # exist audio
@@ -1650,7 +1642,7 @@ class MiniCPMO(MiniCPMVBaseModel):
ret[i, start:ending] = True ret[i, start:ending] = True
return ret return ret
def get_audio_embedding(self, multimodal_input: MultimodalInputs, chunk_length=-1): def get_audio_embedding(self, items: List[MultimodalDataItem], chunk_length=-1):
r""" r"""
Extract full audio embeddings with optional chunk-based attention. Extract full audio embeddings with optional chunk-based attention.
@@ -1659,31 +1651,25 @@ class MiniCPMO(MiniCPMVBaseModel):
not use key-value caching and is suitable for non-streaming inference. not use key-value caching and is suitable for non-streaming inference.
Args: Args:
multimodal_input (dict):
- **"audio_features"** (`torch.FloatTensor`): Input mel-spectrograms of shape `(batch_size, 80, frames)`.
- **"audio_feature_lens"** (List[List[int]]): Lengths of each audio segment for each item in the batch.
chunk_length (int, optional): Determines whether to use full attention (-1) or chunk-based chunk_length (int, optional): Determines whether to use full attention (-1) or chunk-based
attention (>0) during embedding computation. attention (>0) during embedding computation.
Returns: Returns:
List[List[torch.Tensor]]: audio embeddings List[List[torch.Tensor]]: audio embeddings
""" """
# print("audio embedding")
# (bs, 80, frames) or [], multi audios need filled in advance # (bs, 80, frames) or [], multi audios need filled in advance
wavforms = ( wavforms = flatten_nested_list(
[] [item.audio_features for item in items if item.audio_features]
if multimodal_input.audio_features is None
else multimodal_input.audio_features
) )
# list, [[x1, x2], [y1], [z1]] # list, [[x1, x2], [y1], [z1]]
audio_feature_lens_raw = ( audio_feature_lens_raw = flatten_nested_list(
[] [item.audio_feature_lens for item in items if item.audio_feature_lens]
if multimodal_input.audio_feature_lens is None
else multimodal_input.audio_feature_lens
) )
final_audio_embeds = [] final_audio_embeds = []
assert isinstance(wavforms, list)
assert isinstance(wavforms[0], torch.Tensor)
# exist audio # exist audio
for wavform in wavforms: for wavform in wavforms:
if len(wavform) > 0: if len(wavform) > 0:
@@ -1757,86 +1743,46 @@ class MiniCPMO(MiniCPMVBaseModel):
final_audio_embeds.append(target_audio_embeds) final_audio_embeds.append(target_audio_embeds)
return final_audio_embeds return final_audio_embeds
def get_audio_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
embedding = self.get_omni_embedding(
items=items,
chunk_length=self.config.audio_chunk_length,
stream_input=False,
)
return embedding
def get_omni_embedding( def get_omni_embedding(
self, self,
input_ids, items: List[MultimodalDataItem],
multimodal_input: MultimodalInputs,
input_embeds: torch.Tensor,
forward_mode: ForwardMode,
chunk_length=-1, chunk_length=-1,
stream_input=False, stream_input=False,
): ):
""" """
Args: Args:
multimodal_input:
input_embeds:
chunk_length: whisper use full attention or chunk attention chunk_length: whisper use full attention or chunk attention
stream_input: use streaming audio embedding stream_input: use streaming audio embedding
Returns: Returns:
final embeddings with audio feature final embeddings with audio feature
""" """
input_embeds = input_embeds.unsqueeze(0)
if not forward_mode.is_decode() and multimodal_input.contains_audio_inputs():
audio_bounds = get_multimodal_data_bounds(
input_ids=input_ids,
pad_values=multimodal_input.pad_values,
token_pairs=[
(multimodal_input.audio_start_id, multimodal_input.audio_end_id)
],
)
if audio_bounds.numel() == 0:
input_embeds = input_embeds.squeeze(0)
# TODO
logger.warn("Unimplemented logic. Please try disabling chunked prefill")
return input_embeds
audio_bounds = audio_bounds.unsqueeze(0)
bs = len(input_embeds)
if stream_input: if stream_input:
audio_embeddings = self.get_audio_embedding_streaming(multimodal_input) audio_embeddings = self.get_audio_embedding_streaming(items)
else: else:
audio_embeddings = self.get_audio_embedding( audio_embeddings = self.get_audio_embedding(items, chunk_length)
multimodal_input, chunk_length bs = len(audio_embeddings)
) # batch size
# batch size audio_embs = torch.cat(flatten_nested_list(audio_embeddings), dim=0)
assert len(audio_embeddings) == len(input_embeds)
if len(audio_embeddings) > 0:
if self.config.chunk_input:
for i in range(bs):
audio_embs = torch.cat(audio_embeddings[i], dim=0).to(
device=input_embeds.device, dtype=input_embeds.dtype
)
audio_start_pos = 0
for bound in audio_bounds[i]:
audio_len = bound[1] - bound[0] + 1
input_embeds[0, bound[0] : bound[1] + 1] = audio_embs[
audio_start_pos : audio_start_pos + audio_len, :
]
audio_start_pos += audio_len
else:
for i in range(bs):
audio_embs = audio_embeddings[i]
bounds = audio_bounds[i]
for embs, bound in zip(audio_embs, bounds):
audio_indices = torch.arange(
bound[0], bound[1], dtype=torch.long
).to(input_embeds.device)
if embs.shape[0] != len(audio_indices): return audio_embs
raise ValueError(
f"Shape mismatch: Trying to assign embeddings of shape {embs.shape} " def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
f"to input indices of length {len(audio_indices)}" # list of tensors
) pixel_values = flatten_nested_list([item.pixel_values for item in items])
input_embeds[i, audio_indices] = embs.to(input_embeds.dtype) tgt_sizes = torch.stack(
input_embeds = input_embeds.squeeze(0) flatten_nested_list([item.tgt_size for item in items]), dim=0
return input_embeds )
assert len(pixel_values) == tgt_sizes.shape[0]
def get_image_features(
self,
image_inputs: MultimodalInputs,
) -> torch.Tensor:
pixel_values = image_inputs.pixel_values
tgt_sizes = image_inputs.tgt_sizes
device = self.vpm.embeddings.position_embedding.weight.device device = self.vpm.embeddings.position_embedding.weight.device
dtype = self.vpm.embeddings.position_embedding.weight.dtype dtype = self.vpm.embeddings.position_embedding.weight.dtype
all_pixel_values_lst = [ all_pixel_values_lst = [
@@ -1845,10 +1791,10 @@ class MiniCPMO(MiniCPMVBaseModel):
max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item() max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item()
assert isinstance(max_patches, int) assert isinstance(max_patches, int)
all_pixel_values = torch.nn.utils.rnn.pad_sequence( all_pixel_values = torch.nn.utils.rnn.pad_sequence(
all_pixel_values_lst, batch_first=True, padding_value=0.0 all_pixel_values_lst, batch_first=True, padding_value=0.0
) )
B, L, _ = all_pixel_values.shape B, L, _ = all_pixel_values.shape
all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L) all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
patch_attn_mask = torch.zeros( patch_attn_mask = torch.zeros(
@@ -1875,53 +1821,23 @@ class MiniCPMO(MiniCPMVBaseModel):
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
**kwargs: Any, **kwargs: Any,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = None
# TODO(mick): optimize the logic here: clamp, merge and embedding should happens at most once
if (
not forward_batch.forward_mode.is_decode()
and forward_batch.contains_image_inputs()
):
mm_inputs = forward_batch.merge_mm_inputs()
inputs_embeds = embed_mm_inputs(
mm_input=mm_inputs,
input_ids=input_ids,
input_embedding=self.get_input_embeddings(),
mm_data_embedding_func=self.get_image_features,
placeholder_token_ids=[mm_inputs.im_token_id] + mm_inputs.pad_values,
)
input_ids = input_ids.clamp( mm_input = forward_batch.merge_mm_inputs()
min=0, max=self.get_input_embeddings().num_embeddings - 1 placeholder_token_ids = (
([mm_input.im_token_id] + [item.pad_value for item in mm_input.mm_items])
if forward_batch.contains_mm_inputs()
else []
) )
if inputs_embeds is None: hidden_states = general_mm_embed_routine(
inputs_embeds = self.llm.get_input_embeddings(input_ids) input_ids=input_ids,
if (
not forward_batch.forward_mode.is_decode()
and self.config.init_audio
and forward_batch.contains_audio_inputs()
):
mm_input = forward_batch.merge_mm_inputs()
inputs_embeds = self.get_omni_embedding(
input_ids=input_ids,
multimodal_input=mm_input,
input_embeds=inputs_embeds,
forward_mode=forward_batch.forward_mode,
chunk_length=self.config.audio_chunk_length,
stream_input=False,
)
forward_batch.mm_inputs = None
hidden_states = self.llm.model(
input_ids=None,
positions=positions,
forward_batch=forward_batch, forward_batch=forward_batch,
input_embeds=inputs_embeds, language_model=self.llm,
) image_data_embedding_func=self.get_image_feature,
audio_data_embedding_func=self.get_audio_feature,
return self.logits_processor( placeholder_token_ids=placeholder_token_ids,
input_ids, hidden_states, self.llm.lm_head, forward_batch positions=positions,
) )
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [

View File

@@ -54,12 +54,12 @@ from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs, MultiModalityDataPaddingPatternTokenPairs,
general_mm_embed_routine, general_mm_embed_routine,
) )
from sglang.srt.managers.schedule_batch import MultimodalInputs from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.utils import set_default_torch_dtype from sglang.srt.model_loader.utils import set_default_torch_dtype
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2 import Qwen2Config, Qwen2ForCausalLM from sglang.srt.models.qwen2 import Qwen2Config, Qwen2ForCausalLM
from sglang.srt.utils import add_prefix from sglang.srt.utils import add_prefix, flatten_nested_list
RawImageType = Union[Image.Image, torch.Tensor] RawImageType = Union[Image.Image, torch.Tensor]
@@ -661,7 +661,7 @@ def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]:
return tuple(int(x) for x in version_str.split(".")) return tuple(int(x) for x in version_str.split("."))
class MiniCPMVBaseModel(nn.Module): class MiniCPMBaseModel(nn.Module):
""" """
The abstract class of MiniCPMV can only be inherited, but cannot be The abstract class of MiniCPMV can only be inherited, but cannot be
instantiated. instantiated.
@@ -853,7 +853,7 @@ class MiniCPMVBaseModel(nn.Module):
return vlm_embedding, vision_hidden_states return vlm_embedding, vision_hidden_states
def get_input_embeddings(self) -> nn.Embedding: def get_input_embeddings(self) -> nn.Embedding:
return self.llm.get_input_embedding() return self.llm.get_input_embeddings()
def forward( def forward(
self, self,
@@ -862,23 +862,14 @@ class MiniCPMVBaseModel(nn.Module):
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
**kwargs: Any, **kwargs: Any,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = general_mm_embed_routine( hidden_states = general_mm_embed_routine(
input_ids=input_ids, input_ids=input_ids,
forward_batch=forward_batch, forward_batch=forward_batch,
embed_tokens=self.get_input_embeddings(), image_data_embedding_func=self.get_image_feature,
mm_data_embedding_func=self.get_image_features, language_model=self.llm,
)
hidden_states = self.llm.model(
input_ids=None,
positions=positions, positions=positions,
forward_batch=forward_batch,
input_embeds=inputs_embeds,
)
return self.logits_processor(
input_ids, hidden_states, self.llm.lm_head, forward_batch
) )
return hidden_states
def init_llm( def init_llm(
self, self,
@@ -913,11 +904,11 @@ class MiniCPMVBaseModel(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
def get_image_features(self, image_inputs: MultimodalInputs) -> torch.Tensor: def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
class MiniCPMV2_6(MiniCPMVBaseModel): class MiniCPMV2_6(MiniCPMBaseModel):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
@@ -1023,14 +1014,13 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
) )
return vision_embedding return vision_embedding
def get_image_features( def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
self,
image_inputs: MultimodalInputs,
) -> torch.Tensor:
# list of tensors # list of tensors
pixel_values = image_inputs.pixel_values pixel_values = flatten_nested_list([item.pixel_values for item in items])
tgt_sizes = torch.stack(
tgt_sizes = image_inputs.tgt_sizes flatten_nested_list([item.tgt_size for item in items]), dim=0
)
assert len(pixel_values) == tgt_sizes.shape[0]
device = self.vpm.embeddings.position_embedding.weight.device device = self.vpm.embeddings.position_embedding.weight.device
dtype = self.vpm.embeddings.position_embedding.weight.dtype dtype = self.vpm.embeddings.position_embedding.weight.dtype
@@ -1040,10 +1030,10 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item() max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item()
assert isinstance(max_patches, int) assert isinstance(max_patches, int)
all_pixel_values = torch.nn.utils.rnn.pad_sequence( all_pixel_values = torch.nn.utils.rnn.pad_sequence(
all_pixel_values_lst, batch_first=True, padding_value=0.0 all_pixel_values_lst, batch_first=True, padding_value=0.0
) )
B, L, _ = all_pixel_values.shape B, L, _ = all_pixel_values.shape
all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L) all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
patch_attn_mask = torch.zeros( patch_attn_mask = torch.zeros(

View File

@@ -796,14 +796,16 @@ class MllamaForConditionalGeneration(nn.Module):
self.logits_processor = LogitsProcessor(config.text_config) self.logits_processor = LogitsProcessor(config.text_config)
self.capture_mode = False self.capture_mode = False
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs): def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
pixel_values = image_inputs.pixel_values pixel_values = torch.cat(
pad_values = image_inputs.pad_values [item.pixel_values for item in mm_inputs.mm_items], dim=0
)
pad_values = [item.pad_value for item in mm_inputs.mm_items]
num_concurrent_media, num_tiles = pixel_values.shape[1:3] num_concurrent_media, num_tiles = pixel_values.shape[1:3]
num_patches = self.vision_model.num_patches num_patches = self.vision_model.num_patches
image_len = num_concurrent_media * num_tiles * num_patches image_len = num_concurrent_media * num_tiles * num_patches
image_inputs.num_image_tokens = image_len mm_inputs.num_image_tokens = image_len
pad_ids = pad_values * ((image_len + len(pad_values)) // len(pad_values)) pad_ids = pad_values * ((image_len + len(pad_values)) // len(pad_values))
@@ -815,10 +817,16 @@ class MllamaForConditionalGeneration(nn.Module):
# pixel_values: shape (bs, num_image, num_tiles, 3, image_res, image_res) # pixel_values: shape (bs, num_image, num_tiles, 3, image_res, image_res)
max_num_images = max_num_tiles = bs = 0 max_num_images = max_num_tiles = bs = 0
for i, im in enumerate(forward_batch.mm_inputs): for i, mm_input in enumerate(forward_batch.mm_inputs):
if not forward_batch.encoder_cached[i] and im is not None:
max_num_images = max(max_num_images, im.pixel_values.shape[1]) if not forward_batch.encoder_cached[i] and mm_input is not None:
max_num_tiles = max(max_num_tiles, im.pixel_values.shape[2]) pixel_values = torch.cat(
[item.pixel_values for item in mm_input.mm_items], dim=0
)
# max_num_images = max(max_num_images, sum(1 if item.is_image() else 0 for item in mm_input.items))
max_num_images = max(max_num_images, pixel_values.shape[1])
max_num_tiles = max(max_num_tiles, pixel_values.shape[2])
bs += 1 bs += 1
if max_num_images * max_num_tiles * bs == 0: if max_num_images * max_num_tiles * bs == 0:
@@ -842,17 +850,24 @@ class MllamaForConditionalGeneration(nn.Module):
) )
i = 0 i = 0
encoder_lens_need = [] encoder_lens_need = []
for k, im in enumerate(forward_batch.mm_inputs):
if forward_batch.encoder_cached[k] or im is None: for k, mm_input in enumerate(forward_batch.mm_inputs):
if forward_batch.encoder_cached[k] or mm_input is None:
continue continue
encoder_lens_need.append(forward_batch.encoder_lens[k]) encoder_lens_need.append(forward_batch.encoder_lens[k])
for j in range(im.pixel_values.shape[1]): pixel_values = torch.cat(
img = im.pixel_values[0, j] [item.pixel_values for item in mm_input.mm_items], dim=0
)
for j in range(pixel_values.shape[1]):
img = pixel_values[0, j]
num_tiles = img.shape[0] num_tiles = img.shape[0]
batched_images[i, j, :num_tiles] = img batched_images[i, j, :num_tiles] = img
batched_ar_ids[i, j] = im.aspect_ratio_ids[0, j] batched_ar_ids[i, j] = mm_input.mm_items[0].aspect_ratio_id[0, j]
batched_ar_mask[i, j, :num_tiles] = im.aspect_ratio_mask[0, j]
batched_ar_mask[i, j, :num_tiles] = mm_input.mm_items[
0
].aspect_ratio_mask[0, j]
i += 1 i += 1
return batched_images, batched_ar_ids, batched_ar_mask, encoder_lens_need return batched_images, batched_ar_ids, batched_ar_mask, encoder_lens_need

View File

@@ -261,11 +261,14 @@ class Qwen2Model(nn.Module):
) )
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor:
if hasattr(self.config, "scale_emb"): if hasattr(self.config, "scale_emb"):
return self.embed_tokens(input_ids) * self.config.scale_emb return self.get_input_embeddings()(input_ids) * self.config.scale_emb
else: else:
return self.embed_tokens(input_ids) return self.get_input_embeddings()(input_ids)
def get_input_embeddings(self) -> nn.Embedding:
return self.embed_tokens
def forward( def forward(
self, self,
@@ -358,10 +361,10 @@ class Qwen2ForCausalLM(nn.Module):
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.get_input_embedding(input_ids)
def get_input_embedding(self) -> nn.Embedding: def get_input_embeddings(self) -> nn.Embedding:
return self.model.embed_tokens return self.model.embed_tokens
@torch.no_grad() @torch.no_grad()

View File

@@ -30,22 +30,13 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from transformers import AutoModel, Qwen2VLConfig from transformers import Qwen2VLConfig
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm
from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
Qwen2_5_VLConfig,
Qwen2_5_VLVisionConfig, Qwen2_5_VLVisionConfig,
) )
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VLForConditionalGeneration,
)
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from sglang.srt.hf_transformers_utils import get_processor from sglang.srt.hf_transformers_utils import get_processor
from sglang.srt.layers.attention.vision import VisionAttention from sglang.srt.layers.attention.vision import VisionAttention
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
@@ -57,7 +48,7 @@ from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs, MultiModalityDataPaddingPatternTokenPairs,
general_mm_embed_routine, general_mm_embed_routine,
) )
from sglang.srt.managers.schedule_batch import MultimodalInputs from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2 import Qwen2Model from sglang.srt.models.qwen2 import Qwen2Model
@@ -513,19 +504,24 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs): def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
# Get all special token IDs # Get all special token IDs
im_start_id: int = image_inputs.im_start_id im_start_id: int = mm_inputs.im_start_id
im_end_id: int = image_inputs.im_end_id im_end_id: int = mm_inputs.im_end_id
media_token_pairs = [(im_start_id, im_end_id)] media_token_pairs = [(im_start_id, im_end_id)]
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs) pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
return pattern.pad_input_tokens(input_ids, mm_inputs)
return pattern.pad_input_tokens(input_ids, image_inputs) def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
# in qwen-vl, last dim is the same
def get_image_feature(self, image_input: MultimodalInputs) -> torch.Tensor: pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type(
pixel_values = image_input.pixel_values.type(self.visual.dtype) self.visual.dtype
image_embeds = self.visual(pixel_values, grid_thw=image_input.image_grid_thws) )
image_grid_thws = torch.concat([item.image_grid_thws for item in items], dim=0)
assert pixel_values.dim() == 2, pixel_values.dim()
assert image_grid_thws.dim() == 2, image_grid_thws.dim()
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thws)
return image_embeds return image_embeds
def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor: def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor:
@@ -570,18 +566,12 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
f"(3, seq_len) positions, but got {positions.size()}" f"(3, seq_len) positions, but got {positions.size()}"
) )
inputs_embeds = general_mm_embed_routine( hidden_states = general_mm_embed_routine(
input_ids=input_ids, input_ids=input_ids,
forward_batch=forward_batch, forward_batch=forward_batch,
embed_tokens=self.get_input_embeddings(), language_model=self.model,
mm_data_embedding_func=self.get_image_feature, image_data_embedding_func=self.get_image_feature,
)
hidden_states = self.model(
input_ids=None,
positions=positions, positions=positions,
forward_batch=forward_batch,
input_embeds=inputs_embeds,
) )
if not get_embedding: if not get_embedding:
@@ -594,9 +584,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), (".qkv_proj", ".q_proj", "q"),
("qkv_proj", "k_proj", "k"), (".qkv_proj", ".k_proj", "k"),
("qkv_proj", "v_proj", "v"), (".qkv_proj", ".v_proj", "v"),
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "gate_proj", 0),
] ]

View File

@@ -45,7 +45,7 @@ from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs, MultiModalityDataPaddingPatternTokenPairs,
general_mm_embed_routine, general_mm_embed_routine,
) )
from sglang.srt.managers.schedule_batch import MultimodalInputs from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2 import Qwen2Model from sglang.srt.models.qwen2 import Qwen2Model
@@ -472,18 +472,24 @@ class Qwen2VLForConditionalGeneration(nn.Module):
# Use grid_t * grid_w * grid_h to pad tokens for each image # Use grid_t * grid_w * grid_h to pad tokens for each image
# add replaced padding by unique image hash # add replaced padding by unique image hash
def pad_input_ids(self, input_ids: List[int], multi_modal_inputs: MultimodalInputs): def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
# Get all special token IDs # Get all special token IDs
im_start_id: int = multi_modal_inputs.im_start_id im_start_id: int = mm_inputs.im_start_id
im_end_id: int = multi_modal_inputs.im_end_id im_end_id: int = mm_inputs.im_end_id
media_token_pairs = [(im_start_id, im_end_id)] media_token_pairs = [(im_start_id, im_end_id)]
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs) pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
return pattern.pad_input_tokens(input_ids, multi_modal_inputs) return pattern.pad_input_tokens(input_ids, mm_inputs)
def get_image_feature(self, image_input: MultimodalInputs) -> torch.Tensor: def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
pixel_values = image_input.pixel_values.type(self.visual.dtype) # in qwen-vl, last dim is the same
image_embeds = self.visual(pixel_values, grid_thw=image_input.image_grid_thws) pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type(
self.visual.dtype
)
image_grid_thws = torch.concat([item.image_grid_thws for item in items], dim=0)
assert pixel_values.dim() == 2, pixel_values.dim()
assert image_grid_thws.dim() == 2, image_grid_thws.dim()
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thws)
return image_embeds return image_embeds
def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor: def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor:
@@ -527,27 +533,20 @@ class Qwen2VLForConditionalGeneration(nn.Module):
"multimodal section rotary embedding requires " "multimodal section rotary embedding requires "
f"(3, seq_len) positions, but got {positions.size()}" f"(3, seq_len) positions, but got {positions.size()}"
) )
hidden_states = general_mm_embed_routine(
inputs_embeds = general_mm_embed_routine(
input_ids=input_ids, input_ids=input_ids,
forward_batch=forward_batch, forward_batch=forward_batch,
embed_tokens=self.get_input_embeddings(), language_model=self.model,
mm_data_embedding_func=self.get_image_feature, image_data_embedding_func=self.get_image_feature,
)
hidden_states = self.model(
input_ids=None,
positions=positions, positions=positions,
forward_batch=forward_batch,
input_embeds=inputs_embeds,
) )
if not get_embedding: if get_embedding:
return self.pooler(hidden_states, forward_batch)
else:
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch input_ids, hidden_states, self.lm_head, forward_batch
) )
else:
return self.pooler(hidden_states, forward_batch)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [

View File

@@ -897,6 +897,7 @@ def v1_chat_generate_request(
request_ids: List[str] = None, request_ids: List[str] = None,
): ):
input_ids = [] input_ids = []
prompts = []
sampling_params_list = [] sampling_params_list = []
image_data_list = [] image_data_list = []
audio_data_list = [] audio_data_list = []
@@ -916,6 +917,7 @@ def v1_chat_generate_request(
# - audio_data: None or a list of audio strings (URLs). # - audio_data: None or a list of audio strings (URLs).
# None skips any image processing in GenerateReqInput. # None skips any image processing in GenerateReqInput.
strict_tag = None strict_tag = None
prompt = ""
if not isinstance(request.messages, str): if not isinstance(request.messages, str):
# Apply chat template and its stop strings. # Apply chat template and its stop strings.
tools = None tools = None
@@ -1005,11 +1007,13 @@ def v1_chat_generate_request(
image_data = None image_data = None
audio_data = None audio_data = None
modalities = [] modalities = []
prompt = request.messages
input_ids.append(prompt_ids) input_ids.append(prompt_ids)
return_logprobs.append(request.logprobs) return_logprobs.append(request.logprobs)
logprob_start_lens.append(-1) logprob_start_lens.append(-1)
top_logprobs_nums.append(request.top_logprobs or 0) top_logprobs_nums.append(request.top_logprobs or 0)
lora_paths.append(request.lora_path) lora_paths.append(request.lora_path)
prompts.append(prompt)
sampling_params = { sampling_params = {
"temperature": request.temperature, "temperature": request.temperature,
@@ -1063,10 +1067,14 @@ def v1_chat_generate_request(
audio_data_list.append(audio_data) audio_data_list.append(audio_data)
modalities_list.append(modalities) modalities_list.append(modalities)
if len(all_requests) == 1: if len(all_requests) == 1:
if isinstance(input_ids[0], str): if tokenizer_manager.model_config.is_multimodal:
prompt_kwargs = {"text": input_ids[0]} # processor will need text input
prompt_kwargs = {"text": prompts[0]}
else: else:
prompt_kwargs = {"input_ids": input_ids[0]} if isinstance(input_ids[0], str):
prompt_kwargs = {"text": input_ids[0]}
else:
prompt_kwargs = {"input_ids": input_ids[0]}
sampling_params_list = sampling_params_list[0] sampling_params_list = sampling_params_list[0]
image_data_list = image_data_list[0] image_data_list = image_data_list[0]
audio_data_list = audio_data_list[0] audio_data_list = audio_data_list[0]
@@ -1076,10 +1084,14 @@ def v1_chat_generate_request(
modalities_list = modalities_list[0] modalities_list = modalities_list[0]
lora_paths = lora_paths[0] lora_paths = lora_paths[0]
else: else:
if isinstance(input_ids[0], str): if tokenizer_manager.model_config.is_multimodal:
prompt_kwargs = {"text": input_ids} # processor will need text input
prompt_kwargs = {"text": prompts}
else: else:
prompt_kwargs = {"input_ids": input_ids} if isinstance(input_ids[0], str):
prompt_kwargs = {"text": input_ids}
else:
prompt_kwargs = {"input_ids": input_ids}
adapted_request = GenerateReqInput( adapted_request = GenerateReqInput(
**prompt_kwargs, **prompt_kwargs,

View File

@@ -12,7 +12,6 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Common utilities.""" """Common utilities."""
import base64 import base64
import builtins import builtins
import ctypes import ctypes
@@ -54,6 +53,7 @@ import torch.distributed
import torch.distributed as dist import torch.distributed as dist
import triton import triton
import zmq import zmq
from decord import VideoReader, cpu
from fastapi.responses import ORJSONResponse from fastapi.responses import ORJSONResponse
from packaging import version as pkg_version from packaging import version as pkg_version
from PIL import Image from PIL import Image
@@ -513,13 +513,18 @@ def load_audio(audio_file: str, sr: int = 16000, mono: bool = True) -> np.ndarra
import soundfile as sf import soundfile as sf
from scipy.signal import resample from scipy.signal import resample
# print(f"loading {audio_file}")
# Load audio data # Load audio data
if isinstance(audio_file, bytes): if isinstance(audio_file, bytes):
audio, original_sr = sf.read(BytesIO(audio_file)) audio, original_sr = sf.read(BytesIO(audio_file))
elif audio_file.startswith("data:"): elif audio_file.startswith("data:"):
audio_file = audio_file.split(",")[1] audio_file = audio_file.split(",")[1]
audio, original_sr = sf.read(BytesIO(base64.b64decode(audio_file))) audio, original_sr = sf.read(BytesIO(base64.b64decode(audio_file)))
elif audio_file.startswith("http://") or audio_file.startswith("https://"):
timeout = int(os.getenv("REQUEST_TIMEOUT", "5"))
response = requests.get(audio_file, stream=True, timeout=timeout)
audio_file = BytesIO(response.content)
response.close()
audio, original_sr = sf.read(audio_file)
elif isinstance(audio_file, str): elif isinstance(audio_file, str):
audio, original_sr = sf.read(audio_file) audio, original_sr = sf.read(audio_file)
else: else:
@@ -537,6 +542,30 @@ def load_audio(audio_file: str, sr: int = 16000, mono: bool = True) -> np.ndarra
return audio return audio
def encode_video(video_path, frame_count_limit=None):
if not os.path.exists(video_path):
logger.error(f"Video {video_path} does not exist")
return []
if frame_count_limit == 0:
return []
def uniform_sample(l, n):
gap = len(l) / n
idxs = [int(i * gap + gap / 2) for i in range(n)]
return [l[i] for i in idxs]
vr = VideoReader(video_path, ctx=cpu(0))
sample_fps = round(vr.get_avg_fps() / 1) # FPS
frame_indices = [i for i in range(0, len(vr), sample_fps)]
if frame_count_limit is not None and len(frame_indices) > frame_count_limit:
frame_indices = uniform_sample(frame_indices, frame_count_limit)
frames = vr.get_batch(frame_indices).asnumpy()
frames = [Image.fromarray(v.astype("uint8")) for v in frames]
return frames
def load_image(image_file: Union[str, bytes]) -> tuple[Image, tuple[int, int]]: def load_image(image_file: Union[str, bytes]) -> tuple[Image, tuple[int, int]]:
image = image_size = None image = image_size = None
@@ -1796,3 +1825,12 @@ def retry(
traceback.print_exc() traceback.print_exc()
time.sleep(delay) time.sleep(delay)
def flatten_nested_list(nested_list):
if isinstance(nested_list, list):
return [
item for sublist in nested_list for item in flatten_nested_list(sublist)
]
else:
return [nested_list]

View File

@@ -155,9 +155,7 @@ class TestOpenAIVisionServer(CustomTestCase):
"content": [ "content": [
{ {
"type": "image_url", "type": "image_url",
"image_url": { "image_url": {"url": IMAGE_MAN_IRONING_URL},
"url": "https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png"
},
"modalities": "multi-images", "modalities": "multi-images",
}, },
{ {
@@ -399,14 +397,14 @@ class TestOpenAIVisionServer(CustomTestCase):
{ {
"role": "user", "role": "user",
"content": [ "content": [
{
"type": "text",
"text": prompt,
},
{ {
"type": "audio_url", "type": "audio_url",
"audio_url": {"url": f"{audio_file_name}"}, "audio_url": {"url": f"{audio_file_name}"},
}, },
{
"type": "text",
"text": prompt,
},
], ],
} }
] ]

View File

@@ -3,6 +3,7 @@
import unittest import unittest
from io import BytesIO from io import BytesIO
from typing import List
import numpy as np import numpy as np
import requests import requests
@@ -14,7 +15,11 @@ from transformers import AutoModel, AutoProcessor, AutoTokenizer
from sglang.srt.configs.model_config import ModelConfig from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.conversation import generate_chat_conv from sglang.srt.conversation import generate_chat_conv
from sglang.srt.managers.mm_utils import embed_mm_inputs from sglang.srt.managers.mm_utils import embed_mm_inputs
from sglang.srt.managers.schedule_batch import MultimodalInputs from sglang.srt.managers.schedule_batch import (
Modality,
MultimodalDataItem,
MultimodalInputs,
)
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.openai_api.protocol import ChatCompletionRequest from sglang.srt.openai_api.protocol import ChatCompletionRequest
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
@@ -195,14 +200,35 @@ class TestMiniCPMVLogits(VisionLLMLogitsBase):
# sglang # sglang
model = self.get_sglang_model() model = self.get_sglang_model()
input_ids = inputs["input_ids"].to(self.device).flatten() input_ids = inputs["input_ids"].to(self.device).flatten()
pixel_values = inputs["pixel_values"]
tgt_sizes = inputs["tgt_sizes"]
pixel_values_flat: List[torch.Tensor] = []
tgt_sizes_flat: List[torch.Tensor] = []
for pixel_b, tgt_b in zip(pixel_values, tgt_sizes):
# per image
if len(pixel_b) != len(tgt_b):
raise ValueError(
"Inconsistent N lengths, found: "
f"{len(pixel_b)} vs {len(tgt_b)}"
)
for pixel_n, tgt_n in zip(pixel_b, tgt_b):
pixel_values_flat += [pixel_n]
tgt_sizes_flat += [tgt_n]
sglang_output = embed_mm_inputs( sglang_output = embed_mm_inputs(
mm_input=MultimodalInputs( mm_inputs=MultimodalInputs(
pixel_values=inputs["pixel_values"][0], mm_items=[
tgt_sizes=inputs["tgt_sizes"][0], MultimodalDataItem(
pixel_values=pixel_values_flat,
tgt_size=tgt_sizes_flat,
modality=Modality.IMAGE,
pad_value=self.processor.tokenizer.unk_token_id,
)
]
), ),
input_ids=input_ids, input_ids=input_ids,
input_embedding=model.get_input_embeddings(), input_embedding=model.get_input_embeddings(),
mm_data_embedding_func=model.get_image_features, image_data_embedding_func=model.get_image_feature,
placeholder_token_ids=[ placeholder_token_ids=[
self.processor.tokenizer.unk_token_id, self.processor.tokenizer.unk_token_id,
], ],