diff --git a/.github/scripts/test-offline-ctc.sh b/.github/scripts/test-offline-ctc.sh index 7e0b840d..1648a18a 100755 --- a/.github/scripts/test-offline-ctc.sh +++ b/.github/scripts/test-offline-ctc.sh @@ -13,6 +13,50 @@ echo "PATH: $PATH" which $EXE +log "------------------------------------------------------------" +log "Run tdnn yesno (Hebrew)" +log "------------------------------------------------------------" +repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-tdnn-yesno +log "Start testing ${repo_url}" +repo=$(basename $repo_url) +log "Download pretrained model and test-data from $repo_url" + +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +pushd $repo +git lfs pull --include "*.onnx" +ls -lh *.onnx +popd + +log "test float32 models" +time $EXE \ + --sample-rate=8000 \ + --feat-dim=23 \ + \ + --tokens=$repo/tokens.txt \ + --tdnn-model=$repo/model-epoch-14-avg-2.onnx \ + $repo/test_wavs/0_0_0_1_0_0_0_1.wav \ + $repo/test_wavs/0_0_1_0_0_0_1_0.wav \ + $repo/test_wavs/0_0_1_0_0_1_1_1.wav \ + $repo/test_wavs/0_0_1_0_1_0_0_1.wav \ + $repo/test_wavs/0_0_1_1_0_0_0_1.wav \ + $repo/test_wavs/0_0_1_1_0_1_1_0.wav + +log "test int8 models" +time $EXE \ + --sample-rate=8000 \ + --feat-dim=23 \ + \ + --tokens=$repo/tokens.txt \ + --tdnn-model=$repo/model-epoch-14-avg-2.int8.onnx \ + $repo/test_wavs/0_0_0_1_0_0_0_1.wav \ + $repo/test_wavs/0_0_1_0_0_0_1_0.wav \ + $repo/test_wavs/0_0_1_0_0_1_1_1.wav \ + $repo/test_wavs/0_0_1_0_1_0_0_1.wav \ + $repo/test_wavs/0_0_1_1_0_0_0_1.wav \ + $repo/test_wavs/0_0_1_1_0_1_1_0.wav + +rm -rf $repo + log "------------------------------------------------------------" log "Run Citrinet (stt_en_citrinet_512, English)" log "------------------------------------------------------------" diff --git a/.github/workflows/test-python-offline-websocket-server.yaml b/.github/workflows/test-python-offline-websocket-server.yaml index ce538d18..7ec4e29d 100644 --- a/.github/workflows/test-python-offline-websocket-server.yaml +++ b/.github/workflows/test-python-offline-websocket-server.yaml @@ -24,7 +24,7 @@ jobs: matrix: os: [ubuntu-latest, windows-latest, macos-latest] python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] - model_type: ["transducer", "paraformer", "nemo_ctc", "whisper"] + model_type: ["transducer", "paraformer", "nemo_ctc", "whisper", "tdnn"] steps: - uses: actions/checkout@v2 @@ -172,3 +172,41 @@ jobs: ./sherpa-onnx-whisper-tiny.en/test_wavs/0.wav \ ./sherpa-onnx-whisper-tiny.en/test_wavs/1.wav \ ./sherpa-onnx-whisper-tiny.en/test_wavs/8k.wav + + - name: Start server for tdnn models + if: matrix.model_type == 'tdnn' + shell: bash + run: | + GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-tdnn-yesno + cd sherpa-onnx-tdnn-yesno + git lfs pull --include "*.onnx" + cd .. + + python3 ./python-api-examples/non_streaming_server.py \ + --tdnn-model=./sherpa-onnx-tdnn-yesno/model-epoch-14-avg-2.onnx \ + --tokens=./sherpa-onnx-tdnn-yesno/tokens.txt \ + --sample-rate=8000 \ + --feat-dim=23 & + + echo "sleep 10 seconds to wait the server start" + sleep 10 + + - name: Start client for tdnn models + if: matrix.model_type == 'tdnn' + shell: bash + run: | + python3 ./python-api-examples/offline-websocket-client-decode-files-paralell.py \ + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_0_1_0_0_0_1.wav \ + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_0_1_0.wav \ + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_1_1_1.wav \ + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_1_0_0_1.wav \ + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_1_0_0_0_1.wav \ + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_1_0_1_1_0.wav + + python3 ./python-api-examples/offline-websocket-client-decode-files-sequential.py \ + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_0_1_0_0_0_1.wav \ + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_0_1_0.wav \ + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_1_1_1.wav \ + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_1_0_0_1.wav \ + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_1_0_0_0_1.wav \ + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_1_0_1_1_0.wav diff --git a/CMakeLists.txt b/CMakeLists.txt index c0ff2901..abb77787 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,7 +1,7 @@ cmake_minimum_required(VERSION 3.13 FATAL_ERROR) project(sherpa-onnx) -set(SHERPA_ONNX_VERSION "1.7.2") +set(SHERPA_ONNX_VERSION "1.7.3") # Disable warning about # diff --git a/python-api-examples/non_streaming_server.py b/python-api-examples/non_streaming_server.py index 52210d8b..7d3502fa 100755 --- a/python-api-examples/non_streaming_server.py +++ b/python-api-examples/non_streaming_server.py @@ -71,6 +71,20 @@ python3 ./python-api-examples/non_streaming_server.py \ --whisper-decoder=./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.onnx \ --tokens=./sherpa-onnx-whisper-tiny.en/tiny.en-tokens.txt +(5) Use a tdnn model of the yesno recipe from icefall + +cd /path/to/sherpa-onnx + +GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-tdnn-yesno +cd sherpa-onnx-tdnn-yesno +git lfs pull --include "*.onnx" + +python3 ./python-api-examples/non_streaming_server.py \ + --sample-rate=8000 \ + --feat-dim=23 \ + --tdnn-model=./sherpa-onnx-tdnn-yesno/model-epoch-14-avg-2.onnx \ + --tokens=./sherpa-onnx-tdnn-yesno/tokens.txt + ---- To use a certificate so that you can use https, please use @@ -196,6 +210,15 @@ def add_nemo_ctc_model_args(parser: argparse.ArgumentParser): ) +def add_tdnn_ctc_model_args(parser: argparse.ArgumentParser): + parser.add_argument( + "--tdnn-model", + default="", + type=str, + help="Path to the model.onnx for the tdnn model of the yesno recipe", + ) + + def add_whisper_model_args(parser: argparse.ArgumentParser): parser.add_argument( "--whisper-encoder", @@ -216,6 +239,7 @@ def add_model_args(parser: argparse.ArgumentParser): add_transducer_model_args(parser) add_paraformer_model_args(parser) add_nemo_ctc_model_args(parser) + add_tdnn_ctc_model_args(parser) add_whisper_model_args(parser) parser.add_argument( @@ -730,6 +754,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: assert len(args.nemo_ctc) == 0, args.nemo_ctc assert len(args.whisper_encoder) == 0, args.whisper_encoder assert len(args.whisper_decoder) == 0, args.whisper_decoder + assert len(args.tdnn_model) == 0, args.tdnn_model assert_file_exists(args.encoder) assert_file_exists(args.decoder) @@ -750,6 +775,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: assert len(args.nemo_ctc) == 0, args.nemo_ctc assert len(args.whisper_encoder) == 0, args.whisper_encoder assert len(args.whisper_decoder) == 0, args.whisper_decoder + assert len(args.tdnn_model) == 0, args.tdnn_model assert_file_exists(args.paraformer) @@ -764,6 +790,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: elif args.nemo_ctc: assert len(args.whisper_encoder) == 0, args.whisper_encoder assert len(args.whisper_decoder) == 0, args.whisper_decoder + assert len(args.tdnn_model) == 0, args.tdnn_model assert_file_exists(args.nemo_ctc) @@ -776,6 +803,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: decoding_method=args.decoding_method, ) elif args.whisper_encoder: + assert len(args.tdnn_model) == 0, args.tdnn_model assert_file_exists(args.whisper_encoder) assert_file_exists(args.whisper_decoder) @@ -786,6 +814,17 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: num_threads=args.num_threads, decoding_method=args.decoding_method, ) + elif args.tdnn_model: + assert_file_exists(args.tdnn_model) + + recognizer = sherpa_onnx.OfflineRecognizer.from_tdnn_ctc( + model=args.tdnn_model, + tokens=args.tokens, + sample_rate=args.sample_rate, + feature_dim=args.feat_dim, + num_threads=args.num_threads, + decoding_method=args.decoding_method, + ) else: raise ValueError("Please specify at least one model") diff --git a/python-api-examples/offline-decode-files.py b/python-api-examples/offline-decode-files.py index c6b63ee0..f3f7949c 100755 --- a/python-api-examples/offline-decode-files.py +++ b/python-api-examples/offline-decode-files.py @@ -8,6 +8,7 @@ This file demonstrates how to use sherpa-onnx Python API to transcribe file(s) with a non-streaming model. (1) For paraformer + ./python-api-examples/offline-decode-files.py \ --tokens=/path/to/tokens.txt \ --paraformer=/path/to/paraformer.onnx \ @@ -20,6 +21,7 @@ file(s) with a non-streaming model. /path/to/1.wav (2) For transducer models from icefall + ./python-api-examples/offline-decode-files.py \ --tokens=/path/to/tokens.txt \ --encoder=/path/to/encoder.onnx \ @@ -56,9 +58,20 @@ python3 ./python-api-examples/offline-decode-files.py \ ./sherpa-onnx-whisper-base.en/test_wavs/1.wav \ ./sherpa-onnx-whisper-base.en/test_wavs/8k.wav +(5) For tdnn models of the yesno recipe from icefall + +python3 ./python-api-examples/offline-decode-files.py \ + --sample-rate=8000 \ + --feature-dim=23 \ + --tdnn-model=./sherpa-onnx-tdnn-yesno/model-epoch-14-avg-2.onnx \ + --tokens=./sherpa-onnx-tdnn-yesno/tokens.txt \ + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_0_1_0_0_0_1.wav \ + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_0_1_0.wav \ + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_1_1_1.wav + Please refer to https://k2-fsa.github.io/sherpa/onnx/index.html -to install sherpa-onnx and to download the pre-trained models +to install sherpa-onnx and to download non-streaming pre-trained models used in this file. """ import argparse @@ -159,6 +172,13 @@ def get_args(): help="Path to the model.onnx from NeMo CTC", ) + parser.add_argument( + "--tdnn-model", + default="", + type=str, + help="Path to the model.onnx for the tdnn model of the yesno recipe", + ) + parser.add_argument( "--num-threads", type=int, @@ -285,6 +305,7 @@ def main(): assert len(args.nemo_ctc) == 0, args.nemo_ctc assert len(args.whisper_encoder) == 0, args.whisper_encoder assert len(args.whisper_decoder) == 0, args.whisper_decoder + assert len(args.tdnn_model) == 0, args.tdnn_model contexts = [x.strip().upper() for x in args.contexts.split("/") if x.strip()] if contexts: @@ -311,6 +332,7 @@ def main(): assert len(args.nemo_ctc) == 0, args.nemo_ctc assert len(args.whisper_encoder) == 0, args.whisper_encoder assert len(args.whisper_decoder) == 0, args.whisper_decoder + assert len(args.tdnn_model) == 0, args.tdnn_model assert_file_exists(args.paraformer) @@ -326,6 +348,7 @@ def main(): elif args.nemo_ctc: assert len(args.whisper_encoder) == 0, args.whisper_encoder assert len(args.whisper_decoder) == 0, args.whisper_decoder + assert len(args.tdnn_model) == 0, args.tdnn_model assert_file_exists(args.nemo_ctc) @@ -339,6 +362,7 @@ def main(): debug=args.debug, ) elif args.whisper_encoder: + assert len(args.tdnn_model) == 0, args.tdnn_model assert_file_exists(args.whisper_encoder) assert_file_exists(args.whisper_decoder) @@ -347,6 +371,20 @@ def main(): decoder=args.whisper_decoder, tokens=args.tokens, num_threads=args.num_threads, + sample_rate=args.sample_rate, + feature_dim=args.feature_dim, + decoding_method=args.decoding_method, + debug=args.debug, + ) + elif args.tdnn_model: + assert_file_exists(args.tdnn_model) + + recognizer = sherpa_onnx.OfflineRecognizer.from_tdnn_ctc( + model=args.tdnn_model, + tokens=args.tokens, + sample_rate=args.sample_rate, + feature_dim=args.feature_dim, + num_threads=args.num_threads, decoding_method=args.decoding_method, debug=args.debug, ) diff --git a/python-api-examples/web/js/upload.js b/python-api-examples/web/js/upload.js index 34315010..308c36ef 100644 --- a/python-api-examples/web/js/upload.js +++ b/python-api-examples/web/js/upload.js @@ -97,20 +97,18 @@ function onFileChange() { console.log('file.type ' + file.type); console.log('file.size ' + file.size); + let audioCtx = new AudioContext({sampleRate: 16000}); + let reader = new FileReader(); reader.onload = function() { console.log('reading file!'); - let view = new Int16Array(reader.result); - // we assume the input file is a wav file. - // TODO: add some checks here. - let int16_samples = view.subarray(22); // header has 44 bytes == 22 shorts - let num_samples = int16_samples.length; - let float32_samples = new Float32Array(num_samples); - console.log('num_samples ' + num_samples) + audioCtx.decodeAudioData(reader.result, decodedDone); + }; - for (let i = 0; i < num_samples; ++i) { - float32_samples[i] = int16_samples[i] / 32768. - } + function decodedDone(decoded) { + let typedArray = new Float32Array(decoded.length); + let float32_samples = decoded.getChannelData(0); + let buf = float32_samples.buffer // Send 1024 audio samples per request. // @@ -119,14 +117,13 @@ function onFileChange() { // (2) There is a limit on the number of bytes in the payload that can be // sent by websocket, which is 1MB, I think. We can send a large // audio file for decoding in this approach. - let buf = float32_samples.buffer let n = 1024 * 4; // send this number of bytes per request. console.log('buf length, ' + buf.byteLength); send_header(buf.byteLength); for (let start = 0; start < buf.byteLength; start += n) { socket.send(buf.slice(start, start + n)); } - }; + } reader.readAsArrayBuffer(file); } diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index 3426d565..cb4953c5 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -32,6 +32,8 @@ set(sources offline-recognizer.cc offline-rnn-lm.cc offline-stream.cc + offline-tdnn-ctc-model.cc + offline-tdnn-model-config.cc offline-transducer-greedy-search-decoder.cc offline-transducer-model-config.cc offline-transducer-model.cc diff --git a/sherpa-onnx/csrc/offline-ctc-model.cc b/sherpa-onnx/csrc/offline-ctc-model.cc index 1d19253d..f4529bcc 100644 --- a/sherpa-onnx/csrc/offline-ctc-model.cc +++ b/sherpa-onnx/csrc/offline-ctc-model.cc @@ -11,12 +11,14 @@ #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h" +#include "sherpa-onnx/csrc/offline-tdnn-ctc-model.h" #include "sherpa-onnx/csrc/onnx-utils.h" namespace { enum class ModelType { kEncDecCTCModelBPE, + kTdnn, kUnkown, }; @@ -55,6 +57,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, if (model_type.get() == std::string("EncDecCTCModelBPE")) { return ModelType::kEncDecCTCModelBPE; + } else if (model_type.get() == std::string("tdnn")) { + return ModelType::kTdnn; } else { SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get()); return ModelType::kUnkown; @@ -65,8 +69,18 @@ std::unique_ptr OfflineCtcModel::Create( const OfflineModelConfig &config) { ModelType model_type = ModelType::kUnkown; + std::string filename; + if (!config.nemo_ctc.model.empty()) { + filename = config.nemo_ctc.model; + } else if (!config.tdnn.model.empty()) { + filename = config.tdnn.model; + } else { + SHERPA_ONNX_LOGE("Please specify a CTC model"); + exit(-1); + } + { - auto buffer = ReadFile(config.nemo_ctc.model); + auto buffer = ReadFile(filename); model_type = GetModelType(buffer.data(), buffer.size(), config.debug); } @@ -75,6 +89,9 @@ std::unique_ptr OfflineCtcModel::Create( case ModelType::kEncDecCTCModelBPE: return std::make_unique(config); break; + case ModelType::kTdnn: + return std::make_unique(config); + break; case ModelType::kUnkown: SHERPA_ONNX_LOGE("Unknown model type in offline CTC!"); return nullptr; diff --git a/sherpa-onnx/csrc/offline-ctc-model.h b/sherpa-onnx/csrc/offline-ctc-model.h index 8be7f99b..8ef43d55 100644 --- a/sherpa-onnx/csrc/offline-ctc-model.h +++ b/sherpa-onnx/csrc/offline-ctc-model.h @@ -39,10 +39,10 @@ class OfflineCtcModel { /** SubsamplingFactor of the model * - * For Citrinet, the subsampling factor is usually 4. - * For Conformer CTC, the subsampling factor is usually 8. + * For NeMo Citrinet, the subsampling factor is usually 4. + * For NeMo Conformer CTC, the subsampling factor is usually 8. */ - virtual int32_t SubsamplingFactor() const = 0; + virtual int32_t SubsamplingFactor() const { return 1; } /** Return an allocator for allocating memory */ diff --git a/sherpa-onnx/csrc/offline-model-config.cc b/sherpa-onnx/csrc/offline-model-config.cc index 9808d8f6..c491ed55 100644 --- a/sherpa-onnx/csrc/offline-model-config.cc +++ b/sherpa-onnx/csrc/offline-model-config.cc @@ -15,6 +15,7 @@ void OfflineModelConfig::Register(ParseOptions *po) { paraformer.Register(po); nemo_ctc.Register(po); whisper.Register(po); + tdnn.Register(po); po->Register("tokens", &tokens, "Path to tokens.txt"); @@ -29,7 +30,8 @@ void OfflineModelConfig::Register(ParseOptions *po) { po->Register("model-type", &model_type, "Specify it to reduce model initialization time. " - "Valid values are: transducer, paraformer, nemo_ctc, whisper." + "Valid values are: transducer, paraformer, nemo_ctc, whisper, " + "tdnn." "All other values lead to loading the model twice."); } @@ -56,6 +58,10 @@ bool OfflineModelConfig::Validate() const { return whisper.Validate(); } + if (!tdnn.model.empty()) { + return tdnn.Validate(); + } + return transducer.Validate(); } @@ -67,6 +73,7 @@ std::string OfflineModelConfig::ToString() const { os << "paraformer=" << paraformer.ToString() << ", "; os << "nemo_ctc=" << nemo_ctc.ToString() << ", "; os << "whisper=" << whisper.ToString() << ", "; + os << "tdnn=" << tdnn.ToString() << ", "; os << "tokens=\"" << tokens << "\", "; os << "num_threads=" << num_threads << ", "; os << "debug=" << (debug ? "True" : "False") << ", "; diff --git a/sherpa-onnx/csrc/offline-model-config.h b/sherpa-onnx/csrc/offline-model-config.h index 41f441c9..2664db31 100644 --- a/sherpa-onnx/csrc/offline-model-config.h +++ b/sherpa-onnx/csrc/offline-model-config.h @@ -8,6 +8,7 @@ #include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h" #include "sherpa-onnx/csrc/offline-paraformer-model-config.h" +#include "sherpa-onnx/csrc/offline-tdnn-model-config.h" #include "sherpa-onnx/csrc/offline-transducer-model-config.h" #include "sherpa-onnx/csrc/offline-whisper-model-config.h" @@ -18,6 +19,7 @@ struct OfflineModelConfig { OfflineParaformerModelConfig paraformer; OfflineNemoEncDecCtcModelConfig nemo_ctc; OfflineWhisperModelConfig whisper; + OfflineTdnnModelConfig tdnn; std::string tokens; int32_t num_threads = 2; @@ -40,12 +42,14 @@ struct OfflineModelConfig { const OfflineParaformerModelConfig ¶former, const OfflineNemoEncDecCtcModelConfig &nemo_ctc, const OfflineWhisperModelConfig &whisper, + const OfflineTdnnModelConfig &tdnn, const std::string &tokens, int32_t num_threads, bool debug, const std::string &provider, const std::string &model_type) : transducer(transducer), paraformer(paraformer), nemo_ctc(nemo_ctc), whisper(whisper), + tdnn(tdnn), tokens(tokens), num_threads(num_threads), debug(debug), diff --git a/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h b/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h index 80c946c7..d62fe09b 100644 --- a/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h @@ -27,6 +27,10 @@ static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src, std::string text; for (int32_t i = 0; i != src.tokens.size(); ++i) { + if (sym_table.contains("SIL") && src.tokens[i] == sym_table["SIL"]) { + // tdnn models from yesno have a SIL token, we should remove it. + continue; + } auto sym = sym_table[src.tokens[i]]; text.append(sym); r.tokens.push_back(std::move(sym)); @@ -46,14 +50,22 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { model_->FeatureNormalizationMethod(); if (config.decoding_method == "greedy_search") { - if (!symbol_table_.contains("")) { + if (!symbol_table_.contains("") && + !symbol_table_.contains("")) { SHERPA_ONNX_LOGE( "We expect that tokens.txt contains " - "the symbol and its ID."); + "the symbol or and its ID."); exit(-1); } - int32_t blank_id = symbol_table_[""]; + int32_t blank_id = 0; + if (symbol_table_.contains("")) { + blank_id = symbol_table_[""]; + } else if (symbol_table_.contains("")) { + // for tdnn models of the yesno recipe from icefall + blank_id = symbol_table_[""]; + } + decoder_ = std::make_unique(blank_id); } else { SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s", diff --git a/sherpa-onnx/csrc/offline-recognizer-impl.cc b/sherpa-onnx/csrc/offline-recognizer-impl.cc index 5058a8ce..c818b95e 100644 --- a/sherpa-onnx/csrc/offline-recognizer-impl.cc +++ b/sherpa-onnx/csrc/offline-recognizer-impl.cc @@ -27,6 +27,8 @@ std::unique_ptr OfflineRecognizerImpl::Create( return std::make_unique(config); } else if (model_type == "nemo_ctc") { return std::make_unique(config); + } else if (model_type == "tdnn") { + return std::make_unique(config); } else if (model_type == "whisper") { return std::make_unique(config); } else { @@ -46,6 +48,8 @@ std::unique_ptr OfflineRecognizerImpl::Create( model_filename = config.model_config.paraformer.model; } else if (!config.model_config.nemo_ctc.model.empty()) { model_filename = config.model_config.nemo_ctc.model; + } else if (!config.model_config.tdnn.model.empty()) { + model_filename = config.model_config.tdnn.model; } else if (!config.model_config.whisper.encoder.empty()) { model_filename = config.model_config.whisper.encoder; } else { @@ -84,6 +88,11 @@ std::unique_ptr OfflineRecognizerImpl::Create( "paraformer-onnxruntime-python-example/blob/main/add-model-metadata.py" "\n " "(3) Whisper" + "\n " + "(4) Tdnn models of the yesno recipe from icefall" + "\n " + "https://github.com/k2-fsa/icefall/tree/master/egs/yesno/ASR/tdnn" + "\n" "\n"); exit(-1); } @@ -102,6 +111,10 @@ std::unique_ptr OfflineRecognizerImpl::Create( return std::make_unique(config); } + if (model_type == "tdnn") { + return std::make_unique(config); + } + if (strncmp(model_type.c_str(), "whisper", 7) == 0) { return std::make_unique(config); } @@ -112,7 +125,8 @@ std::unique_ptr OfflineRecognizerImpl::Create( " - Non-streaming transducer models from icefall\n" " - Non-streaming Paraformer models from FunASR\n" " - EncDecCTCModelBPE models from NeMo\n" - " - Whisper models\n", + " - Whisper models\n" + " - Tdnn models\n", model_type.c_str()); exit(-1); diff --git a/sherpa-onnx/csrc/offline-tdnn-ctc-model.cc b/sherpa-onnx/csrc/offline-tdnn-ctc-model.cc new file mode 100644 index 00000000..ff08295c --- /dev/null +++ b/sherpa-onnx/csrc/offline-tdnn-ctc-model.cc @@ -0,0 +1,106 @@ +// sherpa-onnx/csrc/offline-tdnn-ctc-model.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-tdnn-ctc-model.h" + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/onnx-utils.h" +#include "sherpa-onnx/csrc/session.h" +#include "sherpa-onnx/csrc/text-utils.h" +#include "sherpa-onnx/csrc/transpose.h" + +namespace sherpa_onnx { + +class OfflineTdnnCtcModel::Impl { + public: + explicit Impl(const OfflineModelConfig &config) + : config_(config), + env_(ORT_LOGGING_LEVEL_ERROR), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + Init(); + } + + std::pair Forward(Ort::Value features) { + auto nnet_out = + sess_->Run({}, input_names_ptr_.data(), &features, 1, + output_names_ptr_.data(), output_names_ptr_.size()); + + std::vector nnet_out_shape = + nnet_out[0].GetTensorTypeAndShapeInfo().GetShape(); + + std::vector out_length_vec(nnet_out_shape[0], nnet_out_shape[1]); + std::vector out_length_shape(1, nnet_out_shape[0]); + + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + Ort::Value nnet_out_length = Ort::Value::CreateTensor( + memory_info, out_length_vec.data(), out_length_vec.size(), + out_length_shape.data(), out_length_shape.size()); + + return {std::move(nnet_out[0]), Clone(Allocator(), &nnet_out_length)}; + } + + int32_t VocabSize() const { return vocab_size_; } + + OrtAllocator *Allocator() const { return allocator_; } + + private: + void Init() { + auto buf = ReadFile(config_.tdnn.model); + + sess_ = std::make_unique(env_, buf.data(), buf.size(), + sess_opts_); + + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); + + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); + + // get meta data + Ort::ModelMetadata meta_data = sess_->GetModelMetadata(); + if (config_.debug) { + std::ostringstream os; + PrintModelMetadata(os, meta_data); + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); + } + + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below + SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size"); + } + + private: + OfflineModelConfig config_; + Ort::Env env_; + Ort::SessionOptions sess_opts_; + Ort::AllocatorWithDefaultOptions allocator_; + + std::unique_ptr sess_; + + std::vector input_names_; + std::vector input_names_ptr_; + + std::vector output_names_; + std::vector output_names_ptr_; + + int32_t vocab_size_ = 0; +}; + +OfflineTdnnCtcModel::OfflineTdnnCtcModel(const OfflineModelConfig &config) + : impl_(std::make_unique(config)) {} + +OfflineTdnnCtcModel::~OfflineTdnnCtcModel() = default; + +std::pair OfflineTdnnCtcModel::Forward( + Ort::Value features, Ort::Value /*features_length*/) { + return impl_->Forward(std::move(features)); +} + +int32_t OfflineTdnnCtcModel::VocabSize() const { return impl_->VocabSize(); } + +OrtAllocator *OfflineTdnnCtcModel::Allocator() const { + return impl_->Allocator(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-tdnn-ctc-model.h b/sherpa-onnx/csrc/offline-tdnn-ctc-model.h new file mode 100644 index 00000000..882e6e57 --- /dev/null +++ b/sherpa-onnx/csrc/offline-tdnn-ctc-model.h @@ -0,0 +1,56 @@ +// sherpa-onnx/csrc/offline-tdnn-ctc-model.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TDNN_CTC_MODEL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_TDNN_CTC_MODEL_H_ +#include +#include +#include +#include + +#include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/offline-ctc-model.h" +#include "sherpa-onnx/csrc/offline-model-config.h" + +namespace sherpa_onnx { + +/** This class implements the tdnn model of the yesno recipe from icefall. + * + * See + * https://github.com/k2-fsa/icefall/tree/master/egs/yesno/ASR/tdnn + */ +class OfflineTdnnCtcModel : public OfflineCtcModel { + public: + explicit OfflineTdnnCtcModel(const OfflineModelConfig &config); + ~OfflineTdnnCtcModel() override; + + /** Run the forward method of the model. + * + * @param features A tensor of shape (N, T, C). It is changed in-place. + * @param features_length A 1-D tensor of shape (N,) containing number of + * valid frames in `features` before padding. + * Its dtype is int64_t. + * + * @return Return a pair containing: + * - log_probs: A 3-D tensor of shape (N, T', vocab_size). + * - log_probs_length A 1-D tensor of shape (N,). Its dtype is int64_t + */ + std::pair Forward( + Ort::Value features, Ort::Value /*features_length*/) override; + + /** Return the vocabulary size of the model + */ + int32_t VocabSize() const override; + + /** Return an allocator for allocating memory + */ + OrtAllocator *Allocator() const override; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_TDNN_CTC_MODEL_H_ diff --git a/sherpa-onnx/csrc/offline-tdnn-model-config.cc b/sherpa-onnx/csrc/offline-tdnn-model-config.cc new file mode 100644 index 00000000..be1b11cd --- /dev/null +++ b/sherpa-onnx/csrc/offline-tdnn-model-config.cc @@ -0,0 +1,34 @@ +// sherpa-onnx/csrc/offline-tdnn-model-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-tdnn-model-config.h" + +#include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/macros.h" + +namespace sherpa_onnx { + +void OfflineTdnnModelConfig::Register(ParseOptions *po) { + po->Register("tdnn-model", &model, "Path to onnx model"); +} + +bool OfflineTdnnModelConfig::Validate() const { + if (!FileExists(model)) { + SHERPA_ONNX_LOGE("tdnn model file %s does not exist", model.c_str()); + return false; + } + + return true; +} + +std::string OfflineTdnnModelConfig::ToString() const { + std::ostringstream os; + + os << "OfflineTdnnModelConfig("; + os << "model=\"" << model << "\")"; + + return os.str(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-tdnn-model-config.h b/sherpa-onnx/csrc/offline-tdnn-model-config.h new file mode 100644 index 00000000..bddea551 --- /dev/null +++ b/sherpa-onnx/csrc/offline-tdnn-model-config.h @@ -0,0 +1,28 @@ +// sherpa-onnx/csrc/offline-tdnn-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TDNN_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_TDNN_MODEL_CONFIG_H_ + +#include + +#include "sherpa-onnx/csrc/parse-options.h" + +namespace sherpa_onnx { + +// for https://github.com/k2-fsa/icefall/tree/master/egs/yesno/ASR/tdnn +struct OfflineTdnnModelConfig { + std::string model; + + OfflineTdnnModelConfig() = default; + explicit OfflineTdnnModelConfig(const std::string &model) : model(model) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_TDNN_MODEL_CONFIG_H_ diff --git a/sherpa-onnx/csrc/sherpa-onnx-offline.cc b/sherpa-onnx/csrc/sherpa-onnx-offline.cc index c51549af..e8c1a7b2 100644 --- a/sherpa-onnx/csrc/sherpa-onnx-offline.cc +++ b/sherpa-onnx/csrc/sherpa-onnx-offline.cc @@ -14,10 +14,14 @@ int main(int32_t argc, char *argv[]) { const char *kUsageMessage = R"usage( +Speech recognition using non-streaming models with sherpa-onnx. + Usage: (1) Transducer from icefall +See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-transducer/index.html + ./bin/sherpa-onnx-offline \ --tokens=/path/to/tokens.txt \ --encoder=/path/to/encoder.onnx \ @@ -30,6 +34,8 @@ Usage: (2) Paraformer from FunASR +See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-paraformer/index.html + ./bin/sherpa-onnx-offline \ --tokens=/path/to/tokens.txt \ --paraformer=/path/to/model.onnx \ @@ -39,6 +45,8 @@ Usage: (3) Whisper models +See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/tiny.en.html + ./bin/sherpa-onnx-offline \ --whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \ --whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \ @@ -46,6 +54,31 @@ Usage: --num-threads=1 \ /path/to/foo.wav [bar.wav foobar.wav ...] +(4) NeMo CTC models + +See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/index.html + + ./bin/sherpa-onnx-offline \ + --tokens=./sherpa-onnx-nemo-ctc-en-conformer-medium/tokens.txt \ + --nemo-ctc-model=./sherpa-onnx-nemo-ctc-en-conformer-medium/model.onnx \ + --num-threads=2 \ + --decoding-method=greedy_search \ + --debug=false \ + ./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/0.wav \ + ./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/1.wav \ + ./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/8k.wav + +(5) TDNN CTC model for the yesno recipe from icefall + +See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/yesno/index.html + // + ./build/bin/sherpa-onnx-offline \ + --sample-rate=8000 \ + --feat-dim=23 \ + --tokens=./sherpa-onnx-tdnn-yesno/tokens.txt \ + --tdnn-model=./sherpa-onnx-tdnn-yesno/model-epoch-14-avg-2.onnx \ + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_0_1_0_0_0_1.wav \ + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_0_1_0.wav Note: It supports decoding multiple files in batches diff --git a/sherpa-onnx/python/csrc/CMakeLists.txt b/sherpa-onnx/python/csrc/CMakeLists.txt index e58d60d9..28612924 100644 --- a/sherpa-onnx/python/csrc/CMakeLists.txt +++ b/sherpa-onnx/python/csrc/CMakeLists.txt @@ -10,6 +10,7 @@ pybind11_add_module(_sherpa_onnx offline-paraformer-model-config.cc offline-recognizer.cc offline-stream.cc + offline-tdnn-model-config.cc offline-transducer-model-config.cc offline-whisper-model-config.cc online-lm-config.cc diff --git a/sherpa-onnx/python/csrc/offline-model-config.cc b/sherpa-onnx/python/csrc/offline-model-config.cc index cfec6f14..4ed0483c 100644 --- a/sherpa-onnx/python/csrc/offline-model-config.cc +++ b/sherpa-onnx/python/csrc/offline-model-config.cc @@ -10,6 +10,7 @@ #include "sherpa-onnx/csrc/offline-model-config.h" #include "sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h" #include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h" +#include "sherpa-onnx/python/csrc/offline-tdnn-model-config.h" #include "sherpa-onnx/python/csrc/offline-transducer-model-config.h" #include "sherpa-onnx/python/csrc/offline-whisper-model-config.h" @@ -20,24 +21,28 @@ void PybindOfflineModelConfig(py::module *m) { PybindOfflineParaformerModelConfig(m); PybindOfflineNemoEncDecCtcModelConfig(m); PybindOfflineWhisperModelConfig(m); + PybindOfflineTdnnModelConfig(m); using PyClass = OfflineModelConfig; py::class_(*m, "OfflineModelConfig") .def(py::init(), py::arg("transducer") = OfflineTransducerModelConfig(), py::arg("paraformer") = OfflineParaformerModelConfig(), py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(), - py::arg("whisper") = OfflineWhisperModelConfig(), py::arg("tokens"), + py::arg("whisper") = OfflineWhisperModelConfig(), + py::arg("tdnn") = OfflineTdnnModelConfig(), py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false, py::arg("provider") = "cpu", py::arg("model_type") = "") .def_readwrite("transducer", &PyClass::transducer) .def_readwrite("paraformer", &PyClass::paraformer) .def_readwrite("nemo_ctc", &PyClass::nemo_ctc) .def_readwrite("whisper", &PyClass::whisper) + .def_readwrite("tdnn", &PyClass::tdnn) .def_readwrite("tokens", &PyClass::tokens) .def_readwrite("num_threads", &PyClass::num_threads) .def_readwrite("debug", &PyClass::debug) diff --git a/sherpa-onnx/python/csrc/offline-tdnn-model-config.cc b/sherpa-onnx/python/csrc/offline-tdnn-model-config.cc new file mode 100644 index 00000000..5a96fcfa --- /dev/null +++ b/sherpa-onnx/python/csrc/offline-tdnn-model-config.cc @@ -0,0 +1,22 @@ +// sherpa-onnx/python/csrc/offline-tdnn-model-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-tdnn-model-config.h" + +#include +#include + +#include "sherpa-onnx/python/csrc/offline-tdnn-model-config.h" + +namespace sherpa_onnx { + +void PybindOfflineTdnnModelConfig(py::module *m) { + using PyClass = OfflineTdnnModelConfig; + py::class_(*m, "OfflineTdnnModelConfig") + .def(py::init(), py::arg("model")) + .def_readwrite("model", &PyClass::model) + .def("__str__", &PyClass::ToString); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/offline-tdnn-model-config.h b/sherpa-onnx/python/csrc/offline-tdnn-model-config.h new file mode 100644 index 00000000..f68ba1ba --- /dev/null +++ b/sherpa-onnx/python/csrc/offline-tdnn-model-config.h @@ -0,0 +1,16 @@ +// sherpa-onnx/python/csrc/offline-tdnn-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TDNN_MODEL_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TDNN_MODEL_CONFIG_H_ + +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" + +namespace sherpa_onnx { + +void PybindOfflineTdnnModelConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TDNN_MODEL_CONFIG_H_ diff --git a/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py b/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py index cc5b5559..c87f34cf 100644 --- a/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py @@ -8,6 +8,7 @@ from _sherpa_onnx import ( OfflineModelConfig, OfflineNemoEncDecCtcModelConfig, OfflineParaformerModelConfig, + OfflineTdnnModelConfig, OfflineWhisperModelConfig, ) from _sherpa_onnx import OfflineRecognizer as _Recognizer @@ -37,7 +38,7 @@ class OfflineRecognizer(object): decoder: str, joiner: str, tokens: str, - num_threads: int, + num_threads: int = 1, sample_rate: int = 16000, feature_dim: int = 80, decoding_method: str = "greedy_search", @@ -48,7 +49,7 @@ class OfflineRecognizer(object): ): """ Please refer to - ``_ + ``_ to download pre-trained models for different languages, e.g., Chinese, English, etc. @@ -115,7 +116,7 @@ class OfflineRecognizer(object): cls, paraformer: str, tokens: str, - num_threads: int, + num_threads: int = 1, sample_rate: int = 16000, feature_dim: int = 80, decoding_method: str = "greedy_search", @@ -124,9 +125,8 @@ class OfflineRecognizer(object): ): """ Please refer to - ``_ - to download pre-trained models for different languages, e.g., Chinese, - English, etc. + ``_ + to download pre-trained models. Args: tokens: @@ -179,7 +179,7 @@ class OfflineRecognizer(object): cls, model: str, tokens: str, - num_threads: int, + num_threads: int = 1, sample_rate: int = 16000, feature_dim: int = 80, decoding_method: str = "greedy_search", @@ -188,7 +188,7 @@ class OfflineRecognizer(object): ): """ Please refer to - ``_ + ``_ to download pre-trained models for different languages, e.g., Chinese, English, etc. @@ -244,14 +244,14 @@ class OfflineRecognizer(object): encoder: str, decoder: str, tokens: str, - num_threads: int, + num_threads: int = 1, decoding_method: str = "greedy_search", debug: bool = False, provider: str = "cpu", ): """ Please refer to - ``_ + ``_ to download pre-trained models for different kinds of whisper models, e.g., tiny, tiny.en, base, base.en, etc. @@ -301,6 +301,69 @@ class OfflineRecognizer(object): self.config = recognizer_config return self + @classmethod + def from_tdnn_ctc( + cls, + model: str, + tokens: str, + num_threads: int = 1, + sample_rate: int = 8000, + feature_dim: int = 23, + decoding_method: str = "greedy_search", + debug: bool = False, + provider: str = "cpu", + ): + """ + Please refer to + ``_ + to download pre-trained models. + + Args: + model: + Path to ``model.onnx``. + tokens: + Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two + columns:: + + symbol integer_id + + num_threads: + Number of threads for neural network computation. + sample_rate: + Sample rate of the training data used to train the model. + feature_dim: + Dimension of the feature used to train the model. + decoding_method: + Valid values are greedy_search. + debug: + True to show debug messages. + provider: + onnxruntime execution providers. Valid values are: cpu, cuda, coreml. + """ + self = cls.__new__(cls) + model_config = OfflineModelConfig( + tdnn=OfflineTdnnModelConfig(model=model), + tokens=tokens, + num_threads=num_threads, + debug=debug, + provider=provider, + model_type="tdnn", + ) + + feat_config = OfflineFeatureExtractorConfig( + sampling_rate=sample_rate, + feature_dim=feature_dim, + ) + + recognizer_config = OfflineRecognizerConfig( + feat_config=feat_config, + model_config=model_config, + decoding_method=decoding_method, + ) + self.recognizer = _Recognizer(recognizer_config) + self.config = recognizer_config + return self + def create_stream(self, contexts_list: Optional[List[List[int]]] = None): if contexts_list is None: return self.recognizer.create_stream()