Refactor mm processors and Enable mixed modality processing (#7629)
Signed-off-by: Xinyuan Tong <justinning0323@outlook.com>
This commit is contained in:
@@ -125,74 +125,38 @@ class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPa
|
||||
e.g. <image><image>....<image>, or <audio><audio>...<audio>
|
||||
"""
|
||||
|
||||
def __init__(self, token_ids: List[int]) -> None:
|
||||
self.token_ids = token_ids
|
||||
|
||||
def pad_input_tokens(
|
||||
self, input_ids: List[int], mm_inputs: MultimodalInputs
|
||||
) -> List[int]:
|
||||
"""
|
||||
Finds contiguous regions of tokens matching `self.token_ids` in `input_ids`
|
||||
and replaces each region with the corresponding `pad_value` from `mm_inputs.mm_items`.
|
||||
Replaces multimodal tokens in input_ids with corresponding pad_values from mm_items.
|
||||
Each modality (image, audio, video) is handled separately based on its token_id.
|
||||
"""
|
||||
pad_values = [item.pad_value for item in mm_inputs.mm_items]
|
||||
if not pad_values:
|
||||
# No multimodal items, return original input_ids
|
||||
if not input_ids or not mm_inputs.mm_items:
|
||||
return input_ids
|
||||
if not input_ids:
|
||||
return []
|
||||
|
||||
input_ids_tensor = torch.tensor(input_ids)
|
||||
device = input_ids_tensor.device
|
||||
token_ids_tensor = torch.tensor(self.token_ids, device=device)
|
||||
mask = torch.isin(input_ids_tensor, token_ids_tensor)
|
||||
|
||||
if not mask.any():
|
||||
# No tokens match token_ids, return original input_ids
|
||||
return input_ids
|
||||
# Create mapping of token_ids to pad_values for each modality
|
||||
token_to_pad_mapping = {}
|
||||
|
||||
# Find contiguous regions
|
||||
padded_mask = torch.cat(
|
||||
(
|
||||
torch.tensor([False], device=device),
|
||||
mask,
|
||||
torch.tensor([False], device=device),
|
||||
)
|
||||
)
|
||||
# Find indices where the mask value changes
|
||||
diff_indices = torch.where(padded_mask[1:] != padded_mask[:-1])[0]
|
||||
|
||||
# Start indices are where False changes to True
|
||||
starts = diff_indices[::2]
|
||||
# End indices are where True changes to False (exclusive index)
|
||||
ends = diff_indices[1::2]
|
||||
|
||||
# Check if the number of regions matches the number of pad values
|
||||
if len(starts) != len(pad_values):
|
||||
# Maybe log a warning here?
|
||||
num_regions = len(starts)
|
||||
num_pad_values = len(pad_values)
|
||||
if num_regions > 0 and num_pad_values > 0:
|
||||
pad_values = (pad_values * (num_regions // num_pad_values + 1))[
|
||||
:num_regions
|
||||
]
|
||||
else: # If no regions or no pad_values, this loop won't run anyway.
|
||||
pad_values = [] # Ensure pad_values is empty if starts is empty
|
||||
|
||||
# Create a copy to modify
|
||||
output_ids_tensor = input_ids_tensor.clone()
|
||||
|
||||
# Replace tokens in each region with the corresponding pad value
|
||||
# Ensure we don't iterate if pad_values became empty due to mismatch and num_regions=0
|
||||
for i in range(min(len(starts), len(pad_values))):
|
||||
start_idx = starts[i]
|
||||
end_idx = ends[i]
|
||||
pad_value = pad_values[i]
|
||||
if pad_value is not None: # Ensure pad_value is not None before assignment
|
||||
output_ids_tensor[start_idx:end_idx] = pad_value
|
||||
for item in mm_inputs.mm_items:
|
||||
if item.is_image() and mm_inputs.im_token_id is not None:
|
||||
token_to_pad_mapping[mm_inputs.im_token_id] = item.pad_value
|
||||
elif item.is_audio() and mm_inputs.audio_token_id is not None:
|
||||
token_to_pad_mapping[mm_inputs.audio_token_id] = item.pad_value
|
||||
elif item.is_video() and mm_inputs.video_token_id is not None:
|
||||
token_to_pad_mapping[mm_inputs.video_token_id] = item.pad_value
|
||||
else:
|
||||
logger.warning(f"Skipping region {i} due to None pad_value.")
|
||||
return output_ids_tensor.tolist()
|
||||
raise ValueError(f"No multimodal token id provided for {item.modality}")
|
||||
|
||||
# Apply replacements for all tokens at once
|
||||
for token_id, pad_value in token_to_pad_mapping.items():
|
||||
input_ids_tensor[input_ids_tensor == token_id] = pad_value
|
||||
|
||||
ret_input_ids = input_ids_tensor.tolist()
|
||||
|
||||
return ret_input_ids
|
||||
|
||||
|
||||
embedding_cache = None
|
||||
|
||||
@@ -174,6 +174,15 @@ class Modality(Enum):
|
||||
VIDEO = auto()
|
||||
AUDIO = auto()
|
||||
|
||||
@staticmethod
|
||||
def from_str(modality_str: str):
|
||||
try:
|
||||
return Modality[modality_str.upper()]
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
f"Invalid modality string: {modality_str}. Valid modalities are: {[m.name for m in Modality]}"
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MultimodalDataItem:
|
||||
|
||||
@@ -482,20 +482,25 @@ class TokenizerManager:
|
||||
token_type_ids = encoded.get("token_type_ids", [None])[0]
|
||||
|
||||
if self.mm_processor and obj.contains_mm_input():
|
||||
image_inputs: Dict = await self.mm_processor.process_mm_data_async(
|
||||
if not isinstance(obj.image_data, list):
|
||||
obj.image_data = [obj.image_data]
|
||||
if not isinstance(obj.audio_data, list):
|
||||
obj.audio_data = [obj.audio_data]
|
||||
mm_inputs: Dict = await self.mm_processor.process_mm_data_async(
|
||||
image_data=obj.image_data,
|
||||
audio_data=obj.audio_data,
|
||||
input_text=input_text or input_ids,
|
||||
request_obj=obj,
|
||||
max_req_input_len=self.max_req_input_len,
|
||||
)
|
||||
if image_inputs and "input_ids" in image_inputs:
|
||||
input_ids = image_inputs["input_ids"]
|
||||
if mm_inputs and "input_ids" in mm_inputs:
|
||||
input_ids = mm_inputs["input_ids"]
|
||||
else:
|
||||
image_inputs: Optional[Dict] = None
|
||||
mm_inputs = None
|
||||
|
||||
self._validate_one_request(obj, input_ids)
|
||||
return self._create_tokenized_object(
|
||||
obj, input_text, input_ids, input_embeds, image_inputs, token_type_ids
|
||||
obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids
|
||||
)
|
||||
|
||||
def _validate_one_request(
|
||||
@@ -559,7 +564,7 @@ class TokenizerManager:
|
||||
input_text: str,
|
||||
input_ids: List[int],
|
||||
input_embeds: Optional[Union[List[float], None]] = None,
|
||||
image_inputs: Optional[Dict] = None,
|
||||
mm_inputs: Optional[Dict] = None,
|
||||
token_type_ids: Optional[List[int]] = None,
|
||||
) -> Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]:
|
||||
"""Create a tokenized request object from common parameters."""
|
||||
@@ -584,7 +589,7 @@ class TokenizerManager:
|
||||
obj.rid,
|
||||
input_text,
|
||||
input_ids,
|
||||
image_inputs,
|
||||
mm_inputs,
|
||||
sampling_params,
|
||||
obj.return_logprob,
|
||||
obj.logprob_start_len,
|
||||
@@ -606,7 +611,7 @@ class TokenizerManager:
|
||||
obj.rid,
|
||||
input_text,
|
||||
input_ids,
|
||||
image_inputs,
|
||||
mm_inputs,
|
||||
token_type_ids,
|
||||
sampling_params,
|
||||
)
|
||||
@@ -644,9 +649,9 @@ class TokenizerManager:
|
||||
) -> None:
|
||||
"""Validate constraints for batch tokenization processing."""
|
||||
for i in range(batch_size):
|
||||
if self.is_generation and obj[i].image_data:
|
||||
if self.is_generation and obj[i].contains_mm_input():
|
||||
raise ValueError(
|
||||
"For image input processing do not set `enable_tokenizer_batch_encode`."
|
||||
"For multimodal input processing do not set `enable_tokenizer_batch_encode`."
|
||||
)
|
||||
if obj[i].input_ids is not None:
|
||||
raise ValueError(
|
||||
|
||||
@@ -253,11 +253,9 @@ class DeepseekVL2ForCausalLM(nn.Module):
|
||||
weights_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weights_loader(param, loaded_weight)
|
||||
|
||||
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
|
||||
helper = MultiModalityDataPaddingPatternMultimodalTokens(
|
||||
[image_inputs.im_token_id]
|
||||
)
|
||||
return helper.pad_input_tokens(input_ids, image_inputs)
|
||||
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
||||
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
|
||||
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
||||
|
||||
def get_image_feature(self, items: List[MultimodalDataItem]):
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from sglang.srt.managers.mm_utils import (
|
||||
MultiModalityDataPaddingPatternTokenPairs,
|
||||
MultiModalityDataPaddingPatternMultimodalTokens,
|
||||
general_mm_embed_routine,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import (
|
||||
@@ -244,26 +244,11 @@ class Gemma3nForConditionalGeneration(PreTrainedModel):
|
||||
def pad_input_ids(
|
||||
self,
|
||||
input_ids: List[int],
|
||||
mm_inputs: Optional[MultimodalInputs] = None,
|
||||
mm_inputs: MultimodalInputs,
|
||||
) -> List[int]:
|
||||
"""Pad input IDs with image and audio tokens."""
|
||||
if mm_inputs is None:
|
||||
return input_ids
|
||||
|
||||
# Collect available media token pairs
|
||||
media_token_pairs = []
|
||||
for attr_name in ["im_start_id", "audio_start_id"]:
|
||||
if hasattr(mm_inputs, attr_name):
|
||||
start_id = getattr(mm_inputs, attr_name)
|
||||
end_id = getattr(mm_inputs, attr_name.replace("start", "end"))
|
||||
media_token_pairs.append((start_id, end_id))
|
||||
|
||||
# Apply padding pattern if we have media tokens
|
||||
if media_token_pairs:
|
||||
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
|
||||
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
||||
|
||||
return input_ids
|
||||
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
|
||||
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
||||
|
||||
def get_input_embeddings(self) -> nn.Embedding:
|
||||
return self.language_model.get_input_embeddings()
|
||||
@@ -431,7 +416,6 @@ class Gemma3nForConditionalGeneration(PreTrainedModel):
|
||||
)
|
||||
|
||||
positions += 1
|
||||
|
||||
if input_ids is not None:
|
||||
# Prepare per-layer inputs from inputs_ids
|
||||
per_layer_inputs_mask = torch.logical_and(
|
||||
|
||||
@@ -154,8 +154,7 @@ class KimiVLForConditionalGeneration(nn.Module):
|
||||
return res
|
||||
|
||||
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
||||
# Get all special token IDs
|
||||
pattern = MultiModalityDataPaddingPatternMultimodalTokens(mm_inputs.im_token_id)
|
||||
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
|
||||
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
||||
|
||||
def forward(
|
||||
|
||||
@@ -50,10 +50,7 @@ class Llama4ForConditionalGeneration(nn.Module):
|
||||
self.logits_processor = LogitsProcessor(config.text_config)
|
||||
|
||||
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
||||
# Get all special token IDs
|
||||
im_token_id: int = mm_inputs.im_token_id
|
||||
|
||||
pattern = MultiModalityDataPaddingPatternMultimodalTokens([im_token_id])
|
||||
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
|
||||
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
||||
|
||||
def get_image_feature(
|
||||
|
||||
@@ -446,9 +446,7 @@ class Phi4MMForCausalLM(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
||||
# Get all special token IDs
|
||||
im_token_id: int = mm_inputs.im_token_id
|
||||
pattern = MultiModalityDataPaddingPatternMultimodalTokens([im_token_id])
|
||||
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
|
||||
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
||||
|
||||
def should_apply_lora(self, module_name: str) -> bool:
|
||||
|
||||
@@ -268,15 +268,14 @@ class PixtralHFVisionModel(nn.Module):
|
||||
|
||||
DEFAULT_IMAGE_TOKEN_ID = 10
|
||||
|
||||
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
|
||||
return self.input_padder.pad_input_tokens(input_ids, image_inputs)
|
||||
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
||||
return self.input_padder.pad_input_tokens(input_ids, mm_inputs)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PixtralVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
*,
|
||||
image_token_id: int = DEFAULT_IMAGE_TOKEN_ID,
|
||||
num_hidden_layers_override: Optional[int] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
@@ -314,11 +313,8 @@ class PixtralHFVisionModel(nn.Module):
|
||||
)
|
||||
|
||||
# Initialize patch position embedding
|
||||
self.image_token_id = image_token_id
|
||||
self.patch_positional_embedding = PixtralRotaryEmbedding(config)
|
||||
self.input_padder = MultiModalityDataPaddingPatternMultimodalTokens(
|
||||
[self.image_token_id]
|
||||
)
|
||||
self.input_padder = MultiModalityDataPaddingPatternMultimodalTokens()
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
|
||||
@@ -493,9 +493,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||
|
||||
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
||||
# Get all special token IDs
|
||||
im_token_id: int = mm_inputs.im_token_id
|
||||
pattern = MultiModalityDataPaddingPatternMultimodalTokens([im_token_id])
|
||||
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
|
||||
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
||||
|
||||
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
||||
|
||||
@@ -479,10 +479,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||
|
||||
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
||||
# Get all special token IDs
|
||||
im_token_id: int = mm_inputs.im_token_id
|
||||
|
||||
pattern = MultiModalityDataPaddingPatternMultimodalTokens([im_token_id])
|
||||
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
|
||||
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
||||
|
||||
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
||||
|
||||
@@ -270,15 +270,10 @@ class VILAForConditionalGeneration(nn.Module):
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
def pad_input_ids(
|
||||
self,
|
||||
input_ids: List[int],
|
||||
image_inputs: MultimodalInputs,
|
||||
self, input_ids: List[int], mm_inputs: MultimodalInputs
|
||||
) -> List[int]:
|
||||
pattern = MultiModalityDataPaddingPatternMultimodalTokens(
|
||||
token_ids=[self.config.image_token_id],
|
||||
)
|
||||
|
||||
return pattern.pad_input_tokens(input_ids, image_inputs)
|
||||
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
|
||||
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
||||
|
||||
##### BEGIN COPY modeling_vila.py #####
|
||||
|
||||
|
||||
@@ -17,15 +17,6 @@ from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||
from sglang.srt.utils import encode_video, load_audio, load_image
|
||||
|
||||
|
||||
class MultimodalInputFormat(Enum):
|
||||
"""Enum for different multimodal input formats."""
|
||||
|
||||
RAW_IMAGES = "raw_images"
|
||||
PRECOMPUTED_FEATURES = "precomputed_features"
|
||||
PIXEL_VALUES = "pixel_values"
|
||||
AUDIO = "audio"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BaseMultiModalProcessorOutput:
|
||||
# input_text, with each frame of video/image represented with a image_token
|
||||
@@ -110,18 +101,45 @@ class BaseMultimodalProcessor(ABC):
|
||||
max_workers=int(os.environ.get("SGLANG_CPU_WORKERS", os.cpu_count())),
|
||||
)
|
||||
|
||||
# Mapping from attribute names to modality types
|
||||
self.ATTR_NAME_TO_MODALITY = {
|
||||
# Image-related attributes
|
||||
"pixel_values": Modality.IMAGE,
|
||||
"image_sizes": Modality.IMAGE,
|
||||
"image_grid_thw": Modality.IMAGE,
|
||||
"image_emb_mask": Modality.IMAGE,
|
||||
"image_spatial_crop": Modality.IMAGE,
|
||||
"tgt_size": Modality.IMAGE,
|
||||
"image_grid_hws": Modality.IMAGE,
|
||||
"aspect_ratio_id": Modality.IMAGE,
|
||||
"aspect_ratio_mask": Modality.IMAGE,
|
||||
"second_per_grid_ts": Modality.IMAGE,
|
||||
# Audio-related attributes
|
||||
"audio_features": Modality.AUDIO,
|
||||
"audio_feature_lens": Modality.AUDIO,
|
||||
"input_features": Modality.AUDIO,
|
||||
"input_features_mask": Modality.AUDIO,
|
||||
# Video-related attributes
|
||||
"video_grid_thws": Modality.VIDEO,
|
||||
# Generic attributes that could apply to multiple modalities
|
||||
# "precomputed_features" - handled specially as it can be any modality
|
||||
}
|
||||
|
||||
def process_mm_data(
|
||||
self, input_text, images=None, videos=None, audios=None, **kwargs
|
||||
):
|
||||
"""
|
||||
process multimodal data with transformers AutoProcessor
|
||||
"""
|
||||
if images is not None:
|
||||
if images:
|
||||
kwargs["images"] = images
|
||||
if videos is not None:
|
||||
if videos:
|
||||
kwargs["videos"] = videos
|
||||
if audios is not None:
|
||||
if audios:
|
||||
kwargs["audios"] = audios
|
||||
if self.__class__.__name__ == "Gemma3nSGLangProcessor":
|
||||
# Note(Xinyuan): for gemma3n, ref: https://github.com/huggingface/transformers/blob/ccf2ca162e33f381e454cdb74bf4b41a51ab976d/src/transformers/models/gemma3n/processing_gemma3n.py#L107
|
||||
kwargs["audio"] = audios
|
||||
|
||||
processor = self._processor
|
||||
if hasattr(processor, "image_processor") and isinstance(
|
||||
@@ -144,6 +162,7 @@ class BaseMultimodalProcessor(ABC):
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
image_data,
|
||||
audio_data,
|
||||
input_text,
|
||||
request_obj,
|
||||
max_req_input_len,
|
||||
@@ -418,175 +437,137 @@ class BaseMultimodalProcessor(ABC):
|
||||
values[k] = v
|
||||
return values
|
||||
|
||||
def process_and_combine_mm_data(
|
||||
self, base_output: BaseMultiModalProcessorOutput
|
||||
) -> Tuple[Optional[MultimodalDataItem], torch.Tensor]:
|
||||
def collect_mm_items_from_processor_output(
|
||||
self, data_dict: dict
|
||||
) -> List[MultimodalDataItem]:
|
||||
"""Create mm_items directly from processor output."""
|
||||
items = {} # modality -> MultimodalDataItem
|
||||
|
||||
for attr_name, value in data_dict.items():
|
||||
if attr_name == "input_ids":
|
||||
continue
|
||||
|
||||
# Get modality for this attribute
|
||||
modality = self.ATTR_NAME_TO_MODALITY.get(attr_name)
|
||||
|
||||
if not modality and attr_name == "precomputed_features":
|
||||
modality_str = data_dict.get("modality")
|
||||
try:
|
||||
modality = (
|
||||
Modality.from_str(modality_str)
|
||||
if modality_str
|
||||
else Modality.IMAGE
|
||||
)
|
||||
except ValueError:
|
||||
modality = Modality.IMAGE
|
||||
|
||||
if modality:
|
||||
# Create item if needed
|
||||
if modality not in items:
|
||||
items[modality] = MultimodalDataItem(modality=modality)
|
||||
|
||||
# Set attribute
|
||||
if hasattr(items[modality], attr_name):
|
||||
setattr(items[modality], attr_name, value)
|
||||
|
||||
return list(items.values())
|
||||
|
||||
def _process_and_collect_mm_items(
|
||||
self, input_text: str, images=None, audios=None, videos=None, **kwargs
|
||||
) -> Tuple[List[MultimodalDataItem], torch.Tensor]:
|
||||
"""
|
||||
Process multimodal data and return the combined multimodal item and input_ids.
|
||||
Handles all three input formats at the same abstraction level.
|
||||
Helper method to process multimodal data and create mm_items in one step.
|
||||
|
||||
Returns:
|
||||
Tuple of (combined_mm_item, input_ids)
|
||||
Tuple of (created mm_items, input_ids)
|
||||
"""
|
||||
ret = self.process_mm_data(
|
||||
input_text=input_text, images=images, audios=audios, videos=videos, **kwargs
|
||||
)
|
||||
|
||||
def tokenize_text(input_text: str) -> torch.Tensor:
|
||||
"""Tokenize input text."""
|
||||
return self._processor.tokenizer(
|
||||
input_text,
|
||||
input_ids = ret["input_ids"].flatten()
|
||||
collected_items = self.collect_mm_items_from_processor_output(ret)
|
||||
|
||||
return collected_items, input_ids
|
||||
|
||||
def process_and_combine_mm_data(
|
||||
self, base_output: BaseMultiModalProcessorOutput
|
||||
) -> Tuple[List[MultimodalDataItem], torch.Tensor]:
|
||||
"""
|
||||
Process multimodal data and return the combined multimodal items and input_ids.
|
||||
Supports mixed modalities (images and audio in the same request).
|
||||
|
||||
Returns:
|
||||
Tuple of (list of mm_items, input_ids)
|
||||
"""
|
||||
# Collect all items and categorize them
|
||||
all_items = (base_output.images or []) + (base_output.audios or [])
|
||||
|
||||
# Handle text-only case
|
||||
if not all_items:
|
||||
input_ids = self._processor.tokenizer(
|
||||
base_output.input_text,
|
||||
return_tensors="pt",
|
||||
add_special_tokens=True,
|
||||
).input_ids.flatten()
|
||||
return [], input_ids
|
||||
|
||||
dict_items, raw_images, raw_audios = [], [], []
|
||||
for item in all_items:
|
||||
if isinstance(item, dict):
|
||||
dict_items.append(item)
|
||||
elif isinstance(item, Image.Image):
|
||||
raw_images.append(item)
|
||||
elif isinstance(item, np.ndarray):
|
||||
raw_audios.append(item)
|
||||
else:
|
||||
raise ValueError(f"Unknown multimodal item type: {type(item)}")
|
||||
|
||||
# Process items and get input_ids
|
||||
all_collected_items = []
|
||||
input_ids = None
|
||||
|
||||
# Handle dict items (already processed)
|
||||
for dict_item in dict_items:
|
||||
all_collected_items.extend(
|
||||
self.collect_mm_items_from_processor_output(dict_item)
|
||||
)
|
||||
|
||||
# Handle raw items (need processing)
|
||||
if raw_images or raw_audios:
|
||||
collected_items, input_ids = self._process_and_collect_mm_items(
|
||||
input_text=base_output.input_text,
|
||||
images=raw_images,
|
||||
audios=raw_audios,
|
||||
)
|
||||
all_collected_items.extend(collected_items)
|
||||
|
||||
# Fallback tokenization if no raw items were processed
|
||||
if input_ids is None:
|
||||
input_ids = self._processor.tokenizer(
|
||||
base_output.input_text,
|
||||
return_tensors="pt",
|
||||
add_special_tokens=True,
|
||||
).input_ids.flatten()
|
||||
|
||||
def categorize_mm_inputs(mm_inputs: List) -> MultimodalInputFormat:
|
||||
"""Categorize multimodal inputs and validate consistency."""
|
||||
try:
|
||||
has_image = False
|
||||
has_pixel_values = False
|
||||
has_precomputed_features = False
|
||||
has_audio = False
|
||||
|
||||
for mm_input in mm_inputs:
|
||||
if isinstance(mm_input, Image.Image):
|
||||
has_image = True
|
||||
elif isinstance(mm_input, np.ndarray):
|
||||
has_audio = True
|
||||
elif isinstance(mm_input, dict):
|
||||
if mm_input.get("precomputed_features", None) is not None:
|
||||
has_precomputed_features = True
|
||||
elif mm_input.get("pixel_values", None) is not None:
|
||||
has_pixel_values = True
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid multimodal input: {mm_input}, expected dict with pixel_values or precomputed_features"
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid multimodal input: {mm_input}, expected Image.Image or dict"
|
||||
)
|
||||
|
||||
# Validate format consistency
|
||||
format_count = sum(
|
||||
[has_image, has_pixel_values, has_precomputed_features, has_audio]
|
||||
)
|
||||
if format_count > 1:
|
||||
raise ValueError(
|
||||
"Unsupported: mixture of multimodal input formats. "
|
||||
f"Found formats: image={has_image}, pixel_values={has_pixel_values}, "
|
||||
f"precomputed_features={has_precomputed_features}, audio={has_audio}"
|
||||
)
|
||||
|
||||
if has_image:
|
||||
return MultimodalInputFormat.RAW_IMAGES
|
||||
elif has_precomputed_features:
|
||||
return MultimodalInputFormat.PRECOMPUTED_FEATURES
|
||||
elif has_pixel_values:
|
||||
return MultimodalInputFormat.PIXEL_VALUES
|
||||
elif has_audio:
|
||||
return MultimodalInputFormat.AUDIO
|
||||
else:
|
||||
raise ValueError("No valid multimodal input format found")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to categorize inputs: {e}")
|
||||
|
||||
def process_raw_images(
|
||||
base_output: BaseMultiModalProcessorOutput,
|
||||
) -> Tuple[MultimodalDataItem, torch.Tensor]:
|
||||
"""Process raw Image.Image objects using transformers processor."""
|
||||
ret = self.process_mm_data(
|
||||
input_text=base_output.input_text,
|
||||
images=base_output.images,
|
||||
)
|
||||
combined_mm_item = MultimodalDataItem(modality=Modality.IMAGE)
|
||||
|
||||
# Copy all fields from processor output except input_ids
|
||||
for key, value in ret.items():
|
||||
if key != "input_ids" and hasattr(combined_mm_item, key):
|
||||
setattr(combined_mm_item, key, value)
|
||||
|
||||
input_ids = ret["input_ids"].flatten()
|
||||
return combined_mm_item, input_ids
|
||||
|
||||
def process_precomputed_features(
|
||||
base_output: BaseMultiModalProcessorOutput,
|
||||
) -> Tuple[MultimodalDataItem, torch.Tensor]:
|
||||
"""Process inputs with precomputed features."""
|
||||
combined_mm_item = MultimodalDataItem(modality=Modality.IMAGE)
|
||||
combined_mm_item.precomputed_features = self._extract_processor_features(
|
||||
base_output.images, "precomputed_features"
|
||||
)
|
||||
input_ids = tokenize_text(base_output.input_text)
|
||||
return combined_mm_item, input_ids
|
||||
|
||||
def process_pixel_values(
|
||||
base_output: BaseMultiModalProcessorOutput,
|
||||
) -> Tuple[MultimodalDataItem, torch.Tensor]:
|
||||
"""Process inputs with pixel values."""
|
||||
values = self._extract_processor_features_from_all_attributes(
|
||||
base_output.images
|
||||
)
|
||||
combined_mm_item = MultimodalDataItem.from_dict(values)
|
||||
input_ids = tokenize_text(base_output.input_text)
|
||||
return combined_mm_item, input_ids
|
||||
|
||||
def process_audio(
|
||||
base_output: BaseMultiModalProcessorOutput,
|
||||
) -> Tuple[MultimodalDataItem, torch.Tensor]:
|
||||
"""Process inputs with audio."""
|
||||
ret = self.process_mm_data(
|
||||
input_text=base_output.input_text,
|
||||
audio=base_output.audios, # Note: "audio" is for gemma3n only
|
||||
)
|
||||
combined_mm_item = MultimodalDataItem(modality=Modality.AUDIO)
|
||||
for key, value in ret.items():
|
||||
if key != "input_ids" and hasattr(combined_mm_item, key):
|
||||
setattr(combined_mm_item, key, value)
|
||||
input_ids = ret["input_ids"].flatten()
|
||||
return combined_mm_item, input_ids
|
||||
|
||||
def finalize_mm_item(
|
||||
combined_mm_item: MultimodalDataItem, input_ids: torch.Tensor
|
||||
) -> MultimodalDataItem:
|
||||
"""Apply common post-processing to the multimodal item."""
|
||||
if combined_mm_item.modality in [Modality.IMAGE, Modality.MULTI_IMAGES]:
|
||||
combined_mm_item.image_offsets = self.get_mm_items_offset(
|
||||
# Add offsets to all items
|
||||
for mm_item in all_collected_items:
|
||||
if mm_item.modality in [Modality.IMAGE, Modality.MULTI_IMAGES]:
|
||||
mm_item.image_offsets = self.get_mm_items_offset(
|
||||
input_ids=input_ids,
|
||||
mm_token_id=self.IM_TOKEN_ID,
|
||||
)
|
||||
elif combined_mm_item.modality == Modality.AUDIO:
|
||||
combined_mm_item.audio_offsets = self.get_mm_items_offset(
|
||||
elif mm_item.modality == Modality.AUDIO:
|
||||
mm_item.audio_offsets = self.get_mm_items_offset(
|
||||
input_ids=input_ids,
|
||||
mm_token_id=self.AUDIO_TOKEN_ID,
|
||||
)
|
||||
elif combined_mm_item.modality == Modality.VIDEO:
|
||||
combined_mm_item.video_offsets = self.get_mm_items_offset(
|
||||
elif mm_item.modality == Modality.VIDEO:
|
||||
mm_item.video_offsets = self.get_mm_items_offset(
|
||||
input_ids=input_ids,
|
||||
mm_token_id=self.VIDEO_TOKEN_ID,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown modality: {combined_mm_item.modality}")
|
||||
return combined_mm_item
|
||||
raise ValueError(f"Unknown modality: {mm_item.modality}")
|
||||
|
||||
# Main logic - determine input type and handle text-only case
|
||||
mm_inputs = base_output.images or base_output.audios
|
||||
if not mm_inputs:
|
||||
input_ids = tokenize_text(base_output.input_text)
|
||||
return None, input_ids
|
||||
|
||||
# Categorize input formats
|
||||
input_format = categorize_mm_inputs(mm_inputs)
|
||||
|
||||
# Process based on format
|
||||
if input_format == MultimodalInputFormat.RAW_IMAGES:
|
||||
combined_mm_item, input_ids = process_raw_images(base_output)
|
||||
elif input_format == MultimodalInputFormat.PRECOMPUTED_FEATURES:
|
||||
combined_mm_item, input_ids = process_precomputed_features(base_output)
|
||||
elif input_format == MultimodalInputFormat.PIXEL_VALUES:
|
||||
combined_mm_item, input_ids = process_pixel_values(base_output)
|
||||
elif input_format == MultimodalInputFormat.AUDIO:
|
||||
combined_mm_item, input_ids = process_audio(base_output)
|
||||
else:
|
||||
raise ValueError(f"Unknown input format: {input_format}")
|
||||
|
||||
# Finalize with common processing
|
||||
combined_mm_item = finalize_mm_item(combined_mm_item, input_ids)
|
||||
return combined_mm_item, input_ids
|
||||
return all_collected_items, input_ids
|
||||
|
||||
@@ -15,20 +15,11 @@ class ClipImageProcessor(BaseMultimodalProcessor):
|
||||
async def process_mm_data_async(
|
||||
self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs
|
||||
):
|
||||
if not image_data:
|
||||
return None
|
||||
|
||||
if isinstance(input_text, list):
|
||||
assert len(input_text) and isinstance(input_text[0], int)
|
||||
input_text = self._processor.tokenizer.decode(input_text)
|
||||
|
||||
if not isinstance(image_data, list):
|
||||
image_data = [image_data]
|
||||
|
||||
if len(image_data) > 0:
|
||||
images = [load_image(image)[0] for image in image_data]
|
||||
else:
|
||||
images = load_image(image_data[0])[0]
|
||||
images = [load_image(image)[0] for image in image_data]
|
||||
|
||||
image_inputs = self.process_mm_data(input_text=input_text, images=images)
|
||||
image_inputs["data_hashes"] = [hash(str(image_data))]
|
||||
|
||||
@@ -44,17 +44,10 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
if not image_data:
|
||||
return None
|
||||
|
||||
if not isinstance(image_data, list):
|
||||
image_data = [image_data]
|
||||
|
||||
image_token = self.IMAGE_TOKEN
|
||||
base_output = self.load_mm_data(
|
||||
input_text,
|
||||
image_data=image_data,
|
||||
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
|
||||
multimodal_tokens=MultimodalSpecialTokens(image_token=self.IMAGE_TOKEN),
|
||||
max_req_input_len=max_req_input_len,
|
||||
)
|
||||
res = self.process_mm_data(
|
||||
|
||||
@@ -36,11 +36,6 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
if not image_data:
|
||||
return None
|
||||
if isinstance(image_data, str):
|
||||
image_data = [image_data]
|
||||
|
||||
base_output = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
image_data=image_data,
|
||||
@@ -51,11 +46,11 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
|
||||
discard_alpha_channel=True,
|
||||
)
|
||||
|
||||
combined_mm_item, input_ids = self.process_and_combine_mm_data(base_output)
|
||||
mm_items, input_ids = self.process_and_combine_mm_data(base_output)
|
||||
|
||||
return {
|
||||
"input_ids": input_ids.tolist(),
|
||||
"mm_items": [combined_mm_item] if combined_mm_item is not None else [],
|
||||
"mm_items": mm_items,
|
||||
"im_start_id": self.IM_START_TOKEN_ID,
|
||||
"im_end_id": self.IM_END_TOKEN_ID,
|
||||
}
|
||||
|
||||
@@ -59,17 +59,6 @@ class Gemma3nSGLangProcessor(SGLangBaseProcessor):
|
||||
**kwargs,
|
||||
):
|
||||
"""Process multimodal data including images and audio."""
|
||||
|
||||
audio_data = request_obj.audio_data
|
||||
if not image_data and not audio_data:
|
||||
return None
|
||||
|
||||
if isinstance(image_data, str):
|
||||
image_data = [image_data]
|
||||
|
||||
if isinstance(audio_data, str):
|
||||
audio_data = [audio_data]
|
||||
|
||||
base_output = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
image_data=image_data,
|
||||
@@ -83,13 +72,11 @@ class Gemma3nSGLangProcessor(SGLangBaseProcessor):
|
||||
),
|
||||
)
|
||||
|
||||
combined_mm_item, input_ids = self.process_and_combine_mm_data(base_output)
|
||||
mm_items, input_ids = self.process_and_combine_mm_data(base_output)
|
||||
|
||||
return {
|
||||
"input_ids": input_ids.tolist(),
|
||||
"mm_items": [combined_mm_item] if combined_mm_item is not None else [],
|
||||
"im_start_id": self.IM_START_TOKEN_ID,
|
||||
"im_end_id": self.IM_END_TOKEN_ID,
|
||||
"audio_start_id": self.AUDIO_START_TOKEN_ID,
|
||||
"audio_end_id": self.AUDIO_END_TOKEN_ID,
|
||||
"mm_items": mm_items,
|
||||
"im_token_id": self.IM_TOKEN_ID,
|
||||
"audio_token_id": self.AUDIO_TOKEN_ID,
|
||||
}
|
||||
|
||||
@@ -172,13 +172,6 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
|
||||
async def process_mm_data_async(
|
||||
self, image_data, input_text, request_obj, max_req_input_len, **kwargs
|
||||
):
|
||||
if not image_data:
|
||||
return None
|
||||
|
||||
# Ensure image_data is a list
|
||||
if isinstance(image_data, str):
|
||||
image_data = [image_data]
|
||||
|
||||
base_output = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
image_data=image_data,
|
||||
|
||||
@@ -22,12 +22,6 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
|
||||
max_req_input_len,
|
||||
**kwargs,
|
||||
):
|
||||
if not image_data:
|
||||
return None
|
||||
|
||||
if not isinstance(image_data, list):
|
||||
image_data = [image_data]
|
||||
|
||||
processor = self._processor
|
||||
|
||||
base_out = self.load_mm_data(
|
||||
|
||||
@@ -30,11 +30,6 @@ class KimiVLImageProcessor(SGLangBaseProcessor):
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
if not image_data:
|
||||
return None
|
||||
if isinstance(image_data, str):
|
||||
image_data = [image_data]
|
||||
|
||||
base_output = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
image_data=image_data,
|
||||
@@ -44,10 +39,10 @@ class KimiVLImageProcessor(SGLangBaseProcessor):
|
||||
max_req_input_len=max_req_input_len,
|
||||
)
|
||||
|
||||
combined_mm_item, input_ids = self.process_and_combine_mm_data(base_output)
|
||||
mm_items, input_ids = self.process_and_combine_mm_data(base_output)
|
||||
|
||||
return {
|
||||
"input_ids": input_ids.tolist(),
|
||||
"mm_items": [combined_mm_item] if combined_mm_item is not None else [],
|
||||
"mm_items": mm_items,
|
||||
"im_token_id": self.IM_TOKEN_ID,
|
||||
}
|
||||
|
||||
@@ -110,9 +110,6 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
if not image_data:
|
||||
return None
|
||||
|
||||
modalities = request_obj.modalities or ["image"]
|
||||
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
|
||||
grid_pinpoints = (
|
||||
@@ -122,9 +119,6 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
|
||||
else None
|
||||
)
|
||||
|
||||
if isinstance(image_data, str):
|
||||
image_data = [image_data]
|
||||
|
||||
if isinstance(image_data, list) and len(image_data) > 0:
|
||||
if "multi-images" in modalities or "video" in modalities:
|
||||
# Multiple images
|
||||
|
||||
@@ -23,19 +23,12 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
image_data: List[Union[str, bytes]],
|
||||
audio_data: List[Union[str, bytes]],
|
||||
input_text,
|
||||
request_obj,
|
||||
max_req_input_len,
|
||||
**kwargs,
|
||||
):
|
||||
audio_data = request_obj.audio_data
|
||||
if not image_data and not audio_data:
|
||||
return None
|
||||
if not isinstance(image_data, list):
|
||||
image_data = [image_data]
|
||||
if not isinstance(audio_data, list):
|
||||
audio_data = [audio_data]
|
||||
|
||||
base_output = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
max_req_input_len=max_req_input_len,
|
||||
|
||||
@@ -15,21 +15,11 @@ class MllamaImageProcessor(BaseMultimodalProcessor):
|
||||
async def process_mm_data_async(
|
||||
self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs
|
||||
):
|
||||
if not image_data:
|
||||
return None
|
||||
|
||||
if isinstance(input_text, list):
|
||||
assert len(input_text) and isinstance(input_text[0], int)
|
||||
input_text = self._processor.tokenizer.decode(input_text)
|
||||
|
||||
if not isinstance(image_data, list):
|
||||
image_data = [image_data]
|
||||
|
||||
if len(image_data) > 0:
|
||||
images = [load_image(image)[0] for image in image_data]
|
||||
else:
|
||||
images = load_image(image_data[0])[0]
|
||||
|
||||
images = [load_image(image)[0] for image in image_data]
|
||||
image_inputs = self.process_mm_data(input_text=input_text, images=images)
|
||||
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
|
||||
image_inputs["mm_items"] = [
|
||||
|
||||
@@ -37,9 +37,6 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
if not image_data:
|
||||
return None
|
||||
|
||||
if isinstance(input_text, list):
|
||||
assert len(input_text) and isinstance(input_text[0], int)
|
||||
input_text = self._processor.tokenizer.decode(input_text)
|
||||
|
||||
@@ -26,22 +26,12 @@ class Phi4MMImageProcessor(BaseMultimodalProcessor):
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
image_data: List[Union[str, bytes]],
|
||||
audio_data,
|
||||
input_text,
|
||||
request_obj,
|
||||
max_req_input_len,
|
||||
**kwargs,
|
||||
):
|
||||
audio_data = request_obj.audio_data
|
||||
|
||||
if not image_data and not audio_data:
|
||||
return None
|
||||
|
||||
if not isinstance(image_data, list):
|
||||
image_data = [image_data]
|
||||
|
||||
if not isinstance(audio_data, list):
|
||||
audio_data = [audio_data]
|
||||
|
||||
if audio_data:
|
||||
logger.warning(
|
||||
"Currently SGLang does not support audio data for Phi4MM. We are working on it. You can file an issue to help us prioritize."
|
||||
|
||||
@@ -78,12 +78,6 @@ class PixtralProcessor(BaseMultimodalProcessor):
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
if not image_data:
|
||||
return None
|
||||
|
||||
if isinstance(image_data, str):
|
||||
image_data = [image_data]
|
||||
|
||||
mm_data = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
multimodal_tokens=self.multimodal_tokens,
|
||||
|
||||
@@ -49,9 +49,6 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
if isinstance(image_data, str):
|
||||
image_data = [image_data]
|
||||
|
||||
base_output = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
image_data=image_data,
|
||||
@@ -130,12 +127,13 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
||||
|
||||
video_grid_thw = None # TODO
|
||||
|
||||
combined_mm_item, input_ids = self.process_and_combine_mm_data(base_output)
|
||||
mm_items, input_ids = self.process_and_combine_mm_data(base_output)
|
||||
|
||||
if combined_mm_item is None:
|
||||
if not mm_items:
|
||||
# Note(Xinyuan): This is the case where image loading fails.
|
||||
return None
|
||||
|
||||
combined_mm_item = mm_items[0] # only image is supported for now
|
||||
video_grid_thw = None # TODO
|
||||
second_per_grid_ts = getattr(combined_mm_item, "second_per_grid_ts", None)
|
||||
|
||||
@@ -157,7 +155,7 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
||||
|
||||
return {
|
||||
"input_ids": input_ids.tolist(),
|
||||
"mm_items": [combined_mm_item],
|
||||
"mm_items": mm_items,
|
||||
"im_start_id": self.IM_START_TOKEN_ID,
|
||||
"im_end_id": self.IM_END_TOKEN_ID,
|
||||
"im_token_id": self.IM_TOKEN_ID,
|
||||
|
||||
@@ -37,6 +37,8 @@ class VILAMultimodalProcessor(BaseMultimodalProcessor):
|
||||
_processor: VILAProcessor,
|
||||
) -> None:
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
self.IM_TOKEN_ID = hf_config.image_token_id
|
||||
self.VIDEO_TOKEN_ID = hf_config.video_token_id
|
||||
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
@@ -46,13 +48,7 @@ class VILAMultimodalProcessor(BaseMultimodalProcessor):
|
||||
max_req_input_len: int,
|
||||
**kwargs,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
if not image_data:
|
||||
return None
|
||||
|
||||
if not isinstance(image_data, list):
|
||||
image_data = [image_data]
|
||||
|
||||
mm_data = self.load_mm_data(
|
||||
base_output = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
multimodal_tokens=MultimodalSpecialTokens(
|
||||
image_token=self._processor.tokenizer.image_token
|
||||
@@ -61,25 +57,11 @@ class VILAMultimodalProcessor(BaseMultimodalProcessor):
|
||||
image_data=image_data,
|
||||
)
|
||||
|
||||
inputs = self.process_mm_data(
|
||||
input_text=mm_data.input_text,
|
||||
images=mm_data.images,
|
||||
)
|
||||
mm_items, input_ids = self.process_and_combine_mm_data(base_output)
|
||||
|
||||
image_offsets = self.get_mm_items_offset(
|
||||
input_ids=inputs.input_ids[0],
|
||||
mm_token_id=cast(int, self._processor.tokenizer.image_token_id),
|
||||
)
|
||||
|
||||
mm_items: List[MultimodalDataItem] = [
|
||||
MultimodalDataItem(
|
||||
modality=Modality.IMAGE,
|
||||
image_offsets=image_offsets,
|
||||
pixel_values=inputs.pixel_values,
|
||||
)
|
||||
]
|
||||
|
||||
return dict(
|
||||
input_ids=inputs.input_ids[0].tolist(),
|
||||
mm_items=mm_items,
|
||||
)
|
||||
return {
|
||||
"input_ids": input_ids.tolist(),
|
||||
"mm_items": mm_items,
|
||||
"im_token_id": self.IM_TOKEN_ID,
|
||||
"video_token_id": self.VIDEO_TOKEN_ID,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user