diff --git a/harmony-os/SherpaOnnxHar/sherpa_onnx/src/main/cpp/keyword-spotting.cc b/harmony-os/SherpaOnnxHar/sherpa_onnx/src/main/cpp/keyword-spotting.cc index 6980f1ab..6562ef5a 100644 --- a/harmony-os/SherpaOnnxHar/sherpa_onnx/src/main/cpp/keyword-spotting.cc +++ b/harmony-os/SherpaOnnxHar/sherpa_onnx/src/main/cpp/keyword-spotting.cc @@ -16,6 +16,16 @@ SherpaOnnxOnlineModelConfig GetOnlineModelConfig(Napi::Object obj); static Napi::External CreateKeywordSpotterWrapper( const Napi::CallbackInfo &info) { Napi::Env env = info.Env(); +#if __OHOS__ + if (info.Length() != 2) { + std::ostringstream os; + os << "Expect only 2 arguments. Given: " << info.Length(); + + Napi::TypeError::New(env, os.str()).ThrowAsJavaScriptException(); + + return {}; + } +#else if (info.Length() != 1) { std::ostringstream os; os << "Expect only 1 argument. Given: " << info.Length(); @@ -24,7 +34,7 @@ static Napi::External CreateKeywordSpotterWrapper( return {}; } - +#endif if (!info[0].IsObject()) { Napi::TypeError::New(env, "Expect an object as the argument") .ThrowAsJavaScriptException(); @@ -46,7 +56,18 @@ static Napi::External CreateKeywordSpotterWrapper( SHERPA_ONNX_ASSIGN_ATTR_STR(keywords_buf, keywordsBuf); SHERPA_ONNX_ASSIGN_ATTR_INT32(keywords_buf_size, keywordsBufSize); +#if __OHOS__ + std::unique_ptr + mgr(OH_ResourceManager_InitNativeResourceManager(env, info[1]), + &OH_ResourceManager_ReleaseNativeResourceManager); + + const SherpaOnnxKeywordSpotter *kws = + SherpaOnnxCreateKeywordSpotterOHOS(&c, mgr.get()); +#else const SherpaOnnxKeywordSpotter *kws = SherpaOnnxCreateKeywordSpotter(&c); +#endif + SHERPA_ONNX_DELETE_C_STR(c.model_config.transducer.encoder); SHERPA_ONNX_DELETE_C_STR(c.model_config.transducer.decoder); SHERPA_ONNX_DELETE_C_STR(c.model_config.transducer.joiner); @@ -79,9 +100,9 @@ static Napi::External CreateKeywordSpotterWrapper( static Napi::External CreateKeywordStreamWrapper( const Napi::CallbackInfo &info) { Napi::Env env = info.Env(); - if (info.Length() != 1) { + if (info.Length() != 1 && info.Length() != 2) { std::ostringstream os; - os << "Expect only 1 argument. Given: " << info.Length(); + os << "Expect only 1 or 2 arguments. Given: " << info.Length(); Napi::TypeError::New(env, os.str()).ThrowAsJavaScriptException(); @@ -96,10 +117,24 @@ static Napi::External CreateKeywordStreamWrapper( return {}; } + if (info.Length() == 2 && !info[1].IsString()) { + std::ostringstream os; + os << "Argument 2 should be a string."; + Napi::TypeError::New(env, os.str()).ThrowAsJavaScriptException(); + return {}; + } + const SherpaOnnxKeywordSpotter *kws = info[0].As>().Data(); - const SherpaOnnxOnlineStream *stream = SherpaOnnxCreateKeywordStream(kws); + const SherpaOnnxOnlineStream *stream; + if (info.Length() == 1) { + stream = SherpaOnnxCreateKeywordStream(kws); + } else { + Napi::String js_keywords = info[1].As(); + std::string keywords = js_keywords.Utf8Value(); + stream = SherpaOnnxCreateKeywordStreamWithKeywords(kws, keywords.c_str()); + } return Napi::External::New( env, const_cast(stream), diff --git a/sherpa-onnx/c-api/c-api.cc b/sherpa-onnx/c-api/c-api.cc index ce6d5563..78701cf7 100644 --- a/sherpa-onnx/c-api/c-api.cc +++ b/sherpa-onnx/c-api/c-api.cc @@ -678,7 +678,7 @@ struct SherpaOnnxKeywordSpotter { std::unique_ptr impl; }; -const SherpaOnnxKeywordSpotter *SherpaOnnxCreateKeywordSpotter( +static sherpa_onnx::KeywordSpotterConfig GetKeywordSpotterConfig( const SherpaOnnxKeywordSpotterConfig *config) { sherpa_onnx::KeywordSpotterConfig spotter_config; @@ -739,10 +739,20 @@ const SherpaOnnxKeywordSpotter *SherpaOnnxCreateKeywordSpotter( std::string(config->keywords_buf, config->keywords_buf_size); } - if (config->model_config.debug) { + if (spotter_config.model_config.debug) { +#if OHOS + SHERPA_ONNX_LOGE("%{public}s\n", spotter_config.ToString().c_str()); +#else SHERPA_ONNX_LOGE("%s\n", spotter_config.ToString().c_str()); +#endif } + return spotter_config; +} + +const SherpaOnnxKeywordSpotter *SherpaOnnxCreateKeywordSpotter( + const SherpaOnnxKeywordSpotterConfig *config) { + auto spotter_config = GetKeywordSpotterConfig(config); if (!spotter_config.Validate()) { SHERPA_ONNX_LOGE("Errors in config!"); return nullptr; @@ -2272,6 +2282,22 @@ SherpaOnnxCreateSpeakerEmbeddingExtractorOHOS( return p; } +const SherpaOnnxKeywordSpotter *SherpaOnnxCreateKeywordSpotterOHOS( + const SherpaOnnxKeywordSpotterConfig *config, NativeResourceManager *mgr) { + if (!mgr) { + return SherpaOnnxCreateKeywordSpotter(config); + } + + auto spotter_config = GetKeywordSpotterConfig(config); + + SherpaOnnxKeywordSpotter *spotter = new SherpaOnnxKeywordSpotter; + + spotter->impl = + std::make_unique(mgr, spotter_config); + + return spotter; +} + #if SHERPA_ONNX_ENABLE_TTS == 1 const SherpaOnnxOfflineTts *SherpaOnnxCreateOfflineTtsOHOS( const SherpaOnnxOfflineTtsConfig *config, NativeResourceManager *mgr) { diff --git a/sherpa-onnx/c-api/c-api.h b/sherpa-onnx/c-api/c-api.h index 10bd101f..5fe124d4 100644 --- a/sherpa-onnx/c-api/c-api.h +++ b/sherpa-onnx/c-api/c-api.h @@ -1645,6 +1645,10 @@ SherpaOnnxCreateSpeakerEmbeddingExtractorOHOS( const SherpaOnnxSpeakerEmbeddingExtractorConfig *config, NativeResourceManager *mgr); +SHERPA_ONNX_API const SherpaOnnxKeywordSpotter * +SherpaOnnxCreateKeywordSpotterOHOS(const SherpaOnnxKeywordSpotterConfig *config, + NativeResourceManager *mgr); + SHERPA_ONNX_API const SherpaOnnxOfflineSpeakerDiarization * SherpaOnnxCreateOfflineSpeakerDiarizationOHOS( const SherpaOnnxOfflineSpeakerDiarizationConfig *config, diff --git a/sherpa-onnx/csrc/keyword-spotter-impl.cc b/sherpa-onnx/csrc/keyword-spotter-impl.cc index 1c9d5948..affb212c 100644 --- a/sherpa-onnx/csrc/keyword-spotter-impl.cc +++ b/sherpa-onnx/csrc/keyword-spotter-impl.cc @@ -6,6 +6,15 @@ #include "sherpa-onnx/csrc/keyword-spotter-transducer-impl.h" +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + namespace sherpa_onnx { std::unique_ptr KeywordSpotterImpl::Create( @@ -18,9 +27,9 @@ std::unique_ptr KeywordSpotterImpl::Create( exit(-1); } -#if __ANDROID_API__ >= 9 +template std::unique_ptr KeywordSpotterImpl::Create( - AAssetManager *mgr, const KeywordSpotterConfig &config) { + Manager *mgr, const KeywordSpotterConfig &config) { if (!config.model_config.transducer.encoder.empty()) { return std::make_unique(mgr, config); } @@ -28,6 +37,15 @@ std::unique_ptr KeywordSpotterImpl::Create( SHERPA_ONNX_LOGE("Please specify a model"); exit(-1); } + +#if __ANDROID_API__ >= 9 +template std::unique_ptr KeywordSpotterImpl::Create( + AAssetManager *mgr, const KeywordSpotterConfig &config); +#endif + +#if __OHOS__ +template std::unique_ptr KeywordSpotterImpl::Create( + NativeResourceManager *mgr, const KeywordSpotterConfig &config); #endif } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/keyword-spotter-impl.h b/sherpa-onnx/csrc/keyword-spotter-impl.h index 6180f917..dafb0a8b 100644 --- a/sherpa-onnx/csrc/keyword-spotter-impl.h +++ b/sherpa-onnx/csrc/keyword-spotter-impl.h @@ -9,11 +9,6 @@ #include #include -#if __ANDROID_API__ >= 9 -#include "android/asset_manager.h" -#include "android/asset_manager_jni.h" -#endif - #include "sherpa-onnx/csrc/keyword-spotter.h" #include "sherpa-onnx/csrc/online-stream.h" @@ -24,10 +19,9 @@ class KeywordSpotterImpl { static std::unique_ptr Create( const KeywordSpotterConfig &config); -#if __ANDROID_API__ >= 9 + template static std::unique_ptr Create( - AAssetManager *mgr, const KeywordSpotterConfig &config); -#endif + Manager *mgr, const KeywordSpotterConfig &config); virtual ~KeywordSpotterImpl() = default; diff --git a/sherpa-onnx/csrc/keyword-spotter-transducer-impl.h b/sherpa-onnx/csrc/keyword-spotter-transducer-impl.h index d29b8b58..e62b3b2c 100644 --- a/sherpa-onnx/csrc/keyword-spotter-transducer-impl.h +++ b/sherpa-onnx/csrc/keyword-spotter-transducer-impl.h @@ -9,16 +9,10 @@ #include #include // NOLINT #include +#include #include #include -#if __ANDROID_API__ >= 9 -#include - -#include "android/asset_manager.h" -#include "android/asset_manager_jni.h" -#endif - #include "sherpa-onnx/csrc/file-utils.h" #include "sherpa-onnx/csrc/keyword-spotter-impl.h" #include "sherpa-onnx/csrc/keyword-spotter.h" @@ -91,9 +85,8 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl { unk_id_); } -#if __ANDROID_API__ >= 9 - KeywordSpotterTransducerImpl(AAssetManager *mgr, - const KeywordSpotterConfig &config) + template + KeywordSpotterTransducerImpl(Manager *mgr, const KeywordSpotterConfig &config) : config_(config), model_(OnlineTransducerModel::Create(mgr, config.model_config)), sym_(mgr, config.model_config.tokens) { @@ -109,7 +102,6 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl { model_.get(), config_.max_active_paths, config_.num_trailing_blanks, unk_id_); } -#endif std::unique_ptr CreateStream() const override { auto stream = @@ -130,7 +122,11 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl { if (!EncodeKeywords(is, sym_, ¤t_ids, ¤t_kws, ¤t_scores, ¤t_thresholds)) { +#if __OHOS__ + SHERPA_ONNX_LOGE("Encode keywords %{public}s failed.", keywords.c_str()); +#else SHERPA_ONNX_LOGE("Encode keywords %s failed.", keywords.c_str()); +#endif return nullptr; } @@ -306,16 +302,21 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl { // each line in keywords_file contains space-separated words std::ifstream is(config_.keywords_file); if (!is) { +#if __OHOS__ + SHERPA_ONNX_LOGE("Open keywords file failed: %{public}s", + config_.keywords_file.c_str()); +#else SHERPA_ONNX_LOGE("Open keywords file failed: %s", config_.keywords_file.c_str()); +#endif exit(-1); } InitKeywords(is); #endif } -#if __ANDROID_API__ >= 9 - void InitKeywords(AAssetManager *mgr) { + template + void InitKeywords(Manager *mgr) { // each line in keywords_file contains space-separated words auto buf = ReadFile(mgr, config_.keywords_file); @@ -323,13 +324,17 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl { std::istrstream is(buf.data(), buf.size()); if (!is) { +#if __OHOS__ + SHERPA_ONNX_LOGE("Open keywords file failed: %{public}s", + config_.keywords_file.c_str()); +#else SHERPA_ONNX_LOGE("Open keywords file failed: %s", config_.keywords_file.c_str()); +#endif exit(-1); } InitKeywords(is); } -#endif void InitKeywordsFromBufStr() { // keywords_buf's content is supposed to be same as the keywords_file's diff --git a/sherpa-onnx/csrc/keyword-spotter.cc b/sherpa-onnx/csrc/keyword-spotter.cc index 66d0907a..615aab9c 100644 --- a/sherpa-onnx/csrc/keyword-spotter.cc +++ b/sherpa-onnx/csrc/keyword-spotter.cc @@ -13,6 +13,15 @@ #include #include +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + #include "sherpa-onnx/csrc/keyword-spotter-impl.h" namespace sherpa_onnx { @@ -136,11 +145,9 @@ std::string KeywordSpotterConfig::ToString() const { KeywordSpotter::KeywordSpotter(const KeywordSpotterConfig &config) : impl_(KeywordSpotterImpl::Create(config)) {} -#if __ANDROID_API__ >= 9 -KeywordSpotter::KeywordSpotter(AAssetManager *mgr, - const KeywordSpotterConfig &config) +template +KeywordSpotter::KeywordSpotter(Manager *mgr, const KeywordSpotterConfig &config) : impl_(KeywordSpotterImpl::Create(mgr, config)) {} -#endif KeywordSpotter::~KeywordSpotter() = default; @@ -167,4 +174,14 @@ KeywordResult KeywordSpotter::GetResult(OnlineStream *s) const { return impl_->GetResult(s); } +#if __ANDROID_API__ >= 9 +template KeywordSpotter::KeywordSpotter(AAssetManager *mgr, + const KeywordSpotterConfig &config); +#endif + +#if __OHOS__ +template KeywordSpotter::KeywordSpotter(NativeResourceManager *mgr, + const KeywordSpotterConfig &config); +#endif + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/keyword-spotter.h b/sherpa-onnx/csrc/keyword-spotter.h index c933f4b2..e494b3e5 100644 --- a/sherpa-onnx/csrc/keyword-spotter.h +++ b/sherpa-onnx/csrc/keyword-spotter.h @@ -9,11 +9,6 @@ #include #include -#if __ANDROID_API__ >= 9 -#include "android/asset_manager.h" -#include "android/asset_manager_jni.h" -#endif - #include "sherpa-onnx/csrc/features.h" #include "sherpa-onnx/csrc/online-model-config.h" #include "sherpa-onnx/csrc/online-stream.h" @@ -101,9 +96,8 @@ class KeywordSpotter { public: explicit KeywordSpotter(const KeywordSpotterConfig &config); -#if __ANDROID_API__ >= 9 - KeywordSpotter(AAssetManager *mgr, const KeywordSpotterConfig &config); -#endif + template + KeywordSpotter(Manager *mgr, const KeywordSpotterConfig &config); ~KeywordSpotter();