Files
sglang/python/sglang/srt/managers/mm_utils.py
2025-09-07 20:44:34 -07:00

776 lines
29 KiB
Python

"""
Multi-modality utils
"""
import hashlib
import pickle
from abc import abstractmethod
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple
import numpy as np
import torch
from torch import nn
from sglang.srt.layers.multimodal import gpu_tensor_hash
from sglang.srt.managers.schedule_batch import (
Modality,
MultimodalDataItem,
MultimodalInputs,
global_server_args_dict,
)
from sglang.srt.mem_cache.multimodal_cache import MultiModalCache
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import flatten_nested_list, is_npu, print_warning_once
from sglang.utils import logger
_is_npu = is_npu()
# NOTE: Using the shared logger from sglang.utils instead of creating a module-specific logger
# to ensure consistent logging behavior across the codebase. This prevents issues with log
# propagation that can cause some log messages (like 'server is fired up') to not appear
# in the console when multimodal support is enabled.
# TODO(mick): nccl
# cuda_ipc: for intranode tensor sharing
TensorTransportMode = Literal["cuda_ipc", "auto", "default"]
class TransportProxyTensor(torch.Tensor):
"""
A convenient torch.Tensor subclass that carries extra metadata and supports
efficient inter-process communications
"""
@staticmethod
def __new__(
cls,
data: torch.Tensor,
name: Optional[str] = None,
fields: Optional[Dict[str, Any]] = None,
transport_mode: TensorTransportMode = "default",
*args,
**kwargs,
):
if not isinstance(data, torch.Tensor):
raise TypeError(
f"Input 'data' must be a torch.Tensor, but got {type(data)}"
)
instance = data.as_subclass(cls)
instance._metadata = {
"name": name,
"fields": fields if fields is not None else {},
"transport_mode": transport_mode,
}
return instance
def __getstate__(self):
"""
Called during pickling. Implements the serialization logic.
"""
# acquire all serialize metadata from _metadata
state = {
"metadata": self._metadata,
"tensor_data": None,
"ipc_extra": None,
}
transport_mode = self._metadata.get("transport_mode", "default")
if transport_mode == "cuda_ipc" and self.is_cuda:
try:
storage = self.untyped_storage()
handle = storage._share_cuda_()
state["ipc_extra"] = {
"handle": handle,
"shape": self.shape,
"dtype": self.dtype,
"stride": self.stride(),
"device_index": self.device.index,
}
state["tensor_data"] = None
except Exception as e:
# Failed to get CUDA IPC handle (possibly tp). Falling back to default transport.
state["metadata"]["transport_mode"] = "default"
state["tensor_data"] = self.as_subclass(torch.Tensor)
else:
state["metadata"]["transport_mode"] = "default"
state["tensor_data"] = self.as_subclass(torch.Tensor)
return state
def __setstate__(self, state: Dict[str, Any]):
"""
Called during unpickling. Implements the deserialization logic.
"""
self._metadata = state["metadata"]
transport_mode = self._metadata.get("transport_mode", "default")
if transport_mode == "cuda_ipc" and state["ipc_extra"] is not None:
ipc_extra = state["ipc_extra"]
handle, shape, dtype, stride, source_device_index = (
ipc_extra["handle"],
ipc_extra["shape"],
ipc_extra["dtype"],
ipc_extra["stride"],
ipc_extra["device_index"],
)
try:
target_device = torch.device(f"cuda:{source_device_index}")
with torch.cuda.device(target_device):
storage = torch.UntypedStorage._new_shared_cuda(*handle)
reconstructed_tensor = torch.empty(
0, dtype=dtype, device=target_device
).set_(storage, storage_offset=0, size=shape, stride=stride)
self.set_(reconstructed_tensor)
except Exception as e:
print(f"Error: Failed to deserialize from CUDA IPC handle ({e}).")
raise e
elif state["tensor_data"] is not None:
self.set_(state["tensor_data"])
else:
raise pickle.UnpicklingError(
"Invalid state for TransportProxyTensor: no tensor data found."
)
@property
def name(self) -> Optional[str]:
return self._metadata.get("name")
@property
def fields(self) -> Dict[str, Any]:
return self._metadata.get("fields", {})
@property
def transport_mode(self) -> TensorTransportMode:
return self._metadata.get("transport_mode", "default")
class MultiModalityDataPaddingPattern:
"""
Data tokens (like image tokens) often need special handling during padding
to maintain model compatibility. This class provides the interface for
implementing different padding strategies for data tokens
"""
@abstractmethod
def pad_input_tokens(
self, input_ids: List[int], mm_inputs: MultimodalInputs
) -> List[int]:
"""
Pad the input ids sequence containing data tokens, and replace them with pad_values
"""
pass
class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern):
"""In this pattern, data tokens should be enclosed by special token pairs (e.g. <image>...</image>, data_token_pairs)
The padded value in a region enclosed by a token pair with be the same one, as the MultimodalDataItem's pad value
This strategy should be applied when data content is marked by start/end token pairs in the input sequence.
"""
def __init__(
self,
data_token_pairs: Optional[List[Tuple[int, int]]],
data_start_token_ids: Optional[List[int]] = None,
) -> None:
"""
Args:
data_start_token_ids marks the start of a single multimodal data
See Minicpmo's slice_start_id for example
"""
self.data_token_id_pairs = data_token_pairs
self.data_start_token_ids = data_start_token_ids or [
s for s, _e in data_token_pairs
]
def pad_input_tokens(
self, input_ids: List[int], mm_inputs: MultimodalInputs
) -> List[int]:
"""
This function will replace the data-tokens in between with pad_values accordingly
"""
pad_values = [item.pad_value for item in mm_inputs.mm_items]
data_token_pairs = self.data_token_id_pairs
mm_inputs.data_offsets = []
if data_token_pairs is None:
data_token_pairs = [mm_inputs.im_start_id, mm_inputs.im_end_id]
if data_token_pairs is None:
print_warning_once(
"No data_token_pairs provided, RadixAttention might be influenced."
)
return input_ids
start_token_ids = {s for s, _e in data_token_pairs}
end_tokens_ids = {e for _s, e in data_token_pairs}
padded_ids = []
last_idx = 0
data_idx = -1
start_indices = [i for i, x in enumerate(input_ids) if x in start_token_ids]
end_indices = [i for i, x in enumerate(input_ids) if x in end_tokens_ids]
if len(start_indices) != len(end_indices):
return input_ids
for start_idx, end_idx in zip(start_indices, end_indices):
padded_ids.extend(input_ids[last_idx : start_idx + 1])
if input_ids[start_idx] in self.data_start_token_ids:
data_idx += 1
mm_inputs.data_offsets += [start_idx]
if data_idx >= len(pad_values):
data_idx = len(pad_values) - 1
num_tokens = end_idx - start_idx - 1
pad_value = pad_values[data_idx]
padded_ids.extend([pad_value] * num_tokens)
last_idx = end_idx
padded_ids.extend(input_ids[last_idx:])
assert len(input_ids) == len(padded_ids), "Length validation fails"
return padded_ids
class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPattern):
"""In this pattern, data tokens should be represented as repetitions of a single token
e.g. <image><image>....<image>, or <audio><audio>...<audio>
"""
def pad_input_tokens(
self, input_ids: List[int], mm_inputs: MultimodalInputs
) -> List[int]:
"""
Replaces multimodal tokens in input_ids with corresponding pad_values from mm_items.
Each modality (image, audio, video) is handled separately based on its token_id.
"""
if not input_ids or not mm_inputs.mm_items:
return input_ids
input_ids_tensor = torch.as_tensor(input_ids)
# Create mapping of token_ids to pad_values for each modality
token_to_pad_mapping = {}
for item in mm_inputs.mm_items:
if item.is_image() and mm_inputs.im_token_id is not None:
token_to_pad_mapping[mm_inputs.im_token_id] = item.pad_value
elif item.is_audio() and mm_inputs.audio_token_id is not None:
token_to_pad_mapping[mm_inputs.audio_token_id] = item.pad_value
elif item.is_video() and mm_inputs.video_token_id is not None:
token_to_pad_mapping[mm_inputs.video_token_id] = item.pad_value
else:
raise ValueError(f"No multimodal token id provided for {item.modality}")
# Apply replacements for all tokens at once
for token_id, pad_value in token_to_pad_mapping.items():
input_ids_tensor[input_ids_tensor == token_id] = pad_value
ret_input_ids = input_ids_tensor.tolist()
return ret_input_ids
embedding_cache: Optional[MultiModalCache] = None
def init_embedding_cache(max_size: int = 0):
global embedding_cache
embedding_cache = MultiModalCache(max_size)
def get_embedding_hash(embedding_items: List[MultimodalDataItem]) -> int:
hash_list = [item.hash for item in embedding_items]
return hash(tuple(hash_list))
def get_embedding_chunk(
embedding: torch.Tensor,
extend_prefix_len: int,
extend_seq_len: int,
items_offset: List[Tuple[int, int]],
) -> Tuple[torch.Tensor, int, int]:
"""
Extract a chunk of embeddings based on the specified prefix length, sequence length, and offset ranges.
Args:
embedding: The full embedding tensor to extract a chunk from
extend_prefix_len: The starting position (prefix length) for extraction
extend_seq_len: The number of tokens to extract
items_offset: List of [start, end] offset ranges for multimodal items in the input sequence
Returns:
A tuple containing:
- The extracted embedding chunk as a tensor
- The start index used for extraction
- The end index used for extraction
Note:
If there's no overlap between the requested range and the offset ranges,
an empty tensor is returned with zeros for start and end indices.
"""
start_index, end_index = 0, 0
extend_start_index = extend_prefix_len
extend_end_index = extend_prefix_len + extend_seq_len - 1
for start, end in items_offset:
if extend_start_index >= start and extend_start_index <= end:
start_index += extend_start_index - start
elif extend_start_index > end:
start_index += end - start + 1
if extend_end_index >= start and extend_end_index <= end:
end_index += extend_end_index - start + 1
elif extend_end_index > end:
end_index += end - start + 1
# some models' embedding is 3-dim, reshape it to 2-dim
embedding = embedding.reshape(-1, embedding.shape[-1])
embedding_chunk = embedding[start_index:end_index]
return embedding_chunk, start_index, end_index
def _get_precomputed_embedding(
items: List[MultimodalDataItem],
) -> Optional[torch.Tensor]:
"""
If all items have precomputed_embeddings, return their concatenation.
If some but not all have precomputed_embeddings, raise NotImplementedError.
If none have precomputed_embeddings, return None.
"""
precomputed_embeddings = [item.precomputed_embeddings for item in items]
if any(feature is not None for feature in precomputed_embeddings):
if not all(feature is not None for feature in precomputed_embeddings):
raise NotImplementedError(
"MM inputs where only some items are precomputed."
)
result = torch.concat(precomputed_embeddings)
# some models embedding is 3-dim, reshape it to 2-dim (similar to get_embedding_chunk)
result = result.reshape(-1, result.shape[-1])
return result
return None
def _get_chunked_prefill_embedding(
data_embedding_func: Callable[[List[MultimodalDataItem]], torch.Tensor],
embedding_items: List[MultimodalDataItem],
items_size: List[int],
prefix_length: List[int],
extend_length: List[int],
items_offset_list: List[List[Tuple[int, int]]],
) -> Optional[torch.Tensor]:
# Calculate embedding for each request, try to get it from cache to avoid repeated calculation
embedding_list = []
# FIXME(Xinyuan): temporary workaround for eagle3, which may have len(items_size) > len(prefix_length)
max_iterations = min(len(items_size) - 1, len(prefix_length))
for i in range(max_iterations):
if items_size[i] == items_size[i + 1]:
continue
embedding_items_per_req = embedding_items[items_size[i] : items_size[i + 1]]
items_offset = items_offset_list[i]
assert items_offset is not None, items_offset
embedding_items_hash = get_embedding_hash(embedding_items_per_req)
# if all items has been prefixed, we do not need to calculate embedding
if all([offset_end < prefix_length[i] for _, offset_end in items_offset]):
continue
embedding_per_req = embedding_cache.get(embedding_items_hash)
if embedding_per_req is None:
embedding_per_req = data_embedding_func(embedding_items_per_req)
if not embedding_cache.put(embedding_items_hash, embedding_per_req):
print_warning_once(
"Multimodal embedding cache is full. This typically occurs when a single "
"embedding exceeds the cache size limit. Consider increasing the "
"`SGLANG_VLM_CACHE_SIZE_MB` environment variable or reducing the input "
"embedding size."
)
embedding_per_req_chunk, _, _ = get_embedding_chunk(
embedding=embedding_per_req,
extend_prefix_len=prefix_length[i],
extend_seq_len=extend_length[i] if i < len(extend_length) else 0,
items_offset=items_offset,
)
embedding_list.append(embedding_per_req_chunk)
if len(embedding_list) == 0:
return None
return torch.concat(embedding_list, dim=0)
def _get_multimodal_mask(
input_ids: torch.Tensor, placeholder_tensor: torch.Tensor
) -> torch.Tensor:
return torch.isin(input_ids, placeholder_tensor).unsqueeze(-1)
def _adjust_embedding_length(
embedding: torch.Tensor,
mask: torch.Tensor,
logger,
) -> torch.Tensor:
num_mm_tokens_in_embedding = embedding.shape[0]
num_mm_tokens_in_input_ids = mask.sum().item()
if num_mm_tokens_in_input_ids != num_mm_tokens_in_embedding:
logger.warning(
f"Number of tokens in multimodal embedding does not match those in the input text. "
f"Got {num_mm_tokens_in_input_ids} tokens in the text but {num_mm_tokens_in_embedding} "
f"tokens from multimodal embeddings."
)
if num_mm_tokens_in_input_ids < num_mm_tokens_in_embedding:
chunked_prefill_size = global_server_args_dict["chunked_prefill_size"]
if chunked_prefill_size != -1:
logger.warning(
"You may want to avoid this issue by raising `chunked_prefill_size`, or disabling chunked prefill"
)
# extract from the end: this is a compromise
if embedding.dim() == 2:
embedding = embedding[-num_mm_tokens_in_input_ids:, :]
else:
num_multimodal = num_mm_tokens_in_input_ids // embedding.shape[0]
embedding = embedding[-num_multimodal:, :]
else:
raise RuntimeError(
f"Insufficient multimodal embedding length: {num_mm_tokens_in_input_ids=} vs {num_mm_tokens_in_embedding=}. This is an internal error"
)
return embedding
def get_embedding_and_mask(
data_embedding_func: Callable[[List[MultimodalDataItem]], torch.Tensor],
embedding_items: List[MultimodalDataItem],
placeholder_tensor: torch.Tensor,
input_ids: torch.Tensor,
items_size: List[int],
prefix_length: List[int],
extend_length: List[int],
items_offset_list: List[List[Tuple[int, int]]],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Generate multimodal embeddings and create a mask for identifying their positions in the input sequence.
Args:
data_embedding_func: Function that generates embeddings for multimodal items
embedding_items: List of multimodal items to embed
placeholder_tensor: Tensor containing token IDs that serve as placeholders for multimodal content
input_ids: The input token IDs tensor
items_size: Cumulative sizes of multimodal items per request
prefix_length: Prefix lengths for each request
extend_length: Sequence lengths for each request
items_offset_list: List of offset ranges for multimodal items in each request
Returns:
A tuple containing:
- The generated embeddings tensor
- A boolean mask tensor indicating where these embeddings should be placed
"""
# 1. Get embedding
embedding = _get_precomputed_embedding(embedding_items)
if embedding is None:
embedding = _get_chunked_prefill_embedding(
data_embedding_func,
embedding_items,
items_size,
prefix_length,
extend_length,
items_offset_list,
)
if embedding is None:
return None, None
# 2. Get mask
if _is_npu:
torch.npu.current_stream().synchronize()
special_multimodal_mask = _get_multimodal_mask(input_ids, placeholder_tensor)
# 3. Adjust embedding length if needed
embedding = _adjust_embedding_length(embedding, special_multimodal_mask, logger)
return embedding, special_multimodal_mask
def embed_mm_inputs(
mm_inputs_list: List[MultimodalInputs],
extend_prefix_lens: List[int],
extend_seq_lens: List[int],
input_ids: torch.Tensor,
input_embedding: nn.Embedding,
multimodal_model: nn.Module = None,
data_embedding_func_mapping: Dict[
Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
] = None,
placeholder_tokens: dict[Modality, List[int]] = None,
) -> Optional[torch.Tensor]:
"""
Embed multimodal inputs and integrate them with text token embeddings.
Args:
mm_inputs_list: List of multimodal inputs to process
extend_prefix_lens: Prefix lengths for each request
extend_seq_lens: Sequence lengths for each request
input_ids: Input token IDs tensor
input_embedding: Embedding layer for text tokens
placeholder_tokens: Token IDs for multimodal placeholders (uses pad_values if None)
Returns:
Combined embedding tensor with multimodal content integrated
"""
if mm_inputs_list is None:
return None
# 1. Calculate the multimodal data which exists in input_ids, with the help of pad_values
# we assume that multimodal data are represented with its pad_values in input_ids
item_flatten_list = []
for mm_inputs in mm_inputs_list:
item_flatten_list += [item for item in mm_inputs.mm_items if item is not None]
embeddings, masks = [], []
# 2. Get multimodal embedding separately
# Try get mm embedding if any
for modality in Modality.all():
items = [
item for item in item_flatten_list if item.is_modality(modality=modality)
]
embedder = (
None
if data_embedding_func_mapping is None
else data_embedding_func_mapping.get(modality, None)
)
if embedder is None:
# "image", "video", etc
modality_id = modality.name.lower()
embedder = getattr(multimodal_model, f"get_{modality_id}_feature", None)
if len(items) != 0 and embedder is not None:
placeholder_tensor = torch.as_tensor(
[item.pad_value for item in items],
device=input_ids.device,
)
# calculate per request items length offset
items_size = torch.zeros(len(mm_inputs_list) + 1, dtype=int)
items_offsets = []
for i, mm_inputs in enumerate(mm_inputs_list):
mm_items = [
item
for item in mm_inputs.mm_items
if item.is_modality(modality=modality)
]
items_size[i + 1] = len(mm_items)
items_offsets.append(
flatten_nested_list([item.offsets for item in mm_items])
)
items_size = torch.cumsum(items_size, dim=0).tolist()
embedding, mask = get_embedding_and_mask(
data_embedding_func=embedder,
embedding_items=items,
placeholder_tensor=placeholder_tensor,
input_ids=input_ids,
items_size=items_size,
prefix_length=extend_prefix_lens,
extend_length=extend_seq_lens,
items_offset_list=items_offsets,
)
embeddings += [embedding]
masks += [mask]
# 3. Get input embeddings
vocab_size = input_embedding.num_embeddings
# Important: clamp after getting original multimodal regions
# Clamp input ids. This is because the input_ids for the multimodal tokens are
# filled with the hash values of the multimodal for the prefix matching in the radix attention.
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
input_ids.clamp_(min=0, max=vocab_size - 1)
inputs_embeds = input_embedding(input_ids)
# 4. scatter embeddings into input embedding
for embedding, mask in zip(embeddings, masks):
if embedding is None or mask is None:
continue
# in-place update
indices = torch.where(mask.squeeze(dim=-1))[0]
inputs_embeds[indices] = embedding.to(inputs_embeds.device, inputs_embeds.dtype)
return inputs_embeds
def general_mm_embed_routine(
input_ids: torch.Tensor,
forward_batch: ForwardBatch,
language_model: nn.Module,
multimodal_model: Optional[nn.Module] = None,
data_embedding_funcs: Dict[
Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
] = None,
placeholder_tokens: Optional[dict[Modality, List[int]]] = None,
**kwargs,
) -> torch.Tensor:
"""
Process multimodal inputs and forward through language model.
Args:
input_ids: Input token IDs tensor
forward_batch: Batch information for model forward pass
language_model: Base language model to use
data_embedding_funcs: A dictionary mapping from modality type to the corresponding embedding function.
placeholder_tokens: Token IDs for multimodal placeholders
**kwargs: Additional arguments passed to language model
Returns:
Hidden states from language model forward pass
"""
assert hasattr(language_model, "get_input_embeddings")
embed_tokens = language_model.get_input_embeddings()
if (
not forward_batch.forward_mode.is_decode()
and not forward_batch.forward_mode.is_target_verify()
and forward_batch.contains_mm_inputs()
):
mm_inputs_list = [
mm_input for mm_input in forward_batch.mm_inputs if mm_input is not None
]
extend_prefix_lens = [
prefix_len
for i, prefix_len in enumerate(forward_batch.extend_prefix_lens_cpu)
if forward_batch.mm_inputs[i] is not None
]
extend_seq_lens = [
seq_len
for i, seq_len in enumerate(forward_batch.extend_seq_lens_cpu)
if forward_batch.mm_inputs[i] is not None
]
inputs_embeds = embed_mm_inputs(
mm_inputs_list=mm_inputs_list,
extend_prefix_lens=extend_prefix_lens,
extend_seq_lens=extend_seq_lens,
input_ids=input_ids,
input_embedding=embed_tokens,
multimodal_model=multimodal_model,
data_embedding_func_mapping=data_embedding_funcs,
placeholder_tokens=placeholder_tokens,
)
# once used, mm_inputs is useless, considering chunked-prefill is disabled for multimodal models
# just being defensive here
forward_batch.mm_inputs = None
else:
inputs_embeds = embed_tokens(input_ids)
hidden_states = language_model(
input_ids=None,
forward_batch=forward_batch,
input_embeds=inputs_embeds,
**kwargs,
)
return hidden_states
def get_multimodal_data_bounds(
input_ids: torch.Tensor, pad_values: List[int], token_pairs: List[Tuple[int, int]]
) -> torch.Tensor:
"""
Returns a tensor indicating the bounds of multimodal data (images, video, audio, etc.)
Returns:
[bounds_count, 2]
"""
# All the multimodal data in the batch should share the same special bound token ids.
start_tokens = {s for s, _e in token_pairs}
end_tokens = {e for _s, e in token_pairs}
assert all(isinstance(t, int) for t in start_tokens)
assert all(isinstance(t, int) for t in end_tokens)
start_cond = torch.isin(
input_ids, torch.as_tensor(start_tokens, device=input_ids.device)
)
end_cond = torch.isin(
input_ids, torch.as_tensor(end_tokens, device=input_ids.device)
)
(data_start_tokens,) = torch.where(start_cond)
(data_end_tokens,) = torch.where(end_cond)
data_start_tokens_cpu = data_start_tokens.cpu().tolist()
data_end_tokens_cpu = data_end_tokens.cpu().tolist()
# the im_start_id sometimes can be cached as prefix, but it is needed for the embedding of the multimodal data
if len(data_start_tokens_cpu) != len(data_end_tokens_cpu):
if (
len(data_start_tokens_cpu) + 1 == len(data_end_tokens_cpu)
and input_ids[0].item() in pad_values
and data_end_tokens_cpu
and data_start_tokens_cpu
and data_end_tokens_cpu[0] < data_start_tokens_cpu[0]
):
data_start_tokens_cpu.insert(0, 0)
valid_mm_data_nums = min(len(data_start_tokens_cpu), len(data_end_tokens_cpu))
if valid_mm_data_nums == 0:
return torch.zeros((0, 2), device=input_ids.device)
# Filter out pairs where start_token >= end_token
valid_pairs = []
for i in range(valid_mm_data_nums):
start_token = data_start_tokens_cpu[i]
end_token = data_end_tokens_cpu[i]
if start_token < end_token:
valid_pairs.append((start_token + 1, end_token - 1))
if not valid_pairs:
return torch.zeros((0, 2), device=input_ids.device)
# Convert valid pairs to tensor
valid_pairs_tensor = torch.as_tensor(valid_pairs, device=input_ids.device)
return valid_pairs_tensor
def data_hash(data) -> int:
hash_bytes = hashlib.sha256(data).digest()[:8]
return int.from_bytes(hash_bytes, byteorder="big", signed=False)
def tensor_hash(tensor_list) -> int:
"""
hash a tensor or a tensor list
"""
tensor = tensor_list
if isinstance(tensor_list, list):
tensor_list = flatten_nested_list(tensor_list)
tensor_list = [
x.flatten() if isinstance(x, torch.Tensor) else x for x in tensor_list
]
tensor = torch.concat(tensor_list)
if tensor.is_cuda:
return gpu_tensor_hash(tensor.cuda())
tensor = tensor.detach().contiguous()
if tensor.dtype == torch.bfloat16:
# memoryview() doesn't support PyTorch's BFloat16 dtype
tensor = tensor.float()
assert isinstance(tensor, torch.Tensor)
tensor_cpu = tensor.cpu()
mv = memoryview(tensor_cpu.numpy())
return data_hash(mv.tobytes())
def hash_feature(f):
if isinstance(f, list):
if isinstance(f[0], torch.Tensor):
return tensor_hash(f)
return data_hash(tuple(flatten_nested_list(f)))
elif isinstance(f, np.ndarray):
arr = np.ascontiguousarray(f)
arr_bytes = arr.tobytes()
return data_hash(arr_bytes)
elif isinstance(f, torch.Tensor):
return tensor_hash([f])
return data_hash(f)