Add GigaAM NeMo transducer model for Russian ASR (#1467)

This commit is contained in:
Fangjun Kuang
2024-10-25 15:20:13 +08:00
committed by GitHub
parent b41f6d2c94
commit 707cf792c5
12 changed files with 543 additions and 21 deletions

View File

@@ -166,7 +166,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
return std::make_unique<OfflineRecognizerParaformerImpl>(config);
}
if (model_type == "EncDecHybridRNNTCTCBPEModel" &&
if ((model_type == "EncDecHybridRNNTCTCBPEModel" ||
model_type == "EncDecRNNTBPEModel") &&
!config.model_config.transducer.decoder_filename.empty() &&
!config.model_config.transducer.joiner_filename.empty()) {
return std::make_unique<OfflineRecognizerTransducerNeMoImpl>(config);
@@ -191,6 +192,7 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
" - EncDecCTCModelBPE models from NeMo\n"
" - EncDecCTCModel models from NeMo\n"
" - EncDecHybridRNNTCTCBPEModel models from NeMo\n"
" - EncDecRNNTBPEModel models from NeMO"
" - Whisper models\n"
" - Tdnn models\n"
" - Zipformer CTC models\n"
@@ -338,7 +340,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
return std::make_unique<OfflineRecognizerParaformerImpl>(mgr, config);
}
if (model_type == "EncDecHybridRNNTCTCBPEModel" &&
if ((model_type == "EncDecHybridRNNTCTCBPEModel" ||
model_type == "EncDecRNNTBPEModel") &&
!config.model_config.transducer.decoder_filename.empty() &&
!config.model_config.transducer.joiner_filename.empty()) {
return std::make_unique<OfflineRecognizerTransducerNeMoImpl>(mgr, config);
@@ -363,6 +366,7 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
" - EncDecCTCModelBPE models from NeMo\n"
" - EncDecCTCModel models from NeMo\n"
" - EncDecHybridRNNTCTCBPEModel models from NeMo\n"
" - EncDecRNNTBPEModel models from NeMo\n"
" - Whisper models\n"
" - Tdnn models\n"
" - Zipformer CTC models\n"

View File

@@ -139,23 +139,29 @@ class OfflineRecognizerTransducerNeMoImpl : public OfflineRecognizerImpl {
}
}
OfflineRecognizerConfig GetConfig() const override {
return config_;
}
OfflineRecognizerConfig GetConfig() const override { return config_; }
private:
void PostInit() {
config_.feat_config.nemo_normalize_type =
model_->FeatureNormalizationMethod();
config_.feat_config.low_freq = 0;
// config_.feat_config.high_freq = 8000;
config_.feat_config.is_librosa = true;
config_.feat_config.remove_dc_offset = false;
// config_.feat_config.window_type = "hann";
config_.feat_config.dither = 0;
config_.feat_config.nemo_normalize_type =
model_->FeatureNormalizationMethod();
if (model_->IsGigaAM()) {
config_.feat_config.low_freq = 0;
config_.feat_config.high_freq = 8000;
config_.feat_config.remove_dc_offset = false;
config_.feat_config.preemph_coeff = 0;
config_.feat_config.window_type = "hann";
config_.feat_config.feature_dim = 64;
} else {
config_.feat_config.low_freq = 0;
// config_.feat_config.high_freq = 8000;
config_.feat_config.is_librosa = true;
config_.feat_config.remove_dc_offset = false;
// config_.feat_config.window_type = "hann";
}
int32_t vocab_size = model_->VocabSize();

View File

@@ -153,6 +153,8 @@ class OfflineTransducerNeMoModel::Impl {
std::string FeatureNormalizationMethod() const { return normalize_type_; }
bool IsGigaAM() const { return is_giga_am_; }
private:
void InitEncoder(void *model_data, size_t model_data_length) {
encoder_sess_ = std::make_unique<Ort::Session>(
@@ -181,9 +183,11 @@ class OfflineTransducerNeMoModel::Impl {
vocab_size_ += 1;
SHERPA_ONNX_READ_META_DATA(subsampling_factor_, "subsampling_factor");
SHERPA_ONNX_READ_META_DATA_STR(normalize_type_, "normalize_type");
SHERPA_ONNX_READ_META_DATA_STR_ALLOW_EMPTY(normalize_type_,
"normalize_type");
SHERPA_ONNX_READ_META_DATA(pred_rnn_layers_, "pred_rnn_layers");
SHERPA_ONNX_READ_META_DATA(pred_hidden_, "pred_hidden");
SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(is_giga_am_, "is_giga_am", 0);
if (normalize_type_ == "NA") {
normalize_type_ = "";
@@ -245,6 +249,7 @@ class OfflineTransducerNeMoModel::Impl {
std::string normalize_type_;
int32_t pred_rnn_layers_ = -1;
int32_t pred_hidden_ = -1;
int32_t is_giga_am_ = 0;
};
OfflineTransducerNeMoModel::OfflineTransducerNeMoModel(
@@ -298,4 +303,6 @@ std::string OfflineTransducerNeMoModel::FeatureNormalizationMethod() const {
return impl_->FeatureNormalizationMethod();
}
bool OfflineTransducerNeMoModel::IsGigaAM() const { return impl_->IsGigaAM(); }
} // namespace sherpa_onnx

View File

@@ -93,6 +93,8 @@ class OfflineTransducerNeMoModel {
// for details
std::string FeatureNormalizationMethod() const;
bool IsGigaAM() const;
private:
class Impl;
std::unique_ptr<Impl> impl_;

View File

@@ -404,6 +404,19 @@ fun getOfflineModelConfig(type: Int): OfflineModelConfig? {
tokens = "$modelDir/tokens.txt",
)
}
20 -> {
val modelDir = "sherpa-onnx-nemo-transducer-giga-am-russian-2024-10-24"
return OfflineModelConfig(
transducer = OfflineTransducerModelConfig(
encoder = "$modelDir/encoder.int8.onnx",
decoder = "$modelDir/decoder.onnx",
joiner = "$modelDir/joiner.onnx",
),
tokens = "$modelDir/tokens.txt",
modelType = "nemo_transducer",
)
}
}
return null
}