// sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-impl.h // // Copyright (c) 2025 Xiaomi Corporation #ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_IMPL_H_ #define SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_IMPL_H_ #include #include #include #include #include "kaldi-native-fbank/csrc/feature-window.h" #include "kaldi-native-fbank/csrc/istft.h" #include "kaldi-native-fbank/csrc/stft.h" #include "sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model.h" #include "sherpa-onnx/csrc/offline-speech-denoiser-impl.h" #include "sherpa-onnx/csrc/offline-speech-denoiser.h" #include "sherpa-onnx/csrc/resample.h" namespace sherpa_onnx { class OfflineSpeechDenoiserGtcrnImpl : public OfflineSpeechDenoiserImpl { public: explicit OfflineSpeechDenoiserGtcrnImpl( const OfflineSpeechDenoiserConfig &config) : model_(config.model) {} template OfflineSpeechDenoiserGtcrnImpl(Manager *mgr, const OfflineSpeechDenoiserConfig &config) : model_(mgr, config.model) {} DenoisedAudio Run(const float *samples, int32_t n, int32_t sample_rate) const override { const auto &meta = model_.GetMetaData(); std::vector tmp; auto p = samples; if (sample_rate != meta.sample_rate) { SHERPA_ONNX_LOGE( "Creating a resampler:\n" " in_sample_rate: %d\n" " output_sample_rate: %d\n", sample_rate, meta.sample_rate); float min_freq = std::min(sample_rate, meta.sample_rate); float lowpass_cutoff = 0.99 * 0.5 * min_freq; int32_t lowpass_filter_width = 6; auto resampler = std::make_unique( sample_rate, meta.sample_rate, lowpass_cutoff, lowpass_filter_width); resampler->Resample(samples, n, true, &tmp); p = tmp.data(); n = tmp.size(); } knf::StftConfig stft_config; stft_config.n_fft = meta.n_fft; stft_config.hop_length = meta.hop_length; stft_config.win_length = meta.window_length; stft_config.window_type = meta.window_type; if (stft_config.window_type == "hann_sqrt") { auto window = knf::GetWindow("hann", stft_config.win_length); for (auto &w : window) { w = std::sqrt(w); } stft_config.window = std::move(window); } knf::Stft stft(stft_config); knf::StftResult stft_result = stft.Compute(p, n); auto states = model_.GetInitStates(); OfflineSpeechDenoiserGtcrnModel::States next_states; knf::StftResult enhanced_stft_result; enhanced_stft_result.num_frames = stft_result.num_frames; for (int32_t i = 0; i < stft_result.num_frames; ++i) { auto p = Process(stft_result, i, std::move(states), &next_states); states = std::move(next_states); enhanced_stft_result.real.insert(enhanced_stft_result.real.end(), p.first.begin(), p.first.end()); enhanced_stft_result.imag.insert(enhanced_stft_result.imag.end(), p.second.begin(), p.second.end()); } knf::IStft istft(stft_config); DenoisedAudio denoised_audio; denoised_audio.sample_rate = meta.sample_rate; denoised_audio.samples = istft.Compute(enhanced_stft_result); return denoised_audio; } int32_t GetSampleRate() const override { return model_.GetMetaData().sample_rate; } private: std::pair, std::vector> Process( const knf::StftResult &stft_result, int32_t frame_index, OfflineSpeechDenoiserGtcrnModel::States states, OfflineSpeechDenoiserGtcrnModel::States *next_states) const { const auto &meta = model_.GetMetaData(); int32_t n_fft = meta.n_fft; std::vector x((n_fft / 2 + 1) * 2); const float *p_real = stft_result.real.data() + frame_index * (n_fft / 2 + 1); const float *p_imag = stft_result.imag.data() + frame_index * (n_fft / 2 + 1); for (int32_t i = 0; i < n_fft / 2 + 1; ++i) { x[2 * i] = p_real[i]; x[2 * i + 1] = p_imag[i]; } auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); std::array x_shape{1, n_fft / 2 + 1, 1, 2}; Ort::Value x_tensor = Ort::Value::CreateTensor( memory_info, x.data(), x.size(), x_shape.data(), x_shape.size()); Ort::Value output{nullptr}; std::tie(output, *next_states) = model_.Run(std::move(x_tensor), std::move(states)); std::vector real(n_fft / 2 + 1); std::vector imag(n_fft / 2 + 1); const auto *p = output.GetTensorData(); for (int32_t i = 0; i < n_fft / 2 + 1; ++i) { real[i] = p[2 * i]; imag[i] = p[2 * i + 1]; } return {std::move(real), std::move(imag)}; } private: OfflineSpeechDenoiserGtcrnModel model_; }; } // namespace sherpa_onnx #endif // SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_IMPL_H_