Add C++ runtime for spleeter about source separation (#2242)

This commit is contained in:
Fangjun Kuang
2025-05-23 22:30:57 +08:00
committed by GitHub
parent 55a44793e6
commit 716ba8317b
28 changed files with 1267 additions and 72 deletions

View File

@@ -12,15 +12,14 @@ from separate import load_audio
"""
----------inputs for ./2stems/vocals.onnx----------
NodeArg(name='x', type='tensor(float)', shape=['num_splits', 2, 512, 1024])
NodeArg(name='x', type='tensor(float)', shape=[2, 'num_splits', 512, 1024])
----------outputs for ./2stems/vocals.onnx----------
NodeArg(name='y', type='tensor(float)', shape=['Muly_dim_0', 2, 512, 1024])
NodeArg(name='y', type='tensor(float)', shape=[2, 'Transposey_dim_1', 512, 1024])
----------inputs for ./2stems/accompaniment.onnx----------
NodeArg(name='x', type='tensor(float)', shape=['num_splits', 2, 512, 1024])
NodeArg(name='x', type='tensor(float)', shape=[2, 'num_splits', 512, 1024])
----------outputs for ./2stems/accompaniment.onnx----------
NodeArg(name='y', type='tensor(float)', shape=['Muly_dim_0', 2, 512, 1024])
NodeArg(name='y', type='tensor(float)', shape=[2, 'Transposey_dim_1', 512, 1024])
"""
@@ -123,16 +122,16 @@ def main():
if padding > 0:
stft0 = torch.nn.functional.pad(stft0, (0, 0, 0, padding))
stft1 = torch.nn.functional.pad(stft1, (0, 0, 0, padding))
stft0 = stft0.reshape(-1, 1, 512, 1024)
stft1 = stft1.reshape(-1, 1, 512, 1024)
stft0 = stft0.reshape(1, -1, 512, 1024)
stft1 = stft1.reshape(1, -1, 512, 1024)
stft_01 = torch.cat([stft0, stft1], axis=1)
stft_01 = torch.cat([stft0, stft1], axis=0)
print("stft_01", stft_01.shape, stft_01.dtype)
vocals_spec = vocals(stft_01)
accompaniment_spec = accompaniment(stft_01)
# (num_splits, num_channels, 512, 1024)
# (num_channels, num_splits, 512, 1024)
sum_spec = (vocals_spec.square() + accompaniment_spec.square()) + 1e-10
@@ -142,8 +141,8 @@ def main():
for name, spec in zip(
["vocals", "accompaniment"], [vocals_spec, accompaniment_spec]
):
spec_c0 = spec[:, 0, :, :]
spec_c1 = spec[:, 1, :, :]
spec_c0 = spec[0]
spec_c1 = spec[1]
spec_c0 = spec_c0.reshape(-1, 1024)
spec_c1 = spec_c1.reshape(-1, 1024)