Add CXX API for speech enhancement GTCRN models (#1986)
This commit is contained in:
@@ -513,4 +513,49 @@ void KeywordSpotter::Reset(const OnlineStream *s) const {
|
||||
SherpaOnnxResetKeywordStream(p_, s->Get());
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// For Offline Speech Enhancement
|
||||
// ============================================================
|
||||
|
||||
OfflineSpeechDenoiser OfflineSpeechDenoiser::Create(
|
||||
const OfflineSpeechDenoiserConfig &config) {
|
||||
struct SherpaOnnxOfflineSpeechDenoiserConfig c;
|
||||
memset(&c, 0, sizeof(c));
|
||||
|
||||
c.model.gtcrn.model = config.model.gtcrn.model.c_str();
|
||||
|
||||
c.model.num_threads = config.model.num_threads;
|
||||
c.model.provider = config.model.provider.c_str();
|
||||
c.model.debug = config.model.debug;
|
||||
|
||||
auto p = SherpaOnnxCreateOfflineSpeechDenoiser(&c);
|
||||
|
||||
return OfflineSpeechDenoiser(p);
|
||||
}
|
||||
|
||||
void OfflineSpeechDenoiser::Destroy(
|
||||
const SherpaOnnxOfflineSpeechDenoiser *p) const {
|
||||
SherpaOnnxDestroyOfflineSpeechDenoiser(p);
|
||||
}
|
||||
|
||||
OfflineSpeechDenoiser::OfflineSpeechDenoiser(
|
||||
const SherpaOnnxOfflineSpeechDenoiser *p)
|
||||
: MoveOnly<OfflineSpeechDenoiser, SherpaOnnxOfflineSpeechDenoiser>(p) {}
|
||||
|
||||
DenoisedAudio OfflineSpeechDenoiser::Run(const float *samples, int32_t n,
|
||||
int32_t sample_rate) const {
|
||||
auto audio = SherpaOnnxOfflineSpeechDenoiserRun(p_, samples, n, sample_rate);
|
||||
|
||||
DenoisedAudio ans;
|
||||
ans.samples = {audio->samples, audio->samples + audio->n};
|
||||
ans.sample_rate = audio->sample_rate;
|
||||
SherpaOnnxDestroyDenoisedAudio(audio);
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
int32_t OfflineSpeechDenoiser::GetSampleRate() const {
|
||||
return SherpaOnnxOfflineSpeechDenoiserGetSampleRate(p_);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx::cxx
|
||||
|
||||
@@ -464,6 +464,42 @@ class SHERPA_ONNX_API KeywordSpotter
|
||||
explicit KeywordSpotter(const SherpaOnnxKeywordSpotter *p);
|
||||
};
|
||||
|
||||
struct OfflineSpeechDenoiserGtcrnModelConfig {
|
||||
std::string model;
|
||||
};
|
||||
|
||||
struct OfflineSpeechDenoiserModelConfig {
|
||||
OfflineSpeechDenoiserGtcrnModelConfig gtcrn;
|
||||
int32_t num_threads = 1;
|
||||
int32_t debug = false;
|
||||
std::string provider = "cpu";
|
||||
};
|
||||
|
||||
struct OfflineSpeechDenoiserConfig {
|
||||
OfflineSpeechDenoiserModelConfig model;
|
||||
};
|
||||
|
||||
struct DenoisedAudio {
|
||||
std::vector<float> samples; // in the range [-1, 1]
|
||||
int32_t sample_rate;
|
||||
};
|
||||
|
||||
class SHERPA_ONNX_API OfflineSpeechDenoiser
|
||||
: public MoveOnly<OfflineSpeechDenoiser, SherpaOnnxOfflineSpeechDenoiser> {
|
||||
public:
|
||||
static OfflineSpeechDenoiser Create(
|
||||
const OfflineSpeechDenoiserConfig &config);
|
||||
|
||||
void Destroy(const SherpaOnnxOfflineSpeechDenoiser *p) const;
|
||||
|
||||
DenoisedAudio Run(const float *samples, int32_t n, int32_t sample_rate) const;
|
||||
|
||||
int32_t GetSampleRate() const;
|
||||
|
||||
private:
|
||||
explicit OfflineSpeechDenoiser(const SherpaOnnxOfflineSpeechDenoiser *p);
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx::cxx
|
||||
|
||||
#endif // SHERPA_ONNX_C_API_CXX_API_H_
|
||||
|
||||
Reference in New Issue
Block a user