vlm: support video as an input modality (#5888)

This commit is contained in:
Mick
2025-07-10 14:48:35 +08:00
committed by GitHub
parent 4ed57807c2
commit b5e3d6031c
42 changed files with 887 additions and 524 deletions

View File

@@ -88,9 +88,11 @@ class Conversation:
stop_str: Union[str, List[str]] = None
# The string that represents an image token in the prompt
image_token: str = "<image>"
video_token: str = "<video>"
audio_token: str = "<audio>"
image_data: Optional[List[str]] = None
video_data: Optional[List[str]] = None
modalities: Optional[List[str]] = None
stop_token_ids: Optional[int] = None
@@ -380,11 +382,15 @@ class Conversation:
self.messages.append([role, message])
def append_image(self, image: str):
"""Append a new message."""
"""Append a new image."""
self.image_data.append(image)
def append_video(self, video: str):
"""Append a new video."""
self.video_data.append(video)
def append_audio(self, audio: str):
"""Append a new message."""
"""Append a new audio."""
self.audio_data.append(audio)
def update_last_message(self, message: str):
@@ -433,6 +439,7 @@ class Conversation:
sep2=self.sep2,
stop_str=self.stop_str,
image_token=self.image_token,
video_token=self.video_token,
audio_token=self.audio_token,
)
@@ -495,8 +502,12 @@ def generate_embedding_convs(
sep2=conv_template.sep2,
stop_str=conv_template.stop_str,
image_data=[],
video_data=[],
audio_data=[],
modalities=[],
image_token=conv_template.image_token,
video_token=conv_template.video_token,
audio_token=conv_template.audio_token,
)
real_content = ""
@@ -557,10 +568,12 @@ def generate_chat_conv(
sep2=conv.sep2,
stop_str=conv.stop_str,
image_data=[],
video_data=[],
audio_data=[],
modalities=[],
image_token=conv.image_token,
audio_token=conv.audio_token,
video_token=conv.video_token,
)
if isinstance(request.messages, str):
@@ -602,6 +615,7 @@ def generate_chat_conv(
image_token = ""
audio_token = conv.audio_token
video_token = conv.video_token
for content in message.content:
if content.type == "text":
if num_image_url > 16:
@@ -614,6 +628,9 @@ def generate_chat_conv(
else:
real_content += image_token
conv.append_image(content.image_url.url)
elif content.type == "video_url":
real_content += video_token
conv.append_video(content.video_url.url)
elif content.type == "audio_url":
real_content += audio_token
conv.append_audio(content.audio_url.url)
@@ -810,6 +827,7 @@ register_conv_template(
sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
stop_str=["<|im_end|>"],
image_token="<|vision_start|><|image_pad|><|vision_end|>",
video_token="<|vision_start|><|video_pad|><|vision_end|>",
)
)
@@ -870,6 +888,7 @@ register_conv_template(
sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
stop_str=("<|im_end|>", "<|endoftext|>"),
image_token="(<image>./</image>)",
video_token="(<video>./</video>)",
)
)

View File

@@ -267,6 +267,10 @@ class ChatCompletionMessageContentImageURL(BaseModel):
detail: Optional[Literal["auto", "low", "high"]] = "auto"
class ChatCompletionMessageContentVideoURL(BaseModel):
url: str
class ChatCompletionMessageContentAudioURL(BaseModel):
url: str
@@ -277,6 +281,11 @@ class ChatCompletionMessageContentImagePart(BaseModel):
modalities: Optional[Literal["image", "multi-images", "video"]] = "image"
class ChatCompletionMessageContentVideoPart(BaseModel):
type: Literal["video_url"]
video_url: ChatCompletionMessageContentVideoURL
class ChatCompletionMessageContentAudioPart(BaseModel):
type: Literal["audio_url"]
audio_url: ChatCompletionMessageContentAudioURL
@@ -285,6 +294,7 @@ class ChatCompletionMessageContentAudioPart(BaseModel):
ChatCompletionMessageContentPart = Union[
ChatCompletionMessageContentTextPart,
ChatCompletionMessageContentImagePart,
ChatCompletionMessageContentVideoPart,
ChatCompletionMessageContentAudioPart,
]
@@ -629,6 +639,7 @@ class MessageProcessingResult:
prompt_ids: Union[str, List[int]]
image_data: Optional[Any]
audio_data: Optional[Any]
video_data: Optional[Any]
modalities: List[str]
stop: List[str]
tool_call_constraint: Optional[Any] = None

View File

@@ -82,6 +82,7 @@ class OpenAIServingChat(OpenAIServingBase):
adapted_request = GenerateReqInput(
**prompt_kwargs,
image_data=processed_messages.image_data,
video_data=processed_messages.video_data,
audio_data=processed_messages.audio_data,
sampling_params=sampling_params,
return_logprob=request.logprobs,
@@ -143,6 +144,7 @@ class OpenAIServingChat(OpenAIServingBase):
prompt_ids = []
openai_compatible_messages = []
image_data = []
video_data = []
audio_data = []
modalities = []
@@ -158,6 +160,7 @@ class OpenAIServingChat(OpenAIServingBase):
msg_dict,
template_content_format,
image_data,
video_data,
audio_data,
modalities,
)
@@ -214,11 +217,13 @@ class OpenAIServingChat(OpenAIServingBase):
stop = request.stop
image_data = image_data if image_data else None
audio_data = audio_data if audio_data else None
video_data = video_data if video_data else None
modalities = modalities if modalities else []
return MessageProcessingResult(
prompt=prompt,
prompt_ids=prompt_ids,
image_data=image_data,
video_data=video_data,
audio_data=audio_data,
modalities=modalities,
stop=stop,
@@ -260,6 +265,7 @@ class OpenAIServingChat(OpenAIServingBase):
prompt = conv.get_prompt()
image_data = conv.image_data if conv.image_data else None
video_data = conv.video_data if conv.video_data else None
audio_data = conv.audio_data if conv.audio_data else None
modalities = conv.modalities if conv.modalities else []
stop = copy.copy(conv.stop_str or [] if not request.ignore_eos else [])
@@ -277,6 +283,7 @@ class OpenAIServingChat(OpenAIServingBase):
prompt=prompt,
prompt_ids=prompt_ids,
image_data=image_data,
video_data=video_data,
audio_data=audio_data,
modalities=modalities,
stop=stop,

View File

@@ -110,6 +110,7 @@ def process_content_for_template_format(
msg_dict: dict,
content_format: str,
image_data: list,
video_data: list,
audio_data: list,
modalities: list,
) -> dict:
@@ -120,6 +121,7 @@ def process_content_for_template_format(
msg_dict: Message dictionary with content
content_format: 'string' or 'openai' (detected via AST analysis)
image_data: List to append extracted image URLs
video_data: List to append extracted video URLs
audio_data: List to append extracted audio URLs
modalities: List to append modalities
@@ -143,6 +145,12 @@ def process_content_for_template_format(
modalities.append(chunk.get("modalities"))
# Normalize to simple 'image' type for template compatibility
processed_content_parts.append({"type": "image"})
elif chunk_type == "video_url":
video_data.append(chunk["video_url"]["url"])
if chunk.get("modalities"):
modalities.append(chunk.get("modalities"))
# Normalize to simple 'video' type for template compatibility
processed_content_parts.append({"type": "video"})
elif chunk_type == "audio_url":
audio_data.append(chunk["audio_url"]["url"])
# Normalize to simple 'audio' type

View File

@@ -65,6 +65,8 @@ class GenerateReqInput:
] = None
# The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
audio_data: Optional[Union[List[AudioDataItem], AudioDataItem]] = None
# The video input. Like image data, it can be a file name, a url, or base64 encoded string.
video_data: Optional[Union[List[List[str]], List[str], str]] = None
# The sampling_params. See descriptions below.
sampling_params: Optional[Union[List[Dict], Dict]] = None
# The request id.
@@ -110,7 +112,11 @@ class GenerateReqInput:
data_parallel_rank: Optional[int] = None
def contains_mm_input(self) -> bool:
return has_valid_data(self.image_data) or has_valid_data(self.audio_data)
return (
has_valid_data(self.image_data)
or has_valid_data(self.video_data)
or has_valid_data(self.audio_data)
)
def normalize_batch_and_arguments(self):
"""
@@ -232,6 +238,7 @@ class GenerateReqInput:
self._normalize_rid(num)
self._normalize_lora_paths(num)
self._normalize_image_data(num)
self._normalize_video_data(num)
self._normalize_audio_data(num)
self._normalize_sampling_params(num)
self._normalize_logprob_params(num)
@@ -300,6 +307,15 @@ class GenerateReqInput:
self.image_data = wrapped_images * self.parallel_sample_num
self.modalities = ["image"] * num
def _normalize_video_data(self, num):
"""Normalize video data for batch processing."""
if self.video_data is None:
self.video_data = [None] * num
elif not isinstance(self.video_data, list):
self.video_data = [self.video_data] * num
elif isinstance(self.video_data, list):
self.video_data = self.video_data * self.parallel_sample_num
def _normalize_audio_data(self, num):
"""Normalize audio data for batch processing."""
if self.audio_data is None:
@@ -408,6 +424,7 @@ class GenerateReqInput:
self.input_embeds[i] if self.input_embeds is not None else None
),
image_data=self.image_data[i],
video_data=self.video_data[i],
audio_data=self.audio_data[i],
sampling_params=self.sampling_params[i],
rid=self.rid[i],
@@ -507,6 +524,8 @@ class EmbeddingReqInput:
image_data: Optional[
Union[List[List[Union[Image, str]]], List[Union[Image, str]], Union[Image, str]]
] = None
# The video input. Like image data, it can be a file name, a url, or base64 encoded string.
video_data: Optional[Union[List[str], str]] = None
# The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
audio_data: Optional[Union[List[str], str]] = None
# The token ids for text; one can either specify text or input_ids.
@@ -578,7 +597,11 @@ class EmbeddingReqInput:
return self.rid
def contains_mm_input(self) -> bool:
return has_valid_data(self.image_data) or has_valid_data(self.audio_data)
return (
has_valid_data(self.image_data)
or has_valid_data(self.video_data)
or has_valid_data(self.audio_data)
)
def __getitem__(self, i):
if self.is_cross_encoder_request:

View File

@@ -4,7 +4,7 @@ Multi-modality utils
import hashlib
from abc import abstractmethod
from typing import Callable, List, Optional, Tuple
from typing import Callable, Dict, List, Optional, Tuple
import numpy as np
import torch
@@ -76,6 +76,7 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
This function will replace the data-tokens in between with pad_values accordingly
"""
pad_values = [item.pad_value for item in mm_inputs.mm_items]
print(f"{mm_inputs.mm_items=}")
data_token_pairs = self.data_token_id_pairs
mm_inputs.data_offsets = []
if data_token_pairs is None:
@@ -159,10 +160,10 @@ class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPa
return ret_input_ids
embedding_cache = None
embedding_cache: Optional[MultiModalCache] = None
def init_embedding_cache(max_size: int):
def init_embedding_cache(max_size: int = 0):
global embedding_cache
embedding_cache = MultiModalCache(max_size)
@@ -255,6 +256,7 @@ def _get_chunked_prefill_embedding(
continue
embedding_items_per_req = embedding_items[items_size[i] : items_size[i + 1]]
items_offset = items_offset_list[i]
assert items_offset is not None, items_offset
embedding_items_hash = get_embedding_hash(embedding_items_per_req)
# if all items has been prefixed, we do not need to calculate embedding
if all([offset_end < prefix_length[i] for _, offset_end in items_offset]):
@@ -380,11 +382,9 @@ def embed_mm_inputs(
extend_seq_lens: List[int],
input_ids: torch.Tensor,
input_embedding: nn.Embedding,
image_data_embedding_func: Callable[
[List[MultimodalDataItem]], torch.Tensor
] = None,
audio_data_embedding_func: Callable[
[List[MultimodalDataItem]], torch.Tensor
multimodal_model: nn.Module = None,
data_embedding_func_mapping: Dict[
Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
] = None,
placeholder_tokens: dict[Modality, List[int]] = None,
) -> Optional[torch.Tensor]:
@@ -397,8 +397,6 @@ def embed_mm_inputs(
extend_seq_lens: Sequence lengths for each request
input_ids: Input token IDs tensor
input_embedding: Embedding layer for text tokens
image_data_embedding_func: Function to embed image data
audio_data_embedding_func: Function to embed audio data
placeholder_tokens: Token IDs for multimodal placeholders (uses pad_values if None)
Returns:
@@ -415,88 +413,53 @@ def embed_mm_inputs(
item_flatten_list += [item for item in mm_inputs.mm_items if item is not None]
embeddings, masks = [], []
# 2. Get multimodal embedding separately
# TODO: make this more generic
# Try get image embedding if any
if (
any(True for item in item_flatten_list if item.is_image())
and image_data_embedding_func
):
items = [item for item in item_flatten_list if item.is_image()]
placeholder_tensor = torch.tensor(
[item.pad_value for item in items],
device=input_ids.device,
# Try get mm embedding if any
for modality in Modality.all():
items = [
item for item in item_flatten_list if item.is_modality(modality=modality)
]
embedder = (
None
if data_embedding_func_mapping is None
else data_embedding_func_mapping.get(modality, None)
)
# calculate per request items length offset
items_size = torch.zeros(len(mm_inputs_list) + 1, dtype=int)
items_offsets = []
for i, mm_inputs in enumerate(mm_inputs_list):
image_items = [item for item in mm_inputs.mm_items if item.is_image()]
items_size[i + 1] = len(image_items)
items_offsets.append(
flatten_nested_list(
[
item.image_offsets
for item in mm_inputs.mm_items
if item.is_image()
]
)
if embedder is None:
# "image", "video", etc
modality_id = modality.name.lower()
embedder = getattr(multimodal_model, f"get_{modality_id}_feature", None)
if len(items) != 0 and embedder is not None:
placeholder_tensor = torch.tensor(
[item.pad_value for item in items],
device=input_ids.device,
)
items_size = torch.cumsum(items_size, dim=0).tolist()
embedding, mask = get_embedding_and_mask(
data_embedding_func=image_data_embedding_func,
embedding_items=items,
placeholder_tensor=placeholder_tensor,
input_ids=input_ids,
items_size=items_size,
prefix_length=extend_prefix_lens,
extend_length=extend_seq_lens,
items_offset_list=items_offsets,
)
embeddings += [embedding]
masks += [mask]
# Try get audio embedding if any
if (
any(True for item in item_flatten_list if item.is_audio())
and audio_data_embedding_func
):
items = [item for item in item_flatten_list if item.is_audio()]
placeholder_tensor = torch.tensor(
[item.pad_value for item in items],
device=input_ids.device,
)
items_offsets = []
# calculate per request items length offset
items_size = torch.zeros(len(mm_inputs_list) + 1, dtype=int)
for i, mm_inputs in enumerate(mm_inputs_list):
audio_items = [item for item in mm_inputs.mm_items if item.is_audio()]
items_size[i + 1] = len(audio_items)
items_offsets.append(
flatten_nested_list(
[
item.audio_offsets
for item in mm_inputs.mm_items
if item.is_audio()
]
# calculate per request items length offset
items_size = torch.zeros(len(mm_inputs_list) + 1, dtype=int)
items_offsets = []
for i, mm_inputs in enumerate(mm_inputs_list):
mm_items = [
item
for item in mm_inputs.mm_items
if item.is_modality(modality=modality)
]
items_size[i + 1] = len(mm_items)
items_offsets.append(
flatten_nested_list([item.offsets for item in mm_inputs.mm_items])
)
)
items_size = torch.cumsum(items_size, dim=0)
items_size = torch.cumsum(items_size, dim=0).tolist()
embedding, mask = get_embedding_and_mask(
data_embedding_func=audio_data_embedding_func,
embedding_items=items,
placeholder_tensor=placeholder_tensor,
input_ids=input_ids,
items_size=items_size,
prefix_length=extend_prefix_lens,
extend_length=extend_seq_lens,
items_offset_list=items_offsets,
)
embeddings += [embedding]
masks += [mask]
embedding, mask = get_embedding_and_mask(
data_embedding_func=embedder,
embedding_items=items,
placeholder_tensor=placeholder_tensor,
input_ids=input_ids,
items_size=items_size,
prefix_length=extend_prefix_lens,
extend_length=extend_seq_lens,
items_offset_list=items_offsets,
)
embeddings += [embedding]
masks += [mask]
# 3. Get input embeddings
vocab_size = input_embedding.num_embeddings
@@ -523,11 +486,9 @@ def general_mm_embed_routine(
input_ids: torch.Tensor,
forward_batch: ForwardBatch,
language_model: nn.Module,
image_data_embedding_func: Optional[
Callable[[List[MultimodalDataItem]], torch.Tensor]
] = None,
audio_data_embedding_func: Optional[
Callable[[List[MultimodalDataItem]], torch.Tensor]
multimodal_model: Optional[nn.Module] = None,
data_embedding_funcs: Dict[
Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
] = None,
placeholder_tokens: Optional[dict[Modality, List[int]]] = None,
**kwargs,
@@ -572,8 +533,8 @@ def general_mm_embed_routine(
extend_seq_lens=extend_seq_lens,
input_ids=input_ids,
input_embedding=embed_tokens,
image_data_embedding_func=image_data_embedding_func,
audio_data_embedding_func=audio_data_embedding_func,
multimodal_model=multimodal_model,
data_embedding_func_mapping=data_embedding_funcs,
placeholder_tokens=placeholder_tokens,
)
# once used, mm_inputs is useless, considering chunked-prefill is disabled for multimodal models

View File

@@ -185,6 +185,10 @@ class Modality(Enum):
f"Invalid modality string: {modality_str}. Valid modalities are: {[m.name for m in Modality]}"
)
@staticmethod
def all():
return [Modality.IMAGE, Modality.VIDEO, Modality.AUDIO]
@dataclasses.dataclass
class MultimodalDataItem:
@@ -200,7 +204,7 @@ class MultimodalDataItem:
hash: int = None
pad_value: int = None
image_sizes: Tuple[int, int] = None
image_offsets: Optional[list] = None
offsets: Optional[list] = None
# the real data, pixel_values or audio_features
# data: Union[List[torch.Tensor], List[np.ndarray]]
@@ -253,12 +257,17 @@ class MultimodalDataItem:
self.hash = hash_feature(self.audio_features)
elif self.input_features is not None:
self.hash = hash_feature(self.input_features)
elif self.is_video():
self.hash = hash_feature(self.pixel_values_videos)
else:
self.hash = hash_feature(self.pixel_values)
assert self.hash is not None
self.pad_value = self.hash % (1 << 30)
def is_modality(self, modality: Modality) -> bool:
return self.modality == modality
def is_audio(self):
return (self.modality == Modality.AUDIO) and (
self.precomputed_features is not None
@@ -268,7 +277,7 @@ class MultimodalDataItem:
def is_image(self):
return (
self.modality == Modality.IMAGE or self.modality == Modality.MULTI_IMAGES
self.is_modality(Modality.IMAGE) or self.is_modality(Modality.MULTI_IMAGES)
) and (
self.precomputed_features is not None
or not MultimodalDataItem.is_empty_list(self.pixel_values)
@@ -277,7 +286,7 @@ class MultimodalDataItem:
def is_video(self):
return (self.modality == Modality.VIDEO) and (
self.precomputed_features is not None
or not MultimodalDataItem.is_empty_list(self.pixel_values)
or not MultimodalDataItem.is_empty_list(self.pixel_values_videos)
)
def is_valid(self) -> bool:
@@ -351,6 +360,7 @@ class MultimodalInputs:
"im_token_id",
"im_start_id",
"im_end_id",
"video_token_id",
"slice_start_id",
"slice_end_id",
"audio_start_id",
@@ -364,11 +374,12 @@ class MultimodalInputs:
return ret
def contains_image_inputs(self) -> bool:
""" """
return any(item.is_image() for item in self.mm_items)
def contains_video_inputs(self) -> bool:
return any(item.is_video() for item in self.mm_items)
def contains_audio_inputs(self) -> bool:
""" """
return any(item.is_audio() for item in self.mm_items)
def contains_mm_input(self) -> bool:

View File

@@ -453,8 +453,20 @@ class ForwardBatch:
for mm_input in self.mm_inputs
)
def contains_video_inputs(self) -> bool:
if self.mm_inputs is None:
return False
return any(
mm_input is not None and mm_input.contains_video_inputs()
for mm_input in self.mm_inputs
)
def contains_mm_inputs(self) -> bool:
return self.contains_audio_inputs() or self.contains_image_inputs()
return (
self.contains_audio_inputs()
or self.contains_video_inputs()
or self.contains_image_inputs()
)
def _compute_mrope_positions(
self, model_runner: ModelRunner, batch: ModelWorkerBatch

View File

@@ -1989,7 +1989,7 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
hidden_states = general_mm_embed_routine(
input_ids=input_ids,
forward_batch=forward_batch,
image_data_embedding_func=self.get_image_feature,
multimodal_model=self,
language_model=self.language_model,
positions=positions,
)

View File

@@ -227,7 +227,7 @@ class DeepseekVL2ForCausalLM(nn.Module):
input_ids=input_ids,
positions=positions,
forward_batch=forward_batch,
image_data_embedding_func=self.get_image_feature,
multimodal_model=self,
language_model=self.language_model,
)

View File

@@ -374,7 +374,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
input_ids=llm_input_ids,
forward_batch=forward_batch,
language_model=self.language_model,
image_data_embedding_func=self.get_image_feature,
multimodal_model=self,
positions=positions,
)

View File

@@ -1,7 +1,7 @@
import logging
import re
from functools import lru_cache
from typing import Dict, Iterable, List, Optional, Set, Tuple, TypedDict, Union
from typing import Iterable, List, Optional, Set, Tuple, TypedDict, Union
import torch
from torch import nn
@@ -25,6 +25,7 @@ from sglang.srt.managers.mm_utils import (
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import (
Modality,
MultimodalDataItem,
MultimodalInputs,
flatten_nested_list,
@@ -434,8 +435,10 @@ class Gemma3nForConditionalGeneration(PreTrainedModel):
input_ids=input_ids,
forward_batch=forward_batch,
language_model=self.language_model,
image_data_embedding_func=self.get_image_feature,
audio_data_embedding_func=self.get_audio_feature,
data_embedding_funcs={
Modality.IMAGE: self.get_image_feature,
Modality.AUDIO: self.get_audio_feature,
},
positions=positions,
per_layer_inputs=per_layer_inputs,
)

View File

@@ -29,7 +29,11 @@ from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
from sglang.srt.managers.schedule_batch import (
Modality,
MultimodalDataItem,
MultimodalInputs,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.deepseek_janus_pro import DropPath
@@ -523,7 +527,9 @@ class InternVLChatModel(nn.Module):
input_ids=input_ids,
forward_batch=forward_batch,
language_model=self.language_model,
image_data_embedding_func=self.get_image_feature,
data_embedding_funcs={
Modality.IMAGE: self.get_image_feature,
},
positions=positions,
)

View File

@@ -67,7 +67,11 @@ from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternMultimodalTokens,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
from sglang.srt.managers.schedule_batch import (
Modality,
MultimodalDataItem,
MultimodalInputs,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import (
default_weight_loader,
@@ -168,7 +172,9 @@ class KimiVLForConditionalGeneration(nn.Module):
input_ids=input_ids,
forward_batch=forward_batch,
language_model=self.language_model,
image_data_embedding_func=self.get_image_feature,
data_embedding_funcs={
Modality.IMAGE: self.get_image_feature,
},
positions=positions,
)

View File

@@ -787,7 +787,9 @@ class LlavaForConditionalGeneration(LlavaBaseForCausalLM):
forward_batch=forward_batch,
get_embedding=get_embedding,
language_model=self.language_model,
image_data_embedding_func=self.get_image_feature,
data_embedding_funcs={
Modality.IMAGE: self.get_image_feature,
},
placeholder_tokens=None, # using mm_item.pad_value
positions=positions,
)

View File

@@ -142,7 +142,7 @@ class LlavaVidForCausalLM(nn.Module):
)
image_offsets = [
flatten_nested_list(
[item.image_offsets for item in image_inputs[i].mm_items]
[item.offsets for item in image_inputs[i].mm_items]
)
for i in range(bs)
if need_vision[i]

View File

@@ -1827,8 +1827,7 @@ class MiniCPMO(MiniCPMBaseModel):
input_ids=input_ids,
forward_batch=forward_batch,
language_model=self.llm,
image_data_embedding_func=self.get_image_feature,
audio_data_embedding_func=self.get_audio_feature,
multimodal_model=self,
positions=positions,
)
return hidden_states

View File

@@ -573,7 +573,7 @@ class MiniCPMBaseModel(nn.Module):
hidden_states = general_mm_embed_routine(
input_ids=input_ids,
forward_batch=forward_batch,
image_data_embedding_func=self.get_image_feature,
multimodal_model=self,
language_model=self.llm,
positions=positions,
)

View File

@@ -6,8 +6,11 @@ from typing import List, Optional, Set, Tuple
import torch
from torch import nn
from transformers import Llama4Config, Llama4VisionModel
from transformers.models.llama4.modeling_llama4 import Llama4MultiModalProjector
from transformers import Llama4Config
from transformers.models.llama4.modeling_llama4 import (
Llama4MultiModalProjector,
Llama4VisionModel,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
@@ -16,7 +19,11 @@ from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternMultimodalTokens,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
from sglang.srt.managers.schedule_batch import (
Modality,
MultimodalDataItem,
MultimodalInputs,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix, is_cpu
@@ -166,7 +173,9 @@ class Llama4ForConditionalGeneration(nn.Module):
input_ids=input_ids,
forward_batch=forward_batch,
language_model=self.language_model,
image_data_embedding_func=image_embedding_func,
data_embedding_funcs={
Modality.IMAGE: self.get_image_feature,
},
positions=positions,
)

View File

@@ -31,7 +31,11 @@ from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternMultimodalTokens,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
from sglang.srt.managers.schedule_batch import (
Modality,
MultimodalDataItem,
MultimodalInputs,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.idefics2 import Idefics2VisionTransformer
@@ -439,7 +443,9 @@ class Phi4MMForCausalLM(nn.Module):
input_ids=input_ids,
forward_batch=forward_batch,
language_model=self.language_model,
image_data_embedding_func=self.get_image_feature,
data_embedding_funcs={
Modality.IMAGE: self.get_image_feature,
},
positions=positions,
)

View File

@@ -56,7 +56,6 @@ from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInp
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2 import Qwen2Model
from sglang.srt.models.qwen2_vl import Qwen2VLVideoInputs
from sglang.srt.utils import add_prefix
logger = logging.getLogger(__name__)
@@ -507,11 +506,15 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
return image_embeds
def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor:
pixel_values_videos = video_input["pixel_values_videos"].type(self.visual.dtype)
video_embeds = self.visual(
pixel_values_videos, grid_thw=video_input["video_grid_thw"]
)
def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
# in qwen-vl, last dim is the same
pixel_values = torch.cat(
[getattr(item, "pixel_values_videos") for item in items], dim=0
).type(self.visual.dtype)
video_grid_thw = torch.concat([item.video_grid_thw for item in items], dim=0)
assert pixel_values.dim() == 2, pixel_values.dim()
assert video_grid_thw.dim() == 2, video_grid_thw.dim()
video_embeds = self.visual(pixel_values, grid_thw=video_grid_thw)
return video_embeds
def get_input_embeddings(self):
@@ -553,7 +556,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
input_ids=input_ids,
forward_batch=forward_batch,
language_model=self.model,
image_data_embedding_func=self.get_image_feature,
multimodal_model=self,
positions=positions,
)

View File

@@ -493,6 +493,17 @@ class Qwen2VLForConditionalGeneration(nn.Module):
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
return image_embeds
def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
# in qwen-vl, last dim is the same
pixel_values = torch.cat(
[item.pixel_values_videos for item in items], dim=0
).type(self.visual.dtype)
video_grid_thw = torch.concat([item.video_grid_thw for item in items], dim=0)
assert pixel_values.dim() == 2, pixel_values.dim()
assert video_grid_thw.dim() == 2, video_grid_thw.dim()
video_embeds = self.visual(pixel_values, grid_thw=video_grid_thw)
return video_embeds
def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor:
pixel_values_videos = video_input["pixel_values_videos"].type(self.visual.dtype)
video_embeds = self.visual(
@@ -538,7 +549,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
input_ids=input_ids,
forward_batch=forward_batch,
language_model=self.model,
image_data_embedding_func=self.get_image_feature,
multimodal_model=self,
positions=positions,
)

View File

@@ -17,7 +17,11 @@ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorO
from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.mm_utils import MultiModalityDataPaddingPatternMultimodalTokens
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
from sglang.srt.managers.schedule_batch import (
Modality,
MultimodalDataItem,
MultimodalInputs,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
@@ -223,7 +227,9 @@ class VILAForConditionalGeneration(nn.Module):
input_ids=input_ids,
forward_batch=forward_batch,
language_model=self.llm,
image_data_embedding_func=self.get_image_feature,
data_embedding_funcs={
Modality.IMAGE: self.get_image_feature,
},
get_embedding=get_embedding,
positions=positions,
)

View File

@@ -5,7 +5,7 @@ import multiprocessing as mp
import os
import re
from abc import ABC, abstractmethod
from enum import Enum
from functools import lru_cache
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
@@ -14,7 +14,7 @@ from PIL import Image
from transformers import BaseImageProcessorFast
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.utils import encode_video, load_audio, load_image
from sglang.srt.utils import load_audio, load_image, load_video, logger
@dataclasses.dataclass
@@ -25,14 +25,22 @@ class BaseMultiModalProcessorOutput:
# frames loaded from image and video, in given order
images: Optional[list[Union[Image.Image, dict]]] = None
# videos
videos: Optional[list[Union[torch.Tensor, dict]]] = None
# audios
audios: Optional[list[Union[np.ndarray, dict]]] = None
def normalize(self):
for field_name in ["images", "audios"]:
field = getattr(self, field_name, None)
if field is not None and isinstance(field, list) and len(field) == 0:
setattr(self, field_name, None)
def organize_results(self) -> List[Tuple[Modality, Any]]:
"""
:return: a list of results, with their corresponding modalities
"""
return (
[(Modality.IMAGE, data) for data in self.images]
+ [(Modality.VIDEO, data) for data in self.videos]
+ [(Modality.AUDIO, data) for data in self.audios]
)
@dataclasses.dataclass
@@ -41,6 +49,10 @@ class MultimodalSpecialTokens:
video_token: Optional[Union[int, str, List[str]]] = None
audio_token: Optional[Union[int, str, List[str]]] = None
image_token_regex: Optional[re.Pattern] = None
video_token_regex: Optional[re.Pattern] = None
audio_token_regex: Optional[re.Pattern] = None
def convert_to_str(self, token: Union[str, int], processor) -> str:
if token is None:
return token
@@ -53,11 +65,29 @@ class MultimodalSpecialTokens:
self.video_token = self.convert_to_str(self.video_token, processor)
self.audio_token = self.convert_to_str(self.audio_token, processor)
image_token_regex: Optional[re.Pattern] = None
video_token_regex: Optional[re.Pattern] = None
audio_token_regex: Optional[re.Pattern] = None
def get_modality_of_token(self, token) -> Optional[Modality]:
"""
:return: the modality associated with the given token, if the token is a special_token or matches with the multimodal token regex
"""
modality = {
self.image_token: Modality.IMAGE,
self.video_token: Modality.VIDEO,
self.audio_token: Modality.AUDIO,
}.get(token)
if modality:
return modality
def __post_init__(self):
for regex, modality in [
(self.image_token_regex, Modality.IMAGE),
(self.video_token_regex, Modality.VIDEO),
(self.audio_token_regex, Modality.AUDIO),
]:
if regex and regex.match(token):
return modality
return None
def parse_regex(self):
if self.image_token_regex is None and self.image_token is not None:
self.image_token_regex = re.compile(re.escape(self.image_token))
if self.video_token_regex is None and self.video_token is not None:
@@ -65,7 +95,7 @@ class MultimodalSpecialTokens:
if self.audio_token_regex is None and self.audio_token is not None:
self.audio_token_regex = re.compile(re.escape(self.audio_token))
def collect(self) -> re.Pattern:
def combine_regex(self) -> re.Pattern:
tokens = [
self.image_token_regex,
self.video_token_regex,
@@ -105,6 +135,7 @@ class BaseMultimodalProcessor(ABC):
self.ATTR_NAME_TO_MODALITY = {
# Image-related attributes
"pixel_values": Modality.IMAGE,
"pixel_values_videos": Modality.VIDEO,
"image_sizes": Modality.IMAGE,
"image_grid_thw": Modality.IMAGE,
"image_emb_mask": Modality.IMAGE,
@@ -120,7 +151,7 @@ class BaseMultimodalProcessor(ABC):
"input_features": Modality.AUDIO,
"input_features_mask": Modality.AUDIO,
# Video-related attributes
"video_grid_thws": Modality.VIDEO,
"video_grid_thw": Modality.VIDEO,
# Generic attributes that could apply to multiple modalities
# "precomputed_features" - handled specially as it can be any modality
}
@@ -196,20 +227,25 @@ class BaseMultimodalProcessor(ABC):
@staticmethod
def _load_single_item(
data, is_video, is_audio, frame_count_limit=None, discard_alpha_channel=True
data, modality: Modality, frame_count_limit=None, discard_alpha_channel=True
):
"""Static method that can be pickled for multiprocessing"""
"""
Load a single multimodal data.
If data is precomputed, returns directly.
Static method that can be pickled for multiprocessing"""
if isinstance(data, dict):
return data
try:
if is_audio:
return load_audio(data)
elif is_video:
path = data[len("video:") :]
return encode_video(path, frame_count_limit)
else:
if modality == Modality.IMAGE:
img, _ = load_image(data)
return img.convert("RGB") if discard_alpha_channel else img
elif modality == Modality.VIDEO:
return load_video(data, frame_count_limit)
elif modality == Modality.AUDIO:
return load_audio(data)
except Exception as e:
raise RuntimeError(f"Error while loading data {data}: {e}")
@@ -217,75 +253,78 @@ class BaseMultimodalProcessor(ABC):
self,
text_parts: List[str],
multimodal_tokens: MultimodalSpecialTokens,
image_data: Optional[list] = None,
audio_data: Optional[list] = None,
data_iterators: dict,
discard_alpha_channel: bool = True,
):
image_estimated_frames_iter: Optional[iter] = None,
image_scaling_factor: float = 1.0,
max_image_frames: int = 30,
) -> Tuple[List, List]:
"""
load multimodal data parallelly
load multimodal data parallelly using iterators.
"""
# TODO(mick): load from server_args, env, or sampling_params
MAX_NUM_FRAMES = 30
estimated_frames_list = self.get_estimated_frames_list(image_data=image_data)
total_frame_count = sum(estimated_frames_list)
# a heuristic value, suggesting the maximum fraction of frames to embed from all visual inputs.
# e.g., 0.1 suggests that 1 frame out of 10 input frames should be used
scaling_factor = min(1.0, MAX_NUM_FRAMES / max(1, total_frame_count))
assert len(image_data) == len(estimated_frames_list)
# Submit all tasks
futures = []
task_info = []
image_index, audio_index = 0, 0
for text_part in text_parts:
if (
multimodal_tokens.image_token_regex
and multimodal_tokens.image_token_regex.match(text_part)
):
data = image_data[image_index]
is_video = isinstance(data, str) and data.startswith("video:")
estimated_frames = estimated_frames_list[image_index]
frame_count_limit = max(1, int(estimated_frames * scaling_factor))
modality = multimodal_tokens.get_modality_of_token(text_part)
if modality is not None:
data_iterator = data_iterators.get(modality)
if data_iterator is None:
raise ValueError(f"No data iterator found for token: {text_part}")
try:
data = next(data_iterator)
except StopIteration:
raise ValueError(
f"Mismatch: More '{text_part}' tokens found than corresponding data items provided."
)
frame_count_limit = None
if modality == Modality.IMAGE and image_estimated_frames_iter:
try:
estimated_frames = next(image_estimated_frames_iter)
# Use the pre-calculated scaling factor and max frames
frame_count_limit = max(
1, int(estimated_frames * image_scaling_factor)
)
# Ensure we don't exceed the absolute max (redundant if scaling_factor handles it)
# frame_count_limit = min(frame_count_limit, max_image_frames)
except StopIteration:
raise ValueError(
"Mismatch between image tokens and estimated frame counts."
)
futures.append(
self.io_executor.submit(
BaseMultimodalProcessor._load_single_item,
data,
is_video,
False,
modality,
frame_count_limit,
discard_alpha_channel,
)
)
task_info.append((Modality.IMAGE, data, frame_count_limit))
image_index += 1
elif (
multimodal_tokens.audio_token_regex
and multimodal_tokens.audio_token_regex.match(text_part)
):
data = audio_data[audio_index]
futures.append(
self.io_executor.submit(
BaseMultimodalProcessor._load_single_item,
data,
False,
True,
None,
discard_alpha_channel,
)
task_info.append((modality, data, frame_count_limit))
for modality, iterator in data_iterators.items():
try:
next(iterator)
logger.warning(
f"Warning: More {modality.name.lower()} data items provided than corresponding tokens found in the prompt."
)
task_info.append((Modality.AUDIO, data, None))
audio_index += 1
except StopIteration:
pass
except Exception:
pass
return futures, task_info
def load_mm_data(
self,
prompt: str | List[int],
prompt: str,
multimodal_tokens: MultimodalSpecialTokens,
max_req_input_len: int,
image_data: Optional[list] = None,
video_data: Optional[list] = None,
audio_data: Optional[list] = None,
return_text: Optional[bool] = True,
discard_alpha_channel: bool = True,
@@ -299,14 +338,9 @@ class BaseMultimodalProcessor(ABC):
discard_alpha_channel: if True, discards the alpha channel in the returned images
"""
if not return_text:
raise NotImplementedError()
if image_data is None:
image_data = []
multimodal_tokens.convert_to_strs(self._processor)
multimodal_tokens_pattern = multimodal_tokens.collect()
multimodal_tokens.parse_regex()
multimodal_tokens_pattern = multimodal_tokens.combine_regex()
if isinstance(prompt, list) and return_text:
assert len(prompt) and isinstance(prompt[0], int)
prompt = self._processor.tokenizer.decode(prompt)
@@ -317,59 +351,84 @@ class BaseMultimodalProcessor(ABC):
# split text into list of normal text and special tokens
text_parts = re.split(multimodal_tokens_pattern, prompt)
# collect all data
data_iterators = {}
if multimodal_tokens.image_token and image_data:
data_iterators[Modality.IMAGE] = iter(image_data)
if multimodal_tokens.video_token and video_data:
data_iterators[Modality.VIDEO] = iter(video_data)
if multimodal_tokens.audio_token and audio_data:
data_iterators[Modality.AUDIO] = iter(audio_data)
# futures: the futures of loaded data
# task_info: modality, raw_data, and other metadata of each data
futures, task_info = self.submit_data_loading_tasks(
text_parts=text_parts,
multimodal_tokens=multimodal_tokens,
image_data=image_data,
audio_data=audio_data,
data_iterators=data_iterators,
discard_alpha_channel=discard_alpha_channel,
)
task_info_iter = iter(task_info)
futures_iter = iter(futures)
# Process results
images, audios = [], []
new_text = ""
task_ptr = 0
images, videos, audios = [], [], []
new_text_parts = []
for text_part in text_parts:
if multimodal_tokens_pattern.match(text_part):
task_type, data, frame_limit = task_info[task_ptr]
result = futures[task_ptr].result()
task_ptr += 1
try:
if multimodal_tokens_pattern.match(text_part):
modality, raw_data, frame_limit = next(task_info_iter)
is_precomputed = isinstance(raw_data, dict)
result = next(futures_iter).result()
if task_type == Modality.IMAGE:
# If data is already processed it will be a
# dictionary. In this case we want to keep the
# expanded tokens in text_part. Otherwise, we will
# call the processor code, so keep only a single image
# token.
mm_tokens = (
text_part
if isinstance(data, dict)
else multimodal_tokens.image_token
)
frames = [result] if not isinstance(result, list) else result
if frames:
images += frames
new_text += mm_tokens * len(frames)
elif task_type == Modality.AUDIO:
# audio
mm_tokens = (
text_part
if isinstance(data, dict)
else multimodal_tokens.audio_token
)
audios.append(result)
new_text += mm_tokens
# TODO: handle video
else:
new_text += text_part
if modality == Modality.IMAGE:
# If data is already processed it will be a
# dictionary(precomputed). In this case we want to keep the
# expanded tokens in text_part. Otherwise, we will
# call the processor code, so keep only a single image
# token.
mm_tokens = (
text_part
if is_precomputed
else multimodal_tokens.image_token
)
frames = [result] if not isinstance(result, list) else result
if frames:
# only for minicpmv
images += frames
new_text_parts += mm_tokens * len(frames)
elif modality == Modality.VIDEO:
# load as video
mm_tokens = (
text_part
if is_precomputed
else multimodal_tokens.video_token
)
videos += [result]
new_text_parts += mm_tokens
elif modality == Modality.AUDIO:
# audio
mm_tokens = (
text_part
if is_precomputed
else multimodal_tokens.audio_token
)
audios += [result]
new_text_parts += mm_tokens
else:
# normal text
new_text_parts += [text_part]
out = BaseMultiModalProcessorOutput(
input_text=new_text,
except Exception as e:
raise RuntimeError(
f"An exception occurred while loading multimodal data: {e}"
)
return BaseMultiModalProcessorOutput(
images=images,
audios=audios,
videos=videos,
input_text="".join(new_text_parts),
)
out.normalize()
return out
@staticmethod
def get_mm_items_offset(
@@ -460,21 +519,19 @@ class BaseMultimodalProcessor(ABC):
)
except ValueError:
modality = Modality.IMAGE
if modality:
# Create item if needed
if modality not in items:
items[modality] = MultimodalDataItem(modality=modality)
# Set attribute
if hasattr(items[modality], attr_name):
setattr(items[modality], attr_name, value)
setattr(items[modality], attr_name, value)
return list(items.values())
def _process_and_collect_mm_items(
self, input_text: str, images=None, audios=None, videos=None, **kwargs
) -> Tuple[List[MultimodalDataItem], torch.Tensor]:
) -> Tuple[List[MultimodalDataItem], torch.Tensor, dict]:
"""
Helper method to process multimodal data and create mm_items in one step.
@@ -488,11 +545,11 @@ class BaseMultimodalProcessor(ABC):
input_ids = ret["input_ids"].flatten()
collected_items = self.collect_mm_items_from_processor_output(ret)
return collected_items, input_ids
return collected_items, input_ids, ret
def process_and_combine_mm_data(
self, base_output: BaseMultiModalProcessorOutput
) -> Tuple[List[MultimodalDataItem], torch.Tensor]:
) -> Tuple[List[MultimodalDataItem], torch.Tensor, dict]:
"""
Process multimodal data and return the combined multimodal items and input_ids.
Supports mixed modalities (images and audio in the same request).
@@ -501,8 +558,7 @@ class BaseMultimodalProcessor(ABC):
Tuple of (list of mm_items, input_ids)
"""
# Collect all items and categorize them
all_items = (base_output.images or []) + (base_output.audios or [])
all_items = base_output.organize_results()
# Handle text-only case
if not all_items:
input_ids = self._processor.tokenizer(
@@ -510,19 +566,20 @@ class BaseMultimodalProcessor(ABC):
return_tensors="pt",
add_special_tokens=True,
).input_ids.flatten()
return [], input_ids
return [], input_ids, {}
dict_items, raw_images, raw_audios = [], [], []
for item in all_items:
dict_items, raw_images, raw_audios, raw_videos = [], [], [], []
for modality, item in all_items:
if isinstance(item, dict):
dict_items.append(item)
elif isinstance(item, Image.Image):
elif modality == Modality.IMAGE:
raw_images.append(item)
elif isinstance(item, np.ndarray):
elif modality == Modality.AUDIO:
raw_audios.append(item)
elif modality == Modality.VIDEO:
raw_videos.append(item)
else:
raise ValueError(f"Unknown multimodal item type: {type(item)}")
# Process items and get input_ids
all_collected_items = []
input_ids = None
@@ -534,13 +591,16 @@ class BaseMultimodalProcessor(ABC):
)
# Handle raw items (need processing)
if raw_images or raw_audios:
collected_items, input_ids = self._process_and_collect_mm_items(
if raw_images or raw_audios or raw_videos:
collected_items, input_ids, ret = self._process_and_collect_mm_items(
input_text=base_output.input_text,
images=raw_images,
audios=raw_audios,
videos=raw_videos,
)
all_collected_items.extend(collected_items)
else:
ret = None
# Fallback tokenization if no raw items were processed
if input_ids is None:
@@ -553,21 +613,21 @@ class BaseMultimodalProcessor(ABC):
# Add offsets to all items
for mm_item in all_collected_items:
if mm_item.modality in [Modality.IMAGE, Modality.MULTI_IMAGES]:
mm_item.image_offsets = self.get_mm_items_offset(
mm_item.offsets = self.get_mm_items_offset(
input_ids=input_ids,
mm_token_id=self.IM_TOKEN_ID,
)
elif mm_item.modality == Modality.AUDIO:
mm_item.audio_offsets = self.get_mm_items_offset(
mm_item.offsets = self.get_mm_items_offset(
input_ids=input_ids,
mm_token_id=self.AUDIO_TOKEN_ID,
)
elif mm_item.modality == Modality.VIDEO:
mm_item.video_offsets = self.get_mm_items_offset(
mm_item.offsets = self.get_mm_items_offset(
input_ids=input_ids,
mm_token_id=self.VIDEO_TOKEN_ID,
)
else:
raise ValueError(f"Unknown modality: {mm_item.modality}")
return all_collected_items, input_ids
return all_collected_items, input_ids, ret

View File

@@ -69,7 +69,7 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
)
item = MultimodalDataItem(
pixel_values=res["images"],
image_offsets=image_offsets,
offsets=image_offsets,
modality=Modality.IMAGE,
image_emb_mask=images_seq_mask,
image_spatial_crop=batched_images_spatial_crop,

View File

@@ -36,6 +36,7 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
*args,
**kwargs,
):
print(f"{image_data=}")
base_output = self.load_mm_data(
prompt=input_text,
image_data=image_data,
@@ -46,8 +47,9 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
discard_alpha_channel=True,
)
mm_items, input_ids = self.process_and_combine_mm_data(base_output)
mm_items, input_ids, _ = self.process_and_combine_mm_data(base_output)
print(f"{base_output=}")
print(f"{mm_items=}")
return {
"input_ids": input_ids.tolist(),
"mm_items": mm_items,

View File

@@ -72,7 +72,7 @@ class Gemma3nSGLangProcessor(SGLangBaseProcessor):
),
)
mm_items, input_ids = self.process_and_combine_mm_data(base_output)
mm_items, input_ids, _ = self.process_and_combine_mm_data(base_output)
return {
"input_ids": input_ids.tolist(),

View File

@@ -225,7 +225,7 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
MultimodalDataItem(
pixel_values=pixel_values,
modality=Modality.IMAGE,
image_offsets=image_offsets,
offsets=image_offsets,
)
]

View File

@@ -49,7 +49,7 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
MultimodalDataItem(
pixel_values=res["pixel_values"],
image_emb_mask=res["images_emb_mask"],
image_offsets=image_offsets,
offsets=image_offsets,
modality=Modality.IMAGE,
)
],

View File

@@ -39,7 +39,7 @@ class KimiVLImageProcessor(SGLangBaseProcessor):
max_req_input_len=max_req_input_len,
)
mm_items, input_ids = self.process_and_combine_mm_data(base_output)
mm_items, input_ids, _ = self.process_and_combine_mm_data(base_output)
return {
"input_ids": input_ids.tolist(),

View File

@@ -19,6 +19,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
super().__init__(hf_config, server_args, _processor)
self.image_token = "(<image>./</image>)"
self.audio_token = "(<audio>./</audio>)"
self.video_token = "(<video>./</video>)"
async def process_mm_data_async(
self,
@@ -36,6 +37,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens(
image_token=self.image_token,
video_token=self.video_token,
audio_token=self.audio_token,
),
)
@@ -113,7 +115,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
if len(pixel_values) != 0:
item = MultimodalDataItem(
pixel_values=pixel_values,
image_offsets=image_offsets,
offsets=image_offsets,
tgt_size=tgt_sizes_flat,
modality=Modality.IMAGE,
)
@@ -135,11 +137,10 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
item = MultimodalDataItem(
audio_features=[res["audio_features"]],
audio_feature_lens=res["audio_feature_lens"],
audio_offsets=audio_offsets,
offsets=audio_offsets,
modality=Modality.AUDIO,
)
items += [item]
return {
"mm_items": items,
"input_ids": input_ids.tolist(),

View File

@@ -144,7 +144,7 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
MultimodalDataItem(
pixel_values=processor_output["pixel_values"],
modality=Modality.IMAGE,
image_offsets=image_offsets,
offsets=image_offsets,
)
]

View File

@@ -65,7 +65,7 @@ class Phi4MMImageProcessor(BaseMultimodalProcessor):
pixel_values=res["input_image_embeds"],
image_sizes=res["image_sizes"],
image_emb_mask=res["image_attention_mask"],
image_offsets=image_offsets,
offsets=image_offsets,
modality=Modality.IMAGE,
)
]

View File

@@ -106,7 +106,7 @@ class PixtralProcessor(BaseMultimodalProcessor):
pixel_values=processor_output["pixel_values"],
image_sizes=processor_output["image_sizes"],
modality=Modality.IMAGE,
image_offsets=image_offsets,
offsets=image_offsets,
)
]

View File

@@ -1,9 +1,13 @@
import asyncio
import math
import os
import re
from typing import Dict, List, Union
from typing import List, Union
import torch
import torchvision
from PIL import Image
from torchvision.transforms import InterpolationMode
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
from sglang.srt.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
@@ -12,6 +16,185 @@ from sglang.srt.multimodal.processors.base_processor import (
BaseMultimodalProcessor as SGLangBaseProcessor,
)
from sglang.srt.multimodal.processors.base_processor import MultimodalSpecialTokens
from sglang.utils import logger
IMAGE_FACTOR = 28
MIN_PIXELS = 4 * 28 * 28
MAX_PIXELS = 16384 * 28 * 28
MAX_RATIO = 200
VIDEO_TOTAL_PIXELS = int(
float(os.environ.get("VIDEO_MAX_PIXELS", 128000 * 28 * 28 * 0.9))
)
VIDEO_MIN_PIXELS = 128 * 28 * 28
VIDEO_MAX_PIXELS = 768 * 28 * 28
FRAME_FACTOR = 2
FPS = 2.0
FPS_MIN_FRAMES = 4
FPS_MAX_FRAMES = 768
def smart_resize(
height: int,
width: int,
factor: int = IMAGE_FACTOR,
min_pixels: int = MIN_PIXELS,
max_pixels: int = MAX_PIXELS,
) -> tuple[int, int]:
"""
Rescales the image so that the following conditions are met:
1. Both dimensions (height and width) are divisible by 'factor'.
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
3. The aspect ratio of the image is maintained as closely as possible.
"""
if max(height, width) / min(height, width) > MAX_RATIO:
raise ValueError(
f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
)
h_bar = max(factor, round_by_factor(height, factor))
w_bar = max(factor, round_by_factor(width, factor))
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = floor_by_factor(height / beta, factor)
w_bar = floor_by_factor(width / beta, factor)
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = ceil_by_factor(height * beta, factor)
w_bar = ceil_by_factor(width * beta, factor)
return h_bar, w_bar
def resize_image(image, size_factor: int = IMAGE_FACTOR) -> Image.Image:
width, height = image.size
min_pixels = MIN_PIXELS
max_pixels = MAX_PIXELS
resized_height, resized_width = smart_resize(
height,
width,
factor=size_factor,
min_pixels=min_pixels,
max_pixels=max_pixels,
)
image = image.resize((resized_width, resized_height))
return image
def round_by_factor(number: int, factor: int) -> int:
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
return round(number / factor) * factor
def ceil_by_factor(number: int, factor: int) -> int:
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
return math.ceil(number / factor) * factor
def floor_by_factor(number: int, factor: int) -> int:
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
return math.floor(number / factor) * factor
async def resize_image_async(image):
return resize_image(image)
def smart_nframes(
ele: dict,
total_frames: int,
video_fps: int | float,
) -> int:
"""calculate the number of frames for video used for model inputs.
Args:
ele (dict): a dict contains the configuration of video.
support either `fps` or `nframes`:
- nframes: the number of frames to extract for model inputs.
- fps: the fps to extract frames for model inputs.
- min_frames: the minimum number of frames of the video, only used when fps is provided.
- max_frames: the maximum number of frames of the video, only used when fps is provided.
total_frames (int): the original total number of frames of the video.
video_fps (int | float): the original fps of the video.
Raises:
ValueError: nframes should in interval [FRAME_FACTOR, total_frames].
Returns:
int: the number of frames for video used for model inputs.
"""
assert not (
"fps" in ele and "nframes" in ele
), "Only accept either `fps` or `nframes`"
if "nframes" in ele:
nframes = round_by_factor(ele["nframes"], FRAME_FACTOR)
else:
fps = ele.get("fps", FPS)
min_frames = ceil_by_factor(ele.get("min_frames", FPS_MIN_FRAMES), FRAME_FACTOR)
max_frames = floor_by_factor(
ele.get("max_frames", min(FPS_MAX_FRAMES, total_frames)), FRAME_FACTOR
)
nframes = total_frames / video_fps * fps
if nframes > total_frames:
logger.warning(
f"smart_nframes: nframes[{nframes}] > total_frames[{total_frames}]"
)
nframes = min(min(max(nframes, min_frames), max_frames), total_frames)
nframes = floor_by_factor(nframes, FRAME_FACTOR)
if not (FRAME_FACTOR <= nframes and nframes <= total_frames):
raise ValueError(
f"nframes should in interval [{FRAME_FACTOR}, {total_frames}], but got {nframes}."
)
return nframes
# process video, qwen-specific
async def preprocess_video(
vr,
image_factor: int = IMAGE_FACTOR,
# vr: VideoReader, image_factor: int = IMAGE_FACTOR
) -> torch.Tensor:
ele = {}
total_frames, video_fps = len(vr), vr.get_avg_fps()
nframes = smart_nframes({}, total_frames=total_frames, video_fps=video_fps)
idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist()
video = vr.get_batch(idx).asnumpy()
video = torch.tensor(video).permute(0, 3, 1, 2) # Convert to TCHW format
nframes, _, height, width = video.shape
min_pixels = ele.get("min_pixels", VIDEO_MIN_PIXELS)
total_pixels = ele.get("total_pixels", VIDEO_TOTAL_PIXELS)
max_pixels = max(
min(VIDEO_MAX_PIXELS, total_pixels / nframes * FRAME_FACTOR),
int(min_pixels * 1.05),
)
max_pixels_supposed = ele.get("max_pixels", max_pixels)
if max_pixels_supposed > max_pixels:
logger.warning(
f"The given max_pixels[{max_pixels_supposed}] exceeds limit[{max_pixels}]."
)
max_pixels = min(max_pixels_supposed, max_pixels)
if "resized_height" in ele and "resized_width" in ele:
resized_height, resized_width = smart_resize(
ele["resized_height"],
ele["resized_width"],
factor=image_factor,
)
else:
resized_height, resized_width = smart_resize(
height,
width,
factor=image_factor,
min_pixels=min_pixels,
max_pixels=max_pixels,
)
video = torchvision.transforms.functional.resize(
video,
[resized_height, resized_width],
interpolation=InterpolationMode.BICUBIC,
antialias=True,
).float()
return video
# Compatible with Qwen2VL and Qwen2_5VL
@@ -37,104 +220,44 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
self.MIN_PIXELS = 4 * 28 * 28
self.MAX_PIXELS = 16384 * 28 * 28
self.MAX_RATIO = 200
# TODO(mick): move all MultimodalSpecialTokens initializations into processor init
self.mm_special_tokens = MultimodalSpecialTokens(
image_token=self.IMAGE_TOKEN,
image_token_regex=self.IMAGE_TOKEN_REGEX,
video_token=self.VIDEO_TOKEN_ID,
)
async def process_mm_data_async(
self,
image_data: List[Union[str, bytes, Dict]],
image_data: List[Union[str, bytes]],
input_text,
request_obj,
max_req_input_len,
*args,
**kwargs,
):
base_output = self.load_mm_data(
prompt=input_text,
image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens(
image_token=self.IMAGE_TOKEN,
image_token_regex=self.IMAGE_TOKEN_REGEX,
),
video_data=request_obj.video_data,
multimodal_tokens=self.mm_special_tokens,
max_req_input_len=max_req_input_len,
)
def smart_resize(
height: int,
width: int,
factor: int = self.IMAGE_FACTOR,
min_pixels: int = self.MIN_PIXELS,
max_pixels: int = self.MAX_PIXELS,
) -> tuple[int, int]:
"""
Rescales the image so that the following conditions are met:
1. Both dimensions (height and width) are divisible by 'factor'.
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
3. The aspect ratio of the image is maintained as closely as possible.
"""
if max(height, width) / min(height, width) > self.MAX_RATIO:
raise ValueError(
f"absolute aspect ratio must be smaller than {self.MAX_RATIO}, got {max(height, width) / min(height, width)}"
)
h_bar = max(factor, round_by_factor(height, factor))
w_bar = max(factor, round_by_factor(width, factor))
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = floor_by_factor(height / beta, factor)
w_bar = floor_by_factor(width / beta, factor)
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = ceil_by_factor(height * beta, factor)
w_bar = ceil_by_factor(width * beta, factor)
return h_bar, w_bar
def resize_image(image, size_factor: int = self.IMAGE_FACTOR) -> Image.Image:
width, height = image.size
min_pixels = self.MIN_PIXELS
max_pixels = self.MAX_PIXELS
resized_height, resized_width = smart_resize(
height,
width,
factor=size_factor,
min_pixels=min_pixels,
max_pixels=max_pixels,
)
image = image.resize((resized_width, resized_height))
return image
def round_by_factor(number: int, factor: int) -> int:
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
return round(number / factor) * factor
def ceil_by_factor(number: int, factor: int) -> int:
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
return math.ceil(number / factor) * factor
def floor_by_factor(number: int, factor: int) -> int:
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
return math.floor(number / factor) * factor
async def resize_image_async(image):
return resize_image(image)
# Qwen-specific: resize images if they are raw Image objects
if base_output.images and isinstance(base_output.images[0], Image.Image):
resize_tasks = [resize_image_async(image) for image in base_output.images]
base_output.images = await asyncio.gather(*resize_tasks)
video_grid_thw = None # TODO
if base_output.videos:
base_output.videos = [
await preprocess_video(video) for video in base_output.videos
]
mm_items, input_ids = self.process_and_combine_mm_data(base_output)
if not mm_items:
# Note(Xinyuan): This is the case where image loading fails.
return None
combined_mm_item = mm_items[0] # only image is supported for now
video_grid_thw = None # TODO
second_per_grid_ts = getattr(combined_mm_item, "second_per_grid_ts", None)
mm_items, input_ids, ret = self.process_and_combine_mm_data(base_output)
input_ids = input_ids.flatten()
mrope_positions, mrope_position_delta = MRotaryEmbedding.get_rope_index(
spatial_merge_size=self.hf_config.vision_config.spatial_merge_size,
image_token_id=self.IM_TOKEN_ID,
@@ -145,9 +268,9 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
self.hf_config.vision_config, "tokens_per_second", None
),
input_ids=input_ids.unsqueeze(0),
image_grid_thw=combined_mm_item.image_grid_thw,
video_grid_thw=video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
image_grid_thw=getattr(ret, "image_grid_thw", None),
video_grid_thw=getattr(ret, "video_grid_thw", None),
second_per_grid_ts=getattr(ret, "second_per_grid_ts", None),
)
mrope_positions = mrope_positions.squeeze(1)

View File

@@ -57,7 +57,7 @@ class VILAMultimodalProcessor(BaseMultimodalProcessor):
image_data=image_data,
)
mm_items, input_ids = self.process_and_combine_mm_data(base_output)
mm_items, input_ids, _ = self.process_and_combine_mm_data(base_output)
return {
"input_ids": input_ids.tolist(),

View File

@@ -728,33 +728,6 @@ def load_audio(audio_file: str, sr: int = 16000, mono: bool = True) -> np.ndarra
return audio
def encode_video(video_path, frame_count_limit=None):
# Lazy import because decord is not available on some arm platforms.
from decord import VideoReader, cpu
if not os.path.exists(video_path):
logger.error(f"Video {video_path} does not exist")
return []
if frame_count_limit == 0:
return []
def uniform_sample(l, n):
gap = len(l) / n
idxs = [int(i * gap + gap / 2) for i in range(n)]
return [l[i] for i in idxs]
vr = VideoReader(video_path, ctx=cpu(0))
sample_fps = round(vr.get_avg_fps() / 1) # FPS
frame_indices = [i for i in range(0, len(vr), sample_fps)]
if frame_count_limit is not None and len(frame_indices) > frame_count_limit:
frame_indices = uniform_sample(frame_indices, frame_count_limit)
frames = vr.get_batch(frame_indices).asnumpy()
frames = [Image.fromarray(v.astype("uint8")) for v in frames]
return frames
def load_image(
image_file: Union[Image.Image, str, bytes],
) -> tuple[Image.Image, tuple[int, int]]:
@@ -774,9 +747,6 @@ def load_image(
elif image_file.startswith("data:"):
image_file = image_file.split(",")[1]
image = Image.open(BytesIO(pybase64.b64decode(image_file, validate=True)))
elif image_file.startswith("video:"):
image_file = image_file.replace("video:", "")
image, image_size = decode_video_base64(image_file)
elif isinstance(image_file, str):
image = Image.open(BytesIO(pybase64.b64decode(image_file, validate=True)))
else:
@@ -785,6 +755,61 @@ def load_image(
return image, image_size
def load_video(video_file: Union[str, bytes], use_gpu: bool = True):
# We import decord here to avoid a strange Segmentation fault (core dumped) issue.
from decord import VideoReader, cpu, gpu
try:
from decord.bridge import decord_bridge
ctx = gpu(0)
_ = decord_bridge.get_ctx_device(ctx)
except Exception:
ctx = cpu(0)
tmp_file = None
vr = None
try:
if isinstance(video_file, bytes):
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
tmp_file.write(video_file)
tmp_file.close()
vr = VideoReader(tmp_file.name, ctx=ctx)
elif isinstance(video_file, str):
if video_file.startswith(("http://", "https://")):
timeout = int(os.getenv("REQUEST_TIMEOUT", "10"))
response = requests.get(video_file, stream=True, timeout=timeout)
response.raise_for_status()
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
for chunk in response.iter_content(chunk_size=8192):
tmp_file.write(chunk)
tmp_file.close()
vr = VideoReader(tmp_file.name, ctx=ctx)
elif video_file.startswith("data:"):
_, encoded = video_file.split(",", 1)
video_bytes = base64.b64decode(encoded)
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
tmp_file.write(video_bytes)
tmp_file.close()
vr = VideoReader(tmp_file.name, ctx=ctx)
elif os.path.isfile(video_file):
vr = VideoReader(video_file, ctx=ctx)
else:
video_bytes = base64.b64decode(video_file)
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
tmp_file.write(video_bytes)
tmp_file.close()
vr = VideoReader(tmp_file.name, ctx=ctx)
else:
raise ValueError(f"Unsupported video input type: {type(video_file)}")
return vr
finally:
if tmp_file and os.path.exists(tmp_file.name):
os.unlink(tmp_file.name)
def suppress_other_loggers():
warnings.filterwarnings(
"ignore", category=UserWarning, message="The given NumPy array is not writable"