Reduce model initialization time for offline speech recognition (#213)
This commit is contained in:
@@ -387,6 +387,7 @@ void CNonStreamingSpeechRecognitionDlg::InitParaformer() {
|
||||
config_.model_config.tokens = tokens.c_str();
|
||||
config_.model_config.num_threads = 1;
|
||||
config_.model_config.debug = 1;
|
||||
config_.model_config.model_type = "paraformer";
|
||||
|
||||
config_.decoding_method = "greedy_search";
|
||||
config_.max_active_paths = 4;
|
||||
@@ -447,6 +448,7 @@ void CNonStreamingSpeechRecognitionDlg::InitRecognizer() {
|
||||
config_.model_config.tokens = tokens.c_str();
|
||||
config_.model_config.num_threads = 1;
|
||||
config_.model_config.debug = 0;
|
||||
config_.model_config.model_type = "transducer";
|
||||
|
||||
config_.decoding_method = "greedy_search";
|
||||
config_.max_active_paths = 4;
|
||||
|
||||
@@ -76,6 +76,8 @@ namespace SherpaOnnx
|
||||
Tokens = "";
|
||||
NumThreads = 1;
|
||||
Debug = 0;
|
||||
Provider = "cpu";
|
||||
ModelType = "";
|
||||
}
|
||||
public OfflineTransducerModelConfig Transducer;
|
||||
public OfflineParaformerModelConfig Paraformer;
|
||||
@@ -87,6 +89,12 @@ namespace SherpaOnnx
|
||||
public int NumThreads;
|
||||
|
||||
public int Debug;
|
||||
|
||||
[MarshalAs(UnmanagedType.LPStr)]
|
||||
public string Provider;
|
||||
|
||||
[MarshalAs(UnmanagedType.LPStr)]
|
||||
public string ModelType;
|
||||
}
|
||||
|
||||
[StructLayout(LayoutKind.Sequential)]
|
||||
|
||||
@@ -33,23 +33,33 @@ SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer(
|
||||
const SherpaOnnxOnlineRecognizerConfig *config) {
|
||||
sherpa_onnx::OnlineRecognizerConfig recognizer_config;
|
||||
|
||||
recognizer_config.feat_config.sampling_rate = SHERPA_ONNX_OR(config->feat_config.sample_rate, 16000);
|
||||
recognizer_config.feat_config.feature_dim = SHERPA_ONNX_OR(config->feat_config.feature_dim, 80);
|
||||
recognizer_config.feat_config.sampling_rate =
|
||||
SHERPA_ONNX_OR(config->feat_config.sample_rate, 16000);
|
||||
recognizer_config.feat_config.feature_dim =
|
||||
SHERPA_ONNX_OR(config->feat_config.feature_dim, 80);
|
||||
|
||||
recognizer_config.model_config.encoder_filename =
|
||||
SHERPA_ONNX_OR(config->model_config.encoder, "");
|
||||
recognizer_config.model_config.decoder_filename =
|
||||
SHERPA_ONNX_OR(config->model_config.decoder, "");
|
||||
recognizer_config.model_config.joiner_filename = SHERPA_ONNX_OR(config->model_config.joiner, "");
|
||||
recognizer_config.model_config.tokens = SHERPA_ONNX_OR(config->model_config.tokens, "");
|
||||
recognizer_config.model_config.num_threads = SHERPA_ONNX_OR(config->model_config.num_threads, 1);
|
||||
recognizer_config.model_config.provider = SHERPA_ONNX_OR(config->model_config.provider, "cpu");
|
||||
recognizer_config.model_config.debug = SHERPA_ONNX_OR(config->model_config.debug, 0);
|
||||
recognizer_config.model_config.joiner_filename =
|
||||
SHERPA_ONNX_OR(config->model_config.joiner, "");
|
||||
recognizer_config.model_config.tokens =
|
||||
SHERPA_ONNX_OR(config->model_config.tokens, "");
|
||||
recognizer_config.model_config.num_threads =
|
||||
SHERPA_ONNX_OR(config->model_config.num_threads, 1);
|
||||
recognizer_config.model_config.provider =
|
||||
SHERPA_ONNX_OR(config->model_config.provider, "cpu");
|
||||
recognizer_config.model_config.debug =
|
||||
SHERPA_ONNX_OR(config->model_config.debug, 0);
|
||||
|
||||
recognizer_config.decoding_method = SHERPA_ONNX_OR(config->decoding_method, "greedy_search");
|
||||
recognizer_config.max_active_paths = SHERPA_ONNX_OR(config->max_active_paths, 4);
|
||||
recognizer_config.decoding_method =
|
||||
SHERPA_ONNX_OR(config->decoding_method, "greedy_search");
|
||||
recognizer_config.max_active_paths =
|
||||
SHERPA_ONNX_OR(config->max_active_paths, 4);
|
||||
|
||||
recognizer_config.enable_endpoint = SHERPA_ONNX_OR(config->enable_endpoint, 0);
|
||||
recognizer_config.enable_endpoint =
|
||||
SHERPA_ONNX_OR(config->enable_endpoint, 0);
|
||||
|
||||
recognizer_config.endpoint_config.rule1.min_trailing_silence =
|
||||
SHERPA_ONNX_OR(config->rule1_min_trailing_silence, 2.4);
|
||||
@@ -173,9 +183,11 @@ SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer(
|
||||
const SherpaOnnxOfflineRecognizerConfig *config) {
|
||||
sherpa_onnx::OfflineRecognizerConfig recognizer_config;
|
||||
|
||||
recognizer_config.feat_config.sampling_rate = SHERPA_ONNX_OR(config->feat_config.sample_rate, 16000);
|
||||
recognizer_config.feat_config.sampling_rate =
|
||||
SHERPA_ONNX_OR(config->feat_config.sample_rate, 16000);
|
||||
|
||||
recognizer_config.feat_config.feature_dim = SHERPA_ONNX_OR(config->feat_config.feature_dim, 80);
|
||||
recognizer_config.feat_config.feature_dim =
|
||||
SHERPA_ONNX_OR(config->feat_config.feature_dim, 80);
|
||||
|
||||
recognizer_config.model_config.transducer.encoder_filename =
|
||||
SHERPA_ONNX_OR(config->model_config.transducer.encoder, "");
|
||||
@@ -184,7 +196,7 @@ SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer(
|
||||
SHERPA_ONNX_OR(config->model_config.transducer.decoder, "");
|
||||
|
||||
recognizer_config.model_config.transducer.joiner_filename =
|
||||
SHERPA_ONNX_OR(config->model_config.transducer.joiner,"");
|
||||
SHERPA_ONNX_OR(config->model_config.transducer.joiner, "");
|
||||
|
||||
recognizer_config.model_config.paraformer.model =
|
||||
SHERPA_ONNX_OR(config->model_config.paraformer.model, "");
|
||||
@@ -192,15 +204,26 @@ SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer(
|
||||
recognizer_config.model_config.nemo_ctc.model =
|
||||
SHERPA_ONNX_OR(config->model_config.nemo_ctc.model, "");
|
||||
|
||||
recognizer_config.model_config.tokens = SHERPA_ONNX_OR(config->model_config.tokens, "");
|
||||
recognizer_config.model_config.num_threads = SHERPA_ONNX_OR(config->model_config.num_threads, 1);
|
||||
recognizer_config.model_config.debug = SHERPA_ONNX_OR(config->model_config.debug, 0);
|
||||
recognizer_config.model_config.tokens =
|
||||
SHERPA_ONNX_OR(config->model_config.tokens, "");
|
||||
recognizer_config.model_config.num_threads =
|
||||
SHERPA_ONNX_OR(config->model_config.num_threads, 1);
|
||||
recognizer_config.model_config.debug =
|
||||
SHERPA_ONNX_OR(config->model_config.debug, 0);
|
||||
recognizer_config.model_config.provider =
|
||||
SHERPA_ONNX_OR(config->model_config.provider, "cpu");
|
||||
recognizer_config.model_config.model_type =
|
||||
SHERPA_ONNX_OR(config->model_config.model_type, "");
|
||||
|
||||
recognizer_config.lm_config.model = SHERPA_ONNX_OR(config->lm_config.model, "");
|
||||
recognizer_config.lm_config.scale = SHERPA_ONNX_OR(config->lm_config.scale, 1.0);
|
||||
recognizer_config.lm_config.model =
|
||||
SHERPA_ONNX_OR(config->lm_config.model, "");
|
||||
recognizer_config.lm_config.scale =
|
||||
SHERPA_ONNX_OR(config->lm_config.scale, 1.0);
|
||||
|
||||
recognizer_config.decoding_method = SHERPA_ONNX_OR(config->decoding_method, "greedy_search");
|
||||
recognizer_config.max_active_paths = SHERPA_ONNX_OR(config->max_active_paths, 4);
|
||||
recognizer_config.decoding_method =
|
||||
SHERPA_ONNX_OR(config->decoding_method, "greedy_search");
|
||||
recognizer_config.max_active_paths =
|
||||
SHERPA_ONNX_OR(config->max_active_paths, 4);
|
||||
|
||||
if (config->model_config.debug) {
|
||||
fprintf(stderr, "%s\n", recognizer_config.ToString().c_str());
|
||||
|
||||
@@ -272,6 +272,8 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineModelConfig {
|
||||
const char *tokens;
|
||||
int32_t num_threads;
|
||||
int32_t debug;
|
||||
const char *provider;
|
||||
const char *model_type;
|
||||
} SherpaOnnxOfflineModelConfig;
|
||||
|
||||
SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerConfig {
|
||||
|
||||
@@ -25,6 +25,11 @@ void OfflineModelConfig::Register(ParseOptions *po) {
|
||||
|
||||
po->Register("provider", &provider,
|
||||
"Specify a provider to use: cpu, cuda, coreml");
|
||||
|
||||
po->Register("model-type", &model_type,
|
||||
"Specify it to reduce model initialization time. "
|
||||
"Valid values are: transducer, paraformer, nemo_ctc. "
|
||||
"All other values lead to loading the model twice.");
|
||||
}
|
||||
|
||||
bool OfflineModelConfig::Validate() const {
|
||||
@@ -34,7 +39,7 @@ bool OfflineModelConfig::Validate() const {
|
||||
}
|
||||
|
||||
if (!FileExists(tokens)) {
|
||||
SHERPA_ONNX_LOGE("%s does not exist", tokens.c_str());
|
||||
SHERPA_ONNX_LOGE("tokens: %s does not exist", tokens.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -59,7 +64,8 @@ std::string OfflineModelConfig::ToString() const {
|
||||
os << "tokens=\"" << tokens << "\", ";
|
||||
os << "num_threads=" << num_threads << ", ";
|
||||
os << "debug=" << (debug ? "True" : "False") << ", ";
|
||||
os << "provider=\"" << provider << "\")";
|
||||
os << "provider=\"" << provider << "\", ";
|
||||
os << "model_type=\"" << model_type << "\")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
@@ -22,19 +22,31 @@ struct OfflineModelConfig {
|
||||
bool debug = false;
|
||||
std::string provider = "cpu";
|
||||
|
||||
// With the help of this field, we only need to load the model once
|
||||
// instead of twice; and therefore it reduces initialization time.
|
||||
//
|
||||
// Valid values:
|
||||
// - transducer. The given model is from icefall
|
||||
// - paraformer. It is a paraformer model
|
||||
// - nemo_ctc. It is a NeMo CTC model.
|
||||
//
|
||||
// All other values are invalid and lead to loading the model twice.
|
||||
std::string model_type;
|
||||
|
||||
OfflineModelConfig() = default;
|
||||
OfflineModelConfig(const OfflineTransducerModelConfig &transducer,
|
||||
const OfflineParaformerModelConfig ¶former,
|
||||
const OfflineNemoEncDecCtcModelConfig &nemo_ctc,
|
||||
const std::string &tokens, int32_t num_threads, bool debug,
|
||||
const std::string &provider)
|
||||
const std::string &provider, const std::string &model_type)
|
||||
: transducer(transducer),
|
||||
paraformer(paraformer),
|
||||
nemo_ctc(nemo_ctc),
|
||||
tokens(tokens),
|
||||
num_threads(num_threads),
|
||||
debug(debug),
|
||||
provider(provider) {}
|
||||
provider(provider),
|
||||
model_type(model_type) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
|
||||
@@ -18,6 +18,21 @@ namespace sherpa_onnx {
|
||||
|
||||
std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
const OfflineRecognizerConfig &config) {
|
||||
if (!config.model_config.model_type.empty()) {
|
||||
const auto &model_type = config.model_config.model_type;
|
||||
if (model_type == "transducer") {
|
||||
return std::make_unique<OfflineRecognizerTransducerImpl>(config);
|
||||
} else if (model_type == "paraformer") {
|
||||
return std::make_unique<OfflineRecognizerParaformerImpl>(config);
|
||||
} else if (model_type == "nemo_ctc") {
|
||||
return std::make_unique<OfflineRecognizerCtcImpl>(config);
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Invalid model_type: %s. Trying to load the model to get its type",
|
||||
model_type.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
Ort::Env env(ORT_LOGGING_LEVEL_ERROR);
|
||||
|
||||
Ort::SessionOptions sess_opts;
|
||||
|
||||
@@ -18,17 +18,17 @@ void OfflineTransducerModelConfig::Register(ParseOptions *po) {
|
||||
|
||||
bool OfflineTransducerModelConfig::Validate() const {
|
||||
if (!FileExists(encoder_filename)) {
|
||||
SHERPA_ONNX_LOGE("%s does not exist", encoder_filename.c_str());
|
||||
SHERPA_ONNX_LOGE("encoder: %s does not exist", encoder_filename.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!FileExists(decoder_filename)) {
|
||||
SHERPA_ONNX_LOGE("%s does not exist", decoder_filename.c_str());
|
||||
SHERPA_ONNX_LOGE("decoder: %s does not exist", decoder_filename.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!FileExists(joiner_filename)) {
|
||||
SHERPA_ONNX_LOGE("%s does not exist", joiner_filename.c_str());
|
||||
SHERPA_ONNX_LOGE("joiner: %s does not exist", joiner_filename.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
@@ -21,15 +21,16 @@ void PybindOfflineModelConfig(py::module *m) {
|
||||
|
||||
using PyClass = OfflineModelConfig;
|
||||
py::class_<PyClass>(*m, "OfflineModelConfig")
|
||||
.def(py::init<const OfflineTransducerModelConfig &,
|
||||
const OfflineParaformerModelConfig &,
|
||||
const OfflineNemoEncDecCtcModelConfig &,
|
||||
const std::string &, int32_t, bool, const std::string &>(),
|
||||
py::arg("transducer") = OfflineTransducerModelConfig(),
|
||||
py::arg("paraformer") = OfflineParaformerModelConfig(),
|
||||
py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(),
|
||||
py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false,
|
||||
py::arg("provider") = "cpu")
|
||||
.def(
|
||||
py::init<const OfflineTransducerModelConfig &,
|
||||
const OfflineParaformerModelConfig &,
|
||||
const OfflineNemoEncDecCtcModelConfig &, const std::string &,
|
||||
int32_t, bool, const std::string &, const std::string &>(),
|
||||
py::arg("transducer") = OfflineTransducerModelConfig(),
|
||||
py::arg("paraformer") = OfflineParaformerModelConfig(),
|
||||
py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(),
|
||||
py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false,
|
||||
py::arg("provider") = "cpu", py::arg("model_type") = "")
|
||||
.def_readwrite("transducer", &PyClass::transducer)
|
||||
.def_readwrite("paraformer", &PyClass::paraformer)
|
||||
.def_readwrite("nemo_ctc", &PyClass::nemo_ctc)
|
||||
@@ -37,6 +38,7 @@ void PybindOfflineModelConfig(py::module *m) {
|
||||
.def_readwrite("num_threads", &PyClass::num_threads)
|
||||
.def_readwrite("debug", &PyClass::debug)
|
||||
.def_readwrite("provider", &PyClass::provider)
|
||||
.def_readwrite("model_type", &PyClass::model_type)
|
||||
.def("__str__", &PyClass::ToString);
|
||||
}
|
||||
|
||||
|
||||
@@ -86,6 +86,7 @@ class OfflineRecognizer(object):
|
||||
num_threads=num_threads,
|
||||
debug=debug,
|
||||
provider=provider,
|
||||
model_type="transducer",
|
||||
)
|
||||
|
||||
feat_config = OfflineFeatureExtractorConfig(
|
||||
@@ -149,6 +150,7 @@ class OfflineRecognizer(object):
|
||||
num_threads=num_threads,
|
||||
debug=debug,
|
||||
provider=provider,
|
||||
model_type="paraformer",
|
||||
)
|
||||
|
||||
feat_config = OfflineFeatureExtractorConfig(
|
||||
@@ -211,6 +213,7 @@ class OfflineRecognizer(object):
|
||||
num_threads=num_threads,
|
||||
debug=debug,
|
||||
provider=provider,
|
||||
model_type="nemo_ctc",
|
||||
)
|
||||
|
||||
feat_config = OfflineFeatureExtractorConfig(
|
||||
|
||||
Reference in New Issue
Block a user