update README

This commit is contained in:
2025-09-10 10:47:02 +08:00
parent 5088f0b50a
commit ff78032400
603 changed files with 21 additions and 23 deletions

View File

@@ -0,0 +1,33 @@
"""ADAPTIVE BATCH SIZE"""
print("Adaptive batch size: using grouping batch sampler, frames_per_gpu fixed fed in")
print(" -> least padding, gather wavs with accumulated frames in a batch\n")
# data
total_hours = 95282
mel_hop_length = 256
mel_sampling_rate = 24000
# target
wanted_max_updates = 1200000
# train params
gpus = 8
frames_per_gpu = 38400 # 8 * 38400 = 307200
grad_accum = 1
# intermediate
mini_batch_frames = frames_per_gpu * grad_accum * gpus
mini_batch_hours = mini_batch_frames * mel_hop_length / mel_sampling_rate / 3600
updates_per_epoch = total_hours / mini_batch_hours
# steps_per_epoch = updates_per_epoch * grad_accum
# result
epochs = wanted_max_updates / updates_per_epoch
print(f"epochs should be set to: {epochs:.0f} ({epochs / grad_accum:.1f} x gd_acum {grad_accum})")
print(f"progress_bar should show approx. 0/{updates_per_epoch:.0f} updates")
# print(f" or approx. 0/{steps_per_epoch:.0f} steps")
# others
print(f"total {total_hours:.0f} hours")
print(f"mini-batch of {mini_batch_frames:.0f} frames, {mini_batch_hours:.2f} hours per mini-batch")

View File

@@ -0,0 +1,40 @@
import os
import sys
sys.path.append(os.getcwd())
import thop
import torch
from f5_tts.model import CFM, DiT
""" ~155M """
# transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4)
# transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4, text_dim = 512, conv_layers = 4)
# transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2)
# transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2, text_dim = 512, conv_layers = 4)
# transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2, text_dim = 512, conv_layers = 4, long_skip_connection = True)
# transformer = MMDiT(dim = 512, depth = 16, heads = 16, ff_mult = 2)
""" ~335M """
# FLOPs: 622.1 G, Params: 333.2 M
# transformer = UNetT(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
# FLOPs: 363.4 G, Params: 335.8 M
transformer = DiT(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
model = CFM(transformer=transformer)
target_sample_rate = 24000
n_mel_channels = 100
hop_length = 256
duration = 20
frame_length = int(duration * target_sample_rate / hop_length)
text_length = 150
flops, params = thop.profile(
model, inputs=(torch.randn(1, frame_length, n_mel_channels), torch.zeros(1, text_length, dtype=torch.long))
)
print(f"FLOPs: {flops / 1e9} G")
print(f"Params: {params / 1e6} M")