diff --git a/sherpa-onnx/csrc/offline-lm.cc b/sherpa-onnx/csrc/offline-lm.cc index f76dcfd6..d1301393 100644 --- a/sherpa-onnx/csrc/offline-lm.cc +++ b/sherpa-onnx/csrc/offline-lm.cc @@ -12,7 +12,8 @@ namespace sherpa_onnx { -std::unique_ptr OfflineLM::Create(const OfflineLMConfig &config) { +std::unique_ptr OfflineLM::Create( + const OfflineRecognizerConfig &config) { return std::make_unique(config); } diff --git a/sherpa-onnx/csrc/offline-lm.h b/sherpa-onnx/csrc/offline-lm.h index f99a8ad9..73a8051c 100644 --- a/sherpa-onnx/csrc/offline-lm.h +++ b/sherpa-onnx/csrc/offline-lm.h @@ -10,7 +10,7 @@ #include "onnxruntime_cxx_api.h" // NOLINT #include "sherpa-onnx/csrc/hypothesis.h" -#include "sherpa-onnx/csrc/offline-lm-config.h" +#include "sherpa-onnx/csrc/offline-recognizer.h" namespace sherpa_onnx { @@ -18,7 +18,8 @@ class OfflineLM { public: virtual ~OfflineLM() = default; - static std::unique_ptr Create(const OfflineLMConfig &config); + static std::unique_ptr Create( + const OfflineRecognizerConfig &config); /** Rescore a batch of sentences. * diff --git a/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h b/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h index d8360dce..0ff36c63 100644 --- a/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h @@ -59,7 +59,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { std::make_unique(model_.get()); } else if (config_.decoding_method == "modified_beam_search") { if (!config_.lm_config.model.empty()) { - lm_ = OfflineLM::Create(config.lm_config); + lm_ = OfflineLM::Create(config); } decoder_ = std::make_unique( diff --git a/sherpa-onnx/csrc/offline-rnn-lm.cc b/sherpa-onnx/csrc/offline-rnn-lm.cc index a16118a7..da717aba 100644 --- a/sherpa-onnx/csrc/offline-rnn-lm.cc +++ b/sherpa-onnx/csrc/offline-rnn-lm.cc @@ -12,17 +12,18 @@ #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/text-utils.h" +#include "sherpa-onnx/csrc/session.h" namespace sherpa_onnx { class OfflineRnnLM::Impl { public: - explicit Impl(const OfflineLMConfig &config) - : config_(config), + explicit Impl(const OfflineRecognizerConfig &config) + : config_(config.lm_config), env_(ORT_LOGGING_LEVEL_ERROR), - sess_opts_{}, + sess_opts_{GetSessionOptions(config.model_config)}, allocator_{} { - Init(config); + Init(config.lm_config); } Ort::Value Rescore(Ort::Value x, Ort::Value x_lens) { @@ -62,7 +63,7 @@ class OfflineRnnLM::Impl { std::vector output_names_ptr_; }; -OfflineRnnLM::OfflineRnnLM(const OfflineLMConfig &config) +OfflineRnnLM::OfflineRnnLM(const OfflineRecognizerConfig &config) : impl_(std::make_unique(config)) {} OfflineRnnLM::~OfflineRnnLM() = default; diff --git a/sherpa-onnx/csrc/offline-rnn-lm.h b/sherpa-onnx/csrc/offline-rnn-lm.h index e9f8cc97..e0e9f804 100644 --- a/sherpa-onnx/csrc/offline-rnn-lm.h +++ b/sherpa-onnx/csrc/offline-rnn-lm.h @@ -17,7 +17,7 @@ class OfflineRnnLM : public OfflineLM { public: ~OfflineRnnLM() override; - explicit OfflineRnnLM(const OfflineLMConfig &config); + explicit OfflineRnnLM(const OfflineRecognizerConfig &config); /** Rescore a batch of sentences. * diff --git a/sherpa-onnx/csrc/online-lm.cc b/sherpa-onnx/csrc/online-lm.cc index dfec00cc..964a0fd7 100644 --- a/sherpa-onnx/csrc/online-lm.cc +++ b/sherpa-onnx/csrc/online-lm.cc @@ -13,7 +13,8 @@ namespace sherpa_onnx { -std::unique_ptr OnlineLM::Create(const OnlineLMConfig &config) { +std::unique_ptr OnlineLM::Create( + const OnlineRecognizerConfig &config) { return std::make_unique(config); } diff --git a/sherpa-onnx/csrc/online-lm.h b/sherpa-onnx/csrc/online-lm.h index 6c73f46c..fc22e43b 100644 --- a/sherpa-onnx/csrc/online-lm.h +++ b/sherpa-onnx/csrc/online-lm.h @@ -11,7 +11,7 @@ #include "onnxruntime_cxx_api.h" // NOLINT #include "sherpa-onnx/csrc/hypothesis.h" -#include "sherpa-onnx/csrc/online-lm-config.h" +#include "sherpa-onnx/csrc/online-recognizer.h" namespace sherpa_onnx { @@ -19,7 +19,7 @@ class OnlineLM { public: virtual ~OnlineLM() = default; - static std::unique_ptr Create(const OnlineLMConfig &config); + static std::unique_ptr Create(const OnlineRecognizerConfig &config); virtual std::pair> GetInitStates() = 0; diff --git a/sherpa-onnx/csrc/online-recognizer.cc b/sherpa-onnx/csrc/online-recognizer.cc index 06aab880..8489c174 100644 --- a/sherpa-onnx/csrc/online-recognizer.cc +++ b/sherpa-onnx/csrc/online-recognizer.cc @@ -129,7 +129,7 @@ class OnlineRecognizer::Impl { endpoint_(config_.endpoint_config) { if (config.decoding_method == "modified_beam_search") { if (!config_.lm_config.model.empty()) { - lm_ = OnlineLM::Create(config.lm_config); + lm_ = OnlineLM::Create(config); } decoder_ = std::make_unique( diff --git a/sherpa-onnx/csrc/online-rnn-lm.cc b/sherpa-onnx/csrc/online-rnn-lm.cc index f23e1ef7..38149326 100644 --- a/sherpa-onnx/csrc/online-rnn-lm.cc +++ b/sherpa-onnx/csrc/online-rnn-lm.cc @@ -13,17 +13,18 @@ #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/text-utils.h" +#include "sherpa-onnx/csrc/session.h" namespace sherpa_onnx { class OnlineRnnLM::Impl { public: - explicit Impl(const OnlineLMConfig &config) - : config_(config), + explicit Impl(const OnlineRecognizerConfig &config) + : config_(config.lm_config), env_(ORT_LOGGING_LEVEL_ERROR), - sess_opts_{}, + sess_opts_{GetSessionOptions(config.model_config)}, allocator_{} { - Init(config); + Init(config.lm_config); } void ComputeLMScore(float scale, Hypothesis *hyp) { @@ -142,7 +143,7 @@ class OnlineRnnLM::Impl { int32_t sos_id_ = 1; }; -OnlineRnnLM::OnlineRnnLM(const OnlineLMConfig &config) +OnlineRnnLM::OnlineRnnLM(const OnlineRecognizerConfig &config) : impl_(std::make_unique(config)) {} OnlineRnnLM::~OnlineRnnLM() = default; diff --git a/sherpa-onnx/csrc/online-rnn-lm.h b/sherpa-onnx/csrc/online-rnn-lm.h index dee17f73..00cac454 100644 --- a/sherpa-onnx/csrc/online-rnn-lm.h +++ b/sherpa-onnx/csrc/online-rnn-lm.h @@ -20,7 +20,7 @@ class OnlineRnnLM : public OnlineLM { public: ~OnlineRnnLM() override; - explicit OnlineRnnLM(const OnlineLMConfig &config); + explicit OnlineRnnLM(const OnlineRecognizerConfig &config); std::pair> GetInitStates() override;