rnnlm model inference supports num_threads setting (#169)

Co-authored-by: cuidongcai1035 <cuidongcai1035@wezhuiyi.com>
This commit is contained in:
keanu
2023-06-07 09:32:27 +08:00
committed by GitHub
parent 8fad17c87e
commit 9c017c2ccb
10 changed files with 25 additions and 20 deletions

View File

@@ -12,7 +12,8 @@
namespace sherpa_onnx { namespace sherpa_onnx {
std::unique_ptr<OfflineLM> OfflineLM::Create(const OfflineLMConfig &config) { std::unique_ptr<OfflineLM> OfflineLM::Create(
const OfflineRecognizerConfig &config) {
return std::make_unique<OfflineRnnLM>(config); return std::make_unique<OfflineRnnLM>(config);
} }

View File

@@ -10,7 +10,7 @@
#include "onnxruntime_cxx_api.h" // NOLINT #include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/hypothesis.h" #include "sherpa-onnx/csrc/hypothesis.h"
#include "sherpa-onnx/csrc/offline-lm-config.h" #include "sherpa-onnx/csrc/offline-recognizer.h"
namespace sherpa_onnx { namespace sherpa_onnx {
@@ -18,7 +18,8 @@ class OfflineLM {
public: public:
virtual ~OfflineLM() = default; virtual ~OfflineLM() = default;
static std::unique_ptr<OfflineLM> Create(const OfflineLMConfig &config); static std::unique_ptr<OfflineLM> Create(
const OfflineRecognizerConfig &config);
/** Rescore a batch of sentences. /** Rescore a batch of sentences.
* *

View File

@@ -59,7 +59,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
std::make_unique<OfflineTransducerGreedySearchDecoder>(model_.get()); std::make_unique<OfflineTransducerGreedySearchDecoder>(model_.get());
} else if (config_.decoding_method == "modified_beam_search") { } else if (config_.decoding_method == "modified_beam_search") {
if (!config_.lm_config.model.empty()) { if (!config_.lm_config.model.empty()) {
lm_ = OfflineLM::Create(config.lm_config); lm_ = OfflineLM::Create(config);
} }
decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>( decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>(

View File

@@ -12,17 +12,18 @@
#include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/text-utils.h" #include "sherpa-onnx/csrc/text-utils.h"
#include "sherpa-onnx/csrc/session.h"
namespace sherpa_onnx { namespace sherpa_onnx {
class OfflineRnnLM::Impl { class OfflineRnnLM::Impl {
public: public:
explicit Impl(const OfflineLMConfig &config) explicit Impl(const OfflineRecognizerConfig &config)
: config_(config), : config_(config.lm_config),
env_(ORT_LOGGING_LEVEL_ERROR), env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_{}, sess_opts_{GetSessionOptions(config.model_config)},
allocator_{} { allocator_{} {
Init(config); Init(config.lm_config);
} }
Ort::Value Rescore(Ort::Value x, Ort::Value x_lens) { Ort::Value Rescore(Ort::Value x, Ort::Value x_lens) {
@@ -62,7 +63,7 @@ class OfflineRnnLM::Impl {
std::vector<const char *> output_names_ptr_; std::vector<const char *> output_names_ptr_;
}; };
OfflineRnnLM::OfflineRnnLM(const OfflineLMConfig &config) OfflineRnnLM::OfflineRnnLM(const OfflineRecognizerConfig &config)
: impl_(std::make_unique<Impl>(config)) {} : impl_(std::make_unique<Impl>(config)) {}
OfflineRnnLM::~OfflineRnnLM() = default; OfflineRnnLM::~OfflineRnnLM() = default;

View File

@@ -17,7 +17,7 @@ class OfflineRnnLM : public OfflineLM {
public: public:
~OfflineRnnLM() override; ~OfflineRnnLM() override;
explicit OfflineRnnLM(const OfflineLMConfig &config); explicit OfflineRnnLM(const OfflineRecognizerConfig &config);
/** Rescore a batch of sentences. /** Rescore a batch of sentences.
* *

View File

@@ -13,7 +13,8 @@
namespace sherpa_onnx { namespace sherpa_onnx {
std::unique_ptr<OnlineLM> OnlineLM::Create(const OnlineLMConfig &config) { std::unique_ptr<OnlineLM> OnlineLM::Create(
const OnlineRecognizerConfig &config) {
return std::make_unique<OnlineRnnLM>(config); return std::make_unique<OnlineRnnLM>(config);
} }

View File

@@ -11,7 +11,7 @@
#include "onnxruntime_cxx_api.h" // NOLINT #include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/hypothesis.h" #include "sherpa-onnx/csrc/hypothesis.h"
#include "sherpa-onnx/csrc/online-lm-config.h" #include "sherpa-onnx/csrc/online-recognizer.h"
namespace sherpa_onnx { namespace sherpa_onnx {
@@ -19,7 +19,7 @@ class OnlineLM {
public: public:
virtual ~OnlineLM() = default; virtual ~OnlineLM() = default;
static std::unique_ptr<OnlineLM> Create(const OnlineLMConfig &config); static std::unique_ptr<OnlineLM> Create(const OnlineRecognizerConfig &config);
virtual std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStates() = 0; virtual std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStates() = 0;

View File

@@ -129,7 +129,7 @@ class OnlineRecognizer::Impl {
endpoint_(config_.endpoint_config) { endpoint_(config_.endpoint_config) {
if (config.decoding_method == "modified_beam_search") { if (config.decoding_method == "modified_beam_search") {
if (!config_.lm_config.model.empty()) { if (!config_.lm_config.model.empty()) {
lm_ = OnlineLM::Create(config.lm_config); lm_ = OnlineLM::Create(config);
} }
decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>( decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(

View File

@@ -13,17 +13,18 @@
#include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/text-utils.h" #include "sherpa-onnx/csrc/text-utils.h"
#include "sherpa-onnx/csrc/session.h"
namespace sherpa_onnx { namespace sherpa_onnx {
class OnlineRnnLM::Impl { class OnlineRnnLM::Impl {
public: public:
explicit Impl(const OnlineLMConfig &config) explicit Impl(const OnlineRecognizerConfig &config)
: config_(config), : config_(config.lm_config),
env_(ORT_LOGGING_LEVEL_ERROR), env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_{}, sess_opts_{GetSessionOptions(config.model_config)},
allocator_{} { allocator_{} {
Init(config); Init(config.lm_config);
} }
void ComputeLMScore(float scale, Hypothesis *hyp) { void ComputeLMScore(float scale, Hypothesis *hyp) {
@@ -142,7 +143,7 @@ class OnlineRnnLM::Impl {
int32_t sos_id_ = 1; int32_t sos_id_ = 1;
}; };
OnlineRnnLM::OnlineRnnLM(const OnlineLMConfig &config) OnlineRnnLM::OnlineRnnLM(const OnlineRecognizerConfig &config)
: impl_(std::make_unique<Impl>(config)) {} : impl_(std::make_unique<Impl>(config)) {}
OnlineRnnLM::~OnlineRnnLM() = default; OnlineRnnLM::~OnlineRnnLM() = default;

View File

@@ -20,7 +20,7 @@ class OnlineRnnLM : public OnlineLM {
public: public:
~OnlineRnnLM() override; ~OnlineRnnLM() override;
explicit OnlineRnnLM(const OnlineLMConfig &config); explicit OnlineRnnLM(const OnlineRecognizerConfig &config);
std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStates() override; std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStates() override;