diff --git a/.github/scripts/Main.kt b/.github/scripts/Main.kt index 475baff5..8bbeb76e 100644 --- a/.github/scripts/Main.kt +++ b/.github/scripts/Main.kt @@ -33,18 +33,20 @@ fun main() { config = config, ) - var samples = WaveReader.readWave( + var objArray = WaveReader.readWave( assetManager = AssetManager(), filename = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/test_wavs/1089-134686-0001.wav", ) + var samples : FloatArray = objArray[0] as FloatArray + var sampleRate : Int = objArray[1] as Int - model.acceptWaveform(samples!!, sampleRate=16000) + model.acceptWaveform(samples, sampleRate=sampleRate) while (model.isReady()) { model.decode() } - var tail_paddings = FloatArray(8000) // 0.5 seconds - model.acceptWaveform(tail_paddings, sampleRate=16000) + var tail_paddings = FloatArray((sampleRate * 0.5).toInt()) // 0.5 seconds + model.acceptWaveform(tail_paddings, sampleRate=sampleRate) model.inputFinished() while (model.isReady()) { model.decode() diff --git a/.github/scripts/test-offline-transducer.sh b/.github/scripts/test-offline-transducer.sh new file mode 100755 index 00000000..9f98c522 --- /dev/null +++ b/.github/scripts/test-offline-transducer.sh @@ -0,0 +1,60 @@ +#!/usr/bin/env bash + +set -e + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +echo "EXE is $EXE" +echo "PATH: $PATH" + +which $EXE + +log "------------------------------------------------------------" +log "Run Conformer transducer (English)" +log "------------------------------------------------------------" + +repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-conformer-en-2023-03-18 +log "Start testing ${repo_url}" +repo=$(basename $repo_url) +log "Download pretrained model and test-data from $repo_url" + +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +pushd $repo +git lfs pull --include "*.onnx" +cd test_wavs +popd + +waves=( +$repo/test_wavs/0.wav +$repo/test_wavs/1.wav +$repo/test_wavs/2.wav +) + +for wave in ${waves[@]}; do + time $EXE \ + $repo/tokens.txt \ + $repo/encoder-epoch-99-avg-1.onnx \ + $repo/decoder-epoch-99-avg-1.onnx \ + $repo/joiner-epoch-99-avg-1.onnx \ + $wave \ + 2 +done + + +if command -v sox &> /dev/null; then + echo "test 8kHz" + sox $repo/test_wavs/0.wav -r 8000 8k.wav + time $EXE \ + $repo/tokens.txt \ + $repo/encoder-epoch-99-avg-1.onnx \ + $repo/decoder-epoch-99-avg-1.onnx \ + $repo/joiner-epoch-99-avg-1.onnx \ + 8k.wav \ + 2 +fi + +rm -rf $repo diff --git a/.github/scripts/test-online-transducer.sh b/.github/scripts/test-online-transducer.sh index c92971b8..138e0f73 100755 --- a/.github/scripts/test-online-transducer.sh +++ b/.github/scripts/test-online-transducer.sh @@ -40,7 +40,7 @@ for wave in ${waves[@]}; do $repo/decoder-epoch-99-avg-1.onnx \ $repo/joiner-epoch-99-avg-1.onnx \ $wave \ - 4 + 2 done rm -rf $repo @@ -72,7 +72,7 @@ for wave in ${waves[@]}; do $repo/decoder-epoch-11-avg-1.onnx \ $repo/joiner-epoch-11-avg-1.onnx \ $wave \ - 4 + 2 done rm -rf $repo @@ -104,7 +104,7 @@ for wave in ${waves[@]}; do $repo/decoder-epoch-99-avg-1.onnx \ $repo/joiner-epoch-99-avg-1.onnx \ $wave \ - 4 + 2 done rm -rf $repo @@ -138,7 +138,7 @@ for wave in ${waves[@]}; do $repo/decoder-epoch-99-avg-1.onnx \ $repo/joiner-epoch-99-avg-1.onnx \ $wave \ - 4 + 2 done # Decode a URL @@ -149,7 +149,7 @@ if [ $EXE == "sherpa-onnx-ffmpeg" ]; then $repo/decoder-epoch-99-avg-1.onnx \ $repo/joiner-epoch-99-avg-1.onnx \ https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/resolve/main/test_wavs/4.wav \ - 4 + 2 fi rm -rf $repo diff --git a/.github/workflows/linux.yaml b/.github/workflows/linux.yaml index 4870fc4b..bfb17f8b 100644 --- a/.github/workflows/linux.yaml +++ b/.github/workflows/linux.yaml @@ -7,11 +7,11 @@ on: paths: - '.github/workflows/linux.yaml' - '.github/scripts/test-online-transducer.sh' + - '.github/scripts/test-offline-transducer.sh' - 'CMakeLists.txt' - 'cmake/**' - 'sherpa-onnx/csrc/*' - 'sherpa-onnx/c-api/*' - - 'ffmpeg-examples/**' - 'c-api-examples/**' pull_request: branches: @@ -19,11 +19,11 @@ on: paths: - '.github/workflows/linux.yaml' - '.github/scripts/test-online-transducer.sh' + - '.github/scripts/test-offline-transducer.sh' - 'CMakeLists.txt' - 'cmake/**' - 'sherpa-onnx/csrc/*' - 'sherpa-onnx/c-api/*' - - 'ffmpeg-examples/**' concurrency: group: linux-${{ github.ref }} @@ -39,35 +39,26 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest] + build_type: [Release, Debug] steps: - uses: actions/checkout@v2 with: fetch-depth: 0 - - name: Install ffmpeg + - name: Install sox shell: bash run: | - sudo apt-get install -y software-properties-common - sudo add-apt-repository ppa:savoury1/ffmpeg4 - sudo add-apt-repository ppa:savoury1/ffmpeg5 - - sudo apt-get install -y libavdevice-dev libavutil-dev ffmpeg - pkg-config --modversion libavutil - ffmpeg -version - - - name: Show ffmpeg version - shell: bash - run: | - pkg-config --modversion libavutil - ffmpeg -version + sudo apt-get update + sudo apt-get install -y sox + sox -h - name: Configure CMake shell: bash run: | mkdir build cd build - cmake -D CMAKE_BUILD_TYPE=Release .. + cmake -D CMAKE_BUILD_TYPE=${{ matrix.build_type }} .. - name: Build sherpa-onnx for ubuntu shell: bash @@ -78,21 +69,19 @@ jobs: ls -lh lib ls -lh bin - cd ../ffmpeg-examples - make - - name: Display dependencies of sherpa-onnx for linux shell: bash run: | file build/bin/sherpa-onnx readelf -d build/bin/sherpa-onnx - - name: Test sherpa-onnx-ffmpeg + - name: Test offline transducer + shell: bash run: | - export PATH=$PWD/ffmpeg-examples:$PATH - export EXE=sherpa-onnx-ffmpeg + export PATH=$PWD/build/bin:$PATH + export EXE=sherpa-onnx-offline - .github/scripts/test-online-transducer.sh + .github/scripts/test-offline-transducer.sh - name: Test online transducer shell: bash diff --git a/.github/workflows/macos.yaml b/.github/workflows/macos.yaml index b3243b84..a5d93307 100644 --- a/.github/workflows/macos.yaml +++ b/.github/workflows/macos.yaml @@ -7,6 +7,7 @@ on: paths: - '.github/workflows/macos.yaml' - '.github/scripts/test-online-transducer.sh' + - '.github/scripts/test-offline-transducer.sh' - 'CMakeLists.txt' - 'cmake/**' - 'sherpa-onnx/csrc/*' @@ -16,6 +17,7 @@ on: paths: - '.github/workflows/macos.yaml' - '.github/scripts/test-online-transducer.sh' + - '.github/scripts/test-offline-transducer.sh' - 'CMakeLists.txt' - 'cmake/**' - 'sherpa-onnx/csrc/*' @@ -34,18 +36,25 @@ jobs: fail-fast: false matrix: os: [macos-latest] + build_type: [Release, Debug] steps: - uses: actions/checkout@v2 with: fetch-depth: 0 + - name: Install sox + shell: bash + run: | + brew install sox + sox -h + - name: Configure CMake shell: bash run: | mkdir build cd build - cmake -D CMAKE_BUILD_TYPE=Release .. + cmake -D CMAKE_BUILD_TYPE=${{ matrix.build_type }} .. - name: Build sherpa-onnx for macos shell: bash @@ -64,6 +73,14 @@ jobs: otool -L build/bin/sherpa-onnx otool -l build/bin/sherpa-onnx + - name: Test offline transducer + shell: bash + run: | + export PATH=$PWD/build/bin:$PATH + export EXE=sherpa-onnx-offline + + .github/scripts/test-offline-transducer.sh + - name: Test online transducer shell: bash run: | diff --git a/.gitignore b/.gitignore index 716dfd3a..6c4f80e0 100644 --- a/.gitignore +++ b/.gitignore @@ -39,3 +39,5 @@ tags run-decode-file-python.sh android/SherpaOnnx/app/src/main/assets/ *.ncnn.* +run-sherpa-onnx-offline.sh +sherpa-onnx-conformer-en-2023-03-18 diff --git a/CMakeLists.txt b/CMakeLists.txt index 9b5601cc..0432ef4f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -13,7 +13,7 @@ endif() option(SHERPA_ONNX_ENABLE_PYTHON "Whether to build Python" OFF) option(SHERPA_ONNX_ENABLE_TESTS "Whether to build tests" OFF) -option(SHERPA_ONNX_ENABLE_CHECK "Whether to build with assert" ON) +option(SHERPA_ONNX_ENABLE_CHECK "Whether to build with assert" OFF) option(BUILD_SHARED_LIBS "Whether to build shared libraries" OFF) option(SHERPA_ONNX_ENABLE_PORTAUDIO "Whether to build with portaudio" ON) option(SHERPA_ONNX_ENABLE_JNI "Whether to build JNI internface" OFF) diff --git a/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt b/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt index 78caea6e..a653b639 100644 --- a/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt +++ b/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt @@ -121,7 +121,7 @@ class MainActivity : AppCompatActivity() { val ret = audioRecord?.read(buffer, 0, buffer.size) if (ret != null && ret > 0) { val samples = FloatArray(ret) { buffer[it] / 32768.0f } - model.acceptWaveform(samples, sampleRate=16000) + model.acceptWaveform(samples, sampleRate=sampleRateInHz) while (model.isReady()) { model.decode() } @@ -180,7 +180,7 @@ class MainActivity : AppCompatActivity() { val type = 0 println("Select model type ${type}") val config = OnlineRecognizerConfig( - featConfig = getFeatureConfig(sampleRate = 16000, featureDim = 80), + featConfig = getFeatureConfig(sampleRate = sampleRateInHz, featureDim = 80), modelConfig = getModelConfig(type = type)!!, endpointConfig = getEndpointConfig(), enableEndpoint = true, diff --git a/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/WaveReader.kt b/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/WaveReader.kt index 4444ab93..82cf7cac 100644 --- a/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/WaveReader.kt +++ b/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/WaveReader.kt @@ -8,7 +8,7 @@ class WaveReader { // No resampling is made. external fun readWave( assetManager: AssetManager, filename: String, expected_sample_rate: Float = 16000.0f - ): FloatArray? + ): Array init { System.loadLibrary("sherpa-onnx-jni") diff --git a/cmake/kaldi-native-fbank.cmake b/cmake/kaldi-native-fbank.cmake index a968bf24..4f39163f 100644 --- a/cmake/kaldi-native-fbank.cmake +++ b/cmake/kaldi-native-fbank.cmake @@ -1,9 +1,9 @@ function(download_kaldi_native_fbank) include(FetchContent) - set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.13.tar.gz") - set(kaldi_native_fbank_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.13.tar.gz") - set(kaldi_native_fbank_HASH "SHA256=1f4d228f9fe3e3e9f92a74a7eecd2489071a03982e4ba6d7c70fc5fa7444df57") + set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.14.tar.gz") + set(kaldi_native_fbank_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.14.tar.gz") + set(kaldi_native_fbank_HASH "SHA256=6a66638a111d3ce21fe6f29cbf9ab3dbcae2331c77391bf825927df5cbf2babe") set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE) set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE) @@ -12,11 +12,11 @@ function(download_kaldi_native_fbank) # If you don't have access to the Internet, # please pre-download kaldi-native-fbank set(possible_file_locations - $ENV{HOME}/Downloads/kaldi-native-fbank-1.13.tar.gz - ${PROJECT_SOURCE_DIR}/kaldi-native-fbank-1.13.tar.gz - ${PROJECT_BINARY_DIR}/kaldi-native-fbank-1.13.tar.gz - /tmp/kaldi-native-fbank-1.13.tar.gz - /star-fj/fangjun/download/github/kaldi-native-fbank-1.13.tar.gz + $ENV{HOME}/Downloads/kaldi-native-fbank-1.14.tar.gz + ${PROJECT_SOURCE_DIR}/kaldi-native-fbank-1.14.tar.gz + ${PROJECT_BINARY_DIR}/kaldi-native-fbank-1.14.tar.gz + /tmp/kaldi-native-fbank-1.14.tar.gz + /star-fj/fangjun/download/github/kaldi-native-fbank-1.14.tar.gz ) foreach(f IN LISTS possible_file_locations) diff --git a/python-api-examples/speech-recognition-from-microphone-with-endpoint-detection.py b/python-api-examples/speech-recognition-from-microphone-with-endpoint-detection.py index 9840b51e..fbc71e7f 100755 --- a/python-api-examples/speech-recognition-from-microphone-with-endpoint-detection.py +++ b/python-api-examples/speech-recognition-from-microphone-with-endpoint-detection.py @@ -91,7 +91,6 @@ def create_recognizer(): rule2_min_trailing_silence=1.2, rule3_min_utterance_length=300, # it essentially disables this rule decoding_method=args.decoding_method, - max_feature_vectors=100, # 1 second ) return recognizer diff --git a/python-api-examples/speech-recognition-from-microphone.py b/python-api-examples/speech-recognition-from-microphone.py index e13b8d7f..6dc608d1 100755 --- a/python-api-examples/speech-recognition-from-microphone.py +++ b/python-api-examples/speech-recognition-from-microphone.py @@ -86,7 +86,6 @@ def create_recognizer(): sample_rate=16000, feature_dim=80, decoding_method=args.decoding_method, - max_feature_vectors=100, # 1 second ) return recognizer diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index a44f6678..5eb1fb49 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -6,6 +6,11 @@ set(sources features.cc file-utils.cc hypothesis.cc + offline-stream.cc + offline-transducer-greedy-search-decoder.cc + offline-transducer-model-config.cc + offline-transducer-model.cc + offline-recognizer.cc online-lstm-transducer-model.cc online-recognizer.cc online-stream.cc @@ -56,10 +61,13 @@ if(SHERPA_ONNX_ENABLE_CHECK) endif() add_executable(sherpa-onnx sherpa-onnx.cc) +add_executable(sherpa-onnx-offline sherpa-onnx-offline.cc) target_link_libraries(sherpa-onnx sherpa-onnx-core) +target_link_libraries(sherpa-onnx-offline sherpa-onnx-core) if(NOT WIN32) target_link_libraries(sherpa-onnx "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib") + target_link_libraries(sherpa-onnx-offline "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib") endif() if(SHERPA_ONNX_ENABLE_PYTHON AND WIN32) @@ -68,7 +76,13 @@ else() install(TARGETS sherpa-onnx-core DESTINATION lib) endif() -install(TARGETS sherpa-onnx DESTINATION bin) +install( + TARGETS + sherpa-onnx + sherpa-onnx-offline + DESTINATION + bin +) if(SHERPA_ONNX_HAS_ALSA) add_executable(sherpa-onnx-alsa sherpa-onnx-alsa.cc alsa.cc) diff --git a/sherpa-onnx/csrc/features.cc b/sherpa-onnx/csrc/features.cc index eab137e7..bdfb9fd8 100644 --- a/sherpa-onnx/csrc/features.cc +++ b/sherpa-onnx/csrc/features.cc @@ -19,7 +19,9 @@ namespace sherpa_onnx { void FeatureExtractorConfig::Register(ParseOptions *po) { po->Register("sample-rate", &sampling_rate, "Sampling rate of the input waveform. Must match the one " - "expected by the model."); + "expected by the model. Note: You can have a different " + "sample rate for the input waveform. We will do resampling " + "inside the feature extractor"); po->Register("feat-dim", &feature_dim, "Feature dimension. Must match the one expected by the model."); @@ -30,8 +32,7 @@ std::string FeatureExtractorConfig::ToString() const { os << "FeatureExtractorConfig("; os << "sampling_rate=" << sampling_rate << ", "; - os << "feature_dim=" << feature_dim << ", "; - os << "max_feature_vectors=" << max_feature_vectors << ")"; + os << "feature_dim=" << feature_dim << ")"; return os.str(); } @@ -43,8 +44,6 @@ class FeatureExtractor::Impl { opts_.frame_opts.snip_edges = false; opts_.frame_opts.samp_freq = config.sampling_rate; - opts_.frame_opts.max_feature_vectors = config.max_feature_vectors; - opts_.mel_opts.num_bins = config.feature_dim; fbank_ = std::make_unique(opts_); @@ -95,7 +94,7 @@ class FeatureExtractor::Impl { fbank_->AcceptWaveform(sampling_rate, waveform, n); } - void InputFinished() { + void InputFinished() const { std::lock_guard lock(mutex_); fbank_->InputFinished(); } @@ -110,12 +109,21 @@ class FeatureExtractor::Impl { return fbank_->IsLastFrame(frame); } - std::vector GetFrames(int32_t frame_index, int32_t n) const { - if (frame_index + n > NumFramesReady()) { - fprintf(stderr, "%d + %d > %d\n", frame_index, n, NumFramesReady()); + std::vector GetFrames(int32_t frame_index, int32_t n) { + std::lock_guard lock(mutex_); + if (frame_index + n > fbank_->NumFramesReady()) { + SHERPA_ONNX_LOGE("%d + %d > %d\n", frame_index, n, + fbank_->NumFramesReady()); exit(-1); } - std::lock_guard lock(mutex_); + + int32_t discard_num = frame_index - last_frame_index_; + if (discard_num < 0) { + SHERPA_ONNX_LOGE("last_frame_index_: %d, frame_index_: %d", + last_frame_index_, frame_index); + exit(-1); + } + fbank_->Pop(discard_num); int32_t feature_dim = fbank_->Dim(); std::vector features(feature_dim * n); @@ -128,12 +136,9 @@ class FeatureExtractor::Impl { p += feature_dim; } - return features; - } + last_frame_index_ = frame_index; - void Reset() { - std::lock_guard lock(mutex_); - fbank_ = std::make_unique(opts_); + return features; } int32_t FeatureDim() const { return opts_.mel_opts.num_bins; } @@ -143,6 +148,7 @@ class FeatureExtractor::Impl { knf::FbankOptions opts_; mutable std::mutex mutex_; std::unique_ptr resampler_; + int32_t last_frame_index_ = 0; }; FeatureExtractor::FeatureExtractor(const FeatureExtractorConfig &config /*={}*/) @@ -151,11 +157,11 @@ FeatureExtractor::FeatureExtractor(const FeatureExtractorConfig &config /*={}*/) FeatureExtractor::~FeatureExtractor() = default; void FeatureExtractor::AcceptWaveform(int32_t sampling_rate, - const float *waveform, int32_t n) { + const float *waveform, int32_t n) const { impl_->AcceptWaveform(sampling_rate, waveform, n); } -void FeatureExtractor::InputFinished() { impl_->InputFinished(); } +void FeatureExtractor::InputFinished() const { impl_->InputFinished(); } int32_t FeatureExtractor::NumFramesReady() const { return impl_->NumFramesReady(); @@ -170,8 +176,6 @@ std::vector FeatureExtractor::GetFrames(int32_t frame_index, return impl_->GetFrames(frame_index, n); } -void FeatureExtractor::Reset() { impl_->Reset(); } - int32_t FeatureExtractor::FeatureDim() const { return impl_->FeatureDim(); } } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/features.h b/sherpa-onnx/csrc/features.h index 831f221e..d4eaffda 100644 --- a/sherpa-onnx/csrc/features.h +++ b/sherpa-onnx/csrc/features.h @@ -14,9 +14,12 @@ namespace sherpa_onnx { struct FeatureExtractorConfig { + // Sampling rate used by the feature extractor. If it is different from + // the sampling rate of the input waveform, we will do resampling inside. int32_t sampling_rate = 16000; + + // Feature dimension int32_t feature_dim = 80; - int32_t max_feature_vectors = -1; std::string ToString() const; @@ -36,7 +39,8 @@ class FeatureExtractor { the range [-1, 1]. @param n Number of entries in waveform */ - void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n); + void AcceptWaveform(int32_t sampling_rate, const float *waveform, + int32_t n) const; /** * InputFinished() tells the class you won't be providing any @@ -44,7 +48,7 @@ class FeatureExtractor { * of features, in the case where snip-edges == false; it also * affects the return value of IsLastFrame(). */ - void InputFinished(); + void InputFinished() const; int32_t NumFramesReady() const; @@ -62,8 +66,6 @@ class FeatureExtractor { */ std::vector GetFrames(int32_t frame_index, int32_t n) const; - void Reset(); - /// Return feature dim of this extractor int32_t FeatureDim() const; diff --git a/sherpa-onnx/csrc/offline-recognizer.cc b/sherpa-onnx/csrc/offline-recognizer.cc new file mode 100644 index 00000000..30a4154a --- /dev/null +++ b/sherpa-onnx/csrc/offline-recognizer.cc @@ -0,0 +1,163 @@ +// sherpa-onnx/csrc/offline-recognizer.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-recognizer.h" + +#include +#include + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/offline-transducer-decoder.h" +#include "sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h" +#include "sherpa-onnx/csrc/offline-transducer-model.h" +#include "sherpa-onnx/csrc/pad-sequence.h" +#include "sherpa-onnx/csrc/symbol-table.h" + +namespace sherpa_onnx { + +static OfflineRecognitionResult Convert( + const OfflineTransducerDecoderResult &src, const SymbolTable &sym_table, + int32_t frame_shift_ms, int32_t subsampling_factor) { + OfflineRecognitionResult r; + r.tokens.reserve(src.tokens.size()); + r.timestamps.reserve(src.timestamps.size()); + + std::string text; + for (auto i : src.tokens) { + auto sym = sym_table[i]; + text.append(sym); + + r.tokens.push_back(std::move(sym)); + } + r.text = std::move(text); + + float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor; + for (auto t : src.timestamps) { + float time = frame_shift_s * t; + r.timestamps.push_back(time); + } + + return r; +} + +void OfflineRecognizerConfig::Register(ParseOptions *po) { + feat_config.Register(po); + model_config.Register(po); + + po->Register("decoding-method", &decoding_method, + "decoding method," + "Valid values: greedy_search."); +} + +bool OfflineRecognizerConfig::Validate() const { + return model_config.Validate(); +} + +std::string OfflineRecognizerConfig::ToString() const { + std::ostringstream os; + + os << "OfflineRecognizerConfig("; + os << "feat_config=" << feat_config.ToString() << ", "; + os << "model_config=" << model_config.ToString() << ", "; + os << "decoding_method=\"" << decoding_method << "\")"; + + return os.str(); +} + +class OfflineRecognizer::Impl { + public: + explicit Impl(const OfflineRecognizerConfig &config) + : config_(config), + symbol_table_(config_.model_config.tokens), + model_(std::make_unique(config_.model_config)) { + if (config_.decoding_method == "greedy_search") { + decoder_ = + std::make_unique(model_.get()); + } else if (config_.decoding_method == "modified_beam_search") { + SHERPA_ONNX_LOGE("TODO: modified_beam_search is to be implemented"); + exit(-1); + } else { + SHERPA_ONNX_LOGE("Unsupported decoding method: %s", + config_.decoding_method.c_str()); + exit(-1); + } + } + + std::unique_ptr CreateStream() const { + return std::make_unique(config_.feat_config); + } + + void DecodeStreams(OfflineStream **ss, int32_t n) const { + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + int32_t feat_dim = ss[0]->FeatureDim(); + + std::vector features; + + features.reserve(n); + + std::vector> features_vec(n); + std::vector features_length_vec(n); + for (int32_t i = 0; i != n; ++i) { + auto f = ss[i]->GetFrames(); + int32_t num_frames = f.size() / feat_dim; + + features_length_vec[i] = num_frames; + features_vec[i] = std::move(f); + + std::array shape = {num_frames, feat_dim}; + + Ort::Value x = Ort::Value::CreateTensor( + memory_info, features_vec[i].data(), features_vec[i].size(), + shape.data(), shape.size()); + features.push_back(std::move(x)); + } + + std::vector features_pointer(n); + for (int32_t i = 0; i != n; ++i) { + features_pointer[i] = &features[i]; + } + + std::array features_length_shape = {n}; + Ort::Value x_length = Ort::Value::CreateTensor( + memory_info, features_length_vec.data(), n, + features_length_shape.data(), features_length_shape.size()); + + Ort::Value x = PadSequence(model_->Allocator(), features_pointer, + -23.025850929940457f); + + auto t = model_->RunEncoder(std::move(x), std::move(x_length)); + auto results = decoder_->Decode(std::move(t.first), std::move(t.second)); + + int32_t frame_shift_ms = 10; + for (int32_t i = 0; i != n; ++i) { + auto r = Convert(results[i], symbol_table_, frame_shift_ms, + model_->SubsamplingFactor()); + + ss[i]->SetResult(r); + } + } + + private: + OfflineRecognizerConfig config_; + SymbolTable symbol_table_; + std::unique_ptr model_; + std::unique_ptr decoder_; +}; + +OfflineRecognizer::OfflineRecognizer(const OfflineRecognizerConfig &config) + : impl_(std::make_unique(config)) {} + +OfflineRecognizer::~OfflineRecognizer() = default; + +std::unique_ptr OfflineRecognizer::CreateStream() const { + return impl_->CreateStream(); +} + +void OfflineRecognizer::DecodeStreams(OfflineStream **ss, int32_t n) const { + impl_->DecodeStreams(ss, n); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-recognizer.h b/sherpa-onnx/csrc/offline-recognizer.h new file mode 100644 index 00000000..49423a03 --- /dev/null +++ b/sherpa-onnx/csrc/offline-recognizer.h @@ -0,0 +1,87 @@ +// sherpa-onnx/csrc/offline-recognizer.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_H_ + +#include +#include +#include + +#include "sherpa-onnx/csrc/offline-stream.h" +#include "sherpa-onnx/csrc/offline-transducer-model-config.h" +#include "sherpa-onnx/csrc/parse-options.h" + +namespace sherpa_onnx { + +struct OfflineRecognitionResult { + // Recognition results. + // For English, it consists of space separated words. + // For Chinese, it consists of Chinese words without spaces. + std::string text; + + // Decoded results at the token level. + // For instance, for BPE-based models it consists of a list of BPE tokens. + std::vector tokens; + + /// timestamps.size() == tokens.size() + /// timestamps[i] records the time in seconds when tokens[i] is decoded. + std::vector timestamps; +}; + +struct OfflineRecognizerConfig { + OfflineFeatureExtractorConfig feat_config; + OfflineTransducerModelConfig model_config; + + std::string decoding_method = "greedy_search"; + // only greedy_search is implemented + // TODO(fangjun): Implement modified_beam_search + + OfflineRecognizerConfig() = default; + OfflineRecognizerConfig(const OfflineFeatureExtractorConfig &feat_config, + const OfflineTransducerModelConfig &model_config, + const std::string &decoding_method) + : feat_config(feat_config), + model_config(model_config), + decoding_method(decoding_method) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +class OfflineRecognizer { + public: + ~OfflineRecognizer(); + + explicit OfflineRecognizer(const OfflineRecognizerConfig &config); + + /// Create a stream for decoding. + std::unique_ptr CreateStream() const; + + /** Decode a single stream + * + * @param s The stream to decode. + */ + void DecodeStream(OfflineStream *s) const { + OfflineStream *ss[1] = {s}; + DecodeStreams(ss, 1); + } + + /** Decode a list of streams. + * + * @param ss Pointer to an array of streams. + * @param n Size of the input array. + */ + void DecodeStreams(OfflineStream **ss, int32_t n) const; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_H_ diff --git a/sherpa-onnx/csrc/offline-stream.cc b/sherpa-onnx/csrc/offline-stream.cc new file mode 100644 index 00000000..2fd0a8ab --- /dev/null +++ b/sherpa-onnx/csrc/offline-stream.cc @@ -0,0 +1,134 @@ +// sherpa-onnx/csrc/offline-stream.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-stream.h" + +#include + +#include + +#include "kaldi-native-fbank/csrc/online-feature.h" +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/offline-recognizer.h" +#include "sherpa-onnx/csrc/resample.h" + +namespace sherpa_onnx { + +void OfflineFeatureExtractorConfig::Register(ParseOptions *po) { + po->Register("sample-rate", &sampling_rate, + "Sampling rate of the input waveform. Must match the one " + "expected by the model. Note: You can have a different " + "sample rate for the input waveform. We will do resampling " + "inside the feature extractor"); + + po->Register("feat-dim", &feature_dim, + "Feature dimension. Must match the one expected by the model."); +} + +std::string OfflineFeatureExtractorConfig::ToString() const { + std::ostringstream os; + + os << "OfflineFeatureExtractorConfig("; + os << "sampling_rate=" << sampling_rate << ", "; + os << "feature_dim=" << feature_dim << ")"; + + return os.str(); +} + +class OfflineStream::Impl { + public: + explicit Impl(const OfflineFeatureExtractorConfig &config) { + opts_.frame_opts.dither = 0; + opts_.frame_opts.snip_edges = false; + opts_.frame_opts.samp_freq = config.sampling_rate; + opts_.mel_opts.num_bins = config.feature_dim; + + fbank_ = std::make_unique(opts_); + } + + void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) { + if (sampling_rate != opts_.frame_opts.samp_freq) { + SHERPA_ONNX_LOGE( + "Creating a resampler:\n" + " in_sample_rate: %d\n" + " output_sample_rate: %d\n", + sampling_rate, static_cast(opts_.frame_opts.samp_freq)); + + float min_freq = + std::min(sampling_rate, opts_.frame_opts.samp_freq); + float lowpass_cutoff = 0.99 * 0.5 * min_freq; + + int32_t lowpass_filter_width = 6; + auto resampler = std::make_unique( + sampling_rate, opts_.frame_opts.samp_freq, lowpass_cutoff, + lowpass_filter_width); + std::vector samples; + resampler->Resample(waveform, n, true, &samples); + fbank_->AcceptWaveform(opts_.frame_opts.samp_freq, samples.data(), + samples.size()); + fbank_->InputFinished(); + return; + } + + fbank_->AcceptWaveform(sampling_rate, waveform, n); + fbank_->InputFinished(); + } + + int32_t FeatureDim() const { return opts_.mel_opts.num_bins; } + + std::vector GetFrames() const { + int32_t n = fbank_->NumFramesReady(); + assert(n > 0 && "Please first call AcceptWaveform()"); + + int32_t feature_dim = FeatureDim(); + + std::vector features(n * feature_dim); + + float *p = features.data(); + + for (int32_t i = 0; i != n; ++i) { + const float *f = fbank_->GetFrame(i); + std::copy(f, f + feature_dim, p); + p += feature_dim; + } + + return features; + } + + void SetResult(const OfflineRecognitionResult &r) { r_ = r; } + + const OfflineRecognitionResult &GetResult() const { return r_; } + + private: + std::unique_ptr fbank_; + knf::FbankOptions opts_; + OfflineRecognitionResult r_; +}; + +OfflineStream::OfflineStream( + const OfflineFeatureExtractorConfig &config /*= {}*/) + : impl_(std::make_unique(config)) {} + +OfflineStream::~OfflineStream() = default; + +void OfflineStream::AcceptWaveform(int32_t sampling_rate, const float *waveform, + int32_t n) const { + impl_->AcceptWaveform(sampling_rate, waveform, n); +} + +int32_t OfflineStream::FeatureDim() const { return impl_->FeatureDim(); } + +std::vector OfflineStream::GetFrames() const { + return impl_->GetFrames(); +} + +void OfflineStream::SetResult(const OfflineRecognitionResult &r) { + impl_->SetResult(r); +} + +const OfflineRecognitionResult &OfflineStream::GetResult() const { + return impl_->GetResult(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-stream.h b/sherpa-onnx/csrc/offline-stream.h new file mode 100644 index 00000000..3059c38a --- /dev/null +++ b/sherpa-onnx/csrc/offline-stream.h @@ -0,0 +1,70 @@ +// sherpa-onnx/csrc/offline-stream.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_STREAM_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_STREAM_H_ +#include + +#include +#include +#include + +#include "sherpa-onnx/csrc/parse-options.h" + +namespace sherpa_onnx { +struct OfflineRecognitionResult; + +struct OfflineFeatureExtractorConfig { + // Sampling rate used by the feature extractor. If it is different from + // the sampling rate of the input waveform, we will do resampling inside. + int32_t sampling_rate = 16000; + + // Feature dimension + int32_t feature_dim = 80; + + std::string ToString() const; + + void Register(ParseOptions *po); +}; + +class OfflineStream { + public: + explicit OfflineStream(const OfflineFeatureExtractorConfig &config = {}); + ~OfflineStream(); + + /** + @param sampling_rate The sampling_rate of the input waveform. If it does + not equal to config.sampling_rate, we will do + resampling inside. + @param waveform Pointer to a 1-D array of size n. It must be normalized to + the range [-1, 1]. + @param n Number of entries in waveform + + Caution: You can only invoke this function once so you have to input + all the samples at once + */ + void AcceptWaveform(int32_t sampling_rate, const float *waveform, + int32_t n) const; + + /// Return feature dim of this extractor + int32_t FeatureDim() const; + + // Get all the feature frames of this stream in a 1-D array, which is + // flattened from a 2-D array of shape (num_frames, feat_dim). + std::vector GetFrames() const; + + /** Set the recognition result for this stream. */ + void SetResult(const OfflineRecognitionResult &r); + + /** Get the recognition result of this stream */ + const OfflineRecognitionResult &GetResult() const; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_STREAM_H_ diff --git a/sherpa-onnx/csrc/offline-transducer-decoder.h b/sherpa-onnx/csrc/offline-transducer-decoder.h new file mode 100644 index 00000000..898fc29c --- /dev/null +++ b/sherpa-onnx/csrc/offline-transducer-decoder.h @@ -0,0 +1,41 @@ +// sherpa-onnx/csrc/offline-transducer-decoder.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_DECODER_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_DECODER_H_ + +#include + +#include "onnxruntime_cxx_api.h" // NOLINT + +namespace sherpa_onnx { + +struct OfflineTransducerDecoderResult { + /// The decoded token IDs + std::vector tokens; + + /// timestamps[i] contains the output frame index where tokens[i] is decoded. + /// Note: The index is after subsampling + std::vector timestamps; +}; + +class OfflineTransducerDecoder { + public: + virtual ~OfflineTransducerDecoder() = default; + + /** Run transducer beam search given the output from the encoder model. + * + * @param encoder_out A 3-D tensor of shape (N, T, joiner_dim) + * @param encoder_out_length A 1-D tensor of shape (N,) containing number + * of valid frames in encoder_out before padding. + * + * @return Return a vector of size `N` containing the decoded results. + */ + virtual std::vector Decode( + Ort::Value encoder_out, Ort::Value encoder_out_length) = 0; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_DECODER_H_ diff --git a/sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.cc b/sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.cc new file mode 100644 index 00000000..6432ff94 --- /dev/null +++ b/sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.cc @@ -0,0 +1,79 @@ +// sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h" + +#include +#include +#include + +#include "sherpa-onnx/csrc/onnx-utils.h" +#include "sherpa-onnx/csrc/packed-sequence.h" +#include "sherpa-onnx/csrc/slice.h" + +namespace sherpa_onnx { + +std::vector +OfflineTransducerGreedySearchDecoder::Decode(Ort::Value encoder_out, + Ort::Value encoder_out_length) { + PackedSequence packed_encoder_out = PackPaddedSequence( + model_->Allocator(), &encoder_out, &encoder_out_length); + + int32_t batch_size = + static_cast(packed_encoder_out.sorted_indexes.size()); + + int32_t vocab_size = model_->VocabSize(); + int32_t context_size = model_->ContextSize(); + + std::vector ans(batch_size); + for (auto &r : ans) { + // 0 is the ID of the blank token + r.tokens.resize(context_size, 0); + } + + auto decoder_input = model_->BuildDecoderInput(ans, ans.size()); + Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input)); + + int32_t start = 0; + int32_t t = 0; + for (auto n : packed_encoder_out.batch_sizes) { + Ort::Value cur_encoder_out = packed_encoder_out.Get(start, n); + Ort::Value cur_decoder_out = Slice(model_->Allocator(), &decoder_out, 0, n); + start += n; + Ort::Value logit = model_->RunJoiner(std::move(cur_encoder_out), + std::move(cur_decoder_out)); + const float *p_logit = logit.GetTensorData(); + bool emitted = false; + for (int32_t i = 0; i != n; ++i) { + auto y = static_cast(std::distance( + static_cast(p_logit), + std::max_element(static_cast(p_logit), + static_cast(p_logit) + vocab_size))); + p_logit += vocab_size; + if (y != 0) { + ans[i].tokens.push_back(y); + ans[i].timestamps.push_back(t); + emitted = true; + } + } + if (emitted) { + Ort::Value decoder_input = model_->BuildDecoderInput(ans, n); + decoder_out = model_->RunDecoder(std::move(decoder_input)); + } + ++t; + } + + for (auto &r : ans) { + r.tokens = {r.tokens.begin() + context_size, r.tokens.end()}; + } + + std::vector unsorted_ans(batch_size); + for (int32_t i = 0; i != batch_size; ++i) { + unsorted_ans[packed_encoder_out.sorted_indexes[i]] = std::move(ans[i]); + } + + return unsorted_ans; +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h b/sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h new file mode 100644 index 00000000..a0175d5c --- /dev/null +++ b/sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h @@ -0,0 +1,29 @@ +// sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_GREEDY_SEARCH_DECODER_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_GREEDY_SEARCH_DECODER_H_ + +#include + +#include "sherpa-onnx/csrc/offline-transducer-decoder.h" +#include "sherpa-onnx/csrc/offline-transducer-model.h" + +namespace sherpa_onnx { + +class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder { + public: + explicit OfflineTransducerGreedySearchDecoder(OfflineTransducerModel *model) + : model_(model) {} + + std::vector Decode( + Ort::Value encoder_out, Ort::Value encoder_out_length) override; + + private: + OfflineTransducerModel *model_; // Not owned +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_GREEDY_SEARCH_DECODER_H_ diff --git a/sherpa-onnx/csrc/offline-transducer-model-config.cc b/sherpa-onnx/csrc/offline-transducer-model-config.cc new file mode 100644 index 00000000..b66ff303 --- /dev/null +++ b/sherpa-onnx/csrc/offline-transducer-model-config.cc @@ -0,0 +1,68 @@ +// sherpa-onnx/csrc/offline-transducer-model-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation +#include "sherpa-onnx/csrc/offline-transducer-model-config.h" + +#include + +#include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/macros.h" + +namespace sherpa_onnx { + +void OfflineTransducerModelConfig::Register(ParseOptions *po) { + po->Register("encoder", &encoder_filename, "Path to encoder.onnx"); + po->Register("decoder", &decoder_filename, "Path to decoder.onnx"); + po->Register("joiner", &joiner_filename, "Path to joiner.onnx"); + po->Register("tokens", &tokens, "Path to tokens.txt"); + po->Register("num_threads", &num_threads, + "Number of threads to run the neural network"); + + po->Register("debug", &debug, + "true to print model information while loading it."); +} + +bool OfflineTransducerModelConfig::Validate() const { + if (!FileExists(tokens)) { + SHERPA_ONNX_LOGE("%s does not exist", tokens.c_str()); + return false; + } + + if (!FileExists(encoder_filename)) { + SHERPA_ONNX_LOGE("%s does not exist", encoder_filename.c_str()); + return false; + } + + if (!FileExists(decoder_filename)) { + SHERPA_ONNX_LOGE("%s does not exist", decoder_filename.c_str()); + return false; + } + + if (!FileExists(joiner_filename)) { + SHERPA_ONNX_LOGE("%s does not exist", joiner_filename.c_str()); + return false; + } + + if (num_threads < 1) { + SHERPA_ONNX_LOGE("num_threads should be > 0. Given %d", num_threads); + return false; + } + + return true; +} + +std::string OfflineTransducerModelConfig::ToString() const { + std::ostringstream os; + + os << "OfflineTransducerModelConfig("; + os << "encoder_filename=\"" << encoder_filename << "\", "; + os << "decoder_filename=\"" << decoder_filename << "\", "; + os << "joiner_filename=\"" << joiner_filename << "\", "; + os << "tokens=\"" << tokens << "\", "; + os << "num_threads=" << num_threads << ", "; + os << "debug=" << (debug ? "True" : "False") << ")"; + + return os.str(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-transducer-model-config.h b/sherpa-onnx/csrc/offline-transducer-model-config.h new file mode 100644 index 00000000..39987bbc --- /dev/null +++ b/sherpa-onnx/csrc/offline-transducer-model-config.h @@ -0,0 +1,42 @@ +// sherpa-onnx/csrc/offline-transducer-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_MODEL_CONFIG_H_ + +#include + +#include "sherpa-onnx/csrc/parse-options.h" + +namespace sherpa_onnx { + +struct OfflineTransducerModelConfig { + std::string encoder_filename; + std::string decoder_filename; + std::string joiner_filename; + std::string tokens; + int32_t num_threads = 2; + bool debug = false; + + OfflineTransducerModelConfig() = default; + OfflineTransducerModelConfig(const std::string &encoder_filename, + const std::string &decoder_filename, + const std::string &joiner_filename, + const std::string &tokens, int32_t num_threads, + bool debug) + : encoder_filename(encoder_filename), + decoder_filename(decoder_filename), + joiner_filename(joiner_filename), + tokens(tokens), + num_threads(num_threads), + debug(debug) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_MODEL_CONFIG_H_ diff --git a/sherpa-onnx/csrc/offline-transducer-model.cc b/sherpa-onnx/csrc/offline-transducer-model.cc new file mode 100644 index 00000000..3d584b5f --- /dev/null +++ b/sherpa-onnx/csrc/offline-transducer-model.cc @@ -0,0 +1,238 @@ +// sherpa-onnx/csrc/offline-transducer-model.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-transducer-model.h" + +#include +#include +#include + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/offline-transducer-decoder.h" +#include "sherpa-onnx/csrc/onnx-utils.h" + +namespace sherpa_onnx { + +class OfflineTransducerModel::Impl { + public: + explicit Impl(const OfflineTransducerModelConfig &config) + : config_(config), + env_(ORT_LOGGING_LEVEL_WARNING), + sess_opts_{}, + allocator_{} { + sess_opts_.SetIntraOpNumThreads(config.num_threads); + sess_opts_.SetInterOpNumThreads(config.num_threads); + { + auto buf = ReadFile(config.encoder_filename); + InitEncoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(config.decoder_filename); + InitDecoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(config.joiner_filename); + InitJoiner(buf.data(), buf.size()); + } + } + + std::pair RunEncoder(Ort::Value features, + Ort::Value features_length) { + std::array encoder_inputs = {std::move(features), + std::move(features_length)}; + + auto encoder_out = encoder_sess_->Run( + {}, encoder_input_names_ptr_.data(), encoder_inputs.data(), + encoder_inputs.size(), encoder_output_names_ptr_.data(), + encoder_output_names_ptr_.size()); + + return {std::move(encoder_out[0]), std::move(encoder_out[1])}; + } + + Ort::Value RunDecoder(Ort::Value decoder_input) { + auto decoder_out = decoder_sess_->Run( + {}, decoder_input_names_ptr_.data(), &decoder_input, 1, + decoder_output_names_ptr_.data(), decoder_output_names_ptr_.size()); + return std::move(decoder_out[0]); + } + + Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) { + std::array joiner_input = {std::move(encoder_out), + std::move(decoder_out)}; + auto logit = joiner_sess_->Run({}, joiner_input_names_ptr_.data(), + joiner_input.data(), joiner_input.size(), + joiner_output_names_ptr_.data(), + joiner_output_names_ptr_.size()); + + return std::move(logit[0]); + } + + int32_t VocabSize() const { return vocab_size_; } + int32_t ContextSize() const { return context_size_; } + int32_t SubsamplingFactor() const { return 4; } + OrtAllocator *Allocator() const { return allocator_; } + + Ort::Value BuildDecoderInput( + const std::vector &results, + int32_t end_index) const { + assert(end_index <= results.size()); + + int32_t batch_size = end_index; + int32_t context_size = ContextSize(); + std::array shape{batch_size, context_size}; + + Ort::Value decoder_input = Ort::Value::CreateTensor( + Allocator(), shape.data(), shape.size()); + int64_t *p = decoder_input.GetTensorMutableData(); + + for (int32_t i = 0; i != batch_size; ++i) { + const auto &r = results[i]; + const int64_t *begin = r.tokens.data() + r.tokens.size() - context_size; + const int64_t *end = r.tokens.data() + r.tokens.size(); + std::copy(begin, end, p); + p += context_size; + } + return decoder_input; + } + + private: + void InitEncoder(void *model_data, size_t model_data_length) { + encoder_sess_ = std::make_unique( + env_, model_data, model_data_length, sess_opts_); + + GetInputNames(encoder_sess_.get(), &encoder_input_names_, + &encoder_input_names_ptr_); + + GetOutputNames(encoder_sess_.get(), &encoder_output_names_, + &encoder_output_names_ptr_); + + // get meta data + Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata(); + if (config_.debug) { + std::ostringstream os; + os << "---encoder---\n"; + PrintModelMetadata(os, meta_data); + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); + } + } + + void InitDecoder(void *model_data, size_t model_data_length) { + decoder_sess_ = std::make_unique( + env_, model_data, model_data_length, sess_opts_); + + GetInputNames(decoder_sess_.get(), &decoder_input_names_, + &decoder_input_names_ptr_); + + GetOutputNames(decoder_sess_.get(), &decoder_output_names_, + &decoder_output_names_ptr_); + + // get meta data + Ort::ModelMetadata meta_data = decoder_sess_->GetModelMetadata(); + if (config_.debug) { + std::ostringstream os; + os << "---decoder---\n"; + PrintModelMetadata(os, meta_data); + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); + } + + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below + SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size"); + SHERPA_ONNX_READ_META_DATA(context_size_, "context_size"); + } + + void InitJoiner(void *model_data, size_t model_data_length) { + joiner_sess_ = std::make_unique( + env_, model_data, model_data_length, sess_opts_); + + GetInputNames(joiner_sess_.get(), &joiner_input_names_, + &joiner_input_names_ptr_); + + GetOutputNames(joiner_sess_.get(), &joiner_output_names_, + &joiner_output_names_ptr_); + + // get meta data + Ort::ModelMetadata meta_data = joiner_sess_->GetModelMetadata(); + if (config_.debug) { + std::ostringstream os; + os << "---joiner---\n"; + PrintModelMetadata(os, meta_data); + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); + } + } + + private: + OfflineTransducerModelConfig config_; + Ort::Env env_; + Ort::SessionOptions sess_opts_; + Ort::AllocatorWithDefaultOptions allocator_; + + std::unique_ptr encoder_sess_; + std::unique_ptr decoder_sess_; + std::unique_ptr joiner_sess_; + + std::vector encoder_input_names_; + std::vector encoder_input_names_ptr_; + + std::vector encoder_output_names_; + std::vector encoder_output_names_ptr_; + + std::vector decoder_input_names_; + std::vector decoder_input_names_ptr_; + + std::vector decoder_output_names_; + std::vector decoder_output_names_ptr_; + + std::vector joiner_input_names_; + std::vector joiner_input_names_ptr_; + + std::vector joiner_output_names_; + std::vector joiner_output_names_ptr_; + + int32_t vocab_size_ = 0; // initialized in InitDecoder + int32_t context_size_ = 0; // initialized in InitDecoder +}; + +OfflineTransducerModel::OfflineTransducerModel( + const OfflineTransducerModelConfig &config) + : impl_(std::make_unique(config)) {} + +OfflineTransducerModel::~OfflineTransducerModel() = default; + +std::pair OfflineTransducerModel::RunEncoder( + Ort::Value features, Ort::Value features_length) { + return impl_->RunEncoder(std::move(features), std::move(features_length)); +} + +Ort::Value OfflineTransducerModel::RunDecoder(Ort::Value decoder_input) { + return impl_->RunDecoder(std::move(decoder_input)); +} + +Ort::Value OfflineTransducerModel::RunJoiner(Ort::Value encoder_out, + Ort::Value decoder_out) { + return impl_->RunJoiner(std::move(encoder_out), std::move(decoder_out)); +} + +int32_t OfflineTransducerModel::VocabSize() const { return impl_->VocabSize(); } + +int32_t OfflineTransducerModel::ContextSize() const { + return impl_->ContextSize(); +} + +int32_t OfflineTransducerModel::SubsamplingFactor() const { + return impl_->SubsamplingFactor(); +} + +OrtAllocator *OfflineTransducerModel::Allocator() const { + return impl_->Allocator(); +} + +Ort::Value OfflineTransducerModel::BuildDecoderInput( + const std::vector &results, + int32_t end_index) const { + return impl_->BuildDecoderInput(results, end_index); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-transducer-model.h b/sherpa-onnx/csrc/offline-transducer-model.h new file mode 100644 index 00000000..f40c82a0 --- /dev/null +++ b/sherpa-onnx/csrc/offline-transducer-model.h @@ -0,0 +1,95 @@ +// sherpa-onnx/csrc/offline-transducer-model.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_MODEL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_MODEL_H_ + +#include +#include +#include + +#include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/offline-transducer-model-config.h" + +namespace sherpa_onnx { + +struct OfflineTransducerDecoderResult; + +class OfflineTransducerModel { + public: + explicit OfflineTransducerModel(const OfflineTransducerModelConfig &config); + ~OfflineTransducerModel(); + + /** Run the encoder. + * + * @param features A tensor of shape (N, T, C). It is changed in-place. + * @param features_length A 1-D tensor of shape (N,) containing number of + * valid frames in `features` before padding. + * + * @return Return a pair containing: + * - encoder_out: A 3-D tensor of shape (N, T', encoder_dim) + * - encoder_out_length: A 1-D tensor of shape (N,) containing number + * of frames in `encoder_out` before padding. + */ + std::pair RunEncoder(Ort::Value features, + Ort::Value features_length); + + /** Run the decoder network. + * + * Caution: We assume there are no recurrent connections in the decoder and + * the decoder is stateless. See + * https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py + * for an example + * + * @param decoder_input It is usually of shape (N, context_size) + * @return Return a tensor of shape (N, decoder_dim). + */ + Ort::Value RunDecoder(Ort::Value decoder_input); + + /** Run the joint network. + * + * @param encoder_out Output of the encoder network. A tensor of shape + * (N, joiner_dim). + * @param decoder_out Output of the decoder network. A tensor of shape + * (N, joiner_dim). + * @return Return a tensor of shape (N, vocab_size). In icefall, the last + * last layer of the joint network is `nn.Linear`, + * not `nn.LogSoftmax`. + */ + Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out); + + /** Return the vocabulary size of the model + */ + int32_t VocabSize() const; + + /** Return the context_size of the decoder model. + */ + int32_t ContextSize() const; + + /** Return the subsampling factor of the model. + */ + int32_t SubsamplingFactor() const; + + /** Return an allocator for allocating memory + */ + OrtAllocator *Allocator() const; + + /** Build decoder_input from the current results. + * + * @param results Current decoded results. + * @param end_index We only use results[0:end_index] to build + * the decoder_input. + * @return Return a tensor of shape (results.size(), ContextSize()) + */ + Ort::Value BuildDecoderInput( + const std::vector &results, + int32_t end_index) const; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_MODEL_H_ diff --git a/sherpa-onnx/csrc/online-lstm-transducer-model.cc b/sherpa-onnx/csrc/online-lstm-transducer-model.cc index 29972b19..25e300ae 100644 --- a/sherpa-onnx/csrc/online-lstm-transducer-model.cc +++ b/sherpa-onnx/csrc/online-lstm-transducer-model.cc @@ -95,7 +95,7 @@ void OnlineLstmTransducerModel::InitEncoder(void *model_data, std::ostringstream os; os << "---encoder---\n"; PrintModelMetadata(os, meta_data); - fprintf(stderr, "%s\n", os.str().c_str()); + SHERPA_ONNX_LOGE("%s", os.str().c_str()); } Ort::AllocatorWithDefaultOptions allocator; // used in the macro below @@ -123,7 +123,7 @@ void OnlineLstmTransducerModel::InitDecoder(void *model_data, std::ostringstream os; os << "---decoder---\n"; PrintModelMetadata(os, meta_data); - fprintf(stderr, "%s\n", os.str().c_str()); + SHERPA_ONNX_LOGE("%s", os.str().c_str()); } Ort::AllocatorWithDefaultOptions allocator; // used in the macro below @@ -148,7 +148,7 @@ void OnlineLstmTransducerModel::InitJoiner(void *model_data, std::ostringstream os; os << "---joiner---\n"; PrintModelMetadata(os, meta_data); - fprintf(stderr, "%s\n", os.str().c_str()); + SHERPA_ONNX_LOGE("%s", os.str().c_str()); } } @@ -228,9 +228,6 @@ std::vector OnlineLstmTransducerModel::GetEncoderInitStates() { std::pair> OnlineLstmTransducerModel::RunEncoder(Ort::Value features, std::vector states) { - auto memory_info = - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); - std::array encoder_inputs = { std::move(features), std::move(states[0]), std::move(states[1])}; diff --git a/sherpa-onnx/csrc/online-stream.cc b/sherpa-onnx/csrc/online-stream.cc index da397d67..27c49462 100644 --- a/sherpa-onnx/csrc/online-stream.cc +++ b/sherpa-onnx/csrc/online-stream.cc @@ -20,7 +20,7 @@ class OnlineStream::Impl { feat_extractor_.AcceptWaveform(sampling_rate, waveform, n); } - void InputFinished() { feat_extractor_.InputFinished(); } + void InputFinished() const { feat_extractor_.InputFinished(); } int32_t NumFramesReady() const { return feat_extractor_.NumFramesReady() - start_frame_index_; @@ -68,11 +68,11 @@ OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/) OnlineStream::~OnlineStream() = default; void OnlineStream::AcceptWaveform(int32_t sampling_rate, const float *waveform, - int32_t n) { + int32_t n) const { impl_->AcceptWaveform(sampling_rate, waveform, n); } -void OnlineStream::InputFinished() { impl_->InputFinished(); } +void OnlineStream::InputFinished() const { impl_->InputFinished(); } int32_t OnlineStream::NumFramesReady() const { return impl_->NumFramesReady(); } diff --git a/sherpa-onnx/csrc/online-stream.h b/sherpa-onnx/csrc/online-stream.h index 32fe1248..bc1935da 100644 --- a/sherpa-onnx/csrc/online-stream.h +++ b/sherpa-onnx/csrc/online-stream.h @@ -27,7 +27,8 @@ class OnlineStream { the range [-1, 1]. @param n Number of entries in waveform */ - void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n); + void AcceptWaveform(int32_t sampling_rate, const float *waveform, + int32_t n) const; /** * InputFinished() tells the class you won't be providing any @@ -35,7 +36,7 @@ class OnlineStream { * of features, in the case where snip-edges == false; it also * affects the return value of IsLastFrame(). */ - void InputFinished(); + void InputFinished() const; int32_t NumFramesReady() const; diff --git a/sherpa-onnx/csrc/online-websocket-client.cc b/sherpa-onnx/csrc/online-websocket-client.cc index 2df87b6c..62a6832b 100644 --- a/sherpa-onnx/csrc/online-websocket-client.cc +++ b/sherpa-onnx/csrc/online-websocket-client.cc @@ -248,14 +248,21 @@ int32_t main(int32_t argc, char *argv[]) { std::string wave_filename = po.GetArg(1); bool is_ok = false; + int32_t actual_sample_rate = -1; std::vector samples = - sherpa_onnx::ReadWave(wave_filename, sample_rate, &is_ok); + sherpa_onnx::ReadWave(wave_filename, &actual_sample_rate, &is_ok); if (!is_ok) { SHERPA_ONNX_LOGE("Failed to read %s", wave_filename.c_str()); return -1; } + if (actual_sample_rate != sample_rate) { + SHERPA_ONNX_LOGE("Expected sample rate: %d, given %d", sample_rate, + actual_sample_rate); + return -1; + } + asio::io_context io_conn; // for network connections Client c(io_conn, server_ip, server_port, samples, samples_per_message, seconds_per_message); diff --git a/sherpa-onnx/csrc/online-zipformer-transducer-model.cc b/sherpa-onnx/csrc/online-zipformer-transducer-model.cc index c0748179..4265a76f 100644 --- a/sherpa-onnx/csrc/online-zipformer-transducer-model.cc +++ b/sherpa-onnx/csrc/online-zipformer-transducer-model.cc @@ -97,7 +97,7 @@ void OnlineZipformerTransducerModel::InitEncoder(void *model_data, std::ostringstream os; os << "---encoder---\n"; PrintModelMetadata(os, meta_data); - fprintf(stderr, "%s\n", os.str().c_str()); + SHERPA_ONNX_LOGE("%s", os.str().c_str()); } Ort::AllocatorWithDefaultOptions allocator; // used in the macro below @@ -123,8 +123,8 @@ void OnlineZipformerTransducerModel::InitEncoder(void *model_data, print(num_encoder_layers_, "num_encoder_layers"); print(cnn_module_kernels_, "cnn_module_kernels"); print(left_context_len_, "left_context_len"); - fprintf(stderr, "T: %d\n", T_); - fprintf(stderr, "decode_chunk_len_: %d\n", decode_chunk_len_); + SHERPA_ONNX_LOGE("T: %d", T_); + SHERPA_ONNX_LOGE("decode_chunk_len_: %d", decode_chunk_len_); } } @@ -145,7 +145,7 @@ void OnlineZipformerTransducerModel::InitDecoder(void *model_data, std::ostringstream os; os << "---decoder---\n"; PrintModelMetadata(os, meta_data); - fprintf(stderr, "%s\n", os.str().c_str()); + SHERPA_ONNX_LOGE("%s", os.str().c_str()); } Ort::AllocatorWithDefaultOptions allocator; // used in the macro below @@ -170,7 +170,7 @@ void OnlineZipformerTransducerModel::InitJoiner(void *model_data, std::ostringstream os; os << "---joiner---\n"; PrintModelMetadata(os, meta_data); - fprintf(stderr, "%s\n", os.str().c_str()); + SHERPA_ONNX_LOGE("%s", os.str().c_str()); } } @@ -435,9 +435,6 @@ std::vector OnlineZipformerTransducerModel::GetEncoderInitStates() { std::pair> OnlineZipformerTransducerModel::RunEncoder(Ort::Value features, std::vector states) { - auto memory_info = - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); - std::vector encoder_inputs; encoder_inputs.reserve(1 + states.size()); diff --git a/sherpa-onnx/csrc/packed-sequence.cc b/sherpa-onnx/csrc/packed-sequence.cc index 09f88018..df5b9202 100644 --- a/sherpa-onnx/csrc/packed-sequence.cc +++ b/sherpa-onnx/csrc/packed-sequence.cc @@ -41,7 +41,7 @@ PackedSequence PackPaddedSequence(OrtAllocator *allocator, std::vector l_shape = length->GetTensorTypeAndShapeInfo().GetShape(); assert(v_shape.size() == 3); - assert(l_shape.size() == 3); + assert(l_shape.size() == 1); assert(v_shape[0] == l_shape[0]); std::vector indexes(v_shape[0]); diff --git a/sherpa-onnx/csrc/packed-sequence.h b/sherpa-onnx/csrc/packed-sequence.h index d2d125c7..203f47c3 100644 --- a/sherpa-onnx/csrc/packed-sequence.h +++ b/sherpa-onnx/csrc/packed-sequence.h @@ -13,7 +13,26 @@ namespace sherpa_onnx { struct PackedSequence { std::vector sorted_indexes; std::vector batch_sizes; + + // data is a 2-D tensor of shape (sum(batch_sizes), channels) Ort::Value data{nullptr}; + + // Return a shallow copy of data[start:start+size, :] + Ort::Value Get(int32_t start, int32_t size) { + auto shape = data.GetTensorTypeAndShapeInfo().GetShape(); + + std::array ans_shape{size, shape[1]}; + + float *p = data.GetTensorMutableData(); + + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + // a shallow copy + return Ort::Value::CreateTensor(memory_info, p + start * shape[1], + size * shape[1], ans_shape.data(), + ans_shape.size()); + } }; /** Similar to torch.nn.utils.rnn.pad_sequence but it supports only diff --git a/sherpa-onnx/csrc/resample.cc b/sherpa-onnx/csrc/resample.cc index 8ef3a1b5..f82c61a9 100644 --- a/sherpa-onnx/csrc/resample.cc +++ b/sherpa-onnx/csrc/resample.cc @@ -46,7 +46,7 @@ I Gcd(I m, I n) { // this function is copied from kaldi/src/base/kaldi-math.h if (m == 0 || n == 0) { if (m == 0 && n == 0) { // gcd not defined, as all integers are divisors. - fprintf(stderr, "Undefined GCD since m = 0, n = 0."); + fprintf(stderr, "Undefined GCD since m = 0, n = 0.\n"); exit(-1); } return (m == 0 ? (n > 0 ? n : -n) : (m > 0 ? m : -m)); diff --git a/sherpa-onnx/csrc/sherpa-onnx-alsa.cc b/sherpa-onnx/csrc/sherpa-onnx-alsa.cc index 730f7618..c468f01b 100644 --- a/sherpa-onnx/csrc/sherpa-onnx-alsa.cc +++ b/sherpa-onnx/csrc/sherpa-onnx-alsa.cc @@ -95,6 +95,10 @@ as the device_name. fprintf(stderr, "%s\n", config.ToString().c_str()); + if (!config.Validate()) { + fprintf(stderr, "Errors in config!\n"); + return -1; + } sherpa_onnx::OnlineRecognizer recognizer(config); int32_t expected_sample_rate = config.feat_config.sampling_rate; diff --git a/sherpa-onnx/csrc/sherpa-onnx-microphone.cc b/sherpa-onnx/csrc/sherpa-onnx-microphone.cc index 2f16186d..27d9426e 100644 --- a/sherpa-onnx/csrc/sherpa-onnx-microphone.cc +++ b/sherpa-onnx/csrc/sherpa-onnx-microphone.cc @@ -86,6 +86,11 @@ for a list of pre-trained models to download. fprintf(stderr, "%s\n", config.ToString().c_str()); + if (!config.Validate()) { + fprintf(stderr, "Errors in config!\n"); + return -1; + } + sherpa_onnx::OnlineRecognizer recognizer(config); auto s = recognizer.CreateStream(); diff --git a/sherpa-onnx/csrc/sherpa-onnx-offline.cc b/sherpa-onnx/csrc/sherpa-onnx-offline.cc new file mode 100644 index 00000000..d4b7529d --- /dev/null +++ b/sherpa-onnx/csrc/sherpa-onnx-offline.cc @@ -0,0 +1,115 @@ +// sherpa-onnx/csrc/sherpa-onnx-offline.cc +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#include + +#include // NOLINT +#include +#include + +#include "sherpa-onnx/csrc/offline-recognizer.h" +#include "sherpa-onnx/csrc/offline-stream.h" +#include "sherpa-onnx/csrc/offline-transducer-decoder.h" +#include "sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h" +#include "sherpa-onnx/csrc/offline-transducer-model.h" +#include "sherpa-onnx/csrc/pad-sequence.h" +#include "sherpa-onnx/csrc/symbol-table.h" +#include "sherpa-onnx/csrc/wave-reader.h" + +int main(int32_t argc, char *argv[]) { + if (argc < 6 || argc > 8) { + const char *usage = R"usage( +Usage: + ./bin/sherpa-onnx-offline \ + /path/to/tokens.txt \ + /path/to/encoder.onnx \ + /path/to/decoder.onnx \ + /path/to/joiner.onnx \ + /path/to/foo.wav [num_threads [decoding_method]] + +Default value for num_threads is 2. +Valid values for decoding_method: greedy_search. +foo.wav should be of single channel, 16-bit PCM encoded wave file; its +sampling rate can be arbitrary and does not need to be 16kHz. + +Please refer to +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html +for a list of pre-trained models to download. +)usage"; + fprintf(stderr, "%s\n", usage); + + return 0; + } + + sherpa_onnx::OfflineRecognizerConfig config; + + config.model_config.tokens = argv[1]; + + config.model_config.debug = false; + config.model_config.encoder_filename = argv[2]; + config.model_config.decoder_filename = argv[3]; + config.model_config.joiner_filename = argv[4]; + + std::string wav_filename = argv[5]; + + config.model_config.num_threads = 2; + if (argc == 7 && atoi(argv[6]) > 0) { + config.model_config.num_threads = atoi(argv[6]); + } + + if (argc == 8) { + config.decoding_method = argv[7]; + } + + fprintf(stderr, "%s\n", config.ToString().c_str()); + + if (!config.Validate()) { + fprintf(stderr, "Errors in config!\n"); + return -1; + } + + int32_t sampling_rate = -1; + + bool is_ok = false; + std::vector samples = + sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok); + if (!is_ok) { + fprintf(stderr, "Failed to read %s\n", wav_filename.c_str()); + return -1; + } + fprintf(stderr, "sampling rate of input file: %d\n", sampling_rate); + + float duration = samples.size() / static_cast(sampling_rate); + + sherpa_onnx::OfflineRecognizer recognizer(config); + auto s = recognizer.CreateStream(); + + auto begin = std::chrono::steady_clock::now(); + fprintf(stderr, "Started\n"); + + s->AcceptWaveform(sampling_rate, samples.data(), samples.size()); + + recognizer.DecodeStream(s.get()); + + fprintf(stderr, "Done!\n"); + + fprintf(stderr, "Recognition result for %s:\n%s\n", wav_filename.c_str(), + s->GetResult().text.c_str()); + + auto end = std::chrono::steady_clock::now(); + float elapsed_seconds = + std::chrono::duration_cast(end - begin) + .count() / + 1000.; + + fprintf(stderr, "num threads: %d\n", config.model_config.num_threads); + fprintf(stderr, "decoding method: %s\n", config.decoding_method.c_str()); + + fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds); + float rtf = elapsed_seconds / duration; + fprintf(stderr, "Real time factor (RTF): %.3f / %.3f = %.3f\n", + elapsed_seconds, duration, rtf); + + return 0; +} diff --git a/sherpa-onnx/csrc/sherpa-onnx.cc b/sherpa-onnx/csrc/sherpa-onnx.cc index ad499e4c..12a04744 100644 --- a/sherpa-onnx/csrc/sherpa-onnx.cc +++ b/sherpa-onnx/csrc/sherpa-onnx.cc @@ -26,6 +26,8 @@ Usage: Default value for num_threads is 2. Valid values for decoding_method: greedy_search (default), modified_beam_search. +foo.wav should be of single channel, 16-bit PCM encoded wave file; its +sampling rate can be arbitrary and does not need to be 16kHz. Please refer to https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html @@ -59,20 +61,26 @@ for a list of pre-trained models to download. fprintf(stderr, "%s\n", config.ToString().c_str()); + if (!config.Validate()) { + fprintf(stderr, "Errors in config!\n"); + return -1; + } + sherpa_onnx::OnlineRecognizer recognizer(config); - int32_t expected_sampling_rate = config.feat_config.sampling_rate; + int32_t sampling_rate = -1; bool is_ok = false; std::vector samples = - sherpa_onnx::ReadWave(wav_filename, expected_sampling_rate, &is_ok); + sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok); if (!is_ok) { fprintf(stderr, "Failed to read %s\n", wav_filename.c_str()); return -1; } + fprintf(stderr, "sampling rate of input file: %d\n", sampling_rate); - float duration = samples.size() / static_cast(expected_sampling_rate); + float duration = samples.size() / static_cast(sampling_rate); fprintf(stderr, "wav filename: %s\n", wav_filename.c_str()); fprintf(stderr, "wav duration (s): %.3f\n", duration); @@ -81,12 +89,13 @@ for a list of pre-trained models to download. fprintf(stderr, "Started\n"); auto s = recognizer.CreateStream(); - s->AcceptWaveform(expected_sampling_rate, samples.data(), samples.size()); + s->AcceptWaveform(sampling_rate, samples.data(), samples.size()); - std::vector tail_paddings( - static_cast(0.2 * expected_sampling_rate)); - s->AcceptWaveform(expected_sampling_rate, tail_paddings.data(), - tail_paddings.size()); + std::vector tail_paddings(static_cast(0.2 * sampling_rate)); + // Note: We can call AcceptWaveform() multiple times. + s->AcceptWaveform(sampling_rate, tail_paddings.data(), tail_paddings.size()); + + // Call InputFinished() to indicate that no audio samples are available s->InputFinished(); while (recognizer.IsReady(s.get())) { diff --git a/sherpa-onnx/csrc/slice-test.cc b/sherpa-onnx/csrc/slice-test.cc index 6f7bde3e..43f97bc8 100644 --- a/sherpa-onnx/csrc/slice-test.cc +++ b/sherpa-onnx/csrc/slice-test.cc @@ -30,4 +30,23 @@ TEST(Slice, Slice3D) { // TODO(fangjun): Check that the results are correct } +TEST(Slice, Slice2D) { + Ort::AllocatorWithDefaultOptions allocator; + std::array shape{5, 8}; + Ort::Value v = + Ort::Value::CreateTensor(allocator, shape.data(), shape.size()); + float *p = v.GetTensorMutableData(); + + std::iota(p, p + shape[0] * shape[1], 0); + + auto v1 = Slice(allocator, &v, 1, 3); + auto v2 = Slice(allocator, &v, 0, 2); + + Print2D(&v); + Print2D(&v1); + Print2D(&v2); + + // TODO(fangjun): Check that the results are correct +} + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/slice.cc b/sherpa-onnx/csrc/slice.cc index 189f8517..6d15e0f5 100644 --- a/sherpa-onnx/csrc/slice.cc +++ b/sherpa-onnx/csrc/slice.cc @@ -24,7 +24,7 @@ Ort::Value Slice(OrtAllocator *allocator, const Ort::Value *v, assert(0 <= dim1_start); assert(dim1_start < dim1_end); - assert(dim1_end < shape[1]); + assert(dim1_end <= shape[1]); const T *src = v->GetTensorData(); @@ -46,8 +46,35 @@ Ort::Value Slice(OrtAllocator *allocator, const Ort::Value *v, return ans; } +template +Ort::Value Slice(OrtAllocator *allocator, const Ort::Value *v, + int32_t dim0_start, int32_t dim0_end) { + std::vector shape = v->GetTensorTypeAndShapeInfo().GetShape(); + assert(shape.size() == 2); + + assert(0 <= dim0_start); + assert(dim0_start < dim0_end); + assert(dim0_end <= shape[0]); + + const T *src = v->GetTensorData(); + + std::array ans_shape{dim0_end - dim0_start, shape[1]}; + + Ort::Value ans = Ort::Value::CreateTensor(allocator, ans_shape.data(), + ans_shape.size()); + const T *start = v->GetTensorData() + dim0_start * shape[1]; + const T *end = v->GetTensorData() + dim0_end * shape[1]; + T *dst = ans.GetTensorMutableData(); + std::copy(start, end, dst); + + return ans; +} + template Ort::Value Slice(OrtAllocator *allocator, const Ort::Value *v, int32_t dim0_start, int32_t dim0_end, int32_t dim1_start, int32_t dim1_end); +template Ort::Value Slice(OrtAllocator *allocator, const Ort::Value *v, + int32_t dim0_start, int32_t dim0_end); + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/slice.h b/sherpa-onnx/csrc/slice.h index fb406cae..21a93f1a 100644 --- a/sherpa-onnx/csrc/slice.h +++ b/sherpa-onnx/csrc/slice.h @@ -8,12 +8,12 @@ namespace sherpa_onnx { -/** Get a deep copy by slicing v. +/** Get a deep copy by slicing a 3-D tensor v. * - * It returns v[dim0_start:dim0_end, dim1_start:dim1_end] + * It returns v[dim0_start:dim0_end, dim1_start:dim1_end, :] * * @param allocator - * @param v A 3-D tensor. Its data type is T. + * @param v A 2-D tensor. Its data type is T. * @param dim0_start Start index of the first dimension.. * @param dim0_end End index of the first dimension.. * @param dim1_start Start index of the second dimension. @@ -26,6 +26,23 @@ template Ort::Value Slice(OrtAllocator *allocator, const Ort::Value *v, int32_t dim0_start, int32_t dim0_end, int32_t dim1_start, int32_t dim1_end); + +/** Get a deep copy by slicing a 2-D tensor v. + * + * It returns v[dim0_start:dim0_end, :] + * + * @param allocator + * @param v A 2-D tensor. Its data type is T. + * @param dim0_start Start index of the first dimension.. + * @param dim0_end End index of the first dimension.. + * + * @return Return a 2-D tensor of shape + * (dim0_end-dim0_start, v.shape[1]) + */ +template +Ort::Value Slice(OrtAllocator *allocator, const Ort::Value *v, + int32_t dim0_start, int32_t dim0_end); + } // namespace sherpa_onnx #endif // SHERPA_ONNX_CSRC_SLICE_H_ diff --git a/sherpa-onnx/csrc/wave-reader.cc b/sherpa-onnx/csrc/wave-reader.cc index 2223641d..98d34fb7 100644 --- a/sherpa-onnx/csrc/wave-reader.cc +++ b/sherpa-onnx/csrc/wave-reader.cc @@ -6,10 +6,11 @@ #include #include -#include #include #include +#include "sherpa-onnx/csrc/macros.h" + namespace sherpa_onnx { namespace { // see http://soundfile.sapp.org/doc/WaveFormat/ @@ -20,26 +21,34 @@ struct WaveHeader { bool Validate() const { // F F I R if (chunk_id != 0x46464952) { + SHERPA_ONNX_LOGE("Expected chunk_id RIFF. Given: 0x%08x\n", chunk_id); return false; } // E V A W if (format != 0x45564157) { + SHERPA_ONNX_LOGE("Expected format WAVE. Given: 0x%08x\n", format); return false; } if (subchunk1_id != 0x20746d66) { + SHERPA_ONNX_LOGE("Expected subchunk1_id 0x20746d66. Given: 0x%08x\n", + subchunk1_id); return false; } if (subchunk1_size != 16) { // 16 for PCM + SHERPA_ONNX_LOGE("Expected subchunk1_size 16. Given: %d\n", + subchunk1_size); return false; } if (audio_format != 1) { // 1 for PCM + SHERPA_ONNX_LOGE("Expected audio_format 1. Given: %d\n", audio_format); return false; } if (num_channels != 1) { // we support only single channel for now + SHERPA_ONNX_LOGE("Expected single channel. Given: %d\n", num_channels); return false; } if (byte_rate != (sample_rate * num_channels * bits_per_sample / 8)) { @@ -51,6 +60,8 @@ struct WaveHeader { } if (bits_per_sample != 16) { // we support only 16 bits per sample + SHERPA_ONNX_LOGE("Expected bits_per_sample 16. Given: %d\n", + bits_per_sample); return false; } @@ -62,7 +73,7 @@ struct WaveHeader { // and // https://www.robotplanet.dk/audio/wav_meta_data/riff_mci.pdf void SeekToDataChunk(std::istream &is) { - // a t a d + // a t a d while (is && subchunk2_id != 0x61746164) { // const char *p = reinterpret_cast(&subchunk2_id); // printf("Skip chunk (%x): %c%c%c%c of size: %d\n", subchunk2_id, p[0], @@ -91,7 +102,7 @@ static_assert(sizeof(WaveHeader) == 44, ""); // Read a wave file of mono-channel. // Return its samples normalized to the range [-1, 1). -std::vector ReadWaveImpl(std::istream &is, float expected_sample_rate, +std::vector ReadWaveImpl(std::istream &is, int32_t *sampling_rate, bool *is_ok) { WaveHeader header; is.read(reinterpret_cast(&header), sizeof(header)); @@ -111,10 +122,7 @@ std::vector ReadWaveImpl(std::istream &is, float expected_sample_rate, return {}; } - if (expected_sample_rate != header.sample_rate) { - *is_ok = false; - return {}; - } + *sampling_rate = header.sample_rate; // header.subchunk2_size contains the number of bytes in the data. // As we assume each sample contains two bytes, so it is divided by 2 here @@ -137,15 +145,15 @@ std::vector ReadWaveImpl(std::istream &is, float expected_sample_rate, } // namespace -std::vector ReadWave(const std::string &filename, - float expected_sample_rate, bool *is_ok) { +std::vector ReadWave(const std::string &filename, int32_t *sampling_rate, + bool *is_ok) { std::ifstream is(filename, std::ifstream::binary); - return ReadWave(is, expected_sample_rate, is_ok); + return ReadWave(is, sampling_rate, is_ok); } -std::vector ReadWave(std::istream &is, float expected_sample_rate, +std::vector ReadWave(std::istream &is, int32_t *sampling_rate, bool *is_ok) { - auto samples = ReadWaveImpl(is, expected_sample_rate, is_ok); + auto samples = ReadWaveImpl(is, sampling_rate, is_ok); return samples; } diff --git a/sherpa-onnx/csrc/wave-reader.h b/sherpa-onnx/csrc/wave-reader.h index dfec9807..98e956ab 100644 --- a/sherpa-onnx/csrc/wave-reader.h +++ b/sherpa-onnx/csrc/wave-reader.h @@ -13,17 +13,17 @@ namespace sherpa_onnx { /** Read a wave file with expected sample rate. - @param filename Path to a wave file. It MUST be single channel, PCM encoded. - @param expected_sample_rate Expected sample rate of the wave file. If the - sample rate don't match, it throws an exception. + @param filename Path to a wave file. It MUST be single channel, 16-bit + PCM encoded. + @param sampling_rate On return, it contains the sampling rate of the file. @param is_ok On return it is true if the reading succeeded; false otherwise. @return Return wave samples normalized to the range [-1, 1). */ -std::vector ReadWave(const std::string &filename, - float expected_sample_rate, bool *is_ok); +std::vector ReadWave(const std::string &filename, int32_t *sampling_rate, + bool *is_ok); -std::vector ReadWave(std::istream &is, float expected_sample_rate, +std::vector ReadWave(std::istream &is, int32_t *sampling_rate, bool *is_ok); } // namespace sherpa_onnx diff --git a/sherpa-onnx/jni/jni.cc b/sherpa-onnx/jni/jni.cc index 451ffa3d..bb81ec58 100644 --- a/sherpa-onnx/jni/jni.cc +++ b/sherpa-onnx/jni/jni.cc @@ -11,6 +11,7 @@ #include "jni.h" // NOLINT #include +#include #if __ANDROID_API__ >= 9 #include "android/asset_manager.h" @@ -43,14 +44,18 @@ class SherpaOnnx { stream_(recognizer_.CreateStream()) { } - void AcceptWaveform(int32_t sample_rate, const float *samples, - int32_t n) const { + void AcceptWaveform(int32_t sample_rate, const float *samples, int32_t n) { + if (input_sample_rate_ == -1) { + input_sample_rate_ = sample_rate; + } + stream_->AcceptWaveform(sample_rate, samples, n); } void InputFinished() const { - std::vector tail_padding(16000 * 0.32, 0); - stream_->AcceptWaveform(16000, tail_padding.data(), tail_padding.size()); + std::vector tail_padding(input_sample_rate_ * 0.32, 0); + stream_->AcceptWaveform(input_sample_rate_, tail_padding.data(), + tail_padding.size()); stream_->InputFinished(); } @@ -70,6 +75,7 @@ class SherpaOnnx { private: sherpa_onnx::OnlineRecognizer recognizer_; std::unique_ptr stream_; + int32_t input_sample_rate_ = -1; }; static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) { @@ -276,17 +282,24 @@ JNIEXPORT jstring JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_getText( return env->NewStringUTF(text.c_str()); } +// see +// https://stackoverflow.com/questions/29043872/android-jni-return-multiple-variables +static jobject NewInteger(JNIEnv *env, int32_t value) { + jclass cls = env->FindClass("java/lang/Integer"); + jmethodID constructor = env->GetMethodID(cls, "", "(I)V"); + return env->NewObject(cls, constructor, value); +} + SHERPA_ONNX_EXTERN_C -JNIEXPORT jfloatArray JNICALL +JNIEXPORT jobjectArray JNICALL Java_com_k2fsa_sherpa_onnx_WaveReader_00024Companion_readWave( - JNIEnv *env, jclass /*cls*/, jobject asset_manager, jstring filename, - jfloat expected_sample_rate) { + JNIEnv *env, jclass /*cls*/, jobject asset_manager, jstring filename) { const char *p_filename = env->GetStringUTFChars(filename, nullptr); #if __ANDROID_API__ >= 9 AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager); if (!mgr) { SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr); - return nullptr; + exit(-1); } std::vector buffer = sherpa_onnx::ReadFile(mgr, p_filename); @@ -297,16 +310,25 @@ Java_com_k2fsa_sherpa_onnx_WaveReader_00024Companion_readWave( #endif bool is_ok = false; + int32_t sampling_rate = -1; std::vector samples = - sherpa_onnx::ReadWave(is, expected_sample_rate, &is_ok); + sherpa_onnx::ReadWave(is, &sampling_rate, &is_ok); env->ReleaseStringUTFChars(filename, p_filename); if (!is_ok) { - return nullptr; + SHERPA_ONNX_LOGE("Failed to read %s", p_filename); + exit(-1); } jfloatArray ans = env->NewFloatArray(samples.size()); env->SetFloatArrayRegion(ans, 0, samples.size(), samples.data()); - return ans; + + jobjectArray obj_arr = (jobjectArray)env->NewObjectArray( + 2, env->FindClass("java/lang/Object"), nullptr); + + env->SetObjectArrayElement(obj_arr, 0, ans); + env->SetObjectArrayElement(obj_arr, 1, NewInteger(env, sampling_rate)); + + return obj_arr; } diff --git a/sherpa-onnx/python/csrc/features.cc b/sherpa-onnx/python/csrc/features.cc index c5601a2a..5139398f 100644 --- a/sherpa-onnx/python/csrc/features.cc +++ b/sherpa-onnx/python/csrc/features.cc @@ -11,12 +11,10 @@ namespace sherpa_onnx { static void PybindFeatureExtractorConfig(py::module *m) { using PyClass = FeatureExtractorConfig; py::class_(*m, "FeatureExtractorConfig") - .def(py::init(), - py::arg("sampling_rate") = 16000, py::arg("feature_dim") = 80, - py::arg("max_feature_vectors") = -1) + .def(py::init(), py::arg("sampling_rate") = 16000, + py::arg("feature_dim") = 80) .def_readwrite("sampling_rate", &PyClass::sampling_rate) .def_readwrite("feature_dim", &PyClass::feature_dim) - .def_readwrite("max_feature_vectors", &PyClass::max_feature_vectors) .def("__str__", &PyClass::ToString); } diff --git a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py index e8fd64a7..ce1a6afa 100644 --- a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py @@ -34,7 +34,6 @@ class OnlineRecognizer(object): rule3_min_utterance_length: int = 20, decoding_method: str = "greedy_search", max_active_paths: int = 4, - max_feature_vectors: int = -1, ): """ Please refer to @@ -82,9 +81,6 @@ class OnlineRecognizer(object): max_active_paths: Use only when decoding_method is modified_beam_search. It specifies the maximum number of active paths during beam search. - max_feature_vectors: - Number of feature vectors to cache. -1 means to cache all feature - frames that have been processed. """ _assert_file_exists(tokens) _assert_file_exists(encoder) @@ -104,7 +100,6 @@ class OnlineRecognizer(object): feat_config = FeatureExtractorConfig( sampling_rate=sample_rate, feature_dim=feature_dim, - max_feature_vectors=max_feature_vectors, ) endpoint_config = EndpointConfig( diff --git a/sherpa-onnx/python/tests/test_feature_extractor_config.py b/sherpa-onnx/python/tests/test_feature_extractor_config.py index e12f808a..f3e83f63 100644 --- a/sherpa-onnx/python/tests/test_feature_extractor_config.py +++ b/sherpa-onnx/python/tests/test_feature_extractor_config.py @@ -8,18 +8,18 @@ import unittest -import sherpa_onnx +import _sherpa_onnx class TestFeatureExtractorConfig(unittest.TestCase): def test_default_constructor(self): - config = sherpa_onnx.FeatureExtractorConfig() + config = _sherpa_onnx.FeatureExtractorConfig() assert config.sampling_rate == 16000, config.sampling_rate assert config.feature_dim == 80, config.feature_dim print(config) def test_constructor(self): - config = sherpa_onnx.FeatureExtractorConfig(sampling_rate=8000, feature_dim=40) + config = _sherpa_onnx.FeatureExtractorConfig(sampling_rate=8000, feature_dim=40) assert config.sampling_rate == 8000, config.sampling_rate assert config.feature_dim == 40, config.feature_dim print(config) diff --git a/sherpa-onnx/python/tests/test_online_transducer_model_config.py b/sherpa-onnx/python/tests/test_online_transducer_model_config.py index 1b9010db..6c41bb4f 100644 --- a/sherpa-onnx/python/tests/test_online_transducer_model_config.py +++ b/sherpa-onnx/python/tests/test_online_transducer_model_config.py @@ -8,21 +8,23 @@ import unittest -import sherpa_onnx +import _sherpa_onnx class TestOnlineTransducerModelConfig(unittest.TestCase): def test_constructor(self): - config = sherpa_onnx.OnlineTransducerModelConfig( + config = _sherpa_onnx.OnlineTransducerModelConfig( encoder_filename="encoder.onnx", decoder_filename="decoder.onnx", joiner_filename="joiner.onnx", + tokens="tokens.txt", num_threads=8, debug=True, ) assert config.encoder_filename == "encoder.onnx", config.encoder_filename assert config.decoder_filename == "decoder.onnx", config.decoder_filename assert config.joiner_filename == "joiner.onnx", config.joiner_filename + assert config.tokens == "tokens.txt", config.tokens assert config.num_threads == 8, config.num_threads assert config.debug is True, config.debug print(config)