Format code & move functions (#155)

This commit is contained in:
Lianmin Zheng
2024-02-06 13:27:46 -08:00
committed by GitHub
parent a7334aeea1
commit 23f05005fd
14 changed files with 94 additions and 54 deletions

View File

@@ -42,14 +42,14 @@ class QWenMLP(nn.Module):
2 * [intermediate_size],
bias=False,
gather_output=False,
linear_method=linear_method
linear_method=linear_method,
)
self.c_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
input_is_parallel=True,
linear_method=linear_method
linear_method=linear_method,
)
if hidden_act != "silu":
raise ValueError(
@@ -74,7 +74,7 @@ class QWenAttention(nn.Module):
layer_id: int = 0,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
linear_method: Optional[LinearMethodBase] = None
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.hidden_size = hidden_size
@@ -86,18 +86,18 @@ class QWenAttention(nn.Module):
# pylint: disable=invalid-name
self.c_attn = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
bias=True,
linear_method=linear_method
hidden_size,
self.head_dim,
self.total_num_heads,
bias=True,
linear_method=linear_method,
)
self.c_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
input_is_parallel=True,
linear_method=linear_method
linear_method=linear_method,
)
self.rotary_emb = get_rope(
self.head_dim,
@@ -143,12 +143,16 @@ class QWenBlock(nn.Module):
rope_theta=rope_theta,
rope_scaling=rope_scaling,
layer_id=layer_id,
linear_method=linear_method
linear_method=linear_method,
)
self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.mlp = QWenMLP(config.hidden_size, config.intermediate_size // 2, linear_method=linear_method)
self.mlp = QWenMLP(
config.hidden_size,
config.intermediate_size // 2,
linear_method=linear_method,
)
def forward(
self,
@@ -186,7 +190,10 @@ class QWenModel(nn.Module):
config.hidden_size,
)
self.h = nn.ModuleList(
[QWenBlock(config, i, linear_method=linear_method) for i in range(config.num_hidden_layers)]
[
QWenBlock(config, i, linear_method=linear_method)
for i in range(config.num_hidden_layers)
]
)
self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

View File

@@ -4,14 +4,17 @@ from typing import List, Optional
import torch
import torch.nn as nn
from sglang.srt.models.llava import (
LlavaLlamaForCausalLM,
clip_vision_embed_forward,
monkey_path_clip_vision_embed_forward,
)
from transformers import CLIPVisionModel, LlavaConfig
from vllm.model_executor.weight_utils import (
default_weight_loader,
hf_model_weights_iterator,
)
from sglang.srt.models.llava import LlavaLlamaForCausalLM, clip_vision_embed_forward, monkey_path_clip_vision_embed_forward
class YiVLForCausalLM(LlavaLlamaForCausalLM):
def __init__(self, *args, **kwargs):
@@ -19,7 +22,9 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM):
super().__init__(self.config)
self.multi_modal_projector = YiVLMultiModalProjector(self.config)
self.vision_tower_subfolder = self.config.mm_vision_tower.replace("./", "") # Everything after "./"
self.vision_tower_subfolder = self.config.mm_vision_tower.replace(
"./", ""
) # Everything after "./"
def load_weights(
self,
@@ -30,7 +35,9 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM):
):
# We have to use the subfolder of the main model directory (e.g. 01-ai/Yi-VL-6B)
self.vision_tower = CLIPVisionModel.from_pretrained(
model_name_or_path, torch_dtype=torch.float16, subfolder=self.vision_tower_subfolder
model_name_or_path,
torch_dtype=torch.float16,
subfolder=self.vision_tower_subfolder,
).cuda()
self.vision_tower.eval()
@@ -80,14 +87,19 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM):
monkey_path_clip_vision_embed_forward()
class YiVLMultiModalProjector(nn.Module):
def __init__(self, config: LlavaConfig):
super().__init__()
self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size)
self.linear_1 = nn.Linear(
config.vision_config.hidden_size, config.text_config.hidden_size
)
self.ln_1 = nn.LayerNorm(config.text_config.hidden_size)
self.act = nn.GELU()
self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size)
self.linear_2 = nn.Linear(
config.text_config.hidden_size, config.text_config.hidden_size
)
self.ln_2 = nn.LayerNorm(config.text_config.hidden_size)
def forward(self, image_features):
@@ -98,4 +110,5 @@ class YiVLMultiModalProjector(nn.Module):
hidden_states = self.ln_2(hidden_states)
return hidden_states
EntryClass = YiVLForCausalLM
EntryClass = YiVLForCausalLM