[fix] Fix prefix caching for multi-image/video (#2239)
This commit is contained in:
@@ -145,15 +145,17 @@ class ImageInputs:
|
||||
# Use image hash as fake token_ids, which is then used for prefix matching
|
||||
ret = ImageInputs(
|
||||
pixel_values=obj["pixel_values"],
|
||||
image_hashes=hash(tuple(obj["image_hashes"])),
|
||||
image_hashes=obj["image_hashes"],
|
||||
)
|
||||
image_hash = ret.image_hashes
|
||||
ret.pad_values = [
|
||||
(image_hash) % vocab_size,
|
||||
(image_hash >> 16) % vocab_size,
|
||||
(image_hash >> 32) % vocab_size,
|
||||
(image_hash >> 64) % vocab_size,
|
||||
]
|
||||
if not isinstance(ret.image_hashes, list):
|
||||
ret.pad_values = [
|
||||
(ret.image_hashes) % vocab_size,
|
||||
(ret.image_hashes >> 16) % vocab_size,
|
||||
(ret.image_hashes >> 32) % vocab_size,
|
||||
(ret.image_hashes >> 64) % vocab_size,
|
||||
]
|
||||
else:
|
||||
ret.pad_values = [x % vocab_size for x in ret.image_hashes]
|
||||
|
||||
optional_args = [
|
||||
"image_sizes",
|
||||
@@ -171,14 +173,18 @@ class ImageInputs:
|
||||
def merge(self, other, vocab_size):
|
||||
assert self.pixel_values.shape[1:] == other.pixel_values.shape[1:]
|
||||
self.pixel_values = np.concatenate([self.pixel_values, other.pixel_values])
|
||||
self.image_hashes += other.image_hashes
|
||||
|
||||
self.pad_values = [
|
||||
(self.image_hashes) % vocab_size,
|
||||
(self.image_hashes >> 16) % vocab_size,
|
||||
(self.image_hashes >> 32) % vocab_size,
|
||||
(self.image_hashes >> 64) % vocab_size,
|
||||
]
|
||||
if isinstance(self.image_hashes, list) and isinstance(other.image_hashes, list):
|
||||
self.image_hashes += other.image_hashes
|
||||
self.pad_values = [x % vocab_size for x in self.image_hashes]
|
||||
else:
|
||||
self.image_hashes = hash(tuple(self.image_hashes, other.image_hashes))
|
||||
self.pad_values = [
|
||||
(self.image_hashes) % vocab_size,
|
||||
(self.image_hashes >> 16) % vocab_size,
|
||||
(self.image_hashes >> 32) % vocab_size,
|
||||
(self.image_hashes >> 64) % vocab_size,
|
||||
]
|
||||
|
||||
optional_args = [
|
||||
"image_sizes",
|
||||
|
||||
@@ -57,7 +57,7 @@ class LlavaBaseForCausalLM(nn.Module):
|
||||
else:
|
||||
image_aspect_ratio = "anyres"
|
||||
offset_list = []
|
||||
for image_s in image_sizes:
|
||||
for image_idx, image_s in enumerate(image_sizes):
|
||||
if len(image_sizes) > 16:
|
||||
# 2x2 pooling with stride 2
|
||||
new_image_feature_len = (
|
||||
@@ -92,10 +92,6 @@ class LlavaBaseForCausalLM(nn.Module):
|
||||
new_w = int(new_w // times)
|
||||
new_image_feature_len += new_h * (new_w + 1)
|
||||
|
||||
pad_ids = pad_values * (
|
||||
(new_image_feature_len + len(pad_values)) // len(pad_values)
|
||||
)
|
||||
# print("calculated new_image_feature_len: ", new_image_feature_len)
|
||||
try:
|
||||
offset = input_ids.index(self.config.image_token_index)
|
||||
except ValueError:
|
||||
@@ -103,7 +99,7 @@ class LlavaBaseForCausalLM(nn.Module):
|
||||
# old_len + pad_len - 1, because we need to remove image_token_id
|
||||
input_ids = (
|
||||
input_ids[:offset]
|
||||
+ pad_ids[:new_image_feature_len]
|
||||
+ [pad_values[image_idx]] * new_image_feature_len
|
||||
+ input_ids[offset + 1 :]
|
||||
)
|
||||
offset_list.append(offset)
|
||||
|
||||
@@ -500,7 +500,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
return num_image_tokens
|
||||
|
||||
# Use grid_t * grid_w * grid_h to pad tokens for each image
|
||||
# and replaced padding by unique image hash
|
||||
# add replaced padding by unique image hash
|
||||
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
|
||||
image_grid_thws = image_inputs.image_grid_thws
|
||||
pad_values = image_inputs.pad_values
|
||||
|
||||
Reference in New Issue
Block a user