Add C++ runtime for spleeter about source separation (#2242)

This commit is contained in:
Fangjun Kuang
2025-05-23 22:30:57 +08:00
committed by GitHub
parent 55a44793e6
commit 716ba8317b
28 changed files with 1267 additions and 72 deletions

View File

@@ -67,6 +67,14 @@ class UNet(torch.nn.Module):
self.up7 = torch.nn.Conv2d(1, 2, kernel_size=4, dilation=2, padding=3)
def forward(self, x):
"""
Args:
x: (num_audio_channels, num_splits, 512, 1024)
Returns:
y: (num_audio_channels, num_splits, 512, 1024)
"""
x = x.permute(1, 0, 2, 3)
in_x = x
# in_x is (3, 2, 512, 1024) = (T, 2, 512, 1024)
x = torch.nn.functional.pad(x, (1, 2, 1, 2), "constant", 0)
@@ -147,4 +155,5 @@ class UNet(torch.nn.Module):
up7 = self.up7(batch12)
up7 = torch.sigmoid(up7) # (3, 2, 512, 1024)
return up7 * in_x
ans = up7 * in_x
return ans.permute(1, 0, 2, 3)