479 lines
18 KiB
Python
479 lines
18 KiB
Python
"""
|
|
Multi-modality utils
|
|
"""
|
|
|
|
import logging
|
|
from abc import abstractmethod
|
|
from typing import Callable, List, Optional, Tuple
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
from sglang.srt.managers.schedule_batch import (
|
|
Modality,
|
|
MultimodalDataItem,
|
|
MultimodalInputs,
|
|
global_server_args_dict,
|
|
)
|
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|
from sglang.srt.utils import flatten_nested_list, print_warning_once
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class MultiModalityDataPaddingPattern:
|
|
"""
|
|
Data tokens (like image tokens) often need special handling during padding
|
|
to maintain model compatibility. This class provides the interface for
|
|
implementing different padding strategies for data tokens
|
|
"""
|
|
|
|
@abstractmethod
|
|
def pad_input_tokens(
|
|
self, input_ids: List[int], mm_inputs: MultimodalInputs
|
|
) -> List[int]:
|
|
"""
|
|
Pad the input ids sequence containing data tokens, and replace them with pad_values
|
|
"""
|
|
pass
|
|
|
|
|
|
class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern):
|
|
"""In this pattern, data tokens should be enclosed by special token pairs (e.g. <image>...</image>, data_token_pairs)
|
|
|
|
This strategy should be applied when data content is marked by start/end token pairs in the input sequence.
|
|
"""
|
|
|
|
def __init__(self, data_token_pairs: Optional[List[Tuple[int, int]]]) -> None:
|
|
self.data_token_id_pairs = data_token_pairs
|
|
|
|
def pad_input_tokens(
|
|
self, input_ids: List[int], mm_inputs: MultimodalInputs
|
|
) -> List[int]:
|
|
"""
|
|
This function will replace the data-tokens inbetween with pad_values accordingly
|
|
"""
|
|
pad_values = [item.pad_value for item in mm_inputs.mm_items]
|
|
data_token_pairs = self.data_token_id_pairs
|
|
mm_inputs.data_offsets = []
|
|
if data_token_pairs is None:
|
|
data_token_pairs = [mm_inputs.im_start_id, mm_inputs.im_end_id]
|
|
if data_token_pairs is None:
|
|
print_warning_once(
|
|
"No data_token_pairs provided, RadixAttention might be influenced."
|
|
)
|
|
return input_ids
|
|
start_token_ids = [s for s, _e in data_token_pairs]
|
|
end_tokens_ids = [e for _s, e in data_token_pairs]
|
|
|
|
padded_ids = []
|
|
last_idx = 0
|
|
data_idx = -1
|
|
|
|
start_indices = [i for i, x in enumerate(input_ids) if x in start_token_ids]
|
|
end_indices = [i for i, x in enumerate(input_ids) if x in end_tokens_ids]
|
|
|
|
if len(start_indices) != len(end_indices):
|
|
return input_ids
|
|
|
|
for start_idx, end_idx in zip(start_indices, end_indices):
|
|
padded_ids.extend(input_ids[last_idx : start_idx + 1])
|
|
|
|
if input_ids[start_idx] in start_token_ids:
|
|
data_idx += 1
|
|
mm_inputs.data_offsets += [start_idx]
|
|
|
|
if data_idx >= len(pad_values):
|
|
data_idx = len(pad_values) - 1
|
|
|
|
num_tokens = end_idx - start_idx - 1
|
|
pad_value = pad_values[data_idx]
|
|
padded_ids.extend([pad_value] * num_tokens)
|
|
|
|
last_idx = end_idx
|
|
|
|
padded_ids.extend(input_ids[last_idx:])
|
|
|
|
assert len(input_ids) == len(padded_ids), "Length validation fails"
|
|
return padded_ids
|
|
|
|
|
|
class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPattern):
|
|
"""In this pattern, data tokens should be represented as repetitions of a single token
|
|
e.g. <image><image>....<image>, or <audio><audio>...<audio>
|
|
"""
|
|
|
|
def __init__(self, token_ids: List[int]) -> None:
|
|
self.token_ids = token_ids
|
|
|
|
def pad_input_tokens(
|
|
self, input_ids: List[int], mm_inputs: MultimodalInputs
|
|
) -> List[int]:
|
|
"""
|
|
Finds contiguous regions of tokens matching `self.token_ids` in `input_ids`
|
|
and replaces each region with the corresponding `pad_value` from `mm_inputs.mm_items`.
|
|
"""
|
|
pad_values = [item.pad_value for item in mm_inputs.mm_items]
|
|
if not pad_values:
|
|
# No multimodal items, return original input_ids
|
|
return input_ids
|
|
if not input_ids:
|
|
return []
|
|
|
|
input_ids_tensor = torch.tensor(input_ids)
|
|
device = input_ids_tensor.device
|
|
token_ids_tensor = torch.tensor(self.token_ids, device=device)
|
|
mask = torch.isin(input_ids_tensor, token_ids_tensor)
|
|
|
|
if not mask.any():
|
|
# No tokens match token_ids, return original input_ids
|
|
return input_ids
|
|
|
|
# Find contiguous regions
|
|
padded_mask = torch.cat(
|
|
(
|
|
torch.tensor([False], device=device),
|
|
mask,
|
|
torch.tensor([False], device=device),
|
|
)
|
|
)
|
|
# Find indices where the mask value changes
|
|
diff_indices = torch.where(padded_mask[1:] != padded_mask[:-1])[0]
|
|
|
|
# Start indices are where False changes to True
|
|
starts = diff_indices[::2]
|
|
# End indices are where True changes to False (exclusive index)
|
|
ends = diff_indices[1::2]
|
|
|
|
# Check if the number of regions matches the number of pad values
|
|
if len(starts) != len(pad_values):
|
|
# Maybe log a warning here?
|
|
num_regions = len(starts)
|
|
num_pad_values = len(pad_values)
|
|
if num_regions > 0 and num_pad_values > 0:
|
|
pad_values = (pad_values * (num_regions // num_pad_values + 1))[
|
|
:num_regions
|
|
]
|
|
else: # If no regions or no pad_values, this loop won't run anyway.
|
|
pad_values = [] # Ensure pad_values is empty if starts is empty
|
|
|
|
# Create a copy to modify
|
|
output_ids_tensor = input_ids_tensor.clone()
|
|
|
|
# Replace tokens in each region with the corresponding pad value
|
|
# Ensure we don't iterate if pad_values became empty due to mismatch and num_regions=0
|
|
for i in range(min(len(starts), len(pad_values))):
|
|
start_idx = starts[i]
|
|
end_idx = ends[i]
|
|
pad_value = pad_values[i]
|
|
if pad_value is not None: # Ensure pad_value is not None before assignment
|
|
output_ids_tensor[start_idx:end_idx] = pad_value
|
|
else:
|
|
logger.warning(f"Skipping region {i} due to None pad_value.")
|
|
|
|
return output_ids_tensor.tolist()
|
|
|
|
|
|
def get_embedding_and_mask(
|
|
data_embedding_func: Callable[[List[MultimodalDataItem]], torch.Tensor],
|
|
embedding_items: List[MultimodalDataItem],
|
|
placeholder_tensor: torch.Tensor,
|
|
input_ids: torch.Tensor,
|
|
):
|
|
"""
|
|
Get the multimodal embedding and its mask from input_ids
|
|
|
|
"""
|
|
# 1. Get the embedding
|
|
embedding = data_embedding_func(embedding_items)
|
|
|
|
# 2. Check the embedding
|
|
if embedding.dim() == 2:
|
|
num_mm_tokens_in_embedding = embedding.shape[0]
|
|
else:
|
|
num_mm_tokens_in_embedding = embedding.shape[0] * embedding.shape[1]
|
|
|
|
# the mask of multimodal tokens from input_ids
|
|
special_multimodal_mask = torch.isin(
|
|
input_ids,
|
|
placeholder_tensor,
|
|
).unsqueeze(-1)
|
|
|
|
num_mm_tokens_in_input_ids = special_multimodal_mask.sum().item()
|
|
if num_mm_tokens_in_input_ids != num_mm_tokens_in_embedding:
|
|
logger.warning(
|
|
f"Number of tokens in multimodal embedding does not match those in the input text."
|
|
f"Got {num_mm_tokens_in_input_ids} tokens in the text but {num_mm_tokens_in_embedding} "
|
|
"tokens from multimodal embeddings."
|
|
)
|
|
if num_mm_tokens_in_input_ids < num_mm_tokens_in_embedding:
|
|
# TODO: chunked prefill will split special tokens from input_ids into several passes, failing the embedding
|
|
# a fix may be cache the unfinished multimodal embedding for future reuse, determine the tokens to embed with
|
|
# extend_start_loc and extend_seq_lens
|
|
chunked_prefill_size = global_server_args_dict["chunked_prefill_size"]
|
|
if chunked_prefill_size != -1:
|
|
logger.warning(
|
|
"You may want to avoid this issue by raising `chunked_prefill_size`, or disabling chunked prefill"
|
|
)
|
|
# extract from the end: this is a compromise
|
|
if embedding.dim() == 2:
|
|
embedding = embedding[-num_mm_tokens_in_input_ids:, :]
|
|
else:
|
|
num_multimodal = num_mm_tokens_in_input_ids // embedding.shape[0]
|
|
embedding = embedding[-num_multimodal:, :]
|
|
else:
|
|
raise RuntimeError(
|
|
f"Insufficient multimodal embedding length: {num_mm_tokens_in_input_ids=} vs {num_mm_tokens_in_embedding=}. This is an internal error"
|
|
)
|
|
|
|
return embedding, special_multimodal_mask
|
|
|
|
|
|
def embed_mm_inputs(
|
|
mm_inputs: MultimodalInputs,
|
|
input_ids: torch.Tensor,
|
|
input_embedding: nn.Embedding,
|
|
image_data_embedding_func: Callable[
|
|
[List[MultimodalDataItem]], torch.Tensor
|
|
] = None,
|
|
audio_data_embedding_func: Callable[
|
|
[List[MultimodalDataItem]], torch.Tensor
|
|
] = None,
|
|
placeholder_tokens: dict[Modality, List[int]] = None,
|
|
) -> Optional[torch.Tensor]:
|
|
"""
|
|
Calculate the multimodal embeddings if necessary, then scatter the result with the help of a boolean mask denoting the embed locations
|
|
|
|
Args:
|
|
placeholder_tokens: denoting the token of multimodal data in input_ids.
|
|
If none, the pad_values of multimodal items are used
|
|
|
|
Returns:
|
|
final embedding: Optional[torch.Tensor]
|
|
"""
|
|
|
|
if mm_inputs is None:
|
|
return None
|
|
|
|
# 1. Calculate the multimodal data which exists in input_ids, with the help of pad_values
|
|
# we assume that multimodal data are represented with its pad_values in input_ids
|
|
# See `pad_input_ids` for more detail
|
|
|
|
# if placeholder_tokens is specified
|
|
if placeholder_tokens is not None:
|
|
placeholder_token_ids = flatten_nested_list(
|
|
[placeholder_token for placeholder_token in placeholder_tokens.values()]
|
|
)
|
|
else:
|
|
placeholder_token_ids = [item.pad_value for item in mm_inputs.mm_items]
|
|
|
|
assert isinstance(placeholder_token_ids[0], int)
|
|
|
|
placeholder_tensor = torch.tensor(placeholder_token_ids, device=input_ids.device)
|
|
|
|
placeholder_masks = torch.isin(input_ids, placeholder_tensor)
|
|
|
|
appearing_pad_values = torch.unique(
|
|
input_ids[placeholder_masks], return_counts=False
|
|
)
|
|
|
|
if appearing_pad_values.numel() == 0:
|
|
# all been prefixed
|
|
inputs_embeds = input_embedding(input_ids)
|
|
else:
|
|
appearing_items = [
|
|
item
|
|
for item in mm_inputs.mm_items
|
|
if item.pad_value is not None and item.pad_value in appearing_pad_values
|
|
]
|
|
|
|
using_all_items = False
|
|
if len(appearing_items) == 0:
|
|
# This happens mostly when arg placeholder_token_ids is passed
|
|
logger.warning(
|
|
"No multimodal data item's pad value exist in placeholder ids. Using all items"
|
|
)
|
|
using_all_items = True
|
|
appearing_items = mm_inputs.mm_items
|
|
|
|
embeddings, masks = [], []
|
|
|
|
# 2. Get multimodal embedding separately
|
|
# TODO: make this more generic
|
|
# Try get image embedding if any
|
|
if (
|
|
any(True for item in appearing_items if item.is_image())
|
|
and image_data_embedding_func
|
|
):
|
|
items = [item for item in appearing_items if item.is_image()]
|
|
embedding, mask = get_embedding_and_mask(
|
|
data_embedding_func=image_data_embedding_func,
|
|
embedding_items=items,
|
|
placeholder_tensor=(
|
|
# use the specified modality token to identify the location to embed
|
|
placeholder_tokens[Modality.IMAGE]
|
|
if using_all_items
|
|
else torch.tensor(
|
|
[item.pad_value for item in items],
|
|
device=input_ids.device,
|
|
)
|
|
),
|
|
input_ids=input_ids,
|
|
)
|
|
embeddings += [embedding]
|
|
masks += [mask]
|
|
|
|
# Try get audio embedding if any
|
|
if (
|
|
any(True for item in appearing_items if item.is_audio())
|
|
and audio_data_embedding_func
|
|
):
|
|
items = [item for item in appearing_items if item.is_audio()]
|
|
embedding, mask = get_embedding_and_mask(
|
|
data_embedding_func=audio_data_embedding_func,
|
|
embedding_items=items,
|
|
placeholder_tensor=(
|
|
placeholder_tokens[Modality.AUDIO]
|
|
if using_all_items
|
|
else torch.tensor(
|
|
[item.pad_value for item in items],
|
|
device=input_ids.device,
|
|
)
|
|
),
|
|
input_ids=input_ids,
|
|
)
|
|
embeddings += [embedding]
|
|
masks += [mask]
|
|
|
|
# 3. Get input embeddings
|
|
vocab_size = input_embedding.num_embeddings
|
|
# Important: clamp after getting original multimodal regions
|
|
# Clamp input ids. This is because the input_ids for the multimodal tokens are
|
|
# filled with the hash values of the multimodal for the prefix matching in the radix attention.
|
|
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
|
|
input_ids.clamp_(min=0, max=vocab_size - 1)
|
|
inputs_embeds = input_embedding(input_ids)
|
|
|
|
# 4. Scatter embeddings into input embedding
|
|
for embedding, mask in zip(embeddings, masks):
|
|
mask = mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
|
inputs_embeds = inputs_embeds.masked_scatter(
|
|
mask,
|
|
embedding.to(inputs_embeds.device, inputs_embeds.dtype),
|
|
)
|
|
return inputs_embeds
|
|
|
|
|
|
def general_mm_embed_routine(
|
|
input_ids: torch.Tensor,
|
|
forward_batch: ForwardBatch,
|
|
language_model: nn.Module,
|
|
image_data_embedding_func: Callable[
|
|
[List[MultimodalDataItem]], torch.Tensor
|
|
] = None,
|
|
audio_data_embedding_func: Callable[
|
|
[List[MultimodalDataItem]], torch.Tensor
|
|
] = None,
|
|
placeholder_tokens: dict[Modality, List[int]] = None,
|
|
**kwargs,
|
|
) -> torch.Tensor:
|
|
"""
|
|
A general wrapper function to get final input embeds from multimodal models with a language model as causal model
|
|
|
|
Args:
|
|
placeholder_token_ids (List[int]): the ids of mm data placeholder tokens
|
|
image_data_embedding_func : the function returning the image embedding
|
|
audio_data_embedding_func : the function returning the image embedding
|
|
|
|
Returns:
|
|
forwarded hidden states
|
|
|
|
"""
|
|
|
|
assert hasattr(language_model, "get_input_embeddings")
|
|
embed_tokens = language_model.get_input_embeddings()
|
|
if (
|
|
not forward_batch.forward_mode.is_decode()
|
|
and forward_batch.contains_mm_inputs()
|
|
):
|
|
mm_input = forward_batch.merge_mm_inputs()
|
|
inputs_embeds = embed_mm_inputs(
|
|
mm_inputs=mm_input,
|
|
input_ids=input_ids,
|
|
input_embedding=embed_tokens,
|
|
image_data_embedding_func=image_data_embedding_func,
|
|
audio_data_embedding_func=audio_data_embedding_func,
|
|
placeholder_tokens=placeholder_tokens,
|
|
)
|
|
# once used, mm_inputs is useless, considering chunked-prefill is disabled for multimodal models
|
|
# just being defensive here
|
|
forward_batch.mm_inputs = None
|
|
else:
|
|
inputs_embeds = embed_tokens(input_ids)
|
|
|
|
hidden_states = language_model(
|
|
input_ids=None,
|
|
forward_batch=forward_batch,
|
|
input_embeds=inputs_embeds,
|
|
**kwargs,
|
|
)
|
|
return hidden_states
|
|
|
|
|
|
def get_multimodal_data_bounds(
|
|
input_ids: torch.Tensor, pad_values: List[int], token_pairs: List[Tuple[int, int]]
|
|
) -> torch.Tensor:
|
|
"""
|
|
Returns a tensor indicating the bounds of multimodal data (images, video, audio, etc.)
|
|
|
|
Returns:
|
|
[bounds_count, 2]
|
|
"""
|
|
# All the multimodal data in the batch should share the same special bound token ids.
|
|
start_tokens = [s for s, _e in token_pairs]
|
|
end_tokens = [e for _s, e in token_pairs]
|
|
|
|
assert all(isinstance(t, int) for t in start_tokens)
|
|
assert all(isinstance(t, int) for t in end_tokens)
|
|
|
|
start_cond = torch.isin(
|
|
input_ids, torch.tensor(start_tokens, device=input_ids.device)
|
|
)
|
|
end_cond = torch.isin(input_ids, torch.tensor(end_tokens, device=input_ids.device))
|
|
|
|
(data_start_tokens,) = torch.where(start_cond)
|
|
(data_end_tokens,) = torch.where(end_cond)
|
|
|
|
# the im_start_id sometimes can be cached as prefix, but it is needed for the embedding of the multimodal data
|
|
if len(data_start_tokens) != len(data_end_tokens):
|
|
if (
|
|
len(data_start_tokens) + 1 == len(data_end_tokens)
|
|
and input_ids[0] in pad_values
|
|
and data_end_tokens[0] < data_start_tokens[0]
|
|
):
|
|
data_start_tokens = torch.cat(
|
|
[
|
|
torch.tensor([0], device=data_start_tokens.device),
|
|
data_start_tokens,
|
|
]
|
|
)
|
|
valid_mm_data_nums = min(len(data_start_tokens), len(data_end_tokens))
|
|
|
|
if valid_mm_data_nums == 0:
|
|
return torch.zeros((0, 2), device=input_ids.device)
|
|
|
|
# Filter out pairs where start_token >= end_token
|
|
valid_pairs = []
|
|
for i in range(valid_mm_data_nums):
|
|
start_token = data_start_tokens[i]
|
|
end_token = data_end_tokens[i]
|
|
if start_token < end_token:
|
|
valid_pairs.append((start_token + 1, end_token - 1))
|
|
|
|
if not valid_pairs:
|
|
return torch.zeros((0, 2), device=input_ids.device)
|
|
|
|
# Convert valid pairs to tensor
|
|
valid_pairs_tensor = torch.tensor(valid_pairs, device=input_ids.device)
|
|
return valid_pairs_tensor
|