diff --git a/.github/scripts/test-online-paraformer.sh b/.github/scripts/test-online-paraformer.sh
new file mode 100755
index 00000000..93574e3f
--- /dev/null
+++ b/.github/scripts/test-online-paraformer.sh
@@ -0,0 +1,53 @@
+#!/usr/bin/env bash
+
+set -e
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+echo "EXE is $EXE"
+echo "PATH: $PATH"
+
+which $EXE
+
+log "------------------------------------------------------------"
+log "Run streaming Paraformer"
+log "------------------------------------------------------------"
+
+repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-streaming-paraformer-bilingual-zh-en
+log "Start testing ${repo_url}"
+repo=$(basename $repo_url)
+log "Download pretrained model and test-data from $repo_url"
+
+GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
+pushd $repo
+git lfs pull --include "*.onnx"
+ls -lh *.onnx
+popd
+
+time $EXE \
+ --tokens=$repo/tokens.txt \
+ --paraformer-encoder=$repo/encoder.onnx \
+ --paraformer-decoder=$repo/decoder.onnx \
+ --num-threads=2 \
+ $repo/test_wavs/0.wav \
+ $repo/test_wavs/1.wav \
+ $repo/test_wavs/2.wav \
+ $repo/test_wavs/3.wav \
+ $repo/test_wavs/8k.wav
+
+time $EXE \
+ --tokens=$repo/tokens.txt \
+ --paraformer-encoder=$repo/encoder.int8.onnx \
+ --paraformer-decoder=$repo/decoder.int8.onnx \
+ --num-threads=2 \
+ $repo/test_wavs/0.wav \
+ $repo/test_wavs/1.wav \
+ $repo/test_wavs/2.wav \
+ $repo/test_wavs/3.wav \
+ $repo/test_wavs/8k.wav
+
+rm -rf $repo
diff --git a/.github/workflows/linux-gpu.yaml b/.github/workflows/linux-gpu.yaml
index 7b14ac2b..25350a31 100644
--- a/.github/workflows/linux-gpu.yaml
+++ b/.github/workflows/linux-gpu.yaml
@@ -9,6 +9,7 @@ on:
paths:
- '.github/workflows/linux-gpu.yaml'
- '.github/scripts/test-online-transducer.sh'
+ - '.github/scripts/test-online-paraformer.sh'
- '.github/scripts/test-offline-transducer.sh'
- '.github/scripts/test-offline-ctc.sh'
- 'CMakeLists.txt'
@@ -22,6 +23,7 @@ on:
paths:
- '.github/workflows/linux-gpu.yaml'
- '.github/scripts/test-online-transducer.sh'
+ - '.github/scripts/test-online-paraformer.sh'
- '.github/scripts/test-offline-transducer.sh'
- '.github/scripts/test-offline-ctc.sh'
- 'CMakeLists.txt'
@@ -85,6 +87,14 @@ jobs:
file build/bin/sherpa-onnx
readelf -d build/bin/sherpa-onnx
+ - name: Test online paraformer
+ shell: bash
+ run: |
+ export PATH=$PWD/build/bin:$PATH
+ export EXE=sherpa-onnx
+
+ .github/scripts/test-online-paraformer.sh
+
- name: Test offline Whisper
shell: bash
run: |
diff --git a/.github/workflows/linux.yaml b/.github/workflows/linux.yaml
index a03602ba..2c026fc0 100644
--- a/.github/workflows/linux.yaml
+++ b/.github/workflows/linux.yaml
@@ -9,6 +9,7 @@ on:
paths:
- '.github/workflows/linux.yaml'
- '.github/scripts/test-online-transducer.sh'
+ - '.github/scripts/test-online-paraformer.sh'
- '.github/scripts/test-offline-transducer.sh'
- '.github/scripts/test-offline-ctc.sh'
- 'CMakeLists.txt'
@@ -22,6 +23,7 @@ on:
paths:
- '.github/workflows/linux.yaml'
- '.github/scripts/test-online-transducer.sh'
+ - '.github/scripts/test-online-paraformer.sh'
- '.github/scripts/test-offline-transducer.sh'
- '.github/scripts/test-offline-ctc.sh'
- 'CMakeLists.txt'
@@ -84,6 +86,14 @@ jobs:
file build/bin/sherpa-onnx
readelf -d build/bin/sherpa-onnx
+ - name: Test online paraformer
+ shell: bash
+ run: |
+ export PATH=$PWD/build/bin:$PATH
+ export EXE=sherpa-onnx
+
+ .github/scripts/test-online-paraformer.sh
+
- name: Test offline Whisper
shell: bash
run: |
diff --git a/.github/workflows/macos.yaml b/.github/workflows/macos.yaml
index cebc5ac6..f3b11a5d 100644
--- a/.github/workflows/macos.yaml
+++ b/.github/workflows/macos.yaml
@@ -7,6 +7,7 @@ on:
paths:
- '.github/workflows/macos.yaml'
- '.github/scripts/test-online-transducer.sh'
+ - '.github/scripts/test-online-paraformer.sh'
- '.github/scripts/test-offline-transducer.sh'
- '.github/scripts/test-offline-ctc.sh'
- 'CMakeLists.txt'
@@ -18,6 +19,7 @@ on:
paths:
- '.github/workflows/macos.yaml'
- '.github/scripts/test-online-transducer.sh'
+ - '.github/scripts/test-online-paraformer.sh'
- '.github/scripts/test-offline-transducer.sh'
- '.github/scripts/test-offline-ctc.sh'
- 'CMakeLists.txt'
@@ -82,6 +84,14 @@ jobs:
otool -L build/bin/sherpa-onnx
otool -l build/bin/sherpa-onnx
+ - name: Test online paraformer
+ shell: bash
+ run: |
+ export PATH=$PWD/build/bin:$PATH
+ export EXE=sherpa-onnx
+
+ .github/scripts/test-online-paraformer.sh
+
- name: Test offline Whisper
shell: bash
run: |
diff --git a/.github/workflows/test-pip-install.yaml b/.github/workflows/test-pip-install.yaml
index 34a15360..01fdb4c6 100644
--- a/.github/workflows/test-pip-install.yaml
+++ b/.github/workflows/test-pip-install.yaml
@@ -58,7 +58,6 @@ jobs:
sherpa-onnx-microphone-offline --help
sherpa-onnx-offline-websocket-server --help
- sherpa-onnx-offline-websocket-client --help
sherpa-onnx-online-websocket-server --help
sherpa-onnx-online-websocket-client --help
diff --git a/.github/workflows/test-python-offline-websocket-server.yaml b/.github/workflows/test-python-offline-websocket-server.yaml
index 7ec4e29d..d7ea4dde 100644
--- a/.github/workflows/test-python-offline-websocket-server.yaml
+++ b/.github/workflows/test-python-offline-websocket-server.yaml
@@ -84,14 +84,14 @@ jobs:
if: matrix.model_type == 'paraformer'
shell: bash
run: |
- GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-paraformer-zh-2023-03-28
- cd sherpa-onnx-paraformer-zh-2023-03-28
+ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-paraformer-bilingual-zh-en
+ cd sherpa-onnx-paraformer-bilingual-zh-en
git lfs pull --include "*.onnx"
cd ..
python3 ./python-api-examples/non_streaming_server.py \
- --paraformer ./sherpa-onnx-paraformer-zh-2023-03-28/model.int8.onnx \
- --tokens ./sherpa-onnx-paraformer-zh-2023-03-28/tokens.txt &
+ --paraformer ./sherpa-onnx-paraformer-bilingual-zh-en/model.int8.onnx \
+ --tokens ./sherpa-onnx-paraformer-bilingual-zh-en/tokens.txt &
echo "sleep 10 seconds to wait the server start"
sleep 10
@@ -101,16 +101,16 @@ jobs:
shell: bash
run: |
python3 ./python-api-examples/offline-websocket-client-decode-files-paralell.py \
- ./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/0.wav \
- ./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/1.wav \
- ./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/2.wav \
- ./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/8k.wav
+ ./sherpa-onnx-paraformer-bilingual-zh-en/test_wavs/0.wav \
+ ./sherpa-onnx-paraformer-bilingual-zh-en/test_wavs/1.wav \
+ ./sherpa-onnx-paraformer-bilingual-zh-en/test_wavs/2.wav \
+ ./sherpa-onnx-paraformer-bilingual-zh-en/test_wavs/8k.wav
python3 ./python-api-examples/offline-websocket-client-decode-files-sequential.py \
- ./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/0.wav \
- ./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/1.wav \
- ./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/2.wav \
- ./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/8k.wav
+ ./sherpa-onnx-paraformer-bilingual-zh-en/test_wavs/0.wav \
+ ./sherpa-onnx-paraformer-bilingual-zh-en/test_wavs/1.wav \
+ ./sherpa-onnx-paraformer-bilingual-zh-en/test_wavs/2.wav \
+ ./sherpa-onnx-paraformer-bilingual-zh-en/test_wavs/8k.wav
- name: Start server for nemo_ctc models
if: matrix.model_type == 'nemo_ctc'
diff --git a/.github/workflows/test-python-online-websocket-server.yaml b/.github/workflows/test-python-online-websocket-server.yaml
index c7e3319d..7616afa3 100644
--- a/.github/workflows/test-python-online-websocket-server.yaml
+++ b/.github/workflows/test-python-online-websocket-server.yaml
@@ -24,7 +24,7 @@ jobs:
matrix:
os: [ubuntu-latest, windows-latest, macos-latest]
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"]
- model_type: ["transducer"]
+ model_type: ["transducer", "paraformer"]
steps:
- uses: actions/checkout@v2
@@ -71,3 +71,36 @@ jobs:
run: |
python3 ./python-api-examples/online-websocket-client-decode-file.py \
./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/0.wav
+
+ - name: Start server for paraformer models
+ if: matrix.model_type == 'paraformer'
+ shell: bash
+ run: |
+ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-paraformer-bilingual-zh-en
+ cd sherpa-onnx-streaming-paraformer-bilingual-zh-en
+ git lfs pull --include "*.onnx"
+ cd ..
+
+ python3 ./python-api-examples/streaming_server.py \
+ --tokens ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/tokens.txt \
+ --paraformer-encoder ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/encoder.int8.onnx \
+ --paraformer-decoder ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/decoder.int8.onnx &
+
+ echo "sleep 10 seconds to wait the server start"
+ sleep 10
+
+ - name: Start client for paraformer models
+ if: matrix.model_type == 'paraformer'
+ shell: bash
+ run: |
+ python3 ./python-api-examples/online-websocket-client-decode-file.py \
+ ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/0.wav
+
+ python3 ./python-api-examples/online-websocket-client-decode-file.py \
+ ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/1.wav
+
+ python3 ./python-api-examples/online-websocket-client-decode-file.py \
+ ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/2.wav
+
+ python3 ./python-api-examples/online-websocket-client-decode-file.py \
+ ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/3.wav
diff --git a/.github/workflows/windows-x64-cuda.yaml b/.github/workflows/windows-x64-cuda.yaml
index 24b8158d..17e53d8b 100644
--- a/.github/workflows/windows-x64-cuda.yaml
+++ b/.github/workflows/windows-x64-cuda.yaml
@@ -9,6 +9,7 @@ on:
paths:
- '.github/workflows/windows-x64-cuda.yaml'
- '.github/scripts/test-online-transducer.sh'
+ - '.github/scripts/test-online-paraformer.sh'
- '.github/scripts/test-offline-transducer.sh'
- '.github/scripts/test-offline-ctc.sh'
- 'CMakeLists.txt'
@@ -20,6 +21,7 @@ on:
paths:
- '.github/workflows/windows-x64-cuda.yaml'
- '.github/scripts/test-online-transducer.sh'
+ - '.github/scripts/test-online-paraformer.sh'
- '.github/scripts/test-offline-transducer.sh'
- '.github/scripts/test-offline-ctc.sh'
- 'CMakeLists.txt'
@@ -74,6 +76,14 @@ jobs:
ls -lh ./bin/Release/sherpa-onnx.exe
+ - name: Test online paraformer for windows x64
+ shell: bash
+ run: |
+ export PATH=$PWD/build/bin/Release:$PATH
+ export EXE=sherpa-onnx.exe
+
+ .github/scripts/test-online-paraformer.sh
+
- name: Test offline Whisper for windows x64
shell: bash
run: |
diff --git a/.github/workflows/windows-x64.yaml b/.github/workflows/windows-x64.yaml
index 83b80de9..c63dbae3 100644
--- a/.github/workflows/windows-x64.yaml
+++ b/.github/workflows/windows-x64.yaml
@@ -9,6 +9,7 @@ on:
paths:
- '.github/workflows/windows-x64.yaml'
- '.github/scripts/test-online-transducer.sh'
+ - '.github/scripts/test-online-paraformer.sh'
- '.github/scripts/test-offline-transducer.sh'
- '.github/scripts/test-offline-ctc.sh'
- 'CMakeLists.txt'
@@ -20,6 +21,7 @@ on:
paths:
- '.github/workflows/windows-x64.yaml'
- '.github/scripts/test-online-transducer.sh'
+ - '.github/scripts/test-online-paraformer.sh'
- '.github/scripts/test-offline-transducer.sh'
- '.github/scripts/test-offline-ctc.sh'
- 'CMakeLists.txt'
@@ -75,6 +77,14 @@ jobs:
ls -lh ./bin/Release/sherpa-onnx.exe
+ - name: Test online paraformer for windows x64
+ shell: bash
+ run: |
+ export PATH=$PWD/build/bin/Release:$PATH
+ export EXE=sherpa-onnx.exe
+
+ .github/scripts/test-online-paraformer.sh
+
- name: Test offline Whisper for windows x64
shell: bash
run: |
diff --git a/.github/workflows/windows-x86.yaml b/.github/workflows/windows-x86.yaml
index d181e22c..b39a1ddc 100644
--- a/.github/workflows/windows-x86.yaml
+++ b/.github/workflows/windows-x86.yaml
@@ -7,6 +7,7 @@ on:
paths:
- '.github/workflows/windows-x86.yaml'
- '.github/scripts/test-online-transducer.sh'
+ - '.github/scripts/test-online-paraformer.sh'
- '.github/scripts/test-offline-transducer.sh'
- '.github/scripts/test-offline-ctc.sh'
- 'CMakeLists.txt'
@@ -18,6 +19,7 @@ on:
paths:
- '.github/workflows/windows-x86.yaml'
- '.github/scripts/test-online-transducer.sh'
+ - '.github/scripts/test-online-paraformer.sh'
- '.github/scripts/test-offline-transducer.sh'
- '.github/scripts/test-offline-ctc.sh'
- 'CMakeLists.txt'
@@ -73,6 +75,14 @@ jobs:
ls -lh ./bin/Release/sherpa-onnx.exe
+ - name: Test online paraformer for windows x86
+ shell: bash
+ run: |
+ export PATH=$PWD/build/bin/Release:$PATH
+ export EXE=sherpa-onnx.exe
+
+ .github/scripts/test-online-paraformer.sh
+
- name: Test offline Whisper for windows x86
shell: bash
run: |
diff --git a/CMakeLists.txt b/CMakeLists.txt
index abb77787..c6086fa3 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -1,7 +1,7 @@
cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
project(sherpa-onnx)
-set(SHERPA_ONNX_VERSION "1.7.3")
+set(SHERPA_ONNX_VERSION "1.7.4")
# Disable warning about
#
diff --git a/python-api-examples/non_streaming_server.py b/python-api-examples/non_streaming_server.py
index 7d3502fa..cbfaa760 100755
--- a/python-api-examples/non_streaming_server.py
+++ b/python-api-examples/non_streaming_server.py
@@ -37,14 +37,14 @@ python3 ./python-api-examples/non_streaming_server.py \
(2) Use a non-streaming paraformer
cd /path/to/sherpa-onnx
-GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-paraformer-zh-2023-03-28
-cd sherpa-onnx-paraformer-zh-2023-03-28
+GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-paraformer-bilingual-zh-en
+cd sherpa-onnx-paraformer-bilingual-zh-en/
git lfs pull --include "*.onnx"
cd ..
python3 ./python-api-examples/non_streaming_server.py \
- --paraformer ./sherpa-onnx-paraformer-zh-2023-03-28/model.int8.onnx \
- --tokens ./sherpa-onnx-paraformer-zh-2023-03-28/tokens.txt
+ --paraformer ./sherpa-onnx-paraformer-bilingual-zh-en/model.int8.onnx \
+ --tokens ./sherpa-onnx-paraformer-bilingual-zh-en/tokens.txt
(3) Use a non-streaming CTC model from NeMo
diff --git a/python-api-examples/online-decode-files.py b/python-api-examples/online-decode-files.py
index e2e1dc55..eff85427 100755
--- a/python-api-examples/online-decode-files.py
+++ b/python-api-examples/online-decode-files.py
@@ -5,16 +5,41 @@ This file demonstrates how to use sherpa-onnx Python API to transcribe
file(s) with a streaming model.
Usage:
- ./online-decode-files.py \
- /path/to/foo.wav \
- /path/to/bar.wav \
- /path/to/16kHz.wav \
- /path/to/8kHz.wav
+
+(1) Streaming transducer
+
+GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-06-26
+cd sherpa-onnx-streaming-zipformer-en-2023-06-26
+git lfs pull --include "*.onnx"
+
+./python-api-examples/online-decode-files.py \
+ --tokens=./sherpa-onnx-streaming-zipformer-en-2023-06-26/tokens.txt \
+ --encoder=./sherpa-onnx-streaming-zipformer-en-2023-06-26/encoder-epoch-99-avg-1-chunk-16-left-64.onnx \
+ --decoder=./sherpa-onnx-streaming-zipformer-en-2023-06-26/decoder-epoch-99-avg-1-chunk-16-left-64.onnx \
+ --joiner=./sherpa-onnx-streaming-zipformer-en-2023-06-26/joiner-epoch-99-avg-1-chunk-16-left-64.onnx \
+ ./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/0.wav \
+ ./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/1.wav \
+ ./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/8k.wav
+
+(2) Streaming paraformer
+
+GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-paraformer-bilingual-zh-en
+cd sherpa-onnx-streaming-paraformer-bilingual-zh-en
+git lfs pull --include "*.onnx"
+
+./python-api-examples/online-decode-files.py \
+ --tokens=./sherpa-onnx-streaming-paraformer-bilingual-zh-en/tokens.txt \
+ --paraformer-encoder=./sherpa-onnx-streaming-paraformer-bilingual-zh-en/encoder.int8.onnx \
+ --paraformer-decoder=./sherpa-onnx-streaming-paraformer-bilingual-zh-en/decoder.int8.onnx \
+ ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/0.wav \
+ ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/1.wav \
+ ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/2.wav \
+ ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/3.wav \
+ ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/8k.wav
Please refer to
https://k2-fsa.github.io/sherpa/onnx/index.html
-to install sherpa-onnx and to download the pre-trained models
-used in this file.
+to install sherpa-onnx and to download streaming pre-trained models.
"""
import argparse
import time
@@ -41,19 +66,31 @@ def get_args():
parser.add_argument(
"--encoder",
type=str,
- help="Path to the encoder model",
+ help="Path to the transducer encoder model",
)
parser.add_argument(
"--decoder",
type=str,
- help="Path to the decoder model",
+ help="Path to the transducer decoder model",
)
parser.add_argument(
"--joiner",
type=str,
- help="Path to the joiner model",
+ help="Path to the transducer joiner model",
+ )
+
+ parser.add_argument(
+ "--paraformer-encoder",
+ type=str,
+ help="Path to the paraformer encoder model",
+ )
+
+ parser.add_argument(
+ "--paraformer-decoder",
+ type=str,
+ help="Path to the paraformer decoder model",
)
parser.add_argument(
@@ -200,24 +237,42 @@ def encode_contexts(args, contexts: List[str]) -> List[List[int]]:
def main():
args = get_args()
- assert_file_exists(args.encoder)
- assert_file_exists(args.decoder)
- assert_file_exists(args.joiner)
assert_file_exists(args.tokens)
- recognizer = sherpa_onnx.OnlineRecognizer.from_transducer(
- tokens=args.tokens,
- encoder=args.encoder,
- decoder=args.decoder,
- joiner=args.joiner,
- num_threads=args.num_threads,
- provider=args.provider,
- sample_rate=16000,
- feature_dim=80,
- decoding_method=args.decoding_method,
- max_active_paths=args.max_active_paths,
- context_score=args.context_score,
- )
+ if args.encoder:
+ assert_file_exists(args.encoder)
+ assert_file_exists(args.decoder)
+ assert_file_exists(args.joiner)
+
+ assert not args.paraformer_encoder, args.paraformer_encoder
+ assert not args.paraformer_decoder, args.paraformer_decoder
+
+ recognizer = sherpa_onnx.OnlineRecognizer.from_transducer(
+ tokens=args.tokens,
+ encoder=args.encoder,
+ decoder=args.decoder,
+ joiner=args.joiner,
+ num_threads=args.num_threads,
+ provider=args.provider,
+ sample_rate=16000,
+ feature_dim=80,
+ decoding_method=args.decoding_method,
+ max_active_paths=args.max_active_paths,
+ context_score=args.context_score,
+ )
+ elif args.paraformer_encoder:
+ recognizer = sherpa_onnx.OnlineRecognizer.from_paraformer(
+ tokens=args.tokens,
+ encoder=args.paraformer_encoder,
+ decoder=args.paraformer_decoder,
+ num_threads=args.num_threads,
+ provider=args.provider,
+ sample_rate=16000,
+ feature_dim=80,
+ decoding_method="greedy_search",
+ )
+ else:
+ raise ValueError("Please provide a model")
print("Started!")
start_time = time.time()
@@ -243,7 +298,7 @@ def main():
s.accept_waveform(sample_rate, samples)
- tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32)
+ tail_paddings = np.zeros(int(0.66 * sample_rate), dtype=np.float32)
s.accept_waveform(sample_rate, tail_paddings)
s.input_finished()
diff --git a/python-api-examples/streaming_server.py b/python-api-examples/streaming_server.py
index c707a70c..33d4e5ee 100755
--- a/python-api-examples/streaming_server.py
+++ b/python-api-examples/streaming_server.py
@@ -16,9 +16,9 @@ Example:
(1) Without a certificate
python3 ./python-api-examples/streaming_server.py \
- --encoder-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx \
- --decoder-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx \
- --joiner-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx \
+ --encoder ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx \
+ --decoder ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx \
+ --joiner ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx \
--tokens ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt
(2) With a certificate
@@ -32,9 +32,9 @@ python3 ./python-api-examples/streaming_server.py \
(b) Start the server
python3 ./python-api-examples/streaming_server.py \
- --encoder-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx \
- --decoder-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx \
- --joiner-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx \
+ --encoder ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx \
+ --decoder ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx \
+ --joiner ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx \
--tokens ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt \
--certificate ./python-api-examples/web/cert.pem
@@ -113,24 +113,33 @@ def setup_logger(
def add_model_args(parser: argparse.ArgumentParser):
parser.add_argument(
- "--encoder-model",
+ "--encoder",
type=str,
- required=True,
- help="Path to the encoder model",
+ help="Path to the transducer encoder model",
)
parser.add_argument(
- "--decoder-model",
+ "--decoder",
type=str,
- required=True,
- help="Path to the decoder model.",
+ help="Path to the transducer decoder model.",
)
parser.add_argument(
- "--joiner-model",
+ "--joiner",
type=str,
- required=True,
- help="Path to the joiner model.",
+ help="Path to the transducer joiner model.",
+ )
+
+ parser.add_argument(
+ "--paraformer-encoder",
+ type=str,
+ help="Path to the paraformer encoder model",
+ )
+
+ parser.add_argument(
+ "--paraformer-decoder",
+ type=str,
+ help="Path to the transducer decoder model.",
)
parser.add_argument(
@@ -323,22 +332,40 @@ def get_args():
def create_recognizer(args) -> sherpa_onnx.OnlineRecognizer:
- recognizer = sherpa_onnx.OnlineRecognizer.from_transducer(
- tokens=args.tokens,
- encoder=args.encoder_model,
- decoder=args.decoder_model,
- joiner=args.joiner_model,
- num_threads=args.num_threads,
- sample_rate=args.sample_rate,
- feature_dim=args.feat_dim,
- decoding_method=args.decoding_method,
- max_active_paths=args.num_active_paths,
- enable_endpoint_detection=args.use_endpoint != 0,
- rule1_min_trailing_silence=args.rule1_min_trailing_silence,
- rule2_min_trailing_silence=args.rule2_min_trailing_silence,
- rule3_min_utterance_length=args.rule3_min_utterance_length,
- provider=args.provider,
- )
+ if args.encoder:
+ recognizer = sherpa_onnx.OnlineRecognizer.from_transducer(
+ tokens=args.tokens,
+ encoder=args.encoder,
+ decoder=args.decoder,
+ joiner=args.joiner,
+ num_threads=args.num_threads,
+ sample_rate=args.sample_rate,
+ feature_dim=args.feat_dim,
+ decoding_method=args.decoding_method,
+ max_active_paths=args.num_active_paths,
+ enable_endpoint_detection=args.use_endpoint != 0,
+ rule1_min_trailing_silence=args.rule1_min_trailing_silence,
+ rule2_min_trailing_silence=args.rule2_min_trailing_silence,
+ rule3_min_utterance_length=args.rule3_min_utterance_length,
+ provider=args.provider,
+ )
+ elif args.paraformer_encoder:
+ recognizer = sherpa_onnx.OnlineRecognizer.from_paraformer(
+ tokens=args.tokens,
+ encoder=args.paraformer_encoder,
+ decoder=args.paraformer_decoder,
+ num_threads=args.num_threads,
+ sample_rate=args.sample_rate,
+ feature_dim=args.feat_dim,
+ decoding_method=args.decoding_method,
+ enable_endpoint_detection=args.use_endpoint != 0,
+ rule1_min_trailing_silence=args.rule1_min_trailing_silence,
+ rule2_min_trailing_silence=args.rule2_min_trailing_silence,
+ rule3_min_utterance_length=args.rule3_min_utterance_length,
+ provider=args.provider,
+ )
+ else:
+ raise ValueError("Please provide a model")
return recognizer
@@ -654,11 +681,25 @@ Go back to /streaming_record.html
def check_args(args):
- assert Path(args.encoder_model).is_file(), f"{args.encoder_model} does not exist"
+ if args.encoder:
+ assert Path(args.encoder).is_file(), f"{args.encoder} does not exist"
- assert Path(args.decoder_model).is_file(), f"{args.decoder_model} does not exist"
+ assert Path(args.decoder).is_file(), f"{args.decoder} does not exist"
- assert Path(args.joiner_model).is_file(), f"{args.joiner_model} does not exist"
+ assert Path(args.joiner).is_file(), f"{args.joiner} does not exist"
+
+ assert args.paraformer_encoder is None, args.paraformer_encoder
+ assert args.paraformer_decoder is None, args.paraformer_decoder
+ elif args.paraformer_encoder:
+ assert Path(
+ args.paraformer_encoder
+ ).is_file(), f"{args.paraformer_encoder} does not exist"
+
+ assert Path(
+ args.paraformer_decoder
+ ).is_file(), f"{args.paraformer_decoder} does not exist"
+ else:
+ raise ValueError("Please provide a model")
if not Path(args.tokens).is_file():
raise ValueError(f"{args.tokens} does not exist")
diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt
index cb4953c5..b9bac58c 100644
--- a/sherpa-onnx/csrc/CMakeLists.txt
+++ b/sherpa-onnx/csrc/CMakeLists.txt
@@ -46,6 +46,8 @@ set(sources
online-lm.cc
online-lstm-transducer-model.cc
online-model-config.cc
+ online-paraformer-model-config.cc
+ online-paraformer-model.cc
online-recognizer-impl.cc
online-recognizer.cc
online-rnn-lm.cc
diff --git a/sherpa-onnx/csrc/features.cc b/sherpa-onnx/csrc/features.cc
index 7f804684..51500e11 100644
--- a/sherpa-onnx/csrc/features.cc
+++ b/sherpa-onnx/csrc/features.cc
@@ -39,7 +39,7 @@ std::string FeatureExtractorConfig::ToString() const {
class FeatureExtractor::Impl {
public:
- explicit Impl(const FeatureExtractorConfig &config) {
+ explicit Impl(const FeatureExtractorConfig &config) : config_(config) {
opts_.frame_opts.dither = 0;
opts_.frame_opts.snip_edges = false;
opts_.frame_opts.samp_freq = config.sampling_rate;
@@ -50,6 +50,19 @@ class FeatureExtractor::Impl {
}
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
+ if (config_.normalize_samples) {
+ AcceptWaveformImpl(sampling_rate, waveform, n);
+ } else {
+ std::vector buf(n);
+ for (int32_t i = 0; i != n; ++i) {
+ buf[i] = waveform[i] * 32768;
+ }
+ AcceptWaveformImpl(sampling_rate, buf.data(), n);
+ }
+ }
+
+ void AcceptWaveformImpl(int32_t sampling_rate, const float *waveform,
+ int32_t n) {
std::lock_guard lock(mutex_);
if (resampler_) {
@@ -146,6 +159,7 @@ class FeatureExtractor::Impl {
private:
std::unique_ptr fbank_;
knf::FbankOptions opts_;
+ FeatureExtractorConfig config_;
mutable std::mutex mutex_;
std::unique_ptr resampler_;
int32_t last_frame_index_ = 0;
diff --git a/sherpa-onnx/csrc/features.h b/sherpa-onnx/csrc/features.h
index d4eaffda..497dd01c 100644
--- a/sherpa-onnx/csrc/features.h
+++ b/sherpa-onnx/csrc/features.h
@@ -21,6 +21,13 @@ struct FeatureExtractorConfig {
// Feature dimension
int32_t feature_dim = 80;
+ // Set internally by some models, e.g., paraformer sets it to false.
+ // This parameter is not exposed to users from the commandline
+ // If true, the feature extractor expects inputs to be normalized to
+ // the range [-1, 1].
+ // If false, we will multiply the inputs by 32768
+ bool normalize_samples = true;
+
std::string ToString() const;
void Register(ParseOptions *po);
diff --git a/sherpa-onnx/csrc/online-model-config.cc b/sherpa-onnx/csrc/online-model-config.cc
index 7a4416b5..9c1f8c49 100644
--- a/sherpa-onnx/csrc/online-model-config.cc
+++ b/sherpa-onnx/csrc/online-model-config.cc
@@ -12,6 +12,7 @@ namespace sherpa_onnx {
void OnlineModelConfig::Register(ParseOptions *po) {
transducer.Register(po);
+ paraformer.Register(po);
po->Register("tokens", &tokens, "Path to tokens.txt");
@@ -41,6 +42,10 @@ bool OnlineModelConfig::Validate() const {
return false;
}
+ if (!paraformer.encoder.empty()) {
+ return paraformer.Validate();
+ }
+
return transducer.Validate();
}
@@ -49,6 +54,7 @@ std::string OnlineModelConfig::ToString() const {
os << "OnlineModelConfig(";
os << "transducer=" << transducer.ToString() << ", ";
+ os << "paraformer=" << paraformer.ToString() << ", ";
os << "tokens=\"" << tokens << "\", ";
os << "num_threads=" << num_threads << ", ";
os << "debug=" << (debug ? "True" : "False") << ", ";
diff --git a/sherpa-onnx/csrc/online-model-config.h b/sherpa-onnx/csrc/online-model-config.h
index 34e7b1e4..2afd6617 100644
--- a/sherpa-onnx/csrc/online-model-config.h
+++ b/sherpa-onnx/csrc/online-model-config.h
@@ -6,12 +6,14 @@
#include
+#include "sherpa-onnx/csrc/online-paraformer-model-config.h"
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
namespace sherpa_onnx {
struct OnlineModelConfig {
OnlineTransducerModelConfig transducer;
+ OnlineParaformerModelConfig paraformer;
std::string tokens;
int32_t num_threads = 1;
bool debug = false;
@@ -28,9 +30,11 @@ struct OnlineModelConfig {
OnlineModelConfig() = default;
OnlineModelConfig(const OnlineTransducerModelConfig &transducer,
+ const OnlineParaformerModelConfig ¶former,
const std::string &tokens, int32_t num_threads, bool debug,
const std::string &provider, const std::string &model_type)
: transducer(transducer),
+ paraformer(paraformer),
tokens(tokens),
num_threads(num_threads),
debug(debug),
diff --git a/sherpa-onnx/csrc/online-paraformer-decoder.h b/sherpa-onnx/csrc/online-paraformer-decoder.h
new file mode 100644
index 00000000..9f675275
--- /dev/null
+++ b/sherpa-onnx/csrc/online-paraformer-decoder.h
@@ -0,0 +1,23 @@
+// sherpa-onnx/csrc/online-paraformer-decoder.h
+//
+// Copyright (c) 2023 Xiaomi Corporation
+
+#ifndef SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_DECODER_H_
+#define SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_DECODER_H_
+
+#include
+
+#include "onnxruntime_cxx_api.h" // NOLINT
+
+namespace sherpa_onnx {
+
+struct OnlineParaformerDecoderResult {
+ /// The decoded token IDs
+ std::vector tokens;
+
+ int32_t last_non_blank_frame_index = 0;
+};
+
+} // namespace sherpa_onnx
+
+#endif // SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_DECODER_H_
diff --git a/sherpa-onnx/csrc/online-paraformer-model-config.cc b/sherpa-onnx/csrc/online-paraformer-model-config.cc
new file mode 100644
index 00000000..a93fe299
--- /dev/null
+++ b/sherpa-onnx/csrc/online-paraformer-model-config.cc
@@ -0,0 +1,43 @@
+// sherpa-onnx/csrc/online-paraformer-model-config.cc
+//
+// Copyright (c) 2023 Xiaomi Corporation
+
+#include "sherpa-onnx/csrc/online-paraformer-model-config.h"
+
+#include "sherpa-onnx/csrc/file-utils.h"
+#include "sherpa-onnx/csrc/macros.h"
+
+namespace sherpa_onnx {
+
+void OnlineParaformerModelConfig::Register(ParseOptions *po) {
+ po->Register("paraformer-encoder", &encoder,
+ "Path to encoder.onnx of paraformer.");
+ po->Register("paraformer-decoder", &decoder,
+ "Path to decoder.onnx of paraformer.");
+}
+
+bool OnlineParaformerModelConfig::Validate() const {
+ if (!FileExists(encoder)) {
+ SHERPA_ONNX_LOGE("Paraformer encoder %s does not exist", encoder.c_str());
+ return false;
+ }
+
+ if (!FileExists(decoder)) {
+ SHERPA_ONNX_LOGE("Paraformer decoder %s does not exist", decoder.c_str());
+ return false;
+ }
+
+ return true;
+}
+
+std::string OnlineParaformerModelConfig::ToString() const {
+ std::ostringstream os;
+
+ os << "OnlineParaformerModelConfig(";
+ os << "encoder=\"" << encoder << "\", ";
+ os << "decoder=\"" << decoder << "\")";
+
+ return os.str();
+}
+
+} // namespace sherpa_onnx
diff --git a/sherpa-onnx/csrc/online-paraformer-model-config.h b/sherpa-onnx/csrc/online-paraformer-model-config.h
new file mode 100644
index 00000000..29f33e45
--- /dev/null
+++ b/sherpa-onnx/csrc/online-paraformer-model-config.h
@@ -0,0 +1,31 @@
+// sherpa-onnx/csrc/online-paraformer-model-config.h
+//
+// Copyright (c) 2023 Xiaomi Corporation
+#ifndef SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_MODEL_CONFIG_H_
+#define SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_MODEL_CONFIG_H_
+
+#include
+
+#include "sherpa-onnx/csrc/parse-options.h"
+
+namespace sherpa_onnx {
+
+struct OnlineParaformerModelConfig {
+ std::string encoder;
+ std::string decoder;
+
+ OnlineParaformerModelConfig() = default;
+
+ OnlineParaformerModelConfig(const std::string &encoder,
+ const std::string &decoder)
+ : encoder(encoder), decoder(decoder) {}
+
+ void Register(ParseOptions *po);
+ bool Validate() const;
+
+ std::string ToString() const;
+};
+
+} // namespace sherpa_onnx
+
+#endif // SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_MODEL_CONFIG_H_
diff --git a/sherpa-onnx/csrc/online-paraformer-model.cc b/sherpa-onnx/csrc/online-paraformer-model.cc
new file mode 100644
index 00000000..2d6a410e
--- /dev/null
+++ b/sherpa-onnx/csrc/online-paraformer-model.cc
@@ -0,0 +1,249 @@
+// sherpa-onnx/csrc/online-paraformer-model.cc
+//
+// Copyright (c) 2022-2023 Xiaomi Corporation
+
+#include "sherpa-onnx/csrc/online-paraformer-model.h"
+
+#include
+#include
+#include
+
+#if __ANDROID_API__ >= 9
+#include "android/asset_manager.h"
+#include "android/asset_manager_jni.h"
+#endif
+
+#include "sherpa-onnx/csrc/macros.h"
+#include "sherpa-onnx/csrc/onnx-utils.h"
+#include "sherpa-onnx/csrc/session.h"
+#include "sherpa-onnx/csrc/text-utils.h"
+
+namespace sherpa_onnx {
+
+class OnlineParaformerModel::Impl {
+ public:
+ explicit Impl(const OnlineModelConfig &config)
+ : config_(config),
+ env_(ORT_LOGGING_LEVEL_ERROR),
+ sess_opts_(GetSessionOptions(config)),
+ allocator_{} {
+ {
+ auto buf = ReadFile(config.paraformer.encoder);
+ InitEncoder(buf.data(), buf.size());
+ }
+
+ {
+ auto buf = ReadFile(config.paraformer.decoder);
+ InitDecoder(buf.data(), buf.size());
+ }
+ }
+
+#if __ANDROID_API__ >= 9
+ Impl(AAssetManager *mgr, const OnlineModelConfig &config)
+ : config_(config),
+ env_(ORT_LOGGING_LEVEL_WARNING),
+ sess_opts_(GetSessionOptions(config)),
+ allocator_{} {
+ {
+ auto buf = ReadFile(mgr, config.paraformer.encoder);
+ InitEncoder(buf.data(), buf.size());
+ }
+
+ {
+ auto buf = ReadFile(mgr, config.paraformer.decoder);
+ InitDecoder(buf.data(), buf.size());
+ }
+ }
+#endif
+
+ std::vector ForwardEncoder(Ort::Value features,
+ Ort::Value features_length) {
+ std::array inputs = {std::move(features),
+ std::move(features_length)};
+
+ return encoder_sess_->Run(
+ {}, encoder_input_names_ptr_.data(), inputs.data(), inputs.size(),
+ encoder_output_names_ptr_.data(), encoder_output_names_ptr_.size());
+ }
+
+ std::vector ForwardDecoder(Ort::Value encoder_out,
+ Ort::Value encoder_out_length,
+ Ort::Value acoustic_embedding,
+ Ort::Value acoustic_embedding_length,
+ std::vector states) {
+ std::vector decoder_inputs;
+ decoder_inputs.reserve(4 + states.size());
+
+ decoder_inputs.push_back(std::move(encoder_out));
+ decoder_inputs.push_back(std::move(encoder_out_length));
+ decoder_inputs.push_back(std::move(acoustic_embedding));
+ decoder_inputs.push_back(std::move(acoustic_embedding_length));
+
+ for (auto &v : states) {
+ decoder_inputs.push_back(std::move(v));
+ }
+
+ return decoder_sess_->Run({}, decoder_input_names_ptr_.data(),
+ decoder_inputs.data(), decoder_inputs.size(),
+ decoder_output_names_ptr_.data(),
+ decoder_output_names_ptr_.size());
+ }
+
+ int32_t VocabSize() const { return vocab_size_; }
+
+ int32_t LfrWindowSize() const { return lfr_window_size_; }
+
+ int32_t LfrWindowShift() const { return lfr_window_shift_; }
+
+ int32_t EncoderOutputSize() const { return encoder_output_size_; }
+
+ int32_t DecoderKernelSize() const { return decoder_kernel_size_; }
+
+ int32_t DecoderNumBlocks() const { return decoder_num_blocks_; }
+
+ const std::vector &NegativeMean() const { return neg_mean_; }
+
+ const std::vector &InverseStdDev() const { return inv_stddev_; }
+
+ OrtAllocator *Allocator() const { return allocator_; }
+
+ private:
+ void InitEncoder(void *model_data, size_t model_data_length) {
+ encoder_sess_ = std::make_unique(
+ env_, model_data, model_data_length, sess_opts_);
+
+ GetInputNames(encoder_sess_.get(), &encoder_input_names_,
+ &encoder_input_names_ptr_);
+
+ GetOutputNames(encoder_sess_.get(), &encoder_output_names_,
+ &encoder_output_names_ptr_);
+
+ // get meta data
+ Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata();
+ if (config_.debug) {
+ std::ostringstream os;
+ PrintModelMetadata(os, meta_data);
+ SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
+ }
+
+ Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
+ SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size");
+ SHERPA_ONNX_READ_META_DATA(lfr_window_size_, "lfr_window_size");
+ SHERPA_ONNX_READ_META_DATA(lfr_window_shift_, "lfr_window_shift");
+ SHERPA_ONNX_READ_META_DATA(encoder_output_size_, "encoder_output_size");
+ SHERPA_ONNX_READ_META_DATA(decoder_num_blocks_, "decoder_num_blocks");
+ SHERPA_ONNX_READ_META_DATA(decoder_kernel_size_, "decoder_kernel_size");
+
+ SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(neg_mean_, "neg_mean");
+ SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(inv_stddev_, "inv_stddev");
+
+ float scale = std::sqrt(encoder_output_size_);
+ for (auto &f : inv_stddev_) {
+ f *= scale;
+ }
+ }
+
+ void InitDecoder(void *model_data, size_t model_data_length) {
+ decoder_sess_ = std::make_unique(
+ env_, model_data, model_data_length, sess_opts_);
+
+ GetInputNames(decoder_sess_.get(), &decoder_input_names_,
+ &decoder_input_names_ptr_);
+
+ GetOutputNames(decoder_sess_.get(), &decoder_output_names_,
+ &decoder_output_names_ptr_);
+ }
+
+ private:
+ OnlineModelConfig config_;
+ Ort::Env env_;
+ Ort::SessionOptions sess_opts_;
+ Ort::AllocatorWithDefaultOptions allocator_;
+
+ std::unique_ptr encoder_sess_;
+
+ std::vector encoder_input_names_;
+ std::vector encoder_input_names_ptr_;
+
+ std::vector encoder_output_names_;
+ std::vector encoder_output_names_ptr_;
+
+ std::unique_ptr decoder_sess_;
+
+ std::vector decoder_input_names_;
+ std::vector decoder_input_names_ptr_;
+
+ std::vector decoder_output_names_;
+ std::vector decoder_output_names_ptr_;
+
+ std::vector neg_mean_;
+ std::vector inv_stddev_;
+
+ int32_t vocab_size_ = 0; // initialized in Init
+ int32_t lfr_window_size_ = 0;
+ int32_t lfr_window_shift_ = 0;
+
+ int32_t encoder_output_size_ = 0;
+ int32_t decoder_num_blocks_ = 0;
+ int32_t decoder_kernel_size_ = 0;
+};
+
+OnlineParaformerModel::OnlineParaformerModel(const OnlineModelConfig &config)
+ : impl_(std::make_unique(config)) {}
+
+#if __ANDROID_API__ >= 9
+OnlineParaformerModel::OnlineParaformerModel(AAssetManager *mgr,
+ const OnlineModelConfig &config)
+ : impl_(std::make_unique(mgr, config)) {}
+#endif
+
+OnlineParaformerModel::~OnlineParaformerModel() = default;
+
+std::vector OnlineParaformerModel::ForwardEncoder(
+ Ort::Value features, Ort::Value features_length) const {
+ return impl_->ForwardEncoder(std::move(features), std::move(features_length));
+}
+
+std::vector OnlineParaformerModel::ForwardDecoder(
+ Ort::Value encoder_out, Ort::Value encoder_out_length,
+ Ort::Value acoustic_embedding, Ort::Value acoustic_embedding_length,
+ std::vector states) const {
+ return impl_->ForwardDecoder(
+ std::move(encoder_out), std::move(encoder_out_length),
+ std::move(acoustic_embedding), std::move(acoustic_embedding_length),
+ std::move(states));
+}
+
+int32_t OnlineParaformerModel::VocabSize() const { return impl_->VocabSize(); }
+
+int32_t OnlineParaformerModel::LfrWindowSize() const {
+ return impl_->LfrWindowSize();
+}
+int32_t OnlineParaformerModel::LfrWindowShift() const {
+ return impl_->LfrWindowShift();
+}
+
+int32_t OnlineParaformerModel::EncoderOutputSize() const {
+ return impl_->EncoderOutputSize();
+}
+
+int32_t OnlineParaformerModel::DecoderKernelSize() const {
+ return impl_->DecoderKernelSize();
+}
+
+int32_t OnlineParaformerModel::DecoderNumBlocks() const {
+ return impl_->DecoderNumBlocks();
+}
+
+const std::vector &OnlineParaformerModel::NegativeMean() const {
+ return impl_->NegativeMean();
+}
+const std::vector &OnlineParaformerModel::InverseStdDev() const {
+ return impl_->InverseStdDev();
+}
+
+OrtAllocator *OnlineParaformerModel::Allocator() const {
+ return impl_->Allocator();
+}
+
+} // namespace sherpa_onnx
diff --git a/sherpa-onnx/csrc/online-paraformer-model.h b/sherpa-onnx/csrc/online-paraformer-model.h
new file mode 100644
index 00000000..3c018a72
--- /dev/null
+++ b/sherpa-onnx/csrc/online-paraformer-model.h
@@ -0,0 +1,76 @@
+// sherpa-onnx/csrc/online-paraformer-model.h
+//
+// Copyright (c) 2022-2023 Xiaomi Corporation
+#ifndef SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_MODEL_H_
+#define SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_MODEL_H_
+
+#include
+#include
+#include
+
+#if __ANDROID_API__ >= 9
+#include "android/asset_manager.h"
+#include "android/asset_manager_jni.h"
+#endif
+
+#include "onnxruntime_cxx_api.h" // NOLINT
+#include "sherpa-onnx/csrc/online-model-config.h"
+
+namespace sherpa_onnx {
+
+class OnlineParaformerModel {
+ public:
+ explicit OnlineParaformerModel(const OnlineModelConfig &config);
+
+#if __ANDROID_API__ >= 9
+ OnlineParaformerModel(AAssetManager *mgr, const OnlineModelConfig &config);
+#endif
+
+ ~OnlineParaformerModel();
+
+ std::vector ForwardEncoder(Ort::Value features,
+ Ort::Value features_length) const;
+
+ std::vector ForwardDecoder(Ort::Value encoder_out,
+ Ort::Value encoder_out_length,
+ Ort::Value acoustic_embedding,
+ Ort::Value acoustic_embedding_length,
+ std::vector states) const;
+
+ /** Return the vocabulary size of the model
+ */
+ int32_t VocabSize() const;
+
+ /** It is lfr_m in config.yaml
+ */
+ int32_t LfrWindowSize() const;
+
+ /** It is lfr_n in config.yaml
+ */
+ int32_t LfrWindowShift() const;
+
+ int32_t EncoderOutputSize() const;
+
+ int32_t DecoderKernelSize() const;
+ int32_t DecoderNumBlocks() const;
+
+ /** Return negative mean for CMVN
+ */
+ const std::vector &NegativeMean() const;
+
+ /** Return inverse stddev for CMVN
+ */
+ const std::vector &InverseStdDev() const;
+
+ /** Return an allocator for allocating memory
+ */
+ OrtAllocator *Allocator() const;
+
+ private:
+ class Impl;
+ std::unique_ptr impl_;
+};
+
+} // namespace sherpa_onnx
+
+#endif // SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_MODEL_H_
diff --git a/sherpa-onnx/csrc/online-recognizer-impl.cc b/sherpa-onnx/csrc/online-recognizer-impl.cc
index a9e545dd..1eb16c03 100644
--- a/sherpa-onnx/csrc/online-recognizer-impl.cc
+++ b/sherpa-onnx/csrc/online-recognizer-impl.cc
@@ -4,6 +4,7 @@
#include "sherpa-onnx/csrc/online-recognizer-impl.h"
+#include "sherpa-onnx/csrc/online-recognizer-paraformer-impl.h"
#include "sherpa-onnx/csrc/online-recognizer-transducer-impl.h"
namespace sherpa_onnx {
@@ -14,6 +15,10 @@ std::unique_ptr OnlineRecognizerImpl::Create(
return std::make_unique(config);
}
+ if (!config.model_config.paraformer.encoder.empty()) {
+ return std::make_unique(config);
+ }
+
SHERPA_ONNX_LOGE("Please specify a model");
exit(-1);
}
@@ -25,6 +30,10 @@ std::unique_ptr OnlineRecognizerImpl::Create(
return std::make_unique(mgr, config);
}
+ if (!config.model_config.paraformer.encoder.empty()) {
+ return std::make_unique(mgr, config);
+ }
+
SHERPA_ONNX_LOGE("Please specify a model");
exit(-1);
}
diff --git a/sherpa-onnx/csrc/online-recognizer-impl.h b/sherpa-onnx/csrc/online-recognizer-impl.h
index 8b574a4d..515c9d9e 100644
--- a/sherpa-onnx/csrc/online-recognizer-impl.h
+++ b/sherpa-onnx/csrc/online-recognizer-impl.h
@@ -26,8 +26,6 @@ class OnlineRecognizerImpl {
virtual ~OnlineRecognizerImpl() = default;
- virtual void InitOnlineStream(OnlineStream *stream) const = 0;
-
virtual std::unique_ptr CreateStream() const = 0;
virtual std::unique_ptr CreateStream(
diff --git a/sherpa-onnx/csrc/online-recognizer-paraformer-impl.h b/sherpa-onnx/csrc/online-recognizer-paraformer-impl.h
new file mode 100644
index 00000000..ae209633
--- /dev/null
+++ b/sherpa-onnx/csrc/online-recognizer-paraformer-impl.h
@@ -0,0 +1,465 @@
+// sherpa-onnx/csrc/online-recognizer-paraformer-impl.h
+//
+// Copyright (c) 2022-2023 Xiaomi Corporation
+
+#ifndef SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_PARAFORMER_IMPL_H_
+#define SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_PARAFORMER_IMPL_H_
+
+#include
+#include
+#include
+#include
+#include
+
+#include "sherpa-onnx/csrc/file-utils.h"
+#include "sherpa-onnx/csrc/macros.h"
+#include "sherpa-onnx/csrc/online-lm.h"
+#include "sherpa-onnx/csrc/online-paraformer-decoder.h"
+#include "sherpa-onnx/csrc/online-paraformer-model.h"
+#include "sherpa-onnx/csrc/online-recognizer-impl.h"
+#include "sherpa-onnx/csrc/online-recognizer.h"
+#include "sherpa-onnx/csrc/symbol-table.h"
+
+namespace sherpa_onnx {
+
+static OnlineRecognizerResult Convert(const OnlineParaformerDecoderResult &src,
+ const SymbolTable &sym_table) {
+ OnlineRecognizerResult r;
+ r.tokens.reserve(src.tokens.size());
+
+ std::string text;
+
+ // When the current token ends with "@@" we set mergeable to true
+ bool mergeable = false;
+
+ for (int32_t i = 0; i != src.tokens.size(); ++i) {
+ auto sym = sym_table[src.tokens[i]];
+ r.tokens.push_back(sym);
+
+ if ((sym.back() != '@') || (sym.size() > 2 && sym[sym.size() - 2] != '@')) {
+ // sym does not end with "@@"
+ const uint8_t *p = reinterpret_cast(sym.c_str());
+ if (p[0] < 0x80) {
+ // an ascii
+ if (mergeable) {
+ mergeable = false;
+ text.append(sym);
+ } else {
+ text.append(" ");
+ text.append(sym);
+ }
+ } else {
+ // not an ascii
+ mergeable = false;
+
+ if (i > 0) {
+ const uint8_t *p = reinterpret_cast(
+ sym_table[src.tokens[i - 1]].c_str());
+ if (p[0] < 0x80) {
+ // put a space between ascii and non-ascii
+ text.append(" ");
+ }
+ }
+ text.append(sym);
+ }
+ } else {
+ // this sym ends with @@
+ sym = std::string(sym.data(), sym.size() - 2);
+ if (mergeable) {
+ text.append(sym);
+ } else {
+ text.append(" ");
+ text.append(sym);
+ mergeable = true;
+ }
+ }
+ }
+ r.text = std::move(text);
+
+ return r;
+}
+
+// y[i] += x[i] * scale
+static void ScaleAddInPlace(const float *x, int32_t n, float scale, float *y) {
+ for (int32_t i = 0; i != n; ++i) {
+ y[i] += x[i] * scale;
+ }
+}
+
+// y[i] = x[i] * scale
+static void Scale(const float *x, int32_t n, float scale, float *y) {
+ for (int32_t i = 0; i != n; ++i) {
+ y[i] = x[i] * scale;
+ }
+}
+
+class OnlineRecognizerParaformerImpl : public OnlineRecognizerImpl {
+ public:
+ explicit OnlineRecognizerParaformerImpl(const OnlineRecognizerConfig &config)
+ : config_(config),
+ model_(config.model_config),
+ sym_(config.model_config.tokens),
+ endpoint_(config_.endpoint_config) {
+ if (config.decoding_method != "greedy_search") {
+ SHERPA_ONNX_LOGE(
+ "Unsupported decoding method: %s. Support only greedy_search at "
+ "present",
+ config.decoding_method.c_str());
+ exit(-1);
+ }
+
+ // Paraformer models assume input samples are in the range
+ // [-32768, 32767], so we set normalize_samples to false
+ config_.feat_config.normalize_samples = false;
+ }
+
+#if __ANDROID_API__ >= 9
+ explicit OnlineRecognizerParaformerImpl(AAssetManager *mgr,
+ const OnlineRecognizerConfig &config)
+ : config_(config),
+ model_(mgr, config.model_config),
+ sym_(mgr, config.model_config.tokens),
+ endpoint_(config_.endpoint_config) {
+ if (config.decoding_method == "greedy_search") {
+ // add greedy search decoder
+ // SHERPA_ONNX_LOGE("to be implemented");
+ // exit(-1);
+ } else {
+ SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
+ config.decoding_method.c_str());
+ exit(-1);
+ }
+
+ // Paraformer models assume input samples are in the range
+ // [-32768, 32767], so we set normalize_samples to false
+ config_.feat_config.normalize_samples = false;
+ }
+#endif
+ OnlineRecognizerParaformerImpl(const OnlineRecognizerParaformerImpl &) =
+ delete;
+
+ OnlineRecognizerParaformerImpl operator=(
+ const OnlineRecognizerParaformerImpl &) = delete;
+
+ std::unique_ptr CreateStream() const override {
+ auto stream = std::make_unique(config_.feat_config);
+
+ OnlineParaformerDecoderResult r;
+ stream->SetParaformerResult(r);
+
+ return stream;
+ }
+
+ bool IsReady(OnlineStream *s) const override {
+ return s->GetNumProcessedFrames() + chunk_size_ < s->NumFramesReady();
+ }
+
+ void DecodeStreams(OnlineStream **ss, int32_t n) const override {
+ // TODO(fangjun): Support batch size > 1
+ for (int32_t i = 0; i != n; ++i) {
+ DecodeStream(ss[i]);
+ }
+ }
+
+ OnlineRecognizerResult GetResult(OnlineStream *s) const override {
+ auto decoder_result = s->GetParaformerResult();
+
+ return Convert(decoder_result, sym_);
+ }
+
+ bool IsEndpoint(OnlineStream *s) const override {
+ if (!config_.enable_endpoint) {
+ return false;
+ }
+
+ const auto &result = s->GetParaformerResult();
+
+ int32_t num_processed_frames = s->GetNumProcessedFrames();
+
+ // frame shift is 10 milliseconds
+ float frame_shift_in_seconds = 0.01;
+
+ int32_t trailing_silence_frames =
+ num_processed_frames - result.last_non_blank_frame_index;
+
+ return endpoint_.IsEndpoint(num_processed_frames, trailing_silence_frames,
+ frame_shift_in_seconds);
+ }
+
+ void Reset(OnlineStream *s) const override {
+ OnlineParaformerDecoderResult r;
+ s->SetParaformerResult(r);
+
+ // the internal model caches are not reset
+
+ // Note: We only update counters. The underlying audio samples
+ // are not discarded.
+ s->Reset();
+ }
+
+ private:
+ void DecodeStream(OnlineStream *s) const {
+ const auto num_processed_frames = s->GetNumProcessedFrames();
+ std::vector frames = s->GetFrames(num_processed_frames, chunk_size_);
+ s->GetNumProcessedFrames() += chunk_size_ - 1;
+
+ frames = ApplyLFR(frames);
+ ApplyCMVN(&frames);
+ PositionalEncoding(&frames, num_processed_frames / model_.LfrWindowShift());
+
+ int32_t feat_dim = model_.NegativeMean().size();
+
+ // We have scaled inv_stddev by sqrt(encoder_output_size)
+ // so the following line can be commented out
+ // frames *= encoder_output_size ** 0.5
+
+ // add overlap chunk
+ std::vector &feat_cache = s->GetParaformerFeatCache();
+ if (feat_cache.empty()) {
+ int32_t n = (left_chunk_size_ + right_chunk_size_) * feat_dim;
+ feat_cache.resize(n, 0);
+ }
+
+ frames.insert(frames.begin(), feat_cache.begin(), feat_cache.end());
+ std::copy(frames.end() - feat_cache.size(), frames.end(),
+ feat_cache.begin());
+
+ int32_t num_frames = frames.size() / feat_dim;
+
+ auto memory_info =
+ Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
+
+ std::array x_shape{1, num_frames, feat_dim};
+ Ort::Value x =
+ Ort::Value::CreateTensor(memory_info, frames.data(), frames.size(),
+ x_shape.data(), x_shape.size());
+
+ int64_t x_len_shape = 1;
+ int32_t x_len_val = num_frames;
+
+ Ort::Value x_length =
+ Ort::Value::CreateTensor(memory_info, &x_len_val, 1, &x_len_shape, 1);
+
+ auto encoder_out_vec =
+ model_.ForwardEncoder(std::move(x), std::move(x_length));
+
+ // CIF search
+ auto &encoder_out = encoder_out_vec[0];
+ auto &encoder_out_len = encoder_out_vec[1];
+ auto &alpha = encoder_out_vec[2];
+
+ float *p_alpha = alpha.GetTensorMutableData();
+
+ std::vector alpha_shape =
+ alpha.GetTensorTypeAndShapeInfo().GetShape();
+
+ std::fill(p_alpha, p_alpha + left_chunk_size_, 0);
+ std::fill(p_alpha + alpha_shape[1] - right_chunk_size_,
+ p_alpha + alpha_shape[1], 0);
+
+ const float *p_encoder_out = encoder_out.GetTensorData();
+
+ std::vector encoder_out_shape =
+ encoder_out.GetTensorTypeAndShapeInfo().GetShape();
+
+ std::vector &initial_hidden = s->GetParaformerEncoderOutCache();
+ if (initial_hidden.empty()) {
+ initial_hidden.resize(encoder_out_shape[2]);
+ }
+
+ std::vector &alpha_cache = s->GetParaformerAlphaCache();
+ if (alpha_cache.empty()) {
+ alpha_cache.resize(1);
+ }
+
+ std::vector acoustic_embedding;
+ acoustic_embedding.reserve(encoder_out_shape[1] * encoder_out_shape[2]);
+
+ float threshold = 1.0;
+
+ float integrate = alpha_cache[0];
+
+ for (int32_t i = 0; i != encoder_out_shape[1]; ++i) {
+ float this_alpha = p_alpha[i];
+ if (integrate + this_alpha < threshold) {
+ integrate += this_alpha;
+ ScaleAddInPlace(p_encoder_out + i * encoder_out_shape[2],
+ encoder_out_shape[2], this_alpha,
+ initial_hidden.data());
+ continue;
+ }
+
+ // fire
+ ScaleAddInPlace(p_encoder_out + i * encoder_out_shape[2],
+ encoder_out_shape[2], threshold - integrate,
+ initial_hidden.data());
+ acoustic_embedding.insert(acoustic_embedding.end(),
+ initial_hidden.begin(), initial_hidden.end());
+ integrate += this_alpha - threshold;
+
+ Scale(p_encoder_out + i * encoder_out_shape[2], encoder_out_shape[2],
+ integrate, initial_hidden.data());
+ }
+
+ alpha_cache[0] = integrate;
+
+ if (acoustic_embedding.empty()) {
+ return;
+ }
+
+ auto &states = s->GetStates();
+ if (states.empty()) {
+ states.reserve(model_.DecoderNumBlocks());
+
+ std::array shape{1, model_.EncoderOutputSize(),
+ model_.DecoderKernelSize() - 1};
+
+ int32_t num_bytes = sizeof(float) * shape[0] * shape[1] * shape[2];
+
+ for (int32_t i = 0; i != model_.DecoderNumBlocks(); ++i) {
+ Ort::Value this_state = Ort::Value::CreateTensor(
+ model_.Allocator(), shape.data(), shape.size());
+
+ memset(this_state.GetTensorMutableData(), 0, num_bytes);
+
+ states.push_back(std::move(this_state));
+ }
+ }
+
+ int32_t num_tokens = acoustic_embedding.size() / initial_hidden.size();
+ std::array acoustic_embedding_shape{
+ 1, num_tokens, static_cast(initial_hidden.size())};
+
+ Ort::Value acoustic_embedding_tensor = Ort::Value::CreateTensor(
+ memory_info, acoustic_embedding.data(), acoustic_embedding.size(),
+ acoustic_embedding_shape.data(), acoustic_embedding_shape.size());
+
+ std::array acoustic_embedding_length_shape{1};
+ Ort::Value acoustic_embedding_length_tensor = Ort::Value::CreateTensor(
+ memory_info, &num_tokens, 1, acoustic_embedding_length_shape.data(),
+ acoustic_embedding_length_shape.size());
+
+ auto decoder_out_vec = model_.ForwardDecoder(
+ std::move(encoder_out), std::move(encoder_out_len),
+ std::move(acoustic_embedding_tensor),
+ std::move(acoustic_embedding_length_tensor), std::move(states));
+
+ states.reserve(model_.DecoderNumBlocks());
+ for (int32_t i = 2; i != decoder_out_vec.size(); ++i) {
+ // TODO(fangjun): When we change chunk_size_, we need to
+ // slice decoder_out_vec[i] accordingly.
+ states.push_back(std::move(decoder_out_vec[i]));
+ }
+
+ const auto &sample_ids = decoder_out_vec[1];
+ const int64_t *p_sample_ids = sample_ids.GetTensorData();
+
+ bool non_blank_detected = false;
+
+ auto &result = s->GetParaformerResult();
+
+ for (int32_t i = 0; i != num_tokens; ++i) {
+ int32_t t = p_sample_ids[i];
+ if (t == 0) {
+ continue;
+ }
+
+ non_blank_detected = true;
+ result.tokens.push_back(t);
+ }
+
+ if (non_blank_detected) {
+ result.last_non_blank_frame_index = num_processed_frames;
+ }
+ }
+
+ std::vector ApplyLFR(const std::vector &in) const {
+ int32_t lfr_window_size = model_.LfrWindowSize();
+ int32_t lfr_window_shift = model_.LfrWindowShift();
+ int32_t in_feat_dim = config_.feat_config.feature_dim;
+
+ int32_t in_num_frames = in.size() / in_feat_dim;
+ int32_t out_num_frames =
+ (in_num_frames - lfr_window_size) / lfr_window_shift + 1;
+ int32_t out_feat_dim = in_feat_dim * lfr_window_size;
+
+ std::vector out(out_num_frames * out_feat_dim);
+
+ const float *p_in = in.data();
+ float *p_out = out.data();
+
+ for (int32_t i = 0; i != out_num_frames; ++i) {
+ std::copy(p_in, p_in + out_feat_dim, p_out);
+
+ p_out += out_feat_dim;
+ p_in += lfr_window_shift * in_feat_dim;
+ }
+
+ return out;
+ }
+
+ void ApplyCMVN(std::vector *v) const {
+ const std::vector &neg_mean = model_.NegativeMean();
+ const std::vector &inv_stddev = model_.InverseStdDev();
+
+ int32_t dim = neg_mean.size();
+ int32_t num_frames = v->size() / dim;
+
+ float *p = v->data();
+
+ for (int32_t i = 0; i != num_frames; ++i) {
+ for (int32_t k = 0; k != dim; ++k) {
+ p[k] = (p[k] + neg_mean[k]) * inv_stddev[k];
+ }
+
+ p += dim;
+ }
+ }
+
+ void PositionalEncoding(std::vector *v, int32_t t_offset) const {
+ int32_t lfr_window_size = model_.LfrWindowSize();
+ int32_t in_feat_dim = config_.feat_config.feature_dim;
+
+ int32_t feat_dim = in_feat_dim * lfr_window_size;
+ int32_t T = v->size() / feat_dim;
+
+ // log(10000)/(7*80/2-1) == 0.03301197265941284
+ // 7 is lfr_window_size
+ // 80 is in_feat_dim
+ // 7*80 is feat_dim
+ constexpr float kScale = -0.03301197265941284;
+
+ for (int32_t t = 0; t != T; ++t) {
+ float *p = v->data() + t * feat_dim;
+
+ int32_t offset = t + 1 + t_offset;
+
+ for (int32_t d = 0; d < feat_dim / 2; ++d) {
+ float inv_timescale = offset * std::exp(d * kScale);
+
+ float sin_d = std::sin(inv_timescale);
+ float cos_d = std::cos(inv_timescale);
+
+ p[d] += sin_d;
+ p[d + feat_dim / 2] += cos_d;
+ }
+ }
+ }
+
+ private:
+ OnlineRecognizerConfig config_;
+ OnlineParaformerModel model_;
+ SymbolTable sym_;
+ Endpoint endpoint_;
+
+ // 0.61 seconds
+ int32_t chunk_size_ = 61;
+ // (61 - 7) / 6 + 1 = 10
+
+ int32_t left_chunk_size_ = 5;
+ int32_t right_chunk_size_ = 5;
+};
+
+} // namespace sherpa_onnx
+
+#endif // SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_PARAFORMER_IMPL_H_
diff --git a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h
index a5d2d815..625d02b1 100644
--- a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h
+++ b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h
@@ -94,21 +94,6 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
}
#endif
- void InitOnlineStream(OnlineStream *stream) const override {
- auto r = decoder_->GetEmptyResult();
-
- if (config_.decoding_method == "modified_beam_search" &&
- nullptr != stream->GetContextGraph()) {
- // r.hyps has only one element.
- for (auto it = r.hyps.begin(); it != r.hyps.end(); ++it) {
- it->second.context_state = stream->GetContextGraph()->Root();
- }
- }
-
- stream->SetResult(r);
- stream->SetStates(model_->GetEncoderInitStates());
- }
-
std::unique_ptr CreateStream() const override {
auto stream = std::make_unique(config_.feat_config);
InitOnlineStream(stream.get());
@@ -211,7 +196,10 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
}
bool IsEndpoint(OnlineStream *s) const override {
- if (!config_.enable_endpoint) return false;
+ if (!config_.enable_endpoint) {
+ return false;
+ }
+
int32_t num_processed_frames = s->GetNumProcessedFrames();
// frame shift is 10 milliseconds
@@ -244,6 +232,22 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
s->Reset();
}
+ private:
+ void InitOnlineStream(OnlineStream *stream) const {
+ auto r = decoder_->GetEmptyResult();
+
+ if (config_.decoding_method == "modified_beam_search" &&
+ nullptr != stream->GetContextGraph()) {
+ // r.hyps has only one element.
+ for (auto it = r.hyps.begin(); it != r.hyps.end(); ++it) {
+ it->second.context_state = stream->GetContextGraph()->Root();
+ }
+ }
+
+ stream->SetResult(r);
+ stream->SetStates(model_->GetEncoderInitStates());
+ }
+
private:
OnlineRecognizerConfig config_;
std::unique_ptr model_;
diff --git a/sherpa-onnx/csrc/online-stream.cc b/sherpa-onnx/csrc/online-stream.cc
index e0593ff6..8960ed13 100644
--- a/sherpa-onnx/csrc/online-stream.cc
+++ b/sherpa-onnx/csrc/online-stream.cc
@@ -47,6 +47,14 @@ class OnlineStream::Impl {
OnlineTransducerDecoderResult &GetResult() { return result_; }
+ void SetParaformerResult(const OnlineParaformerDecoderResult &r) {
+ paraformer_result_ = r;
+ }
+
+ OnlineParaformerDecoderResult &GetParaformerResult() {
+ return paraformer_result_;
+ }
+
int32_t FeatureDim() const { return feat_extractor_.FeatureDim(); }
void SetStates(std::vector states) {
@@ -57,6 +65,18 @@ class OnlineStream::Impl {
const ContextGraphPtr &GetContextGraph() const { return context_graph_; }
+ std::vector &GetParaformerFeatCache() {
+ return paraformer_feat_cache_;
+ }
+
+ std::vector &GetParaformerEncoderOutCache() {
+ return paraformer_encoder_out_cache_;
+ }
+
+ std::vector &GetParaformerAlphaCache() {
+ return paraformer_alpha_cache_;
+ }
+
private:
FeatureExtractor feat_extractor_;
/// For contextual-biasing
@@ -65,6 +85,10 @@ class OnlineStream::Impl {
int32_t start_frame_index_ = 0; // never reset
OnlineTransducerDecoderResult result_;
std::vector states_;
+ std::vector paraformer_feat_cache_;
+ std::vector paraformer_encoder_out_cache_;
+ std::vector paraformer_alpha_cache_;
+ OnlineParaformerDecoderResult paraformer_result_;
};
OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/,
@@ -107,6 +131,14 @@ OnlineTransducerDecoderResult &OnlineStream::GetResult() {
return impl_->GetResult();
}
+void OnlineStream::SetParaformerResult(const OnlineParaformerDecoderResult &r) {
+ impl_->SetParaformerResult(r);
+}
+
+OnlineParaformerDecoderResult &OnlineStream::GetParaformerResult() {
+ return impl_->GetParaformerResult();
+}
+
void OnlineStream::SetStates(std::vector states) {
impl_->SetStates(std::move(states));
}
@@ -119,4 +151,16 @@ const ContextGraphPtr &OnlineStream::GetContextGraph() const {
return impl_->GetContextGraph();
}
+std::vector &OnlineStream::GetParaformerFeatCache() {
+ return impl_->GetParaformerFeatCache();
+}
+
+std::vector &OnlineStream::GetParaformerEncoderOutCache() {
+ return impl_->GetParaformerEncoderOutCache();
+}
+
+std::vector &OnlineStream::GetParaformerAlphaCache() {
+ return impl_->GetParaformerAlphaCache();
+}
+
} // namespace sherpa_onnx
diff --git a/sherpa-onnx/csrc/online-stream.h b/sherpa-onnx/csrc/online-stream.h
index 60dce950..ae920c1d 100644
--- a/sherpa-onnx/csrc/online-stream.h
+++ b/sherpa-onnx/csrc/online-stream.h
@@ -11,6 +11,7 @@
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/context-graph.h"
#include "sherpa-onnx/csrc/features.h"
+#include "sherpa-onnx/csrc/online-paraformer-decoder.h"
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
namespace sherpa_onnx {
@@ -70,6 +71,9 @@ class OnlineStream {
void SetResult(const OnlineTransducerDecoderResult &r);
OnlineTransducerDecoderResult &GetResult();
+ void SetParaformerResult(const OnlineParaformerDecoderResult &r);
+ OnlineParaformerDecoderResult &GetParaformerResult();
+
void SetStates(std::vector states);
std::vector &GetStates();
@@ -80,6 +84,11 @@ class OnlineStream {
*/
const ContextGraphPtr &GetContextGraph() const;
+ // for streaming parformer
+ std::vector &GetParaformerFeatCache();
+ std::vector &GetParaformerEncoderOutCache();
+ std::vector &GetParaformerAlphaCache();
+
private:
class Impl;
std::unique_ptr impl_;
diff --git a/sherpa-onnx/csrc/sherpa-onnx.cc b/sherpa-onnx/csrc/sherpa-onnx.cc
index 8d527e90..9e771fc5 100644
--- a/sherpa-onnx/csrc/sherpa-onnx.cc
+++ b/sherpa-onnx/csrc/sherpa-onnx.cc
@@ -12,8 +12,8 @@
#include "sherpa-onnx/csrc/online-recognizer.h"
#include "sherpa-onnx/csrc/online-stream.h"
-#include "sherpa-onnx/csrc/symbol-table.h"
#include "sherpa-onnx/csrc/parse-options.h"
+#include "sherpa-onnx/csrc/symbol-table.h"
#include "sherpa-onnx/csrc/wave-reader.h"
typedef struct {
@@ -80,7 +80,7 @@ for a list of pre-trained models to download.
bool is_ok = false;
const std::vector samples =
- sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok);
+ sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok);
if (!is_ok) {
fprintf(stderr, "Failed to read %s\n", wav_filename.c_str());
@@ -92,14 +92,14 @@ for a list of pre-trained models to download.
auto s = recognizer.CreateStream();
s->AcceptWaveform(sampling_rate, samples.data(), samples.size());
- std::vector tail_paddings(static_cast(0.3 * sampling_rate));
+ std::vector tail_paddings(static_cast(0.8 * sampling_rate));
// Note: We can call AcceptWaveform() multiple times.
- s->AcceptWaveform(
- sampling_rate, tail_paddings.data(), tail_paddings.size());
+ s->AcceptWaveform(sampling_rate, tail_paddings.data(),
+ tail_paddings.size());
// Call InputFinished() to indicate that no audio samples are available
s->InputFinished();
- ss.push_back({ std::move(s), duration, 0 });
+ ss.push_back({std::move(s), duration, 0});
}
std::vector ready_streams;
@@ -112,8 +112,9 @@ for a list of pre-trained models to download.
} else if (s.elapsed_seconds == 0) {
const auto end = std::chrono::steady_clock::now();
const float elapsed_seconds =
- std::chrono::duration_cast(end - begin)
- .count() / 1000.;
+ std::chrono::duration_cast(end - begin)
+ .count() /
+ 1000.;
s.elapsed_seconds = elapsed_seconds;
}
}
diff --git a/sherpa-onnx/python/csrc/CMakeLists.txt b/sherpa-onnx/python/csrc/CMakeLists.txt
index 28612924..d61e4303 100644
--- a/sherpa-onnx/python/csrc/CMakeLists.txt
+++ b/sherpa-onnx/python/csrc/CMakeLists.txt
@@ -15,6 +15,7 @@ pybind11_add_module(_sherpa_onnx
offline-whisper-model-config.cc
online-lm-config.cc
online-model-config.cc
+ online-paraformer-model-config.cc
online-recognizer.cc
online-stream.cc
online-transducer-model-config.cc
diff --git a/sherpa-onnx/python/csrc/online-model-config.cc b/sherpa-onnx/python/csrc/online-model-config.cc
index 677d3b1f..7e37a87c 100644
--- a/sherpa-onnx/python/csrc/online-model-config.cc
+++ b/sherpa-onnx/python/csrc/online-model-config.cc
@@ -1,6 +1,6 @@
// sherpa-onnx/python/csrc/online-model-config.cc
//
-// Copyright (c) 2023 by manyeyes
+// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/python/csrc/online-model-config.h"
@@ -9,21 +9,26 @@
#include "sherpa-onnx/csrc/online-model-config.h"
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
+#include "sherpa-onnx/python/csrc/online-paraformer-model-config.h"
#include "sherpa-onnx/python/csrc/online-transducer-model-config.h"
namespace sherpa_onnx {
void PybindOnlineModelConfig(py::module *m) {
PybindOnlineTransducerModelConfig(m);
+ PybindOnlineParaformerModelConfig(m);
using PyClass = OnlineModelConfig;
py::class_(*m, "OnlineModelConfig")
- .def(py::init(),
py::arg("transducer") = OnlineTransducerModelConfig(),
+ py::arg("paraformer") = OnlineParaformerModelConfig(),
py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false,
py::arg("provider") = "cpu", py::arg("model_type") = "")
.def_readwrite("transducer", &PyClass::transducer)
+ .def_readwrite("paraformer", &PyClass::paraformer)
.def_readwrite("tokens", &PyClass::tokens)
.def_readwrite("num_threads", &PyClass::num_threads)
.def_readwrite("debug", &PyClass::debug)
diff --git a/sherpa-onnx/python/csrc/online-model-config.h b/sherpa-onnx/python/csrc/online-model-config.h
index 73154fc9..3624a104 100644
--- a/sherpa-onnx/python/csrc/online-model-config.h
+++ b/sherpa-onnx/python/csrc/online-model-config.h
@@ -1,6 +1,6 @@
// sherpa-onnx/python/csrc/online-model-config.h
//
-// Copyright (c) 2023 by manyeyes
+// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_MODEL_CONFIG_H_
#define SHERPA_ONNX_PYTHON_CSRC_ONLINE_MODEL_CONFIG_H_
diff --git a/sherpa-onnx/python/csrc/online-paraformer-model-config.cc b/sherpa-onnx/python/csrc/online-paraformer-model-config.cc
new file mode 100644
index 00000000..84895acb
--- /dev/null
+++ b/sherpa-onnx/python/csrc/online-paraformer-model-config.cc
@@ -0,0 +1,24 @@
+// sherpa-onnx/python/csrc/online-paraformer-model-config.cc
+//
+// Copyright (c) 2023 Xiaomi Corporation
+
+#include "sherpa-onnx/python/csrc/online-paraformer-model-config.h"
+
+#include
+#include
+
+#include "sherpa-onnx/csrc/online-paraformer-model-config.h"
+
+namespace sherpa_onnx {
+
+void PybindOnlineParaformerModelConfig(py::module *m) {
+ using PyClass = OnlineParaformerModelConfig;
+ py::class_(*m, "OnlineParaformerModelConfig")
+ .def(py::init(),
+ py::arg("encoder"), py::arg("decoder"))
+ .def_readwrite("encoder", &PyClass::encoder)
+ .def_readwrite("decoder", &PyClass::decoder)
+ .def("__str__", &PyClass::ToString);
+}
+
+} // namespace sherpa_onnx
diff --git a/sherpa-onnx/python/csrc/online-paraformer-model-config.h b/sherpa-onnx/python/csrc/online-paraformer-model-config.h
new file mode 100644
index 00000000..ad1dc1d7
--- /dev/null
+++ b/sherpa-onnx/python/csrc/online-paraformer-model-config.h
@@ -0,0 +1,16 @@
+// sherpa-onnx/python/csrc/online-paraformer-model-config.h
+//
+// Copyright (c) 2023 Xiaomi Corporation
+
+#ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_PARAFORMER_MODEL_CONFIG_H_
+#define SHERPA_ONNX_PYTHON_CSRC_ONLINE_PARAFORMER_MODEL_CONFIG_H_
+
+#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
+
+namespace sherpa_onnx {
+
+void PybindOnlineParaformerModelConfig(py::module *m);
+
+}
+
+#endif // SHERPA_ONNX_PYTHON_CSRC_ONLINE_PARAFORMER_MODEL_CONFIG_H_
diff --git a/sherpa-onnx/python/csrc/online-recognizer.cc b/sherpa-onnx/python/csrc/online-recognizer.cc
index 34a907ce..c130d87c 100644
--- a/sherpa-onnx/python/csrc/online-recognizer.cc
+++ b/sherpa-onnx/python/csrc/online-recognizer.cc
@@ -33,7 +33,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
py::arg("feat_config"), py::arg("model_config"),
py::arg("lm_config") = OnlineLMConfig(), py::arg("endpoint_config"),
py::arg("enable_endpoint"), py::arg("decoding_method"),
- py::arg("max_active_paths"), py::arg("context_score"))
+ py::arg("max_active_paths") = 4, py::arg("context_score") = 0)
.def_readwrite("feat_config", &PyClass::feat_config)
.def_readwrite("model_config", &PyClass::model_config)
.def_readwrite("endpoint_config", &PyClass::endpoint_config)
diff --git a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py
index c49e1b43..55e789ba 100644
--- a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py
+++ b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py
@@ -6,6 +6,7 @@ from _sherpa_onnx import (
EndpointConfig,
FeatureExtractorConfig,
OnlineModelConfig,
+ OnlineParaformerModelConfig,
OnlineRecognizer as _Recognizer,
OnlineRecognizerConfig,
OnlineStream,
@@ -32,7 +33,7 @@ class OnlineRecognizer(object):
encoder: str,
decoder: str,
joiner: str,
- num_threads: int = 4,
+ num_threads: int = 2,
sample_rate: float = 16000,
feature_dim: int = 80,
enable_endpoint_detection: bool = False,
@@ -144,6 +145,109 @@ class OnlineRecognizer(object):
self.config = recognizer_config
return self
+ @classmethod
+ def from_paraformer(
+ cls,
+ tokens: str,
+ encoder: str,
+ decoder: str,
+ num_threads: int = 2,
+ sample_rate: float = 16000,
+ feature_dim: int = 80,
+ enable_endpoint_detection: bool = False,
+ rule1_min_trailing_silence: float = 2.4,
+ rule2_min_trailing_silence: float = 1.2,
+ rule3_min_utterance_length: float = 20.0,
+ decoding_method: str = "greedy_search",
+ provider: str = "cpu",
+ ):
+ """
+ Please refer to
+ ``_
+ to download pre-trained models for different languages, e.g., Chinese,
+ English, etc.
+
+ Args:
+ tokens:
+ Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
+ columns::
+
+ symbol integer_id
+
+ encoder:
+ Path to ``encoder.onnx``.
+ decoder:
+ Path to ``decoder.onnx``.
+ num_threads:
+ Number of threads for neural network computation.
+ sample_rate:
+ Sample rate of the training data used to train the model.
+ feature_dim:
+ Dimension of the feature used to train the model.
+ enable_endpoint_detection:
+ True to enable endpoint detection. False to disable endpoint
+ detection.
+ rule1_min_trailing_silence:
+ Used only when enable_endpoint_detection is True. If the duration
+ of trailing silence in seconds is larger than this value, we assume
+ an endpoint is detected.
+ rule2_min_trailing_silence:
+ Used only when enable_endpoint_detection is True. If we have decoded
+ something that is nonsilence and if the duration of trailing silence
+ in seconds is larger than this value, we assume an endpoint is
+ detected.
+ rule3_min_utterance_length:
+ Used only when enable_endpoint_detection is True. If the utterance
+ length in seconds is larger than this value, we assume an endpoint
+ is detected.
+ decoding_method:
+ The only valid value is greedy_search.
+ provider:
+ onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
+ """
+ self = cls.__new__(cls)
+ _assert_file_exists(tokens)
+ _assert_file_exists(encoder)
+ _assert_file_exists(decoder)
+
+ assert num_threads > 0, num_threads
+
+ paraformer_config = OnlineParaformerModelConfig(
+ encoder=encoder,
+ decoder=decoder,
+ )
+
+ model_config = OnlineModelConfig(
+ paraformer=paraformer_config,
+ tokens=tokens,
+ num_threads=num_threads,
+ provider=provider,
+ model_type="paraformer",
+ )
+
+ feat_config = FeatureExtractorConfig(
+ sampling_rate=sample_rate,
+ feature_dim=feature_dim,
+ )
+
+ endpoint_config = EndpointConfig(
+ rule1_min_trailing_silence=rule1_min_trailing_silence,
+ rule2_min_trailing_silence=rule2_min_trailing_silence,
+ rule3_min_utterance_length=rule3_min_utterance_length,
+ )
+
+ recognizer_config = OnlineRecognizerConfig(
+ feat_config=feat_config,
+ model_config=model_config,
+ endpoint_config=endpoint_config,
+ enable_endpoint=enable_endpoint_detection,
+ decoding_method=decoding_method,
+ )
+
+ self.recognizer = _Recognizer(recognizer_config)
+ self.config = recognizer_config
+ return self
+
def create_stream(self, contexts_list: Optional[List[List[int]]] = None):
if contexts_list is None:
return self.recognizer.create_stream()