// sherpa-onnx/c-api/c-api.cc // // Copyright (c) 2023 Xiaomi Corporation #include "sherpa-onnx/c-api/c-api.h" #include #include #include #include #include "sherpa-onnx/csrc/display.h" #include "sherpa-onnx/csrc/offline-recognizer.h" #include "sherpa-onnx/csrc/online-recognizer.h" struct SherpaOnnxOnlineRecognizer { std::unique_ptr impl; }; struct SherpaOnnxOnlineStream { std::unique_ptr impl; explicit SherpaOnnxOnlineStream(std::unique_ptr p) : impl(std::move(p)) {} }; struct SherpaOnnxDisplay { std::unique_ptr impl; }; #define SHERPA_ONNX_OR(x, y) (x ? x : y) SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer( const SherpaOnnxOnlineRecognizerConfig *config) { sherpa_onnx::OnlineRecognizerConfig recognizer_config; recognizer_config.feat_config.sampling_rate = SHERPA_ONNX_OR(config->feat_config.sample_rate, 16000); recognizer_config.feat_config.feature_dim = SHERPA_ONNX_OR(config->feat_config.feature_dim, 80); recognizer_config.model_config.encoder_filename = SHERPA_ONNX_OR(config->model_config.encoder, ""); recognizer_config.model_config.decoder_filename = SHERPA_ONNX_OR(config->model_config.decoder, ""); recognizer_config.model_config.joiner_filename = SHERPA_ONNX_OR(config->model_config.joiner, ""); recognizer_config.model_config.tokens = SHERPA_ONNX_OR(config->model_config.tokens, ""); recognizer_config.model_config.num_threads = SHERPA_ONNX_OR(config->model_config.num_threads, 1); recognizer_config.model_config.provider = SHERPA_ONNX_OR(config->model_config.provider, "cpu"); recognizer_config.model_config.model_type = SHERPA_ONNX_OR(config->model_config.model_type, ""); recognizer_config.model_config.debug = SHERPA_ONNX_OR(config->model_config.debug, 0); recognizer_config.decoding_method = SHERPA_ONNX_OR(config->decoding_method, "greedy_search"); recognizer_config.max_active_paths = SHERPA_ONNX_OR(config->max_active_paths, 4); recognizer_config.enable_endpoint = SHERPA_ONNX_OR(config->enable_endpoint, 0); recognizer_config.endpoint_config.rule1.min_trailing_silence = SHERPA_ONNX_OR(config->rule1_min_trailing_silence, 2.4); recognizer_config.endpoint_config.rule2.min_trailing_silence = SHERPA_ONNX_OR(config->rule2_min_trailing_silence, 1.2); recognizer_config.endpoint_config.rule3.min_utterance_length = SHERPA_ONNX_OR(config->rule3_min_utterance_length, 20); if (config->model_config.debug) { fprintf(stderr, "%s\n", recognizer_config.ToString().c_str()); } SherpaOnnxOnlineRecognizer *recognizer = new SherpaOnnxOnlineRecognizer; recognizer->impl = std::make_unique(recognizer_config); return recognizer; } void DestroyOnlineRecognizer(SherpaOnnxOnlineRecognizer *recognizer) { delete recognizer; } SherpaOnnxOnlineStream *CreateOnlineStream( const SherpaOnnxOnlineRecognizer *recognizer) { SherpaOnnxOnlineStream *stream = new SherpaOnnxOnlineStream(recognizer->impl->CreateStream()); return stream; } void DestroyOnlineStream(SherpaOnnxOnlineStream *stream) { delete stream; } void AcceptWaveform(SherpaOnnxOnlineStream *stream, int32_t sample_rate, const float *samples, int32_t n) { stream->impl->AcceptWaveform(sample_rate, samples, n); } int32_t IsOnlineStreamReady(SherpaOnnxOnlineRecognizer *recognizer, SherpaOnnxOnlineStream *stream) { return recognizer->impl->IsReady(stream->impl.get()); } void DecodeOnlineStream(SherpaOnnxOnlineRecognizer *recognizer, SherpaOnnxOnlineStream *stream) { recognizer->impl->DecodeStream(stream->impl.get()); } void DecodeMultipleOnlineStreams(SherpaOnnxOnlineRecognizer *recognizer, SherpaOnnxOnlineStream **streams, int32_t n) { std::vector ss(n); for (int32_t i = 0; i != n; ++i) { ss[i] = streams[i]->impl.get(); } recognizer->impl->DecodeStreams(ss.data(), n); } SherpaOnnxOnlineRecognizerResult *GetOnlineStreamResult( SherpaOnnxOnlineRecognizer *recognizer, SherpaOnnxOnlineStream *stream) { sherpa_onnx::OnlineRecognizerResult result = recognizer->impl->GetResult(stream->impl.get()); const auto &text = result.text; auto r = new SherpaOnnxOnlineRecognizerResult; r->text = new char[text.size() + 1]; std::copy(text.begin(), text.end(), const_cast(r->text)); const_cast(r->text)[text.size()] = 0; return r; } void DestroyOnlineRecognizerResult(const SherpaOnnxOnlineRecognizerResult *r) { delete[] r->text; delete r; } void Reset(SherpaOnnxOnlineRecognizer *recognizer, SherpaOnnxOnlineStream *stream) { recognizer->impl->Reset(stream->impl.get()); } void InputFinished(SherpaOnnxOnlineStream *stream) { stream->impl->InputFinished(); } int32_t IsEndpoint(SherpaOnnxOnlineRecognizer *recognizer, SherpaOnnxOnlineStream *stream) { return recognizer->impl->IsEndpoint(stream->impl.get()); } SherpaOnnxDisplay *CreateDisplay(int32_t max_word_per_line) { SherpaOnnxDisplay *ans = new SherpaOnnxDisplay; ans->impl = std::make_unique(max_word_per_line); return ans; } void DestroyDisplay(SherpaOnnxDisplay *display) { delete display; } void SherpaOnnxPrint(SherpaOnnxDisplay *display, int32_t idx, const char *s) { display->impl->Print(idx, s); } // ============================================================ // For offline ASR (i.e., non-streaming ASR) // ============================================================ // struct SherpaOnnxOfflineRecognizer { std::unique_ptr impl; }; struct SherpaOnnxOfflineStream { std::unique_ptr impl; explicit SherpaOnnxOfflineStream( std::unique_ptr p) : impl(std::move(p)) {} }; SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer( const SherpaOnnxOfflineRecognizerConfig *config) { sherpa_onnx::OfflineRecognizerConfig recognizer_config; recognizer_config.feat_config.sampling_rate = SHERPA_ONNX_OR(config->feat_config.sample_rate, 16000); recognizer_config.feat_config.feature_dim = SHERPA_ONNX_OR(config->feat_config.feature_dim, 80); recognizer_config.model_config.transducer.encoder_filename = SHERPA_ONNX_OR(config->model_config.transducer.encoder, ""); recognizer_config.model_config.transducer.decoder_filename = SHERPA_ONNX_OR(config->model_config.transducer.decoder, ""); recognizer_config.model_config.transducer.joiner_filename = SHERPA_ONNX_OR(config->model_config.transducer.joiner, ""); recognizer_config.model_config.paraformer.model = SHERPA_ONNX_OR(config->model_config.paraformer.model, ""); recognizer_config.model_config.nemo_ctc.model = SHERPA_ONNX_OR(config->model_config.nemo_ctc.model, ""); recognizer_config.model_config.tokens = SHERPA_ONNX_OR(config->model_config.tokens, ""); recognizer_config.model_config.num_threads = SHERPA_ONNX_OR(config->model_config.num_threads, 1); recognizer_config.model_config.debug = SHERPA_ONNX_OR(config->model_config.debug, 0); recognizer_config.model_config.provider = SHERPA_ONNX_OR(config->model_config.provider, "cpu"); recognizer_config.model_config.model_type = SHERPA_ONNX_OR(config->model_config.model_type, ""); recognizer_config.lm_config.model = SHERPA_ONNX_OR(config->lm_config.model, ""); recognizer_config.lm_config.scale = SHERPA_ONNX_OR(config->lm_config.scale, 1.0); recognizer_config.decoding_method = SHERPA_ONNX_OR(config->decoding_method, "greedy_search"); recognizer_config.max_active_paths = SHERPA_ONNX_OR(config->max_active_paths, 4); if (config->model_config.debug) { fprintf(stderr, "%s\n", recognizer_config.ToString().c_str()); } SherpaOnnxOfflineRecognizer *recognizer = new SherpaOnnxOfflineRecognizer; recognizer->impl = std::make_unique(recognizer_config); return recognizer; } void DestroyOfflineRecognizer(SherpaOnnxOfflineRecognizer *recognizer) { delete recognizer; } SherpaOnnxOfflineStream *CreateOfflineStream( const SherpaOnnxOfflineRecognizer *recognizer) { SherpaOnnxOfflineStream *stream = new SherpaOnnxOfflineStream(recognizer->impl->CreateStream()); return stream; } void DestroyOfflineStream(SherpaOnnxOfflineStream *stream) { delete stream; } void AcceptWaveformOffline(SherpaOnnxOfflineStream *stream, int32_t sample_rate, const float *samples, int32_t n) { stream->impl->AcceptWaveform(sample_rate, samples, n); } void DecodeOfflineStream(SherpaOnnxOfflineRecognizer *recognizer, SherpaOnnxOfflineStream *stream) { recognizer->impl->DecodeStream(stream->impl.get()); } void DecodeMultipleOfflineStreams(SherpaOnnxOfflineRecognizer *recognizer, SherpaOnnxOfflineStream **streams, int32_t n) { std::vector ss(n); for (int32_t i = 0; i != n; ++i) { ss[i] = streams[i]->impl.get(); } recognizer->impl->DecodeStreams(ss.data(), n); } SherpaOnnxOfflineRecognizerResult *GetOfflineStreamResult( SherpaOnnxOfflineStream *stream) { const sherpa_onnx::OfflineRecognitionResult &result = stream->impl->GetResult(); const auto &text = result.text; auto r = new SherpaOnnxOfflineRecognizerResult; r->text = new char[text.size() + 1]; std::copy(text.begin(), text.end(), const_cast(r->text)); const_cast(r->text)[text.size()] = 0; return r; } void DestroyOfflineRecognizerResult( const SherpaOnnxOfflineRecognizerResult *r) { delete[] r->text; delete r; }