diff --git a/.github/scripts/test-offline-speech-denoiser.sh b/.github/scripts/test-offline-speech-denoiser.sh new file mode 100755 index 00000000..9a10129d --- /dev/null +++ b/.github/scripts/test-offline-speech-denoiser.sh @@ -0,0 +1,32 @@ +#!/usr/bin/env bash + +set -ex + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +if [ -z $EXE ]; then + EXE=./build/bin/sherpa-onnx-offline-denoiser +fi + +echo "EXE is $EXE" +echo "PATH: $PATH" + +which $EXE + +log "------------------------------------------------------------" +log "Run gtcrn" +log "------------------------------------------------------------" +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speech-enhancement-models/gtcrn_simple.onnx +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speech-enhancement-models/speech_with_noise.wav + +$EXE \ + --debug=1 \ + --speech-denoiser-gtcrn-model=./gtcrn_simple.onnx \ + --input-wav=./speech_with_noise.wav \ + --output-wav=./enhanced_speech_16k.wav + +rm ./gtcrn_simple.onnx diff --git a/.github/workflows/linux.yaml b/.github/workflows/linux.yaml index 7db76631..bcf27136 100644 --- a/.github/workflows/linux.yaml +++ b/.github/workflows/linux.yaml @@ -10,6 +10,7 @@ on: - '.github/workflows/linux.yaml' - '.github/scripts/test-kws.sh' - '.github/scripts/test-online-transducer.sh' + - '.github/scripts/test-offline-speech-denoiser.sh' - '.github/scripts/test-online-paraformer.sh' - '.github/scripts/test-offline-transducer.sh' - '.github/scripts/test-offline-ctc.sh' @@ -31,6 +32,7 @@ on: paths: - '.github/workflows/linux.yaml' - '.github/scripts/test-kws.sh' + - '.github/scripts/test-offline-speech-denoiser.sh' - '.github/scripts/test-online-transducer.sh' - '.github/scripts/test-online-paraformer.sh' - '.github/scripts/test-offline-transducer.sh' @@ -203,6 +205,15 @@ jobs: overwrite: true file: sherpa-onnx-*.tar.bz2 + - name: Test offline speech denoiser + shell: bash + run: | + du -h -d1 . + export PATH=$PWD/build/bin:$PATH + export EXE=sherpa-onnx-offline-denoiser + + .github/scripts/test-offline-speech-denoiser.sh + - name: Test offline TTS if: matrix.with_tts == 'ON' shell: bash @@ -214,6 +225,11 @@ jobs: .github/scripts/test-offline-tts.sh du -h -d1 . + - uses: actions/upload-artifact@v4 + with: + name: speech-denoiser-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }} + path: ./*speech*.wav + - uses: actions/upload-artifact@v4 if: matrix.with_tts == 'ON' with: diff --git a/.github/workflows/macos.yaml b/.github/workflows/macos.yaml index 5619825a..bc7bbfd7 100644 --- a/.github/workflows/macos.yaml +++ b/.github/workflows/macos.yaml @@ -7,6 +7,7 @@ on: tags: - 'v[0-9]+.[0-9]+.[0-9]+*' paths: + - '.github/scripts/test-offline-speech-denoiser.sh' - '.github/workflows/macos.yaml' - '.github/scripts/test-kws.sh' - '.github/scripts/test-online-transducer.sh' @@ -28,6 +29,7 @@ on: branches: - master paths: + - '.github/scripts/test-offline-speech-denoiser.sh' - '.github/workflows/macos.yaml' - '.github/scripts/test-kws.sh' - '.github/scripts/test-online-transducer.sh' @@ -160,6 +162,15 @@ jobs: overwrite: true file: sherpa-onnx-*osx-universal2*.tar.bz2 + - name: Test offline speech denoiser + shell: bash + run: | + du -h -d1 . + export PATH=$PWD/build/bin:$PATH + export EXE=sherpa-onnx-offline-denoiser + + .github/scripts/test-offline-speech-denoiser.sh + - name: Test offline TTS if: matrix.with_tts == 'ON' shell: bash diff --git a/README.md b/README.md index aab2d52e..1f08959c 100644 --- a/README.md +++ b/README.md @@ -12,9 +12,9 @@ |--------------------------------|---------------|--------------------------| | ✔️ | ✔️ | ✔️ | -| Keyword spotting | Add punctuation | -|------------------|-----------------| -| ✔️ | ✔️ | +| Keyword spotting | Add punctuation | Speech enhancement | +|------------------|-----------------|--------------------| +| ✔️ | ✔️ | ✔️ | ### Supported platforms @@ -198,6 +198,7 @@ We also have spaces built using WebAssembly. They are listed below: | Spoken language identification (Language ID)| See multi-lingual [Whisper][Whisper] ASR models from [Speech recognition][asr-models]| | Punctuation | [Address][punct-models] | | Speaker segmentation | [Address][speaker-segmentation-models] | +| Speech enhancement | [Address][speech-enhancement-models] | @@ -442,3 +443,4 @@ sherpa-onnx in Unity. See also [#1695](https://github.com/k2-fsa/sherpa-onnx/iss [Moonshine tiny]: https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-moonshine-tiny-en-int8.tar.bz2 [NVIDIA Jetson Orin NX]: https://developer.download.nvidia.com/assets/embedded/secure/jetson/orin_nx/docs/Jetson_Orin_NX_DS-10712-001_v0.5.pdf?RCPGu9Q6OVAOv7a7vgtwc9-BLScXRIWq6cSLuditMALECJ_dOj27DgnqAPGVnT2VpiNpQan9SyFy-9zRykR58CokzbXwjSA7Gj819e91AXPrWkGZR3oS1VLxiDEpJa_Y0lr7UT-N4GnXtb8NlUkP4GkCkkF_FQivGPrAucCUywL481GH_WpP_p7ziHU1Wg==&t=eyJscyI6ImdzZW8iLCJsc2QiOiJodHRwczovL3d3dy5nb29nbGUuY29tLmhrLyJ9 [NVIDIA Jetson Nano B01]: https://www.seeedstudio.com/blog/2020/01/16/new-revision-of-jetson-nano-dev-kit-now-supports-new-jetson-nano-module/ +[speech-enhancement-models]: https://github.com/k2-fsa/sherpa-onnx/releases/tag/speech-enhancement-models diff --git a/cmake/kaldi-native-fbank.cmake b/cmake/kaldi-native-fbank.cmake index 8f6803c8..f7aba1b5 100644 --- a/cmake/kaldi-native-fbank.cmake +++ b/cmake/kaldi-native-fbank.cmake @@ -1,9 +1,9 @@ function(download_kaldi_native_fbank) include(FetchContent) - set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.20.0.tar.gz") - set(kaldi_native_fbank_URL2 "https://hf-mirror.com/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.20.0.tar.gz") - set(kaldi_native_fbank_HASH "SHA256=c6195b3cf374eef824644061d3c04f6b2a9267ae554169cbaa9865c89c1fe4f9") + set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.21.1.tar.gz") + set(kaldi_native_fbank_URL2 "https://hf-mirror.com/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.21.1.tar.gz") + set(kaldi_native_fbank_HASH "SHA256=37c1aa230b00fe062791d800d8fc50aa3de215918d3dce6440699e67275d859e") set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE) set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE) @@ -12,11 +12,11 @@ function(download_kaldi_native_fbank) # If you don't have access to the Internet, # please pre-download kaldi-native-fbank set(possible_file_locations - $ENV{HOME}/Downloads/kaldi-native-fbank-1.20.0.tar.gz - ${CMAKE_SOURCE_DIR}/kaldi-native-fbank-1.20.0.tar.gz - ${CMAKE_BINARY_DIR}/kaldi-native-fbank-1.20.0.tar.gz - /tmp/kaldi-native-fbank-1.20.0.tar.gz - /star-fj/fangjun/download/github/kaldi-native-fbank-1.20.0.tar.gz + $ENV{HOME}/Downloads/kaldi-native-fbank-1.21.1.tar.gz + ${CMAKE_SOURCE_DIR}/kaldi-native-fbank-1.21.1.tar.gz + ${CMAKE_BINARY_DIR}/kaldi-native-fbank-1.21.1.tar.gz + /tmp/kaldi-native-fbank-1.21.1.tar.gz + /star-fj/fangjun/download/github/kaldi-native-fbank-1.21.1.tar.gz ) foreach(f IN LISTS possible_file_locations) diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index 1415c2ae..fbe20792 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -186,6 +186,14 @@ if(SHERPA_ONNX_ENABLE_TTS) ) endif() +list(APPEND sources + offline-speech-denoiser-gtcrn-model-config.cc + offline-speech-denoiser-gtcrn-model.cc + offline-speech-denoiser-impl.cc + offline-speech-denoiser-model-config.cc + offline-speech-denoiser.cc +) + if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION) list(APPEND sources fast-clustering-config.cc @@ -301,6 +309,7 @@ if(SHERPA_ONNX_ENABLE_BINARY) add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc) add_executable(sherpa-onnx-offline-punctuation sherpa-onnx-offline-punctuation.cc) add_executable(sherpa-onnx-online-punctuation sherpa-onnx-online-punctuation.cc) + add_executable(sherpa-onnx-offline-denoiser sherpa-onnx-offline-denoiser.cc) if(SHERPA_ONNX_ENABLE_TTS) add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc) @@ -318,6 +327,7 @@ if(SHERPA_ONNX_ENABLE_BINARY) sherpa-onnx-offline-language-identification sherpa-onnx-offline-parallel sherpa-onnx-offline-punctuation + sherpa-onnx-offline-denoiser sherpa-onnx-online-punctuation ) if(SHERPA_ONNX_ENABLE_TTS) diff --git a/sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-impl.h b/sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-impl.h new file mode 100644 index 00000000..dcd959a6 --- /dev/null +++ b/sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-impl.h @@ -0,0 +1,149 @@ +// sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-impl.h +// +// Copyright (c) 2025 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_IMPL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_IMPL_H_ + +#include +#include +#include +#include + +#include "kaldi-native-fbank/csrc/feature-window.h" +#include "kaldi-native-fbank/csrc/istft.h" +#include "kaldi-native-fbank/csrc/stft.h" +#include "sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model.h" +#include "sherpa-onnx/csrc/offline-speech-denoiser-impl.h" +#include "sherpa-onnx/csrc/offline-speech-denoiser.h" +#include "sherpa-onnx/csrc/resample.h" + +namespace sherpa_onnx { + +class OfflineSpeechDenoiserGtcrnImpl : public OfflineSpeechDenoiserImpl { + public: + explicit OfflineSpeechDenoiserGtcrnImpl( + const OfflineSpeechDenoiserConfig &config) + : model_(config.model) {} + + template + OfflineSpeechDenoiserGtcrnImpl(Manager *mgr, + const OfflineSpeechDenoiserConfig &config) + : model_(mgr, config.model) {} + + DenoisedAudio Run(const float *samples, int32_t n, + int32_t sample_rate) const override { + SHERPA_ONNX_LOGE("n: %d, sample_rate: %d", n, sample_rate); + const auto &meta = model_.GetMetaData(); + + std::vector tmp; + auto p = samples; + + if (sample_rate != meta.sample_rate) { + SHERPA_ONNX_LOGE( + "Creating a resampler:\n" + " in_sample_rate: %d\n" + " output_sample_rate: %d\n", + sample_rate, meta.sample_rate); + + float min_freq = std::min(sample_rate, meta.sample_rate); + float lowpass_cutoff = 0.99 * 0.5 * min_freq; + + int32_t lowpass_filter_width = 6; + auto resampler = std::make_unique( + sample_rate, meta.sample_rate, lowpass_cutoff, lowpass_filter_width); + resampler->Resample(samples, n, true, &tmp); + p = tmp.data(); + n = tmp.size(); + } + + knf::StftConfig stft_config; + stft_config.n_fft = meta.n_fft; + stft_config.hop_length = meta.hop_length; + stft_config.win_length = meta.window_length; + stft_config.window_type = meta.window_type; + if (stft_config.window_type == "hann_sqrt") { + auto window = knf::GetWindow("hann", stft_config.win_length); + for (auto &w : window) { + w = std::sqrt(w); + } + stft_config.window = std::move(window); + } + + knf::Stft stft(stft_config); + knf::StftResult stft_result = stft.Compute(p, n); + + auto states = model_.GetInitStates(); + OfflineSpeechDenoiserGtcrnModel::States next_states; + + knf::StftResult enhanced_stft_result; + enhanced_stft_result.num_frames = stft_result.num_frames; + for (int32_t i = 0; i < stft_result.num_frames; ++i) { + auto p = Process(stft_result, i, std::move(states), &next_states); + states = std::move(next_states); + + enhanced_stft_result.real.insert(enhanced_stft_result.real.end(), + p.first.begin(), p.first.end()); + enhanced_stft_result.imag.insert(enhanced_stft_result.imag.end(), + p.second.begin(), p.second.end()); + } + + knf::IStft istft(stft_config); + + DenoisedAudio denoised_audio; + denoised_audio.sample_rate = meta.sample_rate; + denoised_audio.samples = istft.Compute(enhanced_stft_result); + return denoised_audio; + } + + int32_t GetSampleRate() const override { + return model_.GetMetaData().sample_rate; + } + + private: + std::pair, std::vector> Process( + const knf::StftResult &stft_result, int32_t frame_index, + OfflineSpeechDenoiserGtcrnModel::States states, + OfflineSpeechDenoiserGtcrnModel::States *next_states) const { + const auto &meta = model_.GetMetaData(); + int32_t n_fft = meta.n_fft; + std::vector x((n_fft / 2 + 1) * 2); + + const float *p_real = + stft_result.real.data() + frame_index * (n_fft / 2 + 1); + const float *p_imag = + stft_result.imag.data() + frame_index * (n_fft / 2 + 1); + + for (int32_t i = 0; i < n_fft / 2 + 1; ++i) { + x[2 * i] = p_real[i]; + x[2 * i + 1] = p_imag[i]; + } + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + std::array x_shape{1, n_fft / 2 + 1, 1, 2}; + Ort::Value x_tensor = Ort::Value::CreateTensor( + memory_info, x.data(), x.size(), x_shape.data(), x_shape.size()); + + Ort::Value output{nullptr}; + std::tie(output, *next_states) = + model_.Run(std::move(x_tensor), std::move(states)); + + std::vector real(n_fft / 2 + 1); + std::vector imag(n_fft / 2 + 1); + const auto *p = output.GetTensorData(); + for (int32_t i = 0; i < n_fft / 2 + 1; ++i) { + real[i] = p[2 * i]; + imag[i] = p[2 * i + 1]; + } + + return {std::move(real), std::move(imag)}; + } + + private: + OfflineSpeechDenoiserGtcrnModel model_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_IMPL_H_ diff --git a/sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model-config.cc b/sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model-config.cc new file mode 100644 index 00000000..049e517d --- /dev/null +++ b/sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model-config.cc @@ -0,0 +1,40 @@ +// sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model-config.cc +// +// Copyright (c) 2025 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model-config.h" + +#include + +#include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/macros.h" + +namespace sherpa_onnx { + +void OfflineSpeechDenoiserGtcrnModelConfig::Register(ParseOptions *po) { + po->Register("speech-denoiser-gtcrn-model", &model, + "Path to the gtcrn model for speech denoising"); +} + +bool OfflineSpeechDenoiserGtcrnModelConfig::Validate() const { + if (model.empty()) { + SHERPA_ONNX_LOGE("Please provide --speech-denoiser-gtcrn-model"); + return false; + } + + if (!FileExists(model)) { + SHERPA_ONNX_LOGE("gtcrn model file '%s' does not exist", model.c_str()); + return false; + } + return true; +} + +std::string OfflineSpeechDenoiserGtcrnModelConfig::ToString() const { + std::ostringstream os; + + os << "OfflineSpeechDenoiserGtcrnModelConfig("; + os << "model=\"" << model << "\")"; + return os.str(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model-config.h b/sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model-config.h new file mode 100644 index 00000000..8d504c4f --- /dev/null +++ b/sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model-config.h @@ -0,0 +1,24 @@ +// sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model-config.h +// +// Copyright (c) 2025 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_MODEL_CONFIG_H_ + +#include + +#include "sherpa-onnx/csrc/parse-options.h" + +namespace sherpa_onnx { + +struct OfflineSpeechDenoiserGtcrnModelConfig { + std::string model; + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_MODEL_CONFIG_H_ diff --git a/sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model-meta-data.h b/sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model-meta-data.h new file mode 100644 index 00000000..8cf0cdab --- /dev/null +++ b/sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model-meta-data.h @@ -0,0 +1,31 @@ +// sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model-meta-data.h +// +// Copyright (c) 2025 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_MODEL_META_DATA_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_MODEL_META_DATA_H_ + +#include +#include +#include + +namespace sherpa_onnx { + +// please refer to +// https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/kokoro/add-meta-data.py +struct OfflineSpeechDenoiserGtcrnModelMetaData { + int32_t sample_rate = 0; + int32_t version = 1; + int32_t n_fft = 0; + int32_t hop_length = 0; + int32_t window_length = 0; + std::string window_type; + + std::vector conv_cache_shape; + std::vector tra_cache_shape; + std::vector inter_cache_shape; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_MODEL_META_DATA_H_ diff --git a/sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model.cc b/sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model.cc new file mode 100644 index 00000000..ef7e04f5 --- /dev/null +++ b/sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model.cc @@ -0,0 +1,196 @@ +// sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model.cc +// +// Copyright (c) 2025 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model.h" + +#include +#include +#include +#include + +#include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/onnx-utils.h" +#include "sherpa-onnx/csrc/session.h" +#include "sherpa-onnx/csrc/text-utils.h" + +namespace sherpa_onnx { + +class OfflineSpeechDenoiserGtcrnModel::Impl { + public: + explicit Impl(const OfflineSpeechDenoiserModelConfig &config) + : config_(config), + env_(ORT_LOGGING_LEVEL_ERROR), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(config.gtcrn.model); + Init(buf.data(), buf.size()); + } + } + + template + Impl(Manager *mgr, const OfflineSpeechDenoiserModelConfig &config) + : config_(config), + env_(ORT_LOGGING_LEVEL_ERROR), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(mgr, config.gtcrn.model); + Init(buf.data(), buf.size()); + } + } + + const OfflineSpeechDenoiserGtcrnModelMetaData &GetMetaData() const { + return meta_; + } + + States GetInitStates() const { + Ort::Value conv_cache = Ort::Value::CreateTensor( + allocator_, meta_.conv_cache_shape.data(), + meta_.conv_cache_shape.size()); + + Ort::Value tra_cache = Ort::Value::CreateTensor( + allocator_, meta_.tra_cache_shape.data(), meta_.tra_cache_shape.size()); + + Ort::Value inter_cache = Ort::Value::CreateTensor( + allocator_, meta_.inter_cache_shape.data(), + meta_.inter_cache_shape.size()); + + Fill(&conv_cache, 0); + Fill(&tra_cache, 0); + Fill(&inter_cache, 0); + + std::vector states; + + states.reserve(3); + states.push_back(std::move(conv_cache)); + states.push_back(std::move(tra_cache)); + states.push_back(std::move(inter_cache)); + + return states; + } + + std::pair Run(Ort::Value x, States states) const { + std::vector inputs; + inputs.reserve(1 + states.size()); + inputs.push_back(std::move(x)); + for (auto &s : states) { + inputs.push_back(std::move(s)); + } + + auto out = + sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), + output_names_ptr_.data(), output_names_ptr_.size()); + + std::vector next_states; + next_states.reserve(out.size() - 1); + for (int32_t k = 1; k < out.size(); ++k) { + next_states.push_back(std::move(out[k])); + } + + return {std::move(out[0]), std::move(next_states)}; + } + + private: + void Init(void *model_data, size_t model_data_length) { + sess_ = std::make_unique(env_, model_data, model_data_length, + sess_opts_); + + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); + + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); + + Ort::ModelMetadata meta_data = sess_->GetModelMetadata(); + if (config_.debug) { + std::ostringstream os; + os << "---gtcrn model---\n"; + PrintModelMetadata(os, meta_data); + + os << "----------input names----------\n"; + int32_t i = 0; + for (const auto &s : input_names_) { + os << i << " " << s << "\n"; + ++i; + } + os << "----------output names----------\n"; + i = 0; + for (const auto &s : output_names_) { + os << i << " " << s << "\n"; + ++i; + } + +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s\n", os.str().c_str()); +#else + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); +#endif + } + + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below + + std::string model_type; + SHERPA_ONNX_READ_META_DATA_STR(model_type, "model_type"); + if (model_type != "gtcrn") { + SHERPA_ONNX_LOGE("Expect model type 'gtcrn'. Given: '%s'", + model_type.c_str()); + SHERPA_ONNX_EXIT(-1); + } + + SHERPA_ONNX_READ_META_DATA(meta_.sample_rate, "sample_rate"); + SHERPA_ONNX_READ_META_DATA(meta_.n_fft, "n_fft"); + SHERPA_ONNX_READ_META_DATA(meta_.hop_length, "hop_length"); + SHERPA_ONNX_READ_META_DATA(meta_.window_length, "window_length"); + SHERPA_ONNX_READ_META_DATA_STR(meta_.window_type, "window_type"); + SHERPA_ONNX_READ_META_DATA(meta_.version, "version"); + + SHERPA_ONNX_READ_META_DATA_VEC(meta_.conv_cache_shape, "conv_cache_shape"); + SHERPA_ONNX_READ_META_DATA_VEC(meta_.tra_cache_shape, "tra_cache_shape"); + SHERPA_ONNX_READ_META_DATA_VEC(meta_.inter_cache_shape, + "inter_cache_shape"); + } + + private: + OfflineSpeechDenoiserModelConfig config_; + OfflineSpeechDenoiserGtcrnModelMetaData meta_; + + Ort::Env env_; + Ort::SessionOptions sess_opts_; + Ort::AllocatorWithDefaultOptions allocator_; + + std::unique_ptr sess_; + + std::vector input_names_; + std::vector input_names_ptr_; + + std::vector output_names_; + std::vector output_names_ptr_; +}; + +OfflineSpeechDenoiserGtcrnModel::~OfflineSpeechDenoiserGtcrnModel() = default; + +OfflineSpeechDenoiserGtcrnModel::OfflineSpeechDenoiserGtcrnModel( + const OfflineSpeechDenoiserModelConfig &config) + : impl_(std::make_unique(config)) {} + +template +OfflineSpeechDenoiserGtcrnModel::OfflineSpeechDenoiserGtcrnModel( + Manager *mgr, const OfflineSpeechDenoiserModelConfig &config) + : impl_(std::make_unique(mgr, config)) {} + +OfflineSpeechDenoiserGtcrnModel::States +OfflineSpeechDenoiserGtcrnModel::GetInitStates() const { + return impl_->GetInitStates(); +} + +std::pair +OfflineSpeechDenoiserGtcrnModel::Run(Ort::Value x, States states) const { + return impl_->Run(std::move(x), std::move(states)); +} + +const OfflineSpeechDenoiserGtcrnModelMetaData & +OfflineSpeechDenoiserGtcrnModel::GetMetaData() const { + return impl_->GetMetaData(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model.h b/sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model.h new file mode 100644 index 00000000..bcd84e74 --- /dev/null +++ b/sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model.h @@ -0,0 +1,42 @@ +// sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model.h +// +// Copyright (c) 2025 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_MODEL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_MODEL_H_ +#include +#include +#include + +#include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model-meta-data.h" +#include "sherpa-onnx/csrc/offline-speech-denoiser-model-config.h" +#include "sherpa-onnx/csrc/offline-speech-denoiser.h" + +namespace sherpa_onnx { + +class OfflineSpeechDenoiserGtcrnModel { + public: + ~OfflineSpeechDenoiserGtcrnModel(); + explicit OfflineSpeechDenoiserGtcrnModel( + const OfflineSpeechDenoiserModelConfig &config); + + template + OfflineSpeechDenoiserGtcrnModel( + Manager *mgr, const OfflineSpeechDenoiserModelConfig &config); + + using States = std::vector; + + States GetInitStates() const; + + std::pair Run(Ort::Value x, States states) const; + + const OfflineSpeechDenoiserGtcrnModelMetaData &GetMetaData() const; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_MODEL_H_ diff --git a/sherpa-onnx/csrc/offline-speech-denoiser-impl.cc b/sherpa-onnx/csrc/offline-speech-denoiser-impl.cc new file mode 100644 index 00000000..ec3302a1 --- /dev/null +++ b/sherpa-onnx/csrc/offline-speech-denoiser-impl.cc @@ -0,0 +1,53 @@ +// sherpa-onnx/csrc/offline-speech-denoiser-impl.cc +// +// Copyright (c) 2025 Xiaomi Corporation +#include "sherpa-onnx/csrc/offline-speech-denoiser-impl.h" + +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-impl.h" + +namespace sherpa_onnx { + +std::unique_ptr OfflineSpeechDenoiserImpl::Create( + const OfflineSpeechDenoiserConfig &config) { + if (!config.model.gtcrn.model.empty()) { + return std::make_unique(config); + } + SHERPA_ONNX_LOGE("Please provide a speech denoising model."); + return nullptr; +} + +template +std::unique_ptr OfflineSpeechDenoiserImpl::Create( + Manager *mgr, const OfflineSpeechDenoiserConfig &config) { + if (!config.model.gtcrn.model.empty()) { + return std::make_unique(mgr, config); + } + SHERPA_ONNX_LOGE("Please provide a speech denoising model."); + return nullptr; +} + +#if __ANDROID_API__ >= 9 +template std::unique_ptr +OfflineSpeechDenoiserImpl::Create(AAssetManager *mgr, + const OfflineSpeechDenoiserConfig &config); +#endif + +#if __OHOS__ +template std::unique_ptr +OfflineSpeechDenoiserImpl::Create(NativeResourceManager *mgr, + const OfflineSpeechDenoiserConfig &config); +#endif + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-speech-denoiser-impl.h b/sherpa-onnx/csrc/offline-speech-denoiser-impl.h new file mode 100644 index 00000000..d31f37b0 --- /dev/null +++ b/sherpa-onnx/csrc/offline-speech-denoiser-impl.h @@ -0,0 +1,33 @@ +// sherpa-onnx/csrc/offline-speaker-speech-denoiser-impl.h +// +// Copyright (c) 2025 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_IMPL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_IMPL_H_ + +#include + +#include "sherpa-onnx/csrc/offline-speech-denoiser.h" + +namespace sherpa_onnx { + +class OfflineSpeechDenoiserImpl { + public: + virtual ~OfflineSpeechDenoiserImpl() = default; + + static std::unique_ptr Create( + const OfflineSpeechDenoiserConfig &config); + + template + static std::unique_ptr Create( + Manager *mgr, const OfflineSpeechDenoiserConfig &config); + + virtual DenoisedAudio Run(const float *samples, int32_t n, + int32_t sample_rate) const = 0; + + virtual int32_t GetSampleRate() const = 0; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_IMPL_H_ diff --git a/sherpa-onnx/csrc/offline-speech-denoiser-model-config.cc b/sherpa-onnx/csrc/offline-speech-denoiser-model-config.cc new file mode 100644 index 00000000..af003d1d --- /dev/null +++ b/sherpa-onnx/csrc/offline-speech-denoiser-model-config.cc @@ -0,0 +1,40 @@ +// sherpa-onnx/csrc/offline-speech-denoiser-model-config.cc +// +// Copyright (c) 2025 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-speech-denoiser-model-config.h" + +#include + +namespace sherpa_onnx { + +void OfflineSpeechDenoiserModelConfig::Register(ParseOptions *po) { + gtcrn.Register(po); + + po->Register("num-threads", &num_threads, + "Number of threads to run the neural network"); + + po->Register("debug", &debug, + "true to print model information while loading it."); + + po->Register("provider", &provider, + "Specify a provider to use: cpu, cuda, coreml"); +} + +bool OfflineSpeechDenoiserModelConfig::Validate() const { + return gtcrn.Validate(); +} + +std::string OfflineSpeechDenoiserModelConfig::ToString() const { + std::ostringstream os; + + os << "OfflineSpeechDenoiserModelConfig("; + os << "gtcrn=" << gtcrn.ToString() << ", "; + os << "num_threads=" << num_threads << ", "; + os << "debug=" << (debug ? "True" : "False") << ", "; + os << "provider=\"" << provider << "\")"; + + return os.str(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-speech-denoiser-model-config.h b/sherpa-onnx/csrc/offline-speech-denoiser-model-config.h new file mode 100644 index 00000000..0c15e660 --- /dev/null +++ b/sherpa-onnx/csrc/offline-speech-denoiser-model-config.h @@ -0,0 +1,39 @@ +// sherpa-onnx/csrc/offline-speech-denoiser-model-config.h +// +// Copyright (c) 2025 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_MODEL_CONFIG_H_ + +#include + +#include "sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model-config.h" +#include "sherpa-onnx/csrc/parse-options.h" + +namespace sherpa_onnx { + +struct OfflineSpeechDenoiserModelConfig { + OfflineSpeechDenoiserGtcrnModelConfig gtcrn; + + int32_t num_threads = 1; + bool debug = false; + std::string provider = "cpu"; + + OfflineSpeechDenoiserModelConfig() = default; + + OfflineSpeechDenoiserModelConfig(OfflineSpeechDenoiserGtcrnModelConfig gtcrn, + int32_t num_threads, bool debug, + const std::string &provider) + : gtcrn(gtcrn), + num_threads(num_threads), + debug(debug), + provider(provider) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_MODEL_CONFIG_H_ diff --git a/sherpa-onnx/csrc/offline-speech-denoiser.cc b/sherpa-onnx/csrc/offline-speech-denoiser.cc new file mode 100644 index 00000000..afdd4d9f --- /dev/null +++ b/sherpa-onnx/csrc/offline-speech-denoiser.cc @@ -0,0 +1,64 @@ +// sherpa-onnx/csrc/offline-speech-denoiser.h +// +// Copyright (c) 2025 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-speech-denoiser.h" + +#include "sherpa-onnx/csrc/offline-speech-denoiser-impl.h" + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +namespace sherpa_onnx { + +void OfflineSpeechDenoiserConfig::Register(ParseOptions *po) { + model.Register(po); +} + +bool OfflineSpeechDenoiserConfig::Validate() const { return model.Validate(); } + +std::string OfflineSpeechDenoiserConfig::ToString() const { + std::ostringstream os; + + os << "OfflineSpeechDenoiserConfig("; + os << "model=" << model.ToString() << ")"; + return os.str(); +} + +template +OfflineSpeechDenoiser::OfflineSpeechDenoiser( + Manager *mgr, const OfflineSpeechDenoiserConfig &config) + : impl_(OfflineSpeechDenoiserImpl::Create(mgr, config)) {} + +OfflineSpeechDenoiser::OfflineSpeechDenoiser( + const OfflineSpeechDenoiserConfig &config) + : impl_(OfflineSpeechDenoiserImpl::Create(config)) {} + +OfflineSpeechDenoiser::~OfflineSpeechDenoiser() = default; + +DenoisedAudio OfflineSpeechDenoiser::Run(const float *samples, int32_t n, + int32_t sample_rate) const { + return impl_->Run(samples, n, sample_rate); +} + +int32_t OfflineSpeechDenoiser::GetSampleRate() const { + return impl_->GetSampleRate(); +} + +#if __ANDROID_API__ >= 9 +template OfflineSpeechDenoiser::OfflineSpeechDenoiser( + AAssetManager *mgr, const OfflineSpeechDenoiserConfig &config); +#endif + +#if __OHOS__ +template OfflineSpeechDenoiser::OfflineSpeechDenoiser( + NativeResourceManager *mgr, const OfflineSpeechDenoiserConfig &config); +#endif + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-speech-denoiser.h b/sherpa-onnx/csrc/offline-speech-denoiser.h new file mode 100644 index 00000000..80422ba4 --- /dev/null +++ b/sherpa-onnx/csrc/offline-speech-denoiser.h @@ -0,0 +1,61 @@ +// sherpa-onnx/csrc/offline-speech-denoiser.h +// +// Copyright (c) 2025 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_H_ + +#include +#include +#include + +#include "sherpa-onnx/csrc/offline-speech-denoiser-model-config.h" +#include "sherpa-onnx/csrc/parse-options.h" + +namespace sherpa_onnx { + +struct DenoisedAudio { + std::vector samples; + int32_t sample_rate; +}; + +struct OfflineSpeechDenoiserConfig { + OfflineSpeechDenoiserModelConfig model; + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +class OfflineSpeechDenoiserImpl; + +class OfflineSpeechDenoiser { + public: + explicit OfflineSpeechDenoiser(const OfflineSpeechDenoiserConfig &config); + ~OfflineSpeechDenoiser(); + + template + OfflineSpeechDenoiser(Manager *mgr, + const OfflineSpeechDenoiserConfig &config); + + /* + * @param samples 1-D array of audio samples. Each sample is in the + * range [-1, 1]. + * @param n Number of samples + * @param sample_rate Sample rate of the input samples + * + */ + DenoisedAudio Run(const float *samples, int32_t n, int32_t sample_rate) const; + + /* + * Return the sample rate of the denoised audio + */ + int32_t GetSampleRate() const; + + private: + std::unique_ptr impl_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_H_ diff --git a/sherpa-onnx/csrc/offline-tts-kokoro-model-meta-data.h b/sherpa-onnx/csrc/offline-tts-kokoro-model-meta-data.h index b37babb3..37446241 100644 --- a/sherpa-onnx/csrc/offline-tts-kokoro-model-meta-data.h +++ b/sherpa-onnx/csrc/offline-tts-kokoro-model-meta-data.h @@ -1,4 +1,4 @@ -// sherpa-onnx/csrc/offline-tts-kokoro-model-metadata.h +// sherpa-onnx/csrc/offline-tts-kokoro-model-meta-data.h // // Copyright (c) 2025 Xiaomi Corporation diff --git a/sherpa-onnx/csrc/sherpa-onnx-offline-denoiser.cc b/sherpa-onnx/csrc/sherpa-onnx-offline-denoiser.cc new file mode 100644 index 00000000..b8ce169f --- /dev/null +++ b/sherpa-onnx/csrc/sherpa-onnx-offline-denoiser.cc @@ -0,0 +1,95 @@ +// sherpa-onnx/csrc/sherpa-onnx-offline-denoiser.cc +// +// Copyright (c) 2025 Xiaomi Corporation +#include + +#include // NOLINT + +#include "sherpa-onnx/csrc/offline-speech-denoiser.h" +#include "sherpa-onnx/csrc/wave-reader.h" +#include "sherpa-onnx/csrc/wave-writer.h" + +int main(int32_t argc, char *argv[]) { + const char *kUsageMessage = R"usage( +Non-stremaing speech denoising with sherpa-onnx. + +Please visit +https://github.com/k2-fsa/sherpa-onnx/releases/tag/speech-enhancement-models +to download models. + +Usage: + +(1) Use gtcrn models + +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speech-enhancement-models/gtcrn_simple.onnx +./bin/sherpa-onnx-offline-denoiser \ + --speech-denoiser-gtcrn-model=gtcrn_simple.onnx \ + --input-wav input.wav \ + --output-wav output_16k.wav +)usage"; + + sherpa_onnx::ParseOptions po(kUsageMessage); + sherpa_onnx::OfflineSpeechDenoiserConfig config; + std::string input_wave; + std::string output_wave; + + config.Register(&po); + po.Register("input-wav", &input_wave, "Path to input wav."); + po.Register("output-wav", &output_wave, "Path to output wav"); + + po.Read(argc, argv); + if (po.NumArgs() != 0) { + fprintf(stderr, "Please don't give positional arguments\n"); + po.PrintUsage(); + exit(EXIT_FAILURE); + } + fprintf(stderr, "%s\n", config.ToString().c_str()); + + if (input_wave.empty()) { + fprintf(stderr, "Please provide --input-wav\n"); + po.PrintUsage(); + exit(EXIT_FAILURE); + } + + if (output_wave.empty()) { + fprintf(stderr, "Please provide --output-wav\n"); + po.PrintUsage(); + exit(EXIT_FAILURE); + } + + sherpa_onnx::OfflineSpeechDenoiser denoiser(config); + int32_t sampling_rate = -1; + bool is_ok = false; + std::vector samples = + sherpa_onnx::ReadWave(input_wave, &sampling_rate, &is_ok); + if (!is_ok) { + fprintf(stderr, "Failed to read '%s'\n", input_wave.c_str()); + return -1; + } + + fprintf(stderr, "Started\n"); + const auto begin = std::chrono::steady_clock::now(); + auto result = denoiser.Run(samples.data(), samples.size(), sampling_rate); + const auto end = std::chrono::steady_clock::now(); + + float elapsed_seconds = + std::chrono::duration_cast(end - begin) + .count() / + 1000.; + + fprintf(stderr, "Done\n"); + is_ok = sherpa_onnx::WriteWave(output_wave, result.sample_rate, + result.samples.data(), result.samples.size()); + if (is_ok) { + fprintf(stderr, "Saved to %s\n", output_wave.c_str()); + } else { + fprintf(stderr, "Failed to save to %s\n", output_wave.c_str()); + } + + float duration = samples.size() / static_cast(sampling_rate); + fprintf(stderr, "num threads: %d\n", config.model.num_threads); + fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds); + float rtf = elapsed_seconds / duration; + fprintf(stderr, "Real time factor (RTF): %.3f / %.3f = %.3f\n", + elapsed_seconds, duration, rtf); +}