279 lines
9.0 KiB
Python
279 lines
9.0 KiB
Python
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
import torch
|
|
from torch import nn
|
|
from transformers import PretrainedConfig
|
|
from vllm.model_executor.layers.activation import SiluAndMul
|
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
|
from vllm.model_executor.layers.linear import (
|
|
LinearMethodBase,
|
|
MergedColumnParallelLinear,
|
|
QKVParallelLinear,
|
|
RowParallelLinear,
|
|
)
|
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
ParallelLMHead,
|
|
VocabParallelEmbedding,
|
|
)
|
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
|
get_tensor_model_parallel_world_size,
|
|
)
|
|
from vllm.model_executor.weight_utils import (
|
|
default_weight_loader,
|
|
hf_model_weights_iterator,
|
|
)
|
|
|
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
|
from sglang.srt.layers.radix_attention import RadixAttention
|
|
from sglang.srt.managers.router.model_runner import InputMetadata
|
|
|
|
|
|
class QWenMLP(nn.Module):
|
|
def __init__(
|
|
self,
|
|
hidden_size: int,
|
|
intermediate_size: int,
|
|
hidden_act: str = "silu",
|
|
linear_method: Optional[LinearMethodBase] = None,
|
|
):
|
|
super().__init__()
|
|
self.gate_up_proj = MergedColumnParallelLinear(
|
|
hidden_size,
|
|
2 * [intermediate_size],
|
|
bias=False,
|
|
gather_output=False,
|
|
linear_method=linear_method,
|
|
)
|
|
self.c_proj = RowParallelLinear(
|
|
intermediate_size,
|
|
hidden_size,
|
|
bias=False,
|
|
input_is_parallel=True,
|
|
linear_method=linear_method,
|
|
)
|
|
if hidden_act != "silu":
|
|
raise ValueError(
|
|
f"Unsupported activation: {hidden_act}. "
|
|
"Only silu is supported for now."
|
|
)
|
|
self.act_fn = SiluAndMul()
|
|
|
|
def forward(self, x):
|
|
gate_up, _ = self.gate_up_proj(x)
|
|
x = self.act_fn(gate_up)
|
|
x, _ = self.c_proj(x)
|
|
return x
|
|
|
|
|
|
class QWenAttention(nn.Module):
|
|
def __init__(
|
|
self,
|
|
hidden_size: int,
|
|
num_heads: int,
|
|
max_position_embeddings: int,
|
|
layer_id: int = 0,
|
|
rope_theta: float = 10000,
|
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
|
linear_method: Optional[LinearMethodBase] = None,
|
|
):
|
|
super().__init__()
|
|
self.hidden_size = hidden_size
|
|
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
|
|
self.total_num_heads = num_heads
|
|
assert self.total_num_heads % tensor_model_parallel_world_size == 0
|
|
self.num_heads = self.total_num_heads // tensor_model_parallel_world_size
|
|
self.head_dim = hidden_size // self.total_num_heads
|
|
|
|
# pylint: disable=invalid-name
|
|
self.c_attn = QKVParallelLinear(
|
|
hidden_size,
|
|
self.head_dim,
|
|
self.total_num_heads,
|
|
bias=True,
|
|
linear_method=linear_method,
|
|
)
|
|
self.c_proj = RowParallelLinear(
|
|
self.total_num_heads * self.head_dim,
|
|
hidden_size,
|
|
bias=False,
|
|
input_is_parallel=True,
|
|
linear_method=linear_method,
|
|
)
|
|
self.rotary_emb = get_rope(
|
|
self.head_dim,
|
|
rotary_dim=self.head_dim,
|
|
max_position=max_position_embeddings,
|
|
base=rope_theta,
|
|
rope_scaling=rope_scaling,
|
|
)
|
|
self.scaling = self.head_dim**-0.5
|
|
self.attn = RadixAttention(
|
|
self.num_heads,
|
|
self.head_dim,
|
|
self.scaling,
|
|
num_kv_heads=self.num_heads,
|
|
layer_id=layer_id,
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
input_metadata: InputMetadata,
|
|
) -> torch.Tensor:
|
|
qkv, _ = self.c_attn(hidden_states)
|
|
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
|
q, k = self.rotary_emb(positions, q, k)
|
|
attn_output = self.attn(q, k, v, input_metadata)
|
|
output, _ = self.c_proj(attn_output)
|
|
return output
|
|
|
|
|
|
class QWenBlock(nn.Module):
|
|
def __init__(self, config: PretrainedConfig, layer_id, linear_method=None):
|
|
super().__init__()
|
|
self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
|
|
|
rope_theta = getattr(config, "rope_theta", 10000)
|
|
rope_scaling = getattr(config, "rope_scaling", None)
|
|
self.attn = QWenAttention(
|
|
config.hidden_size,
|
|
config.num_attention_heads,
|
|
config.max_position_embeddings,
|
|
rope_theta=rope_theta,
|
|
rope_scaling=rope_scaling,
|
|
layer_id=layer_id,
|
|
linear_method=linear_method,
|
|
)
|
|
|
|
self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
|
|
|
self.mlp = QWenMLP(
|
|
config.hidden_size,
|
|
config.intermediate_size // 2,
|
|
linear_method=linear_method,
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
input_metadata: InputMetadata,
|
|
) -> torch.Tensor:
|
|
# Self Attention
|
|
residual = hidden_states
|
|
hidden_states = self.ln_1(hidden_states)
|
|
hidden_states = self.attn(
|
|
positions=positions,
|
|
hidden_states=hidden_states,
|
|
input_metadata=input_metadata,
|
|
)
|
|
hidden_states = residual + hidden_states
|
|
|
|
# Fully Connected
|
|
residual = hidden_states
|
|
hidden_states = self.ln_2(hidden_states)
|
|
hidden_states = self.mlp(hidden_states)
|
|
hidden_states = residual + hidden_states
|
|
return hidden_states
|
|
|
|
|
|
class QWenModel(nn.Module):
|
|
def __init__(self, config: PretrainedConfig, linear_method=None):
|
|
super().__init__()
|
|
self.config = config
|
|
self.vocab_size = config.vocab_size
|
|
|
|
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
|
self.wte = VocabParallelEmbedding(
|
|
vocab_size,
|
|
config.hidden_size,
|
|
)
|
|
self.h = nn.ModuleList(
|
|
[
|
|
QWenBlock(config, i, linear_method=linear_method)
|
|
for i in range(config.num_hidden_layers)
|
|
]
|
|
)
|
|
self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
input_metadata: InputMetadata,
|
|
) -> torch.Tensor:
|
|
hidden_states = self.wte(input_ids)
|
|
for i in range(len(self.h)):
|
|
layer = self.h[i]
|
|
hidden_states = layer(
|
|
positions,
|
|
hidden_states,
|
|
input_metadata,
|
|
)
|
|
hidden_states = self.ln_f(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class QWenLMHeadModel(nn.Module):
|
|
def __init__(self, config: PretrainedConfig, linear_method=None):
|
|
super().__init__()
|
|
self.config = config
|
|
self.transformer = QWenModel(config, linear_method=linear_method)
|
|
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
|
self.lm_head = ParallelLMHead(vocab_size, config.hidden_size)
|
|
self.logits_processor = LogitsProcessor(config)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
input_metadata: InputMetadata,
|
|
):
|
|
hidden_states = self.transformer(input_ids, positions, input_metadata)
|
|
next_tokens = self.logits_processor(
|
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
|
)
|
|
return next_tokens
|
|
|
|
def load_weights(
|
|
self,
|
|
model_name_or_path: str,
|
|
cache_dir: Optional[str] = None,
|
|
load_format: str = "auto",
|
|
revision: Optional[str] = None,
|
|
):
|
|
stacked_params_mapping = [
|
|
# (param_name, shard_name, shard_id)
|
|
("gate_up_proj", "w2", 0),
|
|
("gate_up_proj", "w1", 1),
|
|
]
|
|
params_dict = dict(self.named_parameters())
|
|
for name, loaded_weight in hf_model_weights_iterator(
|
|
model_name_or_path, cache_dir, load_format, revision
|
|
):
|
|
if "rotary_emb.inv_freq" in name:
|
|
continue
|
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
|
if weight_name not in name:
|
|
continue
|
|
name = name.replace(weight_name, param_name)
|
|
# Skip loading extra bias for GPTQ models.
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
continue
|
|
param = params_dict[name]
|
|
weight_loader = param.weight_loader
|
|
weight_loader(param, loaded_weight, shard_id)
|
|
break
|
|
else:
|
|
# Skip loading extra bias for GPTQ models.
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
continue
|
|
param = params_dict[name]
|
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
|
weight_loader(param, loaded_weight)
|
|
|
|
|
|
EntryClass = QWenLMHeadModel
|