Add Python API for source separation (#2283)
This commit is contained in:
@@ -20,6 +20,10 @@ set(srcs
|
||||
offline-punctuation.cc
|
||||
offline-recognizer.cc
|
||||
offline-sense-voice-model-config.cc
|
||||
offline-source-separation-model-config.cc
|
||||
offline-source-separation-spleeter-model-config.cc
|
||||
offline-source-separation-uvr-model-config.cc
|
||||
offline-source-separation.cc
|
||||
offline-speech-denoiser-gtcrn-model-config.cc
|
||||
offline-speech-denoiser-model-config.cc
|
||||
offline-speech-denoiser.cc
|
||||
|
||||
@@ -9,6 +9,8 @@
|
||||
|
||||
#include "sherpa-onnx/csrc/fast-clustering.h"
|
||||
|
||||
#define C_CONTIGUOUS py::detail::npy_api::constants::NPY_ARRAY_C_CONTIGUOUS_
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static void PybindFastClusteringConfig(py::module *m) {
|
||||
@@ -32,6 +34,12 @@ void PybindFastClustering(py::module *m) {
|
||||
"__call__",
|
||||
[](const PyClass &self,
|
||||
py::array_t<float> features) -> std::vector<int32_t> {
|
||||
if (!(C_CONTIGUOUS == (features.flags() & C_CONTIGUOUS))) {
|
||||
throw py::value_error(
|
||||
"input features should be contiguous. Please use "
|
||||
"np.ascontiguousarray(features)");
|
||||
}
|
||||
|
||||
int num_dim = features.ndim();
|
||||
if (num_dim != 2) {
|
||||
std::ostringstream os;
|
||||
|
||||
@@ -59,14 +59,14 @@ void PybindOfflineRecognizer(py::module *m) {
|
||||
return self.CreateStream(hotwords);
|
||||
},
|
||||
py::arg("hotwords"), py::call_guard<py::gil_scoped_release>())
|
||||
.def("decode_stream", &PyClass::DecodeStream,
|
||||
.def("decode_stream", &PyClass::DecodeStream, py::arg("s"),
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"decode_streams",
|
||||
[](const PyClass &self, std::vector<OfflineStream *> ss) {
|
||||
self.DecodeStreams(ss.data(), ss.size());
|
||||
},
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
py::arg("ss"), py::call_guard<py::gil_scoped_release>());
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
// sherpa-onnx/python/csrc/offline-source-separation-model-config.cc
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/python/csrc/offline-source-separation-model-config.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-source-separation-model-config.h"
|
||||
#include "sherpa-onnx/python/csrc/offline-source-separation-spleeter-model-config.h"
|
||||
#include "sherpa-onnx/python/csrc/offline-source-separation-uvr-model-config.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void PybindOfflineSourceSeparationModelConfig(py::module *m) {
|
||||
PybindOfflineSourceSeparationSpleeterModelConfig(m);
|
||||
PybindOfflineSourceSeparationUvrModelConfig(m);
|
||||
|
||||
using PyClass = OfflineSourceSeparationModelConfig;
|
||||
py::class_<PyClass>(*m, "OfflineSourceSeparationModelConfig")
|
||||
.def(py::init<const OfflineSourceSeparationSpleeterModelConfig &,
|
||||
const OfflineSourceSeparationUvrModelConfig &, int32_t,
|
||||
bool, const std::string &>(),
|
||||
py::arg("spleeter") = OfflineSourceSeparationSpleeterModelConfig{},
|
||||
py::arg("uvr") = OfflineSourceSeparationUvrModelConfig{},
|
||||
py::arg("num_threads") = 1, py::arg("debug") = false,
|
||||
py::arg("provider") = "cpu")
|
||||
.def_readwrite("spleeter", &PyClass::spleeter)
|
||||
.def_readwrite("uvr", &PyClass::uvr)
|
||||
.def_readwrite("num_threads", &PyClass::num_threads)
|
||||
.def_readwrite("debug", &PyClass::debug)
|
||||
.def_readwrite("provider", &PyClass::provider)
|
||||
.def("validate", &PyClass::Validate)
|
||||
.def("__str__", &PyClass::ToString);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
@@ -0,0 +1,16 @@
|
||||
// sherpa-onnx/python/csrc/offline-source-separation-model-config.h
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SOURCE_SEPARATION_MODEL_CONFIG_H_
|
||||
#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SOURCE_SEPARATION_MODEL_CONFIG_H_
|
||||
|
||||
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void PybindOfflineSourceSeparationModelConfig(py::module *m);
|
||||
|
||||
}
|
||||
|
||||
#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SOURCE_SEPARATION_MODEL_CONFIG_H_
|
||||
@@ -0,0 +1,24 @@
|
||||
// sherpa-onnx/python/csrc/offline-source-separation-spleeter-model-config.cc
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/python/csrc/offline-source-separation-spleeter-model-config.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-source-separation-spleeter-model-config.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void PybindOfflineSourceSeparationSpleeterModelConfig(py::module *m) {
|
||||
using PyClass = OfflineSourceSeparationSpleeterModelConfig;
|
||||
py::class_<PyClass>(*m, "OfflineSourceSeparationSpleeterModelConfig")
|
||||
.def(py::init<const std::string &, const std::string &>(),
|
||||
py::arg("vocals") = "", py::arg("accompaniment") = "")
|
||||
.def_readwrite("vocals", &PyClass::vocals)
|
||||
.def_readwrite("accompaniment", &PyClass::accompaniment)
|
||||
.def("validate", &PyClass::Validate)
|
||||
.def("__str__", &PyClass::ToString);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
@@ -0,0 +1,16 @@
|
||||
// sherpa-onnx/python/csrc/offline-source-separation-spleeter-model-config.h
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_CONFIG_H_
|
||||
#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_CONFIG_H_
|
||||
|
||||
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void PybindOfflineSourceSeparationSpleeterModelConfig(py::module *m);
|
||||
|
||||
}
|
||||
|
||||
#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_CONFIG_H_
|
||||
@@ -0,0 +1,22 @@
|
||||
// sherpa-onnx/python/csrc/offline-source-separation-uvr-model-config.cc
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/python/csrc/offline-source-separation-uvr-model-config.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-source-separation-uvr-model-config.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void PybindOfflineSourceSeparationUvrModelConfig(py::module *m) {
|
||||
using PyClass = OfflineSourceSeparationUvrModelConfig;
|
||||
py::class_<PyClass>(*m, "OfflineSourceSeparationUvrModelConfig")
|
||||
.def(py::init<const std::string &>(), py::arg("model") = "")
|
||||
.def_readwrite("model", &PyClass::model)
|
||||
.def("validate", &PyClass::Validate)
|
||||
.def("__str__", &PyClass::ToString);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
@@ -0,0 +1,16 @@
|
||||
// sherpa-onnx/python/csrc/offline-source-separation-uvr-model-config.h
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SOURCE_SEPARATION_UVR_MODEL_CONFIG_H_
|
||||
#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SOURCE_SEPARATION_UVR_MODEL_CONFIG_H_
|
||||
|
||||
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void PybindOfflineSourceSeparationUvrModelConfig(py::module *m);
|
||||
|
||||
}
|
||||
|
||||
#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SOURCE_SEPARATION_UVR_MODEL_CONFIG_H_
|
||||
133
sherpa-onnx/python/csrc/offline-source-separation.cc
Normal file
133
sherpa-onnx/python/csrc/offline-source-separation.cc
Normal file
@@ -0,0 +1,133 @@
|
||||
// sherpa-onnx/python/csrc/offline-source-separation-config.cc
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-source-separation.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/python/csrc/offline-source-separation-model-config.h"
|
||||
#include "sherpa-onnx/python/csrc/offline-source-separation.h"
|
||||
|
||||
#define C_CONTIGUOUS py::detail::npy_api::constants::NPY_ARRAY_C_CONTIGUOUS_
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static void PybindOfflineSourceSeparationConfig(py::module *m) {
|
||||
PybindOfflineSourceSeparationModelConfig(m);
|
||||
|
||||
using PyClass = OfflineSourceSeparationConfig;
|
||||
py::class_<PyClass>(*m, "OfflineSourceSeparationConfig")
|
||||
.def(py::init<const OfflineSourceSeparationModelConfig &>(),
|
||||
py::arg("model") = OfflineSourceSeparationModelConfig{})
|
||||
.def_readwrite("model", &PyClass::model)
|
||||
.def("validate", &PyClass::Validate)
|
||||
.def("__str__", &PyClass::ToString);
|
||||
}
|
||||
|
||||
static void PybindMultiChannelSamples(py::module *m) {
|
||||
using PyClass = MultiChannelSamples;
|
||||
|
||||
py::class_<PyClass>(*m, "MultiChannelSamples")
|
||||
.def_property_readonly("data", [](PyClass &self) -> py::object {
|
||||
// if data is not empty, return a float array of
|
||||
// shape (num_channels, num_samples)
|
||||
int32_t num_channels = self.data.size();
|
||||
if (num_channels == 0) {
|
||||
return py::none();
|
||||
}
|
||||
|
||||
int32_t num_samples = self.data[0].size();
|
||||
if (num_samples == 0) {
|
||||
return py::none();
|
||||
}
|
||||
|
||||
py::array_t<float> ans({num_channels, num_samples});
|
||||
|
||||
py::buffer_info buf = ans.request();
|
||||
auto p = static_cast<float *>(buf.ptr);
|
||||
|
||||
for (int32_t i = 0; i != num_channels; ++i) {
|
||||
std::copy(self.data[i].begin(), self.data[i].end(),
|
||||
p + i * num_samples);
|
||||
}
|
||||
|
||||
return ans;
|
||||
});
|
||||
}
|
||||
|
||||
static void PybindOfflineSourceSeparationOutput(py::module *m) {
|
||||
using PyClass = OfflineSourceSeparationOutput;
|
||||
py::class_<PyClass>(*m, "OfflineSourceSeparationOutput")
|
||||
.def_property_readonly(
|
||||
"sample_rate", [](const PyClass &self) { return self.sample_rate; })
|
||||
.def_property_readonly("stems",
|
||||
[](const PyClass &self) { return self.stems; });
|
||||
}
|
||||
|
||||
void PybindOfflineSourceSeparation(py::module *m) {
|
||||
PybindOfflineSourceSeparationConfig(m);
|
||||
PybindOfflineSourceSeparationOutput(m);
|
||||
|
||||
PybindMultiChannelSamples(m);
|
||||
|
||||
using PyClass = OfflineSourceSeparation;
|
||||
py::class_<PyClass>(*m, "OfflineSourceSeparation")
|
||||
.def(py::init<const OfflineSourceSeparationConfig &>(),
|
||||
py::arg("config") = OfflineSourceSeparationConfig{})
|
||||
.def(
|
||||
"process",
|
||||
[](const PyClass &self, int32_t sample_rate,
|
||||
const py::array_t<float> &samples) {
|
||||
if (!(C_CONTIGUOUS == (samples.flags() & C_CONTIGUOUS))) {
|
||||
throw py::value_error(
|
||||
"input samples should be contiguous. Please use "
|
||||
"np.ascontiguousarray(samples)");
|
||||
}
|
||||
|
||||
int num_dim = samples.ndim();
|
||||
if (samples.ndim() != 2) {
|
||||
std::ostringstream os;
|
||||
os << "Expect an array of 2 dimensions [num_channels x "
|
||||
"num_samples]. "
|
||||
"Given dim: "
|
||||
<< num_dim << "\n";
|
||||
throw py::value_error(os.str());
|
||||
}
|
||||
|
||||
// if num_samples is less than 10, it is very likely the user
|
||||
// has swapped num_channels and num_samples.
|
||||
if (samples.shape(1) < 10) {
|
||||
std::ostringstream os;
|
||||
os << "Expect an array of 2 dimensions [num_channels x "
|
||||
"num_samples]. "
|
||||
"Given ["
|
||||
<< samples.shape(0) << " x " << samples.shape(1) << "]"
|
||||
<< "\n";
|
||||
throw py::value_error(os.str());
|
||||
}
|
||||
|
||||
int32_t num_channels = samples.shape(0);
|
||||
int32_t num_samples = samples.shape(1);
|
||||
const float *p = samples.data();
|
||||
|
||||
OfflineSourceSeparationInput input;
|
||||
|
||||
input.samples.data.resize(num_channels);
|
||||
input.sample_rate = sample_rate;
|
||||
|
||||
for (int32_t i = 0; i != num_channels; ++i) {
|
||||
input.samples.data[i] = {p + i * num_samples,
|
||||
p + (i + 1) * num_samples};
|
||||
}
|
||||
|
||||
pybind11::gil_scoped_release release;
|
||||
|
||||
return self.Process(input);
|
||||
},
|
||||
py::arg("sample_rate"), py::arg("samples"),
|
||||
"samples is of shape (num_channels, num-samples) with dtype "
|
||||
"np.float32");
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
16
sherpa-onnx/python/csrc/offline-source-separation.h
Normal file
16
sherpa-onnx/python/csrc/offline-source-separation.h
Normal file
@@ -0,0 +1,16 @@
|
||||
// sherpa-onnx/python/csrc/offline-source-separation-config.h
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SOURCE_SEPARATION_CONFIG_H_
|
||||
#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SOURCE_SEPARATION_CONFIG_H_
|
||||
|
||||
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void PybindOfflineSourceSeparation(py::module *m);
|
||||
|
||||
}
|
||||
|
||||
#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SOURCE_SEPARATION_CONFIG_H_
|
||||
@@ -47,6 +47,7 @@ void PybindOfflineSpeechDenoiser(py::module *m) {
|
||||
int32_t sample_rate) {
|
||||
return self.Run(samples.data(), samples.size(), sample_rate);
|
||||
},
|
||||
py::arg("samples"), py::arg("sample_rate"),
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"run",
|
||||
@@ -54,6 +55,7 @@ void PybindOfflineSpeechDenoiser(py::module *m) {
|
||||
int32_t sample_rate) {
|
||||
return self.Run(samples.data(), samples.size(), sample_rate);
|
||||
},
|
||||
py::arg("samples"), py::arg("sample_rate"),
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def_property_readonly("sample_rate", &PyClass::GetSampleRate);
|
||||
}
|
||||
|
||||
@@ -109,19 +109,20 @@ void PybindOnlineRecognizer(py::module *m) {
|
||||
py::arg("hotwords"), py::call_guard<py::gil_scoped_release>())
|
||||
.def("is_ready", &PyClass::IsReady,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("decode_stream", &PyClass::DecodeStream,
|
||||
.def("decode_stream", &PyClass::DecodeStream, py::arg("s"),
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"decode_streams",
|
||||
[](PyClass &self, std::vector<OnlineStream *> ss) {
|
||||
self.DecodeStreams(ss.data(), ss.size());
|
||||
},
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("get_result", &PyClass::GetResult,
|
||||
py::arg("ss"), py::call_guard<py::gil_scoped_release>())
|
||||
.def("get_result", &PyClass::GetResult, py::arg("s"),
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("is_endpoint", &PyClass::IsEndpoint,
|
||||
.def("is_endpoint", &PyClass::IsEndpoint, py::arg("s"),
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("reset", &PyClass::Reset, py::call_guard<py::gil_scoped_release>());
|
||||
.def("reset", &PyClass::Reset, py::arg("s"),
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
#include "sherpa-onnx/python/csrc/offline-model-config.h"
|
||||
#include "sherpa-onnx/python/csrc/offline-punctuation.h"
|
||||
#include "sherpa-onnx/python/csrc/offline-recognizer.h"
|
||||
#include "sherpa-onnx/python/csrc/offline-source-separation.h"
|
||||
#include "sherpa-onnx/python/csrc/offline-speech-denoiser.h"
|
||||
#include "sherpa-onnx/python/csrc/offline-stream.h"
|
||||
#include "sherpa-onnx/python/csrc/online-ctc-fst-decoder-config.h"
|
||||
@@ -110,6 +111,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
|
||||
|
||||
PybindAlsa(&m);
|
||||
PybindOfflineSpeechDenoiser(&m);
|
||||
PybindOfflineSourceSeparation(&m);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -11,6 +11,11 @@ from _sherpa_onnx import (
|
||||
OfflinePunctuation,
|
||||
OfflinePunctuationConfig,
|
||||
OfflinePunctuationModelConfig,
|
||||
OfflineSourceSeparation,
|
||||
OfflineSourceSeparationConfig,
|
||||
OfflineSourceSeparationModelConfig,
|
||||
OfflineSourceSeparationSpleeterModelConfig,
|
||||
OfflineSourceSeparationUvrModelConfig,
|
||||
OfflineSpeakerDiarization,
|
||||
OfflineSpeakerDiarizationConfig,
|
||||
OfflineSpeakerDiarizationResult,
|
||||
|
||||
Reference in New Issue
Block a user