561 lines
21 KiB
Python
561 lines
21 KiB
Python
# Copyright (c) 2024 Tsinghua Univ. (authors: Xingchen Song)
|
||
# 2025 (authors: Yuekai Zhang)
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
# Modified from https://github.com/xingchensong/S3Tokenizer/blob/main/s3tokenizer/cli.py
|
||
""" Example Usage
|
||
torchrun --nproc_per_node=1 \
|
||
benchmark.py --output-dir $log_dir \
|
||
--batch-size $batch_size \
|
||
--enable-warmup \
|
||
--split-name $split_name \
|
||
--model-path $F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt \
|
||
--vocab-file $F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt \
|
||
--vocoder-trt-engine-path $vocoder_trt_engine_path \
|
||
--backend-type $backend_type \
|
||
--tllm-model-dir $F5_TTS_TRT_LLM_ENGINE_PATH || exit 1
|
||
"""
|
||
|
||
import argparse
|
||
import json
|
||
import os
|
||
import time
|
||
from typing import Dict, List, Union
|
||
|
||
import datasets
|
||
import jieba
|
||
import tensorrt as trt
|
||
import torch
|
||
import torch.distributed as dist
|
||
import torch.nn.functional as F
|
||
import torchaudio
|
||
from datasets import load_dataset
|
||
from f5_tts_trtllm import F5TTS
|
||
from huggingface_hub import hf_hub_download
|
||
from pypinyin import Style, lazy_pinyin
|
||
from tensorrt_llm._utils import trt_dtype_to_torch
|
||
from tensorrt_llm.logger import logger
|
||
from tensorrt_llm.runtime.session import Session, TensorInfo
|
||
from torch.nn.utils.rnn import pad_sequence
|
||
from torch.utils.data import DataLoader, DistributedSampler
|
||
from tqdm import tqdm
|
||
from vocos import Vocos
|
||
|
||
|
||
torch.manual_seed(0)
|
||
|
||
|
||
def get_args():
|
||
parser = argparse.ArgumentParser(description="extract speech code")
|
||
parser.add_argument(
|
||
"--split-name",
|
||
type=str,
|
||
default="wenetspeech4tts",
|
||
choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"],
|
||
help="huggingface dataset split name",
|
||
)
|
||
parser.add_argument("--output-dir", required=True, type=str, help="dir to save result")
|
||
parser.add_argument(
|
||
"--vocab-file",
|
||
required=True,
|
||
type=str,
|
||
help="vocab file",
|
||
)
|
||
parser.add_argument(
|
||
"--model-path",
|
||
required=True,
|
||
type=str,
|
||
help="model path, to load text embedding",
|
||
)
|
||
parser.add_argument(
|
||
"--tllm-model-dir",
|
||
required=True,
|
||
type=str,
|
||
help="tllm model dir",
|
||
)
|
||
parser.add_argument(
|
||
"--batch-size",
|
||
required=True,
|
||
type=int,
|
||
help="batch size (per-device) for inference",
|
||
)
|
||
parser.add_argument("--num-workers", type=int, default=0, help="workers for dataloader")
|
||
parser.add_argument("--prefetch", type=int, default=None, help="prefetch for dataloader")
|
||
parser.add_argument(
|
||
"--vocoder",
|
||
default="vocos",
|
||
type=str,
|
||
help="vocoder name",
|
||
)
|
||
parser.add_argument(
|
||
"--vocoder-trt-engine-path",
|
||
default=None,
|
||
type=str,
|
||
help="vocoder trt engine path",
|
||
)
|
||
parser.add_argument("--enable-warmup", action="store_true")
|
||
parser.add_argument("--remove-input-padding", action="store_true")
|
||
parser.add_argument("--use-perf", action="store_true", help="use nvtx to record performance")
|
||
parser.add_argument("--backend-type", type=str, default="triton", choices=["trt", "pytorch"], help="backend type")
|
||
args = parser.parse_args()
|
||
return args
|
||
|
||
|
||
def padded_mel_batch(ref_mels, max_seq_len):
|
||
padded_ref_mels = []
|
||
for mel in ref_mels:
|
||
# pad along the last dimension
|
||
padded_ref_mel = F.pad(mel, (0, 0, 0, max_seq_len - mel.shape[0]), value=0)
|
||
padded_ref_mels.append(padded_ref_mel)
|
||
padded_ref_mels = torch.stack(padded_ref_mels)
|
||
return padded_ref_mels
|
||
|
||
|
||
def data_collator(batch, vocab_char_map, device="cuda", use_perf=False):
|
||
if use_perf:
|
||
torch.cuda.nvtx.range_push("data_collator")
|
||
target_sample_rate = 24000
|
||
target_rms = 0.1
|
||
ids, ref_mel_list, ref_mel_len_list, estimated_reference_target_mel_len, reference_target_texts_list = (
|
||
[],
|
||
[],
|
||
[],
|
||
[],
|
||
[],
|
||
)
|
||
for i, item in enumerate(batch):
|
||
item_id, prompt_text, target_text = (
|
||
item["id"],
|
||
item["prompt_text"],
|
||
item["target_text"],
|
||
)
|
||
ids.append(item_id)
|
||
reference_target_texts_list.append(prompt_text + target_text)
|
||
|
||
ref_audio_org, ref_sr = (
|
||
item["prompt_audio"]["array"],
|
||
item["prompt_audio"]["sampling_rate"],
|
||
)
|
||
ref_audio_org = torch.from_numpy(ref_audio_org).unsqueeze(0).float()
|
||
ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio_org)))
|
||
if ref_rms < target_rms:
|
||
ref_audio_org = ref_audio_org * target_rms / ref_rms
|
||
|
||
if ref_sr != target_sample_rate:
|
||
resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
|
||
ref_audio = resampler(ref_audio_org)
|
||
else:
|
||
ref_audio = ref_audio_org
|
||
|
||
if use_perf:
|
||
torch.cuda.nvtx.range_push(f"mel_spectrogram {i}")
|
||
ref_mel = mel_spectrogram(ref_audio, vocoder="vocos", device="cuda")
|
||
if use_perf:
|
||
torch.cuda.nvtx.range_pop()
|
||
ref_mel = ref_mel.squeeze()
|
||
ref_mel_len = ref_mel.shape[0]
|
||
assert ref_mel.shape[1] == 100
|
||
|
||
ref_mel_list.append(ref_mel)
|
||
ref_mel_len_list.append(ref_mel_len)
|
||
|
||
estimated_reference_target_mel_len.append(
|
||
int(ref_mel.shape[0] * (1 + len(target_text.encode("utf-8")) / len(prompt_text.encode("utf-8"))))
|
||
)
|
||
|
||
max_seq_len = max(estimated_reference_target_mel_len)
|
||
ref_mel_batch = padded_mel_batch(ref_mel_list, max_seq_len)
|
||
ref_mel_len_batch = torch.LongTensor(ref_mel_len_list)
|
||
|
||
pinyin_list = convert_char_to_pinyin(reference_target_texts_list, polyphone=True)
|
||
text_pad_sequence = list_str_to_idx(pinyin_list, vocab_char_map)
|
||
|
||
for i, item in enumerate(text_pad_sequence):
|
||
text_pad_sequence[i] = F.pad(
|
||
item, (0, estimated_reference_target_mel_len[i] - len(item)), mode="constant", value=-1
|
||
)
|
||
text_pad_sequence[i] += 1 # WAR: 0 is reserved for padding token, hard coding in F5-TTS
|
||
text_pad_sequence = pad_sequence(text_pad_sequence, padding_value=-1, batch_first=True).to(device)
|
||
text_pad_sequence = F.pad(
|
||
text_pad_sequence, (0, max_seq_len - text_pad_sequence.shape[1]), mode="constant", value=-1
|
||
)
|
||
if use_perf:
|
||
torch.cuda.nvtx.range_pop()
|
||
return {
|
||
"ids": ids,
|
||
"ref_mel_batch": ref_mel_batch,
|
||
"ref_mel_len_batch": ref_mel_len_batch,
|
||
"text_pad_sequence": text_pad_sequence,
|
||
"estimated_reference_target_mel_len": estimated_reference_target_mel_len,
|
||
}
|
||
|
||
|
||
def init_distributed():
|
||
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||
rank = int(os.environ.get("RANK", 0))
|
||
print(
|
||
"Inference on multiple gpus, this gpu {}".format(local_rank)
|
||
+ ", rank {}, world_size {}".format(rank, world_size)
|
||
)
|
||
torch.cuda.set_device(local_rank)
|
||
# Initialize process group with explicit device IDs
|
||
dist.init_process_group(
|
||
"nccl",
|
||
)
|
||
return world_size, local_rank, rank
|
||
|
||
|
||
def get_tokenizer(vocab_file_path: str):
|
||
"""
|
||
tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
|
||
- "char" for char-wise tokenizer, need .txt vocab_file
|
||
- "byte" for utf-8 tokenizer
|
||
- "custom" if you're directly passing in a path to the vocab.txt you want to use
|
||
vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
|
||
- if use "char", derived from unfiltered character & symbol counts of custom dataset
|
||
- if use "byte", set to 256 (unicode byte range)
|
||
"""
|
||
with open(vocab_file_path, "r", encoding="utf-8") as f:
|
||
vocab_char_map = {}
|
||
for i, char in enumerate(f):
|
||
vocab_char_map[char[:-1]] = i
|
||
vocab_size = len(vocab_char_map)
|
||
return vocab_char_map, vocab_size
|
||
|
||
|
||
def convert_char_to_pinyin(reference_target_texts_list, polyphone=True):
|
||
final_reference_target_texts_list = []
|
||
custom_trans = str.maketrans(
|
||
{";": ",", "“": '"', "”": '"', "‘": "'", "’": "'"}
|
||
) # add custom trans here, to address oov
|
||
|
||
def is_chinese(c):
|
||
return "\u3100" <= c <= "\u9fff" # common chinese characters
|
||
|
||
for text in reference_target_texts_list:
|
||
char_list = []
|
||
text = text.translate(custom_trans)
|
||
for seg in jieba.cut(text):
|
||
seg_byte_len = len(bytes(seg, "UTF-8"))
|
||
if seg_byte_len == len(seg): # if pure alphabets and symbols
|
||
if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
|
||
char_list.append(" ")
|
||
char_list.extend(seg)
|
||
elif polyphone and seg_byte_len == 3 * len(seg): # if pure east asian characters
|
||
seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
|
||
for i, c in enumerate(seg):
|
||
if is_chinese(c):
|
||
char_list.append(" ")
|
||
char_list.append(seg_[i])
|
||
else: # if mixed characters, alphabets and symbols
|
||
for c in seg:
|
||
if ord(c) < 256:
|
||
char_list.extend(c)
|
||
elif is_chinese(c):
|
||
char_list.append(" ")
|
||
char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True))
|
||
else:
|
||
char_list.append(c)
|
||
final_reference_target_texts_list.append(char_list)
|
||
|
||
return final_reference_target_texts_list
|
||
|
||
|
||
def list_str_to_idx(
|
||
text: Union[List[str], List[List[str]]],
|
||
vocab_char_map: Dict[str, int], # {char: idx}
|
||
padding_value=-1,
|
||
):
|
||
list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
|
||
# text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
|
||
return list_idx_tensors
|
||
|
||
|
||
def load_vocoder(
|
||
vocoder_name="vocos", is_local=False, local_path="", device="cuda", hf_cache_dir=None, vocoder_trt_engine_path=None
|
||
):
|
||
if vocoder_name == "vocos":
|
||
if vocoder_trt_engine_path is not None:
|
||
vocoder = VocosTensorRT(engine_path=vocoder_trt_engine_path)
|
||
else:
|
||
# vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device)
|
||
if is_local:
|
||
print(f"Load vocos from local path {local_path}")
|
||
config_path = f"{local_path}/config.yaml"
|
||
model_path = f"{local_path}/pytorch_model.bin"
|
||
else:
|
||
print("Download Vocos from huggingface charactr/vocos-mel-24khz")
|
||
repo_id = "charactr/vocos-mel-24khz"
|
||
config_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml")
|
||
model_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="pytorch_model.bin")
|
||
vocoder = Vocos.from_hparams(config_path)
|
||
state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
|
||
from vocos.feature_extractors import EncodecFeatures
|
||
|
||
if isinstance(vocoder.feature_extractor, EncodecFeatures):
|
||
encodec_parameters = {
|
||
"feature_extractor.encodec." + key: value
|
||
for key, value in vocoder.feature_extractor.encodec.state_dict().items()
|
||
}
|
||
state_dict.update(encodec_parameters)
|
||
vocoder.load_state_dict(state_dict)
|
||
vocoder = vocoder.eval().to(device)
|
||
elif vocoder_name == "bigvgan":
|
||
raise NotImplementedError("BigVGAN is not implemented yet")
|
||
return vocoder
|
||
|
||
|
||
def mel_spectrogram(waveform, vocoder="vocos", device="cuda"):
|
||
if vocoder == "vocos":
|
||
mel_stft = torchaudio.transforms.MelSpectrogram(
|
||
sample_rate=24000,
|
||
n_fft=1024,
|
||
win_length=1024,
|
||
hop_length=256,
|
||
n_mels=100,
|
||
power=1,
|
||
center=True,
|
||
normalized=False,
|
||
norm=None,
|
||
).to(device)
|
||
mel = mel_stft(waveform.to(device))
|
||
mel = mel.clamp(min=1e-5).log()
|
||
return mel.transpose(1, 2)
|
||
|
||
|
||
class VocosTensorRT:
|
||
def __init__(self, engine_path="./vocos_vocoder.plan", stream=None):
|
||
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
|
||
trt.init_libnvinfer_plugins(TRT_LOGGER, namespace="")
|
||
logger.info(f"Loading vae engine from {engine_path}")
|
||
self.engine_path = engine_path
|
||
with open(engine_path, "rb") as f:
|
||
engine_buffer = f.read()
|
||
self.session = Session.from_serialized_engine(engine_buffer)
|
||
self.stream = stream if stream is not None else torch.cuda.current_stream().cuda_stream
|
||
|
||
def decode(self, mels):
|
||
mels = mels.contiguous()
|
||
inputs = {"mel": mels}
|
||
output_info = self.session.infer_shapes([TensorInfo("mel", trt.DataType.FLOAT, mels.shape)])
|
||
outputs = {
|
||
t.name: torch.empty(tuple(t.shape), dtype=trt_dtype_to_torch(t.dtype), device="cuda") for t in output_info
|
||
}
|
||
ok = self.session.run(inputs, outputs, self.stream)
|
||
|
||
assert ok, "Runtime execution failed for vae session"
|
||
|
||
samples = outputs["waveform"]
|
||
return samples
|
||
|
||
|
||
def main():
|
||
args = get_args()
|
||
os.makedirs(args.output_dir, exist_ok=True)
|
||
|
||
assert torch.cuda.is_available()
|
||
world_size, local_rank, rank = init_distributed()
|
||
device = torch.device(f"cuda:{local_rank}")
|
||
|
||
vocab_char_map, vocab_size = get_tokenizer(args.vocab_file)
|
||
|
||
tllm_model_dir = args.tllm_model_dir
|
||
config_file = os.path.join(tllm_model_dir, "config.json")
|
||
with open(config_file) as f:
|
||
config = json.load(f)
|
||
if args.backend_type == "trt":
|
||
model = F5TTS(
|
||
config, debug_mode=False, tllm_model_dir=tllm_model_dir, model_path=args.model_path, vocab_size=vocab_size
|
||
)
|
||
elif args.backend_type == "pytorch":
|
||
import sys
|
||
|
||
sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../../../src/")
|
||
from f5_tts.infer.utils_infer import load_model
|
||
from f5_tts.model import DiT
|
||
|
||
F5TTS_model_cfg = dict(
|
||
dim=1024,
|
||
depth=22,
|
||
heads=16,
|
||
ff_mult=2,
|
||
text_dim=512,
|
||
conv_layers=4,
|
||
pe_attn_head=1,
|
||
text_mask_padding=False,
|
||
)
|
||
model = load_model(DiT, F5TTS_model_cfg, args.model_path)
|
||
|
||
vocoder = load_vocoder(
|
||
vocoder_name=args.vocoder, device=device, vocoder_trt_engine_path=args.vocoder_trt_engine_path
|
||
)
|
||
|
||
dataset = load_dataset(
|
||
"yuekai/seed_tts",
|
||
split=args.split_name,
|
||
trust_remote_code=True,
|
||
)
|
||
|
||
def add_estimated_duration(example):
|
||
prompt_audio_len = example["prompt_audio"]["array"].shape[0]
|
||
scale_factor = 1 + len(example["target_text"]) / len(example["prompt_text"])
|
||
estimated_duration = prompt_audio_len * scale_factor
|
||
example["estimated_duration"] = estimated_duration / example["prompt_audio"]["sampling_rate"]
|
||
return example
|
||
|
||
dataset = dataset.map(add_estimated_duration)
|
||
dataset = dataset.sort("estimated_duration", reverse=True)
|
||
if args.use_perf:
|
||
# dataset_list = [dataset.select(range(1)) for i in range(16)] # seq_len 1000
|
||
dataset_list_short = [dataset.select([24]) for i in range(8)] # seq_len 719
|
||
# dataset_list_long = [dataset.select([23]) for i in range(8)] # seq_len 2002
|
||
# dataset = datasets.concatenate_datasets(dataset_list_short + dataset_list_long)
|
||
dataset = datasets.concatenate_datasets(dataset_list_short)
|
||
if world_size > 1:
|
||
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
|
||
else:
|
||
# This would disable shuffling
|
||
sampler = None
|
||
|
||
dataloader = DataLoader(
|
||
dataset,
|
||
batch_size=args.batch_size,
|
||
sampler=sampler,
|
||
shuffle=False,
|
||
num_workers=args.num_workers,
|
||
prefetch_factor=args.prefetch,
|
||
collate_fn=lambda x: data_collator(x, vocab_char_map, use_perf=args.use_perf),
|
||
)
|
||
|
||
total_steps = len(dataset)
|
||
|
||
if args.enable_warmup:
|
||
for batch in dataloader:
|
||
ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch["ref_mel_len_batch"].to(device)
|
||
text_pad_seq = batch["text_pad_sequence"].to(device)
|
||
total_mel_lens = batch["estimated_reference_target_mel_len"]
|
||
if args.backend_type == "trt":
|
||
_ = model.sample(
|
||
text_pad_seq, ref_mels, ref_mel_lens, total_mel_lens, remove_input_padding=args.remove_input_padding
|
||
)
|
||
elif args.backend_type == "pytorch":
|
||
with torch.inference_mode():
|
||
text_pad_seq -= 1
|
||
text_pad_seq[text_pad_seq == -2] = -1
|
||
total_mel_lens = torch.tensor(total_mel_lens, device=device)
|
||
generated, _ = model.sample(
|
||
cond=ref_mels,
|
||
text=text_pad_seq,
|
||
duration=total_mel_lens,
|
||
steps=16,
|
||
cfg_strength=2.0,
|
||
sway_sampling_coef=-1,
|
||
)
|
||
|
||
if rank == 0:
|
||
progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs")
|
||
|
||
decoding_time = 0
|
||
vocoder_time = 0
|
||
total_duration = 0
|
||
if args.use_perf:
|
||
torch.cuda.cudart().cudaProfilerStart()
|
||
total_decoding_time = time.time()
|
||
for batch in dataloader:
|
||
if args.use_perf:
|
||
torch.cuda.nvtx.range_push("data sample")
|
||
ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch["ref_mel_len_batch"].to(device)
|
||
text_pad_seq = batch["text_pad_sequence"].to(device)
|
||
total_mel_lens = batch["estimated_reference_target_mel_len"]
|
||
|
||
if args.use_perf:
|
||
torch.cuda.nvtx.range_pop()
|
||
if args.backend_type == "trt":
|
||
generated, cost_time = model.sample(
|
||
text_pad_seq,
|
||
ref_mels,
|
||
ref_mel_lens,
|
||
total_mel_lens,
|
||
remove_input_padding=args.remove_input_padding,
|
||
use_perf=args.use_perf,
|
||
)
|
||
elif args.backend_type == "pytorch":
|
||
total_mel_lens = torch.tensor(total_mel_lens, device=device)
|
||
with torch.inference_mode():
|
||
start_time = time.time()
|
||
text_pad_seq -= 1
|
||
text_pad_seq[text_pad_seq == -2] = -1
|
||
generated, _ = model.sample(
|
||
cond=ref_mels,
|
||
text=text_pad_seq,
|
||
duration=total_mel_lens,
|
||
lens=ref_mel_lens,
|
||
steps=16,
|
||
cfg_strength=2.0,
|
||
sway_sampling_coef=-1,
|
||
)
|
||
cost_time = time.time() - start_time
|
||
decoding_time += cost_time
|
||
vocoder_start_time = time.time()
|
||
for i, gen in enumerate(generated):
|
||
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
|
||
gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32)
|
||
if args.vocoder == "vocos":
|
||
if args.use_perf:
|
||
torch.cuda.nvtx.range_push("vocoder decode")
|
||
generated_wave = vocoder.decode(gen_mel_spec).cpu()
|
||
if args.use_perf:
|
||
torch.cuda.nvtx.range_pop()
|
||
else:
|
||
generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu()
|
||
target_rms = 0.1
|
||
target_sample_rate = 24_000
|
||
# if ref_rms_list[i] < target_rms:
|
||
# generated_wave = generated_wave * ref_rms_list[i] / target_rms
|
||
rms = torch.sqrt(torch.mean(torch.square(generated_wave)))
|
||
if rms < target_rms:
|
||
generated_wave = generated_wave * target_rms / rms
|
||
utt = batch["ids"][i]
|
||
torchaudio.save(
|
||
f"{args.output_dir}/{utt}.wav",
|
||
generated_wave,
|
||
target_sample_rate,
|
||
)
|
||
total_duration += generated_wave.shape[1] / target_sample_rate
|
||
vocoder_time += time.time() - vocoder_start_time
|
||
if rank == 0:
|
||
progress_bar.update(world_size * len(batch["ids"]))
|
||
total_decoding_time = time.time() - total_decoding_time
|
||
if rank == 0:
|
||
progress_bar.close()
|
||
rtf = total_decoding_time / total_duration
|
||
s = f"RTF: {rtf:.4f}\n"
|
||
s += f"total_duration: {total_duration:.3f} seconds\n"
|
||
s += f"({total_duration / 3600:.2f} hours)\n"
|
||
s += f"DiT time: {decoding_time:.3f} seconds ({decoding_time / 3600:.2f} hours)\n"
|
||
s += f"Vocoder time: {vocoder_time:.3f} seconds ({vocoder_time / 3600:.2f} hours)\n"
|
||
s += f"total decoding time: {total_decoding_time:.3f} seconds ({total_decoding_time / 3600:.2f} hours)\n"
|
||
s += f"batch size: {args.batch_size}\n"
|
||
print(s)
|
||
|
||
with open(f"{args.output_dir}/rtf.txt", "w") as f:
|
||
f.write(s)
|
||
|
||
dist.barrier()
|
||
dist.destroy_process_group()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|