Reduce model initialization time for offline speech recognition (#213)
This commit is contained in:
@@ -25,6 +25,11 @@ void OfflineModelConfig::Register(ParseOptions *po) {
|
||||
|
||||
po->Register("provider", &provider,
|
||||
"Specify a provider to use: cpu, cuda, coreml");
|
||||
|
||||
po->Register("model-type", &model_type,
|
||||
"Specify it to reduce model initialization time. "
|
||||
"Valid values are: transducer, paraformer, nemo_ctc. "
|
||||
"All other values lead to loading the model twice.");
|
||||
}
|
||||
|
||||
bool OfflineModelConfig::Validate() const {
|
||||
@@ -34,7 +39,7 @@ bool OfflineModelConfig::Validate() const {
|
||||
}
|
||||
|
||||
if (!FileExists(tokens)) {
|
||||
SHERPA_ONNX_LOGE("%s does not exist", tokens.c_str());
|
||||
SHERPA_ONNX_LOGE("tokens: %s does not exist", tokens.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -59,7 +64,8 @@ std::string OfflineModelConfig::ToString() const {
|
||||
os << "tokens=\"" << tokens << "\", ";
|
||||
os << "num_threads=" << num_threads << ", ";
|
||||
os << "debug=" << (debug ? "True" : "False") << ", ";
|
||||
os << "provider=\"" << provider << "\")";
|
||||
os << "provider=\"" << provider << "\", ";
|
||||
os << "model_type=\"" << model_type << "\")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
@@ -22,19 +22,31 @@ struct OfflineModelConfig {
|
||||
bool debug = false;
|
||||
std::string provider = "cpu";
|
||||
|
||||
// With the help of this field, we only need to load the model once
|
||||
// instead of twice; and therefore it reduces initialization time.
|
||||
//
|
||||
// Valid values:
|
||||
// - transducer. The given model is from icefall
|
||||
// - paraformer. It is a paraformer model
|
||||
// - nemo_ctc. It is a NeMo CTC model.
|
||||
//
|
||||
// All other values are invalid and lead to loading the model twice.
|
||||
std::string model_type;
|
||||
|
||||
OfflineModelConfig() = default;
|
||||
OfflineModelConfig(const OfflineTransducerModelConfig &transducer,
|
||||
const OfflineParaformerModelConfig ¶former,
|
||||
const OfflineNemoEncDecCtcModelConfig &nemo_ctc,
|
||||
const std::string &tokens, int32_t num_threads, bool debug,
|
||||
const std::string &provider)
|
||||
const std::string &provider, const std::string &model_type)
|
||||
: transducer(transducer),
|
||||
paraformer(paraformer),
|
||||
nemo_ctc(nemo_ctc),
|
||||
tokens(tokens),
|
||||
num_threads(num_threads),
|
||||
debug(debug),
|
||||
provider(provider) {}
|
||||
provider(provider),
|
||||
model_type(model_type) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
|
||||
@@ -18,6 +18,21 @@ namespace sherpa_onnx {
|
||||
|
||||
std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
const OfflineRecognizerConfig &config) {
|
||||
if (!config.model_config.model_type.empty()) {
|
||||
const auto &model_type = config.model_config.model_type;
|
||||
if (model_type == "transducer") {
|
||||
return std::make_unique<OfflineRecognizerTransducerImpl>(config);
|
||||
} else if (model_type == "paraformer") {
|
||||
return std::make_unique<OfflineRecognizerParaformerImpl>(config);
|
||||
} else if (model_type == "nemo_ctc") {
|
||||
return std::make_unique<OfflineRecognizerCtcImpl>(config);
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Invalid model_type: %s. Trying to load the model to get its type",
|
||||
model_type.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
Ort::Env env(ORT_LOGGING_LEVEL_ERROR);
|
||||
|
||||
Ort::SessionOptions sess_opts;
|
||||
|
||||
@@ -18,17 +18,17 @@ void OfflineTransducerModelConfig::Register(ParseOptions *po) {
|
||||
|
||||
bool OfflineTransducerModelConfig::Validate() const {
|
||||
if (!FileExists(encoder_filename)) {
|
||||
SHERPA_ONNX_LOGE("%s does not exist", encoder_filename.c_str());
|
||||
SHERPA_ONNX_LOGE("encoder: %s does not exist", encoder_filename.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!FileExists(decoder_filename)) {
|
||||
SHERPA_ONNX_LOGE("%s does not exist", decoder_filename.c_str());
|
||||
SHERPA_ONNX_LOGE("decoder: %s does not exist", decoder_filename.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!FileExists(joiner_filename)) {
|
||||
SHERPA_ONNX_LOGE("%s does not exist", joiner_filename.c_str());
|
||||
SHERPA_ONNX_LOGE("joiner: %s does not exist", joiner_filename.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user