1695 lines
65 KiB
Python
1695 lines
65 KiB
Python
# coding=utf-8
|
|
# Copyright 2021 The Google Flax Team Authors and The HuggingFace Inc. team.
|
|
# Modifications copyright 2022 Xinyang Geng
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
from typing import Callable, Optional, Tuple
|
|
from collections import OrderedDict
|
|
from typing import Mapping
|
|
|
|
import numpy as np
|
|
|
|
import flax.linen as nn
|
|
import jax
|
|
import jax.numpy as jnp
|
|
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
|
from flax.linen import combine_masks, make_causal_mask
|
|
from flax.linen import partitioning as nn_partitioning
|
|
from flax.linen.attention import dot_product_attention_weights
|
|
from flax.traverse_util import flatten_dict, unflatten_dict
|
|
from jax import lax
|
|
from jax.sharding import PartitionSpec
|
|
|
|
from transformers.configuration_utils import PretrainedConfig
|
|
from transformers.modeling_flax_outputs import (
|
|
FlaxBaseModelOutputWithPastAndCrossAttentions,
|
|
FlaxBaseModelOutputWithPooling,
|
|
FlaxBaseModelOutputWithPoolingAndCrossAttentions,
|
|
FlaxCausalLMOutputWithCrossAttentions,
|
|
FlaxMaskedLMOutput,
|
|
FlaxMultipleChoiceModelOutput,
|
|
FlaxQuestionAnsweringModelOutput,
|
|
FlaxSequenceClassifierOutput,
|
|
FlaxTokenClassifierOutput,
|
|
)
|
|
from transformers.modeling_flax_utils import (
|
|
ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring,
|
|
overwrite_call_docstring
|
|
)
|
|
from transformers.utils import (
|
|
add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
|
)
|
|
from transformers import AutoTokenizer
|
|
|
|
from ml_collections import ConfigDict
|
|
from ml_collections.config_dict import config_dict
|
|
from mlxu import function_args_to_config, load_pickle
|
|
|
|
from EasyLM.jax_utils import with_sharding_constraint, get_jax_mesh
|
|
|
|
|
|
"""
|
|
The follow code is taken from
|
|
transformers/src/transformers/models/roberta/configuration_roberta.py
|
|
and modified to work with EasyLM.
|
|
"""
|
|
|
|
|
|
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
|
"roberta-base": "https://huggingface.co/roberta-base/resolve/main/config.json",
|
|
"roberta-large": "https://huggingface.co/roberta-large/resolve/main/config.json",
|
|
"roberta-large-mnli": "https://huggingface.co/roberta-large-mnli/resolve/main/config.json",
|
|
"distilroberta-base": "https://huggingface.co/distilroberta-base/resolve/main/config.json",
|
|
"roberta-base-openai-detector": "https://huggingface.co/roberta-base-openai-detector/resolve/main/config.json",
|
|
"roberta-large-openai-detector": "https://huggingface.co/roberta-large-openai-detector/resolve/main/config.json",
|
|
}
|
|
|
|
|
|
class RobertaConfig(PretrainedConfig):
|
|
r"""
|
|
This is the configuration class to store the configuration of a [`RobertaModel`] or a [`TFRobertaModel`]. It is
|
|
used to instantiate a RoBERTa model according to the specified arguments, defining the model architecture.
|
|
Instantiating a configuration with the defaults will yield a similar configuration to that of the RoBERTa
|
|
[roberta-base](https://huggingface.co/roberta-base) architecture.
|
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
|
documentation from [`PretrainedConfig`] for more information.
|
|
Args:
|
|
vocab_size (`int`, *optional*, defaults to 30522):
|
|
Vocabulary size of the RoBERTa model. Defines the number of different tokens that can be represented by the
|
|
`inputs_ids` passed when calling [`RobertaModel`] or [`TFRobertaModel`].
|
|
hidden_size (`int`, *optional*, defaults to 768):
|
|
Dimensionality of the encoder layers and the pooler layer.
|
|
num_hidden_layers (`int`, *optional*, defaults to 12):
|
|
Number of hidden layers in the Transformer encoder.
|
|
num_attention_heads (`int`, *optional*, defaults to 12):
|
|
Number of attention heads for each attention layer in the Transformer encoder.
|
|
intermediate_size (`int`, *optional*, defaults to 3072):
|
|
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
|
|
hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
|
|
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
|
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
|
hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
|
|
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
|
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
|
|
The dropout ratio for the attention probabilities.
|
|
max_position_embeddings (`int`, *optional*, defaults to 512):
|
|
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
|
just in case (e.g., 512 or 1024 or 2048).
|
|
type_vocab_size (`int`, *optional*, defaults to 2):
|
|
The vocabulary size of the `token_type_ids` passed when calling [`RobertaModel`] or [`TFRobertaModel`].
|
|
initializer_range (`float`, *optional*, defaults to 0.02):
|
|
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
|
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
|
The epsilon used by the layer normalization layers.
|
|
position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
|
|
Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
|
|
positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
|
|
[Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155).
|
|
For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
|
|
with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).
|
|
use_cache (`bool`, *optional*, defaults to `True`):
|
|
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
|
relevant if `config.is_decoder=True`.
|
|
classifier_dropout (`float`, *optional*):
|
|
The dropout ratio for the classification head.
|
|
Examples:
|
|
```python
|
|
>>> from transformers import RobertaConfig, RobertaModel
|
|
>>> # Initializing a RoBERTa configuration
|
|
>>> configuration = RobertaConfig()
|
|
>>> # Initializing a model (with random weights) from the configuration
|
|
>>> model = RobertaModel(configuration)
|
|
>>> # Accessing the model configuration
|
|
>>> configuration = model.config
|
|
```"""
|
|
model_type = "roberta"
|
|
|
|
def __init__(
|
|
self,
|
|
vocab_size=50265,
|
|
hidden_size=768,
|
|
num_hidden_layers=12,
|
|
num_attention_heads=12,
|
|
intermediate_size=3072,
|
|
hidden_act="gelu",
|
|
hidden_dropout_prob=0.1,
|
|
attention_probs_dropout_prob=0.1,
|
|
max_position_embeddings=514,
|
|
type_vocab_size=1,
|
|
initializer_range=0.02,
|
|
layer_norm_eps=1e-5,
|
|
pad_token_id=1,
|
|
bos_token_id=0,
|
|
eos_token_id=2,
|
|
position_embedding_type="absolute",
|
|
use_cache=True,
|
|
classifier_dropout=None,
|
|
**kwargs
|
|
):
|
|
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
|
|
|
self.vocab_size = vocab_size
|
|
self.hidden_size = hidden_size
|
|
self.num_hidden_layers = num_hidden_layers
|
|
self.num_attention_heads = num_attention_heads
|
|
self.hidden_act = hidden_act
|
|
self.intermediate_size = intermediate_size
|
|
self.hidden_dropout_prob = hidden_dropout_prob
|
|
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
|
self.max_position_embeddings = max_position_embeddings
|
|
self.type_vocab_size = type_vocab_size
|
|
self.initializer_range = initializer_range
|
|
self.layer_norm_eps = layer_norm_eps
|
|
self.position_embedding_type = position_embedding_type
|
|
self.use_cache = use_cache
|
|
self.classifier_dropout = classifier_dropout
|
|
|
|
@classmethod
|
|
def get_default_config(cls, updates=None):
|
|
none_arg_types = dict(
|
|
classifier_dropout=float,
|
|
)
|
|
config = function_args_to_config(cls.__init__, none_arg_types=none_arg_types)
|
|
config.tie_word_embeddings = True
|
|
|
|
if updates is not None:
|
|
config.update(ConfigDict(updates).copy_and_resolve_references())
|
|
|
|
return config
|
|
|
|
@staticmethod
|
|
def get_jax_mesh(axis_dims):
|
|
return get_jax_mesh(axis_dims, ('dp', 'fsdp', 'mp'))
|
|
|
|
@staticmethod
|
|
def get_partition_rules():
|
|
""" Parition rules for Roberta model. """
|
|
return (
|
|
('embeddings/(position_embeddings|token_type_embeddings)/embedding', PartitionSpec()),
|
|
('embeddings/word_embeddings/embedding', PartitionSpec()),
|
|
('attention/self/(key|query|value)/kernel', PartitionSpec('fsdp', 'mp')),
|
|
('attention/self/(key|query|value)/bias', PartitionSpec()),
|
|
('attention/output/dense/kernel', PartitionSpec('mp', 'fsdp')),
|
|
('attention/output/dense/bias', PartitionSpec()),
|
|
('(LayerNorm|layer_norm)/(bias|scale)', PartitionSpec()),
|
|
('intermediate/dense/kernel', PartitionSpec('fsdp', 'mp')),
|
|
('intermediate/dense/bias', PartitionSpec('mp')),
|
|
('output/dense/kernel', PartitionSpec('mp', 'fsdp')),
|
|
('output/dense/bias', PartitionSpec()),
|
|
('lm_head/dense/kernel', PartitionSpec()),
|
|
('lm_head/dense/bias', PartitionSpec()),
|
|
('lm_head/decoder/kernel', PartitionSpec('fsdp', 'mp')),
|
|
('lm_head/decoder/bias', PartitionSpec('mp')),
|
|
('.*', PartitionSpec()),
|
|
)
|
|
|
|
@staticmethod
|
|
def get_weight_decay_exclusions():
|
|
return ('bias', 'LayerNorm/scale', 'layer_norm/scale')
|
|
|
|
@staticmethod
|
|
def rng_keys():
|
|
return ('params', 'dropout')
|
|
|
|
@staticmethod
|
|
def get_tokenizer_config(updates=None):
|
|
config = ConfigDict()
|
|
config.name = 'roberta-base'
|
|
|
|
if updates is not None:
|
|
config.update(ConfigDict(updates).copy_and_resolve_references())
|
|
|
|
return config
|
|
|
|
@classmethod
|
|
def get_tokenizer(cls, config):
|
|
config = cls.get_tokenizer_config(config)
|
|
return AutoTokenizer.from_pretrained(
|
|
config.name,
|
|
)
|
|
|
|
@staticmethod
|
|
def load_pretrained(name):
|
|
with jax.default_device(jax.devices("cpu")[0]):
|
|
params = FlaxRobertaForMaskedLM.from_pretrained(name, _do_init=False)[1]
|
|
params = freeze({'params': params})
|
|
return params
|
|
|
|
@classmethod
|
|
def load_config(cls, path):
|
|
load_type, load_path = path.split('::', 1)
|
|
if load_type == 'pickle':
|
|
return cls.from_dict(load_pickle(load_path)['roberta_config'])
|
|
elif load_type == 'huggingface':
|
|
return cls.from_pretrained(load_path)
|
|
else:
|
|
raise ValueError(f'Unsupported load config type: {load_type}')
|
|
|
|
|
|
"""
|
|
The follow code is taken from
|
|
transformers/src/transformers/models/roberta/modeling_flax_roberta.py
|
|
and modified to work with EasyLM.
|
|
"""
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
_CHECKPOINT_FOR_DOC = "roberta-base"
|
|
_CONFIG_FOR_DOC = "RobertaConfig"
|
|
|
|
remat = nn_partitioning.remat
|
|
|
|
|
|
def create_position_ids_from_input_ids(input_ids, padding_idx):
|
|
"""
|
|
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
|
|
are ignored. This is modified from fairseq's `utils.make_positions`.
|
|
Args:
|
|
input_ids: jnp.ndarray
|
|
padding_idx: int
|
|
Returns: jnp.ndarray
|
|
"""
|
|
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
|
|
mask = (input_ids != padding_idx).astype("i4")
|
|
|
|
if mask.ndim > 2:
|
|
mask = mask.reshape((-1, mask.shape[-1]))
|
|
incremental_indices = jnp.cumsum(mask, axis=1).astype("i4") * mask
|
|
incremental_indices = incremental_indices.reshape(input_ids.shape)
|
|
else:
|
|
incremental_indices = jnp.cumsum(mask, axis=1).astype("i4") * mask
|
|
|
|
return incremental_indices.astype("i4") + padding_idx
|
|
|
|
|
|
ROBERTA_START_DOCSTRING = r"""
|
|
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
|
|
library implements for all its model (such as downloading, saving and converting weights from PyTorch models)
|
|
This model is also a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
|
|
subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to
|
|
general usage and behavior.
|
|
Finally, this model supports inherent JAX features such as:
|
|
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
|
|
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
|
|
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
|
|
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
|
|
Parameters:
|
|
config ([`RobertaConfig`]): Model configuration class with all the parameters of the
|
|
model. Initializing with a config file does not load the weights associated with the model, only the
|
|
configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
|
|
"""
|
|
|
|
ROBERTA_INPUTS_DOCSTRING = r"""
|
|
Args:
|
|
input_ids (`numpy.ndarray` of shape `({0})`):
|
|
Indices of input sequence tokens in the vocabulary.
|
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
|
[`PreTrainedTokenizer.__call__`] for details.
|
|
[What are input IDs?](../glossary#input-ids)
|
|
attention_mask (`numpy.ndarray` of shape `({0})`, *optional*):
|
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
|
- 1 for tokens that are **not masked**,
|
|
- 0 for tokens that are **masked**.
|
|
[What are attention masks?](../glossary#attention-mask)
|
|
token_type_ids (`numpy.ndarray` of shape `({0})`, *optional*):
|
|
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
|
|
1]`:
|
|
- 0 corresponds to a *sentence A* token,
|
|
- 1 corresponds to a *sentence B* token.
|
|
[What are token type IDs?](../glossary#token-type-ids)
|
|
position_ids (`numpy.ndarray` of shape `({0})`, *optional*):
|
|
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
|
config.max_position_embeddings - 1]`.
|
|
head_mask (`numpy.ndarray` of shape `({0})`, `optional):
|
|
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
|
|
- 1 indicates the head is **not masked**,
|
|
- 0 indicates the head is **masked**.
|
|
return_dict (`bool`, *optional*):
|
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
|
"""
|
|
|
|
|
|
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEmbeddings with Bert->Roberta
|
|
class FlaxRobertaEmbeddings(nn.Module):
|
|
"""Construct the embeddings from word, position and token_type embeddings."""
|
|
|
|
config: RobertaConfig
|
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
|
|
|
def setup(self):
|
|
self.word_embeddings = nn.Embed(
|
|
self.config.vocab_size,
|
|
self.config.hidden_size,
|
|
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
|
dtype=self.dtype,
|
|
)
|
|
self.position_embeddings = nn.Embed(
|
|
self.config.max_position_embeddings,
|
|
self.config.hidden_size,
|
|
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
|
dtype=self.dtype,
|
|
)
|
|
self.token_type_embeddings = nn.Embed(
|
|
self.config.type_vocab_size,
|
|
self.config.hidden_size,
|
|
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
|
dtype=self.dtype,
|
|
)
|
|
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
|
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
|
|
|
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):
|
|
# Embed
|
|
inputs_embeds = self.word_embeddings(input_ids.astype("i4"))
|
|
position_embeds = self.position_embeddings(position_ids.astype("i4"))
|
|
token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4"))
|
|
|
|
# Sum all embeddings
|
|
hidden_states = inputs_embeds + token_type_embeddings + position_embeds
|
|
|
|
# Layer Norm
|
|
hidden_states = self.LayerNorm(hidden_states)
|
|
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
|
return hidden_states
|
|
|
|
|
|
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfAttention with Bert->Roberta
|
|
class FlaxRobertaSelfAttention(nn.Module):
|
|
config: RobertaConfig
|
|
causal: bool = False
|
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
|
|
|
def setup(self):
|
|
self.head_dim = self.config.hidden_size // self.config.num_attention_heads
|
|
if self.config.hidden_size % self.config.num_attention_heads != 0:
|
|
raise ValueError(
|
|
"`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads` "
|
|
" : {self.config.num_attention_heads}"
|
|
)
|
|
|
|
self.query = nn.Dense(
|
|
self.config.hidden_size,
|
|
dtype=self.dtype,
|
|
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
|
)
|
|
self.key = nn.Dense(
|
|
self.config.hidden_size,
|
|
dtype=self.dtype,
|
|
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
|
)
|
|
self.value = nn.Dense(
|
|
self.config.hidden_size,
|
|
dtype=self.dtype,
|
|
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
|
)
|
|
|
|
if self.causal:
|
|
self.causal_mask = make_causal_mask(
|
|
jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool"
|
|
)
|
|
|
|
def _split_heads(self, hidden_states):
|
|
return hidden_states.reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, self.head_dim))
|
|
|
|
def _merge_heads(self, hidden_states):
|
|
return hidden_states.reshape(hidden_states.shape[:2] + (self.config.hidden_size,))
|
|
|
|
@nn.compact
|
|
# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention._concatenate_to_cache
|
|
def _concatenate_to_cache(self, key, value, query, attention_mask):
|
|
"""
|
|
This function takes projected key, value states from a single input token and concatenates the states to cached
|
|
states from previous steps. This function is slighly adapted from the official Flax repository:
|
|
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
|
|
"""
|
|
# detect if we're initializing by absence of existing cache data.
|
|
is_initialized = self.has_variable("cache", "cached_key")
|
|
cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
|
|
cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
|
|
cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
|
|
|
|
if is_initialized:
|
|
*batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
|
|
# update key, value caches with our new 1d spatial slices
|
|
cur_index = cache_index.value
|
|
indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
|
|
key = lax.dynamic_update_slice(cached_key.value, key, indices)
|
|
value = lax.dynamic_update_slice(cached_value.value, value, indices)
|
|
cached_key.value = key
|
|
cached_value.value = value
|
|
num_updated_cache_vectors = query.shape[1]
|
|
cache_index.value = cache_index.value + num_updated_cache_vectors
|
|
# causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
|
|
pad_mask = jnp.broadcast_to(
|
|
jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
|
|
tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
|
|
)
|
|
attention_mask = combine_masks(pad_mask, attention_mask)
|
|
return key, value, attention_mask
|
|
|
|
def __call__(
|
|
self,
|
|
hidden_states,
|
|
attention_mask,
|
|
layer_head_mask,
|
|
key_value_states: Optional[jnp.array] = None,
|
|
init_cache: bool = False,
|
|
deterministic=True,
|
|
output_attentions: bool = False,
|
|
):
|
|
# if key_value_states are provided this layer is used as a cross-attention layer
|
|
# for the decoder
|
|
is_cross_attention = key_value_states is not None
|
|
batch_size = hidden_states.shape[0]
|
|
|
|
# get query proj
|
|
query_states = self.query(hidden_states)
|
|
# get key, value proj
|
|
if is_cross_attention:
|
|
# cross_attentions
|
|
key_states = self.key(key_value_states)
|
|
value_states = self.value(key_value_states)
|
|
else:
|
|
# self_attention
|
|
key_states = self.key(hidden_states)
|
|
value_states = self.value(hidden_states)
|
|
|
|
query_states = self._split_heads(query_states)
|
|
key_states = self._split_heads(key_states)
|
|
value_states = self._split_heads(value_states)
|
|
|
|
# handle cache prepare causal attention mask
|
|
if self.causal:
|
|
query_length, key_length = query_states.shape[1], key_states.shape[1]
|
|
if self.has_variable("cache", "cached_key"):
|
|
mask_shift = self.variables["cache"]["cache_index"]
|
|
max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
|
|
causal_mask = lax.dynamic_slice(
|
|
self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
|
|
)
|
|
else:
|
|
causal_mask = self.causal_mask[:, :, :query_length, :key_length]
|
|
causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
|
|
|
|
# combine masks if needed
|
|
if attention_mask is not None and self.causal:
|
|
attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
|
|
attention_mask = combine_masks(attention_mask, causal_mask)
|
|
elif self.causal:
|
|
attention_mask = causal_mask
|
|
elif attention_mask is not None:
|
|
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
|
|
|
|
# During fast autoregressive decoding, we feed one position at a time,
|
|
# and cache the keys and values step by step.
|
|
if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
|
|
key_states, value_states, attention_mask = self._concatenate_to_cache(
|
|
key_states, value_states, query_states, attention_mask
|
|
)
|
|
|
|
# Convert the boolean attention mask to an attention bias.
|
|
if attention_mask is not None:
|
|
# attention mask in the form of attention bias
|
|
attention_bias = lax.select(
|
|
attention_mask > 0,
|
|
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
|
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
|
|
)
|
|
else:
|
|
attention_bias = None
|
|
|
|
dropout_rng = None
|
|
if not deterministic and self.config.attention_probs_dropout_prob > 0.0:
|
|
dropout_rng = self.make_rng("dropout")
|
|
|
|
attn_weights = dot_product_attention_weights(
|
|
query_states,
|
|
key_states,
|
|
bias=attention_bias,
|
|
dropout_rng=dropout_rng,
|
|
dropout_rate=self.config.attention_probs_dropout_prob,
|
|
broadcast_dropout=True,
|
|
deterministic=deterministic,
|
|
dtype=self.dtype,
|
|
precision=None,
|
|
)
|
|
|
|
# Mask heads if we want to
|
|
if layer_head_mask is not None:
|
|
attn_weights = jnp.einsum("...hqk,h->...hqk", attn_weights, layer_head_mask)
|
|
|
|
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
|
|
attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,))
|
|
|
|
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
|
|
return outputs
|
|
|
|
|
|
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfOutput with Bert->Roberta
|
|
class FlaxRobertaSelfOutput(nn.Module):
|
|
config: RobertaConfig
|
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
|
|
|
def setup(self):
|
|
self.dense = nn.Dense(
|
|
self.config.hidden_size,
|
|
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
|
dtype=self.dtype,
|
|
)
|
|
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
|
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
|
|
|
def __call__(self, hidden_states, input_tensor, deterministic: bool = True):
|
|
hidden_states = self.dense(hidden_states)
|
|
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
|
return hidden_states
|
|
|
|
|
|
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertAttention with Bert->Roberta
|
|
class FlaxRobertaAttention(nn.Module):
|
|
config: RobertaConfig
|
|
causal: bool = False
|
|
dtype: jnp.dtype = jnp.float32
|
|
|
|
def setup(self):
|
|
self.self = FlaxRobertaSelfAttention(self.config, causal=self.causal, dtype=self.dtype)
|
|
self.output = FlaxRobertaSelfOutput(self.config, dtype=self.dtype)
|
|
|
|
def __call__(
|
|
self,
|
|
hidden_states,
|
|
attention_mask,
|
|
layer_head_mask,
|
|
key_value_states=None,
|
|
init_cache=False,
|
|
deterministic=True,
|
|
output_attentions: bool = False,
|
|
):
|
|
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
|
|
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
|
|
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
|
|
attn_outputs = self.self(
|
|
hidden_states,
|
|
attention_mask,
|
|
layer_head_mask=layer_head_mask,
|
|
key_value_states=key_value_states,
|
|
init_cache=init_cache,
|
|
deterministic=deterministic,
|
|
output_attentions=output_attentions,
|
|
)
|
|
attn_output = attn_outputs[0]
|
|
hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic)
|
|
|
|
outputs = (hidden_states,)
|
|
|
|
if output_attentions:
|
|
outputs += (attn_outputs[1],)
|
|
|
|
return outputs
|
|
|
|
|
|
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertIntermediate with Bert->Roberta
|
|
class FlaxRobertaIntermediate(nn.Module):
|
|
config: RobertaConfig
|
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
|
|
|
def setup(self):
|
|
self.dense = nn.Dense(
|
|
self.config.intermediate_size,
|
|
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
|
dtype=self.dtype,
|
|
)
|
|
self.activation = ACT2FN[self.config.hidden_act]
|
|
|
|
def __call__(self, hidden_states):
|
|
hidden_states = self.dense(hidden_states)
|
|
hidden_states = self.activation(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertOutput with Bert->Roberta
|
|
class FlaxRobertaOutput(nn.Module):
|
|
config: RobertaConfig
|
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
|
|
|
def setup(self):
|
|
self.dense = nn.Dense(
|
|
self.config.hidden_size,
|
|
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
|
dtype=self.dtype,
|
|
)
|
|
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
|
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
|
|
|
def __call__(self, hidden_states, attention_output, deterministic: bool = True):
|
|
hidden_states = self.dense(hidden_states)
|
|
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
|
hidden_states = self.LayerNorm(hidden_states + attention_output)
|
|
return hidden_states
|
|
|
|
|
|
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayer with Bert->Roberta
|
|
class FlaxRobertaLayer(nn.Module):
|
|
config: RobertaConfig
|
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
|
|
|
def setup(self):
|
|
self.attention = FlaxRobertaAttention(self.config, causal=self.config.is_decoder, dtype=self.dtype)
|
|
self.intermediate = FlaxRobertaIntermediate(self.config, dtype=self.dtype)
|
|
self.output = FlaxRobertaOutput(self.config, dtype=self.dtype)
|
|
if self.config.add_cross_attention:
|
|
self.crossattention = FlaxRobertaAttention(self.config, causal=False, dtype=self.dtype)
|
|
|
|
def __call__(
|
|
self,
|
|
hidden_states,
|
|
attention_mask,
|
|
layer_head_mask,
|
|
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
|
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
|
init_cache: bool = False,
|
|
deterministic: bool = True,
|
|
output_attentions: bool = False,
|
|
):
|
|
# Self Attention
|
|
attention_outputs = self.attention(
|
|
hidden_states,
|
|
attention_mask,
|
|
layer_head_mask=layer_head_mask,
|
|
init_cache=init_cache,
|
|
deterministic=deterministic,
|
|
output_attentions=output_attentions,
|
|
)
|
|
attention_output = attention_outputs[0]
|
|
|
|
# Cross-Attention Block
|
|
if encoder_hidden_states is not None:
|
|
cross_attention_outputs = self.crossattention(
|
|
attention_output,
|
|
attention_mask=encoder_attention_mask,
|
|
layer_head_mask=layer_head_mask,
|
|
key_value_states=encoder_hidden_states,
|
|
deterministic=deterministic,
|
|
output_attentions=output_attentions,
|
|
)
|
|
attention_output = cross_attention_outputs[0]
|
|
|
|
hidden_states = self.intermediate(attention_output)
|
|
hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic)
|
|
|
|
outputs = (hidden_states,)
|
|
|
|
if output_attentions:
|
|
outputs += (attention_outputs[1],)
|
|
if encoder_hidden_states is not None:
|
|
outputs += (cross_attention_outputs[1],)
|
|
return outputs
|
|
|
|
|
|
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayerCollection with Bert->Roberta
|
|
class FlaxRobertaLayerCollection(nn.Module):
|
|
config: RobertaConfig
|
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
|
gradient_checkpointing: bool = False
|
|
|
|
def setup(self):
|
|
if self.gradient_checkpointing:
|
|
FlaxRobertaCheckpointLayer = remat(FlaxRobertaLayer, static_argnums=(5, 6, 7))
|
|
self.layers = [
|
|
FlaxRobertaCheckpointLayer(self.config, name=str(i), dtype=self.dtype)
|
|
for i in range(self.config.num_hidden_layers)
|
|
]
|
|
else:
|
|
self.layers = [
|
|
FlaxRobertaLayer(self.config, name=str(i), dtype=self.dtype)
|
|
for i in range(self.config.num_hidden_layers)
|
|
]
|
|
|
|
def __call__(
|
|
self,
|
|
hidden_states,
|
|
attention_mask,
|
|
head_mask,
|
|
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
|
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
|
init_cache: bool = False,
|
|
deterministic: bool = True,
|
|
output_attentions: bool = False,
|
|
output_hidden_states: bool = False,
|
|
return_dict: bool = True,
|
|
):
|
|
all_attentions = () if output_attentions else None
|
|
all_hidden_states = () if output_hidden_states else None
|
|
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
|
|
|
# Check if head_mask has a correct number of layers specified if desired
|
|
if head_mask is not None:
|
|
if head_mask.shape[0] != (len(self.layers)):
|
|
raise ValueError(
|
|
f"The head_mask should be specified for {len(self.layers)} layers, but it is for "
|
|
f" {head_mask.shape[0]}."
|
|
)
|
|
|
|
for i, layer in enumerate(self.layers):
|
|
if output_hidden_states:
|
|
all_hidden_states += (hidden_states,)
|
|
|
|
layer_outputs = layer(
|
|
hidden_states,
|
|
attention_mask,
|
|
head_mask[i] if head_mask is not None else None,
|
|
encoder_hidden_states,
|
|
encoder_attention_mask,
|
|
init_cache,
|
|
deterministic,
|
|
output_attentions,
|
|
)
|
|
|
|
hidden_states = layer_outputs[0]
|
|
|
|
if output_attentions:
|
|
all_attentions += (layer_outputs[1],)
|
|
|
|
if encoder_hidden_states is not None:
|
|
all_cross_attentions += (layer_outputs[2],)
|
|
|
|
if output_hidden_states:
|
|
all_hidden_states += (hidden_states,)
|
|
|
|
outputs = (hidden_states, all_hidden_states, all_attentions, all_cross_attentions)
|
|
|
|
if not return_dict:
|
|
return tuple(v for v in outputs if v is not None)
|
|
|
|
return FlaxBaseModelOutputWithPastAndCrossAttentions(
|
|
last_hidden_state=hidden_states,
|
|
hidden_states=all_hidden_states,
|
|
attentions=all_attentions,
|
|
cross_attentions=all_cross_attentions,
|
|
)
|
|
|
|
|
|
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEncoder with Bert->Roberta
|
|
class FlaxRobertaEncoder(nn.Module):
|
|
config: RobertaConfig
|
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
|
gradient_checkpointing: bool = False
|
|
|
|
def setup(self):
|
|
self.layer = FlaxRobertaLayerCollection(
|
|
self.config,
|
|
dtype=self.dtype,
|
|
gradient_checkpointing=self.gradient_checkpointing,
|
|
)
|
|
|
|
def __call__(
|
|
self,
|
|
hidden_states,
|
|
attention_mask,
|
|
head_mask,
|
|
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
|
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
|
init_cache: bool = False,
|
|
deterministic: bool = True,
|
|
output_attentions: bool = False,
|
|
output_hidden_states: bool = False,
|
|
return_dict: bool = True,
|
|
):
|
|
return self.layer(
|
|
hidden_states,
|
|
attention_mask,
|
|
head_mask=head_mask,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
init_cache=init_cache,
|
|
deterministic=deterministic,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
|
|
|
|
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPooler with Bert->Roberta
|
|
class FlaxRobertaPooler(nn.Module):
|
|
config: RobertaConfig
|
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
|
|
|
def setup(self):
|
|
self.dense = nn.Dense(
|
|
self.config.hidden_size,
|
|
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
|
dtype=self.dtype,
|
|
)
|
|
|
|
def __call__(self, hidden_states):
|
|
cls_hidden_state = hidden_states[:, 0]
|
|
cls_hidden_state = self.dense(cls_hidden_state)
|
|
return nn.tanh(cls_hidden_state)
|
|
|
|
|
|
class FlaxRobertaLMHead(nn.Module):
|
|
config: RobertaConfig
|
|
dtype: jnp.dtype = jnp.float32
|
|
bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros
|
|
|
|
def setup(self):
|
|
self.dense = nn.Dense(
|
|
self.config.hidden_size,
|
|
dtype=self.dtype,
|
|
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
|
)
|
|
self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
|
self.decoder = nn.Dense(
|
|
self.config.vocab_size,
|
|
dtype=self.dtype,
|
|
use_bias=False,
|
|
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
|
)
|
|
self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,))
|
|
|
|
def __call__(self, hidden_states, shared_embedding=None):
|
|
hidden_states = self.dense(hidden_states)
|
|
hidden_states = ACT2FN["gelu"](hidden_states)
|
|
hidden_states = self.layer_norm(hidden_states)
|
|
|
|
if shared_embedding is not None:
|
|
hidden_states = self.decoder.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
|
|
else:
|
|
hidden_states = self.decoder(hidden_states)
|
|
|
|
bias = jnp.asarray(self.bias, self.dtype)
|
|
hidden_states += bias
|
|
return hidden_states
|
|
|
|
|
|
class FlaxRobertaClassificationHead(nn.Module):
|
|
config: RobertaConfig
|
|
dtype: jnp.dtype = jnp.float32
|
|
|
|
def setup(self):
|
|
self.dense = nn.Dense(
|
|
self.config.hidden_size,
|
|
dtype=self.dtype,
|
|
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
|
)
|
|
classifier_dropout = (
|
|
self.config.classifier_dropout
|
|
if self.config.classifier_dropout is not None
|
|
else self.config.hidden_dropout_prob
|
|
)
|
|
self.dropout = nn.Dropout(rate=classifier_dropout)
|
|
self.out_proj = nn.Dense(
|
|
self.config.num_labels,
|
|
dtype=self.dtype,
|
|
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
|
)
|
|
|
|
def __call__(self, hidden_states, deterministic=True):
|
|
hidden_states = hidden_states[:, 0, :] # take <s> token (equiv. to [CLS])
|
|
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
|
hidden_states = self.dense(hidden_states)
|
|
hidden_states = nn.tanh(hidden_states)
|
|
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
|
hidden_states = self.out_proj(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
|
|
"""
|
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
|
models.
|
|
"""
|
|
|
|
config_class = RobertaConfig
|
|
base_model_prefix = "roberta"
|
|
|
|
module_class: nn.Module = None
|
|
|
|
def __init__(
|
|
self,
|
|
config: RobertaConfig,
|
|
input_shape: Tuple = (1, 1),
|
|
seed: int = 0,
|
|
dtype: jnp.dtype = jnp.float32,
|
|
_do_init: bool = True,
|
|
gradient_checkpointing: bool = False,
|
|
**kwargs,
|
|
):
|
|
module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs)
|
|
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
|
|
|
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.enable_gradient_checkpointing
|
|
def enable_gradient_checkpointing(self):
|
|
self._module = self.module_class(
|
|
config=self.config,
|
|
dtype=self.dtype,
|
|
gradient_checkpointing=True,
|
|
)
|
|
|
|
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
|
# init input tensors
|
|
input_ids = jnp.zeros(input_shape, dtype="i4")
|
|
token_type_ids = jnp.ones_like(input_ids)
|
|
position_ids = create_position_ids_from_input_ids(input_ids, self.config.pad_token_id)
|
|
attention_mask = jnp.ones_like(input_ids)
|
|
head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads))
|
|
|
|
params_rng, dropout_rng = jax.random.split(rng)
|
|
rngs = {"params": params_rng, "dropout": dropout_rng}
|
|
|
|
if self.config.add_cross_attention:
|
|
encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,))
|
|
encoder_attention_mask = attention_mask
|
|
module_init_outputs = self.module.init(
|
|
rngs,
|
|
input_ids,
|
|
attention_mask,
|
|
token_type_ids,
|
|
position_ids,
|
|
head_mask,
|
|
encoder_hidden_states,
|
|
encoder_attention_mask,
|
|
return_dict=False,
|
|
)
|
|
else:
|
|
module_init_outputs = self.module.init(
|
|
rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
|
|
)
|
|
|
|
random_params = module_init_outputs["params"]
|
|
|
|
if params is not None:
|
|
random_params = flatten_dict(unfreeze(random_params))
|
|
params = flatten_dict(unfreeze(params))
|
|
for missing_key in self._missing_keys:
|
|
params[missing_key] = random_params[missing_key]
|
|
self._missing_keys = set()
|
|
return freeze(unflatten_dict(params))
|
|
else:
|
|
return random_params
|
|
|
|
# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartDecoderPreTrainedModel.init_cache
|
|
def init_cache(self, batch_size, max_length):
|
|
r"""
|
|
Args:
|
|
batch_size (`int`):
|
|
batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
|
|
max_length (`int`):
|
|
maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
|
|
cache.
|
|
"""
|
|
# init input variables to retrieve cache
|
|
input_ids = jnp.ones((batch_size, max_length), dtype="i4")
|
|
attention_mask = jnp.ones_like(input_ids, dtype="i4")
|
|
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
|
|
|
|
init_variables = self.module.init(
|
|
jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
|
|
)
|
|
return unfreeze(init_variables["cache"])
|
|
|
|
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
|
def __call__(
|
|
self,
|
|
input_ids,
|
|
attention_mask=None,
|
|
token_type_ids=None,
|
|
position_ids=None,
|
|
head_mask=None,
|
|
encoder_hidden_states=None,
|
|
encoder_attention_mask=None,
|
|
params: dict = None,
|
|
dropout_rng: jax.random.PRNGKey = None,
|
|
train: bool = False,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
past_key_values: dict = None,
|
|
):
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
output_hidden_states = (
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
)
|
|
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
|
|
|
# init input tensors if not passed
|
|
if token_type_ids is None:
|
|
token_type_ids = jnp.zeros_like(input_ids)
|
|
|
|
if position_ids is None:
|
|
position_ids = create_position_ids_from_input_ids(input_ids, self.config.pad_token_id)
|
|
|
|
if attention_mask is None:
|
|
attention_mask = jnp.ones_like(input_ids)
|
|
|
|
if head_mask is None:
|
|
head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads))
|
|
|
|
# Handle any PRNG if needed
|
|
rngs = {}
|
|
if dropout_rng is not None:
|
|
rngs["dropout"] = dropout_rng
|
|
|
|
inputs = {"params": params or self.params}
|
|
|
|
if self.config.add_cross_attention:
|
|
# if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed
|
|
# down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be
|
|
# changed by FlaxRobertaAttention module
|
|
if past_key_values:
|
|
inputs["cache"] = past_key_values
|
|
mutable = ["cache"]
|
|
else:
|
|
mutable = False
|
|
|
|
outputs = self.module.apply(
|
|
inputs,
|
|
jnp.array(input_ids, dtype="i4"),
|
|
jnp.array(attention_mask, dtype="i4"),
|
|
token_type_ids=jnp.array(token_type_ids, dtype="i4"),
|
|
position_ids=jnp.array(position_ids, dtype="i4"),
|
|
head_mask=jnp.array(head_mask, dtype="i4"),
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
deterministic=not train,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
rngs=rngs,
|
|
mutable=mutable,
|
|
)
|
|
|
|
# add updated cache to model output
|
|
if past_key_values is not None and return_dict:
|
|
outputs, past_key_values = outputs
|
|
outputs["past_key_values"] = unfreeze(past_key_values["cache"])
|
|
return outputs
|
|
elif past_key_values is not None and not return_dict:
|
|
outputs, past_key_values = outputs
|
|
outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
|
|
|
|
else:
|
|
outputs = self.module.apply(
|
|
inputs,
|
|
jnp.array(input_ids, dtype="i4"),
|
|
jnp.array(attention_mask, dtype="i4"),
|
|
token_type_ids=jnp.array(token_type_ids, dtype="i4"),
|
|
position_ids=jnp.array(position_ids, dtype="i4"),
|
|
head_mask=jnp.array(head_mask, dtype="i4"),
|
|
deterministic=not train,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
rngs=rngs,
|
|
)
|
|
|
|
return outputs
|
|
|
|
|
|
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertModule with Bert->Roberta
|
|
class FlaxRobertaModule(nn.Module):
|
|
config: RobertaConfig
|
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
|
add_pooling_layer: bool = True
|
|
gradient_checkpointing: bool = False
|
|
|
|
def setup(self):
|
|
self.embeddings = FlaxRobertaEmbeddings(self.config, dtype=self.dtype)
|
|
self.encoder = FlaxRobertaEncoder(
|
|
self.config,
|
|
dtype=self.dtype,
|
|
gradient_checkpointing=self.gradient_checkpointing,
|
|
)
|
|
self.pooler = FlaxRobertaPooler(self.config, dtype=self.dtype)
|
|
|
|
def __call__(
|
|
self,
|
|
input_ids,
|
|
attention_mask,
|
|
token_type_ids: Optional[jnp.ndarray] = None,
|
|
position_ids: Optional[jnp.ndarray] = None,
|
|
head_mask: Optional[jnp.ndarray] = None,
|
|
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
|
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
|
init_cache: bool = False,
|
|
deterministic: bool = True,
|
|
output_attentions: bool = False,
|
|
output_hidden_states: bool = False,
|
|
return_dict: bool = True,
|
|
):
|
|
# make sure `token_type_ids` is correctly initialized when not passed
|
|
if token_type_ids is None:
|
|
token_type_ids = jnp.zeros_like(input_ids)
|
|
|
|
# make sure `position_ids` is correctly initialized when not passed
|
|
if position_ids is None:
|
|
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
|
|
|
|
hidden_states = self.embeddings(
|
|
input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic
|
|
)
|
|
outputs = self.encoder(
|
|
hidden_states,
|
|
attention_mask,
|
|
head_mask=head_mask,
|
|
deterministic=deterministic,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
init_cache=init_cache,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
hidden_states = outputs[0]
|
|
pooled = self.pooler(hidden_states) if self.add_pooling_layer else None
|
|
|
|
if not return_dict:
|
|
# if pooled is None, don't return it
|
|
if pooled is None:
|
|
return (hidden_states,) + outputs[1:]
|
|
return (hidden_states, pooled) + outputs[1:]
|
|
|
|
return FlaxBaseModelOutputWithPoolingAndCrossAttentions(
|
|
last_hidden_state=hidden_states,
|
|
pooler_output=pooled,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
cross_attentions=outputs.cross_attentions,
|
|
)
|
|
|
|
|
|
@add_start_docstrings(
|
|
"The bare RoBERTa Model transformer outputting raw hidden-states without any specific head on top.",
|
|
ROBERTA_START_DOCSTRING,
|
|
)
|
|
class FlaxRobertaModel(FlaxRobertaPreTrainedModel):
|
|
module_class = FlaxRobertaModule
|
|
|
|
|
|
append_call_sample_docstring(FlaxRobertaModel, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutputWithPooling, _CONFIG_FOR_DOC)
|
|
|
|
|
|
class FlaxRobertaForMaskedLMModule(nn.Module):
|
|
config: RobertaConfig
|
|
dtype: jnp.dtype = jnp.float32
|
|
gradient_checkpointing: bool = False
|
|
|
|
def setup(self):
|
|
self.roberta = FlaxRobertaModule(
|
|
config=self.config,
|
|
add_pooling_layer=False,
|
|
dtype=self.dtype,
|
|
gradient_checkpointing=self.gradient_checkpointing,
|
|
)
|
|
self.lm_head = FlaxRobertaLMHead(config=self.config, dtype=self.dtype)
|
|
|
|
def __call__(
|
|
self,
|
|
input_ids,
|
|
attention_mask,
|
|
token_type_ids,
|
|
position_ids,
|
|
head_mask,
|
|
deterministic: bool = True,
|
|
output_attentions: bool = False,
|
|
output_hidden_states: bool = False,
|
|
return_dict: bool = True,
|
|
):
|
|
# Model
|
|
outputs = self.roberta(
|
|
input_ids,
|
|
attention_mask,
|
|
token_type_ids,
|
|
position_ids,
|
|
head_mask,
|
|
deterministic=deterministic,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
|
|
hidden_states = outputs[0]
|
|
if self.config.tie_word_embeddings:
|
|
shared_embedding = self.roberta.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
|
|
else:
|
|
shared_embedding = None
|
|
|
|
# Compute the prediction scores
|
|
logits = self.lm_head(hidden_states, shared_embedding=shared_embedding)
|
|
|
|
if not return_dict:
|
|
return (logits,) + outputs[1:]
|
|
|
|
return FlaxMaskedLMOutput(
|
|
logits=logits,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
)
|
|
|
|
|
|
@add_start_docstrings("""RoBERTa Model with a `language modeling` head on top.""", ROBERTA_START_DOCSTRING)
|
|
class FlaxRobertaForMaskedLM(FlaxRobertaPreTrainedModel):
|
|
module_class = FlaxRobertaForMaskedLMModule
|
|
|
|
|
|
append_call_sample_docstring(
|
|
FlaxRobertaForMaskedLM,
|
|
_CHECKPOINT_FOR_DOC,
|
|
FlaxBaseModelOutputWithPooling,
|
|
_CONFIG_FOR_DOC,
|
|
mask="<mask>",
|
|
)
|
|
|
|
|
|
class FlaxRobertaForSequenceClassificationModule(nn.Module):
|
|
config: RobertaConfig
|
|
dtype: jnp.dtype = jnp.float32
|
|
gradient_checkpointing: bool = False
|
|
|
|
def setup(self):
|
|
self.roberta = FlaxRobertaModule(
|
|
config=self.config,
|
|
dtype=self.dtype,
|
|
add_pooling_layer=False,
|
|
gradient_checkpointing=self.gradient_checkpointing,
|
|
)
|
|
self.classifier = FlaxRobertaClassificationHead(config=self.config, dtype=self.dtype)
|
|
|
|
def __call__(
|
|
self,
|
|
input_ids,
|
|
attention_mask,
|
|
token_type_ids,
|
|
position_ids,
|
|
head_mask,
|
|
deterministic: bool = True,
|
|
output_attentions: bool = False,
|
|
output_hidden_states: bool = False,
|
|
return_dict: bool = True,
|
|
):
|
|
# Model
|
|
outputs = self.roberta(
|
|
input_ids,
|
|
attention_mask,
|
|
token_type_ids,
|
|
position_ids,
|
|
head_mask,
|
|
deterministic=deterministic,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
|
|
sequence_output = outputs[0]
|
|
logits = self.classifier(sequence_output, deterministic=deterministic)
|
|
|
|
if not return_dict:
|
|
return (logits,) + outputs[1:]
|
|
|
|
return FlaxSequenceClassifierOutput(
|
|
logits=logits,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
)
|
|
|
|
|
|
@add_start_docstrings(
|
|
"""
|
|
Roberta Model transformer with a sequence classification/regression head on top (a linear layer on top of the
|
|
pooled output) e.g. for GLUE tasks.
|
|
""",
|
|
ROBERTA_START_DOCSTRING,
|
|
)
|
|
class FlaxRobertaForSequenceClassification(FlaxRobertaPreTrainedModel):
|
|
module_class = FlaxRobertaForSequenceClassificationModule
|
|
|
|
|
|
append_call_sample_docstring(
|
|
FlaxRobertaForSequenceClassification,
|
|
_CHECKPOINT_FOR_DOC,
|
|
FlaxSequenceClassifierOutput,
|
|
_CONFIG_FOR_DOC,
|
|
)
|
|
|
|
|
|
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForMultipleChoiceModule with Bert->Roberta, with self.bert->self.roberta
|
|
class FlaxRobertaForMultipleChoiceModule(nn.Module):
|
|
config: RobertaConfig
|
|
dtype: jnp.dtype = jnp.float32
|
|
gradient_checkpointing: bool = False
|
|
|
|
def setup(self):
|
|
self.roberta = FlaxRobertaModule(
|
|
config=self.config,
|
|
dtype=self.dtype,
|
|
gradient_checkpointing=self.gradient_checkpointing,
|
|
)
|
|
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
|
self.classifier = nn.Dense(1, dtype=self.dtype)
|
|
|
|
def __call__(
|
|
self,
|
|
input_ids,
|
|
attention_mask,
|
|
token_type_ids,
|
|
position_ids,
|
|
head_mask,
|
|
deterministic: bool = True,
|
|
output_attentions: bool = False,
|
|
output_hidden_states: bool = False,
|
|
return_dict: bool = True,
|
|
):
|
|
num_choices = input_ids.shape[1]
|
|
input_ids = input_ids.reshape(-1, input_ids.shape[-1]) if input_ids is not None else None
|
|
attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) if attention_mask is not None else None
|
|
token_type_ids = token_type_ids.reshape(-1, token_type_ids.shape[-1]) if token_type_ids is not None else None
|
|
position_ids = position_ids.reshape(-1, position_ids.shape[-1]) if position_ids is not None else None
|
|
|
|
# Model
|
|
outputs = self.roberta(
|
|
input_ids,
|
|
attention_mask,
|
|
token_type_ids,
|
|
position_ids,
|
|
head_mask,
|
|
deterministic=deterministic,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
|
|
pooled_output = outputs[1]
|
|
pooled_output = self.dropout(pooled_output, deterministic=deterministic)
|
|
logits = self.classifier(pooled_output)
|
|
|
|
reshaped_logits = logits.reshape(-1, num_choices)
|
|
|
|
if not return_dict:
|
|
return (reshaped_logits,) + outputs[2:]
|
|
|
|
return FlaxMultipleChoiceModelOutput(
|
|
logits=reshaped_logits,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
)
|
|
|
|
|
|
@add_start_docstrings(
|
|
"""
|
|
Roberta Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
|
|
softmax) e.g. for RocStories/SWAG tasks.
|
|
""",
|
|
ROBERTA_START_DOCSTRING,
|
|
)
|
|
class FlaxRobertaForMultipleChoice(FlaxRobertaPreTrainedModel):
|
|
module_class = FlaxRobertaForMultipleChoiceModule
|
|
|
|
|
|
overwrite_call_docstring(
|
|
FlaxRobertaForMultipleChoice, ROBERTA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
|
|
)
|
|
append_call_sample_docstring(
|
|
FlaxRobertaForMultipleChoice,
|
|
_CHECKPOINT_FOR_DOC,
|
|
FlaxMultipleChoiceModelOutput,
|
|
_CONFIG_FOR_DOC,
|
|
)
|
|
|
|
|
|
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForTokenClassificationModule with Bert->Roberta, with self.bert->self.roberta
|
|
class FlaxRobertaForTokenClassificationModule(nn.Module):
|
|
config: RobertaConfig
|
|
dtype: jnp.dtype = jnp.float32
|
|
gradient_checkpointing: bool = False
|
|
|
|
def setup(self):
|
|
self.roberta = FlaxRobertaModule(
|
|
config=self.config,
|
|
dtype=self.dtype,
|
|
add_pooling_layer=False,
|
|
gradient_checkpointing=self.gradient_checkpointing,
|
|
)
|
|
classifier_dropout = (
|
|
self.config.classifier_dropout
|
|
if self.config.classifier_dropout is not None
|
|
else self.config.hidden_dropout_prob
|
|
)
|
|
self.dropout = nn.Dropout(rate=classifier_dropout)
|
|
self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype)
|
|
|
|
def __call__(
|
|
self,
|
|
input_ids,
|
|
attention_mask,
|
|
token_type_ids,
|
|
position_ids,
|
|
head_mask,
|
|
deterministic: bool = True,
|
|
output_attentions: bool = False,
|
|
output_hidden_states: bool = False,
|
|
return_dict: bool = True,
|
|
):
|
|
# Model
|
|
outputs = self.roberta(
|
|
input_ids,
|
|
attention_mask,
|
|
token_type_ids,
|
|
position_ids,
|
|
head_mask,
|
|
deterministic=deterministic,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
|
|
hidden_states = outputs[0]
|
|
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
|
logits = self.classifier(hidden_states)
|
|
|
|
if not return_dict:
|
|
return (logits,) + outputs[1:]
|
|
|
|
return FlaxTokenClassifierOutput(
|
|
logits=logits,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
)
|
|
|
|
|
|
@add_start_docstrings(
|
|
"""
|
|
Roberta Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
|
|
Named-Entity-Recognition (NER) tasks.
|
|
""",
|
|
ROBERTA_START_DOCSTRING,
|
|
)
|
|
class FlaxRobertaForTokenClassification(FlaxRobertaPreTrainedModel):
|
|
module_class = FlaxRobertaForTokenClassificationModule
|
|
|
|
|
|
append_call_sample_docstring(
|
|
FlaxRobertaForTokenClassification,
|
|
_CHECKPOINT_FOR_DOC,
|
|
FlaxTokenClassifierOutput,
|
|
_CONFIG_FOR_DOC,
|
|
)
|
|
|
|
|
|
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForQuestionAnsweringModule with Bert->Roberta, with self.bert->self.roberta
|
|
class FlaxRobertaForQuestionAnsweringModule(nn.Module):
|
|
config: RobertaConfig
|
|
dtype: jnp.dtype = jnp.float32
|
|
gradient_checkpointing: bool = False
|
|
|
|
def setup(self):
|
|
self.roberta = FlaxRobertaModule(
|
|
config=self.config,
|
|
dtype=self.dtype,
|
|
add_pooling_layer=False,
|
|
gradient_checkpointing=self.gradient_checkpointing,
|
|
)
|
|
self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype)
|
|
|
|
def __call__(
|
|
self,
|
|
input_ids,
|
|
attention_mask,
|
|
token_type_ids,
|
|
position_ids,
|
|
head_mask,
|
|
deterministic: bool = True,
|
|
output_attentions: bool = False,
|
|
output_hidden_states: bool = False,
|
|
return_dict: bool = True,
|
|
):
|
|
# Model
|
|
outputs = self.roberta(
|
|
input_ids,
|
|
attention_mask,
|
|
token_type_ids,
|
|
position_ids,
|
|
head_mask,
|
|
deterministic=deterministic,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
|
|
hidden_states = outputs[0]
|
|
|
|
logits = self.qa_outputs(hidden_states)
|
|
start_logits, end_logits = logits.split(self.config.num_labels, axis=-1)
|
|
start_logits = start_logits.squeeze(-1)
|
|
end_logits = end_logits.squeeze(-1)
|
|
|
|
if not return_dict:
|
|
return (start_logits, end_logits) + outputs[1:]
|
|
|
|
return FlaxQuestionAnsweringModelOutput(
|
|
start_logits=start_logits,
|
|
end_logits=end_logits,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
)
|
|
|
|
|
|
@add_start_docstrings(
|
|
"""
|
|
Roberta Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
|
|
layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
|
|
""",
|
|
ROBERTA_START_DOCSTRING,
|
|
)
|
|
class FlaxRobertaForQuestionAnswering(FlaxRobertaPreTrainedModel):
|
|
module_class = FlaxRobertaForQuestionAnsweringModule
|
|
|
|
|
|
append_call_sample_docstring(
|
|
FlaxRobertaForQuestionAnswering,
|
|
_CHECKPOINT_FOR_DOC,
|
|
FlaxQuestionAnsweringModelOutput,
|
|
_CONFIG_FOR_DOC,
|
|
)
|
|
|
|
|
|
class FlaxRobertaForCausalLMModule(nn.Module):
|
|
config: RobertaConfig
|
|
dtype: jnp.dtype = jnp.float32
|
|
gradient_checkpointing: bool = False
|
|
|
|
def setup(self):
|
|
self.roberta = FlaxRobertaModule(
|
|
config=self.config,
|
|
add_pooling_layer=False,
|
|
dtype=self.dtype,
|
|
gradient_checkpointing=self.gradient_checkpointing,
|
|
)
|
|
self.lm_head = FlaxRobertaLMHead(config=self.config, dtype=self.dtype)
|
|
|
|
def __call__(
|
|
self,
|
|
input_ids,
|
|
attention_mask,
|
|
position_ids,
|
|
token_type_ids: Optional[jnp.ndarray] = None,
|
|
head_mask: Optional[jnp.ndarray] = None,
|
|
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
|
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
|
init_cache: bool = False,
|
|
deterministic: bool = True,
|
|
output_attentions: bool = False,
|
|
output_hidden_states: bool = False,
|
|
return_dict: bool = True,
|
|
):
|
|
# Model
|
|
outputs = self.roberta(
|
|
input_ids,
|
|
attention_mask,
|
|
token_type_ids,
|
|
position_ids,
|
|
head_mask,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
init_cache=init_cache,
|
|
deterministic=deterministic,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
|
|
hidden_states = outputs[0]
|
|
if self.config.tie_word_embeddings:
|
|
shared_embedding = self.roberta.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
|
|
else:
|
|
shared_embedding = None
|
|
|
|
# Compute the prediction scores
|
|
logits = self.lm_head(hidden_states, shared_embedding=shared_embedding)
|
|
|
|
if not return_dict:
|
|
return (logits,) + outputs[1:]
|
|
|
|
return FlaxCausalLMOutputWithCrossAttentions(
|
|
logits=logits,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
cross_attentions=outputs.cross_attentions,
|
|
)
|
|
|
|
|
|
@add_start_docstrings(
|
|
"""
|
|
Roberta Model with a language modeling head on top (a linear layer on top of the hidden-states output) e.g for
|
|
autoregressive tasks.
|
|
""",
|
|
ROBERTA_START_DOCSTRING,
|
|
)
|
|
class FlaxRobertaForCausalLM(FlaxRobertaPreTrainedModel):
|
|
module_class = FlaxRobertaForCausalLMModule
|
|
|
|
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
|
|
# initializing the cache
|
|
batch_size, seq_length = input_ids.shape
|
|
|
|
past_key_values = self.init_cache(batch_size, max_length)
|
|
# Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
|
|
# But since the decoder uses a causal mask, those positions are masked anyway.
|
|
# Thus, we can create a single static attention_mask here, which is more efficient for compilation
|
|
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
|
|
if attention_mask is not None:
|
|
position_ids = attention_mask.cumsum(axis=-1) - 1
|
|
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
|
|
else:
|
|
position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
|
|
|
|
return {
|
|
"past_key_values": past_key_values,
|
|
"attention_mask": extended_attention_mask,
|
|
"position_ids": position_ids,
|
|
}
|
|
|
|
def update_inputs_for_generation(self, model_outputs, model_kwargs):
|
|
model_kwargs["past_key_values"] = model_outputs.past_key_values
|
|
model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
|
|
return model_kwargs
|
|
|
|
|
|
append_call_sample_docstring(
|
|
FlaxRobertaForCausalLM,
|
|
_CHECKPOINT_FOR_DOC,
|
|
FlaxCausalLMOutputWithCrossAttentions,
|
|
_CONFIG_FOR_DOC,
|
|
)
|