Reduce model initialization time for offline speech recognition (#213)

This commit is contained in:
Fangjun Kuang
2023-07-14 18:07:27 +08:00
committed by GitHub
parent 0abd7ce881
commit f3206c49dc
10 changed files with 109 additions and 36 deletions

View File

@@ -387,6 +387,7 @@ void CNonStreamingSpeechRecognitionDlg::InitParaformer() {
config_.model_config.tokens = tokens.c_str(); config_.model_config.tokens = tokens.c_str();
config_.model_config.num_threads = 1; config_.model_config.num_threads = 1;
config_.model_config.debug = 1; config_.model_config.debug = 1;
config_.model_config.model_type = "paraformer";
config_.decoding_method = "greedy_search"; config_.decoding_method = "greedy_search";
config_.max_active_paths = 4; config_.max_active_paths = 4;
@@ -447,6 +448,7 @@ void CNonStreamingSpeechRecognitionDlg::InitRecognizer() {
config_.model_config.tokens = tokens.c_str(); config_.model_config.tokens = tokens.c_str();
config_.model_config.num_threads = 1; config_.model_config.num_threads = 1;
config_.model_config.debug = 0; config_.model_config.debug = 0;
config_.model_config.model_type = "transducer";
config_.decoding_method = "greedy_search"; config_.decoding_method = "greedy_search";
config_.max_active_paths = 4; config_.max_active_paths = 4;

View File

@@ -76,6 +76,8 @@ namespace SherpaOnnx
Tokens = ""; Tokens = "";
NumThreads = 1; NumThreads = 1;
Debug = 0; Debug = 0;
Provider = "cpu";
ModelType = "";
} }
public OfflineTransducerModelConfig Transducer; public OfflineTransducerModelConfig Transducer;
public OfflineParaformerModelConfig Paraformer; public OfflineParaformerModelConfig Paraformer;
@@ -87,6 +89,12 @@ namespace SherpaOnnx
public int NumThreads; public int NumThreads;
public int Debug; public int Debug;
[MarshalAs(UnmanagedType.LPStr)]
public string Provider;
[MarshalAs(UnmanagedType.LPStr)]
public string ModelType;
} }
[StructLayout(LayoutKind.Sequential)] [StructLayout(LayoutKind.Sequential)]

View File

@@ -33,23 +33,33 @@ SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer(
const SherpaOnnxOnlineRecognizerConfig *config) { const SherpaOnnxOnlineRecognizerConfig *config) {
sherpa_onnx::OnlineRecognizerConfig recognizer_config; sherpa_onnx::OnlineRecognizerConfig recognizer_config;
recognizer_config.feat_config.sampling_rate = SHERPA_ONNX_OR(config->feat_config.sample_rate, 16000); recognizer_config.feat_config.sampling_rate =
recognizer_config.feat_config.feature_dim = SHERPA_ONNX_OR(config->feat_config.feature_dim, 80); SHERPA_ONNX_OR(config->feat_config.sample_rate, 16000);
recognizer_config.feat_config.feature_dim =
SHERPA_ONNX_OR(config->feat_config.feature_dim, 80);
recognizer_config.model_config.encoder_filename = recognizer_config.model_config.encoder_filename =
SHERPA_ONNX_OR(config->model_config.encoder, ""); SHERPA_ONNX_OR(config->model_config.encoder, "");
recognizer_config.model_config.decoder_filename = recognizer_config.model_config.decoder_filename =
SHERPA_ONNX_OR(config->model_config.decoder, ""); SHERPA_ONNX_OR(config->model_config.decoder, "");
recognizer_config.model_config.joiner_filename = SHERPA_ONNX_OR(config->model_config.joiner, ""); recognizer_config.model_config.joiner_filename =
recognizer_config.model_config.tokens = SHERPA_ONNX_OR(config->model_config.tokens, ""); SHERPA_ONNX_OR(config->model_config.joiner, "");
recognizer_config.model_config.num_threads = SHERPA_ONNX_OR(config->model_config.num_threads, 1); recognizer_config.model_config.tokens =
recognizer_config.model_config.provider = SHERPA_ONNX_OR(config->model_config.provider, "cpu"); SHERPA_ONNX_OR(config->model_config.tokens, "");
recognizer_config.model_config.debug = SHERPA_ONNX_OR(config->model_config.debug, 0); recognizer_config.model_config.num_threads =
SHERPA_ONNX_OR(config->model_config.num_threads, 1);
recognizer_config.model_config.provider =
SHERPA_ONNX_OR(config->model_config.provider, "cpu");
recognizer_config.model_config.debug =
SHERPA_ONNX_OR(config->model_config.debug, 0);
recognizer_config.decoding_method = SHERPA_ONNX_OR(config->decoding_method, "greedy_search"); recognizer_config.decoding_method =
recognizer_config.max_active_paths = SHERPA_ONNX_OR(config->max_active_paths, 4); SHERPA_ONNX_OR(config->decoding_method, "greedy_search");
recognizer_config.max_active_paths =
SHERPA_ONNX_OR(config->max_active_paths, 4);
recognizer_config.enable_endpoint = SHERPA_ONNX_OR(config->enable_endpoint, 0); recognizer_config.enable_endpoint =
SHERPA_ONNX_OR(config->enable_endpoint, 0);
recognizer_config.endpoint_config.rule1.min_trailing_silence = recognizer_config.endpoint_config.rule1.min_trailing_silence =
SHERPA_ONNX_OR(config->rule1_min_trailing_silence, 2.4); SHERPA_ONNX_OR(config->rule1_min_trailing_silence, 2.4);
@@ -173,9 +183,11 @@ SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer(
const SherpaOnnxOfflineRecognizerConfig *config) { const SherpaOnnxOfflineRecognizerConfig *config) {
sherpa_onnx::OfflineRecognizerConfig recognizer_config; sherpa_onnx::OfflineRecognizerConfig recognizer_config;
recognizer_config.feat_config.sampling_rate = SHERPA_ONNX_OR(config->feat_config.sample_rate, 16000); recognizer_config.feat_config.sampling_rate =
SHERPA_ONNX_OR(config->feat_config.sample_rate, 16000);
recognizer_config.feat_config.feature_dim = SHERPA_ONNX_OR(config->feat_config.feature_dim, 80); recognizer_config.feat_config.feature_dim =
SHERPA_ONNX_OR(config->feat_config.feature_dim, 80);
recognizer_config.model_config.transducer.encoder_filename = recognizer_config.model_config.transducer.encoder_filename =
SHERPA_ONNX_OR(config->model_config.transducer.encoder, ""); SHERPA_ONNX_OR(config->model_config.transducer.encoder, "");
@@ -184,7 +196,7 @@ SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer(
SHERPA_ONNX_OR(config->model_config.transducer.decoder, ""); SHERPA_ONNX_OR(config->model_config.transducer.decoder, "");
recognizer_config.model_config.transducer.joiner_filename = recognizer_config.model_config.transducer.joiner_filename =
SHERPA_ONNX_OR(config->model_config.transducer.joiner,""); SHERPA_ONNX_OR(config->model_config.transducer.joiner, "");
recognizer_config.model_config.paraformer.model = recognizer_config.model_config.paraformer.model =
SHERPA_ONNX_OR(config->model_config.paraformer.model, ""); SHERPA_ONNX_OR(config->model_config.paraformer.model, "");
@@ -192,15 +204,26 @@ SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer(
recognizer_config.model_config.nemo_ctc.model = recognizer_config.model_config.nemo_ctc.model =
SHERPA_ONNX_OR(config->model_config.nemo_ctc.model, ""); SHERPA_ONNX_OR(config->model_config.nemo_ctc.model, "");
recognizer_config.model_config.tokens = SHERPA_ONNX_OR(config->model_config.tokens, ""); recognizer_config.model_config.tokens =
recognizer_config.model_config.num_threads = SHERPA_ONNX_OR(config->model_config.num_threads, 1); SHERPA_ONNX_OR(config->model_config.tokens, "");
recognizer_config.model_config.debug = SHERPA_ONNX_OR(config->model_config.debug, 0); recognizer_config.model_config.num_threads =
SHERPA_ONNX_OR(config->model_config.num_threads, 1);
recognizer_config.model_config.debug =
SHERPA_ONNX_OR(config->model_config.debug, 0);
recognizer_config.model_config.provider =
SHERPA_ONNX_OR(config->model_config.provider, "cpu");
recognizer_config.model_config.model_type =
SHERPA_ONNX_OR(config->model_config.model_type, "");
recognizer_config.lm_config.model = SHERPA_ONNX_OR(config->lm_config.model, ""); recognizer_config.lm_config.model =
recognizer_config.lm_config.scale = SHERPA_ONNX_OR(config->lm_config.scale, 1.0); SHERPA_ONNX_OR(config->lm_config.model, "");
recognizer_config.lm_config.scale =
SHERPA_ONNX_OR(config->lm_config.scale, 1.0);
recognizer_config.decoding_method = SHERPA_ONNX_OR(config->decoding_method, "greedy_search"); recognizer_config.decoding_method =
recognizer_config.max_active_paths = SHERPA_ONNX_OR(config->max_active_paths, 4); SHERPA_ONNX_OR(config->decoding_method, "greedy_search");
recognizer_config.max_active_paths =
SHERPA_ONNX_OR(config->max_active_paths, 4);
if (config->model_config.debug) { if (config->model_config.debug) {
fprintf(stderr, "%s\n", recognizer_config.ToString().c_str()); fprintf(stderr, "%s\n", recognizer_config.ToString().c_str());

View File

@@ -272,6 +272,8 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineModelConfig {
const char *tokens; const char *tokens;
int32_t num_threads; int32_t num_threads;
int32_t debug; int32_t debug;
const char *provider;
const char *model_type;
} SherpaOnnxOfflineModelConfig; } SherpaOnnxOfflineModelConfig;
SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerConfig { SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerConfig {

View File

@@ -25,6 +25,11 @@ void OfflineModelConfig::Register(ParseOptions *po) {
po->Register("provider", &provider, po->Register("provider", &provider,
"Specify a provider to use: cpu, cuda, coreml"); "Specify a provider to use: cpu, cuda, coreml");
po->Register("model-type", &model_type,
"Specify it to reduce model initialization time. "
"Valid values are: transducer, paraformer, nemo_ctc. "
"All other values lead to loading the model twice.");
} }
bool OfflineModelConfig::Validate() const { bool OfflineModelConfig::Validate() const {
@@ -34,7 +39,7 @@ bool OfflineModelConfig::Validate() const {
} }
if (!FileExists(tokens)) { if (!FileExists(tokens)) {
SHERPA_ONNX_LOGE("%s does not exist", tokens.c_str()); SHERPA_ONNX_LOGE("tokens: %s does not exist", tokens.c_str());
return false; return false;
} }
@@ -59,7 +64,8 @@ std::string OfflineModelConfig::ToString() const {
os << "tokens=\"" << tokens << "\", "; os << "tokens=\"" << tokens << "\", ";
os << "num_threads=" << num_threads << ", "; os << "num_threads=" << num_threads << ", ";
os << "debug=" << (debug ? "True" : "False") << ", "; os << "debug=" << (debug ? "True" : "False") << ", ";
os << "provider=\"" << provider << "\")"; os << "provider=\"" << provider << "\", ";
os << "model_type=\"" << model_type << "\")";
return os.str(); return os.str();
} }

View File

@@ -22,19 +22,31 @@ struct OfflineModelConfig {
bool debug = false; bool debug = false;
std::string provider = "cpu"; std::string provider = "cpu";
// With the help of this field, we only need to load the model once
// instead of twice; and therefore it reduces initialization time.
//
// Valid values:
// - transducer. The given model is from icefall
// - paraformer. It is a paraformer model
// - nemo_ctc. It is a NeMo CTC model.
//
// All other values are invalid and lead to loading the model twice.
std::string model_type;
OfflineModelConfig() = default; OfflineModelConfig() = default;
OfflineModelConfig(const OfflineTransducerModelConfig &transducer, OfflineModelConfig(const OfflineTransducerModelConfig &transducer,
const OfflineParaformerModelConfig &paraformer, const OfflineParaformerModelConfig &paraformer,
const OfflineNemoEncDecCtcModelConfig &nemo_ctc, const OfflineNemoEncDecCtcModelConfig &nemo_ctc,
const std::string &tokens, int32_t num_threads, bool debug, const std::string &tokens, int32_t num_threads, bool debug,
const std::string &provider) const std::string &provider, const std::string &model_type)
: transducer(transducer), : transducer(transducer),
paraformer(paraformer), paraformer(paraformer),
nemo_ctc(nemo_ctc), nemo_ctc(nemo_ctc),
tokens(tokens), tokens(tokens),
num_threads(num_threads), num_threads(num_threads),
debug(debug), debug(debug),
provider(provider) {} provider(provider),
model_type(model_type) {}
void Register(ParseOptions *po); void Register(ParseOptions *po);
bool Validate() const; bool Validate() const;

View File

@@ -18,6 +18,21 @@ namespace sherpa_onnx {
std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
const OfflineRecognizerConfig &config) { const OfflineRecognizerConfig &config) {
if (!config.model_config.model_type.empty()) {
const auto &model_type = config.model_config.model_type;
if (model_type == "transducer") {
return std::make_unique<OfflineRecognizerTransducerImpl>(config);
} else if (model_type == "paraformer") {
return std::make_unique<OfflineRecognizerParaformerImpl>(config);
} else if (model_type == "nemo_ctc") {
return std::make_unique<OfflineRecognizerCtcImpl>(config);
} else {
SHERPA_ONNX_LOGE(
"Invalid model_type: %s. Trying to load the model to get its type",
model_type.c_str());
}
}
Ort::Env env(ORT_LOGGING_LEVEL_ERROR); Ort::Env env(ORT_LOGGING_LEVEL_ERROR);
Ort::SessionOptions sess_opts; Ort::SessionOptions sess_opts;

View File

@@ -18,17 +18,17 @@ void OfflineTransducerModelConfig::Register(ParseOptions *po) {
bool OfflineTransducerModelConfig::Validate() const { bool OfflineTransducerModelConfig::Validate() const {
if (!FileExists(encoder_filename)) { if (!FileExists(encoder_filename)) {
SHERPA_ONNX_LOGE("%s does not exist", encoder_filename.c_str()); SHERPA_ONNX_LOGE("encoder: %s does not exist", encoder_filename.c_str());
return false; return false;
} }
if (!FileExists(decoder_filename)) { if (!FileExists(decoder_filename)) {
SHERPA_ONNX_LOGE("%s does not exist", decoder_filename.c_str()); SHERPA_ONNX_LOGE("decoder: %s does not exist", decoder_filename.c_str());
return false; return false;
} }
if (!FileExists(joiner_filename)) { if (!FileExists(joiner_filename)) {
SHERPA_ONNX_LOGE("%s does not exist", joiner_filename.c_str()); SHERPA_ONNX_LOGE("joiner: %s does not exist", joiner_filename.c_str());
return false; return false;
} }

View File

@@ -21,15 +21,16 @@ void PybindOfflineModelConfig(py::module *m) {
using PyClass = OfflineModelConfig; using PyClass = OfflineModelConfig;
py::class_<PyClass>(*m, "OfflineModelConfig") py::class_<PyClass>(*m, "OfflineModelConfig")
.def(py::init<const OfflineTransducerModelConfig &, .def(
const OfflineParaformerModelConfig &, py::init<const OfflineTransducerModelConfig &,
const OfflineNemoEncDecCtcModelConfig &, const OfflineParaformerModelConfig &,
const std::string &, int32_t, bool, const std::string &>(), const OfflineNemoEncDecCtcModelConfig &, const std::string &,
py::arg("transducer") = OfflineTransducerModelConfig(), int32_t, bool, const std::string &, const std::string &>(),
py::arg("paraformer") = OfflineParaformerModelConfig(), py::arg("transducer") = OfflineTransducerModelConfig(),
py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(), py::arg("paraformer") = OfflineParaformerModelConfig(),
py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false, py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(),
py::arg("provider") = "cpu") py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false,
py::arg("provider") = "cpu", py::arg("model_type") = "")
.def_readwrite("transducer", &PyClass::transducer) .def_readwrite("transducer", &PyClass::transducer)
.def_readwrite("paraformer", &PyClass::paraformer) .def_readwrite("paraformer", &PyClass::paraformer)
.def_readwrite("nemo_ctc", &PyClass::nemo_ctc) .def_readwrite("nemo_ctc", &PyClass::nemo_ctc)
@@ -37,6 +38,7 @@ void PybindOfflineModelConfig(py::module *m) {
.def_readwrite("num_threads", &PyClass::num_threads) .def_readwrite("num_threads", &PyClass::num_threads)
.def_readwrite("debug", &PyClass::debug) .def_readwrite("debug", &PyClass::debug)
.def_readwrite("provider", &PyClass::provider) .def_readwrite("provider", &PyClass::provider)
.def_readwrite("model_type", &PyClass::model_type)
.def("__str__", &PyClass::ToString); .def("__str__", &PyClass::ToString);
} }

View File

@@ -86,6 +86,7 @@ class OfflineRecognizer(object):
num_threads=num_threads, num_threads=num_threads,
debug=debug, debug=debug,
provider=provider, provider=provider,
model_type="transducer",
) )
feat_config = OfflineFeatureExtractorConfig( feat_config = OfflineFeatureExtractorConfig(
@@ -149,6 +150,7 @@ class OfflineRecognizer(object):
num_threads=num_threads, num_threads=num_threads,
debug=debug, debug=debug,
provider=provider, provider=provider,
model_type="paraformer",
) )
feat_config = OfflineFeatureExtractorConfig( feat_config = OfflineFeatureExtractorConfig(
@@ -211,6 +213,7 @@ class OfflineRecognizer(object):
num_threads=num_threads, num_threads=num_threads,
debug=debug, debug=debug,
provider=provider, provider=provider,
model_type="nemo_ctc",
) )
feat_config = OfflineFeatureExtractorConfig( feat_config = OfflineFeatureExtractorConfig(