#!/usr/bin/env python3 # Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang) import argparse import time from pathlib import Path from typing import List import kaldi_native_fbank as knf import librosa import numpy as np import onnxruntime as ort import soundfile as sf def get_args(): parser = argparse.ArgumentParser() parser.add_argument( "--encoder", type=str, required=True, help="Path to encoder.onnx" ) parser.add_argument( "--decoder", type=str, required=True, help="Path to decoder.onnx" ) parser.add_argument("--tokens", type=str, required=True, help="Path to tokens.txt") parser.add_argument( "--source-lang", type=str, help="Language of the input wav. Valid values are: en, de, es, fr", ) parser.add_argument( "--target-lang", type=str, help="Language of the recognition result. Valid values are: en, de, es, fr", ) parser.add_argument( "--use-pnc", type=int, default=1, help="1 to enable cases and punctuations. 0 to disable that", ) parser.add_argument("--wav", type=str, required=True, help="Path to test.wav") return parser.parse_args() def display(sess, model): print(f"=========={model} Input==========") for i in sess.get_inputs(): print(i) print(f"=========={model }Output==========") for i in sess.get_outputs(): print(i) class OnnxModel: def __init__( self, encoder: str, decoder: str, ): self.init_encoder(encoder) display(self.encoder, "encoder") self.init_decoder(decoder) display(self.decoder, "decoder") def init_encoder(self, encoder): session_opts = ort.SessionOptions() session_opts.inter_op_num_threads = 1 session_opts.intra_op_num_threads = 1 self.encoder = ort.InferenceSession( encoder, sess_options=session_opts, providers=["CPUExecutionProvider"], ) meta = self.encoder.get_modelmeta().custom_metadata_map # self.normalize_type = meta["normalize_type"] self.normalize_type = "per_feature" print(meta) def init_decoder(self, decoder): session_opts = ort.SessionOptions() session_opts.inter_op_num_threads = 1 session_opts.intra_op_num_threads = 1 self.decoder = ort.InferenceSession( decoder, sess_options=session_opts, providers=["CPUExecutionProvider"], ) def run_encoder(self, x: np.ndarray, x_lens: np.ndarray): """ Args: x: (N, T, C), np.float x_lens: (N,), np.int64 Returns: enc_states: (N, T, C) enc_lens: (N,), np.int64 enc_masks: (N, T), np.bool """ enc_states, enc_lens, enc_masks = self.encoder.run( [ self.encoder.get_outputs()[0].name, self.encoder.get_outputs()[1].name, self.encoder.get_outputs()[2].name, ], { self.encoder.get_inputs()[0].name: x, self.encoder.get_inputs()[1].name: x_lens, }, ) return enc_states, enc_lens, enc_masks def run_decoder( self, decoder_input_ids: np.ndarray, decoder_mems_list: List[np.ndarray], enc_states: np.ndarray, enc_mask: np.ndarray, ): """ Args: decoder_input_ids: (N, num_tokens), int32 decoder_mems_list: a list of tensors, each of which is (N, num_tokens, C) enc_states: (N, T, C), float enc_mask: (N, T), bool Returns: logits: (1, 1, vocab_size), float new_decoder_mems_list: """ (logits, *new_decoder_mems_list) = self.decoder.run( [ self.decoder.get_outputs()[0].name, self.decoder.get_outputs()[1].name, self.decoder.get_outputs()[2].name, self.decoder.get_outputs()[3].name, self.decoder.get_outputs()[4].name, self.decoder.get_outputs()[5].name, self.decoder.get_outputs()[6].name, ], { self.decoder.get_inputs()[0].name: decoder_input_ids, self.decoder.get_inputs()[1].name: decoder_mems_list[0], self.decoder.get_inputs()[2].name: decoder_mems_list[1], self.decoder.get_inputs()[3].name: decoder_mems_list[2], self.decoder.get_inputs()[4].name: decoder_mems_list[3], self.decoder.get_inputs()[5].name: decoder_mems_list[4], self.decoder.get_inputs()[6].name: decoder_mems_list[5], self.decoder.get_inputs()[7].name: enc_states, self.decoder.get_inputs()[8].name: enc_mask, }, ) return logits, new_decoder_mems_list def create_fbank(): opts = knf.FbankOptions() opts.frame_opts.dither = 0 opts.frame_opts.remove_dc_offset = False opts.frame_opts.window_type = "hann" opts.mel_opts.low_freq = 0 opts.mel_opts.num_bins = 128 opts.mel_opts.is_librosa = True fbank = knf.OnlineFbank(opts) return fbank def compute_features(audio, fbank): assert len(audio.shape) == 1, audio.shape fbank.accept_waveform(16000, audio) ans = [] processed = 0 while processed < fbank.num_frames_ready: ans.append(np.array(fbank.get_frame(processed))) processed += 1 ans = np.stack(ans) return ans def main(): args = get_args() assert Path(args.encoder).is_file(), args.encoder assert Path(args.decoder).is_file(), args.decoder assert Path(args.tokens).is_file(), args.tokens assert Path(args.wav).is_file(), args.wav print(vars(args)) id2token = dict() token2id = dict() with open(args.tokens, encoding="utf-8") as f: for line in f: fields = line.split() if len(fields) == 2: t, idx = fields[0], int(fields[1]) if line[0] == " ": t = " " + t else: t = " " idx = int(fields[0]) id2token[idx] = t token2id[t] = idx model = OnnxModel(args.encoder, args.decoder) fbank = create_fbank() start = time.time() audio, sample_rate = sf.read(args.wav, dtype="float32", always_2d=True) audio = audio[:, 0] # only use the first channel if sample_rate != 16000: audio = librosa.resample( audio, orig_sr=sample_rate, target_sr=16000, ) sample_rate = 16000 features = compute_features(audio, fbank) if model.normalize_type != "": assert model.normalize_type == "per_feature", model.normalize_type mean = features.mean(axis=1, keepdims=True) stddev = features.std(axis=1, keepdims=True) + 1e-5 features = (features - mean) / stddev features = np.expand_dims(features, axis=0) # features.shape: (1, 291, 128) features_len = np.array([features.shape[1]], dtype=np.int64) enc_states, _, enc_masks = model.run_encoder(features, features_len) decoder_input_ids = [] decoder_input_ids.append(token2id["<|startofcontext|>"]) decoder_input_ids.append(token2id["<|startoftranscript|>"]) decoder_input_ids.append(token2id["<|emo:undefined|>"]) if args.source_lang in ("en", "es", "de", "fr"): decoder_input_ids.append(token2id[f"<|{args.source_lang}|>"]) else: decoder_input_ids.append(token2id[f"<|en|>"]) if args.target_lang in ("en", "es", "de", "fr"): decoder_input_ids.append(token2id[f"<|{args.target_lang}|>"]) else: decoder_input_ids.append(token2id[f"<|en|>"]) if args.use_pnc: decoder_input_ids.append(token2id[f"<|pnc|>"]) else: decoder_input_ids.append(token2id[f"<|nopnc|>"]) decoder_input_ids.append(token2id[f"<|noitn|>"]) decoder_input_ids.append(token2id["<|notimestamp|>"]) decoder_input_ids.append(token2id["<|nodiarize|>"]) decoder_input_ids.append(0) decoder_mems_list = [np.zeros((1, 0, 1024), dtype=np.float32) for _ in range(6)] logits, decoder_mems_list = model.run_decoder( np.array([decoder_input_ids], dtype=np.int32), decoder_mems_list, enc_states, enc_masks, ) tokens = [logits.argmax()] print("decoder_input_ids", decoder_input_ids) eos = token2id["<|endoftext|>"] for i in range(1, 200): decoder_input_ids = [tokens[-1], i] logits, decoder_mems_list = model.run_decoder( np.array([decoder_input_ids], dtype=np.int32), decoder_mems_list, enc_states, enc_masks, ) t = logits.argmax() if t == eos: break tokens.append(t) print("len(tokens)", len(tokens)) print("tokens", tokens) text = "".join([id2token[i] for i in tokens]) print("text:", text) if __name__ == "__main__": main()