diff --git a/sherpa-onnx/c-api/c-api.cc b/sherpa-onnx/c-api/c-api.cc index b204f7ae..96cc822b 100644 --- a/sherpa-onnx/c-api/c-api.cc +++ b/sherpa-onnx/c-api/c-api.cc @@ -50,6 +50,8 @@ SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer( SHERPA_ONNX_OR(config->model_config.num_threads, 1); recognizer_config.model_config.provider = SHERPA_ONNX_OR(config->model_config.provider, "cpu"); + recognizer_config.model_config.model_type = + SHERPA_ONNX_OR(config->model_config.model_type, ""); recognizer_config.model_config.debug = SHERPA_ONNX_OR(config->model_config.debug, 0); diff --git a/sherpa-onnx/c-api/c-api.h b/sherpa-onnx/c-api/c-api.h index 751c1894..698f7f38 100644 --- a/sherpa-onnx/c-api/c-api.h +++ b/sherpa-onnx/c-api/c-api.h @@ -53,6 +53,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineTransducerModelConfig { const char *tokens; int32_t num_threads; const char *provider; + const char *model_type; int32_t debug; // true to print debug information of the model } SherpaOnnxOnlineTransducerModelConfig; diff --git a/sherpa-onnx/csrc/online-transducer-model-config.cc b/sherpa-onnx/csrc/online-transducer-model-config.cc index b61f9bc8..f13a2791 100644 --- a/sherpa-onnx/csrc/online-transducer-model-config.cc +++ b/sherpa-onnx/csrc/online-transducer-model-config.cc @@ -22,26 +22,30 @@ void OnlineTransducerModelConfig::Register(ParseOptions *po) { po->Register("debug", &debug, "true to print model information while loading it."); + po->Register("model-type", &model_type, + "Specify it to reduce model initialization time. " + "Valid values are: conformer, lstm, zipformer, zipformer2. " + "All other values lead to loading the model twice."); } bool OnlineTransducerModelConfig::Validate() const { if (!FileExists(tokens)) { - SHERPA_ONNX_LOGE("%s does not exist", tokens.c_str()); + SHERPA_ONNX_LOGE("tokens: %s does not exist", tokens.c_str()); return false; } if (!FileExists(encoder_filename)) { - SHERPA_ONNX_LOGE("%s does not exist", encoder_filename.c_str()); + SHERPA_ONNX_LOGE("encoder: %s does not exist", encoder_filename.c_str()); return false; } if (!FileExists(decoder_filename)) { - SHERPA_ONNX_LOGE("%s does not exist", decoder_filename.c_str()); + SHERPA_ONNX_LOGE("decoder: %s does not exist", decoder_filename.c_str()); return false; } if (!FileExists(joiner_filename)) { - SHERPA_ONNX_LOGE("%s does not exist", joiner_filename.c_str()); + SHERPA_ONNX_LOGE("joiner: %s does not exist", joiner_filename.c_str()); return false; } @@ -63,6 +67,7 @@ std::string OnlineTransducerModelConfig::ToString() const { os << "tokens=\"" << tokens << "\", "; os << "num_threads=" << num_threads << ", "; os << "provider=\"" << provider << "\", "; + os << "model_type=\"" << model_type << "\", "; os << "debug=" << (debug ? "True" : "False") << ")"; return os.str(); diff --git a/sherpa-onnx/csrc/online-transducer-model-config.h b/sherpa-onnx/csrc/online-transducer-model-config.h index c9fc1b73..040dfe28 100644 --- a/sherpa-onnx/csrc/online-transducer-model-config.h +++ b/sherpa-onnx/csrc/online-transducer-model-config.h @@ -19,19 +19,33 @@ struct OnlineTransducerModelConfig { bool debug = false; std::string provider = "cpu"; + // With the help of this field, we only need to load the model once + // instead of twice; and therefore it reduces initialization time. + // + // Valid values: + // - conformer + // - lstm + // - zipformer + // - zipformer2 + // + // All other values are invalid and lead to loading the model twice. + std::string model_type; + OnlineTransducerModelConfig() = default; OnlineTransducerModelConfig(const std::string &encoder_filename, const std::string &decoder_filename, const std::string &joiner_filename, const std::string &tokens, int32_t num_threads, - bool debug, const std::string &provider) + bool debug, const std::string &provider, + const std::string &model_type) : encoder_filename(encoder_filename), decoder_filename(decoder_filename), joiner_filename(joiner_filename), tokens(tokens), num_threads(num_threads), debug(debug), - provider(provider) {} + provider(provider), + model_type(model_type) {} void Register(ParseOptions *po); bool Validate() const; diff --git a/sherpa-onnx/csrc/online-transducer-model.cc b/sherpa-onnx/csrc/online-transducer-model.cc index d00afcbe..bf471526 100644 --- a/sherpa-onnx/csrc/online-transducer-model.cc +++ b/sherpa-onnx/csrc/online-transducer-model.cc @@ -77,6 +77,22 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, std::unique_ptr OnlineTransducerModel::Create( const OnlineTransducerModelConfig &config) { + if (!config.model_type.empty()) { + const auto &model_type = config.model_type; + if (model_type == "conformer") { + return std::make_unique(config); + } else if (model_type == "lstm") { + return std::make_unique(config); + } else if (model_type == "zipformer") { + return std::make_unique(config); + } else if (model_type == "zipformer2") { + return std::make_unique(config); + } else { + SHERPA_ONNX_LOGE( + "Invalid model_type: %s. Trying to load the model to get its type", + model_type.c_str()); + } + } ModelType model_type = ModelType::kUnkown; { @@ -140,6 +156,23 @@ Ort::Value OnlineTransducerModel::BuildDecoderInput( #if __ANDROID_API__ >= 9 std::unique_ptr OnlineTransducerModel::Create( AAssetManager *mgr, const OnlineTransducerModelConfig &config) { + if (!config.model_type.empty()) { + const auto &model_type = config.model_type; + if (model_type == "conformer") { + return std::make_unique(mgr, config); + } else if (model_type == "lstm") { + return std::make_unique(mgr, config); + } else if (model_type == "zipformer") { + return std::make_unique(mgr, config); + } else if (model_type == "zipformer2") { + return std::make_unique(mgr, config); + } else { + SHERPA_ONNX_LOGE( + "Invalid model_type: %s. Trying to load the model to get its type", + model_type.c_str()); + } + } + auto buffer = ReadFile(mgr, config.encoder_filename); auto model_type = GetModelType(buffer.data(), buffer.size(), config.debug); diff --git a/sherpa-onnx/python/csrc/online-transducer-model-config.cc b/sherpa-onnx/python/csrc/online-transducer-model-config.cc index 62c89e3e..8246acdd 100644 --- a/sherpa-onnx/python/csrc/online-transducer-model-config.cc +++ b/sherpa-onnx/python/csrc/online-transducer-model-config.cc @@ -15,11 +15,11 @@ void PybindOnlineTransducerModelConfig(py::module *m) { py::class_(*m, "OnlineTransducerModelConfig") .def(py::init(), + const std::string &, const std::string &>(), py::arg("encoder_filename"), py::arg("decoder_filename"), py::arg("joiner_filename"), py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false, - py::arg("provider") = "cpu") + py::arg("provider") = "cpu", py::arg("model_type") = "") .def_readwrite("encoder_filename", &PyClass::encoder_filename) .def_readwrite("decoder_filename", &PyClass::decoder_filename) .def_readwrite("joiner_filename", &PyClass::joiner_filename) @@ -27,6 +27,7 @@ void PybindOnlineTransducerModelConfig(py::module *m) { .def_readwrite("num_threads", &PyClass::num_threads) .def_readwrite("debug", &PyClass::debug) .def_readwrite("provider", &PyClass::provider) + .def_readwrite("model_type", &PyClass::model_type) .def("__str__", &PyClass::ToString); } diff --git a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py index bcfe1da3..e0f47068 100644 --- a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py @@ -41,6 +41,7 @@ class OnlineRecognizer(object): max_active_paths: int = 4, context_score: float = 1.5, provider: str = "cpu", + model_type: str = "", ): """ Please refer to @@ -90,6 +91,9 @@ class OnlineRecognizer(object): the maximum number of active paths during beam search. provider: onnxruntime execution providers. Valid values are: cpu, cuda, coreml. + model_type: + Online transducer model type. Valid values are: conformer, lstm, + zipformer, zipformer2. All other values lead to loading the model twice. """ _assert_file_exists(tokens) _assert_file_exists(encoder) @@ -105,6 +109,7 @@ class OnlineRecognizer(object): tokens=tokens, num_threads=num_threads, provider=provider, + model_type=model_type, ) feat_config = FeatureExtractorConfig(