#!/usr/bin/env python3 # Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang) import time import kaldi_native_fbank as knf import numpy as np import onnxruntime as ort import soundfile as sf import torch from separate import load_audio """ ----------inputs for ./2stems/vocals.onnx---------- NodeArg(name='x', type='tensor(float)', shape=[2, 'num_splits', 512, 1024]) ----------outputs for ./2stems/vocals.onnx---------- NodeArg(name='y', type='tensor(float)', shape=[2, 'Transposey_dim_1', 512, 1024]) ----------inputs for ./2stems/accompaniment.onnx---------- NodeArg(name='x', type='tensor(float)', shape=[2, 'num_splits', 512, 1024]) ----------outputs for ./2stems/accompaniment.onnx---------- NodeArg(name='y', type='tensor(float)', shape=[2, 'Transposey_dim_1', 512, 1024]) """ class OnnxModel: def __init__(self, filename): session_opts = ort.SessionOptions() session_opts.inter_op_num_threads = 1 session_opts.intra_op_num_threads = 1 self.session_opts = session_opts self.model = ort.InferenceSession( filename, sess_options=self.session_opts, providers=["CPUExecutionProvider"], ) print(f"----------inputs for {filename}----------") for i in self.model.get_inputs(): print(i) print(f"----------outputs for {filename}----------") for i in self.model.get_outputs(): print(i) print("--------------------") def __call__(self, x): """ Args: x: (num_splits, 2, 512, 1024) """ spec = self.model.run( [ self.model.get_outputs()[0].name, ], { self.model.get_inputs()[0].name: x.numpy(), }, )[0] return torch.from_numpy(spec) def main(): vocals = OnnxModel("./2stems/vocals.onnx") accompaniment = OnnxModel("./2stems/accompaniment.onnx") waveform, sample_rate = load_audio("./qi-feng-le.mp3") waveform = waveform[: 44100 * 10, :] stft_config = knf.StftConfig( n_fft=4096, hop_length=1024, win_length=4096, center=False, window_type="hann", ) knf_stft = knf.Stft(stft_config) knf_istft = knf.IStft(stft_config) start = time.time() stft_result_c0 = knf_stft(waveform[:, 0].tolist()) stft_result_c1 = knf_stft(waveform[:, 1].tolist()) print("c0 stft", stft_result_c0.num_frames) orig_real0 = np.array(stft_result_c0.real, dtype=np.float32).reshape( stft_result_c0.num_frames, -1 ) orig_imag0 = np.array(stft_result_c0.imag, dtype=np.float32).reshape( stft_result_c0.num_frames, -1 ) orig_real1 = np.array(stft_result_c1.real, dtype=np.float32).reshape( stft_result_c1.num_frames, -1 ) orig_imag1 = np.array(stft_result_c1.imag, dtype=np.float32).reshape( stft_result_c1.num_frames, -1 ) real0 = torch.from_numpy(orig_real0) imag0 = torch.from_numpy(orig_imag0) real1 = torch.from_numpy(orig_real1) imag1 = torch.from_numpy(orig_imag1) # (num_frames, n_fft/2_1) print("real0", real0.shape) # keep only the first 1024 bins real0 = real0[:, :1024] imag0 = imag0[:, :1024] real1 = real1[:, :1024] imag1 = imag1[:, :1024] stft0 = (real0.square() + imag0.square()).sqrt() stft1 = (real1.square() + imag1.square()).sqrt() # pad it to multiple of 512 padding = 512 - real0.shape[0] % 512 print("padding", padding) if padding > 0: stft0 = torch.nn.functional.pad(stft0, (0, 0, 0, padding)) stft1 = torch.nn.functional.pad(stft1, (0, 0, 0, padding)) stft0 = stft0.reshape(1, -1, 512, 1024) stft1 = stft1.reshape(1, -1, 512, 1024) stft_01 = torch.cat([stft0, stft1], axis=0) print("stft_01", stft_01.shape, stft_01.dtype) vocals_spec = vocals(stft_01) accompaniment_spec = accompaniment(stft_01) # (num_channels, num_splits, 512, 1024) sum_spec = (vocals_spec.square() + accompaniment_spec.square()) + 1e-10 vocals_spec = (vocals_spec**2 + 1e-10 / 2) / sum_spec accompaniment_spec = (accompaniment_spec**2 + 1e-10 / 2) / sum_spec for name, spec in zip( ["vocals", "accompaniment"], [vocals_spec, accompaniment_spec] ): spec_c0 = spec[0] spec_c1 = spec[1] spec_c0 = spec_c0.reshape(-1, 1024) spec_c1 = spec_c1.reshape(-1, 1024) spec_c0 = spec_c0[: stft_result_c0.num_frames, :] spec_c1 = spec_c1[: stft_result_c0.num_frames, :] spec_c0 = torch.nn.functional.pad(spec_c0, (0, 2049 - 1024, 0, 0)) spec_c1 = torch.nn.functional.pad(spec_c1, (0, 2049 - 1024, 0, 0)) spec_c0_real = spec_c0 * orig_real0 spec_c0_imag = spec_c0 * orig_imag0 spec_c1_real = spec_c1 * orig_real1 spec_c1_imag = spec_c1 * orig_imag1 result0 = knf.StftResult( real=spec_c0_real.reshape(-1).tolist(), imag=spec_c0_imag.reshape(-1).tolist(), num_frames=orig_real0.shape[0], ) result1 = knf.StftResult( real=spec_c1_real.reshape(-1).tolist(), imag=spec_c1_imag.reshape(-1).tolist(), num_frames=orig_real1.shape[0], ) wav0 = knf_istft(result0) wav1 = knf_istft(result1) wav = np.array([wav0, wav1], dtype=np.float32) wav = np.transpose(wav) # now wav is (num_samples, num_channels) sf.write(f"./onnx-{name}.wav", wav, 44100) print(f"Saved to ./onnx-{name}.wav") end = time.time() elapsed_seconds = end - start audio_duration = waveform.shape[0] / sample_rate real_time_factor = elapsed_seconds / audio_duration print(f"Elapsed seconds: {elapsed_seconds:.3f}") print(f"Audio duration in seconds: {audio_duration:.3f}") print(f"RTF: {elapsed_seconds:.3f}/{audio_duration:.3f} = {real_time_factor:.3f}") if __name__ == "__main__": main()