First commit

This commit is contained in:
2025-08-05 19:02:46 +08:00
parent 9efe891f99
commit 99fb9f5cb0
1412 changed files with 203615 additions and 0 deletions

View File

@@ -0,0 +1,87 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from pathlib import Path
from typing import Any, Callable, Dict, Set, Union
from xformers.utils import (
generate_matching_config,
get_registry_decorator,
import_all_modules,
)
from .base import PositionEmbedding, PositionEmbeddingConfig # noqa
# CREDITS: Classy Vision registry mechanism
POSITION_EMBEDDING_REGISTRY: Dict[str, Any] = {}
POSITION_EMBEDDING_CLASS_NAMES: Set[str] = set()
def build_positional_embedding(config: Union[Dict[str, Any], PositionEmbeddingConfig]):
"""Builds a position encoding from a config.
This assumes a 'name' key in the config which is used to determine what
attention class to instantiate. For instance, a config `{"name": "my_position_encoding",
"foo": "bar"}` will find a class that was registered as "my_position_encoding"
(see :func:`register_positional_embedding`) and call .from_config on it."""
if not isinstance(config, PositionEmbeddingConfig):
config_instance = generate_matching_config(
config, POSITION_EMBEDDING_REGISTRY[config["name"]].config
)
else:
config_instance = config
return POSITION_EMBEDDING_REGISTRY[config_instance.name].constructor.from_config(
config_instance
)
"""Registers a PositionEncoding subclass.
This decorator allows xFormers to instantiate a subclass of PositionEncoding
from a configuration file, even if the class itself is not part of the
xFormers framework. To use it, apply this decorator to a `PositionEncoding`
subclass, like this:
.. code-block:: python
@dataclass
class MyConfig:
...
@register_positional_embedding('my_encoding', MyConfig)
class MyEncoding(PositionEncoding):
...
To instantiate a position encoding from a configuration file, see :func:`build_positional_embedding`."""
register_positional_embedding: Callable[
[str, Any], Callable[[Any], Any]
] = get_registry_decorator(
POSITION_EMBEDDING_REGISTRY,
POSITION_EMBEDDING_CLASS_NAMES,
PositionEmbedding,
PositionEmbeddingConfig,
)
from .rotary import RotaryEmbedding # noqa
from .sine import SinePositionalEmbedding # type: ignore # noqa
from .vocab import VocabEmbedding # noqa
__all__ = [
"RotaryEmbedding",
"SinePositionalEmbedding",
"VocabEmbedding",
"build_positional_embedding",
"register_positional_embedding",
]
# automatically import any Python files in the directory
import_all_modules(
str(Path(__file__).parent), "xformers.components.positional_embedding"
)

View File

@@ -0,0 +1,35 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from abc import ABCMeta, abstractmethod
from dataclasses import asdict, dataclass
from typing import Type, TypeVar
import torch.nn as nn
Self = TypeVar("Self", bound="PositionEmbedding")
@dataclass
class PositionEmbeddingConfig:
name: str
dim_model: int
seq_len: int
class PositionEmbedding(nn.Module, metaclass=ABCMeta):
@abstractmethod
def __init__(self, *args, **kwargs) -> None:
super().__init__()
@classmethod
def from_config(cls: Type[Self], config: PositionEmbeddingConfig) -> Self:
# Generate the class inputs from the config
fields = asdict(config)
# Skip all Nones so that default values are used
fields = {k: v for k, v in fields.items() if v is not None}
return cls(**fields)

View File

@@ -0,0 +1,54 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
import torch
from xformers.components.positional_embedding import (
PositionEmbedding,
PositionEmbeddingConfig,
register_positional_embedding,
)
@dataclass
class LearnablePositionalEmbeddingConfig(PositionEmbeddingConfig):
name: str
seq_len: int
dim_model: int
add_class_token: bool
@register_positional_embedding("learnable", LearnablePositionalEmbeddingConfig)
class LearnablePositionalEmbedding(PositionEmbedding):
def __init__(
self, seq_len: int, dim_model: int, add_class_token: bool = False, *_, **__
):
super().__init__()
# 0.02 is BERT initialization
self.pos_emb = torch.nn.Parameter(
torch.randn(1, seq_len + int(add_class_token), dim_model) * 0.02
)
self.class_token = (
torch.nn.Parameter(torch.zeros(dim_model)) if add_class_token else None
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.class_token is not None:
# Prepend class token
clf_token = (
torch.ones(x.shape[0], 1, self.pos_emb.shape[-1], device=x.device)
* self.class_token
)
x = torch.cat([clf_token, x], dim=1)
if x.ndim == 2:
x = x.unsqueeze(-1)
return x + self.pos_emb

View File

@@ -0,0 +1,91 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# CREDITS: This implementation is inspired by GPT-NeoX https://github.com/EleutherAI/gpt-neox
# NOTE: Almost the same right now, moving parts to Triton is the next step
from typing import Tuple
import torch
def rotate_half(x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
@torch.jit.script
def apply_rotary_pos_emb(x, cos, sin):
# NOTE: This could probably be moved to Triton
# Handle a possible sequence length mismatch in between q and k
cos = cos[:, :, : x.shape[-2], :]
sin = sin[:, :, : x.shape[-2], :]
return (x * cos) + (rotate_half(x) * sin)
class RotaryEmbedding(torch.nn.Module):
"""
The rotary position embeddings from RoFormer_ (Su et. al).
A crucial insight from the method is that the query and keys are
transformed by rotation matrices which depend on the relative positions.
Other implementations are available in the Rotary Transformer repo_ and in
GPT-NeoX_, GPT-NeoX was an inspiration
.. _RoFormer: https://arxiv.org/abs/2104.09864
.. _repo: https://github.com/ZhuiyiTechnology/roformer
.. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
.. warning: Please note that this embedding is not registered on purpose, as it is transformative
(it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis
"""
def __init__(self, dim_model: int, *_, **__):
super().__init__()
# Generate and save the inverse frequency buffer (non trainable)
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim_model, 2).float() / dim_model))
self.register_buffer("inv_freq", inv_freq)
self._seq_len_cached = None
self._cos_cached = None
self._sin_cached = None
def _update_cos_sin_tables(self, x, seq_dimension=1):
seq_len = x.shape[seq_dimension]
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
if (
seq_len != self._seq_len_cached
or self._cos_cached.device != x.device
or self._cos_cached.dtype != x.dtype
):
self._seq_len_cached = seq_len
t = torch.arange(
x.shape[seq_dimension], device=x.device, dtype=torch.float32
)
freqs = torch.einsum("i,j->ij", t, self.inv_freq.to(x.dtype))
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self._cos_cached = emb.cos()[None, None, :, :].to(x.dtype)
self._sin_cached = emb.sin()[None, None, :, :].to(x.dtype)
return self._cos_cached, self._sin_cached
def forward(
self, q: torch.Tensor, k: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
self._cos_cached, self._sin_cached = self._update_cos_sin_tables(
k, seq_dimension=-2
)
return (
apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
)

View File

@@ -0,0 +1,46 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Silence Mypy errors in this file.
# type: ignore
import math
import torch
from xformers.components.positional_embedding import (
PositionEmbedding,
PositionEmbeddingConfig,
register_positional_embedding,
)
@register_positional_embedding("sine", PositionEmbeddingConfig)
class SinePositionalEmbedding(PositionEmbedding):
def __init__(self, dim_model: int, *args, **kwargs):
super().__init__()
self.dim_model = dim_model
def forward(self, x: torch.Tensor) -> torch.Tensor:
seq_len = x.shape[1]
pos = (
torch.arange(0, seq_len, device=x.device, dtype=torch.float32)
.unsqueeze(1)
.repeat(1, self.dim_model)
)
dim = (
torch.arange(0, self.dim_model, device=x.device, dtype=torch.float32)
.unsqueeze(0)
.repeat(seq_len, 1)
)
div = torch.exp(-math.log(10000) * (2 * (dim // 2) / self.dim_model))
pos *= div
pos[:, 0::2] = torch.sin(pos[:, 0::2])
pos[:, 1::2] = torch.cos(pos[:, 1::2])
output = x.unsqueeze(-1) if x.ndim == 2 else x
return output + pos.unsqueeze(0)

View File

@@ -0,0 +1,65 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
from xformers.components.positional_embedding import (
PositionEmbedding,
PositionEmbeddingConfig,
register_positional_embedding,
)
@dataclass
class VocabEmbeddingConfig(PositionEmbeddingConfig):
vocab_size: int
dropout: float
@register_positional_embedding("vocab", VocabEmbeddingConfig)
class VocabEmbedding(PositionEmbedding):
def __init__(
self,
dim_model: int,
seq_len: int,
vocab_size: int,
dropout: float = 0.0,
*args,
**kwargs
):
super().__init__()
self.vocab_size = vocab_size
self.dim_model = dim_model
self.dropout = torch.nn.Dropout(p=dropout)
self.position_embeddings = nn.Embedding(seq_len, self.dim_model)
self.word_embeddings = nn.Embedding(self.vocab_size, self.dim_model)
self.position_ids: Optional[torch.Tensor] = None
self.init_weights()
def init_weights(self, gain: float = 1.0):
torch.nn.init.normal_(self.position_embeddings.weight, std=0.02 * gain)
torch.nn.init.normal_(self.word_embeddings.weight, std=0.02 * gain)
def forward(self, x: torch.Tensor):
position_ids = torch.arange(x.shape[1], dtype=torch.long, device=x.device)[
None, :
].repeat(x.shape[0], 1)
X_token = self.word_embeddings(x)
X_pos = self.position_embeddings(position_ids)
X = X_token + X_pos
X = self.dropout(X)
return X