Refactor the code (#15)
* code refactoring * Remove reference files * Update README and CI * small fixes * fix style issues * add style check for CI * fix style issues * remove kaldi-native-io
This commit is contained in:
9
.clang-format
Normal file
9
.clang-format
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
---
|
||||||
|
BasedOnStyle: Google
|
||||||
|
---
|
||||||
|
Language: Cpp
|
||||||
|
Cpp11BracedListStyle: true
|
||||||
|
Standard: Cpp11
|
||||||
|
DerivePointerAlignment: false
|
||||||
|
PointerAlignment: Right
|
||||||
|
---
|
||||||
58
.github/workflows/style_check.yaml
vendored
Normal file
58
.github/workflows/style_check.yaml
vendored
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
# Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
name: style_check
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- master
|
||||||
|
paths:
|
||||||
|
- '.github/workflows/style_check.yaml'
|
||||||
|
- 'sherpa-onnx/**'
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- master
|
||||||
|
paths:
|
||||||
|
- '.github/workflows/style_check.yaml'
|
||||||
|
- 'sherpa-onnx/**'
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: style_check-${{ github.ref }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
style_check:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python-version: [3.8]
|
||||||
|
fail-fast: false
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- name: Setup Python ${{ matrix.python-version }}
|
||||||
|
uses: actions/setup-python@v1
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
||||||
|
- name: Check style with cpplint
|
||||||
|
shell: bash
|
||||||
|
working-directory: ${{github.workspace}}
|
||||||
|
run: ./scripts/check_style_cpplint.sh
|
||||||
9
.github/workflows/test-linux-macos.yaml
vendored
9
.github/workflows/test-linux-macos.yaml
vendored
@@ -59,31 +59,28 @@ jobs:
|
|||||||
- name: Run tests for ubuntu/macos (English)
|
- name: Run tests for ubuntu/macos (English)
|
||||||
run: |
|
run: |
|
||||||
time ./build/bin/sherpa-onnx \
|
time ./build/bin/sherpa-onnx \
|
||||||
|
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \
|
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \
|
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \
|
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \
|
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \
|
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \
|
|
||||||
greedy \
|
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1089-134686-0001.wav
|
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1089-134686-0001.wav
|
||||||
|
|
||||||
time ./build/bin/sherpa-onnx \
|
time ./build/bin/sherpa-onnx \
|
||||||
|
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \
|
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \
|
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \
|
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \
|
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \
|
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \
|
|
||||||
greedy \
|
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0001.wav
|
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0001.wav
|
||||||
|
|
||||||
time ./build/bin/sherpa-onnx \
|
time ./build/bin/sherpa-onnx \
|
||||||
|
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \
|
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \
|
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \
|
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \
|
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \
|
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \
|
|
||||||
greedy \
|
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0002.wav
|
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0002.wav
|
||||||
|
|||||||
@@ -38,7 +38,6 @@ set(CMAKE_CXX_EXTENSIONS OFF)
|
|||||||
list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules)
|
list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules)
|
||||||
list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake)
|
list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake)
|
||||||
|
|
||||||
include(kaldi_native_io)
|
|
||||||
include(kaldi-native-fbank)
|
include(kaldi-native-fbank)
|
||||||
include(onnxruntime)
|
include(onnxruntime)
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,9 @@ the following links:
|
|||||||
**NOTE**: We provide only non-streaming models at present.
|
**NOTE**: We provide only non-streaming models at present.
|
||||||
|
|
||||||
|
|
||||||
|
**HINT**: The script for exporting the English model can be found at
|
||||||
|
<https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless3/export.py>
|
||||||
|
|
||||||
# Usage
|
# Usage
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
@@ -34,13 +37,14 @@ cd ..
|
|||||||
git lfs install
|
git lfs install
|
||||||
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13
|
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13
|
||||||
|
|
||||||
|
./build/bin/sherpa-onnx --help
|
||||||
|
|
||||||
./build/bin/sherpa-onnx \
|
./build/bin/sherpa-onnx \
|
||||||
|
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \
|
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \
|
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \
|
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \
|
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \
|
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \
|
|
||||||
greedy \
|
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1089-134686-0001.wav
|
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1089-134686-0001.wav
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -1,39 +0,0 @@
|
|||||||
function(download_kaldi_native_io)
|
|
||||||
if(CMAKE_VERSION VERSION_LESS 3.11)
|
|
||||||
# FetchContent is available since 3.11,
|
|
||||||
# we've copied it to ${CMAKE_SOURCE_DIR}/cmake/Modules
|
|
||||||
# so that it can be used in lower CMake versions.
|
|
||||||
message(STATUS "Use FetchContent provided by sherpa-onnx")
|
|
||||||
list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
include(FetchContent)
|
|
||||||
|
|
||||||
set(kaldi_native_io_URL "https://github.com/csukuangfj/kaldi_native_io/archive/refs/tags/v1.15.1.tar.gz")
|
|
||||||
set(kaldi_native_io_HASH "SHA256=97377e1d61e99d8fc1d6037a418d3037522dfa46337e06162e24b1d97f3d70a6")
|
|
||||||
|
|
||||||
set(KALDI_NATIVE_IO_BUILD_TESTS OFF CACHE BOOL "" FORCE)
|
|
||||||
set(KALDI_NATIVE_IO_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
|
|
||||||
|
|
||||||
FetchContent_Declare(kaldi_native_io
|
|
||||||
URL ${kaldi_native_io_URL}
|
|
||||||
URL_HASH ${kaldi_native_io_HASH}
|
|
||||||
)
|
|
||||||
|
|
||||||
FetchContent_GetProperties(kaldi_native_io)
|
|
||||||
if(NOT kaldi_native_io_POPULATED)
|
|
||||||
message(STATUS "Downloading kaldi_native_io ${kaldi_native_io_URL}")
|
|
||||||
FetchContent_Populate(kaldi_native_io)
|
|
||||||
endif()
|
|
||||||
message(STATUS "kaldi_native_io is downloaded to ${kaldi_native_io_SOURCE_DIR}")
|
|
||||||
message(STATUS "kaldi_native_io's binary dir is ${kaldi_native_io_BINARY_DIR}")
|
|
||||||
|
|
||||||
add_subdirectory(${kaldi_native_io_SOURCE_DIR} ${kaldi_native_io_BINARY_DIR} EXCLUDE_FROM_ALL)
|
|
||||||
|
|
||||||
target_include_directories(kaldi_native_io_core
|
|
||||||
PUBLIC
|
|
||||||
${kaldi_native_io_SOURCE_DIR}/
|
|
||||||
)
|
|
||||||
endfunction()
|
|
||||||
|
|
||||||
download_kaldi_native_io()
|
|
||||||
@@ -10,7 +10,7 @@ function(download_onnxruntime)
|
|||||||
include(FetchContent)
|
include(FetchContent)
|
||||||
|
|
||||||
if(UNIX AND NOT APPLE)
|
if(UNIX AND NOT APPLE)
|
||||||
set(onnxruntime_URL "http://github.com/microsoft/onnxruntime/releases/download/v1.12.1/onnxruntime-linux-x64-1.12.1.tgz")
|
set(onnxruntime_URL "https://github.com/microsoft/onnxruntime/releases/download/v1.12.1/onnxruntime-linux-x64-1.12.1.tgz")
|
||||||
|
|
||||||
# If you don't have access to the internet, you can first download onnxruntime to some directory, and the use
|
# If you don't have access to the internet, you can first download onnxruntime to some directory, and the use
|
||||||
# set(onnxruntime_URL "file:///ceph-fj/fangjun/open-source/sherpa-onnx/onnxruntime-linux-x64-1.12.1.tgz")
|
# set(onnxruntime_URL "file:///ceph-fj/fangjun/open-source/sherpa-onnx/onnxruntime-linux-x64-1.12.1.tgz")
|
||||||
|
|||||||
126
scripts/check_style_cpplint.sh
Executable file
126
scripts/check_style_cpplint.sh
Executable file
@@ -0,0 +1,126 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
#
|
||||||
|
# Copyright 2020 Mobvoi Inc. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
# Usage:
|
||||||
|
#
|
||||||
|
# (1) To check files of the last commit
|
||||||
|
# ./scripts/check_style_cpplint.sh
|
||||||
|
#
|
||||||
|
# (2) To check changed files not committed yet
|
||||||
|
# ./scripts/check_style_cpplint.sh 1
|
||||||
|
#
|
||||||
|
# (3) To check all files in the project
|
||||||
|
# ./scripts/check_style_cpplint.sh 2
|
||||||
|
|
||||||
|
|
||||||
|
cpplint_version="1.5.4"
|
||||||
|
cur_dir=$(cd $(dirname $BASH_SOURCE) && pwd)
|
||||||
|
sherpa_onnx_dir=$(cd $cur_dir/.. && pwd)
|
||||||
|
|
||||||
|
build_dir=$sherpa_onnx_dir/build
|
||||||
|
mkdir -p $build_dir
|
||||||
|
|
||||||
|
cpplint_src=$build_dir/cpplint-${cpplint_version}/cpplint.py
|
||||||
|
|
||||||
|
if [ ! -d "$build_dir/cpplint-${cpplint_version}" ]; then
|
||||||
|
pushd $build_dir
|
||||||
|
if command -v wget &> /dev/null; then
|
||||||
|
wget https://github.com/cpplint/cpplint/archive/${cpplint_version}.tar.gz
|
||||||
|
elif command -v curl &> /dev/null; then
|
||||||
|
curl -O -SL https://github.com/cpplint/cpplint/archive/${cpplint_version}.tar.gz
|
||||||
|
else
|
||||||
|
echo "Please install wget or curl to download cpplint"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
tar xf ${cpplint_version}.tar.gz
|
||||||
|
rm ${cpplint_version}.tar.gz
|
||||||
|
|
||||||
|
# cpplint will report the following error for: __host__ __device__ (
|
||||||
|
#
|
||||||
|
# Extra space before ( in function call [whitespace/parens] [4]
|
||||||
|
#
|
||||||
|
# the following patch disables the above error
|
||||||
|
sed -i "3490i\ not Search(r'__host__ __device__\\\s+\\\(', fncall) and" $cpplint_src
|
||||||
|
popd
|
||||||
|
fi
|
||||||
|
|
||||||
|
source $sherpa_onnx_dir/scripts/utils.sh
|
||||||
|
|
||||||
|
# return true if the given file is a c++ source file
|
||||||
|
# return false otherwise
|
||||||
|
function is_source_code_file() {
|
||||||
|
case "$1" in
|
||||||
|
*.cc|*.h|*.cu)
|
||||||
|
echo true;;
|
||||||
|
*)
|
||||||
|
echo false;;
|
||||||
|
esac
|
||||||
|
}
|
||||||
|
|
||||||
|
function check_style() {
|
||||||
|
python3 $cpplint_src $1 || abort $1
|
||||||
|
}
|
||||||
|
|
||||||
|
function check_last_commit() {
|
||||||
|
files=$(git diff HEAD^1 --name-only --diff-filter=ACDMRUXB)
|
||||||
|
echo $files
|
||||||
|
}
|
||||||
|
|
||||||
|
function check_current_dir() {
|
||||||
|
files=$(git status -s -uno --porcelain | awk '{
|
||||||
|
if (NF == 4) {
|
||||||
|
# a file has been renamed
|
||||||
|
print $NF
|
||||||
|
} else {
|
||||||
|
print $2
|
||||||
|
}}')
|
||||||
|
|
||||||
|
echo $files
|
||||||
|
}
|
||||||
|
|
||||||
|
function do_check() {
|
||||||
|
case "$1" in
|
||||||
|
1)
|
||||||
|
echo "Check changed files"
|
||||||
|
files=$(check_current_dir)
|
||||||
|
;;
|
||||||
|
2)
|
||||||
|
echo "Check all files"
|
||||||
|
files=$(find $sherpa_onnx_dir/sherpa-onnx -name "*.h" -o -name "*.cc")
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
echo "Check last commit"
|
||||||
|
files=$(check_last_commit)
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
|
||||||
|
for f in $files; do
|
||||||
|
need_check=$(is_source_code_file $f)
|
||||||
|
if $need_check; then
|
||||||
|
[[ -f $f ]] && check_style $f
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
}
|
||||||
|
|
||||||
|
function main() {
|
||||||
|
do_check $1
|
||||||
|
|
||||||
|
ok "Great! Style check passed!"
|
||||||
|
}
|
||||||
|
|
||||||
|
cd $sherpa_onnx_dir
|
||||||
|
|
||||||
|
main $1
|
||||||
19
scripts/utils.sh
Normal file
19
scripts/utils.sh
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
default='\033[0m'
|
||||||
|
bold='\033[1m'
|
||||||
|
red='\033[31m'
|
||||||
|
green='\033[32m'
|
||||||
|
|
||||||
|
function ok() {
|
||||||
|
printf "${bold}${green}[OK]${default} $1\n"
|
||||||
|
}
|
||||||
|
|
||||||
|
function error() {
|
||||||
|
printf "${bold}${red}[FAILED]${default} $1\n"
|
||||||
|
}
|
||||||
|
|
||||||
|
function abort() {
|
||||||
|
printf "${bold}${red}[FAILED]${default} $1\n"
|
||||||
|
exit 1
|
||||||
|
}
|
||||||
@@ -1,8 +1,17 @@
|
|||||||
include_directories(${CMAKE_SOURCE_DIR})
|
include_directories(${CMAKE_SOURCE_DIR})
|
||||||
add_executable(sherpa-onnx main.cpp)
|
|
||||||
|
add_executable(sherpa-onnx
|
||||||
|
decode.cc
|
||||||
|
rnnt-model.cc
|
||||||
|
sherpa-onnx.cc
|
||||||
|
symbol-table.cc
|
||||||
|
wave-reader.cc
|
||||||
|
)
|
||||||
|
|
||||||
target_link_libraries(sherpa-onnx
|
target_link_libraries(sherpa-onnx
|
||||||
onnxruntime
|
onnxruntime
|
||||||
kaldi-native-fbank-core
|
kaldi-native-fbank-core
|
||||||
kaldi_native_io_core
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
add_executable(sherpa-show-onnx-info show-onnx-info.cc)
|
||||||
|
target_link_libraries(sherpa-show-onnx-info onnxruntime)
|
||||||
|
|||||||
84
sherpa-onnx/csrc/decode.cc
Normal file
84
sherpa-onnx/csrc/decode.cc
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
/**
|
||||||
|
* Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||||
|
*
|
||||||
|
* See LICENSE for clarification regarding multiple authors
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/decode.h"
|
||||||
|
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
std::vector<int32_t> GreedySearch(RnntModel &model, // NOLINT
|
||||||
|
const Ort::Value &encoder_out) {
|
||||||
|
std::vector<int64_t> encoder_out_shape =
|
||||||
|
encoder_out.GetTensorTypeAndShapeInfo().GetShape();
|
||||||
|
assert(encoder_out_shape[0] == 1 && "Only batch_size=1 is implemented");
|
||||||
|
Ort::Value projected_encoder_out =
|
||||||
|
model.RunJoinerEncoderProj(encoder_out.GetTensorData<float>(),
|
||||||
|
encoder_out_shape[1], encoder_out_shape[2]);
|
||||||
|
|
||||||
|
const float *p_projected_encoder_out =
|
||||||
|
projected_encoder_out.GetTensorData<float>();
|
||||||
|
|
||||||
|
int32_t context_size = 2; // hard-code it to 2
|
||||||
|
int32_t blank_id = 0; // hard-code it to 0
|
||||||
|
std::vector<int32_t> hyp(context_size, blank_id);
|
||||||
|
std::array<int64_t, 2> decoder_input{blank_id, blank_id};
|
||||||
|
|
||||||
|
Ort::Value decoder_out = model.RunDecoder(decoder_input.data(), context_size);
|
||||||
|
|
||||||
|
std::vector<int64_t> decoder_out_shape =
|
||||||
|
decoder_out.GetTensorTypeAndShapeInfo().GetShape();
|
||||||
|
|
||||||
|
Ort::Value projected_decoder_out = model.RunJoinerDecoderProj(
|
||||||
|
decoder_out.GetTensorData<float>(), decoder_out_shape[2]);
|
||||||
|
|
||||||
|
int32_t joiner_dim =
|
||||||
|
projected_decoder_out.GetTensorTypeAndShapeInfo().GetShape()[1];
|
||||||
|
|
||||||
|
int32_t T = encoder_out_shape[1];
|
||||||
|
for (int32_t t = 0; t != T; ++t) {
|
||||||
|
Ort::Value logit = model.RunJoiner(
|
||||||
|
p_projected_encoder_out + t * joiner_dim,
|
||||||
|
projected_decoder_out.GetTensorData<float>(), joiner_dim);
|
||||||
|
|
||||||
|
int32_t vocab_size = logit.GetTensorTypeAndShapeInfo().GetShape()[1];
|
||||||
|
|
||||||
|
const float *p_logit = logit.GetTensorData<float>();
|
||||||
|
|
||||||
|
auto y = static_cast<int32_t>(std::distance(
|
||||||
|
static_cast<const float *>(p_logit),
|
||||||
|
std::max_element(static_cast<const float *>(p_logit),
|
||||||
|
static_cast<const float *>(p_logit) + vocab_size)));
|
||||||
|
|
||||||
|
if (y != blank_id) {
|
||||||
|
decoder_input[0] = hyp.back();
|
||||||
|
decoder_input[1] = y;
|
||||||
|
hyp.push_back(y);
|
||||||
|
decoder_out = model.RunDecoder(decoder_input.data(), context_size);
|
||||||
|
projected_decoder_out = model.RunJoinerDecoderProj(
|
||||||
|
decoder_out.GetTensorData<float>(), decoder_out_shape[2]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return {hyp.begin() + context_size, hyp.end()};
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
40
sherpa-onnx/csrc/decode.h
Normal file
40
sherpa-onnx/csrc/decode.h
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
/**
|
||||||
|
* Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||||
|
*
|
||||||
|
* See LICENSE for clarification regarding multiple authors
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_DECODE_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_DECODE_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/rnnt-model.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
/** Greedy search for non-streaming ASR.
|
||||||
|
*
|
||||||
|
* @TODO(fangjun) Support batch size > 1
|
||||||
|
*
|
||||||
|
* @param model The RnntModel
|
||||||
|
* @param encoder_out Its shape is (1, num_frames, encoder_out_dim).
|
||||||
|
*/
|
||||||
|
std::vector<int32_t> GreedySearch(RnntModel &model, // NOLINT
|
||||||
|
const Ort::Value &encoder_out);
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_DECODE_H_
|
||||||
@@ -1,57 +0,0 @@
|
|||||||
#include <iostream>
|
|
||||||
|
|
||||||
#include "kaldi_native_io/csrc/kaldi-io.h"
|
|
||||||
#include "kaldi_native_io/csrc/wave-reader.h"
|
|
||||||
#include "kaldi-native-fbank/csrc/online-feature.h"
|
|
||||||
|
|
||||||
|
|
||||||
kaldiio::Matrix<float> readWav(std::string filename, bool log = false){
|
|
||||||
if (log)
|
|
||||||
std::cout << "reading " << filename << std::endl;
|
|
||||||
|
|
||||||
bool binary = true;
|
|
||||||
kaldiio::Input ki(filename, &binary);
|
|
||||||
kaldiio::WaveHolder wh;
|
|
||||||
|
|
||||||
if (!wh.Read(ki.Stream())) {
|
|
||||||
std::cerr << "Failed to read " << filename;
|
|
||||||
exit(EXIT_FAILURE);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto &wave_data = wh.Value();
|
|
||||||
auto &d = wave_data.Data();
|
|
||||||
|
|
||||||
if (log)
|
|
||||||
std::cout << "wav shape: " << "(" << d.NumRows() << "," << d.NumCols() << ")" << std::endl;
|
|
||||||
|
|
||||||
return d;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
std::vector<float> ComputeFeatures(knf::OnlineFbank &fbank, knf::FbankOptions opts, kaldiio::Matrix<float> samples, bool log = false){
|
|
||||||
int numSamples = samples.NumCols();
|
|
||||||
|
|
||||||
for (int i = 0; i < numSamples; i++)
|
|
||||||
{
|
|
||||||
float currentSample = samples.Row(0).Data()[i] / 32768;
|
|
||||||
fbank.AcceptWaveform(opts.frame_opts.samp_freq, ¤tSample, 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<float> features;
|
|
||||||
int32_t num_frames = fbank.NumFramesReady();
|
|
||||||
for (int32_t i = 0; i != num_frames; ++i) {
|
|
||||||
const float *frame = fbank.GetFrame(i);
|
|
||||||
for (int32_t k = 0; k != opts.mel_opts.num_bins; ++k) {
|
|
||||||
features.push_back(frame[k]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (log){
|
|
||||||
std::cout << "done feature extraction" << std::endl;
|
|
||||||
std::cout << "extracted fbank shape " << "(" << num_frames << "," << opts.mel_opts.num_bins << ")" << std::endl;
|
|
||||||
|
|
||||||
for (int i=0; i< 20; i++)
|
|
||||||
std::cout << features.at(i) << std::endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
return features;
|
|
||||||
}
|
|
||||||
@@ -1,99 +0,0 @@
|
|||||||
#include <algorithm>
|
|
||||||
#include <fstream>
|
|
||||||
#include <iostream>
|
|
||||||
#include <math.h>
|
|
||||||
#include <time.h>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "sherpa-onnx/csrc/fbank_features.h"
|
|
||||||
#include "sherpa-onnx/csrc/rnnt_beam_search.h"
|
|
||||||
|
|
||||||
#include "kaldi-native-fbank/csrc/online-feature.h"
|
|
||||||
|
|
||||||
int main(int argc, char *argv[]) {
|
|
||||||
char *encoder_path = argv[1];
|
|
||||||
char *decoder_path = argv[2];
|
|
||||||
char *joiner_path = argv[3];
|
|
||||||
char *joiner_encoder_proj_path = argv[4];
|
|
||||||
char *joiner_decoder_proj_path = argv[5];
|
|
||||||
char *token_path = argv[6];
|
|
||||||
std::string search_method = argv[7];
|
|
||||||
char *filename = argv[8];
|
|
||||||
|
|
||||||
// General parameters
|
|
||||||
int numberOfThreads = 16;
|
|
||||||
|
|
||||||
// Initialize fbanks
|
|
||||||
knf::FbankOptions opts;
|
|
||||||
opts.frame_opts.dither = 0;
|
|
||||||
opts.frame_opts.samp_freq = 16000;
|
|
||||||
opts.frame_opts.frame_shift_ms = 10.0f;
|
|
||||||
opts.frame_opts.frame_length_ms = 25.0f;
|
|
||||||
opts.mel_opts.num_bins = 80;
|
|
||||||
opts.frame_opts.window_type = "povey";
|
|
||||||
opts.frame_opts.snip_edges = false;
|
|
||||||
knf::OnlineFbank fbank(opts);
|
|
||||||
|
|
||||||
// set session opts
|
|
||||||
// https://onnxruntime.ai/docs/performance/tune-performance.html
|
|
||||||
session_options.SetIntraOpNumThreads(numberOfThreads);
|
|
||||||
session_options.SetInterOpNumThreads(numberOfThreads);
|
|
||||||
session_options.SetGraphOptimizationLevel(
|
|
||||||
GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
|
|
||||||
session_options.SetLogSeverityLevel(4);
|
|
||||||
session_options.SetExecutionMode(ExecutionMode::ORT_SEQUENTIAL);
|
|
||||||
|
|
||||||
api.CreateTensorRTProviderOptions(&tensorrt_options);
|
|
||||||
std::unique_ptr<OrtTensorRTProviderOptionsV2,
|
|
||||||
decltype(api.ReleaseTensorRTProviderOptions)>
|
|
||||||
rel_trt_options(tensorrt_options, api.ReleaseTensorRTProviderOptions);
|
|
||||||
api.SessionOptionsAppendExecutionProvider_TensorRT_V2(
|
|
||||||
static_cast<OrtSessionOptions *>(session_options), rel_trt_options.get());
|
|
||||||
|
|
||||||
// Define model
|
|
||||||
auto model =
|
|
||||||
get_model(encoder_path, decoder_path, joiner_path,
|
|
||||||
joiner_encoder_proj_path, joiner_decoder_proj_path, token_path);
|
|
||||||
|
|
||||||
std::vector<std::string> filename_list{filename};
|
|
||||||
|
|
||||||
for (auto filename : filename_list) {
|
|
||||||
std::cout << filename << std::endl;
|
|
||||||
auto samples = readWav(filename, true);
|
|
||||||
int numSamples = samples.NumCols();
|
|
||||||
|
|
||||||
auto features = ComputeFeatures(fbank, opts, samples);
|
|
||||||
|
|
||||||
auto tic = std::chrono::high_resolution_clock::now();
|
|
||||||
|
|
||||||
// # === Encoder Out === #
|
|
||||||
int num_frames = features.size() / opts.mel_opts.num_bins;
|
|
||||||
auto encoder_out =
|
|
||||||
model.encoder_forward(features, std::vector<int64_t>{num_frames},
|
|
||||||
std::vector<int64_t>{1, num_frames, 80},
|
|
||||||
std::vector<int64_t>{1}, memory_info);
|
|
||||||
|
|
||||||
// # === Search === #
|
|
||||||
std::vector<std::vector<int32_t>> hyps;
|
|
||||||
if (search_method == "greedy")
|
|
||||||
hyps = GreedySearch(&model, &encoder_out);
|
|
||||||
else {
|
|
||||||
std::cout << "wrong search method!" << std::endl;
|
|
||||||
exit(0);
|
|
||||||
}
|
|
||||||
auto results = hyps2result(model.tokens_map, hyps);
|
|
||||||
|
|
||||||
// # === Print Elapsed Time === #
|
|
||||||
auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(
|
|
||||||
std::chrono::high_resolution_clock::now() - tic);
|
|
||||||
std::cout << "Elapsed: " << float(elapsed.count()) / 1000 << " seconds"
|
|
||||||
<< std::endl;
|
|
||||||
std::cout << "rtf: " << float(elapsed.count()) / 1000 / (numSamples / 16000)
|
|
||||||
<< std::endl;
|
|
||||||
|
|
||||||
print_hyps(hyps);
|
|
||||||
std::cout << results[0] << std::endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
@@ -1,253 +0,0 @@
|
|||||||
#include <map>
|
|
||||||
#include <vector>
|
|
||||||
#include <iostream>
|
|
||||||
#include <algorithm>
|
|
||||||
#include <sys/stat.h>
|
|
||||||
|
|
||||||
#include "utils_onnx.h"
|
|
||||||
|
|
||||||
|
|
||||||
struct Model
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
const char* encoder_path;
|
|
||||||
const char* decoder_path;
|
|
||||||
const char* joiner_path;
|
|
||||||
const char* joiner_encoder_proj_path;
|
|
||||||
const char* joiner_decoder_proj_path;
|
|
||||||
const char* tokens_path;
|
|
||||||
|
|
||||||
Ort::Session encoder = load_model(encoder_path);
|
|
||||||
Ort::Session decoder = load_model(decoder_path);
|
|
||||||
Ort::Session joiner = load_model(joiner_path);
|
|
||||||
Ort::Session joiner_encoder_proj = load_model(joiner_encoder_proj_path);
|
|
||||||
Ort::Session joiner_decoder_proj = load_model(joiner_decoder_proj_path);
|
|
||||||
std::map<int, std::string> tokens_map = get_token_map(tokens_path);
|
|
||||||
|
|
||||||
int32_t blank_id;
|
|
||||||
int32_t unk_id;
|
|
||||||
int32_t context_size;
|
|
||||||
|
|
||||||
std::vector<Ort::Value> encoder_forward(std::vector<float> in_vector,
|
|
||||||
std::vector<int64_t> in_vector_length,
|
|
||||||
std::vector<int64_t> feature_dims,
|
|
||||||
std::vector<int64_t> feature_length_dims,
|
|
||||||
Ort::MemoryInfo &memory_info){
|
|
||||||
std::vector<Ort::Value> encoder_inputTensors;
|
|
||||||
encoder_inputTensors.push_back(Ort::Value::CreateTensor<float>(memory_info, in_vector.data(), in_vector.size(), feature_dims.data(), feature_dims.size()));
|
|
||||||
encoder_inputTensors.push_back(Ort::Value::CreateTensor<int64_t>(memory_info, in_vector_length.data(), in_vector_length.size(), feature_length_dims.data(), feature_length_dims.size()));
|
|
||||||
|
|
||||||
std::vector<const char*> encoder_inputNames = {encoder.GetInputName(0, allocator), encoder.GetInputName(1, allocator)};
|
|
||||||
std::vector<const char*> encoder_outputNames = {encoder.GetOutputName(0, allocator)};
|
|
||||||
|
|
||||||
auto out = encoder.Run(Ort::RunOptions{nullptr},
|
|
||||||
encoder_inputNames.data(),
|
|
||||||
encoder_inputTensors.data(),
|
|
||||||
encoder_inputTensors.size(),
|
|
||||||
encoder_outputNames.data(),
|
|
||||||
encoder_outputNames.size());
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<Ort::Value> decoder_forward(std::vector<int64_t> in_vector,
|
|
||||||
std::vector<int64_t> dims,
|
|
||||||
Ort::MemoryInfo &memory_info){
|
|
||||||
std::vector<Ort::Value> inputTensors;
|
|
||||||
inputTensors.push_back(Ort::Value::CreateTensor<int64_t>(memory_info, in_vector.data(), in_vector.size(), dims.data(), dims.size()));
|
|
||||||
|
|
||||||
std::vector<const char*> inputNames {decoder.GetInputName(0, allocator)};
|
|
||||||
std::vector<const char*> outputNames {decoder.GetOutputName(0, allocator)};
|
|
||||||
|
|
||||||
auto out = decoder.Run(Ort::RunOptions{nullptr},
|
|
||||||
inputNames.data(),
|
|
||||||
inputTensors.data(),
|
|
||||||
inputTensors.size(),
|
|
||||||
outputNames.data(),
|
|
||||||
outputNames.size());
|
|
||||||
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<Ort::Value> joiner_forward(std::vector<float> projected_encoder_out,
|
|
||||||
std::vector<float> decoder_out,
|
|
||||||
std::vector<int64_t> projected_encoder_out_dims,
|
|
||||||
std::vector<int64_t> decoder_out_dims,
|
|
||||||
Ort::MemoryInfo &memory_info){
|
|
||||||
std::vector<Ort::Value> inputTensors;
|
|
||||||
inputTensors.push_back(Ort::Value::CreateTensor<float>(memory_info, projected_encoder_out.data(), projected_encoder_out.size(), projected_encoder_out_dims.data(), projected_encoder_out_dims.size()));
|
|
||||||
inputTensors.push_back(Ort::Value::CreateTensor<float>(memory_info, decoder_out.data(), decoder_out.size(), decoder_out_dims.data(), decoder_out_dims.size()));
|
|
||||||
std::vector<const char*> inputNames = {joiner.GetInputName(0, allocator), joiner.GetInputName(1, allocator)};
|
|
||||||
std::vector<const char*> outputNames = {joiner.GetOutputName(0, allocator)};
|
|
||||||
|
|
||||||
auto out = joiner.Run(Ort::RunOptions{nullptr},
|
|
||||||
inputNames.data(),
|
|
||||||
inputTensors.data(),
|
|
||||||
inputTensors.size(),
|
|
||||||
outputNames.data(),
|
|
||||||
outputNames.size());
|
|
||||||
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<Ort::Value> joiner_encoder_proj_forward(std::vector<float> in_vector,
|
|
||||||
std::vector<int64_t> dims,
|
|
||||||
Ort::MemoryInfo &memory_info){
|
|
||||||
std::vector<Ort::Value> inputTensors;
|
|
||||||
inputTensors.push_back(Ort::Value::CreateTensor<float>(memory_info, in_vector.data(), in_vector.size(), dims.data(), dims.size()));
|
|
||||||
|
|
||||||
std::vector<const char*> inputNames {joiner_encoder_proj.GetInputName(0, allocator)};
|
|
||||||
std::vector<const char*> outputNames {joiner_encoder_proj.GetOutputName(0, allocator)};
|
|
||||||
|
|
||||||
auto out = joiner_encoder_proj.Run(Ort::RunOptions{nullptr},
|
|
||||||
inputNames.data(),
|
|
||||||
inputTensors.data(),
|
|
||||||
inputTensors.size(),
|
|
||||||
outputNames.data(),
|
|
||||||
outputNames.size());
|
|
||||||
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<Ort::Value> joiner_decoder_proj_forward(std::vector<float> in_vector,
|
|
||||||
std::vector<int64_t> dims,
|
|
||||||
Ort::MemoryInfo &memory_info){
|
|
||||||
std::vector<Ort::Value> inputTensors;
|
|
||||||
inputTensors.push_back(Ort::Value::CreateTensor<float>(memory_info, in_vector.data(), in_vector.size(), dims.data(), dims.size()));
|
|
||||||
|
|
||||||
std::vector<const char*> inputNames {joiner_decoder_proj.GetInputName(0, allocator)};
|
|
||||||
std::vector<const char*> outputNames {joiner_decoder_proj.GetOutputName(0, allocator)};
|
|
||||||
|
|
||||||
auto out = joiner_decoder_proj.Run(Ort::RunOptions{nullptr},
|
|
||||||
inputNames.data(),
|
|
||||||
inputTensors.data(),
|
|
||||||
inputTensors.size(),
|
|
||||||
outputNames.data(),
|
|
||||||
outputNames.size());
|
|
||||||
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
|
|
||||||
Ort::Session load_model(const char* path){
|
|
||||||
struct stat buffer;
|
|
||||||
if (stat(path, &buffer) != 0){
|
|
||||||
std::cout << "File does not exist!: " << path << std::endl;
|
|
||||||
exit(0);
|
|
||||||
}
|
|
||||||
std::cout << "loading " << path << std::endl;
|
|
||||||
Ort::Session onnx_model(env, path, session_options);
|
|
||||||
return onnx_model;
|
|
||||||
}
|
|
||||||
|
|
||||||
void extract_constant_lm_parameters(){
|
|
||||||
/*
|
|
||||||
all_in_one contains these params. We should trace all_in_one and find 'constants_lm' nodes to extract these params
|
|
||||||
For now, these params are set staticaly.
|
|
||||||
in: Ort::Session &all_in_one
|
|
||||||
out: {blank_id, unk_id, context_size}
|
|
||||||
should return std::vector<int32_t>
|
|
||||||
*/
|
|
||||||
blank_id = 0;
|
|
||||||
unk_id = 0;
|
|
||||||
context_size = 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::map<int, std::string> get_token_map(const char* token_path){
|
|
||||||
std::ifstream inFile;
|
|
||||||
inFile.open(token_path);
|
|
||||||
if (inFile.fail())
|
|
||||||
std::cerr << "Could not find token file" << std::endl;
|
|
||||||
|
|
||||||
std::map<int, std::string> token_map;
|
|
||||||
|
|
||||||
std::string line;
|
|
||||||
while (std::getline(inFile, line))
|
|
||||||
{
|
|
||||||
int id;
|
|
||||||
std::string token;
|
|
||||||
|
|
||||||
std::istringstream iss(line);
|
|
||||||
iss >> token;
|
|
||||||
iss >> id;
|
|
||||||
|
|
||||||
token_map[id] = token;
|
|
||||||
}
|
|
||||||
|
|
||||||
return token_map;
|
|
||||||
}
|
|
||||||
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
Model get_model(std::string exp_path, char* tokens_path){
|
|
||||||
Model model{
|
|
||||||
(exp_path + "/encoder_simp.onnx").c_str(),
|
|
||||||
(exp_path + "/decoder_simp.onnx").c_str(),
|
|
||||||
(exp_path + "/joiner_simp.onnx").c_str(),
|
|
||||||
(exp_path + "/joiner_encoder_proj_simp.onnx").c_str(),
|
|
||||||
(exp_path + "/joiner_decoder_proj_simp.onnx").c_str(),
|
|
||||||
tokens_path,
|
|
||||||
};
|
|
||||||
model.extract_constant_lm_parameters();
|
|
||||||
|
|
||||||
return model;
|
|
||||||
}
|
|
||||||
|
|
||||||
Model get_model(char* encoder_path,
|
|
||||||
char* decoder_path,
|
|
||||||
char* joiner_path,
|
|
||||||
char* joiner_encoder_proj_path,
|
|
||||||
char* joiner_decoder_proj_path,
|
|
||||||
char* tokens_path){
|
|
||||||
Model model{
|
|
||||||
encoder_path,
|
|
||||||
decoder_path,
|
|
||||||
joiner_path,
|
|
||||||
joiner_encoder_proj_path,
|
|
||||||
joiner_decoder_proj_path,
|
|
||||||
tokens_path,
|
|
||||||
};
|
|
||||||
model.extract_constant_lm_parameters();
|
|
||||||
|
|
||||||
return model;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
void doWarmup(Model *model, int numWarmup = 5){
|
|
||||||
std::cout << "Warmup is started" << std::endl;
|
|
||||||
|
|
||||||
std::vector<float> encoder_warmup_sample (500 * 80, 1.0);
|
|
||||||
for (int i=0; i<numWarmup; i++)
|
|
||||||
auto encoder_out = model->encoder_forward(encoder_warmup_sample,
|
|
||||||
std::vector<int64_t> {500},
|
|
||||||
std::vector<int64_t> {1, 500, 80},
|
|
||||||
std::vector<int64_t> {1},
|
|
||||||
memory_info);
|
|
||||||
|
|
||||||
std::vector<int64_t> decoder_warmup_sample {1, 1};
|
|
||||||
for (int i=0; i<numWarmup; i++)
|
|
||||||
auto decoder_out = model->decoder_forward(decoder_warmup_sample,
|
|
||||||
std::vector<int64_t> {1, 2},
|
|
||||||
memory_info);
|
|
||||||
|
|
||||||
std::vector<float> joiner_warmup_sample1 (512, 1.0);
|
|
||||||
std::vector<float> joiner_warmup_sample2 (512, 1.0);
|
|
||||||
for (int i=0; i<numWarmup; i++)
|
|
||||||
auto logits = model->joiner_forward(joiner_warmup_sample1,
|
|
||||||
joiner_warmup_sample2,
|
|
||||||
std::vector<int64_t> {1, 1, 1, 512},
|
|
||||||
std::vector<int64_t> {1, 1, 1, 512},
|
|
||||||
memory_info);
|
|
||||||
|
|
||||||
std::vector<float> joiner_encoder_proj_warmup_sample (100 * 512, 1.0);
|
|
||||||
for (int i=0; i<numWarmup; i++)
|
|
||||||
auto projected_encoder_out = model->joiner_encoder_proj_forward(joiner_encoder_proj_warmup_sample,
|
|
||||||
std::vector<int64_t> {100, 512},
|
|
||||||
memory_info);
|
|
||||||
|
|
||||||
std::vector<float> joiner_decoder_proj_warmup_sample (512, 1.0);
|
|
||||||
for (int i=0; i<numWarmup; i++)
|
|
||||||
auto projected_decoder_out = model->joiner_decoder_proj_forward(joiner_decoder_proj_warmup_sample,
|
|
||||||
std::vector<int64_t> {1, 512},
|
|
||||||
memory_info);
|
|
||||||
std::cout << "Warmup is done" << std::endl;
|
|
||||||
}
|
|
||||||
247
sherpa-onnx/csrc/rnnt-model.cc
Normal file
247
sherpa-onnx/csrc/rnnt-model.cc
Normal file
@@ -0,0 +1,247 @@
|
|||||||
|
/**
|
||||||
|
* Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||||
|
*
|
||||||
|
* See LICENSE for clarification regarding multiple authors
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "sherpa-onnx/csrc/rnnt-model.h"
|
||||||
|
|
||||||
|
#include <array>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the input names of a model.
|
||||||
|
*
|
||||||
|
* @param sess An onnxruntime session.
|
||||||
|
* @param input_names. On return, it contains the input names of the model.
|
||||||
|
* @param input_names_ptr. On return, input_names_ptr[i] contains
|
||||||
|
* input_names[i].c_str()
|
||||||
|
*/
|
||||||
|
static void GetInputNames(Ort::Session *sess,
|
||||||
|
std::vector<std::string> *input_names,
|
||||||
|
std::vector<const char *> *input_names_ptr) {
|
||||||
|
Ort::AllocatorWithDefaultOptions allocator;
|
||||||
|
size_t node_count = sess->GetInputCount();
|
||||||
|
input_names->resize(node_count);
|
||||||
|
input_names_ptr->resize(node_count);
|
||||||
|
for (size_t i = 0; i != node_count; ++i) {
|
||||||
|
auto tmp = sess->GetInputNameAllocated(i, allocator);
|
||||||
|
(*input_names)[i] = tmp.get();
|
||||||
|
(*input_names_ptr)[i] = (*input_names)[i].c_str();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the output names of a model.
|
||||||
|
*
|
||||||
|
* @param sess An onnxruntime session.
|
||||||
|
* @param output_names. On return, it contains the output names of the model.
|
||||||
|
* @param output_names_ptr. On return, output_names_ptr[i] contains
|
||||||
|
* output_names[i].c_str()
|
||||||
|
*/
|
||||||
|
static void GetOutputNames(Ort::Session *sess,
|
||||||
|
std::vector<std::string> *output_names,
|
||||||
|
std::vector<const char *> *output_names_ptr) {
|
||||||
|
Ort::AllocatorWithDefaultOptions allocator;
|
||||||
|
size_t node_count = sess->GetOutputCount();
|
||||||
|
output_names->resize(node_count);
|
||||||
|
output_names_ptr->resize(node_count);
|
||||||
|
for (size_t i = 0; i != node_count; ++i) {
|
||||||
|
auto tmp = sess->GetOutputNameAllocated(i, allocator);
|
||||||
|
(*output_names)[i] = tmp.get();
|
||||||
|
(*output_names_ptr)[i] = (*output_names)[i].c_str();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RnntModel::RnntModel(const std::string &encoder_filename,
|
||||||
|
const std::string &decoder_filename,
|
||||||
|
const std::string &joiner_filename,
|
||||||
|
const std::string &joiner_encoder_proj_filename,
|
||||||
|
const std::string &joiner_decoder_proj_filename,
|
||||||
|
int32_t num_threads)
|
||||||
|
: env_(ORT_LOGGING_LEVEL_WARNING) {
|
||||||
|
sess_opts_.SetIntraOpNumThreads(num_threads);
|
||||||
|
sess_opts_.SetInterOpNumThreads(num_threads);
|
||||||
|
|
||||||
|
InitEncoder(encoder_filename);
|
||||||
|
InitDecoder(decoder_filename);
|
||||||
|
InitJoiner(joiner_filename);
|
||||||
|
InitJoinerEncoderProj(joiner_encoder_proj_filename);
|
||||||
|
InitJoinerDecoderProj(joiner_decoder_proj_filename);
|
||||||
|
}
|
||||||
|
|
||||||
|
void RnntModel::InitEncoder(const std::string &filename) {
|
||||||
|
encoder_sess_ =
|
||||||
|
std::make_unique<Ort::Session>(env_, filename.c_str(), sess_opts_);
|
||||||
|
GetInputNames(encoder_sess_.get(), &encoder_input_names_,
|
||||||
|
&encoder_input_names_ptr_);
|
||||||
|
|
||||||
|
GetOutputNames(encoder_sess_.get(), &encoder_output_names_,
|
||||||
|
&encoder_output_names_ptr_);
|
||||||
|
}
|
||||||
|
|
||||||
|
void RnntModel::InitDecoder(const std::string &filename) {
|
||||||
|
decoder_sess_ =
|
||||||
|
std::make_unique<Ort::Session>(env_, filename.c_str(), sess_opts_);
|
||||||
|
|
||||||
|
GetInputNames(decoder_sess_.get(), &decoder_input_names_,
|
||||||
|
&decoder_input_names_ptr_);
|
||||||
|
|
||||||
|
GetOutputNames(decoder_sess_.get(), &decoder_output_names_,
|
||||||
|
&decoder_output_names_ptr_);
|
||||||
|
}
|
||||||
|
|
||||||
|
void RnntModel::InitJoiner(const std::string &filename) {
|
||||||
|
joiner_sess_ =
|
||||||
|
std::make_unique<Ort::Session>(env_, filename.c_str(), sess_opts_);
|
||||||
|
|
||||||
|
GetInputNames(joiner_sess_.get(), &joiner_input_names_,
|
||||||
|
&joiner_input_names_ptr_);
|
||||||
|
|
||||||
|
GetOutputNames(joiner_sess_.get(), &joiner_output_names_,
|
||||||
|
&joiner_output_names_ptr_);
|
||||||
|
}
|
||||||
|
|
||||||
|
void RnntModel::InitJoinerEncoderProj(const std::string &filename) {
|
||||||
|
joiner_encoder_proj_sess_ =
|
||||||
|
std::make_unique<Ort::Session>(env_, filename.c_str(), sess_opts_);
|
||||||
|
|
||||||
|
GetInputNames(joiner_encoder_proj_sess_.get(),
|
||||||
|
&joiner_encoder_proj_input_names_,
|
||||||
|
&joiner_encoder_proj_input_names_ptr_);
|
||||||
|
|
||||||
|
GetOutputNames(joiner_encoder_proj_sess_.get(),
|
||||||
|
&joiner_encoder_proj_output_names_,
|
||||||
|
&joiner_encoder_proj_output_names_ptr_);
|
||||||
|
}
|
||||||
|
|
||||||
|
void RnntModel::InitJoinerDecoderProj(const std::string &filename) {
|
||||||
|
joiner_decoder_proj_sess_ =
|
||||||
|
std::make_unique<Ort::Session>(env_, filename.c_str(), sess_opts_);
|
||||||
|
|
||||||
|
GetInputNames(joiner_decoder_proj_sess_.get(),
|
||||||
|
&joiner_decoder_proj_input_names_,
|
||||||
|
&joiner_decoder_proj_input_names_ptr_);
|
||||||
|
|
||||||
|
GetOutputNames(joiner_decoder_proj_sess_.get(),
|
||||||
|
&joiner_decoder_proj_output_names_,
|
||||||
|
&joiner_decoder_proj_output_names_ptr_);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ort::Value RnntModel::RunEncoder(const float *features, int32_t T,
|
||||||
|
int32_t feature_dim) {
|
||||||
|
auto memory_info =
|
||||||
|
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||||
|
std::array<int64_t, 3> x_shape{1, T, feature_dim};
|
||||||
|
Ort::Value x =
|
||||||
|
Ort::Value::CreateTensor(memory_info, const_cast<float *>(features),
|
||||||
|
T * feature_dim, x_shape.data(), x_shape.size());
|
||||||
|
|
||||||
|
std::array<int64_t, 1> x_lens_shape{1};
|
||||||
|
int64_t x_lens_tmp = T;
|
||||||
|
|
||||||
|
Ort::Value x_lens = Ort::Value::CreateTensor(
|
||||||
|
memory_info, &x_lens_tmp, 1, x_lens_shape.data(), x_lens_shape.size());
|
||||||
|
|
||||||
|
std::array<Ort::Value, 2> encoder_inputs{std::move(x), std::move(x_lens)};
|
||||||
|
|
||||||
|
// Note: We discard encoder_out_lens since we only implement
|
||||||
|
// batch==1.
|
||||||
|
auto encoder_out = encoder_sess_->Run(
|
||||||
|
{}, encoder_input_names_ptr_.data(), encoder_inputs.data(),
|
||||||
|
encoder_inputs.size(), encoder_output_names_ptr_.data(),
|
||||||
|
encoder_output_names_ptr_.size());
|
||||||
|
return std::move(encoder_out[0]);
|
||||||
|
}
|
||||||
|
Ort::Value RnntModel::RunJoinerEncoderProj(const float *encoder_out, int32_t T,
|
||||||
|
int32_t encoder_out_dim) {
|
||||||
|
auto memory_info =
|
||||||
|
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||||
|
|
||||||
|
std::array<int64_t, 2> in_shape{T, encoder_out_dim};
|
||||||
|
Ort::Value in = Ort::Value::CreateTensor(
|
||||||
|
memory_info, const_cast<float *>(encoder_out), T * encoder_out_dim,
|
||||||
|
in_shape.data(), in_shape.size());
|
||||||
|
|
||||||
|
auto encoder_proj_out = joiner_encoder_proj_sess_->Run(
|
||||||
|
{}, joiner_encoder_proj_input_names_ptr_.data(), &in, 1,
|
||||||
|
joiner_encoder_proj_output_names_ptr_.data(),
|
||||||
|
joiner_encoder_proj_output_names_ptr_.size());
|
||||||
|
return std::move(encoder_proj_out[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ort::Value RnntModel::RunDecoder(const int64_t *decoder_input,
|
||||||
|
int32_t context_size) {
|
||||||
|
auto memory_info =
|
||||||
|
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||||
|
|
||||||
|
int32_t batch_size = 1; // TODO(fangjun): handle the case when it's > 1
|
||||||
|
std::array<int64_t, 2> shape{batch_size, context_size};
|
||||||
|
Ort::Value in = Ort::Value::CreateTensor(
|
||||||
|
memory_info, const_cast<int64_t *>(decoder_input),
|
||||||
|
batch_size * context_size, shape.data(), shape.size());
|
||||||
|
|
||||||
|
auto decoder_out = decoder_sess_->Run(
|
||||||
|
{}, decoder_input_names_ptr_.data(), &in, 1,
|
||||||
|
decoder_output_names_ptr_.data(), decoder_output_names_ptr_.size());
|
||||||
|
return std::move(decoder_out[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ort::Value RnntModel::RunJoinerDecoderProj(const float *decoder_out,
|
||||||
|
int32_t decoder_out_dim) {
|
||||||
|
auto memory_info =
|
||||||
|
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||||
|
|
||||||
|
int32_t batch_size = 1; // TODO(fangjun): handle the case when it's > 1
|
||||||
|
std::array<int64_t, 2> shape{batch_size, decoder_out_dim};
|
||||||
|
Ort::Value in = Ort::Value::CreateTensor(
|
||||||
|
memory_info, const_cast<float *>(decoder_out),
|
||||||
|
batch_size * decoder_out_dim, shape.data(), shape.size());
|
||||||
|
|
||||||
|
auto decoder_proj_out = joiner_decoder_proj_sess_->Run(
|
||||||
|
{}, joiner_decoder_proj_input_names_ptr_.data(), &in, 1,
|
||||||
|
joiner_decoder_proj_output_names_ptr_.data(),
|
||||||
|
joiner_decoder_proj_output_names_ptr_.size());
|
||||||
|
return std::move(decoder_proj_out[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ort::Value RnntModel::RunJoiner(const float *projected_encoder_out,
|
||||||
|
const float *projected_decoder_out,
|
||||||
|
int32_t joiner_dim) {
|
||||||
|
auto memory_info =
|
||||||
|
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||||
|
int32_t batch_size = 1; // TODO(fangjun): handle the case when it's > 1
|
||||||
|
std::array<int64_t, 2> shape{batch_size, joiner_dim};
|
||||||
|
|
||||||
|
Ort::Value enc = Ort::Value::CreateTensor(
|
||||||
|
memory_info, const_cast<float *>(projected_encoder_out),
|
||||||
|
batch_size * joiner_dim, shape.data(), shape.size());
|
||||||
|
|
||||||
|
Ort::Value dec = Ort::Value::CreateTensor(
|
||||||
|
memory_info, const_cast<float *>(projected_decoder_out),
|
||||||
|
batch_size * joiner_dim, shape.data(), shape.size());
|
||||||
|
|
||||||
|
std::array<Ort::Value, 2> inputs{std::move(enc), std::move(dec)};
|
||||||
|
|
||||||
|
auto logit = joiner_sess_->Run(
|
||||||
|
{}, joiner_input_names_ptr_.data(), inputs.data(), inputs.size(),
|
||||||
|
joiner_output_names_ptr_.data(), joiner_output_names_ptr_.size());
|
||||||
|
|
||||||
|
return std::move(logit[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
148
sherpa-onnx/csrc/rnnt-model.h
Normal file
148
sherpa-onnx/csrc/rnnt-model.h
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
/**
|
||||||
|
* Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||||
|
*
|
||||||
|
* See LICENSE for clarification regarding multiple authors
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_RNNT_MODEL_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_RNNT_MODEL_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
class RnntModel {
|
||||||
|
public:
|
||||||
|
/**
|
||||||
|
* @param encoder_filename Path to the encoder model
|
||||||
|
* @param decoder_filename Path to the decoder model
|
||||||
|
* @param joiner_filename Path to the joiner model
|
||||||
|
* @param joiner_encoder_proj_filename Path to the joiner encoder_proj model
|
||||||
|
* @param joiner_decoder_proj_filename Path to the joiner decoder_proj model
|
||||||
|
* @param num_threads Number of threads to use to run the models
|
||||||
|
*/
|
||||||
|
RnntModel(const std::string &encoder_filename,
|
||||||
|
const std::string &decoder_filename,
|
||||||
|
const std::string &joiner_filename,
|
||||||
|
const std::string &joiner_encoder_proj_filename,
|
||||||
|
const std::string &joiner_decoder_proj_filename,
|
||||||
|
int32_t num_threads);
|
||||||
|
|
||||||
|
/** Run the encoder model.
|
||||||
|
*
|
||||||
|
* @TODO(fangjun): Support batch_size > 1
|
||||||
|
*
|
||||||
|
* @param features A tensor of shape (batch_size, T, feature_dim)
|
||||||
|
* @param T Number of feature frames
|
||||||
|
* @param feature_dim Dimension of the feature.
|
||||||
|
*
|
||||||
|
* @return Return a tensor of shape (batch_size, T', encoder_out_dim)
|
||||||
|
*/
|
||||||
|
Ort::Value RunEncoder(const float *features, int32_t T, int32_t feature_dim);
|
||||||
|
|
||||||
|
/** Run the joiner encoder_proj model.
|
||||||
|
*
|
||||||
|
* @param encoder_out A tensor of shape (T, encoder_out_dim)
|
||||||
|
* @param T Number of frames in encoder_out.
|
||||||
|
* @param encoder_out_dim Dimension of encoder_out.
|
||||||
|
*
|
||||||
|
* @return Return a tensor of shape (T, joiner_dim)
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
Ort::Value RunJoinerEncoderProj(const float *encoder_out, int32_t T,
|
||||||
|
int32_t encoder_out_dim);
|
||||||
|
|
||||||
|
/** Run the decoder model.
|
||||||
|
*
|
||||||
|
* @TODO(fangjun): Support batch_size > 1
|
||||||
|
*
|
||||||
|
* @param decoder_input A tensor of shape (batch_size, context_size).
|
||||||
|
* @return Return a tensor of shape (batch_size, 1, decoder_out_dim)
|
||||||
|
*/
|
||||||
|
Ort::Value RunDecoder(const int64_t *decoder_input, int32_t context_size);
|
||||||
|
|
||||||
|
/** Run joiner decoder_proj model.
|
||||||
|
*
|
||||||
|
* @TODO(fangjun): Support batch_size > 1
|
||||||
|
*
|
||||||
|
* @param decoder_out A tensor of shape (batch_size, decoder_out_dim)
|
||||||
|
* @param decoder_out_dim Output dimension of the decoder_out.
|
||||||
|
*
|
||||||
|
* @return Return a tensor of shape (batch_size, joiner_dim);
|
||||||
|
*/
|
||||||
|
Ort::Value RunJoinerDecoderProj(const float *decoder_out,
|
||||||
|
int32_t decoder_out_dim);
|
||||||
|
|
||||||
|
/** Run the joiner model.
|
||||||
|
*
|
||||||
|
* @TODO(fangjun): Support batch_size > 1
|
||||||
|
*
|
||||||
|
* @param projected_encoder_out A tensor of shape (batch_size, joiner_dim).
|
||||||
|
* @param projected_decoder_out A tensor of shape (batch_size, joiner_dim).
|
||||||
|
*
|
||||||
|
* @return Return a tensor of shape (batch_size, vocab_size)
|
||||||
|
*/
|
||||||
|
Ort::Value RunJoiner(const float *projected_encoder_out,
|
||||||
|
const float *projected_decoder_out, int32_t joiner_dim);
|
||||||
|
|
||||||
|
private:
|
||||||
|
void InitEncoder(const std::string &encoder_filename);
|
||||||
|
void InitDecoder(const std::string &decoder_filename);
|
||||||
|
void InitJoiner(const std::string &joiner_filename);
|
||||||
|
void InitJoinerEncoderProj(const std::string &joiner_encoder_proj_filename);
|
||||||
|
void InitJoinerDecoderProj(const std::string &joiner_decoder_proj_filename);
|
||||||
|
|
||||||
|
private:
|
||||||
|
Ort::Env env_;
|
||||||
|
Ort::SessionOptions sess_opts_;
|
||||||
|
std::unique_ptr<Ort::Session> encoder_sess_;
|
||||||
|
std::unique_ptr<Ort::Session> decoder_sess_;
|
||||||
|
std::unique_ptr<Ort::Session> joiner_sess_;
|
||||||
|
std::unique_ptr<Ort::Session> joiner_encoder_proj_sess_;
|
||||||
|
std::unique_ptr<Ort::Session> joiner_decoder_proj_sess_;
|
||||||
|
|
||||||
|
std::vector<std::string> encoder_input_names_;
|
||||||
|
std::vector<const char *> encoder_input_names_ptr_;
|
||||||
|
std::vector<std::string> encoder_output_names_;
|
||||||
|
std::vector<const char *> encoder_output_names_ptr_;
|
||||||
|
|
||||||
|
std::vector<std::string> decoder_input_names_;
|
||||||
|
std::vector<const char *> decoder_input_names_ptr_;
|
||||||
|
std::vector<std::string> decoder_output_names_;
|
||||||
|
std::vector<const char *> decoder_output_names_ptr_;
|
||||||
|
|
||||||
|
std::vector<std::string> joiner_input_names_;
|
||||||
|
std::vector<const char *> joiner_input_names_ptr_;
|
||||||
|
std::vector<std::string> joiner_output_names_;
|
||||||
|
std::vector<const char *> joiner_output_names_ptr_;
|
||||||
|
|
||||||
|
std::vector<std::string> joiner_encoder_proj_input_names_;
|
||||||
|
std::vector<const char *> joiner_encoder_proj_input_names_ptr_;
|
||||||
|
std::vector<std::string> joiner_encoder_proj_output_names_;
|
||||||
|
std::vector<const char *> joiner_encoder_proj_output_names_ptr_;
|
||||||
|
|
||||||
|
std::vector<std::string> joiner_decoder_proj_input_names_;
|
||||||
|
std::vector<const char *> joiner_decoder_proj_input_names_ptr_;
|
||||||
|
std::vector<std::string> joiner_decoder_proj_output_names_;
|
||||||
|
std::vector<const char *> joiner_decoder_proj_output_names_ptr_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_RNNT_MODEL_H_
|
||||||
@@ -1,120 +0,0 @@
|
|||||||
#include <vector>
|
|
||||||
#include <iostream>
|
|
||||||
#include <algorithm>
|
|
||||||
#include <time.h>
|
|
||||||
|
|
||||||
#include "models.h"
|
|
||||||
#include "utils.h"
|
|
||||||
|
|
||||||
|
|
||||||
std::vector<float> getEncoderCol(Ort::Value &tensor, int start, int length){
|
|
||||||
float* floatarr = tensor.GetTensorMutableData<float>();
|
|
||||||
std::vector<float> vector {floatarr + start, floatarr + length};
|
|
||||||
return vector;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Assume batch size = 1
|
|
||||||
*/
|
|
||||||
std::vector<int64_t> BuildDecoderInput(const std::vector<std::vector<int32_t>> &hyps,
|
|
||||||
std::vector<int64_t> &decoder_input) {
|
|
||||||
|
|
||||||
int32_t context_size = decoder_input.size();
|
|
||||||
int32_t hyps_length = hyps[0].size();
|
|
||||||
for (int i=0; i < context_size; i++)
|
|
||||||
decoder_input[i] = hyps[0][hyps_length-context_size+i];
|
|
||||||
|
|
||||||
return decoder_input;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
std::vector<std::vector<int32_t>> GreedySearch(
|
|
||||||
Model *model, // NOLINT
|
|
||||||
std::vector<Ort::Value> *encoder_out){
|
|
||||||
Ort::Value &encoder_out_tensor = encoder_out->at(0);
|
|
||||||
int encoder_out_dim1 = encoder_out_tensor.GetTensorTypeAndShapeInfo().GetShape()[1];
|
|
||||||
int encoder_out_dim2 = encoder_out_tensor.GetTensorTypeAndShapeInfo().GetShape()[2];
|
|
||||||
auto encoder_out_vector = ortVal2Vector(encoder_out_tensor, encoder_out_dim1 * encoder_out_dim2);
|
|
||||||
|
|
||||||
// # === Greedy Search === #
|
|
||||||
int32_t batch_size = 1;
|
|
||||||
std::vector<int32_t> blanks(model->context_size, model->blank_id);
|
|
||||||
std::vector<std::vector<int32_t>> hyps(batch_size, blanks);
|
|
||||||
std::vector<int64_t> decoder_input(model->context_size, model->blank_id);
|
|
||||||
|
|
||||||
auto decoder_out = model->decoder_forward(decoder_input,
|
|
||||||
std::vector<int64_t> {batch_size, model->context_size},
|
|
||||||
memory_info);
|
|
||||||
|
|
||||||
Ort::Value &decoder_out_tensor = decoder_out[0];
|
|
||||||
int decoder_out_dim = decoder_out_tensor.GetTensorTypeAndShapeInfo().GetShape()[2];
|
|
||||||
auto decoder_out_vector = ortVal2Vector(decoder_out_tensor, decoder_out_dim);
|
|
||||||
|
|
||||||
decoder_out = model->joiner_decoder_proj_forward(decoder_out_vector,
|
|
||||||
std::vector<int64_t> {1, decoder_out_dim},
|
|
||||||
memory_info);
|
|
||||||
Ort::Value &projected_decoder_out_tensor = decoder_out[0];
|
|
||||||
auto projected_decoder_out_dim = projected_decoder_out_tensor.GetTensorTypeAndShapeInfo().GetShape()[1];
|
|
||||||
auto projected_decoder_out_vector = ortVal2Vector(projected_decoder_out_tensor, projected_decoder_out_dim);
|
|
||||||
|
|
||||||
auto projected_encoder_out = model->joiner_encoder_proj_forward(encoder_out_vector,
|
|
||||||
std::vector<int64_t> {encoder_out_dim1, encoder_out_dim2},
|
|
||||||
memory_info);
|
|
||||||
Ort::Value &projected_encoder_out_tensor = projected_encoder_out[0];
|
|
||||||
int projected_encoder_out_dim1 = projected_encoder_out_tensor.GetTensorTypeAndShapeInfo().GetShape()[0];
|
|
||||||
int projected_encoder_out_dim2 = projected_encoder_out_tensor.GetTensorTypeAndShapeInfo().GetShape()[1];
|
|
||||||
auto projected_encoder_out_vector = ortVal2Vector(projected_encoder_out_tensor, projected_encoder_out_dim1 * projected_encoder_out_dim2);
|
|
||||||
|
|
||||||
int32_t offset = 0;
|
|
||||||
for (int i=0; i< projected_encoder_out_dim1; i++){
|
|
||||||
int32_t cur_batch_size = 1;
|
|
||||||
int32_t start = offset;
|
|
||||||
int32_t end = start + cur_batch_size;
|
|
||||||
offset = end;
|
|
||||||
|
|
||||||
auto cur_encoder_out = getEncoderCol(projected_encoder_out_tensor, start * projected_encoder_out_dim2, end * projected_encoder_out_dim2);
|
|
||||||
|
|
||||||
auto logits = model->joiner_forward(cur_encoder_out,
|
|
||||||
projected_decoder_out_vector,
|
|
||||||
std::vector<int64_t> {1, projected_encoder_out_dim2},
|
|
||||||
std::vector<int64_t> {1, projected_decoder_out_dim},
|
|
||||||
memory_info);
|
|
||||||
|
|
||||||
Ort::Value &logits_tensor = logits[0];
|
|
||||||
int logits_dim = logits_tensor.GetTensorTypeAndShapeInfo().GetShape()[1];
|
|
||||||
auto logits_vector = ortVal2Vector(logits_tensor, logits_dim);
|
|
||||||
|
|
||||||
int max_indices = static_cast<int>(std::distance(logits_vector.begin(), std::max_element(logits_vector.begin(), logits_vector.end())));
|
|
||||||
bool emitted = false;
|
|
||||||
|
|
||||||
for (int32_t k = 0; k != cur_batch_size; ++k) {
|
|
||||||
auto index = max_indices;
|
|
||||||
if (index != model->blank_id && index != model->unk_id) {
|
|
||||||
emitted = true;
|
|
||||||
hyps[k].push_back(index);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (emitted) {
|
|
||||||
decoder_input = BuildDecoderInput(hyps, decoder_input);
|
|
||||||
|
|
||||||
decoder_out = model->decoder_forward(decoder_input,
|
|
||||||
std::vector<int64_t> {batch_size, model->context_size},
|
|
||||||
memory_info);
|
|
||||||
|
|
||||||
decoder_out_dim = decoder_out[0].GetTensorTypeAndShapeInfo().GetShape()[2];
|
|
||||||
decoder_out_vector = ortVal2Vector(decoder_out[0], decoder_out_dim);
|
|
||||||
|
|
||||||
decoder_out = model->joiner_decoder_proj_forward(decoder_out_vector,
|
|
||||||
std::vector<int64_t> {1, decoder_out_dim},
|
|
||||||
memory_info);
|
|
||||||
|
|
||||||
projected_decoder_out_dim = decoder_out[0].GetTensorTypeAndShapeInfo().GetShape()[1];
|
|
||||||
projected_decoder_out_vector = ortVal2Vector(decoder_out[0], projected_decoder_out_dim);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return hyps;
|
|
||||||
}
|
|
||||||
|
|
||||||
129
sherpa-onnx/csrc/sherpa-onnx.cc
Normal file
129
sherpa-onnx/csrc/sherpa-onnx.cc
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
/**
|
||||||
|
* Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||||
|
*
|
||||||
|
* See LICENSE for clarification regarding multiple authors
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "kaldi-native-fbank/csrc/online-feature.h"
|
||||||
|
#include "sherpa-onnx/csrc/decode.h"
|
||||||
|
#include "sherpa-onnx/csrc/rnnt-model.h"
|
||||||
|
#include "sherpa-onnx/csrc/symbol-table.h"
|
||||||
|
#include "sherpa-onnx/csrc/wave-reader.h"
|
||||||
|
|
||||||
|
/** Compute fbank features of the input wave filename.
|
||||||
|
*
|
||||||
|
* @param wav_filename. Path to a mono wave file.
|
||||||
|
* @param expected_sampling_rate Expected sampling rate of the input wave file.
|
||||||
|
* @param num_frames On return, it contains the number of feature frames.
|
||||||
|
* @return Return the computed feature of shape (num_frames, feature_dim)
|
||||||
|
* stored in row-major.
|
||||||
|
*/
|
||||||
|
static std::vector<float> ComputeFeatures(const std::string &wav_filename,
|
||||||
|
float expected_sampling_rate,
|
||||||
|
int32_t *num_frames) {
|
||||||
|
std::vector<float> samples =
|
||||||
|
sherpa_onnx::ReadWave(wav_filename, expected_sampling_rate);
|
||||||
|
|
||||||
|
float duration = samples.size() / expected_sampling_rate;
|
||||||
|
|
||||||
|
std::cout << "wav filename: " << wav_filename << "\n";
|
||||||
|
std::cout << "wav duration (s): " << duration << "\n";
|
||||||
|
|
||||||
|
knf::FbankOptions opts;
|
||||||
|
opts.frame_opts.dither = 0;
|
||||||
|
opts.frame_opts.snip_edges = false;
|
||||||
|
opts.frame_opts.samp_freq = expected_sampling_rate;
|
||||||
|
|
||||||
|
int32_t feature_dim = 80;
|
||||||
|
|
||||||
|
opts.mel_opts.num_bins = feature_dim;
|
||||||
|
|
||||||
|
knf::OnlineFbank fbank(opts);
|
||||||
|
fbank.AcceptWaveform(expected_sampling_rate, samples.data(), samples.size());
|
||||||
|
fbank.InputFinished();
|
||||||
|
|
||||||
|
*num_frames = fbank.NumFramesReady();
|
||||||
|
|
||||||
|
std::vector<float> features(*num_frames * feature_dim);
|
||||||
|
float *p = features.data();
|
||||||
|
|
||||||
|
for (int32_t i = 0; i != fbank.NumFramesReady(); ++i, p += feature_dim) {
|
||||||
|
const float *f = fbank.GetFrame(i);
|
||||||
|
std::copy(f, f + feature_dim, p);
|
||||||
|
}
|
||||||
|
|
||||||
|
return features;
|
||||||
|
}
|
||||||
|
|
||||||
|
int main(int32_t argc, char *argv[]) {
|
||||||
|
if (argc < 8 || argc > 9) {
|
||||||
|
const char *usage = R"usage(
|
||||||
|
Usage:
|
||||||
|
./bin/sherpa-onnx \
|
||||||
|
/path/to/tokens.txt \
|
||||||
|
/path/to/encoder.onnx \
|
||||||
|
/path/to/decoder.onnx \
|
||||||
|
/path/to/joiner.onnx \
|
||||||
|
/path/to/joiner_encoder_proj.ncnn.param \
|
||||||
|
/path/to/joiner_decoder_proj.ncnn.param \
|
||||||
|
/path/to/foo.wav [num_threads]
|
||||||
|
|
||||||
|
You can download pre-trained models from the following repository:
|
||||||
|
https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13
|
||||||
|
)usage";
|
||||||
|
std::cerr << usage << "\n";
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string tokens = argv[1];
|
||||||
|
std::string encoder = argv[2];
|
||||||
|
std::string decoder = argv[3];
|
||||||
|
std::string joiner = argv[4];
|
||||||
|
std::string joiner_encoder_proj = argv[5];
|
||||||
|
std::string joiner_decoder_proj = argv[6];
|
||||||
|
std::string wav_filename = argv[7];
|
||||||
|
int32_t num_threads = 4;
|
||||||
|
if (argc == 9) {
|
||||||
|
num_threads = atoi(argv[8]);
|
||||||
|
}
|
||||||
|
|
||||||
|
sherpa_onnx::SymbolTable sym(tokens);
|
||||||
|
|
||||||
|
int32_t num_frames;
|
||||||
|
auto features = ComputeFeatures(wav_filename, 16000, &num_frames);
|
||||||
|
int32_t feature_dim = features.size() / num_frames;
|
||||||
|
|
||||||
|
sherpa_onnx::RnntModel model(encoder, decoder, joiner, joiner_encoder_proj,
|
||||||
|
joiner_decoder_proj, num_threads);
|
||||||
|
Ort::Value encoder_out =
|
||||||
|
model.RunEncoder(features.data(), num_frames, feature_dim);
|
||||||
|
|
||||||
|
auto hyp = sherpa_onnx::GreedySearch(model, encoder_out);
|
||||||
|
|
||||||
|
std::string text;
|
||||||
|
for (auto i : hyp) {
|
||||||
|
text += sym[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
std::cout << "Recognition result for " << wav_filename << "\n"
|
||||||
|
<< text << "\n";
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
@@ -15,34 +15,21 @@
|
|||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
#include "kaldi-native-fbank/csrc/online-feature.h"
|
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||||
|
|
||||||
int main() {
|
int main() {
|
||||||
knf::FbankOptions opts;
|
std::cout << "ORT_API_VERSION: " << ORT_API_VERSION << "\n";
|
||||||
opts.frame_opts.dither = 0;
|
std::vector<std::string> providers = Ort::GetAvailableProviders();
|
||||||
opts.mel_opts.num_bins = 10;
|
|
||||||
|
|
||||||
knf::OnlineFbank fbank(opts);
|
|
||||||
for (int32_t i = 0; i < 1600; ++i) {
|
|
||||||
float s = (i * i - i / 2) / 32767.;
|
|
||||||
fbank.AcceptWaveform(16000, &s, 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::ostringstream os;
|
std::ostringstream os;
|
||||||
|
os << "Available providers: ";
|
||||||
int32_t n = fbank.NumFramesReady();
|
std::string sep = "";
|
||||||
for (int32_t i = 0; i != n; ++i) {
|
for (const auto &p : providers) {
|
||||||
const float *frame = fbank.GetFrame(i);
|
os << sep << p;
|
||||||
for (int32_t k = 0; k != opts.mel_opts.num_bins; ++k) {
|
sep = ", ";
|
||||||
os << frame[k] << ", ";
|
|
||||||
}
|
|
||||||
os << "\n";
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::cout << os.str() << "\n";
|
std::cout << os.str() << "\n";
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
78
sherpa-onnx/csrc/symbol-table.cc
Normal file
78
sherpa-onnx/csrc/symbol-table.cc
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||||
|
*
|
||||||
|
* See LICENSE for clarification regarding multiple authors
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/symbol-table.h"
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
#include <fstream>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
SymbolTable::SymbolTable(const std::string &filename) {
|
||||||
|
std::ifstream is(filename);
|
||||||
|
std::string sym;
|
||||||
|
int32_t id;
|
||||||
|
while (is >> sym >> id) {
|
||||||
|
if (sym.size() >= 3) {
|
||||||
|
// For BPE-based models, we replace ▁ with a space
|
||||||
|
// Unicode 9601, hex 0x2581, utf8 0xe29681
|
||||||
|
const uint8_t *p = reinterpret_cast<const uint8_t *>(sym.c_str());
|
||||||
|
if (p[0] == 0xe2 && p[1] == 0x96 && p[2] == 0x81) {
|
||||||
|
sym = sym.replace(0, 3, " ");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert(!sym.empty());
|
||||||
|
assert(sym2id_.count(sym) == 0);
|
||||||
|
assert(id2sym_.count(id) == 0);
|
||||||
|
|
||||||
|
sym2id_.insert({sym, id});
|
||||||
|
id2sym_.insert({id, sym});
|
||||||
|
}
|
||||||
|
assert(is.eof());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string SymbolTable::ToString() const {
|
||||||
|
std::ostringstream os;
|
||||||
|
char sep = ' ';
|
||||||
|
for (const auto &p : sym2id_) {
|
||||||
|
os << p.first << sep << p.second << "\n";
|
||||||
|
}
|
||||||
|
return os.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::string &SymbolTable::operator[](int32_t id) const {
|
||||||
|
return id2sym_.at(id);
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t SymbolTable::operator[](const std::string &sym) const {
|
||||||
|
return sym2id_.at(sym);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool SymbolTable::contains(int32_t id) const { return id2sym_.count(id) != 0; }
|
||||||
|
|
||||||
|
bool SymbolTable::contains(const std::string &sym) const {
|
||||||
|
return sym2id_.count(sym) != 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::ostream &operator<<(std::ostream &os, const SymbolTable &symbol_table) {
|
||||||
|
return os << symbol_table.ToString();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
62
sherpa-onnx/csrc/symbol-table.h
Normal file
62
sherpa-onnx/csrc/symbol-table.h
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||||
|
*
|
||||||
|
* See LICENSE for clarification regarding multiple authors
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_SYMBOL_TABLE_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_SYMBOL_TABLE_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
/// It manages mapping between symbols and integer IDs.
|
||||||
|
class SymbolTable {
|
||||||
|
public:
|
||||||
|
SymbolTable() = default;
|
||||||
|
/// Construct a symbol table from a file.
|
||||||
|
/// Each line in the file contains two fields:
|
||||||
|
///
|
||||||
|
/// sym ID
|
||||||
|
///
|
||||||
|
/// Fields are separated by space(s).
|
||||||
|
explicit SymbolTable(const std::string &filename);
|
||||||
|
|
||||||
|
/// Return a string representation of this symbol table
|
||||||
|
std::string ToString() const;
|
||||||
|
|
||||||
|
/// Return the symbol corresponding to the given ID.
|
||||||
|
const std::string &operator[](int32_t id) const;
|
||||||
|
/// Return the ID corresponding to the given symbol.
|
||||||
|
int32_t operator[](const std::string &sym) const;
|
||||||
|
|
||||||
|
/// Return true if there is a symbol with the given ID.
|
||||||
|
bool contains(int32_t id) const;
|
||||||
|
|
||||||
|
/// Return true if there is a given symbol in the symbol table.
|
||||||
|
bool contains(const std::string &sym) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::unordered_map<std::string, int32_t> sym2id_;
|
||||||
|
std::unordered_map<int32_t, std::string> id2sym_;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::ostream &operator<<(std::ostream &os, const SymbolTable &symbol_table);
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_SYMBOL_TABLE_H_
|
||||||
@@ -1,39 +0,0 @@
|
|||||||
#include <iostream>
|
|
||||||
#include <fstream>
|
|
||||||
|
|
||||||
|
|
||||||
void vector2file(std::vector<float> vector, std::string saveFileName){
|
|
||||||
std::ofstream f(saveFileName);
|
|
||||||
for(std::vector<float>::const_iterator i = vector.begin(); i != vector.end(); ++i) {
|
|
||||||
f << *i << '\n';
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
std::vector<std::string> hyps2result(std::map<int, std::string> token_map, std::vector<std::vector<int32_t>> hyps, int context_size = 2){
|
|
||||||
std::vector<std::string> results;
|
|
||||||
|
|
||||||
for (int k=0; k < hyps.size(); k++){
|
|
||||||
std::string result = token_map[hyps[k][context_size]];
|
|
||||||
|
|
||||||
for (int i=context_size+1; i < hyps[k].size(); i++){
|
|
||||||
std::string token = token_map[hyps[k][i]];
|
|
||||||
|
|
||||||
// TODO: recognising '_' is not working
|
|
||||||
if (token.at(0) == '_')
|
|
||||||
result += " " + token;
|
|
||||||
else
|
|
||||||
result += token;
|
|
||||||
}
|
|
||||||
results.push_back(result);
|
|
||||||
}
|
|
||||||
return results;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
void print_hyps(std::vector<std::vector<int32_t>> hyps, int context_size = 2){
|
|
||||||
std::cout << "Hyps:" << std::endl;
|
|
||||||
for (int i=context_size; i<hyps[0].size(); i++)
|
|
||||||
std::cout << hyps[0][i] << "-";
|
|
||||||
std::cout << "|" << std::endl;
|
|
||||||
}
|
|
||||||
@@ -1,77 +0,0 @@
|
|||||||
#include <iostream>
|
|
||||||
#include "onnxruntime_cxx_api.h"
|
|
||||||
|
|
||||||
Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "test");
|
|
||||||
const auto& api = Ort::GetApi();
|
|
||||||
OrtTensorRTProviderOptionsV2* tensorrt_options;
|
|
||||||
Ort::SessionOptions session_options;
|
|
||||||
Ort::AllocatorWithDefaultOptions allocator;
|
|
||||||
auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
|
|
||||||
|
|
||||||
|
|
||||||
std::vector<float> ortVal2Vector(Ort::Value &tensor, int tensor_length){
|
|
||||||
/**
|
|
||||||
* convert ort tensor to vector
|
|
||||||
*/
|
|
||||||
float* floatarr = tensor.GetTensorMutableData<float>();
|
|
||||||
std::vector<float> vector {floatarr, floatarr + tensor_length};
|
|
||||||
return vector;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
void print_onnx_forward_output(std::vector<Ort::Value> &output_tensors, int num){
|
|
||||||
float* floatarr = output_tensors.front().GetTensorMutableData<float>();
|
|
||||||
for (int i = 0; i < num; i++)
|
|
||||||
printf("[%d] = %f\n", i, floatarr[i]);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
void print_shape_of_ort_val(std::vector<Ort::Value> &tensor){
|
|
||||||
auto out_shape = tensor.front().GetTensorTypeAndShapeInfo().GetShape();
|
|
||||||
auto out_size = out_shape.size();
|
|
||||||
std::cout << "(";
|
|
||||||
for (int i=0; i<out_size; i++){
|
|
||||||
std::cout << out_shape[i];
|
|
||||||
if (i < out_size-1)
|
|
||||||
std::cout << ",";
|
|
||||||
}
|
|
||||||
std::cout << ")" << std::endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
void print_model_info(Ort::Session &session, std::string title){
|
|
||||||
std::cout << "=== Printing '" << title << "' model ===" << std::endl;
|
|
||||||
Ort::AllocatorWithDefaultOptions allocator;
|
|
||||||
|
|
||||||
// print number of model input nodes
|
|
||||||
size_t num_input_nodes = session.GetInputCount();
|
|
||||||
std::vector<const char*> input_node_names(num_input_nodes);
|
|
||||||
std::vector<int64_t> input_node_dims;
|
|
||||||
|
|
||||||
printf("Number of inputs = %zu\n", num_input_nodes);
|
|
||||||
|
|
||||||
char* output_name = session.GetOutputName(0, allocator);
|
|
||||||
printf("output name: %s\n", output_name);
|
|
||||||
|
|
||||||
// iterate over all input nodes
|
|
||||||
for (int i = 0; i < num_input_nodes; i++) {
|
|
||||||
// print input node names
|
|
||||||
char* input_name = session.GetInputName(i, allocator);
|
|
||||||
printf("Input %d : name=%s\n", i, input_name);
|
|
||||||
input_node_names[i] = input_name;
|
|
||||||
|
|
||||||
// print input node types
|
|
||||||
Ort::TypeInfo type_info = session.GetInputTypeInfo(i);
|
|
||||||
auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
|
|
||||||
|
|
||||||
ONNXTensorElementDataType type = tensor_info.GetElementType();
|
|
||||||
printf("Input %d : type=%d\n", i, type);
|
|
||||||
|
|
||||||
// print input shapes/dims
|
|
||||||
input_node_dims = tensor_info.GetShape();
|
|
||||||
printf("Input %d : num_dims=%zu\n", i, input_node_dims.size());
|
|
||||||
for (size_t j = 0; j < input_node_dims.size(); j++)
|
|
||||||
printf("Input %d : dim %zu=%jd\n", i, j, input_node_dims[j]);
|
|
||||||
}
|
|
||||||
std::cout << "=======================================" << std::endl;
|
|
||||||
}
|
|
||||||
108
sherpa-onnx/csrc/wave-reader.cc
Normal file
108
sherpa-onnx/csrc/wave-reader.cc
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||||
|
*
|
||||||
|
* See LICENSE for clarification regarding multiple authors
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/wave-reader.h"
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
#include <fstream>
|
||||||
|
#include <iostream>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// see http://soundfile.sapp.org/doc/WaveFormat/
|
||||||
|
//
|
||||||
|
// Note: We assume little endian here
|
||||||
|
// TODO(fangjun): Support big endian
|
||||||
|
struct WaveHeader {
|
||||||
|
void Validate() const {
|
||||||
|
// F F I R
|
||||||
|
assert(chunk_id == 0x46464952);
|
||||||
|
assert(chunk_size == 36 + subchunk2_size);
|
||||||
|
// E V A W
|
||||||
|
assert(format == 0x45564157);
|
||||||
|
assert(subchunk1_id == 0x20746d66);
|
||||||
|
assert(subchunk1_size == 16); // 16 for PCM
|
||||||
|
assert(audio_format == 1); // 1 for PCM
|
||||||
|
assert(num_channels == 1); // we support only single channel for now
|
||||||
|
assert(byte_rate == sample_rate * num_channels * bits_per_sample / 8);
|
||||||
|
assert(block_align == num_channels * bits_per_sample / 8);
|
||||||
|
assert(bits_per_sample == 16); // we support only 16 bits per sample
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t chunk_id;
|
||||||
|
int32_t chunk_size;
|
||||||
|
int32_t format;
|
||||||
|
int32_t subchunk1_id;
|
||||||
|
int32_t subchunk1_size;
|
||||||
|
int16_t audio_format;
|
||||||
|
int16_t num_channels;
|
||||||
|
int32_t sample_rate;
|
||||||
|
int32_t byte_rate;
|
||||||
|
int16_t block_align;
|
||||||
|
int16_t bits_per_sample;
|
||||||
|
int32_t subchunk2_id;
|
||||||
|
int32_t subchunk2_size;
|
||||||
|
};
|
||||||
|
static_assert(sizeof(WaveHeader) == 44, "");
|
||||||
|
|
||||||
|
// Read a wave file of mono-channel.
|
||||||
|
// Return its samples normalized to the range [-1, 1).
|
||||||
|
std::vector<float> ReadWaveImpl(std::istream &is, float *sample_rate) {
|
||||||
|
WaveHeader header;
|
||||||
|
is.read(reinterpret_cast<char *>(&header), sizeof(header));
|
||||||
|
assert(static_cast<bool>(is));
|
||||||
|
|
||||||
|
header.Validate();
|
||||||
|
|
||||||
|
*sample_rate = header.sample_rate;
|
||||||
|
|
||||||
|
// header.subchunk2_size contains the number of bytes in the data.
|
||||||
|
// As we assume each sample contains two bytes, so it is divided by 2 here
|
||||||
|
std::vector<int16_t> samples(header.subchunk2_size / 2);
|
||||||
|
|
||||||
|
is.read(reinterpret_cast<char *>(samples.data()), header.subchunk2_size);
|
||||||
|
|
||||||
|
assert(static_cast<bool>(is));
|
||||||
|
|
||||||
|
std::vector<float> ans(samples.size());
|
||||||
|
for (int32_t i = 0; i != ans.size(); ++i) {
|
||||||
|
ans[i] = samples[i] / 32768.;
|
||||||
|
}
|
||||||
|
|
||||||
|
return ans;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::vector<float> ReadWave(const std::string &filename,
|
||||||
|
float expected_sample_rate) {
|
||||||
|
std::ifstream is(filename, std::ifstream::binary);
|
||||||
|
float sample_rate;
|
||||||
|
auto samples = ReadWaveImpl(is, &sample_rate);
|
||||||
|
if (expected_sample_rate != sample_rate) {
|
||||||
|
std::cerr << "Expected sample rate: " << expected_sample_rate
|
||||||
|
<< ". Given: " << sample_rate << ".\n";
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
return samples;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
41
sherpa-onnx/csrc/wave-reader.h
Normal file
41
sherpa-onnx/csrc/wave-reader.h
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||||
|
*
|
||||||
|
* See LICENSE for clarification regarding multiple authors
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_WAVE_READER_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_WAVE_READER_H_
|
||||||
|
|
||||||
|
#include <istream>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
/** Read a wave file with expected sample rate.
|
||||||
|
|
||||||
|
@param filename Path to a wave file. It MUST be single channel, PCM encoded.
|
||||||
|
@param expected_sample_rate Expected sample rate of the wave file. If the
|
||||||
|
sample rate don't match, it throws an exception.
|
||||||
|
|
||||||
|
@return Return wave samples normalized to the range [-1, 1).
|
||||||
|
*/
|
||||||
|
std::vector<float> ReadWave(const std::string &filename,
|
||||||
|
float expected_sample_rate);
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_WAVE_READER_H_
|
||||||
Reference in New Issue
Block a user