refactor: multimodal data (#4754)

This commit is contained in:
Mick
2025-04-01 00:57:51 +08:00
committed by GitHub
parent c7457191a0
commit 5cb552b1d4
36 changed files with 989 additions and 1138 deletions

View File

@@ -31,7 +31,7 @@ from transformers import (
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.schedule_batch import MultimodalInputs
from sglang.srt.managers.schedule_batch import Modality, MultimodalInputs
from sglang.srt.mm_utils import (
get_anyres_image_grid_shape,
unpad_image,
@@ -42,17 +42,21 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.llama import LlamaForCausalLM
from sglang.srt.models.mistral import MistralForCausalLM
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
from sglang.srt.utils import add_prefix
from sglang.srt.utils import add_prefix, flatten_nested_list
class LlavaBaseForCausalLM(nn.Module):
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
image_sizes, pad_values = image_inputs.image_sizes, image_inputs.pad_values
image_sizes = flatten_nested_list(
[item.image_sizes for item in image_inputs.mm_items]
)
pad_values = [item.pad_value for item in image_inputs.mm_items]
# hardcode for spatial_unpad + anyres
if image_inputs.modalities is not None and (
"multi-images" in image_inputs.modalities
or "video" in image_inputs.modalities
if any(
item.modality == Modality.MULTI_IMAGES or item.modality == Modality.VIDEO
for item in image_inputs.mm_items
):
image_aspect_ratio = "pad"
else:
@@ -66,7 +70,7 @@ class LlavaBaseForCausalLM(nn.Module):
math.ceil(self.image_size / self.patch_size / 2) ** 2
)
else:
new_image_feature_len = self.image_feature_len # multiimage
new_image_feature_len = self.image_feature_len # multi-image
height = width = self.num_patches_per_side
if "anyres" in image_aspect_ratio:
@@ -101,7 +105,7 @@ class LlavaBaseForCausalLM(nn.Module):
# old_len + pad_len - 1, because we need to remove image_token_id
input_ids = (
input_ids[:offset]
+ [pad_values[image_idx]] * new_image_feature_len
+ [pad_values[image_idx % len(pad_values)]] * new_image_feature_len
+ input_ids[offset + 1 :]
)
offset_list.append(offset)
@@ -150,8 +154,8 @@ class LlavaBaseForCausalLM(nn.Module):
modalities_list = []
max_image_offset = []
for im in image_inputs:
if im and im.modalities is not None:
modalities_list.extend(im.modalities)
if im:
modalities_list.extend([item.modality for item in im.mm_items])
if im and im.image_offsets:
max_image_offset.append(
np.max(np.array(im.image_offsets) + np.array(im.image_pad_len))
@@ -164,11 +168,19 @@ class LlavaBaseForCausalLM(nn.Module):
if need_vision.any():
bs = forward_batch.batch_size
pixel_values = [
image_inputs[i].pixel_values for i in range(bs) if need_vision[i]
]
pixel_values = flatten_nested_list(
[
[item.pixel_values for item in image_inputs[i].mm_items]
for i in range(bs)
if need_vision[i]
]
)
image_sizes = [
image_inputs[i].image_sizes for i in range(bs) if need_vision[i]
flatten_nested_list(
[item.image_sizes for item in image_inputs[i].mm_items]
)
for i in range(bs)
if need_vision[i]
]
########## Encode Image ########
@@ -197,13 +209,13 @@ class LlavaBaseForCausalLM(nn.Module):
new_image_features = []
height = width = self.num_patches_per_side
for image_idx, image_feature in enumerate(image_features):
if modalities_list[image_idx] == "image":
if modalities_list[image_idx] == Modality.IMAGE:
image_aspect_ratio = (
self.config.image_aspect_ratio
) # single image
elif (
modalities_list[image_idx] == "multi-images"
or modalities_list[image_idx] == "video"
modalities_list[image_idx] == Modality.MULTI_IMAGES
or modalities_list[image_idx] == Modality.VIDEO
):
image_aspect_ratio = "pad" # multi image
# image_aspect_ratio = (
@@ -212,7 +224,7 @@ class LlavaBaseForCausalLM(nn.Module):
if (
image_feature.shape[0] > 1
and "anyres" in image_aspect_ratio
and modalities_list[image_idx] == "image"
and modalities_list[image_idx] == Modality.IMAGE
):
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
@@ -312,7 +324,7 @@ class LlavaBaseForCausalLM(nn.Module):
)
image_feature = image_feature.unsqueeze(0)
else:
if modalities_list[image_idx] == "video": # video
if modalities_list[image_idx] == Modality.VIDEO: # video
# 2x2 pooling
num_of_frames = image_feature.shape[0]
image_feature = image_feature.view(