[Refactor] simplify multimodal data processing (#8107)
Signed-off-by: Xinyuan Tong <justinning0323@outlook.com>
This commit is contained in:
@@ -42,6 +42,9 @@ def select_best_resolution(image_size, candidate_resolutions):
|
||||
|
||||
|
||||
class DictOutput(object):
|
||||
def items(self):
|
||||
return self.__dict__.items()
|
||||
|
||||
def keys(self):
|
||||
return self.__dict__.keys()
|
||||
|
||||
@@ -59,7 +62,9 @@ class DictOutput(object):
|
||||
class VLChatProcessorOutput(DictOutput):
|
||||
input_ids: torch.LongTensor
|
||||
target_ids: torch.LongTensor
|
||||
images: torch.Tensor
|
||||
pixel_values: (
|
||||
torch.Tensor
|
||||
) # rename from "images" to "pixel_values" for compatibility
|
||||
images_seq_mask: torch.BoolTensor
|
||||
images_spatial_crop: torch.LongTensor
|
||||
|
||||
@@ -312,10 +317,14 @@ class DeepseekVLV2Processor(ProcessorMixin):
|
||||
images = torch.stack(images_list, dim=0)
|
||||
images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long)
|
||||
|
||||
images_spatial_crop = torch.stack(
|
||||
[images_spatial_crop], dim=0
|
||||
) # stack the tensor to make it a batch of 1
|
||||
|
||||
prepare = VLChatProcessorOutput(
|
||||
input_ids=input_ids,
|
||||
target_ids=target_ids,
|
||||
images=images,
|
||||
pixel_values=images,
|
||||
images_seq_mask=images_seq_mask,
|
||||
images_spatial_crop=images_spatial_crop,
|
||||
)
|
||||
|
||||
@@ -284,6 +284,9 @@ class VLMImageProcessor(BaseImageProcessor):
|
||||
|
||||
|
||||
class DictOutput(object):
|
||||
def items(self):
|
||||
return self.__dict__.items()
|
||||
|
||||
def keys(self):
|
||||
return self.__dict__.keys()
|
||||
|
||||
|
||||
@@ -221,17 +221,17 @@ def _get_precomputed_embedding(
|
||||
items: List[MultimodalDataItem],
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
If all items have precomputed_features, return their concatenation.
|
||||
If some but not all have precomputed_features, raise NotImplementedError.
|
||||
If none have precomputed_features, return None.
|
||||
If all items have precomputed_embeddings, return their concatenation.
|
||||
If some but not all have precomputed_embeddings, raise NotImplementedError.
|
||||
If none have precomputed_embeddings, return None.
|
||||
"""
|
||||
precomputed_features = [item.precomputed_features for item in items]
|
||||
if any(feature is not None for feature in precomputed_features):
|
||||
if not all(feature is not None for feature in precomputed_features):
|
||||
precomputed_embeddings = [item.precomputed_embeddings for item in items]
|
||||
if any(feature is not None for feature in precomputed_embeddings):
|
||||
if not all(feature is not None for feature in precomputed_embeddings):
|
||||
raise NotImplementedError(
|
||||
"MM inputs where only some items are precomputed."
|
||||
)
|
||||
result = torch.concat(precomputed_features)
|
||||
result = torch.concat(precomputed_embeddings)
|
||||
# some models embedding is 3-dim, reshape it to 2-dim (similar to get_embedding_chunk)
|
||||
result = result.reshape(-1, result.shape[-1])
|
||||
return result
|
||||
|
||||
@@ -1,94 +0,0 @@
|
||||
import re
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||
BaseMultimodalProcessor,
|
||||
MultimodalSpecialTokens,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||
from sglang.srt.models.qwen2_audio import Qwen2AudioForConditionalGeneration
|
||||
|
||||
|
||||
class Qwen2AudioMultimodalProcessor(BaseMultimodalProcessor):
|
||||
models = [Qwen2AudioForConditionalGeneration]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
self.AUDIO_TOKEN = "<|audio_bos|><|AUDIO|><|audio_eos|>"
|
||||
self.AUDIO_TOKEN_REGEX = re.compile(
|
||||
r"<\|audio_bos\|>(?:<\|AUDIO\|>)+<\|audio_eos\|>"
|
||||
)
|
||||
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
image_data: List[Union[str, bytes]],
|
||||
input_text,
|
||||
request_obj,
|
||||
max_req_input_len,
|
||||
**kwargs,
|
||||
):
|
||||
audio_data = request_obj.audio_data
|
||||
if not isinstance(audio_data, list):
|
||||
audio_data = [audio_data]
|
||||
|
||||
base_output = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
max_req_input_len=max_req_input_len,
|
||||
audio_data=audio_data,
|
||||
multimodal_tokens=MultimodalSpecialTokens(
|
||||
audio_token=self.AUDIO_TOKEN,
|
||||
audio_token_regex=self.AUDIO_TOKEN_REGEX,
|
||||
),
|
||||
)
|
||||
if base_output is None:
|
||||
return None
|
||||
|
||||
res = self.process_mm_data(
|
||||
input_text=base_output.input_text,
|
||||
audio=base_output.audios,
|
||||
)
|
||||
|
||||
# Collect special token ids
|
||||
tokenizer = self._processor.tokenizer
|
||||
audio_start_id = tokenizer.convert_tokens_to_ids("<|audio_bos|>")
|
||||
audio_token_id = tokenizer.convert_tokens_to_ids("<|AUDIO|>")
|
||||
audio_end_id = tokenizer.convert_tokens_to_ids("<|audio_eos|>")
|
||||
|
||||
items = []
|
||||
input_ids = res["input_ids"].flatten()
|
||||
|
||||
if (
|
||||
"input_features" in res
|
||||
and res["input_features"] is not None
|
||||
and len(res["input_features"]) != 0
|
||||
):
|
||||
if audio_start_id is not None and audio_end_id is not None:
|
||||
audio_offsets = self.get_mm_items_offset_by_pair(
|
||||
input_ids=input_ids,
|
||||
mm_start_id=audio_start_id,
|
||||
mm_end_id=audio_end_id,
|
||||
)
|
||||
else:
|
||||
audio_offsets = None
|
||||
|
||||
input_lengths = res["feature_attention_mask"].sum(dim=-1)
|
||||
input_lengths = (input_lengths - 1) // 2 + 1
|
||||
output_lengths = (input_lengths - 2) // 2 + 1
|
||||
|
||||
item = MultimodalDataItem(
|
||||
feature=res["input_features"],
|
||||
audio_feature_lens=output_lengths,
|
||||
audio_offsets=audio_offsets,
|
||||
modality=Modality.AUDIO,
|
||||
)
|
||||
items += [item]
|
||||
|
||||
return {
|
||||
"mm_items": items,
|
||||
"input_ids": input_ids.tolist(),
|
||||
"audio_start_id": audio_start_id,
|
||||
"audio_token_id": audio_token_id,
|
||||
"audio_end_id": audio_end_id,
|
||||
}
|
||||
@@ -201,7 +201,7 @@ class MultimodalDataItem:
|
||||
For example, if there are 3 images and 1 audio inputs, there will be 2 MultimodalDataItem.
|
||||
One for images and one for audio.
|
||||
|
||||
We put the common fields first and the model-specific fields last.
|
||||
We put the common fields first and the model-specific fields in model_specific_data.
|
||||
"""
|
||||
|
||||
modality: Modality
|
||||
@@ -211,37 +211,31 @@ class MultimodalDataItem:
|
||||
# the raw features returned by processor, e.g. pixel_values or audio_features
|
||||
feature: Union[torch.Tensor, np.ndarray] = None
|
||||
|
||||
image_sizes: Tuple[int, int] = None
|
||||
# the precomputed embeddings for the modality, e.g. image_emb for image, audio_emb for audio
|
||||
precomputed_embeddings: Optional[Union[torch.Tensor, np.ndarray]] = None
|
||||
|
||||
audio_feature_lens: Optional[List[torch.Tensor]] = None
|
||||
audio_offsets: Optional[List[Tuple[int, int]]] = None
|
||||
precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None
|
||||
# Model-specific data stored in a dictionary
|
||||
model_specific_data: dict[str, Any] = dataclasses.field(default_factory=dict)
|
||||
|
||||
# For qwen-vl
|
||||
image_grid_thw: Union[torch.Tensor, np.ndarray] = None
|
||||
second_per_grid_ts: Optional[List[torch.Tensor]] = None
|
||||
def __getattr__(self, name: str):
|
||||
if (
|
||||
"model_specific_data" in self.__dict__
|
||||
and name in self.__dict__["model_specific_data"]
|
||||
):
|
||||
return self.__dict__["model_specific_data"][name]
|
||||
else:
|
||||
raise AttributeError(
|
||||
f"'{self.__class__.__name__}' object has no attribute '{name}'"
|
||||
)
|
||||
|
||||
# For deepseek-vl
|
||||
image_emb_mask: Optional[torch.Tensor] = None
|
||||
image_spatial_crop: Optional[torch.Tensor] = None
|
||||
def __setitem__(self, key: str, value: Any):
|
||||
if key in self.__dict__:
|
||||
self.__dict__[key] = value
|
||||
else:
|
||||
self.model_specific_data[key] = value
|
||||
|
||||
# For minicpmv
|
||||
# [num_images, (n, w, h)]
|
||||
tgt_size: Tuple[int, int] = None
|
||||
|
||||
# For mllama
|
||||
aspect_ratio_id: Optional[List[torch.Tensor]] = None
|
||||
aspect_ratio_mask: Optional[List[torch.Tensor]] = None
|
||||
|
||||
# For kimi-vl
|
||||
image_grid_hws: Optional[List[torch.Tensor]] = None
|
||||
|
||||
# For gemma3n
|
||||
input_features_mask: Optional[torch.Tensor] = None
|
||||
|
||||
# For phi4-mm
|
||||
image_attention_mask: Optional[torch.Tensor] = None
|
||||
audio_attention_mask: Optional[torch.Tensor] = None
|
||||
def set(self, key: str, value: Any):
|
||||
self.__setitem__(key, value)
|
||||
|
||||
@staticmethod
|
||||
def is_empty_list(l):
|
||||
@@ -259,7 +253,7 @@ class MultimodalDataItem:
|
||||
if self.feature is not None:
|
||||
hashed_feature = self.feature
|
||||
else:
|
||||
hashed_feature = self.precomputed_features
|
||||
hashed_feature = self.precomputed_embeddings
|
||||
self.hash = hash_feature(hashed_feature)
|
||||
assert self.hash is not None
|
||||
self.pad_value = self.hash % (1 << 30)
|
||||
@@ -268,24 +262,13 @@ class MultimodalDataItem:
|
||||
return self.modality == modality
|
||||
|
||||
def is_audio(self):
|
||||
return (self.modality == Modality.AUDIO) and (
|
||||
self.precomputed_features is not None
|
||||
or not MultimodalDataItem.is_empty_list(self.feature)
|
||||
)
|
||||
return self.modality == Modality.AUDIO
|
||||
|
||||
def is_image(self):
|
||||
return (
|
||||
self.is_modality(Modality.IMAGE) or self.is_modality(Modality.MULTI_IMAGES)
|
||||
) and (
|
||||
self.precomputed_features is not None
|
||||
or not MultimodalDataItem.is_empty_list(self.feature)
|
||||
)
|
||||
return self.modality in [Modality.IMAGE, Modality.MULTI_IMAGES]
|
||||
|
||||
def is_video(self):
|
||||
return (self.modality == Modality.VIDEO) and (
|
||||
self.precomputed_features is not None
|
||||
or not MultimodalDataItem.is_empty_list(self.feature)
|
||||
)
|
||||
return self.modality == Modality.VIDEO
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
return self.is_image() or self.is_video() or self.is_audio()
|
||||
@@ -306,8 +289,7 @@ class MultimodalDataItem:
|
||||
|
||||
def merge(self, other):
|
||||
self.feature += other.feature
|
||||
self.image_sizes += other.image_sizes
|
||||
self.image_offsets += other.image_offsets
|
||||
self.offsets += other.offsets
|
||||
self.hash = hash((self.hash, other.hash))
|
||||
self.set_pad_value()
|
||||
|
||||
|
||||
@@ -260,7 +260,7 @@ class DeepseekVL2ForCausalLM(nn.Module):
|
||||
def get_image_feature(self, items: List[MultimodalDataItem]):
|
||||
|
||||
images_spatial_crop = torch.cat(
|
||||
[item.image_spatial_crop for item in items], dim=0
|
||||
[item.images_spatial_crop for item in items], dim=0
|
||||
)
|
||||
|
||||
assert images_spatial_crop.dim() == 3
|
||||
@@ -278,8 +278,8 @@ class DeepseekVL2ForCausalLM(nn.Module):
|
||||
_, hw, n_dim = images_embeds.shape
|
||||
h = w = int(hw**0.5)
|
||||
tile_index = 0
|
||||
for jdx in range(item.image_spatial_crop.shape[1]):
|
||||
num_width_tiles, num_height_tiles = item.image_spatial_crop[0, jdx]
|
||||
for jdx in range(item.images_spatial_crop.shape[1]):
|
||||
num_width_tiles, num_height_tiles = item.images_spatial_crop[0, jdx]
|
||||
if num_width_tiles == 0 or num_height_tiles == 0:
|
||||
break
|
||||
num_tiles_in_image = num_width_tiles * num_height_tiles
|
||||
|
||||
@@ -81,6 +81,7 @@ class Llama4ForConditionalGeneration(nn.Module):
|
||||
self.logits_processor = LogitsProcessor(
|
||||
config.text_config if hasattr(config, "text_config") else config
|
||||
)
|
||||
self.padding_pattern = MultiModalityDataPaddingPatternMultimodalTokens()
|
||||
|
||||
def _has_vision_weights(self, config) -> bool:
|
||||
"""Check if the model has vision components by examining the checkpoint."""
|
||||
@@ -135,8 +136,7 @@ class Llama4ForConditionalGeneration(nn.Module):
|
||||
return False
|
||||
|
||||
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
||||
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
|
||||
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
||||
return self.padding_pattern.pad_input_tokens(input_ids, mm_inputs)
|
||||
|
||||
def get_image_feature(
|
||||
self,
|
||||
|
||||
@@ -435,7 +435,12 @@ class Phi4MMForCausalLM(nn.Module):
|
||||
dtype = next(self.vision_encoder.parameters()).dtype
|
||||
pixel_values = torch.cat([item.feature for item in items], dim=0).type(dtype)
|
||||
image_attention_mask = torch.cat(
|
||||
[item.image_attention_mask for item in items], dim=0
|
||||
[
|
||||
item.image_attention_mask
|
||||
for item in items
|
||||
if hasattr(item, "image_attention_mask")
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
image_sizes = torch.cat([item.image_sizes for item in items], dim=0)
|
||||
image_embeds = self.vision_encoder(
|
||||
@@ -456,7 +461,7 @@ class Phi4MMForCausalLM(nn.Module):
|
||||
audio_features=item.feature.to(device).type(dtype),
|
||||
audio_attention_mask=(
|
||||
item.audio_attention_mask.to(device)
|
||||
if item.audio_attention_mask is not None
|
||||
if hasattr(item, "audio_attention_mask")
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@ import multiprocessing as mp
|
||||
import os
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -155,17 +155,15 @@ class BaseMultimodalProcessor(ABC):
|
||||
self.ATTR_NAME_TO_MODALITY = {
|
||||
# Image-related attributes
|
||||
"pixel_values": Modality.IMAGE,
|
||||
"pixel_values_videos": Modality.VIDEO,
|
||||
"image_sizes": Modality.IMAGE,
|
||||
"image_grid_thw": Modality.IMAGE,
|
||||
"image_attention_mask": Modality.IMAGE,
|
||||
"image_emb_mask": Modality.IMAGE,
|
||||
"image_spatial_crop": Modality.IMAGE,
|
||||
"images_spatial_crop": Modality.IMAGE,
|
||||
"tgt_size": Modality.IMAGE,
|
||||
"image_grid_hws": Modality.IMAGE,
|
||||
"aspect_ratio_id": Modality.IMAGE,
|
||||
"aspect_ratio_ids": Modality.IMAGE,
|
||||
"aspect_ratio_mask": Modality.IMAGE,
|
||||
"second_per_grid_ts": Modality.IMAGE,
|
||||
# Audio-related attributes
|
||||
"audio_features": Modality.AUDIO,
|
||||
"audio_feature_lens": Modality.AUDIO,
|
||||
@@ -173,9 +171,11 @@ class BaseMultimodalProcessor(ABC):
|
||||
"input_features_mask": Modality.AUDIO,
|
||||
"audio_attention_mask": Modality.AUDIO,
|
||||
# Video-related attributes
|
||||
"pixel_values_videos": Modality.VIDEO,
|
||||
"second_per_grid_ts": Modality.VIDEO,
|
||||
"video_grid_thw": Modality.VIDEO,
|
||||
# Generic attributes that could apply to multiple modalities
|
||||
# "precomputed_features" - handled specially as it can be any modality
|
||||
# "precomputed_embeddings" - handled specially as it can be any modality
|
||||
}
|
||||
|
||||
# name of the feature filed
|
||||
@@ -222,7 +222,6 @@ class BaseMultimodalProcessor(ABC):
|
||||
audio_data,
|
||||
input_text,
|
||||
request_obj,
|
||||
max_req_input_len,
|
||||
**kwargs,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
pass
|
||||
@@ -283,7 +282,7 @@ class BaseMultimodalProcessor(ABC):
|
||||
self,
|
||||
text_parts: List[str],
|
||||
multimodal_tokens: MultimodalSpecialTokens,
|
||||
data_iterators: dict,
|
||||
data_iterators: dict[Modality, Iterator[Any]],
|
||||
discard_alpha_channel: bool = True,
|
||||
image_estimated_frames_iter: Optional[iter] = None,
|
||||
image_scaling_factor: float = 1.0,
|
||||
@@ -354,7 +353,6 @@ class BaseMultimodalProcessor(ABC):
|
||||
self,
|
||||
prompt: str,
|
||||
multimodal_tokens: MultimodalSpecialTokens,
|
||||
max_req_input_len: int,
|
||||
image_data: Optional[list] = None,
|
||||
video_data: Optional[list] = None,
|
||||
audio_data: Optional[list] = None,
|
||||
@@ -489,50 +487,11 @@ class BaseMultimodalProcessor(ABC):
|
||||
|
||||
return list(zip(indices_start.tolist(), indices_end.tolist()))
|
||||
|
||||
@staticmethod
|
||||
def _extract_processor_features(
|
||||
items: List[dict], attr_name: str
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Helper function to concat extracted attributes from processor output.
|
||||
"""
|
||||
values = [value for item in items if (value := item.get(attr_name)) is not None]
|
||||
return torch.cat(values) if values else None
|
||||
|
||||
# When we assume that all the items have the same attributes
|
||||
def _extract_processor_features_from_all_attributes(
|
||||
self, items: List[dict]
|
||||
) -> dict:
|
||||
values = {}
|
||||
# Verify all items have the same keys
|
||||
first_keys = set(items[0].keys())
|
||||
for item in items[1:]:
|
||||
if set(item.keys()) != first_keys:
|
||||
raise ValueError(
|
||||
f"All items must have the same attributes. "
|
||||
f"First item has {first_keys}, but found {set(item.keys())}"
|
||||
)
|
||||
|
||||
# Process each attribute
|
||||
for k, v in items[0].items():
|
||||
if isinstance(v, list):
|
||||
values[k] = self._extract_processor_features(items, k)
|
||||
else:
|
||||
# Verify all items have the same value for non-list attributes
|
||||
for item in items[1:]:
|
||||
if item[k] != v:
|
||||
raise ValueError(
|
||||
f"All items must have the same value for attribute {k}. "
|
||||
f"First item has {v}, but found {item[k]}"
|
||||
)
|
||||
values[k] = v
|
||||
return values
|
||||
|
||||
def collect_mm_items_from_processor_output(
|
||||
self, data_dict: dict
|
||||
) -> List[MultimodalDataItem]:
|
||||
"""Create mm_items directly from processor output."""
|
||||
items = {} # modality -> MultimodalDataItem
|
||||
items: dict[Modality, MultimodalDataItem] = {}
|
||||
|
||||
for attr_name, value in data_dict.items():
|
||||
if attr_name == "input_ids":
|
||||
@@ -541,16 +500,15 @@ class BaseMultimodalProcessor(ABC):
|
||||
# Get modality for this attribute
|
||||
modality = self.ATTR_NAME_TO_MODALITY.get(attr_name)
|
||||
|
||||
if not modality and attr_name == "precomputed_features":
|
||||
if attr_name == "precomputed_embeddings":
|
||||
modality_str = data_dict.get("modality")
|
||||
try:
|
||||
modality = (
|
||||
Modality.from_str(modality_str)
|
||||
if modality_str
|
||||
else Modality.IMAGE
|
||||
)
|
||||
except ValueError:
|
||||
modality = Modality.IMAGE
|
||||
modality = Modality.IMAGE
|
||||
if modality_str:
|
||||
try:
|
||||
modality = Modality.from_str(modality_str)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if modality:
|
||||
# Create item if needed
|
||||
if modality not in items:
|
||||
@@ -559,8 +517,7 @@ class BaseMultimodalProcessor(ABC):
|
||||
if attr_name in self.FEATURE_NAMES:
|
||||
attr_name = "feature"
|
||||
|
||||
# Set attribute
|
||||
setattr(items[modality], attr_name, value)
|
||||
items[modality].set(attr_name, value)
|
||||
|
||||
return list(items.values())
|
||||
|
||||
@@ -586,6 +543,7 @@ class BaseMultimodalProcessor(ABC):
|
||||
self,
|
||||
base_output: BaseMultiModalProcessorOutput,
|
||||
mm_tokens: MultimodalSpecialTokens,
|
||||
**kwargs,
|
||||
) -> Tuple[List[MultimodalDataItem], torch.Tensor, dict]:
|
||||
"""
|
||||
Process multimodal data and return the combined multimodal items and input_ids.
|
||||
@@ -618,7 +576,7 @@ class BaseMultimodalProcessor(ABC):
|
||||
else:
|
||||
raise ValueError(f"Unknown multimodal item type: {type(item)}")
|
||||
# Process items and get input_ids
|
||||
all_collected_items = []
|
||||
all_collected_items: list[MultimodalDataItem] = []
|
||||
input_ids = None
|
||||
|
||||
# Handle dict items (already processed)
|
||||
@@ -634,6 +592,7 @@ class BaseMultimodalProcessor(ABC):
|
||||
images=raw_images,
|
||||
audios=raw_audios,
|
||||
videos=raw_videos,
|
||||
**kwargs,
|
||||
)
|
||||
all_collected_items.extend(collected_items)
|
||||
else:
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from typing import List, Union
|
||||
|
||||
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||
from sglang.srt.models.clip import CLIPModel
|
||||
from sglang.srt.multimodal.processors.base_processor import BaseMultimodalProcessor
|
||||
from sglang.srt.utils import load_image
|
||||
from sglang.srt.multimodal.processors.base_processor import (
|
||||
BaseMultimodalProcessor,
|
||||
MultimodalSpecialTokens,
|
||||
)
|
||||
|
||||
|
||||
class ClipImageProcessor(BaseMultimodalProcessor):
|
||||
@@ -11,23 +12,24 @@ class ClipImageProcessor(BaseMultimodalProcessor):
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
self.mm_tokens = MultimodalSpecialTokens(image_token="<image>").build(
|
||||
_processor
|
||||
)
|
||||
|
||||
async def process_mm_data_async(
|
||||
self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs
|
||||
):
|
||||
if isinstance(input_text, list):
|
||||
assert len(input_text) and isinstance(input_text[0], int)
|
||||
input_text = self._processor.tokenizer.decode(input_text)
|
||||
base_output = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
multimodal_tokens=self.mm_tokens,
|
||||
image_data=image_data,
|
||||
)
|
||||
|
||||
images = [load_image(image)[0] for image in image_data]
|
||||
mm_items, input_ids, _ = self.process_and_combine_mm_data(
|
||||
base_output, self.mm_tokens
|
||||
)
|
||||
|
||||
image_inputs = self.process_mm_data(input_text=input_text, images=images)
|
||||
image_inputs["data_hashes"] = [hash(str(image_data))]
|
||||
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
|
||||
image_inputs["mm_items"] = [
|
||||
MultimodalDataItem(
|
||||
feature=image_inputs["pixel_values"], modality=Modality.IMAGE
|
||||
)
|
||||
]
|
||||
|
||||
return image_inputs
|
||||
return {
|
||||
"input_ids": input_ids.tolist(),
|
||||
"mm_items": mm_items,
|
||||
}
|
||||
|
||||
@@ -33,9 +33,9 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
self.mm_tokens = MultimodalSpecialTokens(image_token="<image>").build(
|
||||
_processor
|
||||
)
|
||||
self.mm_tokens = MultimodalSpecialTokens(
|
||||
image_token="<image>", image_token_id=self._processor.image_token_id
|
||||
).build(_processor)
|
||||
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
@@ -50,36 +50,16 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
|
||||
input_text,
|
||||
image_data=image_data,
|
||||
multimodal_tokens=self.mm_tokens,
|
||||
max_req_input_len=max_req_input_len,
|
||||
)
|
||||
res = self.process_mm_data(
|
||||
input_text=base_output.input_text,
|
||||
images=base_output.images,
|
||||
mm_items, input_ids, _ = self.process_and_combine_mm_data(
|
||||
base_output,
|
||||
self.mm_tokens,
|
||||
max_req_input_len=max_req_input_len,
|
||||
conversations=base_output.input_text,
|
||||
)
|
||||
images_seq_mask = res["images_seq_mask"]
|
||||
images_spatial_crop = res["images_spatial_crop"]
|
||||
batched_images_spatial_crop = []
|
||||
batched_images_spatial_crop.append(images_spatial_crop)
|
||||
batched_images_spatial_crop = torch.stack(batched_images_spatial_crop, dim=0)
|
||||
|
||||
items = []
|
||||
input_ids = res["input_ids"]
|
||||
image_offsets = self.get_mm_items_offset(
|
||||
input_ids=input_ids, mm_token_id=self._processor.image_token_id
|
||||
)
|
||||
item = MultimodalDataItem(
|
||||
feature=res["images"],
|
||||
offsets=image_offsets,
|
||||
modality=Modality.IMAGE,
|
||||
image_emb_mask=images_seq_mask,
|
||||
image_spatial_crop=batched_images_spatial_crop,
|
||||
)
|
||||
items += [item]
|
||||
|
||||
return {
|
||||
"mm_items": items,
|
||||
"mm_items": mm_items,
|
||||
"input_ids": input_ids.tolist(),
|
||||
"im_token_id": self._processor.image_token_id,
|
||||
}
|
||||
|
||||
@@ -33,7 +33,6 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
|
||||
image_data: List[Union[str, bytes, Dict]],
|
||||
input_text,
|
||||
request_obj,
|
||||
max_req_input_len,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -41,7 +40,6 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
|
||||
prompt=input_text,
|
||||
image_data=image_data,
|
||||
multimodal_tokens=self.mm_tokens,
|
||||
max_req_input_len=max_req_input_len,
|
||||
discard_alpha_channel=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -54,7 +54,6 @@ class Gemma3nSGLangProcessor(SGLangBaseProcessor):
|
||||
audio_data: Optional[List[Union[str, bytes, Dict]]] = None,
|
||||
input_text: str = "",
|
||||
request_obj=None,
|
||||
max_req_input_len: int = 0,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -63,7 +62,6 @@ class Gemma3nSGLangProcessor(SGLangBaseProcessor):
|
||||
prompt=input_text,
|
||||
image_data=image_data,
|
||||
audio_data=audio_data,
|
||||
max_req_input_len=max_req_input_len,
|
||||
multimodal_tokens=self.mm_tokens,
|
||||
)
|
||||
|
||||
|
||||
@@ -170,13 +170,12 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
|
||||
return pixel_values, num_patches_list
|
||||
|
||||
async def process_mm_data_async(
|
||||
self, image_data, input_text, request_obj, max_req_input_len, **kwargs
|
||||
self, image_data, input_text, request_obj, **kwargs
|
||||
):
|
||||
base_output = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
image_data=image_data,
|
||||
multimodal_tokens=self.mm_tokens,
|
||||
max_req_input_len=max_req_input_len,
|
||||
discard_alpha_channel=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -11,52 +11,35 @@ from sglang.srt.multimodal.processors.base_processor import (
|
||||
class JanusProImageProcessor(BaseMultimodalProcessor):
|
||||
models = [MultiModalityCausalLM]
|
||||
|
||||
def __init__(self, hf_config, server_args, processor):
|
||||
super().__init__(hf_config, server_args, processor)
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
|
||||
self.mm_tokens = MultimodalSpecialTokens(
|
||||
image_token=processor.image_token
|
||||
).build(processor)
|
||||
image_token=_processor.image_token,
|
||||
image_token_id=_processor.image_id,
|
||||
).build(_processor)
|
||||
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
image_data: List[Union[str, bytes]],
|
||||
input_text,
|
||||
request_obj,
|
||||
max_req_input_len,
|
||||
**kwargs,
|
||||
):
|
||||
processor = self._processor
|
||||
|
||||
base_out = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
image_data=image_data,
|
||||
multimodal_tokens=self.mm_tokens,
|
||||
max_req_input_len=max_req_input_len,
|
||||
)
|
||||
|
||||
images = base_out.images
|
||||
res = self.process_mm_data(
|
||||
input_text=base_out.input_text,
|
||||
prompt=base_out.input_text,
|
||||
images=images,
|
||||
mm_items, input_ids, _ = self.process_and_combine_mm_data(
|
||||
base_out, self.mm_tokens, prompt=base_out.input_text
|
||||
)
|
||||
|
||||
input_ids = res["input_ids"].flatten()
|
||||
image_offsets = self.get_mm_items_offset(
|
||||
input_ids=input_ids, mm_token_id=processor.image_id
|
||||
)
|
||||
return {
|
||||
"mm_items": [
|
||||
MultimodalDataItem(
|
||||
feature=res["pixel_values"],
|
||||
image_emb_mask=res["images_emb_mask"],
|
||||
offsets=image_offsets,
|
||||
modality=Modality.IMAGE,
|
||||
)
|
||||
],
|
||||
"mm_items": mm_items,
|
||||
"input_ids": input_ids.tolist(),
|
||||
"im_start_id": processor.image_start_id,
|
||||
"im_end_id": processor.image_end_id,
|
||||
"im_token_id": processor.image_id,
|
||||
"im_start_id": self._processor.image_start_id,
|
||||
"im_end_id": self._processor.image_end_id,
|
||||
"im_token_id": self.mm_tokens.image_token_id,
|
||||
}
|
||||
|
||||
@@ -26,7 +26,6 @@ class KimiVLImageProcessor(SGLangBaseProcessor):
|
||||
image_data: List[Union[str, bytes, Dict]],
|
||||
input_text,
|
||||
request_obj,
|
||||
max_req_input_len,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -34,7 +33,6 @@ class KimiVLImageProcessor(SGLangBaseProcessor):
|
||||
prompt=input_text,
|
||||
image_data=image_data,
|
||||
multimodal_tokens=self.mm_tokens,
|
||||
max_req_input_len=max_req_input_len,
|
||||
)
|
||||
|
||||
mm_items, input_ids, _ = self.process_and_combine_mm_data(
|
||||
|
||||
@@ -159,7 +159,9 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
|
||||
"mm_items": [
|
||||
MultimodalDataItem(
|
||||
feature=pixel_values,
|
||||
image_sizes=image_sizes,
|
||||
model_specific_data={
|
||||
"image_sizes": image_sizes,
|
||||
},
|
||||
modality=modality,
|
||||
)
|
||||
],
|
||||
|
||||
@@ -17,10 +17,21 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
# Collect special token ids
|
||||
tokenizer = self._processor.tokenizer
|
||||
self.slice_start_id = getattr(tokenizer, "slice_start_id", None)
|
||||
self.slice_end_id = getattr(tokenizer, "slice_end_id", None)
|
||||
self.audio_start_id = getattr(tokenizer, "audio_start_id", None)
|
||||
self.audio_end_id = getattr(tokenizer, "audio_end_id", None)
|
||||
self.im_start_id = getattr(tokenizer, "im_start_id", None)
|
||||
self.im_end_id = getattr(tokenizer, "im_end_id", None)
|
||||
self.im_token_id = getattr(tokenizer, "unk_id", None)
|
||||
|
||||
self.mm_tokens = MultimodalSpecialTokens(
|
||||
image_token="(<image>./</image>)",
|
||||
audio_token="(<audio>./</audio>)",
|
||||
video_token="(<video>./</video>)",
|
||||
image_token_id=self.im_token_id,
|
||||
).build(_processor)
|
||||
|
||||
async def process_mm_data_async(
|
||||
@@ -29,12 +40,10 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
||||
audio_data: List[Union[str, bytes]],
|
||||
input_text,
|
||||
request_obj,
|
||||
max_req_input_len,
|
||||
**kwargs,
|
||||
):
|
||||
base_output = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
max_req_input_len=max_req_input_len,
|
||||
audio_data=audio_data,
|
||||
image_data=image_data,
|
||||
multimodal_tokens=self.mm_tokens,
|
||||
@@ -48,24 +57,6 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
||||
audios=base_output.audios,
|
||||
)
|
||||
|
||||
# Collect special token ids
|
||||
tokenizer = self._processor.tokenizer
|
||||
slice_start_id, slice_end_id, audio_start_id, audio_end_id = (
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
if tokenizer.slice_start_id:
|
||||
slice_start_id = tokenizer.slice_start_id
|
||||
slice_end_id = tokenizer.slice_end_id
|
||||
if hasattr(tokenizer, "audio_start_id"):
|
||||
audio_start_id = tokenizer.audio_start_id
|
||||
audio_end_id = tokenizer.audio_end_id
|
||||
|
||||
im_start_id = tokenizer.im_start_id
|
||||
im_end_id = tokenizer.im_end_id
|
||||
im_token_id = tokenizer.unk_id
|
||||
pixel_values = res["pixel_values"]
|
||||
tgt_sizes = res["tgt_sizes"]
|
||||
|
||||
@@ -102,10 +93,12 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
||||
items = []
|
||||
input_ids = res["input_ids"].flatten()
|
||||
image_offsets = self.get_mm_items_offset_by_pair(
|
||||
input_ids=input_ids, mm_start_id=im_start_id, mm_end_id=im_end_id
|
||||
input_ids=input_ids, mm_start_id=self.im_start_id, mm_end_id=self.im_end_id
|
||||
)
|
||||
slice_offsets = self.get_mm_items_offset_by_pair(
|
||||
input_ids=input_ids, mm_start_id=slice_start_id, mm_end_id=slice_end_id
|
||||
input_ids=input_ids,
|
||||
mm_start_id=self.slice_start_id,
|
||||
mm_end_id=self.slice_end_id,
|
||||
)
|
||||
image_offsets.extend(slice_offsets)
|
||||
image_offsets = sorted(image_offsets)
|
||||
@@ -114,7 +107,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
||||
item = MultimodalDataItem(
|
||||
feature=pixel_values,
|
||||
offsets=image_offsets,
|
||||
tgt_size=tgt_sizes_flat,
|
||||
model_specific_data={"tgt_size": tgt_sizes_flat},
|
||||
modality=Modality.IMAGE,
|
||||
)
|
||||
items += [item]
|
||||
@@ -124,17 +117,17 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
||||
and res["audio_features"] is not None
|
||||
and len(res["audio_features"]) != 0
|
||||
):
|
||||
if audio_start_id is not None and audio_end_id is not None:
|
||||
if self.audio_start_id is not None and self.audio_end_id is not None:
|
||||
audio_offsets = self.get_mm_items_offset_by_pair(
|
||||
input_ids=input_ids,
|
||||
mm_start_id=audio_start_id,
|
||||
mm_end_id=audio_end_id,
|
||||
mm_start_id=self.audio_start_id,
|
||||
mm_end_id=self.audio_end_id,
|
||||
)
|
||||
else:
|
||||
audio_offsets = None
|
||||
item = MultimodalDataItem(
|
||||
feature=[res["audio_features"]],
|
||||
audio_feature_lens=res["audio_feature_lens"],
|
||||
model_specific_data={"audio_feature_lens": res["audio_feature_lens"]},
|
||||
offsets=audio_offsets,
|
||||
modality=Modality.AUDIO,
|
||||
)
|
||||
@@ -142,11 +135,11 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
||||
return {
|
||||
"mm_items": items,
|
||||
"input_ids": input_ids.tolist(),
|
||||
"audio_start_id": audio_start_id,
|
||||
"audio_end_id": audio_end_id,
|
||||
"im_token_id": im_token_id,
|
||||
"im_start_id": im_start_id,
|
||||
"im_end_id": im_end_id,
|
||||
"slice_start_id": slice_start_id,
|
||||
"slice_end_id": slice_end_id,
|
||||
"audio_start_id": self.audio_start_id,
|
||||
"audio_end_id": self.audio_end_id,
|
||||
"im_token_id": self.im_token_id,
|
||||
"im_start_id": self.im_start_id,
|
||||
"im_end_id": self.im_end_id,
|
||||
"slice_start_id": self.slice_start_id,
|
||||
"slice_end_id": self.slice_end_id,
|
||||
}
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from typing import List, Union
|
||||
|
||||
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||
from sglang.srt.models.mllama import MllamaForConditionalGeneration
|
||||
from sglang.srt.multimodal.processors.base_processor import BaseMultimodalProcessor
|
||||
from sglang.srt.utils import load_image
|
||||
from sglang.srt.multimodal.processors.base_processor import (
|
||||
BaseMultimodalProcessor,
|
||||
MultimodalSpecialTokens,
|
||||
)
|
||||
|
||||
|
||||
class MllamaImageProcessor(BaseMultimodalProcessor):
|
||||
@@ -11,24 +12,26 @@ class MllamaImageProcessor(BaseMultimodalProcessor):
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
self.mm_tokens = MultimodalSpecialTokens(
|
||||
image_token=self._processor.image_token,
|
||||
image_token_id=self._processor.image_token_id,
|
||||
).build(_processor)
|
||||
|
||||
async def process_mm_data_async(
|
||||
self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs
|
||||
):
|
||||
if isinstance(input_text, list):
|
||||
assert len(input_text) and isinstance(input_text[0], int)
|
||||
input_text = self._processor.tokenizer.decode(input_text)
|
||||
base_out = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
image_data=image_data,
|
||||
multimodal_tokens=self.mm_tokens,
|
||||
)
|
||||
|
||||
images = [load_image(image)[0] for image in image_data]
|
||||
image_inputs = self.process_mm_data(input_text=input_text, images=images)
|
||||
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
|
||||
image_inputs["mm_items"] = [
|
||||
MultimodalDataItem(
|
||||
feature=image_inputs["pixel_values"],
|
||||
aspect_ratio_id=image_inputs["aspect_ratio_ids"],
|
||||
aspect_ratio_mask=image_inputs["aspect_ratio_mask"],
|
||||
modality=Modality.IMAGE,
|
||||
)
|
||||
]
|
||||
mm_items, input_ids, _ = self.process_and_combine_mm_data(
|
||||
base_out, self.mm_tokens
|
||||
)
|
||||
|
||||
return image_inputs
|
||||
return {
|
||||
"mm_items": mm_items,
|
||||
"input_ids": input_ids.tolist(),
|
||||
"im_token_id": self.mm_tokens.image_token_id,
|
||||
}
|
||||
|
||||
@@ -27,13 +27,13 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
|
||||
self.image_token_index = hf_config.image_token_index
|
||||
self.multimodal_tokens = MultimodalSpecialTokens(
|
||||
image_token=_processor.image_token,
|
||||
image_token_id=self.image_token_index,
|
||||
).build(_processor)
|
||||
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
image_data: List[Union[str, bytes]],
|
||||
input_text,
|
||||
max_req_input_len=None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -45,7 +45,6 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
|
||||
processed_data = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
multimodal_tokens=self.multimodal_tokens,
|
||||
max_req_input_len=max_req_input_len or 4096,
|
||||
image_data=image_data,
|
||||
return_text=True,
|
||||
)
|
||||
|
||||
@@ -31,6 +31,7 @@ class Phi4MMProcessorAdapter(ProcessorMixin):
|
||||
for hf_key, sglang_key in key_mapping.items():
|
||||
if hf_key in result:
|
||||
result[sglang_key] = result[hf_key]
|
||||
del result[hf_key]
|
||||
|
||||
# Filter out None or empty tensors from the result.
|
||||
# This prevents the sglang function base_processor.collect_mm_items_from_processor_output()
|
||||
@@ -58,7 +59,7 @@ class Phi4MMMultimodalProcessor(BaseMultimodalProcessor):
|
||||
self.AUDIO_TOKEN_ID = 200011
|
||||
self.AUDIO_SAMPLE_RATE = 16000
|
||||
|
||||
self.multimodal_tokens = MultimodalSpecialTokens(
|
||||
self.mm_tokens = MultimodalSpecialTokens(
|
||||
image_token=self.IMAGE_TOKEN,
|
||||
image_token_id=self.IM_TOKEN_ID,
|
||||
audio_token=self.AUDIO_TOKEN,
|
||||
@@ -71,15 +72,13 @@ class Phi4MMMultimodalProcessor(BaseMultimodalProcessor):
|
||||
audio_data,
|
||||
input_text,
|
||||
request_obj,
|
||||
max_req_input_len,
|
||||
**kwargs,
|
||||
):
|
||||
base_output = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
max_req_input_len=max_req_input_len,
|
||||
audio_data=audio_data,
|
||||
image_data=image_data,
|
||||
multimodal_tokens=self.multimodal_tokens,
|
||||
multimodal_tokens=self.mm_tokens,
|
||||
audio_sample_rate=self.AUDIO_SAMPLE_RATE,
|
||||
)
|
||||
|
||||
@@ -91,12 +90,12 @@ class Phi4MMMultimodalProcessor(BaseMultimodalProcessor):
|
||||
]
|
||||
|
||||
mm_items, input_ids, _ = self.process_and_combine_mm_data(
|
||||
base_output, self.multimodal_tokens
|
||||
base_output, self.mm_tokens
|
||||
)
|
||||
|
||||
return {
|
||||
"input_ids": input_ids.tolist(),
|
||||
"mm_items": mm_items,
|
||||
"im_token_id": self.IM_TOKEN_ID,
|
||||
"audio_token_id": self.AUDIO_TOKEN_ID,
|
||||
"im_token_id": self.mm_tokens.image_token_id,
|
||||
"audio_token_id": self.mm_tokens.audio_token_id,
|
||||
}
|
||||
|
||||
@@ -6,7 +6,6 @@ from transformers.models.pixtral.image_processing_pixtral import (
|
||||
_num_image_tokens as _get_pixtral_hf_num_image_tokens,
|
||||
)
|
||||
|
||||
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||
from sglang.srt.models.pixtral import PixtralVisionModel
|
||||
from sglang.srt.multimodal.processors.base_processor import (
|
||||
BaseMultimodalProcessor,
|
||||
@@ -45,7 +44,7 @@ class PixtralProcessor(BaseMultimodalProcessor):
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
self.image_token_id = getattr(
|
||||
self.IM_TOKEN_ID = getattr(
|
||||
hf_config, "image_token_index", PixtralVisionModel.DEFAULT_IMAGE_TOKEN_ID
|
||||
)
|
||||
# Instantiate the patcher logic helper using the class defined above
|
||||
@@ -53,8 +52,9 @@ class PixtralProcessor(BaseMultimodalProcessor):
|
||||
self.vision_config = hf_config.vision_config
|
||||
self.image_size = self.vision_config.image_size
|
||||
self.patch_size = self.vision_config.patch_size
|
||||
self.multimodal_tokens = MultimodalSpecialTokens(
|
||||
image_token=_processor.image_token
|
||||
self.mm_tokens = MultimodalSpecialTokens(
|
||||
image_token=_processor.image_token,
|
||||
image_token_id=self.IM_TOKEN_ID,
|
||||
).build(_processor)
|
||||
_processor.tokenizer.add_special_tokens(
|
||||
{
|
||||
@@ -80,42 +80,21 @@ class PixtralProcessor(BaseMultimodalProcessor):
|
||||
):
|
||||
mm_data = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
multimodal_tokens=self.multimodal_tokens,
|
||||
max_req_input_len=kwargs.get("max_req_input_len", 4096),
|
||||
multimodal_tokens=self.mm_tokens,
|
||||
image_data=image_data,
|
||||
return_text=True,
|
||||
)
|
||||
|
||||
if mm_data.images:
|
||||
resize_tasks = [self._resize(image) for image in mm_data.images]
|
||||
mm_data.images = await asyncio.gather(*resize_tasks)
|
||||
|
||||
processor_output = self.process_mm_data(
|
||||
input_text=mm_data.input_text,
|
||||
images=mm_data.images,
|
||||
mm_items, input_ids, _ = self.process_and_combine_mm_data(
|
||||
mm_data, self.mm_tokens
|
||||
)
|
||||
|
||||
if "pixel_values" in processor_output:
|
||||
input_ids = processor_output["input_ids"].view(-1)
|
||||
image_offsets = self.get_mm_items_offset(
|
||||
input_ids=input_ids,
|
||||
mm_token_id=self.image_token_id,
|
||||
)
|
||||
mm_items = [
|
||||
MultimodalDataItem(
|
||||
feature=processor_output["pixel_values"],
|
||||
image_sizes=processor_output["image_sizes"],
|
||||
modality=Modality.IMAGE,
|
||||
offsets=image_offsets,
|
||||
)
|
||||
]
|
||||
|
||||
input_ids = input_ids.tolist()
|
||||
processor_output.update(
|
||||
input_ids=input_ids,
|
||||
mm_items=mm_items,
|
||||
# there's no im_start_id for pixtral, only im_token and im_end_token
|
||||
im_end_id=self.IMG_END_TOKEN_ID,
|
||||
im_token_id=self.image_token_id,
|
||||
)
|
||||
return processor_output
|
||||
return {
|
||||
"mm_items": mm_items,
|
||||
"input_ids": input_ids.tolist(),
|
||||
"im_token_id": self.IM_TOKEN_ID,
|
||||
"im_token": self._processor.image_token,
|
||||
}
|
||||
|
||||
65
python/sglang/srt/multimodal/processors/qwen_audio.py
Normal file
65
python/sglang/srt/multimodal/processors/qwen_audio.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import re
|
||||
|
||||
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||
from sglang.srt.models.qwen2_audio import Qwen2AudioForConditionalGeneration
|
||||
from sglang.srt.multimodal.processors.base_processor import (
|
||||
BaseMultimodalProcessor,
|
||||
MultimodalSpecialTokens,
|
||||
)
|
||||
|
||||
|
||||
class Qwen2AudioMultimodalProcessor(BaseMultimodalProcessor):
|
||||
models = [Qwen2AudioForConditionalGeneration]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
self.AUDIO_TOKEN = "<|audio_bos|><|AUDIO|><|audio_eos|>"
|
||||
self.AUDIO_TOKEN_REGEX = re.compile(
|
||||
r"<\|audio_bos\|>(?:<\|AUDIO\|>)+<\|audio_eos\|>"
|
||||
)
|
||||
# Collect special token ids
|
||||
tokenizer = self._processor.tokenizer
|
||||
self.audio_start_id = tokenizer.convert_tokens_to_ids("<|audio_bos|>")
|
||||
self.audio_token_id = tokenizer.convert_tokens_to_ids("<|AUDIO|>")
|
||||
self.audio_end_id = tokenizer.convert_tokens_to_ids("<|audio_eos|>")
|
||||
|
||||
self.mm_tokens = MultimodalSpecialTokens(
|
||||
audio_token=self.AUDIO_TOKEN,
|
||||
audio_token_regex=self.AUDIO_TOKEN_REGEX,
|
||||
audio_token_id=self.audio_token_id,
|
||||
).build(_processor)
|
||||
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
audio_data,
|
||||
input_text,
|
||||
**kwargs,
|
||||
):
|
||||
base_output = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
audio_data=audio_data,
|
||||
multimodal_tokens=self.mm_tokens,
|
||||
)
|
||||
if base_output is None:
|
||||
return None
|
||||
|
||||
mm_items, input_ids, ret = self.process_and_combine_mm_data(
|
||||
base_output, self.mm_tokens
|
||||
)
|
||||
|
||||
assert (
|
||||
"feature_attention_mask" in ret
|
||||
), "feature_attention_mask not found in processor output"
|
||||
input_lengths = ret["feature_attention_mask"].sum(dim=-1)
|
||||
input_lengths = (input_lengths - 1) // 2 + 1
|
||||
output_lengths = (input_lengths - 2) // 2 + 1
|
||||
|
||||
mm_items[0].model_specific_data["audio_feature_lens"] = output_lengths
|
||||
|
||||
return {
|
||||
"mm_items": mm_items,
|
||||
"input_ids": input_ids.tolist(),
|
||||
"audio_start_id": self.audio_start_id,
|
||||
"audio_token_id": self.audio_token_id,
|
||||
"audio_end_id": self.audio_end_id,
|
||||
}
|
||||
@@ -227,7 +227,6 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
||||
image_data: List[Union[str, bytes]],
|
||||
input_text,
|
||||
request_obj,
|
||||
max_req_input_len,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -237,7 +236,6 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
||||
image_data=image_data,
|
||||
video_data=request_obj.video_data,
|
||||
multimodal_tokens=self.mm_tokens,
|
||||
max_req_input_len=max_req_input_len,
|
||||
)
|
||||
|
||||
# Qwen-specific: resize images if they are raw Image objects
|
||||
|
||||
@@ -47,13 +47,11 @@ class VILAMultimodalProcessor(BaseMultimodalProcessor):
|
||||
image_data: Optional[ImageDataInputItem | List[ImageDataInputItem]],
|
||||
input_text: str | List[int],
|
||||
request_obj: GenerateReqInput | EmbeddingReqInput,
|
||||
max_req_input_len: int,
|
||||
**kwargs,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
base_output = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
multimodal_tokens=self.mm_tokens,
|
||||
max_req_input_len=max_req_input_len,
|
||||
image_data=image_data,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user