Export spleeter model to onnx for source separation (#2237)
This commit is contained in:
117
.github/workflows/export-spleeter-to-onnx.yaml
vendored
Normal file
117
.github/workflows/export-spleeter-to-onnx.yaml
vendored
Normal file
@@ -0,0 +1,117 @@
|
||||
name: export-spleeter-to-onnx
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- spleeter-2
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: export-spleeter-to-onnx-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
export-spleeter-to-onnx:
|
||||
if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj'
|
||||
name: export spleeter to ONNX
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [macos-latest]
|
||||
python-version: ["3.10"]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Install dependencies
|
||||
shell: bash
|
||||
run: |
|
||||
pip install tensorflow torch "numpy<2" onnx==1.17.0 onnxruntime==1.17.1 onnxmltools
|
||||
|
||||
- name: Run
|
||||
shell: bash
|
||||
run: |
|
||||
cd scripts/spleeter
|
||||
./run.sh
|
||||
|
||||
echo "---"
|
||||
ls -lh 2stems
|
||||
echo "---"
|
||||
ls -lh 2stems/*.onnx
|
||||
echo "---"
|
||||
|
||||
mv -v 2stems/*.onnx ../..
|
||||
|
||||
- name: Collect models
|
||||
shell: bash
|
||||
run: |
|
||||
mkdir sherpa-onnx-spleeter-2stems
|
||||
mkdir sherpa-onnx-spleeter-2stems-int8
|
||||
mkdir sherpa-onnx-spleeter-2stems-fp16
|
||||
|
||||
mv -v vocals.onnx sherpa-onnx-spleeter-2stems/
|
||||
mv -v accompaniment.onnx sherpa-onnx-spleeter-2stems/
|
||||
|
||||
mv -v vocals.int8.onnx sherpa-onnx-spleeter-2stems-int8/
|
||||
mv -v accompaniment.int8.onnx sherpa-onnx-spleeter-2stems-int8/
|
||||
|
||||
mv -v vocals.fp16.onnx sherpa-onnx-spleeter-2stems-fp16/
|
||||
mv -v accompaniment.fp16.onnx sherpa-onnx-spleeter-2stems-fp16/
|
||||
|
||||
tar cjvf sherpa-onnx-spleeter-2stems.tar.bz2 sherpa-onnx-spleeter-2stems
|
||||
tar cjvf sherpa-onnx-spleeter-2stems-int8.tar.bz2 sherpa-onnx-spleeter-2stems-int8
|
||||
tar cjvf sherpa-onnx-spleeter-2stems-fp16.tar.bz2 sherpa-onnx-spleeter-2stems-fp16
|
||||
|
||||
ls -lh *.tar.bz2
|
||||
|
||||
- name: Release
|
||||
uses: svenstaro/upload-release-action@v2
|
||||
with:
|
||||
file_glob: true
|
||||
file: ./*.tar.bz2
|
||||
overwrite: true
|
||||
repo_name: k2-fsa/sherpa-onnx
|
||||
repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
|
||||
tag: source-separation-models
|
||||
|
||||
- name: Publish to huggingface
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
uses: nick-fields/retry@v3
|
||||
with:
|
||||
max_attempts: 20
|
||||
timeout_seconds: 200
|
||||
shell: bash
|
||||
command: |
|
||||
git config --global user.email "csukuangfj@gmail.com"
|
||||
git config --global user.name "Fangjun Kuang"
|
||||
|
||||
export GIT_LFS_SKIP_SMUDGE=1
|
||||
export GIT_CLONE_PROTECTION_ACTIVE=false
|
||||
|
||||
names=(
|
||||
sherpa-onnx-spleeter-2stems
|
||||
sherpa-onnx-spleeter-2stems-int8
|
||||
sherpa-onnx-spleeter-2stems-fp16
|
||||
)
|
||||
for d in ${names[@]}; do
|
||||
rm -rf huggingface
|
||||
git clone https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$d huggingface
|
||||
cp -v $d/*onnx huggingface
|
||||
|
||||
cd huggingface
|
||||
git lfs track "*.onnx"
|
||||
git status
|
||||
git add .
|
||||
ls -lh
|
||||
git status
|
||||
git commit -m "add models"
|
||||
git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$d main
|
||||
cd ..
|
||||
done
|
||||
2
scripts/spleeter/.gitignore
vendored
Normal file
2
scripts/spleeter/.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
2stems.tar.gz
|
||||
2stems
|
||||
0
scripts/spleeter/__init__.py
Normal file
0
scripts/spleeter/__init__.py
Normal file
89
scripts/spleeter/convert_to_pb.py
Executable file
89
scripts/spleeter/convert_to_pb.py
Executable file
@@ -0,0 +1,89 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Code in this file is modified from
|
||||
# https://blog.metaflow.fr/tensorflow-how-to-freeze-a-model-and-serve-it-with-a-python-api-d4f3596b3adc
|
||||
#
|
||||
# Please see ./run.sh for usages
|
||||
import argparse
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def freeze_graph(model_dir, output_node_names, output_filename):
|
||||
"""Extract the sub graph defined by the output nodes and convert all its
|
||||
variables into constant
|
||||
|
||||
Args:
|
||||
model_dir:
|
||||
the root folder containing the checkpoint state file
|
||||
output_node_names:
|
||||
a string, containing all the output node's names, comma separated
|
||||
output_filename:
|
||||
Filename to save the graph.
|
||||
"""
|
||||
if not tf.compat.v1.gfile.Exists(model_dir):
|
||||
raise AssertionError(
|
||||
"Export directory doesn't exists. Please specify an export "
|
||||
"directory: %s" % model_dir
|
||||
)
|
||||
|
||||
if not output_node_names:
|
||||
print("You need to supply the name of a node to --output_node_names.")
|
||||
return -1
|
||||
|
||||
# We retrieve our checkpoint fullpath
|
||||
checkpoint = tf.train.get_checkpoint_state(model_dir)
|
||||
input_checkpoint = checkpoint.model_checkpoint_path
|
||||
|
||||
# We precise the file fullname of our freezed graph
|
||||
output_graph = output_filename
|
||||
|
||||
# We clear devices to allow TensorFlow to control on which device it will load operations
|
||||
clear_devices = True
|
||||
|
||||
# We start a session using a temporary fresh Graph
|
||||
with tf.compat.v1.Session(graph=tf.Graph()) as sess:
|
||||
# We import the meta graph in the current default Graph
|
||||
saver = tf.compat.v1.train.import_meta_graph(
|
||||
input_checkpoint + ".meta", clear_devices=clear_devices
|
||||
)
|
||||
|
||||
# We restore the weights
|
||||
saver.restore(sess, input_checkpoint)
|
||||
|
||||
# We use a built-in TF helper to export variables to constants
|
||||
output_graph_def = tf.compat.v1.graph_util.convert_variables_to_constants(
|
||||
sess, # The session is used to retrieve the weights
|
||||
tf.compat.v1.get_default_graph().as_graph_def(), # The graph_def is used to retrieve the nodes
|
||||
output_node_names.split(
|
||||
","
|
||||
), # The output node names are used to select the usefull nodes
|
||||
)
|
||||
|
||||
# Finally we serialize and dump the output graph to the filesystem
|
||||
with tf.compat.v1.gfile.GFile(output_graph, "wb") as f:
|
||||
f.write(output_graph_def.SerializeToString())
|
||||
print("%d ops in the final graph." % len(output_graph_def.node))
|
||||
|
||||
return output_graph_def
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model-dir", type=str, default="", help="Model folder to export"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-node-names",
|
||||
type=str,
|
||||
default="vocals_spectrogram/mul,accompaniment_spectrogram/mul",
|
||||
help="The name of the output nodes, comma separated.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output-filename",
|
||||
type=str,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
freeze_graph(args.model_dir, args.output_node_names, args.output_filename)
|
||||
240
scripts/spleeter/convert_to_torch.py
Executable file
240
scripts/spleeter/convert_to_torch.py
Executable file
@@ -0,0 +1,240 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
|
||||
# Please see ./run.sh for usage
|
||||
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
|
||||
from unet import UNet
|
||||
|
||||
|
||||
def load_graph(frozen_graph_filename):
|
||||
# This function is modified from
|
||||
# https://blog.metaflow.fr/tensorflow-how-to-freeze-a-model-and-serve-it-with-a-python-api-d4f3596b3adc
|
||||
|
||||
# We load the protobuf file from the disk and parse it to retrieve the
|
||||
# unserialized graph_def
|
||||
with tf.compat.v1.gfile.GFile(frozen_graph_filename, "rb") as f:
|
||||
graph_def = tf.compat.v1.GraphDef()
|
||||
graph_def.ParseFromString(f.read())
|
||||
|
||||
# Then, we import the graph_def into a new Graph and returns it
|
||||
with tf.Graph().as_default() as graph:
|
||||
# The name var will prefix every op/nodes in your graph
|
||||
# Since we load everything in a new graph, this is not needed
|
||||
# tf.import_graph_def(graph_def, name="prefix")
|
||||
tf.import_graph_def(graph_def, name="")
|
||||
return graph
|
||||
|
||||
|
||||
def generate_waveform():
|
||||
np.random.seed(20230821)
|
||||
waveform = np.random.rand(60 * 44100).astype(np.float32)
|
||||
|
||||
# (num_samples, num_channels)
|
||||
waveform = waveform.reshape(-1, 2)
|
||||
return waveform
|
||||
|
||||
|
||||
def get_param(graph, name):
|
||||
with tf.compat.v1.Session(graph=graph) as sess:
|
||||
constant_ops = [op for op in sess.graph.get_operations() if op.type == "Const"]
|
||||
for constant_op in constant_ops:
|
||||
if constant_op.name != name:
|
||||
continue
|
||||
|
||||
value = sess.run(constant_op.outputs[0])
|
||||
return torch.from_numpy(value)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main(name):
|
||||
graph = load_graph(f"./2stems/frozen_{name}_model.pb")
|
||||
# for op in graph.get_operations():
|
||||
# print(op.name)
|
||||
x = graph.get_tensor_by_name("waveform:0")
|
||||
# y = graph.get_tensor_by_name("Reshape:0")
|
||||
y0 = graph.get_tensor_by_name("strided_slice_3:0")
|
||||
# y1 = graph.get_tensor_by_name("leaky_re_lu_5/LeakyRelu:0")
|
||||
# y1 = graph.get_tensor_by_name("conv2d_5/BiasAdd:0")
|
||||
# y1 = graph.get_tensor_by_name("conv2d_transpose/BiasAdd:0")
|
||||
# y1 = graph.get_tensor_by_name("re_lu/Relu:0")
|
||||
# y1 = graph.get_tensor_by_name("batch_normalization_6/cond/FusedBatchNorm_1:0")
|
||||
# y1 = graph.get_tensor_by_name("concatenate/concat:0")
|
||||
# y1 = graph.get_tensor_by_name("concatenate_1/concat:0")
|
||||
# y1 = graph.get_tensor_by_name("concatenate_4/concat:0")
|
||||
# y1 = graph.get_tensor_by_name("batch_normalization_11/cond/FusedBatchNorm_1:0")
|
||||
# y1 = graph.get_tensor_by_name("conv2d_6/Sigmoid:0")
|
||||
y1 = graph.get_tensor_by_name(f"{name}_spectrogram/mul:0")
|
||||
|
||||
unet = UNet()
|
||||
unet.eval()
|
||||
|
||||
# For the conv2d in tensorflow, weight shape is (kernel_h, kernel_w, in_channel, out_channel)
|
||||
# default input shape is NHWC
|
||||
|
||||
# For the conv2d in torch, weight shape is (out_channel, in_channel, kernel_h, kernel_w)
|
||||
# default input shape is NCHW
|
||||
state_dict = unet.state_dict()
|
||||
# print(list(state_dict.keys()))
|
||||
|
||||
if name == "vocals":
|
||||
state_dict["conv.weight"] = get_param(graph, "conv2d/kernel").permute(
|
||||
3, 2, 0, 1
|
||||
)
|
||||
state_dict["conv.bias"] = get_param(graph, "conv2d/bias")
|
||||
|
||||
state_dict["bn.weight"] = get_param(graph, "batch_normalization/gamma")
|
||||
state_dict["bn.bias"] = get_param(graph, "batch_normalization/beta")
|
||||
state_dict["bn.running_mean"] = get_param(
|
||||
graph, "batch_normalization/moving_mean"
|
||||
)
|
||||
state_dict["bn.running_var"] = get_param(
|
||||
graph, "batch_normalization/moving_variance"
|
||||
)
|
||||
|
||||
conv_offset = 0
|
||||
bn_offset = 0
|
||||
else:
|
||||
state_dict["conv.weight"] = get_param(graph, "conv2d_7/kernel").permute(
|
||||
3, 2, 0, 1
|
||||
)
|
||||
state_dict["conv.bias"] = get_param(graph, "conv2d_7/bias")
|
||||
|
||||
state_dict["bn.weight"] = get_param(graph, "batch_normalization_12/gamma")
|
||||
state_dict["bn.bias"] = get_param(graph, "batch_normalization_12/beta")
|
||||
state_dict["bn.running_mean"] = get_param(
|
||||
graph, "batch_normalization_12/moving_mean"
|
||||
)
|
||||
state_dict["bn.running_var"] = get_param(
|
||||
graph, "batch_normalization_12/moving_variance"
|
||||
)
|
||||
conv_offset = 7
|
||||
bn_offset = 12
|
||||
|
||||
for i in range(1, 6):
|
||||
state_dict[f"conv{i}.weight"] = get_param(
|
||||
graph, f"conv2d_{i+conv_offset}/kernel"
|
||||
).permute(3, 2, 0, 1)
|
||||
state_dict[f"conv{i}.bias"] = get_param(graph, f"conv2d_{i+conv_offset}/bias")
|
||||
if i >= 5:
|
||||
continue
|
||||
state_dict[f"bn{i}.weight"] = get_param(
|
||||
graph, f"batch_normalization_{i+bn_offset}/gamma"
|
||||
)
|
||||
state_dict[f"bn{i}.bias"] = get_param(
|
||||
graph, f"batch_normalization_{i+bn_offset}/beta"
|
||||
)
|
||||
state_dict[f"bn{i}.running_mean"] = get_param(
|
||||
graph, f"batch_normalization_{i+bn_offset}/moving_mean"
|
||||
)
|
||||
state_dict[f"bn{i}.running_var"] = get_param(
|
||||
graph, f"batch_normalization_{i+bn_offset}/moving_variance"
|
||||
)
|
||||
|
||||
if name == "vocals":
|
||||
state_dict["up1.weight"] = get_param(graph, "conv2d_transpose/kernel").permute(
|
||||
3, 2, 0, 1
|
||||
)
|
||||
state_dict["up1.bias"] = get_param(graph, "conv2d_transpose/bias")
|
||||
|
||||
state_dict["bn5.weight"] = get_param(graph, "batch_normalization_6/gamma")
|
||||
state_dict["bn5.bias"] = get_param(graph, "batch_normalization_6/beta")
|
||||
state_dict["bn5.running_mean"] = get_param(
|
||||
graph, "batch_normalization_6/moving_mean"
|
||||
)
|
||||
state_dict["bn5.running_var"] = get_param(
|
||||
graph, "batch_normalization_6/moving_variance"
|
||||
)
|
||||
conv_offset = 0
|
||||
bn_offset = 0
|
||||
else:
|
||||
state_dict["up1.weight"] = get_param(
|
||||
graph, "conv2d_transpose_6/kernel"
|
||||
).permute(3, 2, 0, 1)
|
||||
state_dict["up1.bias"] = get_param(graph, "conv2d_transpose_6/bias")
|
||||
|
||||
state_dict["bn5.weight"] = get_param(graph, "batch_normalization_18/gamma")
|
||||
state_dict["bn5.bias"] = get_param(graph, "batch_normalization_18/beta")
|
||||
state_dict["bn5.running_mean"] = get_param(
|
||||
graph, "batch_normalization_18/moving_mean"
|
||||
)
|
||||
state_dict["bn5.running_var"] = get_param(
|
||||
graph, "batch_normalization_18/moving_variance"
|
||||
)
|
||||
conv_offset = 6
|
||||
bn_offset = 12
|
||||
|
||||
for i in range(1, 6):
|
||||
state_dict[f"up{i+1}.weight"] = get_param(
|
||||
graph, f"conv2d_transpose_{i+conv_offset}/kernel"
|
||||
).permute(3, 2, 0, 1)
|
||||
|
||||
state_dict[f"up{i+1}.bias"] = get_param(
|
||||
graph, f"conv2d_transpose_{i+conv_offset}/bias"
|
||||
)
|
||||
|
||||
state_dict[f"bn{5+i}.weight"] = get_param(
|
||||
graph, f"batch_normalization_{6+i+bn_offset}/gamma"
|
||||
)
|
||||
state_dict[f"bn{5+i}.bias"] = get_param(
|
||||
graph, f"batch_normalization_{6+i+bn_offset}/beta"
|
||||
)
|
||||
state_dict[f"bn{5+i}.running_mean"] = get_param(
|
||||
graph, f"batch_normalization_{6+i+bn_offset}/moving_mean"
|
||||
)
|
||||
state_dict[f"bn{5+i}.running_var"] = get_param(
|
||||
graph, f"batch_normalization_{6+i+bn_offset}/moving_variance"
|
||||
)
|
||||
|
||||
if name == "vocals":
|
||||
state_dict["up7.weight"] = get_param(graph, "conv2d_6/kernel").permute(
|
||||
3, 2, 0, 1
|
||||
)
|
||||
state_dict["up7.bias"] = get_param(graph, "conv2d_6/bias")
|
||||
else:
|
||||
state_dict["up7.weight"] = get_param(graph, "conv2d_13/kernel").permute(
|
||||
3, 2, 0, 1
|
||||
)
|
||||
state_dict["up7.bias"] = get_param(graph, "conv2d_13/bias")
|
||||
|
||||
unet.load_state_dict(state_dict)
|
||||
|
||||
with tf.compat.v1.Session(graph=graph) as sess:
|
||||
y0_out, y1_out = sess.run([y0, y1], feed_dict={x: generate_waveform()})
|
||||
# y0_out = sess.run(y0, feed_dict={x: generate_waveform()})
|
||||
# y1_out = sess.run(y1, feed_dict={x: generate_waveform()})
|
||||
# print(y0_out.shape)
|
||||
# print(y1_out.shape)
|
||||
|
||||
# for the batchnormalization in tensorflow,
|
||||
# default input shape is NHWC
|
||||
|
||||
# for the batchnormalization in torch,
|
||||
# default input shape is NCHW
|
||||
|
||||
# NHWC to NCHW
|
||||
torch_y1_out = unet(torch.from_numpy(y0_out).permute(0, 3, 1, 2))
|
||||
|
||||
# print(torch_y1_out.shape, torch.from_numpy(y1_out).permute(0, 3, 1, 2).shape)
|
||||
assert torch.allclose(
|
||||
torch_y1_out, torch.from_numpy(y1_out).permute(0, 3, 1, 2), atol=1e-1
|
||||
), ((torch_y1_out - torch.from_numpy(y1_out).permute(0, 3, 1, 2)).abs().max())
|
||||
torch.save(unet.state_dict(), f"2stems/{name}.pt")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--name",
|
||||
type=str,
|
||||
required=True,
|
||||
choices=["vocals", "accompaniment"],
|
||||
)
|
||||
args = parser.parse_args()
|
||||
print(vars(args))
|
||||
main(args.name)
|
||||
94
scripts/spleeter/export_onnx.py
Executable file
94
scripts/spleeter/export_onnx.py
Executable file
@@ -0,0 +1,94 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
|
||||
import onnx
|
||||
import onnxmltools
|
||||
import torch
|
||||
from onnxmltools.utils.float16_converter import convert_float_to_float16
|
||||
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||
|
||||
from unet import UNet
|
||||
|
||||
|
||||
def export_onnx_fp16(onnx_fp32_path, onnx_fp16_path):
|
||||
onnx_fp32_model = onnxmltools.utils.load_model(onnx_fp32_path)
|
||||
onnx_fp16_model = convert_float_to_float16(onnx_fp32_model, keep_io_types=True)
|
||||
onnxmltools.utils.save_model(onnx_fp16_model, onnx_fp16_path)
|
||||
|
||||
|
||||
def add_meta_data(filename, prefix):
|
||||
meta_data = {
|
||||
"model_type": "spleeter",
|
||||
"sample_rate": 41000,
|
||||
"version": 1,
|
||||
"model_url": "https://github.com/deezer/spleeter",
|
||||
"stems": 2,
|
||||
"comment": prefix,
|
||||
"model_name": "2stems.tar.gz",
|
||||
}
|
||||
model = onnx.load(filename)
|
||||
|
||||
print(model.metadata_props)
|
||||
|
||||
while len(model.metadata_props):
|
||||
model.metadata_props.pop()
|
||||
|
||||
for key, value in meta_data.items():
|
||||
meta = model.metadata_props.add()
|
||||
meta.key = key
|
||||
meta.value = str(value)
|
||||
print("--------------------")
|
||||
|
||||
print(model.metadata_props)
|
||||
|
||||
onnx.save(model, filename)
|
||||
|
||||
|
||||
def export(model, prefix):
|
||||
num_splits = 1
|
||||
x = torch.rand(num_splits, 2, 512, 1024, dtype=torch.float32)
|
||||
|
||||
filename = f"./2stems/{prefix}.onnx"
|
||||
torch.onnx.export(
|
||||
model,
|
||||
x,
|
||||
filename,
|
||||
input_names=["x"],
|
||||
output_names=["y"],
|
||||
dynamic_axes={
|
||||
"x": {0: "num_splits"},
|
||||
},
|
||||
opset_version=13,
|
||||
)
|
||||
|
||||
add_meta_data(filename, prefix)
|
||||
|
||||
filename_int8 = f"./2stems/{prefix}.int8.onnx"
|
||||
quantize_dynamic(
|
||||
model_input=filename,
|
||||
model_output=filename_int8,
|
||||
weight_type=QuantType.QUInt8,
|
||||
)
|
||||
|
||||
filename_fp16 = f"./2stems/{prefix}.fp16.onnx"
|
||||
export_onnx_fp16(filename, filename_fp16)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
vocals = UNet()
|
||||
state_dict = torch.load("./2stems/vocals.pt", map_location="cpu")
|
||||
vocals.load_state_dict(state_dict)
|
||||
vocals.eval()
|
||||
|
||||
accompaniment = UNet()
|
||||
state_dict = torch.load("./2stems/accompaniment.pt", map_location="cpu")
|
||||
accompaniment.load_state_dict(state_dict)
|
||||
accompaniment.eval()
|
||||
|
||||
export(vocals, "vocals")
|
||||
export(accompaniment, "accompaniment")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
41
scripts/spleeter/run.sh
Executable file
41
scripts/spleeter/run.sh
Executable file
@@ -0,0 +1,41 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
|
||||
if [ ! -f 2stems.tar.gz ]; then
|
||||
curl -SL -O https://github.com/deezer/spleeter/releases/download/v1.4.0/2stems.tar.gz
|
||||
fi
|
||||
|
||||
if [ ! -d ./2stems ]; then
|
||||
mkdir -p 2stems
|
||||
cd 2stems
|
||||
tar xvf ../2stems.tar.gz
|
||||
cd ..
|
||||
fi
|
||||
|
||||
ls -lh
|
||||
|
||||
ls -lh 2stems
|
||||
|
||||
if [ ! -f 2stems/frozen_vocals_model.pb ]; then
|
||||
python3 ./convert_to_pb.py \
|
||||
--model-dir ./2stems \
|
||||
--output-node-names vocals_spectrogram/mul \
|
||||
--output-filename ./2stems/frozen_vocals_model.pb
|
||||
fi
|
||||
|
||||
ls -lh 2stems
|
||||
|
||||
if [ ! -f 2stems/frozen_accompaniment_model.pb ]; then
|
||||
python3 ./convert_to_pb.py \
|
||||
--model-dir ./2stems \
|
||||
--output-node-names accompaniment_spectrogram/mul \
|
||||
--output-filename ./2stems/frozen_accompaniment_model.pb
|
||||
fi
|
||||
|
||||
ls -lh 2stems
|
||||
|
||||
python3 ./convert_to_torch.py --name vocals
|
||||
python3 ./convert_to_torch.py --name accompaniment
|
||||
python3 ./export_onnx.py
|
||||
|
||||
ls -lh 2stems
|
||||
170
scripts/spleeter/separate.py
Executable file
170
scripts/spleeter/separate.py
Executable file
@@ -0,0 +1,170 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
|
||||
# Please see ./run.sh for usage
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import ffmpeg
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import torch
|
||||
from pydub import AudioSegment
|
||||
|
||||
from unet import UNet
|
||||
|
||||
|
||||
def load_audio(filename, sample_rate: Optional[int] = 44100):
|
||||
probe = ffmpeg.probe(filename)
|
||||
if "streams" not in probe or len(probe["streams"]) == 0:
|
||||
raise ValueError("No stream was found with ffprobe")
|
||||
|
||||
metadata = next(
|
||||
stream for stream in probe["streams"] if stream["codec_type"] == "audio"
|
||||
)
|
||||
n_channels = metadata["channels"]
|
||||
|
||||
if sample_rate is None:
|
||||
sample_rate = metadata["sample_rate"]
|
||||
|
||||
process = (
|
||||
ffmpeg.input(filename)
|
||||
.output("pipe:", format="f32le", ar=sample_rate)
|
||||
.run_async(pipe_stdout=True, pipe_stderr=True)
|
||||
)
|
||||
buffer, _ = process.communicate()
|
||||
waveform = np.frombuffer(buffer, dtype="<f4").reshape(-1, n_channels)
|
||||
|
||||
waveform = torch.from_numpy(np.copy(waveform)).to(torch.float32)
|
||||
if n_channels == 1:
|
||||
waveform = waveform.tile(1, 2)
|
||||
|
||||
if n_channels > 2:
|
||||
waveform = waveform[:, :2]
|
||||
|
||||
return waveform, sample_rate
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
vocals = UNet()
|
||||
vocals.eval()
|
||||
state_dict = torch.load("./2stems/vocals.pt", map_location="cpu")
|
||||
vocals.load_state_dict(state_dict)
|
||||
|
||||
accompaniment = UNet()
|
||||
accompaniment.eval()
|
||||
state_dict = torch.load("./2stems/accompaniment.pt", map_location="cpu")
|
||||
accompaniment.load_state_dict(state_dict)
|
||||
|
||||
#
|
||||
# waveform, sample_rate = load_audio("./audio_example.mp3")
|
||||
|
||||
# You can download the following two mp3 from
|
||||
# https://huggingface.co/spaces/csukuangfj/music-source-separation/tree/main/examples
|
||||
waveform, sample_rate = load_audio("./qi-feng-le.mp3")
|
||||
# waveform, sample_rate = load_audio("./Yesterday_Once_More-Carpenters.mp3")
|
||||
assert waveform.shape[1] == 2, waveform.shape
|
||||
|
||||
waveform = torch.nn.functional.pad(waveform, (0, 0, 0, 4096))
|
||||
|
||||
# torch.stft requires a 2-D input of shape (N, T), so we transpose waveform
|
||||
stft = torch.stft(
|
||||
waveform.t(),
|
||||
n_fft=4096,
|
||||
hop_length=1024,
|
||||
window=torch.hann_window(4096, periodic=True),
|
||||
center=False,
|
||||
onesided=True,
|
||||
return_complex=True,
|
||||
)
|
||||
print("stft", stft.shape)
|
||||
|
||||
# stft: (2, 2049, 465)
|
||||
# stft is a complex tensor
|
||||
y = stft.permute(2, 1, 0)
|
||||
print("y0", y.shape)
|
||||
# (465, 2049, 2)
|
||||
|
||||
y = y[:, :1024, :]
|
||||
# (465, 1024, 2)
|
||||
|
||||
tensor_size = y.shape[0] - int(y.shape[0] / 512) * 512
|
||||
pad_size = 512 - tensor_size
|
||||
y = torch.nn.functional.pad(y, (0, 0, 0, 0, 0, pad_size))
|
||||
# (512, 1024, 2)
|
||||
print("y1", y.shape, y.dtype)
|
||||
|
||||
num_splits = int(y.shape[0] / 512)
|
||||
y = y.reshape([num_splits, 512] + list(y.shape[1:]))
|
||||
# y: (1, 512, 1024, 2)
|
||||
print("y2", y.shape, y.dtype)
|
||||
|
||||
y = y.abs()
|
||||
y = y.permute(0, 3, 1, 2)
|
||||
# (1, 2, 512, 1024)
|
||||
print("y3", y.shape, y.dtype)
|
||||
|
||||
vocals_spec = vocals(y)
|
||||
accompaniment_spec = accompaniment(y)
|
||||
|
||||
sum_spec = (vocals_spec**2 + accompaniment_spec**2) + 1e-10
|
||||
print(
|
||||
"vocals_spec",
|
||||
vocals_spec.shape,
|
||||
accompaniment_spec.shape,
|
||||
sum_spec.shape,
|
||||
vocals_spec.dtype,
|
||||
)
|
||||
|
||||
vocals_spec = (vocals_spec**2 + 1e-10 / 2) / sum_spec
|
||||
# (1, 2, 512, 1024)
|
||||
|
||||
accompaniment_spec = (accompaniment_spec**2 + 1e-10 / 2) / sum_spec
|
||||
# (1, 2, 512, 1024)
|
||||
|
||||
for name, spec in zip(
|
||||
["vocals", "accompaniment"], [vocals_spec, accompaniment_spec]
|
||||
):
|
||||
spec = torch.nn.functional.pad(spec, (0, 2049 - 1024, 0, 0, 0, 0, 0, 0))
|
||||
# (1, 2, 512, 2049)
|
||||
|
||||
spec = spec.permute(0, 2, 3, 1)
|
||||
# (1, 512, 2049, 2)
|
||||
print("here00", spec.shape)
|
||||
|
||||
spec = spec.reshape(-1, spec.shape[2], spec.shape[3])
|
||||
# (512, 2049, 2)
|
||||
|
||||
print("here2", spec.shape)
|
||||
# (512, 2049, 2)
|
||||
|
||||
spec = spec[: stft.shape[2], :, :]
|
||||
# (465, 2049, 2)
|
||||
print("here 3", spec.shape, stft.shape)
|
||||
|
||||
spec = spec.permute(2, 1, 0)
|
||||
# (2, 2049, 465)
|
||||
|
||||
masked_stft = spec * stft
|
||||
|
||||
wave = torch.istft(
|
||||
masked_stft,
|
||||
4096,
|
||||
1024,
|
||||
window=torch.hann_window(4096, periodic=True),
|
||||
onesided=True,
|
||||
) * (2 / 3)
|
||||
|
||||
print(wave.shape, wave.dtype)
|
||||
sf.write(f"{name}.wav", wave.t(), 44100)
|
||||
|
||||
wave = (wave.t() * 32768).to(torch.int16)
|
||||
sound = AudioSegment(
|
||||
data=wave.numpy().tobytes(), sample_width=2, frame_rate=44100, channels=2
|
||||
)
|
||||
sound.export(f"{name}.mp3", format="mp3", bitrate="128k")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
197
scripts/spleeter/separate_onnx.py
Executable file
197
scripts/spleeter/separate_onnx.py
Executable file
@@ -0,0 +1,197 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
import time
|
||||
|
||||
import kaldi_native_fbank as knf
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import soundfile as sf
|
||||
import torch
|
||||
|
||||
from separate import load_audio
|
||||
|
||||
"""
|
||||
----------inputs for ./2stems/vocals.onnx----------
|
||||
NodeArg(name='x', type='tensor(float)', shape=['num_splits', 2, 512, 1024])
|
||||
----------outputs for ./2stems/vocals.onnx----------
|
||||
NodeArg(name='y', type='tensor(float)', shape=['Muly_dim_0', 2, 512, 1024])
|
||||
|
||||
----------inputs for ./2stems/accompaniment.onnx----------
|
||||
NodeArg(name='x', type='tensor(float)', shape=['num_splits', 2, 512, 1024])
|
||||
----------outputs for ./2stems/accompaniment.onnx----------
|
||||
NodeArg(name='y', type='tensor(float)', shape=['Muly_dim_0', 2, 512, 1024])
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class OnnxModel:
|
||||
def __init__(self, filename):
|
||||
session_opts = ort.SessionOptions()
|
||||
session_opts.inter_op_num_threads = 1
|
||||
session_opts.intra_op_num_threads = 1
|
||||
|
||||
self.session_opts = session_opts
|
||||
self.model = ort.InferenceSession(
|
||||
filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
print(f"----------inputs for {filename}----------")
|
||||
for i in self.model.get_inputs():
|
||||
print(i)
|
||||
|
||||
print(f"----------outputs for {filename}----------")
|
||||
|
||||
for i in self.model.get_outputs():
|
||||
print(i)
|
||||
print("--------------------")
|
||||
|
||||
def __call__(self, x):
|
||||
"""
|
||||
Args:
|
||||
x: (num_splits, 2, 512, 1024)
|
||||
"""
|
||||
spec = self.model.run(
|
||||
[
|
||||
self.model.get_outputs()[0].name,
|
||||
],
|
||||
{
|
||||
self.model.get_inputs()[0].name: x.numpy(),
|
||||
},
|
||||
)[0]
|
||||
|
||||
return torch.from_numpy(spec)
|
||||
|
||||
|
||||
def main():
|
||||
vocals = OnnxModel("./2stems/vocals.onnx")
|
||||
accompaniment = OnnxModel("./2stems/accompaniment.onnx")
|
||||
|
||||
waveform, sample_rate = load_audio("./qi-feng-le.mp3")
|
||||
waveform = waveform[: 44100 * 10, :]
|
||||
|
||||
stft_config = knf.StftConfig(
|
||||
n_fft=4096,
|
||||
hop_length=1024,
|
||||
win_length=4096,
|
||||
center=False,
|
||||
window_type="hann",
|
||||
)
|
||||
knf_stft = knf.Stft(stft_config)
|
||||
knf_istft = knf.IStft(stft_config)
|
||||
|
||||
start = time.time()
|
||||
|
||||
stft_result_c0 = knf_stft(waveform[:, 0].tolist())
|
||||
stft_result_c1 = knf_stft(waveform[:, 1].tolist())
|
||||
print("c0 stft", stft_result_c0.num_frames)
|
||||
|
||||
orig_real0 = np.array(stft_result_c0.real, dtype=np.float32).reshape(
|
||||
stft_result_c0.num_frames, -1
|
||||
)
|
||||
orig_imag0 = np.array(stft_result_c0.imag, dtype=np.float32).reshape(
|
||||
stft_result_c0.num_frames, -1
|
||||
)
|
||||
|
||||
orig_real1 = np.array(stft_result_c1.real, dtype=np.float32).reshape(
|
||||
stft_result_c1.num_frames, -1
|
||||
)
|
||||
orig_imag1 = np.array(stft_result_c1.imag, dtype=np.float32).reshape(
|
||||
stft_result_c1.num_frames, -1
|
||||
)
|
||||
|
||||
real0 = torch.from_numpy(orig_real0)
|
||||
imag0 = torch.from_numpy(orig_imag0)
|
||||
real1 = torch.from_numpy(orig_real1)
|
||||
imag1 = torch.from_numpy(orig_imag1)
|
||||
# (num_frames, n_fft/2_1)
|
||||
print("real0", real0.shape)
|
||||
|
||||
# keep only the first 1024 bins
|
||||
real0 = real0[:, :1024]
|
||||
imag0 = imag0[:, :1024]
|
||||
real1 = real1[:, :1024]
|
||||
imag1 = imag1[:, :1024]
|
||||
|
||||
stft0 = (real0.square() + imag0.square()).sqrt()
|
||||
stft1 = (real1.square() + imag1.square()).sqrt()
|
||||
|
||||
# pad it to multiple of 512
|
||||
padding = 512 - real0.shape[0] % 512
|
||||
print("padding", padding)
|
||||
if padding > 0:
|
||||
stft0 = torch.nn.functional.pad(stft0, (0, 0, 0, padding))
|
||||
stft1 = torch.nn.functional.pad(stft1, (0, 0, 0, padding))
|
||||
stft0 = stft0.reshape(-1, 1, 512, 1024)
|
||||
stft1 = stft1.reshape(-1, 1, 512, 1024)
|
||||
|
||||
stft_01 = torch.cat([stft0, stft1], axis=1)
|
||||
|
||||
print("stft_01", stft_01.shape, stft_01.dtype)
|
||||
|
||||
vocals_spec = vocals(stft_01)
|
||||
accompaniment_spec = accompaniment(stft_01)
|
||||
# (num_splits, num_channels, 512, 1024)
|
||||
|
||||
sum_spec = (vocals_spec.square() + accompaniment_spec.square()) + 1e-10
|
||||
|
||||
vocals_spec = (vocals_spec**2 + 1e-10 / 2) / sum_spec
|
||||
accompaniment_spec = (accompaniment_spec**2 + 1e-10 / 2) / sum_spec
|
||||
|
||||
for name, spec in zip(
|
||||
["vocals", "accompaniment"], [vocals_spec, accompaniment_spec]
|
||||
):
|
||||
spec_c0 = spec[:, 0, :, :]
|
||||
spec_c1 = spec[:, 1, :, :]
|
||||
|
||||
spec_c0 = spec_c0.reshape(-1, 1024)
|
||||
spec_c1 = spec_c1.reshape(-1, 1024)
|
||||
|
||||
spec_c0 = spec_c0[: stft_result_c0.num_frames, :]
|
||||
spec_c1 = spec_c1[: stft_result_c0.num_frames, :]
|
||||
|
||||
spec_c0 = torch.nn.functional.pad(spec_c0, (0, 2049 - 1024, 0, 0))
|
||||
spec_c1 = torch.nn.functional.pad(spec_c1, (0, 2049 - 1024, 0, 0))
|
||||
|
||||
spec_c0_real = spec_c0 * orig_real0
|
||||
spec_c0_imag = spec_c0 * orig_imag0
|
||||
|
||||
spec_c1_real = spec_c1 * orig_real1
|
||||
spec_c1_imag = spec_c1 * orig_imag1
|
||||
|
||||
result0 = knf.StftResult(
|
||||
real=spec_c0_real.reshape(-1).tolist(),
|
||||
imag=spec_c0_imag.reshape(-1).tolist(),
|
||||
num_frames=orig_real0.shape[0],
|
||||
)
|
||||
|
||||
result1 = knf.StftResult(
|
||||
real=spec_c1_real.reshape(-1).tolist(),
|
||||
imag=spec_c1_imag.reshape(-1).tolist(),
|
||||
num_frames=orig_real1.shape[0],
|
||||
)
|
||||
|
||||
wav0 = knf_istft(result0)
|
||||
wav1 = knf_istft(result1)
|
||||
|
||||
wav = np.array([wav0, wav1], dtype=np.float32)
|
||||
wav = np.transpose(wav)
|
||||
# now wav is (num_samples, num_channels)
|
||||
|
||||
sf.write(f"./onnx-{name}.wav", wav, 44100)
|
||||
|
||||
print(f"Saved to ./onnx-{name}.wav")
|
||||
|
||||
end = time.time()
|
||||
elapsed_seconds = end - start
|
||||
audio_duration = waveform.shape[0] / sample_rate
|
||||
real_time_factor = elapsed_seconds / audio_duration
|
||||
|
||||
print(f"Elapsed seconds: {elapsed_seconds:.3f}")
|
||||
print(f"Audio duration in seconds: {audio_duration:.3f}")
|
||||
print(f"RTF: {elapsed_seconds:.3f}/{audio_duration:.3f} = {real_time_factor:.3f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
150
scripts/spleeter/unet.py
Normal file
150
scripts/spleeter/unet.py
Normal file
@@ -0,0 +1,150 @@
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class UNet(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(2, 16, kernel_size=5, stride=(2, 2), padding=0)
|
||||
self.bn = torch.nn.BatchNorm2d(
|
||||
16, track_running_stats=True, eps=1e-3, momentum=0.01
|
||||
)
|
||||
#
|
||||
self.conv1 = torch.nn.Conv2d(16, 32, kernel_size=5, stride=(2, 2), padding=0)
|
||||
self.bn1 = torch.nn.BatchNorm2d(
|
||||
32, track_running_stats=True, eps=1e-3, momentum=0.01
|
||||
)
|
||||
|
||||
self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=5, stride=(2, 2), padding=0)
|
||||
self.bn2 = torch.nn.BatchNorm2d(
|
||||
64, track_running_stats=True, eps=1e-3, momentum=0.01
|
||||
)
|
||||
|
||||
self.conv3 = torch.nn.Conv2d(64, 128, kernel_size=5, stride=(2, 2), padding=0)
|
||||
self.bn3 = torch.nn.BatchNorm2d(
|
||||
128, track_running_stats=True, eps=1e-3, momentum=0.01
|
||||
)
|
||||
|
||||
self.conv4 = torch.nn.Conv2d(128, 256, kernel_size=5, stride=(2, 2), padding=0)
|
||||
self.bn4 = torch.nn.BatchNorm2d(
|
||||
256, track_running_stats=True, eps=1e-3, momentum=0.01
|
||||
)
|
||||
|
||||
self.conv5 = torch.nn.Conv2d(256, 512, kernel_size=5, stride=(2, 2), padding=0)
|
||||
|
||||
self.up1 = torch.nn.ConvTranspose2d(512, 256, kernel_size=5, stride=2)
|
||||
self.bn5 = torch.nn.BatchNorm2d(
|
||||
256, track_running_stats=True, eps=1e-3, momentum=0.01
|
||||
)
|
||||
|
||||
self.up2 = torch.nn.ConvTranspose2d(512, 128, kernel_size=5, stride=2)
|
||||
self.bn6 = torch.nn.BatchNorm2d(
|
||||
128, track_running_stats=True, eps=1e-3, momentum=0.01
|
||||
)
|
||||
|
||||
self.up3 = torch.nn.ConvTranspose2d(256, 64, kernel_size=5, stride=2)
|
||||
self.bn7 = torch.nn.BatchNorm2d(
|
||||
64, track_running_stats=True, eps=1e-3, momentum=0.01
|
||||
)
|
||||
|
||||
self.up4 = torch.nn.ConvTranspose2d(128, 32, kernel_size=5, stride=2)
|
||||
self.bn8 = torch.nn.BatchNorm2d(
|
||||
32, track_running_stats=True, eps=1e-3, momentum=0.01
|
||||
)
|
||||
|
||||
self.up5 = torch.nn.ConvTranspose2d(64, 16, kernel_size=5, stride=2)
|
||||
self.bn9 = torch.nn.BatchNorm2d(
|
||||
16, track_running_stats=True, eps=1e-3, momentum=0.01
|
||||
)
|
||||
|
||||
self.up6 = torch.nn.ConvTranspose2d(32, 1, kernel_size=5, stride=2)
|
||||
self.bn10 = torch.nn.BatchNorm2d(
|
||||
1, track_running_stats=True, eps=1e-3, momentum=0.01
|
||||
)
|
||||
|
||||
# output logit is False, so we need self.up7
|
||||
self.up7 = torch.nn.Conv2d(1, 2, kernel_size=4, dilation=2, padding=3)
|
||||
|
||||
def forward(self, x):
|
||||
in_x = x
|
||||
# in_x is (3, 2, 512, 1024) = (T, 2, 512, 1024)
|
||||
x = torch.nn.functional.pad(x, (1, 2, 1, 2), "constant", 0)
|
||||
conv1 = self.conv(x)
|
||||
batch1 = self.bn(conv1)
|
||||
rel1 = torch.nn.functional.leaky_relu(batch1, negative_slope=0.2)
|
||||
|
||||
x = torch.nn.functional.pad(rel1, (1, 2, 1, 2), "constant", 0)
|
||||
conv2 = self.conv1(x) # (3, 32, 128, 256)
|
||||
batch2 = self.bn1(conv2)
|
||||
rel2 = torch.nn.functional.leaky_relu(
|
||||
batch2, negative_slope=0.2
|
||||
) # (3, 32, 128, 256)
|
||||
|
||||
x = torch.nn.functional.pad(rel2, (1, 2, 1, 2), "constant", 0)
|
||||
conv3 = self.conv2(x) # (3, 64, 64, 128)
|
||||
batch3 = self.bn2(conv3)
|
||||
rel3 = torch.nn.functional.leaky_relu(
|
||||
batch3, negative_slope=0.2
|
||||
) # (3, 64, 64, 128)
|
||||
|
||||
x = torch.nn.functional.pad(rel3, (1, 2, 1, 2), "constant", 0)
|
||||
conv4 = self.conv3(x) # (3, 128, 32, 64)
|
||||
batch4 = self.bn3(conv4)
|
||||
rel4 = torch.nn.functional.leaky_relu(
|
||||
batch4, negative_slope=0.2
|
||||
) # (3, 128, 32, 64)
|
||||
|
||||
x = torch.nn.functional.pad(rel4, (1, 2, 1, 2), "constant", 0)
|
||||
conv5 = self.conv4(x) # (3, 256, 16, 32)
|
||||
batch5 = self.bn4(conv5)
|
||||
rel6 = torch.nn.functional.leaky_relu(
|
||||
batch5, negative_slope=0.2
|
||||
) # (3, 256, 16, 32)
|
||||
|
||||
x = torch.nn.functional.pad(rel6, (1, 2, 1, 2), "constant", 0)
|
||||
conv6 = self.conv5(x) # (3, 512, 8, 16)
|
||||
|
||||
up1 = self.up1(conv6)
|
||||
up1 = up1[:, :, 1:-2, 1:-2] # (3, 256, 16, 32)
|
||||
up1 = torch.nn.functional.relu(up1)
|
||||
batch7 = self.bn5(up1)
|
||||
merge1 = torch.cat([conv5, batch7], axis=1) # (3, 512, 16, 32)
|
||||
|
||||
up2 = self.up2(merge1)
|
||||
up2 = up2[:, :, 1:-2, 1:-2]
|
||||
up2 = torch.nn.functional.relu(up2)
|
||||
batch8 = self.bn6(up2)
|
||||
|
||||
merge2 = torch.cat([conv4, batch8], axis=1) # (3, 256, 32, 64)
|
||||
|
||||
up3 = self.up3(merge2)
|
||||
up3 = up3[:, :, 1:-2, 1:-2]
|
||||
up3 = torch.nn.functional.relu(up3)
|
||||
batch9 = self.bn7(up3)
|
||||
|
||||
merge3 = torch.cat([conv3, batch9], axis=1) # (3, 128, 64, 128)
|
||||
|
||||
up4 = self.up4(merge3)
|
||||
up4 = up4[:, :, 1:-2, 1:-2]
|
||||
up4 = torch.nn.functional.relu(up4)
|
||||
batch10 = self.bn8(up4)
|
||||
|
||||
merge4 = torch.cat([conv2, batch10], axis=1) # (3, 64, 128, 256)
|
||||
|
||||
up5 = self.up5(merge4)
|
||||
up5 = up5[:, :, 1:-2, 1:-2]
|
||||
up5 = torch.nn.functional.relu(up5)
|
||||
batch11 = self.bn9(up5)
|
||||
|
||||
merge5 = torch.cat([conv1, batch11], axis=1) # (3, 32, 256, 512)
|
||||
|
||||
up6 = self.up6(merge5)
|
||||
up6 = up6[:, :, 1:-2, 1:-2]
|
||||
up6 = torch.nn.functional.relu(up6)
|
||||
batch12 = self.bn10(up6) # (3, 1, 512, 1024) = (T, 1, 512, 1024)
|
||||
|
||||
up7 = self.up7(batch12)
|
||||
up7 = torch.sigmoid(up7) # (3, 2, 512, 1024)
|
||||
|
||||
return up7 * in_x
|
||||
Reference in New Issue
Block a user