[Feat/WIP] add llava-onevision, with support for (1) siglip encoder, (2) qwen2 decoder (3) openai api compatible server. (#1123)
Co-authored-by: Bo Li <drluodian@gmail.com>
This commit is contained in:
committed by
GitHub
parent
5fafcac008
commit
a5b14ad043
@@ -20,7 +20,7 @@ dependencies = [
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
srt = ["aiohttp", "fastapi", "hf_transfer", "huggingface_hub", "interegular",
|
||||
srt = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hub", "interegular",
|
||||
"packaging", "pillow", "psutil", "pydantic", "python-multipart",
|
||||
"torch", "uvicorn", "uvloop", "zmq",
|
||||
"vllm==0.5.4", "outlines>=0.0.44"]
|
||||
|
||||
@@ -137,7 +137,7 @@ register_chat_template(
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="chatml-llava",
|
||||
default_system_prompt="Answer the questions.",
|
||||
default_system_prompt="You are a helpful assistant.",
|
||||
role_prefix_and_suffix={
|
||||
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
|
||||
"user": ("<|im_start|>user\n", "<|im_end|>\n"),
|
||||
@@ -145,7 +145,7 @@ register_chat_template(
|
||||
},
|
||||
style=ChatTemplateStyle.PLAIN,
|
||||
stop_str=("<|im_end|>",),
|
||||
image_token=" <image>\n",
|
||||
image_token="<image>\n",
|
||||
)
|
||||
)
|
||||
|
||||
@@ -322,12 +322,17 @@ def match_chat_ml(model_path: str):
|
||||
if "tinyllama" in model_path:
|
||||
return get_chat_template("chatml")
|
||||
# Now the suffix for qwen2 chat model is "instruct"
|
||||
if "qwen" in model_path and ("chat" in model_path or "instruct" in model_path):
|
||||
if (
|
||||
"qwen" in model_path
|
||||
and ("chat" in model_path or "instruct" in model_path)
|
||||
and ("llava" not in model_path)
|
||||
):
|
||||
return get_chat_template("qwen")
|
||||
if (
|
||||
"llava-v1.6-34b" in model_path
|
||||
or "llava-v1.6-yi-34b" in model_path
|
||||
or "llava-next-video-34b" in model_path
|
||||
or "llava-onevision-qwen2" in model_path
|
||||
):
|
||||
return get_chat_template("chatml-llava")
|
||||
|
||||
|
||||
@@ -34,6 +34,7 @@ class SeparatorStyle(IntEnum):
|
||||
NO_COLON_TWO = auto()
|
||||
ADD_NEW_LINE_SINGLE = auto()
|
||||
LLAMA2 = auto()
|
||||
LLAMA3 = auto()
|
||||
CHATGLM = auto()
|
||||
CHATML = auto()
|
||||
CHATINTERN = auto()
|
||||
@@ -137,6 +138,20 @@ class Conversation:
|
||||
else:
|
||||
ret += role + ":"
|
||||
return ret
|
||||
elif self.sep_style == SeparatorStyle.LLAMA3:
|
||||
ret = "<|begin_of_text|>"
|
||||
if self.system_message:
|
||||
ret += system_prompt
|
||||
else:
|
||||
ret += ""
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n"
|
||||
ret += f"{message.strip()}<|eot_id|>"
|
||||
else:
|
||||
ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n"
|
||||
# print(ret)
|
||||
return ret
|
||||
elif self.sep_style == SeparatorStyle.LLAMA2:
|
||||
seps = [self.sep, self.sep2]
|
||||
if self.system_message:
|
||||
@@ -379,12 +394,23 @@ def generate_chat_conv(
|
||||
conv.append_message(conv.roles[0], message.content)
|
||||
else:
|
||||
real_content = ""
|
||||
# calculate number of image_url
|
||||
num_image_url = 0
|
||||
for content in message.content:
|
||||
if content.type == "image_url":
|
||||
num_image_url += 1
|
||||
if num_image_url > 1:
|
||||
image_token = "<image>"
|
||||
else:
|
||||
image_token = "<image>\n"
|
||||
for content in message.content:
|
||||
if content.type == "text":
|
||||
if num_image_url > 16:
|
||||
real_content += "\n" # for video
|
||||
real_content += content.text
|
||||
elif content.type == "image_url":
|
||||
# NOTE: Only works for llava
|
||||
real_content += "<image>\n"
|
||||
real_content += image_token
|
||||
conv.append_image(content.image_url.url)
|
||||
conv.append_message(conv.roles[0], real_content)
|
||||
elif msg_role == "assistant":
|
||||
@@ -425,6 +451,18 @@ register_conv_template(
|
||||
)
|
||||
)
|
||||
|
||||
register_conv_template(
|
||||
Conversation(
|
||||
name="chatml-llava",
|
||||
system_template="<|im_start|>system\n{system_message}",
|
||||
system_message="You are a helpful assistant.",
|
||||
roles=("<|im_start|>user", "<|im_start|>assistant"),
|
||||
sep_style=SeparatorStyle.CHATML,
|
||||
sep="<|im_end|>",
|
||||
stop_str=["<|endoftext|>", "<|im_end|>"],
|
||||
)
|
||||
)
|
||||
|
||||
register_conv_template(
|
||||
Conversation(
|
||||
name="vicuna_v1.1",
|
||||
@@ -437,6 +475,17 @@ register_conv_template(
|
||||
)
|
||||
)
|
||||
|
||||
register_conv_template(
|
||||
Conversation(
|
||||
name="llava_llama_3",
|
||||
system_message="You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.",
|
||||
system_template="<|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>",
|
||||
roles=("user", "assistant"),
|
||||
sep_style=SeparatorStyle.LLAMA3,
|
||||
sep="",
|
||||
stop_str=["<|end_of_text|>", "<|eot_id|>"],
|
||||
)
|
||||
)
|
||||
# Reference: https://github.com/InternLM/lmdeploy/blob/387bf54b4f124e72aab30ae9755f562e435d3d01/lmdeploy/model.py#L425-L442
|
||||
register_conv_template(
|
||||
Conversation(
|
||||
|
||||
@@ -131,11 +131,49 @@ class TokenizerManager:
|
||||
self.model_update_lock = asyncio.Lock()
|
||||
self.model_update_result = None
|
||||
|
||||
async def get_pixel_values(self, image_data):
|
||||
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
|
||||
grid_pinpoints = (
|
||||
self.hf_config.image_grid_pinpoints if aspect_ratio == "anyres" else None
|
||||
async def get_pixel_values(self, image_data, aspect_ratio=None):
|
||||
aspect_ratio = (
|
||||
getattr(self.hf_config, "image_aspect_ratio", None)
|
||||
if aspect_ratio is None
|
||||
else aspect_ratio
|
||||
)
|
||||
grid_pinpoints = (
|
||||
self.hf_config.image_grid_pinpoints
|
||||
if hasattr(self.hf_config, "image_grid_pinpoints")
|
||||
and "anyres" in aspect_ratio
|
||||
else None
|
||||
)
|
||||
|
||||
if isinstance(image_data, list) and len(image_data) > 0:
|
||||
pixel_values, image_hash, image_size = [], [], []
|
||||
if len(image_data) > 1:
|
||||
aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
|
||||
for img_data in image_data:
|
||||
pixel_v, image_h, image_s = await self._process_single_image(
|
||||
img_data, aspect_ratio, grid_pinpoints
|
||||
)
|
||||
pixel_values.append(pixel_v)
|
||||
image_hash.append(image_h)
|
||||
image_size.append(image_s)
|
||||
pixel_values = np.stack(pixel_values, axis=0)
|
||||
else:
|
||||
pixel_values, image_hash, image_size = await self._process_single_image(
|
||||
image_data[0], aspect_ratio, grid_pinpoints
|
||||
)
|
||||
image_hash = [image_hash]
|
||||
image_size = [image_size]
|
||||
elif isinstance(image_data, str):
|
||||
pixel_values, image_hash, image_size = await self._process_single_image(
|
||||
image_data, aspect_ratio, grid_pinpoints
|
||||
)
|
||||
image_hash = [image_hash]
|
||||
image_size = [image_size]
|
||||
else:
|
||||
pixel_values, image_hash, image_size = None, None, None
|
||||
|
||||
return pixel_values, image_hash, image_size
|
||||
|
||||
async def _process_single_image(self, image_data, aspect_ratio, grid_pinpoints):
|
||||
if self.executor is not None:
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
@@ -194,8 +232,8 @@ class TokenizerManager:
|
||||
)
|
||||
|
||||
if self.is_generation:
|
||||
pixel_values, image_hash, image_size = await self._get_pixel_values(
|
||||
obj.image_data if not_use_index else obj.image_data[index]
|
||||
pixel_values, image_hash, image_size = await self.get_pixel_values(
|
||||
obj.image_data
|
||||
)
|
||||
return_logprob = (
|
||||
obj.return_logprob if not_use_index else obj.return_logprob[index]
|
||||
@@ -704,7 +742,7 @@ def get_pixel_values(
|
||||
tuple(int(x * 255) for x in processor.image_processor.image_mean),
|
||||
)
|
||||
pixel_values = processor.image_processor(image)["pixel_values"][0]
|
||||
elif image_aspect_ratio == "anyres":
|
||||
elif image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
|
||||
pixel_values = process_anyres_image(
|
||||
image, processor.image_processor, image_grid_pinpoints
|
||||
)
|
||||
|
||||
@@ -322,11 +322,16 @@ class ModelTpServer:
|
||||
if self.model_runner.is_generation:
|
||||
req.pixel_values = recv_req.pixel_values
|
||||
if req.pixel_values is not None:
|
||||
image_hash = (
|
||||
hash(tuple(recv_req.image_hash))
|
||||
if isinstance(recv_req.image_hash, list)
|
||||
else recv_req.image_hash
|
||||
)
|
||||
req.pad_value = [
|
||||
(recv_req.image_hash) % self.model_config.vocab_size,
|
||||
(recv_req.image_hash >> 16) % self.model_config.vocab_size,
|
||||
(recv_req.image_hash >> 32) % self.model_config.vocab_size,
|
||||
(recv_req.image_hash >> 64) % self.model_config.vocab_size,
|
||||
(image_hash) % self.model_config.vocab_size,
|
||||
(image_hash >> 16) % self.model_config.vocab_size,
|
||||
(image_hash >> 32) % self.model_config.vocab_size,
|
||||
(image_hash >> 64) % self.model_config.vocab_size,
|
||||
]
|
||||
req.image_size = recv_req.image_size
|
||||
(
|
||||
|
||||
@@ -13,10 +13,25 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
# Source: https://github.com/haotian-liu/LLaVA/blob/main/llava/mm_utils.py
|
||||
# Source: https://github.com/LLaVA-VL/LLaVA-NeXT/blob/main/llava/mm_utils.py
|
||||
"""
|
||||
Utilities for multi-modal models.
|
||||
|
||||
This python file mainly contains utilities that were used in the
|
||||
image processing logic of llava-next including operations such as
|
||||
anyres and anyres_max
|
||||
|
||||
Currently supports the anyres and anyres_max operation for CLIP and
|
||||
SigLip. For more information, you may refer to the paper or the blog
|
||||
|
||||
LLaVA-NeXT : https://llava-vl.github.io/blog/2024-01-30-llava-next/
|
||||
LLaVA-Onevision : https://arxiv.org/pdf/2408.03326
|
||||
|
||||
"""
|
||||
import ast
|
||||
import base64
|
||||
import math
|
||||
import re
|
||||
from io import BytesIO
|
||||
|
||||
import numpy as np
|
||||
@@ -40,10 +55,13 @@ def select_best_resolution(original_size, possible_resolutions):
|
||||
min_wasted_resolution = float("inf")
|
||||
|
||||
for width, height in possible_resolutions:
|
||||
# Calculate the downscaled size to keep the aspect ratio
|
||||
scale = min(width / original_width, height / original_height)
|
||||
downscaled_width, downscaled_height = int(original_width * scale), int(
|
||||
original_height * scale
|
||||
)
|
||||
|
||||
# Calculate effective and wasted resolutions
|
||||
effective_resolution = min(
|
||||
downscaled_width * downscaled_height, original_width * original_height
|
||||
)
|
||||
@@ -129,6 +147,26 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||||
Returns:
|
||||
tuple: The shape of the image patch grid in the format (width, height).
|
||||
"""
|
||||
if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
|
||||
assert patch_size in [
|
||||
224,
|
||||
336,
|
||||
384,
|
||||
448,
|
||||
512,
|
||||
], "patch_size should be in [224, 336, 384, 448, 512]"
|
||||
# Use regex to extract the range from the input string
|
||||
matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
|
||||
range_start = tuple(map(int, matches[0]))
|
||||
range_end = tuple(map(int, matches[-1]))
|
||||
# Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1])
|
||||
grid_pinpoints = [
|
||||
(i, j)
|
||||
for i in range(range_start[0], range_end[0] + 1)
|
||||
for j in range(range_start[1], range_end[1] + 1)
|
||||
]
|
||||
# Multiply all elements by patch_size
|
||||
grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
|
||||
if type(grid_pinpoints) is list:
|
||||
possible_resolutions = grid_pinpoints
|
||||
else:
|
||||
@@ -149,6 +187,31 @@ def process_anyres_image(image, processor, grid_pinpoints):
|
||||
Returns:
|
||||
np.array: An np array containing the processed image patches.
|
||||
"""
|
||||
if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
|
||||
try:
|
||||
patch_size = processor.size[0]
|
||||
except Exception as e:
|
||||
patch_size = processor.size["shortest_edge"]
|
||||
assert patch_size in [
|
||||
224,
|
||||
336,
|
||||
384,
|
||||
448,
|
||||
512,
|
||||
], "patch_size should be in [224, 336, 384, 448, 512]"
|
||||
# Use regex to extract the range from the input string
|
||||
matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
|
||||
range_start = tuple(map(int, matches[0]))
|
||||
range_end = tuple(map(int, matches[-1]))
|
||||
# Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1])
|
||||
grid_pinpoints = [
|
||||
(i, j)
|
||||
for i in range(range_start[0], range_end[0] + 1)
|
||||
for j in range(range_start[1], range_end[1] + 1)
|
||||
]
|
||||
# Multiply all elements by patch_size
|
||||
grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
|
||||
|
||||
if type(grid_pinpoints) is list:
|
||||
possible_resolutions = grid_pinpoints
|
||||
else:
|
||||
@@ -156,15 +219,24 @@ def process_anyres_image(image, processor, grid_pinpoints):
|
||||
best_resolution = select_best_resolution(image.size, possible_resolutions)
|
||||
image_padded = resize_and_pad_image(image, best_resolution)
|
||||
|
||||
patches = divide_to_patches(image_padded, processor.crop_size["height"])
|
||||
|
||||
image_original_resize = image.resize(
|
||||
(processor.size["shortest_edge"], processor.size["shortest_edge"])
|
||||
# For Siglip processor, only have size but no crop size
|
||||
crop_size = (
|
||||
processor.crop_size["height"]
|
||||
if "crop_size" in processor.__dict__
|
||||
else processor.size["height"]
|
||||
)
|
||||
shortest_edge = (
|
||||
processor.size["shortest_edge"]
|
||||
if "shortest_edge" in processor.size
|
||||
else processor.size["height"]
|
||||
)
|
||||
patches = divide_to_patches(image_padded, crop_size)
|
||||
|
||||
image_original_resize = image.resize((shortest_edge, shortest_edge))
|
||||
|
||||
image_patches = [image_original_resize] + patches
|
||||
image_patches = [
|
||||
processor.preprocess(image_patch)["pixel_values"][0]
|
||||
processor.preprocess(image_patch.convert("RGB"))["pixel_values"][0]
|
||||
for image_patch in image_patches
|
||||
]
|
||||
return np.stack(image_patches, axis=0)
|
||||
@@ -255,7 +327,7 @@ def process_images(images, image_processor, model_cfg):
|
||||
)
|
||||
image = image_processor.preprocess(image)["pixel_values"][0]
|
||||
new_images.append(image)
|
||||
elif image_aspect_ratio == "anyres":
|
||||
elif "anyres" in image_aspect_ratio:
|
||||
for image in images:
|
||||
image = process_anyres_image(
|
||||
image, image_processor, model_cfg.image_grid_pinpoints
|
||||
|
||||
@@ -88,14 +88,19 @@ class InputMetadata:
|
||||
reqs = batch.reqs
|
||||
self.pixel_values = [r.pixel_values for r in reqs]
|
||||
self.image_sizes = [r.image_size for r in reqs]
|
||||
self.image_offsets = [
|
||||
(
|
||||
(r.image_offset - batch.prefix_lens_cpu[i])
|
||||
if r.image_offset is not None
|
||||
else 0
|
||||
)
|
||||
for i, r in enumerate(reqs)
|
||||
]
|
||||
self.image_offsets = []
|
||||
for r in reqs:
|
||||
if isinstance(r.image_offset, list):
|
||||
self.image_offsets.append(
|
||||
[
|
||||
(image_offset - len(r.prefix_indices))
|
||||
for image_offset in r.image_offset
|
||||
]
|
||||
)
|
||||
elif isinstance(r.image_offset, int):
|
||||
self.image_offsets.append(r.image_offset - len(r.prefix_indices))
|
||||
elif r.image_offset is None:
|
||||
self.image_offsets.append(0)
|
||||
|
||||
def compute_positions(self, batch: ScheduleBatch):
|
||||
position_ids_offsets = batch.position_ids_offsets
|
||||
|
||||
@@ -15,6 +15,8 @@ limitations under the License.
|
||||
|
||||
"""Inference-only LLaVa model compatible with HuggingFace weights."""
|
||||
|
||||
import math
|
||||
import re
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
@@ -26,6 +28,8 @@ from transformers import (
|
||||
LlavaConfig,
|
||||
MistralConfig,
|
||||
Qwen2Config,
|
||||
SiglipVisionConfig,
|
||||
SiglipVisionModel,
|
||||
)
|
||||
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
||||
from vllm.config import CacheConfig
|
||||
@@ -63,34 +67,61 @@ class LlavaLlamaForCausalLM(nn.Module):
|
||||
)
|
||||
|
||||
def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None):
|
||||
new_image_feature_len = self.image_feature_len
|
||||
# now only support spatial_unpad + anyres
|
||||
if self.mm_patch_merge_type.startswith("spatial"):
|
||||
height = width = self.num_patches_per_side
|
||||
if pt_shape[0] > 1:
|
||||
if self.image_aspect_ratio == "anyres":
|
||||
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
|
||||
image_size,
|
||||
self.image_grid_pinpoints,
|
||||
self.vision_tower.config.image_size,
|
||||
)
|
||||
if "unpad" in self.mm_patch_merge_type:
|
||||
h = num_patch_height * height
|
||||
w = num_patch_width * width
|
||||
new_h, new_w = unpad_image_shape(h, w, image_size)
|
||||
new_image_feature_len += new_h * (new_w + 1)
|
||||
|
||||
pad_ids = pad_value * (
|
||||
(new_image_feature_len + len(pad_value)) // len(pad_value)
|
||||
)
|
||||
offset = input_ids.index(self.config.image_token_index)
|
||||
# old_len + pad_len - 1, because we need to remove image_token_id
|
||||
new_input_ids = (
|
||||
input_ids[:offset]
|
||||
+ pad_ids[:new_image_feature_len]
|
||||
+ input_ids[offset + 1 :]
|
||||
)
|
||||
return new_input_ids, offset
|
||||
# hardcode for spatial_unpad + anyres
|
||||
image_aspect_ratio = "anyres" if len(image_size) == 1 else "pad"
|
||||
offset_list = []
|
||||
for image_s in image_size:
|
||||
if len(image_size) > 16:
|
||||
# 2x2 pooling with stride 2
|
||||
new_image_feature_len = (
|
||||
math.ceil(self.image_size / self.patch_size / 2) ** 2
|
||||
)
|
||||
else:
|
||||
new_image_feature_len = self.image_feature_len # multiimage
|
||||
|
||||
height = width = self.num_patches_per_side
|
||||
if "anyres" in image_aspect_ratio:
|
||||
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
|
||||
image_s,
|
||||
self.image_grid_pinpoints,
|
||||
self.vision_tower.config.image_size,
|
||||
)
|
||||
h = num_patch_height * height
|
||||
w = num_patch_width * width
|
||||
new_h, new_w = unpad_image_shape(h, w, image_s)
|
||||
|
||||
if "anyres_max" in self.config.image_aspect_ratio:
|
||||
matched_anyres_max_num_patches = re.match(
|
||||
r"anyres_max_(\d+)", self.config.image_aspect_ratio
|
||||
)
|
||||
if matched_anyres_max_num_patches:
|
||||
max_num_patches = int(matched_anyres_max_num_patches.group(1))
|
||||
# times = math.sqrt(h * w / (max_num_patches * unit**2))
|
||||
times = math.sqrt(
|
||||
new_h * new_w / (max_num_patches * self.image_feature_len)
|
||||
)
|
||||
if times > 1.1:
|
||||
new_h = int(new_h // times)
|
||||
new_w = int(new_w // times)
|
||||
new_image_feature_len += new_h * (new_w + 1)
|
||||
|
||||
pad_ids = pad_value * (
|
||||
(new_image_feature_len + len(pad_value)) // len(pad_value)
|
||||
)
|
||||
# print("calculated new_image_feature_len: ", new_image_feature_len)
|
||||
try:
|
||||
offset = input_ids.index(self.config.image_token_index)
|
||||
except ValueError:
|
||||
offset = 0
|
||||
# old_len + pad_len - 1, because we need to remove image_token_id
|
||||
input_ids = (
|
||||
input_ids[:offset]
|
||||
+ pad_ids[:new_image_feature_len]
|
||||
+ input_ids[offset + 1 :]
|
||||
)
|
||||
offset_list.append(offset)
|
||||
return input_ids, offset_list
|
||||
|
||||
def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||||
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
|
||||
@@ -124,7 +155,6 @@ class LlavaLlamaForCausalLM(nn.Module):
|
||||
|
||||
# Embed text input
|
||||
input_embeds = self.language_model.model.embed_tokens(input_ids)
|
||||
|
||||
# Embed vision input
|
||||
need_vision = (
|
||||
(positions[input_metadata.extend_start_loc] < self.image_feature_len)
|
||||
@@ -163,27 +193,73 @@ class LlavaLlamaForCausalLM(nn.Module):
|
||||
|
||||
if self.mm_patch_merge_type.startswith("spatial"):
|
||||
new_image_features = []
|
||||
height = width = self.num_patches_per_side
|
||||
for image_idx, image_feature in enumerate(image_features):
|
||||
if image_feature.shape[0] > 1:
|
||||
if len(image_sizes[image_idx]) == 1:
|
||||
image_aspect_ratio = (
|
||||
self.config.image_aspect_ratio
|
||||
) # single image
|
||||
else:
|
||||
image_aspect_ratio = "pad" # multi image
|
||||
# image_aspect_ratio = (
|
||||
# "anyres" if len(image_sizes[image_idx]) == 1 else "pad"
|
||||
# )
|
||||
if (
|
||||
image_feature.shape[0] > 1
|
||||
and "anyres" in image_aspect_ratio
|
||||
):
|
||||
base_image_feature = image_feature[0]
|
||||
image_feature = image_feature[1:]
|
||||
height = width = self.num_patches_per_side
|
||||
assert height * width == base_image_feature.shape[0]
|
||||
if self.image_aspect_ratio == "anyres":
|
||||
(
|
||||
num_patch_width,
|
||||
num_patch_height,
|
||||
) = get_anyres_image_grid_shape(
|
||||
image_sizes[image_idx],
|
||||
self.image_grid_pinpoints,
|
||||
self.vision_tower.config.image_size,
|
||||
|
||||
if "anyres_max" in image_aspect_ratio:
|
||||
matched_anyres_max_num_patches = re.match(
|
||||
r"anyres_max_(\d+)", image_aspect_ratio
|
||||
)
|
||||
if matched_anyres_max_num_patches:
|
||||
max_num_patches = int(
|
||||
matched_anyres_max_num_patches.group(1)
|
||||
)
|
||||
|
||||
if (
|
||||
image_aspect_ratio == "anyres"
|
||||
or "anyres_max" in image_aspect_ratio
|
||||
):
|
||||
vision_tower_image_size = self.image_size
|
||||
try:
|
||||
num_patch_width, num_patch_height = (
|
||||
get_anyres_image_grid_shape(
|
||||
image_sizes[image_idx][0],
|
||||
self.config.image_grid_pinpoints,
|
||||
vision_tower_image_size,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
num_patch_width, num_patch_height = 2, 2
|
||||
image_feature = image_feature.view(
|
||||
num_patch_height, num_patch_width, height, width, -1
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
image_feature = image_feature.view(
|
||||
2, 2, height, width, -1
|
||||
)
|
||||
|
||||
# (
|
||||
# num_patch_width,
|
||||
# num_patch_height,
|
||||
# ) = get_anyres_image_grid_shape(
|
||||
# image_sizes[image_idx][0],
|
||||
# self.image_grid_pinpoints,
|
||||
# self.vision_tower.config.image_size,
|
||||
# )
|
||||
|
||||
# image_feature = image_feature.view(
|
||||
# num_patch_height, num_patch_width, height, width, -1
|
||||
# )
|
||||
|
||||
if "unpad" in self.mm_patch_merge_type:
|
||||
unit = image_feature.shape[2]
|
||||
image_feature = image_feature.permute(
|
||||
4, 0, 2, 1, 3
|
||||
).contiguous()
|
||||
@@ -191,8 +267,23 @@ class LlavaLlamaForCausalLM(nn.Module):
|
||||
2, 3
|
||||
)
|
||||
image_feature = unpad_image(
|
||||
image_feature, image_sizes[image_idx]
|
||||
image_feature, image_sizes[image_idx][0]
|
||||
)
|
||||
if (
|
||||
"anyres_max" in image_aspect_ratio
|
||||
and matched_anyres_max_num_patches
|
||||
):
|
||||
c, h, w = image_feature.shape
|
||||
times = math.sqrt(
|
||||
h * w / (max_num_patches * unit**2)
|
||||
)
|
||||
if times > 1.1:
|
||||
image_feature = image_feature[None]
|
||||
image_feature = nn.functional.interpolate(
|
||||
image_feature,
|
||||
[int(h // times), int(w // times)],
|
||||
mode="bilinear",
|
||||
)[0]
|
||||
image_feature = torch.cat(
|
||||
(
|
||||
image_feature,
|
||||
@@ -213,16 +304,31 @@ class LlavaLlamaForCausalLM(nn.Module):
|
||||
image_feature = torch.cat(
|
||||
(base_image_feature, image_feature), dim=0
|
||||
)
|
||||
image_feature = image_feature.unsqueeze(0)
|
||||
else:
|
||||
image_feature = image_feature[0]
|
||||
if "unpad" in self.mm_patch_merge_type:
|
||||
image_feature = torch.cat(
|
||||
(
|
||||
image_feature,
|
||||
self.language_model.model.image_newline[None],
|
||||
),
|
||||
dim=0,
|
||||
if image_feature.shape[0] > 16: # video
|
||||
# 2x2 pooling
|
||||
num_of_frames = image_feature.shape[0]
|
||||
image_feature = image_feature.view(
|
||||
num_of_frames, height, width, -1
|
||||
)
|
||||
image_feature = image_feature.permute(
|
||||
0, 3, 1, 2
|
||||
).contiguous() # N, C, H, W
|
||||
height, weight = image_feature.shape[2:]
|
||||
scaled_shape = [
|
||||
math.ceil(height / 2),
|
||||
math.ceil(weight / 2),
|
||||
]
|
||||
image_feature = nn.functional.interpolate(
|
||||
image_feature, size=scaled_shape, mode="bilinear"
|
||||
)
|
||||
image_feature = (
|
||||
image_feature.flatten(2)
|
||||
.transpose(1, 2)
|
||||
.contiguous()
|
||||
) # N, C, H*W
|
||||
|
||||
new_image_features.append(image_feature)
|
||||
image_features = new_image_features
|
||||
|
||||
@@ -233,21 +339,22 @@ class LlavaLlamaForCausalLM(nn.Module):
|
||||
continue
|
||||
|
||||
start_idx = extend_start_loc_cpu[i]
|
||||
pad_len, pad_dim = image_features[pt].shape # 576, 4096
|
||||
pad_dim = image_features[pt].shape[-1] # 576, 4096
|
||||
dim = input_embeds.shape[1]
|
||||
assert (
|
||||
pad_dim == dim
|
||||
), "invalid pad_dim={}, input_embed_dim={}!".format(pad_dim, dim)
|
||||
# Fill in the placeholder for the image
|
||||
try:
|
||||
input_embeds[
|
||||
start_idx
|
||||
+ image_offsets[i] : start_idx
|
||||
+ image_offsets[i]
|
||||
+ pad_len
|
||||
] = image_features[pt]
|
||||
for j, image_off in enumerate(image_offsets[i]):
|
||||
# print("actual image_features length: ", image_features[pt][j].shape[0])
|
||||
pad_len = image_features[pt][j].shape[0]
|
||||
input_embeds[
|
||||
start_idx + image_off : start_idx + image_off + pad_len
|
||||
] = image_features[pt][j]
|
||||
except RuntimeError as e:
|
||||
print(f"RuntimeError in llava image encoding: {e}")
|
||||
print(image_features[pt].shape)
|
||||
print(input_embeds.shape)
|
||||
print(start_idx, image_offsets[i])
|
||||
pt += 1
|
||||
@@ -262,9 +369,16 @@ class LlavaLlamaForCausalLM(nn.Module):
|
||||
# load clip vision model by cfg['mm_vision_tower']:
|
||||
# huggingface_name or path_of_clip_relative_to_llava_model_dir
|
||||
vision_path = self.config.mm_vision_tower
|
||||
self.vision_tower = CLIPVisionModel.from_pretrained(
|
||||
vision_path, torch_dtype=torch.float16
|
||||
).cuda()
|
||||
if "clip" in vision_path:
|
||||
self.vision_tower = CLIPVisionModel.from_pretrained(
|
||||
vision_path, torch_dtype=torch.float16
|
||||
).cuda()
|
||||
elif "siglip" in vision_path:
|
||||
self.vision_tower = SiglipVisionModel.from_pretrained(
|
||||
vision_path, torch_dtype=torch.float16
|
||||
).cuda()
|
||||
# Siglip needs all feature tokens
|
||||
self.config.mm_vision_select_feature = "full"
|
||||
self.vision_tower.eval()
|
||||
|
||||
self.vision_feature_layer = self.config.mm_vision_select_layer
|
||||
@@ -276,8 +390,11 @@ class LlavaLlamaForCausalLM(nn.Module):
|
||||
self.image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
|
||||
self.image_grid_pinpoints = getattr(self.config, "image_grid_pinpoints", None)
|
||||
|
||||
self.image_feature_len = int((self.image_size / self.patch_size) ** 2)
|
||||
if self.vision_feature_select_strategy == "patch":
|
||||
self.image_feature_len = int((self.image_size // self.patch_size) ** 2)
|
||||
if (
|
||||
self.vision_feature_select_strategy == "patch"
|
||||
or self.vision_feature_select_strategy == "full"
|
||||
):
|
||||
pass
|
||||
elif self.vision_feature_select_strategy == "cls_patch":
|
||||
self.image_feature_len += 1
|
||||
|
||||
Reference in New Issue
Block a user