[Bug] Fix InternVL KeyError: ((1, 1, 3), '<i8') (#108)
This commit is contained in:
@@ -13,6 +13,7 @@ from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Annotated, Any, Literal, Optional, TypeVar, Union
|
||||
|
||||
import numpy.typing as npt
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.transforms as T
|
||||
@@ -297,6 +298,8 @@ def video_to_pixel_values_internvl(
|
||||
transform = build_transform(input_size=input_size)
|
||||
frames_list = list[Image.Image]()
|
||||
for frame in video:
|
||||
if frame.dtype != np.uint8:
|
||||
frame = frame.astype(np.uint8)
|
||||
pil_frame = dynamic_preprocess_internvl(
|
||||
Image.fromarray(frame, mode="RGB"),
|
||||
target_ratios=target_ratios,
|
||||
@@ -1420,4 +1423,4 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
return MultiModelKeys.from_string_field(
|
||||
language_model="language_model",
|
||||
connector="mlp1",
|
||||
tower_model="vision_model")
|
||||
tower_model="vision_model")
|
||||
|
||||
Reference in New Issue
Block a user