fix batch error for llava-hd (#98)
This commit is contained in:
@@ -112,24 +112,28 @@ class LlavaLlamaForCausalLM(nn.Module):
|
|||||||
need_vision = need_vision & has_pixel
|
need_vision = need_vision & has_pixel
|
||||||
|
|
||||||
if need_vision.any():
|
if need_vision.any():
|
||||||
pixel_values = torch.tensor(
|
pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]]
|
||||||
np.array([pixel_values[i] for i in range(bs) if need_vision[i]]),
|
image_sizes = [image_sizes[i] for i in range(bs) if need_vision[i]]
|
||||||
device=self.vision_tower.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
########## Encode Image ########
|
########## Encode Image ########
|
||||||
|
|
||||||
if pixel_values.ndim == 5:
|
if pixel_values[0].ndim == 4:
|
||||||
# llava-hd: BS, num_patch, C=3, H=336, W=336, num_patch obtained from process_images
|
# llava-hd: BS, num_patch, C=3, H=336, W=336, num_patch obtained from process_images
|
||||||
concat_images = torch.cat(
|
np.concatenate(pixel_values, axis=0)
|
||||||
[image for image in pixel_values], dim=0
|
# ndim=4
|
||||||
) # ndim=4
|
concat_images = torch.tensor(
|
||||||
|
np.concatenate(pixel_values, axis=0),
|
||||||
|
device=self.vision_tower.device,
|
||||||
|
)
|
||||||
image_features = self.encode_images(concat_images)
|
image_features = self.encode_images(concat_images)
|
||||||
split_sizes = [image.shape[0] for image in pixel_values]
|
split_sizes = [image.shape[0] for image in pixel_values]
|
||||||
image_features = torch.split(image_features, split_sizes, dim=0)
|
image_features = torch.split(image_features, split_sizes, dim=0)
|
||||||
# hd image_features: BS, num_patch, 576, 4096
|
# hd image_features: BS, num_patch, 576, 4096
|
||||||
else:
|
else:
|
||||||
# normal pixel: BS, C=3, H=336, W=336
|
# normal pixel: BS, C=3, H=336, W=336
|
||||||
|
pixel_values = torch.tensor(
|
||||||
|
np.array(pixel_values), device=self.vision_tower.device
|
||||||
|
)
|
||||||
image_features = self.encode_images(pixel_values)
|
image_features = self.encode_images(pixel_values)
|
||||||
# image_features: BS, 576, 4096
|
# image_features: BS, 576, 4096
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user