Add C++ runtime for spleeter about source separation (#2242)
This commit is contained in:
@@ -217,8 +217,8 @@ def main(name):
|
||||
# for the batchnormalization in torch,
|
||||
# default input shape is NCHW
|
||||
|
||||
# NHWC to NCHW
|
||||
torch_y1_out = unet(torch.from_numpy(y0_out).permute(0, 3, 1, 2))
|
||||
torch_y1_out = unet(torch.from_numpy(y0_out).permute(3, 0, 1, 2))
|
||||
torch_y1_out = torch_y1_out.permute(1, 0, 2, 3)
|
||||
|
||||
# print(torch_y1_out.shape, torch.from_numpy(y1_out).permute(0, 3, 1, 2).shape)
|
||||
assert torch.allclose(
|
||||
|
||||
@@ -46,7 +46,7 @@ def add_meta_data(filename, prefix):
|
||||
|
||||
def export(model, prefix):
|
||||
num_splits = 1
|
||||
x = torch.rand(num_splits, 2, 512, 1024, dtype=torch.float32)
|
||||
x = torch.rand(2, num_splits, 512, 1024, dtype=torch.float32)
|
||||
|
||||
filename = f"./2stems/{prefix}.onnx"
|
||||
torch.onnx.export(
|
||||
@@ -56,7 +56,7 @@ def export(model, prefix):
|
||||
input_names=["x"],
|
||||
output_names=["y"],
|
||||
dynamic_axes={
|
||||
"x": {0: "num_splits"},
|
||||
"x": {1: "num_splits"},
|
||||
},
|
||||
opset_version=13,
|
||||
)
|
||||
|
||||
@@ -101,13 +101,17 @@ def main():
|
||||
print("y2", y.shape, y.dtype)
|
||||
|
||||
y = y.abs()
|
||||
y = y.permute(0, 3, 1, 2)
|
||||
# (1, 2, 512, 1024)
|
||||
|
||||
y = y.permute(3, 0, 1, 2)
|
||||
# (2, 1, 512, 1024)
|
||||
print("y3", y.shape, y.dtype)
|
||||
|
||||
vocals_spec = vocals(y)
|
||||
accompaniment_spec = accompaniment(y)
|
||||
|
||||
vocals_spec = vocals_spec.permute(1, 0, 2, 3)
|
||||
accompaniment_spec = accompaniment_spec.permute(1, 0, 2, 3)
|
||||
|
||||
sum_spec = (vocals_spec**2 + accompaniment_spec**2) + 1e-10
|
||||
print(
|
||||
"vocals_spec",
|
||||
|
||||
@@ -12,15 +12,14 @@ from separate import load_audio
|
||||
|
||||
"""
|
||||
----------inputs for ./2stems/vocals.onnx----------
|
||||
NodeArg(name='x', type='tensor(float)', shape=['num_splits', 2, 512, 1024])
|
||||
NodeArg(name='x', type='tensor(float)', shape=[2, 'num_splits', 512, 1024])
|
||||
----------outputs for ./2stems/vocals.onnx----------
|
||||
NodeArg(name='y', type='tensor(float)', shape=['Muly_dim_0', 2, 512, 1024])
|
||||
NodeArg(name='y', type='tensor(float)', shape=[2, 'Transposey_dim_1', 512, 1024])
|
||||
|
||||
----------inputs for ./2stems/accompaniment.onnx----------
|
||||
NodeArg(name='x', type='tensor(float)', shape=['num_splits', 2, 512, 1024])
|
||||
NodeArg(name='x', type='tensor(float)', shape=[2, 'num_splits', 512, 1024])
|
||||
----------outputs for ./2stems/accompaniment.onnx----------
|
||||
NodeArg(name='y', type='tensor(float)', shape=['Muly_dim_0', 2, 512, 1024])
|
||||
|
||||
NodeArg(name='y', type='tensor(float)', shape=[2, 'Transposey_dim_1', 512, 1024])
|
||||
"""
|
||||
|
||||
|
||||
@@ -123,16 +122,16 @@ def main():
|
||||
if padding > 0:
|
||||
stft0 = torch.nn.functional.pad(stft0, (0, 0, 0, padding))
|
||||
stft1 = torch.nn.functional.pad(stft1, (0, 0, 0, padding))
|
||||
stft0 = stft0.reshape(-1, 1, 512, 1024)
|
||||
stft1 = stft1.reshape(-1, 1, 512, 1024)
|
||||
stft0 = stft0.reshape(1, -1, 512, 1024)
|
||||
stft1 = stft1.reshape(1, -1, 512, 1024)
|
||||
|
||||
stft_01 = torch.cat([stft0, stft1], axis=1)
|
||||
stft_01 = torch.cat([stft0, stft1], axis=0)
|
||||
|
||||
print("stft_01", stft_01.shape, stft_01.dtype)
|
||||
|
||||
vocals_spec = vocals(stft_01)
|
||||
accompaniment_spec = accompaniment(stft_01)
|
||||
# (num_splits, num_channels, 512, 1024)
|
||||
# (num_channels, num_splits, 512, 1024)
|
||||
|
||||
sum_spec = (vocals_spec.square() + accompaniment_spec.square()) + 1e-10
|
||||
|
||||
@@ -142,8 +141,8 @@ def main():
|
||||
for name, spec in zip(
|
||||
["vocals", "accompaniment"], [vocals_spec, accompaniment_spec]
|
||||
):
|
||||
spec_c0 = spec[:, 0, :, :]
|
||||
spec_c1 = spec[:, 1, :, :]
|
||||
spec_c0 = spec[0]
|
||||
spec_c1 = spec[1]
|
||||
|
||||
spec_c0 = spec_c0.reshape(-1, 1024)
|
||||
spec_c1 = spec_c1.reshape(-1, 1024)
|
||||
|
||||
@@ -67,6 +67,14 @@ class UNet(torch.nn.Module):
|
||||
self.up7 = torch.nn.Conv2d(1, 2, kernel_size=4, dilation=2, padding=3)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
x: (num_audio_channels, num_splits, 512, 1024)
|
||||
Returns:
|
||||
y: (num_audio_channels, num_splits, 512, 1024)
|
||||
"""
|
||||
x = x.permute(1, 0, 2, 3)
|
||||
|
||||
in_x = x
|
||||
# in_x is (3, 2, 512, 1024) = (T, 2, 512, 1024)
|
||||
x = torch.nn.functional.pad(x, (1, 2, 1, 2), "constant", 0)
|
||||
@@ -147,4 +155,5 @@ class UNet(torch.nn.Module):
|
||||
up7 = self.up7(batch12)
|
||||
up7 = torch.sigmoid(up7) # (3, 2, 512, 1024)
|
||||
|
||||
return up7 * in_x
|
||||
ans = up7 * in_x
|
||||
return ans.permute(1, 0, 2, 3)
|
||||
|
||||
Reference in New Issue
Block a user