diff --git a/.github/scripts/test-online-transducer.sh b/.github/scripts/test-online-transducer.sh new file mode 100755 index 00000000..2c2c3815 --- /dev/null +++ b/.github/scripts/test-online-transducer.sh @@ -0,0 +1,47 @@ +#!/usr/bin/env bash + +set -e + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +echo "EXE is $EXE" +echo "PATH: $PATH" + +which $EXE + +log "------------------------------------------------------------" +log "Run LSTM transducer (English)" +log "------------------------------------------------------------" + +repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-lstm-en-2023-02-17 + +log "Start testing ${repo_url}" +repo=$(basename $repo_url) +log "Download pretrained model and test-data from $repo_url" + +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +pushd $repo +git lfs pull --include "*.onnx" +popd + +waves=( +$repo/test_wavs/1089-134686-0001.wav +$repo/test_wavs/1221-135766-0001.wav +$repo/test_wavs/1221-135766-0002.wav +) + +for wave in ${waves[@]}; do + time $EXE \ + $repo/tokens.txt \ + $repo/encoder-epoch-99-avg-1.onnx \ + $repo/decoder-epoch-99-avg-1.onnx \ + $repo/joiner-epoch-99-avg-1.onnx \ + $wave \ + 4 +done + +rm -rf $repo diff --git a/.github/workflows/linux.yaml b/.github/workflows/linux.yaml new file mode 100644 index 00000000..08a9c4f0 --- /dev/null +++ b/.github/workflows/linux.yaml @@ -0,0 +1,71 @@ +name: linux + +on: + push: + branches: + - master + paths: + - '.github/workflows/linux.yaml' + - '.github/scripts/test-online-transducer.sh' + - 'CMakeLists.txt' + - 'cmake/**' + - 'sherpa-onnx/csrc/*' + pull_request: + branches: + - master + paths: + - '.github/workflows/linux.yaml' + - '.github/scripts/test-online-transducer.sh' + - 'CMakeLists.txt' + - 'cmake/**' + - 'sherpa-onnx/csrc/*' + +concurrency: + group: linux-${{ github.ref }} + cancel-in-progress: true + +permissions: + contents: read + +jobs: + linux: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Configure CMake + shell: bash + run: | + mkdir build + cd build + cmake -D CMAKE_BUILD_TYPE=Release .. + + - name: Build sherpa-onnx for ubuntu + shell: bash + run: | + cd build + make -j2 + + ls -lh lib + ls -lh bin + + - name: Display dependencies of sherpa-onnx for linux + shell: bash + run: | + file build/bin/sherpa-onnx + readelf -d build/bin/sherpa-onnx + + - name: Test online transducer + shell: bash + run: | + export PATH=$PWD/build/bin:$PATH + export EXE=sherpa-onnx + + .github/scripts/test-online-transducer.sh diff --git a/.github/workflows/macos.yaml b/.github/workflows/macos.yaml new file mode 100644 index 00000000..d6897552 --- /dev/null +++ b/.github/workflows/macos.yaml @@ -0,0 +1,73 @@ +name: macos + +on: + push: + branches: + - master + paths: + - '.github/workflows/macos.yaml' + - '.github/scripts/test-online-transducer.sh' + - 'CMakeLists.txt' + - 'cmake/**' + - 'sherpa-onnx/csrc/*' + pull_request: + branches: + - master + paths: + - '.github/workflows/macos.yaml' + - '.github/scripts/test-online-transducer.sh' + - 'CMakeLists.txt' + - 'cmake/**' + - 'sherpa-onnx/csrc/*' + +concurrency: + group: macos-${{ github.ref }} + cancel-in-progress: true + +permissions: + contents: read + +jobs: + macos: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [macos-latest] + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Configure CMake + shell: bash + run: | + mkdir build + cd build + cmake -D CMAKE_BUILD_TYPE=Release .. + + - name: Build sherpa for macos + shell: bash + run: | + cd build + make -j2 + + ls -lh lib + ls -lh bin + + + - name: Display dependencies of sherpa-onnx for macos + shell: bash + run: | + file bin/sherpa-onnx + otool -L build/bin/sherpa-onnx + otool -l build/bin/sherpa-onnx + + - name: Test online transducer + shell: bash + run: | + export PATH=$PWD/build/bin:$PATH + export EXE=sherpa-onnx + + .github/scripts/test-online-transducer.sh diff --git a/.github/workflows/test-linux-macos-windows.yaml b/.github/workflows/test-linux-macos-windows.yaml deleted file mode 100644 index 55c0f8bb..00000000 --- a/.github/workflows/test-linux-macos-windows.yaml +++ /dev/null @@ -1,207 +0,0 @@ -name: test-linux-macos-windows - -on: - push: - branches: - - master - paths: - - '.github/workflows/test-linux-macos-windows.yaml' - - 'CMakeLists.txt' - - 'cmake/**' - - 'sherpa-onnx/csrc/*' - pull_request: - branches: - - master - paths: - - '.github/workflows/test-linux-macos-windows.yaml' - - 'CMakeLists.txt' - - 'cmake/**' - - 'sherpa-onnx/csrc/*' - -concurrency: - group: test-linux-macos-windows-${{ github.ref }} - cancel-in-progress: true - -permissions: - contents: read - -jobs: - test-linux-macos-windows: - runs-on: ${{ matrix.os }} - strategy: - fail-fast: false - matrix: - os: [ubuntu-latest, macos-latest, windows-latest] - - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - # see https://github.com/microsoft/setup-msbuild - - name: Add msbuild to PATH - if: startsWith(matrix.os, 'windows') - uses: microsoft/setup-msbuild@v1.0.2 - - - name: Download pretrained model and test-data (English) - shell: bash - run: | - git lfs install - GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13 - cd icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13 - ls -lh exp/onnx/*.onnx - git lfs pull --include "exp/onnx/*.onnx" - ls -lh exp/onnx/*.onnx - - - name: Download pretrained model and test-data (Chinese) - shell: bash - run: | - GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2 - cd icefall_asr_wenetspeech_pruned_transducer_stateless2 - ls -lh exp/*.onnx - git lfs pull --include "exp/*.onnx" - ls -lh exp/*.onnx - - - name: Configure CMake - shell: bash - run: | - mkdir build - cd build - cmake -D CMAKE_BUILD_TYPE=Release .. - - - name: Build sherpa-onnx for ubuntu/macos - if: startsWith(matrix.os, 'ubuntu') || startsWith(matrix.os, 'macos') - shell: bash - run: | - cd build - make VERBOSE=1 -j3 - - - name: Build sherpa-onnx for Windows - if: startsWith(matrix.os, 'windows') - shell: bash - run: | - cmake --build ./build --config Release - - - name: Run tests for ubuntu/macos (English) - if: startsWith(matrix.os, 'ubuntu') || startsWith(matrix.os, 'macos') - shell: bash - run: | - time ./build/bin/sherpa-onnx \ - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \ - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \ - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \ - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \ - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \ - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \ - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1089-134686-0001.wav - - time ./build/bin/sherpa-onnx \ - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \ - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \ - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \ - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \ - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \ - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \ - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0001.wav - - time ./build/bin/sherpa-onnx \ - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \ - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \ - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \ - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \ - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \ - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \ - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0002.wav - - - name: Run tests for Windows (English) - if: startsWith(matrix.os, 'windows') - shell: bash - run: | - ./build/bin/Release/sherpa-onnx \ - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \ - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \ - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \ - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \ - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \ - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \ - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1089-134686-0001.wav - - ./build/bin/Release/sherpa-onnx \ - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \ - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \ - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \ - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \ - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \ - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \ - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0001.wav - - ./build/bin/Release/sherpa-onnx \ - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \ - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \ - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \ - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \ - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \ - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \ - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0002.wav - - - name: Run tests for ubuntu/macos (Chinese) - if: startsWith(matrix.os, 'ubuntu') || startsWith(matrix.os, 'macos') - shell: bash - run: | - time ./build/bin/sherpa-onnx \ - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/data/lang_char/tokens.txt \ - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/encoder-epoch-10-avg-2.onnx \ - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/decoder-epoch-10-avg-2.onnx \ - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner-epoch-10-avg-2.onnx \ - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_encoder_proj-epoch-10-avg-2.onnx \ - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_decoder_proj-epoch-10-avg-2.onnx \ - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/test_wavs/DEV_T0000000000.wav - - time ./build/bin/sherpa-onnx \ - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/data/lang_char/tokens.txt \ - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/encoder-epoch-10-avg-2.onnx \ - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/decoder-epoch-10-avg-2.onnx \ - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner-epoch-10-avg-2.onnx \ - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_encoder_proj-epoch-10-avg-2.onnx \ - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_decoder_proj-epoch-10-avg-2.onnx \ - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/test_wavs/DEV_T0000000001.wav - - time ./build/bin/sherpa-onnx \ - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/data/lang_char/tokens.txt \ - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/encoder-epoch-10-avg-2.onnx \ - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/decoder-epoch-10-avg-2.onnx \ - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner-epoch-10-avg-2.onnx \ - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_encoder_proj-epoch-10-avg-2.onnx \ - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_decoder_proj-epoch-10-avg-2.onnx \ - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/test_wavs/DEV_T0000000002.wav - - - name: Run tests for windows (Chinese) - if: startsWith(matrix.os, 'windows') - shell: bash - run: | - ./build/bin/Release/sherpa-onnx \ - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/data/lang_char/tokens.txt \ - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/encoder-epoch-10-avg-2.onnx \ - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/decoder-epoch-10-avg-2.onnx \ - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner-epoch-10-avg-2.onnx \ - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_encoder_proj-epoch-10-avg-2.onnx \ - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_decoder_proj-epoch-10-avg-2.onnx \ - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/test_wavs/DEV_T0000000000.wav - - ./build/bin/Release/sherpa-onnx \ - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/data/lang_char/tokens.txt \ - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/encoder-epoch-10-avg-2.onnx \ - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/decoder-epoch-10-avg-2.onnx \ - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner-epoch-10-avg-2.onnx \ - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_encoder_proj-epoch-10-avg-2.onnx \ - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_decoder_proj-epoch-10-avg-2.onnx \ - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/test_wavs/DEV_T0000000001.wav - - ./build/bin/Release/sherpa-onnx \ - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/data/lang_char/tokens.txt \ - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/encoder-epoch-10-avg-2.onnx \ - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/decoder-epoch-10-avg-2.onnx \ - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner-epoch-10-avg-2.onnx \ - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_encoder_proj-epoch-10-avg-2.onnx \ - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_decoder_proj-epoch-10-avg-2.onnx \ - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/test_wavs/DEV_T0000000002.wav diff --git a/.github/workflows/windows-x64.yaml b/.github/workflows/windows-x64.yaml new file mode 100644 index 00000000..42e4f70d --- /dev/null +++ b/.github/workflows/windows-x64.yaml @@ -0,0 +1,80 @@ +name: windows-x64 + +on: + push: + branches: + - master + paths: + - '.github/workflows/windows-x64.yaml' + - '.github/scripts/test-online-transducer.sh' + - 'CMakeLists.txt' + - 'cmake/**' + - 'sherpa-onnx/csrc/*' + pull_request: + branches: + - master + paths: + - '.github/workflows/windows-x64.yaml' + - '.github/scripts/test-online-transducer.sh' + - 'CMakeLists.txt' + - 'cmake/**' + - 'sherpa-onnx/csrc/*' + +concurrency: + group: windows-x64-${{ github.ref }} + cancel-in-progress: true + +permissions: + contents: read + +jobs: + windows_x64: + runs-on: ${{ matrix.os }} + name: ${{ matrix.vs-version }} + strategy: + fail-fast: false + matrix: + include: + - vs-version: vs2015 + toolset-version: v140 + os: windows-2019 + + - vs-version: vs2017 + toolset-version: v141 + os: windows-2019 + + - vs-version: vs2019 + toolset-version: v142 + os: windows-2022 + + - vs-version: vs2022 + toolset-version: v143 + os: windows-2022 + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Configure CMake + shell: bash + run: | + mkdir build + cd build + cmake -T ${{ matrix.toolset-version}},host=x64 -A x64 -D CMAKE_BUILD_TYPE=Release .. + + - name: Build sherpa-onnx for windows + shell: bash + run: | + cd build + cmake --build . --config Release -- -m:2 + + ls -lh ./bin/Release/sherpa-onnx.exe + + - name: Test sherpa-onnx for Windows x64 + shell: bash + run: | + export PATH=$PWD/build/bin/Release:$PATH + export EXE=sherpa-onnx.exe + + .github/scripts/test-online-transducer.sh diff --git a/.gitignore b/.gitignore index 06854d86..066c5948 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ build onnxruntime-* icefall-* run.sh +sherpa-onnx-* diff --git a/README.md b/README.md index 1ed0f234..83a5c336 100644 --- a/README.md +++ b/README.md @@ -2,89 +2,7 @@ Documentation: -Try it in colab: -[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1tmQbdlYeTl_klmtaGiUb7a7ZPz-AkBSH?usp=sharing) - See This repo uses [onnxruntime](https://github.com/microsoft/onnxruntime) and does not depend on libtorch. - -We provide exported models in onnx format and they can be downloaded using -the following links: - -- English: -- Chinese: - -**NOTE**: We provide only non-streaming models at present. - - -**HINT**: The script for exporting the English model can be found at - - -**HINT**: The script for exporting the Chinese model can be found at - - -## Build for Linux/macOS - -```bash -git clone https://github.com/k2-fsa/sherpa-onnx -cd sherpa-onnx -mkdir build -cd build -cmake -DCMAKE_BUILD_TYPE=Release .. -make -j6 -cd .. -``` - -## Build for Windows - -```bash -git clone https://github.com/k2-fsa/sherpa-onnx -cd sherpa-onnx -mkdir build -cd build -cmake -DCMAKE_BUILD_TYPE=Release .. -cmake --build . --config Release -cd .. -``` - -## Download the pretrained model (English) - -```bash -GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13 -cd icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13 -git lfs pull --include "exp/onnx/*.onnx" -cd .. - -./build/bin/sherpa-onnx --help - -./build/bin/sherpa-onnx \ - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \ - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \ - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \ - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \ - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \ - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \ - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1089-134686-0001.wav -``` - -## Download the pretrained model (Chinese) - -```bash -GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2 -cd icefall_asr_wenetspeech_pruned_transducer_stateless2 -git lfs pull --include "exp/*.onnx" -cd .. - -./build/bin/sherpa-onnx --help - -./build/bin/sherpa-onnx \ - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/data/lang_char/tokens.txt \ - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/encoder-epoch-10-avg-2.onnx \ - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/decoder-epoch-10-avg-2.onnx \ - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner-epoch-10-avg-2.onnx \ - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_encoder_proj-epoch-10-avg-2.onnx \ - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_decoder_proj-epoch-10-avg-2.onnx \ - ./icefall_asr_wenetspeech_pruned_transducer_stateless2/test_wavs/DEV_T0000000000.wav -``` diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index 072b3a6e..dbc4461e 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -2,7 +2,11 @@ include_directories(${CMAKE_SOURCE_DIR}) add_executable(sherpa-onnx decode.cc - rnnt-model.cc + features.cc + online-lstm-transducer-model.cc + online-transducer-model-config.cc + online-transducer-model.cc + onnx-utils.cc sherpa-onnx.cc symbol-table.cc wave-reader.cc @@ -13,5 +17,5 @@ target_link_libraries(sherpa-onnx kaldi-native-fbank-core ) -# add_executable(sherpa-show-onnx-info show-onnx-info.cc) -# target_link_libraries(sherpa-show-onnx-info onnxruntime) +add_executable(sherpa-onnx-show-info show-onnx-info.cc) +target_link_libraries(sherpa-onnx-show-info onnxruntime) diff --git a/sherpa-onnx/csrc/decode.cc b/sherpa-onnx/csrc/decode.cc index 5e5cf65b..5f0cb0f1 100644 --- a/sherpa-onnx/csrc/decode.cc +++ b/sherpa-onnx/csrc/decode.cc @@ -1,84 +1,79 @@ -/** - * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) - * - * See LICENSE for clarification regarding multiple authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +// sherpa/csrc/decode.cc +// +// Copyright (c) 2023 Xiaomi Corporation #include "sherpa-onnx/csrc/decode.h" #include #include +#include #include namespace sherpa_onnx { -std::vector GreedySearch(RnntModel &model, // NOLINT - const Ort::Value &encoder_out) { +static Ort::Value Clone(Ort::Value *v) { + auto type_and_shape = v->GetTensorTypeAndShapeInfo(); + std::vector shape = type_and_shape.GetShape(); + + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + return Ort::Value::CreateTensor(memory_info, v->GetTensorMutableData(), + type_and_shape.GetElementCount(), + shape.data(), shape.size()); +} + +static Ort::Value GetFrame(Ort::Value *encoder_out, int32_t t) { + std::vector encoder_out_shape = + encoder_out->GetTensorTypeAndShapeInfo().GetShape(); + assert(encoder_out_shape[0] == 1); + + int32_t encoder_out_dim = encoder_out_shape[2]; + + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + std::array shape{1, encoder_out_dim}; + + return Ort::Value::CreateTensor( + memory_info, + encoder_out->GetTensorMutableData() + t * encoder_out_dim, + encoder_out_dim, shape.data(), shape.size()); +} + +void GreedySearch(OnlineTransducerModel *model, Ort::Value encoder_out, + std::vector *hyp) { std::vector encoder_out_shape = encoder_out.GetTensorTypeAndShapeInfo().GetShape(); - assert(encoder_out_shape[0] == 1 && "Only batch_size=1 is implemented"); - Ort::Value projected_encoder_out = - model.RunJoinerEncoderProj(encoder_out.GetTensorData(), - encoder_out_shape[1], encoder_out_shape[2]); - const float *p_projected_encoder_out = - projected_encoder_out.GetTensorData(); + if (encoder_out_shape[0] > 1) { + fprintf(stderr, "Only batch_size=1 is implemented. Given: %d\n", + static_cast(encoder_out_shape[0])); + } - int32_t context_size = 2; // hard-code it to 2 - int32_t blank_id = 0; // hard-code it to 0 - std::vector hyp(context_size, blank_id); - std::array decoder_input{blank_id, blank_id}; + int32_t num_frames = encoder_out_shape[1]; + int32_t vocab_size = model->VocabSize(); - Ort::Value decoder_out = model.RunDecoder(decoder_input.data(), context_size); - - std::vector decoder_out_shape = - decoder_out.GetTensorTypeAndShapeInfo().GetShape(); - - Ort::Value projected_decoder_out = model.RunJoinerDecoderProj( - decoder_out.GetTensorData(), decoder_out_shape[2]); - - int32_t joiner_dim = - projected_decoder_out.GetTensorTypeAndShapeInfo().GetShape()[1]; - - int32_t T = encoder_out_shape[1]; - for (int32_t t = 0; t != T; ++t) { - Ort::Value logit = model.RunJoiner( - p_projected_encoder_out + t * joiner_dim, - projected_decoder_out.GetTensorData(), joiner_dim); - - int32_t vocab_size = logit.GetTensorTypeAndShapeInfo().GetShape()[1]; + Ort::Value decoder_input = model->BuildDecoderInput(*hyp); + Ort::Value decoder_out = model->RunDecoder(std::move(decoder_input)); + for (int32_t t = 0; t != num_frames; ++t) { + Ort::Value cur_encoder_out = GetFrame(&encoder_out, t); + Ort::Value logit = + model->RunJoiner(std::move(cur_encoder_out), Clone(&decoder_out)); const float *p_logit = logit.GetTensorData(); auto y = static_cast(std::distance( static_cast(p_logit), std::max_element(static_cast(p_logit), static_cast(p_logit) + vocab_size))); - - if (y != blank_id) { - decoder_input[0] = hyp.back(); - decoder_input[1] = y; - hyp.push_back(y); - decoder_out = model.RunDecoder(decoder_input.data(), context_size); - projected_decoder_out = model.RunJoinerDecoderProj( - decoder_out.GetTensorData(), decoder_out_shape[2]); + if (y != 0) { + hyp->push_back(y); + decoder_input = model->BuildDecoderInput(*hyp); + decoder_out = model->RunDecoder(std::move(decoder_input)); } } - - return {hyp.begin() + context_size, hyp.end()}; } } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/decode.h b/sherpa-onnx/csrc/decode.h index 7511247c..88821573 100644 --- a/sherpa-onnx/csrc/decode.h +++ b/sherpa-onnx/csrc/decode.h @@ -1,27 +1,13 @@ -/** - * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) - * - * See LICENSE for clarification regarding multiple authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +// sherpa/csrc/decode.h +// +// Copyright (c) 2023 Xiaomi Corporation #ifndef SHERPA_ONNX_CSRC_DECODE_H_ #define SHERPA_ONNX_CSRC_DECODE_H_ #include -#include "sherpa-onnx/csrc/rnnt-model.h" +#include "sherpa-onnx/csrc/online-transducer-model.h" namespace sherpa_onnx { @@ -32,8 +18,8 @@ namespace sherpa_onnx { * @param model The RnntModel * @param encoder_out Its shape is (1, num_frames, encoder_out_dim). */ -std::vector GreedySearch(RnntModel &model, // NOLINT - const Ort::Value &encoder_out); +void GreedySearch(OnlineTransducerModel *model, Ort::Value encoder_out, + std::vector *hyp); } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/features.cc b/sherpa-onnx/csrc/features.cc new file mode 100644 index 00000000..7bb47850 --- /dev/null +++ b/sherpa-onnx/csrc/features.cc @@ -0,0 +1,79 @@ +// sherpa/csrc/features.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/features.h" + +#include +#include +#include + +namespace sherpa_onnx { + +FeatureExtractor::FeatureExtractor() { + opts_.frame_opts.dither = 0; + opts_.frame_opts.snip_edges = false; + opts_.frame_opts.samp_freq = 16000; + + // cache 100 seconds of feature frames, which is more than enough + // for real needs + opts_.frame_opts.max_feature_vectors = 100 * 100; + + opts_.mel_opts.num_bins = 80; // feature dim + + fbank_ = std::make_unique(opts_); +} + +FeatureExtractor::FeatureExtractor(const knf::FbankOptions &opts) + : opts_(opts) { + fbank_ = std::make_unique(opts_); +} + +void FeatureExtractor::AcceptWaveform(float sampling_rate, + const float *waveform, int32_t n) { + std::lock_guard lock(mutex_); + fbank_->AcceptWaveform(sampling_rate, waveform, n); +} + +void FeatureExtractor::InputFinished() { + std::lock_guard lock(mutex_); + fbank_->InputFinished(); +} + +int32_t FeatureExtractor::NumFramesReady() const { + std::lock_guard lock(mutex_); + return fbank_->NumFramesReady(); +} + +bool FeatureExtractor::IsLastFrame(int32_t frame) const { + std::lock_guard lock(mutex_); + return fbank_->IsLastFrame(frame); +} + +std::vector FeatureExtractor::GetFrames(int32_t frame_index, + int32_t n) const { + if (frame_index + n > NumFramesReady()) { + fprintf(stderr, "%d + %d > %d\n", frame_index, n, NumFramesReady()); + exit(-1); + } + std::lock_guard lock(mutex_); + + int32_t feature_dim = fbank_->Dim(); + std::vector features(feature_dim * n); + + float *p = features.data(); + + for (int32_t i = 0; i != n; ++i) { + const float *f = fbank_->GetFrame(i + frame_index); + std::copy(f, f + feature_dim, p); + p += feature_dim; + } + + return features; +} + +void FeatureExtractor::Reset() { + fbank_ = std::make_unique(opts_); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/features.h b/sherpa-onnx/csrc/features.h new file mode 100644 index 00000000..8e569c2b --- /dev/null +++ b/sherpa-onnx/csrc/features.h @@ -0,0 +1,61 @@ +// sherpa/csrc/features.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_FEATURES_H_ +#define SHERPA_ONNX_CSRC_FEATURES_H_ + +#include +#include // NOLINT +#include + +#include "kaldi-native-fbank/csrc/online-feature.h" + +namespace sherpa_onnx { + +class FeatureExtractor { + public: + FeatureExtractor(); + explicit FeatureExtractor(const knf::FbankOptions &fbank_opts); + + /** + @param sampling_rate The sampling_rate of the input waveform. Should match + the one expected by the feature extractor. + @param waveform Pointer to a 1-D array of size n + @param n Number of entries in waveform + */ + void AcceptWaveform(float sampling_rate, const float *waveform, int32_t n); + + // InputFinished() tells the class you won't be providing any + // more waveform. This will help flush out the last frame or two + // of features, in the case where snip-edges == false; it also + // affects the return value of IsLastFrame(). + void InputFinished(); + + int32_t NumFramesReady() const; + + // Note: IsLastFrame() will only ever return true if you have called + // InputFinished() (and this frame is the last frame). + bool IsLastFrame(int32_t frame) const; + + /** Get n frames starting from the given frame index. + * + * @param frame_index The starting frame index + * @param n Number of frames to get. + * @return Return a 2-D tensor of shape (n, feature_dim). + * which is flattened into a 1-D vector (flattened in in row major) + */ + std::vector GetFrames(int32_t frame_index, int32_t n) const; + + void Reset(); + int32_t FeatureDim() const { return opts_.mel_opts.num_bins; } + + private: + std::unique_ptr fbank_; + knf::FbankOptions opts_; + mutable std::mutex mutex_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_FEATURES_H_ diff --git a/sherpa-onnx/csrc/online-lstm-transducer-model.cc b/sherpa-onnx/csrc/online-lstm-transducer-model.cc new file mode 100644 index 00000000..f55c804a --- /dev/null +++ b/sherpa-onnx/csrc/online-lstm-transducer-model.cc @@ -0,0 +1,223 @@ +// sherpa/csrc/online-lstm-transducer-model.cc +// +// Copyright (c) 2023 Xiaomi Corporation +#include "sherpa-onnx/csrc/online-lstm-transducer-model.h" + +#include +#include +#include +#include +#include + +#include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/onnx-utils.h" + +#define SHERPA_ONNX_READ_META_DATA(dst, src_key) \ + do { \ + auto value = \ + meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ + if (!value) { \ + fprintf(stderr, "%s does not exist in the metadata\n", src_key); \ + exit(-1); \ + } \ + dst = atoi(value.get()); \ + if (dst <= 0) { \ + fprintf(stderr, "Invalud value %d for %s\n", dst, src_key); \ + exit(-1); \ + } \ + } while (0) + +namespace sherpa_onnx { + +OnlineLstmTransducerModel::OnlineLstmTransducerModel( + const OnlineTransducerModelConfig &config) + : env_(ORT_LOGGING_LEVEL_WARNING), + config_(config), + sess_opts_{}, + allocator_{} { + sess_opts_.SetIntraOpNumThreads(config.num_threads); + sess_opts_.SetInterOpNumThreads(config.num_threads); + + InitEncoder(config.encoder_filename); + InitDecoder(config.decoder_filename); + InitJoiner(config.joiner_filename); +} + +void OnlineLstmTransducerModel::InitEncoder(const std::string &filename) { + encoder_sess_ = std::make_unique( + env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_); + + GetInputNames(encoder_sess_.get(), &encoder_input_names_, + &encoder_input_names_ptr_); + + GetOutputNames(encoder_sess_.get(), &encoder_output_names_, + &encoder_output_names_ptr_); + + // get meta data + Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata(); + if (config_.debug) { + std::ostringstream os; + os << "---encoder---\n"; + PrintModelMetadata(os, meta_data); + fprintf(stderr, "%s\n", os.str().c_str()); + } + + Ort::AllocatorWithDefaultOptions allocator; + SHERPA_ONNX_READ_META_DATA(num_encoder_layers_, "num_encoder_layers"); + SHERPA_ONNX_READ_META_DATA(T_, "T"); + SHERPA_ONNX_READ_META_DATA(decode_chunk_len_, "decode_chunk_len"); + SHERPA_ONNX_READ_META_DATA(rnn_hidden_size_, "rnn_hidden_size"); + SHERPA_ONNX_READ_META_DATA(d_model_, "d_model"); +} + +void OnlineLstmTransducerModel::InitDecoder(const std::string &filename) { + decoder_sess_ = std::make_unique( + env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_); + + GetInputNames(decoder_sess_.get(), &decoder_input_names_, + &decoder_input_names_ptr_); + + GetOutputNames(decoder_sess_.get(), &decoder_output_names_, + &decoder_output_names_ptr_); + + // get meta data + Ort::ModelMetadata meta_data = decoder_sess_->GetModelMetadata(); + if (config_.debug) { + std::ostringstream os; + os << "---decoder---\n"; + PrintModelMetadata(os, meta_data); + fprintf(stderr, "%s\n", os.str().c_str()); + } + + Ort::AllocatorWithDefaultOptions allocator; + SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size"); + SHERPA_ONNX_READ_META_DATA(context_size_, "context_size"); +} + +void OnlineLstmTransducerModel::InitJoiner(const std::string &filename) { + joiner_sess_ = std::make_unique( + env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_); + + GetInputNames(joiner_sess_.get(), &joiner_input_names_, + &joiner_input_names_ptr_); + + GetOutputNames(joiner_sess_.get(), &joiner_output_names_, + &joiner_output_names_ptr_); + + // get meta data + Ort::ModelMetadata meta_data = joiner_sess_->GetModelMetadata(); + if (config_.debug) { + std::ostringstream os; + os << "---joiner---\n"; + PrintModelMetadata(os, meta_data); + fprintf(stderr, "%s\n", os.str().c_str()); + } +} + +Ort::Value OnlineLstmTransducerModel::StackStates( + const std::vector &states) const { + fprintf(stderr, "implement me: %s:%d!\n", __func__, + static_cast(__LINE__)); + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + int64_t a; + std::array x_shape{1, 1, 1}; + Ort::Value x = Ort::Value::CreateTensor(memory_info, &a, 0, &a, 0); + return x; +} + +std::vector OnlineLstmTransducerModel::UnStackStates( + Ort::Value states) const { + fprintf(stderr, "implement me: %s:%d!\n", __func__, + static_cast(__LINE__)); + return {}; +} + +std::vector OnlineLstmTransducerModel::GetEncoderInitStates() { + // Please see + // https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx.py#L185 + // for details + constexpr int32_t kBatchSize = 1; + std::array h_shape{num_encoder_layers_, kBatchSize, d_model_}; + Ort::Value h = Ort::Value::CreateTensor(allocator_, h_shape.data(), + h_shape.size()); + + std::fill(h.GetTensorMutableData(), + h.GetTensorMutableData() + + num_encoder_layers_ * kBatchSize * d_model_, + 0); + + std::array c_shape{num_encoder_layers_, kBatchSize, + rnn_hidden_size_}; + Ort::Value c = Ort::Value::CreateTensor(allocator_, c_shape.data(), + c_shape.size()); + + std::fill(c.GetTensorMutableData(), + c.GetTensorMutableData() + + num_encoder_layers_ * kBatchSize * rnn_hidden_size_, + 0); + + std::vector states; + + states.reserve(2); + states.push_back(std::move(h)); + states.push_back(std::move(c)); + + return states; +} + +std::pair> +OnlineLstmTransducerModel::RunEncoder(Ort::Value features, + std::vector &states) { + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + std::array encoder_inputs = { + std::move(features), std::move(states[0]), std::move(states[1])}; + + auto encoder_out = encoder_sess_->Run( + {}, encoder_input_names_ptr_.data(), encoder_inputs.data(), + encoder_inputs.size(), encoder_output_names_ptr_.data(), + encoder_output_names_ptr_.size()); + + std::vector next_states; + next_states.reserve(2); + next_states.push_back(std::move(encoder_out[1])); + next_states.push_back(std::move(encoder_out[2])); + + return {std::move(encoder_out[0]), std::move(next_states)}; +} + +Ort::Value OnlineLstmTransducerModel::BuildDecoderInput( + const std::vector &hyp) { + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + std::array shape{1, context_size_}; + + return Ort::Value::CreateTensor( + memory_info, + const_cast(hyp.data() + hyp.size() - context_size_), + context_size_, shape.data(), shape.size()); +} + +Ort::Value OnlineLstmTransducerModel::RunDecoder(Ort::Value decoder_input) { + auto decoder_out = decoder_sess_->Run( + {}, decoder_input_names_ptr_.data(), &decoder_input, 1, + decoder_output_names_ptr_.data(), decoder_output_names_ptr_.size()); + return std::move(decoder_out[0]); +} + +Ort::Value OnlineLstmTransducerModel::RunJoiner(Ort::Value encoder_out, + Ort::Value decoder_out) { + std::array joiner_input = {std::move(encoder_out), + std::move(decoder_out)}; + auto logit = + joiner_sess_->Run({}, joiner_input_names_ptr_.data(), joiner_input.data(), + joiner_input.size(), joiner_output_names_ptr_.data(), + joiner_output_names_ptr_.size()); + + return std::move(logit[0]); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-lstm-transducer-model.h b/sherpa-onnx/csrc/online-lstm-transducer-model.h new file mode 100644 index 00000000..6dc03d8e --- /dev/null +++ b/sherpa-onnx/csrc/online-lstm-transducer-model.h @@ -0,0 +1,91 @@ +// sherpa/csrc/online-lstm-transducer-model.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_ONLINE_LSTM_TRANSDUCER_MODEL_H_ +#define SHERPA_ONNX_CSRC_ONLINE_LSTM_TRANSDUCER_MODEL_H_ + +#include +#include +#include +#include + +#include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/online-transducer-model-config.h" +#include "sherpa-onnx/csrc/online-transducer-model.h" + +namespace sherpa_onnx { + +class OnlineLstmTransducerModel : public OnlineTransducerModel { + public: + explicit OnlineLstmTransducerModel(const OnlineTransducerModelConfig &config); + + Ort::Value StackStates(const std::vector &states) const override; + + std::vector UnStackStates(Ort::Value states) const override; + + std::vector GetEncoderInitStates() override; + + std::pair> RunEncoder( + Ort::Value features, std::vector &states) override; + + Ort::Value BuildDecoderInput(const std::vector &hyp) override; + + Ort::Value RunDecoder(Ort::Value decoder_input) override; + + Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) override; + + int32_t ContextSize() const override { return context_size_; } + + int32_t ChunkSize() const override { return T_; } + + int32_t ChunkShift() const override { return decode_chunk_len_; } + + int32_t VocabSize() const override { return vocab_size_; } + + private: + void InitEncoder(const std::string &encoder_filename); + void InitDecoder(const std::string &decoder_filename); + void InitJoiner(const std::string &joiner_filename); + + private: + Ort::Env env_; + Ort::SessionOptions sess_opts_; + + Ort::AllocatorWithDefaultOptions allocator_; + + std::unique_ptr encoder_sess_; + std::unique_ptr decoder_sess_; + std::unique_ptr joiner_sess_; + + std::vector encoder_input_names_; + std::vector encoder_input_names_ptr_; + + std::vector encoder_output_names_; + std::vector encoder_output_names_ptr_; + + std::vector decoder_input_names_; + std::vector decoder_input_names_ptr_; + + std::vector decoder_output_names_; + std::vector decoder_output_names_ptr_; + + std::vector joiner_input_names_; + std::vector joiner_input_names_ptr_; + + std::vector joiner_output_names_; + std::vector joiner_output_names_ptr_; + + OnlineTransducerModelConfig config_; + + int32_t num_encoder_layers_ = 0; + int32_t T_ = 0; + int32_t decode_chunk_len_ = 0; + int32_t rnn_hidden_size_ = 0; + int32_t d_model_ = 0; + int32_t context_size_ = 0; + int32_t vocab_size_ = 0; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_ONLINE_LSTM_TRANSDUCER_MODEL_H_ diff --git a/sherpa-onnx/csrc/online-transducer-model-config.cc b/sherpa-onnx/csrc/online-transducer-model-config.cc new file mode 100644 index 00000000..5fcb09ee --- /dev/null +++ b/sherpa-onnx/csrc/online-transducer-model-config.cc @@ -0,0 +1,23 @@ +// sherpa/csrc/online-transducer-model-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation +#include "sherpa-onnx/csrc/online-transducer-model-config.h" + +#include + +namespace sherpa_onnx { + +std::string OnlineTransducerModelConfig::ToString() const { + std::ostringstream os; + + os << "OnlineTransducerModelConfig("; + os << "encoder_filename=\"" << encoder_filename << "\", "; + os << "decoder_filename=\"" << decoder_filename << "\", "; + os << "joiner_filename=\"" << joiner_filename << "\", "; + os << "num_threads=" << num_threads << ", "; + os << "debug=" << (debug ? "True" : "False") << ")"; + + return os.str(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-transducer-model-config.h b/sherpa-onnx/csrc/online-transducer-model-config.h new file mode 100644 index 00000000..ca2e5dbc --- /dev/null +++ b/sherpa-onnx/csrc/online-transducer-model-config.h @@ -0,0 +1,23 @@ +// sherpa/csrc/online-transducer-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODEL_CONFIG_H_ + +#include + +namespace sherpa_onnx { + +struct OnlineTransducerModelConfig { + std::string encoder_filename; + std::string decoder_filename; + std::string joiner_filename; + int32_t num_threads; + bool debug = false; + + std::string ToString() const; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODEL_CONFIG_H_ diff --git a/sherpa-onnx/csrc/online-transducer-model.cc b/sherpa-onnx/csrc/online-transducer-model.cc new file mode 100644 index 00000000..27af24e7 --- /dev/null +++ b/sherpa-onnx/csrc/online-transducer-model.cc @@ -0,0 +1,64 @@ +// sherpa/csrc/online-transducer-model.cc +// +// Copyright (c) 2023 Xiaomi Corporation +#include "sherpa-onnx/csrc/online-transducer-model.h" + +#include +#include +#include + +#include "sherpa-onnx/csrc/online-lstm-transducer-model.h" +#include "sherpa-onnx/csrc/onnx-utils.h" +namespace sherpa_onnx { + +enum class ModelType { + kLstm, + kUnkown, +}; + +static ModelType GetModelType(const OnlineTransducerModelConfig &config) { + Ort::Env env(ORT_LOGGING_LEVEL_WARNING); + Ort::SessionOptions sess_opts; + + auto sess = std::make_unique( + env, SHERPA_MAYBE_WIDE(config.encoder_filename).c_str(), sess_opts); + + Ort::ModelMetadata meta_data = sess->GetModelMetadata(); + if (config.debug) { + std::ostringstream os; + PrintModelMetadata(os, meta_data); + fprintf(stderr, "%s\n", os.str().c_str()); + } + + Ort::AllocatorWithDefaultOptions allocator; + auto model_type = + meta_data.LookupCustomMetadataMapAllocated("model_type", allocator); + if (!model_type) { + fprintf(stderr, "No model_type in the metadata!\n"); + return ModelType::kUnkown; + } + + if (model_type.get() == std::string("lstm")) { + return ModelType::kLstm; + } else { + fprintf(stderr, "Unsupported model_type: %s\n", model_type.get()); + return ModelType::kUnkown; + } +} + +std::unique_ptr OnlineTransducerModel::Create( + const OnlineTransducerModelConfig &config) { + auto model_type = GetModelType(config); + + switch (model_type) { + case ModelType::kLstm: + return std::make_unique(config); + case ModelType::kUnkown: + return nullptr; + } + + // unreachable code + return nullptr; +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-transducer-model.h b/sherpa-onnx/csrc/online-transducer-model.h new file mode 100644 index 00000000..dfcf9452 --- /dev/null +++ b/sherpa-onnx/csrc/online-transducer-model.h @@ -0,0 +1,118 @@ +// sherpa/csrc/online-transducer-model.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODEL_H_ +#define SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODEL_H_ + +#include +#include +#include + +#include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/online-transducer-model-config.h" + +namespace sherpa_onnx { + +class OnlineTransducerModel { + public: + virtual ~OnlineTransducerModel() = default; + + static std::unique_ptr Create( + const OnlineTransducerModelConfig &config); + + /** Stack a list of individual states into a batch. + * + * It is the inverse operation of `UnStackStates`. + * + * @param states states[i] contains the state for the i-th utterance. + * @return Return a single value representing the batched state. + */ + virtual Ort::Value StackStates( + const std::vector &states) const = 0; + + /** Unstack a batch state into a list of individual states. + * + * It is the inverse operation of `StackStates`. + * + * @param states A batched state. + * @return ans[i] contains the state for the i-th utterance. + */ + virtual std::vector UnStackStates(Ort::Value states) const = 0; + + /** Get the initial encoder states. + * + * @return Return the initial encoder state. + */ + virtual std::vector GetEncoderInitStates() = 0; + + /** Run the encoder. + * + * @param features A tensor of shape (N, T, C). It is changed in-place. + * @param states Encoder state of the previous chunk. It is changed in-place. + * + * @return Return a tuple containing: + * - encoder_out, a tensor of shape (N, T', encoder_out_dim) + * - next_states Encoder state for the next chunk. + */ + virtual std::pair> RunEncoder( + Ort::Value features, + std::vector &states) = 0; // NOLINT + + virtual Ort::Value BuildDecoderInput(const std::vector &hyp) = 0; + + /** Run the decoder network. + * + * Caution: We assume there are no recurrent connections in the decoder and + * the decoder is stateless. See + * https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py + * for an example + * + * @param decoder_input It is usually of shape (N, context_size) + * @return Return a tensor of shape (N, decoder_dim). + */ + virtual Ort::Value RunDecoder(Ort::Value decoder_input) = 0; + + /** Run the joint network. + * + * @param encoder_out Output of the encoder network. A tensor of shape + * (N, joiner_dim). + * @param decoder_out Output of the decoder network. A tensor of shape + * (N, joiner_dim). + * @return Return a tensor of shape (N, vocab_size). In icefall, the last + * last layer of the joint network is `nn.Linear`, + * not `nn.LogSoftmax`. + */ + virtual Ort::Value RunJoiner(Ort::Value encoder_out, + Ort::Value decoder_out) = 0; + + /** If we are using a stateless decoder and if it contains a + * Conv1D, this function returns the kernel size of the convolution layer. + */ + virtual int32_t ContextSize() const = 0; + + /** We send this number of feature frames to the encoder at a time. */ + virtual int32_t ChunkSize() const = 0; + + /** Number of input frames to discard after each call to RunEncoder. + * + * For instance, if we have 30 frames, chunk_size=8, chunk_shift=6. + * + * In the first call of RunEncoder, we use frames 0~7 since chunk_size is 8. + * Then we discard frame 0~5 since chunk_shift is 6. + * In the second call of RunEncoder, we use frames 6~13; and then we discard + * frames 6~11. + * In the third call of RunEncoder, we use frames 12~19; and then we discard + * frames 12~16. + * + * Note: ChunkSize() - ChunkShift() == right context size + */ + virtual int32_t ChunkShift() const = 0; + + virtual int32_t VocabSize() const = 0; + + virtual int32_t SubsamplingFactor() const { return 4; } +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODEL_H_ diff --git a/sherpa-onnx/csrc/onnx-utils.cc b/sherpa-onnx/csrc/onnx-utils.cc new file mode 100644 index 00000000..47dd1576 --- /dev/null +++ b/sherpa-onnx/csrc/onnx-utils.cc @@ -0,0 +1,49 @@ +// sherpa/csrc/onnx-utils.cc +// +// Copyright (c) 2023 Xiaomi Corporation +#include "sherpa-onnx/csrc/onnx-utils.h" + +#include +#include + +#include "onnxruntime_cxx_api.h" // NOLINT + +namespace sherpa_onnx { + +void GetInputNames(Ort::Session *sess, std::vector *input_names, + std::vector *input_names_ptr) { + Ort::AllocatorWithDefaultOptions allocator; + size_t node_count = sess->GetInputCount(); + input_names->resize(node_count); + input_names_ptr->resize(node_count); + for (size_t i = 0; i != node_count; ++i) { + auto tmp = sess->GetInputNameAllocated(i, allocator); + (*input_names)[i] = tmp.get(); + (*input_names_ptr)[i] = (*input_names)[i].c_str(); + } +} + +void GetOutputNames(Ort::Session *sess, std::vector *output_names, + std::vector *output_names_ptr) { + Ort::AllocatorWithDefaultOptions allocator; + size_t node_count = sess->GetOutputCount(); + output_names->resize(node_count); + output_names_ptr->resize(node_count); + for (size_t i = 0; i != node_count; ++i) { + auto tmp = sess->GetOutputNameAllocated(i, allocator); + (*output_names)[i] = tmp.get(); + (*output_names_ptr)[i] = (*output_names)[i].c_str(); + } +} + +void PrintModelMetadata(std::ostream &os, const Ort::ModelMetadata &meta_data) { + Ort::AllocatorWithDefaultOptions allocator; + std::vector v = + meta_data.GetCustomMetadataMapKeysAllocated(allocator); + for (const auto &key : v) { + auto p = meta_data.LookupCustomMetadataMapAllocated(key.get(), allocator); + os << key.get() << "=" << p.get() << "\n"; + } +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/onnx-utils.h b/sherpa-onnx/csrc/onnx-utils.h new file mode 100644 index 00000000..f7f5677e --- /dev/null +++ b/sherpa-onnx/csrc/onnx-utils.h @@ -0,0 +1,60 @@ +// sherpa/csrc/onnx-utils.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_ONNX_UTILS_H_ +#define SHERPA_ONNX_CSRC_ONNX_UTILS_H_ + +#ifdef _MSC_VER +// For ToWide() below +#include +#include +#endif + +#include +#include +#include + +#include "onnxruntime_cxx_api.h" // NOLINT + +namespace sherpa_onnx { + +#ifdef _MSC_VER +// See +// https://stackoverflow.com/questions/2573834/c-convert-string-or-char-to-wstring-or-wchar-t +static std::wstring ToWide(const std::string &s) { + std::wstring_convert> converter; + return converter.from_bytes(s); +} +#define SHERPA_MAYBE_WIDE(s) ToWide(s) +#else +#define SHERPA_MAYBE_WIDE(s) s +#endif + +/** + * Get the input names of a model. + * + * @param sess An onnxruntime session. + * @param input_names. On return, it contains the input names of the model. + * @param input_names_ptr. On return, input_names_ptr[i] contains + * input_names[i].c_str() + */ +void GetInputNames(Ort::Session *sess, std::vector *input_names, + std::vector *input_names_ptr); + +/** + * Get the output names of a model. + * + * @param sess An onnxruntime session. + * @param output_names. On return, it contains the output names of the model. + * @param output_names_ptr. On return, output_names_ptr[i] contains + * output_names[i].c_str() + */ +void GetOutputNames(Ort::Session *sess, std::vector *output_names, + std::vector *output_names_ptr); + +void PrintModelMetadata(std::ostream &os, + const Ort::ModelMetadata &meta_data); // NOLINT + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_ONNX_UTILS_H_ diff --git a/sherpa-onnx/csrc/rnnt-model.cc b/sherpa-onnx/csrc/rnnt-model.cc deleted file mode 100644 index ead5d023..00000000 --- a/sherpa-onnx/csrc/rnnt-model.cc +++ /dev/null @@ -1,265 +0,0 @@ -/** - * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) - * - * See LICENSE for clarification regarding multiple authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "sherpa-onnx/csrc/rnnt-model.h" - -#include -#include -#include - -#ifdef _MSC_VER -// For ToWide() below -#include -#include -#endif - -namespace sherpa_onnx { - -#ifdef _MSC_VER -// See -// https://stackoverflow.com/questions/2573834/c-convert-string-or-char-to-wstring-or-wchar-t -static std::wstring ToWide(const std::string &s) { - std::wstring_convert> converter; - return converter.from_bytes(s); -} -#define SHERPA_MAYBE_WIDE(s) ToWide(s) -#else -#define SHERPA_MAYBE_WIDE(s) s -#endif - -/** - * Get the input names of a model. - * - * @param sess An onnxruntime session. - * @param input_names. On return, it contains the input names of the model. - * @param input_names_ptr. On return, input_names_ptr[i] contains - * input_names[i].c_str() - */ -static void GetInputNames(Ort::Session *sess, - std::vector *input_names, - std::vector *input_names_ptr) { - Ort::AllocatorWithDefaultOptions allocator; - size_t node_count = sess->GetInputCount(); - input_names->resize(node_count); - input_names_ptr->resize(node_count); - for (size_t i = 0; i != node_count; ++i) { - auto tmp = sess->GetInputNameAllocated(i, allocator); - (*input_names)[i] = tmp.get(); - (*input_names_ptr)[i] = (*input_names)[i].c_str(); - } -} - -/** - * Get the output names of a model. - * - * @param sess An onnxruntime session. - * @param output_names. On return, it contains the output names of the model. - * @param output_names_ptr. On return, output_names_ptr[i] contains - * output_names[i].c_str() - */ -static void GetOutputNames(Ort::Session *sess, - std::vector *output_names, - std::vector *output_names_ptr) { - Ort::AllocatorWithDefaultOptions allocator; - size_t node_count = sess->GetOutputCount(); - output_names->resize(node_count); - output_names_ptr->resize(node_count); - for (size_t i = 0; i != node_count; ++i) { - auto tmp = sess->GetOutputNameAllocated(i, allocator); - (*output_names)[i] = tmp.get(); - (*output_names_ptr)[i] = (*output_names)[i].c_str(); - } -} - -RnntModel::RnntModel(const std::string &encoder_filename, - const std::string &decoder_filename, - const std::string &joiner_filename, - const std::string &joiner_encoder_proj_filename, - const std::string &joiner_decoder_proj_filename, - int32_t num_threads) - : env_(ORT_LOGGING_LEVEL_WARNING) { - sess_opts_.SetIntraOpNumThreads(num_threads); - sess_opts_.SetInterOpNumThreads(num_threads); - - InitEncoder(encoder_filename); - InitDecoder(decoder_filename); - InitJoiner(joiner_filename); - InitJoinerEncoderProj(joiner_encoder_proj_filename); - InitJoinerDecoderProj(joiner_decoder_proj_filename); -} - -void RnntModel::InitEncoder(const std::string &filename) { - encoder_sess_ = std::make_unique( - env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_); - GetInputNames(encoder_sess_.get(), &encoder_input_names_, - &encoder_input_names_ptr_); - - GetOutputNames(encoder_sess_.get(), &encoder_output_names_, - &encoder_output_names_ptr_); -} - -void RnntModel::InitDecoder(const std::string &filename) { - decoder_sess_ = std::make_unique( - env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_); - - GetInputNames(decoder_sess_.get(), &decoder_input_names_, - &decoder_input_names_ptr_); - - GetOutputNames(decoder_sess_.get(), &decoder_output_names_, - &decoder_output_names_ptr_); -} - -void RnntModel::InitJoiner(const std::string &filename) { - joiner_sess_ = std::make_unique( - env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_); - - GetInputNames(joiner_sess_.get(), &joiner_input_names_, - &joiner_input_names_ptr_); - - GetOutputNames(joiner_sess_.get(), &joiner_output_names_, - &joiner_output_names_ptr_); -} - -void RnntModel::InitJoinerEncoderProj(const std::string &filename) { - joiner_encoder_proj_sess_ = std::make_unique( - env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_); - - GetInputNames(joiner_encoder_proj_sess_.get(), - &joiner_encoder_proj_input_names_, - &joiner_encoder_proj_input_names_ptr_); - - GetOutputNames(joiner_encoder_proj_sess_.get(), - &joiner_encoder_proj_output_names_, - &joiner_encoder_proj_output_names_ptr_); -} - -void RnntModel::InitJoinerDecoderProj(const std::string &filename) { - joiner_decoder_proj_sess_ = std::make_unique( - env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_); - - GetInputNames(joiner_decoder_proj_sess_.get(), - &joiner_decoder_proj_input_names_, - &joiner_decoder_proj_input_names_ptr_); - - GetOutputNames(joiner_decoder_proj_sess_.get(), - &joiner_decoder_proj_output_names_, - &joiner_decoder_proj_output_names_ptr_); -} - -Ort::Value RnntModel::RunEncoder(const float *features, int32_t T, - int32_t feature_dim) { - auto memory_info = - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); - std::array x_shape{1, T, feature_dim}; - Ort::Value x = - Ort::Value::CreateTensor(memory_info, const_cast(features), - T * feature_dim, x_shape.data(), x_shape.size()); - - std::array x_lens_shape{1}; - int64_t x_lens_tmp = T; - - Ort::Value x_lens = Ort::Value::CreateTensor( - memory_info, &x_lens_tmp, 1, x_lens_shape.data(), x_lens_shape.size()); - - std::array encoder_inputs{std::move(x), std::move(x_lens)}; - - // Note: We discard encoder_out_lens since we only implement - // batch==1. - auto encoder_out = encoder_sess_->Run( - {}, encoder_input_names_ptr_.data(), encoder_inputs.data(), - encoder_inputs.size(), encoder_output_names_ptr_.data(), - encoder_output_names_ptr_.size()); - return std::move(encoder_out[0]); -} -Ort::Value RnntModel::RunJoinerEncoderProj(const float *encoder_out, int32_t T, - int32_t encoder_out_dim) { - auto memory_info = - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); - - std::array in_shape{T, encoder_out_dim}; - Ort::Value in = Ort::Value::CreateTensor( - memory_info, const_cast(encoder_out), T * encoder_out_dim, - in_shape.data(), in_shape.size()); - - auto encoder_proj_out = joiner_encoder_proj_sess_->Run( - {}, joiner_encoder_proj_input_names_ptr_.data(), &in, 1, - joiner_encoder_proj_output_names_ptr_.data(), - joiner_encoder_proj_output_names_ptr_.size()); - return std::move(encoder_proj_out[0]); -} - -Ort::Value RnntModel::RunDecoder(const int64_t *decoder_input, - int32_t context_size) { - auto memory_info = - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); - - int32_t batch_size = 1; // TODO(fangjun): handle the case when it's > 1 - std::array shape{batch_size, context_size}; - Ort::Value in = Ort::Value::CreateTensor( - memory_info, const_cast(decoder_input), - batch_size * context_size, shape.data(), shape.size()); - - auto decoder_out = decoder_sess_->Run( - {}, decoder_input_names_ptr_.data(), &in, 1, - decoder_output_names_ptr_.data(), decoder_output_names_ptr_.size()); - return std::move(decoder_out[0]); -} - -Ort::Value RnntModel::RunJoinerDecoderProj(const float *decoder_out, - int32_t decoder_out_dim) { - auto memory_info = - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); - - int32_t batch_size = 1; // TODO(fangjun): handle the case when it's > 1 - std::array shape{batch_size, decoder_out_dim}; - Ort::Value in = Ort::Value::CreateTensor( - memory_info, const_cast(decoder_out), - batch_size * decoder_out_dim, shape.data(), shape.size()); - - auto decoder_proj_out = joiner_decoder_proj_sess_->Run( - {}, joiner_decoder_proj_input_names_ptr_.data(), &in, 1, - joiner_decoder_proj_output_names_ptr_.data(), - joiner_decoder_proj_output_names_ptr_.size()); - return std::move(decoder_proj_out[0]); -} - -Ort::Value RnntModel::RunJoiner(const float *projected_encoder_out, - const float *projected_decoder_out, - int32_t joiner_dim) { - auto memory_info = - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); - int32_t batch_size = 1; // TODO(fangjun): handle the case when it's > 1 - std::array shape{batch_size, joiner_dim}; - - Ort::Value enc = Ort::Value::CreateTensor( - memory_info, const_cast(projected_encoder_out), - batch_size * joiner_dim, shape.data(), shape.size()); - - Ort::Value dec = Ort::Value::CreateTensor( - memory_info, const_cast(projected_decoder_out), - batch_size * joiner_dim, shape.data(), shape.size()); - - std::array inputs{std::move(enc), std::move(dec)}; - - auto logit = joiner_sess_->Run( - {}, joiner_input_names_ptr_.data(), inputs.data(), inputs.size(), - joiner_output_names_ptr_.data(), joiner_output_names_ptr_.size()); - - return std::move(logit[0]); -} - -} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/rnnt-model.h b/sherpa-onnx/csrc/rnnt-model.h deleted file mode 100644 index 9068d2cb..00000000 --- a/sherpa-onnx/csrc/rnnt-model.h +++ /dev/null @@ -1,148 +0,0 @@ -/** - * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) - * - * See LICENSE for clarification regarding multiple authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef SHERPA_ONNX_CSRC_RNNT_MODEL_H_ -#define SHERPA_ONNX_CSRC_RNNT_MODEL_H_ - -#include -#include -#include - -#include "onnxruntime_cxx_api.h" // NOLINT - -namespace sherpa_onnx { - -class RnntModel { - public: - /** - * @param encoder_filename Path to the encoder model - * @param decoder_filename Path to the decoder model - * @param joiner_filename Path to the joiner model - * @param joiner_encoder_proj_filename Path to the joiner encoder_proj model - * @param joiner_decoder_proj_filename Path to the joiner decoder_proj model - * @param num_threads Number of threads to use to run the models - */ - RnntModel(const std::string &encoder_filename, - const std::string &decoder_filename, - const std::string &joiner_filename, - const std::string &joiner_encoder_proj_filename, - const std::string &joiner_decoder_proj_filename, - int32_t num_threads); - - /** Run the encoder model. - * - * @TODO(fangjun): Support batch_size > 1 - * - * @param features A tensor of shape (batch_size, T, feature_dim) - * @param T Number of feature frames - * @param feature_dim Dimension of the feature. - * - * @return Return a tensor of shape (batch_size, T', encoder_out_dim) - */ - Ort::Value RunEncoder(const float *features, int32_t T, int32_t feature_dim); - - /** Run the joiner encoder_proj model. - * - * @param encoder_out A tensor of shape (T, encoder_out_dim) - * @param T Number of frames in encoder_out. - * @param encoder_out_dim Dimension of encoder_out. - * - * @return Return a tensor of shape (T, joiner_dim) - * - */ - Ort::Value RunJoinerEncoderProj(const float *encoder_out, int32_t T, - int32_t encoder_out_dim); - - /** Run the decoder model. - * - * @TODO(fangjun): Support batch_size > 1 - * - * @param decoder_input A tensor of shape (batch_size, context_size). - * @return Return a tensor of shape (batch_size, 1, decoder_out_dim) - */ - Ort::Value RunDecoder(const int64_t *decoder_input, int32_t context_size); - - /** Run joiner decoder_proj model. - * - * @TODO(fangjun): Support batch_size > 1 - * - * @param decoder_out A tensor of shape (batch_size, decoder_out_dim) - * @param decoder_out_dim Output dimension of the decoder_out. - * - * @return Return a tensor of shape (batch_size, joiner_dim); - */ - Ort::Value RunJoinerDecoderProj(const float *decoder_out, - int32_t decoder_out_dim); - - /** Run the joiner model. - * - * @TODO(fangjun): Support batch_size > 1 - * - * @param projected_encoder_out A tensor of shape (batch_size, joiner_dim). - * @param projected_decoder_out A tensor of shape (batch_size, joiner_dim). - * - * @return Return a tensor of shape (batch_size, vocab_size) - */ - Ort::Value RunJoiner(const float *projected_encoder_out, - const float *projected_decoder_out, int32_t joiner_dim); - - private: - void InitEncoder(const std::string &encoder_filename); - void InitDecoder(const std::string &decoder_filename); - void InitJoiner(const std::string &joiner_filename); - void InitJoinerEncoderProj(const std::string &joiner_encoder_proj_filename); - void InitJoinerDecoderProj(const std::string &joiner_decoder_proj_filename); - - private: - Ort::Env env_; - Ort::SessionOptions sess_opts_; - std::unique_ptr encoder_sess_; - std::unique_ptr decoder_sess_; - std::unique_ptr joiner_sess_; - std::unique_ptr joiner_encoder_proj_sess_; - std::unique_ptr joiner_decoder_proj_sess_; - - std::vector encoder_input_names_; - std::vector encoder_input_names_ptr_; - std::vector encoder_output_names_; - std::vector encoder_output_names_ptr_; - - std::vector decoder_input_names_; - std::vector decoder_input_names_ptr_; - std::vector decoder_output_names_; - std::vector decoder_output_names_ptr_; - - std::vector joiner_input_names_; - std::vector joiner_input_names_ptr_; - std::vector joiner_output_names_; - std::vector joiner_output_names_ptr_; - - std::vector joiner_encoder_proj_input_names_; - std::vector joiner_encoder_proj_input_names_ptr_; - std::vector joiner_encoder_proj_output_names_; - std::vector joiner_encoder_proj_output_names_ptr_; - - std::vector joiner_decoder_proj_input_names_; - std::vector joiner_decoder_proj_input_names_ptr_; - std::vector joiner_decoder_proj_output_names_; - std::vector joiner_decoder_proj_output_names_ptr_; -}; - -} // namespace sherpa_onnx - -#endif // SHERPA_ONNX_CSRC_RNNT_MODEL_H_ diff --git a/sherpa-onnx/csrc/sherpa-onnx.cc b/sherpa-onnx/csrc/sherpa-onnx.cc index 8ac8f2d8..02d984e2 100644 --- a/sherpa-onnx/csrc/sherpa-onnx.cc +++ b/sherpa-onnx/csrc/sherpa-onnx.cc @@ -1,78 +1,22 @@ -/** - * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) - * - * See LICENSE for clarification regarding multiple authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +// sherpa-onnx/csrc/sherpa-onnx.cc +// +// Copyright (c) 2022-2023 Xiaomi Corporation +#include // NOLINT #include #include #include #include "kaldi-native-fbank/csrc/online-feature.h" #include "sherpa-onnx/csrc/decode.h" -#include "sherpa-onnx/csrc/rnnt-model.h" +#include "sherpa-onnx/csrc/features.h" +#include "sherpa-onnx/csrc/online-transducer-model-config.h" +#include "sherpa-onnx/csrc/online-transducer-model.h" #include "sherpa-onnx/csrc/symbol-table.h" #include "sherpa-onnx/csrc/wave-reader.h" -/** Compute fbank features of the input wave filename. - * - * @param wav_filename. Path to a mono wave file. - * @param expected_sampling_rate Expected sampling rate of the input wave file. - * @param num_frames On return, it contains the number of feature frames. - * @return Return the computed feature of shape (num_frames, feature_dim) - * stored in row-major. - */ -static std::vector ComputeFeatures(const std::string &wav_filename, - float expected_sampling_rate, - int32_t *num_frames) { - std::vector samples = - sherpa_onnx::ReadWave(wav_filename, expected_sampling_rate); - - float duration = samples.size() / expected_sampling_rate; - - std::cout << "wav filename: " << wav_filename << "\n"; - std::cout << "wav duration (s): " << duration << "\n"; - - knf::FbankOptions opts; - opts.frame_opts.dither = 0; - opts.frame_opts.snip_edges = false; - opts.frame_opts.samp_freq = expected_sampling_rate; - - int32_t feature_dim = 80; - - opts.mel_opts.num_bins = feature_dim; - - knf::OnlineFbank fbank(opts); - fbank.AcceptWaveform(expected_sampling_rate, samples.data(), samples.size()); - fbank.InputFinished(); - - *num_frames = fbank.NumFramesReady(); - - std::vector features(*num_frames * feature_dim); - float *p = features.data(); - - for (int32_t i = 0; i != fbank.NumFramesReady(); ++i, p += feature_dim) { - const float *f = fbank.GetFrame(i); - std::copy(f, f + feature_dim, p); - } - - return features; -} - int main(int32_t argc, char *argv[]) { - if (argc < 8 || argc > 9) { + if (argc < 6 || argc > 7) { const char *usage = R"usage( Usage: ./bin/sherpa-onnx \ @@ -80,12 +24,11 @@ Usage: /path/to/encoder.onnx \ /path/to/decoder.onnx \ /path/to/joiner.onnx \ - /path/to/joiner_encoder_proj.onnx \ - /path/to/joiner_decoder_proj.onnx \ /path/to/foo.wav [num_threads] -You can download pre-trained models from the following repository: -https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13 +Please refer to +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html +for a list of pre-trained models to download. )usage"; std::cerr << usage << "\n"; @@ -93,37 +36,102 @@ https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stat } std::string tokens = argv[1]; - std::string encoder = argv[2]; - std::string decoder = argv[3]; - std::string joiner = argv[4]; - std::string joiner_encoder_proj = argv[5]; - std::string joiner_decoder_proj = argv[6]; - std::string wav_filename = argv[7]; - int32_t num_threads = 4; - if (argc == 9) { - num_threads = atoi(argv[8]); + sherpa_onnx::OnlineTransducerModelConfig config; + config.debug = true; + config.encoder_filename = argv[2]; + config.decoder_filename = argv[3]; + config.joiner_filename = argv[4]; + std::string wav_filename = argv[5]; + + config.num_threads = 2; + if (argc == 7) { + config.num_threads = atoi(argv[6]); } + std::cout << config.ToString().c_str() << "\n"; + + auto model = sherpa_onnx::OnlineTransducerModel::Create(config); sherpa_onnx::SymbolTable sym(tokens); - int32_t num_frames; - auto features = ComputeFeatures(wav_filename, 16000, &num_frames); - int32_t feature_dim = features.size() / num_frames; + Ort::AllocatorWithDefaultOptions allocator; - sherpa_onnx::RnntModel model(encoder, decoder, joiner, joiner_encoder_proj, - joiner_decoder_proj, num_threads); - Ort::Value encoder_out = - model.RunEncoder(features.data(), num_frames, feature_dim); + int32_t chunk_size = model->ChunkSize(); + int32_t chunk_shift = model->ChunkShift(); - auto hyp = sherpa_onnx::GreedySearch(model, encoder_out); + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); - std::string text; - for (auto i : hyp) { - text += sym[i]; + std::vector states = model->GetEncoderInitStates(); + + std::vector hyp(model->ContextSize(), 0); + + int32_t expected_sampling_rate = 16000; + + bool is_ok = false; + std::vector samples = + sherpa_onnx::ReadWave(wav_filename, expected_sampling_rate, &is_ok); + + if (!is_ok) { + std::cerr << "Failed to read " << wav_filename << "\n"; + return -1; } + const float duration = samples.size() / expected_sampling_rate; + + std::cout << "wav filename: " << wav_filename << "\n"; + std::cout << "wav duration (s): " << duration << "\n"; + + auto begin = std::chrono::steady_clock::now(); + std::cout << "Started!\n"; + + sherpa_onnx::FeatureExtractor feat_extractor; + feat_extractor.AcceptWaveform(expected_sampling_rate, samples.data(), + samples.size()); + + std::vector tail_paddings( + static_cast(0.2 * expected_sampling_rate)); + feat_extractor.AcceptWaveform(expected_sampling_rate, tail_paddings.data(), + tail_paddings.size()); + feat_extractor.InputFinished(); + + int32_t num_frames = feat_extractor.NumFramesReady(); + int32_t feature_dim = feat_extractor.FeatureDim(); + + std::array x_shape{1, chunk_size, feature_dim}; + + for (int32_t start = 0; start + chunk_size < num_frames; + start += chunk_shift) { + std::vector features = feat_extractor.GetFrames(start, chunk_size); + + Ort::Value x = + Ort::Value::CreateTensor(memory_info, features.data(), features.size(), + x_shape.data(), x_shape.size()); + auto pair = model->RunEncoder(std::move(x), states); + states = std::move(pair.second); + sherpa_onnx::GreedySearch(model.get(), std::move(pair.first), &hyp); + } + std::string text; + for (size_t i = model->ContextSize(); i != hyp.size(); ++i) { + text += sym[hyp[i]]; + } + + std::cout << "Done!\n"; + std::cout << "Recognition result for " << wav_filename << "\n" << text << "\n"; + auto end = std::chrono::steady_clock::now(); + float elapsed_seconds = + std::chrono::duration_cast(end - begin) + .count() / + 1000.; + + std::cout << "num threads: " << config.num_threads << "\n"; + + fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds); + float rtf = elapsed_seconds / duration; + fprintf(stderr, "Real time factor (RTF): %.3f / %.3f = %.3f\n", + elapsed_seconds, duration, rtf); + return 0; } diff --git a/sherpa-onnx/csrc/show-onnx-info.cc b/sherpa-onnx/csrc/show-onnx-info.cc index 3ee78fbb..ef2766c7 100644 --- a/sherpa-onnx/csrc/show-onnx-info.cc +++ b/sherpa-onnx/csrc/show-onnx-info.cc @@ -1,20 +1,7 @@ -/** - * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) - * - * See LICENSE for clarification regarding multiple authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +// sherpa-onnx/csrc/show-onnx-info.cc +// +// Copyright (c) 2022-2023 Xiaomi Corporation + #include #include diff --git a/sherpa-onnx/csrc/symbol-table.cc b/sherpa-onnx/csrc/symbol-table.cc index 50ffe961..3d4f2b78 100644 --- a/sherpa-onnx/csrc/symbol-table.cc +++ b/sherpa-onnx/csrc/symbol-table.cc @@ -1,20 +1,6 @@ -/** - * Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang) - * - * See LICENSE for clarification regarding multiple authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +// sherpa-onnx/csrc/symbol-table.cc +// +// Copyright (c) 2022-2023 Xiaomi Corporation #include "sherpa-onnx/csrc/symbol-table.h" diff --git a/sherpa-onnx/csrc/symbol-table.h b/sherpa-onnx/csrc/symbol-table.h index 46044cfd..fdcde41e 100644 --- a/sherpa-onnx/csrc/symbol-table.h +++ b/sherpa-onnx/csrc/symbol-table.h @@ -1,20 +1,6 @@ -/** - * Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang) - * - * See LICENSE for clarification regarding multiple authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +// sherpa-onnx/csrc/symbol-table.cc +// +// Copyright (c) 2022-2023 Xiaomi Corporation #ifndef SHERPA_ONNX_CSRC_SYMBOL_TABLE_H_ #define SHERPA_ONNX_CSRC_SYMBOL_TABLE_H_ diff --git a/sherpa-onnx/csrc/wave-reader.cc b/sherpa-onnx/csrc/wave-reader.cc index e04e024a..cdc80f81 100644 --- a/sherpa-onnx/csrc/wave-reader.cc +++ b/sherpa-onnx/csrc/wave-reader.cc @@ -1,20 +1,6 @@ -/** - * Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang) - * - * See LICENSE for clarification regarding multiple authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +// sherpa/csrc/wave-reader.cc +// +// Copyright (c) 2023 Xiaomi Corporation #include "sherpa-onnx/csrc/wave-reader.h" @@ -31,19 +17,44 @@ namespace { // Note: We assume little endian here // TODO(fangjun): Support big endian struct WaveHeader { - void Validate() const { - // F F I R - assert(chunk_id == 0x46464952); - assert(chunk_size == 36 + subchunk2_size); - // E V A W - assert(format == 0x45564157); - assert(subchunk1_id == 0x20746d66); - assert(subchunk1_size == 16); // 16 for PCM - assert(audio_format == 1); // 1 for PCM - assert(num_channels == 1); // we support only single channel for now - assert(byte_rate == sample_rate * num_channels * bits_per_sample / 8); - assert(block_align == num_channels * bits_per_sample / 8); - assert(bits_per_sample == 16); // we support only 16 bits per sample + bool Validate() const { + // F F I R + if (chunk_id != 0x46464952) { + return false; + } + // E V A W + if (format != 0x45564157) { + return false; + } + + if (subchunk1_id != 0x20746d66) { + return false; + } + + if (subchunk1_size != 16) { // 16 for PCM + return false; + } + + if (audio_format != 1) { // 1 for PCM + return false; + } + + if (num_channels != 1) { // we support only single channel for now + return false; + } + if (byte_rate != (sample_rate * num_channels * bits_per_sample / 8)) { + return false; + } + + if (block_align != (num_channels * bits_per_sample / 8)) { + return false; + } + + if (bits_per_sample != 16) { // we support only 16 bits per sample + return false; + } + + return true; } // See @@ -52,7 +63,7 @@ struct WaveHeader { // https://www.robotplanet.dk/audio/wav_meta_data/riff_mci.pdf void SeekToDataChunk(std::istream &is) { // a t a d - while (subchunk2_id != 0x61746164) { + while (is && subchunk2_id != 0x61746164) { // const char *p = reinterpret_cast(&subchunk2_id); // printf("Skip chunk (%x): %c%c%c%c of size: %d\n", subchunk2_id, p[0], // p[1], p[2], p[3], subchunk2_size); @@ -80,44 +91,61 @@ static_assert(sizeof(WaveHeader) == 44, ""); // Read a wave file of mono-channel. // Return its samples normalized to the range [-1, 1). -std::vector ReadWaveImpl(std::istream &is, float *sample_rate) { +std::vector ReadWaveImpl(std::istream &is, float expected_sample_rate, + bool *is_ok) { WaveHeader header; is.read(reinterpret_cast(&header), sizeof(header)); - assert(static_cast(is)); - header.Validate(); + if (!is) { + *is_ok = false; + return {}; + } + + if (!header.Validate()) { + *is_ok = false; + return {}; + } header.SeekToDataChunk(is); + if (!is) { + *is_ok = false; + return {}; + } - *sample_rate = header.sample_rate; + if (expected_sample_rate != header.sample_rate) { + *is_ok = false; + return {}; + } // header.subchunk2_size contains the number of bytes in the data. // As we assume each sample contains two bytes, so it is divided by 2 here std::vector samples(header.subchunk2_size / 2); is.read(reinterpret_cast(samples.data()), header.subchunk2_size); - - assert(static_cast(is)); + if (!is) { + *is_ok = false; + return {}; + } std::vector ans(samples.size()); for (int32_t i = 0; i != ans.size(); ++i) { ans[i] = samples[i] / 32768.; } + *is_ok = true; return ans; } } // namespace std::vector ReadWave(const std::string &filename, - float expected_sample_rate) { + float expected_sample_rate, bool *is_ok) { std::ifstream is(filename, std::ifstream::binary); - float sample_rate; - auto samples = ReadWaveImpl(is, &sample_rate); - if (expected_sample_rate != sample_rate) { - std::cerr << "Expected sample rate: " << expected_sample_rate - << ". Given: " << sample_rate << ".\n"; - exit(-1); - } + return ReadWave(is, expected_sample_rate, is_ok); +} + +std::vector ReadWave(std::istream &is, float expected_sample_rate, + bool *is_ok) { + auto samples = ReadWaveImpl(is, expected_sample_rate, is_ok); return samples; } diff --git a/sherpa-onnx/csrc/wave-reader.h b/sherpa-onnx/csrc/wave-reader.h index 7db5c1f9..fb5c68c1 100644 --- a/sherpa-onnx/csrc/wave-reader.h +++ b/sherpa-onnx/csrc/wave-reader.h @@ -1,20 +1,6 @@ -/** - * Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang) - * - * See LICENSE for clarification regarding multiple authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +// sherpa/csrc/wave-reader.h +// +// Copyright (c) 2023 Xiaomi Corporation #ifndef SHERPA_ONNX_CSRC_WAVE_READER_H_ #define SHERPA_ONNX_CSRC_WAVE_READER_H_ @@ -30,11 +16,15 @@ namespace sherpa_onnx { @param filename Path to a wave file. It MUST be single channel, PCM encoded. @param expected_sample_rate Expected sample rate of the wave file. If the sample rate don't match, it throws an exception. + @param is_ok On return it is true if the reading succeeded; false otherwise. @return Return wave samples normalized to the range [-1, 1). */ std::vector ReadWave(const std::string &filename, - float expected_sample_rate); + float expected_sample_rate, bool *is_ok); + +std::vector ReadWave(std::istream &is, float expected_sample_rate, + bool *is_ok); } // namespace sherpa_onnx