fix batch error for llava-hd (#98)
This commit is contained in:
@@ -112,24 +112,28 @@ class LlavaLlamaForCausalLM(nn.Module):
|
||||
need_vision = need_vision & has_pixel
|
||||
|
||||
if need_vision.any():
|
||||
pixel_values = torch.tensor(
|
||||
np.array([pixel_values[i] for i in range(bs) if need_vision[i]]),
|
||||
device=self.vision_tower.device,
|
||||
)
|
||||
pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]]
|
||||
image_sizes = [image_sizes[i] for i in range(bs) if need_vision[i]]
|
||||
|
||||
########## Encode Image ########
|
||||
|
||||
if pixel_values.ndim == 5:
|
||||
if pixel_values[0].ndim == 4:
|
||||
# llava-hd: BS, num_patch, C=3, H=336, W=336, num_patch obtained from process_images
|
||||
concat_images = torch.cat(
|
||||
[image for image in pixel_values], dim=0
|
||||
) # ndim=4
|
||||
np.concatenate(pixel_values, axis=0)
|
||||
# ndim=4
|
||||
concat_images = torch.tensor(
|
||||
np.concatenate(pixel_values, axis=0),
|
||||
device=self.vision_tower.device,
|
||||
)
|
||||
image_features = self.encode_images(concat_images)
|
||||
split_sizes = [image.shape[0] for image in pixel_values]
|
||||
image_features = torch.split(image_features, split_sizes, dim=0)
|
||||
# hd image_features: BS, num_patch, 576, 4096
|
||||
else:
|
||||
# normal pixel: BS, C=3, H=336, W=336
|
||||
pixel_values = torch.tensor(
|
||||
np.array(pixel_values), device=self.vision_tower.device
|
||||
)
|
||||
image_features = self.encode_images(pixel_values)
|
||||
# image_features: BS, 576, 4096
|
||||
|
||||
|
||||
Reference in New Issue
Block a user