2026-02-07 03:16:38 +00:00
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/exaone/modular_exaone.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_exaone.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# Copyright 2026 The LG AI Research and HuggingFace Inc. team. All rights reserved.
2025-03-18 05:36:41 +00:00
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" LG AI Research EXAONE Lab """
2026-02-07 03:16:38 +00:00
from collections . abc import Callable
from typing import Optional
2025-03-18 05:36:41 +00:00
import torch
from torch import nn
from transformers . activations import ACT2FN
2026-02-07 03:16:38 +00:00
from transformers . cache_utils import Cache , DynamicCache
2025-03-18 05:36:41 +00:00
from transformers . generation import GenerationMixin
2026-02-07 03:16:38 +00:00
from transformers . integrations import use_kernel_forward_from_hub , use_kernel_func_from_hub , use_kernelized_func
from transformers . masking_utils import create_causal_mask
from transformers . modeling_layers import GradientCheckpointingLayer
from transformers . modeling_outputs import BaseModelOutputWithPast , CausalLMOutputWithPast
from transformers . modeling_rope_utils import ROPE_INIT_FUNCTIONS , dynamic_rope_update
from transformers . modeling_utils import ALL_ATTENTION_FUNCTIONS , PreTrainedModel
from transformers . processing_utils import Unpack
from transformers . utils import TransformersKwargs , auto_docstring , can_return_tuple
from transformers . utils . generic import check_model_inputs , maybe_autocast
2025-03-18 05:36:41 +00:00
from . configuration_exaone import ExaoneConfig
2026-02-07 03:16:38 +00:00
@use_kernel_forward_from_hub ( " RMSNorm " )
class ExaoneRMSNorm ( nn . Module ) :
def __init__ ( self , hidden_size , eps = 1e-6 ) :
"""
ExaoneRMSNorm is equivalent to T5LayerNorm
"""
super ( ) . __init__ ( )
self . weight = nn . Parameter ( torch . ones ( hidden_size ) )
self . variance_epsilon = eps
2025-03-18 05:36:41 +00:00
2026-02-07 03:16:38 +00:00
def forward ( self , hidden_states ) :
input_dtype = hidden_states . dtype
hidden_states = hidden_states . to ( torch . float32 )
variance = hidden_states . pow ( 2 ) . mean ( - 1 , keepdim = True )
hidden_states = hidden_states * torch . rsqrt ( variance + self . variance_epsilon )
return self . weight * hidden_states . to ( input_dtype )
2025-03-18 05:36:41 +00:00
2026-02-07 03:16:38 +00:00
def extra_repr ( self ) :
return f " { tuple ( self . weight . shape ) } , eps= { self . variance_epsilon } "
2025-03-18 05:36:41 +00:00
2026-02-07 03:16:38 +00:00
def rotate_half ( x ) :
""" Rotates half the hidden dims of the input. """
x1 = x [ . . . , : x . shape [ - 1 ] / / 2 ]
x2 = x [ . . . , x . shape [ - 1 ] / / 2 : ]
return torch . cat ( ( - x2 , x1 ) , dim = - 1 )
2025-03-18 05:36:41 +00:00
2026-02-07 03:16:38 +00:00
@use_kernel_func_from_hub ( " rotary_pos_emb " )
2025-03-18 05:36:41 +00:00
def apply_rotary_pos_emb ( q , k , cos , sin , unsqueeze_dim = 1 ) :
""" Applies Rotary Position Embedding to the query and key tensors.
Args :
q ( ` torch . Tensor ` ) : The query tensor .
k ( ` torch . Tensor ` ) : The key tensor .
cos ( ` torch . Tensor ` ) : The cosine part of the rotary embedding .
sin ( ` torch . Tensor ` ) : The sine part of the rotary embedding .
unsqueeze_dim ( ` int ` , * optional * , defaults to 1 ) :
The ' unsqueeze_dim ' argument specifies the dimension along which to unsqueeze cos [ position_ids ] and
sin [ position_ids ] so that they can be properly broadcasted to the dimensions of q and k . For example , note
that cos [ position_ids ] and sin [ position_ids ] have the shape [ batch_size , seq_len , head_dim ] . Then , if q and
k have the shape [ batch_size , heads , seq_len , head_dim ] , then setting unsqueeze_dim = 1 makes
cos [ position_ids ] and sin [ position_ids ] broadcastable to the shapes of q and k . Similarly , if q and k have
the shape [ batch_size , seq_len , heads , head_dim ] , then set unsqueeze_dim = 2.
Returns :
` tuple ( torch . Tensor ) ` comprising of the query and key tensors rotated using the Rotary Position Embedding .
"""
cos = cos . unsqueeze ( unsqueeze_dim )
sin = sin . unsqueeze ( unsqueeze_dim )
q_embed = ( q * cos ) + ( rotate_half ( q ) * sin )
k_embed = ( k * cos ) + ( rotate_half ( k ) * sin )
return q_embed , k_embed
2026-02-07 03:16:38 +00:00
def repeat_kv ( hidden_states : torch . Tensor , n_rep : int ) - > torch . Tensor :
2025-03-18 05:36:41 +00:00
"""
2026-02-07 03:16:38 +00:00
This is the equivalent of torch . repeat_interleave ( x , dim = 1 , repeats = n_rep ) . The hidden states go from ( batch ,
num_key_value_heads , seqlen , head_dim ) to ( batch , num_attention_heads , seqlen , head_dim )
2025-03-18 05:36:41 +00:00
"""
2026-02-07 03:16:38 +00:00
batch , num_key_value_heads , slen , head_dim = hidden_states . shape
if n_rep == 1 :
return hidden_states
hidden_states = hidden_states [ : , : , None , : , : ] . expand ( batch , num_key_value_heads , n_rep , slen , head_dim )
return hidden_states . reshape ( batch , num_key_value_heads * n_rep , slen , head_dim )
2025-03-18 05:36:41 +00:00
2026-02-07 03:16:38 +00:00
def eager_attention_forward (
module : nn . Module ,
query : torch . Tensor ,
key : torch . Tensor ,
value : torch . Tensor ,
attention_mask : torch . Tensor | None ,
scaling : float ,
dropout : float = 0.0 ,
* * kwargs : Unpack [ TransformersKwargs ] ,
) :
key_states = repeat_kv ( key , module . num_key_value_groups )
value_states = repeat_kv ( value , module . num_key_value_groups )
2025-03-18 05:36:41 +00:00
2026-02-07 03:16:38 +00:00
attn_weights = torch . matmul ( query , key_states . transpose ( 2 , 3 ) ) * scaling
if attention_mask is not None :
causal_mask = attention_mask [ : , : , : , : key_states . shape [ - 2 ] ]
attn_weights = attn_weights + causal_mask
2025-03-18 05:36:41 +00:00
2026-02-07 03:16:38 +00:00
attn_weights = nn . functional . softmax ( attn_weights , dim = - 1 , dtype = torch . float32 ) . to ( query . dtype )
attn_weights = nn . functional . dropout ( attn_weights , p = dropout , training = module . training )
attn_output = torch . matmul ( attn_weights , value_states )
attn_output = attn_output . transpose ( 1 , 2 ) . contiguous ( )
2025-03-18 05:36:41 +00:00
2026-02-07 03:16:38 +00:00
return attn_output , attn_weights
2025-03-18 05:36:41 +00:00
2026-02-07 03:16:38 +00:00
@use_kernelized_func ( apply_rotary_pos_emb )
class ExaoneAttention ( nn . Module ) :
""" Multi-headed attention from ' Attention Is All You Need ' paper """
2025-03-18 05:36:41 +00:00
2026-02-07 03:16:38 +00:00
def __init__ ( self , config : ExaoneConfig , layer_idx : int ) :
2025-03-18 05:36:41 +00:00
super ( ) . __init__ ( )
self . config = config
self . layer_idx = layer_idx
2026-02-07 03:16:38 +00:00
self . head_dim = getattr ( config , " head_dim " , config . hidden_size / / config . num_attention_heads )
self . num_key_value_groups = config . num_attention_heads / / config . num_key_value_heads
self . scaling = self . head_dim * * - 0.5
self . attention_dropout = config . attention_dropout
self . is_causal = True
self . q_proj = nn . Linear ( config . hidden_size , config . num_attention_heads * self . head_dim , bias = False )
self . k_proj = nn . Linear ( config . hidden_size , config . num_key_value_heads * self . head_dim , bias = False )
self . v_proj = nn . Linear ( config . hidden_size , config . num_key_value_heads * self . head_dim , bias = False )
self . out_proj = nn . Linear ( config . num_attention_heads * self . head_dim , config . hidden_size , bias = False )
2025-03-18 05:36:41 +00:00
def forward (
self ,
hidden_states : torch . Tensor ,
2026-02-07 03:16:38 +00:00
position_embeddings : tuple [ torch . Tensor , torch . Tensor ] | None = None ,
attention_mask : torch . Tensor | None = None ,
past_key_values : Cache | None = None ,
cache_position : torch . LongTensor | None = None ,
* * kwargs : Unpack [ TransformersKwargs ] ,
) - > tuple [ torch . Tensor , torch . Tensor ] :
input_shape = hidden_states . shape [ : - 1 ]
hidden_shape = ( * input_shape , - 1 , self . head_dim )
query_states = self . q_proj ( hidden_states ) . view ( hidden_shape ) . transpose ( 1 , 2 )
key_states = self . k_proj ( hidden_states ) . view ( hidden_shape ) . transpose ( 1 , 2 )
value_states = self . v_proj ( hidden_states ) . view ( hidden_shape ) . transpose ( 1 , 2 )
cos , sin = position_embeddings
2025-03-18 05:36:41 +00:00
query_states , key_states = apply_rotary_pos_emb ( query_states , key_states , cos , sin )
2026-02-07 03:16:38 +00:00
if past_key_values is not None :
2025-03-18 05:36:41 +00:00
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = { " sin " : sin , " cos " : cos , " cache_position " : cache_position }
2026-02-07 03:16:38 +00:00
key_states , value_states = past_key_values . update ( key_states , value_states , self . layer_idx , cache_kwargs )
2025-03-18 05:36:41 +00:00
2026-02-07 03:16:38 +00:00
attention_interface : Callable = ALL_ATTENTION_FUNCTIONS . get_interface (
self . config . _attn_implementation , eager_attention_forward
2025-03-18 05:36:41 +00:00
)
2026-02-07 03:16:38 +00:00
attn_output , attn_weights = attention_interface (
self ,
2025-03-18 05:36:41 +00:00
query_states ,
key_states ,
value_states ,
2026-02-07 03:16:38 +00:00
attention_mask ,
dropout = 0.0 if not self . training else self . attention_dropout ,
scaling = self . scaling ,
* * kwargs ,
2025-03-18 05:36:41 +00:00
)
2026-02-07 03:16:38 +00:00
attn_output = attn_output . reshape ( * input_shape , - 1 ) . contiguous ( )
2025-03-18 05:36:41 +00:00
attn_output = self . out_proj ( attn_output )
2026-02-07 03:16:38 +00:00
return attn_output , attn_weights
2025-03-18 05:36:41 +00:00
2026-02-07 03:16:38 +00:00
class ExaoneAttentionBlock ( nn . Module ) :
""" Dummy wrapper class for EXAONE 3.5 structure """
2025-03-18 05:36:41 +00:00
2026-02-07 03:16:38 +00:00
def __init__ ( self , config : ExaoneConfig , layer_idx : int ) :
2025-03-18 05:36:41 +00:00
super ( ) . __init__ ( )
2026-02-07 03:16:38 +00:00
self . config = config
self . layer_idx = layer_idx
self . attention = ExaoneAttention ( config , layer_idx )
2025-03-18 05:36:41 +00:00
def forward (
self ,
hidden_states : torch . Tensor ,
2026-02-07 03:16:38 +00:00
position_embeddings : tuple [ torch . Tensor , torch . Tensor ] | None = None ,
attention_mask : torch . Tensor | None = None ,
past_key_values : Cache | None = None ,
cache_position : torch . LongTensor | None = None ,
* * kwargs : Unpack [ TransformersKwargs ] ,
) - > tuple [ torch . Tensor , torch . Tensor ] :
2025-03-18 05:36:41 +00:00
return self . attention (
hidden_states = hidden_states ,
2026-02-07 03:16:38 +00:00
position_embeddings = position_embeddings ,
2025-03-18 05:36:41 +00:00
attention_mask = attention_mask ,
2026-02-07 03:16:38 +00:00
past_key_values = past_key_values ,
2025-03-18 05:36:41 +00:00
cache_position = cache_position ,
* * kwargs ,
)
2026-02-07 03:16:38 +00:00
class ExaoneMLP ( nn . Module ) :
def __init__ ( self , config ) :
2025-03-18 05:36:41 +00:00
super ( ) . __init__ ( )
self . config = config
2026-02-07 03:16:38 +00:00
self . hidden_size = config . hidden_size
self . intermediate_size = config . intermediate_size
self . c_fc_0 = nn . Linear ( self . hidden_size , self . intermediate_size , bias = False )
self . c_fc_1 = nn . Linear ( self . hidden_size , self . intermediate_size , bias = False )
self . c_proj = nn . Linear ( self . intermediate_size , self . hidden_size , bias = False )
self . act = ACT2FN [ config . hidden_act ]
def forward ( self , x ) :
output_proj = self . c_proj ( self . act ( self . c_fc_0 ( x ) ) * self . c_fc_1 ( x ) )
2025-03-18 05:36:41 +00:00
return output_proj
2026-02-07 03:16:38 +00:00
class ExaoneDecoderLayer ( GradientCheckpointingLayer ) :
2025-03-18 05:36:41 +00:00
def __init__ ( self , config , layer_id ) :
super ( ) . __init__ ( )
self . config = config
2026-02-07 03:16:38 +00:00
self . hidden_size = config . hidden_size
self . ln_1 = ExaoneRMSNorm ( hidden_size = self . hidden_size , eps = config . layer_norm_epsilon )
self . attn = ExaoneAttentionBlock ( config , layer_id )
self . ln_2 = ExaoneRMSNorm ( hidden_size = self . hidden_size , eps = config . layer_norm_epsilon )
self . mlp = ExaoneMLP ( config )
2025-03-18 05:36:41 +00:00
def forward (
self ,
hidden_states : torch . Tensor ,
2026-02-07 03:16:38 +00:00
attention_mask : torch . Tensor | None = None ,
position_ids : torch . LongTensor | None = None ,
past_key_values : Cache | None = None ,
use_cache : bool | None = False ,
cache_position : torch . LongTensor | None = None ,
position_embeddings : tuple [ torch . Tensor , torch . Tensor ] | None = None ,
* * kwargs : Unpack [ TransformersKwargs ] ,
) - > torch . Tensor :
2025-03-18 05:36:41 +00:00
residual = hidden_states
hidden_states = self . ln_1 ( hidden_states )
2026-02-07 03:16:38 +00:00
# Self Attention
hidden_states , _ = self . attn (
2025-03-18 05:36:41 +00:00
hidden_states = hidden_states ,
attention_mask = attention_mask ,
position_ids = position_ids ,
2026-02-07 03:16:38 +00:00
past_key_values = past_key_values ,
2025-03-18 05:36:41 +00:00
use_cache = use_cache ,
cache_position = cache_position ,
position_embeddings = position_embeddings ,
* * kwargs ,
)
hidden_states = residual + hidden_states
2026-02-07 03:16:38 +00:00
# Fully Connected
2025-03-18 05:36:41 +00:00
residual = hidden_states
hidden_states = self . ln_2 ( hidden_states )
hidden_states = self . mlp ( hidden_states )
hidden_states = residual + hidden_states
2026-02-07 03:16:38 +00:00
return hidden_states
2025-03-18 05:36:41 +00:00
2026-02-07 03:16:38 +00:00
@auto_docstring
2025-03-18 05:36:41 +00:00
class ExaonePreTrainedModel ( PreTrainedModel ) :
2026-02-07 03:16:38 +00:00
config : ExaoneConfig
2025-03-18 05:36:41 +00:00
base_model_prefix = " transformer "
supports_gradient_checkpointing = True
2026-02-07 03:16:38 +00:00
_no_split_modules = [ " ExaoneDecoderLayer " ]
_skip_keys_device_placement = [ " past_key_values " ]
_supports_flash_attn = True
2025-03-18 05:36:41 +00:00
_supports_sdpa = True
2026-02-07 03:16:38 +00:00
_supports_flex_attn = True
_can_compile_fullgraph = True
_supports_attention_backend = True
_can_record_outputs = {
" hidden_states " : ExaoneDecoderLayer ,
" attentions " : ExaoneAttention ,
}
class ExaoneRotaryEmbedding ( nn . Module ) :
inv_freq : torch . Tensor # fix linting for `register_buffer`
def __init__ ( self , config : ExaoneConfig , device = None ) :
super ( ) . __init__ ( )
self . max_seq_len_cached = config . max_position_embeddings
self . original_max_seq_len = config . max_position_embeddings
self . config = config
self . rope_type = self . config . rope_parameters [ " rope_type " ]
rope_init_fn : Callable = self . compute_default_rope_parameters
if self . rope_type != " default " :
rope_init_fn = ROPE_INIT_FUNCTIONS [ self . rope_type ]
inv_freq , self . attention_scaling = rope_init_fn ( self . config , device )
self . register_buffer ( " inv_freq " , inv_freq , persistent = False )
self . register_buffer ( " original_inv_freq " , inv_freq . clone ( ) , persistent = False )
@staticmethod
def compute_default_rope_parameters (
config : ExaoneConfig | None = None ,
device : Optional [ " torch.device " ] = None ,
seq_len : int | None = None ,
) - > tuple [ " torch.Tensor " , float ] :
"""
Computes the inverse frequencies according to the original RoPE implementation
Args :
config ( [ ` ~ transformers . PreTrainedConfig ` ] ) :
The model configuration .
device ( ` torch . device ` ) :
The device to use for initialization of the inverse frequencies .
seq_len ( ` int ` , * optional * ) :
The current sequence length . Unused for this type of RoPE .
Returns :
Tuple of ( ` torch . Tensor ` , ` float ` ) , containing the inverse frequencies for the RoPE embeddings and the
post - processing scaling factor applied to the computed cos / sin ( unused in this type of RoPE ) .
"""
base = config . rope_parameters [ " rope_theta " ]
dim = getattr ( config , " head_dim " , None ) or config . hidden_size / / config . num_attention_heads
attention_factor = 1.0 # Unused in this type of RoPE
# Compute the inverse frequencies
inv_freq = 1.0 / (
base * * ( torch . arange ( 0 , dim , 2 , dtype = torch . int64 ) . to ( device = device , dtype = torch . float ) / dim )
)
return inv_freq , attention_factor
@torch.no_grad ( )
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
def forward ( self , x , position_ids ) :
inv_freq_expanded = self . inv_freq [ None , : , None ] . float ( ) . expand ( position_ids . shape [ 0 ] , - 1 , 1 ) . to ( x . device )
position_ids_expanded = position_ids [ : , None , : ] . float ( )
device_type = x . device . type if isinstance ( x . device . type , str ) and x . device . type != " mps " else " cpu "
with maybe_autocast ( device_type = device_type , enabled = False ) : # Force float32
freqs = ( inv_freq_expanded . float ( ) @ position_ids_expanded . float ( ) ) . transpose ( 1 , 2 )
emb = torch . cat ( ( freqs , freqs ) , dim = - 1 )
cos = emb . cos ( ) * self . attention_scaling
sin = emb . sin ( ) * self . attention_scaling
return cos . to ( dtype = x . dtype ) , sin . to ( dtype = x . dtype )
@auto_docstring
2025-03-18 05:36:41 +00:00
class ExaoneModel ( ExaonePreTrainedModel ) :
2026-02-07 03:16:38 +00:00
def __init__ ( self , config : ExaoneConfig ) :
2025-03-18 05:36:41 +00:00
super ( ) . __init__ ( config )
self . config = config
2026-02-07 03:16:38 +00:00
self . hidden_size = config . hidden_size
self . padding_idx = config . pad_token_id
self . vocab_size = config . vocab_size
self . wte = nn . Embedding ( self . vocab_size , self . hidden_size , self . padding_idx )
2025-03-18 05:36:41 +00:00
self . drop = nn . Dropout ( float ( config . embed_dropout ) )
2026-02-07 03:16:38 +00:00
self . h = nn . ModuleList ( [ ExaoneDecoderLayer ( config , layer_id = i ) for i in range ( config . num_layers ) ] )
self . ln_f = ExaoneRMSNorm ( hidden_size = self . hidden_size , eps = config . layer_norm_epsilon )
2025-03-18 05:36:41 +00:00
self . rotary = ExaoneRotaryEmbedding ( config )
2026-02-07 03:16:38 +00:00
2025-03-18 05:36:41 +00:00
# Initialize weights and apply final processing
self . post_init ( )
2026-02-07 03:16:38 +00:00
@check_model_inputs
@auto_docstring
2025-03-18 05:36:41 +00:00
def forward (
self ,
2026-02-07 03:16:38 +00:00
input_ids : torch . LongTensor | None = None ,
attention_mask : torch . Tensor | None = None ,
position_ids : torch . LongTensor | None = None ,
past_key_values : Cache | None = None ,
inputs_embeds : torch . FloatTensor | None = None ,
cache_position : torch . LongTensor | None = None ,
use_cache : bool | None = None ,
* * kwargs : Unpack [ TransformersKwargs ] ,
) - > BaseModelOutputWithPast :
if ( input_ids is None ) ^ ( inputs_embeds is not None ) :
raise ValueError ( " You must specify exactly one of input_ids or inputs_embeds " )
2025-03-18 05:36:41 +00:00
if inputs_embeds is None :
2026-02-07 03:16:38 +00:00
inputs_embeds : torch . Tensor = self . wte ( input_ids )
if use_cache and past_key_values is None :
past_key_values = DynamicCache ( config = self . config )
2025-03-18 05:36:41 +00:00
if cache_position is None :
past_seen_tokens = past_key_values . get_seq_length ( ) if past_key_values is not None else 0
2026-02-07 03:16:38 +00:00
cache_position : torch . Tensor = (
torch . arange ( inputs_embeds . shape [ 1 ] , device = inputs_embeds . device ) + past_seen_tokens
2025-03-18 05:36:41 +00:00
)
2026-02-07 03:16:38 +00:00
2025-03-18 05:36:41 +00:00
if position_ids is None :
position_ids = cache_position . unsqueeze ( 0 )
2026-02-07 03:16:38 +00:00
causal_mask = create_causal_mask (
config = self . config ,
input_embeds = inputs_embeds ,
attention_mask = attention_mask ,
cache_position = cache_position ,
past_key_values = past_key_values ,
position_ids = position_ids ,
2025-03-18 05:36:41 +00:00
)
hidden_states = inputs_embeds
2026-02-07 03:16:38 +00:00
position_embeddings = self . rotary ( hidden_states , position_ids = position_ids )
2025-03-18 05:36:41 +00:00
2026-02-07 03:16:38 +00:00
for decoder_layer in self . h [ : self . config . num_layers ] :
hidden_states = decoder_layer (
hidden_states ,
attention_mask = causal_mask ,
position_embeddings = position_embeddings ,
position_ids = position_ids ,
past_key_values = past_key_values ,
use_cache = use_cache ,
cache_position = cache_position ,
* * kwargs ,
)
2025-03-18 05:36:41 +00:00
2026-02-07 03:16:38 +00:00
hidden_states = self . ln_f ( hidden_states )
2025-03-18 05:36:41 +00:00
return BaseModelOutputWithPast (
last_hidden_state = hidden_states ,
2026-02-07 03:16:38 +00:00
past_key_values = past_key_values ,
2025-03-18 05:36:41 +00:00
)
2026-02-07 03:16:38 +00:00
@auto_docstring
2025-03-18 05:36:41 +00:00
class ExaoneForCausalLM ( ExaonePreTrainedModel , GenerationMixin ) :
2026-02-07 03:16:38 +00:00
_tied_weights_keys = { " lm_head.weight " : " transformer.wte.weight " }
_tp_plan = { " lm_head " : " colwise_gather_output " }
_pp_plan = { " lm_head " : ( [ " hidden_states " ] , [ " logits " ] ) }
2025-03-18 05:36:41 +00:00
def __init__ ( self , config ) :
super ( ) . __init__ ( config )
self . transformer = ExaoneModel ( config )
2026-02-07 03:16:38 +00:00
self . vocab_size = config . vocab_size
2025-03-18 05:36:41 +00:00
self . lm_head = nn . Linear ( config . hidden_size , config . vocab_size , bias = False )
2026-02-07 03:16:38 +00:00
2025-03-18 05:36:41 +00:00
# Initialize weights and apply final processing
self . post_init ( )
2026-02-07 03:16:38 +00:00
@can_return_tuple
@auto_docstring
2025-03-18 05:36:41 +00:00
def forward (
self ,
2026-02-07 03:16:38 +00:00
input_ids : torch . LongTensor | None = None ,
attention_mask : torch . Tensor | None = None ,
position_ids : torch . LongTensor | None = None ,
past_key_values : Cache | None = None ,
inputs_embeds : torch . FloatTensor | None = None ,
labels : torch . LongTensor | None = None ,
use_cache : bool | None = None ,
cache_position : torch . LongTensor | None = None ,
logits_to_keep : int | torch . Tensor = 0 ,
* * kwargs : Unpack [ TransformersKwargs ] ,
) - > CausalLMOutputWithPast :
2025-03-18 05:36:41 +00:00
r """
Args :
labels ( ` torch . LongTensor ` of shape ` ( batch_size , sequence_length ) ` , * optional * ) :
Labels for language modeling . Note that the labels * * are shifted * * inside the model , i . e . you can set
` labels = input_ids ` Indices are selected in ` [ - 100 , 0 , . . . , config . vocab_size ] ` All labels set to ` - 100 `
are ignored ( masked ) , the loss is only computed for labels in ` [ 0 , . . . , config . vocab_size ] `
Example :
` ` ` python
>> > from transformers import AutoModelForCausalLM , AutoTokenizer
2026-02-07 03:16:38 +00:00
>> > model = AutoModelForCausalLM . from_pretrained ( " LGAI-EXAONE/EXAONE-3.5-2.4B-Instruct " ,
2025-03-18 05:36:41 +00:00
trust_remote_code = True )
2026-02-07 03:16:38 +00:00
>> > tokenizer = AutoTokenizer . from_pretrained ( " LGAI-EXAONE/EXAONE-3.5-2.4B-Instruct " )
2025-03-18 05:36:41 +00:00
>> > prompt = " Explain how wonderful you are "
>> > messages = [
{ " role " : " system " , " content " : " You are a helpful assistant. " } ,
{ " role " : " user " , " content " : prompt }
]
>> > input_ids = tokenizer . apply_chat_template (
messages ,
tokenize = True ,
add_generation_prompt = True ,
return_tensors = " pt "
)
2026-02-07 03:16:38 +00:00
>> > output = model . generate ( * * input_ids . to ( model . device ) , max_new_tokens = 128 )
2025-03-18 05:36:41 +00:00
>> > tokenizer . decode ( output [ 0 ] , skip_special_tokens = True )
2026-02-07 03:16:38 +00:00
' [|system|]You are a helpful assistant. \n [|user|]Explain how wonderful you are \n [|assistant|]As an AI assistant, I don \' t experience feelings or qualities like " wonderfulness " in the way humans do, but I can certainly highlight several aspects that make my capabilities and interactions valuable and beneficial: \n \n 1. **Knowledge and Information**: I am equipped with extensive knowledge across a wide range of topics including science, technology, history, culture, and more. This allows me to provide accurate, informative responses to a vast array of inquiries, helping users learn and explore new ideas. \n \n 2. **Accessibility**: I am available 24/7, meaning you can ask me questions or seek assistance at '
2025-03-18 05:36:41 +00:00
` ` `
"""
2026-02-07 03:16:38 +00:00
outputs : BaseModelOutputWithPast = self . transformer (
input_ids = input_ids ,
2025-03-18 05:36:41 +00:00
attention_mask = attention_mask ,
position_ids = position_ids ,
2026-02-07 03:16:38 +00:00
past_key_values = past_key_values ,
2025-03-18 05:36:41 +00:00
inputs_embeds = inputs_embeds ,
use_cache = use_cache ,
cache_position = cache_position ,
2026-02-07 03:16:38 +00:00
* * kwargs ,
2025-03-18 05:36:41 +00:00
)
2026-02-07 03:16:38 +00:00
hidden_states = outputs . last_hidden_state
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice ( - logits_to_keep , None ) if isinstance ( logits_to_keep , int ) else logits_to_keep
logits = self . lm_head ( hidden_states [ : , slice_indices , : ] )
2025-03-18 05:36:41 +00:00
loss = None
if labels is not None :
2026-02-07 03:16:38 +00:00
loss = self . loss_function ( logits = logits , labels = labels , vocab_size = self . config . vocab_size , * * kwargs )
2025-03-18 05:36:41 +00:00
2026-02-07 03:16:38 +00:00
return CausalLMOutputWithPast (
loss = loss ,
logits = logits ,
past_key_values = outputs . past_key_values ,
2025-03-18 05:36:41 +00:00
hidden_states = outputs . hidden_states ,
attentions = outputs . attentions ,
)
2026-02-07 03:16:38 +00:00
__all__ = [ " ExaonePreTrainedModel " , " ExaoneModel " , " ExaoneForCausalLM " ]