Add C++ runtime for speech enhancement GTCRN models (#1977)
See also https://github.com/Xiaobin-Rong/gtcrn
This commit is contained in:
32
.github/scripts/test-offline-speech-denoiser.sh
vendored
Executable file
32
.github/scripts/test-offline-speech-denoiser.sh
vendored
Executable file
@@ -0,0 +1,32 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -ex
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
if [ -z $EXE ]; then
|
||||
EXE=./build/bin/sherpa-onnx-offline-denoiser
|
||||
fi
|
||||
|
||||
echo "EXE is $EXE"
|
||||
echo "PATH: $PATH"
|
||||
|
||||
which $EXE
|
||||
|
||||
log "------------------------------------------------------------"
|
||||
log "Run gtcrn"
|
||||
log "------------------------------------------------------------"
|
||||
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speech-enhancement-models/gtcrn_simple.onnx
|
||||
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speech-enhancement-models/speech_with_noise.wav
|
||||
|
||||
$EXE \
|
||||
--debug=1 \
|
||||
--speech-denoiser-gtcrn-model=./gtcrn_simple.onnx \
|
||||
--input-wav=./speech_with_noise.wav \
|
||||
--output-wav=./enhanced_speech_16k.wav
|
||||
|
||||
rm ./gtcrn_simple.onnx
|
||||
16
.github/workflows/linux.yaml
vendored
16
.github/workflows/linux.yaml
vendored
@@ -10,6 +10,7 @@ on:
|
||||
- '.github/workflows/linux.yaml'
|
||||
- '.github/scripts/test-kws.sh'
|
||||
- '.github/scripts/test-online-transducer.sh'
|
||||
- '.github/scripts/test-offline-speech-denoiser.sh'
|
||||
- '.github/scripts/test-online-paraformer.sh'
|
||||
- '.github/scripts/test-offline-transducer.sh'
|
||||
- '.github/scripts/test-offline-ctc.sh'
|
||||
@@ -31,6 +32,7 @@ on:
|
||||
paths:
|
||||
- '.github/workflows/linux.yaml'
|
||||
- '.github/scripts/test-kws.sh'
|
||||
- '.github/scripts/test-offline-speech-denoiser.sh'
|
||||
- '.github/scripts/test-online-transducer.sh'
|
||||
- '.github/scripts/test-online-paraformer.sh'
|
||||
- '.github/scripts/test-offline-transducer.sh'
|
||||
@@ -203,6 +205,15 @@ jobs:
|
||||
overwrite: true
|
||||
file: sherpa-onnx-*.tar.bz2
|
||||
|
||||
- name: Test offline speech denoiser
|
||||
shell: bash
|
||||
run: |
|
||||
du -h -d1 .
|
||||
export PATH=$PWD/build/bin:$PATH
|
||||
export EXE=sherpa-onnx-offline-denoiser
|
||||
|
||||
.github/scripts/test-offline-speech-denoiser.sh
|
||||
|
||||
- name: Test offline TTS
|
||||
if: matrix.with_tts == 'ON'
|
||||
shell: bash
|
||||
@@ -214,6 +225,11 @@ jobs:
|
||||
.github/scripts/test-offline-tts.sh
|
||||
du -h -d1 .
|
||||
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: speech-denoiser-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }}
|
||||
path: ./*speech*.wav
|
||||
|
||||
- uses: actions/upload-artifact@v4
|
||||
if: matrix.with_tts == 'ON'
|
||||
with:
|
||||
|
||||
11
.github/workflows/macos.yaml
vendored
11
.github/workflows/macos.yaml
vendored
@@ -7,6 +7,7 @@ on:
|
||||
tags:
|
||||
- 'v[0-9]+.[0-9]+.[0-9]+*'
|
||||
paths:
|
||||
- '.github/scripts/test-offline-speech-denoiser.sh'
|
||||
- '.github/workflows/macos.yaml'
|
||||
- '.github/scripts/test-kws.sh'
|
||||
- '.github/scripts/test-online-transducer.sh'
|
||||
@@ -28,6 +29,7 @@ on:
|
||||
branches:
|
||||
- master
|
||||
paths:
|
||||
- '.github/scripts/test-offline-speech-denoiser.sh'
|
||||
- '.github/workflows/macos.yaml'
|
||||
- '.github/scripts/test-kws.sh'
|
||||
- '.github/scripts/test-online-transducer.sh'
|
||||
@@ -160,6 +162,15 @@ jobs:
|
||||
overwrite: true
|
||||
file: sherpa-onnx-*osx-universal2*.tar.bz2
|
||||
|
||||
- name: Test offline speech denoiser
|
||||
shell: bash
|
||||
run: |
|
||||
du -h -d1 .
|
||||
export PATH=$PWD/build/bin:$PATH
|
||||
export EXE=sherpa-onnx-offline-denoiser
|
||||
|
||||
.github/scripts/test-offline-speech-denoiser.sh
|
||||
|
||||
- name: Test offline TTS
|
||||
if: matrix.with_tts == 'ON'
|
||||
shell: bash
|
||||
|
||||
@@ -12,9 +12,9 @@
|
||||
|--------------------------------|---------------|--------------------------|
|
||||
| ✔️ | ✔️ | ✔️ |
|
||||
|
||||
| Keyword spotting | Add punctuation |
|
||||
|------------------|-----------------|
|
||||
| ✔️ | ✔️ |
|
||||
| Keyword spotting | Add punctuation | Speech enhancement |
|
||||
|------------------|-----------------|--------------------|
|
||||
| ✔️ | ✔️ | ✔️ |
|
||||
|
||||
### Supported platforms
|
||||
|
||||
@@ -198,6 +198,7 @@ We also have spaces built using WebAssembly. They are listed below:
|
||||
| Spoken language identification (Language ID)| See multi-lingual [Whisper][Whisper] ASR models from [Speech recognition][asr-models]|
|
||||
| Punctuation | [Address][punct-models] |
|
||||
| Speaker segmentation | [Address][speaker-segmentation-models] |
|
||||
| Speech enhancement | [Address][speech-enhancement-models] |
|
||||
|
||||
</details>
|
||||
|
||||
@@ -442,3 +443,4 @@ sherpa-onnx in Unity. See also [#1695](https://github.com/k2-fsa/sherpa-onnx/iss
|
||||
[Moonshine tiny]: https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-moonshine-tiny-en-int8.tar.bz2
|
||||
[NVIDIA Jetson Orin NX]: https://developer.download.nvidia.com/assets/embedded/secure/jetson/orin_nx/docs/Jetson_Orin_NX_DS-10712-001_v0.5.pdf?RCPGu9Q6OVAOv7a7vgtwc9-BLScXRIWq6cSLuditMALECJ_dOj27DgnqAPGVnT2VpiNpQan9SyFy-9zRykR58CokzbXwjSA7Gj819e91AXPrWkGZR3oS1VLxiDEpJa_Y0lr7UT-N4GnXtb8NlUkP4GkCkkF_FQivGPrAucCUywL481GH_WpP_p7ziHU1Wg==&t=eyJscyI6ImdzZW8iLCJsc2QiOiJodHRwczovL3d3dy5nb29nbGUuY29tLmhrLyJ9
|
||||
[NVIDIA Jetson Nano B01]: https://www.seeedstudio.com/blog/2020/01/16/new-revision-of-jetson-nano-dev-kit-now-supports-new-jetson-nano-module/
|
||||
[speech-enhancement-models]: https://github.com/k2-fsa/sherpa-onnx/releases/tag/speech-enhancement-models
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
function(download_kaldi_native_fbank)
|
||||
include(FetchContent)
|
||||
|
||||
set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.20.0.tar.gz")
|
||||
set(kaldi_native_fbank_URL2 "https://hf-mirror.com/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.20.0.tar.gz")
|
||||
set(kaldi_native_fbank_HASH "SHA256=c6195b3cf374eef824644061d3c04f6b2a9267ae554169cbaa9865c89c1fe4f9")
|
||||
set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.21.1.tar.gz")
|
||||
set(kaldi_native_fbank_URL2 "https://hf-mirror.com/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.21.1.tar.gz")
|
||||
set(kaldi_native_fbank_HASH "SHA256=37c1aa230b00fe062791d800d8fc50aa3de215918d3dce6440699e67275d859e")
|
||||
|
||||
set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE)
|
||||
set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
|
||||
@@ -12,11 +12,11 @@ function(download_kaldi_native_fbank)
|
||||
# If you don't have access to the Internet,
|
||||
# please pre-download kaldi-native-fbank
|
||||
set(possible_file_locations
|
||||
$ENV{HOME}/Downloads/kaldi-native-fbank-1.20.0.tar.gz
|
||||
${CMAKE_SOURCE_DIR}/kaldi-native-fbank-1.20.0.tar.gz
|
||||
${CMAKE_BINARY_DIR}/kaldi-native-fbank-1.20.0.tar.gz
|
||||
/tmp/kaldi-native-fbank-1.20.0.tar.gz
|
||||
/star-fj/fangjun/download/github/kaldi-native-fbank-1.20.0.tar.gz
|
||||
$ENV{HOME}/Downloads/kaldi-native-fbank-1.21.1.tar.gz
|
||||
${CMAKE_SOURCE_DIR}/kaldi-native-fbank-1.21.1.tar.gz
|
||||
${CMAKE_BINARY_DIR}/kaldi-native-fbank-1.21.1.tar.gz
|
||||
/tmp/kaldi-native-fbank-1.21.1.tar.gz
|
||||
/star-fj/fangjun/download/github/kaldi-native-fbank-1.21.1.tar.gz
|
||||
)
|
||||
|
||||
foreach(f IN LISTS possible_file_locations)
|
||||
|
||||
@@ -186,6 +186,14 @@ if(SHERPA_ONNX_ENABLE_TTS)
|
||||
)
|
||||
endif()
|
||||
|
||||
list(APPEND sources
|
||||
offline-speech-denoiser-gtcrn-model-config.cc
|
||||
offline-speech-denoiser-gtcrn-model.cc
|
||||
offline-speech-denoiser-impl.cc
|
||||
offline-speech-denoiser-model-config.cc
|
||||
offline-speech-denoiser.cc
|
||||
)
|
||||
|
||||
if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION)
|
||||
list(APPEND sources
|
||||
fast-clustering-config.cc
|
||||
@@ -301,6 +309,7 @@ if(SHERPA_ONNX_ENABLE_BINARY)
|
||||
add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc)
|
||||
add_executable(sherpa-onnx-offline-punctuation sherpa-onnx-offline-punctuation.cc)
|
||||
add_executable(sherpa-onnx-online-punctuation sherpa-onnx-online-punctuation.cc)
|
||||
add_executable(sherpa-onnx-offline-denoiser sherpa-onnx-offline-denoiser.cc)
|
||||
|
||||
if(SHERPA_ONNX_ENABLE_TTS)
|
||||
add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc)
|
||||
@@ -318,6 +327,7 @@ if(SHERPA_ONNX_ENABLE_BINARY)
|
||||
sherpa-onnx-offline-language-identification
|
||||
sherpa-onnx-offline-parallel
|
||||
sherpa-onnx-offline-punctuation
|
||||
sherpa-onnx-offline-denoiser
|
||||
sherpa-onnx-online-punctuation
|
||||
)
|
||||
if(SHERPA_ONNX_ENABLE_TTS)
|
||||
|
||||
149
sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-impl.h
Normal file
149
sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-impl.h
Normal file
@@ -0,0 +1,149 @@
|
||||
// sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-impl.h
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_IMPL_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_IMPL_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "kaldi-native-fbank/csrc/feature-window.h"
|
||||
#include "kaldi-native-fbank/csrc/istft.h"
|
||||
#include "kaldi-native-fbank/csrc/stft.h"
|
||||
#include "sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model.h"
|
||||
#include "sherpa-onnx/csrc/offline-speech-denoiser-impl.h"
|
||||
#include "sherpa-onnx/csrc/offline-speech-denoiser.h"
|
||||
#include "sherpa-onnx/csrc/resample.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OfflineSpeechDenoiserGtcrnImpl : public OfflineSpeechDenoiserImpl {
|
||||
public:
|
||||
explicit OfflineSpeechDenoiserGtcrnImpl(
|
||||
const OfflineSpeechDenoiserConfig &config)
|
||||
: model_(config.model) {}
|
||||
|
||||
template <typename Manager>
|
||||
OfflineSpeechDenoiserGtcrnImpl(Manager *mgr,
|
||||
const OfflineSpeechDenoiserConfig &config)
|
||||
: model_(mgr, config.model) {}
|
||||
|
||||
DenoisedAudio Run(const float *samples, int32_t n,
|
||||
int32_t sample_rate) const override {
|
||||
SHERPA_ONNX_LOGE("n: %d, sample_rate: %d", n, sample_rate);
|
||||
const auto &meta = model_.GetMetaData();
|
||||
|
||||
std::vector<float> tmp;
|
||||
auto p = samples;
|
||||
|
||||
if (sample_rate != meta.sample_rate) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Creating a resampler:\n"
|
||||
" in_sample_rate: %d\n"
|
||||
" output_sample_rate: %d\n",
|
||||
sample_rate, meta.sample_rate);
|
||||
|
||||
float min_freq = std::min<int32_t>(sample_rate, meta.sample_rate);
|
||||
float lowpass_cutoff = 0.99 * 0.5 * min_freq;
|
||||
|
||||
int32_t lowpass_filter_width = 6;
|
||||
auto resampler = std::make_unique<LinearResample>(
|
||||
sample_rate, meta.sample_rate, lowpass_cutoff, lowpass_filter_width);
|
||||
resampler->Resample(samples, n, true, &tmp);
|
||||
p = tmp.data();
|
||||
n = tmp.size();
|
||||
}
|
||||
|
||||
knf::StftConfig stft_config;
|
||||
stft_config.n_fft = meta.n_fft;
|
||||
stft_config.hop_length = meta.hop_length;
|
||||
stft_config.win_length = meta.window_length;
|
||||
stft_config.window_type = meta.window_type;
|
||||
if (stft_config.window_type == "hann_sqrt") {
|
||||
auto window = knf::GetWindow("hann", stft_config.win_length);
|
||||
for (auto &w : window) {
|
||||
w = std::sqrt(w);
|
||||
}
|
||||
stft_config.window = std::move(window);
|
||||
}
|
||||
|
||||
knf::Stft stft(stft_config);
|
||||
knf::StftResult stft_result = stft.Compute(p, n);
|
||||
|
||||
auto states = model_.GetInitStates();
|
||||
OfflineSpeechDenoiserGtcrnModel::States next_states;
|
||||
|
||||
knf::StftResult enhanced_stft_result;
|
||||
enhanced_stft_result.num_frames = stft_result.num_frames;
|
||||
for (int32_t i = 0; i < stft_result.num_frames; ++i) {
|
||||
auto p = Process(stft_result, i, std::move(states), &next_states);
|
||||
states = std::move(next_states);
|
||||
|
||||
enhanced_stft_result.real.insert(enhanced_stft_result.real.end(),
|
||||
p.first.begin(), p.first.end());
|
||||
enhanced_stft_result.imag.insert(enhanced_stft_result.imag.end(),
|
||||
p.second.begin(), p.second.end());
|
||||
}
|
||||
|
||||
knf::IStft istft(stft_config);
|
||||
|
||||
DenoisedAudio denoised_audio;
|
||||
denoised_audio.sample_rate = meta.sample_rate;
|
||||
denoised_audio.samples = istft.Compute(enhanced_stft_result);
|
||||
return denoised_audio;
|
||||
}
|
||||
|
||||
int32_t GetSampleRate() const override {
|
||||
return model_.GetMetaData().sample_rate;
|
||||
}
|
||||
|
||||
private:
|
||||
std::pair<std::vector<float>, std::vector<float>> Process(
|
||||
const knf::StftResult &stft_result, int32_t frame_index,
|
||||
OfflineSpeechDenoiserGtcrnModel::States states,
|
||||
OfflineSpeechDenoiserGtcrnModel::States *next_states) const {
|
||||
const auto &meta = model_.GetMetaData();
|
||||
int32_t n_fft = meta.n_fft;
|
||||
std::vector<float> x((n_fft / 2 + 1) * 2);
|
||||
|
||||
const float *p_real =
|
||||
stft_result.real.data() + frame_index * (n_fft / 2 + 1);
|
||||
const float *p_imag =
|
||||
stft_result.imag.data() + frame_index * (n_fft / 2 + 1);
|
||||
|
||||
for (int32_t i = 0; i < n_fft / 2 + 1; ++i) {
|
||||
x[2 * i] = p_real[i];
|
||||
x[2 * i + 1] = p_imag[i];
|
||||
}
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
std::array<int64_t, 4> x_shape{1, n_fft / 2 + 1, 1, 2};
|
||||
Ort::Value x_tensor = Ort::Value::CreateTensor(
|
||||
memory_info, x.data(), x.size(), x_shape.data(), x_shape.size());
|
||||
|
||||
Ort::Value output{nullptr};
|
||||
std::tie(output, *next_states) =
|
||||
model_.Run(std::move(x_tensor), std::move(states));
|
||||
|
||||
std::vector<float> real(n_fft / 2 + 1);
|
||||
std::vector<float> imag(n_fft / 2 + 1);
|
||||
const auto *p = output.GetTensorData<float>();
|
||||
for (int32_t i = 0; i < n_fft / 2 + 1; ++i) {
|
||||
real[i] = p[2 * i];
|
||||
imag[i] = p[2 * i + 1];
|
||||
}
|
||||
|
||||
return {std::move(real), std::move(imag)};
|
||||
}
|
||||
|
||||
private:
|
||||
OfflineSpeechDenoiserGtcrnModel model_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_IMPL_H_
|
||||
@@ -0,0 +1,40 @@
|
||||
// sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model-config.cc
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model-config.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void OfflineSpeechDenoiserGtcrnModelConfig::Register(ParseOptions *po) {
|
||||
po->Register("speech-denoiser-gtcrn-model", &model,
|
||||
"Path to the gtcrn model for speech denoising");
|
||||
}
|
||||
|
||||
bool OfflineSpeechDenoiserGtcrnModelConfig::Validate() const {
|
||||
if (model.empty()) {
|
||||
SHERPA_ONNX_LOGE("Please provide --speech-denoiser-gtcrn-model");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!FileExists(model)) {
|
||||
SHERPA_ONNX_LOGE("gtcrn model file '%s' does not exist", model.c_str());
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string OfflineSpeechDenoiserGtcrnModelConfig::ToString() const {
|
||||
std::ostringstream os;
|
||||
|
||||
os << "OfflineSpeechDenoiserGtcrnModelConfig(";
|
||||
os << "model=\"" << model << "\")";
|
||||
return os.str();
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
@@ -0,0 +1,24 @@
|
||||
// sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model-config.h
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_MODEL_CONFIG_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_MODEL_CONFIG_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct OfflineSpeechDenoiserGtcrnModelConfig {
|
||||
std::string model;
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
|
||||
std::string ToString() const;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_MODEL_CONFIG_H_
|
||||
@@ -0,0 +1,31 @@
|
||||
// sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model-meta-data.h
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_MODEL_META_DATA_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_MODEL_META_DATA_H_
|
||||
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
// please refer to
|
||||
// https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/kokoro/add-meta-data.py
|
||||
struct OfflineSpeechDenoiserGtcrnModelMetaData {
|
||||
int32_t sample_rate = 0;
|
||||
int32_t version = 1;
|
||||
int32_t n_fft = 0;
|
||||
int32_t hop_length = 0;
|
||||
int32_t window_length = 0;
|
||||
std::string window_type;
|
||||
|
||||
std::vector<int64_t> conv_cache_shape;
|
||||
std::vector<int64_t> tra_cache_shape;
|
||||
std::vector<int64_t> inter_cache_shape;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_MODEL_META_DATA_H_
|
||||
196
sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model.cc
Normal file
196
sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model.cc
Normal file
@@ -0,0 +1,196 @@
|
||||
// sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model.cc
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/session.h"
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OfflineSpeechDenoiserGtcrnModel::Impl {
|
||||
public:
|
||||
explicit Impl(const OfflineSpeechDenoiserModelConfig &config)
|
||||
: config_(config),
|
||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
{
|
||||
auto buf = ReadFile(config.gtcrn.model);
|
||||
Init(buf.data(), buf.size());
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Manager>
|
||||
Impl(Manager *mgr, const OfflineSpeechDenoiserModelConfig &config)
|
||||
: config_(config),
|
||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
{
|
||||
auto buf = ReadFile(mgr, config.gtcrn.model);
|
||||
Init(buf.data(), buf.size());
|
||||
}
|
||||
}
|
||||
|
||||
const OfflineSpeechDenoiserGtcrnModelMetaData &GetMetaData() const {
|
||||
return meta_;
|
||||
}
|
||||
|
||||
States GetInitStates() const {
|
||||
Ort::Value conv_cache = Ort::Value::CreateTensor<float>(
|
||||
allocator_, meta_.conv_cache_shape.data(),
|
||||
meta_.conv_cache_shape.size());
|
||||
|
||||
Ort::Value tra_cache = Ort::Value::CreateTensor<float>(
|
||||
allocator_, meta_.tra_cache_shape.data(), meta_.tra_cache_shape.size());
|
||||
|
||||
Ort::Value inter_cache = Ort::Value::CreateTensor<float>(
|
||||
allocator_, meta_.inter_cache_shape.data(),
|
||||
meta_.inter_cache_shape.size());
|
||||
|
||||
Fill<float>(&conv_cache, 0);
|
||||
Fill<float>(&tra_cache, 0);
|
||||
Fill<float>(&inter_cache, 0);
|
||||
|
||||
std::vector<Ort::Value> states;
|
||||
|
||||
states.reserve(3);
|
||||
states.push_back(std::move(conv_cache));
|
||||
states.push_back(std::move(tra_cache));
|
||||
states.push_back(std::move(inter_cache));
|
||||
|
||||
return states;
|
||||
}
|
||||
|
||||
std::pair<Ort::Value, States> Run(Ort::Value x, States states) const {
|
||||
std::vector<Ort::Value> inputs;
|
||||
inputs.reserve(1 + states.size());
|
||||
inputs.push_back(std::move(x));
|
||||
for (auto &s : states) {
|
||||
inputs.push_back(std::move(s));
|
||||
}
|
||||
|
||||
auto out =
|
||||
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
|
||||
output_names_ptr_.data(), output_names_ptr_.size());
|
||||
|
||||
std::vector<Ort::Value> next_states;
|
||||
next_states.reserve(out.size() - 1);
|
||||
for (int32_t k = 1; k < out.size(); ++k) {
|
||||
next_states.push_back(std::move(out[k]));
|
||||
}
|
||||
|
||||
return {std::move(out[0]), std::move(next_states)};
|
||||
}
|
||||
|
||||
private:
|
||||
void Init(void *model_data, size_t model_data_length) {
|
||||
sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
|
||||
sess_opts_);
|
||||
|
||||
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
|
||||
|
||||
GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
|
||||
|
||||
Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
|
||||
if (config_.debug) {
|
||||
std::ostringstream os;
|
||||
os << "---gtcrn model---\n";
|
||||
PrintModelMetadata(os, meta_data);
|
||||
|
||||
os << "----------input names----------\n";
|
||||
int32_t i = 0;
|
||||
for (const auto &s : input_names_) {
|
||||
os << i << " " << s << "\n";
|
||||
++i;
|
||||
}
|
||||
os << "----------output names----------\n";
|
||||
i = 0;
|
||||
for (const auto &s : output_names_) {
|
||||
os << i << " " << s << "\n";
|
||||
++i;
|
||||
}
|
||||
|
||||
#if __OHOS__
|
||||
SHERPA_ONNX_LOGE("%{public}s\n", os.str().c_str());
|
||||
#else
|
||||
SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
|
||||
#endif
|
||||
}
|
||||
|
||||
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
|
||||
|
||||
std::string model_type;
|
||||
SHERPA_ONNX_READ_META_DATA_STR(model_type, "model_type");
|
||||
if (model_type != "gtcrn") {
|
||||
SHERPA_ONNX_LOGE("Expect model type 'gtcrn'. Given: '%s'",
|
||||
model_type.c_str());
|
||||
SHERPA_ONNX_EXIT(-1);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_READ_META_DATA(meta_.sample_rate, "sample_rate");
|
||||
SHERPA_ONNX_READ_META_DATA(meta_.n_fft, "n_fft");
|
||||
SHERPA_ONNX_READ_META_DATA(meta_.hop_length, "hop_length");
|
||||
SHERPA_ONNX_READ_META_DATA(meta_.window_length, "window_length");
|
||||
SHERPA_ONNX_READ_META_DATA_STR(meta_.window_type, "window_type");
|
||||
SHERPA_ONNX_READ_META_DATA(meta_.version, "version");
|
||||
|
||||
SHERPA_ONNX_READ_META_DATA_VEC(meta_.conv_cache_shape, "conv_cache_shape");
|
||||
SHERPA_ONNX_READ_META_DATA_VEC(meta_.tra_cache_shape, "tra_cache_shape");
|
||||
SHERPA_ONNX_READ_META_DATA_VEC(meta_.inter_cache_shape,
|
||||
"inter_cache_shape");
|
||||
}
|
||||
|
||||
private:
|
||||
OfflineSpeechDenoiserModelConfig config_;
|
||||
OfflineSpeechDenoiserGtcrnModelMetaData meta_;
|
||||
|
||||
Ort::Env env_;
|
||||
Ort::SessionOptions sess_opts_;
|
||||
Ort::AllocatorWithDefaultOptions allocator_;
|
||||
|
||||
std::unique_ptr<Ort::Session> sess_;
|
||||
|
||||
std::vector<std::string> input_names_;
|
||||
std::vector<const char *> input_names_ptr_;
|
||||
|
||||
std::vector<std::string> output_names_;
|
||||
std::vector<const char *> output_names_ptr_;
|
||||
};
|
||||
|
||||
OfflineSpeechDenoiserGtcrnModel::~OfflineSpeechDenoiserGtcrnModel() = default;
|
||||
|
||||
OfflineSpeechDenoiserGtcrnModel::OfflineSpeechDenoiserGtcrnModel(
|
||||
const OfflineSpeechDenoiserModelConfig &config)
|
||||
: impl_(std::make_unique<Impl>(config)) {}
|
||||
|
||||
template <typename Manager>
|
||||
OfflineSpeechDenoiserGtcrnModel::OfflineSpeechDenoiserGtcrnModel(
|
||||
Manager *mgr, const OfflineSpeechDenoiserModelConfig &config)
|
||||
: impl_(std::make_unique<Impl>(mgr, config)) {}
|
||||
|
||||
OfflineSpeechDenoiserGtcrnModel::States
|
||||
OfflineSpeechDenoiserGtcrnModel::GetInitStates() const {
|
||||
return impl_->GetInitStates();
|
||||
}
|
||||
|
||||
std::pair<Ort::Value, OfflineSpeechDenoiserGtcrnModel::States>
|
||||
OfflineSpeechDenoiserGtcrnModel::Run(Ort::Value x, States states) const {
|
||||
return impl_->Run(std::move(x), std::move(states));
|
||||
}
|
||||
|
||||
const OfflineSpeechDenoiserGtcrnModelMetaData &
|
||||
OfflineSpeechDenoiserGtcrnModel::GetMetaData() const {
|
||||
return impl_->GetMetaData();
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
42
sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model.h
Normal file
42
sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model.h
Normal file
@@ -0,0 +1,42 @@
|
||||
// sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model.h
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_MODEL_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_MODEL_H_
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model-meta-data.h"
|
||||
#include "sherpa-onnx/csrc/offline-speech-denoiser-model-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-speech-denoiser.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OfflineSpeechDenoiserGtcrnModel {
|
||||
public:
|
||||
~OfflineSpeechDenoiserGtcrnModel();
|
||||
explicit OfflineSpeechDenoiserGtcrnModel(
|
||||
const OfflineSpeechDenoiserModelConfig &config);
|
||||
|
||||
template <typename Manager>
|
||||
OfflineSpeechDenoiserGtcrnModel(
|
||||
Manager *mgr, const OfflineSpeechDenoiserModelConfig &config);
|
||||
|
||||
using States = std::vector<Ort::Value>;
|
||||
|
||||
States GetInitStates() const;
|
||||
|
||||
std::pair<Ort::Value, States> Run(Ort::Value x, States states) const;
|
||||
|
||||
const OfflineSpeechDenoiserGtcrnModelMetaData &GetMetaData() const;
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> impl_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_MODEL_H_
|
||||
53
sherpa-onnx/csrc/offline-speech-denoiser-impl.cc
Normal file
53
sherpa-onnx/csrc/offline-speech-denoiser-impl.cc
Normal file
@@ -0,0 +1,53 @@
|
||||
// sherpa-onnx/csrc/offline-speech-denoiser-impl.cc
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
#include "sherpa-onnx/csrc/offline-speech-denoiser-impl.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#if __OHOS__
|
||||
#include "rawfile/raw_file_manager.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-impl.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
std::unique_ptr<OfflineSpeechDenoiserImpl> OfflineSpeechDenoiserImpl::Create(
|
||||
const OfflineSpeechDenoiserConfig &config) {
|
||||
if (!config.model.gtcrn.model.empty()) {
|
||||
return std::make_unique<OfflineSpeechDenoiserGtcrnImpl>(config);
|
||||
}
|
||||
SHERPA_ONNX_LOGE("Please provide a speech denoising model.");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <typename Manager>
|
||||
std::unique_ptr<OfflineSpeechDenoiserImpl> OfflineSpeechDenoiserImpl::Create(
|
||||
Manager *mgr, const OfflineSpeechDenoiserConfig &config) {
|
||||
if (!config.model.gtcrn.model.empty()) {
|
||||
return std::make_unique<OfflineSpeechDenoiserGtcrnImpl>(mgr, config);
|
||||
}
|
||||
SHERPA_ONNX_LOGE("Please provide a speech denoising model.");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
template std::unique_ptr<OfflineSpeechDenoiserImpl>
|
||||
OfflineSpeechDenoiserImpl::Create(AAssetManager *mgr,
|
||||
const OfflineSpeechDenoiserConfig &config);
|
||||
#endif
|
||||
|
||||
#if __OHOS__
|
||||
template std::unique_ptr<OfflineSpeechDenoiserImpl>
|
||||
OfflineSpeechDenoiserImpl::Create(NativeResourceManager *mgr,
|
||||
const OfflineSpeechDenoiserConfig &config);
|
||||
#endif
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
33
sherpa-onnx/csrc/offline-speech-denoiser-impl.h
Normal file
33
sherpa-onnx/csrc/offline-speech-denoiser-impl.h
Normal file
@@ -0,0 +1,33 @@
|
||||
// sherpa-onnx/csrc/offline-speaker-speech-denoiser-impl.h
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_IMPL_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_IMPL_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-speech-denoiser.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OfflineSpeechDenoiserImpl {
|
||||
public:
|
||||
virtual ~OfflineSpeechDenoiserImpl() = default;
|
||||
|
||||
static std::unique_ptr<OfflineSpeechDenoiserImpl> Create(
|
||||
const OfflineSpeechDenoiserConfig &config);
|
||||
|
||||
template <typename Manager>
|
||||
static std::unique_ptr<OfflineSpeechDenoiserImpl> Create(
|
||||
Manager *mgr, const OfflineSpeechDenoiserConfig &config);
|
||||
|
||||
virtual DenoisedAudio Run(const float *samples, int32_t n,
|
||||
int32_t sample_rate) const = 0;
|
||||
|
||||
virtual int32_t GetSampleRate() const = 0;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_IMPL_H_
|
||||
40
sherpa-onnx/csrc/offline-speech-denoiser-model-config.cc
Normal file
40
sherpa-onnx/csrc/offline-speech-denoiser-model-config.cc
Normal file
@@ -0,0 +1,40 @@
|
||||
// sherpa-onnx/csrc/offline-speech-denoiser-model-config.cc
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-speech-denoiser-model-config.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void OfflineSpeechDenoiserModelConfig::Register(ParseOptions *po) {
|
||||
gtcrn.Register(po);
|
||||
|
||||
po->Register("num-threads", &num_threads,
|
||||
"Number of threads to run the neural network");
|
||||
|
||||
po->Register("debug", &debug,
|
||||
"true to print model information while loading it.");
|
||||
|
||||
po->Register("provider", &provider,
|
||||
"Specify a provider to use: cpu, cuda, coreml");
|
||||
}
|
||||
|
||||
bool OfflineSpeechDenoiserModelConfig::Validate() const {
|
||||
return gtcrn.Validate();
|
||||
}
|
||||
|
||||
std::string OfflineSpeechDenoiserModelConfig::ToString() const {
|
||||
std::ostringstream os;
|
||||
|
||||
os << "OfflineSpeechDenoiserModelConfig(";
|
||||
os << "gtcrn=" << gtcrn.ToString() << ", ";
|
||||
os << "num_threads=" << num_threads << ", ";
|
||||
os << "debug=" << (debug ? "True" : "False") << ", ";
|
||||
os << "provider=\"" << provider << "\")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
39
sherpa-onnx/csrc/offline-speech-denoiser-model-config.h
Normal file
39
sherpa-onnx/csrc/offline-speech-denoiser-model-config.h
Normal file
@@ -0,0 +1,39 @@
|
||||
// sherpa-onnx/csrc/offline-speech-denoiser-model-config.h
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_MODEL_CONFIG_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_MODEL_CONFIG_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model-config.h"
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct OfflineSpeechDenoiserModelConfig {
|
||||
OfflineSpeechDenoiserGtcrnModelConfig gtcrn;
|
||||
|
||||
int32_t num_threads = 1;
|
||||
bool debug = false;
|
||||
std::string provider = "cpu";
|
||||
|
||||
OfflineSpeechDenoiserModelConfig() = default;
|
||||
|
||||
OfflineSpeechDenoiserModelConfig(OfflineSpeechDenoiserGtcrnModelConfig gtcrn,
|
||||
int32_t num_threads, bool debug,
|
||||
const std::string &provider)
|
||||
: gtcrn(gtcrn),
|
||||
num_threads(num_threads),
|
||||
debug(debug),
|
||||
provider(provider) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
|
||||
std::string ToString() const;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_MODEL_CONFIG_H_
|
||||
64
sherpa-onnx/csrc/offline-speech-denoiser.cc
Normal file
64
sherpa-onnx/csrc/offline-speech-denoiser.cc
Normal file
@@ -0,0 +1,64 @@
|
||||
// sherpa-onnx/csrc/offline-speech-denoiser.h
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-speech-denoiser.h"
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-speech-denoiser-impl.h"
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#if __OHOS__
|
||||
#include "rawfile/raw_file_manager.h"
|
||||
#endif
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void OfflineSpeechDenoiserConfig::Register(ParseOptions *po) {
|
||||
model.Register(po);
|
||||
}
|
||||
|
||||
bool OfflineSpeechDenoiserConfig::Validate() const { return model.Validate(); }
|
||||
|
||||
std::string OfflineSpeechDenoiserConfig::ToString() const {
|
||||
std::ostringstream os;
|
||||
|
||||
os << "OfflineSpeechDenoiserConfig(";
|
||||
os << "model=" << model.ToString() << ")";
|
||||
return os.str();
|
||||
}
|
||||
|
||||
template <typename Manager>
|
||||
OfflineSpeechDenoiser::OfflineSpeechDenoiser(
|
||||
Manager *mgr, const OfflineSpeechDenoiserConfig &config)
|
||||
: impl_(OfflineSpeechDenoiserImpl::Create(mgr, config)) {}
|
||||
|
||||
OfflineSpeechDenoiser::OfflineSpeechDenoiser(
|
||||
const OfflineSpeechDenoiserConfig &config)
|
||||
: impl_(OfflineSpeechDenoiserImpl::Create(config)) {}
|
||||
|
||||
OfflineSpeechDenoiser::~OfflineSpeechDenoiser() = default;
|
||||
|
||||
DenoisedAudio OfflineSpeechDenoiser::Run(const float *samples, int32_t n,
|
||||
int32_t sample_rate) const {
|
||||
return impl_->Run(samples, n, sample_rate);
|
||||
}
|
||||
|
||||
int32_t OfflineSpeechDenoiser::GetSampleRate() const {
|
||||
return impl_->GetSampleRate();
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
template OfflineSpeechDenoiser::OfflineSpeechDenoiser(
|
||||
AAssetManager *mgr, const OfflineSpeechDenoiserConfig &config);
|
||||
#endif
|
||||
|
||||
#if __OHOS__
|
||||
template OfflineSpeechDenoiser::OfflineSpeechDenoiser(
|
||||
NativeResourceManager *mgr, const OfflineSpeechDenoiserConfig &config);
|
||||
#endif
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
61
sherpa-onnx/csrc/offline-speech-denoiser.h
Normal file
61
sherpa-onnx/csrc/offline-speech-denoiser.h
Normal file
@@ -0,0 +1,61 @@
|
||||
// sherpa-onnx/csrc/offline-speech-denoiser.h
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-speech-denoiser-model-config.h"
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct DenoisedAudio {
|
||||
std::vector<float> samples;
|
||||
int32_t sample_rate;
|
||||
};
|
||||
|
||||
struct OfflineSpeechDenoiserConfig {
|
||||
OfflineSpeechDenoiserModelConfig model;
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
|
||||
std::string ToString() const;
|
||||
};
|
||||
|
||||
class OfflineSpeechDenoiserImpl;
|
||||
|
||||
class OfflineSpeechDenoiser {
|
||||
public:
|
||||
explicit OfflineSpeechDenoiser(const OfflineSpeechDenoiserConfig &config);
|
||||
~OfflineSpeechDenoiser();
|
||||
|
||||
template <typename Manager>
|
||||
OfflineSpeechDenoiser(Manager *mgr,
|
||||
const OfflineSpeechDenoiserConfig &config);
|
||||
|
||||
/*
|
||||
* @param samples 1-D array of audio samples. Each sample is in the
|
||||
* range [-1, 1].
|
||||
* @param n Number of samples
|
||||
* @param sample_rate Sample rate of the input samples
|
||||
*
|
||||
*/
|
||||
DenoisedAudio Run(const float *samples, int32_t n, int32_t sample_rate) const;
|
||||
|
||||
/*
|
||||
* Return the sample rate of the denoised audio
|
||||
*/
|
||||
int32_t GetSampleRate() const;
|
||||
|
||||
private:
|
||||
std::unique_ptr<OfflineSpeechDenoiserImpl> impl_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_H_
|
||||
@@ -1,4 +1,4 @@
|
||||
// sherpa-onnx/csrc/offline-tts-kokoro-model-metadata.h
|
||||
// sherpa-onnx/csrc/offline-tts-kokoro-model-meta-data.h
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
|
||||
|
||||
95
sherpa-onnx/csrc/sherpa-onnx-offline-denoiser.cc
Normal file
95
sherpa-onnx/csrc/sherpa-onnx-offline-denoiser.cc
Normal file
@@ -0,0 +1,95 @@
|
||||
// sherpa-onnx/csrc/sherpa-onnx-offline-denoiser.cc
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
#include <stdio.h>
|
||||
|
||||
#include <chrono> // NOLINT
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-speech-denoiser.h"
|
||||
#include "sherpa-onnx/csrc/wave-reader.h"
|
||||
#include "sherpa-onnx/csrc/wave-writer.h"
|
||||
|
||||
int main(int32_t argc, char *argv[]) {
|
||||
const char *kUsageMessage = R"usage(
|
||||
Non-stremaing speech denoising with sherpa-onnx.
|
||||
|
||||
Please visit
|
||||
https://github.com/k2-fsa/sherpa-onnx/releases/tag/speech-enhancement-models
|
||||
to download models.
|
||||
|
||||
Usage:
|
||||
|
||||
(1) Use gtcrn models
|
||||
|
||||
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speech-enhancement-models/gtcrn_simple.onnx
|
||||
./bin/sherpa-onnx-offline-denoiser \
|
||||
--speech-denoiser-gtcrn-model=gtcrn_simple.onnx \
|
||||
--input-wav input.wav \
|
||||
--output-wav output_16k.wav
|
||||
)usage";
|
||||
|
||||
sherpa_onnx::ParseOptions po(kUsageMessage);
|
||||
sherpa_onnx::OfflineSpeechDenoiserConfig config;
|
||||
std::string input_wave;
|
||||
std::string output_wave;
|
||||
|
||||
config.Register(&po);
|
||||
po.Register("input-wav", &input_wave, "Path to input wav.");
|
||||
po.Register("output-wav", &output_wave, "Path to output wav");
|
||||
|
||||
po.Read(argc, argv);
|
||||
if (po.NumArgs() != 0) {
|
||||
fprintf(stderr, "Please don't give positional arguments\n");
|
||||
po.PrintUsage();
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
fprintf(stderr, "%s\n", config.ToString().c_str());
|
||||
|
||||
if (input_wave.empty()) {
|
||||
fprintf(stderr, "Please provide --input-wav\n");
|
||||
po.PrintUsage();
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
if (output_wave.empty()) {
|
||||
fprintf(stderr, "Please provide --output-wav\n");
|
||||
po.PrintUsage();
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
sherpa_onnx::OfflineSpeechDenoiser denoiser(config);
|
||||
int32_t sampling_rate = -1;
|
||||
bool is_ok = false;
|
||||
std::vector<float> samples =
|
||||
sherpa_onnx::ReadWave(input_wave, &sampling_rate, &is_ok);
|
||||
if (!is_ok) {
|
||||
fprintf(stderr, "Failed to read '%s'\n", input_wave.c_str());
|
||||
return -1;
|
||||
}
|
||||
|
||||
fprintf(stderr, "Started\n");
|
||||
const auto begin = std::chrono::steady_clock::now();
|
||||
auto result = denoiser.Run(samples.data(), samples.size(), sampling_rate);
|
||||
const auto end = std::chrono::steady_clock::now();
|
||||
|
||||
float elapsed_seconds =
|
||||
std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)
|
||||
.count() /
|
||||
1000.;
|
||||
|
||||
fprintf(stderr, "Done\n");
|
||||
is_ok = sherpa_onnx::WriteWave(output_wave, result.sample_rate,
|
||||
result.samples.data(), result.samples.size());
|
||||
if (is_ok) {
|
||||
fprintf(stderr, "Saved to %s\n", output_wave.c_str());
|
||||
} else {
|
||||
fprintf(stderr, "Failed to save to %s\n", output_wave.c_str());
|
||||
}
|
||||
|
||||
float duration = samples.size() / static_cast<float>(sampling_rate);
|
||||
fprintf(stderr, "num threads: %d\n", config.model.num_threads);
|
||||
fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
|
||||
float rtf = elapsed_seconds / duration;
|
||||
fprintf(stderr, "Real time factor (RTF): %.3f / %.3f = %.3f\n",
|
||||
elapsed_seconds, duration, rtf);
|
||||
}
|
||||
Reference in New Issue
Block a user