[Feat] Add modalities for vision server when handling pixel values for llava (#1346)

This commit is contained in:
Kaichen Zhang - NTU
2024-09-09 17:07:34 +08:00
committed by GitHub
parent 8e6bdf851c
commit 662ecd9368
11 changed files with 40 additions and 2 deletions

View File

@@ -138,6 +138,12 @@ class LlavaBaseForCausalLM(nn.Module):
) -> torch.Tensor:
if input_metadata.forward_mode == ForwardMode.EXTEND:
bs = input_metadata.batch_size
# Got List[List[str]] extend it to List[str]
# The length of the List should be equal to batch size
modalities_list = []
for modalities in input_metadata.modalities:
if modalities is not None:
modalities_list.extend(modalities)
# Embed text inputs
input_embeds = self.language_model.model.embed_tokens(input_ids)
@@ -179,7 +185,7 @@ class LlavaBaseForCausalLM(nn.Module):
new_image_features = []
height = width = self.num_patches_per_side
for image_idx, image_feature in enumerate(image_features):
if len(image_sizes[image_idx]) == 1:
if modalities_list[image_idx] == 1:
image_aspect_ratio = (
self.config.image_aspect_ratio
) # single image
@@ -191,6 +197,7 @@ class LlavaBaseForCausalLM(nn.Module):
if (
image_feature.shape[0] > 1
and "anyres" in image_aspect_ratio
and modalities_list[image_idx] == "image"
):
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
@@ -290,7 +297,7 @@ class LlavaBaseForCausalLM(nn.Module):
)
image_feature = image_feature.unsqueeze(0)
else:
if image_feature.shape[0] > 16: # video
if modalities_list[image_idx] == "video": # video
# 2x2 pooling
num_of_frames = image_feature.shape[0]
image_feature = image_feature.view(