Add C++ runtime for spleeter about source separation (#2242)
This commit is contained in:
@@ -67,6 +67,14 @@ class UNet(torch.nn.Module):
|
||||
self.up7 = torch.nn.Conv2d(1, 2, kernel_size=4, dilation=2, padding=3)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
x: (num_audio_channels, num_splits, 512, 1024)
|
||||
Returns:
|
||||
y: (num_audio_channels, num_splits, 512, 1024)
|
||||
"""
|
||||
x = x.permute(1, 0, 2, 3)
|
||||
|
||||
in_x = x
|
||||
# in_x is (3, 2, 512, 1024) = (T, 2, 512, 1024)
|
||||
x = torch.nn.functional.pad(x, (1, 2, 1, 2), "constant", 0)
|
||||
@@ -147,4 +155,5 @@ class UNet(torch.nn.Module):
|
||||
up7 = self.up7(batch12)
|
||||
up7 = torch.sigmoid(up7) # (3, 2, 512, 1024)
|
||||
|
||||
return up7 * in_x
|
||||
ans = up7 * in_x
|
||||
return ans.permute(1, 0, 2, 3)
|
||||
|
||||
Reference in New Issue
Block a user