model: Minicpmo (#3023)
This commit is contained in:
@@ -9,7 +9,7 @@ import torch
|
||||
from torch import nn
|
||||
|
||||
from sglang.srt.managers.schedule_batch import (
|
||||
ImageInputs,
|
||||
MultimodalInputs,
|
||||
global_server_args_dict,
|
||||
logger,
|
||||
)
|
||||
@@ -26,7 +26,7 @@ class MultiModalityDataPaddingPattern:
|
||||
|
||||
@abstractmethod
|
||||
def pad_input_tokens(
|
||||
self, input_ids: List[int], image_inputs: ImageInputs
|
||||
self, input_ids: List[int], image_inputs: MultimodalInputs
|
||||
) -> List[int]:
|
||||
"""
|
||||
Pad the input ids sequence containing data tokens, and replace them with pad_values
|
||||
@@ -44,16 +44,16 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
|
||||
self.data_token_id_pairs = data_token_pairs
|
||||
|
||||
def pad_input_tokens(
|
||||
self, input_ids: List[int], image_inputs: ImageInputs
|
||||
self, input_ids: List[int], mm_inputs: MultimodalInputs
|
||||
) -> List[int]:
|
||||
"""
|
||||
This function will replace the data-tokens inbetween with pad_values accordingly
|
||||
"""
|
||||
pad_values = image_inputs.pad_values
|
||||
pad_values = mm_inputs.pad_values
|
||||
data_token_pairs = self.data_token_id_pairs
|
||||
image_inputs.image_offsets = []
|
||||
mm_inputs.image_offsets = []
|
||||
if data_token_pairs is None:
|
||||
data_token_pairs = [image_inputs.im_start_id, image_inputs.im_end_id]
|
||||
data_token_pairs = [mm_inputs.im_start_id, mm_inputs.im_end_id]
|
||||
if data_token_pairs is None:
|
||||
logger.warning(
|
||||
"No data_token_pairs provided, RadixAttention might be influenced."
|
||||
@@ -61,8 +61,6 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
|
||||
return input_ids
|
||||
start_token_ids = [s for s, _e in data_token_pairs]
|
||||
end_tokens_ids = [e for _s, e in data_token_pairs]
|
||||
# First start token marks new data
|
||||
data_start_token = start_token_ids[0]
|
||||
|
||||
padded_ids = []
|
||||
last_idx = 0
|
||||
@@ -77,9 +75,12 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
|
||||
for start_idx, end_idx in zip(start_indices, end_indices):
|
||||
padded_ids.extend(input_ids[last_idx : start_idx + 1])
|
||||
|
||||
if input_ids[start_idx] == data_start_token:
|
||||
if input_ids[start_idx] in start_token_ids:
|
||||
data_idx += 1
|
||||
image_inputs.image_offsets += [start_idx]
|
||||
mm_inputs.image_offsets += [start_idx]
|
||||
|
||||
if data_idx >= len(mm_inputs.pad_values):
|
||||
data_idx = len(mm_inputs.pad_values) - 1
|
||||
|
||||
num_tokens = end_idx - start_idx - 1
|
||||
pad_value = pad_values[data_idx]
|
||||
@@ -89,7 +90,7 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
|
||||
|
||||
padded_ids.extend(input_ids[last_idx:])
|
||||
|
||||
assert len(input_ids) == len(padded_ids)
|
||||
assert len(input_ids) == len(padded_ids), "Length validation fails"
|
||||
return padded_ids
|
||||
|
||||
|
||||
@@ -107,26 +108,25 @@ class MultModalityDataPaddingPatternSingleToken(MultiModalityDataPaddingPattern)
|
||||
self.num_data_token_calc_func = num_data_token_calc_func
|
||||
|
||||
def pad_input_tokens(
|
||||
self, input_ids: List[int], image_inputs: ImageInputs
|
||||
self, input_ids: List[int], mm_inputs: MultimodalInputs
|
||||
) -> List[int]:
|
||||
"""
|
||||
This function will follow the procedure of:
|
||||
1. the data token will be expanded, of which the final number will be calculated by `num_data_token_calc_func`
|
||||
2. the padded data tokens will be replaced with their pad_values
|
||||
"""
|
||||
image_grid_thws = image_inputs.image_grid_thws
|
||||
pad_values = image_inputs.pad_values
|
||||
image_grid_thws = mm_inputs.image_grid_thws
|
||||
pad_values = mm_inputs.pad_values
|
||||
|
||||
image_indices = [
|
||||
idx
|
||||
for idx, token in enumerate(input_ids)
|
||||
if token == image_inputs.im_token_id
|
||||
idx for idx, token in enumerate(input_ids) if token == mm_inputs.im_token_id
|
||||
]
|
||||
|
||||
image_inputs.image_offsets = []
|
||||
mm_inputs.image_offsets = []
|
||||
|
||||
input_ids_with_image = []
|
||||
for image_cnt, _ in enumerate(image_grid_thws):
|
||||
# print(f"image_cnt {image_cnt}")
|
||||
num_image_tokens = self.num_data_token_calc_func(image_grid_thws[image_cnt])
|
||||
if image_cnt == 0:
|
||||
non_image_tokens = input_ids[: image_indices[image_cnt]]
|
||||
@@ -135,7 +135,7 @@ class MultModalityDataPaddingPatternSingleToken(MultiModalityDataPaddingPattern)
|
||||
image_indices[image_cnt - 1] + 1 : image_indices[image_cnt]
|
||||
]
|
||||
input_ids_with_image.extend(non_image_tokens)
|
||||
image_inputs.image_offsets.append(len(input_ids_with_image))
|
||||
mm_inputs.image_offsets.append(len(input_ids_with_image))
|
||||
pad_ids = pad_values * (
|
||||
(num_image_tokens + len(pad_values)) // len(pad_values)
|
||||
)
|
||||
@@ -170,11 +170,11 @@ class MultiModalityDataPaddingPatternImageTokens(MultiModalityDataPaddingPattern
|
||||
return input_ids_tensor.tolist()
|
||||
|
||||
|
||||
def embed_image_inputs(
|
||||
image_input: ImageInputs,
|
||||
def embed_mm_inputs(
|
||||
mm_input: MultimodalInputs,
|
||||
input_ids: torch.Tensor,
|
||||
input_embedding: nn.Embedding,
|
||||
image_embedding_func,
|
||||
mm_data_embedding_func: Callable[[MultimodalInputs], torch.Tensor],
|
||||
placeholder_token_ids: List[int] = None,
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
@@ -184,10 +184,10 @@ def embed_image_inputs(
|
||||
Returns:
|
||||
final embedding: Optional[torch.Tensor]
|
||||
"""
|
||||
if image_input is None:
|
||||
if mm_input is None:
|
||||
return None
|
||||
|
||||
placeholder_token_ids = placeholder_token_ids or image_input.pad_values
|
||||
placeholder_token_ids = placeholder_token_ids or mm_input.pad_values
|
||||
|
||||
# boolean masking the special tokens
|
||||
special_image_mask = torch.isin(
|
||||
@@ -196,12 +196,18 @@ def embed_image_inputs(
|
||||
).unsqueeze(-1)
|
||||
|
||||
num_image_tokens_in_input_ids = special_image_mask.sum()
|
||||
# print(f"{num_image_tokens_in_input_ids}")
|
||||
# print(f"{input_ids}")
|
||||
|
||||
# return
|
||||
if num_image_tokens_in_input_ids == 0:
|
||||
# unexpected
|
||||
inputs_embeds = input_embedding(input_ids)
|
||||
else:
|
||||
image_embedding = image_embedding_func(image_input)
|
||||
# print(f"Getting image feature")
|
||||
image_embedding = mm_data_embedding_func(mm_input)
|
||||
|
||||
# print(f"image_embedding: {image_embedding.shape}")
|
||||
|
||||
if image_embedding.dim() == 2:
|
||||
num_image_tokens_in_embedding = image_embedding.shape[0]
|
||||
@@ -273,31 +279,95 @@ def embed_image_embedding(
|
||||
|
||||
def general_mm_embed_routine(
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
embed_tokens: nn.Embedding,
|
||||
image_embedding_func: Callable[[ImageInputs], torch.Tensor],
|
||||
mm_data_embedding_func: Callable[[MultimodalInputs], torch.Tensor],
|
||||
placeholder_token_ids: List[int] = None,
|
||||
):
|
||||
"""
|
||||
a general wrapper function to get final input embeds from multimodal models
|
||||
with a language model as causal model
|
||||
|
||||
Args:
|
||||
placeholder_token_ids (List[int]): the ids of mm data placeholder tokens
|
||||
|
||||
"""
|
||||
if (
|
||||
forward_batch.forward_mode.is_decode()
|
||||
or not forward_batch.contains_image_inputs()
|
||||
not forward_batch.forward_mode.is_decode()
|
||||
and forward_batch.contains_mm_inputs()
|
||||
):
|
||||
inputs_embeds = embed_tokens(input_ids)
|
||||
else:
|
||||
image = forward_batch.merge_image_inputs()
|
||||
inputs_embeds = embed_image_inputs(
|
||||
image_input=image,
|
||||
image = forward_batch.merge_mm_inputs()
|
||||
inputs_embeds = embed_mm_inputs(
|
||||
mm_input=image,
|
||||
input_ids=input_ids,
|
||||
input_embedding=embed_tokens,
|
||||
image_embedding_func=image_embedding_func,
|
||||
mm_data_embedding_func=mm_data_embedding_func,
|
||||
placeholder_token_ids=placeholder_token_ids,
|
||||
)
|
||||
# once used, image_inputs is useless
|
||||
# once used, mm_inputs is useless
|
||||
# just being defensive here
|
||||
forward_batch.image_inputs = None
|
||||
forward_batch.mm_inputs = None
|
||||
else:
|
||||
inputs_embeds = embed_tokens(input_ids)
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
|
||||
def get_multimodal_data_bounds(
|
||||
input_ids: torch.Tensor, pad_values: List[int], token_pairs: List[Tuple[int, int]]
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Returns a tensor indicating the bounds of multimodal data (images, video, audio, etc.)
|
||||
|
||||
Returns:
|
||||
[bounds_count, 2]
|
||||
"""
|
||||
# All the images in the batch should share the same special image
|
||||
# bound token ids.
|
||||
start_tokens = [s for s, _e in token_pairs]
|
||||
end_tokens = [e for _s, e in token_pairs]
|
||||
|
||||
assert all(isinstance(t, int) for t in start_tokens)
|
||||
assert all(isinstance(t, int) for t in end_tokens)
|
||||
|
||||
# print(input_ids)
|
||||
start_cond = torch.isin(
|
||||
input_ids, torch.tensor(start_tokens, device=input_ids.device)
|
||||
)
|
||||
end_cond = torch.isin(input_ids, torch.tensor(end_tokens, device=input_ids.device))
|
||||
|
||||
(data_start_tokens,) = torch.where(start_cond)
|
||||
(data_end_tokens,) = torch.where(end_cond)
|
||||
|
||||
# the im_start_id sometimes can be cached as prefix, but it is needed for the embedding of the images
|
||||
if len(data_start_tokens) != len(data_end_tokens):
|
||||
if (
|
||||
len(data_start_tokens) + 1 == len(data_end_tokens)
|
||||
and input_ids[0] in pad_values
|
||||
and data_end_tokens[0] < data_start_tokens[0]
|
||||
):
|
||||
data_start_tokens = torch.cat(
|
||||
[
|
||||
torch.tensor([0], device=data_start_tokens.device),
|
||||
data_start_tokens,
|
||||
]
|
||||
)
|
||||
valid_image_nums = min(len(data_start_tokens), len(data_end_tokens))
|
||||
|
||||
if valid_image_nums == 0:
|
||||
return torch.zeros((0, 2), device=input_ids.device)
|
||||
|
||||
# Filter out pairs where start_token >= end_token
|
||||
valid_pairs = []
|
||||
for i in range(valid_image_nums):
|
||||
start_token = data_start_tokens[i]
|
||||
end_token = data_end_tokens[i]
|
||||
if start_token < end_token:
|
||||
valid_pairs.append((start_token + 1, end_token - 1))
|
||||
|
||||
if not valid_pairs:
|
||||
return torch.zeros((0, 2), device=input_ids.device)
|
||||
|
||||
# Convert valid pairs to tensor
|
||||
valid_pairs_tensor = torch.tensor(valid_pairs, device=input_ids.device)
|
||||
return valid_pairs_tensor
|
||||
|
||||
Reference in New Issue
Block a user