Add CXX API for speech enhancement GTCRN models (#1986)

This commit is contained in:
Fangjun Kuang
2025-03-11 17:07:52 +08:00
committed by GitHub
parent c5dbf1177c
commit 802119db17
9 changed files with 192 additions and 3 deletions

View File

@@ -513,4 +513,49 @@ void KeywordSpotter::Reset(const OnlineStream *s) const {
SherpaOnnxResetKeywordStream(p_, s->Get());
}
// ============================================================
// For Offline Speech Enhancement
// ============================================================
OfflineSpeechDenoiser OfflineSpeechDenoiser::Create(
const OfflineSpeechDenoiserConfig &config) {
struct SherpaOnnxOfflineSpeechDenoiserConfig c;
memset(&c, 0, sizeof(c));
c.model.gtcrn.model = config.model.gtcrn.model.c_str();
c.model.num_threads = config.model.num_threads;
c.model.provider = config.model.provider.c_str();
c.model.debug = config.model.debug;
auto p = SherpaOnnxCreateOfflineSpeechDenoiser(&c);
return OfflineSpeechDenoiser(p);
}
void OfflineSpeechDenoiser::Destroy(
const SherpaOnnxOfflineSpeechDenoiser *p) const {
SherpaOnnxDestroyOfflineSpeechDenoiser(p);
}
OfflineSpeechDenoiser::OfflineSpeechDenoiser(
const SherpaOnnxOfflineSpeechDenoiser *p)
: MoveOnly<OfflineSpeechDenoiser, SherpaOnnxOfflineSpeechDenoiser>(p) {}
DenoisedAudio OfflineSpeechDenoiser::Run(const float *samples, int32_t n,
int32_t sample_rate) const {
auto audio = SherpaOnnxOfflineSpeechDenoiserRun(p_, samples, n, sample_rate);
DenoisedAudio ans;
ans.samples = {audio->samples, audio->samples + audio->n};
ans.sample_rate = audio->sample_rate;
SherpaOnnxDestroyDenoisedAudio(audio);
return ans;
}
int32_t OfflineSpeechDenoiser::GetSampleRate() const {
return SherpaOnnxOfflineSpeechDenoiserGetSampleRate(p_);
}
} // namespace sherpa_onnx::cxx

View File

@@ -464,6 +464,42 @@ class SHERPA_ONNX_API KeywordSpotter
explicit KeywordSpotter(const SherpaOnnxKeywordSpotter *p);
};
struct OfflineSpeechDenoiserGtcrnModelConfig {
std::string model;
};
struct OfflineSpeechDenoiserModelConfig {
OfflineSpeechDenoiserGtcrnModelConfig gtcrn;
int32_t num_threads = 1;
int32_t debug = false;
std::string provider = "cpu";
};
struct OfflineSpeechDenoiserConfig {
OfflineSpeechDenoiserModelConfig model;
};
struct DenoisedAudio {
std::vector<float> samples; // in the range [-1, 1]
int32_t sample_rate;
};
class SHERPA_ONNX_API OfflineSpeechDenoiser
: public MoveOnly<OfflineSpeechDenoiser, SherpaOnnxOfflineSpeechDenoiser> {
public:
static OfflineSpeechDenoiser Create(
const OfflineSpeechDenoiserConfig &config);
void Destroy(const SherpaOnnxOfflineSpeechDenoiser *p) const;
DenoisedAudio Run(const float *samples, int32_t n, int32_t sample_rate) const;
int32_t GetSampleRate() const;
private:
explicit OfflineSpeechDenoiser(const SherpaOnnxOfflineSpeechDenoiser *p);
};
} // namespace sherpa_onnx::cxx
#endif // SHERPA_ONNX_C_API_CXX_API_H_