diff --git a/sherpa-onnx/csrc/offline-lm-config.cc b/sherpa-onnx/csrc/offline-lm-config.cc index 429e5144..262d91f0 100644 --- a/sherpa-onnx/csrc/offline-lm-config.cc +++ b/sherpa-onnx/csrc/offline-lm-config.cc @@ -14,6 +14,10 @@ namespace sherpa_onnx { void OfflineLMConfig::Register(ParseOptions *po) { po->Register("lm", &model, "Path to LM model."); po->Register("lm-scale", &scale, "LM scale."); + po->Register("lm-num-threads", &lm_num_threads, + "Number of threads to run the neural network of LM model"); + po->Register("lm-provider", &lm_provider, + "Specify a provider to LM model use: cpu, cuda, coreml"); } bool OfflineLMConfig::Validate() const { diff --git a/sherpa-onnx/csrc/offline-lm-config.h b/sherpa-onnx/csrc/offline-lm-config.h index 1a35dc85..3468c58a 100644 --- a/sherpa-onnx/csrc/offline-lm-config.h +++ b/sherpa-onnx/csrc/offline-lm-config.h @@ -16,11 +16,17 @@ struct OfflineLMConfig { // LM scale float scale = 0.5; + int32_t lm_num_threads = 1; + std::string lm_provider = "cpu"; OfflineLMConfig() = default; - OfflineLMConfig(const std::string &model, float scale) - : model(model), scale(scale) {} + OfflineLMConfig(const std::string &model, float scale, int32_t lm_num_threads, + const std::string &lm_provider) + : model(model), + scale(scale), + lm_num_threads(lm_num_threads), + lm_provider(lm_provider) {} void Register(ParseOptions *po); bool Validate() const; diff --git a/sherpa-onnx/csrc/offline-lm.cc b/sherpa-onnx/csrc/offline-lm.cc index d1301393..f76dcfd6 100644 --- a/sherpa-onnx/csrc/offline-lm.cc +++ b/sherpa-onnx/csrc/offline-lm.cc @@ -12,8 +12,7 @@ namespace sherpa_onnx { -std::unique_ptr OfflineLM::Create( - const OfflineRecognizerConfig &config) { +std::unique_ptr OfflineLM::Create(const OfflineLMConfig &config) { return std::make_unique(config); } diff --git a/sherpa-onnx/csrc/offline-lm.h b/sherpa-onnx/csrc/offline-lm.h index 73a8051c..f99a8ad9 100644 --- a/sherpa-onnx/csrc/offline-lm.h +++ b/sherpa-onnx/csrc/offline-lm.h @@ -10,7 +10,7 @@ #include "onnxruntime_cxx_api.h" // NOLINT #include "sherpa-onnx/csrc/hypothesis.h" -#include "sherpa-onnx/csrc/offline-recognizer.h" +#include "sherpa-onnx/csrc/offline-lm-config.h" namespace sherpa_onnx { @@ -18,8 +18,7 @@ class OfflineLM { public: virtual ~OfflineLM() = default; - static std::unique_ptr Create( - const OfflineRecognizerConfig &config); + static std::unique_ptr Create(const OfflineLMConfig &config); /** Rescore a batch of sentences. * diff --git a/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h b/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h index 0ff36c63..d8360dce 100644 --- a/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h @@ -59,7 +59,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { std::make_unique(model_.get()); } else if (config_.decoding_method == "modified_beam_search") { if (!config_.lm_config.model.empty()) { - lm_ = OfflineLM::Create(config); + lm_ = OfflineLM::Create(config.lm_config); } decoder_ = std::make_unique( diff --git a/sherpa-onnx/csrc/offline-rnn-lm.cc b/sherpa-onnx/csrc/offline-rnn-lm.cc index da717aba..77d2d393 100644 --- a/sherpa-onnx/csrc/offline-rnn-lm.cc +++ b/sherpa-onnx/csrc/offline-rnn-lm.cc @@ -18,12 +18,12 @@ namespace sherpa_onnx { class OfflineRnnLM::Impl { public: - explicit Impl(const OfflineRecognizerConfig &config) - : config_(config.lm_config), + explicit Impl(const OfflineLMConfig &config) + : config_(config), env_(ORT_LOGGING_LEVEL_ERROR), - sess_opts_{GetSessionOptions(config.model_config)}, + sess_opts_{GetSessionOptions(config)}, allocator_{} { - Init(config.lm_config); + Init(config); } Ort::Value Rescore(Ort::Value x, Ort::Value x_lens) { @@ -63,7 +63,7 @@ class OfflineRnnLM::Impl { std::vector output_names_ptr_; }; -OfflineRnnLM::OfflineRnnLM(const OfflineRecognizerConfig &config) +OfflineRnnLM::OfflineRnnLM(const OfflineLMConfig &config) : impl_(std::make_unique(config)) {} OfflineRnnLM::~OfflineRnnLM() = default; diff --git a/sherpa-onnx/csrc/offline-rnn-lm.h b/sherpa-onnx/csrc/offline-rnn-lm.h index e0e9f804..e9f8cc97 100644 --- a/sherpa-onnx/csrc/offline-rnn-lm.h +++ b/sherpa-onnx/csrc/offline-rnn-lm.h @@ -17,7 +17,7 @@ class OfflineRnnLM : public OfflineLM { public: ~OfflineRnnLM() override; - explicit OfflineRnnLM(const OfflineRecognizerConfig &config); + explicit OfflineRnnLM(const OfflineLMConfig &config); /** Rescore a batch of sentences. * diff --git a/sherpa-onnx/csrc/online-lm-config.cc b/sherpa-onnx/csrc/online-lm-config.cc index 80f597e7..d5b41d2b 100644 --- a/sherpa-onnx/csrc/online-lm-config.cc +++ b/sherpa-onnx/csrc/online-lm-config.cc @@ -14,6 +14,10 @@ namespace sherpa_onnx { void OnlineLMConfig::Register(ParseOptions *po) { po->Register("lm", &model, "Path to LM model."); po->Register("lm-scale", &scale, "LM scale."); + po->Register("lm-num-threads", &lm_num_threads, + "Number of threads to run the neural network of LM model"); + po->Register("lm-provider", &lm_provider, + "Specify a provider to LM model use: cpu, cuda, coreml"); } bool OnlineLMConfig::Validate() const { diff --git a/sherpa-onnx/csrc/online-lm-config.h b/sherpa-onnx/csrc/online-lm-config.h index 8bb4ab53..90bc13d9 100644 --- a/sherpa-onnx/csrc/online-lm-config.h +++ b/sherpa-onnx/csrc/online-lm-config.h @@ -16,11 +16,17 @@ struct OnlineLMConfig { // LM scale float scale = 0.5; + int32_t lm_num_threads = 1; + std::string lm_provider = "cpu"; OnlineLMConfig() = default; - OnlineLMConfig(const std::string &model, float scale) - : model(model), scale(scale) {} + OnlineLMConfig(const std::string &model, float scale, int32_t lm_num_threads, + const std::string &lm_provider) + : model(model), + scale(scale), + lm_num_threads(lm_num_threads), + lm_provider(lm_provider) {} void Register(ParseOptions *po); bool Validate() const; diff --git a/sherpa-onnx/csrc/online-lm.cc b/sherpa-onnx/csrc/online-lm.cc index 964a0fd7..dfec00cc 100644 --- a/sherpa-onnx/csrc/online-lm.cc +++ b/sherpa-onnx/csrc/online-lm.cc @@ -13,8 +13,7 @@ namespace sherpa_onnx { -std::unique_ptr OnlineLM::Create( - const OnlineRecognizerConfig &config) { +std::unique_ptr OnlineLM::Create(const OnlineLMConfig &config) { return std::make_unique(config); } diff --git a/sherpa-onnx/csrc/online-lm.h b/sherpa-onnx/csrc/online-lm.h index fc22e43b..6c73f46c 100644 --- a/sherpa-onnx/csrc/online-lm.h +++ b/sherpa-onnx/csrc/online-lm.h @@ -11,7 +11,7 @@ #include "onnxruntime_cxx_api.h" // NOLINT #include "sherpa-onnx/csrc/hypothesis.h" -#include "sherpa-onnx/csrc/online-recognizer.h" +#include "sherpa-onnx/csrc/online-lm-config.h" namespace sherpa_onnx { @@ -19,7 +19,7 @@ class OnlineLM { public: virtual ~OnlineLM() = default; - static std::unique_ptr Create(const OnlineRecognizerConfig &config); + static std::unique_ptr Create(const OnlineLMConfig &config); virtual std::pair> GetInitStates() = 0; diff --git a/sherpa-onnx/csrc/online-recognizer.cc b/sherpa-onnx/csrc/online-recognizer.cc index 8489c174..06aab880 100644 --- a/sherpa-onnx/csrc/online-recognizer.cc +++ b/sherpa-onnx/csrc/online-recognizer.cc @@ -129,7 +129,7 @@ class OnlineRecognizer::Impl { endpoint_(config_.endpoint_config) { if (config.decoding_method == "modified_beam_search") { if (!config_.lm_config.model.empty()) { - lm_ = OnlineLM::Create(config); + lm_ = OnlineLM::Create(config.lm_config); } decoder_ = std::make_unique( diff --git a/sherpa-onnx/csrc/online-rnn-lm.cc b/sherpa-onnx/csrc/online-rnn-lm.cc index 38149326..29b150e4 100644 --- a/sherpa-onnx/csrc/online-rnn-lm.cc +++ b/sherpa-onnx/csrc/online-rnn-lm.cc @@ -19,12 +19,12 @@ namespace sherpa_onnx { class OnlineRnnLM::Impl { public: - explicit Impl(const OnlineRecognizerConfig &config) - : config_(config.lm_config), + explicit Impl(const OnlineLMConfig &config) + : config_(config), env_(ORT_LOGGING_LEVEL_ERROR), - sess_opts_{GetSessionOptions(config.model_config)}, + sess_opts_{GetSessionOptions(config)}, allocator_{} { - Init(config.lm_config); + Init(config); } void ComputeLMScore(float scale, Hypothesis *hyp) { @@ -143,7 +143,7 @@ class OnlineRnnLM::Impl { int32_t sos_id_ = 1; }; -OnlineRnnLM::OnlineRnnLM(const OnlineRecognizerConfig &config) +OnlineRnnLM::OnlineRnnLM(const OnlineLMConfig &config) : impl_(std::make_unique(config)) {} OnlineRnnLM::~OnlineRnnLM() = default; diff --git a/sherpa-onnx/csrc/online-rnn-lm.h b/sherpa-onnx/csrc/online-rnn-lm.h index 00cac454..dee17f73 100644 --- a/sherpa-onnx/csrc/online-rnn-lm.h +++ b/sherpa-onnx/csrc/online-rnn-lm.h @@ -20,7 +20,7 @@ class OnlineRnnLM : public OnlineLM { public: ~OnlineRnnLM() override; - explicit OnlineRnnLM(const OnlineRecognizerConfig &config); + explicit OnlineRnnLM(const OnlineLMConfig &config); std::pair> GetInitStates() override; diff --git a/sherpa-onnx/csrc/session.cc b/sherpa-onnx/csrc/session.cc index 49979873..5c2abac6 100644 --- a/sherpa-onnx/csrc/session.cc +++ b/sherpa-onnx/csrc/session.cc @@ -69,4 +69,12 @@ Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config) { return GetSessionOptionsImpl(config.num_threads, config.provider); } +Ort::SessionOptions GetSessionOptions(const OfflineLMConfig &config) { + return GetSessionOptionsImpl(config.lm_num_threads, config.lm_provider); +} + +Ort::SessionOptions GetSessionOptions(const OnlineLMConfig &config) { + return GetSessionOptionsImpl(config.lm_num_threads, config.lm_provider); +} + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/session.h b/sherpa-onnx/csrc/session.h index 8e0508ed..7f28742a 100644 --- a/sherpa-onnx/csrc/session.h +++ b/sherpa-onnx/csrc/session.h @@ -6,7 +6,9 @@ #define SHERPA_ONNX_CSRC_SESSION_H_ #include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/offline-lm-config.h" #include "sherpa-onnx/csrc/offline-model-config.h" +#include "sherpa-onnx/csrc/online-lm-config.h" #include "sherpa-onnx/csrc/online-transducer-model-config.h" namespace sherpa_onnx { @@ -16,6 +18,9 @@ Ort::SessionOptions GetSessionOptions( Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config); +Ort::SessionOptions GetSessionOptions(const OfflineLMConfig &config); + +Ort::SessionOptions GetSessionOptions(const OnlineLMConfig &config); } // namespace sherpa_onnx #endif // SHERPA_ONNX_CSRC_SESSION_H_ diff --git a/sherpa-onnx/python/csrc/offline-lm-config.cc b/sherpa-onnx/python/csrc/offline-lm-config.cc index a5f58cfd..ba9ea94d 100644 --- a/sherpa-onnx/python/csrc/offline-lm-config.cc +++ b/sherpa-onnx/python/csrc/offline-lm-config.cc @@ -13,10 +13,13 @@ namespace sherpa_onnx { void PybindOfflineLMConfig(py::module *m) { using PyClass = OfflineLMConfig; py::class_(*m, "OfflineLMConfig") - .def(py::init(), py::arg("model"), - py::arg("scale")) + .def(py::init(), + py::arg("model"), py::arg("scale") = 0.5f, + py::arg("lm_num_threads") = 1, py::arg("lm-provider") = "cpu") .def_readwrite("model", &PyClass::model) .def_readwrite("scale", &PyClass::scale) + .def_readwrite("lm_provider", &PyClass::lm_provider) + .def_readwrite("lm_num_threads", &PyClass::lm_num_threads) .def("__str__", &PyClass::ToString); } diff --git a/sherpa-onnx/python/csrc/online-lm-config.cc b/sherpa-onnx/python/csrc/online-lm-config.cc index f7097e49..56da7399 100644 --- a/sherpa-onnx/python/csrc/online-lm-config.cc +++ b/sherpa-onnx/python/csrc/online-lm-config.cc @@ -13,10 +13,13 @@ namespace sherpa_onnx { void PybindOnlineLMConfig(py::module *m) { using PyClass = OnlineLMConfig; py::class_(*m, "OnlineLMConfig") - .def(py::init(), py::arg("model") = "", - py::arg("scale") = 0.5f) + .def(py::init(), + py::arg("model") = "", py::arg("scale") = 0.5f, + py::arg("lm_num_threads") = 1, py::arg("lm_provider") = "cpu") .def_readwrite("model", &PyClass::model) .def_readwrite("scale", &PyClass::scale) + .def_readwrite("lm_provider", &PyClass::lm_provider) + .def_readwrite("lm_num_threads", &PyClass::lm_num_threads) .def("__str__", &PyClass::ToString); }