chore: improvements on mm_utils (#7737)
This commit is contained in:
@@ -85,8 +85,8 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
|
||||
"No data_token_pairs provided, RadixAttention might be influenced."
|
||||
)
|
||||
return input_ids
|
||||
start_token_ids = [s for s, _e in data_token_pairs]
|
||||
end_tokens_ids = [e for _s, e in data_token_pairs]
|
||||
start_token_ids = {s for s, _e in data_token_pairs}
|
||||
end_tokens_ids = {e for _s, e in data_token_pairs}
|
||||
|
||||
padded_ids = []
|
||||
last_idx = 0
|
||||
@@ -135,7 +135,7 @@ class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPa
|
||||
if not input_ids or not mm_inputs.mm_items:
|
||||
return input_ids
|
||||
|
||||
input_ids_tensor = torch.tensor(input_ids)
|
||||
input_ids_tensor = torch.as_tensor(input_ids)
|
||||
|
||||
# Create mapping of token_ids to pad_values for each modality
|
||||
token_to_pad_mapping = {}
|
||||
@@ -211,7 +211,7 @@ def get_embedding_chunk(
|
||||
end_index += extend_end_index - start + 1
|
||||
elif extend_end_index > end:
|
||||
end_index += end - start + 1
|
||||
# some models embedding is 3-dim, reshape it to 2-dim
|
||||
# some models' embedding is 3-dim, reshape it to 2-dim
|
||||
embedding = embedding.reshape(-1, embedding.shape[-1])
|
||||
embedding_chunk = embedding[start_index:end_index]
|
||||
return embedding_chunk, start_index, end_index
|
||||
@@ -428,7 +428,7 @@ def embed_mm_inputs(
|
||||
modality_id = modality.name.lower()
|
||||
embedder = getattr(multimodal_model, f"get_{modality_id}_feature", None)
|
||||
if len(items) != 0 and embedder is not None:
|
||||
placeholder_tensor = torch.tensor(
|
||||
placeholder_tensor = torch.as_tensor(
|
||||
[item.pad_value for item in items],
|
||||
device=input_ids.device,
|
||||
)
|
||||
@@ -473,11 +473,9 @@ def embed_mm_inputs(
|
||||
for embedding, mask in zip(embeddings, masks):
|
||||
if embedding is None or mask is None:
|
||||
continue
|
||||
mask = mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(
|
||||
mask,
|
||||
embedding.to(inputs_embeds.device, inputs_embeds.dtype),
|
||||
)
|
||||
# in-place update
|
||||
indices = torch.where(mask.squeeze(dim=-1))[0]
|
||||
inputs_embeds[indices] = embedding.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
return inputs_embeds
|
||||
|
||||
|
||||
@@ -561,34 +559,36 @@ def get_multimodal_data_bounds(
|
||||
[bounds_count, 2]
|
||||
"""
|
||||
# All the multimodal data in the batch should share the same special bound token ids.
|
||||
start_tokens = [s for s, _e in token_pairs]
|
||||
end_tokens = [e for _s, e in token_pairs]
|
||||
start_tokens = {s for s, _e in token_pairs}
|
||||
end_tokens = {e for _s, e in token_pairs}
|
||||
|
||||
assert all(isinstance(t, int) for t in start_tokens)
|
||||
assert all(isinstance(t, int) for t in end_tokens)
|
||||
|
||||
start_cond = torch.isin(
|
||||
input_ids, torch.tensor(start_tokens, device=input_ids.device)
|
||||
input_ids, torch.as_tensor(start_tokens, device=input_ids.device)
|
||||
)
|
||||
end_cond = torch.isin(
|
||||
input_ids, torch.as_tensor(end_tokens, device=input_ids.device)
|
||||
)
|
||||
end_cond = torch.isin(input_ids, torch.tensor(end_tokens, device=input_ids.device))
|
||||
|
||||
(data_start_tokens,) = torch.where(start_cond)
|
||||
(data_end_tokens,) = torch.where(end_cond)
|
||||
|
||||
data_start_tokens_cpu = data_start_tokens.cpu().tolist()
|
||||
data_end_tokens_cpu = data_end_tokens.cpu().tolist()
|
||||
|
||||
# the im_start_id sometimes can be cached as prefix, but it is needed for the embedding of the multimodal data
|
||||
if len(data_start_tokens) != len(data_end_tokens):
|
||||
if len(data_start_tokens_cpu) != len(data_end_tokens_cpu):
|
||||
if (
|
||||
len(data_start_tokens) + 1 == len(data_end_tokens)
|
||||
and input_ids[0] in pad_values
|
||||
and data_end_tokens[0] < data_start_tokens[0]
|
||||
len(data_start_tokens_cpu) + 1 == len(data_end_tokens_cpu)
|
||||
and input_ids[0].item() in pad_values
|
||||
and data_end_tokens_cpu
|
||||
and data_start_tokens_cpu
|
||||
and data_end_tokens_cpu[0] < data_start_tokens_cpu[0]
|
||||
):
|
||||
data_start_tokens = torch.cat(
|
||||
[
|
||||
torch.tensor([0], device=data_start_tokens.device),
|
||||
data_start_tokens,
|
||||
]
|
||||
)
|
||||
valid_mm_data_nums = min(len(data_start_tokens), len(data_end_tokens))
|
||||
data_start_tokens_cpu.insert(0, 0)
|
||||
valid_mm_data_nums = min(len(data_start_tokens_cpu), len(data_end_tokens_cpu))
|
||||
|
||||
if valid_mm_data_nums == 0:
|
||||
return torch.zeros((0, 2), device=input_ids.device)
|
||||
@@ -596,8 +596,8 @@ def get_multimodal_data_bounds(
|
||||
# Filter out pairs where start_token >= end_token
|
||||
valid_pairs = []
|
||||
for i in range(valid_mm_data_nums):
|
||||
start_token = data_start_tokens[i]
|
||||
end_token = data_end_tokens[i]
|
||||
start_token = data_start_tokens_cpu[i]
|
||||
end_token = data_end_tokens_cpu[i]
|
||||
if start_token < end_token:
|
||||
valid_pairs.append((start_token + 1, end_token - 1))
|
||||
|
||||
@@ -605,7 +605,7 @@ def get_multimodal_data_bounds(
|
||||
return torch.zeros((0, 2), device=input_ids.device)
|
||||
|
||||
# Convert valid pairs to tensor
|
||||
valid_pairs_tensor = torch.tensor(valid_pairs, device=input_ids.device)
|
||||
valid_pairs_tensor = torch.as_tensor(valid_pairs, device=input_ids.device)
|
||||
return valid_pairs_tensor
|
||||
|
||||
|
||||
@@ -634,11 +634,7 @@ def tensor_hash(tensor_list) -> int:
|
||||
tensor = tensor.float()
|
||||
|
||||
assert isinstance(tensor, torch.Tensor)
|
||||
if tensor.is_cuda:
|
||||
# TODO: improve this
|
||||
tensor_cpu = tensor.cpu()
|
||||
else:
|
||||
tensor_cpu = tensor
|
||||
tensor_cpu = tensor.cpu()
|
||||
|
||||
mv = memoryview(tensor_cpu.numpy())
|
||||
return data_hash(mv.tobytes())
|
||||
|
||||
Reference in New Issue
Block a user