[Feat/Fix] Refactoring Llava models into single file (#475)

This commit is contained in:
Li Bo
2024-05-27 03:29:51 +08:00
committed by GitHub
parent 947bda73fe
commit 2b605ab1d7
6 changed files with 118 additions and 688 deletions

View File

@@ -5,7 +5,7 @@ from typing import List, Iterable, Optional, Tuple
import numpy as np
import torch
from torch import nn
from transformers import CLIPVisionModel, LlavaConfig
from transformers import CLIPVisionModel, CLIPVisionConfig, LlavaConfig, Qwen2Config, MistralConfig
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@@ -18,6 +18,8 @@ from sglang.srt.mm_utils import (
unpad_image_shape,
)
from sglang.srt.models.llama2 import LlamaForCausalLM
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
from sglang.srt.models.mistral import MistralForCausalLM
class LlavaLlamaForCausalLM(nn.Module):
@@ -287,8 +289,101 @@ class LlavaLlamaForCausalLM(nn.Module):
return self.image_size // self.patch_size
first_call = True
class LlavaQwenForCausalLM(LlavaLlamaForCausalLM):
def __init__(
self,
config: LlavaConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__(config, quant_config=quant_config)
self.config = config
self.vision_tower = None
if getattr(self.config, "vision_config", None) is None:
self.config.vision_config = CLIPVisionConfig(self.config.mm_vision_tower)
if getattr(self.config, "text_config", None) is None:
self.config.text_config = Qwen2Config(self.config._name_or_path)
self.config.vision_config.hidden_size = config.mm_hidden_size
self.config.text_config.hidden_size = config.hidden_size
if getattr(self.config, "projector_hidden_act", None) is None:
self.config.projector_hidden_act = "gelu"
if getattr(self.config, "image_token_index", None) is None:
self.config.image_token_index = 151646
self.multi_modal_projector = LlavaMultiModalProjector(config)
self.language_model = Qwen2ForCausalLM(config, quant_config=quant_config)
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
self.language_model.model.image_newline = nn.Parameter(
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
)
def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None):
new_image_feature_len = self.image_feature_len
# now only support spatial_unpad + anyres
if self.mm_patch_merge_type.startswith("spatial"):
height = width = self.num_patches_per_side
if pt_shape[0] > 1:
if self.image_aspect_ratio == "anyres":
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
image_size,
self.image_grid_pinpoints,
self.vision_tower.config.image_size,
)
if "unpad" in self.mm_patch_merge_type:
h = num_patch_height * height
w = num_patch_width * width
new_h, new_w = unpad_image_shape(h, w, image_size)
new_image_feature_len += new_h * (new_w + 1)
pad_ids = pad_value * (
(new_image_feature_len + len(pad_value)) // len(pad_value)
)
offset = input_ids.index(self.config.image_token_index)
# old_len + pad_len - 1, because we need to remove image_token_id
new_input_ids = (
input_ids[:offset]
+ pad_ids[:new_image_feature_len]
+ input_ids[offset + 1 :]
)
return new_input_ids, offset
class LlavaMistralForCausalLM(LlavaLlamaForCausalLM):
def __init__(
self,
config: LlavaConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__(config, quant_config=quant_config)
self.config = config
self.vision_tower = None
if getattr(self.config, "vision_config", None) is None:
self.config.vision_config = CLIPVisionConfig(self.config.mm_vision_tower)
if getattr(self.config, "text_config", None) is None:
self.config.text_config = MistralConfig(self.config._name_or_path)
self.config.vision_config.hidden_size = config.mm_hidden_size
self.config.text_config.hidden_size = config.hidden_size
if getattr(self.config, "projector_hidden_act", None) is None:
self.config.projector_hidden_act = "gelu"
if getattr(self.config, "image_token_index", None) is None:
self.config.image_token_index = 32000
self.multi_modal_projector = LlavaMultiModalProjector(config)
self.language_model = MistralForCausalLM(config, quant_config=quant_config)
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
self.language_model.model.image_newline = nn.Parameter(
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
)
first_call = True
def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0]
@@ -319,4 +414,8 @@ def monkey_path_clip_vision_embed_forward():
)
EntryClass = LlavaLlamaForCausalLM
EntryClass = [
LlavaLlamaForCausalLM,
LlavaQwenForCausalLM,
LlavaMistralForCausalLM
]