This commit is contained in:
2025-09-10 10:56:53 +08:00
commit 1df95ad2f6
606 changed files with 590904 additions and 0 deletions

View File

@@ -0,0 +1,11 @@
FROM git.modelhub.org.cn:9443/enginex-iluvatar/mr100_corex:4.3.0
WORKDIR /workspace
COPY . /workspace/
RUN pip install -r requirements_f5.txt -c constraints_f5.txt -i https://nexus.4pd.io/repository/pypi-all/simple
RUN cd F5-TTS && pip install -e . -c ../constraints_f5.txt -i https://nexus.4pd.io/repository/pypi-all/simple
#ENTRYPOINT ["/bin/bash", "launch_f5.sh"]
ENTRYPOINT ["/bin/bash", "launch.sh"]

171
mr_v100-f5-tts/F5-TTS/.gitignore vendored Normal file
View File

@@ -0,0 +1,171 @@
# Customed
.vscode/
tests/
runs/
data/
ckpts/
wandb/
results/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

3
mr_v100-f5-tts/F5-TTS/.gitmodules vendored Normal file
View File

@@ -0,0 +1,3 @@
[submodule "src/third_party/BigVGAN"]
path = src/third_party/BigVGAN
url = https://github.com/NVIDIA/BigVGAN.git

View File

@@ -0,0 +1,17 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.11.2
hooks:
- id: ruff
name: ruff linter
args: [--fix]
- id: ruff-format
name: ruff formatter
- id: ruff
name: ruff sorter
args: [--select, I, --fix]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
hooks:
- id: check-yaml

View File

@@ -0,0 +1,30 @@
FROM pytorch/pytorch:2.4.0-cuda12.4-cudnn9-devel
USER root
ARG DEBIAN_FRONTEND=noninteractive
LABEL github_repo="https://github.com/SWivid/F5-TTS"
RUN set -x \
&& apt-get update \
&& apt-get -y install wget curl man git less openssl libssl-dev unzip unar build-essential aria2 tmux vim \
&& apt-get install -y openssh-server sox libsox-fmt-all libsox-fmt-mp3 libsndfile1-dev ffmpeg \
&& apt-get install -y librdmacm1 libibumad3 librdmacm-dev libibverbs1 libibverbs-dev ibverbs-utils ibverbs-providers \
&& rm -rf /var/lib/apt/lists/* \
&& apt-get clean
WORKDIR /workspace
RUN git clone https://github.com/SWivid/F5-TTS.git \
&& cd F5-TTS \
&& git submodule update --init --recursive \
&& pip install -e . --no-cache-dir
ENV SHELL=/bin/bash
VOLUME /root/.cache/huggingface/hub/
EXPOSE 7860
WORKDIR /workspace/F5-TTS

View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2024 Yushen CHEN
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -0,0 +1,261 @@
# F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching
[![python](https://img.shields.io/badge/Python-3.10-brightgreen)](https://github.com/SWivid/F5-TTS)
[![arXiv](https://img.shields.io/badge/arXiv-2410.06885-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2410.06885)
[![demo](https://img.shields.io/badge/GitHub-Demo%20page-orange.svg)](https://swivid.github.io/F5-TTS/)
[![hfspace](https://img.shields.io/badge/🤗-Space%20demo-yellow)](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
[![msspace](https://img.shields.io/badge/🤖-Space%20demo-blue)](https://modelscope.cn/studios/modelscope/E2-F5-TTS)
[![lab](https://img.shields.io/badge/X--LANCE-Lab-grey?labelColor=lightgrey)](https://x-lance.sjtu.edu.cn/)
[![lab](https://img.shields.io/badge/Peng%20Cheng-Lab-grey?labelColor=lightgrey)](https://www.pcl.ac.cn)
<!-- <img src="https://github.com/user-attachments/assets/12d7749c-071a-427c-81bf-b87b91def670" alt="Watermark" style="width: 40px; height: auto"> -->
**F5-TTS**: Diffusion Transformer with ConvNeXt V2, faster trained and inference.
**E2 TTS**: Flat-UNet Transformer, closest reproduction from [paper](https://arxiv.org/abs/2406.18009).
**Sway Sampling**: Inference-time flow step sampling strategy, greatly improves performance
### Thanks to all the contributors !
## News
- **2025/03/12**: 🔥 F5-TTS v1 base model with better training and inference performance. [Few demo](https://swivid.github.io/F5-TTS_updates).
- **2024/10/08**: F5-TTS & E2 TTS base models on [🤗 Hugging Face](https://huggingface.co/SWivid/F5-TTS), [🤖 Model Scope](https://www.modelscope.cn/models/SWivid/F5-TTS_Emilia-ZH-EN), [🟣 Wisemodel](https://wisemodel.cn/models/SJTU_X-LANCE/F5-TTS_Emilia-ZH-EN).
## Installation
### Create a separate environment if needed
```bash
# Create a python 3.10 conda env (you could also use virtualenv)
conda create -n f5-tts python=3.10
conda activate f5-tts
```
### Install PyTorch with matched device
<details>
<summary>NVIDIA GPU</summary>
> ```bash
> # Install pytorch with your CUDA version, e.g.
> pip install torch==2.4.0+cu124 torchaudio==2.4.0+cu124 --extra-index-url https://download.pytorch.org/whl/cu124
> ```
</details>
<details>
<summary>AMD GPU</summary>
> ```bash
> # Install pytorch with your ROCm version (Linux only), e.g.
> pip install torch==2.5.1+rocm6.2 torchaudio==2.5.1+rocm6.2 --extra-index-url https://download.pytorch.org/whl/rocm6.2
> ```
</details>
<details>
<summary>Intel GPU</summary>
> ```bash
> # Install pytorch with your XPU version, e.g.
> # Intel® Deep Learning Essentials or Intel® oneAPI Base Toolkit must be installed
> pip install torch torchaudio --index-url https://download.pytorch.org/whl/test/xpu
>
> # Intel GPU support is also available through IPEX (Intel® Extension for PyTorch)
> # IPEX does not require the Intel® Deep Learning Essentials or Intel® oneAPI Base Toolkit
> # See: https://pytorch-extension.intel.com/installation?request=platform
> ```
</details>
<details>
<summary>Apple Silicon</summary>
> ```bash
> # Install the stable pytorch, e.g.
> pip install torch torchaudio
> ```
</details>
### Then you can choose one from below:
> ### 1. As a pip package (if just for inference)
>
> ```bash
> pip install f5-tts
> ```
>
> ### 2. Local editable (if also do training, finetuning)
>
> ```bash
> git clone https://github.com/SWivid/F5-TTS.git
> cd F5-TTS
> # git submodule update --init --recursive # (optional, if use bigvgan as vocoder)
> pip install -e .
> ```
### Docker usage also available
```bash
# Build from Dockerfile
docker build -t f5tts:v1 .
# Run from GitHub Container Registry
docker container run --rm -it --gpus=all --mount 'type=volume,source=f5-tts,target=/root/.cache/huggingface/hub/' -p 7860:7860 ghcr.io/swivid/f5-tts:main
# Quickstart if you want to just run the web interface (not CLI)
docker container run --rm -it --gpus=all --mount 'type=volume,source=f5-tts,target=/root/.cache/huggingface/hub/' -p 7860:7860 ghcr.io/swivid/f5-tts:main f5-tts_infer-gradio --host 0.0.0.0
```
### Runtime
Deployment solution with Triton and TensorRT-LLM.
#### Benchmark Results
Decoding on a single L20 GPU, using 26 different prompt_audio & target_text pairs, 16 NFE.
| Model | Concurrency | Avg Latency | RTF | Mode |
|---------------------|----------------|-------------|--------|-----------------|
| F5-TTS Base (Vocos) | 2 | 253 ms | 0.0394 | Client-Server |
| F5-TTS Base (Vocos) | 1 (Batch_size) | - | 0.0402 | Offline TRT-LLM |
| F5-TTS Base (Vocos) | 1 (Batch_size) | - | 0.1467 | Offline Pytorch |
See [detailed instructions](src/f5_tts/runtime/triton_trtllm/README.md) for more information.
## Inference
- In order to achieve desired performance, take a moment to read [detailed guidance](src/f5_tts/infer).
- By properly searching the keywords of problem encountered, [issues](https://github.com/SWivid/F5-TTS/issues?q=is%3Aissue) are very helpful.
### 1. Gradio App
Currently supported features:
- Basic TTS with Chunk Inference
- Multi-Style / Multi-Speaker Generation
- Voice Chat powered by Qwen2.5-3B-Instruct
- [Custom inference with more language support](src/f5_tts/infer/SHARED.md)
```bash
# Launch a Gradio app (web interface)
f5-tts_infer-gradio
# Specify the port/host
f5-tts_infer-gradio --port 7860 --host 0.0.0.0
# Launch a share link
f5-tts_infer-gradio --share
```
<details>
<summary>NVIDIA device docker compose file example</summary>
```yaml
services:
f5-tts:
image: ghcr.io/swivid/f5-tts:main
ports:
- "7860:7860"
environment:
GRADIO_SERVER_PORT: 7860
entrypoint: ["f5-tts_infer-gradio", "--port", "7860", "--host", "0.0.0.0"]
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]
volumes:
f5-tts:
driver: local
```
</details>
### 2. CLI Inference
```bash
# Run with flags
# Leave --ref_text "" will have ASR model transcribe (extra GPU memory usage)
f5-tts_infer-cli --model F5TTS_v1_Base \
--ref_audio "provide_prompt_wav_path_here.wav" \
--ref_text "The content, subtitle or transcription of reference audio." \
--gen_text "Some text you want TTS model generate for you."
# Run with default setting. src/f5_tts/infer/examples/basic/basic.toml
f5-tts_infer-cli
# Or with your own .toml file
f5-tts_infer-cli -c custom.toml
# Multi voice. See src/f5_tts/infer/README.md
f5-tts_infer-cli -c src/f5_tts/infer/examples/multi/story.toml
```
## Training
### 1. With Hugging Face Accelerate
Refer to [training & finetuning guidance](src/f5_tts/train) for best practice.
### 2. With Gradio App
```bash
# Quick start with Gradio web interface
f5-tts_finetune-gradio
```
Read [training & finetuning guidance](src/f5_tts/train) for more instructions.
## [Evaluation](src/f5_tts/eval)
## Development
Use pre-commit to ensure code quality (will run linters and formatters automatically):
```bash
pip install pre-commit
pre-commit install
```
When making a pull request, before each commit, run:
```bash
pre-commit run --all-files
```
Note: Some model components have linting exceptions for E722 to accommodate tensor notation.
## Acknowledgements
- [E2-TTS](https://arxiv.org/abs/2406.18009) brilliant work, simple and effective
- [Emilia](https://arxiv.org/abs/2407.05361), [WenetSpeech4TTS](https://arxiv.org/abs/2406.05763), [LibriTTS](https://arxiv.org/abs/1904.02882), [LJSpeech](https://keithito.com/LJ-Speech-Dataset/) valuable datasets
- [lucidrains](https://github.com/lucidrains) initial CFM structure with also [bfs18](https://github.com/bfs18) for discussion
- [SD3](https://arxiv.org/abs/2403.03206) & [Hugging Face diffusers](https://github.com/huggingface/diffusers) DiT and MMDiT code structure
- [torchdiffeq](https://github.com/rtqichen/torchdiffeq) as ODE solver, [Vocos](https://huggingface.co/charactr/vocos-mel-24khz) and [BigVGAN](https://github.com/NVIDIA/BigVGAN) as vocoder
- [FunASR](https://github.com/modelscope/FunASR), [faster-whisper](https://github.com/SYSTRAN/faster-whisper), [UniSpeech](https://github.com/microsoft/UniSpeech), [SpeechMOS](https://github.com/tarepan/SpeechMOS) for evaluation tools
- [ctc-forced-aligner](https://github.com/MahmoudAshraf97/ctc-forced-aligner) for speech edit test
- [mrfakename](https://x.com/realmrfakename) huggingface space demo ~
- [f5-tts-mlx](https://github.com/lucasnewman/f5-tts-mlx/tree/main) Implementation with MLX framework by [Lucas Newman](https://github.com/lucasnewman)
- [F5-TTS-ONNX](https://github.com/DakeQQ/F5-TTS-ONNX) ONNX Runtime version by [DakeQQ](https://github.com/DakeQQ)
- [Yuekai Zhang](https://github.com/yuekaizhang) Triton and TensorRT-LLM support ~
## Citation
If our work and codebase is useful for you, please cite as:
```
@article{chen-etal-2024-f5tts,
title={F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching},
author={Yushen Chen and Zhikang Niu and Ziyang Ma and Keqi Deng and Chunhui Wang and Jian Zhao and Kai Yu and Xie Chen},
journal={arXiv preprint arXiv:2410.06885},
year={2024},
}
```
## License
Our code is released under MIT License. The pre-trained models are licensed under the CC-BY-NC license due to the training data Emilia, which is an in-the-wild dataset. Sorry for any inconvenience this may cause.

View File

@@ -0,0 +1,63 @@
[build-system]
requires = ["setuptools >= 61.0", "setuptools-scm>=8.0"]
build-backend = "setuptools.build_meta"
[project]
name = "f5-tts"
version = "1.1.5"
description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
readme = "README.md"
license = {text = "MIT License"}
classifiers = [
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
]
dependencies = [
"accelerate>=0.33.0,!=1.7.0",
"bitsandbytes>0.37.0; platform_machine != 'arm64' and platform_system != 'Darwin'",
"cached_path",
"click",
"datasets",
"ema_pytorch>=0.5.2",
"gradio>=3.45.2",
"hydra-core>=1.3.0",
"jieba",
"librosa",
"matplotlib",
"numpy<=1.26.4",
"pydantic<=2.10.6",
"pydub",
"pypinyin",
"safetensors",
"soundfile",
"tomli",
"torch>=2.0.0",
"torchaudio>=2.0.0",
"torchdiffeq",
"tqdm>=4.65.0",
"transformers",
"transformers_stream_generator",
"vocos",
"wandb",
"x_transformers>=1.31.14",
]
[project.optional-dependencies]
eval = [
"faster_whisper==0.10.1",
"funasr",
"jiwer",
"modelscope",
"zhconv",
"zhon",
]
[project.urls]
Homepage = "https://github.com/SWivid/F5-TTS"
[project.scripts]
"f5-tts_infer-cli" = "f5_tts.infer.infer_cli:main"
"f5-tts_infer-gradio" = "f5_tts.infer.infer_gradio:main"
"f5-tts_finetune-cli" = "f5_tts.train.finetune_cli:main"
"f5-tts_finetune-gradio" = "f5_tts.train.finetune_gradio:main"

View File

@@ -0,0 +1,10 @@
line-length = 120
target-version = "py310"
[lint]
# Only ignore variables with names starting with "_".
dummy-variable-rgx = "^_.*$"
[lint.isort]
force-single-line = false
lines-after-imports = 2

View File

@@ -0,0 +1,164 @@
import random
import sys
from importlib.resources import files
import soundfile as sf
import tqdm
from cached_path import cached_path
from hydra.utils import get_class
from omegaconf import OmegaConf
from f5_tts.infer.utils_infer import (
infer_process,
load_model,
load_vocoder,
preprocess_ref_audio_text,
remove_silence_for_generated_wav,
save_spectrogram,
transcribe,
)
from f5_tts.model.utils import seed_everything
class F5TTS:
def __init__(
self,
model="F5TTS_v1_Base",
ckpt_file="",
vocab_file="",
ode_method="euler",
use_ema=True,
vocoder_local_path=None,
device=None,
hf_cache_dir=None,
):
model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
model_arc = model_cfg.model.arch
self.mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
self.target_sample_rate = model_cfg.model.mel_spec.target_sample_rate
self.ode_method = ode_method
self.use_ema = use_ema
if device is not None:
self.device = device
else:
import torch
self.device = (
"cuda"
if torch.cuda.is_available()
else "xpu"
if torch.xpu.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
# Load models
self.vocoder = load_vocoder(
self.mel_spec_type, vocoder_local_path is not None, vocoder_local_path, self.device, hf_cache_dir
)
repo_name, ckpt_step, ckpt_type = "F5-TTS", 1250000, "safetensors"
# override for previous models
if model == "F5TTS_Base":
if self.mel_spec_type == "vocos":
ckpt_step = 1200000
elif self.mel_spec_type == "bigvgan":
model = "F5TTS_Base_bigvgan"
ckpt_type = "pt"
elif model == "E2TTS_Base":
repo_name = "E2-TTS"
ckpt_step = 1200000
if not ckpt_file:
ckpt_file = str(
cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}", cache_dir=hf_cache_dir)
)
self.ema_model = load_model(
model_cls, model_arc, ckpt_file, self.mel_spec_type, vocab_file, self.ode_method, self.use_ema, self.device
)
def transcribe(self, ref_audio, language=None):
return transcribe(ref_audio, language)
def export_wav(self, wav, file_wave, remove_silence=False):
sf.write(file_wave, wav, self.target_sample_rate)
if remove_silence:
remove_silence_for_generated_wav(file_wave)
def export_spectrogram(self, spec, file_spec):
save_spectrogram(spec, file_spec)
def infer(
self,
ref_file,
ref_text,
gen_text,
show_info=print,
progress=tqdm,
target_rms=0.1,
cross_fade_duration=0.15,
sway_sampling_coef=-1,
cfg_strength=2,
nfe_step=32,
speed=1.0,
fix_duration=None,
remove_silence=False,
file_wave=None,
file_spec=None,
seed=None,
):
if seed is None:
seed = random.randint(0, sys.maxsize)
seed_everything(seed)
self.seed = seed
ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text)
wav, sr, spec = infer_process(
ref_file,
ref_text,
gen_text,
self.ema_model,
self.vocoder,
self.mel_spec_type,
show_info=show_info,
progress=progress,
target_rms=target_rms,
cross_fade_duration=cross_fade_duration,
nfe_step=nfe_step,
cfg_strength=cfg_strength,
sway_sampling_coef=sway_sampling_coef,
speed=speed,
fix_duration=fix_duration,
device=self.device,
)
if file_wave is not None:
self.export_wav(wav, file_wave, remove_silence)
if file_spec is not None:
self.export_spectrogram(spec, file_spec)
return wav, sr, spec
if __name__ == "__main__":
f5tts = F5TTS()
wav, sr, spec = f5tts.infer(
ref_file=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")),
ref_text="some call me nature, others call me mother nature.",
gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""",
file_wave=str(files("f5_tts").joinpath("../../tests/api_out.wav")),
file_spec=str(files("f5_tts").joinpath("../../tests/api_out.png")),
seed=None,
)
print("seed :", f5tts.seed)

View File

@@ -0,0 +1,49 @@
hydra:
run:
dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
datasets:
name: Emilia_ZH_EN # dataset name
batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
batch_size_type: frame # frame | sample
max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
num_workers: 16
optim:
epochs: 11
learning_rate: 7.5e-5
num_warmup_updates: 20000 # warmup updates
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
max_grad_norm: 1.0 # gradient clipping
bnb_optimizer: False # use bnb 8bit AdamW optimizer or not
model:
name: E2TTS_Base
tokenizer: pinyin
tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt)
backbone: UNetT
arch:
dim: 1024
depth: 24
heads: 16
ff_mult: 4
text_mask_padding: False
pe_attn_head: 1
mel_spec:
target_sample_rate: 24000
n_mel_channels: 100
hop_length: 256
win_length: 1024
n_fft: 1024
mel_spec_type: vocos # vocos | bigvgan
vocoder:
is_local: False # use local offline ckpt or not
local_path: null # local vocoder path
ckpts:
logger: wandb # wandb | tensorboard | null
log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples
save_per_updates: 50000 # save checkpoint per updates
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
last_per_updates: 5000 # save last checkpoint per updates
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}

View File

@@ -0,0 +1,49 @@
hydra:
run:
dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
datasets:
name: Emilia_ZH_EN
batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
batch_size_type: frame # frame | sample
max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
num_workers: 16
optim:
epochs: 11
learning_rate: 7.5e-5
num_warmup_updates: 20000 # warmup updates
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
max_grad_norm: 1.0
bnb_optimizer: False
model:
name: E2TTS_Small
tokenizer: pinyin
tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt)
backbone: UNetT
arch:
dim: 768
depth: 20
heads: 12
ff_mult: 4
text_mask_padding: False
pe_attn_head: 1
mel_spec:
target_sample_rate: 24000
n_mel_channels: 100
hop_length: 256
win_length: 1024
n_fft: 1024
mel_spec_type: vocos # vocos | bigvgan
vocoder:
is_local: False # use local offline ckpt or not
local_path: null # local vocoder path
ckpts:
logger: wandb # wandb | tensorboard | null
log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples
save_per_updates: 50000 # save checkpoint per updates
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
last_per_updates: 5000 # save last checkpoint per updates
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}

View File

@@ -0,0 +1,54 @@
hydra:
run:
dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
datasets:
name: Emilia_ZH_EN # dataset name
batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
batch_size_type: frame # frame | sample
max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
num_workers: 16
optim:
epochs: 11
learning_rate: 7.5e-5
num_warmup_updates: 20000 # warmup updates
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
max_grad_norm: 1.0 # gradient clipping
bnb_optimizer: False # use bnb 8bit AdamW optimizer or not
model:
name: F5TTS_Base # model name
tokenizer: pinyin # tokenizer type
tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt)
backbone: DiT
arch:
dim: 1024
depth: 22
heads: 16
ff_mult: 2
text_dim: 512
text_mask_padding: False
conv_layers: 4
pe_attn_head: 1
attn_backend: torch # torch | flash_attn
attn_mask_enabled: False
checkpoint_activations: False # recompute activations and save memory for extra compute
mel_spec:
target_sample_rate: 24000
n_mel_channels: 100
hop_length: 256
win_length: 1024
n_fft: 1024
mel_spec_type: vocos # vocos | bigvgan
vocoder:
is_local: False # use local offline ckpt or not
local_path: null # local vocoder path
ckpts:
logger: wandb # wandb | tensorboard | null
log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples
save_per_updates: 50000 # save checkpoint per updates
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
last_per_updates: 5000 # save last checkpoint per updates
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}

View File

@@ -0,0 +1,54 @@
hydra:
run:
dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
datasets:
name: Emilia_ZH_EN
batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
batch_size_type: frame # frame | sample
max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
num_workers: 16
optim:
epochs: 11 # only suitable for Emilia, if you want to train it on LibriTTS, set epoch 686
learning_rate: 7.5e-5
num_warmup_updates: 20000 # warmup updates
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
max_grad_norm: 1.0 # gradient clipping
bnb_optimizer: False # use bnb 8bit AdamW optimizer or not
model:
name: F5TTS_Small
tokenizer: pinyin
tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt)
backbone: DiT
arch:
dim: 768
depth: 18
heads: 12
ff_mult: 2
text_dim: 512
text_mask_padding: False
conv_layers: 4
pe_attn_head: 1
attn_backend: torch # torch | flash_attn
attn_mask_enabled: False
checkpoint_activations: False # recompute activations and save memory for extra compute
mel_spec:
target_sample_rate: 24000
n_mel_channels: 100
hop_length: 256
win_length: 1024
n_fft: 1024
mel_spec_type: vocos # vocos | bigvgan
vocoder:
is_local: False # use local offline ckpt or not
local_path: null # local vocoder path
ckpts:
logger: wandb # wandb | tensorboard | null
log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples
save_per_updates: 50000 # save checkpoint per updates
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
last_per_updates: 5000 # save last checkpoint per updates
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}

View File

@@ -0,0 +1,55 @@
hydra:
run:
dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
datasets:
name: Emilia_ZH_EN # dataset name
batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
batch_size_type: frame # frame | sample
max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
num_workers: 16
optim:
epochs: 11
learning_rate: 7.5e-5
num_warmup_updates: 20000 # warmup updates
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
max_grad_norm: 1.0 # gradient clipping
bnb_optimizer: False # use bnb 8bit AdamW optimizer or not
model:
name: F5TTS_v1_Base # model name
tokenizer: pinyin # tokenizer type
tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt)
backbone: DiT
arch:
dim: 1024
depth: 22
heads: 16
ff_mult: 2
text_dim: 512
text_mask_padding: True
qk_norm: null # null | rms_norm
conv_layers: 4
pe_attn_head: null
attn_backend: flash_attn # torch | flash_attn
attn_mask_enabled: False
checkpoint_activations: False # recompute activations and save memory for extra compute
mel_spec:
target_sample_rate: 24000
n_mel_channels: 100
hop_length: 256
win_length: 1024
n_fft: 1024
mel_spec_type: vocos # vocos | bigvgan
vocoder:
is_local: False # use local offline ckpt or not
local_path: null # local vocoder path
ckpts:
logger: wandb # wandb | tensorboard | null
log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples
save_per_updates: 50000 # save checkpoint per updates
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
last_per_updates: 5000 # save last checkpoint per updates
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}

View File

@@ -0,0 +1,52 @@
# Evaluation
Install packages for evaluation:
```bash
pip install -e .[eval]
```
## Generating Samples for Evaluation
### Prepare Test Datasets
1. *Seed-TTS testset*: Download from [seed-tts-eval](https://github.com/BytedanceSpeech/seed-tts-eval).
2. *LibriSpeech test-clean*: Download from [OpenSLR](http://www.openslr.org/12/).
3. Unzip the downloaded datasets and place them in the `data/` directory.
4. Update the path for *LibriSpeech test-clean* data in `src/f5_tts/eval/eval_infer_batch.py`
5. Our filtered LibriSpeech-PC 4-10s subset: `data/librispeech_pc_test_clean_cross_sentence.lst`
### Batch Inference for Test Set
To run batch inference for evaluations, execute the following commands:
```bash
# batch inference for evaluations
accelerate config # if not set before
bash src/f5_tts/eval/eval_infer_batch.sh
```
## Objective Evaluation on Generated Results
### Download Evaluation Model Checkpoints
1. Chinese ASR Model: [Paraformer-zh](https://huggingface.co/funasr/paraformer-zh)
2. English ASR Model: [Faster-Whisper](https://huggingface.co/Systran/faster-whisper-large-v3)
3. WavLM Model: Download from [Google Drive](https://drive.google.com/file/d/1-aE1NfzpRCLxA4GUxX9ITI3F9LlbtEGP/view).
Then update in the following scripts with the paths you put evaluation model ckpts to.
### Objective Evaluation
Update the path with your batch-inferenced results, and carry out WER / SIM / UTMOS evaluations:
```bash
# Evaluation [WER] for Seed-TTS test [ZH] set
python src/f5_tts/eval/eval_seedtts_testset.py --eval_task wer --lang zh --gen_wav_dir <GEN_WAV_DIR> --gpu_nums 8
# Evaluation [SIM] for LibriSpeech-PC test-clean (cross-sentence)
python src/f5_tts/eval/eval_librispeech_test_clean.py --eval_task sim --gen_wav_dir <GEN_WAV_DIR> --librispeech_test_clean_path <TEST_CLEAN_PATH>
# Evaluation [UTMOS]. --ext: Audio extension
python src/f5_tts/eval/eval_utmos.py --audio_dir <WAV_DIR> --ext wav
```

View File

@@ -0,0 +1,331 @@
# just for speaker similarity evaluation, third-party code
# From https://github.com/microsoft/UniSpeech/blob/main/downstreams/speaker_verification/models/
# part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
""" Res2Conv1d + BatchNorm1d + ReLU
"""
class Res2Conv1dReluBn(nn.Module):
"""
in_channels == out_channels == channels
"""
def __init__(self, channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True, scale=4):
super().__init__()
assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
self.scale = scale
self.width = channels // scale
self.nums = scale if scale == 1 else scale - 1
self.convs = []
self.bns = []
for i in range(self.nums):
self.convs.append(nn.Conv1d(self.width, self.width, kernel_size, stride, padding, dilation, bias=bias))
self.bns.append(nn.BatchNorm1d(self.width))
self.convs = nn.ModuleList(self.convs)
self.bns = nn.ModuleList(self.bns)
def forward(self, x):
out = []
spx = torch.split(x, self.width, 1)
for i in range(self.nums):
if i == 0:
sp = spx[i]
else:
sp = sp + spx[i]
# Order: conv -> relu -> bn
sp = self.convs[i](sp)
sp = self.bns[i](F.relu(sp))
out.append(sp)
if self.scale != 1:
out.append(spx[self.nums])
out = torch.cat(out, dim=1)
return out
""" Conv1d + BatchNorm1d + ReLU
"""
class Conv1dReluBn(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True):
super().__init__()
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)
self.bn = nn.BatchNorm1d(out_channels)
def forward(self, x):
return self.bn(F.relu(self.conv(x)))
""" The SE connection of 1D case.
"""
class SE_Connect(nn.Module):
def __init__(self, channels, se_bottleneck_dim=128):
super().__init__()
self.linear1 = nn.Linear(channels, se_bottleneck_dim)
self.linear2 = nn.Linear(se_bottleneck_dim, channels)
def forward(self, x):
out = x.mean(dim=2)
out = F.relu(self.linear1(out))
out = torch.sigmoid(self.linear2(out))
out = x * out.unsqueeze(2)
return out
""" SE-Res2Block of the ECAPA-TDNN architecture.
"""
# def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale):
# return nn.Sequential(
# Conv1dReluBn(channels, 512, kernel_size=1, stride=1, padding=0),
# Res2Conv1dReluBn(512, kernel_size, stride, padding, dilation, scale=scale),
# Conv1dReluBn(512, channels, kernel_size=1, stride=1, padding=0),
# SE_Connect(channels)
# )
class SE_Res2Block(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, scale, se_bottleneck_dim):
super().__init__()
self.Conv1dReluBn1 = Conv1dReluBn(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.Res2Conv1dReluBn = Res2Conv1dReluBn(out_channels, kernel_size, stride, padding, dilation, scale=scale)
self.Conv1dReluBn2 = Conv1dReluBn(out_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.SE_Connect = SE_Connect(out_channels, se_bottleneck_dim)
self.shortcut = None
if in_channels != out_channels:
self.shortcut = nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
)
def forward(self, x):
residual = x
if self.shortcut:
residual = self.shortcut(x)
x = self.Conv1dReluBn1(x)
x = self.Res2Conv1dReluBn(x)
x = self.Conv1dReluBn2(x)
x = self.SE_Connect(x)
return x + residual
""" Attentive weighted mean and standard deviation pooling.
"""
class AttentiveStatsPool(nn.Module):
def __init__(self, in_dim, attention_channels=128, global_context_att=False):
super().__init__()
self.global_context_att = global_context_att
# Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
if global_context_att:
self.linear1 = nn.Conv1d(in_dim * 3, attention_channels, kernel_size=1) # equals W and b in the paper
else:
self.linear1 = nn.Conv1d(in_dim, attention_channels, kernel_size=1) # equals W and b in the paper
self.linear2 = nn.Conv1d(attention_channels, in_dim, kernel_size=1) # equals V and k in the paper
def forward(self, x):
if self.global_context_att:
context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
context_std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
x_in = torch.cat((x, context_mean, context_std), dim=1)
else:
x_in = x
# DON'T use ReLU here! In experiments, I find ReLU hard to converge.
alpha = torch.tanh(self.linear1(x_in))
# alpha = F.relu(self.linear1(x_in))
alpha = torch.softmax(self.linear2(alpha), dim=2)
mean = torch.sum(alpha * x, dim=2)
residuals = torch.sum(alpha * (x**2), dim=2) - mean**2
std = torch.sqrt(residuals.clamp(min=1e-9))
return torch.cat([mean, std], dim=1)
class ECAPA_TDNN(nn.Module):
def __init__(
self,
feat_dim=80,
channels=512,
emb_dim=192,
global_context_att=False,
feat_type="wavlm_large",
sr=16000,
feature_selection="hidden_states",
update_extract=False,
config_path=None,
):
super().__init__()
self.feat_type = feat_type
self.feature_selection = feature_selection
self.update_extract = update_extract
self.sr = sr
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
try:
local_s3prl_path = os.path.expanduser("~/.cache/torch/hub/s3prl_s3prl_main")
self.feature_extract = torch.hub.load(local_s3prl_path, feat_type, source="local", config_path=config_path)
except: # noqa: E722
self.feature_extract = torch.hub.load("s3prl/s3prl", feat_type)
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"
):
self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = False
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention"
):
self.feature_extract.model.encoder.layers[11].self_attn.fp32_attention = False
self.feat_num = self.get_feat_num()
self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
if feat_type != "fbank" and feat_type != "mfcc":
freeze_list = ["final_proj", "label_embs_concat", "mask_emb", "project_q", "quantizer"]
for name, param in self.feature_extract.named_parameters():
for freeze_val in freeze_list:
if freeze_val in name:
param.requires_grad = False
break
if not self.update_extract:
for param in self.feature_extract.parameters():
param.requires_grad = False
self.instance_norm = nn.InstanceNorm1d(feat_dim)
# self.channels = [channels] * 4 + [channels * 3]
self.channels = [channels] * 4 + [1536]
self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2)
self.layer2 = SE_Res2Block(
self.channels[0],
self.channels[1],
kernel_size=3,
stride=1,
padding=2,
dilation=2,
scale=8,
se_bottleneck_dim=128,
)
self.layer3 = SE_Res2Block(
self.channels[1],
self.channels[2],
kernel_size=3,
stride=1,
padding=3,
dilation=3,
scale=8,
se_bottleneck_dim=128,
)
self.layer4 = SE_Res2Block(
self.channels[2],
self.channels[3],
kernel_size=3,
stride=1,
padding=4,
dilation=4,
scale=8,
se_bottleneck_dim=128,
)
# self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1)
cat_channels = channels * 3
self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1)
self.pooling = AttentiveStatsPool(
self.channels[-1], attention_channels=128, global_context_att=global_context_att
)
self.bn = nn.BatchNorm1d(self.channels[-1] * 2)
self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
def get_feat_num(self):
self.feature_extract.eval()
wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)]
with torch.no_grad():
features = self.feature_extract(wav)
select_feature = features[self.feature_selection]
if isinstance(select_feature, (list, tuple)):
return len(select_feature)
else:
return 1
def get_feat(self, x):
if self.update_extract:
x = self.feature_extract([sample for sample in x])
else:
with torch.no_grad():
if self.feat_type == "fbank" or self.feat_type == "mfcc":
x = self.feature_extract(x) + 1e-6 # B x feat_dim x time_len
else:
x = self.feature_extract([sample for sample in x])
if self.feat_type == "fbank":
x = x.log()
if self.feat_type != "fbank" and self.feat_type != "mfcc":
x = x[self.feature_selection]
if isinstance(x, (list, tuple)):
x = torch.stack(x, dim=0)
else:
x = x.unsqueeze(0)
norm_weights = F.softmax(self.feature_weight, dim=-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
x = (norm_weights * x).sum(dim=0)
x = torch.transpose(x, 1, 2) + 1e-6
x = self.instance_norm(x)
return x
def forward(self, x):
x = self.get_feat(x)
out1 = self.layer1(x)
out2 = self.layer2(out1)
out3 = self.layer3(out2)
out4 = self.layer4(out3)
out = torch.cat([out2, out3, out4], dim=1)
out = F.relu(self.conv(out))
out = self.bn(self.pooling(out))
out = self.linear(out)
return out
def ECAPA_TDNN_SMALL(
feat_dim,
emb_dim=256,
feat_type="wavlm_large",
sr=16000,
feature_selection="hidden_states",
update_extract=False,
config_path=None,
):
return ECAPA_TDNN(
feat_dim=feat_dim,
channels=512,
emb_dim=emb_dim,
feat_type=feat_type,
sr=sr,
feature_selection=feature_selection,
update_extract=update_extract,
config_path=config_path,
)

View File

@@ -0,0 +1,210 @@
import os
import sys
sys.path.append(os.getcwd())
import argparse
import time
from importlib.resources import files
import torch
import torchaudio
from accelerate import Accelerator
from hydra.utils import get_class
from omegaconf import OmegaConf
from tqdm import tqdm
from f5_tts.eval.utils_eval import (
get_inference_prompt,
get_librispeech_test_clean_metainfo,
get_seedtts_testset_metainfo,
)
from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder
from f5_tts.model import CFM
from f5_tts.model.utils import get_tokenizer
accelerator = Accelerator()
device = f"cuda:{accelerator.process_index}"
use_ema = True
target_rms = 0.1
rel_path = str(files("f5_tts").joinpath("../../"))
def main():
parser = argparse.ArgumentParser(description="batch inference")
parser.add_argument("-s", "--seed", default=None, type=int)
parser.add_argument("-n", "--expname", required=True)
parser.add_argument("-c", "--ckptstep", default=1250000, type=int)
parser.add_argument("-nfe", "--nfestep", default=32, type=int)
parser.add_argument("-o", "--odemethod", default="euler")
parser.add_argument("-ss", "--swaysampling", default=-1, type=float)
parser.add_argument("-t", "--testset", required=True)
args = parser.parse_args()
seed = args.seed
exp_name = args.expname
ckpt_step = args.ckptstep
nfe_step = args.nfestep
ode_method = args.odemethod
sway_sampling_coef = args.swaysampling
testset = args.testset
infer_batch_size = 1 # max frames. 1 for ddp single inference (recommended)
cfg_strength = 2.0
speed = 1.0
use_truth_duration = False
no_ref_audio = False
model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{exp_name}.yaml")))
model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
model_arc = model_cfg.model.arch
dataset_name = model_cfg.datasets.name
tokenizer = model_cfg.model.tokenizer
mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
target_sample_rate = model_cfg.model.mel_spec.target_sample_rate
n_mel_channels = model_cfg.model.mel_spec.n_mel_channels
hop_length = model_cfg.model.mel_spec.hop_length
win_length = model_cfg.model.mel_spec.win_length
n_fft = model_cfg.model.mel_spec.n_fft
if testset == "ls_pc_test_clean":
metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst"
librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
metainfo = get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path)
elif testset == "seedtts_test_zh":
metalst = rel_path + "/data/seedtts_testset/zh/meta.lst"
metainfo = get_seedtts_testset_metainfo(metalst)
elif testset == "seedtts_test_en":
metalst = rel_path + "/data/seedtts_testset/en/meta.lst"
metainfo = get_seedtts_testset_metainfo(metalst)
# path to save genereted wavs
output_dir = (
f"{rel_path}/"
f"results/{exp_name}_{ckpt_step}/{testset}/"
f"seed{seed}_{ode_method}_nfe{nfe_step}_{mel_spec_type}"
f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}"
f"_cfg{cfg_strength}_speed{speed}"
f"{'_gt-dur' if use_truth_duration else ''}"
f"{'_no-ref-audio' if no_ref_audio else ''}"
)
# -------------------------------------------------#
prompts_all = get_inference_prompt(
metainfo,
speed=speed,
tokenizer=tokenizer,
target_sample_rate=target_sample_rate,
n_mel_channels=n_mel_channels,
hop_length=hop_length,
mel_spec_type=mel_spec_type,
target_rms=target_rms,
use_truth_duration=use_truth_duration,
infer_batch_size=infer_batch_size,
)
# Vocoder model
local = False
if mel_spec_type == "vocos":
vocoder_local_path = "../checkpoints/charactr/vocos-mel-24khz"
elif mel_spec_type == "bigvgan":
vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
vocoder = load_vocoder(vocoder_name=mel_spec_type, is_local=local, local_path=vocoder_local_path)
# Tokenizer
vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
# Model
model = CFM(
transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
mel_spec_kwargs=dict(
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
n_mel_channels=n_mel_channels,
target_sample_rate=target_sample_rate,
mel_spec_type=mel_spec_type,
),
odeint_kwargs=dict(
method=ode_method,
),
vocab_char_map=vocab_char_map,
).to(device)
ckpt_prefix = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}"
if os.path.exists(ckpt_prefix + ".pt"):
ckpt_path = ckpt_prefix + ".pt"
elif os.path.exists(ckpt_prefix + ".safetensors"):
ckpt_path = ckpt_prefix + ".safetensors"
else:
print("Loading from self-organized training checkpoints rather than released pretrained.")
ckpt_path = rel_path + f"/{model_cfg.ckpts.save_dir}/model_{ckpt_step}.pt"
dtype = torch.float32 if mel_spec_type == "bigvgan" else None
model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
if not os.path.exists(output_dir) and accelerator.is_main_process:
os.makedirs(output_dir)
# start batch inference
accelerator.wait_for_everyone()
start = time.time()
with accelerator.split_between_processes(prompts_all) as prompts:
for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process):
utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = prompt
ref_mels = ref_mels.to(device)
ref_mel_lens = torch.tensor(ref_mel_lens, dtype=torch.long).to(device)
total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device)
# Inference
with torch.inference_mode():
generated, _ = model.sample(
cond=ref_mels,
text=final_text_list,
duration=total_mel_lens,
lens=ref_mel_lens,
steps=nfe_step,
cfg_strength=cfg_strength,
sway_sampling_coef=sway_sampling_coef,
no_ref_audio=no_ref_audio,
seed=seed,
)
# Final result
for i, gen in enumerate(generated):
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32)
if mel_spec_type == "vocos":
generated_wave = vocoder.decode(gen_mel_spec).cpu()
elif mel_spec_type == "bigvgan":
generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu()
if ref_rms_list[i] < target_rms:
generated_wave = generated_wave * ref_rms_list[i] / target_rms
torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave, target_sample_rate)
accelerator.wait_for_everyone()
if accelerator.is_main_process:
timediff = time.time() - start
print(f"Done batch inference in {timediff / 60:.2f} minutes.")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,18 @@
#!/bin/bash
# e.g. F5-TTS, 16 NFE
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "seedtts_test_zh" -nfe 16
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "seedtts_test_en" -nfe 16
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "ls_pc_test_clean" -nfe 16
# e.g. Vanilla E2 TTS, 32 NFE
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "seedtts_test_zh" -o "midpoint" -ss 0
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "seedtts_test_en" -o "midpoint" -ss 0
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "ls_pc_test_clean" -o "midpoint" -ss 0
# e.g. evaluate F5-TTS 16 NFE result on Seed-TTS test-zh
python src/f5_tts/eval/eval_seedtts_testset.py -e wer -l zh --gen_wav_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe32_vocos_ss-1_cfg2.0_speed1.0 --gpu_nums 8
python src/f5_tts/eval/eval_seedtts_testset.py -e sim -l zh --gen_wav_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe32_vocos_ss-1_cfg2.0_speed1.0 --gpu_nums 8
python src/f5_tts/eval/eval_utmos.py --audio_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe32_vocos_ss-1_cfg2.0_speed1.0
# etc.

View File

@@ -0,0 +1,89 @@
# Evaluate with Librispeech test-clean, ~3s prompt to generate 4-10s audio (the way of valle/voicebox evaluation)
import argparse
import json
import os
import sys
sys.path.append(os.getcwd())
import multiprocessing as mp
from importlib.resources import files
import numpy as np
from f5_tts.eval.utils_eval import get_librispeech_test, run_asr_wer, run_sim
rel_path = str(files("f5_tts").joinpath("../../"))
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("-e", "--eval_task", type=str, default="wer", choices=["sim", "wer"])
parser.add_argument("-l", "--lang", type=str, default="en")
parser.add_argument("-g", "--gen_wav_dir", type=str, required=True)
parser.add_argument("-p", "--librispeech_test_clean_path", type=str, required=True)
parser.add_argument("-n", "--gpu_nums", type=int, default=8, help="Number of GPUs to use")
parser.add_argument("--local", action="store_true", help="Use local custom checkpoint directory")
return parser.parse_args()
def main():
args = get_args()
eval_task = args.eval_task
lang = args.lang
librispeech_test_clean_path = args.librispeech_test_clean_path # test-clean path
gen_wav_dir = args.gen_wav_dir
metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst"
gpus = list(range(args.gpu_nums))
test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path)
## In LibriSpeech, some speakers utilized varying voice characteristics for different characters in the book,
## leading to a low similarity for the ground truth in some cases.
# test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth = True) # eval ground truth
local = args.local
if local: # use local custom checkpoint dir
asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3"
else:
asr_ckpt_dir = "" # auto download to cache dir
wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
# --------------------------------------------------------------------------
full_results = []
metrics = []
if eval_task == "wer":
with mp.Pool(processes=len(gpus)) as pool:
args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
results = pool.map(run_asr_wer, args)
for r in results:
full_results.extend(r)
elif eval_task == "sim":
with mp.Pool(processes=len(gpus)) as pool:
args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
results = pool.map(run_sim, args)
for r in results:
full_results.extend(r)
else:
raise ValueError(f"Unknown metric type: {eval_task}")
result_path = f"{gen_wav_dir}/_{eval_task}_results.jsonl"
with open(result_path, "w") as f:
for line in full_results:
metrics.append(line[eval_task])
f.write(json.dumps(line, ensure_ascii=False) + "\n")
metric = round(np.mean(metrics), 5)
f.write(f"\n{eval_task.upper()}: {metric}\n")
print(f"\nTotal {len(metrics)} samples")
print(f"{eval_task.upper()}: {metric}")
print(f"{eval_task.upper()} results saved to {result_path}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,88 @@
# Evaluate with Seed-TTS testset
import argparse
import json
import os
import sys
sys.path.append(os.getcwd())
import multiprocessing as mp
from importlib.resources import files
import numpy as np
from f5_tts.eval.utils_eval import get_seed_tts_test, run_asr_wer, run_sim
rel_path = str(files("f5_tts").joinpath("../../"))
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("-e", "--eval_task", type=str, default="wer", choices=["sim", "wer"])
parser.add_argument("-l", "--lang", type=str, default="en", choices=["zh", "en"])
parser.add_argument("-g", "--gen_wav_dir", type=str, required=True)
parser.add_argument("-n", "--gpu_nums", type=int, default=8, help="Number of GPUs to use")
parser.add_argument("--local", action="store_true", help="Use local custom checkpoint directory")
return parser.parse_args()
def main():
args = get_args()
eval_task = args.eval_task
lang = args.lang
gen_wav_dir = args.gen_wav_dir
metalst = rel_path + f"/data/seedtts_testset/{lang}/meta.lst" # seed-tts testset
# NOTE. paraformer-zh result will be slightly different according to the number of gpus, cuz batchsize is different
# zh 1.254 seems a result of 4 workers wer_seed_tts
gpus = list(range(args.gpu_nums))
test_set = get_seed_tts_test(metalst, gen_wav_dir, gpus)
local = args.local
if local: # use local custom checkpoint dir
if lang == "zh":
asr_ckpt_dir = "../checkpoints/funasr" # paraformer-zh dir under funasr
elif lang == "en":
asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3"
else:
asr_ckpt_dir = "" # auto download to cache dir
wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
# --------------------------------------------------------------------------
full_results = []
metrics = []
if eval_task == "wer":
with mp.Pool(processes=len(gpus)) as pool:
args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
results = pool.map(run_asr_wer, args)
for r in results:
full_results.extend(r)
elif eval_task == "sim":
with mp.Pool(processes=len(gpus)) as pool:
args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
results = pool.map(run_sim, args)
for r in results:
full_results.extend(r)
else:
raise ValueError(f"Unknown metric type: {eval_task}")
result_path = f"{gen_wav_dir}/_{eval_task}_results.jsonl"
with open(result_path, "w") as f:
for line in full_results:
metrics.append(line[eval_task])
f.write(json.dumps(line, ensure_ascii=False) + "\n")
metric = round(np.mean(metrics), 5)
f.write(f"\n{eval_task.upper()}: {metric}\n")
print(f"\nTotal {len(metrics)} samples")
print(f"{eval_task.upper()}: {metric}")
print(f"{eval_task.upper()} results saved to {result_path}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,42 @@
import argparse
import json
from pathlib import Path
import librosa
import torch
from tqdm import tqdm
def main():
parser = argparse.ArgumentParser(description="UTMOS Evaluation")
parser.add_argument("--audio_dir", type=str, required=True, help="Audio file path.")
parser.add_argument("--ext", type=str, default="wav", help="Audio extension.")
args = parser.parse_args()
device = "cuda" if torch.cuda.is_available() else "xpu" if torch.xpu.is_available() else "cpu"
predictor = torch.hub.load("tarepan/SpeechMOS:v1.2.0", "utmos22_strong", trust_repo=True)
predictor = predictor.to(device)
audio_paths = list(Path(args.audio_dir).rglob(f"*.{args.ext}"))
utmos_score = 0
utmos_result_path = Path(args.audio_dir) / "_utmos_results.jsonl"
with open(utmos_result_path, "w", encoding="utf-8") as f:
for audio_path in tqdm(audio_paths, desc="Processing"):
wav, sr = librosa.load(audio_path, sr=None, mono=True)
wav_tensor = torch.from_numpy(wav).to(device).unsqueeze(0)
score = predictor(wav_tensor, sr)
line = {}
line["wav"], line["utmos"] = str(audio_path.stem), score.item()
utmos_score += score.item()
f.write(json.dumps(line, ensure_ascii=False) + "\n")
avg_score = utmos_score / len(audio_paths) if len(audio_paths) > 0 else 0
f.write(f"\nUTMOS: {avg_score:.4f}\n")
print(f"UTMOS: {avg_score:.4f}")
print(f"UTMOS results saved to {utmos_result_path}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,419 @@
import math
import os
import random
import string
from pathlib import Path
import torch
import torch.nn.functional as F
import torchaudio
from tqdm import tqdm
from f5_tts.eval.ecapa_tdnn import ECAPA_TDNN_SMALL
from f5_tts.model.modules import MelSpec
from f5_tts.model.utils import convert_char_to_pinyin
# seedtts testset metainfo: utt, prompt_text, prompt_wav, gt_text, gt_wav
def get_seedtts_testset_metainfo(metalst):
f = open(metalst)
lines = f.readlines()
f.close()
metainfo = []
for line in lines:
if len(line.strip().split("|")) == 5:
utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split("|")
elif len(line.strip().split("|")) == 4:
utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
gt_wav = os.path.join(os.path.dirname(metalst), "wavs", utt + ".wav")
if not os.path.isabs(prompt_wav):
prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
metainfo.append((utt, prompt_text, prompt_wav, gt_text, gt_wav))
return metainfo
# librispeech test-clean metainfo: gen_utt, ref_txt, ref_wav, gen_txt, gen_wav
def get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path):
f = open(metalst)
lines = f.readlines()
f.close()
metainfo = []
for line in lines:
ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split("\t")
# ref_txt = ref_txt[0] + ref_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
ref_spk_id, ref_chaptr_id, _ = ref_utt.split("-")
ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + ".flac")
# gen_txt = gen_txt[0] + gen_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
gen_spk_id, gen_chaptr_id, _ = gen_utt.split("-")
gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + ".flac")
metainfo.append((gen_utt, ref_txt, ref_wav, " " + gen_txt, gen_wav))
return metainfo
# padded to max length mel batch
def padded_mel_batch(ref_mels):
max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax()
padded_ref_mels = []
for mel in ref_mels:
padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value=0)
padded_ref_mels.append(padded_ref_mel)
padded_ref_mels = torch.stack(padded_ref_mels)
padded_ref_mels = padded_ref_mels.permute(0, 2, 1)
return padded_ref_mels
# get prompts from metainfo containing: utt, prompt_text, prompt_wav, gt_text, gt_wav
def get_inference_prompt(
metainfo,
speed=1.0,
tokenizer="pinyin",
polyphone=True,
target_sample_rate=24000,
n_fft=1024,
win_length=1024,
n_mel_channels=100,
hop_length=256,
mel_spec_type="vocos",
target_rms=0.1,
use_truth_duration=False,
infer_batch_size=1,
num_buckets=200,
min_secs=3,
max_secs=40,
):
prompts_all = []
min_tokens = min_secs * target_sample_rate // hop_length
max_tokens = max_secs * target_sample_rate // hop_length
batch_accum = [0] * num_buckets
utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = (
[[] for _ in range(num_buckets)] for _ in range(6)
)
mel_spectrogram = MelSpec(
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
n_mel_channels=n_mel_channels,
target_sample_rate=target_sample_rate,
mel_spec_type=mel_spec_type,
)
for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(metainfo, desc="Processing prompts..."):
# Audio
ref_audio, ref_sr = torchaudio.load(prompt_wav)
ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio)))
if ref_rms < target_rms:
ref_audio = ref_audio * target_rms / ref_rms
assert ref_audio.shape[-1] > 5000, f"Empty prompt wav: {prompt_wav}, or torchaudio backend issue."
if ref_sr != target_sample_rate:
resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
ref_audio = resampler(ref_audio)
# Text
if len(prompt_text[-1].encode("utf-8")) == 1:
prompt_text = prompt_text + " "
text = [prompt_text + gt_text]
if tokenizer == "pinyin":
text_list = convert_char_to_pinyin(text, polyphone=polyphone)
else:
text_list = text
# to mel spectrogram
ref_mel = mel_spectrogram(ref_audio)
ref_mel = ref_mel.squeeze(0)
# Duration, mel frame length
ref_mel_len = ref_mel.shape[-1]
if use_truth_duration:
gt_audio, gt_sr = torchaudio.load(gt_wav)
if gt_sr != target_sample_rate:
resampler = torchaudio.transforms.Resample(gt_sr, target_sample_rate)
gt_audio = resampler(gt_audio)
total_mel_len = ref_mel_len + int(gt_audio.shape[-1] / hop_length / speed)
# # test vocoder resynthesis
# ref_audio = gt_audio
else:
ref_text_len = len(prompt_text.encode("utf-8"))
gen_text_len = len(gt_text.encode("utf-8"))
total_mel_len = ref_mel_len + int(ref_mel_len / ref_text_len * gen_text_len / speed)
# deal with batch
assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
assert min_tokens <= total_mel_len <= max_tokens, (
f"Audio {utt} has duration {total_mel_len * hop_length // target_sample_rate}s out of range [{min_secs}, {max_secs}]."
)
bucket_i = math.floor((total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets)
utts[bucket_i].append(utt)
ref_rms_list[bucket_i].append(ref_rms)
ref_mels[bucket_i].append(ref_mel)
ref_mel_lens[bucket_i].append(ref_mel_len)
total_mel_lens[bucket_i].append(total_mel_len)
final_text_list[bucket_i].extend(text_list)
batch_accum[bucket_i] += total_mel_len
if batch_accum[bucket_i] >= infer_batch_size:
# print(f"\n{len(ref_mels[bucket_i][0][0])}\n{ref_mel_lens[bucket_i]}\n{total_mel_lens[bucket_i]}")
prompts_all.append(
(
utts[bucket_i],
ref_rms_list[bucket_i],
padded_mel_batch(ref_mels[bucket_i]),
ref_mel_lens[bucket_i],
total_mel_lens[bucket_i],
final_text_list[bucket_i],
)
)
batch_accum[bucket_i] = 0
(
utts[bucket_i],
ref_rms_list[bucket_i],
ref_mels[bucket_i],
ref_mel_lens[bucket_i],
total_mel_lens[bucket_i],
final_text_list[bucket_i],
) = [], [], [], [], [], []
# add residual
for bucket_i, bucket_frames in enumerate(batch_accum):
if bucket_frames > 0:
prompts_all.append(
(
utts[bucket_i],
ref_rms_list[bucket_i],
padded_mel_batch(ref_mels[bucket_i]),
ref_mel_lens[bucket_i],
total_mel_lens[bucket_i],
final_text_list[bucket_i],
)
)
# not only leave easy work for last workers
random.seed(666)
random.shuffle(prompts_all)
return prompts_all
# get wav_res_ref_text of seed-tts test metalst
# https://github.com/BytedanceSpeech/seed-tts-eval
def get_seed_tts_test(metalst, gen_wav_dir, gpus):
f = open(metalst)
lines = f.readlines()
f.close()
test_set_ = []
for line in tqdm(lines):
if len(line.strip().split("|")) == 5:
utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split("|")
elif len(line.strip().split("|")) == 4:
utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
if not os.path.exists(os.path.join(gen_wav_dir, utt + ".wav")):
continue
gen_wav = os.path.join(gen_wav_dir, utt + ".wav")
if not os.path.isabs(prompt_wav):
prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
test_set_.append((gen_wav, prompt_wav, gt_text))
num_jobs = len(gpus)
if num_jobs == 1:
return [(gpus[0], test_set_)]
wav_per_job = len(test_set_) // num_jobs + 1
test_set = []
for i in range(num_jobs):
test_set.append((gpus[i], test_set_[i * wav_per_job : (i + 1) * wav_per_job]))
return test_set
# get librispeech test-clean cross sentence test
def get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth=False):
f = open(metalst)
lines = f.readlines()
f.close()
test_set_ = []
for line in tqdm(lines):
ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split("\t")
if eval_ground_truth:
gen_spk_id, gen_chaptr_id, _ = gen_utt.split("-")
gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + ".flac")
else:
if not os.path.exists(os.path.join(gen_wav_dir, gen_utt + ".wav")):
raise FileNotFoundError(f"Generated wav not found: {gen_utt}")
gen_wav = os.path.join(gen_wav_dir, gen_utt + ".wav")
ref_spk_id, ref_chaptr_id, _ = ref_utt.split("-")
ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + ".flac")
test_set_.append((gen_wav, ref_wav, gen_txt))
num_jobs = len(gpus)
if num_jobs == 1:
return [(gpus[0], test_set_)]
wav_per_job = len(test_set_) // num_jobs + 1
test_set = []
for i in range(num_jobs):
test_set.append((gpus[i], test_set_[i * wav_per_job : (i + 1) * wav_per_job]))
return test_set
# load asr model
def load_asr_model(lang, ckpt_dir=""):
if lang == "zh":
from funasr import AutoModel
model = AutoModel(
model=os.path.join(ckpt_dir, "paraformer-zh"),
# vad_model = os.path.join(ckpt_dir, "fsmn-vad"),
# punc_model = os.path.join(ckpt_dir, "ct-punc"),
# spk_model = os.path.join(ckpt_dir, "cam++"),
disable_update=True,
) # following seed-tts setting
elif lang == "en":
from faster_whisper import WhisperModel
model_size = "large-v3" if ckpt_dir == "" else ckpt_dir
model = WhisperModel(model_size, device="cuda", compute_type="float16")
return model
# WER Evaluation, the way Seed-TTS does
def run_asr_wer(args):
rank, lang, test_set, ckpt_dir = args
if lang == "zh":
import zhconv
torch.cuda.set_device(rank)
elif lang == "en":
os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
else:
raise NotImplementedError(
"lang support only 'zh' (funasr paraformer-zh), 'en' (faster-whisper-large-v3), for now."
)
asr_model = load_asr_model(lang, ckpt_dir=ckpt_dir)
from zhon.hanzi import punctuation
punctuation_all = punctuation + string.punctuation
wer_results = []
from jiwer import compute_measures
for gen_wav, prompt_wav, truth in tqdm(test_set):
if lang == "zh":
res = asr_model.generate(input=gen_wav, batch_size_s=300, disable_pbar=True)
hypo = res[0]["text"]
hypo = zhconv.convert(hypo, "zh-cn")
elif lang == "en":
segments, _ = asr_model.transcribe(gen_wav, beam_size=5, language="en")
hypo = ""
for segment in segments:
hypo = hypo + " " + segment.text
raw_truth = truth
raw_hypo = hypo
for x in punctuation_all:
truth = truth.replace(x, "")
hypo = hypo.replace(x, "")
truth = truth.replace(" ", " ")
hypo = hypo.replace(" ", " ")
if lang == "zh":
truth = " ".join([x for x in truth])
hypo = " ".join([x for x in hypo])
elif lang == "en":
truth = truth.lower()
hypo = hypo.lower()
measures = compute_measures(truth, hypo)
wer = measures["wer"]
# ref_list = truth.split(" ")
# subs = measures["substitutions"] / len(ref_list)
# dele = measures["deletions"] / len(ref_list)
# inse = measures["insertions"] / len(ref_list)
wer_results.append(
{
"wav": Path(gen_wav).stem,
"truth": raw_truth,
"hypo": raw_hypo,
"wer": wer,
}
)
return wer_results
# SIM Evaluation
def run_sim(args):
rank, test_set, ckpt_dir = args
device = f"cuda:{rank}"
model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type="wavlm_large", config_path=None)
state_dict = torch.load(ckpt_dir, weights_only=True, map_location=lambda storage, loc: storage)
model.load_state_dict(state_dict["model"], strict=False)
use_gpu = True if torch.cuda.is_available() else False
if use_gpu:
model = model.cuda(device)
model.eval()
sim_results = []
for gen_wav, prompt_wav, truth in tqdm(test_set):
wav1, sr1 = torchaudio.load(gen_wav)
wav2, sr2 = torchaudio.load(prompt_wav)
resample1 = torchaudio.transforms.Resample(orig_freq=sr1, new_freq=16000)
resample2 = torchaudio.transforms.Resample(orig_freq=sr2, new_freq=16000)
wav1 = resample1(wav1)
wav2 = resample2(wav2)
if use_gpu:
wav1 = wav1.cuda(device)
wav2 = wav2.cuda(device)
with torch.no_grad():
emb1 = model(wav1)
emb2 = model(wav2)
sim = F.cosine_similarity(emb1, emb2)[0].item()
# print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).")
sim_results.append(
{
"wav": Path(gen_wav).stem,
"sim": sim,
}
)
return sim_results

View File

@@ -0,0 +1,177 @@
# Inference
The pretrained model checkpoints can be reached at [🤗 Hugging Face](https://huggingface.co/SWivid/F5-TTS) and [🤖 Model Scope](https://www.modelscope.cn/models/SWivid/F5-TTS_Emilia-ZH-EN), or will be automatically downloaded when running inference scripts.
**More checkpoints with whole community efforts can be found in [SHARED.md](SHARED.md), supporting more languages.**
Currently support **30s for a single** generation, which is the **total length** (same logic if `fix_duration`) including both prompt and output audio. However, `infer_cli` and `infer_gradio` will automatically do chunk generation for longer text. Long reference audio will be **clip short to ~12s**.
To avoid possible inference failures, make sure you have seen through the following instructions.
- Use reference audio <12s and leave proper silence space (e.g. 1s) at the end. Otherwise there is a risk of truncating in the middle of word, leading to suboptimal generation.
- <ins>Uppercased letters</ins> (best with form like K.F.C.) will be uttered letter by letter, and lowercased letters used for common words.
- Add some spaces (blank: " ") or punctuations (e.g. "," ".") <ins>to explicitly introduce some pauses</ins>.
- If English punctuation marks the end of a sentence, make sure there is a space " " after it. Otherwise not regarded as when chunk.
- <ins>Preprocess numbers</ins> to Chinese letters if you want to have them read in Chinese, otherwise in English.
- If the generation output is blank (pure silence), <ins>check for FFmpeg installation</ins>.
- Try <ins>turn off `use_ema` if using an early-stage</ins> finetuned checkpoint (which goes just few updates).
## Gradio App
Currently supported features:
- Basic TTS with Chunk Inference
- Multi-Style / Multi-Speaker Generation
- Voice Chat powered by Qwen2.5-3B-Instruct
- [Custom inference with more language support](SHARED.md)
The cli command `f5-tts_infer-gradio` equals to `python src/f5_tts/infer/infer_gradio.py`, which launches a Gradio APP (web interface) for inference.
The script will load model checkpoints from Huggingface. You can also manually download files and update the path to `load_model()` in `infer_gradio.py`. Currently only load TTS models first, will load ASR model to do transcription if `ref_text` not provided, will load LLM model if use Voice Chat.
More flags options:
```bash
# Automatically launch the interface in the default web browser
f5-tts_infer-gradio --inbrowser
# Set the root path of the application, if it's not served from the root ("/") of the domain
# For example, if the application is served at "https://example.com/myapp"
f5-tts_infer-gradio --root_path "/myapp"
```
Could also be used as a component for larger application:
```python
import gradio as gr
from f5_tts.infer.infer_gradio import app
with gr.Blocks() as main_app:
gr.Markdown("# This is an example of using F5-TTS within a bigger Gradio app")
# ... other Gradio components
app.render()
main_app.launch()
```
## CLI Inference
The cli command `f5-tts_infer-cli` equals to `python src/f5_tts/infer/infer_cli.py`, which is a command line tool for inference.
The script will load model checkpoints from Huggingface. You can also manually download files and use `--ckpt_file` to specify the model you want to load, or directly update in `infer_cli.py`.
For change vocab.txt use `--vocab_file` to provide your `vocab.txt` file.
Basically you can inference with flags:
```bash
# Leave --ref_text "" will have ASR model transcribe (extra GPU memory usage)
f5-tts_infer-cli \
--model F5TTS_v1_Base \
--ref_audio "ref_audio.wav" \
--ref_text "The content, subtitle or transcription of reference audio." \
--gen_text "Some text you want TTS model generate for you."
# Use BigVGAN as vocoder. Currently only support F5TTS_Base.
f5-tts_infer-cli --model F5TTS_Base --vocoder_name bigvgan --load_vocoder_from_local
# Use custom path checkpoint, e.g.
f5-tts_infer-cli --ckpt_file ckpts/F5TTS_v1_Base/model_1250000.safetensors
# More instructions
f5-tts_infer-cli --help
```
And a `.toml` file would help with more flexible usage.
```bash
f5-tts_infer-cli -c custom.toml
```
For example, you can use `.toml` to pass in variables, refer to `src/f5_tts/infer/examples/basic/basic.toml`:
```toml
# F5TTS_v1_Base | E2TTS_Base
model = "F5TTS_v1_Base"
ref_audio = "infer/examples/basic/basic_ref_en.wav"
# If an empty "", transcribes the reference audio automatically.
ref_text = "Some call me nature, others call me mother nature."
gen_text = "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring."
# File with text to generate. Ignores the text above.
gen_file = ""
remove_silence = false
output_dir = "tests"
```
You can also leverage `.toml` file to do multi-style generation, refer to `src/f5_tts/infer/examples/multi/story.toml`.
```toml
# F5TTS_v1_Base | E2TTS_Base
model = "F5TTS_v1_Base"
ref_audio = "infer/examples/multi/main.flac"
# If an empty "", transcribes the reference audio automatically.
ref_text = ""
gen_text = ""
# File with text to generate. Ignores the text above.
gen_file = "infer/examples/multi/story.txt"
remove_silence = true
output_dir = "tests"
[voices.town]
ref_audio = "infer/examples/multi/town.flac"
ref_text = ""
[voices.country]
ref_audio = "infer/examples/multi/country.flac"
ref_text = ""
```
You should mark the voice with `[main]` `[town]` `[country]` whenever you want to change voice, refer to `src/f5_tts/infer/examples/multi/story.txt`.
## API Usage
```python
from importlib.resources import files
from f5_tts.api import F5TTS
f5tts = F5TTS()
wav, sr, spec = f5tts.infer(
ref_file=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")),
ref_text="some call me nature, others call me mother nature.",
gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""",
file_wave=str(files("f5_tts").joinpath("../../tests/api_out.wav")),
file_spec=str(files("f5_tts").joinpath("../../tests/api_out.png")),
seed=None,
)
```
Check [api.py](../api.py) for more details.
## TensorRT-LLM Deployment
See [detailed instructions](../runtime/triton_trtllm/README.md) for more information.
## Socket Real-time Service
Real-time voice output with chunk stream:
```bash
# Start socket server
python src/f5_tts/socket_server.py
# If PyAudio not installed
sudo apt-get install portaudio19-dev
pip install pyaudio
# Communicate with socket client
python src/f5_tts/socket_client.py
```
## Speech Editing
To test speech editing capabilities, use the following command:
```bash
python src/f5_tts/infer/speech_edit.py
```

View File

@@ -0,0 +1,193 @@
<!-- omit in toc -->
# Shared Model Cards
<!-- omit in toc -->
### **Prerequisites of using**
- This document is serving as a quick lookup table for the community training/finetuning result, with various language support.
- The models in this repository are open source and are based on voluntary contributions from contributors.
- The use of models must be conditioned on respect for the respective creators. The convenience brought comes from their efforts.
<!-- omit in toc -->
### **Welcome to share here**
- Have a pretrained/finetuned result: model checkpoint (pruned best to facilitate inference, i.e. leave only `ema_model_state_dict`) and corresponding vocab file (for tokenization).
- Host a public [huggingface model repository](https://huggingface.co/new) and upload the model related files.
- Make a pull request adding a model card to the current page, i.e. `src\f5_tts\infer\SHARED.md`.
<!-- omit in toc -->
### Supported Languages
- [Multilingual](#multilingual)
- [F5-TTS v1 v0 Base @ zh \& en @ F5-TTS](#f5-tts-v1-v0-base--zh--en--f5-tts)
- [English](#english)
- [Finnish](#finnish)
- [F5-TTS Base @ fi @ AsmoKoskinen](#f5-tts-base--fi--asmokoskinen)
- [French](#french)
- [F5-TTS Base @ fr @ RASPIAUDIO](#f5-tts-base--fr--raspiaudio)
- [German](#german)
- [F5-TTS Base @ de @ hvoss-techfak](#f5-tts-base--de--hvoss-techfak)
- [Hindi](#hindi)
- [F5-TTS Small @ hi @ SPRINGLab](#f5-tts-small--hi--springlab)
- [Italian](#italian)
- [F5-TTS Base @ it @ alien79](#f5-tts-base--it--alien79)
- [Japanese](#japanese)
- [F5-TTS Base @ ja @ Jmica](#f5-tts-base--ja--jmica)
- [Mandarin](#mandarin)
- [Russian](#russian)
- [F5-TTS Base @ ru @ HotDro4illa](#f5-tts-base--ru--hotdro4illa)
- [Spanish](#spanish)
- [F5-TTS Base @ es @ jpgallegoar](#f5-tts-base--es--jpgallegoar)
## Multilingual
#### F5-TTS v1 v0 Base @ zh & en @ F5-TTS
|Model|🤗Hugging Face|Data (Hours)|Model License|
|:---:|:------------:|:-----------:|:-------------:|
|F5-TTS v1 Base|[ckpt & vocab](https://huggingface.co/SWivid/F5-TTS/tree/main/F5TTS_v1_Base)|[Emilia 95K zh&en](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07)|cc-by-nc-4.0|
```bash
Model: hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors
# A Variant Model: hf://SWivid/F5-TTS/F5TTS_v1_Base_no_zero_init/model_1250000.safetensors
Vocab: hf://SWivid/F5-TTS/F5TTS_v1_Base/vocab.txt
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
```
|Model|🤗Hugging Face|Data (Hours)|Model License|
|:---:|:------------:|:-----------:|:-------------:|
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/SWivid/F5-TTS/tree/main/F5TTS_Base)|[Emilia 95K zh&en](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07)|cc-by-nc-4.0|
```bash
Model: hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors
Vocab: hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
```
*Other infos, e.g. Author info, Github repo, Link to some sampled results, Usage instruction, Tutorial (Blog, Video, etc.) ...*
## English
## Finnish
#### F5-TTS Base @ fi @ AsmoKoskinen
|Model|🤗Hugging Face|Data|Model License|
|:---:|:------------:|:-----------:|:-------------:|
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/AsmoKoskinen/F5-TTS_Finnish_Model)|[Common Voice](https://huggingface.co/datasets/mozilla-foundation/common_voice_17_0), [Vox Populi](https://huggingface.co/datasets/facebook/voxpopuli)|cc-by-nc-4.0|
```bash
Model: hf://AsmoKoskinen/F5-TTS_Finnish_Model/model_common_voice_fi_vox_populi_fi_20241206.safetensors
Vocab: hf://AsmoKoskinen/F5-TTS_Finnish_Model/vocab.txt
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
```
## French
#### F5-TTS Base @ fr @ RASPIAUDIO
|Model|🤗Hugging Face|Data (Hours)|Model License|
|:---:|:------------:|:-----------:|:-------------:|
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/RASPIAUDIO/F5-French-MixedSpeakers-reduced)|[LibriVox](https://librivox.org/)|cc-by-nc-4.0|
```bash
Model: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/model_last_reduced.pt
Vocab: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/vocab.txt
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
```
- [Online Inference with Hugging Face Space](https://huggingface.co/spaces/RASPIAUDIO/f5-tts_french).
- [Tutorial video to train a new language model](https://www.youtube.com/watch?v=UO4usaOojys).
- [Discussion about this training can be found here](https://github.com/SWivid/F5-TTS/issues/434).
## German
#### F5-TTS Base @ de @ hvoss-techfak
|Model|🤗Hugging Face|Data (Hours)|Model License|
|:---:|:------------:|:-----------:|:-------------:|
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/hvoss-techfak/F5-TTS-German)|[Mozilla Common Voice 19.0](https://commonvoice.mozilla.org/en/datasets) & 800 hours Crowdsourced |cc-by-nc-4.0|
```bash
Model: hf://hvoss-techfak/F5-TTS-German/model_f5tts_german.pt
Vocab: hf://hvoss-techfak/F5-TTS-German/vocab.txt
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
```
- Finetuned by [@hvoss-techfak](https://github.com/hvoss-techfak)
## Hindi
#### F5-TTS Small @ hi @ SPRINGLab
|Model|🤗Hugging Face|Data (Hours)|Model License|
|:---:|:------------:|:-----------:|:-------------:|
|F5-TTS Small|[ckpt & vocab](https://huggingface.co/SPRINGLab/F5-Hindi-24KHz)|[IndicTTS Hi](https://huggingface.co/datasets/SPRINGLab/IndicTTS-Hindi) & [IndicVoices-R Hi](https://huggingface.co/datasets/SPRINGLab/IndicVoices-R_Hindi) |cc-by-4.0|
```bash
Model: hf://SPRINGLab/F5-Hindi-24KHz/model_2500000.safetensors
Vocab: hf://SPRINGLab/F5-Hindi-24KHz/vocab.txt
Config: {"dim": 768, "depth": 18, "heads": 12, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
```
- Authors: SPRING Lab, Indian Institute of Technology, Madras
- Website: https://asr.iitm.ac.in/
## Italian
#### F5-TTS Base @ it @ alien79
|Model|🤗Hugging Face|Data|Model License|
|:---:|:------------:|:-----------:|:-------------:|
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/alien79/F5-TTS-italian)|[ylacombe/cml-tts](https://huggingface.co/datasets/ylacombe/cml-tts) |cc-by-nc-4.0|
```bash
Model: hf://alien79/F5-TTS-italian/model_159600.safetensors
Vocab: hf://alien79/F5-TTS-italian/vocab.txt
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
```
- Trained by [Mithril Man](https://github.com/MithrilMan)
- Model details on [hf project home](https://huggingface.co/alien79/F5-TTS-italian)
- Open to collaborations to further improve the model
## Japanese
#### F5-TTS Base @ ja @ Jmica
|Model|🤗Hugging Face|Data (Hours)|Model License|
|:---:|:------------:|:-----------:|:-------------:|
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/Jmica/F5TTS/tree/main/JA_21999120)|[Emilia 1.7k JA](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07) & [Galgame Dataset 5.4k](https://huggingface.co/datasets/OOPPEENN/Galgame_Dataset)|cc-by-nc-4.0|
```bash
Model: hf://Jmica/F5TTS/JA_21999120/model_21999120.pt
Vocab: hf://Jmica/F5TTS/JA_21999120/vocab_japanese.txt
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
```
## Mandarin
## Russian
#### F5-TTS Base @ ru @ HotDro4illa
|Model|🤗Hugging Face|Data (Hours)|Model License|
|:---:|:------------:|:-----------:|:-------------:|
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/hotstone228/F5-TTS-Russian)|[Common voice](https://huggingface.co/datasets/mozilla-foundation/common_voice_17_0)|cc-by-nc-4.0|
```bash
Model: hf://hotstone228/F5-TTS-Russian/model_last.safetensors
Vocab: hf://hotstone228/F5-TTS-Russian/vocab.txt
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
```
- Finetuned by [HotDro4illa](https://github.com/HotDro4illa)
- Any improvements are welcome
## Spanish
#### F5-TTS Base @ es @ jpgallegoar
|Model|🤗Hugging Face|Data (Hours)|Model License|
|:---:|:------------:|:-----------:|:-------------:|
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/jpgallegoar/F5-Spanish)|[Voxpopuli](https://huggingface.co/datasets/facebook/voxpopuli) & Crowdsourced & TEDx, 218 hours|cc0-1.0|
- @jpgallegoar [GitHub repo](https://github.com/jpgallegoar/Spanish-F5), Jupyter Notebook and Gradio usage for Spanish model.

View File

@@ -0,0 +1,11 @@
# F5TTS_v1_Base | E2TTS_Base
model = "F5TTS_v1_Base"
ref_audio = "infer/examples/basic/basic_ref_en.wav"
# If an empty "", transcribes the reference audio automatically.
ref_text = "Some call me nature, others call me mother nature."
gen_text = "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring."
# File with text to generate. Ignores the text above.
gen_file = ""
remove_silence = false
output_dir = "tests"
output_file = "infer_cli_basic.wav"

View File

@@ -0,0 +1,20 @@
# F5TTS_v1_Base | E2TTS_Base
model = "F5TTS_v1_Base"
ref_audio = "infer/examples/multi/main.flac"
# If an empty "", transcribes the reference audio automatically.
ref_text = ""
gen_text = ""
# File with text to generate. Ignores the text above.
gen_file = "infer/examples/multi/story.txt"
remove_silence = true
output_dir = "tests"
output_file = "infer_cli_story.wav"
[voices.town]
ref_audio = "infer/examples/multi/town.flac"
ref_text = ""
[voices.country]
ref_audio = "infer/examples/multi/country.flac"
ref_text = ""

View File

@@ -0,0 +1 @@
A Town Mouse and a Country Mouse were acquaintances, and the Country Mouse one day invited his friend to come and see him at his home in the fields. The Town Mouse came, and they sat down to a dinner of barleycorns and roots, the latter of which had a distinctly earthy flavour. The fare was not much to the taste of the guest, and presently he broke out with [town] “My poor dear friend, you live here no better than the ants. Now, you should just see how I fare! My larder is a regular horn of plenty. You must come and stay with me, and I promise you you shall live on the fat of the land.” [main] So when he returned to town he took the Country Mouse with him, and showed him into a larder containing flour and oatmeal and figs and honey and dates. The Country Mouse had never seen anything like it, and sat down to enjoy the luxuries his friend provided: but before they had well begun, the door of the larder opened and someone came in. The two Mice scampered off and hid themselves in a narrow and exceedingly uncomfortable hole. Presently, when all was quiet, they ventured out again; but someone else came in, and off they scuttled again. This was too much for the visitor. [country] “Goodbye,” [main] said he, [country] “Im off. You live in the lap of luxury, I can see, but you are surrounded by dangers; whereas at home I can enjoy my simple dinner of roots and corn in peace.”

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,368 @@
import argparse
import codecs
import os
import re
from datetime import datetime
from importlib.resources import files
from pathlib import Path
import numpy as np
import soundfile as sf
import tomli
from cached_path import cached_path
from hydra.utils import get_class
from omegaconf import OmegaConf
from f5_tts.infer.utils_infer import (
cfg_strength,
cross_fade_duration,
device,
fix_duration,
infer_process,
load_model,
load_vocoder,
mel_spec_type,
nfe_step,
preprocess_ref_audio_text,
remove_silence_for_generated_wav,
speed,
sway_sampling_coef,
target_rms,
)
parser = argparse.ArgumentParser(
prog="python3 infer-cli.py",
description="Commandline interface for E2/F5 TTS with Advanced Batch Processing.",
epilog="Specify options above to override one or more settings from config.",
)
parser.add_argument(
"-c",
"--config",
type=str,
default=os.path.join(files("f5_tts").joinpath("infer/examples/basic"), "basic.toml"),
help="The configuration file, default see infer/examples/basic/basic.toml",
)
# Note. Not to provide default value here in order to read default from config file
parser.add_argument(
"-m",
"--model",
type=str,
help="The model name: F5TTS_v1_Base | F5TTS_Base | E2TTS_Base | etc.",
)
parser.add_argument(
"-mc",
"--model_cfg",
type=str,
help="The path to F5-TTS model config file .yaml",
)
parser.add_argument(
"-p",
"--ckpt_file",
type=str,
help="The path to model checkpoint .pt, leave blank to use default",
)
parser.add_argument(
"-v",
"--vocab_file",
type=str,
help="The path to vocab file .txt, leave blank to use default",
)
parser.add_argument(
"-r",
"--ref_audio",
type=str,
help="The reference audio file.",
)
parser.add_argument(
"-s",
"--ref_text",
type=str,
help="The transcript/subtitle for the reference audio",
)
parser.add_argument(
"-t",
"--gen_text",
type=str,
help="The text to make model synthesize a speech",
)
parser.add_argument(
"-f",
"--gen_file",
type=str,
help="The file with text to generate, will ignore --gen_text",
)
parser.add_argument(
"-o",
"--output_dir",
type=str,
help="The path to output folder",
)
parser.add_argument(
"-w",
"--output_file",
type=str,
help="The name of output file",
)
parser.add_argument(
"--save_chunk",
action="store_true",
help="To save each audio chunks during inference",
)
parser.add_argument(
"--remove_silence",
action="store_true",
help="To remove long silence found in ouput",
)
parser.add_argument(
"--load_vocoder_from_local",
action="store_true",
help="To load vocoder from local dir, default to ../checkpoints/vocos-mel-24khz",
)
parser.add_argument(
"--vocoder_name",
type=str,
choices=["vocos", "bigvgan"],
help=f"Used vocoder name: vocos | bigvgan, default {mel_spec_type}",
)
parser.add_argument(
"--target_rms",
type=float,
help=f"Target output speech loudness normalization value, default {target_rms}",
)
parser.add_argument(
"--cross_fade_duration",
type=float,
help=f"Duration of cross-fade between audio segments in seconds, default {cross_fade_duration}",
)
parser.add_argument(
"--nfe_step",
type=int,
help=f"The number of function evaluation (denoising steps), default {nfe_step}",
)
parser.add_argument(
"--cfg_strength",
type=float,
help=f"Classifier-free guidance strength, default {cfg_strength}",
)
parser.add_argument(
"--sway_sampling_coef",
type=float,
help=f"Sway Sampling coefficient, default {sway_sampling_coef}",
)
parser.add_argument(
"--speed",
type=float,
help=f"The speed of the generated audio, default {speed}",
)
parser.add_argument(
"--fix_duration",
type=float,
help=f"Fix the total duration (ref and gen audios) in seconds, default {fix_duration}",
)
parser.add_argument(
"--device",
type=str,
help="Specify the device to run on",
)
args = parser.parse_args()
# config file
config = tomli.load(open(args.config, "rb"))
# command-line interface parameters
model = args.model or config.get("model", "F5TTS_v1_Base")
ckpt_file = args.ckpt_file or config.get("ckpt_file", "")
vocab_file = args.vocab_file or config.get("vocab_file", "")
ref_audio = args.ref_audio or config.get("ref_audio", "infer/examples/basic/basic_ref_en.wav")
ref_text = (
args.ref_text
if args.ref_text is not None
else config.get("ref_text", "Some call me nature, others call me mother nature.")
)
gen_text = args.gen_text or config.get("gen_text", "Here we generate something just for test.")
gen_file = args.gen_file or config.get("gen_file", "")
output_dir = args.output_dir or config.get("output_dir", "tests")
output_file = args.output_file or config.get(
"output_file", f"infer_cli_{datetime.now().strftime(r'%Y%m%d_%H%M%S')}.wav"
)
save_chunk = args.save_chunk or config.get("save_chunk", False)
remove_silence = args.remove_silence or config.get("remove_silence", False)
load_vocoder_from_local = args.load_vocoder_from_local or config.get("load_vocoder_from_local", False)
vocoder_name = args.vocoder_name or config.get("vocoder_name", mel_spec_type)
target_rms = args.target_rms or config.get("target_rms", target_rms)
cross_fade_duration = args.cross_fade_duration or config.get("cross_fade_duration", cross_fade_duration)
nfe_step = args.nfe_step or config.get("nfe_step", nfe_step)
cfg_strength = args.cfg_strength or config.get("cfg_strength", cfg_strength)
sway_sampling_coef = args.sway_sampling_coef or config.get("sway_sampling_coef", sway_sampling_coef)
speed = args.speed or config.get("speed", speed)
fix_duration = args.fix_duration or config.get("fix_duration", fix_duration)
device = args.device or config.get("device", device)
# patches for pip pkg user
if "infer/examples/" in ref_audio:
ref_audio = str(files("f5_tts").joinpath(f"{ref_audio}"))
if "infer/examples/" in gen_file:
gen_file = str(files("f5_tts").joinpath(f"{gen_file}"))
if "voices" in config:
for voice in config["voices"]:
voice_ref_audio = config["voices"][voice]["ref_audio"]
if "infer/examples/" in voice_ref_audio:
config["voices"][voice]["ref_audio"] = str(files("f5_tts").joinpath(f"{voice_ref_audio}"))
# ignore gen_text if gen_file provided
if gen_file:
gen_text = codecs.open(gen_file, "r", "utf-8").read()
# output path
wave_path = Path(output_dir) / output_file
# spectrogram_path = Path(output_dir) / "infer_cli_out.png"
if save_chunk:
output_chunk_dir = os.path.join(output_dir, f"{Path(output_file).stem}_chunks")
if not os.path.exists(output_chunk_dir):
os.makedirs(output_chunk_dir)
# load vocoder
if vocoder_name == "vocos":
vocoder_local_path = "../checkpoints/vocos-mel-24khz"
elif vocoder_name == "bigvgan":
vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
vocoder = load_vocoder(
vocoder_name=vocoder_name, is_local=load_vocoder_from_local, local_path=vocoder_local_path, device=device
)
# load TTS model
model_cfg = OmegaConf.load(
args.model_cfg or config.get("model_cfg", str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
)
model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
model_arc = model_cfg.model.arch
repo_name, ckpt_step, ckpt_type = "F5-TTS", 1250000, "safetensors"
if model != "F5TTS_Base":
assert vocoder_name == model_cfg.model.mel_spec.mel_spec_type
# override for previous models
if model == "F5TTS_Base":
if vocoder_name == "vocos":
ckpt_step = 1200000
elif vocoder_name == "bigvgan":
model = "F5TTS_Base_bigvgan"
ckpt_type = "pt"
elif model == "E2TTS_Base":
repo_name = "E2-TTS"
ckpt_step = 1200000
if not ckpt_file:
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}"))
print(f"Using {model}...")
ema_model = load_model(
model_cls, model_arc, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file, device=device
)
# inference process
def main():
main_voice = {"ref_audio": ref_audio, "ref_text": ref_text}
if "voices" not in config:
voices = {"main": main_voice}
else:
voices = config["voices"]
voices["main"] = main_voice
for voice in voices:
print("Voice:", voice)
print("ref_audio ", voices[voice]["ref_audio"])
voices[voice]["ref_audio"], voices[voice]["ref_text"] = preprocess_ref_audio_text(
voices[voice]["ref_audio"], voices[voice]["ref_text"]
)
print("ref_audio_", voices[voice]["ref_audio"], "\n\n")
generated_audio_segments = []
reg1 = r"(?=\[\w+\])"
chunks = re.split(reg1, gen_text)
reg2 = r"\[(\w+)\]"
for text in chunks:
if not text.strip():
continue
match = re.match(reg2, text)
if match:
voice = match[1]
else:
print("No voice tag found, using main.")
voice = "main"
if voice not in voices:
print(f"Voice {voice} not found, using main.")
voice = "main"
text = re.sub(reg2, "", text)
ref_audio_ = voices[voice]["ref_audio"]
ref_text_ = voices[voice]["ref_text"]
gen_text_ = text.strip()
print(f"Voice: {voice}")
audio_segment, final_sample_rate, spectrogram = infer_process(
ref_audio_,
ref_text_,
gen_text_,
ema_model,
vocoder,
mel_spec_type=vocoder_name,
target_rms=target_rms,
cross_fade_duration=cross_fade_duration,
nfe_step=nfe_step,
cfg_strength=cfg_strength,
sway_sampling_coef=sway_sampling_coef,
speed=speed,
fix_duration=fix_duration,
device=device,
)
generated_audio_segments.append(audio_segment)
if save_chunk:
if len(gen_text_) > 200:
gen_text_ = gen_text_[:200] + " ... "
sf.write(
os.path.join(output_chunk_dir, f"{len(generated_audio_segments) - 1}_{gen_text_}.wav"),
audio_segment,
final_sample_rate,
)
if generated_audio_segments:
final_wave = np.concatenate(generated_audio_segments)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
with open(wave_path, "wb") as f:
sf.write(f.name, final_wave, final_sample_rate)
# Remove silence
if remove_silence:
remove_silence_for_generated_wav(f.name)
print(f.name)
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,205 @@
import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility
from importlib.resources import files
import torch
import torch.nn.functional as F
import torchaudio
from cached_path import cached_path
from hydra.utils import get_class
from omegaconf import OmegaConf
from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder, save_spectrogram
from f5_tts.model import CFM
from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
device = (
"cuda"
if torch.cuda.is_available()
else "xpu"
if torch.xpu.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
# ---------------------- infer setting ---------------------- #
seed = None # int | None
exp_name = "F5TTS_v1_Base" # F5TTS_v1_Base | E2TTS_Base
ckpt_step = 1250000
nfe_step = 32 # 16, 32
cfg_strength = 2.0
ode_method = "euler" # euler | midpoint
sway_sampling_coef = -1.0
speed = 1.0
target_rms = 0.1
model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{exp_name}.yaml")))
model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
model_arc = model_cfg.model.arch
dataset_name = model_cfg.datasets.name
tokenizer = model_cfg.model.tokenizer
mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
target_sample_rate = model_cfg.model.mel_spec.target_sample_rate
n_mel_channels = model_cfg.model.mel_spec.n_mel_channels
hop_length = model_cfg.model.mel_spec.hop_length
win_length = model_cfg.model.mel_spec.win_length
n_fft = model_cfg.model.mel_spec.n_fft
# ckpt_path = str(files("f5_tts").joinpath("../../")) + f"/ckpts/{exp_name}/model_{ckpt_step}.safetensors"
ckpt_path = str(cached_path(f"hf://SWivid/F5-TTS/{exp_name}/model_{ckpt_step}.safetensors"))
output_dir = "tests"
# [leverage https://github.com/MahmoudAshraf97/ctc-forced-aligner to get char level alignment]
# pip install git+https://github.com/MahmoudAshraf97/ctc-forced-aligner.git
# [write the origin_text into a file, e.g. tests/test_edit.txt]
# ctc-forced-aligner --audio_path "src/f5_tts/infer/examples/basic/basic_ref_en.wav" --text_path "tests/test_edit.txt" --language "zho" --romanize --split_size "char"
# [result will be saved at same path of audio file]
# [--language "zho" for Chinese, "eng" for English]
# [if local ckpt, set --alignment_model "../checkpoints/mms-300m-1130-forced-aligner"]
audio_to_edit = str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav"))
origin_text = "Some call me nature, others call me mother nature."
target_text = "Some call me optimist, others call me realist."
parts_to_edit = [
[1.42, 2.44],
[4.04, 4.9],
] # stard_ends of "nature" & "mother nature", in seconds
fix_duration = [
1.2,
1,
] # fix duration for "optimist" & "realist", in seconds
# audio_to_edit = "src/f5_tts/infer/examples/basic/basic_ref_zh.wav"
# origin_text = "对,这就是我,万人敬仰的太乙真人。"
# target_text = "对,那就是你,万人敬仰的太白金星。"
# parts_to_edit = [[0.84, 1.4], [1.92, 2.4], [4.26, 6.26], ]
# fix_duration = None # use origin text duration
# -------------------------------------------------#
use_ema = True
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Vocoder model
local = False
if mel_spec_type == "vocos":
vocoder_local_path = "../checkpoints/charactr/vocos-mel-24khz"
elif mel_spec_type == "bigvgan":
vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
vocoder = load_vocoder(vocoder_name=mel_spec_type, is_local=local, local_path=vocoder_local_path)
# Tokenizer
vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
# Model
model = CFM(
transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
mel_spec_kwargs=dict(
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
n_mel_channels=n_mel_channels,
target_sample_rate=target_sample_rate,
mel_spec_type=mel_spec_type,
),
odeint_kwargs=dict(
method=ode_method,
),
vocab_char_map=vocab_char_map,
).to(device)
dtype = torch.float32 if mel_spec_type == "bigvgan" else None
model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
# Audio
audio, sr = torchaudio.load(audio_to_edit)
if audio.shape[0] > 1:
audio = torch.mean(audio, dim=0, keepdim=True)
rms = torch.sqrt(torch.mean(torch.square(audio)))
if rms < target_rms:
audio = audio * target_rms / rms
if sr != target_sample_rate:
resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
audio = resampler(audio)
offset = 0
audio_ = torch.zeros(1, 0)
edit_mask = torch.zeros(1, 0, dtype=torch.bool)
for part in parts_to_edit:
start, end = part
part_dur = end - start if fix_duration is None else fix_duration.pop(0)
part_dur = part_dur * target_sample_rate
start = start * target_sample_rate
audio_ = torch.cat((audio_, audio[:, round(offset) : round(start)], torch.zeros(1, round(part_dur))), dim=-1)
edit_mask = torch.cat(
(
edit_mask,
torch.ones(1, round((start - offset) / hop_length), dtype=torch.bool),
torch.zeros(1, round(part_dur / hop_length), dtype=torch.bool),
),
dim=-1,
)
offset = end * target_sample_rate
audio = torch.cat((audio_, audio[:, round(offset) :]), dim=-1)
edit_mask = F.pad(edit_mask, (0, audio.shape[-1] // hop_length - edit_mask.shape[-1] + 1), value=True)
audio = audio.to(device)
edit_mask = edit_mask.to(device)
# Text
text_list = [target_text]
if tokenizer == "pinyin":
final_text_list = convert_char_to_pinyin(text_list)
else:
final_text_list = [text_list]
print(f"text : {text_list}")
print(f"pinyin: {final_text_list}")
# Duration
ref_audio_len = 0
duration = audio.shape[-1] // hop_length
# Inference
with torch.inference_mode():
generated, trajectory = model.sample(
cond=audio,
text=final_text_list,
duration=duration,
steps=nfe_step,
cfg_strength=cfg_strength,
sway_sampling_coef=sway_sampling_coef,
seed=seed,
edit_mask=edit_mask,
)
print(f"Generated mel: {generated.shape}")
# Final result
generated = generated.to(torch.float32)
generated = generated[:, ref_audio_len:, :]
gen_mel_spec = generated.permute(0, 2, 1)
if mel_spec_type == "vocos":
generated_wave = vocoder.decode(gen_mel_spec).cpu()
elif mel_spec_type == "bigvgan":
generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu()
if rms < target_rms:
generated_wave = generated_wave * rms / target_rms
save_spectrogram(gen_mel_spec[0].cpu().numpy(), f"{output_dir}/speech_edit_out.png")
torchaudio.save(f"{output_dir}/speech_edit_out.wav", generated_wave, target_sample_rate)
print(f"Generated wav: {generated_wave.shape}")

View File

@@ -0,0 +1,610 @@
# A unified script for inference process
# Make adjustments inside functions, and consider both gradio and cli scripts if need to change func output format
import os
import sys
from concurrent.futures import ThreadPoolExecutor
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility
sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../third_party/BigVGAN/")
import hashlib
import re
import tempfile
from importlib.resources import files
import matplotlib
matplotlib.use("Agg")
import matplotlib.pylab as plt
import numpy as np
import torch
import torchaudio
import tqdm
from huggingface_hub import hf_hub_download
from pydub import AudioSegment, silence
from transformers import pipeline
from vocos import Vocos
from f5_tts.model import CFM
from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
_ref_audio_cache = {}
_ref_text_cache = {}
device = (
"cuda"
if torch.cuda.is_available()
else "xpu"
if torch.xpu.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
tempfile_kwargs = {"delete_on_close": False} if sys.version_info >= (3, 12) else {"delete": False}
# -----------------------------------------
target_sample_rate = 24000
n_mel_channels = 100
hop_length = 256
win_length = 1024
n_fft = 1024
mel_spec_type = "vocos"
target_rms = 0.1
cross_fade_duration = 0.15
ode_method = "euler"
nfe_step = 32 # 16, 32
cfg_strength = 2.0
sway_sampling_coef = -1.0
speed = 1.0
fix_duration = None
# -----------------------------------------
# chunk text into smaller pieces
def chunk_text(text, max_chars=135):
"""
Splits the input text into chunks, each with a maximum number of characters.
Args:
text (str): The text to be split.
max_chars (int): The maximum number of characters per chunk.
Returns:
List[str]: A list of text chunks.
"""
chunks = []
current_chunk = ""
# Split the text into sentences based on punctuation followed by whitespace
sentences = re.split(r"(?<=[;:,.!?])\s+|(?<=[;:,。!?])", text)
for sentence in sentences:
if len(current_chunk.encode("utf-8")) + len(sentence.encode("utf-8")) <= max_chars:
current_chunk += sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
else:
if current_chunk:
chunks.append(current_chunk.strip())
current_chunk = sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
if current_chunk:
chunks.append(current_chunk.strip())
return chunks
# load vocoder
def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=device, hf_cache_dir=None):
if vocoder_name == "vocos":
# vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device)
if is_local:
print(f"Load vocos from local path {local_path}")
config_path = f"{local_path}/config.yaml"
model_path = f"{local_path}/pytorch_model.bin"
else:
print("Download Vocos from huggingface charactr/vocos-mel-24khz")
repo_id = "charactr/vocos-mel-24khz"
config_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml")
model_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="pytorch_model.bin")
vocoder = Vocos.from_hparams(config_path)
state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
from vocos.feature_extractors import EncodecFeatures
if isinstance(vocoder.feature_extractor, EncodecFeatures):
encodec_parameters = {
"feature_extractor.encodec." + key: value
for key, value in vocoder.feature_extractor.encodec.state_dict().items()
}
state_dict.update(encodec_parameters)
vocoder.load_state_dict(state_dict)
vocoder = vocoder.eval().to(device)
elif vocoder_name == "bigvgan":
try:
from third_party.BigVGAN import bigvgan
except ImportError:
print("You need to follow the README to init submodule and change the BigVGAN source code.")
if is_local:
# download generator from https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x/tree/main
vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False)
else:
vocoder = bigvgan.BigVGAN.from_pretrained(
"nvidia/bigvgan_v2_24khz_100band_256x", use_cuda_kernel=False, cache_dir=hf_cache_dir
)
vocoder.remove_weight_norm()
vocoder = vocoder.eval().to(device)
return vocoder
# load asr pipeline
asr_pipe = None
def initialize_asr_pipeline(device: str = device, dtype=None):
if dtype is None:
dtype = (
torch.float16
if "cuda" in device
and torch.cuda.get_device_properties(device).major >= 7
and not torch.cuda.get_device_name().endswith("[ZLUDA]")
else torch.float32
)
global asr_pipe
asr_pipe = pipeline(
"automatic-speech-recognition",
model="openai/whisper-large-v3-turbo",
torch_dtype=dtype,
device=device,
)
# transcribe
def transcribe(ref_audio, language=None):
global asr_pipe
if asr_pipe is None:
initialize_asr_pipeline(device=device)
return asr_pipe(
ref_audio,
chunk_length_s=30,
batch_size=128,
generate_kwargs={"task": "transcribe", "language": language} if language else {"task": "transcribe"},
return_timestamps=False,
)["text"].strip()
# load model checkpoint for inference
def load_checkpoint(model, ckpt_path, device: str, dtype=None, use_ema=True):
if dtype is None:
print(f'device: {device}', flush=True)
try:
dtype = (
torch.float16
if "cuda" in device
and torch.cuda.get_device_properties(device).major >= 7
and not torch.cuda.get_device_name().endswith("[ZLUDA]")
else torch.float32
)
except Exception as e:
# print(f"Error determining dtype: {e}", flush=True)
dtype = torch.float16 if "cuda" in device else torch.float32
model = model.to(dtype)
ckpt_type = ckpt_path.split(".")[-1]
if ckpt_type == "safetensors":
from safetensors.torch import load_file
checkpoint = load_file(ckpt_path, device=device)
else:
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True)
if use_ema:
if ckpt_type == "safetensors":
checkpoint = {"ema_model_state_dict": checkpoint}
checkpoint["model_state_dict"] = {
k.replace("ema_model.", ""): v
for k, v in checkpoint["ema_model_state_dict"].items()
if k not in ["initted", "step"]
}
# patch for backward compatibility, 305e3ea
for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]:
if key in checkpoint["model_state_dict"]:
del checkpoint["model_state_dict"][key]
model.load_state_dict(checkpoint["model_state_dict"])
else:
if ckpt_type == "safetensors":
checkpoint = {"model_state_dict": checkpoint}
model.load_state_dict(checkpoint["model_state_dict"])
del checkpoint
torch.cuda.empty_cache()
return model.to(device)
# load model for inference
def load_model(
model_cls,
model_cfg,
ckpt_path,
mel_spec_type=mel_spec_type,
vocab_file="",
ode_method=ode_method,
use_ema=True,
device=device,
):
if vocab_file == "":
vocab_file = str(files("f5_tts").joinpath("infer/examples/vocab.txt"))
tokenizer = "custom"
print("\nvocab : ", vocab_file)
print("token : ", tokenizer)
print("model : ", ckpt_path, "\n")
vocab_char_map, vocab_size = get_tokenizer(vocab_file, tokenizer)
model = CFM(
transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
mel_spec_kwargs=dict(
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
n_mel_channels=n_mel_channels,
target_sample_rate=target_sample_rate,
mel_spec_type=mel_spec_type,
),
odeint_kwargs=dict(
method=ode_method,
),
vocab_char_map=vocab_char_map,
).to(device)
dtype = torch.float32 if mel_spec_type == "bigvgan" else None
model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
return model
def remove_silence_edges(audio, silence_threshold=-42):
# Remove silence from the start
non_silent_start_idx = silence.detect_leading_silence(audio, silence_threshold=silence_threshold)
audio = audio[non_silent_start_idx:]
# Remove silence from the end
non_silent_end_duration = audio.duration_seconds
for ms in reversed(audio):
if ms.dBFS > silence_threshold:
break
non_silent_end_duration -= 0.001
trimmed_audio = audio[: int(non_silent_end_duration * 1000)]
return trimmed_audio
# preprocess reference audio and text
def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print):
show_info("Converting audio...")
# Compute a hash of the reference audio file
with open(ref_audio_orig, "rb") as audio_file:
audio_data = audio_file.read()
audio_hash = hashlib.md5(audio_data).hexdigest()
global _ref_audio_cache
if audio_hash in _ref_audio_cache:
show_info("Using cached preprocessed reference audio...")
ref_audio = _ref_audio_cache[audio_hash]
else: # first pass, do preprocess
with tempfile.NamedTemporaryFile(suffix=".wav", **tempfile_kwargs) as f:
temp_path = f.name
aseg = AudioSegment.from_file(ref_audio_orig)
# 1. try to find long silence for clipping
non_silent_segs = silence.split_on_silence(
aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000, seek_step=10
)
non_silent_wave = AudioSegment.silent(duration=0)
for non_silent_seg in non_silent_segs:
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000:
show_info("Audio is over 12s, clipping short. (1)")
break
non_silent_wave += non_silent_seg
# 2. try to find short silence for clipping if 1. failed
if len(non_silent_wave) > 12000:
non_silent_segs = silence.split_on_silence(
aseg, min_silence_len=100, silence_thresh=-40, keep_silence=1000, seek_step=10
)
non_silent_wave = AudioSegment.silent(duration=0)
for non_silent_seg in non_silent_segs:
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000:
show_info("Audio is over 12s, clipping short. (2)")
break
non_silent_wave += non_silent_seg
aseg = non_silent_wave
# 3. if no proper silence found for clipping
if len(aseg) > 12000:
aseg = aseg[:12000]
show_info("Audio is over 12s, clipping short. (3)")
aseg = remove_silence_edges(aseg) + AudioSegment.silent(duration=50)
aseg.export(temp_path, format="wav")
ref_audio = temp_path
# Cache the processed reference audio
_ref_audio_cache[audio_hash] = ref_audio
if not ref_text.strip():
global _ref_text_cache
if audio_hash in _ref_text_cache:
# Use cached asr transcription
show_info("Using cached reference text...")
ref_text = _ref_text_cache[audio_hash]
else:
show_info("No reference text provided, transcribing reference audio...")
ref_text = transcribe(ref_audio)
# Cache the transcribed text (not caching custom ref_text, enabling users to do manual tweak)
_ref_text_cache[audio_hash] = ref_text
else:
show_info("Using custom reference text...")
# Ensure ref_text ends with a proper sentence-ending punctuation
if not ref_text.endswith(". ") and not ref_text.endswith(""):
if ref_text.endswith("."):
ref_text += " "
else:
ref_text += ". "
print("\nref_text ", ref_text)
return ref_audio, ref_text
# infer process: chunk text -> infer batches [i.e. infer_batch_process()]
def infer_process(
ref_audio,
ref_text,
gen_text,
model_obj,
vocoder,
mel_spec_type=mel_spec_type,
show_info=print,
progress=tqdm,
target_rms=target_rms,
cross_fade_duration=cross_fade_duration,
nfe_step=nfe_step,
cfg_strength=cfg_strength,
sway_sampling_coef=sway_sampling_coef,
speed=speed,
fix_duration=fix_duration,
device=device,
):
# Split the input text into batches
audio, sr = torchaudio.load(ref_audio)
max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (22 - audio.shape[-1] / sr) * speed)
gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
for i, gen_text in enumerate(gen_text_batches):
print(f"gen_text {i}", gen_text)
print("\n")
show_info(f"Generating audio in {len(gen_text_batches)} batches...")
return next(
infer_batch_process(
(audio, sr),
ref_text,
gen_text_batches,
model_obj,
vocoder,
mel_spec_type=mel_spec_type,
progress=progress,
target_rms=target_rms,
cross_fade_duration=cross_fade_duration,
nfe_step=nfe_step,
cfg_strength=cfg_strength,
sway_sampling_coef=sway_sampling_coef,
speed=speed,
fix_duration=fix_duration,
device=device,
)
)
# infer batches
def infer_batch_process(
ref_audio,
ref_text,
gen_text_batches,
model_obj,
vocoder,
mel_spec_type="vocos",
progress=tqdm,
target_rms=0.1,
cross_fade_duration=0.15,
nfe_step=32,
cfg_strength=2.0,
sway_sampling_coef=-1,
speed=1,
fix_duration=None,
device=None,
streaming=False,
chunk_size=2048,
):
audio, sr = ref_audio
if audio.shape[0] > 1:
audio = torch.mean(audio, dim=0, keepdim=True)
rms = torch.sqrt(torch.mean(torch.square(audio)))
if rms < target_rms:
audio = audio * target_rms / rms
if sr != target_sample_rate:
resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
audio = resampler(audio)
audio = audio.to(device)
generated_waves = []
spectrograms = []
if len(ref_text[-1].encode("utf-8")) == 1:
ref_text = ref_text + " "
def process_batch(gen_text):
local_speed = speed
if len(gen_text.encode("utf-8")) < 10:
local_speed = 0.3
# Prepare the text
text_list = [ref_text + gen_text]
final_text_list = convert_char_to_pinyin(text_list)
ref_audio_len = audio.shape[-1] // hop_length
if fix_duration is not None:
duration = int(fix_duration * target_sample_rate / hop_length)
else:
# Calculate duration
ref_text_len = len(ref_text.encode("utf-8"))
gen_text_len = len(gen_text.encode("utf-8"))
duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / local_speed)
# inference
with torch.inference_mode():
generated, _ = model_obj.sample(
cond=audio,
text=final_text_list,
duration=duration,
steps=nfe_step,
cfg_strength=cfg_strength,
sway_sampling_coef=sway_sampling_coef,
)
del _
generated = generated.to(torch.float32) # generated mel spectrogram
generated = generated[:, ref_audio_len:, :]
generated = generated.permute(0, 2, 1)
if mel_spec_type == "vocos":
generated_wave = vocoder.decode(generated)
elif mel_spec_type == "bigvgan":
generated_wave = vocoder(generated)
if rms < target_rms:
generated_wave = generated_wave * rms / target_rms
# wav -> numpy
generated_wave = generated_wave.squeeze().cpu().numpy()
if streaming:
for j in range(0, len(generated_wave), chunk_size):
yield generated_wave[j : j + chunk_size], target_sample_rate
else:
generated_cpu = generated[0].cpu().numpy()
del generated
yield generated_wave, generated_cpu
if streaming:
for gen_text in progress.tqdm(gen_text_batches) if progress is not None else gen_text_batches:
for chunk in process_batch(gen_text):
yield chunk
else:
with ThreadPoolExecutor() as executor:
futures = [executor.submit(process_batch, gen_text) for gen_text in gen_text_batches]
for future in progress.tqdm(futures) if progress is not None else futures:
result = future.result()
if result:
generated_wave, generated_mel_spec = next(result)
generated_waves.append(generated_wave)
spectrograms.append(generated_mel_spec)
if generated_waves:
if cross_fade_duration <= 0:
# Simply concatenate
final_wave = np.concatenate(generated_waves)
else:
# Combine all generated waves with cross-fading
final_wave = generated_waves[0]
for i in range(1, len(generated_waves)):
prev_wave = final_wave
next_wave = generated_waves[i]
# Calculate cross-fade samples, ensuring it does not exceed wave lengths
cross_fade_samples = int(cross_fade_duration * target_sample_rate)
cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave))
if cross_fade_samples <= 0:
# No overlap possible, concatenate
final_wave = np.concatenate([prev_wave, next_wave])
continue
# Overlapping parts
prev_overlap = prev_wave[-cross_fade_samples:]
next_overlap = next_wave[:cross_fade_samples]
# Fade out and fade in
fade_out = np.linspace(1, 0, cross_fade_samples)
fade_in = np.linspace(0, 1, cross_fade_samples)
# Cross-faded overlap
cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in
# Combine
new_wave = np.concatenate(
[prev_wave[:-cross_fade_samples], cross_faded_overlap, next_wave[cross_fade_samples:]]
)
final_wave = new_wave
# Create a combined spectrogram
combined_spectrogram = np.concatenate(spectrograms, axis=1)
yield final_wave, target_sample_rate, combined_spectrogram
else:
yield None, target_sample_rate, None
# remove silence from generated wav
def remove_silence_for_generated_wav(filename):
aseg = AudioSegment.from_file(filename)
non_silent_segs = silence.split_on_silence(
aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500, seek_step=10
)
non_silent_wave = AudioSegment.silent(duration=0)
for non_silent_seg in non_silent_segs:
non_silent_wave += non_silent_seg
aseg = non_silent_wave
aseg.export(filename, format="wav")
# save spectrogram
def save_spectrogram(spectrogram, path):
plt.figure(figsize=(12, 4))
plt.imshow(spectrogram, origin="lower", aspect="auto")
plt.colorbar()
plt.savefig(path)
plt.close()

View File

@@ -0,0 +1,8 @@
from f5_tts.model.backbones.dit import DiT
from f5_tts.model.backbones.mmdit import MMDiT
from f5_tts.model.backbones.unett import UNetT
from f5_tts.model.cfm import CFM
from f5_tts.model.trainer import Trainer
__all__ = ["CFM", "UNetT", "DiT", "MMDiT", "Trainer"]

View File

@@ -0,0 +1,20 @@
## Backbones quick introduction
### unett.py
- flat unet transformer
- structure same as in e2-tts & voicebox paper except using rotary pos emb
- possible abs pos emb & convnextv2 blocks for embedded text before concat
### dit.py
- adaln-zero dit
- embedded timestep as condition
- concatted noised_input + masked_cond + embedded_text, linear proj in
- possible abs pos emb & convnextv2 blocks for embedded text before concat
- possible long skip connection (first layer to last layer)
### mmdit.py
- stable diffusion 3 block structure
- timestep as condition
- left stream: text embedded and applied a abs pos emb
- right stream: masked_cond & noised_input concatted and with same conv pos emb as unett

View File

@@ -0,0 +1,259 @@
"""
ein notation:
b - batch
n - sequence
nt - text sequence
nw - raw wave length
d - dimension
"""
from __future__ import annotations
import torch
import torch.nn.functional as F
from torch import nn
from x_transformers.x_transformers import RotaryEmbedding
from f5_tts.model.modules import (
AdaLayerNorm_Final,
ConvNeXtV2Block,
ConvPositionEmbedding,
DiTBlock,
TimestepEmbedding,
get_pos_embed_indices,
precompute_freqs_cis,
)
# Text embedding
class TextEmbedding(nn.Module):
def __init__(self, text_num_embeds, text_dim, mask_padding=True, conv_layers=0, conv_mult=2):
super().__init__()
self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
self.mask_padding = mask_padding # mask filler and batch padding tokens or not
if conv_layers > 0:
self.extra_modeling = True
self.precompute_max_pos = 4096 # ~44s of 24khz audio
self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
self.text_blocks = nn.Sequential(
*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
)
else:
self.extra_modeling = False
def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
batch, text_len = text.shape[0], text.shape[1]
text = F.pad(text, (0, seq_len - text_len), value=0)
if self.mask_padding:
text_mask = text == 0
if drop_text: # cfg for text
text = torch.zeros_like(text)
text = self.text_embed(text) # b n -> b n d
# possible extra modeling
if self.extra_modeling:
# sinus pos emb
batch_start = torch.zeros((batch,), dtype=torch.long)
pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
text_pos_embed = self.freqs_cis[pos_idx]
text = text + text_pos_embed
# convnextv2 blocks
if self.mask_padding:
text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
for block in self.text_blocks:
text = block(text)
text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
else:
text = self.text_blocks(text)
return text
# noised input audio and context mixing embedding
class InputEmbedding(nn.Module):
def __init__(self, mel_dim, text_dim, out_dim):
super().__init__()
self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False): # noqa: F722
if drop_audio_cond: # cfg for cond audio
cond = torch.zeros_like(cond)
x = self.proj(torch.cat((x, cond, text_embed), dim=-1))
x = self.conv_pos_embed(x) + x
return x
# Transformer backbone using DiT blocks
class DiT(nn.Module):
def __init__(
self,
*,
dim,
depth=8,
heads=8,
dim_head=64,
dropout=0.1,
ff_mult=4,
mel_dim=100,
text_num_embeds=256,
text_dim=None,
text_mask_padding=True,
qk_norm=None,
conv_layers=0,
pe_attn_head=None,
attn_backend="torch", # "torch" | "flash_attn"
attn_mask_enabled=False,
long_skip_connection=False,
checkpoint_activations=False,
):
super().__init__()
self.time_embed = TimestepEmbedding(dim)
if text_dim is None:
text_dim = mel_dim
self.text_embed = TextEmbedding(
text_num_embeds, text_dim, mask_padding=text_mask_padding, conv_layers=conv_layers
)
self.text_cond, self.text_uncond = None, None # text cache
self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
self.rotary_embed = RotaryEmbedding(dim_head)
self.dim = dim
self.depth = depth
self.transformer_blocks = nn.ModuleList(
[
DiTBlock(
dim=dim,
heads=heads,
dim_head=dim_head,
ff_mult=ff_mult,
dropout=dropout,
qk_norm=qk_norm,
pe_attn_head=pe_attn_head,
attn_backend=attn_backend,
attn_mask_enabled=attn_mask_enabled,
)
for _ in range(depth)
]
)
self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
self.norm_out = AdaLayerNorm_Final(dim) # final modulation
self.proj_out = nn.Linear(dim, mel_dim)
self.checkpoint_activations = checkpoint_activations
self.initialize_weights()
def initialize_weights(self):
# Zero-out AdaLN layers in DiT blocks:
for block in self.transformer_blocks:
nn.init.constant_(block.attn_norm.linear.weight, 0)
nn.init.constant_(block.attn_norm.linear.bias, 0)
# Zero-out output layers:
nn.init.constant_(self.norm_out.linear.weight, 0)
nn.init.constant_(self.norm_out.linear.bias, 0)
nn.init.constant_(self.proj_out.weight, 0)
nn.init.constant_(self.proj_out.bias, 0)
def ckpt_wrapper(self, module):
# https://github.com/chuanyangjin/fast-DiT/blob/main/models.py
def ckpt_forward(*inputs):
outputs = module(*inputs)
return outputs
return ckpt_forward
def get_input_embed(
self,
x, # b n d
cond, # b n d
text, # b nt
drop_audio_cond: bool = False,
drop_text: bool = False,
cache: bool = True,
):
seq_len = x.shape[1]
if cache:
if drop_text:
if self.text_uncond is None:
self.text_uncond = self.text_embed(text, seq_len, drop_text=True)
text_embed = self.text_uncond
else:
if self.text_cond is None:
self.text_cond = self.text_embed(text, seq_len, drop_text=False)
text_embed = self.text_cond
else:
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
return x
def clear_cache(self):
self.text_cond, self.text_uncond = None, None
def forward(
self,
x: float["b n d"], # nosied input audio # noqa: F722
cond: float["b n d"], # masked cond audio # noqa: F722
text: int["b nt"], # text # noqa: F722
time: float["b"] | float[""], # time step # noqa: F821 F722
mask: bool["b n"] | None = None, # noqa: F722
drop_audio_cond: bool = False, # cfg for cond audio
drop_text: bool = False, # cfg for text
cfg_infer: bool = False, # cfg inference, pack cond & uncond forward
cache: bool = False,
):
batch, seq_len = x.shape[0], x.shape[1]
if time.ndim == 0:
time = time.repeat(batch)
# t: conditioning time, text: text, x: noised audio + cond audio + text
t = self.time_embed(time)
if cfg_infer: # pack cond & uncond forward: b n d -> 2b n d
x_cond = self.get_input_embed(x, cond, text, drop_audio_cond=False, drop_text=False, cache=cache)
x_uncond = self.get_input_embed(x, cond, text, drop_audio_cond=True, drop_text=True, cache=cache)
x = torch.cat((x_cond, x_uncond), dim=0)
t = torch.cat((t, t), dim=0)
mask = torch.cat((mask, mask), dim=0) if mask is not None else None
else:
x = self.get_input_embed(x, cond, text, drop_audio_cond=drop_audio_cond, drop_text=drop_text, cache=cache)
rope = self.rotary_embed.forward_from_seq_len(seq_len)
if self.long_skip_connection is not None:
residual = x
for block in self.transformer_blocks:
if self.checkpoint_activations:
# https://pytorch.org/docs/stable/checkpoint.html#torch.utils.checkpoint.checkpoint
x = torch.utils.checkpoint.checkpoint(self.ckpt_wrapper(block), x, t, mask, rope, use_reentrant=False)
else:
x = block(x, t, mask=mask, rope=rope)
if self.long_skip_connection is not None:
x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
x = self.norm_out(x, t)
output = self.proj_out(x)
return output

View File

@@ -0,0 +1,212 @@
"""
ein notation:
b - batch
n - sequence
nt - text sequence
nw - raw wave length
d - dimension
"""
from __future__ import annotations
import torch
from torch import nn
from x_transformers.x_transformers import RotaryEmbedding
from f5_tts.model.modules import (
AdaLayerNorm_Final,
ConvPositionEmbedding,
MMDiTBlock,
TimestepEmbedding,
get_pos_embed_indices,
precompute_freqs_cis,
)
# text embedding
class TextEmbedding(nn.Module):
def __init__(self, out_dim, text_num_embeds, mask_padding=True):
super().__init__()
self.text_embed = nn.Embedding(text_num_embeds + 1, out_dim) # will use 0 as filler token
self.mask_padding = mask_padding # mask filler and batch padding tokens or not
self.precompute_max_pos = 1024
self.register_buffer("freqs_cis", precompute_freqs_cis(out_dim, self.precompute_max_pos), persistent=False)
def forward(self, text: int["b nt"], drop_text=False) -> int["b nt d"]: # noqa: F722
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
if self.mask_padding:
text_mask = text == 0
if drop_text: # cfg for text
text = torch.zeros_like(text)
text = self.text_embed(text) # b nt -> b nt d
# sinus pos emb
batch_start = torch.zeros((text.shape[0],), dtype=torch.long)
batch_text_len = text.shape[1]
pos_idx = get_pos_embed_indices(batch_start, batch_text_len, max_pos=self.precompute_max_pos)
text_pos_embed = self.freqs_cis[pos_idx]
text = text + text_pos_embed
if self.mask_padding:
text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
return text
# noised input & masked cond audio embedding
class AudioEmbedding(nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.linear = nn.Linear(2 * in_dim, out_dim)
self.conv_pos_embed = ConvPositionEmbedding(out_dim)
def forward(self, x: float["b n d"], cond: float["b n d"], drop_audio_cond=False): # noqa: F722
if drop_audio_cond:
cond = torch.zeros_like(cond)
x = torch.cat((x, cond), dim=-1)
x = self.linear(x)
x = self.conv_pos_embed(x) + x
return x
# Transformer backbone using MM-DiT blocks
class MMDiT(nn.Module):
def __init__(
self,
*,
dim,
depth=8,
heads=8,
dim_head=64,
dropout=0.1,
ff_mult=4,
mel_dim=100,
text_num_embeds=256,
text_mask_padding=True,
qk_norm=None,
):
super().__init__()
self.time_embed = TimestepEmbedding(dim)
self.text_embed = TextEmbedding(dim, text_num_embeds, mask_padding=text_mask_padding)
self.text_cond, self.text_uncond = None, None # text cache
self.audio_embed = AudioEmbedding(mel_dim, dim)
self.rotary_embed = RotaryEmbedding(dim_head)
self.dim = dim
self.depth = depth
self.transformer_blocks = nn.ModuleList(
[
MMDiTBlock(
dim=dim,
heads=heads,
dim_head=dim_head,
dropout=dropout,
ff_mult=ff_mult,
context_pre_only=i == depth - 1,
qk_norm=qk_norm,
)
for i in range(depth)
]
)
self.norm_out = AdaLayerNorm_Final(dim) # final modulation
self.proj_out = nn.Linear(dim, mel_dim)
self.initialize_weights()
def initialize_weights(self):
# Zero-out AdaLN layers in MMDiT blocks:
for block in self.transformer_blocks:
nn.init.constant_(block.attn_norm_x.linear.weight, 0)
nn.init.constant_(block.attn_norm_x.linear.bias, 0)
nn.init.constant_(block.attn_norm_c.linear.weight, 0)
nn.init.constant_(block.attn_norm_c.linear.bias, 0)
# Zero-out output layers:
nn.init.constant_(self.norm_out.linear.weight, 0)
nn.init.constant_(self.norm_out.linear.bias, 0)
nn.init.constant_(self.proj_out.weight, 0)
nn.init.constant_(self.proj_out.bias, 0)
def get_input_embed(
self,
x, # b n d
cond, # b n d
text, # b nt
drop_audio_cond: bool = False,
drop_text: bool = False,
cache: bool = True,
):
if cache:
if drop_text:
if self.text_uncond is None:
self.text_uncond = self.text_embed(text, drop_text=True)
c = self.text_uncond
else:
if self.text_cond is None:
self.text_cond = self.text_embed(text, drop_text=False)
c = self.text_cond
else:
c = self.text_embed(text, drop_text=drop_text)
x = self.audio_embed(x, cond, drop_audio_cond=drop_audio_cond)
return x, c
def clear_cache(self):
self.text_cond, self.text_uncond = None, None
def forward(
self,
x: float["b n d"], # nosied input audio # noqa: F722
cond: float["b n d"], # masked cond audio # noqa: F722
text: int["b nt"], # text # noqa: F722
time: float["b"] | float[""], # time step # noqa: F821 F722
mask: bool["b n"] | None = None, # noqa: F722
drop_audio_cond: bool = False, # cfg for cond audio
drop_text: bool = False, # cfg for text
cfg_infer: bool = False, # cfg inference, pack cond & uncond forward
cache: bool = False,
):
batch = x.shape[0]
if time.ndim == 0:
time = time.repeat(batch)
# t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
t = self.time_embed(time)
if cfg_infer: # pack cond & uncond forward: b n d -> 2b n d
x_cond, c_cond = self.get_input_embed(x, cond, text, drop_audio_cond=False, drop_text=False, cache=cache)
x_uncond, c_uncond = self.get_input_embed(x, cond, text, drop_audio_cond=True, drop_text=True, cache=cache)
x = torch.cat((x_cond, x_uncond), dim=0)
c = torch.cat((c_cond, c_uncond), dim=0)
t = torch.cat((t, t), dim=0)
mask = torch.cat((mask, mask), dim=0) if mask is not None else None
else:
x, c = self.get_input_embed(
x, cond, text, drop_audio_cond=drop_audio_cond, drop_text=drop_text, cache=cache
)
seq_len = x.shape[1]
text_len = text.shape[1]
rope_audio = self.rotary_embed.forward_from_seq_len(seq_len)
rope_text = self.rotary_embed.forward_from_seq_len(text_len)
for block in self.transformer_blocks:
c, x = block(x, c, t, mask=mask, rope=rope_audio, c_rope=rope_text)
x = self.norm_out(x, t)
output = self.proj_out(x)
return output

View File

@@ -0,0 +1,273 @@
"""
ein notation:
b - batch
n - sequence
nt - text sequence
nw - raw wave length
d - dimension
"""
from __future__ import annotations
from typing import Literal
import torch
import torch.nn.functional as F
from torch import nn
from x_transformers import RMSNorm
from x_transformers.x_transformers import RotaryEmbedding
from f5_tts.model.modules import (
Attention,
AttnProcessor,
ConvNeXtV2Block,
ConvPositionEmbedding,
FeedForward,
TimestepEmbedding,
get_pos_embed_indices,
precompute_freqs_cis,
)
# Text embedding
class TextEmbedding(nn.Module):
def __init__(self, text_num_embeds, text_dim, mask_padding=True, conv_layers=0, conv_mult=2):
super().__init__()
self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
self.mask_padding = mask_padding # mask filler and batch padding tokens or not
if conv_layers > 0:
self.extra_modeling = True
self.precompute_max_pos = 4096 # ~44s of 24khz audio
self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
self.text_blocks = nn.Sequential(
*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
)
else:
self.extra_modeling = False
def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
batch, text_len = text.shape[0], text.shape[1]
text = F.pad(text, (0, seq_len - text_len), value=0)
if self.mask_padding:
text_mask = text == 0
if drop_text: # cfg for text
text = torch.zeros_like(text)
text = self.text_embed(text) # b n -> b n d
# possible extra modeling
if self.extra_modeling:
# sinus pos emb
batch_start = torch.zeros((batch,), dtype=torch.long)
pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
text_pos_embed = self.freqs_cis[pos_idx]
text = text + text_pos_embed
# convnextv2 blocks
if self.mask_padding:
text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
for block in self.text_blocks:
text = block(text)
text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
else:
text = self.text_blocks(text)
return text
# noised input audio and context mixing embedding
class InputEmbedding(nn.Module):
def __init__(self, mel_dim, text_dim, out_dim):
super().__init__()
self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False): # noqa: F722
if drop_audio_cond: # cfg for cond audio
cond = torch.zeros_like(cond)
x = self.proj(torch.cat((x, cond, text_embed), dim=-1))
x = self.conv_pos_embed(x) + x
return x
# Flat UNet Transformer backbone
class UNetT(nn.Module):
def __init__(
self,
*,
dim,
depth=8,
heads=8,
dim_head=64,
dropout=0.1,
ff_mult=4,
mel_dim=100,
text_num_embeds=256,
text_dim=None,
text_mask_padding=True,
qk_norm=None,
conv_layers=0,
pe_attn_head=None,
skip_connect_type: Literal["add", "concat", "none"] = "concat",
):
super().__init__()
assert depth % 2 == 0, "UNet-Transformer's depth should be even."
self.time_embed = TimestepEmbedding(dim)
if text_dim is None:
text_dim = mel_dim
self.text_embed = TextEmbedding(
text_num_embeds, text_dim, mask_padding=text_mask_padding, conv_layers=conv_layers
)
self.text_cond, self.text_uncond = None, None # text cache
self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
self.rotary_embed = RotaryEmbedding(dim_head)
# transformer layers & skip connections
self.dim = dim
self.skip_connect_type = skip_connect_type
needs_skip_proj = skip_connect_type == "concat"
self.depth = depth
self.layers = nn.ModuleList([])
for idx in range(depth):
is_later_half = idx >= (depth // 2)
attn_norm = RMSNorm(dim)
attn = Attention(
processor=AttnProcessor(pe_attn_head=pe_attn_head),
dim=dim,
heads=heads,
dim_head=dim_head,
dropout=dropout,
qk_norm=qk_norm,
)
ff_norm = RMSNorm(dim)
ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
skip_proj = nn.Linear(dim * 2, dim, bias=False) if needs_skip_proj and is_later_half else None
self.layers.append(
nn.ModuleList(
[
skip_proj,
attn_norm,
attn,
ff_norm,
ff,
]
)
)
self.norm_out = RMSNorm(dim)
self.proj_out = nn.Linear(dim, mel_dim)
def get_input_embed(
self,
x, # b n d
cond, # b n d
text, # b nt
drop_audio_cond: bool = False,
drop_text: bool = False,
cache: bool = True,
):
seq_len = x.shape[1]
if cache:
if drop_text:
if self.text_uncond is None:
self.text_uncond = self.text_embed(text, seq_len, drop_text=True)
text_embed = self.text_uncond
else:
if self.text_cond is None:
self.text_cond = self.text_embed(text, seq_len, drop_text=False)
text_embed = self.text_cond
else:
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
return x
def clear_cache(self):
self.text_cond, self.text_uncond = None, None
def forward(
self,
x: float["b n d"], # nosied input audio # noqa: F722
cond: float["b n d"], # masked cond audio # noqa: F722
text: int["b nt"], # text # noqa: F722
time: float["b"] | float[""], # time step # noqa: F821 F722
mask: bool["b n"] | None = None, # noqa: F722
drop_audio_cond: bool = False, # cfg for cond audio
drop_text: bool = False, # cfg for text
cfg_infer: bool = False, # cfg inference, pack cond & uncond forward
cache: bool = False,
):
batch, seq_len = x.shape[0], x.shape[1]
if time.ndim == 0:
time = time.repeat(batch)
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
t = self.time_embed(time)
if cfg_infer: # pack cond & uncond forward: b n d -> 2b n d
x_cond = self.get_input_embed(x, cond, text, drop_audio_cond=False, drop_text=False, cache=cache)
x_uncond = self.get_input_embed(x, cond, text, drop_audio_cond=True, drop_text=True, cache=cache)
x = torch.cat((x_cond, x_uncond), dim=0)
t = torch.cat((t, t), dim=0)
mask = torch.cat((mask, mask), dim=0) if mask is not None else None
else:
x = self.get_input_embed(x, cond, text, drop_audio_cond=drop_audio_cond, drop_text=drop_text, cache=cache)
# postfix time t to input x, [b n d] -> [b n+1 d]
x = torch.cat([t.unsqueeze(1), x], dim=1) # pack t to x
if mask is not None:
mask = F.pad(mask, (1, 0), value=1)
rope = self.rotary_embed.forward_from_seq_len(seq_len + 1)
# flat unet transformer
skip_connect_type = self.skip_connect_type
skips = []
for idx, (maybe_skip_proj, attn_norm, attn, ff_norm, ff) in enumerate(self.layers):
layer = idx + 1
# skip connection logic
is_first_half = layer <= (self.depth // 2)
is_later_half = not is_first_half
if is_first_half:
skips.append(x)
if is_later_half:
skip = skips.pop()
if skip_connect_type == "concat":
x = torch.cat((x, skip), dim=-1)
x = maybe_skip_proj(x)
elif skip_connect_type == "add":
x = x + skip
# attention and feedforward blocks
x = attn(attn_norm(x), rope=rope, mask=mask) + x
x = ff(ff_norm(x)) + x
assert len(skips) == 0
x = self.norm_out(x)[:, 1:, :] # unpack t from x
return self.proj_out(x)

View File

@@ -0,0 +1,302 @@
"""
ein notation:
b - batch
n - sequence
nt - text sequence
nw - raw wave length
d - dimension
"""
from __future__ import annotations
from random import random
from typing import Callable
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from torchdiffeq import odeint
from f5_tts.model.modules import MelSpec
from f5_tts.model.utils import (
default,
exists,
get_epss_timesteps,
lens_to_mask,
list_str_to_idx,
list_str_to_tensor,
mask_from_frac_lengths,
)
class CFM(nn.Module):
def __init__(
self,
transformer: nn.Module,
sigma=0.0,
odeint_kwargs: dict = dict(
# atol = 1e-5,
# rtol = 1e-5,
method="euler" # 'midpoint'
),
audio_drop_prob=0.3,
cond_drop_prob=0.2,
num_channels=None,
mel_spec_module: nn.Module | None = None,
mel_spec_kwargs: dict = dict(),
frac_lengths_mask: tuple[float, float] = (0.7, 1.0),
vocab_char_map: dict[str:int] | None = None,
):
super().__init__()
self.frac_lengths_mask = frac_lengths_mask
# mel spec
self.mel_spec = default(mel_spec_module, MelSpec(**mel_spec_kwargs))
num_channels = default(num_channels, self.mel_spec.n_mel_channels)
self.num_channels = num_channels
# classifier-free guidance
self.audio_drop_prob = audio_drop_prob
self.cond_drop_prob = cond_drop_prob
# transformer
self.transformer = transformer
dim = transformer.dim
self.dim = dim
# conditional flow related
self.sigma = sigma
# sampling related
self.odeint_kwargs = odeint_kwargs
# vocab map for tokenization
self.vocab_char_map = vocab_char_map
@property
def device(self):
return next(self.parameters()).device
@torch.no_grad()
def sample(
self,
cond: float["b n d"] | float["b nw"], # noqa: F722
text: int["b nt"] | list[str], # noqa: F722
duration: int | int["b"], # noqa: F821
*,
lens: int["b"] | None = None, # noqa: F821
steps=32,
cfg_strength=1.0,
sway_sampling_coef=None,
seed: int | None = None,
max_duration=4096,
vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, # noqa: F722
use_epss=True,
no_ref_audio=False,
duplicate_test=False,
t_inter=0.1,
edit_mask=None,
):
self.eval()
# raw wave
if cond.ndim == 2:
cond = self.mel_spec(cond)
cond = cond.permute(0, 2, 1)
assert cond.shape[-1] == self.num_channels
cond = cond.to(next(self.parameters()).dtype)
batch, cond_seq_len, device = *cond.shape[:2], cond.device
if not exists(lens):
lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long)
# text
if isinstance(text, list):
if exists(self.vocab_char_map):
text = list_str_to_idx(text, self.vocab_char_map).to(device)
else:
text = list_str_to_tensor(text).to(device)
assert text.shape[0] == batch
# duration
cond_mask = lens_to_mask(lens)
if edit_mask is not None:
cond_mask = cond_mask & edit_mask
if isinstance(duration, int):
duration = torch.full((batch,), duration, device=device, dtype=torch.long)
duration = torch.maximum(
torch.maximum((text != -1).sum(dim=-1), lens) + 1, duration
) # duration at least text/audio prompt length plus one token, so something is generated
duration = duration.clamp(max=max_duration)
max_duration = duration.amax()
# duplicate test corner for inner time step oberservation
if duplicate_test:
test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2 * cond_seq_len), value=0.0)
cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0)
if no_ref_audio:
cond = torch.zeros_like(cond)
cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False)
cond_mask = cond_mask.unsqueeze(-1)
step_cond = torch.where(
cond_mask, cond, torch.zeros_like(cond)
) # allow direct control (cut cond audio) with lens passed in
if batch > 1:
mask = lens_to_mask(duration)
else: # save memory and speed up, as single inference need no mask currently
mask = None
# neural ode
def fn(t, x):
# at each step, conditioning is fixed
# step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond))
# predict flow (cond)
if cfg_strength < 1e-5:
pred = self.transformer(
x=x,
cond=step_cond,
text=text,
time=t,
mask=mask,
drop_audio_cond=False,
drop_text=False,
cache=True,
)
return pred
# predict flow (cond and uncond), for classifier-free guidance
pred_cfg = self.transformer(
x=x,
cond=step_cond,
text=text,
time=t,
mask=mask,
cfg_infer=True,
cache=True,
)
pred, null_pred = torch.chunk(pred_cfg, 2, dim=0)
return pred + (pred - null_pred) * cfg_strength
# noise input
# to make sure batch inference result is same with different batch size, and for sure single inference
# still some difference maybe due to convolutional layers
y0 = []
for dur in duration:
if exists(seed):
torch.manual_seed(seed)
y0.append(torch.randn(dur, self.num_channels, device=self.device, dtype=step_cond.dtype))
y0 = pad_sequence(y0, padding_value=0, batch_first=True)
t_start = 0
# duplicate test corner for inner time step oberservation
if duplicate_test:
t_start = t_inter
y0 = (1 - t_start) * y0 + t_start * test_cond
steps = int(steps * (1 - t_start))
if t_start == 0 and use_epss: # use Empirically Pruned Step Sampling for low NFE
t = get_epss_timesteps(steps, device=self.device, dtype=step_cond.dtype)
else:
t = torch.linspace(t_start, 1, steps + 1, device=self.device, dtype=step_cond.dtype)
if sway_sampling_coef is not None:
t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
self.transformer.clear_cache()
sampled = trajectory[-1]
out = sampled
out = torch.where(cond_mask, cond, out)
if exists(vocoder):
out = out.permute(0, 2, 1)
out = vocoder(out)
return out, trajectory
def forward(
self,
inp: float["b n d"] | float["b nw"], # mel or raw wave # noqa: F722
text: int["b nt"] | list[str], # noqa: F722
*,
lens: int["b"] | None = None, # noqa: F821
noise_scheduler: str | None = None,
):
# handle raw wave
if inp.ndim == 2:
inp = self.mel_spec(inp)
inp = inp.permute(0, 2, 1)
assert inp.shape[-1] == self.num_channels
batch, seq_len, dtype, device, _σ1 = *inp.shape[:2], inp.dtype, self.device, self.sigma
# handle text as string
if isinstance(text, list):
if exists(self.vocab_char_map):
text = list_str_to_idx(text, self.vocab_char_map).to(device)
else:
text = list_str_to_tensor(text).to(device)
assert text.shape[0] == batch
# lens and mask
if not exists(lens):
lens = torch.full((batch,), seq_len, device=device)
mask = lens_to_mask(lens, length=seq_len) # useless here, as collate_fn will pad to max length in batch
# get a random span to mask out for training conditionally
frac_lengths = torch.zeros((batch,), device=self.device).float().uniform_(*self.frac_lengths_mask)
rand_span_mask = mask_from_frac_lengths(lens, frac_lengths)
if exists(mask):
rand_span_mask &= mask
# mel is x1
x1 = inp
# x0 is gaussian noise
x0 = torch.randn_like(x1)
# time step
time = torch.rand((batch,), dtype=dtype, device=self.device)
# TODO. noise_scheduler
# sample xt (φ_t(x) in the paper)
t = time.unsqueeze(-1).unsqueeze(-1)
φ = (1 - t) * x0 + t * x1
flow = x1 - x0
# only predict what is within the random mask span for infilling
cond = torch.where(rand_span_mask[..., None], torch.zeros_like(x1), x1)
# transformer and cfg training with a drop rate
drop_audio_cond = random() < self.audio_drop_prob # p_drop in voicebox paper
if random() < self.cond_drop_prob: # p_uncond in voicebox paper
drop_audio_cond = True
drop_text = True
else:
drop_text = False
# apply mask will use more memory; might adjust batchsize or batchsampler long sequence threshold
pred = self.transformer(
x=φ, cond=cond, text=text, time=time, drop_audio_cond=drop_audio_cond, drop_text=drop_text, mask=mask
)
# flow matching loss
loss = F.mse_loss(pred, flow, reduction="none")
loss = loss[rand_span_mask]
return loss.mean(), cond, pred

View File

@@ -0,0 +1,330 @@
import json
from importlib.resources import files
import torch
import torch.nn.functional as F
import torchaudio
from datasets import Dataset as Dataset_
from datasets import load_from_disk
from torch import nn
from torch.utils.data import Dataset, Sampler
from tqdm import tqdm
from f5_tts.model.modules import MelSpec
from f5_tts.model.utils import default
class HFDataset(Dataset):
def __init__(
self,
hf_dataset: Dataset,
target_sample_rate=24_000,
n_mel_channels=100,
hop_length=256,
n_fft=1024,
win_length=1024,
mel_spec_type="vocos",
):
self.data = hf_dataset
self.target_sample_rate = target_sample_rate
self.hop_length = hop_length
self.mel_spectrogram = MelSpec(
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
n_mel_channels=n_mel_channels,
target_sample_rate=target_sample_rate,
mel_spec_type=mel_spec_type,
)
def get_frame_len(self, index):
row = self.data[index]
audio = row["audio"]["array"]
sample_rate = row["audio"]["sampling_rate"]
return audio.shape[-1] / sample_rate * self.target_sample_rate / self.hop_length
def __len__(self):
return len(self.data)
def __getitem__(self, index):
row = self.data[index]
audio = row["audio"]["array"]
# logger.info(f"Audio shape: {audio.shape}")
sample_rate = row["audio"]["sampling_rate"]
duration = audio.shape[-1] / sample_rate
if duration > 30 or duration < 0.3:
return self.__getitem__((index + 1) % len(self.data))
audio_tensor = torch.from_numpy(audio).float()
if sample_rate != self.target_sample_rate:
resampler = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate)
audio_tensor = resampler(audio_tensor)
audio_tensor = audio_tensor.unsqueeze(0) # 't -> 1 t')
mel_spec = self.mel_spectrogram(audio_tensor)
mel_spec = mel_spec.squeeze(0) # '1 d t -> d t'
text = row["text"]
return dict(
mel_spec=mel_spec,
text=text,
)
class CustomDataset(Dataset):
def __init__(
self,
custom_dataset: Dataset,
durations=None,
target_sample_rate=24_000,
hop_length=256,
n_mel_channels=100,
n_fft=1024,
win_length=1024,
mel_spec_type="vocos",
preprocessed_mel=False,
mel_spec_module: nn.Module | None = None,
):
self.data = custom_dataset
self.durations = durations
self.target_sample_rate = target_sample_rate
self.hop_length = hop_length
self.n_fft = n_fft
self.win_length = win_length
self.mel_spec_type = mel_spec_type
self.preprocessed_mel = preprocessed_mel
if not preprocessed_mel:
self.mel_spectrogram = default(
mel_spec_module,
MelSpec(
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
n_mel_channels=n_mel_channels,
target_sample_rate=target_sample_rate,
mel_spec_type=mel_spec_type,
),
)
def get_frame_len(self, index):
if (
self.durations is not None
): # Please make sure the separately provided durations are correct, otherwise 99.99% OOM
return self.durations[index] * self.target_sample_rate / self.hop_length
return self.data[index]["duration"] * self.target_sample_rate / self.hop_length
def __len__(self):
return len(self.data)
def __getitem__(self, index):
while True:
row = self.data[index]
audio_path = row["audio_path"]
text = row["text"]
duration = row["duration"]
# filter by given length
if 0.3 <= duration <= 30:
break # valid
index = (index + 1) % len(self.data)
if self.preprocessed_mel:
mel_spec = torch.tensor(row["mel_spec"])
else:
audio, source_sample_rate = torchaudio.load(audio_path)
# make sure mono input
if audio.shape[0] > 1:
audio = torch.mean(audio, dim=0, keepdim=True)
# resample if necessary
if source_sample_rate != self.target_sample_rate:
resampler = torchaudio.transforms.Resample(source_sample_rate, self.target_sample_rate)
audio = resampler(audio)
# to mel spectrogram
mel_spec = self.mel_spectrogram(audio)
mel_spec = mel_spec.squeeze(0) # '1 d t -> d t'
return {
"mel_spec": mel_spec,
"text": text,
}
# Dynamic Batch Sampler
class DynamicBatchSampler(Sampler[list[int]]):
"""Extension of Sampler that will do the following:
1. Change the batch size (essentially number of sequences)
in a batch to ensure that the total number of frames are less
than a certain threshold.
2. Make sure the padding efficiency in the batch is high.
3. Shuffle batches each epoch while maintaining reproducibility.
"""
def __init__(
self, sampler: Sampler[int], frames_threshold: int, max_samples=0, random_seed=None, drop_residual: bool = False
):
self.sampler = sampler
self.frames_threshold = frames_threshold
self.max_samples = max_samples
self.random_seed = random_seed
self.epoch = 0
indices, batches = [], []
data_source = self.sampler.data_source
for idx in tqdm(
self.sampler, desc="Sorting with sampler... if slow, check whether dataset is provided with duration"
):
indices.append((idx, data_source.get_frame_len(idx)))
indices.sort(key=lambda elem: elem[1])
batch = []
batch_frames = 0
for idx, frame_len in tqdm(
indices, desc=f"Creating dynamic batches with {frames_threshold} audio frames per gpu"
):
if batch_frames + frame_len <= self.frames_threshold and (max_samples == 0 or len(batch) < max_samples):
batch.append(idx)
batch_frames += frame_len
else:
if len(batch) > 0:
batches.append(batch)
if frame_len <= self.frames_threshold:
batch = [idx]
batch_frames = frame_len
else:
batch = []
batch_frames = 0
if not drop_residual and len(batch) > 0:
batches.append(batch)
del indices
self.batches = batches
# Ensure even batches with accelerate BatchSamplerShard cls under frame_per_batch setting
self.drop_last = True
def set_epoch(self, epoch: int) -> None:
"""Sets the epoch for this sampler."""
self.epoch = epoch
def __iter__(self):
# Use both random_seed and epoch for deterministic but different shuffling per epoch
if self.random_seed is not None:
g = torch.Generator()
g.manual_seed(self.random_seed + self.epoch)
# Use PyTorch's random permutation for better reproducibility across PyTorch versions
indices = torch.randperm(len(self.batches), generator=g).tolist()
batches = [self.batches[i] for i in indices]
else:
batches = self.batches
return iter(batches)
def __len__(self):
return len(self.batches)
# Load dataset
def load_dataset(
dataset_name: str,
tokenizer: str = "pinyin",
dataset_type: str = "CustomDataset",
audio_type: str = "raw",
mel_spec_module: nn.Module | None = None,
mel_spec_kwargs: dict = dict(),
) -> CustomDataset | HFDataset:
"""
dataset_type - "CustomDataset" if you want to use tokenizer name and default data path to load for train_dataset
- "CustomDatasetPath" if you just want to pass the full path to a preprocessed dataset without relying on tokenizer
"""
print("Loading dataset ...")
if dataset_type == "CustomDataset":
rel_data_path = str(files("f5_tts").joinpath(f"../../data/{dataset_name}_{tokenizer}"))
if audio_type == "raw":
try:
train_dataset = load_from_disk(f"{rel_data_path}/raw")
except: # noqa: E722
train_dataset = Dataset_.from_file(f"{rel_data_path}/raw.arrow")
preprocessed_mel = False
elif audio_type == "mel":
train_dataset = Dataset_.from_file(f"{rel_data_path}/mel.arrow")
preprocessed_mel = True
with open(f"{rel_data_path}/duration.json", "r", encoding="utf-8") as f:
data_dict = json.load(f)
durations = data_dict["duration"]
train_dataset = CustomDataset(
train_dataset,
durations=durations,
preprocessed_mel=preprocessed_mel,
mel_spec_module=mel_spec_module,
**mel_spec_kwargs,
)
elif dataset_type == "CustomDatasetPath":
try:
train_dataset = load_from_disk(f"{dataset_name}/raw")
except: # noqa: E722
train_dataset = Dataset_.from_file(f"{dataset_name}/raw.arrow")
with open(f"{dataset_name}/duration.json", "r", encoding="utf-8") as f:
data_dict = json.load(f)
durations = data_dict["duration"]
train_dataset = CustomDataset(
train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs
)
elif dataset_type == "HFDataset":
print(
"Should manually modify the path of huggingface dataset to your need.\n"
+ "May also the corresponding script cuz different dataset may have different format."
)
pre, post = dataset_name.split("_")
train_dataset = HFDataset(
load_dataset(f"{pre}/{pre}", split=f"train.{post}", cache_dir=str(files("f5_tts").joinpath("../../data"))),
)
return train_dataset
# collation
def collate_fn(batch):
mel_specs = [item["mel_spec"].squeeze(0) for item in batch]
mel_lengths = torch.LongTensor([spec.shape[-1] for spec in mel_specs])
max_mel_length = mel_lengths.amax()
padded_mel_specs = []
for spec in mel_specs:
padding = (0, max_mel_length - spec.size(-1))
padded_spec = F.pad(spec, padding, value=0)
padded_mel_specs.append(padded_spec)
mel_specs = torch.stack(padded_mel_specs)
text = [item["text"] for item in batch]
text_lengths = torch.LongTensor([len(item) for item in text])
return dict(
mel=mel_specs,
mel_lengths=mel_lengths, # records for padding mask
text=text,
text_lengths=text_lengths,
)

View File

@@ -0,0 +1,784 @@
"""
ein notation:
b - batch
n - sequence
nt - text sequence
nw - raw wave length
d - dimension
"""
# flake8: noqa
from __future__ import annotations
import math
from typing import Optional
import torch
import torch.nn.functional as F
import torchaudio
from librosa.filters import mel as librosa_mel_fn
from torch import nn
from x_transformers.x_transformers import apply_rotary_pos_emb
from f5_tts.model.utils import is_package_available
# raw wav to mel spec
mel_basis_cache = {}
hann_window_cache = {}
def get_bigvgan_mel_spectrogram(
waveform,
n_fft=1024,
n_mel_channels=100,
target_sample_rate=24000,
hop_length=256,
win_length=1024,
fmin=0,
fmax=None,
center=False,
): # Copy from https://github.com/NVIDIA/BigVGAN/tree/main
device = waveform.device
key = f"{n_fft}_{n_mel_channels}_{target_sample_rate}_{hop_length}_{win_length}_{fmin}_{fmax}_{device}"
if key not in mel_basis_cache:
mel = librosa_mel_fn(sr=target_sample_rate, n_fft=n_fft, n_mels=n_mel_channels, fmin=fmin, fmax=fmax)
mel_basis_cache[key] = torch.from_numpy(mel).float().to(device) # TODO: why they need .float()?
hann_window_cache[key] = torch.hann_window(win_length).to(device)
mel_basis = mel_basis_cache[key]
hann_window = hann_window_cache[key]
padding = (n_fft - hop_length) // 2
waveform = torch.nn.functional.pad(waveform.unsqueeze(1), (padding, padding), mode="reflect").squeeze(1)
spec = torch.stft(
waveform,
n_fft,
hop_length=hop_length,
win_length=win_length,
window=hann_window,
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=True,
)
spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9)
mel_spec = torch.matmul(mel_basis, spec)
mel_spec = torch.log(torch.clamp(mel_spec, min=1e-5))
return mel_spec
def get_vocos_mel_spectrogram(
waveform,
n_fft=1024,
n_mel_channels=100,
target_sample_rate=24000,
hop_length=256,
win_length=1024,
):
mel_stft = torchaudio.transforms.MelSpectrogram(
sample_rate=target_sample_rate,
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
n_mels=n_mel_channels,
power=1,
center=True,
normalized=False,
norm=None,
).to(waveform.device)
if len(waveform.shape) == 3:
waveform = waveform.squeeze(1) # 'b 1 nw -> b nw'
assert len(waveform.shape) == 2
mel = mel_stft(waveform)
mel = mel.clamp(min=1e-5).log()
return mel
class MelSpec(nn.Module):
def __init__(
self,
n_fft=1024,
hop_length=256,
win_length=1024,
n_mel_channels=100,
target_sample_rate=24_000,
mel_spec_type="vocos",
):
super().__init__()
assert mel_spec_type in ["vocos", "bigvgan"], print("We only support two extract mel backend: vocos or bigvgan")
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
self.n_mel_channels = n_mel_channels
self.target_sample_rate = target_sample_rate
if mel_spec_type == "vocos":
self.extractor = get_vocos_mel_spectrogram
elif mel_spec_type == "bigvgan":
self.extractor = get_bigvgan_mel_spectrogram
self.register_buffer("dummy", torch.tensor(0), persistent=False)
def forward(self, wav):
if self.dummy.device != wav.device:
self.to(wav.device)
mel = self.extractor(
waveform=wav,
n_fft=self.n_fft,
n_mel_channels=self.n_mel_channels,
target_sample_rate=self.target_sample_rate,
hop_length=self.hop_length,
win_length=self.win_length,
)
return mel
# sinusoidal position embedding
class SinusPositionEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x, scale=1000):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
# convolutional position embedding
class ConvPositionEmbedding(nn.Module):
def __init__(self, dim, kernel_size=31, groups=16):
super().__init__()
assert kernel_size % 2 != 0
self.conv1d = nn.Sequential(
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
nn.Mish(),
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
nn.Mish(),
)
def forward(self, x: float["b n d"], mask: bool["b n"] | None = None):
if mask is not None:
mask = mask[..., None]
x = x.masked_fill(~mask, 0.0)
x = x.permute(0, 2, 1)
x = self.conv1d(x)
out = x.permute(0, 2, 1)
if mask is not None:
out = out.masked_fill(~mask, 0.0)
return out
# rotary positional embedding related
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0):
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
# has some connection to NTK literature
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
# https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py
theta *= theta_rescale_factor ** (dim / (dim - 2))
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
freqs_cos = torch.cos(freqs) # real part
freqs_sin = torch.sin(freqs) # imaginary part
return torch.cat([freqs_cos, freqs_sin], dim=-1)
def get_pos_embed_indices(start, length, max_pos, scale=1.0):
# length = length if isinstance(length, int) else length.max()
scale = scale * torch.ones_like(start, dtype=torch.float32) # in case scale is a scalar
pos = (
start.unsqueeze(1)
+ (torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) * scale.unsqueeze(1)).long()
)
# avoid extra long error.
pos = torch.where(pos < max_pos, pos, max_pos - 1)
return pos
# Global Response Normalization layer (Instance Normalization ?)
class GRN(nn.Module):
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
self.beta = nn.Parameter(torch.zeros(1, 1, dim))
def forward(self, x):
Gx = torch.norm(x, p=2, dim=1, keepdim=True)
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
return self.gamma * (x * Nx) + self.beta + x
# ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py
# ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108
class ConvNeXtV2Block(nn.Module):
def __init__(
self,
dim: int,
intermediate_dim: int,
dilation: int = 1,
):
super().__init__()
padding = (dilation * (7 - 1)) // 2
self.dwconv = nn.Conv1d(
dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
) # depthwise conv
self.norm = nn.LayerNorm(dim, eps=1e-6)
self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.grn = GRN(intermediate_dim)
self.pwconv2 = nn.Linear(intermediate_dim, dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
x = x.transpose(1, 2) # b n d -> b d n
x = self.dwconv(x)
x = x.transpose(1, 2) # b d n -> b n d
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.grn(x)
x = self.pwconv2(x)
return residual + x
# RMSNorm
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
self.native_rms_norm = float(torch.__version__[:3]) >= 2.4
def forward(self, x):
if self.native_rms_norm:
if self.weight.dtype in [torch.float16, torch.bfloat16]:
x = x.to(self.weight.dtype)
x = F.rms_norm(x, normalized_shape=(x.shape[-1],), weight=self.weight, eps=self.eps)
else:
variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + self.eps)
if self.weight.dtype in [torch.float16, torch.bfloat16]:
x = x.to(self.weight.dtype)
x = x * self.weight
return x
# AdaLayerNorm
# return with modulated x for attn input, and params for later mlp modulation
class AdaLayerNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(dim, dim * 6)
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
def forward(self, x, emb=None):
emb = self.linear(self.silu(emb))
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
# AdaLayerNorm for final layer
# return only with modulated x for attn input, cuz no more mlp modulation
class AdaLayerNorm_Final(nn.Module):
def __init__(self, dim):
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(dim, dim * 2)
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
def forward(self, x, emb):
emb = self.linear(self.silu(emb))
scale, shift = torch.chunk(emb, 2, dim=1)
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
return x
# FeedForward
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none"):
super().__init__()
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
activation = nn.GELU(approximate=approximate)
project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation)
self.ff = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
def forward(self, x):
return self.ff(x)
# Attention with possible joint part
# modified from diffusers/src/diffusers/models/attention_processor.py
class Attention(nn.Module):
def __init__(
self,
processor: JointAttnProcessor | AttnProcessor,
dim: int,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
context_dim: Optional[int] = None, # if not None -> joint attention
context_pre_only: bool = False,
qk_norm: Optional[str] = None,
):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
self.processor = processor
self.dim = dim
self.heads = heads
self.inner_dim = dim_head * heads
self.dropout = dropout
self.context_dim = context_dim
self.context_pre_only = context_pre_only
self.to_q = nn.Linear(dim, self.inner_dim)
self.to_k = nn.Linear(dim, self.inner_dim)
self.to_v = nn.Linear(dim, self.inner_dim)
if qk_norm is None:
self.q_norm = None
self.k_norm = None
elif qk_norm == "rms_norm":
self.q_norm = RMSNorm(dim_head, eps=1e-6)
self.k_norm = RMSNorm(dim_head, eps=1e-6)
else:
raise ValueError(f"Unimplemented qk_norm: {qk_norm}")
if self.context_dim is not None:
self.to_q_c = nn.Linear(context_dim, self.inner_dim)
self.to_k_c = nn.Linear(context_dim, self.inner_dim)
self.to_v_c = nn.Linear(context_dim, self.inner_dim)
if qk_norm is None:
self.c_q_norm = None
self.c_k_norm = None
elif qk_norm == "rms_norm":
self.c_q_norm = RMSNorm(dim_head, eps=1e-6)
self.c_k_norm = RMSNorm(dim_head, eps=1e-6)
self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(self.inner_dim, dim))
self.to_out.append(nn.Dropout(dropout))
if self.context_dim is not None and not self.context_pre_only:
self.to_out_c = nn.Linear(self.inner_dim, context_dim)
def forward(
self,
x: float["b n d"], # noised input x
c: float["b n d"] = None, # context c
mask: bool["b n"] | None = None,
rope=None, # rotary position embedding for x
c_rope=None, # rotary position embedding for c
) -> torch.Tensor:
if c is not None:
return self.processor(self, x, c=c, mask=mask, rope=rope, c_rope=c_rope)
else:
return self.processor(self, x, mask=mask, rope=rope)
# Attention processor
if is_package_available("flash_attn"):
from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn import flash_attn_varlen_func, flash_attn_func
class AttnProcessor:
def __init__(
self,
pe_attn_head: int | None = None, # number of attention head to apply rope, None for all
attn_backend: str = "torch", # "torch" or "flash_attn"
attn_mask_enabled: bool = True,
):
if attn_backend == "flash_attn":
assert is_package_available("flash_attn"), "Please install flash-attn first."
self.pe_attn_head = pe_attn_head
self.attn_backend = attn_backend
self.attn_mask_enabled = attn_mask_enabled
def __call__(
self,
attn: Attention,
x: float["b n d"], # noised input x
mask: bool["b n"] | None = None,
rope=None, # rotary position embedding
) -> torch.FloatTensor:
batch_size = x.shape[0]
# `sample` projections
query = attn.to_q(x)
key = attn.to_k(x)
value = attn.to_v(x)
# attention
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# qk norm
if attn.q_norm is not None:
query = attn.q_norm(query)
if attn.k_norm is not None:
key = attn.k_norm(key)
# apply rotary position embedding
if rope is not None:
freqs, xpos_scale = rope
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
if self.pe_attn_head is not None:
pn = self.pe_attn_head
query[:, :pn, :, :] = apply_rotary_pos_emb(query[:, :pn, :, :], freqs, q_xpos_scale)
key[:, :pn, :, :] = apply_rotary_pos_emb(key[:, :pn, :, :], freqs, k_xpos_scale)
else:
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
if self.attn_backend == "torch":
# mask. e.g. inference got a batch with different target durations, mask out the padding
if self.attn_mask_enabled and mask is not None:
attn_mask = mask
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
else:
attn_mask = None
x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
elif self.attn_backend == "flash_attn":
query = query.transpose(1, 2) # [b, h, n, d] -> [b, n, h, d]
key = key.transpose(1, 2)
value = value.transpose(1, 2)
if self.attn_mask_enabled and mask is not None:
query, indices, q_cu_seqlens, q_max_seqlen_in_batch, _ = unpad_input(query, mask)
key, _, k_cu_seqlens, k_max_seqlen_in_batch, _ = unpad_input(key, mask)
value, _, _, _, _ = unpad_input(value, mask)
x = flash_attn_varlen_func(
query,
key,
value,
q_cu_seqlens,
k_cu_seqlens,
q_max_seqlen_in_batch,
k_max_seqlen_in_batch,
)
x = pad_input(x, indices, batch_size, q_max_seqlen_in_batch)
x = x.reshape(batch_size, -1, attn.heads * head_dim)
else:
x = flash_attn_func(query, key, value, dropout_p=0.0, causal=False)
x = x.reshape(batch_size, -1, attn.heads * head_dim)
x = x.to(query.dtype)
# linear proj
x = attn.to_out[0](x)
# dropout
x = attn.to_out[1](x)
if mask is not None:
mask = mask.unsqueeze(-1)
x = x.masked_fill(~mask, 0.0)
return x
# Joint Attention processor for MM-DiT
# modified from diffusers/src/diffusers/models/attention_processor.py
class JointAttnProcessor:
def __init__(self):
pass
def __call__(
self,
attn: Attention,
x: float["b n d"], # noised input x
c: float["b nt d"] = None, # context c, here text
mask: bool["b n"] | None = None,
rope=None, # rotary position embedding for x
c_rope=None, # rotary position embedding for c
) -> torch.FloatTensor:
residual = x
batch_size = c.shape[0]
# `sample` projections
query = attn.to_q(x)
key = attn.to_k(x)
value = attn.to_v(x)
# `context` projections
c_query = attn.to_q_c(c)
c_key = attn.to_k_c(c)
c_value = attn.to_v_c(c)
# attention
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
c_query = c_query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
c_key = c_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
c_value = c_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# qk norm
if attn.q_norm is not None:
query = attn.q_norm(query)
if attn.k_norm is not None:
key = attn.k_norm(key)
if attn.c_q_norm is not None:
c_query = attn.c_q_norm(c_query)
if attn.c_k_norm is not None:
c_key = attn.c_k_norm(c_key)
# apply rope for context and noised input independently
if rope is not None:
freqs, xpos_scale = rope
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
if c_rope is not None:
freqs, xpos_scale = c_rope
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
# joint attention
query = torch.cat([query, c_query], dim=2)
key = torch.cat([key, c_key], dim=2)
value = torch.cat([value, c_value], dim=2)
# mask. e.g. inference got a batch with different target durations, mask out the padding
if mask is not None:
attn_mask = F.pad(mask, (0, c.shape[1]), value=True) # no mask for c (text)
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
else:
attn_mask = None
x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
x = x.to(query.dtype)
# Split the attention outputs.
x, c = (
x[:, : residual.shape[1]],
x[:, residual.shape[1] :],
)
# linear proj
x = attn.to_out[0](x)
# dropout
x = attn.to_out[1](x)
if not attn.context_pre_only:
c = attn.to_out_c(c)
if mask is not None:
mask = mask.unsqueeze(-1)
x = x.masked_fill(~mask, 0.0)
# c = c.masked_fill(~mask, 0.) # no mask for c (text)
return x, c
# DiT Block
class DiTBlock(nn.Module):
def __init__(
self,
dim,
heads,
dim_head,
ff_mult=4,
dropout=0.1,
qk_norm=None,
pe_attn_head=None,
attn_backend="torch", # "torch" or "flash_attn"
attn_mask_enabled=True,
):
super().__init__()
self.attn_norm = AdaLayerNorm(dim)
self.attn = Attention(
processor=AttnProcessor(
pe_attn_head=pe_attn_head,
attn_backend=attn_backend,
attn_mask_enabled=attn_mask_enabled,
),
dim=dim,
heads=heads,
dim_head=dim_head,
dropout=dropout,
qk_norm=qk_norm,
)
self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embedding
# pre-norm & modulation for attention input
norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
# attention
attn_output = self.attn(x=norm, mask=mask, rope=rope)
# process attention output for input x
x = x + gate_msa.unsqueeze(1) * attn_output
norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
ff_output = self.ff(norm)
x = x + gate_mlp.unsqueeze(1) * ff_output
return x
# MMDiT Block https://arxiv.org/abs/2403.03206
class MMDiTBlock(nn.Module):
r"""
modified from diffusers/src/diffusers/models/attention.py
notes.
_c: context related. text, cond, etc. (left part in sd3 fig2.b)
_x: noised input related. (right part)
context_pre_only: last layer only do prenorm + modulation cuz no more ffn
"""
def __init__(
self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_dim=None, context_pre_only=False, qk_norm=None
):
super().__init__()
if context_dim is None:
context_dim = dim
self.context_pre_only = context_pre_only
self.attn_norm_c = AdaLayerNorm_Final(context_dim) if context_pre_only else AdaLayerNorm(context_dim)
self.attn_norm_x = AdaLayerNorm(dim)
self.attn = Attention(
processor=JointAttnProcessor(),
dim=dim,
heads=heads,
dim_head=dim_head,
dropout=dropout,
context_dim=context_dim,
context_pre_only=context_pre_only,
qk_norm=qk_norm,
)
if not context_pre_only:
self.ff_norm_c = nn.LayerNorm(context_dim, elementwise_affine=False, eps=1e-6)
self.ff_c = FeedForward(dim=context_dim, mult=ff_mult, dropout=dropout, approximate="tanh")
else:
self.ff_norm_c = None
self.ff_c = None
self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_x = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
def forward(self, x, c, t, mask=None, rope=None, c_rope=None): # x: noised input, c: context, t: time embedding
# pre-norm & modulation for attention input
if self.context_pre_only:
norm_c = self.attn_norm_c(c, t)
else:
norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c(c, emb=t)
norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x(x, emb=t)
# attention
x_attn_output, c_attn_output = self.attn(x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope)
# process attention output for context c
if self.context_pre_only:
c = None
else: # if not last layer
c = c + c_gate_msa.unsqueeze(1) * c_attn_output
norm_c = self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
c_ff_output = self.ff_c(norm_c)
c = c + c_gate_mlp.unsqueeze(1) * c_ff_output
# process attention output for input x
x = x + x_gate_msa.unsqueeze(1) * x_attn_output
norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None]
x_ff_output = self.ff_x(norm_x)
x = x + x_gate_mlp.unsqueeze(1) * x_ff_output
return c, x
# time step conditioning embedding
class TimestepEmbedding(nn.Module):
def __init__(self, dim, freq_embed_dim=256):
super().__init__()
self.time_embed = SinusPositionEmbedding(freq_embed_dim)
self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
def forward(self, timestep: float["b"]):
time_hidden = self.time_embed(timestep)
time_hidden = time_hidden.to(timestep.dtype)
time = self.time_mlp(time_hidden) # b d
return time

View File

@@ -0,0 +1,439 @@
from __future__ import annotations
import gc
import math
import os
import torch
import torchaudio
import wandb
from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs
from ema_pytorch import EMA
from torch.optim import AdamW
from torch.optim.lr_scheduler import LinearLR, SequentialLR
from torch.utils.data import DataLoader, Dataset, SequentialSampler
from tqdm import tqdm
from f5_tts.model import CFM
from f5_tts.model.dataset import DynamicBatchSampler, collate_fn
from f5_tts.model.utils import default, exists
# trainer
class Trainer:
def __init__(
self,
model: CFM,
epochs,
learning_rate,
num_warmup_updates=20000,
save_per_updates=1000,
keep_last_n_checkpoints: int = -1, # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
checkpoint_path=None,
batch_size_per_gpu=32,
batch_size_type: str = "sample",
max_samples=32,
grad_accumulation_steps=1,
max_grad_norm=1.0,
noise_scheduler: str | None = None,
duration_predictor: torch.nn.Module | None = None,
logger: str | None = "wandb", # "wandb" | "tensorboard" | None
wandb_project="test_f5-tts",
wandb_run_name="test_run",
wandb_resume_id: str = None,
log_samples: bool = False,
last_per_updates=None,
accelerate_kwargs: dict = dict(),
ema_kwargs: dict = dict(),
bnb_optimizer: bool = False,
mel_spec_type: str = "vocos", # "vocos" | "bigvgan"
is_local_vocoder: bool = False, # use local path vocoder
local_vocoder_path: str = "", # local vocoder path
model_cfg_dict: dict = dict(), # training config
):
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
if logger == "wandb" and not wandb.api.api_key:
logger = None
self.log_samples = log_samples
self.accelerator = Accelerator(
log_with=logger if logger == "wandb" else None,
kwargs_handlers=[ddp_kwargs],
gradient_accumulation_steps=grad_accumulation_steps,
**accelerate_kwargs,
)
self.logger = logger
if self.logger == "wandb":
if exists(wandb_resume_id):
init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name, "id": wandb_resume_id}}
else:
init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
if not model_cfg_dict:
model_cfg_dict = {
"epochs": epochs,
"learning_rate": learning_rate,
"num_warmup_updates": num_warmup_updates,
"batch_size_per_gpu": batch_size_per_gpu,
"batch_size_type": batch_size_type,
"max_samples": max_samples,
"grad_accumulation_steps": grad_accumulation_steps,
"max_grad_norm": max_grad_norm,
"noise_scheduler": noise_scheduler,
}
model_cfg_dict["gpus"] = self.accelerator.num_processes
self.accelerator.init_trackers(
project_name=wandb_project,
init_kwargs=init_kwargs,
config=model_cfg_dict,
)
elif self.logger == "tensorboard":
from torch.utils.tensorboard import SummaryWriter
self.writer = SummaryWriter(log_dir=f"runs/{wandb_run_name}")
self.model = model
if self.is_main:
self.ema_model = EMA(model, include_online_model=False, **ema_kwargs)
self.ema_model.to(self.accelerator.device)
print(f"Using logger: {logger}")
if grad_accumulation_steps > 1:
print(
"Gradient accumulation checkpointing with per_updates now, old logic per_steps used with before f992c4e"
)
self.epochs = epochs
self.num_warmup_updates = num_warmup_updates
self.save_per_updates = save_per_updates
self.keep_last_n_checkpoints = keep_last_n_checkpoints
self.last_per_updates = default(last_per_updates, save_per_updates)
self.checkpoint_path = default(checkpoint_path, "ckpts/test_f5-tts")
self.batch_size_per_gpu = batch_size_per_gpu
self.batch_size_type = batch_size_type
self.max_samples = max_samples
self.grad_accumulation_steps = grad_accumulation_steps
self.max_grad_norm = max_grad_norm
# mel vocoder config
self.vocoder_name = mel_spec_type
self.is_local_vocoder = is_local_vocoder
self.local_vocoder_path = local_vocoder_path
self.noise_scheduler = noise_scheduler
self.duration_predictor = duration_predictor
if bnb_optimizer:
import bitsandbytes as bnb
self.optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=learning_rate)
else:
self.optimizer = AdamW(model.parameters(), lr=learning_rate)
self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
@property
def is_main(self):
return self.accelerator.is_main_process
def save_checkpoint(self, update, last=False):
self.accelerator.wait_for_everyone()
if self.is_main:
checkpoint = dict(
model_state_dict=self.accelerator.unwrap_model(self.model).state_dict(),
optimizer_state_dict=self.accelerator.unwrap_model(self.optimizer).state_dict(),
ema_model_state_dict=self.ema_model.state_dict(),
scheduler_state_dict=self.scheduler.state_dict(),
update=update,
)
if not os.path.exists(self.checkpoint_path):
os.makedirs(self.checkpoint_path)
if last:
self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt")
print(f"Saved last checkpoint at update {update}")
else:
if self.keep_last_n_checkpoints == 0:
return
self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{update}.pt")
if self.keep_last_n_checkpoints > 0:
# Updated logic to exclude pretrained model from rotation
checkpoints = [
f
for f in os.listdir(self.checkpoint_path)
if f.startswith("model_")
and not f.startswith("pretrained_") # Exclude pretrained models
and f.endswith(".pt")
and f != "model_last.pt"
]
checkpoints.sort(key=lambda x: int(x.split("_")[1].split(".")[0]))
while len(checkpoints) > self.keep_last_n_checkpoints:
oldest_checkpoint = checkpoints.pop(0)
os.remove(os.path.join(self.checkpoint_path, oldest_checkpoint))
print(f"Removed old checkpoint: {oldest_checkpoint}")
def load_checkpoint(self):
if (
not exists(self.checkpoint_path)
or not os.path.exists(self.checkpoint_path)
or not any(filename.endswith((".pt", ".safetensors")) for filename in os.listdir(self.checkpoint_path))
):
return 0
self.accelerator.wait_for_everyone()
if "model_last.pt" in os.listdir(self.checkpoint_path):
latest_checkpoint = "model_last.pt"
else:
# Updated to consider pretrained models for loading but prioritize training checkpoints
all_checkpoints = [
f
for f in os.listdir(self.checkpoint_path)
if (f.startswith("model_") or f.startswith("pretrained_")) and f.endswith((".pt", ".safetensors"))
]
# First try to find regular training checkpoints
training_checkpoints = [f for f in all_checkpoints if f.startswith("model_") and f != "model_last.pt"]
if training_checkpoints:
latest_checkpoint = sorted(
training_checkpoints,
key=lambda x: int("".join(filter(str.isdigit, x))),
)[-1]
else:
# If no training checkpoints, use pretrained model
latest_checkpoint = next(f for f in all_checkpoints if f.startswith("pretrained_"))
if latest_checkpoint.endswith(".safetensors"): # always a pretrained checkpoint
from safetensors.torch import load_file
checkpoint = load_file(f"{self.checkpoint_path}/{latest_checkpoint}", device="cpu")
checkpoint = {"ema_model_state_dict": checkpoint}
elif latest_checkpoint.endswith(".pt"):
# checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
checkpoint = torch.load(
f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu"
)
# patch for backward compatibility, 305e3ea
for key in ["ema_model.mel_spec.mel_stft.mel_scale.fb", "ema_model.mel_spec.mel_stft.spectrogram.window"]:
if key in checkpoint["ema_model_state_dict"]:
del checkpoint["ema_model_state_dict"][key]
if self.is_main:
self.ema_model.load_state_dict(checkpoint["ema_model_state_dict"])
if "update" in checkpoint or "step" in checkpoint:
# patch for backward compatibility, with before f992c4e
if "step" in checkpoint:
checkpoint["update"] = checkpoint["step"] // self.grad_accumulation_steps
if self.grad_accumulation_steps > 1 and self.is_main:
print(
"F5-TTS WARNING: Loading checkpoint saved with per_steps logic (before f992c4e), will convert to per_updates according to grad_accumulation_steps setting, may have unexpected behaviour."
)
# patch for backward compatibility, 305e3ea
for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]:
if key in checkpoint["model_state_dict"]:
del checkpoint["model_state_dict"][key]
self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"])
self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint["optimizer_state_dict"])
if self.scheduler:
self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
update = checkpoint["update"]
else:
checkpoint["model_state_dict"] = {
k.replace("ema_model.", ""): v
for k, v in checkpoint["ema_model_state_dict"].items()
if k not in ["initted", "update", "step"]
}
self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"])
update = 0
del checkpoint
gc.collect()
return update
def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
if self.log_samples:
from f5_tts.infer.utils_infer import cfg_strength, load_vocoder, nfe_step, sway_sampling_coef
vocoder = load_vocoder(
vocoder_name=self.vocoder_name, is_local=self.is_local_vocoder, local_path=self.local_vocoder_path
)
target_sample_rate = self.accelerator.unwrap_model(self.model).mel_spec.target_sample_rate
log_samples_path = f"{self.checkpoint_path}/samples"
os.makedirs(log_samples_path, exist_ok=True)
if exists(resumable_with_seed):
generator = torch.Generator()
generator.manual_seed(resumable_with_seed)
else:
generator = None
if self.batch_size_type == "sample":
train_dataloader = DataLoader(
train_dataset,
collate_fn=collate_fn,
num_workers=num_workers,
pin_memory=True,
persistent_workers=True,
batch_size=self.batch_size_per_gpu,
shuffle=True,
generator=generator,
)
elif self.batch_size_type == "frame":
self.accelerator.even_batches = False
sampler = SequentialSampler(train_dataset)
batch_sampler = DynamicBatchSampler(
sampler,
self.batch_size_per_gpu,
max_samples=self.max_samples,
random_seed=resumable_with_seed, # This enables reproducible shuffling
drop_residual=False,
)
train_dataloader = DataLoader(
train_dataset,
collate_fn=collate_fn,
num_workers=num_workers,
pin_memory=True,
persistent_workers=True,
batch_sampler=batch_sampler,
)
else:
raise ValueError(f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}")
# accelerator.prepare() dispatches batches to devices;
# which means the length of dataloader calculated before, should consider the number of devices
warmup_updates = (
self.num_warmup_updates * self.accelerator.num_processes
) # consider a fixed warmup steps while using accelerate multi-gpu ddp
# otherwise by default with split_batches=False, warmup steps change with num_processes
total_updates = math.ceil(len(train_dataloader) / self.grad_accumulation_steps) * self.epochs
decay_updates = total_updates - warmup_updates
warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_updates)
decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_updates)
self.scheduler = SequentialLR(
self.optimizer, schedulers=[warmup_scheduler, decay_scheduler], milestones=[warmup_updates]
)
train_dataloader, self.scheduler = self.accelerator.prepare(
train_dataloader, self.scheduler
) # actual multi_gpu updates = single_gpu updates / gpu nums
start_update = self.load_checkpoint()
global_update = start_update
if exists(resumable_with_seed):
orig_epoch_step = len(train_dataloader)
start_step = start_update * self.grad_accumulation_steps
skipped_epoch = int(start_step // orig_epoch_step)
skipped_batch = start_step % orig_epoch_step
skipped_dataloader = self.accelerator.skip_first_batches(train_dataloader, num_batches=skipped_batch)
else:
skipped_epoch = 0
for epoch in range(skipped_epoch, self.epochs):
self.model.train()
if exists(resumable_with_seed) and epoch == skipped_epoch:
progress_bar_initial = math.ceil(skipped_batch / self.grad_accumulation_steps)
current_dataloader = skipped_dataloader
else:
progress_bar_initial = 0
current_dataloader = train_dataloader
# Set epoch for the batch sampler if it exists
if hasattr(train_dataloader, "batch_sampler") and hasattr(train_dataloader.batch_sampler, "set_epoch"):
train_dataloader.batch_sampler.set_epoch(epoch)
progress_bar = tqdm(
range(math.ceil(len(train_dataloader) / self.grad_accumulation_steps)),
desc=f"Epoch {epoch + 1}/{self.epochs}",
unit="update",
disable=not self.accelerator.is_local_main_process,
initial=progress_bar_initial,
)
for batch in current_dataloader:
with self.accelerator.accumulate(self.model):
text_inputs = batch["text"]
mel_spec = batch["mel"].permute(0, 2, 1)
mel_lengths = batch["mel_lengths"]
# TODO. add duration predictor training
if self.duration_predictor is not None and self.accelerator.is_local_main_process:
dur_loss = self.duration_predictor(mel_spec, lens=batch.get("durations"))
self.accelerator.log({"duration loss": dur_loss.item()}, step=global_update)
loss, cond, pred = self.model(
mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler
)
self.accelerator.backward(loss)
if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
self.optimizer.step()
self.scheduler.step()
self.optimizer.zero_grad()
if self.accelerator.sync_gradients:
if self.is_main:
self.ema_model.update()
global_update += 1
progress_bar.update(1)
progress_bar.set_postfix(update=str(global_update), loss=loss.item())
if self.accelerator.is_local_main_process:
self.accelerator.log(
{"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_update
)
if self.logger == "tensorboard":
self.writer.add_scalar("loss", loss.item(), global_update)
self.writer.add_scalar("lr", self.scheduler.get_last_lr()[0], global_update)
if global_update % self.last_per_updates == 0 and self.accelerator.sync_gradients:
self.save_checkpoint(global_update, last=True)
if global_update % self.save_per_updates == 0 and self.accelerator.sync_gradients:
self.save_checkpoint(global_update)
if self.log_samples and self.accelerator.is_local_main_process:
ref_audio_len = mel_lengths[0]
infer_text = [
text_inputs[0] + ([" "] if isinstance(text_inputs[0], list) else " ") + text_inputs[0]
]
with torch.inference_mode():
generated, _ = self.accelerator.unwrap_model(self.model).sample(
cond=mel_spec[0][:ref_audio_len].unsqueeze(0),
text=infer_text,
duration=ref_audio_len * 2,
steps=nfe_step,
cfg_strength=cfg_strength,
sway_sampling_coef=sway_sampling_coef,
)
generated = generated.to(torch.float32)
gen_mel_spec = generated[:, ref_audio_len:, :].permute(0, 2, 1).to(self.accelerator.device)
ref_mel_spec = batch["mel"][0].unsqueeze(0)
if self.vocoder_name == "vocos":
gen_audio = vocoder.decode(gen_mel_spec).cpu()
ref_audio = vocoder.decode(ref_mel_spec).cpu()
elif self.vocoder_name == "bigvgan":
gen_audio = vocoder(gen_mel_spec).squeeze(0).cpu()
ref_audio = vocoder(ref_mel_spec).squeeze(0).cpu()
torchaudio.save(
f"{log_samples_path}/update_{global_update}_gen.wav", gen_audio, target_sample_rate
)
torchaudio.save(
f"{log_samples_path}/update_{global_update}_ref.wav", ref_audio, target_sample_rate
)
self.model.train()
self.save_checkpoint(global_update, last=True)
self.accelerator.end_training()

View File

@@ -0,0 +1,220 @@
from __future__ import annotations
import os
import random
from collections import defaultdict
from importlib.resources import files
import jieba
import torch
from pypinyin import Style, lazy_pinyin
from torch.nn.utils.rnn import pad_sequence
# seed everything
def seed_everything(seed=0):
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# helpers
def exists(v):
return v is not None
def default(v, d):
return v if exists(v) else d
def is_package_available(package_name: str) -> bool:
try:
import importlib
package_exists = importlib.util.find_spec(package_name) is not None
return package_exists
except Exception:
return False
# tensor helpers
def lens_to_mask(t: int["b"], length: int | None = None) -> bool["b n"]: # noqa: F722 F821
if not exists(length):
length = t.amax()
seq = torch.arange(length, device=t.device)
return seq[None, :] < t[:, None]
def mask_from_start_end_indices(seq_len: int["b"], start: int["b"], end: int["b"]): # noqa: F722 F821
max_seq_len = seq_len.max().item()
seq = torch.arange(max_seq_len, device=start.device).long()
start_mask = seq[None, :] >= start[:, None]
end_mask = seq[None, :] < end[:, None]
return start_mask & end_mask
def mask_from_frac_lengths(seq_len: int["b"], frac_lengths: float["b"]): # noqa: F722 F821
lengths = (frac_lengths * seq_len).long()
max_start = seq_len - lengths
rand = torch.rand_like(frac_lengths)
start = (max_start * rand).long().clamp(min=0)
end = start + lengths
return mask_from_start_end_indices(seq_len, start, end)
def maybe_masked_mean(t: float["b n d"], mask: bool["b n"] = None) -> float["b d"]: # noqa: F722
if not exists(mask):
return t.mean(dim=1)
t = torch.where(mask[:, :, None], t, torch.tensor(0.0, device=t.device))
num = t.sum(dim=1)
den = mask.float().sum(dim=1)
return num / den.clamp(min=1.0)
# simple utf-8 tokenizer, since paper went character based
def list_str_to_tensor(text: list[str], padding_value=-1) -> int["b nt"]: # noqa: F722
list_tensors = [torch.tensor([*bytes(t, "UTF-8")]) for t in text] # ByT5 style
text = pad_sequence(list_tensors, padding_value=padding_value, batch_first=True)
return text
# char tokenizer, based on custom dataset's extracted .txt file
def list_str_to_idx(
text: list[str] | list[list[str]],
vocab_char_map: dict[str, int], # {char: idx}
padding_value=-1,
) -> int["b nt"]: # noqa: F722
list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
return text
# Get tokenizer
def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
"""
tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
- "char" for char-wise tokenizer, need .txt vocab_file
- "byte" for utf-8 tokenizer
- "custom" if you're directly passing in a path to the vocab.txt you want to use
vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
- if use "char", derived from unfiltered character & symbol counts of custom dataset
- if use "byte", set to 256 (unicode byte range)
"""
if tokenizer in ["pinyin", "char"]:
tokenizer_path = os.path.join(files("f5_tts").joinpath("../../data"), f"{dataset_name}_{tokenizer}/vocab.txt")
with open(tokenizer_path, "r", encoding="utf-8") as f:
vocab_char_map = {}
for i, char in enumerate(f):
vocab_char_map[char[:-1]] = i
vocab_size = len(vocab_char_map)
assert vocab_char_map[" "] == 0, "make sure space is of idx 0 in vocab.txt, cuz 0 is used for unknown char"
elif tokenizer == "byte":
vocab_char_map = None
vocab_size = 256
elif tokenizer == "custom":
with open(dataset_name, "r", encoding="utf-8") as f:
vocab_char_map = {}
for i, char in enumerate(f):
vocab_char_map[char[:-1]] = i
vocab_size = len(vocab_char_map)
return vocab_char_map, vocab_size
# convert char to pinyin
def convert_char_to_pinyin(text_list, polyphone=True):
if jieba.dt.initialized is False:
jieba.default_logger.setLevel(50) # CRITICAL
jieba.initialize()
final_text_list = []
custom_trans = str.maketrans(
{";": ",", "": '"', "": '"', "": "'", "": "'"}
) # add custom trans here, to address oov
def is_chinese(c):
return (
"\u3100" <= c <= "\u9fff" # common chinese characters
)
for text in text_list:
char_list = []
text = text.translate(custom_trans)
for seg in jieba.cut(text):
seg_byte_len = len(bytes(seg, "UTF-8"))
if seg_byte_len == len(seg): # if pure alphabets and symbols
if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
char_list.append(" ")
char_list.extend(seg)
elif polyphone and seg_byte_len == 3 * len(seg): # if pure east asian characters
seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
for i, c in enumerate(seg):
if is_chinese(c):
char_list.append(" ")
char_list.append(seg_[i])
else: # if mixed characters, alphabets and symbols
for c in seg:
if ord(c) < 256:
char_list.extend(c)
elif is_chinese(c):
char_list.append(" ")
char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True))
else:
char_list.append(c)
final_text_list.append(char_list)
return final_text_list
# filter func for dirty data with many repetitions
def repetition_found(text, length=2, tolerance=10):
pattern_count = defaultdict(int)
for i in range(len(text) - length + 1):
pattern = text[i : i + length]
pattern_count[pattern] += 1
for pattern, count in pattern_count.items():
if count > tolerance:
return True
return False
# get the empirically pruned step for sampling
def get_epss_timesteps(n, device, dtype):
dt = 1 / 32
predefined_timesteps = {
5: [0, 2, 4, 8, 16, 32],
6: [0, 2, 4, 6, 8, 16, 32],
7: [0, 2, 4, 6, 8, 16, 24, 32],
10: [0, 2, 4, 6, 8, 12, 16, 20, 24, 28, 32],
12: [0, 2, 4, 6, 8, 10, 12, 14, 16, 20, 24, 28, 32],
16: [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 32],
}
t = predefined_timesteps.get(n, [])
if not t:
return torch.linspace(0, 1, n + 1, device=device, dtype=dtype)
return dt * torch.tensor(t, device=device, dtype=dtype)

View File

@@ -0,0 +1,3 @@
FROM nvcr.io/nvidia/tritonserver:24.12-py3
RUN pip install tritonclient[grpc] tensorrt-llm==0.16.0 torchaudio==2.5.1 jieba pypinyin librosa vocos
WORKDIR /workspace

View File

@@ -0,0 +1,69 @@
## Triton Inference Serving Best Practice for F5-TTS
### Quick Start
Directly launch the service using docker compose.
```sh
# TODO: support F5TTS_v1_Base
MODEL=F5TTS_Base docker compose up
```
### Build Image
Build the docker image from scratch.
```sh
docker build . -f Dockerfile.server -t soar97/triton-f5-tts:24.12
```
### Create Docker Container
```sh
your_mount_dir=/mnt:/mnt
docker run -it --name "f5-server" --gpus all --net host -v $your_mount_dir --shm-size=2g soar97/triton-f5-tts:24.12
```
### Export Models to TensorRT-LLM and Launch Server
Inside docker container, we would follow the official guide of TensorRT-LLM to build qwen and whisper TensorRT-LLM engines. See [here](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/whisper).
```sh
bash run.sh 0 4 F5TTS_Base
```
### HTTP Client
```sh
python3 client_http.py
```
### Benchmark using Client-Server Mode
```sh
num_task=2
python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts
```
### Benchmark using Offline TRT-LLM Mode
```sh
batch_size=1
split_name=wenetspeech4tts
backend_type=trt
log_dir=./log_benchmark_batch_size_${batch_size}_${split_name}_${backend_type}
rm -r $log_dir
ln -s model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py ./
torchrun --nproc_per_node=1 \
benchmark.py --output-dir $log_dir \
--batch-size $batch_size \
--enable-warmup \
--split-name $split_name \
--model-path $F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt \
--vocab-file $F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt \
--vocoder-trt-engine-path $vocoder_trt_engine_path \
--backend-type $backend_type \
--tllm-model-dir $F5_TTS_TRT_LLM_ENGINE_PATH || exit 1
```
### Benchmark Results
Decoding on a single L20 GPU, using 26 different prompt_audio & target_text pairs, 16 NFE.
| Model | Concurrency | Avg Latency | RTF | Mode |
|---------------------|----------------|-------------|--------|-----------------|
| F5-TTS Base (Vocos) | 2 | 253 ms | 0.0394 | Client-Server |
| F5-TTS Base (Vocos) | 1 (Batch_size) | - | 0.0402 | Offline TRT-LLM |
| F5-TTS Base (Vocos) | 1 (Batch_size) | - | 0.1467 | Offline Pytorch |
### Credits
1. [F5-TTS-TRTLLM](https://github.com/Bigfishering/f5-tts-trtllm)

View File

@@ -0,0 +1,560 @@
# Copyright (c) 2024 Tsinghua Univ. (authors: Xingchen Song)
# 2025 authors: Yuekai Zhang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from https://github.com/xingchensong/S3Tokenizer/blob/main/s3tokenizer/cli.py
""" Example Usage
torchrun --nproc_per_node=1 \
benchmark.py --output-dir $log_dir \
--batch-size $batch_size \
--enable-warmup \
--split-name $split_name \
--model-path $F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt \
--vocab-file $F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt \
--vocoder-trt-engine-path $vocoder_trt_engine_path \
--backend-type $backend_type \
--tllm-model-dir $F5_TTS_TRT_LLM_ENGINE_PATH || exit 1
"""
import argparse
import json
import os
import time
from typing import Dict, List, Union
import datasets
import jieba
import tensorrt as trt
import torch
import torch.distributed as dist
import torch.nn.functional as F
import torchaudio
from datasets import load_dataset
from f5_tts_trtllm import F5TTS
from huggingface_hub import hf_hub_download
from pypinyin import Style, lazy_pinyin
from tensorrt_llm._utils import trt_dtype_to_torch
from tensorrt_llm.logger import logger
from tensorrt_llm.runtime.session import Session, TensorInfo
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, DistributedSampler
from tqdm import tqdm
from vocos import Vocos
torch.manual_seed(0)
def get_args():
parser = argparse.ArgumentParser(description="extract speech code")
parser.add_argument(
"--split-name",
type=str,
default="wenetspeech4tts",
choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"],
help="huggingface dataset split name",
)
parser.add_argument("--output-dir", required=True, type=str, help="dir to save result")
parser.add_argument(
"--vocab-file",
required=True,
type=str,
help="vocab file",
)
parser.add_argument(
"--model-path",
required=True,
type=str,
help="model path, to load text embedding",
)
parser.add_argument(
"--tllm-model-dir",
required=True,
type=str,
help="tllm model dir",
)
parser.add_argument(
"--batch-size",
required=True,
type=int,
help="batch size (per-device) for inference",
)
parser.add_argument("--num-workers", type=int, default=0, help="workers for dataloader")
parser.add_argument("--prefetch", type=int, default=None, help="prefetch for dataloader")
parser.add_argument(
"--vocoder",
default="vocos",
type=str,
help="vocoder name",
)
parser.add_argument(
"--vocoder-trt-engine-path",
default=None,
type=str,
help="vocoder trt engine path",
)
parser.add_argument("--enable-warmup", action="store_true")
parser.add_argument("--remove-input-padding", action="store_true")
parser.add_argument("--use-perf", action="store_true", help="use nvtx to record performance")
parser.add_argument("--backend-type", type=str, default="triton", choices=["trt", "pytorch"], help="backend type")
args = parser.parse_args()
return args
def padded_mel_batch(ref_mels, max_seq_len):
padded_ref_mels = []
for mel in ref_mels:
# pad along the last dimension
padded_ref_mel = F.pad(mel, (0, 0, 0, max_seq_len - mel.shape[0]), value=0)
padded_ref_mels.append(padded_ref_mel)
padded_ref_mels = torch.stack(padded_ref_mels)
return padded_ref_mels
def data_collator(batch, vocab_char_map, device="cuda", use_perf=False):
if use_perf:
torch.cuda.nvtx.range_push("data_collator")
target_sample_rate = 24000
target_rms = 0.1
ids, ref_mel_list, ref_mel_len_list, estimated_reference_target_mel_len, reference_target_texts_list = (
[],
[],
[],
[],
[],
)
for i, item in enumerate(batch):
item_id, prompt_text, target_text = (
item["id"],
item["prompt_text"],
item["target_text"],
)
ids.append(item_id)
reference_target_texts_list.append(prompt_text + target_text)
ref_audio_org, ref_sr = (
item["prompt_audio"]["array"],
item["prompt_audio"]["sampling_rate"],
)
ref_audio_org = torch.from_numpy(ref_audio_org).unsqueeze(0).float()
ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio_org)))
if ref_rms < target_rms:
ref_audio_org = ref_audio_org * target_rms / ref_rms
if ref_sr != target_sample_rate:
resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
ref_audio = resampler(ref_audio_org)
else:
ref_audio = ref_audio_org
if use_perf:
torch.cuda.nvtx.range_push(f"mel_spectrogram {i}")
ref_mel = mel_spectrogram(ref_audio, vocoder="vocos", device="cuda")
if use_perf:
torch.cuda.nvtx.range_pop()
ref_mel = ref_mel.squeeze()
ref_mel_len = ref_mel.shape[0]
assert ref_mel.shape[1] == 100
ref_mel_list.append(ref_mel)
ref_mel_len_list.append(ref_mel_len)
estimated_reference_target_mel_len.append(
int(ref_mel.shape[0] * (1 + len(target_text.encode("utf-8")) / len(prompt_text.encode("utf-8"))))
)
max_seq_len = max(estimated_reference_target_mel_len)
ref_mel_batch = padded_mel_batch(ref_mel_list, max_seq_len)
ref_mel_len_batch = torch.LongTensor(ref_mel_len_list)
pinyin_list = convert_char_to_pinyin(reference_target_texts_list, polyphone=True)
text_pad_sequence = list_str_to_idx(pinyin_list, vocab_char_map)
for i, item in enumerate(text_pad_sequence):
text_pad_sequence[i] = F.pad(
item, (0, estimated_reference_target_mel_len[i] - len(item)), mode="constant", value=-1
)
text_pad_sequence[i] += 1 # WAR: 0 is reserved for padding token, hard coding in F5-TTS
text_pad_sequence = pad_sequence(text_pad_sequence, padding_value=-1, batch_first=True).to(device)
text_pad_sequence = F.pad(
text_pad_sequence, (0, max_seq_len - text_pad_sequence.shape[1]), mode="constant", value=-1
)
if use_perf:
torch.cuda.nvtx.range_pop()
return {
"ids": ids,
"ref_mel_batch": ref_mel_batch,
"ref_mel_len_batch": ref_mel_len_batch,
"text_pad_sequence": text_pad_sequence,
"estimated_reference_target_mel_len": estimated_reference_target_mel_len,
}
def init_distributed():
world_size = int(os.environ.get("WORLD_SIZE", 1))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
rank = int(os.environ.get("RANK", 0))
print(
"Inference on multiple gpus, this gpu {}".format(local_rank)
+ ", rank {}, world_size {}".format(rank, world_size)
)
torch.cuda.set_device(local_rank)
# Initialize process group with explicit device IDs
dist.init_process_group(
"nccl",
)
return world_size, local_rank, rank
def get_tokenizer(vocab_file_path: str):
"""
tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
- "char" for char-wise tokenizer, need .txt vocab_file
- "byte" for utf-8 tokenizer
- "custom" if you're directly passing in a path to the vocab.txt you want to use
vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
- if use "char", derived from unfiltered character & symbol counts of custom dataset
- if use "byte", set to 256 (unicode byte range)
"""
with open(vocab_file_path, "r", encoding="utf-8") as f:
vocab_char_map = {}
for i, char in enumerate(f):
vocab_char_map[char[:-1]] = i
vocab_size = len(vocab_char_map)
return vocab_char_map, vocab_size
def convert_char_to_pinyin(reference_target_texts_list, polyphone=True):
final_reference_target_texts_list = []
custom_trans = str.maketrans(
{";": ",", "": '"', "": '"', "": "'", "": "'"}
) # add custom trans here, to address oov
def is_chinese(c):
return "\u3100" <= c <= "\u9fff" # common chinese characters
for text in reference_target_texts_list:
char_list = []
text = text.translate(custom_trans)
for seg in jieba.cut(text):
seg_byte_len = len(bytes(seg, "UTF-8"))
if seg_byte_len == len(seg): # if pure alphabets and symbols
if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
char_list.append(" ")
char_list.extend(seg)
elif polyphone and seg_byte_len == 3 * len(seg): # if pure east asian characters
seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
for i, c in enumerate(seg):
if is_chinese(c):
char_list.append(" ")
char_list.append(seg_[i])
else: # if mixed characters, alphabets and symbols
for c in seg:
if ord(c) < 256:
char_list.extend(c)
elif is_chinese(c):
char_list.append(" ")
char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True))
else:
char_list.append(c)
final_reference_target_texts_list.append(char_list)
return final_reference_target_texts_list
def list_str_to_idx(
text: Union[List[str], List[List[str]]],
vocab_char_map: Dict[str, int], # {char: idx}
padding_value=-1,
):
list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
# text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
return list_idx_tensors
def load_vocoder(
vocoder_name="vocos", is_local=False, local_path="", device="cuda", hf_cache_dir=None, vocoder_trt_engine_path=None
):
if vocoder_name == "vocos":
if vocoder_trt_engine_path is not None:
vocoder = VocosTensorRT(engine_path=vocoder_trt_engine_path)
else:
# vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device)
if is_local:
print(f"Load vocos from local path {local_path}")
config_path = f"{local_path}/config.yaml"
model_path = f"{local_path}/pytorch_model.bin"
else:
print("Download Vocos from huggingface charactr/vocos-mel-24khz")
repo_id = "charactr/vocos-mel-24khz"
config_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml")
model_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="pytorch_model.bin")
vocoder = Vocos.from_hparams(config_path)
state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
from vocos.feature_extractors import EncodecFeatures
if isinstance(vocoder.feature_extractor, EncodecFeatures):
encodec_parameters = {
"feature_extractor.encodec." + key: value
for key, value in vocoder.feature_extractor.encodec.state_dict().items()
}
state_dict.update(encodec_parameters)
vocoder.load_state_dict(state_dict)
vocoder = vocoder.eval().to(device)
elif vocoder_name == "bigvgan":
raise NotImplementedError("BigVGAN is not implemented yet")
return vocoder
def mel_spectrogram(waveform, vocoder="vocos", device="cuda"):
if vocoder == "vocos":
mel_stft = torchaudio.transforms.MelSpectrogram(
sample_rate=24000,
n_fft=1024,
win_length=1024,
hop_length=256,
n_mels=100,
power=1,
center=True,
normalized=False,
norm=None,
).to(device)
mel = mel_stft(waveform.to(device))
mel = mel.clamp(min=1e-5).log()
return mel.transpose(1, 2)
class VocosTensorRT:
def __init__(self, engine_path="./vocos_vocoder.plan", stream=None):
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
trt.init_libnvinfer_plugins(TRT_LOGGER, namespace="")
logger.info(f"Loading vae engine from {engine_path}")
self.engine_path = engine_path
with open(engine_path, "rb") as f:
engine_buffer = f.read()
self.session = Session.from_serialized_engine(engine_buffer)
self.stream = stream if stream is not None else torch.cuda.current_stream().cuda_stream
def decode(self, mels):
mels = mels.contiguous()
inputs = {"mel": mels}
output_info = self.session.infer_shapes([TensorInfo("mel", trt.DataType.FLOAT, mels.shape)])
outputs = {
t.name: torch.empty(tuple(t.shape), dtype=trt_dtype_to_torch(t.dtype), device="cuda") for t in output_info
}
ok = self.session.run(inputs, outputs, self.stream)
assert ok, "Runtime execution failed for vae session"
samples = outputs["waveform"]
return samples
def main():
args = get_args()
os.makedirs(args.output_dir, exist_ok=True)
assert torch.cuda.is_available()
world_size, local_rank, rank = init_distributed()
device = torch.device(f"cuda:{local_rank}")
vocab_char_map, vocab_size = get_tokenizer(args.vocab_file)
tllm_model_dir = args.tllm_model_dir
config_file = os.path.join(tllm_model_dir, "config.json")
with open(config_file) as f:
config = json.load(f)
if args.backend_type == "trt":
model = F5TTS(
config, debug_mode=False, tllm_model_dir=tllm_model_dir, model_path=args.model_path, vocab_size=vocab_size
)
elif args.backend_type == "pytorch":
import sys
sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../../../src/")
from f5_tts.infer.utils_infer import load_model
from f5_tts.model import DiT
F5TTS_model_cfg = dict(
dim=1024,
depth=22,
heads=16,
ff_mult=2,
text_dim=512,
conv_layers=4,
pe_attn_head=1,
text_mask_padding=False,
)
model = load_model(DiT, F5TTS_model_cfg, args.model_path)
vocoder = load_vocoder(
vocoder_name=args.vocoder, device=device, vocoder_trt_engine_path=args.vocoder_trt_engine_path
)
dataset = load_dataset(
"yuekai/seed_tts",
split=args.split_name,
trust_remote_code=True,
)
def add_estimated_duration(example):
prompt_audio_len = example["prompt_audio"]["array"].shape[0]
scale_factor = 1 + len(example["target_text"]) / len(example["prompt_text"])
estimated_duration = prompt_audio_len * scale_factor
example["estimated_duration"] = estimated_duration / example["prompt_audio"]["sampling_rate"]
return example
dataset = dataset.map(add_estimated_duration)
dataset = dataset.sort("estimated_duration", reverse=True)
if args.use_perf:
# dataset_list = [dataset.select(range(1)) for i in range(16)] # seq_len 1000
dataset_list_short = [dataset.select([24]) for i in range(8)] # seq_len 719
# dataset_list_long = [dataset.select([23]) for i in range(8)] # seq_len 2002
# dataset = datasets.concatenate_datasets(dataset_list_short + dataset_list_long)
dataset = datasets.concatenate_datasets(dataset_list_short)
if world_size > 1:
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
else:
# This would disable shuffling
sampler = None
dataloader = DataLoader(
dataset,
batch_size=args.batch_size,
sampler=sampler,
shuffle=False,
num_workers=args.num_workers,
prefetch_factor=args.prefetch,
collate_fn=lambda x: data_collator(x, vocab_char_map, use_perf=args.use_perf),
)
total_steps = len(dataset)
if args.enable_warmup:
for batch in dataloader:
ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch["ref_mel_len_batch"].to(device)
text_pad_seq = batch["text_pad_sequence"].to(device)
total_mel_lens = batch["estimated_reference_target_mel_len"]
if args.backend_type == "trt":
_ = model.sample(
text_pad_seq, ref_mels, ref_mel_lens, total_mel_lens, remove_input_padding=args.remove_input_padding
)
elif args.backend_type == "pytorch":
with torch.inference_mode():
text_pad_seq -= 1
text_pad_seq[text_pad_seq == -2] = -1
total_mel_lens = torch.tensor(total_mel_lens, device=device)
generated, _ = model.sample(
cond=ref_mels,
text=text_pad_seq,
duration=total_mel_lens,
steps=16,
cfg_strength=2.0,
sway_sampling_coef=-1,
)
if rank == 0:
progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs")
decoding_time = 0
vocoder_time = 0
total_duration = 0
if args.use_perf:
torch.cuda.cudart().cudaProfilerStart()
total_decoding_time = time.time()
for batch in dataloader:
if args.use_perf:
torch.cuda.nvtx.range_push("data sample")
ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch["ref_mel_len_batch"].to(device)
text_pad_seq = batch["text_pad_sequence"].to(device)
total_mel_lens = batch["estimated_reference_target_mel_len"]
if args.use_perf:
torch.cuda.nvtx.range_pop()
if args.backend_type == "trt":
generated, cost_time = model.sample(
text_pad_seq,
ref_mels,
ref_mel_lens,
total_mel_lens,
remove_input_padding=args.remove_input_padding,
use_perf=args.use_perf,
)
elif args.backend_type == "pytorch":
total_mel_lens = torch.tensor(total_mel_lens, device=device)
with torch.inference_mode():
start_time = time.time()
text_pad_seq -= 1
text_pad_seq[text_pad_seq == -2] = -1
generated, _ = model.sample(
cond=ref_mels,
text=text_pad_seq,
duration=total_mel_lens,
lens=ref_mel_lens,
steps=16,
cfg_strength=2.0,
sway_sampling_coef=-1,
)
cost_time = time.time() - start_time
decoding_time += cost_time
vocoder_start_time = time.time()
for i, gen in enumerate(generated):
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32)
if args.vocoder == "vocos":
if args.use_perf:
torch.cuda.nvtx.range_push("vocoder decode")
generated_wave = vocoder.decode(gen_mel_spec).cpu()
if args.use_perf:
torch.cuda.nvtx.range_pop()
else:
generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu()
target_rms = 0.1
target_sample_rate = 24_000
# if ref_rms_list[i] < target_rms:
# generated_wave = generated_wave * ref_rms_list[i] / target_rms
rms = torch.sqrt(torch.mean(torch.square(generated_wave)))
if rms < target_rms:
generated_wave = generated_wave * target_rms / rms
utt = batch["ids"][i]
torchaudio.save(
f"{args.output_dir}/{utt}.wav",
generated_wave,
target_sample_rate,
)
total_duration += generated_wave.shape[1] / target_sample_rate
vocoder_time += time.time() - vocoder_start_time
if rank == 0:
progress_bar.update(world_size * len(batch["ids"]))
total_decoding_time = time.time() - total_decoding_time
if rank == 0:
progress_bar.close()
rtf = total_decoding_time / total_duration
s = f"RTF: {rtf:.4f}\n"
s += f"total_duration: {total_duration:.3f} seconds\n"
s += f"({total_duration / 3600:.2f} hours)\n"
s += f"DiT time: {decoding_time:.3f} seconds ({decoding_time / 3600:.2f} hours)\n"
s += f"Vocoder time: {vocoder_time:.3f} seconds ({vocoder_time / 3600:.2f} hours)\n"
s += f"total decoding time: {total_decoding_time:.3f} seconds ({total_decoding_time / 3600:.2f} hours)\n"
s += f"batch size: {args.batch_size}\n"
print(s)
with open(f"{args.output_dir}/rtf.txt", "w") as f:
f.write(s)
dist.barrier()
dist.destroy_process_group()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,470 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
# 2023 Nvidia (authors: Yuekai Zhang)
# 2023 Recurrent.ai (authors: Songtao Shi)
# See LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script supports to load dataset from huggingface and sends it to the server
for decoding, in parallel.
Usage:
num_task=2
# For offline F5-TTS
python3 client_grpc.py \
--server-addr localhost \
--model-name f5_tts \
--num-tasks $num_task \
--huggingface-dataset yuekai/seed_tts \
--split-name test_zh \
--log-dir ./log_concurrent_tasks_${num_task}
# For offline Spark-TTS-0.5B
python3 client_grpc.py \
--server-addr localhost \
--model-name spark_tts \
--num-tasks $num_task \
--huggingface-dataset yuekai/seed_tts \
--split-name wenetspeech4tts \
--log-dir ./log_concurrent_tasks_${num_task}
"""
import argparse
import asyncio
import json
import os
import time
import types
from pathlib import Path
import numpy as np
import soundfile as sf
import tritonclient
import tritonclient.grpc.aio as grpcclient
from tritonclient.utils import np_to_triton_dtype
def write_triton_stats(stats, summary_file):
with open(summary_file, "w") as summary_f:
model_stats = stats["model_stats"]
# write a note, the log is from triton_client.get_inference_statistics(), to better human readability
summary_f.write(
"The log is parsing from triton_client.get_inference_statistics(), to better human readability. \n"
)
summary_f.write("To learn more about the log, please refer to: \n")
summary_f.write("1. https://github.com/triton-inference-server/server/blob/main/docs/user_guide/metrics.md \n")
summary_f.write("2. https://github.com/triton-inference-server/server/issues/5374 \n\n")
summary_f.write(
"To better improve throughput, we always would like let requests wait in the queue for a while, and then execute them with a larger batch size. \n"
)
summary_f.write(
"However, there is a trade-off between the increased queue time and the increased batch size. \n"
)
summary_f.write(
"You may change 'max_queue_delay_microseconds' and 'preferred_batch_size' in the model configuration file to achieve this. \n"
)
summary_f.write(
"See https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#delayed-batching for more details. \n\n"
)
for model_state in model_stats:
if "last_inference" not in model_state:
continue
summary_f.write(f"model name is {model_state['name']} \n")
model_inference_stats = model_state["inference_stats"]
total_queue_time_s = int(model_inference_stats["queue"]["ns"]) / 1e9
total_infer_time_s = int(model_inference_stats["compute_infer"]["ns"]) / 1e9
total_input_time_s = int(model_inference_stats["compute_input"]["ns"]) / 1e9
total_output_time_s = int(model_inference_stats["compute_output"]["ns"]) / 1e9
summary_f.write(
f"queue time {total_queue_time_s:<5.2f} s, compute infer time {total_infer_time_s:<5.2f} s, compute input time {total_input_time_s:<5.2f} s, compute output time {total_output_time_s:<5.2f} s \n" # noqa
)
model_batch_stats = model_state["batch_stats"]
for batch in model_batch_stats:
batch_size = int(batch["batch_size"])
compute_input = batch["compute_input"]
compute_output = batch["compute_output"]
compute_infer = batch["compute_infer"]
batch_count = int(compute_infer["count"])
assert compute_infer["count"] == compute_output["count"] == compute_input["count"]
compute_infer_time_ms = int(compute_infer["ns"]) / 1e6
compute_input_time_ms = int(compute_input["ns"]) / 1e6
compute_output_time_ms = int(compute_output["ns"]) / 1e6
summary_f.write(
f"execuate inference with batch_size {batch_size:<2} total {batch_count:<5} times, total_infer_time {compute_infer_time_ms:<9.2f} ms, avg_infer_time {compute_infer_time_ms:<9.2f}/{batch_count:<5}={compute_infer_time_ms / batch_count:.2f} ms, avg_infer_time_per_sample {compute_infer_time_ms:<9.2f}/{batch_count:<5}/{batch_size}={compute_infer_time_ms / batch_count / batch_size:.2f} ms \n" # noqa
)
summary_f.write(
f"input {compute_input_time_ms:<9.2f} ms, avg {compute_input_time_ms / batch_count:.2f} ms, " # noqa
)
summary_f.write(
f"output {compute_output_time_ms:<9.2f} ms, avg {compute_output_time_ms / batch_count:.2f} ms \n" # noqa
)
def get_args():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
"--server-addr",
type=str,
default="localhost",
help="Address of the server",
)
parser.add_argument(
"--server-port",
type=int,
default=8001,
help="Grpc port of the triton server, default is 8001",
)
parser.add_argument(
"--reference-audio",
type=str,
default=None,
help="Path to a single audio file. It can't be specified at the same time with --manifest-dir",
)
parser.add_argument(
"--reference-text",
type=str,
default="",
help="",
)
parser.add_argument(
"--target-text",
type=str,
default="",
help="",
)
parser.add_argument(
"--huggingface-dataset",
type=str,
default="yuekai/seed_tts",
help="dataset name in huggingface dataset hub",
)
parser.add_argument(
"--split-name",
type=str,
default="wenetspeech4tts",
choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"],
help="dataset split name, default is 'test'",
)
parser.add_argument(
"--manifest-path",
type=str,
default=None,
help="Path to the manifest dir which includes wav.scp trans.txt files.",
)
parser.add_argument(
"--model-name",
type=str,
default="f5_tts",
choices=["f5_tts", "spark_tts"],
help="triton model_repo module name to request: transducer for k2, attention_rescoring for wenet offline, streaming_wenet for wenet streaming, infer_pipeline for paraformer large offline",
)
parser.add_argument(
"--num-tasks",
type=int,
default=1,
help="Number of concurrent tasks for sending",
)
parser.add_argument(
"--log-interval",
type=int,
default=5,
help="Controls how frequently we print the log.",
)
parser.add_argument(
"--compute-wer",
action="store_true",
default=False,
help="""True to compute WER.
""",
)
parser.add_argument(
"--log-dir",
type=str,
required=False,
default="./tmp",
help="log directory",
)
parser.add_argument(
"--batch-size",
type=int,
default=1,
help="Inference batch_size per request for offline mode.",
)
return parser.parse_args()
def load_audio(wav_path, target_sample_rate=24000):
assert target_sample_rate == 24000, "hard coding in server"
if isinstance(wav_path, dict):
waveform = wav_path["array"]
sample_rate = wav_path["sampling_rate"]
else:
waveform, sample_rate = sf.read(wav_path)
if sample_rate != target_sample_rate:
from scipy.signal import resample
num_samples = int(len(waveform) * (target_sample_rate / sample_rate))
waveform = resample(waveform, num_samples)
return waveform, target_sample_rate
async def send(
manifest_item_list: list,
name: str,
triton_client: tritonclient.grpc.aio.InferenceServerClient,
protocol_client: types.ModuleType,
log_interval: int,
model_name: str,
padding_duration: int = None,
audio_save_dir: str = "./",
save_sample_rate: int = 24000,
):
total_duration = 0.0
latency_data = []
task_id = int(name[5:])
print(f"manifest_item_list: {manifest_item_list}")
for i, item in enumerate(manifest_item_list):
if i % log_interval == 0:
print(f"{name}: {i}/{len(manifest_item_list)}")
waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=24000)
duration = len(waveform) / sample_rate
lengths = np.array([[len(waveform)]], dtype=np.int32)
reference_text, target_text = item["reference_text"], item["target_text"]
estimated_target_duration = duration / len(reference_text) * len(target_text)
if padding_duration:
# padding to nearset 10 seconds
samples = np.zeros(
(
1,
padding_duration
* sample_rate
* ((int(estimated_target_duration + duration) // padding_duration) + 1),
),
dtype=np.float32,
)
samples[0, : len(waveform)] = waveform
else:
samples = waveform
samples = samples.reshape(1, -1).astype(np.float32)
inputs = [
protocol_client.InferInput("reference_wav", samples.shape, np_to_triton_dtype(samples.dtype)),
protocol_client.InferInput("reference_wav_len", lengths.shape, np_to_triton_dtype(lengths.dtype)),
protocol_client.InferInput("reference_text", [1, 1], "BYTES"),
protocol_client.InferInput("target_text", [1, 1], "BYTES"),
]
inputs[0].set_data_from_numpy(samples)
inputs[1].set_data_from_numpy(lengths)
input_data_numpy = np.array([reference_text], dtype=object)
input_data_numpy = input_data_numpy.reshape((1, 1))
inputs[2].set_data_from_numpy(input_data_numpy)
input_data_numpy = np.array([target_text], dtype=object)
input_data_numpy = input_data_numpy.reshape((1, 1))
inputs[3].set_data_from_numpy(input_data_numpy)
outputs = [protocol_client.InferRequestedOutput("waveform")]
sequence_id = 100000000 + i + task_id * 10
start = time.time()
response = await triton_client.infer(model_name, inputs, request_id=str(sequence_id), outputs=outputs)
audio = response.as_numpy("waveform").reshape(-1)
end = time.time() - start
audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav")
sf.write(audio_save_path, audio, save_sample_rate, "PCM_16")
actual_duration = len(audio) / save_sample_rate
latency_data.append((end, actual_duration))
total_duration += actual_duration
return total_duration, latency_data
def load_manifests(manifest_path):
with open(manifest_path, "r") as f:
manifest_list = []
for line in f:
assert len(line.strip().split("|")) == 4
utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
utt = Path(utt).stem
# gt_wav = os.path.join(os.path.dirname(manifest_path), "wavs", utt + ".wav")
if not os.path.isabs(prompt_wav):
prompt_wav = os.path.join(os.path.dirname(manifest_path), prompt_wav)
manifest_list.append(
{
"audio_filepath": prompt_wav,
"reference_text": prompt_text,
"target_text": gt_text,
"target_audio_path": utt,
}
)
return manifest_list
def split_data(data, k):
n = len(data)
if n < k:
print(f"Warning: the length of the input list ({n}) is less than k ({k}). Setting k to {n}.")
k = n
quotient = n // k
remainder = n % k
result = []
start = 0
for i in range(k):
if i < remainder:
end = start + quotient + 1
else:
end = start + quotient
result.append(data[start:end])
start = end
return result
async def main():
args = get_args()
url = f"{args.server_addr}:{args.server_port}"
triton_client = grpcclient.InferenceServerClient(url=url, verbose=False)
protocol_client = grpcclient
if args.reference_audio:
args.num_tasks = 1
args.log_interval = 1
manifest_item_list = [
{
"reference_text": args.reference_text,
"target_text": args.target_text,
"audio_filepath": args.reference_audio,
"target_audio_path": "test",
}
]
elif args.huggingface_dataset:
import datasets
dataset = datasets.load_dataset(
args.huggingface_dataset,
split=args.split_name,
trust_remote_code=True,
)
manifest_item_list = []
for i in range(len(dataset)):
manifest_item_list.append(
{
"audio_filepath": dataset[i]["prompt_audio"],
"reference_text": dataset[i]["prompt_text"],
"target_audio_path": dataset[i]["id"],
"target_text": dataset[i]["target_text"],
}
)
else:
manifest_item_list = load_manifests(args.manifest_path)
args.num_tasks = min(args.num_tasks, len(manifest_item_list))
manifest_item_list = split_data(manifest_item_list, args.num_tasks)
os.makedirs(args.log_dir, exist_ok=True)
tasks = []
start_time = time.time()
for i in range(args.num_tasks):
task = asyncio.create_task(
send(
manifest_item_list[i],
name=f"task-{i}",
triton_client=triton_client,
protocol_client=protocol_client,
log_interval=args.log_interval,
model_name=args.model_name,
audio_save_dir=args.log_dir,
padding_duration=1,
save_sample_rate=24000,
)
)
tasks.append(task)
ans_list = await asyncio.gather(*tasks)
end_time = time.time()
elapsed = end_time - start_time
total_duration = 0.0
latency_data = []
for ans in ans_list:
total_duration += ans[0]
latency_data += ans[1]
rtf = elapsed / total_duration
s = f"RTF: {rtf:.4f}\n"
s += f"total_duration: {total_duration:.3f} seconds\n"
s += f"({total_duration / 3600:.2f} hours)\n"
s += f"processing time: {elapsed:.3f} seconds ({elapsed / 3600:.2f} hours)\n"
latency_list = [chunk_end for (chunk_end, chunk_duration) in latency_data]
latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0
latency_variance = np.var(latency_list, dtype=np.float64) * 1000.0
s += f"latency_variance: {latency_variance:.2f}\n"
s += f"latency_50_percentile_ms: {np.percentile(latency_list, 50) * 1000.0:.2f}\n"
s += f"latency_90_percentile_ms: {np.percentile(latency_list, 90) * 1000.0:.2f}\n"
s += f"latency_95_percentile_ms: {np.percentile(latency_list, 95) * 1000.0:.2f}\n"
s += f"latency_99_percentile_ms: {np.percentile(latency_list, 99) * 1000.0:.2f}\n"
s += f"average_latency_ms: {latency_ms:.2f}\n"
print(s)
if args.manifest_path:
name = Path(args.manifest_path).stem
elif args.split_name:
name = args.split_name
with open(f"{args.log_dir}/rtf-{name}.txt", "w") as f:
f.write(s)
stats = await triton_client.get_inference_statistics(model_name="", as_json=True)
write_triton_stats(stats, f"{args.log_dir}/stats_summary-{name}.txt")
metadata = await triton_client.get_model_config(model_name=args.model_name, as_json=True)
with open(f"{args.log_dir}/model_config-{name}.json", "w") as f:
json.dump(metadata, f, indent=4)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,143 @@
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import argparse
import numpy as np
import requests
import soundfile as sf
def get_args():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
"--server-url",
type=str,
default="localhost:8000",
help="Address of the server",
)
parser.add_argument(
"--reference-audio",
type=str,
default="../../infer/examples/basic/basic_ref_en.wav",
help="Path to a single audio file. It can't be specified at the same time with --manifest-dir",
)
parser.add_argument(
"--reference-text",
type=str,
default="Some call me nature, others call me mother nature.",
help="",
)
parser.add_argument(
"--target-text",
type=str,
default="I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring.",
help="",
)
parser.add_argument(
"--model-name",
type=str,
default="f5_tts",
choices=["f5_tts", "spark_tts"],
help="triton model_repo module name to request",
)
parser.add_argument(
"--output-audio",
type=str,
default="output.wav",
help="Path to save the output audio",
)
return parser.parse_args()
def prepare_request(
samples,
reference_text,
target_text,
sample_rate=24000,
audio_save_dir: str = "./",
):
assert len(samples.shape) == 1, "samples should be 1D"
lengths = np.array([[len(samples)]], dtype=np.int32)
samples = samples.reshape(1, -1).astype(np.float32)
data = {
"inputs": [
{"name": "reference_wav", "shape": samples.shape, "datatype": "FP32", "data": samples.tolist()},
{
"name": "reference_wav_len",
"shape": lengths.shape,
"datatype": "INT32",
"data": lengths.tolist(),
},
{"name": "reference_text", "shape": [1, 1], "datatype": "BYTES", "data": [reference_text]},
{"name": "target_text", "shape": [1, 1], "datatype": "BYTES", "data": [target_text]},
]
}
return data
def load_audio(wav_path, target_sample_rate=24000):
assert target_sample_rate == 24000, "hard coding in server"
if isinstance(wav_path, dict):
samples = wav_path["array"]
sample_rate = wav_path["sampling_rate"]
else:
samples, sample_rate = sf.read(wav_path)
if sample_rate != target_sample_rate:
from scipy.signal import resample
num_samples = int(len(samples) * (target_sample_rate / sample_rate))
samples = resample(samples, num_samples)
return samples, target_sample_rate
if __name__ == "__main__":
args = get_args()
server_url = args.server_url
if not server_url.startswith(("http://", "https://")):
server_url = f"http://{server_url}"
url = f"{server_url}/v2/models/{args.model_name}/infer"
samples, sr = load_audio(args.reference_audio)
assert sr == 24000, "sample rate hardcoded in server"
samples = np.array(samples, dtype=np.float32)
data = prepare_request(samples, args.reference_text, args.target_text)
rsp = requests.post(
url, headers={"Content-Type": "application/json"}, json=data, verify=False, params={"request_id": "0"}
)
result = rsp.json()
audio = result["outputs"][0]["data"]
audio = np.array(audio, dtype=np.float32)
sf.write(args.output_audio, audio, 24000, "PCM_16")

View File

@@ -0,0 +1,20 @@
services:
tts:
image: soar97/triton-f5-tts:24.12
shm_size: '1gb'
ports:
- "8000:8000"
- "8001:8001"
- "8002:8002"
environment:
- PYTHONIOENCODING=utf-8
- MODEL_ID=${MODEL_ID}
deploy:
resources:
reservations:
devices:
- driver: nvidia
device_ids: ['0']
capabilities: [gpu]
command: >
/bin/bash -c "pip install vocos && rm -rf F5-TTS && git clone https://github.com/SWivid/F5-TTS.git && cd F5-TTS/src/f5_tts/runtime/triton_trtllm/ && bash run.sh 0 4 $MODEL"

View File

@@ -0,0 +1,430 @@
import math
import os
import time
from functools import wraps
from typing import List, Optional
import tensorrt as trt
import tensorrt_llm
import torch
import torch.nn as nn
import torch.nn.functional as F
from tensorrt_llm._utils import str_dtype_to_torch, trt_dtype_to_torch
from tensorrt_llm.logger import logger
from tensorrt_llm.runtime.session import Session
def remove_tensor_padding(input_tensor, input_tensor_lengths=None):
# Audio tensor case: batch, seq_len, feature_len
# position_ids case: batch, seq_len
assert input_tensor_lengths is not None, "input_tensor_lengths must be provided for 3D input_tensor"
# Initialize a list to collect valid sequences
valid_sequences = []
for i in range(input_tensor.shape[0]):
valid_length = input_tensor_lengths[i]
valid_sequences.append(input_tensor[i, :valid_length])
# Concatenate all valid sequences along the batch dimension
output_tensor = torch.cat(valid_sequences, dim=0).contiguous()
return output_tensor
class TextEmbedding(nn.Module):
def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2, precompute_max_pos=4096):
super().__init__()
self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, precompute_max_pos), persistent=False)
self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)])
def forward(self, text):
# only keep tensors with value not -1
text_mask = text != -1
text_pad_cut_off_index = text_mask.sum(dim=1).max()
text = text[:, :text_pad_cut_off_index]
text = self.text_embed(text)
text = text + self.freqs_cis[: text.shape[1], :]
for block in self.text_blocks:
text = block(text)
# padding text to the original length
# text shape: B,seq_len,C
# pad at the second dimension
text = F.pad(text, (0, 0, 0, text_mask.shape[1] - text.shape[1], 0, 0), value=0)
return text
class GRN(nn.Module):
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
self.beta = nn.Parameter(torch.zeros(1, 1, dim))
def forward(self, x):
Gx = torch.norm(x, p=2, dim=1, keepdim=True)
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
return self.gamma * (x * Nx) + self.beta + x
class ConvNeXtV2Block(nn.Module):
def __init__(
self,
dim: int,
intermediate_dim: int,
dilation: int = 1,
):
super().__init__()
padding = (dilation * (7 - 1)) // 2
self.dwconv = nn.Conv1d(
dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
) # depthwise conv
self.norm = nn.LayerNorm(dim, eps=1e-6)
self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.grn = GRN(intermediate_dim)
self.pwconv2 = nn.Linear(intermediate_dim, dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
x = x.transpose(1, 2) # b n d -> b d n
x = self.dwconv(x)
x = x.transpose(1, 2) # b d n -> b n d
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.grn(x)
x = self.pwconv2(x)
return residual + x
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0):
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
# has some connection to NTK literature
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
# https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py
theta *= theta_rescale_factor ** (dim / (dim - 2))
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
freqs_cos = torch.cos(freqs) # real part
freqs_sin = torch.sin(freqs) # imaginary part
return torch.cat([freqs_cos, freqs_sin], dim=-1)
def load_checkpoint(ckpt_path, use_ema=True):
checkpoint = torch.load(ckpt_path, weights_only=True)
if use_ema:
checkpoint["model_state_dict"] = {
k.replace("ema_model.", ""): v
for k, v in checkpoint["ema_model_state_dict"].items()
if k not in ["initted", "step"]
}
dict_state = checkpoint["model_state_dict"]
text_embed_dict = {}
for key in dict_state.keys():
# transformer.text_embed.text_embed.weight -> text_embed.weight
if "text_embed" in key:
text_embed_dict[key.replace("transformer.text_embed.", "")] = dict_state[key]
return text_embed_dict
class F5TTS(object):
def __init__(
self,
config,
debug_mode=True,
stream: Optional[torch.cuda.Stream] = None,
tllm_model_dir: Optional[str] = None,
model_path: Optional[str] = None,
vocab_size: Optional[int] = None,
):
self.dtype = config["pretrained_config"]["dtype"]
rank = tensorrt_llm.mpi_rank()
world_size = config["pretrained_config"]["mapping"]["world_size"]
cp_size = config["pretrained_config"]["mapping"]["cp_size"]
tp_size = config["pretrained_config"]["mapping"]["tp_size"]
pp_size = config["pretrained_config"]["mapping"]["pp_size"]
assert pp_size == 1
self.mapping = tensorrt_llm.Mapping(
world_size=world_size, rank=rank, cp_size=cp_size, tp_size=tp_size, pp_size=1, gpus_per_node=1
)
local_rank = rank % self.mapping.gpus_per_node
self.device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(self.device)
self.stream = stream
if self.stream is None:
self.stream = torch.cuda.Stream(self.device)
torch.cuda.set_stream(self.stream)
engine_file = os.path.join(tllm_model_dir, f"rank{rank}.engine")
logger.info(f"Loading engine from {engine_file}")
with open(engine_file, "rb") as f:
engine_buffer = f.read()
assert engine_buffer is not None
self.session = Session.from_serialized_engine(engine_buffer)
self.debug_mode = debug_mode
self.inputs = {}
self.outputs = {}
self.buffer_allocated = False
expected_tensor_names = ["noise", "cond", "time", "rope_cos", "rope_sin", "input_lengths", "denoised"]
found_tensor_names = [self.session.engine.get_tensor_name(i) for i in range(self.session.engine.num_io_tensors)]
if not self.debug_mode and set(expected_tensor_names) != set(found_tensor_names):
logger.error(
f"The following expected tensors are not found: {set(expected_tensor_names).difference(set(found_tensor_names))}"
)
logger.error(
f"Those tensors in engine are not expected: {set(found_tensor_names).difference(set(expected_tensor_names))}"
)
logger.error(f"Expected tensor names: {expected_tensor_names}")
logger.error(f"Found tensor names: {found_tensor_names}")
raise RuntimeError("Tensor names in engine are not the same as expected.")
if self.debug_mode:
self.debug_tensors = list(set(found_tensor_names) - set(expected_tensor_names))
self.max_mel_len = 4096
self.text_embedding = TextEmbedding(
text_num_embeds=vocab_size, text_dim=512, conv_layers=4, precompute_max_pos=self.max_mel_len
).to(self.device)
self.text_embedding.load_state_dict(load_checkpoint(model_path), strict=True)
self.target_audio_sample_rate = 24000
self.target_rms = 0.15 # target rms for audio
self.n_fft = 1024
self.win_length = 1024
self.hop_length = 256
self.n_mel_channels = 100
# self.max_mel_len = 3000
self.head_dim = 64
self.base_rescale_factor = 1.0
self.interpolation_factor = 1.0
base = 10000.0 * self.base_rescale_factor ** (self.head_dim / (self.head_dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim))
freqs = torch.outer(torch.arange(self.max_mel_len, dtype=torch.float32), inv_freq) / self.interpolation_factor
self.freqs = freqs.repeat_interleave(2, dim=-1).unsqueeze(0)
self.rope_cos = self.freqs.cos().half()
self.rope_sin = self.freqs.sin().half()
self.nfe_steps = 16
t = torch.linspace(0, 1, self.nfe_steps + 1, dtype=torch.float32)
time_step = t + (-1.0) * (torch.cos(torch.pi * 0.5 * t) - 1 + t)
delta_t = torch.diff(time_step)
# WAR: hard coding 256 here
tmp_dim = 256
time_expand = torch.zeros((1, self.nfe_steps, tmp_dim), dtype=torch.float32)
half_dim = tmp_dim // 2
emb_factor = math.log(10000) / (half_dim - 1)
emb_factor = 1000.0 * torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb_factor)
for i in range(self.nfe_steps):
emb = time_step[i] * emb_factor
time_expand[:, i, :] = torch.cat((emb.sin(), emb.cos()), dim=-1)
self.time_expand = time_expand.to(self.device)
self.delta_t = torch.cat((delta_t, delta_t), dim=0).contiguous().to(self.device)
def _tensor_dtype(self, name):
# return torch dtype given tensor name for convenience
dtype = trt_dtype_to_torch(self.session.engine.get_tensor_dtype(name))
return dtype
def _setup(self, batch_size, seq_len):
for i in range(self.session.engine.num_io_tensors):
name = self.session.engine.get_tensor_name(i)
if self.session.engine.get_tensor_mode(name) == trt.TensorIOMode.OUTPUT:
shape = list(self.session.engine.get_tensor_shape(name))
shape[0] = batch_size
shape[1] = seq_len
self.outputs[name] = torch.empty(shape, dtype=self._tensor_dtype(name), device=self.device)
self.buffer_allocated = True
def cuda_stream_guard(func):
"""Sync external stream and set current stream to the one bound to the session. Reset on exit."""
@wraps(func)
def wrapper(self, *args, **kwargs):
external_stream = torch.cuda.current_stream()
if external_stream != self.stream:
external_stream.synchronize()
torch.cuda.set_stream(self.stream)
ret = func(self, *args, **kwargs)
if external_stream != self.stream:
self.stream.synchronize()
torch.cuda.set_stream(external_stream)
return ret
return wrapper
@cuda_stream_guard
def forward(
self,
noise: torch.Tensor,
cond: torch.Tensor,
time_expand: torch.Tensor,
rope_cos: torch.Tensor,
rope_sin: torch.Tensor,
input_lengths: torch.Tensor,
delta_t: torch.Tensor,
use_perf: bool = False,
):
if use_perf:
torch.cuda.nvtx.range_push("flow matching")
cfg_strength = 2.0
batch_size = noise.shape[0]
half_batch = batch_size // 2
noise_half = noise[:half_batch] # Store the initial half of noise
input_type = str_dtype_to_torch(self.dtype)
# Keep a copy of the initial tensors
cond = cond.to(input_type)
rope_cos = rope_cos.to(input_type)
rope_sin = rope_sin.to(input_type)
input_lengths = input_lengths.to(str_dtype_to_torch("int32"))
# Instead of iteratively updating noise within a single model context,
# we'll do a single forward pass for each iteration with fresh context setup
for i in range(self.nfe_steps):
# Re-setup the buffers for clean execution
self._setup(batch_size, noise.shape[1])
if not self.buffer_allocated:
raise RuntimeError("Buffer not allocated, please call setup first!")
# Re-create combined noises for this iteration
current_noise = torch.cat([noise_half, noise_half], dim=0).to(input_type)
# Get time step for this iteration
current_time = time_expand[:, i].to(input_type)
# Create fresh input dictionary for this iteration
current_inputs = {
"noise": current_noise,
"cond": cond,
"time": current_time,
"rope_cos": rope_cos,
"rope_sin": rope_sin,
"input_lengths": input_lengths,
}
# Update inputs and set shapes
self.inputs.clear() # Clear previous inputs
self.inputs.update(**current_inputs)
self.session.set_shapes(self.inputs)
if use_perf:
torch.cuda.nvtx.range_push(f"execute {i}")
ok = self.session.run(self.inputs, self.outputs, self.stream.cuda_stream)
assert ok, "Failed to execute model"
# self.session.context.execute_async_v3(self.stream.cuda_stream)
if use_perf:
torch.cuda.nvtx.range_pop()
# Process results
t_scale = delta_t[i].unsqueeze(0).to(input_type)
# Extract predictions
pred_cond = self.outputs["denoised"][:half_batch]
pred_uncond = self.outputs["denoised"][half_batch:]
# Apply classifier-free guidance with safeguards
guidance = pred_cond + (pred_cond - pred_uncond) * cfg_strength
# Calculate update for noise
noise_half = noise_half + guidance * t_scale
if use_perf:
torch.cuda.nvtx.range_pop()
return noise_half
def sample(
self,
text_pad_sequence: torch.Tensor,
ref_mel_batch: torch.Tensor,
ref_mel_len_batch: torch.Tensor,
estimated_reference_target_mel_len: List[int],
remove_input_padding: bool = False,
use_perf: bool = False,
):
if use_perf:
torch.cuda.nvtx.range_push("text embedding")
batch = text_pad_sequence.shape[0]
max_seq_len = ref_mel_batch.shape[1]
text_pad_sequence_drop = torch.cat(
(text_pad_sequence, torch.zeros((1, text_pad_sequence.shape[1]), dtype=torch.int32).to(self.device)), dim=0
)
text_embedding_drop_list = []
for i in range(batch + 1):
text_embedding_drop_list.append(self.text_embedding(text_pad_sequence_drop[i].unsqueeze(0).to(self.device)))
text_embedding_drop_condition = torch.cat(text_embedding_drop_list, dim=0)
text_embedding = text_embedding_drop_condition[:-1]
# text_embedding_drop B,T,C batch should be the same
text_embedding_drop = text_embedding_drop_condition[-1].unsqueeze(0).repeat(batch, 1, 1)
noise = torch.randn_like(ref_mel_batch).to(self.device)
rope_cos = self.rope_cos[:, :max_seq_len, :].float().repeat(batch, 1, 1)
rope_sin = self.rope_sin[:, :max_seq_len, :].float().repeat(batch, 1, 1)
cat_mel_text = torch.cat((ref_mel_batch, text_embedding), dim=-1)
cat_mel_text_drop = torch.cat(
(
torch.zeros((batch, max_seq_len, self.n_mel_channels), dtype=torch.float32).to(self.device),
text_embedding_drop,
),
dim=-1,
)
time_expand = self.time_expand.repeat(2 * batch, 1, 1).contiguous()
# Convert estimated_reference_target_mel_len to tensor
input_lengths = torch.tensor(estimated_reference_target_mel_len, dtype=torch.int32)
# combine above along the batch dimension
inputs = {
"noise": torch.cat((noise, noise), dim=0).contiguous(),
"cond": torch.cat((cat_mel_text, cat_mel_text_drop), dim=0).contiguous(),
"time_expand": time_expand,
"rope_cos": torch.cat((rope_cos, rope_cos), dim=0).contiguous(),
"rope_sin": torch.cat((rope_sin, rope_sin), dim=0).contiguous(),
"input_lengths": torch.cat((input_lengths, input_lengths), dim=0).contiguous(),
"delta_t": self.delta_t,
}
if use_perf and remove_input_padding:
torch.cuda.nvtx.range_push("remove input padding")
if remove_input_padding:
max_seq_len = inputs["cond"].shape[1]
inputs["noise"] = remove_tensor_padding(inputs["noise"], inputs["input_lengths"])
inputs["cond"] = remove_tensor_padding(inputs["cond"], inputs["input_lengths"])
# for time_expand, convert from B,D to B,T,D by repeat
inputs["time_expand"] = inputs["time_expand"].unsqueeze(1).repeat(1, max_seq_len, 1, 1)
inputs["time_expand"] = remove_tensor_padding(inputs["time_expand"], inputs["input_lengths"])
inputs["rope_cos"] = remove_tensor_padding(inputs["rope_cos"], inputs["input_lengths"])
inputs["rope_sin"] = remove_tensor_padding(inputs["rope_sin"], inputs["input_lengths"])
if use_perf and remove_input_padding:
torch.cuda.nvtx.range_pop()
for key in inputs:
inputs[key] = inputs[key].to(self.device)
if use_perf:
torch.cuda.nvtx.range_pop()
start_time = time.time()
denoised = self.forward(**inputs, use_perf=use_perf)
cost_time = time.time() - start_time
if use_perf and remove_input_padding:
torch.cuda.nvtx.range_push("remove input padding output")
if remove_input_padding:
denoised_list = []
start_idx = 0
for i in range(batch):
denoised_list.append(denoised[start_idx : start_idx + inputs["input_lengths"][i]])
start_idx += inputs["input_lengths"][i]
if use_perf and remove_input_padding:
torch.cuda.nvtx.range_pop()
return denoised_list, cost_time
return denoised, cost_time

View File

@@ -0,0 +1,278 @@
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import json
import os
import jieba
import torch
import torch.nn.functional as F
import torchaudio
import triton_python_backend_utils as pb_utils
from f5_tts_trtllm import F5TTS
from pypinyin import Style, lazy_pinyin
from torch.nn.utils.rnn import pad_sequence
from torch.utils.dlpack import from_dlpack, to_dlpack
def get_tokenizer(vocab_file_path: str):
"""
tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
- "char" for char-wise tokenizer, need .txt vocab_file
- "byte" for utf-8 tokenizer
- "custom" if you're directly passing in a path to the vocab.txt you want to use
vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
- if use "char", derived from unfiltered character & symbol counts of custom dataset
- if use "byte", set to 256 (unicode byte range)
"""
with open(vocab_file_path, "r", encoding="utf-8") as f:
vocab_char_map = {}
for i, char in enumerate(f):
vocab_char_map[char[:-1]] = i
vocab_size = len(vocab_char_map)
return vocab_char_map, vocab_size
def convert_char_to_pinyin(reference_target_texts_list, polyphone=True):
final_reference_target_texts_list = []
custom_trans = str.maketrans(
{";": ",", "": '"', "": '"', "": "'", "": "'"}
) # add custom trans here, to address oov
def is_chinese(c):
return "\u3100" <= c <= "\u9fff" # common chinese characters
for text in reference_target_texts_list:
char_list = []
text = text.translate(custom_trans)
for seg in jieba.cut(text):
seg_byte_len = len(bytes(seg, "UTF-8"))
if seg_byte_len == len(seg): # if pure alphabets and symbols
if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
char_list.append(" ")
char_list.extend(seg)
elif polyphone and seg_byte_len == 3 * len(seg): # if pure east asian characters
seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
for i, c in enumerate(seg):
if is_chinese(c):
char_list.append(" ")
char_list.append(seg_[i])
else: # if mixed characters, alphabets and symbols
for c in seg:
if ord(c) < 256:
char_list.extend(c)
elif is_chinese(c):
char_list.append(" ")
char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True))
else:
char_list.append(c)
final_reference_target_texts_list.append(char_list)
return final_reference_target_texts_list
def list_str_to_idx(
text: list[str] | list[list[str]],
vocab_char_map: dict[str, int], # {char: idx}
padding_value=-1,
): # noqa: F722
list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
return list_idx_tensors
class TritonPythonModel:
def initialize(self, args):
self.use_perf = True
self.device = torch.device("cuda")
self.target_audio_sample_rate = 24000
self.target_rms = 0.15 # target rms for audio
self.n_fft = 1024
self.win_length = 1024
self.hop_length = 256
self.n_mel_channels = 100
self.max_mel_len = 3000
self.head_dim = 64
parameters = json.loads(args["model_config"])["parameters"]
for key, value in parameters.items():
parameters[key] = value["string_value"]
self.vocab_char_map, self.vocab_size = get_tokenizer(parameters["vocab_file"])
self.reference_sample_rate = int(parameters["reference_audio_sample_rate"])
self.resampler = torchaudio.transforms.Resample(self.reference_sample_rate, self.target_audio_sample_rate)
self.tllm_model_dir = parameters["tllm_model_dir"]
config_file = os.path.join(self.tllm_model_dir, "config.json")
with open(config_file) as f:
config = json.load(f)
self.model = F5TTS(
config,
debug_mode=False,
tllm_model_dir=self.tllm_model_dir,
model_path=parameters["model_path"],
vocab_size=self.vocab_size,
)
self.vocoder = parameters["vocoder"]
assert self.vocoder in ["vocos", "bigvgan"]
if self.vocoder == "vocos":
self.mel_stft = torchaudio.transforms.MelSpectrogram(
sample_rate=self.target_audio_sample_rate,
n_fft=self.n_fft,
win_length=self.win_length,
hop_length=self.hop_length,
n_mels=self.n_mel_channels,
power=1,
center=True,
normalized=False,
norm=None,
).to(self.device)
self.compute_mel_fn = self.get_vocos_mel_spectrogram
elif self.vocoder == "bigvgan":
self.compute_mel_fn = self.get_bigvgan_mel_spectrogram
def get_vocos_mel_spectrogram(self, waveform):
mel = self.mel_stft(waveform)
mel = mel.clamp(min=1e-5).log()
return mel.transpose(1, 2)
def forward_vocoder(self, mel):
mel = mel.to(torch.float32).contiguous().cpu()
input_tensor_0 = pb_utils.Tensor.from_dlpack("mel", to_dlpack(mel))
inference_request = pb_utils.InferenceRequest(
model_name="vocoder", requested_output_names=["waveform"], inputs=[input_tensor_0]
)
inference_response = inference_request.exec()
if inference_response.has_error():
raise pb_utils.TritonModelException(inference_response.error().message())
else:
waveform = pb_utils.get_output_tensor_by_name(inference_response, "waveform")
waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu()
return waveform
def execute(self, requests):
(
reference_text_list,
target_text_list,
reference_target_texts_list,
estimated_reference_target_mel_len,
reference_mel_len,
) = [], [], [], [], []
mel_features_list = []
if self.use_perf:
torch.cuda.nvtx.range_push("preprocess")
for request in requests:
wav_tensor = pb_utils.get_input_tensor_by_name(request, "reference_wav")
wav_lens = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
reference_text = reference_text[0][0].decode("utf-8")
reference_text_list.append(reference_text)
target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
target_text = target_text[0][0].decode("utf-8")
target_text_list.append(target_text)
text = reference_text + target_text
reference_target_texts_list.append(text)
wav = from_dlpack(wav_tensor.to_dlpack())
wav_len = from_dlpack(wav_lens.to_dlpack())
wav_len = wav_len.squeeze()
assert wav.shape[0] == 1, "Only support batch size 1 for now."
wav = wav[:, :wav_len]
ref_rms = torch.sqrt(torch.mean(torch.square(wav)))
if ref_rms < self.target_rms:
wav = wav * self.target_rms / ref_rms
if self.reference_sample_rate != self.target_audio_sample_rate:
wav = self.resampler(wav)
wav = wav.to(self.device)
if self.use_perf:
torch.cuda.nvtx.range_push("compute_mel")
mel_features = self.compute_mel_fn(wav)
if self.use_perf:
torch.cuda.nvtx.range_pop()
mel_features_list.append(mel_features)
reference_mel_len.append(mel_features.shape[1])
estimated_reference_target_mel_len.append(
int(
mel_features.shape[1] * (1 + len(target_text.encode("utf-8")) / len(reference_text.encode("utf-8")))
)
)
max_seq_len = min(max(estimated_reference_target_mel_len), self.max_mel_len)
batch = len(requests)
mel_features = torch.zeros((batch, max_seq_len, self.n_mel_channels), dtype=torch.float16).to(self.device)
for i, mel in enumerate(mel_features_list):
mel_features[i, : mel.shape[1], :] = mel
reference_mel_len_tensor = torch.LongTensor(reference_mel_len).to(self.device)
pinyin_list = convert_char_to_pinyin(reference_target_texts_list, polyphone=True)
text_pad_sequence = list_str_to_idx(pinyin_list, self.vocab_char_map)
for i, item in enumerate(text_pad_sequence):
text_pad_sequence[i] = F.pad(
item, (0, estimated_reference_target_mel_len[i] - len(item)), mode="constant", value=-1
)
text_pad_sequence[i] += 1 # WAR: 0 is reserved for padding token, hard coding in F5-TTS
text_pad_sequence = pad_sequence(text_pad_sequence, padding_value=-1, batch_first=True).to(self.device)
text_pad_sequence = F.pad(
text_pad_sequence, (0, max_seq_len - text_pad_sequence.shape[1]), mode="constant", value=-1
)
if self.use_perf:
torch.cuda.nvtx.range_pop()
denoised, cost_time = self.model.sample(
text_pad_sequence,
mel_features,
reference_mel_len_tensor,
estimated_reference_target_mel_len,
remove_input_padding=False,
use_perf=self.use_perf,
)
if self.use_perf:
torch.cuda.nvtx.range_push("vocoder")
responses = []
for i in range(batch):
ref_me_len = reference_mel_len[i]
estimated_mel_len = estimated_reference_target_mel_len[i]
denoised_one_item = denoised[i, ref_me_len:estimated_mel_len, :].unsqueeze(0).transpose(1, 2)
audio = self.forward_vocoder(denoised_one_item)
rms = torch.sqrt(torch.mean(torch.square(audio)))
if rms < self.target_rms:
audio = audio * self.target_rms / rms
audio = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
inference_response = pb_utils.InferenceResponse(output_tensors=[audio])
responses.append(inference_response)
if self.use_perf:
torch.cuda.nvtx.range_pop()
return responses

View File

@@ -0,0 +1,81 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
name: "f5_tts"
backend: "python"
max_batch_size: 4
dynamic_batching {
max_queue_delay_microseconds: 1000
}
parameters [
{
key: "vocab_file"
value: { string_value: "${vocab}"}
},
{
key: "model_path",
value: {string_value:"${model}"}
},
{
key: "tllm_model_dir",
value: {string_value:"${trtllm}"}
},
{
key: "reference_audio_sample_rate",
value: {string_value:"24000"}
},
{
key: "vocoder",
value: {string_value:"${vocoder}"}
}
]
input [
{
name: "reference_wav"
data_type: TYPE_FP32
dims: [-1]
optional: True
},
{
name: "reference_wav_len"
data_type: TYPE_INT32
dims: [1]
optional: True
},
{
name: "reference_text"
data_type: TYPE_STRING
dims: [1]
},
{
name: "target_text"
data_type: TYPE_STRING
dims: [1]
}
]
output [
{
name: "waveform"
data_type: TYPE_FP32
dims: [ -1 ]
}
]
instance_group [
{
count: 1
kind: KIND_GPU
}
]

View File

@@ -0,0 +1,32 @@
name: "vocoder"
backend: "tensorrt"
default_model_filename: "vocoder.plan"
max_batch_size: 4
input [
{
name: "mel"
data_type: TYPE_FP32
dims: [ 100, -1 ]
}
]
output [
{
name: "waveform"
data_type: TYPE_FP32
dims: [ -1 ]
}
]
dynamic_batching {
preferred_batch_size: [1, 2, 4]
max_queue_delay_microseconds: 1
}
instance_group [
{
count: 1
kind: KIND_GPU
}
]

View File

@@ -0,0 +1,199 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .baichuan.model import BaichuanForCausalLM
from .bert.model import (
BertForQuestionAnswering,
BertForSequenceClassification,
BertModel,
RobertaForQuestionAnswering,
RobertaForSequenceClassification,
RobertaModel,
)
from .bloom.model import BloomForCausalLM, BloomModel
from .chatglm.config import ChatGLMConfig
from .chatglm.model import ChatGLMForCausalLM, ChatGLMModel
from .cogvlm.config import CogVLMConfig
from .cogvlm.model import CogVLMForCausalLM
from .commandr.model import CohereForCausalLM
from .dbrx.config import DbrxConfig
from .dbrx.model import DbrxForCausalLM
from .deepseek_v1.model import DeepseekForCausalLM
from .deepseek_v2.model import DeepseekV2ForCausalLM
from .dit.model import DiT
from .eagle.model import EagleForCausalLM
from .enc_dec.model import DecoderModel, EncoderModel, WhisperEncoder
from .f5tts.model import F5TTS
from .falcon.config import FalconConfig
from .falcon.model import FalconForCausalLM, FalconModel
from .gemma.config import GEMMA2_ARCHITECTURE, GEMMA_ARCHITECTURE, GemmaConfig
from .gemma.model import GemmaForCausalLM
from .gpt.config import GPTConfig
from .gpt.model import GPTForCausalLM, GPTModel
from .gptj.config import GPTJConfig
from .gptj.model import GPTJForCausalLM, GPTJModel
from .gptneox.model import GPTNeoXForCausalLM, GPTNeoXModel
from .grok.model import GrokForCausalLM
from .llama.config import LLaMAConfig
from .llama.model import LLaMAForCausalLM, LLaMAModel
from .mamba.model import MambaForCausalLM
from .medusa.config import MedusaConfig
from .medusa.model import MedusaForCausalLm
from .mllama.model import MLLaMAModel
from .modeling_utils import PretrainedConfig, PretrainedModel, SpeculativeDecodingMode
from .mpt.model import MPTForCausalLM, MPTModel
from .nemotron_nas.model import DeciLMForCausalLM
from .opt.model import OPTForCausalLM, OPTModel
from .phi.model import PhiForCausalLM, PhiModel
from .phi3.model import Phi3ForCausalLM, Phi3Model
from .qwen.model import QWenForCausalLM
from .recurrentgemma.model import RecurrentGemmaForCausalLM
from .redrafter.model import ReDrafterForCausalLM
__all__ = [
"BertModel",
"BertForQuestionAnswering",
"BertForSequenceClassification",
"RobertaModel",
"RobertaForQuestionAnswering",
"RobertaForSequenceClassification",
"BloomModel",
"BloomForCausalLM",
"DiT",
"DeepseekForCausalLM",
"FalconConfig",
"DeepseekV2ForCausalLM",
"FalconForCausalLM",
"FalconModel",
"GPTConfig",
"GPTModel",
"GPTForCausalLM",
"OPTForCausalLM",
"OPTModel",
"LLaMAConfig",
"LLaMAForCausalLM",
"LLaMAModel",
"MedusaConfig",
"MedusaForCausalLm",
"ReDrafterForCausalLM",
"GPTJConfig",
"GPTJModel",
"GPTJForCausalLM",
"GPTNeoXModel",
"GPTNeoXForCausalLM",
"PhiModel",
"PhiConfig",
"Phi3Model",
"Phi3Config",
"PhiForCausalLM",
"Phi3ForCausalLM",
"ChatGLMConfig",
"ChatGLMForCausalLM",
"ChatGLMModel",
"BaichuanForCausalLM",
"QWenConfigQWenForCausalLM",
"QWenModel",
"EncoderModel",
"DecoderModel",
"PretrainedConfig",
"PretrainedModel",
"WhisperEncoder",
"MambaForCausalLM",
"MambaConfig",
"MPTForCausalLM",
"MPTModel",
"SkyworkForCausalLM",
"GemmaConfig",
"GemmaForCausalLM",
"DbrxConfig",
"DbrxForCausalLM",
"RecurrentGemmaForCausalLM",
"CogVLMConfig",
"CogVLMForCausalLM",
"EagleForCausalLM",
"SpeculativeDecodingMode",
"CohereForCausalLM",
"MLLaMAModel",
"F5TTS",
]
MODEL_MAP = {
"GPT2LMHeadModel": GPTForCausalLM,
"GPT2LMHeadCustomModel": GPTForCausalLM,
"GPTBigCodeForCausalLM": GPTForCausalLM,
"Starcoder2ForCausalLM": GPTForCausalLM,
"FuyuForCausalLM": GPTForCausalLM,
"Kosmos2ForConditionalGeneration": GPTForCausalLM,
"JAISLMHeadModel": GPTForCausalLM,
"GPTForCausalLM": GPTForCausalLM,
"NemotronForCausalLM": GPTForCausalLM,
"OPTForCausalLM": OPTForCausalLM,
"BloomForCausalLM": BloomForCausalLM,
"RWForCausalLM": FalconForCausalLM,
"FalconForCausalLM": FalconForCausalLM,
"PhiForCausalLM": PhiForCausalLM,
"Phi3ForCausalLM": Phi3ForCausalLM,
"Phi3VForCausalLM": Phi3ForCausalLM,
"Phi3SmallForCausalLM": Phi3ForCausalLM,
"PhiMoEForCausalLM": Phi3ForCausalLM,
"MambaForCausalLM": MambaForCausalLM,
"GPTNeoXForCausalLM": GPTNeoXForCausalLM,
"GPTJForCausalLM": GPTJForCausalLM,
"MPTForCausalLM": MPTForCausalLM,
"GLMModel": ChatGLMForCausalLM,
"ChatGLMModel": ChatGLMForCausalLM,
"ChatGLMForCausalLM": ChatGLMForCausalLM,
"LlamaForCausalLM": LLaMAForCausalLM,
"ExaoneForCausalLM": LLaMAForCausalLM,
"MistralForCausalLM": LLaMAForCausalLM,
"MixtralForCausalLM": LLaMAForCausalLM,
"ArcticForCausalLM": LLaMAForCausalLM,
"Grok1ModelForCausalLM": GrokForCausalLM,
"InternLMForCausalLM": LLaMAForCausalLM,
"InternLM2ForCausalLM": LLaMAForCausalLM,
"MedusaForCausalLM": MedusaForCausalLm,
"ReDrafterForCausalLM": ReDrafterForCausalLM,
"BaichuanForCausalLM": BaichuanForCausalLM,
"BaiChuanForCausalLM": BaichuanForCausalLM,
"SkyworkForCausalLM": LLaMAForCausalLM,
GEMMA_ARCHITECTURE: GemmaForCausalLM,
GEMMA2_ARCHITECTURE: GemmaForCausalLM,
"QWenLMHeadModel": QWenForCausalLM,
"QWenForCausalLM": QWenForCausalLM,
"Qwen2ForCausalLM": QWenForCausalLM,
"Qwen2MoeForCausalLM": QWenForCausalLM,
"Qwen2ForSequenceClassification": QWenForCausalLM,
"Qwen2VLForConditionalGeneration": QWenForCausalLM,
"WhisperEncoder": WhisperEncoder,
"EncoderModel": EncoderModel,
"DecoderModel": DecoderModel,
"DbrxForCausalLM": DbrxForCausalLM,
"RecurrentGemmaForCausalLM": RecurrentGemmaForCausalLM,
"CogVLMForCausalLM": CogVLMForCausalLM,
"DiT": DiT,
"DeepseekForCausalLM": DeepseekForCausalLM,
"DeciLMForCausalLM": DeciLMForCausalLM,
"DeepseekV2ForCausalLM": DeepseekV2ForCausalLM,
"EagleForCausalLM": EagleForCausalLM,
"CohereForCausalLM": CohereForCausalLM,
"MllamaForConditionalGeneration": MLLaMAModel,
"BertForQuestionAnswering": BertForQuestionAnswering,
"BertForSequenceClassification": BertForSequenceClassification,
"BertModel": BertModel,
"RobertaModel": RobertaModel,
"RobertaForQuestionAnswering": RobertaForQuestionAnswering,
"RobertaForSequenceClassification": RobertaForSequenceClassification,
"F5TTS": F5TTS,
}

View File

@@ -0,0 +1,222 @@
from __future__ import annotations
import os
import sys
from collections import OrderedDict
import tensorrt as trt
from tensorrt_llm._common import default_net
from ..._utils import str_dtype_to_trt
from ...functional import Tensor, concat
from ...layers import Linear
from ...module import Module, ModuleList
from ...plugin import current_all_reduce_helper
from ..modeling_utils import PretrainedConfig, PretrainedModel
from .modules import AdaLayerNormZero_Final, ConvPositionEmbedding, DiTBlock, TimestepEmbedding
current_file_path = os.path.abspath(__file__)
parent_dir = os.path.dirname(current_file_path)
sys.path.append(parent_dir)
class InputEmbedding(Module):
def __init__(self, mel_dim, text_dim, out_dim):
super().__init__()
self.proj = Linear(mel_dim * 2 + text_dim, out_dim)
self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
def forward(self, x, cond):
x = self.proj(concat([x, cond], dim=-1))
return self.conv_pos_embed(x) + x
class F5TTS(PretrainedModel):
def __init__(self, config: PretrainedConfig):
super().__init__(config)
self.dtype = str_dtype_to_trt(config.dtype)
self.time_embed = TimestepEmbedding(config.hidden_size)
self.input_embed = InputEmbedding(config.mel_dim, config.text_dim, config.hidden_size)
self.dim = config.hidden_size
self.depth = config.num_hidden_layers
self.transformer_blocks = ModuleList(
[
DiTBlock(
dim=self.dim,
heads=config.num_attention_heads,
dim_head=config.dim_head,
ff_mult=config.ff_mult,
dropout=config.dropout,
)
for _ in range(self.depth)
]
)
self.norm_out = AdaLayerNormZero_Final(config.hidden_size) # final modulation
self.proj_out = Linear(config.hidden_size, config.mel_dim)
def forward(
self,
noise, # nosied input audio
cond, # masked cond audio
time, # time step
rope_cos,
rope_sin,
input_lengths,
scale=1.0,
):
t = self.time_embed(time)
x = self.input_embed(noise, cond)
for block in self.transformer_blocks:
x = block(x, t, rope_cos=rope_cos, rope_sin=rope_sin, input_lengths=input_lengths, scale=scale)
denoise = self.proj_out(self.norm_out(x, t))
denoise.mark_output("denoised", self.dtype)
return denoise
def prepare_inputs(self, **kwargs):
max_batch_size = kwargs["max_batch_size"]
batch_size_range = [2, 2, max_batch_size]
mel_size = 100
max_seq_len = 3000
num_frames_range = [200, 2 * max_seq_len, max_seq_len * max_batch_size]
hidden_size = 512
concat_feature_dim = mel_size + hidden_size
freq_embed_dim = 256
head_dim = 64
mapping = self.config.mapping
if mapping.tp_size > 1:
current_all_reduce_helper().set_workspace_tensor(mapping, 1)
if default_net().plugin_config.remove_input_padding:
noise = Tensor(
name="noise",
dtype=self.dtype,
shape=[-1, mel_size],
dim_range=OrderedDict(
[
("num_frames", [num_frames_range]),
("n_mels", [mel_size]),
]
),
)
cond = Tensor(
name="cond",
dtype=self.dtype,
shape=[-1, concat_feature_dim],
dim_range=OrderedDict(
[
("num_frames", [num_frames_range]),
("embeded_length", [concat_feature_dim]),
]
),
)
time = Tensor(
name="time",
dtype=self.dtype,
shape=[-1, freq_embed_dim],
dim_range=OrderedDict(
[
("num_frames", [num_frames_range]),
("freq_dim", [freq_embed_dim]),
]
),
)
rope_cos = Tensor(
name="rope_cos",
dtype=self.dtype,
shape=[-1, head_dim],
dim_range=OrderedDict(
[
("num_frames", [num_frames_range]),
("head_dim", [head_dim]),
]
),
)
rope_sin = Tensor(
name="rope_sin",
dtype=self.dtype,
shape=[-1, head_dim],
dim_range=OrderedDict(
[
("num_frames", [num_frames_range]),
("head_dim", [head_dim]),
]
),
)
else:
noise = Tensor(
name="noise",
dtype=self.dtype,
shape=[-1, -1, mel_size],
dim_range=OrderedDict(
[
("batch_size", [batch_size_range]),
("max_duratuion", [[100, max_seq_len // 2, max_seq_len]]),
("n_mels", [mel_size]),
]
),
)
cond = Tensor(
name="cond",
dtype=self.dtype,
shape=[-1, -1, concat_feature_dim],
dim_range=OrderedDict(
[
("batch_size", [batch_size_range]),
("max_duratuion", [[100, max_seq_len // 2, max_seq_len]]),
("embeded_length", [concat_feature_dim]),
]
),
)
time = Tensor(
name="time",
dtype=self.dtype,
shape=[-1, freq_embed_dim],
dim_range=OrderedDict(
[
("batch_size", [batch_size_range]),
("freq_dim", [freq_embed_dim]),
]
),
)
rope_cos = Tensor(
name="rope_cos",
dtype=self.dtype,
shape=[-1, -1, head_dim],
dim_range=OrderedDict(
[
("batch_size", [batch_size_range]),
("max_duratuion", [[100, max_seq_len // 2, max_seq_len]]),
("head_dim", [head_dim]),
]
),
)
rope_sin = Tensor(
name="rope_sin",
dtype=self.dtype,
shape=[-1, -1, head_dim],
dim_range=OrderedDict(
[
("batch_size", [batch_size_range]),
("max_duratuion", [[100, max_seq_len // 2, max_seq_len]]),
("head_dim", [head_dim]),
]
),
)
input_lengths = Tensor(
name="input_lengths",
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([("batch_size", [batch_size_range])]),
)
return {
"noise": noise,
"cond": cond,
"time": time,
"rope_cos": rope_cos,
"rope_sin": rope_sin,
"input_lengths": input_lengths,
}

View File

@@ -0,0 +1,412 @@
from __future__ import annotations
import math
from typing import Optional
import numpy as np
import torch
import torch.nn.functional as F
from tensorrt_llm._common import default_net
from ..._utils import str_dtype_to_trt, trt_dtype_to_np
from ...functional import (
Tensor,
bert_attention,
cast,
chunk,
concat,
constant,
expand,
expand_dims,
expand_dims_like,
expand_mask,
gelu,
matmul,
permute,
shape,
silu,
slice,
softmax,
squeeze,
unsqueeze,
view,
)
from ...layers import ColumnLinear, Conv1d, LayerNorm, Linear, Mish, RowLinear
from ...module import Module
class FeedForward(Module):
def __init__(self, dim, dim_out=None, mult=4, dropout=0.0):
super().__init__()
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
self.project_in = Linear(dim, inner_dim)
self.ff = Linear(inner_dim, dim_out)
def forward(self, x):
return self.ff(gelu(self.project_in(x)))
class AdaLayerNormZero(Module):
def __init__(self, dim):
super().__init__()
self.linear = Linear(dim, dim * 6)
self.norm = LayerNorm(dim, elementwise_affine=False, eps=1e-6)
def forward(self, x, emb=None):
emb = self.linear(silu(emb))
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = chunk(emb, 6, dim=1)
x = self.norm(x)
ones = constant(np.ones(1, dtype=np.float32)).cast(x.dtype)
if default_net().plugin_config.remove_input_padding:
x = x * (ones + scale_msa) + shift_msa
else:
x = x * (ones + unsqueeze(scale_msa, 1)) + unsqueeze(shift_msa, 1)
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
class AdaLayerNormZero_Final(Module):
def __init__(self, dim):
super().__init__()
self.linear = Linear(dim, dim * 2)
self.norm = LayerNorm(dim, elementwise_affine=False, eps=1e-6)
def forward(self, x, emb):
emb = self.linear(silu(emb))
scale, shift = chunk(emb, 2, dim=1)
ones = constant(np.ones(1, dtype=np.float32)).cast(x.dtype)
if default_net().plugin_config.remove_input_padding:
x = self.norm(x) * (ones + scale) + shift
else:
x = self.norm(x) * unsqueeze((ones + scale), 1)
x = x + unsqueeze(shift, 1)
return x
class ConvPositionEmbedding(Module):
def __init__(self, dim, kernel_size=31, groups=16):
super().__init__()
assert kernel_size % 2 != 0
self.conv1d1 = Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2)
self.conv1d2 = Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2)
self.mish = Mish()
def forward(self, x, mask=None): # noqa: F722
if default_net().plugin_config.remove_input_padding:
x = unsqueeze(x, 0)
x = permute(x, [0, 2, 1])
x = self.mish(self.conv1d2(self.mish(self.conv1d1(x))))
out = permute(x, [0, 2, 1])
if default_net().plugin_config.remove_input_padding:
out = squeeze(out, 0)
return out
class Attention(Module):
def __init__(
self,
processor: AttnProcessor,
dim: int,
heads: int = 16,
dim_head: int = 64,
dropout: float = 0.0,
context_dim: Optional[int] = None, # if not None -> joint attention
context_pre_only=None,
):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
self.processor = processor
self.dim = dim # hidden_size
self.heads = heads
self.inner_dim = dim_head * heads
self.dropout = dropout
self.attention_head_size = dim_head
self.context_dim = context_dim
self.context_pre_only = context_pre_only
self.tp_size = 1
self.num_attention_heads = heads // self.tp_size
self.num_attention_kv_heads = heads // self.tp_size # 8
self.dtype = str_dtype_to_trt("float32")
self.attention_hidden_size = self.attention_head_size * self.num_attention_heads
self.to_q = ColumnLinear(
dim,
self.tp_size * self.num_attention_heads * self.attention_head_size,
bias=True,
dtype=self.dtype,
tp_group=None,
tp_size=self.tp_size,
)
self.to_k = ColumnLinear(
dim,
self.tp_size * self.num_attention_heads * self.attention_head_size,
bias=True,
dtype=self.dtype,
tp_group=None,
tp_size=self.tp_size,
)
self.to_v = ColumnLinear(
dim,
self.tp_size * self.num_attention_heads * self.attention_head_size,
bias=True,
dtype=self.dtype,
tp_group=None,
tp_size=self.tp_size,
)
if self.context_dim is not None:
self.to_k_c = Linear(context_dim, self.inner_dim)
self.to_v_c = Linear(context_dim, self.inner_dim)
if self.context_pre_only is not None:
self.to_q_c = Linear(context_dim, self.inner_dim)
self.to_out = RowLinear(
self.tp_size * self.num_attention_heads * self.attention_head_size,
dim,
bias=True,
dtype=self.dtype,
tp_group=None,
tp_size=self.tp_size,
)
if self.context_pre_only is not None and not self.context_pre_only:
self.to_out_c = Linear(self.inner_dim, dim)
def forward(
self,
x, # noised input x
rope_cos,
rope_sin,
input_lengths,
c=None, # context c
scale=1.0,
rope=None,
c_rope=None, # rotary position embedding for c
) -> torch.Tensor:
if c is not None:
return self.processor(self, x, c=c, input_lengths=input_lengths, scale=scale, rope=rope, c_rope=c_rope)
else:
return self.processor(
self, x, rope_cos=rope_cos, rope_sin=rope_sin, input_lengths=input_lengths, scale=scale
)
def rotate_every_two_3dim(tensor: Tensor) -> Tensor:
shape_tensor = concat(
[shape(tensor, i) / 2 if i == (tensor.ndim() - 1) else shape(tensor, i) for i in range(tensor.ndim())]
)
if default_net().plugin_config.remove_input_padding:
assert tensor.ndim() == 2
x1 = slice(tensor, [0, 0], shape_tensor, [1, 2])
x2 = slice(tensor, [0, 1], shape_tensor, [1, 2])
x1 = expand_dims(x1, 2)
x2 = expand_dims(x2, 2)
zero = constant(np.ascontiguousarray(np.zeros([1], dtype=trt_dtype_to_np(tensor.dtype))))
x2 = zero - x2
x = concat([x2, x1], 2)
out = view(x, concat([shape(x, 0), shape(x, 1) * 2]))
else:
assert tensor.ndim() == 3
x1 = slice(tensor, [0, 0, 0], shape_tensor, [1, 1, 2])
x2 = slice(tensor, [0, 0, 1], shape_tensor, [1, 1, 2])
x1 = expand_dims(x1, 3)
x2 = expand_dims(x2, 3)
zero = constant(np.ascontiguousarray(np.zeros([1], dtype=trt_dtype_to_np(tensor.dtype))))
x2 = zero - x2
x = concat([x2, x1], 3)
out = view(x, concat([shape(x, 0), shape(x, 1), shape(x, 2) * 2]))
return out
def apply_rotary_pos_emb_3dim(x, rope_cos, rope_sin):
if default_net().plugin_config.remove_input_padding:
rot_dim = shape(rope_cos, -1) # 64
new_t_shape = concat([shape(x, 0), rot_dim]) # (-1, 64)
x_ = slice(x, [0, 0], new_t_shape, [1, 1])
end_dim = shape(x, -1) - shape(rope_cos, -1)
new_t_unrotated_shape = concat([shape(x, 0), end_dim]) # (2, -1, 960)
x_unrotated = slice(x, concat([0, rot_dim]), new_t_unrotated_shape, [1, 1])
out = concat([x_ * rope_cos + rotate_every_two_3dim(x_) * rope_sin, x_unrotated], dim=-1)
else:
rot_dim = shape(rope_cos, 2) # 64
new_t_shape = concat([shape(x, 0), shape(x, 1), rot_dim]) # (2, -1, 64)
x_ = slice(x, [0, 0, 0], new_t_shape, [1, 1, 1])
end_dim = shape(x, 2) - shape(rope_cos, 2)
new_t_unrotated_shape = concat([shape(x, 0), shape(x, 1), end_dim]) # (2, -1, 960)
x_unrotated = slice(x, concat([0, 0, rot_dim]), new_t_unrotated_shape, [1, 1, 1])
out = concat([x_ * rope_cos + rotate_every_two_3dim(x_) * rope_sin, x_unrotated], dim=-1)
return out
class AttnProcessor:
def __init__(self):
pass
def __call__(
self,
attn,
x, # noised input x
rope_cos,
rope_sin,
input_lengths,
scale=1.0,
rope=None,
) -> torch.FloatTensor:
query = attn.to_q(x)
key = attn.to_k(x)
value = attn.to_v(x)
# k,v,q all (2,1226,1024)
query = apply_rotary_pos_emb_3dim(query, rope_cos, rope_sin)
key = apply_rotary_pos_emb_3dim(key, rope_cos, rope_sin)
# attention
inner_dim = key.shape[-1]
norm_factor = math.sqrt(attn.attention_head_size)
q_scaling = 1.0 / norm_factor
mask = None
if not default_net().plugin_config.remove_input_padding:
N = shape(x, 1)
B = shape(x, 0)
seq_len_2d = concat([1, N])
max_position_embeddings = 4096
# create position ids
position_ids_buffer = constant(np.expand_dims(np.arange(max_position_embeddings).astype(np.int32), 0))
tmp_position_ids = slice(position_ids_buffer, starts=[0, 0], sizes=seq_len_2d)
tmp_position_ids = expand(tmp_position_ids, concat([B, N])) # BxL
tmp_input_lengths = unsqueeze(input_lengths, 1) # Bx1
tmp_input_lengths = expand(tmp_input_lengths, concat([B, N])) # BxL
mask = tmp_position_ids < tmp_input_lengths # BxL
mask = mask.cast("int32")
if default_net().plugin_config.bert_attention_plugin:
qkv = concat([query, key, value], dim=-1)
# TRT plugin mode
assert input_lengths is not None
if default_net().plugin_config.remove_input_padding:
qkv = qkv.view(concat([-1, 3 * inner_dim]))
max_input_length = constant(
np.zeros(
[
2048,
],
dtype=np.int32,
)
)
else:
max_input_length = None
context = bert_attention(
qkv,
input_lengths,
attn.num_attention_heads,
attn.attention_head_size,
q_scaling=q_scaling,
max_input_length=max_input_length,
)
else:
assert not default_net().plugin_config.remove_input_padding
def transpose_for_scores(x):
new_x_shape = concat([shape(x, 0), shape(x, 1), attn.num_attention_heads, attn.attention_head_size])
y = x.view(new_x_shape)
y = y.transpose(1, 2)
return y
def transpose_for_scores_k(x):
new_x_shape = concat([shape(x, 0), shape(x, 1), attn.num_attention_heads, attn.attention_head_size])
y = x.view(new_x_shape)
y = y.permute([0, 2, 3, 1])
return y
query = transpose_for_scores(query)
key = transpose_for_scores_k(key)
value = transpose_for_scores(value)
attention_scores = matmul(query, key, use_fp32_acc=False)
if mask is not None:
attention_mask = expand_mask(mask, shape(query, 2))
attention_mask = cast(attention_mask, attention_scores.dtype)
attention_scores = attention_scores + attention_mask
attention_probs = softmax(attention_scores, dim=-1)
context = matmul(attention_probs, value, use_fp32_acc=False).transpose(1, 2)
context = context.view(concat([shape(context, 0), shape(context, 1), attn.attention_hidden_size]))
context = attn.to_out(context)
if mask is not None:
mask = mask.view(concat([shape(mask, 0), shape(mask, 1), 1]))
mask = expand_dims_like(mask, context)
mask = cast(mask, context.dtype)
context = context * mask
return context
# DiT Block
class DiTBlock(Module):
def __init__(self, dim, heads, dim_head, ff_mult=2, dropout=0.1):
super().__init__()
self.attn_norm = AdaLayerNormZero(dim)
self.attn = Attention(
processor=AttnProcessor(),
dim=dim,
heads=heads,
dim_head=dim_head,
dropout=dropout,
)
self.ff_norm = LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout)
def forward(
self, x, t, rope_cos, rope_sin, input_lengths, scale=1.0, rope=ModuleNotFoundError
): # x: noised input, t: time embedding
# pre-norm & modulation for attention input
norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
# attention
# norm ----> (2,1226,1024)
attn_output = self.attn(x=norm, rope_cos=rope_cos, rope_sin=rope_sin, input_lengths=input_lengths, scale=scale)
# process attention output for input x
if default_net().plugin_config.remove_input_padding:
x = x + gate_msa * attn_output
else:
x = x + unsqueeze(gate_msa, 1) * attn_output
ones = constant(np.ones(1, dtype=np.float32)).cast(x.dtype)
if default_net().plugin_config.remove_input_padding:
norm = self.ff_norm(x) * (ones + scale_mlp) + shift_mlp
else:
norm = self.ff_norm(x) * (ones + unsqueeze(scale_mlp, 1)) + unsqueeze(shift_mlp, 1)
# norm = self.ff_norm(x) * (ones + scale_mlp) + shift_mlp
ff_output = self.ff(norm)
if default_net().plugin_config.remove_input_padding:
x = x + gate_mlp * ff_output
else:
x = x + unsqueeze(gate_mlp, 1) * ff_output
return x
class TimestepEmbedding(Module):
def __init__(self, dim, freq_embed_dim=256, dtype=None):
super().__init__()
# self.time_embed = SinusPositionEmbedding(freq_embed_dim)
self.mlp1 = Linear(freq_embed_dim, dim, bias=True, dtype=dtype)
self.mlp2 = Linear(dim, dim, bias=True, dtype=dtype)
def forward(self, timestep):
t_freq = self.mlp1(timestep)
t_freq = silu(t_freq)
t_emb = self.mlp2(t_freq)
return t_emb

View File

@@ -0,0 +1,24 @@
accelerate>=0.33.0
bitsandbytes>0.37.0
cached_path
click
datasets
ema_pytorch>=0.5.2
gradio>=3.45.2
hydra-core>=1.3.0
jieba
librosa
matplotlib
numpy<=1.26.4
pydub
pypinyin
safetensors
soundfile
tomli
torch>=2.0.0
# torchaudio>=2.0.0
torchdiffeq
tqdm>=4.65.0
transformers
x_transformers>=1.31.14
packaging>=24.2

View File

@@ -0,0 +1,110 @@
stage=$1
stop_stage=$2
model=$3 # F5TTS_Base
if [ -z "$model" ]; then
echo "Model is none, using default model F5TTS_Base"
model=F5TTS_Base
fi
echo "Start stage: $stage, Stop stage: $stop_stage, Model: $model"
export CUDA_VISIBLE_DEVICES=0
F5_TTS_HF_DOWNLOAD_PATH=./F5-TTS
F5_TTS_TRT_LLM_CHECKPOINT_PATH=./trtllm_ckpt
F5_TTS_TRT_LLM_ENGINE_PATH=./f5_trt_llm_engine
vocoder_trt_engine_path=vocos_vocoder.plan
model_repo=./model_repo
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
echo "Downloading f5 tts from huggingface"
huggingface-cli download SWivid/F5-TTS --local-dir $F5_TTS_HF_DOWNLOAD_PATH
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
echo "Converting checkpoint"
python3 ./scripts/convert_checkpoint.py \
--timm_ckpt "$F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt" \
--output_dir "$F5_TTS_TRT_LLM_CHECKPOINT_PATH" --model_name $model
python_package_path=/usr/local/lib/python3.12/dist-packages
cp -r patch/* $python_package_path/tensorrt_llm/models
trtllm-build --checkpoint_dir $F5_TTS_TRT_LLM_CHECKPOINT_PATH \
--max_batch_size 8 \
--output_dir $F5_TTS_TRT_LLM_ENGINE_PATH --remove_input_padding disable
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
echo "Exporting vocos vocoder"
onnx_vocoder_path=vocos_vocoder.onnx
python3 scripts/export_vocoder_to_onnx.py --vocoder vocos --output-path $onnx_vocoder_path
bash scripts/export_vocos_trt.sh $onnx_vocoder_path $vocoder_trt_engine_path
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
echo "Building triton server"
rm -r $model_repo
cp -r ./model_repo_f5_tts $model_repo
python3 scripts/fill_template.py -i $model_repo/f5_tts/config.pbtxt vocab:$F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt,model:$F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt,trtllm:$F5_TTS_TRT_LLM_ENGINE_PATH,vocoder:vocos
cp $vocoder_trt_engine_path $model_repo/vocoder/1/vocoder.plan
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
echo "Starting triton server"
tritonserver --model-repository=$model_repo
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
echo "Testing triton server"
num_task=1
log_dir=./log_concurrent_tasks_${num_task}
rm -r $log_dir
python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts --log-dir $log_dir
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
echo "Testing http client"
audio=../../infer/examples/basic/basic_ref_en.wav
reference_text="Some call me nature, others call me mother nature."
target_text="I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring."
python3 client_http.py --reference-audio $audio --reference-text "$reference_text" --target-text "$target_text"
fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
echo "TRT-LLM: offline decoding benchmark test"
batch_size=1
split_name=wenetspeech4tts
backend_type=trt
log_dir=./log_benchmark_batch_size_${batch_size}_${split_name}_${backend_type}
rm -r $log_dir
ln -s model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py ./
torchrun --nproc_per_node=1 \
benchmark.py --output-dir $log_dir \
--batch-size $batch_size \
--enable-warmup \
--split-name $split_name \
--model-path $F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt \
--vocab-file $F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt \
--vocoder-trt-engine-path $vocoder_trt_engine_path \
--backend-type $backend_type \
--tllm-model-dir $F5_TTS_TRT_LLM_ENGINE_PATH || exit 1
fi
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
echo "Native Pytorch: offline decoding benchmark test"
pip install -r requirements-pytorch.txt
batch_size=1
split_name=wenetspeech4tts
backend_type=pytorch
log_dir=./log_benchmark_batch_size_${batch_size}_${split_name}_${backend_type}
rm -r $log_dir
ln -s model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py ./
torchrun --nproc_per_node=1 \
benchmark.py --output-dir $log_dir \
--batch-size $batch_size \
--split-name $split_name \
--enable-warmup \
--model-path $F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt \
--vocab-file $F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt \
--backend-type $backend_type \
--tllm-model-dir $F5_TTS_TRT_LLM_ENGINE_PATH || exit 1
fi

View File

@@ -0,0 +1,248 @@
# Modified from https://github.com/echocatzh/conv-stft/blob/master/conv_stft/conv_stft.py
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# MIT License
# Copyright (c) 2020 Shimin Zhang
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import torch as th
import torch.nn.functional as F
from scipy.signal import check_COLA, get_window
support_clp_op = None
if th.__version__ >= "1.7.0":
from torch.fft import rfft as fft
support_clp_op = True
else:
from torch import rfft as fft
class STFT(th.nn.Module):
def __init__(
self,
win_len=1024,
win_hop=512,
fft_len=1024,
enframe_mode="continue",
win_type="hann",
win_sqrt=False,
pad_center=True,
):
"""
Implement of STFT using 1D convolution and 1D transpose convolutions.
Implement of framing the signal in 2 ways, `break` and `continue`.
`break` method is a kaldi-like framing.
`continue` method is a librosa-like framing.
More information about `perfect reconstruction`:
1. https://ww2.mathworks.cn/help/signal/ref/stft.html
2. https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.get_window.html
Args:
win_len (int): Number of points in one frame. Defaults to 1024.
win_hop (int): Number of framing stride. Defaults to 512.
fft_len (int): Number of DFT points. Defaults to 1024.
enframe_mode (str, optional): `break` and `continue`. Defaults to 'continue'.
win_type (str, optional): The type of window to create. Defaults to 'hann'.
win_sqrt (bool, optional): using square root window. Defaults to True.
pad_center (bool, optional): `perfect reconstruction` opts. Defaults to True.
"""
super(STFT, self).__init__()
assert enframe_mode in ["break", "continue"]
assert fft_len >= win_len
self.win_len = win_len
self.win_hop = win_hop
self.fft_len = fft_len
self.mode = enframe_mode
self.win_type = win_type
self.win_sqrt = win_sqrt
self.pad_center = pad_center
self.pad_amount = self.fft_len // 2
en_k, fft_k, ifft_k, ola_k = self.__init_kernel__()
self.register_buffer("en_k", en_k)
self.register_buffer("fft_k", fft_k)
self.register_buffer("ifft_k", ifft_k)
self.register_buffer("ola_k", ola_k)
def __init_kernel__(self):
"""
Generate enframe_kernel, fft_kernel, ifft_kernel and overlap-add kernel.
** enframe_kernel: Using conv1d layer and identity matrix.
** fft_kernel: Using linear layer for matrix multiplication. In fact,
enframe_kernel and fft_kernel can be combined, But for the sake of
readability, I took the two apart.
** ifft_kernel, pinv of fft_kernel.
** overlap-add kernel, just like enframe_kernel, but transposed.
Returns:
tuple: four kernels.
"""
enframed_kernel = th.eye(self.fft_len)[:, None, :]
if support_clp_op:
tmp = fft(th.eye(self.fft_len))
fft_kernel = th.stack([tmp.real, tmp.imag], dim=2)
else:
fft_kernel = fft(th.eye(self.fft_len), 1)
if self.mode == "break":
enframed_kernel = th.eye(self.win_len)[:, None, :]
fft_kernel = fft_kernel[: self.win_len]
fft_kernel = th.cat((fft_kernel[:, :, 0], fft_kernel[:, :, 1]), dim=1)
ifft_kernel = th.pinverse(fft_kernel)[:, None, :]
window = get_window(self.win_type, self.win_len)
self.perfect_reconstruct = check_COLA(window, self.win_len, self.win_len - self.win_hop)
window = th.FloatTensor(window)
if self.mode == "continue":
left_pad = (self.fft_len - self.win_len) // 2
right_pad = left_pad + (self.fft_len - self.win_len) % 2
window = F.pad(window, (left_pad, right_pad))
if self.win_sqrt:
self.padded_window = window
window = th.sqrt(window)
else:
self.padded_window = window**2
fft_kernel = fft_kernel.T * window
ifft_kernel = ifft_kernel * window
ola_kernel = th.eye(self.fft_len)[: self.win_len, None, :]
if self.mode == "continue":
ola_kernel = th.eye(self.fft_len)[:, None, : self.fft_len]
return enframed_kernel, fft_kernel, ifft_kernel, ola_kernel
def is_perfect(self):
"""
Whether the parameters win_len, win_hop and win_sqrt
obey constants overlap-add(COLA)
Returns:
bool: Return true if parameters obey COLA.
"""
return self.perfect_reconstruct and self.pad_center
def transform(self, inputs, return_type="complex"):
"""Take input data (audio) to STFT domain.
Args:
inputs (tensor): Tensor of floats, with shape (num_batch, num_samples)
return_type (str, optional): return (mag, phase) when `magphase`,
return (real, imag) when `realimag` and complex(real, imag) when `complex`.
Defaults to 'complex'.
Returns:
tuple: (mag, phase) when `magphase`, return (real, imag) when
`realimag`. Defaults to 'complex', each elements with shape
[num_batch, num_frequencies, num_frames]
"""
assert return_type in ["magphase", "realimag", "complex"]
if inputs.dim() == 2:
inputs = th.unsqueeze(inputs, 1)
self.num_samples = inputs.size(-1)
if self.pad_center:
inputs = F.pad(inputs, (self.pad_amount, self.pad_amount), mode="reflect")
enframe_inputs = F.conv1d(inputs, self.en_k, stride=self.win_hop)
outputs = th.transpose(enframe_inputs, 1, 2)
outputs = F.linear(outputs, self.fft_k)
outputs = th.transpose(outputs, 1, 2)
dim = self.fft_len // 2 + 1
real = outputs[:, :dim, :]
imag = outputs[:, dim:, :]
if return_type == "realimag":
return real, imag
elif return_type == "complex":
assert support_clp_op
return th.complex(real, imag)
else:
mags = th.sqrt(real**2 + imag**2)
phase = th.atan2(imag, real)
return mags, phase
def inverse(self, input1, input2=None, input_type="magphase"):
"""Call the inverse STFT (iSTFT), given tensors produced
by the `transform` function.
Args:
input1 (tensors): Magnitude/Real-part of STFT with shape
[num_batch, num_frequencies, num_frames]
input2 (tensors): Phase/Imag-part of STFT with shape
[num_batch, num_frequencies, num_frames]
input_type (str, optional): Mathematical meaning of input tensor's.
Defaults to 'magphase'.
Returns:
tensors: Reconstructed audio given magnitude and phase. Of
shape [num_batch, num_samples]
"""
assert input_type in ["magphase", "realimag"]
if input_type == "realimag":
real, imag = None, None
if support_clp_op and th.is_complex(input1):
real, imag = input1.real, input1.imag
else:
real, imag = input1, input2
else:
real = input1 * th.cos(input2)
imag = input1 * th.sin(input2)
inputs = th.cat([real, imag], dim=1)
outputs = F.conv_transpose1d(inputs, self.ifft_k, stride=self.win_hop)
t = (self.padded_window[None, :, None]).repeat(1, 1, inputs.size(-1))
t = t.to(inputs.device)
coff = F.conv_transpose1d(t, self.ola_k, stride=self.win_hop)
num_frames = input1.size(-1)
num_samples = num_frames * self.win_hop
rm_start, rm_end = self.pad_amount, self.pad_amount + num_samples
outputs = outputs[..., rm_start:rm_end]
coff = coff[..., rm_start:rm_end]
coffidx = th.where(coff > 1e-8)
outputs[coffidx] = outputs[coffidx] / (coff[coffidx])
return outputs.squeeze(dim=1)
def forward(self, inputs):
"""Take input data (audio) to STFT domain and then back to audio.
Args:
inputs (tensor): Tensor of floats, with shape [num_batch, num_samples]
Returns:
tensor: Reconstructed audio given magnitude and phase.
Of shape [num_batch, num_samples]
"""
mag, phase = self.transform(inputs)
rec_wav = self.inverse(mag, phase)
return rec_wav

View File

@@ -0,0 +1,358 @@
import argparse
import json
import os
import re
import time
import traceback
from concurrent.futures import ThreadPoolExecutor, as_completed
import safetensors.torch
import torch
from tensorrt_llm import str_dtype_to_torch
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.convert_utils import split, split_matrix_tp
def split_q_tp(v, n_head, n_hidden, tensor_parallel, rank):
split_v = split(v, tensor_parallel, rank, dim=1)
return split_v.contiguous()
def split_q_bias_tp(v, n_head, n_hidden, tensor_parallel, rank):
split_v = split(v, tensor_parallel, rank, dim=0)
return split_v.contiguous()
FACEBOOK_DIT_NAME_MAPPING = {
"^time_embed.time_mlp.0.weight$": "time_embed.mlp1.weight",
"^time_embed.time_mlp.0.bias$": "time_embed.mlp1.bias",
"^time_embed.time_mlp.2.weight$": "time_embed.mlp2.weight",
"^time_embed.time_mlp.2.bias$": "time_embed.mlp2.bias",
"^input_embed.conv_pos_embed.conv1d.0.weight$": "input_embed.conv_pos_embed.conv1d1.weight",
"^input_embed.conv_pos_embed.conv1d.0.bias$": "input_embed.conv_pos_embed.conv1d1.bias",
"^input_embed.conv_pos_embed.conv1d.2.weight$": "input_embed.conv_pos_embed.conv1d2.weight",
"^input_embed.conv_pos_embed.conv1d.2.bias$": "input_embed.conv_pos_embed.conv1d2.bias",
"^transformer_blocks.0.attn.to_out.0.weight$": "transformer_blocks.0.attn.to_out.weight",
"^transformer_blocks.0.attn.to_out.0.bias$": "transformer_blocks.0.attn.to_out.bias",
"^transformer_blocks.1.attn.to_out.0.weight$": "transformer_blocks.1.attn.to_out.weight",
"^transformer_blocks.1.attn.to_out.0.bias$": "transformer_blocks.1.attn.to_out.bias",
"^transformer_blocks.2.attn.to_out.0.weight$": "transformer_blocks.2.attn.to_out.weight",
"^transformer_blocks.2.attn.to_out.0.bias$": "transformer_blocks.2.attn.to_out.bias",
"^transformer_blocks.3.attn.to_out.0.weight$": "transformer_blocks.3.attn.to_out.weight",
"^transformer_blocks.3.attn.to_out.0.bias$": "transformer_blocks.3.attn.to_out.bias",
"^transformer_blocks.4.attn.to_out.0.weight$": "transformer_blocks.4.attn.to_out.weight",
"^transformer_blocks.4.attn.to_out.0.bias$": "transformer_blocks.4.attn.to_out.bias",
"^transformer_blocks.5.attn.to_out.0.weight$": "transformer_blocks.5.attn.to_out.weight",
"^transformer_blocks.5.attn.to_out.0.bias$": "transformer_blocks.5.attn.to_out.bias",
"^transformer_blocks.6.attn.to_out.0.weight$": "transformer_blocks.6.attn.to_out.weight",
"^transformer_blocks.6.attn.to_out.0.bias$": "transformer_blocks.6.attn.to_out.bias",
"^transformer_blocks.7.attn.to_out.0.weight$": "transformer_blocks.7.attn.to_out.weight",
"^transformer_blocks.7.attn.to_out.0.bias$": "transformer_blocks.7.attn.to_out.bias",
"^transformer_blocks.8.attn.to_out.0.weight$": "transformer_blocks.8.attn.to_out.weight",
"^transformer_blocks.8.attn.to_out.0.bias$": "transformer_blocks.8.attn.to_out.bias",
"^transformer_blocks.9.attn.to_out.0.weight$": "transformer_blocks.9.attn.to_out.weight",
"^transformer_blocks.9.attn.to_out.0.bias$": "transformer_blocks.9.attn.to_out.bias",
"^transformer_blocks.10.attn.to_out.0.weight$": "transformer_blocks.10.attn.to_out.weight",
"^transformer_blocks.10.attn.to_out.0.bias$": "transformer_blocks.10.attn.to_out.bias",
"^transformer_blocks.11.attn.to_out.0.weight$": "transformer_blocks.11.attn.to_out.weight",
"^transformer_blocks.11.attn.to_out.0.bias$": "transformer_blocks.11.attn.to_out.bias",
"^transformer_blocks.12.attn.to_out.0.weight$": "transformer_blocks.12.attn.to_out.weight",
"^transformer_blocks.12.attn.to_out.0.bias$": "transformer_blocks.12.attn.to_out.bias",
"^transformer_blocks.13.attn.to_out.0.weight$": "transformer_blocks.13.attn.to_out.weight",
"^transformer_blocks.13.attn.to_out.0.bias$": "transformer_blocks.13.attn.to_out.bias",
"^transformer_blocks.14.attn.to_out.0.weight$": "transformer_blocks.14.attn.to_out.weight",
"^transformer_blocks.14.attn.to_out.0.bias$": "transformer_blocks.14.attn.to_out.bias",
"^transformer_blocks.15.attn.to_out.0.weight$": "transformer_blocks.15.attn.to_out.weight",
"^transformer_blocks.15.attn.to_out.0.bias$": "transformer_blocks.15.attn.to_out.bias",
"^transformer_blocks.16.attn.to_out.0.weight$": "transformer_blocks.16.attn.to_out.weight",
"^transformer_blocks.16.attn.to_out.0.bias$": "transformer_blocks.16.attn.to_out.bias",
"^transformer_blocks.17.attn.to_out.0.weight$": "transformer_blocks.17.attn.to_out.weight",
"^transformer_blocks.17.attn.to_out.0.bias$": "transformer_blocks.17.attn.to_out.bias",
"^transformer_blocks.18.attn.to_out.0.weight$": "transformer_blocks.18.attn.to_out.weight",
"^transformer_blocks.18.attn.to_out.0.bias$": "transformer_blocks.18.attn.to_out.bias",
"^transformer_blocks.19.attn.to_out.0.weight$": "transformer_blocks.19.attn.to_out.weight",
"^transformer_blocks.19.attn.to_out.0.bias$": "transformer_blocks.19.attn.to_out.bias",
"^transformer_blocks.20.attn.to_out.0.weight$": "transformer_blocks.20.attn.to_out.weight",
"^transformer_blocks.20.attn.to_out.0.bias$": "transformer_blocks.20.attn.to_out.bias",
"^transformer_blocks.21.attn.to_out.0.weight$": "transformer_blocks.21.attn.to_out.weight",
"^transformer_blocks.21.attn.to_out.0.bias$": "transformer_blocks.21.attn.to_out.bias",
"^transformer_blocks.0.ff.ff.0.0.weight$": "transformer_blocks.0.ff.project_in.weight",
"^transformer_blocks.0.ff.ff.0.0.bias$": "transformer_blocks.0.ff.project_in.bias",
"^transformer_blocks.0.ff.ff.2.weight$": "transformer_blocks.0.ff.ff.weight",
"^transformer_blocks.0.ff.ff.2.bias$": "transformer_blocks.0.ff.ff.bias",
"^transformer_blocks.1.ff.ff.0.0.weight$": "transformer_blocks.1.ff.project_in.weight",
"^transformer_blocks.1.ff.ff.0.0.bias$": "transformer_blocks.1.ff.project_in.bias",
"^transformer_blocks.1.ff.ff.2.weight$": "transformer_blocks.1.ff.ff.weight",
"^transformer_blocks.1.ff.ff.2.bias$": "transformer_blocks.1.ff.ff.bias",
"^transformer_blocks.2.ff.ff.0.0.weight$": "transformer_blocks.2.ff.project_in.weight",
"^transformer_blocks.2.ff.ff.0.0.bias$": "transformer_blocks.2.ff.project_in.bias",
"^transformer_blocks.2.ff.ff.2.weight$": "transformer_blocks.2.ff.ff.weight",
"^transformer_blocks.2.ff.ff.2.bias$": "transformer_blocks.2.ff.ff.bias",
"^transformer_blocks.3.ff.ff.0.0.weight$": "transformer_blocks.3.ff.project_in.weight",
"^transformer_blocks.3.ff.ff.0.0.bias$": "transformer_blocks.3.ff.project_in.bias",
"^transformer_blocks.3.ff.ff.2.weight$": "transformer_blocks.3.ff.ff.weight",
"^transformer_blocks.3.ff.ff.2.bias$": "transformer_blocks.3.ff.ff.bias",
"^transformer_blocks.4.ff.ff.0.0.weight$": "transformer_blocks.4.ff.project_in.weight",
"^transformer_blocks.4.ff.ff.0.0.bias$": "transformer_blocks.4.ff.project_in.bias",
"^transformer_blocks.4.ff.ff.2.weight$": "transformer_blocks.4.ff.ff.weight",
"^transformer_blocks.4.ff.ff.2.bias$": "transformer_blocks.4.ff.ff.bias",
"^transformer_blocks.5.ff.ff.0.0.weight$": "transformer_blocks.5.ff.project_in.weight",
"^transformer_blocks.5.ff.ff.0.0.bias$": "transformer_blocks.5.ff.project_in.bias",
"^transformer_blocks.5.ff.ff.2.weight$": "transformer_blocks.5.ff.ff.weight",
"^transformer_blocks.5.ff.ff.2.bias$": "transformer_blocks.5.ff.ff.bias",
"^transformer_blocks.6.ff.ff.0.0.weight$": "transformer_blocks.6.ff.project_in.weight",
"^transformer_blocks.6.ff.ff.0.0.bias$": "transformer_blocks.6.ff.project_in.bias",
"^transformer_blocks.6.ff.ff.2.weight$": "transformer_blocks.6.ff.ff.weight",
"^transformer_blocks.6.ff.ff.2.bias$": "transformer_blocks.6.ff.ff.bias",
"^transformer_blocks.7.ff.ff.0.0.weight$": "transformer_blocks.7.ff.project_in.weight",
"^transformer_blocks.7.ff.ff.0.0.bias$": "transformer_blocks.7.ff.project_in.bias",
"^transformer_blocks.7.ff.ff.2.weight$": "transformer_blocks.7.ff.ff.weight",
"^transformer_blocks.7.ff.ff.2.bias$": "transformer_blocks.7.ff.ff.bias",
"^transformer_blocks.8.ff.ff.0.0.weight$": "transformer_blocks.8.ff.project_in.weight",
"^transformer_blocks.8.ff.ff.0.0.bias$": "transformer_blocks.8.ff.project_in.bias",
"^transformer_blocks.8.ff.ff.2.weight$": "transformer_blocks.8.ff.ff.weight",
"^transformer_blocks.8.ff.ff.2.bias$": "transformer_blocks.8.ff.ff.bias",
"^transformer_blocks.9.ff.ff.0.0.weight$": "transformer_blocks.9.ff.project_in.weight",
"^transformer_blocks.9.ff.ff.0.0.bias$": "transformer_blocks.9.ff.project_in.bias",
"^transformer_blocks.9.ff.ff.2.weight$": "transformer_blocks.9.ff.ff.weight",
"^transformer_blocks.9.ff.ff.2.bias$": "transformer_blocks.9.ff.ff.bias",
"^transformer_blocks.10.ff.ff.0.0.weight$": "transformer_blocks.10.ff.project_in.weight",
"^transformer_blocks.10.ff.ff.0.0.bias$": "transformer_blocks.10.ff.project_in.bias",
"^transformer_blocks.10.ff.ff.2.weight$": "transformer_blocks.10.ff.ff.weight",
"^transformer_blocks.10.ff.ff.2.bias$": "transformer_blocks.10.ff.ff.bias",
"^transformer_blocks.11.ff.ff.0.0.weight$": "transformer_blocks.11.ff.project_in.weight",
"^transformer_blocks.11.ff.ff.0.0.bias$": "transformer_blocks.11.ff.project_in.bias",
"^transformer_blocks.11.ff.ff.2.weight$": "transformer_blocks.11.ff.ff.weight",
"^transformer_blocks.11.ff.ff.2.bias$": "transformer_blocks.11.ff.ff.bias",
"^transformer_blocks.12.ff.ff.0.0.weight$": "transformer_blocks.12.ff.project_in.weight",
"^transformer_blocks.12.ff.ff.0.0.bias$": "transformer_blocks.12.ff.project_in.bias",
"^transformer_blocks.12.ff.ff.2.weight$": "transformer_blocks.12.ff.ff.weight",
"^transformer_blocks.12.ff.ff.2.bias$": "transformer_blocks.12.ff.ff.bias",
"^transformer_blocks.13.ff.ff.0.0.weight$": "transformer_blocks.13.ff.project_in.weight",
"^transformer_blocks.13.ff.ff.0.0.bias$": "transformer_blocks.13.ff.project_in.bias",
"^transformer_blocks.13.ff.ff.2.weight$": "transformer_blocks.13.ff.ff.weight",
"^transformer_blocks.13.ff.ff.2.bias$": "transformer_blocks.13.ff.ff.bias",
"^transformer_blocks.14.ff.ff.0.0.weight$": "transformer_blocks.14.ff.project_in.weight",
"^transformer_blocks.14.ff.ff.0.0.bias$": "transformer_blocks.14.ff.project_in.bias",
"^transformer_blocks.14.ff.ff.2.weight$": "transformer_blocks.14.ff.ff.weight",
"^transformer_blocks.14.ff.ff.2.bias$": "transformer_blocks.14.ff.ff.bias",
"^transformer_blocks.15.ff.ff.0.0.weight$": "transformer_blocks.15.ff.project_in.weight",
"^transformer_blocks.15.ff.ff.0.0.bias$": "transformer_blocks.15.ff.project_in.bias",
"^transformer_blocks.15.ff.ff.2.weight$": "transformer_blocks.15.ff.ff.weight",
"^transformer_blocks.15.ff.ff.2.bias$": "transformer_blocks.15.ff.ff.bias",
"^transformer_blocks.16.ff.ff.0.0.weight$": "transformer_blocks.16.ff.project_in.weight",
"^transformer_blocks.16.ff.ff.0.0.bias$": "transformer_blocks.16.ff.project_in.bias",
"^transformer_blocks.16.ff.ff.2.weight$": "transformer_blocks.16.ff.ff.weight",
"^transformer_blocks.16.ff.ff.2.bias$": "transformer_blocks.16.ff.ff.bias",
"^transformer_blocks.17.ff.ff.0.0.weight$": "transformer_blocks.17.ff.project_in.weight",
"^transformer_blocks.17.ff.ff.0.0.bias$": "transformer_blocks.17.ff.project_in.bias",
"^transformer_blocks.17.ff.ff.2.weight$": "transformer_blocks.17.ff.ff.weight",
"^transformer_blocks.17.ff.ff.2.bias$": "transformer_blocks.17.ff.ff.bias",
"^transformer_blocks.18.ff.ff.0.0.weight$": "transformer_blocks.18.ff.project_in.weight",
"^transformer_blocks.18.ff.ff.0.0.bias$": "transformer_blocks.18.ff.project_in.bias",
"^transformer_blocks.18.ff.ff.2.weight$": "transformer_blocks.18.ff.ff.weight",
"^transformer_blocks.18.ff.ff.2.bias$": "transformer_blocks.18.ff.ff.bias",
"^transformer_blocks.19.ff.ff.0.0.weight$": "transformer_blocks.19.ff.project_in.weight",
"^transformer_blocks.19.ff.ff.0.0.bias$": "transformer_blocks.19.ff.project_in.bias",
"^transformer_blocks.19.ff.ff.2.weight$": "transformer_blocks.19.ff.ff.weight",
"^transformer_blocks.19.ff.ff.2.bias$": "transformer_blocks.19.ff.ff.bias",
"^transformer_blocks.20.ff.ff.0.0.weight$": "transformer_blocks.20.ff.project_in.weight",
"^transformer_blocks.20.ff.ff.0.0.bias$": "transformer_blocks.20.ff.project_in.bias",
"^transformer_blocks.20.ff.ff.2.weight$": "transformer_blocks.20.ff.ff.weight",
"^transformer_blocks.20.ff.ff.2.bias$": "transformer_blocks.20.ff.ff.bias",
"^transformer_blocks.21.ff.ff.0.0.weight$": "transformer_blocks.21.ff.project_in.weight",
"^transformer_blocks.21.ff.ff.0.0.bias$": "transformer_blocks.21.ff.project_in.bias",
"^transformer_blocks.21.ff.ff.2.weight$": "transformer_blocks.21.ff.ff.weight",
"^transformer_blocks.21.ff.ff.2.bias$": "transformer_blocks.21.ff.ff.bias",
}
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name",
type=str,
default="F5TTS_Base",
choices=[
"F5TTS_Base",
],
) # TODO: support F5TTS_v1_Base
parser.add_argument("--timm_ckpt", type=str, default="./ckpts/model_1200000.pt")
parser.add_argument(
"--output_dir", type=str, default="./tllm_checkpoint", help="The path to save the TensorRT-LLM checkpoint"
)
parser.add_argument("--hidden_size", type=int, default=1024, help="The hidden size of DiT")
parser.add_argument("--depth", type=int, default=22, help="The number of DiTBlock layers")
parser.add_argument("--num_heads", type=int, default=16, help="The number of heads of attention module")
parser.add_argument("--cfg_scale", type=float, default=4.0)
parser.add_argument("--tp_size", type=int, default=1, help="N-way tensor parallelism size")
parser.add_argument("--cp_size", type=int, default=1, help="Context parallelism size")
parser.add_argument("--pp_size", type=int, default=1, help="N-way pipeline parallelism size")
parser.add_argument("--dtype", type=str, default="float16", choices=["float32", "bfloat16", "float16"])
parser.add_argument("--fp8_linear", action="store_true", help="Whether use FP8 for linear layers")
parser.add_argument(
"--workers", type=int, default=1, help="The number of workers for converting checkpoint in parallel"
)
args = parser.parse_args()
return args
def convert_timm_dit(args, mapping, dtype="float32"):
weights = {}
tik = time.time()
torch_dtype = str_dtype_to_torch(dtype)
tensor_parallel = mapping.tp_size
model_params = dict(torch.load(args.timm_ckpt))
model_params = {
k: v for k, v in model_params["ema_model_state_dict"].items() if k.startswith("ema_model.transformer")
}
prefix = "ema_model.transformer."
model_params = {key[len(prefix) :] if key.startswith(prefix) else key: value for key, value in model_params.items()}
timm_to_trtllm_name = FACEBOOK_DIT_NAME_MAPPING
def get_trtllm_name(timm_name):
for k, v in timm_to_trtllm_name.items():
m = re.match(k, timm_name)
if m is not None:
if "*" in v:
v = v.replace("*", m.groups()[0])
return v
return timm_name
weights = dict()
for name, param in model_params.items():
if name == "input_embed.conv_pos_embed.conv1d.0.weight" or name == "input_embed.conv_pos_embed.conv1d.2.weight":
weights[get_trtllm_name(name)] = param.contiguous().to(torch_dtype).unsqueeze(-1)
else:
weights[get_trtllm_name(name)] = param.contiguous().to(torch_dtype)
assert len(weights) == len(model_params)
# new_prefix = 'f5_transformer.'
new_prefix = ""
weights = {new_prefix + key: value for key, value in weights.items()}
import math
scale_factor = math.pow(64, -0.25)
for k, v in weights.items():
if re.match("^transformer_blocks.*.attn.to_k.weight$", k):
weights[k] *= scale_factor
weights[k] = split_q_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank)
elif re.match("^transformer_blocks.*.attn.to_k.bias$", k):
weights[k] *= scale_factor
weights[k] = split_q_bias_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank)
elif re.match("^transformer_blocks.*.attn.to_q.weight$", k):
weights[k] = split_q_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank)
weights[k] *= scale_factor
elif re.match("^transformer_blocks.*.attn.to_q.bias$", k):
weights[k] = split_q_bias_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank)
weights[k] *= scale_factor
elif re.match("^transformer_blocks.*.attn.to_v.weight$", k):
weights[k] = split_q_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank)
elif re.match("^transformer_blocks.*.attn.to_v.bias$", k):
weights[k] = split_q_bias_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank)
elif re.match("^transformer_blocks.*.attn.to_out.weight$", k):
weights[k] = split_matrix_tp(v, tensor_parallel, mapping.tp_rank, dim=1)
tok = time.time()
t = time.strftime("%H:%M:%S", time.gmtime(tok - tik))
print(f"Weights loaded. Total time: {t}")
return weights
def save_config(args):
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
config = {
"architecture": "F5TTS",
"dtype": args.dtype,
"hidden_size": 1024,
"num_hidden_layers": 22,
"num_attention_heads": 16,
"dim_head": 64,
"dropout": 0.1,
"ff_mult": 2,
"mel_dim": 100,
"text_num_embeds": 256,
"text_dim": 512,
"conv_layers": 4,
"long_skip_connection": False,
"mapping": {
"world_size": args.cp_size * args.tp_size * args.pp_size,
"cp_size": args.cp_size,
"tp_size": args.tp_size,
"pp_size": args.pp_size,
},
}
if args.fp8_linear:
config["quantization"] = {
"quant_algo": "FP8",
# TODO: add support for exclude modules.
# 'exclude_modules': "*final_layer*",
}
with open(os.path.join(args.output_dir, "config.json"), "w") as f:
json.dump(config, f, indent=4)
def covert_and_save(args, rank):
if rank == 0:
save_config(args)
mapping = Mapping(
world_size=args.cp_size * args.tp_size * args.pp_size,
rank=rank,
cp_size=args.cp_size,
tp_size=args.tp_size,
pp_size=args.pp_size,
)
weights = convert_timm_dit(args, mapping, dtype=args.dtype)
safetensors.torch.save_file(weights, os.path.join(args.output_dir, f"rank{rank}.safetensors"))
def execute(workers, func, args):
if workers == 1:
for rank, f in enumerate(func):
f(args, rank)
else:
with ThreadPoolExecutor(max_workers=workers) as p:
futures = [p.submit(f, args, rank) for rank, f in enumerate(func)]
exceptions = []
for future in as_completed(futures):
try:
future.result()
except Exception as e:
traceback.print_exc()
exceptions.append(e)
assert len(exceptions) == 0, "Checkpoint conversion failed, please check error log."
def main():
args = parse_arguments()
world_size = args.cp_size * args.tp_size * args.pp_size
assert args.pp_size == 1, "PP is not supported yet."
tik = time.time()
if args.timm_ckpt is None:
return
print("start execute")
execute(args.workers, [covert_and_save] * world_size, args)
tok = time.time()
t = time.strftime("%H:%M:%S", time.gmtime(tok - tik))
print(f"Total time of converting checkpoints: {t}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,138 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import torch
import torch.nn as nn
from conv_stft import STFT
from huggingface_hub import hf_hub_download
from vocos import Vocos
opset_version = 17
def get_args():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
"--vocoder",
type=str,
default="vocos",
choices=["vocos", "bigvgan"],
help="Vocoder to export",
)
parser.add_argument(
"--output-path",
type=str,
default="./vocos_vocoder.onnx",
help="Output path",
)
return parser.parse_args()
class ISTFTHead(nn.Module):
def __init__(self, n_fft: int, hop_length: int):
super().__init__()
self.out = None
self.stft = STFT(fft_len=n_fft, win_hop=hop_length, win_len=n_fft)
def forward(self, x: torch.Tensor):
x = self.out(x).transpose(1, 2)
mag, p = x.chunk(2, dim=1)
mag = torch.exp(mag)
mag = torch.clip(mag, max=1e2)
real = mag * torch.cos(p)
imag = mag * torch.sin(p)
audio = self.stft.inverse(input1=real, input2=imag, input_type="realimag")
return audio
class VocosVocoder(nn.Module):
def __init__(self, vocos_vocoder):
super(VocosVocoder, self).__init__()
self.vocos_vocoder = vocos_vocoder
istft_head_out = self.vocos_vocoder.head.out
n_fft = self.vocos_vocoder.head.istft.n_fft
hop_length = self.vocos_vocoder.head.istft.hop_length
istft_head_for_export = ISTFTHead(n_fft, hop_length)
istft_head_for_export.out = istft_head_out
self.vocos_vocoder.head = istft_head_for_export
def forward(self, mel):
waveform = self.vocos_vocoder.decode(mel)
return waveform
def export_VocosVocoder(vocos_vocoder, output_path, verbose):
vocos_vocoder = VocosVocoder(vocos_vocoder).cuda()
vocos_vocoder.eval()
dummy_batch_size = 8
dummy_input_length = 500
dummy_mel = torch.randn(dummy_batch_size, 100, dummy_input_length).cuda()
with torch.no_grad():
dummy_waveform = vocos_vocoder(mel=dummy_mel)
print(dummy_waveform.shape)
dummy_input = dummy_mel
torch.onnx.export(
vocos_vocoder,
dummy_input,
output_path,
opset_version=opset_version,
do_constant_folding=True,
input_names=["mel"],
output_names=["waveform"],
dynamic_axes={
"mel": {0: "batch_size", 2: "input_length"},
"waveform": {0: "batch_size", 1: "output_length"},
},
verbose=verbose,
)
print("Exported to {}".format(output_path))
def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device="cpu", hf_cache_dir=None):
if vocoder_name == "vocos":
# vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device)
if is_local:
print(f"Load vocos from local path {local_path}")
config_path = f"{local_path}/config.yaml"
model_path = f"{local_path}/pytorch_model.bin"
else:
print("Download Vocos from huggingface charactr/vocos-mel-24khz")
repo_id = "charactr/vocos-mel-24khz"
config_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml")
model_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="pytorch_model.bin")
vocoder = Vocos.from_hparams(config_path)
state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
vocoder.load_state_dict(state_dict)
vocoder = vocoder.eval().to(device)
elif vocoder_name == "bigvgan":
raise NotImplementedError("BigVGAN is not supported yet")
vocoder.remove_weight_norm()
vocoder = vocoder.eval().to(device)
return vocoder
if __name__ == "__main__":
args = get_args()
vocoder = load_vocoder(vocoder_name=args.vocoder, device="cpu", hf_cache_dir=None)
if args.vocoder == "vocos":
export_VocosVocoder(vocoder, args.output_path, verbose=False)

View File

@@ -0,0 +1,43 @@
#!/bin/bash
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
TRTEXEC="/usr/src/tensorrt/bin/trtexec"
ONNX_PATH=$1
ENGINE_PATH=$2
echo "ONNX_PATH: $ONNX_PATH"
echo "ENGINE_PATH: $ENGINE_PATH"
PRECISION="fp32"
MIN_BATCH_SIZE=1
OPT_BATCH_SIZE=1
MAX_BATCH_SIZE=8
MIN_INPUT_LENGTH=1
OPT_INPUT_LENGTH=1000
MAX_INPUT_LENGTH=3000
MEL_MIN_SHAPE="${MIN_BATCH_SIZE}x100x${MIN_INPUT_LENGTH}"
MEL_OPT_SHAPE="${OPT_BATCH_SIZE}x100x${OPT_INPUT_LENGTH}"
MEL_MAX_SHAPE="${MAX_BATCH_SIZE}x100x${MAX_INPUT_LENGTH}"
${TRTEXEC} \
--minShapes="mel:${MEL_MIN_SHAPE}" \
--optShapes="mel:${MEL_OPT_SHAPE}" \
--maxShapes="mel:${MEL_MAX_SHAPE}" \
--onnx=${ONNX_PATH} \
--saveEngine=${ENGINE_PATH}

View File

@@ -0,0 +1,36 @@
#! /usr/bin/env python3
from argparse import ArgumentParser
from string import Template
def main(file_path, substitutions, in_place, participant_ids):
with open(file_path) as f:
pbtxt = Template(f.read())
sub_dict = {"max_queue_size": 0}
sub_dict["participant_ids"] = participant_ids
for sub in substitutions.split(","):
key, value = sub.split(":")
sub_dict[key] = value
pbtxt = pbtxt.safe_substitute(sub_dict)
if in_place:
with open(file_path, "w") as f:
f.write(pbtxt)
else:
print(pbtxt)
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("file_path", help="path of the .pbtxt to modify")
parser.add_argument(
"substitutions",
help="substitutions to perform, in the format variable_name_1:value_1,variable_name_2:value_2...",
)
parser.add_argument("--in_place", "-i", action="store_true", help="do the operation in-place")
parser.add_argument("--participant_ids", help="Participant IDs for the model", default="")
args = parser.parse_args()
main(**vars(args))

View File

@@ -0,0 +1,33 @@
"""ADAPTIVE BATCH SIZE"""
print("Adaptive batch size: using grouping batch sampler, frames_per_gpu fixed fed in")
print(" -> least padding, gather wavs with accumulated frames in a batch\n")
# data
total_hours = 95282
mel_hop_length = 256
mel_sampling_rate = 24000
# target
wanted_max_updates = 1200000
# train params
gpus = 8
frames_per_gpu = 38400 # 8 * 38400 = 307200
grad_accum = 1
# intermediate
mini_batch_frames = frames_per_gpu * grad_accum * gpus
mini_batch_hours = mini_batch_frames * mel_hop_length / mel_sampling_rate / 3600
updates_per_epoch = total_hours / mini_batch_hours
# steps_per_epoch = updates_per_epoch * grad_accum
# result
epochs = wanted_max_updates / updates_per_epoch
print(f"epochs should be set to: {epochs:.0f} ({epochs / grad_accum:.1f} x gd_acum {grad_accum})")
print(f"progress_bar should show approx. 0/{updates_per_epoch:.0f} updates")
# print(f" or approx. 0/{steps_per_epoch:.0f} steps")
# others
print(f"total {total_hours:.0f} hours")
print(f"mini-batch of {mini_batch_frames:.0f} frames, {mini_batch_hours:.2f} hours per mini-batch")

View File

@@ -0,0 +1,40 @@
import os
import sys
sys.path.append(os.getcwd())
import thop
import torch
from f5_tts.model import CFM, DiT
""" ~155M """
# transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4)
# transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4, text_dim = 512, conv_layers = 4)
# transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2)
# transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2, text_dim = 512, conv_layers = 4)
# transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2, text_dim = 512, conv_layers = 4, long_skip_connection = True)
# transformer = MMDiT(dim = 512, depth = 16, heads = 16, ff_mult = 2)
""" ~335M """
# FLOPs: 622.1 G, Params: 333.2 M
# transformer = UNetT(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
# FLOPs: 363.4 G, Params: 335.8 M
transformer = DiT(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
model = CFM(transformer=transformer)
target_sample_rate = 24000
n_mel_channels = 100
hop_length = 256
duration = 20
frame_length = int(duration * target_sample_rate / hop_length)
text_length = 150
flops, params = thop.profile(
model, inputs=(torch.randn(1, frame_length, n_mel_channels), torch.zeros(1, text_length, dtype=torch.long))
)
print(f"FLOPs: {flops / 1e9} G")
print(f"Params: {params / 1e6} M")

View File

@@ -0,0 +1,63 @@
import asyncio
import logging
import socket
import time
import numpy as np
import pyaudio
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
async def listen_to_F5TTS(text, server_ip="localhost", server_port=9998):
client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
await asyncio.get_event_loop().run_in_executor(None, client_socket.connect, (server_ip, int(server_port)))
start_time = time.time()
first_chunk_time = None
async def play_audio_stream():
nonlocal first_chunk_time
p = pyaudio.PyAudio()
stream = p.open(format=pyaudio.paFloat32, channels=1, rate=24000, output=True, frames_per_buffer=2048)
try:
while True:
data = await asyncio.get_event_loop().run_in_executor(None, client_socket.recv, 8192)
if not data:
break
if data == b"END":
logger.info("End of audio received.")
break
audio_array = np.frombuffer(data, dtype=np.float32)
stream.write(audio_array.tobytes())
if first_chunk_time is None:
first_chunk_time = time.time()
finally:
stream.stop_stream()
stream.close()
p.terminate()
logger.info(f"Total time taken: {time.time() - start_time:.4f} seconds")
try:
data_to_send = f"{text}".encode("utf-8")
await asyncio.get_event_loop().run_in_executor(None, client_socket.sendall, data_to_send)
await play_audio_stream()
except Exception as e:
logger.error(f"Error in listen_to_F5TTS: {e}")
finally:
client_socket.close()
if __name__ == "__main__":
text_to_send = "As a Reader assistant, I'm familiar with new technology. which are key to its improved performance in terms of both training speed and inference efficiency. Let's break down the components"
asyncio.run(listen_to_F5TTS(text_to_send))

View File

@@ -0,0 +1,268 @@
import argparse
import gc
import logging
import queue
import socket
import struct
import threading
import traceback
import wave
from importlib.resources import files
import numpy as np
import torch
import torchaudio
from huggingface_hub import hf_hub_download
from hydra.utils import get_class
from omegaconf import OmegaConf
from f5_tts.infer.utils_infer import (
chunk_text,
infer_batch_process,
load_model,
load_vocoder,
preprocess_ref_audio_text,
)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class AudioFileWriterThread(threading.Thread):
"""Threaded file writer to avoid blocking the TTS streaming process."""
def __init__(self, output_file, sampling_rate):
super().__init__()
self.output_file = output_file
self.sampling_rate = sampling_rate
self.queue = queue.Queue()
self.stop_event = threading.Event()
self.audio_data = []
def run(self):
"""Process queued audio data and write it to a file."""
logger.info("AudioFileWriterThread started.")
with wave.open(self.output_file, "wb") as wf:
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(self.sampling_rate)
while not self.stop_event.is_set() or not self.queue.empty():
try:
chunk = self.queue.get(timeout=0.1)
if chunk is not None:
chunk = np.int16(chunk * 32767)
self.audio_data.append(chunk)
wf.writeframes(chunk.tobytes())
except queue.Empty:
continue
def add_chunk(self, chunk):
"""Add a new chunk to the queue."""
self.queue.put(chunk)
def stop(self):
"""Stop writing and ensure all queued data is written."""
self.stop_event.set()
self.join()
logger.info("Audio writing completed.")
class TTSStreamingProcessor:
def __init__(self, model, ckpt_file, vocab_file, ref_audio, ref_text, device=None, dtype=torch.float32):
self.device = device or (
"cuda"
if torch.cuda.is_available()
else "xpu"
if torch.xpu.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
self.model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
self.model_arc = model_cfg.model.arch
self.mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
self.sampling_rate = model_cfg.model.mel_spec.target_sample_rate
self.model = self.load_ema_model(ckpt_file, vocab_file, dtype)
self.vocoder = self.load_vocoder_model()
self.update_reference(ref_audio, ref_text)
self._warm_up()
self.file_writer_thread = None
self.first_package = True
def load_ema_model(self, ckpt_file, vocab_file, dtype):
return load_model(
self.model_cls,
self.model_arc,
ckpt_path=ckpt_file,
mel_spec_type=self.mel_spec_type,
vocab_file=vocab_file,
ode_method="euler",
use_ema=True,
device=self.device,
).to(self.device, dtype=dtype)
def load_vocoder_model(self):
return load_vocoder(vocoder_name=self.mel_spec_type, is_local=False, local_path=None, device=self.device)
def update_reference(self, ref_audio, ref_text):
self.ref_audio, self.ref_text = preprocess_ref_audio_text(ref_audio, ref_text)
self.audio, self.sr = torchaudio.load(self.ref_audio)
ref_audio_duration = self.audio.shape[-1] / self.sr
ref_text_byte_len = len(self.ref_text.encode("utf-8"))
self.max_chars = int(ref_text_byte_len / (ref_audio_duration) * (25 - ref_audio_duration))
self.few_chars = int(ref_text_byte_len / (ref_audio_duration) * (25 - ref_audio_duration) / 2)
self.min_chars = int(ref_text_byte_len / (ref_audio_duration) * (25 - ref_audio_duration) / 4)
def _warm_up(self):
logger.info("Warming up the model...")
gen_text = "Warm-up text for the model."
for _ in infer_batch_process(
(self.audio, self.sr),
self.ref_text,
[gen_text],
self.model,
self.vocoder,
progress=None,
device=self.device,
streaming=True,
):
pass
logger.info("Warm-up completed.")
def generate_stream(self, text, conn):
text_batches = chunk_text(text, max_chars=self.max_chars)
if self.first_package:
text_batches = chunk_text(text_batches[0], max_chars=self.few_chars) + text_batches[1:]
text_batches = chunk_text(text_batches[0], max_chars=self.min_chars) + text_batches[1:]
self.first_package = False
audio_stream = infer_batch_process(
(self.audio, self.sr),
self.ref_text,
text_batches,
self.model,
self.vocoder,
progress=None,
device=self.device,
streaming=True,
chunk_size=2048,
)
# Reset the file writer thread
if self.file_writer_thread is not None:
self.file_writer_thread.stop()
self.file_writer_thread = AudioFileWriterThread("output.wav", self.sampling_rate)
self.file_writer_thread.start()
for audio_chunk, _ in audio_stream:
if len(audio_chunk) > 0:
logger.info(f"Generated audio chunk of size: {len(audio_chunk)}")
# Send audio chunk via socket
conn.sendall(struct.pack(f"{len(audio_chunk)}f", *audio_chunk))
# Write to file asynchronously
self.file_writer_thread.add_chunk(audio_chunk)
logger.info("Finished sending audio stream.")
conn.sendall(b"END") # Send end signal
# Ensure all audio data is written before exiting
self.file_writer_thread.stop()
def handle_client(conn, processor):
try:
with conn:
conn.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
while True:
data = conn.recv(1024)
if not data:
processor.first_package = True
break
data_str = data.decode("utf-8").strip()
logger.info(f"Received text: {data_str}")
try:
processor.generate_stream(data_str, conn)
except Exception as inner_e:
logger.error(f"Error during processing: {inner_e}")
traceback.print_exc()
break
except Exception as e:
logger.error(f"Error handling client: {e}")
traceback.print_exc()
def start_server(host, port, processor):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind((host, port))
s.listen()
logger.info(f"Server started on {host}:{port}")
while True:
conn, addr = s.accept()
logger.info(f"Connected by {addr}")
handle_client(conn, processor)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", default="0.0.0.0")
parser.add_argument("--port", default=9998)
parser.add_argument(
"--model",
default="F5TTS_v1_Base",
help="The model name, e.g. F5TTS_v1_Base",
)
parser.add_argument(
"--ckpt_file",
default=str(hf_hub_download(repo_id="SWivid/F5-TTS", filename="F5TTS_v1_Base/model_1250000.safetensors")),
help="Path to the model checkpoint file",
)
parser.add_argument(
"--vocab_file",
default="",
help="Path to the vocab file if customized",
)
parser.add_argument(
"--ref_audio",
default=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")),
help="Reference audio to provide model with speaker characteristics",
)
parser.add_argument(
"--ref_text",
default="",
help="Reference audio subtitle, leave empty to auto-transcribe",
)
parser.add_argument("--device", default=None, help="Device to run the model on")
parser.add_argument("--dtype", default=torch.float32, help="Data type to use for model inference")
args = parser.parse_args()
try:
# Initialize the processor with the model and vocoder
processor = TTSStreamingProcessor(
model=args.model,
ckpt_file=args.ckpt_file,
vocab_file=args.vocab_file,
ref_audio=args.ref_audio,
ref_text=args.ref_text,
device=args.device,
dtype=args.dtype,
)
# Start the server
start_server(args.host, args.port, processor)
except KeyboardInterrupt:
gc.collect()

View File

@@ -0,0 +1,92 @@
# Training
Check your FFmpeg installation:
```bash
ffmpeg -version
```
If not found, install it first (or skip assuming you know of other backends available).
## Prepare Dataset
Example data processing scripts, and you may tailor your own one along with a Dataset class in `src/f5_tts/model/dataset.py`.
### 1. Some specific Datasets preparing scripts
Download corresponding dataset first, and fill in the path in scripts.
```bash
# Prepare the Emilia dataset
python src/f5_tts/train/datasets/prepare_emilia.py
# Prepare the Wenetspeech4TTS dataset
python src/f5_tts/train/datasets/prepare_wenetspeech4tts.py
# Prepare the LibriTTS dataset
python src/f5_tts/train/datasets/prepare_libritts.py
# Prepare the LJSpeech dataset
python src/f5_tts/train/datasets/prepare_ljspeech.py
```
### 2. Create custom dataset with metadata.csv
Use guidance see [#57 here](https://github.com/SWivid/F5-TTS/discussions/57#discussioncomment-10959029).
```bash
python src/f5_tts/train/datasets/prepare_csv_wavs.py
```
## Training & Finetuning
Once your datasets are prepared, you can start the training process.
### 1. Training script used for pretrained model
```bash
# setup accelerate config, e.g. use multi-gpu ddp, fp16
# will be to: ~/.cache/huggingface/accelerate/default_config.yaml
accelerate config
# .yaml files are under src/f5_tts/configs directory
accelerate launch src/f5_tts/train/train.py --config-name F5TTS_v1_Base.yaml
# possible to overwrite accelerate and hydra config
accelerate launch --mixed_precision=fp16 src/f5_tts/train/train.py --config-name F5TTS_v1_Base.yaml ++datasets.batch_size_per_gpu=19200
```
### 2. Finetuning practice
Discussion board for Finetuning [#57](https://github.com/SWivid/F5-TTS/discussions/57).
Gradio UI training/finetuning with `src/f5_tts/train/finetune_gradio.py` see [#143](https://github.com/SWivid/F5-TTS/discussions/143).
If want to finetune with a variant version e.g. *F5TTS_v1_Base_no_zero_init*, manually download pretrained checkpoint from model weight repository and fill in the path correspondingly on web interface.
If use tensorboard as logger, install it first with `pip install tensorboard`.
<ins>The `use_ema = True` might be harmful for early-stage finetuned checkpoints</ins> (which goes just few updates, thus ema weights still dominated by pretrained ones), try turn it off with finetune gradio option or `load_model(..., use_ema=False)`, see if offer better results.
### 3. W&B Logging
The `wandb/` dir will be created under path you run training/finetuning scripts.
By default, the training script does NOT use logging (assuming you didn't manually log in using `wandb login`).
To turn on wandb logging, you can either:
1. Manually login with `wandb login`: Learn more [here](https://docs.wandb.ai/ref/cli/wandb-login)
2. Automatically login programmatically by setting an environment variable: Get an API KEY at https://wandb.ai/authorize and set the environment variable as follows:
On Mac & Linux:
```
export WANDB_API_KEY=<YOUR WANDB API KEY>
```
On Windows:
```
set WANDB_API_KEY=<YOUR WANDB API KEY>
```
Moreover, if you couldn't access W&B and want to log metrics offline, you can set the environment variable as follows:
```
export WANDB_MODE=offline
```

View File

@@ -0,0 +1,283 @@
import concurrent.futures
import multiprocessing
import os
import shutil
import signal
import subprocess # For invoking ffprobe
import sys
from contextlib import contextmanager
sys.path.append(os.getcwd())
import argparse
import csv
import json
from importlib.resources import files
from pathlib import Path
import torchaudio
from datasets.arrow_writer import ArrowWriter
from tqdm import tqdm
from f5_tts.model.utils import convert_char_to_pinyin
PRETRAINED_VOCAB_PATH = files("f5_tts").joinpath("../../data/Emilia_ZH_EN_pinyin/vocab.txt")
def is_csv_wavs_format(input_dataset_dir):
fpath = Path(input_dataset_dir)
metadata = fpath / "metadata.csv"
wavs = fpath / "wavs"
return metadata.exists() and metadata.is_file() and wavs.exists() and wavs.is_dir()
# Configuration constants
BATCH_SIZE = 100 # Batch size for text conversion
MAX_WORKERS = max(1, multiprocessing.cpu_count() - 1) # Leave one CPU free
THREAD_NAME_PREFIX = "AudioProcessor"
CHUNK_SIZE = 100 # Number of files to process per worker batch
executor = None # Global executor for cleanup
@contextmanager
def graceful_exit():
"""Context manager for graceful shutdown on signals"""
def signal_handler(signum, frame):
print("\nReceived signal to terminate. Cleaning up...")
if executor is not None:
print("Shutting down executor...")
executor.shutdown(wait=False, cancel_futures=True)
sys.exit(1)
# Set up signal handlers
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
try:
yield
finally:
if executor is not None:
executor.shutdown(wait=False)
def process_audio_file(audio_path, text, polyphone):
"""Process a single audio file by checking its existence and extracting duration."""
if not Path(audio_path).exists():
print(f"audio {audio_path} not found, skipping")
return None
try:
audio_duration = get_audio_duration(audio_path)
if audio_duration <= 0:
raise ValueError(f"Duration {audio_duration} is non-positive.")
return (audio_path, text, audio_duration)
except Exception as e:
print(f"Warning: Failed to process {audio_path} due to error: {e}. Skipping corrupt file.")
return None
def batch_convert_texts(texts, polyphone, batch_size=BATCH_SIZE):
"""Convert a list of texts to pinyin in batches."""
converted_texts = []
for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size]
converted_batch = convert_char_to_pinyin(batch, polyphone=polyphone)
converted_texts.extend(converted_batch)
return converted_texts
def prepare_csv_wavs_dir(input_dir, num_workers=None):
global executor
assert is_csv_wavs_format(input_dir), f"not csv_wavs format: {input_dir}"
input_dir = Path(input_dir)
metadata_path = input_dir / "metadata.csv"
audio_path_text_pairs = read_audio_text_pairs(metadata_path.as_posix())
polyphone = True
total_files = len(audio_path_text_pairs)
# Use provided worker count or calculate optimal number
worker_count = num_workers if num_workers is not None else min(MAX_WORKERS, total_files)
print(f"\nProcessing {total_files} audio files using {worker_count} workers...")
with graceful_exit():
# Initialize thread pool with optimized settings
with concurrent.futures.ThreadPoolExecutor(
max_workers=worker_count, thread_name_prefix=THREAD_NAME_PREFIX
) as exec:
executor = exec
results = []
# Process files in chunks for better efficiency
for i in range(0, len(audio_path_text_pairs), CHUNK_SIZE):
chunk = audio_path_text_pairs[i : i + CHUNK_SIZE]
# Submit futures in order
chunk_futures = [executor.submit(process_audio_file, pair[0], pair[1], polyphone) for pair in chunk]
# Iterate over futures in the original submission order to preserve ordering
for future in tqdm(
chunk_futures,
total=len(chunk),
desc=f"Processing chunk {i // CHUNK_SIZE + 1}/{(total_files + CHUNK_SIZE - 1) // CHUNK_SIZE}",
):
try:
result = future.result()
if result is not None:
results.append(result)
except Exception as e:
print(f"Error processing file: {e}")
executor = None
# Filter out failed results
processed = [res for res in results if res is not None]
if not processed:
raise RuntimeError("No valid audio files were processed!")
# Batch process text conversion
raw_texts = [item[1] for item in processed]
converted_texts = batch_convert_texts(raw_texts, polyphone, batch_size=BATCH_SIZE)
# Prepare final results
sub_result = []
durations = []
vocab_set = set()
for (audio_path, _, duration), conv_text in zip(processed, converted_texts):
sub_result.append({"audio_path": audio_path, "text": conv_text, "duration": duration})
durations.append(duration)
vocab_set.update(list(conv_text))
return sub_result, durations, vocab_set
def get_audio_duration(audio_path, timeout=5):
"""
Get the duration of an audio file in seconds using ffmpeg's ffprobe.
Falls back to torchaudio.load() if ffprobe fails.
"""
try:
cmd = [
"ffprobe",
"-v",
"error",
"-show_entries",
"format=duration",
"-of",
"default=noprint_wrappers=1:nokey=1",
audio_path,
]
result = subprocess.run(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=True, timeout=timeout
)
duration_str = result.stdout.strip()
if duration_str:
return float(duration_str)
raise ValueError("Empty duration string from ffprobe.")
except (subprocess.TimeoutExpired, subprocess.SubprocessError, ValueError) as e:
print(f"Warning: ffprobe failed for {audio_path} with error: {e}. Falling back to torchaudio.")
try:
audio, sample_rate = torchaudio.load(audio_path)
return audio.shape[1] / sample_rate
except Exception as e:
raise RuntimeError(f"Both ffprobe and torchaudio failed for {audio_path}: {e}")
def read_audio_text_pairs(csv_file_path):
audio_text_pairs = []
parent = Path(csv_file_path).parent
with open(csv_file_path, mode="r", newline="", encoding="utf-8-sig") as csvfile:
reader = csv.reader(csvfile, delimiter="|")
next(reader) # Skip the header row
for row in reader:
if len(row) >= 2:
audio_file = row[0].strip() # First column: audio file path
text = row[1].strip() # Second column: text
audio_file_path = parent / audio_file
audio_text_pairs.append((audio_file_path.as_posix(), text))
return audio_text_pairs
def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_finetune):
out_dir = Path(out_dir)
out_dir.mkdir(exist_ok=True, parents=True)
print(f"\nSaving to {out_dir} ...")
# Save dataset with improved batch size for better I/O performance
raw_arrow_path = out_dir / "raw.arrow"
with ArrowWriter(path=raw_arrow_path.as_posix(), writer_batch_size=100) as writer:
for line in tqdm(result, desc="Writing to raw.arrow ..."):
writer.write(line)
# Save durations to JSON
dur_json_path = out_dir / "duration.json"
with open(dur_json_path.as_posix(), "w", encoding="utf-8") as f:
json.dump({"duration": duration_list}, f, ensure_ascii=False)
# Handle vocab file - write only once based on finetune flag
voca_out_path = out_dir / "vocab.txt"
if is_finetune:
file_vocab_finetune = PRETRAINED_VOCAB_PATH.as_posix()
shutil.copy2(file_vocab_finetune, voca_out_path)
else:
with open(voca_out_path.as_posix(), "w") as f:
for vocab in sorted(text_vocab_set):
f.write(vocab + "\n")
dataset_name = out_dir.stem
print(f"\nFor {dataset_name}, sample count: {len(result)}")
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours")
def prepare_and_save_set(inp_dir, out_dir, is_finetune: bool = True, num_workers: int = None):
if is_finetune:
assert PRETRAINED_VOCAB_PATH.exists(), f"pretrained vocab.txt not found: {PRETRAINED_VOCAB_PATH}"
sub_result, durations, vocab_set = prepare_csv_wavs_dir(inp_dir, num_workers=num_workers)
save_prepped_dataset(out_dir, sub_result, durations, vocab_set, is_finetune)
def cli():
try:
# Before processing, check if ffprobe is available.
if shutil.which("ffprobe") is None:
print(
"Warning: ffprobe is not available. Duration extraction will rely on torchaudio (which may be slower)."
)
# Usage examples in help text
parser = argparse.ArgumentParser(
description="Prepare and save dataset.",
epilog="""
Examples:
# For fine-tuning (default):
python prepare_csv_wavs.py /input/dataset/path /output/dataset/path
# For pre-training:
python prepare_csv_wavs.py /input/dataset/path /output/dataset/path --pretrain
# With custom worker count:
python prepare_csv_wavs.py /input/dataset/path /output/dataset/path --workers 4
""",
)
parser.add_argument("inp_dir", type=str, help="Input directory containing the data.")
parser.add_argument("out_dir", type=str, help="Output directory to save the prepared data.")
parser.add_argument("--pretrain", action="store_true", help="Enable for new pretrain, otherwise is a fine-tune")
parser.add_argument("--workers", type=int, help=f"Number of worker threads (default: {MAX_WORKERS})")
args = parser.parse_args()
prepare_and_save_set(args.inp_dir, args.out_dir, is_finetune=not args.pretrain, num_workers=args.workers)
except KeyboardInterrupt:
print("\nOperation cancelled by user. Cleaning up...")
if executor is not None:
executor.shutdown(wait=False, cancel_futures=True)
sys.exit(1)
if __name__ == "__main__":
cli()

View File

@@ -0,0 +1,228 @@
# Emilia Dataset: https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07
# if use updated new version, i.e. WebDataset, feel free to modify / draft your own script
# generate audio text map for Emilia ZH & EN
# evaluate for vocab size
import os
import sys
sys.path.append(os.getcwd())
import json
from concurrent.futures import ProcessPoolExecutor
from importlib.resources import files
from pathlib import Path
from datasets.arrow_writer import ArrowWriter
from tqdm import tqdm
from f5_tts.model.utils import convert_char_to_pinyin, repetition_found
out_zh = {
"ZH_B00041_S06226",
"ZH_B00042_S09204",
"ZH_B00065_S09430",
"ZH_B00065_S09431",
"ZH_B00066_S09327",
"ZH_B00066_S09328",
}
zh_filters = ["", ""]
# seems synthesized audios, or heavily code-switched
out_en = {
"EN_B00013_S00913",
"EN_B00042_S00120",
"EN_B00055_S04111",
"EN_B00061_S00693",
"EN_B00061_S01494",
"EN_B00061_S03375",
"EN_B00059_S00092",
"EN_B00111_S04300",
"EN_B00100_S03759",
"EN_B00087_S03811",
"EN_B00059_S00950",
"EN_B00089_S00946",
"EN_B00078_S05127",
"EN_B00070_S04089",
"EN_B00074_S09659",
"EN_B00061_S06983",
"EN_B00061_S07060",
"EN_B00059_S08397",
"EN_B00082_S06192",
"EN_B00091_S01238",
"EN_B00089_S07349",
"EN_B00070_S04343",
"EN_B00061_S02400",
"EN_B00076_S01262",
"EN_B00068_S06467",
"EN_B00076_S02943",
"EN_B00064_S05954",
"EN_B00061_S05386",
"EN_B00066_S06544",
"EN_B00076_S06944",
"EN_B00072_S08620",
"EN_B00076_S07135",
"EN_B00076_S09127",
"EN_B00065_S00497",
"EN_B00059_S06227",
"EN_B00063_S02859",
"EN_B00075_S01547",
"EN_B00061_S08286",
"EN_B00079_S02901",
"EN_B00092_S03643",
"EN_B00096_S08653",
"EN_B00063_S04297",
"EN_B00063_S04614",
"EN_B00079_S04698",
"EN_B00104_S01666",
"EN_B00061_S09504",
"EN_B00061_S09694",
"EN_B00065_S05444",
"EN_B00063_S06860",
"EN_B00065_S05725",
"EN_B00069_S07628",
"EN_B00083_S03875",
"EN_B00071_S07665",
"EN_B00071_S07665",
"EN_B00062_S04187",
"EN_B00065_S09873",
"EN_B00065_S09922",
"EN_B00084_S02463",
"EN_B00067_S05066",
"EN_B00106_S08060",
"EN_B00073_S06399",
"EN_B00073_S09236",
"EN_B00087_S00432",
"EN_B00085_S05618",
"EN_B00064_S01262",
"EN_B00072_S01739",
"EN_B00059_S03913",
"EN_B00069_S04036",
"EN_B00067_S05623",
"EN_B00060_S05389",
"EN_B00060_S07290",
"EN_B00062_S08995",
}
en_filters = ["ا", "", ""]
def deal_with_audio_dir(audio_dir):
audio_jsonl = audio_dir.with_suffix(".jsonl")
sub_result, durations = [], []
vocab_set = set()
bad_case_zh = 0
bad_case_en = 0
with open(audio_jsonl, "r") as f:
lines = f.readlines()
for line in tqdm(lines, desc=f"{audio_jsonl.stem}"):
obj = json.loads(line)
text = obj["text"]
if obj["language"] == "zh":
if obj["wav"].split("/")[1] in out_zh or any(f in text for f in zh_filters) or repetition_found(text):
bad_case_zh += 1
continue
else:
text = text.translate(
str.maketrans({",": "", "!": "", "?": ""})
) # not "。" cuz much code-switched
if obj["language"] == "en":
if (
obj["wav"].split("/")[1] in out_en
or any(f in text for f in en_filters)
or repetition_found(text, length=4)
):
bad_case_en += 1
continue
if tokenizer == "pinyin":
text = convert_char_to_pinyin([text], polyphone=polyphone)[0]
duration = obj["duration"]
sub_result.append({"audio_path": str(audio_dir.parent / obj["wav"]), "text": text, "duration": duration})
durations.append(duration)
vocab_set.update(list(text))
return sub_result, durations, vocab_set, bad_case_zh, bad_case_en
def main():
assert tokenizer in ["pinyin", "char"]
result = []
duration_list = []
text_vocab_set = set()
total_bad_case_zh = 0
total_bad_case_en = 0
# process raw data
executor = ProcessPoolExecutor(max_workers=max_workers)
futures = []
for lang in langs:
dataset_path = Path(os.path.join(dataset_dir, lang))
[
futures.append(executor.submit(deal_with_audio_dir, audio_dir))
for audio_dir in dataset_path.iterdir()
if audio_dir.is_dir()
]
for futures in tqdm(futures, total=len(futures)):
sub_result, durations, vocab_set, bad_case_zh, bad_case_en = futures.result()
result.extend(sub_result)
duration_list.extend(durations)
text_vocab_set.update(vocab_set)
total_bad_case_zh += bad_case_zh
total_bad_case_en += bad_case_en
executor.shutdown()
# save preprocessed dataset to disk
if not os.path.exists(f"{save_dir}"):
os.makedirs(f"{save_dir}")
print(f"\nSaving to {save_dir} ...")
# dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list}) # oom
# dataset.save_to_disk(f"{save_dir}/raw", max_shard_size="2GB")
with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
for line in tqdm(result, desc="Writing to raw.arrow ..."):
writer.write(line)
# dup a json separately saving duration in case for DynamicBatchSampler ease
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
json.dump({"duration": duration_list}, f, ensure_ascii=False)
# vocab map, i.e. tokenizer
# add alphabets and symbols (optional, if plan to ft on de/fr etc.)
# if tokenizer == "pinyin":
# text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)])
with open(f"{save_dir}/vocab.txt", "w") as f:
for vocab in sorted(text_vocab_set):
f.write(vocab + "\n")
print(f"\nFor {dataset_name}, sample count: {len(result)}")
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours")
if "ZH" in langs:
print(f"Bad zh transcription case: {total_bad_case_zh}")
if "EN" in langs:
print(f"Bad en transcription case: {total_bad_case_en}\n")
if __name__ == "__main__":
max_workers = 32
tokenizer = "pinyin" # "pinyin" | "char"
polyphone = True
langs = ["ZH", "EN"]
dataset_dir = "<SOME_PATH>/Emilia_Dataset/raw"
dataset_name = f"Emilia_{'_'.join(langs)}_{tokenizer}"
save_dir = str(files("f5_tts").joinpath("../../")) + f"/data/{dataset_name}"
print(f"\nPrepare for {dataset_name}, will save to {save_dir}\n")
main()
# Emilia ZH & EN
# samples count 37837916 (after removal)
# pinyin vocab size 2543 (polyphone)
# total duration 95281.87 (hours)
# bad zh asr cnt 230435 (samples)
# bad eh asr cnt 37217 (samples)
# vocab size may be slightly different due to jieba tokenizer and pypinyin (e.g. way of polyphoneme)
# please be careful if using pretrained model, make sure the vocab.txt is same

View File

@@ -0,0 +1,94 @@
# put in src/f5_tts/train/datasets/prepare_emilia_v2.py
# prepares Emilia dataset with the new format w/ Emilia-YODAS
import json
import os
from concurrent.futures import ProcessPoolExecutor
from importlib.resources import files
from pathlib import Path
from datasets.arrow_writer import ArrowWriter
from tqdm import tqdm
from f5_tts.model.utils import repetition_found
# Define filters for exclusion
out_en = set()
en_filters = ["ا", "", ""]
def process_audio_directory(audio_dir):
sub_result, durations, vocab_set = [], [], set()
bad_case_en = 0
for file in audio_dir.iterdir():
if file.suffix == ".json":
with open(file, "r") as f:
obj = json.load(f)
text = obj["text"]
if any(f in text for f in en_filters) or repetition_found(text, length=4):
bad_case_en += 1
continue
duration = obj["duration"]
audio_file = file.with_suffix(".mp3")
if audio_file.exists():
sub_result.append({"audio_path": str(audio_file), "text": text, "duration": duration})
durations.append(duration)
vocab_set.update(list(text))
return sub_result, durations, vocab_set, bad_case_en
def main():
assert tokenizer in ["pinyin", "char"]
result, duration_list, text_vocab_set = [], [], set()
total_bad_case_en = 0
executor = ProcessPoolExecutor(max_workers=max_workers)
futures = []
dataset_path = Path(dataset_dir)
for sub_dir in dataset_path.iterdir():
if sub_dir.is_dir():
futures.append(executor.submit(process_audio_directory, sub_dir))
for future in tqdm(futures, total=len(futures)):
sub_result, durations, vocab_set, bad_case_en = future.result()
result.extend(sub_result)
duration_list.extend(durations)
text_vocab_set.update(vocab_set)
total_bad_case_en += bad_case_en
executor.shutdown()
if not os.path.exists(f"{save_dir}"):
os.makedirs(f"{save_dir}")
with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
for line in tqdm(result, desc="Writing to raw.arrow ..."):
writer.write(line)
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
json.dump({"duration": duration_list}, f, ensure_ascii=False)
with open(f"{save_dir}/vocab.txt", "w") as f:
for vocab in sorted(text_vocab_set):
f.write(vocab + "\n")
print(f"For {dataset_name}, sample count: {len(result)}")
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours")
print(f"Bad en transcription case: {total_bad_case_en}\n")
if __name__ == "__main__":
max_workers = 32
tokenizer = "char"
dataset_dir = "/home/ubuntu/emilia-dataset/Emilia-YODAS/EN"
dataset_name = f"Emilia_EN_{tokenizer}"
# save_dir = os.path.expanduser(f"~/F5-TTS/data/{dataset_name}")
save_dir = str(files("f5_tts").joinpath("../../")) + f"/data/{dataset_name}"
print(f"Prepare for {dataset_name}, will save to {save_dir}\n")
main()

View File

@@ -0,0 +1,94 @@
import os
import sys
sys.path.append(os.getcwd())
import json
from concurrent.futures import ProcessPoolExecutor
from importlib.resources import files
from pathlib import Path
import soundfile as sf
from datasets.arrow_writer import ArrowWriter
from tqdm import tqdm
def deal_with_audio_dir(audio_dir):
sub_result, durations = [], []
vocab_set = set()
audio_lists = list(audio_dir.rglob("*.wav"))
for line in audio_lists:
text_path = line.with_suffix(".normalized.txt")
text = open(text_path, "r").read().strip()
duration = sf.info(line).duration
if duration < 0.4 or duration > 30:
continue
sub_result.append({"audio_path": str(line), "text": text, "duration": duration})
durations.append(duration)
vocab_set.update(list(text))
return sub_result, durations, vocab_set
def main():
result = []
duration_list = []
text_vocab_set = set()
# process raw data
executor = ProcessPoolExecutor(max_workers=max_workers)
futures = []
for subset in tqdm(SUB_SET):
dataset_path = Path(os.path.join(dataset_dir, subset))
[
futures.append(executor.submit(deal_with_audio_dir, audio_dir))
for audio_dir in dataset_path.iterdir()
if audio_dir.is_dir()
]
for future in tqdm(futures, total=len(futures)):
sub_result, durations, vocab_set = future.result()
result.extend(sub_result)
duration_list.extend(durations)
text_vocab_set.update(vocab_set)
executor.shutdown()
# save preprocessed dataset to disk
if not os.path.exists(f"{save_dir}"):
os.makedirs(f"{save_dir}")
print(f"\nSaving to {save_dir} ...")
with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
for line in tqdm(result, desc="Writing to raw.arrow ..."):
writer.write(line)
# dup a json separately saving duration in case for DynamicBatchSampler ease
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
json.dump({"duration": duration_list}, f, ensure_ascii=False)
# vocab map, i.e. tokenizer
with open(f"{save_dir}/vocab.txt", "w") as f:
for vocab in sorted(text_vocab_set):
f.write(vocab + "\n")
print(f"\nFor {dataset_name}, sample count: {len(result)}")
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours")
if __name__ == "__main__":
max_workers = 36
tokenizer = "char" # "pinyin" | "char"
SUB_SET = ["train-clean-100", "train-clean-360", "train-other-500"]
dataset_dir = "<SOME_PATH>/LibriTTS"
dataset_name = f"LibriTTS_{'_'.join(SUB_SET)}_{tokenizer}".replace("train-clean-", "").replace("train-other-", "")
save_dir = str(files("f5_tts").joinpath("../../")) + f"/data/{dataset_name}"
print(f"\nPrepare for {dataset_name}, will save to {save_dir}\n")
main()
# For LibriTTS_100_360_500_char, sample count: 354218
# For LibriTTS_100_360_500_char, vocab size is: 78
# For LibriTTS_100_360_500_char, total 554.09 hours

View File

@@ -0,0 +1,67 @@
import os
import sys
sys.path.append(os.getcwd())
import json
from importlib.resources import files
from pathlib import Path
import soundfile as sf
from datasets.arrow_writer import ArrowWriter
from tqdm import tqdm
def main():
result = []
duration_list = []
text_vocab_set = set()
with open(meta_info, "r") as f:
lines = f.readlines()
for line in tqdm(lines):
uttr, text, norm_text = line.split("|")
norm_text = norm_text.strip()
wav_path = Path(dataset_dir) / "wavs" / f"{uttr}.wav"
duration = sf.info(wav_path).duration
if duration < 0.4 or duration > 30:
continue
result.append({"audio_path": str(wav_path), "text": norm_text, "duration": duration})
duration_list.append(duration)
text_vocab_set.update(list(norm_text))
# save preprocessed dataset to disk
if not os.path.exists(f"{save_dir}"):
os.makedirs(f"{save_dir}")
print(f"\nSaving to {save_dir} ...")
with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
for line in tqdm(result, desc="Writing to raw.arrow ..."):
writer.write(line)
# dup a json separately saving duration in case for DynamicBatchSampler ease
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
json.dump({"duration": duration_list}, f, ensure_ascii=False)
# vocab map, i.e. tokenizer
# add alphabets and symbols (optional, if plan to ft on de/fr etc.)
with open(f"{save_dir}/vocab.txt", "w") as f:
for vocab in sorted(text_vocab_set):
f.write(vocab + "\n")
print(f"\nFor {dataset_name}, sample count: {len(result)}")
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours")
if __name__ == "__main__":
tokenizer = "char" # "pinyin" | "char"
dataset_dir = "<SOME_PATH>/LJSpeech-1.1"
dataset_name = f"LJSpeech_{tokenizer}"
meta_info = os.path.join(dataset_dir, "metadata.csv")
save_dir = str(files("f5_tts").joinpath("../../")) + f"/data/{dataset_name}"
print(f"\nPrepare for {dataset_name}, will save to {save_dir}\n")
main()

View File

@@ -0,0 +1,126 @@
# generate audio text map for WenetSpeech4TTS
# evaluate for vocab size
import os
import sys
sys.path.append(os.getcwd())
import json
from concurrent.futures import ProcessPoolExecutor
from importlib.resources import files
import torchaudio
from datasets import Dataset
from tqdm import tqdm
from f5_tts.model.utils import convert_char_to_pinyin
def deal_with_sub_path_files(dataset_path, sub_path):
print(f"Dealing with: {sub_path}")
text_dir = os.path.join(dataset_path, sub_path, "txts")
audio_dir = os.path.join(dataset_path, sub_path, "wavs")
text_files = os.listdir(text_dir)
audio_paths, texts, durations = [], [], []
for text_file in tqdm(text_files):
with open(os.path.join(text_dir, text_file), "r", encoding="utf-8") as file:
first_line = file.readline().split("\t")
audio_nm = first_line[0]
audio_path = os.path.join(audio_dir, audio_nm + ".wav")
text = first_line[1].strip()
audio_paths.append(audio_path)
if tokenizer == "pinyin":
texts.extend(convert_char_to_pinyin([text], polyphone=polyphone))
elif tokenizer == "char":
texts.append(text)
audio, sample_rate = torchaudio.load(audio_path)
durations.append(audio.shape[-1] / sample_rate)
return audio_paths, texts, durations
def main():
assert tokenizer in ["pinyin", "char"]
audio_path_list, text_list, duration_list = [], [], []
executor = ProcessPoolExecutor(max_workers=max_workers)
futures = []
for dataset_path in dataset_paths:
sub_items = os.listdir(dataset_path)
sub_paths = [item for item in sub_items if os.path.isdir(os.path.join(dataset_path, item))]
for sub_path in sub_paths:
futures.append(executor.submit(deal_with_sub_path_files, dataset_path, sub_path))
for future in tqdm(futures, total=len(futures)):
audio_paths, texts, durations = future.result()
audio_path_list.extend(audio_paths)
text_list.extend(texts)
duration_list.extend(durations)
executor.shutdown()
if not os.path.exists("data"):
os.makedirs("data")
print(f"\nSaving to {save_dir} ...")
dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list})
dataset.save_to_disk(f"{save_dir}/raw", max_shard_size="2GB") # arrow format
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
json.dump(
{"duration": duration_list}, f, ensure_ascii=False
) # dup a json separately saving duration in case for DynamicBatchSampler ease
print("\nEvaluating vocab size (all characters and symbols / all phonemes) ...")
text_vocab_set = set()
for text in tqdm(text_list):
text_vocab_set.update(list(text))
# add alphabets and symbols (optional, if plan to ft on de/fr etc.)
if tokenizer == "pinyin":
text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)])
with open(f"{save_dir}/vocab.txt", "w") as f:
for vocab in sorted(text_vocab_set):
f.write(vocab + "\n")
print(f"\nFor {dataset_name}, sample count: {len(text_list)}")
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}\n")
if __name__ == "__main__":
max_workers = 32
tokenizer = "pinyin" # "pinyin" | "char"
polyphone = True
dataset_choice = 1 # 1: Premium, 2: Standard, 3: Basic
dataset_name = (
["WenetSpeech4TTS_Premium", "WenetSpeech4TTS_Standard", "WenetSpeech4TTS_Basic"][dataset_choice - 1]
+ "_"
+ tokenizer
)
dataset_paths = [
"<SOME_PATH>/WenetSpeech4TTS/Basic",
"<SOME_PATH>/WenetSpeech4TTS/Standard",
"<SOME_PATH>/WenetSpeech4TTS/Premium",
][-dataset_choice:]
save_dir = str(files("f5_tts").joinpath("../../")) + f"/data/{dataset_name}"
print(f"\nChoose Dataset: {dataset_name}, will save to {save_dir}\n")
main()
# Results (if adding alphabets with accents and symbols):
# WenetSpeech4TTS Basic Standard Premium
# samples count 3932473 1941220 407494
# pinyin vocab size 1349 1348 1344 (no polyphone)
# - - 1459 (polyphone)
# char vocab size 5264 5219 5042
# vocab size may be slightly different due to jieba tokenizer and pypinyin (e.g. way of polyphoneme)
# please be careful if using pretrained model, make sure the vocab.txt is same

View File

@@ -0,0 +1,214 @@
import argparse
import os
import shutil
from importlib.resources import files
from cached_path import cached_path
from f5_tts.model import CFM, DiT, Trainer, UNetT
from f5_tts.model.dataset import load_dataset
from f5_tts.model.utils import get_tokenizer
# -------------------------- Dataset Settings --------------------------- #
target_sample_rate = 24000
n_mel_channels = 100
hop_length = 256
win_length = 1024
n_fft = 1024
mel_spec_type = "vocos" # 'vocos' or 'bigvgan'
# -------------------------- Argument Parsing --------------------------- #
def parse_args():
parser = argparse.ArgumentParser(description="Train CFM Model")
parser.add_argument(
"--exp_name",
type=str,
default="F5TTS_v1_Base",
choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"],
help="Experiment name",
)
parser.add_argument("--dataset_name", type=str, default="Emilia_ZH_EN", help="Name of the dataset to use")
parser.add_argument("--learning_rate", type=float, default=1e-5, help="Learning rate for training")
parser.add_argument("--batch_size_per_gpu", type=int, default=3200, help="Batch size per GPU")
parser.add_argument(
"--batch_size_type", type=str, default="frame", choices=["frame", "sample"], help="Batch size type"
)
parser.add_argument("--max_samples", type=int, default=64, help="Max sequences per batch")
parser.add_argument("--grad_accumulation_steps", type=int, default=1, help="Gradient accumulation steps")
parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping")
parser.add_argument("--epochs", type=int, default=100, help="Number of training epochs")
parser.add_argument("--num_warmup_updates", type=int, default=20000, help="Warmup updates")
parser.add_argument("--save_per_updates", type=int, default=50000, help="Save checkpoint every N updates")
parser.add_argument(
"--keep_last_n_checkpoints",
type=int,
default=-1,
help="-1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints",
)
parser.add_argument("--last_per_updates", type=int, default=5000, help="Save last checkpoint every N updates")
parser.add_argument("--finetune", action="store_true", help="Use Finetune")
parser.add_argument("--pretrain", type=str, default=None, help="the path to the checkpoint")
parser.add_argument(
"--tokenizer", type=str, default="pinyin", choices=["pinyin", "char", "custom"], help="Tokenizer type"
)
parser.add_argument(
"--tokenizer_path",
type=str,
default=None,
help="Path to custom tokenizer vocab file (only used if tokenizer = 'custom')",
)
parser.add_argument(
"--log_samples",
action="store_true",
help="Log inferenced samples per ckpt save updates",
)
parser.add_argument("--logger", type=str, default=None, choices=[None, "wandb", "tensorboard"], help="logger")
parser.add_argument(
"--bnb_optimizer",
action="store_true",
help="Use 8-bit Adam optimizer from bitsandbytes",
)
return parser.parse_args()
# -------------------------- Training Settings -------------------------- #
def main():
args = parse_args()
checkpoint_path = str(files("f5_tts").joinpath(f"../../ckpts/{args.dataset_name}"))
# Model parameters based on experiment name
if args.exp_name == "F5TTS_v1_Base":
wandb_resume_id = None
model_cls = DiT
model_cfg = dict(
dim=1024,
depth=22,
heads=16,
ff_mult=2,
text_dim=512,
conv_layers=4,
)
if args.finetune:
if args.pretrain is None:
ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors"))
else:
ckpt_path = args.pretrain
elif args.exp_name == "F5TTS_Base":
wandb_resume_id = None
model_cls = DiT
model_cfg = dict(
dim=1024,
depth=22,
heads=16,
ff_mult=2,
text_dim=512,
text_mask_padding=False,
conv_layers=4,
pe_attn_head=1,
)
if args.finetune:
if args.pretrain is None:
ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
else:
ckpt_path = args.pretrain
elif args.exp_name == "E2TTS_Base":
wandb_resume_id = None
model_cls = UNetT
model_cfg = dict(
dim=1024,
depth=24,
heads=16,
ff_mult=4,
text_mask_padding=False,
pe_attn_head=1,
)
if args.finetune:
if args.pretrain is None:
ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
else:
ckpt_path = args.pretrain
if args.finetune:
if not os.path.isdir(checkpoint_path):
os.makedirs(checkpoint_path, exist_ok=True)
file_checkpoint = os.path.basename(ckpt_path)
if not file_checkpoint.startswith("pretrained_"): # Change: Add 'pretrained_' prefix to copied model
file_checkpoint = "pretrained_" + file_checkpoint
file_checkpoint = os.path.join(checkpoint_path, file_checkpoint)
if not os.path.isfile(file_checkpoint):
shutil.copy2(ckpt_path, file_checkpoint)
print("copy checkpoint for finetune")
# Use the tokenizer and tokenizer_path provided in the command line arguments
tokenizer = args.tokenizer
if tokenizer == "custom":
if not args.tokenizer_path:
raise ValueError("Custom tokenizer selected, but no tokenizer_path provided.")
tokenizer_path = args.tokenizer_path
else:
tokenizer_path = args.dataset_name
vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
print("\nvocab : ", vocab_size)
print("\nvocoder : ", mel_spec_type)
mel_spec_kwargs = dict(
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
n_mel_channels=n_mel_channels,
target_sample_rate=target_sample_rate,
mel_spec_type=mel_spec_type,
)
model = CFM(
transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
mel_spec_kwargs=mel_spec_kwargs,
vocab_char_map=vocab_char_map,
)
trainer = Trainer(
model,
args.epochs,
args.learning_rate,
num_warmup_updates=args.num_warmup_updates,
save_per_updates=args.save_per_updates,
keep_last_n_checkpoints=args.keep_last_n_checkpoints,
checkpoint_path=checkpoint_path,
batch_size_per_gpu=args.batch_size_per_gpu,
batch_size_type=args.batch_size_type,
max_samples=args.max_samples,
grad_accumulation_steps=args.grad_accumulation_steps,
max_grad_norm=args.max_grad_norm,
logger=args.logger,
wandb_project=args.dataset_name,
wandb_run_name=args.exp_name,
wandb_resume_id=wandb_resume_id,
log_samples=args.log_samples,
last_per_updates=args.last_per_updates,
bnb_optimizer=args.bnb_optimizer,
)
train_dataset = load_dataset(args.dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
trainer.train(
train_dataset,
resumable_with_seed=666, # seed for shuffling dataset
)
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,77 @@
# training script.
import os
from importlib.resources import files
import hydra
from omegaconf import OmegaConf
from f5_tts.model import CFM, Trainer
from f5_tts.model.dataset import load_dataset
from f5_tts.model.utils import get_tokenizer
os.chdir(str(files("f5_tts").joinpath("../.."))) # change working directory to root of project (local editable)
@hydra.main(version_base="1.3", config_path=str(files("f5_tts").joinpath("configs")), config_name=None)
def main(model_cfg):
model_cls = hydra.utils.get_class(f"f5_tts.model.{model_cfg.model.backbone}")
model_arc = model_cfg.model.arch
tokenizer = model_cfg.model.tokenizer
mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
exp_name = f"{model_cfg.model.name}_{mel_spec_type}_{model_cfg.model.tokenizer}_{model_cfg.datasets.name}"
wandb_resume_id = None
# set text tokenizer
if tokenizer != "custom":
tokenizer_path = model_cfg.datasets.name
else:
tokenizer_path = model_cfg.model.tokenizer_path
vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
# set model
model = CFM(
transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=model_cfg.model.mel_spec.n_mel_channels),
mel_spec_kwargs=model_cfg.model.mel_spec,
vocab_char_map=vocab_char_map,
)
# init trainer
trainer = Trainer(
model,
epochs=model_cfg.optim.epochs,
learning_rate=model_cfg.optim.learning_rate,
num_warmup_updates=model_cfg.optim.num_warmup_updates,
save_per_updates=model_cfg.ckpts.save_per_updates,
keep_last_n_checkpoints=model_cfg.ckpts.keep_last_n_checkpoints,
checkpoint_path=str(files("f5_tts").joinpath(f"../../{model_cfg.ckpts.save_dir}")),
batch_size_per_gpu=model_cfg.datasets.batch_size_per_gpu,
batch_size_type=model_cfg.datasets.batch_size_type,
max_samples=model_cfg.datasets.max_samples,
grad_accumulation_steps=model_cfg.optim.grad_accumulation_steps,
max_grad_norm=model_cfg.optim.max_grad_norm,
logger=model_cfg.ckpts.logger,
wandb_project="CFM-TTS",
wandb_run_name=exp_name,
wandb_resume_id=wandb_resume_id,
last_per_updates=model_cfg.ckpts.last_per_updates,
log_samples=model_cfg.ckpts.log_samples,
bnb_optimizer=model_cfg.optim.bnb_optimizer,
mel_spec_type=mel_spec_type,
is_local_vocoder=model_cfg.model.vocoder.is_local,
local_vocoder_path=model_cfg.model.vocoder.local_path,
model_cfg_dict=OmegaConf.to_container(model_cfg, resolve=True),
)
train_dataset = load_dataset(model_cfg.datasets.name, tokenizer, mel_spec_kwargs=model_cfg.model.mel_spec)
trainer.train(
train_dataset,
num_workers=model_cfg.datasets.num_workers,
resumable_with_seed=666, # seed for shuffling dataset
)
if __name__ == "__main__":
main()

45
mr_v100-f5-tts/README.md Normal file
View File

@@ -0,0 +1,45 @@
# F5-TTS
本项目基于 **F5-TTS** 模型封装,提供简洁的 Docker 部署方式,支持 **SSML 输入**,输出 **PCM 原始音频**,可用于语音合成。
---
## Quickstart
### 1. 安装镜像
```bash
docker build -t tts:f5 . -f Dockerfile_f5
```
### 2. 启动服务
```bash
docker run -it --rm \
-v /models/F5-TTS_Emilia-ZH-EN:/mnt/models \
-v /dev:/dev \
--device=/dev/iluvatar0:/dev/iluvatar0 \
-p 8080:80 \
-e MODEL_DIR=/mnt/models \
-e MODEL_NAME=model_1250000.safetensors \
tts:f5
```
参数说明:
- `MODEL_DIR`:模型所在目录(挂载到容器内 `/mnt/models`
- `MODEL_NAME`:加载的模型文件名(通常为 `.safetensors`
- `-p 8080:80`:将容器内服务端口映射到宿主机 `8080`
- `--device=/dev/iluvatar0:/dev/iluvatar0`:指定推理设备(如 GPU/加速卡)
### 3. 测试服务
```bash
curl --request POST "http://localhost:8080/tts" \
--header 'Content-Type: application/ssml+xml' \
--header 'User-Agent: curl' \
--data-raw '<speak version="1.0" xml:lang="zh">
<voice xml:lang="zh" xml:gender="Female" name="zh">
今天天气很好,不知道明天天气怎么样。
</voice>
</speak>' \
--output sound.pcm
```
---

View File

@@ -0,0 +1,34 @@
*.7z filter=lfs diff=lfs merge=lfs -text
*.arrow filter=lfs diff=lfs merge=lfs -text
*.bin filter=lfs diff=lfs merge=lfs -text
*.bz2 filter=lfs diff=lfs merge=lfs -text
*.ckpt filter=lfs diff=lfs merge=lfs -text
*.ftz filter=lfs diff=lfs merge=lfs -text
*.gz filter=lfs diff=lfs merge=lfs -text
*.h5 filter=lfs diff=lfs merge=lfs -text
*.joblib filter=lfs diff=lfs merge=lfs -text
*.lfs.* filter=lfs diff=lfs merge=lfs -text
*.mlmodel filter=lfs diff=lfs merge=lfs -text
*.model filter=lfs diff=lfs merge=lfs -text
*.msgpack filter=lfs diff=lfs merge=lfs -text
*.npy filter=lfs diff=lfs merge=lfs -text
*.npz filter=lfs diff=lfs merge=lfs -text
*.onnx filter=lfs diff=lfs merge=lfs -text
*.ot filter=lfs diff=lfs merge=lfs -text
*.parquet filter=lfs diff=lfs merge=lfs -text
*.pb filter=lfs diff=lfs merge=lfs -text
*.pickle filter=lfs diff=lfs merge=lfs -text
*.pkl filter=lfs diff=lfs merge=lfs -text
*.pt filter=lfs diff=lfs merge=lfs -text
*.pth filter=lfs diff=lfs merge=lfs -text
*.rar filter=lfs diff=lfs merge=lfs -text
*.safetensors filter=lfs diff=lfs merge=lfs -text
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.tar.* filter=lfs diff=lfs merge=lfs -text
*.tflite filter=lfs diff=lfs merge=lfs -text
*.tgz filter=lfs diff=lfs merge=lfs -text
*.wasm filter=lfs diff=lfs merge=lfs -text
*.xz filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text

View File

@@ -0,0 +1,71 @@
---
license: mit
---
# Vocos: Closing the gap between time-domain and Fourier-based neural vocoders for high-quality audio synthesis
[Audio samples](https://charactr-platform.github.io/vocos/) |
Paper [[abs]](https://arxiv.org/abs/2306.00814) [[pdf]](https://arxiv.org/pdf/2306.00814.pdf)
Vocos is a fast neural vocoder designed to synthesize audio waveforms from acoustic features. Trained using a Generative
Adversarial Network (GAN) objective, Vocos can generate waveforms in a single forward pass. Unlike other typical
GAN-based vocoders, Vocos does not model audio samples in the time domain. Instead, it generates spectral
coefficients, facilitating rapid audio reconstruction through inverse Fourier transform.
## Installation
To use Vocos only in inference mode, install it using:
```bash
pip install vocos
```
If you wish to train the model, install it with additional dependencies:
```bash
pip install vocos[train]
```
## Usage
### Reconstruct audio from mel-spectrogram
```python
import torch
from vocos import Vocos
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
mel = torch.randn(1, 100, 256) # B, C, T
audio = vocos.decode(mel)
```
Copy-synthesis from a file:
```python
import torchaudio
y, sr = torchaudio.load(YOUR_AUDIO_FILE)
if y.size(0) > 1: # mix to mono
y = y.mean(dim=0, keepdim=True)
y = torchaudio.functional.resample(y, orig_freq=sr, new_freq=24000)
y_hat = vocos(y)
```
## Citation
If this code contributes to your research, please cite our work:
```
@article{siuzdak2023vocos,
title={Vocos: Closing the gap between time-domain and Fourier-based neural vocoders for high-quality audio synthesis},
author={Siuzdak, Hubert},
journal={arXiv preprint arXiv:2306.00814},
year={2023}
}
```
## License
The code in this repository is released under the MIT license.

View File

@@ -0,0 +1,24 @@
feature_extractor:
class_path: vocos.feature_extractors.MelSpectrogramFeatures
init_args:
sample_rate: 24000
n_fft: 1024
hop_length: 256
n_mels: 100
padding: center
backbone:
class_path: vocos.models.VocosBackbone
init_args:
input_channels: 100
dim: 512
intermediate_dim: 1536
num_layers: 8
head:
class_path: vocos.heads.ISTFTHead
init_args:
dim: 512
n_fft: 1024
hop_length: 256
padding: center

BIN
mr_v100-f5-tts/charactr/vocos-mel-24khz/pytorch_model.bin (Stored with Git LFS) Normal file

Binary file not shown.

View File

@@ -0,0 +1,2 @@
torch==2.4.1+corex.4.3.0
accelerate==1.6.0

341
mr_v100-f5-tts/f5_server.py Normal file
View File

@@ -0,0 +1,341 @@
import os
model_dir = os.getenv("MODEL_DIR", "/mounted_model")
model_name = os.getenv("MODEL_NAME", "model.safetensors")
import logging
logging.basicConfig(
format="%(asctime)s %(name)-12s %(levelname)-4s %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=os.environ.get("LOGLEVEL", "INFO"),
)
logger = logging.getLogger(__file__)
# enable custom patcher if available
patcher_path = os.path.join(model_dir, "custom_patcher.py")
if os.path.exists(patcher_path):
import shutil
shutil.copyfile(patcher_path, "custom_patcher.py")
try:
import custom_patcher
logger.info("Custom patcher has been applied.")
except ImportError:
logger.info("Failed to import custom_patcher. Ensure it is a valid Python module.")
else:
logger.info("No custom_patcher found.")
import torch
torch.set_num_threads(4)
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_math_sdp(True)
from torch import Tensor
from typing import Optional, List
import torch.nn.functional as F
# def custom_conv1d_forward(self, input: Tensor, debug=False) -> Tensor:
# with torch.amp.autocast(input.device.type, dtype=torch.float):
# return self._conv_forward(input, self.weight, self.bias)
# torch.nn.Conv1d.forward = custom_conv1d_forward
def conv_transpose1d_forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
if self.padding_mode != 'zeros':
raise ValueError('Only `zeros` padding mode is supported for ConvTranspose1d')
assert isinstance(self.padding, tuple)
# One cannot replace List by Tuple or Sequence in "_output_padding" because
# TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
num_spatial_dims = 1
output_padding = self._output_padding(
input, output_size, self.stride, self.padding, self.kernel_size, # type: ignore[arg-type]
num_spatial_dims, self.dilation) # type: ignore[arg-type]
with torch.amp.autocast('cuda', dtype=torch.float16):
return F.conv_transpose1d(
input, self.weight, self.bias, self.stride, self.padding,
output_padding, self.groups, self.dilation).float()
torch.nn.ConvTranspose1d.forward = conv_transpose1d_forward
from f5_tts.infer.utils_infer import (
load_vocoder,
load_model,
preprocess_ref_audio_text,
infer_process,
infer_batch_process,
)
from omegaconf import OmegaConf
from hydra.utils import get_class
import torch
import re
import numpy as np
import soundfile as sf
import torchaudio
from scipy import signal
import io
import time
from fastapi import FastAPI, Request, Response, Body, HTTPException
from fastapi import UploadFile, File, Form
from fastapi.responses import StreamingResponse, JSONResponse
from contextlib import asynccontextmanager
import uvicorn
import os
import hashlib
import xml.etree.ElementTree as ET
from typing import Union
vocoder_dir = os.getenv('VOCODER_DIR', '/workspace/charactr/vocos-mel-24khz')
speed = float(os.getenv('SPEED', 1.0))
ema_model = None
vocoder = None
voice_dict = {}
device = 'cuda' if torch.cuda.is_available() else 'cpu'
TARGET_SR = 16000
N_ZEROS = 20
# std_ref_audio_file = os.path.join(model_dir, 'ref_audio.wav')
# std_ref_text_file = os.path.join(model_dir, 'ref_text.txt')
std_ref_audio_file = '/workspace/ref_audio.wav'
std_ref_text_file = '/workspace/ref_text.txt'
std_ref_audio = None
std_ref_text = None
def init():
global ema_model, vocoder
global std_ref_audio, std_ref_text
logger.info(f'{device=}')
# load vocoder
vocoder_name = 'vocos'
vocoder = load_vocoder(vocoder_name=vocoder_name, is_local=True, local_path=vocoder_dir, device=device)
# load TTS model
model_cfg = OmegaConf.load('/workspace/F5-TTS/src/f5_tts/configs/F5TTS_v1_Base.yaml')
model_cls = get_class(f'f5_tts.model.{model_cfg.model.backbone}')
model_arc = model_cfg.model.arch
ckpt_file = os.path.join(model_dir, model_name)
vocab_file = os.path.join(model_dir, 'vocab.txt')
ema_model = load_model(
model_cls, model_arc, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file, device=device
)
with open(std_ref_audio_file, 'rb') as f:
std_ref_audio = f.read()
with open(std_ref_text_file, 'r', encoding='utf-8') as f:
std_ref_text = f.read().strip()
@asynccontextmanager
async def lifespan(app: FastAPI):
init()
yield
pass
app = FastAPI(lifespan=lifespan)
@app.get("/health")
@app.get("/ready")
async def ready():
return JSONResponse(status_code=200, content={"message": "success"})
def encode_audio_key(audio_bytes: bytes) -> str:
return hashlib.md5(audio_bytes[:16000]).hexdigest()[:16]
@app.post("/register_voice")
async def register_voice(
audio: UploadFile = File(...),
text: str = Form(...)
):
global voice_dict
audio_bytes = await audio.read()
audio_key = encode_audio_key(audio_bytes)
# Ensure ref_text ends with a proper sentence-ending punctuation
if not text.endswith(". ") and not text.endswith(""):
if text.endswith("."):
text += " "
else:
text += ". "
voice_dict[audio_key] = {
'ref_audio': audio_bytes,
'ref_text': text.strip()
}
# warmup
for _ in generate("流式语音合成,合成测试", audio_key, fast_infer=2):
logger.info("Warming up")
response = {
"status": "success",
"audio_key": audio_key
}
return JSONResponse(status_code=200, content=response)
symbols = """,.!?;:()[]{}<>,。!?;:【】《》……'"“”_—"""
def contains_words(text):
return any(char not in symbols for char in text)
def split_text(text, max_chars=135, cut_short_first=False):
sentences = re.split(r"(?<=[;:,.!?])\s+|(?<=[;:,。!?])", text)
sentences = [s.strip() for s in sentences if s.strip()]
chunks = []
current_chunk = ""
for sentence in sentences:
if len(current_chunk.encode("utf-8")) + len(sentence.encode("utf-8")) <= max_chars:
current_chunk += sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
else:
if current_chunk and contains_words(current_chunk):
chunks.append(current_chunk.strip())
current_chunk = sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
if current_chunk and contains_words(current_chunk):
chunks.append(current_chunk.strip())
if cut_short_first:
first_sentences = re.split(r"(?<=[;:,.!?])\s+|(?<=[;:,。!?])", chunks[0])
first = first_sentences[0].strip()
rest = "".join(first_sentences[1:]).strip()
first_chunk = [first, rest] if rest else [first]
final_chunks = first_chunk + chunks[1:]
else:
final_chunks = chunks
return final_chunks
def audio_postprocess(audio: np.ndarray, ori_sr: int, target_sr: int) -> np.ndarray:
number_of_samples = int(len(audio) * float(target_sr) / ori_sr)
audio_resampled = signal.resample(audio, number_of_samples)
if audio.dtype == np.float32:
audio_resampled = np.clip(audio_resampled, -1.0, 1.0)
audio_resampled = (audio_resampled * 32767).astype(np.int16)
return audio_resampled
def generate(gen_text, ref_audio_key, fast_infer=0):
global voice_dict, ema_model, vocoder
ref_audio_ = voice_dict[ref_audio_key]['ref_audio']
ref_text_ = voice_dict[ref_audio_key]['ref_text']
nfe_step = 16
if fast_infer >= 1:
nfe_step = 7
# nonuniform_step = True
# if fast_infer >= 2:
# ref_audio_ = voice_dict[ref_audio_key].get('ref_audio_slice', ref_audio_)
# ref_text_ = voice_dict[ref_audio_key].get('ref_text_slice', ref_text_)
audio, sr = torchaudio.load(io.BytesIO(ref_audio_))
max_chars = int(len(ref_text_.encode("utf-8")) / (audio.shape[-1] / sr) * (22 - audio.shape[-1] / sr))
gen_text_batches = split_text(gen_text, max_chars=max_chars, cut_short_first=(fast_infer > 0))
for gen_audio, gen_sr in infer_batch_process(
(audio, sr),
ref_text_,
gen_text_batches,
ema_model,
vocoder,
device=device,
streaming=True,
chunk_size=int(24e6),
nfe_step=nfe_step,
speed=speed,
):
yield audio_postprocess(gen_audio, gen_sr, TARGET_SR).tobytes()
def generate_with_audio(gen_text, ref_audio, ref_text, fast_infer=0):
global ema_model, vocoder
if not contains_words(gen_text):
audio = np.zeros(N_ZEROS, dtype=np.int16).tobytes()
yield audio
return
nfe_step = 16
if fast_infer >= 1:
nfe_step = 7
audio, sr = torchaudio.load(io.BytesIO(ref_audio))
max_chars = min(int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (22 - audio.shape[-1] / sr)), 135)
gen_text_batches = split_text(gen_text, max_chars=max_chars, cut_short_first=(fast_infer > 0))
for gen_audio, gen_sr in infer_batch_process(
(audio, sr),
ref_text,
gen_text_batches,
ema_model,
vocoder,
device=device,
streaming=True,
chunk_size=int(24e6),
nfe_step=nfe_step,
speed=speed,
):
yield audio_postprocess(gen_audio, gen_sr, TARGET_SR).tobytes()
@app.post("/synthesize")
async def synthesize(request: Request):
data = await request.json()
text = data['text']
audio_key = data['audio_key']
fast_infer = data.get('fast_infer', 0)
if fast_infer == True:
fast_infer = 2
else:
fast_infer = int(fast_infer)
# logger.info(f"Synthesizing text: {text}, audio_key: {audio_key}, fast_infer: {fast_infer}")
if not contains_words(text):
audio = np.zeros(N_ZEROS, dtype=np.int16).tobytes()
return Response(audio, media_type='audio/wav')
global voice_dict
if audio_key not in voice_dict:
raise HTTPException(status_code=400, detail="Invalid audio key")
return StreamingResponse(generate(text, audio_key, fast_infer), media_type="audio/wav")
xml_namespace = "{http://www.w3.org/XML/1998/namespace}"
@app.post("/tts")
def predict(ssml: str = Body(...), fast_infer: Union[bool, int] = 0):
try:
root = ET.fromstring(ssml)
voice_element = root.find(".//voice")
if voice_element is not None:
transcription = voice_element.text.strip()
language = voice_element.get(f'{xml_namespace}lang', "zh").strip()
# voice_name = voice_element.get("name", "zh-f-soft-1").strip()
else:
return JSONResponse(status_code=400, content={"message": "Invalid SSML format: <voice> element not found."})
except ET.ParseError as e:
return JSONResponse(status_code=400, content={"message": "Invalid SSML format", "Exception": str(e)})
fast_infer = int(fast_infer)
return StreamingResponse(
generate_with_audio(transcription, std_ref_audio, std_ref_text, fast_infer),
media_type="audio/wav"
)
@app.get("/health_check")
async def health_check():
try:
a = torch.ones(10, 20, dtype=torch.float32, device='cuda')
b = torch.ones(20, 10, dtype=torch.float32, device='cuda')
c = torch.matmul(a, b)
if c.sum() == 10 * 20 * 10:
return {"status": "ok"}
else:
raise HTTPException(status_code=503)
except Exception as e:
print(f'health_check failed')
raise HTTPException(status_code=503)
if __name__ == "__main__":
uvicorn.run("f5_server:app", host="0.0.0.0", port=80)

3
mr_v100-f5-tts/launch.sh Executable file
View File

@@ -0,0 +1,3 @@
#!/bin/bash
python3 f5_server.py

3
mr_v100-f5-tts/launch_f5.sh Executable file
View File

@@ -0,0 +1,3 @@
#!/bin/bash
python3 f5_server.py

Binary file not shown.

View File

@@ -0,0 +1 @@
而这条街道,没有半分“不谐”之感,实属难得。

View File

@@ -0,0 +1,2 @@
fastapi
uvicorn[standard]

BIN
mr_v100-gpt-sovits/.DS_Store vendored Normal file

Binary file not shown.

View File

@@ -0,0 +1,19 @@
FROM git.modelhub.org.cn:9443/enginex-iluvatar/mr100_corex:4.3.0
WORKDIR /workspace
RUN apt-get update && \
apt-get install -y ffmpeg libsox-dev redis && \
rm -rf /var/lib/apt/lists/*
COPY wav /workspace/wav
COPY GPT-SoVITS /workspace/GPT-SoVITS
COPY constraints_gsv.txt /workspace/
RUN pip install -r GPT-SoVITS/extra-req.txt --no-deps \
&& pip install -r GPT-SoVITS/requirements.txt -c constraints_gsv.txt
#COPY launch_gsv.sh /workspace/
#ENTRYPOINT ["/bin/bash", "launch_gsv.sh"]
COPY launch.sh /workspace/
ENTRYPOINT ["/bin/bash", "launch.sh"]

View File

@@ -0,0 +1,198 @@
GPT_SoVITS/pretrained_models/*
tools/asr/models/*
tools/uvr5/uvr5_weights/*
.git
.DS_Store
.vscode
*.pyc
env
runtime
.idea
output
logs
SoVITS_weights*/
GPT_weights*/
TEMP
weight.json
ffmpeg*
ffprobe*
cfg.json
speakers.json
ref_audios
# Byte-compiled / optimized / DLL files
__pycache__/
**/__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
Pipfile.lock
# UV
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
uv.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/
# Ruff stuff:
.ruff_cache/
# PyPI configuration file
.pypirc

195
mr_v100-gpt-sovits/GPT-SoVITS/.gitignore vendored Normal file
View File

@@ -0,0 +1,195 @@
.DS_Store
.vscode
__pycache__
*.pyc
env
runtime
.idea
output
logs
SoVITS_weights*/
GPT_weights*/
TEMP
weight.json
ffmpeg*
ffprobe*
cfg.json
speakers.json
ref_audios
tools/AP_BWE_main/24kto48k/*
!tools/AP_BWE_main/24kto48k/readme.txt
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# UV
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
#uv.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
# Ruff stuff:
.ruff_cache/
# PyPI configuration file
.pypirc

View File

@@ -0,0 +1,15 @@
ci:
autoupdate_schedule: monthly
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.11.7
hooks:
# Run the linter.
- id: ruff
types_or: [ python, pyi ]
args: [ --fix , "--exit-zero" ]
# Run the formatter.
- id: ruff-format
types_or: [ python, pyi ]
args: [ --line-length, "120", --target-version, "py310" ]

Some files were not shown because too many files have changed in this diff Show More