diff --git a/mfc-examples/NonStreamingSpeechRecognition/NonStreamingSpeechRecognitionDlg.cpp b/mfc-examples/NonStreamingSpeechRecognition/NonStreamingSpeechRecognitionDlg.cpp index 52b699c0..aefd5b57 100644 --- a/mfc-examples/NonStreamingSpeechRecognition/NonStreamingSpeechRecognitionDlg.cpp +++ b/mfc-examples/NonStreamingSpeechRecognition/NonStreamingSpeechRecognitionDlg.cpp @@ -387,6 +387,7 @@ void CNonStreamingSpeechRecognitionDlg::InitParaformer() { config_.model_config.tokens = tokens.c_str(); config_.model_config.num_threads = 1; config_.model_config.debug = 1; + config_.model_config.model_type = "paraformer"; config_.decoding_method = "greedy_search"; config_.max_active_paths = 4; @@ -447,6 +448,7 @@ void CNonStreamingSpeechRecognitionDlg::InitRecognizer() { config_.model_config.tokens = tokens.c_str(); config_.model_config.num_threads = 1; config_.model_config.debug = 0; + config_.model_config.model_type = "transducer"; config_.decoding_method = "greedy_search"; config_.max_active_paths = 4; diff --git a/scripts/dotnet/offline.cs b/scripts/dotnet/offline.cs index afbec42a..60c1279e 100644 --- a/scripts/dotnet/offline.cs +++ b/scripts/dotnet/offline.cs @@ -76,6 +76,8 @@ namespace SherpaOnnx Tokens = ""; NumThreads = 1; Debug = 0; + Provider = "cpu"; + ModelType = ""; } public OfflineTransducerModelConfig Transducer; public OfflineParaformerModelConfig Paraformer; @@ -87,6 +89,12 @@ namespace SherpaOnnx public int NumThreads; public int Debug; + + [MarshalAs(UnmanagedType.LPStr)] + public string Provider; + + [MarshalAs(UnmanagedType.LPStr)] + public string ModelType; } [StructLayout(LayoutKind.Sequential)] diff --git a/sherpa-onnx/c-api/c-api.cc b/sherpa-onnx/c-api/c-api.cc index ba5f768f..b204f7ae 100644 --- a/sherpa-onnx/c-api/c-api.cc +++ b/sherpa-onnx/c-api/c-api.cc @@ -33,23 +33,33 @@ SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer( const SherpaOnnxOnlineRecognizerConfig *config) { sherpa_onnx::OnlineRecognizerConfig recognizer_config; - recognizer_config.feat_config.sampling_rate = SHERPA_ONNX_OR(config->feat_config.sample_rate, 16000); - recognizer_config.feat_config.feature_dim = SHERPA_ONNX_OR(config->feat_config.feature_dim, 80); + recognizer_config.feat_config.sampling_rate = + SHERPA_ONNX_OR(config->feat_config.sample_rate, 16000); + recognizer_config.feat_config.feature_dim = + SHERPA_ONNX_OR(config->feat_config.feature_dim, 80); recognizer_config.model_config.encoder_filename = SHERPA_ONNX_OR(config->model_config.encoder, ""); recognizer_config.model_config.decoder_filename = SHERPA_ONNX_OR(config->model_config.decoder, ""); - recognizer_config.model_config.joiner_filename = SHERPA_ONNX_OR(config->model_config.joiner, ""); - recognizer_config.model_config.tokens = SHERPA_ONNX_OR(config->model_config.tokens, ""); - recognizer_config.model_config.num_threads = SHERPA_ONNX_OR(config->model_config.num_threads, 1); - recognizer_config.model_config.provider = SHERPA_ONNX_OR(config->model_config.provider, "cpu"); - recognizer_config.model_config.debug = SHERPA_ONNX_OR(config->model_config.debug, 0); + recognizer_config.model_config.joiner_filename = + SHERPA_ONNX_OR(config->model_config.joiner, ""); + recognizer_config.model_config.tokens = + SHERPA_ONNX_OR(config->model_config.tokens, ""); + recognizer_config.model_config.num_threads = + SHERPA_ONNX_OR(config->model_config.num_threads, 1); + recognizer_config.model_config.provider = + SHERPA_ONNX_OR(config->model_config.provider, "cpu"); + recognizer_config.model_config.debug = + SHERPA_ONNX_OR(config->model_config.debug, 0); - recognizer_config.decoding_method = SHERPA_ONNX_OR(config->decoding_method, "greedy_search"); - recognizer_config.max_active_paths = SHERPA_ONNX_OR(config->max_active_paths, 4); + recognizer_config.decoding_method = + SHERPA_ONNX_OR(config->decoding_method, "greedy_search"); + recognizer_config.max_active_paths = + SHERPA_ONNX_OR(config->max_active_paths, 4); - recognizer_config.enable_endpoint = SHERPA_ONNX_OR(config->enable_endpoint, 0); + recognizer_config.enable_endpoint = + SHERPA_ONNX_OR(config->enable_endpoint, 0); recognizer_config.endpoint_config.rule1.min_trailing_silence = SHERPA_ONNX_OR(config->rule1_min_trailing_silence, 2.4); @@ -173,9 +183,11 @@ SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer( const SherpaOnnxOfflineRecognizerConfig *config) { sherpa_onnx::OfflineRecognizerConfig recognizer_config; - recognizer_config.feat_config.sampling_rate = SHERPA_ONNX_OR(config->feat_config.sample_rate, 16000); + recognizer_config.feat_config.sampling_rate = + SHERPA_ONNX_OR(config->feat_config.sample_rate, 16000); - recognizer_config.feat_config.feature_dim = SHERPA_ONNX_OR(config->feat_config.feature_dim, 80); + recognizer_config.feat_config.feature_dim = + SHERPA_ONNX_OR(config->feat_config.feature_dim, 80); recognizer_config.model_config.transducer.encoder_filename = SHERPA_ONNX_OR(config->model_config.transducer.encoder, ""); @@ -184,7 +196,7 @@ SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer( SHERPA_ONNX_OR(config->model_config.transducer.decoder, ""); recognizer_config.model_config.transducer.joiner_filename = - SHERPA_ONNX_OR(config->model_config.transducer.joiner,""); + SHERPA_ONNX_OR(config->model_config.transducer.joiner, ""); recognizer_config.model_config.paraformer.model = SHERPA_ONNX_OR(config->model_config.paraformer.model, ""); @@ -192,15 +204,26 @@ SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer( recognizer_config.model_config.nemo_ctc.model = SHERPA_ONNX_OR(config->model_config.nemo_ctc.model, ""); - recognizer_config.model_config.tokens = SHERPA_ONNX_OR(config->model_config.tokens, ""); - recognizer_config.model_config.num_threads = SHERPA_ONNX_OR(config->model_config.num_threads, 1); - recognizer_config.model_config.debug = SHERPA_ONNX_OR(config->model_config.debug, 0); + recognizer_config.model_config.tokens = + SHERPA_ONNX_OR(config->model_config.tokens, ""); + recognizer_config.model_config.num_threads = + SHERPA_ONNX_OR(config->model_config.num_threads, 1); + recognizer_config.model_config.debug = + SHERPA_ONNX_OR(config->model_config.debug, 0); + recognizer_config.model_config.provider = + SHERPA_ONNX_OR(config->model_config.provider, "cpu"); + recognizer_config.model_config.model_type = + SHERPA_ONNX_OR(config->model_config.model_type, ""); - recognizer_config.lm_config.model = SHERPA_ONNX_OR(config->lm_config.model, ""); - recognizer_config.lm_config.scale = SHERPA_ONNX_OR(config->lm_config.scale, 1.0); + recognizer_config.lm_config.model = + SHERPA_ONNX_OR(config->lm_config.model, ""); + recognizer_config.lm_config.scale = + SHERPA_ONNX_OR(config->lm_config.scale, 1.0); - recognizer_config.decoding_method = SHERPA_ONNX_OR(config->decoding_method, "greedy_search"); - recognizer_config.max_active_paths = SHERPA_ONNX_OR(config->max_active_paths, 4); + recognizer_config.decoding_method = + SHERPA_ONNX_OR(config->decoding_method, "greedy_search"); + recognizer_config.max_active_paths = + SHERPA_ONNX_OR(config->max_active_paths, 4); if (config->model_config.debug) { fprintf(stderr, "%s\n", recognizer_config.ToString().c_str()); diff --git a/sherpa-onnx/c-api/c-api.h b/sherpa-onnx/c-api/c-api.h index 7cc32927..751c1894 100644 --- a/sherpa-onnx/c-api/c-api.h +++ b/sherpa-onnx/c-api/c-api.h @@ -272,6 +272,8 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineModelConfig { const char *tokens; int32_t num_threads; int32_t debug; + const char *provider; + const char *model_type; } SherpaOnnxOfflineModelConfig; SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerConfig { diff --git a/sherpa-onnx/csrc/offline-model-config.cc b/sherpa-onnx/csrc/offline-model-config.cc index d9736649..92380e76 100644 --- a/sherpa-onnx/csrc/offline-model-config.cc +++ b/sherpa-onnx/csrc/offline-model-config.cc @@ -25,6 +25,11 @@ void OfflineModelConfig::Register(ParseOptions *po) { po->Register("provider", &provider, "Specify a provider to use: cpu, cuda, coreml"); + + po->Register("model-type", &model_type, + "Specify it to reduce model initialization time. " + "Valid values are: transducer, paraformer, nemo_ctc. " + "All other values lead to loading the model twice."); } bool OfflineModelConfig::Validate() const { @@ -34,7 +39,7 @@ bool OfflineModelConfig::Validate() const { } if (!FileExists(tokens)) { - SHERPA_ONNX_LOGE("%s does not exist", tokens.c_str()); + SHERPA_ONNX_LOGE("tokens: %s does not exist", tokens.c_str()); return false; } @@ -59,7 +64,8 @@ std::string OfflineModelConfig::ToString() const { os << "tokens=\"" << tokens << "\", "; os << "num_threads=" << num_threads << ", "; os << "debug=" << (debug ? "True" : "False") << ", "; - os << "provider=\"" << provider << "\")"; + os << "provider=\"" << provider << "\", "; + os << "model_type=\"" << model_type << "\")"; return os.str(); } diff --git a/sherpa-onnx/csrc/offline-model-config.h b/sherpa-onnx/csrc/offline-model-config.h index b440c9e3..4afdd65a 100644 --- a/sherpa-onnx/csrc/offline-model-config.h +++ b/sherpa-onnx/csrc/offline-model-config.h @@ -22,19 +22,31 @@ struct OfflineModelConfig { bool debug = false; std::string provider = "cpu"; + // With the help of this field, we only need to load the model once + // instead of twice; and therefore it reduces initialization time. + // + // Valid values: + // - transducer. The given model is from icefall + // - paraformer. It is a paraformer model + // - nemo_ctc. It is a NeMo CTC model. + // + // All other values are invalid and lead to loading the model twice. + std::string model_type; + OfflineModelConfig() = default; OfflineModelConfig(const OfflineTransducerModelConfig &transducer, const OfflineParaformerModelConfig ¶former, const OfflineNemoEncDecCtcModelConfig &nemo_ctc, const std::string &tokens, int32_t num_threads, bool debug, - const std::string &provider) + const std::string &provider, const std::string &model_type) : transducer(transducer), paraformer(paraformer), nemo_ctc(nemo_ctc), tokens(tokens), num_threads(num_threads), debug(debug), - provider(provider) {} + provider(provider), + model_type(model_type) {} void Register(ParseOptions *po); bool Validate() const; diff --git a/sherpa-onnx/csrc/offline-recognizer-impl.cc b/sherpa-onnx/csrc/offline-recognizer-impl.cc index d47426e8..8a2b42a0 100644 --- a/sherpa-onnx/csrc/offline-recognizer-impl.cc +++ b/sherpa-onnx/csrc/offline-recognizer-impl.cc @@ -18,6 +18,21 @@ namespace sherpa_onnx { std::unique_ptr OfflineRecognizerImpl::Create( const OfflineRecognizerConfig &config) { + if (!config.model_config.model_type.empty()) { + const auto &model_type = config.model_config.model_type; + if (model_type == "transducer") { + return std::make_unique(config); + } else if (model_type == "paraformer") { + return std::make_unique(config); + } else if (model_type == "nemo_ctc") { + return std::make_unique(config); + } else { + SHERPA_ONNX_LOGE( + "Invalid model_type: %s. Trying to load the model to get its type", + model_type.c_str()); + } + } + Ort::Env env(ORT_LOGGING_LEVEL_ERROR); Ort::SessionOptions sess_opts; diff --git a/sherpa-onnx/csrc/offline-transducer-model-config.cc b/sherpa-onnx/csrc/offline-transducer-model-config.cc index 16b7a9f3..b90d68e7 100644 --- a/sherpa-onnx/csrc/offline-transducer-model-config.cc +++ b/sherpa-onnx/csrc/offline-transducer-model-config.cc @@ -18,17 +18,17 @@ void OfflineTransducerModelConfig::Register(ParseOptions *po) { bool OfflineTransducerModelConfig::Validate() const { if (!FileExists(encoder_filename)) { - SHERPA_ONNX_LOGE("%s does not exist", encoder_filename.c_str()); + SHERPA_ONNX_LOGE("encoder: %s does not exist", encoder_filename.c_str()); return false; } if (!FileExists(decoder_filename)) { - SHERPA_ONNX_LOGE("%s does not exist", decoder_filename.c_str()); + SHERPA_ONNX_LOGE("decoder: %s does not exist", decoder_filename.c_str()); return false; } if (!FileExists(joiner_filename)) { - SHERPA_ONNX_LOGE("%s does not exist", joiner_filename.c_str()); + SHERPA_ONNX_LOGE("joiner: %s does not exist", joiner_filename.c_str()); return false; } diff --git a/sherpa-onnx/python/csrc/offline-model-config.cc b/sherpa-onnx/python/csrc/offline-model-config.cc index 48f99954..6665e85f 100644 --- a/sherpa-onnx/python/csrc/offline-model-config.cc +++ b/sherpa-onnx/python/csrc/offline-model-config.cc @@ -21,15 +21,16 @@ void PybindOfflineModelConfig(py::module *m) { using PyClass = OfflineModelConfig; py::class_(*m, "OfflineModelConfig") - .def(py::init(), - py::arg("transducer") = OfflineTransducerModelConfig(), - py::arg("paraformer") = OfflineParaformerModelConfig(), - py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(), - py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false, - py::arg("provider") = "cpu") + .def( + py::init(), + py::arg("transducer") = OfflineTransducerModelConfig(), + py::arg("paraformer") = OfflineParaformerModelConfig(), + py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(), + py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false, + py::arg("provider") = "cpu", py::arg("model_type") = "") .def_readwrite("transducer", &PyClass::transducer) .def_readwrite("paraformer", &PyClass::paraformer) .def_readwrite("nemo_ctc", &PyClass::nemo_ctc) @@ -37,6 +38,7 @@ void PybindOfflineModelConfig(py::module *m) { .def_readwrite("num_threads", &PyClass::num_threads) .def_readwrite("debug", &PyClass::debug) .def_readwrite("provider", &PyClass::provider) + .def_readwrite("model_type", &PyClass::model_type) .def("__str__", &PyClass::ToString); } diff --git a/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py b/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py index 9c321384..32fad47a 100644 --- a/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py @@ -86,6 +86,7 @@ class OfflineRecognizer(object): num_threads=num_threads, debug=debug, provider=provider, + model_type="transducer", ) feat_config = OfflineFeatureExtractorConfig( @@ -149,6 +150,7 @@ class OfflineRecognizer(object): num_threads=num_threads, debug=debug, provider=provider, + model_type="paraformer", ) feat_config = OfflineFeatureExtractorConfig( @@ -211,6 +213,7 @@ class OfflineRecognizer(object): num_threads=num_threads, debug=debug, provider=provider, + model_type="nemo_ctc", ) feat_config = OfflineFeatureExtractorConfig(