Add C++ runtime for spleeter about source separation (#2242)

This commit is contained in:
Fangjun Kuang
2025-05-23 22:30:57 +08:00
committed by GitHub
parent 55a44793e6
commit 716ba8317b
28 changed files with 1267 additions and 72 deletions

View File

@@ -217,8 +217,8 @@ def main(name):
# for the batchnormalization in torch,
# default input shape is NCHW
# NHWC to NCHW
torch_y1_out = unet(torch.from_numpy(y0_out).permute(0, 3, 1, 2))
torch_y1_out = unet(torch.from_numpy(y0_out).permute(3, 0, 1, 2))
torch_y1_out = torch_y1_out.permute(1, 0, 2, 3)
# print(torch_y1_out.shape, torch.from_numpy(y1_out).permute(0, 3, 1, 2).shape)
assert torch.allclose(

View File

@@ -46,7 +46,7 @@ def add_meta_data(filename, prefix):
def export(model, prefix):
num_splits = 1
x = torch.rand(num_splits, 2, 512, 1024, dtype=torch.float32)
x = torch.rand(2, num_splits, 512, 1024, dtype=torch.float32)
filename = f"./2stems/{prefix}.onnx"
torch.onnx.export(
@@ -56,7 +56,7 @@ def export(model, prefix):
input_names=["x"],
output_names=["y"],
dynamic_axes={
"x": {0: "num_splits"},
"x": {1: "num_splits"},
},
opset_version=13,
)

View File

@@ -101,13 +101,17 @@ def main():
print("y2", y.shape, y.dtype)
y = y.abs()
y = y.permute(0, 3, 1, 2)
# (1, 2, 512, 1024)
y = y.permute(3, 0, 1, 2)
# (2, 1, 512, 1024)
print("y3", y.shape, y.dtype)
vocals_spec = vocals(y)
accompaniment_spec = accompaniment(y)
vocals_spec = vocals_spec.permute(1, 0, 2, 3)
accompaniment_spec = accompaniment_spec.permute(1, 0, 2, 3)
sum_spec = (vocals_spec**2 + accompaniment_spec**2) + 1e-10
print(
"vocals_spec",

View File

@@ -12,15 +12,14 @@ from separate import load_audio
"""
----------inputs for ./2stems/vocals.onnx----------
NodeArg(name='x', type='tensor(float)', shape=['num_splits', 2, 512, 1024])
NodeArg(name='x', type='tensor(float)', shape=[2, 'num_splits', 512, 1024])
----------outputs for ./2stems/vocals.onnx----------
NodeArg(name='y', type='tensor(float)', shape=['Muly_dim_0', 2, 512, 1024])
NodeArg(name='y', type='tensor(float)', shape=[2, 'Transposey_dim_1', 512, 1024])
----------inputs for ./2stems/accompaniment.onnx----------
NodeArg(name='x', type='tensor(float)', shape=['num_splits', 2, 512, 1024])
NodeArg(name='x', type='tensor(float)', shape=[2, 'num_splits', 512, 1024])
----------outputs for ./2stems/accompaniment.onnx----------
NodeArg(name='y', type='tensor(float)', shape=['Muly_dim_0', 2, 512, 1024])
NodeArg(name='y', type='tensor(float)', shape=[2, 'Transposey_dim_1', 512, 1024])
"""
@@ -123,16 +122,16 @@ def main():
if padding > 0:
stft0 = torch.nn.functional.pad(stft0, (0, 0, 0, padding))
stft1 = torch.nn.functional.pad(stft1, (0, 0, 0, padding))
stft0 = stft0.reshape(-1, 1, 512, 1024)
stft1 = stft1.reshape(-1, 1, 512, 1024)
stft0 = stft0.reshape(1, -1, 512, 1024)
stft1 = stft1.reshape(1, -1, 512, 1024)
stft_01 = torch.cat([stft0, stft1], axis=1)
stft_01 = torch.cat([stft0, stft1], axis=0)
print("stft_01", stft_01.shape, stft_01.dtype)
vocals_spec = vocals(stft_01)
accompaniment_spec = accompaniment(stft_01)
# (num_splits, num_channels, 512, 1024)
# (num_channels, num_splits, 512, 1024)
sum_spec = (vocals_spec.square() + accompaniment_spec.square()) + 1e-10
@@ -142,8 +141,8 @@ def main():
for name, spec in zip(
["vocals", "accompaniment"], [vocals_spec, accompaniment_spec]
):
spec_c0 = spec[:, 0, :, :]
spec_c1 = spec[:, 1, :, :]
spec_c0 = spec[0]
spec_c1 = spec[1]
spec_c0 = spec_c0.reshape(-1, 1024)
spec_c1 = spec_c1.reshape(-1, 1024)

View File

@@ -67,6 +67,14 @@ class UNet(torch.nn.Module):
self.up7 = torch.nn.Conv2d(1, 2, kernel_size=4, dilation=2, padding=3)
def forward(self, x):
"""
Args:
x: (num_audio_channels, num_splits, 512, 1024)
Returns:
y: (num_audio_channels, num_splits, 512, 1024)
"""
x = x.permute(1, 0, 2, 3)
in_x = x
# in_x is (3, 2, 512, 1024) = (T, 2, 512, 1024)
x = torch.nn.functional.pad(x, (1, 2, 1, 2), "constant", 0)
@@ -147,4 +155,5 @@ class UNet(torch.nn.Module):
up7 = self.up7(batch12)
up7 = torch.sigmoid(up7) # (3, 2, 512, 1024)
return up7 * in_x
ans = up7 * in_x
return ans.permute(1, 0, 2, 3)