First commit
This commit is contained in:
87
pkgs/xformers/components/positional_embedding/__init__.py
Normal file
87
pkgs/xformers/components/positional_embedding/__init__.py
Normal file
@@ -0,0 +1,87 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Set, Union
|
||||
|
||||
from xformers.utils import (
|
||||
generate_matching_config,
|
||||
get_registry_decorator,
|
||||
import_all_modules,
|
||||
)
|
||||
|
||||
from .base import PositionEmbedding, PositionEmbeddingConfig # noqa
|
||||
|
||||
# CREDITS: Classy Vision registry mechanism
|
||||
|
||||
POSITION_EMBEDDING_REGISTRY: Dict[str, Any] = {}
|
||||
POSITION_EMBEDDING_CLASS_NAMES: Set[str] = set()
|
||||
|
||||
|
||||
def build_positional_embedding(config: Union[Dict[str, Any], PositionEmbeddingConfig]):
|
||||
"""Builds a position encoding from a config.
|
||||
|
||||
This assumes a 'name' key in the config which is used to determine what
|
||||
attention class to instantiate. For instance, a config `{"name": "my_position_encoding",
|
||||
"foo": "bar"}` will find a class that was registered as "my_position_encoding"
|
||||
(see :func:`register_positional_embedding`) and call .from_config on it."""
|
||||
|
||||
if not isinstance(config, PositionEmbeddingConfig):
|
||||
config_instance = generate_matching_config(
|
||||
config, POSITION_EMBEDDING_REGISTRY[config["name"]].config
|
||||
)
|
||||
else:
|
||||
config_instance = config
|
||||
|
||||
return POSITION_EMBEDDING_REGISTRY[config_instance.name].constructor.from_config(
|
||||
config_instance
|
||||
)
|
||||
|
||||
|
||||
"""Registers a PositionEncoding subclass.
|
||||
|
||||
This decorator allows xFormers to instantiate a subclass of PositionEncoding
|
||||
from a configuration file, even if the class itself is not part of the
|
||||
xFormers framework. To use it, apply this decorator to a `PositionEncoding`
|
||||
subclass, like this:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@dataclass
|
||||
class MyConfig:
|
||||
...
|
||||
|
||||
@register_positional_embedding('my_encoding', MyConfig)
|
||||
class MyEncoding(PositionEncoding):
|
||||
...
|
||||
|
||||
To instantiate a position encoding from a configuration file, see :func:`build_positional_embedding`."""
|
||||
register_positional_embedding: Callable[
|
||||
[str, Any], Callable[[Any], Any]
|
||||
] = get_registry_decorator(
|
||||
POSITION_EMBEDDING_REGISTRY,
|
||||
POSITION_EMBEDDING_CLASS_NAMES,
|
||||
PositionEmbedding,
|
||||
PositionEmbeddingConfig,
|
||||
)
|
||||
|
||||
|
||||
from .rotary import RotaryEmbedding # noqa
|
||||
from .sine import SinePositionalEmbedding # type: ignore # noqa
|
||||
from .vocab import VocabEmbedding # noqa
|
||||
|
||||
__all__ = [
|
||||
"RotaryEmbedding",
|
||||
"SinePositionalEmbedding",
|
||||
"VocabEmbedding",
|
||||
"build_positional_embedding",
|
||||
"register_positional_embedding",
|
||||
]
|
||||
|
||||
# automatically import any Python files in the directory
|
||||
import_all_modules(
|
||||
str(Path(__file__).parent), "xformers.components.positional_embedding"
|
||||
)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
35
pkgs/xformers/components/positional_embedding/base.py
Normal file
35
pkgs/xformers/components/positional_embedding/base.py
Normal file
@@ -0,0 +1,35 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Type, TypeVar
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
Self = TypeVar("Self", bound="PositionEmbedding")
|
||||
|
||||
|
||||
@dataclass
|
||||
class PositionEmbeddingConfig:
|
||||
name: str
|
||||
dim_model: int
|
||||
seq_len: int
|
||||
|
||||
|
||||
class PositionEmbedding(nn.Module, metaclass=ABCMeta):
|
||||
@abstractmethod
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__()
|
||||
|
||||
@classmethod
|
||||
def from_config(cls: Type[Self], config: PositionEmbeddingConfig) -> Self:
|
||||
# Generate the class inputs from the config
|
||||
fields = asdict(config)
|
||||
|
||||
# Skip all Nones so that default values are used
|
||||
fields = {k: v for k, v in fields.items() if v is not None}
|
||||
return cls(**fields)
|
||||
54
pkgs/xformers/components/positional_embedding/param.py
Normal file
54
pkgs/xformers/components/positional_embedding/param.py
Normal file
@@ -0,0 +1,54 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from xformers.components.positional_embedding import (
|
||||
PositionEmbedding,
|
||||
PositionEmbeddingConfig,
|
||||
register_positional_embedding,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LearnablePositionalEmbeddingConfig(PositionEmbeddingConfig):
|
||||
name: str
|
||||
seq_len: int
|
||||
dim_model: int
|
||||
add_class_token: bool
|
||||
|
||||
|
||||
@register_positional_embedding("learnable", LearnablePositionalEmbeddingConfig)
|
||||
class LearnablePositionalEmbedding(PositionEmbedding):
|
||||
def __init__(
|
||||
self, seq_len: int, dim_model: int, add_class_token: bool = False, *_, **__
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# 0.02 is BERT initialization
|
||||
self.pos_emb = torch.nn.Parameter(
|
||||
torch.randn(1, seq_len + int(add_class_token), dim_model) * 0.02
|
||||
)
|
||||
|
||||
self.class_token = (
|
||||
torch.nn.Parameter(torch.zeros(dim_model)) if add_class_token else None
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.class_token is not None:
|
||||
# Prepend class token
|
||||
clf_token = (
|
||||
torch.ones(x.shape[0], 1, self.pos_emb.shape[-1], device=x.device)
|
||||
* self.class_token
|
||||
)
|
||||
x = torch.cat([clf_token, x], dim=1)
|
||||
|
||||
if x.ndim == 2:
|
||||
x = x.unsqueeze(-1)
|
||||
|
||||
return x + self.pos_emb
|
||||
91
pkgs/xformers/components/positional_embedding/rotary.py
Normal file
91
pkgs/xformers/components/positional_embedding/rotary.py
Normal file
@@ -0,0 +1,91 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
# CREDITS: This implementation is inspired by GPT-NeoX https://github.com/EleutherAI/gpt-neox
|
||||
# NOTE: Almost the same right now, moving parts to Triton is the next step
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
x1, x2 = x.chunk(2, dim=-1)
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def apply_rotary_pos_emb(x, cos, sin):
|
||||
# NOTE: This could probably be moved to Triton
|
||||
|
||||
# Handle a possible sequence length mismatch in between q and k
|
||||
cos = cos[:, :, : x.shape[-2], :]
|
||||
sin = sin[:, :, : x.shape[-2], :]
|
||||
|
||||
return (x * cos) + (rotate_half(x) * sin)
|
||||
|
||||
|
||||
class RotaryEmbedding(torch.nn.Module):
|
||||
"""
|
||||
The rotary position embeddings from RoFormer_ (Su et. al).
|
||||
A crucial insight from the method is that the query and keys are
|
||||
transformed by rotation matrices which depend on the relative positions.
|
||||
|
||||
Other implementations are available in the Rotary Transformer repo_ and in
|
||||
GPT-NeoX_, GPT-NeoX was an inspiration
|
||||
|
||||
.. _RoFormer: https://arxiv.org/abs/2104.09864
|
||||
.. _repo: https://github.com/ZhuiyiTechnology/roformer
|
||||
.. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
|
||||
|
||||
|
||||
.. warning: Please note that this embedding is not registered on purpose, as it is transformative
|
||||
(it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis
|
||||
"""
|
||||
|
||||
def __init__(self, dim_model: int, *_, **__):
|
||||
super().__init__()
|
||||
# Generate and save the inverse frequency buffer (non trainable)
|
||||
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim_model, 2).float() / dim_model))
|
||||
self.register_buffer("inv_freq", inv_freq)
|
||||
|
||||
self._seq_len_cached = None
|
||||
self._cos_cached = None
|
||||
self._sin_cached = None
|
||||
|
||||
def _update_cos_sin_tables(self, x, seq_dimension=1):
|
||||
seq_len = x.shape[seq_dimension]
|
||||
|
||||
# Reset the tables if the sequence length has changed,
|
||||
# or if we're on a new device (possibly due to tracing for instance)
|
||||
if (
|
||||
seq_len != self._seq_len_cached
|
||||
or self._cos_cached.device != x.device
|
||||
or self._cos_cached.dtype != x.dtype
|
||||
):
|
||||
self._seq_len_cached = seq_len
|
||||
t = torch.arange(
|
||||
x.shape[seq_dimension], device=x.device, dtype=torch.float32
|
||||
)
|
||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq.to(x.dtype))
|
||||
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
|
||||
|
||||
self._cos_cached = emb.cos()[None, None, :, :].to(x.dtype)
|
||||
self._sin_cached = emb.sin()[None, None, :, :].to(x.dtype)
|
||||
|
||||
return self._cos_cached, self._sin_cached
|
||||
|
||||
def forward(
|
||||
self, q: torch.Tensor, k: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
self._cos_cached, self._sin_cached = self._update_cos_sin_tables(
|
||||
k, seq_dimension=-2
|
||||
)
|
||||
|
||||
return (
|
||||
apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
|
||||
apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
|
||||
)
|
||||
46
pkgs/xformers/components/positional_embedding/sine.py
Normal file
46
pkgs/xformers/components/positional_embedding/sine.py
Normal file
@@ -0,0 +1,46 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
# Silence Mypy errors in this file.
|
||||
# type: ignore
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
from xformers.components.positional_embedding import (
|
||||
PositionEmbedding,
|
||||
PositionEmbeddingConfig,
|
||||
register_positional_embedding,
|
||||
)
|
||||
|
||||
|
||||
@register_positional_embedding("sine", PositionEmbeddingConfig)
|
||||
class SinePositionalEmbedding(PositionEmbedding):
|
||||
def __init__(self, dim_model: int, *args, **kwargs):
|
||||
super().__init__()
|
||||
self.dim_model = dim_model
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
seq_len = x.shape[1]
|
||||
pos = (
|
||||
torch.arange(0, seq_len, device=x.device, dtype=torch.float32)
|
||||
.unsqueeze(1)
|
||||
.repeat(1, self.dim_model)
|
||||
)
|
||||
dim = (
|
||||
torch.arange(0, self.dim_model, device=x.device, dtype=torch.float32)
|
||||
.unsqueeze(0)
|
||||
.repeat(seq_len, 1)
|
||||
)
|
||||
div = torch.exp(-math.log(10000) * (2 * (dim // 2) / self.dim_model))
|
||||
pos *= div
|
||||
pos[:, 0::2] = torch.sin(pos[:, 0::2])
|
||||
pos[:, 1::2] = torch.cos(pos[:, 1::2])
|
||||
|
||||
output = x.unsqueeze(-1) if x.ndim == 2 else x
|
||||
|
||||
return output + pos.unsqueeze(0)
|
||||
65
pkgs/xformers/components/positional_embedding/vocab.py
Normal file
65
pkgs/xformers/components/positional_embedding/vocab.py
Normal file
@@ -0,0 +1,65 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from xformers.components.positional_embedding import (
|
||||
PositionEmbedding,
|
||||
PositionEmbeddingConfig,
|
||||
register_positional_embedding,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VocabEmbeddingConfig(PositionEmbeddingConfig):
|
||||
vocab_size: int
|
||||
dropout: float
|
||||
|
||||
|
||||
@register_positional_embedding("vocab", VocabEmbeddingConfig)
|
||||
class VocabEmbedding(PositionEmbedding):
|
||||
def __init__(
|
||||
self,
|
||||
dim_model: int,
|
||||
seq_len: int,
|
||||
vocab_size: int,
|
||||
dropout: float = 0.0,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.dim_model = dim_model
|
||||
|
||||
self.dropout = torch.nn.Dropout(p=dropout)
|
||||
self.position_embeddings = nn.Embedding(seq_len, self.dim_model)
|
||||
self.word_embeddings = nn.Embedding(self.vocab_size, self.dim_model)
|
||||
|
||||
self.position_ids: Optional[torch.Tensor] = None
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self, gain: float = 1.0):
|
||||
torch.nn.init.normal_(self.position_embeddings.weight, std=0.02 * gain)
|
||||
torch.nn.init.normal_(self.word_embeddings.weight, std=0.02 * gain)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
position_ids = torch.arange(x.shape[1], dtype=torch.long, device=x.device)[
|
||||
None, :
|
||||
].repeat(x.shape[0], 1)
|
||||
|
||||
X_token = self.word_embeddings(x)
|
||||
X_pos = self.position_embeddings(position_ids)
|
||||
|
||||
X = X_token + X_pos
|
||||
X = self.dropout(X)
|
||||
|
||||
return X
|
||||
Reference in New Issue
Block a user