Support streaming paraformer (#263)
This commit is contained in:
53
.github/scripts/test-online-paraformer.sh
vendored
Executable file
53
.github/scripts/test-online-paraformer.sh
vendored
Executable file
@@ -0,0 +1,53 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
log() {
|
||||||
|
# This function is from espnet
|
||||||
|
local fname=${BASH_SOURCE[1]##*/}
|
||||||
|
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||||
|
}
|
||||||
|
|
||||||
|
echo "EXE is $EXE"
|
||||||
|
echo "PATH: $PATH"
|
||||||
|
|
||||||
|
which $EXE
|
||||||
|
|
||||||
|
log "------------------------------------------------------------"
|
||||||
|
log "Run streaming Paraformer"
|
||||||
|
log "------------------------------------------------------------"
|
||||||
|
|
||||||
|
repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-streaming-paraformer-bilingual-zh-en
|
||||||
|
log "Start testing ${repo_url}"
|
||||||
|
repo=$(basename $repo_url)
|
||||||
|
log "Download pretrained model and test-data from $repo_url"
|
||||||
|
|
||||||
|
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||||
|
pushd $repo
|
||||||
|
git lfs pull --include "*.onnx"
|
||||||
|
ls -lh *.onnx
|
||||||
|
popd
|
||||||
|
|
||||||
|
time $EXE \
|
||||||
|
--tokens=$repo/tokens.txt \
|
||||||
|
--paraformer-encoder=$repo/encoder.onnx \
|
||||||
|
--paraformer-decoder=$repo/decoder.onnx \
|
||||||
|
--num-threads=2 \
|
||||||
|
$repo/test_wavs/0.wav \
|
||||||
|
$repo/test_wavs/1.wav \
|
||||||
|
$repo/test_wavs/2.wav \
|
||||||
|
$repo/test_wavs/3.wav \
|
||||||
|
$repo/test_wavs/8k.wav
|
||||||
|
|
||||||
|
time $EXE \
|
||||||
|
--tokens=$repo/tokens.txt \
|
||||||
|
--paraformer-encoder=$repo/encoder.int8.onnx \
|
||||||
|
--paraformer-decoder=$repo/decoder.int8.onnx \
|
||||||
|
--num-threads=2 \
|
||||||
|
$repo/test_wavs/0.wav \
|
||||||
|
$repo/test_wavs/1.wav \
|
||||||
|
$repo/test_wavs/2.wav \
|
||||||
|
$repo/test_wavs/3.wav \
|
||||||
|
$repo/test_wavs/8k.wav
|
||||||
|
|
||||||
|
rm -rf $repo
|
||||||
10
.github/workflows/linux-gpu.yaml
vendored
10
.github/workflows/linux-gpu.yaml
vendored
@@ -9,6 +9,7 @@ on:
|
|||||||
paths:
|
paths:
|
||||||
- '.github/workflows/linux-gpu.yaml'
|
- '.github/workflows/linux-gpu.yaml'
|
||||||
- '.github/scripts/test-online-transducer.sh'
|
- '.github/scripts/test-online-transducer.sh'
|
||||||
|
- '.github/scripts/test-online-paraformer.sh'
|
||||||
- '.github/scripts/test-offline-transducer.sh'
|
- '.github/scripts/test-offline-transducer.sh'
|
||||||
- '.github/scripts/test-offline-ctc.sh'
|
- '.github/scripts/test-offline-ctc.sh'
|
||||||
- 'CMakeLists.txt'
|
- 'CMakeLists.txt'
|
||||||
@@ -22,6 +23,7 @@ on:
|
|||||||
paths:
|
paths:
|
||||||
- '.github/workflows/linux-gpu.yaml'
|
- '.github/workflows/linux-gpu.yaml'
|
||||||
- '.github/scripts/test-online-transducer.sh'
|
- '.github/scripts/test-online-transducer.sh'
|
||||||
|
- '.github/scripts/test-online-paraformer.sh'
|
||||||
- '.github/scripts/test-offline-transducer.sh'
|
- '.github/scripts/test-offline-transducer.sh'
|
||||||
- '.github/scripts/test-offline-ctc.sh'
|
- '.github/scripts/test-offline-ctc.sh'
|
||||||
- 'CMakeLists.txt'
|
- 'CMakeLists.txt'
|
||||||
@@ -85,6 +87,14 @@ jobs:
|
|||||||
file build/bin/sherpa-onnx
|
file build/bin/sherpa-onnx
|
||||||
readelf -d build/bin/sherpa-onnx
|
readelf -d build/bin/sherpa-onnx
|
||||||
|
|
||||||
|
- name: Test online paraformer
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
export PATH=$PWD/build/bin:$PATH
|
||||||
|
export EXE=sherpa-onnx
|
||||||
|
|
||||||
|
.github/scripts/test-online-paraformer.sh
|
||||||
|
|
||||||
- name: Test offline Whisper
|
- name: Test offline Whisper
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
10
.github/workflows/linux.yaml
vendored
10
.github/workflows/linux.yaml
vendored
@@ -9,6 +9,7 @@ on:
|
|||||||
paths:
|
paths:
|
||||||
- '.github/workflows/linux.yaml'
|
- '.github/workflows/linux.yaml'
|
||||||
- '.github/scripts/test-online-transducer.sh'
|
- '.github/scripts/test-online-transducer.sh'
|
||||||
|
- '.github/scripts/test-online-paraformer.sh'
|
||||||
- '.github/scripts/test-offline-transducer.sh'
|
- '.github/scripts/test-offline-transducer.sh'
|
||||||
- '.github/scripts/test-offline-ctc.sh'
|
- '.github/scripts/test-offline-ctc.sh'
|
||||||
- 'CMakeLists.txt'
|
- 'CMakeLists.txt'
|
||||||
@@ -22,6 +23,7 @@ on:
|
|||||||
paths:
|
paths:
|
||||||
- '.github/workflows/linux.yaml'
|
- '.github/workflows/linux.yaml'
|
||||||
- '.github/scripts/test-online-transducer.sh'
|
- '.github/scripts/test-online-transducer.sh'
|
||||||
|
- '.github/scripts/test-online-paraformer.sh'
|
||||||
- '.github/scripts/test-offline-transducer.sh'
|
- '.github/scripts/test-offline-transducer.sh'
|
||||||
- '.github/scripts/test-offline-ctc.sh'
|
- '.github/scripts/test-offline-ctc.sh'
|
||||||
- 'CMakeLists.txt'
|
- 'CMakeLists.txt'
|
||||||
@@ -84,6 +86,14 @@ jobs:
|
|||||||
file build/bin/sherpa-onnx
|
file build/bin/sherpa-onnx
|
||||||
readelf -d build/bin/sherpa-onnx
|
readelf -d build/bin/sherpa-onnx
|
||||||
|
|
||||||
|
- name: Test online paraformer
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
export PATH=$PWD/build/bin:$PATH
|
||||||
|
export EXE=sherpa-onnx
|
||||||
|
|
||||||
|
.github/scripts/test-online-paraformer.sh
|
||||||
|
|
||||||
- name: Test offline Whisper
|
- name: Test offline Whisper
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
10
.github/workflows/macos.yaml
vendored
10
.github/workflows/macos.yaml
vendored
@@ -7,6 +7,7 @@ on:
|
|||||||
paths:
|
paths:
|
||||||
- '.github/workflows/macos.yaml'
|
- '.github/workflows/macos.yaml'
|
||||||
- '.github/scripts/test-online-transducer.sh'
|
- '.github/scripts/test-online-transducer.sh'
|
||||||
|
- '.github/scripts/test-online-paraformer.sh'
|
||||||
- '.github/scripts/test-offline-transducer.sh'
|
- '.github/scripts/test-offline-transducer.sh'
|
||||||
- '.github/scripts/test-offline-ctc.sh'
|
- '.github/scripts/test-offline-ctc.sh'
|
||||||
- 'CMakeLists.txt'
|
- 'CMakeLists.txt'
|
||||||
@@ -18,6 +19,7 @@ on:
|
|||||||
paths:
|
paths:
|
||||||
- '.github/workflows/macos.yaml'
|
- '.github/workflows/macos.yaml'
|
||||||
- '.github/scripts/test-online-transducer.sh'
|
- '.github/scripts/test-online-transducer.sh'
|
||||||
|
- '.github/scripts/test-online-paraformer.sh'
|
||||||
- '.github/scripts/test-offline-transducer.sh'
|
- '.github/scripts/test-offline-transducer.sh'
|
||||||
- '.github/scripts/test-offline-ctc.sh'
|
- '.github/scripts/test-offline-ctc.sh'
|
||||||
- 'CMakeLists.txt'
|
- 'CMakeLists.txt'
|
||||||
@@ -82,6 +84,14 @@ jobs:
|
|||||||
otool -L build/bin/sherpa-onnx
|
otool -L build/bin/sherpa-onnx
|
||||||
otool -l build/bin/sherpa-onnx
|
otool -l build/bin/sherpa-onnx
|
||||||
|
|
||||||
|
- name: Test online paraformer
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
export PATH=$PWD/build/bin:$PATH
|
||||||
|
export EXE=sherpa-onnx
|
||||||
|
|
||||||
|
.github/scripts/test-online-paraformer.sh
|
||||||
|
|
||||||
- name: Test offline Whisper
|
- name: Test offline Whisper
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
1
.github/workflows/test-pip-install.yaml
vendored
1
.github/workflows/test-pip-install.yaml
vendored
@@ -58,7 +58,6 @@ jobs:
|
|||||||
sherpa-onnx-microphone-offline --help
|
sherpa-onnx-microphone-offline --help
|
||||||
|
|
||||||
sherpa-onnx-offline-websocket-server --help
|
sherpa-onnx-offline-websocket-server --help
|
||||||
sherpa-onnx-offline-websocket-client --help
|
|
||||||
|
|
||||||
sherpa-onnx-online-websocket-server --help
|
sherpa-onnx-online-websocket-server --help
|
||||||
sherpa-onnx-online-websocket-client --help
|
sherpa-onnx-online-websocket-client --help
|
||||||
|
|||||||
@@ -84,14 +84,14 @@ jobs:
|
|||||||
if: matrix.model_type == 'paraformer'
|
if: matrix.model_type == 'paraformer'
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-paraformer-zh-2023-03-28
|
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-paraformer-bilingual-zh-en
|
||||||
cd sherpa-onnx-paraformer-zh-2023-03-28
|
cd sherpa-onnx-paraformer-bilingual-zh-en
|
||||||
git lfs pull --include "*.onnx"
|
git lfs pull --include "*.onnx"
|
||||||
cd ..
|
cd ..
|
||||||
|
|
||||||
python3 ./python-api-examples/non_streaming_server.py \
|
python3 ./python-api-examples/non_streaming_server.py \
|
||||||
--paraformer ./sherpa-onnx-paraformer-zh-2023-03-28/model.int8.onnx \
|
--paraformer ./sherpa-onnx-paraformer-bilingual-zh-en/model.int8.onnx \
|
||||||
--tokens ./sherpa-onnx-paraformer-zh-2023-03-28/tokens.txt &
|
--tokens ./sherpa-onnx-paraformer-bilingual-zh-en/tokens.txt &
|
||||||
|
|
||||||
echo "sleep 10 seconds to wait the server start"
|
echo "sleep 10 seconds to wait the server start"
|
||||||
sleep 10
|
sleep 10
|
||||||
@@ -101,16 +101,16 @@ jobs:
|
|||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
python3 ./python-api-examples/offline-websocket-client-decode-files-paralell.py \
|
python3 ./python-api-examples/offline-websocket-client-decode-files-paralell.py \
|
||||||
./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/0.wav \
|
./sherpa-onnx-paraformer-bilingual-zh-en/test_wavs/0.wav \
|
||||||
./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/1.wav \
|
./sherpa-onnx-paraformer-bilingual-zh-en/test_wavs/1.wav \
|
||||||
./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/2.wav \
|
./sherpa-onnx-paraformer-bilingual-zh-en/test_wavs/2.wav \
|
||||||
./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/8k.wav
|
./sherpa-onnx-paraformer-bilingual-zh-en/test_wavs/8k.wav
|
||||||
|
|
||||||
python3 ./python-api-examples/offline-websocket-client-decode-files-sequential.py \
|
python3 ./python-api-examples/offline-websocket-client-decode-files-sequential.py \
|
||||||
./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/0.wav \
|
./sherpa-onnx-paraformer-bilingual-zh-en/test_wavs/0.wav \
|
||||||
./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/1.wav \
|
./sherpa-onnx-paraformer-bilingual-zh-en/test_wavs/1.wav \
|
||||||
./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/2.wav \
|
./sherpa-onnx-paraformer-bilingual-zh-en/test_wavs/2.wav \
|
||||||
./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/8k.wav
|
./sherpa-onnx-paraformer-bilingual-zh-en/test_wavs/8k.wav
|
||||||
|
|
||||||
- name: Start server for nemo_ctc models
|
- name: Start server for nemo_ctc models
|
||||||
if: matrix.model_type == 'nemo_ctc'
|
if: matrix.model_type == 'nemo_ctc'
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ jobs:
|
|||||||
matrix:
|
matrix:
|
||||||
os: [ubuntu-latest, windows-latest, macos-latest]
|
os: [ubuntu-latest, windows-latest, macos-latest]
|
||||||
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"]
|
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"]
|
||||||
model_type: ["transducer"]
|
model_type: ["transducer", "paraformer"]
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v2
|
||||||
@@ -71,3 +71,36 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
python3 ./python-api-examples/online-websocket-client-decode-file.py \
|
python3 ./python-api-examples/online-websocket-client-decode-file.py \
|
||||||
./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/0.wav
|
./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/0.wav
|
||||||
|
|
||||||
|
- name: Start server for paraformer models
|
||||||
|
if: matrix.model_type == 'paraformer'
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-paraformer-bilingual-zh-en
|
||||||
|
cd sherpa-onnx-streaming-paraformer-bilingual-zh-en
|
||||||
|
git lfs pull --include "*.onnx"
|
||||||
|
cd ..
|
||||||
|
|
||||||
|
python3 ./python-api-examples/streaming_server.py \
|
||||||
|
--tokens ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/tokens.txt \
|
||||||
|
--paraformer-encoder ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/encoder.int8.onnx \
|
||||||
|
--paraformer-decoder ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/decoder.int8.onnx &
|
||||||
|
|
||||||
|
echo "sleep 10 seconds to wait the server start"
|
||||||
|
sleep 10
|
||||||
|
|
||||||
|
- name: Start client for paraformer models
|
||||||
|
if: matrix.model_type == 'paraformer'
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
python3 ./python-api-examples/online-websocket-client-decode-file.py \
|
||||||
|
./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/0.wav
|
||||||
|
|
||||||
|
python3 ./python-api-examples/online-websocket-client-decode-file.py \
|
||||||
|
./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/1.wav
|
||||||
|
|
||||||
|
python3 ./python-api-examples/online-websocket-client-decode-file.py \
|
||||||
|
./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/2.wav
|
||||||
|
|
||||||
|
python3 ./python-api-examples/online-websocket-client-decode-file.py \
|
||||||
|
./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/3.wav
|
||||||
|
|||||||
10
.github/workflows/windows-x64-cuda.yaml
vendored
10
.github/workflows/windows-x64-cuda.yaml
vendored
@@ -9,6 +9,7 @@ on:
|
|||||||
paths:
|
paths:
|
||||||
- '.github/workflows/windows-x64-cuda.yaml'
|
- '.github/workflows/windows-x64-cuda.yaml'
|
||||||
- '.github/scripts/test-online-transducer.sh'
|
- '.github/scripts/test-online-transducer.sh'
|
||||||
|
- '.github/scripts/test-online-paraformer.sh'
|
||||||
- '.github/scripts/test-offline-transducer.sh'
|
- '.github/scripts/test-offline-transducer.sh'
|
||||||
- '.github/scripts/test-offline-ctc.sh'
|
- '.github/scripts/test-offline-ctc.sh'
|
||||||
- 'CMakeLists.txt'
|
- 'CMakeLists.txt'
|
||||||
@@ -20,6 +21,7 @@ on:
|
|||||||
paths:
|
paths:
|
||||||
- '.github/workflows/windows-x64-cuda.yaml'
|
- '.github/workflows/windows-x64-cuda.yaml'
|
||||||
- '.github/scripts/test-online-transducer.sh'
|
- '.github/scripts/test-online-transducer.sh'
|
||||||
|
- '.github/scripts/test-online-paraformer.sh'
|
||||||
- '.github/scripts/test-offline-transducer.sh'
|
- '.github/scripts/test-offline-transducer.sh'
|
||||||
- '.github/scripts/test-offline-ctc.sh'
|
- '.github/scripts/test-offline-ctc.sh'
|
||||||
- 'CMakeLists.txt'
|
- 'CMakeLists.txt'
|
||||||
@@ -74,6 +76,14 @@ jobs:
|
|||||||
|
|
||||||
ls -lh ./bin/Release/sherpa-onnx.exe
|
ls -lh ./bin/Release/sherpa-onnx.exe
|
||||||
|
|
||||||
|
- name: Test online paraformer for windows x64
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
export PATH=$PWD/build/bin/Release:$PATH
|
||||||
|
export EXE=sherpa-onnx.exe
|
||||||
|
|
||||||
|
.github/scripts/test-online-paraformer.sh
|
||||||
|
|
||||||
- name: Test offline Whisper for windows x64
|
- name: Test offline Whisper for windows x64
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
10
.github/workflows/windows-x64.yaml
vendored
10
.github/workflows/windows-x64.yaml
vendored
@@ -9,6 +9,7 @@ on:
|
|||||||
paths:
|
paths:
|
||||||
- '.github/workflows/windows-x64.yaml'
|
- '.github/workflows/windows-x64.yaml'
|
||||||
- '.github/scripts/test-online-transducer.sh'
|
- '.github/scripts/test-online-transducer.sh'
|
||||||
|
- '.github/scripts/test-online-paraformer.sh'
|
||||||
- '.github/scripts/test-offline-transducer.sh'
|
- '.github/scripts/test-offline-transducer.sh'
|
||||||
- '.github/scripts/test-offline-ctc.sh'
|
- '.github/scripts/test-offline-ctc.sh'
|
||||||
- 'CMakeLists.txt'
|
- 'CMakeLists.txt'
|
||||||
@@ -20,6 +21,7 @@ on:
|
|||||||
paths:
|
paths:
|
||||||
- '.github/workflows/windows-x64.yaml'
|
- '.github/workflows/windows-x64.yaml'
|
||||||
- '.github/scripts/test-online-transducer.sh'
|
- '.github/scripts/test-online-transducer.sh'
|
||||||
|
- '.github/scripts/test-online-paraformer.sh'
|
||||||
- '.github/scripts/test-offline-transducer.sh'
|
- '.github/scripts/test-offline-transducer.sh'
|
||||||
- '.github/scripts/test-offline-ctc.sh'
|
- '.github/scripts/test-offline-ctc.sh'
|
||||||
- 'CMakeLists.txt'
|
- 'CMakeLists.txt'
|
||||||
@@ -75,6 +77,14 @@ jobs:
|
|||||||
|
|
||||||
ls -lh ./bin/Release/sherpa-onnx.exe
|
ls -lh ./bin/Release/sherpa-onnx.exe
|
||||||
|
|
||||||
|
- name: Test online paraformer for windows x64
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
export PATH=$PWD/build/bin/Release:$PATH
|
||||||
|
export EXE=sherpa-onnx.exe
|
||||||
|
|
||||||
|
.github/scripts/test-online-paraformer.sh
|
||||||
|
|
||||||
- name: Test offline Whisper for windows x64
|
- name: Test offline Whisper for windows x64
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
10
.github/workflows/windows-x86.yaml
vendored
10
.github/workflows/windows-x86.yaml
vendored
@@ -7,6 +7,7 @@ on:
|
|||||||
paths:
|
paths:
|
||||||
- '.github/workflows/windows-x86.yaml'
|
- '.github/workflows/windows-x86.yaml'
|
||||||
- '.github/scripts/test-online-transducer.sh'
|
- '.github/scripts/test-online-transducer.sh'
|
||||||
|
- '.github/scripts/test-online-paraformer.sh'
|
||||||
- '.github/scripts/test-offline-transducer.sh'
|
- '.github/scripts/test-offline-transducer.sh'
|
||||||
- '.github/scripts/test-offline-ctc.sh'
|
- '.github/scripts/test-offline-ctc.sh'
|
||||||
- 'CMakeLists.txt'
|
- 'CMakeLists.txt'
|
||||||
@@ -18,6 +19,7 @@ on:
|
|||||||
paths:
|
paths:
|
||||||
- '.github/workflows/windows-x86.yaml'
|
- '.github/workflows/windows-x86.yaml'
|
||||||
- '.github/scripts/test-online-transducer.sh'
|
- '.github/scripts/test-online-transducer.sh'
|
||||||
|
- '.github/scripts/test-online-paraformer.sh'
|
||||||
- '.github/scripts/test-offline-transducer.sh'
|
- '.github/scripts/test-offline-transducer.sh'
|
||||||
- '.github/scripts/test-offline-ctc.sh'
|
- '.github/scripts/test-offline-ctc.sh'
|
||||||
- 'CMakeLists.txt'
|
- 'CMakeLists.txt'
|
||||||
@@ -73,6 +75,14 @@ jobs:
|
|||||||
|
|
||||||
ls -lh ./bin/Release/sherpa-onnx.exe
|
ls -lh ./bin/Release/sherpa-onnx.exe
|
||||||
|
|
||||||
|
- name: Test online paraformer for windows x86
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
export PATH=$PWD/build/bin/Release:$PATH
|
||||||
|
export EXE=sherpa-onnx.exe
|
||||||
|
|
||||||
|
.github/scripts/test-online-paraformer.sh
|
||||||
|
|
||||||
- name: Test offline Whisper for windows x86
|
- name: Test offline Whisper for windows x86
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
|
cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
|
||||||
project(sherpa-onnx)
|
project(sherpa-onnx)
|
||||||
|
|
||||||
set(SHERPA_ONNX_VERSION "1.7.3")
|
set(SHERPA_ONNX_VERSION "1.7.4")
|
||||||
|
|
||||||
# Disable warning about
|
# Disable warning about
|
||||||
#
|
#
|
||||||
|
|||||||
@@ -37,14 +37,14 @@ python3 ./python-api-examples/non_streaming_server.py \
|
|||||||
(2) Use a non-streaming paraformer
|
(2) Use a non-streaming paraformer
|
||||||
|
|
||||||
cd /path/to/sherpa-onnx
|
cd /path/to/sherpa-onnx
|
||||||
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-paraformer-zh-2023-03-28
|
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-paraformer-bilingual-zh-en
|
||||||
cd sherpa-onnx-paraformer-zh-2023-03-28
|
cd sherpa-onnx-paraformer-bilingual-zh-en/
|
||||||
git lfs pull --include "*.onnx"
|
git lfs pull --include "*.onnx"
|
||||||
cd ..
|
cd ..
|
||||||
|
|
||||||
python3 ./python-api-examples/non_streaming_server.py \
|
python3 ./python-api-examples/non_streaming_server.py \
|
||||||
--paraformer ./sherpa-onnx-paraformer-zh-2023-03-28/model.int8.onnx \
|
--paraformer ./sherpa-onnx-paraformer-bilingual-zh-en/model.int8.onnx \
|
||||||
--tokens ./sherpa-onnx-paraformer-zh-2023-03-28/tokens.txt
|
--tokens ./sherpa-onnx-paraformer-bilingual-zh-en/tokens.txt
|
||||||
|
|
||||||
(3) Use a non-streaming CTC model from NeMo
|
(3) Use a non-streaming CTC model from NeMo
|
||||||
|
|
||||||
|
|||||||
@@ -5,16 +5,41 @@ This file demonstrates how to use sherpa-onnx Python API to transcribe
|
|||||||
file(s) with a streaming model.
|
file(s) with a streaming model.
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
./online-decode-files.py \
|
|
||||||
/path/to/foo.wav \
|
(1) Streaming transducer
|
||||||
/path/to/bar.wav \
|
|
||||||
/path/to/16kHz.wav \
|
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-06-26
|
||||||
/path/to/8kHz.wav
|
cd sherpa-onnx-streaming-zipformer-en-2023-06-26
|
||||||
|
git lfs pull --include "*.onnx"
|
||||||
|
|
||||||
|
./python-api-examples/online-decode-files.py \
|
||||||
|
--tokens=./sherpa-onnx-streaming-zipformer-en-2023-06-26/tokens.txt \
|
||||||
|
--encoder=./sherpa-onnx-streaming-zipformer-en-2023-06-26/encoder-epoch-99-avg-1-chunk-16-left-64.onnx \
|
||||||
|
--decoder=./sherpa-onnx-streaming-zipformer-en-2023-06-26/decoder-epoch-99-avg-1-chunk-16-left-64.onnx \
|
||||||
|
--joiner=./sherpa-onnx-streaming-zipformer-en-2023-06-26/joiner-epoch-99-avg-1-chunk-16-left-64.onnx \
|
||||||
|
./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/0.wav \
|
||||||
|
./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/1.wav \
|
||||||
|
./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/8k.wav
|
||||||
|
|
||||||
|
(2) Streaming paraformer
|
||||||
|
|
||||||
|
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-paraformer-bilingual-zh-en
|
||||||
|
cd sherpa-onnx-streaming-paraformer-bilingual-zh-en
|
||||||
|
git lfs pull --include "*.onnx"
|
||||||
|
|
||||||
|
./python-api-examples/online-decode-files.py \
|
||||||
|
--tokens=./sherpa-onnx-streaming-paraformer-bilingual-zh-en/tokens.txt \
|
||||||
|
--paraformer-encoder=./sherpa-onnx-streaming-paraformer-bilingual-zh-en/encoder.int8.onnx \
|
||||||
|
--paraformer-decoder=./sherpa-onnx-streaming-paraformer-bilingual-zh-en/decoder.int8.onnx \
|
||||||
|
./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/0.wav \
|
||||||
|
./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/1.wav \
|
||||||
|
./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/2.wav \
|
||||||
|
./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/3.wav \
|
||||||
|
./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/8k.wav
|
||||||
|
|
||||||
Please refer to
|
Please refer to
|
||||||
https://k2-fsa.github.io/sherpa/onnx/index.html
|
https://k2-fsa.github.io/sherpa/onnx/index.html
|
||||||
to install sherpa-onnx and to download the pre-trained models
|
to install sherpa-onnx and to download streaming pre-trained models.
|
||||||
used in this file.
|
|
||||||
"""
|
"""
|
||||||
import argparse
|
import argparse
|
||||||
import time
|
import time
|
||||||
@@ -41,19 +66,31 @@ def get_args():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--encoder",
|
"--encoder",
|
||||||
type=str,
|
type=str,
|
||||||
help="Path to the encoder model",
|
help="Path to the transducer encoder model",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--decoder",
|
"--decoder",
|
||||||
type=str,
|
type=str,
|
||||||
help="Path to the decoder model",
|
help="Path to the transducer decoder model",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--joiner",
|
"--joiner",
|
||||||
type=str,
|
type=str,
|
||||||
help="Path to the joiner model",
|
help="Path to the transducer joiner model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--paraformer-encoder",
|
||||||
|
type=str,
|
||||||
|
help="Path to the paraformer encoder model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--paraformer-decoder",
|
||||||
|
type=str,
|
||||||
|
help="Path to the paraformer decoder model",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -200,24 +237,42 @@ def encode_contexts(args, contexts: List[str]) -> List[List[int]]:
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = get_args()
|
args = get_args()
|
||||||
assert_file_exists(args.encoder)
|
|
||||||
assert_file_exists(args.decoder)
|
|
||||||
assert_file_exists(args.joiner)
|
|
||||||
assert_file_exists(args.tokens)
|
assert_file_exists(args.tokens)
|
||||||
|
|
||||||
recognizer = sherpa_onnx.OnlineRecognizer.from_transducer(
|
if args.encoder:
|
||||||
tokens=args.tokens,
|
assert_file_exists(args.encoder)
|
||||||
encoder=args.encoder,
|
assert_file_exists(args.decoder)
|
||||||
decoder=args.decoder,
|
assert_file_exists(args.joiner)
|
||||||
joiner=args.joiner,
|
|
||||||
num_threads=args.num_threads,
|
assert not args.paraformer_encoder, args.paraformer_encoder
|
||||||
provider=args.provider,
|
assert not args.paraformer_decoder, args.paraformer_decoder
|
||||||
sample_rate=16000,
|
|
||||||
feature_dim=80,
|
recognizer = sherpa_onnx.OnlineRecognizer.from_transducer(
|
||||||
decoding_method=args.decoding_method,
|
tokens=args.tokens,
|
||||||
max_active_paths=args.max_active_paths,
|
encoder=args.encoder,
|
||||||
context_score=args.context_score,
|
decoder=args.decoder,
|
||||||
)
|
joiner=args.joiner,
|
||||||
|
num_threads=args.num_threads,
|
||||||
|
provider=args.provider,
|
||||||
|
sample_rate=16000,
|
||||||
|
feature_dim=80,
|
||||||
|
decoding_method=args.decoding_method,
|
||||||
|
max_active_paths=args.max_active_paths,
|
||||||
|
context_score=args.context_score,
|
||||||
|
)
|
||||||
|
elif args.paraformer_encoder:
|
||||||
|
recognizer = sherpa_onnx.OnlineRecognizer.from_paraformer(
|
||||||
|
tokens=args.tokens,
|
||||||
|
encoder=args.paraformer_encoder,
|
||||||
|
decoder=args.paraformer_decoder,
|
||||||
|
num_threads=args.num_threads,
|
||||||
|
provider=args.provider,
|
||||||
|
sample_rate=16000,
|
||||||
|
feature_dim=80,
|
||||||
|
decoding_method="greedy_search",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("Please provide a model")
|
||||||
|
|
||||||
print("Started!")
|
print("Started!")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
@@ -243,7 +298,7 @@ def main():
|
|||||||
|
|
||||||
s.accept_waveform(sample_rate, samples)
|
s.accept_waveform(sample_rate, samples)
|
||||||
|
|
||||||
tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32)
|
tail_paddings = np.zeros(int(0.66 * sample_rate), dtype=np.float32)
|
||||||
s.accept_waveform(sample_rate, tail_paddings)
|
s.accept_waveform(sample_rate, tail_paddings)
|
||||||
|
|
||||||
s.input_finished()
|
s.input_finished()
|
||||||
|
|||||||
@@ -16,9 +16,9 @@ Example:
|
|||||||
(1) Without a certificate
|
(1) Without a certificate
|
||||||
|
|
||||||
python3 ./python-api-examples/streaming_server.py \
|
python3 ./python-api-examples/streaming_server.py \
|
||||||
--encoder-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx \
|
--encoder ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx \
|
||||||
--decoder-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx \
|
--decoder ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx \
|
||||||
--joiner-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx \
|
--joiner ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx \
|
||||||
--tokens ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt
|
--tokens ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt
|
||||||
|
|
||||||
(2) With a certificate
|
(2) With a certificate
|
||||||
@@ -32,9 +32,9 @@ python3 ./python-api-examples/streaming_server.py \
|
|||||||
(b) Start the server
|
(b) Start the server
|
||||||
|
|
||||||
python3 ./python-api-examples/streaming_server.py \
|
python3 ./python-api-examples/streaming_server.py \
|
||||||
--encoder-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx \
|
--encoder ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx \
|
||||||
--decoder-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx \
|
--decoder ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx \
|
||||||
--joiner-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx \
|
--joiner ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx \
|
||||||
--tokens ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt \
|
--tokens ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt \
|
||||||
--certificate ./python-api-examples/web/cert.pem
|
--certificate ./python-api-examples/web/cert.pem
|
||||||
|
|
||||||
@@ -113,24 +113,33 @@ def setup_logger(
|
|||||||
|
|
||||||
def add_model_args(parser: argparse.ArgumentParser):
|
def add_model_args(parser: argparse.ArgumentParser):
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--encoder-model",
|
"--encoder",
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
help="Path to the transducer encoder model",
|
||||||
help="Path to the encoder model",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--decoder-model",
|
"--decoder",
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
help="Path to the transducer decoder model.",
|
||||||
help="Path to the decoder model.",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--joiner-model",
|
"--joiner",
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
help="Path to the transducer joiner model.",
|
||||||
help="Path to the joiner model.",
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--paraformer-encoder",
|
||||||
|
type=str,
|
||||||
|
help="Path to the paraformer encoder model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--paraformer-decoder",
|
||||||
|
type=str,
|
||||||
|
help="Path to the transducer decoder model.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -323,22 +332,40 @@ def get_args():
|
|||||||
|
|
||||||
|
|
||||||
def create_recognizer(args) -> sherpa_onnx.OnlineRecognizer:
|
def create_recognizer(args) -> sherpa_onnx.OnlineRecognizer:
|
||||||
recognizer = sherpa_onnx.OnlineRecognizer.from_transducer(
|
if args.encoder:
|
||||||
tokens=args.tokens,
|
recognizer = sherpa_onnx.OnlineRecognizer.from_transducer(
|
||||||
encoder=args.encoder_model,
|
tokens=args.tokens,
|
||||||
decoder=args.decoder_model,
|
encoder=args.encoder,
|
||||||
joiner=args.joiner_model,
|
decoder=args.decoder,
|
||||||
num_threads=args.num_threads,
|
joiner=args.joiner,
|
||||||
sample_rate=args.sample_rate,
|
num_threads=args.num_threads,
|
||||||
feature_dim=args.feat_dim,
|
sample_rate=args.sample_rate,
|
||||||
decoding_method=args.decoding_method,
|
feature_dim=args.feat_dim,
|
||||||
max_active_paths=args.num_active_paths,
|
decoding_method=args.decoding_method,
|
||||||
enable_endpoint_detection=args.use_endpoint != 0,
|
max_active_paths=args.num_active_paths,
|
||||||
rule1_min_trailing_silence=args.rule1_min_trailing_silence,
|
enable_endpoint_detection=args.use_endpoint != 0,
|
||||||
rule2_min_trailing_silence=args.rule2_min_trailing_silence,
|
rule1_min_trailing_silence=args.rule1_min_trailing_silence,
|
||||||
rule3_min_utterance_length=args.rule3_min_utterance_length,
|
rule2_min_trailing_silence=args.rule2_min_trailing_silence,
|
||||||
provider=args.provider,
|
rule3_min_utterance_length=args.rule3_min_utterance_length,
|
||||||
)
|
provider=args.provider,
|
||||||
|
)
|
||||||
|
elif args.paraformer_encoder:
|
||||||
|
recognizer = sherpa_onnx.OnlineRecognizer.from_paraformer(
|
||||||
|
tokens=args.tokens,
|
||||||
|
encoder=args.paraformer_encoder,
|
||||||
|
decoder=args.paraformer_decoder,
|
||||||
|
num_threads=args.num_threads,
|
||||||
|
sample_rate=args.sample_rate,
|
||||||
|
feature_dim=args.feat_dim,
|
||||||
|
decoding_method=args.decoding_method,
|
||||||
|
enable_endpoint_detection=args.use_endpoint != 0,
|
||||||
|
rule1_min_trailing_silence=args.rule1_min_trailing_silence,
|
||||||
|
rule2_min_trailing_silence=args.rule2_min_trailing_silence,
|
||||||
|
rule3_min_utterance_length=args.rule3_min_utterance_length,
|
||||||
|
provider=args.provider,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("Please provide a model")
|
||||||
|
|
||||||
return recognizer
|
return recognizer
|
||||||
|
|
||||||
@@ -654,11 +681,25 @@ Go back to <a href="/streaming_record.html">/streaming_record.html</a>
|
|||||||
|
|
||||||
|
|
||||||
def check_args(args):
|
def check_args(args):
|
||||||
assert Path(args.encoder_model).is_file(), f"{args.encoder_model} does not exist"
|
if args.encoder:
|
||||||
|
assert Path(args.encoder).is_file(), f"{args.encoder} does not exist"
|
||||||
|
|
||||||
assert Path(args.decoder_model).is_file(), f"{args.decoder_model} does not exist"
|
assert Path(args.decoder).is_file(), f"{args.decoder} does not exist"
|
||||||
|
|
||||||
assert Path(args.joiner_model).is_file(), f"{args.joiner_model} does not exist"
|
assert Path(args.joiner).is_file(), f"{args.joiner} does not exist"
|
||||||
|
|
||||||
|
assert args.paraformer_encoder is None, args.paraformer_encoder
|
||||||
|
assert args.paraformer_decoder is None, args.paraformer_decoder
|
||||||
|
elif args.paraformer_encoder:
|
||||||
|
assert Path(
|
||||||
|
args.paraformer_encoder
|
||||||
|
).is_file(), f"{args.paraformer_encoder} does not exist"
|
||||||
|
|
||||||
|
assert Path(
|
||||||
|
args.paraformer_decoder
|
||||||
|
).is_file(), f"{args.paraformer_decoder} does not exist"
|
||||||
|
else:
|
||||||
|
raise ValueError("Please provide a model")
|
||||||
|
|
||||||
if not Path(args.tokens).is_file():
|
if not Path(args.tokens).is_file():
|
||||||
raise ValueError(f"{args.tokens} does not exist")
|
raise ValueError(f"{args.tokens} does not exist")
|
||||||
|
|||||||
@@ -46,6 +46,8 @@ set(sources
|
|||||||
online-lm.cc
|
online-lm.cc
|
||||||
online-lstm-transducer-model.cc
|
online-lstm-transducer-model.cc
|
||||||
online-model-config.cc
|
online-model-config.cc
|
||||||
|
online-paraformer-model-config.cc
|
||||||
|
online-paraformer-model.cc
|
||||||
online-recognizer-impl.cc
|
online-recognizer-impl.cc
|
||||||
online-recognizer.cc
|
online-recognizer.cc
|
||||||
online-rnn-lm.cc
|
online-rnn-lm.cc
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ std::string FeatureExtractorConfig::ToString() const {
|
|||||||
|
|
||||||
class FeatureExtractor::Impl {
|
class FeatureExtractor::Impl {
|
||||||
public:
|
public:
|
||||||
explicit Impl(const FeatureExtractorConfig &config) {
|
explicit Impl(const FeatureExtractorConfig &config) : config_(config) {
|
||||||
opts_.frame_opts.dither = 0;
|
opts_.frame_opts.dither = 0;
|
||||||
opts_.frame_opts.snip_edges = false;
|
opts_.frame_opts.snip_edges = false;
|
||||||
opts_.frame_opts.samp_freq = config.sampling_rate;
|
opts_.frame_opts.samp_freq = config.sampling_rate;
|
||||||
@@ -50,6 +50,19 @@ class FeatureExtractor::Impl {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
|
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
|
||||||
|
if (config_.normalize_samples) {
|
||||||
|
AcceptWaveformImpl(sampling_rate, waveform, n);
|
||||||
|
} else {
|
||||||
|
std::vector<float> buf(n);
|
||||||
|
for (int32_t i = 0; i != n; ++i) {
|
||||||
|
buf[i] = waveform[i] * 32768;
|
||||||
|
}
|
||||||
|
AcceptWaveformImpl(sampling_rate, buf.data(), n);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void AcceptWaveformImpl(int32_t sampling_rate, const float *waveform,
|
||||||
|
int32_t n) {
|
||||||
std::lock_guard<std::mutex> lock(mutex_);
|
std::lock_guard<std::mutex> lock(mutex_);
|
||||||
|
|
||||||
if (resampler_) {
|
if (resampler_) {
|
||||||
@@ -146,6 +159,7 @@ class FeatureExtractor::Impl {
|
|||||||
private:
|
private:
|
||||||
std::unique_ptr<knf::OnlineFbank> fbank_;
|
std::unique_ptr<knf::OnlineFbank> fbank_;
|
||||||
knf::FbankOptions opts_;
|
knf::FbankOptions opts_;
|
||||||
|
FeatureExtractorConfig config_;
|
||||||
mutable std::mutex mutex_;
|
mutable std::mutex mutex_;
|
||||||
std::unique_ptr<LinearResample> resampler_;
|
std::unique_ptr<LinearResample> resampler_;
|
||||||
int32_t last_frame_index_ = 0;
|
int32_t last_frame_index_ = 0;
|
||||||
|
|||||||
@@ -21,6 +21,13 @@ struct FeatureExtractorConfig {
|
|||||||
// Feature dimension
|
// Feature dimension
|
||||||
int32_t feature_dim = 80;
|
int32_t feature_dim = 80;
|
||||||
|
|
||||||
|
// Set internally by some models, e.g., paraformer sets it to false.
|
||||||
|
// This parameter is not exposed to users from the commandline
|
||||||
|
// If true, the feature extractor expects inputs to be normalized to
|
||||||
|
// the range [-1, 1].
|
||||||
|
// If false, we will multiply the inputs by 32768
|
||||||
|
bool normalize_samples = true;
|
||||||
|
|
||||||
std::string ToString() const;
|
std::string ToString() const;
|
||||||
|
|
||||||
void Register(ParseOptions *po);
|
void Register(ParseOptions *po);
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ namespace sherpa_onnx {
|
|||||||
|
|
||||||
void OnlineModelConfig::Register(ParseOptions *po) {
|
void OnlineModelConfig::Register(ParseOptions *po) {
|
||||||
transducer.Register(po);
|
transducer.Register(po);
|
||||||
|
paraformer.Register(po);
|
||||||
|
|
||||||
po->Register("tokens", &tokens, "Path to tokens.txt");
|
po->Register("tokens", &tokens, "Path to tokens.txt");
|
||||||
|
|
||||||
@@ -41,6 +42,10 @@ bool OnlineModelConfig::Validate() const {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!paraformer.encoder.empty()) {
|
||||||
|
return paraformer.Validate();
|
||||||
|
}
|
||||||
|
|
||||||
return transducer.Validate();
|
return transducer.Validate();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -49,6 +54,7 @@ std::string OnlineModelConfig::ToString() const {
|
|||||||
|
|
||||||
os << "OnlineModelConfig(";
|
os << "OnlineModelConfig(";
|
||||||
os << "transducer=" << transducer.ToString() << ", ";
|
os << "transducer=" << transducer.ToString() << ", ";
|
||||||
|
os << "paraformer=" << paraformer.ToString() << ", ";
|
||||||
os << "tokens=\"" << tokens << "\", ";
|
os << "tokens=\"" << tokens << "\", ";
|
||||||
os << "num_threads=" << num_threads << ", ";
|
os << "num_threads=" << num_threads << ", ";
|
||||||
os << "debug=" << (debug ? "True" : "False") << ", ";
|
os << "debug=" << (debug ? "True" : "False") << ", ";
|
||||||
|
|||||||
@@ -6,12 +6,14 @@
|
|||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/online-paraformer-model-config.h"
|
||||||
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
|
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
|
||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
struct OnlineModelConfig {
|
struct OnlineModelConfig {
|
||||||
OnlineTransducerModelConfig transducer;
|
OnlineTransducerModelConfig transducer;
|
||||||
|
OnlineParaformerModelConfig paraformer;
|
||||||
std::string tokens;
|
std::string tokens;
|
||||||
int32_t num_threads = 1;
|
int32_t num_threads = 1;
|
||||||
bool debug = false;
|
bool debug = false;
|
||||||
@@ -28,9 +30,11 @@ struct OnlineModelConfig {
|
|||||||
|
|
||||||
OnlineModelConfig() = default;
|
OnlineModelConfig() = default;
|
||||||
OnlineModelConfig(const OnlineTransducerModelConfig &transducer,
|
OnlineModelConfig(const OnlineTransducerModelConfig &transducer,
|
||||||
|
const OnlineParaformerModelConfig ¶former,
|
||||||
const std::string &tokens, int32_t num_threads, bool debug,
|
const std::string &tokens, int32_t num_threads, bool debug,
|
||||||
const std::string &provider, const std::string &model_type)
|
const std::string &provider, const std::string &model_type)
|
||||||
: transducer(transducer),
|
: transducer(transducer),
|
||||||
|
paraformer(paraformer),
|
||||||
tokens(tokens),
|
tokens(tokens),
|
||||||
num_threads(num_threads),
|
num_threads(num_threads),
|
||||||
debug(debug),
|
debug(debug),
|
||||||
|
|||||||
23
sherpa-onnx/csrc/online-paraformer-decoder.h
Normal file
23
sherpa-onnx/csrc/online-paraformer-decoder.h
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
// sherpa-onnx/csrc/online-paraformer-decoder.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_DECODER_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_DECODER_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
struct OnlineParaformerDecoderResult {
|
||||||
|
/// The decoded token IDs
|
||||||
|
std::vector<int32_t> tokens;
|
||||||
|
|
||||||
|
int32_t last_non_blank_frame_index = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_DECODER_H_
|
||||||
43
sherpa-onnx/csrc/online-paraformer-model-config.cc
Normal file
43
sherpa-onnx/csrc/online-paraformer-model-config.cc
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
// sherpa-onnx/csrc/online-paraformer-model-config.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/online-paraformer-model-config.h"
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/file-utils.h"
|
||||||
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
void OnlineParaformerModelConfig::Register(ParseOptions *po) {
|
||||||
|
po->Register("paraformer-encoder", &encoder,
|
||||||
|
"Path to encoder.onnx of paraformer.");
|
||||||
|
po->Register("paraformer-decoder", &decoder,
|
||||||
|
"Path to decoder.onnx of paraformer.");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool OnlineParaformerModelConfig::Validate() const {
|
||||||
|
if (!FileExists(encoder)) {
|
||||||
|
SHERPA_ONNX_LOGE("Paraformer encoder %s does not exist", encoder.c_str());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!FileExists(decoder)) {
|
||||||
|
SHERPA_ONNX_LOGE("Paraformer decoder %s does not exist", decoder.c_str());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string OnlineParaformerModelConfig::ToString() const {
|
||||||
|
std::ostringstream os;
|
||||||
|
|
||||||
|
os << "OnlineParaformerModelConfig(";
|
||||||
|
os << "encoder=\"" << encoder << "\", ";
|
||||||
|
os << "decoder=\"" << decoder << "\")";
|
||||||
|
|
||||||
|
return os.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
31
sherpa-onnx/csrc/online-paraformer-model-config.h
Normal file
31
sherpa-onnx/csrc/online-paraformer-model-config.h
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
// sherpa-onnx/csrc/online-paraformer-model-config.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_MODEL_CONFIG_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_MODEL_CONFIG_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/parse-options.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
struct OnlineParaformerModelConfig {
|
||||||
|
std::string encoder;
|
||||||
|
std::string decoder;
|
||||||
|
|
||||||
|
OnlineParaformerModelConfig() = default;
|
||||||
|
|
||||||
|
OnlineParaformerModelConfig(const std::string &encoder,
|
||||||
|
const std::string &decoder)
|
||||||
|
: encoder(encoder), decoder(decoder) {}
|
||||||
|
|
||||||
|
void Register(ParseOptions *po);
|
||||||
|
bool Validate() const;
|
||||||
|
|
||||||
|
std::string ToString() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_MODEL_CONFIG_H_
|
||||||
249
sherpa-onnx/csrc/online-paraformer-model.cc
Normal file
249
sherpa-onnx/csrc/online-paraformer-model.cc
Normal file
@@ -0,0 +1,249 @@
|
|||||||
|
// sherpa-onnx/csrc/online-paraformer-model.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2022-2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/online-paraformer-model.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cmath>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
#include "android/asset_manager.h"
|
||||||
|
#include "android/asset_manager_jni.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
|
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||||
|
#include "sherpa-onnx/csrc/session.h"
|
||||||
|
#include "sherpa-onnx/csrc/text-utils.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
class OnlineParaformerModel::Impl {
|
||||||
|
public:
|
||||||
|
explicit Impl(const OnlineModelConfig &config)
|
||||||
|
: config_(config),
|
||||||
|
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||||
|
sess_opts_(GetSessionOptions(config)),
|
||||||
|
allocator_{} {
|
||||||
|
{
|
||||||
|
auto buf = ReadFile(config.paraformer.encoder);
|
||||||
|
InitEncoder(buf.data(), buf.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto buf = ReadFile(config.paraformer.decoder);
|
||||||
|
InitDecoder(buf.data(), buf.size());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
Impl(AAssetManager *mgr, const OnlineModelConfig &config)
|
||||||
|
: config_(config),
|
||||||
|
env_(ORT_LOGGING_LEVEL_WARNING),
|
||||||
|
sess_opts_(GetSessionOptions(config)),
|
||||||
|
allocator_{} {
|
||||||
|
{
|
||||||
|
auto buf = ReadFile(mgr, config.paraformer.encoder);
|
||||||
|
InitEncoder(buf.data(), buf.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto buf = ReadFile(mgr, config.paraformer.decoder);
|
||||||
|
InitDecoder(buf.data(), buf.size());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
std::vector<Ort::Value> ForwardEncoder(Ort::Value features,
|
||||||
|
Ort::Value features_length) {
|
||||||
|
std::array<Ort::Value, 2> inputs = {std::move(features),
|
||||||
|
std::move(features_length)};
|
||||||
|
|
||||||
|
return encoder_sess_->Run(
|
||||||
|
{}, encoder_input_names_ptr_.data(), inputs.data(), inputs.size(),
|
||||||
|
encoder_output_names_ptr_.data(), encoder_output_names_ptr_.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Ort::Value> ForwardDecoder(Ort::Value encoder_out,
|
||||||
|
Ort::Value encoder_out_length,
|
||||||
|
Ort::Value acoustic_embedding,
|
||||||
|
Ort::Value acoustic_embedding_length,
|
||||||
|
std::vector<Ort::Value> states) {
|
||||||
|
std::vector<Ort::Value> decoder_inputs;
|
||||||
|
decoder_inputs.reserve(4 + states.size());
|
||||||
|
|
||||||
|
decoder_inputs.push_back(std::move(encoder_out));
|
||||||
|
decoder_inputs.push_back(std::move(encoder_out_length));
|
||||||
|
decoder_inputs.push_back(std::move(acoustic_embedding));
|
||||||
|
decoder_inputs.push_back(std::move(acoustic_embedding_length));
|
||||||
|
|
||||||
|
for (auto &v : states) {
|
||||||
|
decoder_inputs.push_back(std::move(v));
|
||||||
|
}
|
||||||
|
|
||||||
|
return decoder_sess_->Run({}, decoder_input_names_ptr_.data(),
|
||||||
|
decoder_inputs.data(), decoder_inputs.size(),
|
||||||
|
decoder_output_names_ptr_.data(),
|
||||||
|
decoder_output_names_ptr_.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t VocabSize() const { return vocab_size_; }
|
||||||
|
|
||||||
|
int32_t LfrWindowSize() const { return lfr_window_size_; }
|
||||||
|
|
||||||
|
int32_t LfrWindowShift() const { return lfr_window_shift_; }
|
||||||
|
|
||||||
|
int32_t EncoderOutputSize() const { return encoder_output_size_; }
|
||||||
|
|
||||||
|
int32_t DecoderKernelSize() const { return decoder_kernel_size_; }
|
||||||
|
|
||||||
|
int32_t DecoderNumBlocks() const { return decoder_num_blocks_; }
|
||||||
|
|
||||||
|
const std::vector<float> &NegativeMean() const { return neg_mean_; }
|
||||||
|
|
||||||
|
const std::vector<float> &InverseStdDev() const { return inv_stddev_; }
|
||||||
|
|
||||||
|
OrtAllocator *Allocator() const { return allocator_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
void InitEncoder(void *model_data, size_t model_data_length) {
|
||||||
|
encoder_sess_ = std::make_unique<Ort::Session>(
|
||||||
|
env_, model_data, model_data_length, sess_opts_);
|
||||||
|
|
||||||
|
GetInputNames(encoder_sess_.get(), &encoder_input_names_,
|
||||||
|
&encoder_input_names_ptr_);
|
||||||
|
|
||||||
|
GetOutputNames(encoder_sess_.get(), &encoder_output_names_,
|
||||||
|
&encoder_output_names_ptr_);
|
||||||
|
|
||||||
|
// get meta data
|
||||||
|
Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata();
|
||||||
|
if (config_.debug) {
|
||||||
|
std::ostringstream os;
|
||||||
|
PrintModelMetadata(os, meta_data);
|
||||||
|
SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
|
||||||
|
SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size");
|
||||||
|
SHERPA_ONNX_READ_META_DATA(lfr_window_size_, "lfr_window_size");
|
||||||
|
SHERPA_ONNX_READ_META_DATA(lfr_window_shift_, "lfr_window_shift");
|
||||||
|
SHERPA_ONNX_READ_META_DATA(encoder_output_size_, "encoder_output_size");
|
||||||
|
SHERPA_ONNX_READ_META_DATA(decoder_num_blocks_, "decoder_num_blocks");
|
||||||
|
SHERPA_ONNX_READ_META_DATA(decoder_kernel_size_, "decoder_kernel_size");
|
||||||
|
|
||||||
|
SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(neg_mean_, "neg_mean");
|
||||||
|
SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(inv_stddev_, "inv_stddev");
|
||||||
|
|
||||||
|
float scale = std::sqrt(encoder_output_size_);
|
||||||
|
for (auto &f : inv_stddev_) {
|
||||||
|
f *= scale;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void InitDecoder(void *model_data, size_t model_data_length) {
|
||||||
|
decoder_sess_ = std::make_unique<Ort::Session>(
|
||||||
|
env_, model_data, model_data_length, sess_opts_);
|
||||||
|
|
||||||
|
GetInputNames(decoder_sess_.get(), &decoder_input_names_,
|
||||||
|
&decoder_input_names_ptr_);
|
||||||
|
|
||||||
|
GetOutputNames(decoder_sess_.get(), &decoder_output_names_,
|
||||||
|
&decoder_output_names_ptr_);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
OnlineModelConfig config_;
|
||||||
|
Ort::Env env_;
|
||||||
|
Ort::SessionOptions sess_opts_;
|
||||||
|
Ort::AllocatorWithDefaultOptions allocator_;
|
||||||
|
|
||||||
|
std::unique_ptr<Ort::Session> encoder_sess_;
|
||||||
|
|
||||||
|
std::vector<std::string> encoder_input_names_;
|
||||||
|
std::vector<const char *> encoder_input_names_ptr_;
|
||||||
|
|
||||||
|
std::vector<std::string> encoder_output_names_;
|
||||||
|
std::vector<const char *> encoder_output_names_ptr_;
|
||||||
|
|
||||||
|
std::unique_ptr<Ort::Session> decoder_sess_;
|
||||||
|
|
||||||
|
std::vector<std::string> decoder_input_names_;
|
||||||
|
std::vector<const char *> decoder_input_names_ptr_;
|
||||||
|
|
||||||
|
std::vector<std::string> decoder_output_names_;
|
||||||
|
std::vector<const char *> decoder_output_names_ptr_;
|
||||||
|
|
||||||
|
std::vector<float> neg_mean_;
|
||||||
|
std::vector<float> inv_stddev_;
|
||||||
|
|
||||||
|
int32_t vocab_size_ = 0; // initialized in Init
|
||||||
|
int32_t lfr_window_size_ = 0;
|
||||||
|
int32_t lfr_window_shift_ = 0;
|
||||||
|
|
||||||
|
int32_t encoder_output_size_ = 0;
|
||||||
|
int32_t decoder_num_blocks_ = 0;
|
||||||
|
int32_t decoder_kernel_size_ = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
OnlineParaformerModel::OnlineParaformerModel(const OnlineModelConfig &config)
|
||||||
|
: impl_(std::make_unique<Impl>(config)) {}
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
OnlineParaformerModel::OnlineParaformerModel(AAssetManager *mgr,
|
||||||
|
const OnlineModelConfig &config)
|
||||||
|
: impl_(std::make_unique<Impl>(mgr, config)) {}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
OnlineParaformerModel::~OnlineParaformerModel() = default;
|
||||||
|
|
||||||
|
std::vector<Ort::Value> OnlineParaformerModel::ForwardEncoder(
|
||||||
|
Ort::Value features, Ort::Value features_length) const {
|
||||||
|
return impl_->ForwardEncoder(std::move(features), std::move(features_length));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Ort::Value> OnlineParaformerModel::ForwardDecoder(
|
||||||
|
Ort::Value encoder_out, Ort::Value encoder_out_length,
|
||||||
|
Ort::Value acoustic_embedding, Ort::Value acoustic_embedding_length,
|
||||||
|
std::vector<Ort::Value> states) const {
|
||||||
|
return impl_->ForwardDecoder(
|
||||||
|
std::move(encoder_out), std::move(encoder_out_length),
|
||||||
|
std::move(acoustic_embedding), std::move(acoustic_embedding_length),
|
||||||
|
std::move(states));
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t OnlineParaformerModel::VocabSize() const { return impl_->VocabSize(); }
|
||||||
|
|
||||||
|
int32_t OnlineParaformerModel::LfrWindowSize() const {
|
||||||
|
return impl_->LfrWindowSize();
|
||||||
|
}
|
||||||
|
int32_t OnlineParaformerModel::LfrWindowShift() const {
|
||||||
|
return impl_->LfrWindowShift();
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t OnlineParaformerModel::EncoderOutputSize() const {
|
||||||
|
return impl_->EncoderOutputSize();
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t OnlineParaformerModel::DecoderKernelSize() const {
|
||||||
|
return impl_->DecoderKernelSize();
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t OnlineParaformerModel::DecoderNumBlocks() const {
|
||||||
|
return impl_->DecoderNumBlocks();
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::vector<float> &OnlineParaformerModel::NegativeMean() const {
|
||||||
|
return impl_->NegativeMean();
|
||||||
|
}
|
||||||
|
const std::vector<float> &OnlineParaformerModel::InverseStdDev() const {
|
||||||
|
return impl_->InverseStdDev();
|
||||||
|
}
|
||||||
|
|
||||||
|
OrtAllocator *OnlineParaformerModel::Allocator() const {
|
||||||
|
return impl_->Allocator();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
76
sherpa-onnx/csrc/online-paraformer-model.h
Normal file
76
sherpa-onnx/csrc/online-paraformer-model.h
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
// sherpa-onnx/csrc/online-paraformer-model.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2022-2023 Xiaomi Corporation
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_MODEL_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_MODEL_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
#include "android/asset_manager.h"
|
||||||
|
#include "android/asset_manager_jni.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||||
|
#include "sherpa-onnx/csrc/online-model-config.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
class OnlineParaformerModel {
|
||||||
|
public:
|
||||||
|
explicit OnlineParaformerModel(const OnlineModelConfig &config);
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
OnlineParaformerModel(AAssetManager *mgr, const OnlineModelConfig &config);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
~OnlineParaformerModel();
|
||||||
|
|
||||||
|
std::vector<Ort::Value> ForwardEncoder(Ort::Value features,
|
||||||
|
Ort::Value features_length) const;
|
||||||
|
|
||||||
|
std::vector<Ort::Value> ForwardDecoder(Ort::Value encoder_out,
|
||||||
|
Ort::Value encoder_out_length,
|
||||||
|
Ort::Value acoustic_embedding,
|
||||||
|
Ort::Value acoustic_embedding_length,
|
||||||
|
std::vector<Ort::Value> states) const;
|
||||||
|
|
||||||
|
/** Return the vocabulary size of the model
|
||||||
|
*/
|
||||||
|
int32_t VocabSize() const;
|
||||||
|
|
||||||
|
/** It is lfr_m in config.yaml
|
||||||
|
*/
|
||||||
|
int32_t LfrWindowSize() const;
|
||||||
|
|
||||||
|
/** It is lfr_n in config.yaml
|
||||||
|
*/
|
||||||
|
int32_t LfrWindowShift() const;
|
||||||
|
|
||||||
|
int32_t EncoderOutputSize() const;
|
||||||
|
|
||||||
|
int32_t DecoderKernelSize() const;
|
||||||
|
int32_t DecoderNumBlocks() const;
|
||||||
|
|
||||||
|
/** Return negative mean for CMVN
|
||||||
|
*/
|
||||||
|
const std::vector<float> &NegativeMean() const;
|
||||||
|
|
||||||
|
/** Return inverse stddev for CMVN
|
||||||
|
*/
|
||||||
|
const std::vector<float> &InverseStdDev() const;
|
||||||
|
|
||||||
|
/** Return an allocator for allocating memory
|
||||||
|
*/
|
||||||
|
OrtAllocator *Allocator() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
class Impl;
|
||||||
|
std::unique_ptr<Impl> impl_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_MODEL_H_
|
||||||
@@ -4,6 +4,7 @@
|
|||||||
|
|
||||||
#include "sherpa-onnx/csrc/online-recognizer-impl.h"
|
#include "sherpa-onnx/csrc/online-recognizer-impl.h"
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/online-recognizer-paraformer-impl.h"
|
||||||
#include "sherpa-onnx/csrc/online-recognizer-transducer-impl.h"
|
#include "sherpa-onnx/csrc/online-recognizer-transducer-impl.h"
|
||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
@@ -14,6 +15,10 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
|
|||||||
return std::make_unique<OnlineRecognizerTransducerImpl>(config);
|
return std::make_unique<OnlineRecognizerTransducerImpl>(config);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!config.model_config.paraformer.encoder.empty()) {
|
||||||
|
return std::make_unique<OnlineRecognizerParaformerImpl>(config);
|
||||||
|
}
|
||||||
|
|
||||||
SHERPA_ONNX_LOGE("Please specify a model");
|
SHERPA_ONNX_LOGE("Please specify a model");
|
||||||
exit(-1);
|
exit(-1);
|
||||||
}
|
}
|
||||||
@@ -25,6 +30,10 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
|
|||||||
return std::make_unique<OnlineRecognizerTransducerImpl>(mgr, config);
|
return std::make_unique<OnlineRecognizerTransducerImpl>(mgr, config);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!config.model_config.paraformer.encoder.empty()) {
|
||||||
|
return std::make_unique<OnlineRecognizerParaformerImpl>(mgr, config);
|
||||||
|
}
|
||||||
|
|
||||||
SHERPA_ONNX_LOGE("Please specify a model");
|
SHERPA_ONNX_LOGE("Please specify a model");
|
||||||
exit(-1);
|
exit(-1);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -26,8 +26,6 @@ class OnlineRecognizerImpl {
|
|||||||
|
|
||||||
virtual ~OnlineRecognizerImpl() = default;
|
virtual ~OnlineRecognizerImpl() = default;
|
||||||
|
|
||||||
virtual void InitOnlineStream(OnlineStream *stream) const = 0;
|
|
||||||
|
|
||||||
virtual std::unique_ptr<OnlineStream> CreateStream() const = 0;
|
virtual std::unique_ptr<OnlineStream> CreateStream() const = 0;
|
||||||
|
|
||||||
virtual std::unique_ptr<OnlineStream> CreateStream(
|
virtual std::unique_ptr<OnlineStream> CreateStream(
|
||||||
|
|||||||
465
sherpa-onnx/csrc/online-recognizer-paraformer-impl.h
Normal file
465
sherpa-onnx/csrc/online-recognizer-paraformer-impl.h
Normal file
@@ -0,0 +1,465 @@
|
|||||||
|
// sherpa-onnx/csrc/online-recognizer-paraformer-impl.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2022-2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_PARAFORMER_IMPL_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_PARAFORMER_IMPL_H_
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/file-utils.h"
|
||||||
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
|
#include "sherpa-onnx/csrc/online-lm.h"
|
||||||
|
#include "sherpa-onnx/csrc/online-paraformer-decoder.h"
|
||||||
|
#include "sherpa-onnx/csrc/online-paraformer-model.h"
|
||||||
|
#include "sherpa-onnx/csrc/online-recognizer-impl.h"
|
||||||
|
#include "sherpa-onnx/csrc/online-recognizer.h"
|
||||||
|
#include "sherpa-onnx/csrc/symbol-table.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
static OnlineRecognizerResult Convert(const OnlineParaformerDecoderResult &src,
|
||||||
|
const SymbolTable &sym_table) {
|
||||||
|
OnlineRecognizerResult r;
|
||||||
|
r.tokens.reserve(src.tokens.size());
|
||||||
|
|
||||||
|
std::string text;
|
||||||
|
|
||||||
|
// When the current token ends with "@@" we set mergeable to true
|
||||||
|
bool mergeable = false;
|
||||||
|
|
||||||
|
for (int32_t i = 0; i != src.tokens.size(); ++i) {
|
||||||
|
auto sym = sym_table[src.tokens[i]];
|
||||||
|
r.tokens.push_back(sym);
|
||||||
|
|
||||||
|
if ((sym.back() != '@') || (sym.size() > 2 && sym[sym.size() - 2] != '@')) {
|
||||||
|
// sym does not end with "@@"
|
||||||
|
const uint8_t *p = reinterpret_cast<const uint8_t *>(sym.c_str());
|
||||||
|
if (p[0] < 0x80) {
|
||||||
|
// an ascii
|
||||||
|
if (mergeable) {
|
||||||
|
mergeable = false;
|
||||||
|
text.append(sym);
|
||||||
|
} else {
|
||||||
|
text.append(" ");
|
||||||
|
text.append(sym);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// not an ascii
|
||||||
|
mergeable = false;
|
||||||
|
|
||||||
|
if (i > 0) {
|
||||||
|
const uint8_t *p = reinterpret_cast<const uint8_t *>(
|
||||||
|
sym_table[src.tokens[i - 1]].c_str());
|
||||||
|
if (p[0] < 0x80) {
|
||||||
|
// put a space between ascii and non-ascii
|
||||||
|
text.append(" ");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
text.append(sym);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// this sym ends with @@
|
||||||
|
sym = std::string(sym.data(), sym.size() - 2);
|
||||||
|
if (mergeable) {
|
||||||
|
text.append(sym);
|
||||||
|
} else {
|
||||||
|
text.append(" ");
|
||||||
|
text.append(sym);
|
||||||
|
mergeable = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
r.text = std::move(text);
|
||||||
|
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
|
||||||
|
// y[i] += x[i] * scale
|
||||||
|
static void ScaleAddInPlace(const float *x, int32_t n, float scale, float *y) {
|
||||||
|
for (int32_t i = 0; i != n; ++i) {
|
||||||
|
y[i] += x[i] * scale;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// y[i] = x[i] * scale
|
||||||
|
static void Scale(const float *x, int32_t n, float scale, float *y) {
|
||||||
|
for (int32_t i = 0; i != n; ++i) {
|
||||||
|
y[i] = x[i] * scale;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class OnlineRecognizerParaformerImpl : public OnlineRecognizerImpl {
|
||||||
|
public:
|
||||||
|
explicit OnlineRecognizerParaformerImpl(const OnlineRecognizerConfig &config)
|
||||||
|
: config_(config),
|
||||||
|
model_(config.model_config),
|
||||||
|
sym_(config.model_config.tokens),
|
||||||
|
endpoint_(config_.endpoint_config) {
|
||||||
|
if (config.decoding_method != "greedy_search") {
|
||||||
|
SHERPA_ONNX_LOGE(
|
||||||
|
"Unsupported decoding method: %s. Support only greedy_search at "
|
||||||
|
"present",
|
||||||
|
config.decoding_method.c_str());
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Paraformer models assume input samples are in the range
|
||||||
|
// [-32768, 32767], so we set normalize_samples to false
|
||||||
|
config_.feat_config.normalize_samples = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
explicit OnlineRecognizerParaformerImpl(AAssetManager *mgr,
|
||||||
|
const OnlineRecognizerConfig &config)
|
||||||
|
: config_(config),
|
||||||
|
model_(mgr, config.model_config),
|
||||||
|
sym_(mgr, config.model_config.tokens),
|
||||||
|
endpoint_(config_.endpoint_config) {
|
||||||
|
if (config.decoding_method == "greedy_search") {
|
||||||
|
// add greedy search decoder
|
||||||
|
// SHERPA_ONNX_LOGE("to be implemented");
|
||||||
|
// exit(-1);
|
||||||
|
} else {
|
||||||
|
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
|
||||||
|
config.decoding_method.c_str());
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Paraformer models assume input samples are in the range
|
||||||
|
// [-32768, 32767], so we set normalize_samples to false
|
||||||
|
config_.feat_config.normalize_samples = false;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
OnlineRecognizerParaformerImpl(const OnlineRecognizerParaformerImpl &) =
|
||||||
|
delete;
|
||||||
|
|
||||||
|
OnlineRecognizerParaformerImpl operator=(
|
||||||
|
const OnlineRecognizerParaformerImpl &) = delete;
|
||||||
|
|
||||||
|
std::unique_ptr<OnlineStream> CreateStream() const override {
|
||||||
|
auto stream = std::make_unique<OnlineStream>(config_.feat_config);
|
||||||
|
|
||||||
|
OnlineParaformerDecoderResult r;
|
||||||
|
stream->SetParaformerResult(r);
|
||||||
|
|
||||||
|
return stream;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IsReady(OnlineStream *s) const override {
|
||||||
|
return s->GetNumProcessedFrames() + chunk_size_ < s->NumFramesReady();
|
||||||
|
}
|
||||||
|
|
||||||
|
void DecodeStreams(OnlineStream **ss, int32_t n) const override {
|
||||||
|
// TODO(fangjun): Support batch size > 1
|
||||||
|
for (int32_t i = 0; i != n; ++i) {
|
||||||
|
DecodeStream(ss[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
OnlineRecognizerResult GetResult(OnlineStream *s) const override {
|
||||||
|
auto decoder_result = s->GetParaformerResult();
|
||||||
|
|
||||||
|
return Convert(decoder_result, sym_);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IsEndpoint(OnlineStream *s) const override {
|
||||||
|
if (!config_.enable_endpoint) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto &result = s->GetParaformerResult();
|
||||||
|
|
||||||
|
int32_t num_processed_frames = s->GetNumProcessedFrames();
|
||||||
|
|
||||||
|
// frame shift is 10 milliseconds
|
||||||
|
float frame_shift_in_seconds = 0.01;
|
||||||
|
|
||||||
|
int32_t trailing_silence_frames =
|
||||||
|
num_processed_frames - result.last_non_blank_frame_index;
|
||||||
|
|
||||||
|
return endpoint_.IsEndpoint(num_processed_frames, trailing_silence_frames,
|
||||||
|
frame_shift_in_seconds);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Reset(OnlineStream *s) const override {
|
||||||
|
OnlineParaformerDecoderResult r;
|
||||||
|
s->SetParaformerResult(r);
|
||||||
|
|
||||||
|
// the internal model caches are not reset
|
||||||
|
|
||||||
|
// Note: We only update counters. The underlying audio samples
|
||||||
|
// are not discarded.
|
||||||
|
s->Reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
void DecodeStream(OnlineStream *s) const {
|
||||||
|
const auto num_processed_frames = s->GetNumProcessedFrames();
|
||||||
|
std::vector<float> frames = s->GetFrames(num_processed_frames, chunk_size_);
|
||||||
|
s->GetNumProcessedFrames() += chunk_size_ - 1;
|
||||||
|
|
||||||
|
frames = ApplyLFR(frames);
|
||||||
|
ApplyCMVN(&frames);
|
||||||
|
PositionalEncoding(&frames, num_processed_frames / model_.LfrWindowShift());
|
||||||
|
|
||||||
|
int32_t feat_dim = model_.NegativeMean().size();
|
||||||
|
|
||||||
|
// We have scaled inv_stddev by sqrt(encoder_output_size)
|
||||||
|
// so the following line can be commented out
|
||||||
|
// frames *= encoder_output_size ** 0.5
|
||||||
|
|
||||||
|
// add overlap chunk
|
||||||
|
std::vector<float> &feat_cache = s->GetParaformerFeatCache();
|
||||||
|
if (feat_cache.empty()) {
|
||||||
|
int32_t n = (left_chunk_size_ + right_chunk_size_) * feat_dim;
|
||||||
|
feat_cache.resize(n, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
frames.insert(frames.begin(), feat_cache.begin(), feat_cache.end());
|
||||||
|
std::copy(frames.end() - feat_cache.size(), frames.end(),
|
||||||
|
feat_cache.begin());
|
||||||
|
|
||||||
|
int32_t num_frames = frames.size() / feat_dim;
|
||||||
|
|
||||||
|
auto memory_info =
|
||||||
|
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||||
|
|
||||||
|
std::array<int64_t, 3> x_shape{1, num_frames, feat_dim};
|
||||||
|
Ort::Value x =
|
||||||
|
Ort::Value::CreateTensor(memory_info, frames.data(), frames.size(),
|
||||||
|
x_shape.data(), x_shape.size());
|
||||||
|
|
||||||
|
int64_t x_len_shape = 1;
|
||||||
|
int32_t x_len_val = num_frames;
|
||||||
|
|
||||||
|
Ort::Value x_length =
|
||||||
|
Ort::Value::CreateTensor(memory_info, &x_len_val, 1, &x_len_shape, 1);
|
||||||
|
|
||||||
|
auto encoder_out_vec =
|
||||||
|
model_.ForwardEncoder(std::move(x), std::move(x_length));
|
||||||
|
|
||||||
|
// CIF search
|
||||||
|
auto &encoder_out = encoder_out_vec[0];
|
||||||
|
auto &encoder_out_len = encoder_out_vec[1];
|
||||||
|
auto &alpha = encoder_out_vec[2];
|
||||||
|
|
||||||
|
float *p_alpha = alpha.GetTensorMutableData<float>();
|
||||||
|
|
||||||
|
std::vector<int64_t> alpha_shape =
|
||||||
|
alpha.GetTensorTypeAndShapeInfo().GetShape();
|
||||||
|
|
||||||
|
std::fill(p_alpha, p_alpha + left_chunk_size_, 0);
|
||||||
|
std::fill(p_alpha + alpha_shape[1] - right_chunk_size_,
|
||||||
|
p_alpha + alpha_shape[1], 0);
|
||||||
|
|
||||||
|
const float *p_encoder_out = encoder_out.GetTensorData<float>();
|
||||||
|
|
||||||
|
std::vector<int64_t> encoder_out_shape =
|
||||||
|
encoder_out.GetTensorTypeAndShapeInfo().GetShape();
|
||||||
|
|
||||||
|
std::vector<float> &initial_hidden = s->GetParaformerEncoderOutCache();
|
||||||
|
if (initial_hidden.empty()) {
|
||||||
|
initial_hidden.resize(encoder_out_shape[2]);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<float> &alpha_cache = s->GetParaformerAlphaCache();
|
||||||
|
if (alpha_cache.empty()) {
|
||||||
|
alpha_cache.resize(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<float> acoustic_embedding;
|
||||||
|
acoustic_embedding.reserve(encoder_out_shape[1] * encoder_out_shape[2]);
|
||||||
|
|
||||||
|
float threshold = 1.0;
|
||||||
|
|
||||||
|
float integrate = alpha_cache[0];
|
||||||
|
|
||||||
|
for (int32_t i = 0; i != encoder_out_shape[1]; ++i) {
|
||||||
|
float this_alpha = p_alpha[i];
|
||||||
|
if (integrate + this_alpha < threshold) {
|
||||||
|
integrate += this_alpha;
|
||||||
|
ScaleAddInPlace(p_encoder_out + i * encoder_out_shape[2],
|
||||||
|
encoder_out_shape[2], this_alpha,
|
||||||
|
initial_hidden.data());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fire
|
||||||
|
ScaleAddInPlace(p_encoder_out + i * encoder_out_shape[2],
|
||||||
|
encoder_out_shape[2], threshold - integrate,
|
||||||
|
initial_hidden.data());
|
||||||
|
acoustic_embedding.insert(acoustic_embedding.end(),
|
||||||
|
initial_hidden.begin(), initial_hidden.end());
|
||||||
|
integrate += this_alpha - threshold;
|
||||||
|
|
||||||
|
Scale(p_encoder_out + i * encoder_out_shape[2], encoder_out_shape[2],
|
||||||
|
integrate, initial_hidden.data());
|
||||||
|
}
|
||||||
|
|
||||||
|
alpha_cache[0] = integrate;
|
||||||
|
|
||||||
|
if (acoustic_embedding.empty()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto &states = s->GetStates();
|
||||||
|
if (states.empty()) {
|
||||||
|
states.reserve(model_.DecoderNumBlocks());
|
||||||
|
|
||||||
|
std::array<int64_t, 3> shape{1, model_.EncoderOutputSize(),
|
||||||
|
model_.DecoderKernelSize() - 1};
|
||||||
|
|
||||||
|
int32_t num_bytes = sizeof(float) * shape[0] * shape[1] * shape[2];
|
||||||
|
|
||||||
|
for (int32_t i = 0; i != model_.DecoderNumBlocks(); ++i) {
|
||||||
|
Ort::Value this_state = Ort::Value::CreateTensor<float>(
|
||||||
|
model_.Allocator(), shape.data(), shape.size());
|
||||||
|
|
||||||
|
memset(this_state.GetTensorMutableData<float>(), 0, num_bytes);
|
||||||
|
|
||||||
|
states.push_back(std::move(this_state));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t num_tokens = acoustic_embedding.size() / initial_hidden.size();
|
||||||
|
std::array<int64_t, 3> acoustic_embedding_shape{
|
||||||
|
1, num_tokens, static_cast<int32_t>(initial_hidden.size())};
|
||||||
|
|
||||||
|
Ort::Value acoustic_embedding_tensor = Ort::Value::CreateTensor(
|
||||||
|
memory_info, acoustic_embedding.data(), acoustic_embedding.size(),
|
||||||
|
acoustic_embedding_shape.data(), acoustic_embedding_shape.size());
|
||||||
|
|
||||||
|
std::array<int64_t, 1> acoustic_embedding_length_shape{1};
|
||||||
|
Ort::Value acoustic_embedding_length_tensor = Ort::Value::CreateTensor(
|
||||||
|
memory_info, &num_tokens, 1, acoustic_embedding_length_shape.data(),
|
||||||
|
acoustic_embedding_length_shape.size());
|
||||||
|
|
||||||
|
auto decoder_out_vec = model_.ForwardDecoder(
|
||||||
|
std::move(encoder_out), std::move(encoder_out_len),
|
||||||
|
std::move(acoustic_embedding_tensor),
|
||||||
|
std::move(acoustic_embedding_length_tensor), std::move(states));
|
||||||
|
|
||||||
|
states.reserve(model_.DecoderNumBlocks());
|
||||||
|
for (int32_t i = 2; i != decoder_out_vec.size(); ++i) {
|
||||||
|
// TODO(fangjun): When we change chunk_size_, we need to
|
||||||
|
// slice decoder_out_vec[i] accordingly.
|
||||||
|
states.push_back(std::move(decoder_out_vec[i]));
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto &sample_ids = decoder_out_vec[1];
|
||||||
|
const int64_t *p_sample_ids = sample_ids.GetTensorData<int64_t>();
|
||||||
|
|
||||||
|
bool non_blank_detected = false;
|
||||||
|
|
||||||
|
auto &result = s->GetParaformerResult();
|
||||||
|
|
||||||
|
for (int32_t i = 0; i != num_tokens; ++i) {
|
||||||
|
int32_t t = p_sample_ids[i];
|
||||||
|
if (t == 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
non_blank_detected = true;
|
||||||
|
result.tokens.push_back(t);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (non_blank_detected) {
|
||||||
|
result.last_non_blank_frame_index = num_processed_frames;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<float> ApplyLFR(const std::vector<float> &in) const {
|
||||||
|
int32_t lfr_window_size = model_.LfrWindowSize();
|
||||||
|
int32_t lfr_window_shift = model_.LfrWindowShift();
|
||||||
|
int32_t in_feat_dim = config_.feat_config.feature_dim;
|
||||||
|
|
||||||
|
int32_t in_num_frames = in.size() / in_feat_dim;
|
||||||
|
int32_t out_num_frames =
|
||||||
|
(in_num_frames - lfr_window_size) / lfr_window_shift + 1;
|
||||||
|
int32_t out_feat_dim = in_feat_dim * lfr_window_size;
|
||||||
|
|
||||||
|
std::vector<float> out(out_num_frames * out_feat_dim);
|
||||||
|
|
||||||
|
const float *p_in = in.data();
|
||||||
|
float *p_out = out.data();
|
||||||
|
|
||||||
|
for (int32_t i = 0; i != out_num_frames; ++i) {
|
||||||
|
std::copy(p_in, p_in + out_feat_dim, p_out);
|
||||||
|
|
||||||
|
p_out += out_feat_dim;
|
||||||
|
p_in += lfr_window_shift * in_feat_dim;
|
||||||
|
}
|
||||||
|
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ApplyCMVN(std::vector<float> *v) const {
|
||||||
|
const std::vector<float> &neg_mean = model_.NegativeMean();
|
||||||
|
const std::vector<float> &inv_stddev = model_.InverseStdDev();
|
||||||
|
|
||||||
|
int32_t dim = neg_mean.size();
|
||||||
|
int32_t num_frames = v->size() / dim;
|
||||||
|
|
||||||
|
float *p = v->data();
|
||||||
|
|
||||||
|
for (int32_t i = 0; i != num_frames; ++i) {
|
||||||
|
for (int32_t k = 0; k != dim; ++k) {
|
||||||
|
p[k] = (p[k] + neg_mean[k]) * inv_stddev[k];
|
||||||
|
}
|
||||||
|
|
||||||
|
p += dim;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void PositionalEncoding(std::vector<float> *v, int32_t t_offset) const {
|
||||||
|
int32_t lfr_window_size = model_.LfrWindowSize();
|
||||||
|
int32_t in_feat_dim = config_.feat_config.feature_dim;
|
||||||
|
|
||||||
|
int32_t feat_dim = in_feat_dim * lfr_window_size;
|
||||||
|
int32_t T = v->size() / feat_dim;
|
||||||
|
|
||||||
|
// log(10000)/(7*80/2-1) == 0.03301197265941284
|
||||||
|
// 7 is lfr_window_size
|
||||||
|
// 80 is in_feat_dim
|
||||||
|
// 7*80 is feat_dim
|
||||||
|
constexpr float kScale = -0.03301197265941284;
|
||||||
|
|
||||||
|
for (int32_t t = 0; t != T; ++t) {
|
||||||
|
float *p = v->data() + t * feat_dim;
|
||||||
|
|
||||||
|
int32_t offset = t + 1 + t_offset;
|
||||||
|
|
||||||
|
for (int32_t d = 0; d < feat_dim / 2; ++d) {
|
||||||
|
float inv_timescale = offset * std::exp(d * kScale);
|
||||||
|
|
||||||
|
float sin_d = std::sin(inv_timescale);
|
||||||
|
float cos_d = std::cos(inv_timescale);
|
||||||
|
|
||||||
|
p[d] += sin_d;
|
||||||
|
p[d + feat_dim / 2] += cos_d;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
OnlineRecognizerConfig config_;
|
||||||
|
OnlineParaformerModel model_;
|
||||||
|
SymbolTable sym_;
|
||||||
|
Endpoint endpoint_;
|
||||||
|
|
||||||
|
// 0.61 seconds
|
||||||
|
int32_t chunk_size_ = 61;
|
||||||
|
// (61 - 7) / 6 + 1 = 10
|
||||||
|
|
||||||
|
int32_t left_chunk_size_ = 5;
|
||||||
|
int32_t right_chunk_size_ = 5;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_PARAFORMER_IMPL_H_
|
||||||
@@ -94,21 +94,6 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
void InitOnlineStream(OnlineStream *stream) const override {
|
|
||||||
auto r = decoder_->GetEmptyResult();
|
|
||||||
|
|
||||||
if (config_.decoding_method == "modified_beam_search" &&
|
|
||||||
nullptr != stream->GetContextGraph()) {
|
|
||||||
// r.hyps has only one element.
|
|
||||||
for (auto it = r.hyps.begin(); it != r.hyps.end(); ++it) {
|
|
||||||
it->second.context_state = stream->GetContextGraph()->Root();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
stream->SetResult(r);
|
|
||||||
stream->SetStates(model_->GetEncoderInitStates());
|
|
||||||
}
|
|
||||||
|
|
||||||
std::unique_ptr<OnlineStream> CreateStream() const override {
|
std::unique_ptr<OnlineStream> CreateStream() const override {
|
||||||
auto stream = std::make_unique<OnlineStream>(config_.feat_config);
|
auto stream = std::make_unique<OnlineStream>(config_.feat_config);
|
||||||
InitOnlineStream(stream.get());
|
InitOnlineStream(stream.get());
|
||||||
@@ -211,7 +196,10 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool IsEndpoint(OnlineStream *s) const override {
|
bool IsEndpoint(OnlineStream *s) const override {
|
||||||
if (!config_.enable_endpoint) return false;
|
if (!config_.enable_endpoint) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
int32_t num_processed_frames = s->GetNumProcessedFrames();
|
int32_t num_processed_frames = s->GetNumProcessedFrames();
|
||||||
|
|
||||||
// frame shift is 10 milliseconds
|
// frame shift is 10 milliseconds
|
||||||
@@ -244,6 +232,22 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
|||||||
s->Reset();
|
s->Reset();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
void InitOnlineStream(OnlineStream *stream) const {
|
||||||
|
auto r = decoder_->GetEmptyResult();
|
||||||
|
|
||||||
|
if (config_.decoding_method == "modified_beam_search" &&
|
||||||
|
nullptr != stream->GetContextGraph()) {
|
||||||
|
// r.hyps has only one element.
|
||||||
|
for (auto it = r.hyps.begin(); it != r.hyps.end(); ++it) {
|
||||||
|
it->second.context_state = stream->GetContextGraph()->Root();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
stream->SetResult(r);
|
||||||
|
stream->SetStates(model_->GetEncoderInitStates());
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
OnlineRecognizerConfig config_;
|
OnlineRecognizerConfig config_;
|
||||||
std::unique_ptr<OnlineTransducerModel> model_;
|
std::unique_ptr<OnlineTransducerModel> model_;
|
||||||
|
|||||||
@@ -47,6 +47,14 @@ class OnlineStream::Impl {
|
|||||||
|
|
||||||
OnlineTransducerDecoderResult &GetResult() { return result_; }
|
OnlineTransducerDecoderResult &GetResult() { return result_; }
|
||||||
|
|
||||||
|
void SetParaformerResult(const OnlineParaformerDecoderResult &r) {
|
||||||
|
paraformer_result_ = r;
|
||||||
|
}
|
||||||
|
|
||||||
|
OnlineParaformerDecoderResult &GetParaformerResult() {
|
||||||
|
return paraformer_result_;
|
||||||
|
}
|
||||||
|
|
||||||
int32_t FeatureDim() const { return feat_extractor_.FeatureDim(); }
|
int32_t FeatureDim() const { return feat_extractor_.FeatureDim(); }
|
||||||
|
|
||||||
void SetStates(std::vector<Ort::Value> states) {
|
void SetStates(std::vector<Ort::Value> states) {
|
||||||
@@ -57,6 +65,18 @@ class OnlineStream::Impl {
|
|||||||
|
|
||||||
const ContextGraphPtr &GetContextGraph() const { return context_graph_; }
|
const ContextGraphPtr &GetContextGraph() const { return context_graph_; }
|
||||||
|
|
||||||
|
std::vector<float> &GetParaformerFeatCache() {
|
||||||
|
return paraformer_feat_cache_;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<float> &GetParaformerEncoderOutCache() {
|
||||||
|
return paraformer_encoder_out_cache_;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<float> &GetParaformerAlphaCache() {
|
||||||
|
return paraformer_alpha_cache_;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
FeatureExtractor feat_extractor_;
|
FeatureExtractor feat_extractor_;
|
||||||
/// For contextual-biasing
|
/// For contextual-biasing
|
||||||
@@ -65,6 +85,10 @@ class OnlineStream::Impl {
|
|||||||
int32_t start_frame_index_ = 0; // never reset
|
int32_t start_frame_index_ = 0; // never reset
|
||||||
OnlineTransducerDecoderResult result_;
|
OnlineTransducerDecoderResult result_;
|
||||||
std::vector<Ort::Value> states_;
|
std::vector<Ort::Value> states_;
|
||||||
|
std::vector<float> paraformer_feat_cache_;
|
||||||
|
std::vector<float> paraformer_encoder_out_cache_;
|
||||||
|
std::vector<float> paraformer_alpha_cache_;
|
||||||
|
OnlineParaformerDecoderResult paraformer_result_;
|
||||||
};
|
};
|
||||||
|
|
||||||
OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/,
|
OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/,
|
||||||
@@ -107,6 +131,14 @@ OnlineTransducerDecoderResult &OnlineStream::GetResult() {
|
|||||||
return impl_->GetResult();
|
return impl_->GetResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void OnlineStream::SetParaformerResult(const OnlineParaformerDecoderResult &r) {
|
||||||
|
impl_->SetParaformerResult(r);
|
||||||
|
}
|
||||||
|
|
||||||
|
OnlineParaformerDecoderResult &OnlineStream::GetParaformerResult() {
|
||||||
|
return impl_->GetParaformerResult();
|
||||||
|
}
|
||||||
|
|
||||||
void OnlineStream::SetStates(std::vector<Ort::Value> states) {
|
void OnlineStream::SetStates(std::vector<Ort::Value> states) {
|
||||||
impl_->SetStates(std::move(states));
|
impl_->SetStates(std::move(states));
|
||||||
}
|
}
|
||||||
@@ -119,4 +151,16 @@ const ContextGraphPtr &OnlineStream::GetContextGraph() const {
|
|||||||
return impl_->GetContextGraph();
|
return impl_->GetContextGraph();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<float> &OnlineStream::GetParaformerFeatCache() {
|
||||||
|
return impl_->GetParaformerFeatCache();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<float> &OnlineStream::GetParaformerEncoderOutCache() {
|
||||||
|
return impl_->GetParaformerEncoderOutCache();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<float> &OnlineStream::GetParaformerAlphaCache() {
|
||||||
|
return impl_->GetParaformerAlphaCache();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|||||||
@@ -11,6 +11,7 @@
|
|||||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||||
#include "sherpa-onnx/csrc/context-graph.h"
|
#include "sherpa-onnx/csrc/context-graph.h"
|
||||||
#include "sherpa-onnx/csrc/features.h"
|
#include "sherpa-onnx/csrc/features.h"
|
||||||
|
#include "sherpa-onnx/csrc/online-paraformer-decoder.h"
|
||||||
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
|
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
|
||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
@@ -70,6 +71,9 @@ class OnlineStream {
|
|||||||
void SetResult(const OnlineTransducerDecoderResult &r);
|
void SetResult(const OnlineTransducerDecoderResult &r);
|
||||||
OnlineTransducerDecoderResult &GetResult();
|
OnlineTransducerDecoderResult &GetResult();
|
||||||
|
|
||||||
|
void SetParaformerResult(const OnlineParaformerDecoderResult &r);
|
||||||
|
OnlineParaformerDecoderResult &GetParaformerResult();
|
||||||
|
|
||||||
void SetStates(std::vector<Ort::Value> states);
|
void SetStates(std::vector<Ort::Value> states);
|
||||||
std::vector<Ort::Value> &GetStates();
|
std::vector<Ort::Value> &GetStates();
|
||||||
|
|
||||||
@@ -80,6 +84,11 @@ class OnlineStream {
|
|||||||
*/
|
*/
|
||||||
const ContextGraphPtr &GetContextGraph() const;
|
const ContextGraphPtr &GetContextGraph() const;
|
||||||
|
|
||||||
|
// for streaming parformer
|
||||||
|
std::vector<float> &GetParaformerFeatCache();
|
||||||
|
std::vector<float> &GetParaformerEncoderOutCache();
|
||||||
|
std::vector<float> &GetParaformerAlphaCache();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
class Impl;
|
class Impl;
|
||||||
std::unique_ptr<Impl> impl_;
|
std::unique_ptr<Impl> impl_;
|
||||||
|
|||||||
@@ -12,8 +12,8 @@
|
|||||||
|
|
||||||
#include "sherpa-onnx/csrc/online-recognizer.h"
|
#include "sherpa-onnx/csrc/online-recognizer.h"
|
||||||
#include "sherpa-onnx/csrc/online-stream.h"
|
#include "sherpa-onnx/csrc/online-stream.h"
|
||||||
#include "sherpa-onnx/csrc/symbol-table.h"
|
|
||||||
#include "sherpa-onnx/csrc/parse-options.h"
|
#include "sherpa-onnx/csrc/parse-options.h"
|
||||||
|
#include "sherpa-onnx/csrc/symbol-table.h"
|
||||||
#include "sherpa-onnx/csrc/wave-reader.h"
|
#include "sherpa-onnx/csrc/wave-reader.h"
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
@@ -80,7 +80,7 @@ for a list of pre-trained models to download.
|
|||||||
|
|
||||||
bool is_ok = false;
|
bool is_ok = false;
|
||||||
const std::vector<float> samples =
|
const std::vector<float> samples =
|
||||||
sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok);
|
sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok);
|
||||||
|
|
||||||
if (!is_ok) {
|
if (!is_ok) {
|
||||||
fprintf(stderr, "Failed to read %s\n", wav_filename.c_str());
|
fprintf(stderr, "Failed to read %s\n", wav_filename.c_str());
|
||||||
@@ -92,14 +92,14 @@ for a list of pre-trained models to download.
|
|||||||
auto s = recognizer.CreateStream();
|
auto s = recognizer.CreateStream();
|
||||||
s->AcceptWaveform(sampling_rate, samples.data(), samples.size());
|
s->AcceptWaveform(sampling_rate, samples.data(), samples.size());
|
||||||
|
|
||||||
std::vector<float> tail_paddings(static_cast<int>(0.3 * sampling_rate));
|
std::vector<float> tail_paddings(static_cast<int>(0.8 * sampling_rate));
|
||||||
// Note: We can call AcceptWaveform() multiple times.
|
// Note: We can call AcceptWaveform() multiple times.
|
||||||
s->AcceptWaveform(
|
s->AcceptWaveform(sampling_rate, tail_paddings.data(),
|
||||||
sampling_rate, tail_paddings.data(), tail_paddings.size());
|
tail_paddings.size());
|
||||||
|
|
||||||
// Call InputFinished() to indicate that no audio samples are available
|
// Call InputFinished() to indicate that no audio samples are available
|
||||||
s->InputFinished();
|
s->InputFinished();
|
||||||
ss.push_back({ std::move(s), duration, 0 });
|
ss.push_back({std::move(s), duration, 0});
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<sherpa_onnx::OnlineStream *> ready_streams;
|
std::vector<sherpa_onnx::OnlineStream *> ready_streams;
|
||||||
@@ -112,8 +112,9 @@ for a list of pre-trained models to download.
|
|||||||
} else if (s.elapsed_seconds == 0) {
|
} else if (s.elapsed_seconds == 0) {
|
||||||
const auto end = std::chrono::steady_clock::now();
|
const auto end = std::chrono::steady_clock::now();
|
||||||
const float elapsed_seconds =
|
const float elapsed_seconds =
|
||||||
std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)
|
std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)
|
||||||
.count() / 1000.;
|
.count() /
|
||||||
|
1000.;
|
||||||
s.elapsed_seconds = elapsed_seconds;
|
s.elapsed_seconds = elapsed_seconds;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ pybind11_add_module(_sherpa_onnx
|
|||||||
offline-whisper-model-config.cc
|
offline-whisper-model-config.cc
|
||||||
online-lm-config.cc
|
online-lm-config.cc
|
||||||
online-model-config.cc
|
online-model-config.cc
|
||||||
|
online-paraformer-model-config.cc
|
||||||
online-recognizer.cc
|
online-recognizer.cc
|
||||||
online-stream.cc
|
online-stream.cc
|
||||||
online-transducer-model-config.cc
|
online-transducer-model-config.cc
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
// sherpa-onnx/python/csrc/online-model-config.cc
|
// sherpa-onnx/python/csrc/online-model-config.cc
|
||||||
//
|
//
|
||||||
// Copyright (c) 2023 by manyeyes
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
#include "sherpa-onnx/python/csrc/online-model-config.h"
|
#include "sherpa-onnx/python/csrc/online-model-config.h"
|
||||||
|
|
||||||
@@ -9,21 +9,26 @@
|
|||||||
|
|
||||||
#include "sherpa-onnx/csrc/online-model-config.h"
|
#include "sherpa-onnx/csrc/online-model-config.h"
|
||||||
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
|
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
|
||||||
|
#include "sherpa-onnx/python/csrc/online-paraformer-model-config.h"
|
||||||
#include "sherpa-onnx/python/csrc/online-transducer-model-config.h"
|
#include "sherpa-onnx/python/csrc/online-transducer-model-config.h"
|
||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
void PybindOnlineModelConfig(py::module *m) {
|
void PybindOnlineModelConfig(py::module *m) {
|
||||||
PybindOnlineTransducerModelConfig(m);
|
PybindOnlineTransducerModelConfig(m);
|
||||||
|
PybindOnlineParaformerModelConfig(m);
|
||||||
|
|
||||||
using PyClass = OnlineModelConfig;
|
using PyClass = OnlineModelConfig;
|
||||||
py::class_<PyClass>(*m, "OnlineModelConfig")
|
py::class_<PyClass>(*m, "OnlineModelConfig")
|
||||||
.def(py::init<const OnlineTransducerModelConfig &, std::string &, int32_t,
|
.def(py::init<const OnlineTransducerModelConfig &,
|
||||||
|
const OnlineParaformerModelConfig &, std::string &, int32_t,
|
||||||
bool, const std::string &, const std::string &>(),
|
bool, const std::string &, const std::string &>(),
|
||||||
py::arg("transducer") = OnlineTransducerModelConfig(),
|
py::arg("transducer") = OnlineTransducerModelConfig(),
|
||||||
|
py::arg("paraformer") = OnlineParaformerModelConfig(),
|
||||||
py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false,
|
py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false,
|
||||||
py::arg("provider") = "cpu", py::arg("model_type") = "")
|
py::arg("provider") = "cpu", py::arg("model_type") = "")
|
||||||
.def_readwrite("transducer", &PyClass::transducer)
|
.def_readwrite("transducer", &PyClass::transducer)
|
||||||
|
.def_readwrite("paraformer", &PyClass::paraformer)
|
||||||
.def_readwrite("tokens", &PyClass::tokens)
|
.def_readwrite("tokens", &PyClass::tokens)
|
||||||
.def_readwrite("num_threads", &PyClass::num_threads)
|
.def_readwrite("num_threads", &PyClass::num_threads)
|
||||||
.def_readwrite("debug", &PyClass::debug)
|
.def_readwrite("debug", &PyClass::debug)
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
// sherpa-onnx/python/csrc/online-model-config.h
|
// sherpa-onnx/python/csrc/online-model-config.h
|
||||||
//
|
//
|
||||||
// Copyright (c) 2023 by manyeyes
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
#ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_MODEL_CONFIG_H_
|
#ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_MODEL_CONFIG_H_
|
||||||
#define SHERPA_ONNX_PYTHON_CSRC_ONLINE_MODEL_CONFIG_H_
|
#define SHERPA_ONNX_PYTHON_CSRC_ONLINE_MODEL_CONFIG_H_
|
||||||
|
|||||||
24
sherpa-onnx/python/csrc/online-paraformer-model-config.cc
Normal file
24
sherpa-onnx/python/csrc/online-paraformer-model-config.cc
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
// sherpa-onnx/python/csrc/online-paraformer-model-config.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include "sherpa-onnx/python/csrc/online-paraformer-model-config.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/online-paraformer-model-config.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
void PybindOnlineParaformerModelConfig(py::module *m) {
|
||||||
|
using PyClass = OnlineParaformerModelConfig;
|
||||||
|
py::class_<PyClass>(*m, "OnlineParaformerModelConfig")
|
||||||
|
.def(py::init<const std::string &, const std::string &>(),
|
||||||
|
py::arg("encoder"), py::arg("decoder"))
|
||||||
|
.def_readwrite("encoder", &PyClass::encoder)
|
||||||
|
.def_readwrite("decoder", &PyClass::decoder)
|
||||||
|
.def("__str__", &PyClass::ToString);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
16
sherpa-onnx/python/csrc/online-paraformer-model-config.h
Normal file
16
sherpa-onnx/python/csrc/online-paraformer-model-config.h
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
// sherpa-onnx/python/csrc/online-paraformer-model-config.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_PARAFORMER_MODEL_CONFIG_H_
|
||||||
|
#define SHERPA_ONNX_PYTHON_CSRC_ONLINE_PARAFORMER_MODEL_CONFIG_H_
|
||||||
|
|
||||||
|
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
void PybindOnlineParaformerModelConfig(py::module *m);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_PYTHON_CSRC_ONLINE_PARAFORMER_MODEL_CONFIG_H_
|
||||||
@@ -33,7 +33,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
|
|||||||
py::arg("feat_config"), py::arg("model_config"),
|
py::arg("feat_config"), py::arg("model_config"),
|
||||||
py::arg("lm_config") = OnlineLMConfig(), py::arg("endpoint_config"),
|
py::arg("lm_config") = OnlineLMConfig(), py::arg("endpoint_config"),
|
||||||
py::arg("enable_endpoint"), py::arg("decoding_method"),
|
py::arg("enable_endpoint"), py::arg("decoding_method"),
|
||||||
py::arg("max_active_paths"), py::arg("context_score"))
|
py::arg("max_active_paths") = 4, py::arg("context_score") = 0)
|
||||||
.def_readwrite("feat_config", &PyClass::feat_config)
|
.def_readwrite("feat_config", &PyClass::feat_config)
|
||||||
.def_readwrite("model_config", &PyClass::model_config)
|
.def_readwrite("model_config", &PyClass::model_config)
|
||||||
.def_readwrite("endpoint_config", &PyClass::endpoint_config)
|
.def_readwrite("endpoint_config", &PyClass::endpoint_config)
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from _sherpa_onnx import (
|
|||||||
EndpointConfig,
|
EndpointConfig,
|
||||||
FeatureExtractorConfig,
|
FeatureExtractorConfig,
|
||||||
OnlineModelConfig,
|
OnlineModelConfig,
|
||||||
|
OnlineParaformerModelConfig,
|
||||||
OnlineRecognizer as _Recognizer,
|
OnlineRecognizer as _Recognizer,
|
||||||
OnlineRecognizerConfig,
|
OnlineRecognizerConfig,
|
||||||
OnlineStream,
|
OnlineStream,
|
||||||
@@ -32,7 +33,7 @@ class OnlineRecognizer(object):
|
|||||||
encoder: str,
|
encoder: str,
|
||||||
decoder: str,
|
decoder: str,
|
||||||
joiner: str,
|
joiner: str,
|
||||||
num_threads: int = 4,
|
num_threads: int = 2,
|
||||||
sample_rate: float = 16000,
|
sample_rate: float = 16000,
|
||||||
feature_dim: int = 80,
|
feature_dim: int = 80,
|
||||||
enable_endpoint_detection: bool = False,
|
enable_endpoint_detection: bool = False,
|
||||||
@@ -144,6 +145,109 @@ class OnlineRecognizer(object):
|
|||||||
self.config = recognizer_config
|
self.config = recognizer_config
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_paraformer(
|
||||||
|
cls,
|
||||||
|
tokens: str,
|
||||||
|
encoder: str,
|
||||||
|
decoder: str,
|
||||||
|
num_threads: int = 2,
|
||||||
|
sample_rate: float = 16000,
|
||||||
|
feature_dim: int = 80,
|
||||||
|
enable_endpoint_detection: bool = False,
|
||||||
|
rule1_min_trailing_silence: float = 2.4,
|
||||||
|
rule2_min_trailing_silence: float = 1.2,
|
||||||
|
rule3_min_utterance_length: float = 20.0,
|
||||||
|
decoding_method: str = "greedy_search",
|
||||||
|
provider: str = "cpu",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Please refer to
|
||||||
|
`<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html>`_
|
||||||
|
to download pre-trained models for different languages, e.g., Chinese,
|
||||||
|
English, etc.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokens:
|
||||||
|
Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
|
||||||
|
columns::
|
||||||
|
|
||||||
|
symbol integer_id
|
||||||
|
|
||||||
|
encoder:
|
||||||
|
Path to ``encoder.onnx``.
|
||||||
|
decoder:
|
||||||
|
Path to ``decoder.onnx``.
|
||||||
|
num_threads:
|
||||||
|
Number of threads for neural network computation.
|
||||||
|
sample_rate:
|
||||||
|
Sample rate of the training data used to train the model.
|
||||||
|
feature_dim:
|
||||||
|
Dimension of the feature used to train the model.
|
||||||
|
enable_endpoint_detection:
|
||||||
|
True to enable endpoint detection. False to disable endpoint
|
||||||
|
detection.
|
||||||
|
rule1_min_trailing_silence:
|
||||||
|
Used only when enable_endpoint_detection is True. If the duration
|
||||||
|
of trailing silence in seconds is larger than this value, we assume
|
||||||
|
an endpoint is detected.
|
||||||
|
rule2_min_trailing_silence:
|
||||||
|
Used only when enable_endpoint_detection is True. If we have decoded
|
||||||
|
something that is nonsilence and if the duration of trailing silence
|
||||||
|
in seconds is larger than this value, we assume an endpoint is
|
||||||
|
detected.
|
||||||
|
rule3_min_utterance_length:
|
||||||
|
Used only when enable_endpoint_detection is True. If the utterance
|
||||||
|
length in seconds is larger than this value, we assume an endpoint
|
||||||
|
is detected.
|
||||||
|
decoding_method:
|
||||||
|
The only valid value is greedy_search.
|
||||||
|
provider:
|
||||||
|
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
|
||||||
|
"""
|
||||||
|
self = cls.__new__(cls)
|
||||||
|
_assert_file_exists(tokens)
|
||||||
|
_assert_file_exists(encoder)
|
||||||
|
_assert_file_exists(decoder)
|
||||||
|
|
||||||
|
assert num_threads > 0, num_threads
|
||||||
|
|
||||||
|
paraformer_config = OnlineParaformerModelConfig(
|
||||||
|
encoder=encoder,
|
||||||
|
decoder=decoder,
|
||||||
|
)
|
||||||
|
|
||||||
|
model_config = OnlineModelConfig(
|
||||||
|
paraformer=paraformer_config,
|
||||||
|
tokens=tokens,
|
||||||
|
num_threads=num_threads,
|
||||||
|
provider=provider,
|
||||||
|
model_type="paraformer",
|
||||||
|
)
|
||||||
|
|
||||||
|
feat_config = FeatureExtractorConfig(
|
||||||
|
sampling_rate=sample_rate,
|
||||||
|
feature_dim=feature_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
endpoint_config = EndpointConfig(
|
||||||
|
rule1_min_trailing_silence=rule1_min_trailing_silence,
|
||||||
|
rule2_min_trailing_silence=rule2_min_trailing_silence,
|
||||||
|
rule3_min_utterance_length=rule3_min_utterance_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
recognizer_config = OnlineRecognizerConfig(
|
||||||
|
feat_config=feat_config,
|
||||||
|
model_config=model_config,
|
||||||
|
endpoint_config=endpoint_config,
|
||||||
|
enable_endpoint=enable_endpoint_detection,
|
||||||
|
decoding_method=decoding_method,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.recognizer = _Recognizer(recognizer_config)
|
||||||
|
self.config = recognizer_config
|
||||||
|
return self
|
||||||
|
|
||||||
def create_stream(self, contexts_list: Optional[List[List[int]]] = None):
|
def create_stream(self, contexts_list: Optional[List[List[int]]] = None):
|
||||||
if contexts_list is None:
|
if contexts_list is None:
|
||||||
return self.recognizer.create_stream()
|
return self.recognizer.create_stream()
|
||||||
|
|||||||
Reference in New Issue
Block a user