[fix] Fix prefix caching for multi-image/video (#2239)

This commit is contained in:
Ying Sheng
2024-11-28 12:08:13 -08:00
committed by GitHub
parent 65fdb28929
commit b7038fec9b
4 changed files with 26 additions and 22 deletions

View File

@@ -57,7 +57,7 @@ class LlavaBaseForCausalLM(nn.Module):
else:
image_aspect_ratio = "anyres"
offset_list = []
for image_s in image_sizes:
for image_idx, image_s in enumerate(image_sizes):
if len(image_sizes) > 16:
# 2x2 pooling with stride 2
new_image_feature_len = (
@@ -92,10 +92,6 @@ class LlavaBaseForCausalLM(nn.Module):
new_w = int(new_w // times)
new_image_feature_len += new_h * (new_w + 1)
pad_ids = pad_values * (
(new_image_feature_len + len(pad_values)) // len(pad_values)
)
# print("calculated new_image_feature_len: ", new_image_feature_len)
try:
offset = input_ids.index(self.config.image_token_index)
except ValueError:
@@ -103,7 +99,7 @@ class LlavaBaseForCausalLM(nn.Module):
# old_len + pad_len - 1, because we need to remove image_token_id
input_ids = (
input_ids[:offset]
+ pad_ids[:new_image_feature_len]
+ [pad_values[image_idx]] * new_image_feature_len
+ input_ids[offset + 1 :]
)
offset_list.append(offset)

View File

@@ -500,7 +500,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
return num_image_tokens
# Use grid_t * grid_w * grid_h to pad tokens for each image
# and replaced padding by unique image hash
# add replaced padding by unique image hash
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
image_grid_thws = image_inputs.image_grid_thws
pad_values = image_inputs.pad_values