init muxi
This commit is contained in:
92
metaX-C500-f5-tts/F5-TTS/src/f5_tts/train/README.md
Normal file
92
metaX-C500-f5-tts/F5-TTS/src/f5_tts/train/README.md
Normal file
@@ -0,0 +1,92 @@
|
||||
# Training
|
||||
|
||||
Check your FFmpeg installation:
|
||||
```bash
|
||||
ffmpeg -version
|
||||
```
|
||||
If not found, install it first (or skip assuming you know of other backends available).
|
||||
|
||||
## Prepare Dataset
|
||||
|
||||
Example data processing scripts, and you may tailor your own one along with a Dataset class in `src/f5_tts/model/dataset.py`.
|
||||
|
||||
### 1. Some specific Datasets preparing scripts
|
||||
Download corresponding dataset first, and fill in the path in scripts.
|
||||
|
||||
```bash
|
||||
# Prepare the Emilia dataset
|
||||
python src/f5_tts/train/datasets/prepare_emilia.py
|
||||
|
||||
# Prepare the Wenetspeech4TTS dataset
|
||||
python src/f5_tts/train/datasets/prepare_wenetspeech4tts.py
|
||||
|
||||
# Prepare the LibriTTS dataset
|
||||
python src/f5_tts/train/datasets/prepare_libritts.py
|
||||
|
||||
# Prepare the LJSpeech dataset
|
||||
python src/f5_tts/train/datasets/prepare_ljspeech.py
|
||||
```
|
||||
|
||||
### 2. Create custom dataset with metadata.csv
|
||||
Use guidance see [#57 here](https://github.com/SWivid/F5-TTS/discussions/57#discussioncomment-10959029).
|
||||
|
||||
```bash
|
||||
python src/f5_tts/train/datasets/prepare_csv_wavs.py
|
||||
```
|
||||
|
||||
## Training & Finetuning
|
||||
|
||||
Once your datasets are prepared, you can start the training process.
|
||||
|
||||
### 1. Training script used for pretrained model
|
||||
|
||||
```bash
|
||||
# setup accelerate config, e.g. use multi-gpu ddp, fp16
|
||||
# will be to: ~/.cache/huggingface/accelerate/default_config.yaml
|
||||
accelerate config
|
||||
|
||||
# .yaml files are under src/f5_tts/configs directory
|
||||
accelerate launch src/f5_tts/train/train.py --config-name F5TTS_v1_Base.yaml
|
||||
|
||||
# possible to overwrite accelerate and hydra config
|
||||
accelerate launch --mixed_precision=fp16 src/f5_tts/train/train.py --config-name F5TTS_v1_Base.yaml ++datasets.batch_size_per_gpu=19200
|
||||
```
|
||||
|
||||
### 2. Finetuning practice
|
||||
Discussion board for Finetuning [#57](https://github.com/SWivid/F5-TTS/discussions/57).
|
||||
|
||||
Gradio UI training/finetuning with `src/f5_tts/train/finetune_gradio.py` see [#143](https://github.com/SWivid/F5-TTS/discussions/143).
|
||||
|
||||
If want to finetune with a variant version e.g. *F5TTS_v1_Base_no_zero_init*, manually download pretrained checkpoint from model weight repository and fill in the path correspondingly on web interface.
|
||||
|
||||
If use tensorboard as logger, install it first with `pip install tensorboard`.
|
||||
|
||||
<ins>The `use_ema = True` might be harmful for early-stage finetuned checkpoints</ins> (which goes just few updates, thus ema weights still dominated by pretrained ones), try turn it off with finetune gradio option or `load_model(..., use_ema=False)`, see if offer better results.
|
||||
|
||||
### 3. W&B Logging
|
||||
|
||||
The `wandb/` dir will be created under path you run training/finetuning scripts.
|
||||
|
||||
By default, the training script does NOT use logging (assuming you didn't manually log in using `wandb login`).
|
||||
|
||||
To turn on wandb logging, you can either:
|
||||
|
||||
1. Manually login with `wandb login`: Learn more [here](https://docs.wandb.ai/ref/cli/wandb-login)
|
||||
2. Automatically login programmatically by setting an environment variable: Get an API KEY at https://wandb.ai/authorize and set the environment variable as follows:
|
||||
|
||||
On Mac & Linux:
|
||||
|
||||
```
|
||||
export WANDB_API_KEY=<YOUR WANDB API KEY>
|
||||
```
|
||||
|
||||
On Windows:
|
||||
|
||||
```
|
||||
set WANDB_API_KEY=<YOUR WANDB API KEY>
|
||||
```
|
||||
Moreover, if you couldn't access W&B and want to log metrics offline, you can set the environment variable as follows:
|
||||
|
||||
```
|
||||
export WANDB_MODE=offline
|
||||
```
|
||||
@@ -0,0 +1,283 @@
|
||||
import concurrent.futures
|
||||
import multiprocessing
|
||||
import os
|
||||
import shutil
|
||||
import signal
|
||||
import subprocess # For invoking ffprobe
|
||||
import sys
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import json
|
||||
from importlib.resources import files
|
||||
from pathlib import Path
|
||||
|
||||
import torchaudio
|
||||
from datasets.arrow_writer import ArrowWriter
|
||||
from tqdm import tqdm
|
||||
|
||||
from f5_tts.model.utils import convert_char_to_pinyin
|
||||
|
||||
|
||||
PRETRAINED_VOCAB_PATH = files("f5_tts").joinpath("../../data/Emilia_ZH_EN_pinyin/vocab.txt")
|
||||
|
||||
|
||||
def is_csv_wavs_format(input_dataset_dir):
|
||||
fpath = Path(input_dataset_dir)
|
||||
metadata = fpath / "metadata.csv"
|
||||
wavs = fpath / "wavs"
|
||||
return metadata.exists() and metadata.is_file() and wavs.exists() and wavs.is_dir()
|
||||
|
||||
|
||||
# Configuration constants
|
||||
BATCH_SIZE = 100 # Batch size for text conversion
|
||||
MAX_WORKERS = max(1, multiprocessing.cpu_count() - 1) # Leave one CPU free
|
||||
THREAD_NAME_PREFIX = "AudioProcessor"
|
||||
CHUNK_SIZE = 100 # Number of files to process per worker batch
|
||||
|
||||
executor = None # Global executor for cleanup
|
||||
|
||||
|
||||
@contextmanager
|
||||
def graceful_exit():
|
||||
"""Context manager for graceful shutdown on signals"""
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
print("\nReceived signal to terminate. Cleaning up...")
|
||||
if executor is not None:
|
||||
print("Shutting down executor...")
|
||||
executor.shutdown(wait=False, cancel_futures=True)
|
||||
sys.exit(1)
|
||||
|
||||
# Set up signal handlers
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if executor is not None:
|
||||
executor.shutdown(wait=False)
|
||||
|
||||
|
||||
def process_audio_file(audio_path, text, polyphone):
|
||||
"""Process a single audio file by checking its existence and extracting duration."""
|
||||
if not Path(audio_path).exists():
|
||||
print(f"audio {audio_path} not found, skipping")
|
||||
return None
|
||||
try:
|
||||
audio_duration = get_audio_duration(audio_path)
|
||||
if audio_duration <= 0:
|
||||
raise ValueError(f"Duration {audio_duration} is non-positive.")
|
||||
return (audio_path, text, audio_duration)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to process {audio_path} due to error: {e}. Skipping corrupt file.")
|
||||
return None
|
||||
|
||||
|
||||
def batch_convert_texts(texts, polyphone, batch_size=BATCH_SIZE):
|
||||
"""Convert a list of texts to pinyin in batches."""
|
||||
converted_texts = []
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i : i + batch_size]
|
||||
converted_batch = convert_char_to_pinyin(batch, polyphone=polyphone)
|
||||
converted_texts.extend(converted_batch)
|
||||
return converted_texts
|
||||
|
||||
|
||||
def prepare_csv_wavs_dir(input_dir, num_workers=None):
|
||||
global executor
|
||||
assert is_csv_wavs_format(input_dir), f"not csv_wavs format: {input_dir}"
|
||||
input_dir = Path(input_dir)
|
||||
metadata_path = input_dir / "metadata.csv"
|
||||
audio_path_text_pairs = read_audio_text_pairs(metadata_path.as_posix())
|
||||
|
||||
polyphone = True
|
||||
total_files = len(audio_path_text_pairs)
|
||||
|
||||
# Use provided worker count or calculate optimal number
|
||||
worker_count = num_workers if num_workers is not None else min(MAX_WORKERS, total_files)
|
||||
print(f"\nProcessing {total_files} audio files using {worker_count} workers...")
|
||||
|
||||
with graceful_exit():
|
||||
# Initialize thread pool with optimized settings
|
||||
with concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=worker_count, thread_name_prefix=THREAD_NAME_PREFIX
|
||||
) as exec:
|
||||
executor = exec
|
||||
results = []
|
||||
|
||||
# Process files in chunks for better efficiency
|
||||
for i in range(0, len(audio_path_text_pairs), CHUNK_SIZE):
|
||||
chunk = audio_path_text_pairs[i : i + CHUNK_SIZE]
|
||||
# Submit futures in order
|
||||
chunk_futures = [executor.submit(process_audio_file, pair[0], pair[1], polyphone) for pair in chunk]
|
||||
|
||||
# Iterate over futures in the original submission order to preserve ordering
|
||||
for future in tqdm(
|
||||
chunk_futures,
|
||||
total=len(chunk),
|
||||
desc=f"Processing chunk {i // CHUNK_SIZE + 1}/{(total_files + CHUNK_SIZE - 1) // CHUNK_SIZE}",
|
||||
):
|
||||
try:
|
||||
result = future.result()
|
||||
if result is not None:
|
||||
results.append(result)
|
||||
except Exception as e:
|
||||
print(f"Error processing file: {e}")
|
||||
|
||||
executor = None
|
||||
|
||||
# Filter out failed results
|
||||
processed = [res for res in results if res is not None]
|
||||
if not processed:
|
||||
raise RuntimeError("No valid audio files were processed!")
|
||||
|
||||
# Batch process text conversion
|
||||
raw_texts = [item[1] for item in processed]
|
||||
converted_texts = batch_convert_texts(raw_texts, polyphone, batch_size=BATCH_SIZE)
|
||||
|
||||
# Prepare final results
|
||||
sub_result = []
|
||||
durations = []
|
||||
vocab_set = set()
|
||||
|
||||
for (audio_path, _, duration), conv_text in zip(processed, converted_texts):
|
||||
sub_result.append({"audio_path": audio_path, "text": conv_text, "duration": duration})
|
||||
durations.append(duration)
|
||||
vocab_set.update(list(conv_text))
|
||||
|
||||
return sub_result, durations, vocab_set
|
||||
|
||||
|
||||
def get_audio_duration(audio_path, timeout=5):
|
||||
"""
|
||||
Get the duration of an audio file in seconds using ffmpeg's ffprobe.
|
||||
Falls back to torchaudio.load() if ffprobe fails.
|
||||
"""
|
||||
try:
|
||||
cmd = [
|
||||
"ffprobe",
|
||||
"-v",
|
||||
"error",
|
||||
"-show_entries",
|
||||
"format=duration",
|
||||
"-of",
|
||||
"default=noprint_wrappers=1:nokey=1",
|
||||
audio_path,
|
||||
]
|
||||
result = subprocess.run(
|
||||
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=True, timeout=timeout
|
||||
)
|
||||
duration_str = result.stdout.strip()
|
||||
if duration_str:
|
||||
return float(duration_str)
|
||||
raise ValueError("Empty duration string from ffprobe.")
|
||||
except (subprocess.TimeoutExpired, subprocess.SubprocessError, ValueError) as e:
|
||||
print(f"Warning: ffprobe failed for {audio_path} with error: {e}. Falling back to torchaudio.")
|
||||
try:
|
||||
audio, sample_rate = torchaudio.load(audio_path)
|
||||
return audio.shape[1] / sample_rate
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Both ffprobe and torchaudio failed for {audio_path}: {e}")
|
||||
|
||||
|
||||
def read_audio_text_pairs(csv_file_path):
|
||||
audio_text_pairs = []
|
||||
|
||||
parent = Path(csv_file_path).parent
|
||||
with open(csv_file_path, mode="r", newline="", encoding="utf-8-sig") as csvfile:
|
||||
reader = csv.reader(csvfile, delimiter="|")
|
||||
next(reader) # Skip the header row
|
||||
for row in reader:
|
||||
if len(row) >= 2:
|
||||
audio_file = row[0].strip() # First column: audio file path
|
||||
text = row[1].strip() # Second column: text
|
||||
audio_file_path = parent / audio_file
|
||||
audio_text_pairs.append((audio_file_path.as_posix(), text))
|
||||
|
||||
return audio_text_pairs
|
||||
|
||||
|
||||
def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_finetune):
|
||||
out_dir = Path(out_dir)
|
||||
out_dir.mkdir(exist_ok=True, parents=True)
|
||||
print(f"\nSaving to {out_dir} ...")
|
||||
|
||||
# Save dataset with improved batch size for better I/O performance
|
||||
raw_arrow_path = out_dir / "raw.arrow"
|
||||
with ArrowWriter(path=raw_arrow_path.as_posix(), writer_batch_size=100) as writer:
|
||||
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
||||
writer.write(line)
|
||||
|
||||
# Save durations to JSON
|
||||
dur_json_path = out_dir / "duration.json"
|
||||
with open(dur_json_path.as_posix(), "w", encoding="utf-8") as f:
|
||||
json.dump({"duration": duration_list}, f, ensure_ascii=False)
|
||||
|
||||
# Handle vocab file - write only once based on finetune flag
|
||||
voca_out_path = out_dir / "vocab.txt"
|
||||
if is_finetune:
|
||||
file_vocab_finetune = PRETRAINED_VOCAB_PATH.as_posix()
|
||||
shutil.copy2(file_vocab_finetune, voca_out_path)
|
||||
else:
|
||||
with open(voca_out_path.as_posix(), "w") as f:
|
||||
for vocab in sorted(text_vocab_set):
|
||||
f.write(vocab + "\n")
|
||||
|
||||
dataset_name = out_dir.stem
|
||||
print(f"\nFor {dataset_name}, sample count: {len(result)}")
|
||||
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
|
||||
print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours")
|
||||
|
||||
|
||||
def prepare_and_save_set(inp_dir, out_dir, is_finetune: bool = True, num_workers: int = None):
|
||||
if is_finetune:
|
||||
assert PRETRAINED_VOCAB_PATH.exists(), f"pretrained vocab.txt not found: {PRETRAINED_VOCAB_PATH}"
|
||||
sub_result, durations, vocab_set = prepare_csv_wavs_dir(inp_dir, num_workers=num_workers)
|
||||
save_prepped_dataset(out_dir, sub_result, durations, vocab_set, is_finetune)
|
||||
|
||||
|
||||
def cli():
|
||||
try:
|
||||
# Before processing, check if ffprobe is available.
|
||||
if shutil.which("ffprobe") is None:
|
||||
print(
|
||||
"Warning: ffprobe is not available. Duration extraction will rely on torchaudio (which may be slower)."
|
||||
)
|
||||
|
||||
# Usage examples in help text
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Prepare and save dataset.",
|
||||
epilog="""
|
||||
Examples:
|
||||
# For fine-tuning (default):
|
||||
python prepare_csv_wavs.py /input/dataset/path /output/dataset/path
|
||||
|
||||
# For pre-training:
|
||||
python prepare_csv_wavs.py /input/dataset/path /output/dataset/path --pretrain
|
||||
|
||||
# With custom worker count:
|
||||
python prepare_csv_wavs.py /input/dataset/path /output/dataset/path --workers 4
|
||||
""",
|
||||
)
|
||||
parser.add_argument("inp_dir", type=str, help="Input directory containing the data.")
|
||||
parser.add_argument("out_dir", type=str, help="Output directory to save the prepared data.")
|
||||
parser.add_argument("--pretrain", action="store_true", help="Enable for new pretrain, otherwise is a fine-tune")
|
||||
parser.add_argument("--workers", type=int, help=f"Number of worker threads (default: {MAX_WORKERS})")
|
||||
args = parser.parse_args()
|
||||
|
||||
prepare_and_save_set(args.inp_dir, args.out_dir, is_finetune=not args.pretrain, num_workers=args.workers)
|
||||
except KeyboardInterrupt:
|
||||
print("\nOperation cancelled by user. Cleaning up...")
|
||||
if executor is not None:
|
||||
executor.shutdown(wait=False, cancel_futures=True)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
@@ -0,0 +1,228 @@
|
||||
# Emilia Dataset: https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07
|
||||
# if use updated new version, i.e. WebDataset, feel free to modify / draft your own script
|
||||
|
||||
# generate audio text map for Emilia ZH & EN
|
||||
# evaluate for vocab size
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
import json
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from importlib.resources import files
|
||||
from pathlib import Path
|
||||
|
||||
from datasets.arrow_writer import ArrowWriter
|
||||
from tqdm import tqdm
|
||||
|
||||
from f5_tts.model.utils import convert_char_to_pinyin, repetition_found
|
||||
|
||||
|
||||
out_zh = {
|
||||
"ZH_B00041_S06226",
|
||||
"ZH_B00042_S09204",
|
||||
"ZH_B00065_S09430",
|
||||
"ZH_B00065_S09431",
|
||||
"ZH_B00066_S09327",
|
||||
"ZH_B00066_S09328",
|
||||
}
|
||||
zh_filters = ["い", "て"]
|
||||
# seems synthesized audios, or heavily code-switched
|
||||
out_en = {
|
||||
"EN_B00013_S00913",
|
||||
"EN_B00042_S00120",
|
||||
"EN_B00055_S04111",
|
||||
"EN_B00061_S00693",
|
||||
"EN_B00061_S01494",
|
||||
"EN_B00061_S03375",
|
||||
"EN_B00059_S00092",
|
||||
"EN_B00111_S04300",
|
||||
"EN_B00100_S03759",
|
||||
"EN_B00087_S03811",
|
||||
"EN_B00059_S00950",
|
||||
"EN_B00089_S00946",
|
||||
"EN_B00078_S05127",
|
||||
"EN_B00070_S04089",
|
||||
"EN_B00074_S09659",
|
||||
"EN_B00061_S06983",
|
||||
"EN_B00061_S07060",
|
||||
"EN_B00059_S08397",
|
||||
"EN_B00082_S06192",
|
||||
"EN_B00091_S01238",
|
||||
"EN_B00089_S07349",
|
||||
"EN_B00070_S04343",
|
||||
"EN_B00061_S02400",
|
||||
"EN_B00076_S01262",
|
||||
"EN_B00068_S06467",
|
||||
"EN_B00076_S02943",
|
||||
"EN_B00064_S05954",
|
||||
"EN_B00061_S05386",
|
||||
"EN_B00066_S06544",
|
||||
"EN_B00076_S06944",
|
||||
"EN_B00072_S08620",
|
||||
"EN_B00076_S07135",
|
||||
"EN_B00076_S09127",
|
||||
"EN_B00065_S00497",
|
||||
"EN_B00059_S06227",
|
||||
"EN_B00063_S02859",
|
||||
"EN_B00075_S01547",
|
||||
"EN_B00061_S08286",
|
||||
"EN_B00079_S02901",
|
||||
"EN_B00092_S03643",
|
||||
"EN_B00096_S08653",
|
||||
"EN_B00063_S04297",
|
||||
"EN_B00063_S04614",
|
||||
"EN_B00079_S04698",
|
||||
"EN_B00104_S01666",
|
||||
"EN_B00061_S09504",
|
||||
"EN_B00061_S09694",
|
||||
"EN_B00065_S05444",
|
||||
"EN_B00063_S06860",
|
||||
"EN_B00065_S05725",
|
||||
"EN_B00069_S07628",
|
||||
"EN_B00083_S03875",
|
||||
"EN_B00071_S07665",
|
||||
"EN_B00071_S07665",
|
||||
"EN_B00062_S04187",
|
||||
"EN_B00065_S09873",
|
||||
"EN_B00065_S09922",
|
||||
"EN_B00084_S02463",
|
||||
"EN_B00067_S05066",
|
||||
"EN_B00106_S08060",
|
||||
"EN_B00073_S06399",
|
||||
"EN_B00073_S09236",
|
||||
"EN_B00087_S00432",
|
||||
"EN_B00085_S05618",
|
||||
"EN_B00064_S01262",
|
||||
"EN_B00072_S01739",
|
||||
"EN_B00059_S03913",
|
||||
"EN_B00069_S04036",
|
||||
"EN_B00067_S05623",
|
||||
"EN_B00060_S05389",
|
||||
"EN_B00060_S07290",
|
||||
"EN_B00062_S08995",
|
||||
}
|
||||
en_filters = ["ا", "い", "て"]
|
||||
|
||||
|
||||
def deal_with_audio_dir(audio_dir):
|
||||
audio_jsonl = audio_dir.with_suffix(".jsonl")
|
||||
sub_result, durations = [], []
|
||||
vocab_set = set()
|
||||
bad_case_zh = 0
|
||||
bad_case_en = 0
|
||||
with open(audio_jsonl, "r") as f:
|
||||
lines = f.readlines()
|
||||
for line in tqdm(lines, desc=f"{audio_jsonl.stem}"):
|
||||
obj = json.loads(line)
|
||||
text = obj["text"]
|
||||
if obj["language"] == "zh":
|
||||
if obj["wav"].split("/")[1] in out_zh or any(f in text for f in zh_filters) or repetition_found(text):
|
||||
bad_case_zh += 1
|
||||
continue
|
||||
else:
|
||||
text = text.translate(
|
||||
str.maketrans({",": ",", "!": "!", "?": "?"})
|
||||
) # not "。" cuz much code-switched
|
||||
if obj["language"] == "en":
|
||||
if (
|
||||
obj["wav"].split("/")[1] in out_en
|
||||
or any(f in text for f in en_filters)
|
||||
or repetition_found(text, length=4)
|
||||
):
|
||||
bad_case_en += 1
|
||||
continue
|
||||
if tokenizer == "pinyin":
|
||||
text = convert_char_to_pinyin([text], polyphone=polyphone)[0]
|
||||
duration = obj["duration"]
|
||||
sub_result.append({"audio_path": str(audio_dir.parent / obj["wav"]), "text": text, "duration": duration})
|
||||
durations.append(duration)
|
||||
vocab_set.update(list(text))
|
||||
return sub_result, durations, vocab_set, bad_case_zh, bad_case_en
|
||||
|
||||
|
||||
def main():
|
||||
assert tokenizer in ["pinyin", "char"]
|
||||
result = []
|
||||
duration_list = []
|
||||
text_vocab_set = set()
|
||||
total_bad_case_zh = 0
|
||||
total_bad_case_en = 0
|
||||
|
||||
# process raw data
|
||||
executor = ProcessPoolExecutor(max_workers=max_workers)
|
||||
futures = []
|
||||
for lang in langs:
|
||||
dataset_path = Path(os.path.join(dataset_dir, lang))
|
||||
[
|
||||
futures.append(executor.submit(deal_with_audio_dir, audio_dir))
|
||||
for audio_dir in dataset_path.iterdir()
|
||||
if audio_dir.is_dir()
|
||||
]
|
||||
for futures in tqdm(futures, total=len(futures)):
|
||||
sub_result, durations, vocab_set, bad_case_zh, bad_case_en = futures.result()
|
||||
result.extend(sub_result)
|
||||
duration_list.extend(durations)
|
||||
text_vocab_set.update(vocab_set)
|
||||
total_bad_case_zh += bad_case_zh
|
||||
total_bad_case_en += bad_case_en
|
||||
executor.shutdown()
|
||||
|
||||
# save preprocessed dataset to disk
|
||||
if not os.path.exists(f"{save_dir}"):
|
||||
os.makedirs(f"{save_dir}")
|
||||
print(f"\nSaving to {save_dir} ...")
|
||||
|
||||
# dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list}) # oom
|
||||
# dataset.save_to_disk(f"{save_dir}/raw", max_shard_size="2GB")
|
||||
with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
|
||||
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
||||
writer.write(line)
|
||||
|
||||
# dup a json separately saving duration in case for DynamicBatchSampler ease
|
||||
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
|
||||
json.dump({"duration": duration_list}, f, ensure_ascii=False)
|
||||
|
||||
# vocab map, i.e. tokenizer
|
||||
# add alphabets and symbols (optional, if plan to ft on de/fr etc.)
|
||||
# if tokenizer == "pinyin":
|
||||
# text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)])
|
||||
with open(f"{save_dir}/vocab.txt", "w") as f:
|
||||
for vocab in sorted(text_vocab_set):
|
||||
f.write(vocab + "\n")
|
||||
|
||||
print(f"\nFor {dataset_name}, sample count: {len(result)}")
|
||||
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
|
||||
print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours")
|
||||
if "ZH" in langs:
|
||||
print(f"Bad zh transcription case: {total_bad_case_zh}")
|
||||
if "EN" in langs:
|
||||
print(f"Bad en transcription case: {total_bad_case_en}\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
max_workers = 32
|
||||
|
||||
tokenizer = "pinyin" # "pinyin" | "char"
|
||||
polyphone = True
|
||||
|
||||
langs = ["ZH", "EN"]
|
||||
dataset_dir = "<SOME_PATH>/Emilia_Dataset/raw"
|
||||
dataset_name = f"Emilia_{'_'.join(langs)}_{tokenizer}"
|
||||
save_dir = str(files("f5_tts").joinpath("../../")) + f"/data/{dataset_name}"
|
||||
print(f"\nPrepare for {dataset_name}, will save to {save_dir}\n")
|
||||
|
||||
main()
|
||||
|
||||
# Emilia ZH & EN
|
||||
# samples count 37837916 (after removal)
|
||||
# pinyin vocab size 2543 (polyphone)
|
||||
# total duration 95281.87 (hours)
|
||||
# bad zh asr cnt 230435 (samples)
|
||||
# bad eh asr cnt 37217 (samples)
|
||||
|
||||
# vocab size may be slightly different due to jieba tokenizer and pypinyin (e.g. way of polyphoneme)
|
||||
# please be careful if using pretrained model, make sure the vocab.txt is same
|
||||
@@ -0,0 +1,94 @@
|
||||
# put in src/f5_tts/train/datasets/prepare_emilia_v2.py
|
||||
# prepares Emilia dataset with the new format w/ Emilia-YODAS
|
||||
|
||||
import json
|
||||
import os
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from importlib.resources import files
|
||||
from pathlib import Path
|
||||
|
||||
from datasets.arrow_writer import ArrowWriter
|
||||
from tqdm import tqdm
|
||||
|
||||
from f5_tts.model.utils import repetition_found
|
||||
|
||||
|
||||
# Define filters for exclusion
|
||||
out_en = set()
|
||||
en_filters = ["ا", "い", "て"]
|
||||
|
||||
|
||||
def process_audio_directory(audio_dir):
|
||||
sub_result, durations, vocab_set = [], [], set()
|
||||
bad_case_en = 0
|
||||
|
||||
for file in audio_dir.iterdir():
|
||||
if file.suffix == ".json":
|
||||
with open(file, "r") as f:
|
||||
obj = json.load(f)
|
||||
text = obj["text"]
|
||||
if any(f in text for f in en_filters) or repetition_found(text, length=4):
|
||||
bad_case_en += 1
|
||||
continue
|
||||
|
||||
duration = obj["duration"]
|
||||
audio_file = file.with_suffix(".mp3")
|
||||
if audio_file.exists():
|
||||
sub_result.append({"audio_path": str(audio_file), "text": text, "duration": duration})
|
||||
durations.append(duration)
|
||||
vocab_set.update(list(text))
|
||||
|
||||
return sub_result, durations, vocab_set, bad_case_en
|
||||
|
||||
|
||||
def main():
|
||||
assert tokenizer in ["pinyin", "char"]
|
||||
result, duration_list, text_vocab_set = [], [], set()
|
||||
total_bad_case_en = 0
|
||||
|
||||
executor = ProcessPoolExecutor(max_workers=max_workers)
|
||||
futures = []
|
||||
dataset_path = Path(dataset_dir)
|
||||
for sub_dir in dataset_path.iterdir():
|
||||
if sub_dir.is_dir():
|
||||
futures.append(executor.submit(process_audio_directory, sub_dir))
|
||||
|
||||
for future in tqdm(futures, total=len(futures)):
|
||||
sub_result, durations, vocab_set, bad_case_en = future.result()
|
||||
result.extend(sub_result)
|
||||
duration_list.extend(durations)
|
||||
text_vocab_set.update(vocab_set)
|
||||
total_bad_case_en += bad_case_en
|
||||
|
||||
executor.shutdown()
|
||||
|
||||
if not os.path.exists(f"{save_dir}"):
|
||||
os.makedirs(f"{save_dir}")
|
||||
|
||||
with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
|
||||
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
||||
writer.write(line)
|
||||
|
||||
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
|
||||
json.dump({"duration": duration_list}, f, ensure_ascii=False)
|
||||
|
||||
with open(f"{save_dir}/vocab.txt", "w") as f:
|
||||
for vocab in sorted(text_vocab_set):
|
||||
f.write(vocab + "\n")
|
||||
|
||||
print(f"For {dataset_name}, sample count: {len(result)}")
|
||||
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
|
||||
print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours")
|
||||
print(f"Bad en transcription case: {total_bad_case_en}\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
max_workers = 32
|
||||
tokenizer = "char"
|
||||
dataset_dir = "/home/ubuntu/emilia-dataset/Emilia-YODAS/EN"
|
||||
dataset_name = f"Emilia_EN_{tokenizer}"
|
||||
# save_dir = os.path.expanduser(f"~/F5-TTS/data/{dataset_name}")
|
||||
save_dir = str(files("f5_tts").joinpath("../../")) + f"/data/{dataset_name}"
|
||||
|
||||
print(f"Prepare for {dataset_name}, will save to {save_dir}\n")
|
||||
main()
|
||||
@@ -0,0 +1,94 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
import json
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from importlib.resources import files
|
||||
from pathlib import Path
|
||||
|
||||
import soundfile as sf
|
||||
from datasets.arrow_writer import ArrowWriter
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def deal_with_audio_dir(audio_dir):
|
||||
sub_result, durations = [], []
|
||||
vocab_set = set()
|
||||
audio_lists = list(audio_dir.rglob("*.wav"))
|
||||
|
||||
for line in audio_lists:
|
||||
text_path = line.with_suffix(".normalized.txt")
|
||||
text = open(text_path, "r").read().strip()
|
||||
duration = sf.info(line).duration
|
||||
if duration < 0.4 or duration > 30:
|
||||
continue
|
||||
sub_result.append({"audio_path": str(line), "text": text, "duration": duration})
|
||||
durations.append(duration)
|
||||
vocab_set.update(list(text))
|
||||
return sub_result, durations, vocab_set
|
||||
|
||||
|
||||
def main():
|
||||
result = []
|
||||
duration_list = []
|
||||
text_vocab_set = set()
|
||||
|
||||
# process raw data
|
||||
executor = ProcessPoolExecutor(max_workers=max_workers)
|
||||
futures = []
|
||||
|
||||
for subset in tqdm(SUB_SET):
|
||||
dataset_path = Path(os.path.join(dataset_dir, subset))
|
||||
[
|
||||
futures.append(executor.submit(deal_with_audio_dir, audio_dir))
|
||||
for audio_dir in dataset_path.iterdir()
|
||||
if audio_dir.is_dir()
|
||||
]
|
||||
for future in tqdm(futures, total=len(futures)):
|
||||
sub_result, durations, vocab_set = future.result()
|
||||
result.extend(sub_result)
|
||||
duration_list.extend(durations)
|
||||
text_vocab_set.update(vocab_set)
|
||||
executor.shutdown()
|
||||
|
||||
# save preprocessed dataset to disk
|
||||
if not os.path.exists(f"{save_dir}"):
|
||||
os.makedirs(f"{save_dir}")
|
||||
print(f"\nSaving to {save_dir} ...")
|
||||
|
||||
with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
|
||||
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
||||
writer.write(line)
|
||||
|
||||
# dup a json separately saving duration in case for DynamicBatchSampler ease
|
||||
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
|
||||
json.dump({"duration": duration_list}, f, ensure_ascii=False)
|
||||
|
||||
# vocab map, i.e. tokenizer
|
||||
with open(f"{save_dir}/vocab.txt", "w") as f:
|
||||
for vocab in sorted(text_vocab_set):
|
||||
f.write(vocab + "\n")
|
||||
|
||||
print(f"\nFor {dataset_name}, sample count: {len(result)}")
|
||||
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
|
||||
print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
max_workers = 36
|
||||
|
||||
tokenizer = "char" # "pinyin" | "char"
|
||||
|
||||
SUB_SET = ["train-clean-100", "train-clean-360", "train-other-500"]
|
||||
dataset_dir = "<SOME_PATH>/LibriTTS"
|
||||
dataset_name = f"LibriTTS_{'_'.join(SUB_SET)}_{tokenizer}".replace("train-clean-", "").replace("train-other-", "")
|
||||
save_dir = str(files("f5_tts").joinpath("../../")) + f"/data/{dataset_name}"
|
||||
print(f"\nPrepare for {dataset_name}, will save to {save_dir}\n")
|
||||
main()
|
||||
|
||||
# For LibriTTS_100_360_500_char, sample count: 354218
|
||||
# For LibriTTS_100_360_500_char, vocab size is: 78
|
||||
# For LibriTTS_100_360_500_char, total 554.09 hours
|
||||
@@ -0,0 +1,67 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
import json
|
||||
from importlib.resources import files
|
||||
from pathlib import Path
|
||||
|
||||
import soundfile as sf
|
||||
from datasets.arrow_writer import ArrowWriter
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def main():
|
||||
result = []
|
||||
duration_list = []
|
||||
text_vocab_set = set()
|
||||
|
||||
with open(meta_info, "r") as f:
|
||||
lines = f.readlines()
|
||||
for line in tqdm(lines):
|
||||
uttr, text, norm_text = line.split("|")
|
||||
norm_text = norm_text.strip()
|
||||
wav_path = Path(dataset_dir) / "wavs" / f"{uttr}.wav"
|
||||
duration = sf.info(wav_path).duration
|
||||
if duration < 0.4 or duration > 30:
|
||||
continue
|
||||
result.append({"audio_path": str(wav_path), "text": norm_text, "duration": duration})
|
||||
duration_list.append(duration)
|
||||
text_vocab_set.update(list(norm_text))
|
||||
|
||||
# save preprocessed dataset to disk
|
||||
if not os.path.exists(f"{save_dir}"):
|
||||
os.makedirs(f"{save_dir}")
|
||||
print(f"\nSaving to {save_dir} ...")
|
||||
|
||||
with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
|
||||
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
||||
writer.write(line)
|
||||
|
||||
# dup a json separately saving duration in case for DynamicBatchSampler ease
|
||||
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
|
||||
json.dump({"duration": duration_list}, f, ensure_ascii=False)
|
||||
|
||||
# vocab map, i.e. tokenizer
|
||||
# add alphabets and symbols (optional, if plan to ft on de/fr etc.)
|
||||
with open(f"{save_dir}/vocab.txt", "w") as f:
|
||||
for vocab in sorted(text_vocab_set):
|
||||
f.write(vocab + "\n")
|
||||
|
||||
print(f"\nFor {dataset_name}, sample count: {len(result)}")
|
||||
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
|
||||
print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tokenizer = "char" # "pinyin" | "char"
|
||||
|
||||
dataset_dir = "<SOME_PATH>/LJSpeech-1.1"
|
||||
dataset_name = f"LJSpeech_{tokenizer}"
|
||||
meta_info = os.path.join(dataset_dir, "metadata.csv")
|
||||
save_dir = str(files("f5_tts").joinpath("../../")) + f"/data/{dataset_name}"
|
||||
print(f"\nPrepare for {dataset_name}, will save to {save_dir}\n")
|
||||
|
||||
main()
|
||||
@@ -0,0 +1,126 @@
|
||||
# generate audio text map for WenetSpeech4TTS
|
||||
# evaluate for vocab size
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
import json
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from importlib.resources import files
|
||||
|
||||
import torchaudio
|
||||
from datasets import Dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
from f5_tts.model.utils import convert_char_to_pinyin
|
||||
|
||||
|
||||
def deal_with_sub_path_files(dataset_path, sub_path):
|
||||
print(f"Dealing with: {sub_path}")
|
||||
|
||||
text_dir = os.path.join(dataset_path, sub_path, "txts")
|
||||
audio_dir = os.path.join(dataset_path, sub_path, "wavs")
|
||||
text_files = os.listdir(text_dir)
|
||||
|
||||
audio_paths, texts, durations = [], [], []
|
||||
for text_file in tqdm(text_files):
|
||||
with open(os.path.join(text_dir, text_file), "r", encoding="utf-8") as file:
|
||||
first_line = file.readline().split("\t")
|
||||
audio_nm = first_line[0]
|
||||
audio_path = os.path.join(audio_dir, audio_nm + ".wav")
|
||||
text = first_line[1].strip()
|
||||
|
||||
audio_paths.append(audio_path)
|
||||
|
||||
if tokenizer == "pinyin":
|
||||
texts.extend(convert_char_to_pinyin([text], polyphone=polyphone))
|
||||
elif tokenizer == "char":
|
||||
texts.append(text)
|
||||
|
||||
audio, sample_rate = torchaudio.load(audio_path)
|
||||
durations.append(audio.shape[-1] / sample_rate)
|
||||
|
||||
return audio_paths, texts, durations
|
||||
|
||||
|
||||
def main():
|
||||
assert tokenizer in ["pinyin", "char"]
|
||||
|
||||
audio_path_list, text_list, duration_list = [], [], []
|
||||
|
||||
executor = ProcessPoolExecutor(max_workers=max_workers)
|
||||
futures = []
|
||||
for dataset_path in dataset_paths:
|
||||
sub_items = os.listdir(dataset_path)
|
||||
sub_paths = [item for item in sub_items if os.path.isdir(os.path.join(dataset_path, item))]
|
||||
for sub_path in sub_paths:
|
||||
futures.append(executor.submit(deal_with_sub_path_files, dataset_path, sub_path))
|
||||
for future in tqdm(futures, total=len(futures)):
|
||||
audio_paths, texts, durations = future.result()
|
||||
audio_path_list.extend(audio_paths)
|
||||
text_list.extend(texts)
|
||||
duration_list.extend(durations)
|
||||
executor.shutdown()
|
||||
|
||||
if not os.path.exists("data"):
|
||||
os.makedirs("data")
|
||||
|
||||
print(f"\nSaving to {save_dir} ...")
|
||||
dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list})
|
||||
dataset.save_to_disk(f"{save_dir}/raw", max_shard_size="2GB") # arrow format
|
||||
|
||||
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
|
||||
json.dump(
|
||||
{"duration": duration_list}, f, ensure_ascii=False
|
||||
) # dup a json separately saving duration in case for DynamicBatchSampler ease
|
||||
|
||||
print("\nEvaluating vocab size (all characters and symbols / all phonemes) ...")
|
||||
text_vocab_set = set()
|
||||
for text in tqdm(text_list):
|
||||
text_vocab_set.update(list(text))
|
||||
|
||||
# add alphabets and symbols (optional, if plan to ft on de/fr etc.)
|
||||
if tokenizer == "pinyin":
|
||||
text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)])
|
||||
|
||||
with open(f"{save_dir}/vocab.txt", "w") as f:
|
||||
for vocab in sorted(text_vocab_set):
|
||||
f.write(vocab + "\n")
|
||||
print(f"\nFor {dataset_name}, sample count: {len(text_list)}")
|
||||
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
max_workers = 32
|
||||
|
||||
tokenizer = "pinyin" # "pinyin" | "char"
|
||||
polyphone = True
|
||||
dataset_choice = 1 # 1: Premium, 2: Standard, 3: Basic
|
||||
|
||||
dataset_name = (
|
||||
["WenetSpeech4TTS_Premium", "WenetSpeech4TTS_Standard", "WenetSpeech4TTS_Basic"][dataset_choice - 1]
|
||||
+ "_"
|
||||
+ tokenizer
|
||||
)
|
||||
dataset_paths = [
|
||||
"<SOME_PATH>/WenetSpeech4TTS/Basic",
|
||||
"<SOME_PATH>/WenetSpeech4TTS/Standard",
|
||||
"<SOME_PATH>/WenetSpeech4TTS/Premium",
|
||||
][-dataset_choice:]
|
||||
save_dir = str(files("f5_tts").joinpath("../../")) + f"/data/{dataset_name}"
|
||||
print(f"\nChoose Dataset: {dataset_name}, will save to {save_dir}\n")
|
||||
|
||||
main()
|
||||
|
||||
# Results (if adding alphabets with accents and symbols):
|
||||
# WenetSpeech4TTS Basic Standard Premium
|
||||
# samples count 3932473 1941220 407494
|
||||
# pinyin vocab size 1349 1348 1344 (no polyphone)
|
||||
# - - 1459 (polyphone)
|
||||
# char vocab size 5264 5219 5042
|
||||
|
||||
# vocab size may be slightly different due to jieba tokenizer and pypinyin (e.g. way of polyphoneme)
|
||||
# please be careful if using pretrained model, make sure the vocab.txt is same
|
||||
214
metaX-C500-f5-tts/F5-TTS/src/f5_tts/train/finetune_cli.py
Normal file
214
metaX-C500-f5-tts/F5-TTS/src/f5_tts/train/finetune_cli.py
Normal file
@@ -0,0 +1,214 @@
|
||||
import argparse
|
||||
import os
|
||||
import shutil
|
||||
from importlib.resources import files
|
||||
|
||||
from cached_path import cached_path
|
||||
|
||||
from f5_tts.model import CFM, DiT, Trainer, UNetT
|
||||
from f5_tts.model.dataset import load_dataset
|
||||
from f5_tts.model.utils import get_tokenizer
|
||||
|
||||
|
||||
# -------------------------- Dataset Settings --------------------------- #
|
||||
target_sample_rate = 24000
|
||||
n_mel_channels = 100
|
||||
hop_length = 256
|
||||
win_length = 1024
|
||||
n_fft = 1024
|
||||
mel_spec_type = "vocos" # 'vocos' or 'bigvgan'
|
||||
|
||||
|
||||
# -------------------------- Argument Parsing --------------------------- #
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Train CFM Model")
|
||||
|
||||
parser.add_argument(
|
||||
"--exp_name",
|
||||
type=str,
|
||||
default="F5TTS_v1_Base",
|
||||
choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"],
|
||||
help="Experiment name",
|
||||
)
|
||||
parser.add_argument("--dataset_name", type=str, default="Emilia_ZH_EN", help="Name of the dataset to use")
|
||||
parser.add_argument("--learning_rate", type=float, default=1e-5, help="Learning rate for training")
|
||||
parser.add_argument("--batch_size_per_gpu", type=int, default=3200, help="Batch size per GPU")
|
||||
parser.add_argument(
|
||||
"--batch_size_type", type=str, default="frame", choices=["frame", "sample"], help="Batch size type"
|
||||
)
|
||||
parser.add_argument("--max_samples", type=int, default=64, help="Max sequences per batch")
|
||||
parser.add_argument("--grad_accumulation_steps", type=int, default=1, help="Gradient accumulation steps")
|
||||
parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping")
|
||||
parser.add_argument("--epochs", type=int, default=100, help="Number of training epochs")
|
||||
parser.add_argument("--num_warmup_updates", type=int, default=20000, help="Warmup updates")
|
||||
parser.add_argument("--save_per_updates", type=int, default=50000, help="Save checkpoint every N updates")
|
||||
parser.add_argument(
|
||||
"--keep_last_n_checkpoints",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="-1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints",
|
||||
)
|
||||
parser.add_argument("--last_per_updates", type=int, default=5000, help="Save last checkpoint every N updates")
|
||||
parser.add_argument("--finetune", action="store_true", help="Use Finetune")
|
||||
parser.add_argument("--pretrain", type=str, default=None, help="the path to the checkpoint")
|
||||
parser.add_argument(
|
||||
"--tokenizer", type=str, default="pinyin", choices=["pinyin", "char", "custom"], help="Tokenizer type"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to custom tokenizer vocab file (only used if tokenizer = 'custom')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log_samples",
|
||||
action="store_true",
|
||||
help="Log inferenced samples per ckpt save updates",
|
||||
)
|
||||
parser.add_argument("--logger", type=str, default=None, choices=[None, "wandb", "tensorboard"], help="logger")
|
||||
parser.add_argument(
|
||||
"--bnb_optimizer",
|
||||
action="store_true",
|
||||
help="Use 8-bit Adam optimizer from bitsandbytes",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
# -------------------------- Training Settings -------------------------- #
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
checkpoint_path = str(files("f5_tts").joinpath(f"../../ckpts/{args.dataset_name}"))
|
||||
|
||||
# Model parameters based on experiment name
|
||||
|
||||
if args.exp_name == "F5TTS_v1_Base":
|
||||
wandb_resume_id = None
|
||||
model_cls = DiT
|
||||
model_cfg = dict(
|
||||
dim=1024,
|
||||
depth=22,
|
||||
heads=16,
|
||||
ff_mult=2,
|
||||
text_dim=512,
|
||||
conv_layers=4,
|
||||
)
|
||||
if args.finetune:
|
||||
if args.pretrain is None:
|
||||
ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors"))
|
||||
else:
|
||||
ckpt_path = args.pretrain
|
||||
|
||||
elif args.exp_name == "F5TTS_Base":
|
||||
wandb_resume_id = None
|
||||
model_cls = DiT
|
||||
model_cfg = dict(
|
||||
dim=1024,
|
||||
depth=22,
|
||||
heads=16,
|
||||
ff_mult=2,
|
||||
text_dim=512,
|
||||
text_mask_padding=False,
|
||||
conv_layers=4,
|
||||
pe_attn_head=1,
|
||||
)
|
||||
if args.finetune:
|
||||
if args.pretrain is None:
|
||||
ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
|
||||
else:
|
||||
ckpt_path = args.pretrain
|
||||
|
||||
elif args.exp_name == "E2TTS_Base":
|
||||
wandb_resume_id = None
|
||||
model_cls = UNetT
|
||||
model_cfg = dict(
|
||||
dim=1024,
|
||||
depth=24,
|
||||
heads=16,
|
||||
ff_mult=4,
|
||||
text_mask_padding=False,
|
||||
pe_attn_head=1,
|
||||
)
|
||||
if args.finetune:
|
||||
if args.pretrain is None:
|
||||
ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
|
||||
else:
|
||||
ckpt_path = args.pretrain
|
||||
|
||||
if args.finetune:
|
||||
if not os.path.isdir(checkpoint_path):
|
||||
os.makedirs(checkpoint_path, exist_ok=True)
|
||||
|
||||
file_checkpoint = os.path.basename(ckpt_path)
|
||||
if not file_checkpoint.startswith("pretrained_"): # Change: Add 'pretrained_' prefix to copied model
|
||||
file_checkpoint = "pretrained_" + file_checkpoint
|
||||
file_checkpoint = os.path.join(checkpoint_path, file_checkpoint)
|
||||
if not os.path.isfile(file_checkpoint):
|
||||
shutil.copy2(ckpt_path, file_checkpoint)
|
||||
print("copy checkpoint for finetune")
|
||||
|
||||
# Use the tokenizer and tokenizer_path provided in the command line arguments
|
||||
|
||||
tokenizer = args.tokenizer
|
||||
if tokenizer == "custom":
|
||||
if not args.tokenizer_path:
|
||||
raise ValueError("Custom tokenizer selected, but no tokenizer_path provided.")
|
||||
tokenizer_path = args.tokenizer_path
|
||||
else:
|
||||
tokenizer_path = args.dataset_name
|
||||
|
||||
vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
|
||||
|
||||
print("\nvocab : ", vocab_size)
|
||||
print("\nvocoder : ", mel_spec_type)
|
||||
|
||||
mel_spec_kwargs = dict(
|
||||
n_fft=n_fft,
|
||||
hop_length=hop_length,
|
||||
win_length=win_length,
|
||||
n_mel_channels=n_mel_channels,
|
||||
target_sample_rate=target_sample_rate,
|
||||
mel_spec_type=mel_spec_type,
|
||||
)
|
||||
|
||||
model = CFM(
|
||||
transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
|
||||
mel_spec_kwargs=mel_spec_kwargs,
|
||||
vocab_char_map=vocab_char_map,
|
||||
)
|
||||
|
||||
trainer = Trainer(
|
||||
model,
|
||||
args.epochs,
|
||||
args.learning_rate,
|
||||
num_warmup_updates=args.num_warmup_updates,
|
||||
save_per_updates=args.save_per_updates,
|
||||
keep_last_n_checkpoints=args.keep_last_n_checkpoints,
|
||||
checkpoint_path=checkpoint_path,
|
||||
batch_size_per_gpu=args.batch_size_per_gpu,
|
||||
batch_size_type=args.batch_size_type,
|
||||
max_samples=args.max_samples,
|
||||
grad_accumulation_steps=args.grad_accumulation_steps,
|
||||
max_grad_norm=args.max_grad_norm,
|
||||
logger=args.logger,
|
||||
wandb_project=args.dataset_name,
|
||||
wandb_run_name=args.exp_name,
|
||||
wandb_resume_id=wandb_resume_id,
|
||||
log_samples=args.log_samples,
|
||||
last_per_updates=args.last_per_updates,
|
||||
bnb_optimizer=args.bnb_optimizer,
|
||||
)
|
||||
|
||||
train_dataset = load_dataset(args.dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
|
||||
|
||||
trainer.train(
|
||||
train_dataset,
|
||||
resumable_with_seed=666, # seed for shuffling dataset
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1904
metaX-C500-f5-tts/F5-TTS/src/f5_tts/train/finetune_gradio.py
Normal file
1904
metaX-C500-f5-tts/F5-TTS/src/f5_tts/train/finetune_gradio.py
Normal file
File diff suppressed because it is too large
Load Diff
77
metaX-C500-f5-tts/F5-TTS/src/f5_tts/train/train.py
Normal file
77
metaX-C500-f5-tts/F5-TTS/src/f5_tts/train/train.py
Normal file
@@ -0,0 +1,77 @@
|
||||
# training script.
|
||||
|
||||
import os
|
||||
from importlib.resources import files
|
||||
|
||||
import hydra
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from f5_tts.model import CFM, Trainer
|
||||
from f5_tts.model.dataset import load_dataset
|
||||
from f5_tts.model.utils import get_tokenizer
|
||||
|
||||
|
||||
os.chdir(str(files("f5_tts").joinpath("../.."))) # change working directory to root of project (local editable)
|
||||
|
||||
|
||||
@hydra.main(version_base="1.3", config_path=str(files("f5_tts").joinpath("configs")), config_name=None)
|
||||
def main(model_cfg):
|
||||
model_cls = hydra.utils.get_class(f"f5_tts.model.{model_cfg.model.backbone}")
|
||||
model_arc = model_cfg.model.arch
|
||||
tokenizer = model_cfg.model.tokenizer
|
||||
mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
|
||||
|
||||
exp_name = f"{model_cfg.model.name}_{mel_spec_type}_{model_cfg.model.tokenizer}_{model_cfg.datasets.name}"
|
||||
wandb_resume_id = None
|
||||
|
||||
# set text tokenizer
|
||||
if tokenizer != "custom":
|
||||
tokenizer_path = model_cfg.datasets.name
|
||||
else:
|
||||
tokenizer_path = model_cfg.model.tokenizer_path
|
||||
vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
|
||||
|
||||
# set model
|
||||
model = CFM(
|
||||
transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=model_cfg.model.mel_spec.n_mel_channels),
|
||||
mel_spec_kwargs=model_cfg.model.mel_spec,
|
||||
vocab_char_map=vocab_char_map,
|
||||
)
|
||||
|
||||
# init trainer
|
||||
trainer = Trainer(
|
||||
model,
|
||||
epochs=model_cfg.optim.epochs,
|
||||
learning_rate=model_cfg.optim.learning_rate,
|
||||
num_warmup_updates=model_cfg.optim.num_warmup_updates,
|
||||
save_per_updates=model_cfg.ckpts.save_per_updates,
|
||||
keep_last_n_checkpoints=model_cfg.ckpts.keep_last_n_checkpoints,
|
||||
checkpoint_path=str(files("f5_tts").joinpath(f"../../{model_cfg.ckpts.save_dir}")),
|
||||
batch_size_per_gpu=model_cfg.datasets.batch_size_per_gpu,
|
||||
batch_size_type=model_cfg.datasets.batch_size_type,
|
||||
max_samples=model_cfg.datasets.max_samples,
|
||||
grad_accumulation_steps=model_cfg.optim.grad_accumulation_steps,
|
||||
max_grad_norm=model_cfg.optim.max_grad_norm,
|
||||
logger=model_cfg.ckpts.logger,
|
||||
wandb_project="CFM-TTS",
|
||||
wandb_run_name=exp_name,
|
||||
wandb_resume_id=wandb_resume_id,
|
||||
last_per_updates=model_cfg.ckpts.last_per_updates,
|
||||
log_samples=model_cfg.ckpts.log_samples,
|
||||
bnb_optimizer=model_cfg.optim.bnb_optimizer,
|
||||
mel_spec_type=mel_spec_type,
|
||||
is_local_vocoder=model_cfg.model.vocoder.is_local,
|
||||
local_vocoder_path=model_cfg.model.vocoder.local_path,
|
||||
model_cfg_dict=OmegaConf.to_container(model_cfg, resolve=True),
|
||||
)
|
||||
|
||||
train_dataset = load_dataset(model_cfg.datasets.name, tokenizer, mel_spec_kwargs=model_cfg.model.mel_spec)
|
||||
trainer.train(
|
||||
train_dataset,
|
||||
num_workers=model_cfg.datasets.num_workers,
|
||||
resumable_with_seed=666, # seed for shuffling dataset
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user