[Fix] Fix llava on multi images (#1247)
This commit is contained in:
@@ -28,7 +28,6 @@ from transformers import (
|
||||
LlavaConfig,
|
||||
MistralConfig,
|
||||
Qwen2Config,
|
||||
SiglipVisionConfig,
|
||||
SiglipVisionModel,
|
||||
)
|
||||
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
||||
@@ -66,13 +65,18 @@ class LlavaLlamaForCausalLM(nn.Module):
|
||||
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
|
||||
)
|
||||
|
||||
def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None):
|
||||
|
||||
def pad_input_ids(
|
||||
self,
|
||||
input_ids: List[int],
|
||||
pad_value: List[int],
|
||||
pixel_values: List,
|
||||
image_sizes: List[List[int]],
|
||||
):
|
||||
# hardcode for spatial_unpad + anyres
|
||||
image_aspect_ratio = "anyres" if len(image_size) == 1 else "pad"
|
||||
image_aspect_ratio = "anyres" if len(image_sizes) == 1 else "pad"
|
||||
offset_list = []
|
||||
for image_s in image_size:
|
||||
if len(image_size) > 16:
|
||||
for image_s in image_sizes:
|
||||
if len(image_sizes) > 16:
|
||||
# 2x2 pooling with stride 2
|
||||
new_image_feature_len = (
|
||||
math.ceil(self.image_size / self.patch_size / 2) ** 2
|
||||
@@ -153,17 +157,15 @@ class LlavaLlamaForCausalLM(nn.Module):
|
||||
if input_metadata.forward_mode == ForwardMode.EXTEND:
|
||||
bs = input_metadata.batch_size
|
||||
|
||||
# Embed text input
|
||||
# Embed text inputs
|
||||
input_embeds = self.language_model.model.embed_tokens(input_ids)
|
||||
# Embed vision input
|
||||
need_vision = (
|
||||
(positions[input_metadata.extend_start_loc] < self.image_feature_len)
|
||||
.cpu()
|
||||
.numpy()
|
||||
|
||||
# Whether the requests need vision inputs
|
||||
max_image_offset = np.array(
|
||||
[max(image_offsets[i]) if image_offsets[i] else -1 for i in range(bs)]
|
||||
)
|
||||
# FIXME: We need to substract the length of the system prompt
|
||||
has_pixel = np.array([pixel_values[i] is not None for i in range(bs)])
|
||||
need_vision = need_vision & has_pixel
|
||||
start_positions = positions[input_metadata.extend_start_loc].cpu().numpy()
|
||||
need_vision = start_positions <= max_image_offset
|
||||
|
||||
if need_vision.any():
|
||||
pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]]
|
||||
@@ -332,31 +334,35 @@ class LlavaLlamaForCausalLM(nn.Module):
|
||||
new_image_features.append(image_feature)
|
||||
image_features = new_image_features
|
||||
|
||||
# Fill in the placeholder for the image
|
||||
extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy()
|
||||
prefix_lens_cpu = input_metadata.extend_prefix_lens.cpu().numpy()
|
||||
pt = 0
|
||||
for i in range(bs):
|
||||
if not need_vision[i]:
|
||||
continue
|
||||
|
||||
start_idx = extend_start_loc_cpu[i]
|
||||
pad_dim = image_features[pt].shape[-1] # 576, 4096
|
||||
dim = input_embeds.shape[1]
|
||||
assert (
|
||||
pad_dim == dim
|
||||
), "invalid pad_dim={}, input_embed_dim={}!".format(pad_dim, dim)
|
||||
# Fill in the placeholder for the image
|
||||
try:
|
||||
for j, image_off in enumerate(image_offsets[i]):
|
||||
# print("actual image_features length: ", image_features[pt][j].shape[0])
|
||||
pad_len = image_features[pt][j].shape[0]
|
||||
input_embeds[
|
||||
start_idx + image_off : start_idx + image_off + pad_len
|
||||
] = image_features[pt][j]
|
||||
except RuntimeError as e:
|
||||
print(f"RuntimeError in llava image encoding: {e}")
|
||||
print(image_features[pt].shape)
|
||||
print(input_embeds.shape)
|
||||
print(start_idx, image_offsets[i])
|
||||
prefix_len = prefix_lens_cpu[i]
|
||||
|
||||
# Multiple images
|
||||
for j, image_offset in enumerate(image_offsets[i]):
|
||||
if image_offset < prefix_len:
|
||||
continue
|
||||
|
||||
tmp_image_feature = image_features[pt][j]
|
||||
pad_len = tmp_image_feature.shape[0]
|
||||
|
||||
left_idx = start_idx + (image_offset - prefix_len)
|
||||
right_idx = start_idx + (image_offset - prefix_len) + pad_len
|
||||
try:
|
||||
input_embeds[left_idx:right_idx] = tmp_image_feature
|
||||
except RuntimeError as e:
|
||||
print(f"RuntimeError in image encoding: {e}")
|
||||
print(f"{input_embeds.shape=}, {tmp_image_feature.shape=}")
|
||||
print(
|
||||
f"{start_idx=}, {image_offset=}, {prefix_len=}, {pad_len=}"
|
||||
)
|
||||
pt += 1
|
||||
|
||||
return self.language_model(
|
||||
@@ -366,8 +372,9 @@ class LlavaLlamaForCausalLM(nn.Module):
|
||||
return self.language_model(input_ids, positions, input_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
# load clip vision model by cfg['mm_vision_tower']:
|
||||
# huggingface_name or path_of_clip_relative_to_llava_model_dir
|
||||
# Load clip vision model by cfg['mm_vision_tower']:
|
||||
# huggingface_name or path_of_clip_relative_to_llava_model_dir
|
||||
# We put the initialization here instead of __init__ to allow it being reused by other subclasses.
|
||||
vision_path = self.config.mm_vision_tower
|
||||
if "clip" in vision_path:
|
||||
self.vision_tower = CLIPVisionModel.from_pretrained(
|
||||
@@ -422,8 +429,6 @@ class LlavaLlamaForCausalLM(nn.Module):
|
||||
# load language model
|
||||
self.language_model.load_weights(weights)
|
||||
|
||||
monkey_path_clip_vision_embed_forward()
|
||||
|
||||
@property
|
||||
def num_patches_per_side(self):
|
||||
return self.image_size // self.patch_size
|
||||
@@ -495,36 +500,4 @@ class LlavaMistralForCausalLM(LlavaLlamaForCausalLM):
|
||||
)
|
||||
|
||||
|
||||
first_call = True
|
||||
|
||||
|
||||
def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
||||
batch_size = pixel_values.shape[0]
|
||||
|
||||
# Move this conv layer to CPU to avoid a bug in torch >= 2.1 on A10G.
|
||||
global first_call
|
||||
if first_call:
|
||||
self.patch_embedding.cpu().float()
|
||||
first_call = False
|
||||
pixel_values = pixel_values.to(dtype=torch.float32, device="cpu")
|
||||
patch_embeds = self.patch_embedding(pixel_values).cuda().half()
|
||||
|
||||
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
||||
|
||||
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
||||
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
||||
embeddings = embeddings + self.position_embedding(self.position_ids)
|
||||
return embeddings
|
||||
|
||||
|
||||
def monkey_path_clip_vision_embed_forward():
|
||||
import transformers
|
||||
|
||||
setattr(
|
||||
transformers.models.clip.modeling_clip.CLIPVisionEmbeddings,
|
||||
"forward",
|
||||
clip_vision_embed_forward,
|
||||
)
|
||||
|
||||
|
||||
EntryClass = [LlavaLlamaForCausalLM, LlavaQwenForCausalLM, LlavaMistralForCausalLM]
|
||||
|
||||
Reference in New Issue
Block a user