#!/usr/bin/env python3 # Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) import datetime as dt import librosa import numpy as np import onnxruntime as ort import soundfile as sf def display(sess, name): print(f"=========={name} Input==========") for i in sess.get_inputs(): print(i) print(f"=========={name} Output==========") for i in sess.get_outputs(): print(i) class OnnxModel: def __init__( self, preprocess: str, encode: str, uncached_decode: str, cached_decode: str, ): self.init_preprocess(preprocess) display(self.preprocess, "preprocess") self.init_encode(encode) display(self.encode, "encode") self.init_uncached_decode(uncached_decode) display(self.uncached_decode, "uncached_decode") self.init_cached_decode(cached_decode) display(self.cached_decode, "cached_decode") def init_preprocess(self, preprocess): session_opts = ort.SessionOptions() session_opts.inter_op_num_threads = 1 session_opts.intra_op_num_threads = 1 self.preprocess = ort.InferenceSession( preprocess, sess_options=session_opts, providers=["CPUExecutionProvider"], ) def init_encode(self, encode): session_opts = ort.SessionOptions() session_opts.inter_op_num_threads = 1 session_opts.intra_op_num_threads = 1 self.encode = ort.InferenceSession( encode, sess_options=session_opts, providers=["CPUExecutionProvider"], ) def init_uncached_decode(self, uncached_decode): session_opts = ort.SessionOptions() session_opts.inter_op_num_threads = 1 session_opts.intra_op_num_threads = 1 self.uncached_decode = ort.InferenceSession( uncached_decode, sess_options=session_opts, providers=["CPUExecutionProvider"], ) def init_cached_decode(self, cached_decode): session_opts = ort.SessionOptions() session_opts.inter_op_num_threads = 1 session_opts.intra_op_num_threads = 1 self.cached_decode = ort.InferenceSession( cached_decode, sess_options=session_opts, providers=["CPUExecutionProvider"], ) def run_preprocess(self, audio): """ Args: audio: (batch_size, num_samples), float32 Returns: A tensor of shape (batch_size, T, dim), float32 """ return self.preprocess.run( [ self.preprocess.get_outputs()[0].name, ], { self.preprocess.get_inputs()[0].name: audio, }, )[0] def run_encode(self, features): """ Args: features: (batch_size, T, dim) Returns: A tensor of shape (batch_size, T, dim) """ features_len = np.array([features.shape[1]], dtype=np.int32) return self.encode.run( [ self.encode.get_outputs()[0].name, ], { self.encode.get_inputs()[0].name: features, self.encode.get_inputs()[1].name: features_len, }, )[0] def run_uncached_decode(self, token: int, token_len: int, encoder_out: np.ndarray): """ Args: token: The current token token_len: Number of predicted tokens so far encoder_out: A tensor fo shape (batch_size, T, dim) Returns: A a tuple: - a tensor of shape (batch_size, 1, dim) - a list of states """ token_tensor = np.array([[token]], dtype=np.int32) token_len_tensor = np.array([token_len], dtype=np.int32) num_outs = len(self.uncached_decode.get_outputs()) out_names = [ self.uncached_decode.get_outputs()[i].name for i in range(num_outs) ] out = self.uncached_decode.run( out_names, { self.uncached_decode.get_inputs()[0].name: token_tensor, self.uncached_decode.get_inputs()[1].name: encoder_out, self.uncached_decode.get_inputs()[2].name: token_len_tensor, }, ) logits = out[0] states = out[1:] return logits, states def run_cached_decode( self, token: int, token_len: int, encoder_out: np.ndarray, states ): """ Args: token: The current token token_len: Number of predicted tokens so far encoder_out: A tensor of shape (batch_size, T, dim) states: previous states Returns: A a tuple: - a tensor of shape (batch_size, 1, dim) - a list of states """ token_tensor = np.array([[token]], dtype=np.int32) token_len_tensor = np.array([token_len], dtype=np.int32) num_outs = len(self.cached_decode.get_outputs()) out_names = [self.cached_decode.get_outputs()[i].name for i in range(num_outs)] states_inputs = {} for i in range(3, len(self.cached_decode.get_inputs())): name = self.cached_decode.get_inputs()[i].name states_inputs[name] = states[i - 3] out = self.cached_decode.run( out_names, { self.cached_decode.get_inputs()[0].name: token_tensor, self.cached_decode.get_inputs()[1].name: encoder_out, self.cached_decode.get_inputs()[2].name: token_len_tensor, **states_inputs, }, ) logits = out[0] states = out[1:] return logits, states def main(): wave = "./1.wav" id2token = dict() token2id = dict() with open("./tokens.txt", encoding="utf-8") as f: for k, line in enumerate(f): t, idx = line.split("\t") id2token[int(idx)] = t token2id[t] = int(idx) model = OnnxModel( preprocess="./preprocess.onnx", encode="./encode.int8.onnx", uncached_decode="./uncached_decode.int8.onnx", cached_decode="./cached_decode.int8.onnx", ) audio, sample_rate = sf.read(wave, dtype="float32", always_2d=True) audio = audio[:, 0] # only use the first channel if sample_rate != 16000: audio = librosa.resample( audio, orig_sr=sample_rate, target_sr=16000, ) sample_rate = 16000 audio = audio[None] # (1, num_samples) print("audio.shape", audio.shape) # (1, 159414) start_t = dt.datetime.now() features = model.run_preprocess(audio) # (1, 413, 288) print("features", features.shape) sos = token2id[""] eos = token2id[""] tokens = [sos] encoder_out = model.run_encode(features) print("encoder_out.shape", encoder_out.shape) # (1, 413, 288) logits, states = model.run_uncached_decode( token=tokens[-1], token_len=len(tokens), encoder_out=encoder_out, ) print("logits.shape", logits.shape) # (1, 1, 32768) print("len(states)", len(states)) # 24 max_len = int((audio.shape[-1] / 16000) * 6) for i in range(max_len): token = logits.squeeze().argmax() if token == eos: break tokens.append(token) logits, states = model.run_cached_decode( token=tokens[-1], token_len=len(tokens), encoder_out=encoder_out, states=states, ) tokens = tokens[1:] # remove sos words = [id2token[i] for i in tokens] underline = "▁" # underline = b"\xe2\x96\x81".decode() text = "".join(words).replace(underline, " ").strip() end_t = dt.datetime.now() t = (end_t - start_t).total_seconds() rtf = t * 16000 / audio.shape[-1] print(text) print("RTF:", rtf) if __name__ == "__main__": main()