[Feat/WIP] add llava-onevision, with support for (1) siglip encoder, (2) qwen2 decoder (3) openai api compatible server. (#1123)

Co-authored-by: Bo Li <drluodian@gmail.com>
This commit is contained in:
Kaichen Zhang - NTU
2024-08-24 05:11:16 +08:00
committed by GitHub
parent 5fafcac008
commit a5b14ad043
13 changed files with 703 additions and 95 deletions

View File

@@ -20,7 +20,7 @@ dependencies = [
]
[project.optional-dependencies]
srt = ["aiohttp", "fastapi", "hf_transfer", "huggingface_hub", "interegular",
srt = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hub", "interegular",
"packaging", "pillow", "psutil", "pydantic", "python-multipart",
"torch", "uvicorn", "uvloop", "zmq",
"vllm==0.5.4", "outlines>=0.0.44"]

View File

@@ -137,7 +137,7 @@ register_chat_template(
register_chat_template(
ChatTemplate(
name="chatml-llava",
default_system_prompt="Answer the questions.",
default_system_prompt="You are a helpful assistant.",
role_prefix_and_suffix={
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
"user": ("<|im_start|>user\n", "<|im_end|>\n"),
@@ -145,7 +145,7 @@ register_chat_template(
},
style=ChatTemplateStyle.PLAIN,
stop_str=("<|im_end|>",),
image_token=" <image>\n",
image_token="<image>\n",
)
)
@@ -322,12 +322,17 @@ def match_chat_ml(model_path: str):
if "tinyllama" in model_path:
return get_chat_template("chatml")
# Now the suffix for qwen2 chat model is "instruct"
if "qwen" in model_path and ("chat" in model_path or "instruct" in model_path):
if (
"qwen" in model_path
and ("chat" in model_path or "instruct" in model_path)
and ("llava" not in model_path)
):
return get_chat_template("qwen")
if (
"llava-v1.6-34b" in model_path
or "llava-v1.6-yi-34b" in model_path
or "llava-next-video-34b" in model_path
or "llava-onevision-qwen2" in model_path
):
return get_chat_template("chatml-llava")

View File

@@ -34,6 +34,7 @@ class SeparatorStyle(IntEnum):
NO_COLON_TWO = auto()
ADD_NEW_LINE_SINGLE = auto()
LLAMA2 = auto()
LLAMA3 = auto()
CHATGLM = auto()
CHATML = auto()
CHATINTERN = auto()
@@ -137,6 +138,20 @@ class Conversation:
else:
ret += role + ":"
return ret
elif self.sep_style == SeparatorStyle.LLAMA3:
ret = "<|begin_of_text|>"
if self.system_message:
ret += system_prompt
else:
ret += ""
for i, (role, message) in enumerate(self.messages):
if message:
ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n"
ret += f"{message.strip()}<|eot_id|>"
else:
ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n"
# print(ret)
return ret
elif self.sep_style == SeparatorStyle.LLAMA2:
seps = [self.sep, self.sep2]
if self.system_message:
@@ -379,12 +394,23 @@ def generate_chat_conv(
conv.append_message(conv.roles[0], message.content)
else:
real_content = ""
# calculate number of image_url
num_image_url = 0
for content in message.content:
if content.type == "image_url":
num_image_url += 1
if num_image_url > 1:
image_token = "<image>"
else:
image_token = "<image>\n"
for content in message.content:
if content.type == "text":
if num_image_url > 16:
real_content += "\n" # for video
real_content += content.text
elif content.type == "image_url":
# NOTE: Only works for llava
real_content += "<image>\n"
real_content += image_token
conv.append_image(content.image_url.url)
conv.append_message(conv.roles[0], real_content)
elif msg_role == "assistant":
@@ -425,6 +451,18 @@ register_conv_template(
)
)
register_conv_template(
Conversation(
name="chatml-llava",
system_template="<|im_start|>system\n{system_message}",
system_message="You are a helpful assistant.",
roles=("<|im_start|>user", "<|im_start|>assistant"),
sep_style=SeparatorStyle.CHATML,
sep="<|im_end|>",
stop_str=["<|endoftext|>", "<|im_end|>"],
)
)
register_conv_template(
Conversation(
name="vicuna_v1.1",
@@ -437,6 +475,17 @@ register_conv_template(
)
)
register_conv_template(
Conversation(
name="llava_llama_3",
system_message="You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.",
system_template="<|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>",
roles=("user", "assistant"),
sep_style=SeparatorStyle.LLAMA3,
sep="",
stop_str=["<|end_of_text|>", "<|eot_id|>"],
)
)
# Reference: https://github.com/InternLM/lmdeploy/blob/387bf54b4f124e72aab30ae9755f562e435d3d01/lmdeploy/model.py#L425-L442
register_conv_template(
Conversation(

View File

@@ -131,11 +131,49 @@ class TokenizerManager:
self.model_update_lock = asyncio.Lock()
self.model_update_result = None
async def get_pixel_values(self, image_data):
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
grid_pinpoints = (
self.hf_config.image_grid_pinpoints if aspect_ratio == "anyres" else None
async def get_pixel_values(self, image_data, aspect_ratio=None):
aspect_ratio = (
getattr(self.hf_config, "image_aspect_ratio", None)
if aspect_ratio is None
else aspect_ratio
)
grid_pinpoints = (
self.hf_config.image_grid_pinpoints
if hasattr(self.hf_config, "image_grid_pinpoints")
and "anyres" in aspect_ratio
else None
)
if isinstance(image_data, list) and len(image_data) > 0:
pixel_values, image_hash, image_size = [], [], []
if len(image_data) > 1:
aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
for img_data in image_data:
pixel_v, image_h, image_s = await self._process_single_image(
img_data, aspect_ratio, grid_pinpoints
)
pixel_values.append(pixel_v)
image_hash.append(image_h)
image_size.append(image_s)
pixel_values = np.stack(pixel_values, axis=0)
else:
pixel_values, image_hash, image_size = await self._process_single_image(
image_data[0], aspect_ratio, grid_pinpoints
)
image_hash = [image_hash]
image_size = [image_size]
elif isinstance(image_data, str):
pixel_values, image_hash, image_size = await self._process_single_image(
image_data, aspect_ratio, grid_pinpoints
)
image_hash = [image_hash]
image_size = [image_size]
else:
pixel_values, image_hash, image_size = None, None, None
return pixel_values, image_hash, image_size
async def _process_single_image(self, image_data, aspect_ratio, grid_pinpoints):
if self.executor is not None:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
@@ -194,8 +232,8 @@ class TokenizerManager:
)
if self.is_generation:
pixel_values, image_hash, image_size = await self._get_pixel_values(
obj.image_data if not_use_index else obj.image_data[index]
pixel_values, image_hash, image_size = await self.get_pixel_values(
obj.image_data
)
return_logprob = (
obj.return_logprob if not_use_index else obj.return_logprob[index]
@@ -704,7 +742,7 @@ def get_pixel_values(
tuple(int(x * 255) for x in processor.image_processor.image_mean),
)
pixel_values = processor.image_processor(image)["pixel_values"][0]
elif image_aspect_ratio == "anyres":
elif image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
pixel_values = process_anyres_image(
image, processor.image_processor, image_grid_pinpoints
)

View File

@@ -322,11 +322,16 @@ class ModelTpServer:
if self.model_runner.is_generation:
req.pixel_values = recv_req.pixel_values
if req.pixel_values is not None:
image_hash = (
hash(tuple(recv_req.image_hash))
if isinstance(recv_req.image_hash, list)
else recv_req.image_hash
)
req.pad_value = [
(recv_req.image_hash) % self.model_config.vocab_size,
(recv_req.image_hash >> 16) % self.model_config.vocab_size,
(recv_req.image_hash >> 32) % self.model_config.vocab_size,
(recv_req.image_hash >> 64) % self.model_config.vocab_size,
(image_hash) % self.model_config.vocab_size,
(image_hash >> 16) % self.model_config.vocab_size,
(image_hash >> 32) % self.model_config.vocab_size,
(image_hash >> 64) % self.model_config.vocab_size,
]
req.image_size = recv_req.image_size
(

View File

@@ -13,10 +13,25 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
# Source: https://github.com/haotian-liu/LLaVA/blob/main/llava/mm_utils.py
# Source: https://github.com/LLaVA-VL/LLaVA-NeXT/blob/main/llava/mm_utils.py
"""
Utilities for multi-modal models.
This python file mainly contains utilities that were used in the
image processing logic of llava-next including operations such as
anyres and anyres_max
Currently supports the anyres and anyres_max operation for CLIP and
SigLip. For more information, you may refer to the paper or the blog
LLaVA-NeXT : https://llava-vl.github.io/blog/2024-01-30-llava-next/
LLaVA-Onevision : https://arxiv.org/pdf/2408.03326
"""
import ast
import base64
import math
import re
from io import BytesIO
import numpy as np
@@ -40,10 +55,13 @@ def select_best_resolution(original_size, possible_resolutions):
min_wasted_resolution = float("inf")
for width, height in possible_resolutions:
# Calculate the downscaled size to keep the aspect ratio
scale = min(width / original_width, height / original_height)
downscaled_width, downscaled_height = int(original_width * scale), int(
original_height * scale
)
# Calculate effective and wasted resolutions
effective_resolution = min(
downscaled_width * downscaled_height, original_width * original_height
)
@@ -129,6 +147,26 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
Returns:
tuple: The shape of the image patch grid in the format (width, height).
"""
if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
assert patch_size in [
224,
336,
384,
448,
512,
], "patch_size should be in [224, 336, 384, 448, 512]"
# Use regex to extract the range from the input string
matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
range_start = tuple(map(int, matches[0]))
range_end = tuple(map(int, matches[-1]))
# Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1])
grid_pinpoints = [
(i, j)
for i in range(range_start[0], range_end[0] + 1)
for j in range(range_start[1], range_end[1] + 1)
]
# Multiply all elements by patch_size
grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
if type(grid_pinpoints) is list:
possible_resolutions = grid_pinpoints
else:
@@ -149,6 +187,31 @@ def process_anyres_image(image, processor, grid_pinpoints):
Returns:
np.array: An np array containing the processed image patches.
"""
if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
try:
patch_size = processor.size[0]
except Exception as e:
patch_size = processor.size["shortest_edge"]
assert patch_size in [
224,
336,
384,
448,
512,
], "patch_size should be in [224, 336, 384, 448, 512]"
# Use regex to extract the range from the input string
matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
range_start = tuple(map(int, matches[0]))
range_end = tuple(map(int, matches[-1]))
# Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1])
grid_pinpoints = [
(i, j)
for i in range(range_start[0], range_end[0] + 1)
for j in range(range_start[1], range_end[1] + 1)
]
# Multiply all elements by patch_size
grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
if type(grid_pinpoints) is list:
possible_resolutions = grid_pinpoints
else:
@@ -156,15 +219,24 @@ def process_anyres_image(image, processor, grid_pinpoints):
best_resolution = select_best_resolution(image.size, possible_resolutions)
image_padded = resize_and_pad_image(image, best_resolution)
patches = divide_to_patches(image_padded, processor.crop_size["height"])
image_original_resize = image.resize(
(processor.size["shortest_edge"], processor.size["shortest_edge"])
# For Siglip processor, only have size but no crop size
crop_size = (
processor.crop_size["height"]
if "crop_size" in processor.__dict__
else processor.size["height"]
)
shortest_edge = (
processor.size["shortest_edge"]
if "shortest_edge" in processor.size
else processor.size["height"]
)
patches = divide_to_patches(image_padded, crop_size)
image_original_resize = image.resize((shortest_edge, shortest_edge))
image_patches = [image_original_resize] + patches
image_patches = [
processor.preprocess(image_patch)["pixel_values"][0]
processor.preprocess(image_patch.convert("RGB"))["pixel_values"][0]
for image_patch in image_patches
]
return np.stack(image_patches, axis=0)
@@ -255,7 +327,7 @@ def process_images(images, image_processor, model_cfg):
)
image = image_processor.preprocess(image)["pixel_values"][0]
new_images.append(image)
elif image_aspect_ratio == "anyres":
elif "anyres" in image_aspect_ratio:
for image in images:
image = process_anyres_image(
image, image_processor, model_cfg.image_grid_pinpoints

View File

@@ -88,14 +88,19 @@ class InputMetadata:
reqs = batch.reqs
self.pixel_values = [r.pixel_values for r in reqs]
self.image_sizes = [r.image_size for r in reqs]
self.image_offsets = [
(
(r.image_offset - batch.prefix_lens_cpu[i])
if r.image_offset is not None
else 0
)
for i, r in enumerate(reqs)
]
self.image_offsets = []
for r in reqs:
if isinstance(r.image_offset, list):
self.image_offsets.append(
[
(image_offset - len(r.prefix_indices))
for image_offset in r.image_offset
]
)
elif isinstance(r.image_offset, int):
self.image_offsets.append(r.image_offset - len(r.prefix_indices))
elif r.image_offset is None:
self.image_offsets.append(0)
def compute_positions(self, batch: ScheduleBatch):
position_ids_offsets = batch.position_ids_offsets

View File

@@ -15,6 +15,8 @@ limitations under the License.
"""Inference-only LLaVa model compatible with HuggingFace weights."""
import math
import re
from typing import Iterable, List, Optional, Tuple
import numpy as np
@@ -26,6 +28,8 @@ from transformers import (
LlavaConfig,
MistralConfig,
Qwen2Config,
SiglipVisionConfig,
SiglipVisionModel,
)
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
from vllm.config import CacheConfig
@@ -63,34 +67,61 @@ class LlavaLlamaForCausalLM(nn.Module):
)
def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None):
new_image_feature_len = self.image_feature_len
# now only support spatial_unpad + anyres
if self.mm_patch_merge_type.startswith("spatial"):
height = width = self.num_patches_per_side
if pt_shape[0] > 1:
if self.image_aspect_ratio == "anyres":
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
image_size,
self.image_grid_pinpoints,
self.vision_tower.config.image_size,
)
if "unpad" in self.mm_patch_merge_type:
h = num_patch_height * height
w = num_patch_width * width
new_h, new_w = unpad_image_shape(h, w, image_size)
new_image_feature_len += new_h * (new_w + 1)
pad_ids = pad_value * (
(new_image_feature_len + len(pad_value)) // len(pad_value)
)
offset = input_ids.index(self.config.image_token_index)
# old_len + pad_len - 1, because we need to remove image_token_id
new_input_ids = (
input_ids[:offset]
+ pad_ids[:new_image_feature_len]
+ input_ids[offset + 1 :]
)
return new_input_ids, offset
# hardcode for spatial_unpad + anyres
image_aspect_ratio = "anyres" if len(image_size) == 1 else "pad"
offset_list = []
for image_s in image_size:
if len(image_size) > 16:
# 2x2 pooling with stride 2
new_image_feature_len = (
math.ceil(self.image_size / self.patch_size / 2) ** 2
)
else:
new_image_feature_len = self.image_feature_len # multiimage
height = width = self.num_patches_per_side
if "anyres" in image_aspect_ratio:
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
image_s,
self.image_grid_pinpoints,
self.vision_tower.config.image_size,
)
h = num_patch_height * height
w = num_patch_width * width
new_h, new_w = unpad_image_shape(h, w, image_s)
if "anyres_max" in self.config.image_aspect_ratio:
matched_anyres_max_num_patches = re.match(
r"anyres_max_(\d+)", self.config.image_aspect_ratio
)
if matched_anyres_max_num_patches:
max_num_patches = int(matched_anyres_max_num_patches.group(1))
# times = math.sqrt(h * w / (max_num_patches * unit**2))
times = math.sqrt(
new_h * new_w / (max_num_patches * self.image_feature_len)
)
if times > 1.1:
new_h = int(new_h // times)
new_w = int(new_w // times)
new_image_feature_len += new_h * (new_w + 1)
pad_ids = pad_value * (
(new_image_feature_len + len(pad_value)) // len(pad_value)
)
# print("calculated new_image_feature_len: ", new_image_feature_len)
try:
offset = input_ids.index(self.config.image_token_index)
except ValueError:
offset = 0
# old_len + pad_len - 1, because we need to remove image_token_id
input_ids = (
input_ids[:offset]
+ pad_ids[:new_image_feature_len]
+ input_ids[offset + 1 :]
)
offset_list.append(offset)
return input_ids, offset_list
def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor:
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
@@ -124,7 +155,6 @@ class LlavaLlamaForCausalLM(nn.Module):
# Embed text input
input_embeds = self.language_model.model.embed_tokens(input_ids)
# Embed vision input
need_vision = (
(positions[input_metadata.extend_start_loc] < self.image_feature_len)
@@ -163,27 +193,73 @@ class LlavaLlamaForCausalLM(nn.Module):
if self.mm_patch_merge_type.startswith("spatial"):
new_image_features = []
height = width = self.num_patches_per_side
for image_idx, image_feature in enumerate(image_features):
if image_feature.shape[0] > 1:
if len(image_sizes[image_idx]) == 1:
image_aspect_ratio = (
self.config.image_aspect_ratio
) # single image
else:
image_aspect_ratio = "pad" # multi image
# image_aspect_ratio = (
# "anyres" if len(image_sizes[image_idx]) == 1 else "pad"
# )
if (
image_feature.shape[0] > 1
and "anyres" in image_aspect_ratio
):
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
height = width = self.num_patches_per_side
assert height * width == base_image_feature.shape[0]
if self.image_aspect_ratio == "anyres":
(
num_patch_width,
num_patch_height,
) = get_anyres_image_grid_shape(
image_sizes[image_idx],
self.image_grid_pinpoints,
self.vision_tower.config.image_size,
if "anyres_max" in image_aspect_ratio:
matched_anyres_max_num_patches = re.match(
r"anyres_max_(\d+)", image_aspect_ratio
)
if matched_anyres_max_num_patches:
max_num_patches = int(
matched_anyres_max_num_patches.group(1)
)
if (
image_aspect_ratio == "anyres"
or "anyres_max" in image_aspect_ratio
):
vision_tower_image_size = self.image_size
try:
num_patch_width, num_patch_height = (
get_anyres_image_grid_shape(
image_sizes[image_idx][0],
self.config.image_grid_pinpoints,
vision_tower_image_size,
)
)
except Exception as e:
print(f"Error: {e}")
num_patch_width, num_patch_height = 2, 2
image_feature = image_feature.view(
num_patch_height, num_patch_width, height, width, -1
)
else:
raise NotImplementedError()
image_feature = image_feature.view(
2, 2, height, width, -1
)
# (
# num_patch_width,
# num_patch_height,
# ) = get_anyres_image_grid_shape(
# image_sizes[image_idx][0],
# self.image_grid_pinpoints,
# self.vision_tower.config.image_size,
# )
# image_feature = image_feature.view(
# num_patch_height, num_patch_width, height, width, -1
# )
if "unpad" in self.mm_patch_merge_type:
unit = image_feature.shape[2]
image_feature = image_feature.permute(
4, 0, 2, 1, 3
).contiguous()
@@ -191,8 +267,23 @@ class LlavaLlamaForCausalLM(nn.Module):
2, 3
)
image_feature = unpad_image(
image_feature, image_sizes[image_idx]
image_feature, image_sizes[image_idx][0]
)
if (
"anyres_max" in image_aspect_ratio
and matched_anyres_max_num_patches
):
c, h, w = image_feature.shape
times = math.sqrt(
h * w / (max_num_patches * unit**2)
)
if times > 1.1:
image_feature = image_feature[None]
image_feature = nn.functional.interpolate(
image_feature,
[int(h // times), int(w // times)],
mode="bilinear",
)[0]
image_feature = torch.cat(
(
image_feature,
@@ -213,16 +304,31 @@ class LlavaLlamaForCausalLM(nn.Module):
image_feature = torch.cat(
(base_image_feature, image_feature), dim=0
)
image_feature = image_feature.unsqueeze(0)
else:
image_feature = image_feature[0]
if "unpad" in self.mm_patch_merge_type:
image_feature = torch.cat(
(
image_feature,
self.language_model.model.image_newline[None],
),
dim=0,
if image_feature.shape[0] > 16: # video
# 2x2 pooling
num_of_frames = image_feature.shape[0]
image_feature = image_feature.view(
num_of_frames, height, width, -1
)
image_feature = image_feature.permute(
0, 3, 1, 2
).contiguous() # N, C, H, W
height, weight = image_feature.shape[2:]
scaled_shape = [
math.ceil(height / 2),
math.ceil(weight / 2),
]
image_feature = nn.functional.interpolate(
image_feature, size=scaled_shape, mode="bilinear"
)
image_feature = (
image_feature.flatten(2)
.transpose(1, 2)
.contiguous()
) # N, C, H*W
new_image_features.append(image_feature)
image_features = new_image_features
@@ -233,21 +339,22 @@ class LlavaLlamaForCausalLM(nn.Module):
continue
start_idx = extend_start_loc_cpu[i]
pad_len, pad_dim = image_features[pt].shape # 576, 4096
pad_dim = image_features[pt].shape[-1] # 576, 4096
dim = input_embeds.shape[1]
assert (
pad_dim == dim
), "invalid pad_dim={}, input_embed_dim={}!".format(pad_dim, dim)
# Fill in the placeholder for the image
try:
input_embeds[
start_idx
+ image_offsets[i] : start_idx
+ image_offsets[i]
+ pad_len
] = image_features[pt]
for j, image_off in enumerate(image_offsets[i]):
# print("actual image_features length: ", image_features[pt][j].shape[0])
pad_len = image_features[pt][j].shape[0]
input_embeds[
start_idx + image_off : start_idx + image_off + pad_len
] = image_features[pt][j]
except RuntimeError as e:
print(f"RuntimeError in llava image encoding: {e}")
print(image_features[pt].shape)
print(input_embeds.shape)
print(start_idx, image_offsets[i])
pt += 1
@@ -262,9 +369,16 @@ class LlavaLlamaForCausalLM(nn.Module):
# load clip vision model by cfg['mm_vision_tower']:
# huggingface_name or path_of_clip_relative_to_llava_model_dir
vision_path = self.config.mm_vision_tower
self.vision_tower = CLIPVisionModel.from_pretrained(
vision_path, torch_dtype=torch.float16
).cuda()
if "clip" in vision_path:
self.vision_tower = CLIPVisionModel.from_pretrained(
vision_path, torch_dtype=torch.float16
).cuda()
elif "siglip" in vision_path:
self.vision_tower = SiglipVisionModel.from_pretrained(
vision_path, torch_dtype=torch.float16
).cuda()
# Siglip needs all feature tokens
self.config.mm_vision_select_feature = "full"
self.vision_tower.eval()
self.vision_feature_layer = self.config.mm_vision_select_layer
@@ -276,8 +390,11 @@ class LlavaLlamaForCausalLM(nn.Module):
self.image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
self.image_grid_pinpoints = getattr(self.config, "image_grid_pinpoints", None)
self.image_feature_len = int((self.image_size / self.patch_size) ** 2)
if self.vision_feature_select_strategy == "patch":
self.image_feature_len = int((self.image_size // self.patch_size) ** 2)
if (
self.vision_feature_select_strategy == "patch"
or self.vision_feature_select_strategy == "full"
):
pass
elif self.vision_feature_select_strategy == "cls_patch":
self.image_feature_len += 1