// sherpa-onnx/csrc/offline-source-separation-spleeter-impl.h // // Copyright (c) 2025 Xiaomi Corporation #ifndef SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_IMPL_H_ #define SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_IMPL_H_ #include #include #include #include "Eigen/Dense" #include "kaldi-native-fbank/csrc/istft.h" #include "kaldi-native-fbank/csrc/stft.h" #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/offline-source-separation-spleeter-model.h" #include "sherpa-onnx/csrc/offline-source-separation.h" #include "sherpa-onnx/csrc/onnx-utils.h" namespace sherpa_onnx { class OfflineSourceSeparationSpleeterImpl : public OfflineSourceSeparationImpl { public: explicit OfflineSourceSeparationSpleeterImpl( const OfflineSourceSeparationConfig &config) : config_(config), model_(config_.model) {} template OfflineSourceSeparationSpleeterImpl( Manager *mgr, const OfflineSourceSeparationConfig &config) : config_(config), model_(mgr, config_.model) {} OfflineSourceSeparationOutput Process( const OfflineSourceSeparationInput &_input) const override { auto input = Resample(_input, config_.model.debug); auto stft_ch0 = ComputeStft(input, 0); auto stft_ch1 = ComputeStft(input, 1); knf::StftResult *p_stft_ch1 = stft_ch1.real.empty() ? &stft_ch0 : &stft_ch1; int32_t num_frames = stft_ch0.num_frames; int32_t fft_bins = stft_ch0.real.size() / num_frames; int32_t pad = 512 - (stft_ch0.num_frames % 512); if (pad < 512) { num_frames += pad; } if (num_frames % 512) { SHERPA_ONNX_LOGE("num_frames should be multiple of 512, actual: %d. %d", num_frames, num_frames % 512); SHERPA_ONNX_EXIT(-1); } Eigen::VectorXf real(2 * num_frames * 1024); Eigen::VectorXf imag(2 * num_frames * 1024); real.setZero(); imag.setZero(); float *p_real = &real[0]; float *p_imag = &imag[0]; // copy stft result of channel 0 for (int32_t i = 0; i != stft_ch0.num_frames; ++i) { std::copy(stft_ch0.real.data() + i * fft_bins, stft_ch0.real.data() + i * fft_bins + 1024, p_real + 1024 * i); std::copy(stft_ch0.imag.data() + i * fft_bins, stft_ch0.imag.data() + i * fft_bins + 1024, p_imag + 1024 * i); } p_real += num_frames * 1024; p_imag += num_frames * 1024; // copy stft result of channel 1 for (int32_t i = 0; i != stft_ch1.num_frames; ++i) { std::copy(p_stft_ch1->real.data() + i * fft_bins, p_stft_ch1->real.data() + i * fft_bins + 1024, p_real + 1024 * i); std::copy(p_stft_ch1->imag.data() + i * fft_bins, p_stft_ch1->imag.data() + i * fft_bins + 1024, p_imag + 1024 * i); } Eigen::VectorXf x = (real.array().square() + imag.array().square()).sqrt(); auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); std::array x_shape{2, num_frames / 512, 512, 1024}; Ort::Value x_tensor = Ort::Value::CreateTensor( memory_info, &x[0], x.size(), x_shape.data(), x_shape.size()); Ort::Value vocals_spec_tensor = model_.RunVocals(View(&x_tensor)); Ort::Value accompaniment_spec_tensor = model_.RunAccompaniment(std::move(x_tensor)); Eigen::VectorXf vocals_spec = Eigen::Map( vocals_spec_tensor.GetTensorMutableData(), x.size()); Eigen::VectorXf accompaniment_spec = Eigen::Map( accompaniment_spec_tensor.GetTensorMutableData(), x.size()); Eigen::VectorXf sum_spec = vocals_spec.array().square() + accompaniment_spec.array().square() + 1e-10; vocals_spec = (vocals_spec.array().square() + 1e-10 / 2) / sum_spec.array(); accompaniment_spec = (accompaniment_spec.array().square() + 1e-10 / 2) / sum_spec.array(); auto vocals_samples_ch0 = ProcessSpec(vocals_spec, stft_ch0, 0); auto vocals_samples_ch1 = ProcessSpec(vocals_spec, *p_stft_ch1, 1); auto accompaniment_samples_ch0 = ProcessSpec(accompaniment_spec, stft_ch0, 0); auto accompaniment_samples_ch1 = ProcessSpec(accompaniment_spec, *p_stft_ch1, 1); OfflineSourceSeparationOutput ans; ans.sample_rate = GetOutputSampleRate(); ans.stems.resize(2); ans.stems[0].data.reserve(2); ans.stems[1].data.reserve(2); ans.stems[0].data.push_back(std::move(vocals_samples_ch0)); ans.stems[0].data.push_back(std::move(vocals_samples_ch1)); ans.stems[1].data.push_back(std::move(accompaniment_samples_ch0)); ans.stems[1].data.push_back(std::move(accompaniment_samples_ch1)); return ans; } int32_t GetOutputSampleRate() const override { return model_.GetMetaData().sample_rate; } int32_t GetNumberOfStems() const override { return model_.GetMetaData().num_stems; } private: // spec is of shape (2, num_chunks, 512, 1024) std::vector ProcessSpec(const Eigen::VectorXf &spec, const knf::StftResult &stft, int32_t channel) const { int32_t fft_bins = stft.real.size() / stft.num_frames; Eigen::VectorXf mask(stft.real.size()); mask.setZero(); float *p_mask = &mask[0]; // assume there are 2 channels const float *p_spec = &spec[0] + (spec.size() / 2) * channel; for (int32_t i = 0; i != stft.num_frames; ++i) { std::copy(p_spec + i * 1024, p_spec + (i + 1) * 1024, p_mask + i * fft_bins); } knf::StftResult masked_stft; masked_stft.num_frames = stft.num_frames; masked_stft.real.resize(stft.real.size()); masked_stft.imag.resize(stft.imag.size()); Eigen::Map(masked_stft.real.data(), masked_stft.real.size()) = mask.array() * Eigen::Map(const_cast(stft.real.data()), stft.real.size()) .array(); Eigen::Map(masked_stft.imag.data(), masked_stft.imag.size()) = mask.array() * Eigen::Map(const_cast(stft.imag.data()), stft.imag.size()) .array(); auto stft_config = GetStftConfig(); knf::IStft istft(stft_config); return istft.Compute(masked_stft); } knf::StftResult ComputeStft(const OfflineSourceSeparationInput &input, int32_t ch) const { if (ch >= input.samples.data.size()) { SHERPA_ONNX_LOGE("Invalid channel %d. Max %d", ch, static_cast(input.samples.data.size())); SHERPA_ONNX_EXIT(-1); } if (input.samples.data[ch].empty()) { return {}; } return ComputeStft(input.samples.data[ch]); } knf::StftResult ComputeStft(const std::vector &samples) const { auto stft_config = GetStftConfig(); knf::Stft stft(stft_config); return stft.Compute(samples.data(), samples.size()); } knf::StftConfig GetStftConfig() const { const auto &meta = model_.GetMetaData(); knf::StftConfig stft_config; stft_config.n_fft = meta.n_fft; stft_config.hop_length = meta.hop_length; stft_config.win_length = meta.window_length; stft_config.window_type = meta.window_type; stft_config.center = meta.center; return stft_config; } private: OfflineSourceSeparationConfig config_; OfflineSourceSeparationSpleeterModel model_; }; } // namespace sherpa_onnx #endif // SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_IMPL_H_