Format code & move functions (#155)
This commit is contained in:
@@ -42,14 +42,14 @@ class QWenMLP(nn.Module):
|
||||
2 * [intermediate_size],
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
linear_method=linear_method
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.c_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
linear_method=linear_method
|
||||
linear_method=linear_method,
|
||||
)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(
|
||||
@@ -74,7 +74,7 @@ class QWenAttention(nn.Module):
|
||||
layer_id: int = 0,
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
linear_method: Optional[LinearMethodBase] = None
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@@ -86,18 +86,18 @@ class QWenAttention(nn.Module):
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
self.c_attn = QKVParallelLinear(
|
||||
hidden_size,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
bias=True,
|
||||
linear_method=linear_method
|
||||
hidden_size,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
bias=True,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.c_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
linear_method=linear_method
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
@@ -143,12 +143,16 @@ class QWenBlock(nn.Module):
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
layer_id=layer_id,
|
||||
linear_method=linear_method
|
||||
linear_method=linear_method,
|
||||
)
|
||||
|
||||
self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||
|
||||
self.mlp = QWenMLP(config.hidden_size, config.intermediate_size // 2, linear_method=linear_method)
|
||||
self.mlp = QWenMLP(
|
||||
config.hidden_size,
|
||||
config.intermediate_size // 2,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -186,7 +190,10 @@ class QWenModel(nn.Module):
|
||||
config.hidden_size,
|
||||
)
|
||||
self.h = nn.ModuleList(
|
||||
[QWenBlock(config, i, linear_method=linear_method) for i in range(config.num_hidden_layers)]
|
||||
[
|
||||
QWenBlock(config, i, linear_method=linear_method)
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||
|
||||
|
||||
@@ -4,14 +4,17 @@ from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from sglang.srt.models.llava import (
|
||||
LlavaLlamaForCausalLM,
|
||||
clip_vision_embed_forward,
|
||||
monkey_path_clip_vision_embed_forward,
|
||||
)
|
||||
from transformers import CLIPVisionModel, LlavaConfig
|
||||
from vllm.model_executor.weight_utils import (
|
||||
default_weight_loader,
|
||||
hf_model_weights_iterator,
|
||||
)
|
||||
|
||||
from sglang.srt.models.llava import LlavaLlamaForCausalLM, clip_vision_embed_forward, monkey_path_clip_vision_embed_forward
|
||||
|
||||
|
||||
class YiVLForCausalLM(LlavaLlamaForCausalLM):
|
||||
def __init__(self, *args, **kwargs):
|
||||
@@ -19,7 +22,9 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM):
|
||||
super().__init__(self.config)
|
||||
|
||||
self.multi_modal_projector = YiVLMultiModalProjector(self.config)
|
||||
self.vision_tower_subfolder = self.config.mm_vision_tower.replace("./", "") # Everything after "./"
|
||||
self.vision_tower_subfolder = self.config.mm_vision_tower.replace(
|
||||
"./", ""
|
||||
) # Everything after "./"
|
||||
|
||||
def load_weights(
|
||||
self,
|
||||
@@ -30,7 +35,9 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM):
|
||||
):
|
||||
# We have to use the subfolder of the main model directory (e.g. 01-ai/Yi-VL-6B)
|
||||
self.vision_tower = CLIPVisionModel.from_pretrained(
|
||||
model_name_or_path, torch_dtype=torch.float16, subfolder=self.vision_tower_subfolder
|
||||
model_name_or_path,
|
||||
torch_dtype=torch.float16,
|
||||
subfolder=self.vision_tower_subfolder,
|
||||
).cuda()
|
||||
|
||||
self.vision_tower.eval()
|
||||
@@ -80,14 +87,19 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM):
|
||||
|
||||
monkey_path_clip_vision_embed_forward()
|
||||
|
||||
|
||||
class YiVLMultiModalProjector(nn.Module):
|
||||
def __init__(self, config: LlavaConfig):
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size)
|
||||
self.linear_1 = nn.Linear(
|
||||
config.vision_config.hidden_size, config.text_config.hidden_size
|
||||
)
|
||||
self.ln_1 = nn.LayerNorm(config.text_config.hidden_size)
|
||||
self.act = nn.GELU()
|
||||
self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size)
|
||||
self.linear_2 = nn.Linear(
|
||||
config.text_config.hidden_size, config.text_config.hidden_size
|
||||
)
|
||||
self.ln_2 = nn.LayerNorm(config.text_config.hidden_size)
|
||||
|
||||
def forward(self, image_features):
|
||||
@@ -98,4 +110,5 @@ class YiVLMultiModalProjector(nn.Module):
|
||||
hidden_states = self.ln_2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
EntryClass = YiVLForCausalLM
|
||||
|
||||
EntryClass = YiVLForCausalLM
|
||||
|
||||
Reference in New Issue
Block a user