update
This commit is contained in:
52
bi_v100-f5-tts/F5-TTS/src/f5_tts/eval/README.md
Normal file
52
bi_v100-f5-tts/F5-TTS/src/f5_tts/eval/README.md
Normal file
@@ -0,0 +1,52 @@
|
||||
|
||||
# Evaluation
|
||||
|
||||
Install packages for evaluation:
|
||||
|
||||
```bash
|
||||
pip install -e .[eval]
|
||||
```
|
||||
|
||||
## Generating Samples for Evaluation
|
||||
|
||||
### Prepare Test Datasets
|
||||
|
||||
1. *Seed-TTS testset*: Download from [seed-tts-eval](https://github.com/BytedanceSpeech/seed-tts-eval).
|
||||
2. *LibriSpeech test-clean*: Download from [OpenSLR](http://www.openslr.org/12/).
|
||||
3. Unzip the downloaded datasets and place them in the `data/` directory.
|
||||
4. Update the path for *LibriSpeech test-clean* data in `src/f5_tts/eval/eval_infer_batch.py`
|
||||
5. Our filtered LibriSpeech-PC 4-10s subset: `data/librispeech_pc_test_clean_cross_sentence.lst`
|
||||
|
||||
### Batch Inference for Test Set
|
||||
|
||||
To run batch inference for evaluations, execute the following commands:
|
||||
|
||||
```bash
|
||||
# batch inference for evaluations
|
||||
accelerate config # if not set before
|
||||
bash src/f5_tts/eval/eval_infer_batch.sh
|
||||
```
|
||||
|
||||
## Objective Evaluation on Generated Results
|
||||
|
||||
### Download Evaluation Model Checkpoints
|
||||
|
||||
1. Chinese ASR Model: [Paraformer-zh](https://huggingface.co/funasr/paraformer-zh)
|
||||
2. English ASR Model: [Faster-Whisper](https://huggingface.co/Systran/faster-whisper-large-v3)
|
||||
3. WavLM Model: Download from [Google Drive](https://drive.google.com/file/d/1-aE1NfzpRCLxA4GUxX9ITI3F9LlbtEGP/view).
|
||||
|
||||
Then update in the following scripts with the paths you put evaluation model ckpts to.
|
||||
|
||||
### Objective Evaluation
|
||||
|
||||
Update the path with your batch-inferenced results, and carry out WER / SIM / UTMOS evaluations:
|
||||
```bash
|
||||
# Evaluation [WER] for Seed-TTS test [ZH] set
|
||||
python src/f5_tts/eval/eval_seedtts_testset.py --eval_task wer --lang zh --gen_wav_dir <GEN_WAV_DIR> --gpu_nums 8
|
||||
|
||||
# Evaluation [SIM] for LibriSpeech-PC test-clean (cross-sentence)
|
||||
python src/f5_tts/eval/eval_librispeech_test_clean.py --eval_task sim --gen_wav_dir <GEN_WAV_DIR> --librispeech_test_clean_path <TEST_CLEAN_PATH>
|
||||
|
||||
# Evaluation [UTMOS]. --ext: Audio extension
|
||||
python src/f5_tts/eval/eval_utmos.py --audio_dir <WAV_DIR> --ext wav
|
||||
```
|
||||
331
bi_v100-f5-tts/F5-TTS/src/f5_tts/eval/ecapa_tdnn.py
Normal file
331
bi_v100-f5-tts/F5-TTS/src/f5_tts/eval/ecapa_tdnn.py
Normal file
@@ -0,0 +1,331 @@
|
||||
# just for speaker similarity evaluation, third-party code
|
||||
|
||||
# From https://github.com/microsoft/UniSpeech/blob/main/downstreams/speaker_verification/models/
|
||||
# part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN
|
||||
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
""" Res2Conv1d + BatchNorm1d + ReLU
|
||||
"""
|
||||
|
||||
|
||||
class Res2Conv1dReluBn(nn.Module):
|
||||
"""
|
||||
in_channels == out_channels == channels
|
||||
"""
|
||||
|
||||
def __init__(self, channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True, scale=4):
|
||||
super().__init__()
|
||||
assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
|
||||
self.scale = scale
|
||||
self.width = channels // scale
|
||||
self.nums = scale if scale == 1 else scale - 1
|
||||
|
||||
self.convs = []
|
||||
self.bns = []
|
||||
for i in range(self.nums):
|
||||
self.convs.append(nn.Conv1d(self.width, self.width, kernel_size, stride, padding, dilation, bias=bias))
|
||||
self.bns.append(nn.BatchNorm1d(self.width))
|
||||
self.convs = nn.ModuleList(self.convs)
|
||||
self.bns = nn.ModuleList(self.bns)
|
||||
|
||||
def forward(self, x):
|
||||
out = []
|
||||
spx = torch.split(x, self.width, 1)
|
||||
for i in range(self.nums):
|
||||
if i == 0:
|
||||
sp = spx[i]
|
||||
else:
|
||||
sp = sp + spx[i]
|
||||
# Order: conv -> relu -> bn
|
||||
sp = self.convs[i](sp)
|
||||
sp = self.bns[i](F.relu(sp))
|
||||
out.append(sp)
|
||||
if self.scale != 1:
|
||||
out.append(spx[self.nums])
|
||||
out = torch.cat(out, dim=1)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
""" Conv1d + BatchNorm1d + ReLU
|
||||
"""
|
||||
|
||||
|
||||
class Conv1dReluBn(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)
|
||||
self.bn = nn.BatchNorm1d(out_channels)
|
||||
|
||||
def forward(self, x):
|
||||
return self.bn(F.relu(self.conv(x)))
|
||||
|
||||
|
||||
""" The SE connection of 1D case.
|
||||
"""
|
||||
|
||||
|
||||
class SE_Connect(nn.Module):
|
||||
def __init__(self, channels, se_bottleneck_dim=128):
|
||||
super().__init__()
|
||||
self.linear1 = nn.Linear(channels, se_bottleneck_dim)
|
||||
self.linear2 = nn.Linear(se_bottleneck_dim, channels)
|
||||
|
||||
def forward(self, x):
|
||||
out = x.mean(dim=2)
|
||||
out = F.relu(self.linear1(out))
|
||||
out = torch.sigmoid(self.linear2(out))
|
||||
out = x * out.unsqueeze(2)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
""" SE-Res2Block of the ECAPA-TDNN architecture.
|
||||
"""
|
||||
|
||||
# def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale):
|
||||
# return nn.Sequential(
|
||||
# Conv1dReluBn(channels, 512, kernel_size=1, stride=1, padding=0),
|
||||
# Res2Conv1dReluBn(512, kernel_size, stride, padding, dilation, scale=scale),
|
||||
# Conv1dReluBn(512, channels, kernel_size=1, stride=1, padding=0),
|
||||
# SE_Connect(channels)
|
||||
# )
|
||||
|
||||
|
||||
class SE_Res2Block(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, scale, se_bottleneck_dim):
|
||||
super().__init__()
|
||||
self.Conv1dReluBn1 = Conv1dReluBn(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.Res2Conv1dReluBn = Res2Conv1dReluBn(out_channels, kernel_size, stride, padding, dilation, scale=scale)
|
||||
self.Conv1dReluBn2 = Conv1dReluBn(out_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.SE_Connect = SE_Connect(out_channels, se_bottleneck_dim)
|
||||
|
||||
self.shortcut = None
|
||||
if in_channels != out_channels:
|
||||
self.shortcut = nn.Conv1d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
if self.shortcut:
|
||||
residual = self.shortcut(x)
|
||||
|
||||
x = self.Conv1dReluBn1(x)
|
||||
x = self.Res2Conv1dReluBn(x)
|
||||
x = self.Conv1dReluBn2(x)
|
||||
x = self.SE_Connect(x)
|
||||
|
||||
return x + residual
|
||||
|
||||
|
||||
""" Attentive weighted mean and standard deviation pooling.
|
||||
"""
|
||||
|
||||
|
||||
class AttentiveStatsPool(nn.Module):
|
||||
def __init__(self, in_dim, attention_channels=128, global_context_att=False):
|
||||
super().__init__()
|
||||
self.global_context_att = global_context_att
|
||||
|
||||
# Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
|
||||
if global_context_att:
|
||||
self.linear1 = nn.Conv1d(in_dim * 3, attention_channels, kernel_size=1) # equals W and b in the paper
|
||||
else:
|
||||
self.linear1 = nn.Conv1d(in_dim, attention_channels, kernel_size=1) # equals W and b in the paper
|
||||
self.linear2 = nn.Conv1d(attention_channels, in_dim, kernel_size=1) # equals V and k in the paper
|
||||
|
||||
def forward(self, x):
|
||||
if self.global_context_att:
|
||||
context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
|
||||
context_std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
|
||||
x_in = torch.cat((x, context_mean, context_std), dim=1)
|
||||
else:
|
||||
x_in = x
|
||||
|
||||
# DON'T use ReLU here! In experiments, I find ReLU hard to converge.
|
||||
alpha = torch.tanh(self.linear1(x_in))
|
||||
# alpha = F.relu(self.linear1(x_in))
|
||||
alpha = torch.softmax(self.linear2(alpha), dim=2)
|
||||
mean = torch.sum(alpha * x, dim=2)
|
||||
residuals = torch.sum(alpha * (x**2), dim=2) - mean**2
|
||||
std = torch.sqrt(residuals.clamp(min=1e-9))
|
||||
return torch.cat([mean, std], dim=1)
|
||||
|
||||
|
||||
class ECAPA_TDNN(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
feat_dim=80,
|
||||
channels=512,
|
||||
emb_dim=192,
|
||||
global_context_att=False,
|
||||
feat_type="wavlm_large",
|
||||
sr=16000,
|
||||
feature_selection="hidden_states",
|
||||
update_extract=False,
|
||||
config_path=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.feat_type = feat_type
|
||||
self.feature_selection = feature_selection
|
||||
self.update_extract = update_extract
|
||||
self.sr = sr
|
||||
|
||||
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
|
||||
try:
|
||||
local_s3prl_path = os.path.expanduser("~/.cache/torch/hub/s3prl_s3prl_main")
|
||||
self.feature_extract = torch.hub.load(local_s3prl_path, feat_type, source="local", config_path=config_path)
|
||||
except: # noqa: E722
|
||||
self.feature_extract = torch.hub.load("s3prl/s3prl", feat_type)
|
||||
|
||||
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
|
||||
self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"
|
||||
):
|
||||
self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = False
|
||||
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
|
||||
self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention"
|
||||
):
|
||||
self.feature_extract.model.encoder.layers[11].self_attn.fp32_attention = False
|
||||
|
||||
self.feat_num = self.get_feat_num()
|
||||
self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
|
||||
|
||||
if feat_type != "fbank" and feat_type != "mfcc":
|
||||
freeze_list = ["final_proj", "label_embs_concat", "mask_emb", "project_q", "quantizer"]
|
||||
for name, param in self.feature_extract.named_parameters():
|
||||
for freeze_val in freeze_list:
|
||||
if freeze_val in name:
|
||||
param.requires_grad = False
|
||||
break
|
||||
|
||||
if not self.update_extract:
|
||||
for param in self.feature_extract.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
self.instance_norm = nn.InstanceNorm1d(feat_dim)
|
||||
# self.channels = [channels] * 4 + [channels * 3]
|
||||
self.channels = [channels] * 4 + [1536]
|
||||
|
||||
self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2)
|
||||
self.layer2 = SE_Res2Block(
|
||||
self.channels[0],
|
||||
self.channels[1],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=2,
|
||||
dilation=2,
|
||||
scale=8,
|
||||
se_bottleneck_dim=128,
|
||||
)
|
||||
self.layer3 = SE_Res2Block(
|
||||
self.channels[1],
|
||||
self.channels[2],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=3,
|
||||
dilation=3,
|
||||
scale=8,
|
||||
se_bottleneck_dim=128,
|
||||
)
|
||||
self.layer4 = SE_Res2Block(
|
||||
self.channels[2],
|
||||
self.channels[3],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=4,
|
||||
dilation=4,
|
||||
scale=8,
|
||||
se_bottleneck_dim=128,
|
||||
)
|
||||
|
||||
# self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1)
|
||||
cat_channels = channels * 3
|
||||
self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1)
|
||||
self.pooling = AttentiveStatsPool(
|
||||
self.channels[-1], attention_channels=128, global_context_att=global_context_att
|
||||
)
|
||||
self.bn = nn.BatchNorm1d(self.channels[-1] * 2)
|
||||
self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
|
||||
|
||||
def get_feat_num(self):
|
||||
self.feature_extract.eval()
|
||||
wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)]
|
||||
with torch.no_grad():
|
||||
features = self.feature_extract(wav)
|
||||
select_feature = features[self.feature_selection]
|
||||
if isinstance(select_feature, (list, tuple)):
|
||||
return len(select_feature)
|
||||
else:
|
||||
return 1
|
||||
|
||||
def get_feat(self, x):
|
||||
if self.update_extract:
|
||||
x = self.feature_extract([sample for sample in x])
|
||||
else:
|
||||
with torch.no_grad():
|
||||
if self.feat_type == "fbank" or self.feat_type == "mfcc":
|
||||
x = self.feature_extract(x) + 1e-6 # B x feat_dim x time_len
|
||||
else:
|
||||
x = self.feature_extract([sample for sample in x])
|
||||
|
||||
if self.feat_type == "fbank":
|
||||
x = x.log()
|
||||
|
||||
if self.feat_type != "fbank" and self.feat_type != "mfcc":
|
||||
x = x[self.feature_selection]
|
||||
if isinstance(x, (list, tuple)):
|
||||
x = torch.stack(x, dim=0)
|
||||
else:
|
||||
x = x.unsqueeze(0)
|
||||
norm_weights = F.softmax(self.feature_weight, dim=-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||
x = (norm_weights * x).sum(dim=0)
|
||||
x = torch.transpose(x, 1, 2) + 1e-6
|
||||
|
||||
x = self.instance_norm(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self.get_feat(x)
|
||||
|
||||
out1 = self.layer1(x)
|
||||
out2 = self.layer2(out1)
|
||||
out3 = self.layer3(out2)
|
||||
out4 = self.layer4(out3)
|
||||
|
||||
out = torch.cat([out2, out3, out4], dim=1)
|
||||
out = F.relu(self.conv(out))
|
||||
out = self.bn(self.pooling(out))
|
||||
out = self.linear(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def ECAPA_TDNN_SMALL(
|
||||
feat_dim,
|
||||
emb_dim=256,
|
||||
feat_type="wavlm_large",
|
||||
sr=16000,
|
||||
feature_selection="hidden_states",
|
||||
update_extract=False,
|
||||
config_path=None,
|
||||
):
|
||||
return ECAPA_TDNN(
|
||||
feat_dim=feat_dim,
|
||||
channels=512,
|
||||
emb_dim=emb_dim,
|
||||
feat_type=feat_type,
|
||||
sr=sr,
|
||||
feature_selection=feature_selection,
|
||||
update_extract=update_extract,
|
||||
config_path=config_path,
|
||||
)
|
||||
210
bi_v100-f5-tts/F5-TTS/src/f5_tts/eval/eval_infer_batch.py
Normal file
210
bi_v100-f5-tts/F5-TTS/src/f5_tts/eval/eval_infer_batch.py
Normal file
@@ -0,0 +1,210 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
import argparse
|
||||
import time
|
||||
from importlib.resources import files
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
from accelerate import Accelerator
|
||||
from hydra.utils import get_class
|
||||
from omegaconf import OmegaConf
|
||||
from tqdm import tqdm
|
||||
|
||||
from f5_tts.eval.utils_eval import (
|
||||
get_inference_prompt,
|
||||
get_librispeech_test_clean_metainfo,
|
||||
get_seedtts_testset_metainfo,
|
||||
)
|
||||
from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder
|
||||
from f5_tts.model import CFM
|
||||
from f5_tts.model.utils import get_tokenizer
|
||||
|
||||
|
||||
accelerator = Accelerator()
|
||||
device = f"cuda:{accelerator.process_index}"
|
||||
|
||||
|
||||
use_ema = True
|
||||
target_rms = 0.1
|
||||
|
||||
|
||||
rel_path = str(files("f5_tts").joinpath("../../"))
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="batch inference")
|
||||
|
||||
parser.add_argument("-s", "--seed", default=None, type=int)
|
||||
parser.add_argument("-n", "--expname", required=True)
|
||||
parser.add_argument("-c", "--ckptstep", default=1250000, type=int)
|
||||
|
||||
parser.add_argument("-nfe", "--nfestep", default=32, type=int)
|
||||
parser.add_argument("-o", "--odemethod", default="euler")
|
||||
parser.add_argument("-ss", "--swaysampling", default=-1, type=float)
|
||||
|
||||
parser.add_argument("-t", "--testset", required=True)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
seed = args.seed
|
||||
exp_name = args.expname
|
||||
ckpt_step = args.ckptstep
|
||||
|
||||
nfe_step = args.nfestep
|
||||
ode_method = args.odemethod
|
||||
sway_sampling_coef = args.swaysampling
|
||||
|
||||
testset = args.testset
|
||||
|
||||
infer_batch_size = 1 # max frames. 1 for ddp single inference (recommended)
|
||||
cfg_strength = 2.0
|
||||
speed = 1.0
|
||||
use_truth_duration = False
|
||||
no_ref_audio = False
|
||||
|
||||
model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{exp_name}.yaml")))
|
||||
model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
|
||||
model_arc = model_cfg.model.arch
|
||||
|
||||
dataset_name = model_cfg.datasets.name
|
||||
tokenizer = model_cfg.model.tokenizer
|
||||
|
||||
mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
|
||||
target_sample_rate = model_cfg.model.mel_spec.target_sample_rate
|
||||
n_mel_channels = model_cfg.model.mel_spec.n_mel_channels
|
||||
hop_length = model_cfg.model.mel_spec.hop_length
|
||||
win_length = model_cfg.model.mel_spec.win_length
|
||||
n_fft = model_cfg.model.mel_spec.n_fft
|
||||
|
||||
if testset == "ls_pc_test_clean":
|
||||
metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst"
|
||||
librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
|
||||
metainfo = get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path)
|
||||
|
||||
elif testset == "seedtts_test_zh":
|
||||
metalst = rel_path + "/data/seedtts_testset/zh/meta.lst"
|
||||
metainfo = get_seedtts_testset_metainfo(metalst)
|
||||
|
||||
elif testset == "seedtts_test_en":
|
||||
metalst = rel_path + "/data/seedtts_testset/en/meta.lst"
|
||||
metainfo = get_seedtts_testset_metainfo(metalst)
|
||||
|
||||
# path to save genereted wavs
|
||||
output_dir = (
|
||||
f"{rel_path}/"
|
||||
f"results/{exp_name}_{ckpt_step}/{testset}/"
|
||||
f"seed{seed}_{ode_method}_nfe{nfe_step}_{mel_spec_type}"
|
||||
f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}"
|
||||
f"_cfg{cfg_strength}_speed{speed}"
|
||||
f"{'_gt-dur' if use_truth_duration else ''}"
|
||||
f"{'_no-ref-audio' if no_ref_audio else ''}"
|
||||
)
|
||||
|
||||
# -------------------------------------------------#
|
||||
|
||||
prompts_all = get_inference_prompt(
|
||||
metainfo,
|
||||
speed=speed,
|
||||
tokenizer=tokenizer,
|
||||
target_sample_rate=target_sample_rate,
|
||||
n_mel_channels=n_mel_channels,
|
||||
hop_length=hop_length,
|
||||
mel_spec_type=mel_spec_type,
|
||||
target_rms=target_rms,
|
||||
use_truth_duration=use_truth_duration,
|
||||
infer_batch_size=infer_batch_size,
|
||||
)
|
||||
|
||||
# Vocoder model
|
||||
local = False
|
||||
if mel_spec_type == "vocos":
|
||||
vocoder_local_path = "../checkpoints/charactr/vocos-mel-24khz"
|
||||
elif mel_spec_type == "bigvgan":
|
||||
vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
|
||||
vocoder = load_vocoder(vocoder_name=mel_spec_type, is_local=local, local_path=vocoder_local_path)
|
||||
|
||||
# Tokenizer
|
||||
vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
|
||||
|
||||
# Model
|
||||
model = CFM(
|
||||
transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
|
||||
mel_spec_kwargs=dict(
|
||||
n_fft=n_fft,
|
||||
hop_length=hop_length,
|
||||
win_length=win_length,
|
||||
n_mel_channels=n_mel_channels,
|
||||
target_sample_rate=target_sample_rate,
|
||||
mel_spec_type=mel_spec_type,
|
||||
),
|
||||
odeint_kwargs=dict(
|
||||
method=ode_method,
|
||||
),
|
||||
vocab_char_map=vocab_char_map,
|
||||
).to(device)
|
||||
|
||||
ckpt_prefix = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}"
|
||||
if os.path.exists(ckpt_prefix + ".pt"):
|
||||
ckpt_path = ckpt_prefix + ".pt"
|
||||
elif os.path.exists(ckpt_prefix + ".safetensors"):
|
||||
ckpt_path = ckpt_prefix + ".safetensors"
|
||||
else:
|
||||
print("Loading from self-organized training checkpoints rather than released pretrained.")
|
||||
ckpt_path = rel_path + f"/{model_cfg.ckpts.save_dir}/model_{ckpt_step}.pt"
|
||||
|
||||
dtype = torch.float32 if mel_spec_type == "bigvgan" else None
|
||||
model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
|
||||
|
||||
if not os.path.exists(output_dir) and accelerator.is_main_process:
|
||||
os.makedirs(output_dir)
|
||||
|
||||
# start batch inference
|
||||
accelerator.wait_for_everyone()
|
||||
start = time.time()
|
||||
|
||||
with accelerator.split_between_processes(prompts_all) as prompts:
|
||||
for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process):
|
||||
utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = prompt
|
||||
ref_mels = ref_mels.to(device)
|
||||
ref_mel_lens = torch.tensor(ref_mel_lens, dtype=torch.long).to(device)
|
||||
total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device)
|
||||
|
||||
# Inference
|
||||
with torch.inference_mode():
|
||||
generated, _ = model.sample(
|
||||
cond=ref_mels,
|
||||
text=final_text_list,
|
||||
duration=total_mel_lens,
|
||||
lens=ref_mel_lens,
|
||||
steps=nfe_step,
|
||||
cfg_strength=cfg_strength,
|
||||
sway_sampling_coef=sway_sampling_coef,
|
||||
no_ref_audio=no_ref_audio,
|
||||
seed=seed,
|
||||
)
|
||||
# Final result
|
||||
for i, gen in enumerate(generated):
|
||||
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
|
||||
gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32)
|
||||
if mel_spec_type == "vocos":
|
||||
generated_wave = vocoder.decode(gen_mel_spec).cpu()
|
||||
elif mel_spec_type == "bigvgan":
|
||||
generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu()
|
||||
|
||||
if ref_rms_list[i] < target_rms:
|
||||
generated_wave = generated_wave * ref_rms_list[i] / target_rms
|
||||
torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave, target_sample_rate)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
timediff = time.time() - start
|
||||
print(f"Done batch inference in {timediff / 60:.2f} minutes.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
18
bi_v100-f5-tts/F5-TTS/src/f5_tts/eval/eval_infer_batch.sh
Normal file
18
bi_v100-f5-tts/F5-TTS/src/f5_tts/eval/eval_infer_batch.sh
Normal file
@@ -0,0 +1,18 @@
|
||||
#!/bin/bash
|
||||
|
||||
# e.g. F5-TTS, 16 NFE
|
||||
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "seedtts_test_zh" -nfe 16
|
||||
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "seedtts_test_en" -nfe 16
|
||||
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "ls_pc_test_clean" -nfe 16
|
||||
|
||||
# e.g. Vanilla E2 TTS, 32 NFE
|
||||
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "seedtts_test_zh" -o "midpoint" -ss 0
|
||||
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "seedtts_test_en" -o "midpoint" -ss 0
|
||||
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "ls_pc_test_clean" -o "midpoint" -ss 0
|
||||
|
||||
# e.g. evaluate F5-TTS 16 NFE result on Seed-TTS test-zh
|
||||
python src/f5_tts/eval/eval_seedtts_testset.py -e wer -l zh --gen_wav_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe32_vocos_ss-1_cfg2.0_speed1.0 --gpu_nums 8
|
||||
python src/f5_tts/eval/eval_seedtts_testset.py -e sim -l zh --gen_wav_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe32_vocos_ss-1_cfg2.0_speed1.0 --gpu_nums 8
|
||||
python src/f5_tts/eval/eval_utmos.py --audio_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe32_vocos_ss-1_cfg2.0_speed1.0
|
||||
|
||||
# etc.
|
||||
@@ -0,0 +1,89 @@
|
||||
# Evaluate with Librispeech test-clean, ~3s prompt to generate 4-10s audio (the way of valle/voicebox evaluation)
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
import multiprocessing as mp
|
||||
from importlib.resources import files
|
||||
|
||||
import numpy as np
|
||||
|
||||
from f5_tts.eval.utils_eval import get_librispeech_test, run_asr_wer, run_sim
|
||||
|
||||
|
||||
rel_path = str(files("f5_tts").joinpath("../../"))
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-e", "--eval_task", type=str, default="wer", choices=["sim", "wer"])
|
||||
parser.add_argument("-l", "--lang", type=str, default="en")
|
||||
parser.add_argument("-g", "--gen_wav_dir", type=str, required=True)
|
||||
parser.add_argument("-p", "--librispeech_test_clean_path", type=str, required=True)
|
||||
parser.add_argument("-n", "--gpu_nums", type=int, default=8, help="Number of GPUs to use")
|
||||
parser.add_argument("--local", action="store_true", help="Use local custom checkpoint directory")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
eval_task = args.eval_task
|
||||
lang = args.lang
|
||||
librispeech_test_clean_path = args.librispeech_test_clean_path # test-clean path
|
||||
gen_wav_dir = args.gen_wav_dir
|
||||
metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst"
|
||||
|
||||
gpus = list(range(args.gpu_nums))
|
||||
test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path)
|
||||
|
||||
## In LibriSpeech, some speakers utilized varying voice characteristics for different characters in the book,
|
||||
## leading to a low similarity for the ground truth in some cases.
|
||||
# test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth = True) # eval ground truth
|
||||
|
||||
local = args.local
|
||||
if local: # use local custom checkpoint dir
|
||||
asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3"
|
||||
else:
|
||||
asr_ckpt_dir = "" # auto download to cache dir
|
||||
wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
|
||||
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
full_results = []
|
||||
metrics = []
|
||||
|
||||
if eval_task == "wer":
|
||||
with mp.Pool(processes=len(gpus)) as pool:
|
||||
args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
|
||||
results = pool.map(run_asr_wer, args)
|
||||
for r in results:
|
||||
full_results.extend(r)
|
||||
elif eval_task == "sim":
|
||||
with mp.Pool(processes=len(gpus)) as pool:
|
||||
args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
|
||||
results = pool.map(run_sim, args)
|
||||
for r in results:
|
||||
full_results.extend(r)
|
||||
else:
|
||||
raise ValueError(f"Unknown metric type: {eval_task}")
|
||||
|
||||
result_path = f"{gen_wav_dir}/_{eval_task}_results.jsonl"
|
||||
with open(result_path, "w") as f:
|
||||
for line in full_results:
|
||||
metrics.append(line[eval_task])
|
||||
f.write(json.dumps(line, ensure_ascii=False) + "\n")
|
||||
metric = round(np.mean(metrics), 5)
|
||||
f.write(f"\n{eval_task.upper()}: {metric}\n")
|
||||
|
||||
print(f"\nTotal {len(metrics)} samples")
|
||||
print(f"{eval_task.upper()}: {metric}")
|
||||
print(f"{eval_task.upper()} results saved to {result_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,88 @@
|
||||
# Evaluate with Seed-TTS testset
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
import multiprocessing as mp
|
||||
from importlib.resources import files
|
||||
|
||||
import numpy as np
|
||||
|
||||
from f5_tts.eval.utils_eval import get_seed_tts_test, run_asr_wer, run_sim
|
||||
|
||||
|
||||
rel_path = str(files("f5_tts").joinpath("../../"))
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-e", "--eval_task", type=str, default="wer", choices=["sim", "wer"])
|
||||
parser.add_argument("-l", "--lang", type=str, default="en", choices=["zh", "en"])
|
||||
parser.add_argument("-g", "--gen_wav_dir", type=str, required=True)
|
||||
parser.add_argument("-n", "--gpu_nums", type=int, default=8, help="Number of GPUs to use")
|
||||
parser.add_argument("--local", action="store_true", help="Use local custom checkpoint directory")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
eval_task = args.eval_task
|
||||
lang = args.lang
|
||||
gen_wav_dir = args.gen_wav_dir
|
||||
metalst = rel_path + f"/data/seedtts_testset/{lang}/meta.lst" # seed-tts testset
|
||||
|
||||
# NOTE. paraformer-zh result will be slightly different according to the number of gpus, cuz batchsize is different
|
||||
# zh 1.254 seems a result of 4 workers wer_seed_tts
|
||||
gpus = list(range(args.gpu_nums))
|
||||
test_set = get_seed_tts_test(metalst, gen_wav_dir, gpus)
|
||||
|
||||
local = args.local
|
||||
if local: # use local custom checkpoint dir
|
||||
if lang == "zh":
|
||||
asr_ckpt_dir = "../checkpoints/funasr" # paraformer-zh dir under funasr
|
||||
elif lang == "en":
|
||||
asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3"
|
||||
else:
|
||||
asr_ckpt_dir = "" # auto download to cache dir
|
||||
wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
|
||||
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
full_results = []
|
||||
metrics = []
|
||||
|
||||
if eval_task == "wer":
|
||||
with mp.Pool(processes=len(gpus)) as pool:
|
||||
args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
|
||||
results = pool.map(run_asr_wer, args)
|
||||
for r in results:
|
||||
full_results.extend(r)
|
||||
elif eval_task == "sim":
|
||||
with mp.Pool(processes=len(gpus)) as pool:
|
||||
args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
|
||||
results = pool.map(run_sim, args)
|
||||
for r in results:
|
||||
full_results.extend(r)
|
||||
else:
|
||||
raise ValueError(f"Unknown metric type: {eval_task}")
|
||||
|
||||
result_path = f"{gen_wav_dir}/_{eval_task}_results.jsonl"
|
||||
with open(result_path, "w") as f:
|
||||
for line in full_results:
|
||||
metrics.append(line[eval_task])
|
||||
f.write(json.dumps(line, ensure_ascii=False) + "\n")
|
||||
metric = round(np.mean(metrics), 5)
|
||||
f.write(f"\n{eval_task.upper()}: {metric}\n")
|
||||
|
||||
print(f"\nTotal {len(metrics)} samples")
|
||||
print(f"{eval_task.upper()}: {metric}")
|
||||
print(f"{eval_task.upper()} results saved to {result_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
42
bi_v100-f5-tts/F5-TTS/src/f5_tts/eval/eval_utmos.py
Normal file
42
bi_v100-f5-tts/F5-TTS/src/f5_tts/eval/eval_utmos.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import librosa
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="UTMOS Evaluation")
|
||||
parser.add_argument("--audio_dir", type=str, required=True, help="Audio file path.")
|
||||
parser.add_argument("--ext", type=str, default="wav", help="Audio extension.")
|
||||
args = parser.parse_args()
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "xpu" if torch.xpu.is_available() else "cpu"
|
||||
|
||||
predictor = torch.hub.load("tarepan/SpeechMOS:v1.2.0", "utmos22_strong", trust_repo=True)
|
||||
predictor = predictor.to(device)
|
||||
|
||||
audio_paths = list(Path(args.audio_dir).rglob(f"*.{args.ext}"))
|
||||
utmos_score = 0
|
||||
|
||||
utmos_result_path = Path(args.audio_dir) / "_utmos_results.jsonl"
|
||||
with open(utmos_result_path, "w", encoding="utf-8") as f:
|
||||
for audio_path in tqdm(audio_paths, desc="Processing"):
|
||||
wav, sr = librosa.load(audio_path, sr=None, mono=True)
|
||||
wav_tensor = torch.from_numpy(wav).to(device).unsqueeze(0)
|
||||
score = predictor(wav_tensor, sr)
|
||||
line = {}
|
||||
line["wav"], line["utmos"] = str(audio_path.stem), score.item()
|
||||
utmos_score += score.item()
|
||||
f.write(json.dumps(line, ensure_ascii=False) + "\n")
|
||||
avg_score = utmos_score / len(audio_paths) if len(audio_paths) > 0 else 0
|
||||
f.write(f"\nUTMOS: {avg_score:.4f}\n")
|
||||
|
||||
print(f"UTMOS: {avg_score:.4f}")
|
||||
print(f"UTMOS results saved to {utmos_result_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
419
bi_v100-f5-tts/F5-TTS/src/f5_tts/eval/utils_eval.py
Normal file
419
bi_v100-f5-tts/F5-TTS/src/f5_tts/eval/utils_eval.py
Normal file
@@ -0,0 +1,419 @@
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import string
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
from tqdm import tqdm
|
||||
|
||||
from f5_tts.eval.ecapa_tdnn import ECAPA_TDNN_SMALL
|
||||
from f5_tts.model.modules import MelSpec
|
||||
from f5_tts.model.utils import convert_char_to_pinyin
|
||||
|
||||
|
||||
# seedtts testset metainfo: utt, prompt_text, prompt_wav, gt_text, gt_wav
|
||||
def get_seedtts_testset_metainfo(metalst):
|
||||
f = open(metalst)
|
||||
lines = f.readlines()
|
||||
f.close()
|
||||
metainfo = []
|
||||
for line in lines:
|
||||
if len(line.strip().split("|")) == 5:
|
||||
utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split("|")
|
||||
elif len(line.strip().split("|")) == 4:
|
||||
utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
|
||||
gt_wav = os.path.join(os.path.dirname(metalst), "wavs", utt + ".wav")
|
||||
if not os.path.isabs(prompt_wav):
|
||||
prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
|
||||
metainfo.append((utt, prompt_text, prompt_wav, gt_text, gt_wav))
|
||||
return metainfo
|
||||
|
||||
|
||||
# librispeech test-clean metainfo: gen_utt, ref_txt, ref_wav, gen_txt, gen_wav
|
||||
def get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path):
|
||||
f = open(metalst)
|
||||
lines = f.readlines()
|
||||
f.close()
|
||||
metainfo = []
|
||||
for line in lines:
|
||||
ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split("\t")
|
||||
|
||||
# ref_txt = ref_txt[0] + ref_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
|
||||
ref_spk_id, ref_chaptr_id, _ = ref_utt.split("-")
|
||||
ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + ".flac")
|
||||
|
||||
# gen_txt = gen_txt[0] + gen_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
|
||||
gen_spk_id, gen_chaptr_id, _ = gen_utt.split("-")
|
||||
gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + ".flac")
|
||||
|
||||
metainfo.append((gen_utt, ref_txt, ref_wav, " " + gen_txt, gen_wav))
|
||||
|
||||
return metainfo
|
||||
|
||||
|
||||
# padded to max length mel batch
|
||||
def padded_mel_batch(ref_mels):
|
||||
max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax()
|
||||
padded_ref_mels = []
|
||||
for mel in ref_mels:
|
||||
padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value=0)
|
||||
padded_ref_mels.append(padded_ref_mel)
|
||||
padded_ref_mels = torch.stack(padded_ref_mels)
|
||||
padded_ref_mels = padded_ref_mels.permute(0, 2, 1)
|
||||
return padded_ref_mels
|
||||
|
||||
|
||||
# get prompts from metainfo containing: utt, prompt_text, prompt_wav, gt_text, gt_wav
|
||||
|
||||
|
||||
def get_inference_prompt(
|
||||
metainfo,
|
||||
speed=1.0,
|
||||
tokenizer="pinyin",
|
||||
polyphone=True,
|
||||
target_sample_rate=24000,
|
||||
n_fft=1024,
|
||||
win_length=1024,
|
||||
n_mel_channels=100,
|
||||
hop_length=256,
|
||||
mel_spec_type="vocos",
|
||||
target_rms=0.1,
|
||||
use_truth_duration=False,
|
||||
infer_batch_size=1,
|
||||
num_buckets=200,
|
||||
min_secs=3,
|
||||
max_secs=40,
|
||||
):
|
||||
prompts_all = []
|
||||
|
||||
min_tokens = min_secs * target_sample_rate // hop_length
|
||||
max_tokens = max_secs * target_sample_rate // hop_length
|
||||
|
||||
batch_accum = [0] * num_buckets
|
||||
utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = (
|
||||
[[] for _ in range(num_buckets)] for _ in range(6)
|
||||
)
|
||||
|
||||
mel_spectrogram = MelSpec(
|
||||
n_fft=n_fft,
|
||||
hop_length=hop_length,
|
||||
win_length=win_length,
|
||||
n_mel_channels=n_mel_channels,
|
||||
target_sample_rate=target_sample_rate,
|
||||
mel_spec_type=mel_spec_type,
|
||||
)
|
||||
|
||||
for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(metainfo, desc="Processing prompts..."):
|
||||
# Audio
|
||||
ref_audio, ref_sr = torchaudio.load(prompt_wav)
|
||||
ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio)))
|
||||
if ref_rms < target_rms:
|
||||
ref_audio = ref_audio * target_rms / ref_rms
|
||||
assert ref_audio.shape[-1] > 5000, f"Empty prompt wav: {prompt_wav}, or torchaudio backend issue."
|
||||
if ref_sr != target_sample_rate:
|
||||
resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
|
||||
ref_audio = resampler(ref_audio)
|
||||
|
||||
# Text
|
||||
if len(prompt_text[-1].encode("utf-8")) == 1:
|
||||
prompt_text = prompt_text + " "
|
||||
text = [prompt_text + gt_text]
|
||||
if tokenizer == "pinyin":
|
||||
text_list = convert_char_to_pinyin(text, polyphone=polyphone)
|
||||
else:
|
||||
text_list = text
|
||||
|
||||
# to mel spectrogram
|
||||
ref_mel = mel_spectrogram(ref_audio)
|
||||
ref_mel = ref_mel.squeeze(0)
|
||||
|
||||
# Duration, mel frame length
|
||||
ref_mel_len = ref_mel.shape[-1]
|
||||
|
||||
if use_truth_duration:
|
||||
gt_audio, gt_sr = torchaudio.load(gt_wav)
|
||||
if gt_sr != target_sample_rate:
|
||||
resampler = torchaudio.transforms.Resample(gt_sr, target_sample_rate)
|
||||
gt_audio = resampler(gt_audio)
|
||||
total_mel_len = ref_mel_len + int(gt_audio.shape[-1] / hop_length / speed)
|
||||
|
||||
# # test vocoder resynthesis
|
||||
# ref_audio = gt_audio
|
||||
else:
|
||||
ref_text_len = len(prompt_text.encode("utf-8"))
|
||||
gen_text_len = len(gt_text.encode("utf-8"))
|
||||
total_mel_len = ref_mel_len + int(ref_mel_len / ref_text_len * gen_text_len / speed)
|
||||
|
||||
# deal with batch
|
||||
assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
|
||||
assert min_tokens <= total_mel_len <= max_tokens, (
|
||||
f"Audio {utt} has duration {total_mel_len * hop_length // target_sample_rate}s out of range [{min_secs}, {max_secs}]."
|
||||
)
|
||||
bucket_i = math.floor((total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets)
|
||||
|
||||
utts[bucket_i].append(utt)
|
||||
ref_rms_list[bucket_i].append(ref_rms)
|
||||
ref_mels[bucket_i].append(ref_mel)
|
||||
ref_mel_lens[bucket_i].append(ref_mel_len)
|
||||
total_mel_lens[bucket_i].append(total_mel_len)
|
||||
final_text_list[bucket_i].extend(text_list)
|
||||
|
||||
batch_accum[bucket_i] += total_mel_len
|
||||
|
||||
if batch_accum[bucket_i] >= infer_batch_size:
|
||||
# print(f"\n{len(ref_mels[bucket_i][0][0])}\n{ref_mel_lens[bucket_i]}\n{total_mel_lens[bucket_i]}")
|
||||
prompts_all.append(
|
||||
(
|
||||
utts[bucket_i],
|
||||
ref_rms_list[bucket_i],
|
||||
padded_mel_batch(ref_mels[bucket_i]),
|
||||
ref_mel_lens[bucket_i],
|
||||
total_mel_lens[bucket_i],
|
||||
final_text_list[bucket_i],
|
||||
)
|
||||
)
|
||||
batch_accum[bucket_i] = 0
|
||||
(
|
||||
utts[bucket_i],
|
||||
ref_rms_list[bucket_i],
|
||||
ref_mels[bucket_i],
|
||||
ref_mel_lens[bucket_i],
|
||||
total_mel_lens[bucket_i],
|
||||
final_text_list[bucket_i],
|
||||
) = [], [], [], [], [], []
|
||||
|
||||
# add residual
|
||||
for bucket_i, bucket_frames in enumerate(batch_accum):
|
||||
if bucket_frames > 0:
|
||||
prompts_all.append(
|
||||
(
|
||||
utts[bucket_i],
|
||||
ref_rms_list[bucket_i],
|
||||
padded_mel_batch(ref_mels[bucket_i]),
|
||||
ref_mel_lens[bucket_i],
|
||||
total_mel_lens[bucket_i],
|
||||
final_text_list[bucket_i],
|
||||
)
|
||||
)
|
||||
# not only leave easy work for last workers
|
||||
random.seed(666)
|
||||
random.shuffle(prompts_all)
|
||||
|
||||
return prompts_all
|
||||
|
||||
|
||||
# get wav_res_ref_text of seed-tts test metalst
|
||||
# https://github.com/BytedanceSpeech/seed-tts-eval
|
||||
|
||||
|
||||
def get_seed_tts_test(metalst, gen_wav_dir, gpus):
|
||||
f = open(metalst)
|
||||
lines = f.readlines()
|
||||
f.close()
|
||||
|
||||
test_set_ = []
|
||||
for line in tqdm(lines):
|
||||
if len(line.strip().split("|")) == 5:
|
||||
utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split("|")
|
||||
elif len(line.strip().split("|")) == 4:
|
||||
utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
|
||||
|
||||
if not os.path.exists(os.path.join(gen_wav_dir, utt + ".wav")):
|
||||
continue
|
||||
gen_wav = os.path.join(gen_wav_dir, utt + ".wav")
|
||||
if not os.path.isabs(prompt_wav):
|
||||
prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
|
||||
|
||||
test_set_.append((gen_wav, prompt_wav, gt_text))
|
||||
|
||||
num_jobs = len(gpus)
|
||||
if num_jobs == 1:
|
||||
return [(gpus[0], test_set_)]
|
||||
|
||||
wav_per_job = len(test_set_) // num_jobs + 1
|
||||
test_set = []
|
||||
for i in range(num_jobs):
|
||||
test_set.append((gpus[i], test_set_[i * wav_per_job : (i + 1) * wav_per_job]))
|
||||
|
||||
return test_set
|
||||
|
||||
|
||||
# get librispeech test-clean cross sentence test
|
||||
|
||||
|
||||
def get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth=False):
|
||||
f = open(metalst)
|
||||
lines = f.readlines()
|
||||
f.close()
|
||||
|
||||
test_set_ = []
|
||||
for line in tqdm(lines):
|
||||
ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split("\t")
|
||||
|
||||
if eval_ground_truth:
|
||||
gen_spk_id, gen_chaptr_id, _ = gen_utt.split("-")
|
||||
gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + ".flac")
|
||||
else:
|
||||
if not os.path.exists(os.path.join(gen_wav_dir, gen_utt + ".wav")):
|
||||
raise FileNotFoundError(f"Generated wav not found: {gen_utt}")
|
||||
gen_wav = os.path.join(gen_wav_dir, gen_utt + ".wav")
|
||||
|
||||
ref_spk_id, ref_chaptr_id, _ = ref_utt.split("-")
|
||||
ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + ".flac")
|
||||
|
||||
test_set_.append((gen_wav, ref_wav, gen_txt))
|
||||
|
||||
num_jobs = len(gpus)
|
||||
if num_jobs == 1:
|
||||
return [(gpus[0], test_set_)]
|
||||
|
||||
wav_per_job = len(test_set_) // num_jobs + 1
|
||||
test_set = []
|
||||
for i in range(num_jobs):
|
||||
test_set.append((gpus[i], test_set_[i * wav_per_job : (i + 1) * wav_per_job]))
|
||||
|
||||
return test_set
|
||||
|
||||
|
||||
# load asr model
|
||||
|
||||
|
||||
def load_asr_model(lang, ckpt_dir=""):
|
||||
if lang == "zh":
|
||||
from funasr import AutoModel
|
||||
|
||||
model = AutoModel(
|
||||
model=os.path.join(ckpt_dir, "paraformer-zh"),
|
||||
# vad_model = os.path.join(ckpt_dir, "fsmn-vad"),
|
||||
# punc_model = os.path.join(ckpt_dir, "ct-punc"),
|
||||
# spk_model = os.path.join(ckpt_dir, "cam++"),
|
||||
disable_update=True,
|
||||
) # following seed-tts setting
|
||||
elif lang == "en":
|
||||
from faster_whisper import WhisperModel
|
||||
|
||||
model_size = "large-v3" if ckpt_dir == "" else ckpt_dir
|
||||
model = WhisperModel(model_size, device="cuda", compute_type="float16")
|
||||
return model
|
||||
|
||||
|
||||
# WER Evaluation, the way Seed-TTS does
|
||||
|
||||
|
||||
def run_asr_wer(args):
|
||||
rank, lang, test_set, ckpt_dir = args
|
||||
|
||||
if lang == "zh":
|
||||
import zhconv
|
||||
|
||||
torch.cuda.set_device(rank)
|
||||
elif lang == "en":
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"lang support only 'zh' (funasr paraformer-zh), 'en' (faster-whisper-large-v3), for now."
|
||||
)
|
||||
|
||||
asr_model = load_asr_model(lang, ckpt_dir=ckpt_dir)
|
||||
|
||||
from zhon.hanzi import punctuation
|
||||
|
||||
punctuation_all = punctuation + string.punctuation
|
||||
wer_results = []
|
||||
|
||||
from jiwer import compute_measures
|
||||
|
||||
for gen_wav, prompt_wav, truth in tqdm(test_set):
|
||||
if lang == "zh":
|
||||
res = asr_model.generate(input=gen_wav, batch_size_s=300, disable_pbar=True)
|
||||
hypo = res[0]["text"]
|
||||
hypo = zhconv.convert(hypo, "zh-cn")
|
||||
elif lang == "en":
|
||||
segments, _ = asr_model.transcribe(gen_wav, beam_size=5, language="en")
|
||||
hypo = ""
|
||||
for segment in segments:
|
||||
hypo = hypo + " " + segment.text
|
||||
|
||||
raw_truth = truth
|
||||
raw_hypo = hypo
|
||||
|
||||
for x in punctuation_all:
|
||||
truth = truth.replace(x, "")
|
||||
hypo = hypo.replace(x, "")
|
||||
|
||||
truth = truth.replace(" ", " ")
|
||||
hypo = hypo.replace(" ", " ")
|
||||
|
||||
if lang == "zh":
|
||||
truth = " ".join([x for x in truth])
|
||||
hypo = " ".join([x for x in hypo])
|
||||
elif lang == "en":
|
||||
truth = truth.lower()
|
||||
hypo = hypo.lower()
|
||||
|
||||
measures = compute_measures(truth, hypo)
|
||||
wer = measures["wer"]
|
||||
|
||||
# ref_list = truth.split(" ")
|
||||
# subs = measures["substitutions"] / len(ref_list)
|
||||
# dele = measures["deletions"] / len(ref_list)
|
||||
# inse = measures["insertions"] / len(ref_list)
|
||||
|
||||
wer_results.append(
|
||||
{
|
||||
"wav": Path(gen_wav).stem,
|
||||
"truth": raw_truth,
|
||||
"hypo": raw_hypo,
|
||||
"wer": wer,
|
||||
}
|
||||
)
|
||||
|
||||
return wer_results
|
||||
|
||||
|
||||
# SIM Evaluation
|
||||
|
||||
|
||||
def run_sim(args):
|
||||
rank, test_set, ckpt_dir = args
|
||||
device = f"cuda:{rank}"
|
||||
|
||||
model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type="wavlm_large", config_path=None)
|
||||
state_dict = torch.load(ckpt_dir, weights_only=True, map_location=lambda storage, loc: storage)
|
||||
model.load_state_dict(state_dict["model"], strict=False)
|
||||
|
||||
use_gpu = True if torch.cuda.is_available() else False
|
||||
if use_gpu:
|
||||
model = model.cuda(device)
|
||||
model.eval()
|
||||
|
||||
sim_results = []
|
||||
for gen_wav, prompt_wav, truth in tqdm(test_set):
|
||||
wav1, sr1 = torchaudio.load(gen_wav)
|
||||
wav2, sr2 = torchaudio.load(prompt_wav)
|
||||
|
||||
resample1 = torchaudio.transforms.Resample(orig_freq=sr1, new_freq=16000)
|
||||
resample2 = torchaudio.transforms.Resample(orig_freq=sr2, new_freq=16000)
|
||||
wav1 = resample1(wav1)
|
||||
wav2 = resample2(wav2)
|
||||
|
||||
if use_gpu:
|
||||
wav1 = wav1.cuda(device)
|
||||
wav2 = wav2.cuda(device)
|
||||
with torch.no_grad():
|
||||
emb1 = model(wav1)
|
||||
emb2 = model(wav2)
|
||||
|
||||
sim = F.cosine_similarity(emb1, emb2)[0].item()
|
||||
# print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).")
|
||||
sim_results.append(
|
||||
{
|
||||
"wav": Path(gen_wav).stem,
|
||||
"sim": sim,
|
||||
}
|
||||
)
|
||||
|
||||
return sim_results
|
||||
Reference in New Issue
Block a user