Fix mixed batch for multi modal models (#1702)

This commit is contained in:
Lianmin Zheng
2024-10-17 10:27:26 -07:00
committed by GitHub
parent dd3809fad8
commit d17d19e5b8
3 changed files with 58 additions and 6 deletions

View File

@@ -160,9 +160,6 @@ class LlavaBaseForCausalLM(nn.Module):
image_sizes = [
image_inputs[i].image_sizes for i in range(bs) if need_vision[i]
]
image_offsets = [
image_inputs[i].image_offsets for i in range(bs) if need_vision[i]
]
########## Encode Image ########
@@ -358,7 +355,7 @@ class LlavaBaseForCausalLM(nn.Module):
prefix_len = prefix_lens_cpu[i]
# Multiple images
for j, image_offset in enumerate(image_offsets[i]):
for j, image_offset in enumerate(image_inputs[i].image_offsets):
if image_offset < prefix_len:
continue