diff --git a/sherpa-onnx/csrc/offline-recognizer-transducer-nemo-impl.h b/sherpa-onnx/csrc/offline-recognizer-transducer-nemo-impl.h index cf8e18da..eda0295d 100644 --- a/sherpa-onnx/csrc/offline-recognizer-transducer-nemo-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-transducer-nemo-impl.h @@ -138,6 +138,12 @@ class OfflineRecognizerTransducerNeMoImpl : public OfflineRecognizerImpl { private: void PostInit() { + int32_t feat_dim = model_->FeatureDim(); + + if (feat_dim > 0) { + config_.feat_config.feature_dim = feat_dim; + } + config_.feat_config.nemo_normalize_type = model_->FeatureNormalizationMethod(); diff --git a/sherpa-onnx/csrc/offline-transducer-nemo-model.cc b/sherpa-onnx/csrc/offline-transducer-nemo-model.cc index 6fbb2b61..3470fcc0 100644 --- a/sherpa-onnx/csrc/offline-transducer-nemo-model.cc +++ b/sherpa-onnx/csrc/offline-transducer-nemo-model.cc @@ -164,6 +164,8 @@ class OfflineTransducerNeMoModel::Impl { bool IsGigaAM() const { return is_giga_am_; } + int32_t FeatureDim() const { return feat_dim_; } + private: void InitEncoder(void *model_data, size_t model_data_length) { encoder_sess_ = std::make_unique( @@ -201,6 +203,7 @@ class OfflineTransducerNeMoModel::Impl { SHERPA_ONNX_READ_META_DATA(pred_rnn_layers_, "pred_rnn_layers"); SHERPA_ONNX_READ_META_DATA(pred_hidden_, "pred_hidden"); SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(is_giga_am_, "is_giga_am", 0); + SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(feat_dim_, "feat_dim", -1); if (normalize_type_ == "NA") { normalize_type_ = ""; @@ -263,6 +266,11 @@ class OfflineTransducerNeMoModel::Impl { int32_t pred_rnn_layers_ = -1; int32_t pred_hidden_ = -1; int32_t is_giga_am_ = 0; + + // giga am uses 64 + // parakeet-tdt-0.6b-v2 uses 128 + // others use 80 + int32_t feat_dim_ = -1; // -1 means to use default values. }; OfflineTransducerNeMoModel::OfflineTransducerNeMoModel( @@ -317,6 +325,10 @@ std::string OfflineTransducerNeMoModel::FeatureNormalizationMethod() const { bool OfflineTransducerNeMoModel::IsGigaAM() const { return impl_->IsGigaAM(); } +int32_t OfflineTransducerNeMoModel::FeatureDim() const { + return impl_->FeatureDim(); +} + #if __ANDROID_API__ >= 9 template OfflineTransducerNeMoModel::OfflineTransducerNeMoModel( AAssetManager *mgr, const OfflineModelConfig &config); diff --git a/sherpa-onnx/csrc/offline-transducer-nemo-model.h b/sherpa-onnx/csrc/offline-transducer-nemo-model.h index 697f749e..e5101a57 100644 --- a/sherpa-onnx/csrc/offline-transducer-nemo-model.h +++ b/sherpa-onnx/csrc/offline-transducer-nemo-model.h @@ -88,6 +88,8 @@ class OfflineTransducerNeMoModel { bool IsGigaAM() const; + int32_t FeatureDim() const; + private: class Impl; std::unique_ptr impl_;