EXAONE 3.0 Model Support (#1258)
Co-authored-by: Yineng Zhang <me@zhyncs.com>
This commit is contained in:
5
python/sglang/srt/configs/__init__.py
Normal file
5
python/sglang/srt/configs/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from sglang.srt.configs.exaone import ExaoneConfig
|
||||
|
||||
__all__ = [
|
||||
"ExaoneConfig",
|
||||
]
|
||||
195
python/sglang/srt/configs/exaone.py
Normal file
195
python/sglang/srt/configs/exaone.py
Normal file
@@ -0,0 +1,195 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The LG AI Research EXAONE Lab. All rights reserved.
|
||||
# Copyright 2024 The LG CNS AI Engineering Team.
|
||||
# Copyright 2023-2024 SGLang Team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" EXAONE model configuration """
|
||||
from typing import Any, Dict
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
EXAONE_PRETRAINED_CONFIG_ARCHIVE_MAP: Dict[str, Any] = {}
|
||||
|
||||
|
||||
# ruff: noqa: E501
|
||||
class ExaoneConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a :class:`~transformers.ExaoneModel`. It is used to
|
||||
instantiate a EXAONE model according to the specified arguments, defining the model architecture. Instantiating a
|
||||
configuration with the defaults will yield a similar configuration to that of the Exaone
|
||||
|
||||
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
|
||||
outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (:obj:`int`, `optional`, defaults to 102400):
|
||||
Vocabulary size of the EXAONE model. Defines the number of different tokens that can be represented by the
|
||||
:obj:`inputs_ids` passed when calling :class:`~transformers.ExaoneModel`. Vocabulary size of the model.
|
||||
Defines the different tokens that can be represented by the `inputs_ids` passed to the forward method of
|
||||
:class:`~transformers.EXAONEModel`.
|
||||
max_position_embeddings (:obj:`int`, `optional`, defaults to 2048):
|
||||
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
||||
just in case (e.g., 512 or 1024 or 2048).
|
||||
hidden_size (:obj:`int`, `optional`, defaults to 2048):
|
||||
Dimensionality of the encoder layers and the pooler layer.
|
||||
num_layers (:obj:`int`, `optional`, defaults to 32):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (:obj:`int`, `optional`, defaults to 32):
|
||||
Number of attention heads for each attention layer in the Transformer decoder.
|
||||
num_key_value_heads (:obj:`int`, `optional`):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details checkout [this
|
||||
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
||||
`num_attention_heads`.
|
||||
intermediate_size (:obj:`int`, `optional`, defaults to `hidden_size * 4`):
|
||||
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
||||
activation_function (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"silu"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
rope_theta (:obj:`float`, `optional`, defaults to 10000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
rope_scaling (:obj:`Dict`, `optional`):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
||||
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
||||
accordingly.
|
||||
Expected contents:
|
||||
`rope_type` (:obj:`str`):
|
||||
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
||||
'llama3'], with 'default' being the original RoPE implementation.
|
||||
`factor` (:obj:`float`, `optional`):
|
||||
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
||||
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
||||
original maximum pre-trained length.
|
||||
`original_max_position_embeddings` (:obj:`int`, `optional`):
|
||||
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
||||
pretraining.
|
||||
`attention_factor` (:obj:`float`, `optional`):
|
||||
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
||||
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
||||
`factor` field to infer the suggested value.
|
||||
`beta_fast` (:obj:`float`, `optional`):
|
||||
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 32.
|
||||
`beta_slow` (:obj:`float`, `optional`):
|
||||
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 1.
|
||||
`short_factor` (:obj:`List[float]`, `optional`):
|
||||
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`long_factor` (:obj:`List[float]`, `optional`):
|
||||
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`low_freq_factor` (:obj:`float`, `optional`):
|
||||
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
||||
`high_freq_factor` (:obj:`float`, `optional`):
|
||||
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
||||
embed_dropout (:obj:`float`, `optional`, defaults to 0.0):
|
||||
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
|
||||
attention_dropout (:obj:`float`, `optional`, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
layer_norm_epsilon (:obj:`float`, `optional`, defaults to 1e-5):
|
||||
The epsilon used by the layer normalization layers.
|
||||
initializer_range (:obj:`float`, `optional`, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if ``configs.is_decoder=True``.
|
||||
bos_token_id (:obj:`int`, `optional`, defaults to 0):
|
||||
Beginning of stream token id.
|
||||
eos_token_id (:obj:`int`, `optional`, defaults to 2):
|
||||
End of stream token id.
|
||||
tie_word_embeddings (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether to tie weight embeddings
|
||||
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
|
||||
|
||||
Example::
|
||||
|
||||
>>> from transformers import EXAONEModel, ExaoneConfig
|
||||
|
||||
>>> # Initializing a EXAONE configuration
|
||||
>>> configuration = ExaoneConfig()
|
||||
|
||||
>>> # Initializing a model from configuration
|
||||
>>> model = EXAONEModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.configs
|
||||
"""
|
||||
|
||||
model_type = "exaone"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
attribute_map = {"num_hidden_layers": "num_layers"}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=102400,
|
||||
max_position_embeddings=2048,
|
||||
hidden_size=2048,
|
||||
num_layers=32,
|
||||
num_attention_heads=32,
|
||||
num_key_value_heads=None,
|
||||
intermediate_size=None,
|
||||
activation_function="silu",
|
||||
rope_theta=10000.0,
|
||||
rope_scaling=None,
|
||||
embed_dropout=0.0,
|
||||
attention_dropout=0.0,
|
||||
layer_norm_epsilon=1e-5,
|
||||
initializer_range=0.02,
|
||||
use_cache=True,
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
tie_word_embeddings=True,
|
||||
**kwargs
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.num_layers = num_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_hidden_layers = num_layers
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
if intermediate_size:
|
||||
self.intermediate_size = intermediate_size
|
||||
else:
|
||||
self.intermediate_size = hidden_size * 4
|
||||
self.activation_function = activation_function
|
||||
self.embed_dropout = embed_dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.layer_norm_epsilon = layer_norm_epsilon
|
||||
self.initializer_range = initializer_range
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
|
||||
self.bos_token_id = bos_token_id
|
||||
self.eos_token_id = eos_token_id
|
||||
|
||||
super().__init__(
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs
|
||||
)
|
||||
@@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
"""Utilities for Huggingface Transformers."""
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import json
|
||||
import os
|
||||
@@ -34,14 +35,21 @@ from transformers import (
|
||||
try:
|
||||
from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig
|
||||
|
||||
from sglang.srt.configs import ExaoneConfig
|
||||
|
||||
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
||||
ChatGLMConfig.model_type: ChatGLMConfig,
|
||||
DbrxConfig.model_type: DbrxConfig,
|
||||
ExaoneConfig.model_type: ExaoneConfig,
|
||||
}
|
||||
except ImportError:
|
||||
# We want this file to run without vllm dependency
|
||||
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {}
|
||||
|
||||
for name, cls in _CONFIG_REGISTRY.items():
|
||||
with contextlib.suppress(ValueError):
|
||||
AutoConfig.register(name, cls)
|
||||
|
||||
from sglang.srt.utils import is_multimodal_model
|
||||
|
||||
|
||||
@@ -53,7 +61,7 @@ def download_from_hf(model_path: str):
|
||||
|
||||
|
||||
def get_config_json(model_path: str):
|
||||
with open(os.path.join(model_path, "config.json")) as f:
|
||||
with open(os.path.join(model_path, "configs.json")) as f:
|
||||
config = json.load(f)
|
||||
return config
|
||||
|
||||
@@ -89,7 +97,7 @@ CONTEXT_LENGTH_KEYS = [
|
||||
|
||||
|
||||
def get_context_length(config):
|
||||
"""Get the context length of a model from a huggingface model config."""
|
||||
"""Get the context length of a model from a huggingface model configs."""
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
if rope_scaling:
|
||||
rope_scaling_factor = config.rope_scaling["factor"]
|
||||
|
||||
399
python/sglang/srt/models/exaone.py
Normal file
399
python/sglang/srt/models/exaone.py
Normal file
@@ -0,0 +1,399 @@
|
||||
"""
|
||||
Copyright 2024 The LGcns AI Engineering Team
|
||||
Copyright 2023-2024 SGLang Team
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
# Adapted from llama2.py
|
||||
"""Inference-only Exaone model compatible with HuggingFace weights."""
|
||||
|
||||
from typing import Any, Dict, Iterable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.sampler import Sampler
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
|
||||
|
||||
class ExaoneGatedMLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size,
|
||||
[intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
)
|
||||
self.c_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.c_proj",
|
||||
)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(
|
||||
f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now."
|
||||
)
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x):
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x, _ = self.c_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class ExaoneAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
layer_id: int = 0,
|
||||
rope_theta: float = 500000,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
rope_is_neox_style: bool = True,
|
||||
max_position_embeddings: int = 4096,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
self.total_num_kv_heads = num_kv_heads
|
||||
if self.total_num_kv_heads >= tp_size:
|
||||
# Number of KV heads is greater than TP size, so we partition
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert self.total_num_kv_heads % tp_size == 0
|
||||
else:
|
||||
# Number of KV heads is less than TP size, so we replicate
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
|
||||
self.head_dim = getattr(
|
||||
config, "head_dim", self.hidden_size // self.total_num_heads
|
||||
)
|
||||
self.rotary_dim = int(
|
||||
self.head_dim * getattr(config, "partial_rotary_factor", 1)
|
||||
)
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
self.out_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.out_proj",
|
||||
)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.rotary_dim,
|
||||
max_position=max_position_embeddings,
|
||||
base=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
is_neox_style=rope_is_neox_style,
|
||||
)
|
||||
self.attn = RadixAttention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_id,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, input_metadata)
|
||||
output, _ = self.out_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class ExaoneDecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
layer_id: int = 0,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 500000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
if rope_scaling is not None and getattr(
|
||||
config, "original_max_position_embeddings", None
|
||||
):
|
||||
rope_scaling["original_max_position_embeddings"] = (
|
||||
config.original_max_position_embeddings
|
||||
)
|
||||
rope_is_neox_style = getattr(config, "rope_is_neox_style", True)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings", 4096)
|
||||
self.self_attn = ExaoneAttention(
|
||||
config=config,
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
layer_id=layer_id,
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
rope_is_neox_style=rope_is_neox_style,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
self.mlp = ExaoneGatedMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.activation_function,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
rms_norm_eps = config.layer_norm_epsilon
|
||||
self.ln_1 = RMSNorm(config.hidden_size, eps=rms_norm_eps)
|
||||
self.ln_2 = RMSNorm(config.hidden_size, eps=rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.ln_1(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.ln_1(hidden_states, residual)
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
input_metadata=input_metadata,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.ln_2(hidden_states, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
class ExaoneModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
self.wte = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
self.h = nn.ModuleList(
|
||||
[
|
||||
ExaoneDecoderLayer(
|
||||
config, i, quant_config=quant_config, prefix=f"model.h.{i}"
|
||||
)
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
rms_norm_eps = config.layer_norm_epsilon
|
||||
self.ln_f = RMSNorm(config.hidden_size, eps=rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
if input_embeds is None:
|
||||
hidden_states = self.wte(input_ids)
|
||||
else:
|
||||
hidden_states = input_embeds
|
||||
residual = None
|
||||
for i in range(len(self.h)):
|
||||
layer = self.h[i]
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
input_metadata,
|
||||
residual,
|
||||
)
|
||||
hidden_states, _ = self.ln_f(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class ExaoneForCausalLM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
efficient_weight_load=False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.transformer = ExaoneModel(config, quant_config=quant_config)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.sampler = Sampler()
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> LogitsProcessorOutput:
|
||||
hidden_states = self.transformer(
|
||||
input_ids, positions, input_metadata, input_embeds
|
||||
)
|
||||
logits_output = self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||
)
|
||||
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
||||
return sample_output, logits_output
|
||||
|
||||
def get_module_name(self, name):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id, num_shard)
|
||||
("qkv_proj", "q_proj", "q", 3),
|
||||
("qkv_proj", "k_proj", "k", 3),
|
||||
("qkv_proj", "v_proj", "v", 3),
|
||||
("gate_up_proj", "c_fc_0", 0, 2),
|
||||
("gate_up_proj", "c_fc_1", 1, 2),
|
||||
]
|
||||
for param_name, weight_name, shard_id, num_shard in stacked_params_mapping:
|
||||
if weight_name in name:
|
||||
return (
|
||||
name.replace(weight_name, param_name)[: -len(".weight")],
|
||||
num_shard,
|
||||
)
|
||||
return name[: -len(".weight")], 1
|
||||
|
||||
def get_num_params(self):
|
||||
params_dict = dict(self.named_parameters())
|
||||
return len(params_dict)
|
||||
|
||||
def load_weights(
|
||||
self, weights: Iterable[Tuple[str, torch.Tensor]], name=None, loaded_weight=None
|
||||
):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "c_fc_0", 0),
|
||||
("gate_up_proj", "c_fc_1", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
|
||||
def load_weights_per_param(name, loaded_weight):
|
||||
if "rotary_emb.inv_freq" in name or "projector" in name:
|
||||
return
|
||||
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
||||
# Models trained using ColossalAI may include these tensors in
|
||||
# the checkpoint. Skip them.
|
||||
return
|
||||
if name.startswith("model.vision_tower") and name not in params_dict:
|
||||
return
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
return
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
if name is None or loaded_weight is None:
|
||||
for name, loaded_weight in weights:
|
||||
name = name.replace("attn.attention", "self_attn")
|
||||
load_weights_per_param(name, loaded_weight)
|
||||
else:
|
||||
name = name.replace("attn.attention", "self_attn")
|
||||
load_weights_per_param(name, loaded_weight)
|
||||
|
||||
|
||||
EntryClass = ExaoneForCausalLM
|
||||
Reference in New Issue
Block a user