407 lines
14 KiB
Python
407 lines
14 KiB
Python
"""
|
|
Copyright 2023-2024 SGLang Team
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
"""
|
|
|
|
# Adapted from
|
|
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1
|
|
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
|
|
|
from typing import Any, Dict, Iterable, Optional, Tuple
|
|
|
|
import torch
|
|
from torch import nn
|
|
from transformers import LlamaConfig
|
|
from vllm.config import CacheConfig
|
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
|
from vllm.model_executor.layers.activation import SiluAndMul
|
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
ParallelLMHead,
|
|
VocabParallelEmbedding,
|
|
)
|
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
|
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
|
from sglang.srt.layers.radix_attention import RadixAttention
|
|
from sglang.srt.managers.controller.model_runner import InputMetadata
|
|
|
|
MergedColumnParallelLinear = None
|
|
QKVParallelLinear = None
|
|
RowParallelLinear = None
|
|
|
|
|
|
class LlamaMLP(nn.Module):
|
|
def __init__(
|
|
self,
|
|
hidden_size: int,
|
|
intermediate_size: int,
|
|
hidden_act: str,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
self.gate_up_proj = MergedColumnParallelLinear(
|
|
hidden_size,
|
|
[intermediate_size] * 2,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.gate_up_proj",
|
|
)
|
|
self.down_proj = RowParallelLinear(
|
|
intermediate_size,
|
|
hidden_size,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.down_proj",
|
|
)
|
|
if hidden_act != "silu":
|
|
raise ValueError(
|
|
f"Unsupported activation: {hidden_act}. "
|
|
"Only silu is supported for now."
|
|
)
|
|
self.act_fn = SiluAndMul()
|
|
|
|
def forward(self, x):
|
|
gate_up, _ = self.gate_up_proj(x)
|
|
x = self.act_fn(gate_up)
|
|
x, _ = self.down_proj(x)
|
|
return x
|
|
|
|
|
|
class LlamaAttention(nn.Module):
|
|
def __init__(
|
|
self,
|
|
config: LlamaConfig,
|
|
hidden_size: int,
|
|
num_heads: int,
|
|
num_kv_heads: int,
|
|
layer_id: int = 0,
|
|
rope_theta: float = 10000,
|
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
|
rope_is_neox_style: bool = True,
|
|
max_position_embeddings: int = 8192,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
self.hidden_size = hidden_size
|
|
tp_size = get_tensor_model_parallel_world_size()
|
|
self.total_num_heads = num_heads
|
|
assert self.total_num_heads % tp_size == 0
|
|
self.num_heads = self.total_num_heads // tp_size
|
|
self.total_num_kv_heads = num_kv_heads
|
|
if self.total_num_kv_heads >= tp_size:
|
|
# Number of KV heads is greater than TP size, so we partition
|
|
# the KV heads across multiple tensor parallel GPUs.
|
|
assert self.total_num_kv_heads % tp_size == 0
|
|
else:
|
|
# Number of KV heads is less than TP size, so we replicate
|
|
# the KV heads across multiple tensor parallel GPUs.
|
|
assert tp_size % self.total_num_kv_heads == 0
|
|
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
|
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
|
|
self.head_dim = getattr(
|
|
config, "head_dim", self.hidden_size // self.total_num_heads
|
|
)
|
|
self.q_size = self.num_heads * self.head_dim
|
|
self.kv_size = self.num_kv_heads * self.head_dim
|
|
self.scaling = self.head_dim**-0.5
|
|
self.rope_theta = rope_theta
|
|
self.max_position_embeddings = max_position_embeddings
|
|
|
|
self.qkv_proj = QKVParallelLinear(
|
|
hidden_size,
|
|
self.head_dim,
|
|
self.total_num_heads,
|
|
self.total_num_kv_heads,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.qkv_proj",
|
|
)
|
|
self.o_proj = RowParallelLinear(
|
|
self.total_num_heads * self.head_dim,
|
|
hidden_size,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.o_proj",
|
|
)
|
|
|
|
self.rotary_emb = get_rope(
|
|
self.head_dim,
|
|
rotary_dim=self.head_dim,
|
|
max_position=max_position_embeddings,
|
|
base=rope_theta,
|
|
rope_scaling=rope_scaling,
|
|
is_neox_style=rope_is_neox_style,
|
|
)
|
|
self.attn = RadixAttention(
|
|
self.num_heads,
|
|
self.head_dim,
|
|
self.scaling,
|
|
num_kv_heads=self.num_kv_heads,
|
|
layer_id=layer_id,
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
input_metadata: InputMetadata,
|
|
) -> torch.Tensor:
|
|
qkv, _ = self.qkv_proj(hidden_states)
|
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
|
q, k = self.rotary_emb(positions, q, k)
|
|
attn_output = self.attn(q, k, v, input_metadata)
|
|
output, _ = self.o_proj(attn_output)
|
|
return output
|
|
|
|
|
|
class LlamaDecoderLayer(nn.Module):
|
|
def __init__(
|
|
self,
|
|
config: LlamaConfig,
|
|
layer_id: int = 0,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
self.hidden_size = config.hidden_size
|
|
rope_theta = getattr(config, "rope_theta", 10000)
|
|
rope_scaling = getattr(config, "rope_scaling", None)
|
|
if rope_scaling is not None and getattr(
|
|
config, "original_max_position_embeddings", None
|
|
):
|
|
rope_scaling["original_max_position_embeddings"] = (
|
|
config.original_max_position_embeddings
|
|
)
|
|
rope_is_neox_style = getattr(config, "rope_is_neox_style", True)
|
|
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
|
self.self_attn = LlamaAttention(
|
|
config=config,
|
|
hidden_size=self.hidden_size,
|
|
num_heads=config.num_attention_heads,
|
|
num_kv_heads=config.num_key_value_heads,
|
|
layer_id=layer_id,
|
|
rope_theta=rope_theta,
|
|
rope_scaling=rope_scaling,
|
|
rope_is_neox_style=rope_is_neox_style,
|
|
max_position_embeddings=max_position_embeddings,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.self_attn",
|
|
)
|
|
self.mlp = LlamaMLP(
|
|
hidden_size=self.hidden_size,
|
|
intermediate_size=config.intermediate_size,
|
|
hidden_act=config.hidden_act,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.mlp",
|
|
)
|
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
self.post_attention_layernorm = RMSNorm(
|
|
config.hidden_size, eps=config.rms_norm_eps
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
input_metadata: InputMetadata,
|
|
residual: Optional[torch.Tensor],
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
# Self Attention
|
|
if residual is None:
|
|
residual = hidden_states
|
|
hidden_states = self.input_layernorm(hidden_states)
|
|
else:
|
|
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
|
hidden_states = self.self_attn(
|
|
positions=positions,
|
|
hidden_states=hidden_states,
|
|
input_metadata=input_metadata,
|
|
)
|
|
|
|
# Fully Connected
|
|
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
|
hidden_states = self.mlp(hidden_states)
|
|
return hidden_states, residual
|
|
|
|
|
|
class LlamaModel(nn.Module):
|
|
def __init__(
|
|
self,
|
|
config: LlamaConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
) -> None:
|
|
super().__init__()
|
|
self.config = config
|
|
self.padding_idx = config.pad_token_id
|
|
self.vocab_size = config.vocab_size
|
|
self.embed_tokens = VocabParallelEmbedding(
|
|
config.vocab_size,
|
|
config.hidden_size,
|
|
)
|
|
self.layers = nn.ModuleList(
|
|
[
|
|
LlamaDecoderLayer(
|
|
config, i, quant_config=quant_config, prefix=f"model.layers.{i}"
|
|
)
|
|
for i in range(config.num_hidden_layers)
|
|
]
|
|
)
|
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
input_metadata: InputMetadata,
|
|
input_embeds: torch.Tensor = None,
|
|
) -> torch.Tensor:
|
|
if input_embeds is None:
|
|
hidden_states = self.embed_tokens(input_ids)
|
|
else:
|
|
hidden_states = input_embeds
|
|
residual = None
|
|
for i in range(len(self.layers)):
|
|
layer = self.layers[i]
|
|
hidden_states, residual = layer(
|
|
positions,
|
|
hidden_states,
|
|
input_metadata,
|
|
residual,
|
|
)
|
|
hidden_states, _ = self.norm(hidden_states, residual)
|
|
return hidden_states
|
|
|
|
|
|
class LlamaForCausalLM(nn.Module):
|
|
def __init__(
|
|
self,
|
|
config: LlamaConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
cache_config: Optional[CacheConfig] = None,
|
|
efficient_weight_load=False,
|
|
) -> None:
|
|
global MergedColumnParallelLinear
|
|
global QKVParallelLinear
|
|
global RowParallelLinear
|
|
|
|
if efficient_weight_load:
|
|
from sglang.srt.layers.linear import (
|
|
MergedColumnParallelLinear,
|
|
QKVParallelLinear,
|
|
RowParallelLinear,
|
|
)
|
|
else:
|
|
from vllm.model_executor.layers.linear import (
|
|
MergedColumnParallelLinear,
|
|
QKVParallelLinear,
|
|
RowParallelLinear,
|
|
)
|
|
|
|
super().__init__()
|
|
self.config = config
|
|
self.quant_config = quant_config
|
|
self.model = LlamaModel(config, quant_config=quant_config)
|
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
|
self.logits_processor = LogitsProcessor(config)
|
|
|
|
@torch.no_grad()
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
input_metadata: InputMetadata,
|
|
input_embeds: torch.Tensor = None,
|
|
) -> torch.Tensor:
|
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
|
return self.logits_processor(
|
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
|
)
|
|
|
|
def get_module_name(self, name):
|
|
stacked_params_mapping = [
|
|
# (param_name, shard_name, shard_id, num_shard)
|
|
("qkv_proj", "q_proj", "q", 3),
|
|
("qkv_proj", "k_proj", "k", 3),
|
|
("qkv_proj", "v_proj", "v", 3),
|
|
("gate_up_proj", "gate_proj", 0, 2),
|
|
("gate_up_proj", "up_proj", 1, 2),
|
|
]
|
|
for param_name, weight_name, shard_id, num_shard in stacked_params_mapping:
|
|
if weight_name in name:
|
|
return (
|
|
name.replace(weight_name, param_name)[: -len(".weight")],
|
|
num_shard,
|
|
)
|
|
return name[: -len(".weight")], 1
|
|
|
|
def get_num_params(self):
|
|
params_dict = dict(self.named_parameters())
|
|
return len(params_dict)
|
|
|
|
def load_weights(
|
|
self, weights: Iterable[Tuple[str, torch.Tensor]], name=None, loaded_weight=None
|
|
):
|
|
stacked_params_mapping = [
|
|
# (param_name, shard_name, shard_id)
|
|
("qkv_proj", "q_proj", "q"),
|
|
("qkv_proj", "k_proj", "k"),
|
|
("qkv_proj", "v_proj", "v"),
|
|
("gate_up_proj", "gate_proj", 0),
|
|
("gate_up_proj", "up_proj", 1),
|
|
]
|
|
params_dict = dict(self.named_parameters())
|
|
|
|
def load_weights_per_param(name, loaded_weight):
|
|
if "rotary_emb.inv_freq" in name or "projector" in name:
|
|
return
|
|
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
|
# Models trained using ColossalAI may include these tensors in
|
|
# the checkpoint. Skip them.
|
|
return
|
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
|
if weight_name not in name:
|
|
continue
|
|
name = name.replace(weight_name, param_name)
|
|
# Skip loading extra bias for GPTQ models.
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
continue
|
|
if name.startswith("model.vision_tower") and name not in params_dict:
|
|
continue
|
|
param = params_dict[name]
|
|
weight_loader = param.weight_loader
|
|
weight_loader(param, loaded_weight, shard_id)
|
|
break
|
|
else:
|
|
# Skip loading extra bias for GPTQ models.
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
return
|
|
if name.startswith("model.vision_tower") and name not in params_dict:
|
|
return
|
|
param = params_dict[name]
|
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
|
weight_loader(param, loaded_weight)
|
|
|
|
if name is None or loaded_weight is None:
|
|
for name, loaded_weight in weights:
|
|
load_weights_per_param(name, loaded_weight)
|
|
else:
|
|
load_weights_per_param(name, loaded_weight)
|
|
|
|
|
|
EntryClass = LlamaForCausalLM
|