From d57e4f84de78568936494c324c78a9d090ac68a2 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 5 Jun 2025 20:44:26 +0800 Subject: [PATCH] Add Python API for source separation (#2283) --- .github/scripts/test-python.sh | 26 ++++ .github/workflows/run-python-test-macos.yaml | 5 + .github/workflows/run-python-test.yaml | 35 ++--- .../offline-source-separation-spleeter.py | 122 ++++++++++++++++ .../offline-source-separation-uvr.py | 118 ++++++++++++++++ sherpa-onnx/python/csrc/CMakeLists.txt | 4 + sherpa-onnx/python/csrc/fast-clustering.cc | 8 ++ sherpa-onnx/python/csrc/offline-recognizer.cc | 4 +- .../offline-source-separation-model-config.cc | 37 +++++ .../offline-source-separation-model-config.h | 16 +++ ...source-separation-spleeter-model-config.cc | 24 ++++ ...-source-separation-spleeter-model-config.h | 16 +++ ...line-source-separation-uvr-model-config.cc | 22 +++ ...fline-source-separation-uvr-model-config.h | 16 +++ .../python/csrc/offline-source-separation.cc | 133 ++++++++++++++++++ .../python/csrc/offline-source-separation.h | 16 +++ .../python/csrc/offline-speech-denoiser.cc | 2 + sherpa-onnx/python/csrc/online-recognizer.cc | 11 +- sherpa-onnx/python/csrc/sherpa-onnx.cc | 2 + sherpa-onnx/python/sherpa_onnx/__init__.py | 5 + 20 files changed, 599 insertions(+), 23 deletions(-) create mode 100755 python-api-examples/offline-source-separation-spleeter.py create mode 100755 python-api-examples/offline-source-separation-uvr.py create mode 100644 sherpa-onnx/python/csrc/offline-source-separation-model-config.cc create mode 100644 sherpa-onnx/python/csrc/offline-source-separation-model-config.h create mode 100644 sherpa-onnx/python/csrc/offline-source-separation-spleeter-model-config.cc create mode 100644 sherpa-onnx/python/csrc/offline-source-separation-spleeter-model-config.h create mode 100644 sherpa-onnx/python/csrc/offline-source-separation-uvr-model-config.cc create mode 100644 sherpa-onnx/python/csrc/offline-source-separation-uvr-model-config.h create mode 100644 sherpa-onnx/python/csrc/offline-source-separation.cc create mode 100644 sherpa-onnx/python/csrc/offline-source-separation.h diff --git a/.github/scripts/test-python.sh b/.github/scripts/test-python.sh index b9cbe8b9..fe665501 100755 --- a/.github/scripts/test-python.sh +++ b/.github/scripts/test-python.sh @@ -8,6 +8,32 @@ log() { echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" } +log "test spleeter" + +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/source-separation-models/sherpa-onnx-spleeter-2stems-fp16.tar.bz2 +tar xvf sherpa-onnx-spleeter-2stems-fp16.tar.bz2 +rm sherpa-onnx-spleeter-2stems-fp16.tar.bz2 +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/source-separation-models/qi-feng-le-zh.wav +./python-api-examples/offline-source-separation-spleeter.py +rm -rf sherpa-onnx-spleeter-2stems-fp16 +rm qi-feng-le-zh.wav + +log "test UVR" + +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/source-separation-models/UVR_MDXNET_9482.onnx +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/source-separation-models/qi-feng-le-zh.wav +./python-api-examples/offline-source-separation-uvr.py +rm UVR_MDXNET_9482.onnx +rm qi-feng-le-zh.wav + +mkdir source-separation + +mv spleeter-*.wav source-separation +mv uvr-*.wav source-separation + +ls -lh source-separation + + log "test offline dolphin ctc" curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-dolphin-base-ctc-multi-lang-int8-2025-04-02.tar.bz2 tar xvf sherpa-onnx-dolphin-base-ctc-multi-lang-int8-2025-04-02.tar.bz2 diff --git a/.github/workflows/run-python-test-macos.yaml b/.github/workflows/run-python-test-macos.yaml index c9fafe68..afcdee02 100644 --- a/.github/workflows/run-python-test-macos.yaml +++ b/.github/workflows/run-python-test-macos.yaml @@ -97,6 +97,11 @@ jobs: .github/scripts/test-python.sh .github/scripts/test-speaker-recognition-python.sh + - uses: actions/upload-artifact@v4 + with: + name: source-separation-${{ matrix.os }}-${{ matrix.python-version }} + path: ./source-separation + - uses: actions/upload-artifact@v4 with: name: tts-generated-test-files-${{ matrix.os }}-${{ matrix.python-version }} diff --git a/.github/workflows/run-python-test.yaml b/.github/workflows/run-python-test.yaml index 7080420f..9cea32bd 100644 --- a/.github/workflows/run-python-test.yaml +++ b/.github/workflows/run-python-test.yaml @@ -36,22 +36,18 @@ jobs: fail-fast: false matrix: include: - # it fails to install ffmpeg on ubuntu 20.04 - # - # - os: ubuntu-20.04 - # python-version: "3.7" - # - os: ubuntu-20.04 - # python-version: "3.8" - # - os: ubuntu-20.04 - # python-version: "3.9" + - os: ubuntu-24.04 + python-version: "3.8" + - os: ubuntu-24.04 + python-version: "3.9" - - os: ubuntu-22.04 + - os: ubuntu-24.04 python-version: "3.10" - - os: ubuntu-22.04 + - os: ubuntu-24.04 python-version: "3.11" - - os: ubuntu-22.04 + - os: ubuntu-24.04 python-version: "3.12" - - os: ubuntu-22.04 + - os: ubuntu-24.04 python-version: "3.13" steps: @@ -81,10 +77,12 @@ jobs: python3 -m pip install --upgrade pip numpy pypinyin sentencepiece>=0.1.96 soundfile python3 -m pip install wheel twine setuptools - - name: Install ffmpeg - shell: bash - run: | - sudo apt-get install ffmpeg + - uses: afoley587/setup-ffmpeg@main + id: setup-ffmpeg + with: + ffmpeg-version: release + architecture: '' + github-token: ${{ github.server_url == 'https://github.com' && github.token || '' }} - name: Install ninja shell: bash @@ -189,6 +187,11 @@ jobs: .github/scripts/test-python.sh .github/scripts/test-speaker-recognition-python.sh + - uses: actions/upload-artifact@v4 + with: + name: source-separation-${{ matrix.os }}-${{ matrix.python-version }}-whl + path: ./source-separation + - uses: actions/upload-artifact@v4 with: name: tts-generated-test-files-${{ matrix.os }}-${{ matrix.python-version }} diff --git a/python-api-examples/offline-source-separation-spleeter.py b/python-api-examples/offline-source-separation-spleeter.py new file mode 100755 index 00000000..0d167fdb --- /dev/null +++ b/python-api-examples/offline-source-separation-spleeter.py @@ -0,0 +1,122 @@ +#!/usr/bin/env python3 +# Copyright (c) 2025 Xiaomi Corporation + +""" +This file shows how to use spleeter for source separation. + +Please first download a spleeter model from + +https://github.com/k2-fsa/sherpa-onnx/releases/tag/source-separation-models + +The following is an example: + + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/source-separation-models/sherpa-onnx-spleeter-2stems-fp16.tar.bz2 + +Please also download a test file + + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/source-separation-models/qi-feng-le-zh.wav + +The test wav file is 16-bit encoded with 2 channels. If you have other +formats, e.g., .mp4 or .mp3, please first use ffmpeg to convert it. +For instance + + ffmpeg -i your.mp4 -vn -acodec pcm_s16le -ar 44100 -ac 2 out.wav + +Then you can use out.wav as input for this example. +""" + +import time +from pathlib import Path + +import numpy as np +import sherpa_onnx +import soundfile as sf + + +def create_offline_source_separation(): + # Please read the help message at the beginning of this file + # to download model files + vocals = "./sherpa-onnx-spleeter-2stems-fp16/vocals.fp16.onnx" + accompaniment = "./sherpa-onnx-spleeter-2stems-fp16/accompaniment.fp16.onnx" + + if not Path(vocals).is_file(): + raise ValueError(f"{vocals} does not exist.") + + if not Path(accompaniment).is_file(): + raise ValueError(f"{accompaniment} does not exist.") + + config = sherpa_onnx.OfflineSourceSeparationConfig( + model=sherpa_onnx.OfflineSourceSeparationModelConfig( + spleeter=sherpa_onnx.OfflineSourceSeparationSpleeterModelConfig( + vocals=vocals, + accompaniment=accompaniment, + ), + num_threads=1, + debug=False, + provider="cpu", + ) + ) + if not config.validate(): + raise ValueError("Please check your config.") + + return sherpa_onnx.OfflineSourceSeparation(config) + + +def load_audio(): + # Please read the help message at the beginning of this file to download + # the following wav_file + wav_file = "./qi-feng-le-zh.wav" + if not Path(wav_file).is_file(): + raise ValueError(f"{wav_file} does not exist") + + samples, sample_rate = sf.read(wav_file, dtype="float32", always_2d=True) + samples = np.transpose(samples) + # now samples is of shape (num_channels, num_samples) + assert ( + samples.shape[1] > samples.shape[0] + ), f"You should use (num_channels, num_samples). {samples.shape}" + + assert ( + samples.dtype == np.float32 + ), f"Expect np.float32 as dtype. Given: {samples.dtype}" + + return samples, sample_rate + + +def main(): + sp = create_offline_source_separation() + samples, sample_rate = load_audio() + samples = np.ascontiguousarray(samples) + + start = time.time() + output = sp.process(sample_rate=sample_rate, samples=samples) + end = time.time() + + print("output.sample_rate", output.sample_rate) + + assert len(output.stems) == 2, len(output.stems) + + vocals = output.stems[0].data + non_vocals = output.stems[1].data + # vocals.shape (num_channels, num_samples) + + vocals = np.transpose(vocals) + non_vocals = np.transpose(non_vocals) + + # vocals.shape (num_samples,num_channels) + + sf.write("./spleeter-vocals.wav", vocals, samplerate=output.sample_rate) + sf.write("./spleeter-non-vocals.wav", non_vocals, samplerate=output.sample_rate) + + elapsed_seconds = end - start + audio_duration = samples.shape[1] / sample_rate + real_time_factor = elapsed_seconds / audio_duration + + print("Saved to ./spleeter-vocals.wav and ./spleeter-non-vocals.wav") + print(f"Elapsed seconds: {elapsed_seconds:.3f}") + print(f"Audio duration in seconds: {audio_duration:.3f}") + print(f"RTF: {elapsed_seconds:.3f}/{audio_duration:.3f} = {real_time_factor:.3f}") + + +if __name__ == "__main__": + main() diff --git a/python-api-examples/offline-source-separation-uvr.py b/python-api-examples/offline-source-separation-uvr.py new file mode 100755 index 00000000..bffb9717 --- /dev/null +++ b/python-api-examples/offline-source-separation-uvr.py @@ -0,0 +1,118 @@ +#!/usr/bin/env python3 +# Copyright (c) 2025 Xiaomi Corporation + +""" +This file shows how to use UVR for source separation. + +Please first download a UVR model from + +https://github.com/k2-fsa/sherpa-onnx/releases/tag/source-separation-models + +The following is an example: + + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/source-separation-models/UVR_MDXNET_9482.onnx + +Please also download a test file + + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/source-separation-models/qi-feng-le-zh.wav + +The test wav file is 16-bit encoded with 2 channels. If you have other +formats, e.g., .mp4 or .mp3, please first use ffmpeg to convert it. +For instance + + ffmpeg -i your.mp4 -vn -acodec pcm_s16le -ar 44100 -ac 2 out.wav + +Then you can use out.wav as input for this example. +""" + +import time +from pathlib import Path + +import numpy as np +import sherpa_onnx +import soundfile as sf + + +def create_offline_source_separation(): + # Please read the help message at the beginning of this file + # to download model files + model = "./UVR_MDXNET_9482.onnx" + + if not Path(model).is_file(): + raise ValueError(f"{model} does not exist.") + + config = sherpa_onnx.OfflineSourceSeparationConfig( + model=sherpa_onnx.OfflineSourceSeparationModelConfig( + uvr=sherpa_onnx.OfflineSourceSeparationUvrModelConfig( + model=model, + ), + num_threads=1, + debug=False, + provider="cpu", + ) + ) + if not config.validate(): + raise ValueError("Please check your config.") + + return sherpa_onnx.OfflineSourceSeparation(config) + + +def load_audio(): + # Please read the help message at the beginning of this file to download + # the following wav_file + wav_file = "./qi-feng-le-zh.wav" + if not Path(wav_file).is_file(): + raise ValueError(f"{wav_file} does not exist") + + samples, sample_rate = sf.read(wav_file, dtype="float32", always_2d=True) + samples = np.transpose(samples) + # now samples is of shape (num_channels, num_samples) + assert ( + samples.shape[1] > samples.shape[0] + ), f"You should use (num_channels, num_samples). {samples.shape}" + + assert ( + samples.dtype == np.float32 + ), f"Expect np.float32 as dtype. Given: {samples.dtype}" + + return samples, sample_rate + + +def main(): + sp = create_offline_source_separation() + samples, sample_rate = load_audio() + samples = np.ascontiguousarray(samples) + + print("Started. Please wait") + start = time.time() + output = sp.process(sample_rate=sample_rate, samples=samples) + end = time.time() + + print("output.sample_rate", output.sample_rate) + + assert len(output.stems) == 2, len(output.stems) + + vocals = output.stems[0].data + non_vocals = output.stems[1].data + # vocals.shape (num_channels, num_samples) + + vocals = np.transpose(vocals) + non_vocals = np.transpose(non_vocals) + + # vocals.shape (num_samples,num_channels) + + sf.write("./uvr-vocals.wav", vocals, samplerate=output.sample_rate) + sf.write("./uvr-non-vocals.wav", non_vocals, samplerate=output.sample_rate) + + elapsed_seconds = end - start + audio_duration = samples.shape[1] / sample_rate + real_time_factor = elapsed_seconds / audio_duration + + print("Saved to ./uvr-vocals.wav and ./uvr-non-vocals.wav") + print(f"Elapsed seconds: {elapsed_seconds:.3f}") + print(f"Audio duration in seconds: {audio_duration:.3f}") + print(f"RTF: {elapsed_seconds:.3f}/{audio_duration:.3f} = {real_time_factor:.3f}") + + +if __name__ == "__main__": + main() diff --git a/sherpa-onnx/python/csrc/CMakeLists.txt b/sherpa-onnx/python/csrc/CMakeLists.txt index 2626a251..e3df5028 100644 --- a/sherpa-onnx/python/csrc/CMakeLists.txt +++ b/sherpa-onnx/python/csrc/CMakeLists.txt @@ -20,6 +20,10 @@ set(srcs offline-punctuation.cc offline-recognizer.cc offline-sense-voice-model-config.cc + offline-source-separation-model-config.cc + offline-source-separation-spleeter-model-config.cc + offline-source-separation-uvr-model-config.cc + offline-source-separation.cc offline-speech-denoiser-gtcrn-model-config.cc offline-speech-denoiser-model-config.cc offline-speech-denoiser.cc diff --git a/sherpa-onnx/python/csrc/fast-clustering.cc b/sherpa-onnx/python/csrc/fast-clustering.cc index b0342b3f..17956ccc 100644 --- a/sherpa-onnx/python/csrc/fast-clustering.cc +++ b/sherpa-onnx/python/csrc/fast-clustering.cc @@ -9,6 +9,8 @@ #include "sherpa-onnx/csrc/fast-clustering.h" +#define C_CONTIGUOUS py::detail::npy_api::constants::NPY_ARRAY_C_CONTIGUOUS_ + namespace sherpa_onnx { static void PybindFastClusteringConfig(py::module *m) { @@ -32,6 +34,12 @@ void PybindFastClustering(py::module *m) { "__call__", [](const PyClass &self, py::array_t features) -> std::vector { + if (!(C_CONTIGUOUS == (features.flags() & C_CONTIGUOUS))) { + throw py::value_error( + "input features should be contiguous. Please use " + "np.ascontiguousarray(features)"); + } + int num_dim = features.ndim(); if (num_dim != 2) { std::ostringstream os; diff --git a/sherpa-onnx/python/csrc/offline-recognizer.cc b/sherpa-onnx/python/csrc/offline-recognizer.cc index 13b44d7f..dbc68862 100644 --- a/sherpa-onnx/python/csrc/offline-recognizer.cc +++ b/sherpa-onnx/python/csrc/offline-recognizer.cc @@ -59,14 +59,14 @@ void PybindOfflineRecognizer(py::module *m) { return self.CreateStream(hotwords); }, py::arg("hotwords"), py::call_guard()) - .def("decode_stream", &PyClass::DecodeStream, + .def("decode_stream", &PyClass::DecodeStream, py::arg("s"), py::call_guard()) .def( "decode_streams", [](const PyClass &self, std::vector ss) { self.DecodeStreams(ss.data(), ss.size()); }, - py::call_guard()); + py::arg("ss"), py::call_guard()); } } // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/offline-source-separation-model-config.cc b/sherpa-onnx/python/csrc/offline-source-separation-model-config.cc new file mode 100644 index 00000000..8e4b874a --- /dev/null +++ b/sherpa-onnx/python/csrc/offline-source-separation-model-config.cc @@ -0,0 +1,37 @@ +// sherpa-onnx/python/csrc/offline-source-separation-model-config.cc +// +// Copyright (c) 2025 Xiaomi Corporation + +#include "sherpa-onnx/python/csrc/offline-source-separation-model-config.h" + +#include + +#include "sherpa-onnx/csrc/offline-source-separation-model-config.h" +#include "sherpa-onnx/python/csrc/offline-source-separation-spleeter-model-config.h" +#include "sherpa-onnx/python/csrc/offline-source-separation-uvr-model-config.h" + +namespace sherpa_onnx { + +void PybindOfflineSourceSeparationModelConfig(py::module *m) { + PybindOfflineSourceSeparationSpleeterModelConfig(m); + PybindOfflineSourceSeparationUvrModelConfig(m); + + using PyClass = OfflineSourceSeparationModelConfig; + py::class_(*m, "OfflineSourceSeparationModelConfig") + .def(py::init(), + py::arg("spleeter") = OfflineSourceSeparationSpleeterModelConfig{}, + py::arg("uvr") = OfflineSourceSeparationUvrModelConfig{}, + py::arg("num_threads") = 1, py::arg("debug") = false, + py::arg("provider") = "cpu") + .def_readwrite("spleeter", &PyClass::spleeter) + .def_readwrite("uvr", &PyClass::uvr) + .def_readwrite("num_threads", &PyClass::num_threads) + .def_readwrite("debug", &PyClass::debug) + .def_readwrite("provider", &PyClass::provider) + .def("validate", &PyClass::Validate) + .def("__str__", &PyClass::ToString); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/offline-source-separation-model-config.h b/sherpa-onnx/python/csrc/offline-source-separation-model-config.h new file mode 100644 index 00000000..9999d40e --- /dev/null +++ b/sherpa-onnx/python/csrc/offline-source-separation-model-config.h @@ -0,0 +1,16 @@ +// sherpa-onnx/python/csrc/offline-source-separation-model-config.h +// +// Copyright (c) 2025 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SOURCE_SEPARATION_MODEL_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SOURCE_SEPARATION_MODEL_CONFIG_H_ + +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" + +namespace sherpa_onnx { + +void PybindOfflineSourceSeparationModelConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SOURCE_SEPARATION_MODEL_CONFIG_H_ diff --git a/sherpa-onnx/python/csrc/offline-source-separation-spleeter-model-config.cc b/sherpa-onnx/python/csrc/offline-source-separation-spleeter-model-config.cc new file mode 100644 index 00000000..78c739c6 --- /dev/null +++ b/sherpa-onnx/python/csrc/offline-source-separation-spleeter-model-config.cc @@ -0,0 +1,24 @@ +// sherpa-onnx/python/csrc/offline-source-separation-spleeter-model-config.cc +// +// Copyright (c) 2025 Xiaomi Corporation + +#include "sherpa-onnx/python/csrc/offline-source-separation-spleeter-model-config.h" + +#include + +#include "sherpa-onnx/csrc/offline-source-separation-spleeter-model-config.h" + +namespace sherpa_onnx { + +void PybindOfflineSourceSeparationSpleeterModelConfig(py::module *m) { + using PyClass = OfflineSourceSeparationSpleeterModelConfig; + py::class_(*m, "OfflineSourceSeparationSpleeterModelConfig") + .def(py::init(), + py::arg("vocals") = "", py::arg("accompaniment") = "") + .def_readwrite("vocals", &PyClass::vocals) + .def_readwrite("accompaniment", &PyClass::accompaniment) + .def("validate", &PyClass::Validate) + .def("__str__", &PyClass::ToString); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/offline-source-separation-spleeter-model-config.h b/sherpa-onnx/python/csrc/offline-source-separation-spleeter-model-config.h new file mode 100644 index 00000000..1ec3ff53 --- /dev/null +++ b/sherpa-onnx/python/csrc/offline-source-separation-spleeter-model-config.h @@ -0,0 +1,16 @@ +// sherpa-onnx/python/csrc/offline-source-separation-spleeter-model-config.h +// +// Copyright (c) 2025 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_CONFIG_H_ + +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" + +namespace sherpa_onnx { + +void PybindOfflineSourceSeparationSpleeterModelConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_CONFIG_H_ diff --git a/sherpa-onnx/python/csrc/offline-source-separation-uvr-model-config.cc b/sherpa-onnx/python/csrc/offline-source-separation-uvr-model-config.cc new file mode 100644 index 00000000..11749459 --- /dev/null +++ b/sherpa-onnx/python/csrc/offline-source-separation-uvr-model-config.cc @@ -0,0 +1,22 @@ +// sherpa-onnx/python/csrc/offline-source-separation-uvr-model-config.cc +// +// Copyright (c) 2025 Xiaomi Corporation + +#include "sherpa-onnx/python/csrc/offline-source-separation-uvr-model-config.h" + +#include + +#include "sherpa-onnx/csrc/offline-source-separation-uvr-model-config.h" + +namespace sherpa_onnx { + +void PybindOfflineSourceSeparationUvrModelConfig(py::module *m) { + using PyClass = OfflineSourceSeparationUvrModelConfig; + py::class_(*m, "OfflineSourceSeparationUvrModelConfig") + .def(py::init(), py::arg("model") = "") + .def_readwrite("model", &PyClass::model) + .def("validate", &PyClass::Validate) + .def("__str__", &PyClass::ToString); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/offline-source-separation-uvr-model-config.h b/sherpa-onnx/python/csrc/offline-source-separation-uvr-model-config.h new file mode 100644 index 00000000..5fc204c3 --- /dev/null +++ b/sherpa-onnx/python/csrc/offline-source-separation-uvr-model-config.h @@ -0,0 +1,16 @@ +// sherpa-onnx/python/csrc/offline-source-separation-uvr-model-config.h +// +// Copyright (c) 2025 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SOURCE_SEPARATION_UVR_MODEL_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SOURCE_SEPARATION_UVR_MODEL_CONFIG_H_ + +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" + +namespace sherpa_onnx { + +void PybindOfflineSourceSeparationUvrModelConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SOURCE_SEPARATION_UVR_MODEL_CONFIG_H_ diff --git a/sherpa-onnx/python/csrc/offline-source-separation.cc b/sherpa-onnx/python/csrc/offline-source-separation.cc new file mode 100644 index 00000000..f4040377 --- /dev/null +++ b/sherpa-onnx/python/csrc/offline-source-separation.cc @@ -0,0 +1,133 @@ +// sherpa-onnx/python/csrc/offline-source-separation-config.cc +// +// Copyright (c) 2025 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-source-separation.h" + +#include + +#include "sherpa-onnx/python/csrc/offline-source-separation-model-config.h" +#include "sherpa-onnx/python/csrc/offline-source-separation.h" + +#define C_CONTIGUOUS py::detail::npy_api::constants::NPY_ARRAY_C_CONTIGUOUS_ + +namespace sherpa_onnx { + +static void PybindOfflineSourceSeparationConfig(py::module *m) { + PybindOfflineSourceSeparationModelConfig(m); + + using PyClass = OfflineSourceSeparationConfig; + py::class_(*m, "OfflineSourceSeparationConfig") + .def(py::init(), + py::arg("model") = OfflineSourceSeparationModelConfig{}) + .def_readwrite("model", &PyClass::model) + .def("validate", &PyClass::Validate) + .def("__str__", &PyClass::ToString); +} + +static void PybindMultiChannelSamples(py::module *m) { + using PyClass = MultiChannelSamples; + + py::class_(*m, "MultiChannelSamples") + .def_property_readonly("data", [](PyClass &self) -> py::object { + // if data is not empty, return a float array of + // shape (num_channels, num_samples) + int32_t num_channels = self.data.size(); + if (num_channels == 0) { + return py::none(); + } + + int32_t num_samples = self.data[0].size(); + if (num_samples == 0) { + return py::none(); + } + + py::array_t ans({num_channels, num_samples}); + + py::buffer_info buf = ans.request(); + auto p = static_cast(buf.ptr); + + for (int32_t i = 0; i != num_channels; ++i) { + std::copy(self.data[i].begin(), self.data[i].end(), + p + i * num_samples); + } + + return ans; + }); +} + +static void PybindOfflineSourceSeparationOutput(py::module *m) { + using PyClass = OfflineSourceSeparationOutput; + py::class_(*m, "OfflineSourceSeparationOutput") + .def_property_readonly( + "sample_rate", [](const PyClass &self) { return self.sample_rate; }) + .def_property_readonly("stems", + [](const PyClass &self) { return self.stems; }); +} + +void PybindOfflineSourceSeparation(py::module *m) { + PybindOfflineSourceSeparationConfig(m); + PybindOfflineSourceSeparationOutput(m); + + PybindMultiChannelSamples(m); + + using PyClass = OfflineSourceSeparation; + py::class_(*m, "OfflineSourceSeparation") + .def(py::init(), + py::arg("config") = OfflineSourceSeparationConfig{}) + .def( + "process", + [](const PyClass &self, int32_t sample_rate, + const py::array_t &samples) { + if (!(C_CONTIGUOUS == (samples.flags() & C_CONTIGUOUS))) { + throw py::value_error( + "input samples should be contiguous. Please use " + "np.ascontiguousarray(samples)"); + } + + int num_dim = samples.ndim(); + if (samples.ndim() != 2) { + std::ostringstream os; + os << "Expect an array of 2 dimensions [num_channels x " + "num_samples]. " + "Given dim: " + << num_dim << "\n"; + throw py::value_error(os.str()); + } + + // if num_samples is less than 10, it is very likely the user + // has swapped num_channels and num_samples. + if (samples.shape(1) < 10) { + std::ostringstream os; + os << "Expect an array of 2 dimensions [num_channels x " + "num_samples]. " + "Given [" + << samples.shape(0) << " x " << samples.shape(1) << "]" + << "\n"; + throw py::value_error(os.str()); + } + + int32_t num_channels = samples.shape(0); + int32_t num_samples = samples.shape(1); + const float *p = samples.data(); + + OfflineSourceSeparationInput input; + + input.samples.data.resize(num_channels); + input.sample_rate = sample_rate; + + for (int32_t i = 0; i != num_channels; ++i) { + input.samples.data[i] = {p + i * num_samples, + p + (i + 1) * num_samples}; + } + + pybind11::gil_scoped_release release; + + return self.Process(input); + }, + py::arg("sample_rate"), py::arg("samples"), + "samples is of shape (num_channels, num-samples) with dtype " + "np.float32"); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/offline-source-separation.h b/sherpa-onnx/python/csrc/offline-source-separation.h new file mode 100644 index 00000000..90c9fdcc --- /dev/null +++ b/sherpa-onnx/python/csrc/offline-source-separation.h @@ -0,0 +1,16 @@ +// sherpa-onnx/python/csrc/offline-source-separation-config.h +// +// Copyright (c) 2025 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SOURCE_SEPARATION_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SOURCE_SEPARATION_CONFIG_H_ + +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" + +namespace sherpa_onnx { + +void PybindOfflineSourceSeparation(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SOURCE_SEPARATION_CONFIG_H_ diff --git a/sherpa-onnx/python/csrc/offline-speech-denoiser.cc b/sherpa-onnx/python/csrc/offline-speech-denoiser.cc index a111c315..9bd956e4 100644 --- a/sherpa-onnx/python/csrc/offline-speech-denoiser.cc +++ b/sherpa-onnx/python/csrc/offline-speech-denoiser.cc @@ -47,6 +47,7 @@ void PybindOfflineSpeechDenoiser(py::module *m) { int32_t sample_rate) { return self.Run(samples.data(), samples.size(), sample_rate); }, + py::arg("samples"), py::arg("sample_rate"), py::call_guard()) .def( "run", @@ -54,6 +55,7 @@ void PybindOfflineSpeechDenoiser(py::module *m) { int32_t sample_rate) { return self.Run(samples.data(), samples.size(), sample_rate); }, + py::arg("samples"), py::arg("sample_rate"), py::call_guard()) .def_property_readonly("sample_rate", &PyClass::GetSampleRate); } diff --git a/sherpa-onnx/python/csrc/online-recognizer.cc b/sherpa-onnx/python/csrc/online-recognizer.cc index b9c74d54..e2f83be8 100644 --- a/sherpa-onnx/python/csrc/online-recognizer.cc +++ b/sherpa-onnx/python/csrc/online-recognizer.cc @@ -109,19 +109,20 @@ void PybindOnlineRecognizer(py::module *m) { py::arg("hotwords"), py::call_guard()) .def("is_ready", &PyClass::IsReady, py::call_guard()) - .def("decode_stream", &PyClass::DecodeStream, + .def("decode_stream", &PyClass::DecodeStream, py::arg("s"), py::call_guard()) .def( "decode_streams", [](PyClass &self, std::vector ss) { self.DecodeStreams(ss.data(), ss.size()); }, - py::call_guard()) - .def("get_result", &PyClass::GetResult, + py::arg("ss"), py::call_guard()) + .def("get_result", &PyClass::GetResult, py::arg("s"), py::call_guard()) - .def("is_endpoint", &PyClass::IsEndpoint, + .def("is_endpoint", &PyClass::IsEndpoint, py::arg("s"), py::call_guard()) - .def("reset", &PyClass::Reset, py::call_guard()); + .def("reset", &PyClass::Reset, py::arg("s"), + py::call_guard()); } } // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/sherpa-onnx.cc b/sherpa-onnx/python/csrc/sherpa-onnx.cc index a6fa8cba..4636eea2 100644 --- a/sherpa-onnx/python/csrc/sherpa-onnx.cc +++ b/sherpa-onnx/python/csrc/sherpa-onnx.cc @@ -17,6 +17,7 @@ #include "sherpa-onnx/python/csrc/offline-model-config.h" #include "sherpa-onnx/python/csrc/offline-punctuation.h" #include "sherpa-onnx/python/csrc/offline-recognizer.h" +#include "sherpa-onnx/python/csrc/offline-source-separation.h" #include "sherpa-onnx/python/csrc/offline-speech-denoiser.h" #include "sherpa-onnx/python/csrc/offline-stream.h" #include "sherpa-onnx/python/csrc/online-ctc-fst-decoder-config.h" @@ -110,6 +111,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) { PybindAlsa(&m); PybindOfflineSpeechDenoiser(&m); + PybindOfflineSourceSeparation(&m); } } // namespace sherpa_onnx diff --git a/sherpa-onnx/python/sherpa_onnx/__init__.py b/sherpa-onnx/python/sherpa_onnx/__init__.py index 529712b3..9bb82750 100644 --- a/sherpa-onnx/python/sherpa_onnx/__init__.py +++ b/sherpa-onnx/python/sherpa_onnx/__init__.py @@ -11,6 +11,11 @@ from _sherpa_onnx import ( OfflinePunctuation, OfflinePunctuationConfig, OfflinePunctuationModelConfig, + OfflineSourceSeparation, + OfflineSourceSeparationConfig, + OfflineSourceSeparationModelConfig, + OfflineSourceSeparationSpleeterModelConfig, + OfflineSourceSeparationUvrModelConfig, OfflineSpeakerDiarization, OfflineSpeakerDiarizationConfig, OfflineSpeakerDiarizationResult,