Llava-hd Support (#92)
Co-authored-by: Haotian Liu <liuhaotian.cn@gmail.com>
This commit is contained in:
@@ -18,7 +18,7 @@ dependencies = [
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
srt = ["fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn",
|
||||
srt = ["aiohttp", "fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn",
|
||||
"zmq", "vllm>=0.2.5", "interegular", "lark", "numba",
|
||||
"pydantic", "diskcache", "cloudpickle"]
|
||||
openai = ["openai>=1.0", "numpy"]
|
||||
|
||||
@@ -62,6 +62,7 @@ class TokenizedGenerateReqInput:
|
||||
input_ids: List[int]
|
||||
pixel_values: List[float]
|
||||
image_hash: int
|
||||
image_size: List[int]
|
||||
sampling_params: SamplingParams
|
||||
return_logprob: bool
|
||||
logprob_start_len: int
|
||||
|
||||
@@ -26,6 +26,7 @@ class Req:
|
||||
self.input_ids = []
|
||||
self.output_ids = []
|
||||
self.pixel_values = None
|
||||
self.image_size = None
|
||||
self.image_offset = 0
|
||||
self.sampling_params = None
|
||||
self.return_logprob = False
|
||||
@@ -104,6 +105,7 @@ class Batch:
|
||||
|
||||
# for multimodal
|
||||
pixel_values: List[torch.Tensor] = None
|
||||
image_sizes: List[List[int]] = None
|
||||
image_offsets: List[int] = None
|
||||
|
||||
# other arguments for control
|
||||
@@ -195,6 +197,7 @@ class Batch:
|
||||
flatten_input_ids, dtype=torch.int32, device=device
|
||||
)
|
||||
self.pixel_values = [r.pixel_values for r in reqs]
|
||||
self.image_sizes = [r.image_size for r in reqs]
|
||||
self.image_offsets = [
|
||||
r.image_offset - p_len for r, p_len in zip(reqs, prefix_lens)
|
||||
]
|
||||
|
||||
@@ -203,6 +203,7 @@ class ModelRpcServer(rpyc.Service):
|
||||
req = Req(recv_req.rid)
|
||||
req.input_ids = recv_req.input_ids
|
||||
req.pixel_values = recv_req.pixel_values
|
||||
req.image_size = recv_req.image_size
|
||||
if req.pixel_values is not None:
|
||||
pad_value = [
|
||||
(recv_req.image_hash) % self.model_config.vocab_size,
|
||||
@@ -211,7 +212,7 @@ class ModelRpcServer(rpyc.Service):
|
||||
(recv_req.image_hash >> 64) % self.model_config.vocab_size,
|
||||
]
|
||||
req.input_ids, req.image_offset = self.model_runner.model.pad_input_ids(
|
||||
req.input_ids, pad_value
|
||||
req.input_ids, pad_value, req.pixel_values.shape, req.image_size
|
||||
)
|
||||
req.sampling_params = recv_req.sampling_params
|
||||
req.return_logprob = recv_req.return_logprob
|
||||
|
||||
@@ -409,6 +409,7 @@ class ModelRunner:
|
||||
self,
|
||||
input_ids,
|
||||
pixel_values,
|
||||
image_sizes,
|
||||
image_offsets,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
@@ -433,6 +434,7 @@ class ModelRunner:
|
||||
input_metadata.positions,
|
||||
input_metadata,
|
||||
pixel_values,
|
||||
image_sizes,
|
||||
image_offsets,
|
||||
)
|
||||
|
||||
@@ -441,6 +443,7 @@ class ModelRunner:
|
||||
kwargs = {
|
||||
"input_ids": batch.input_ids,
|
||||
"pixel_values": batch.pixel_values,
|
||||
"image_sizes": batch.image_sizes,
|
||||
"image_offsets": batch.image_offsets,
|
||||
"req_pool_indices": batch.req_pool_indices,
|
||||
"seq_lens": batch.seq_lens,
|
||||
|
||||
@@ -20,6 +20,7 @@ from sglang.srt.managers.io_struct import (
|
||||
GenerateReqInput,
|
||||
TokenizedGenerateReqInput,
|
||||
)
|
||||
from sglang.srt.mm_utils import expand2square, process_anyres_image
|
||||
from sglang.srt.sampling_params import SamplingParams
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import get_exception_traceback, is_multimodal_model, load_image
|
||||
@@ -48,14 +49,25 @@ def init_global_processor(server_args: ServerArgs):
|
||||
)
|
||||
|
||||
|
||||
def get_pixel_values(image_data, processor=None):
|
||||
def get_pixel_values(image_data, model_cfg, processor=None):
|
||||
image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
|
||||
try:
|
||||
processor = processor or global_processor
|
||||
image = load_image(image_data)
|
||||
image_hash = hash(image_data)
|
||||
pixel_values = processor.image_processor(image)["pixel_values"][0]
|
||||
if image_aspect_ratio == "pad":
|
||||
image = expand2square(
|
||||
image, tuple(int(x * 255) for x in processor.image_processor.image_mean)
|
||||
)
|
||||
pixel_values = processor.image_processor(image)["pixel_values"][0]
|
||||
elif image_aspect_ratio == "anyres":
|
||||
pixel_values = process_anyres_image(
|
||||
image, processor.image_processor, model_cfg.image_grid_pinpoints
|
||||
)
|
||||
else:
|
||||
pixel_values = processor.image_processor(image)["pixel_values"][0]
|
||||
pixel_values = pixel_values.astype(np.float16)
|
||||
return pixel_values, image_hash
|
||||
return pixel_values, image_hash, image.size
|
||||
except Exception:
|
||||
print("Exception in TokenizerManager:\n" + get_exception_traceback())
|
||||
|
||||
@@ -77,6 +89,7 @@ class TokenizerManager:
|
||||
self.hf_config = get_config(
|
||||
self.model_path, trust_remote_code=server_args.trust_remote_code
|
||||
)
|
||||
|
||||
self.context_len = get_context_length(self.hf_config)
|
||||
|
||||
if is_multimodal_model(self.model_path):
|
||||
@@ -104,10 +117,10 @@ class TokenizerManager:
|
||||
if self.executor is not None:
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
self.executor, get_pixel_values, image_data
|
||||
self.executor, get_pixel_values, image_data, self.hf_config
|
||||
)
|
||||
else:
|
||||
return get_pixel_values(image_data, self.processor)
|
||||
return get_pixel_values(image_data, self.hf_config, self.processor)
|
||||
|
||||
async def generate_request(self, obj: GenerateReqInput):
|
||||
if self.to_create_loop:
|
||||
@@ -123,14 +136,17 @@ class TokenizerManager:
|
||||
sampling_params.normalize(self.tokenizer)
|
||||
sampling_params.verify()
|
||||
if obj.image_data is None:
|
||||
pixel_values, image_hash = None, None
|
||||
pixel_values, image_hash, image_size = None, None, None
|
||||
else:
|
||||
pixel_values, image_hash = await self.get_pixel_values(obj.image_data)
|
||||
pixel_values, image_hash, image_size = await self.get_pixel_values(
|
||||
obj.image_data
|
||||
)
|
||||
tokenized_obj = TokenizedGenerateReqInput(
|
||||
rid=rid,
|
||||
input_ids=input_ids,
|
||||
pixel_values=pixel_values,
|
||||
image_hash=image_hash,
|
||||
image_size=image_size,
|
||||
sampling_params=sampling_params,
|
||||
return_logprob=obj.return_logprob,
|
||||
logprob_start_len=obj.logprob_start_len,
|
||||
@@ -162,9 +178,9 @@ class TokenizerManager:
|
||||
sampling_params.normalize(self.tokenizer)
|
||||
sampling_params.verify()
|
||||
if obj.image_data[i] is None:
|
||||
pixel_values, image_hash = None, None
|
||||
pixel_values, image_hash, image_size = None, None, None
|
||||
else:
|
||||
pixel_values, image_hash = await self.get_pixel_values(
|
||||
pixel_values, image_hash, image_size = await self.get_pixel_values(
|
||||
obj.image_data[i]
|
||||
)
|
||||
tokenized_obj = TokenizedGenerateReqInput(
|
||||
@@ -172,6 +188,7 @@ class TokenizerManager:
|
||||
input_ids=input_ids,
|
||||
pixel_values=pixel_values,
|
||||
image_hash=image_hash,
|
||||
image_size=image_size,
|
||||
sampling_params=sampling_params,
|
||||
return_logprob=obj.return_logprob[i],
|
||||
logprob_start_len=obj.logprob_start_len[i],
|
||||
|
||||
251
python/sglang/srt/mm_utils.py
Normal file
251
python/sglang/srt/mm_utils.py
Normal file
@@ -0,0 +1,251 @@
|
||||
# Source: https://github.com/haotian-liu/LLaVA/blob/main/llava/mm_utils.py
|
||||
import ast
|
||||
import base64
|
||||
import math
|
||||
from io import BytesIO
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def select_best_resolution(original_size, possible_resolutions):
|
||||
"""
|
||||
Selects the best resolution from a list of possible resolutions based on the original size.
|
||||
|
||||
Args:
|
||||
original_size (tuple): The original size of the image in the format (width, height).
|
||||
possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
|
||||
|
||||
Returns:
|
||||
tuple: The best fit resolution in the format (width, height).
|
||||
"""
|
||||
original_width, original_height = original_size
|
||||
best_fit = None
|
||||
max_effective_resolution = 0
|
||||
min_wasted_resolution = float("inf")
|
||||
|
||||
for width, height in possible_resolutions:
|
||||
scale = min(width / original_width, height / original_height)
|
||||
downscaled_width, downscaled_height = int(original_width * scale), int(
|
||||
original_height * scale
|
||||
)
|
||||
effective_resolution = min(
|
||||
downscaled_width * downscaled_height, original_width * original_height
|
||||
)
|
||||
wasted_resolution = (width * height) - effective_resolution
|
||||
|
||||
if effective_resolution > max_effective_resolution or (
|
||||
effective_resolution == max_effective_resolution
|
||||
and wasted_resolution < min_wasted_resolution
|
||||
):
|
||||
max_effective_resolution = effective_resolution
|
||||
min_wasted_resolution = wasted_resolution
|
||||
best_fit = (width, height)
|
||||
|
||||
return best_fit
|
||||
|
||||
|
||||
def resize_and_pad_image(image, target_resolution):
|
||||
"""
|
||||
Resize and pad an image to a target resolution while maintaining aspect ratio.
|
||||
|
||||
Args:
|
||||
image (PIL.Image.Image): The input image.
|
||||
target_resolution (tuple): The target resolution (width, height) of the image.
|
||||
|
||||
Returns:
|
||||
PIL.Image.Image: The resized and padded image.
|
||||
"""
|
||||
original_width, original_height = image.size
|
||||
target_width, target_height = target_resolution
|
||||
|
||||
scale_w = target_width / original_width
|
||||
scale_h = target_height / original_height
|
||||
|
||||
if scale_w < scale_h:
|
||||
new_width = target_width
|
||||
new_height = min(math.ceil(original_height * scale_w), target_height)
|
||||
else:
|
||||
new_height = target_height
|
||||
new_width = min(math.ceil(original_width * scale_h), target_width)
|
||||
|
||||
# Resize the image
|
||||
resized_image = image.resize((new_width, new_height))
|
||||
|
||||
new_image = Image.new("RGB", (target_width, target_height), (0, 0, 0))
|
||||
paste_x = (target_width - new_width) // 2
|
||||
paste_y = (target_height - new_height) // 2
|
||||
new_image.paste(resized_image, (paste_x, paste_y))
|
||||
|
||||
return new_image
|
||||
|
||||
|
||||
def divide_to_patches(image, patch_size):
|
||||
"""
|
||||
Divides an image into patches of a specified size.
|
||||
|
||||
Args:
|
||||
image (PIL.Image.Image): The input image.
|
||||
patch_size (int): The size of each patch.
|
||||
|
||||
Returns:
|
||||
list: A list of PIL.Image.Image objects representing the patches.
|
||||
"""
|
||||
patches = []
|
||||
width, height = image.size
|
||||
for i in range(0, height, patch_size):
|
||||
for j in range(0, width, patch_size):
|
||||
box = (j, i, j + patch_size, i + patch_size)
|
||||
patch = image.crop(box)
|
||||
patches.append(patch)
|
||||
|
||||
return patches
|
||||
|
||||
|
||||
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||||
"""
|
||||
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
|
||||
|
||||
Args:
|
||||
image_size (tuple): The size of the input image in the format (width, height).
|
||||
grid_pinpoints (str): A string representation of a list of possible resolutions.
|
||||
patch_size (int): The size of each image patch.
|
||||
|
||||
Returns:
|
||||
tuple: The shape of the image patch grid in the format (width, height).
|
||||
"""
|
||||
if type(grid_pinpoints) is list:
|
||||
possible_resolutions = grid_pinpoints
|
||||
else:
|
||||
possible_resolutions = ast.literal_eval(grid_pinpoints)
|
||||
width, height = select_best_resolution(image_size, possible_resolutions)
|
||||
return width // patch_size, height // patch_size
|
||||
|
||||
|
||||
def process_anyres_image(image, processor, grid_pinpoints):
|
||||
"""
|
||||
Process an image with variable resolutions.
|
||||
|
||||
Args:
|
||||
image (PIL.Image.Image): The input image to be processed.
|
||||
processor: The image processor object.
|
||||
grid_pinpoints (str): A string representation of a list of possible resolutions.
|
||||
|
||||
Returns:
|
||||
np.array: An np array containing the processed image patches.
|
||||
"""
|
||||
if type(grid_pinpoints) is list:
|
||||
possible_resolutions = grid_pinpoints
|
||||
else:
|
||||
possible_resolutions = ast.literal_eval(grid_pinpoints)
|
||||
best_resolution = select_best_resolution(image.size, possible_resolutions)
|
||||
image_padded = resize_and_pad_image(image, best_resolution)
|
||||
|
||||
patches = divide_to_patches(image_padded, processor.crop_size["height"])
|
||||
|
||||
image_original_resize = image.resize(
|
||||
(processor.size["shortest_edge"], processor.size["shortest_edge"])
|
||||
)
|
||||
|
||||
image_patches = [image_original_resize] + patches
|
||||
image_patches = [
|
||||
processor.preprocess(image_patch)["pixel_values"][0]
|
||||
for image_patch in image_patches
|
||||
]
|
||||
return np.stack(image_patches, axis=0)
|
||||
|
||||
|
||||
def load_image_from_base64(image):
|
||||
return Image.open(BytesIO(base64.b64decode(image)))
|
||||
|
||||
|
||||
def expand2square(pil_img, background_color):
|
||||
width, height = pil_img.size
|
||||
if width == height:
|
||||
return pil_img
|
||||
elif width > height:
|
||||
result = Image.new(pil_img.mode, (width, width), background_color)
|
||||
result.paste(pil_img, (0, (width - height) // 2))
|
||||
return result
|
||||
else:
|
||||
result = Image.new(pil_img.mode, (height, height), background_color)
|
||||
result.paste(pil_img, ((height - width) // 2, 0))
|
||||
return result
|
||||
|
||||
|
||||
def unpad_image(tensor, original_size):
|
||||
"""
|
||||
Unpads a PyTorch tensor of a padded and resized image.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
|
||||
original_size (tuple): The original size of the image (height, width).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The unpadded image tensor.
|
||||
"""
|
||||
original_width, original_height = original_size
|
||||
current_height, current_width = tensor.shape[1:]
|
||||
|
||||
original_aspect_ratio = original_width / original_height
|
||||
current_aspect_ratio = current_width / current_height
|
||||
|
||||
if original_aspect_ratio > current_aspect_ratio:
|
||||
scale_factor = current_width / original_width
|
||||
new_height = int(original_height * scale_factor)
|
||||
padding = (current_height - new_height) // 2
|
||||
unpadded_tensor = tensor[:, padding : current_height - padding, :]
|
||||
else:
|
||||
scale_factor = current_height / original_height
|
||||
new_width = int(original_width * scale_factor)
|
||||
padding = (current_width - new_width) // 2
|
||||
unpadded_tensor = tensor[:, :, padding : current_width - padding]
|
||||
|
||||
return unpadded_tensor
|
||||
|
||||
|
||||
def unpad_image_shape(current_height, current_width, original_size):
|
||||
"""
|
||||
Unpads a PyTorch tensor of a padded and resized image
|
||||
and returns the new shape.
|
||||
"""
|
||||
original_width, original_height = original_size
|
||||
|
||||
original_aspect_ratio = original_width / original_height
|
||||
current_aspect_ratio = current_width / current_height
|
||||
|
||||
if original_aspect_ratio > current_aspect_ratio:
|
||||
scale_factor = current_width / original_width
|
||||
new_height = int(original_height * scale_factor)
|
||||
padding = (current_height - new_height) // 2
|
||||
new_shape = (current_height - 2 * padding, current_width)
|
||||
else:
|
||||
scale_factor = current_height / original_height
|
||||
new_width = int(original_width * scale_factor)
|
||||
padding = (current_width - new_width) // 2
|
||||
new_shape = (current_height, current_width - 2 * padding)
|
||||
|
||||
return new_shape
|
||||
|
||||
|
||||
def process_images(images, image_processor, model_cfg):
|
||||
image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
|
||||
new_images = []
|
||||
if image_aspect_ratio == "pad":
|
||||
for image in images:
|
||||
image = expand2square(
|
||||
image, tuple(int(x * 255) for x in image_processor.image_mean)
|
||||
)
|
||||
image = image_processor.preprocess(image)["pixel_values"][0]
|
||||
new_images.append(image)
|
||||
elif image_aspect_ratio == "anyres":
|
||||
for image in images:
|
||||
image = process_anyres_image(
|
||||
image, image_processor, model_cfg.image_grid_pinpoints
|
||||
)
|
||||
new_images.append(image)
|
||||
else:
|
||||
return image_processor(images)["pixel_values"]
|
||||
if all(x.shape == new_images[0].shape for x in new_images):
|
||||
new_images = np.stack(new_images, axis=0)
|
||||
return new_images
|
||||
@@ -1,15 +1,18 @@
|
||||
"""Inference-only LLaVa model compatible with HuggingFace weights."""
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from sglang.srt.managers.router.infer_batch import ForwardMode
|
||||
from sglang.srt.managers.router.model_runner import InputMetadata
|
||||
from sglang.srt.mm_utils import (
|
||||
get_anyres_image_grid_shape,
|
||||
unpad_image,
|
||||
unpad_image_shape,
|
||||
)
|
||||
from sglang.srt.models.llama2 import LlamaForCausalLM
|
||||
from torch import nn
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModel, LlavaConfig
|
||||
from transformers import CLIPVisionModel, LlamaConfig, LlavaConfig
|
||||
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
||||
from vllm.model_executor.layers.linear import LinearMethodBase
|
||||
from vllm.model_executor.weight_utils import (
|
||||
@@ -31,26 +34,64 @@ class LlavaLlamaForCausalLM(nn.Module):
|
||||
self.config.text_config.hidden_size = config.hidden_size
|
||||
self.multi_modal_projector = LlavaMultiModalProjector(config)
|
||||
self.language_model = LlamaForCausalLM(config, linear_method)
|
||||
if "unpad" in getattr(config, "mm_patch_merge_type"):
|
||||
self.language_model.model.image_newline = nn.Parameter(
|
||||
torch.empty(config.text_config.hidden_size, dtype=torch.float16))
|
||||
|
||||
def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None):
|
||||
new_image_feature_len = self.image_feature_len
|
||||
# now only support spatial_unpad + anyres
|
||||
if self.mm_patch_merge_type.startswith("spatial"):
|
||||
height = width = self.num_patches_per_side
|
||||
if pt_shape[0] > 1:
|
||||
if self.image_aspect_ratio == "anyres":
|
||||
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
|
||||
image_size,
|
||||
self.image_grid_pinpoints,
|
||||
self.vision_tower.config.image_size,
|
||||
)
|
||||
if "unpad" in self.mm_patch_merge_type:
|
||||
h = num_patch_height * height
|
||||
w = num_patch_width * width
|
||||
new_h, new_w = unpad_image_shape(h, w, image_size)
|
||||
new_image_feature_len += new_h * (new_w + 1)
|
||||
|
||||
def pad_input_ids(self, input_ids, pad_value):
|
||||
pad_ids = pad_value * (
|
||||
(self.image_feature_len + len(pad_value)) // len(pad_value)
|
||||
(new_image_feature_len + len(pad_value)) // len(pad_value)
|
||||
)
|
||||
offset = input_ids.index(self.config.image_token_index)
|
||||
# old_len + pad_len - 1, because we need to remove image_token_id
|
||||
new_input_ids = (
|
||||
input_ids[:offset]
|
||||
+ pad_ids[: self.image_feature_len]
|
||||
+ pad_ids[:new_image_feature_len]
|
||||
+ input_ids[offset + 1 :]
|
||||
)
|
||||
return new_input_ids, offset
|
||||
|
||||
def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||||
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
|
||||
# NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated.
|
||||
|
||||
selected_image_feature = image_outputs.hidden_states[self.vision_feature_layer]
|
||||
if self.vision_feature_select_strategy in ["default", "patch"]:
|
||||
selected_image_feature = selected_image_feature[:, 1:]
|
||||
elif self.vision_feature_select_strategy == "full":
|
||||
selected_image_feature = selected_image_feature
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
|
||||
)
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
|
||||
return image_features
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
pixel_values: Optional[List[Optional[np.array]]] = None,
|
||||
image_sizes: Optional[List[List[int]]] = None,
|
||||
image_offsets: Optional[List[int]] = None,
|
||||
) -> torch.Tensor:
|
||||
if input_metadata.forward_mode == ForwardMode.EXTEND:
|
||||
@@ -75,23 +116,86 @@ class LlavaLlamaForCausalLM(nn.Module):
|
||||
device=self.vision_tower.device,
|
||||
)
|
||||
|
||||
image_outputs = self.vision_tower(
|
||||
pixel_values, output_hidden_states=True
|
||||
)
|
||||
# NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated.
|
||||
########## Encode Image ########
|
||||
|
||||
selected_image_feature = image_outputs.hidden_states[
|
||||
self.vision_feature_layer
|
||||
]
|
||||
if self.vision_feature_select_strategy in ["default", "patch"]:
|
||||
selected_image_feature = selected_image_feature[:, 1:]
|
||||
elif self.vision_feature_select_strategy == "full":
|
||||
selected_image_feature = selected_image_feature
|
||||
if pixel_values.ndim == 5:
|
||||
# llava-hd: BS, num_patch, C=3, H=336, W=336, num_patch obtained from process_images
|
||||
concat_images = torch.cat(
|
||||
[image for image in pixel_values], dim=0
|
||||
) # ndim=4
|
||||
image_features = self.encode_images(concat_images)
|
||||
split_sizes = [image.shape[0] for image in pixel_values]
|
||||
image_features = torch.split(image_features, split_sizes, dim=0)
|
||||
# hd image_features: BS, num_patch, 576, 4096
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
|
||||
)
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
# normal pixel: BS, C=3, H=336, W=336
|
||||
image_features = self.encode_images(pixel_values)
|
||||
# image_features: BS, 576, 4096
|
||||
|
||||
if self.mm_patch_merge_type.startswith("spatial"):
|
||||
new_image_features = []
|
||||
for image_idx, image_feature in enumerate(image_features):
|
||||
if image_feature.shape[0] > 1:
|
||||
base_image_feature = image_feature[0]
|
||||
image_feature = image_feature[1:]
|
||||
height = width = self.num_patches_per_side
|
||||
assert height * width == base_image_feature.shape[0]
|
||||
if self.image_aspect_ratio == "anyres":
|
||||
(
|
||||
num_patch_width,
|
||||
num_patch_height,
|
||||
) = get_anyres_image_grid_shape(
|
||||
image_sizes[image_idx],
|
||||
self.image_grid_pinpoints,
|
||||
self.vision_tower.config.image_size,
|
||||
)
|
||||
image_feature = image_feature.view(
|
||||
num_patch_height, num_patch_width, height, width, -1
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
if "unpad" in self.mm_patch_merge_type:
|
||||
image_feature = image_feature.permute(
|
||||
4, 0, 2, 1, 3
|
||||
).contiguous()
|
||||
image_feature = image_feature.flatten(1, 2).flatten(
|
||||
2, 3
|
||||
)
|
||||
image_feature = unpad_image(
|
||||
image_feature, image_sizes[image_idx]
|
||||
)
|
||||
image_feature = torch.cat(
|
||||
(
|
||||
image_feature,
|
||||
self.language_model.model.image_newline[
|
||||
:, None, None
|
||||
].expand(*image_feature.shape[:-1], 1),
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
image_feature = image_feature.flatten(1, 2).transpose(
|
||||
0, 1
|
||||
)
|
||||
else:
|
||||
image_feature = image_feature.permute(
|
||||
0, 2, 1, 3, 4
|
||||
).contiguous()
|
||||
image_feature = image_feature.flatten(0, 3)
|
||||
image_feature = torch.cat(
|
||||
(base_image_feature, image_feature), dim=0
|
||||
)
|
||||
else:
|
||||
image_feature = image_feature[0]
|
||||
if "unpad" in self.mm_patch_merge_type:
|
||||
image_feature = torch.cat(
|
||||
(
|
||||
image_feature,
|
||||
self.language_model.model.image_newline[None],
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
new_image_features.append(image_feature)
|
||||
image_features = new_image_features
|
||||
|
||||
extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy()
|
||||
pt = 0
|
||||
@@ -100,7 +204,7 @@ class LlavaLlamaForCausalLM(nn.Module):
|
||||
continue
|
||||
|
||||
start_idx = extend_start_loc_cpu[i]
|
||||
pad_len, pad_dim = image_features[pt].shape
|
||||
pad_len, pad_dim = image_features[pt].shape # 576, 4096
|
||||
dim = input_embeds.shape[1]
|
||||
assert (
|
||||
pad_dim == dim
|
||||
@@ -146,6 +250,11 @@ class LlavaLlamaForCausalLM(nn.Module):
|
||||
self.vision_feature_select_strategy = self.config.mm_vision_select_feature
|
||||
self.image_size = self.vision_tower.config.image_size
|
||||
self.patch_size = self.vision_tower.config.patch_size
|
||||
|
||||
self.mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat")
|
||||
self.image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
|
||||
self.image_grid_pinpoints = getattr(self.config, "image_grid_pinpoints", None)
|
||||
|
||||
self.image_feature_len = int((self.image_size / self.patch_size) ** 2)
|
||||
if self.vision_feature_select_strategy == "patch":
|
||||
pass
|
||||
@@ -159,13 +268,14 @@ class LlavaLlamaForCausalLM(nn.Module):
|
||||
projector_weights = {
|
||||
"model.mm_projector.0": "multi_modal_projector.linear_1",
|
||||
"model.mm_projector.2": "multi_modal_projector.linear_2",
|
||||
"model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
|
||||
}
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format, revision
|
||||
):
|
||||
# FIXME: why projector weights read two times?
|
||||
if "projector" in name:
|
||||
if "projector" in name or "vision_tower" in name:
|
||||
for weight_name, param_name in projector_weights.items():
|
||||
if weight_name in name:
|
||||
name = name.replace(weight_name, param_name)
|
||||
@@ -180,6 +290,10 @@ class LlavaLlamaForCausalLM(nn.Module):
|
||||
|
||||
monkey_path_clip_vision_embed_forward()
|
||||
|
||||
@property
|
||||
def num_patches_per_side(self):
|
||||
return self.image_size // self.patch_size
|
||||
|
||||
|
||||
first_call = True
|
||||
|
||||
|
||||
@@ -469,7 +469,6 @@ class Runtime:
|
||||
prompt: str,
|
||||
sampling_params,
|
||||
) -> None:
|
||||
|
||||
json_data = {
|
||||
"text": prompt,
|
||||
"sampling_params": sampling_params,
|
||||
|
||||
Reference in New Issue
Block a user