diff --git a/.github/scripts/test-online-paraformer.sh b/.github/scripts/test-online-paraformer.sh new file mode 100755 index 00000000..93574e3f --- /dev/null +++ b/.github/scripts/test-online-paraformer.sh @@ -0,0 +1,53 @@ +#!/usr/bin/env bash + +set -e + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +echo "EXE is $EXE" +echo "PATH: $PATH" + +which $EXE + +log "------------------------------------------------------------" +log "Run streaming Paraformer" +log "------------------------------------------------------------" + +repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-streaming-paraformer-bilingual-zh-en +log "Start testing ${repo_url}" +repo=$(basename $repo_url) +log "Download pretrained model and test-data from $repo_url" + +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +pushd $repo +git lfs pull --include "*.onnx" +ls -lh *.onnx +popd + +time $EXE \ + --tokens=$repo/tokens.txt \ + --paraformer-encoder=$repo/encoder.onnx \ + --paraformer-decoder=$repo/decoder.onnx \ + --num-threads=2 \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/2.wav \ + $repo/test_wavs/3.wav \ + $repo/test_wavs/8k.wav + +time $EXE \ + --tokens=$repo/tokens.txt \ + --paraformer-encoder=$repo/encoder.int8.onnx \ + --paraformer-decoder=$repo/decoder.int8.onnx \ + --num-threads=2 \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/2.wav \ + $repo/test_wavs/3.wav \ + $repo/test_wavs/8k.wav + +rm -rf $repo diff --git a/.github/workflows/linux-gpu.yaml b/.github/workflows/linux-gpu.yaml index 7b14ac2b..25350a31 100644 --- a/.github/workflows/linux-gpu.yaml +++ b/.github/workflows/linux-gpu.yaml @@ -9,6 +9,7 @@ on: paths: - '.github/workflows/linux-gpu.yaml' - '.github/scripts/test-online-transducer.sh' + - '.github/scripts/test-online-paraformer.sh' - '.github/scripts/test-offline-transducer.sh' - '.github/scripts/test-offline-ctc.sh' - 'CMakeLists.txt' @@ -22,6 +23,7 @@ on: paths: - '.github/workflows/linux-gpu.yaml' - '.github/scripts/test-online-transducer.sh' + - '.github/scripts/test-online-paraformer.sh' - '.github/scripts/test-offline-transducer.sh' - '.github/scripts/test-offline-ctc.sh' - 'CMakeLists.txt' @@ -85,6 +87,14 @@ jobs: file build/bin/sherpa-onnx readelf -d build/bin/sherpa-onnx + - name: Test online paraformer + shell: bash + run: | + export PATH=$PWD/build/bin:$PATH + export EXE=sherpa-onnx + + .github/scripts/test-online-paraformer.sh + - name: Test offline Whisper shell: bash run: | diff --git a/.github/workflows/linux.yaml b/.github/workflows/linux.yaml index a03602ba..2c026fc0 100644 --- a/.github/workflows/linux.yaml +++ b/.github/workflows/linux.yaml @@ -9,6 +9,7 @@ on: paths: - '.github/workflows/linux.yaml' - '.github/scripts/test-online-transducer.sh' + - '.github/scripts/test-online-paraformer.sh' - '.github/scripts/test-offline-transducer.sh' - '.github/scripts/test-offline-ctc.sh' - 'CMakeLists.txt' @@ -22,6 +23,7 @@ on: paths: - '.github/workflows/linux.yaml' - '.github/scripts/test-online-transducer.sh' + - '.github/scripts/test-online-paraformer.sh' - '.github/scripts/test-offline-transducer.sh' - '.github/scripts/test-offline-ctc.sh' - 'CMakeLists.txt' @@ -84,6 +86,14 @@ jobs: file build/bin/sherpa-onnx readelf -d build/bin/sherpa-onnx + - name: Test online paraformer + shell: bash + run: | + export PATH=$PWD/build/bin:$PATH + export EXE=sherpa-onnx + + .github/scripts/test-online-paraformer.sh + - name: Test offline Whisper shell: bash run: | diff --git a/.github/workflows/macos.yaml b/.github/workflows/macos.yaml index cebc5ac6..f3b11a5d 100644 --- a/.github/workflows/macos.yaml +++ b/.github/workflows/macos.yaml @@ -7,6 +7,7 @@ on: paths: - '.github/workflows/macos.yaml' - '.github/scripts/test-online-transducer.sh' + - '.github/scripts/test-online-paraformer.sh' - '.github/scripts/test-offline-transducer.sh' - '.github/scripts/test-offline-ctc.sh' - 'CMakeLists.txt' @@ -18,6 +19,7 @@ on: paths: - '.github/workflows/macos.yaml' - '.github/scripts/test-online-transducer.sh' + - '.github/scripts/test-online-paraformer.sh' - '.github/scripts/test-offline-transducer.sh' - '.github/scripts/test-offline-ctc.sh' - 'CMakeLists.txt' @@ -82,6 +84,14 @@ jobs: otool -L build/bin/sherpa-onnx otool -l build/bin/sherpa-onnx + - name: Test online paraformer + shell: bash + run: | + export PATH=$PWD/build/bin:$PATH + export EXE=sherpa-onnx + + .github/scripts/test-online-paraformer.sh + - name: Test offline Whisper shell: bash run: | diff --git a/.github/workflows/test-pip-install.yaml b/.github/workflows/test-pip-install.yaml index 34a15360..01fdb4c6 100644 --- a/.github/workflows/test-pip-install.yaml +++ b/.github/workflows/test-pip-install.yaml @@ -58,7 +58,6 @@ jobs: sherpa-onnx-microphone-offline --help sherpa-onnx-offline-websocket-server --help - sherpa-onnx-offline-websocket-client --help sherpa-onnx-online-websocket-server --help sherpa-onnx-online-websocket-client --help diff --git a/.github/workflows/test-python-offline-websocket-server.yaml b/.github/workflows/test-python-offline-websocket-server.yaml index 7ec4e29d..d7ea4dde 100644 --- a/.github/workflows/test-python-offline-websocket-server.yaml +++ b/.github/workflows/test-python-offline-websocket-server.yaml @@ -84,14 +84,14 @@ jobs: if: matrix.model_type == 'paraformer' shell: bash run: | - GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-paraformer-zh-2023-03-28 - cd sherpa-onnx-paraformer-zh-2023-03-28 + GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-paraformer-bilingual-zh-en + cd sherpa-onnx-paraformer-bilingual-zh-en git lfs pull --include "*.onnx" cd .. python3 ./python-api-examples/non_streaming_server.py \ - --paraformer ./sherpa-onnx-paraformer-zh-2023-03-28/model.int8.onnx \ - --tokens ./sherpa-onnx-paraformer-zh-2023-03-28/tokens.txt & + --paraformer ./sherpa-onnx-paraformer-bilingual-zh-en/model.int8.onnx \ + --tokens ./sherpa-onnx-paraformer-bilingual-zh-en/tokens.txt & echo "sleep 10 seconds to wait the server start" sleep 10 @@ -101,16 +101,16 @@ jobs: shell: bash run: | python3 ./python-api-examples/offline-websocket-client-decode-files-paralell.py \ - ./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/0.wav \ - ./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/1.wav \ - ./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/2.wav \ - ./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/8k.wav + ./sherpa-onnx-paraformer-bilingual-zh-en/test_wavs/0.wav \ + ./sherpa-onnx-paraformer-bilingual-zh-en/test_wavs/1.wav \ + ./sherpa-onnx-paraformer-bilingual-zh-en/test_wavs/2.wav \ + ./sherpa-onnx-paraformer-bilingual-zh-en/test_wavs/8k.wav python3 ./python-api-examples/offline-websocket-client-decode-files-sequential.py \ - ./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/0.wav \ - ./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/1.wav \ - ./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/2.wav \ - ./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/8k.wav + ./sherpa-onnx-paraformer-bilingual-zh-en/test_wavs/0.wav \ + ./sherpa-onnx-paraformer-bilingual-zh-en/test_wavs/1.wav \ + ./sherpa-onnx-paraformer-bilingual-zh-en/test_wavs/2.wav \ + ./sherpa-onnx-paraformer-bilingual-zh-en/test_wavs/8k.wav - name: Start server for nemo_ctc models if: matrix.model_type == 'nemo_ctc' diff --git a/.github/workflows/test-python-online-websocket-server.yaml b/.github/workflows/test-python-online-websocket-server.yaml index c7e3319d..7616afa3 100644 --- a/.github/workflows/test-python-online-websocket-server.yaml +++ b/.github/workflows/test-python-online-websocket-server.yaml @@ -24,7 +24,7 @@ jobs: matrix: os: [ubuntu-latest, windows-latest, macos-latest] python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] - model_type: ["transducer"] + model_type: ["transducer", "paraformer"] steps: - uses: actions/checkout@v2 @@ -71,3 +71,36 @@ jobs: run: | python3 ./python-api-examples/online-websocket-client-decode-file.py \ ./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/0.wav + + - name: Start server for paraformer models + if: matrix.model_type == 'paraformer' + shell: bash + run: | + GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-paraformer-bilingual-zh-en + cd sherpa-onnx-streaming-paraformer-bilingual-zh-en + git lfs pull --include "*.onnx" + cd .. + + python3 ./python-api-examples/streaming_server.py \ + --tokens ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/tokens.txt \ + --paraformer-encoder ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/encoder.int8.onnx \ + --paraformer-decoder ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/decoder.int8.onnx & + + echo "sleep 10 seconds to wait the server start" + sleep 10 + + - name: Start client for paraformer models + if: matrix.model_type == 'paraformer' + shell: bash + run: | + python3 ./python-api-examples/online-websocket-client-decode-file.py \ + ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/0.wav + + python3 ./python-api-examples/online-websocket-client-decode-file.py \ + ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/1.wav + + python3 ./python-api-examples/online-websocket-client-decode-file.py \ + ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/2.wav + + python3 ./python-api-examples/online-websocket-client-decode-file.py \ + ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/3.wav diff --git a/.github/workflows/windows-x64-cuda.yaml b/.github/workflows/windows-x64-cuda.yaml index 24b8158d..17e53d8b 100644 --- a/.github/workflows/windows-x64-cuda.yaml +++ b/.github/workflows/windows-x64-cuda.yaml @@ -9,6 +9,7 @@ on: paths: - '.github/workflows/windows-x64-cuda.yaml' - '.github/scripts/test-online-transducer.sh' + - '.github/scripts/test-online-paraformer.sh' - '.github/scripts/test-offline-transducer.sh' - '.github/scripts/test-offline-ctc.sh' - 'CMakeLists.txt' @@ -20,6 +21,7 @@ on: paths: - '.github/workflows/windows-x64-cuda.yaml' - '.github/scripts/test-online-transducer.sh' + - '.github/scripts/test-online-paraformer.sh' - '.github/scripts/test-offline-transducer.sh' - '.github/scripts/test-offline-ctc.sh' - 'CMakeLists.txt' @@ -74,6 +76,14 @@ jobs: ls -lh ./bin/Release/sherpa-onnx.exe + - name: Test online paraformer for windows x64 + shell: bash + run: | + export PATH=$PWD/build/bin/Release:$PATH + export EXE=sherpa-onnx.exe + + .github/scripts/test-online-paraformer.sh + - name: Test offline Whisper for windows x64 shell: bash run: | diff --git a/.github/workflows/windows-x64.yaml b/.github/workflows/windows-x64.yaml index 83b80de9..c63dbae3 100644 --- a/.github/workflows/windows-x64.yaml +++ b/.github/workflows/windows-x64.yaml @@ -9,6 +9,7 @@ on: paths: - '.github/workflows/windows-x64.yaml' - '.github/scripts/test-online-transducer.sh' + - '.github/scripts/test-online-paraformer.sh' - '.github/scripts/test-offline-transducer.sh' - '.github/scripts/test-offline-ctc.sh' - 'CMakeLists.txt' @@ -20,6 +21,7 @@ on: paths: - '.github/workflows/windows-x64.yaml' - '.github/scripts/test-online-transducer.sh' + - '.github/scripts/test-online-paraformer.sh' - '.github/scripts/test-offline-transducer.sh' - '.github/scripts/test-offline-ctc.sh' - 'CMakeLists.txt' @@ -75,6 +77,14 @@ jobs: ls -lh ./bin/Release/sherpa-onnx.exe + - name: Test online paraformer for windows x64 + shell: bash + run: | + export PATH=$PWD/build/bin/Release:$PATH + export EXE=sherpa-onnx.exe + + .github/scripts/test-online-paraformer.sh + - name: Test offline Whisper for windows x64 shell: bash run: | diff --git a/.github/workflows/windows-x86.yaml b/.github/workflows/windows-x86.yaml index d181e22c..b39a1ddc 100644 --- a/.github/workflows/windows-x86.yaml +++ b/.github/workflows/windows-x86.yaml @@ -7,6 +7,7 @@ on: paths: - '.github/workflows/windows-x86.yaml' - '.github/scripts/test-online-transducer.sh' + - '.github/scripts/test-online-paraformer.sh' - '.github/scripts/test-offline-transducer.sh' - '.github/scripts/test-offline-ctc.sh' - 'CMakeLists.txt' @@ -18,6 +19,7 @@ on: paths: - '.github/workflows/windows-x86.yaml' - '.github/scripts/test-online-transducer.sh' + - '.github/scripts/test-online-paraformer.sh' - '.github/scripts/test-offline-transducer.sh' - '.github/scripts/test-offline-ctc.sh' - 'CMakeLists.txt' @@ -73,6 +75,14 @@ jobs: ls -lh ./bin/Release/sherpa-onnx.exe + - name: Test online paraformer for windows x86 + shell: bash + run: | + export PATH=$PWD/build/bin/Release:$PATH + export EXE=sherpa-onnx.exe + + .github/scripts/test-online-paraformer.sh + - name: Test offline Whisper for windows x86 shell: bash run: | diff --git a/CMakeLists.txt b/CMakeLists.txt index abb77787..c6086fa3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,7 +1,7 @@ cmake_minimum_required(VERSION 3.13 FATAL_ERROR) project(sherpa-onnx) -set(SHERPA_ONNX_VERSION "1.7.3") +set(SHERPA_ONNX_VERSION "1.7.4") # Disable warning about # diff --git a/python-api-examples/non_streaming_server.py b/python-api-examples/non_streaming_server.py index 7d3502fa..cbfaa760 100755 --- a/python-api-examples/non_streaming_server.py +++ b/python-api-examples/non_streaming_server.py @@ -37,14 +37,14 @@ python3 ./python-api-examples/non_streaming_server.py \ (2) Use a non-streaming paraformer cd /path/to/sherpa-onnx -GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-paraformer-zh-2023-03-28 -cd sherpa-onnx-paraformer-zh-2023-03-28 +GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-paraformer-bilingual-zh-en +cd sherpa-onnx-paraformer-bilingual-zh-en/ git lfs pull --include "*.onnx" cd .. python3 ./python-api-examples/non_streaming_server.py \ - --paraformer ./sherpa-onnx-paraformer-zh-2023-03-28/model.int8.onnx \ - --tokens ./sherpa-onnx-paraformer-zh-2023-03-28/tokens.txt + --paraformer ./sherpa-onnx-paraformer-bilingual-zh-en/model.int8.onnx \ + --tokens ./sherpa-onnx-paraformer-bilingual-zh-en/tokens.txt (3) Use a non-streaming CTC model from NeMo diff --git a/python-api-examples/online-decode-files.py b/python-api-examples/online-decode-files.py index e2e1dc55..eff85427 100755 --- a/python-api-examples/online-decode-files.py +++ b/python-api-examples/online-decode-files.py @@ -5,16 +5,41 @@ This file demonstrates how to use sherpa-onnx Python API to transcribe file(s) with a streaming model. Usage: - ./online-decode-files.py \ - /path/to/foo.wav \ - /path/to/bar.wav \ - /path/to/16kHz.wav \ - /path/to/8kHz.wav + +(1) Streaming transducer + +GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-06-26 +cd sherpa-onnx-streaming-zipformer-en-2023-06-26 +git lfs pull --include "*.onnx" + +./python-api-examples/online-decode-files.py \ + --tokens=./sherpa-onnx-streaming-zipformer-en-2023-06-26/tokens.txt \ + --encoder=./sherpa-onnx-streaming-zipformer-en-2023-06-26/encoder-epoch-99-avg-1-chunk-16-left-64.onnx \ + --decoder=./sherpa-onnx-streaming-zipformer-en-2023-06-26/decoder-epoch-99-avg-1-chunk-16-left-64.onnx \ + --joiner=./sherpa-onnx-streaming-zipformer-en-2023-06-26/joiner-epoch-99-avg-1-chunk-16-left-64.onnx \ + ./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/0.wav \ + ./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/1.wav \ + ./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/8k.wav + +(2) Streaming paraformer + +GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-paraformer-bilingual-zh-en +cd sherpa-onnx-streaming-paraformer-bilingual-zh-en +git lfs pull --include "*.onnx" + +./python-api-examples/online-decode-files.py \ + --tokens=./sherpa-onnx-streaming-paraformer-bilingual-zh-en/tokens.txt \ + --paraformer-encoder=./sherpa-onnx-streaming-paraformer-bilingual-zh-en/encoder.int8.onnx \ + --paraformer-decoder=./sherpa-onnx-streaming-paraformer-bilingual-zh-en/decoder.int8.onnx \ + ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/0.wav \ + ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/1.wav \ + ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/2.wav \ + ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/3.wav \ + ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/8k.wav Please refer to https://k2-fsa.github.io/sherpa/onnx/index.html -to install sherpa-onnx and to download the pre-trained models -used in this file. +to install sherpa-onnx and to download streaming pre-trained models. """ import argparse import time @@ -41,19 +66,31 @@ def get_args(): parser.add_argument( "--encoder", type=str, - help="Path to the encoder model", + help="Path to the transducer encoder model", ) parser.add_argument( "--decoder", type=str, - help="Path to the decoder model", + help="Path to the transducer decoder model", ) parser.add_argument( "--joiner", type=str, - help="Path to the joiner model", + help="Path to the transducer joiner model", + ) + + parser.add_argument( + "--paraformer-encoder", + type=str, + help="Path to the paraformer encoder model", + ) + + parser.add_argument( + "--paraformer-decoder", + type=str, + help="Path to the paraformer decoder model", ) parser.add_argument( @@ -200,24 +237,42 @@ def encode_contexts(args, contexts: List[str]) -> List[List[int]]: def main(): args = get_args() - assert_file_exists(args.encoder) - assert_file_exists(args.decoder) - assert_file_exists(args.joiner) assert_file_exists(args.tokens) - recognizer = sherpa_onnx.OnlineRecognizer.from_transducer( - tokens=args.tokens, - encoder=args.encoder, - decoder=args.decoder, - joiner=args.joiner, - num_threads=args.num_threads, - provider=args.provider, - sample_rate=16000, - feature_dim=80, - decoding_method=args.decoding_method, - max_active_paths=args.max_active_paths, - context_score=args.context_score, - ) + if args.encoder: + assert_file_exists(args.encoder) + assert_file_exists(args.decoder) + assert_file_exists(args.joiner) + + assert not args.paraformer_encoder, args.paraformer_encoder + assert not args.paraformer_decoder, args.paraformer_decoder + + recognizer = sherpa_onnx.OnlineRecognizer.from_transducer( + tokens=args.tokens, + encoder=args.encoder, + decoder=args.decoder, + joiner=args.joiner, + num_threads=args.num_threads, + provider=args.provider, + sample_rate=16000, + feature_dim=80, + decoding_method=args.decoding_method, + max_active_paths=args.max_active_paths, + context_score=args.context_score, + ) + elif args.paraformer_encoder: + recognizer = sherpa_onnx.OnlineRecognizer.from_paraformer( + tokens=args.tokens, + encoder=args.paraformer_encoder, + decoder=args.paraformer_decoder, + num_threads=args.num_threads, + provider=args.provider, + sample_rate=16000, + feature_dim=80, + decoding_method="greedy_search", + ) + else: + raise ValueError("Please provide a model") print("Started!") start_time = time.time() @@ -243,7 +298,7 @@ def main(): s.accept_waveform(sample_rate, samples) - tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32) + tail_paddings = np.zeros(int(0.66 * sample_rate), dtype=np.float32) s.accept_waveform(sample_rate, tail_paddings) s.input_finished() diff --git a/python-api-examples/streaming_server.py b/python-api-examples/streaming_server.py index c707a70c..33d4e5ee 100755 --- a/python-api-examples/streaming_server.py +++ b/python-api-examples/streaming_server.py @@ -16,9 +16,9 @@ Example: (1) Without a certificate python3 ./python-api-examples/streaming_server.py \ - --encoder-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx \ - --decoder-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx \ - --joiner-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx \ + --encoder ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx \ + --decoder ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx \ + --joiner ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx \ --tokens ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt (2) With a certificate @@ -32,9 +32,9 @@ python3 ./python-api-examples/streaming_server.py \ (b) Start the server python3 ./python-api-examples/streaming_server.py \ - --encoder-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx \ - --decoder-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx \ - --joiner-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx \ + --encoder ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx \ + --decoder ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx \ + --joiner ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx \ --tokens ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt \ --certificate ./python-api-examples/web/cert.pem @@ -113,24 +113,33 @@ def setup_logger( def add_model_args(parser: argparse.ArgumentParser): parser.add_argument( - "--encoder-model", + "--encoder", type=str, - required=True, - help="Path to the encoder model", + help="Path to the transducer encoder model", ) parser.add_argument( - "--decoder-model", + "--decoder", type=str, - required=True, - help="Path to the decoder model.", + help="Path to the transducer decoder model.", ) parser.add_argument( - "--joiner-model", + "--joiner", type=str, - required=True, - help="Path to the joiner model.", + help="Path to the transducer joiner model.", + ) + + parser.add_argument( + "--paraformer-encoder", + type=str, + help="Path to the paraformer encoder model", + ) + + parser.add_argument( + "--paraformer-decoder", + type=str, + help="Path to the transducer decoder model.", ) parser.add_argument( @@ -323,22 +332,40 @@ def get_args(): def create_recognizer(args) -> sherpa_onnx.OnlineRecognizer: - recognizer = sherpa_onnx.OnlineRecognizer.from_transducer( - tokens=args.tokens, - encoder=args.encoder_model, - decoder=args.decoder_model, - joiner=args.joiner_model, - num_threads=args.num_threads, - sample_rate=args.sample_rate, - feature_dim=args.feat_dim, - decoding_method=args.decoding_method, - max_active_paths=args.num_active_paths, - enable_endpoint_detection=args.use_endpoint != 0, - rule1_min_trailing_silence=args.rule1_min_trailing_silence, - rule2_min_trailing_silence=args.rule2_min_trailing_silence, - rule3_min_utterance_length=args.rule3_min_utterance_length, - provider=args.provider, - ) + if args.encoder: + recognizer = sherpa_onnx.OnlineRecognizer.from_transducer( + tokens=args.tokens, + encoder=args.encoder, + decoder=args.decoder, + joiner=args.joiner, + num_threads=args.num_threads, + sample_rate=args.sample_rate, + feature_dim=args.feat_dim, + decoding_method=args.decoding_method, + max_active_paths=args.num_active_paths, + enable_endpoint_detection=args.use_endpoint != 0, + rule1_min_trailing_silence=args.rule1_min_trailing_silence, + rule2_min_trailing_silence=args.rule2_min_trailing_silence, + rule3_min_utterance_length=args.rule3_min_utterance_length, + provider=args.provider, + ) + elif args.paraformer_encoder: + recognizer = sherpa_onnx.OnlineRecognizer.from_paraformer( + tokens=args.tokens, + encoder=args.paraformer_encoder, + decoder=args.paraformer_decoder, + num_threads=args.num_threads, + sample_rate=args.sample_rate, + feature_dim=args.feat_dim, + decoding_method=args.decoding_method, + enable_endpoint_detection=args.use_endpoint != 0, + rule1_min_trailing_silence=args.rule1_min_trailing_silence, + rule2_min_trailing_silence=args.rule2_min_trailing_silence, + rule3_min_utterance_length=args.rule3_min_utterance_length, + provider=args.provider, + ) + else: + raise ValueError("Please provide a model") return recognizer @@ -654,11 +681,25 @@ Go back to /streaming_record.html def check_args(args): - assert Path(args.encoder_model).is_file(), f"{args.encoder_model} does not exist" + if args.encoder: + assert Path(args.encoder).is_file(), f"{args.encoder} does not exist" - assert Path(args.decoder_model).is_file(), f"{args.decoder_model} does not exist" + assert Path(args.decoder).is_file(), f"{args.decoder} does not exist" - assert Path(args.joiner_model).is_file(), f"{args.joiner_model} does not exist" + assert Path(args.joiner).is_file(), f"{args.joiner} does not exist" + + assert args.paraformer_encoder is None, args.paraformer_encoder + assert args.paraformer_decoder is None, args.paraformer_decoder + elif args.paraformer_encoder: + assert Path( + args.paraformer_encoder + ).is_file(), f"{args.paraformer_encoder} does not exist" + + assert Path( + args.paraformer_decoder + ).is_file(), f"{args.paraformer_decoder} does not exist" + else: + raise ValueError("Please provide a model") if not Path(args.tokens).is_file(): raise ValueError(f"{args.tokens} does not exist") diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index cb4953c5..b9bac58c 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -46,6 +46,8 @@ set(sources online-lm.cc online-lstm-transducer-model.cc online-model-config.cc + online-paraformer-model-config.cc + online-paraformer-model.cc online-recognizer-impl.cc online-recognizer.cc online-rnn-lm.cc diff --git a/sherpa-onnx/csrc/features.cc b/sherpa-onnx/csrc/features.cc index 7f804684..51500e11 100644 --- a/sherpa-onnx/csrc/features.cc +++ b/sherpa-onnx/csrc/features.cc @@ -39,7 +39,7 @@ std::string FeatureExtractorConfig::ToString() const { class FeatureExtractor::Impl { public: - explicit Impl(const FeatureExtractorConfig &config) { + explicit Impl(const FeatureExtractorConfig &config) : config_(config) { opts_.frame_opts.dither = 0; opts_.frame_opts.snip_edges = false; opts_.frame_opts.samp_freq = config.sampling_rate; @@ -50,6 +50,19 @@ class FeatureExtractor::Impl { } void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) { + if (config_.normalize_samples) { + AcceptWaveformImpl(sampling_rate, waveform, n); + } else { + std::vector buf(n); + for (int32_t i = 0; i != n; ++i) { + buf[i] = waveform[i] * 32768; + } + AcceptWaveformImpl(sampling_rate, buf.data(), n); + } + } + + void AcceptWaveformImpl(int32_t sampling_rate, const float *waveform, + int32_t n) { std::lock_guard lock(mutex_); if (resampler_) { @@ -146,6 +159,7 @@ class FeatureExtractor::Impl { private: std::unique_ptr fbank_; knf::FbankOptions opts_; + FeatureExtractorConfig config_; mutable std::mutex mutex_; std::unique_ptr resampler_; int32_t last_frame_index_ = 0; diff --git a/sherpa-onnx/csrc/features.h b/sherpa-onnx/csrc/features.h index d4eaffda..497dd01c 100644 --- a/sherpa-onnx/csrc/features.h +++ b/sherpa-onnx/csrc/features.h @@ -21,6 +21,13 @@ struct FeatureExtractorConfig { // Feature dimension int32_t feature_dim = 80; + // Set internally by some models, e.g., paraformer sets it to false. + // This parameter is not exposed to users from the commandline + // If true, the feature extractor expects inputs to be normalized to + // the range [-1, 1]. + // If false, we will multiply the inputs by 32768 + bool normalize_samples = true; + std::string ToString() const; void Register(ParseOptions *po); diff --git a/sherpa-onnx/csrc/online-model-config.cc b/sherpa-onnx/csrc/online-model-config.cc index 7a4416b5..9c1f8c49 100644 --- a/sherpa-onnx/csrc/online-model-config.cc +++ b/sherpa-onnx/csrc/online-model-config.cc @@ -12,6 +12,7 @@ namespace sherpa_onnx { void OnlineModelConfig::Register(ParseOptions *po) { transducer.Register(po); + paraformer.Register(po); po->Register("tokens", &tokens, "Path to tokens.txt"); @@ -41,6 +42,10 @@ bool OnlineModelConfig::Validate() const { return false; } + if (!paraformer.encoder.empty()) { + return paraformer.Validate(); + } + return transducer.Validate(); } @@ -49,6 +54,7 @@ std::string OnlineModelConfig::ToString() const { os << "OnlineModelConfig("; os << "transducer=" << transducer.ToString() << ", "; + os << "paraformer=" << paraformer.ToString() << ", "; os << "tokens=\"" << tokens << "\", "; os << "num_threads=" << num_threads << ", "; os << "debug=" << (debug ? "True" : "False") << ", "; diff --git a/sherpa-onnx/csrc/online-model-config.h b/sherpa-onnx/csrc/online-model-config.h index 34e7b1e4..2afd6617 100644 --- a/sherpa-onnx/csrc/online-model-config.h +++ b/sherpa-onnx/csrc/online-model-config.h @@ -6,12 +6,14 @@ #include +#include "sherpa-onnx/csrc/online-paraformer-model-config.h" #include "sherpa-onnx/csrc/online-transducer-model-config.h" namespace sherpa_onnx { struct OnlineModelConfig { OnlineTransducerModelConfig transducer; + OnlineParaformerModelConfig paraformer; std::string tokens; int32_t num_threads = 1; bool debug = false; @@ -28,9 +30,11 @@ struct OnlineModelConfig { OnlineModelConfig() = default; OnlineModelConfig(const OnlineTransducerModelConfig &transducer, + const OnlineParaformerModelConfig ¶former, const std::string &tokens, int32_t num_threads, bool debug, const std::string &provider, const std::string &model_type) : transducer(transducer), + paraformer(paraformer), tokens(tokens), num_threads(num_threads), debug(debug), diff --git a/sherpa-onnx/csrc/online-paraformer-decoder.h b/sherpa-onnx/csrc/online-paraformer-decoder.h new file mode 100644 index 00000000..9f675275 --- /dev/null +++ b/sherpa-onnx/csrc/online-paraformer-decoder.h @@ -0,0 +1,23 @@ +// sherpa-onnx/csrc/online-paraformer-decoder.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_DECODER_H_ +#define SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_DECODER_H_ + +#include + +#include "onnxruntime_cxx_api.h" // NOLINT + +namespace sherpa_onnx { + +struct OnlineParaformerDecoderResult { + /// The decoded token IDs + std::vector tokens; + + int32_t last_non_blank_frame_index = 0; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_DECODER_H_ diff --git a/sherpa-onnx/csrc/online-paraformer-model-config.cc b/sherpa-onnx/csrc/online-paraformer-model-config.cc new file mode 100644 index 00000000..a93fe299 --- /dev/null +++ b/sherpa-onnx/csrc/online-paraformer-model-config.cc @@ -0,0 +1,43 @@ +// sherpa-onnx/csrc/online-paraformer-model-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/online-paraformer-model-config.h" + +#include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/macros.h" + +namespace sherpa_onnx { + +void OnlineParaformerModelConfig::Register(ParseOptions *po) { + po->Register("paraformer-encoder", &encoder, + "Path to encoder.onnx of paraformer."); + po->Register("paraformer-decoder", &decoder, + "Path to decoder.onnx of paraformer."); +} + +bool OnlineParaformerModelConfig::Validate() const { + if (!FileExists(encoder)) { + SHERPA_ONNX_LOGE("Paraformer encoder %s does not exist", encoder.c_str()); + return false; + } + + if (!FileExists(decoder)) { + SHERPA_ONNX_LOGE("Paraformer decoder %s does not exist", decoder.c_str()); + return false; + } + + return true; +} + +std::string OnlineParaformerModelConfig::ToString() const { + std::ostringstream os; + + os << "OnlineParaformerModelConfig("; + os << "encoder=\"" << encoder << "\", "; + os << "decoder=\"" << decoder << "\")"; + + return os.str(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-paraformer-model-config.h b/sherpa-onnx/csrc/online-paraformer-model-config.h new file mode 100644 index 00000000..29f33e45 --- /dev/null +++ b/sherpa-onnx/csrc/online-paraformer-model-config.h @@ -0,0 +1,31 @@ +// sherpa-onnx/csrc/online-paraformer-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_MODEL_CONFIG_H_ + +#include + +#include "sherpa-onnx/csrc/parse-options.h" + +namespace sherpa_onnx { + +struct OnlineParaformerModelConfig { + std::string encoder; + std::string decoder; + + OnlineParaformerModelConfig() = default; + + OnlineParaformerModelConfig(const std::string &encoder, + const std::string &decoder) + : encoder(encoder), decoder(decoder) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_MODEL_CONFIG_H_ diff --git a/sherpa-onnx/csrc/online-paraformer-model.cc b/sherpa-onnx/csrc/online-paraformer-model.cc new file mode 100644 index 00000000..2d6a410e --- /dev/null +++ b/sherpa-onnx/csrc/online-paraformer-model.cc @@ -0,0 +1,249 @@ +// sherpa-onnx/csrc/online-paraformer-model.cc +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/online-paraformer-model.h" + +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/onnx-utils.h" +#include "sherpa-onnx/csrc/session.h" +#include "sherpa-onnx/csrc/text-utils.h" + +namespace sherpa_onnx { + +class OnlineParaformerModel::Impl { + public: + explicit Impl(const OnlineModelConfig &config) + : config_(config), + env_(ORT_LOGGING_LEVEL_ERROR), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(config.paraformer.encoder); + InitEncoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(config.paraformer.decoder); + InitDecoder(buf.data(), buf.size()); + } + } + +#if __ANDROID_API__ >= 9 + Impl(AAssetManager *mgr, const OnlineModelConfig &config) + : config_(config), + env_(ORT_LOGGING_LEVEL_WARNING), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(mgr, config.paraformer.encoder); + InitEncoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(mgr, config.paraformer.decoder); + InitDecoder(buf.data(), buf.size()); + } + } +#endif + + std::vector ForwardEncoder(Ort::Value features, + Ort::Value features_length) { + std::array inputs = {std::move(features), + std::move(features_length)}; + + return encoder_sess_->Run( + {}, encoder_input_names_ptr_.data(), inputs.data(), inputs.size(), + encoder_output_names_ptr_.data(), encoder_output_names_ptr_.size()); + } + + std::vector ForwardDecoder(Ort::Value encoder_out, + Ort::Value encoder_out_length, + Ort::Value acoustic_embedding, + Ort::Value acoustic_embedding_length, + std::vector states) { + std::vector decoder_inputs; + decoder_inputs.reserve(4 + states.size()); + + decoder_inputs.push_back(std::move(encoder_out)); + decoder_inputs.push_back(std::move(encoder_out_length)); + decoder_inputs.push_back(std::move(acoustic_embedding)); + decoder_inputs.push_back(std::move(acoustic_embedding_length)); + + for (auto &v : states) { + decoder_inputs.push_back(std::move(v)); + } + + return decoder_sess_->Run({}, decoder_input_names_ptr_.data(), + decoder_inputs.data(), decoder_inputs.size(), + decoder_output_names_ptr_.data(), + decoder_output_names_ptr_.size()); + } + + int32_t VocabSize() const { return vocab_size_; } + + int32_t LfrWindowSize() const { return lfr_window_size_; } + + int32_t LfrWindowShift() const { return lfr_window_shift_; } + + int32_t EncoderOutputSize() const { return encoder_output_size_; } + + int32_t DecoderKernelSize() const { return decoder_kernel_size_; } + + int32_t DecoderNumBlocks() const { return decoder_num_blocks_; } + + const std::vector &NegativeMean() const { return neg_mean_; } + + const std::vector &InverseStdDev() const { return inv_stddev_; } + + OrtAllocator *Allocator() const { return allocator_; } + + private: + void InitEncoder(void *model_data, size_t model_data_length) { + encoder_sess_ = std::make_unique( + env_, model_data, model_data_length, sess_opts_); + + GetInputNames(encoder_sess_.get(), &encoder_input_names_, + &encoder_input_names_ptr_); + + GetOutputNames(encoder_sess_.get(), &encoder_output_names_, + &encoder_output_names_ptr_); + + // get meta data + Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata(); + if (config_.debug) { + std::ostringstream os; + PrintModelMetadata(os, meta_data); + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); + } + + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below + SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size"); + SHERPA_ONNX_READ_META_DATA(lfr_window_size_, "lfr_window_size"); + SHERPA_ONNX_READ_META_DATA(lfr_window_shift_, "lfr_window_shift"); + SHERPA_ONNX_READ_META_DATA(encoder_output_size_, "encoder_output_size"); + SHERPA_ONNX_READ_META_DATA(decoder_num_blocks_, "decoder_num_blocks"); + SHERPA_ONNX_READ_META_DATA(decoder_kernel_size_, "decoder_kernel_size"); + + SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(neg_mean_, "neg_mean"); + SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(inv_stddev_, "inv_stddev"); + + float scale = std::sqrt(encoder_output_size_); + for (auto &f : inv_stddev_) { + f *= scale; + } + } + + void InitDecoder(void *model_data, size_t model_data_length) { + decoder_sess_ = std::make_unique( + env_, model_data, model_data_length, sess_opts_); + + GetInputNames(decoder_sess_.get(), &decoder_input_names_, + &decoder_input_names_ptr_); + + GetOutputNames(decoder_sess_.get(), &decoder_output_names_, + &decoder_output_names_ptr_); + } + + private: + OnlineModelConfig config_; + Ort::Env env_; + Ort::SessionOptions sess_opts_; + Ort::AllocatorWithDefaultOptions allocator_; + + std::unique_ptr encoder_sess_; + + std::vector encoder_input_names_; + std::vector encoder_input_names_ptr_; + + std::vector encoder_output_names_; + std::vector encoder_output_names_ptr_; + + std::unique_ptr decoder_sess_; + + std::vector decoder_input_names_; + std::vector decoder_input_names_ptr_; + + std::vector decoder_output_names_; + std::vector decoder_output_names_ptr_; + + std::vector neg_mean_; + std::vector inv_stddev_; + + int32_t vocab_size_ = 0; // initialized in Init + int32_t lfr_window_size_ = 0; + int32_t lfr_window_shift_ = 0; + + int32_t encoder_output_size_ = 0; + int32_t decoder_num_blocks_ = 0; + int32_t decoder_kernel_size_ = 0; +}; + +OnlineParaformerModel::OnlineParaformerModel(const OnlineModelConfig &config) + : impl_(std::make_unique(config)) {} + +#if __ANDROID_API__ >= 9 +OnlineParaformerModel::OnlineParaformerModel(AAssetManager *mgr, + const OnlineModelConfig &config) + : impl_(std::make_unique(mgr, config)) {} +#endif + +OnlineParaformerModel::~OnlineParaformerModel() = default; + +std::vector OnlineParaformerModel::ForwardEncoder( + Ort::Value features, Ort::Value features_length) const { + return impl_->ForwardEncoder(std::move(features), std::move(features_length)); +} + +std::vector OnlineParaformerModel::ForwardDecoder( + Ort::Value encoder_out, Ort::Value encoder_out_length, + Ort::Value acoustic_embedding, Ort::Value acoustic_embedding_length, + std::vector states) const { + return impl_->ForwardDecoder( + std::move(encoder_out), std::move(encoder_out_length), + std::move(acoustic_embedding), std::move(acoustic_embedding_length), + std::move(states)); +} + +int32_t OnlineParaformerModel::VocabSize() const { return impl_->VocabSize(); } + +int32_t OnlineParaformerModel::LfrWindowSize() const { + return impl_->LfrWindowSize(); +} +int32_t OnlineParaformerModel::LfrWindowShift() const { + return impl_->LfrWindowShift(); +} + +int32_t OnlineParaformerModel::EncoderOutputSize() const { + return impl_->EncoderOutputSize(); +} + +int32_t OnlineParaformerModel::DecoderKernelSize() const { + return impl_->DecoderKernelSize(); +} + +int32_t OnlineParaformerModel::DecoderNumBlocks() const { + return impl_->DecoderNumBlocks(); +} + +const std::vector &OnlineParaformerModel::NegativeMean() const { + return impl_->NegativeMean(); +} +const std::vector &OnlineParaformerModel::InverseStdDev() const { + return impl_->InverseStdDev(); +} + +OrtAllocator *OnlineParaformerModel::Allocator() const { + return impl_->Allocator(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-paraformer-model.h b/sherpa-onnx/csrc/online-paraformer-model.h new file mode 100644 index 00000000..3c018a72 --- /dev/null +++ b/sherpa-onnx/csrc/online-paraformer-model.h @@ -0,0 +1,76 @@ +// sherpa-onnx/csrc/online-paraformer-model.h +// +// Copyright (c) 2022-2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_MODEL_H_ +#define SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_MODEL_H_ + +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/online-model-config.h" + +namespace sherpa_onnx { + +class OnlineParaformerModel { + public: + explicit OnlineParaformerModel(const OnlineModelConfig &config); + +#if __ANDROID_API__ >= 9 + OnlineParaformerModel(AAssetManager *mgr, const OnlineModelConfig &config); +#endif + + ~OnlineParaformerModel(); + + std::vector ForwardEncoder(Ort::Value features, + Ort::Value features_length) const; + + std::vector ForwardDecoder(Ort::Value encoder_out, + Ort::Value encoder_out_length, + Ort::Value acoustic_embedding, + Ort::Value acoustic_embedding_length, + std::vector states) const; + + /** Return the vocabulary size of the model + */ + int32_t VocabSize() const; + + /** It is lfr_m in config.yaml + */ + int32_t LfrWindowSize() const; + + /** It is lfr_n in config.yaml + */ + int32_t LfrWindowShift() const; + + int32_t EncoderOutputSize() const; + + int32_t DecoderKernelSize() const; + int32_t DecoderNumBlocks() const; + + /** Return negative mean for CMVN + */ + const std::vector &NegativeMean() const; + + /** Return inverse stddev for CMVN + */ + const std::vector &InverseStdDev() const; + + /** Return an allocator for allocating memory + */ + OrtAllocator *Allocator() const; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_MODEL_H_ diff --git a/sherpa-onnx/csrc/online-recognizer-impl.cc b/sherpa-onnx/csrc/online-recognizer-impl.cc index a9e545dd..1eb16c03 100644 --- a/sherpa-onnx/csrc/online-recognizer-impl.cc +++ b/sherpa-onnx/csrc/online-recognizer-impl.cc @@ -4,6 +4,7 @@ #include "sherpa-onnx/csrc/online-recognizer-impl.h" +#include "sherpa-onnx/csrc/online-recognizer-paraformer-impl.h" #include "sherpa-onnx/csrc/online-recognizer-transducer-impl.h" namespace sherpa_onnx { @@ -14,6 +15,10 @@ std::unique_ptr OnlineRecognizerImpl::Create( return std::make_unique(config); } + if (!config.model_config.paraformer.encoder.empty()) { + return std::make_unique(config); + } + SHERPA_ONNX_LOGE("Please specify a model"); exit(-1); } @@ -25,6 +30,10 @@ std::unique_ptr OnlineRecognizerImpl::Create( return std::make_unique(mgr, config); } + if (!config.model_config.paraformer.encoder.empty()) { + return std::make_unique(mgr, config); + } + SHERPA_ONNX_LOGE("Please specify a model"); exit(-1); } diff --git a/sherpa-onnx/csrc/online-recognizer-impl.h b/sherpa-onnx/csrc/online-recognizer-impl.h index 8b574a4d..515c9d9e 100644 --- a/sherpa-onnx/csrc/online-recognizer-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-impl.h @@ -26,8 +26,6 @@ class OnlineRecognizerImpl { virtual ~OnlineRecognizerImpl() = default; - virtual void InitOnlineStream(OnlineStream *stream) const = 0; - virtual std::unique_ptr CreateStream() const = 0; virtual std::unique_ptr CreateStream( diff --git a/sherpa-onnx/csrc/online-recognizer-paraformer-impl.h b/sherpa-onnx/csrc/online-recognizer-paraformer-impl.h new file mode 100644 index 00000000..ae209633 --- /dev/null +++ b/sherpa-onnx/csrc/online-recognizer-paraformer-impl.h @@ -0,0 +1,465 @@ +// sherpa-onnx/csrc/online-recognizer-paraformer-impl.h +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_PARAFORMER_IMPL_H_ +#define SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_PARAFORMER_IMPL_H_ + +#include +#include +#include +#include +#include + +#include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/online-lm.h" +#include "sherpa-onnx/csrc/online-paraformer-decoder.h" +#include "sherpa-onnx/csrc/online-paraformer-model.h" +#include "sherpa-onnx/csrc/online-recognizer-impl.h" +#include "sherpa-onnx/csrc/online-recognizer.h" +#include "sherpa-onnx/csrc/symbol-table.h" + +namespace sherpa_onnx { + +static OnlineRecognizerResult Convert(const OnlineParaformerDecoderResult &src, + const SymbolTable &sym_table) { + OnlineRecognizerResult r; + r.tokens.reserve(src.tokens.size()); + + std::string text; + + // When the current token ends with "@@" we set mergeable to true + bool mergeable = false; + + for (int32_t i = 0; i != src.tokens.size(); ++i) { + auto sym = sym_table[src.tokens[i]]; + r.tokens.push_back(sym); + + if ((sym.back() != '@') || (sym.size() > 2 && sym[sym.size() - 2] != '@')) { + // sym does not end with "@@" + const uint8_t *p = reinterpret_cast(sym.c_str()); + if (p[0] < 0x80) { + // an ascii + if (mergeable) { + mergeable = false; + text.append(sym); + } else { + text.append(" "); + text.append(sym); + } + } else { + // not an ascii + mergeable = false; + + if (i > 0) { + const uint8_t *p = reinterpret_cast( + sym_table[src.tokens[i - 1]].c_str()); + if (p[0] < 0x80) { + // put a space between ascii and non-ascii + text.append(" "); + } + } + text.append(sym); + } + } else { + // this sym ends with @@ + sym = std::string(sym.data(), sym.size() - 2); + if (mergeable) { + text.append(sym); + } else { + text.append(" "); + text.append(sym); + mergeable = true; + } + } + } + r.text = std::move(text); + + return r; +} + +// y[i] += x[i] * scale +static void ScaleAddInPlace(const float *x, int32_t n, float scale, float *y) { + for (int32_t i = 0; i != n; ++i) { + y[i] += x[i] * scale; + } +} + +// y[i] = x[i] * scale +static void Scale(const float *x, int32_t n, float scale, float *y) { + for (int32_t i = 0; i != n; ++i) { + y[i] = x[i] * scale; + } +} + +class OnlineRecognizerParaformerImpl : public OnlineRecognizerImpl { + public: + explicit OnlineRecognizerParaformerImpl(const OnlineRecognizerConfig &config) + : config_(config), + model_(config.model_config), + sym_(config.model_config.tokens), + endpoint_(config_.endpoint_config) { + if (config.decoding_method != "greedy_search") { + SHERPA_ONNX_LOGE( + "Unsupported decoding method: %s. Support only greedy_search at " + "present", + config.decoding_method.c_str()); + exit(-1); + } + + // Paraformer models assume input samples are in the range + // [-32768, 32767], so we set normalize_samples to false + config_.feat_config.normalize_samples = false; + } + +#if __ANDROID_API__ >= 9 + explicit OnlineRecognizerParaformerImpl(AAssetManager *mgr, + const OnlineRecognizerConfig &config) + : config_(config), + model_(mgr, config.model_config), + sym_(mgr, config.model_config.tokens), + endpoint_(config_.endpoint_config) { + if (config.decoding_method == "greedy_search") { + // add greedy search decoder + // SHERPA_ONNX_LOGE("to be implemented"); + // exit(-1); + } else { + SHERPA_ONNX_LOGE("Unsupported decoding method: %s", + config.decoding_method.c_str()); + exit(-1); + } + + // Paraformer models assume input samples are in the range + // [-32768, 32767], so we set normalize_samples to false + config_.feat_config.normalize_samples = false; + } +#endif + OnlineRecognizerParaformerImpl(const OnlineRecognizerParaformerImpl &) = + delete; + + OnlineRecognizerParaformerImpl operator=( + const OnlineRecognizerParaformerImpl &) = delete; + + std::unique_ptr CreateStream() const override { + auto stream = std::make_unique(config_.feat_config); + + OnlineParaformerDecoderResult r; + stream->SetParaformerResult(r); + + return stream; + } + + bool IsReady(OnlineStream *s) const override { + return s->GetNumProcessedFrames() + chunk_size_ < s->NumFramesReady(); + } + + void DecodeStreams(OnlineStream **ss, int32_t n) const override { + // TODO(fangjun): Support batch size > 1 + for (int32_t i = 0; i != n; ++i) { + DecodeStream(ss[i]); + } + } + + OnlineRecognizerResult GetResult(OnlineStream *s) const override { + auto decoder_result = s->GetParaformerResult(); + + return Convert(decoder_result, sym_); + } + + bool IsEndpoint(OnlineStream *s) const override { + if (!config_.enable_endpoint) { + return false; + } + + const auto &result = s->GetParaformerResult(); + + int32_t num_processed_frames = s->GetNumProcessedFrames(); + + // frame shift is 10 milliseconds + float frame_shift_in_seconds = 0.01; + + int32_t trailing_silence_frames = + num_processed_frames - result.last_non_blank_frame_index; + + return endpoint_.IsEndpoint(num_processed_frames, trailing_silence_frames, + frame_shift_in_seconds); + } + + void Reset(OnlineStream *s) const override { + OnlineParaformerDecoderResult r; + s->SetParaformerResult(r); + + // the internal model caches are not reset + + // Note: We only update counters. The underlying audio samples + // are not discarded. + s->Reset(); + } + + private: + void DecodeStream(OnlineStream *s) const { + const auto num_processed_frames = s->GetNumProcessedFrames(); + std::vector frames = s->GetFrames(num_processed_frames, chunk_size_); + s->GetNumProcessedFrames() += chunk_size_ - 1; + + frames = ApplyLFR(frames); + ApplyCMVN(&frames); + PositionalEncoding(&frames, num_processed_frames / model_.LfrWindowShift()); + + int32_t feat_dim = model_.NegativeMean().size(); + + // We have scaled inv_stddev by sqrt(encoder_output_size) + // so the following line can be commented out + // frames *= encoder_output_size ** 0.5 + + // add overlap chunk + std::vector &feat_cache = s->GetParaformerFeatCache(); + if (feat_cache.empty()) { + int32_t n = (left_chunk_size_ + right_chunk_size_) * feat_dim; + feat_cache.resize(n, 0); + } + + frames.insert(frames.begin(), feat_cache.begin(), feat_cache.end()); + std::copy(frames.end() - feat_cache.size(), frames.end(), + feat_cache.begin()); + + int32_t num_frames = frames.size() / feat_dim; + + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + std::array x_shape{1, num_frames, feat_dim}; + Ort::Value x = + Ort::Value::CreateTensor(memory_info, frames.data(), frames.size(), + x_shape.data(), x_shape.size()); + + int64_t x_len_shape = 1; + int32_t x_len_val = num_frames; + + Ort::Value x_length = + Ort::Value::CreateTensor(memory_info, &x_len_val, 1, &x_len_shape, 1); + + auto encoder_out_vec = + model_.ForwardEncoder(std::move(x), std::move(x_length)); + + // CIF search + auto &encoder_out = encoder_out_vec[0]; + auto &encoder_out_len = encoder_out_vec[1]; + auto &alpha = encoder_out_vec[2]; + + float *p_alpha = alpha.GetTensorMutableData(); + + std::vector alpha_shape = + alpha.GetTensorTypeAndShapeInfo().GetShape(); + + std::fill(p_alpha, p_alpha + left_chunk_size_, 0); + std::fill(p_alpha + alpha_shape[1] - right_chunk_size_, + p_alpha + alpha_shape[1], 0); + + const float *p_encoder_out = encoder_out.GetTensorData(); + + std::vector encoder_out_shape = + encoder_out.GetTensorTypeAndShapeInfo().GetShape(); + + std::vector &initial_hidden = s->GetParaformerEncoderOutCache(); + if (initial_hidden.empty()) { + initial_hidden.resize(encoder_out_shape[2]); + } + + std::vector &alpha_cache = s->GetParaformerAlphaCache(); + if (alpha_cache.empty()) { + alpha_cache.resize(1); + } + + std::vector acoustic_embedding; + acoustic_embedding.reserve(encoder_out_shape[1] * encoder_out_shape[2]); + + float threshold = 1.0; + + float integrate = alpha_cache[0]; + + for (int32_t i = 0; i != encoder_out_shape[1]; ++i) { + float this_alpha = p_alpha[i]; + if (integrate + this_alpha < threshold) { + integrate += this_alpha; + ScaleAddInPlace(p_encoder_out + i * encoder_out_shape[2], + encoder_out_shape[2], this_alpha, + initial_hidden.data()); + continue; + } + + // fire + ScaleAddInPlace(p_encoder_out + i * encoder_out_shape[2], + encoder_out_shape[2], threshold - integrate, + initial_hidden.data()); + acoustic_embedding.insert(acoustic_embedding.end(), + initial_hidden.begin(), initial_hidden.end()); + integrate += this_alpha - threshold; + + Scale(p_encoder_out + i * encoder_out_shape[2], encoder_out_shape[2], + integrate, initial_hidden.data()); + } + + alpha_cache[0] = integrate; + + if (acoustic_embedding.empty()) { + return; + } + + auto &states = s->GetStates(); + if (states.empty()) { + states.reserve(model_.DecoderNumBlocks()); + + std::array shape{1, model_.EncoderOutputSize(), + model_.DecoderKernelSize() - 1}; + + int32_t num_bytes = sizeof(float) * shape[0] * shape[1] * shape[2]; + + for (int32_t i = 0; i != model_.DecoderNumBlocks(); ++i) { + Ort::Value this_state = Ort::Value::CreateTensor( + model_.Allocator(), shape.data(), shape.size()); + + memset(this_state.GetTensorMutableData(), 0, num_bytes); + + states.push_back(std::move(this_state)); + } + } + + int32_t num_tokens = acoustic_embedding.size() / initial_hidden.size(); + std::array acoustic_embedding_shape{ + 1, num_tokens, static_cast(initial_hidden.size())}; + + Ort::Value acoustic_embedding_tensor = Ort::Value::CreateTensor( + memory_info, acoustic_embedding.data(), acoustic_embedding.size(), + acoustic_embedding_shape.data(), acoustic_embedding_shape.size()); + + std::array acoustic_embedding_length_shape{1}; + Ort::Value acoustic_embedding_length_tensor = Ort::Value::CreateTensor( + memory_info, &num_tokens, 1, acoustic_embedding_length_shape.data(), + acoustic_embedding_length_shape.size()); + + auto decoder_out_vec = model_.ForwardDecoder( + std::move(encoder_out), std::move(encoder_out_len), + std::move(acoustic_embedding_tensor), + std::move(acoustic_embedding_length_tensor), std::move(states)); + + states.reserve(model_.DecoderNumBlocks()); + for (int32_t i = 2; i != decoder_out_vec.size(); ++i) { + // TODO(fangjun): When we change chunk_size_, we need to + // slice decoder_out_vec[i] accordingly. + states.push_back(std::move(decoder_out_vec[i])); + } + + const auto &sample_ids = decoder_out_vec[1]; + const int64_t *p_sample_ids = sample_ids.GetTensorData(); + + bool non_blank_detected = false; + + auto &result = s->GetParaformerResult(); + + for (int32_t i = 0; i != num_tokens; ++i) { + int32_t t = p_sample_ids[i]; + if (t == 0) { + continue; + } + + non_blank_detected = true; + result.tokens.push_back(t); + } + + if (non_blank_detected) { + result.last_non_blank_frame_index = num_processed_frames; + } + } + + std::vector ApplyLFR(const std::vector &in) const { + int32_t lfr_window_size = model_.LfrWindowSize(); + int32_t lfr_window_shift = model_.LfrWindowShift(); + int32_t in_feat_dim = config_.feat_config.feature_dim; + + int32_t in_num_frames = in.size() / in_feat_dim; + int32_t out_num_frames = + (in_num_frames - lfr_window_size) / lfr_window_shift + 1; + int32_t out_feat_dim = in_feat_dim * lfr_window_size; + + std::vector out(out_num_frames * out_feat_dim); + + const float *p_in = in.data(); + float *p_out = out.data(); + + for (int32_t i = 0; i != out_num_frames; ++i) { + std::copy(p_in, p_in + out_feat_dim, p_out); + + p_out += out_feat_dim; + p_in += lfr_window_shift * in_feat_dim; + } + + return out; + } + + void ApplyCMVN(std::vector *v) const { + const std::vector &neg_mean = model_.NegativeMean(); + const std::vector &inv_stddev = model_.InverseStdDev(); + + int32_t dim = neg_mean.size(); + int32_t num_frames = v->size() / dim; + + float *p = v->data(); + + for (int32_t i = 0; i != num_frames; ++i) { + for (int32_t k = 0; k != dim; ++k) { + p[k] = (p[k] + neg_mean[k]) * inv_stddev[k]; + } + + p += dim; + } + } + + void PositionalEncoding(std::vector *v, int32_t t_offset) const { + int32_t lfr_window_size = model_.LfrWindowSize(); + int32_t in_feat_dim = config_.feat_config.feature_dim; + + int32_t feat_dim = in_feat_dim * lfr_window_size; + int32_t T = v->size() / feat_dim; + + // log(10000)/(7*80/2-1) == 0.03301197265941284 + // 7 is lfr_window_size + // 80 is in_feat_dim + // 7*80 is feat_dim + constexpr float kScale = -0.03301197265941284; + + for (int32_t t = 0; t != T; ++t) { + float *p = v->data() + t * feat_dim; + + int32_t offset = t + 1 + t_offset; + + for (int32_t d = 0; d < feat_dim / 2; ++d) { + float inv_timescale = offset * std::exp(d * kScale); + + float sin_d = std::sin(inv_timescale); + float cos_d = std::cos(inv_timescale); + + p[d] += sin_d; + p[d + feat_dim / 2] += cos_d; + } + } + } + + private: + OnlineRecognizerConfig config_; + OnlineParaformerModel model_; + SymbolTable sym_; + Endpoint endpoint_; + + // 0.61 seconds + int32_t chunk_size_ = 61; + // (61 - 7) / 6 + 1 = 10 + + int32_t left_chunk_size_ = 5; + int32_t right_chunk_size_ = 5; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_PARAFORMER_IMPL_H_ diff --git a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h index a5d2d815..625d02b1 100644 --- a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h @@ -94,21 +94,6 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { } #endif - void InitOnlineStream(OnlineStream *stream) const override { - auto r = decoder_->GetEmptyResult(); - - if (config_.decoding_method == "modified_beam_search" && - nullptr != stream->GetContextGraph()) { - // r.hyps has only one element. - for (auto it = r.hyps.begin(); it != r.hyps.end(); ++it) { - it->second.context_state = stream->GetContextGraph()->Root(); - } - } - - stream->SetResult(r); - stream->SetStates(model_->GetEncoderInitStates()); - } - std::unique_ptr CreateStream() const override { auto stream = std::make_unique(config_.feat_config); InitOnlineStream(stream.get()); @@ -211,7 +196,10 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { } bool IsEndpoint(OnlineStream *s) const override { - if (!config_.enable_endpoint) return false; + if (!config_.enable_endpoint) { + return false; + } + int32_t num_processed_frames = s->GetNumProcessedFrames(); // frame shift is 10 milliseconds @@ -244,6 +232,22 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { s->Reset(); } + private: + void InitOnlineStream(OnlineStream *stream) const { + auto r = decoder_->GetEmptyResult(); + + if (config_.decoding_method == "modified_beam_search" && + nullptr != stream->GetContextGraph()) { + // r.hyps has only one element. + for (auto it = r.hyps.begin(); it != r.hyps.end(); ++it) { + it->second.context_state = stream->GetContextGraph()->Root(); + } + } + + stream->SetResult(r); + stream->SetStates(model_->GetEncoderInitStates()); + } + private: OnlineRecognizerConfig config_; std::unique_ptr model_; diff --git a/sherpa-onnx/csrc/online-stream.cc b/sherpa-onnx/csrc/online-stream.cc index e0593ff6..8960ed13 100644 --- a/sherpa-onnx/csrc/online-stream.cc +++ b/sherpa-onnx/csrc/online-stream.cc @@ -47,6 +47,14 @@ class OnlineStream::Impl { OnlineTransducerDecoderResult &GetResult() { return result_; } + void SetParaformerResult(const OnlineParaformerDecoderResult &r) { + paraformer_result_ = r; + } + + OnlineParaformerDecoderResult &GetParaformerResult() { + return paraformer_result_; + } + int32_t FeatureDim() const { return feat_extractor_.FeatureDim(); } void SetStates(std::vector states) { @@ -57,6 +65,18 @@ class OnlineStream::Impl { const ContextGraphPtr &GetContextGraph() const { return context_graph_; } + std::vector &GetParaformerFeatCache() { + return paraformer_feat_cache_; + } + + std::vector &GetParaformerEncoderOutCache() { + return paraformer_encoder_out_cache_; + } + + std::vector &GetParaformerAlphaCache() { + return paraformer_alpha_cache_; + } + private: FeatureExtractor feat_extractor_; /// For contextual-biasing @@ -65,6 +85,10 @@ class OnlineStream::Impl { int32_t start_frame_index_ = 0; // never reset OnlineTransducerDecoderResult result_; std::vector states_; + std::vector paraformer_feat_cache_; + std::vector paraformer_encoder_out_cache_; + std::vector paraformer_alpha_cache_; + OnlineParaformerDecoderResult paraformer_result_; }; OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/, @@ -107,6 +131,14 @@ OnlineTransducerDecoderResult &OnlineStream::GetResult() { return impl_->GetResult(); } +void OnlineStream::SetParaformerResult(const OnlineParaformerDecoderResult &r) { + impl_->SetParaformerResult(r); +} + +OnlineParaformerDecoderResult &OnlineStream::GetParaformerResult() { + return impl_->GetParaformerResult(); +} + void OnlineStream::SetStates(std::vector states) { impl_->SetStates(std::move(states)); } @@ -119,4 +151,16 @@ const ContextGraphPtr &OnlineStream::GetContextGraph() const { return impl_->GetContextGraph(); } +std::vector &OnlineStream::GetParaformerFeatCache() { + return impl_->GetParaformerFeatCache(); +} + +std::vector &OnlineStream::GetParaformerEncoderOutCache() { + return impl_->GetParaformerEncoderOutCache(); +} + +std::vector &OnlineStream::GetParaformerAlphaCache() { + return impl_->GetParaformerAlphaCache(); +} + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-stream.h b/sherpa-onnx/csrc/online-stream.h index 60dce950..ae920c1d 100644 --- a/sherpa-onnx/csrc/online-stream.h +++ b/sherpa-onnx/csrc/online-stream.h @@ -11,6 +11,7 @@ #include "onnxruntime_cxx_api.h" // NOLINT #include "sherpa-onnx/csrc/context-graph.h" #include "sherpa-onnx/csrc/features.h" +#include "sherpa-onnx/csrc/online-paraformer-decoder.h" #include "sherpa-onnx/csrc/online-transducer-decoder.h" namespace sherpa_onnx { @@ -70,6 +71,9 @@ class OnlineStream { void SetResult(const OnlineTransducerDecoderResult &r); OnlineTransducerDecoderResult &GetResult(); + void SetParaformerResult(const OnlineParaformerDecoderResult &r); + OnlineParaformerDecoderResult &GetParaformerResult(); + void SetStates(std::vector states); std::vector &GetStates(); @@ -80,6 +84,11 @@ class OnlineStream { */ const ContextGraphPtr &GetContextGraph() const; + // for streaming parformer + std::vector &GetParaformerFeatCache(); + std::vector &GetParaformerEncoderOutCache(); + std::vector &GetParaformerAlphaCache(); + private: class Impl; std::unique_ptr impl_; diff --git a/sherpa-onnx/csrc/sherpa-onnx.cc b/sherpa-onnx/csrc/sherpa-onnx.cc index 8d527e90..9e771fc5 100644 --- a/sherpa-onnx/csrc/sherpa-onnx.cc +++ b/sherpa-onnx/csrc/sherpa-onnx.cc @@ -12,8 +12,8 @@ #include "sherpa-onnx/csrc/online-recognizer.h" #include "sherpa-onnx/csrc/online-stream.h" -#include "sherpa-onnx/csrc/symbol-table.h" #include "sherpa-onnx/csrc/parse-options.h" +#include "sherpa-onnx/csrc/symbol-table.h" #include "sherpa-onnx/csrc/wave-reader.h" typedef struct { @@ -80,7 +80,7 @@ for a list of pre-trained models to download. bool is_ok = false; const std::vector samples = - sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok); + sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok); if (!is_ok) { fprintf(stderr, "Failed to read %s\n", wav_filename.c_str()); @@ -92,14 +92,14 @@ for a list of pre-trained models to download. auto s = recognizer.CreateStream(); s->AcceptWaveform(sampling_rate, samples.data(), samples.size()); - std::vector tail_paddings(static_cast(0.3 * sampling_rate)); + std::vector tail_paddings(static_cast(0.8 * sampling_rate)); // Note: We can call AcceptWaveform() multiple times. - s->AcceptWaveform( - sampling_rate, tail_paddings.data(), tail_paddings.size()); + s->AcceptWaveform(sampling_rate, tail_paddings.data(), + tail_paddings.size()); // Call InputFinished() to indicate that no audio samples are available s->InputFinished(); - ss.push_back({ std::move(s), duration, 0 }); + ss.push_back({std::move(s), duration, 0}); } std::vector ready_streams; @@ -112,8 +112,9 @@ for a list of pre-trained models to download. } else if (s.elapsed_seconds == 0) { const auto end = std::chrono::steady_clock::now(); const float elapsed_seconds = - std::chrono::duration_cast(end - begin) - .count() / 1000.; + std::chrono::duration_cast(end - begin) + .count() / + 1000.; s.elapsed_seconds = elapsed_seconds; } } diff --git a/sherpa-onnx/python/csrc/CMakeLists.txt b/sherpa-onnx/python/csrc/CMakeLists.txt index 28612924..d61e4303 100644 --- a/sherpa-onnx/python/csrc/CMakeLists.txt +++ b/sherpa-onnx/python/csrc/CMakeLists.txt @@ -15,6 +15,7 @@ pybind11_add_module(_sherpa_onnx offline-whisper-model-config.cc online-lm-config.cc online-model-config.cc + online-paraformer-model-config.cc online-recognizer.cc online-stream.cc online-transducer-model-config.cc diff --git a/sherpa-onnx/python/csrc/online-model-config.cc b/sherpa-onnx/python/csrc/online-model-config.cc index 677d3b1f..7e37a87c 100644 --- a/sherpa-onnx/python/csrc/online-model-config.cc +++ b/sherpa-onnx/python/csrc/online-model-config.cc @@ -1,6 +1,6 @@ // sherpa-onnx/python/csrc/online-model-config.cc // -// Copyright (c) 2023 by manyeyes +// Copyright (c) 2023 Xiaomi Corporation #include "sherpa-onnx/python/csrc/online-model-config.h" @@ -9,21 +9,26 @@ #include "sherpa-onnx/csrc/online-model-config.h" #include "sherpa-onnx/csrc/online-transducer-model-config.h" +#include "sherpa-onnx/python/csrc/online-paraformer-model-config.h" #include "sherpa-onnx/python/csrc/online-transducer-model-config.h" namespace sherpa_onnx { void PybindOnlineModelConfig(py::module *m) { PybindOnlineTransducerModelConfig(m); + PybindOnlineParaformerModelConfig(m); using PyClass = OnlineModelConfig; py::class_(*m, "OnlineModelConfig") - .def(py::init(), py::arg("transducer") = OnlineTransducerModelConfig(), + py::arg("paraformer") = OnlineParaformerModelConfig(), py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false, py::arg("provider") = "cpu", py::arg("model_type") = "") .def_readwrite("transducer", &PyClass::transducer) + .def_readwrite("paraformer", &PyClass::paraformer) .def_readwrite("tokens", &PyClass::tokens) .def_readwrite("num_threads", &PyClass::num_threads) .def_readwrite("debug", &PyClass::debug) diff --git a/sherpa-onnx/python/csrc/online-model-config.h b/sherpa-onnx/python/csrc/online-model-config.h index 73154fc9..3624a104 100644 --- a/sherpa-onnx/python/csrc/online-model-config.h +++ b/sherpa-onnx/python/csrc/online-model-config.h @@ -1,6 +1,6 @@ // sherpa-onnx/python/csrc/online-model-config.h // -// Copyright (c) 2023 by manyeyes +// Copyright (c) 2023 Xiaomi Corporation #ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_MODEL_CONFIG_H_ #define SHERPA_ONNX_PYTHON_CSRC_ONLINE_MODEL_CONFIG_H_ diff --git a/sherpa-onnx/python/csrc/online-paraformer-model-config.cc b/sherpa-onnx/python/csrc/online-paraformer-model-config.cc new file mode 100644 index 00000000..84895acb --- /dev/null +++ b/sherpa-onnx/python/csrc/online-paraformer-model-config.cc @@ -0,0 +1,24 @@ +// sherpa-onnx/python/csrc/online-paraformer-model-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/python/csrc/online-paraformer-model-config.h" + +#include +#include + +#include "sherpa-onnx/csrc/online-paraformer-model-config.h" + +namespace sherpa_onnx { + +void PybindOnlineParaformerModelConfig(py::module *m) { + using PyClass = OnlineParaformerModelConfig; + py::class_(*m, "OnlineParaformerModelConfig") + .def(py::init(), + py::arg("encoder"), py::arg("decoder")) + .def_readwrite("encoder", &PyClass::encoder) + .def_readwrite("decoder", &PyClass::decoder) + .def("__str__", &PyClass::ToString); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/online-paraformer-model-config.h b/sherpa-onnx/python/csrc/online-paraformer-model-config.h new file mode 100644 index 00000000..ad1dc1d7 --- /dev/null +++ b/sherpa-onnx/python/csrc/online-paraformer-model-config.h @@ -0,0 +1,16 @@ +// sherpa-onnx/python/csrc/online-paraformer-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_PARAFORMER_MODEL_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_ONLINE_PARAFORMER_MODEL_CONFIG_H_ + +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" + +namespace sherpa_onnx { + +void PybindOnlineParaformerModelConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_ONLINE_PARAFORMER_MODEL_CONFIG_H_ diff --git a/sherpa-onnx/python/csrc/online-recognizer.cc b/sherpa-onnx/python/csrc/online-recognizer.cc index 34a907ce..c130d87c 100644 --- a/sherpa-onnx/python/csrc/online-recognizer.cc +++ b/sherpa-onnx/python/csrc/online-recognizer.cc @@ -33,7 +33,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) { py::arg("feat_config"), py::arg("model_config"), py::arg("lm_config") = OnlineLMConfig(), py::arg("endpoint_config"), py::arg("enable_endpoint"), py::arg("decoding_method"), - py::arg("max_active_paths"), py::arg("context_score")) + py::arg("max_active_paths") = 4, py::arg("context_score") = 0) .def_readwrite("feat_config", &PyClass::feat_config) .def_readwrite("model_config", &PyClass::model_config) .def_readwrite("endpoint_config", &PyClass::endpoint_config) diff --git a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py index c49e1b43..55e789ba 100644 --- a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py @@ -6,6 +6,7 @@ from _sherpa_onnx import ( EndpointConfig, FeatureExtractorConfig, OnlineModelConfig, + OnlineParaformerModelConfig, OnlineRecognizer as _Recognizer, OnlineRecognizerConfig, OnlineStream, @@ -32,7 +33,7 @@ class OnlineRecognizer(object): encoder: str, decoder: str, joiner: str, - num_threads: int = 4, + num_threads: int = 2, sample_rate: float = 16000, feature_dim: int = 80, enable_endpoint_detection: bool = False, @@ -144,6 +145,109 @@ class OnlineRecognizer(object): self.config = recognizer_config return self + @classmethod + def from_paraformer( + cls, + tokens: str, + encoder: str, + decoder: str, + num_threads: int = 2, + sample_rate: float = 16000, + feature_dim: int = 80, + enable_endpoint_detection: bool = False, + rule1_min_trailing_silence: float = 2.4, + rule2_min_trailing_silence: float = 1.2, + rule3_min_utterance_length: float = 20.0, + decoding_method: str = "greedy_search", + provider: str = "cpu", + ): + """ + Please refer to + ``_ + to download pre-trained models for different languages, e.g., Chinese, + English, etc. + + Args: + tokens: + Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two + columns:: + + symbol integer_id + + encoder: + Path to ``encoder.onnx``. + decoder: + Path to ``decoder.onnx``. + num_threads: + Number of threads for neural network computation. + sample_rate: + Sample rate of the training data used to train the model. + feature_dim: + Dimension of the feature used to train the model. + enable_endpoint_detection: + True to enable endpoint detection. False to disable endpoint + detection. + rule1_min_trailing_silence: + Used only when enable_endpoint_detection is True. If the duration + of trailing silence in seconds is larger than this value, we assume + an endpoint is detected. + rule2_min_trailing_silence: + Used only when enable_endpoint_detection is True. If we have decoded + something that is nonsilence and if the duration of trailing silence + in seconds is larger than this value, we assume an endpoint is + detected. + rule3_min_utterance_length: + Used only when enable_endpoint_detection is True. If the utterance + length in seconds is larger than this value, we assume an endpoint + is detected. + decoding_method: + The only valid value is greedy_search. + provider: + onnxruntime execution providers. Valid values are: cpu, cuda, coreml. + """ + self = cls.__new__(cls) + _assert_file_exists(tokens) + _assert_file_exists(encoder) + _assert_file_exists(decoder) + + assert num_threads > 0, num_threads + + paraformer_config = OnlineParaformerModelConfig( + encoder=encoder, + decoder=decoder, + ) + + model_config = OnlineModelConfig( + paraformer=paraformer_config, + tokens=tokens, + num_threads=num_threads, + provider=provider, + model_type="paraformer", + ) + + feat_config = FeatureExtractorConfig( + sampling_rate=sample_rate, + feature_dim=feature_dim, + ) + + endpoint_config = EndpointConfig( + rule1_min_trailing_silence=rule1_min_trailing_silence, + rule2_min_trailing_silence=rule2_min_trailing_silence, + rule3_min_utterance_length=rule3_min_utterance_length, + ) + + recognizer_config = OnlineRecognizerConfig( + feat_config=feat_config, + model_config=model_config, + endpoint_config=endpoint_config, + enable_endpoint=enable_endpoint_detection, + decoding_method=decoding_method, + ) + + self.recognizer = _Recognizer(recognizer_config) + self.config = recognizer_config + return self + def create_stream(self, contexts_list: Optional[List[List[int]]] = None): if contexts_list is None: return self.recognizer.create_stream()