init ascend tts
This commit is contained in:
@@ -0,0 +1,11 @@
|
||||
For the inference of the v3 model, if you find that the generated audio sounds somewhat muffled, you can try using this audio super-resolution model.
|
||||
对于v3模型的推理,如果你发现生成的音频比较闷,可以尝试这个音频超分模型。
|
||||
|
||||
put g_24kto48k.zip and config.json in this folder
|
||||
把g_24kto48k.zip and config.json下到这个文件夹
|
||||
|
||||
download link 下载链接:
|
||||
https://drive.google.com/drive/folders/1IIYTf2zbJWzelu4IftKD6ooHloJ8mnZF?usp=share_link
|
||||
|
||||
audio sr project page 音频超分项目主页:
|
||||
https://github.com/yxlu-0102/AP-BWE
|
||||
21
ascend_910-gpt-sovits/GPT-SoVITS/tools/AP_BWE_main/LICENSE
Normal file
21
ascend_910-gpt-sovits/GPT-SoVITS/tools/AP_BWE_main/LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 Ye-Xin Lu
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
91
ascend_910-gpt-sovits/GPT-SoVITS/tools/AP_BWE_main/README.md
Normal file
91
ascend_910-gpt-sovits/GPT-SoVITS/tools/AP_BWE_main/README.md
Normal file
@@ -0,0 +1,91 @@
|
||||
# Towards High-Quality and Efficient Speech Bandwidth Extension with Parallel Amplitude and Phase Prediction
|
||||
### Ye-Xin Lu, Yang Ai, Hui-Peng Du, Zhen-Hua Ling
|
||||
|
||||
**Abstract:**
|
||||
Speech bandwidth extension (BWE) refers to widening the frequency bandwidth range of speech signals, enhancing the speech quality towards brighter and fuller.
|
||||
This paper proposes a generative adversarial network (GAN) based BWE model with parallel prediction of Amplitude and Phase spectra, named AP-BWE, which achieves both high-quality and efficient wideband speech waveform generation.
|
||||
The proposed AP-BWE generator is entirely based on convolutional neural networks (CNNs).
|
||||
It features a dual-stream architecture with mutual interaction, where the amplitude stream and the phase stream communicate with each other and respectively extend the high-frequency components from the input narrowband amplitude and phase spectra.
|
||||
To improve the naturalness of the extended speech signals, we employ a multi-period discriminator at the waveform level and design a pair of multi-resolution amplitude and phase discriminators at the spectral level, respectively.
|
||||
Experimental results demonstrate that our proposed AP-BWE achieves state-of-the-art performance in terms of speech quality for BWE tasks targeting sampling rates of both 16 kHz and 48 kHz.
|
||||
In terms of generation efficiency, due to the all-convolutional architecture and all-frame-level operations, the proposed AP-BWE can generate 48 kHz waveform samples 292.3 times faster than real-time on a single RTX 4090 GPU and 18.1 times faster than real-time on a single CPU.
|
||||
Notably, to our knowledge, AP-BWE is the first to achieve the direct extension of the high-frequency phase spectrum, which is beneficial for improving the effectiveness of existing BWE methods.
|
||||
|
||||
**We provide our implementation as open source in this repository. Audio samples can be found at the [demo website](http://yxlu-0102.github.io/AP-BWE).**
|
||||
|
||||
|
||||
## Pre-requisites
|
||||
0. Python >= 3.9.
|
||||
0. Clone this repository.
|
||||
0. Install python requirements. Please refer [requirements.txt](requirements.txt).
|
||||
0. Download datasets
|
||||
1. Download and extract the [VCTK-0.92 dataset](https://datashare.ed.ac.uk/handle/10283/3443), and move its `wav48` directory into [VCTK-Corpus-0.92](VCTK-Corpus-0.92) and rename it as `wav48_origin`.
|
||||
1. Trim the silence of the dataset, and the trimmed files will be saved to `wav48_silence_trimmed`.
|
||||
```
|
||||
cd VCTK-Corpus-0.92
|
||||
python flac2wav.py
|
||||
```
|
||||
1. Move all the trimmed training files from `wav48_silence_trimmed` to [wav48/train](wav48/train) following the indexes in [training.txt](VCTK-Corpus-0.92/training.txt), and move all the untrimmed test files from `wav48_origin` to [wav48/test](wav48/test) following the indexes in [test.txt](VCTK-Corpus-0.92/test.txt).
|
||||
|
||||
## Training
|
||||
```
|
||||
cd train
|
||||
CUDA_VISIBLE_DEVICES=0 python train_16k.py --config [config file path]
|
||||
CUDA_VISIBLE_DEVICES=0 python train_48k.py --config [config file path]
|
||||
```
|
||||
Checkpoints and copies of the configuration file are saved in the `cp_model` directory by default.<br>
|
||||
You can change the path by using the `--checkpoint_path` option.
|
||||
Here is an example:
|
||||
```
|
||||
CUDA_VISIBLE_DEVICES=0 python train_16k.py --config ../configs/config_2kto16k.json --checkpoint_path ../checkpoints/AP-BWE_2kto16k
|
||||
```
|
||||
|
||||
## Inference
|
||||
```
|
||||
cd inference
|
||||
python inference_16k.py --checkpoint_file [generator checkpoint file path]
|
||||
python inference_48k.py --checkpoint_file [generator checkpoint file path]
|
||||
```
|
||||
You can download the [pretrained weights](https://drive.google.com/drive/folders/1IIYTf2zbJWzelu4IftKD6ooHloJ8mnZF?usp=share_link) we provide and move all the files to the `checkpoints` directory.
|
||||
<br>
|
||||
Generated wav files are saved in `generated_files` by default.
|
||||
You can change the path by adding `--output_dir` option.
|
||||
Here is an example:
|
||||
```
|
||||
python inference_16k.py --checkpoint_file ../checkpoints/2kto16k/g_2kto16k --output_dir ../generated_files/2kto16k
|
||||
```
|
||||
|
||||
## Model Structure
|
||||

|
||||
|
||||
## Comparison with other speech BWE methods
|
||||
### 2k/4k/8kHz to 16kHz
|
||||
<p align="center">
|
||||
<img src="Figures/table_16k.png" alt="comparison" width="90%"/>
|
||||
</p>
|
||||
|
||||
### 8k/12k/16/24kHz to 16kHz
|
||||
<p align="center">
|
||||
<img src="Figures/table_48k.png" alt="comparison" width="100%"/>
|
||||
</p>
|
||||
|
||||
## Acknowledgements
|
||||
We referred to [HiFi-GAN](https://github.com/jik876/hifi-gan) and [NSPP](https://github.com/YangAi520/NSPP) to implement this.
|
||||
|
||||
## Citation
|
||||
```
|
||||
@article{lu2024towards,
|
||||
title={Towards high-quality and efficient speech bandwidth extension with parallel amplitude and phase prediction},
|
||||
author={Lu, Ye-Xin and Ai, Yang and Du, Hui-Peng and Ling, Zhen-Hua},
|
||||
journal={arXiv preprint arXiv:2401.06387},
|
||||
year={2024}
|
||||
}
|
||||
|
||||
@inproceedings{lu2024multi,
|
||||
title={Multi-Stage Speech Bandwidth Extension with Flexible Sampling Rate Control},
|
||||
author={Lu, Ye-Xin and Ai, Yang and Sheng, Zheng-Yan and Ling, Zhen-Hua},
|
||||
booktitle={Proc. Interspeech},
|
||||
pages={2270--2274},
|
||||
year={2024}
|
||||
}
|
||||
```
|
||||
@@ -0,0 +1 @@
|
||||
|
||||
@@ -0,0 +1,108 @@
|
||||
import os
|
||||
import random
|
||||
import torch
|
||||
import torchaudio
|
||||
import torch.utils.data
|
||||
import torchaudio.functional as aF
|
||||
|
||||
|
||||
def amp_pha_stft(audio, n_fft, hop_size, win_size, center=True):
|
||||
hann_window = torch.hann_window(win_size).to(audio.device)
|
||||
stft_spec = torch.stft(
|
||||
audio,
|
||||
n_fft,
|
||||
hop_length=hop_size,
|
||||
win_length=win_size,
|
||||
window=hann_window,
|
||||
center=center,
|
||||
pad_mode="reflect",
|
||||
normalized=False,
|
||||
return_complex=True,
|
||||
)
|
||||
log_amp = torch.log(torch.abs(stft_spec) + 1e-4)
|
||||
pha = torch.angle(stft_spec)
|
||||
|
||||
com = torch.stack((torch.exp(log_amp) * torch.cos(pha), torch.exp(log_amp) * torch.sin(pha)), dim=-1)
|
||||
|
||||
return log_amp, pha, com
|
||||
|
||||
|
||||
def amp_pha_istft(log_amp, pha, n_fft, hop_size, win_size, center=True):
|
||||
amp = torch.exp(log_amp)
|
||||
com = torch.complex(amp * torch.cos(pha), amp * torch.sin(pha))
|
||||
hann_window = torch.hann_window(win_size).to(com.device)
|
||||
audio = torch.istft(com, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window, center=center)
|
||||
|
||||
return audio
|
||||
|
||||
|
||||
def get_dataset_filelist(a):
|
||||
with open(a.input_training_file, "r", encoding="utf-8") as fi:
|
||||
training_indexes = [x.split("|")[0] for x in fi.read().split("\n") if len(x) > 0]
|
||||
|
||||
with open(a.input_validation_file, "r", encoding="utf-8") as fi:
|
||||
validation_indexes = [x.split("|")[0] for x in fi.read().split("\n") if len(x) > 0]
|
||||
|
||||
return training_indexes, validation_indexes
|
||||
|
||||
|
||||
class Dataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
training_indexes,
|
||||
wavs_dir,
|
||||
segment_size,
|
||||
hr_sampling_rate,
|
||||
lr_sampling_rate,
|
||||
split=True,
|
||||
shuffle=True,
|
||||
n_cache_reuse=1,
|
||||
device=None,
|
||||
):
|
||||
self.audio_indexes = training_indexes
|
||||
random.seed(1234)
|
||||
if shuffle:
|
||||
random.shuffle(self.audio_indexes)
|
||||
self.wavs_dir = wavs_dir
|
||||
self.segment_size = segment_size
|
||||
self.hr_sampling_rate = hr_sampling_rate
|
||||
self.lr_sampling_rate = lr_sampling_rate
|
||||
self.split = split
|
||||
self.cached_wav = None
|
||||
self.n_cache_reuse = n_cache_reuse
|
||||
self._cache_ref_count = 0
|
||||
self.device = device
|
||||
|
||||
def __getitem__(self, index):
|
||||
filename = self.audio_indexes[index]
|
||||
if self._cache_ref_count == 0:
|
||||
audio, orig_sampling_rate = torchaudio.load(os.path.join(self.wavs_dir, filename + ".wav"))
|
||||
self.cached_wav = audio
|
||||
self._cache_ref_count = self.n_cache_reuse
|
||||
else:
|
||||
audio = self.cached_wav
|
||||
self._cache_ref_count -= 1
|
||||
|
||||
if orig_sampling_rate == self.hr_sampling_rate:
|
||||
audio_hr = audio
|
||||
else:
|
||||
audio_hr = aF.resample(audio, orig_freq=orig_sampling_rate, new_freq=self.hr_sampling_rate)
|
||||
|
||||
audio_lr = aF.resample(audio, orig_freq=orig_sampling_rate, new_freq=self.lr_sampling_rate)
|
||||
audio_lr = aF.resample(audio_lr, orig_freq=self.lr_sampling_rate, new_freq=self.hr_sampling_rate)
|
||||
audio_lr = audio_lr[:, : audio_hr.size(1)]
|
||||
|
||||
if self.split:
|
||||
if audio_hr.size(1) >= self.segment_size:
|
||||
max_audio_start = audio_hr.size(1) - self.segment_size
|
||||
audio_start = random.randint(0, max_audio_start)
|
||||
audio_hr = audio_hr[:, audio_start : audio_start + self.segment_size]
|
||||
audio_lr = audio_lr[:, audio_start : audio_start + self.segment_size]
|
||||
else:
|
||||
audio_hr = torch.nn.functional.pad(audio_hr, (0, self.segment_size - audio_hr.size(1)), "constant")
|
||||
audio_lr = torch.nn.functional.pad(audio_lr, (0, self.segment_size - audio_lr.size(1)), "constant")
|
||||
|
||||
return (audio_hr.squeeze(), audio_lr.squeeze())
|
||||
|
||||
def __len__(self):
|
||||
return len(self.audio_indexes)
|
||||
@@ -0,0 +1 @@
|
||||
|
||||
@@ -0,0 +1,464 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
from torch.nn.utils import weight_norm, spectral_norm
|
||||
|
||||
|
||||
# from utils import init_weights, get_padding
|
||||
def get_padding(kernel_size, dilation=1):
|
||||
return int((kernel_size * dilation - dilation) / 2)
|
||||
|
||||
|
||||
def init_weights(m, mean=0.0, std=0.01):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("Conv") != -1:
|
||||
m.weight.data.normal_(mean, std)
|
||||
|
||||
|
||||
import numpy as np
|
||||
from typing import Tuple, List
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
|
||||
class ConvNeXtBlock(nn.Module):
|
||||
"""ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
intermediate_dim (int): Dimensionality of the intermediate layer.
|
||||
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
|
||||
Defaults to None.
|
||||
adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
|
||||
None means non-conditional LayerNorm. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
layer_scale_init_value=None,
|
||||
adanorm_num_embeddings=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
|
||||
self.adanorm = adanorm_num_embeddings is not None
|
||||
|
||||
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
||||
self.pwconv1 = nn.Linear(dim, dim * 3) # pointwise/1x1 convs, implemented with linear layers
|
||||
self.act = nn.GELU()
|
||||
self.pwconv2 = nn.Linear(dim * 3, dim)
|
||||
self.gamma = (
|
||||
nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
|
||||
if layer_scale_init_value > 0
|
||||
else None
|
||||
)
|
||||
|
||||
def forward(self, x, cond_embedding_id=None):
|
||||
residual = x
|
||||
x = self.dwconv(x)
|
||||
x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
|
||||
if self.adanorm:
|
||||
assert cond_embedding_id is not None
|
||||
x = self.norm(x, cond_embedding_id)
|
||||
else:
|
||||
x = self.norm(x)
|
||||
x = self.pwconv1(x)
|
||||
x = self.act(x)
|
||||
x = self.pwconv2(x)
|
||||
if self.gamma is not None:
|
||||
x = self.gamma * x
|
||||
x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
|
||||
|
||||
x = residual + x
|
||||
return x
|
||||
|
||||
|
||||
class APNet_BWE_Model(torch.nn.Module):
|
||||
def __init__(self, h):
|
||||
super(APNet_BWE_Model, self).__init__()
|
||||
self.h = h
|
||||
self.adanorm_num_embeddings = None
|
||||
layer_scale_init_value = 1 / h.ConvNeXt_layers
|
||||
|
||||
self.conv_pre_mag = nn.Conv1d(h.n_fft // 2 + 1, h.ConvNeXt_channels, 7, 1, padding=get_padding(7, 1))
|
||||
self.norm_pre_mag = nn.LayerNorm(h.ConvNeXt_channels, eps=1e-6)
|
||||
self.conv_pre_pha = nn.Conv1d(h.n_fft // 2 + 1, h.ConvNeXt_channels, 7, 1, padding=get_padding(7, 1))
|
||||
self.norm_pre_pha = nn.LayerNorm(h.ConvNeXt_channels, eps=1e-6)
|
||||
|
||||
self.convnext_mag = nn.ModuleList(
|
||||
[
|
||||
ConvNeXtBlock(
|
||||
dim=h.ConvNeXt_channels,
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
adanorm_num_embeddings=self.adanorm_num_embeddings,
|
||||
)
|
||||
for _ in range(h.ConvNeXt_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.convnext_pha = nn.ModuleList(
|
||||
[
|
||||
ConvNeXtBlock(
|
||||
dim=h.ConvNeXt_channels,
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
adanorm_num_embeddings=self.adanorm_num_embeddings,
|
||||
)
|
||||
for _ in range(h.ConvNeXt_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.norm_post_mag = nn.LayerNorm(h.ConvNeXt_channels, eps=1e-6)
|
||||
self.norm_post_pha = nn.LayerNorm(h.ConvNeXt_channels, eps=1e-6)
|
||||
self.apply(self._init_weights)
|
||||
self.linear_post_mag = nn.Linear(h.ConvNeXt_channels, h.n_fft // 2 + 1)
|
||||
self.linear_post_pha_r = nn.Linear(h.ConvNeXt_channels, h.n_fft // 2 + 1)
|
||||
self.linear_post_pha_i = nn.Linear(h.ConvNeXt_channels, h.n_fft // 2 + 1)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, (nn.Conv1d, nn.Linear)):
|
||||
nn.init.trunc_normal_(m.weight, std=0.02)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, mag_nb, pha_nb):
|
||||
x_mag = self.conv_pre_mag(mag_nb)
|
||||
x_pha = self.conv_pre_pha(pha_nb)
|
||||
x_mag = self.norm_pre_mag(x_mag.transpose(1, 2)).transpose(1, 2)
|
||||
x_pha = self.norm_pre_pha(x_pha.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
for conv_block_mag, conv_block_pha in zip(self.convnext_mag, self.convnext_pha):
|
||||
x_mag = x_mag + x_pha
|
||||
x_pha = x_pha + x_mag
|
||||
x_mag = conv_block_mag(x_mag, cond_embedding_id=None)
|
||||
x_pha = conv_block_pha(x_pha, cond_embedding_id=None)
|
||||
|
||||
x_mag = self.norm_post_mag(x_mag.transpose(1, 2))
|
||||
mag_wb = mag_nb + self.linear_post_mag(x_mag).transpose(1, 2)
|
||||
|
||||
x_pha = self.norm_post_pha(x_pha.transpose(1, 2))
|
||||
x_pha_r = self.linear_post_pha_r(x_pha)
|
||||
x_pha_i = self.linear_post_pha_i(x_pha)
|
||||
pha_wb = torch.atan2(x_pha_i, x_pha_r).transpose(1, 2)
|
||||
|
||||
com_wb = torch.stack((torch.exp(mag_wb) * torch.cos(pha_wb), torch.exp(mag_wb) * torch.sin(pha_wb)), dim=-1)
|
||||
|
||||
return mag_wb, pha_wb, com_wb
|
||||
|
||||
|
||||
class DiscriminatorP(torch.nn.Module):
|
||||
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
||||
super(DiscriminatorP, self).__init__()
|
||||
self.period = period
|
||||
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
||||
self.convs = nn.ModuleList(
|
||||
[
|
||||
norm_f(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||
norm_f(nn.Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||
norm_f(nn.Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||
norm_f(nn.Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||
norm_f(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
|
||||
]
|
||||
)
|
||||
self.conv_post = norm_f(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
||||
|
||||
def forward(self, x):
|
||||
fmap = []
|
||||
|
||||
# 1d to 2d
|
||||
b, c, t = x.shape
|
||||
if t % self.period != 0: # pad first
|
||||
n_pad = self.period - (t % self.period)
|
||||
x = F.pad(x, (0, n_pad), "reflect")
|
||||
t = t + n_pad
|
||||
x = x.view(b, c, t // self.period, self.period)
|
||||
|
||||
for i, l in enumerate(self.convs):
|
||||
x = l(x)
|
||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||
if i > 0:
|
||||
fmap.append(x)
|
||||
x = self.conv_post(x)
|
||||
fmap.append(x)
|
||||
x = torch.flatten(x, 1, -1)
|
||||
|
||||
return x, fmap
|
||||
|
||||
|
||||
class MultiPeriodDiscriminator(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(MultiPeriodDiscriminator, self).__init__()
|
||||
self.discriminators = nn.ModuleList(
|
||||
[
|
||||
DiscriminatorP(2),
|
||||
DiscriminatorP(3),
|
||||
DiscriminatorP(5),
|
||||
DiscriminatorP(7),
|
||||
DiscriminatorP(11),
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, y, y_hat):
|
||||
y_d_rs = []
|
||||
y_d_gs = []
|
||||
fmap_rs = []
|
||||
fmap_gs = []
|
||||
for i, d in enumerate(self.discriminators):
|
||||
y_d_r, fmap_r = d(y)
|
||||
y_d_g, fmap_g = d(y_hat)
|
||||
y_d_rs.append(y_d_r)
|
||||
fmap_rs.append(fmap_r)
|
||||
y_d_gs.append(y_d_g)
|
||||
fmap_gs.append(fmap_g)
|
||||
|
||||
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||||
|
||||
|
||||
class MultiResolutionAmplitudeDiscriminator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
resolutions: Tuple[Tuple[int, int, int]] = ((512, 128, 512), (1024, 256, 1024), (2048, 512, 2048)),
|
||||
num_embeddings: int = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.discriminators = nn.ModuleList(
|
||||
[DiscriminatorAR(resolution=r, num_embeddings=num_embeddings) for r in resolutions]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None
|
||||
) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
|
||||
y_d_rs = []
|
||||
y_d_gs = []
|
||||
fmap_rs = []
|
||||
fmap_gs = []
|
||||
|
||||
for d in self.discriminators:
|
||||
y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id)
|
||||
y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id)
|
||||
y_d_rs.append(y_d_r)
|
||||
fmap_rs.append(fmap_r)
|
||||
y_d_gs.append(y_d_g)
|
||||
fmap_gs.append(fmap_g)
|
||||
|
||||
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||||
|
||||
|
||||
class DiscriminatorAR(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
resolution: Tuple[int, int, int],
|
||||
channels: int = 64,
|
||||
in_channels: int = 1,
|
||||
num_embeddings: int = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.convs = nn.ModuleList(
|
||||
[
|
||||
weight_norm(nn.Conv2d(in_channels, channels, kernel_size=(7, 5), stride=(2, 2), padding=(3, 2))),
|
||||
weight_norm(nn.Conv2d(channels, channels, kernel_size=(5, 3), stride=(2, 1), padding=(2, 1))),
|
||||
weight_norm(nn.Conv2d(channels, channels, kernel_size=(5, 3), stride=(2, 2), padding=(2, 1))),
|
||||
weight_norm(nn.Conv2d(channels, channels, kernel_size=3, stride=(2, 1), padding=1)),
|
||||
weight_norm(nn.Conv2d(channels, channels, kernel_size=3, stride=(2, 2), padding=1)),
|
||||
]
|
||||
)
|
||||
if num_embeddings is not None:
|
||||
self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=channels)
|
||||
torch.nn.init.zeros_(self.emb.weight)
|
||||
self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), padding=(1, 1)))
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None
|
||||
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||
fmap = []
|
||||
x = x.squeeze(1)
|
||||
|
||||
x = self.spectrogram(x)
|
||||
x = x.unsqueeze(1)
|
||||
for l in self.convs:
|
||||
x = l(x)
|
||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||
fmap.append(x)
|
||||
if cond_embedding_id is not None:
|
||||
emb = self.emb(cond_embedding_id)
|
||||
h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True)
|
||||
else:
|
||||
h = 0
|
||||
x = self.conv_post(x)
|
||||
fmap.append(x)
|
||||
x += h
|
||||
x = torch.flatten(x, 1, -1)
|
||||
|
||||
return x, fmap
|
||||
|
||||
def spectrogram(self, x: torch.Tensor) -> torch.Tensor:
|
||||
n_fft, hop_length, win_length = self.resolution
|
||||
amplitude_spectrogram = torch.stft(
|
||||
x,
|
||||
n_fft=n_fft,
|
||||
hop_length=hop_length,
|
||||
win_length=win_length,
|
||||
window=None, # interestingly rectangular window kind of works here
|
||||
center=True,
|
||||
return_complex=True,
|
||||
).abs()
|
||||
|
||||
return amplitude_spectrogram
|
||||
|
||||
|
||||
class MultiResolutionPhaseDiscriminator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
resolutions: Tuple[Tuple[int, int, int]] = ((512, 128, 512), (1024, 256, 1024), (2048, 512, 2048)),
|
||||
num_embeddings: int = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.discriminators = nn.ModuleList(
|
||||
[DiscriminatorPR(resolution=r, num_embeddings=num_embeddings) for r in resolutions]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None
|
||||
) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
|
||||
y_d_rs = []
|
||||
y_d_gs = []
|
||||
fmap_rs = []
|
||||
fmap_gs = []
|
||||
|
||||
for d in self.discriminators:
|
||||
y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id)
|
||||
y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id)
|
||||
y_d_rs.append(y_d_r)
|
||||
fmap_rs.append(fmap_r)
|
||||
y_d_gs.append(y_d_g)
|
||||
fmap_gs.append(fmap_g)
|
||||
|
||||
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||||
|
||||
|
||||
class DiscriminatorPR(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
resolution: Tuple[int, int, int],
|
||||
channels: int = 64,
|
||||
in_channels: int = 1,
|
||||
num_embeddings: int = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.convs = nn.ModuleList(
|
||||
[
|
||||
weight_norm(nn.Conv2d(in_channels, channels, kernel_size=(7, 5), stride=(2, 2), padding=(3, 2))),
|
||||
weight_norm(nn.Conv2d(channels, channels, kernel_size=(5, 3), stride=(2, 1), padding=(2, 1))),
|
||||
weight_norm(nn.Conv2d(channels, channels, kernel_size=(5, 3), stride=(2, 2), padding=(2, 1))),
|
||||
weight_norm(nn.Conv2d(channels, channels, kernel_size=3, stride=(2, 1), padding=1)),
|
||||
weight_norm(nn.Conv2d(channels, channels, kernel_size=3, stride=(2, 2), padding=1)),
|
||||
]
|
||||
)
|
||||
if num_embeddings is not None:
|
||||
self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=channels)
|
||||
torch.nn.init.zeros_(self.emb.weight)
|
||||
self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), padding=(1, 1)))
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None
|
||||
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||
fmap = []
|
||||
x = x.squeeze(1)
|
||||
|
||||
x = self.spectrogram(x)
|
||||
x = x.unsqueeze(1)
|
||||
for l in self.convs:
|
||||
x = l(x)
|
||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||
fmap.append(x)
|
||||
if cond_embedding_id is not None:
|
||||
emb = self.emb(cond_embedding_id)
|
||||
h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True)
|
||||
else:
|
||||
h = 0
|
||||
x = self.conv_post(x)
|
||||
fmap.append(x)
|
||||
x += h
|
||||
x = torch.flatten(x, 1, -1)
|
||||
|
||||
return x, fmap
|
||||
|
||||
def spectrogram(self, x: torch.Tensor) -> torch.Tensor:
|
||||
n_fft, hop_length, win_length = self.resolution
|
||||
phase_spectrogram = torch.stft(
|
||||
x,
|
||||
n_fft=n_fft,
|
||||
hop_length=hop_length,
|
||||
win_length=win_length,
|
||||
window=None, # interestingly rectangular window kind of works here
|
||||
center=True,
|
||||
return_complex=True,
|
||||
).angle()
|
||||
|
||||
return phase_spectrogram
|
||||
|
||||
|
||||
def feature_loss(fmap_r, fmap_g):
|
||||
loss = 0
|
||||
for dr, dg in zip(fmap_r, fmap_g):
|
||||
for rl, gl in zip(dr, dg):
|
||||
loss += torch.mean(torch.abs(rl - gl))
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
||||
loss = 0
|
||||
r_losses = []
|
||||
g_losses = []
|
||||
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
||||
r_loss = torch.mean(torch.clamp(1 - dr, min=0))
|
||||
g_loss = torch.mean(torch.clamp(1 + dg, min=0))
|
||||
loss += r_loss + g_loss
|
||||
r_losses.append(r_loss.item())
|
||||
g_losses.append(g_loss.item())
|
||||
|
||||
return loss, r_losses, g_losses
|
||||
|
||||
|
||||
def generator_loss(disc_outputs):
|
||||
loss = 0
|
||||
gen_losses = []
|
||||
for dg in disc_outputs:
|
||||
l = torch.mean(torch.clamp(1 - dg, min=0))
|
||||
gen_losses.append(l)
|
||||
loss += l
|
||||
|
||||
return loss, gen_losses
|
||||
|
||||
|
||||
def phase_losses(phase_r, phase_g):
|
||||
ip_loss = torch.mean(anti_wrapping_function(phase_r - phase_g))
|
||||
gd_loss = torch.mean(anti_wrapping_function(torch.diff(phase_r, dim=1) - torch.diff(phase_g, dim=1)))
|
||||
iaf_loss = torch.mean(anti_wrapping_function(torch.diff(phase_r, dim=2) - torch.diff(phase_g, dim=2)))
|
||||
|
||||
return ip_loss, gd_loss, iaf_loss
|
||||
|
||||
|
||||
def anti_wrapping_function(x):
|
||||
return torch.abs(x - torch.round(x / (2 * np.pi)) * 2 * np.pi)
|
||||
|
||||
|
||||
def stft_mag(audio, n_fft=2048, hop_length=512):
|
||||
hann_window = torch.hann_window(n_fft).to(audio.device)
|
||||
stft_spec = torch.stft(audio, n_fft, hop_length, window=hann_window, return_complex=True)
|
||||
stft_mag = torch.abs(stft_spec)
|
||||
return stft_mag
|
||||
|
||||
|
||||
def cal_snr(pred, target):
|
||||
snr = (20 * torch.log10(torch.norm(target, dim=-1) / torch.norm(pred - target, dim=-1).clamp(min=1e-8))).mean()
|
||||
return snr
|
||||
|
||||
|
||||
def cal_lsd(pred, target):
|
||||
sp = torch.log10(stft_mag(pred).square().clamp(1e-8))
|
||||
st = torch.log10(stft_mag(target).square().clamp(1e-8))
|
||||
return (sp - st).square().mean(dim=1).sqrt().mean()
|
||||
Reference in New Issue
Block a user