refactor: multimodal data (#4754)
This commit is contained in:
@@ -31,7 +31,7 @@ from transformers import (
|
||||
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
||||
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.managers.schedule_batch import MultimodalInputs
|
||||
from sglang.srt.managers.schedule_batch import Modality, MultimodalInputs
|
||||
from sglang.srt.mm_utils import (
|
||||
get_anyres_image_grid_shape,
|
||||
unpad_image,
|
||||
@@ -42,17 +42,21 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.llama import LlamaForCausalLM
|
||||
from sglang.srt.models.mistral import MistralForCausalLM
|
||||
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
||||
from sglang.srt.utils import add_prefix
|
||||
from sglang.srt.utils import add_prefix, flatten_nested_list
|
||||
|
||||
|
||||
class LlavaBaseForCausalLM(nn.Module):
|
||||
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
|
||||
image_sizes, pad_values = image_inputs.image_sizes, image_inputs.pad_values
|
||||
image_sizes = flatten_nested_list(
|
||||
[item.image_sizes for item in image_inputs.mm_items]
|
||||
)
|
||||
|
||||
pad_values = [item.pad_value for item in image_inputs.mm_items]
|
||||
|
||||
# hardcode for spatial_unpad + anyres
|
||||
if image_inputs.modalities is not None and (
|
||||
"multi-images" in image_inputs.modalities
|
||||
or "video" in image_inputs.modalities
|
||||
if any(
|
||||
item.modality == Modality.MULTI_IMAGES or item.modality == Modality.VIDEO
|
||||
for item in image_inputs.mm_items
|
||||
):
|
||||
image_aspect_ratio = "pad"
|
||||
else:
|
||||
@@ -66,7 +70,7 @@ class LlavaBaseForCausalLM(nn.Module):
|
||||
math.ceil(self.image_size / self.patch_size / 2) ** 2
|
||||
)
|
||||
else:
|
||||
new_image_feature_len = self.image_feature_len # multiimage
|
||||
new_image_feature_len = self.image_feature_len # multi-image
|
||||
|
||||
height = width = self.num_patches_per_side
|
||||
if "anyres" in image_aspect_ratio:
|
||||
@@ -101,7 +105,7 @@ class LlavaBaseForCausalLM(nn.Module):
|
||||
# old_len + pad_len - 1, because we need to remove image_token_id
|
||||
input_ids = (
|
||||
input_ids[:offset]
|
||||
+ [pad_values[image_idx]] * new_image_feature_len
|
||||
+ [pad_values[image_idx % len(pad_values)]] * new_image_feature_len
|
||||
+ input_ids[offset + 1 :]
|
||||
)
|
||||
offset_list.append(offset)
|
||||
@@ -150,8 +154,8 @@ class LlavaBaseForCausalLM(nn.Module):
|
||||
modalities_list = []
|
||||
max_image_offset = []
|
||||
for im in image_inputs:
|
||||
if im and im.modalities is not None:
|
||||
modalities_list.extend(im.modalities)
|
||||
if im:
|
||||
modalities_list.extend([item.modality for item in im.mm_items])
|
||||
if im and im.image_offsets:
|
||||
max_image_offset.append(
|
||||
np.max(np.array(im.image_offsets) + np.array(im.image_pad_len))
|
||||
@@ -164,11 +168,19 @@ class LlavaBaseForCausalLM(nn.Module):
|
||||
|
||||
if need_vision.any():
|
||||
bs = forward_batch.batch_size
|
||||
pixel_values = [
|
||||
image_inputs[i].pixel_values for i in range(bs) if need_vision[i]
|
||||
]
|
||||
pixel_values = flatten_nested_list(
|
||||
[
|
||||
[item.pixel_values for item in image_inputs[i].mm_items]
|
||||
for i in range(bs)
|
||||
if need_vision[i]
|
||||
]
|
||||
)
|
||||
image_sizes = [
|
||||
image_inputs[i].image_sizes for i in range(bs) if need_vision[i]
|
||||
flatten_nested_list(
|
||||
[item.image_sizes for item in image_inputs[i].mm_items]
|
||||
)
|
||||
for i in range(bs)
|
||||
if need_vision[i]
|
||||
]
|
||||
|
||||
########## Encode Image ########
|
||||
@@ -197,13 +209,13 @@ class LlavaBaseForCausalLM(nn.Module):
|
||||
new_image_features = []
|
||||
height = width = self.num_patches_per_side
|
||||
for image_idx, image_feature in enumerate(image_features):
|
||||
if modalities_list[image_idx] == "image":
|
||||
if modalities_list[image_idx] == Modality.IMAGE:
|
||||
image_aspect_ratio = (
|
||||
self.config.image_aspect_ratio
|
||||
) # single image
|
||||
elif (
|
||||
modalities_list[image_idx] == "multi-images"
|
||||
or modalities_list[image_idx] == "video"
|
||||
modalities_list[image_idx] == Modality.MULTI_IMAGES
|
||||
or modalities_list[image_idx] == Modality.VIDEO
|
||||
):
|
||||
image_aspect_ratio = "pad" # multi image
|
||||
# image_aspect_ratio = (
|
||||
@@ -212,7 +224,7 @@ class LlavaBaseForCausalLM(nn.Module):
|
||||
if (
|
||||
image_feature.shape[0] > 1
|
||||
and "anyres" in image_aspect_ratio
|
||||
and modalities_list[image_idx] == "image"
|
||||
and modalities_list[image_idx] == Modality.IMAGE
|
||||
):
|
||||
base_image_feature = image_feature[0]
|
||||
image_feature = image_feature[1:]
|
||||
@@ -312,7 +324,7 @@ class LlavaBaseForCausalLM(nn.Module):
|
||||
)
|
||||
image_feature = image_feature.unsqueeze(0)
|
||||
else:
|
||||
if modalities_list[image_idx] == "video": # video
|
||||
if modalities_list[image_idx] == Modality.VIDEO: # video
|
||||
# 2x2 pooling
|
||||
num_of_frames = image_feature.shape[0]
|
||||
image_feature = image_feature.view(
|
||||
|
||||
Reference in New Issue
Block a user