#!/usr/bin/env python3 # Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang) import datetime as dt import kaldi_native_fbank as knf import numpy as np import onnxruntime as ort import soundfile as sf try: from piper_phonemize import phonemize_espeak except Exception as ex: raise RuntimeError( f"{ex}\nPlease run\n" "pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html" ) class OnnxVocosModel: def __init__( self, filename: str, ): session_opts = ort.SessionOptions() session_opts.inter_op_num_threads = 1 session_opts.intra_op_num_threads = 1 self.session_opts = session_opts self.model = ort.InferenceSession( filename, sess_options=self.session_opts, providers=["CPUExecutionProvider"], ) print("----------vocos----------") for i in self.model.get_inputs(): print(i) print("-----") for i in self.model.get_outputs(): print(i) print() def __call__(self, x: np.ndarray): """ Args: x: (N, feat_dim, num_frames) Returns: mag: (N, n_fft/2+1, num_frames) x: (N, n_fft/2+1, num_frames) y: (N, n_fft/2+1, num_frames) The complex spectrum is mag * (x + j*y) """ assert x.ndim == 3, x.shape assert x.shape[0] == 1, x.shape mag, x, y = self.model.run( [ self.model.get_outputs()[0].name, self.model.get_outputs()[1].name, self.model.get_outputs()[2].name, ], { self.model.get_inputs()[0].name: x, }, ) return mag, x, y class OnnxHifiGANModel: def __init__( self, filename: str, ): session_opts = ort.SessionOptions() session_opts.inter_op_num_threads = 1 session_opts.intra_op_num_threads = 1 self.session_opts = session_opts self.model = ort.InferenceSession( filename, sess_options=self.session_opts, providers=["CPUExecutionProvider"], ) print("----------hifigan----------") for i in self.model.get_inputs(): print(i) print("-----") for i in self.model.get_outputs(): print(i) print() def __call__(self, x: np.ndarray): """ Args: x: (N, feat_dim, num_frames) Returns: audio: (N, num_samples) """ assert x.ndim == 3, x.shape assert x.shape[0] == 1, x.shape audio = self.model.run( [self.model.get_outputs()[0].name], { self.model.get_inputs()[0].name: x, }, )[0] # audio: (batch_size, num_samples) return audio def load_tokens(filename): token2id = dict() with open(filename, encoding="utf-8") as f: for line in f: fields = line.strip().split() if len(fields) == 1: t = " " idx = int(fields[0]) else: t, idx = line.strip().split() token2id[t] = int(idx) return token2id class OnnxModel: def __init__( self, filename: str, tokens: str, ): self.token2id = load_tokens(tokens) session_opts = ort.SessionOptions() session_opts.inter_op_num_threads = 1 session_opts.intra_op_num_threads = 1 self.session_opts = session_opts self.model = ort.InferenceSession( filename, sess_options=self.session_opts, providers=["CPUExecutionProvider"], ) print(f"{self.model.get_modelmeta().custom_metadata_map}") metadata = self.model.get_modelmeta().custom_metadata_map self.sample_rate = int(metadata["sample_rate"]) print("----------matcha----------") for i in self.model.get_inputs(): print(i) print("-----") for i in self.model.get_outputs(): print(i) print() def __call__(self, x: np.ndim): """ Args: """ assert x.ndim == 2, x.shape assert x.shape[0] == 1, x.shape x_lengths = np.array([x.shape[1]], dtype=np.int64) noise_scale = np.array([1.0], dtype=np.float32) length_scale = np.array([1.0], dtype=np.float32) mel = self.model.run( [self.model.get_outputs()[0].name], { self.model.get_inputs()[0].name: x, self.model.get_inputs()[1].name: x_lengths, self.model.get_inputs()[2].name: noise_scale, self.model.get_inputs()[3].name: length_scale, }, )[0] # mel: (batch_size, feat_dim, num_frames) return mel def main(): am = OnnxModel( filename="./matcha-icefall-en_US-ljspeech/model-steps-3.onnx", tokens="./matcha-icefall-en_US-ljspeech/tokens.txt", ) vocoder = OnnxHifiGANModel("./hifigan_v2.onnx") vocos = OnnxVocosModel("./mel_spec_22khz_univ.onnx") text = "Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar." tokens_list = phonemize_espeak(text, "en-us") print(tokens_list) tokens = [] for t in tokens_list: tokens.extend(t) token_ids = [] for t in tokens: if t not in am.token2id: print(f"Skip OOV '{t}'") continue token_ids.append(am.token2id[t]) token_ids2 = [am.token2id["_"]] * (len(token_ids) * 2 + 1) token_ids2[1::2] = token_ids token_ids = token_ids2 x = np.array([token_ids], dtype=np.int64) mel_start_t = dt.datetime.now() mel = am(x) mel_end_t = dt.datetime.now() print("mel", mel.shape) # mel:(1, 80, 78) vocos_start_t = dt.datetime.now() mag, x, y = vocos(mel) stft_result = knf.StftResult( real=(mag * x)[0].transpose().reshape(-1).tolist(), imag=(mag * y)[0].transpose().reshape(-1).tolist(), num_frames=mag.shape[2], ) config = knf.StftConfig( n_fft=1024, hop_length=256, win_length=1024, window_type="hann", center=True, pad_mode="reflect", normalized=False, ) istft = knf.IStft(config) audio_vocos = istft(stft_result) vocos_end_t = dt.datetime.now() audio_vocos = np.array(audio_vocos) # audio = audio / 2 print("vocos max/min", np.max(audio_vocos), np.min(audio_vocos)) sf.write("vocos.wav", audio_vocos, am.sample_rate, "PCM_16") hifigan_start_t = dt.datetime.now() audio_hifigan = vocoder(mel) hifigan_end_t = dt.datetime.now() audio_hifigan = audio_hifigan.squeeze() print("hifigan max/min", np.max(audio_hifigan), np.min(audio_hifigan)) sample_rate = am.sample_rate sf.write("hifigan-v2.wav", audio_hifigan, sample_rate, "PCM_16") am_t = (mel_end_t - mel_start_t).total_seconds() vocos_t = (vocos_end_t - vocos_start_t).total_seconds() hifigan_t = (hifigan_end_t - hifigan_start_t).total_seconds() mean_audio_duration = ( (audio_vocos.shape[-1] + audio_hifigan.shape[-1]) / 2 / sample_rate ) rtf_am = am_t / mean_audio_duration rtf_vocos = vocos_t * sample_rate / audio_vocos.shape[-1] rtf_hifigan = hifigan_t * sample_rate / audio_hifigan.shape[-1] print( "Audio duration for vocos {:.3f} s".format(audio_vocos.shape[-1] / sample_rate) ) print( "Audio duration for hifigan {:.3f} s".format( audio_hifigan.shape[-1] / sample_rate ) ) print("Mean audio duration: {:.3f} s".format(mean_audio_duration)) print("RTF for acoustic model {:.3f}".format(rtf_am)) print("RTF for vocos {:.3f}".format(rtf_vocos)) print("RTF for hifigan {:.3f}".format(rtf_hifigan)) if __name__ == "__main__": main()