Add C++ runtime for spleeter about source separation (#2242)
This commit is contained in:
@@ -3,7 +3,7 @@ name: export-spleeter-to-onnx
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- spleeter-2
|
||||
- spleeter-cpp-2
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
|
||||
@@ -56,6 +56,7 @@ def get_binaries():
|
||||
"sherpa-onnx-offline-denoiser",
|
||||
"sherpa-onnx-offline-language-identification",
|
||||
"sherpa-onnx-offline-punctuation",
|
||||
"sherpa-onnx-offline-source-separation",
|
||||
"sherpa-onnx-offline-speaker-diarization",
|
||||
"sherpa-onnx-offline-tts",
|
||||
"sherpa-onnx-offline-tts-play",
|
||||
|
||||
@@ -217,8 +217,8 @@ def main(name):
|
||||
# for the batchnormalization in torch,
|
||||
# default input shape is NCHW
|
||||
|
||||
# NHWC to NCHW
|
||||
torch_y1_out = unet(torch.from_numpy(y0_out).permute(0, 3, 1, 2))
|
||||
torch_y1_out = unet(torch.from_numpy(y0_out).permute(3, 0, 1, 2))
|
||||
torch_y1_out = torch_y1_out.permute(1, 0, 2, 3)
|
||||
|
||||
# print(torch_y1_out.shape, torch.from_numpy(y1_out).permute(0, 3, 1, 2).shape)
|
||||
assert torch.allclose(
|
||||
|
||||
@@ -46,7 +46,7 @@ def add_meta_data(filename, prefix):
|
||||
|
||||
def export(model, prefix):
|
||||
num_splits = 1
|
||||
x = torch.rand(num_splits, 2, 512, 1024, dtype=torch.float32)
|
||||
x = torch.rand(2, num_splits, 512, 1024, dtype=torch.float32)
|
||||
|
||||
filename = f"./2stems/{prefix}.onnx"
|
||||
torch.onnx.export(
|
||||
@@ -56,7 +56,7 @@ def export(model, prefix):
|
||||
input_names=["x"],
|
||||
output_names=["y"],
|
||||
dynamic_axes={
|
||||
"x": {0: "num_splits"},
|
||||
"x": {1: "num_splits"},
|
||||
},
|
||||
opset_version=13,
|
||||
)
|
||||
|
||||
@@ -101,13 +101,17 @@ def main():
|
||||
print("y2", y.shape, y.dtype)
|
||||
|
||||
y = y.abs()
|
||||
y = y.permute(0, 3, 1, 2)
|
||||
# (1, 2, 512, 1024)
|
||||
|
||||
y = y.permute(3, 0, 1, 2)
|
||||
# (2, 1, 512, 1024)
|
||||
print("y3", y.shape, y.dtype)
|
||||
|
||||
vocals_spec = vocals(y)
|
||||
accompaniment_spec = accompaniment(y)
|
||||
|
||||
vocals_spec = vocals_spec.permute(1, 0, 2, 3)
|
||||
accompaniment_spec = accompaniment_spec.permute(1, 0, 2, 3)
|
||||
|
||||
sum_spec = (vocals_spec**2 + accompaniment_spec**2) + 1e-10
|
||||
print(
|
||||
"vocals_spec",
|
||||
|
||||
@@ -12,15 +12,14 @@ from separate import load_audio
|
||||
|
||||
"""
|
||||
----------inputs for ./2stems/vocals.onnx----------
|
||||
NodeArg(name='x', type='tensor(float)', shape=['num_splits', 2, 512, 1024])
|
||||
NodeArg(name='x', type='tensor(float)', shape=[2, 'num_splits', 512, 1024])
|
||||
----------outputs for ./2stems/vocals.onnx----------
|
||||
NodeArg(name='y', type='tensor(float)', shape=['Muly_dim_0', 2, 512, 1024])
|
||||
NodeArg(name='y', type='tensor(float)', shape=[2, 'Transposey_dim_1', 512, 1024])
|
||||
|
||||
----------inputs for ./2stems/accompaniment.onnx----------
|
||||
NodeArg(name='x', type='tensor(float)', shape=['num_splits', 2, 512, 1024])
|
||||
NodeArg(name='x', type='tensor(float)', shape=[2, 'num_splits', 512, 1024])
|
||||
----------outputs for ./2stems/accompaniment.onnx----------
|
||||
NodeArg(name='y', type='tensor(float)', shape=['Muly_dim_0', 2, 512, 1024])
|
||||
|
||||
NodeArg(name='y', type='tensor(float)', shape=[2, 'Transposey_dim_1', 512, 1024])
|
||||
"""
|
||||
|
||||
|
||||
@@ -123,16 +122,16 @@ def main():
|
||||
if padding > 0:
|
||||
stft0 = torch.nn.functional.pad(stft0, (0, 0, 0, padding))
|
||||
stft1 = torch.nn.functional.pad(stft1, (0, 0, 0, padding))
|
||||
stft0 = stft0.reshape(-1, 1, 512, 1024)
|
||||
stft1 = stft1.reshape(-1, 1, 512, 1024)
|
||||
stft0 = stft0.reshape(1, -1, 512, 1024)
|
||||
stft1 = stft1.reshape(1, -1, 512, 1024)
|
||||
|
||||
stft_01 = torch.cat([stft0, stft1], axis=1)
|
||||
stft_01 = torch.cat([stft0, stft1], axis=0)
|
||||
|
||||
print("stft_01", stft_01.shape, stft_01.dtype)
|
||||
|
||||
vocals_spec = vocals(stft_01)
|
||||
accompaniment_spec = accompaniment(stft_01)
|
||||
# (num_splits, num_channels, 512, 1024)
|
||||
# (num_channels, num_splits, 512, 1024)
|
||||
|
||||
sum_spec = (vocals_spec.square() + accompaniment_spec.square()) + 1e-10
|
||||
|
||||
@@ -142,8 +141,8 @@ def main():
|
||||
for name, spec in zip(
|
||||
["vocals", "accompaniment"], [vocals_spec, accompaniment_spec]
|
||||
):
|
||||
spec_c0 = spec[:, 0, :, :]
|
||||
spec_c1 = spec[:, 1, :, :]
|
||||
spec_c0 = spec[0]
|
||||
spec_c1 = spec[1]
|
||||
|
||||
spec_c0 = spec_c0.reshape(-1, 1024)
|
||||
spec_c1 = spec_c1.reshape(-1, 1024)
|
||||
|
||||
@@ -67,6 +67,14 @@ class UNet(torch.nn.Module):
|
||||
self.up7 = torch.nn.Conv2d(1, 2, kernel_size=4, dilation=2, padding=3)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
x: (num_audio_channels, num_splits, 512, 1024)
|
||||
Returns:
|
||||
y: (num_audio_channels, num_splits, 512, 1024)
|
||||
"""
|
||||
x = x.permute(1, 0, 2, 3)
|
||||
|
||||
in_x = x
|
||||
# in_x is (3, 2, 512, 1024) = (T, 2, 512, 1024)
|
||||
x = torch.nn.functional.pad(x, (1, 2, 1, 2), "constant", 0)
|
||||
@@ -147,4 +155,5 @@ class UNet(torch.nn.Module):
|
||||
up7 = self.up7(batch12)
|
||||
up7 = torch.sigmoid(up7) # (3, 2, 512, 1024)
|
||||
|
||||
return up7 * in_x
|
||||
ans = up7 * in_x
|
||||
return ans.permute(1, 0, 2, 3)
|
||||
|
||||
@@ -50,6 +50,13 @@ set(sources
|
||||
offline-rnn-lm.cc
|
||||
offline-sense-voice-model-config.cc
|
||||
offline-sense-voice-model.cc
|
||||
|
||||
offline-source-separation-impl.cc
|
||||
offline-source-separation-model-config.cc
|
||||
offline-source-separation-spleeter-model-config.cc
|
||||
offline-source-separation-spleeter-model.cc
|
||||
offline-source-separation.cc
|
||||
|
||||
offline-stream.cc
|
||||
offline-tdnn-ctc-model.cc
|
||||
offline-tdnn-model-config.cc
|
||||
@@ -326,6 +333,7 @@ if(SHERPA_ONNX_ENABLE_BINARY)
|
||||
add_executable(sherpa-onnx-offline-language-identification sherpa-onnx-offline-language-identification.cc)
|
||||
add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc)
|
||||
add_executable(sherpa-onnx-offline-punctuation sherpa-onnx-offline-punctuation.cc)
|
||||
add_executable(sherpa-onnx-offline-source-separation sherpa-onnx-offline-source-separation.cc)
|
||||
add_executable(sherpa-onnx-online-punctuation sherpa-onnx-online-punctuation.cc)
|
||||
add_executable(sherpa-onnx-vad sherpa-onnx-vad.cc)
|
||||
|
||||
@@ -346,6 +354,7 @@ if(SHERPA_ONNX_ENABLE_BINARY)
|
||||
sherpa-onnx-offline-language-identification
|
||||
sherpa-onnx-offline-parallel
|
||||
sherpa-onnx-offline-punctuation
|
||||
sherpa-onnx-offline-source-separation
|
||||
sherpa-onnx-online-punctuation
|
||||
sherpa-onnx-vad
|
||||
)
|
||||
|
||||
40
sherpa-onnx/csrc/offline-source-separation-impl.cc
Normal file
40
sherpa-onnx/csrc/offline-source-separation-impl.cc
Normal file
@@ -0,0 +1,40 @@
|
||||
// sherpa-onnx/csrc/offline-source-separation-impl.cc
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-source-separation-impl.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-source-separation-spleeter-impl.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
std::unique_ptr<OfflineSourceSeparationImpl>
|
||||
OfflineSourceSeparationImpl::Create(
|
||||
const OfflineSourceSeparationConfig &config) {
|
||||
// TODO(fangjun): Support other models
|
||||
return std::make_unique<OfflineSourceSeparationSpleeterImpl>(config);
|
||||
}
|
||||
|
||||
template <typename Manager>
|
||||
std::unique_ptr<OfflineSourceSeparationImpl>
|
||||
OfflineSourceSeparationImpl::Create(
|
||||
Manager *mgr, const OfflineSourceSeparationConfig &config) {
|
||||
// TODO(fangjun): Support other models
|
||||
return std::make_unique<OfflineSourceSeparationSpleeterImpl>(mgr, config);
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
template std::unique_ptr<OfflineSourceSeparationImpl>
|
||||
OfflineSourceSeparationImpl::Create(
|
||||
AAssetManager *mgr, const OfflineSourceSeparationConfig &config);
|
||||
#endif
|
||||
|
||||
#if __OHOS__
|
||||
template std::unique_ptr<OfflineSourceSeparationImpl>
|
||||
OfflineSourceSeparationImpl::Create(
|
||||
NativeResourceManager *mgr, const OfflineSourceSeparationConfig &config);
|
||||
#endif
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
35
sherpa-onnx/csrc/offline-source-separation-impl.h
Normal file
35
sherpa-onnx/csrc/offline-source-separation-impl.h
Normal file
@@ -0,0 +1,35 @@
|
||||
// sherpa-onnx/csrc/offline-source-separation-impl.h
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_IMPL_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_IMPL_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-source-separation.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OfflineSourceSeparationImpl {
|
||||
public:
|
||||
static std::unique_ptr<OfflineSourceSeparationImpl> Create(
|
||||
const OfflineSourceSeparationConfig &config);
|
||||
|
||||
template <typename Manager>
|
||||
static std::unique_ptr<OfflineSourceSeparationImpl> Create(
|
||||
Manager *mgr, const OfflineSourceSeparationConfig &config);
|
||||
|
||||
virtual ~OfflineSourceSeparationImpl() = default;
|
||||
|
||||
virtual OfflineSourceSeparationOutput Process(
|
||||
const OfflineSourceSeparationInput &input) const = 0;
|
||||
|
||||
virtual int32_t GetOutputSampleRate() const = 0;
|
||||
|
||||
virtual int32_t GetNumberOfStems() const = 0;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_IMPL_H_
|
||||
38
sherpa-onnx/csrc/offline-source-separation-model-config.cc
Normal file
38
sherpa-onnx/csrc/offline-source-separation-model-config.cc
Normal file
@@ -0,0 +1,38 @@
|
||||
// sherpa-onnx/csrc/offline-source-separation-model-config.cc
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-source-separation-model-config.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void OfflineSourceSeparationModelConfig::Register(ParseOptions *po) {
|
||||
spleeter.Register(po);
|
||||
|
||||
po->Register("num-threads", &num_threads,
|
||||
"Number of threads to run the neural network");
|
||||
|
||||
po->Register("debug", &debug,
|
||||
"true to print model information while loading it.");
|
||||
|
||||
po->Register("provider", &provider,
|
||||
"Specify a provider to use: cpu, cuda, coreml");
|
||||
}
|
||||
|
||||
bool OfflineSourceSeparationModelConfig::Validate() const {
|
||||
return spleeter.Validate();
|
||||
}
|
||||
|
||||
std::string OfflineSourceSeparationModelConfig::ToString() const {
|
||||
std::ostringstream os;
|
||||
|
||||
os << "OfflineSourceSeparationModelConfig(";
|
||||
os << "spleeter=" << spleeter.ToString() << ", ";
|
||||
os << "num_threads=" << num_threads << ", ";
|
||||
os << "debug=" << (debug ? "True" : "False") << ", ";
|
||||
os << "provider=\"" << provider << "\")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
41
sherpa-onnx/csrc/offline-source-separation-model-config.h
Normal file
41
sherpa-onnx/csrc/offline-source-separation-model-config.h
Normal file
@@ -0,0 +1,41 @@
|
||||
// sherpa-onnx/csrc/offline-source-separation-model-config.h
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_MODEL_CONFIG_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_MODEL_CONFIG_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-source-separation-spleeter-model-config.h"
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct OfflineSourceSeparationModelConfig {
|
||||
OfflineSourceSeparationSpleeterModelConfig spleeter;
|
||||
|
||||
int32_t num_threads = 1;
|
||||
bool debug = false;
|
||||
std::string provider = "cpu";
|
||||
|
||||
OfflineSourceSeparationModelConfig() = default;
|
||||
|
||||
OfflineSourceSeparationModelConfig(
|
||||
const OfflineSourceSeparationSpleeterModelConfig &spleeter,
|
||||
int32_t num_threads, bool debug, const std::string &provider)
|
||||
: spleeter(spleeter),
|
||||
num_threads(num_threads),
|
||||
debug(debug),
|
||||
provider(provider) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
|
||||
bool Validate() const;
|
||||
|
||||
std::string ToString() const;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_MODEL_CONFIG_H_
|
||||
276
sherpa-onnx/csrc/offline-source-separation-spleeter-impl.h
Normal file
276
sherpa-onnx/csrc/offline-source-separation-spleeter-impl.h
Normal file
@@ -0,0 +1,276 @@
|
||||
// sherpa-onnx/csrc/offline-source-separation-spleeter-impl.h
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_IMPL_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_IMPL_H_
|
||||
|
||||
#include "Eigen/Dense"
|
||||
#include "kaldi-native-fbank/csrc/istft.h"
|
||||
#include "kaldi-native-fbank/csrc/stft.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/offline-source-separation-spleeter-model.h"
|
||||
#include "sherpa-onnx/csrc/offline-source-separation.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/resample.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OfflineSourceSeparationSpleeterImpl : public OfflineSourceSeparationImpl {
|
||||
public:
|
||||
OfflineSourceSeparationSpleeterImpl(
|
||||
const OfflineSourceSeparationConfig &config)
|
||||
: config_(config), model_(config_.model) {}
|
||||
|
||||
template <typename Manager>
|
||||
OfflineSourceSeparationSpleeterImpl(
|
||||
Manager *mgr, const OfflineSourceSeparationConfig &config)
|
||||
: config_(config), model_(mgr, config_.model) {}
|
||||
|
||||
OfflineSourceSeparationOutput Process(
|
||||
const OfflineSourceSeparationInput &input) const override {
|
||||
const OfflineSourceSeparationInput *p_input = &input;
|
||||
OfflineSourceSeparationInput tmp_input;
|
||||
|
||||
int32_t output_sample_rate = GetOutputSampleRate();
|
||||
|
||||
if (input.sample_rate != output_sample_rate) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Creating a resampler:\n"
|
||||
" in_sample_rate: %d\n"
|
||||
" output_sample_rate: %d\n",
|
||||
input.sample_rate, output_sample_rate);
|
||||
|
||||
float min_freq = std::min<int32_t>(input.sample_rate, output_sample_rate);
|
||||
float lowpass_cutoff = 0.99 * 0.5 * min_freq;
|
||||
|
||||
int32_t lowpass_filter_width = 6;
|
||||
auto resampler = std::make_unique<LinearResample>(
|
||||
input.sample_rate, output_sample_rate, lowpass_cutoff,
|
||||
lowpass_filter_width);
|
||||
|
||||
std::vector<float> s;
|
||||
for (const auto &samples : input.samples.data) {
|
||||
resampler->Reset();
|
||||
resampler->Resample(samples.data(), samples.size(), true, &s);
|
||||
tmp_input.samples.data.push_back(std::move(s));
|
||||
}
|
||||
|
||||
tmp_input.sample_rate = output_sample_rate;
|
||||
p_input = &tmp_input;
|
||||
}
|
||||
|
||||
if (p_input->samples.data.size() > 1) {
|
||||
if (config_.model.debug) {
|
||||
SHERPA_ONNX_LOGE("input ch1 samples size: %d",
|
||||
static_cast<int32_t>(p_input->samples.data[1].size()));
|
||||
}
|
||||
|
||||
if (p_input->samples.data[0].size() != p_input->samples.data[1].size()) {
|
||||
SHERPA_ONNX_LOGE("ch0 samples size %d vs ch1 samples size %d",
|
||||
static_cast<int32_t>(p_input->samples.data[0].size()),
|
||||
static_cast<int32_t>(p_input->samples.data[1].size()));
|
||||
|
||||
SHERPA_ONNX_EXIT(-1);
|
||||
}
|
||||
}
|
||||
|
||||
auto stft_ch0 = ComputeStft(*p_input, 0);
|
||||
|
||||
auto stft_ch1 = ComputeStft(*p_input, 1);
|
||||
knf::StftResult *p_stft_ch1 = stft_ch1.real.empty() ? &stft_ch0 : &stft_ch1;
|
||||
|
||||
int32_t num_frames = stft_ch0.num_frames;
|
||||
int32_t fft_bins = stft_ch0.real.size() / num_frames;
|
||||
|
||||
int32_t pad = 512 - (stft_ch0.num_frames % 512);
|
||||
if (pad < 512) {
|
||||
num_frames += pad;
|
||||
}
|
||||
|
||||
if (num_frames % 512) {
|
||||
SHERPA_ONNX_LOGE("num_frames should be multiple of 512, actual: %d. %d",
|
||||
num_frames, num_frames % 512);
|
||||
SHERPA_ONNX_EXIT(-1);
|
||||
}
|
||||
|
||||
Eigen::VectorXf real(2 * num_frames * 1024);
|
||||
Eigen::VectorXf imag(2 * num_frames * 1024);
|
||||
real.setZero();
|
||||
imag.setZero();
|
||||
|
||||
float *p_real = &real[0];
|
||||
float *p_imag = &imag[0];
|
||||
|
||||
// copy stft result of channel 0
|
||||
for (int32_t i = 0; i != stft_ch0.num_frames; ++i) {
|
||||
std::copy(stft_ch0.real.data() + i * fft_bins,
|
||||
stft_ch0.real.data() + i * fft_bins + 1024, p_real + 1024 * i);
|
||||
|
||||
std::copy(stft_ch0.imag.data() + i * fft_bins,
|
||||
stft_ch0.imag.data() + i * fft_bins + 1024, p_imag + 1024 * i);
|
||||
}
|
||||
|
||||
p_real += num_frames * 1024;
|
||||
p_imag += num_frames * 1024;
|
||||
|
||||
// copy stft result of channel 1
|
||||
for (int32_t i = 0; i != stft_ch1.num_frames; ++i) {
|
||||
std::copy(p_stft_ch1->real.data() + i * fft_bins,
|
||||
p_stft_ch1->real.data() + i * fft_bins + 1024,
|
||||
p_real + 1024 * i);
|
||||
|
||||
std::copy(p_stft_ch1->imag.data() + i * fft_bins,
|
||||
p_stft_ch1->imag.data() + i * fft_bins + 1024,
|
||||
p_imag + 1024 * i);
|
||||
}
|
||||
|
||||
Eigen::VectorXf x = (real.array().square() + imag.array().square()).sqrt();
|
||||
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
std::array<int64_t, 4> x_shape{2, num_frames / 512, 512, 1024};
|
||||
Ort::Value x_tensor = Ort::Value::CreateTensor(
|
||||
memory_info, &x[0], x.size(), x_shape.data(), x_shape.size());
|
||||
|
||||
Ort::Value vocals_spec_tensor = model_.RunVocals(View(&x_tensor));
|
||||
Ort::Value accompaniment_spec_tensor =
|
||||
model_.RunAccompaniment(std::move(x_tensor));
|
||||
|
||||
Eigen::VectorXf vocals_spec = Eigen::Map<Eigen::VectorXf>(
|
||||
vocals_spec_tensor.GetTensorMutableData<float>(), x.size());
|
||||
|
||||
Eigen::VectorXf accompaniment_spec = Eigen::Map<Eigen::VectorXf>(
|
||||
accompaniment_spec_tensor.GetTensorMutableData<float>(), x.size());
|
||||
|
||||
Eigen::VectorXf sum_spec = vocals_spec.array().square() +
|
||||
accompaniment_spec.array().square() + 1e-10;
|
||||
|
||||
vocals_spec = (vocals_spec.array().square() + 1e-10 / 2) / sum_spec.array();
|
||||
|
||||
accompaniment_spec =
|
||||
(accompaniment_spec.array().square() + 1e-10 / 2) / sum_spec.array();
|
||||
|
||||
auto vocals_samples_ch0 = ProcessSpec(vocals_spec, stft_ch0, 0);
|
||||
auto vocals_samples_ch1 = ProcessSpec(vocals_spec, *p_stft_ch1, 1);
|
||||
|
||||
auto accompaniment_samples_ch0 =
|
||||
ProcessSpec(accompaniment_spec, stft_ch0, 0);
|
||||
auto accompaniment_samples_ch1 =
|
||||
ProcessSpec(accompaniment_spec, *p_stft_ch1, 1);
|
||||
|
||||
OfflineSourceSeparationOutput ans;
|
||||
ans.sample_rate = GetOutputSampleRate();
|
||||
|
||||
ans.stems.resize(2);
|
||||
ans.stems[0].data.reserve(2);
|
||||
ans.stems[1].data.reserve(2);
|
||||
|
||||
ans.stems[0].data.push_back(std::move(vocals_samples_ch0));
|
||||
ans.stems[0].data.push_back(std::move(vocals_samples_ch1));
|
||||
|
||||
ans.stems[1].data.push_back(std::move(accompaniment_samples_ch0));
|
||||
ans.stems[1].data.push_back(std::move(accompaniment_samples_ch1));
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
int32_t GetOutputSampleRate() const override {
|
||||
return model_.GetMetaData().sample_rate;
|
||||
}
|
||||
|
||||
int32_t GetNumberOfStems() const override {
|
||||
return model_.GetMetaData().num_stems;
|
||||
}
|
||||
|
||||
private:
|
||||
// spec is of shape (2, num_chunks, 512, 1024)
|
||||
std::vector<float> ProcessSpec(const Eigen::VectorXf &spec,
|
||||
const knf::StftResult &stft,
|
||||
int32_t channel) const {
|
||||
int32_t fft_bins = stft.real.size() / stft.num_frames;
|
||||
|
||||
Eigen::VectorXf mask(stft.real.size());
|
||||
mask.setZero();
|
||||
|
||||
float *p_mask = &mask[0];
|
||||
|
||||
// assume there are 2 channels
|
||||
const float *p_spec = &spec[0] + (spec.size() / 2) * channel;
|
||||
|
||||
for (int32_t i = 0; i != stft.num_frames; ++i) {
|
||||
std::copy(p_spec + i * 1024, p_spec + (i + 1) * 1024,
|
||||
p_mask + i * fft_bins);
|
||||
}
|
||||
|
||||
knf::StftResult masked_stft;
|
||||
|
||||
masked_stft.num_frames = stft.num_frames;
|
||||
masked_stft.real.resize(stft.real.size());
|
||||
masked_stft.imag.resize(stft.imag.size());
|
||||
|
||||
Eigen::Map<Eigen::VectorXf>(masked_stft.real.data(),
|
||||
masked_stft.real.size()) =
|
||||
mask.array() *
|
||||
Eigen::Map<Eigen::VectorXf>(const_cast<float *>(stft.real.data()),
|
||||
stft.real.size())
|
||||
.array();
|
||||
|
||||
Eigen::Map<Eigen::VectorXf>(masked_stft.imag.data(),
|
||||
masked_stft.imag.size()) =
|
||||
mask.array() *
|
||||
Eigen::Map<Eigen::VectorXf>(const_cast<float *>(stft.imag.data()),
|
||||
stft.imag.size())
|
||||
.array();
|
||||
|
||||
auto stft_config = GetStftConfig();
|
||||
knf::IStft istft(stft_config);
|
||||
|
||||
return istft.Compute(masked_stft);
|
||||
}
|
||||
|
||||
knf::StftResult ComputeStft(const OfflineSourceSeparationInput &input,
|
||||
int32_t ch) const {
|
||||
if (ch >= input.samples.data.size()) {
|
||||
SHERPA_ONNX_LOGE("Invalid channel %d. Max %d", ch,
|
||||
static_cast<int32_t>(input.samples.data.size()));
|
||||
SHERPA_ONNX_EXIT(-1);
|
||||
}
|
||||
|
||||
if (input.samples.data[ch].empty()) {
|
||||
return {};
|
||||
}
|
||||
|
||||
return ComputeStft(input.samples.data[ch]);
|
||||
}
|
||||
|
||||
knf::StftResult ComputeStft(const std::vector<float> &samples) const {
|
||||
auto stft_config = GetStftConfig();
|
||||
knf::Stft stft(stft_config);
|
||||
|
||||
return stft.Compute(samples.data(), samples.size());
|
||||
}
|
||||
|
||||
knf::StftConfig GetStftConfig() const {
|
||||
const auto &meta = model_.GetMetaData();
|
||||
|
||||
knf::StftConfig stft_config;
|
||||
stft_config.n_fft = meta.n_fft;
|
||||
stft_config.hop_length = meta.hop_length;
|
||||
stft_config.win_length = meta.window_length;
|
||||
stft_config.window_type = meta.window_type;
|
||||
stft_config.center = meta.center;
|
||||
stft_config.center = false;
|
||||
|
||||
return stft_config;
|
||||
}
|
||||
|
||||
private:
|
||||
OfflineSourceSeparationConfig config_;
|
||||
OfflineSourceSeparationSpleeterModel model_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_IMPL_H_
|
||||
@@ -0,0 +1,54 @@
|
||||
// sherpa-onnx/csrc/offline-source-separation-spleeter_model-config.cc
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-source-separation-spleeter-model-config.h"
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void OfflineSourceSeparationSpleeterModelConfig::Register(ParseOptions *po) {
|
||||
po->Register("spleeter-vocals", &vocals, "Path to the spleeter vocals model");
|
||||
|
||||
po->Register("spleeter-accompaniment", &accompaniment,
|
||||
"Path to the spleeter accompaniment model");
|
||||
}
|
||||
|
||||
bool OfflineSourceSeparationSpleeterModelConfig::Validate() const {
|
||||
if (vocals.empty()) {
|
||||
SHERPA_ONNX_LOGE("Please provide --spleeter-vocals");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!FileExists(vocals)) {
|
||||
SHERPA_ONNX_LOGE("spleeter vocals '%s' does not exist. ", vocals.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (accompaniment.empty()) {
|
||||
SHERPA_ONNX_LOGE("Please provide --spleeter-accompaniment");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!FileExists(accompaniment)) {
|
||||
SHERPA_ONNX_LOGE("spleeter accompaniment '%s' does not exist. ",
|
||||
accompaniment.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string OfflineSourceSeparationSpleeterModelConfig::ToString() const {
|
||||
std::ostringstream os;
|
||||
|
||||
os << "OfflineSourceSeparationSpleeterModelConfig(";
|
||||
os << "vocals=\"" << vocals << "\", ";
|
||||
os << "accompaniment=\"" << accompaniment << "\")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
@@ -0,0 +1,35 @@
|
||||
// sherpa-onnx/csrc/offline-source-separation-spleeter_model-config.h
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_CONFIG_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_CONFIG_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-source-separation-spleeter-model-config.h"
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct OfflineSourceSeparationSpleeterModelConfig {
|
||||
std::string vocals;
|
||||
|
||||
std::string accompaniment;
|
||||
|
||||
OfflineSourceSeparationSpleeterModelConfig() = default;
|
||||
|
||||
OfflineSourceSeparationSpleeterModelConfig(const std::string &vocals,
|
||||
const std::string &accompaniment)
|
||||
: vocals(vocals), accompaniment(accompaniment) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
|
||||
bool Validate() const;
|
||||
|
||||
std::string ToString() const;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_CONFIG_H_
|
||||
@@ -0,0 +1,28 @@
|
||||
// sherpa-onnx/csrc/offline-source-separation-spleeter-model-meta-data.h
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_META_DATA_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_META_DATA_H_
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
// See also
|
||||
// https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/spleeter/separate_onnx.py
|
||||
struct OfflineSourceSeparationSpleeterModelMetaData {
|
||||
int32_t sample_rate = 44100;
|
||||
int32_t num_stems = 2;
|
||||
|
||||
int32_t n_fft = 4096;
|
||||
int32_t hop_length = 1024;
|
||||
int32_t window_length = 4096;
|
||||
bool center = false;
|
||||
std::string window_type = "hann";
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_META_DATA_H_
|
||||
212
sherpa-onnx/csrc/offline-source-separation-spleeter-model.cc
Normal file
212
sherpa-onnx/csrc/offline-source-separation-spleeter-model.cc
Normal file
@@ -0,0 +1,212 @@
|
||||
// sherpa-onnx/csrc/offline-source-separation-spleeter-model.cc
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-source-separation-spleeter-model.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#if __OHOS__
|
||||
#include "rawfile/raw_file_manager.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/session.h"
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OfflineSourceSeparationSpleeterModel::Impl {
|
||||
public:
|
||||
explicit Impl(const OfflineSourceSeparationModelConfig &config)
|
||||
: config_(config),
|
||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
{
|
||||
auto buf = ReadFile(config.spleeter.vocals);
|
||||
InitVocals(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
{
|
||||
auto buf = ReadFile(config.spleeter.accompaniment);
|
||||
InitAccompaniment(buf.data(), buf.size());
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Manager>
|
||||
Impl(Manager *mgr, const OfflineSourceSeparationModelConfig &config)
|
||||
: config_(config),
|
||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
{
|
||||
auto buf = ReadFile(mgr, config.spleeter.vocals);
|
||||
InitVocals(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
{
|
||||
auto buf = ReadFile(mgr, config.spleeter.accompaniment);
|
||||
InitAccompaniment(buf.data(), buf.size());
|
||||
}
|
||||
}
|
||||
|
||||
const OfflineSourceSeparationSpleeterModelMetaData &GetMetaData() const {
|
||||
return meta_;
|
||||
}
|
||||
|
||||
Ort::Value RunVocals(Ort::Value x) const {
|
||||
auto out = vocals_sess_->Run({}, vocals_input_names_ptr_.data(), &x, 1,
|
||||
vocals_output_names_ptr_.data(),
|
||||
vocals_output_names_ptr_.size());
|
||||
return std::move(out[0]);
|
||||
}
|
||||
|
||||
Ort::Value RunAccompaniment(Ort::Value x) const {
|
||||
auto out =
|
||||
accompaniment_sess_->Run({}, accompaniment_input_names_ptr_.data(), &x,
|
||||
1, accompaniment_output_names_ptr_.data(),
|
||||
accompaniment_output_names_ptr_.size());
|
||||
return std::move(out[0]);
|
||||
}
|
||||
|
||||
private:
|
||||
void InitVocals(void *model_data, size_t model_data_length) {
|
||||
vocals_sess_ = std::make_unique<Ort::Session>(
|
||||
env_, model_data, model_data_length, sess_opts_);
|
||||
|
||||
GetInputNames(vocals_sess_.get(), &vocals_input_names_,
|
||||
&vocals_input_names_ptr_);
|
||||
|
||||
GetOutputNames(vocals_sess_.get(), &vocals_output_names_,
|
||||
&vocals_output_names_ptr_);
|
||||
|
||||
Ort::ModelMetadata meta_data = vocals_sess_->GetModelMetadata();
|
||||
if (config_.debug) {
|
||||
std::ostringstream os;
|
||||
os << "---vocals model---\n";
|
||||
PrintModelMetadata(os, meta_data);
|
||||
|
||||
os << "----------input names----------\n";
|
||||
int32_t i = 0;
|
||||
for (const auto &s : vocals_input_names_) {
|
||||
os << i << " " << s << "\n";
|
||||
++i;
|
||||
}
|
||||
os << "----------output names----------\n";
|
||||
i = 0;
|
||||
for (const auto &s : vocals_output_names_) {
|
||||
os << i << " " << s << "\n";
|
||||
++i;
|
||||
}
|
||||
|
||||
#if __OHOS__
|
||||
SHERPA_ONNX_LOGE("%{public}s\n", os.str().c_str());
|
||||
#else
|
||||
SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
|
||||
#endif
|
||||
}
|
||||
|
||||
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
|
||||
|
||||
std::string model_type;
|
||||
SHERPA_ONNX_READ_META_DATA_STR(model_type, "model_type");
|
||||
if (model_type != "spleeter") {
|
||||
SHERPA_ONNX_LOGE("Expect model type 'spleeter'. Given: '%s'",
|
||||
model_type.c_str());
|
||||
SHERPA_ONNX_EXIT(-1);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_READ_META_DATA(meta_.num_stems, "stems");
|
||||
if (meta_.num_stems != 2) {
|
||||
SHERPA_ONNX_LOGE("Only 2stems is supported. Given %d stems",
|
||||
meta_.num_stems);
|
||||
SHERPA_ONNX_EXIT(-1);
|
||||
}
|
||||
}
|
||||
|
||||
void InitAccompaniment(void *model_data, size_t model_data_length) {
|
||||
accompaniment_sess_ = std::make_unique<Ort::Session>(
|
||||
env_, model_data, model_data_length, sess_opts_);
|
||||
|
||||
GetInputNames(accompaniment_sess_.get(), &accompaniment_input_names_,
|
||||
&accompaniment_input_names_ptr_);
|
||||
|
||||
GetOutputNames(accompaniment_sess_.get(), &accompaniment_output_names_,
|
||||
&accompaniment_output_names_ptr_);
|
||||
}
|
||||
|
||||
private:
|
||||
OfflineSourceSeparationModelConfig config_;
|
||||
OfflineSourceSeparationSpleeterModelMetaData meta_;
|
||||
|
||||
Ort::Env env_;
|
||||
Ort::SessionOptions sess_opts_;
|
||||
Ort::AllocatorWithDefaultOptions allocator_;
|
||||
|
||||
std::unique_ptr<Ort::Session> vocals_sess_;
|
||||
|
||||
std::vector<std::string> vocals_input_names_;
|
||||
std::vector<const char *> vocals_input_names_ptr_;
|
||||
|
||||
std::vector<std::string> vocals_output_names_;
|
||||
std::vector<const char *> vocals_output_names_ptr_;
|
||||
|
||||
std::unique_ptr<Ort::Session> accompaniment_sess_;
|
||||
|
||||
std::vector<std::string> accompaniment_input_names_;
|
||||
std::vector<const char *> accompaniment_input_names_ptr_;
|
||||
|
||||
std::vector<std::string> accompaniment_output_names_;
|
||||
std::vector<const char *> accompaniment_output_names_ptr_;
|
||||
};
|
||||
|
||||
OfflineSourceSeparationSpleeterModel::~OfflineSourceSeparationSpleeterModel() =
|
||||
default;
|
||||
|
||||
OfflineSourceSeparationSpleeterModel::OfflineSourceSeparationSpleeterModel(
|
||||
const OfflineSourceSeparationModelConfig &config)
|
||||
: impl_(std::make_unique<Impl>(config)) {}
|
||||
|
||||
template <typename Manager>
|
||||
OfflineSourceSeparationSpleeterModel::OfflineSourceSeparationSpleeterModel(
|
||||
Manager *mgr, const OfflineSourceSeparationModelConfig &config)
|
||||
: impl_(std::make_unique<Impl>(mgr, config)) {}
|
||||
|
||||
Ort::Value OfflineSourceSeparationSpleeterModel::RunVocals(Ort::Value x) const {
|
||||
return impl_->RunVocals(std::move(x));
|
||||
}
|
||||
|
||||
Ort::Value OfflineSourceSeparationSpleeterModel::RunAccompaniment(
|
||||
Ort::Value x) const {
|
||||
return impl_->RunAccompaniment(std::move(x));
|
||||
}
|
||||
|
||||
const OfflineSourceSeparationSpleeterModelMetaData &
|
||||
OfflineSourceSeparationSpleeterModel::GetMetaData() const {
|
||||
return impl_->GetMetaData();
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
template OfflineSourceSeparationSpleeterModel::
|
||||
OfflineSourceSeparationSpleeterModel(
|
||||
AAssetManager *mgr, const OfflineSourceSeparationModelConfig &config);
|
||||
#endif
|
||||
|
||||
#if __OHOS__
|
||||
template OfflineSourceSeparationSpleeterModel::
|
||||
OfflineSourceSeparationSpleeterModel(
|
||||
NativeResourceManager *mgr,
|
||||
const OfflineSourceSeparationModelConfig &config);
|
||||
#endif
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
37
sherpa-onnx/csrc/offline-source-separation-spleeter-model.h
Normal file
37
sherpa-onnx/csrc/offline-source-separation-spleeter-model.h
Normal file
@@ -0,0 +1,37 @@
|
||||
// sherpa-onnx/csrc/offline-source-separation-spleeter-model.h
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_H_
|
||||
#include <memory>
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/offline-source-separation-model-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-source-separation-spleeter-model-meta-data.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OfflineSourceSeparationSpleeterModel {
|
||||
public:
|
||||
~OfflineSourceSeparationSpleeterModel();
|
||||
|
||||
explicit OfflineSourceSeparationSpleeterModel(
|
||||
const OfflineSourceSeparationModelConfig &config);
|
||||
|
||||
template <typename Manager>
|
||||
OfflineSourceSeparationSpleeterModel(
|
||||
Manager *mgr, const OfflineSourceSeparationModelConfig &config);
|
||||
|
||||
Ort::Value RunVocals(Ort::Value x) const;
|
||||
Ort::Value RunAccompaniment(Ort::Value x) const;
|
||||
|
||||
const OfflineSourceSeparationSpleeterModelMetaData &GetMetaData() const;
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> impl_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_H_
|
||||
74
sherpa-onnx/csrc/offline-source-separation.cc
Normal file
74
sherpa-onnx/csrc/offline-source-separation.cc
Normal file
@@ -0,0 +1,74 @@
|
||||
// sherpa-onnx/csrc/offline-source-separation.cc
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-source-separation.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-source-separation-impl.h"
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#if __OHOS__
|
||||
#include "rawfile/raw_file_manager.h"
|
||||
#endif
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void OfflineSourceSeparationConfig::Register(ParseOptions *po) {
|
||||
model.Register(po);
|
||||
}
|
||||
|
||||
bool OfflineSourceSeparationConfig::Validate() const {
|
||||
return model.Validate();
|
||||
}
|
||||
|
||||
std::string OfflineSourceSeparationConfig::ToString() const {
|
||||
std::ostringstream os;
|
||||
|
||||
os << "OfflineSourceSeparationConfig(";
|
||||
os << "model=" << model.ToString() << ")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
template <typename Manager>
|
||||
OfflineSourceSeparation::OfflineSourceSeparation(
|
||||
Manager *mgr, const OfflineSourceSeparationConfig &config)
|
||||
: impl_(OfflineSourceSeparationImpl::Create(mgr, config)) {}
|
||||
|
||||
OfflineSourceSeparation::OfflineSourceSeparation(
|
||||
const OfflineSourceSeparationConfig &config)
|
||||
: impl_(OfflineSourceSeparationImpl::Create(config)) {}
|
||||
|
||||
OfflineSourceSeparation::~OfflineSourceSeparation() = default;
|
||||
|
||||
OfflineSourceSeparationOutput OfflineSourceSeparation::Process(
|
||||
const OfflineSourceSeparationInput &input) const {
|
||||
return impl_->Process(input);
|
||||
}
|
||||
|
||||
int32_t OfflineSourceSeparation::GetOutputSampleRate() const {
|
||||
return impl_->GetOutputSampleRate();
|
||||
}
|
||||
|
||||
// e.g., it is 2 for 2stems from spleeter
|
||||
int32_t OfflineSourceSeparation::GetNumberOfStems() const {
|
||||
return impl_->GetNumberOfStems();
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
template OfflineSourceSeparation::OfflineSourceSeparation(
|
||||
AAssetManager *mgr, const OfflineSourceSeparationConfig &config);
|
||||
#endif
|
||||
|
||||
#if __OHOS__
|
||||
template OfflineSourceSeparation::OfflineSourceSeparation(
|
||||
NativeResourceManager *mgr, const OfflineSourceSeparationConfig &config);
|
||||
#endif
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
77
sherpa-onnx/csrc/offline-source-separation.h
Normal file
77
sherpa-onnx/csrc/offline-source-separation.h
Normal file
@@ -0,0 +1,77 @@
|
||||
// sherpa-onnx/csrc/offline-source-separation.h
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-source-separation-model-config.h"
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct OfflineSourceSeparationConfig {
|
||||
OfflineSourceSeparationModelConfig model;
|
||||
|
||||
OfflineSourceSeparationConfig() = default;
|
||||
|
||||
OfflineSourceSeparationConfig(const OfflineSourceSeparationModelConfig &model)
|
||||
: model(model) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
|
||||
bool Validate() const;
|
||||
|
||||
std::string ToString() const;
|
||||
};
|
||||
|
||||
struct MultiChannelSamples {
|
||||
// data[i] is for the i-th channel
|
||||
//
|
||||
// each sample is in the range [-1, 1]
|
||||
std::vector<std::vector<float>> data;
|
||||
};
|
||||
|
||||
struct OfflineSourceSeparationInput {
|
||||
MultiChannelSamples samples;
|
||||
|
||||
int32_t sample_rate;
|
||||
};
|
||||
|
||||
struct OfflineSourceSeparationOutput {
|
||||
std::vector<MultiChannelSamples> stems;
|
||||
|
||||
int32_t sample_rate;
|
||||
};
|
||||
|
||||
class OfflineSourceSeparationImpl;
|
||||
|
||||
class OfflineSourceSeparation {
|
||||
public:
|
||||
~OfflineSourceSeparation();
|
||||
|
||||
OfflineSourceSeparation(const OfflineSourceSeparationConfig &config);
|
||||
|
||||
template <typename Manager>
|
||||
OfflineSourceSeparation(Manager *mgr,
|
||||
const OfflineSourceSeparationConfig &config);
|
||||
|
||||
OfflineSourceSeparationOutput Process(
|
||||
const OfflineSourceSeparationInput &input) const;
|
||||
|
||||
int32_t GetOutputSampleRate() const;
|
||||
|
||||
// e.g., it is 2 for 2stems from spleeter
|
||||
int32_t GetNumberOfStems() const;
|
||||
|
||||
private:
|
||||
std::unique_ptr<OfflineSourceSeparationImpl> impl_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_H_
|
||||
@@ -12,7 +12,7 @@
|
||||
namespace sherpa_onnx {
|
||||
|
||||
// please refer to
|
||||
// https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/kokoro/add-meta-data.py
|
||||
// https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/gtcrn/add_meta_data.py
|
||||
struct OfflineSpeechDenoiserGtcrnModelMetaData {
|
||||
int32_t sample_rate = 0;
|
||||
int32_t version = 1;
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
|
||||
int main(int32_t argc, char *argv[]) {
|
||||
const char *kUsageMessage = R"usage(
|
||||
Non-stremaing speech denoising with sherpa-onnx.
|
||||
Non-streaming speech denoising with sherpa-onnx.
|
||||
|
||||
Please visit
|
||||
https://github.com/k2-fsa/sherpa-onnx/releases/tag/speech-enhancement-models
|
||||
|
||||
138
sherpa-onnx/csrc/sherpa-onnx-offline-source-separation.cc
Normal file
138
sherpa-onnx/csrc/sherpa-onnx-offline-source-separation.cc
Normal file
@@ -0,0 +1,138 @@
|
||||
// sherpa-onnx/csrc/sherpa-onnx-offline-source-separation.cc
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
#include <stdio.h>
|
||||
|
||||
#include <chrono> // NOLINT
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-source-separation.h"
|
||||
#include "sherpa-onnx/csrc/wave-reader.h"
|
||||
#include "sherpa-onnx/csrc/wave-writer.h"
|
||||
|
||||
int main(int32_t argc, char *argv[]) {
|
||||
const char *kUsageMessage = R"usage(
|
||||
Non-streaming source separation with sherpa-onnx.
|
||||
|
||||
Please visit
|
||||
https://github.com/k2-fsa/sherpa-onnx/releases/tag/source-separation-models
|
||||
to download models.
|
||||
|
||||
Usage:
|
||||
|
||||
(1) Use spleeter models
|
||||
|
||||
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/source-separation-models/sherpa-onnx-spleeter-2stems-fp16.tar.bz2
|
||||
tar xvf sherpa-onnx-spleeter-2stems-fp16.tar.bz2
|
||||
|
||||
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/source-separation-models/audio_example.wav
|
||||
|
||||
./bin/sherpa-onnx-offline-source-separation \
|
||||
--spleeter-vocals=sherpa-onnx-spleeter-2stems-fp16/vocals.fp16.onnx \
|
||||
--spleeter-accompaniment=sherpa-onnx-spleeter-2stems-fp16/accompaniment.fp16.onnx \
|
||||
--input-wav=audio_example.wav \
|
||||
--output-vocals-wav=output_vocals.wav \
|
||||
--output-accompaniment-wav=output_accompaniment.wav
|
||||
)usage";
|
||||
|
||||
sherpa_onnx::ParseOptions po(kUsageMessage);
|
||||
sherpa_onnx::OfflineSourceSeparationConfig config;
|
||||
|
||||
std::string input_wave;
|
||||
std::string output_vocals_wave;
|
||||
std::string output_accompaniment_wave;
|
||||
|
||||
config.Register(&po);
|
||||
po.Register("input-wav", &input_wave, "Path to input wav.");
|
||||
po.Register("output-vocals-wav", &output_vocals_wave,
|
||||
"Path to output vocals wav");
|
||||
po.Register("output-accompaniment-wav", &output_accompaniment_wave,
|
||||
"Path to output accompaniment wav");
|
||||
|
||||
po.Read(argc, argv);
|
||||
if (po.NumArgs() != 0) {
|
||||
fprintf(stderr, "Please don't give positional arguments\n");
|
||||
po.PrintUsage();
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
fprintf(stderr, "%s\n", config.ToString().c_str());
|
||||
|
||||
if (input_wave.empty()) {
|
||||
fprintf(stderr, "Please provide --input-wav\n");
|
||||
po.PrintUsage();
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
if (output_vocals_wave.empty()) {
|
||||
fprintf(stderr, "Please provide --output-vocals-wav\n");
|
||||
po.PrintUsage();
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
if (output_accompaniment_wave.empty()) {
|
||||
fprintf(stderr, "Please provide --output-accompaniment-wav\n");
|
||||
po.PrintUsage();
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
if (!config.Validate()) {
|
||||
fprintf(stderr, "Errors in config!\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
bool is_ok = false;
|
||||
sherpa_onnx::OfflineSourceSeparationInput input;
|
||||
input.samples.data =
|
||||
sherpa_onnx::ReadWaveMultiChannel(input_wave, &input.sample_rate, &is_ok);
|
||||
if (!is_ok) {
|
||||
fprintf(stderr, "Failed to read '%s'\n", input_wave.c_str());
|
||||
return -1;
|
||||
}
|
||||
|
||||
fprintf(stderr, "Started\n");
|
||||
|
||||
sherpa_onnx::OfflineSourceSeparation sp(config);
|
||||
|
||||
const auto begin = std::chrono::steady_clock::now();
|
||||
auto output = sp.Process(input);
|
||||
const auto end = std::chrono::steady_clock::now();
|
||||
|
||||
float elapsed_seconds =
|
||||
std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)
|
||||
.count() /
|
||||
1000.;
|
||||
|
||||
is_ok = sherpa_onnx::WriteWave(
|
||||
output_vocals_wave, output.sample_rate, output.stems[0].data[0].data(),
|
||||
output.stems[0].data[1].data(), output.stems[0].data[0].size());
|
||||
|
||||
if (!is_ok) {
|
||||
fprintf(stderr, "Failed to write to '%s'\n", output_vocals_wave.c_str());
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
is_ok = sherpa_onnx::WriteWave(output_accompaniment_wave, output.sample_rate,
|
||||
output.stems[1].data[0].data(),
|
||||
output.stems[1].data[1].data(),
|
||||
output.stems[1].data[0].size());
|
||||
|
||||
if (!is_ok) {
|
||||
fprintf(stderr, "Failed to write to '%s'\n",
|
||||
output_accompaniment_wave.c_str());
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
fprintf(stderr, "Done\n");
|
||||
fprintf(stderr, "Saved to write to '%s' and '%s'\n",
|
||||
output_vocals_wave.c_str(), output_accompaniment_wave.c_str());
|
||||
|
||||
float duration =
|
||||
input.samples.data[0].size() / static_cast<float>(input.sample_rate);
|
||||
fprintf(stderr, "num threads: %d\n", config.model.num_threads);
|
||||
fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
|
||||
float rtf = elapsed_seconds / duration;
|
||||
fprintf(stderr, "Real time factor (RTF): %.3f / %.3f = %.3f\n",
|
||||
elapsed_seconds, duration, rtf);
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -63,8 +63,9 @@ in sherpa-onnx.
|
||||
|
||||
// Read a wave file of mono-channel.
|
||||
// Return its samples normalized to the range [-1, 1).
|
||||
std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate,
|
||||
bool *is_ok) {
|
||||
std::vector<std::vector<float>> ReadWaveImpl(std::istream &is,
|
||||
int32_t *sampling_rate,
|
||||
bool *is_ok) {
|
||||
WaveHeader header{};
|
||||
is.read(reinterpret_cast<char *>(&header.chunk_id), sizeof(header.chunk_id));
|
||||
|
||||
@@ -144,12 +145,6 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate,
|
||||
is.read(reinterpret_cast<char *>(&header.num_channels),
|
||||
sizeof(header.num_channels));
|
||||
|
||||
if (header.num_channels != 1) { // we support only single channel for now
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Warning: %d channels are found. We only use the first channel.\n",
|
||||
header.num_channels);
|
||||
}
|
||||
|
||||
is.read(reinterpret_cast<char *>(&header.sample_rate),
|
||||
sizeof(header.sample_rate));
|
||||
|
||||
@@ -219,7 +214,7 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate,
|
||||
|
||||
*sampling_rate = header.sample_rate;
|
||||
|
||||
std::vector<float> ans;
|
||||
std::vector<std::vector<float>> ans(header.num_channels);
|
||||
|
||||
if (header.bits_per_sample == 16 && header.audio_format == 1) {
|
||||
// header.subchunk2_size contains the number of bytes in the data.
|
||||
@@ -233,11 +228,16 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate,
|
||||
return {};
|
||||
}
|
||||
|
||||
ans.resize(samples.size() / header.num_channels);
|
||||
for (auto &v : ans) {
|
||||
v.resize(samples.size() / header.num_channels);
|
||||
}
|
||||
|
||||
// samples are interleaved
|
||||
for (int32_t i = 0; i != static_cast<int32_t>(ans.size()); ++i) {
|
||||
ans[i] = samples[i * header.num_channels] / 32768.;
|
||||
for (int32_t i = 0, k = 0; i < static_cast<int32_t>(samples.size());
|
||||
i += header.num_channels, ++k) {
|
||||
for (int32_t c = 0; c != header.num_channels; ++c) {
|
||||
ans[c][k] = samples[i + c] / 32768.;
|
||||
}
|
||||
}
|
||||
} else if (header.bits_per_sample == 8 && header.audio_format == 1) {
|
||||
// number of samples == number of bytes for 8-bit encoded samples
|
||||
@@ -252,14 +252,21 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate,
|
||||
return {};
|
||||
}
|
||||
|
||||
ans.resize(samples.size() / header.num_channels);
|
||||
for (int32_t i = 0; i != static_cast<int32_t>(ans.size()); ++i) {
|
||||
// Note(fangjun): We want to normalize each sample into the range [-1, 1]
|
||||
// Since each original sample is in the range [0, 256], dividing
|
||||
// them by 128 converts them to the range [0, 2];
|
||||
// so after subtracting 1, we get the range [-1, 1]
|
||||
//
|
||||
ans[i] = samples[i * header.num_channels] / 128. - 1;
|
||||
for (auto &v : ans) {
|
||||
v.resize(samples.size() / header.num_channels);
|
||||
}
|
||||
|
||||
// samples are interleaved
|
||||
for (int32_t i = 0, k = 0; i < static_cast<int32_t>(samples.size());
|
||||
i += header.num_channels, ++k) {
|
||||
for (int32_t c = 0; c != header.num_channels; ++c) {
|
||||
// Note(fangjun): We want to normalize each sample into the range [-1,
|
||||
// 1] Since each original sample is in the range [0, 256], dividing them
|
||||
// by 128 converts them to the range [0, 2]; so after subtracting 1, we
|
||||
// get the range [-1, 1]
|
||||
//
|
||||
ans[c][k] = samples[i + c] / 128. - 1;
|
||||
}
|
||||
}
|
||||
} else if (header.bits_per_sample == 32 && header.audio_format == 1) {
|
||||
// 32 here is for int32
|
||||
@@ -275,9 +282,16 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate,
|
||||
return {};
|
||||
}
|
||||
|
||||
ans.resize(samples.size() / header.num_channels);
|
||||
for (int32_t i = 0; i != static_cast<int32_t>(ans.size()); ++i) {
|
||||
ans[i] = static_cast<float>(samples[i * header.num_channels]) / (1 << 31);
|
||||
for (auto &v : ans) {
|
||||
v.resize(samples.size() / header.num_channels);
|
||||
}
|
||||
|
||||
// samples are interleaved
|
||||
for (int32_t i = 0, k = 0; i < static_cast<int32_t>(samples.size());
|
||||
i += header.num_channels, ++k) {
|
||||
for (int32_t c = 0; c != header.num_channels; ++c) {
|
||||
ans[c][k] = static_cast<float>(samples[i + c]) / (1 << 31);
|
||||
}
|
||||
}
|
||||
} else if (header.bits_per_sample == 32 && header.audio_format == 3) {
|
||||
// 32 here is for float32
|
||||
@@ -293,9 +307,16 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate,
|
||||
return {};
|
||||
}
|
||||
|
||||
ans.resize(samples.size() / header.num_channels);
|
||||
for (int32_t i = 0; i != static_cast<int32_t>(ans.size()); ++i) {
|
||||
ans[i] = samples[i * header.num_channels];
|
||||
for (auto &v : ans) {
|
||||
v.resize(samples.size() / header.num_channels);
|
||||
}
|
||||
|
||||
// samples are interleaved
|
||||
for (int32_t i = 0, k = 0; i < static_cast<int32_t>(samples.size());
|
||||
i += header.num_channels, ++k) {
|
||||
for (int32_t c = 0; c != header.num_channels; ++c) {
|
||||
ans[c][k] = samples[i + c];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE(
|
||||
@@ -321,7 +342,27 @@ std::vector<float> ReadWave(const std::string &filename, int32_t *sampling_rate,
|
||||
std::vector<float> ReadWave(std::istream &is, int32_t *sampling_rate,
|
||||
bool *is_ok) {
|
||||
auto samples = ReadWaveImpl(is, sampling_rate, is_ok);
|
||||
|
||||
if (samples.size() > 1) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Warning: %d channels are found. We only use the first channel.\n",
|
||||
static_cast<int32_t>(samples.size()));
|
||||
}
|
||||
|
||||
return samples[0];
|
||||
}
|
||||
|
||||
std::vector<std::vector<float>> ReadWaveMultiChannel(std::istream &is,
|
||||
int32_t *sampling_rate,
|
||||
bool *is_ok) {
|
||||
auto samples = ReadWaveImpl(is, sampling_rate, is_ok);
|
||||
return samples;
|
||||
}
|
||||
|
||||
std::vector<std::vector<float>> ReadWaveMultiChannel(
|
||||
const std::string &filename, int32_t *sampling_rate, bool *is_ok) {
|
||||
std::ifstream is(filename, std::ifstream::binary);
|
||||
return ReadWaveMultiChannel(is, sampling_rate, is_ok);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -26,6 +26,13 @@ std::vector<float> ReadWave(const std::string &filename, int32_t *sampling_rate,
|
||||
std::vector<float> ReadWave(std::istream &is, int32_t *sampling_rate,
|
||||
bool *is_ok);
|
||||
|
||||
std::vector<std::vector<float>> ReadWaveMultiChannel(std::istream &is,
|
||||
int32_t *sampling_rate,
|
||||
bool *is_ok);
|
||||
|
||||
std::vector<std::vector<float>> ReadWaveMultiChannel(
|
||||
const std::string &filename, int32_t *sampling_rate, bool *is_ok);
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_WAVE_READER_H_
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
#include "sherpa-onnx/csrc/wave-writer.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
@@ -36,12 +37,44 @@ struct WaveHeader {
|
||||
|
||||
} // namespace
|
||||
|
||||
int64_t WaveFileSize(int32_t n_samples) {
|
||||
return sizeof(WaveHeader) + n_samples * sizeof(int16_t);
|
||||
int64_t WaveFileSize(int32_t n_samples, int32_t num_channels /*= 1*/) {
|
||||
return sizeof(WaveHeader) + n_samples * sizeof(int16_t) * num_channels;
|
||||
}
|
||||
|
||||
void WriteWave(char *buffer, int32_t sampling_rate, const float *samples,
|
||||
int32_t n) {
|
||||
WriteWave(buffer, sampling_rate, samples, nullptr, n);
|
||||
}
|
||||
|
||||
bool WriteWave(const std::string &filename, int32_t sampling_rate,
|
||||
const float *samples, int32_t n) {
|
||||
return WriteWave(filename, sampling_rate, samples, nullptr, n);
|
||||
}
|
||||
|
||||
bool WriteWave(const std::string &filename, int32_t sampling_rate,
|
||||
const float *samples_ch0, const float *samples_ch1, int32_t n) {
|
||||
std::string buffer;
|
||||
buffer.resize(WaveFileSize(n, samples_ch1 == nullptr ? 1 : 2));
|
||||
|
||||
WriteWave(buffer.data(), sampling_rate, samples_ch0, samples_ch1, n);
|
||||
|
||||
std::ofstream os(filename, std::ios::binary);
|
||||
if (!os) {
|
||||
SHERPA_ONNX_LOGE("Failed to create '%s'", filename.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
os << buffer;
|
||||
if (!os) {
|
||||
SHERPA_ONNX_LOGE("Write '%s' failed", filename.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void WriteWave(char *buffer, int32_t sampling_rate, const float *samples_ch0,
|
||||
const float *samples_ch1, int32_t n) {
|
||||
WaveHeader header{};
|
||||
header.chunk_id = 0x46464952; // FFIR
|
||||
header.format = 0x45564157; // EVAW
|
||||
@@ -49,8 +82,9 @@ void WriteWave(char *buffer, int32_t sampling_rate, const float *samples,
|
||||
header.subchunk1_size = 16; // 16 for PCM
|
||||
header.audio_format = 1; // PCM =1
|
||||
|
||||
int32_t num_channels = 1;
|
||||
int32_t num_channels = samples_ch1 == nullptr ? 1 : 2;
|
||||
int32_t bits_per_sample = 16; // int16_t
|
||||
|
||||
header.num_channels = num_channels;
|
||||
header.sample_rate = sampling_rate;
|
||||
header.byte_rate = sampling_rate * num_channels * bits_per_sample / 8;
|
||||
@@ -61,32 +95,32 @@ void WriteWave(char *buffer, int32_t sampling_rate, const float *samples,
|
||||
|
||||
header.chunk_size = 36 + header.subchunk2_size;
|
||||
|
||||
std::vector<int16_t> samples_int16(n);
|
||||
std::vector<int16_t> samples_int16_ch0(n);
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
samples_int16[i] = samples[i] * 32767;
|
||||
samples_int16_ch0[i] = std::min<int32_t>(samples_ch0[i] * 32767, 32767);
|
||||
}
|
||||
|
||||
std::vector<int16_t> samples_int16_ch1;
|
||||
if (samples_ch1) {
|
||||
samples_int16_ch1.resize(n);
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
samples_int16_ch1[i] = std::min<int32_t>(samples_ch1[i] * 32767, 32767);
|
||||
}
|
||||
}
|
||||
|
||||
memcpy(buffer, &header, sizeof(WaveHeader));
|
||||
memcpy(buffer + sizeof(WaveHeader), samples_int16.data(),
|
||||
n * sizeof(int16_t));
|
||||
}
|
||||
|
||||
bool WriteWave(const std::string &filename, int32_t sampling_rate,
|
||||
const float *samples, int32_t n) {
|
||||
std::string buffer;
|
||||
buffer.resize(WaveFileSize(n));
|
||||
WriteWave(buffer.data(), sampling_rate, samples, n);
|
||||
std::ofstream os(filename, std::ios::binary);
|
||||
if (!os) {
|
||||
SHERPA_ONNX_LOGE("Failed to create '%s'", filename.c_str());
|
||||
return false;
|
||||
if (samples_ch1 == nullptr) {
|
||||
memcpy(buffer + sizeof(WaveHeader), samples_int16_ch0.data(),
|
||||
n * sizeof(int16_t));
|
||||
} else {
|
||||
auto p = reinterpret_cast<int16_t *>(buffer + sizeof(WaveHeader));
|
||||
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
p[2 * i] = samples_int16_ch0[i];
|
||||
p[2 * i + 1] = samples_int16_ch1[i];
|
||||
}
|
||||
}
|
||||
os << buffer;
|
||||
if (!os) {
|
||||
SHERPA_ONNX_LOGE("Write '%s' failed", filename.c_str());
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -25,7 +25,13 @@ bool WriteWave(const std::string &filename, int32_t sampling_rate,
|
||||
void WriteWave(char *buffer, int32_t sampling_rate, const float *samples,
|
||||
int32_t n);
|
||||
|
||||
int64_t WaveFileSize(int32_t n_samples);
|
||||
bool WriteWave(const std::string &filename, int32_t sampling_rate,
|
||||
const float *samples_ch0, const float *samples_ch1, int32_t n);
|
||||
|
||||
void WriteWave(char *buffer, int32_t sampling_rate, const float *samples_ch0,
|
||||
const float *samples_ch1, int32_t n);
|
||||
|
||||
int64_t WaveFileSize(int32_t n_samples, int32_t num_channels = 1);
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
|
||||
@@ -77,7 +77,7 @@ fileInput.addEventListener('change', function(event) {
|
||||
console.log('ArrayBuffer length:', arrayBuffer.byteLength);
|
||||
|
||||
const uint8Array = new Uint8Array(arrayBuffer);
|
||||
const wave = readWaveFromBinaryData(uint8Array);
|
||||
const wave = readWaveFromBinaryData(uint8Array, Module);
|
||||
if (wave == null) {
|
||||
alert(
|
||||
`${file.name} is not a valid .wav file. Please select a *.wav file`);
|
||||
|
||||
Reference in New Issue
Block a user