Support precomputed multimodal features for Qwen-VL and Gemma3 models. (#6136)
Co-authored-by: Yury Sulsky <ysulsky@tesla.com>
This commit is contained in:
@@ -47,6 +47,7 @@ from sglang.srt.managers.io_struct import (
|
||||
EmbeddingReqInput,
|
||||
GenerateReqInput,
|
||||
GetWeightsByNameReqInput,
|
||||
ImageDataItem,
|
||||
InitWeightsUpdateGroupReqInput,
|
||||
ReleaseMemoryOccupationReqInput,
|
||||
ResumeMemoryOccupationReqInput,
|
||||
@@ -150,9 +151,9 @@ class Engine(EngineBase):
|
||||
# See also python/sglang/srt/utils.py:load_image for more details.
|
||||
image_data: Optional[
|
||||
Union[
|
||||
List[List[Union[Image, str]]],
|
||||
List[Union[Image, str]],
|
||||
Union[Image, str],
|
||||
List[List[ImageDataItem]],
|
||||
List[ImageDataItem],
|
||||
ImageDataItem,
|
||||
]
|
||||
] = None,
|
||||
return_logprob: Optional[Union[List[bool], bool]] = False,
|
||||
@@ -221,9 +222,9 @@ class Engine(EngineBase):
|
||||
# See also python/sglang/srt/utils.py:load_image for more details.
|
||||
image_data: Optional[
|
||||
Union[
|
||||
List[List[Union[Image, str]]],
|
||||
List[Union[Image, str]],
|
||||
Union[Image, str],
|
||||
List[List[ImageDataItem]],
|
||||
List[ImageDataItem],
|
||||
ImageDataItem,
|
||||
]
|
||||
] = None,
|
||||
return_logprob: Optional[Union[List[bool], bool]] = False,
|
||||
|
||||
@@ -40,6 +40,10 @@ class SessionParams:
|
||||
replace: Optional[bool] = None
|
||||
|
||||
|
||||
AudioDataItem = Union[str, Dict]
|
||||
ImageDataItem = Union[Image, str, Dict]
|
||||
|
||||
|
||||
@dataclass
|
||||
class GenerateReqInput:
|
||||
# The input prompt. It can be a single prompt or a batch of prompts.
|
||||
@@ -55,10 +59,10 @@ class GenerateReqInput:
|
||||
# - List of lists of images (multiple images per request)
|
||||
# See also python/sglang/srt/utils.py:load_image for more details.
|
||||
image_data: Optional[
|
||||
Union[List[List[Union[Image, str]]], List[Union[Image, str]], Union[Image, str]]
|
||||
Union[List[List[ImageDataItem]], List[ImageDataItem], ImageDataItem]
|
||||
] = None
|
||||
# The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
|
||||
audio_data: Optional[Union[List[str], str]] = None
|
||||
audio_data: Optional[Union[List[AudioDataItem], AudioDataItem]] = None
|
||||
# The sampling_params. See descriptions below.
|
||||
sampling_params: Optional[Union[List[Dict], Dict]] = None
|
||||
# The request id.
|
||||
|
||||
@@ -368,13 +368,13 @@ def general_mm_embed_routine(
|
||||
input_ids: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
language_model: nn.Module,
|
||||
image_data_embedding_func: Callable[
|
||||
[List[MultimodalDataItem]], torch.Tensor
|
||||
image_data_embedding_func: Optional[
|
||||
Callable[[List[MultimodalDataItem]], torch.Tensor]
|
||||
] = None,
|
||||
audio_data_embedding_func: Callable[
|
||||
[List[MultimodalDataItem]], torch.Tensor
|
||||
audio_data_embedding_func: Optional[
|
||||
Callable[[List[MultimodalDataItem]], torch.Tensor]
|
||||
] = None,
|
||||
placeholder_tokens: dict[Modality, List[int]] = None,
|
||||
placeholder_tokens: Optional[dict[Modality, List[int]]] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@@ -389,7 +389,6 @@ def general_mm_embed_routine(
|
||||
forwarded hidden states
|
||||
|
||||
"""
|
||||
|
||||
assert hasattr(language_model, "get_input_embeddings")
|
||||
embed_tokens = language_model.get_input_embeddings()
|
||||
if (
|
||||
|
||||
@@ -3,16 +3,16 @@ import concurrent.futures
|
||||
import dataclasses
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import BaseImageProcessorFast
|
||||
|
||||
from sglang.srt.managers.schedule_batch import Modality
|
||||
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||
from sglang.srt.utils import encode_video, load_audio, load_image
|
||||
|
||||
|
||||
@@ -22,13 +22,13 @@ class BaseMultiModalProcessorOutput:
|
||||
input_text: str
|
||||
|
||||
# frames loaded from image and video, in given order
|
||||
images: Optional[list[PIL.Image]] = None
|
||||
images: Optional[list[Union[Image.Image, MultimodalDataItem]]] = None
|
||||
|
||||
# audios
|
||||
audios: Optional[list[np.ndarray]] = None
|
||||
audios: Optional[list[Union[np.ndarray, MultimodalDataItem]]] = None
|
||||
|
||||
def normalize(self):
|
||||
for field_name in ["image_sizes", "images", "audios"]:
|
||||
for field_name in ["images", "audios"]:
|
||||
field = getattr(self, field_name, None)
|
||||
if field is not None and isinstance(field, list) and len(field) == 0:
|
||||
setattr(self, field_name, None)
|
||||
@@ -40,12 +40,32 @@ class MultimodalSpecialTokens:
|
||||
video_token: Optional[str] = None
|
||||
audio_token: Optional[str] = None
|
||||
|
||||
def collect(self) -> list[str]:
|
||||
return [
|
||||
token
|
||||
for token in [self.image_token, self.video_token, self.audio_token]
|
||||
if token
|
||||
image_token_regex: Optional[re.Pattern] = None
|
||||
video_token_regex: Optional[re.Pattern] = None
|
||||
audio_token_regex: Optional[re.Pattern] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.image_token_regex is None and self.image_token is not None:
|
||||
self.image_token_regex = re.compile(re.escape(self.image_token))
|
||||
if self.video_token_regex is None and self.video_token is not None:
|
||||
self.video_token_regex = re.compile(re.escape(self.video_token))
|
||||
if self.audio_token_regex is None and self.audio_token is not None:
|
||||
self.audio_token_regex = re.compile(re.escape(self.audio_token))
|
||||
|
||||
def collect(self) -> re.Pattern:
|
||||
tokens = [
|
||||
self.image_token_regex,
|
||||
self.video_token_regex,
|
||||
self.audio_token_regex,
|
||||
]
|
||||
patterns = []
|
||||
flags = 0
|
||||
for t in tokens:
|
||||
if t is not None:
|
||||
patterns.append(t.pattern)
|
||||
flags |= t.flags
|
||||
combined = "(" + "|".join(f"(?:{p})" for p in patterns) + ")"
|
||||
return re.compile(combined, flags)
|
||||
|
||||
|
||||
class BaseMultimodalProcessor(ABC):
|
||||
@@ -136,6 +156,10 @@ class BaseMultimodalProcessor(ABC):
|
||||
data, is_video, is_audio, frame_count_limit=None, discard_alpha_channel=True
|
||||
):
|
||||
"""Static method that can be pickled for multiprocessing"""
|
||||
if isinstance(data, dict):
|
||||
return MultimodalDataItem.from_dict(data)
|
||||
if isinstance(data, MultimodalDataItem):
|
||||
return data
|
||||
try:
|
||||
if is_audio:
|
||||
return load_audio(data)
|
||||
@@ -175,7 +199,10 @@ class BaseMultimodalProcessor(ABC):
|
||||
image_index, audio_index = 0, 0
|
||||
|
||||
for text_part in text_parts:
|
||||
if text_part == multimodal_tokens.image_token:
|
||||
if (
|
||||
multimodal_tokens.image_token_regex
|
||||
and multimodal_tokens.image_token_regex.match(text_part)
|
||||
):
|
||||
data = image_data[image_index]
|
||||
is_video = isinstance(data, str) and data.startswith("video:")
|
||||
estimated_frames = estimated_frames_list[image_index]
|
||||
@@ -192,7 +219,10 @@ class BaseMultimodalProcessor(ABC):
|
||||
)
|
||||
task_info.append((Modality.IMAGE, data, frame_count_limit))
|
||||
image_index += 1
|
||||
elif text_part == multimodal_tokens.audio_token:
|
||||
elif (
|
||||
multimodal_tokens.audio_token_regex
|
||||
and multimodal_tokens.audio_token_regex.match(text_part)
|
||||
):
|
||||
data = audio_data[audio_index]
|
||||
futures.append(
|
||||
self.io_executor.submit(
|
||||
@@ -228,17 +258,22 @@ class BaseMultimodalProcessor(ABC):
|
||||
discard_alpha_channel: if True, discards the alpha channel in the returned images
|
||||
|
||||
"""
|
||||
if not return_text:
|
||||
raise NotImplementedError()
|
||||
|
||||
if image_data is None:
|
||||
image_data = []
|
||||
if isinstance(multimodal_tokens.image_token, int):
|
||||
multimodal_tokens.image_token = (
|
||||
self._processor.tokenizer.convert_ids_to_tokens(
|
||||
multimodal_tokens.image_token
|
||||
multimodal_tokens.image_token = re.compile(
|
||||
re.escape(
|
||||
self._processor.tokenizer.convert_ids_to_tokens(
|
||||
multimodal_tokens.image_token
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
multimodal_tokens.image_token = multimodal_tokens.image_token
|
||||
multimodal_tokens_pattern = multimodal_tokens.collect()
|
||||
|
||||
if isinstance(prompt, list) and return_text:
|
||||
assert len(prompt) and isinstance(prompt[0], int)
|
||||
@@ -247,16 +282,8 @@ class BaseMultimodalProcessor(ABC):
|
||||
prompt = prompt
|
||||
|
||||
assert isinstance(prompt, str)
|
||||
if return_text:
|
||||
import re
|
||||
|
||||
pattern = (
|
||||
"("
|
||||
+ "|".join(re.escape(sep) for sep in multimodal_tokens.collect())
|
||||
+ ")"
|
||||
)
|
||||
# split text into list of normal text and special tokens
|
||||
text_parts = re.split(pattern, prompt)
|
||||
# split text into list of normal text and special tokens
|
||||
text_parts = re.split(multimodal_tokens_pattern, prompt)
|
||||
|
||||
futures, task_info = self.submit_data_loading_tasks(
|
||||
text_parts=text_parts,
|
||||
@@ -266,26 +293,40 @@ class BaseMultimodalProcessor(ABC):
|
||||
discard_alpha_channel=discard_alpha_channel,
|
||||
)
|
||||
# Process results
|
||||
image_sizes, images, audios = [], [], []
|
||||
images, audios = [], []
|
||||
new_text = ""
|
||||
task_ptr = 0
|
||||
|
||||
for text_part in text_parts:
|
||||
if text_part in multimodal_tokens.collect():
|
||||
if multimodal_tokens_pattern.match(text_part):
|
||||
task_type, data, frame_limit = task_info[task_ptr]
|
||||
result = futures[task_ptr].result()
|
||||
task_ptr += 1
|
||||
|
||||
if task_type == Modality.IMAGE:
|
||||
# If data is already processed it will be a
|
||||
# dictionary. In this case we want to keep the
|
||||
# expanded tokens in text_part. Otherwise, we will
|
||||
# call the processor code, so keep only a single image
|
||||
# token.
|
||||
mm_tokens = (
|
||||
text_part
|
||||
if isinstance(data, dict)
|
||||
else multimodal_tokens.image_token
|
||||
)
|
||||
frames = [result] if not isinstance(result, list) else result
|
||||
if frames:
|
||||
image_sizes += frames[0].size * len(frames)
|
||||
images += frames
|
||||
new_text += multimodal_tokens.image_token * len(frames)
|
||||
new_text += mm_tokens * len(frames)
|
||||
elif task_type == Modality.AUDIO:
|
||||
# audio
|
||||
mm_tokens = (
|
||||
text_part
|
||||
if isinstance(data, dict)
|
||||
else multimodal_tokens.audio_token
|
||||
)
|
||||
audios.append(result)
|
||||
new_text += multimodal_tokens.audio_token
|
||||
new_text += mm_tokens
|
||||
# TODO: handle video
|
||||
else:
|
||||
new_text += text_part
|
||||
@@ -297,3 +338,16 @@ class BaseMultimodalProcessor(ABC):
|
||||
)
|
||||
out.normalize()
|
||||
return out
|
||||
|
||||
def mm_inputs_are_preprocessed(self, mm_inputs: Optional[list]):
|
||||
"""Returns true if all images are preprocessed, false if all are not, and error otherwise."""
|
||||
if not mm_inputs:
|
||||
return True
|
||||
ret = any(isinstance(mm_input, MultimodalDataItem) for mm_input in mm_inputs)
|
||||
if ret and not all(
|
||||
isinstance(mm_input, MultimodalDataItem) for mm_input in mm_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
"Unsupported: mixture of multimodal inputs where some but not all are preprocessed."
|
||||
)
|
||||
return ret
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import List, Union
|
||||
import re
|
||||
from typing import Dict, List, Union
|
||||
|
||||
from sglang.srt.managers.multimodal_processor import (
|
||||
BaseMultimodalProcessor as SGLangBaseProcessor,
|
||||
@@ -18,13 +19,18 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
# The single, pre-expanded image token.
|
||||
self.IMAGE_TOKEN = "<start_of_image>"
|
||||
# The regex that matches expanded image tokens.
|
||||
self.IMAGE_TOKEN_REGEX = re.compile(
|
||||
r"<start_of_image>(?:(?:<image_soft_token>)*<end_of_image>)?"
|
||||
)
|
||||
self.IM_START_TOKEN_ID = hf_config.boi_token_index
|
||||
self.IM_END_TOKEN_ID = hf_config.eoi_token_index
|
||||
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
image_data: List[Union[str, bytes]],
|
||||
image_data: List[Union[str, bytes, Dict]],
|
||||
input_text,
|
||||
request_obj,
|
||||
max_req_input_len,
|
||||
@@ -37,22 +43,35 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
|
||||
image_data = [image_data]
|
||||
|
||||
image_token = self.IMAGE_TOKEN
|
||||
image_token_regex = self.IMAGE_TOKEN_REGEX
|
||||
base_output = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
image_data=image_data,
|
||||
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
|
||||
multimodal_tokens=MultimodalSpecialTokens(
|
||||
image_token=image_token, image_token_regex=image_token_regex
|
||||
),
|
||||
max_req_input_len=max_req_input_len,
|
||||
discard_alpha_channel=True,
|
||||
)
|
||||
|
||||
images_are_preprocessed = self.mm_inputs_are_preprocessed(base_output.images)
|
||||
ret = self.process_mm_data(
|
||||
input_text=base_output.input_text, images=base_output.images
|
||||
input_text=base_output.input_text,
|
||||
images=None if images_are_preprocessed else base_output.images,
|
||||
)
|
||||
|
||||
items = []
|
||||
for i, image in enumerate(base_output.images):
|
||||
if images_are_preprocessed:
|
||||
pixel_values = image.pixel_values
|
||||
precomputed_features = image.precomputed_features
|
||||
else:
|
||||
pixel_values = ret["pixel_values"][i]
|
||||
precomputed_features = None
|
||||
|
||||
item = MultimodalDataItem(
|
||||
pixel_values=ret["pixel_values"][i],
|
||||
pixel_values=pixel_values,
|
||||
precomputed_features=precomputed_features,
|
||||
modality=Modality.IMAGE,
|
||||
)
|
||||
items += [item]
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import math
|
||||
from typing import List, Union
|
||||
import re
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
@@ -23,7 +24,12 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
# The single, pre-expanded image token.
|
||||
self.IMAGE_TOKEN = "<|vision_start|><|image_pad|><|vision_end|>"
|
||||
# The regex that matches expanded image tokens.
|
||||
self.IMAGE_TOKEN_REGEX = re.compile(
|
||||
r"<\|vision_start\|>(?:<\|image_pad\|>)+<\|vision_end\|>"
|
||||
)
|
||||
self.IM_START_TOKEN_ID = hf_config.vision_start_token_id
|
||||
self.IM_END_TOKEN_ID = hf_config.vision_end_token_id
|
||||
self.image_token_id = hf_config.image_token_id
|
||||
@@ -38,7 +44,7 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
||||
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
image_data: List[Union[str, bytes]],
|
||||
image_data: List[Union[str, bytes, Dict]],
|
||||
input_text,
|
||||
request_obj,
|
||||
max_req_input_len,
|
||||
@@ -48,11 +54,13 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
||||
if isinstance(image_data, str):
|
||||
image_data = [image_data]
|
||||
|
||||
image_token = self.IMAGE_TOKEN
|
||||
base_output = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
image_data=image_data,
|
||||
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
|
||||
multimodal_tokens=MultimodalSpecialTokens(
|
||||
image_token=self.IMAGE_TOKEN,
|
||||
image_token_regex=self.IMAGE_TOKEN_REGEX,
|
||||
),
|
||||
max_req_input_len=max_req_input_len,
|
||||
)
|
||||
|
||||
@@ -117,26 +125,56 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
||||
async def resize_image_async(image):
|
||||
return resize_image(image)
|
||||
|
||||
if base_output.images:
|
||||
images_are_preprocessed = self.mm_inputs_are_preprocessed(base_output.images)
|
||||
if base_output.images and not images_are_preprocessed:
|
||||
resize_tasks = [resize_image_async(image) for image in base_output.images]
|
||||
base_output.images = await asyncio.gather(*resize_tasks)
|
||||
|
||||
ret = self.process_mm_data(
|
||||
input_text=base_output.input_text,
|
||||
images=base_output.images,
|
||||
images=None if images_are_preprocessed else base_output.images,
|
||||
)
|
||||
|
||||
input_ids = ret["input_ids"].flatten().tolist()
|
||||
image_grid_thw = None
|
||||
video_grid_thw = None # TODO
|
||||
items = []
|
||||
|
||||
input_ids = ret["input_ids"].flatten().tolist()
|
||||
if "pixel_values" in ret:
|
||||
if base_output.images:
|
||||
if images_are_preprocessed:
|
||||
image_grid_thw = torch.concat(
|
||||
[
|
||||
torch.as_tensor(item.image_grid_thws)
|
||||
for item in base_output.images
|
||||
]
|
||||
)
|
||||
all_pixel_values = [
|
||||
item.pixel_values
|
||||
for item in base_output.images
|
||||
if item.pixel_values is not None
|
||||
]
|
||||
all_precomputed_features = [
|
||||
item.precomputed_features
|
||||
for item in base_output.images
|
||||
if item.precomputed_features is not None
|
||||
]
|
||||
pixel_values = (
|
||||
torch.concat(all_pixel_values) if all_pixel_values else None
|
||||
)
|
||||
precomputed_features = (
|
||||
torch.concat(all_precomputed_features)
|
||||
if all_precomputed_features
|
||||
else None
|
||||
)
|
||||
else:
|
||||
image_grid_thw = ret["image_grid_thw"]
|
||||
pixel_values = ret["pixel_values"]
|
||||
precomputed_features = None
|
||||
items += [
|
||||
MultimodalDataItem(
|
||||
pixel_values=ret["pixel_values"],
|
||||
image_grid_thws=torch.concat([ret["image_grid_thw"]]),
|
||||
# TODO
|
||||
video_grid_thws=None,
|
||||
second_per_grid_ts=ret.get("second_per_grid_ts", None),
|
||||
pixel_values=pixel_values,
|
||||
image_grid_thws=image_grid_thw,
|
||||
video_grid_thws=video_grid_thw,
|
||||
precomputed_features=precomputed_features,
|
||||
modality=Modality.IMAGE,
|
||||
)
|
||||
]
|
||||
@@ -151,8 +189,8 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
||||
self.hf_config.vision_config, "tokens_per_second", None
|
||||
),
|
||||
input_ids=torch.tensor(input_ids).unsqueeze(0),
|
||||
image_grid_thw=ret.get("image_grid_thw", None),
|
||||
video_grid_thw=ret.get("video_grid_thw", None),
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
second_per_grid_ts=ret.get("second_per_grid_ts", None),
|
||||
)
|
||||
mrope_positions = mrope_positions.squeeze(1)
|
||||
|
||||
@@ -177,10 +177,10 @@ class MultimodalDataItem:
|
||||
image_offsets: Optional[list] = None
|
||||
|
||||
# the real data, pixel_values or audio_features
|
||||
# data: Union[List[torch.Tensor], List[np.array]]
|
||||
pixel_values: Union[torch.Tensor, np.array] = None
|
||||
image_grid_thws: Union[torch.Tensor, np.array] = None
|
||||
video_grid_thws: Union[torch.Tensor, np.array] = None
|
||||
# data: Union[List[torch.Tensor], List[np.ndarray]]
|
||||
pixel_values: Union[torch.Tensor, np.ndarray] = None
|
||||
image_grid_thws: Union[torch.Tensor, np.ndarray] = None
|
||||
video_grid_thws: Union[torch.Tensor, np.ndarray] = None
|
||||
|
||||
image_emb_mask: Optional[torch.Tensor] = None
|
||||
image_spatial_crop: Optional[torch.Tensor] = None
|
||||
@@ -189,9 +189,11 @@ class MultimodalDataItem:
|
||||
# [num_images, (n, w, h)]
|
||||
tgt_size: Tuple[int, int] = None
|
||||
|
||||
audio_features: Union[torch.Tensor, np.array] = None
|
||||
audio_features: Union[torch.Tensor, np.ndarray] = None
|
||||
audio_feature_lens: Optional[List[torch.Tensor]] = None
|
||||
|
||||
precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None
|
||||
|
||||
@staticmethod
|
||||
def is_empty_list(l):
|
||||
if l is None:
|
||||
@@ -249,7 +251,9 @@ class MultimodalDataItem:
|
||||
return tensor_hash([f])
|
||||
return data_hash(f)
|
||||
|
||||
if self.is_audio():
|
||||
if self.precomputed_features is not None:
|
||||
self.hash = hash_feature(self.precomputed_features)
|
||||
elif self.is_audio():
|
||||
self.hash = hash_feature(self.audio_features)
|
||||
else:
|
||||
self.hash = hash_feature(self.pixel_values)
|
||||
@@ -258,19 +262,24 @@ class MultimodalDataItem:
|
||||
self.pad_value = self.hash % (1 << 30)
|
||||
|
||||
def is_audio(self):
|
||||
return (
|
||||
self.modality == Modality.AUDIO
|
||||
) and not MultimodalDataItem.is_empty_list(self.audio_features)
|
||||
return (self.modality == Modality.AUDIO) and (
|
||||
self.precomputed_features is not None
|
||||
or not MultimodalDataItem.is_empty_list(self.audio_features)
|
||||
)
|
||||
|
||||
def is_image(self):
|
||||
return (
|
||||
self.modality == Modality.IMAGE or self.modality == Modality.MULTI_IMAGES
|
||||
) and not MultimodalDataItem.is_empty_list(self.pixel_values)
|
||||
) and (
|
||||
self.precomputed_features is not None
|
||||
or not MultimodalDataItem.is_empty_list(self.pixel_values)
|
||||
)
|
||||
|
||||
def is_video(self):
|
||||
return (
|
||||
self.modality == Modality.VIDEO
|
||||
) and not MultimodalDataItem.is_empty_list(self.pixel_values)
|
||||
return (self.modality == Modality.VIDEO) and (
|
||||
self.precomputed_features is not None
|
||||
or not MultimodalDataItem.is_empty_list(self.pixel_values)
|
||||
)
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
return self.is_image() or self.is_video() or self.is_audio()
|
||||
@@ -279,6 +288,16 @@ class MultimodalDataItem:
|
||||
...
|
||||
# TODO
|
||||
|
||||
@staticmethod
|
||||
def from_dict(obj: dict):
|
||||
kwargs = dict(obj)
|
||||
modality = kwargs.pop("modality")
|
||||
if isinstance(modality, str):
|
||||
modality = Modality[modality]
|
||||
ret = MultimodalDataItem(modality=modality, **kwargs)
|
||||
ret.validate()
|
||||
return ret
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MultimodalInputs:
|
||||
|
||||
@@ -54,7 +54,7 @@ class SessionReqNode:
|
||||
prefix += " -- " + self.childs[0].req.rid
|
||||
ret = self.childs[0]._str_helper(prefix)
|
||||
for child in self.childs[1:]:
|
||||
prefix = " " * len(origin_prefix) + " \- " + child.req.rid
|
||||
prefix = " " * len(origin_prefix) + r" \- " + child.req.rid
|
||||
ret += child._str_helper(prefix)
|
||||
return ret
|
||||
|
||||
|
||||
@@ -278,6 +278,12 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
|
||||
Returns:
|
||||
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
||||
"""
|
||||
if any(item.precomputed_features is not None for item in items):
|
||||
if not all(item.precomputed_features is not None for item in items):
|
||||
raise NotImplementedError(
|
||||
"MM inputs where only some items are precomputed."
|
||||
)
|
||||
return torch.concat([item.precomputed_features for item in items])
|
||||
pixel_values = torch.stack(
|
||||
flatten_nested_list([item.pixel_values for item in items]), dim=0
|
||||
)
|
||||
|
||||
@@ -497,6 +497,12 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
||||
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
||||
|
||||
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
||||
if any(item.precomputed_features is not None for item in items):
|
||||
if not all(item.precomputed_features is not None for item in items):
|
||||
raise NotImplementedError(
|
||||
"MM inputs where only some items are precomputed."
|
||||
)
|
||||
return torch.concat([item.precomputed_features for item in items])
|
||||
# in qwen-vl, last dim is the same
|
||||
pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type(
|
||||
self.visual.dtype
|
||||
|
||||
@@ -486,6 +486,12 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
||||
|
||||
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
||||
if any(item.precomputed_features is not None for item in items):
|
||||
if not all(item.precomputed_features is not None for item in items):
|
||||
raise NotImplementedError(
|
||||
"MM inputs where only some items are precomputed."
|
||||
)
|
||||
return torch.concat([item.precomputed_features for item in items])
|
||||
# in qwen-vl, last dim is the same
|
||||
pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type(
|
||||
self.visual.dtype
|
||||
|
||||
Reference in New Issue
Block a user