Add GigaAM NeMo transducer model for Russian ASR (#1467)
This commit is contained in:
@@ -166,7 +166,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
return std::make_unique<OfflineRecognizerParaformerImpl>(config);
|
||||
}
|
||||
|
||||
if (model_type == "EncDecHybridRNNTCTCBPEModel" &&
|
||||
if ((model_type == "EncDecHybridRNNTCTCBPEModel" ||
|
||||
model_type == "EncDecRNNTBPEModel") &&
|
||||
!config.model_config.transducer.decoder_filename.empty() &&
|
||||
!config.model_config.transducer.joiner_filename.empty()) {
|
||||
return std::make_unique<OfflineRecognizerTransducerNeMoImpl>(config);
|
||||
@@ -191,6 +192,7 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
" - EncDecCTCModelBPE models from NeMo\n"
|
||||
" - EncDecCTCModel models from NeMo\n"
|
||||
" - EncDecHybridRNNTCTCBPEModel models from NeMo\n"
|
||||
" - EncDecRNNTBPEModel models from NeMO"
|
||||
" - Whisper models\n"
|
||||
" - Tdnn models\n"
|
||||
" - Zipformer CTC models\n"
|
||||
@@ -338,7 +340,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
return std::make_unique<OfflineRecognizerParaformerImpl>(mgr, config);
|
||||
}
|
||||
|
||||
if (model_type == "EncDecHybridRNNTCTCBPEModel" &&
|
||||
if ((model_type == "EncDecHybridRNNTCTCBPEModel" ||
|
||||
model_type == "EncDecRNNTBPEModel") &&
|
||||
!config.model_config.transducer.decoder_filename.empty() &&
|
||||
!config.model_config.transducer.joiner_filename.empty()) {
|
||||
return std::make_unique<OfflineRecognizerTransducerNeMoImpl>(mgr, config);
|
||||
@@ -363,6 +366,7 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
" - EncDecCTCModelBPE models from NeMo\n"
|
||||
" - EncDecCTCModel models from NeMo\n"
|
||||
" - EncDecHybridRNNTCTCBPEModel models from NeMo\n"
|
||||
" - EncDecRNNTBPEModel models from NeMo\n"
|
||||
" - Whisper models\n"
|
||||
" - Tdnn models\n"
|
||||
" - Zipformer CTC models\n"
|
||||
|
||||
@@ -139,23 +139,29 @@ class OfflineRecognizerTransducerNeMoImpl : public OfflineRecognizerImpl {
|
||||
}
|
||||
}
|
||||
|
||||
OfflineRecognizerConfig GetConfig() const override {
|
||||
return config_;
|
||||
}
|
||||
OfflineRecognizerConfig GetConfig() const override { return config_; }
|
||||
|
||||
private:
|
||||
void PostInit() {
|
||||
config_.feat_config.nemo_normalize_type =
|
||||
model_->FeatureNormalizationMethod();
|
||||
|
||||
config_.feat_config.low_freq = 0;
|
||||
// config_.feat_config.high_freq = 8000;
|
||||
config_.feat_config.is_librosa = true;
|
||||
config_.feat_config.remove_dc_offset = false;
|
||||
// config_.feat_config.window_type = "hann";
|
||||
config_.feat_config.dither = 0;
|
||||
config_.feat_config.nemo_normalize_type =
|
||||
model_->FeatureNormalizationMethod();
|
||||
|
||||
if (model_->IsGigaAM()) {
|
||||
config_.feat_config.low_freq = 0;
|
||||
config_.feat_config.high_freq = 8000;
|
||||
config_.feat_config.remove_dc_offset = false;
|
||||
config_.feat_config.preemph_coeff = 0;
|
||||
config_.feat_config.window_type = "hann";
|
||||
config_.feat_config.feature_dim = 64;
|
||||
} else {
|
||||
config_.feat_config.low_freq = 0;
|
||||
// config_.feat_config.high_freq = 8000;
|
||||
config_.feat_config.is_librosa = true;
|
||||
config_.feat_config.remove_dc_offset = false;
|
||||
// config_.feat_config.window_type = "hann";
|
||||
}
|
||||
|
||||
int32_t vocab_size = model_->VocabSize();
|
||||
|
||||
|
||||
@@ -153,6 +153,8 @@ class OfflineTransducerNeMoModel::Impl {
|
||||
|
||||
std::string FeatureNormalizationMethod() const { return normalize_type_; }
|
||||
|
||||
bool IsGigaAM() const { return is_giga_am_; }
|
||||
|
||||
private:
|
||||
void InitEncoder(void *model_data, size_t model_data_length) {
|
||||
encoder_sess_ = std::make_unique<Ort::Session>(
|
||||
@@ -181,9 +183,11 @@ class OfflineTransducerNeMoModel::Impl {
|
||||
vocab_size_ += 1;
|
||||
|
||||
SHERPA_ONNX_READ_META_DATA(subsampling_factor_, "subsampling_factor");
|
||||
SHERPA_ONNX_READ_META_DATA_STR(normalize_type_, "normalize_type");
|
||||
SHERPA_ONNX_READ_META_DATA_STR_ALLOW_EMPTY(normalize_type_,
|
||||
"normalize_type");
|
||||
SHERPA_ONNX_READ_META_DATA(pred_rnn_layers_, "pred_rnn_layers");
|
||||
SHERPA_ONNX_READ_META_DATA(pred_hidden_, "pred_hidden");
|
||||
SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(is_giga_am_, "is_giga_am", 0);
|
||||
|
||||
if (normalize_type_ == "NA") {
|
||||
normalize_type_ = "";
|
||||
@@ -245,6 +249,7 @@ class OfflineTransducerNeMoModel::Impl {
|
||||
std::string normalize_type_;
|
||||
int32_t pred_rnn_layers_ = -1;
|
||||
int32_t pred_hidden_ = -1;
|
||||
int32_t is_giga_am_ = 0;
|
||||
};
|
||||
|
||||
OfflineTransducerNeMoModel::OfflineTransducerNeMoModel(
|
||||
@@ -298,4 +303,6 @@ std::string OfflineTransducerNeMoModel::FeatureNormalizationMethod() const {
|
||||
return impl_->FeatureNormalizationMethod();
|
||||
}
|
||||
|
||||
bool OfflineTransducerNeMoModel::IsGigaAM() const { return impl_->IsGigaAM(); }
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -93,6 +93,8 @@ class OfflineTransducerNeMoModel {
|
||||
// for details
|
||||
std::string FeatureNormalizationMethod() const;
|
||||
|
||||
bool IsGigaAM() const;
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> impl_;
|
||||
|
||||
@@ -404,6 +404,19 @@ fun getOfflineModelConfig(type: Int): OfflineModelConfig? {
|
||||
tokens = "$modelDir/tokens.txt",
|
||||
)
|
||||
}
|
||||
|
||||
20 -> {
|
||||
val modelDir = "sherpa-onnx-nemo-transducer-giga-am-russian-2024-10-24"
|
||||
return OfflineModelConfig(
|
||||
transducer = OfflineTransducerModelConfig(
|
||||
encoder = "$modelDir/encoder.int8.onnx",
|
||||
decoder = "$modelDir/decoder.onnx",
|
||||
joiner = "$modelDir/joiner.onnx",
|
||||
),
|
||||
tokens = "$modelDir/tokens.txt",
|
||||
modelType = "nemo_transducer",
|
||||
)
|
||||
}
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user