Add online LSTM transducer model (#25)
This commit is contained in:
47
.github/scripts/test-online-transducer.sh
vendored
Executable file
47
.github/scripts/test-online-transducer.sh
vendored
Executable file
@@ -0,0 +1,47 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
log() {
|
||||||
|
# This function is from espnet
|
||||||
|
local fname=${BASH_SOURCE[1]##*/}
|
||||||
|
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||||
|
}
|
||||||
|
|
||||||
|
echo "EXE is $EXE"
|
||||||
|
echo "PATH: $PATH"
|
||||||
|
|
||||||
|
which $EXE
|
||||||
|
|
||||||
|
log "------------------------------------------------------------"
|
||||||
|
log "Run LSTM transducer (English)"
|
||||||
|
log "------------------------------------------------------------"
|
||||||
|
|
||||||
|
repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-lstm-en-2023-02-17
|
||||||
|
|
||||||
|
log "Start testing ${repo_url}"
|
||||||
|
repo=$(basename $repo_url)
|
||||||
|
log "Download pretrained model and test-data from $repo_url"
|
||||||
|
|
||||||
|
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||||
|
pushd $repo
|
||||||
|
git lfs pull --include "*.onnx"
|
||||||
|
popd
|
||||||
|
|
||||||
|
waves=(
|
||||||
|
$repo/test_wavs/1089-134686-0001.wav
|
||||||
|
$repo/test_wavs/1221-135766-0001.wav
|
||||||
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
)
|
||||||
|
|
||||||
|
for wave in ${waves[@]}; do
|
||||||
|
time $EXE \
|
||||||
|
$repo/tokens.txt \
|
||||||
|
$repo/encoder-epoch-99-avg-1.onnx \
|
||||||
|
$repo/decoder-epoch-99-avg-1.onnx \
|
||||||
|
$repo/joiner-epoch-99-avg-1.onnx \
|
||||||
|
$wave \
|
||||||
|
4
|
||||||
|
done
|
||||||
|
|
||||||
|
rm -rf $repo
|
||||||
71
.github/workflows/linux.yaml
vendored
Normal file
71
.github/workflows/linux.yaml
vendored
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
name: linux
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- master
|
||||||
|
paths:
|
||||||
|
- '.github/workflows/linux.yaml'
|
||||||
|
- '.github/scripts/test-online-transducer.sh'
|
||||||
|
- 'CMakeLists.txt'
|
||||||
|
- 'cmake/**'
|
||||||
|
- 'sherpa-onnx/csrc/*'
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- master
|
||||||
|
paths:
|
||||||
|
- '.github/workflows/linux.yaml'
|
||||||
|
- '.github/scripts/test-online-transducer.sh'
|
||||||
|
- 'CMakeLists.txt'
|
||||||
|
- 'cmake/**'
|
||||||
|
- 'sherpa-onnx/csrc/*'
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: linux-${{ github.ref }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
linux:
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
os: [ubuntu-latest]
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- name: Configure CMake
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
mkdir build
|
||||||
|
cd build
|
||||||
|
cmake -D CMAKE_BUILD_TYPE=Release ..
|
||||||
|
|
||||||
|
- name: Build sherpa-onnx for ubuntu
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
cd build
|
||||||
|
make -j2
|
||||||
|
|
||||||
|
ls -lh lib
|
||||||
|
ls -lh bin
|
||||||
|
|
||||||
|
- name: Display dependencies of sherpa-onnx for linux
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
file build/bin/sherpa-onnx
|
||||||
|
readelf -d build/bin/sherpa-onnx
|
||||||
|
|
||||||
|
- name: Test online transducer
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
export PATH=$PWD/build/bin:$PATH
|
||||||
|
export EXE=sherpa-onnx
|
||||||
|
|
||||||
|
.github/scripts/test-online-transducer.sh
|
||||||
73
.github/workflows/macos.yaml
vendored
Normal file
73
.github/workflows/macos.yaml
vendored
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
name: macos
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- master
|
||||||
|
paths:
|
||||||
|
- '.github/workflows/macos.yaml'
|
||||||
|
- '.github/scripts/test-online-transducer.sh'
|
||||||
|
- 'CMakeLists.txt'
|
||||||
|
- 'cmake/**'
|
||||||
|
- 'sherpa-onnx/csrc/*'
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- master
|
||||||
|
paths:
|
||||||
|
- '.github/workflows/macos.yaml'
|
||||||
|
- '.github/scripts/test-online-transducer.sh'
|
||||||
|
- 'CMakeLists.txt'
|
||||||
|
- 'cmake/**'
|
||||||
|
- 'sherpa-onnx/csrc/*'
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: macos-${{ github.ref }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
macos:
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
os: [macos-latest]
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- name: Configure CMake
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
mkdir build
|
||||||
|
cd build
|
||||||
|
cmake -D CMAKE_BUILD_TYPE=Release ..
|
||||||
|
|
||||||
|
- name: Build sherpa for macos
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
cd build
|
||||||
|
make -j2
|
||||||
|
|
||||||
|
ls -lh lib
|
||||||
|
ls -lh bin
|
||||||
|
|
||||||
|
|
||||||
|
- name: Display dependencies of sherpa-onnx for macos
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
file bin/sherpa-onnx
|
||||||
|
otool -L build/bin/sherpa-onnx
|
||||||
|
otool -l build/bin/sherpa-onnx
|
||||||
|
|
||||||
|
- name: Test online transducer
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
export PATH=$PWD/build/bin:$PATH
|
||||||
|
export EXE=sherpa-onnx
|
||||||
|
|
||||||
|
.github/scripts/test-online-transducer.sh
|
||||||
207
.github/workflows/test-linux-macos-windows.yaml
vendored
207
.github/workflows/test-linux-macos-windows.yaml
vendored
@@ -1,207 +0,0 @@
|
|||||||
name: test-linux-macos-windows
|
|
||||||
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
branches:
|
|
||||||
- master
|
|
||||||
paths:
|
|
||||||
- '.github/workflows/test-linux-macos-windows.yaml'
|
|
||||||
- 'CMakeLists.txt'
|
|
||||||
- 'cmake/**'
|
|
||||||
- 'sherpa-onnx/csrc/*'
|
|
||||||
pull_request:
|
|
||||||
branches:
|
|
||||||
- master
|
|
||||||
paths:
|
|
||||||
- '.github/workflows/test-linux-macos-windows.yaml'
|
|
||||||
- 'CMakeLists.txt'
|
|
||||||
- 'cmake/**'
|
|
||||||
- 'sherpa-onnx/csrc/*'
|
|
||||||
|
|
||||||
concurrency:
|
|
||||||
group: test-linux-macos-windows-${{ github.ref }}
|
|
||||||
cancel-in-progress: true
|
|
||||||
|
|
||||||
permissions:
|
|
||||||
contents: read
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
test-linux-macos-windows:
|
|
||||||
runs-on: ${{ matrix.os }}
|
|
||||||
strategy:
|
|
||||||
fail-fast: false
|
|
||||||
matrix:
|
|
||||||
os: [ubuntu-latest, macos-latest, windows-latest]
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v2
|
|
||||||
with:
|
|
||||||
fetch-depth: 0
|
|
||||||
|
|
||||||
# see https://github.com/microsoft/setup-msbuild
|
|
||||||
- name: Add msbuild to PATH
|
|
||||||
if: startsWith(matrix.os, 'windows')
|
|
||||||
uses: microsoft/setup-msbuild@v1.0.2
|
|
||||||
|
|
||||||
- name: Download pretrained model and test-data (English)
|
|
||||||
shell: bash
|
|
||||||
run: |
|
|
||||||
git lfs install
|
|
||||||
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13
|
|
||||||
cd icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13
|
|
||||||
ls -lh exp/onnx/*.onnx
|
|
||||||
git lfs pull --include "exp/onnx/*.onnx"
|
|
||||||
ls -lh exp/onnx/*.onnx
|
|
||||||
|
|
||||||
- name: Download pretrained model and test-data (Chinese)
|
|
||||||
shell: bash
|
|
||||||
run: |
|
|
||||||
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2
|
|
||||||
cd icefall_asr_wenetspeech_pruned_transducer_stateless2
|
|
||||||
ls -lh exp/*.onnx
|
|
||||||
git lfs pull --include "exp/*.onnx"
|
|
||||||
ls -lh exp/*.onnx
|
|
||||||
|
|
||||||
- name: Configure CMake
|
|
||||||
shell: bash
|
|
||||||
run: |
|
|
||||||
mkdir build
|
|
||||||
cd build
|
|
||||||
cmake -D CMAKE_BUILD_TYPE=Release ..
|
|
||||||
|
|
||||||
- name: Build sherpa-onnx for ubuntu/macos
|
|
||||||
if: startsWith(matrix.os, 'ubuntu') || startsWith(matrix.os, 'macos')
|
|
||||||
shell: bash
|
|
||||||
run: |
|
|
||||||
cd build
|
|
||||||
make VERBOSE=1 -j3
|
|
||||||
|
|
||||||
- name: Build sherpa-onnx for Windows
|
|
||||||
if: startsWith(matrix.os, 'windows')
|
|
||||||
shell: bash
|
|
||||||
run: |
|
|
||||||
cmake --build ./build --config Release
|
|
||||||
|
|
||||||
- name: Run tests for ubuntu/macos (English)
|
|
||||||
if: startsWith(matrix.os, 'ubuntu') || startsWith(matrix.os, 'macos')
|
|
||||||
shell: bash
|
|
||||||
run: |
|
|
||||||
time ./build/bin/sherpa-onnx \
|
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \
|
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \
|
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \
|
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \
|
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \
|
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \
|
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1089-134686-0001.wav
|
|
||||||
|
|
||||||
time ./build/bin/sherpa-onnx \
|
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \
|
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \
|
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \
|
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \
|
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \
|
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \
|
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0001.wav
|
|
||||||
|
|
||||||
time ./build/bin/sherpa-onnx \
|
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \
|
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \
|
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \
|
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \
|
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \
|
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \
|
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0002.wav
|
|
||||||
|
|
||||||
- name: Run tests for Windows (English)
|
|
||||||
if: startsWith(matrix.os, 'windows')
|
|
||||||
shell: bash
|
|
||||||
run: |
|
|
||||||
./build/bin/Release/sherpa-onnx \
|
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \
|
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \
|
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \
|
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \
|
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \
|
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \
|
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1089-134686-0001.wav
|
|
||||||
|
|
||||||
./build/bin/Release/sherpa-onnx \
|
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \
|
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \
|
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \
|
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \
|
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \
|
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \
|
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0001.wav
|
|
||||||
|
|
||||||
./build/bin/Release/sherpa-onnx \
|
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \
|
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \
|
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \
|
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \
|
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \
|
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \
|
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0002.wav
|
|
||||||
|
|
||||||
- name: Run tests for ubuntu/macos (Chinese)
|
|
||||||
if: startsWith(matrix.os, 'ubuntu') || startsWith(matrix.os, 'macos')
|
|
||||||
shell: bash
|
|
||||||
run: |
|
|
||||||
time ./build/bin/sherpa-onnx \
|
|
||||||
./icefall_asr_wenetspeech_pruned_transducer_stateless2/data/lang_char/tokens.txt \
|
|
||||||
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/encoder-epoch-10-avg-2.onnx \
|
|
||||||
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/decoder-epoch-10-avg-2.onnx \
|
|
||||||
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner-epoch-10-avg-2.onnx \
|
|
||||||
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_encoder_proj-epoch-10-avg-2.onnx \
|
|
||||||
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_decoder_proj-epoch-10-avg-2.onnx \
|
|
||||||
./icefall_asr_wenetspeech_pruned_transducer_stateless2/test_wavs/DEV_T0000000000.wav
|
|
||||||
|
|
||||||
time ./build/bin/sherpa-onnx \
|
|
||||||
./icefall_asr_wenetspeech_pruned_transducer_stateless2/data/lang_char/tokens.txt \
|
|
||||||
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/encoder-epoch-10-avg-2.onnx \
|
|
||||||
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/decoder-epoch-10-avg-2.onnx \
|
|
||||||
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner-epoch-10-avg-2.onnx \
|
|
||||||
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_encoder_proj-epoch-10-avg-2.onnx \
|
|
||||||
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_decoder_proj-epoch-10-avg-2.onnx \
|
|
||||||
./icefall_asr_wenetspeech_pruned_transducer_stateless2/test_wavs/DEV_T0000000001.wav
|
|
||||||
|
|
||||||
time ./build/bin/sherpa-onnx \
|
|
||||||
./icefall_asr_wenetspeech_pruned_transducer_stateless2/data/lang_char/tokens.txt \
|
|
||||||
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/encoder-epoch-10-avg-2.onnx \
|
|
||||||
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/decoder-epoch-10-avg-2.onnx \
|
|
||||||
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner-epoch-10-avg-2.onnx \
|
|
||||||
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_encoder_proj-epoch-10-avg-2.onnx \
|
|
||||||
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_decoder_proj-epoch-10-avg-2.onnx \
|
|
||||||
./icefall_asr_wenetspeech_pruned_transducer_stateless2/test_wavs/DEV_T0000000002.wav
|
|
||||||
|
|
||||||
- name: Run tests for windows (Chinese)
|
|
||||||
if: startsWith(matrix.os, 'windows')
|
|
||||||
shell: bash
|
|
||||||
run: |
|
|
||||||
./build/bin/Release/sherpa-onnx \
|
|
||||||
./icefall_asr_wenetspeech_pruned_transducer_stateless2/data/lang_char/tokens.txt \
|
|
||||||
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/encoder-epoch-10-avg-2.onnx \
|
|
||||||
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/decoder-epoch-10-avg-2.onnx \
|
|
||||||
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner-epoch-10-avg-2.onnx \
|
|
||||||
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_encoder_proj-epoch-10-avg-2.onnx \
|
|
||||||
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_decoder_proj-epoch-10-avg-2.onnx \
|
|
||||||
./icefall_asr_wenetspeech_pruned_transducer_stateless2/test_wavs/DEV_T0000000000.wav
|
|
||||||
|
|
||||||
./build/bin/Release/sherpa-onnx \
|
|
||||||
./icefall_asr_wenetspeech_pruned_transducer_stateless2/data/lang_char/tokens.txt \
|
|
||||||
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/encoder-epoch-10-avg-2.onnx \
|
|
||||||
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/decoder-epoch-10-avg-2.onnx \
|
|
||||||
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner-epoch-10-avg-2.onnx \
|
|
||||||
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_encoder_proj-epoch-10-avg-2.onnx \
|
|
||||||
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_decoder_proj-epoch-10-avg-2.onnx \
|
|
||||||
./icefall_asr_wenetspeech_pruned_transducer_stateless2/test_wavs/DEV_T0000000001.wav
|
|
||||||
|
|
||||||
./build/bin/Release/sherpa-onnx \
|
|
||||||
./icefall_asr_wenetspeech_pruned_transducer_stateless2/data/lang_char/tokens.txt \
|
|
||||||
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/encoder-epoch-10-avg-2.onnx \
|
|
||||||
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/decoder-epoch-10-avg-2.onnx \
|
|
||||||
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner-epoch-10-avg-2.onnx \
|
|
||||||
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_encoder_proj-epoch-10-avg-2.onnx \
|
|
||||||
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_decoder_proj-epoch-10-avg-2.onnx \
|
|
||||||
./icefall_asr_wenetspeech_pruned_transducer_stateless2/test_wavs/DEV_T0000000002.wav
|
|
||||||
80
.github/workflows/windows-x64.yaml
vendored
Normal file
80
.github/workflows/windows-x64.yaml
vendored
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
name: windows-x64
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- master
|
||||||
|
paths:
|
||||||
|
- '.github/workflows/windows-x64.yaml'
|
||||||
|
- '.github/scripts/test-online-transducer.sh'
|
||||||
|
- 'CMakeLists.txt'
|
||||||
|
- 'cmake/**'
|
||||||
|
- 'sherpa-onnx/csrc/*'
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- master
|
||||||
|
paths:
|
||||||
|
- '.github/workflows/windows-x64.yaml'
|
||||||
|
- '.github/scripts/test-online-transducer.sh'
|
||||||
|
- 'CMakeLists.txt'
|
||||||
|
- 'cmake/**'
|
||||||
|
- 'sherpa-onnx/csrc/*'
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: windows-x64-${{ github.ref }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
windows_x64:
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
name: ${{ matrix.vs-version }}
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
include:
|
||||||
|
- vs-version: vs2015
|
||||||
|
toolset-version: v140
|
||||||
|
os: windows-2019
|
||||||
|
|
||||||
|
- vs-version: vs2017
|
||||||
|
toolset-version: v141
|
||||||
|
os: windows-2019
|
||||||
|
|
||||||
|
- vs-version: vs2019
|
||||||
|
toolset-version: v142
|
||||||
|
os: windows-2022
|
||||||
|
|
||||||
|
- vs-version: vs2022
|
||||||
|
toolset-version: v143
|
||||||
|
os: windows-2022
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- name: Configure CMake
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
mkdir build
|
||||||
|
cd build
|
||||||
|
cmake -T ${{ matrix.toolset-version}},host=x64 -A x64 -D CMAKE_BUILD_TYPE=Release ..
|
||||||
|
|
||||||
|
- name: Build sherpa-onnx for windows
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
cd build
|
||||||
|
cmake --build . --config Release -- -m:2
|
||||||
|
|
||||||
|
ls -lh ./bin/Release/sherpa-onnx.exe
|
||||||
|
|
||||||
|
- name: Test sherpa-onnx for Windows x64
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
export PATH=$PWD/build/bin/Release:$PATH
|
||||||
|
export EXE=sherpa-onnx.exe
|
||||||
|
|
||||||
|
.github/scripts/test-online-transducer.sh
|
||||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -4,3 +4,4 @@ build
|
|||||||
onnxruntime-*
|
onnxruntime-*
|
||||||
icefall-*
|
icefall-*
|
||||||
run.sh
|
run.sh
|
||||||
|
sherpa-onnx-*
|
||||||
|
|||||||
82
README.md
82
README.md
@@ -2,89 +2,7 @@
|
|||||||
|
|
||||||
Documentation: <https://k2-fsa.github.io/sherpa/onnx/index.html>
|
Documentation: <https://k2-fsa.github.io/sherpa/onnx/index.html>
|
||||||
|
|
||||||
Try it in colab:
|
|
||||||
[](https://colab.research.google.com/drive/1tmQbdlYeTl_klmtaGiUb7a7ZPz-AkBSH?usp=sharing)
|
|
||||||
|
|
||||||
See <https://github.com/k2-fsa/sherpa>
|
See <https://github.com/k2-fsa/sherpa>
|
||||||
|
|
||||||
This repo uses [onnxruntime](https://github.com/microsoft/onnxruntime) and
|
This repo uses [onnxruntime](https://github.com/microsoft/onnxruntime) and
|
||||||
does not depend on libtorch.
|
does not depend on libtorch.
|
||||||
|
|
||||||
We provide exported models in onnx format and they can be downloaded using
|
|
||||||
the following links:
|
|
||||||
|
|
||||||
- English: <https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13>
|
|
||||||
- Chinese: <https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2>
|
|
||||||
|
|
||||||
**NOTE**: We provide only non-streaming models at present.
|
|
||||||
|
|
||||||
|
|
||||||
**HINT**: The script for exporting the English model can be found at
|
|
||||||
<https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless3/export.py>
|
|
||||||
|
|
||||||
**HINT**: The script for exporting the Chinese model can be found at
|
|
||||||
<https://github.com/k2-fsa/icefall/blob/master/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py>
|
|
||||||
|
|
||||||
## Build for Linux/macOS
|
|
||||||
|
|
||||||
```bash
|
|
||||||
git clone https://github.com/k2-fsa/sherpa-onnx
|
|
||||||
cd sherpa-onnx
|
|
||||||
mkdir build
|
|
||||||
cd build
|
|
||||||
cmake -DCMAKE_BUILD_TYPE=Release ..
|
|
||||||
make -j6
|
|
||||||
cd ..
|
|
||||||
```
|
|
||||||
|
|
||||||
## Build for Windows
|
|
||||||
|
|
||||||
```bash
|
|
||||||
git clone https://github.com/k2-fsa/sherpa-onnx
|
|
||||||
cd sherpa-onnx
|
|
||||||
mkdir build
|
|
||||||
cd build
|
|
||||||
cmake -DCMAKE_BUILD_TYPE=Release ..
|
|
||||||
cmake --build . --config Release
|
|
||||||
cd ..
|
|
||||||
```
|
|
||||||
|
|
||||||
## Download the pretrained model (English)
|
|
||||||
|
|
||||||
```bash
|
|
||||||
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13
|
|
||||||
cd icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13
|
|
||||||
git lfs pull --include "exp/onnx/*.onnx"
|
|
||||||
cd ..
|
|
||||||
|
|
||||||
./build/bin/sherpa-onnx --help
|
|
||||||
|
|
||||||
./build/bin/sherpa-onnx \
|
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \
|
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \
|
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \
|
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \
|
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \
|
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \
|
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1089-134686-0001.wav
|
|
||||||
```
|
|
||||||
|
|
||||||
## Download the pretrained model (Chinese)
|
|
||||||
|
|
||||||
```bash
|
|
||||||
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2
|
|
||||||
cd icefall_asr_wenetspeech_pruned_transducer_stateless2
|
|
||||||
git lfs pull --include "exp/*.onnx"
|
|
||||||
cd ..
|
|
||||||
|
|
||||||
./build/bin/sherpa-onnx --help
|
|
||||||
|
|
||||||
./build/bin/sherpa-onnx \
|
|
||||||
./icefall_asr_wenetspeech_pruned_transducer_stateless2/data/lang_char/tokens.txt \
|
|
||||||
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/encoder-epoch-10-avg-2.onnx \
|
|
||||||
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/decoder-epoch-10-avg-2.onnx \
|
|
||||||
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner-epoch-10-avg-2.onnx \
|
|
||||||
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_encoder_proj-epoch-10-avg-2.onnx \
|
|
||||||
./icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/joiner_decoder_proj-epoch-10-avg-2.onnx \
|
|
||||||
./icefall_asr_wenetspeech_pruned_transducer_stateless2/test_wavs/DEV_T0000000000.wav
|
|
||||||
```
|
|
||||||
|
|||||||
@@ -2,7 +2,11 @@ include_directories(${CMAKE_SOURCE_DIR})
|
|||||||
|
|
||||||
add_executable(sherpa-onnx
|
add_executable(sherpa-onnx
|
||||||
decode.cc
|
decode.cc
|
||||||
rnnt-model.cc
|
features.cc
|
||||||
|
online-lstm-transducer-model.cc
|
||||||
|
online-transducer-model-config.cc
|
||||||
|
online-transducer-model.cc
|
||||||
|
onnx-utils.cc
|
||||||
sherpa-onnx.cc
|
sherpa-onnx.cc
|
||||||
symbol-table.cc
|
symbol-table.cc
|
||||||
wave-reader.cc
|
wave-reader.cc
|
||||||
@@ -13,5 +17,5 @@ target_link_libraries(sherpa-onnx
|
|||||||
kaldi-native-fbank-core
|
kaldi-native-fbank-core
|
||||||
)
|
)
|
||||||
|
|
||||||
# add_executable(sherpa-show-onnx-info show-onnx-info.cc)
|
add_executable(sherpa-onnx-show-info show-onnx-info.cc)
|
||||||
# target_link_libraries(sherpa-show-onnx-info onnxruntime)
|
target_link_libraries(sherpa-onnx-show-info onnxruntime)
|
||||||
|
|||||||
@@ -1,84 +1,79 @@
|
|||||||
/**
|
// sherpa/csrc/decode.cc
|
||||||
* Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
|
//
|
||||||
*
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
* See LICENSE for clarification regarding multiple authors
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include "sherpa-onnx/csrc/decode.h"
|
#include "sherpa-onnx/csrc/decode.h"
|
||||||
|
|
||||||
#include <assert.h>
|
#include <assert.h>
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
std::vector<int32_t> GreedySearch(RnntModel &model, // NOLINT
|
static Ort::Value Clone(Ort::Value *v) {
|
||||||
const Ort::Value &encoder_out) {
|
auto type_and_shape = v->GetTensorTypeAndShapeInfo();
|
||||||
|
std::vector<int64_t> shape = type_and_shape.GetShape();
|
||||||
|
|
||||||
|
auto memory_info =
|
||||||
|
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||||
|
|
||||||
|
return Ort::Value::CreateTensor(memory_info, v->GetTensorMutableData<float>(),
|
||||||
|
type_and_shape.GetElementCount(),
|
||||||
|
shape.data(), shape.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
static Ort::Value GetFrame(Ort::Value *encoder_out, int32_t t) {
|
||||||
|
std::vector<int64_t> encoder_out_shape =
|
||||||
|
encoder_out->GetTensorTypeAndShapeInfo().GetShape();
|
||||||
|
assert(encoder_out_shape[0] == 1);
|
||||||
|
|
||||||
|
int32_t encoder_out_dim = encoder_out_shape[2];
|
||||||
|
|
||||||
|
auto memory_info =
|
||||||
|
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||||
|
|
||||||
|
std::array<int64_t, 2> shape{1, encoder_out_dim};
|
||||||
|
|
||||||
|
return Ort::Value::CreateTensor(
|
||||||
|
memory_info,
|
||||||
|
encoder_out->GetTensorMutableData<float>() + t * encoder_out_dim,
|
||||||
|
encoder_out_dim, shape.data(), shape.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
void GreedySearch(OnlineTransducerModel *model, Ort::Value encoder_out,
|
||||||
|
std::vector<int64_t> *hyp) {
|
||||||
std::vector<int64_t> encoder_out_shape =
|
std::vector<int64_t> encoder_out_shape =
|
||||||
encoder_out.GetTensorTypeAndShapeInfo().GetShape();
|
encoder_out.GetTensorTypeAndShapeInfo().GetShape();
|
||||||
assert(encoder_out_shape[0] == 1 && "Only batch_size=1 is implemented");
|
|
||||||
Ort::Value projected_encoder_out =
|
|
||||||
model.RunJoinerEncoderProj(encoder_out.GetTensorData<float>(),
|
|
||||||
encoder_out_shape[1], encoder_out_shape[2]);
|
|
||||||
|
|
||||||
const float *p_projected_encoder_out =
|
if (encoder_out_shape[0] > 1) {
|
||||||
projected_encoder_out.GetTensorData<float>();
|
fprintf(stderr, "Only batch_size=1 is implemented. Given: %d\n",
|
||||||
|
static_cast<int32_t>(encoder_out_shape[0]));
|
||||||
|
}
|
||||||
|
|
||||||
int32_t context_size = 2; // hard-code it to 2
|
int32_t num_frames = encoder_out_shape[1];
|
||||||
int32_t blank_id = 0; // hard-code it to 0
|
int32_t vocab_size = model->VocabSize();
|
||||||
std::vector<int32_t> hyp(context_size, blank_id);
|
|
||||||
std::array<int64_t, 2> decoder_input{blank_id, blank_id};
|
|
||||||
|
|
||||||
Ort::Value decoder_out = model.RunDecoder(decoder_input.data(), context_size);
|
Ort::Value decoder_input = model->BuildDecoderInput(*hyp);
|
||||||
|
Ort::Value decoder_out = model->RunDecoder(std::move(decoder_input));
|
||||||
std::vector<int64_t> decoder_out_shape =
|
|
||||||
decoder_out.GetTensorTypeAndShapeInfo().GetShape();
|
|
||||||
|
|
||||||
Ort::Value projected_decoder_out = model.RunJoinerDecoderProj(
|
|
||||||
decoder_out.GetTensorData<float>(), decoder_out_shape[2]);
|
|
||||||
|
|
||||||
int32_t joiner_dim =
|
|
||||||
projected_decoder_out.GetTensorTypeAndShapeInfo().GetShape()[1];
|
|
||||||
|
|
||||||
int32_t T = encoder_out_shape[1];
|
|
||||||
for (int32_t t = 0; t != T; ++t) {
|
|
||||||
Ort::Value logit = model.RunJoiner(
|
|
||||||
p_projected_encoder_out + t * joiner_dim,
|
|
||||||
projected_decoder_out.GetTensorData<float>(), joiner_dim);
|
|
||||||
|
|
||||||
int32_t vocab_size = logit.GetTensorTypeAndShapeInfo().GetShape()[1];
|
|
||||||
|
|
||||||
|
for (int32_t t = 0; t != num_frames; ++t) {
|
||||||
|
Ort::Value cur_encoder_out = GetFrame(&encoder_out, t);
|
||||||
|
Ort::Value logit =
|
||||||
|
model->RunJoiner(std::move(cur_encoder_out), Clone(&decoder_out));
|
||||||
const float *p_logit = logit.GetTensorData<float>();
|
const float *p_logit = logit.GetTensorData<float>();
|
||||||
|
|
||||||
auto y = static_cast<int32_t>(std::distance(
|
auto y = static_cast<int32_t>(std::distance(
|
||||||
static_cast<const float *>(p_logit),
|
static_cast<const float *>(p_logit),
|
||||||
std::max_element(static_cast<const float *>(p_logit),
|
std::max_element(static_cast<const float *>(p_logit),
|
||||||
static_cast<const float *>(p_logit) + vocab_size)));
|
static_cast<const float *>(p_logit) + vocab_size)));
|
||||||
|
if (y != 0) {
|
||||||
if (y != blank_id) {
|
hyp->push_back(y);
|
||||||
decoder_input[0] = hyp.back();
|
decoder_input = model->BuildDecoderInput(*hyp);
|
||||||
decoder_input[1] = y;
|
decoder_out = model->RunDecoder(std::move(decoder_input));
|
||||||
hyp.push_back(y);
|
|
||||||
decoder_out = model.RunDecoder(decoder_input.data(), context_size);
|
|
||||||
projected_decoder_out = model.RunJoinerDecoderProj(
|
|
||||||
decoder_out.GetTensorData<float>(), decoder_out_shape[2]);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return {hyp.begin() + context_size, hyp.end()};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|||||||
@@ -1,27 +1,13 @@
|
|||||||
/**
|
// sherpa/csrc/decode.h
|
||||||
* Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
|
//
|
||||||
*
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
* See LICENSE for clarification regarding multiple authors
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef SHERPA_ONNX_CSRC_DECODE_H_
|
#ifndef SHERPA_ONNX_CSRC_DECODE_H_
|
||||||
#define SHERPA_ONNX_CSRC_DECODE_H_
|
#define SHERPA_ONNX_CSRC_DECODE_H_
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "sherpa-onnx/csrc/rnnt-model.h"
|
#include "sherpa-onnx/csrc/online-transducer-model.h"
|
||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
@@ -32,8 +18,8 @@ namespace sherpa_onnx {
|
|||||||
* @param model The RnntModel
|
* @param model The RnntModel
|
||||||
* @param encoder_out Its shape is (1, num_frames, encoder_out_dim).
|
* @param encoder_out Its shape is (1, num_frames, encoder_out_dim).
|
||||||
*/
|
*/
|
||||||
std::vector<int32_t> GreedySearch(RnntModel &model, // NOLINT
|
void GreedySearch(OnlineTransducerModel *model, Ort::Value encoder_out,
|
||||||
const Ort::Value &encoder_out);
|
std::vector<int64_t> *hyp);
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
|||||||
79
sherpa-onnx/csrc/features.cc
Normal file
79
sherpa-onnx/csrc/features.cc
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
// sherpa/csrc/features.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/features.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
FeatureExtractor::FeatureExtractor() {
|
||||||
|
opts_.frame_opts.dither = 0;
|
||||||
|
opts_.frame_opts.snip_edges = false;
|
||||||
|
opts_.frame_opts.samp_freq = 16000;
|
||||||
|
|
||||||
|
// cache 100 seconds of feature frames, which is more than enough
|
||||||
|
// for real needs
|
||||||
|
opts_.frame_opts.max_feature_vectors = 100 * 100;
|
||||||
|
|
||||||
|
opts_.mel_opts.num_bins = 80; // feature dim
|
||||||
|
|
||||||
|
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
|
||||||
|
}
|
||||||
|
|
||||||
|
FeatureExtractor::FeatureExtractor(const knf::FbankOptions &opts)
|
||||||
|
: opts_(opts) {
|
||||||
|
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
|
||||||
|
}
|
||||||
|
|
||||||
|
void FeatureExtractor::AcceptWaveform(float sampling_rate,
|
||||||
|
const float *waveform, int32_t n) {
|
||||||
|
std::lock_guard<std::mutex> lock(mutex_);
|
||||||
|
fbank_->AcceptWaveform(sampling_rate, waveform, n);
|
||||||
|
}
|
||||||
|
|
||||||
|
void FeatureExtractor::InputFinished() {
|
||||||
|
std::lock_guard<std::mutex> lock(mutex_);
|
||||||
|
fbank_->InputFinished();
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t FeatureExtractor::NumFramesReady() const {
|
||||||
|
std::lock_guard<std::mutex> lock(mutex_);
|
||||||
|
return fbank_->NumFramesReady();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool FeatureExtractor::IsLastFrame(int32_t frame) const {
|
||||||
|
std::lock_guard<std::mutex> lock(mutex_);
|
||||||
|
return fbank_->IsLastFrame(frame);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<float> FeatureExtractor::GetFrames(int32_t frame_index,
|
||||||
|
int32_t n) const {
|
||||||
|
if (frame_index + n > NumFramesReady()) {
|
||||||
|
fprintf(stderr, "%d + %d > %d\n", frame_index, n, NumFramesReady());
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
std::lock_guard<std::mutex> lock(mutex_);
|
||||||
|
|
||||||
|
int32_t feature_dim = fbank_->Dim();
|
||||||
|
std::vector<float> features(feature_dim * n);
|
||||||
|
|
||||||
|
float *p = features.data();
|
||||||
|
|
||||||
|
for (int32_t i = 0; i != n; ++i) {
|
||||||
|
const float *f = fbank_->GetFrame(i + frame_index);
|
||||||
|
std::copy(f, f + feature_dim, p);
|
||||||
|
p += feature_dim;
|
||||||
|
}
|
||||||
|
|
||||||
|
return features;
|
||||||
|
}
|
||||||
|
|
||||||
|
void FeatureExtractor::Reset() {
|
||||||
|
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
61
sherpa-onnx/csrc/features.h
Normal file
61
sherpa-onnx/csrc/features.h
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
// sherpa/csrc/features.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_FEATURES_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_FEATURES_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <mutex> // NOLINT
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "kaldi-native-fbank/csrc/online-feature.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
class FeatureExtractor {
|
||||||
|
public:
|
||||||
|
FeatureExtractor();
|
||||||
|
explicit FeatureExtractor(const knf::FbankOptions &fbank_opts);
|
||||||
|
|
||||||
|
/**
|
||||||
|
@param sampling_rate The sampling_rate of the input waveform. Should match
|
||||||
|
the one expected by the feature extractor.
|
||||||
|
@param waveform Pointer to a 1-D array of size n
|
||||||
|
@param n Number of entries in waveform
|
||||||
|
*/
|
||||||
|
void AcceptWaveform(float sampling_rate, const float *waveform, int32_t n);
|
||||||
|
|
||||||
|
// InputFinished() tells the class you won't be providing any
|
||||||
|
// more waveform. This will help flush out the last frame or two
|
||||||
|
// of features, in the case where snip-edges == false; it also
|
||||||
|
// affects the return value of IsLastFrame().
|
||||||
|
void InputFinished();
|
||||||
|
|
||||||
|
int32_t NumFramesReady() const;
|
||||||
|
|
||||||
|
// Note: IsLastFrame() will only ever return true if you have called
|
||||||
|
// InputFinished() (and this frame is the last frame).
|
||||||
|
bool IsLastFrame(int32_t frame) const;
|
||||||
|
|
||||||
|
/** Get n frames starting from the given frame index.
|
||||||
|
*
|
||||||
|
* @param frame_index The starting frame index
|
||||||
|
* @param n Number of frames to get.
|
||||||
|
* @return Return a 2-D tensor of shape (n, feature_dim).
|
||||||
|
* which is flattened into a 1-D vector (flattened in in row major)
|
||||||
|
*/
|
||||||
|
std::vector<float> GetFrames(int32_t frame_index, int32_t n) const;
|
||||||
|
|
||||||
|
void Reset();
|
||||||
|
int32_t FeatureDim() const { return opts_.mel_opts.num_bins; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::unique_ptr<knf::OnlineFbank> fbank_;
|
||||||
|
knf::FbankOptions opts_;
|
||||||
|
mutable std::mutex mutex_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_FEATURES_H_
|
||||||
223
sherpa-onnx/csrc/online-lstm-transducer-model.cc
Normal file
223
sherpa-onnx/csrc/online-lstm-transducer-model.cc
Normal file
@@ -0,0 +1,223 @@
|
|||||||
|
// sherpa/csrc/online-lstm-transducer-model.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
#include "sherpa-onnx/csrc/online-lstm-transducer-model.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <sstream>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||||
|
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||||
|
|
||||||
|
#define SHERPA_ONNX_READ_META_DATA(dst, src_key) \
|
||||||
|
do { \
|
||||||
|
auto value = \
|
||||||
|
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
|
||||||
|
if (!value) { \
|
||||||
|
fprintf(stderr, "%s does not exist in the metadata\n", src_key); \
|
||||||
|
exit(-1); \
|
||||||
|
} \
|
||||||
|
dst = atoi(value.get()); \
|
||||||
|
if (dst <= 0) { \
|
||||||
|
fprintf(stderr, "Invalud value %d for %s\n", dst, src_key); \
|
||||||
|
exit(-1); \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
OnlineLstmTransducerModel::OnlineLstmTransducerModel(
|
||||||
|
const OnlineTransducerModelConfig &config)
|
||||||
|
: env_(ORT_LOGGING_LEVEL_WARNING),
|
||||||
|
config_(config),
|
||||||
|
sess_opts_{},
|
||||||
|
allocator_{} {
|
||||||
|
sess_opts_.SetIntraOpNumThreads(config.num_threads);
|
||||||
|
sess_opts_.SetInterOpNumThreads(config.num_threads);
|
||||||
|
|
||||||
|
InitEncoder(config.encoder_filename);
|
||||||
|
InitDecoder(config.decoder_filename);
|
||||||
|
InitJoiner(config.joiner_filename);
|
||||||
|
}
|
||||||
|
|
||||||
|
void OnlineLstmTransducerModel::InitEncoder(const std::string &filename) {
|
||||||
|
encoder_sess_ = std::make_unique<Ort::Session>(
|
||||||
|
env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_);
|
||||||
|
|
||||||
|
GetInputNames(encoder_sess_.get(), &encoder_input_names_,
|
||||||
|
&encoder_input_names_ptr_);
|
||||||
|
|
||||||
|
GetOutputNames(encoder_sess_.get(), &encoder_output_names_,
|
||||||
|
&encoder_output_names_ptr_);
|
||||||
|
|
||||||
|
// get meta data
|
||||||
|
Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata();
|
||||||
|
if (config_.debug) {
|
||||||
|
std::ostringstream os;
|
||||||
|
os << "---encoder---\n";
|
||||||
|
PrintModelMetadata(os, meta_data);
|
||||||
|
fprintf(stderr, "%s\n", os.str().c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
Ort::AllocatorWithDefaultOptions allocator;
|
||||||
|
SHERPA_ONNX_READ_META_DATA(num_encoder_layers_, "num_encoder_layers");
|
||||||
|
SHERPA_ONNX_READ_META_DATA(T_, "T");
|
||||||
|
SHERPA_ONNX_READ_META_DATA(decode_chunk_len_, "decode_chunk_len");
|
||||||
|
SHERPA_ONNX_READ_META_DATA(rnn_hidden_size_, "rnn_hidden_size");
|
||||||
|
SHERPA_ONNX_READ_META_DATA(d_model_, "d_model");
|
||||||
|
}
|
||||||
|
|
||||||
|
void OnlineLstmTransducerModel::InitDecoder(const std::string &filename) {
|
||||||
|
decoder_sess_ = std::make_unique<Ort::Session>(
|
||||||
|
env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_);
|
||||||
|
|
||||||
|
GetInputNames(decoder_sess_.get(), &decoder_input_names_,
|
||||||
|
&decoder_input_names_ptr_);
|
||||||
|
|
||||||
|
GetOutputNames(decoder_sess_.get(), &decoder_output_names_,
|
||||||
|
&decoder_output_names_ptr_);
|
||||||
|
|
||||||
|
// get meta data
|
||||||
|
Ort::ModelMetadata meta_data = decoder_sess_->GetModelMetadata();
|
||||||
|
if (config_.debug) {
|
||||||
|
std::ostringstream os;
|
||||||
|
os << "---decoder---\n";
|
||||||
|
PrintModelMetadata(os, meta_data);
|
||||||
|
fprintf(stderr, "%s\n", os.str().c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
Ort::AllocatorWithDefaultOptions allocator;
|
||||||
|
SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size");
|
||||||
|
SHERPA_ONNX_READ_META_DATA(context_size_, "context_size");
|
||||||
|
}
|
||||||
|
|
||||||
|
void OnlineLstmTransducerModel::InitJoiner(const std::string &filename) {
|
||||||
|
joiner_sess_ = std::make_unique<Ort::Session>(
|
||||||
|
env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_);
|
||||||
|
|
||||||
|
GetInputNames(joiner_sess_.get(), &joiner_input_names_,
|
||||||
|
&joiner_input_names_ptr_);
|
||||||
|
|
||||||
|
GetOutputNames(joiner_sess_.get(), &joiner_output_names_,
|
||||||
|
&joiner_output_names_ptr_);
|
||||||
|
|
||||||
|
// get meta data
|
||||||
|
Ort::ModelMetadata meta_data = joiner_sess_->GetModelMetadata();
|
||||||
|
if (config_.debug) {
|
||||||
|
std::ostringstream os;
|
||||||
|
os << "---joiner---\n";
|
||||||
|
PrintModelMetadata(os, meta_data);
|
||||||
|
fprintf(stderr, "%s\n", os.str().c_str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ort::Value OnlineLstmTransducerModel::StackStates(
|
||||||
|
const std::vector<Ort::Value> &states) const {
|
||||||
|
fprintf(stderr, "implement me: %s:%d!\n", __func__,
|
||||||
|
static_cast<int>(__LINE__));
|
||||||
|
auto memory_info =
|
||||||
|
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||||
|
int64_t a;
|
||||||
|
std::array<int64_t, 3> x_shape{1, 1, 1};
|
||||||
|
Ort::Value x = Ort::Value::CreateTensor(memory_info, &a, 0, &a, 0);
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Ort::Value> OnlineLstmTransducerModel::UnStackStates(
|
||||||
|
Ort::Value states) const {
|
||||||
|
fprintf(stderr, "implement me: %s:%d!\n", __func__,
|
||||||
|
static_cast<int>(__LINE__));
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Ort::Value> OnlineLstmTransducerModel::GetEncoderInitStates() {
|
||||||
|
// Please see
|
||||||
|
// https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx.py#L185
|
||||||
|
// for details
|
||||||
|
constexpr int32_t kBatchSize = 1;
|
||||||
|
std::array<int64_t, 3> h_shape{num_encoder_layers_, kBatchSize, d_model_};
|
||||||
|
Ort::Value h = Ort::Value::CreateTensor<float>(allocator_, h_shape.data(),
|
||||||
|
h_shape.size());
|
||||||
|
|
||||||
|
std::fill(h.GetTensorMutableData<float>(),
|
||||||
|
h.GetTensorMutableData<float>() +
|
||||||
|
num_encoder_layers_ * kBatchSize * d_model_,
|
||||||
|
0);
|
||||||
|
|
||||||
|
std::array<int64_t, 3> c_shape{num_encoder_layers_, kBatchSize,
|
||||||
|
rnn_hidden_size_};
|
||||||
|
Ort::Value c = Ort::Value::CreateTensor<float>(allocator_, c_shape.data(),
|
||||||
|
c_shape.size());
|
||||||
|
|
||||||
|
std::fill(c.GetTensorMutableData<float>(),
|
||||||
|
c.GetTensorMutableData<float>() +
|
||||||
|
num_encoder_layers_ * kBatchSize * rnn_hidden_size_,
|
||||||
|
0);
|
||||||
|
|
||||||
|
std::vector<Ort::Value> states;
|
||||||
|
|
||||||
|
states.reserve(2);
|
||||||
|
states.push_back(std::move(h));
|
||||||
|
states.push_back(std::move(c));
|
||||||
|
|
||||||
|
return states;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<Ort::Value, std::vector<Ort::Value>>
|
||||||
|
OnlineLstmTransducerModel::RunEncoder(Ort::Value features,
|
||||||
|
std::vector<Ort::Value> &states) {
|
||||||
|
auto memory_info =
|
||||||
|
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||||
|
|
||||||
|
std::array<Ort::Value, 3> encoder_inputs = {
|
||||||
|
std::move(features), std::move(states[0]), std::move(states[1])};
|
||||||
|
|
||||||
|
auto encoder_out = encoder_sess_->Run(
|
||||||
|
{}, encoder_input_names_ptr_.data(), encoder_inputs.data(),
|
||||||
|
encoder_inputs.size(), encoder_output_names_ptr_.data(),
|
||||||
|
encoder_output_names_ptr_.size());
|
||||||
|
|
||||||
|
std::vector<Ort::Value> next_states;
|
||||||
|
next_states.reserve(2);
|
||||||
|
next_states.push_back(std::move(encoder_out[1]));
|
||||||
|
next_states.push_back(std::move(encoder_out[2]));
|
||||||
|
|
||||||
|
return {std::move(encoder_out[0]), std::move(next_states)};
|
||||||
|
}
|
||||||
|
|
||||||
|
Ort::Value OnlineLstmTransducerModel::BuildDecoderInput(
|
||||||
|
const std::vector<int64_t> &hyp) {
|
||||||
|
auto memory_info =
|
||||||
|
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||||
|
|
||||||
|
std::array<int64_t, 2> shape{1, context_size_};
|
||||||
|
|
||||||
|
return Ort::Value::CreateTensor(
|
||||||
|
memory_info,
|
||||||
|
const_cast<int64_t *>(hyp.data() + hyp.size() - context_size_),
|
||||||
|
context_size_, shape.data(), shape.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
Ort::Value OnlineLstmTransducerModel::RunDecoder(Ort::Value decoder_input) {
|
||||||
|
auto decoder_out = decoder_sess_->Run(
|
||||||
|
{}, decoder_input_names_ptr_.data(), &decoder_input, 1,
|
||||||
|
decoder_output_names_ptr_.data(), decoder_output_names_ptr_.size());
|
||||||
|
return std::move(decoder_out[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ort::Value OnlineLstmTransducerModel::RunJoiner(Ort::Value encoder_out,
|
||||||
|
Ort::Value decoder_out) {
|
||||||
|
std::array<Ort::Value, 2> joiner_input = {std::move(encoder_out),
|
||||||
|
std::move(decoder_out)};
|
||||||
|
auto logit =
|
||||||
|
joiner_sess_->Run({}, joiner_input_names_ptr_.data(), joiner_input.data(),
|
||||||
|
joiner_input.size(), joiner_output_names_ptr_.data(),
|
||||||
|
joiner_output_names_ptr_.size());
|
||||||
|
|
||||||
|
return std::move(logit[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
91
sherpa-onnx/csrc/online-lstm-transducer-model.h
Normal file
91
sherpa-onnx/csrc/online-lstm-transducer-model.h
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
// sherpa/csrc/online-lstm-transducer-model.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_ONLINE_LSTM_TRANSDUCER_MODEL_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_ONLINE_LSTM_TRANSDUCER_MODEL_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||||
|
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
|
||||||
|
#include "sherpa-onnx/csrc/online-transducer-model.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
class OnlineLstmTransducerModel : public OnlineTransducerModel {
|
||||||
|
public:
|
||||||
|
explicit OnlineLstmTransducerModel(const OnlineTransducerModelConfig &config);
|
||||||
|
|
||||||
|
Ort::Value StackStates(const std::vector<Ort::Value> &states) const override;
|
||||||
|
|
||||||
|
std::vector<Ort::Value> UnStackStates(Ort::Value states) const override;
|
||||||
|
|
||||||
|
std::vector<Ort::Value> GetEncoderInitStates() override;
|
||||||
|
|
||||||
|
std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder(
|
||||||
|
Ort::Value features, std::vector<Ort::Value> &states) override;
|
||||||
|
|
||||||
|
Ort::Value BuildDecoderInput(const std::vector<int64_t> &hyp) override;
|
||||||
|
|
||||||
|
Ort::Value RunDecoder(Ort::Value decoder_input) override;
|
||||||
|
|
||||||
|
Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) override;
|
||||||
|
|
||||||
|
int32_t ContextSize() const override { return context_size_; }
|
||||||
|
|
||||||
|
int32_t ChunkSize() const override { return T_; }
|
||||||
|
|
||||||
|
int32_t ChunkShift() const override { return decode_chunk_len_; }
|
||||||
|
|
||||||
|
int32_t VocabSize() const override { return vocab_size_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
void InitEncoder(const std::string &encoder_filename);
|
||||||
|
void InitDecoder(const std::string &decoder_filename);
|
||||||
|
void InitJoiner(const std::string &joiner_filename);
|
||||||
|
|
||||||
|
private:
|
||||||
|
Ort::Env env_;
|
||||||
|
Ort::SessionOptions sess_opts_;
|
||||||
|
|
||||||
|
Ort::AllocatorWithDefaultOptions allocator_;
|
||||||
|
|
||||||
|
std::unique_ptr<Ort::Session> encoder_sess_;
|
||||||
|
std::unique_ptr<Ort::Session> decoder_sess_;
|
||||||
|
std::unique_ptr<Ort::Session> joiner_sess_;
|
||||||
|
|
||||||
|
std::vector<std::string> encoder_input_names_;
|
||||||
|
std::vector<const char *> encoder_input_names_ptr_;
|
||||||
|
|
||||||
|
std::vector<std::string> encoder_output_names_;
|
||||||
|
std::vector<const char *> encoder_output_names_ptr_;
|
||||||
|
|
||||||
|
std::vector<std::string> decoder_input_names_;
|
||||||
|
std::vector<const char *> decoder_input_names_ptr_;
|
||||||
|
|
||||||
|
std::vector<std::string> decoder_output_names_;
|
||||||
|
std::vector<const char *> decoder_output_names_ptr_;
|
||||||
|
|
||||||
|
std::vector<std::string> joiner_input_names_;
|
||||||
|
std::vector<const char *> joiner_input_names_ptr_;
|
||||||
|
|
||||||
|
std::vector<std::string> joiner_output_names_;
|
||||||
|
std::vector<const char *> joiner_output_names_ptr_;
|
||||||
|
|
||||||
|
OnlineTransducerModelConfig config_;
|
||||||
|
|
||||||
|
int32_t num_encoder_layers_ = 0;
|
||||||
|
int32_t T_ = 0;
|
||||||
|
int32_t decode_chunk_len_ = 0;
|
||||||
|
int32_t rnn_hidden_size_ = 0;
|
||||||
|
int32_t d_model_ = 0;
|
||||||
|
int32_t context_size_ = 0;
|
||||||
|
int32_t vocab_size_ = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_ONLINE_LSTM_TRANSDUCER_MODEL_H_
|
||||||
23
sherpa-onnx/csrc/online-transducer-model-config.cc
Normal file
23
sherpa-onnx/csrc/online-transducer-model-config.cc
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
// sherpa/csrc/online-transducer-model-config.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
|
||||||
|
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
std::string OnlineTransducerModelConfig::ToString() const {
|
||||||
|
std::ostringstream os;
|
||||||
|
|
||||||
|
os << "OnlineTransducerModelConfig(";
|
||||||
|
os << "encoder_filename=\"" << encoder_filename << "\", ";
|
||||||
|
os << "decoder_filename=\"" << decoder_filename << "\", ";
|
||||||
|
os << "joiner_filename=\"" << joiner_filename << "\", ";
|
||||||
|
os << "num_threads=" << num_threads << ", ";
|
||||||
|
os << "debug=" << (debug ? "True" : "False") << ")";
|
||||||
|
|
||||||
|
return os.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
23
sherpa-onnx/csrc/online-transducer-model-config.h
Normal file
23
sherpa-onnx/csrc/online-transducer-model-config.h
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
// sherpa/csrc/online-transducer-model-config.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODEL_CONFIG_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODEL_CONFIG_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
struct OnlineTransducerModelConfig {
|
||||||
|
std::string encoder_filename;
|
||||||
|
std::string decoder_filename;
|
||||||
|
std::string joiner_filename;
|
||||||
|
int32_t num_threads;
|
||||||
|
bool debug = false;
|
||||||
|
|
||||||
|
std::string ToString() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODEL_CONFIG_H_
|
||||||
64
sherpa-onnx/csrc/online-transducer-model.cc
Normal file
64
sherpa-onnx/csrc/online-transducer-model.cc
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
// sherpa/csrc/online-transducer-model.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
#include "sherpa-onnx/csrc/online-transducer-model.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <sstream>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/online-lstm-transducer-model.h"
|
||||||
|
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
enum class ModelType {
|
||||||
|
kLstm,
|
||||||
|
kUnkown,
|
||||||
|
};
|
||||||
|
|
||||||
|
static ModelType GetModelType(const OnlineTransducerModelConfig &config) {
|
||||||
|
Ort::Env env(ORT_LOGGING_LEVEL_WARNING);
|
||||||
|
Ort::SessionOptions sess_opts;
|
||||||
|
|
||||||
|
auto sess = std::make_unique<Ort::Session>(
|
||||||
|
env, SHERPA_MAYBE_WIDE(config.encoder_filename).c_str(), sess_opts);
|
||||||
|
|
||||||
|
Ort::ModelMetadata meta_data = sess->GetModelMetadata();
|
||||||
|
if (config.debug) {
|
||||||
|
std::ostringstream os;
|
||||||
|
PrintModelMetadata(os, meta_data);
|
||||||
|
fprintf(stderr, "%s\n", os.str().c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
Ort::AllocatorWithDefaultOptions allocator;
|
||||||
|
auto model_type =
|
||||||
|
meta_data.LookupCustomMetadataMapAllocated("model_type", allocator);
|
||||||
|
if (!model_type) {
|
||||||
|
fprintf(stderr, "No model_type in the metadata!\n");
|
||||||
|
return ModelType::kUnkown;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (model_type.get() == std::string("lstm")) {
|
||||||
|
return ModelType::kLstm;
|
||||||
|
} else {
|
||||||
|
fprintf(stderr, "Unsupported model_type: %s\n", model_type.get());
|
||||||
|
return ModelType::kUnkown;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
|
||||||
|
const OnlineTransducerModelConfig &config) {
|
||||||
|
auto model_type = GetModelType(config);
|
||||||
|
|
||||||
|
switch (model_type) {
|
||||||
|
case ModelType::kLstm:
|
||||||
|
return std::make_unique<OnlineLstmTransducerModel>(config);
|
||||||
|
case ModelType::kUnkown:
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// unreachable code
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
118
sherpa-onnx/csrc/online-transducer-model.h
Normal file
118
sherpa-onnx/csrc/online-transducer-model.h
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
// sherpa/csrc/online-transducer-model.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODEL_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODEL_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||||
|
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
class OnlineTransducerModel {
|
||||||
|
public:
|
||||||
|
virtual ~OnlineTransducerModel() = default;
|
||||||
|
|
||||||
|
static std::unique_ptr<OnlineTransducerModel> Create(
|
||||||
|
const OnlineTransducerModelConfig &config);
|
||||||
|
|
||||||
|
/** Stack a list of individual states into a batch.
|
||||||
|
*
|
||||||
|
* It is the inverse operation of `UnStackStates`.
|
||||||
|
*
|
||||||
|
* @param states states[i] contains the state for the i-th utterance.
|
||||||
|
* @return Return a single value representing the batched state.
|
||||||
|
*/
|
||||||
|
virtual Ort::Value StackStates(
|
||||||
|
const std::vector<Ort::Value> &states) const = 0;
|
||||||
|
|
||||||
|
/** Unstack a batch state into a list of individual states.
|
||||||
|
*
|
||||||
|
* It is the inverse operation of `StackStates`.
|
||||||
|
*
|
||||||
|
* @param states A batched state.
|
||||||
|
* @return ans[i] contains the state for the i-th utterance.
|
||||||
|
*/
|
||||||
|
virtual std::vector<Ort::Value> UnStackStates(Ort::Value states) const = 0;
|
||||||
|
|
||||||
|
/** Get the initial encoder states.
|
||||||
|
*
|
||||||
|
* @return Return the initial encoder state.
|
||||||
|
*/
|
||||||
|
virtual std::vector<Ort::Value> GetEncoderInitStates() = 0;
|
||||||
|
|
||||||
|
/** Run the encoder.
|
||||||
|
*
|
||||||
|
* @param features A tensor of shape (N, T, C). It is changed in-place.
|
||||||
|
* @param states Encoder state of the previous chunk. It is changed in-place.
|
||||||
|
*
|
||||||
|
* @return Return a tuple containing:
|
||||||
|
* - encoder_out, a tensor of shape (N, T', encoder_out_dim)
|
||||||
|
* - next_states Encoder state for the next chunk.
|
||||||
|
*/
|
||||||
|
virtual std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder(
|
||||||
|
Ort::Value features,
|
||||||
|
std::vector<Ort::Value> &states) = 0; // NOLINT
|
||||||
|
|
||||||
|
virtual Ort::Value BuildDecoderInput(const std::vector<int64_t> &hyp) = 0;
|
||||||
|
|
||||||
|
/** Run the decoder network.
|
||||||
|
*
|
||||||
|
* Caution: We assume there are no recurrent connections in the decoder and
|
||||||
|
* the decoder is stateless. See
|
||||||
|
* https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py
|
||||||
|
* for an example
|
||||||
|
*
|
||||||
|
* @param decoder_input It is usually of shape (N, context_size)
|
||||||
|
* @return Return a tensor of shape (N, decoder_dim).
|
||||||
|
*/
|
||||||
|
virtual Ort::Value RunDecoder(Ort::Value decoder_input) = 0;
|
||||||
|
|
||||||
|
/** Run the joint network.
|
||||||
|
*
|
||||||
|
* @param encoder_out Output of the encoder network. A tensor of shape
|
||||||
|
* (N, joiner_dim).
|
||||||
|
* @param decoder_out Output of the decoder network. A tensor of shape
|
||||||
|
* (N, joiner_dim).
|
||||||
|
* @return Return a tensor of shape (N, vocab_size). In icefall, the last
|
||||||
|
* last layer of the joint network is `nn.Linear`,
|
||||||
|
* not `nn.LogSoftmax`.
|
||||||
|
*/
|
||||||
|
virtual Ort::Value RunJoiner(Ort::Value encoder_out,
|
||||||
|
Ort::Value decoder_out) = 0;
|
||||||
|
|
||||||
|
/** If we are using a stateless decoder and if it contains a
|
||||||
|
* Conv1D, this function returns the kernel size of the convolution layer.
|
||||||
|
*/
|
||||||
|
virtual int32_t ContextSize() const = 0;
|
||||||
|
|
||||||
|
/** We send this number of feature frames to the encoder at a time. */
|
||||||
|
virtual int32_t ChunkSize() const = 0;
|
||||||
|
|
||||||
|
/** Number of input frames to discard after each call to RunEncoder.
|
||||||
|
*
|
||||||
|
* For instance, if we have 30 frames, chunk_size=8, chunk_shift=6.
|
||||||
|
*
|
||||||
|
* In the first call of RunEncoder, we use frames 0~7 since chunk_size is 8.
|
||||||
|
* Then we discard frame 0~5 since chunk_shift is 6.
|
||||||
|
* In the second call of RunEncoder, we use frames 6~13; and then we discard
|
||||||
|
* frames 6~11.
|
||||||
|
* In the third call of RunEncoder, we use frames 12~19; and then we discard
|
||||||
|
* frames 12~16.
|
||||||
|
*
|
||||||
|
* Note: ChunkSize() - ChunkShift() == right context size
|
||||||
|
*/
|
||||||
|
virtual int32_t ChunkShift() const = 0;
|
||||||
|
|
||||||
|
virtual int32_t VocabSize() const = 0;
|
||||||
|
|
||||||
|
virtual int32_t SubsamplingFactor() const { return 4; }
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODEL_H_
|
||||||
49
sherpa-onnx/csrc/onnx-utils.cc
Normal file
49
sherpa-onnx/csrc/onnx-utils.cc
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
// sherpa/csrc/onnx-utils.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
void GetInputNames(Ort::Session *sess, std::vector<std::string> *input_names,
|
||||||
|
std::vector<const char *> *input_names_ptr) {
|
||||||
|
Ort::AllocatorWithDefaultOptions allocator;
|
||||||
|
size_t node_count = sess->GetInputCount();
|
||||||
|
input_names->resize(node_count);
|
||||||
|
input_names_ptr->resize(node_count);
|
||||||
|
for (size_t i = 0; i != node_count; ++i) {
|
||||||
|
auto tmp = sess->GetInputNameAllocated(i, allocator);
|
||||||
|
(*input_names)[i] = tmp.get();
|
||||||
|
(*input_names_ptr)[i] = (*input_names)[i].c_str();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void GetOutputNames(Ort::Session *sess, std::vector<std::string> *output_names,
|
||||||
|
std::vector<const char *> *output_names_ptr) {
|
||||||
|
Ort::AllocatorWithDefaultOptions allocator;
|
||||||
|
size_t node_count = sess->GetOutputCount();
|
||||||
|
output_names->resize(node_count);
|
||||||
|
output_names_ptr->resize(node_count);
|
||||||
|
for (size_t i = 0; i != node_count; ++i) {
|
||||||
|
auto tmp = sess->GetOutputNameAllocated(i, allocator);
|
||||||
|
(*output_names)[i] = tmp.get();
|
||||||
|
(*output_names_ptr)[i] = (*output_names)[i].c_str();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void PrintModelMetadata(std::ostream &os, const Ort::ModelMetadata &meta_data) {
|
||||||
|
Ort::AllocatorWithDefaultOptions allocator;
|
||||||
|
std::vector<Ort::AllocatedStringPtr> v =
|
||||||
|
meta_data.GetCustomMetadataMapKeysAllocated(allocator);
|
||||||
|
for (const auto &key : v) {
|
||||||
|
auto p = meta_data.LookupCustomMetadataMapAllocated(key.get(), allocator);
|
||||||
|
os << key.get() << "=" << p.get() << "\n";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
60
sherpa-onnx/csrc/onnx-utils.h
Normal file
60
sherpa-onnx/csrc/onnx-utils.h
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
// sherpa/csrc/onnx-utils.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_ONNX_UTILS_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_ONNX_UTILS_H_
|
||||||
|
|
||||||
|
#ifdef _MSC_VER
|
||||||
|
// For ToWide() below
|
||||||
|
#include <codecvt>
|
||||||
|
#include <locale>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include <ostream>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
#ifdef _MSC_VER
|
||||||
|
// See
|
||||||
|
// https://stackoverflow.com/questions/2573834/c-convert-string-or-char-to-wstring-or-wchar-t
|
||||||
|
static std::wstring ToWide(const std::string &s) {
|
||||||
|
std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>> converter;
|
||||||
|
return converter.from_bytes(s);
|
||||||
|
}
|
||||||
|
#define SHERPA_MAYBE_WIDE(s) ToWide(s)
|
||||||
|
#else
|
||||||
|
#define SHERPA_MAYBE_WIDE(s) s
|
||||||
|
#endif
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the input names of a model.
|
||||||
|
*
|
||||||
|
* @param sess An onnxruntime session.
|
||||||
|
* @param input_names. On return, it contains the input names of the model.
|
||||||
|
* @param input_names_ptr. On return, input_names_ptr[i] contains
|
||||||
|
* input_names[i].c_str()
|
||||||
|
*/
|
||||||
|
void GetInputNames(Ort::Session *sess, std::vector<std::string> *input_names,
|
||||||
|
std::vector<const char *> *input_names_ptr);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the output names of a model.
|
||||||
|
*
|
||||||
|
* @param sess An onnxruntime session.
|
||||||
|
* @param output_names. On return, it contains the output names of the model.
|
||||||
|
* @param output_names_ptr. On return, output_names_ptr[i] contains
|
||||||
|
* output_names[i].c_str()
|
||||||
|
*/
|
||||||
|
void GetOutputNames(Ort::Session *sess, std::vector<std::string> *output_names,
|
||||||
|
std::vector<const char *> *output_names_ptr);
|
||||||
|
|
||||||
|
void PrintModelMetadata(std::ostream &os,
|
||||||
|
const Ort::ModelMetadata &meta_data); // NOLINT
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_ONNX_UTILS_H_
|
||||||
@@ -1,265 +0,0 @@
|
|||||||
/**
|
|
||||||
* Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
|
|
||||||
*
|
|
||||||
* See LICENSE for clarification regarding multiple authors
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#include "sherpa-onnx/csrc/rnnt-model.h"
|
|
||||||
|
|
||||||
#include <array>
|
|
||||||
#include <utility>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#ifdef _MSC_VER
|
|
||||||
// For ToWide() below
|
|
||||||
#include <codecvt>
|
|
||||||
#include <locale>
|
|
||||||
#endif
|
|
||||||
|
|
||||||
namespace sherpa_onnx {
|
|
||||||
|
|
||||||
#ifdef _MSC_VER
|
|
||||||
// See
|
|
||||||
// https://stackoverflow.com/questions/2573834/c-convert-string-or-char-to-wstring-or-wchar-t
|
|
||||||
static std::wstring ToWide(const std::string &s) {
|
|
||||||
std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>> converter;
|
|
||||||
return converter.from_bytes(s);
|
|
||||||
}
|
|
||||||
#define SHERPA_MAYBE_WIDE(s) ToWide(s)
|
|
||||||
#else
|
|
||||||
#define SHERPA_MAYBE_WIDE(s) s
|
|
||||||
#endif
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the input names of a model.
|
|
||||||
*
|
|
||||||
* @param sess An onnxruntime session.
|
|
||||||
* @param input_names. On return, it contains the input names of the model.
|
|
||||||
* @param input_names_ptr. On return, input_names_ptr[i] contains
|
|
||||||
* input_names[i].c_str()
|
|
||||||
*/
|
|
||||||
static void GetInputNames(Ort::Session *sess,
|
|
||||||
std::vector<std::string> *input_names,
|
|
||||||
std::vector<const char *> *input_names_ptr) {
|
|
||||||
Ort::AllocatorWithDefaultOptions allocator;
|
|
||||||
size_t node_count = sess->GetInputCount();
|
|
||||||
input_names->resize(node_count);
|
|
||||||
input_names_ptr->resize(node_count);
|
|
||||||
for (size_t i = 0; i != node_count; ++i) {
|
|
||||||
auto tmp = sess->GetInputNameAllocated(i, allocator);
|
|
||||||
(*input_names)[i] = tmp.get();
|
|
||||||
(*input_names_ptr)[i] = (*input_names)[i].c_str();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the output names of a model.
|
|
||||||
*
|
|
||||||
* @param sess An onnxruntime session.
|
|
||||||
* @param output_names. On return, it contains the output names of the model.
|
|
||||||
* @param output_names_ptr. On return, output_names_ptr[i] contains
|
|
||||||
* output_names[i].c_str()
|
|
||||||
*/
|
|
||||||
static void GetOutputNames(Ort::Session *sess,
|
|
||||||
std::vector<std::string> *output_names,
|
|
||||||
std::vector<const char *> *output_names_ptr) {
|
|
||||||
Ort::AllocatorWithDefaultOptions allocator;
|
|
||||||
size_t node_count = sess->GetOutputCount();
|
|
||||||
output_names->resize(node_count);
|
|
||||||
output_names_ptr->resize(node_count);
|
|
||||||
for (size_t i = 0; i != node_count; ++i) {
|
|
||||||
auto tmp = sess->GetOutputNameAllocated(i, allocator);
|
|
||||||
(*output_names)[i] = tmp.get();
|
|
||||||
(*output_names_ptr)[i] = (*output_names)[i].c_str();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
RnntModel::RnntModel(const std::string &encoder_filename,
|
|
||||||
const std::string &decoder_filename,
|
|
||||||
const std::string &joiner_filename,
|
|
||||||
const std::string &joiner_encoder_proj_filename,
|
|
||||||
const std::string &joiner_decoder_proj_filename,
|
|
||||||
int32_t num_threads)
|
|
||||||
: env_(ORT_LOGGING_LEVEL_WARNING) {
|
|
||||||
sess_opts_.SetIntraOpNumThreads(num_threads);
|
|
||||||
sess_opts_.SetInterOpNumThreads(num_threads);
|
|
||||||
|
|
||||||
InitEncoder(encoder_filename);
|
|
||||||
InitDecoder(decoder_filename);
|
|
||||||
InitJoiner(joiner_filename);
|
|
||||||
InitJoinerEncoderProj(joiner_encoder_proj_filename);
|
|
||||||
InitJoinerDecoderProj(joiner_decoder_proj_filename);
|
|
||||||
}
|
|
||||||
|
|
||||||
void RnntModel::InitEncoder(const std::string &filename) {
|
|
||||||
encoder_sess_ = std::make_unique<Ort::Session>(
|
|
||||||
env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_);
|
|
||||||
GetInputNames(encoder_sess_.get(), &encoder_input_names_,
|
|
||||||
&encoder_input_names_ptr_);
|
|
||||||
|
|
||||||
GetOutputNames(encoder_sess_.get(), &encoder_output_names_,
|
|
||||||
&encoder_output_names_ptr_);
|
|
||||||
}
|
|
||||||
|
|
||||||
void RnntModel::InitDecoder(const std::string &filename) {
|
|
||||||
decoder_sess_ = std::make_unique<Ort::Session>(
|
|
||||||
env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_);
|
|
||||||
|
|
||||||
GetInputNames(decoder_sess_.get(), &decoder_input_names_,
|
|
||||||
&decoder_input_names_ptr_);
|
|
||||||
|
|
||||||
GetOutputNames(decoder_sess_.get(), &decoder_output_names_,
|
|
||||||
&decoder_output_names_ptr_);
|
|
||||||
}
|
|
||||||
|
|
||||||
void RnntModel::InitJoiner(const std::string &filename) {
|
|
||||||
joiner_sess_ = std::make_unique<Ort::Session>(
|
|
||||||
env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_);
|
|
||||||
|
|
||||||
GetInputNames(joiner_sess_.get(), &joiner_input_names_,
|
|
||||||
&joiner_input_names_ptr_);
|
|
||||||
|
|
||||||
GetOutputNames(joiner_sess_.get(), &joiner_output_names_,
|
|
||||||
&joiner_output_names_ptr_);
|
|
||||||
}
|
|
||||||
|
|
||||||
void RnntModel::InitJoinerEncoderProj(const std::string &filename) {
|
|
||||||
joiner_encoder_proj_sess_ = std::make_unique<Ort::Session>(
|
|
||||||
env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_);
|
|
||||||
|
|
||||||
GetInputNames(joiner_encoder_proj_sess_.get(),
|
|
||||||
&joiner_encoder_proj_input_names_,
|
|
||||||
&joiner_encoder_proj_input_names_ptr_);
|
|
||||||
|
|
||||||
GetOutputNames(joiner_encoder_proj_sess_.get(),
|
|
||||||
&joiner_encoder_proj_output_names_,
|
|
||||||
&joiner_encoder_proj_output_names_ptr_);
|
|
||||||
}
|
|
||||||
|
|
||||||
void RnntModel::InitJoinerDecoderProj(const std::string &filename) {
|
|
||||||
joiner_decoder_proj_sess_ = std::make_unique<Ort::Session>(
|
|
||||||
env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_);
|
|
||||||
|
|
||||||
GetInputNames(joiner_decoder_proj_sess_.get(),
|
|
||||||
&joiner_decoder_proj_input_names_,
|
|
||||||
&joiner_decoder_proj_input_names_ptr_);
|
|
||||||
|
|
||||||
GetOutputNames(joiner_decoder_proj_sess_.get(),
|
|
||||||
&joiner_decoder_proj_output_names_,
|
|
||||||
&joiner_decoder_proj_output_names_ptr_);
|
|
||||||
}
|
|
||||||
|
|
||||||
Ort::Value RnntModel::RunEncoder(const float *features, int32_t T,
|
|
||||||
int32_t feature_dim) {
|
|
||||||
auto memory_info =
|
|
||||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
|
||||||
std::array<int64_t, 3> x_shape{1, T, feature_dim};
|
|
||||||
Ort::Value x =
|
|
||||||
Ort::Value::CreateTensor(memory_info, const_cast<float *>(features),
|
|
||||||
T * feature_dim, x_shape.data(), x_shape.size());
|
|
||||||
|
|
||||||
std::array<int64_t, 1> x_lens_shape{1};
|
|
||||||
int64_t x_lens_tmp = T;
|
|
||||||
|
|
||||||
Ort::Value x_lens = Ort::Value::CreateTensor(
|
|
||||||
memory_info, &x_lens_tmp, 1, x_lens_shape.data(), x_lens_shape.size());
|
|
||||||
|
|
||||||
std::array<Ort::Value, 2> encoder_inputs{std::move(x), std::move(x_lens)};
|
|
||||||
|
|
||||||
// Note: We discard encoder_out_lens since we only implement
|
|
||||||
// batch==1.
|
|
||||||
auto encoder_out = encoder_sess_->Run(
|
|
||||||
{}, encoder_input_names_ptr_.data(), encoder_inputs.data(),
|
|
||||||
encoder_inputs.size(), encoder_output_names_ptr_.data(),
|
|
||||||
encoder_output_names_ptr_.size());
|
|
||||||
return std::move(encoder_out[0]);
|
|
||||||
}
|
|
||||||
Ort::Value RnntModel::RunJoinerEncoderProj(const float *encoder_out, int32_t T,
|
|
||||||
int32_t encoder_out_dim) {
|
|
||||||
auto memory_info =
|
|
||||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
|
||||||
|
|
||||||
std::array<int64_t, 2> in_shape{T, encoder_out_dim};
|
|
||||||
Ort::Value in = Ort::Value::CreateTensor(
|
|
||||||
memory_info, const_cast<float *>(encoder_out), T * encoder_out_dim,
|
|
||||||
in_shape.data(), in_shape.size());
|
|
||||||
|
|
||||||
auto encoder_proj_out = joiner_encoder_proj_sess_->Run(
|
|
||||||
{}, joiner_encoder_proj_input_names_ptr_.data(), &in, 1,
|
|
||||||
joiner_encoder_proj_output_names_ptr_.data(),
|
|
||||||
joiner_encoder_proj_output_names_ptr_.size());
|
|
||||||
return std::move(encoder_proj_out[0]);
|
|
||||||
}
|
|
||||||
|
|
||||||
Ort::Value RnntModel::RunDecoder(const int64_t *decoder_input,
|
|
||||||
int32_t context_size) {
|
|
||||||
auto memory_info =
|
|
||||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
|
||||||
|
|
||||||
int32_t batch_size = 1; // TODO(fangjun): handle the case when it's > 1
|
|
||||||
std::array<int64_t, 2> shape{batch_size, context_size};
|
|
||||||
Ort::Value in = Ort::Value::CreateTensor(
|
|
||||||
memory_info, const_cast<int64_t *>(decoder_input),
|
|
||||||
batch_size * context_size, shape.data(), shape.size());
|
|
||||||
|
|
||||||
auto decoder_out = decoder_sess_->Run(
|
|
||||||
{}, decoder_input_names_ptr_.data(), &in, 1,
|
|
||||||
decoder_output_names_ptr_.data(), decoder_output_names_ptr_.size());
|
|
||||||
return std::move(decoder_out[0]);
|
|
||||||
}
|
|
||||||
|
|
||||||
Ort::Value RnntModel::RunJoinerDecoderProj(const float *decoder_out,
|
|
||||||
int32_t decoder_out_dim) {
|
|
||||||
auto memory_info =
|
|
||||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
|
||||||
|
|
||||||
int32_t batch_size = 1; // TODO(fangjun): handle the case when it's > 1
|
|
||||||
std::array<int64_t, 2> shape{batch_size, decoder_out_dim};
|
|
||||||
Ort::Value in = Ort::Value::CreateTensor(
|
|
||||||
memory_info, const_cast<float *>(decoder_out),
|
|
||||||
batch_size * decoder_out_dim, shape.data(), shape.size());
|
|
||||||
|
|
||||||
auto decoder_proj_out = joiner_decoder_proj_sess_->Run(
|
|
||||||
{}, joiner_decoder_proj_input_names_ptr_.data(), &in, 1,
|
|
||||||
joiner_decoder_proj_output_names_ptr_.data(),
|
|
||||||
joiner_decoder_proj_output_names_ptr_.size());
|
|
||||||
return std::move(decoder_proj_out[0]);
|
|
||||||
}
|
|
||||||
|
|
||||||
Ort::Value RnntModel::RunJoiner(const float *projected_encoder_out,
|
|
||||||
const float *projected_decoder_out,
|
|
||||||
int32_t joiner_dim) {
|
|
||||||
auto memory_info =
|
|
||||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
|
||||||
int32_t batch_size = 1; // TODO(fangjun): handle the case when it's > 1
|
|
||||||
std::array<int64_t, 2> shape{batch_size, joiner_dim};
|
|
||||||
|
|
||||||
Ort::Value enc = Ort::Value::CreateTensor(
|
|
||||||
memory_info, const_cast<float *>(projected_encoder_out),
|
|
||||||
batch_size * joiner_dim, shape.data(), shape.size());
|
|
||||||
|
|
||||||
Ort::Value dec = Ort::Value::CreateTensor(
|
|
||||||
memory_info, const_cast<float *>(projected_decoder_out),
|
|
||||||
batch_size * joiner_dim, shape.data(), shape.size());
|
|
||||||
|
|
||||||
std::array<Ort::Value, 2> inputs{std::move(enc), std::move(dec)};
|
|
||||||
|
|
||||||
auto logit = joiner_sess_->Run(
|
|
||||||
{}, joiner_input_names_ptr_.data(), inputs.data(), inputs.size(),
|
|
||||||
joiner_output_names_ptr_.data(), joiner_output_names_ptr_.size());
|
|
||||||
|
|
||||||
return std::move(logit[0]);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
|
||||||
@@ -1,148 +0,0 @@
|
|||||||
/**
|
|
||||||
* Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
|
|
||||||
*
|
|
||||||
* See LICENSE for clarification regarding multiple authors
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef SHERPA_ONNX_CSRC_RNNT_MODEL_H_
|
|
||||||
#define SHERPA_ONNX_CSRC_RNNT_MODEL_H_
|
|
||||||
|
|
||||||
#include <memory>
|
|
||||||
#include <string>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
|
||||||
|
|
||||||
namespace sherpa_onnx {
|
|
||||||
|
|
||||||
class RnntModel {
|
|
||||||
public:
|
|
||||||
/**
|
|
||||||
* @param encoder_filename Path to the encoder model
|
|
||||||
* @param decoder_filename Path to the decoder model
|
|
||||||
* @param joiner_filename Path to the joiner model
|
|
||||||
* @param joiner_encoder_proj_filename Path to the joiner encoder_proj model
|
|
||||||
* @param joiner_decoder_proj_filename Path to the joiner decoder_proj model
|
|
||||||
* @param num_threads Number of threads to use to run the models
|
|
||||||
*/
|
|
||||||
RnntModel(const std::string &encoder_filename,
|
|
||||||
const std::string &decoder_filename,
|
|
||||||
const std::string &joiner_filename,
|
|
||||||
const std::string &joiner_encoder_proj_filename,
|
|
||||||
const std::string &joiner_decoder_proj_filename,
|
|
||||||
int32_t num_threads);
|
|
||||||
|
|
||||||
/** Run the encoder model.
|
|
||||||
*
|
|
||||||
* @TODO(fangjun): Support batch_size > 1
|
|
||||||
*
|
|
||||||
* @param features A tensor of shape (batch_size, T, feature_dim)
|
|
||||||
* @param T Number of feature frames
|
|
||||||
* @param feature_dim Dimension of the feature.
|
|
||||||
*
|
|
||||||
* @return Return a tensor of shape (batch_size, T', encoder_out_dim)
|
|
||||||
*/
|
|
||||||
Ort::Value RunEncoder(const float *features, int32_t T, int32_t feature_dim);
|
|
||||||
|
|
||||||
/** Run the joiner encoder_proj model.
|
|
||||||
*
|
|
||||||
* @param encoder_out A tensor of shape (T, encoder_out_dim)
|
|
||||||
* @param T Number of frames in encoder_out.
|
|
||||||
* @param encoder_out_dim Dimension of encoder_out.
|
|
||||||
*
|
|
||||||
* @return Return a tensor of shape (T, joiner_dim)
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
Ort::Value RunJoinerEncoderProj(const float *encoder_out, int32_t T,
|
|
||||||
int32_t encoder_out_dim);
|
|
||||||
|
|
||||||
/** Run the decoder model.
|
|
||||||
*
|
|
||||||
* @TODO(fangjun): Support batch_size > 1
|
|
||||||
*
|
|
||||||
* @param decoder_input A tensor of shape (batch_size, context_size).
|
|
||||||
* @return Return a tensor of shape (batch_size, 1, decoder_out_dim)
|
|
||||||
*/
|
|
||||||
Ort::Value RunDecoder(const int64_t *decoder_input, int32_t context_size);
|
|
||||||
|
|
||||||
/** Run joiner decoder_proj model.
|
|
||||||
*
|
|
||||||
* @TODO(fangjun): Support batch_size > 1
|
|
||||||
*
|
|
||||||
* @param decoder_out A tensor of shape (batch_size, decoder_out_dim)
|
|
||||||
* @param decoder_out_dim Output dimension of the decoder_out.
|
|
||||||
*
|
|
||||||
* @return Return a tensor of shape (batch_size, joiner_dim);
|
|
||||||
*/
|
|
||||||
Ort::Value RunJoinerDecoderProj(const float *decoder_out,
|
|
||||||
int32_t decoder_out_dim);
|
|
||||||
|
|
||||||
/** Run the joiner model.
|
|
||||||
*
|
|
||||||
* @TODO(fangjun): Support batch_size > 1
|
|
||||||
*
|
|
||||||
* @param projected_encoder_out A tensor of shape (batch_size, joiner_dim).
|
|
||||||
* @param projected_decoder_out A tensor of shape (batch_size, joiner_dim).
|
|
||||||
*
|
|
||||||
* @return Return a tensor of shape (batch_size, vocab_size)
|
|
||||||
*/
|
|
||||||
Ort::Value RunJoiner(const float *projected_encoder_out,
|
|
||||||
const float *projected_decoder_out, int32_t joiner_dim);
|
|
||||||
|
|
||||||
private:
|
|
||||||
void InitEncoder(const std::string &encoder_filename);
|
|
||||||
void InitDecoder(const std::string &decoder_filename);
|
|
||||||
void InitJoiner(const std::string &joiner_filename);
|
|
||||||
void InitJoinerEncoderProj(const std::string &joiner_encoder_proj_filename);
|
|
||||||
void InitJoinerDecoderProj(const std::string &joiner_decoder_proj_filename);
|
|
||||||
|
|
||||||
private:
|
|
||||||
Ort::Env env_;
|
|
||||||
Ort::SessionOptions sess_opts_;
|
|
||||||
std::unique_ptr<Ort::Session> encoder_sess_;
|
|
||||||
std::unique_ptr<Ort::Session> decoder_sess_;
|
|
||||||
std::unique_ptr<Ort::Session> joiner_sess_;
|
|
||||||
std::unique_ptr<Ort::Session> joiner_encoder_proj_sess_;
|
|
||||||
std::unique_ptr<Ort::Session> joiner_decoder_proj_sess_;
|
|
||||||
|
|
||||||
std::vector<std::string> encoder_input_names_;
|
|
||||||
std::vector<const char *> encoder_input_names_ptr_;
|
|
||||||
std::vector<std::string> encoder_output_names_;
|
|
||||||
std::vector<const char *> encoder_output_names_ptr_;
|
|
||||||
|
|
||||||
std::vector<std::string> decoder_input_names_;
|
|
||||||
std::vector<const char *> decoder_input_names_ptr_;
|
|
||||||
std::vector<std::string> decoder_output_names_;
|
|
||||||
std::vector<const char *> decoder_output_names_ptr_;
|
|
||||||
|
|
||||||
std::vector<std::string> joiner_input_names_;
|
|
||||||
std::vector<const char *> joiner_input_names_ptr_;
|
|
||||||
std::vector<std::string> joiner_output_names_;
|
|
||||||
std::vector<const char *> joiner_output_names_ptr_;
|
|
||||||
|
|
||||||
std::vector<std::string> joiner_encoder_proj_input_names_;
|
|
||||||
std::vector<const char *> joiner_encoder_proj_input_names_ptr_;
|
|
||||||
std::vector<std::string> joiner_encoder_proj_output_names_;
|
|
||||||
std::vector<const char *> joiner_encoder_proj_output_names_ptr_;
|
|
||||||
|
|
||||||
std::vector<std::string> joiner_decoder_proj_input_names_;
|
|
||||||
std::vector<const char *> joiner_decoder_proj_input_names_ptr_;
|
|
||||||
std::vector<std::string> joiner_decoder_proj_output_names_;
|
|
||||||
std::vector<const char *> joiner_decoder_proj_output_names_ptr_;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
|
||||||
|
|
||||||
#endif // SHERPA_ONNX_CSRC_RNNT_MODEL_H_
|
|
||||||
@@ -1,78 +1,22 @@
|
|||||||
/**
|
// sherpa-onnx/csrc/sherpa-onnx.cc
|
||||||
* Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
|
//
|
||||||
*
|
// Copyright (c) 2022-2023 Xiaomi Corporation
|
||||||
* See LICENSE for clarification regarding multiple authors
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
|
#include <chrono> // NOLINT
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "kaldi-native-fbank/csrc/online-feature.h"
|
#include "kaldi-native-fbank/csrc/online-feature.h"
|
||||||
#include "sherpa-onnx/csrc/decode.h"
|
#include "sherpa-onnx/csrc/decode.h"
|
||||||
#include "sherpa-onnx/csrc/rnnt-model.h"
|
#include "sherpa-onnx/csrc/features.h"
|
||||||
|
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
|
||||||
|
#include "sherpa-onnx/csrc/online-transducer-model.h"
|
||||||
#include "sherpa-onnx/csrc/symbol-table.h"
|
#include "sherpa-onnx/csrc/symbol-table.h"
|
||||||
#include "sherpa-onnx/csrc/wave-reader.h"
|
#include "sherpa-onnx/csrc/wave-reader.h"
|
||||||
|
|
||||||
/** Compute fbank features of the input wave filename.
|
|
||||||
*
|
|
||||||
* @param wav_filename. Path to a mono wave file.
|
|
||||||
* @param expected_sampling_rate Expected sampling rate of the input wave file.
|
|
||||||
* @param num_frames On return, it contains the number of feature frames.
|
|
||||||
* @return Return the computed feature of shape (num_frames, feature_dim)
|
|
||||||
* stored in row-major.
|
|
||||||
*/
|
|
||||||
static std::vector<float> ComputeFeatures(const std::string &wav_filename,
|
|
||||||
float expected_sampling_rate,
|
|
||||||
int32_t *num_frames) {
|
|
||||||
std::vector<float> samples =
|
|
||||||
sherpa_onnx::ReadWave(wav_filename, expected_sampling_rate);
|
|
||||||
|
|
||||||
float duration = samples.size() / expected_sampling_rate;
|
|
||||||
|
|
||||||
std::cout << "wav filename: " << wav_filename << "\n";
|
|
||||||
std::cout << "wav duration (s): " << duration << "\n";
|
|
||||||
|
|
||||||
knf::FbankOptions opts;
|
|
||||||
opts.frame_opts.dither = 0;
|
|
||||||
opts.frame_opts.snip_edges = false;
|
|
||||||
opts.frame_opts.samp_freq = expected_sampling_rate;
|
|
||||||
|
|
||||||
int32_t feature_dim = 80;
|
|
||||||
|
|
||||||
opts.mel_opts.num_bins = feature_dim;
|
|
||||||
|
|
||||||
knf::OnlineFbank fbank(opts);
|
|
||||||
fbank.AcceptWaveform(expected_sampling_rate, samples.data(), samples.size());
|
|
||||||
fbank.InputFinished();
|
|
||||||
|
|
||||||
*num_frames = fbank.NumFramesReady();
|
|
||||||
|
|
||||||
std::vector<float> features(*num_frames * feature_dim);
|
|
||||||
float *p = features.data();
|
|
||||||
|
|
||||||
for (int32_t i = 0; i != fbank.NumFramesReady(); ++i, p += feature_dim) {
|
|
||||||
const float *f = fbank.GetFrame(i);
|
|
||||||
std::copy(f, f + feature_dim, p);
|
|
||||||
}
|
|
||||||
|
|
||||||
return features;
|
|
||||||
}
|
|
||||||
|
|
||||||
int main(int32_t argc, char *argv[]) {
|
int main(int32_t argc, char *argv[]) {
|
||||||
if (argc < 8 || argc > 9) {
|
if (argc < 6 || argc > 7) {
|
||||||
const char *usage = R"usage(
|
const char *usage = R"usage(
|
||||||
Usage:
|
Usage:
|
||||||
./bin/sherpa-onnx \
|
./bin/sherpa-onnx \
|
||||||
@@ -80,12 +24,11 @@ Usage:
|
|||||||
/path/to/encoder.onnx \
|
/path/to/encoder.onnx \
|
||||||
/path/to/decoder.onnx \
|
/path/to/decoder.onnx \
|
||||||
/path/to/joiner.onnx \
|
/path/to/joiner.onnx \
|
||||||
/path/to/joiner_encoder_proj.onnx \
|
|
||||||
/path/to/joiner_decoder_proj.onnx \
|
|
||||||
/path/to/foo.wav [num_threads]
|
/path/to/foo.wav [num_threads]
|
||||||
|
|
||||||
You can download pre-trained models from the following repository:
|
Please refer to
|
||||||
https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13
|
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
|
||||||
|
for a list of pre-trained models to download.
|
||||||
)usage";
|
)usage";
|
||||||
std::cerr << usage << "\n";
|
std::cerr << usage << "\n";
|
||||||
|
|
||||||
@@ -93,37 +36,102 @@ https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stat
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::string tokens = argv[1];
|
std::string tokens = argv[1];
|
||||||
std::string encoder = argv[2];
|
sherpa_onnx::OnlineTransducerModelConfig config;
|
||||||
std::string decoder = argv[3];
|
config.debug = true;
|
||||||
std::string joiner = argv[4];
|
config.encoder_filename = argv[2];
|
||||||
std::string joiner_encoder_proj = argv[5];
|
config.decoder_filename = argv[3];
|
||||||
std::string joiner_decoder_proj = argv[6];
|
config.joiner_filename = argv[4];
|
||||||
std::string wav_filename = argv[7];
|
std::string wav_filename = argv[5];
|
||||||
int32_t num_threads = 4;
|
|
||||||
if (argc == 9) {
|
config.num_threads = 2;
|
||||||
num_threads = atoi(argv[8]);
|
if (argc == 7) {
|
||||||
|
config.num_threads = atoi(argv[6]);
|
||||||
}
|
}
|
||||||
|
std::cout << config.ToString().c_str() << "\n";
|
||||||
|
|
||||||
|
auto model = sherpa_onnx::OnlineTransducerModel::Create(config);
|
||||||
|
|
||||||
sherpa_onnx::SymbolTable sym(tokens);
|
sherpa_onnx::SymbolTable sym(tokens);
|
||||||
|
|
||||||
int32_t num_frames;
|
Ort::AllocatorWithDefaultOptions allocator;
|
||||||
auto features = ComputeFeatures(wav_filename, 16000, &num_frames);
|
|
||||||
int32_t feature_dim = features.size() / num_frames;
|
|
||||||
|
|
||||||
sherpa_onnx::RnntModel model(encoder, decoder, joiner, joiner_encoder_proj,
|
int32_t chunk_size = model->ChunkSize();
|
||||||
joiner_decoder_proj, num_threads);
|
int32_t chunk_shift = model->ChunkShift();
|
||||||
Ort::Value encoder_out =
|
|
||||||
model.RunEncoder(features.data(), num_frames, feature_dim);
|
|
||||||
|
|
||||||
auto hyp = sherpa_onnx::GreedySearch(model, encoder_out);
|
auto memory_info =
|
||||||
|
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||||
|
|
||||||
std::string text;
|
std::vector<Ort::Value> states = model->GetEncoderInitStates();
|
||||||
for (auto i : hyp) {
|
|
||||||
text += sym[i];
|
std::vector<int64_t> hyp(model->ContextSize(), 0);
|
||||||
|
|
||||||
|
int32_t expected_sampling_rate = 16000;
|
||||||
|
|
||||||
|
bool is_ok = false;
|
||||||
|
std::vector<float> samples =
|
||||||
|
sherpa_onnx::ReadWave(wav_filename, expected_sampling_rate, &is_ok);
|
||||||
|
|
||||||
|
if (!is_ok) {
|
||||||
|
std::cerr << "Failed to read " << wav_filename << "\n";
|
||||||
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const float duration = samples.size() / expected_sampling_rate;
|
||||||
|
|
||||||
|
std::cout << "wav filename: " << wav_filename << "\n";
|
||||||
|
std::cout << "wav duration (s): " << duration << "\n";
|
||||||
|
|
||||||
|
auto begin = std::chrono::steady_clock::now();
|
||||||
|
std::cout << "Started!\n";
|
||||||
|
|
||||||
|
sherpa_onnx::FeatureExtractor feat_extractor;
|
||||||
|
feat_extractor.AcceptWaveform(expected_sampling_rate, samples.data(),
|
||||||
|
samples.size());
|
||||||
|
|
||||||
|
std::vector<float> tail_paddings(
|
||||||
|
static_cast<int>(0.2 * expected_sampling_rate));
|
||||||
|
feat_extractor.AcceptWaveform(expected_sampling_rate, tail_paddings.data(),
|
||||||
|
tail_paddings.size());
|
||||||
|
feat_extractor.InputFinished();
|
||||||
|
|
||||||
|
int32_t num_frames = feat_extractor.NumFramesReady();
|
||||||
|
int32_t feature_dim = feat_extractor.FeatureDim();
|
||||||
|
|
||||||
|
std::array<int64_t, 3> x_shape{1, chunk_size, feature_dim};
|
||||||
|
|
||||||
|
for (int32_t start = 0; start + chunk_size < num_frames;
|
||||||
|
start += chunk_shift) {
|
||||||
|
std::vector<float> features = feat_extractor.GetFrames(start, chunk_size);
|
||||||
|
|
||||||
|
Ort::Value x =
|
||||||
|
Ort::Value::CreateTensor(memory_info, features.data(), features.size(),
|
||||||
|
x_shape.data(), x_shape.size());
|
||||||
|
auto pair = model->RunEncoder(std::move(x), states);
|
||||||
|
states = std::move(pair.second);
|
||||||
|
sherpa_onnx::GreedySearch(model.get(), std::move(pair.first), &hyp);
|
||||||
|
}
|
||||||
|
std::string text;
|
||||||
|
for (size_t i = model->ContextSize(); i != hyp.size(); ++i) {
|
||||||
|
text += sym[hyp[i]];
|
||||||
|
}
|
||||||
|
|
||||||
|
std::cout << "Done!\n";
|
||||||
|
|
||||||
std::cout << "Recognition result for " << wav_filename << "\n"
|
std::cout << "Recognition result for " << wav_filename << "\n"
|
||||||
<< text << "\n";
|
<< text << "\n";
|
||||||
|
|
||||||
|
auto end = std::chrono::steady_clock::now();
|
||||||
|
float elapsed_seconds =
|
||||||
|
std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)
|
||||||
|
.count() /
|
||||||
|
1000.;
|
||||||
|
|
||||||
|
std::cout << "num threads: " << config.num_threads << "\n";
|
||||||
|
|
||||||
|
fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
|
||||||
|
float rtf = elapsed_seconds / duration;
|
||||||
|
fprintf(stderr, "Real time factor (RTF): %.3f / %.3f = %.3f\n",
|
||||||
|
elapsed_seconds, duration, rtf);
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,20 +1,7 @@
|
|||||||
/**
|
// sherpa-onnx/csrc/show-onnx-info.cc
|
||||||
* Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
|
//
|
||||||
*
|
// Copyright (c) 2022-2023 Xiaomi Corporation
|
||||||
* See LICENSE for clarification regarding multiple authors
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
|
|||||||
@@ -1,20 +1,6 @@
|
|||||||
/**
|
// sherpa-onnx/csrc/symbol-table.cc
|
||||||
* Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang)
|
//
|
||||||
*
|
// Copyright (c) 2022-2023 Xiaomi Corporation
|
||||||
* See LICENSE for clarification regarding multiple authors
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include "sherpa-onnx/csrc/symbol-table.h"
|
#include "sherpa-onnx/csrc/symbol-table.h"
|
||||||
|
|
||||||
|
|||||||
@@ -1,20 +1,6 @@
|
|||||||
/**
|
// sherpa-onnx/csrc/symbol-table.cc
|
||||||
* Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang)
|
//
|
||||||
*
|
// Copyright (c) 2022-2023 Xiaomi Corporation
|
||||||
* See LICENSE for clarification regarding multiple authors
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef SHERPA_ONNX_CSRC_SYMBOL_TABLE_H_
|
#ifndef SHERPA_ONNX_CSRC_SYMBOL_TABLE_H_
|
||||||
#define SHERPA_ONNX_CSRC_SYMBOL_TABLE_H_
|
#define SHERPA_ONNX_CSRC_SYMBOL_TABLE_H_
|
||||||
|
|||||||
@@ -1,20 +1,6 @@
|
|||||||
/**
|
// sherpa/csrc/wave-reader.cc
|
||||||
* Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang)
|
//
|
||||||
*
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
* See LICENSE for clarification regarding multiple authors
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include "sherpa-onnx/csrc/wave-reader.h"
|
#include "sherpa-onnx/csrc/wave-reader.h"
|
||||||
|
|
||||||
@@ -31,19 +17,44 @@ namespace {
|
|||||||
// Note: We assume little endian here
|
// Note: We assume little endian here
|
||||||
// TODO(fangjun): Support big endian
|
// TODO(fangjun): Support big endian
|
||||||
struct WaveHeader {
|
struct WaveHeader {
|
||||||
void Validate() const {
|
bool Validate() const {
|
||||||
// F F I R
|
// F F I R
|
||||||
assert(chunk_id == 0x46464952);
|
if (chunk_id != 0x46464952) {
|
||||||
assert(chunk_size == 36 + subchunk2_size);
|
return false;
|
||||||
// E V A W
|
}
|
||||||
assert(format == 0x45564157);
|
// E V A W
|
||||||
assert(subchunk1_id == 0x20746d66);
|
if (format != 0x45564157) {
|
||||||
assert(subchunk1_size == 16); // 16 for PCM
|
return false;
|
||||||
assert(audio_format == 1); // 1 for PCM
|
}
|
||||||
assert(num_channels == 1); // we support only single channel for now
|
|
||||||
assert(byte_rate == sample_rate * num_channels * bits_per_sample / 8);
|
if (subchunk1_id != 0x20746d66) {
|
||||||
assert(block_align == num_channels * bits_per_sample / 8);
|
return false;
|
||||||
assert(bits_per_sample == 16); // we support only 16 bits per sample
|
}
|
||||||
|
|
||||||
|
if (subchunk1_size != 16) { // 16 for PCM
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (audio_format != 1) { // 1 for PCM
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (num_channels != 1) { // we support only single channel for now
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (byte_rate != (sample_rate * num_channels * bits_per_sample / 8)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (block_align != (num_channels * bits_per_sample / 8)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (bits_per_sample != 16) { // we support only 16 bits per sample
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// See
|
// See
|
||||||
@@ -52,7 +63,7 @@ struct WaveHeader {
|
|||||||
// https://www.robotplanet.dk/audio/wav_meta_data/riff_mci.pdf
|
// https://www.robotplanet.dk/audio/wav_meta_data/riff_mci.pdf
|
||||||
void SeekToDataChunk(std::istream &is) {
|
void SeekToDataChunk(std::istream &is) {
|
||||||
// a t a d
|
// a t a d
|
||||||
while (subchunk2_id != 0x61746164) {
|
while (is && subchunk2_id != 0x61746164) {
|
||||||
// const char *p = reinterpret_cast<const char *>(&subchunk2_id);
|
// const char *p = reinterpret_cast<const char *>(&subchunk2_id);
|
||||||
// printf("Skip chunk (%x): %c%c%c%c of size: %d\n", subchunk2_id, p[0],
|
// printf("Skip chunk (%x): %c%c%c%c of size: %d\n", subchunk2_id, p[0],
|
||||||
// p[1], p[2], p[3], subchunk2_size);
|
// p[1], p[2], p[3], subchunk2_size);
|
||||||
@@ -80,44 +91,61 @@ static_assert(sizeof(WaveHeader) == 44, "");
|
|||||||
|
|
||||||
// Read a wave file of mono-channel.
|
// Read a wave file of mono-channel.
|
||||||
// Return its samples normalized to the range [-1, 1).
|
// Return its samples normalized to the range [-1, 1).
|
||||||
std::vector<float> ReadWaveImpl(std::istream &is, float *sample_rate) {
|
std::vector<float> ReadWaveImpl(std::istream &is, float expected_sample_rate,
|
||||||
|
bool *is_ok) {
|
||||||
WaveHeader header;
|
WaveHeader header;
|
||||||
is.read(reinterpret_cast<char *>(&header), sizeof(header));
|
is.read(reinterpret_cast<char *>(&header), sizeof(header));
|
||||||
assert(static_cast<bool>(is));
|
if (!is) {
|
||||||
header.Validate();
|
*is_ok = false;
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!header.Validate()) {
|
||||||
|
*is_ok = false;
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
header.SeekToDataChunk(is);
|
header.SeekToDataChunk(is);
|
||||||
|
if (!is) {
|
||||||
|
*is_ok = false;
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
*sample_rate = header.sample_rate;
|
if (expected_sample_rate != header.sample_rate) {
|
||||||
|
*is_ok = false;
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
// header.subchunk2_size contains the number of bytes in the data.
|
// header.subchunk2_size contains the number of bytes in the data.
|
||||||
// As we assume each sample contains two bytes, so it is divided by 2 here
|
// As we assume each sample contains two bytes, so it is divided by 2 here
|
||||||
std::vector<int16_t> samples(header.subchunk2_size / 2);
|
std::vector<int16_t> samples(header.subchunk2_size / 2);
|
||||||
|
|
||||||
is.read(reinterpret_cast<char *>(samples.data()), header.subchunk2_size);
|
is.read(reinterpret_cast<char *>(samples.data()), header.subchunk2_size);
|
||||||
|
if (!is) {
|
||||||
assert(static_cast<bool>(is));
|
*is_ok = false;
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<float> ans(samples.size());
|
std::vector<float> ans(samples.size());
|
||||||
for (int32_t i = 0; i != ans.size(); ++i) {
|
for (int32_t i = 0; i != ans.size(); ++i) {
|
||||||
ans[i] = samples[i] / 32768.;
|
ans[i] = samples[i] / 32768.;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
*is_ok = true;
|
||||||
return ans;
|
return ans;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::vector<float> ReadWave(const std::string &filename,
|
std::vector<float> ReadWave(const std::string &filename,
|
||||||
float expected_sample_rate) {
|
float expected_sample_rate, bool *is_ok) {
|
||||||
std::ifstream is(filename, std::ifstream::binary);
|
std::ifstream is(filename, std::ifstream::binary);
|
||||||
float sample_rate;
|
return ReadWave(is, expected_sample_rate, is_ok);
|
||||||
auto samples = ReadWaveImpl(is, &sample_rate);
|
}
|
||||||
if (expected_sample_rate != sample_rate) {
|
|
||||||
std::cerr << "Expected sample rate: " << expected_sample_rate
|
std::vector<float> ReadWave(std::istream &is, float expected_sample_rate,
|
||||||
<< ". Given: " << sample_rate << ".\n";
|
bool *is_ok) {
|
||||||
exit(-1);
|
auto samples = ReadWaveImpl(is, expected_sample_rate, is_ok);
|
||||||
}
|
|
||||||
return samples;
|
return samples;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,20 +1,6 @@
|
|||||||
/**
|
// sherpa/csrc/wave-reader.h
|
||||||
* Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang)
|
//
|
||||||
*
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
* See LICENSE for clarification regarding multiple authors
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef SHERPA_ONNX_CSRC_WAVE_READER_H_
|
#ifndef SHERPA_ONNX_CSRC_WAVE_READER_H_
|
||||||
#define SHERPA_ONNX_CSRC_WAVE_READER_H_
|
#define SHERPA_ONNX_CSRC_WAVE_READER_H_
|
||||||
@@ -30,11 +16,15 @@ namespace sherpa_onnx {
|
|||||||
@param filename Path to a wave file. It MUST be single channel, PCM encoded.
|
@param filename Path to a wave file. It MUST be single channel, PCM encoded.
|
||||||
@param expected_sample_rate Expected sample rate of the wave file. If the
|
@param expected_sample_rate Expected sample rate of the wave file. If the
|
||||||
sample rate don't match, it throws an exception.
|
sample rate don't match, it throws an exception.
|
||||||
|
@param is_ok On return it is true if the reading succeeded; false otherwise.
|
||||||
|
|
||||||
@return Return wave samples normalized to the range [-1, 1).
|
@return Return wave samples normalized to the range [-1, 1).
|
||||||
*/
|
*/
|
||||||
std::vector<float> ReadWave(const std::string &filename,
|
std::vector<float> ReadWave(const std::string &filename,
|
||||||
float expected_sample_rate);
|
float expected_sample_rate, bool *is_ok);
|
||||||
|
|
||||||
|
std::vector<float> ReadWave(std::istream &is, float expected_sample_rate,
|
||||||
|
bool *is_ok);
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user