Files
sglang/python/sglang/srt/managers/mm_utils.py
2025-04-23 20:35:05 -07:00

479 lines
18 KiB
Python

"""
Multi-modality utils
"""
import logging
from abc import abstractmethod
from typing import Callable, List, Optional, Tuple
import torch
from torch import nn
from sglang.srt.managers.schedule_batch import (
Modality,
MultimodalDataItem,
MultimodalInputs,
global_server_args_dict,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import flatten_nested_list, print_warning_once
logger = logging.getLogger(__name__)
class MultiModalityDataPaddingPattern:
"""
Data tokens (like image tokens) often need special handling during padding
to maintain model compatibility. This class provides the interface for
implementing different padding strategies for data tokens
"""
@abstractmethod
def pad_input_tokens(
self, input_ids: List[int], mm_inputs: MultimodalInputs
) -> List[int]:
"""
Pad the input ids sequence containing data tokens, and replace them with pad_values
"""
pass
class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern):
"""In this pattern, data tokens should be enclosed by special token pairs (e.g. <image>...</image>, data_token_pairs)
This strategy should be applied when data content is marked by start/end token pairs in the input sequence.
"""
def __init__(self, data_token_pairs: Optional[List[Tuple[int, int]]]) -> None:
self.data_token_id_pairs = data_token_pairs
def pad_input_tokens(
self, input_ids: List[int], mm_inputs: MultimodalInputs
) -> List[int]:
"""
This function will replace the data-tokens inbetween with pad_values accordingly
"""
pad_values = [item.pad_value for item in mm_inputs.mm_items]
data_token_pairs = self.data_token_id_pairs
mm_inputs.data_offsets = []
if data_token_pairs is None:
data_token_pairs = [mm_inputs.im_start_id, mm_inputs.im_end_id]
if data_token_pairs is None:
print_warning_once(
"No data_token_pairs provided, RadixAttention might be influenced."
)
return input_ids
start_token_ids = [s for s, _e in data_token_pairs]
end_tokens_ids = [e for _s, e in data_token_pairs]
padded_ids = []
last_idx = 0
data_idx = -1
start_indices = [i for i, x in enumerate(input_ids) if x in start_token_ids]
end_indices = [i for i, x in enumerate(input_ids) if x in end_tokens_ids]
if len(start_indices) != len(end_indices):
return input_ids
for start_idx, end_idx in zip(start_indices, end_indices):
padded_ids.extend(input_ids[last_idx : start_idx + 1])
if input_ids[start_idx] in start_token_ids:
data_idx += 1
mm_inputs.data_offsets += [start_idx]
if data_idx >= len(pad_values):
data_idx = len(pad_values) - 1
num_tokens = end_idx - start_idx - 1
pad_value = pad_values[data_idx]
padded_ids.extend([pad_value] * num_tokens)
last_idx = end_idx
padded_ids.extend(input_ids[last_idx:])
assert len(input_ids) == len(padded_ids), "Length validation fails"
return padded_ids
class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPattern):
"""In this pattern, data tokens should be represented as repetitions of a single token
e.g. <image><image>....<image>, or <audio><audio>...<audio>
"""
def __init__(self, token_ids: List[int]) -> None:
self.token_ids = token_ids
def pad_input_tokens(
self, input_ids: List[int], mm_inputs: MultimodalInputs
) -> List[int]:
"""
Finds contiguous regions of tokens matching `self.token_ids` in `input_ids`
and replaces each region with the corresponding `pad_value` from `mm_inputs.mm_items`.
"""
pad_values = [item.pad_value for item in mm_inputs.mm_items]
if not pad_values:
# No multimodal items, return original input_ids
return input_ids
if not input_ids:
return []
input_ids_tensor = torch.tensor(input_ids)
device = input_ids_tensor.device
token_ids_tensor = torch.tensor(self.token_ids, device=device)
mask = torch.isin(input_ids_tensor, token_ids_tensor)
if not mask.any():
# No tokens match token_ids, return original input_ids
return input_ids
# Find contiguous regions
padded_mask = torch.cat(
(
torch.tensor([False], device=device),
mask,
torch.tensor([False], device=device),
)
)
# Find indices where the mask value changes
diff_indices = torch.where(padded_mask[1:] != padded_mask[:-1])[0]
# Start indices are where False changes to True
starts = diff_indices[::2]
# End indices are where True changes to False (exclusive index)
ends = diff_indices[1::2]
# Check if the number of regions matches the number of pad values
if len(starts) != len(pad_values):
# Maybe log a warning here?
num_regions = len(starts)
num_pad_values = len(pad_values)
if num_regions > 0 and num_pad_values > 0:
pad_values = (pad_values * (num_regions // num_pad_values + 1))[
:num_regions
]
else: # If no regions or no pad_values, this loop won't run anyway.
pad_values = [] # Ensure pad_values is empty if starts is empty
# Create a copy to modify
output_ids_tensor = input_ids_tensor.clone()
# Replace tokens in each region with the corresponding pad value
# Ensure we don't iterate if pad_values became empty due to mismatch and num_regions=0
for i in range(min(len(starts), len(pad_values))):
start_idx = starts[i]
end_idx = ends[i]
pad_value = pad_values[i]
if pad_value is not None: # Ensure pad_value is not None before assignment
output_ids_tensor[start_idx:end_idx] = pad_value
else:
logger.warning(f"Skipping region {i} due to None pad_value.")
return output_ids_tensor.tolist()
def get_embedding_and_mask(
data_embedding_func: Callable[[List[MultimodalDataItem]], torch.Tensor],
embedding_items: List[MultimodalDataItem],
placeholder_tensor: torch.Tensor,
input_ids: torch.Tensor,
):
"""
Get the multimodal embedding and its mask from input_ids
"""
# 1. Get the embedding
embedding = data_embedding_func(embedding_items)
# 2. Check the embedding
if embedding.dim() == 2:
num_mm_tokens_in_embedding = embedding.shape[0]
else:
num_mm_tokens_in_embedding = embedding.shape[0] * embedding.shape[1]
# the mask of multimodal tokens from input_ids
special_multimodal_mask = torch.isin(
input_ids,
placeholder_tensor,
).unsqueeze(-1)
num_mm_tokens_in_input_ids = special_multimodal_mask.sum().item()
if num_mm_tokens_in_input_ids != num_mm_tokens_in_embedding:
logger.warning(
f"Number of tokens in multimodal embedding does not match those in the input text."
f"Got {num_mm_tokens_in_input_ids} tokens in the text but {num_mm_tokens_in_embedding} "
"tokens from multimodal embeddings."
)
if num_mm_tokens_in_input_ids < num_mm_tokens_in_embedding:
# TODO: chunked prefill will split special tokens from input_ids into several passes, failing the embedding
# a fix may be cache the unfinished multimodal embedding for future reuse, determine the tokens to embed with
# extend_start_loc and extend_seq_lens
chunked_prefill_size = global_server_args_dict["chunked_prefill_size"]
if chunked_prefill_size != -1:
logger.warning(
"You may want to avoid this issue by raising `chunked_prefill_size`, or disabling chunked prefill"
)
# extract from the end: this is a compromise
if embedding.dim() == 2:
embedding = embedding[-num_mm_tokens_in_input_ids:, :]
else:
num_multimodal = num_mm_tokens_in_input_ids // embedding.shape[0]
embedding = embedding[-num_multimodal:, :]
else:
raise RuntimeError(
f"Insufficient multimodal embedding length: {num_mm_tokens_in_input_ids=} vs {num_mm_tokens_in_embedding=}. This is an internal error"
)
return embedding, special_multimodal_mask
def embed_mm_inputs(
mm_inputs: MultimodalInputs,
input_ids: torch.Tensor,
input_embedding: nn.Embedding,
image_data_embedding_func: Callable[
[List[MultimodalDataItem]], torch.Tensor
] = None,
audio_data_embedding_func: Callable[
[List[MultimodalDataItem]], torch.Tensor
] = None,
placeholder_tokens: dict[Modality, List[int]] = None,
) -> Optional[torch.Tensor]:
"""
Calculate the multimodal embeddings if necessary, then scatter the result with the help of a boolean mask denoting the embed locations
Args:
placeholder_tokens: denoting the token of multimodal data in input_ids.
If none, the pad_values of multimodal items are used
Returns:
final embedding: Optional[torch.Tensor]
"""
if mm_inputs is None:
return None
# 1. Calculate the multimodal data which exists in input_ids, with the help of pad_values
# we assume that multimodal data are represented with its pad_values in input_ids
# See `pad_input_ids` for more detail
# if placeholder_tokens is specified
if placeholder_tokens is not None:
placeholder_token_ids = flatten_nested_list(
[placeholder_token for placeholder_token in placeholder_tokens.values()]
)
else:
placeholder_token_ids = [item.pad_value for item in mm_inputs.mm_items]
assert isinstance(placeholder_token_ids[0], int)
placeholder_tensor = torch.tensor(placeholder_token_ids, device=input_ids.device)
placeholder_masks = torch.isin(input_ids, placeholder_tensor)
appearing_pad_values = torch.unique(
input_ids[placeholder_masks], return_counts=False
)
if appearing_pad_values.numel() == 0:
# all been prefixed
inputs_embeds = input_embedding(input_ids)
else:
appearing_items = [
item
for item in mm_inputs.mm_items
if item.pad_value is not None and item.pad_value in appearing_pad_values
]
using_all_items = False
if len(appearing_items) == 0:
# This happens mostly when arg placeholder_token_ids is passed
logger.warning(
"No multimodal data item's pad value exist in placeholder ids. Using all items"
)
using_all_items = True
appearing_items = mm_inputs.mm_items
embeddings, masks = [], []
# 2. Get multimodal embedding separately
# TODO: make this more generic
# Try get image embedding if any
if (
any(True for item in appearing_items if item.is_image())
and image_data_embedding_func
):
items = [item for item in appearing_items if item.is_image()]
embedding, mask = get_embedding_and_mask(
data_embedding_func=image_data_embedding_func,
embedding_items=items,
placeholder_tensor=(
# use the specified modality token to identify the location to embed
placeholder_tokens[Modality.IMAGE]
if using_all_items
else torch.tensor(
[item.pad_value for item in items],
device=input_ids.device,
)
),
input_ids=input_ids,
)
embeddings += [embedding]
masks += [mask]
# Try get audio embedding if any
if (
any(True for item in appearing_items if item.is_audio())
and audio_data_embedding_func
):
items = [item for item in appearing_items if item.is_audio()]
embedding, mask = get_embedding_and_mask(
data_embedding_func=audio_data_embedding_func,
embedding_items=items,
placeholder_tensor=(
placeholder_tokens[Modality.AUDIO]
if using_all_items
else torch.tensor(
[item.pad_value for item in items],
device=input_ids.device,
)
),
input_ids=input_ids,
)
embeddings += [embedding]
masks += [mask]
# 3. Get input embeddings
vocab_size = input_embedding.num_embeddings
# Important: clamp after getting original multimodal regions
# Clamp input ids. This is because the input_ids for the multimodal tokens are
# filled with the hash values of the multimodal for the prefix matching in the radix attention.
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
input_ids.clamp_(min=0, max=vocab_size - 1)
inputs_embeds = input_embedding(input_ids)
# 4. Scatter embeddings into input embedding
for embedding, mask in zip(embeddings, masks):
mask = mask.expand_as(inputs_embeds).to(inputs_embeds.device)
inputs_embeds = inputs_embeds.masked_scatter(
mask,
embedding.to(inputs_embeds.device, inputs_embeds.dtype),
)
return inputs_embeds
def general_mm_embed_routine(
input_ids: torch.Tensor,
forward_batch: ForwardBatch,
language_model: nn.Module,
image_data_embedding_func: Callable[
[List[MultimodalDataItem]], torch.Tensor
] = None,
audio_data_embedding_func: Callable[
[List[MultimodalDataItem]], torch.Tensor
] = None,
placeholder_tokens: dict[Modality, List[int]] = None,
**kwargs,
) -> torch.Tensor:
"""
A general wrapper function to get final input embeds from multimodal models with a language model as causal model
Args:
placeholder_token_ids (List[int]): the ids of mm data placeholder tokens
image_data_embedding_func : the function returning the image embedding
audio_data_embedding_func : the function returning the image embedding
Returns:
forwarded hidden states
"""
assert hasattr(language_model, "get_input_embeddings")
embed_tokens = language_model.get_input_embeddings()
if (
not forward_batch.forward_mode.is_decode()
and forward_batch.contains_mm_inputs()
):
mm_input = forward_batch.merge_mm_inputs()
inputs_embeds = embed_mm_inputs(
mm_inputs=mm_input,
input_ids=input_ids,
input_embedding=embed_tokens,
image_data_embedding_func=image_data_embedding_func,
audio_data_embedding_func=audio_data_embedding_func,
placeholder_tokens=placeholder_tokens,
)
# once used, mm_inputs is useless, considering chunked-prefill is disabled for multimodal models
# just being defensive here
forward_batch.mm_inputs = None
else:
inputs_embeds = embed_tokens(input_ids)
hidden_states = language_model(
input_ids=None,
forward_batch=forward_batch,
input_embeds=inputs_embeds,
**kwargs,
)
return hidden_states
def get_multimodal_data_bounds(
input_ids: torch.Tensor, pad_values: List[int], token_pairs: List[Tuple[int, int]]
) -> torch.Tensor:
"""
Returns a tensor indicating the bounds of multimodal data (images, video, audio, etc.)
Returns:
[bounds_count, 2]
"""
# All the multimodal data in the batch should share the same special bound token ids.
start_tokens = [s for s, _e in token_pairs]
end_tokens = [e for _s, e in token_pairs]
assert all(isinstance(t, int) for t in start_tokens)
assert all(isinstance(t, int) for t in end_tokens)
start_cond = torch.isin(
input_ids, torch.tensor(start_tokens, device=input_ids.device)
)
end_cond = torch.isin(input_ids, torch.tensor(end_tokens, device=input_ids.device))
(data_start_tokens,) = torch.where(start_cond)
(data_end_tokens,) = torch.where(end_cond)
# the im_start_id sometimes can be cached as prefix, but it is needed for the embedding of the multimodal data
if len(data_start_tokens) != len(data_end_tokens):
if (
len(data_start_tokens) + 1 == len(data_end_tokens)
and input_ids[0] in pad_values
and data_end_tokens[0] < data_start_tokens[0]
):
data_start_tokens = torch.cat(
[
torch.tensor([0], device=data_start_tokens.device),
data_start_tokens,
]
)
valid_mm_data_nums = min(len(data_start_tokens), len(data_end_tokens))
if valid_mm_data_nums == 0:
return torch.zeros((0, 2), device=input_ids.device)
# Filter out pairs where start_token >= end_token
valid_pairs = []
for i in range(valid_mm_data_nums):
start_token = data_start_tokens[i]
end_token = data_end_tokens[i]
if start_token < end_token:
valid_pairs.append((start_token + 1, end_token - 1))
if not valid_pairs:
return torch.zeros((0, 2), device=input_ids.device)
# Convert valid pairs to tensor
valid_pairs_tensor = torch.tensor(valid_pairs, device=input_ids.device)
return valid_pairs_tensor