init muxi

This commit is contained in:
2025-09-12 11:39:55 +08:00
commit 96ef2da601
602 changed files with 591073 additions and 0 deletions

View File

@@ -0,0 +1,92 @@
# Training
Check your FFmpeg installation:
```bash
ffmpeg -version
```
If not found, install it first (or skip assuming you know of other backends available).
## Prepare Dataset
Example data processing scripts, and you may tailor your own one along with a Dataset class in `src/f5_tts/model/dataset.py`.
### 1. Some specific Datasets preparing scripts
Download corresponding dataset first, and fill in the path in scripts.
```bash
# Prepare the Emilia dataset
python src/f5_tts/train/datasets/prepare_emilia.py
# Prepare the Wenetspeech4TTS dataset
python src/f5_tts/train/datasets/prepare_wenetspeech4tts.py
# Prepare the LibriTTS dataset
python src/f5_tts/train/datasets/prepare_libritts.py
# Prepare the LJSpeech dataset
python src/f5_tts/train/datasets/prepare_ljspeech.py
```
### 2. Create custom dataset with metadata.csv
Use guidance see [#57 here](https://github.com/SWivid/F5-TTS/discussions/57#discussioncomment-10959029).
```bash
python src/f5_tts/train/datasets/prepare_csv_wavs.py
```
## Training & Finetuning
Once your datasets are prepared, you can start the training process.
### 1. Training script used for pretrained model
```bash
# setup accelerate config, e.g. use multi-gpu ddp, fp16
# will be to: ~/.cache/huggingface/accelerate/default_config.yaml
accelerate config
# .yaml files are under src/f5_tts/configs directory
accelerate launch src/f5_tts/train/train.py --config-name F5TTS_v1_Base.yaml
# possible to overwrite accelerate and hydra config
accelerate launch --mixed_precision=fp16 src/f5_tts/train/train.py --config-name F5TTS_v1_Base.yaml ++datasets.batch_size_per_gpu=19200
```
### 2. Finetuning practice
Discussion board for Finetuning [#57](https://github.com/SWivid/F5-TTS/discussions/57).
Gradio UI training/finetuning with `src/f5_tts/train/finetune_gradio.py` see [#143](https://github.com/SWivid/F5-TTS/discussions/143).
If want to finetune with a variant version e.g. *F5TTS_v1_Base_no_zero_init*, manually download pretrained checkpoint from model weight repository and fill in the path correspondingly on web interface.
If use tensorboard as logger, install it first with `pip install tensorboard`.
<ins>The `use_ema = True` might be harmful for early-stage finetuned checkpoints</ins> (which goes just few updates, thus ema weights still dominated by pretrained ones), try turn it off with finetune gradio option or `load_model(..., use_ema=False)`, see if offer better results.
### 3. W&B Logging
The `wandb/` dir will be created under path you run training/finetuning scripts.
By default, the training script does NOT use logging (assuming you didn't manually log in using `wandb login`).
To turn on wandb logging, you can either:
1. Manually login with `wandb login`: Learn more [here](https://docs.wandb.ai/ref/cli/wandb-login)
2. Automatically login programmatically by setting an environment variable: Get an API KEY at https://wandb.ai/authorize and set the environment variable as follows:
On Mac & Linux:
```
export WANDB_API_KEY=<YOUR WANDB API KEY>
```
On Windows:
```
set WANDB_API_KEY=<YOUR WANDB API KEY>
```
Moreover, if you couldn't access W&B and want to log metrics offline, you can set the environment variable as follows:
```
export WANDB_MODE=offline
```

View File

@@ -0,0 +1,283 @@
import concurrent.futures
import multiprocessing
import os
import shutil
import signal
import subprocess # For invoking ffprobe
import sys
from contextlib import contextmanager
sys.path.append(os.getcwd())
import argparse
import csv
import json
from importlib.resources import files
from pathlib import Path
import torchaudio
from datasets.arrow_writer import ArrowWriter
from tqdm import tqdm
from f5_tts.model.utils import convert_char_to_pinyin
PRETRAINED_VOCAB_PATH = files("f5_tts").joinpath("../../data/Emilia_ZH_EN_pinyin/vocab.txt")
def is_csv_wavs_format(input_dataset_dir):
fpath = Path(input_dataset_dir)
metadata = fpath / "metadata.csv"
wavs = fpath / "wavs"
return metadata.exists() and metadata.is_file() and wavs.exists() and wavs.is_dir()
# Configuration constants
BATCH_SIZE = 100 # Batch size for text conversion
MAX_WORKERS = max(1, multiprocessing.cpu_count() - 1) # Leave one CPU free
THREAD_NAME_PREFIX = "AudioProcessor"
CHUNK_SIZE = 100 # Number of files to process per worker batch
executor = None # Global executor for cleanup
@contextmanager
def graceful_exit():
"""Context manager for graceful shutdown on signals"""
def signal_handler(signum, frame):
print("\nReceived signal to terminate. Cleaning up...")
if executor is not None:
print("Shutting down executor...")
executor.shutdown(wait=False, cancel_futures=True)
sys.exit(1)
# Set up signal handlers
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
try:
yield
finally:
if executor is not None:
executor.shutdown(wait=False)
def process_audio_file(audio_path, text, polyphone):
"""Process a single audio file by checking its existence and extracting duration."""
if not Path(audio_path).exists():
print(f"audio {audio_path} not found, skipping")
return None
try:
audio_duration = get_audio_duration(audio_path)
if audio_duration <= 0:
raise ValueError(f"Duration {audio_duration} is non-positive.")
return (audio_path, text, audio_duration)
except Exception as e:
print(f"Warning: Failed to process {audio_path} due to error: {e}. Skipping corrupt file.")
return None
def batch_convert_texts(texts, polyphone, batch_size=BATCH_SIZE):
"""Convert a list of texts to pinyin in batches."""
converted_texts = []
for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size]
converted_batch = convert_char_to_pinyin(batch, polyphone=polyphone)
converted_texts.extend(converted_batch)
return converted_texts
def prepare_csv_wavs_dir(input_dir, num_workers=None):
global executor
assert is_csv_wavs_format(input_dir), f"not csv_wavs format: {input_dir}"
input_dir = Path(input_dir)
metadata_path = input_dir / "metadata.csv"
audio_path_text_pairs = read_audio_text_pairs(metadata_path.as_posix())
polyphone = True
total_files = len(audio_path_text_pairs)
# Use provided worker count or calculate optimal number
worker_count = num_workers if num_workers is not None else min(MAX_WORKERS, total_files)
print(f"\nProcessing {total_files} audio files using {worker_count} workers...")
with graceful_exit():
# Initialize thread pool with optimized settings
with concurrent.futures.ThreadPoolExecutor(
max_workers=worker_count, thread_name_prefix=THREAD_NAME_PREFIX
) as exec:
executor = exec
results = []
# Process files in chunks for better efficiency
for i in range(0, len(audio_path_text_pairs), CHUNK_SIZE):
chunk = audio_path_text_pairs[i : i + CHUNK_SIZE]
# Submit futures in order
chunk_futures = [executor.submit(process_audio_file, pair[0], pair[1], polyphone) for pair in chunk]
# Iterate over futures in the original submission order to preserve ordering
for future in tqdm(
chunk_futures,
total=len(chunk),
desc=f"Processing chunk {i // CHUNK_SIZE + 1}/{(total_files + CHUNK_SIZE - 1) // CHUNK_SIZE}",
):
try:
result = future.result()
if result is not None:
results.append(result)
except Exception as e:
print(f"Error processing file: {e}")
executor = None
# Filter out failed results
processed = [res for res in results if res is not None]
if not processed:
raise RuntimeError("No valid audio files were processed!")
# Batch process text conversion
raw_texts = [item[1] for item in processed]
converted_texts = batch_convert_texts(raw_texts, polyphone, batch_size=BATCH_SIZE)
# Prepare final results
sub_result = []
durations = []
vocab_set = set()
for (audio_path, _, duration), conv_text in zip(processed, converted_texts):
sub_result.append({"audio_path": audio_path, "text": conv_text, "duration": duration})
durations.append(duration)
vocab_set.update(list(conv_text))
return sub_result, durations, vocab_set
def get_audio_duration(audio_path, timeout=5):
"""
Get the duration of an audio file in seconds using ffmpeg's ffprobe.
Falls back to torchaudio.load() if ffprobe fails.
"""
try:
cmd = [
"ffprobe",
"-v",
"error",
"-show_entries",
"format=duration",
"-of",
"default=noprint_wrappers=1:nokey=1",
audio_path,
]
result = subprocess.run(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=True, timeout=timeout
)
duration_str = result.stdout.strip()
if duration_str:
return float(duration_str)
raise ValueError("Empty duration string from ffprobe.")
except (subprocess.TimeoutExpired, subprocess.SubprocessError, ValueError) as e:
print(f"Warning: ffprobe failed for {audio_path} with error: {e}. Falling back to torchaudio.")
try:
audio, sample_rate = torchaudio.load(audio_path)
return audio.shape[1] / sample_rate
except Exception as e:
raise RuntimeError(f"Both ffprobe and torchaudio failed for {audio_path}: {e}")
def read_audio_text_pairs(csv_file_path):
audio_text_pairs = []
parent = Path(csv_file_path).parent
with open(csv_file_path, mode="r", newline="", encoding="utf-8-sig") as csvfile:
reader = csv.reader(csvfile, delimiter="|")
next(reader) # Skip the header row
for row in reader:
if len(row) >= 2:
audio_file = row[0].strip() # First column: audio file path
text = row[1].strip() # Second column: text
audio_file_path = parent / audio_file
audio_text_pairs.append((audio_file_path.as_posix(), text))
return audio_text_pairs
def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_finetune):
out_dir = Path(out_dir)
out_dir.mkdir(exist_ok=True, parents=True)
print(f"\nSaving to {out_dir} ...")
# Save dataset with improved batch size for better I/O performance
raw_arrow_path = out_dir / "raw.arrow"
with ArrowWriter(path=raw_arrow_path.as_posix(), writer_batch_size=100) as writer:
for line in tqdm(result, desc="Writing to raw.arrow ..."):
writer.write(line)
# Save durations to JSON
dur_json_path = out_dir / "duration.json"
with open(dur_json_path.as_posix(), "w", encoding="utf-8") as f:
json.dump({"duration": duration_list}, f, ensure_ascii=False)
# Handle vocab file - write only once based on finetune flag
voca_out_path = out_dir / "vocab.txt"
if is_finetune:
file_vocab_finetune = PRETRAINED_VOCAB_PATH.as_posix()
shutil.copy2(file_vocab_finetune, voca_out_path)
else:
with open(voca_out_path.as_posix(), "w") as f:
for vocab in sorted(text_vocab_set):
f.write(vocab + "\n")
dataset_name = out_dir.stem
print(f"\nFor {dataset_name}, sample count: {len(result)}")
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours")
def prepare_and_save_set(inp_dir, out_dir, is_finetune: bool = True, num_workers: int = None):
if is_finetune:
assert PRETRAINED_VOCAB_PATH.exists(), f"pretrained vocab.txt not found: {PRETRAINED_VOCAB_PATH}"
sub_result, durations, vocab_set = prepare_csv_wavs_dir(inp_dir, num_workers=num_workers)
save_prepped_dataset(out_dir, sub_result, durations, vocab_set, is_finetune)
def cli():
try:
# Before processing, check if ffprobe is available.
if shutil.which("ffprobe") is None:
print(
"Warning: ffprobe is not available. Duration extraction will rely on torchaudio (which may be slower)."
)
# Usage examples in help text
parser = argparse.ArgumentParser(
description="Prepare and save dataset.",
epilog="""
Examples:
# For fine-tuning (default):
python prepare_csv_wavs.py /input/dataset/path /output/dataset/path
# For pre-training:
python prepare_csv_wavs.py /input/dataset/path /output/dataset/path --pretrain
# With custom worker count:
python prepare_csv_wavs.py /input/dataset/path /output/dataset/path --workers 4
""",
)
parser.add_argument("inp_dir", type=str, help="Input directory containing the data.")
parser.add_argument("out_dir", type=str, help="Output directory to save the prepared data.")
parser.add_argument("--pretrain", action="store_true", help="Enable for new pretrain, otherwise is a fine-tune")
parser.add_argument("--workers", type=int, help=f"Number of worker threads (default: {MAX_WORKERS})")
args = parser.parse_args()
prepare_and_save_set(args.inp_dir, args.out_dir, is_finetune=not args.pretrain, num_workers=args.workers)
except KeyboardInterrupt:
print("\nOperation cancelled by user. Cleaning up...")
if executor is not None:
executor.shutdown(wait=False, cancel_futures=True)
sys.exit(1)
if __name__ == "__main__":
cli()

View File

@@ -0,0 +1,228 @@
# Emilia Dataset: https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07
# if use updated new version, i.e. WebDataset, feel free to modify / draft your own script
# generate audio text map for Emilia ZH & EN
# evaluate for vocab size
import os
import sys
sys.path.append(os.getcwd())
import json
from concurrent.futures import ProcessPoolExecutor
from importlib.resources import files
from pathlib import Path
from datasets.arrow_writer import ArrowWriter
from tqdm import tqdm
from f5_tts.model.utils import convert_char_to_pinyin, repetition_found
out_zh = {
"ZH_B00041_S06226",
"ZH_B00042_S09204",
"ZH_B00065_S09430",
"ZH_B00065_S09431",
"ZH_B00066_S09327",
"ZH_B00066_S09328",
}
zh_filters = ["", ""]
# seems synthesized audios, or heavily code-switched
out_en = {
"EN_B00013_S00913",
"EN_B00042_S00120",
"EN_B00055_S04111",
"EN_B00061_S00693",
"EN_B00061_S01494",
"EN_B00061_S03375",
"EN_B00059_S00092",
"EN_B00111_S04300",
"EN_B00100_S03759",
"EN_B00087_S03811",
"EN_B00059_S00950",
"EN_B00089_S00946",
"EN_B00078_S05127",
"EN_B00070_S04089",
"EN_B00074_S09659",
"EN_B00061_S06983",
"EN_B00061_S07060",
"EN_B00059_S08397",
"EN_B00082_S06192",
"EN_B00091_S01238",
"EN_B00089_S07349",
"EN_B00070_S04343",
"EN_B00061_S02400",
"EN_B00076_S01262",
"EN_B00068_S06467",
"EN_B00076_S02943",
"EN_B00064_S05954",
"EN_B00061_S05386",
"EN_B00066_S06544",
"EN_B00076_S06944",
"EN_B00072_S08620",
"EN_B00076_S07135",
"EN_B00076_S09127",
"EN_B00065_S00497",
"EN_B00059_S06227",
"EN_B00063_S02859",
"EN_B00075_S01547",
"EN_B00061_S08286",
"EN_B00079_S02901",
"EN_B00092_S03643",
"EN_B00096_S08653",
"EN_B00063_S04297",
"EN_B00063_S04614",
"EN_B00079_S04698",
"EN_B00104_S01666",
"EN_B00061_S09504",
"EN_B00061_S09694",
"EN_B00065_S05444",
"EN_B00063_S06860",
"EN_B00065_S05725",
"EN_B00069_S07628",
"EN_B00083_S03875",
"EN_B00071_S07665",
"EN_B00071_S07665",
"EN_B00062_S04187",
"EN_B00065_S09873",
"EN_B00065_S09922",
"EN_B00084_S02463",
"EN_B00067_S05066",
"EN_B00106_S08060",
"EN_B00073_S06399",
"EN_B00073_S09236",
"EN_B00087_S00432",
"EN_B00085_S05618",
"EN_B00064_S01262",
"EN_B00072_S01739",
"EN_B00059_S03913",
"EN_B00069_S04036",
"EN_B00067_S05623",
"EN_B00060_S05389",
"EN_B00060_S07290",
"EN_B00062_S08995",
}
en_filters = ["ا", "", ""]
def deal_with_audio_dir(audio_dir):
audio_jsonl = audio_dir.with_suffix(".jsonl")
sub_result, durations = [], []
vocab_set = set()
bad_case_zh = 0
bad_case_en = 0
with open(audio_jsonl, "r") as f:
lines = f.readlines()
for line in tqdm(lines, desc=f"{audio_jsonl.stem}"):
obj = json.loads(line)
text = obj["text"]
if obj["language"] == "zh":
if obj["wav"].split("/")[1] in out_zh or any(f in text for f in zh_filters) or repetition_found(text):
bad_case_zh += 1
continue
else:
text = text.translate(
str.maketrans({",": "", "!": "", "?": ""})
) # not "。" cuz much code-switched
if obj["language"] == "en":
if (
obj["wav"].split("/")[1] in out_en
or any(f in text for f in en_filters)
or repetition_found(text, length=4)
):
bad_case_en += 1
continue
if tokenizer == "pinyin":
text = convert_char_to_pinyin([text], polyphone=polyphone)[0]
duration = obj["duration"]
sub_result.append({"audio_path": str(audio_dir.parent / obj["wav"]), "text": text, "duration": duration})
durations.append(duration)
vocab_set.update(list(text))
return sub_result, durations, vocab_set, bad_case_zh, bad_case_en
def main():
assert tokenizer in ["pinyin", "char"]
result = []
duration_list = []
text_vocab_set = set()
total_bad_case_zh = 0
total_bad_case_en = 0
# process raw data
executor = ProcessPoolExecutor(max_workers=max_workers)
futures = []
for lang in langs:
dataset_path = Path(os.path.join(dataset_dir, lang))
[
futures.append(executor.submit(deal_with_audio_dir, audio_dir))
for audio_dir in dataset_path.iterdir()
if audio_dir.is_dir()
]
for futures in tqdm(futures, total=len(futures)):
sub_result, durations, vocab_set, bad_case_zh, bad_case_en = futures.result()
result.extend(sub_result)
duration_list.extend(durations)
text_vocab_set.update(vocab_set)
total_bad_case_zh += bad_case_zh
total_bad_case_en += bad_case_en
executor.shutdown()
# save preprocessed dataset to disk
if not os.path.exists(f"{save_dir}"):
os.makedirs(f"{save_dir}")
print(f"\nSaving to {save_dir} ...")
# dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list}) # oom
# dataset.save_to_disk(f"{save_dir}/raw", max_shard_size="2GB")
with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
for line in tqdm(result, desc="Writing to raw.arrow ..."):
writer.write(line)
# dup a json separately saving duration in case for DynamicBatchSampler ease
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
json.dump({"duration": duration_list}, f, ensure_ascii=False)
# vocab map, i.e. tokenizer
# add alphabets and symbols (optional, if plan to ft on de/fr etc.)
# if tokenizer == "pinyin":
# text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)])
with open(f"{save_dir}/vocab.txt", "w") as f:
for vocab in sorted(text_vocab_set):
f.write(vocab + "\n")
print(f"\nFor {dataset_name}, sample count: {len(result)}")
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours")
if "ZH" in langs:
print(f"Bad zh transcription case: {total_bad_case_zh}")
if "EN" in langs:
print(f"Bad en transcription case: {total_bad_case_en}\n")
if __name__ == "__main__":
max_workers = 32
tokenizer = "pinyin" # "pinyin" | "char"
polyphone = True
langs = ["ZH", "EN"]
dataset_dir = "<SOME_PATH>/Emilia_Dataset/raw"
dataset_name = f"Emilia_{'_'.join(langs)}_{tokenizer}"
save_dir = str(files("f5_tts").joinpath("../../")) + f"/data/{dataset_name}"
print(f"\nPrepare for {dataset_name}, will save to {save_dir}\n")
main()
# Emilia ZH & EN
# samples count 37837916 (after removal)
# pinyin vocab size 2543 (polyphone)
# total duration 95281.87 (hours)
# bad zh asr cnt 230435 (samples)
# bad eh asr cnt 37217 (samples)
# vocab size may be slightly different due to jieba tokenizer and pypinyin (e.g. way of polyphoneme)
# please be careful if using pretrained model, make sure the vocab.txt is same

View File

@@ -0,0 +1,94 @@
# put in src/f5_tts/train/datasets/prepare_emilia_v2.py
# prepares Emilia dataset with the new format w/ Emilia-YODAS
import json
import os
from concurrent.futures import ProcessPoolExecutor
from importlib.resources import files
from pathlib import Path
from datasets.arrow_writer import ArrowWriter
from tqdm import tqdm
from f5_tts.model.utils import repetition_found
# Define filters for exclusion
out_en = set()
en_filters = ["ا", "", ""]
def process_audio_directory(audio_dir):
sub_result, durations, vocab_set = [], [], set()
bad_case_en = 0
for file in audio_dir.iterdir():
if file.suffix == ".json":
with open(file, "r") as f:
obj = json.load(f)
text = obj["text"]
if any(f in text for f in en_filters) or repetition_found(text, length=4):
bad_case_en += 1
continue
duration = obj["duration"]
audio_file = file.with_suffix(".mp3")
if audio_file.exists():
sub_result.append({"audio_path": str(audio_file), "text": text, "duration": duration})
durations.append(duration)
vocab_set.update(list(text))
return sub_result, durations, vocab_set, bad_case_en
def main():
assert tokenizer in ["pinyin", "char"]
result, duration_list, text_vocab_set = [], [], set()
total_bad_case_en = 0
executor = ProcessPoolExecutor(max_workers=max_workers)
futures = []
dataset_path = Path(dataset_dir)
for sub_dir in dataset_path.iterdir():
if sub_dir.is_dir():
futures.append(executor.submit(process_audio_directory, sub_dir))
for future in tqdm(futures, total=len(futures)):
sub_result, durations, vocab_set, bad_case_en = future.result()
result.extend(sub_result)
duration_list.extend(durations)
text_vocab_set.update(vocab_set)
total_bad_case_en += bad_case_en
executor.shutdown()
if not os.path.exists(f"{save_dir}"):
os.makedirs(f"{save_dir}")
with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
for line in tqdm(result, desc="Writing to raw.arrow ..."):
writer.write(line)
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
json.dump({"duration": duration_list}, f, ensure_ascii=False)
with open(f"{save_dir}/vocab.txt", "w") as f:
for vocab in sorted(text_vocab_set):
f.write(vocab + "\n")
print(f"For {dataset_name}, sample count: {len(result)}")
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours")
print(f"Bad en transcription case: {total_bad_case_en}\n")
if __name__ == "__main__":
max_workers = 32
tokenizer = "char"
dataset_dir = "/home/ubuntu/emilia-dataset/Emilia-YODAS/EN"
dataset_name = f"Emilia_EN_{tokenizer}"
# save_dir = os.path.expanduser(f"~/F5-TTS/data/{dataset_name}")
save_dir = str(files("f5_tts").joinpath("../../")) + f"/data/{dataset_name}"
print(f"Prepare for {dataset_name}, will save to {save_dir}\n")
main()

View File

@@ -0,0 +1,94 @@
import os
import sys
sys.path.append(os.getcwd())
import json
from concurrent.futures import ProcessPoolExecutor
from importlib.resources import files
from pathlib import Path
import soundfile as sf
from datasets.arrow_writer import ArrowWriter
from tqdm import tqdm
def deal_with_audio_dir(audio_dir):
sub_result, durations = [], []
vocab_set = set()
audio_lists = list(audio_dir.rglob("*.wav"))
for line in audio_lists:
text_path = line.with_suffix(".normalized.txt")
text = open(text_path, "r").read().strip()
duration = sf.info(line).duration
if duration < 0.4 or duration > 30:
continue
sub_result.append({"audio_path": str(line), "text": text, "duration": duration})
durations.append(duration)
vocab_set.update(list(text))
return sub_result, durations, vocab_set
def main():
result = []
duration_list = []
text_vocab_set = set()
# process raw data
executor = ProcessPoolExecutor(max_workers=max_workers)
futures = []
for subset in tqdm(SUB_SET):
dataset_path = Path(os.path.join(dataset_dir, subset))
[
futures.append(executor.submit(deal_with_audio_dir, audio_dir))
for audio_dir in dataset_path.iterdir()
if audio_dir.is_dir()
]
for future in tqdm(futures, total=len(futures)):
sub_result, durations, vocab_set = future.result()
result.extend(sub_result)
duration_list.extend(durations)
text_vocab_set.update(vocab_set)
executor.shutdown()
# save preprocessed dataset to disk
if not os.path.exists(f"{save_dir}"):
os.makedirs(f"{save_dir}")
print(f"\nSaving to {save_dir} ...")
with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
for line in tqdm(result, desc="Writing to raw.arrow ..."):
writer.write(line)
# dup a json separately saving duration in case for DynamicBatchSampler ease
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
json.dump({"duration": duration_list}, f, ensure_ascii=False)
# vocab map, i.e. tokenizer
with open(f"{save_dir}/vocab.txt", "w") as f:
for vocab in sorted(text_vocab_set):
f.write(vocab + "\n")
print(f"\nFor {dataset_name}, sample count: {len(result)}")
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours")
if __name__ == "__main__":
max_workers = 36
tokenizer = "char" # "pinyin" | "char"
SUB_SET = ["train-clean-100", "train-clean-360", "train-other-500"]
dataset_dir = "<SOME_PATH>/LibriTTS"
dataset_name = f"LibriTTS_{'_'.join(SUB_SET)}_{tokenizer}".replace("train-clean-", "").replace("train-other-", "")
save_dir = str(files("f5_tts").joinpath("../../")) + f"/data/{dataset_name}"
print(f"\nPrepare for {dataset_name}, will save to {save_dir}\n")
main()
# For LibriTTS_100_360_500_char, sample count: 354218
# For LibriTTS_100_360_500_char, vocab size is: 78
# For LibriTTS_100_360_500_char, total 554.09 hours

View File

@@ -0,0 +1,67 @@
import os
import sys
sys.path.append(os.getcwd())
import json
from importlib.resources import files
from pathlib import Path
import soundfile as sf
from datasets.arrow_writer import ArrowWriter
from tqdm import tqdm
def main():
result = []
duration_list = []
text_vocab_set = set()
with open(meta_info, "r") as f:
lines = f.readlines()
for line in tqdm(lines):
uttr, text, norm_text = line.split("|")
norm_text = norm_text.strip()
wav_path = Path(dataset_dir) / "wavs" / f"{uttr}.wav"
duration = sf.info(wav_path).duration
if duration < 0.4 or duration > 30:
continue
result.append({"audio_path": str(wav_path), "text": norm_text, "duration": duration})
duration_list.append(duration)
text_vocab_set.update(list(norm_text))
# save preprocessed dataset to disk
if not os.path.exists(f"{save_dir}"):
os.makedirs(f"{save_dir}")
print(f"\nSaving to {save_dir} ...")
with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
for line in tqdm(result, desc="Writing to raw.arrow ..."):
writer.write(line)
# dup a json separately saving duration in case for DynamicBatchSampler ease
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
json.dump({"duration": duration_list}, f, ensure_ascii=False)
# vocab map, i.e. tokenizer
# add alphabets and symbols (optional, if plan to ft on de/fr etc.)
with open(f"{save_dir}/vocab.txt", "w") as f:
for vocab in sorted(text_vocab_set):
f.write(vocab + "\n")
print(f"\nFor {dataset_name}, sample count: {len(result)}")
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours")
if __name__ == "__main__":
tokenizer = "char" # "pinyin" | "char"
dataset_dir = "<SOME_PATH>/LJSpeech-1.1"
dataset_name = f"LJSpeech_{tokenizer}"
meta_info = os.path.join(dataset_dir, "metadata.csv")
save_dir = str(files("f5_tts").joinpath("../../")) + f"/data/{dataset_name}"
print(f"\nPrepare for {dataset_name}, will save to {save_dir}\n")
main()

View File

@@ -0,0 +1,126 @@
# generate audio text map for WenetSpeech4TTS
# evaluate for vocab size
import os
import sys
sys.path.append(os.getcwd())
import json
from concurrent.futures import ProcessPoolExecutor
from importlib.resources import files
import torchaudio
from datasets import Dataset
from tqdm import tqdm
from f5_tts.model.utils import convert_char_to_pinyin
def deal_with_sub_path_files(dataset_path, sub_path):
print(f"Dealing with: {sub_path}")
text_dir = os.path.join(dataset_path, sub_path, "txts")
audio_dir = os.path.join(dataset_path, sub_path, "wavs")
text_files = os.listdir(text_dir)
audio_paths, texts, durations = [], [], []
for text_file in tqdm(text_files):
with open(os.path.join(text_dir, text_file), "r", encoding="utf-8") as file:
first_line = file.readline().split("\t")
audio_nm = first_line[0]
audio_path = os.path.join(audio_dir, audio_nm + ".wav")
text = first_line[1].strip()
audio_paths.append(audio_path)
if tokenizer == "pinyin":
texts.extend(convert_char_to_pinyin([text], polyphone=polyphone))
elif tokenizer == "char":
texts.append(text)
audio, sample_rate = torchaudio.load(audio_path)
durations.append(audio.shape[-1] / sample_rate)
return audio_paths, texts, durations
def main():
assert tokenizer in ["pinyin", "char"]
audio_path_list, text_list, duration_list = [], [], []
executor = ProcessPoolExecutor(max_workers=max_workers)
futures = []
for dataset_path in dataset_paths:
sub_items = os.listdir(dataset_path)
sub_paths = [item for item in sub_items if os.path.isdir(os.path.join(dataset_path, item))]
for sub_path in sub_paths:
futures.append(executor.submit(deal_with_sub_path_files, dataset_path, sub_path))
for future in tqdm(futures, total=len(futures)):
audio_paths, texts, durations = future.result()
audio_path_list.extend(audio_paths)
text_list.extend(texts)
duration_list.extend(durations)
executor.shutdown()
if not os.path.exists("data"):
os.makedirs("data")
print(f"\nSaving to {save_dir} ...")
dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list})
dataset.save_to_disk(f"{save_dir}/raw", max_shard_size="2GB") # arrow format
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
json.dump(
{"duration": duration_list}, f, ensure_ascii=False
) # dup a json separately saving duration in case for DynamicBatchSampler ease
print("\nEvaluating vocab size (all characters and symbols / all phonemes) ...")
text_vocab_set = set()
for text in tqdm(text_list):
text_vocab_set.update(list(text))
# add alphabets and symbols (optional, if plan to ft on de/fr etc.)
if tokenizer == "pinyin":
text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)])
with open(f"{save_dir}/vocab.txt", "w") as f:
for vocab in sorted(text_vocab_set):
f.write(vocab + "\n")
print(f"\nFor {dataset_name}, sample count: {len(text_list)}")
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}\n")
if __name__ == "__main__":
max_workers = 32
tokenizer = "pinyin" # "pinyin" | "char"
polyphone = True
dataset_choice = 1 # 1: Premium, 2: Standard, 3: Basic
dataset_name = (
["WenetSpeech4TTS_Premium", "WenetSpeech4TTS_Standard", "WenetSpeech4TTS_Basic"][dataset_choice - 1]
+ "_"
+ tokenizer
)
dataset_paths = [
"<SOME_PATH>/WenetSpeech4TTS/Basic",
"<SOME_PATH>/WenetSpeech4TTS/Standard",
"<SOME_PATH>/WenetSpeech4TTS/Premium",
][-dataset_choice:]
save_dir = str(files("f5_tts").joinpath("../../")) + f"/data/{dataset_name}"
print(f"\nChoose Dataset: {dataset_name}, will save to {save_dir}\n")
main()
# Results (if adding alphabets with accents and symbols):
# WenetSpeech4TTS Basic Standard Premium
# samples count 3932473 1941220 407494
# pinyin vocab size 1349 1348 1344 (no polyphone)
# - - 1459 (polyphone)
# char vocab size 5264 5219 5042
# vocab size may be slightly different due to jieba tokenizer and pypinyin (e.g. way of polyphoneme)
# please be careful if using pretrained model, make sure the vocab.txt is same

View File

@@ -0,0 +1,214 @@
import argparse
import os
import shutil
from importlib.resources import files
from cached_path import cached_path
from f5_tts.model import CFM, DiT, Trainer, UNetT
from f5_tts.model.dataset import load_dataset
from f5_tts.model.utils import get_tokenizer
# -------------------------- Dataset Settings --------------------------- #
target_sample_rate = 24000
n_mel_channels = 100
hop_length = 256
win_length = 1024
n_fft = 1024
mel_spec_type = "vocos" # 'vocos' or 'bigvgan'
# -------------------------- Argument Parsing --------------------------- #
def parse_args():
parser = argparse.ArgumentParser(description="Train CFM Model")
parser.add_argument(
"--exp_name",
type=str,
default="F5TTS_v1_Base",
choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"],
help="Experiment name",
)
parser.add_argument("--dataset_name", type=str, default="Emilia_ZH_EN", help="Name of the dataset to use")
parser.add_argument("--learning_rate", type=float, default=1e-5, help="Learning rate for training")
parser.add_argument("--batch_size_per_gpu", type=int, default=3200, help="Batch size per GPU")
parser.add_argument(
"--batch_size_type", type=str, default="frame", choices=["frame", "sample"], help="Batch size type"
)
parser.add_argument("--max_samples", type=int, default=64, help="Max sequences per batch")
parser.add_argument("--grad_accumulation_steps", type=int, default=1, help="Gradient accumulation steps")
parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping")
parser.add_argument("--epochs", type=int, default=100, help="Number of training epochs")
parser.add_argument("--num_warmup_updates", type=int, default=20000, help="Warmup updates")
parser.add_argument("--save_per_updates", type=int, default=50000, help="Save checkpoint every N updates")
parser.add_argument(
"--keep_last_n_checkpoints",
type=int,
default=-1,
help="-1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints",
)
parser.add_argument("--last_per_updates", type=int, default=5000, help="Save last checkpoint every N updates")
parser.add_argument("--finetune", action="store_true", help="Use Finetune")
parser.add_argument("--pretrain", type=str, default=None, help="the path to the checkpoint")
parser.add_argument(
"--tokenizer", type=str, default="pinyin", choices=["pinyin", "char", "custom"], help="Tokenizer type"
)
parser.add_argument(
"--tokenizer_path",
type=str,
default=None,
help="Path to custom tokenizer vocab file (only used if tokenizer = 'custom')",
)
parser.add_argument(
"--log_samples",
action="store_true",
help="Log inferenced samples per ckpt save updates",
)
parser.add_argument("--logger", type=str, default=None, choices=[None, "wandb", "tensorboard"], help="logger")
parser.add_argument(
"--bnb_optimizer",
action="store_true",
help="Use 8-bit Adam optimizer from bitsandbytes",
)
return parser.parse_args()
# -------------------------- Training Settings -------------------------- #
def main():
args = parse_args()
checkpoint_path = str(files("f5_tts").joinpath(f"../../ckpts/{args.dataset_name}"))
# Model parameters based on experiment name
if args.exp_name == "F5TTS_v1_Base":
wandb_resume_id = None
model_cls = DiT
model_cfg = dict(
dim=1024,
depth=22,
heads=16,
ff_mult=2,
text_dim=512,
conv_layers=4,
)
if args.finetune:
if args.pretrain is None:
ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors"))
else:
ckpt_path = args.pretrain
elif args.exp_name == "F5TTS_Base":
wandb_resume_id = None
model_cls = DiT
model_cfg = dict(
dim=1024,
depth=22,
heads=16,
ff_mult=2,
text_dim=512,
text_mask_padding=False,
conv_layers=4,
pe_attn_head=1,
)
if args.finetune:
if args.pretrain is None:
ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
else:
ckpt_path = args.pretrain
elif args.exp_name == "E2TTS_Base":
wandb_resume_id = None
model_cls = UNetT
model_cfg = dict(
dim=1024,
depth=24,
heads=16,
ff_mult=4,
text_mask_padding=False,
pe_attn_head=1,
)
if args.finetune:
if args.pretrain is None:
ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
else:
ckpt_path = args.pretrain
if args.finetune:
if not os.path.isdir(checkpoint_path):
os.makedirs(checkpoint_path, exist_ok=True)
file_checkpoint = os.path.basename(ckpt_path)
if not file_checkpoint.startswith("pretrained_"): # Change: Add 'pretrained_' prefix to copied model
file_checkpoint = "pretrained_" + file_checkpoint
file_checkpoint = os.path.join(checkpoint_path, file_checkpoint)
if not os.path.isfile(file_checkpoint):
shutil.copy2(ckpt_path, file_checkpoint)
print("copy checkpoint for finetune")
# Use the tokenizer and tokenizer_path provided in the command line arguments
tokenizer = args.tokenizer
if tokenizer == "custom":
if not args.tokenizer_path:
raise ValueError("Custom tokenizer selected, but no tokenizer_path provided.")
tokenizer_path = args.tokenizer_path
else:
tokenizer_path = args.dataset_name
vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
print("\nvocab : ", vocab_size)
print("\nvocoder : ", mel_spec_type)
mel_spec_kwargs = dict(
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
n_mel_channels=n_mel_channels,
target_sample_rate=target_sample_rate,
mel_spec_type=mel_spec_type,
)
model = CFM(
transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
mel_spec_kwargs=mel_spec_kwargs,
vocab_char_map=vocab_char_map,
)
trainer = Trainer(
model,
args.epochs,
args.learning_rate,
num_warmup_updates=args.num_warmup_updates,
save_per_updates=args.save_per_updates,
keep_last_n_checkpoints=args.keep_last_n_checkpoints,
checkpoint_path=checkpoint_path,
batch_size_per_gpu=args.batch_size_per_gpu,
batch_size_type=args.batch_size_type,
max_samples=args.max_samples,
grad_accumulation_steps=args.grad_accumulation_steps,
max_grad_norm=args.max_grad_norm,
logger=args.logger,
wandb_project=args.dataset_name,
wandb_run_name=args.exp_name,
wandb_resume_id=wandb_resume_id,
log_samples=args.log_samples,
last_per_updates=args.last_per_updates,
bnb_optimizer=args.bnb_optimizer,
)
train_dataset = load_dataset(args.dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
trainer.train(
train_dataset,
resumable_with_seed=666, # seed for shuffling dataset
)
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,77 @@
# training script.
import os
from importlib.resources import files
import hydra
from omegaconf import OmegaConf
from f5_tts.model import CFM, Trainer
from f5_tts.model.dataset import load_dataset
from f5_tts.model.utils import get_tokenizer
os.chdir(str(files("f5_tts").joinpath("../.."))) # change working directory to root of project (local editable)
@hydra.main(version_base="1.3", config_path=str(files("f5_tts").joinpath("configs")), config_name=None)
def main(model_cfg):
model_cls = hydra.utils.get_class(f"f5_tts.model.{model_cfg.model.backbone}")
model_arc = model_cfg.model.arch
tokenizer = model_cfg.model.tokenizer
mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
exp_name = f"{model_cfg.model.name}_{mel_spec_type}_{model_cfg.model.tokenizer}_{model_cfg.datasets.name}"
wandb_resume_id = None
# set text tokenizer
if tokenizer != "custom":
tokenizer_path = model_cfg.datasets.name
else:
tokenizer_path = model_cfg.model.tokenizer_path
vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
# set model
model = CFM(
transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=model_cfg.model.mel_spec.n_mel_channels),
mel_spec_kwargs=model_cfg.model.mel_spec,
vocab_char_map=vocab_char_map,
)
# init trainer
trainer = Trainer(
model,
epochs=model_cfg.optim.epochs,
learning_rate=model_cfg.optim.learning_rate,
num_warmup_updates=model_cfg.optim.num_warmup_updates,
save_per_updates=model_cfg.ckpts.save_per_updates,
keep_last_n_checkpoints=model_cfg.ckpts.keep_last_n_checkpoints,
checkpoint_path=str(files("f5_tts").joinpath(f"../../{model_cfg.ckpts.save_dir}")),
batch_size_per_gpu=model_cfg.datasets.batch_size_per_gpu,
batch_size_type=model_cfg.datasets.batch_size_type,
max_samples=model_cfg.datasets.max_samples,
grad_accumulation_steps=model_cfg.optim.grad_accumulation_steps,
max_grad_norm=model_cfg.optim.max_grad_norm,
logger=model_cfg.ckpts.logger,
wandb_project="CFM-TTS",
wandb_run_name=exp_name,
wandb_resume_id=wandb_resume_id,
last_per_updates=model_cfg.ckpts.last_per_updates,
log_samples=model_cfg.ckpts.log_samples,
bnb_optimizer=model_cfg.optim.bnb_optimizer,
mel_spec_type=mel_spec_type,
is_local_vocoder=model_cfg.model.vocoder.is_local,
local_vocoder_path=model_cfg.model.vocoder.local_path,
model_cfg_dict=OmegaConf.to_container(model_cfg, resolve=True),
)
train_dataset = load_dataset(model_cfg.datasets.name, tokenizer, mel_spec_kwargs=model_cfg.model.mel_spec)
trainer.train(
train_dataset,
num_workers=model_cfg.datasets.num_workers,
resumable_with_seed=666, # seed for shuffling dataset
)
if __name__ == "__main__":
main()