197 lines
5.9 KiB
Python
Executable File
197 lines
5.9 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
|
|
import time
|
|
|
|
import kaldi_native_fbank as knf
|
|
import numpy as np
|
|
import onnxruntime as ort
|
|
import soundfile as sf
|
|
import torch
|
|
|
|
from separate import load_audio
|
|
|
|
"""
|
|
----------inputs for ./2stems/vocals.onnx----------
|
|
NodeArg(name='x', type='tensor(float)', shape=[2, 'num_splits', 512, 1024])
|
|
----------outputs for ./2stems/vocals.onnx----------
|
|
NodeArg(name='y', type='tensor(float)', shape=[2, 'Transposey_dim_1', 512, 1024])
|
|
|
|
----------inputs for ./2stems/accompaniment.onnx----------
|
|
NodeArg(name='x', type='tensor(float)', shape=[2, 'num_splits', 512, 1024])
|
|
----------outputs for ./2stems/accompaniment.onnx----------
|
|
NodeArg(name='y', type='tensor(float)', shape=[2, 'Transposey_dim_1', 512, 1024])
|
|
"""
|
|
|
|
|
|
class OnnxModel:
|
|
def __init__(self, filename):
|
|
session_opts = ort.SessionOptions()
|
|
session_opts.inter_op_num_threads = 1
|
|
session_opts.intra_op_num_threads = 1
|
|
|
|
self.session_opts = session_opts
|
|
self.model = ort.InferenceSession(
|
|
filename,
|
|
sess_options=self.session_opts,
|
|
providers=["CPUExecutionProvider"],
|
|
)
|
|
|
|
print(f"----------inputs for {filename}----------")
|
|
for i in self.model.get_inputs():
|
|
print(i)
|
|
|
|
print(f"----------outputs for {filename}----------")
|
|
|
|
for i in self.model.get_outputs():
|
|
print(i)
|
|
print("--------------------")
|
|
|
|
def __call__(self, x):
|
|
"""
|
|
Args:
|
|
x: (num_splits, 2, 512, 1024)
|
|
"""
|
|
spec = self.model.run(
|
|
[
|
|
self.model.get_outputs()[0].name,
|
|
],
|
|
{
|
|
self.model.get_inputs()[0].name: x.numpy(),
|
|
},
|
|
)[0]
|
|
|
|
return torch.from_numpy(spec)
|
|
|
|
|
|
def main():
|
|
vocals = OnnxModel("./2stems/vocals.onnx")
|
|
accompaniment = OnnxModel("./2stems/accompaniment.onnx")
|
|
|
|
waveform, sample_rate = load_audio("./qi-feng-le.mp3")
|
|
waveform = waveform[: 44100 * 10, :]
|
|
|
|
stft_config = knf.StftConfig(
|
|
n_fft=4096,
|
|
hop_length=1024,
|
|
win_length=4096,
|
|
center=False,
|
|
window_type="hann",
|
|
)
|
|
knf_stft = knf.Stft(stft_config)
|
|
knf_istft = knf.IStft(stft_config)
|
|
|
|
start = time.time()
|
|
|
|
stft_result_c0 = knf_stft(waveform[:, 0].tolist())
|
|
stft_result_c1 = knf_stft(waveform[:, 1].tolist())
|
|
print("c0 stft", stft_result_c0.num_frames)
|
|
|
|
orig_real0 = np.array(stft_result_c0.real, dtype=np.float32).reshape(
|
|
stft_result_c0.num_frames, -1
|
|
)
|
|
orig_imag0 = np.array(stft_result_c0.imag, dtype=np.float32).reshape(
|
|
stft_result_c0.num_frames, -1
|
|
)
|
|
|
|
orig_real1 = np.array(stft_result_c1.real, dtype=np.float32).reshape(
|
|
stft_result_c1.num_frames, -1
|
|
)
|
|
orig_imag1 = np.array(stft_result_c1.imag, dtype=np.float32).reshape(
|
|
stft_result_c1.num_frames, -1
|
|
)
|
|
|
|
real0 = torch.from_numpy(orig_real0)
|
|
imag0 = torch.from_numpy(orig_imag0)
|
|
real1 = torch.from_numpy(orig_real1)
|
|
imag1 = torch.from_numpy(orig_imag1)
|
|
# (num_frames, n_fft/2_1)
|
|
print("real0", real0.shape)
|
|
|
|
# keep only the first 1024 bins
|
|
real0 = real0[:, :1024]
|
|
imag0 = imag0[:, :1024]
|
|
real1 = real1[:, :1024]
|
|
imag1 = imag1[:, :1024]
|
|
|
|
stft0 = (real0.square() + imag0.square()).sqrt()
|
|
stft1 = (real1.square() + imag1.square()).sqrt()
|
|
|
|
# pad it to multiple of 512
|
|
padding = 512 - real0.shape[0] % 512
|
|
print("padding", padding)
|
|
if padding > 0:
|
|
stft0 = torch.nn.functional.pad(stft0, (0, 0, 0, padding))
|
|
stft1 = torch.nn.functional.pad(stft1, (0, 0, 0, padding))
|
|
stft0 = stft0.reshape(1, -1, 512, 1024)
|
|
stft1 = stft1.reshape(1, -1, 512, 1024)
|
|
|
|
stft_01 = torch.cat([stft0, stft1], axis=0)
|
|
|
|
print("stft_01", stft_01.shape, stft_01.dtype)
|
|
|
|
vocals_spec = vocals(stft_01)
|
|
accompaniment_spec = accompaniment(stft_01)
|
|
# (num_channels, num_splits, 512, 1024)
|
|
|
|
sum_spec = (vocals_spec.square() + accompaniment_spec.square()) + 1e-10
|
|
|
|
vocals_spec = (vocals_spec**2 + 1e-10 / 2) / sum_spec
|
|
accompaniment_spec = (accompaniment_spec**2 + 1e-10 / 2) / sum_spec
|
|
|
|
for name, spec in zip(
|
|
["vocals", "accompaniment"], [vocals_spec, accompaniment_spec]
|
|
):
|
|
spec_c0 = spec[0]
|
|
spec_c1 = spec[1]
|
|
|
|
spec_c0 = spec_c0.reshape(-1, 1024)
|
|
spec_c1 = spec_c1.reshape(-1, 1024)
|
|
|
|
spec_c0 = spec_c0[: stft_result_c0.num_frames, :]
|
|
spec_c1 = spec_c1[: stft_result_c0.num_frames, :]
|
|
|
|
spec_c0 = torch.nn.functional.pad(spec_c0, (0, 2049 - 1024, 0, 0))
|
|
spec_c1 = torch.nn.functional.pad(spec_c1, (0, 2049 - 1024, 0, 0))
|
|
|
|
spec_c0_real = spec_c0 * orig_real0
|
|
spec_c0_imag = spec_c0 * orig_imag0
|
|
|
|
spec_c1_real = spec_c1 * orig_real1
|
|
spec_c1_imag = spec_c1 * orig_imag1
|
|
|
|
result0 = knf.StftResult(
|
|
real=spec_c0_real.reshape(-1).tolist(),
|
|
imag=spec_c0_imag.reshape(-1).tolist(),
|
|
num_frames=orig_real0.shape[0],
|
|
)
|
|
|
|
result1 = knf.StftResult(
|
|
real=spec_c1_real.reshape(-1).tolist(),
|
|
imag=spec_c1_imag.reshape(-1).tolist(),
|
|
num_frames=orig_real1.shape[0],
|
|
)
|
|
|
|
wav0 = knf_istft(result0)
|
|
wav1 = knf_istft(result1)
|
|
|
|
wav = np.array([wav0, wav1], dtype=np.float32)
|
|
wav = np.transpose(wav)
|
|
# now wav is (num_samples, num_channels)
|
|
|
|
sf.write(f"./onnx-{name}.wav", wav, 44100)
|
|
|
|
print(f"Saved to ./onnx-{name}.wav")
|
|
|
|
end = time.time()
|
|
elapsed_seconds = end - start
|
|
audio_duration = waveform.shape[0] / sample_rate
|
|
real_time_factor = elapsed_seconds / audio_duration
|
|
|
|
print(f"Elapsed seconds: {elapsed_seconds:.3f}")
|
|
print(f"Audio duration in seconds: {audio_duration:.3f}")
|
|
print(f"RTF: {elapsed_seconds:.3f}/{audio_duration:.3f} = {real_time_factor:.3f}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|