1345 lines
60 KiB
Python
1345 lines
60 KiB
Python
|
|
import ast
|
||
|
|
import contextlib
|
||
|
|
import gc
|
||
|
|
import json
|
||
|
|
import os
|
||
|
|
from dataclasses import dataclass
|
||
|
|
from functools import partial
|
||
|
|
from itertools import chain
|
||
|
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||
|
|
|
||
|
|
import torch
|
||
|
|
import torch.distributed as dist
|
||
|
|
import torch.nn as nn
|
||
|
|
from einops import rearrange
|
||
|
|
from timm.layers import LayerNorm, LayerNorm2d
|
||
|
|
from timm.models.regnet import RegStage
|
||
|
|
from torch.nn import CrossEntropyLoss
|
||
|
|
from transformers import (
|
||
|
|
AutoConfig,
|
||
|
|
AutoModel,
|
||
|
|
AutoModelForCausalLM,
|
||
|
|
AutoTokenizer,
|
||
|
|
PreTrainedModel,
|
||
|
|
)
|
||
|
|
from transformers.generation.utils import GenerationMixin
|
||
|
|
from transformers.modeling_utils import (
|
||
|
|
is_fsdp_enabled,
|
||
|
|
is_local_dist_rank_0,
|
||
|
|
no_init_weights,
|
||
|
|
)
|
||
|
|
from transformers.models.auto import CONFIG_MAPPING
|
||
|
|
from transformers.utils import ModelOutput
|
||
|
|
|
||
|
|
from .configuration_hyperclovax import HCXVisionConfig
|
||
|
|
from .image_processing_hyperclovax import select_best_resolution
|
||
|
|
|
||
|
|
EOT = "<|endofturn|>"
|
||
|
|
IMAGE_LOC = "<|dummy3|>"
|
||
|
|
VIDEO_LOC = "<|_unuse_missing_100270|>"
|
||
|
|
|
||
|
|
|
||
|
|
def get_rank():
|
||
|
|
if dist.is_initialized():
|
||
|
|
return dist.get_rank()
|
||
|
|
return 0
|
||
|
|
|
||
|
|
|
||
|
|
def get_world_size():
|
||
|
|
if torch.distributed.is_initialized():
|
||
|
|
world_size = torch.distributed.get_world_size()
|
||
|
|
else:
|
||
|
|
world_size = 1
|
||
|
|
return world_size
|
||
|
|
|
||
|
|
|
||
|
|
def unpad_image(tensor: torch.Tensor, original_size: Tuple[int, int]) -> torch.Tensor:
|
||
|
|
"""Unpads a PyTorch tensor of a padded and resized image.
|
||
|
|
|
||
|
|
This function removes padding from a tensor image that was previously padded and resized.
|
||
|
|
The padding is removed based on the aspect ratio difference between the original and current image dimensions.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
tensor: The image tensor, assumed to be in CxHxW format.
|
||
|
|
original_size: The original size of the image as (width, height).
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
The unpadded image tensor.
|
||
|
|
|
||
|
|
Examples:
|
||
|
|
>>> import torch
|
||
|
|
>>> # Example 1: Unpadding with height padding
|
||
|
|
>>> padded_tensor = torch.randn(1, 64, 48) # Padded tensor (C=1, H=64, W=48)
|
||
|
|
>>> original_size = (32, 32) # Original size (width=32, height=32)
|
||
|
|
>>> unpadded_tensor = unpad_image(padded_tensor, original_size)
|
||
|
|
>>> unpadded_tensor.shape
|
||
|
|
torch.Size([1, 48, 48])
|
||
|
|
>>> # Example 2: Unpadding with width padding
|
||
|
|
>>> padded_tensor = torch.randn(1, 48, 64) # Padded tensor (C=1, H=48, W=64)
|
||
|
|
>>> original_size = (32, 32) # Original size (width=32, height=32)
|
||
|
|
>>> unpadded_tensor = unpad_image(padded_tensor, original_size)
|
||
|
|
>>> unpadded_tensor.shape
|
||
|
|
torch.Size([1, 48, 48])
|
||
|
|
"""
|
||
|
|
original_width, original_height = original_size
|
||
|
|
current_height, current_width = tensor.shape[1:]
|
||
|
|
|
||
|
|
original_aspect_ratio = original_width / original_height
|
||
|
|
current_aspect_ratio = current_width / current_height
|
||
|
|
|
||
|
|
if original_aspect_ratio > current_aspect_ratio:
|
||
|
|
scale_factor = current_width / original_width
|
||
|
|
new_height = int(original_height * scale_factor)
|
||
|
|
padding = (current_height - new_height) // 2
|
||
|
|
unpadded_tensor = tensor[:, padding : current_height - padding, :]
|
||
|
|
else:
|
||
|
|
scale_factor = current_height / original_height
|
||
|
|
new_width = int(original_width * scale_factor)
|
||
|
|
padding = (current_width - new_width) // 2
|
||
|
|
unpadded_tensor = tensor[:, :, padding : current_width - padding]
|
||
|
|
|
||
|
|
return unpadded_tensor
|
||
|
|
|
||
|
|
|
||
|
|
def get_anyres_image_grid_shape(
|
||
|
|
image_size: Tuple[int, int],
|
||
|
|
grid_pinpoints: Union[str, List[Tuple[int, int]]],
|
||
|
|
patch_size: int,
|
||
|
|
) -> Tuple[int, int]:
|
||
|
|
"""Calculates the image patch grid shape after any-resolution preprocessing.
|
||
|
|
|
||
|
|
Selects the optimal resolution from predefined grid pinpoints based on input image
|
||
|
|
dimensions using `select_best_resolution`, then computes the grid layout by
|
||
|
|
dividing the selected resolution by the patch size using integer division.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
image_size (Tuple[int, int]): Original image dimensions in (width, height) format.
|
||
|
|
grid_pinpoints (Union[str, List[Tuple[int, int]]]): Accepts either:
|
||
|
|
- List of (height, width) resolution tuples
|
||
|
|
- String representation of list (e.g., "[(224, 224), (336, 336)]")
|
||
|
|
patch_size (int): Spatial dimension of square patches for grid division.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Tuple[int, int]: Grid dimensions as (num_patches_width, num_patches_height).
|
||
|
|
|
||
|
|
Examples:
|
||
|
|
>>> # Basic case with list input
|
||
|
|
>>> get_anyres_image_grid_shape((1000, 800), [(224, 224), (448, 448)], 112)
|
||
|
|
(4, 4)
|
||
|
|
|
||
|
|
>>> # Basic case with string input
|
||
|
|
>>> get_anyres_image_grid_shape((600, 400), "[(336, 336), (672, 672)]", 112)
|
||
|
|
(6, 6)
|
||
|
|
|
||
|
|
>>> # Case where resolution is not perfectly divisible by patch_size
|
||
|
|
>>> # select_best_resolution picks (224, 224). 224 // 100 = 2
|
||
|
|
>>> get_anyres_image_grid_shape((500, 500), [(224, 224)], 100)
|
||
|
|
(2, 2)
|
||
|
|
|
||
|
|
>>> # Different patch size
|
||
|
|
>>> # select_best_resolution picks (448, 448). 448 // 224 = 2
|
||
|
|
>>> get_anyres_image_grid_shape((1200, 900), [(448, 448), (224, 224)], 224)
|
||
|
|
(2, 2)
|
||
|
|
|
||
|
|
Note:
|
||
|
|
String-formatted grid_pinpoints are converted via ast.literal_eval. Invalid formats
|
||
|
|
may raise syntax exceptions. The actual resolution selection depends on the
|
||
|
|
implementation of `select_best_resolution`. The doctests assume
|
||
|
|
`select_best_resolution` picks the *first* resolution provided in `grid_pinpoints`.
|
||
|
|
"""
|
||
|
|
possible_resolutions = grid_pinpoints if isinstance(grid_pinpoints, list) else ast.literal_eval(grid_pinpoints)
|
||
|
|
|
||
|
|
original_width, original_height = image_size
|
||
|
|
height, width = select_best_resolution((original_height, original_width), possible_resolutions)
|
||
|
|
return width // patch_size, height // patch_size
|
||
|
|
|
||
|
|
|
||
|
|
def reshape_and_unpad_image_features(
|
||
|
|
image_feature: torch.Tensor,
|
||
|
|
height: int,
|
||
|
|
width: int,
|
||
|
|
image_size: Tuple[int, int],
|
||
|
|
possible_resolutions: List[Tuple[int, int]],
|
||
|
|
grid_size: int,
|
||
|
|
unpad: bool,
|
||
|
|
image_newline: torch.Tensor,
|
||
|
|
) -> torch.Tensor:
|
||
|
|
"""Reshapes and processes image features with optional unpadding operation.
|
||
|
|
|
||
|
|
Processes input image features by:
|
||
|
|
1. Separating base features from spatial features
|
||
|
|
2. Reshaping spatial features into a 5D tensor (num_patch_height, num_patch_width, height, width, channels)
|
||
|
|
3. Performing either unpadding operation or simple reshaping based on 'unpad' flag
|
||
|
|
4. Concatenating processed features with base features
|
||
|
|
|
||
|
|
Args:
|
||
|
|
image_feature: Input tensor containing image features with shape
|
||
|
|
[1 + num_patches, feature_dim] where the first element is the base feature
|
||
|
|
height: Original image height in pixels
|
||
|
|
width: Original image width in pixels
|
||
|
|
image_size: Target image size as (width, height) tuple
|
||
|
|
possible_resolutions: List of possible [height, width] resolutions for multi-scale processing
|
||
|
|
grid_size: Grid dimension for patch arrangement
|
||
|
|
unpad: Flag to enable unpadding operation
|
||
|
|
image_newline: Special token tensor used as separator when unpadding
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
torch.Tensor: Processed image features tensor with shape [1 + num_processed_patches, feature_dim]
|
||
|
|
|
||
|
|
Raises:
|
||
|
|
AssertionError: If base feature dimension doesn't match height*width
|
||
|
|
"""
|
||
|
|
base_image_feature = image_feature[0]
|
||
|
|
image_feature = image_feature[1:]
|
||
|
|
|
||
|
|
assert (
|
||
|
|
height * width == base_image_feature.shape[0]
|
||
|
|
), f"height: {height}, width: {width}, base_image_feature.shape[0]: {base_image_feature.shape[0]}"
|
||
|
|
|
||
|
|
num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_size, possible_resolutions, grid_size)
|
||
|
|
image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
|
||
|
|
|
||
|
|
if unpad:
|
||
|
|
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
|
||
|
|
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
|
||
|
|
image_feature = unpad_image(image_feature, image_size)
|
||
|
|
image_feature = torch.cat(
|
||
|
|
(
|
||
|
|
image_feature,
|
||
|
|
image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device),
|
||
|
|
),
|
||
|
|
dim=-1,
|
||
|
|
)
|
||
|
|
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
|
||
|
|
else:
|
||
|
|
image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
|
||
|
|
image_feature = image_feature.flatten(0, 3)
|
||
|
|
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
|
||
|
|
|
||
|
|
return image_feature
|
||
|
|
|
||
|
|
|
||
|
|
def anyres_postprocessing(
|
||
|
|
image_forward_outs: List[torch.FloatTensor],
|
||
|
|
image_sizes: List[List[int]],
|
||
|
|
possible_resolutions: List[Tuple[int, int]],
|
||
|
|
patch_size: int,
|
||
|
|
grid_size: int,
|
||
|
|
image_newline: torch.FloatTensor,
|
||
|
|
num_queries_vis_abstractor: int = -1,
|
||
|
|
unpad: bool = False,
|
||
|
|
) -> List[torch.FloatTensor]:
|
||
|
|
"""Processes 2D visual features into 1D sequences with post-processing steps.
|
||
|
|
|
||
|
|
Performs AnyRes postprocessing by flattening 2D visual features from grid partitions into 1D sequences, adding
|
||
|
|
newline embeddings at row boundaries for images, and optionally removing padding regions based on original image
|
||
|
|
sizes. For video data, processes each frame's features separately into a single sequence per video and disables
|
||
|
|
unpadding and newline insertion.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
image_forward_outs (List[torch.FloatTensor]): List of input tensors with shape
|
||
|
|
(number_of_images_in_grid, total_patches, feature_dim) containing visual features.
|
||
|
|
split_sizes (List[int]): A list containing the number of patches for each sample in the batch. The sum of
|
||
|
|
`split_sizes` should equal `image_forward_outs.shape[0]`.
|
||
|
|
image_sizes (List[List[int]]): A list where each element is a list `[width, height]` representing the original
|
||
|
|
dimensions of the corresponding image sample. Used for unpadding.
|
||
|
|
possible_resolutions (List[Tuple[int, int]]): A list of supported resolution tuples `(height, width)` used by
|
||
|
|
`reshape_and_unpad_image_features` for spatial reconstruction, especially during unpadding.
|
||
|
|
patch_size (int): The spatial dimension (height and width) of the square patches the image was divided into.
|
||
|
|
grid_size (int): The spatial dimension (height and width) of the square grid onto which patches are mapped.
|
||
|
|
`grid_size` should be divisible by `patch_size`.
|
||
|
|
image_newline (torch.FloatTensor): A learnable tensor representing the newline embedding, typically with shape
|
||
|
|
(1, feature_dim). Added after each row of image patches when not unpadding.
|
||
|
|
num_queries_vis_abstractor (int, optional): If a visual abstractor with a fixed number of output queries is used
|
||
|
|
instead of grid patching, this specifies the number of queries. Must be a perfect square if > 0.
|
||
|
|
Defaults to -1 (indicating standard grid patching is used).
|
||
|
|
unpad (bool, optional): If `True`, removes padding tokens from image features based on `image_sizes` and
|
||
|
|
`possible_resolutions`. Does not apply to video features. Defaults to False.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
List[torch.FloatTensor]: A list of tensors, where each tensor represents the processed 1D sequence of visual
|
||
|
|
features for a single sample from the input batch. The length of the sequence varies depending on processing
|
||
|
|
(unpadding, newlines, video flattening).
|
||
|
|
|
||
|
|
Raises:
|
||
|
|
AssertionError: If `num_queries_vis_abstractor` is greater than 0 but not a perfect square.
|
||
|
|
"""
|
||
|
|
height = width = grid_size // patch_size
|
||
|
|
|
||
|
|
if num_queries_vis_abstractor > 0:
|
||
|
|
assert (num_queries_vis_abstractor**0.5).is_integer(), "n_queries must be square number"
|
||
|
|
height = width = int(num_queries_vis_abstractor**0.5)
|
||
|
|
|
||
|
|
# post-processing (unpad, add newline)
|
||
|
|
new_image_features = []
|
||
|
|
for image_idx, image_feature in enumerate(image_forward_outs):
|
||
|
|
if image_feature.shape[0] > 1:
|
||
|
|
image_feature = reshape_and_unpad_image_features(
|
||
|
|
image_feature=image_feature,
|
||
|
|
height=height,
|
||
|
|
width=width,
|
||
|
|
image_size=image_sizes[image_idx],
|
||
|
|
possible_resolutions=possible_resolutions,
|
||
|
|
grid_size=grid_size, # Pass grid info if needed by helper
|
||
|
|
unpad=unpad,
|
||
|
|
image_newline=image_newline,
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
image_feature = image_feature[0]
|
||
|
|
image_feature = torch.cat((image_feature, image_newline[None].to(image_feature.device)), dim=0)
|
||
|
|
new_image_features.append(image_feature)
|
||
|
|
image_features = new_image_features
|
||
|
|
return image_features
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class HCXVisionOutput(ModelOutput):
|
||
|
|
"""Output class for vision models, containing various computation results.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
loss (Optional[torch.FloatTensor], optional): Total cross-entropy loss calculated from logits and labels.
|
||
|
|
loss_per_sample (Optional[torch.FloatTensor], optional): Per-sample loss values for advanced loss processing.
|
||
|
|
logits (torch.FloatTensor): Classification scores (before SoftMax) of shape (batch_size, num_classes).
|
||
|
|
past_key_values (Optional[Tuple[Tuple[torch.FloatTensor]]], optional): Contains precomputed hidden-states
|
||
|
|
that can be used (see `past_key_values` input) to speed up sequential decoding.
|
||
|
|
hidden_states (Optional[Tuple[torch.FloatTensor]], optional):
|
||
|
|
Tuple of torch.FloatTensor (one for the output of the embeddings + one for the output of each layer) of
|
||
|
|
shape (batch_size, sequence_length, hidden_size).
|
||
|
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||
|
|
attentions (Optional[Tuple[torch.FloatTensor]], optional): Tuple of torch.FloatTensor (one for each layer)
|
||
|
|
of shape (batch_size, num_heads, sequence_length, sequence_length). Attentions weights after the attention
|
||
|
|
softmax, used to compute the weighted average in the self-attention heads.
|
||
|
|
"""
|
||
|
|
|
||
|
|
loss: Optional[torch.FloatTensor] = None
|
||
|
|
loss_per_sample: Optional[torch.FloatTensor] = None
|
||
|
|
logits: torch.FloatTensor = None
|
||
|
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||
|
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||
|
|
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||
|
|
|
||
|
|
|
||
|
|
class HCXVisionForCausalLM(PreTrainedModel, GenerationMixin):
|
||
|
|
"""HCX Vision model for causal language modeling with vision-language capabilities.
|
||
|
|
|
||
|
|
This class combines a vision model with a language model to create a multimodal model
|
||
|
|
capable of processing images or videos and generating text based on the visual inputs.
|
||
|
|
|
||
|
|
Attributes:
|
||
|
|
config_class: Configuration class for the model.
|
||
|
|
vision_model_name: Name of the vision model component.
|
||
|
|
_no_split_modules: List of modules that should not be split during parallel processing.
|
||
|
|
supports_gradient_checkpointing: Whether the model supports gradient checkpointing.
|
||
|
|
_skip_keys_device_placement: Keys to skip during device placement.
|
||
|
|
"""
|
||
|
|
|
||
|
|
config_class = HCXVisionConfig
|
||
|
|
vision_model_name = "vision_model"
|
||
|
|
_no_split_modules = ["SiglipEncoderLayer", "LlamaDecoderLayer", "HyperCLOVAXDecoderLayer"]
|
||
|
|
supports_gradient_checkpointing = True
|
||
|
|
_skip_keys_device_placement = "past_key_values"
|
||
|
|
_supports_flash_attn_2 = True
|
||
|
|
_supports_sdpa = True
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
config: HCXVisionConfig,
|
||
|
|
**kwargs: Optional[Any],
|
||
|
|
) -> None:
|
||
|
|
"""Initialize the HCXVisionForCausalLM model.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
config: Configuration object for the model containing parameters for both
|
||
|
|
vision and language components.
|
||
|
|
**kwargs: Additional keyword arguments:
|
||
|
|
- use_liger: Whether to use liger kernel for hyperclovax models.
|
||
|
|
- use_fused_ce: Whether to use fused cross-entropy loss.
|
||
|
|
- use_sum_loss: Whether to use sum reduction for loss instead of mean.
|
||
|
|
- is_safetensor_save: Whether to save model using safetensors format.
|
||
|
|
|
||
|
|
Raises:
|
||
|
|
ValueError: If vision_config is not defined or if text_config is not defined.
|
||
|
|
"""
|
||
|
|
super().__init__(config) # self.config = config
|
||
|
|
|
||
|
|
# init configs
|
||
|
|
text_config = self._init_text_config(config)
|
||
|
|
vision_config = self._init_vision_config(config)
|
||
|
|
|
||
|
|
## possible_resolution should be matched with preprocessor_config.json
|
||
|
|
config.possible_resolutions = self._init_possible_resolutions(config, vision_config)
|
||
|
|
|
||
|
|
# init models & parameters
|
||
|
|
with no_init_weights(): # weight will be loaded in from_pretrained
|
||
|
|
self.vision_model = AutoModel.from_config(vision_config, trust_remote_code=True)
|
||
|
|
|
||
|
|
self.mm_projector = self._init_mm_projector(config, text_config, vision_config)
|
||
|
|
|
||
|
|
self.language_model = AutoModelForCausalLM.from_config(text_config)
|
||
|
|
self.lm_head_vocab_size = getattr(text_config, "padded_vocab_size", text_config.vocab_size)
|
||
|
|
self.language_model.lm_head = nn.Linear(text_config.hidden_size, self.lm_head_vocab_size, bias=False)
|
||
|
|
|
||
|
|
if config.anyres:
|
||
|
|
self.image_newline = nn.Parameter(torch.empty(text_config.hidden_size, dtype=self.dtype))
|
||
|
|
|
||
|
|
# modify configs or model settings
|
||
|
|
if text_config.model_type in ["llama", "hyperclovax", "gpt2"]:
|
||
|
|
self.language_model.gradient_checkpointing_enable()
|
||
|
|
if text_config.model_type == "hyperclovax" and self.use_liger:
|
||
|
|
self.language_model._get_apply_liger_kernel_converter()(model=self.language_model)
|
||
|
|
|
||
|
|
# update configs
|
||
|
|
self.vision_config = vision_config = self.vision_model.config
|
||
|
|
self.text_config = text_config = self.language_model.config
|
||
|
|
config.update({"vision_config": vision_config})
|
||
|
|
config.update({"text_config": text_config})
|
||
|
|
|
||
|
|
# etc
|
||
|
|
self.use_liger = kwargs.pop("use_liger", False)
|
||
|
|
self.use_fused_ce = kwargs.pop("use_fused_ce", False)
|
||
|
|
self.use_meansum_loss = kwargs.pop("use_meansum_loss", False)
|
||
|
|
self.freeze_before_sampler = kwargs.pop("freeze_before_sampler", False)
|
||
|
|
self.use_turnmeansum_loss = kwargs.pop("use_turnmeansum_loss", False)
|
||
|
|
self.vision_input_chunk_size = kwargs.pop("vision_input_chunk_size", None)
|
||
|
|
self.is_safetensor_save = kwargs.get("is_safetensor_save", True)
|
||
|
|
|
||
|
|
use_sum_loss = True if kwargs.pop("use_sum_loss", False) else False
|
||
|
|
self.reduction = self._init_reduction_type(use_sum_loss)
|
||
|
|
|
||
|
|
self.vision_model_use_no_grad = None # forward 시 체크 및 할당
|
||
|
|
|
||
|
|
self._backward_compatibility_gradient_checkpointing() # self.post_init() 에 포함되어 있는 gc 가능한지 확인하고 켜주는 함수
|
||
|
|
|
||
|
|
def _init_weights(self, module):
|
||
|
|
# copies from https://github.com/kakaobrain/honeybee/blob/main/honeybee/common_layers.py#L55
|
||
|
|
if (
|
||
|
|
isinstance(module, nn.Conv2d) # noqa: SIM101
|
||
|
|
or isinstance(module, nn.Embedding)
|
||
|
|
or isinstance(module, nn.Linear)
|
||
|
|
):
|
||
|
|
module.weight.data.normal_(mean=0.0, std=0.02)
|
||
|
|
if hasattr(module, "bias") and module.bias is not None:
|
||
|
|
module.bias.data.zero_()
|
||
|
|
|
||
|
|
elif isinstance(module, nn.LayerNorm):
|
||
|
|
module.bias.data.zero_()
|
||
|
|
module.weight.data.fill_(1.0)
|
||
|
|
elif isinstance(module, nn.Parameter):
|
||
|
|
embed_std = 1 / torch.sqrt(torch.tensor(module.size(0), dtype=torch.float)).to(module.dtype)
|
||
|
|
module.data.normal_(mean=0.0, std=embed_std)
|
||
|
|
|
||
|
|
def _init_reduction_type(self, use_sum_loss):
|
||
|
|
assert not (
|
||
|
|
self.use_meansum_loss and self.use_turnmeansum_loss
|
||
|
|
), "use_meansum_loss and use_turnmeansum_loss cannot both be True; only one or neither may be True."
|
||
|
|
if self.use_meansum_loss or self.use_turnmeansum_loss:
|
||
|
|
reduction = "none"
|
||
|
|
elif use_sum_loss:
|
||
|
|
reduction = "sum"
|
||
|
|
else:
|
||
|
|
reduction = "mean"
|
||
|
|
return reduction
|
||
|
|
|
||
|
|
def _init_vision_config(self, config):
|
||
|
|
vision_model_type = config.vision_config.model_type
|
||
|
|
if vision_model_type in CONFIG_MAPPING:
|
||
|
|
vision_config = CONFIG_MAPPING[vision_model_type](**config.vision_config.to_dict())
|
||
|
|
vision_config.auto_map = {}
|
||
|
|
else:
|
||
|
|
if config.vision_model_name_or_path is not None:
|
||
|
|
vision_config = AutoConfig.from_pretrained(config.vision_model_name_or_path, trust_remote_code=True)
|
||
|
|
elif config.vision_config._name_or_path is not None:
|
||
|
|
vision_config = AutoConfig.from_pretrained(config.vision_config._name_or_path, trust_remote_code=True)
|
||
|
|
else:
|
||
|
|
raise ValueError("vision_config is not defined")
|
||
|
|
|
||
|
|
vision_config.anyres = config.anyres
|
||
|
|
vision_config.max_num_grids = config.max_num_grids
|
||
|
|
return vision_config
|
||
|
|
|
||
|
|
def _init_text_config(self, config):
|
||
|
|
if hasattr(config, "text_config") and config.text_config is not None:
|
||
|
|
model_type = config.text_config.model_type
|
||
|
|
text_config = CONFIG_MAPPING[model_type](**config.text_config.to_dict())
|
||
|
|
else:
|
||
|
|
raise ValueError("text_config is not defined")
|
||
|
|
text_config._attn_implementation = config._attn_implementation
|
||
|
|
if text_config.model_type != "hyperclovax":
|
||
|
|
text_config.logits_scaling = 1.0
|
||
|
|
return text_config
|
||
|
|
|
||
|
|
def _init_possible_resolutions(self, config, vision_config):
|
||
|
|
"""possible_resolution should be matched with preprocessor_config.json"""
|
||
|
|
if not getattr(config, "possible_resolutions", []):
|
||
|
|
possible_resolutions = []
|
||
|
|
if config.anyres:
|
||
|
|
assert config.max_num_grids > 0
|
||
|
|
for i in range(1, config.max_num_grids + 1):
|
||
|
|
for j in range(1, config.max_num_grids + 1):
|
||
|
|
if i == 1 and j == 1 and not config.use_1x1_grid:
|
||
|
|
continue
|
||
|
|
if i * j <= config.max_num_grids:
|
||
|
|
possible_resolutions.append([i, j])
|
||
|
|
|
||
|
|
possible_resolutions = [
|
||
|
|
[ys * vision_config.image_size, xs * vision_config.image_size] for ys, xs in possible_resolutions
|
||
|
|
]
|
||
|
|
return possible_resolutions
|
||
|
|
else:
|
||
|
|
return config.possible_resolutions
|
||
|
|
|
||
|
|
def _init_mm_projector(self, config, text_config, vision_config):
|
||
|
|
input_hidden_size = vision_config.hidden_size
|
||
|
|
if config.mm_projector_type == "linear":
|
||
|
|
mm_projector = nn.Linear(input_hidden_size, text_config.hidden_size)
|
||
|
|
mm_projector.dtype = next(mm_projector.parameters()).dtype
|
||
|
|
elif config.mm_projector_type == "cabstractor":
|
||
|
|
mm_projector = HCXVisionCAbstractor(
|
||
|
|
num_queries=config.num_queries_vis_abstractor_image,
|
||
|
|
num_input_tokens=(vision_config.image_size // vision_config.patch_size) ** 2,
|
||
|
|
encoder_hidden_size=input_hidden_size,
|
||
|
|
hidden_size=input_hidden_size,
|
||
|
|
output_hidden_size=text_config.hidden_size,
|
||
|
|
pos_emb=config.proj_pos_emb,
|
||
|
|
prenorm=config.proj_prenorm,
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
mm_projector = HCXVisionMlp(
|
||
|
|
config.mm_projector_type,
|
||
|
|
input_hidden_size,
|
||
|
|
hidden_features=input_hidden_size, # TODO: llava 처럼 hidden_size 를 input_hidden_size 가 아니라 LLM embedding size 로 바꿔주기
|
||
|
|
out_features=self.text_config.hidden_size,
|
||
|
|
)
|
||
|
|
return mm_projector
|
||
|
|
|
||
|
|
def forward(
|
||
|
|
self,
|
||
|
|
input_ids: Optional[torch.LongTensor] = None,
|
||
|
|
pixel_values_images: Optional[List[List[torch.FloatTensor]]] = None,
|
||
|
|
image_sizes_images: Optional[List[List[Tuple[int, int]]]] = None,
|
||
|
|
pixel_values_videos: Optional[List[List[torch.FloatTensor]]] = None,
|
||
|
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||
|
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||
|
|
position_ids: Optional[torch.LongTensor] = None,
|
||
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||
|
|
labels: Optional[torch.LongTensor] = None,
|
||
|
|
use_cache: Optional[bool] = None,
|
||
|
|
output_attentions: Optional[bool] = None,
|
||
|
|
output_hidden_states: Optional[bool] = None,
|
||
|
|
return_dict: Optional[bool] = None,
|
||
|
|
**kwargs,
|
||
|
|
) -> Union[Tuple, HCXVisionOutput]:
|
||
|
|
"""Forward pass of the model.
|
||
|
|
|
||
|
|
This method processes the input tokens and images, combines them into a unified
|
||
|
|
representation, and generates text output based on the inputs.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
input_ids: Input token IDs. In positions where images are inputted, the value is replaced by "<|dummy3|>"
|
||
|
|
pixel_values: List of lists of 4D tensors for images. Each outer list corresponds to a batch and contains
|
||
|
|
inner lists of image tensors.
|
||
|
|
past_key_values: Pre-computed key and value states of the attention layers for faster inference.
|
||
|
|
attention_mask: Mask to avoid performing attention on padding token indices.
|
||
|
|
inputs_embeds: Input embeddings. If provided, input_ids will not be used.
|
||
|
|
labels: Labels for computing the language modeling loss.
|
||
|
|
use_cache: Whether to use past key/values for faster inference.
|
||
|
|
output_attentions: Whether to return attention weights of each layer.
|
||
|
|
output_hidden_states: Whether to return hidden states of each layer.
|
||
|
|
return_dict: Whether to return a ModelOutput instead of a tuple.
|
||
|
|
image_sizes: List of lists representing image dimensions (width, height).
|
||
|
|
vision_query_lengths: List of lists containing lengths when each image is converted into visual tokens.
|
||
|
|
non_vision_query_lengths: List of lengths of text tokens (excluding visual tokens) for each sample.
|
||
|
|
img_start_ids_list: List of lists containing indices of img_start_id tokens for each sample.
|
||
|
|
num_queries_vis_abstractors: List of lists containing number of visual tokens for each image grid.\
|
||
|
|
For video frames, this is the number of visual tokens for the fast part.
|
||
|
|
num_queries_vis_abstractors_slow: List of lists containing number of visual tokens for
|
||
|
|
the slow part when applying the slowfast algorithm to video frames.
|
||
|
|
first_last_frames_slows: List of booleans indicating whether the slowfast algorithm is
|
||
|
|
applied to the first or last frames of the video.
|
||
|
|
is_video_list: List of booleans indicating which inputs are videos.
|
||
|
|
**kwargs: Additional keyword arguments.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
If return_dict=True, returns an HCXVisionOutput object containing:
|
||
|
|
- loss: Language modeling loss if labels are provided, otherwise None.
|
||
|
|
- loss_per_sample: Per-sample loss if labels are provided, otherwise None.
|
||
|
|
- logits: Prediction scores of the language modeling head.
|
||
|
|
- past_key_values: Past key/values for faster inference if use_cache=True.
|
||
|
|
- hidden_states: Hidden states of all layers if output_hidden_states=True.
|
||
|
|
- attentions: Attention weights of all layers if output_attentions=True.
|
||
|
|
If return_dict=False, returns a tuple containing the above items except loss_per_sample.
|
||
|
|
"""
|
||
|
|
output_attentions = (
|
||
|
|
output_attentions if output_attentions is not None else self.config.vision_config.output_attentions
|
||
|
|
)
|
||
|
|
output_hidden_states = (
|
||
|
|
output_hidden_states if output_hidden_states is not None else self.config.vision_config.output_hidden_states
|
||
|
|
)
|
||
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||
|
|
|
||
|
|
if inputs_embeds is None and past_key_values is None:
|
||
|
|
if pixel_values_images is not None or pixel_values_videos is not None:
|
||
|
|
inputs_embeds = self.extract_inputs_embeds(
|
||
|
|
input_ids=input_ids,
|
||
|
|
pixel_values_images=pixel_values_images,
|
||
|
|
image_sizes_images=image_sizes_images,
|
||
|
|
pixel_values_videos=pixel_values_videos,
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||
|
|
|
||
|
|
if inputs_embeds is not None:
|
||
|
|
input_ids = None
|
||
|
|
|
||
|
|
################################
|
||
|
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||
|
|
outputs = self.language_model.base_model(
|
||
|
|
input_ids=input_ids,
|
||
|
|
inputs_embeds=inputs_embeds,
|
||
|
|
attention_mask=attention_mask,
|
||
|
|
position_ids=position_ids,
|
||
|
|
past_key_values=past_key_values,
|
||
|
|
use_cache=use_cache,
|
||
|
|
output_attentions=output_attentions,
|
||
|
|
output_hidden_states=output_hidden_states,
|
||
|
|
return_dict=return_dict,
|
||
|
|
)
|
||
|
|
|
||
|
|
hidden_states = outputs[0]
|
||
|
|
hidden_states = hidden_states * self.text_config.logits_scaling
|
||
|
|
|
||
|
|
loss = None
|
||
|
|
loss_per_sample = None
|
||
|
|
logits = self.language_model.lm_head(hidden_states)
|
||
|
|
if labels is not None:
|
||
|
|
# Shift so that tokens < n predict n
|
||
|
|
shift_logits = logits[..., :-1, :].contiguous()
|
||
|
|
shift_labels = labels[..., 1:].contiguous()
|
||
|
|
|
||
|
|
# Flatten the tokens
|
||
|
|
loss_fct = CrossEntropyLoss(reduction="none") # ignore IGNORE_INDEX(-100)
|
||
|
|
shift_logits = shift_logits.view(-1, self.lm_head_vocab_size)
|
||
|
|
shift_labels = shift_labels.view(-1)
|
||
|
|
|
||
|
|
# Enable model/pipeline parallelism
|
||
|
|
shift_labels = shift_labels.to(shift_logits.device)
|
||
|
|
loss = loss_fct(shift_logits, shift_labels)
|
||
|
|
if get_rank() == 0:
|
||
|
|
loss_per_sample = loss.view(logits.shape[0], -1).sum(axis=1) / (
|
||
|
|
shift_labels.view(logits.shape[0], -1) != self.config.ignore_index
|
||
|
|
).sum(axis=1)
|
||
|
|
loss = loss[shift_labels != self.config.ignore_index].mean()
|
||
|
|
if not return_dict:
|
||
|
|
output = (logits,) + outputs[1:]
|
||
|
|
return (loss,) + output if loss is not None else output
|
||
|
|
|
||
|
|
return HCXVisionOutput(
|
||
|
|
loss=loss,
|
||
|
|
loss_per_sample=loss_per_sample,
|
||
|
|
logits=logits,
|
||
|
|
past_key_values=outputs.past_key_values,
|
||
|
|
hidden_states=outputs.hidden_states,
|
||
|
|
attentions=outputs.attentions,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_input_embeddings
|
||
|
|
def get_input_embeddings(self):
|
||
|
|
return self.language_model.get_input_embeddings()
|
||
|
|
|
||
|
|
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_input_embeddings
|
||
|
|
def set_input_embeddings(self, value):
|
||
|
|
self.language_model.set_input_embeddings(value)
|
||
|
|
|
||
|
|
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_output_embeddings
|
||
|
|
def get_output_embeddings(self):
|
||
|
|
return self.language_model.get_output_embeddings()
|
||
|
|
|
||
|
|
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_output_embeddings
|
||
|
|
def set_output_embeddings(self, new_embeddings):
|
||
|
|
self.language_model.set_output_embeddings(new_embeddings)
|
||
|
|
|
||
|
|
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_decoder
|
||
|
|
def set_decoder(self, decoder):
|
||
|
|
self.language_model.set_decoder(decoder)
|
||
|
|
|
||
|
|
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_decoder
|
||
|
|
def get_decoder(self):
|
||
|
|
return self.language_model.get_decoder()
|
||
|
|
|
||
|
|
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.tie_weights
|
||
|
|
def tie_weights(self):
|
||
|
|
return self.language_model.tie_weights()
|
||
|
|
|
||
|
|
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.resize_token_embeddings
|
||
|
|
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
|
||
|
|
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
|
||
|
|
self.config.text_config.vocab_size = model_embeds.num_embeddings
|
||
|
|
self.vocab_size = model_embeds.num_embeddings
|
||
|
|
return model_embeds
|
||
|
|
|
||
|
|
def extract_inputs_embeds(
|
||
|
|
self,
|
||
|
|
input_ids: Optional[torch.LongTensor] = None,
|
||
|
|
pixel_values_images: Optional[List[List[torch.FloatTensor]]] = None,
|
||
|
|
image_sizes_images: Optional[List[List[Tuple[int, int]]]] = None,
|
||
|
|
pixel_values_videos: Optional[List[List[torch.FloatTensor]]] = None,
|
||
|
|
):
|
||
|
|
"""Extract input embeddings by processing text tokens and visual features.
|
||
|
|
|
||
|
|
This method processes the input tokens and image features, extracts the visual features
|
||
|
|
using the vision model, and combines them with the text token embeddings to create
|
||
|
|
a unified input representation for the language model.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
input_ids: Input token IDs with img_start_id markers for image positions.
|
||
|
|
pixel_values: List of lists of image tensors.
|
||
|
|
past_key_values: Pre-computed key and value states for faster inference.
|
||
|
|
image_sizes: List of lists of image dimensions (width, height).
|
||
|
|
vision_query_lengths: List of lists of lengths when each image is converted to visual tokens.
|
||
|
|
non_vision_query_lengths: List of lengths of text tokens (excluding visual tokens) for each sample.
|
||
|
|
img_start_ids_list: List of lists containing indices of img_start_id tokens for each sample.
|
||
|
|
first_last_frames_slows: List of booleans indicating whether the slowfast algorithm is
|
||
|
|
applied to the first or last frames of the video.
|
||
|
|
is_videos: List of booleans indicating which inputs are videos.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Combined embeddings of text tokens and visual features.
|
||
|
|
"""
|
||
|
|
# for convert back to List of List format
|
||
|
|
len_pixel_values_images = [len(pixel_value) for pixel_value in pixel_values_images] if pixel_values_images else []
|
||
|
|
len_pixel_values_videos = [len(pixel_value) for pixel_value in pixel_values_videos] if pixel_values_videos else []
|
||
|
|
|
||
|
|
if sum(len_pixel_values_images) + sum(len_pixel_values_videos) == 0:
|
||
|
|
return None
|
||
|
|
|
||
|
|
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||
|
|
|
||
|
|
if sum(len_pixel_values_images) > 0:
|
||
|
|
image_features_batch = self.forward_images(
|
||
|
|
pixel_values_images, image_sizes_images, len_pixel_values_images
|
||
|
|
)
|
||
|
|
for i, image_features in enumerate(image_features_batch):
|
||
|
|
if len(image_features) > 0:
|
||
|
|
image_token_indices = (input_ids[i] == self.config.image_token_id).nonzero().squeeze()
|
||
|
|
inputs_embeds[i][image_token_indices] = torch.cat(image_features).to(inputs_embeds.dtype)
|
||
|
|
|
||
|
|
if sum(len_pixel_values_videos) > 0:
|
||
|
|
video_features_batch = self.forward_videos(pixel_values_videos, len_pixel_values_videos)
|
||
|
|
for i, video_features in enumerate(video_features_batch):
|
||
|
|
if len(video_features) > 0:
|
||
|
|
video_token_indices = (input_ids[i] == self.config.video_token_id).nonzero().squeeze()
|
||
|
|
inputs_embeds[i][video_token_indices] = torch.cat(video_features).to(inputs_embeds.dtype)
|
||
|
|
|
||
|
|
return inputs_embeds
|
||
|
|
|
||
|
|
def forward_images(
|
||
|
|
self,
|
||
|
|
pixel_values_images: List[List[torch.FloatTensor]],
|
||
|
|
image_sizes_images: List[List[Tuple[int, int]]],
|
||
|
|
len_pixel_values_images: List[int],
|
||
|
|
) -> List[List[torch.Tensor]]:
|
||
|
|
if sum(len_pixel_values_images) == 0:
|
||
|
|
return None
|
||
|
|
|
||
|
|
concat_pixel_values_images = torch.cat(list(chain(*pixel_values_images)), dim=0)
|
||
|
|
|
||
|
|
visual_token_idx = 0 if "siglip" in self.vision_config.model_type else 1
|
||
|
|
context_vision_model = torch.no_grad() if self.vision_model_use_no_grad else contextlib.nullcontext()
|
||
|
|
with context_vision_model:
|
||
|
|
if self.config.use_nth_layer == -1:
|
||
|
|
# Replace post_layernorm of the last layer with Identity
|
||
|
|
self.vision_model.vision_model.post_layernorm = nn.Identity()
|
||
|
|
image_forward_outs = self.vision_model(concat_pixel_values_images)
|
||
|
|
image_forward_outs = image_forward_outs.last_hidden_state[:, visual_token_idx:]
|
||
|
|
else:
|
||
|
|
image_forward_outs = self.vision_model(concat_pixel_values_images, output_hidden_states=True)
|
||
|
|
image_forward_outs = image_forward_outs.hidden_states[self.config.use_nth_layer][:, visual_token_idx:]
|
||
|
|
|
||
|
|
image_forward_outs = image_forward_outs.to(dtype=self.mm_projector.dtype)
|
||
|
|
image_forward_outs = self.mm_projector(image_forward_outs) # b (h w) d
|
||
|
|
|
||
|
|
# feature 를 분할. e.g. torch.Size([18, 81, 3072]) -> [torch.Size([9, 81, 3072]), torch.Size([9, 81, 3072])]
|
||
|
|
split_sizes = [pixel_value.shape[0] for pixel_value in chain(*pixel_values_images)]
|
||
|
|
image_forward_outs = torch.split(image_forward_outs, split_sizes, dim=0)
|
||
|
|
|
||
|
|
# newline 붙여주기 (anyres postprocessing)
|
||
|
|
image_features = anyres_postprocessing(
|
||
|
|
image_forward_outs=image_forward_outs,
|
||
|
|
image_sizes=[image_size for image_sizes in image_sizes_images for image_size in image_sizes],
|
||
|
|
num_queries_vis_abstractor=self.config.num_queries_vis_abstractor_image,
|
||
|
|
unpad=self.config.unpad,
|
||
|
|
patch_size=self.vision_config.patch_size,
|
||
|
|
grid_size=self.vision_config.image_size,
|
||
|
|
image_newline=self.image_newline,
|
||
|
|
possible_resolutions=self.config.possible_resolutions,
|
||
|
|
)
|
||
|
|
|
||
|
|
# 원래 pixel_values_images 형태로 복원
|
||
|
|
image_features = [
|
||
|
|
image_features[sum(len_pixel_values_images[:i]) : sum(len_pixel_values_images[: i + 1])]
|
||
|
|
for i in range(len(len_pixel_values_images))
|
||
|
|
]
|
||
|
|
|
||
|
|
return image_features
|
||
|
|
|
||
|
|
def forward_videos(
|
||
|
|
self,
|
||
|
|
pixel_values_videos: List[List[torch.FloatTensor]],
|
||
|
|
len_pixel_values_videos: List[int],
|
||
|
|
) -> List[torch.Tensor]:
|
||
|
|
|
||
|
|
len_video_grids = sum(len_pixel_values_videos)
|
||
|
|
if len_video_grids == 0:
|
||
|
|
return None
|
||
|
|
|
||
|
|
# Run Vision Model
|
||
|
|
concat_pixel_values_videos = torch.cat(list(chain(*pixel_values_videos)), dim=0)
|
||
|
|
|
||
|
|
visual_token_idx = 0 if "siglip" in self.vision_config.model_type else 1
|
||
|
|
context_vision_model = torch.no_grad() if self.vision_model_use_no_grad else contextlib.nullcontext()
|
||
|
|
with context_vision_model:
|
||
|
|
if self.config.use_nth_layer == -1:
|
||
|
|
# Replace post_layernorm of the last layer with Identity
|
||
|
|
self.vision_model.vision_model.post_layernorm = nn.Identity()
|
||
|
|
video_forward_outs = self.vision_model(concat_pixel_values_videos)
|
||
|
|
video_forward_outs = video_forward_outs.last_hidden_state[:, visual_token_idx:]
|
||
|
|
else:
|
||
|
|
video_forward_outs = self.vision_model(concat_pixel_values_videos, output_hidden_states=True)
|
||
|
|
video_forward_outs = video_forward_outs.hidden_states[self.config.use_nth_layer][:, visual_token_idx:]
|
||
|
|
|
||
|
|
video_forward_outs = video_forward_outs.to(dtype=self.mm_projector.dtype)
|
||
|
|
|
||
|
|
# Run MM-Projector
|
||
|
|
# len(num_grids) == len(num_queries_vis_abstractors) + 1
|
||
|
|
grid_idx = 0
|
||
|
|
num_grids = [grid_idx] # e.g. [0, 9, 18, 19, 27, 28, 36, 37, 45, 46, 54, 55, 56]
|
||
|
|
num_queries_vis_abstractors = [] # e.g. [81, 81, 81, 9, 81, 9, 81, 9, 81, 9, 81, 9]
|
||
|
|
len_total_frames = video_forward_outs.shape[0]
|
||
|
|
|
||
|
|
if self.config.first_last_frames_slow:
|
||
|
|
# TODO: 동작 확인 안 했음. 해야 함.
|
||
|
|
# slowfast (first_last_frames_slow)
|
||
|
|
assert len_total_frames != 0
|
||
|
|
if len_total_frames <= 2:
|
||
|
|
num_queries_vis_abstractors.append(self.config.num_queries_vis_abstractor_video_slow)
|
||
|
|
grid_idx += len_total_frames
|
||
|
|
num_grids.append(grid_idx)
|
||
|
|
else:
|
||
|
|
num_queries_vis_abstractors.append(self.config.num_queries_vis_abstractor_video_slow)
|
||
|
|
grid_idx += 1
|
||
|
|
num_grids.append(grid_idx)
|
||
|
|
|
||
|
|
num_queries_vis_abstractors.append(self.config.num_queries_vis_abstractor_video_fast)
|
||
|
|
grid_idx += len_total_frames - 2
|
||
|
|
num_grids.append(grid_idx)
|
||
|
|
|
||
|
|
num_queries_vis_abstractors.append(self.config.num_queries_vis_abstractor_video_slow)
|
||
|
|
grid_idx += 1
|
||
|
|
num_grids.append(grid_idx)
|
||
|
|
else:
|
||
|
|
# slowfast
|
||
|
|
for pixel_values_frames in pixel_values_videos:
|
||
|
|
for pixel_values_frame in pixel_values_frames:
|
||
|
|
if len(pixel_values_frame) > 0:
|
||
|
|
num_queries_vis_abstractors.append(self.config.num_queries_vis_abstractor_video_slow)
|
||
|
|
grid_idx += 1
|
||
|
|
num_grids.append(grid_idx)
|
||
|
|
num_queries_vis_abstractors.append(self.config.num_queries_vis_abstractor_video_fast)
|
||
|
|
grid_idx = grid_idx + len(pixel_values_frame) - 1
|
||
|
|
num_grids.append(grid_idx)
|
||
|
|
|
||
|
|
video_forward_outs = self.mm_projector(video_forward_outs, num_queries_vis_abstractors, num_grids)
|
||
|
|
|
||
|
|
# video_group 별로 concat 처리.
|
||
|
|
# 예를 들어, 3x3 grid 를 사용했을 경우, 총 9개의 feature 가 모일 때까지, grouped_features 에 리스트를 모아주고, concat 처리.
|
||
|
|
video_features = [] # what we want to return
|
||
|
|
target_features = []
|
||
|
|
target_group_size = 0
|
||
|
|
group_counter = 0
|
||
|
|
video_groups = [
|
||
|
|
len(frame) for frames in pixel_values_videos for frame in frames
|
||
|
|
] # for concat video features after projector
|
||
|
|
|
||
|
|
for forward_out in video_forward_outs:
|
||
|
|
target_group_size += len(forward_out)
|
||
|
|
target_features.append(forward_out.flatten(0, 1))
|
||
|
|
|
||
|
|
video_group_size = video_groups[group_counter]
|
||
|
|
if video_group_size == target_group_size:
|
||
|
|
video_features.append(torch.cat(target_features, dim=0))
|
||
|
|
target_features = []
|
||
|
|
group_counter += 1
|
||
|
|
target_group_size = 0
|
||
|
|
|
||
|
|
elif video_group_size < target_group_size:
|
||
|
|
raise RuntimeError(f"video_group_size < target_group_size!! [{video_group_size} < {target_group_size}]")
|
||
|
|
|
||
|
|
assert len(target_features) == 0, f"target_features is not empty!! {target_features}"
|
||
|
|
assert len(video_groups) == len(video_features)
|
||
|
|
|
||
|
|
# 원래 pixel_values_videos 형태로 복원
|
||
|
|
video_features = [
|
||
|
|
video_features[sum(len_pixel_values_videos[:i]) : sum(len_pixel_values_videos[: i + 1])]
|
||
|
|
for i in range(len(len_pixel_values_videos))
|
||
|
|
]
|
||
|
|
|
||
|
|
return video_features
|
||
|
|
|
||
|
|
@torch.no_grad()
|
||
|
|
def generate(
|
||
|
|
self,
|
||
|
|
input_ids: Optional[torch.LongTensor] = None,
|
||
|
|
pixel_values_images: Optional[List[List[torch.FloatTensor]]] = None,
|
||
|
|
image_sizes_images: Optional[List[List[Tuple[int, int]]]] = None,
|
||
|
|
pixel_values_videos: Optional[List[List[torch.FloatTensor]]] = None,
|
||
|
|
pad_token_id: Optional[int] = None,
|
||
|
|
eos_token_id: Optional[int] = None,
|
||
|
|
bad_words_ids: Optional[List[List[int]]] = None,
|
||
|
|
max_length: int = 196,
|
||
|
|
min_length: int = 2,
|
||
|
|
do_sample: bool = True,
|
||
|
|
num_beams: int = 1,
|
||
|
|
top_p: float = 0.6,
|
||
|
|
top_k: int = 0,
|
||
|
|
temperature: float = 0.5,
|
||
|
|
repetition_penalty: float = 1.0,
|
||
|
|
length_penalty: int = 1,
|
||
|
|
use_cache: bool = True,
|
||
|
|
verbose: bool = False,
|
||
|
|
**kwargs,
|
||
|
|
) -> torch.LongTensor:
|
||
|
|
"""Generate text based on input tokens and images.
|
||
|
|
|
||
|
|
This method generates text based on the provided input tokens and images using
|
||
|
|
beam search and/or sampling strategies.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
input_ids: Input token IDs with img_start_id markers for image positions.
|
||
|
|
pixel_values: List of lists of image tensors.
|
||
|
|
image_sizes: List of lists of image dimensions (width, height).
|
||
|
|
vision_query_lengths: List of lists of lengths when each image is converted to visual tokens.
|
||
|
|
non_vision_query_lengths: List of lengths of text tokens (excluding visual tokens) for each sample.
|
||
|
|
num_queries_vis_abstractors: List of lists containing number of visual tokens for each image grid.
|
||
|
|
num_queries_vis_abstractors_slow: List of lists containing number of visual tokens for the slow part when
|
||
|
|
applying the slowfast algorithm to video frames.
|
||
|
|
first_last_frames_slows: List of booleans indicating whether the slowfast algorithm is applied to the first
|
||
|
|
or last frames of the video.
|
||
|
|
is_videos: List of booleans indicating which inputs are videos.
|
||
|
|
img_start_ids_list: List of lists containing indices of img_start_id tokens for each sample.
|
||
|
|
pad_token_id: Token ID used for padding.
|
||
|
|
eos_token_id: Token ID used to signal the end of a sequence.
|
||
|
|
bad_words_ids: List of token ID sequences that should not be generated.
|
||
|
|
max_length: Maximum length of the sequence to be generated (input length + max_new_tokens).
|
||
|
|
min_length: Minimum length of the sequence to be generated (input length + min_new_tokens).
|
||
|
|
do_sample: Whether to use sampling for generation (otherwise uses greedy decoding).
|
||
|
|
num_beams: Number of beams for beam search. 1 means no beam search.
|
||
|
|
top_p: Nucleus sampling parameter. Tokens with cumulative probability > top_p are kept.
|
||
|
|
top_k: Number of highest probability tokens to keep for top-k-filtering.
|
||
|
|
temperature: Value used to modulate the next token probabilities.
|
||
|
|
repetition_penalty: Penalty applied to tokens that have already appeared in the sequence.
|
||
|
|
length_penalty: Exponential penalty applied to sequence length.
|
||
|
|
use_cache: Whether to use past key/values for faster inference.
|
||
|
|
**kwargs: Additional keyword arguments.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Generated token IDs.
|
||
|
|
"""
|
||
|
|
# inputs_embeds: torch.bfloat16 : [batchsize, variable(visual token, text token, system prompt 모두 포함)]
|
||
|
|
if pad_token_id is None:
|
||
|
|
pad_token_id = self.tokenizer.pad_token_id
|
||
|
|
if eos_token_id is None:
|
||
|
|
eos_token_id = self.tokenizer.encode("<|endofturn|>")[0]
|
||
|
|
if bad_words_ids is None:
|
||
|
|
bad_words_ids = [
|
||
|
|
[
|
||
|
|
self.config.text_config.bos_token_id,
|
||
|
|
],
|
||
|
|
[
|
||
|
|
self.config.text_config.eos_token_id,
|
||
|
|
],
|
||
|
|
]
|
||
|
|
|
||
|
|
if (pixel_values_images is None or all(len(pixel_values) == 0 for pixel_values in pixel_values_images)) and (
|
||
|
|
pixel_values_videos is None or all(len(pixel_values) == 0 for pixel_values in pixel_values_videos)
|
||
|
|
):
|
||
|
|
return self.language_model.generate(
|
||
|
|
input_ids, pad_token_id=pad_token_id, eos_token_id=eos_token_id, bad_words_ids=bad_words_ids, **kwargs
|
||
|
|
)
|
||
|
|
|
||
|
|
inputs_embeds = self.extract_inputs_embeds(
|
||
|
|
input_ids=input_ids,
|
||
|
|
pixel_values_images=pixel_values_images,
|
||
|
|
image_sizes_images=image_sizes_images,
|
||
|
|
pixel_values_videos=pixel_values_videos,
|
||
|
|
)
|
||
|
|
|
||
|
|
inputs_embeds = inputs_embeds.to(device=self.language_model.device, dtype=self.language_model.dtype)
|
||
|
|
|
||
|
|
# pred : torch.int64 : [batchsize, generated token_length]
|
||
|
|
pred = self.language_model.generate(
|
||
|
|
inputs_embeds=inputs_embeds,
|
||
|
|
pad_token_id=pad_token_id,
|
||
|
|
eos_token_id=eos_token_id,
|
||
|
|
bad_words_ids=bad_words_ids,
|
||
|
|
max_new_tokens=max_length,
|
||
|
|
min_length=min_length,
|
||
|
|
num_beams=num_beams,
|
||
|
|
do_sample=(False if temperature == 0.0 else do_sample), # set do_sample=False if invalid temperature
|
||
|
|
top_k=top_k,
|
||
|
|
top_p=top_p,
|
||
|
|
temperature=temperature,
|
||
|
|
repetition_penalty=repetition_penalty,
|
||
|
|
length_penalty=length_penalty,
|
||
|
|
early_stopping=(False if num_beams <= 1 else True), # set early_stopping=False when not beam_search
|
||
|
|
use_cache=use_cache,
|
||
|
|
)
|
||
|
|
|
||
|
|
if verbose:
|
||
|
|
llm_query = self.tokenizer.batch_decode(
|
||
|
|
[
|
||
|
|
[token_id for token_id in input_ids_row if token_id != self.tokenizer.pad_token_id]
|
||
|
|
for input_ids_row in input_ids.detach().cpu().tolist()
|
||
|
|
],
|
||
|
|
skip_special_tokens=False,
|
||
|
|
)[0]
|
||
|
|
llm_pred = self.tokenizer.batch_decode(
|
||
|
|
[
|
||
|
|
[token_id for token_id in pred_row if token_id != self.tokenizer.pad_token_id]
|
||
|
|
for pred_row in pred.detach().cpu().tolist()
|
||
|
|
],
|
||
|
|
skip_special_tokens=False,
|
||
|
|
)[0]
|
||
|
|
print(f"# [info] llm_query: {llm_query}")
|
||
|
|
print(f"# [info] llm_pred: {llm_pred}")
|
||
|
|
|
||
|
|
return pred
|
||
|
|
|
||
|
|
def to_vision_model_device(self, input_tensor: Union[torch.Tensor, List]) -> Union[torch.Tensor, List]:
|
||
|
|
"""Move input tensors to the vision model's device.
|
||
|
|
This method recursively moves input tensors or lists of tensors to the vision model's device.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
input_tensor: Input tensor or list of tensors to be moved to the vision model's device.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
The input tensor or list of tensors moved to the vision model's device.
|
||
|
|
|
||
|
|
Raises:
|
||
|
|
TypeError: If the input is neither a tensor nor a list.
|
||
|
|
"""
|
||
|
|
if isinstance(input_tensor, list):
|
||
|
|
return [self.to_vision_model_device(item) for item in input_tensor]
|
||
|
|
elif isinstance(input_tensor, torch.Tensor):
|
||
|
|
return input_tensor.to(self.vision_model.device)
|
||
|
|
else:
|
||
|
|
raise TypeError("Unsupported data type. Only tensors and lists are allowed.")
|
||
|
|
|
||
|
|
def prepare_inputs_for_generation(
|
||
|
|
self,
|
||
|
|
input_ids: torch.LongTensor,
|
||
|
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||
|
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||
|
|
**kwargs,
|
||
|
|
) -> Dict[str, Any]:
|
||
|
|
"""Prepare inputs for the generation algorithm.
|
||
|
|
|
||
|
|
This method prepares the input for each generation step based on the model's needs.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
input_ids: Input token IDs.
|
||
|
|
past_key_values: Pre-computed key and value states for faster inference.
|
||
|
|
attention_mask: Mask to avoid performing attention on padding token indices.
|
||
|
|
inputs_embeds: Input embeddings. If provided, input_ids will not be used.
|
||
|
|
**kwargs: Additional keyword arguments.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Dictionary containing the prepared inputs for the model.
|
||
|
|
"""
|
||
|
|
input_ids = kwargs.get("decoder_input_ids", input_ids)
|
||
|
|
|
||
|
|
if past_key_values:
|
||
|
|
input_ids = input_ids[:, -1:]
|
||
|
|
|
||
|
|
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||
|
|
if inputs_embeds is not None and past_key_values is None:
|
||
|
|
model_inputs = {"inputs_embeds": inputs_embeds}
|
||
|
|
else:
|
||
|
|
model_inputs = {"input_ids": input_ids}
|
||
|
|
|
||
|
|
model_inputs.update(
|
||
|
|
{
|
||
|
|
"past_key_values": past_key_values,
|
||
|
|
"use_cache": kwargs.get("use_cache"),
|
||
|
|
"attention_mask": attention_mask,
|
||
|
|
"pixel_values": kwargs.get("pixel_values", None),
|
||
|
|
}
|
||
|
|
)
|
||
|
|
return model_inputs
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def from_config(cls, config, vision_model_name_or_path):
|
||
|
|
return cls(config, vision_model_name_or_path)
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def from_pretrained(
|
||
|
|
cls,
|
||
|
|
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
|
||
|
|
*model_args,
|
||
|
|
**kwargs,
|
||
|
|
) -> "HCXVisionForCausalLM":
|
||
|
|
assert pretrained_model_name_or_path is not None
|
||
|
|
|
||
|
|
save_only_vision = kwargs.pop("save_only_vision") if "save_only_vision" in kwargs else False
|
||
|
|
save_only_qformer = kwargs.pop("save_only_qformer") if "save_only_qformer" in kwargs else False
|
||
|
|
save_shard_size = kwargs.pop("save_shard_size") if "save_shard_size" in kwargs else "5GB"
|
||
|
|
|
||
|
|
if pretrained_model_name_or_path is not None: # when evaluate or load instruction tunned model
|
||
|
|
model: HCXVisionForCausalLM = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||
|
|
model.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
|
||
|
|
|
||
|
|
image_token_id = model.tokenizer.encode(IMAGE_LOC, add_special_tokens=False)
|
||
|
|
assert (
|
||
|
|
len(image_token_id) == 1
|
||
|
|
), f'"<|dummy3|>" was not encoded into a single special token. Encoding result: {image_token_id}'
|
||
|
|
model.config.image_token_id = image_token_id[0]
|
||
|
|
|
||
|
|
video_token_id = model.tokenizer.encode(VIDEO_LOC, add_special_tokens=False)
|
||
|
|
assert (
|
||
|
|
len(video_token_id) == 1
|
||
|
|
), f'"<|_unuse_missing_100270|>" was not encoded into a single special token. Encoding result: {video_token_id}'
|
||
|
|
model.config.video_token_id = video_token_id[0]
|
||
|
|
|
||
|
|
model.save_only_vision = save_only_vision
|
||
|
|
model.save_only_qformer = save_only_qformer
|
||
|
|
model.save_shard_size = save_shard_size
|
||
|
|
|
||
|
|
return model
|
||
|
|
|
||
|
|
def get_language_model(self):
|
||
|
|
return self.language_model.base_model
|
||
|
|
|
||
|
|
def get_vision_model(self):
|
||
|
|
return self.vision_model
|
||
|
|
|
||
|
|
def save_pretrained(
|
||
|
|
self,
|
||
|
|
save_directory: Union[str, os.PathLike],
|
||
|
|
*args,
|
||
|
|
**kwargs,
|
||
|
|
):
|
||
|
|
state_dict = kwargs["state_dict"] if "state_dict" in kwargs else self.state_dict()
|
||
|
|
partial_state_dict = self.get_pretrained_state_dict(
|
||
|
|
state_dict,
|
||
|
|
save_directory,
|
||
|
|
)
|
||
|
|
kwargs["state_dict"] = partial_state_dict
|
||
|
|
kwargs["safe_serialization"] = self.is_safetensor_save
|
||
|
|
kwargs.setdefault("max_shard_size", self.save_shard_size)
|
||
|
|
super().save_pretrained(save_directory, *args, **kwargs)
|
||
|
|
|
||
|
|
def get_pretrained_state_dict(self, state_dict, save_dir):
|
||
|
|
vision_key = "vision_model."
|
||
|
|
llm_keys = ["language_model."]
|
||
|
|
head_key = "lm_head."
|
||
|
|
|
||
|
|
for key in list(state_dict.keys()):
|
||
|
|
if self.save_only_vision:
|
||
|
|
for llm_key in llm_keys:
|
||
|
|
if llm_key in key:
|
||
|
|
state_dict.pop(key)
|
||
|
|
if key.startswith(head_key):
|
||
|
|
state_dict.pop(key)
|
||
|
|
|
||
|
|
elif self.save_only_qformer:
|
||
|
|
if f"{vision_key}" in key:
|
||
|
|
state_dict.pop(key)
|
||
|
|
|
||
|
|
return state_dict
|
||
|
|
|
||
|
|
|
||
|
|
|
||
|
|
class HCXVisionMlp(nn.Module):
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
mm_projector_type,
|
||
|
|
in_features,
|
||
|
|
hidden_features=None,
|
||
|
|
out_features=None,
|
||
|
|
act_layer=nn.GELU,
|
||
|
|
):
|
||
|
|
super().__init__()
|
||
|
|
out_features = out_features or in_features
|
||
|
|
hidden_features = hidden_features or in_features
|
||
|
|
self.mm_projector_type = mm_projector_type
|
||
|
|
if self.mm_projector_type == "mlp":
|
||
|
|
self.fc1 = nn.Linear(in_features, hidden_features)
|
||
|
|
self.act = act_layer()
|
||
|
|
self.fc2 = nn.Linear(hidden_features, out_features)
|
||
|
|
elif self.mm_projector_type == "inverted_mlp":
|
||
|
|
self.fc1 = nn.Linear(in_features, 2 * hidden_features)
|
||
|
|
self.act = act_layer()
|
||
|
|
self.fc2 = nn.Linear(2 * hidden_features, out_features)
|
||
|
|
else:
|
||
|
|
raise NotImplementedError("{} is not implemented".format(self.mm_projector_type))
|
||
|
|
|
||
|
|
def forward(self, x):
|
||
|
|
x = self.fc1(x)
|
||
|
|
x = self.act(x)
|
||
|
|
x = self.fc2(x)
|
||
|
|
return x
|
||
|
|
|
||
|
|
|
||
|
|
class HCXVisionCAbstractor(nn.Module):
|
||
|
|
"""
|
||
|
|
This module is based on C-Abstractor, whose license is under apache-2.0.
|
||
|
|
You can check the original code at https://github.com/khanrc/honeybee/blob/main/honeybee/projectors/projectors.py
|
||
|
|
and we made necessary modifications.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
num_queries: int,
|
||
|
|
num_input_tokens: int,
|
||
|
|
encoder_hidden_size: int,
|
||
|
|
hidden_size: int,
|
||
|
|
output_hidden_size: int,
|
||
|
|
pos_emb: bool = True,
|
||
|
|
prenorm: bool = False,
|
||
|
|
):
|
||
|
|
super().__init__()
|
||
|
|
self.num_input_tokens = num_input_tokens
|
||
|
|
self.output_hidden_size = output_hidden_size
|
||
|
|
|
||
|
|
# Positional embedding
|
||
|
|
if pos_emb:
|
||
|
|
self.pos_emb = torch.nn.Parameter(torch.zeros(1, num_input_tokens, encoder_hidden_size))
|
||
|
|
self.pos_emb.data.normal_(mean=0.0, std=0.02)
|
||
|
|
else:
|
||
|
|
self.pos_emb = None
|
||
|
|
|
||
|
|
# (Optional) Pre-normalization layer
|
||
|
|
if prenorm:
|
||
|
|
self.prenorm = LayerNorm(encoder_hidden_size)
|
||
|
|
else:
|
||
|
|
self.prenorm = None
|
||
|
|
|
||
|
|
self.build_net(num_queries, encoder_hidden_size, hidden_size, output_hidden_size)
|
||
|
|
self.dtype = next(self.parameters()).dtype
|
||
|
|
|
||
|
|
def forward(
|
||
|
|
self,
|
||
|
|
x: torch.Tensor,
|
||
|
|
num_queries_vis_abstractors: Optional[List[List[int]]] = None,
|
||
|
|
num_grids: Optional[List[int]] = None,
|
||
|
|
) -> torch.Tensor:
|
||
|
|
"""
|
||
|
|
Args:
|
||
|
|
x: (B, L, encoder_hidden_size) tensor from the visual backbone (e.g. CLIP visual encoder), including cls token.
|
||
|
|
"""
|
||
|
|
if self.prenorm is not None:
|
||
|
|
x = self.prenorm(x)
|
||
|
|
|
||
|
|
if self.pos_emb is not None:
|
||
|
|
x = x + self.pos_emb
|
||
|
|
|
||
|
|
x = self._forward(
|
||
|
|
x,
|
||
|
|
num_queries_vis_abstractors=num_queries_vis_abstractors,
|
||
|
|
num_grids=num_grids,
|
||
|
|
) # (B, L, output_hidden_size)
|
||
|
|
|
||
|
|
return x
|
||
|
|
|
||
|
|
def _forward(
|
||
|
|
self,
|
||
|
|
x: torch.Tensor,
|
||
|
|
num_queries_vis_abstractors: Optional[List[List[int]]] = None,
|
||
|
|
num_grids: Optional[List[int]] = None,
|
||
|
|
) -> torch.Tensor:
|
||
|
|
# x: [B, L, dim]
|
||
|
|
B, L, dim = x.shape
|
||
|
|
hw = int(L**0.5)
|
||
|
|
x = rearrange(x, "b (h w) d -> b d h w", h=hw, w=hw)
|
||
|
|
|
||
|
|
if num_queries_vis_abstractors is not None:
|
||
|
|
assert num_grids is not None
|
||
|
|
return self._forward_adaptive_num_query(x, num_queries_vis_abstractors, num_grids)
|
||
|
|
|
||
|
|
x = self.net(x)
|
||
|
|
x = rearrange(x, "b d h w -> b (h w) d")
|
||
|
|
x = self.readout(x)
|
||
|
|
return x
|
||
|
|
|
||
|
|
def _forward_adaptive_num_query(
|
||
|
|
self,
|
||
|
|
x: torch.Tensor,
|
||
|
|
num_queries_vis_abstractors: Optional[List[List[int]]] = None,
|
||
|
|
num_grids: Optional[List[int]] = None,
|
||
|
|
) -> List[torch.Tensor]:
|
||
|
|
# self.net is consisted by 3 layers (s1, sampler, s2)
|
||
|
|
assert len(self.net) == 3
|
||
|
|
|
||
|
|
x = self.net[0](x) # s1
|
||
|
|
new_x = []
|
||
|
|
for i, num_queries in enumerate(num_queries_vis_abstractors):
|
||
|
|
hw = int(num_queries**0.5)
|
||
|
|
sampler = nn.AdaptiveAvgPool2d((hw, hw))
|
||
|
|
out = sampler(x[num_grids[i] : num_grids[i + 1], :])
|
||
|
|
out = self.net[2](out) # s2
|
||
|
|
|
||
|
|
out = rearrange(out, "b d h w -> b (h w) d")
|
||
|
|
out = self.readout(out)
|
||
|
|
|
||
|
|
new_x.append(out)
|
||
|
|
return new_x
|
||
|
|
|
||
|
|
def build_net(
|
||
|
|
self,
|
||
|
|
n_queries: int,
|
||
|
|
encoder_hidden_size: int,
|
||
|
|
hidden_size: int,
|
||
|
|
output_hidden_size: int,
|
||
|
|
depth: int = 3,
|
||
|
|
mlp_depth: int = 2,
|
||
|
|
):
|
||
|
|
assert (n_queries**0.5).is_integer(), f"n_queries must be square number. n_queries: {n_queries}"
|
||
|
|
hw = int(n_queries**0.5)
|
||
|
|
|
||
|
|
# RegBlock = ResBlock + SE
|
||
|
|
RegBlock = partial(
|
||
|
|
RegStage,
|
||
|
|
stride=1,
|
||
|
|
dilation=1,
|
||
|
|
act_layer=nn.SiLU,
|
||
|
|
norm_layer=LayerNorm2d,
|
||
|
|
)
|
||
|
|
|
||
|
|
s1 = RegBlock(
|
||
|
|
depth,
|
||
|
|
encoder_hidden_size,
|
||
|
|
hidden_size,
|
||
|
|
)
|
||
|
|
sampler = nn.AdaptiveAvgPool2d((hw, hw))
|
||
|
|
s2 = RegBlock(
|
||
|
|
depth,
|
||
|
|
hidden_size,
|
||
|
|
hidden_size,
|
||
|
|
)
|
||
|
|
|
||
|
|
self.net = nn.Sequential(s1, sampler, s2)
|
||
|
|
self.readout = self.build_mlp(mlp_depth, hidden_size, output_hidden_size)
|
||
|
|
|
||
|
|
def build_mlp(
|
||
|
|
self,
|
||
|
|
depth: int,
|
||
|
|
hidden_size: int,
|
||
|
|
output_hidden_size: int,
|
||
|
|
):
|
||
|
|
layers = [nn.Linear(hidden_size, output_hidden_size)]
|
||
|
|
for _ in range(1, depth):
|
||
|
|
layers.append(nn.SiLU())
|
||
|
|
layers.append(nn.Linear(output_hidden_size, output_hidden_size))
|
||
|
|
return nn.Sequential(*layers)
|
||
|
|
|