init ascend tts

This commit is contained in:
2025-09-05 10:49:17 +08:00
parent d53ac91bb6
commit c5a6692774
602 changed files with 590901 additions and 1 deletions

View File

@@ -0,0 +1,659 @@
import math
import torch
from torch import nn
from torch.nn import functional as F
from module import commons
from module.modules import LayerNorm
class Encoder(nn.Module):
def __init__(
self,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size=1,
p_dropout=0.0,
window_size=4,
isflow=False,
**kwargs,
):
super().__init__()
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.window_size = window_size
self.drop = nn.Dropout(p_dropout)
self.attn_layers = nn.ModuleList()
self.norm_layers_1 = nn.ModuleList()
self.ffn_layers = nn.ModuleList()
self.norm_layers_2 = nn.ModuleList()
for i in range(self.n_layers):
self.attn_layers.append(
MultiHeadAttention(
hidden_channels,
hidden_channels,
n_heads,
p_dropout=p_dropout,
window_size=window_size,
)
)
self.norm_layers_1.append(LayerNorm(hidden_channels))
self.ffn_layers.append(
FFN(
hidden_channels,
hidden_channels,
filter_channels,
kernel_size,
p_dropout=p_dropout,
)
)
self.norm_layers_2.append(LayerNorm(hidden_channels))
if isflow:
cond_layer = torch.nn.Conv1d(kwargs["gin_channels"], 2 * hidden_channels * n_layers, 1)
self.cond_pre = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, 1)
self.cond_layer = weight_norm_modules(cond_layer, name="weight")
self.gin_channels = kwargs["gin_channels"]
def forward(self, x, x_mask, g=None):
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
x = x * x_mask
if g is not None:
g = self.cond_layer(g)
for i in range(self.n_layers):
if g is not None:
x = self.cond_pre(x)
cond_offset = i * 2 * self.hidden_channels
g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
x = commons.fused_add_tanh_sigmoid_multiply(x, g_l, torch.IntTensor([self.hidden_channels]))
y = self.attn_layers[i](x, x, attn_mask)
y = self.drop(y)
x = self.norm_layers_1[i](x + y)
y = self.ffn_layers[i](x, x_mask)
y = self.drop(y)
x = self.norm_layers_2[i](x + y)
x = x * x_mask
return x
class Decoder(nn.Module):
def __init__(
self,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size=1,
p_dropout=0.0,
proximal_bias=False,
proximal_init=True,
**kwargs,
):
super().__init__()
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.proximal_bias = proximal_bias
self.proximal_init = proximal_init
self.drop = nn.Dropout(p_dropout)
self.self_attn_layers = nn.ModuleList()
self.norm_layers_0 = nn.ModuleList()
self.encdec_attn_layers = nn.ModuleList()
self.norm_layers_1 = nn.ModuleList()
self.ffn_layers = nn.ModuleList()
self.norm_layers_2 = nn.ModuleList()
for i in range(self.n_layers):
self.self_attn_layers.append(
MultiHeadAttention(
hidden_channels,
hidden_channels,
n_heads,
p_dropout=p_dropout,
proximal_bias=proximal_bias,
proximal_init=proximal_init,
)
)
self.norm_layers_0.append(LayerNorm(hidden_channels))
self.encdec_attn_layers.append(
MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout)
)
self.norm_layers_1.append(LayerNorm(hidden_channels))
self.ffn_layers.append(
FFN(
hidden_channels,
hidden_channels,
filter_channels,
kernel_size,
p_dropout=p_dropout,
causal=True,
)
)
self.norm_layers_2.append(LayerNorm(hidden_channels))
def forward(self, x, x_mask, h, h_mask):
"""
x: decoder input
h: encoder output
"""
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype)
encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
x = x * x_mask
for i in range(self.n_layers):
y = self.self_attn_layers[i](x, x, self_attn_mask)
y = self.drop(y)
x = self.norm_layers_0[i](x + y)
y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
y = self.drop(y)
x = self.norm_layers_1[i](x + y)
y = self.ffn_layers[i](x, x_mask)
y = self.drop(y)
x = self.norm_layers_2[i](x + y)
x = x * x_mask
return x
class MultiHeadAttention(nn.Module):
def __init__(
self,
channels,
out_channels,
n_heads,
p_dropout=0.0,
window_size=None,
heads_share=True,
block_length=None,
proximal_bias=False,
proximal_init=False,
):
super().__init__()
assert channels % n_heads == 0
self.channels = channels
self.out_channels = out_channels
self.n_heads = n_heads
self.p_dropout = p_dropout
self.window_size = window_size
self.heads_share = heads_share
self.block_length = block_length
self.proximal_bias = proximal_bias
self.proximal_init = proximal_init
self.attn = None
self.k_channels = channels // n_heads
self.conv_q = nn.Conv1d(channels, channels, 1)
self.conv_k = nn.Conv1d(channels, channels, 1)
self.conv_v = nn.Conv1d(channels, channels, 1)
self.conv_o = nn.Conv1d(channels, out_channels, 1)
self.drop = nn.Dropout(p_dropout)
if window_size is not None:
n_heads_rel = 1 if heads_share else n_heads
rel_stddev = self.k_channels**-0.5
self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
nn.init.xavier_uniform_(self.conv_q.weight)
nn.init.xavier_uniform_(self.conv_k.weight)
nn.init.xavier_uniform_(self.conv_v.weight)
if proximal_init:
with torch.no_grad():
self.conv_k.weight.copy_(self.conv_q.weight)
self.conv_k.bias.copy_(self.conv_q.bias)
def forward(self, x, c, attn_mask=None):
q = self.conv_q(x)
k = self.conv_k(c)
v = self.conv_v(c)
x, self.attn = self.attention(q, k, v, mask=attn_mask)
x = self.conv_o(x)
return x
def attention(self, query, key, value, mask=None):
# reshape [b, d, t] -> [b, n_h, t, d_k]
b, d, t_s, t_t = (*key.size(), query.size(2))
query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
if self.window_size is not None:
assert t_s == t_t, "Relative attention is only available for self-attention."
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
rel_logits = self._matmul_with_relative_keys(query / math.sqrt(self.k_channels), key_relative_embeddings)
scores_local = self._relative_position_to_absolute_position(rel_logits)
scores = scores + scores_local
if self.proximal_bias:
assert t_s == t_t, "Proximal bias is only available for self-attention."
scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e4)
if self.block_length is not None:
assert t_s == t_t, "Local attention is only available for self-attention."
block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length)
scores = scores.masked_fill(block_mask == 0, -1e4)
p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
p_attn = self.drop(p_attn)
output = torch.matmul(p_attn, value)
if self.window_size is not None:
relative_weights = self._absolute_position_to_relative_position(p_attn)
value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
return output, p_attn
def _matmul_with_relative_values(self, x, y):
"""
x: [b, h, l, m]
y: [h or 1, m, d]
ret: [b, h, l, d]
"""
ret = torch.matmul(x, y.unsqueeze(0))
return ret
def _matmul_with_relative_keys(self, x, y):
"""
x: [b, h, l, d]
y: [h or 1, m, d]
ret: [b, h, l, m]
"""
ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
return ret
def _get_relative_embeddings(self, relative_embeddings, length):
max_relative_position = 2 * self.window_size + 1
# Pad first before slice to avoid using cond ops.
pad_length = max(length - (self.window_size + 1), 0)
slice_start_position = max((self.window_size + 1) - length, 0)
slice_end_position = slice_start_position + 2 * length - 1
if pad_length > 0:
padded_relative_embeddings = F.pad(
relative_embeddings,
commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
)
else:
padded_relative_embeddings = relative_embeddings
used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position]
return used_relative_embeddings
def _relative_position_to_absolute_position(self, x):
"""
x: [b, h, l, 2*l-1]
ret: [b, h, l, l]
"""
batch, heads, length, _ = x.size()
# Concat columns of pad to shift from relative to absolute indexing.
x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
# Concat extra elements so to add up to shape (len+1, 2*len-1).
x_flat = x.view([batch, heads, length * 2 * length])
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
# Reshape and slice out the padded elements.
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1 :]
return x_final
def _absolute_position_to_relative_position(self, x):
"""
x: [b, h, l, l]
ret: [b, h, l, 2*l-1]
"""
batch, heads, length, _ = x.size()
# padd along column
x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
# add 0's in the beginning that will skew the elements after reshape
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
return x_final
def _attention_bias_proximal(self, length):
"""Bias for self-attention to encourage attention to close positions.
Args:
length: an integer scalar.
Returns:
a Tensor with shape [1, 1, length, length]
"""
r = torch.arange(length, dtype=torch.float32)
diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
class FFN(nn.Module):
def __init__(
self,
in_channels,
out_channels,
filter_channels,
kernel_size,
p_dropout=0.0,
activation=None,
causal=False,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.filter_channels = filter_channels
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.activation = activation
self.causal = causal
if causal:
self.padding = self._causal_padding
else:
self.padding = self._same_padding
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
self.drop = nn.Dropout(p_dropout)
def forward(self, x, x_mask):
x = self.conv_1(self.padding(x * x_mask))
if self.activation == "gelu":
x = x * torch.sigmoid(1.702 * x)
else:
x = torch.relu(x)
x = self.drop(x)
x = self.conv_2(self.padding(x * x_mask))
return x * x_mask
def _causal_padding(self, x):
if self.kernel_size == 1:
return x
pad_l = self.kernel_size - 1
pad_r = 0
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
x = F.pad(x, commons.convert_pad_shape(padding))
return x
def _same_padding(self, x):
if self.kernel_size == 1:
return x
pad_l = (self.kernel_size - 1) // 2
pad_r = self.kernel_size // 2
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
x = F.pad(x, commons.convert_pad_shape(padding))
return x
import torch.nn as nn
from torch.nn.utils import remove_weight_norm, weight_norm
class Depthwise_Separable_Conv1D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
bias=True,
padding_mode="zeros", # TODO: refine this type
device=None,
dtype=None,
):
super().__init__()
self.depth_conv = nn.Conv1d(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=kernel_size,
groups=in_channels,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias,
padding_mode=padding_mode,
device=device,
dtype=dtype,
)
self.point_conv = nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
bias=bias,
device=device,
dtype=dtype,
)
def forward(self, input):
return self.point_conv(self.depth_conv(input))
def weight_norm(self):
self.depth_conv = weight_norm(self.depth_conv, name="weight")
self.point_conv = weight_norm(self.point_conv, name="weight")
def remove_weight_norm(self):
self.depth_conv = remove_weight_norm(self.depth_conv, name="weight")
self.point_conv = remove_weight_norm(self.point_conv, name="weight")
class Depthwise_Separable_TransposeConv1D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
output_padding=0,
bias=True,
dilation=1,
padding_mode="zeros", # TODO: refine this type
device=None,
dtype=None,
):
super().__init__()
self.depth_conv = nn.ConvTranspose1d(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=kernel_size,
groups=in_channels,
stride=stride,
output_padding=output_padding,
padding=padding,
dilation=dilation,
bias=bias,
padding_mode=padding_mode,
device=device,
dtype=dtype,
)
self.point_conv = nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
bias=bias,
device=device,
dtype=dtype,
)
def forward(self, input):
return self.point_conv(self.depth_conv(input))
def weight_norm(self):
self.depth_conv = weight_norm(self.depth_conv, name="weight")
self.point_conv = weight_norm(self.point_conv, name="weight")
def remove_weight_norm(self):
remove_weight_norm(self.depth_conv, name="weight")
remove_weight_norm(self.point_conv, name="weight")
def weight_norm_modules(module, name="weight", dim=0):
if isinstance(module, Depthwise_Separable_Conv1D) or isinstance(module, Depthwise_Separable_TransposeConv1D):
module.weight_norm()
return module
else:
return weight_norm(module, name, dim)
def remove_weight_norm_modules(module, name="weight"):
if isinstance(module, Depthwise_Separable_Conv1D) or isinstance(module, Depthwise_Separable_TransposeConv1D):
module.remove_weight_norm()
else:
remove_weight_norm(module, name)
class FFT(nn.Module):
def __init__(
self,
hidden_channels,
filter_channels,
n_heads,
n_layers=1,
kernel_size=1,
p_dropout=0.0,
proximal_bias=False,
proximal_init=True,
isflow=False,
**kwargs,
):
super().__init__()
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.proximal_bias = proximal_bias
self.proximal_init = proximal_init
if isflow:
cond_layer = torch.nn.Conv1d(kwargs["gin_channels"], 2 * hidden_channels * n_layers, 1)
self.cond_pre = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, 1)
self.cond_layer = weight_norm_modules(cond_layer, name="weight")
self.gin_channels = kwargs["gin_channels"]
self.drop = nn.Dropout(p_dropout)
self.self_attn_layers = nn.ModuleList()
self.norm_layers_0 = nn.ModuleList()
self.ffn_layers = nn.ModuleList()
self.norm_layers_1 = nn.ModuleList()
for i in range(self.n_layers):
self.self_attn_layers.append(
MultiHeadAttention(
hidden_channels,
hidden_channels,
n_heads,
p_dropout=p_dropout,
proximal_bias=proximal_bias,
proximal_init=proximal_init,
)
)
self.norm_layers_0.append(LayerNorm(hidden_channels))
self.ffn_layers.append(
FFN(
hidden_channels,
hidden_channels,
filter_channels,
kernel_size,
p_dropout=p_dropout,
causal=True,
)
)
self.norm_layers_1.append(LayerNorm(hidden_channels))
def forward(self, x, x_mask, g=None):
"""
x: decoder input
h: encoder output
"""
if g is not None:
g = self.cond_layer(g)
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype)
x = x * x_mask
for i in range(self.n_layers):
if g is not None:
x = self.cond_pre(x)
cond_offset = i * 2 * self.hidden_channels
g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
x = commons.fused_add_tanh_sigmoid_multiply(x, g_l, torch.IntTensor([self.hidden_channels]))
y = self.self_attn_layers[i](x, x, self_attn_mask)
y = self.drop(y)
x = self.norm_layers_0[i](x + y)
y = self.ffn_layers[i](x, x_mask)
y = self.drop(y)
x = self.norm_layers_1[i](x + y)
x = x * x_mask
return x
class TransformerCouplingLayer(nn.Module):
def __init__(
self,
channels,
hidden_channels,
kernel_size,
n_layers,
n_heads,
p_dropout=0,
filter_channels=0,
mean_only=False,
wn_sharing_parameter=None,
gin_channels=0,
):
assert channels % 2 == 0, "channels should be divisible by 2"
super().__init__()
self.channels = channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.n_layers = n_layers
self.half_channels = channels // 2
self.mean_only = mean_only
self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
self.enc = (
Encoder(
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout,
isflow=True,
gin_channels=gin_channels,
)
if wn_sharing_parameter is None
else wn_sharing_parameter
)
self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
self.post.weight.data.zero_()
self.post.bias.data.zero_()
def forward(self, x, x_mask, g=None, reverse=False):
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
h = self.pre(x0) * x_mask
h = self.enc(h, x_mask, g=g)
stats = self.post(h) * x_mask
if not self.mean_only:
m, logs = torch.split(stats, [self.half_channels] * 2, 1)
else:
m = stats
logs = torch.zeros_like(m)
if not reverse:
x1 = m + x1 * torch.exp(logs) * x_mask
x = torch.cat([x0, x1], 1)
logdet = torch.sum(logs, [1, 2])
return x, logdet
else:
x1 = (x1 - m) * torch.exp(-logs) * x_mask
x = torch.cat([x0, x1], 1)
return x

View File

@@ -0,0 +1,385 @@
import math
import torch
from torch import nn
from torch.nn import functional as F
from module import commons
from typing import Optional
class LayerNorm(nn.Module):
def __init__(self, channels, eps=1e-5):
super().__init__()
self.channels = channels
self.eps = eps
self.gamma = nn.Parameter(torch.ones(channels))
self.beta = nn.Parameter(torch.zeros(channels))
def forward(self, x):
x = x.transpose(1, -1)
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
return x.transpose(1, -1)
@torch.jit.script
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
n_channels_int = n_channels[0]
in_act = input_a + input_b
t_act = torch.tanh(in_act[:, :n_channels_int, :])
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
acts = t_act * s_act
return acts
class Encoder(nn.Module):
def __init__(
self,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size=1,
p_dropout=0.0,
window_size=4,
isflow=True,
**kwargs,
):
super().__init__()
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.window_size = window_size
# if isflow:
# cond_layer = torch.nn.Conv1d(256, 2*hidden_channels*n_layers, 1)
# self.cond_pre = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, 1)
# self.cond_layer = weight_norm(cond_layer, name='weight')
# self.gin_channels = 256
self.cond_layer_idx = self.n_layers
self.spk_emb_linear = nn.Linear(256, self.hidden_channels)
if "gin_channels" in kwargs:
self.gin_channels = kwargs["gin_channels"]
if self.gin_channels != 0:
self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels)
# vits2 says 3rd block, so idx is 2 by default
self.cond_layer_idx = kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2
logging.debug(self.gin_channels, self.cond_layer_idx)
assert self.cond_layer_idx < self.n_layers, "cond_layer_idx should be less than n_layers"
self.drop = nn.Dropout(p_dropout)
self.attn_layers = nn.ModuleList()
self.norm_layers_1 = nn.ModuleList()
self.ffn_layers = nn.ModuleList()
self.norm_layers_2 = nn.ModuleList()
for i in range(self.n_layers):
self.attn_layers.append(
MultiHeadAttention(
hidden_channels,
hidden_channels,
n_heads,
p_dropout=p_dropout,
window_size=window_size,
)
)
self.norm_layers_1.append(LayerNorm(hidden_channels))
self.ffn_layers.append(
FFN(
hidden_channels,
hidden_channels,
filter_channels,
kernel_size,
p_dropout=p_dropout,
)
)
self.norm_layers_2.append(LayerNorm(hidden_channels))
# def forward(self, x, x_mask, g=None):
# attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
# x = x * x_mask
# for i in range(self.n_layers):
# if i == self.cond_layer_idx and g is not None:
# g = self.spk_emb_linear(g.transpose(1, 2))
# g = g.transpose(1, 2)
# x = x + g
# x = x * x_mask
# y = self.attn_layers[i](x, x, attn_mask)
# y = self.drop(y)
# x = self.norm_layers_1[i](x + y)
# y = self.ffn_layers[i](x, x_mask)
# y = self.drop(y)
# x = self.norm_layers_2[i](x + y)
# x = x * x_mask
# return x
def forward(self, x, x_mask):
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
x = x * x_mask
for attn_layers, norm_layers_1, ffn_layers, norm_layers_2 in zip(
self.attn_layers, self.norm_layers_1, self.ffn_layers, self.norm_layers_2
):
y = attn_layers(x, x, attn_mask)
y = self.drop(y)
x = norm_layers_1(x + y)
y = ffn_layers(x, x_mask)
y = self.drop(y)
x = norm_layers_2(x + y)
x = x * x_mask
return x
class MultiHeadAttention(nn.Module):
def __init__(
self,
channels,
out_channels,
n_heads,
p_dropout=0.0,
window_size=None,
heads_share=True,
block_length=None,
proximal_bias=False,
proximal_init=False,
):
super().__init__()
assert channels % n_heads == 0
self.channels = channels
self.out_channels = out_channels
self.n_heads = n_heads
self.p_dropout = p_dropout
self.window_size = window_size
self.heads_share = heads_share
self.block_length = block_length
self.proximal_bias = proximal_bias
self.proximal_init = proximal_init
self.attn = None
self.k_channels = channels // n_heads
self.conv_q = nn.Conv1d(channels, channels, 1)
self.conv_k = nn.Conv1d(channels, channels, 1)
self.conv_v = nn.Conv1d(channels, channels, 1)
self.conv_o = nn.Conv1d(channels, out_channels, 1)
self.drop = nn.Dropout(p_dropout)
if window_size is not None:
n_heads_rel = 1 if heads_share else n_heads
rel_stddev = self.k_channels**-0.5
self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
nn.init.xavier_uniform_(self.conv_q.weight)
nn.init.xavier_uniform_(self.conv_k.weight)
nn.init.xavier_uniform_(self.conv_v.weight)
if proximal_init:
with torch.no_grad():
self.conv_k.weight.copy_(self.conv_q.weight)
self.conv_k.bias.copy_(self.conv_q.bias)
def forward(self, x, c, attn_mask: Optional[torch.Tensor] = None):
q = self.conv_q(x)
k = self.conv_k(c)
v = self.conv_v(c)
# x, self.attn = self.attention(q, k, v, mask=attn_mask)
x, _ = self.attention(q, k, v, mask=attn_mask)
x = self.conv_o(x)
return x
def attention(self, query, key, value, mask: Optional[torch.Tensor] = None):
# reshape [b, d, t] -> [b, n_h, t, d_k]
b, d, t_s, _ = (*key.size(), query.size(2))
query = query.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3)
key = key.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3)
value = value.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3)
scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
if self.window_size is not None:
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
rel_logits = self._matmul_with_relative_keys(query / math.sqrt(self.k_channels), key_relative_embeddings)
scores_local = self._relative_position_to_absolute_position(rel_logits)
scores = scores + scores_local
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e4)
p_attn = F.softmax(scores, dim=-1)
p_attn = self.drop(p_attn)
output = torch.matmul(p_attn, value)
if self.window_size is not None:
relative_weights = self._absolute_position_to_relative_position(p_attn)
value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
output = output.transpose(2, 3).contiguous().view(b, d, -1)
return output, p_attn
def _matmul_with_relative_values(self, x, y):
"""
x: [b, h, l, m]
y: [h or 1, m, d]
ret: [b, h, l, d]
"""
ret = torch.matmul(x, y.unsqueeze(0))
return ret
def _matmul_with_relative_keys(self, x, y):
"""
x: [b, h, l, d]
y: [h or 1, m, d]
ret: [b, h, l, m]
"""
ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
return ret
def _get_relative_embeddings(self, relative_embeddings, length):
max_relative_position = 2 * self.window_size + 1
# Pad first before slice to avoid using cond ops.
pad_l = torch.zeros((1), dtype=torch.int64) + length - (self.window_size + 1)
pad_s = torch.zeros((1), dtype=torch.int64) + (self.window_size + 1) - length
pad_length = torch.max(pad_l, other=torch.zeros((1), dtype=torch.int64))
slice_start_position = torch.max(pad_s, other=torch.zeros((1), dtype=torch.int64))
slice_end_position = slice_start_position + 2 * length - 1
padded_relative_embeddings = F.pad(
relative_embeddings,
commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
)
used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position]
return used_relative_embeddings
def _relative_position_to_absolute_position(self, x):
"""
x: [b, h, l, 2*l-1]
ret: [b, h, l, l]
"""
batch, heads, length, _ = x.size()
# Concat columns of pad to shift from relative to absolute indexing.
x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
# Concat extra elements so to add up to shape (len+1, 2*len-1).
x_flat = x.view([batch, heads, length * 2 * length])
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
# Reshape and slice out the padded elements.
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1 :]
return x_final
def _absolute_position_to_relative_position(self, x):
"""
x: [b, h, l, l]
ret: [b, h, l, 2*l-1]
"""
batch, heads, length, _ = x.size()
# padd along column
x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
# add 0's in the beginning that will skew the elements after reshape
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
return x_final
def _attention_bias_proximal(self, length):
"""Bias for self-attention to encourage attention to close positions.
Args:
length: an integer scalar.
Returns:
a Tensor with shape [1, 1, length, length]
"""
r = torch.arange(length, dtype=torch.float32)
diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
class FFN(nn.Module):
def __init__(
self,
in_channels,
out_channels,
filter_channels,
kernel_size,
p_dropout=0.0,
activation="",
causal=False,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.filter_channels = filter_channels
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.activation = activation
self.causal = causal
# 从上下文看这里一定是 False
# if causal:
# self.padding = self._causal_padding
# else:
# self.padding = self._same_padding
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
self.drop = nn.Dropout(p_dropout)
def forward(self, x, x_mask):
x = self.conv_1(self.padding(x * x_mask))
if self.activation == "gelu":
x = x * torch.sigmoid(1.702 * x)
else:
x = torch.relu(x)
x = self.drop(x)
x = self.conv_2(self.padding(x * x_mask))
return x * x_mask
def padding(self, x):
return self._same_padding(x)
def _causal_padding(self, x):
if self.kernel_size == 1:
return x
pad_l = self.kernel_size - 1
pad_r = 0
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
x = F.pad(x, commons.convert_pad_shape(padding))
return x
def _same_padding(self, x):
if self.kernel_size == 1:
return x
pad_l = (self.kernel_size - 1) // 2
pad_r = self.kernel_size // 2
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
x = F.pad(x, commons.convert_pad_shape(padding))
return x
class MRTE(nn.Module):
def __init__(
self,
content_enc_channels=192,
hidden_size=512,
out_channels=192,
kernel_size=5,
n_heads=4,
ge_layer=2,
):
super(MRTE, self).__init__()
self.cross_attention = MultiHeadAttention(hidden_size, hidden_size, n_heads)
self.c_pre = nn.Conv1d(content_enc_channels, hidden_size, 1)
self.text_pre = nn.Conv1d(content_enc_channels, hidden_size, 1)
self.c_post = nn.Conv1d(hidden_size, out_channels, 1)
def forward(self, ssl_enc, ssl_mask, text, text_mask, ge):
attn_mask = text_mask.unsqueeze(2) * ssl_mask.unsqueeze(-1)
ssl_enc = self.c_pre(ssl_enc * ssl_mask)
text_enc = self.text_pre(text * text_mask)
x = self.cross_attention(ssl_enc * ssl_mask, text_enc * text_mask, attn_mask) + ssl_enc + ge
x = self.c_post(x * ssl_mask)
return x

View File

@@ -0,0 +1,185 @@
import math
import torch
from torch.nn import functional as F
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)
def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
# def convert_pad_shape(pad_shape):
# l = pad_shape[::-1]
# pad_shape = [item for sublist in l for item in sublist]
# return pad_shape
def intersperse(lst, item):
result = [item] * (len(lst) * 2 + 1)
result[1::2] = lst
return result
def kl_divergence(m_p, logs_p, m_q, logs_q):
"""KL(P||Q)"""
kl = (logs_q - logs_p) - 0.5
kl += 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
return kl
def rand_gumbel(shape):
"""Sample from the Gumbel distribution, protect from overflows."""
uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
return -torch.log(-torch.log(uniform_samples))
def rand_gumbel_like(x):
g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
return g
def slice_segments(x, ids_str, segment_size=4):
ret = torch.zeros_like(x[:, :, :segment_size])
for i in range(x.size(0)):
idx_str = ids_str[i]
idx_end = idx_str + segment_size
ret[i] = x[i, :, idx_str:idx_end]
return ret
def rand_slice_segments(x, x_lengths=None, segment_size=4):
b, d, t = x.size()
if x_lengths is None:
x_lengths = t
ids_str_max = x_lengths - segment_size + 1
ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
ret = slice_segments(x, ids_str, segment_size)
return ret, ids_str
def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
position = torch.arange(length, dtype=torch.float)
num_timescales = channels // 2
log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (num_timescales - 1)
inv_timescales = min_timescale * torch.exp(
torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
)
scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
signal = F.pad(signal, [0, 0, 0, channels % 2])
signal = signal.view(1, channels, length)
return signal
def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
b, channels, length = x.size()
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
return x + signal.to(dtype=x.dtype, device=x.device)
def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
b, channels, length = x.size()
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
def subsequent_mask(length):
mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
return mask
@torch.jit.script
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
n_channels_int = n_channels[0]
in_act = input_a + input_b
t_act = torch.tanh(in_act[:, :n_channels_int, :])
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
acts = t_act * s_act
return acts
def convert_pad_shape(pad_shape):
l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist]
return pad_shape
def shift_1d(x):
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
return x
def sequence_mask(length, max_length=None):
if max_length is None:
max_length = length.max()
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
return x.unsqueeze(0) < length.unsqueeze(1)
def generate_path(duration, mask):
"""
duration: [b, 1, t_x]
mask: [b, 1, t_y, t_x]
"""
device = duration.device
b, _, t_y, t_x = mask.shape
cum_duration = torch.cumsum(duration, -1)
cum_duration_flat = cum_duration.view(b * t_x)
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
path = path.view(b, t_x, t_y)
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
path = path.unsqueeze(1).transpose(2, 3) * mask
return path
def clip_grad_value_(parameters, clip_value, norm_type=2):
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters))
norm_type = float(norm_type)
if clip_value is not None:
clip_value = float(clip_value)
total_norm = 0
for p in parameters:
param_norm = p.grad.data.norm(norm_type)
total_norm += param_norm.item() ** norm_type
if clip_value is not None:
p.grad.data.clamp_(min=-clip_value, max=clip_value)
total_norm = total_norm ** (1.0 / norm_type)
return total_norm
def squeeze(x, x_mask=None, n_sqz=2):
b, c, t = x.size()
t = (t // n_sqz) * n_sqz
x = x[:, :, :t]
x_sqz = x.view(b, c, t // n_sqz, n_sqz)
x_sqz = x_sqz.permute(0, 3, 1, 2).contiguous().view(b, c * n_sqz, t // n_sqz)
if x_mask is not None:
x_mask = x_mask[:, :, n_sqz - 1 :: n_sqz]
else:
x_mask = torch.ones(b, 1, t // n_sqz).to(device=x.device, dtype=x.dtype)
return x_sqz * x_mask, x_mask
def unsqueeze(x, x_mask=None, n_sqz=2):
b, c, t = x.size()
x_unsqz = x.view(b, n_sqz, c // n_sqz, t)
x_unsqz = x_unsqz.permute(0, 2, 3, 1).contiguous().view(b, c // n_sqz, t * n_sqz)
if x_mask is not None:
x_mask = x_mask.unsqueeze(-1).repeat(1, 1, 1, n_sqz).view(b, 1, t * n_sqz)
else:
x_mask = torch.ones(b, 1, t * n_sqz).to(device=x.device, dtype=x.dtype)
return x_unsqz * x_mask, x_mask

View File

@@ -0,0 +1,365 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
# This implementation is inspired from
# https://github.com/lucidrains/vector-quantize-pytorch
# which is released under MIT License. Hereafter, the original license:
# MIT License
#
# Copyright (c) 2020 Phil Wang
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""Core vector quantization implementation."""
import typing as tp
from einops import rearrange, repeat
import torch
from torch import nn
import torch.nn.functional as F
from tqdm import tqdm
def default(val: tp.Any, d: tp.Any) -> tp.Any:
return val if val is not None else d
def ema_inplace(moving_avg, new, decay: float):
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
return (x + epsilon) / (x.sum() + n_categories * epsilon)
def uniform_init(*shape: int):
t = torch.empty(shape)
nn.init.kaiming_uniform_(t)
return t
def sample_vectors(samples, num: int):
num_samples, device = samples.shape[0], samples.device
if num_samples >= num:
indices = torch.randperm(num_samples, device=device)[:num]
else:
indices = torch.randint(0, num_samples, (num,), device=device)
return samples[indices]
def kmeans(samples, num_clusters: int, num_iters: int = 10):
dim, dtype = samples.shape[-1], samples.dtype
max_kmeans_samples = 500
samples = samples[:max_kmeans_samples, :]
means = sample_vectors(samples, num_clusters)
print("kmeans start ... ")
for _ in tqdm(range(num_iters)):
diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d")
dists = -(diffs**2).sum(dim=-1)
buckets = dists.max(dim=-1).indices
bins = torch.bincount(buckets, minlength=num_clusters)
zero_mask = bins == 0
bins_min_clamped = bins.masked_fill(zero_mask, 1)
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
new_means = new_means / bins_min_clamped[..., None]
means = torch.where(zero_mask[..., None], means, new_means)
return means, bins
class EuclideanCodebook(nn.Module):
"""Codebook with Euclidean distance.
Args:
dim (int): Dimension.
codebook_size (int): Codebook size.
kmeans_init (bool): Whether to use k-means to initialize the codebooks.
If set to true, run the k-means algorithm on the first training batch and use
the learned centroids as initialization.
kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
decay (float): Decay for exponential moving average over the codebooks.
epsilon (float): Epsilon value for numerical stability.
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
that have an exponential moving average cluster size less than the specified threshold with
randomly selected vector from the current batch.
"""
def __init__(
self,
dim: int,
codebook_size: int,
kmeans_init: int = False,
kmeans_iters: int = 10,
decay: float = 0.99,
epsilon: float = 1e-5,
threshold_ema_dead_code: int = 2,
):
super().__init__()
self.decay = decay
init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros
embed = init_fn(codebook_size, dim)
self.codebook_size = codebook_size
self.kmeans_iters = kmeans_iters
self.epsilon = epsilon
self.threshold_ema_dead_code = threshold_ema_dead_code
self.register_buffer("inited", torch.Tensor([not kmeans_init]))
self.register_buffer("cluster_size", torch.zeros(codebook_size))
self.register_buffer("embed", embed)
self.register_buffer("embed_avg", embed.clone())
@torch.jit.ignore
def init_embed_(self, data):
if self.inited:
return
embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
self.embed.data.copy_(embed)
self.embed_avg.data.copy_(embed.clone())
self.cluster_size.data.copy_(cluster_size)
self.inited.data.copy_(torch.Tensor([True]))
# Make sure all buffers across workers are in sync after initialization
# broadcast_tensors(self.buffers())
def replace_(self, samples, mask):
modified_codebook = torch.where(mask[..., None], sample_vectors(samples, self.codebook_size), self.embed)
self.embed.data.copy_(modified_codebook)
def expire_codes_(self, batch_samples):
if self.threshold_ema_dead_code == 0:
return
expired_codes = self.cluster_size < self.threshold_ema_dead_code
if not torch.any(expired_codes):
return
batch_samples = rearrange(batch_samples, "... d -> (...) d")
self.replace_(batch_samples, mask=expired_codes)
# broadcast_tensors(self.buffers())
def preprocess(self, x):
x = rearrange(x, "... d -> (...) d")
return x
def quantize(self, x):
embed = self.embed.t()
dist = -(x.pow(2).sum(1, keepdim=True) - 2 * x @ embed + embed.pow(2).sum(0, keepdim=True))
embed_ind = dist.max(dim=-1).indices
return embed_ind
def postprocess_emb(self, embed_ind, shape):
return embed_ind.view(*shape[:-1])
def dequantize(self, embed_ind):
quantize = F.embedding(embed_ind, self.embed)
return quantize
def encode(self, x):
shape = x.shape
# pre-process
x = self.preprocess(x)
# quantize
embed_ind = self.quantize(x)
# post-process
embed_ind = self.postprocess_emb(embed_ind, shape)
return embed_ind
def decode(self, embed_ind):
quantize = self.dequantize(embed_ind)
return quantize
def forward(self, x):
shape, dtype = x.shape, x.dtype
x = self.preprocess(x)
self.init_embed_(x)
embed_ind = self.quantize(x)
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
embed_ind = self.postprocess_emb(embed_ind, shape)
quantize = self.dequantize(embed_ind)
if self.training:
# We do the expiry of code at that point as buffers are in sync
# and all the workers will take the same decision.
self.expire_codes_(x)
ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
embed_sum = x.t() @ embed_onehot
ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
cluster_size = (
laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon) * self.cluster_size.sum()
)
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
self.embed.data.copy_(embed_normalized)
return quantize, embed_ind
class VectorQuantization(nn.Module):
"""Vector quantization implementation.
Currently supports only euclidean distance.
Args:
dim (int): Dimension
codebook_size (int): Codebook size
codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
decay (float): Decay for exponential moving average over the codebooks.
epsilon (float): Epsilon value for numerical stability.
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
kmeans_iters (int): Number of iterations used for kmeans initialization.
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
that have an exponential moving average cluster size less than the specified threshold with
randomly selected vector from the current batch.
commitment_weight (float): Weight for commitment loss.
"""
def __init__(
self,
dim: int,
codebook_size: int,
codebook_dim: tp.Optional[int] = None,
decay: float = 0.99,
epsilon: float = 1e-5,
kmeans_init: bool = True,
kmeans_iters: int = 50,
threshold_ema_dead_code: int = 2,
commitment_weight: float = 1.0,
):
super().__init__()
_codebook_dim: int = default(codebook_dim, dim)
requires_projection = _codebook_dim != dim
self.project_in = nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()
self.project_out = nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()
self.epsilon = epsilon
self.commitment_weight = commitment_weight
self._codebook = EuclideanCodebook(
dim=_codebook_dim,
codebook_size=codebook_size,
kmeans_init=kmeans_init,
kmeans_iters=kmeans_iters,
decay=decay,
epsilon=epsilon,
threshold_ema_dead_code=threshold_ema_dead_code,
)
self.codebook_size = codebook_size
@property
def codebook(self):
return self._codebook.embed
def encode(self, x):
x = rearrange(x, "b d n -> b n d")
x = self.project_in(x)
embed_in = self._codebook.encode(x)
return embed_in
def decode(self, embed_ind):
quantize = self._codebook.decode(embed_ind)
quantize = self.project_out(quantize)
quantize = rearrange(quantize, "b n d -> b d n")
return quantize
def forward(self, x):
device = x.device
x = rearrange(x, "b d n -> b n d")
x = self.project_in(x)
quantize, embed_ind = self._codebook(x)
if self.training:
quantize = x + (quantize - x).detach()
loss = torch.tensor([0.0], device=device, requires_grad=self.training)
if self.training:
if self.commitment_weight > 0:
commit_loss = F.mse_loss(quantize.detach(), x)
loss = loss + commit_loss * self.commitment_weight
quantize = self.project_out(quantize)
quantize = rearrange(quantize, "b n d -> b d n")
return quantize, embed_ind, loss
class ResidualVectorQuantization(nn.Module):
"""Residual vector quantization implementation.
Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
"""
def __init__(self, *, num_quantizers, **kwargs):
super().__init__()
self.layers = nn.ModuleList([VectorQuantization(**kwargs) for _ in range(num_quantizers)])
def forward(self, x, n_q: tp.Optional[int] = None, layers: tp.Optional[list] = None):
quantized_out = 0.0
residual = x
all_losses = []
all_indices = []
out_quantized = []
n_q = n_q or len(self.layers)
for i, layer in enumerate(self.layers[:n_q]):
quantized, indices, loss = layer(residual)
residual = residual - quantized
quantized_out = quantized_out + quantized
all_indices.append(indices)
all_losses.append(loss)
if layers and i in layers:
out_quantized.append(quantized)
out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
return quantized_out, out_indices, out_losses, out_quantized
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None) -> torch.Tensor:
residual = x
all_indices = []
n_q = n_q or len(self.layers)
st = st or 0
for layer in self.layers[st:n_q]:
indices = layer.encode(residual)
quantized = layer.decode(indices)
residual = residual - quantized
all_indices.append(indices)
out_indices = torch.stack(all_indices)
return out_indices
def decode(self, q_indices: torch.Tensor, st: int = 0) -> torch.Tensor:
quantized_out = torch.tensor(0.0, device=q_indices.device)
for i, indices in enumerate(q_indices):
layer = self.layers[st + i]
quantized = layer.decode(indices)
quantized_out = quantized_out + quantized
return quantized_out

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,70 @@
import math
import torch
def feature_loss(fmap_r, fmap_g):
loss = 0
for dr, dg in zip(fmap_r, fmap_g):
for rl, gl in zip(dr, dg):
rl = rl.float().detach()
gl = gl.float()
loss += torch.mean(torch.abs(rl - gl))
return loss * 2
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
loss = 0
r_losses = []
g_losses = []
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
dr = dr.float()
dg = dg.float()
r_loss = torch.mean((1 - dr) ** 2)
g_loss = torch.mean(dg**2)
loss += r_loss + g_loss
r_losses.append(r_loss.item())
g_losses.append(g_loss.item())
return loss, r_losses, g_losses
def generator_loss(disc_outputs):
loss = 0
gen_losses = []
for dg in disc_outputs:
dg = dg.float()
l = torch.mean((1 - dg) ** 2)
gen_losses.append(l)
loss += l
return loss, gen_losses
def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
"""
z_p, logs_q: [b, h, t_t]
m_p, logs_p: [b, h, t_t]
"""
z_p = z_p.float()
logs_q = logs_q.float()
m_p = m_p.float()
logs_p = logs_p.float()
z_mask = z_mask.float()
kl = logs_p - logs_q - 0.5
kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
kl = torch.sum(kl * z_mask)
l = kl / torch.sum(z_mask)
return l
def mle_loss(z, m, logs, logdet, mask):
l = torch.sum(logs) + 0.5 * torch.sum(
torch.exp(-2 * logs) * ((z - m) ** 2)
) # neg normal likelihood w/o the constant term
l = l - torch.sum(logdet) # log jacobian determinant
l = l / torch.sum(torch.ones_like(z) * mask) # averaging across batch, channel and time axes
l = l + 0.5 * math.log(2 * math.pi) # add the remaining constant term
return l

View File

@@ -0,0 +1,143 @@
import torch
import torch.utils.data
from librosa.filters import mel as librosa_mel_fn
MAX_WAV_VALUE = 32768.0
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
"""
PARAMS
------
C: compression factor
"""
return torch.log(torch.clamp(x, min=clip_val) * C)
def dynamic_range_decompression_torch(x, C=1):
"""
PARAMS
------
C: compression factor used to compress
"""
return torch.exp(x) / C
def spectral_normalize_torch(magnitudes):
output = dynamic_range_compression_torch(magnitudes)
return output
def spectral_de_normalize_torch(magnitudes):
output = dynamic_range_decompression_torch(magnitudes)
return output
mel_basis = {}
hann_window = {}
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
if torch.min(y) < -1.2:
print("min value is ", torch.min(y))
if torch.max(y) > 1.2:
print("max value is ", torch.max(y))
global hann_window
dtype_device = str(y.dtype) + "_" + str(y.device)
# wnsize_dtype_device = str(win_size) + '_' + dtype_device
key = "%s-%s-%s-%s-%s" % (dtype_device, n_fft, sampling_rate, hop_size, win_size)
# if wnsize_dtype_device not in hann_window:
if key not in hann_window:
# hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
hann_window[key] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
y = torch.nn.functional.pad(
y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
)
y = y.squeeze(1)
# spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
spec = torch.stft(
y,
n_fft,
hop_length=hop_size,
win_length=win_size,
window=hann_window[key],
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=False,
)
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-8)
return spec
def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
global mel_basis
dtype_device = str(spec.dtype) + "_" + str(spec.device)
# fmax_dtype_device = str(fmax) + '_' + dtype_device
key = "%s-%s-%s-%s-%s-%s" % (dtype_device, n_fft, num_mels, sampling_rate, fmin, fmax)
# if fmax_dtype_device not in mel_basis:
if key not in mel_basis:
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
# mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
mel_basis[key] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
# spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
spec = torch.matmul(mel_basis[key], spec)
spec = spectral_normalize_torch(spec)
return spec
def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
if torch.min(y) < -1.2:
print("min value is ", torch.min(y))
if torch.max(y) > 1.2:
print("max value is ", torch.max(y))
global mel_basis, hann_window
dtype_device = str(y.dtype) + "_" + str(y.device)
# fmax_dtype_device = str(fmax) + '_' + dtype_device
fmax_dtype_device = "%s-%s-%s-%s-%s-%s-%s-%s" % (
dtype_device,
n_fft,
num_mels,
sampling_rate,
hop_size,
win_size,
fmin,
fmax,
)
# wnsize_dtype_device = str(win_size) + '_' + dtype_device
wnsize_dtype_device = fmax_dtype_device
if fmax_dtype_device not in mel_basis:
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device)
if wnsize_dtype_device not in hann_window:
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
y = torch.nn.functional.pad(
y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
)
y = y.squeeze(1)
spec = torch.stft(
y,
n_fft,
hop_length=hop_size,
win_length=win_size,
window=hann_window[wnsize_dtype_device],
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=False,
)
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-8)
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
spec = spectral_normalize_torch(spec)
return spec

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,897 @@
import math
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn import Conv1d
from torch.nn.utils import weight_norm, remove_weight_norm
from module import commons
from module.commons import init_weights, get_padding
from module.transforms import piecewise_rational_quadratic_transform
import torch.distributions as D
LRELU_SLOPE = 0.1
class LayerNorm(nn.Module):
def __init__(self, channels, eps=1e-5):
super().__init__()
self.channels = channels
self.eps = eps
self.gamma = nn.Parameter(torch.ones(channels))
self.beta = nn.Parameter(torch.zeros(channels))
def forward(self, x):
x = x.transpose(1, -1)
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
return x.transpose(1, -1)
class ConvReluNorm(nn.Module):
def __init__(
self,
in_channels,
hidden_channels,
out_channels,
kernel_size,
n_layers,
p_dropout,
):
super().__init__()
self.in_channels = in_channels
self.hidden_channels = hidden_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.n_layers = n_layers
self.p_dropout = p_dropout
assert n_layers > 1, "Number of layers should be larger than 0."
self.conv_layers = nn.ModuleList()
self.norm_layers = nn.ModuleList()
self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
self.norm_layers.append(LayerNorm(hidden_channels))
self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
for _ in range(n_layers - 1):
self.conv_layers.append(
nn.Conv1d(
hidden_channels,
hidden_channels,
kernel_size,
padding=kernel_size // 2,
)
)
self.norm_layers.append(LayerNorm(hidden_channels))
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
self.proj.weight.data.zero_()
self.proj.bias.data.zero_()
def forward(self, x, x_mask):
x_org = x
for i in range(self.n_layers):
x = self.conv_layers[i](x * x_mask)
x = self.norm_layers[i](x)
x = self.relu_drop(x)
x = x_org + self.proj(x)
return x * x_mask
class DDSConv(nn.Module):
"""
Dialted and Depth-Separable Convolution
"""
def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
super().__init__()
self.channels = channels
self.kernel_size = kernel_size
self.n_layers = n_layers
self.p_dropout = p_dropout
self.drop = nn.Dropout(p_dropout)
self.convs_sep = nn.ModuleList()
self.convs_1x1 = nn.ModuleList()
self.norms_1 = nn.ModuleList()
self.norms_2 = nn.ModuleList()
for i in range(n_layers):
dilation = kernel_size**i
padding = (kernel_size * dilation - dilation) // 2
self.convs_sep.append(
nn.Conv1d(
channels,
channels,
kernel_size,
groups=channels,
dilation=dilation,
padding=padding,
)
)
self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
self.norms_1.append(LayerNorm(channels))
self.norms_2.append(LayerNorm(channels))
def forward(self, x, x_mask, g=None):
if g is not None:
x = x + g
for i in range(self.n_layers):
y = self.convs_sep[i](x * x_mask)
y = self.norms_1[i](y)
y = F.gelu(y)
y = self.convs_1x1[i](y)
y = self.norms_2[i](y)
y = F.gelu(y)
y = self.drop(y)
x = x + y
return x * x_mask
class WN(torch.nn.Module):
def __init__(
self,
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
gin_channels=0,
p_dropout=0,
):
super(WN, self).__init__()
assert kernel_size % 2 == 1
self.hidden_channels = hidden_channels
self.kernel_size = (kernel_size,)
self.dilation_rate = dilation_rate
self.n_layers = n_layers
self.gin_channels = gin_channels
self.p_dropout = p_dropout
self.in_layers = torch.nn.ModuleList()
self.res_skip_layers = torch.nn.ModuleList()
self.drop = nn.Dropout(p_dropout)
if gin_channels != 0:
cond_layer = torch.nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1)
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
for i in range(n_layers):
dilation = dilation_rate**i
padding = int((kernel_size * dilation - dilation) / 2)
in_layer = torch.nn.Conv1d(
hidden_channels,
2 * hidden_channels,
kernel_size,
dilation=dilation,
padding=padding,
)
in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
self.in_layers.append(in_layer)
# last one is not necessary
if i < n_layers - 1:
res_skip_channels = 2 * hidden_channels
else:
res_skip_channels = hidden_channels
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
self.res_skip_layers.append(res_skip_layer)
def forward(self, x, x_mask, g=None, **kwargs):
output = torch.zeros_like(x)
n_channels_tensor = torch.IntTensor([self.hidden_channels])
if g is not None:
g = self.cond_layer(g)
for i in range(self.n_layers):
x_in = self.in_layers[i](x)
if g is not None:
cond_offset = i * 2 * self.hidden_channels
g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
else:
g_l = torch.zeros_like(x_in)
acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
acts = self.drop(acts)
res_skip_acts = self.res_skip_layers[i](acts)
if i < self.n_layers - 1:
res_acts = res_skip_acts[:, : self.hidden_channels, :]
x = (x + res_acts) * x_mask
output = output + res_skip_acts[:, self.hidden_channels :, :]
else:
output = output + res_skip_acts
return output * x_mask
def remove_weight_norm(self):
if self.gin_channels != 0:
torch.nn.utils.remove_weight_norm(self.cond_layer)
for l in self.in_layers:
torch.nn.utils.remove_weight_norm(l)
for l in self.res_skip_layers:
torch.nn.utils.remove_weight_norm(l)
class ResBlock1(torch.nn.Module):
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
super(ResBlock1, self).__init__()
self.convs1 = nn.ModuleList(
[
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2]),
)
),
]
)
self.convs1.apply(init_weights)
self.convs2 = nn.ModuleList(
[
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
),
]
)
self.convs2.apply(init_weights)
def forward(self, x, x_mask=None):
for c1, c2 in zip(self.convs1, self.convs2):
xt = F.leaky_relu(x, LRELU_SLOPE)
if x_mask is not None:
xt = xt * x_mask
xt = c1(xt)
xt = F.leaky_relu(xt, LRELU_SLOPE)
if x_mask is not None:
xt = xt * x_mask
xt = c2(xt)
x = xt + x
if x_mask is not None:
x = x * x_mask
return x
def remove_weight_norm(self):
for l in self.convs1:
remove_weight_norm(l)
for l in self.convs2:
remove_weight_norm(l)
class ResBlock2(torch.nn.Module):
def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
super(ResBlock2, self).__init__()
self.convs = nn.ModuleList(
[
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]),
)
),
]
)
self.convs.apply(init_weights)
def forward(self, x, x_mask=None):
for c in self.convs:
xt = F.leaky_relu(x, LRELU_SLOPE)
if x_mask is not None:
xt = xt * x_mask
xt = c(xt)
x = xt + x
if x_mask is not None:
x = x * x_mask
return x
def remove_weight_norm(self):
for l in self.convs:
remove_weight_norm(l)
class Log(nn.Module):
def forward(self, x, x_mask, reverse=False, **kwargs):
if not reverse:
y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
logdet = torch.sum(-y, [1, 2])
return y, logdet
else:
x = torch.exp(x) * x_mask
return x
class Flip(nn.Module):
def forward(self, x, *args, reverse=False, **kwargs):
x = torch.flip(x, [1])
if not reverse:
logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
return x, logdet
else:
return x
class ElementwiseAffine(nn.Module):
def __init__(self, channels):
super().__init__()
self.channels = channels
self.m = nn.Parameter(torch.zeros(channels, 1))
self.logs = nn.Parameter(torch.zeros(channels, 1))
def forward(self, x, x_mask, reverse=False, **kwargs):
if not reverse:
y = self.m + torch.exp(self.logs) * x
y = y * x_mask
logdet = torch.sum(self.logs * x_mask, [1, 2])
return y, logdet
else:
x = (x - self.m) * torch.exp(-self.logs) * x_mask
return x
class ResidualCouplingLayer(nn.Module):
def __init__(
self,
channels,
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
p_dropout=0,
gin_channels=0,
mean_only=False,
):
assert channels % 2 == 0, "channels should be divisible by 2"
super().__init__()
self.channels = channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.n_layers = n_layers
self.half_channels = channels // 2
self.mean_only = mean_only
self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
self.enc = WN(
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
p_dropout=p_dropout,
gin_channels=gin_channels,
)
self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
self.post.weight.data.zero_()
self.post.bias.data.zero_()
def forward(self, x, x_mask, g=None, reverse=False):
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
h = self.pre(x0) * x_mask
h = self.enc(h, x_mask, g=g)
stats = self.post(h) * x_mask
if not self.mean_only:
m, logs = torch.split(stats, [self.half_channels] * 2, 1)
else:
m = stats
logs = torch.zeros_like(m)
if not reverse:
x1 = m + x1 * torch.exp(logs) * x_mask
x = torch.cat([x0, x1], 1)
logdet = torch.sum(logs, [1, 2])
return x, logdet
else:
x1 = (x1 - m) * torch.exp(-logs) * x_mask
x = torch.cat([x0, x1], 1)
return x
class ConvFlow(nn.Module):
def __init__(
self,
in_channels,
filter_channels,
kernel_size,
n_layers,
num_bins=10,
tail_bound=5.0,
):
super().__init__()
self.in_channels = in_channels
self.filter_channels = filter_channels
self.kernel_size = kernel_size
self.n_layers = n_layers
self.num_bins = num_bins
self.tail_bound = tail_bound
self.half_channels = in_channels // 2
self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0)
self.proj = nn.Conv1d(filter_channels, self.half_channels * (num_bins * 3 - 1), 1)
self.proj.weight.data.zero_()
self.proj.bias.data.zero_()
def forward(self, x, x_mask, g=None, reverse=False):
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
h = self.pre(x0)
h = self.convs(h, x_mask, g=g)
h = self.proj(h) * x_mask
b, c, t = x0.shape
h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels)
unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(self.filter_channels)
unnormalized_derivatives = h[..., 2 * self.num_bins :]
x1, logabsdet = piecewise_rational_quadratic_transform(
x1,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=reverse,
tails="linear",
tail_bound=self.tail_bound,
)
x = torch.cat([x0, x1], 1) * x_mask
logdet = torch.sum(logabsdet * x_mask, [1, 2])
if not reverse:
return x, logdet
else:
return x
class LinearNorm(nn.Module):
def __init__(
self,
in_channels,
out_channels,
bias=True,
spectral_norm=False,
):
super(LinearNorm, self).__init__()
self.fc = nn.Linear(in_channels, out_channels, bias)
if spectral_norm:
self.fc = nn.utils.spectral_norm(self.fc)
def forward(self, input):
out = self.fc(input)
return out
class Mish(nn.Module):
def __init__(self):
super(Mish, self).__init__()
def forward(self, x):
return x * torch.tanh(F.softplus(x))
class Conv1dGLU(nn.Module):
"""
Conv1d + GLU(Gated Linear Unit) with residual connection.
For GLU refer to https://arxiv.org/abs/1612.08083 paper.
"""
def __init__(self, in_channels, out_channels, kernel_size, dropout):
super(Conv1dGLU, self).__init__()
self.out_channels = out_channels
self.conv1 = ConvNorm(in_channels, 2 * out_channels, kernel_size=kernel_size)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
residual = x
x = self.conv1(x)
x1, x2 = torch.split(x, split_size_or_sections=self.out_channels, dim=1)
x = x1 * torch.sigmoid(x2)
x = residual + self.dropout(x)
return x
class ConvNorm(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=None,
dilation=1,
bias=True,
spectral_norm=False,
):
super(ConvNorm, self).__init__()
if padding is None:
assert kernel_size % 2 == 1
padding = int(dilation * (kernel_size - 1) / 2)
self.conv = torch.nn.Conv1d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias,
)
if spectral_norm:
self.conv = nn.utils.spectral_norm(self.conv)
def forward(self, input):
out = self.conv(input)
return out
class MultiHeadAttention(nn.Module):
"""Multi-Head Attention module"""
def __init__(self, n_head, d_model, d_k, d_v, dropout=0.0, spectral_norm=False):
super().__init__()
self.n_head = n_head
self.d_k = d_k
self.d_v = d_v
self.w_qs = nn.Linear(d_model, n_head * d_k)
self.w_ks = nn.Linear(d_model, n_head * d_k)
self.w_vs = nn.Linear(d_model, n_head * d_v)
self.attention = ScaledDotProductAttention(temperature=np.power(d_model, 0.5), dropout=dropout)
self.fc = nn.Linear(n_head * d_v, d_model)
self.dropout = nn.Dropout(dropout)
if spectral_norm:
self.w_qs = nn.utils.spectral_norm(self.w_qs)
self.w_ks = nn.utils.spectral_norm(self.w_ks)
self.w_vs = nn.utils.spectral_norm(self.w_vs)
self.fc = nn.utils.spectral_norm(self.fc)
def forward(self, x, mask=None):
d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
sz_b, len_x, _ = x.size()
residual = x
q = self.w_qs(x).view(sz_b, len_x, n_head, d_k)
k = self.w_ks(x).view(sz_b, len_x, n_head, d_k)
v = self.w_vs(x).view(sz_b, len_x, n_head, d_v)
q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_k) # (n*b) x lq x dk
k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_k) # (n*b) x lk x dk
v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_v) # (n*b) x lv x dv
if mask is not None:
slf_mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x ..
else:
slf_mask = None
output, attn = self.attention(q, k, v, mask=slf_mask)
output = output.view(n_head, sz_b, len_x, d_v)
output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_x, -1) # b x lq x (n*dv)
output = self.fc(output)
output = self.dropout(output) + residual
return output, attn
class ScaledDotProductAttention(nn.Module):
"""Scaled Dot-Product Attention"""
def __init__(self, temperature, dropout):
super().__init__()
self.temperature = temperature
self.softmax = nn.Softmax(dim=2)
self.dropout = nn.Dropout(dropout)
def forward(self, q, k, v, mask=None):
attn = torch.bmm(q, k.transpose(1, 2))
attn = attn / self.temperature
if mask is not None:
attn = attn.masked_fill(mask, -np.inf)
attn = self.softmax(attn)
p_attn = self.dropout(attn)
output = torch.bmm(p_attn, v)
return output, attn
class MelStyleEncoder(nn.Module):
"""MelStyleEncoder"""
def __init__(
self,
n_mel_channels=80,
style_hidden=128,
style_vector_dim=256,
style_kernel_size=5,
style_head=2,
dropout=0.1,
):
super(MelStyleEncoder, self).__init__()
self.in_dim = n_mel_channels
self.hidden_dim = style_hidden
self.out_dim = style_vector_dim
self.kernel_size = style_kernel_size
self.n_head = style_head
self.dropout = dropout
self.spectral = nn.Sequential(
LinearNorm(self.in_dim, self.hidden_dim),
Mish(),
nn.Dropout(self.dropout),
LinearNorm(self.hidden_dim, self.hidden_dim),
Mish(),
nn.Dropout(self.dropout),
)
self.temporal = nn.Sequential(
Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
)
self.slf_attn = MultiHeadAttention(
self.n_head,
self.hidden_dim,
self.hidden_dim // self.n_head,
self.hidden_dim // self.n_head,
self.dropout,
)
self.fc = LinearNorm(self.hidden_dim, self.out_dim)
def temporal_avg_pool(self, x, mask=None):
if mask is None:
out = torch.mean(x, dim=1)
else:
len_ = (~mask).sum(dim=1).unsqueeze(1)
x = x.masked_fill(mask.unsqueeze(-1), 0)
dtype = x.dtype
x = x.float()
x = torch.div(x, len_.unsqueeze(1))
out = x.sum(dim=1).to(dtype)
return out
def forward(self, x, mask=None):
x = x.transpose(1, 2)
if mask is not None:
mask = (mask.int() == 0).squeeze(1)
max_len = x.shape[1]
slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1) if mask is not None else None
# spectral
x = self.spectral(x)
# temporal
x = x.transpose(1, 2)
x = self.temporal(x)
x = x.transpose(1, 2)
# self-attention
if mask is not None:
x = x.masked_fill(mask.unsqueeze(-1), 0)
x, _ = self.slf_attn(x, mask=slf_attn_mask)
# fc
x = self.fc(x)
# temoral average pooling
w = self.temporal_avg_pool(x, mask=mask)
return w.unsqueeze(-1)
class MelStyleEncoderVAE(nn.Module):
def __init__(self, spec_channels, z_latent_dim, emb_dim):
super().__init__()
self.ref_encoder = MelStyleEncoder(spec_channels, style_vector_dim=emb_dim)
self.fc1 = nn.Linear(emb_dim, z_latent_dim)
self.fc2 = nn.Linear(emb_dim, z_latent_dim)
self.fc3 = nn.Linear(z_latent_dim, emb_dim)
self.z_latent_dim = z_latent_dim
def reparameterize(self, mu, logvar):
if self.training:
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return eps.mul(std).add_(mu)
else:
return mu
def forward(self, inputs, mask=None):
enc_out = self.ref_encoder(inputs.squeeze(-1), mask).squeeze(-1)
mu = self.fc1(enc_out)
logvar = self.fc2(enc_out)
posterior = D.Normal(mu, torch.exp(logvar))
kl_divergence = D.kl_divergence(posterior, D.Normal(torch.zeros_like(mu), torch.ones_like(logvar)))
loss_kl = kl_divergence.mean()
z = posterior.rsample()
style_embed = self.fc3(z)
return style_embed.unsqueeze(-1), loss_kl
def infer(self, inputs=None, random_sample=False, manual_latent=None):
if manual_latent is None:
if random_sample:
dev = next(self.parameters()).device
posterior = D.Normal(
torch.zeros(1, self.z_latent_dim, device=dev),
torch.ones(1, self.z_latent_dim, device=dev),
)
z = posterior.rsample()
else:
enc_out = self.ref_encoder(inputs.transpose(1, 2))
mu = self.fc1(enc_out)
z = mu
else:
z = manual_latent
style_embed = self.fc3(z)
return style_embed.unsqueeze(-1), z
class ActNorm(nn.Module):
def __init__(self, channels, ddi=False, **kwargs):
super().__init__()
self.channels = channels
self.initialized = not ddi
self.logs = nn.Parameter(torch.zeros(1, channels, 1))
self.bias = nn.Parameter(torch.zeros(1, channels, 1))
def forward(self, x, x_mask=None, g=None, reverse=False, **kwargs):
if x_mask is None:
x_mask = torch.ones(x.size(0), 1, x.size(2)).to(device=x.device, dtype=x.dtype)
x_len = torch.sum(x_mask, [1, 2])
if not self.initialized:
self.initialize(x, x_mask)
self.initialized = True
if reverse:
z = (x - self.bias) * torch.exp(-self.logs) * x_mask
logdet = None
return z
else:
z = (self.bias + torch.exp(self.logs) * x) * x_mask
logdet = torch.sum(self.logs) * x_len # [b]
return z, logdet
def store_inverse(self):
pass
def set_ddi(self, ddi):
self.initialized = not ddi
def initialize(self, x, x_mask):
with torch.no_grad():
denom = torch.sum(x_mask, [0, 2])
m = torch.sum(x * x_mask, [0, 2]) / denom
m_sq = torch.sum(x * x * x_mask, [0, 2]) / denom
v = m_sq - (m**2)
logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6))
bias_init = (-m * torch.exp(-logs)).view(*self.bias.shape).to(dtype=self.bias.dtype)
logs_init = (-logs).view(*self.logs.shape).to(dtype=self.logs.dtype)
self.bias.data.copy_(bias_init)
self.logs.data.copy_(logs_init)
class InvConvNear(nn.Module):
def __init__(self, channels, n_split=4, no_jacobian=False, **kwargs):
super().__init__()
assert n_split % 2 == 0
self.channels = channels
self.n_split = n_split
self.no_jacobian = no_jacobian
w_init = torch.linalg.qr(torch.FloatTensor(self.n_split, self.n_split).normal_())[0]
if torch.det(w_init) < 0:
w_init[:, 0] = -1 * w_init[:, 0]
self.weight = nn.Parameter(w_init)
def forward(self, x, x_mask=None, g=None, reverse=False, **kwargs):
b, c, t = x.size()
assert c % self.n_split == 0
if x_mask is None:
x_mask = 1
x_len = torch.ones((b,), dtype=x.dtype, device=x.device) * t
else:
x_len = torch.sum(x_mask, [1, 2])
x = x.view(b, 2, c // self.n_split, self.n_split // 2, t)
x = x.permute(0, 1, 3, 2, 4).contiguous().view(b, self.n_split, c // self.n_split, t)
if reverse:
if hasattr(self, "weight_inv"):
weight = self.weight_inv
else:
weight = torch.inverse(self.weight.float()).to(dtype=self.weight.dtype)
logdet = None
else:
weight = self.weight
if self.no_jacobian:
logdet = 0
else:
logdet = torch.logdet(self.weight) * (c / self.n_split) * x_len # [b]
weight = weight.view(self.n_split, self.n_split, 1, 1)
z = F.conv2d(x, weight)
z = z.view(b, 2, self.n_split // 2, c // self.n_split, t)
z = z.permute(0, 1, 3, 2, 4).contiguous().view(b, c, t) * x_mask
if reverse:
return z
else:
return z, logdet
def store_inverse(self):
self.weight_inv = torch.inverse(self.weight.float()).to(dtype=self.weight.dtype)

View File

@@ -0,0 +1,173 @@
# This is Multi-reference timbre encoder
import torch
from torch import nn
from torch.nn.utils import remove_weight_norm, weight_norm
from module.attentions import MultiHeadAttention
class MRTE(nn.Module):
def __init__(
self,
content_enc_channels=192,
hidden_size=512,
out_channels=192,
kernel_size=5,
n_heads=4,
ge_layer=2,
):
super(MRTE, self).__init__()
self.cross_attention = MultiHeadAttention(hidden_size, hidden_size, n_heads)
self.c_pre = nn.Conv1d(content_enc_channels, hidden_size, 1)
self.text_pre = nn.Conv1d(content_enc_channels, hidden_size, 1)
self.c_post = nn.Conv1d(hidden_size, out_channels, 1)
def forward(self, ssl_enc, ssl_mask, text, text_mask, ge, test=None):
if ge == None:
ge = 0
attn_mask = text_mask.unsqueeze(2) * ssl_mask.unsqueeze(-1)
ssl_enc = self.c_pre(ssl_enc * ssl_mask)
text_enc = self.text_pre(text * text_mask)
if test != None:
if test == 0:
x = self.cross_attention(ssl_enc * ssl_mask, text_enc * text_mask, attn_mask) + ssl_enc + ge
elif test == 1:
x = ssl_enc + ge
elif test == 2:
x = self.cross_attention(ssl_enc * 0 * ssl_mask, text_enc * text_mask, attn_mask) + ge
else:
raise ValueError("test should be 0,1,2")
else:
x = self.cross_attention(ssl_enc * ssl_mask, text_enc * text_mask, attn_mask) + ssl_enc + ge
x = self.c_post(x * ssl_mask)
return x
class SpeakerEncoder(torch.nn.Module):
def __init__(
self,
mel_n_channels=80,
model_num_layers=2,
model_hidden_size=256,
model_embedding_size=256,
):
super(SpeakerEncoder, self).__init__()
self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True)
self.linear = nn.Linear(model_hidden_size, model_embedding_size)
self.relu = nn.ReLU()
def forward(self, mels):
self.lstm.flatten_parameters()
_, (hidden, _) = self.lstm(mels.transpose(-1, -2))
embeds_raw = self.relu(self.linear(hidden[-1]))
return embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True)
class MELEncoder(nn.Module):
def __init__(
self,
in_channels,
out_channels,
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.n_layers = n_layers
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers)
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
def forward(self, x):
# print(x.shape,x_lengths.shape)
x = self.pre(x)
x = self.enc(x)
x = self.proj(x)
return x
class WN(torch.nn.Module):
def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers):
super(WN, self).__init__()
assert kernel_size % 2 == 1
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.n_layers = n_layers
self.in_layers = torch.nn.ModuleList()
self.res_skip_layers = torch.nn.ModuleList()
for i in range(n_layers):
dilation = dilation_rate**i
padding = int((kernel_size * dilation - dilation) / 2)
in_layer = nn.Conv1d(
hidden_channels,
2 * hidden_channels,
kernel_size,
dilation=dilation,
padding=padding,
)
in_layer = weight_norm(in_layer)
self.in_layers.append(in_layer)
# last one is not necessary
if i < n_layers - 1:
res_skip_channels = 2 * hidden_channels
else:
res_skip_channels = hidden_channels
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
res_skip_layer = weight_norm(res_skip_layer, name="weight")
self.res_skip_layers.append(res_skip_layer)
def forward(self, x):
output = torch.zeros_like(x)
n_channels_tensor = torch.IntTensor([self.hidden_channels])
for i in range(self.n_layers):
x_in = self.in_layers[i](x)
acts = fused_add_tanh_sigmoid_multiply(x_in, n_channels_tensor)
res_skip_acts = self.res_skip_layers[i](acts)
if i < self.n_layers - 1:
res_acts = res_skip_acts[:, : self.hidden_channels, :]
x = x + res_acts
output = output + res_skip_acts[:, self.hidden_channels :, :]
else:
output = output + res_skip_acts
return output
def remove_weight_norm(self):
for l in self.in_layers:
remove_weight_norm(l)
for l in self.res_skip_layers:
remove_weight_norm(l)
@torch.jit.script
def fused_add_tanh_sigmoid_multiply(input, n_channels):
n_channels_int = n_channels[0]
t_act = torch.tanh(input[:, :n_channels_int, :])
s_act = torch.sigmoid(input[:, n_channels_int:, :])
acts = t_act * s_act
return acts
if __name__ == "__main__":
content_enc = torch.randn(3, 192, 100)
content_mask = torch.ones(3, 1, 100)
ref_mel = torch.randn(3, 128, 30)
ref_mask = torch.ones(3, 1, 30)
model = MRTE()
out = model(content_enc, content_mask, ref_mel, ref_mask)
print(out.shape)

View File

@@ -0,0 +1,114 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Residual vector quantizer implementation."""
from dataclasses import dataclass, field
import typing as tp
import torch
from torch import nn
from module.core_vq import ResidualVectorQuantization
@dataclass
class QuantizedResult:
quantized: torch.Tensor
codes: torch.Tensor
bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item.
penalty: tp.Optional[torch.Tensor] = None
metrics: dict = field(default_factory=dict)
class ResidualVectorQuantizer(nn.Module):
"""Residual Vector Quantizer.
Args:
dimension (int): Dimension of the codebooks.
n_q (int): Number of residual vector quantizers used.
bins (int): Codebook size.
decay (float): Decay for exponential moving average over the codebooks.
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
kmeans_iters (int): Number of iterations used for kmeans initialization.
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
that have an exponential moving average cluster size less than the specified threshold with
randomly selected vector from the current batch.
"""
def __init__(
self,
dimension: int = 256,
n_q: int = 8,
bins: int = 1024,
decay: float = 0.99,
kmeans_init: bool = True,
kmeans_iters: int = 50,
threshold_ema_dead_code: int = 2,
):
super().__init__()
self.n_q = n_q
self.dimension = dimension
self.bins = bins
self.decay = decay
self.kmeans_init = kmeans_init
self.kmeans_iters = kmeans_iters
self.threshold_ema_dead_code = threshold_ema_dead_code
self.vq = ResidualVectorQuantization(
dim=self.dimension,
codebook_size=self.bins,
num_quantizers=self.n_q,
decay=self.decay,
kmeans_init=self.kmeans_init,
kmeans_iters=self.kmeans_iters,
threshold_ema_dead_code=self.threshold_ema_dead_code,
)
def forward(
self,
x: torch.Tensor,
n_q: tp.Optional[int] = None,
layers: tp.Optional[list] = None,
) -> QuantizedResult:
"""Residual vector quantization on the given input tensor.
Args:
x (torch.Tensor): Input tensor.
n_q (int): Number of quantizer used to quantize. Default: All quantizers.
layers (list): Layer that need to return quantized. Defalt: None.
Returns:
QuantizedResult:
The quantized (or approximately quantized) representation with
the associated numbert quantizers and layer quantized required to return.
"""
n_q = n_q if n_q else self.n_q
if layers and max(layers) >= n_q:
raise ValueError(
f"Last layer index in layers: A {max(layers)}. Number of quantizers in RVQ: B {self.n_q}. A must less than B."
)
quantized, codes, commit_loss, quantized_list = self.vq(x, n_q=n_q, layers=layers)
return quantized, codes, torch.mean(commit_loss), quantized_list
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None) -> torch.Tensor:
"""Encode a given input tensor with the specified sample rate at the given bandwidth.
The RVQ encode method sets the appropriate number of quantizer to use
and returns indices for each quantizer.
Args:
x (torch.Tensor): Input tensor.
n_q (int): Number of quantizer used to quantize. Default: All quantizers.
st (int): Start to encode input from which layers. Default: 0.
"""
n_q = n_q if n_q else self.n_q
st = st or 0
codes = self.vq.encode(x, n_q=n_q, st=st)
return codes
def decode(self, codes: torch.Tensor, st: int = 0) -> torch.Tensor:
"""Decode the given codes to the quantized representation.
Args:
codes (torch.Tensor): Input indices for each quantizer.
st (int): Start to decode input codes from which layers. Default: 0.
"""
quantized = self.vq.decode(codes, st=st)
return quantized

View File

@@ -0,0 +1,205 @@
import torch
from torch.nn import functional as F
import numpy as np
DEFAULT_MIN_BIN_WIDTH = 1e-3
DEFAULT_MIN_BIN_HEIGHT = 1e-3
DEFAULT_MIN_DERIVATIVE = 1e-3
def piecewise_rational_quadratic_transform(
inputs,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=False,
tails=None,
tail_bound=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE,
):
if tails is None:
spline_fn = rational_quadratic_spline
spline_kwargs = {}
else:
spline_fn = unconstrained_rational_quadratic_spline
spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
outputs, logabsdet = spline_fn(
inputs=inputs,
unnormalized_widths=unnormalized_widths,
unnormalized_heights=unnormalized_heights,
unnormalized_derivatives=unnormalized_derivatives,
inverse=inverse,
min_bin_width=min_bin_width,
min_bin_height=min_bin_height,
min_derivative=min_derivative,
**spline_kwargs,
)
return outputs, logabsdet
def searchsorted(bin_locations, inputs, eps=1e-6):
bin_locations[..., -1] += eps
return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
def unconstrained_rational_quadratic_spline(
inputs,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=False,
tails="linear",
tail_bound=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE,
):
inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
outside_interval_mask = ~inside_interval_mask
outputs = torch.zeros_like(inputs)
logabsdet = torch.zeros_like(inputs)
if tails == "linear":
unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
constant = np.log(np.exp(1 - min_derivative) - 1)
unnormalized_derivatives[..., 0] = constant
unnormalized_derivatives[..., -1] = constant
outputs[outside_interval_mask] = inputs[outside_interval_mask]
logabsdet[outside_interval_mask] = 0
else:
raise RuntimeError("{} tails are not implemented.".format(tails))
(
outputs[inside_interval_mask],
logabsdet[inside_interval_mask],
) = rational_quadratic_spline(
inputs=inputs[inside_interval_mask],
unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
inverse=inverse,
left=-tail_bound,
right=tail_bound,
bottom=-tail_bound,
top=tail_bound,
min_bin_width=min_bin_width,
min_bin_height=min_bin_height,
min_derivative=min_derivative,
)
return outputs, logabsdet
def rational_quadratic_spline(
inputs,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=False,
left=0.0,
right=1.0,
bottom=0.0,
top=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE,
):
if torch.min(inputs) < left or torch.max(inputs) > right:
raise ValueError("Input to a transform is not within its domain")
num_bins = unnormalized_widths.shape[-1]
if min_bin_width * num_bins > 1.0:
raise ValueError("Minimal bin width too large for the number of bins")
if min_bin_height * num_bins > 1.0:
raise ValueError("Minimal bin height too large for the number of bins")
widths = F.softmax(unnormalized_widths, dim=-1)
widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
cumwidths = torch.cumsum(widths, dim=-1)
cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
cumwidths = (right - left) * cumwidths + left
cumwidths[..., 0] = left
cumwidths[..., -1] = right
widths = cumwidths[..., 1:] - cumwidths[..., :-1]
derivatives = min_derivative + F.softplus(unnormalized_derivatives)
heights = F.softmax(unnormalized_heights, dim=-1)
heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
cumheights = torch.cumsum(heights, dim=-1)
cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
cumheights = (top - bottom) * cumheights + bottom
cumheights[..., 0] = bottom
cumheights[..., -1] = top
heights = cumheights[..., 1:] - cumheights[..., :-1]
if inverse:
bin_idx = searchsorted(cumheights, inputs)[..., None]
else:
bin_idx = searchsorted(cumwidths, inputs)[..., None]
input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
delta = heights / widths
input_delta = delta.gather(-1, bin_idx)[..., 0]
input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
input_heights = heights.gather(-1, bin_idx)[..., 0]
if inverse:
a = (inputs - input_cumheights) * (
input_derivatives + input_derivatives_plus_one - 2 * input_delta
) + input_heights * (input_delta - input_derivatives)
b = input_heights * input_derivatives - (inputs - input_cumheights) * (
input_derivatives + input_derivatives_plus_one - 2 * input_delta
)
c = -input_delta * (inputs - input_cumheights)
discriminant = b.pow(2) - 4 * a * c
assert (discriminant >= 0).all()
root = (2 * c) / (-b - torch.sqrt(discriminant))
outputs = root * input_bin_widths + input_cumwidths
theta_one_minus_theta = root * (1 - root)
denominator = input_delta + (
(input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta
)
derivative_numerator = input_delta.pow(2) * (
input_derivatives_plus_one * root.pow(2)
+ 2 * input_delta * theta_one_minus_theta
+ input_derivatives * (1 - root).pow(2)
)
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
return outputs, -logabsdet
else:
theta = (inputs - input_cumwidths) / input_bin_widths
theta_one_minus_theta = theta * (1 - theta)
numerator = input_heights * (input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta)
denominator = input_delta + (
(input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta
)
outputs = input_cumheights + numerator / denominator
derivative_numerator = input_delta.pow(2) * (
input_derivatives_plus_one * theta.pow(2)
+ 2 * input_delta * theta_one_minus_theta
+ input_derivatives * (1 - theta).pow(2)
)
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
return outputs, logabsdet