#!/usr/bin/env python3 # Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang) from typing import Tuple import kaldi_native_fbank as knf import numpy as np import onnxruntime as ort import soundfile as sf import torch def load_audio(filename: str) -> Tuple[np.ndarray, int]: data, sample_rate = sf.read( filename, always_2d=True, dtype="float32", ) data = data[:, 0] # use only the first channel samples = np.ascontiguousarray(data) return samples, sample_rate class OnnxModel: def __init__(self): session_opts = ort.SessionOptions() session_opts.inter_op_num_threads = 1 session_opts.intra_op_num_threads = 1 self.session_opts = session_opts self.model = ort.InferenceSession( "./gtcrn_simple.onnx", sess_options=self.session_opts, providers=["CPUExecutionProvider"], ) meta = self.model.get_modelmeta().custom_metadata_map self.sample_rate = int(meta["sample_rate"]) self.n_fft = int(meta["n_fft"]) self.hop_length = int(meta["hop_length"]) self.window_length = int(meta["window_length"]) assert meta["window_type"] == "hann_sqrt", meta["window_type"] self.window = torch.hann_window(self.window_length).pow(0.5) def get_init_states(self): meta = self.model.get_modelmeta().custom_metadata_map conv_cache_shape = list(map(int, meta["conv_cache_shape"].split(","))) tra_cache_shape = list(map(int, meta["tra_cache_shape"].split(","))) inter_cache_shape = list(map(int, meta["inter_cache_shape"].split(","))) conv_cache_shape = np.zeros(conv_cache_shape, dtype=np.float32) tra_cache = np.zeros(tra_cache_shape, dtype=np.float32) inter_cache = np.zeros(inter_cache_shape, dtype=np.float32) return conv_cache_shape, tra_cache, inter_cache def __call__(self, x, states): """ Args: x: (1, n_fft/2+1, 1, 2) Returns: o: (1, n_fft/2+1, 1, 2) """ out, next_conv_cache, next_tra_cache, next_inter_cache = self.model.run( [ self.model.get_outputs()[0].name, self.model.get_outputs()[1].name, self.model.get_outputs()[2].name, self.model.get_outputs()[3].name, ], { self.model.get_inputs()[0].name: x, self.model.get_inputs()[1].name: states[0], self.model.get_inputs()[2].name: states[1], self.model.get_inputs()[3].name: states[2], }, ) return out, (next_conv_cache, next_tra_cache, next_inter_cache) def main(): model = OnnxModel() filename = "./inp_16k.wav" wave, sample_rate = load_audio(filename) if sample_rate != model.sample_rate: import librosa wave = librosa.resample(wave, orig_sr=sample_rate, target_sr=model.sample_rate) sample_rate = model.sample_rate stft_config = knf.StftConfig( n_fft=model.n_fft, hop_length=model.hop_length, win_length=model.window_length, window=model.window.tolist(), ) stft = knf.Stft(stft_config) stft_result = stft(wave) num_frames = stft_result.num_frames real = np.array(stft_result.real, dtype=np.float32).reshape(num_frames, -1) imag = np.array(stft_result.imag, dtype=np.float32).reshape(num_frames, -1) states = model.get_init_states() outputs = [] for i in range(num_frames): x_real = real[i : i + 1] x_imag = imag[i : i + 1] x = np.vstack([x_real, x_imag]).transpose() x = np.expand_dims(x, axis=0) x = np.expand_dims(x, axis=2) o, states = model(x, states) outputs.append(o) outputs = np.concatenate(outputs, axis=2) outputs = outputs.squeeze(0).transpose(1, 0, 2) enhanced_real = outputs[:, :, 0] enhanced_imag = outputs[:, :, 1] enhanced_stft_result = knf.StftResult( real=enhanced_real.reshape(-1).tolist(), imag=enhanced_imag.reshape(-1).tolist(), num_frames=enhanced_real.shape[0], ) istft = knf.IStft(stft_config) enhanced = istft(enhanced_stft_result) sf.write("./enhanced_16k.wav", enhanced, model.sample_rate) if __name__ == "__main__": main()