369 lines
10 KiB
Python
369 lines
10 KiB
Python
import argparse
|
|
import codecs
|
|
import os
|
|
import re
|
|
from datetime import datetime
|
|
from importlib.resources import files
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import soundfile as sf
|
|
import tomli
|
|
from cached_path import cached_path
|
|
from hydra.utils import get_class
|
|
from omegaconf import OmegaConf
|
|
|
|
from f5_tts.infer.utils_infer import (
|
|
cfg_strength,
|
|
cross_fade_duration,
|
|
device,
|
|
fix_duration,
|
|
infer_process,
|
|
load_model,
|
|
load_vocoder,
|
|
mel_spec_type,
|
|
nfe_step,
|
|
preprocess_ref_audio_text,
|
|
remove_silence_for_generated_wav,
|
|
speed,
|
|
sway_sampling_coef,
|
|
target_rms,
|
|
)
|
|
|
|
|
|
parser = argparse.ArgumentParser(
|
|
prog="python3 infer-cli.py",
|
|
description="Commandline interface for E2/F5 TTS with Advanced Batch Processing.",
|
|
epilog="Specify options above to override one or more settings from config.",
|
|
)
|
|
parser.add_argument(
|
|
"-c",
|
|
"--config",
|
|
type=str,
|
|
default=os.path.join(files("f5_tts").joinpath("infer/examples/basic"), "basic.toml"),
|
|
help="The configuration file, default see infer/examples/basic/basic.toml",
|
|
)
|
|
|
|
|
|
# Note. Not to provide default value here in order to read default from config file
|
|
|
|
parser.add_argument(
|
|
"-m",
|
|
"--model",
|
|
type=str,
|
|
help="The model name: F5TTS_v1_Base | F5TTS_Base | E2TTS_Base | etc.",
|
|
)
|
|
parser.add_argument(
|
|
"-mc",
|
|
"--model_cfg",
|
|
type=str,
|
|
help="The path to F5-TTS model config file .yaml",
|
|
)
|
|
parser.add_argument(
|
|
"-p",
|
|
"--ckpt_file",
|
|
type=str,
|
|
help="The path to model checkpoint .pt, leave blank to use default",
|
|
)
|
|
parser.add_argument(
|
|
"-v",
|
|
"--vocab_file",
|
|
type=str,
|
|
help="The path to vocab file .txt, leave blank to use default",
|
|
)
|
|
parser.add_argument(
|
|
"-r",
|
|
"--ref_audio",
|
|
type=str,
|
|
help="The reference audio file.",
|
|
)
|
|
parser.add_argument(
|
|
"-s",
|
|
"--ref_text",
|
|
type=str,
|
|
help="The transcript/subtitle for the reference audio",
|
|
)
|
|
parser.add_argument(
|
|
"-t",
|
|
"--gen_text",
|
|
type=str,
|
|
help="The text to make model synthesize a speech",
|
|
)
|
|
parser.add_argument(
|
|
"-f",
|
|
"--gen_file",
|
|
type=str,
|
|
help="The file with text to generate, will ignore --gen_text",
|
|
)
|
|
parser.add_argument(
|
|
"-o",
|
|
"--output_dir",
|
|
type=str,
|
|
help="The path to output folder",
|
|
)
|
|
parser.add_argument(
|
|
"-w",
|
|
"--output_file",
|
|
type=str,
|
|
help="The name of output file",
|
|
)
|
|
parser.add_argument(
|
|
"--save_chunk",
|
|
action="store_true",
|
|
help="To save each audio chunks during inference",
|
|
)
|
|
parser.add_argument(
|
|
"--remove_silence",
|
|
action="store_true",
|
|
help="To remove long silence found in ouput",
|
|
)
|
|
parser.add_argument(
|
|
"--load_vocoder_from_local",
|
|
action="store_true",
|
|
help="To load vocoder from local dir, default to ../checkpoints/vocos-mel-24khz",
|
|
)
|
|
parser.add_argument(
|
|
"--vocoder_name",
|
|
type=str,
|
|
choices=["vocos", "bigvgan"],
|
|
help=f"Used vocoder name: vocos | bigvgan, default {mel_spec_type}",
|
|
)
|
|
parser.add_argument(
|
|
"--target_rms",
|
|
type=float,
|
|
help=f"Target output speech loudness normalization value, default {target_rms}",
|
|
)
|
|
parser.add_argument(
|
|
"--cross_fade_duration",
|
|
type=float,
|
|
help=f"Duration of cross-fade between audio segments in seconds, default {cross_fade_duration}",
|
|
)
|
|
parser.add_argument(
|
|
"--nfe_step",
|
|
type=int,
|
|
help=f"The number of function evaluation (denoising steps), default {nfe_step}",
|
|
)
|
|
parser.add_argument(
|
|
"--cfg_strength",
|
|
type=float,
|
|
help=f"Classifier-free guidance strength, default {cfg_strength}",
|
|
)
|
|
parser.add_argument(
|
|
"--sway_sampling_coef",
|
|
type=float,
|
|
help=f"Sway Sampling coefficient, default {sway_sampling_coef}",
|
|
)
|
|
parser.add_argument(
|
|
"--speed",
|
|
type=float,
|
|
help=f"The speed of the generated audio, default {speed}",
|
|
)
|
|
parser.add_argument(
|
|
"--fix_duration",
|
|
type=float,
|
|
help=f"Fix the total duration (ref and gen audios) in seconds, default {fix_duration}",
|
|
)
|
|
parser.add_argument(
|
|
"--device",
|
|
type=str,
|
|
help="Specify the device to run on",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
|
|
# config file
|
|
|
|
config = tomli.load(open(args.config, "rb"))
|
|
|
|
|
|
# command-line interface parameters
|
|
|
|
model = args.model or config.get("model", "F5TTS_v1_Base")
|
|
ckpt_file = args.ckpt_file or config.get("ckpt_file", "")
|
|
vocab_file = args.vocab_file or config.get("vocab_file", "")
|
|
|
|
ref_audio = args.ref_audio or config.get("ref_audio", "infer/examples/basic/basic_ref_en.wav")
|
|
ref_text = (
|
|
args.ref_text
|
|
if args.ref_text is not None
|
|
else config.get("ref_text", "Some call me nature, others call me mother nature.")
|
|
)
|
|
gen_text = args.gen_text or config.get("gen_text", "Here we generate something just for test.")
|
|
gen_file = args.gen_file or config.get("gen_file", "")
|
|
|
|
output_dir = args.output_dir or config.get("output_dir", "tests")
|
|
output_file = args.output_file or config.get(
|
|
"output_file", f"infer_cli_{datetime.now().strftime(r'%Y%m%d_%H%M%S')}.wav"
|
|
)
|
|
|
|
save_chunk = args.save_chunk or config.get("save_chunk", False)
|
|
remove_silence = args.remove_silence or config.get("remove_silence", False)
|
|
load_vocoder_from_local = args.load_vocoder_from_local or config.get("load_vocoder_from_local", False)
|
|
|
|
vocoder_name = args.vocoder_name or config.get("vocoder_name", mel_spec_type)
|
|
target_rms = args.target_rms or config.get("target_rms", target_rms)
|
|
cross_fade_duration = args.cross_fade_duration or config.get("cross_fade_duration", cross_fade_duration)
|
|
nfe_step = args.nfe_step or config.get("nfe_step", nfe_step)
|
|
cfg_strength = args.cfg_strength or config.get("cfg_strength", cfg_strength)
|
|
sway_sampling_coef = args.sway_sampling_coef or config.get("sway_sampling_coef", sway_sampling_coef)
|
|
speed = args.speed or config.get("speed", speed)
|
|
fix_duration = args.fix_duration or config.get("fix_duration", fix_duration)
|
|
device = args.device or config.get("device", device)
|
|
|
|
|
|
# patches for pip pkg user
|
|
if "infer/examples/" in ref_audio:
|
|
ref_audio = str(files("f5_tts").joinpath(f"{ref_audio}"))
|
|
if "infer/examples/" in gen_file:
|
|
gen_file = str(files("f5_tts").joinpath(f"{gen_file}"))
|
|
if "voices" in config:
|
|
for voice in config["voices"]:
|
|
voice_ref_audio = config["voices"][voice]["ref_audio"]
|
|
if "infer/examples/" in voice_ref_audio:
|
|
config["voices"][voice]["ref_audio"] = str(files("f5_tts").joinpath(f"{voice_ref_audio}"))
|
|
|
|
|
|
# ignore gen_text if gen_file provided
|
|
|
|
if gen_file:
|
|
gen_text = codecs.open(gen_file, "r", "utf-8").read()
|
|
|
|
|
|
# output path
|
|
|
|
wave_path = Path(output_dir) / output_file
|
|
# spectrogram_path = Path(output_dir) / "infer_cli_out.png"
|
|
if save_chunk:
|
|
output_chunk_dir = os.path.join(output_dir, f"{Path(output_file).stem}_chunks")
|
|
if not os.path.exists(output_chunk_dir):
|
|
os.makedirs(output_chunk_dir)
|
|
|
|
|
|
# load vocoder
|
|
|
|
if vocoder_name == "vocos":
|
|
vocoder_local_path = "../checkpoints/vocos-mel-24khz"
|
|
elif vocoder_name == "bigvgan":
|
|
vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
|
|
|
|
vocoder = load_vocoder(
|
|
vocoder_name=vocoder_name, is_local=load_vocoder_from_local, local_path=vocoder_local_path, device=device
|
|
)
|
|
|
|
|
|
# load TTS model
|
|
|
|
model_cfg = OmegaConf.load(
|
|
args.model_cfg or config.get("model_cfg", str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
|
|
)
|
|
model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
|
|
model_arc = model_cfg.model.arch
|
|
|
|
repo_name, ckpt_step, ckpt_type = "F5-TTS", 1250000, "safetensors"
|
|
|
|
if model != "F5TTS_Base":
|
|
assert vocoder_name == model_cfg.model.mel_spec.mel_spec_type
|
|
|
|
# override for previous models
|
|
if model == "F5TTS_Base":
|
|
if vocoder_name == "vocos":
|
|
ckpt_step = 1200000
|
|
elif vocoder_name == "bigvgan":
|
|
model = "F5TTS_Base_bigvgan"
|
|
ckpt_type = "pt"
|
|
elif model == "E2TTS_Base":
|
|
repo_name = "E2-TTS"
|
|
ckpt_step = 1200000
|
|
|
|
if not ckpt_file:
|
|
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}"))
|
|
|
|
print(f"Using {model}...")
|
|
ema_model = load_model(
|
|
model_cls, model_arc, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file, device=device
|
|
)
|
|
|
|
|
|
# inference process
|
|
|
|
|
|
def main():
|
|
main_voice = {"ref_audio": ref_audio, "ref_text": ref_text}
|
|
if "voices" not in config:
|
|
voices = {"main": main_voice}
|
|
else:
|
|
voices = config["voices"]
|
|
voices["main"] = main_voice
|
|
for voice in voices:
|
|
print("Voice:", voice)
|
|
print("ref_audio ", voices[voice]["ref_audio"])
|
|
voices[voice]["ref_audio"], voices[voice]["ref_text"] = preprocess_ref_audio_text(
|
|
voices[voice]["ref_audio"], voices[voice]["ref_text"]
|
|
)
|
|
print("ref_audio_", voices[voice]["ref_audio"], "\n\n")
|
|
|
|
generated_audio_segments = []
|
|
reg1 = r"(?=\[\w+\])"
|
|
chunks = re.split(reg1, gen_text)
|
|
reg2 = r"\[(\w+)\]"
|
|
for text in chunks:
|
|
if not text.strip():
|
|
continue
|
|
match = re.match(reg2, text)
|
|
if match:
|
|
voice = match[1]
|
|
else:
|
|
print("No voice tag found, using main.")
|
|
voice = "main"
|
|
if voice not in voices:
|
|
print(f"Voice {voice} not found, using main.")
|
|
voice = "main"
|
|
text = re.sub(reg2, "", text)
|
|
ref_audio_ = voices[voice]["ref_audio"]
|
|
ref_text_ = voices[voice]["ref_text"]
|
|
gen_text_ = text.strip()
|
|
print(f"Voice: {voice}")
|
|
audio_segment, final_sample_rate, spectrogram = infer_process(
|
|
ref_audio_,
|
|
ref_text_,
|
|
gen_text_,
|
|
ema_model,
|
|
vocoder,
|
|
mel_spec_type=vocoder_name,
|
|
target_rms=target_rms,
|
|
cross_fade_duration=cross_fade_duration,
|
|
nfe_step=nfe_step,
|
|
cfg_strength=cfg_strength,
|
|
sway_sampling_coef=sway_sampling_coef,
|
|
speed=speed,
|
|
fix_duration=fix_duration,
|
|
device=device,
|
|
)
|
|
generated_audio_segments.append(audio_segment)
|
|
|
|
if save_chunk:
|
|
if len(gen_text_) > 200:
|
|
gen_text_ = gen_text_[:200] + " ... "
|
|
sf.write(
|
|
os.path.join(output_chunk_dir, f"{len(generated_audio_segments) - 1}_{gen_text_}.wav"),
|
|
audio_segment,
|
|
final_sample_rate,
|
|
)
|
|
|
|
if generated_audio_segments:
|
|
final_wave = np.concatenate(generated_audio_segments)
|
|
|
|
if not os.path.exists(output_dir):
|
|
os.makedirs(output_dir)
|
|
|
|
with open(wave_path, "wb") as f:
|
|
sf.write(f.name, final_wave, final_sample_rate)
|
|
# Remove silence
|
|
if remove_silence:
|
|
remove_silence_for_generated_wav(f.name)
|
|
print(f.name)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|