chore: improvements on mm_utils (#7737)
This commit is contained in:
@@ -85,8 +85,8 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
|
|||||||
"No data_token_pairs provided, RadixAttention might be influenced."
|
"No data_token_pairs provided, RadixAttention might be influenced."
|
||||||
)
|
)
|
||||||
return input_ids
|
return input_ids
|
||||||
start_token_ids = [s for s, _e in data_token_pairs]
|
start_token_ids = {s for s, _e in data_token_pairs}
|
||||||
end_tokens_ids = [e for _s, e in data_token_pairs]
|
end_tokens_ids = {e for _s, e in data_token_pairs}
|
||||||
|
|
||||||
padded_ids = []
|
padded_ids = []
|
||||||
last_idx = 0
|
last_idx = 0
|
||||||
@@ -135,7 +135,7 @@ class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPa
|
|||||||
if not input_ids or not mm_inputs.mm_items:
|
if not input_ids or not mm_inputs.mm_items:
|
||||||
return input_ids
|
return input_ids
|
||||||
|
|
||||||
input_ids_tensor = torch.tensor(input_ids)
|
input_ids_tensor = torch.as_tensor(input_ids)
|
||||||
|
|
||||||
# Create mapping of token_ids to pad_values for each modality
|
# Create mapping of token_ids to pad_values for each modality
|
||||||
token_to_pad_mapping = {}
|
token_to_pad_mapping = {}
|
||||||
@@ -211,7 +211,7 @@ def get_embedding_chunk(
|
|||||||
end_index += extend_end_index - start + 1
|
end_index += extend_end_index - start + 1
|
||||||
elif extend_end_index > end:
|
elif extend_end_index > end:
|
||||||
end_index += end - start + 1
|
end_index += end - start + 1
|
||||||
# some models embedding is 3-dim, reshape it to 2-dim
|
# some models' embedding is 3-dim, reshape it to 2-dim
|
||||||
embedding = embedding.reshape(-1, embedding.shape[-1])
|
embedding = embedding.reshape(-1, embedding.shape[-1])
|
||||||
embedding_chunk = embedding[start_index:end_index]
|
embedding_chunk = embedding[start_index:end_index]
|
||||||
return embedding_chunk, start_index, end_index
|
return embedding_chunk, start_index, end_index
|
||||||
@@ -428,7 +428,7 @@ def embed_mm_inputs(
|
|||||||
modality_id = modality.name.lower()
|
modality_id = modality.name.lower()
|
||||||
embedder = getattr(multimodal_model, f"get_{modality_id}_feature", None)
|
embedder = getattr(multimodal_model, f"get_{modality_id}_feature", None)
|
||||||
if len(items) != 0 and embedder is not None:
|
if len(items) != 0 and embedder is not None:
|
||||||
placeholder_tensor = torch.tensor(
|
placeholder_tensor = torch.as_tensor(
|
||||||
[item.pad_value for item in items],
|
[item.pad_value for item in items],
|
||||||
device=input_ids.device,
|
device=input_ids.device,
|
||||||
)
|
)
|
||||||
@@ -473,11 +473,9 @@ def embed_mm_inputs(
|
|||||||
for embedding, mask in zip(embeddings, masks):
|
for embedding, mask in zip(embeddings, masks):
|
||||||
if embedding is None or mask is None:
|
if embedding is None or mask is None:
|
||||||
continue
|
continue
|
||||||
mask = mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
# in-place update
|
||||||
inputs_embeds = inputs_embeds.masked_scatter(
|
indices = torch.where(mask.squeeze(dim=-1))[0]
|
||||||
mask,
|
inputs_embeds[indices] = embedding.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||||
embedding.to(inputs_embeds.device, inputs_embeds.dtype),
|
|
||||||
)
|
|
||||||
return inputs_embeds
|
return inputs_embeds
|
||||||
|
|
||||||
|
|
||||||
@@ -561,34 +559,36 @@ def get_multimodal_data_bounds(
|
|||||||
[bounds_count, 2]
|
[bounds_count, 2]
|
||||||
"""
|
"""
|
||||||
# All the multimodal data in the batch should share the same special bound token ids.
|
# All the multimodal data in the batch should share the same special bound token ids.
|
||||||
start_tokens = [s for s, _e in token_pairs]
|
start_tokens = {s for s, _e in token_pairs}
|
||||||
end_tokens = [e for _s, e in token_pairs]
|
end_tokens = {e for _s, e in token_pairs}
|
||||||
|
|
||||||
assert all(isinstance(t, int) for t in start_tokens)
|
assert all(isinstance(t, int) for t in start_tokens)
|
||||||
assert all(isinstance(t, int) for t in end_tokens)
|
assert all(isinstance(t, int) for t in end_tokens)
|
||||||
|
|
||||||
start_cond = torch.isin(
|
start_cond = torch.isin(
|
||||||
input_ids, torch.tensor(start_tokens, device=input_ids.device)
|
input_ids, torch.as_tensor(start_tokens, device=input_ids.device)
|
||||||
|
)
|
||||||
|
end_cond = torch.isin(
|
||||||
|
input_ids, torch.as_tensor(end_tokens, device=input_ids.device)
|
||||||
)
|
)
|
||||||
end_cond = torch.isin(input_ids, torch.tensor(end_tokens, device=input_ids.device))
|
|
||||||
|
|
||||||
(data_start_tokens,) = torch.where(start_cond)
|
(data_start_tokens,) = torch.where(start_cond)
|
||||||
(data_end_tokens,) = torch.where(end_cond)
|
(data_end_tokens,) = torch.where(end_cond)
|
||||||
|
|
||||||
|
data_start_tokens_cpu = data_start_tokens.cpu().tolist()
|
||||||
|
data_end_tokens_cpu = data_end_tokens.cpu().tolist()
|
||||||
|
|
||||||
# the im_start_id sometimes can be cached as prefix, but it is needed for the embedding of the multimodal data
|
# the im_start_id sometimes can be cached as prefix, but it is needed for the embedding of the multimodal data
|
||||||
if len(data_start_tokens) != len(data_end_tokens):
|
if len(data_start_tokens_cpu) != len(data_end_tokens_cpu):
|
||||||
if (
|
if (
|
||||||
len(data_start_tokens) + 1 == len(data_end_tokens)
|
len(data_start_tokens_cpu) + 1 == len(data_end_tokens_cpu)
|
||||||
and input_ids[0] in pad_values
|
and input_ids[0].item() in pad_values
|
||||||
and data_end_tokens[0] < data_start_tokens[0]
|
and data_end_tokens_cpu
|
||||||
|
and data_start_tokens_cpu
|
||||||
|
and data_end_tokens_cpu[0] < data_start_tokens_cpu[0]
|
||||||
):
|
):
|
||||||
data_start_tokens = torch.cat(
|
data_start_tokens_cpu.insert(0, 0)
|
||||||
[
|
valid_mm_data_nums = min(len(data_start_tokens_cpu), len(data_end_tokens_cpu))
|
||||||
torch.tensor([0], device=data_start_tokens.device),
|
|
||||||
data_start_tokens,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
valid_mm_data_nums = min(len(data_start_tokens), len(data_end_tokens))
|
|
||||||
|
|
||||||
if valid_mm_data_nums == 0:
|
if valid_mm_data_nums == 0:
|
||||||
return torch.zeros((0, 2), device=input_ids.device)
|
return torch.zeros((0, 2), device=input_ids.device)
|
||||||
@@ -596,8 +596,8 @@ def get_multimodal_data_bounds(
|
|||||||
# Filter out pairs where start_token >= end_token
|
# Filter out pairs where start_token >= end_token
|
||||||
valid_pairs = []
|
valid_pairs = []
|
||||||
for i in range(valid_mm_data_nums):
|
for i in range(valid_mm_data_nums):
|
||||||
start_token = data_start_tokens[i]
|
start_token = data_start_tokens_cpu[i]
|
||||||
end_token = data_end_tokens[i]
|
end_token = data_end_tokens_cpu[i]
|
||||||
if start_token < end_token:
|
if start_token < end_token:
|
||||||
valid_pairs.append((start_token + 1, end_token - 1))
|
valid_pairs.append((start_token + 1, end_token - 1))
|
||||||
|
|
||||||
@@ -605,7 +605,7 @@ def get_multimodal_data_bounds(
|
|||||||
return torch.zeros((0, 2), device=input_ids.device)
|
return torch.zeros((0, 2), device=input_ids.device)
|
||||||
|
|
||||||
# Convert valid pairs to tensor
|
# Convert valid pairs to tensor
|
||||||
valid_pairs_tensor = torch.tensor(valid_pairs, device=input_ids.device)
|
valid_pairs_tensor = torch.as_tensor(valid_pairs, device=input_ids.device)
|
||||||
return valid_pairs_tensor
|
return valid_pairs_tensor
|
||||||
|
|
||||||
|
|
||||||
@@ -634,11 +634,7 @@ def tensor_hash(tensor_list) -> int:
|
|||||||
tensor = tensor.float()
|
tensor = tensor.float()
|
||||||
|
|
||||||
assert isinstance(tensor, torch.Tensor)
|
assert isinstance(tensor, torch.Tensor)
|
||||||
if tensor.is_cuda:
|
tensor_cpu = tensor.cpu()
|
||||||
# TODO: improve this
|
|
||||||
tensor_cpu = tensor.cpu()
|
|
||||||
else:
|
|
||||||
tensor_cpu = tensor
|
|
||||||
|
|
||||||
mv = memoryview(tensor_cpu.numpy())
|
mv = memoryview(tensor_cpu.numpy())
|
||||||
return data_hash(mv.tobytes())
|
return data_hash(mv.tobytes())
|
||||||
|
|||||||
Reference in New Issue
Block a user