776 lines
29 KiB
Python
776 lines
29 KiB
Python
"""
|
|
Multi-modality utils
|
|
"""
|
|
|
|
import hashlib
|
|
import pickle
|
|
from abc import abstractmethod
|
|
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple
|
|
|
|
import numpy as np
|
|
import torch
|
|
from torch import nn
|
|
|
|
from sglang.srt.layers.multimodal import gpu_tensor_hash
|
|
from sglang.srt.managers.schedule_batch import (
|
|
Modality,
|
|
MultimodalDataItem,
|
|
MultimodalInputs,
|
|
global_server_args_dict,
|
|
)
|
|
from sglang.srt.mem_cache.multimodal_cache import MultiModalCache
|
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|
from sglang.srt.utils import flatten_nested_list, is_npu, print_warning_once
|
|
from sglang.utils import logger
|
|
|
|
_is_npu = is_npu()
|
|
|
|
# NOTE: Using the shared logger from sglang.utils instead of creating a module-specific logger
|
|
# to ensure consistent logging behavior across the codebase. This prevents issues with log
|
|
# propagation that can cause some log messages (like 'server is fired up') to not appear
|
|
# in the console when multimodal support is enabled.
|
|
|
|
# TODO(mick): nccl
|
|
# cuda_ipc: for intranode tensor sharing
|
|
TensorTransportMode = Literal["cuda_ipc", "auto", "default"]
|
|
|
|
|
|
class TransportProxyTensor(torch.Tensor):
|
|
"""
|
|
A convenient torch.Tensor subclass that carries extra metadata and supports
|
|
efficient inter-process communications
|
|
"""
|
|
|
|
@staticmethod
|
|
def __new__(
|
|
cls,
|
|
data: torch.Tensor,
|
|
name: Optional[str] = None,
|
|
fields: Optional[Dict[str, Any]] = None,
|
|
transport_mode: TensorTransportMode = "default",
|
|
*args,
|
|
**kwargs,
|
|
):
|
|
|
|
if not isinstance(data, torch.Tensor):
|
|
raise TypeError(
|
|
f"Input 'data' must be a torch.Tensor, but got {type(data)}"
|
|
)
|
|
|
|
instance = data.as_subclass(cls)
|
|
|
|
instance._metadata = {
|
|
"name": name,
|
|
"fields": fields if fields is not None else {},
|
|
"transport_mode": transport_mode,
|
|
}
|
|
|
|
return instance
|
|
|
|
def __getstate__(self):
|
|
"""
|
|
Called during pickling. Implements the serialization logic.
|
|
"""
|
|
# acquire all serialize metadata from _metadata
|
|
state = {
|
|
"metadata": self._metadata,
|
|
"tensor_data": None,
|
|
"ipc_extra": None,
|
|
}
|
|
|
|
transport_mode = self._metadata.get("transport_mode", "default")
|
|
|
|
if transport_mode == "cuda_ipc" and self.is_cuda:
|
|
try:
|
|
storage = self.untyped_storage()
|
|
handle = storage._share_cuda_()
|
|
|
|
state["ipc_extra"] = {
|
|
"handle": handle,
|
|
"shape": self.shape,
|
|
"dtype": self.dtype,
|
|
"stride": self.stride(),
|
|
"device_index": self.device.index,
|
|
}
|
|
state["tensor_data"] = None
|
|
except Exception as e:
|
|
# Failed to get CUDA IPC handle (possibly tp). Falling back to default transport.
|
|
state["metadata"]["transport_mode"] = "default"
|
|
state["tensor_data"] = self.as_subclass(torch.Tensor)
|
|
else:
|
|
state["metadata"]["transport_mode"] = "default"
|
|
state["tensor_data"] = self.as_subclass(torch.Tensor)
|
|
|
|
return state
|
|
|
|
def __setstate__(self, state: Dict[str, Any]):
|
|
"""
|
|
Called during unpickling. Implements the deserialization logic.
|
|
"""
|
|
self._metadata = state["metadata"]
|
|
|
|
transport_mode = self._metadata.get("transport_mode", "default")
|
|
|
|
if transport_mode == "cuda_ipc" and state["ipc_extra"] is not None:
|
|
ipc_extra = state["ipc_extra"]
|
|
handle, shape, dtype, stride, source_device_index = (
|
|
ipc_extra["handle"],
|
|
ipc_extra["shape"],
|
|
ipc_extra["dtype"],
|
|
ipc_extra["stride"],
|
|
ipc_extra["device_index"],
|
|
)
|
|
|
|
try:
|
|
target_device = torch.device(f"cuda:{source_device_index}")
|
|
with torch.cuda.device(target_device):
|
|
storage = torch.UntypedStorage._new_shared_cuda(*handle)
|
|
reconstructed_tensor = torch.empty(
|
|
0, dtype=dtype, device=target_device
|
|
).set_(storage, storage_offset=0, size=shape, stride=stride)
|
|
self.set_(reconstructed_tensor)
|
|
except Exception as e:
|
|
print(f"Error: Failed to deserialize from CUDA IPC handle ({e}).")
|
|
raise e
|
|
|
|
elif state["tensor_data"] is not None:
|
|
self.set_(state["tensor_data"])
|
|
else:
|
|
raise pickle.UnpicklingError(
|
|
"Invalid state for TransportProxyTensor: no tensor data found."
|
|
)
|
|
|
|
@property
|
|
def name(self) -> Optional[str]:
|
|
return self._metadata.get("name")
|
|
|
|
@property
|
|
def fields(self) -> Dict[str, Any]:
|
|
return self._metadata.get("fields", {})
|
|
|
|
@property
|
|
def transport_mode(self) -> TensorTransportMode:
|
|
return self._metadata.get("transport_mode", "default")
|
|
|
|
|
|
class MultiModalityDataPaddingPattern:
|
|
"""
|
|
Data tokens (like image tokens) often need special handling during padding
|
|
to maintain model compatibility. This class provides the interface for
|
|
implementing different padding strategies for data tokens
|
|
"""
|
|
|
|
@abstractmethod
|
|
def pad_input_tokens(
|
|
self, input_ids: List[int], mm_inputs: MultimodalInputs
|
|
) -> List[int]:
|
|
"""
|
|
Pad the input ids sequence containing data tokens, and replace them with pad_values
|
|
"""
|
|
pass
|
|
|
|
|
|
class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern):
|
|
"""In this pattern, data tokens should be enclosed by special token pairs (e.g. <image>...</image>, data_token_pairs)
|
|
|
|
The padded value in a region enclosed by a token pair with be the same one, as the MultimodalDataItem's pad value
|
|
|
|
This strategy should be applied when data content is marked by start/end token pairs in the input sequence.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
data_token_pairs: Optional[List[Tuple[int, int]]],
|
|
data_start_token_ids: Optional[List[int]] = None,
|
|
) -> None:
|
|
"""
|
|
|
|
Args:
|
|
data_start_token_ids marks the start of a single multimodal data
|
|
See Minicpmo's slice_start_id for example
|
|
"""
|
|
self.data_token_id_pairs = data_token_pairs
|
|
self.data_start_token_ids = data_start_token_ids or [
|
|
s for s, _e in data_token_pairs
|
|
]
|
|
|
|
def pad_input_tokens(
|
|
self, input_ids: List[int], mm_inputs: MultimodalInputs
|
|
) -> List[int]:
|
|
"""
|
|
This function will replace the data-tokens in between with pad_values accordingly
|
|
"""
|
|
pad_values = [item.pad_value for item in mm_inputs.mm_items]
|
|
data_token_pairs = self.data_token_id_pairs
|
|
mm_inputs.data_offsets = []
|
|
if data_token_pairs is None:
|
|
data_token_pairs = [mm_inputs.im_start_id, mm_inputs.im_end_id]
|
|
if data_token_pairs is None:
|
|
print_warning_once(
|
|
"No data_token_pairs provided, RadixAttention might be influenced."
|
|
)
|
|
return input_ids
|
|
start_token_ids = {s for s, _e in data_token_pairs}
|
|
end_tokens_ids = {e for _s, e in data_token_pairs}
|
|
|
|
padded_ids = []
|
|
last_idx = 0
|
|
data_idx = -1
|
|
|
|
start_indices = [i for i, x in enumerate(input_ids) if x in start_token_ids]
|
|
end_indices = [i for i, x in enumerate(input_ids) if x in end_tokens_ids]
|
|
|
|
if len(start_indices) != len(end_indices):
|
|
return input_ids
|
|
|
|
for start_idx, end_idx in zip(start_indices, end_indices):
|
|
padded_ids.extend(input_ids[last_idx : start_idx + 1])
|
|
|
|
if input_ids[start_idx] in self.data_start_token_ids:
|
|
data_idx += 1
|
|
mm_inputs.data_offsets += [start_idx]
|
|
|
|
if data_idx >= len(pad_values):
|
|
data_idx = len(pad_values) - 1
|
|
|
|
num_tokens = end_idx - start_idx - 1
|
|
pad_value = pad_values[data_idx]
|
|
padded_ids.extend([pad_value] * num_tokens)
|
|
|
|
last_idx = end_idx
|
|
|
|
padded_ids.extend(input_ids[last_idx:])
|
|
|
|
assert len(input_ids) == len(padded_ids), "Length validation fails"
|
|
return padded_ids
|
|
|
|
|
|
class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPattern):
|
|
"""In this pattern, data tokens should be represented as repetitions of a single token
|
|
e.g. <image><image>....<image>, or <audio><audio>...<audio>
|
|
"""
|
|
|
|
def pad_input_tokens(
|
|
self, input_ids: List[int], mm_inputs: MultimodalInputs
|
|
) -> List[int]:
|
|
"""
|
|
Replaces multimodal tokens in input_ids with corresponding pad_values from mm_items.
|
|
Each modality (image, audio, video) is handled separately based on its token_id.
|
|
"""
|
|
if not input_ids or not mm_inputs.mm_items:
|
|
return input_ids
|
|
|
|
input_ids_tensor = torch.as_tensor(input_ids)
|
|
|
|
# Create mapping of token_ids to pad_values for each modality
|
|
token_to_pad_mapping = {}
|
|
|
|
for item in mm_inputs.mm_items:
|
|
if item.is_image() and mm_inputs.im_token_id is not None:
|
|
token_to_pad_mapping[mm_inputs.im_token_id] = item.pad_value
|
|
elif item.is_audio() and mm_inputs.audio_token_id is not None:
|
|
token_to_pad_mapping[mm_inputs.audio_token_id] = item.pad_value
|
|
elif item.is_video() and mm_inputs.video_token_id is not None:
|
|
token_to_pad_mapping[mm_inputs.video_token_id] = item.pad_value
|
|
else:
|
|
raise ValueError(f"No multimodal token id provided for {item.modality}")
|
|
|
|
# Apply replacements for all tokens at once
|
|
for token_id, pad_value in token_to_pad_mapping.items():
|
|
input_ids_tensor[input_ids_tensor == token_id] = pad_value
|
|
|
|
ret_input_ids = input_ids_tensor.tolist()
|
|
|
|
return ret_input_ids
|
|
|
|
|
|
embedding_cache: Optional[MultiModalCache] = None
|
|
|
|
|
|
def init_embedding_cache(max_size: int = 0):
|
|
global embedding_cache
|
|
embedding_cache = MultiModalCache(max_size)
|
|
|
|
|
|
def get_embedding_hash(embedding_items: List[MultimodalDataItem]) -> int:
|
|
hash_list = [item.hash for item in embedding_items]
|
|
return hash(tuple(hash_list))
|
|
|
|
|
|
def get_embedding_chunk(
|
|
embedding: torch.Tensor,
|
|
extend_prefix_len: int,
|
|
extend_seq_len: int,
|
|
items_offset: List[Tuple[int, int]],
|
|
) -> Tuple[torch.Tensor, int, int]:
|
|
"""
|
|
Extract a chunk of embeddings based on the specified prefix length, sequence length, and offset ranges.
|
|
|
|
Args:
|
|
embedding: The full embedding tensor to extract a chunk from
|
|
extend_prefix_len: The starting position (prefix length) for extraction
|
|
extend_seq_len: The number of tokens to extract
|
|
items_offset: List of [start, end] offset ranges for multimodal items in the input sequence
|
|
|
|
Returns:
|
|
A tuple containing:
|
|
- The extracted embedding chunk as a tensor
|
|
- The start index used for extraction
|
|
- The end index used for extraction
|
|
|
|
Note:
|
|
If there's no overlap between the requested range and the offset ranges,
|
|
an empty tensor is returned with zeros for start and end indices.
|
|
"""
|
|
start_index, end_index = 0, 0
|
|
extend_start_index = extend_prefix_len
|
|
extend_end_index = extend_prefix_len + extend_seq_len - 1
|
|
|
|
for start, end in items_offset:
|
|
if extend_start_index >= start and extend_start_index <= end:
|
|
start_index += extend_start_index - start
|
|
elif extend_start_index > end:
|
|
start_index += end - start + 1
|
|
|
|
if extend_end_index >= start and extend_end_index <= end:
|
|
end_index += extend_end_index - start + 1
|
|
elif extend_end_index > end:
|
|
end_index += end - start + 1
|
|
# some models' embedding is 3-dim, reshape it to 2-dim
|
|
embedding = embedding.reshape(-1, embedding.shape[-1])
|
|
embedding_chunk = embedding[start_index:end_index]
|
|
return embedding_chunk, start_index, end_index
|
|
|
|
|
|
def _get_precomputed_embedding(
|
|
items: List[MultimodalDataItem],
|
|
) -> Optional[torch.Tensor]:
|
|
"""
|
|
If all items have precomputed_embeddings, return their concatenation.
|
|
If some but not all have precomputed_embeddings, raise NotImplementedError.
|
|
If none have precomputed_embeddings, return None.
|
|
"""
|
|
precomputed_embeddings = [item.precomputed_embeddings for item in items]
|
|
if any(feature is not None for feature in precomputed_embeddings):
|
|
if not all(feature is not None for feature in precomputed_embeddings):
|
|
raise NotImplementedError(
|
|
"MM inputs where only some items are precomputed."
|
|
)
|
|
result = torch.concat(precomputed_embeddings)
|
|
# some models embedding is 3-dim, reshape it to 2-dim (similar to get_embedding_chunk)
|
|
result = result.reshape(-1, result.shape[-1])
|
|
return result
|
|
return None
|
|
|
|
|
|
def _get_chunked_prefill_embedding(
|
|
data_embedding_func: Callable[[List[MultimodalDataItem]], torch.Tensor],
|
|
embedding_items: List[MultimodalDataItem],
|
|
items_size: List[int],
|
|
prefix_length: List[int],
|
|
extend_length: List[int],
|
|
items_offset_list: List[List[Tuple[int, int]]],
|
|
) -> Optional[torch.Tensor]:
|
|
# Calculate embedding for each request, try to get it from cache to avoid repeated calculation
|
|
embedding_list = []
|
|
# FIXME(Xinyuan): temporary workaround for eagle3, which may have len(items_size) > len(prefix_length)
|
|
max_iterations = min(len(items_size) - 1, len(prefix_length))
|
|
for i in range(max_iterations):
|
|
if items_size[i] == items_size[i + 1]:
|
|
continue
|
|
embedding_items_per_req = embedding_items[items_size[i] : items_size[i + 1]]
|
|
items_offset = items_offset_list[i]
|
|
assert items_offset is not None, items_offset
|
|
embedding_items_hash = get_embedding_hash(embedding_items_per_req)
|
|
# if all items has been prefixed, we do not need to calculate embedding
|
|
if all([offset_end < prefix_length[i] for _, offset_end in items_offset]):
|
|
continue
|
|
embedding_per_req = embedding_cache.get(embedding_items_hash)
|
|
if embedding_per_req is None:
|
|
embedding_per_req = data_embedding_func(embedding_items_per_req)
|
|
if not embedding_cache.put(embedding_items_hash, embedding_per_req):
|
|
print_warning_once(
|
|
"Multimodal embedding cache is full. This typically occurs when a single "
|
|
"embedding exceeds the cache size limit. Consider increasing the "
|
|
"`SGLANG_VLM_CACHE_SIZE_MB` environment variable or reducing the input "
|
|
"embedding size."
|
|
)
|
|
|
|
embedding_per_req_chunk, _, _ = get_embedding_chunk(
|
|
embedding=embedding_per_req,
|
|
extend_prefix_len=prefix_length[i],
|
|
extend_seq_len=extend_length[i] if i < len(extend_length) else 0,
|
|
items_offset=items_offset,
|
|
)
|
|
embedding_list.append(embedding_per_req_chunk)
|
|
if len(embedding_list) == 0:
|
|
return None
|
|
return torch.concat(embedding_list, dim=0)
|
|
|
|
|
|
def _get_multimodal_mask(
|
|
input_ids: torch.Tensor, placeholder_tensor: torch.Tensor
|
|
) -> torch.Tensor:
|
|
return torch.isin(input_ids, placeholder_tensor).unsqueeze(-1)
|
|
|
|
|
|
def _adjust_embedding_length(
|
|
embedding: torch.Tensor,
|
|
mask: torch.Tensor,
|
|
logger,
|
|
) -> torch.Tensor:
|
|
num_mm_tokens_in_embedding = embedding.shape[0]
|
|
num_mm_tokens_in_input_ids = mask.sum().item()
|
|
if num_mm_tokens_in_input_ids != num_mm_tokens_in_embedding:
|
|
logger.warning(
|
|
f"Number of tokens in multimodal embedding does not match those in the input text. "
|
|
f"Got {num_mm_tokens_in_input_ids} tokens in the text but {num_mm_tokens_in_embedding} "
|
|
f"tokens from multimodal embeddings."
|
|
)
|
|
if num_mm_tokens_in_input_ids < num_mm_tokens_in_embedding:
|
|
chunked_prefill_size = global_server_args_dict["chunked_prefill_size"]
|
|
if chunked_prefill_size != -1:
|
|
logger.warning(
|
|
"You may want to avoid this issue by raising `chunked_prefill_size`, or disabling chunked prefill"
|
|
)
|
|
# extract from the end: this is a compromise
|
|
if embedding.dim() == 2:
|
|
embedding = embedding[-num_mm_tokens_in_input_ids:, :]
|
|
else:
|
|
num_multimodal = num_mm_tokens_in_input_ids // embedding.shape[0]
|
|
embedding = embedding[-num_multimodal:, :]
|
|
else:
|
|
raise RuntimeError(
|
|
f"Insufficient multimodal embedding length: {num_mm_tokens_in_input_ids=} vs {num_mm_tokens_in_embedding=}. This is an internal error"
|
|
)
|
|
return embedding
|
|
|
|
|
|
def get_embedding_and_mask(
|
|
data_embedding_func: Callable[[List[MultimodalDataItem]], torch.Tensor],
|
|
embedding_items: List[MultimodalDataItem],
|
|
placeholder_tensor: torch.Tensor,
|
|
input_ids: torch.Tensor,
|
|
items_size: List[int],
|
|
prefix_length: List[int],
|
|
extend_length: List[int],
|
|
items_offset_list: List[List[Tuple[int, int]]],
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Generate multimodal embeddings and create a mask for identifying their positions in the input sequence.
|
|
|
|
Args:
|
|
data_embedding_func: Function that generates embeddings for multimodal items
|
|
embedding_items: List of multimodal items to embed
|
|
placeholder_tensor: Tensor containing token IDs that serve as placeholders for multimodal content
|
|
input_ids: The input token IDs tensor
|
|
items_size: Cumulative sizes of multimodal items per request
|
|
prefix_length: Prefix lengths for each request
|
|
extend_length: Sequence lengths for each request
|
|
items_offset_list: List of offset ranges for multimodal items in each request
|
|
|
|
Returns:
|
|
A tuple containing:
|
|
- The generated embeddings tensor
|
|
- A boolean mask tensor indicating where these embeddings should be placed
|
|
"""
|
|
# 1. Get embedding
|
|
embedding = _get_precomputed_embedding(embedding_items)
|
|
if embedding is None:
|
|
embedding = _get_chunked_prefill_embedding(
|
|
data_embedding_func,
|
|
embedding_items,
|
|
items_size,
|
|
prefix_length,
|
|
extend_length,
|
|
items_offset_list,
|
|
)
|
|
if embedding is None:
|
|
return None, None
|
|
# 2. Get mask
|
|
if _is_npu:
|
|
torch.npu.current_stream().synchronize()
|
|
special_multimodal_mask = _get_multimodal_mask(input_ids, placeholder_tensor)
|
|
# 3. Adjust embedding length if needed
|
|
embedding = _adjust_embedding_length(embedding, special_multimodal_mask, logger)
|
|
return embedding, special_multimodal_mask
|
|
|
|
|
|
def embed_mm_inputs(
|
|
mm_inputs_list: List[MultimodalInputs],
|
|
extend_prefix_lens: List[int],
|
|
extend_seq_lens: List[int],
|
|
input_ids: torch.Tensor,
|
|
input_embedding: nn.Embedding,
|
|
multimodal_model: nn.Module = None,
|
|
data_embedding_func_mapping: Dict[
|
|
Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
|
|
] = None,
|
|
placeholder_tokens: dict[Modality, List[int]] = None,
|
|
) -> Optional[torch.Tensor]:
|
|
"""
|
|
Embed multimodal inputs and integrate them with text token embeddings.
|
|
|
|
Args:
|
|
mm_inputs_list: List of multimodal inputs to process
|
|
extend_prefix_lens: Prefix lengths for each request
|
|
extend_seq_lens: Sequence lengths for each request
|
|
input_ids: Input token IDs tensor
|
|
input_embedding: Embedding layer for text tokens
|
|
placeholder_tokens: Token IDs for multimodal placeholders (uses pad_values if None)
|
|
|
|
Returns:
|
|
Combined embedding tensor with multimodal content integrated
|
|
"""
|
|
|
|
if mm_inputs_list is None:
|
|
return None
|
|
|
|
# 1. Calculate the multimodal data which exists in input_ids, with the help of pad_values
|
|
# we assume that multimodal data are represented with its pad_values in input_ids
|
|
item_flatten_list = []
|
|
for mm_inputs in mm_inputs_list:
|
|
item_flatten_list += [item for item in mm_inputs.mm_items if item is not None]
|
|
|
|
embeddings, masks = [], []
|
|
# 2. Get multimodal embedding separately
|
|
# Try get mm embedding if any
|
|
for modality in Modality.all():
|
|
items = [
|
|
item for item in item_flatten_list if item.is_modality(modality=modality)
|
|
]
|
|
embedder = (
|
|
None
|
|
if data_embedding_func_mapping is None
|
|
else data_embedding_func_mapping.get(modality, None)
|
|
)
|
|
if embedder is None:
|
|
# "image", "video", etc
|
|
modality_id = modality.name.lower()
|
|
embedder = getattr(multimodal_model, f"get_{modality_id}_feature", None)
|
|
if len(items) != 0 and embedder is not None:
|
|
placeholder_tensor = torch.as_tensor(
|
|
[item.pad_value for item in items],
|
|
device=input_ids.device,
|
|
)
|
|
# calculate per request items length offset
|
|
items_size = torch.zeros(len(mm_inputs_list) + 1, dtype=int)
|
|
items_offsets = []
|
|
for i, mm_inputs in enumerate(mm_inputs_list):
|
|
mm_items = [
|
|
item
|
|
for item in mm_inputs.mm_items
|
|
if item.is_modality(modality=modality)
|
|
]
|
|
items_size[i + 1] = len(mm_items)
|
|
items_offsets.append(
|
|
flatten_nested_list([item.offsets for item in mm_items])
|
|
)
|
|
items_size = torch.cumsum(items_size, dim=0).tolist()
|
|
|
|
embedding, mask = get_embedding_and_mask(
|
|
data_embedding_func=embedder,
|
|
embedding_items=items,
|
|
placeholder_tensor=placeholder_tensor,
|
|
input_ids=input_ids,
|
|
items_size=items_size,
|
|
prefix_length=extend_prefix_lens,
|
|
extend_length=extend_seq_lens,
|
|
items_offset_list=items_offsets,
|
|
)
|
|
embeddings += [embedding]
|
|
masks += [mask]
|
|
|
|
# 3. Get input embeddings
|
|
vocab_size = input_embedding.num_embeddings
|
|
# Important: clamp after getting original multimodal regions
|
|
# Clamp input ids. This is because the input_ids for the multimodal tokens are
|
|
# filled with the hash values of the multimodal for the prefix matching in the radix attention.
|
|
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
|
|
input_ids.clamp_(min=0, max=vocab_size - 1)
|
|
inputs_embeds = input_embedding(input_ids)
|
|
|
|
# 4. scatter embeddings into input embedding
|
|
for embedding, mask in zip(embeddings, masks):
|
|
if embedding is None or mask is None:
|
|
continue
|
|
# in-place update
|
|
indices = torch.where(mask.squeeze(dim=-1))[0]
|
|
inputs_embeds[indices] = embedding.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
return inputs_embeds
|
|
|
|
|
|
def general_mm_embed_routine(
|
|
input_ids: torch.Tensor,
|
|
forward_batch: ForwardBatch,
|
|
language_model: nn.Module,
|
|
multimodal_model: Optional[nn.Module] = None,
|
|
data_embedding_funcs: Dict[
|
|
Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
|
|
] = None,
|
|
placeholder_tokens: Optional[dict[Modality, List[int]]] = None,
|
|
**kwargs,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Process multimodal inputs and forward through language model.
|
|
|
|
Args:
|
|
input_ids: Input token IDs tensor
|
|
forward_batch: Batch information for model forward pass
|
|
language_model: Base language model to use
|
|
data_embedding_funcs: A dictionary mapping from modality type to the corresponding embedding function.
|
|
placeholder_tokens: Token IDs for multimodal placeholders
|
|
**kwargs: Additional arguments passed to language model
|
|
|
|
Returns:
|
|
Hidden states from language model forward pass
|
|
"""
|
|
assert hasattr(language_model, "get_input_embeddings")
|
|
embed_tokens = language_model.get_input_embeddings()
|
|
if (
|
|
not forward_batch.forward_mode.is_decode()
|
|
and not forward_batch.forward_mode.is_target_verify()
|
|
and forward_batch.contains_mm_inputs()
|
|
):
|
|
mm_inputs_list = [
|
|
mm_input for mm_input in forward_batch.mm_inputs if mm_input is not None
|
|
]
|
|
extend_prefix_lens = [
|
|
prefix_len
|
|
for i, prefix_len in enumerate(forward_batch.extend_prefix_lens_cpu)
|
|
if forward_batch.mm_inputs[i] is not None
|
|
]
|
|
extend_seq_lens = [
|
|
seq_len
|
|
for i, seq_len in enumerate(forward_batch.extend_seq_lens_cpu)
|
|
if forward_batch.mm_inputs[i] is not None
|
|
]
|
|
inputs_embeds = embed_mm_inputs(
|
|
mm_inputs_list=mm_inputs_list,
|
|
extend_prefix_lens=extend_prefix_lens,
|
|
extend_seq_lens=extend_seq_lens,
|
|
input_ids=input_ids,
|
|
input_embedding=embed_tokens,
|
|
multimodal_model=multimodal_model,
|
|
data_embedding_func_mapping=data_embedding_funcs,
|
|
placeholder_tokens=placeholder_tokens,
|
|
)
|
|
# once used, mm_inputs is useless, considering chunked-prefill is disabled for multimodal models
|
|
# just being defensive here
|
|
forward_batch.mm_inputs = None
|
|
else:
|
|
inputs_embeds = embed_tokens(input_ids)
|
|
|
|
hidden_states = language_model(
|
|
input_ids=None,
|
|
forward_batch=forward_batch,
|
|
input_embeds=inputs_embeds,
|
|
**kwargs,
|
|
)
|
|
return hidden_states
|
|
|
|
|
|
def get_multimodal_data_bounds(
|
|
input_ids: torch.Tensor, pad_values: List[int], token_pairs: List[Tuple[int, int]]
|
|
) -> torch.Tensor:
|
|
"""
|
|
Returns a tensor indicating the bounds of multimodal data (images, video, audio, etc.)
|
|
|
|
Returns:
|
|
[bounds_count, 2]
|
|
"""
|
|
# All the multimodal data in the batch should share the same special bound token ids.
|
|
start_tokens = {s for s, _e in token_pairs}
|
|
end_tokens = {e for _s, e in token_pairs}
|
|
|
|
assert all(isinstance(t, int) for t in start_tokens)
|
|
assert all(isinstance(t, int) for t in end_tokens)
|
|
|
|
start_cond = torch.isin(
|
|
input_ids, torch.as_tensor(start_tokens, device=input_ids.device)
|
|
)
|
|
end_cond = torch.isin(
|
|
input_ids, torch.as_tensor(end_tokens, device=input_ids.device)
|
|
)
|
|
|
|
(data_start_tokens,) = torch.where(start_cond)
|
|
(data_end_tokens,) = torch.where(end_cond)
|
|
|
|
data_start_tokens_cpu = data_start_tokens.cpu().tolist()
|
|
data_end_tokens_cpu = data_end_tokens.cpu().tolist()
|
|
|
|
# the im_start_id sometimes can be cached as prefix, but it is needed for the embedding of the multimodal data
|
|
if len(data_start_tokens_cpu) != len(data_end_tokens_cpu):
|
|
if (
|
|
len(data_start_tokens_cpu) + 1 == len(data_end_tokens_cpu)
|
|
and input_ids[0].item() in pad_values
|
|
and data_end_tokens_cpu
|
|
and data_start_tokens_cpu
|
|
and data_end_tokens_cpu[0] < data_start_tokens_cpu[0]
|
|
):
|
|
data_start_tokens_cpu.insert(0, 0)
|
|
valid_mm_data_nums = min(len(data_start_tokens_cpu), len(data_end_tokens_cpu))
|
|
|
|
if valid_mm_data_nums == 0:
|
|
return torch.zeros((0, 2), device=input_ids.device)
|
|
|
|
# Filter out pairs where start_token >= end_token
|
|
valid_pairs = []
|
|
for i in range(valid_mm_data_nums):
|
|
start_token = data_start_tokens_cpu[i]
|
|
end_token = data_end_tokens_cpu[i]
|
|
if start_token < end_token:
|
|
valid_pairs.append((start_token + 1, end_token - 1))
|
|
|
|
if not valid_pairs:
|
|
return torch.zeros((0, 2), device=input_ids.device)
|
|
|
|
# Convert valid pairs to tensor
|
|
valid_pairs_tensor = torch.as_tensor(valid_pairs, device=input_ids.device)
|
|
return valid_pairs_tensor
|
|
|
|
|
|
def data_hash(data) -> int:
|
|
hash_bytes = hashlib.sha256(data).digest()[:8]
|
|
return int.from_bytes(hash_bytes, byteorder="big", signed=False)
|
|
|
|
|
|
def tensor_hash(tensor_list) -> int:
|
|
"""
|
|
hash a tensor or a tensor list
|
|
"""
|
|
tensor = tensor_list
|
|
if isinstance(tensor_list, list):
|
|
tensor_list = flatten_nested_list(tensor_list)
|
|
tensor_list = [
|
|
x.flatten() if isinstance(x, torch.Tensor) else x for x in tensor_list
|
|
]
|
|
tensor = torch.concat(tensor_list)
|
|
if tensor.is_cuda:
|
|
return gpu_tensor_hash(tensor.cuda())
|
|
tensor = tensor.detach().contiguous()
|
|
|
|
if tensor.dtype == torch.bfloat16:
|
|
# memoryview() doesn't support PyTorch's BFloat16 dtype
|
|
tensor = tensor.float()
|
|
|
|
assert isinstance(tensor, torch.Tensor)
|
|
tensor_cpu = tensor.cpu()
|
|
|
|
mv = memoryview(tensor_cpu.numpy())
|
|
return data_hash(mv.tobytes())
|
|
|
|
|
|
def hash_feature(f):
|
|
if isinstance(f, list):
|
|
if isinstance(f[0], torch.Tensor):
|
|
return tensor_hash(f)
|
|
return data_hash(tuple(flatten_nested_list(f)))
|
|
elif isinstance(f, np.ndarray):
|
|
arr = np.ascontiguousarray(f)
|
|
arr_bytes = arr.tobytes()
|
|
return data_hash(arr_bytes)
|
|
elif isinstance(f, torch.Tensor):
|
|
return tensor_hash([f])
|
|
return data_hash(f)
|