Ebranchformer (#1951)

* adding ebranchformer encoder

* extend surfaced FeatureExtractorConfig

- so ebranchformer feature extraction can be configured from Python
- the GlobCmvn is not needed, as it is a module in the OnnxEncoder

* clean the code

* Integrating remarks from Fangjun
This commit is contained in:
Karel Vesely
2025-03-04 12:41:09 +01:00
committed by GitHub
parent 209eaaae1d
commit 7740dbfb96
8 changed files with 609 additions and 5 deletions

View File

@@ -21,6 +21,7 @@
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/online-conformer-transducer-model.h"
#include "sherpa-onnx/csrc/online-ebranchformer-transducer-model.h"
#include "sherpa-onnx/csrc/online-lstm-transducer-model.h"
#include "sherpa-onnx/csrc/online-zipformer-transducer-model.h"
#include "sherpa-onnx/csrc/online-zipformer2-transducer-model.h"
@@ -30,6 +31,7 @@ namespace {
enum class ModelType : std::uint8_t {
kConformer,
kEbranchformer,
kLstm,
kZipformer,
kZipformer2,
@@ -74,6 +76,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
if (model_type == "conformer") {
return ModelType::kConformer;
} else if (model_type == "ebranchformer") {
return ModelType::kEbranchformer;
} else if (model_type == "lstm") {
return ModelType::kLstm;
} else if (model_type == "zipformer") {
@@ -92,6 +96,8 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
const auto &model_type = config.model_type;
if (model_type == "conformer") {
return std::make_unique<OnlineConformerTransducerModel>(config);
} else if (model_type == "ebranchformer") {
return std::make_unique<OnlineEbranchformerTransducerModel>(config);
} else if (model_type == "lstm") {
return std::make_unique<OnlineLstmTransducerModel>(config);
} else if (model_type == "zipformer") {
@@ -115,6 +121,8 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
switch (model_type) {
case ModelType::kConformer:
return std::make_unique<OnlineConformerTransducerModel>(config);
case ModelType::kEbranchformer:
return std::make_unique<OnlineEbranchformerTransducerModel>(config);
case ModelType::kLstm:
return std::make_unique<OnlineLstmTransducerModel>(config);
case ModelType::kZipformer:
@@ -171,6 +179,8 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
const auto &model_type = config.model_type;
if (model_type == "conformer") {
return std::make_unique<OnlineConformerTransducerModel>(mgr, config);
} else if (model_type == "ebranchformer") {
return std::make_unique<OnlineEbranchformerTransducerModel>(mgr, config);
} else if (model_type == "lstm") {
return std::make_unique<OnlineLstmTransducerModel>(mgr, config);
} else if (model_type == "zipformer") {
@@ -190,6 +200,8 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
switch (model_type) {
case ModelType::kConformer:
return std::make_unique<OnlineConformerTransducerModel>(mgr, config);
case ModelType::kEbranchformer:
return std::make_unique<OnlineEbranchformerTransducerModel>(mgr, config);
case ModelType::kLstm:
return std::make_unique<OnlineLstmTransducerModel>(mgr, config);
case ModelType::kZipformer: