From 55a44793e66ffd76757ee7f286c3b2dd16b0be4f Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 22 May 2025 15:09:38 +0800 Subject: [PATCH] Export spleeter model to onnx for source separation (#2237) --- .../workflows/export-spleeter-to-onnx.yaml | 117 +++++++++ scripts/spleeter/.gitignore | 2 + scripts/spleeter/__init__.py | 0 scripts/spleeter/convert_to_pb.py | 89 +++++++ scripts/spleeter/convert_to_torch.py | 240 ++++++++++++++++++ scripts/spleeter/export_onnx.py | 94 +++++++ scripts/spleeter/run.sh | 41 +++ scripts/spleeter/separate.py | 170 +++++++++++++ scripts/spleeter/separate_onnx.py | 197 ++++++++++++++ scripts/spleeter/unet.py | 150 +++++++++++ 10 files changed, 1100 insertions(+) create mode 100644 .github/workflows/export-spleeter-to-onnx.yaml create mode 100644 scripts/spleeter/.gitignore create mode 100644 scripts/spleeter/__init__.py create mode 100755 scripts/spleeter/convert_to_pb.py create mode 100755 scripts/spleeter/convert_to_torch.py create mode 100755 scripts/spleeter/export_onnx.py create mode 100755 scripts/spleeter/run.sh create mode 100755 scripts/spleeter/separate.py create mode 100755 scripts/spleeter/separate_onnx.py create mode 100644 scripts/spleeter/unet.py diff --git a/.github/workflows/export-spleeter-to-onnx.yaml b/.github/workflows/export-spleeter-to-onnx.yaml new file mode 100644 index 00000000..f1993ce7 --- /dev/null +++ b/.github/workflows/export-spleeter-to-onnx.yaml @@ -0,0 +1,117 @@ +name: export-spleeter-to-onnx + +on: + push: + branches: + - spleeter-2 + workflow_dispatch: + +concurrency: + group: export-spleeter-to-onnx-${{ github.ref }} + cancel-in-progress: true + +jobs: + export-spleeter-to-onnx: + if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj' + name: export spleeter to ONNX + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [macos-latest] + python-version: ["3.10"] + + steps: + - uses: actions/checkout@v4 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + shell: bash + run: | + pip install tensorflow torch "numpy<2" onnx==1.17.0 onnxruntime==1.17.1 onnxmltools + + - name: Run + shell: bash + run: | + cd scripts/spleeter + ./run.sh + + echo "---" + ls -lh 2stems + echo "---" + ls -lh 2stems/*.onnx + echo "---" + + mv -v 2stems/*.onnx ../.. + + - name: Collect models + shell: bash + run: | + mkdir sherpa-onnx-spleeter-2stems + mkdir sherpa-onnx-spleeter-2stems-int8 + mkdir sherpa-onnx-spleeter-2stems-fp16 + + mv -v vocals.onnx sherpa-onnx-spleeter-2stems/ + mv -v accompaniment.onnx sherpa-onnx-spleeter-2stems/ + + mv -v vocals.int8.onnx sherpa-onnx-spleeter-2stems-int8/ + mv -v accompaniment.int8.onnx sherpa-onnx-spleeter-2stems-int8/ + + mv -v vocals.fp16.onnx sherpa-onnx-spleeter-2stems-fp16/ + mv -v accompaniment.fp16.onnx sherpa-onnx-spleeter-2stems-fp16/ + + tar cjvf sherpa-onnx-spleeter-2stems.tar.bz2 sherpa-onnx-spleeter-2stems + tar cjvf sherpa-onnx-spleeter-2stems-int8.tar.bz2 sherpa-onnx-spleeter-2stems-int8 + tar cjvf sherpa-onnx-spleeter-2stems-fp16.tar.bz2 sherpa-onnx-spleeter-2stems-fp16 + + ls -lh *.tar.bz2 + + - name: Release + uses: svenstaro/upload-release-action@v2 + with: + file_glob: true + file: ./*.tar.bz2 + overwrite: true + repo_name: k2-fsa/sherpa-onnx + repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }} + tag: source-separation-models + + - name: Publish to huggingface + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + uses: nick-fields/retry@v3 + with: + max_attempts: 20 + timeout_seconds: 200 + shell: bash + command: | + git config --global user.email "csukuangfj@gmail.com" + git config --global user.name "Fangjun Kuang" + + export GIT_LFS_SKIP_SMUDGE=1 + export GIT_CLONE_PROTECTION_ACTIVE=false + + names=( + sherpa-onnx-spleeter-2stems + sherpa-onnx-spleeter-2stems-int8 + sherpa-onnx-spleeter-2stems-fp16 + ) + for d in ${names[@]}; do + rm -rf huggingface + git clone https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$d huggingface + cp -v $d/*onnx huggingface + + cd huggingface + git lfs track "*.onnx" + git status + git add . + ls -lh + git status + git commit -m "add models" + git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$d main + cd .. + done diff --git a/scripts/spleeter/.gitignore b/scripts/spleeter/.gitignore new file mode 100644 index 00000000..880b1c04 --- /dev/null +++ b/scripts/spleeter/.gitignore @@ -0,0 +1,2 @@ +2stems.tar.gz +2stems diff --git a/scripts/spleeter/__init__.py b/scripts/spleeter/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/scripts/spleeter/convert_to_pb.py b/scripts/spleeter/convert_to_pb.py new file mode 100755 index 00000000..7be256cc --- /dev/null +++ b/scripts/spleeter/convert_to_pb.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python3 + +# Code in this file is modified from +# https://blog.metaflow.fr/tensorflow-how-to-freeze-a-model-and-serve-it-with-a-python-api-d4f3596b3adc +# +# Please see ./run.sh for usages +import argparse + +import tensorflow as tf + + +def freeze_graph(model_dir, output_node_names, output_filename): + """Extract the sub graph defined by the output nodes and convert all its + variables into constant + + Args: + model_dir: + the root folder containing the checkpoint state file + output_node_names: + a string, containing all the output node's names, comma separated + output_filename: + Filename to save the graph. + """ + if not tf.compat.v1.gfile.Exists(model_dir): + raise AssertionError( + "Export directory doesn't exists. Please specify an export " + "directory: %s" % model_dir + ) + + if not output_node_names: + print("You need to supply the name of a node to --output_node_names.") + return -1 + + # We retrieve our checkpoint fullpath + checkpoint = tf.train.get_checkpoint_state(model_dir) + input_checkpoint = checkpoint.model_checkpoint_path + + # We precise the file fullname of our freezed graph + output_graph = output_filename + + # We clear devices to allow TensorFlow to control on which device it will load operations + clear_devices = True + + # We start a session using a temporary fresh Graph + with tf.compat.v1.Session(graph=tf.Graph()) as sess: + # We import the meta graph in the current default Graph + saver = tf.compat.v1.train.import_meta_graph( + input_checkpoint + ".meta", clear_devices=clear_devices + ) + + # We restore the weights + saver.restore(sess, input_checkpoint) + + # We use a built-in TF helper to export variables to constants + output_graph_def = tf.compat.v1.graph_util.convert_variables_to_constants( + sess, # The session is used to retrieve the weights + tf.compat.v1.get_default_graph().as_graph_def(), # The graph_def is used to retrieve the nodes + output_node_names.split( + "," + ), # The output node names are used to select the usefull nodes + ) + + # Finally we serialize and dump the output graph to the filesystem + with tf.compat.v1.gfile.GFile(output_graph, "wb") as f: + f.write(output_graph_def.SerializeToString()) + print("%d ops in the final graph." % len(output_graph_def.node)) + + return output_graph_def + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--model-dir", type=str, default="", help="Model folder to export" + ) + parser.add_argument( + "--output-node-names", + type=str, + default="vocals_spectrogram/mul,accompaniment_spectrogram/mul", + help="The name of the output nodes, comma separated.", + ) + + parser.add_argument( + "--output-filename", + type=str, + ) + args = parser.parse_args() + + freeze_graph(args.model_dir, args.output_node_names, args.output_filename) diff --git a/scripts/spleeter/convert_to_torch.py b/scripts/spleeter/convert_to_torch.py new file mode 100755 index 00000000..dc6e7580 --- /dev/null +++ b/scripts/spleeter/convert_to_torch.py @@ -0,0 +1,240 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) + +# Please see ./run.sh for usage + +import argparse + +import numpy as np +import tensorflow as tf +import torch + +from unet import UNet + + +def load_graph(frozen_graph_filename): + # This function is modified from + # https://blog.metaflow.fr/tensorflow-how-to-freeze-a-model-and-serve-it-with-a-python-api-d4f3596b3adc + + # We load the protobuf file from the disk and parse it to retrieve the + # unserialized graph_def + with tf.compat.v1.gfile.GFile(frozen_graph_filename, "rb") as f: + graph_def = tf.compat.v1.GraphDef() + graph_def.ParseFromString(f.read()) + + # Then, we import the graph_def into a new Graph and returns it + with tf.Graph().as_default() as graph: + # The name var will prefix every op/nodes in your graph + # Since we load everything in a new graph, this is not needed + # tf.import_graph_def(graph_def, name="prefix") + tf.import_graph_def(graph_def, name="") + return graph + + +def generate_waveform(): + np.random.seed(20230821) + waveform = np.random.rand(60 * 44100).astype(np.float32) + + # (num_samples, num_channels) + waveform = waveform.reshape(-1, 2) + return waveform + + +def get_param(graph, name): + with tf.compat.v1.Session(graph=graph) as sess: + constant_ops = [op for op in sess.graph.get_operations() if op.type == "Const"] + for constant_op in constant_ops: + if constant_op.name != name: + continue + + value = sess.run(constant_op.outputs[0]) + return torch.from_numpy(value) + + +@torch.no_grad() +def main(name): + graph = load_graph(f"./2stems/frozen_{name}_model.pb") + # for op in graph.get_operations(): + # print(op.name) + x = graph.get_tensor_by_name("waveform:0") + # y = graph.get_tensor_by_name("Reshape:0") + y0 = graph.get_tensor_by_name("strided_slice_3:0") + # y1 = graph.get_tensor_by_name("leaky_re_lu_5/LeakyRelu:0") + # y1 = graph.get_tensor_by_name("conv2d_5/BiasAdd:0") + # y1 = graph.get_tensor_by_name("conv2d_transpose/BiasAdd:0") + # y1 = graph.get_tensor_by_name("re_lu/Relu:0") + # y1 = graph.get_tensor_by_name("batch_normalization_6/cond/FusedBatchNorm_1:0") + # y1 = graph.get_tensor_by_name("concatenate/concat:0") + # y1 = graph.get_tensor_by_name("concatenate_1/concat:0") + # y1 = graph.get_tensor_by_name("concatenate_4/concat:0") + # y1 = graph.get_tensor_by_name("batch_normalization_11/cond/FusedBatchNorm_1:0") + # y1 = graph.get_tensor_by_name("conv2d_6/Sigmoid:0") + y1 = graph.get_tensor_by_name(f"{name}_spectrogram/mul:0") + + unet = UNet() + unet.eval() + + # For the conv2d in tensorflow, weight shape is (kernel_h, kernel_w, in_channel, out_channel) + # default input shape is NHWC + + # For the conv2d in torch, weight shape is (out_channel, in_channel, kernel_h, kernel_w) + # default input shape is NCHW + state_dict = unet.state_dict() + # print(list(state_dict.keys())) + + if name == "vocals": + state_dict["conv.weight"] = get_param(graph, "conv2d/kernel").permute( + 3, 2, 0, 1 + ) + state_dict["conv.bias"] = get_param(graph, "conv2d/bias") + + state_dict["bn.weight"] = get_param(graph, "batch_normalization/gamma") + state_dict["bn.bias"] = get_param(graph, "batch_normalization/beta") + state_dict["bn.running_mean"] = get_param( + graph, "batch_normalization/moving_mean" + ) + state_dict["bn.running_var"] = get_param( + graph, "batch_normalization/moving_variance" + ) + + conv_offset = 0 + bn_offset = 0 + else: + state_dict["conv.weight"] = get_param(graph, "conv2d_7/kernel").permute( + 3, 2, 0, 1 + ) + state_dict["conv.bias"] = get_param(graph, "conv2d_7/bias") + + state_dict["bn.weight"] = get_param(graph, "batch_normalization_12/gamma") + state_dict["bn.bias"] = get_param(graph, "batch_normalization_12/beta") + state_dict["bn.running_mean"] = get_param( + graph, "batch_normalization_12/moving_mean" + ) + state_dict["bn.running_var"] = get_param( + graph, "batch_normalization_12/moving_variance" + ) + conv_offset = 7 + bn_offset = 12 + + for i in range(1, 6): + state_dict[f"conv{i}.weight"] = get_param( + graph, f"conv2d_{i+conv_offset}/kernel" + ).permute(3, 2, 0, 1) + state_dict[f"conv{i}.bias"] = get_param(graph, f"conv2d_{i+conv_offset}/bias") + if i >= 5: + continue + state_dict[f"bn{i}.weight"] = get_param( + graph, f"batch_normalization_{i+bn_offset}/gamma" + ) + state_dict[f"bn{i}.bias"] = get_param( + graph, f"batch_normalization_{i+bn_offset}/beta" + ) + state_dict[f"bn{i}.running_mean"] = get_param( + graph, f"batch_normalization_{i+bn_offset}/moving_mean" + ) + state_dict[f"bn{i}.running_var"] = get_param( + graph, f"batch_normalization_{i+bn_offset}/moving_variance" + ) + + if name == "vocals": + state_dict["up1.weight"] = get_param(graph, "conv2d_transpose/kernel").permute( + 3, 2, 0, 1 + ) + state_dict["up1.bias"] = get_param(graph, "conv2d_transpose/bias") + + state_dict["bn5.weight"] = get_param(graph, "batch_normalization_6/gamma") + state_dict["bn5.bias"] = get_param(graph, "batch_normalization_6/beta") + state_dict["bn5.running_mean"] = get_param( + graph, "batch_normalization_6/moving_mean" + ) + state_dict["bn5.running_var"] = get_param( + graph, "batch_normalization_6/moving_variance" + ) + conv_offset = 0 + bn_offset = 0 + else: + state_dict["up1.weight"] = get_param( + graph, "conv2d_transpose_6/kernel" + ).permute(3, 2, 0, 1) + state_dict["up1.bias"] = get_param(graph, "conv2d_transpose_6/bias") + + state_dict["bn5.weight"] = get_param(graph, "batch_normalization_18/gamma") + state_dict["bn5.bias"] = get_param(graph, "batch_normalization_18/beta") + state_dict["bn5.running_mean"] = get_param( + graph, "batch_normalization_18/moving_mean" + ) + state_dict["bn5.running_var"] = get_param( + graph, "batch_normalization_18/moving_variance" + ) + conv_offset = 6 + bn_offset = 12 + + for i in range(1, 6): + state_dict[f"up{i+1}.weight"] = get_param( + graph, f"conv2d_transpose_{i+conv_offset}/kernel" + ).permute(3, 2, 0, 1) + + state_dict[f"up{i+1}.bias"] = get_param( + graph, f"conv2d_transpose_{i+conv_offset}/bias" + ) + + state_dict[f"bn{5+i}.weight"] = get_param( + graph, f"batch_normalization_{6+i+bn_offset}/gamma" + ) + state_dict[f"bn{5+i}.bias"] = get_param( + graph, f"batch_normalization_{6+i+bn_offset}/beta" + ) + state_dict[f"bn{5+i}.running_mean"] = get_param( + graph, f"batch_normalization_{6+i+bn_offset}/moving_mean" + ) + state_dict[f"bn{5+i}.running_var"] = get_param( + graph, f"batch_normalization_{6+i+bn_offset}/moving_variance" + ) + + if name == "vocals": + state_dict["up7.weight"] = get_param(graph, "conv2d_6/kernel").permute( + 3, 2, 0, 1 + ) + state_dict["up7.bias"] = get_param(graph, "conv2d_6/bias") + else: + state_dict["up7.weight"] = get_param(graph, "conv2d_13/kernel").permute( + 3, 2, 0, 1 + ) + state_dict["up7.bias"] = get_param(graph, "conv2d_13/bias") + + unet.load_state_dict(state_dict) + + with tf.compat.v1.Session(graph=graph) as sess: + y0_out, y1_out = sess.run([y0, y1], feed_dict={x: generate_waveform()}) + # y0_out = sess.run(y0, feed_dict={x: generate_waveform()}) + # y1_out = sess.run(y1, feed_dict={x: generate_waveform()}) + # print(y0_out.shape) + # print(y1_out.shape) + + # for the batchnormalization in tensorflow, + # default input shape is NHWC + + # for the batchnormalization in torch, + # default input shape is NCHW + + # NHWC to NCHW + torch_y1_out = unet(torch.from_numpy(y0_out).permute(0, 3, 1, 2)) + + # print(torch_y1_out.shape, torch.from_numpy(y1_out).permute(0, 3, 1, 2).shape) + assert torch.allclose( + torch_y1_out, torch.from_numpy(y1_out).permute(0, 3, 1, 2), atol=1e-1 + ), ((torch_y1_out - torch.from_numpy(y1_out).permute(0, 3, 1, 2)).abs().max()) + torch.save(unet.state_dict(), f"2stems/{name}.pt") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--name", + type=str, + required=True, + choices=["vocals", "accompaniment"], + ) + args = parser.parse_args() + print(vars(args)) + main(args.name) diff --git a/scripts/spleeter/export_onnx.py b/scripts/spleeter/export_onnx.py new file mode 100755 index 00000000..adc26048 --- /dev/null +++ b/scripts/spleeter/export_onnx.py @@ -0,0 +1,94 @@ +#!/usr/bin/env python3 +# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang) + +import onnx +import onnxmltools +import torch +from onnxmltools.utils.float16_converter import convert_float_to_float16 +from onnxruntime.quantization import QuantType, quantize_dynamic + +from unet import UNet + + +def export_onnx_fp16(onnx_fp32_path, onnx_fp16_path): + onnx_fp32_model = onnxmltools.utils.load_model(onnx_fp32_path) + onnx_fp16_model = convert_float_to_float16(onnx_fp32_model, keep_io_types=True) + onnxmltools.utils.save_model(onnx_fp16_model, onnx_fp16_path) + + +def add_meta_data(filename, prefix): + meta_data = { + "model_type": "spleeter", + "sample_rate": 41000, + "version": 1, + "model_url": "https://github.com/deezer/spleeter", + "stems": 2, + "comment": prefix, + "model_name": "2stems.tar.gz", + } + model = onnx.load(filename) + + print(model.metadata_props) + + while len(model.metadata_props): + model.metadata_props.pop() + + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = str(value) + print("--------------------") + + print(model.metadata_props) + + onnx.save(model, filename) + + +def export(model, prefix): + num_splits = 1 + x = torch.rand(num_splits, 2, 512, 1024, dtype=torch.float32) + + filename = f"./2stems/{prefix}.onnx" + torch.onnx.export( + model, + x, + filename, + input_names=["x"], + output_names=["y"], + dynamic_axes={ + "x": {0: "num_splits"}, + }, + opset_version=13, + ) + + add_meta_data(filename, prefix) + + filename_int8 = f"./2stems/{prefix}.int8.onnx" + quantize_dynamic( + model_input=filename, + model_output=filename_int8, + weight_type=QuantType.QUInt8, + ) + + filename_fp16 = f"./2stems/{prefix}.fp16.onnx" + export_onnx_fp16(filename, filename_fp16) + + +@torch.no_grad() +def main(): + vocals = UNet() + state_dict = torch.load("./2stems/vocals.pt", map_location="cpu") + vocals.load_state_dict(state_dict) + vocals.eval() + + accompaniment = UNet() + state_dict = torch.load("./2stems/accompaniment.pt", map_location="cpu") + accompaniment.load_state_dict(state_dict) + accompaniment.eval() + + export(vocals, "vocals") + export(accompaniment, "accompaniment") + + +if __name__ == "__main__": + main() diff --git a/scripts/spleeter/run.sh b/scripts/spleeter/run.sh new file mode 100755 index 00000000..d6651ed0 --- /dev/null +++ b/scripts/spleeter/run.sh @@ -0,0 +1,41 @@ +#!/usr/bin/env bash + + +if [ ! -f 2stems.tar.gz ]; then + curl -SL -O https://github.com/deezer/spleeter/releases/download/v1.4.0/2stems.tar.gz +fi + +if [ ! -d ./2stems ]; then + mkdir -p 2stems + cd 2stems + tar xvf ../2stems.tar.gz + cd .. +fi + +ls -lh + +ls -lh 2stems + +if [ ! -f 2stems/frozen_vocals_model.pb ]; then + python3 ./convert_to_pb.py \ + --model-dir ./2stems \ + --output-node-names vocals_spectrogram/mul \ + --output-filename ./2stems/frozen_vocals_model.pb +fi + +ls -lh 2stems + +if [ ! -f 2stems/frozen_accompaniment_model.pb ]; then + python3 ./convert_to_pb.py \ + --model-dir ./2stems \ + --output-node-names accompaniment_spectrogram/mul \ + --output-filename ./2stems/frozen_accompaniment_model.pb +fi + +ls -lh 2stems + +python3 ./convert_to_torch.py --name vocals +python3 ./convert_to_torch.py --name accompaniment +python3 ./export_onnx.py + +ls -lh 2stems diff --git a/scripts/spleeter/separate.py b/scripts/spleeter/separate.py new file mode 100755 index 00000000..2a83b7ef --- /dev/null +++ b/scripts/spleeter/separate.py @@ -0,0 +1,170 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) + +# Please see ./run.sh for usage + +from typing import Optional + +import ffmpeg +import numpy as np +import soundfile as sf +import torch +from pydub import AudioSegment + +from unet import UNet + + +def load_audio(filename, sample_rate: Optional[int] = 44100): + probe = ffmpeg.probe(filename) + if "streams" not in probe or len(probe["streams"]) == 0: + raise ValueError("No stream was found with ffprobe") + + metadata = next( + stream for stream in probe["streams"] if stream["codec_type"] == "audio" + ) + n_channels = metadata["channels"] + + if sample_rate is None: + sample_rate = metadata["sample_rate"] + + process = ( + ffmpeg.input(filename) + .output("pipe:", format="f32le", ar=sample_rate) + .run_async(pipe_stdout=True, pipe_stderr=True) + ) + buffer, _ = process.communicate() + waveform = np.frombuffer(buffer, dtype=" 2: + waveform = waveform[:, :2] + + return waveform, sample_rate + + +@torch.no_grad() +def main(): + vocals = UNet() + vocals.eval() + state_dict = torch.load("./2stems/vocals.pt", map_location="cpu") + vocals.load_state_dict(state_dict) + + accompaniment = UNet() + accompaniment.eval() + state_dict = torch.load("./2stems/accompaniment.pt", map_location="cpu") + accompaniment.load_state_dict(state_dict) + + # + # waveform, sample_rate = load_audio("./audio_example.mp3") + + # You can download the following two mp3 from + # https://huggingface.co/spaces/csukuangfj/music-source-separation/tree/main/examples + waveform, sample_rate = load_audio("./qi-feng-le.mp3") + # waveform, sample_rate = load_audio("./Yesterday_Once_More-Carpenters.mp3") + assert waveform.shape[1] == 2, waveform.shape + + waveform = torch.nn.functional.pad(waveform, (0, 0, 0, 4096)) + + # torch.stft requires a 2-D input of shape (N, T), so we transpose waveform + stft = torch.stft( + waveform.t(), + n_fft=4096, + hop_length=1024, + window=torch.hann_window(4096, periodic=True), + center=False, + onesided=True, + return_complex=True, + ) + print("stft", stft.shape) + + # stft: (2, 2049, 465) + # stft is a complex tensor + y = stft.permute(2, 1, 0) + print("y0", y.shape) + # (465, 2049, 2) + + y = y[:, :1024, :] + # (465, 1024, 2) + + tensor_size = y.shape[0] - int(y.shape[0] / 512) * 512 + pad_size = 512 - tensor_size + y = torch.nn.functional.pad(y, (0, 0, 0, 0, 0, pad_size)) + # (512, 1024, 2) + print("y1", y.shape, y.dtype) + + num_splits = int(y.shape[0] / 512) + y = y.reshape([num_splits, 512] + list(y.shape[1:])) + # y: (1, 512, 1024, 2) + print("y2", y.shape, y.dtype) + + y = y.abs() + y = y.permute(0, 3, 1, 2) + # (1, 2, 512, 1024) + print("y3", y.shape, y.dtype) + + vocals_spec = vocals(y) + accompaniment_spec = accompaniment(y) + + sum_spec = (vocals_spec**2 + accompaniment_spec**2) + 1e-10 + print( + "vocals_spec", + vocals_spec.shape, + accompaniment_spec.shape, + sum_spec.shape, + vocals_spec.dtype, + ) + + vocals_spec = (vocals_spec**2 + 1e-10 / 2) / sum_spec + # (1, 2, 512, 1024) + + accompaniment_spec = (accompaniment_spec**2 + 1e-10 / 2) / sum_spec + # (1, 2, 512, 1024) + + for name, spec in zip( + ["vocals", "accompaniment"], [vocals_spec, accompaniment_spec] + ): + spec = torch.nn.functional.pad(spec, (0, 2049 - 1024, 0, 0, 0, 0, 0, 0)) + # (1, 2, 512, 2049) + + spec = spec.permute(0, 2, 3, 1) + # (1, 512, 2049, 2) + print("here00", spec.shape) + + spec = spec.reshape(-1, spec.shape[2], spec.shape[3]) + # (512, 2049, 2) + + print("here2", spec.shape) + # (512, 2049, 2) + + spec = spec[: stft.shape[2], :, :] + # (465, 2049, 2) + print("here 3", spec.shape, stft.shape) + + spec = spec.permute(2, 1, 0) + # (2, 2049, 465) + + masked_stft = spec * stft + + wave = torch.istft( + masked_stft, + 4096, + 1024, + window=torch.hann_window(4096, periodic=True), + onesided=True, + ) * (2 / 3) + + print(wave.shape, wave.dtype) + sf.write(f"{name}.wav", wave.t(), 44100) + + wave = (wave.t() * 32768).to(torch.int16) + sound = AudioSegment( + data=wave.numpy().tobytes(), sample_width=2, frame_rate=44100, channels=2 + ) + sound.export(f"{name}.mp3", format="mp3", bitrate="128k") + + +if __name__ == "__main__": + main() diff --git a/scripts/spleeter/separate_onnx.py b/scripts/spleeter/separate_onnx.py new file mode 100755 index 00000000..28ed3760 --- /dev/null +++ b/scripts/spleeter/separate_onnx.py @@ -0,0 +1,197 @@ +#!/usr/bin/env python3 +# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang) +import time + +import kaldi_native_fbank as knf +import numpy as np +import onnxruntime as ort +import soundfile as sf +import torch + +from separate import load_audio + +""" +----------inputs for ./2stems/vocals.onnx---------- +NodeArg(name='x', type='tensor(float)', shape=['num_splits', 2, 512, 1024]) +----------outputs for ./2stems/vocals.onnx---------- +NodeArg(name='y', type='tensor(float)', shape=['Muly_dim_0', 2, 512, 1024]) + +----------inputs for ./2stems/accompaniment.onnx---------- +NodeArg(name='x', type='tensor(float)', shape=['num_splits', 2, 512, 1024]) +----------outputs for ./2stems/accompaniment.onnx---------- +NodeArg(name='y', type='tensor(float)', shape=['Muly_dim_0', 2, 512, 1024]) + +""" + + +class OnnxModel: + def __init__(self, filename): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.session_opts = session_opts + self.model = ort.InferenceSession( + filename, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + + print(f"----------inputs for {filename}----------") + for i in self.model.get_inputs(): + print(i) + + print(f"----------outputs for {filename}----------") + + for i in self.model.get_outputs(): + print(i) + print("--------------------") + + def __call__(self, x): + """ + Args: + x: (num_splits, 2, 512, 1024) + """ + spec = self.model.run( + [ + self.model.get_outputs()[0].name, + ], + { + self.model.get_inputs()[0].name: x.numpy(), + }, + )[0] + + return torch.from_numpy(spec) + + +def main(): + vocals = OnnxModel("./2stems/vocals.onnx") + accompaniment = OnnxModel("./2stems/accompaniment.onnx") + + waveform, sample_rate = load_audio("./qi-feng-le.mp3") + waveform = waveform[: 44100 * 10, :] + + stft_config = knf.StftConfig( + n_fft=4096, + hop_length=1024, + win_length=4096, + center=False, + window_type="hann", + ) + knf_stft = knf.Stft(stft_config) + knf_istft = knf.IStft(stft_config) + + start = time.time() + + stft_result_c0 = knf_stft(waveform[:, 0].tolist()) + stft_result_c1 = knf_stft(waveform[:, 1].tolist()) + print("c0 stft", stft_result_c0.num_frames) + + orig_real0 = np.array(stft_result_c0.real, dtype=np.float32).reshape( + stft_result_c0.num_frames, -1 + ) + orig_imag0 = np.array(stft_result_c0.imag, dtype=np.float32).reshape( + stft_result_c0.num_frames, -1 + ) + + orig_real1 = np.array(stft_result_c1.real, dtype=np.float32).reshape( + stft_result_c1.num_frames, -1 + ) + orig_imag1 = np.array(stft_result_c1.imag, dtype=np.float32).reshape( + stft_result_c1.num_frames, -1 + ) + + real0 = torch.from_numpy(orig_real0) + imag0 = torch.from_numpy(orig_imag0) + real1 = torch.from_numpy(orig_real1) + imag1 = torch.from_numpy(orig_imag1) + # (num_frames, n_fft/2_1) + print("real0", real0.shape) + + # keep only the first 1024 bins + real0 = real0[:, :1024] + imag0 = imag0[:, :1024] + real1 = real1[:, :1024] + imag1 = imag1[:, :1024] + + stft0 = (real0.square() + imag0.square()).sqrt() + stft1 = (real1.square() + imag1.square()).sqrt() + + # pad it to multiple of 512 + padding = 512 - real0.shape[0] % 512 + print("padding", padding) + if padding > 0: + stft0 = torch.nn.functional.pad(stft0, (0, 0, 0, padding)) + stft1 = torch.nn.functional.pad(stft1, (0, 0, 0, padding)) + stft0 = stft0.reshape(-1, 1, 512, 1024) + stft1 = stft1.reshape(-1, 1, 512, 1024) + + stft_01 = torch.cat([stft0, stft1], axis=1) + + print("stft_01", stft_01.shape, stft_01.dtype) + + vocals_spec = vocals(stft_01) + accompaniment_spec = accompaniment(stft_01) + # (num_splits, num_channels, 512, 1024) + + sum_spec = (vocals_spec.square() + accompaniment_spec.square()) + 1e-10 + + vocals_spec = (vocals_spec**2 + 1e-10 / 2) / sum_spec + accompaniment_spec = (accompaniment_spec**2 + 1e-10 / 2) / sum_spec + + for name, spec in zip( + ["vocals", "accompaniment"], [vocals_spec, accompaniment_spec] + ): + spec_c0 = spec[:, 0, :, :] + spec_c1 = spec[:, 1, :, :] + + spec_c0 = spec_c0.reshape(-1, 1024) + spec_c1 = spec_c1.reshape(-1, 1024) + + spec_c0 = spec_c0[: stft_result_c0.num_frames, :] + spec_c1 = spec_c1[: stft_result_c0.num_frames, :] + + spec_c0 = torch.nn.functional.pad(spec_c0, (0, 2049 - 1024, 0, 0)) + spec_c1 = torch.nn.functional.pad(spec_c1, (0, 2049 - 1024, 0, 0)) + + spec_c0_real = spec_c0 * orig_real0 + spec_c0_imag = spec_c0 * orig_imag0 + + spec_c1_real = spec_c1 * orig_real1 + spec_c1_imag = spec_c1 * orig_imag1 + + result0 = knf.StftResult( + real=spec_c0_real.reshape(-1).tolist(), + imag=spec_c0_imag.reshape(-1).tolist(), + num_frames=orig_real0.shape[0], + ) + + result1 = knf.StftResult( + real=spec_c1_real.reshape(-1).tolist(), + imag=spec_c1_imag.reshape(-1).tolist(), + num_frames=orig_real1.shape[0], + ) + + wav0 = knf_istft(result0) + wav1 = knf_istft(result1) + + wav = np.array([wav0, wav1], dtype=np.float32) + wav = np.transpose(wav) + # now wav is (num_samples, num_channels) + + sf.write(f"./onnx-{name}.wav", wav, 44100) + + print(f"Saved to ./onnx-{name}.wav") + + end = time.time() + elapsed_seconds = end - start + audio_duration = waveform.shape[0] / sample_rate + real_time_factor = elapsed_seconds / audio_duration + + print(f"Elapsed seconds: {elapsed_seconds:.3f}") + print(f"Audio duration in seconds: {audio_duration:.3f}") + print(f"RTF: {elapsed_seconds:.3f}/{audio_duration:.3f} = {real_time_factor:.3f}") + + +if __name__ == "__main__": + main() diff --git a/scripts/spleeter/unet.py b/scripts/spleeter/unet.py new file mode 100644 index 00000000..cfcabb6c --- /dev/null +++ b/scripts/spleeter/unet.py @@ -0,0 +1,150 @@ +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) + +import torch + + +class UNet(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(2, 16, kernel_size=5, stride=(2, 2), padding=0) + self.bn = torch.nn.BatchNorm2d( + 16, track_running_stats=True, eps=1e-3, momentum=0.01 + ) + # + self.conv1 = torch.nn.Conv2d(16, 32, kernel_size=5, stride=(2, 2), padding=0) + self.bn1 = torch.nn.BatchNorm2d( + 32, track_running_stats=True, eps=1e-3, momentum=0.01 + ) + + self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=5, stride=(2, 2), padding=0) + self.bn2 = torch.nn.BatchNorm2d( + 64, track_running_stats=True, eps=1e-3, momentum=0.01 + ) + + self.conv3 = torch.nn.Conv2d(64, 128, kernel_size=5, stride=(2, 2), padding=0) + self.bn3 = torch.nn.BatchNorm2d( + 128, track_running_stats=True, eps=1e-3, momentum=0.01 + ) + + self.conv4 = torch.nn.Conv2d(128, 256, kernel_size=5, stride=(2, 2), padding=0) + self.bn4 = torch.nn.BatchNorm2d( + 256, track_running_stats=True, eps=1e-3, momentum=0.01 + ) + + self.conv5 = torch.nn.Conv2d(256, 512, kernel_size=5, stride=(2, 2), padding=0) + + self.up1 = torch.nn.ConvTranspose2d(512, 256, kernel_size=5, stride=2) + self.bn5 = torch.nn.BatchNorm2d( + 256, track_running_stats=True, eps=1e-3, momentum=0.01 + ) + + self.up2 = torch.nn.ConvTranspose2d(512, 128, kernel_size=5, stride=2) + self.bn6 = torch.nn.BatchNorm2d( + 128, track_running_stats=True, eps=1e-3, momentum=0.01 + ) + + self.up3 = torch.nn.ConvTranspose2d(256, 64, kernel_size=5, stride=2) + self.bn7 = torch.nn.BatchNorm2d( + 64, track_running_stats=True, eps=1e-3, momentum=0.01 + ) + + self.up4 = torch.nn.ConvTranspose2d(128, 32, kernel_size=5, stride=2) + self.bn8 = torch.nn.BatchNorm2d( + 32, track_running_stats=True, eps=1e-3, momentum=0.01 + ) + + self.up5 = torch.nn.ConvTranspose2d(64, 16, kernel_size=5, stride=2) + self.bn9 = torch.nn.BatchNorm2d( + 16, track_running_stats=True, eps=1e-3, momentum=0.01 + ) + + self.up6 = torch.nn.ConvTranspose2d(32, 1, kernel_size=5, stride=2) + self.bn10 = torch.nn.BatchNorm2d( + 1, track_running_stats=True, eps=1e-3, momentum=0.01 + ) + + # output logit is False, so we need self.up7 + self.up7 = torch.nn.Conv2d(1, 2, kernel_size=4, dilation=2, padding=3) + + def forward(self, x): + in_x = x + # in_x is (3, 2, 512, 1024) = (T, 2, 512, 1024) + x = torch.nn.functional.pad(x, (1, 2, 1, 2), "constant", 0) + conv1 = self.conv(x) + batch1 = self.bn(conv1) + rel1 = torch.nn.functional.leaky_relu(batch1, negative_slope=0.2) + + x = torch.nn.functional.pad(rel1, (1, 2, 1, 2), "constant", 0) + conv2 = self.conv1(x) # (3, 32, 128, 256) + batch2 = self.bn1(conv2) + rel2 = torch.nn.functional.leaky_relu( + batch2, negative_slope=0.2 + ) # (3, 32, 128, 256) + + x = torch.nn.functional.pad(rel2, (1, 2, 1, 2), "constant", 0) + conv3 = self.conv2(x) # (3, 64, 64, 128) + batch3 = self.bn2(conv3) + rel3 = torch.nn.functional.leaky_relu( + batch3, negative_slope=0.2 + ) # (3, 64, 64, 128) + + x = torch.nn.functional.pad(rel3, (1, 2, 1, 2), "constant", 0) + conv4 = self.conv3(x) # (3, 128, 32, 64) + batch4 = self.bn3(conv4) + rel4 = torch.nn.functional.leaky_relu( + batch4, negative_slope=0.2 + ) # (3, 128, 32, 64) + + x = torch.nn.functional.pad(rel4, (1, 2, 1, 2), "constant", 0) + conv5 = self.conv4(x) # (3, 256, 16, 32) + batch5 = self.bn4(conv5) + rel6 = torch.nn.functional.leaky_relu( + batch5, negative_slope=0.2 + ) # (3, 256, 16, 32) + + x = torch.nn.functional.pad(rel6, (1, 2, 1, 2), "constant", 0) + conv6 = self.conv5(x) # (3, 512, 8, 16) + + up1 = self.up1(conv6) + up1 = up1[:, :, 1:-2, 1:-2] # (3, 256, 16, 32) + up1 = torch.nn.functional.relu(up1) + batch7 = self.bn5(up1) + merge1 = torch.cat([conv5, batch7], axis=1) # (3, 512, 16, 32) + + up2 = self.up2(merge1) + up2 = up2[:, :, 1:-2, 1:-2] + up2 = torch.nn.functional.relu(up2) + batch8 = self.bn6(up2) + + merge2 = torch.cat([conv4, batch8], axis=1) # (3, 256, 32, 64) + + up3 = self.up3(merge2) + up3 = up3[:, :, 1:-2, 1:-2] + up3 = torch.nn.functional.relu(up3) + batch9 = self.bn7(up3) + + merge3 = torch.cat([conv3, batch9], axis=1) # (3, 128, 64, 128) + + up4 = self.up4(merge3) + up4 = up4[:, :, 1:-2, 1:-2] + up4 = torch.nn.functional.relu(up4) + batch10 = self.bn8(up4) + + merge4 = torch.cat([conv2, batch10], axis=1) # (3, 64, 128, 256) + + up5 = self.up5(merge4) + up5 = up5[:, :, 1:-2, 1:-2] + up5 = torch.nn.functional.relu(up5) + batch11 = self.bn9(up5) + + merge5 = torch.cat([conv1, batch11], axis=1) # (3, 32, 256, 512) + + up6 = self.up6(merge5) + up6 = up6[:, :, 1:-2, 1:-2] + up6 = torch.nn.functional.relu(up6) + batch12 = self.bn10(up6) # (3, 1, 512, 1024) = (T, 1, 512, 1024) + + up7 = self.up7(batch12) + up7 = torch.sigmoid(up7) # (3, 2, 512, 1024) + + return up7 * in_x