Support non-streaming WeNet CTC models. (#426)
This commit is contained in:
@@ -12,6 +12,7 @@
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h"
|
||||
#include "sherpa-onnx/csrc/offline-tdnn-ctc-model.h"
|
||||
#include "sherpa-onnx/csrc/offline-wenet-ctc-model.h"
|
||||
#include "sherpa-onnx/csrc/offline-zipformer-ctc-model.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
|
||||
@@ -21,10 +22,11 @@ enum class ModelType {
|
||||
kEncDecCTCModelBPE,
|
||||
kTdnn,
|
||||
kZipformerCtc,
|
||||
kWenetCtc,
|
||||
kUnkown,
|
||||
};
|
||||
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
@@ -52,6 +54,9 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
|
||||
"If you are using models from NeMo, please refer to\n"
|
||||
"https://huggingface.co/csukuangfj/"
|
||||
"sherpa-onnx-nemo-ctc-en-citrinet-512/blob/main/add-model-metadata.py"
|
||||
"If you are using models from WeNet, please refer to\n"
|
||||
"https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wenet/"
|
||||
"run.sh\n"
|
||||
"\n"
|
||||
"for how to add metadta to model.onnx\n");
|
||||
return ModelType::kUnkown;
|
||||
@@ -63,6 +68,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
|
||||
return ModelType::kTdnn;
|
||||
} else if (model_type.get() == std::string("zipformer2_ctc")) {
|
||||
return ModelType::kZipformerCtc;
|
||||
} else if (model_type.get() == std::string("wenet_ctc")) {
|
||||
return ModelType::kWenetCtc;
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
|
||||
return ModelType::kUnkown;
|
||||
@@ -80,6 +87,8 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
|
||||
filename = config.tdnn.model;
|
||||
} else if (!config.zipformer_ctc.model.empty()) {
|
||||
filename = config.zipformer_ctc.model;
|
||||
} else if (!config.wenet_ctc.model.empty()) {
|
||||
filename = config.wenet_ctc.model;
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Please specify a CTC model");
|
||||
exit(-1);
|
||||
@@ -101,6 +110,9 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
|
||||
case ModelType::kZipformerCtc:
|
||||
return std::make_unique<OfflineZipformerCtcModel>(config);
|
||||
break;
|
||||
case ModelType::kWenetCtc:
|
||||
return std::make_unique<OfflineWenetCtcModel>(config);
|
||||
break;
|
||||
case ModelType::kUnkown:
|
||||
SHERPA_ONNX_LOGE("Unknown model type in offline CTC!");
|
||||
return nullptr;
|
||||
@@ -122,6 +134,8 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
|
||||
filename = config.tdnn.model;
|
||||
} else if (!config.zipformer_ctc.model.empty()) {
|
||||
filename = config.zipformer_ctc.model;
|
||||
} else if (!config.wenet_ctc.model.empty()) {
|
||||
filename = config.wenet_ctc.model;
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Please specify a CTC model");
|
||||
exit(-1);
|
||||
@@ -143,6 +157,9 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
|
||||
case ModelType::kZipformerCtc:
|
||||
return std::make_unique<OfflineZipformerCtcModel>(mgr, config);
|
||||
break;
|
||||
case ModelType::kWenetCtc:
|
||||
return std::make_unique<OfflineWenetCtcModel>(mgr, config);
|
||||
break;
|
||||
case ModelType::kUnkown:
|
||||
SHERPA_ONNX_LOGE("Unknown model type in offline CTC!");
|
||||
return nullptr;
|
||||
|
||||
Reference in New Issue
Block a user