vlm: support video as an input modality (#5888)
This commit is contained in:
@@ -88,9 +88,11 @@ class Conversation:
|
||||
stop_str: Union[str, List[str]] = None
|
||||
# The string that represents an image token in the prompt
|
||||
image_token: str = "<image>"
|
||||
video_token: str = "<video>"
|
||||
audio_token: str = "<audio>"
|
||||
|
||||
image_data: Optional[List[str]] = None
|
||||
video_data: Optional[List[str]] = None
|
||||
modalities: Optional[List[str]] = None
|
||||
stop_token_ids: Optional[int] = None
|
||||
|
||||
@@ -380,11 +382,15 @@ class Conversation:
|
||||
self.messages.append([role, message])
|
||||
|
||||
def append_image(self, image: str):
|
||||
"""Append a new message."""
|
||||
"""Append a new image."""
|
||||
self.image_data.append(image)
|
||||
|
||||
def append_video(self, video: str):
|
||||
"""Append a new video."""
|
||||
self.video_data.append(video)
|
||||
|
||||
def append_audio(self, audio: str):
|
||||
"""Append a new message."""
|
||||
"""Append a new audio."""
|
||||
self.audio_data.append(audio)
|
||||
|
||||
def update_last_message(self, message: str):
|
||||
@@ -433,6 +439,7 @@ class Conversation:
|
||||
sep2=self.sep2,
|
||||
stop_str=self.stop_str,
|
||||
image_token=self.image_token,
|
||||
video_token=self.video_token,
|
||||
audio_token=self.audio_token,
|
||||
)
|
||||
|
||||
@@ -495,8 +502,12 @@ def generate_embedding_convs(
|
||||
sep2=conv_template.sep2,
|
||||
stop_str=conv_template.stop_str,
|
||||
image_data=[],
|
||||
video_data=[],
|
||||
audio_data=[],
|
||||
modalities=[],
|
||||
image_token=conv_template.image_token,
|
||||
video_token=conv_template.video_token,
|
||||
audio_token=conv_template.audio_token,
|
||||
)
|
||||
real_content = ""
|
||||
|
||||
@@ -557,10 +568,12 @@ def generate_chat_conv(
|
||||
sep2=conv.sep2,
|
||||
stop_str=conv.stop_str,
|
||||
image_data=[],
|
||||
video_data=[],
|
||||
audio_data=[],
|
||||
modalities=[],
|
||||
image_token=conv.image_token,
|
||||
audio_token=conv.audio_token,
|
||||
video_token=conv.video_token,
|
||||
)
|
||||
|
||||
if isinstance(request.messages, str):
|
||||
@@ -602,6 +615,7 @@ def generate_chat_conv(
|
||||
image_token = ""
|
||||
|
||||
audio_token = conv.audio_token
|
||||
video_token = conv.video_token
|
||||
for content in message.content:
|
||||
if content.type == "text":
|
||||
if num_image_url > 16:
|
||||
@@ -614,6 +628,9 @@ def generate_chat_conv(
|
||||
else:
|
||||
real_content += image_token
|
||||
conv.append_image(content.image_url.url)
|
||||
elif content.type == "video_url":
|
||||
real_content += video_token
|
||||
conv.append_video(content.video_url.url)
|
||||
elif content.type == "audio_url":
|
||||
real_content += audio_token
|
||||
conv.append_audio(content.audio_url.url)
|
||||
@@ -810,6 +827,7 @@ register_conv_template(
|
||||
sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
|
||||
stop_str=["<|im_end|>"],
|
||||
image_token="<|vision_start|><|image_pad|><|vision_end|>",
|
||||
video_token="<|vision_start|><|video_pad|><|vision_end|>",
|
||||
)
|
||||
)
|
||||
|
||||
@@ -870,6 +888,7 @@ register_conv_template(
|
||||
sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
|
||||
stop_str=("<|im_end|>", "<|endoftext|>"),
|
||||
image_token="(<image>./</image>)",
|
||||
video_token="(<video>./</video>)",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -267,6 +267,10 @@ class ChatCompletionMessageContentImageURL(BaseModel):
|
||||
detail: Optional[Literal["auto", "low", "high"]] = "auto"
|
||||
|
||||
|
||||
class ChatCompletionMessageContentVideoURL(BaseModel):
|
||||
url: str
|
||||
|
||||
|
||||
class ChatCompletionMessageContentAudioURL(BaseModel):
|
||||
url: str
|
||||
|
||||
@@ -277,6 +281,11 @@ class ChatCompletionMessageContentImagePart(BaseModel):
|
||||
modalities: Optional[Literal["image", "multi-images", "video"]] = "image"
|
||||
|
||||
|
||||
class ChatCompletionMessageContentVideoPart(BaseModel):
|
||||
type: Literal["video_url"]
|
||||
video_url: ChatCompletionMessageContentVideoURL
|
||||
|
||||
|
||||
class ChatCompletionMessageContentAudioPart(BaseModel):
|
||||
type: Literal["audio_url"]
|
||||
audio_url: ChatCompletionMessageContentAudioURL
|
||||
@@ -285,6 +294,7 @@ class ChatCompletionMessageContentAudioPart(BaseModel):
|
||||
ChatCompletionMessageContentPart = Union[
|
||||
ChatCompletionMessageContentTextPart,
|
||||
ChatCompletionMessageContentImagePart,
|
||||
ChatCompletionMessageContentVideoPart,
|
||||
ChatCompletionMessageContentAudioPart,
|
||||
]
|
||||
|
||||
@@ -629,6 +639,7 @@ class MessageProcessingResult:
|
||||
prompt_ids: Union[str, List[int]]
|
||||
image_data: Optional[Any]
|
||||
audio_data: Optional[Any]
|
||||
video_data: Optional[Any]
|
||||
modalities: List[str]
|
||||
stop: List[str]
|
||||
tool_call_constraint: Optional[Any] = None
|
||||
|
||||
@@ -82,6 +82,7 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
adapted_request = GenerateReqInput(
|
||||
**prompt_kwargs,
|
||||
image_data=processed_messages.image_data,
|
||||
video_data=processed_messages.video_data,
|
||||
audio_data=processed_messages.audio_data,
|
||||
sampling_params=sampling_params,
|
||||
return_logprob=request.logprobs,
|
||||
@@ -143,6 +144,7 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
prompt_ids = []
|
||||
openai_compatible_messages = []
|
||||
image_data = []
|
||||
video_data = []
|
||||
audio_data = []
|
||||
modalities = []
|
||||
|
||||
@@ -158,6 +160,7 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
msg_dict,
|
||||
template_content_format,
|
||||
image_data,
|
||||
video_data,
|
||||
audio_data,
|
||||
modalities,
|
||||
)
|
||||
@@ -214,11 +217,13 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
stop = request.stop
|
||||
image_data = image_data if image_data else None
|
||||
audio_data = audio_data if audio_data else None
|
||||
video_data = video_data if video_data else None
|
||||
modalities = modalities if modalities else []
|
||||
return MessageProcessingResult(
|
||||
prompt=prompt,
|
||||
prompt_ids=prompt_ids,
|
||||
image_data=image_data,
|
||||
video_data=video_data,
|
||||
audio_data=audio_data,
|
||||
modalities=modalities,
|
||||
stop=stop,
|
||||
@@ -260,6 +265,7 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
prompt = conv.get_prompt()
|
||||
|
||||
image_data = conv.image_data if conv.image_data else None
|
||||
video_data = conv.video_data if conv.video_data else None
|
||||
audio_data = conv.audio_data if conv.audio_data else None
|
||||
modalities = conv.modalities if conv.modalities else []
|
||||
stop = copy.copy(conv.stop_str or [] if not request.ignore_eos else [])
|
||||
@@ -277,6 +283,7 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
prompt=prompt,
|
||||
prompt_ids=prompt_ids,
|
||||
image_data=image_data,
|
||||
video_data=video_data,
|
||||
audio_data=audio_data,
|
||||
modalities=modalities,
|
||||
stop=stop,
|
||||
|
||||
@@ -110,6 +110,7 @@ def process_content_for_template_format(
|
||||
msg_dict: dict,
|
||||
content_format: str,
|
||||
image_data: list,
|
||||
video_data: list,
|
||||
audio_data: list,
|
||||
modalities: list,
|
||||
) -> dict:
|
||||
@@ -120,6 +121,7 @@ def process_content_for_template_format(
|
||||
msg_dict: Message dictionary with content
|
||||
content_format: 'string' or 'openai' (detected via AST analysis)
|
||||
image_data: List to append extracted image URLs
|
||||
video_data: List to append extracted video URLs
|
||||
audio_data: List to append extracted audio URLs
|
||||
modalities: List to append modalities
|
||||
|
||||
@@ -143,6 +145,12 @@ def process_content_for_template_format(
|
||||
modalities.append(chunk.get("modalities"))
|
||||
# Normalize to simple 'image' type for template compatibility
|
||||
processed_content_parts.append({"type": "image"})
|
||||
elif chunk_type == "video_url":
|
||||
video_data.append(chunk["video_url"]["url"])
|
||||
if chunk.get("modalities"):
|
||||
modalities.append(chunk.get("modalities"))
|
||||
# Normalize to simple 'video' type for template compatibility
|
||||
processed_content_parts.append({"type": "video"})
|
||||
elif chunk_type == "audio_url":
|
||||
audio_data.append(chunk["audio_url"]["url"])
|
||||
# Normalize to simple 'audio' type
|
||||
|
||||
@@ -65,6 +65,8 @@ class GenerateReqInput:
|
||||
] = None
|
||||
# The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
|
||||
audio_data: Optional[Union[List[AudioDataItem], AudioDataItem]] = None
|
||||
# The video input. Like image data, it can be a file name, a url, or base64 encoded string.
|
||||
video_data: Optional[Union[List[List[str]], List[str], str]] = None
|
||||
# The sampling_params. See descriptions below.
|
||||
sampling_params: Optional[Union[List[Dict], Dict]] = None
|
||||
# The request id.
|
||||
@@ -110,7 +112,11 @@ class GenerateReqInput:
|
||||
data_parallel_rank: Optional[int] = None
|
||||
|
||||
def contains_mm_input(self) -> bool:
|
||||
return has_valid_data(self.image_data) or has_valid_data(self.audio_data)
|
||||
return (
|
||||
has_valid_data(self.image_data)
|
||||
or has_valid_data(self.video_data)
|
||||
or has_valid_data(self.audio_data)
|
||||
)
|
||||
|
||||
def normalize_batch_and_arguments(self):
|
||||
"""
|
||||
@@ -232,6 +238,7 @@ class GenerateReqInput:
|
||||
self._normalize_rid(num)
|
||||
self._normalize_lora_paths(num)
|
||||
self._normalize_image_data(num)
|
||||
self._normalize_video_data(num)
|
||||
self._normalize_audio_data(num)
|
||||
self._normalize_sampling_params(num)
|
||||
self._normalize_logprob_params(num)
|
||||
@@ -300,6 +307,15 @@ class GenerateReqInput:
|
||||
self.image_data = wrapped_images * self.parallel_sample_num
|
||||
self.modalities = ["image"] * num
|
||||
|
||||
def _normalize_video_data(self, num):
|
||||
"""Normalize video data for batch processing."""
|
||||
if self.video_data is None:
|
||||
self.video_data = [None] * num
|
||||
elif not isinstance(self.video_data, list):
|
||||
self.video_data = [self.video_data] * num
|
||||
elif isinstance(self.video_data, list):
|
||||
self.video_data = self.video_data * self.parallel_sample_num
|
||||
|
||||
def _normalize_audio_data(self, num):
|
||||
"""Normalize audio data for batch processing."""
|
||||
if self.audio_data is None:
|
||||
@@ -408,6 +424,7 @@ class GenerateReqInput:
|
||||
self.input_embeds[i] if self.input_embeds is not None else None
|
||||
),
|
||||
image_data=self.image_data[i],
|
||||
video_data=self.video_data[i],
|
||||
audio_data=self.audio_data[i],
|
||||
sampling_params=self.sampling_params[i],
|
||||
rid=self.rid[i],
|
||||
@@ -507,6 +524,8 @@ class EmbeddingReqInput:
|
||||
image_data: Optional[
|
||||
Union[List[List[Union[Image, str]]], List[Union[Image, str]], Union[Image, str]]
|
||||
] = None
|
||||
# The video input. Like image data, it can be a file name, a url, or base64 encoded string.
|
||||
video_data: Optional[Union[List[str], str]] = None
|
||||
# The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
|
||||
audio_data: Optional[Union[List[str], str]] = None
|
||||
# The token ids for text; one can either specify text or input_ids.
|
||||
@@ -578,7 +597,11 @@ class EmbeddingReqInput:
|
||||
return self.rid
|
||||
|
||||
def contains_mm_input(self) -> bool:
|
||||
return has_valid_data(self.image_data) or has_valid_data(self.audio_data)
|
||||
return (
|
||||
has_valid_data(self.image_data)
|
||||
or has_valid_data(self.video_data)
|
||||
or has_valid_data(self.audio_data)
|
||||
)
|
||||
|
||||
def __getitem__(self, i):
|
||||
if self.is_cross_encoder_request:
|
||||
|
||||
@@ -4,7 +4,7 @@ Multi-modality utils
|
||||
|
||||
import hashlib
|
||||
from abc import abstractmethod
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -76,6 +76,7 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
|
||||
This function will replace the data-tokens in between with pad_values accordingly
|
||||
"""
|
||||
pad_values = [item.pad_value for item in mm_inputs.mm_items]
|
||||
print(f"{mm_inputs.mm_items=}")
|
||||
data_token_pairs = self.data_token_id_pairs
|
||||
mm_inputs.data_offsets = []
|
||||
if data_token_pairs is None:
|
||||
@@ -159,10 +160,10 @@ class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPa
|
||||
return ret_input_ids
|
||||
|
||||
|
||||
embedding_cache = None
|
||||
embedding_cache: Optional[MultiModalCache] = None
|
||||
|
||||
|
||||
def init_embedding_cache(max_size: int):
|
||||
def init_embedding_cache(max_size: int = 0):
|
||||
global embedding_cache
|
||||
embedding_cache = MultiModalCache(max_size)
|
||||
|
||||
@@ -255,6 +256,7 @@ def _get_chunked_prefill_embedding(
|
||||
continue
|
||||
embedding_items_per_req = embedding_items[items_size[i] : items_size[i + 1]]
|
||||
items_offset = items_offset_list[i]
|
||||
assert items_offset is not None, items_offset
|
||||
embedding_items_hash = get_embedding_hash(embedding_items_per_req)
|
||||
# if all items has been prefixed, we do not need to calculate embedding
|
||||
if all([offset_end < prefix_length[i] for _, offset_end in items_offset]):
|
||||
@@ -380,11 +382,9 @@ def embed_mm_inputs(
|
||||
extend_seq_lens: List[int],
|
||||
input_ids: torch.Tensor,
|
||||
input_embedding: nn.Embedding,
|
||||
image_data_embedding_func: Callable[
|
||||
[List[MultimodalDataItem]], torch.Tensor
|
||||
] = None,
|
||||
audio_data_embedding_func: Callable[
|
||||
[List[MultimodalDataItem]], torch.Tensor
|
||||
multimodal_model: nn.Module = None,
|
||||
data_embedding_func_mapping: Dict[
|
||||
Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
|
||||
] = None,
|
||||
placeholder_tokens: dict[Modality, List[int]] = None,
|
||||
) -> Optional[torch.Tensor]:
|
||||
@@ -397,8 +397,6 @@ def embed_mm_inputs(
|
||||
extend_seq_lens: Sequence lengths for each request
|
||||
input_ids: Input token IDs tensor
|
||||
input_embedding: Embedding layer for text tokens
|
||||
image_data_embedding_func: Function to embed image data
|
||||
audio_data_embedding_func: Function to embed audio data
|
||||
placeholder_tokens: Token IDs for multimodal placeholders (uses pad_values if None)
|
||||
|
||||
Returns:
|
||||
@@ -415,88 +413,53 @@ def embed_mm_inputs(
|
||||
item_flatten_list += [item for item in mm_inputs.mm_items if item is not None]
|
||||
|
||||
embeddings, masks = [], []
|
||||
|
||||
# 2. Get multimodal embedding separately
|
||||
# TODO: make this more generic
|
||||
# Try get image embedding if any
|
||||
if (
|
||||
any(True for item in item_flatten_list if item.is_image())
|
||||
and image_data_embedding_func
|
||||
):
|
||||
items = [item for item in item_flatten_list if item.is_image()]
|
||||
placeholder_tensor = torch.tensor(
|
||||
[item.pad_value for item in items],
|
||||
device=input_ids.device,
|
||||
# Try get mm embedding if any
|
||||
for modality in Modality.all():
|
||||
items = [
|
||||
item for item in item_flatten_list if item.is_modality(modality=modality)
|
||||
]
|
||||
embedder = (
|
||||
None
|
||||
if data_embedding_func_mapping is None
|
||||
else data_embedding_func_mapping.get(modality, None)
|
||||
)
|
||||
# calculate per request items length offset
|
||||
items_size = torch.zeros(len(mm_inputs_list) + 1, dtype=int)
|
||||
items_offsets = []
|
||||
for i, mm_inputs in enumerate(mm_inputs_list):
|
||||
image_items = [item for item in mm_inputs.mm_items if item.is_image()]
|
||||
items_size[i + 1] = len(image_items)
|
||||
items_offsets.append(
|
||||
flatten_nested_list(
|
||||
[
|
||||
item.image_offsets
|
||||
for item in mm_inputs.mm_items
|
||||
if item.is_image()
|
||||
]
|
||||
)
|
||||
if embedder is None:
|
||||
# "image", "video", etc
|
||||
modality_id = modality.name.lower()
|
||||
embedder = getattr(multimodal_model, f"get_{modality_id}_feature", None)
|
||||
if len(items) != 0 and embedder is not None:
|
||||
placeholder_tensor = torch.tensor(
|
||||
[item.pad_value for item in items],
|
||||
device=input_ids.device,
|
||||
)
|
||||
items_size = torch.cumsum(items_size, dim=0).tolist()
|
||||
|
||||
embedding, mask = get_embedding_and_mask(
|
||||
data_embedding_func=image_data_embedding_func,
|
||||
embedding_items=items,
|
||||
placeholder_tensor=placeholder_tensor,
|
||||
input_ids=input_ids,
|
||||
items_size=items_size,
|
||||
prefix_length=extend_prefix_lens,
|
||||
extend_length=extend_seq_lens,
|
||||
items_offset_list=items_offsets,
|
||||
)
|
||||
embeddings += [embedding]
|
||||
masks += [mask]
|
||||
|
||||
# Try get audio embedding if any
|
||||
if (
|
||||
any(True for item in item_flatten_list if item.is_audio())
|
||||
and audio_data_embedding_func
|
||||
):
|
||||
items = [item for item in item_flatten_list if item.is_audio()]
|
||||
placeholder_tensor = torch.tensor(
|
||||
[item.pad_value for item in items],
|
||||
device=input_ids.device,
|
||||
)
|
||||
items_offsets = []
|
||||
# calculate per request items length offset
|
||||
items_size = torch.zeros(len(mm_inputs_list) + 1, dtype=int)
|
||||
for i, mm_inputs in enumerate(mm_inputs_list):
|
||||
audio_items = [item for item in mm_inputs.mm_items if item.is_audio()]
|
||||
items_size[i + 1] = len(audio_items)
|
||||
items_offsets.append(
|
||||
flatten_nested_list(
|
||||
[
|
||||
item.audio_offsets
|
||||
for item in mm_inputs.mm_items
|
||||
if item.is_audio()
|
||||
]
|
||||
# calculate per request items length offset
|
||||
items_size = torch.zeros(len(mm_inputs_list) + 1, dtype=int)
|
||||
items_offsets = []
|
||||
for i, mm_inputs in enumerate(mm_inputs_list):
|
||||
mm_items = [
|
||||
item
|
||||
for item in mm_inputs.mm_items
|
||||
if item.is_modality(modality=modality)
|
||||
]
|
||||
items_size[i + 1] = len(mm_items)
|
||||
items_offsets.append(
|
||||
flatten_nested_list([item.offsets for item in mm_inputs.mm_items])
|
||||
)
|
||||
)
|
||||
items_size = torch.cumsum(items_size, dim=0)
|
||||
items_size = torch.cumsum(items_size, dim=0).tolist()
|
||||
|
||||
embedding, mask = get_embedding_and_mask(
|
||||
data_embedding_func=audio_data_embedding_func,
|
||||
embedding_items=items,
|
||||
placeholder_tensor=placeholder_tensor,
|
||||
input_ids=input_ids,
|
||||
items_size=items_size,
|
||||
prefix_length=extend_prefix_lens,
|
||||
extend_length=extend_seq_lens,
|
||||
items_offset_list=items_offsets,
|
||||
)
|
||||
embeddings += [embedding]
|
||||
masks += [mask]
|
||||
embedding, mask = get_embedding_and_mask(
|
||||
data_embedding_func=embedder,
|
||||
embedding_items=items,
|
||||
placeholder_tensor=placeholder_tensor,
|
||||
input_ids=input_ids,
|
||||
items_size=items_size,
|
||||
prefix_length=extend_prefix_lens,
|
||||
extend_length=extend_seq_lens,
|
||||
items_offset_list=items_offsets,
|
||||
)
|
||||
embeddings += [embedding]
|
||||
masks += [mask]
|
||||
|
||||
# 3. Get input embeddings
|
||||
vocab_size = input_embedding.num_embeddings
|
||||
@@ -523,11 +486,9 @@ def general_mm_embed_routine(
|
||||
input_ids: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
language_model: nn.Module,
|
||||
image_data_embedding_func: Optional[
|
||||
Callable[[List[MultimodalDataItem]], torch.Tensor]
|
||||
] = None,
|
||||
audio_data_embedding_func: Optional[
|
||||
Callable[[List[MultimodalDataItem]], torch.Tensor]
|
||||
multimodal_model: Optional[nn.Module] = None,
|
||||
data_embedding_funcs: Dict[
|
||||
Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
|
||||
] = None,
|
||||
placeholder_tokens: Optional[dict[Modality, List[int]]] = None,
|
||||
**kwargs,
|
||||
@@ -572,8 +533,8 @@ def general_mm_embed_routine(
|
||||
extend_seq_lens=extend_seq_lens,
|
||||
input_ids=input_ids,
|
||||
input_embedding=embed_tokens,
|
||||
image_data_embedding_func=image_data_embedding_func,
|
||||
audio_data_embedding_func=audio_data_embedding_func,
|
||||
multimodal_model=multimodal_model,
|
||||
data_embedding_func_mapping=data_embedding_funcs,
|
||||
placeholder_tokens=placeholder_tokens,
|
||||
)
|
||||
# once used, mm_inputs is useless, considering chunked-prefill is disabled for multimodal models
|
||||
|
||||
@@ -185,6 +185,10 @@ class Modality(Enum):
|
||||
f"Invalid modality string: {modality_str}. Valid modalities are: {[m.name for m in Modality]}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def all():
|
||||
return [Modality.IMAGE, Modality.VIDEO, Modality.AUDIO]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MultimodalDataItem:
|
||||
@@ -200,7 +204,7 @@ class MultimodalDataItem:
|
||||
hash: int = None
|
||||
pad_value: int = None
|
||||
image_sizes: Tuple[int, int] = None
|
||||
image_offsets: Optional[list] = None
|
||||
offsets: Optional[list] = None
|
||||
|
||||
# the real data, pixel_values or audio_features
|
||||
# data: Union[List[torch.Tensor], List[np.ndarray]]
|
||||
@@ -253,12 +257,17 @@ class MultimodalDataItem:
|
||||
self.hash = hash_feature(self.audio_features)
|
||||
elif self.input_features is not None:
|
||||
self.hash = hash_feature(self.input_features)
|
||||
elif self.is_video():
|
||||
self.hash = hash_feature(self.pixel_values_videos)
|
||||
else:
|
||||
self.hash = hash_feature(self.pixel_values)
|
||||
|
||||
assert self.hash is not None
|
||||
self.pad_value = self.hash % (1 << 30)
|
||||
|
||||
def is_modality(self, modality: Modality) -> bool:
|
||||
return self.modality == modality
|
||||
|
||||
def is_audio(self):
|
||||
return (self.modality == Modality.AUDIO) and (
|
||||
self.precomputed_features is not None
|
||||
@@ -268,7 +277,7 @@ class MultimodalDataItem:
|
||||
|
||||
def is_image(self):
|
||||
return (
|
||||
self.modality == Modality.IMAGE or self.modality == Modality.MULTI_IMAGES
|
||||
self.is_modality(Modality.IMAGE) or self.is_modality(Modality.MULTI_IMAGES)
|
||||
) and (
|
||||
self.precomputed_features is not None
|
||||
or not MultimodalDataItem.is_empty_list(self.pixel_values)
|
||||
@@ -277,7 +286,7 @@ class MultimodalDataItem:
|
||||
def is_video(self):
|
||||
return (self.modality == Modality.VIDEO) and (
|
||||
self.precomputed_features is not None
|
||||
or not MultimodalDataItem.is_empty_list(self.pixel_values)
|
||||
or not MultimodalDataItem.is_empty_list(self.pixel_values_videos)
|
||||
)
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
@@ -351,6 +360,7 @@ class MultimodalInputs:
|
||||
"im_token_id",
|
||||
"im_start_id",
|
||||
"im_end_id",
|
||||
"video_token_id",
|
||||
"slice_start_id",
|
||||
"slice_end_id",
|
||||
"audio_start_id",
|
||||
@@ -364,11 +374,12 @@ class MultimodalInputs:
|
||||
return ret
|
||||
|
||||
def contains_image_inputs(self) -> bool:
|
||||
""" """
|
||||
return any(item.is_image() for item in self.mm_items)
|
||||
|
||||
def contains_video_inputs(self) -> bool:
|
||||
return any(item.is_video() for item in self.mm_items)
|
||||
|
||||
def contains_audio_inputs(self) -> bool:
|
||||
""" """
|
||||
return any(item.is_audio() for item in self.mm_items)
|
||||
|
||||
def contains_mm_input(self) -> bool:
|
||||
|
||||
@@ -453,8 +453,20 @@ class ForwardBatch:
|
||||
for mm_input in self.mm_inputs
|
||||
)
|
||||
|
||||
def contains_video_inputs(self) -> bool:
|
||||
if self.mm_inputs is None:
|
||||
return False
|
||||
return any(
|
||||
mm_input is not None and mm_input.contains_video_inputs()
|
||||
for mm_input in self.mm_inputs
|
||||
)
|
||||
|
||||
def contains_mm_inputs(self) -> bool:
|
||||
return self.contains_audio_inputs() or self.contains_image_inputs()
|
||||
return (
|
||||
self.contains_audio_inputs()
|
||||
or self.contains_video_inputs()
|
||||
or self.contains_image_inputs()
|
||||
)
|
||||
|
||||
def _compute_mrope_positions(
|
||||
self, model_runner: ModelRunner, batch: ModelWorkerBatch
|
||||
|
||||
@@ -1989,7 +1989,7 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
|
||||
hidden_states = general_mm_embed_routine(
|
||||
input_ids=input_ids,
|
||||
forward_batch=forward_batch,
|
||||
image_data_embedding_func=self.get_image_feature,
|
||||
multimodal_model=self,
|
||||
language_model=self.language_model,
|
||||
positions=positions,
|
||||
)
|
||||
|
||||
@@ -227,7 +227,7 @@ class DeepseekVL2ForCausalLM(nn.Module):
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
forward_batch=forward_batch,
|
||||
image_data_embedding_func=self.get_image_feature,
|
||||
multimodal_model=self,
|
||||
language_model=self.language_model,
|
||||
)
|
||||
|
||||
|
||||
@@ -374,7 +374,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
|
||||
input_ids=llm_input_ids,
|
||||
forward_batch=forward_batch,
|
||||
language_model=self.language_model,
|
||||
image_data_embedding_func=self.get_image_feature,
|
||||
multimodal_model=self,
|
||||
positions=positions,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import re
|
||||
from functools import lru_cache
|
||||
from typing import Dict, Iterable, List, Optional, Set, Tuple, TypedDict, Union
|
||||
from typing import Iterable, List, Optional, Set, Tuple, TypedDict, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -25,6 +25,7 @@ from sglang.srt.managers.mm_utils import (
|
||||
general_mm_embed_routine,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import (
|
||||
Modality,
|
||||
MultimodalDataItem,
|
||||
MultimodalInputs,
|
||||
flatten_nested_list,
|
||||
@@ -434,8 +435,10 @@ class Gemma3nForConditionalGeneration(PreTrainedModel):
|
||||
input_ids=input_ids,
|
||||
forward_batch=forward_batch,
|
||||
language_model=self.language_model,
|
||||
image_data_embedding_func=self.get_image_feature,
|
||||
audio_data_embedding_func=self.get_audio_feature,
|
||||
data_embedding_funcs={
|
||||
Modality.IMAGE: self.get_image_feature,
|
||||
Modality.AUDIO: self.get_audio_feature,
|
||||
},
|
||||
positions=positions,
|
||||
per_layer_inputs=per_layer_inputs,
|
||||
)
|
||||
|
||||
@@ -29,7 +29,11 @@ from sglang.srt.managers.mm_utils import (
|
||||
MultiModalityDataPaddingPatternTokenPairs,
|
||||
general_mm_embed_routine,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
|
||||
from sglang.srt.managers.schedule_batch import (
|
||||
Modality,
|
||||
MultimodalDataItem,
|
||||
MultimodalInputs,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.deepseek_janus_pro import DropPath
|
||||
@@ -523,7 +527,9 @@ class InternVLChatModel(nn.Module):
|
||||
input_ids=input_ids,
|
||||
forward_batch=forward_batch,
|
||||
language_model=self.language_model,
|
||||
image_data_embedding_func=self.get_image_feature,
|
||||
data_embedding_funcs={
|
||||
Modality.IMAGE: self.get_image_feature,
|
||||
},
|
||||
positions=positions,
|
||||
)
|
||||
|
||||
|
||||
@@ -67,7 +67,11 @@ from sglang.srt.managers.mm_utils import (
|
||||
MultiModalityDataPaddingPatternMultimodalTokens,
|
||||
general_mm_embed_routine,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
|
||||
from sglang.srt.managers.schedule_batch import (
|
||||
Modality,
|
||||
MultimodalDataItem,
|
||||
MultimodalInputs,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import (
|
||||
default_weight_loader,
|
||||
@@ -168,7 +172,9 @@ class KimiVLForConditionalGeneration(nn.Module):
|
||||
input_ids=input_ids,
|
||||
forward_batch=forward_batch,
|
||||
language_model=self.language_model,
|
||||
image_data_embedding_func=self.get_image_feature,
|
||||
data_embedding_funcs={
|
||||
Modality.IMAGE: self.get_image_feature,
|
||||
},
|
||||
positions=positions,
|
||||
)
|
||||
|
||||
|
||||
@@ -787,7 +787,9 @@ class LlavaForConditionalGeneration(LlavaBaseForCausalLM):
|
||||
forward_batch=forward_batch,
|
||||
get_embedding=get_embedding,
|
||||
language_model=self.language_model,
|
||||
image_data_embedding_func=self.get_image_feature,
|
||||
data_embedding_funcs={
|
||||
Modality.IMAGE: self.get_image_feature,
|
||||
},
|
||||
placeholder_tokens=None, # using mm_item.pad_value
|
||||
positions=positions,
|
||||
)
|
||||
|
||||
@@ -142,7 +142,7 @@ class LlavaVidForCausalLM(nn.Module):
|
||||
)
|
||||
image_offsets = [
|
||||
flatten_nested_list(
|
||||
[item.image_offsets for item in image_inputs[i].mm_items]
|
||||
[item.offsets for item in image_inputs[i].mm_items]
|
||||
)
|
||||
for i in range(bs)
|
||||
if need_vision[i]
|
||||
|
||||
@@ -1827,8 +1827,7 @@ class MiniCPMO(MiniCPMBaseModel):
|
||||
input_ids=input_ids,
|
||||
forward_batch=forward_batch,
|
||||
language_model=self.llm,
|
||||
image_data_embedding_func=self.get_image_feature,
|
||||
audio_data_embedding_func=self.get_audio_feature,
|
||||
multimodal_model=self,
|
||||
positions=positions,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
@@ -573,7 +573,7 @@ class MiniCPMBaseModel(nn.Module):
|
||||
hidden_states = general_mm_embed_routine(
|
||||
input_ids=input_ids,
|
||||
forward_batch=forward_batch,
|
||||
image_data_embedding_func=self.get_image_feature,
|
||||
multimodal_model=self,
|
||||
language_model=self.llm,
|
||||
positions=positions,
|
||||
)
|
||||
|
||||
@@ -6,8 +6,11 @@ from typing import List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import Llama4Config, Llama4VisionModel
|
||||
from transformers.models.llama4.modeling_llama4 import Llama4MultiModalProjector
|
||||
from transformers import Llama4Config
|
||||
from transformers.models.llama4.modeling_llama4 import (
|
||||
Llama4MultiModalProjector,
|
||||
Llama4VisionModel,
|
||||
)
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
@@ -16,7 +19,11 @@ from sglang.srt.managers.mm_utils import (
|
||||
MultiModalityDataPaddingPatternMultimodalTokens,
|
||||
general_mm_embed_routine,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
|
||||
from sglang.srt.managers.schedule_batch import (
|
||||
Modality,
|
||||
MultimodalDataItem,
|
||||
MultimodalInputs,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import add_prefix, is_cpu
|
||||
@@ -166,7 +173,9 @@ class Llama4ForConditionalGeneration(nn.Module):
|
||||
input_ids=input_ids,
|
||||
forward_batch=forward_batch,
|
||||
language_model=self.language_model,
|
||||
image_data_embedding_func=image_embedding_func,
|
||||
data_embedding_funcs={
|
||||
Modality.IMAGE: self.get_image_feature,
|
||||
},
|
||||
positions=positions,
|
||||
)
|
||||
|
||||
|
||||
@@ -31,7 +31,11 @@ from sglang.srt.managers.mm_utils import (
|
||||
MultiModalityDataPaddingPatternMultimodalTokens,
|
||||
general_mm_embed_routine,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
|
||||
from sglang.srt.managers.schedule_batch import (
|
||||
Modality,
|
||||
MultimodalDataItem,
|
||||
MultimodalInputs,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.idefics2 import Idefics2VisionTransformer
|
||||
@@ -439,7 +443,9 @@ class Phi4MMForCausalLM(nn.Module):
|
||||
input_ids=input_ids,
|
||||
forward_batch=forward_batch,
|
||||
language_model=self.language_model,
|
||||
image_data_embedding_func=self.get_image_feature,
|
||||
data_embedding_funcs={
|
||||
Modality.IMAGE: self.get_image_feature,
|
||||
},
|
||||
positions=positions,
|
||||
)
|
||||
|
||||
|
||||
@@ -56,7 +56,6 @@ from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInp
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.qwen2 import Qwen2Model
|
||||
from sglang.srt.models.qwen2_vl import Qwen2VLVideoInputs
|
||||
from sglang.srt.utils import add_prefix
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -507,11 +506,15 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
||||
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
||||
return image_embeds
|
||||
|
||||
def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor:
|
||||
pixel_values_videos = video_input["pixel_values_videos"].type(self.visual.dtype)
|
||||
video_embeds = self.visual(
|
||||
pixel_values_videos, grid_thw=video_input["video_grid_thw"]
|
||||
)
|
||||
def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
||||
# in qwen-vl, last dim is the same
|
||||
pixel_values = torch.cat(
|
||||
[getattr(item, "pixel_values_videos") for item in items], dim=0
|
||||
).type(self.visual.dtype)
|
||||
video_grid_thw = torch.concat([item.video_grid_thw for item in items], dim=0)
|
||||
assert pixel_values.dim() == 2, pixel_values.dim()
|
||||
assert video_grid_thw.dim() == 2, video_grid_thw.dim()
|
||||
video_embeds = self.visual(pixel_values, grid_thw=video_grid_thw)
|
||||
return video_embeds
|
||||
|
||||
def get_input_embeddings(self):
|
||||
@@ -553,7 +556,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
||||
input_ids=input_ids,
|
||||
forward_batch=forward_batch,
|
||||
language_model=self.model,
|
||||
image_data_embedding_func=self.get_image_feature,
|
||||
multimodal_model=self,
|
||||
positions=positions,
|
||||
)
|
||||
|
||||
|
||||
@@ -493,6 +493,17 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
||||
return image_embeds
|
||||
|
||||
def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
||||
# in qwen-vl, last dim is the same
|
||||
pixel_values = torch.cat(
|
||||
[item.pixel_values_videos for item in items], dim=0
|
||||
).type(self.visual.dtype)
|
||||
video_grid_thw = torch.concat([item.video_grid_thw for item in items], dim=0)
|
||||
assert pixel_values.dim() == 2, pixel_values.dim()
|
||||
assert video_grid_thw.dim() == 2, video_grid_thw.dim()
|
||||
video_embeds = self.visual(pixel_values, grid_thw=video_grid_thw)
|
||||
return video_embeds
|
||||
|
||||
def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor:
|
||||
pixel_values_videos = video_input["pixel_values_videos"].type(self.visual.dtype)
|
||||
video_embeds = self.visual(
|
||||
@@ -538,7 +549,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
input_ids=input_ids,
|
||||
forward_batch=forward_batch,
|
||||
language_model=self.model,
|
||||
image_data_embedding_func=self.get_image_feature,
|
||||
multimodal_model=self,
|
||||
positions=positions,
|
||||
)
|
||||
|
||||
|
||||
@@ -17,7 +17,11 @@ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorO
|
||||
from sglang.srt.layers.pooler import Pooler, PoolingType
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.managers.mm_utils import MultiModalityDataPaddingPatternMultimodalTokens
|
||||
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
|
||||
from sglang.srt.managers.schedule_batch import (
|
||||
Modality,
|
||||
MultimodalDataItem,
|
||||
MultimodalInputs,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
||||
|
||||
@@ -223,7 +227,9 @@ class VILAForConditionalGeneration(nn.Module):
|
||||
input_ids=input_ids,
|
||||
forward_batch=forward_batch,
|
||||
language_model=self.llm,
|
||||
image_data_embedding_func=self.get_image_feature,
|
||||
data_embedding_funcs={
|
||||
Modality.IMAGE: self.get_image_feature,
|
||||
},
|
||||
get_embedding=get_embedding,
|
||||
positions=positions,
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@ import multiprocessing as mp
|
||||
import os
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from functools import lru_cache
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -14,7 +14,7 @@ from PIL import Image
|
||||
from transformers import BaseImageProcessorFast
|
||||
|
||||
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||
from sglang.srt.utils import encode_video, load_audio, load_image
|
||||
from sglang.srt.utils import load_audio, load_image, load_video, logger
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@@ -25,14 +25,22 @@ class BaseMultiModalProcessorOutput:
|
||||
# frames loaded from image and video, in given order
|
||||
images: Optional[list[Union[Image.Image, dict]]] = None
|
||||
|
||||
# videos
|
||||
videos: Optional[list[Union[torch.Tensor, dict]]] = None
|
||||
|
||||
# audios
|
||||
audios: Optional[list[Union[np.ndarray, dict]]] = None
|
||||
|
||||
def normalize(self):
|
||||
for field_name in ["images", "audios"]:
|
||||
field = getattr(self, field_name, None)
|
||||
if field is not None and isinstance(field, list) and len(field) == 0:
|
||||
setattr(self, field_name, None)
|
||||
def organize_results(self) -> List[Tuple[Modality, Any]]:
|
||||
"""
|
||||
|
||||
:return: a list of results, with their corresponding modalities
|
||||
"""
|
||||
return (
|
||||
[(Modality.IMAGE, data) for data in self.images]
|
||||
+ [(Modality.VIDEO, data) for data in self.videos]
|
||||
+ [(Modality.AUDIO, data) for data in self.audios]
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@@ -41,6 +49,10 @@ class MultimodalSpecialTokens:
|
||||
video_token: Optional[Union[int, str, List[str]]] = None
|
||||
audio_token: Optional[Union[int, str, List[str]]] = None
|
||||
|
||||
image_token_regex: Optional[re.Pattern] = None
|
||||
video_token_regex: Optional[re.Pattern] = None
|
||||
audio_token_regex: Optional[re.Pattern] = None
|
||||
|
||||
def convert_to_str(self, token: Union[str, int], processor) -> str:
|
||||
if token is None:
|
||||
return token
|
||||
@@ -53,11 +65,29 @@ class MultimodalSpecialTokens:
|
||||
self.video_token = self.convert_to_str(self.video_token, processor)
|
||||
self.audio_token = self.convert_to_str(self.audio_token, processor)
|
||||
|
||||
image_token_regex: Optional[re.Pattern] = None
|
||||
video_token_regex: Optional[re.Pattern] = None
|
||||
audio_token_regex: Optional[re.Pattern] = None
|
||||
def get_modality_of_token(self, token) -> Optional[Modality]:
|
||||
"""
|
||||
:return: the modality associated with the given token, if the token is a special_token or matches with the multimodal token regex
|
||||
"""
|
||||
modality = {
|
||||
self.image_token: Modality.IMAGE,
|
||||
self.video_token: Modality.VIDEO,
|
||||
self.audio_token: Modality.AUDIO,
|
||||
}.get(token)
|
||||
if modality:
|
||||
return modality
|
||||
|
||||
def __post_init__(self):
|
||||
for regex, modality in [
|
||||
(self.image_token_regex, Modality.IMAGE),
|
||||
(self.video_token_regex, Modality.VIDEO),
|
||||
(self.audio_token_regex, Modality.AUDIO),
|
||||
]:
|
||||
if regex and regex.match(token):
|
||||
return modality
|
||||
|
||||
return None
|
||||
|
||||
def parse_regex(self):
|
||||
if self.image_token_regex is None and self.image_token is not None:
|
||||
self.image_token_regex = re.compile(re.escape(self.image_token))
|
||||
if self.video_token_regex is None and self.video_token is not None:
|
||||
@@ -65,7 +95,7 @@ class MultimodalSpecialTokens:
|
||||
if self.audio_token_regex is None and self.audio_token is not None:
|
||||
self.audio_token_regex = re.compile(re.escape(self.audio_token))
|
||||
|
||||
def collect(self) -> re.Pattern:
|
||||
def combine_regex(self) -> re.Pattern:
|
||||
tokens = [
|
||||
self.image_token_regex,
|
||||
self.video_token_regex,
|
||||
@@ -105,6 +135,7 @@ class BaseMultimodalProcessor(ABC):
|
||||
self.ATTR_NAME_TO_MODALITY = {
|
||||
# Image-related attributes
|
||||
"pixel_values": Modality.IMAGE,
|
||||
"pixel_values_videos": Modality.VIDEO,
|
||||
"image_sizes": Modality.IMAGE,
|
||||
"image_grid_thw": Modality.IMAGE,
|
||||
"image_emb_mask": Modality.IMAGE,
|
||||
@@ -120,7 +151,7 @@ class BaseMultimodalProcessor(ABC):
|
||||
"input_features": Modality.AUDIO,
|
||||
"input_features_mask": Modality.AUDIO,
|
||||
# Video-related attributes
|
||||
"video_grid_thws": Modality.VIDEO,
|
||||
"video_grid_thw": Modality.VIDEO,
|
||||
# Generic attributes that could apply to multiple modalities
|
||||
# "precomputed_features" - handled specially as it can be any modality
|
||||
}
|
||||
@@ -196,20 +227,25 @@ class BaseMultimodalProcessor(ABC):
|
||||
|
||||
@staticmethod
|
||||
def _load_single_item(
|
||||
data, is_video, is_audio, frame_count_limit=None, discard_alpha_channel=True
|
||||
data, modality: Modality, frame_count_limit=None, discard_alpha_channel=True
|
||||
):
|
||||
"""Static method that can be pickled for multiprocessing"""
|
||||
"""
|
||||
Load a single multimodal data.
|
||||
|
||||
If data is precomputed, returns directly.
|
||||
|
||||
Static method that can be pickled for multiprocessing"""
|
||||
if isinstance(data, dict):
|
||||
return data
|
||||
try:
|
||||
if is_audio:
|
||||
return load_audio(data)
|
||||
elif is_video:
|
||||
path = data[len("video:") :]
|
||||
return encode_video(path, frame_count_limit)
|
||||
else:
|
||||
if modality == Modality.IMAGE:
|
||||
img, _ = load_image(data)
|
||||
return img.convert("RGB") if discard_alpha_channel else img
|
||||
elif modality == Modality.VIDEO:
|
||||
return load_video(data, frame_count_limit)
|
||||
elif modality == Modality.AUDIO:
|
||||
return load_audio(data)
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error while loading data {data}: {e}")
|
||||
|
||||
@@ -217,75 +253,78 @@ class BaseMultimodalProcessor(ABC):
|
||||
self,
|
||||
text_parts: List[str],
|
||||
multimodal_tokens: MultimodalSpecialTokens,
|
||||
image_data: Optional[list] = None,
|
||||
audio_data: Optional[list] = None,
|
||||
data_iterators: dict,
|
||||
discard_alpha_channel: bool = True,
|
||||
):
|
||||
image_estimated_frames_iter: Optional[iter] = None,
|
||||
image_scaling_factor: float = 1.0,
|
||||
max_image_frames: int = 30,
|
||||
) -> Tuple[List, List]:
|
||||
"""
|
||||
load multimodal data parallelly
|
||||
load multimodal data parallelly using iterators.
|
||||
"""
|
||||
|
||||
# TODO(mick): load from server_args, env, or sampling_params
|
||||
MAX_NUM_FRAMES = 30
|
||||
estimated_frames_list = self.get_estimated_frames_list(image_data=image_data)
|
||||
total_frame_count = sum(estimated_frames_list)
|
||||
# a heuristic value, suggesting the maximum fraction of frames to embed from all visual inputs.
|
||||
# e.g., 0.1 suggests that 1 frame out of 10 input frames should be used
|
||||
scaling_factor = min(1.0, MAX_NUM_FRAMES / max(1, total_frame_count))
|
||||
|
||||
assert len(image_data) == len(estimated_frames_list)
|
||||
# Submit all tasks
|
||||
futures = []
|
||||
task_info = []
|
||||
image_index, audio_index = 0, 0
|
||||
|
||||
for text_part in text_parts:
|
||||
if (
|
||||
multimodal_tokens.image_token_regex
|
||||
and multimodal_tokens.image_token_regex.match(text_part)
|
||||
):
|
||||
data = image_data[image_index]
|
||||
is_video = isinstance(data, str) and data.startswith("video:")
|
||||
estimated_frames = estimated_frames_list[image_index]
|
||||
frame_count_limit = max(1, int(estimated_frames * scaling_factor))
|
||||
modality = multimodal_tokens.get_modality_of_token(text_part)
|
||||
if modality is not None:
|
||||
data_iterator = data_iterators.get(modality)
|
||||
if data_iterator is None:
|
||||
raise ValueError(f"No data iterator found for token: {text_part}")
|
||||
|
||||
try:
|
||||
data = next(data_iterator)
|
||||
except StopIteration:
|
||||
raise ValueError(
|
||||
f"Mismatch: More '{text_part}' tokens found than corresponding data items provided."
|
||||
)
|
||||
|
||||
frame_count_limit = None
|
||||
if modality == Modality.IMAGE and image_estimated_frames_iter:
|
||||
try:
|
||||
estimated_frames = next(image_estimated_frames_iter)
|
||||
# Use the pre-calculated scaling factor and max frames
|
||||
frame_count_limit = max(
|
||||
1, int(estimated_frames * image_scaling_factor)
|
||||
)
|
||||
# Ensure we don't exceed the absolute max (redundant if scaling_factor handles it)
|
||||
# frame_count_limit = min(frame_count_limit, max_image_frames)
|
||||
except StopIteration:
|
||||
raise ValueError(
|
||||
"Mismatch between image tokens and estimated frame counts."
|
||||
)
|
||||
|
||||
futures.append(
|
||||
self.io_executor.submit(
|
||||
BaseMultimodalProcessor._load_single_item,
|
||||
data,
|
||||
is_video,
|
||||
False,
|
||||
modality,
|
||||
frame_count_limit,
|
||||
discard_alpha_channel,
|
||||
)
|
||||
)
|
||||
task_info.append((Modality.IMAGE, data, frame_count_limit))
|
||||
image_index += 1
|
||||
elif (
|
||||
multimodal_tokens.audio_token_regex
|
||||
and multimodal_tokens.audio_token_regex.match(text_part)
|
||||
):
|
||||
data = audio_data[audio_index]
|
||||
futures.append(
|
||||
self.io_executor.submit(
|
||||
BaseMultimodalProcessor._load_single_item,
|
||||
data,
|
||||
False,
|
||||
True,
|
||||
None,
|
||||
discard_alpha_channel,
|
||||
)
|
||||
task_info.append((modality, data, frame_count_limit))
|
||||
|
||||
for modality, iterator in data_iterators.items():
|
||||
try:
|
||||
next(iterator)
|
||||
logger.warning(
|
||||
f"Warning: More {modality.name.lower()} data items provided than corresponding tokens found in the prompt."
|
||||
)
|
||||
task_info.append((Modality.AUDIO, data, None))
|
||||
audio_index += 1
|
||||
except StopIteration:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return futures, task_info
|
||||
|
||||
def load_mm_data(
|
||||
self,
|
||||
prompt: str | List[int],
|
||||
prompt: str,
|
||||
multimodal_tokens: MultimodalSpecialTokens,
|
||||
max_req_input_len: int,
|
||||
image_data: Optional[list] = None,
|
||||
video_data: Optional[list] = None,
|
||||
audio_data: Optional[list] = None,
|
||||
return_text: Optional[bool] = True,
|
||||
discard_alpha_channel: bool = True,
|
||||
@@ -299,14 +338,9 @@ class BaseMultimodalProcessor(ABC):
|
||||
discard_alpha_channel: if True, discards the alpha channel in the returned images
|
||||
|
||||
"""
|
||||
if not return_text:
|
||||
raise NotImplementedError()
|
||||
if image_data is None:
|
||||
image_data = []
|
||||
|
||||
multimodal_tokens.convert_to_strs(self._processor)
|
||||
multimodal_tokens_pattern = multimodal_tokens.collect()
|
||||
|
||||
multimodal_tokens.parse_regex()
|
||||
multimodal_tokens_pattern = multimodal_tokens.combine_regex()
|
||||
if isinstance(prompt, list) and return_text:
|
||||
assert len(prompt) and isinstance(prompt[0], int)
|
||||
prompt = self._processor.tokenizer.decode(prompt)
|
||||
@@ -317,59 +351,84 @@ class BaseMultimodalProcessor(ABC):
|
||||
# split text into list of normal text and special tokens
|
||||
text_parts = re.split(multimodal_tokens_pattern, prompt)
|
||||
|
||||
# collect all data
|
||||
data_iterators = {}
|
||||
if multimodal_tokens.image_token and image_data:
|
||||
data_iterators[Modality.IMAGE] = iter(image_data)
|
||||
if multimodal_tokens.video_token and video_data:
|
||||
data_iterators[Modality.VIDEO] = iter(video_data)
|
||||
if multimodal_tokens.audio_token and audio_data:
|
||||
data_iterators[Modality.AUDIO] = iter(audio_data)
|
||||
|
||||
# futures: the futures of loaded data
|
||||
# task_info: modality, raw_data, and other metadata of each data
|
||||
futures, task_info = self.submit_data_loading_tasks(
|
||||
text_parts=text_parts,
|
||||
multimodal_tokens=multimodal_tokens,
|
||||
image_data=image_data,
|
||||
audio_data=audio_data,
|
||||
data_iterators=data_iterators,
|
||||
discard_alpha_channel=discard_alpha_channel,
|
||||
)
|
||||
task_info_iter = iter(task_info)
|
||||
futures_iter = iter(futures)
|
||||
|
||||
# Process results
|
||||
images, audios = [], []
|
||||
new_text = ""
|
||||
task_ptr = 0
|
||||
|
||||
images, videos, audios = [], [], []
|
||||
new_text_parts = []
|
||||
for text_part in text_parts:
|
||||
if multimodal_tokens_pattern.match(text_part):
|
||||
task_type, data, frame_limit = task_info[task_ptr]
|
||||
result = futures[task_ptr].result()
|
||||
task_ptr += 1
|
||||
try:
|
||||
if multimodal_tokens_pattern.match(text_part):
|
||||
modality, raw_data, frame_limit = next(task_info_iter)
|
||||
is_precomputed = isinstance(raw_data, dict)
|
||||
result = next(futures_iter).result()
|
||||
|
||||
if task_type == Modality.IMAGE:
|
||||
# If data is already processed it will be a
|
||||
# dictionary. In this case we want to keep the
|
||||
# expanded tokens in text_part. Otherwise, we will
|
||||
# call the processor code, so keep only a single image
|
||||
# token.
|
||||
mm_tokens = (
|
||||
text_part
|
||||
if isinstance(data, dict)
|
||||
else multimodal_tokens.image_token
|
||||
)
|
||||
frames = [result] if not isinstance(result, list) else result
|
||||
if frames:
|
||||
images += frames
|
||||
new_text += mm_tokens * len(frames)
|
||||
elif task_type == Modality.AUDIO:
|
||||
# audio
|
||||
mm_tokens = (
|
||||
text_part
|
||||
if isinstance(data, dict)
|
||||
else multimodal_tokens.audio_token
|
||||
)
|
||||
audios.append(result)
|
||||
new_text += mm_tokens
|
||||
# TODO: handle video
|
||||
else:
|
||||
new_text += text_part
|
||||
if modality == Modality.IMAGE:
|
||||
# If data is already processed it will be a
|
||||
# dictionary(precomputed). In this case we want to keep the
|
||||
# expanded tokens in text_part. Otherwise, we will
|
||||
# call the processor code, so keep only a single image
|
||||
# token.
|
||||
mm_tokens = (
|
||||
text_part
|
||||
if is_precomputed
|
||||
else multimodal_tokens.image_token
|
||||
)
|
||||
frames = [result] if not isinstance(result, list) else result
|
||||
if frames:
|
||||
# only for minicpmv
|
||||
images += frames
|
||||
new_text_parts += mm_tokens * len(frames)
|
||||
elif modality == Modality.VIDEO:
|
||||
# load as video
|
||||
mm_tokens = (
|
||||
text_part
|
||||
if is_precomputed
|
||||
else multimodal_tokens.video_token
|
||||
)
|
||||
videos += [result]
|
||||
new_text_parts += mm_tokens
|
||||
elif modality == Modality.AUDIO:
|
||||
# audio
|
||||
mm_tokens = (
|
||||
text_part
|
||||
if is_precomputed
|
||||
else multimodal_tokens.audio_token
|
||||
)
|
||||
audios += [result]
|
||||
new_text_parts += mm_tokens
|
||||
else:
|
||||
# normal text
|
||||
new_text_parts += [text_part]
|
||||
|
||||
out = BaseMultiModalProcessorOutput(
|
||||
input_text=new_text,
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"An exception occurred while loading multimodal data: {e}"
|
||||
)
|
||||
return BaseMultiModalProcessorOutput(
|
||||
images=images,
|
||||
audios=audios,
|
||||
videos=videos,
|
||||
input_text="".join(new_text_parts),
|
||||
)
|
||||
out.normalize()
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def get_mm_items_offset(
|
||||
@@ -460,21 +519,19 @@ class BaseMultimodalProcessor(ABC):
|
||||
)
|
||||
except ValueError:
|
||||
modality = Modality.IMAGE
|
||||
|
||||
if modality:
|
||||
# Create item if needed
|
||||
if modality not in items:
|
||||
items[modality] = MultimodalDataItem(modality=modality)
|
||||
|
||||
# Set attribute
|
||||
if hasattr(items[modality], attr_name):
|
||||
setattr(items[modality], attr_name, value)
|
||||
setattr(items[modality], attr_name, value)
|
||||
|
||||
return list(items.values())
|
||||
|
||||
def _process_and_collect_mm_items(
|
||||
self, input_text: str, images=None, audios=None, videos=None, **kwargs
|
||||
) -> Tuple[List[MultimodalDataItem], torch.Tensor]:
|
||||
) -> Tuple[List[MultimodalDataItem], torch.Tensor, dict]:
|
||||
"""
|
||||
Helper method to process multimodal data and create mm_items in one step.
|
||||
|
||||
@@ -488,11 +545,11 @@ class BaseMultimodalProcessor(ABC):
|
||||
input_ids = ret["input_ids"].flatten()
|
||||
collected_items = self.collect_mm_items_from_processor_output(ret)
|
||||
|
||||
return collected_items, input_ids
|
||||
return collected_items, input_ids, ret
|
||||
|
||||
def process_and_combine_mm_data(
|
||||
self, base_output: BaseMultiModalProcessorOutput
|
||||
) -> Tuple[List[MultimodalDataItem], torch.Tensor]:
|
||||
) -> Tuple[List[MultimodalDataItem], torch.Tensor, dict]:
|
||||
"""
|
||||
Process multimodal data and return the combined multimodal items and input_ids.
|
||||
Supports mixed modalities (images and audio in the same request).
|
||||
@@ -501,8 +558,7 @@ class BaseMultimodalProcessor(ABC):
|
||||
Tuple of (list of mm_items, input_ids)
|
||||
"""
|
||||
# Collect all items and categorize them
|
||||
all_items = (base_output.images or []) + (base_output.audios or [])
|
||||
|
||||
all_items = base_output.organize_results()
|
||||
# Handle text-only case
|
||||
if not all_items:
|
||||
input_ids = self._processor.tokenizer(
|
||||
@@ -510,19 +566,20 @@ class BaseMultimodalProcessor(ABC):
|
||||
return_tensors="pt",
|
||||
add_special_tokens=True,
|
||||
).input_ids.flatten()
|
||||
return [], input_ids
|
||||
return [], input_ids, {}
|
||||
|
||||
dict_items, raw_images, raw_audios = [], [], []
|
||||
for item in all_items:
|
||||
dict_items, raw_images, raw_audios, raw_videos = [], [], [], []
|
||||
for modality, item in all_items:
|
||||
if isinstance(item, dict):
|
||||
dict_items.append(item)
|
||||
elif isinstance(item, Image.Image):
|
||||
elif modality == Modality.IMAGE:
|
||||
raw_images.append(item)
|
||||
elif isinstance(item, np.ndarray):
|
||||
elif modality == Modality.AUDIO:
|
||||
raw_audios.append(item)
|
||||
elif modality == Modality.VIDEO:
|
||||
raw_videos.append(item)
|
||||
else:
|
||||
raise ValueError(f"Unknown multimodal item type: {type(item)}")
|
||||
|
||||
# Process items and get input_ids
|
||||
all_collected_items = []
|
||||
input_ids = None
|
||||
@@ -534,13 +591,16 @@ class BaseMultimodalProcessor(ABC):
|
||||
)
|
||||
|
||||
# Handle raw items (need processing)
|
||||
if raw_images or raw_audios:
|
||||
collected_items, input_ids = self._process_and_collect_mm_items(
|
||||
if raw_images or raw_audios or raw_videos:
|
||||
collected_items, input_ids, ret = self._process_and_collect_mm_items(
|
||||
input_text=base_output.input_text,
|
||||
images=raw_images,
|
||||
audios=raw_audios,
|
||||
videos=raw_videos,
|
||||
)
|
||||
all_collected_items.extend(collected_items)
|
||||
else:
|
||||
ret = None
|
||||
|
||||
# Fallback tokenization if no raw items were processed
|
||||
if input_ids is None:
|
||||
@@ -553,21 +613,21 @@ class BaseMultimodalProcessor(ABC):
|
||||
# Add offsets to all items
|
||||
for mm_item in all_collected_items:
|
||||
if mm_item.modality in [Modality.IMAGE, Modality.MULTI_IMAGES]:
|
||||
mm_item.image_offsets = self.get_mm_items_offset(
|
||||
mm_item.offsets = self.get_mm_items_offset(
|
||||
input_ids=input_ids,
|
||||
mm_token_id=self.IM_TOKEN_ID,
|
||||
)
|
||||
elif mm_item.modality == Modality.AUDIO:
|
||||
mm_item.audio_offsets = self.get_mm_items_offset(
|
||||
mm_item.offsets = self.get_mm_items_offset(
|
||||
input_ids=input_ids,
|
||||
mm_token_id=self.AUDIO_TOKEN_ID,
|
||||
)
|
||||
elif mm_item.modality == Modality.VIDEO:
|
||||
mm_item.video_offsets = self.get_mm_items_offset(
|
||||
mm_item.offsets = self.get_mm_items_offset(
|
||||
input_ids=input_ids,
|
||||
mm_token_id=self.VIDEO_TOKEN_ID,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown modality: {mm_item.modality}")
|
||||
|
||||
return all_collected_items, input_ids
|
||||
return all_collected_items, input_ids, ret
|
||||
|
||||
@@ -69,7 +69,7 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
|
||||
)
|
||||
item = MultimodalDataItem(
|
||||
pixel_values=res["images"],
|
||||
image_offsets=image_offsets,
|
||||
offsets=image_offsets,
|
||||
modality=Modality.IMAGE,
|
||||
image_emb_mask=images_seq_mask,
|
||||
image_spatial_crop=batched_images_spatial_crop,
|
||||
|
||||
@@ -36,6 +36,7 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
print(f"{image_data=}")
|
||||
base_output = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
image_data=image_data,
|
||||
@@ -46,8 +47,9 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
|
||||
discard_alpha_channel=True,
|
||||
)
|
||||
|
||||
mm_items, input_ids = self.process_and_combine_mm_data(base_output)
|
||||
|
||||
mm_items, input_ids, _ = self.process_and_combine_mm_data(base_output)
|
||||
print(f"{base_output=}")
|
||||
print(f"{mm_items=}")
|
||||
return {
|
||||
"input_ids": input_ids.tolist(),
|
||||
"mm_items": mm_items,
|
||||
|
||||
@@ -72,7 +72,7 @@ class Gemma3nSGLangProcessor(SGLangBaseProcessor):
|
||||
),
|
||||
)
|
||||
|
||||
mm_items, input_ids = self.process_and_combine_mm_data(base_output)
|
||||
mm_items, input_ids, _ = self.process_and_combine_mm_data(base_output)
|
||||
|
||||
return {
|
||||
"input_ids": input_ids.tolist(),
|
||||
|
||||
@@ -225,7 +225,7 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
|
||||
MultimodalDataItem(
|
||||
pixel_values=pixel_values,
|
||||
modality=Modality.IMAGE,
|
||||
image_offsets=image_offsets,
|
||||
offsets=image_offsets,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@@ -49,7 +49,7 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
|
||||
MultimodalDataItem(
|
||||
pixel_values=res["pixel_values"],
|
||||
image_emb_mask=res["images_emb_mask"],
|
||||
image_offsets=image_offsets,
|
||||
offsets=image_offsets,
|
||||
modality=Modality.IMAGE,
|
||||
)
|
||||
],
|
||||
|
||||
@@ -39,7 +39,7 @@ class KimiVLImageProcessor(SGLangBaseProcessor):
|
||||
max_req_input_len=max_req_input_len,
|
||||
)
|
||||
|
||||
mm_items, input_ids = self.process_and_combine_mm_data(base_output)
|
||||
mm_items, input_ids, _ = self.process_and_combine_mm_data(base_output)
|
||||
|
||||
return {
|
||||
"input_ids": input_ids.tolist(),
|
||||
|
||||
@@ -19,6 +19,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
self.image_token = "(<image>./</image>)"
|
||||
self.audio_token = "(<audio>./</audio>)"
|
||||
self.video_token = "(<video>./</video>)"
|
||||
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
@@ -36,6 +37,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
||||
image_data=image_data,
|
||||
multimodal_tokens=MultimodalSpecialTokens(
|
||||
image_token=self.image_token,
|
||||
video_token=self.video_token,
|
||||
audio_token=self.audio_token,
|
||||
),
|
||||
)
|
||||
@@ -113,7 +115,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
||||
if len(pixel_values) != 0:
|
||||
item = MultimodalDataItem(
|
||||
pixel_values=pixel_values,
|
||||
image_offsets=image_offsets,
|
||||
offsets=image_offsets,
|
||||
tgt_size=tgt_sizes_flat,
|
||||
modality=Modality.IMAGE,
|
||||
)
|
||||
@@ -135,11 +137,10 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
||||
item = MultimodalDataItem(
|
||||
audio_features=[res["audio_features"]],
|
||||
audio_feature_lens=res["audio_feature_lens"],
|
||||
audio_offsets=audio_offsets,
|
||||
offsets=audio_offsets,
|
||||
modality=Modality.AUDIO,
|
||||
)
|
||||
items += [item]
|
||||
|
||||
return {
|
||||
"mm_items": items,
|
||||
"input_ids": input_ids.tolist(),
|
||||
|
||||
@@ -144,7 +144,7 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
|
||||
MultimodalDataItem(
|
||||
pixel_values=processor_output["pixel_values"],
|
||||
modality=Modality.IMAGE,
|
||||
image_offsets=image_offsets,
|
||||
offsets=image_offsets,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@@ -65,7 +65,7 @@ class Phi4MMImageProcessor(BaseMultimodalProcessor):
|
||||
pixel_values=res["input_image_embeds"],
|
||||
image_sizes=res["image_sizes"],
|
||||
image_emb_mask=res["image_attention_mask"],
|
||||
image_offsets=image_offsets,
|
||||
offsets=image_offsets,
|
||||
modality=Modality.IMAGE,
|
||||
)
|
||||
]
|
||||
|
||||
@@ -106,7 +106,7 @@ class PixtralProcessor(BaseMultimodalProcessor):
|
||||
pixel_values=processor_output["pixel_values"],
|
||||
image_sizes=processor_output["image_sizes"],
|
||||
modality=Modality.IMAGE,
|
||||
image_offsets=image_offsets,
|
||||
offsets=image_offsets,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
import asyncio
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, List, Union
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
from PIL import Image
|
||||
from torchvision.transforms import InterpolationMode
|
||||
|
||||
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
||||
from sglang.srt.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
|
||||
@@ -12,6 +16,185 @@ from sglang.srt.multimodal.processors.base_processor import (
|
||||
BaseMultimodalProcessor as SGLangBaseProcessor,
|
||||
)
|
||||
from sglang.srt.multimodal.processors.base_processor import MultimodalSpecialTokens
|
||||
from sglang.utils import logger
|
||||
|
||||
IMAGE_FACTOR = 28
|
||||
MIN_PIXELS = 4 * 28 * 28
|
||||
MAX_PIXELS = 16384 * 28 * 28
|
||||
MAX_RATIO = 200
|
||||
VIDEO_TOTAL_PIXELS = int(
|
||||
float(os.environ.get("VIDEO_MAX_PIXELS", 128000 * 28 * 28 * 0.9))
|
||||
)
|
||||
|
||||
VIDEO_MIN_PIXELS = 128 * 28 * 28
|
||||
VIDEO_MAX_PIXELS = 768 * 28 * 28
|
||||
FRAME_FACTOR = 2
|
||||
FPS = 2.0
|
||||
FPS_MIN_FRAMES = 4
|
||||
FPS_MAX_FRAMES = 768
|
||||
|
||||
|
||||
def smart_resize(
|
||||
height: int,
|
||||
width: int,
|
||||
factor: int = IMAGE_FACTOR,
|
||||
min_pixels: int = MIN_PIXELS,
|
||||
max_pixels: int = MAX_PIXELS,
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Rescales the image so that the following conditions are met:
|
||||
|
||||
1. Both dimensions (height and width) are divisible by 'factor'.
|
||||
|
||||
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
|
||||
|
||||
3. The aspect ratio of the image is maintained as closely as possible.
|
||||
"""
|
||||
if max(height, width) / min(height, width) > MAX_RATIO:
|
||||
raise ValueError(
|
||||
f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
|
||||
)
|
||||
h_bar = max(factor, round_by_factor(height, factor))
|
||||
w_bar = max(factor, round_by_factor(width, factor))
|
||||
if h_bar * w_bar > max_pixels:
|
||||
beta = math.sqrt((height * width) / max_pixels)
|
||||
h_bar = floor_by_factor(height / beta, factor)
|
||||
w_bar = floor_by_factor(width / beta, factor)
|
||||
elif h_bar * w_bar < min_pixels:
|
||||
beta = math.sqrt(min_pixels / (height * width))
|
||||
h_bar = ceil_by_factor(height * beta, factor)
|
||||
w_bar = ceil_by_factor(width * beta, factor)
|
||||
return h_bar, w_bar
|
||||
|
||||
|
||||
def resize_image(image, size_factor: int = IMAGE_FACTOR) -> Image.Image:
|
||||
width, height = image.size
|
||||
min_pixels = MIN_PIXELS
|
||||
max_pixels = MAX_PIXELS
|
||||
resized_height, resized_width = smart_resize(
|
||||
height,
|
||||
width,
|
||||
factor=size_factor,
|
||||
min_pixels=min_pixels,
|
||||
max_pixels=max_pixels,
|
||||
)
|
||||
image = image.resize((resized_width, resized_height))
|
||||
return image
|
||||
|
||||
|
||||
def round_by_factor(number: int, factor: int) -> int:
|
||||
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
|
||||
return round(number / factor) * factor
|
||||
|
||||
|
||||
def ceil_by_factor(number: int, factor: int) -> int:
|
||||
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
|
||||
return math.ceil(number / factor) * factor
|
||||
|
||||
|
||||
def floor_by_factor(number: int, factor: int) -> int:
|
||||
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
|
||||
return math.floor(number / factor) * factor
|
||||
|
||||
|
||||
async def resize_image_async(image):
|
||||
return resize_image(image)
|
||||
|
||||
|
||||
def smart_nframes(
|
||||
ele: dict,
|
||||
total_frames: int,
|
||||
video_fps: int | float,
|
||||
) -> int:
|
||||
"""calculate the number of frames for video used for model inputs.
|
||||
|
||||
Args:
|
||||
ele (dict): a dict contains the configuration of video.
|
||||
support either `fps` or `nframes`:
|
||||
- nframes: the number of frames to extract for model inputs.
|
||||
- fps: the fps to extract frames for model inputs.
|
||||
- min_frames: the minimum number of frames of the video, only used when fps is provided.
|
||||
- max_frames: the maximum number of frames of the video, only used when fps is provided.
|
||||
total_frames (int): the original total number of frames of the video.
|
||||
video_fps (int | float): the original fps of the video.
|
||||
|
||||
Raises:
|
||||
ValueError: nframes should in interval [FRAME_FACTOR, total_frames].
|
||||
|
||||
Returns:
|
||||
int: the number of frames for video used for model inputs.
|
||||
"""
|
||||
assert not (
|
||||
"fps" in ele and "nframes" in ele
|
||||
), "Only accept either `fps` or `nframes`"
|
||||
if "nframes" in ele:
|
||||
nframes = round_by_factor(ele["nframes"], FRAME_FACTOR)
|
||||
else:
|
||||
fps = ele.get("fps", FPS)
|
||||
min_frames = ceil_by_factor(ele.get("min_frames", FPS_MIN_FRAMES), FRAME_FACTOR)
|
||||
max_frames = floor_by_factor(
|
||||
ele.get("max_frames", min(FPS_MAX_FRAMES, total_frames)), FRAME_FACTOR
|
||||
)
|
||||
nframes = total_frames / video_fps * fps
|
||||
if nframes > total_frames:
|
||||
logger.warning(
|
||||
f"smart_nframes: nframes[{nframes}] > total_frames[{total_frames}]"
|
||||
)
|
||||
nframes = min(min(max(nframes, min_frames), max_frames), total_frames)
|
||||
nframes = floor_by_factor(nframes, FRAME_FACTOR)
|
||||
if not (FRAME_FACTOR <= nframes and nframes <= total_frames):
|
||||
raise ValueError(
|
||||
f"nframes should in interval [{FRAME_FACTOR}, {total_frames}], but got {nframes}."
|
||||
)
|
||||
return nframes
|
||||
|
||||
|
||||
# process video, qwen-specific
|
||||
async def preprocess_video(
|
||||
vr,
|
||||
image_factor: int = IMAGE_FACTOR,
|
||||
# vr: VideoReader, image_factor: int = IMAGE_FACTOR
|
||||
) -> torch.Tensor:
|
||||
ele = {}
|
||||
total_frames, video_fps = len(vr), vr.get_avg_fps()
|
||||
nframes = smart_nframes({}, total_frames=total_frames, video_fps=video_fps)
|
||||
idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist()
|
||||
video = vr.get_batch(idx).asnumpy()
|
||||
video = torch.tensor(video).permute(0, 3, 1, 2) # Convert to TCHW format
|
||||
nframes, _, height, width = video.shape
|
||||
min_pixels = ele.get("min_pixels", VIDEO_MIN_PIXELS)
|
||||
total_pixels = ele.get("total_pixels", VIDEO_TOTAL_PIXELS)
|
||||
max_pixels = max(
|
||||
min(VIDEO_MAX_PIXELS, total_pixels / nframes * FRAME_FACTOR),
|
||||
int(min_pixels * 1.05),
|
||||
)
|
||||
max_pixels_supposed = ele.get("max_pixels", max_pixels)
|
||||
if max_pixels_supposed > max_pixels:
|
||||
logger.warning(
|
||||
f"The given max_pixels[{max_pixels_supposed}] exceeds limit[{max_pixels}]."
|
||||
)
|
||||
max_pixels = min(max_pixels_supposed, max_pixels)
|
||||
if "resized_height" in ele and "resized_width" in ele:
|
||||
resized_height, resized_width = smart_resize(
|
||||
ele["resized_height"],
|
||||
ele["resized_width"],
|
||||
factor=image_factor,
|
||||
)
|
||||
else:
|
||||
resized_height, resized_width = smart_resize(
|
||||
height,
|
||||
width,
|
||||
factor=image_factor,
|
||||
min_pixels=min_pixels,
|
||||
max_pixels=max_pixels,
|
||||
)
|
||||
video = torchvision.transforms.functional.resize(
|
||||
video,
|
||||
[resized_height, resized_width],
|
||||
interpolation=InterpolationMode.BICUBIC,
|
||||
antialias=True,
|
||||
).float()
|
||||
return video
|
||||
|
||||
|
||||
# Compatible with Qwen2VL and Qwen2_5VL
|
||||
@@ -37,104 +220,44 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
||||
self.MIN_PIXELS = 4 * 28 * 28
|
||||
self.MAX_PIXELS = 16384 * 28 * 28
|
||||
self.MAX_RATIO = 200
|
||||
# TODO(mick): move all MultimodalSpecialTokens initializations into processor init
|
||||
self.mm_special_tokens = MultimodalSpecialTokens(
|
||||
image_token=self.IMAGE_TOKEN,
|
||||
image_token_regex=self.IMAGE_TOKEN_REGEX,
|
||||
video_token=self.VIDEO_TOKEN_ID,
|
||||
)
|
||||
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
image_data: List[Union[str, bytes, Dict]],
|
||||
image_data: List[Union[str, bytes]],
|
||||
input_text,
|
||||
request_obj,
|
||||
max_req_input_len,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
base_output = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
image_data=image_data,
|
||||
multimodal_tokens=MultimodalSpecialTokens(
|
||||
image_token=self.IMAGE_TOKEN,
|
||||
image_token_regex=self.IMAGE_TOKEN_REGEX,
|
||||
),
|
||||
video_data=request_obj.video_data,
|
||||
multimodal_tokens=self.mm_special_tokens,
|
||||
max_req_input_len=max_req_input_len,
|
||||
)
|
||||
|
||||
def smart_resize(
|
||||
height: int,
|
||||
width: int,
|
||||
factor: int = self.IMAGE_FACTOR,
|
||||
min_pixels: int = self.MIN_PIXELS,
|
||||
max_pixels: int = self.MAX_PIXELS,
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Rescales the image so that the following conditions are met:
|
||||
|
||||
1. Both dimensions (height and width) are divisible by 'factor'.
|
||||
|
||||
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
|
||||
|
||||
3. The aspect ratio of the image is maintained as closely as possible.
|
||||
"""
|
||||
if max(height, width) / min(height, width) > self.MAX_RATIO:
|
||||
raise ValueError(
|
||||
f"absolute aspect ratio must be smaller than {self.MAX_RATIO}, got {max(height, width) / min(height, width)}"
|
||||
)
|
||||
h_bar = max(factor, round_by_factor(height, factor))
|
||||
w_bar = max(factor, round_by_factor(width, factor))
|
||||
if h_bar * w_bar > max_pixels:
|
||||
beta = math.sqrt((height * width) / max_pixels)
|
||||
h_bar = floor_by_factor(height / beta, factor)
|
||||
w_bar = floor_by_factor(width / beta, factor)
|
||||
elif h_bar * w_bar < min_pixels:
|
||||
beta = math.sqrt(min_pixels / (height * width))
|
||||
h_bar = ceil_by_factor(height * beta, factor)
|
||||
w_bar = ceil_by_factor(width * beta, factor)
|
||||
return h_bar, w_bar
|
||||
|
||||
def resize_image(image, size_factor: int = self.IMAGE_FACTOR) -> Image.Image:
|
||||
width, height = image.size
|
||||
min_pixels = self.MIN_PIXELS
|
||||
max_pixels = self.MAX_PIXELS
|
||||
resized_height, resized_width = smart_resize(
|
||||
height,
|
||||
width,
|
||||
factor=size_factor,
|
||||
min_pixels=min_pixels,
|
||||
max_pixels=max_pixels,
|
||||
)
|
||||
image = image.resize((resized_width, resized_height))
|
||||
return image
|
||||
|
||||
def round_by_factor(number: int, factor: int) -> int:
|
||||
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
|
||||
return round(number / factor) * factor
|
||||
|
||||
def ceil_by_factor(number: int, factor: int) -> int:
|
||||
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
|
||||
return math.ceil(number / factor) * factor
|
||||
|
||||
def floor_by_factor(number: int, factor: int) -> int:
|
||||
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
|
||||
return math.floor(number / factor) * factor
|
||||
|
||||
async def resize_image_async(image):
|
||||
return resize_image(image)
|
||||
|
||||
# Qwen-specific: resize images if they are raw Image objects
|
||||
if base_output.images and isinstance(base_output.images[0], Image.Image):
|
||||
resize_tasks = [resize_image_async(image) for image in base_output.images]
|
||||
base_output.images = await asyncio.gather(*resize_tasks)
|
||||
|
||||
video_grid_thw = None # TODO
|
||||
if base_output.videos:
|
||||
base_output.videos = [
|
||||
await preprocess_video(video) for video in base_output.videos
|
||||
]
|
||||
|
||||
mm_items, input_ids = self.process_and_combine_mm_data(base_output)
|
||||
|
||||
if not mm_items:
|
||||
# Note(Xinyuan): This is the case where image loading fails.
|
||||
return None
|
||||
|
||||
combined_mm_item = mm_items[0] # only image is supported for now
|
||||
video_grid_thw = None # TODO
|
||||
second_per_grid_ts = getattr(combined_mm_item, "second_per_grid_ts", None)
|
||||
mm_items, input_ids, ret = self.process_and_combine_mm_data(base_output)
|
||||
|
||||
input_ids = input_ids.flatten()
|
||||
mrope_positions, mrope_position_delta = MRotaryEmbedding.get_rope_index(
|
||||
spatial_merge_size=self.hf_config.vision_config.spatial_merge_size,
|
||||
image_token_id=self.IM_TOKEN_ID,
|
||||
@@ -145,9 +268,9 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
||||
self.hf_config.vision_config, "tokens_per_second", None
|
||||
),
|
||||
input_ids=input_ids.unsqueeze(0),
|
||||
image_grid_thw=combined_mm_item.image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
second_per_grid_ts=second_per_grid_ts,
|
||||
image_grid_thw=getattr(ret, "image_grid_thw", None),
|
||||
video_grid_thw=getattr(ret, "video_grid_thw", None),
|
||||
second_per_grid_ts=getattr(ret, "second_per_grid_ts", None),
|
||||
)
|
||||
mrope_positions = mrope_positions.squeeze(1)
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ class VILAMultimodalProcessor(BaseMultimodalProcessor):
|
||||
image_data=image_data,
|
||||
)
|
||||
|
||||
mm_items, input_ids = self.process_and_combine_mm_data(base_output)
|
||||
mm_items, input_ids, _ = self.process_and_combine_mm_data(base_output)
|
||||
|
||||
return {
|
||||
"input_ids": input_ids.tolist(),
|
||||
|
||||
@@ -728,33 +728,6 @@ def load_audio(audio_file: str, sr: int = 16000, mono: bool = True) -> np.ndarra
|
||||
return audio
|
||||
|
||||
|
||||
def encode_video(video_path, frame_count_limit=None):
|
||||
# Lazy import because decord is not available on some arm platforms.
|
||||
from decord import VideoReader, cpu
|
||||
|
||||
if not os.path.exists(video_path):
|
||||
logger.error(f"Video {video_path} does not exist")
|
||||
return []
|
||||
|
||||
if frame_count_limit == 0:
|
||||
return []
|
||||
|
||||
def uniform_sample(l, n):
|
||||
gap = len(l) / n
|
||||
idxs = [int(i * gap + gap / 2) for i in range(n)]
|
||||
return [l[i] for i in idxs]
|
||||
|
||||
vr = VideoReader(video_path, ctx=cpu(0))
|
||||
sample_fps = round(vr.get_avg_fps() / 1) # FPS
|
||||
frame_indices = [i for i in range(0, len(vr), sample_fps)]
|
||||
if frame_count_limit is not None and len(frame_indices) > frame_count_limit:
|
||||
frame_indices = uniform_sample(frame_indices, frame_count_limit)
|
||||
|
||||
frames = vr.get_batch(frame_indices).asnumpy()
|
||||
frames = [Image.fromarray(v.astype("uint8")) for v in frames]
|
||||
return frames
|
||||
|
||||
|
||||
def load_image(
|
||||
image_file: Union[Image.Image, str, bytes],
|
||||
) -> tuple[Image.Image, tuple[int, int]]:
|
||||
@@ -774,9 +747,6 @@ def load_image(
|
||||
elif image_file.startswith("data:"):
|
||||
image_file = image_file.split(",")[1]
|
||||
image = Image.open(BytesIO(pybase64.b64decode(image_file, validate=True)))
|
||||
elif image_file.startswith("video:"):
|
||||
image_file = image_file.replace("video:", "")
|
||||
image, image_size = decode_video_base64(image_file)
|
||||
elif isinstance(image_file, str):
|
||||
image = Image.open(BytesIO(pybase64.b64decode(image_file, validate=True)))
|
||||
else:
|
||||
@@ -785,6 +755,61 @@ def load_image(
|
||||
return image, image_size
|
||||
|
||||
|
||||
def load_video(video_file: Union[str, bytes], use_gpu: bool = True):
|
||||
# We import decord here to avoid a strange Segmentation fault (core dumped) issue.
|
||||
from decord import VideoReader, cpu, gpu
|
||||
|
||||
try:
|
||||
from decord.bridge import decord_bridge
|
||||
|
||||
ctx = gpu(0)
|
||||
_ = decord_bridge.get_ctx_device(ctx)
|
||||
except Exception:
|
||||
ctx = cpu(0)
|
||||
|
||||
tmp_file = None
|
||||
vr = None
|
||||
try:
|
||||
if isinstance(video_file, bytes):
|
||||
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
|
||||
tmp_file.write(video_file)
|
||||
tmp_file.close()
|
||||
vr = VideoReader(tmp_file.name, ctx=ctx)
|
||||
elif isinstance(video_file, str):
|
||||
if video_file.startswith(("http://", "https://")):
|
||||
timeout = int(os.getenv("REQUEST_TIMEOUT", "10"))
|
||||
response = requests.get(video_file, stream=True, timeout=timeout)
|
||||
response.raise_for_status()
|
||||
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
tmp_file.write(chunk)
|
||||
tmp_file.close()
|
||||
vr = VideoReader(tmp_file.name, ctx=ctx)
|
||||
elif video_file.startswith("data:"):
|
||||
_, encoded = video_file.split(",", 1)
|
||||
video_bytes = base64.b64decode(encoded)
|
||||
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
|
||||
tmp_file.write(video_bytes)
|
||||
tmp_file.close()
|
||||
vr = VideoReader(tmp_file.name, ctx=ctx)
|
||||
elif os.path.isfile(video_file):
|
||||
vr = VideoReader(video_file, ctx=ctx)
|
||||
else:
|
||||
video_bytes = base64.b64decode(video_file)
|
||||
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
|
||||
tmp_file.write(video_bytes)
|
||||
tmp_file.close()
|
||||
vr = VideoReader(tmp_file.name, ctx=ctx)
|
||||
else:
|
||||
raise ValueError(f"Unsupported video input type: {type(video_file)}")
|
||||
|
||||
return vr
|
||||
|
||||
finally:
|
||||
if tmp_file and os.path.exists(tmp_file.name):
|
||||
os.unlink(tmp_file.name)
|
||||
|
||||
|
||||
def suppress_other_loggers():
|
||||
warnings.filterwarnings(
|
||||
"ignore", category=UserWarning, message="The given NumPy array is not writable"
|
||||
|
||||
Reference in New Issue
Block a user