Files
GPT3-dev-125m-0104/modeling_gpt3dev.py

480 lines
17 KiB
Python
Raw Normal View History

import math
import torch
import torch.nn as nn
from packaging.version import Version
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP
from transformers import (
__version__ as TRANSFORMERS_VERSION,
AutoConfig,
AutoModel,
AutoModelForCausalLM
)
from transformers.modeling_outputs import (
CausalLMOutputWithCrossAttentions,
)
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from transformers.models.gpt2.modeling_gpt2 import (
GPT2LMHeadModel,
GPT2Model,
GPT2Block,
GPT2Attention,
GPT2MLP,
CausalLMOutputWithCrossAttentions
)
IS_TRANSFORMERS_V5 = Version(TRANSFORMERS_VERSION) >= Version("5.0.0")
def _normalize_block_args(
extra_args,
*,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
use_cache=False,
output_attentions=False,
):
if IS_TRANSFORMERS_V5:
if extra_args and encoder_hidden_states is None:
encoder_hidden_states = extra_args[0]
else:
if extra_args:
if head_mask is None:
head_mask = extra_args[0]
if len(extra_args) > 1 and encoder_hidden_states is None:
encoder_hidden_states = extra_args[1]
if len(extra_args) > 2 and encoder_attention_mask is None:
encoder_attention_mask = extra_args[2]
if len(extra_args) > 3:
use_cache = extra_args[3]
if len(extra_args) > 4:
output_attentions = extra_args[4]
return (
head_mask,
encoder_hidden_states,
encoder_attention_mask,
use_cache,
output_attentions,
)
class GPT3DevConfig(GPT2Config):
model_type = "gpt3dev"
def __init__(self, use_pre_layernorm=True, window_size=256, stride=128, **kwargs):
super().__init__(**kwargs)
self.use_pre_layernorm = use_pre_layernorm
self.window_size = window_size
self.stride = stride
class GPT3DevAttention(GPT2Attention): # dense
"""GPT-3 style dense attention: nn.Linear instead of Conv1D."""
def __init__(self, config, is_cross_attention=False, layer_idx=None):
super().__init__(config, is_cross_attention, layer_idx=layer_idx)
# GPT-3 uses nn.Linear instead of Conv1D
self.c_attn = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=True)
self.c_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=True)
# forward() inherited from GPT2Attention — no override needed
class GPT3DevSparseAttention(GPT3DevAttention): # local sparse
"""GPT-3 style locally banded sparse attention."""
def __init__(self, config, is_cross_attention=False, layer_idx=None):
super().__init__(config, is_cross_attention, layer_idx=layer_idx)
self.window_size = getattr(config, "window_size", 256)
def forward(
self,
hidden_states,
past_key_value=None,
cache_position=None,
attention_mask=None,
*extra_args,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=False,
past_key_values=None,
**kwargs,
):
if past_key_values is not None and past_key_value is None:
past_key_value = past_key_values
bsz, tgt_len, _ = hidden_states.size()
device = hidden_states.device
dtype = hidden_states.dtype
# Determine query/key positions using cache_position (new API)
if cache_position is not None:
q_pos = cache_position # shape: (tgt_len,)
seq_len = int(q_pos[-1].item()) + 1
else:
q_pos = torch.arange(tgt_len, device=device)
seq_len = tgt_len
k_pos = torch.arange(seq_len, device=device)
diff = q_pos[:, None] - k_pos[None, :] # (tgt_len, seq_len)
is_causal = diff >= 0
within_window = diff.abs() <= self.window_size
allow_attention = is_causal & within_window
del is_causal, within_window, diff
sparse_mask = torch.zeros((1, 1, tgt_len, seq_len), dtype=dtype, device=device)
sparse_mask.masked_fill_(~allow_attention, torch.finfo(dtype).min)
del allow_attention
# Combine with parent's causal mask
if attention_mask is not None:
# Parent may create mask with extra KV positions — trim to match
if attention_mask.size(-1) != sparse_mask.size(-1):
attention_mask = attention_mask[..., :sparse_mask.size(-1)]
if attention_mask.size(-2) != sparse_mask.size(-2):
attention_mask = attention_mask[..., :sparse_mask.size(-2), :]
attention_mask = torch.minimum(attention_mask, sparse_mask)
else:
attention_mask = sparse_mask
del sparse_mask
forward_kwargs = dict(
hidden_states=hidden_states,
cache_position=cache_position,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
**kwargs,
)
if IS_TRANSFORMERS_V5:
forward_kwargs["past_key_values"] = past_key_value
else:
forward_kwargs["past_key_value"] = past_key_value
forward_kwargs["head_mask"] = head_mask
return super().forward(**forward_kwargs)
class GPT3DevMLP(GPT2MLP):
def __init__(self, intermediate_size, config):
super().__init__(intermediate_size, config)
self.c_fc = nn.Linear(config.hidden_size, intermediate_size, bias=True)
self.c_proj = nn.Linear(intermediate_size, config.hidden_size, bias=True)
self.act = nn.GELU() # standard GeLU
class GPT3DevBlock(GPT2Block):
"""GPT-3 block with pre-LayerNorm and alternating dense/sparse attention."""
def __init__(self, config, is_sparse: bool = False, layer_idx=None):
super().__init__(config, layer_idx=layer_idx)
self.use_pre_layernorm = config.use_pre_layernorm
self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
if is_sparse:
self.attn = GPT3DevSparseAttention(config, layer_idx=layer_idx)
else:
self.attn = GPT3DevAttention(config, layer_idx=layer_idx)
self.mlp = GPT3DevMLP(4 * config.hidden_size, config)
def forward(
self,
hidden_states,
past_key_value=None,
cache_position=None,
attention_mask=None,
*extra_args,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
use_cache=False,
output_attentions=False,
past_key_values=None,
**kwargs,
):
if past_key_values is not None and past_key_value is None:
past_key_value = past_key_values
(
head_mask,
encoder_hidden_states,
encoder_attention_mask,
use_cache,
output_attentions,
) = _normalize_block_args(
extra_args,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
if self.use_pre_layernorm:
# Pre-LayerNorm (GPT-3)
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
attn_kwargs = dict(
hidden_states=hidden_states,
cache_position=cache_position,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
**kwargs,
)
if IS_TRANSFORMERS_V5:
attn_kwargs["past_key_values"] = past_key_value
attn_output, attn_weights = self.attn(**attn_kwargs)
else:
attn_kwargs["past_key_value"] = past_key_value
attn_kwargs["head_mask"] = head_mask
attn_output, attn_weights = self.attn(**attn_kwargs)
hidden_states = residual + attn_output
residual = hidden_states
hidden_states = self.ln_2(hidden_states)
feed_forward_hidden_states = self.mlp(hidden_states)
hidden_states = residual + feed_forward_hidden_states
else:
# Post-LayerNorm (GPT-2)
residual = hidden_states
attn_kwargs = dict(
hidden_states=hidden_states,
cache_position=cache_position,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
**kwargs,
)
if IS_TRANSFORMERS_V5:
attn_kwargs["past_key_values"] = past_key_value
attn_output, attn_weights = self.attn(**attn_kwargs)
else:
attn_kwargs["past_key_value"] = past_key_value
attn_kwargs["head_mask"] = head_mask
attn_output, attn_weights = self.attn(**attn_kwargs)
hidden_states = residual + attn_output
hidden_states = self.ln_1(hidden_states)
residual = hidden_states
feed_forward_hidden_states = self.mlp(hidden_states)
hidden_states = residual + feed_forward_hidden_states
hidden_states = self.ln_2(hidden_states)
if IS_TRANSFORMERS_V5:
return hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
class GPT3DevModel(GPT2Model):
config_class = GPT3DevConfig
def __init__(self, config):
super().__init__(config)
self.wte = nn.Embedding(config.vocab_size, config.hidden_size)
self.wpe = nn.Embedding(config.n_positions, config.hidden_size)
self.drop = nn.Dropout(config.embd_pdrop)
self.h = nn.ModuleList()
for i in range(config.num_hidden_layers):
self.h.append(GPT3DevBlock(config, is_sparse=(i % 2 == 1), layer_idx=i))
self.ln_f = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.post_init()
# NOTE: _apply_residual_scaling is called from GPT3DevLMHeadModel.__init__
# AFTER the final post_init(), so it is NOT undone by re-initialization.
def _apply_residual_scaling(self):
# GPT-3/GPT-2 modified init: scale residuals by 1 / sqrt(2 * num_layers)
scale = 1 / math.sqrt(2 * self.config.num_hidden_layers)
for block in self.h:
block.attn.c_proj.weight.data.mul_(scale)
block.mlp.c_proj.weight.data.mul_(scale)
class GPT3DevLMHeadModel(GPT2LMHeadModel):
config_class = GPT3DevConfig
def __init__(self, config):
super().__init__(config)
self.transformer = GPT3DevModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
# GPT-3 modified init: scale residual projections by 1/sqrt(2*num_layers)
# MUST be AFTER the final post_init() which re-initializes all weights
self.transformer._apply_residual_scaling()
def forward(
self,
input_ids=None,
past_key_values=None,
cache_position=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
logits_to_keep=0,
output_logits=None, # Force returning full logits even with labels (for debugging/distillation)
**kwargs,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_kwargs = dict(
input_ids=input_ids,
attention_mask=attention_mask,
cache_position=cache_position,
token_type_ids=token_type_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
**kwargs,
)
if not IS_TRANSFORMERS_V5:
transformer_kwargs["head_mask"] = head_mask
transformer_kwargs["output_attentions"] = output_attentions
transformer_kwargs["output_hidden_states"] = output_hidden_states
transformer_kwargs["return_dict"] = return_dict
transformer_kwargs["past_key_values"] = past_key_values
transformer_outputs = self.transformer(**transformer_kwargs)
hidden_states = (
transformer_outputs.last_hidden_state
if hasattr(transformer_outputs, "last_hidden_state")
else transformer_outputs[0]
)
# Set up for loss computation if labels are provided
compute_full_logits = labels is not None or output_logits or logits_to_keep == 0
if compute_full_logits:
logits_hidden_states = hidden_states
else:
slice_indices = (
slice(-logits_to_keep, None)
if isinstance(logits_to_keep, int)
else logits_to_keep
)
logits_hidden_states = hidden_states[:, slice_indices, :]
lm_logits = self.lm_head(logits_hidden_states.contiguous())
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
return ((loss,) if loss is not None else ()) + (lm_logits,) + transformer_outputs[1:]
return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=lm_logits,
past_key_values=getattr(transformer_outputs, "past_key_values", None),
hidden_states=getattr(transformer_outputs, "hidden_states", None),
attentions=getattr(transformer_outputs, "attentions", None),
cross_attentions=getattr(transformer_outputs, "cross_attentions", None),
)
AutoConfig.register("gpt3dev", GPT3DevConfig)
AutoModel.register(GPT3DevConfig, GPT3DevModel)
AutoModelForCausalLM.register(GPT3DevConfig, GPT3DevLMHeadModel)
# ---- Transformers 5.x compatibility patch ----
_ORIG_GPT3DEV_BLOCK_FORWARD = GPT3DevBlock.forward
_ORIG_GPT3DEV_SPARSE_FORWARD = GPT3DevSparseAttention.forward
def _patched_gpt3dev_block_forward(
self,
hidden_states,
past_key_values=None,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
use_cache=False,
**kwargs,
):
cache_position = kwargs.pop("cache_position", None)
output_attentions = kwargs.pop("output_attentions", False)
head_mask = kwargs.pop("head_mask", None)
past_key_value = kwargs.pop("past_key_value", None)
if past_key_values is None:
past_key_values = past_key_value
return _ORIG_GPT3DEV_BLOCK_FORWARD(
self,
hidden_states,
past_key_value=past_key_values,
cache_position=cache_position,
attention_mask=attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
**kwargs,
)
def _patched_gpt3dev_sparse_forward(
self,
hidden_states,
past_key_values=None,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=False,
**kwargs,
):
cache_position = kwargs.pop("cache_position", None)
head_mask = kwargs.pop("head_mask", None)
past_key_value = kwargs.pop("past_key_value", None)
if past_key_values is None:
past_key_values = past_key_value
return _ORIG_GPT3DEV_SPARSE_FORWARD(
self,
hidden_states,
past_key_value=past_key_values,
cache_position=cache_position,
attention_mask=attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
**kwargs,
)
GPT3DevBlock.forward = _patched_gpt3dev_block_forward
GPT3DevSparseAttention.forward = _patched_gpt3dev_sparse_forward
# ---- End compatibility patch ----