diff --git a/.github/scripts/test-offline-transducer.sh b/.github/scripts/test-offline-transducer.sh index 8cb3c5e8..8e979e49 100755 --- a/.github/scripts/test-offline-transducer.sh +++ b/.github/scripts/test-offline-transducer.sh @@ -13,6 +13,105 @@ echo "PATH: $PATH" which $EXE +log "------------------------------------------------------------------------" +log "Run Nemo fast conformer hybrid transducer ctc models (transducer branch)" +log "------------------------------------------------------------------------" + +url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k.tar.bz2 +name=$(basename $url) +curl -SL -O $url +tar xvf $name +rm $name +repo=$(basename -s .tar.bz2 $name) +ls -lh $repo + +log "test $repo" +test_wavs=( +de-german.wav +es-spanish.wav +hr-croatian.wav +po-polish.wav +uk-ukrainian.wav +en-english.wav +fr-french.wav +it-italian.wav +ru-russian.wav +) +for w in ${test_wavs[@]}; do + time $EXE \ + --tokens=$repo/tokens.txt \ + --encoder=$repo/encoder.onnx \ + --decoder=$repo/decoder.onnx \ + --joiner=$repo/joiner.onnx \ + --debug=1 \ + $repo/test_wavs/$w +done + +rm -rf $repo + +url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-fast-conformer-transducer-en-24500.tar.bz2 +name=$(basename $url) +curl -SL -O $url +tar xvf $name +rm $name +repo=$(basename -s .tar.bz2 $name) +ls -lh $repo + +log "Test $repo" + +time $EXE \ + --tokens=$repo/tokens.txt \ + --encoder=$repo/encoder.onnx \ + --decoder=$repo/decoder.onnx \ + --joiner=$repo/joiner.onnx \ + --debug=1 \ + $repo/test_wavs/en-english.wav + +rm -rf $repo + +url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-fast-conformer-transducer-es-1424.tar.bz2 +name=$(basename $url) +curl -SL -O $url +tar xvf $name +rm $name +repo=$(basename -s .tar.bz2 $name) +ls -lh $repo + +log "test $repo" + +time $EXE \ + --tokens=$repo/tokens.txt \ + --encoder=$repo/encoder.onnx \ + --decoder=$repo/decoder.onnx \ + --joiner=$repo/joiner.onnx \ + --debug=1 \ + $repo/test_wavs/es-spanish.wav + +rm -rf $repo + +url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-fast-conformer-transducer-en-de-es-fr-14288.tar.bz2 +name=$(basename $url) +curl -SL -O $url +tar xvf $name +rm $name +repo=$(basename -s .tar.bz2 $name) +ls -lh $repo + +log "Test $repo" + +time $EXE \ + --tokens=$repo/tokens.txt \ + --encoder=$repo/encoder.onnx \ + --decoder=$repo/decoder.onnx \ + --joiner=$repo/joiner.onnx \ + --debug=1 \ + $repo/test_wavs/en-english.wav \ + $repo/test_wavs/de-german.wav \ + $repo/test_wavs/fr-french.wav \ + $repo/test_wavs/es-spanish.wav + +rm -rf $repo + log "------------------------------------------------------------" log "Run Conformer transducer (English)" log "------------------------------------------------------------" diff --git a/.github/workflows/linux.yaml b/.github/workflows/linux.yaml index c1e8a69a..59ef986b 100644 --- a/.github/workflows/linux.yaml +++ b/.github/workflows/linux.yaml @@ -128,6 +128,14 @@ jobs: name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }} path: install/* + - name: Test offline transducer + shell: bash + run: | + export PATH=$PWD/build/bin:$PATH + export EXE=sherpa-onnx-offline + + .github/scripts/test-offline-transducer.sh + - name: Test spoken language identification (C++ API) shell: bash run: | @@ -215,14 +223,6 @@ jobs: .github/scripts/test-online-paraformer.sh - - name: Test offline transducer - shell: bash - run: | - export PATH=$PWD/build/bin:$PATH - export EXE=sherpa-onnx-offline - - .github/scripts/test-offline-transducer.sh - - name: Test online transducer shell: bash run: | diff --git a/.github/workflows/macos.yaml b/.github/workflows/macos.yaml index 61687712..95030e57 100644 --- a/.github/workflows/macos.yaml +++ b/.github/workflows/macos.yaml @@ -107,6 +107,14 @@ jobs: otool -L build/bin/sherpa-onnx otool -l build/bin/sherpa-onnx + - name: Test offline transducer + shell: bash + run: | + export PATH=$PWD/build/bin:$PATH + export EXE=sherpa-onnx-offline + + .github/scripts/test-offline-transducer.sh + - name: Test online CTC shell: bash run: | @@ -192,14 +200,6 @@ jobs: .github/scripts/test-offline-ctc.sh - - name: Test offline transducer - shell: bash - run: | - export PATH=$PWD/build/bin:$PATH - export EXE=sherpa-onnx-offline - - .github/scripts/test-offline-transducer.sh - - name: Test online transducer shell: bash run: | diff --git a/.gitignore b/.gitignore index 0f023f0d..89b78433 100644 --- a/.gitignore +++ b/.gitignore @@ -104,3 +104,4 @@ sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01 sherpa-onnx-ced-* node_modules package-lock.json +sherpa-onnx-nemo-* diff --git a/python-api-examples/offline-nemo-ctc-decode-files.py b/python-api-examples/offline-nemo-ctc-decode-files.py new file mode 100755 index 00000000..d40b4089 --- /dev/null +++ b/python-api-examples/offline-nemo-ctc-decode-files.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python3 + +""" +This file shows how to use a non-streaming CTC model from NeMo +to decode files. + +Please download model files from +https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models + + +The example model supports 10 languages and it is converted from +https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_multilingual_fastconformer_hybrid_large_pc +""" + +from pathlib import Path + +import sherpa_onnx +import soundfile as sf + + +def create_recognizer(): + model = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/model.onnx" + tokens = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/tokens.txt" + + test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/de-german.wav" + # test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/en-english.wav" + # test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/es-spanish.wav" + # test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/fr-french.wav" + # test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/hr-croatian.wav" + # test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/it-italian.wav" + # test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/po-polish.wav" + # test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/ru-russian.wav" + # test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/uk-ukrainian.wav" + + if not Path(model).is_file() or not Path(test_wav).is_file(): + raise ValueError( + """Please download model files from + https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models + """ + ) + return ( + sherpa_onnx.OfflineRecognizer.from_nemo_ctc( + model=model, + tokens=tokens, + debug=True, + ), + test_wav, + ) + + +def main(): + recognizer, wave_filename = create_recognizer() + + audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True) + audio = audio[:, 0] # only use the first channel + + # audio is a 1-D float32 numpy array normalized to the range [-1, 1] + # sample_rate does not need to be 16000 Hz + + stream = recognizer.create_stream() + stream.accept_waveform(sample_rate, audio) + recognizer.decode_stream(stream) + print(wave_filename) + print(stream.result) + + +if __name__ == "__main__": + main() diff --git a/python-api-examples/offline-nemo-transducer-decode-files.py b/python-api-examples/offline-nemo-transducer-decode-files.py new file mode 100755 index 00000000..cd807764 --- /dev/null +++ b/python-api-examples/offline-nemo-transducer-decode-files.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python3 + +""" +This file shows how to use a non-streaming transducer model from NeMo +to decode files. + +Please download model files from +https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models + + +The example model supports 10 languages and it is converted from +https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_multilingual_fastconformer_hybrid_large_pc +""" + +from pathlib import Path + +import sherpa_onnx +import soundfile as sf + + +def create_recognizer(): + encoder = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/encoder.onnx" + decoder = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/decoder.onnx" + joiner = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/joiner.onnx" + tokens = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/tokens.txt" + + test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/de-german.wav" + # test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/en-english.wav" + # test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/es-spanish.wav" + # test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/fr-french.wav" + # test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/hr-croatian.wav" + # test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/it-italian.wav" + # test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/po-polish.wav" + # test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/ru-russian.wav" + # test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/uk-ukrainian.wav" + + if not Path(encoder).is_file() or not Path(test_wav).is_file(): + raise ValueError( + """Please download model files from + https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models + """ + ) + return ( + sherpa_onnx.OfflineRecognizer.from_transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + tokens=tokens, + model_type="nemo_transducer", + debug=True, + ), + test_wav, + ) + + +def main(): + recognizer, wave_filename = create_recognizer() + + audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True) + audio = audio[:, 0] # only use the first channel + + # audio is a 1-D float32 numpy array normalized to the range [-1, 1] + # sample_rate does not need to be 16000 Hz + + stream = recognizer.create_stream() + stream.accept_waveform(sample_rate, audio) + recognizer.decode_stream(stream) + print(wave_filename) + print(stream.result) + + +if __name__ == "__main__": + main() diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index 8c9fdbb7..fc5d240c 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -40,9 +40,11 @@ set(sources offline-tdnn-ctc-model.cc offline-tdnn-model-config.cc offline-transducer-greedy-search-decoder.cc + offline-transducer-greedy-search-nemo-decoder.cc offline-transducer-model-config.cc offline-transducer-model.cc offline-transducer-modified-beam-search-decoder.cc + offline-transducer-nemo-model.cc offline-wenet-ctc-model-config.cc offline-wenet-ctc-model.cc offline-whisper-greedy-search-decoder.cc diff --git a/sherpa-onnx/csrc/features.h b/sherpa-onnx/csrc/features.h index 08464418..c3bc02d5 100644 --- a/sherpa-onnx/csrc/features.h +++ b/sherpa-onnx/csrc/features.h @@ -56,6 +56,19 @@ struct FeatureExtractorConfig { bool remove_dc_offset = true; // Subtract mean of wave before FFT. std::string window_type = "povey"; // e.g. Hamming window + // For models from NeMo + // This option is not exposed and is set internally when loading models. + // Possible values: + // - per_feature + // - all_features (not implemented yet) + // - fixed_mean (not implemented) + // - fixed_std (not implemented) + // - or just leave it to empty + // See + // https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/parts/preprocessing/features.py#L59 + // for details + std::string nemo_normalize_type; + std::string ToString() const; void Register(ParseOptions *po); diff --git a/sherpa-onnx/csrc/keyword-spotter-transducer-impl.h b/sherpa-onnx/csrc/keyword-spotter-transducer-impl.h index a8a8242c..1b897f69 100644 --- a/sherpa-onnx/csrc/keyword-spotter-transducer-impl.h +++ b/sherpa-onnx/csrc/keyword-spotter-transducer-impl.h @@ -68,7 +68,7 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl { : config_(config), model_(OnlineTransducerModel::Create(config.model_config)), sym_(config.model_config.tokens) { - if (sym_.contains("")) { + if (sym_.Contains("")) { unk_id_ = sym_[""]; } @@ -87,7 +87,7 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl { : config_(config), model_(OnlineTransducerModel::Create(mgr, config.model_config)), sym_(mgr, config.model_config.tokens) { - if (sym_.contains("")) { + if (sym_.Contains("")) { unk_id_ = sym_[""]; } diff --git a/sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.cc b/sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.cc index 23123399..2d790954 100644 --- a/sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.cc +++ b/sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.cc @@ -1,6 +1,6 @@ // sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.cc // -// Copyright (c) 2023 Xiaomi Corporation +// Copyright (c) 2023-2024 Xiaomi Corporation #include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h" diff --git a/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h b/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h index ca0b5522..988a487b 100644 --- a/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h @@ -38,7 +38,7 @@ static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src, std::string text; for (int32_t i = 0; i != src.tokens.size(); ++i) { - if (sym_table.contains("SIL") && src.tokens[i] == sym_table["SIL"]) { + if (sym_table.Contains("SIL") && src.tokens[i] == sym_table["SIL"]) { // tdnn models from yesno have a SIL token, we should remove it. continue; } @@ -103,9 +103,9 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { decoder_ = std::make_unique( config_.ctc_fst_decoder_config); } else if (config_.decoding_method == "greedy_search") { - if (!symbol_table_.contains("") && - !symbol_table_.contains("") && - !symbol_table_.contains("")) { + if (!symbol_table_.Contains("") && + !symbol_table_.Contains("") && + !symbol_table_.Contains("")) { SHERPA_ONNX_LOGE( "We expect that tokens.txt contains " "the symbol or or and its ID."); @@ -113,12 +113,12 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { } int32_t blank_id = 0; - if (symbol_table_.contains("")) { + if (symbol_table_.Contains("")) { blank_id = symbol_table_[""]; - } else if (symbol_table_.contains("")) { + } else if (symbol_table_.Contains("")) { // for tdnn models of the yesno recipe from icefall blank_id = symbol_table_[""]; - } else if (symbol_table_.contains("")) { + } else if (symbol_table_.Contains("")) { // for Wenet CTC models blank_id = symbol_table_[""]; } diff --git a/sherpa-onnx/csrc/offline-recognizer-impl.cc b/sherpa-onnx/csrc/offline-recognizer-impl.cc index cb9246b2..c23acf12 100644 --- a/sherpa-onnx/csrc/offline-recognizer-impl.cc +++ b/sherpa-onnx/csrc/offline-recognizer-impl.cc @@ -11,6 +11,7 @@ #include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h" #include "sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h" #include "sherpa-onnx/csrc/offline-recognizer-transducer-impl.h" +#include "sherpa-onnx/csrc/offline-recognizer-transducer-nemo-impl.h" #include "sherpa-onnx/csrc/offline-recognizer-whisper-impl.h" #include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/text-utils.h" @@ -23,6 +24,8 @@ std::unique_ptr OfflineRecognizerImpl::Create( const auto &model_type = config.model_config.model_type; if (model_type == "transducer") { return std::make_unique(config); + } else if (model_type == "nemo_transducer") { + return std::make_unique(config); } else if (model_type == "paraformer") { return std::make_unique(config); } else if (model_type == "nemo_ctc" || model_type == "tdnn" || @@ -122,6 +125,12 @@ std::unique_ptr OfflineRecognizerImpl::Create( return std::make_unique(config); } + if (model_type == "EncDecHybridRNNTCTCBPEModel" && + !config.model_config.transducer.decoder_filename.empty() && + !config.model_config.transducer.joiner_filename.empty()) { + return std::make_unique(config); + } + if (model_type == "EncDecCTCModelBPE" || model_type == "EncDecHybridRNNTCTCBPEModel" || model_type == "tdnn" || model_type == "zipformer2_ctc" || model_type == "wenet_ctc") { @@ -155,6 +164,8 @@ std::unique_ptr OfflineRecognizerImpl::Create( const auto &model_type = config.model_config.model_type; if (model_type == "transducer") { return std::make_unique(mgr, config); + } else if (model_type == "nemo_transducer") { + return std::make_unique(mgr, config); } else if (model_type == "paraformer") { return std::make_unique(mgr, config); } else if (model_type == "nemo_ctc" || model_type == "tdnn" || @@ -254,6 +265,12 @@ std::unique_ptr OfflineRecognizerImpl::Create( return std::make_unique(mgr, config); } + if (model_type == "EncDecHybridRNNTCTCBPEModel" && + !config.model_config.transducer.decoder_filename.empty() && + !config.model_config.transducer.joiner_filename.empty()) { + return std::make_unique(mgr, config); + } + if (model_type == "EncDecCTCModelBPE" || model_type == "EncDecHybridRNNTCTCBPEModel" || model_type == "tdnn" || model_type == "zipformer2_ctc" || model_type == "wenet_ctc") { diff --git a/sherpa-onnx/csrc/offline-recognizer-transducer-nemo-impl.h b/sherpa-onnx/csrc/offline-recognizer-transducer-nemo-impl.h new file mode 100644 index 00000000..127fe343 --- /dev/null +++ b/sherpa-onnx/csrc/offline-recognizer-transducer-nemo-impl.h @@ -0,0 +1,182 @@ +// sherpa-onnx/csrc/offline-recognizer-transducer-nemo-impl.h +// +// Copyright (c) 2022-2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_ + +#include +#include +#include +#include // NOLINT +#include +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/offline-recognizer-impl.h" +#include "sherpa-onnx/csrc/offline-recognizer.h" +#include "sherpa-onnx/csrc/offline-transducer-greedy-search-nemo-decoder.h" +#include "sherpa-onnx/csrc/offline-transducer-nemo-model.h" +#include "sherpa-onnx/csrc/pad-sequence.h" +#include "sherpa-onnx/csrc/symbol-table.h" +#include "sherpa-onnx/csrc/transpose.h" +#include "sherpa-onnx/csrc/utils.h" + +namespace sherpa_onnx { + +// defined in ./offline-recognizer-transducer-impl.h +OfflineRecognitionResult Convert(const OfflineTransducerDecoderResult &src, + const SymbolTable &sym_table, + int32_t frame_shift_ms, + int32_t subsampling_factor); + +class OfflineRecognizerTransducerNeMoImpl : public OfflineRecognizerImpl { + public: + explicit OfflineRecognizerTransducerNeMoImpl( + const OfflineRecognizerConfig &config) + : config_(config), + symbol_table_(config_.model_config.tokens), + model_(std::make_unique( + config_.model_config)) { + if (config_.decoding_method == "greedy_search") { + decoder_ = std::make_unique( + model_.get(), config_.blank_penalty); + } else { + SHERPA_ONNX_LOGE("Unsupported decoding method: %s", + config_.decoding_method.c_str()); + exit(-1); + } + PostInit(); + } + +#if __ANDROID_API__ >= 9 + explicit OfflineRecognizerTransducerNeMoImpl( + AAssetManager *mgr, const OfflineRecognizerConfig &config) + : config_(config), + symbol_table_(mgr, config_.model_config.tokens), + model_(std::make_unique( + mgr, config_.model_config)) { + if (config_.decoding_method == "greedy_search") { + decoder_ = std::make_unique( + model_.get(), config_.blank_penalty); + } else { + SHERPA_ONNX_LOGE("Unsupported decoding method: %s", + config_.decoding_method.c_str()); + exit(-1); + } + + PostInit(); + } +#endif + + std::unique_ptr CreateStream() const override { + return std::make_unique(config_.feat_config); + } + + void DecodeStreams(OfflineStream **ss, int32_t n) const override { + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + int32_t feat_dim = ss[0]->FeatureDim(); + + std::vector features; + + features.reserve(n); + + std::vector> features_vec(n); + std::vector features_length_vec(n); + for (int32_t i = 0; i != n; ++i) { + auto f = ss[i]->GetFrames(); + int32_t num_frames = f.size() / feat_dim; + + features_length_vec[i] = num_frames; + features_vec[i] = std::move(f); + + std::array shape = {num_frames, feat_dim}; + + Ort::Value x = Ort::Value::CreateTensor( + memory_info, features_vec[i].data(), features_vec[i].size(), + shape.data(), shape.size()); + features.push_back(std::move(x)); + } + + std::vector features_pointer(n); + for (int32_t i = 0; i != n; ++i) { + features_pointer[i] = &features[i]; + } + + std::array features_length_shape = {n}; + Ort::Value x_length = Ort::Value::CreateTensor( + memory_info, features_length_vec.data(), n, + features_length_shape.data(), features_length_shape.size()); + + Ort::Value x = PadSequence(model_->Allocator(), features_pointer, 0); + + auto t = model_->RunEncoder(std::move(x), std::move(x_length)); + // t[0] encoder_out, float tensor, (batch_size, dim, T) + // t[1] encoder_out_length, int64 tensor, (batch_size,) + + Ort::Value encoder_out = Transpose12(model_->Allocator(), &t[0]); + + auto results = decoder_->Decode(std::move(encoder_out), std::move(t[1])); + + int32_t frame_shift_ms = 10; + for (int32_t i = 0; i != n; ++i) { + auto r = Convert(results[i], symbol_table_, frame_shift_ms, + model_->SubsamplingFactor()); + + ss[i]->SetResult(r); + } + } + + private: + void PostInit() { + config_.feat_config.nemo_normalize_type = + model_->FeatureNormalizationMethod(); + + config_.feat_config.low_freq = 0; + // config_.feat_config.high_freq = 8000; + config_.feat_config.is_librosa = true; + config_.feat_config.remove_dc_offset = false; + // config_.feat_config.window_type = "hann"; + config_.feat_config.dither = 0; + config_.feat_config.nemo_normalize_type = + model_->FeatureNormalizationMethod(); + + int32_t vocab_size = model_->VocabSize(); + + // check the blank ID + if (!symbol_table_.Contains("")) { + SHERPA_ONNX_LOGE("tokens.txt does not include the blank token "); + exit(-1); + } + + if (symbol_table_[""] != vocab_size - 1) { + SHERPA_ONNX_LOGE(" is not the last token!"); + exit(-1); + } + + if (symbol_table_.NumSymbols() != vocab_size) { + SHERPA_ONNX_LOGE("number of lines in tokens.txt %d != %d (vocab_size)", + symbol_table_.NumSymbols(), vocab_size); + exit(-1); + } + } + + private: + OfflineRecognizerConfig config_; + SymbolTable symbol_table_; + std::unique_ptr model_; + std::unique_ptr decoder_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_ diff --git a/sherpa-onnx/csrc/offline-recognizer-whisper-impl.h b/sherpa-onnx/csrc/offline-recognizer-whisper-impl.h index ea62925b..d224c860 100644 --- a/sherpa-onnx/csrc/offline-recognizer-whisper-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-whisper-impl.h @@ -35,7 +35,7 @@ static OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src, std::string text; for (auto i : src.tokens) { - if (!sym_table.contains(i)) { + if (!sym_table.Contains(i)) { continue; } diff --git a/sherpa-onnx/csrc/offline-recognizer.h b/sherpa-onnx/csrc/offline-recognizer.h index 1a878d63..e93d7edc 100644 --- a/sherpa-onnx/csrc/offline-recognizer.h +++ b/sherpa-onnx/csrc/offline-recognizer.h @@ -14,6 +14,7 @@ #include "android/asset_manager_jni.h" #endif +#include "sherpa-onnx/csrc/features.h" #include "sherpa-onnx/csrc/offline-ctc-fst-decoder-config.h" #include "sherpa-onnx/csrc/offline-lm-config.h" #include "sherpa-onnx/csrc/offline-model-config.h" @@ -26,7 +27,7 @@ namespace sherpa_onnx { struct OfflineRecognitionResult; struct OfflineRecognizerConfig { - OfflineFeatureExtractorConfig feat_config; + FeatureExtractorConfig feat_config; OfflineModelConfig model_config; OfflineLMConfig lm_config; OfflineCtcFstDecoderConfig ctc_fst_decoder_config; @@ -44,7 +45,7 @@ struct OfflineRecognizerConfig { OfflineRecognizerConfig() = default; OfflineRecognizerConfig( - const OfflineFeatureExtractorConfig &feat_config, + const FeatureExtractorConfig &feat_config, const OfflineModelConfig &model_config, const OfflineLMConfig &lm_config, const OfflineCtcFstDecoderConfig &ctc_fst_decoder_config, const std::string &decoding_method, int32_t max_active_paths, diff --git a/sherpa-onnx/csrc/offline-stream.cc b/sherpa-onnx/csrc/offline-stream.cc index ebd5f186..206b3600 100644 --- a/sherpa-onnx/csrc/offline-stream.cc +++ b/sherpa-onnx/csrc/offline-stream.cc @@ -52,42 +52,25 @@ static void ComputeMeanAndInvStd(const float *p, int32_t num_rows, } } -void OfflineFeatureExtractorConfig::Register(ParseOptions *po) { - po->Register("sample-rate", &sampling_rate, - "Sampling rate of the input waveform. " - "Note: You can have a different " - "sample rate for the input waveform. We will do resampling " - "inside the feature extractor"); - - po->Register("feat-dim", &feature_dim, - "Feature dimension. Must match the one expected by the model."); -} - -std::string OfflineFeatureExtractorConfig::ToString() const { - std::ostringstream os; - - os << "OfflineFeatureExtractorConfig("; - os << "sampling_rate=" << sampling_rate << ", "; - os << "feature_dim=" << feature_dim << ")"; - - return os.str(); -} - class OfflineStream::Impl { public: - explicit Impl(const OfflineFeatureExtractorConfig &config, + explicit Impl(const FeatureExtractorConfig &config, ContextGraphPtr context_graph) : config_(config), context_graph_(context_graph) { - opts_.frame_opts.dither = 0; - opts_.frame_opts.snip_edges = false; + opts_.frame_opts.dither = config.dither; + opts_.frame_opts.snip_edges = config.snip_edges; opts_.frame_opts.samp_freq = config.sampling_rate; + opts_.frame_opts.frame_shift_ms = config.frame_shift_ms; + opts_.frame_opts.frame_length_ms = config.frame_length_ms; + opts_.frame_opts.remove_dc_offset = config.remove_dc_offset; + opts_.frame_opts.window_type = config.window_type; + opts_.mel_opts.num_bins = config.feature_dim; - // Please see - // https://github.com/lhotse-speech/lhotse/blob/master/lhotse/features/fbank.py#L27 - // and - // https://github.com/k2-fsa/sherpa-onnx/issues/514 - opts_.mel_opts.high_freq = -400; + opts_.mel_opts.high_freq = config.high_freq; + opts_.mel_opts.low_freq = config.low_freq; + + opts_.mel_opts.is_librosa = config.is_librosa; fbank_ = std::make_unique(opts_); } @@ -237,7 +220,7 @@ class OfflineStream::Impl { } private: - OfflineFeatureExtractorConfig config_; + FeatureExtractorConfig config_; std::unique_ptr fbank_; std::unique_ptr whisper_fbank_; knf::FbankOptions opts_; @@ -245,9 +228,8 @@ class OfflineStream::Impl { ContextGraphPtr context_graph_; }; -OfflineStream::OfflineStream( - const OfflineFeatureExtractorConfig &config /*= {}*/, - ContextGraphPtr context_graph /*= nullptr*/) +OfflineStream::OfflineStream(const FeatureExtractorConfig &config /*= {}*/, + ContextGraphPtr context_graph /*= nullptr*/) : impl_(std::make_unique(config, context_graph)) {} OfflineStream::OfflineStream(WhisperTag tag) diff --git a/sherpa-onnx/csrc/offline-stream.h b/sherpa-onnx/csrc/offline-stream.h index 4fd6cef9..13cc5600 100644 --- a/sherpa-onnx/csrc/offline-stream.h +++ b/sherpa-onnx/csrc/offline-stream.h @@ -11,6 +11,7 @@ #include #include "sherpa-onnx/csrc/context-graph.h" +#include "sherpa-onnx/csrc/features.h" #include "sherpa-onnx/csrc/parse-options.h" namespace sherpa_onnx { @@ -32,46 +33,12 @@ struct OfflineRecognitionResult { std::string AsJsonString() const; }; -struct OfflineFeatureExtractorConfig { - // Sampling rate used by the feature extractor. If it is different from - // the sampling rate of the input waveform, we will do resampling inside. - int32_t sampling_rate = 16000; - - // Feature dimension - int32_t feature_dim = 80; - - // Set internally by some models, e.g., paraformer and wenet CTC models set - // it to false. - // This parameter is not exposed to users from the commandline - // If true, the feature extractor expects inputs to be normalized to - // the range [-1, 1]. - // If false, we will multiply the inputs by 32768 - bool normalize_samples = true; - - // For models from NeMo - // This option is not exposed and is set internally when loading models. - // Possible values: - // - per_feature - // - all_features (not implemented yet) - // - fixed_mean (not implemented) - // - fixed_std (not implemented) - // - or just leave it to empty - // See - // https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/parts/preprocessing/features.py#L59 - // for details - std::string nemo_normalize_type; - - std::string ToString() const; - - void Register(ParseOptions *po); -}; - struct WhisperTag {}; struct CEDTag {}; class OfflineStream { public: - explicit OfflineStream(const OfflineFeatureExtractorConfig &config = {}, + explicit OfflineStream(const FeatureExtractorConfig &config = {}, ContextGraphPtr context_graph = {}); explicit OfflineStream(WhisperTag tag); diff --git a/sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h b/sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h index ca638c97..b284d22a 100644 --- a/sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h +++ b/sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h @@ -14,8 +14,8 @@ namespace sherpa_onnx { class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder { public: - explicit OfflineTransducerGreedySearchDecoder(OfflineTransducerModel *model, - float blank_penalty) + OfflineTransducerGreedySearchDecoder(OfflineTransducerModel *model, + float blank_penalty) : model_(model), blank_penalty_(blank_penalty) {} std::vector Decode( diff --git a/sherpa-onnx/csrc/offline-transducer-greedy-search-nemo-decoder.cc b/sherpa-onnx/csrc/offline-transducer-greedy-search-nemo-decoder.cc new file mode 100644 index 00000000..9fccefad --- /dev/null +++ b/sherpa-onnx/csrc/offline-transducer-greedy-search-nemo-decoder.cc @@ -0,0 +1,117 @@ +// sherpa-onnx/csrc/offline-transducer-greedy-search-nemo-decoder.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-transducer-greedy-search-nemo-decoder.h" + +#include +#include +#include + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/onnx-utils.h" + +namespace sherpa_onnx { + +static std::pair BuildDecoderInput( + int32_t token, OrtAllocator *allocator) { + std::array shape{1, 1}; + + Ort::Value decoder_input = + Ort::Value::CreateTensor(allocator, shape.data(), shape.size()); + + std::array length_shape{1}; + Ort::Value decoder_input_length = Ort::Value::CreateTensor( + allocator, length_shape.data(), length_shape.size()); + + int32_t *p = decoder_input.GetTensorMutableData(); + + int32_t *p_length = decoder_input_length.GetTensorMutableData(); + + p[0] = token; + + p_length[0] = 1; + + return {std::move(decoder_input), std::move(decoder_input_length)}; +} + +static OfflineTransducerDecoderResult DecodeOne( + const float *p, int32_t num_rows, int32_t num_cols, + OfflineTransducerNeMoModel *model, float blank_penalty) { + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + OfflineTransducerDecoderResult ans; + + int32_t vocab_size = model->VocabSize(); + int32_t blank_id = vocab_size - 1; + + auto decoder_input_pair = BuildDecoderInput(blank_id, model->Allocator()); + + std::pair> decoder_output_pair = + model->RunDecoder(std::move(decoder_input_pair.first), + std::move(decoder_input_pair.second), + model->GetDecoderInitStates(1)); + + std::array encoder_shape{1, num_cols, 1}; + + for (int32_t t = 0; t != num_rows; ++t) { + Ort::Value cur_encoder_out = Ort::Value::CreateTensor( + memory_info, const_cast(p) + t * num_cols, num_cols, + encoder_shape.data(), encoder_shape.size()); + + Ort::Value logit = model->RunJoiner(std::move(cur_encoder_out), + View(&decoder_output_pair.first)); + + float *p_logit = logit.GetTensorMutableData(); + if (blank_penalty > 0) { + p_logit[blank_id] -= blank_penalty; + } + + auto y = static_cast(std::distance( + static_cast(p_logit), + std::max_element(static_cast(p_logit), + static_cast(p_logit) + vocab_size))); + + if (y != blank_id) { + ans.tokens.push_back(y); + ans.timestamps.push_back(t); + + decoder_input_pair = BuildDecoderInput(y, model->Allocator()); + + decoder_output_pair = + model->RunDecoder(std::move(decoder_input_pair.first), + std::move(decoder_input_pair.second), + std::move(decoder_output_pair.second)); + } // if (y != blank_id) + } // for (int32_t i = 0; i != num_rows; ++i) + + return ans; +} + +std::vector +OfflineTransducerGreedySearchNeMoDecoder::Decode( + Ort::Value encoder_out, Ort::Value encoder_out_length, + OfflineStream ** /*ss = nullptr*/, int32_t /*n= 0*/) { + auto shape = encoder_out.GetTensorTypeAndShapeInfo().GetShape(); + + int32_t batch_size = static_cast(shape[0]); + int32_t dim1 = static_cast(shape[1]); + int32_t dim2 = static_cast(shape[2]); + + const int64_t *p_length = encoder_out_length.GetTensorData(); + const float *p = encoder_out.GetTensorData(); + + std::vector ans(batch_size); + + for (int32_t i = 0; i != batch_size; ++i) { + const float *this_p = p + dim1 * dim2 * i; + int32_t this_len = p_length[i]; + + ans[i] = DecodeOne(this_p, this_len, dim2, model_, blank_penalty_); + } + + return ans; +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-transducer-greedy-search-nemo-decoder.h b/sherpa-onnx/csrc/offline-transducer-greedy-search-nemo-decoder.h new file mode 100644 index 00000000..ab0ca422 --- /dev/null +++ b/sherpa-onnx/csrc/offline-transducer-greedy-search-nemo-decoder.h @@ -0,0 +1,33 @@ +// sherpa-onnx/csrc/offline-transducer-greedy-search-nemo-decoder.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_ + +#include + +#include "sherpa-onnx/csrc/offline-transducer-decoder.h" +#include "sherpa-onnx/csrc/offline-transducer-nemo-model.h" + +namespace sherpa_onnx { + +class OfflineTransducerGreedySearchNeMoDecoder + : public OfflineTransducerDecoder { + public: + OfflineTransducerGreedySearchNeMoDecoder(OfflineTransducerNeMoModel *model, + float blank_penalty) + : model_(model), blank_penalty_(blank_penalty) {} + + std::vector Decode( + Ort::Value encoder_out, Ort::Value encoder_out_length, + OfflineStream **ss = nullptr, int32_t n = 0) override; + + private: + OfflineTransducerNeMoModel *model_; // Not owned + float blank_penalty_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_ diff --git a/sherpa-onnx/csrc/offline-transducer-nemo-model.cc b/sherpa-onnx/csrc/offline-transducer-nemo-model.cc new file mode 100644 index 00000000..eee744ab --- /dev/null +++ b/sherpa-onnx/csrc/offline-transducer-nemo-model.cc @@ -0,0 +1,301 @@ +// sherpa-onnx/csrc/offline-transducer-nemo-model.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-transducer-nemo-model.h" + +#include +#include +#include +#include + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/offline-transducer-decoder.h" +#include "sherpa-onnx/csrc/onnx-utils.h" +#include "sherpa-onnx/csrc/session.h" +#include "sherpa-onnx/csrc/transpose.h" + +namespace sherpa_onnx { + +class OfflineTransducerNeMoModel::Impl { + public: + explicit Impl(const OfflineModelConfig &config) + : config_(config), + env_(ORT_LOGGING_LEVEL_WARNING), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(config.transducer.encoder_filename); + InitEncoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(config.transducer.decoder_filename); + InitDecoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(config.transducer.joiner_filename); + InitJoiner(buf.data(), buf.size()); + } + } + +#if __ANDROID_API__ >= 9 + Impl(AAssetManager *mgr, const OfflineModelConfig &config) + : config_(config), + env_(ORT_LOGGING_LEVEL_WARNING), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(mgr, config.transducer.encoder_filename); + InitEncoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(mgr, config.transducer.decoder_filename); + InitDecoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(mgr, config.transducer.joiner_filename); + InitJoiner(buf.data(), buf.size()); + } + } +#endif + + std::vector RunEncoder(Ort::Value features, + Ort::Value features_length) { + // (B, T, C) -> (B, C, T) + features = Transpose12(allocator_, &features); + + std::array encoder_inputs = {std::move(features), + std::move(features_length)}; + + auto encoder_out = encoder_sess_->Run( + {}, encoder_input_names_ptr_.data(), encoder_inputs.data(), + encoder_inputs.size(), encoder_output_names_ptr_.data(), + encoder_output_names_ptr_.size()); + + return encoder_out; + } + + std::pair> RunDecoder( + Ort::Value targets, Ort::Value targets_length, + std::vector states) { + std::vector decoder_inputs; + decoder_inputs.reserve(2 + states.size()); + + decoder_inputs.push_back(std::move(targets)); + decoder_inputs.push_back(std::move(targets_length)); + + for (auto &s : states) { + decoder_inputs.push_back(std::move(s)); + } + + auto decoder_out = decoder_sess_->Run( + {}, decoder_input_names_ptr_.data(), decoder_inputs.data(), + decoder_inputs.size(), decoder_output_names_ptr_.data(), + decoder_output_names_ptr_.size()); + + std::vector states_next; + states_next.reserve(states.size()); + + // decoder_out[0]: decoder_output + // decoder_out[1]: decoder_output_length + // decoder_out[2:] states_next + + for (int32_t i = 0; i != states.size(); ++i) { + states_next.push_back(std::move(decoder_out[i + 2])); + } + + // we discard decoder_out[1] + return {std::move(decoder_out[0]), std::move(states_next)}; + } + + Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) { + std::array joiner_input = {std::move(encoder_out), + std::move(decoder_out)}; + auto logit = joiner_sess_->Run({}, joiner_input_names_ptr_.data(), + joiner_input.data(), joiner_input.size(), + joiner_output_names_ptr_.data(), + joiner_output_names_ptr_.size()); + + return std::move(logit[0]); + } + + std::vector GetDecoderInitStates(int32_t batch_size) const { + std::array s0_shape{pred_rnn_layers_, batch_size, pred_hidden_}; + Ort::Value s0 = Ort::Value::CreateTensor(allocator_, s0_shape.data(), + s0_shape.size()); + + Fill(&s0, 0); + + std::array s1_shape{pred_rnn_layers_, batch_size, pred_hidden_}; + + Ort::Value s1 = Ort::Value::CreateTensor(allocator_, s1_shape.data(), + s1_shape.size()); + + Fill(&s1, 0); + + std::vector states; + + states.reserve(2); + states.push_back(std::move(s0)); + states.push_back(std::move(s1)); + + return states; + } + + int32_t SubsamplingFactor() const { return subsampling_factor_; } + int32_t VocabSize() const { return vocab_size_; } + + OrtAllocator *Allocator() const { return allocator_; } + + std::string FeatureNormalizationMethod() const { return normalize_type_; } + + private: + void InitEncoder(void *model_data, size_t model_data_length) { + encoder_sess_ = std::make_unique( + env_, model_data, model_data_length, sess_opts_); + + GetInputNames(encoder_sess_.get(), &encoder_input_names_, + &encoder_input_names_ptr_); + + GetOutputNames(encoder_sess_.get(), &encoder_output_names_, + &encoder_output_names_ptr_); + + // get meta data + Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata(); + if (config_.debug) { + std::ostringstream os; + os << "---encoder---\n"; + PrintModelMetadata(os, meta_data); + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); + } + + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below + SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size"); + + // need to increase by 1 since the blank token is not included in computing + // vocab_size in NeMo. + vocab_size_ += 1; + + SHERPA_ONNX_READ_META_DATA(subsampling_factor_, "subsampling_factor"); + SHERPA_ONNX_READ_META_DATA_STR(normalize_type_, "normalize_type"); + SHERPA_ONNX_READ_META_DATA(pred_rnn_layers_, "pred_rnn_layers"); + SHERPA_ONNX_READ_META_DATA(pred_hidden_, "pred_hidden"); + + if (normalize_type_ == "NA") { + normalize_type_ = ""; + } + } + + void InitDecoder(void *model_data, size_t model_data_length) { + decoder_sess_ = std::make_unique( + env_, model_data, model_data_length, sess_opts_); + + GetInputNames(decoder_sess_.get(), &decoder_input_names_, + &decoder_input_names_ptr_); + + GetOutputNames(decoder_sess_.get(), &decoder_output_names_, + &decoder_output_names_ptr_); + } + + void InitJoiner(void *model_data, size_t model_data_length) { + joiner_sess_ = std::make_unique( + env_, model_data, model_data_length, sess_opts_); + + GetInputNames(joiner_sess_.get(), &joiner_input_names_, + &joiner_input_names_ptr_); + + GetOutputNames(joiner_sess_.get(), &joiner_output_names_, + &joiner_output_names_ptr_); + } + + private: + OfflineModelConfig config_; + Ort::Env env_; + Ort::SessionOptions sess_opts_; + Ort::AllocatorWithDefaultOptions allocator_; + + std::unique_ptr encoder_sess_; + std::unique_ptr decoder_sess_; + std::unique_ptr joiner_sess_; + + std::vector encoder_input_names_; + std::vector encoder_input_names_ptr_; + + std::vector encoder_output_names_; + std::vector encoder_output_names_ptr_; + + std::vector decoder_input_names_; + std::vector decoder_input_names_ptr_; + + std::vector decoder_output_names_; + std::vector decoder_output_names_ptr_; + + std::vector joiner_input_names_; + std::vector joiner_input_names_ptr_; + + std::vector joiner_output_names_; + std::vector joiner_output_names_ptr_; + + int32_t vocab_size_ = 0; + int32_t subsampling_factor_ = 8; + std::string normalize_type_; + int32_t pred_rnn_layers_ = -1; + int32_t pred_hidden_ = -1; +}; + +OfflineTransducerNeMoModel::OfflineTransducerNeMoModel( + const OfflineModelConfig &config) + : impl_(std::make_unique(config)) {} + +#if __ANDROID_API__ >= 9 +OfflineTransducerNeMoModel::OfflineTransducerNeMoModel( + AAssetManager *mgr, const OfflineModelConfig &config) + : impl_(std::make_unique(mgr, config)) {} +#endif + +OfflineTransducerNeMoModel::~OfflineTransducerNeMoModel() = default; + +std::vector OfflineTransducerNeMoModel::RunEncoder( + Ort::Value features, Ort::Value features_length) const { + return impl_->RunEncoder(std::move(features), std::move(features_length)); +} + +std::pair> +OfflineTransducerNeMoModel::RunDecoder(Ort::Value targets, + Ort::Value targets_length, + std::vector states) const { + return impl_->RunDecoder(std::move(targets), std::move(targets_length), + std::move(states)); +} + +std::vector OfflineTransducerNeMoModel::GetDecoderInitStates( + int32_t batch_size) const { + return impl_->GetDecoderInitStates(batch_size); +} + +Ort::Value OfflineTransducerNeMoModel::RunJoiner(Ort::Value encoder_out, + Ort::Value decoder_out) const { + return impl_->RunJoiner(std::move(encoder_out), std::move(decoder_out)); +} + +int32_t OfflineTransducerNeMoModel::SubsamplingFactor() const { + return impl_->SubsamplingFactor(); +} + +int32_t OfflineTransducerNeMoModel::VocabSize() const { + return impl_->VocabSize(); +} + +OrtAllocator *OfflineTransducerNeMoModel::Allocator() const { + return impl_->Allocator(); +} + +std::string OfflineTransducerNeMoModel::FeatureNormalizationMethod() const { + return impl_->FeatureNormalizationMethod(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-transducer-nemo-model.h b/sherpa-onnx/csrc/offline-transducer-nemo-model.h new file mode 100644 index 00000000..9ac13591 --- /dev/null +++ b/sherpa-onnx/csrc/offline-transducer-nemo-model.h @@ -0,0 +1,103 @@ +// sherpa-onnx/csrc/offline-transducer-nemo-model.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_NEMO_MODEL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_NEMO_MODEL_H_ + +#include +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/offline-model-config.h" + +namespace sherpa_onnx { + +// see +// https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py#L40 +// Its decoder is stateful, not stateless. +class OfflineTransducerNeMoModel { + public: + explicit OfflineTransducerNeMoModel(const OfflineModelConfig &config); + +#if __ANDROID_API__ >= 9 + OfflineTransducerNeMoModel(AAssetManager *mgr, + const OfflineModelConfig &config); +#endif + + ~OfflineTransducerNeMoModel(); + + /** Run the encoder. + * + * @param features A tensor of shape (N, T, C). It is changed in-place. + * @param features_length A 1-D tensor of shape (N,) containing number of + * valid frames in `features` before padding. + * Its dtype is int64_t. + * + * @return Return a vector containing: + * - encoder_out: A 3-D tensor of shape (N, T', encoder_dim) + * - encoder_out_length: A 1-D tensor of shape (N,) containing number + * of frames in `encoder_out` before padding. + */ + std::vector RunEncoder(Ort::Value features, + Ort::Value features_length) const; + + /** Run the decoder network. + * + * @param targets A int32 tensor of shape (batch_size, 1) + * @param targets_length A int32 tensor of shape (batch_size,) + * @param states The states for the decoder model. + * @return Return a vector: + * - ans[0] is the decoder_out (a float tensor) + * - ans[1] is the decoder_out_length (a int32 tensor) + * - ans[2:] is the states_next + */ + std::pair> RunDecoder( + Ort::Value targets, Ort::Value targets_length, + std::vector states) const; + + std::vector GetDecoderInitStates(int32_t batch_size) const; + + /** Run the joint network. + * + * @param encoder_out Output of the encoder network. + * @param decoder_out Output of the decoder network. + * @return Return a tensor of shape (N, 1, 1, vocab_size) containing logits. + */ + Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) const; + + /** Return the subsampling factor of the model. + */ + int32_t SubsamplingFactor() const; + + int32_t VocabSize() const; + + /** Return an allocator for allocating memory + */ + OrtAllocator *Allocator() const; + + // Possible values: + // - per_feature + // - all_features (not implemented yet) + // - fixed_mean (not implemented) + // - fixed_std (not implemented) + // - or just leave it to empty + // See + // https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/parts/preprocessing/features.py#L59 + // for details + std::string FeatureNormalizationMethod() const; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_NEMO_MODEL_H_ diff --git a/sherpa-onnx/csrc/online-recognizer-ctc-impl.h b/sherpa-onnx/csrc/online-recognizer-ctc-impl.h index 4b137e29..a54332d5 100644 --- a/sherpa-onnx/csrc/online-recognizer-ctc-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-ctc-impl.h @@ -223,8 +223,8 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { private: void InitDecoder() { - if (!sym_.contains("") && !sym_.contains("") && - !sym_.contains("")) { + if (!sym_.Contains("") && !sym_.Contains("") && + !sym_.Contains("")) { SHERPA_ONNX_LOGE( "We expect that tokens.txt contains " "the symbol or or and its ID."); @@ -232,12 +232,12 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { } int32_t blank_id = 0; - if (sym_.contains("")) { + if (sym_.Contains("")) { blank_id = sym_[""]; - } else if (sym_.contains("")) { + } else if (sym_.Contains("")) { // for tdnn models of the yesno recipe from icefall blank_id = sym_[""]; - } else if (sym_.contains("")) { + } else if (sym_.Contains("")) { // for WeNet CTC models blank_id = sym_[""]; } diff --git a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h index a2c24452..dcf52b99 100644 --- a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h @@ -87,7 +87,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { model_(OnlineTransducerModel::Create(config.model_config)), sym_(config.model_config.tokens), endpoint_(config_.endpoint_config) { - if (sym_.contains("")) { + if (sym_.Contains("")) { unk_id_ = sym_[""]; } @@ -103,19 +103,13 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { } decoder_ = std::make_unique( - model_.get(), - lm_.get(), - config_.max_active_paths, - config_.lm_config.scale, - unk_id_, - config_.blank_penalty, + model_.get(), lm_.get(), config_.max_active_paths, + config_.lm_config.scale, unk_id_, config_.blank_penalty, config_.temperature_scale); } else if (config.decoding_method == "greedy_search") { decoder_ = std::make_unique( - model_.get(), - unk_id_, - config_.blank_penalty, + model_.get(), unk_id_, config_.blank_penalty, config_.temperature_scale); } else { @@ -132,7 +126,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { model_(OnlineTransducerModel::Create(mgr, config.model_config)), sym_(mgr, config.model_config.tokens), endpoint_(config_.endpoint_config) { - if (sym_.contains("")) { + if (sym_.Contains("")) { unk_id_ = sym_[""]; } @@ -151,19 +145,13 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { } decoder_ = std::make_unique( - model_.get(), - lm_.get(), - config_.max_active_paths, - config_.lm_config.scale, - unk_id_, - config_.blank_penalty, + model_.get(), lm_.get(), config_.max_active_paths, + config_.lm_config.scale, unk_id_, config_.blank_penalty, config_.temperature_scale); } else if (config.decoding_method == "greedy_search") { decoder_ = std::make_unique( - model_.get(), - unk_id_, - config_.blank_penalty, + model_.get(), unk_id_, config_.blank_penalty, config_.temperature_scale); } else { diff --git a/sherpa-onnx/csrc/slice.h b/sherpa-onnx/csrc/slice.h index 21a93f1a..42ebf6ce 100644 --- a/sherpa-onnx/csrc/slice.h +++ b/sherpa-onnx/csrc/slice.h @@ -13,7 +13,7 @@ namespace sherpa_onnx { * It returns v[dim0_start:dim0_end, dim1_start:dim1_end, :] * * @param allocator - * @param v A 2-D tensor. Its data type is T. + * @param v A 3-D tensor. Its data type is T. * @param dim0_start Start index of the first dimension.. * @param dim0_end End index of the first dimension.. * @param dim1_start Start index of the second dimension. diff --git a/sherpa-onnx/csrc/symbol-table.cc b/sherpa-onnx/csrc/symbol-table.cc index d27249f4..1300919b 100644 --- a/sherpa-onnx/csrc/symbol-table.cc +++ b/sherpa-onnx/csrc/symbol-table.cc @@ -100,9 +100,9 @@ int32_t SymbolTable::operator[](const std::string &sym) const { return sym2id_.at(sym); } -bool SymbolTable::contains(int32_t id) const { return id2sym_.count(id) != 0; } +bool SymbolTable::Contains(int32_t id) const { return id2sym_.count(id) != 0; } -bool SymbolTable::contains(const std::string &sym) const { +bool SymbolTable::Contains(const std::string &sym) const { return sym2id_.count(sym) != 0; } diff --git a/sherpa-onnx/csrc/symbol-table.h b/sherpa-onnx/csrc/symbol-table.h index 7a83ab24..8d0a4e98 100644 --- a/sherpa-onnx/csrc/symbol-table.h +++ b/sherpa-onnx/csrc/symbol-table.h @@ -40,14 +40,16 @@ class SymbolTable { int32_t operator[](const std::string &sym) const; /// Return true if there is a symbol with the given ID. - bool contains(int32_t id) const; + bool Contains(int32_t id) const; /// Return true if there is a given symbol in the symbol table. - bool contains(const std::string &sym) const; + bool Contains(const std::string &sym) const; // for tokens.txt from Whisper void ApplyBase64Decode(); + int32_t NumSymbols() const { return id2sym_.size(); } + private: void Init(std::istream &is); diff --git a/sherpa-onnx/csrc/utils.cc b/sherpa-onnx/csrc/utils.cc index a61323c9..657fcbf7 100644 --- a/sherpa-onnx/csrc/utils.cc +++ b/sherpa-onnx/csrc/utils.cc @@ -49,7 +49,7 @@ static bool EncodeBase(std::istream &is, const SymbolTable &symbol_table, word = word.replace(0, 3, " "); } } - if (symbol_table.contains(word)) { + if (symbol_table.Contains(word)) { int32_t id = symbol_table[word]; tmp_ids.push_back(id); } else { diff --git a/sherpa-onnx/python/csrc/offline-recognizer.cc b/sherpa-onnx/python/csrc/offline-recognizer.cc index 823e280f..5ef9d4f2 100644 --- a/sherpa-onnx/python/csrc/offline-recognizer.cc +++ b/sherpa-onnx/python/csrc/offline-recognizer.cc @@ -14,10 +14,10 @@ namespace sherpa_onnx { static void PybindOfflineRecognizerConfig(py::module *m) { using PyClass = OfflineRecognizerConfig; py::class_(*m, "OfflineRecognizerConfig") - .def(py::init(), + .def(py::init(), py::arg("feat_config"), py::arg("model_config"), py::arg("lm_config") = OfflineLMConfig(), py::arg("ctc_fst_decoder_config") = OfflineCtcFstDecoderConfig(), diff --git a/sherpa-onnx/python/csrc/offline-stream.cc b/sherpa-onnx/python/csrc/offline-stream.cc index 80d54d6c..5679eca7 100644 --- a/sherpa-onnx/python/csrc/offline-stream.cc +++ b/sherpa-onnx/python/csrc/offline-stream.cc @@ -25,6 +25,7 @@ Args: static void PybindOfflineRecognitionResult(py::module *m) { // NOLINT using PyClass = OfflineRecognitionResult; py::class_(*m, "OfflineRecognitionResult") + .def("__str__", &PyClass::AsJsonString) .def_property_readonly( "text", [](const PyClass &self) -> py::str { @@ -37,18 +38,7 @@ static void PybindOfflineRecognitionResult(py::module *m) { // NOLINT "timestamps", [](const PyClass &self) { return self.timestamps; }); } -static void PybindOfflineFeatureExtractorConfig(py::module *m) { - using PyClass = OfflineFeatureExtractorConfig; - py::class_(*m, "OfflineFeatureExtractorConfig") - .def(py::init(), py::arg("sampling_rate") = 16000, - py::arg("feature_dim") = 80) - .def_readwrite("sampling_rate", &PyClass::sampling_rate) - .def_readwrite("feature_dim", &PyClass::feature_dim) - .def("__str__", &PyClass::ToString); -} - void PybindOfflineStream(py::module *m) { - PybindOfflineFeatureExtractorConfig(m); PybindOfflineRecognitionResult(m); using PyClass = OfflineStream; diff --git a/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py b/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py index c4b78e97..9bf5d18b 100644 --- a/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py @@ -4,8 +4,8 @@ from pathlib import Path from typing import List, Optional from _sherpa_onnx import ( + FeatureExtractorConfig, OfflineCtcFstDecoderConfig, - OfflineFeatureExtractorConfig, OfflineModelConfig, OfflineNemoEncDecCtcModelConfig, OfflineParaformerModelConfig, @@ -51,6 +51,7 @@ class OfflineRecognizer(object): blank_penalty: float = 0.0, debug: bool = False, provider: str = "cpu", + model_type: str = "transducer", ): """ Please refer to @@ -106,10 +107,10 @@ class OfflineRecognizer(object): num_threads=num_threads, debug=debug, provider=provider, - model_type="transducer", + model_type=model_type, ) - feat_config = OfflineFeatureExtractorConfig( + feat_config = FeatureExtractorConfig( sampling_rate=sample_rate, feature_dim=feature_dim, ) @@ -182,7 +183,7 @@ class OfflineRecognizer(object): model_type="paraformer", ) - feat_config = OfflineFeatureExtractorConfig( + feat_config = FeatureExtractorConfig( sampling_rate=sample_rate, feature_dim=feature_dim, ) @@ -246,7 +247,7 @@ class OfflineRecognizer(object): model_type="nemo_ctc", ) - feat_config = OfflineFeatureExtractorConfig( + feat_config = FeatureExtractorConfig( sampling_rate=sample_rate, feature_dim=feature_dim, ) @@ -326,7 +327,7 @@ class OfflineRecognizer(object): model_type="whisper", ) - feat_config = OfflineFeatureExtractorConfig( + feat_config = FeatureExtractorConfig( sampling_rate=16000, feature_dim=80, ) @@ -389,7 +390,7 @@ class OfflineRecognizer(object): model_type="tdnn", ) - feat_config = OfflineFeatureExtractorConfig( + feat_config = FeatureExtractorConfig( sampling_rate=sample_rate, feature_dim=feature_dim, ) @@ -453,7 +454,7 @@ class OfflineRecognizer(object): model_type="wenet_ctc", ) - feat_config = OfflineFeatureExtractorConfig( + feat_config = FeatureExtractorConfig( sampling_rate=sample_rate, feature_dim=feature_dim, )