Refactor mm processors and Enable mixed modality processing (#7629)

Signed-off-by: Xinyuan Tong <justinning0323@outlook.com>
This commit is contained in:
Xinyuan Tong
2025-06-30 23:14:48 -07:00
committed by GitHub
parent 886d344964
commit 3a911b854d
28 changed files with 235 additions and 428 deletions

View File

@@ -125,74 +125,38 @@ class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPa
e.g. <image><image>....<image>, or <audio><audio>...<audio>
"""
def __init__(self, token_ids: List[int]) -> None:
self.token_ids = token_ids
def pad_input_tokens(
self, input_ids: List[int], mm_inputs: MultimodalInputs
) -> List[int]:
"""
Finds contiguous regions of tokens matching `self.token_ids` in `input_ids`
and replaces each region with the corresponding `pad_value` from `mm_inputs.mm_items`.
Replaces multimodal tokens in input_ids with corresponding pad_values from mm_items.
Each modality (image, audio, video) is handled separately based on its token_id.
"""
pad_values = [item.pad_value for item in mm_inputs.mm_items]
if not pad_values:
# No multimodal items, return original input_ids
if not input_ids or not mm_inputs.mm_items:
return input_ids
if not input_ids:
return []
input_ids_tensor = torch.tensor(input_ids)
device = input_ids_tensor.device
token_ids_tensor = torch.tensor(self.token_ids, device=device)
mask = torch.isin(input_ids_tensor, token_ids_tensor)
if not mask.any():
# No tokens match token_ids, return original input_ids
return input_ids
# Create mapping of token_ids to pad_values for each modality
token_to_pad_mapping = {}
# Find contiguous regions
padded_mask = torch.cat(
(
torch.tensor([False], device=device),
mask,
torch.tensor([False], device=device),
)
)
# Find indices where the mask value changes
diff_indices = torch.where(padded_mask[1:] != padded_mask[:-1])[0]
# Start indices are where False changes to True
starts = diff_indices[::2]
# End indices are where True changes to False (exclusive index)
ends = diff_indices[1::2]
# Check if the number of regions matches the number of pad values
if len(starts) != len(pad_values):
# Maybe log a warning here?
num_regions = len(starts)
num_pad_values = len(pad_values)
if num_regions > 0 and num_pad_values > 0:
pad_values = (pad_values * (num_regions // num_pad_values + 1))[
:num_regions
]
else: # If no regions or no pad_values, this loop won't run anyway.
pad_values = [] # Ensure pad_values is empty if starts is empty
# Create a copy to modify
output_ids_tensor = input_ids_tensor.clone()
# Replace tokens in each region with the corresponding pad value
# Ensure we don't iterate if pad_values became empty due to mismatch and num_regions=0
for i in range(min(len(starts), len(pad_values))):
start_idx = starts[i]
end_idx = ends[i]
pad_value = pad_values[i]
if pad_value is not None: # Ensure pad_value is not None before assignment
output_ids_tensor[start_idx:end_idx] = pad_value
for item in mm_inputs.mm_items:
if item.is_image() and mm_inputs.im_token_id is not None:
token_to_pad_mapping[mm_inputs.im_token_id] = item.pad_value
elif item.is_audio() and mm_inputs.audio_token_id is not None:
token_to_pad_mapping[mm_inputs.audio_token_id] = item.pad_value
elif item.is_video() and mm_inputs.video_token_id is not None:
token_to_pad_mapping[mm_inputs.video_token_id] = item.pad_value
else:
logger.warning(f"Skipping region {i} due to None pad_value.")
return output_ids_tensor.tolist()
raise ValueError(f"No multimodal token id provided for {item.modality}")
# Apply replacements for all tokens at once
for token_id, pad_value in token_to_pad_mapping.items():
input_ids_tensor[input_ids_tensor == token_id] = pad_value
ret_input_ids = input_ids_tensor.tolist()
return ret_input_ids
embedding_cache = None

View File

@@ -174,6 +174,15 @@ class Modality(Enum):
VIDEO = auto()
AUDIO = auto()
@staticmethod
def from_str(modality_str: str):
try:
return Modality[modality_str.upper()]
except KeyError:
raise ValueError(
f"Invalid modality string: {modality_str}. Valid modalities are: {[m.name for m in Modality]}"
)
@dataclasses.dataclass
class MultimodalDataItem:

View File

@@ -482,20 +482,25 @@ class TokenizerManager:
token_type_ids = encoded.get("token_type_ids", [None])[0]
if self.mm_processor and obj.contains_mm_input():
image_inputs: Dict = await self.mm_processor.process_mm_data_async(
if not isinstance(obj.image_data, list):
obj.image_data = [obj.image_data]
if not isinstance(obj.audio_data, list):
obj.audio_data = [obj.audio_data]
mm_inputs: Dict = await self.mm_processor.process_mm_data_async(
image_data=obj.image_data,
audio_data=obj.audio_data,
input_text=input_text or input_ids,
request_obj=obj,
max_req_input_len=self.max_req_input_len,
)
if image_inputs and "input_ids" in image_inputs:
input_ids = image_inputs["input_ids"]
if mm_inputs and "input_ids" in mm_inputs:
input_ids = mm_inputs["input_ids"]
else:
image_inputs: Optional[Dict] = None
mm_inputs = None
self._validate_one_request(obj, input_ids)
return self._create_tokenized_object(
obj, input_text, input_ids, input_embeds, image_inputs, token_type_ids
obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids
)
def _validate_one_request(
@@ -559,7 +564,7 @@ class TokenizerManager:
input_text: str,
input_ids: List[int],
input_embeds: Optional[Union[List[float], None]] = None,
image_inputs: Optional[Dict] = None,
mm_inputs: Optional[Dict] = None,
token_type_ids: Optional[List[int]] = None,
) -> Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]:
"""Create a tokenized request object from common parameters."""
@@ -584,7 +589,7 @@ class TokenizerManager:
obj.rid,
input_text,
input_ids,
image_inputs,
mm_inputs,
sampling_params,
obj.return_logprob,
obj.logprob_start_len,
@@ -606,7 +611,7 @@ class TokenizerManager:
obj.rid,
input_text,
input_ids,
image_inputs,
mm_inputs,
token_type_ids,
sampling_params,
)
@@ -644,9 +649,9 @@ class TokenizerManager:
) -> None:
"""Validate constraints for batch tokenization processing."""
for i in range(batch_size):
if self.is_generation and obj[i].image_data:
if self.is_generation and obj[i].contains_mm_input():
raise ValueError(
"For image input processing do not set `enable_tokenizer_batch_encode`."
"For multimodal input processing do not set `enable_tokenizer_batch_encode`."
)
if obj[i].input_ids is not None:
raise ValueError(