初始化项目,由ModelHub XC社区提供模型
Model: k050506koch/GPT3-dev-125m-0104 Source: Original Platform
This commit is contained in:
480
modeling_gpt3dev.py
Normal file
480
modeling_gpt3dev.py
Normal file
@@ -0,0 +1,480 @@
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from packaging.version import Version
|
||||
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP
|
||||
|
||||
from transformers import (
|
||||
__version__ as TRANSFORMERS_VERSION,
|
||||
AutoConfig,
|
||||
AutoModel,
|
||||
AutoModelForCausalLM
|
||||
)
|
||||
|
||||
from transformers.modeling_outputs import (
|
||||
CausalLMOutputWithCrossAttentions,
|
||||
)
|
||||
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
||||
from transformers.models.gpt2.modeling_gpt2 import (
|
||||
GPT2LMHeadModel,
|
||||
GPT2Model,
|
||||
GPT2Block,
|
||||
GPT2Attention,
|
||||
GPT2MLP,
|
||||
CausalLMOutputWithCrossAttentions
|
||||
)
|
||||
|
||||
IS_TRANSFORMERS_V5 = Version(TRANSFORMERS_VERSION) >= Version("5.0.0")
|
||||
|
||||
|
||||
def _normalize_block_args(
|
||||
extra_args,
|
||||
*,
|
||||
head_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
use_cache=False,
|
||||
output_attentions=False,
|
||||
):
|
||||
if IS_TRANSFORMERS_V5:
|
||||
if extra_args and encoder_hidden_states is None:
|
||||
encoder_hidden_states = extra_args[0]
|
||||
else:
|
||||
if extra_args:
|
||||
if head_mask is None:
|
||||
head_mask = extra_args[0]
|
||||
if len(extra_args) > 1 and encoder_hidden_states is None:
|
||||
encoder_hidden_states = extra_args[1]
|
||||
if len(extra_args) > 2 and encoder_attention_mask is None:
|
||||
encoder_attention_mask = extra_args[2]
|
||||
if len(extra_args) > 3:
|
||||
use_cache = extra_args[3]
|
||||
if len(extra_args) > 4:
|
||||
output_attentions = extra_args[4]
|
||||
|
||||
return (
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
use_cache,
|
||||
output_attentions,
|
||||
)
|
||||
|
||||
|
||||
class GPT3DevConfig(GPT2Config):
|
||||
model_type = "gpt3dev"
|
||||
|
||||
def __init__(self, use_pre_layernorm=True, window_size=256, stride=128, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.use_pre_layernorm = use_pre_layernorm
|
||||
self.window_size = window_size
|
||||
self.stride = stride
|
||||
|
||||
|
||||
class GPT3DevAttention(GPT2Attention): # dense
|
||||
"""GPT-3 style dense attention: nn.Linear instead of Conv1D."""
|
||||
def __init__(self, config, is_cross_attention=False, layer_idx=None):
|
||||
super().__init__(config, is_cross_attention, layer_idx=layer_idx)
|
||||
# GPT-3 uses nn.Linear instead of Conv1D
|
||||
self.c_attn = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=True)
|
||||
self.c_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=True)
|
||||
# forward() inherited from GPT2Attention — no override needed
|
||||
|
||||
|
||||
class GPT3DevSparseAttention(GPT3DevAttention): # local sparse
|
||||
"""GPT-3 style locally banded sparse attention."""
|
||||
def __init__(self, config, is_cross_attention=False, layer_idx=None):
|
||||
super().__init__(config, is_cross_attention, layer_idx=layer_idx)
|
||||
self.window_size = getattr(config, "window_size", 256)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
past_key_value=None,
|
||||
cache_position=None,
|
||||
attention_mask=None,
|
||||
*extra_args,
|
||||
head_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
output_attentions=False,
|
||||
past_key_values=None,
|
||||
**kwargs,
|
||||
):
|
||||
if past_key_values is not None and past_key_value is None:
|
||||
past_key_value = past_key_values
|
||||
|
||||
bsz, tgt_len, _ = hidden_states.size()
|
||||
device = hidden_states.device
|
||||
dtype = hidden_states.dtype
|
||||
|
||||
# Determine query/key positions using cache_position (new API)
|
||||
if cache_position is not None:
|
||||
q_pos = cache_position # shape: (tgt_len,)
|
||||
seq_len = int(q_pos[-1].item()) + 1
|
||||
else:
|
||||
q_pos = torch.arange(tgt_len, device=device)
|
||||
seq_len = tgt_len
|
||||
k_pos = torch.arange(seq_len, device=device)
|
||||
|
||||
diff = q_pos[:, None] - k_pos[None, :] # (tgt_len, seq_len)
|
||||
is_causal = diff >= 0
|
||||
within_window = diff.abs() <= self.window_size
|
||||
allow_attention = is_causal & within_window
|
||||
del is_causal, within_window, diff
|
||||
|
||||
sparse_mask = torch.zeros((1, 1, tgt_len, seq_len), dtype=dtype, device=device)
|
||||
sparse_mask.masked_fill_(~allow_attention, torch.finfo(dtype).min)
|
||||
del allow_attention
|
||||
|
||||
# Combine with parent's causal mask
|
||||
if attention_mask is not None:
|
||||
# Parent may create mask with extra KV positions — trim to match
|
||||
if attention_mask.size(-1) != sparse_mask.size(-1):
|
||||
attention_mask = attention_mask[..., :sparse_mask.size(-1)]
|
||||
if attention_mask.size(-2) != sparse_mask.size(-2):
|
||||
attention_mask = attention_mask[..., :sparse_mask.size(-2), :]
|
||||
attention_mask = torch.minimum(attention_mask, sparse_mask)
|
||||
else:
|
||||
attention_mask = sparse_mask
|
||||
del sparse_mask
|
||||
|
||||
forward_kwargs = dict(
|
||||
hidden_states=hidden_states,
|
||||
cache_position=cache_position,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
**kwargs,
|
||||
)
|
||||
if IS_TRANSFORMERS_V5:
|
||||
forward_kwargs["past_key_values"] = past_key_value
|
||||
else:
|
||||
forward_kwargs["past_key_value"] = past_key_value
|
||||
forward_kwargs["head_mask"] = head_mask
|
||||
|
||||
return super().forward(**forward_kwargs)
|
||||
|
||||
|
||||
class GPT3DevMLP(GPT2MLP):
|
||||
def __init__(self, intermediate_size, config):
|
||||
super().__init__(intermediate_size, config)
|
||||
self.c_fc = nn.Linear(config.hidden_size, intermediate_size, bias=True)
|
||||
self.c_proj = nn.Linear(intermediate_size, config.hidden_size, bias=True)
|
||||
self.act = nn.GELU() # standard GeLU
|
||||
|
||||
|
||||
class GPT3DevBlock(GPT2Block):
|
||||
"""GPT-3 block with pre-LayerNorm and alternating dense/sparse attention."""
|
||||
def __init__(self, config, is_sparse: bool = False, layer_idx=None):
|
||||
super().__init__(config, layer_idx=layer_idx)
|
||||
self.use_pre_layernorm = config.use_pre_layernorm
|
||||
self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||
|
||||
if is_sparse:
|
||||
self.attn = GPT3DevSparseAttention(config, layer_idx=layer_idx)
|
||||
else:
|
||||
self.attn = GPT3DevAttention(config, layer_idx=layer_idx)
|
||||
|
||||
self.mlp = GPT3DevMLP(4 * config.hidden_size, config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
past_key_value=None,
|
||||
cache_position=None,
|
||||
attention_mask=None,
|
||||
*extra_args,
|
||||
head_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
use_cache=False,
|
||||
output_attentions=False,
|
||||
past_key_values=None,
|
||||
**kwargs,
|
||||
):
|
||||
if past_key_values is not None and past_key_value is None:
|
||||
past_key_value = past_key_values
|
||||
(
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
use_cache,
|
||||
output_attentions,
|
||||
) = _normalize_block_args(
|
||||
extra_args,
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
if self.use_pre_layernorm:
|
||||
# Pre-LayerNorm (GPT-3)
|
||||
residual = hidden_states
|
||||
hidden_states = self.ln_1(hidden_states)
|
||||
attn_kwargs = dict(
|
||||
hidden_states=hidden_states,
|
||||
cache_position=cache_position,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
**kwargs,
|
||||
)
|
||||
if IS_TRANSFORMERS_V5:
|
||||
attn_kwargs["past_key_values"] = past_key_value
|
||||
attn_output, attn_weights = self.attn(**attn_kwargs)
|
||||
else:
|
||||
attn_kwargs["past_key_value"] = past_key_value
|
||||
attn_kwargs["head_mask"] = head_mask
|
||||
attn_output, attn_weights = self.attn(**attn_kwargs)
|
||||
hidden_states = residual + attn_output
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.ln_2(hidden_states)
|
||||
feed_forward_hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + feed_forward_hidden_states
|
||||
else:
|
||||
# Post-LayerNorm (GPT-2)
|
||||
residual = hidden_states
|
||||
attn_kwargs = dict(
|
||||
hidden_states=hidden_states,
|
||||
cache_position=cache_position,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
**kwargs,
|
||||
)
|
||||
if IS_TRANSFORMERS_V5:
|
||||
attn_kwargs["past_key_values"] = past_key_value
|
||||
attn_output, attn_weights = self.attn(**attn_kwargs)
|
||||
else:
|
||||
attn_kwargs["past_key_value"] = past_key_value
|
||||
attn_kwargs["head_mask"] = head_mask
|
||||
attn_output, attn_weights = self.attn(**attn_kwargs)
|
||||
hidden_states = residual + attn_output
|
||||
hidden_states = self.ln_1(hidden_states)
|
||||
|
||||
residual = hidden_states
|
||||
feed_forward_hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + feed_forward_hidden_states
|
||||
hidden_states = self.ln_2(hidden_states)
|
||||
|
||||
if IS_TRANSFORMERS_V5:
|
||||
return hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
if output_attentions:
|
||||
outputs += (attn_weights,)
|
||||
return outputs
|
||||
|
||||
|
||||
class GPT3DevModel(GPT2Model):
|
||||
config_class = GPT3DevConfig
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
self.wte = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
self.wpe = nn.Embedding(config.n_positions, config.hidden_size)
|
||||
self.drop = nn.Dropout(config.embd_pdrop)
|
||||
self.h = nn.ModuleList()
|
||||
for i in range(config.num_hidden_layers):
|
||||
self.h.append(GPT3DevBlock(config, is_sparse=(i % 2 == 1), layer_idx=i))
|
||||
|
||||
self.ln_f = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.post_init()
|
||||
# NOTE: _apply_residual_scaling is called from GPT3DevLMHeadModel.__init__
|
||||
# AFTER the final post_init(), so it is NOT undone by re-initialization.
|
||||
|
||||
def _apply_residual_scaling(self):
|
||||
# GPT-3/GPT-2 modified init: scale residuals by 1 / sqrt(2 * num_layers)
|
||||
scale = 1 / math.sqrt(2 * self.config.num_hidden_layers)
|
||||
for block in self.h:
|
||||
block.attn.c_proj.weight.data.mul_(scale)
|
||||
block.mlp.c_proj.weight.data.mul_(scale)
|
||||
|
||||
|
||||
|
||||
|
||||
class GPT3DevLMHeadModel(GPT2LMHeadModel):
|
||||
config_class = GPT3DevConfig
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.transformer = GPT3DevModel(config)
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
self.post_init()
|
||||
# GPT-3 modified init: scale residual projections by 1/sqrt(2*num_layers)
|
||||
# MUST be AFTER the final post_init() which re-initializes all weights
|
||||
self.transformer._apply_residual_scaling()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
past_key_values=None,
|
||||
cache_position=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
labels=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
logits_to_keep=0,
|
||||
output_logits=None, # Force returning full logits even with labels (for debugging/distillation)
|
||||
**kwargs,
|
||||
):
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
transformer_kwargs = dict(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
use_cache=use_cache,
|
||||
**kwargs,
|
||||
)
|
||||
if not IS_TRANSFORMERS_V5:
|
||||
transformer_kwargs["head_mask"] = head_mask
|
||||
transformer_kwargs["output_attentions"] = output_attentions
|
||||
transformer_kwargs["output_hidden_states"] = output_hidden_states
|
||||
transformer_kwargs["return_dict"] = return_dict
|
||||
transformer_kwargs["past_key_values"] = past_key_values
|
||||
|
||||
transformer_outputs = self.transformer(**transformer_kwargs)
|
||||
|
||||
hidden_states = (
|
||||
transformer_outputs.last_hidden_state
|
||||
if hasattr(transformer_outputs, "last_hidden_state")
|
||||
else transformer_outputs[0]
|
||||
)
|
||||
|
||||
# Set up for loss computation if labels are provided
|
||||
compute_full_logits = labels is not None or output_logits or logits_to_keep == 0
|
||||
if compute_full_logits:
|
||||
logits_hidden_states = hidden_states
|
||||
else:
|
||||
slice_indices = (
|
||||
slice(-logits_to_keep, None)
|
||||
if isinstance(logits_to_keep, int)
|
||||
else logits_to_keep
|
||||
)
|
||||
logits_hidden_states = hidden_states[:, slice_indices, :]
|
||||
lm_logits = self.lm_head(logits_hidden_states.contiguous())
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = lm_logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = nn.CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
if not return_dict:
|
||||
return ((loss,) if loss is not None else ()) + (lm_logits,) + transformer_outputs[1:]
|
||||
|
||||
return CausalLMOutputWithCrossAttentions(
|
||||
loss=loss,
|
||||
logits=lm_logits,
|
||||
past_key_values=getattr(transformer_outputs, "past_key_values", None),
|
||||
hidden_states=getattr(transformer_outputs, "hidden_states", None),
|
||||
attentions=getattr(transformer_outputs, "attentions", None),
|
||||
cross_attentions=getattr(transformer_outputs, "cross_attentions", None),
|
||||
)
|
||||
|
||||
|
||||
AutoConfig.register("gpt3dev", GPT3DevConfig)
|
||||
AutoModel.register(GPT3DevConfig, GPT3DevModel)
|
||||
AutoModelForCausalLM.register(GPT3DevConfig, GPT3DevLMHeadModel)
|
||||
|
||||
# ---- Transformers 5.x compatibility patch ----
|
||||
_ORIG_GPT3DEV_BLOCK_FORWARD = GPT3DevBlock.forward
|
||||
_ORIG_GPT3DEV_SPARSE_FORWARD = GPT3DevSparseAttention.forward
|
||||
|
||||
def _patched_gpt3dev_block_forward(
|
||||
self,
|
||||
hidden_states,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
use_cache=False,
|
||||
**kwargs,
|
||||
):
|
||||
cache_position = kwargs.pop("cache_position", None)
|
||||
output_attentions = kwargs.pop("output_attentions", False)
|
||||
head_mask = kwargs.pop("head_mask", None)
|
||||
past_key_value = kwargs.pop("past_key_value", None)
|
||||
if past_key_values is None:
|
||||
past_key_values = past_key_value
|
||||
|
||||
return _ORIG_GPT3DEV_BLOCK_FORWARD(
|
||||
self,
|
||||
hidden_states,
|
||||
past_key_value=past_key_values,
|
||||
cache_position=cache_position,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def _patched_gpt3dev_sparse_forward(
|
||||
self,
|
||||
hidden_states,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
output_attentions=False,
|
||||
**kwargs,
|
||||
):
|
||||
cache_position = kwargs.pop("cache_position", None)
|
||||
head_mask = kwargs.pop("head_mask", None)
|
||||
past_key_value = kwargs.pop("past_key_value", None)
|
||||
if past_key_values is None:
|
||||
past_key_values = past_key_value
|
||||
|
||||
return _ORIG_GPT3DEV_SPARSE_FORWARD(
|
||||
self,
|
||||
hidden_states,
|
||||
past_key_value=past_key_values,
|
||||
cache_position=cache_position,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
GPT3DevBlock.forward = _patched_gpt3dev_block_forward
|
||||
GPT3DevSparseAttention.forward = _patched_gpt3dev_sparse_forward
|
||||
# ---- End compatibility patch ----
|
||||
Reference in New Issue
Block a user