Add C++ runtime for Tele-AI/TeleSpeech-ASR (#970)
This commit is contained in:
@@ -12,6 +12,7 @@
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h"
|
||||
#include "sherpa-onnx/csrc/offline-tdnn-ctc-model.h"
|
||||
#include "sherpa-onnx/csrc/offline-telespeech-ctc-model.h"
|
||||
#include "sherpa-onnx/csrc/offline-wenet-ctc-model.h"
|
||||
#include "sherpa-onnx/csrc/offline-zipformer-ctc-model.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
@@ -24,6 +25,7 @@ enum class ModelType {
|
||||
kTdnn,
|
||||
kZipformerCtc,
|
||||
kWenetCtc,
|
||||
kTeleSpeechCtc,
|
||||
kUnknown,
|
||||
};
|
||||
|
||||
@@ -63,6 +65,9 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
|
||||
"If you are using models from WeNet, please refer to\n"
|
||||
"https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wenet/"
|
||||
"run.sh\n"
|
||||
"If you are using models from TeleSpeech, please refer to\n"
|
||||
"https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/tele-speech/"
|
||||
"add-metadata.py"
|
||||
"\n"
|
||||
"for how to add metadta to model.onnx\n");
|
||||
return ModelType::kUnknown;
|
||||
@@ -78,6 +83,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
|
||||
return ModelType::kZipformerCtc;
|
||||
} else if (model_type.get() == std::string("wenet_ctc")) {
|
||||
return ModelType::kWenetCtc;
|
||||
} else if (model_type.get() == std::string("telespeech_ctc")) {
|
||||
return ModelType::kTeleSpeechCtc;
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
|
||||
return ModelType::kUnknown;
|
||||
@@ -97,6 +104,8 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
|
||||
filename = config.zipformer_ctc.model;
|
||||
} else if (!config.wenet_ctc.model.empty()) {
|
||||
filename = config.wenet_ctc.model;
|
||||
} else if (!config.telespeech_ctc.empty()) {
|
||||
filename = config.telespeech_ctc;
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Please specify a CTC model");
|
||||
exit(-1);
|
||||
@@ -124,6 +133,9 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
|
||||
case ModelType::kWenetCtc:
|
||||
return std::make_unique<OfflineWenetCtcModel>(config);
|
||||
break;
|
||||
case ModelType::kTeleSpeechCtc:
|
||||
return std::make_unique<OfflineTeleSpeechCtcModel>(config);
|
||||
break;
|
||||
case ModelType::kUnknown:
|
||||
SHERPA_ONNX_LOGE("Unknown model type in offline CTC!");
|
||||
return nullptr;
|
||||
@@ -147,6 +159,8 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
|
||||
filename = config.zipformer_ctc.model;
|
||||
} else if (!config.wenet_ctc.model.empty()) {
|
||||
filename = config.wenet_ctc.model;
|
||||
} else if (!config.telespeech_ctc.empty()) {
|
||||
filename = config.telespeech_ctc;
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Please specify a CTC model");
|
||||
exit(-1);
|
||||
@@ -175,6 +189,9 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
|
||||
case ModelType::kWenetCtc:
|
||||
return std::make_unique<OfflineWenetCtcModel>(mgr, config);
|
||||
break;
|
||||
case ModelType::kTeleSpeechCtc:
|
||||
return std::make_unique<OfflineTeleSpeechCtcModel>(mgr, config);
|
||||
break;
|
||||
case ModelType::kUnknown:
|
||||
SHERPA_ONNX_LOGE("Unknown model type in offline CTC!");
|
||||
return nullptr;
|
||||
|
||||
Reference in New Issue
Block a user