From 80060c276d2d5d75c9b2f24a132d7694ed1e4942 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 7 Apr 2023 23:11:34 +0800 Subject: [PATCH] Begin to support CTC models (#119) Please see https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/nemo/index.html for a list of pre-trained CTC models from NeMo. --- .github/scripts/test-offline-ctc.sh | 47 +++++++ .github/scripts/test-python.sh | 38 +++++ .github/workflows/linux.yaml | 10 ++ .github/workflows/macos.yaml | 10 ++ .github/workflows/windows-x64.yaml | 10 ++ .github/workflows/windows-x86.yaml | 11 ++ .gitignore | 3 + python-api-examples/offline-decode-files.py | 47 +++++-- python-api-examples/online-decode-files.py | 8 +- sherpa-onnx/csrc/CMakeLists.txt | 4 + sherpa-onnx/csrc/macros.h | 12 +- sherpa-onnx/csrc/offline-ctc-decoder.h | 42 ++++++ .../csrc/offline-ctc-greedy-search-decoder.cc | 54 ++++++++ .../csrc/offline-ctc-greedy-search-decoder.h | 28 ++++ sherpa-onnx/csrc/offline-ctc-model.cc | 86 ++++++++++++ sherpa-onnx/csrc/offline-ctc-model.h | 59 ++++++++ sherpa-onnx/csrc/offline-model-config.cc | 6 + sherpa-onnx/csrc/offline-model-config.h | 4 + .../offline-nemo-enc-dec-ctc-model-config.cc | 35 +++++ .../offline-nemo-enc-dec-ctc-model-config.h | 28 ++++ .../csrc/offline-nemo-enc-dec-ctc-model.cc | 131 ++++++++++++++++++ .../csrc/offline-nemo-enc-dec-ctc-model.h | 75 ++++++++++ .../csrc/offline-recognizer-ctc-impl.h | 128 +++++++++++++++++ sherpa-onnx/csrc/offline-recognizer-impl.cc | 38 ++++- sherpa-onnx/csrc/offline-stream.cc | 70 ++++++++++ sherpa-onnx/csrc/offline-stream.h | 15 +- sherpa-onnx/csrc/online-transducer-model.cc | 19 ++- sherpa-onnx/csrc/transpose-test.cc | 24 ++++ sherpa-onnx/csrc/transpose.cc | 29 +++- sherpa-onnx/csrc/transpose.h | 16 ++- sherpa-onnx/python/csrc/CMakeLists.txt | 1 + .../python/csrc/offline-model-config.cc | 21 +-- .../offline-nemo-enc-dec-ctc-model-config.cc | 22 +++ .../offline-nemo-enc-dec-ctc-model-config.h | 16 +++ .../csrc/offline-paraformer-model-config.cc | 4 +- sherpa-onnx/python/csrc/offline-recognizer.cc | 2 - sherpa-onnx/python/csrc/offline-stream.cc | 4 +- sherpa-onnx/python/csrc/sherpa-onnx.cc | 9 +- .../python/sherpa_onnx/offline_recognizer.py | 73 ++++++++-- .../python/tests/test_offline_recognizer.py | 65 +++++++++ 40 files changed, 1244 insertions(+), 60 deletions(-) create mode 100755 .github/scripts/test-offline-ctc.sh create mode 100644 sherpa-onnx/csrc/offline-ctc-decoder.h create mode 100644 sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.cc create mode 100644 sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h create mode 100644 sherpa-onnx/csrc/offline-ctc-model.cc create mode 100644 sherpa-onnx/csrc/offline-ctc-model.h create mode 100644 sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.cc create mode 100644 sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h create mode 100644 sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.cc create mode 100644 sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h create mode 100644 sherpa-onnx/csrc/offline-recognizer-ctc-impl.h create mode 100644 sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.cc create mode 100644 sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h diff --git a/.github/scripts/test-offline-ctc.sh b/.github/scripts/test-offline-ctc.sh new file mode 100755 index 00000000..7e0b840d --- /dev/null +++ b/.github/scripts/test-offline-ctc.sh @@ -0,0 +1,47 @@ +#!/usr/bin/env bash + +set -e + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +echo "EXE is $EXE" +echo "PATH: $PATH" + +which $EXE + +log "------------------------------------------------------------" +log "Run Citrinet (stt_en_citrinet_512, English)" +log "------------------------------------------------------------" + +repo_url=http://huggingface.co/csukuangfj/sherpa-onnx-nemo-ctc-en-citrinet-512 +log "Start testing ${repo_url}" +repo=$(basename $repo_url) +log "Download pretrained model and test-data from $repo_url" + +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +pushd $repo +git lfs pull --include "*.onnx" +ls -lh *.onnx +popd + +time $EXE \ + --tokens=$repo/tokens.txt \ + --nemo-ctc-model=$repo/model.onnx \ + --num-threads=2 \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/8k.wav + +time $EXE \ + --tokens=$repo/tokens.txt \ + --nemo-ctc-model=$repo/model.int8.onnx \ + --num-threads=2 \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/8k.wav + +rm -rf $repo diff --git a/.github/scripts/test-python.sh b/.github/scripts/test-python.sh index cd09f785..506adebc 100755 --- a/.github/scripts/test-python.sh +++ b/.github/scripts/test-python.sh @@ -95,6 +95,8 @@ python3 ./python-api-examples/offline-decode-files.py \ python3 sherpa-onnx/python/tests/test_offline_recognizer.py --verbose +rm -rf $repo + log "Test non-streaming paraformer models" pushd $dir @@ -128,3 +130,39 @@ python3 ./python-api-examples/offline-decode-files.py \ $repo/test_wavs/8k.wav python3 sherpa-onnx/python/tests/test_offline_recognizer.py --verbose + +rm -rf $repo + +log "Test non-streaming NeMo CTC models" + +pushd $dir +repo_url=http://huggingface.co/csukuangfj/sherpa-onnx-nemo-ctc-en-citrinet-512 + +log "Start testing ${repo_url}" +repo=$dir/$(basename $repo_url) +log "Download pretrained model and test-data from $repo_url" + +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +cd $repo +git lfs pull --include "*.onnx" +popd + +ls -lh $repo + +python3 ./python-api-examples/offline-decode-files.py \ + --tokens=$repo/tokens.txt \ + --nemo-ctc=$repo/model.onnx \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/8k.wav + +python3 ./python-api-examples/offline-decode-files.py \ + --tokens=$repo/tokens.txt \ + --nemo-ctc=$repo/model.int8.onnx \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/8k.wav + +python3 sherpa-onnx/python/tests/test_offline_recognizer.py --verbose + +rm -rf $repo diff --git a/.github/workflows/linux.yaml b/.github/workflows/linux.yaml index bd536e08..a328616e 100644 --- a/.github/workflows/linux.yaml +++ b/.github/workflows/linux.yaml @@ -8,6 +8,7 @@ on: - '.github/workflows/linux.yaml' - '.github/scripts/test-online-transducer.sh' - '.github/scripts/test-offline-transducer.sh' + - '.github/scripts/test-offline-ctc.sh' - 'CMakeLists.txt' - 'cmake/**' - 'sherpa-onnx/csrc/*' @@ -20,6 +21,7 @@ on: - '.github/workflows/linux.yaml' - '.github/scripts/test-online-transducer.sh' - '.github/scripts/test-offline-transducer.sh' + - '.github/scripts/test-offline-ctc.sh' - 'CMakeLists.txt' - 'cmake/**' - 'sherpa-onnx/csrc/*' @@ -68,6 +70,14 @@ jobs: file build/bin/sherpa-onnx readelf -d build/bin/sherpa-onnx + - name: Test offline CTC + shell: bash + run: | + export PATH=$PWD/build/bin:$PATH + export EXE=sherpa-onnx-offline + + .github/scripts/test-offline-ctc.sh + - name: Test offline transducer shell: bash run: | diff --git a/.github/workflows/macos.yaml b/.github/workflows/macos.yaml index a8a0f1c6..0967290b 100644 --- a/.github/workflows/macos.yaml +++ b/.github/workflows/macos.yaml @@ -8,6 +8,7 @@ on: - '.github/workflows/macos.yaml' - '.github/scripts/test-online-transducer.sh' - '.github/scripts/test-offline-transducer.sh' + - '.github/scripts/test-offline-ctc.sh' - 'CMakeLists.txt' - 'cmake/**' - 'sherpa-onnx/csrc/*' @@ -18,6 +19,7 @@ on: - '.github/workflows/macos.yaml' - '.github/scripts/test-online-transducer.sh' - '.github/scripts/test-offline-transducer.sh' + - '.github/scripts/test-offline-ctc.sh' - 'CMakeLists.txt' - 'cmake/**' - 'sherpa-onnx/csrc/*' @@ -67,6 +69,14 @@ jobs: otool -L build/bin/sherpa-onnx otool -l build/bin/sherpa-onnx + - name: Test offline CTC + shell: bash + run: | + export PATH=$PWD/build/bin:$PATH + export EXE=sherpa-onnx-offline + + .github/scripts/test-offline-ctc.sh + - name: Test offline transducer shell: bash run: | diff --git a/.github/workflows/windows-x64.yaml b/.github/workflows/windows-x64.yaml index 5600c2dd..225ab514 100644 --- a/.github/workflows/windows-x64.yaml +++ b/.github/workflows/windows-x64.yaml @@ -8,6 +8,7 @@ on: - '.github/workflows/windows-x64.yaml' - '.github/scripts/test-online-transducer.sh' - '.github/scripts/test-offline-transducer.sh' + - '.github/scripts/test-offline-ctc.sh' - 'CMakeLists.txt' - 'cmake/**' - 'sherpa-onnx/csrc/*' @@ -18,6 +19,7 @@ on: - '.github/workflows/windows-x64.yaml' - '.github/scripts/test-online-transducer.sh' - '.github/scripts/test-offline-transducer.sh' + - '.github/scripts/test-offline-ctc.sh' - 'CMakeLists.txt' - 'cmake/**' - 'sherpa-onnx/csrc/*' @@ -73,6 +75,14 @@ jobs: ls -lh ./bin/Release/sherpa-onnx.exe + - name: Test offline CTC for windows x64 + shell: bash + run: | + export PATH=$PWD/build/bin/Release:$PATH + export EXE=sherpa-onnx-offline.exe + + .github/scripts/test-offline-ctc.sh + - name: Test offline transducer for Windows x64 shell: bash run: | diff --git a/.github/workflows/windows-x86.yaml b/.github/workflows/windows-x86.yaml index 17171b71..6ed029ca 100644 --- a/.github/workflows/windows-x86.yaml +++ b/.github/workflows/windows-x86.yaml @@ -8,6 +8,7 @@ on: - '.github/workflows/windows-x86.yaml' - '.github/scripts/test-online-transducer.sh' - '.github/scripts/test-offline-transducer.sh' + - '.github/scripts/test-offline-ctc.sh' - 'CMakeLists.txt' - 'cmake/**' - 'sherpa-onnx/csrc/*' @@ -18,6 +19,7 @@ on: - '.github/workflows/windows-x86.yaml' - '.github/scripts/test-online-transducer.sh' - '.github/scripts/test-offline-transducer.sh' + - '.github/scripts/test-offline-ctc.sh' - 'CMakeLists.txt' - 'cmake/**' - 'sherpa-onnx/csrc/*' @@ -31,6 +33,7 @@ permissions: jobs: windows_x86: + if: false # disable windows x86 CI for now runs-on: ${{ matrix.os }} name: ${{ matrix.vs-version }} strategy: @@ -73,6 +76,14 @@ jobs: ls -lh ./bin/Release/sherpa-onnx.exe + - name: Test offline CTC for windows x86 + shell: bash + run: | + export PATH=$PWD/build/bin/Release:$PATH + export EXE=sherpa-onnx-offline.exe + + .github/scripts/test-offline-ctc.sh + - name: Test offline transducer for Windows x86 shell: bash run: | diff --git a/.gitignore b/.gitignore index dcba6799..212f04ea 100644 --- a/.gitignore +++ b/.gitignore @@ -52,3 +52,6 @@ run-offline-websocket-client-*.sh run-sherpa-onnx-*.sh sherpa-onnx-zipformer-en-2023-03-30 sherpa-onnx-zipformer-en-2023-04-01 +run-offline-decode-files.sh +sherpa-onnx-nemo-ctc-en-citrinet-512 +run-offline-decode-files-nemo-ctc.sh diff --git a/python-api-examples/offline-decode-files.py b/python-api-examples/offline-decode-files.py index e41aa01c..8604d77c 100755 --- a/python-api-examples/offline-decode-files.py +++ b/python-api-examples/offline-decode-files.py @@ -6,7 +6,7 @@ This file demonstrates how to use sherpa-onnx Python API to transcribe file(s) with a non-streaming model. -paraformer Usage: +(1) For paraformer ./python-api-examples/offline-decode-files.py \ --tokens=/path/to/tokens.txt \ --paraformer=/path/to/paraformer.onnx \ @@ -18,7 +18,7 @@ paraformer Usage: /path/to/0.wav \ /path/to/1.wav -transducer Usage: +(2) For transducer models from icefall ./python-api-examples/offline-decode-files.py \ --tokens=/path/to/tokens.txt \ --encoder=/path/to/encoder.onnx \ @@ -32,6 +32,8 @@ transducer Usage: /path/to/0.wav \ /path/to/1.wav +(3) For CTC models from NeMo + Please refer to https://k2-fsa.github.io/sherpa/onnx/index.html to install sherpa-onnx and to download the pre-trained models @@ -83,7 +85,14 @@ def get_args(): "--paraformer", default="", type=str, - help="Path to the paraformer model", + help="Path to the model.onnx from Paraformer", + ) + + parser.add_argument( + "--nemo-ctc", + default="", + type=str, + help="Path to the model.onnx from NeMo CTC", ) parser.add_argument( @@ -171,11 +180,14 @@ def main(): args = get_args() assert_file_exists(args.tokens) assert args.num_threads > 0, args.num_threads - if len(args.encoder) > 0: + if args.encoder: + assert len(args.paraformer) == 0, args.paraformer + assert len(args.nemo_ctc) == 0, args.nemo_ctc + assert_file_exists(args.encoder) assert_file_exists(args.decoder) assert_file_exists(args.joiner) - assert len(args.paraformer) == 0, args.paraformer + recognizer = sherpa_onnx.OfflineRecognizer.from_transducer( encoder=args.encoder, decoder=args.decoder, @@ -187,8 +199,10 @@ def main(): decoding_method=args.decoding_method, debug=args.debug, ) - else: + elif args.paraformer: + assert len(args.nemo_ctc) == 0, args.nemo_ctc assert_file_exists(args.paraformer) + recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer( paraformer=args.paraformer, tokens=args.tokens, @@ -198,6 +212,19 @@ def main(): decoding_method=args.decoding_method, debug=args.debug, ) + elif args.nemo_ctc: + recognizer = sherpa_onnx.OfflineRecognizer.from_nemo_ctc( + model=args.nemo_ctc, + tokens=args.tokens, + num_threads=args.num_threads, + sample_rate=args.sample_rate, + feature_dim=args.feature_dim, + decoding_method=args.decoding_method, + debug=args.debug, + ) + else: + print("Please specify at least one model") + return print("Started!") start_time = time.time() @@ -225,12 +252,14 @@ def main(): print("-" * 10) elapsed_seconds = end_time - start_time - rtf = elapsed_seconds / duration + rtf = elapsed_seconds / total_duration print(f"num_threads: {args.num_threads}") print(f"decoding_method: {args.decoding_method}") - print(f"Wave duration: {duration:.3f} s") + print(f"Wave duration: {total_duration:.3f} s") print(f"Elapsed time: {elapsed_seconds:.3f} s") - print(f"Real time factor (RTF): {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}") + print( + f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}" + ) if __name__ == "__main__": diff --git a/python-api-examples/online-decode-files.py b/python-api-examples/online-decode-files.py index a61ed300..fff8bf94 100755 --- a/python-api-examples/online-decode-files.py +++ b/python-api-examples/online-decode-files.py @@ -172,12 +172,14 @@ def main(): print("-" * 10) elapsed_seconds = end_time - start_time - rtf = elapsed_seconds / duration + rtf = elapsed_seconds / total_duration print(f"num_threads: {args.num_threads}") print(f"decoding_method: {args.decoding_method}") - print(f"Wave duration: {duration:.3f} s") + print(f"Wave duration: {total_duration:.3f} s") print(f"Elapsed time: {elapsed_seconds:.3f} s") - print(f"Real time factor (RTF): {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}") + print( + f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}" + ) if __name__ == "__main__": diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index 672172c7..8a50f770 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -16,7 +16,11 @@ set(sources features.cc file-utils.cc hypothesis.cc + offline-ctc-greedy-search-decoder.cc + offline-ctc-model.cc offline-model-config.cc + offline-nemo-enc-dec-ctc-model-config.cc + offline-nemo-enc-dec-ctc-model.cc offline-paraformer-greedy-search-decoder.cc offline-paraformer-model-config.cc offline-paraformer-model.cc diff --git a/sherpa-onnx/csrc/macros.h b/sherpa-onnx/csrc/macros.h index efe61289..685d34ac 100644 --- a/sherpa-onnx/csrc/macros.h +++ b/sherpa-onnx/csrc/macros.h @@ -11,15 +11,19 @@ #include "android/log.h" #define SHERPA_ONNX_LOGE(...) \ do { \ + fprintf(stderr, "%s:%s:%d ", __FILE__, __func__, \ + static_cast(__LINE__)); \ fprintf(stderr, ##__VA_ARGS__); \ fprintf(stderr, "\n"); \ __android_log_print(ANDROID_LOG_WARN, "sherpa-onnx", ##__VA_ARGS__); \ } while (0) #else -#define SHERPA_ONNX_LOGE(...) \ - do { \ - fprintf(stderr, ##__VA_ARGS__); \ - fprintf(stderr, "\n"); \ +#define SHERPA_ONNX_LOGE(...) \ + do { \ + fprintf(stderr, "%s:%s:%d ", __FILE__, __func__, \ + static_cast(__LINE__)); \ + fprintf(stderr, ##__VA_ARGS__); \ + fprintf(stderr, "\n"); \ } while (0) #endif diff --git a/sherpa-onnx/csrc/offline-ctc-decoder.h b/sherpa-onnx/csrc/offline-ctc-decoder.h new file mode 100644 index 00000000..23e8d0bd --- /dev/null +++ b/sherpa-onnx/csrc/offline-ctc-decoder.h @@ -0,0 +1,42 @@ +// sherpa-onnx/csrc/offline-ctc-decoder.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_CTC_DECODER_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_CTC_DECODER_H_ + +#include + +#include "onnxruntime_cxx_api.h" // NOLINT + +namespace sherpa_onnx { + +struct OfflineCtcDecoderResult { + /// The decoded token IDs + std::vector tokens; + + /// timestamps[i] contains the output frame index where tokens[i] is decoded. + /// Note: The index is after subsampling + std::vector timestamps; +}; + +class OfflineCtcDecoder { + public: + virtual ~OfflineCtcDecoder() = default; + + /** Run CTC decoding given the output from the encoder model. + * + * @param log_probs A 3-D tensor of shape (N, T, vocab_size) containing + * lob_probs. + * @param log_probs_length A 1-D tensor of shape (N,) containing number + * of valid frames in log_probs before padding. + * + * @return Return a vector of size `N` containing the decoded results. + */ + virtual std::vector Decode( + Ort::Value log_probs, Ort::Value log_probs_length) = 0; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_CTC_DECODER_H_ diff --git a/sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.cc b/sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.cc new file mode 100644 index 00000000..8c4451b3 --- /dev/null +++ b/sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.cc @@ -0,0 +1,54 @@ +// sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h" + +#include +#include +#include + +#include "sherpa-onnx/csrc/macros.h" + +namespace sherpa_onnx { + +std::vector OfflineCtcGreedySearchDecoder::Decode( + Ort::Value log_probs, Ort::Value log_probs_length) { + std::vector shape = log_probs.GetTensorTypeAndShapeInfo().GetShape(); + int32_t batch_size = static_cast(shape[0]); + int32_t num_frames = static_cast(shape[1]); + int32_t vocab_size = static_cast(shape[2]); + + const int64_t *p_log_probs_length = log_probs_length.GetTensorData(); + + std::vector ans; + ans.reserve(batch_size); + + for (int32_t b = 0; b != batch_size; ++b) { + const float *p_log_probs = + log_probs.GetTensorData() + b * num_frames * vocab_size; + + OfflineCtcDecoderResult r; + int64_t prev_id = -1; + + for (int32_t t = 0; t != static_cast(p_log_probs_length[b]); ++t) { + auto y = static_cast(std::distance( + static_cast(p_log_probs), + std::max_element( + static_cast(p_log_probs), + static_cast(p_log_probs) + vocab_size))); + p_log_probs += vocab_size; + + if (y != blank_id_ && y != prev_id) { + r.tokens.push_back(y); + r.timestamps.push_back(t); + prev_id = y; + } + } // for (int32_t t = 0; ...) + + ans.push_back(std::move(r)); + } + return ans; +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h b/sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h new file mode 100644 index 00000000..ccc2f728 --- /dev/null +++ b/sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h @@ -0,0 +1,28 @@ +// sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_CTC_GREEDY_SEARCH_DECODER_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_CTC_GREEDY_SEARCH_DECODER_H_ + +#include + +#include "sherpa-onnx/csrc/offline-ctc-decoder.h" + +namespace sherpa_onnx { + +class OfflineCtcGreedySearchDecoder : public OfflineCtcDecoder { + public: + explicit OfflineCtcGreedySearchDecoder(int32_t blank_id) + : blank_id_(blank_id) {} + + std::vector Decode( + Ort::Value log_probs, Ort::Value log_probs_length) override; + + private: + int32_t blank_id_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_CTC_GREEDY_SEARCH_DECODER_H_ diff --git a/sherpa-onnx/csrc/offline-ctc-model.cc b/sherpa-onnx/csrc/offline-ctc-model.cc new file mode 100644 index 00000000..1d19253d --- /dev/null +++ b/sherpa-onnx/csrc/offline-ctc-model.cc @@ -0,0 +1,86 @@ +// sherpa-onnx/csrc/offline-ctc-model.cc +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-ctc-model.h" + +#include +#include +#include +#include + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h" +#include "sherpa-onnx/csrc/onnx-utils.h" + +namespace { + +enum class ModelType { + kEncDecCTCModelBPE, + kUnkown, +}; + +} + +namespace sherpa_onnx { + +static ModelType GetModelType(char *model_data, size_t model_data_length, + bool debug) { + Ort::Env env(ORT_LOGGING_LEVEL_WARNING); + Ort::SessionOptions sess_opts; + + auto sess = std::make_unique(env, model_data, model_data_length, + sess_opts); + + Ort::ModelMetadata meta_data = sess->GetModelMetadata(); + if (debug) { + std::ostringstream os; + PrintModelMetadata(os, meta_data); + SHERPA_ONNX_LOGE("%s", os.str().c_str()); + } + + Ort::AllocatorWithDefaultOptions allocator; + auto model_type = + meta_data.LookupCustomMetadataMapAllocated("model_type", allocator); + if (!model_type) { + SHERPA_ONNX_LOGE( + "No model_type in the metadata!\n" + "If you are using models from NeMo, please refer to\n" + "https://huggingface.co/csukuangfj/" + "sherpa-onnx-nemo-ctc-en-citrinet-512/blob/main/add-model-metadata.py" + "\n" + "for how to add metadta to model.onnx\n"); + return ModelType::kUnkown; + } + + if (model_type.get() == std::string("EncDecCTCModelBPE")) { + return ModelType::kEncDecCTCModelBPE; + } else { + SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get()); + return ModelType::kUnkown; + } +} + +std::unique_ptr OfflineCtcModel::Create( + const OfflineModelConfig &config) { + ModelType model_type = ModelType::kUnkown; + + { + auto buffer = ReadFile(config.nemo_ctc.model); + + model_type = GetModelType(buffer.data(), buffer.size(), config.debug); + } + + switch (model_type) { + case ModelType::kEncDecCTCModelBPE: + return std::make_unique(config); + break; + case ModelType::kUnkown: + SHERPA_ONNX_LOGE("Unknown model type in offline CTC!"); + return nullptr; + } + + return nullptr; +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-ctc-model.h b/sherpa-onnx/csrc/offline-ctc-model.h new file mode 100644 index 00000000..8be7f99b --- /dev/null +++ b/sherpa-onnx/csrc/offline-ctc-model.h @@ -0,0 +1,59 @@ +// sherpa-onnx/csrc/offline-ctc-model.h +// +// Copyright (c) 2022-2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_CTC_MODEL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_CTC_MODEL_H_ + +#include +#include +#include + +#include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/offline-model-config.h" + +namespace sherpa_onnx { + +class OfflineCtcModel { + public: + virtual ~OfflineCtcModel() = default; + static std::unique_ptr Create( + const OfflineModelConfig &config); + + /** Run the forward method of the model. + * + * @param features A tensor of shape (N, T, C). It is changed in-place. + * @param features_length A 1-D tensor of shape (N,) containing number of + * valid frames in `features` before padding. + * Its dtype is int64_t. + * + * @return Return a pair containing: + * - log_probs: A 3-D tensor of shape (N, T', vocab_size). + * - log_probs_length A 1-D tensor of shape (N,). Its dtype is int64_t + */ + virtual std::pair Forward( + Ort::Value features, Ort::Value features_length) = 0; + + /** Return the vocabulary size of the model + */ + virtual int32_t VocabSize() const = 0; + + /** SubsamplingFactor of the model + * + * For Citrinet, the subsampling factor is usually 4. + * For Conformer CTC, the subsampling factor is usually 8. + */ + virtual int32_t SubsamplingFactor() const = 0; + + /** Return an allocator for allocating memory + */ + virtual OrtAllocator *Allocator() const = 0; + + /** For some models, e.g., those from NeMo, they require some preprocessing + * for the features. + */ + virtual std::string FeatureNormalizationMethod() const { return {}; } +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_CTC_MODEL_H_ diff --git a/sherpa-onnx/csrc/offline-model-config.cc b/sherpa-onnx/csrc/offline-model-config.cc index 29f7b8a7..c4912abb 100644 --- a/sherpa-onnx/csrc/offline-model-config.cc +++ b/sherpa-onnx/csrc/offline-model-config.cc @@ -13,6 +13,7 @@ namespace sherpa_onnx { void OfflineModelConfig::Register(ParseOptions *po) { transducer.Register(po); paraformer.Register(po); + nemo_ctc.Register(po); po->Register("tokens", &tokens, "Path to tokens.txt"); @@ -38,6 +39,10 @@ bool OfflineModelConfig::Validate() const { return paraformer.Validate(); } + if (!nemo_ctc.model.empty()) { + return nemo_ctc.Validate(); + } + return transducer.Validate(); } @@ -47,6 +52,7 @@ std::string OfflineModelConfig::ToString() const { os << "OfflineModelConfig("; os << "transducer=" << transducer.ToString() << ", "; os << "paraformer=" << paraformer.ToString() << ", "; + os << "nemo_ctc=" << nemo_ctc.ToString() << ", "; os << "tokens=\"" << tokens << "\", "; os << "num_threads=" << num_threads << ", "; os << "debug=" << (debug ? "True" : "False") << ")"; diff --git a/sherpa-onnx/csrc/offline-model-config.h b/sherpa-onnx/csrc/offline-model-config.h index d8412316..da17c7b5 100644 --- a/sherpa-onnx/csrc/offline-model-config.h +++ b/sherpa-onnx/csrc/offline-model-config.h @@ -6,6 +6,7 @@ #include +#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h" #include "sherpa-onnx/csrc/offline-paraformer-model-config.h" #include "sherpa-onnx/csrc/offline-transducer-model-config.h" @@ -14,6 +15,7 @@ namespace sherpa_onnx { struct OfflineModelConfig { OfflineTransducerModelConfig transducer; OfflineParaformerModelConfig paraformer; + OfflineNemoEncDecCtcModelConfig nemo_ctc; std::string tokens; int32_t num_threads = 2; @@ -22,9 +24,11 @@ struct OfflineModelConfig { OfflineModelConfig() = default; OfflineModelConfig(const OfflineTransducerModelConfig &transducer, const OfflineParaformerModelConfig ¶former, + const OfflineNemoEncDecCtcModelConfig &nemo_ctc, const std::string &tokens, int32_t num_threads, bool debug) : transducer(transducer), paraformer(paraformer), + nemo_ctc(nemo_ctc), tokens(tokens), num_threads(num_threads), debug(debug) {} diff --git a/sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.cc b/sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.cc new file mode 100644 index 00000000..c28c522d --- /dev/null +++ b/sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.cc @@ -0,0 +1,35 @@ +// sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h" + +#include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/macros.h" + +namespace sherpa_onnx { + +void OfflineNemoEncDecCtcModelConfig::Register(ParseOptions *po) { + po->Register("nemo-ctc-model", &model, + "Path to model.onnx of Nemo EncDecCtcModel."); +} + +bool OfflineNemoEncDecCtcModelConfig::Validate() const { + if (!FileExists(model)) { + SHERPA_ONNX_LOGE("%s does not exist", model.c_str()); + return false; + } + + return true; +} + +std::string OfflineNemoEncDecCtcModelConfig::ToString() const { + std::ostringstream os; + + os << "OfflineNemoEncDecCtcModelConfig("; + os << "model=\"" << model << "\")"; + + return os.str(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h b/sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h new file mode 100644 index 00000000..9ef7b54a --- /dev/null +++ b/sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h @@ -0,0 +1,28 @@ +// sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_NEMO_ENC_DEC_CTC_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_NEMO_ENC_DEC_CTC_MODEL_CONFIG_H_ + +#include + +#include "sherpa-onnx/csrc/parse-options.h" + +namespace sherpa_onnx { + +struct OfflineNemoEncDecCtcModelConfig { + std::string model; + + OfflineNemoEncDecCtcModelConfig() = default; + explicit OfflineNemoEncDecCtcModelConfig(const std::string &model) + : model(model) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_NEMO_ENC_DEC_CTC_MODEL_CONFIG_H_ diff --git a/sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.cc b/sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.cc new file mode 100644 index 00000000..c981a453 --- /dev/null +++ b/sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.cc @@ -0,0 +1,131 @@ +// sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h" + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/onnx-utils.h" +#include "sherpa-onnx/csrc/text-utils.h" +#include "sherpa-onnx/csrc/transpose.h" + +namespace sherpa_onnx { + +class OfflineNemoEncDecCtcModel::Impl { + public: + explicit Impl(const OfflineModelConfig &config) + : config_(config), + env_(ORT_LOGGING_LEVEL_ERROR), + sess_opts_{}, + allocator_{} { + sess_opts_.SetIntraOpNumThreads(config_.num_threads); + sess_opts_.SetInterOpNumThreads(config_.num_threads); + + Init(); + } + + std::pair Forward(Ort::Value features, + Ort::Value features_length) { + std::vector shape = + features_length.GetTensorTypeAndShapeInfo().GetShape(); + + Ort::Value out_features_length = Ort::Value::CreateTensor( + allocator_, shape.data(), shape.size()); + + const int64_t *src = features_length.GetTensorData(); + int64_t *dst = out_features_length.GetTensorMutableData(); + for (int64_t i = 0; i != shape[0]; ++i) { + dst[i] = src[i] / subsampling_factor_; + } + + // (B, T, C) -> (B, C, T) + features = Transpose12(allocator_, &features); + + std::array inputs = {std::move(features), + std::move(features_length)}; + auto out = + sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), + output_names_ptr_.data(), output_names_ptr_.size()); + + return {std::move(out[0]), std::move(out_features_length)}; + } + + int32_t VocabSize() const { return vocab_size_; } + + int32_t SubsamplingFactor() const { return subsampling_factor_; } + + OrtAllocator *Allocator() const { return allocator_; } + + std::string FeatureNormalizationMethod() const { return normalize_type_; } + + private: + void Init() { + auto buf = ReadFile(config_.nemo_ctc.model); + + sess_ = std::make_unique(env_, buf.data(), buf.size(), + sess_opts_); + + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); + + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); + + // get meta data + Ort::ModelMetadata meta_data = sess_->GetModelMetadata(); + if (config_.debug) { + std::ostringstream os; + PrintModelMetadata(os, meta_data); + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); + } + + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below + SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size"); + SHERPA_ONNX_READ_META_DATA(subsampling_factor_, "subsampling_factor"); + SHERPA_ONNX_READ_META_DATA_STR(normalize_type_, "normalize_type"); + } + + private: + OfflineModelConfig config_; + Ort::Env env_; + Ort::SessionOptions sess_opts_; + Ort::AllocatorWithDefaultOptions allocator_; + + std::unique_ptr sess_; + + std::vector input_names_; + std::vector input_names_ptr_; + + std::vector output_names_; + std::vector output_names_ptr_; + + int32_t vocab_size_ = 0; + int32_t subsampling_factor_ = 0; + std::string normalize_type_; +}; + +OfflineNemoEncDecCtcModel::OfflineNemoEncDecCtcModel( + const OfflineModelConfig &config) + : impl_(std::make_unique(config)) {} + +OfflineNemoEncDecCtcModel::~OfflineNemoEncDecCtcModel() = default; + +std::pair OfflineNemoEncDecCtcModel::Forward( + Ort::Value features, Ort::Value features_length) { + return impl_->Forward(std::move(features), std::move(features_length)); +} + +int32_t OfflineNemoEncDecCtcModel::VocabSize() const { + return impl_->VocabSize(); +} +int32_t OfflineNemoEncDecCtcModel::SubsamplingFactor() const { + return impl_->SubsamplingFactor(); +} + +OrtAllocator *OfflineNemoEncDecCtcModel::Allocator() const { + return impl_->Allocator(); +} + +std::string OfflineNemoEncDecCtcModel::FeatureNormalizationMethod() const { + return impl_->FeatureNormalizationMethod(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h b/sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h new file mode 100644 index 00000000..2b11f655 --- /dev/null +++ b/sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h @@ -0,0 +1,75 @@ +// sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_NEMO_ENC_DEC_CTC_MODEL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_NEMO_ENC_DEC_CTC_MODEL_H_ +#include +#include +#include +#include + +#include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/offline-ctc-model.h" +#include "sherpa-onnx/csrc/offline-model-config.h" + +namespace sherpa_onnx { + +/** This class implements the EncDecCTCModelBPE model from NeMo. + * + * See + * https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/models/ctc_bpe_models.py + * https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/models/ctc_models.py + */ +class OfflineNemoEncDecCtcModel : public OfflineCtcModel { + public: + explicit OfflineNemoEncDecCtcModel(const OfflineModelConfig &config); + ~OfflineNemoEncDecCtcModel() override; + + /** Run the forward method of the model. + * + * @param features A tensor of shape (N, T, C). It is changed in-place. + * @param features_length A 1-D tensor of shape (N,) containing number of + * valid frames in `features` before padding. + * Its dtype is int64_t. + * + * @return Return a pair containing: + * - log_probs: A 3-D tensor of shape (N, T', vocab_size). + * - log_probs_length A 1-D tensor of shape (N,). Its dtype is int64_t + */ + std::pair Forward( + Ort::Value features, Ort::Value features_length) override; + + /** Return the vocabulary size of the model + */ + int32_t VocabSize() const override; + + /** SubsamplingFactor of the model + * + * For Citrinet, the subsampling factor is usually 4. + * For Conformer CTC, the subsampling factor is usually 8. + */ + int32_t SubsamplingFactor() const override; + + /** Return an allocator for allocating memory + */ + OrtAllocator *Allocator() const override; + + // Possible values: + // - per_feature + // - all_features (not implemented yet) + // - fixed_mean (not implemented) + // - fixed_std (not implemented) + // - or just leave it to empty + // See + // https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/parts/preprocessing/features.py#L59 + // for details + std::string FeatureNormalizationMethod() const override; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_NEMO_ENC_DEC_CTC_MODEL_H_ diff --git a/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h b/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h new file mode 100644 index 00000000..80c946c7 --- /dev/null +++ b/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h @@ -0,0 +1,128 @@ +// sherpa-onnx/csrc/offline-recognizer-ctc-impl.h +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_CTC_IMPL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_CTC_IMPL_H_ + +#include +#include +#include +#include + +#include "sherpa-onnx/csrc/offline-ctc-decoder.h" +#include "sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h" +#include "sherpa-onnx/csrc/offline-ctc-model.h" +#include "sherpa-onnx/csrc/offline-recognizer-impl.h" +#include "sherpa-onnx/csrc/pad-sequence.h" +#include "sherpa-onnx/csrc/symbol-table.h" + +namespace sherpa_onnx { + +static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src, + const SymbolTable &sym_table) { + OfflineRecognitionResult r; + r.tokens.reserve(src.tokens.size()); + + std::string text; + + for (int32_t i = 0; i != src.tokens.size(); ++i) { + auto sym = sym_table[src.tokens[i]]; + text.append(sym); + r.tokens.push_back(std::move(sym)); + } + r.text = std::move(text); + + return r; +} + +class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { + public: + explicit OfflineRecognizerCtcImpl(const OfflineRecognizerConfig &config) + : config_(config), + symbol_table_(config_.model_config.tokens), + model_(OfflineCtcModel::Create(config_.model_config)) { + config_.feat_config.nemo_normalize_type = + model_->FeatureNormalizationMethod(); + + if (config.decoding_method == "greedy_search") { + if (!symbol_table_.contains("")) { + SHERPA_ONNX_LOGE( + "We expect that tokens.txt contains " + "the symbol and its ID."); + exit(-1); + } + + int32_t blank_id = symbol_table_[""]; + decoder_ = std::make_unique(blank_id); + } else { + SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s", + config.decoding_method.c_str()); + exit(-1); + } + } + + std::unique_ptr CreateStream() const override { + return std::make_unique(config_.feat_config); + } + + void DecodeStreams(OfflineStream **ss, int32_t n) const override { + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + int32_t feat_dim = config_.feat_config.feature_dim; + + std::vector features; + features.reserve(n); + + std::vector> features_vec(n); + std::vector features_length_vec(n); + + for (int32_t i = 0; i != n; ++i) { + std::vector f = ss[i]->GetFrames(); + + int32_t num_frames = f.size() / feat_dim; + features_vec[i] = std::move(f); + + features_length_vec[i] = num_frames; + + std::array shape = {num_frames, feat_dim}; + + Ort::Value x = Ort::Value::CreateTensor( + memory_info, features_vec[i].data(), features_vec[i].size(), + shape.data(), shape.size()); + features.push_back(std::move(x)); + } // for (int32_t i = 0; i != n; ++i) + + std::vector features_pointer(n); + for (int32_t i = 0; i != n; ++i) { + features_pointer[i] = &features[i]; + } + + std::array features_length_shape = {n}; + Ort::Value x_length = Ort::Value::CreateTensor( + memory_info, features_length_vec.data(), n, + features_length_shape.data(), features_length_shape.size()); + + Ort::Value x = PadSequence(model_->Allocator(), features_pointer, + -23.025850929940457f); + auto t = model_->Forward(std::move(x), std::move(x_length)); + + auto results = decoder_->Decode(std::move(t.first), std::move(t.second)); + + for (int32_t i = 0; i != n; ++i) { + auto r = Convert(results[i], symbol_table_); + ss[i]->SetResult(r); + } + } + + private: + OfflineRecognizerConfig config_; + SymbolTable symbol_table_; + std::unique_ptr model_; + std::unique_ptr decoder_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_CTC_IMPL_H_ diff --git a/sherpa-onnx/csrc/offline-recognizer-impl.cc b/sherpa-onnx/csrc/offline-recognizer-impl.cc index d78a5bec..fa6f770a 100644 --- a/sherpa-onnx/csrc/offline-recognizer-impl.cc +++ b/sherpa-onnx/csrc/offline-recognizer-impl.cc @@ -8,6 +8,7 @@ #include "onnxruntime_cxx_api.h" // NOLINT #include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h" #include "sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h" #include "sherpa-onnx/csrc/offline-recognizer-transducer-impl.h" #include "sherpa-onnx/csrc/onnx-utils.h" @@ -25,6 +26,8 @@ std::unique_ptr OfflineRecognizerImpl::Create( model_filename = config.model_config.transducer.encoder_filename; } else if (!config.model_config.paraformer.model.empty()) { model_filename = config.model_config.paraformer.model; + } else if (!config.model_config.nemo_ctc.model.empty()) { + model_filename = config.model_config.nemo_ctc.model; } else { SHERPA_ONNX_LOGE("Please provide a model"); exit(-1); @@ -39,8 +42,30 @@ std::unique_ptr OfflineRecognizerImpl::Create( Ort::AllocatorWithDefaultOptions allocator; // used in the macro below - std::string model_type; - SHERPA_ONNX_READ_META_DATA_STR(model_type, "model_type"); + auto model_type_ptr = + meta_data.LookupCustomMetadataMapAllocated("model_type", allocator); + if (!model_type_ptr) { + SHERPA_ONNX_LOGE( + "No model_type in the metadata!\n\n" + "Please refer to the following URLs to add metadata" + "\n" + "(0) Transducer models from icefall" + "\n " + "https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/" + "pruned_transducer_stateless7/export-onnx.py#L303" + "\n" + "(1) Nemo CTC models\n " + "https://huggingface.co/csukuangfj/" + "sherpa-onnx-nemo-ctc-en-citrinet-512/blob/main/add-model-metadata.py" + "\n" + "(2) Paraformer" + "\n " + "https://huggingface.co/csukuangfj/" + "paraformer-onnxruntime-python-example/blob/main/add-model-metadata.py" + "\n"); + exit(-1); + } + std::string model_type(model_type_ptr.get()); if (model_type == "conformer" || model_type == "zipformer") { return std::make_unique(config); @@ -50,11 +75,16 @@ std::unique_ptr OfflineRecognizerImpl::Create( return std::make_unique(config); } + if (model_type == "EncDecCTCModelBPE") { + return std::make_unique(config); + } + SHERPA_ONNX_LOGE( "\nUnsupported model_type: %s\n" "We support only the following model types at present: \n" - " - transducer models from icefall\n" - " - Paraformer models from FunASR\n", + " - Non-streaming transducer models from icefall\n" + " - Non-streaming Paraformer models from FunASR\n" + " - EncDecCTCModelBPE models from NeMo\n", model_type.c_str()); exit(-1); diff --git a/sherpa-onnx/csrc/offline-stream.cc b/sherpa-onnx/csrc/offline-stream.cc index bfb9fb64..67781cf3 100644 --- a/sherpa-onnx/csrc/offline-stream.cc +++ b/sherpa-onnx/csrc/offline-stream.cc @@ -7,6 +7,7 @@ #include #include +#include #include "kaldi-native-fbank/csrc/online-feature.h" #include "sherpa-onnx/csrc/macros.h" @@ -15,6 +16,41 @@ namespace sherpa_onnx { +/* Compute mean and inverse stddev over rows. + * + * @param p A pointer to a 2-d array of shape (num_rows, num_cols) + * @param num_rows Number of rows + * @param num_cols Number of columns + * @param mean On return, it contains p.mean(axis=0) + * @param inv_stddev On return, it contains 1/p.std(axis=0) + */ +static void ComputeMeanAndInvStd(const float *p, int32_t num_rows, + int32_t num_cols, std::vector *mean, + std::vector *inv_stddev) { + std::vector sum(num_cols); + std::vector sum_sq(num_cols); + + for (int32_t i = 0; i != num_rows; ++i) { + for (int32_t c = 0; c != num_cols; ++c) { + auto t = p[c]; + sum[c] += t; + sum_sq[c] += t * t; + } + p += num_cols; + } + + mean->resize(num_cols); + inv_stddev->resize(num_cols); + + for (int32_t i = 0; i != num_cols; ++i) { + auto t = sum[i] / num_rows; + (*mean)[i] = t; + + float stddev = std::sqrt(sum_sq[i] / num_rows - t * t); + (*inv_stddev)[i] = 1.0f / (stddev + 1e-5f); + } +} + void OfflineFeatureExtractorConfig::Register(ParseOptions *po) { po->Register("sample-rate", &sampling_rate, "Sampling rate of the input waveform. " @@ -106,6 +142,8 @@ class OfflineStream::Impl { p += feature_dim; } + NemoNormalizeFeatures(features.data(), n, feature_dim); + return features; } @@ -113,6 +151,38 @@ class OfflineStream::Impl { const OfflineRecognitionResult &GetResult() const { return r_; } + private: + void NemoNormalizeFeatures(float *p, int32_t num_frames, + int32_t feature_dim) const { + if (config_.nemo_normalize_type.empty()) { + return; + } + + if (config_.nemo_normalize_type != "per_feature") { + SHERPA_ONNX_LOGE( + "Only normalize_type=per_feature is implemented. Given: %s", + config_.nemo_normalize_type.c_str()); + exit(-1); + } + + NemoNormalizePerFeature(p, num_frames, feature_dim); + } + + static void NemoNormalizePerFeature(float *p, int32_t num_frames, + int32_t feature_dim) { + std::vector mean; + std::vector inv_stddev; + + ComputeMeanAndInvStd(p, num_frames, feature_dim, &mean, &inv_stddev); + + for (int32_t n = 0; n != num_frames; ++n) { + for (int32_t i = 0; i != feature_dim; ++i) { + p[i] = (p[i] - mean[i]) * inv_stddev[i]; + } + p += feature_dim; + } + } + private: OfflineFeatureExtractorConfig config_; std::unique_ptr fbank_; diff --git a/sherpa-onnx/csrc/offline-stream.h b/sherpa-onnx/csrc/offline-stream.h index ba0798bf..99fbd43c 100644 --- a/sherpa-onnx/csrc/offline-stream.h +++ b/sherpa-onnx/csrc/offline-stream.h @@ -37,13 +37,26 @@ struct OfflineFeatureExtractorConfig { // Feature dimension int32_t feature_dim = 80; - // Set internally by some models, e.g., paraformer + // Set internally by some models, e.g., paraformer sets it to false. // This parameter is not exposed to users from the commandline // If true, the feature extractor expects inputs to be normalized to // the range [-1, 1]. // If false, we will multiply the inputs by 32768 bool normalize_samples = true; + // For models from NeMo + // This option is not exposed and is set internally when loading models. + // Possible values: + // - per_feature + // - all_features (not implemented yet) + // - fixed_mean (not implemented) + // - fixed_std (not implemented) + // - or just leave it to empty + // See + // https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/parts/preprocessing/features.py#L59 + // for details + std::string nemo_normalize_type; + std::string ToString() const; void Register(ParseOptions *po); diff --git a/sherpa-onnx/csrc/online-transducer-model.cc b/sherpa-onnx/csrc/online-transducer-model.cc index 960d277f..89ad630e 100644 --- a/sherpa-onnx/csrc/online-transducer-model.cc +++ b/sherpa-onnx/csrc/online-transducer-model.cc @@ -14,10 +14,12 @@ #include #include +#include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/online-lstm-transducer-model.h" #include "sherpa-onnx/csrc/online-zipformer-transducer-model.h" #include "sherpa-onnx/csrc/onnx-utils.h" -namespace sherpa_onnx { + +namespace { enum class ModelType { kLstm, @@ -25,6 +27,10 @@ enum class ModelType { kUnkown, }; +} + +namespace sherpa_onnx { + static ModelType GetModelType(char *model_data, size_t model_data_length, bool debug) { Ort::Env env(ORT_LOGGING_LEVEL_WARNING); @@ -37,14 +43,17 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, if (debug) { std::ostringstream os; PrintModelMetadata(os, meta_data); - fprintf(stderr, "%s\n", os.str().c_str()); + SHERPA_ONNX_LOGE("%s", os.str().c_str()); } Ort::AllocatorWithDefaultOptions allocator; auto model_type = meta_data.LookupCustomMetadataMapAllocated("model_type", allocator); if (!model_type) { - fprintf(stderr, "No model_type in the metadata!\n"); + SHERPA_ONNX_LOGE( + "No model_type in the metadata!\n" + "Please make sure you are using the latest export-onnx.py from icefall " + "to export your transducer models"); return ModelType::kUnkown; } @@ -53,7 +62,7 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, } else if (model_type.get() == std::string("zipformer")) { return ModelType::kZipformer; } else { - fprintf(stderr, "Unsupported model_type: %s\n", model_type.get()); + SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get()); return ModelType::kUnkown; } } @@ -74,6 +83,7 @@ std::unique_ptr OnlineTransducerModel::Create( case ModelType::kZipformer: return std::make_unique(config); case ModelType::kUnkown: + SHERPA_ONNX_LOGE("Unknown model type in online transducer!"); return nullptr; } @@ -127,6 +137,7 @@ std::unique_ptr OnlineTransducerModel::Create( case ModelType::kZipformer: return std::make_unique(mgr, config); case ModelType::kUnkown: + SHERPA_ONNX_LOGE("Unknown model type in online transducer!"); return nullptr; } diff --git a/sherpa-onnx/csrc/transpose-test.cc b/sherpa-onnx/csrc/transpose-test.cc index 98fd179b..36d50372 100644 --- a/sherpa-onnx/csrc/transpose-test.cc +++ b/sherpa-onnx/csrc/transpose-test.cc @@ -35,4 +35,28 @@ TEST(Tranpose, Tranpose01) { } } +TEST(Tranpose, Tranpose12) { + Ort::AllocatorWithDefaultOptions allocator; + std::array shape{3, 2, 5}; + Ort::Value v = + Ort::Value::CreateTensor(allocator, shape.data(), shape.size()); + float *p = v.GetTensorMutableData(); + + std::iota(p, p + shape[0] * shape[1] * shape[2], 0); + + auto ans = Transpose12(allocator, &v); + auto v2 = Transpose12(allocator, &ans); + + Print3D(&v); + Print3D(&ans); + Print3D(&v2); + + const float *q = v2.GetTensorData(); + + for (int32_t i = 0; i != static_cast(shape[0] * shape[1] * shape[2]); + ++i) { + EXPECT_EQ(p[i], q[i]); + } +} + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/transpose.cc b/sherpa-onnx/csrc/transpose.cc index 09a434de..5ec32667 100644 --- a/sherpa-onnx/csrc/transpose.cc +++ b/sherpa-onnx/csrc/transpose.cc @@ -17,8 +17,8 @@ Ort::Value Transpose01(OrtAllocator *allocator, const Ort::Value *v) { assert(shape.size() == 3); std::array ans_shape{shape[1], shape[0], shape[2]}; - Ort::Value ans = Ort::Value::CreateTensor(allocator, ans_shape.data(), - ans_shape.size()); + Ort::Value ans = Ort::Value::CreateTensor(allocator, ans_shape.data(), + ans_shape.size()); T *dst = ans.GetTensorMutableData(); auto plane_offset = shape[1] * shape[2]; @@ -35,7 +35,32 @@ Ort::Value Transpose01(OrtAllocator *allocator, const Ort::Value *v) { return ans; } +template +Ort::Value Transpose12(OrtAllocator *allocator, const Ort::Value *v) { + std::vector shape = v->GetTensorTypeAndShapeInfo().GetShape(); + assert(shape.size() == 3); + + std::array ans_shape{shape[0], shape[2], shape[1]}; + Ort::Value ans = Ort::Value::CreateTensor(allocator, ans_shape.data(), + ans_shape.size()); + T *dst = ans.GetTensorMutableData(); + auto row_stride = shape[2]; + for (int64_t b = 0; b != ans_shape[0]; ++b) { + const T *src = v->GetTensorData() + b * shape[1] * shape[2]; + for (int64_t i = 0; i != ans_shape[1]; ++i) { + for (int64_t k = 0; k != ans_shape[2]; ++k, ++dst) { + *dst = (src + k * row_stride)[i]; + } + } + } + + return ans; +} + template Ort::Value Transpose01(OrtAllocator *allocator, const Ort::Value *v); +template Ort::Value Transpose12(OrtAllocator *allocator, + const Ort::Value *v); + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/transpose.h b/sherpa-onnx/csrc/transpose.h index 404064a3..aba7f44c 100644 --- a/sherpa-onnx/csrc/transpose.h +++ b/sherpa-onnx/csrc/transpose.h @@ -10,13 +10,23 @@ namespace sherpa_onnx { /** Transpose a 3-D tensor from shape (B, T, C) to (T, B, C). * * @param allocator - * @param v A 3-D tensor of shape (B, T, C). Its dataype is T. + * @param v A 3-D tensor of shape (B, T, C). Its dataype is type. * - * @return Return a 3-D tensor of shape (T, B, C). Its datatype is T. + * @return Return a 3-D tensor of shape (T, B, C). Its datatype is type. */ -template +template Ort::Value Transpose01(OrtAllocator *allocator, const Ort::Value *v); +/** Transpose a 3-D tensor from shape (B, T, C) to (B, C, T). + * + * @param allocator + * @param v A 3-D tensor of shape (B, T, C). Its dataype is type. + * + * @return Return a 3-D tensor of shape (B, C, T). Its datatype is type. + */ +template +Ort::Value Transpose12(OrtAllocator *allocator, const Ort::Value *v); + } // namespace sherpa_onnx #endif // SHERPA_ONNX_CSRC_TRANSPOSE_H_ diff --git a/sherpa-onnx/python/csrc/CMakeLists.txt b/sherpa-onnx/python/csrc/CMakeLists.txt index 560adfcb..f32735db 100644 --- a/sherpa-onnx/python/csrc/CMakeLists.txt +++ b/sherpa-onnx/python/csrc/CMakeLists.txt @@ -5,6 +5,7 @@ pybind11_add_module(_sherpa_onnx endpoint.cc features.cc offline-model-config.cc + offline-nemo-enc-dec-ctc-model-config.cc offline-paraformer-model-config.cc offline-recognizer.cc offline-stream.cc diff --git a/sherpa-onnx/python/csrc/offline-model-config.cc b/sherpa-onnx/python/csrc/offline-model-config.cc index 522e16d6..26b561ea 100644 --- a/sherpa-onnx/python/csrc/offline-model-config.cc +++ b/sherpa-onnx/python/csrc/offline-model-config.cc @@ -7,26 +7,31 @@ #include #include -#include "sherpa-onnx/python/csrc/offline-transducer-model-config.h" -#include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h" - #include "sherpa-onnx/csrc/offline-model-config.h" +#include "sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h" +#include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h" +#include "sherpa-onnx/python/csrc/offline-transducer-model-config.h" namespace sherpa_onnx { void PybindOfflineModelConfig(py::module *m) { PybindOfflineTransducerModelConfig(m); PybindOfflineParaformerModelConfig(m); + PybindOfflineNemoEncDecCtcModelConfig(m); using PyClass = OfflineModelConfig; py::class_(*m, "OfflineModelConfig") - .def(py::init(), - py::arg("transducer"), py::arg("paraformer"), py::arg("tokens"), - py::arg("num_threads"), py::arg("debug") = false) + .def(py::init(), + py::arg("transducer") = OfflineTransducerModelConfig(), + py::arg("paraformer") = OfflineParaformerModelConfig(), + py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(), + py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false) .def_readwrite("transducer", &PyClass::transducer) .def_readwrite("paraformer", &PyClass::paraformer) + .def_readwrite("nemo_ctc", &PyClass::nemo_ctc) .def_readwrite("tokens", &PyClass::tokens) .def_readwrite("num_threads", &PyClass::num_threads) .def_readwrite("debug", &PyClass::debug) diff --git a/sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.cc b/sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.cc new file mode 100644 index 00000000..e65f0566 --- /dev/null +++ b/sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.cc @@ -0,0 +1,22 @@ +// sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h" + +#include +#include + +#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h" + +namespace sherpa_onnx { + +void PybindOfflineNemoEncDecCtcModelConfig(py::module *m) { + using PyClass = OfflineNemoEncDecCtcModelConfig; + py::class_(*m, "OfflineNemoEncDecCtcModelConfig") + .def(py::init(), py::arg("model")) + .def_readwrite("model", &PyClass::model) + .def("__str__", &PyClass::ToString); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h b/sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h new file mode 100644 index 00000000..46d65cd6 --- /dev/null +++ b/sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h @@ -0,0 +1,16 @@ +// sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_NEMO_ENC_DEC_CTC_MODEL_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_NEMO_ENC_DEC_CTC_MODEL_CONFIG_H_ + +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" + +namespace sherpa_onnx { + +void PybindOfflineNemoEncDecCtcModelConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_NEMO_ENC_DEC_CTC_MODEL_CONFIG_H_ diff --git a/sherpa-onnx/python/csrc/offline-paraformer-model-config.cc b/sherpa-onnx/python/csrc/offline-paraformer-model-config.cc index ca5c7d24..4b0ca491 100644 --- a/sherpa-onnx/python/csrc/offline-paraformer-model-config.cc +++ b/sherpa-onnx/python/csrc/offline-paraformer-model-config.cc @@ -4,7 +4,6 @@ #include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h" - #include #include @@ -15,8 +14,7 @@ namespace sherpa_onnx { void PybindOfflineParaformerModelConfig(py::module *m) { using PyClass = OfflineParaformerModelConfig; py::class_(*m, "OfflineParaformerModelConfig") - .def(py::init(), - py::arg("model")) + .def(py::init(), py::arg("model")) .def_readwrite("model", &PyClass::model) .def("__str__", &PyClass::ToString); } diff --git a/sherpa-onnx/python/csrc/offline-recognizer.cc b/sherpa-onnx/python/csrc/offline-recognizer.cc index 7365acf1..8a7779ba 100644 --- a/sherpa-onnx/python/csrc/offline-recognizer.cc +++ b/sherpa-onnx/python/csrc/offline-recognizer.cc @@ -11,8 +11,6 @@ namespace sherpa_onnx { - - static void PybindOfflineRecognizerConfig(py::module *m) { using PyClass = OfflineRecognizerConfig; py::class_(*m, "OfflineRecognizerConfig") diff --git a/sherpa-onnx/python/csrc/offline-stream.cc b/sherpa-onnx/python/csrc/offline-stream.cc index be989aca..bf851fd5 100644 --- a/sherpa-onnx/python/csrc/offline-stream.cc +++ b/sherpa-onnx/python/csrc/offline-stream.cc @@ -31,7 +31,6 @@ static void PybindOfflineRecognitionResult(py::module *m) { // NOLINT "timestamps", [](const PyClass &self) { return self.timestamps; }); } - static void PybindOfflineFeatureExtractorConfig(py::module *m) { using PyClass = OfflineFeatureExtractorConfig; py::class_(*m, "OfflineFeatureExtractorConfig") @@ -42,7 +41,6 @@ static void PybindOfflineFeatureExtractorConfig(py::module *m) { .def("__str__", &PyClass::ToString); } - void PybindOfflineStream(py::module *m) { PybindOfflineFeatureExtractorConfig(m); PybindOfflineRecognitionResult(m); @@ -55,7 +53,7 @@ void PybindOfflineStream(py::module *m) { self.AcceptWaveform(sample_rate, waveform.data(), waveform.size()); }, py::arg("sample_rate"), py::arg("waveform"), kAcceptWaveformUsage) - .def_property_readonly("result", &PyClass::GetResult); + .def_property_readonly("result", &PyClass::GetResult); } } // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/sherpa-onnx.cc b/sherpa-onnx/python/csrc/sherpa-onnx.cc index b235d47f..e099bd79 100644 --- a/sherpa-onnx/python/csrc/sherpa-onnx.cc +++ b/sherpa-onnx/python/csrc/sherpa-onnx.cc @@ -7,16 +7,13 @@ #include "sherpa-onnx/python/csrc/display.h" #include "sherpa-onnx/python/csrc/endpoint.h" #include "sherpa-onnx/python/csrc/features.h" +#include "sherpa-onnx/python/csrc/offline-model-config.h" +#include "sherpa-onnx/python/csrc/offline-recognizer.h" +#include "sherpa-onnx/python/csrc/offline-stream.h" #include "sherpa-onnx/python/csrc/online-recognizer.h" #include "sherpa-onnx/python/csrc/online-stream.h" #include "sherpa-onnx/python/csrc/online-transducer-model-config.h" -#include "sherpa-onnx/python/csrc/offline-model-config.h" -#include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h" -#include "sherpa-onnx/python/csrc/offline-recognizer.h" -#include "sherpa-onnx/python/csrc/offline-stream.h" -#include "sherpa-onnx/python/csrc/offline-transducer-model-config.h" - namespace sherpa_onnx { PYBIND11_MODULE(_sherpa_onnx, m) { diff --git a/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py b/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py index f4371e7a..1c25c7d1 100644 --- a/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py @@ -4,12 +4,15 @@ from typing import List from _sherpa_onnx import ( OfflineFeatureExtractorConfig, - OfflineRecognizer as _Recognizer, + OfflineModelConfig, + OfflineNemoEncDecCtcModelConfig, + OfflineParaformerModelConfig, +) +from _sherpa_onnx import OfflineRecognizer as _Recognizer +from _sherpa_onnx import ( OfflineRecognizerConfig, OfflineStream, - OfflineModelConfig, OfflineTransducerModelConfig, - OfflineParaformerModelConfig, ) @@ -75,7 +78,6 @@ class OfflineRecognizer(object): decoder_filename=decoder, joiner_filename=joiner, ), - paraformer=OfflineParaformerModelConfig(model=""), tokens=tokens, num_threads=num_threads, debug=debug, @@ -119,7 +121,7 @@ class OfflineRecognizer(object): symbol integer_id paraformer: - Path to ``paraformer.onnx``. + Path to ``model.onnx``. num_threads: Number of threads for neural network computation. sample_rate: @@ -133,9 +135,6 @@ class OfflineRecognizer(object): """ self = cls.__new__(cls) model_config = OfflineModelConfig( - transducer=OfflineTransducerModelConfig( - encoder_filename="", decoder_filename="", joiner_filename="" - ), paraformer=OfflineParaformerModelConfig(model=paraformer), tokens=tokens, num_threads=num_threads, @@ -155,6 +154,64 @@ class OfflineRecognizer(object): self.recognizer = _Recognizer(recognizer_config) return self + @classmethod + def from_nemo_ctc( + cls, + model: str, + tokens: str, + num_threads: int, + sample_rate: int = 16000, + feature_dim: int = 80, + decoding_method: str = "greedy_search", + debug: bool = False, + ): + """ + Please refer to + ``_ + to download pre-trained models for different languages, e.g., Chinese, + English, etc. + + Args: + tokens: + Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two + columns:: + + symbol integer_id + + model: + Path to ``model.onnx``. + num_threads: + Number of threads for neural network computation. + sample_rate: + Sample rate of the training data used to train the model. + feature_dim: + Dimension of the feature used to train the model. + decoding_method: + Valid values are greedy_search, modified_beam_search. + debug: + True to show debug messages. + """ + self = cls.__new__(cls) + model_config = OfflineModelConfig( + nemo_ctc=OfflineNemoEncDecCtcModelConfig(model=model), + tokens=tokens, + num_threads=num_threads, + debug=debug, + ) + + feat_config = OfflineFeatureExtractorConfig( + sampling_rate=sample_rate, + feature_dim=feature_dim, + ) + + recognizer_config = OfflineRecognizerConfig( + feat_config=feat_config, + model_config=model_config, + decoding_method=decoding_method, + ) + self.recognizer = _Recognizer(recognizer_config) + return self + def create_stream(self): return self.recognizer.create_stream() diff --git a/sherpa-onnx/python/tests/test_offline_recognizer.py b/sherpa-onnx/python/tests/test_offline_recognizer.py index 5f9924d9..bb6c994c 100755 --- a/sherpa-onnx/python/tests/test_offline_recognizer.py +++ b/sherpa-onnx/python/tests/test_offline_recognizer.py @@ -196,6 +196,71 @@ class TestOfflineRecognizer(unittest.TestCase): print(s2.result.text) print(s3.result.text) + def test_nemo_ctc_single_file(self): + for use_int8 in [True, False]: + if use_int8: + model = f"{d}/sherpa-onnx-nemo-ctc-en-citrinet-512/model.int8.onnx" + else: + model = f"{d}/sherpa-onnx-nemo-ctc-en-citrinet-512/model.onnx" + + tokens = f"{d}/sherpa-onnx-nemo-ctc-en-citrinet-512/tokens.txt" + wave0 = f"{d}/sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/0.wav" + + if not Path(model).is_file(): + print("skipping test_nemo_ctc_single_file()") + return + + recognizer = sherpa_onnx.OfflineRecognizer.from_nemo_ctc( + model=model, + tokens=tokens, + num_threads=1, + ) + + s = recognizer.create_stream() + samples, sample_rate = read_wave(wave0) + s.accept_waveform(sample_rate, samples) + recognizer.decode_stream(s) + print(s.result.text) + + def test_nemo_ctc_multiple_files(self): + for use_int8 in [True, False]: + if use_int8: + model = f"{d}/sherpa-onnx-nemo-ctc-en-citrinet-512/model.int8.onnx" + else: + model = f"{d}/sherpa-onnx-nemo-ctc-en-citrinet-512/model.onnx" + + tokens = f"{d}/sherpa-onnx-nemo-ctc-en-citrinet-512/tokens.txt" + wave0 = f"{d}/sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/0.wav" + wave1 = f"{d}/sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/1.wav" + wave2 = f"{d}/sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/8k.wav" + + if not Path(model).is_file(): + print("skipping test_nemo_ctc_multiple_files()") + return + + recognizer = sherpa_onnx.OfflineRecognizer.from_nemo_ctc( + model=model, + tokens=tokens, + num_threads=1, + ) + + s0 = recognizer.create_stream() + samples0, sample_rate0 = read_wave(wave0) + s0.accept_waveform(sample_rate0, samples0) + + s1 = recognizer.create_stream() + samples1, sample_rate1 = read_wave(wave1) + s1.accept_waveform(sample_rate1, samples1) + + s2 = recognizer.create_stream() + samples2, sample_rate2 = read_wave(wave2) + s2.accept_waveform(sample_rate2, samples2) + + recognizer.decode_streams([s0, s1, s2]) + print(s0.result.text) + print(s1.result.text) + print(s2.result.text) + if __name__ == "__main__": unittest.main()