init ascend tts
This commit is contained in:
@@ -0,0 +1,67 @@
|
||||
from packaging import version
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def default(v, d):
|
||||
return v if exists(v) else d
|
||||
|
||||
|
||||
class Attend(nn.Module):
|
||||
def __init__(self, dropout=0.0, flash=False, scale=None):
|
||||
super().__init__()
|
||||
self.scale = scale
|
||||
self.dropout = dropout
|
||||
self.attn_dropout = nn.Dropout(dropout)
|
||||
|
||||
self.flash = flash
|
||||
assert not (flash and version.parse(torch.__version__) < version.parse("2.0.0")), (
|
||||
"in order to use flash attention, you must be using pytorch 2.0 or above"
|
||||
)
|
||||
|
||||
def flash_attn(self, q, k, v):
|
||||
# _, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
|
||||
|
||||
if exists(self.scale):
|
||||
default_scale = q.shape[-1] ** -0.5
|
||||
q = q * (self.scale / default_scale)
|
||||
|
||||
# pytorch 2.0 flash attn: q, k, v, mask, dropout, softmax_scale
|
||||
# with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
|
||||
return F.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout if self.training else 0.0)
|
||||
|
||||
def forward(self, q, k, v):
|
||||
"""
|
||||
einstein notation
|
||||
b - batch
|
||||
h - heads
|
||||
n, i, j - sequence length (base sequence length, source, target)
|
||||
d - feature dimension
|
||||
"""
|
||||
|
||||
# q_len, k_len, device = q.shape[-2], k.shape[-2], q.device
|
||||
|
||||
scale = default(self.scale, q.shape[-1] ** -0.5)
|
||||
|
||||
if self.flash:
|
||||
return self.flash_attn(q, k, v)
|
||||
|
||||
# similarity
|
||||
|
||||
sim = einsum("b h i d, b h j d -> b h i j", q, k) * scale
|
||||
|
||||
# attention
|
||||
|
||||
attn = sim.softmax(dim=-1)
|
||||
attn = self.attn_dropout(attn)
|
||||
|
||||
# aggregate values
|
||||
|
||||
out = einsum("b h i j, b h j d -> b h i d", attn, v)
|
||||
|
||||
return out
|
||||
@@ -0,0 +1,626 @@
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import Module, ModuleList
|
||||
import torch.nn.functional as F
|
||||
|
||||
from bs_roformer.attend import Attend
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from typing import Tuple, Optional, Callable
|
||||
# from beartype.typing import Tuple, Optional, List, Callable
|
||||
# from beartype import beartype
|
||||
|
||||
from rotary_embedding_torch import RotaryEmbedding
|
||||
|
||||
from einops import rearrange, pack, unpack
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helper functions
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def default(v, d):
|
||||
return v if exists(v) else d
|
||||
|
||||
|
||||
def pack_one(t, pattern):
|
||||
return pack([t], pattern)
|
||||
|
||||
|
||||
def unpack_one(t, ps, pattern):
|
||||
return unpack(t, ps, pattern)[0]
|
||||
|
||||
|
||||
# norm
|
||||
|
||||
|
||||
def l2norm(t):
|
||||
return F.normalize(t, dim=-1, p=2)
|
||||
|
||||
|
||||
class RMSNorm(Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.scale = dim**0.5
|
||||
self.gamma = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def forward(self, x):
|
||||
return F.normalize(x, dim=-1) * self.scale * self.gamma
|
||||
|
||||
|
||||
# attention
|
||||
|
||||
|
||||
class FeedForward(Module):
|
||||
def __init__(self, dim, mult=4, dropout=0.0):
|
||||
super().__init__()
|
||||
dim_inner = int(dim * mult)
|
||||
self.net = nn.Sequential(
|
||||
RMSNorm(dim),
|
||||
nn.Linear(dim, dim_inner),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(dim_inner, dim),
|
||||
nn.Dropout(dropout),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(self, dim, heads=8, dim_head=64, dropout=0.0, rotary_embed=None, flash=True):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.scale = dim_head**-0.5
|
||||
dim_inner = heads * dim_head
|
||||
|
||||
self.rotary_embed = rotary_embed
|
||||
|
||||
self.attend = Attend(flash=flash, dropout=dropout)
|
||||
|
||||
self.norm = RMSNorm(dim)
|
||||
self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
|
||||
|
||||
self.to_gates = nn.Linear(dim, heads)
|
||||
|
||||
self.to_out = nn.Sequential(nn.Linear(dim_inner, dim, bias=False), nn.Dropout(dropout))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
|
||||
q, k, v = rearrange(self.to_qkv(x), "b n (qkv h d) -> qkv b h n d", qkv=3, h=self.heads)
|
||||
|
||||
if exists(self.rotary_embed):
|
||||
q = self.rotary_embed.rotate_queries_or_keys(q)
|
||||
k = self.rotary_embed.rotate_queries_or_keys(k)
|
||||
|
||||
out = self.attend(q, k, v)
|
||||
|
||||
gates = self.to_gates(x)
|
||||
out = out * rearrange(gates, "b n h -> b h n 1").sigmoid()
|
||||
|
||||
out = rearrange(out, "b h n d -> b n (h d)")
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class LinearAttention(Module):
|
||||
"""
|
||||
this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al.
|
||||
"""
|
||||
|
||||
# @beartype
|
||||
def __init__(self, *, dim, dim_head=32, heads=8, scale=8, flash=False, dropout=0.0):
|
||||
super().__init__()
|
||||
dim_inner = dim_head * heads
|
||||
self.norm = RMSNorm(dim)
|
||||
|
||||
self.to_qkv = nn.Sequential(
|
||||
nn.Linear(dim, dim_inner * 3, bias=False), Rearrange("b n (qkv h d) -> qkv b h d n", qkv=3, h=heads)
|
||||
)
|
||||
|
||||
self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
|
||||
|
||||
self.attend = Attend(scale=scale, dropout=dropout, flash=flash)
|
||||
|
||||
self.to_out = nn.Sequential(Rearrange("b h d n -> b n (h d)"), nn.Linear(dim_inner, dim, bias=False))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
|
||||
q, k, v = self.to_qkv(x)
|
||||
|
||||
q, k = map(l2norm, (q, k))
|
||||
q = q * self.temperature.exp()
|
||||
|
||||
out = self.attend(q, k, v)
|
||||
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class Transformer(Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dim,
|
||||
depth,
|
||||
dim_head=64,
|
||||
heads=8,
|
||||
attn_dropout=0.0,
|
||||
ff_dropout=0.0,
|
||||
ff_mult=4,
|
||||
norm_output=True,
|
||||
rotary_embed=None,
|
||||
flash_attn=True,
|
||||
linear_attn=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.layers = ModuleList([])
|
||||
|
||||
for _ in range(depth):
|
||||
if linear_attn:
|
||||
attn = LinearAttention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, flash=flash_attn)
|
||||
else:
|
||||
attn = Attention(
|
||||
dim=dim,
|
||||
dim_head=dim_head,
|
||||
heads=heads,
|
||||
dropout=attn_dropout,
|
||||
rotary_embed=rotary_embed,
|
||||
flash=flash_attn,
|
||||
)
|
||||
|
||||
self.layers.append(ModuleList([attn, FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)]))
|
||||
|
||||
self.norm = RMSNorm(dim) if norm_output else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
|
||||
return self.norm(x)
|
||||
|
||||
|
||||
# bandsplit module
|
||||
|
||||
|
||||
class BandSplit(Module):
|
||||
# @beartype
|
||||
def __init__(self, dim, dim_inputs: Tuple[int, ...]):
|
||||
super().__init__()
|
||||
self.dim_inputs = dim_inputs
|
||||
self.to_features = ModuleList([])
|
||||
|
||||
for dim_in in dim_inputs:
|
||||
net = nn.Sequential(RMSNorm(dim_in), nn.Linear(dim_in, dim))
|
||||
|
||||
self.to_features.append(net)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.split(self.dim_inputs, dim=-1)
|
||||
|
||||
outs = []
|
||||
for split_input, to_feature in zip(x, self.to_features):
|
||||
split_output = to_feature(split_input)
|
||||
outs.append(split_output)
|
||||
|
||||
return torch.stack(outs, dim=-2)
|
||||
|
||||
|
||||
def MLP(dim_in, dim_out, dim_hidden=None, depth=1, activation=nn.Tanh):
|
||||
dim_hidden = default(dim_hidden, dim_in)
|
||||
|
||||
net = []
|
||||
dims = (dim_in, *((dim_hidden,) * (depth - 1)), dim_out)
|
||||
|
||||
for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
is_last = ind == (len(dims) - 2)
|
||||
|
||||
net.append(nn.Linear(layer_dim_in, layer_dim_out))
|
||||
|
||||
if is_last:
|
||||
continue
|
||||
|
||||
net.append(activation())
|
||||
|
||||
return nn.Sequential(*net)
|
||||
|
||||
|
||||
class MaskEstimator(Module):
|
||||
# @beartype
|
||||
def __init__(self, dim, dim_inputs: Tuple[int, ...], depth, mlp_expansion_factor=4):
|
||||
super().__init__()
|
||||
self.dim_inputs = dim_inputs
|
||||
self.to_freqs = ModuleList([])
|
||||
dim_hidden = dim * mlp_expansion_factor
|
||||
|
||||
for dim_in in dim_inputs:
|
||||
net = []
|
||||
|
||||
mlp = nn.Sequential(MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth), nn.GLU(dim=-1))
|
||||
|
||||
self.to_freqs.append(mlp)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.unbind(dim=-2)
|
||||
|
||||
outs = []
|
||||
|
||||
for band_features, mlp in zip(x, self.to_freqs):
|
||||
freq_out = mlp(band_features)
|
||||
outs.append(freq_out)
|
||||
|
||||
return torch.cat(outs, dim=-1)
|
||||
|
||||
|
||||
# main class
|
||||
|
||||
DEFAULT_FREQS_PER_BANDS = (
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
4,
|
||||
4,
|
||||
4,
|
||||
4,
|
||||
4,
|
||||
4,
|
||||
4,
|
||||
4,
|
||||
4,
|
||||
4,
|
||||
4,
|
||||
4,
|
||||
12,
|
||||
12,
|
||||
12,
|
||||
12,
|
||||
12,
|
||||
12,
|
||||
12,
|
||||
12,
|
||||
24,
|
||||
24,
|
||||
24,
|
||||
24,
|
||||
24,
|
||||
24,
|
||||
24,
|
||||
24,
|
||||
48,
|
||||
48,
|
||||
48,
|
||||
48,
|
||||
48,
|
||||
48,
|
||||
48,
|
||||
48,
|
||||
128,
|
||||
129,
|
||||
)
|
||||
|
||||
|
||||
class BSRoformer(Module):
|
||||
# @beartype
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
*,
|
||||
depth,
|
||||
stereo=False,
|
||||
num_stems=1,
|
||||
time_transformer_depth=2,
|
||||
freq_transformer_depth=2,
|
||||
linear_transformer_depth=0,
|
||||
freqs_per_bands: Tuple[int, ...] = DEFAULT_FREQS_PER_BANDS,
|
||||
# in the paper, they divide into ~60 bands, test with 1 for starters
|
||||
dim_head=64,
|
||||
heads=8,
|
||||
attn_dropout=0.0,
|
||||
ff_dropout=0.0,
|
||||
flash_attn=True,
|
||||
dim_freqs_in=1025,
|
||||
stft_n_fft=2048,
|
||||
stft_hop_length=512,
|
||||
# 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction
|
||||
stft_win_length=2048,
|
||||
stft_normalized=False,
|
||||
stft_window_fn: Optional[Callable] = None,
|
||||
mask_estimator_depth=2,
|
||||
multi_stft_resolution_loss_weight=1.0,
|
||||
multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
|
||||
multi_stft_hop_size=147,
|
||||
multi_stft_normalized=False,
|
||||
multi_stft_window_fn: Callable = torch.hann_window,
|
||||
mlp_expansion_factor=4,
|
||||
use_torch_checkpoint=False,
|
||||
skip_connection=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.stereo = stereo
|
||||
self.audio_channels = 2 if stereo else 1
|
||||
self.num_stems = num_stems
|
||||
self.use_torch_checkpoint = use_torch_checkpoint
|
||||
self.skip_connection = skip_connection
|
||||
|
||||
self.layers = ModuleList([])
|
||||
|
||||
transformer_kwargs = dict(
|
||||
dim=dim,
|
||||
heads=heads,
|
||||
dim_head=dim_head,
|
||||
attn_dropout=attn_dropout,
|
||||
ff_dropout=ff_dropout,
|
||||
flash_attn=flash_attn,
|
||||
norm_output=False,
|
||||
)
|
||||
|
||||
time_rotary_embed = RotaryEmbedding(dim=dim_head)
|
||||
freq_rotary_embed = RotaryEmbedding(dim=dim_head)
|
||||
|
||||
for _ in range(depth):
|
||||
tran_modules = []
|
||||
if linear_transformer_depth > 0:
|
||||
tran_modules.append(Transformer(depth=linear_transformer_depth, linear_attn=True, **transformer_kwargs))
|
||||
tran_modules.append(
|
||||
Transformer(depth=time_transformer_depth, rotary_embed=time_rotary_embed, **transformer_kwargs)
|
||||
)
|
||||
tran_modules.append(
|
||||
Transformer(depth=freq_transformer_depth, rotary_embed=freq_rotary_embed, **transformer_kwargs)
|
||||
)
|
||||
self.layers.append(nn.ModuleList(tran_modules))
|
||||
|
||||
self.final_norm = RMSNorm(dim)
|
||||
|
||||
self.stft_kwargs = dict(
|
||||
n_fft=stft_n_fft, hop_length=stft_hop_length, win_length=stft_win_length, normalized=stft_normalized
|
||||
)
|
||||
|
||||
self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length)
|
||||
|
||||
freqs = torch.stft(
|
||||
torch.randn(1, 4096), **self.stft_kwargs, window=torch.ones(stft_win_length), return_complex=True
|
||||
).shape[1]
|
||||
|
||||
assert len(freqs_per_bands) > 1
|
||||
assert sum(freqs_per_bands) == freqs, (
|
||||
f"the number of freqs in the bands must equal {freqs} based on the STFT settings, but got {sum(freqs_per_bands)}"
|
||||
)
|
||||
|
||||
freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in freqs_per_bands)
|
||||
|
||||
self.band_split = BandSplit(dim=dim, dim_inputs=freqs_per_bands_with_complex)
|
||||
|
||||
self.mask_estimators = nn.ModuleList([])
|
||||
|
||||
for _ in range(num_stems):
|
||||
mask_estimator = MaskEstimator(
|
||||
dim=dim,
|
||||
dim_inputs=freqs_per_bands_with_complex,
|
||||
depth=mask_estimator_depth,
|
||||
mlp_expansion_factor=mlp_expansion_factor,
|
||||
)
|
||||
|
||||
self.mask_estimators.append(mask_estimator)
|
||||
|
||||
# for the multi-resolution stft loss
|
||||
|
||||
self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
|
||||
self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
|
||||
self.multi_stft_n_fft = stft_n_fft
|
||||
self.multi_stft_window_fn = multi_stft_window_fn
|
||||
|
||||
self.multi_stft_kwargs = dict(hop_length=multi_stft_hop_size, normalized=multi_stft_normalized)
|
||||
|
||||
def forward(self, raw_audio, target=None, return_loss_breakdown=False):
|
||||
"""
|
||||
einops
|
||||
|
||||
b - batch
|
||||
f - freq
|
||||
t - time
|
||||
s - audio channel (1 for mono, 2 for stereo)
|
||||
n - number of 'stems'
|
||||
c - complex (2)
|
||||
d - feature dimension
|
||||
"""
|
||||
|
||||
device = raw_audio.device
|
||||
|
||||
# defining whether model is loaded on MPS (MacOS GPU accelerator)
|
||||
x_is_mps = True if device.type == "mps" else False
|
||||
|
||||
if raw_audio.ndim == 2:
|
||||
raw_audio = rearrange(raw_audio, "b t -> b 1 t")
|
||||
|
||||
channels = raw_audio.shape[1]
|
||||
assert (not self.stereo and channels == 1) or (self.stereo and channels == 2), (
|
||||
"stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)"
|
||||
)
|
||||
|
||||
# to stft
|
||||
|
||||
raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, "* t")
|
||||
|
||||
stft_window = self.stft_window_fn(device=device)
|
||||
|
||||
# RuntimeError: FFT operations are only supported on MacOS 14+
|
||||
# Since it's tedious to define whether we're on correct MacOS version - simple try-catch is used
|
||||
try:
|
||||
stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True)
|
||||
except:
|
||||
stft_repr = torch.stft(
|
||||
raw_audio.cpu() if x_is_mps else raw_audio,
|
||||
**self.stft_kwargs,
|
||||
window=stft_window.cpu() if x_is_mps else stft_window,
|
||||
return_complex=True,
|
||||
).to(device)
|
||||
|
||||
stft_repr = torch.view_as_real(stft_repr)
|
||||
|
||||
stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, "* f t c")
|
||||
|
||||
# merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
|
||||
stft_repr = rearrange(stft_repr, "b s f t c -> b (f s) t c")
|
||||
|
||||
x = rearrange(stft_repr, "b f t c -> b t (f c)")
|
||||
|
||||
if self.use_torch_checkpoint:
|
||||
x = checkpoint(self.band_split, x, use_reentrant=False)
|
||||
else:
|
||||
x = self.band_split(x)
|
||||
|
||||
# axial / hierarchical attention
|
||||
|
||||
store = [None] * len(self.layers)
|
||||
for i, transformer_block in enumerate(self.layers):
|
||||
if len(transformer_block) == 3:
|
||||
linear_transformer, time_transformer, freq_transformer = transformer_block
|
||||
|
||||
x, ft_ps = pack([x], "b * d")
|
||||
if self.use_torch_checkpoint:
|
||||
x = checkpoint(linear_transformer, x, use_reentrant=False)
|
||||
else:
|
||||
x = linear_transformer(x)
|
||||
(x,) = unpack(x, ft_ps, "b * d")
|
||||
else:
|
||||
time_transformer, freq_transformer = transformer_block
|
||||
|
||||
if self.skip_connection:
|
||||
# Sum all previous
|
||||
for j in range(i):
|
||||
x = x + store[j]
|
||||
|
||||
x = rearrange(x, "b t f d -> b f t d")
|
||||
x, ps = pack([x], "* t d")
|
||||
|
||||
if self.use_torch_checkpoint:
|
||||
x = checkpoint(time_transformer, x, use_reentrant=False)
|
||||
else:
|
||||
x = time_transformer(x)
|
||||
|
||||
(x,) = unpack(x, ps, "* t d")
|
||||
x = rearrange(x, "b f t d -> b t f d")
|
||||
x, ps = pack([x], "* f d")
|
||||
|
||||
if self.use_torch_checkpoint:
|
||||
x = checkpoint(freq_transformer, x, use_reentrant=False)
|
||||
else:
|
||||
x = freq_transformer(x)
|
||||
|
||||
(x,) = unpack(x, ps, "* f d")
|
||||
|
||||
if self.skip_connection:
|
||||
store[i] = x
|
||||
|
||||
x = self.final_norm(x)
|
||||
|
||||
num_stems = len(self.mask_estimators)
|
||||
|
||||
if self.use_torch_checkpoint:
|
||||
mask = torch.stack([checkpoint(fn, x, use_reentrant=False) for fn in self.mask_estimators], dim=1)
|
||||
else:
|
||||
mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
|
||||
mask = rearrange(mask, "b n t (f c) -> b n f t c", c=2)
|
||||
|
||||
# modulate frequency representation
|
||||
|
||||
stft_repr = rearrange(stft_repr, "b f t c -> b 1 f t c")
|
||||
|
||||
# complex number multiplication
|
||||
|
||||
stft_repr = torch.view_as_complex(stft_repr)
|
||||
mask = torch.view_as_complex(mask)
|
||||
|
||||
stft_repr = stft_repr * mask
|
||||
|
||||
# istft
|
||||
|
||||
stft_repr = rearrange(stft_repr, "b n (f s) t -> (b n s) f t", s=self.audio_channels)
|
||||
|
||||
# same as torch.stft() fix for MacOS MPS above
|
||||
try:
|
||||
recon_audio = torch.istft(
|
||||
stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False, length=raw_audio.shape[-1]
|
||||
)
|
||||
except:
|
||||
recon_audio = torch.istft(
|
||||
stft_repr.cpu() if x_is_mps else stft_repr,
|
||||
**self.stft_kwargs,
|
||||
window=stft_window.cpu() if x_is_mps else stft_window,
|
||||
return_complex=False,
|
||||
length=raw_audio.shape[-1],
|
||||
).to(device)
|
||||
|
||||
recon_audio = rearrange(recon_audio, "(b n s) t -> b n s t", s=self.audio_channels, n=num_stems)
|
||||
|
||||
if num_stems == 1:
|
||||
recon_audio = rearrange(recon_audio, "b 1 s t -> b s t")
|
||||
|
||||
# if a target is passed in, calculate loss for learning
|
||||
|
||||
if not exists(target):
|
||||
return recon_audio
|
||||
|
||||
if self.num_stems > 1:
|
||||
assert target.ndim == 4 and target.shape[1] == self.num_stems
|
||||
|
||||
if target.ndim == 2:
|
||||
target = rearrange(target, "... t -> ... 1 t")
|
||||
|
||||
target = target[..., : recon_audio.shape[-1]] # protect against lost length on istft
|
||||
|
||||
loss = F.l1_loss(recon_audio, target)
|
||||
|
||||
multi_stft_resolution_loss = 0.0
|
||||
|
||||
for window_size in self.multi_stft_resolutions_window_sizes:
|
||||
res_stft_kwargs = dict(
|
||||
n_fft=max(window_size, self.multi_stft_n_fft), # not sure what n_fft is across multi resolution stft
|
||||
win_length=window_size,
|
||||
return_complex=True,
|
||||
window=self.multi_stft_window_fn(window_size, device=device),
|
||||
**self.multi_stft_kwargs,
|
||||
)
|
||||
|
||||
recon_Y = torch.stft(rearrange(recon_audio, "... s t -> (... s) t"), **res_stft_kwargs)
|
||||
target_Y = torch.stft(rearrange(target, "... s t -> (... s) t"), **res_stft_kwargs)
|
||||
|
||||
multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y)
|
||||
|
||||
weighted_multi_resolution_loss = multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
|
||||
|
||||
total_loss = loss + weighted_multi_resolution_loss
|
||||
|
||||
if not return_loss_breakdown:
|
||||
return total_loss
|
||||
|
||||
return total_loss, (loss, multi_stft_resolution_loss)
|
||||
@@ -0,0 +1,606 @@
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import Module, ModuleList
|
||||
import torch.nn.functional as F
|
||||
|
||||
from bs_roformer.attend import Attend
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from typing import Tuple, Optional, Callable
|
||||
# from beartype.typing import Tuple, Optional, List, Callable
|
||||
# from beartype import beartype
|
||||
|
||||
from rotary_embedding_torch import RotaryEmbedding
|
||||
|
||||
from einops import rearrange, pack, unpack, reduce, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
from librosa import filters
|
||||
|
||||
|
||||
# helper functions
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def default(v, d):
|
||||
return v if exists(v) else d
|
||||
|
||||
|
||||
def pack_one(t, pattern):
|
||||
return pack([t], pattern)
|
||||
|
||||
|
||||
def unpack_one(t, ps, pattern):
|
||||
return unpack(t, ps, pattern)[0]
|
||||
|
||||
|
||||
def pad_at_dim(t, pad, dim=-1, value=0.0):
|
||||
dims_from_right = (-dim - 1) if dim < 0 else (t.ndim - dim - 1)
|
||||
zeros = (0, 0) * dims_from_right
|
||||
return F.pad(t, (*zeros, *pad), value=value)
|
||||
|
||||
|
||||
def l2norm(t):
|
||||
return F.normalize(t, dim=-1, p=2)
|
||||
|
||||
|
||||
# norm
|
||||
|
||||
|
||||
class RMSNorm(Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.scale = dim**0.5
|
||||
self.gamma = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def forward(self, x):
|
||||
return F.normalize(x, dim=-1) * self.scale * self.gamma
|
||||
|
||||
|
||||
# attention
|
||||
|
||||
|
||||
class FeedForward(Module):
|
||||
def __init__(self, dim, mult=4, dropout=0.0):
|
||||
super().__init__()
|
||||
dim_inner = int(dim * mult)
|
||||
self.net = nn.Sequential(
|
||||
RMSNorm(dim),
|
||||
nn.Linear(dim, dim_inner),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(dim_inner, dim),
|
||||
nn.Dropout(dropout),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(self, dim, heads=8, dim_head=64, dropout=0.0, rotary_embed=None, flash=True):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.scale = dim_head**-0.5
|
||||
dim_inner = heads * dim_head
|
||||
|
||||
self.rotary_embed = rotary_embed
|
||||
|
||||
self.attend = Attend(flash=flash, dropout=dropout)
|
||||
|
||||
self.norm = RMSNorm(dim)
|
||||
self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
|
||||
|
||||
self.to_gates = nn.Linear(dim, heads)
|
||||
|
||||
self.to_out = nn.Sequential(nn.Linear(dim_inner, dim, bias=False), nn.Dropout(dropout))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
|
||||
q, k, v = rearrange(self.to_qkv(x), "b n (qkv h d) -> qkv b h n d", qkv=3, h=self.heads)
|
||||
|
||||
if exists(self.rotary_embed):
|
||||
q = self.rotary_embed.rotate_queries_or_keys(q)
|
||||
k = self.rotary_embed.rotate_queries_or_keys(k)
|
||||
|
||||
out = self.attend(q, k, v)
|
||||
|
||||
gates = self.to_gates(x)
|
||||
out = out * rearrange(gates, "b n h -> b h n 1").sigmoid()
|
||||
|
||||
out = rearrange(out, "b h n d -> b n (h d)")
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class LinearAttention(Module):
|
||||
"""
|
||||
this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al.
|
||||
"""
|
||||
|
||||
# @beartype
|
||||
def __init__(self, *, dim, dim_head=32, heads=8, scale=8, flash=False, dropout=0.0):
|
||||
super().__init__()
|
||||
dim_inner = dim_head * heads
|
||||
self.norm = RMSNorm(dim)
|
||||
|
||||
self.to_qkv = nn.Sequential(
|
||||
nn.Linear(dim, dim_inner * 3, bias=False), Rearrange("b n (qkv h d) -> qkv b h d n", qkv=3, h=heads)
|
||||
)
|
||||
|
||||
self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
|
||||
|
||||
self.attend = Attend(scale=scale, dropout=dropout, flash=flash)
|
||||
|
||||
self.to_out = nn.Sequential(Rearrange("b h d n -> b n (h d)"), nn.Linear(dim_inner, dim, bias=False))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
|
||||
q, k, v = self.to_qkv(x)
|
||||
|
||||
q, k = map(l2norm, (q, k))
|
||||
q = q * self.temperature.exp()
|
||||
|
||||
out = self.attend(q, k, v)
|
||||
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class Transformer(Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dim,
|
||||
depth,
|
||||
dim_head=64,
|
||||
heads=8,
|
||||
attn_dropout=0.0,
|
||||
ff_dropout=0.0,
|
||||
ff_mult=4,
|
||||
norm_output=True,
|
||||
rotary_embed=None,
|
||||
flash_attn=True,
|
||||
linear_attn=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.layers = ModuleList([])
|
||||
|
||||
for _ in range(depth):
|
||||
if linear_attn:
|
||||
attn = LinearAttention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, flash=flash_attn)
|
||||
else:
|
||||
attn = Attention(
|
||||
dim=dim,
|
||||
dim_head=dim_head,
|
||||
heads=heads,
|
||||
dropout=attn_dropout,
|
||||
rotary_embed=rotary_embed,
|
||||
flash=flash_attn,
|
||||
)
|
||||
|
||||
self.layers.append(ModuleList([attn, FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)]))
|
||||
|
||||
self.norm = RMSNorm(dim) if norm_output else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
|
||||
return self.norm(x)
|
||||
|
||||
|
||||
# bandsplit module
|
||||
|
||||
|
||||
class BandSplit(Module):
|
||||
# @beartype
|
||||
def __init__(self, dim, dim_inputs: Tuple[int, ...]):
|
||||
super().__init__()
|
||||
self.dim_inputs = dim_inputs
|
||||
self.to_features = ModuleList([])
|
||||
|
||||
for dim_in in dim_inputs:
|
||||
net = nn.Sequential(RMSNorm(dim_in), nn.Linear(dim_in, dim))
|
||||
|
||||
self.to_features.append(net)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.split(self.dim_inputs, dim=-1)
|
||||
|
||||
outs = []
|
||||
for split_input, to_feature in zip(x, self.to_features):
|
||||
split_output = to_feature(split_input)
|
||||
outs.append(split_output)
|
||||
|
||||
return torch.stack(outs, dim=-2)
|
||||
|
||||
|
||||
def MLP(dim_in, dim_out, dim_hidden=None, depth=1, activation=nn.Tanh):
|
||||
dim_hidden = default(dim_hidden, dim_in)
|
||||
|
||||
net = []
|
||||
dims = (dim_in, *((dim_hidden,) * depth), dim_out)
|
||||
|
||||
for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
is_last = ind == (len(dims) - 2)
|
||||
|
||||
net.append(nn.Linear(layer_dim_in, layer_dim_out))
|
||||
|
||||
if is_last:
|
||||
continue
|
||||
|
||||
net.append(activation())
|
||||
|
||||
return nn.Sequential(*net)
|
||||
|
||||
|
||||
class MaskEstimator(Module):
|
||||
# @beartype
|
||||
def __init__(self, dim, dim_inputs: Tuple[int, ...], depth, mlp_expansion_factor=4):
|
||||
super().__init__()
|
||||
self.dim_inputs = dim_inputs
|
||||
self.to_freqs = ModuleList([])
|
||||
dim_hidden = dim * mlp_expansion_factor
|
||||
|
||||
for dim_in in dim_inputs:
|
||||
net = []
|
||||
|
||||
mlp = nn.Sequential(MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth), nn.GLU(dim=-1))
|
||||
|
||||
self.to_freqs.append(mlp)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.unbind(dim=-2)
|
||||
|
||||
outs = []
|
||||
|
||||
for band_features, mlp in zip(x, self.to_freqs):
|
||||
freq_out = mlp(band_features)
|
||||
outs.append(freq_out)
|
||||
|
||||
return torch.cat(outs, dim=-1)
|
||||
|
||||
|
||||
# main class
|
||||
|
||||
|
||||
class MelBandRoformer(Module):
|
||||
# @beartype
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
*,
|
||||
depth,
|
||||
stereo=False,
|
||||
num_stems=1,
|
||||
time_transformer_depth=2,
|
||||
freq_transformer_depth=2,
|
||||
linear_transformer_depth=0,
|
||||
num_bands=60,
|
||||
dim_head=64,
|
||||
heads=8,
|
||||
attn_dropout=0.1,
|
||||
ff_dropout=0.1,
|
||||
flash_attn=True,
|
||||
dim_freqs_in=1025,
|
||||
sample_rate=44100, # needed for mel filter bank from librosa
|
||||
stft_n_fft=2048,
|
||||
stft_hop_length=512,
|
||||
# 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction
|
||||
stft_win_length=2048,
|
||||
stft_normalized=False,
|
||||
stft_window_fn: Optional[Callable] = None,
|
||||
mask_estimator_depth=1,
|
||||
multi_stft_resolution_loss_weight=1.0,
|
||||
multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
|
||||
multi_stft_hop_size=147,
|
||||
multi_stft_normalized=False,
|
||||
multi_stft_window_fn: Callable = torch.hann_window,
|
||||
match_input_audio_length=False, # if True, pad output tensor to match length of input tensor
|
||||
mlp_expansion_factor=4,
|
||||
use_torch_checkpoint=False,
|
||||
skip_connection=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.stereo = stereo
|
||||
self.audio_channels = 2 if stereo else 1
|
||||
self.num_stems = num_stems
|
||||
self.use_torch_checkpoint = use_torch_checkpoint
|
||||
self.skip_connection = skip_connection
|
||||
|
||||
self.layers = ModuleList([])
|
||||
|
||||
transformer_kwargs = dict(
|
||||
dim=dim,
|
||||
heads=heads,
|
||||
dim_head=dim_head,
|
||||
attn_dropout=attn_dropout,
|
||||
ff_dropout=ff_dropout,
|
||||
flash_attn=flash_attn,
|
||||
)
|
||||
|
||||
time_rotary_embed = RotaryEmbedding(dim=dim_head)
|
||||
freq_rotary_embed = RotaryEmbedding(dim=dim_head)
|
||||
|
||||
for _ in range(depth):
|
||||
tran_modules = []
|
||||
if linear_transformer_depth > 0:
|
||||
tran_modules.append(Transformer(depth=linear_transformer_depth, linear_attn=True, **transformer_kwargs))
|
||||
tran_modules.append(
|
||||
Transformer(depth=time_transformer_depth, rotary_embed=time_rotary_embed, **transformer_kwargs)
|
||||
)
|
||||
tran_modules.append(
|
||||
Transformer(depth=freq_transformer_depth, rotary_embed=freq_rotary_embed, **transformer_kwargs)
|
||||
)
|
||||
self.layers.append(nn.ModuleList(tran_modules))
|
||||
|
||||
self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length)
|
||||
|
||||
self.stft_kwargs = dict(
|
||||
n_fft=stft_n_fft, hop_length=stft_hop_length, win_length=stft_win_length, normalized=stft_normalized
|
||||
)
|
||||
|
||||
freqs = torch.stft(
|
||||
torch.randn(1, 4096), **self.stft_kwargs, window=torch.ones(stft_n_fft), return_complex=True
|
||||
).shape[1]
|
||||
|
||||
# create mel filter bank
|
||||
# with librosa.filters.mel as in section 2 of paper
|
||||
|
||||
mel_filter_bank_numpy = filters.mel(sr=sample_rate, n_fft=stft_n_fft, n_mels=num_bands)
|
||||
|
||||
mel_filter_bank = torch.from_numpy(mel_filter_bank_numpy)
|
||||
|
||||
# for some reason, it doesn't include the first freq? just force a value for now
|
||||
|
||||
mel_filter_bank[0][0] = 1.0
|
||||
|
||||
# In some systems/envs we get 0.0 instead of ~1.9e-18 in the last position,
|
||||
# so let's force a positive value
|
||||
|
||||
mel_filter_bank[-1, -1] = 1.0
|
||||
|
||||
# binary as in paper (then estimated masks are averaged for overlapping regions)
|
||||
|
||||
freqs_per_band = mel_filter_bank > 0
|
||||
assert freqs_per_band.any(dim=0).all(), "all frequencies need to be covered by all bands for now"
|
||||
|
||||
repeated_freq_indices = repeat(torch.arange(freqs), "f -> b f", b=num_bands)
|
||||
freq_indices = repeated_freq_indices[freqs_per_band]
|
||||
|
||||
if stereo:
|
||||
freq_indices = repeat(freq_indices, "f -> f s", s=2)
|
||||
freq_indices = freq_indices * 2 + torch.arange(2)
|
||||
freq_indices = rearrange(freq_indices, "f s -> (f s)")
|
||||
|
||||
self.register_buffer("freq_indices", freq_indices, persistent=False)
|
||||
self.register_buffer("freqs_per_band", freqs_per_band, persistent=False)
|
||||
|
||||
num_freqs_per_band = reduce(freqs_per_band, "b f -> b", "sum")
|
||||
num_bands_per_freq = reduce(freqs_per_band, "b f -> f", "sum")
|
||||
|
||||
self.register_buffer("num_freqs_per_band", num_freqs_per_band, persistent=False)
|
||||
self.register_buffer("num_bands_per_freq", num_bands_per_freq, persistent=False)
|
||||
|
||||
# band split and mask estimator
|
||||
|
||||
freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in num_freqs_per_band.tolist())
|
||||
|
||||
self.band_split = BandSplit(dim=dim, dim_inputs=freqs_per_bands_with_complex)
|
||||
|
||||
self.mask_estimators = nn.ModuleList([])
|
||||
|
||||
for _ in range(num_stems):
|
||||
mask_estimator = MaskEstimator(
|
||||
dim=dim,
|
||||
dim_inputs=freqs_per_bands_with_complex,
|
||||
depth=mask_estimator_depth,
|
||||
mlp_expansion_factor=mlp_expansion_factor,
|
||||
)
|
||||
|
||||
self.mask_estimators.append(mask_estimator)
|
||||
|
||||
# for the multi-resolution stft loss
|
||||
|
||||
self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
|
||||
self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
|
||||
self.multi_stft_n_fft = stft_n_fft
|
||||
self.multi_stft_window_fn = multi_stft_window_fn
|
||||
|
||||
self.multi_stft_kwargs = dict(hop_length=multi_stft_hop_size, normalized=multi_stft_normalized)
|
||||
|
||||
self.match_input_audio_length = match_input_audio_length
|
||||
|
||||
def forward(self, raw_audio, target=None, return_loss_breakdown=False):
|
||||
"""
|
||||
einops
|
||||
|
||||
b - batch
|
||||
f - freq
|
||||
t - time
|
||||
s - audio channel (1 for mono, 2 for stereo)
|
||||
n - number of 'stems'
|
||||
c - complex (2)
|
||||
d - feature dimension
|
||||
"""
|
||||
|
||||
device = raw_audio.device
|
||||
|
||||
if raw_audio.ndim == 2:
|
||||
raw_audio = rearrange(raw_audio, "b t -> b 1 t")
|
||||
|
||||
batch, channels, raw_audio_length = raw_audio.shape
|
||||
|
||||
istft_length = raw_audio_length if self.match_input_audio_length else None
|
||||
|
||||
assert (not self.stereo and channels == 1) or (self.stereo and channels == 2), (
|
||||
"stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)"
|
||||
)
|
||||
|
||||
# to stft
|
||||
|
||||
raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, "* t")
|
||||
|
||||
stft_window = self.stft_window_fn(device=device)
|
||||
|
||||
stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True)
|
||||
stft_repr = torch.view_as_real(stft_repr)
|
||||
|
||||
stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, "* f t c")
|
||||
|
||||
# merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
|
||||
stft_repr = rearrange(stft_repr, "b s f t c -> b (f s) t c")
|
||||
|
||||
# index out all frequencies for all frequency ranges across bands ascending in one go
|
||||
|
||||
batch_arange = torch.arange(batch, device=device)[..., None]
|
||||
|
||||
# account for stereo
|
||||
|
||||
x = stft_repr[batch_arange, self.freq_indices]
|
||||
|
||||
# fold the complex (real and imag) into the frequencies dimension
|
||||
|
||||
x = rearrange(x, "b f t c -> b t (f c)")
|
||||
|
||||
if self.use_torch_checkpoint:
|
||||
x = checkpoint(self.band_split, x, use_reentrant=False)
|
||||
else:
|
||||
x = self.band_split(x)
|
||||
|
||||
# axial / hierarchical attention
|
||||
|
||||
store = [None] * len(self.layers)
|
||||
for i, transformer_block in enumerate(self.layers):
|
||||
if len(transformer_block) == 3:
|
||||
linear_transformer, time_transformer, freq_transformer = transformer_block
|
||||
|
||||
x, ft_ps = pack([x], "b * d")
|
||||
if self.use_torch_checkpoint:
|
||||
x = checkpoint(linear_transformer, x, use_reentrant=False)
|
||||
else:
|
||||
x = linear_transformer(x)
|
||||
(x,) = unpack(x, ft_ps, "b * d")
|
||||
else:
|
||||
time_transformer, freq_transformer = transformer_block
|
||||
|
||||
if self.skip_connection:
|
||||
# Sum all previous
|
||||
for j in range(i):
|
||||
x = x + store[j]
|
||||
|
||||
x = rearrange(x, "b t f d -> b f t d")
|
||||
x, ps = pack([x], "* t d")
|
||||
|
||||
if self.use_torch_checkpoint:
|
||||
x = checkpoint(time_transformer, x, use_reentrant=False)
|
||||
else:
|
||||
x = time_transformer(x)
|
||||
|
||||
(x,) = unpack(x, ps, "* t d")
|
||||
x = rearrange(x, "b f t d -> b t f d")
|
||||
x, ps = pack([x], "* f d")
|
||||
|
||||
if self.use_torch_checkpoint:
|
||||
x = checkpoint(freq_transformer, x, use_reentrant=False)
|
||||
else:
|
||||
x = freq_transformer(x)
|
||||
|
||||
(x,) = unpack(x, ps, "* f d")
|
||||
|
||||
if self.skip_connection:
|
||||
store[i] = x
|
||||
|
||||
num_stems = len(self.mask_estimators)
|
||||
if self.use_torch_checkpoint:
|
||||
masks = torch.stack([checkpoint(fn, x, use_reentrant=False) for fn in self.mask_estimators], dim=1)
|
||||
else:
|
||||
masks = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
|
||||
masks = rearrange(masks, "b n t (f c) -> b n f t c", c=2)
|
||||
|
||||
# modulate frequency representation
|
||||
|
||||
stft_repr = rearrange(stft_repr, "b f t c -> b 1 f t c")
|
||||
|
||||
# complex number multiplication
|
||||
|
||||
stft_repr = torch.view_as_complex(stft_repr)
|
||||
masks = torch.view_as_complex(masks)
|
||||
|
||||
masks = masks.type(stft_repr.dtype)
|
||||
|
||||
# need to average the estimated mask for the overlapped frequencies
|
||||
|
||||
scatter_indices = repeat(self.freq_indices, "f -> b n f t", b=batch, n=num_stems, t=stft_repr.shape[-1])
|
||||
|
||||
stft_repr_expanded_stems = repeat(stft_repr, "b 1 ... -> b n ...", n=num_stems)
|
||||
masks_summed = torch.zeros_like(stft_repr_expanded_stems).scatter_add_(2, scatter_indices, masks)
|
||||
|
||||
denom = repeat(self.num_bands_per_freq, "f -> (f r) 1", r=channels)
|
||||
|
||||
masks_averaged = masks_summed / denom.clamp(min=1e-8)
|
||||
|
||||
# modulate stft repr with estimated mask
|
||||
|
||||
stft_repr = stft_repr * masks_averaged
|
||||
|
||||
# istft
|
||||
|
||||
stft_repr = rearrange(stft_repr, "b n (f s) t -> (b n s) f t", s=self.audio_channels)
|
||||
|
||||
recon_audio = torch.istft(
|
||||
stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False, length=istft_length
|
||||
)
|
||||
|
||||
recon_audio = rearrange(recon_audio, "(b n s) t -> b n s t", b=batch, s=self.audio_channels, n=num_stems)
|
||||
|
||||
if num_stems == 1:
|
||||
recon_audio = rearrange(recon_audio, "b 1 s t -> b s t")
|
||||
|
||||
# if a target is passed in, calculate loss for learning
|
||||
|
||||
if not exists(target):
|
||||
return recon_audio
|
||||
|
||||
if self.num_stems > 1:
|
||||
assert target.ndim == 4 and target.shape[1] == self.num_stems
|
||||
|
||||
if target.ndim == 2:
|
||||
target = rearrange(target, "... t -> ... 1 t")
|
||||
|
||||
target = target[..., : recon_audio.shape[-1]] # protect against lost length on istft
|
||||
|
||||
loss = F.l1_loss(recon_audio, target)
|
||||
|
||||
multi_stft_resolution_loss = 0.0
|
||||
|
||||
for window_size in self.multi_stft_resolutions_window_sizes:
|
||||
res_stft_kwargs = dict(
|
||||
n_fft=max(window_size, self.multi_stft_n_fft), # not sure what n_fft is across multi resolution stft
|
||||
win_length=window_size,
|
||||
return_complex=True,
|
||||
window=self.multi_stft_window_fn(window_size, device=device),
|
||||
**self.multi_stft_kwargs,
|
||||
)
|
||||
|
||||
recon_Y = torch.stft(rearrange(recon_audio, "... s t -> (... s) t"), **res_stft_kwargs)
|
||||
target_Y = torch.stft(rearrange(target, "... s t -> (... s) t"), **res_stft_kwargs)
|
||||
|
||||
multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y)
|
||||
|
||||
weighted_multi_resolution_loss = multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
|
||||
|
||||
total_loss = loss + weighted_multi_resolution_loss
|
||||
|
||||
if not return_loss_breakdown:
|
||||
return total_loss
|
||||
|
||||
return total_loss, (loss, multi_stft_resolution_loss)
|
||||
304
ascend_910-gpt-sovits/GPT-SoVITS/tools/uvr5/bsroformer.py
Normal file
304
ascend_910-gpt-sovits/GPT-SoVITS/tools/uvr5/bsroformer.py
Normal file
@@ -0,0 +1,304 @@
|
||||
# This code is modified from https://github.com/ZFTurbo/
|
||||
import os
|
||||
import warnings
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import yaml
|
||||
from tqdm import tqdm
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
|
||||
class Roformer_Loader:
|
||||
def get_config(self, config_path):
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
# use fullloader to load tag !!python/tuple, code can be improved
|
||||
config = yaml.load(f, Loader=yaml.FullLoader)
|
||||
return config
|
||||
|
||||
def get_default_config(self):
|
||||
default_config = None
|
||||
if self.model_type == "bs_roformer":
|
||||
# Use model_bs_roformer_ep_368_sdr_12.9628.yaml and model_bs_roformer_ep_317_sdr_12.9755.yaml as default configuration files
|
||||
# Other BS_Roformer models may not be compatible
|
||||
# fmt: off
|
||||
default_config = {
|
||||
"audio": {"chunk_size": 352800, "sample_rate": 44100},
|
||||
"model": {
|
||||
"dim": 512,
|
||||
"depth": 12,
|
||||
"stereo": True,
|
||||
"num_stems": 1,
|
||||
"time_transformer_depth": 1,
|
||||
"freq_transformer_depth": 1,
|
||||
"linear_transformer_depth": 0,
|
||||
"freqs_per_bands": (2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 12, 12, 12, 12, 12, 12, 12, 12, 24, 24, 24, 24, 24, 24, 24, 24, 48, 48, 48, 48, 48, 48, 48, 48, 128, 129),
|
||||
"dim_head": 64,
|
||||
"heads": 8,
|
||||
"attn_dropout": 0.1,
|
||||
"ff_dropout": 0.1,
|
||||
"flash_attn": True,
|
||||
"dim_freqs_in": 1025,
|
||||
"stft_n_fft": 2048,
|
||||
"stft_hop_length": 441,
|
||||
"stft_win_length": 2048,
|
||||
"stft_normalized": False,
|
||||
"mask_estimator_depth": 2,
|
||||
"multi_stft_resolution_loss_weight": 1.0,
|
||||
"multi_stft_resolutions_window_sizes": (4096, 2048, 1024, 512, 256),
|
||||
"multi_stft_hop_size": 147,
|
||||
"multi_stft_normalized": False,
|
||||
},
|
||||
"training": {"instruments": ["vocals", "other"], "target_instrument": "vocals"},
|
||||
"inference": {"batch_size": 2, "num_overlap": 2},
|
||||
}
|
||||
# fmt: on
|
||||
elif self.model_type == "mel_band_roformer":
|
||||
# Use model_mel_band_roformer_ep_3005_sdr_11.4360.yaml as default configuration files
|
||||
# Other Mel_Band_Roformer models may not be compatible
|
||||
default_config = {
|
||||
"audio": {"chunk_size": 352800, "sample_rate": 44100},
|
||||
"model": {
|
||||
"dim": 384,
|
||||
"depth": 12,
|
||||
"stereo": True,
|
||||
"num_stems": 1,
|
||||
"time_transformer_depth": 1,
|
||||
"freq_transformer_depth": 1,
|
||||
"linear_transformer_depth": 0,
|
||||
"num_bands": 60,
|
||||
"dim_head": 64,
|
||||
"heads": 8,
|
||||
"attn_dropout": 0.1,
|
||||
"ff_dropout": 0.1,
|
||||
"flash_attn": True,
|
||||
"dim_freqs_in": 1025,
|
||||
"sample_rate": 44100,
|
||||
"stft_n_fft": 2048,
|
||||
"stft_hop_length": 441,
|
||||
"stft_win_length": 2048,
|
||||
"stft_normalized": False,
|
||||
"mask_estimator_depth": 2,
|
||||
"multi_stft_resolution_loss_weight": 1.0,
|
||||
"multi_stft_resolutions_window_sizes": (4096, 2048, 1024, 512, 256),
|
||||
"multi_stft_hop_size": 147,
|
||||
"multi_stft_normalized": False,
|
||||
},
|
||||
"training": {"instruments": ["vocals", "other"], "target_instrument": "vocals"},
|
||||
"inference": {"batch_size": 2, "num_overlap": 2},
|
||||
}
|
||||
|
||||
return default_config
|
||||
|
||||
def get_model_from_config(self):
|
||||
if self.model_type == "bs_roformer":
|
||||
from bs_roformer.bs_roformer import BSRoformer
|
||||
|
||||
model = BSRoformer(**dict(self.config["model"]))
|
||||
elif self.model_type == "mel_band_roformer":
|
||||
from bs_roformer.mel_band_roformer import MelBandRoformer
|
||||
|
||||
model = MelBandRoformer(**dict(self.config["model"]))
|
||||
else:
|
||||
print("Error: Unknown model: {}".format(self.model_type))
|
||||
model = None
|
||||
return model
|
||||
|
||||
def demix_track(self, model, mix, device):
|
||||
C = self.config["audio"]["chunk_size"] # chunk_size
|
||||
N = self.config["inference"]["num_overlap"]
|
||||
fade_size = C // 10
|
||||
step = int(C // N)
|
||||
border = C - step
|
||||
batch_size = self.config["inference"]["batch_size"]
|
||||
|
||||
length_init = mix.shape[-1]
|
||||
progress_bar = tqdm(total=length_init // step + 1, desc="Processing", leave=False)
|
||||
|
||||
# Do pad from the beginning and end to account floating window results better
|
||||
if length_init > 2 * border and (border > 0):
|
||||
mix = nn.functional.pad(mix, (border, border), mode="reflect")
|
||||
|
||||
# Prepare windows arrays (do 1 time for speed up). This trick repairs click problems on the edges of segment
|
||||
window_size = C
|
||||
fadein = torch.linspace(0, 1, fade_size)
|
||||
fadeout = torch.linspace(1, 0, fade_size)
|
||||
window_start = torch.ones(window_size)
|
||||
window_middle = torch.ones(window_size)
|
||||
window_finish = torch.ones(window_size)
|
||||
window_start[-fade_size:] *= fadeout # First audio chunk, no fadein
|
||||
window_finish[:fade_size] *= fadein # Last audio chunk, no fadeout
|
||||
window_middle[-fade_size:] *= fadeout
|
||||
window_middle[:fade_size] *= fadein
|
||||
|
||||
with torch.amp.autocast("cuda"):
|
||||
with torch.inference_mode():
|
||||
if self.config["training"]["target_instrument"] is None:
|
||||
req_shape = (len(self.config["training"]["instruments"]),) + tuple(mix.shape)
|
||||
else:
|
||||
req_shape = (1,) + tuple(mix.shape)
|
||||
|
||||
result = torch.zeros(req_shape, dtype=torch.float32)
|
||||
counter = torch.zeros(req_shape, dtype=torch.float32)
|
||||
i = 0
|
||||
batch_data = []
|
||||
batch_locations = []
|
||||
while i < mix.shape[1]:
|
||||
part = mix[:, i : i + C].to(device)
|
||||
length = part.shape[-1]
|
||||
if length < C:
|
||||
if length > C // 2 + 1:
|
||||
part = nn.functional.pad(input=part, pad=(0, C - length), mode="reflect")
|
||||
else:
|
||||
part = nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode="constant", value=0)
|
||||
if self.is_half:
|
||||
part = part.half()
|
||||
batch_data.append(part)
|
||||
batch_locations.append((i, length))
|
||||
i += step
|
||||
progress_bar.update(1)
|
||||
|
||||
if len(batch_data) >= batch_size or (i >= mix.shape[1]):
|
||||
arr = torch.stack(batch_data, dim=0)
|
||||
# print(23333333,arr.dtype)
|
||||
x = model(arr)
|
||||
|
||||
window = window_middle
|
||||
if i - step == 0: # First audio chunk, no fadein
|
||||
window = window_start
|
||||
elif i >= mix.shape[1]: # Last audio chunk, no fadeout
|
||||
window = window_finish
|
||||
|
||||
for j in range(len(batch_locations)):
|
||||
start, l = batch_locations[j]
|
||||
result[..., start : start + l] += x[j][..., :l].cpu() * window[..., :l]
|
||||
counter[..., start : start + l] += window[..., :l]
|
||||
|
||||
batch_data = []
|
||||
batch_locations = []
|
||||
|
||||
estimated_sources = result / counter
|
||||
estimated_sources = estimated_sources.cpu().numpy()
|
||||
np.nan_to_num(estimated_sources, copy=False, nan=0.0)
|
||||
|
||||
if length_init > 2 * border and (border > 0):
|
||||
# Remove pad
|
||||
estimated_sources = estimated_sources[..., border:-border]
|
||||
|
||||
progress_bar.close()
|
||||
|
||||
if self.config["training"]["target_instrument"] is None:
|
||||
return {k: v for k, v in zip(self.config["training"]["instruments"], estimated_sources)}
|
||||
else:
|
||||
return {k: v for k, v in zip([self.config["training"]["target_instrument"]], estimated_sources)}
|
||||
|
||||
def run_folder(self, input, vocal_root, others_root, format):
|
||||
self.model.eval()
|
||||
path = input
|
||||
os.makedirs(vocal_root, exist_ok=True)
|
||||
os.makedirs(others_root, exist_ok=True)
|
||||
file_base_name = os.path.splitext(os.path.basename(path))[0]
|
||||
|
||||
sample_rate = 44100
|
||||
if "sample_rate" in self.config["audio"]:
|
||||
sample_rate = self.config["audio"]["sample_rate"]
|
||||
|
||||
try:
|
||||
mix, sr = librosa.load(path, sr=sample_rate, mono=False)
|
||||
except Exception as e:
|
||||
print("Can read track: {}".format(path))
|
||||
print("Error message: {}".format(str(e)))
|
||||
return
|
||||
|
||||
# in case if model only supports mono tracks
|
||||
isstereo = self.config["model"].get("stereo", True)
|
||||
if not isstereo and len(mix.shape) != 1:
|
||||
mix = np.mean(mix, axis=0) # if more than 2 channels, take mean
|
||||
print("Warning: Track has more than 1 channels, but model is mono, taking mean of all channels.")
|
||||
|
||||
mix_orig = mix.copy()
|
||||
|
||||
mixture = torch.tensor(mix, dtype=torch.float32)
|
||||
res = self.demix_track(self.model, mixture, self.device)
|
||||
|
||||
if self.config["training"]["target_instrument"] is not None:
|
||||
# if target instrument is specified, save target instrument as vocal and other instruments as others
|
||||
# other instruments are caculated by subtracting target instrument from mixture
|
||||
target_instrument = self.config["training"]["target_instrument"]
|
||||
other_instruments = [i for i in self.config["training"]["instruments"] if i != target_instrument]
|
||||
other = mix_orig - res[target_instrument] # caculate other instruments
|
||||
|
||||
path_vocal = "{}/{}_{}.wav".format(vocal_root, file_base_name, target_instrument)
|
||||
path_other = "{}/{}_{}.wav".format(others_root, file_base_name, other_instruments[0])
|
||||
self.save_audio(path_vocal, res[target_instrument].T, sr, format)
|
||||
self.save_audio(path_other, other.T, sr, format)
|
||||
else:
|
||||
# if target instrument is not specified, save the first instrument as vocal and the rest as others
|
||||
vocal_inst = self.config["training"]["instruments"][0]
|
||||
path_vocal = "{}/{}_{}.wav".format(vocal_root, file_base_name, vocal_inst)
|
||||
self.save_audio(path_vocal, res[vocal_inst].T, sr, format)
|
||||
for other in self.config["training"]["instruments"][1:]: # save other instruments
|
||||
path_other = "{}/{}_{}.wav".format(others_root, file_base_name, other)
|
||||
self.save_audio(path_other, res[other].T, sr, format)
|
||||
|
||||
def save_audio(self, path, data, sr, format):
|
||||
# input path should be endwith '.wav'
|
||||
if format in ["wav", "flac"]:
|
||||
if format == "flac":
|
||||
path = path[:-3] + "flac"
|
||||
sf.write(path, data, sr)
|
||||
else:
|
||||
sf.write(path, data, sr)
|
||||
os.system('ffmpeg -i "{}" -vn "{}" -q:a 2 -y'.format(path, path[:-3] + format))
|
||||
try:
|
||||
os.remove(path)
|
||||
except:
|
||||
pass
|
||||
|
||||
def __init__(self, model_path, config_path, device, is_half):
|
||||
self.device = device
|
||||
self.is_half = is_half
|
||||
self.model_type = None
|
||||
self.config = None
|
||||
|
||||
# get model_type, first try:
|
||||
if "bs_roformer" in model_path.lower() or "bsroformer" in model_path.lower():
|
||||
self.model_type = "bs_roformer"
|
||||
elif "mel_band_roformer" in model_path.lower() or "melbandroformer" in model_path.lower():
|
||||
self.model_type = "mel_band_roformer"
|
||||
|
||||
if not os.path.exists(config_path):
|
||||
if self.model_type is None:
|
||||
# if model_type is still None, raise an error
|
||||
raise ValueError(
|
||||
"Error: Unknown model type. If you are using a model without a configuration file, Ensure that your model name includes 'bs_roformer', 'bsroformer', 'mel_band_roformer', or 'melbandroformer'. Otherwise, you can manually place the model configuration file into 'tools/uvr5/uvr5w_weights' and ensure that the configuration file is named as '<model_name>.yaml' then try it again."
|
||||
)
|
||||
self.config = self.get_default_config()
|
||||
else:
|
||||
# if there is a configuration file
|
||||
self.config = self.get_config(config_path)
|
||||
if self.model_type is None:
|
||||
# if model_type is still None, second try, get model_type from the configuration file
|
||||
if "freqs_per_bands" in self.config["model"]:
|
||||
# if freqs_per_bands in config, it's a bs_roformer model
|
||||
self.model_type = "bs_roformer"
|
||||
else:
|
||||
# else it's a mel_band_roformer model
|
||||
self.model_type = "mel_band_roformer"
|
||||
|
||||
print("Detected model type: {}".format(self.model_type))
|
||||
model = self.get_model_from_config()
|
||||
state_dict = torch.load(model_path, map_location="cpu")
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
if is_half == False:
|
||||
self.model = model.to(device)
|
||||
else:
|
||||
self.model = model.half().to(device)
|
||||
|
||||
def _path_audio_(self, input, others_root, vocal_root, format, is_hp3=False):
|
||||
self.run_folder(input, vocal_root, others_root, format)
|
||||
223
ascend_910-gpt-sovits/GPT-SoVITS/tools/uvr5/mdxnet.py
Normal file
223
ascend_910-gpt-sovits/GPT-SoVITS/tools/uvr5/mdxnet.py
Normal file
@@ -0,0 +1,223 @@
|
||||
import os
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
cpu = torch.device("cpu")
|
||||
|
||||
|
||||
class ConvTDFNetTrim:
|
||||
def __init__(self, device, model_name, target_name, L, dim_f, dim_t, n_fft, hop=1024):
|
||||
super(ConvTDFNetTrim, self).__init__()
|
||||
|
||||
self.dim_f = dim_f
|
||||
self.dim_t = 2**dim_t
|
||||
self.n_fft = n_fft
|
||||
self.hop = hop
|
||||
self.n_bins = self.n_fft // 2 + 1
|
||||
self.chunk_size = hop * (self.dim_t - 1)
|
||||
self.window = torch.hann_window(window_length=self.n_fft, periodic=True).to(device)
|
||||
self.target_name = target_name
|
||||
self.blender = "blender" in model_name
|
||||
|
||||
self.dim_c = 4
|
||||
out_c = self.dim_c * 4 if target_name == "*" else self.dim_c
|
||||
self.freq_pad = torch.zeros([1, out_c, self.n_bins - self.dim_f, self.dim_t]).to(device)
|
||||
|
||||
self.n = L // 2
|
||||
|
||||
def stft(self, x):
|
||||
x = x.reshape([-1, self.chunk_size])
|
||||
x = torch.stft(
|
||||
x,
|
||||
n_fft=self.n_fft,
|
||||
hop_length=self.hop,
|
||||
window=self.window,
|
||||
center=True,
|
||||
return_complex=True,
|
||||
)
|
||||
x = torch.view_as_real(x)
|
||||
x = x.permute([0, 3, 1, 2])
|
||||
x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape([-1, self.dim_c, self.n_bins, self.dim_t])
|
||||
return x[:, :, : self.dim_f]
|
||||
|
||||
def istft(self, x, freq_pad=None):
|
||||
freq_pad = self.freq_pad.repeat([x.shape[0], 1, 1, 1]) if freq_pad is None else freq_pad
|
||||
x = torch.cat([x, freq_pad], -2)
|
||||
c = 4 * 2 if self.target_name == "*" else 2
|
||||
x = x.reshape([-1, c, 2, self.n_bins, self.dim_t]).reshape([-1, 2, self.n_bins, self.dim_t])
|
||||
x = x.permute([0, 2, 3, 1])
|
||||
x = x.contiguous()
|
||||
x = torch.view_as_complex(x)
|
||||
x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True)
|
||||
return x.reshape([-1, c, self.chunk_size])
|
||||
|
||||
|
||||
def get_models(device, dim_f, dim_t, n_fft):
|
||||
return ConvTDFNetTrim(
|
||||
device=device,
|
||||
model_name="Conv-TDF",
|
||||
target_name="vocals",
|
||||
L=11,
|
||||
dim_f=dim_f,
|
||||
dim_t=dim_t,
|
||||
n_fft=n_fft,
|
||||
)
|
||||
|
||||
|
||||
class Predictor:
|
||||
def __init__(self, args):
|
||||
import onnxruntime as ort
|
||||
|
||||
logger.info(ort.get_available_providers())
|
||||
self.args = args
|
||||
self.model_ = get_models(device=cpu, dim_f=args.dim_f, dim_t=args.dim_t, n_fft=args.n_fft)
|
||||
self.model = ort.InferenceSession(
|
||||
os.path.join(args.onnx, self.model_.target_name + ".onnx"),
|
||||
providers=[
|
||||
"CUDAExecutionProvider",
|
||||
"DmlExecutionProvider",
|
||||
"CPUExecutionProvider",
|
||||
],
|
||||
)
|
||||
logger.info("ONNX load done")
|
||||
|
||||
def demix(self, mix):
|
||||
samples = mix.shape[-1]
|
||||
margin = self.args.margin
|
||||
chunk_size = self.args.chunks * 44100
|
||||
assert not margin == 0, "margin cannot be zero!"
|
||||
if margin > chunk_size:
|
||||
margin = chunk_size
|
||||
|
||||
segmented_mix = {}
|
||||
|
||||
if self.args.chunks == 0 or samples < chunk_size:
|
||||
chunk_size = samples
|
||||
|
||||
counter = -1
|
||||
for skip in range(0, samples, chunk_size):
|
||||
counter += 1
|
||||
|
||||
s_margin = 0 if counter == 0 else margin
|
||||
end = min(skip + chunk_size + margin, samples)
|
||||
|
||||
start = skip - s_margin
|
||||
|
||||
segmented_mix[skip] = mix[:, start:end].copy()
|
||||
if end == samples:
|
||||
break
|
||||
|
||||
sources = self.demix_base(segmented_mix, margin_size=margin)
|
||||
"""
|
||||
mix:(2,big_sample)
|
||||
segmented_mix:offset->(2,small_sample)
|
||||
sources:(1,2,big_sample)
|
||||
"""
|
||||
return sources
|
||||
|
||||
def demix_base(self, mixes, margin_size):
|
||||
chunked_sources = []
|
||||
progress_bar = tqdm(total=len(mixes))
|
||||
progress_bar.set_description("Processing")
|
||||
for mix in mixes:
|
||||
cmix = mixes[mix]
|
||||
sources = []
|
||||
n_sample = cmix.shape[1]
|
||||
model = self.model_
|
||||
trim = model.n_fft // 2
|
||||
gen_size = model.chunk_size - 2 * trim
|
||||
pad = gen_size - n_sample % gen_size
|
||||
mix_p = np.concatenate((np.zeros((2, trim)), cmix, np.zeros((2, pad)), np.zeros((2, trim))), 1)
|
||||
mix_waves = []
|
||||
i = 0
|
||||
while i < n_sample + pad:
|
||||
waves = np.array(mix_p[:, i : i + model.chunk_size])
|
||||
mix_waves.append(waves)
|
||||
i += gen_size
|
||||
mix_waves = torch.tensor(mix_waves, dtype=torch.float32).to(cpu)
|
||||
with torch.no_grad():
|
||||
_ort = self.model
|
||||
spek = model.stft(mix_waves)
|
||||
if self.args.denoise:
|
||||
spec_pred = (
|
||||
-_ort.run(None, {"input": -spek.cpu().numpy()})[0] * 0.5
|
||||
+ _ort.run(None, {"input": spek.cpu().numpy()})[0] * 0.5
|
||||
)
|
||||
tar_waves = model.istft(torch.tensor(spec_pred))
|
||||
else:
|
||||
tar_waves = model.istft(torch.tensor(_ort.run(None, {"input": spek.cpu().numpy()})[0]))
|
||||
tar_signal = tar_waves[:, :, trim:-trim].transpose(0, 1).reshape(2, -1).numpy()[:, :-pad]
|
||||
|
||||
start = 0 if mix == 0 else margin_size
|
||||
end = None if mix == list(mixes.keys())[::-1][0] else -margin_size
|
||||
if margin_size == 0:
|
||||
end = None
|
||||
sources.append(tar_signal[:, start:end])
|
||||
|
||||
progress_bar.update(1)
|
||||
|
||||
chunked_sources.append(sources)
|
||||
_sources = np.concatenate(chunked_sources, axis=-1)
|
||||
# del self.model
|
||||
progress_bar.close()
|
||||
return _sources
|
||||
|
||||
def prediction(self, m, vocal_root, others_root, format):
|
||||
os.makedirs(vocal_root, exist_ok=True)
|
||||
os.makedirs(others_root, exist_ok=True)
|
||||
basename = os.path.basename(m)
|
||||
mix, rate = librosa.load(m, mono=False, sr=44100)
|
||||
if mix.ndim == 1:
|
||||
mix = np.asfortranarray([mix, mix])
|
||||
mix = mix.T
|
||||
sources = self.demix(mix.T)
|
||||
opt = sources[0].T
|
||||
if format in ["wav", "flac"]:
|
||||
sf.write("%s/%s_main_vocal.%s" % (vocal_root, basename, format), mix - opt, rate)
|
||||
sf.write("%s/%s_others.%s" % (others_root, basename, format), opt, rate)
|
||||
else:
|
||||
path_vocal = "%s/%s_main_vocal.wav" % (vocal_root, basename)
|
||||
path_other = "%s/%s_others.wav" % (others_root, basename)
|
||||
sf.write(path_vocal, mix - opt, rate)
|
||||
sf.write(path_other, opt, rate)
|
||||
opt_path_vocal = path_vocal[:-4] + ".%s" % format
|
||||
opt_path_other = path_other[:-4] + ".%s" % format
|
||||
if os.path.exists(path_vocal):
|
||||
os.system('ffmpeg -i "%s" -vn "%s" -q:a 2 -y' % (path_vocal, opt_path_vocal))
|
||||
if os.path.exists(opt_path_vocal):
|
||||
try:
|
||||
os.remove(path_vocal)
|
||||
except:
|
||||
pass
|
||||
if os.path.exists(path_other):
|
||||
os.system('ffmpeg -i "%s" -vn "%s" -q:a 2 -y' % (path_other, opt_path_other))
|
||||
if os.path.exists(opt_path_other):
|
||||
try:
|
||||
os.remove(path_other)
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
class MDXNetDereverb:
|
||||
def __init__(self, chunks):
|
||||
self.onnx = "%s/uvr5_weights/onnx_dereverb_By_FoxJoy" % os.path.dirname(os.path.abspath(__file__))
|
||||
self.shifts = 10 # 'Predict with randomised equivariant stabilisation'
|
||||
self.mixing = "min_mag" # ['default','min_mag','max_mag']
|
||||
self.chunks = chunks
|
||||
self.margin = 44100
|
||||
self.dim_t = 9
|
||||
self.dim_f = 3072
|
||||
self.n_fft = 6144
|
||||
self.denoise = True
|
||||
self.pred = Predictor(self)
|
||||
self.device = cpu
|
||||
|
||||
def _path_audio_(self, input, others_root, vocal_root, format, is_hp3=False):
|
||||
self.pred.prediction(input, vocal_root, others_root, format)
|
||||
2
ascend_910-gpt-sovits/GPT-SoVITS/tools/uvr5/uvr5_weights/.gitignore
vendored
Normal file
2
ascend_910-gpt-sovits/GPT-SoVITS/tools/uvr5/uvr5_weights/.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
*
|
||||
!.gitignore
|
||||
350
ascend_910-gpt-sovits/GPT-SoVITS/tools/uvr5/vr.py
Normal file
350
ascend_910-gpt-sovits/GPT-SoVITS/tools/uvr5/vr.py
Normal file
@@ -0,0 +1,350 @@
|
||||
import os
|
||||
|
||||
parent_directory = os.path.dirname(os.path.abspath(__file__))
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import torch
|
||||
from lib.lib_v5 import nets_61968KB as Nets
|
||||
from lib.lib_v5 import spec_utils
|
||||
from lib.lib_v5.model_param_init import ModelParameters
|
||||
from lib.lib_v5.nets_new import CascadedNet
|
||||
from lib.utils import inference
|
||||
|
||||
|
||||
class AudioPre:
|
||||
def __init__(self, agg, model_path, device, is_half, tta=False):
|
||||
self.model_path = model_path
|
||||
self.device = device
|
||||
self.data = {
|
||||
# Processing Options
|
||||
"postprocess": False,
|
||||
"tta": tta,
|
||||
# Constants
|
||||
"window_size": 512,
|
||||
"agg": agg,
|
||||
"high_end_process": "mirroring",
|
||||
}
|
||||
mp = ModelParameters("%s/lib/lib_v5/modelparams/4band_v2.json" % parent_directory)
|
||||
model = Nets.CascadedASPPNet(mp.param["bins"] * 2)
|
||||
cpk = torch.load(model_path, map_location="cpu")
|
||||
model.load_state_dict(cpk)
|
||||
model.eval()
|
||||
if is_half:
|
||||
model = model.half().to(device)
|
||||
else:
|
||||
model = model.to(device)
|
||||
|
||||
self.mp = mp
|
||||
self.model = model
|
||||
|
||||
def _path_audio_(self, music_file, ins_root=None, vocal_root=None, format="flac", is_hp3=False):
|
||||
if ins_root is None and vocal_root is None:
|
||||
return "No save root."
|
||||
name = os.path.basename(music_file)
|
||||
if ins_root is not None:
|
||||
os.makedirs(ins_root, exist_ok=True)
|
||||
if vocal_root is not None:
|
||||
os.makedirs(vocal_root, exist_ok=True)
|
||||
X_wave, y_wave, X_spec_s, y_spec_s = {}, {}, {}, {}
|
||||
bands_n = len(self.mp.param["band"])
|
||||
# print(bands_n)
|
||||
for d in range(bands_n, 0, -1):
|
||||
bp = self.mp.param["band"][d]
|
||||
if d == bands_n: # high-end band
|
||||
(
|
||||
X_wave[d],
|
||||
_,
|
||||
) = librosa.core.load( # 理论上librosa读取可能对某些音频有bug,应该上ffmpeg读取,但是太麻烦了弃坑
|
||||
music_file,
|
||||
sr=bp["sr"],
|
||||
mono=False,
|
||||
dtype=np.float32,
|
||||
res_type=bp["res_type"],
|
||||
)
|
||||
if X_wave[d].ndim == 1:
|
||||
X_wave[d] = np.asfortranarray([X_wave[d], X_wave[d]])
|
||||
else: # lower bands
|
||||
X_wave[d] = librosa.core.resample(
|
||||
X_wave[d + 1],
|
||||
orig_sr=self.mp.param["band"][d + 1]["sr"],
|
||||
target_sr=bp["sr"],
|
||||
res_type=bp["res_type"],
|
||||
)
|
||||
# Stft of wave source
|
||||
X_spec_s[d] = spec_utils.wave_to_spectrogram_mt(
|
||||
X_wave[d],
|
||||
bp["hl"],
|
||||
bp["n_fft"],
|
||||
self.mp.param["mid_side"],
|
||||
self.mp.param["mid_side_b2"],
|
||||
self.mp.param["reverse"],
|
||||
)
|
||||
# pdb.set_trace()
|
||||
if d == bands_n and self.data["high_end_process"] != "none":
|
||||
input_high_end_h = (bp["n_fft"] // 2 - bp["crop_stop"]) + (
|
||||
self.mp.param["pre_filter_stop"] - self.mp.param["pre_filter_start"]
|
||||
)
|
||||
input_high_end = X_spec_s[d][:, bp["n_fft"] // 2 - input_high_end_h : bp["n_fft"] // 2, :]
|
||||
|
||||
X_spec_m = spec_utils.combine_spectrograms(X_spec_s, self.mp)
|
||||
aggresive_set = float(self.data["agg"] / 100)
|
||||
aggressiveness = {
|
||||
"value": aggresive_set,
|
||||
"split_bin": self.mp.param["band"][1]["crop_stop"],
|
||||
}
|
||||
with torch.no_grad():
|
||||
pred, X_mag, X_phase = inference(X_spec_m, self.device, self.model, aggressiveness, self.data)
|
||||
# Postprocess
|
||||
if self.data["postprocess"]:
|
||||
pred_inv = np.clip(X_mag - pred, 0, np.inf)
|
||||
pred = spec_utils.mask_silence(pred, pred_inv)
|
||||
y_spec_m = pred * X_phase
|
||||
v_spec_m = X_spec_m - y_spec_m
|
||||
|
||||
if is_hp3 == True:
|
||||
ins_root, vocal_root = vocal_root, ins_root
|
||||
|
||||
if ins_root is not None:
|
||||
if self.data["high_end_process"].startswith("mirroring"):
|
||||
input_high_end_ = spec_utils.mirroring(self.data["high_end_process"], y_spec_m, input_high_end, self.mp)
|
||||
wav_instrument = spec_utils.cmb_spectrogram_to_wave(
|
||||
y_spec_m, self.mp, input_high_end_h, input_high_end_
|
||||
)
|
||||
else:
|
||||
wav_instrument = spec_utils.cmb_spectrogram_to_wave(y_spec_m, self.mp)
|
||||
logger.info("%s instruments done" % name)
|
||||
if is_hp3 == True:
|
||||
head = "vocal_"
|
||||
else:
|
||||
head = "instrument_"
|
||||
if format in ["wav", "flac"]:
|
||||
sf.write(
|
||||
os.path.join(
|
||||
ins_root,
|
||||
head + "{}_{}.{}".format(name, self.data["agg"], format),
|
||||
),
|
||||
(np.array(wav_instrument) * 32768).astype("int16"),
|
||||
self.mp.param["sr"],
|
||||
) #
|
||||
else:
|
||||
path = os.path.join(ins_root, head + "{}_{}.wav".format(name, self.data["agg"]))
|
||||
sf.write(
|
||||
path,
|
||||
(np.array(wav_instrument) * 32768).astype("int16"),
|
||||
self.mp.param["sr"],
|
||||
)
|
||||
if os.path.exists(path):
|
||||
opt_format_path = path[:-4] + ".%s" % format
|
||||
cmd = 'ffmpeg -i "%s" -vn "%s" -q:a 2 -y' % (path, opt_format_path)
|
||||
print(cmd)
|
||||
os.system(cmd)
|
||||
if os.path.exists(opt_format_path):
|
||||
try:
|
||||
os.remove(path)
|
||||
except:
|
||||
pass
|
||||
if vocal_root is not None:
|
||||
if is_hp3 == True:
|
||||
head = "instrument_"
|
||||
else:
|
||||
head = "vocal_"
|
||||
if self.data["high_end_process"].startswith("mirroring"):
|
||||
input_high_end_ = spec_utils.mirroring(self.data["high_end_process"], v_spec_m, input_high_end, self.mp)
|
||||
wav_vocals = spec_utils.cmb_spectrogram_to_wave(v_spec_m, self.mp, input_high_end_h, input_high_end_)
|
||||
else:
|
||||
wav_vocals = spec_utils.cmb_spectrogram_to_wave(v_spec_m, self.mp)
|
||||
logger.info("%s vocals done" % name)
|
||||
if format in ["wav", "flac"]:
|
||||
sf.write(
|
||||
os.path.join(
|
||||
vocal_root,
|
||||
head + "{}_{}.{}".format(name, self.data["agg"], format),
|
||||
),
|
||||
(np.array(wav_vocals) * 32768).astype("int16"),
|
||||
self.mp.param["sr"],
|
||||
)
|
||||
else:
|
||||
path = os.path.join(vocal_root, head + "{}_{}.wav".format(name, self.data["agg"]))
|
||||
sf.write(
|
||||
path,
|
||||
(np.array(wav_vocals) * 32768).astype("int16"),
|
||||
self.mp.param["sr"],
|
||||
)
|
||||
if os.path.exists(path):
|
||||
opt_format_path = path[:-4] + ".%s" % format
|
||||
cmd = 'ffmpeg -i "%s" -vn "%s" -q:a 2 -y' % (path, opt_format_path)
|
||||
print(cmd)
|
||||
os.system(cmd)
|
||||
if os.path.exists(opt_format_path):
|
||||
try:
|
||||
os.remove(path)
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
class AudioPreDeEcho:
|
||||
def __init__(self, agg, model_path, device, is_half, tta=False):
|
||||
self.model_path = model_path
|
||||
self.device = device
|
||||
self.data = {
|
||||
# Processing Options
|
||||
"postprocess": False,
|
||||
"tta": tta,
|
||||
# Constants
|
||||
"window_size": 512,
|
||||
"agg": agg,
|
||||
"high_end_process": "mirroring",
|
||||
}
|
||||
mp = ModelParameters("%s/lib/lib_v5/modelparams/4band_v3.json" % parent_directory)
|
||||
nout = 64 if "DeReverb" in model_path else 48
|
||||
model = CascadedNet(mp.param["bins"] * 2, nout)
|
||||
cpk = torch.load(model_path, map_location="cpu")
|
||||
model.load_state_dict(cpk)
|
||||
model.eval()
|
||||
if is_half:
|
||||
model = model.half().to(device)
|
||||
else:
|
||||
model = model.to(device)
|
||||
|
||||
self.mp = mp
|
||||
self.model = model
|
||||
|
||||
def _path_audio_(
|
||||
self, music_file, vocal_root=None, ins_root=None, format="flac", is_hp3=False
|
||||
): # 3个VR模型vocal和ins是反的
|
||||
if ins_root is None and vocal_root is None:
|
||||
return "No save root."
|
||||
name = os.path.basename(music_file)
|
||||
if ins_root is not None:
|
||||
os.makedirs(ins_root, exist_ok=True)
|
||||
if vocal_root is not None:
|
||||
os.makedirs(vocal_root, exist_ok=True)
|
||||
X_wave, y_wave, X_spec_s, y_spec_s = {}, {}, {}, {}
|
||||
bands_n = len(self.mp.param["band"])
|
||||
# print(bands_n)
|
||||
for d in range(bands_n, 0, -1):
|
||||
bp = self.mp.param["band"][d]
|
||||
if d == bands_n: # high-end band
|
||||
(
|
||||
X_wave[d],
|
||||
_,
|
||||
) = librosa.core.load( # 理论上librosa读取可能对某些音频有bug,应该上ffmpeg读取,但是太麻烦了弃坑
|
||||
music_file,
|
||||
sr=bp["sr"],
|
||||
mono=False,
|
||||
dtype=np.float32,
|
||||
res_type=bp["res_type"],
|
||||
)
|
||||
if X_wave[d].ndim == 1:
|
||||
X_wave[d] = np.asfortranarray([X_wave[d], X_wave[d]])
|
||||
else: # lower bands
|
||||
X_wave[d] = librosa.core.resample(
|
||||
X_wave[d + 1],
|
||||
orig_sr=self.mp.param["band"][d + 1]["sr"],
|
||||
target_sr=bp["sr"],
|
||||
res_type=bp["res_type"],
|
||||
)
|
||||
# Stft of wave source
|
||||
X_spec_s[d] = spec_utils.wave_to_spectrogram_mt(
|
||||
X_wave[d],
|
||||
bp["hl"],
|
||||
bp["n_fft"],
|
||||
self.mp.param["mid_side"],
|
||||
self.mp.param["mid_side_b2"],
|
||||
self.mp.param["reverse"],
|
||||
)
|
||||
# pdb.set_trace()
|
||||
if d == bands_n and self.data["high_end_process"] != "none":
|
||||
input_high_end_h = (bp["n_fft"] // 2 - bp["crop_stop"]) + (
|
||||
self.mp.param["pre_filter_stop"] - self.mp.param["pre_filter_start"]
|
||||
)
|
||||
input_high_end = X_spec_s[d][:, bp["n_fft"] // 2 - input_high_end_h : bp["n_fft"] // 2, :]
|
||||
|
||||
X_spec_m = spec_utils.combine_spectrograms(X_spec_s, self.mp)
|
||||
aggresive_set = float(self.data["agg"] / 100)
|
||||
aggressiveness = {
|
||||
"value": aggresive_set,
|
||||
"split_bin": self.mp.param["band"][1]["crop_stop"],
|
||||
}
|
||||
with torch.no_grad():
|
||||
pred, X_mag, X_phase = inference(X_spec_m, self.device, self.model, aggressiveness, self.data)
|
||||
# Postprocess
|
||||
if self.data["postprocess"]:
|
||||
pred_inv = np.clip(X_mag - pred, 0, np.inf)
|
||||
pred = spec_utils.mask_silence(pred, pred_inv)
|
||||
y_spec_m = pred * X_phase
|
||||
v_spec_m = X_spec_m - y_spec_m
|
||||
|
||||
if ins_root is not None:
|
||||
if self.data["high_end_process"].startswith("mirroring"):
|
||||
input_high_end_ = spec_utils.mirroring(self.data["high_end_process"], y_spec_m, input_high_end, self.mp)
|
||||
wav_instrument = spec_utils.cmb_spectrogram_to_wave(
|
||||
y_spec_m, self.mp, input_high_end_h, input_high_end_
|
||||
)
|
||||
else:
|
||||
wav_instrument = spec_utils.cmb_spectrogram_to_wave(y_spec_m, self.mp)
|
||||
logger.info("%s instruments done" % name)
|
||||
if format in ["wav", "flac"]:
|
||||
sf.write(
|
||||
os.path.join(
|
||||
ins_root,
|
||||
"vocal_{}_{}.{}".format(name, self.data["agg"], format),
|
||||
),
|
||||
(np.array(wav_instrument) * 32768).astype("int16"),
|
||||
self.mp.param["sr"],
|
||||
) #
|
||||
else:
|
||||
path = os.path.join(ins_root, "vocal_{}_{}.wav".format(name, self.data["agg"]))
|
||||
sf.write(
|
||||
path,
|
||||
(np.array(wav_instrument) * 32768).astype("int16"),
|
||||
self.mp.param["sr"],
|
||||
)
|
||||
if os.path.exists(path):
|
||||
opt_format_path = path[:-4] + ".%s" % format
|
||||
cmd = 'ffmpeg -i "%s" -vn "%s" -q:a 2 -y' % (path, opt_format_path)
|
||||
print(cmd)
|
||||
os.system(cmd)
|
||||
if os.path.exists(opt_format_path):
|
||||
try:
|
||||
os.remove(path)
|
||||
except:
|
||||
pass
|
||||
if vocal_root is not None:
|
||||
if self.data["high_end_process"].startswith("mirroring"):
|
||||
input_high_end_ = spec_utils.mirroring(self.data["high_end_process"], v_spec_m, input_high_end, self.mp)
|
||||
wav_vocals = spec_utils.cmb_spectrogram_to_wave(v_spec_m, self.mp, input_high_end_h, input_high_end_)
|
||||
else:
|
||||
wav_vocals = spec_utils.cmb_spectrogram_to_wave(v_spec_m, self.mp)
|
||||
logger.info("%s vocals done" % name)
|
||||
if format in ["wav", "flac"]:
|
||||
sf.write(
|
||||
os.path.join(
|
||||
vocal_root,
|
||||
"instrument_{}_{}.{}".format(name, self.data["agg"], format),
|
||||
),
|
||||
(np.array(wav_vocals) * 32768).astype("int16"),
|
||||
self.mp.param["sr"],
|
||||
)
|
||||
else:
|
||||
path = os.path.join(vocal_root, "instrument_{}_{}.wav".format(name, self.data["agg"]))
|
||||
sf.write(
|
||||
path,
|
||||
(np.array(wav_vocals) * 32768).astype("int16"),
|
||||
self.mp.param["sr"],
|
||||
)
|
||||
if os.path.exists(path):
|
||||
opt_format_path = path[:-4] + ".%s" % format
|
||||
cmd = 'ffmpeg -i "%s" -vn "%s" -q:a 2 -y' % (path, opt_format_path)
|
||||
print(cmd)
|
||||
os.system(cmd)
|
||||
if os.path.exists(opt_format_path):
|
||||
try:
|
||||
os.remove(path)
|
||||
except:
|
||||
pass
|
||||
224
ascend_910-gpt-sovits/GPT-SoVITS/tools/uvr5/webui.py
Normal file
224
ascend_910-gpt-sovits/GPT-SoVITS/tools/uvr5/webui.py
Normal file
@@ -0,0 +1,224 @@
|
||||
import logging
|
||||
import os
|
||||
import traceback
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from tools.i18n.i18n import I18nAuto
|
||||
from tools.my_utils import clean_path
|
||||
|
||||
i18n = I18nAuto()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
import sys
|
||||
|
||||
import ffmpeg
|
||||
import torch
|
||||
from bsroformer import Roformer_Loader
|
||||
from mdxnet import MDXNetDereverb
|
||||
from vr import AudioPre, AudioPreDeEcho
|
||||
|
||||
weight_uvr5_root = "tools/uvr5/uvr5_weights"
|
||||
uvr5_names = []
|
||||
for name in os.listdir(weight_uvr5_root):
|
||||
if name.endswith(".pth") or name.endswith(".ckpt") or "onnx" in name:
|
||||
uvr5_names.append(name.replace(".pth", "").replace(".ckpt", ""))
|
||||
|
||||
device = sys.argv[1]
|
||||
is_half = eval(sys.argv[2])
|
||||
webui_port_uvr5 = int(sys.argv[3])
|
||||
is_share = eval(sys.argv[4])
|
||||
|
||||
|
||||
def html_left(text, label="p"):
|
||||
return f"""<div style="text-align: left; margin: 0; padding: 0;">
|
||||
<{label} style="margin: 0; padding: 0;">{text}</{label}>
|
||||
</div>"""
|
||||
|
||||
|
||||
def html_center(text, label="p"):
|
||||
return f"""<div style="text-align: center; margin: 100; padding: 50;">
|
||||
<{label} style="margin: 0; padding: 0;">{text}</{label}>
|
||||
</div>"""
|
||||
|
||||
|
||||
def uvr(model_name, inp_root, save_root_vocal, paths, save_root_ins, agg, format0):
|
||||
infos = []
|
||||
try:
|
||||
inp_root = clean_path(inp_root)
|
||||
save_root_vocal = clean_path(save_root_vocal)
|
||||
save_root_ins = clean_path(save_root_ins)
|
||||
is_hp3 = "HP3" in model_name
|
||||
if model_name == "onnx_dereverb_By_FoxJoy":
|
||||
pre_fun = MDXNetDereverb(15)
|
||||
elif "roformer" in model_name.lower():
|
||||
func = Roformer_Loader
|
||||
pre_fun = func(
|
||||
model_path=os.path.join(weight_uvr5_root, model_name + ".ckpt"),
|
||||
config_path=os.path.join(weight_uvr5_root, model_name + ".yaml"),
|
||||
device=device,
|
||||
is_half=is_half,
|
||||
)
|
||||
if not os.path.exists(os.path.join(weight_uvr5_root, model_name + ".yaml")):
|
||||
infos.append(
|
||||
"Warning: You are using a model without a configuration file. The program will automatically use the default configuration file. However, the default configuration file cannot guarantee that all models will run successfully. You can manually place the model configuration file into 'tools/uvr5/uvr5w_weights' and ensure that the configuration file is named as '<model_name>.yaml' then try it again. (For example, the configuration file corresponding to the model 'bs_roformer_ep_368_sdr_12.9628.ckpt' should be 'bs_roformer_ep_368_sdr_12.9628.yaml'.) Or you can just ignore this warning."
|
||||
)
|
||||
yield "\n".join(infos)
|
||||
else:
|
||||
func = AudioPre if "DeEcho" not in model_name else AudioPreDeEcho
|
||||
pre_fun = func(
|
||||
agg=int(agg),
|
||||
model_path=os.path.join(weight_uvr5_root, model_name + ".pth"),
|
||||
device=device,
|
||||
is_half=is_half,
|
||||
)
|
||||
if inp_root != "":
|
||||
paths = [os.path.join(inp_root, name) for name in os.listdir(inp_root)]
|
||||
else:
|
||||
paths = [path.name for path in paths]
|
||||
for path in paths:
|
||||
inp_path = os.path.join(inp_root, path)
|
||||
if os.path.isfile(inp_path) == False:
|
||||
continue
|
||||
need_reformat = 1
|
||||
done = 0
|
||||
try:
|
||||
info = ffmpeg.probe(inp_path, cmd="ffprobe")
|
||||
if info["streams"][0]["channels"] == 2 and info["streams"][0]["sample_rate"] == "44100":
|
||||
need_reformat = 0
|
||||
pre_fun._path_audio_(inp_path, save_root_ins, save_root_vocal, format0, is_hp3)
|
||||
done = 1
|
||||
except:
|
||||
need_reformat = 1
|
||||
traceback.print_exc()
|
||||
if need_reformat == 1:
|
||||
tmp_path = "%s/%s.reformatted.wav" % (
|
||||
os.path.join(os.environ["TEMP"]),
|
||||
os.path.basename(inp_path),
|
||||
)
|
||||
os.system(f'ffmpeg -i "{inp_path}" -vn -acodec pcm_s16le -ac 2 -ar 44100 "{tmp_path}" -y')
|
||||
inp_path = tmp_path
|
||||
try:
|
||||
if done == 0:
|
||||
pre_fun._path_audio_(inp_path, save_root_ins, save_root_vocal, format0, is_hp3)
|
||||
infos.append("%s->Success" % (os.path.basename(inp_path)))
|
||||
yield "\n".join(infos)
|
||||
except:
|
||||
infos.append("%s->%s" % (os.path.basename(inp_path), traceback.format_exc()))
|
||||
yield "\n".join(infos)
|
||||
except:
|
||||
infos.append(traceback.format_exc())
|
||||
yield "\n".join(infos)
|
||||
finally:
|
||||
try:
|
||||
if model_name == "onnx_dereverb_By_FoxJoy":
|
||||
del pre_fun.pred.model
|
||||
del pre_fun.pred.model_
|
||||
else:
|
||||
del pre_fun.model
|
||||
del pre_fun
|
||||
except:
|
||||
traceback.print_exc()
|
||||
print("clean_empty_cache")
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
yield "\n".join(infos)
|
||||
|
||||
|
||||
with gr.Blocks(title="UVR5 WebUI", analytics_enabled=False) as app:
|
||||
gr.Markdown(
|
||||
value=i18n("本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.")
|
||||
+ "<br>"
|
||||
+ i18n("如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.")
|
||||
)
|
||||
with gr.Group():
|
||||
gr.Markdown(html_center(i18n("伴奏人声分离&去混响&去回声"), "h2"))
|
||||
with gr.Group():
|
||||
gr.Markdown(
|
||||
value=html_left(
|
||||
i18n("人声伴奏分离批量处理, 使用UVR5模型。")
|
||||
+ "<br>"
|
||||
+ i18n(
|
||||
"合格的文件夹路径格式举例: E:\\codes\\py39\\vits_vc_gpu\\白鹭霜华测试样例(去文件管理器地址栏拷就行了)。"
|
||||
)
|
||||
+ "<br>"
|
||||
+ i18n("模型分为三类:")
|
||||
+ "<br>"
|
||||
+ i18n(
|
||||
"1、保留人声:不带和声的音频选这个,对主人声保留比HP5更好。内置HP2和HP3两个模型,HP3可能轻微漏伴奏但对主人声保留比HP2稍微好一丁点;"
|
||||
)
|
||||
+ "<br>"
|
||||
+ i18n("2、仅保留主人声:带和声的音频选这个,对主人声可能有削弱。内置HP5一个模型;")
|
||||
+ "<br>"
|
||||
+ i18n("3、去混响、去延迟模型(by FoxJoy):")
|
||||
+ "<br> "
|
||||
+ i18n("(1)MDX-Net(onnx_dereverb):对于双通道混响是最好的选择,不能去除单通道混响;")
|
||||
+ "<br> "
|
||||
+ i18n(
|
||||
"(234)DeEcho:去除延迟效果。Aggressive比Normal去除得更彻底,DeReverb额外去除混响,可去除单声道混响,但是对高频重的板式混响去不干净。"
|
||||
)
|
||||
+ "<br>"
|
||||
+ i18n("去混响/去延迟,附:")
|
||||
+ "<br>"
|
||||
+ i18n("1、DeEcho-DeReverb模型的耗时是另外2个DeEcho模型的接近2倍;")
|
||||
+ "<br>"
|
||||
+ i18n("2、MDX-Net-Dereverb模型挺慢的;")
|
||||
+ "<br>"
|
||||
+ i18n("3、个人推荐的最干净的配置是先MDX-Net再DeEcho-Aggressive。"),
|
||||
"h4",
|
||||
)
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
model_choose = gr.Dropdown(label=i18n("模型"), choices=uvr5_names)
|
||||
dir_wav_input = gr.Textbox(
|
||||
label=i18n("输入待处理音频文件夹路径"),
|
||||
placeholder="C:\\Users\\Desktop\\todo-songs",
|
||||
)
|
||||
wav_inputs = gr.File(
|
||||
file_count="multiple", label=i18n("也可批量输入音频文件, 二选一, 优先读文件夹")
|
||||
)
|
||||
with gr.Column():
|
||||
agg = gr.Slider(
|
||||
minimum=0,
|
||||
maximum=20,
|
||||
step=1,
|
||||
label=i18n("人声提取激进程度"),
|
||||
value=10,
|
||||
interactive=True,
|
||||
visible=False, # 先不开放调整
|
||||
)
|
||||
opt_vocal_root = gr.Textbox(label=i18n("指定输出主人声文件夹"), value="output/uvr5_opt")
|
||||
opt_ins_root = gr.Textbox(label=i18n("指定输出非主人声文件夹"), value="output/uvr5_opt")
|
||||
format0 = gr.Radio(
|
||||
label=i18n("导出文件格式"),
|
||||
choices=["wav", "flac", "mp3", "m4a"],
|
||||
value="flac",
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
but2 = gr.Button(i18n("转换"), variant="primary")
|
||||
with gr.Row():
|
||||
vc_output4 = gr.Textbox(label=i18n("输出信息"), lines=3)
|
||||
but2.click(
|
||||
uvr,
|
||||
[
|
||||
model_choose,
|
||||
dir_wav_input,
|
||||
opt_vocal_root,
|
||||
wav_inputs,
|
||||
opt_ins_root,
|
||||
agg,
|
||||
format0,
|
||||
],
|
||||
[vc_output4],
|
||||
api_name="uvr_convert",
|
||||
)
|
||||
app.queue().launch( # concurrency_count=511, max_size=1022
|
||||
server_name="0.0.0.0",
|
||||
inbrowser=True,
|
||||
share=is_share,
|
||||
server_port=webui_port_uvr5,
|
||||
# quiet=True,
|
||||
)
|
||||
Reference in New Issue
Block a user