[Feat/WIP] add llava-onevision, with support for (1) siglip encoder, (2) qwen2 decoder (3) openai api compatible server. (#1123)
Co-authored-by: Bo Li <drluodian@gmail.com>
This commit is contained in:
committed by
GitHub
parent
5fafcac008
commit
a5b14ad043
@@ -15,6 +15,8 @@ limitations under the License.
|
||||
|
||||
"""Inference-only LLaVa model compatible with HuggingFace weights."""
|
||||
|
||||
import math
|
||||
import re
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
@@ -26,6 +28,8 @@ from transformers import (
|
||||
LlavaConfig,
|
||||
MistralConfig,
|
||||
Qwen2Config,
|
||||
SiglipVisionConfig,
|
||||
SiglipVisionModel,
|
||||
)
|
||||
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
||||
from vllm.config import CacheConfig
|
||||
@@ -63,34 +67,61 @@ class LlavaLlamaForCausalLM(nn.Module):
|
||||
)
|
||||
|
||||
def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None):
|
||||
new_image_feature_len = self.image_feature_len
|
||||
# now only support spatial_unpad + anyres
|
||||
if self.mm_patch_merge_type.startswith("spatial"):
|
||||
height = width = self.num_patches_per_side
|
||||
if pt_shape[0] > 1:
|
||||
if self.image_aspect_ratio == "anyres":
|
||||
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
|
||||
image_size,
|
||||
self.image_grid_pinpoints,
|
||||
self.vision_tower.config.image_size,
|
||||
)
|
||||
if "unpad" in self.mm_patch_merge_type:
|
||||
h = num_patch_height * height
|
||||
w = num_patch_width * width
|
||||
new_h, new_w = unpad_image_shape(h, w, image_size)
|
||||
new_image_feature_len += new_h * (new_w + 1)
|
||||
|
||||
pad_ids = pad_value * (
|
||||
(new_image_feature_len + len(pad_value)) // len(pad_value)
|
||||
)
|
||||
offset = input_ids.index(self.config.image_token_index)
|
||||
# old_len + pad_len - 1, because we need to remove image_token_id
|
||||
new_input_ids = (
|
||||
input_ids[:offset]
|
||||
+ pad_ids[:new_image_feature_len]
|
||||
+ input_ids[offset + 1 :]
|
||||
)
|
||||
return new_input_ids, offset
|
||||
# hardcode for spatial_unpad + anyres
|
||||
image_aspect_ratio = "anyres" if len(image_size) == 1 else "pad"
|
||||
offset_list = []
|
||||
for image_s in image_size:
|
||||
if len(image_size) > 16:
|
||||
# 2x2 pooling with stride 2
|
||||
new_image_feature_len = (
|
||||
math.ceil(self.image_size / self.patch_size / 2) ** 2
|
||||
)
|
||||
else:
|
||||
new_image_feature_len = self.image_feature_len # multiimage
|
||||
|
||||
height = width = self.num_patches_per_side
|
||||
if "anyres" in image_aspect_ratio:
|
||||
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
|
||||
image_s,
|
||||
self.image_grid_pinpoints,
|
||||
self.vision_tower.config.image_size,
|
||||
)
|
||||
h = num_patch_height * height
|
||||
w = num_patch_width * width
|
||||
new_h, new_w = unpad_image_shape(h, w, image_s)
|
||||
|
||||
if "anyres_max" in self.config.image_aspect_ratio:
|
||||
matched_anyres_max_num_patches = re.match(
|
||||
r"anyres_max_(\d+)", self.config.image_aspect_ratio
|
||||
)
|
||||
if matched_anyres_max_num_patches:
|
||||
max_num_patches = int(matched_anyres_max_num_patches.group(1))
|
||||
# times = math.sqrt(h * w / (max_num_patches * unit**2))
|
||||
times = math.sqrt(
|
||||
new_h * new_w / (max_num_patches * self.image_feature_len)
|
||||
)
|
||||
if times > 1.1:
|
||||
new_h = int(new_h // times)
|
||||
new_w = int(new_w // times)
|
||||
new_image_feature_len += new_h * (new_w + 1)
|
||||
|
||||
pad_ids = pad_value * (
|
||||
(new_image_feature_len + len(pad_value)) // len(pad_value)
|
||||
)
|
||||
# print("calculated new_image_feature_len: ", new_image_feature_len)
|
||||
try:
|
||||
offset = input_ids.index(self.config.image_token_index)
|
||||
except ValueError:
|
||||
offset = 0
|
||||
# old_len + pad_len - 1, because we need to remove image_token_id
|
||||
input_ids = (
|
||||
input_ids[:offset]
|
||||
+ pad_ids[:new_image_feature_len]
|
||||
+ input_ids[offset + 1 :]
|
||||
)
|
||||
offset_list.append(offset)
|
||||
return input_ids, offset_list
|
||||
|
||||
def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||||
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
|
||||
@@ -124,7 +155,6 @@ class LlavaLlamaForCausalLM(nn.Module):
|
||||
|
||||
# Embed text input
|
||||
input_embeds = self.language_model.model.embed_tokens(input_ids)
|
||||
|
||||
# Embed vision input
|
||||
need_vision = (
|
||||
(positions[input_metadata.extend_start_loc] < self.image_feature_len)
|
||||
@@ -163,27 +193,73 @@ class LlavaLlamaForCausalLM(nn.Module):
|
||||
|
||||
if self.mm_patch_merge_type.startswith("spatial"):
|
||||
new_image_features = []
|
||||
height = width = self.num_patches_per_side
|
||||
for image_idx, image_feature in enumerate(image_features):
|
||||
if image_feature.shape[0] > 1:
|
||||
if len(image_sizes[image_idx]) == 1:
|
||||
image_aspect_ratio = (
|
||||
self.config.image_aspect_ratio
|
||||
) # single image
|
||||
else:
|
||||
image_aspect_ratio = "pad" # multi image
|
||||
# image_aspect_ratio = (
|
||||
# "anyres" if len(image_sizes[image_idx]) == 1 else "pad"
|
||||
# )
|
||||
if (
|
||||
image_feature.shape[0] > 1
|
||||
and "anyres" in image_aspect_ratio
|
||||
):
|
||||
base_image_feature = image_feature[0]
|
||||
image_feature = image_feature[1:]
|
||||
height = width = self.num_patches_per_side
|
||||
assert height * width == base_image_feature.shape[0]
|
||||
if self.image_aspect_ratio == "anyres":
|
||||
(
|
||||
num_patch_width,
|
||||
num_patch_height,
|
||||
) = get_anyres_image_grid_shape(
|
||||
image_sizes[image_idx],
|
||||
self.image_grid_pinpoints,
|
||||
self.vision_tower.config.image_size,
|
||||
|
||||
if "anyres_max" in image_aspect_ratio:
|
||||
matched_anyres_max_num_patches = re.match(
|
||||
r"anyres_max_(\d+)", image_aspect_ratio
|
||||
)
|
||||
if matched_anyres_max_num_patches:
|
||||
max_num_patches = int(
|
||||
matched_anyres_max_num_patches.group(1)
|
||||
)
|
||||
|
||||
if (
|
||||
image_aspect_ratio == "anyres"
|
||||
or "anyres_max" in image_aspect_ratio
|
||||
):
|
||||
vision_tower_image_size = self.image_size
|
||||
try:
|
||||
num_patch_width, num_patch_height = (
|
||||
get_anyres_image_grid_shape(
|
||||
image_sizes[image_idx][0],
|
||||
self.config.image_grid_pinpoints,
|
||||
vision_tower_image_size,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
num_patch_width, num_patch_height = 2, 2
|
||||
image_feature = image_feature.view(
|
||||
num_patch_height, num_patch_width, height, width, -1
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
image_feature = image_feature.view(
|
||||
2, 2, height, width, -1
|
||||
)
|
||||
|
||||
# (
|
||||
# num_patch_width,
|
||||
# num_patch_height,
|
||||
# ) = get_anyres_image_grid_shape(
|
||||
# image_sizes[image_idx][0],
|
||||
# self.image_grid_pinpoints,
|
||||
# self.vision_tower.config.image_size,
|
||||
# )
|
||||
|
||||
# image_feature = image_feature.view(
|
||||
# num_patch_height, num_patch_width, height, width, -1
|
||||
# )
|
||||
|
||||
if "unpad" in self.mm_patch_merge_type:
|
||||
unit = image_feature.shape[2]
|
||||
image_feature = image_feature.permute(
|
||||
4, 0, 2, 1, 3
|
||||
).contiguous()
|
||||
@@ -191,8 +267,23 @@ class LlavaLlamaForCausalLM(nn.Module):
|
||||
2, 3
|
||||
)
|
||||
image_feature = unpad_image(
|
||||
image_feature, image_sizes[image_idx]
|
||||
image_feature, image_sizes[image_idx][0]
|
||||
)
|
||||
if (
|
||||
"anyres_max" in image_aspect_ratio
|
||||
and matched_anyres_max_num_patches
|
||||
):
|
||||
c, h, w = image_feature.shape
|
||||
times = math.sqrt(
|
||||
h * w / (max_num_patches * unit**2)
|
||||
)
|
||||
if times > 1.1:
|
||||
image_feature = image_feature[None]
|
||||
image_feature = nn.functional.interpolate(
|
||||
image_feature,
|
||||
[int(h // times), int(w // times)],
|
||||
mode="bilinear",
|
||||
)[0]
|
||||
image_feature = torch.cat(
|
||||
(
|
||||
image_feature,
|
||||
@@ -213,16 +304,31 @@ class LlavaLlamaForCausalLM(nn.Module):
|
||||
image_feature = torch.cat(
|
||||
(base_image_feature, image_feature), dim=0
|
||||
)
|
||||
image_feature = image_feature.unsqueeze(0)
|
||||
else:
|
||||
image_feature = image_feature[0]
|
||||
if "unpad" in self.mm_patch_merge_type:
|
||||
image_feature = torch.cat(
|
||||
(
|
||||
image_feature,
|
||||
self.language_model.model.image_newline[None],
|
||||
),
|
||||
dim=0,
|
||||
if image_feature.shape[0] > 16: # video
|
||||
# 2x2 pooling
|
||||
num_of_frames = image_feature.shape[0]
|
||||
image_feature = image_feature.view(
|
||||
num_of_frames, height, width, -1
|
||||
)
|
||||
image_feature = image_feature.permute(
|
||||
0, 3, 1, 2
|
||||
).contiguous() # N, C, H, W
|
||||
height, weight = image_feature.shape[2:]
|
||||
scaled_shape = [
|
||||
math.ceil(height / 2),
|
||||
math.ceil(weight / 2),
|
||||
]
|
||||
image_feature = nn.functional.interpolate(
|
||||
image_feature, size=scaled_shape, mode="bilinear"
|
||||
)
|
||||
image_feature = (
|
||||
image_feature.flatten(2)
|
||||
.transpose(1, 2)
|
||||
.contiguous()
|
||||
) # N, C, H*W
|
||||
|
||||
new_image_features.append(image_feature)
|
||||
image_features = new_image_features
|
||||
|
||||
@@ -233,21 +339,22 @@ class LlavaLlamaForCausalLM(nn.Module):
|
||||
continue
|
||||
|
||||
start_idx = extend_start_loc_cpu[i]
|
||||
pad_len, pad_dim = image_features[pt].shape # 576, 4096
|
||||
pad_dim = image_features[pt].shape[-1] # 576, 4096
|
||||
dim = input_embeds.shape[1]
|
||||
assert (
|
||||
pad_dim == dim
|
||||
), "invalid pad_dim={}, input_embed_dim={}!".format(pad_dim, dim)
|
||||
# Fill in the placeholder for the image
|
||||
try:
|
||||
input_embeds[
|
||||
start_idx
|
||||
+ image_offsets[i] : start_idx
|
||||
+ image_offsets[i]
|
||||
+ pad_len
|
||||
] = image_features[pt]
|
||||
for j, image_off in enumerate(image_offsets[i]):
|
||||
# print("actual image_features length: ", image_features[pt][j].shape[0])
|
||||
pad_len = image_features[pt][j].shape[0]
|
||||
input_embeds[
|
||||
start_idx + image_off : start_idx + image_off + pad_len
|
||||
] = image_features[pt][j]
|
||||
except RuntimeError as e:
|
||||
print(f"RuntimeError in llava image encoding: {e}")
|
||||
print(image_features[pt].shape)
|
||||
print(input_embeds.shape)
|
||||
print(start_idx, image_offsets[i])
|
||||
pt += 1
|
||||
@@ -262,9 +369,16 @@ class LlavaLlamaForCausalLM(nn.Module):
|
||||
# load clip vision model by cfg['mm_vision_tower']:
|
||||
# huggingface_name or path_of_clip_relative_to_llava_model_dir
|
||||
vision_path = self.config.mm_vision_tower
|
||||
self.vision_tower = CLIPVisionModel.from_pretrained(
|
||||
vision_path, torch_dtype=torch.float16
|
||||
).cuda()
|
||||
if "clip" in vision_path:
|
||||
self.vision_tower = CLIPVisionModel.from_pretrained(
|
||||
vision_path, torch_dtype=torch.float16
|
||||
).cuda()
|
||||
elif "siglip" in vision_path:
|
||||
self.vision_tower = SiglipVisionModel.from_pretrained(
|
||||
vision_path, torch_dtype=torch.float16
|
||||
).cuda()
|
||||
# Siglip needs all feature tokens
|
||||
self.config.mm_vision_select_feature = "full"
|
||||
self.vision_tower.eval()
|
||||
|
||||
self.vision_feature_layer = self.config.mm_vision_select_layer
|
||||
@@ -276,8 +390,11 @@ class LlavaLlamaForCausalLM(nn.Module):
|
||||
self.image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
|
||||
self.image_grid_pinpoints = getattr(self.config, "image_grid_pinpoints", None)
|
||||
|
||||
self.image_feature_len = int((self.image_size / self.patch_size) ** 2)
|
||||
if self.vision_feature_select_strategy == "patch":
|
||||
self.image_feature_len = int((self.image_size // self.patch_size) ** 2)
|
||||
if (
|
||||
self.vision_feature_select_strategy == "patch"
|
||||
or self.vision_feature_select_strategy == "full"
|
||||
):
|
||||
pass
|
||||
elif self.vision_feature_select_strategy == "cls_patch":
|
||||
self.image_feature_len += 1
|
||||
|
||||
Reference in New Issue
Block a user