RNNLM model support lm_num_thread and lm_provider setting (#173)
* rnnlm model inference supports num_threads setting * rnnlm params decouple num_thread and provider with Transducer. * fix python csrc bug which offline-lm-config.cc and online-lm-config.cc arguments problem * lm_num_threads and lm_provider set default values --------- Co-authored-by: cuidongcai1035 <cuidongcai1035@wezhuiyi.com>
This commit is contained in:
@@ -14,6 +14,10 @@ namespace sherpa_onnx {
|
|||||||
void OfflineLMConfig::Register(ParseOptions *po) {
|
void OfflineLMConfig::Register(ParseOptions *po) {
|
||||||
po->Register("lm", &model, "Path to LM model.");
|
po->Register("lm", &model, "Path to LM model.");
|
||||||
po->Register("lm-scale", &scale, "LM scale.");
|
po->Register("lm-scale", &scale, "LM scale.");
|
||||||
|
po->Register("lm-num-threads", &lm_num_threads,
|
||||||
|
"Number of threads to run the neural network of LM model");
|
||||||
|
po->Register("lm-provider", &lm_provider,
|
||||||
|
"Specify a provider to LM model use: cpu, cuda, coreml");
|
||||||
}
|
}
|
||||||
|
|
||||||
bool OfflineLMConfig::Validate() const {
|
bool OfflineLMConfig::Validate() const {
|
||||||
|
|||||||
@@ -16,11 +16,17 @@ struct OfflineLMConfig {
|
|||||||
|
|
||||||
// LM scale
|
// LM scale
|
||||||
float scale = 0.5;
|
float scale = 0.5;
|
||||||
|
int32_t lm_num_threads = 1;
|
||||||
|
std::string lm_provider = "cpu";
|
||||||
|
|
||||||
OfflineLMConfig() = default;
|
OfflineLMConfig() = default;
|
||||||
|
|
||||||
OfflineLMConfig(const std::string &model, float scale)
|
OfflineLMConfig(const std::string &model, float scale, int32_t lm_num_threads,
|
||||||
: model(model), scale(scale) {}
|
const std::string &lm_provider)
|
||||||
|
: model(model),
|
||||||
|
scale(scale),
|
||||||
|
lm_num_threads(lm_num_threads),
|
||||||
|
lm_provider(lm_provider) {}
|
||||||
|
|
||||||
void Register(ParseOptions *po);
|
void Register(ParseOptions *po);
|
||||||
bool Validate() const;
|
bool Validate() const;
|
||||||
|
|||||||
@@ -12,8 +12,7 @@
|
|||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
std::unique_ptr<OfflineLM> OfflineLM::Create(
|
std::unique_ptr<OfflineLM> OfflineLM::Create(const OfflineLMConfig &config) {
|
||||||
const OfflineRecognizerConfig &config) {
|
|
||||||
return std::make_unique<OfflineRnnLM>(config);
|
return std::make_unique<OfflineRnnLM>(config);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -10,7 +10,7 @@
|
|||||||
|
|
||||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||||
#include "sherpa-onnx/csrc/hypothesis.h"
|
#include "sherpa-onnx/csrc/hypothesis.h"
|
||||||
#include "sherpa-onnx/csrc/offline-recognizer.h"
|
#include "sherpa-onnx/csrc/offline-lm-config.h"
|
||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
@@ -18,8 +18,7 @@ class OfflineLM {
|
|||||||
public:
|
public:
|
||||||
virtual ~OfflineLM() = default;
|
virtual ~OfflineLM() = default;
|
||||||
|
|
||||||
static std::unique_ptr<OfflineLM> Create(
|
static std::unique_ptr<OfflineLM> Create(const OfflineLMConfig &config);
|
||||||
const OfflineRecognizerConfig &config);
|
|
||||||
|
|
||||||
/** Rescore a batch of sentences.
|
/** Rescore a batch of sentences.
|
||||||
*
|
*
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
|
|||||||
std::make_unique<OfflineTransducerGreedySearchDecoder>(model_.get());
|
std::make_unique<OfflineTransducerGreedySearchDecoder>(model_.get());
|
||||||
} else if (config_.decoding_method == "modified_beam_search") {
|
} else if (config_.decoding_method == "modified_beam_search") {
|
||||||
if (!config_.lm_config.model.empty()) {
|
if (!config_.lm_config.model.empty()) {
|
||||||
lm_ = OfflineLM::Create(config);
|
lm_ = OfflineLM::Create(config.lm_config);
|
||||||
}
|
}
|
||||||
|
|
||||||
decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>(
|
decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>(
|
||||||
|
|||||||
@@ -18,12 +18,12 @@ namespace sherpa_onnx {
|
|||||||
|
|
||||||
class OfflineRnnLM::Impl {
|
class OfflineRnnLM::Impl {
|
||||||
public:
|
public:
|
||||||
explicit Impl(const OfflineRecognizerConfig &config)
|
explicit Impl(const OfflineLMConfig &config)
|
||||||
: config_(config.lm_config),
|
: config_(config),
|
||||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||||
sess_opts_{GetSessionOptions(config.model_config)},
|
sess_opts_{GetSessionOptions(config)},
|
||||||
allocator_{} {
|
allocator_{} {
|
||||||
Init(config.lm_config);
|
Init(config);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ort::Value Rescore(Ort::Value x, Ort::Value x_lens) {
|
Ort::Value Rescore(Ort::Value x, Ort::Value x_lens) {
|
||||||
@@ -63,7 +63,7 @@ class OfflineRnnLM::Impl {
|
|||||||
std::vector<const char *> output_names_ptr_;
|
std::vector<const char *> output_names_ptr_;
|
||||||
};
|
};
|
||||||
|
|
||||||
OfflineRnnLM::OfflineRnnLM(const OfflineRecognizerConfig &config)
|
OfflineRnnLM::OfflineRnnLM(const OfflineLMConfig &config)
|
||||||
: impl_(std::make_unique<Impl>(config)) {}
|
: impl_(std::make_unique<Impl>(config)) {}
|
||||||
|
|
||||||
OfflineRnnLM::~OfflineRnnLM() = default;
|
OfflineRnnLM::~OfflineRnnLM() = default;
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ class OfflineRnnLM : public OfflineLM {
|
|||||||
public:
|
public:
|
||||||
~OfflineRnnLM() override;
|
~OfflineRnnLM() override;
|
||||||
|
|
||||||
explicit OfflineRnnLM(const OfflineRecognizerConfig &config);
|
explicit OfflineRnnLM(const OfflineLMConfig &config);
|
||||||
|
|
||||||
/** Rescore a batch of sentences.
|
/** Rescore a batch of sentences.
|
||||||
*
|
*
|
||||||
|
|||||||
@@ -14,6 +14,10 @@ namespace sherpa_onnx {
|
|||||||
void OnlineLMConfig::Register(ParseOptions *po) {
|
void OnlineLMConfig::Register(ParseOptions *po) {
|
||||||
po->Register("lm", &model, "Path to LM model.");
|
po->Register("lm", &model, "Path to LM model.");
|
||||||
po->Register("lm-scale", &scale, "LM scale.");
|
po->Register("lm-scale", &scale, "LM scale.");
|
||||||
|
po->Register("lm-num-threads", &lm_num_threads,
|
||||||
|
"Number of threads to run the neural network of LM model");
|
||||||
|
po->Register("lm-provider", &lm_provider,
|
||||||
|
"Specify a provider to LM model use: cpu, cuda, coreml");
|
||||||
}
|
}
|
||||||
|
|
||||||
bool OnlineLMConfig::Validate() const {
|
bool OnlineLMConfig::Validate() const {
|
||||||
|
|||||||
@@ -16,11 +16,17 @@ struct OnlineLMConfig {
|
|||||||
|
|
||||||
// LM scale
|
// LM scale
|
||||||
float scale = 0.5;
|
float scale = 0.5;
|
||||||
|
int32_t lm_num_threads = 1;
|
||||||
|
std::string lm_provider = "cpu";
|
||||||
|
|
||||||
OnlineLMConfig() = default;
|
OnlineLMConfig() = default;
|
||||||
|
|
||||||
OnlineLMConfig(const std::string &model, float scale)
|
OnlineLMConfig(const std::string &model, float scale, int32_t lm_num_threads,
|
||||||
: model(model), scale(scale) {}
|
const std::string &lm_provider)
|
||||||
|
: model(model),
|
||||||
|
scale(scale),
|
||||||
|
lm_num_threads(lm_num_threads),
|
||||||
|
lm_provider(lm_provider) {}
|
||||||
|
|
||||||
void Register(ParseOptions *po);
|
void Register(ParseOptions *po);
|
||||||
bool Validate() const;
|
bool Validate() const;
|
||||||
|
|||||||
@@ -13,8 +13,7 @@
|
|||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
std::unique_ptr<OnlineLM> OnlineLM::Create(
|
std::unique_ptr<OnlineLM> OnlineLM::Create(const OnlineLMConfig &config) {
|
||||||
const OnlineRecognizerConfig &config) {
|
|
||||||
return std::make_unique<OnlineRnnLM>(config);
|
return std::make_unique<OnlineRnnLM>(config);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,7 @@
|
|||||||
|
|
||||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||||
#include "sherpa-onnx/csrc/hypothesis.h"
|
#include "sherpa-onnx/csrc/hypothesis.h"
|
||||||
#include "sherpa-onnx/csrc/online-recognizer.h"
|
#include "sherpa-onnx/csrc/online-lm-config.h"
|
||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
@@ -19,7 +19,7 @@ class OnlineLM {
|
|||||||
public:
|
public:
|
||||||
virtual ~OnlineLM() = default;
|
virtual ~OnlineLM() = default;
|
||||||
|
|
||||||
static std::unique_ptr<OnlineLM> Create(const OnlineRecognizerConfig &config);
|
static std::unique_ptr<OnlineLM> Create(const OnlineLMConfig &config);
|
||||||
|
|
||||||
virtual std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStates() = 0;
|
virtual std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStates() = 0;
|
||||||
|
|
||||||
|
|||||||
@@ -129,7 +129,7 @@ class OnlineRecognizer::Impl {
|
|||||||
endpoint_(config_.endpoint_config) {
|
endpoint_(config_.endpoint_config) {
|
||||||
if (config.decoding_method == "modified_beam_search") {
|
if (config.decoding_method == "modified_beam_search") {
|
||||||
if (!config_.lm_config.model.empty()) {
|
if (!config_.lm_config.model.empty()) {
|
||||||
lm_ = OnlineLM::Create(config);
|
lm_ = OnlineLM::Create(config.lm_config);
|
||||||
}
|
}
|
||||||
|
|
||||||
decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
|
decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
|
||||||
|
|||||||
@@ -19,12 +19,12 @@ namespace sherpa_onnx {
|
|||||||
|
|
||||||
class OnlineRnnLM::Impl {
|
class OnlineRnnLM::Impl {
|
||||||
public:
|
public:
|
||||||
explicit Impl(const OnlineRecognizerConfig &config)
|
explicit Impl(const OnlineLMConfig &config)
|
||||||
: config_(config.lm_config),
|
: config_(config),
|
||||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||||
sess_opts_{GetSessionOptions(config.model_config)},
|
sess_opts_{GetSessionOptions(config)},
|
||||||
allocator_{} {
|
allocator_{} {
|
||||||
Init(config.lm_config);
|
Init(config);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ComputeLMScore(float scale, Hypothesis *hyp) {
|
void ComputeLMScore(float scale, Hypothesis *hyp) {
|
||||||
@@ -143,7 +143,7 @@ class OnlineRnnLM::Impl {
|
|||||||
int32_t sos_id_ = 1;
|
int32_t sos_id_ = 1;
|
||||||
};
|
};
|
||||||
|
|
||||||
OnlineRnnLM::OnlineRnnLM(const OnlineRecognizerConfig &config)
|
OnlineRnnLM::OnlineRnnLM(const OnlineLMConfig &config)
|
||||||
: impl_(std::make_unique<Impl>(config)) {}
|
: impl_(std::make_unique<Impl>(config)) {}
|
||||||
|
|
||||||
OnlineRnnLM::~OnlineRnnLM() = default;
|
OnlineRnnLM::~OnlineRnnLM() = default;
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ class OnlineRnnLM : public OnlineLM {
|
|||||||
public:
|
public:
|
||||||
~OnlineRnnLM() override;
|
~OnlineRnnLM() override;
|
||||||
|
|
||||||
explicit OnlineRnnLM(const OnlineRecognizerConfig &config);
|
explicit OnlineRnnLM(const OnlineLMConfig &config);
|
||||||
|
|
||||||
std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStates() override;
|
std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStates() override;
|
||||||
|
|
||||||
|
|||||||
@@ -69,4 +69,12 @@ Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config) {
|
|||||||
return GetSessionOptionsImpl(config.num_threads, config.provider);
|
return GetSessionOptionsImpl(config.num_threads, config.provider);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Ort::SessionOptions GetSessionOptions(const OfflineLMConfig &config) {
|
||||||
|
return GetSessionOptionsImpl(config.lm_num_threads, config.lm_provider);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ort::SessionOptions GetSessionOptions(const OnlineLMConfig &config) {
|
||||||
|
return GetSessionOptionsImpl(config.lm_num_threads, config.lm_provider);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|||||||
@@ -6,7 +6,9 @@
|
|||||||
#define SHERPA_ONNX_CSRC_SESSION_H_
|
#define SHERPA_ONNX_CSRC_SESSION_H_
|
||||||
|
|
||||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||||
|
#include "sherpa-onnx/csrc/offline-lm-config.h"
|
||||||
#include "sherpa-onnx/csrc/offline-model-config.h"
|
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||||
|
#include "sherpa-onnx/csrc/online-lm-config.h"
|
||||||
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
|
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
|
||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
@@ -16,6 +18,9 @@ Ort::SessionOptions GetSessionOptions(
|
|||||||
|
|
||||||
Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config);
|
Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config);
|
||||||
|
|
||||||
|
Ort::SessionOptions GetSessionOptions(const OfflineLMConfig &config);
|
||||||
|
|
||||||
|
Ort::SessionOptions GetSessionOptions(const OnlineLMConfig &config);
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
#endif // SHERPA_ONNX_CSRC_SESSION_H_
|
#endif // SHERPA_ONNX_CSRC_SESSION_H_
|
||||||
|
|||||||
@@ -13,10 +13,13 @@ namespace sherpa_onnx {
|
|||||||
void PybindOfflineLMConfig(py::module *m) {
|
void PybindOfflineLMConfig(py::module *m) {
|
||||||
using PyClass = OfflineLMConfig;
|
using PyClass = OfflineLMConfig;
|
||||||
py::class_<PyClass>(*m, "OfflineLMConfig")
|
py::class_<PyClass>(*m, "OfflineLMConfig")
|
||||||
.def(py::init<const std::string &, float>(), py::arg("model"),
|
.def(py::init<const std::string &, float, int32_t, const std::string &>(),
|
||||||
py::arg("scale"))
|
py::arg("model"), py::arg("scale") = 0.5f,
|
||||||
|
py::arg("lm_num_threads") = 1, py::arg("lm-provider") = "cpu")
|
||||||
.def_readwrite("model", &PyClass::model)
|
.def_readwrite("model", &PyClass::model)
|
||||||
.def_readwrite("scale", &PyClass::scale)
|
.def_readwrite("scale", &PyClass::scale)
|
||||||
|
.def_readwrite("lm_provider", &PyClass::lm_provider)
|
||||||
|
.def_readwrite("lm_num_threads", &PyClass::lm_num_threads)
|
||||||
.def("__str__", &PyClass::ToString);
|
.def("__str__", &PyClass::ToString);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -13,10 +13,13 @@ namespace sherpa_onnx {
|
|||||||
void PybindOnlineLMConfig(py::module *m) {
|
void PybindOnlineLMConfig(py::module *m) {
|
||||||
using PyClass = OnlineLMConfig;
|
using PyClass = OnlineLMConfig;
|
||||||
py::class_<PyClass>(*m, "OnlineLMConfig")
|
py::class_<PyClass>(*m, "OnlineLMConfig")
|
||||||
.def(py::init<const std::string &, float>(), py::arg("model") = "",
|
.def(py::init<const std::string &, float, int32_t, const std::string &>(),
|
||||||
py::arg("scale") = 0.5f)
|
py::arg("model") = "", py::arg("scale") = 0.5f,
|
||||||
|
py::arg("lm_num_threads") = 1, py::arg("lm_provider") = "cpu")
|
||||||
.def_readwrite("model", &PyClass::model)
|
.def_readwrite("model", &PyClass::model)
|
||||||
.def_readwrite("scale", &PyClass::scale)
|
.def_readwrite("scale", &PyClass::scale)
|
||||||
|
.def_readwrite("lm_provider", &PyClass::lm_provider)
|
||||||
|
.def_readwrite("lm_num_threads", &PyClass::lm_num_threads)
|
||||||
.def("__str__", &PyClass::ToString);
|
.def("__str__", &PyClass::ToString);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user