[Feat/Fix] Refactoring Llava models into single file (#475)
This commit is contained in:
@@ -5,7 +5,7 @@ from typing import List, Iterable, Optional, Tuple
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import CLIPVisionModel, LlavaConfig
|
||||
from transformers import CLIPVisionModel, CLIPVisionConfig, LlavaConfig, Qwen2Config, MistralConfig
|
||||
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
@@ -18,6 +18,8 @@ from sglang.srt.mm_utils import (
|
||||
unpad_image_shape,
|
||||
)
|
||||
from sglang.srt.models.llama2 import LlamaForCausalLM
|
||||
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
||||
from sglang.srt.models.mistral import MistralForCausalLM
|
||||
|
||||
|
||||
class LlavaLlamaForCausalLM(nn.Module):
|
||||
@@ -287,8 +289,101 @@ class LlavaLlamaForCausalLM(nn.Module):
|
||||
return self.image_size // self.patch_size
|
||||
|
||||
|
||||
first_call = True
|
||||
class LlavaQwenForCausalLM(LlavaLlamaForCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
config: LlavaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__(config, quant_config=quant_config)
|
||||
self.config = config
|
||||
self.vision_tower = None
|
||||
if getattr(self.config, "vision_config", None) is None:
|
||||
self.config.vision_config = CLIPVisionConfig(self.config.mm_vision_tower)
|
||||
|
||||
if getattr(self.config, "text_config", None) is None:
|
||||
self.config.text_config = Qwen2Config(self.config._name_or_path)
|
||||
|
||||
self.config.vision_config.hidden_size = config.mm_hidden_size
|
||||
self.config.text_config.hidden_size = config.hidden_size
|
||||
|
||||
if getattr(self.config, "projector_hidden_act", None) is None:
|
||||
self.config.projector_hidden_act = "gelu"
|
||||
|
||||
if getattr(self.config, "image_token_index", None) is None:
|
||||
self.config.image_token_index = 151646
|
||||
|
||||
self.multi_modal_projector = LlavaMultiModalProjector(config)
|
||||
self.language_model = Qwen2ForCausalLM(config, quant_config=quant_config)
|
||||
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
|
||||
self.language_model.model.image_newline = nn.Parameter(
|
||||
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
|
||||
)
|
||||
|
||||
def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None):
|
||||
new_image_feature_len = self.image_feature_len
|
||||
# now only support spatial_unpad + anyres
|
||||
if self.mm_patch_merge_type.startswith("spatial"):
|
||||
height = width = self.num_patches_per_side
|
||||
if pt_shape[0] > 1:
|
||||
if self.image_aspect_ratio == "anyres":
|
||||
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
|
||||
image_size,
|
||||
self.image_grid_pinpoints,
|
||||
self.vision_tower.config.image_size,
|
||||
)
|
||||
if "unpad" in self.mm_patch_merge_type:
|
||||
h = num_patch_height * height
|
||||
w = num_patch_width * width
|
||||
new_h, new_w = unpad_image_shape(h, w, image_size)
|
||||
new_image_feature_len += new_h * (new_w + 1)
|
||||
|
||||
pad_ids = pad_value * (
|
||||
(new_image_feature_len + len(pad_value)) // len(pad_value)
|
||||
)
|
||||
offset = input_ids.index(self.config.image_token_index)
|
||||
# old_len + pad_len - 1, because we need to remove image_token_id
|
||||
new_input_ids = (
|
||||
input_ids[:offset]
|
||||
+ pad_ids[:new_image_feature_len]
|
||||
+ input_ids[offset + 1 :]
|
||||
)
|
||||
return new_input_ids, offset
|
||||
|
||||
|
||||
class LlavaMistralForCausalLM(LlavaLlamaForCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
config: LlavaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__(config, quant_config=quant_config)
|
||||
self.config = config
|
||||
self.vision_tower = None
|
||||
if getattr(self.config, "vision_config", None) is None:
|
||||
self.config.vision_config = CLIPVisionConfig(self.config.mm_vision_tower)
|
||||
|
||||
if getattr(self.config, "text_config", None) is None:
|
||||
self.config.text_config = MistralConfig(self.config._name_or_path)
|
||||
|
||||
self.config.vision_config.hidden_size = config.mm_hidden_size
|
||||
self.config.text_config.hidden_size = config.hidden_size
|
||||
|
||||
if getattr(self.config, "projector_hidden_act", None) is None:
|
||||
self.config.projector_hidden_act = "gelu"
|
||||
|
||||
if getattr(self.config, "image_token_index", None) is None:
|
||||
self.config.image_token_index = 32000
|
||||
|
||||
self.multi_modal_projector = LlavaMultiModalProjector(config)
|
||||
self.language_model = MistralForCausalLM(config, quant_config=quant_config)
|
||||
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
|
||||
self.language_model.model.image_newline = nn.Parameter(
|
||||
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
|
||||
)
|
||||
|
||||
|
||||
first_call = True
|
||||
|
||||
def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
||||
batch_size = pixel_values.shape[0]
|
||||
@@ -319,4 +414,8 @@ def monkey_path_clip_vision_embed_forward():
|
||||
)
|
||||
|
||||
|
||||
EntryClass = LlavaLlamaForCausalLM
|
||||
EntryClass = [
|
||||
LlavaLlamaForCausalLM,
|
||||
LlavaQwenForCausalLM,
|
||||
LlavaMistralForCausalLM
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user