480 lines
17 KiB
Python
480 lines
17 KiB
Python
|
|
import math
|
||
|
|
import torch
|
||
|
|
import torch.nn as nn
|
||
|
|
from packaging.version import Version
|
||
|
|
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
||
|
|
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP
|
||
|
|
|
||
|
|
from transformers import (
|
||
|
|
__version__ as TRANSFORMERS_VERSION,
|
||
|
|
AutoConfig,
|
||
|
|
AutoModel,
|
||
|
|
AutoModelForCausalLM
|
||
|
|
)
|
||
|
|
|
||
|
|
from transformers.modeling_outputs import (
|
||
|
|
CausalLMOutputWithCrossAttentions,
|
||
|
|
)
|
||
|
|
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
||
|
|
from transformers.models.gpt2.modeling_gpt2 import (
|
||
|
|
GPT2LMHeadModel,
|
||
|
|
GPT2Model,
|
||
|
|
GPT2Block,
|
||
|
|
GPT2Attention,
|
||
|
|
GPT2MLP,
|
||
|
|
CausalLMOutputWithCrossAttentions
|
||
|
|
)
|
||
|
|
|
||
|
|
IS_TRANSFORMERS_V5 = Version(TRANSFORMERS_VERSION) >= Version("5.0.0")
|
||
|
|
|
||
|
|
|
||
|
|
def _normalize_block_args(
|
||
|
|
extra_args,
|
||
|
|
*,
|
||
|
|
head_mask=None,
|
||
|
|
encoder_hidden_states=None,
|
||
|
|
encoder_attention_mask=None,
|
||
|
|
use_cache=False,
|
||
|
|
output_attentions=False,
|
||
|
|
):
|
||
|
|
if IS_TRANSFORMERS_V5:
|
||
|
|
if extra_args and encoder_hidden_states is None:
|
||
|
|
encoder_hidden_states = extra_args[0]
|
||
|
|
else:
|
||
|
|
if extra_args:
|
||
|
|
if head_mask is None:
|
||
|
|
head_mask = extra_args[0]
|
||
|
|
if len(extra_args) > 1 and encoder_hidden_states is None:
|
||
|
|
encoder_hidden_states = extra_args[1]
|
||
|
|
if len(extra_args) > 2 and encoder_attention_mask is None:
|
||
|
|
encoder_attention_mask = extra_args[2]
|
||
|
|
if len(extra_args) > 3:
|
||
|
|
use_cache = extra_args[3]
|
||
|
|
if len(extra_args) > 4:
|
||
|
|
output_attentions = extra_args[4]
|
||
|
|
|
||
|
|
return (
|
||
|
|
head_mask,
|
||
|
|
encoder_hidden_states,
|
||
|
|
encoder_attention_mask,
|
||
|
|
use_cache,
|
||
|
|
output_attentions,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
class GPT3DevConfig(GPT2Config):
|
||
|
|
model_type = "gpt3dev"
|
||
|
|
|
||
|
|
def __init__(self, use_pre_layernorm=True, window_size=256, stride=128, **kwargs):
|
||
|
|
super().__init__(**kwargs)
|
||
|
|
self.use_pre_layernorm = use_pre_layernorm
|
||
|
|
self.window_size = window_size
|
||
|
|
self.stride = stride
|
||
|
|
|
||
|
|
|
||
|
|
class GPT3DevAttention(GPT2Attention): # dense
|
||
|
|
"""GPT-3 style dense attention: nn.Linear instead of Conv1D."""
|
||
|
|
def __init__(self, config, is_cross_attention=False, layer_idx=None):
|
||
|
|
super().__init__(config, is_cross_attention, layer_idx=layer_idx)
|
||
|
|
# GPT-3 uses nn.Linear instead of Conv1D
|
||
|
|
self.c_attn = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=True)
|
||
|
|
self.c_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=True)
|
||
|
|
# forward() inherited from GPT2Attention — no override needed
|
||
|
|
|
||
|
|
|
||
|
|
class GPT3DevSparseAttention(GPT3DevAttention): # local sparse
|
||
|
|
"""GPT-3 style locally banded sparse attention."""
|
||
|
|
def __init__(self, config, is_cross_attention=False, layer_idx=None):
|
||
|
|
super().__init__(config, is_cross_attention, layer_idx=layer_idx)
|
||
|
|
self.window_size = getattr(config, "window_size", 256)
|
||
|
|
|
||
|
|
def forward(
|
||
|
|
self,
|
||
|
|
hidden_states,
|
||
|
|
past_key_value=None,
|
||
|
|
cache_position=None,
|
||
|
|
attention_mask=None,
|
||
|
|
*extra_args,
|
||
|
|
head_mask=None,
|
||
|
|
encoder_hidden_states=None,
|
||
|
|
encoder_attention_mask=None,
|
||
|
|
output_attentions=False,
|
||
|
|
past_key_values=None,
|
||
|
|
**kwargs,
|
||
|
|
):
|
||
|
|
if past_key_values is not None and past_key_value is None:
|
||
|
|
past_key_value = past_key_values
|
||
|
|
|
||
|
|
bsz, tgt_len, _ = hidden_states.size()
|
||
|
|
device = hidden_states.device
|
||
|
|
dtype = hidden_states.dtype
|
||
|
|
|
||
|
|
# Determine query/key positions using cache_position (new API)
|
||
|
|
if cache_position is not None:
|
||
|
|
q_pos = cache_position # shape: (tgt_len,)
|
||
|
|
seq_len = int(q_pos[-1].item()) + 1
|
||
|
|
else:
|
||
|
|
q_pos = torch.arange(tgt_len, device=device)
|
||
|
|
seq_len = tgt_len
|
||
|
|
k_pos = torch.arange(seq_len, device=device)
|
||
|
|
|
||
|
|
diff = q_pos[:, None] - k_pos[None, :] # (tgt_len, seq_len)
|
||
|
|
is_causal = diff >= 0
|
||
|
|
within_window = diff.abs() <= self.window_size
|
||
|
|
allow_attention = is_causal & within_window
|
||
|
|
del is_causal, within_window, diff
|
||
|
|
|
||
|
|
sparse_mask = torch.zeros((1, 1, tgt_len, seq_len), dtype=dtype, device=device)
|
||
|
|
sparse_mask.masked_fill_(~allow_attention, torch.finfo(dtype).min)
|
||
|
|
del allow_attention
|
||
|
|
|
||
|
|
# Combine with parent's causal mask
|
||
|
|
if attention_mask is not None:
|
||
|
|
# Parent may create mask with extra KV positions — trim to match
|
||
|
|
if attention_mask.size(-1) != sparse_mask.size(-1):
|
||
|
|
attention_mask = attention_mask[..., :sparse_mask.size(-1)]
|
||
|
|
if attention_mask.size(-2) != sparse_mask.size(-2):
|
||
|
|
attention_mask = attention_mask[..., :sparse_mask.size(-2), :]
|
||
|
|
attention_mask = torch.minimum(attention_mask, sparse_mask)
|
||
|
|
else:
|
||
|
|
attention_mask = sparse_mask
|
||
|
|
del sparse_mask
|
||
|
|
|
||
|
|
forward_kwargs = dict(
|
||
|
|
hidden_states=hidden_states,
|
||
|
|
cache_position=cache_position,
|
||
|
|
attention_mask=attention_mask,
|
||
|
|
encoder_hidden_states=encoder_hidden_states,
|
||
|
|
encoder_attention_mask=encoder_attention_mask,
|
||
|
|
output_attentions=output_attentions,
|
||
|
|
**kwargs,
|
||
|
|
)
|
||
|
|
if IS_TRANSFORMERS_V5:
|
||
|
|
forward_kwargs["past_key_values"] = past_key_value
|
||
|
|
else:
|
||
|
|
forward_kwargs["past_key_value"] = past_key_value
|
||
|
|
forward_kwargs["head_mask"] = head_mask
|
||
|
|
|
||
|
|
return super().forward(**forward_kwargs)
|
||
|
|
|
||
|
|
|
||
|
|
class GPT3DevMLP(GPT2MLP):
|
||
|
|
def __init__(self, intermediate_size, config):
|
||
|
|
super().__init__(intermediate_size, config)
|
||
|
|
self.c_fc = nn.Linear(config.hidden_size, intermediate_size, bias=True)
|
||
|
|
self.c_proj = nn.Linear(intermediate_size, config.hidden_size, bias=True)
|
||
|
|
self.act = nn.GELU() # standard GeLU
|
||
|
|
|
||
|
|
|
||
|
|
class GPT3DevBlock(GPT2Block):
|
||
|
|
"""GPT-3 block with pre-LayerNorm and alternating dense/sparse attention."""
|
||
|
|
def __init__(self, config, is_sparse: bool = False, layer_idx=None):
|
||
|
|
super().__init__(config, layer_idx=layer_idx)
|
||
|
|
self.use_pre_layernorm = config.use_pre_layernorm
|
||
|
|
self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||
|
|
self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||
|
|
|
||
|
|
if is_sparse:
|
||
|
|
self.attn = GPT3DevSparseAttention(config, layer_idx=layer_idx)
|
||
|
|
else:
|
||
|
|
self.attn = GPT3DevAttention(config, layer_idx=layer_idx)
|
||
|
|
|
||
|
|
self.mlp = GPT3DevMLP(4 * config.hidden_size, config)
|
||
|
|
|
||
|
|
def forward(
|
||
|
|
self,
|
||
|
|
hidden_states,
|
||
|
|
past_key_value=None,
|
||
|
|
cache_position=None,
|
||
|
|
attention_mask=None,
|
||
|
|
*extra_args,
|
||
|
|
head_mask=None,
|
||
|
|
encoder_hidden_states=None,
|
||
|
|
encoder_attention_mask=None,
|
||
|
|
use_cache=False,
|
||
|
|
output_attentions=False,
|
||
|
|
past_key_values=None,
|
||
|
|
**kwargs,
|
||
|
|
):
|
||
|
|
if past_key_values is not None and past_key_value is None:
|
||
|
|
past_key_value = past_key_values
|
||
|
|
(
|
||
|
|
head_mask,
|
||
|
|
encoder_hidden_states,
|
||
|
|
encoder_attention_mask,
|
||
|
|
use_cache,
|
||
|
|
output_attentions,
|
||
|
|
) = _normalize_block_args(
|
||
|
|
extra_args,
|
||
|
|
head_mask=head_mask,
|
||
|
|
encoder_hidden_states=encoder_hidden_states,
|
||
|
|
encoder_attention_mask=encoder_attention_mask,
|
||
|
|
use_cache=use_cache,
|
||
|
|
output_attentions=output_attentions,
|
||
|
|
)
|
||
|
|
|
||
|
|
if self.use_pre_layernorm:
|
||
|
|
# Pre-LayerNorm (GPT-3)
|
||
|
|
residual = hidden_states
|
||
|
|
hidden_states = self.ln_1(hidden_states)
|
||
|
|
attn_kwargs = dict(
|
||
|
|
hidden_states=hidden_states,
|
||
|
|
cache_position=cache_position,
|
||
|
|
attention_mask=attention_mask,
|
||
|
|
encoder_hidden_states=encoder_hidden_states,
|
||
|
|
encoder_attention_mask=encoder_attention_mask,
|
||
|
|
output_attentions=output_attentions,
|
||
|
|
**kwargs,
|
||
|
|
)
|
||
|
|
if IS_TRANSFORMERS_V5:
|
||
|
|
attn_kwargs["past_key_values"] = past_key_value
|
||
|
|
attn_output, attn_weights = self.attn(**attn_kwargs)
|
||
|
|
else:
|
||
|
|
attn_kwargs["past_key_value"] = past_key_value
|
||
|
|
attn_kwargs["head_mask"] = head_mask
|
||
|
|
attn_output, attn_weights = self.attn(**attn_kwargs)
|
||
|
|
hidden_states = residual + attn_output
|
||
|
|
|
||
|
|
residual = hidden_states
|
||
|
|
hidden_states = self.ln_2(hidden_states)
|
||
|
|
feed_forward_hidden_states = self.mlp(hidden_states)
|
||
|
|
hidden_states = residual + feed_forward_hidden_states
|
||
|
|
else:
|
||
|
|
# Post-LayerNorm (GPT-2)
|
||
|
|
residual = hidden_states
|
||
|
|
attn_kwargs = dict(
|
||
|
|
hidden_states=hidden_states,
|
||
|
|
cache_position=cache_position,
|
||
|
|
attention_mask=attention_mask,
|
||
|
|
encoder_hidden_states=encoder_hidden_states,
|
||
|
|
encoder_attention_mask=encoder_attention_mask,
|
||
|
|
output_attentions=output_attentions,
|
||
|
|
**kwargs,
|
||
|
|
)
|
||
|
|
if IS_TRANSFORMERS_V5:
|
||
|
|
attn_kwargs["past_key_values"] = past_key_value
|
||
|
|
attn_output, attn_weights = self.attn(**attn_kwargs)
|
||
|
|
else:
|
||
|
|
attn_kwargs["past_key_value"] = past_key_value
|
||
|
|
attn_kwargs["head_mask"] = head_mask
|
||
|
|
attn_output, attn_weights = self.attn(**attn_kwargs)
|
||
|
|
hidden_states = residual + attn_output
|
||
|
|
hidden_states = self.ln_1(hidden_states)
|
||
|
|
|
||
|
|
residual = hidden_states
|
||
|
|
feed_forward_hidden_states = self.mlp(hidden_states)
|
||
|
|
hidden_states = residual + feed_forward_hidden_states
|
||
|
|
hidden_states = self.ln_2(hidden_states)
|
||
|
|
|
||
|
|
if IS_TRANSFORMERS_V5:
|
||
|
|
return hidden_states
|
||
|
|
|
||
|
|
outputs = (hidden_states,)
|
||
|
|
if output_attentions:
|
||
|
|
outputs += (attn_weights,)
|
||
|
|
return outputs
|
||
|
|
|
||
|
|
|
||
|
|
class GPT3DevModel(GPT2Model):
|
||
|
|
config_class = GPT3DevConfig
|
||
|
|
|
||
|
|
def __init__(self, config):
|
||
|
|
super().__init__(config)
|
||
|
|
|
||
|
|
self.wte = nn.Embedding(config.vocab_size, config.hidden_size)
|
||
|
|
self.wpe = nn.Embedding(config.n_positions, config.hidden_size)
|
||
|
|
self.drop = nn.Dropout(config.embd_pdrop)
|
||
|
|
self.h = nn.ModuleList()
|
||
|
|
for i in range(config.num_hidden_layers):
|
||
|
|
self.h.append(GPT3DevBlock(config, is_sparse=(i % 2 == 1), layer_idx=i))
|
||
|
|
|
||
|
|
self.ln_f = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||
|
|
self.post_init()
|
||
|
|
# NOTE: _apply_residual_scaling is called from GPT3DevLMHeadModel.__init__
|
||
|
|
# AFTER the final post_init(), so it is NOT undone by re-initialization.
|
||
|
|
|
||
|
|
def _apply_residual_scaling(self):
|
||
|
|
# GPT-3/GPT-2 modified init: scale residuals by 1 / sqrt(2 * num_layers)
|
||
|
|
scale = 1 / math.sqrt(2 * self.config.num_hidden_layers)
|
||
|
|
for block in self.h:
|
||
|
|
block.attn.c_proj.weight.data.mul_(scale)
|
||
|
|
block.mlp.c_proj.weight.data.mul_(scale)
|
||
|
|
|
||
|
|
|
||
|
|
|
||
|
|
|
||
|
|
class GPT3DevLMHeadModel(GPT2LMHeadModel):
|
||
|
|
config_class = GPT3DevConfig
|
||
|
|
|
||
|
|
def __init__(self, config):
|
||
|
|
super().__init__(config)
|
||
|
|
self.transformer = GPT3DevModel(config)
|
||
|
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||
|
|
|
||
|
|
self.post_init()
|
||
|
|
# GPT-3 modified init: scale residual projections by 1/sqrt(2*num_layers)
|
||
|
|
# MUST be AFTER the final post_init() which re-initializes all weights
|
||
|
|
self.transformer._apply_residual_scaling()
|
||
|
|
|
||
|
|
def forward(
|
||
|
|
self,
|
||
|
|
input_ids=None,
|
||
|
|
past_key_values=None,
|
||
|
|
cache_position=None,
|
||
|
|
attention_mask=None,
|
||
|
|
token_type_ids=None,
|
||
|
|
position_ids=None,
|
||
|
|
head_mask=None,
|
||
|
|
inputs_embeds=None,
|
||
|
|
encoder_hidden_states=None,
|
||
|
|
encoder_attention_mask=None,
|
||
|
|
labels=None,
|
||
|
|
use_cache=None,
|
||
|
|
output_attentions=None,
|
||
|
|
output_hidden_states=None,
|
||
|
|
return_dict=None,
|
||
|
|
logits_to_keep=0,
|
||
|
|
output_logits=None, # Force returning full logits even with labels (for debugging/distillation)
|
||
|
|
**kwargs,
|
||
|
|
):
|
||
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||
|
|
|
||
|
|
transformer_kwargs = dict(
|
||
|
|
input_ids=input_ids,
|
||
|
|
attention_mask=attention_mask,
|
||
|
|
cache_position=cache_position,
|
||
|
|
token_type_ids=token_type_ids,
|
||
|
|
position_ids=position_ids,
|
||
|
|
inputs_embeds=inputs_embeds,
|
||
|
|
encoder_hidden_states=encoder_hidden_states,
|
||
|
|
encoder_attention_mask=encoder_attention_mask,
|
||
|
|
use_cache=use_cache,
|
||
|
|
**kwargs,
|
||
|
|
)
|
||
|
|
if not IS_TRANSFORMERS_V5:
|
||
|
|
transformer_kwargs["head_mask"] = head_mask
|
||
|
|
transformer_kwargs["output_attentions"] = output_attentions
|
||
|
|
transformer_kwargs["output_hidden_states"] = output_hidden_states
|
||
|
|
transformer_kwargs["return_dict"] = return_dict
|
||
|
|
transformer_kwargs["past_key_values"] = past_key_values
|
||
|
|
|
||
|
|
transformer_outputs = self.transformer(**transformer_kwargs)
|
||
|
|
|
||
|
|
hidden_states = (
|
||
|
|
transformer_outputs.last_hidden_state
|
||
|
|
if hasattr(transformer_outputs, "last_hidden_state")
|
||
|
|
else transformer_outputs[0]
|
||
|
|
)
|
||
|
|
|
||
|
|
# Set up for loss computation if labels are provided
|
||
|
|
compute_full_logits = labels is not None or output_logits or logits_to_keep == 0
|
||
|
|
if compute_full_logits:
|
||
|
|
logits_hidden_states = hidden_states
|
||
|
|
else:
|
||
|
|
slice_indices = (
|
||
|
|
slice(-logits_to_keep, None)
|
||
|
|
if isinstance(logits_to_keep, int)
|
||
|
|
else logits_to_keep
|
||
|
|
)
|
||
|
|
logits_hidden_states = hidden_states[:, slice_indices, :]
|
||
|
|
lm_logits = self.lm_head(logits_hidden_states.contiguous())
|
||
|
|
|
||
|
|
loss = None
|
||
|
|
if labels is not None:
|
||
|
|
# Shift so that tokens < n predict n
|
||
|
|
shift_logits = lm_logits[..., :-1, :].contiguous()
|
||
|
|
shift_labels = labels[..., 1:].contiguous()
|
||
|
|
# Flatten the tokens
|
||
|
|
loss_fct = nn.CrossEntropyLoss()
|
||
|
|
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||
|
|
shift_labels = shift_labels.view(-1)
|
||
|
|
# Enable model parallelism
|
||
|
|
shift_labels = shift_labels.to(shift_logits.device)
|
||
|
|
loss = loss_fct(shift_logits, shift_labels)
|
||
|
|
|
||
|
|
if not return_dict:
|
||
|
|
return ((loss,) if loss is not None else ()) + (lm_logits,) + transformer_outputs[1:]
|
||
|
|
|
||
|
|
return CausalLMOutputWithCrossAttentions(
|
||
|
|
loss=loss,
|
||
|
|
logits=lm_logits,
|
||
|
|
past_key_values=getattr(transformer_outputs, "past_key_values", None),
|
||
|
|
hidden_states=getattr(transformer_outputs, "hidden_states", None),
|
||
|
|
attentions=getattr(transformer_outputs, "attentions", None),
|
||
|
|
cross_attentions=getattr(transformer_outputs, "cross_attentions", None),
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
AutoConfig.register("gpt3dev", GPT3DevConfig)
|
||
|
|
AutoModel.register(GPT3DevConfig, GPT3DevModel)
|
||
|
|
AutoModelForCausalLM.register(GPT3DevConfig, GPT3DevLMHeadModel)
|
||
|
|
|
||
|
|
# ---- Transformers 5.x compatibility patch ----
|
||
|
|
_ORIG_GPT3DEV_BLOCK_FORWARD = GPT3DevBlock.forward
|
||
|
|
_ORIG_GPT3DEV_SPARSE_FORWARD = GPT3DevSparseAttention.forward
|
||
|
|
|
||
|
|
def _patched_gpt3dev_block_forward(
|
||
|
|
self,
|
||
|
|
hidden_states,
|
||
|
|
past_key_values=None,
|
||
|
|
attention_mask=None,
|
||
|
|
encoder_hidden_states=None,
|
||
|
|
encoder_attention_mask=None,
|
||
|
|
use_cache=False,
|
||
|
|
**kwargs,
|
||
|
|
):
|
||
|
|
cache_position = kwargs.pop("cache_position", None)
|
||
|
|
output_attentions = kwargs.pop("output_attentions", False)
|
||
|
|
head_mask = kwargs.pop("head_mask", None)
|
||
|
|
past_key_value = kwargs.pop("past_key_value", None)
|
||
|
|
if past_key_values is None:
|
||
|
|
past_key_values = past_key_value
|
||
|
|
|
||
|
|
return _ORIG_GPT3DEV_BLOCK_FORWARD(
|
||
|
|
self,
|
||
|
|
hidden_states,
|
||
|
|
past_key_value=past_key_values,
|
||
|
|
cache_position=cache_position,
|
||
|
|
attention_mask=attention_mask,
|
||
|
|
head_mask=head_mask,
|
||
|
|
encoder_hidden_states=encoder_hidden_states,
|
||
|
|
encoder_attention_mask=encoder_attention_mask,
|
||
|
|
use_cache=use_cache,
|
||
|
|
output_attentions=output_attentions,
|
||
|
|
**kwargs,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def _patched_gpt3dev_sparse_forward(
|
||
|
|
self,
|
||
|
|
hidden_states,
|
||
|
|
past_key_values=None,
|
||
|
|
attention_mask=None,
|
||
|
|
encoder_hidden_states=None,
|
||
|
|
encoder_attention_mask=None,
|
||
|
|
output_attentions=False,
|
||
|
|
**kwargs,
|
||
|
|
):
|
||
|
|
cache_position = kwargs.pop("cache_position", None)
|
||
|
|
head_mask = kwargs.pop("head_mask", None)
|
||
|
|
past_key_value = kwargs.pop("past_key_value", None)
|
||
|
|
if past_key_values is None:
|
||
|
|
past_key_values = past_key_value
|
||
|
|
|
||
|
|
return _ORIG_GPT3DEV_SPARSE_FORWARD(
|
||
|
|
self,
|
||
|
|
hidden_states,
|
||
|
|
past_key_value=past_key_values,
|
||
|
|
cache_position=cache_position,
|
||
|
|
attention_mask=attention_mask,
|
||
|
|
head_mask=head_mask,
|
||
|
|
encoder_hidden_states=encoder_hidden_states,
|
||
|
|
encoder_attention_mask=encoder_attention_mask,
|
||
|
|
output_attentions=output_attentions,
|
||
|
|
**kwargs,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
GPT3DevBlock.forward = _patched_gpt3dev_block_forward
|
||
|
|
GPT3DevSparseAttention.forward = _patched_gpt3dev_sparse_forward
|
||
|
|
# ---- End compatibility patch ----
|