Support TDNN models from the yesno recipe from icefall (#262)

This commit is contained in:
Fangjun Kuang
2023-08-12 19:50:22 +08:00
committed by GitHub
parent b094868fb8
commit a4bff28e21
23 changed files with 612 additions and 36 deletions

View File

@@ -11,12 +11,14 @@
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h"
#include "sherpa-onnx/csrc/offline-tdnn-ctc-model.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
namespace {
enum class ModelType {
kEncDecCTCModelBPE,
kTdnn,
kUnkown,
};
@@ -55,6 +57,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
if (model_type.get() == std::string("EncDecCTCModelBPE")) {
return ModelType::kEncDecCTCModelBPE;
} else if (model_type.get() == std::string("tdnn")) {
return ModelType::kTdnn;
} else {
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
return ModelType::kUnkown;
@@ -65,8 +69,18 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
const OfflineModelConfig &config) {
ModelType model_type = ModelType::kUnkown;
std::string filename;
if (!config.nemo_ctc.model.empty()) {
filename = config.nemo_ctc.model;
} else if (!config.tdnn.model.empty()) {
filename = config.tdnn.model;
} else {
SHERPA_ONNX_LOGE("Please specify a CTC model");
exit(-1);
}
{
auto buffer = ReadFile(config.nemo_ctc.model);
auto buffer = ReadFile(filename);
model_type = GetModelType(buffer.data(), buffer.size(), config.debug);
}
@@ -75,6 +89,9 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
case ModelType::kEncDecCTCModelBPE:
return std::make_unique<OfflineNemoEncDecCtcModel>(config);
break;
case ModelType::kTdnn:
return std::make_unique<OfflineTdnnCtcModel>(config);
break;
case ModelType::kUnkown:
SHERPA_ONNX_LOGE("Unknown model type in offline CTC!");
return nullptr;