[Fix] Fix llava on multi images (#1247)

This commit is contained in:
Lianmin Zheng
2024-08-28 06:33:05 -07:00
committed by GitHub
parent b1a540ec42
commit bf53bf5142
22 changed files with 272 additions and 488 deletions

View File

@@ -28,7 +28,6 @@ from transformers import (
LlavaConfig,
MistralConfig,
Qwen2Config,
SiglipVisionConfig,
SiglipVisionModel,
)
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
@@ -66,13 +65,18 @@ class LlavaLlamaForCausalLM(nn.Module):
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
)
def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None):
def pad_input_ids(
self,
input_ids: List[int],
pad_value: List[int],
pixel_values: List,
image_sizes: List[List[int]],
):
# hardcode for spatial_unpad + anyres
image_aspect_ratio = "anyres" if len(image_size) == 1 else "pad"
image_aspect_ratio = "anyres" if len(image_sizes) == 1 else "pad"
offset_list = []
for image_s in image_size:
if len(image_size) > 16:
for image_s in image_sizes:
if len(image_sizes) > 16:
# 2x2 pooling with stride 2
new_image_feature_len = (
math.ceil(self.image_size / self.patch_size / 2) ** 2
@@ -153,17 +157,15 @@ class LlavaLlamaForCausalLM(nn.Module):
if input_metadata.forward_mode == ForwardMode.EXTEND:
bs = input_metadata.batch_size
# Embed text input
# Embed text inputs
input_embeds = self.language_model.model.embed_tokens(input_ids)
# Embed vision input
need_vision = (
(positions[input_metadata.extend_start_loc] < self.image_feature_len)
.cpu()
.numpy()
# Whether the requests need vision inputs
max_image_offset = np.array(
[max(image_offsets[i]) if image_offsets[i] else -1 for i in range(bs)]
)
# FIXME: We need to substract the length of the system prompt
has_pixel = np.array([pixel_values[i] is not None for i in range(bs)])
need_vision = need_vision & has_pixel
start_positions = positions[input_metadata.extend_start_loc].cpu().numpy()
need_vision = start_positions <= max_image_offset
if need_vision.any():
pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]]
@@ -332,31 +334,35 @@ class LlavaLlamaForCausalLM(nn.Module):
new_image_features.append(image_feature)
image_features = new_image_features
# Fill in the placeholder for the image
extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy()
prefix_lens_cpu = input_metadata.extend_prefix_lens.cpu().numpy()
pt = 0
for i in range(bs):
if not need_vision[i]:
continue
start_idx = extend_start_loc_cpu[i]
pad_dim = image_features[pt].shape[-1] # 576, 4096
dim = input_embeds.shape[1]
assert (
pad_dim == dim
), "invalid pad_dim={}, input_embed_dim={}!".format(pad_dim, dim)
# Fill in the placeholder for the image
try:
for j, image_off in enumerate(image_offsets[i]):
# print("actual image_features length: ", image_features[pt][j].shape[0])
pad_len = image_features[pt][j].shape[0]
input_embeds[
start_idx + image_off : start_idx + image_off + pad_len
] = image_features[pt][j]
except RuntimeError as e:
print(f"RuntimeError in llava image encoding: {e}")
print(image_features[pt].shape)
print(input_embeds.shape)
print(start_idx, image_offsets[i])
prefix_len = prefix_lens_cpu[i]
# Multiple images
for j, image_offset in enumerate(image_offsets[i]):
if image_offset < prefix_len:
continue
tmp_image_feature = image_features[pt][j]
pad_len = tmp_image_feature.shape[0]
left_idx = start_idx + (image_offset - prefix_len)
right_idx = start_idx + (image_offset - prefix_len) + pad_len
try:
input_embeds[left_idx:right_idx] = tmp_image_feature
except RuntimeError as e:
print(f"RuntimeError in image encoding: {e}")
print(f"{input_embeds.shape=}, {tmp_image_feature.shape=}")
print(
f"{start_idx=}, {image_offset=}, {prefix_len=}, {pad_len=}"
)
pt += 1
return self.language_model(
@@ -366,8 +372,9 @@ class LlavaLlamaForCausalLM(nn.Module):
return self.language_model(input_ids, positions, input_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# load clip vision model by cfg['mm_vision_tower']:
# huggingface_name or path_of_clip_relative_to_llava_model_dir
# Load clip vision model by cfg['mm_vision_tower']:
# huggingface_name or path_of_clip_relative_to_llava_model_dir
# We put the initialization here instead of __init__ to allow it being reused by other subclasses.
vision_path = self.config.mm_vision_tower
if "clip" in vision_path:
self.vision_tower = CLIPVisionModel.from_pretrained(
@@ -422,8 +429,6 @@ class LlavaLlamaForCausalLM(nn.Module):
# load language model
self.language_model.load_weights(weights)
monkey_path_clip_vision_embed_forward()
@property
def num_patches_per_side(self):
return self.image_size // self.patch_size
@@ -495,36 +500,4 @@ class LlavaMistralForCausalLM(LlavaLlamaForCausalLM):
)
first_call = True
def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0]
# Move this conv layer to CPU to avoid a bug in torch >= 2.1 on A10G.
global first_call
if first_call:
self.patch_embedding.cpu().float()
first_call = False
pixel_values = pixel_values.to(dtype=torch.float32, device="cpu")
patch_embeds = self.patch_embedding(pixel_values).cuda().half()
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
embeddings = embeddings + self.position_embedding(self.position_ids)
return embeddings
def monkey_path_clip_vision_embed_forward():
import transformers
setattr(
transformers.models.clip.modeling_clip.CLIPVisionEmbeddings,
"forward",
clip_vision_embed_forward,
)
EntryClass = [LlavaLlamaForCausalLM, LlavaQwenForCausalLM, LlavaMistralForCausalLM]