[Feat] Add modalities for vision server when handling pixel values for llava (#1346)
This commit is contained in:
committed by
GitHub
parent
8e6bdf851c
commit
662ecd9368
@@ -138,6 +138,12 @@ class LlavaBaseForCausalLM(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
if input_metadata.forward_mode == ForwardMode.EXTEND:
|
||||
bs = input_metadata.batch_size
|
||||
# Got List[List[str]] extend it to List[str]
|
||||
# The length of the List should be equal to batch size
|
||||
modalities_list = []
|
||||
for modalities in input_metadata.modalities:
|
||||
if modalities is not None:
|
||||
modalities_list.extend(modalities)
|
||||
|
||||
# Embed text inputs
|
||||
input_embeds = self.language_model.model.embed_tokens(input_ids)
|
||||
@@ -179,7 +185,7 @@ class LlavaBaseForCausalLM(nn.Module):
|
||||
new_image_features = []
|
||||
height = width = self.num_patches_per_side
|
||||
for image_idx, image_feature in enumerate(image_features):
|
||||
if len(image_sizes[image_idx]) == 1:
|
||||
if modalities_list[image_idx] == 1:
|
||||
image_aspect_ratio = (
|
||||
self.config.image_aspect_ratio
|
||||
) # single image
|
||||
@@ -191,6 +197,7 @@ class LlavaBaseForCausalLM(nn.Module):
|
||||
if (
|
||||
image_feature.shape[0] > 1
|
||||
and "anyres" in image_aspect_ratio
|
||||
and modalities_list[image_idx] == "image"
|
||||
):
|
||||
base_image_feature = image_feature[0]
|
||||
image_feature = image_feature[1:]
|
||||
@@ -290,7 +297,7 @@ class LlavaBaseForCausalLM(nn.Module):
|
||||
)
|
||||
image_feature = image_feature.unsqueeze(0)
|
||||
else:
|
||||
if image_feature.shape[0] > 16: # video
|
||||
if modalities_list[image_idx] == "video": # video
|
||||
# 2x2 pooling
|
||||
num_of_frames = image_feature.shape[0]
|
||||
image_feature = image_feature.view(
|
||||
|
||||
Reference in New Issue
Block a user