Refactor mm processors and Enable mixed modality processing (#7629)
Signed-off-by: Xinyuan Tong <justinning0323@outlook.com>
This commit is contained in:
@@ -125,74 +125,38 @@ class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPa
|
||||
e.g. <image><image>....<image>, or <audio><audio>...<audio>
|
||||
"""
|
||||
|
||||
def __init__(self, token_ids: List[int]) -> None:
|
||||
self.token_ids = token_ids
|
||||
|
||||
def pad_input_tokens(
|
||||
self, input_ids: List[int], mm_inputs: MultimodalInputs
|
||||
) -> List[int]:
|
||||
"""
|
||||
Finds contiguous regions of tokens matching `self.token_ids` in `input_ids`
|
||||
and replaces each region with the corresponding `pad_value` from `mm_inputs.mm_items`.
|
||||
Replaces multimodal tokens in input_ids with corresponding pad_values from mm_items.
|
||||
Each modality (image, audio, video) is handled separately based on its token_id.
|
||||
"""
|
||||
pad_values = [item.pad_value for item in mm_inputs.mm_items]
|
||||
if not pad_values:
|
||||
# No multimodal items, return original input_ids
|
||||
if not input_ids or not mm_inputs.mm_items:
|
||||
return input_ids
|
||||
if not input_ids:
|
||||
return []
|
||||
|
||||
input_ids_tensor = torch.tensor(input_ids)
|
||||
device = input_ids_tensor.device
|
||||
token_ids_tensor = torch.tensor(self.token_ids, device=device)
|
||||
mask = torch.isin(input_ids_tensor, token_ids_tensor)
|
||||
|
||||
if not mask.any():
|
||||
# No tokens match token_ids, return original input_ids
|
||||
return input_ids
|
||||
# Create mapping of token_ids to pad_values for each modality
|
||||
token_to_pad_mapping = {}
|
||||
|
||||
# Find contiguous regions
|
||||
padded_mask = torch.cat(
|
||||
(
|
||||
torch.tensor([False], device=device),
|
||||
mask,
|
||||
torch.tensor([False], device=device),
|
||||
)
|
||||
)
|
||||
# Find indices where the mask value changes
|
||||
diff_indices = torch.where(padded_mask[1:] != padded_mask[:-1])[0]
|
||||
|
||||
# Start indices are where False changes to True
|
||||
starts = diff_indices[::2]
|
||||
# End indices are where True changes to False (exclusive index)
|
||||
ends = diff_indices[1::2]
|
||||
|
||||
# Check if the number of regions matches the number of pad values
|
||||
if len(starts) != len(pad_values):
|
||||
# Maybe log a warning here?
|
||||
num_regions = len(starts)
|
||||
num_pad_values = len(pad_values)
|
||||
if num_regions > 0 and num_pad_values > 0:
|
||||
pad_values = (pad_values * (num_regions // num_pad_values + 1))[
|
||||
:num_regions
|
||||
]
|
||||
else: # If no regions or no pad_values, this loop won't run anyway.
|
||||
pad_values = [] # Ensure pad_values is empty if starts is empty
|
||||
|
||||
# Create a copy to modify
|
||||
output_ids_tensor = input_ids_tensor.clone()
|
||||
|
||||
# Replace tokens in each region with the corresponding pad value
|
||||
# Ensure we don't iterate if pad_values became empty due to mismatch and num_regions=0
|
||||
for i in range(min(len(starts), len(pad_values))):
|
||||
start_idx = starts[i]
|
||||
end_idx = ends[i]
|
||||
pad_value = pad_values[i]
|
||||
if pad_value is not None: # Ensure pad_value is not None before assignment
|
||||
output_ids_tensor[start_idx:end_idx] = pad_value
|
||||
for item in mm_inputs.mm_items:
|
||||
if item.is_image() and mm_inputs.im_token_id is not None:
|
||||
token_to_pad_mapping[mm_inputs.im_token_id] = item.pad_value
|
||||
elif item.is_audio() and mm_inputs.audio_token_id is not None:
|
||||
token_to_pad_mapping[mm_inputs.audio_token_id] = item.pad_value
|
||||
elif item.is_video() and mm_inputs.video_token_id is not None:
|
||||
token_to_pad_mapping[mm_inputs.video_token_id] = item.pad_value
|
||||
else:
|
||||
logger.warning(f"Skipping region {i} due to None pad_value.")
|
||||
return output_ids_tensor.tolist()
|
||||
raise ValueError(f"No multimodal token id provided for {item.modality}")
|
||||
|
||||
# Apply replacements for all tokens at once
|
||||
for token_id, pad_value in token_to_pad_mapping.items():
|
||||
input_ids_tensor[input_ids_tensor == token_id] = pad_value
|
||||
|
||||
ret_input_ids = input_ids_tensor.tolist()
|
||||
|
||||
return ret_input_ids
|
||||
|
||||
|
||||
embedding_cache = None
|
||||
|
||||
@@ -174,6 +174,15 @@ class Modality(Enum):
|
||||
VIDEO = auto()
|
||||
AUDIO = auto()
|
||||
|
||||
@staticmethod
|
||||
def from_str(modality_str: str):
|
||||
try:
|
||||
return Modality[modality_str.upper()]
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
f"Invalid modality string: {modality_str}. Valid modalities are: {[m.name for m in Modality]}"
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MultimodalDataItem:
|
||||
|
||||
@@ -482,20 +482,25 @@ class TokenizerManager:
|
||||
token_type_ids = encoded.get("token_type_ids", [None])[0]
|
||||
|
||||
if self.mm_processor and obj.contains_mm_input():
|
||||
image_inputs: Dict = await self.mm_processor.process_mm_data_async(
|
||||
if not isinstance(obj.image_data, list):
|
||||
obj.image_data = [obj.image_data]
|
||||
if not isinstance(obj.audio_data, list):
|
||||
obj.audio_data = [obj.audio_data]
|
||||
mm_inputs: Dict = await self.mm_processor.process_mm_data_async(
|
||||
image_data=obj.image_data,
|
||||
audio_data=obj.audio_data,
|
||||
input_text=input_text or input_ids,
|
||||
request_obj=obj,
|
||||
max_req_input_len=self.max_req_input_len,
|
||||
)
|
||||
if image_inputs and "input_ids" in image_inputs:
|
||||
input_ids = image_inputs["input_ids"]
|
||||
if mm_inputs and "input_ids" in mm_inputs:
|
||||
input_ids = mm_inputs["input_ids"]
|
||||
else:
|
||||
image_inputs: Optional[Dict] = None
|
||||
mm_inputs = None
|
||||
|
||||
self._validate_one_request(obj, input_ids)
|
||||
return self._create_tokenized_object(
|
||||
obj, input_text, input_ids, input_embeds, image_inputs, token_type_ids
|
||||
obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids
|
||||
)
|
||||
|
||||
def _validate_one_request(
|
||||
@@ -559,7 +564,7 @@ class TokenizerManager:
|
||||
input_text: str,
|
||||
input_ids: List[int],
|
||||
input_embeds: Optional[Union[List[float], None]] = None,
|
||||
image_inputs: Optional[Dict] = None,
|
||||
mm_inputs: Optional[Dict] = None,
|
||||
token_type_ids: Optional[List[int]] = None,
|
||||
) -> Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]:
|
||||
"""Create a tokenized request object from common parameters."""
|
||||
@@ -584,7 +589,7 @@ class TokenizerManager:
|
||||
obj.rid,
|
||||
input_text,
|
||||
input_ids,
|
||||
image_inputs,
|
||||
mm_inputs,
|
||||
sampling_params,
|
||||
obj.return_logprob,
|
||||
obj.logprob_start_len,
|
||||
@@ -606,7 +611,7 @@ class TokenizerManager:
|
||||
obj.rid,
|
||||
input_text,
|
||||
input_ids,
|
||||
image_inputs,
|
||||
mm_inputs,
|
||||
token_type_ids,
|
||||
sampling_params,
|
||||
)
|
||||
@@ -644,9 +649,9 @@ class TokenizerManager:
|
||||
) -> None:
|
||||
"""Validate constraints for batch tokenization processing."""
|
||||
for i in range(batch_size):
|
||||
if self.is_generation and obj[i].image_data:
|
||||
if self.is_generation and obj[i].contains_mm_input():
|
||||
raise ValueError(
|
||||
"For image input processing do not set `enable_tokenizer_batch_encode`."
|
||||
"For multimodal input processing do not set `enable_tokenizer_batch_encode`."
|
||||
)
|
||||
if obj[i].input_ids is not None:
|
||||
raise ValueError(
|
||||
|
||||
Reference in New Issue
Block a user