update README
This commit is contained in:
33
mlu_370-f5-tts/F5-TTS/src/f5_tts/scripts/count_max_epoch.py
Normal file
33
mlu_370-f5-tts/F5-TTS/src/f5_tts/scripts/count_max_epoch.py
Normal file
@@ -0,0 +1,33 @@
|
||||
"""ADAPTIVE BATCH SIZE"""
|
||||
|
||||
print("Adaptive batch size: using grouping batch sampler, frames_per_gpu fixed fed in")
|
||||
print(" -> least padding, gather wavs with accumulated frames in a batch\n")
|
||||
|
||||
# data
|
||||
total_hours = 95282
|
||||
mel_hop_length = 256
|
||||
mel_sampling_rate = 24000
|
||||
|
||||
# target
|
||||
wanted_max_updates = 1200000
|
||||
|
||||
# train params
|
||||
gpus = 8
|
||||
frames_per_gpu = 38400 # 8 * 38400 = 307200
|
||||
grad_accum = 1
|
||||
|
||||
# intermediate
|
||||
mini_batch_frames = frames_per_gpu * grad_accum * gpus
|
||||
mini_batch_hours = mini_batch_frames * mel_hop_length / mel_sampling_rate / 3600
|
||||
updates_per_epoch = total_hours / mini_batch_hours
|
||||
# steps_per_epoch = updates_per_epoch * grad_accum
|
||||
|
||||
# result
|
||||
epochs = wanted_max_updates / updates_per_epoch
|
||||
print(f"epochs should be set to: {epochs:.0f} ({epochs / grad_accum:.1f} x gd_acum {grad_accum})")
|
||||
print(f"progress_bar should show approx. 0/{updates_per_epoch:.0f} updates")
|
||||
# print(f" or approx. 0/{steps_per_epoch:.0f} steps")
|
||||
|
||||
# others
|
||||
print(f"total {total_hours:.0f} hours")
|
||||
print(f"mini-batch of {mini_batch_frames:.0f} frames, {mini_batch_hours:.2f} hours per mini-batch")
|
||||
@@ -0,0 +1,40 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
import thop
|
||||
import torch
|
||||
|
||||
from f5_tts.model import CFM, DiT
|
||||
|
||||
|
||||
""" ~155M """
|
||||
# transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4)
|
||||
# transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4, text_dim = 512, conv_layers = 4)
|
||||
# transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2)
|
||||
# transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2, text_dim = 512, conv_layers = 4)
|
||||
# transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2, text_dim = 512, conv_layers = 4, long_skip_connection = True)
|
||||
# transformer = MMDiT(dim = 512, depth = 16, heads = 16, ff_mult = 2)
|
||||
|
||||
""" ~335M """
|
||||
# FLOPs: 622.1 G, Params: 333.2 M
|
||||
# transformer = UNetT(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
|
||||
# FLOPs: 363.4 G, Params: 335.8 M
|
||||
transformer = DiT(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
||||
|
||||
|
||||
model = CFM(transformer=transformer)
|
||||
target_sample_rate = 24000
|
||||
n_mel_channels = 100
|
||||
hop_length = 256
|
||||
duration = 20
|
||||
frame_length = int(duration * target_sample_rate / hop_length)
|
||||
text_length = 150
|
||||
|
||||
flops, params = thop.profile(
|
||||
model, inputs=(torch.randn(1, frame_length, n_mel_channels), torch.zeros(1, text_length, dtype=torch.long))
|
||||
)
|
||||
print(f"FLOPs: {flops / 1e9} G")
|
||||
print(f"Params: {params / 1e6} M")
|
||||
Reference in New Issue
Block a user