Begin to support CTC models (#119)

Please see https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/nemo/index.html for a list of pre-trained CTC models from NeMo.
This commit is contained in:
Fangjun Kuang
2023-04-07 23:11:34 +08:00
committed by GitHub
parent 9ac747248b
commit 80060c276d
40 changed files with 1244 additions and 60 deletions

View File

@@ -14,10 +14,12 @@
#include <sstream>
#include <string>
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/online-lstm-transducer-model.h"
#include "sherpa-onnx/csrc/online-zipformer-transducer-model.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
namespace sherpa_onnx {
namespace {
enum class ModelType {
kLstm,
@@ -25,6 +27,10 @@ enum class ModelType {
kUnkown,
};
}
namespace sherpa_onnx {
static ModelType GetModelType(char *model_data, size_t model_data_length,
bool debug) {
Ort::Env env(ORT_LOGGING_LEVEL_WARNING);
@@ -37,14 +43,17 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
if (debug) {
std::ostringstream os;
PrintModelMetadata(os, meta_data);
fprintf(stderr, "%s\n", os.str().c_str());
SHERPA_ONNX_LOGE("%s", os.str().c_str());
}
Ort::AllocatorWithDefaultOptions allocator;
auto model_type =
meta_data.LookupCustomMetadataMapAllocated("model_type", allocator);
if (!model_type) {
fprintf(stderr, "No model_type in the metadata!\n");
SHERPA_ONNX_LOGE(
"No model_type in the metadata!\n"
"Please make sure you are using the latest export-onnx.py from icefall "
"to export your transducer models");
return ModelType::kUnkown;
}
@@ -53,7 +62,7 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
} else if (model_type.get() == std::string("zipformer")) {
return ModelType::kZipformer;
} else {
fprintf(stderr, "Unsupported model_type: %s\n", model_type.get());
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
return ModelType::kUnkown;
}
}
@@ -74,6 +83,7 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
case ModelType::kZipformer:
return std::make_unique<OnlineZipformerTransducerModel>(config);
case ModelType::kUnkown:
SHERPA_ONNX_LOGE("Unknown model type in online transducer!");
return nullptr;
}
@@ -127,6 +137,7 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
case ModelType::kZipformer:
return std::make_unique<OnlineZipformerTransducerModel>(mgr, config);
case ModelType::kUnkown:
SHERPA_ONNX_LOGE("Unknown model type in online transducer!");
return nullptr;
}