init ascend tts

This commit is contained in:
2025-09-05 10:49:17 +08:00
parent d53ac91bb6
commit c5a6692774
602 changed files with 590901 additions and 1 deletions

View File

@@ -0,0 +1,11 @@
For the inference of the v3 model, if you find that the generated audio sounds somewhat muffled, you can try using this audio super-resolution model.
对于v3模型的推理如果你发现生成的音频比较闷可以尝试这个音频超分模型。
put g_24kto48k.zip and config.json in this folder
把g_24kto48k.zip and config.json下到这个文件夹
download link 下载链接:
https://drive.google.com/drive/folders/1IIYTf2zbJWzelu4IftKD6ooHloJ8mnZF?usp=share_link
audio sr project page 音频超分项目主页:
https://github.com/yxlu-0102/AP-BWE

View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2023 Ye-Xin Lu
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -0,0 +1,91 @@
# Towards High-Quality and Efficient Speech Bandwidth Extension with Parallel Amplitude and Phase Prediction
### Ye-Xin Lu, Yang Ai, Hui-Peng Du, Zhen-Hua Ling
**Abstract:**
Speech bandwidth extension (BWE) refers to widening the frequency bandwidth range of speech signals, enhancing the speech quality towards brighter and fuller.
This paper proposes a generative adversarial network (GAN) based BWE model with parallel prediction of Amplitude and Phase spectra, named AP-BWE, which achieves both high-quality and efficient wideband speech waveform generation.
The proposed AP-BWE generator is entirely based on convolutional neural networks (CNNs).
It features a dual-stream architecture with mutual interaction, where the amplitude stream and the phase stream communicate with each other and respectively extend the high-frequency components from the input narrowband amplitude and phase spectra.
To improve the naturalness of the extended speech signals, we employ a multi-period discriminator at the waveform level and design a pair of multi-resolution amplitude and phase discriminators at the spectral level, respectively.
Experimental results demonstrate that our proposed AP-BWE achieves state-of-the-art performance in terms of speech quality for BWE tasks targeting sampling rates of both 16 kHz and 48 kHz.
In terms of generation efficiency, due to the all-convolutional architecture and all-frame-level operations, the proposed AP-BWE can generate 48 kHz waveform samples 292.3 times faster than real-time on a single RTX 4090 GPU and 18.1 times faster than real-time on a single CPU.
Notably, to our knowledge, AP-BWE is the first to achieve the direct extension of the high-frequency phase spectrum, which is beneficial for improving the effectiveness of existing BWE methods.
**We provide our implementation as open source in this repository. Audio samples can be found at the [demo website](http://yxlu-0102.github.io/AP-BWE).**
## Pre-requisites
0. Python >= 3.9.
0. Clone this repository.
0. Install python requirements. Please refer [requirements.txt](requirements.txt).
0. Download datasets
1. Download and extract the [VCTK-0.92 dataset](https://datashare.ed.ac.uk/handle/10283/3443), and move its `wav48` directory into [VCTK-Corpus-0.92](VCTK-Corpus-0.92) and rename it as `wav48_origin`.
1. Trim the silence of the dataset, and the trimmed files will be saved to `wav48_silence_trimmed`.
```
cd VCTK-Corpus-0.92
python flac2wav.py
```
1. Move all the trimmed training files from `wav48_silence_trimmed` to [wav48/train](wav48/train) following the indexes in [training.txt](VCTK-Corpus-0.92/training.txt), and move all the untrimmed test files from `wav48_origin` to [wav48/test](wav48/test) following the indexes in [test.txt](VCTK-Corpus-0.92/test.txt).
## Training
```
cd train
CUDA_VISIBLE_DEVICES=0 python train_16k.py --config [config file path]
CUDA_VISIBLE_DEVICES=0 python train_48k.py --config [config file path]
```
Checkpoints and copies of the configuration file are saved in the `cp_model` directory by default.<br>
You can change the path by using the `--checkpoint_path` option.
Here is an example:
```
CUDA_VISIBLE_DEVICES=0 python train_16k.py --config ../configs/config_2kto16k.json --checkpoint_path ../checkpoints/AP-BWE_2kto16k
```
## Inference
```
cd inference
python inference_16k.py --checkpoint_file [generator checkpoint file path]
python inference_48k.py --checkpoint_file [generator checkpoint file path]
```
You can download the [pretrained weights](https://drive.google.com/drive/folders/1IIYTf2zbJWzelu4IftKD6ooHloJ8mnZF?usp=share_link) we provide and move all the files to the `checkpoints` directory.
<br>
Generated wav files are saved in `generated_files` by default.
You can change the path by adding `--output_dir` option.
Here is an example:
```
python inference_16k.py --checkpoint_file ../checkpoints/2kto16k/g_2kto16k --output_dir ../generated_files/2kto16k
```
## Model Structure
![model](Figures/model.png)
## Comparison with other speech BWE methods
### 2k/4k/8kHz to 16kHz
<p align="center">
<img src="Figures/table_16k.png" alt="comparison" width="90%"/>
</p>
### 8k/12k/16/24kHz to 16kHz
<p align="center">
<img src="Figures/table_48k.png" alt="comparison" width="100%"/>
</p>
## Acknowledgements
We referred to [HiFi-GAN](https://github.com/jik876/hifi-gan) and [NSPP](https://github.com/YangAi520/NSPP) to implement this.
## Citation
```
@article{lu2024towards,
title={Towards high-quality and efficient speech bandwidth extension with parallel amplitude and phase prediction},
author={Lu, Ye-Xin and Ai, Yang and Du, Hui-Peng and Ling, Zhen-Hua},
journal={arXiv preprint arXiv:2401.06387},
year={2024}
}
@inproceedings{lu2024multi,
title={Multi-Stage Speech Bandwidth Extension with Flexible Sampling Rate Control},
author={Lu, Ye-Xin and Ai, Yang and Sheng, Zheng-Yan and Ling, Zhen-Hua},
booktitle={Proc. Interspeech},
pages={2270--2274},
year={2024}
}
```

View File

@@ -0,0 +1,108 @@
import os
import random
import torch
import torchaudio
import torch.utils.data
import torchaudio.functional as aF
def amp_pha_stft(audio, n_fft, hop_size, win_size, center=True):
hann_window = torch.hann_window(win_size).to(audio.device)
stft_spec = torch.stft(
audio,
n_fft,
hop_length=hop_size,
win_length=win_size,
window=hann_window,
center=center,
pad_mode="reflect",
normalized=False,
return_complex=True,
)
log_amp = torch.log(torch.abs(stft_spec) + 1e-4)
pha = torch.angle(stft_spec)
com = torch.stack((torch.exp(log_amp) * torch.cos(pha), torch.exp(log_amp) * torch.sin(pha)), dim=-1)
return log_amp, pha, com
def amp_pha_istft(log_amp, pha, n_fft, hop_size, win_size, center=True):
amp = torch.exp(log_amp)
com = torch.complex(amp * torch.cos(pha), amp * torch.sin(pha))
hann_window = torch.hann_window(win_size).to(com.device)
audio = torch.istft(com, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window, center=center)
return audio
def get_dataset_filelist(a):
with open(a.input_training_file, "r", encoding="utf-8") as fi:
training_indexes = [x.split("|")[0] for x in fi.read().split("\n") if len(x) > 0]
with open(a.input_validation_file, "r", encoding="utf-8") as fi:
validation_indexes = [x.split("|")[0] for x in fi.read().split("\n") if len(x) > 0]
return training_indexes, validation_indexes
class Dataset(torch.utils.data.Dataset):
def __init__(
self,
training_indexes,
wavs_dir,
segment_size,
hr_sampling_rate,
lr_sampling_rate,
split=True,
shuffle=True,
n_cache_reuse=1,
device=None,
):
self.audio_indexes = training_indexes
random.seed(1234)
if shuffle:
random.shuffle(self.audio_indexes)
self.wavs_dir = wavs_dir
self.segment_size = segment_size
self.hr_sampling_rate = hr_sampling_rate
self.lr_sampling_rate = lr_sampling_rate
self.split = split
self.cached_wav = None
self.n_cache_reuse = n_cache_reuse
self._cache_ref_count = 0
self.device = device
def __getitem__(self, index):
filename = self.audio_indexes[index]
if self._cache_ref_count == 0:
audio, orig_sampling_rate = torchaudio.load(os.path.join(self.wavs_dir, filename + ".wav"))
self.cached_wav = audio
self._cache_ref_count = self.n_cache_reuse
else:
audio = self.cached_wav
self._cache_ref_count -= 1
if orig_sampling_rate == self.hr_sampling_rate:
audio_hr = audio
else:
audio_hr = aF.resample(audio, orig_freq=orig_sampling_rate, new_freq=self.hr_sampling_rate)
audio_lr = aF.resample(audio, orig_freq=orig_sampling_rate, new_freq=self.lr_sampling_rate)
audio_lr = aF.resample(audio_lr, orig_freq=self.lr_sampling_rate, new_freq=self.hr_sampling_rate)
audio_lr = audio_lr[:, : audio_hr.size(1)]
if self.split:
if audio_hr.size(1) >= self.segment_size:
max_audio_start = audio_hr.size(1) - self.segment_size
audio_start = random.randint(0, max_audio_start)
audio_hr = audio_hr[:, audio_start : audio_start + self.segment_size]
audio_lr = audio_lr[:, audio_start : audio_start + self.segment_size]
else:
audio_hr = torch.nn.functional.pad(audio_hr, (0, self.segment_size - audio_hr.size(1)), "constant")
audio_lr = torch.nn.functional.pad(audio_lr, (0, self.segment_size - audio_lr.size(1)), "constant")
return (audio_hr.squeeze(), audio_lr.squeeze())
def __len__(self):
return len(self.audio_indexes)

View File

@@ -0,0 +1,464 @@
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.nn.utils import weight_norm, spectral_norm
# from utils import init_weights, get_padding
def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)
import numpy as np
from typing import Tuple, List
LRELU_SLOPE = 0.1
class ConvNeXtBlock(nn.Module):
"""ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
Args:
dim (int): Number of input channels.
intermediate_dim (int): Dimensionality of the intermediate layer.
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
Defaults to None.
adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
None means non-conditional LayerNorm. Defaults to None.
"""
def __init__(
self,
dim: int,
layer_scale_init_value=None,
adanorm_num_embeddings=None,
):
super().__init__()
self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
self.adanorm = adanorm_num_embeddings is not None
self.norm = nn.LayerNorm(dim, eps=1e-6)
self.pwconv1 = nn.Linear(dim, dim * 3) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.pwconv2 = nn.Linear(dim * 3, dim)
self.gamma = (
nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
if layer_scale_init_value > 0
else None
)
def forward(self, x, cond_embedding_id=None):
residual = x
x = self.dwconv(x)
x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
if self.adanorm:
assert cond_embedding_id is not None
x = self.norm(x, cond_embedding_id)
else:
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.pwconv2(x)
if self.gamma is not None:
x = self.gamma * x
x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
x = residual + x
return x
class APNet_BWE_Model(torch.nn.Module):
def __init__(self, h):
super(APNet_BWE_Model, self).__init__()
self.h = h
self.adanorm_num_embeddings = None
layer_scale_init_value = 1 / h.ConvNeXt_layers
self.conv_pre_mag = nn.Conv1d(h.n_fft // 2 + 1, h.ConvNeXt_channels, 7, 1, padding=get_padding(7, 1))
self.norm_pre_mag = nn.LayerNorm(h.ConvNeXt_channels, eps=1e-6)
self.conv_pre_pha = nn.Conv1d(h.n_fft // 2 + 1, h.ConvNeXt_channels, 7, 1, padding=get_padding(7, 1))
self.norm_pre_pha = nn.LayerNorm(h.ConvNeXt_channels, eps=1e-6)
self.convnext_mag = nn.ModuleList(
[
ConvNeXtBlock(
dim=h.ConvNeXt_channels,
layer_scale_init_value=layer_scale_init_value,
adanorm_num_embeddings=self.adanorm_num_embeddings,
)
for _ in range(h.ConvNeXt_layers)
]
)
self.convnext_pha = nn.ModuleList(
[
ConvNeXtBlock(
dim=h.ConvNeXt_channels,
layer_scale_init_value=layer_scale_init_value,
adanorm_num_embeddings=self.adanorm_num_embeddings,
)
for _ in range(h.ConvNeXt_layers)
]
)
self.norm_post_mag = nn.LayerNorm(h.ConvNeXt_channels, eps=1e-6)
self.norm_post_pha = nn.LayerNorm(h.ConvNeXt_channels, eps=1e-6)
self.apply(self._init_weights)
self.linear_post_mag = nn.Linear(h.ConvNeXt_channels, h.n_fft // 2 + 1)
self.linear_post_pha_r = nn.Linear(h.ConvNeXt_channels, h.n_fft // 2 + 1)
self.linear_post_pha_i = nn.Linear(h.ConvNeXt_channels, h.n_fft // 2 + 1)
def _init_weights(self, m):
if isinstance(m, (nn.Conv1d, nn.Linear)):
nn.init.trunc_normal_(m.weight, std=0.02)
nn.init.constant_(m.bias, 0)
def forward(self, mag_nb, pha_nb):
x_mag = self.conv_pre_mag(mag_nb)
x_pha = self.conv_pre_pha(pha_nb)
x_mag = self.norm_pre_mag(x_mag.transpose(1, 2)).transpose(1, 2)
x_pha = self.norm_pre_pha(x_pha.transpose(1, 2)).transpose(1, 2)
for conv_block_mag, conv_block_pha in zip(self.convnext_mag, self.convnext_pha):
x_mag = x_mag + x_pha
x_pha = x_pha + x_mag
x_mag = conv_block_mag(x_mag, cond_embedding_id=None)
x_pha = conv_block_pha(x_pha, cond_embedding_id=None)
x_mag = self.norm_post_mag(x_mag.transpose(1, 2))
mag_wb = mag_nb + self.linear_post_mag(x_mag).transpose(1, 2)
x_pha = self.norm_post_pha(x_pha.transpose(1, 2))
x_pha_r = self.linear_post_pha_r(x_pha)
x_pha_i = self.linear_post_pha_i(x_pha)
pha_wb = torch.atan2(x_pha_i, x_pha_r).transpose(1, 2)
com_wb = torch.stack((torch.exp(mag_wb) * torch.cos(pha_wb), torch.exp(mag_wb) * torch.sin(pha_wb)), dim=-1)
return mag_wb, pha_wb, com_wb
class DiscriminatorP(torch.nn.Module):
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
super(DiscriminatorP, self).__init__()
self.period = period
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
self.convs = nn.ModuleList(
[
norm_f(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
norm_f(nn.Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
norm_f(nn.Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
norm_f(nn.Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
norm_f(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
]
)
self.conv_post = norm_f(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
def forward(self, x):
fmap = []
# 1d to 2d
b, c, t = x.shape
if t % self.period != 0: # pad first
n_pad = self.period - (t % self.period)
x = F.pad(x, (0, n_pad), "reflect")
t = t + n_pad
x = x.view(b, c, t // self.period, self.period)
for i, l in enumerate(self.convs):
x = l(x)
x = F.leaky_relu(x, LRELU_SLOPE)
if i > 0:
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class MultiPeriodDiscriminator(torch.nn.Module):
def __init__(self):
super(MultiPeriodDiscriminator, self).__init__()
self.discriminators = nn.ModuleList(
[
DiscriminatorP(2),
DiscriminatorP(3),
DiscriminatorP(5),
DiscriminatorP(7),
DiscriminatorP(11),
]
)
def forward(self, y, y_hat):
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for i, d in enumerate(self.discriminators):
y_d_r, fmap_r = d(y)
y_d_g, fmap_g = d(y_hat)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
y_d_gs.append(y_d_g)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
class MultiResolutionAmplitudeDiscriminator(nn.Module):
def __init__(
self,
resolutions: Tuple[Tuple[int, int, int]] = ((512, 128, 512), (1024, 256, 1024), (2048, 512, 2048)),
num_embeddings: int = None,
):
super().__init__()
self.discriminators = nn.ModuleList(
[DiscriminatorAR(resolution=r, num_embeddings=num_embeddings) for r in resolutions]
)
def forward(
self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None
) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for d in self.discriminators:
y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id)
y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
y_d_gs.append(y_d_g)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
class DiscriminatorAR(nn.Module):
def __init__(
self,
resolution: Tuple[int, int, int],
channels: int = 64,
in_channels: int = 1,
num_embeddings: int = None,
):
super().__init__()
self.resolution = resolution
self.in_channels = in_channels
self.convs = nn.ModuleList(
[
weight_norm(nn.Conv2d(in_channels, channels, kernel_size=(7, 5), stride=(2, 2), padding=(3, 2))),
weight_norm(nn.Conv2d(channels, channels, kernel_size=(5, 3), stride=(2, 1), padding=(2, 1))),
weight_norm(nn.Conv2d(channels, channels, kernel_size=(5, 3), stride=(2, 2), padding=(2, 1))),
weight_norm(nn.Conv2d(channels, channels, kernel_size=3, stride=(2, 1), padding=1)),
weight_norm(nn.Conv2d(channels, channels, kernel_size=3, stride=(2, 2), padding=1)),
]
)
if num_embeddings is not None:
self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=channels)
torch.nn.init.zeros_(self.emb.weight)
self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), padding=(1, 1)))
def forward(
self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
fmap = []
x = x.squeeze(1)
x = self.spectrogram(x)
x = x.unsqueeze(1)
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, LRELU_SLOPE)
fmap.append(x)
if cond_embedding_id is not None:
emb = self.emb(cond_embedding_id)
h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True)
else:
h = 0
x = self.conv_post(x)
fmap.append(x)
x += h
x = torch.flatten(x, 1, -1)
return x, fmap
def spectrogram(self, x: torch.Tensor) -> torch.Tensor:
n_fft, hop_length, win_length = self.resolution
amplitude_spectrogram = torch.stft(
x,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
window=None, # interestingly rectangular window kind of works here
center=True,
return_complex=True,
).abs()
return amplitude_spectrogram
class MultiResolutionPhaseDiscriminator(nn.Module):
def __init__(
self,
resolutions: Tuple[Tuple[int, int, int]] = ((512, 128, 512), (1024, 256, 1024), (2048, 512, 2048)),
num_embeddings: int = None,
):
super().__init__()
self.discriminators = nn.ModuleList(
[DiscriminatorPR(resolution=r, num_embeddings=num_embeddings) for r in resolutions]
)
def forward(
self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None
) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for d in self.discriminators:
y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id)
y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
y_d_gs.append(y_d_g)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
class DiscriminatorPR(nn.Module):
def __init__(
self,
resolution: Tuple[int, int, int],
channels: int = 64,
in_channels: int = 1,
num_embeddings: int = None,
):
super().__init__()
self.resolution = resolution
self.in_channels = in_channels
self.convs = nn.ModuleList(
[
weight_norm(nn.Conv2d(in_channels, channels, kernel_size=(7, 5), stride=(2, 2), padding=(3, 2))),
weight_norm(nn.Conv2d(channels, channels, kernel_size=(5, 3), stride=(2, 1), padding=(2, 1))),
weight_norm(nn.Conv2d(channels, channels, kernel_size=(5, 3), stride=(2, 2), padding=(2, 1))),
weight_norm(nn.Conv2d(channels, channels, kernel_size=3, stride=(2, 1), padding=1)),
weight_norm(nn.Conv2d(channels, channels, kernel_size=3, stride=(2, 2), padding=1)),
]
)
if num_embeddings is not None:
self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=channels)
torch.nn.init.zeros_(self.emb.weight)
self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), padding=(1, 1)))
def forward(
self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
fmap = []
x = x.squeeze(1)
x = self.spectrogram(x)
x = x.unsqueeze(1)
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, LRELU_SLOPE)
fmap.append(x)
if cond_embedding_id is not None:
emb = self.emb(cond_embedding_id)
h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True)
else:
h = 0
x = self.conv_post(x)
fmap.append(x)
x += h
x = torch.flatten(x, 1, -1)
return x, fmap
def spectrogram(self, x: torch.Tensor) -> torch.Tensor:
n_fft, hop_length, win_length = self.resolution
phase_spectrogram = torch.stft(
x,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
window=None, # interestingly rectangular window kind of works here
center=True,
return_complex=True,
).angle()
return phase_spectrogram
def feature_loss(fmap_r, fmap_g):
loss = 0
for dr, dg in zip(fmap_r, fmap_g):
for rl, gl in zip(dr, dg):
loss += torch.mean(torch.abs(rl - gl))
return loss
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
loss = 0
r_losses = []
g_losses = []
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
r_loss = torch.mean(torch.clamp(1 - dr, min=0))
g_loss = torch.mean(torch.clamp(1 + dg, min=0))
loss += r_loss + g_loss
r_losses.append(r_loss.item())
g_losses.append(g_loss.item())
return loss, r_losses, g_losses
def generator_loss(disc_outputs):
loss = 0
gen_losses = []
for dg in disc_outputs:
l = torch.mean(torch.clamp(1 - dg, min=0))
gen_losses.append(l)
loss += l
return loss, gen_losses
def phase_losses(phase_r, phase_g):
ip_loss = torch.mean(anti_wrapping_function(phase_r - phase_g))
gd_loss = torch.mean(anti_wrapping_function(torch.diff(phase_r, dim=1) - torch.diff(phase_g, dim=1)))
iaf_loss = torch.mean(anti_wrapping_function(torch.diff(phase_r, dim=2) - torch.diff(phase_g, dim=2)))
return ip_loss, gd_loss, iaf_loss
def anti_wrapping_function(x):
return torch.abs(x - torch.round(x / (2 * np.pi)) * 2 * np.pi)
def stft_mag(audio, n_fft=2048, hop_length=512):
hann_window = torch.hann_window(n_fft).to(audio.device)
stft_spec = torch.stft(audio, n_fft, hop_length, window=hann_window, return_complex=True)
stft_mag = torch.abs(stft_spec)
return stft_mag
def cal_snr(pred, target):
snr = (20 * torch.log10(torch.norm(target, dim=-1) / torch.norm(pred - target, dim=-1).clamp(min=1e-8))).mean()
return snr
def cal_lsd(pred, target):
sp = torch.log10(stft_mag(pred).square().clamp(1e-8))
st = torch.log10(stft_mag(target).square().clamp(1e-8))
return (sp - st).square().mean(dim=1).sqrt().mean()