First commit
This commit is contained in:
65
pkgs/xformers/components/positional_embedding/vocab.py
Normal file
65
pkgs/xformers/components/positional_embedding/vocab.py
Normal file
@@ -0,0 +1,65 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from xformers.components.positional_embedding import (
|
||||
PositionEmbedding,
|
||||
PositionEmbeddingConfig,
|
||||
register_positional_embedding,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VocabEmbeddingConfig(PositionEmbeddingConfig):
|
||||
vocab_size: int
|
||||
dropout: float
|
||||
|
||||
|
||||
@register_positional_embedding("vocab", VocabEmbeddingConfig)
|
||||
class VocabEmbedding(PositionEmbedding):
|
||||
def __init__(
|
||||
self,
|
||||
dim_model: int,
|
||||
seq_len: int,
|
||||
vocab_size: int,
|
||||
dropout: float = 0.0,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.dim_model = dim_model
|
||||
|
||||
self.dropout = torch.nn.Dropout(p=dropout)
|
||||
self.position_embeddings = nn.Embedding(seq_len, self.dim_model)
|
||||
self.word_embeddings = nn.Embedding(self.vocab_size, self.dim_model)
|
||||
|
||||
self.position_ids: Optional[torch.Tensor] = None
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self, gain: float = 1.0):
|
||||
torch.nn.init.normal_(self.position_embeddings.weight, std=0.02 * gain)
|
||||
torch.nn.init.normal_(self.word_embeddings.weight, std=0.02 * gain)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
position_ids = torch.arange(x.shape[1], dtype=torch.long, device=x.device)[
|
||||
None, :
|
||||
].repeat(x.shape[0], 1)
|
||||
|
||||
X_token = self.word_embeddings(x)
|
||||
X_pos = self.position_embeddings(position_ids)
|
||||
|
||||
X = X_token + X_pos
|
||||
X = self.dropout(X)
|
||||
|
||||
return X
|
||||
Reference in New Issue
Block a user