[Feat/WIP] add llava-onevision, with support for (1) siglip encoder, (2) qwen2 decoder (3) openai api compatible server. (#1123)

Co-authored-by: Bo Li <drluodian@gmail.com>
This commit is contained in:
Kaichen Zhang - NTU
2024-08-24 05:11:16 +08:00
committed by GitHub
parent 5fafcac008
commit a5b14ad043
13 changed files with 703 additions and 95 deletions

View File

@@ -15,6 +15,8 @@ limitations under the License.
"""Inference-only LLaVa model compatible with HuggingFace weights."""
import math
import re
from typing import Iterable, List, Optional, Tuple
import numpy as np
@@ -26,6 +28,8 @@ from transformers import (
LlavaConfig,
MistralConfig,
Qwen2Config,
SiglipVisionConfig,
SiglipVisionModel,
)
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
from vllm.config import CacheConfig
@@ -63,34 +67,61 @@ class LlavaLlamaForCausalLM(nn.Module):
)
def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None):
new_image_feature_len = self.image_feature_len
# now only support spatial_unpad + anyres
if self.mm_patch_merge_type.startswith("spatial"):
height = width = self.num_patches_per_side
if pt_shape[0] > 1:
if self.image_aspect_ratio == "anyres":
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
image_size,
self.image_grid_pinpoints,
self.vision_tower.config.image_size,
)
if "unpad" in self.mm_patch_merge_type:
h = num_patch_height * height
w = num_patch_width * width
new_h, new_w = unpad_image_shape(h, w, image_size)
new_image_feature_len += new_h * (new_w + 1)
pad_ids = pad_value * (
(new_image_feature_len + len(pad_value)) // len(pad_value)
)
offset = input_ids.index(self.config.image_token_index)
# old_len + pad_len - 1, because we need to remove image_token_id
new_input_ids = (
input_ids[:offset]
+ pad_ids[:new_image_feature_len]
+ input_ids[offset + 1 :]
)
return new_input_ids, offset
# hardcode for spatial_unpad + anyres
image_aspect_ratio = "anyres" if len(image_size) == 1 else "pad"
offset_list = []
for image_s in image_size:
if len(image_size) > 16:
# 2x2 pooling with stride 2
new_image_feature_len = (
math.ceil(self.image_size / self.patch_size / 2) ** 2
)
else:
new_image_feature_len = self.image_feature_len # multiimage
height = width = self.num_patches_per_side
if "anyres" in image_aspect_ratio:
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
image_s,
self.image_grid_pinpoints,
self.vision_tower.config.image_size,
)
h = num_patch_height * height
w = num_patch_width * width
new_h, new_w = unpad_image_shape(h, w, image_s)
if "anyres_max" in self.config.image_aspect_ratio:
matched_anyres_max_num_patches = re.match(
r"anyres_max_(\d+)", self.config.image_aspect_ratio
)
if matched_anyres_max_num_patches:
max_num_patches = int(matched_anyres_max_num_patches.group(1))
# times = math.sqrt(h * w / (max_num_patches * unit**2))
times = math.sqrt(
new_h * new_w / (max_num_patches * self.image_feature_len)
)
if times > 1.1:
new_h = int(new_h // times)
new_w = int(new_w // times)
new_image_feature_len += new_h * (new_w + 1)
pad_ids = pad_value * (
(new_image_feature_len + len(pad_value)) // len(pad_value)
)
# print("calculated new_image_feature_len: ", new_image_feature_len)
try:
offset = input_ids.index(self.config.image_token_index)
except ValueError:
offset = 0
# old_len + pad_len - 1, because we need to remove image_token_id
input_ids = (
input_ids[:offset]
+ pad_ids[:new_image_feature_len]
+ input_ids[offset + 1 :]
)
offset_list.append(offset)
return input_ids, offset_list
def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor:
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
@@ -124,7 +155,6 @@ class LlavaLlamaForCausalLM(nn.Module):
# Embed text input
input_embeds = self.language_model.model.embed_tokens(input_ids)
# Embed vision input
need_vision = (
(positions[input_metadata.extend_start_loc] < self.image_feature_len)
@@ -163,27 +193,73 @@ class LlavaLlamaForCausalLM(nn.Module):
if self.mm_patch_merge_type.startswith("spatial"):
new_image_features = []
height = width = self.num_patches_per_side
for image_idx, image_feature in enumerate(image_features):
if image_feature.shape[0] > 1:
if len(image_sizes[image_idx]) == 1:
image_aspect_ratio = (
self.config.image_aspect_ratio
) # single image
else:
image_aspect_ratio = "pad" # multi image
# image_aspect_ratio = (
# "anyres" if len(image_sizes[image_idx]) == 1 else "pad"
# )
if (
image_feature.shape[0] > 1
and "anyres" in image_aspect_ratio
):
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
height = width = self.num_patches_per_side
assert height * width == base_image_feature.shape[0]
if self.image_aspect_ratio == "anyres":
(
num_patch_width,
num_patch_height,
) = get_anyres_image_grid_shape(
image_sizes[image_idx],
self.image_grid_pinpoints,
self.vision_tower.config.image_size,
if "anyres_max" in image_aspect_ratio:
matched_anyres_max_num_patches = re.match(
r"anyres_max_(\d+)", image_aspect_ratio
)
if matched_anyres_max_num_patches:
max_num_patches = int(
matched_anyres_max_num_patches.group(1)
)
if (
image_aspect_ratio == "anyres"
or "anyres_max" in image_aspect_ratio
):
vision_tower_image_size = self.image_size
try:
num_patch_width, num_patch_height = (
get_anyres_image_grid_shape(
image_sizes[image_idx][0],
self.config.image_grid_pinpoints,
vision_tower_image_size,
)
)
except Exception as e:
print(f"Error: {e}")
num_patch_width, num_patch_height = 2, 2
image_feature = image_feature.view(
num_patch_height, num_patch_width, height, width, -1
)
else:
raise NotImplementedError()
image_feature = image_feature.view(
2, 2, height, width, -1
)
# (
# num_patch_width,
# num_patch_height,
# ) = get_anyres_image_grid_shape(
# image_sizes[image_idx][0],
# self.image_grid_pinpoints,
# self.vision_tower.config.image_size,
# )
# image_feature = image_feature.view(
# num_patch_height, num_patch_width, height, width, -1
# )
if "unpad" in self.mm_patch_merge_type:
unit = image_feature.shape[2]
image_feature = image_feature.permute(
4, 0, 2, 1, 3
).contiguous()
@@ -191,8 +267,23 @@ class LlavaLlamaForCausalLM(nn.Module):
2, 3
)
image_feature = unpad_image(
image_feature, image_sizes[image_idx]
image_feature, image_sizes[image_idx][0]
)
if (
"anyres_max" in image_aspect_ratio
and matched_anyres_max_num_patches
):
c, h, w = image_feature.shape
times = math.sqrt(
h * w / (max_num_patches * unit**2)
)
if times > 1.1:
image_feature = image_feature[None]
image_feature = nn.functional.interpolate(
image_feature,
[int(h // times), int(w // times)],
mode="bilinear",
)[0]
image_feature = torch.cat(
(
image_feature,
@@ -213,16 +304,31 @@ class LlavaLlamaForCausalLM(nn.Module):
image_feature = torch.cat(
(base_image_feature, image_feature), dim=0
)
image_feature = image_feature.unsqueeze(0)
else:
image_feature = image_feature[0]
if "unpad" in self.mm_patch_merge_type:
image_feature = torch.cat(
(
image_feature,
self.language_model.model.image_newline[None],
),
dim=0,
if image_feature.shape[0] > 16: # video
# 2x2 pooling
num_of_frames = image_feature.shape[0]
image_feature = image_feature.view(
num_of_frames, height, width, -1
)
image_feature = image_feature.permute(
0, 3, 1, 2
).contiguous() # N, C, H, W
height, weight = image_feature.shape[2:]
scaled_shape = [
math.ceil(height / 2),
math.ceil(weight / 2),
]
image_feature = nn.functional.interpolate(
image_feature, size=scaled_shape, mode="bilinear"
)
image_feature = (
image_feature.flatten(2)
.transpose(1, 2)
.contiguous()
) # N, C, H*W
new_image_features.append(image_feature)
image_features = new_image_features
@@ -233,21 +339,22 @@ class LlavaLlamaForCausalLM(nn.Module):
continue
start_idx = extend_start_loc_cpu[i]
pad_len, pad_dim = image_features[pt].shape # 576, 4096
pad_dim = image_features[pt].shape[-1] # 576, 4096
dim = input_embeds.shape[1]
assert (
pad_dim == dim
), "invalid pad_dim={}, input_embed_dim={}!".format(pad_dim, dim)
# Fill in the placeholder for the image
try:
input_embeds[
start_idx
+ image_offsets[i] : start_idx
+ image_offsets[i]
+ pad_len
] = image_features[pt]
for j, image_off in enumerate(image_offsets[i]):
# print("actual image_features length: ", image_features[pt][j].shape[0])
pad_len = image_features[pt][j].shape[0]
input_embeds[
start_idx + image_off : start_idx + image_off + pad_len
] = image_features[pt][j]
except RuntimeError as e:
print(f"RuntimeError in llava image encoding: {e}")
print(image_features[pt].shape)
print(input_embeds.shape)
print(start_idx, image_offsets[i])
pt += 1
@@ -262,9 +369,16 @@ class LlavaLlamaForCausalLM(nn.Module):
# load clip vision model by cfg['mm_vision_tower']:
# huggingface_name or path_of_clip_relative_to_llava_model_dir
vision_path = self.config.mm_vision_tower
self.vision_tower = CLIPVisionModel.from_pretrained(
vision_path, torch_dtype=torch.float16
).cuda()
if "clip" in vision_path:
self.vision_tower = CLIPVisionModel.from_pretrained(
vision_path, torch_dtype=torch.float16
).cuda()
elif "siglip" in vision_path:
self.vision_tower = SiglipVisionModel.from_pretrained(
vision_path, torch_dtype=torch.float16
).cuda()
# Siglip needs all feature tokens
self.config.mm_vision_select_feature = "full"
self.vision_tower.eval()
self.vision_feature_layer = self.config.mm_vision_select_layer
@@ -276,8 +390,11 @@ class LlavaLlamaForCausalLM(nn.Module):
self.image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
self.image_grid_pinpoints = getattr(self.config, "image_grid_pinpoints", None)
self.image_feature_len = int((self.image_size / self.patch_size) ** 2)
if self.vision_feature_select_strategy == "patch":
self.image_feature_len = int((self.image_size // self.patch_size) ** 2)
if (
self.vision_feature_select_strategy == "patch"
or self.vision_feature_select_strategy == "full"
):
pass
elif self.vision_feature_select_strategy == "cls_patch":
self.image_feature_len += 1