Refactor the code (#15)
* code refactoring * Remove reference files * Update README and CI * small fixes * fix style issues * add style check for CI * fix style issues * remove kaldi-native-io
This commit is contained in:
9
.clang-format
Normal file
9
.clang-format
Normal file
@@ -0,0 +1,9 @@
|
||||
---
|
||||
BasedOnStyle: Google
|
||||
---
|
||||
Language: Cpp
|
||||
Cpp11BracedListStyle: true
|
||||
Standard: Cpp11
|
||||
DerivePointerAlignment: false
|
||||
PointerAlignment: Right
|
||||
---
|
||||
58
.github/workflows/style_check.yaml
vendored
Normal file
58
.github/workflows/style_check.yaml
vendored
Normal file
@@ -0,0 +1,58 @@
|
||||
# Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
name: style_check
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
paths:
|
||||
- '.github/workflows/style_check.yaml'
|
||||
- 'sherpa-onnx/**'
|
||||
pull_request:
|
||||
branches:
|
||||
- master
|
||||
paths:
|
||||
- '.github/workflows/style_check.yaml'
|
||||
- 'sherpa-onnx/**'
|
||||
|
||||
concurrency:
|
||||
group: style_check-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
style_check:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: [3.8]
|
||||
fail-fast: false
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v1
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Check style with cpplint
|
||||
shell: bash
|
||||
working-directory: ${{github.workspace}}
|
||||
run: ./scripts/check_style_cpplint.sh
|
||||
9
.github/workflows/test-linux-macos.yaml
vendored
9
.github/workflows/test-linux-macos.yaml
vendored
@@ -59,31 +59,28 @@ jobs:
|
||||
- name: Run tests for ubuntu/macos (English)
|
||||
run: |
|
||||
time ./build/bin/sherpa-onnx \
|
||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \
|
||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \
|
||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \
|
||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \
|
||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \
|
||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \
|
||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \
|
||||
greedy \
|
||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1089-134686-0001.wav
|
||||
|
||||
time ./build/bin/sherpa-onnx \
|
||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \
|
||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \
|
||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \
|
||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \
|
||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \
|
||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \
|
||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \
|
||||
greedy \
|
||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0001.wav
|
||||
|
||||
time ./build/bin/sherpa-onnx \
|
||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \
|
||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \
|
||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \
|
||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \
|
||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \
|
||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \
|
||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \
|
||||
greedy \
|
||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0002.wav
|
||||
|
||||
@@ -38,7 +38,6 @@ set(CMAKE_CXX_EXTENSIONS OFF)
|
||||
list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules)
|
||||
list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake)
|
||||
|
||||
include(kaldi_native_io)
|
||||
include(kaldi-native-fbank)
|
||||
include(onnxruntime)
|
||||
|
||||
|
||||
@@ -14,6 +14,9 @@ the following links:
|
||||
**NOTE**: We provide only non-streaming models at present.
|
||||
|
||||
|
||||
**HINT**: The script for exporting the English model can be found at
|
||||
<https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless3/export.py>
|
||||
|
||||
# Usage
|
||||
|
||||
```bash
|
||||
@@ -34,13 +37,14 @@ cd ..
|
||||
git lfs install
|
||||
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13
|
||||
|
||||
./build/bin/sherpa-onnx --help
|
||||
|
||||
./build/bin/sherpa-onnx \
|
||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \
|
||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \
|
||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \
|
||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \
|
||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \
|
||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \
|
||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \
|
||||
greedy \
|
||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1089-134686-0001.wav
|
||||
```
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
function(download_kaldi_native_io)
|
||||
if(CMAKE_VERSION VERSION_LESS 3.11)
|
||||
# FetchContent is available since 3.11,
|
||||
# we've copied it to ${CMAKE_SOURCE_DIR}/cmake/Modules
|
||||
# so that it can be used in lower CMake versions.
|
||||
message(STATUS "Use FetchContent provided by sherpa-onnx")
|
||||
list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules)
|
||||
endif()
|
||||
|
||||
include(FetchContent)
|
||||
|
||||
set(kaldi_native_io_URL "https://github.com/csukuangfj/kaldi_native_io/archive/refs/tags/v1.15.1.tar.gz")
|
||||
set(kaldi_native_io_HASH "SHA256=97377e1d61e99d8fc1d6037a418d3037522dfa46337e06162e24b1d97f3d70a6")
|
||||
|
||||
set(KALDI_NATIVE_IO_BUILD_TESTS OFF CACHE BOOL "" FORCE)
|
||||
set(KALDI_NATIVE_IO_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
|
||||
|
||||
FetchContent_Declare(kaldi_native_io
|
||||
URL ${kaldi_native_io_URL}
|
||||
URL_HASH ${kaldi_native_io_HASH}
|
||||
)
|
||||
|
||||
FetchContent_GetProperties(kaldi_native_io)
|
||||
if(NOT kaldi_native_io_POPULATED)
|
||||
message(STATUS "Downloading kaldi_native_io ${kaldi_native_io_URL}")
|
||||
FetchContent_Populate(kaldi_native_io)
|
||||
endif()
|
||||
message(STATUS "kaldi_native_io is downloaded to ${kaldi_native_io_SOURCE_DIR}")
|
||||
message(STATUS "kaldi_native_io's binary dir is ${kaldi_native_io_BINARY_DIR}")
|
||||
|
||||
add_subdirectory(${kaldi_native_io_SOURCE_DIR} ${kaldi_native_io_BINARY_DIR} EXCLUDE_FROM_ALL)
|
||||
|
||||
target_include_directories(kaldi_native_io_core
|
||||
PUBLIC
|
||||
${kaldi_native_io_SOURCE_DIR}/
|
||||
)
|
||||
endfunction()
|
||||
|
||||
download_kaldi_native_io()
|
||||
@@ -10,7 +10,7 @@ function(download_onnxruntime)
|
||||
include(FetchContent)
|
||||
|
||||
if(UNIX AND NOT APPLE)
|
||||
set(onnxruntime_URL "http://github.com/microsoft/onnxruntime/releases/download/v1.12.1/onnxruntime-linux-x64-1.12.1.tgz")
|
||||
set(onnxruntime_URL "https://github.com/microsoft/onnxruntime/releases/download/v1.12.1/onnxruntime-linux-x64-1.12.1.tgz")
|
||||
|
||||
# If you don't have access to the internet, you can first download onnxruntime to some directory, and the use
|
||||
# set(onnxruntime_URL "file:///ceph-fj/fangjun/open-source/sherpa-onnx/onnxruntime-linux-x64-1.12.1.tgz")
|
||||
|
||||
126
scripts/check_style_cpplint.sh
Executable file
126
scripts/check_style_cpplint.sh
Executable file
@@ -0,0 +1,126 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# Copyright 2020 Mobvoi Inc. (authors: Fangjun Kuang)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# Usage:
|
||||
#
|
||||
# (1) To check files of the last commit
|
||||
# ./scripts/check_style_cpplint.sh
|
||||
#
|
||||
# (2) To check changed files not committed yet
|
||||
# ./scripts/check_style_cpplint.sh 1
|
||||
#
|
||||
# (3) To check all files in the project
|
||||
# ./scripts/check_style_cpplint.sh 2
|
||||
|
||||
|
||||
cpplint_version="1.5.4"
|
||||
cur_dir=$(cd $(dirname $BASH_SOURCE) && pwd)
|
||||
sherpa_onnx_dir=$(cd $cur_dir/.. && pwd)
|
||||
|
||||
build_dir=$sherpa_onnx_dir/build
|
||||
mkdir -p $build_dir
|
||||
|
||||
cpplint_src=$build_dir/cpplint-${cpplint_version}/cpplint.py
|
||||
|
||||
if [ ! -d "$build_dir/cpplint-${cpplint_version}" ]; then
|
||||
pushd $build_dir
|
||||
if command -v wget &> /dev/null; then
|
||||
wget https://github.com/cpplint/cpplint/archive/${cpplint_version}.tar.gz
|
||||
elif command -v curl &> /dev/null; then
|
||||
curl -O -SL https://github.com/cpplint/cpplint/archive/${cpplint_version}.tar.gz
|
||||
else
|
||||
echo "Please install wget or curl to download cpplint"
|
||||
exit 1
|
||||
fi
|
||||
tar xf ${cpplint_version}.tar.gz
|
||||
rm ${cpplint_version}.tar.gz
|
||||
|
||||
# cpplint will report the following error for: __host__ __device__ (
|
||||
#
|
||||
# Extra space before ( in function call [whitespace/parens] [4]
|
||||
#
|
||||
# the following patch disables the above error
|
||||
sed -i "3490i\ not Search(r'__host__ __device__\\\s+\\\(', fncall) and" $cpplint_src
|
||||
popd
|
||||
fi
|
||||
|
||||
source $sherpa_onnx_dir/scripts/utils.sh
|
||||
|
||||
# return true if the given file is a c++ source file
|
||||
# return false otherwise
|
||||
function is_source_code_file() {
|
||||
case "$1" in
|
||||
*.cc|*.h|*.cu)
|
||||
echo true;;
|
||||
*)
|
||||
echo false;;
|
||||
esac
|
||||
}
|
||||
|
||||
function check_style() {
|
||||
python3 $cpplint_src $1 || abort $1
|
||||
}
|
||||
|
||||
function check_last_commit() {
|
||||
files=$(git diff HEAD^1 --name-only --diff-filter=ACDMRUXB)
|
||||
echo $files
|
||||
}
|
||||
|
||||
function check_current_dir() {
|
||||
files=$(git status -s -uno --porcelain | awk '{
|
||||
if (NF == 4) {
|
||||
# a file has been renamed
|
||||
print $NF
|
||||
} else {
|
||||
print $2
|
||||
}}')
|
||||
|
||||
echo $files
|
||||
}
|
||||
|
||||
function do_check() {
|
||||
case "$1" in
|
||||
1)
|
||||
echo "Check changed files"
|
||||
files=$(check_current_dir)
|
||||
;;
|
||||
2)
|
||||
echo "Check all files"
|
||||
files=$(find $sherpa_onnx_dir/sherpa-onnx -name "*.h" -o -name "*.cc")
|
||||
;;
|
||||
*)
|
||||
echo "Check last commit"
|
||||
files=$(check_last_commit)
|
||||
;;
|
||||
esac
|
||||
|
||||
for f in $files; do
|
||||
need_check=$(is_source_code_file $f)
|
||||
if $need_check; then
|
||||
[[ -f $f ]] && check_style $f
|
||||
fi
|
||||
done
|
||||
}
|
||||
|
||||
function main() {
|
||||
do_check $1
|
||||
|
||||
ok "Great! Style check passed!"
|
||||
}
|
||||
|
||||
cd $sherpa_onnx_dir
|
||||
|
||||
main $1
|
||||
19
scripts/utils.sh
Normal file
19
scripts/utils.sh
Normal file
@@ -0,0 +1,19 @@
|
||||
#!/bin/bash
|
||||
|
||||
default='\033[0m'
|
||||
bold='\033[1m'
|
||||
red='\033[31m'
|
||||
green='\033[32m'
|
||||
|
||||
function ok() {
|
||||
printf "${bold}${green}[OK]${default} $1\n"
|
||||
}
|
||||
|
||||
function error() {
|
||||
printf "${bold}${red}[FAILED]${default} $1\n"
|
||||
}
|
||||
|
||||
function abort() {
|
||||
printf "${bold}${red}[FAILED]${default} $1\n"
|
||||
exit 1
|
||||
}
|
||||
@@ -1,8 +1,17 @@
|
||||
include_directories(${CMAKE_SOURCE_DIR})
|
||||
add_executable(sherpa-onnx main.cpp)
|
||||
|
||||
add_executable(sherpa-onnx
|
||||
decode.cc
|
||||
rnnt-model.cc
|
||||
sherpa-onnx.cc
|
||||
symbol-table.cc
|
||||
wave-reader.cc
|
||||
)
|
||||
|
||||
target_link_libraries(sherpa-onnx
|
||||
onnxruntime
|
||||
kaldi-native-fbank-core
|
||||
kaldi_native_io_core
|
||||
)
|
||||
|
||||
add_executable(sherpa-show-onnx-info show-onnx-info.cc)
|
||||
target_link_libraries(sherpa-show-onnx-info onnxruntime)
|
||||
|
||||
84
sherpa-onnx/csrc/decode.cc
Normal file
84
sherpa-onnx/csrc/decode.cc
Normal file
@@ -0,0 +1,84 @@
|
||||
/**
|
||||
* Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
*
|
||||
* See LICENSE for clarification regarding multiple authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "sherpa-onnx/csrc/decode.h"
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
std::vector<int32_t> GreedySearch(RnntModel &model, // NOLINT
|
||||
const Ort::Value &encoder_out) {
|
||||
std::vector<int64_t> encoder_out_shape =
|
||||
encoder_out.GetTensorTypeAndShapeInfo().GetShape();
|
||||
assert(encoder_out_shape[0] == 1 && "Only batch_size=1 is implemented");
|
||||
Ort::Value projected_encoder_out =
|
||||
model.RunJoinerEncoderProj(encoder_out.GetTensorData<float>(),
|
||||
encoder_out_shape[1], encoder_out_shape[2]);
|
||||
|
||||
const float *p_projected_encoder_out =
|
||||
projected_encoder_out.GetTensorData<float>();
|
||||
|
||||
int32_t context_size = 2; // hard-code it to 2
|
||||
int32_t blank_id = 0; // hard-code it to 0
|
||||
std::vector<int32_t> hyp(context_size, blank_id);
|
||||
std::array<int64_t, 2> decoder_input{blank_id, blank_id};
|
||||
|
||||
Ort::Value decoder_out = model.RunDecoder(decoder_input.data(), context_size);
|
||||
|
||||
std::vector<int64_t> decoder_out_shape =
|
||||
decoder_out.GetTensorTypeAndShapeInfo().GetShape();
|
||||
|
||||
Ort::Value projected_decoder_out = model.RunJoinerDecoderProj(
|
||||
decoder_out.GetTensorData<float>(), decoder_out_shape[2]);
|
||||
|
||||
int32_t joiner_dim =
|
||||
projected_decoder_out.GetTensorTypeAndShapeInfo().GetShape()[1];
|
||||
|
||||
int32_t T = encoder_out_shape[1];
|
||||
for (int32_t t = 0; t != T; ++t) {
|
||||
Ort::Value logit = model.RunJoiner(
|
||||
p_projected_encoder_out + t * joiner_dim,
|
||||
projected_decoder_out.GetTensorData<float>(), joiner_dim);
|
||||
|
||||
int32_t vocab_size = logit.GetTensorTypeAndShapeInfo().GetShape()[1];
|
||||
|
||||
const float *p_logit = logit.GetTensorData<float>();
|
||||
|
||||
auto y = static_cast<int32_t>(std::distance(
|
||||
static_cast<const float *>(p_logit),
|
||||
std::max_element(static_cast<const float *>(p_logit),
|
||||
static_cast<const float *>(p_logit) + vocab_size)));
|
||||
|
||||
if (y != blank_id) {
|
||||
decoder_input[0] = hyp.back();
|
||||
decoder_input[1] = y;
|
||||
hyp.push_back(y);
|
||||
decoder_out = model.RunDecoder(decoder_input.data(), context_size);
|
||||
projected_decoder_out = model.RunJoinerDecoderProj(
|
||||
decoder_out.GetTensorData<float>(), decoder_out_shape[2]);
|
||||
}
|
||||
}
|
||||
|
||||
return {hyp.begin() + context_size, hyp.end()};
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
40
sherpa-onnx/csrc/decode.h
Normal file
40
sherpa-onnx/csrc/decode.h
Normal file
@@ -0,0 +1,40 @@
|
||||
/**
|
||||
* Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
*
|
||||
* See LICENSE for clarification regarding multiple authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_DECODE_H_
|
||||
#define SHERPA_ONNX_CSRC_DECODE_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/rnnt-model.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
/** Greedy search for non-streaming ASR.
|
||||
*
|
||||
* @TODO(fangjun) Support batch size > 1
|
||||
*
|
||||
* @param model The RnntModel
|
||||
* @param encoder_out Its shape is (1, num_frames, encoder_out_dim).
|
||||
*/
|
||||
std::vector<int32_t> GreedySearch(RnntModel &model, // NOLINT
|
||||
const Ort::Value &encoder_out);
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_DECODE_H_
|
||||
@@ -1,57 +0,0 @@
|
||||
#include <iostream>
|
||||
|
||||
#include "kaldi_native_io/csrc/kaldi-io.h"
|
||||
#include "kaldi_native_io/csrc/wave-reader.h"
|
||||
#include "kaldi-native-fbank/csrc/online-feature.h"
|
||||
|
||||
|
||||
kaldiio::Matrix<float> readWav(std::string filename, bool log = false){
|
||||
if (log)
|
||||
std::cout << "reading " << filename << std::endl;
|
||||
|
||||
bool binary = true;
|
||||
kaldiio::Input ki(filename, &binary);
|
||||
kaldiio::WaveHolder wh;
|
||||
|
||||
if (!wh.Read(ki.Stream())) {
|
||||
std::cerr << "Failed to read " << filename;
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
auto &wave_data = wh.Value();
|
||||
auto &d = wave_data.Data();
|
||||
|
||||
if (log)
|
||||
std::cout << "wav shape: " << "(" << d.NumRows() << "," << d.NumCols() << ")" << std::endl;
|
||||
|
||||
return d;
|
||||
}
|
||||
|
||||
|
||||
std::vector<float> ComputeFeatures(knf::OnlineFbank &fbank, knf::FbankOptions opts, kaldiio::Matrix<float> samples, bool log = false){
|
||||
int numSamples = samples.NumCols();
|
||||
|
||||
for (int i = 0; i < numSamples; i++)
|
||||
{
|
||||
float currentSample = samples.Row(0).Data()[i] / 32768;
|
||||
fbank.AcceptWaveform(opts.frame_opts.samp_freq, ¤tSample, 1);
|
||||
}
|
||||
|
||||
std::vector<float> features;
|
||||
int32_t num_frames = fbank.NumFramesReady();
|
||||
for (int32_t i = 0; i != num_frames; ++i) {
|
||||
const float *frame = fbank.GetFrame(i);
|
||||
for (int32_t k = 0; k != opts.mel_opts.num_bins; ++k) {
|
||||
features.push_back(frame[k]);
|
||||
}
|
||||
}
|
||||
if (log){
|
||||
std::cout << "done feature extraction" << std::endl;
|
||||
std::cout << "extracted fbank shape " << "(" << num_frames << "," << opts.mel_opts.num_bins << ")" << std::endl;
|
||||
|
||||
for (int i=0; i< 20; i++)
|
||||
std::cout << features.at(i) << std::endl;
|
||||
}
|
||||
|
||||
return features;
|
||||
}
|
||||
@@ -1,99 +0,0 @@
|
||||
#include <algorithm>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <math.h>
|
||||
#include <time.h>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/fbank_features.h"
|
||||
#include "sherpa-onnx/csrc/rnnt_beam_search.h"
|
||||
|
||||
#include "kaldi-native-fbank/csrc/online-feature.h"
|
||||
|
||||
int main(int argc, char *argv[]) {
|
||||
char *encoder_path = argv[1];
|
||||
char *decoder_path = argv[2];
|
||||
char *joiner_path = argv[3];
|
||||
char *joiner_encoder_proj_path = argv[4];
|
||||
char *joiner_decoder_proj_path = argv[5];
|
||||
char *token_path = argv[6];
|
||||
std::string search_method = argv[7];
|
||||
char *filename = argv[8];
|
||||
|
||||
// General parameters
|
||||
int numberOfThreads = 16;
|
||||
|
||||
// Initialize fbanks
|
||||
knf::FbankOptions opts;
|
||||
opts.frame_opts.dither = 0;
|
||||
opts.frame_opts.samp_freq = 16000;
|
||||
opts.frame_opts.frame_shift_ms = 10.0f;
|
||||
opts.frame_opts.frame_length_ms = 25.0f;
|
||||
opts.mel_opts.num_bins = 80;
|
||||
opts.frame_opts.window_type = "povey";
|
||||
opts.frame_opts.snip_edges = false;
|
||||
knf::OnlineFbank fbank(opts);
|
||||
|
||||
// set session opts
|
||||
// https://onnxruntime.ai/docs/performance/tune-performance.html
|
||||
session_options.SetIntraOpNumThreads(numberOfThreads);
|
||||
session_options.SetInterOpNumThreads(numberOfThreads);
|
||||
session_options.SetGraphOptimizationLevel(
|
||||
GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
|
||||
session_options.SetLogSeverityLevel(4);
|
||||
session_options.SetExecutionMode(ExecutionMode::ORT_SEQUENTIAL);
|
||||
|
||||
api.CreateTensorRTProviderOptions(&tensorrt_options);
|
||||
std::unique_ptr<OrtTensorRTProviderOptionsV2,
|
||||
decltype(api.ReleaseTensorRTProviderOptions)>
|
||||
rel_trt_options(tensorrt_options, api.ReleaseTensorRTProviderOptions);
|
||||
api.SessionOptionsAppendExecutionProvider_TensorRT_V2(
|
||||
static_cast<OrtSessionOptions *>(session_options), rel_trt_options.get());
|
||||
|
||||
// Define model
|
||||
auto model =
|
||||
get_model(encoder_path, decoder_path, joiner_path,
|
||||
joiner_encoder_proj_path, joiner_decoder_proj_path, token_path);
|
||||
|
||||
std::vector<std::string> filename_list{filename};
|
||||
|
||||
for (auto filename : filename_list) {
|
||||
std::cout << filename << std::endl;
|
||||
auto samples = readWav(filename, true);
|
||||
int numSamples = samples.NumCols();
|
||||
|
||||
auto features = ComputeFeatures(fbank, opts, samples);
|
||||
|
||||
auto tic = std::chrono::high_resolution_clock::now();
|
||||
|
||||
// # === Encoder Out === #
|
||||
int num_frames = features.size() / opts.mel_opts.num_bins;
|
||||
auto encoder_out =
|
||||
model.encoder_forward(features, std::vector<int64_t>{num_frames},
|
||||
std::vector<int64_t>{1, num_frames, 80},
|
||||
std::vector<int64_t>{1}, memory_info);
|
||||
|
||||
// # === Search === #
|
||||
std::vector<std::vector<int32_t>> hyps;
|
||||
if (search_method == "greedy")
|
||||
hyps = GreedySearch(&model, &encoder_out);
|
||||
else {
|
||||
std::cout << "wrong search method!" << std::endl;
|
||||
exit(0);
|
||||
}
|
||||
auto results = hyps2result(model.tokens_map, hyps);
|
||||
|
||||
// # === Print Elapsed Time === #
|
||||
auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(
|
||||
std::chrono::high_resolution_clock::now() - tic);
|
||||
std::cout << "Elapsed: " << float(elapsed.count()) / 1000 << " seconds"
|
||||
<< std::endl;
|
||||
std::cout << "rtf: " << float(elapsed.count()) / 1000 / (numSamples / 16000)
|
||||
<< std::endl;
|
||||
|
||||
print_hyps(hyps);
|
||||
std::cout << results[0] << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -1,253 +0,0 @@
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
#include <algorithm>
|
||||
#include <sys/stat.h>
|
||||
|
||||
#include "utils_onnx.h"
|
||||
|
||||
|
||||
struct Model
|
||||
{
|
||||
public:
|
||||
const char* encoder_path;
|
||||
const char* decoder_path;
|
||||
const char* joiner_path;
|
||||
const char* joiner_encoder_proj_path;
|
||||
const char* joiner_decoder_proj_path;
|
||||
const char* tokens_path;
|
||||
|
||||
Ort::Session encoder = load_model(encoder_path);
|
||||
Ort::Session decoder = load_model(decoder_path);
|
||||
Ort::Session joiner = load_model(joiner_path);
|
||||
Ort::Session joiner_encoder_proj = load_model(joiner_encoder_proj_path);
|
||||
Ort::Session joiner_decoder_proj = load_model(joiner_decoder_proj_path);
|
||||
std::map<int, std::string> tokens_map = get_token_map(tokens_path);
|
||||
|
||||
int32_t blank_id;
|
||||
int32_t unk_id;
|
||||
int32_t context_size;
|
||||
|
||||
std::vector<Ort::Value> encoder_forward(std::vector<float> in_vector,
|
||||
std::vector<int64_t> in_vector_length,
|
||||
std::vector<int64_t> feature_dims,
|
||||
std::vector<int64_t> feature_length_dims,
|
||||
Ort::MemoryInfo &memory_info){
|
||||
std::vector<Ort::Value> encoder_inputTensors;
|
||||
encoder_inputTensors.push_back(Ort::Value::CreateTensor<float>(memory_info, in_vector.data(), in_vector.size(), feature_dims.data(), feature_dims.size()));
|
||||
encoder_inputTensors.push_back(Ort::Value::CreateTensor<int64_t>(memory_info, in_vector_length.data(), in_vector_length.size(), feature_length_dims.data(), feature_length_dims.size()));
|
||||
|
||||
std::vector<const char*> encoder_inputNames = {encoder.GetInputName(0, allocator), encoder.GetInputName(1, allocator)};
|
||||
std::vector<const char*> encoder_outputNames = {encoder.GetOutputName(0, allocator)};
|
||||
|
||||
auto out = encoder.Run(Ort::RunOptions{nullptr},
|
||||
encoder_inputNames.data(),
|
||||
encoder_inputTensors.data(),
|
||||
encoder_inputTensors.size(),
|
||||
encoder_outputNames.data(),
|
||||
encoder_outputNames.size());
|
||||
return out;
|
||||
}
|
||||
|
||||
std::vector<Ort::Value> decoder_forward(std::vector<int64_t> in_vector,
|
||||
std::vector<int64_t> dims,
|
||||
Ort::MemoryInfo &memory_info){
|
||||
std::vector<Ort::Value> inputTensors;
|
||||
inputTensors.push_back(Ort::Value::CreateTensor<int64_t>(memory_info, in_vector.data(), in_vector.size(), dims.data(), dims.size()));
|
||||
|
||||
std::vector<const char*> inputNames {decoder.GetInputName(0, allocator)};
|
||||
std::vector<const char*> outputNames {decoder.GetOutputName(0, allocator)};
|
||||
|
||||
auto out = decoder.Run(Ort::RunOptions{nullptr},
|
||||
inputNames.data(),
|
||||
inputTensors.data(),
|
||||
inputTensors.size(),
|
||||
outputNames.data(),
|
||||
outputNames.size());
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
std::vector<Ort::Value> joiner_forward(std::vector<float> projected_encoder_out,
|
||||
std::vector<float> decoder_out,
|
||||
std::vector<int64_t> projected_encoder_out_dims,
|
||||
std::vector<int64_t> decoder_out_dims,
|
||||
Ort::MemoryInfo &memory_info){
|
||||
std::vector<Ort::Value> inputTensors;
|
||||
inputTensors.push_back(Ort::Value::CreateTensor<float>(memory_info, projected_encoder_out.data(), projected_encoder_out.size(), projected_encoder_out_dims.data(), projected_encoder_out_dims.size()));
|
||||
inputTensors.push_back(Ort::Value::CreateTensor<float>(memory_info, decoder_out.data(), decoder_out.size(), decoder_out_dims.data(), decoder_out_dims.size()));
|
||||
std::vector<const char*> inputNames = {joiner.GetInputName(0, allocator), joiner.GetInputName(1, allocator)};
|
||||
std::vector<const char*> outputNames = {joiner.GetOutputName(0, allocator)};
|
||||
|
||||
auto out = joiner.Run(Ort::RunOptions{nullptr},
|
||||
inputNames.data(),
|
||||
inputTensors.data(),
|
||||
inputTensors.size(),
|
||||
outputNames.data(),
|
||||
outputNames.size());
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
std::vector<Ort::Value> joiner_encoder_proj_forward(std::vector<float> in_vector,
|
||||
std::vector<int64_t> dims,
|
||||
Ort::MemoryInfo &memory_info){
|
||||
std::vector<Ort::Value> inputTensors;
|
||||
inputTensors.push_back(Ort::Value::CreateTensor<float>(memory_info, in_vector.data(), in_vector.size(), dims.data(), dims.size()));
|
||||
|
||||
std::vector<const char*> inputNames {joiner_encoder_proj.GetInputName(0, allocator)};
|
||||
std::vector<const char*> outputNames {joiner_encoder_proj.GetOutputName(0, allocator)};
|
||||
|
||||
auto out = joiner_encoder_proj.Run(Ort::RunOptions{nullptr},
|
||||
inputNames.data(),
|
||||
inputTensors.data(),
|
||||
inputTensors.size(),
|
||||
outputNames.data(),
|
||||
outputNames.size());
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
std::vector<Ort::Value> joiner_decoder_proj_forward(std::vector<float> in_vector,
|
||||
std::vector<int64_t> dims,
|
||||
Ort::MemoryInfo &memory_info){
|
||||
std::vector<Ort::Value> inputTensors;
|
||||
inputTensors.push_back(Ort::Value::CreateTensor<float>(memory_info, in_vector.data(), in_vector.size(), dims.data(), dims.size()));
|
||||
|
||||
std::vector<const char*> inputNames {joiner_decoder_proj.GetInputName(0, allocator)};
|
||||
std::vector<const char*> outputNames {joiner_decoder_proj.GetOutputName(0, allocator)};
|
||||
|
||||
auto out = joiner_decoder_proj.Run(Ort::RunOptions{nullptr},
|
||||
inputNames.data(),
|
||||
inputTensors.data(),
|
||||
inputTensors.size(),
|
||||
outputNames.data(),
|
||||
outputNames.size());
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
Ort::Session load_model(const char* path){
|
||||
struct stat buffer;
|
||||
if (stat(path, &buffer) != 0){
|
||||
std::cout << "File does not exist!: " << path << std::endl;
|
||||
exit(0);
|
||||
}
|
||||
std::cout << "loading " << path << std::endl;
|
||||
Ort::Session onnx_model(env, path, session_options);
|
||||
return onnx_model;
|
||||
}
|
||||
|
||||
void extract_constant_lm_parameters(){
|
||||
/*
|
||||
all_in_one contains these params. We should trace all_in_one and find 'constants_lm' nodes to extract these params
|
||||
For now, these params are set staticaly.
|
||||
in: Ort::Session &all_in_one
|
||||
out: {blank_id, unk_id, context_size}
|
||||
should return std::vector<int32_t>
|
||||
*/
|
||||
blank_id = 0;
|
||||
unk_id = 0;
|
||||
context_size = 2;
|
||||
}
|
||||
|
||||
std::map<int, std::string> get_token_map(const char* token_path){
|
||||
std::ifstream inFile;
|
||||
inFile.open(token_path);
|
||||
if (inFile.fail())
|
||||
std::cerr << "Could not find token file" << std::endl;
|
||||
|
||||
std::map<int, std::string> token_map;
|
||||
|
||||
std::string line;
|
||||
while (std::getline(inFile, line))
|
||||
{
|
||||
int id;
|
||||
std::string token;
|
||||
|
||||
std::istringstream iss(line);
|
||||
iss >> token;
|
||||
iss >> id;
|
||||
|
||||
token_map[id] = token;
|
||||
}
|
||||
|
||||
return token_map;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
|
||||
Model get_model(std::string exp_path, char* tokens_path){
|
||||
Model model{
|
||||
(exp_path + "/encoder_simp.onnx").c_str(),
|
||||
(exp_path + "/decoder_simp.onnx").c_str(),
|
||||
(exp_path + "/joiner_simp.onnx").c_str(),
|
||||
(exp_path + "/joiner_encoder_proj_simp.onnx").c_str(),
|
||||
(exp_path + "/joiner_decoder_proj_simp.onnx").c_str(),
|
||||
tokens_path,
|
||||
};
|
||||
model.extract_constant_lm_parameters();
|
||||
|
||||
return model;
|
||||
}
|
||||
|
||||
Model get_model(char* encoder_path,
|
||||
char* decoder_path,
|
||||
char* joiner_path,
|
||||
char* joiner_encoder_proj_path,
|
||||
char* joiner_decoder_proj_path,
|
||||
char* tokens_path){
|
||||
Model model{
|
||||
encoder_path,
|
||||
decoder_path,
|
||||
joiner_path,
|
||||
joiner_encoder_proj_path,
|
||||
joiner_decoder_proj_path,
|
||||
tokens_path,
|
||||
};
|
||||
model.extract_constant_lm_parameters();
|
||||
|
||||
return model;
|
||||
}
|
||||
|
||||
|
||||
void doWarmup(Model *model, int numWarmup = 5){
|
||||
std::cout << "Warmup is started" << std::endl;
|
||||
|
||||
std::vector<float> encoder_warmup_sample (500 * 80, 1.0);
|
||||
for (int i=0; i<numWarmup; i++)
|
||||
auto encoder_out = model->encoder_forward(encoder_warmup_sample,
|
||||
std::vector<int64_t> {500},
|
||||
std::vector<int64_t> {1, 500, 80},
|
||||
std::vector<int64_t> {1},
|
||||
memory_info);
|
||||
|
||||
std::vector<int64_t> decoder_warmup_sample {1, 1};
|
||||
for (int i=0; i<numWarmup; i++)
|
||||
auto decoder_out = model->decoder_forward(decoder_warmup_sample,
|
||||
std::vector<int64_t> {1, 2},
|
||||
memory_info);
|
||||
|
||||
std::vector<float> joiner_warmup_sample1 (512, 1.0);
|
||||
std::vector<float> joiner_warmup_sample2 (512, 1.0);
|
||||
for (int i=0; i<numWarmup; i++)
|
||||
auto logits = model->joiner_forward(joiner_warmup_sample1,
|
||||
joiner_warmup_sample2,
|
||||
std::vector<int64_t> {1, 1, 1, 512},
|
||||
std::vector<int64_t> {1, 1, 1, 512},
|
||||
memory_info);
|
||||
|
||||
std::vector<float> joiner_encoder_proj_warmup_sample (100 * 512, 1.0);
|
||||
for (int i=0; i<numWarmup; i++)
|
||||
auto projected_encoder_out = model->joiner_encoder_proj_forward(joiner_encoder_proj_warmup_sample,
|
||||
std::vector<int64_t> {100, 512},
|
||||
memory_info);
|
||||
|
||||
std::vector<float> joiner_decoder_proj_warmup_sample (512, 1.0);
|
||||
for (int i=0; i<numWarmup; i++)
|
||||
auto projected_decoder_out = model->joiner_decoder_proj_forward(joiner_decoder_proj_warmup_sample,
|
||||
std::vector<int64_t> {1, 512},
|
||||
memory_info);
|
||||
std::cout << "Warmup is done" << std::endl;
|
||||
}
|
||||
247
sherpa-onnx/csrc/rnnt-model.cc
Normal file
247
sherpa-onnx/csrc/rnnt-model.cc
Normal file
@@ -0,0 +1,247 @@
|
||||
/**
|
||||
* Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
*
|
||||
* See LICENSE for clarification regarding multiple authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "sherpa-onnx/csrc/rnnt-model.h"
|
||||
|
||||
#include <array>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
/**
|
||||
* Get the input names of a model.
|
||||
*
|
||||
* @param sess An onnxruntime session.
|
||||
* @param input_names. On return, it contains the input names of the model.
|
||||
* @param input_names_ptr. On return, input_names_ptr[i] contains
|
||||
* input_names[i].c_str()
|
||||
*/
|
||||
static void GetInputNames(Ort::Session *sess,
|
||||
std::vector<std::string> *input_names,
|
||||
std::vector<const char *> *input_names_ptr) {
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
size_t node_count = sess->GetInputCount();
|
||||
input_names->resize(node_count);
|
||||
input_names_ptr->resize(node_count);
|
||||
for (size_t i = 0; i != node_count; ++i) {
|
||||
auto tmp = sess->GetInputNameAllocated(i, allocator);
|
||||
(*input_names)[i] = tmp.get();
|
||||
(*input_names_ptr)[i] = (*input_names)[i].c_str();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the output names of a model.
|
||||
*
|
||||
* @param sess An onnxruntime session.
|
||||
* @param output_names. On return, it contains the output names of the model.
|
||||
* @param output_names_ptr. On return, output_names_ptr[i] contains
|
||||
* output_names[i].c_str()
|
||||
*/
|
||||
static void GetOutputNames(Ort::Session *sess,
|
||||
std::vector<std::string> *output_names,
|
||||
std::vector<const char *> *output_names_ptr) {
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
size_t node_count = sess->GetOutputCount();
|
||||
output_names->resize(node_count);
|
||||
output_names_ptr->resize(node_count);
|
||||
for (size_t i = 0; i != node_count; ++i) {
|
||||
auto tmp = sess->GetOutputNameAllocated(i, allocator);
|
||||
(*output_names)[i] = tmp.get();
|
||||
(*output_names_ptr)[i] = (*output_names)[i].c_str();
|
||||
}
|
||||
}
|
||||
|
||||
RnntModel::RnntModel(const std::string &encoder_filename,
|
||||
const std::string &decoder_filename,
|
||||
const std::string &joiner_filename,
|
||||
const std::string &joiner_encoder_proj_filename,
|
||||
const std::string &joiner_decoder_proj_filename,
|
||||
int32_t num_threads)
|
||||
: env_(ORT_LOGGING_LEVEL_WARNING) {
|
||||
sess_opts_.SetIntraOpNumThreads(num_threads);
|
||||
sess_opts_.SetInterOpNumThreads(num_threads);
|
||||
|
||||
InitEncoder(encoder_filename);
|
||||
InitDecoder(decoder_filename);
|
||||
InitJoiner(joiner_filename);
|
||||
InitJoinerEncoderProj(joiner_encoder_proj_filename);
|
||||
InitJoinerDecoderProj(joiner_decoder_proj_filename);
|
||||
}
|
||||
|
||||
void RnntModel::InitEncoder(const std::string &filename) {
|
||||
encoder_sess_ =
|
||||
std::make_unique<Ort::Session>(env_, filename.c_str(), sess_opts_);
|
||||
GetInputNames(encoder_sess_.get(), &encoder_input_names_,
|
||||
&encoder_input_names_ptr_);
|
||||
|
||||
GetOutputNames(encoder_sess_.get(), &encoder_output_names_,
|
||||
&encoder_output_names_ptr_);
|
||||
}
|
||||
|
||||
void RnntModel::InitDecoder(const std::string &filename) {
|
||||
decoder_sess_ =
|
||||
std::make_unique<Ort::Session>(env_, filename.c_str(), sess_opts_);
|
||||
|
||||
GetInputNames(decoder_sess_.get(), &decoder_input_names_,
|
||||
&decoder_input_names_ptr_);
|
||||
|
||||
GetOutputNames(decoder_sess_.get(), &decoder_output_names_,
|
||||
&decoder_output_names_ptr_);
|
||||
}
|
||||
|
||||
void RnntModel::InitJoiner(const std::string &filename) {
|
||||
joiner_sess_ =
|
||||
std::make_unique<Ort::Session>(env_, filename.c_str(), sess_opts_);
|
||||
|
||||
GetInputNames(joiner_sess_.get(), &joiner_input_names_,
|
||||
&joiner_input_names_ptr_);
|
||||
|
||||
GetOutputNames(joiner_sess_.get(), &joiner_output_names_,
|
||||
&joiner_output_names_ptr_);
|
||||
}
|
||||
|
||||
void RnntModel::InitJoinerEncoderProj(const std::string &filename) {
|
||||
joiner_encoder_proj_sess_ =
|
||||
std::make_unique<Ort::Session>(env_, filename.c_str(), sess_opts_);
|
||||
|
||||
GetInputNames(joiner_encoder_proj_sess_.get(),
|
||||
&joiner_encoder_proj_input_names_,
|
||||
&joiner_encoder_proj_input_names_ptr_);
|
||||
|
||||
GetOutputNames(joiner_encoder_proj_sess_.get(),
|
||||
&joiner_encoder_proj_output_names_,
|
||||
&joiner_encoder_proj_output_names_ptr_);
|
||||
}
|
||||
|
||||
void RnntModel::InitJoinerDecoderProj(const std::string &filename) {
|
||||
joiner_decoder_proj_sess_ =
|
||||
std::make_unique<Ort::Session>(env_, filename.c_str(), sess_opts_);
|
||||
|
||||
GetInputNames(joiner_decoder_proj_sess_.get(),
|
||||
&joiner_decoder_proj_input_names_,
|
||||
&joiner_decoder_proj_input_names_ptr_);
|
||||
|
||||
GetOutputNames(joiner_decoder_proj_sess_.get(),
|
||||
&joiner_decoder_proj_output_names_,
|
||||
&joiner_decoder_proj_output_names_ptr_);
|
||||
}
|
||||
|
||||
Ort::Value RnntModel::RunEncoder(const float *features, int32_t T,
|
||||
int32_t feature_dim) {
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
std::array<int64_t, 3> x_shape{1, T, feature_dim};
|
||||
Ort::Value x =
|
||||
Ort::Value::CreateTensor(memory_info, const_cast<float *>(features),
|
||||
T * feature_dim, x_shape.data(), x_shape.size());
|
||||
|
||||
std::array<int64_t, 1> x_lens_shape{1};
|
||||
int64_t x_lens_tmp = T;
|
||||
|
||||
Ort::Value x_lens = Ort::Value::CreateTensor(
|
||||
memory_info, &x_lens_tmp, 1, x_lens_shape.data(), x_lens_shape.size());
|
||||
|
||||
std::array<Ort::Value, 2> encoder_inputs{std::move(x), std::move(x_lens)};
|
||||
|
||||
// Note: We discard encoder_out_lens since we only implement
|
||||
// batch==1.
|
||||
auto encoder_out = encoder_sess_->Run(
|
||||
{}, encoder_input_names_ptr_.data(), encoder_inputs.data(),
|
||||
encoder_inputs.size(), encoder_output_names_ptr_.data(),
|
||||
encoder_output_names_ptr_.size());
|
||||
return std::move(encoder_out[0]);
|
||||
}
|
||||
Ort::Value RnntModel::RunJoinerEncoderProj(const float *encoder_out, int32_t T,
|
||||
int32_t encoder_out_dim) {
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
std::array<int64_t, 2> in_shape{T, encoder_out_dim};
|
||||
Ort::Value in = Ort::Value::CreateTensor(
|
||||
memory_info, const_cast<float *>(encoder_out), T * encoder_out_dim,
|
||||
in_shape.data(), in_shape.size());
|
||||
|
||||
auto encoder_proj_out = joiner_encoder_proj_sess_->Run(
|
||||
{}, joiner_encoder_proj_input_names_ptr_.data(), &in, 1,
|
||||
joiner_encoder_proj_output_names_ptr_.data(),
|
||||
joiner_encoder_proj_output_names_ptr_.size());
|
||||
return std::move(encoder_proj_out[0]);
|
||||
}
|
||||
|
||||
Ort::Value RnntModel::RunDecoder(const int64_t *decoder_input,
|
||||
int32_t context_size) {
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
int32_t batch_size = 1; // TODO(fangjun): handle the case when it's > 1
|
||||
std::array<int64_t, 2> shape{batch_size, context_size};
|
||||
Ort::Value in = Ort::Value::CreateTensor(
|
||||
memory_info, const_cast<int64_t *>(decoder_input),
|
||||
batch_size * context_size, shape.data(), shape.size());
|
||||
|
||||
auto decoder_out = decoder_sess_->Run(
|
||||
{}, decoder_input_names_ptr_.data(), &in, 1,
|
||||
decoder_output_names_ptr_.data(), decoder_output_names_ptr_.size());
|
||||
return std::move(decoder_out[0]);
|
||||
}
|
||||
|
||||
Ort::Value RnntModel::RunJoinerDecoderProj(const float *decoder_out,
|
||||
int32_t decoder_out_dim) {
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
int32_t batch_size = 1; // TODO(fangjun): handle the case when it's > 1
|
||||
std::array<int64_t, 2> shape{batch_size, decoder_out_dim};
|
||||
Ort::Value in = Ort::Value::CreateTensor(
|
||||
memory_info, const_cast<float *>(decoder_out),
|
||||
batch_size * decoder_out_dim, shape.data(), shape.size());
|
||||
|
||||
auto decoder_proj_out = joiner_decoder_proj_sess_->Run(
|
||||
{}, joiner_decoder_proj_input_names_ptr_.data(), &in, 1,
|
||||
joiner_decoder_proj_output_names_ptr_.data(),
|
||||
joiner_decoder_proj_output_names_ptr_.size());
|
||||
return std::move(decoder_proj_out[0]);
|
||||
}
|
||||
|
||||
Ort::Value RnntModel::RunJoiner(const float *projected_encoder_out,
|
||||
const float *projected_decoder_out,
|
||||
int32_t joiner_dim) {
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
int32_t batch_size = 1; // TODO(fangjun): handle the case when it's > 1
|
||||
std::array<int64_t, 2> shape{batch_size, joiner_dim};
|
||||
|
||||
Ort::Value enc = Ort::Value::CreateTensor(
|
||||
memory_info, const_cast<float *>(projected_encoder_out),
|
||||
batch_size * joiner_dim, shape.data(), shape.size());
|
||||
|
||||
Ort::Value dec = Ort::Value::CreateTensor(
|
||||
memory_info, const_cast<float *>(projected_decoder_out),
|
||||
batch_size * joiner_dim, shape.data(), shape.size());
|
||||
|
||||
std::array<Ort::Value, 2> inputs{std::move(enc), std::move(dec)};
|
||||
|
||||
auto logit = joiner_sess_->Run(
|
||||
{}, joiner_input_names_ptr_.data(), inputs.data(), inputs.size(),
|
||||
joiner_output_names_ptr_.data(), joiner_output_names_ptr_.size());
|
||||
|
||||
return std::move(logit[0]);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
148
sherpa-onnx/csrc/rnnt-model.h
Normal file
148
sherpa-onnx/csrc/rnnt-model.h
Normal file
@@ -0,0 +1,148 @@
|
||||
/**
|
||||
* Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
*
|
||||
* See LICENSE for clarification regarding multiple authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_RNNT_MODEL_H_
|
||||
#define SHERPA_ONNX_CSRC_RNNT_MODEL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class RnntModel {
|
||||
public:
|
||||
/**
|
||||
* @param encoder_filename Path to the encoder model
|
||||
* @param decoder_filename Path to the decoder model
|
||||
* @param joiner_filename Path to the joiner model
|
||||
* @param joiner_encoder_proj_filename Path to the joiner encoder_proj model
|
||||
* @param joiner_decoder_proj_filename Path to the joiner decoder_proj model
|
||||
* @param num_threads Number of threads to use to run the models
|
||||
*/
|
||||
RnntModel(const std::string &encoder_filename,
|
||||
const std::string &decoder_filename,
|
||||
const std::string &joiner_filename,
|
||||
const std::string &joiner_encoder_proj_filename,
|
||||
const std::string &joiner_decoder_proj_filename,
|
||||
int32_t num_threads);
|
||||
|
||||
/** Run the encoder model.
|
||||
*
|
||||
* @TODO(fangjun): Support batch_size > 1
|
||||
*
|
||||
* @param features A tensor of shape (batch_size, T, feature_dim)
|
||||
* @param T Number of feature frames
|
||||
* @param feature_dim Dimension of the feature.
|
||||
*
|
||||
* @return Return a tensor of shape (batch_size, T', encoder_out_dim)
|
||||
*/
|
||||
Ort::Value RunEncoder(const float *features, int32_t T, int32_t feature_dim);
|
||||
|
||||
/** Run the joiner encoder_proj model.
|
||||
*
|
||||
* @param encoder_out A tensor of shape (T, encoder_out_dim)
|
||||
* @param T Number of frames in encoder_out.
|
||||
* @param encoder_out_dim Dimension of encoder_out.
|
||||
*
|
||||
* @return Return a tensor of shape (T, joiner_dim)
|
||||
*
|
||||
*/
|
||||
Ort::Value RunJoinerEncoderProj(const float *encoder_out, int32_t T,
|
||||
int32_t encoder_out_dim);
|
||||
|
||||
/** Run the decoder model.
|
||||
*
|
||||
* @TODO(fangjun): Support batch_size > 1
|
||||
*
|
||||
* @param decoder_input A tensor of shape (batch_size, context_size).
|
||||
* @return Return a tensor of shape (batch_size, 1, decoder_out_dim)
|
||||
*/
|
||||
Ort::Value RunDecoder(const int64_t *decoder_input, int32_t context_size);
|
||||
|
||||
/** Run joiner decoder_proj model.
|
||||
*
|
||||
* @TODO(fangjun): Support batch_size > 1
|
||||
*
|
||||
* @param decoder_out A tensor of shape (batch_size, decoder_out_dim)
|
||||
* @param decoder_out_dim Output dimension of the decoder_out.
|
||||
*
|
||||
* @return Return a tensor of shape (batch_size, joiner_dim);
|
||||
*/
|
||||
Ort::Value RunJoinerDecoderProj(const float *decoder_out,
|
||||
int32_t decoder_out_dim);
|
||||
|
||||
/** Run the joiner model.
|
||||
*
|
||||
* @TODO(fangjun): Support batch_size > 1
|
||||
*
|
||||
* @param projected_encoder_out A tensor of shape (batch_size, joiner_dim).
|
||||
* @param projected_decoder_out A tensor of shape (batch_size, joiner_dim).
|
||||
*
|
||||
* @return Return a tensor of shape (batch_size, vocab_size)
|
||||
*/
|
||||
Ort::Value RunJoiner(const float *projected_encoder_out,
|
||||
const float *projected_decoder_out, int32_t joiner_dim);
|
||||
|
||||
private:
|
||||
void InitEncoder(const std::string &encoder_filename);
|
||||
void InitDecoder(const std::string &decoder_filename);
|
||||
void InitJoiner(const std::string &joiner_filename);
|
||||
void InitJoinerEncoderProj(const std::string &joiner_encoder_proj_filename);
|
||||
void InitJoinerDecoderProj(const std::string &joiner_decoder_proj_filename);
|
||||
|
||||
private:
|
||||
Ort::Env env_;
|
||||
Ort::SessionOptions sess_opts_;
|
||||
std::unique_ptr<Ort::Session> encoder_sess_;
|
||||
std::unique_ptr<Ort::Session> decoder_sess_;
|
||||
std::unique_ptr<Ort::Session> joiner_sess_;
|
||||
std::unique_ptr<Ort::Session> joiner_encoder_proj_sess_;
|
||||
std::unique_ptr<Ort::Session> joiner_decoder_proj_sess_;
|
||||
|
||||
std::vector<std::string> encoder_input_names_;
|
||||
std::vector<const char *> encoder_input_names_ptr_;
|
||||
std::vector<std::string> encoder_output_names_;
|
||||
std::vector<const char *> encoder_output_names_ptr_;
|
||||
|
||||
std::vector<std::string> decoder_input_names_;
|
||||
std::vector<const char *> decoder_input_names_ptr_;
|
||||
std::vector<std::string> decoder_output_names_;
|
||||
std::vector<const char *> decoder_output_names_ptr_;
|
||||
|
||||
std::vector<std::string> joiner_input_names_;
|
||||
std::vector<const char *> joiner_input_names_ptr_;
|
||||
std::vector<std::string> joiner_output_names_;
|
||||
std::vector<const char *> joiner_output_names_ptr_;
|
||||
|
||||
std::vector<std::string> joiner_encoder_proj_input_names_;
|
||||
std::vector<const char *> joiner_encoder_proj_input_names_ptr_;
|
||||
std::vector<std::string> joiner_encoder_proj_output_names_;
|
||||
std::vector<const char *> joiner_encoder_proj_output_names_ptr_;
|
||||
|
||||
std::vector<std::string> joiner_decoder_proj_input_names_;
|
||||
std::vector<const char *> joiner_decoder_proj_input_names_ptr_;
|
||||
std::vector<std::string> joiner_decoder_proj_output_names_;
|
||||
std::vector<const char *> joiner_decoder_proj_output_names_ptr_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_RNNT_MODEL_H_
|
||||
@@ -1,120 +0,0 @@
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
#include <algorithm>
|
||||
#include <time.h>
|
||||
|
||||
#include "models.h"
|
||||
#include "utils.h"
|
||||
|
||||
|
||||
std::vector<float> getEncoderCol(Ort::Value &tensor, int start, int length){
|
||||
float* floatarr = tensor.GetTensorMutableData<float>();
|
||||
std::vector<float> vector {floatarr + start, floatarr + length};
|
||||
return vector;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Assume batch size = 1
|
||||
*/
|
||||
std::vector<int64_t> BuildDecoderInput(const std::vector<std::vector<int32_t>> &hyps,
|
||||
std::vector<int64_t> &decoder_input) {
|
||||
|
||||
int32_t context_size = decoder_input.size();
|
||||
int32_t hyps_length = hyps[0].size();
|
||||
for (int i=0; i < context_size; i++)
|
||||
decoder_input[i] = hyps[0][hyps_length-context_size+i];
|
||||
|
||||
return decoder_input;
|
||||
}
|
||||
|
||||
|
||||
std::vector<std::vector<int32_t>> GreedySearch(
|
||||
Model *model, // NOLINT
|
||||
std::vector<Ort::Value> *encoder_out){
|
||||
Ort::Value &encoder_out_tensor = encoder_out->at(0);
|
||||
int encoder_out_dim1 = encoder_out_tensor.GetTensorTypeAndShapeInfo().GetShape()[1];
|
||||
int encoder_out_dim2 = encoder_out_tensor.GetTensorTypeAndShapeInfo().GetShape()[2];
|
||||
auto encoder_out_vector = ortVal2Vector(encoder_out_tensor, encoder_out_dim1 * encoder_out_dim2);
|
||||
|
||||
// # === Greedy Search === #
|
||||
int32_t batch_size = 1;
|
||||
std::vector<int32_t> blanks(model->context_size, model->blank_id);
|
||||
std::vector<std::vector<int32_t>> hyps(batch_size, blanks);
|
||||
std::vector<int64_t> decoder_input(model->context_size, model->blank_id);
|
||||
|
||||
auto decoder_out = model->decoder_forward(decoder_input,
|
||||
std::vector<int64_t> {batch_size, model->context_size},
|
||||
memory_info);
|
||||
|
||||
Ort::Value &decoder_out_tensor = decoder_out[0];
|
||||
int decoder_out_dim = decoder_out_tensor.GetTensorTypeAndShapeInfo().GetShape()[2];
|
||||
auto decoder_out_vector = ortVal2Vector(decoder_out_tensor, decoder_out_dim);
|
||||
|
||||
decoder_out = model->joiner_decoder_proj_forward(decoder_out_vector,
|
||||
std::vector<int64_t> {1, decoder_out_dim},
|
||||
memory_info);
|
||||
Ort::Value &projected_decoder_out_tensor = decoder_out[0];
|
||||
auto projected_decoder_out_dim = projected_decoder_out_tensor.GetTensorTypeAndShapeInfo().GetShape()[1];
|
||||
auto projected_decoder_out_vector = ortVal2Vector(projected_decoder_out_tensor, projected_decoder_out_dim);
|
||||
|
||||
auto projected_encoder_out = model->joiner_encoder_proj_forward(encoder_out_vector,
|
||||
std::vector<int64_t> {encoder_out_dim1, encoder_out_dim2},
|
||||
memory_info);
|
||||
Ort::Value &projected_encoder_out_tensor = projected_encoder_out[0];
|
||||
int projected_encoder_out_dim1 = projected_encoder_out_tensor.GetTensorTypeAndShapeInfo().GetShape()[0];
|
||||
int projected_encoder_out_dim2 = projected_encoder_out_tensor.GetTensorTypeAndShapeInfo().GetShape()[1];
|
||||
auto projected_encoder_out_vector = ortVal2Vector(projected_encoder_out_tensor, projected_encoder_out_dim1 * projected_encoder_out_dim2);
|
||||
|
||||
int32_t offset = 0;
|
||||
for (int i=0; i< projected_encoder_out_dim1; i++){
|
||||
int32_t cur_batch_size = 1;
|
||||
int32_t start = offset;
|
||||
int32_t end = start + cur_batch_size;
|
||||
offset = end;
|
||||
|
||||
auto cur_encoder_out = getEncoderCol(projected_encoder_out_tensor, start * projected_encoder_out_dim2, end * projected_encoder_out_dim2);
|
||||
|
||||
auto logits = model->joiner_forward(cur_encoder_out,
|
||||
projected_decoder_out_vector,
|
||||
std::vector<int64_t> {1, projected_encoder_out_dim2},
|
||||
std::vector<int64_t> {1, projected_decoder_out_dim},
|
||||
memory_info);
|
||||
|
||||
Ort::Value &logits_tensor = logits[0];
|
||||
int logits_dim = logits_tensor.GetTensorTypeAndShapeInfo().GetShape()[1];
|
||||
auto logits_vector = ortVal2Vector(logits_tensor, logits_dim);
|
||||
|
||||
int max_indices = static_cast<int>(std::distance(logits_vector.begin(), std::max_element(logits_vector.begin(), logits_vector.end())));
|
||||
bool emitted = false;
|
||||
|
||||
for (int32_t k = 0; k != cur_batch_size; ++k) {
|
||||
auto index = max_indices;
|
||||
if (index != model->blank_id && index != model->unk_id) {
|
||||
emitted = true;
|
||||
hyps[k].push_back(index);
|
||||
}
|
||||
}
|
||||
|
||||
if (emitted) {
|
||||
decoder_input = BuildDecoderInput(hyps, decoder_input);
|
||||
|
||||
decoder_out = model->decoder_forward(decoder_input,
|
||||
std::vector<int64_t> {batch_size, model->context_size},
|
||||
memory_info);
|
||||
|
||||
decoder_out_dim = decoder_out[0].GetTensorTypeAndShapeInfo().GetShape()[2];
|
||||
decoder_out_vector = ortVal2Vector(decoder_out[0], decoder_out_dim);
|
||||
|
||||
decoder_out = model->joiner_decoder_proj_forward(decoder_out_vector,
|
||||
std::vector<int64_t> {1, decoder_out_dim},
|
||||
memory_info);
|
||||
|
||||
projected_decoder_out_dim = decoder_out[0].GetTensorTypeAndShapeInfo().GetShape()[1];
|
||||
projected_decoder_out_vector = ortVal2Vector(decoder_out[0], projected_decoder_out_dim);
|
||||
}
|
||||
}
|
||||
|
||||
return hyps;
|
||||
}
|
||||
|
||||
129
sherpa-onnx/csrc/sherpa-onnx.cc
Normal file
129
sherpa-onnx/csrc/sherpa-onnx.cc
Normal file
@@ -0,0 +1,129 @@
|
||||
/**
|
||||
* Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
*
|
||||
* See LICENSE for clarification regarding multiple authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "kaldi-native-fbank/csrc/online-feature.h"
|
||||
#include "sherpa-onnx/csrc/decode.h"
|
||||
#include "sherpa-onnx/csrc/rnnt-model.h"
|
||||
#include "sherpa-onnx/csrc/symbol-table.h"
|
||||
#include "sherpa-onnx/csrc/wave-reader.h"
|
||||
|
||||
/** Compute fbank features of the input wave filename.
|
||||
*
|
||||
* @param wav_filename. Path to a mono wave file.
|
||||
* @param expected_sampling_rate Expected sampling rate of the input wave file.
|
||||
* @param num_frames On return, it contains the number of feature frames.
|
||||
* @return Return the computed feature of shape (num_frames, feature_dim)
|
||||
* stored in row-major.
|
||||
*/
|
||||
static std::vector<float> ComputeFeatures(const std::string &wav_filename,
|
||||
float expected_sampling_rate,
|
||||
int32_t *num_frames) {
|
||||
std::vector<float> samples =
|
||||
sherpa_onnx::ReadWave(wav_filename, expected_sampling_rate);
|
||||
|
||||
float duration = samples.size() / expected_sampling_rate;
|
||||
|
||||
std::cout << "wav filename: " << wav_filename << "\n";
|
||||
std::cout << "wav duration (s): " << duration << "\n";
|
||||
|
||||
knf::FbankOptions opts;
|
||||
opts.frame_opts.dither = 0;
|
||||
opts.frame_opts.snip_edges = false;
|
||||
opts.frame_opts.samp_freq = expected_sampling_rate;
|
||||
|
||||
int32_t feature_dim = 80;
|
||||
|
||||
opts.mel_opts.num_bins = feature_dim;
|
||||
|
||||
knf::OnlineFbank fbank(opts);
|
||||
fbank.AcceptWaveform(expected_sampling_rate, samples.data(), samples.size());
|
||||
fbank.InputFinished();
|
||||
|
||||
*num_frames = fbank.NumFramesReady();
|
||||
|
||||
std::vector<float> features(*num_frames * feature_dim);
|
||||
float *p = features.data();
|
||||
|
||||
for (int32_t i = 0; i != fbank.NumFramesReady(); ++i, p += feature_dim) {
|
||||
const float *f = fbank.GetFrame(i);
|
||||
std::copy(f, f + feature_dim, p);
|
||||
}
|
||||
|
||||
return features;
|
||||
}
|
||||
|
||||
int main(int32_t argc, char *argv[]) {
|
||||
if (argc < 8 || argc > 9) {
|
||||
const char *usage = R"usage(
|
||||
Usage:
|
||||
./bin/sherpa-onnx \
|
||||
/path/to/tokens.txt \
|
||||
/path/to/encoder.onnx \
|
||||
/path/to/decoder.onnx \
|
||||
/path/to/joiner.onnx \
|
||||
/path/to/joiner_encoder_proj.ncnn.param \
|
||||
/path/to/joiner_decoder_proj.ncnn.param \
|
||||
/path/to/foo.wav [num_threads]
|
||||
|
||||
You can download pre-trained models from the following repository:
|
||||
https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13
|
||||
)usage";
|
||||
std::cerr << usage << "\n";
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
std::string tokens = argv[1];
|
||||
std::string encoder = argv[2];
|
||||
std::string decoder = argv[3];
|
||||
std::string joiner = argv[4];
|
||||
std::string joiner_encoder_proj = argv[5];
|
||||
std::string joiner_decoder_proj = argv[6];
|
||||
std::string wav_filename = argv[7];
|
||||
int32_t num_threads = 4;
|
||||
if (argc == 9) {
|
||||
num_threads = atoi(argv[8]);
|
||||
}
|
||||
|
||||
sherpa_onnx::SymbolTable sym(tokens);
|
||||
|
||||
int32_t num_frames;
|
||||
auto features = ComputeFeatures(wav_filename, 16000, &num_frames);
|
||||
int32_t feature_dim = features.size() / num_frames;
|
||||
|
||||
sherpa_onnx::RnntModel model(encoder, decoder, joiner, joiner_encoder_proj,
|
||||
joiner_decoder_proj, num_threads);
|
||||
Ort::Value encoder_out =
|
||||
model.RunEncoder(features.data(), num_frames, feature_dim);
|
||||
|
||||
auto hyp = sherpa_onnx::GreedySearch(model, encoder_out);
|
||||
|
||||
std::string text;
|
||||
for (auto i : hyp) {
|
||||
text += sym[i];
|
||||
}
|
||||
|
||||
std::cout << "Recognition result for " << wav_filename << "\n"
|
||||
<< text << "\n";
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -15,34 +15,21 @@
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "kaldi-native-fbank/csrc/online-feature.h"
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
|
||||
int main() {
|
||||
knf::FbankOptions opts;
|
||||
opts.frame_opts.dither = 0;
|
||||
opts.mel_opts.num_bins = 10;
|
||||
|
||||
knf::OnlineFbank fbank(opts);
|
||||
for (int32_t i = 0; i < 1600; ++i) {
|
||||
float s = (i * i - i / 2) / 32767.;
|
||||
fbank.AcceptWaveform(16000, &s, 1);
|
||||
}
|
||||
|
||||
std::cout << "ORT_API_VERSION: " << ORT_API_VERSION << "\n";
|
||||
std::vector<std::string> providers = Ort::GetAvailableProviders();
|
||||
std::ostringstream os;
|
||||
|
||||
int32_t n = fbank.NumFramesReady();
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
const float *frame = fbank.GetFrame(i);
|
||||
for (int32_t k = 0; k != opts.mel_opts.num_bins; ++k) {
|
||||
os << frame[k] << ", ";
|
||||
}
|
||||
os << "\n";
|
||||
os << "Available providers: ";
|
||||
std::string sep = "";
|
||||
for (const auto &p : providers) {
|
||||
os << sep << p;
|
||||
sep = ", ";
|
||||
}
|
||||
|
||||
std::cout << os.str() << "\n";
|
||||
|
||||
return 0;
|
||||
}
|
||||
78
sherpa-onnx/csrc/symbol-table.cc
Normal file
78
sherpa-onnx/csrc/symbol-table.cc
Normal file
@@ -0,0 +1,78 @@
|
||||
/**
|
||||
* Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
*
|
||||
* See LICENSE for clarification regarding multiple authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "sherpa-onnx/csrc/symbol-table.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
SymbolTable::SymbolTable(const std::string &filename) {
|
||||
std::ifstream is(filename);
|
||||
std::string sym;
|
||||
int32_t id;
|
||||
while (is >> sym >> id) {
|
||||
if (sym.size() >= 3) {
|
||||
// For BPE-based models, we replace ▁ with a space
|
||||
// Unicode 9601, hex 0x2581, utf8 0xe29681
|
||||
const uint8_t *p = reinterpret_cast<const uint8_t *>(sym.c_str());
|
||||
if (p[0] == 0xe2 && p[1] == 0x96 && p[2] == 0x81) {
|
||||
sym = sym.replace(0, 3, " ");
|
||||
}
|
||||
}
|
||||
|
||||
assert(!sym.empty());
|
||||
assert(sym2id_.count(sym) == 0);
|
||||
assert(id2sym_.count(id) == 0);
|
||||
|
||||
sym2id_.insert({sym, id});
|
||||
id2sym_.insert({id, sym});
|
||||
}
|
||||
assert(is.eof());
|
||||
}
|
||||
|
||||
std::string SymbolTable::ToString() const {
|
||||
std::ostringstream os;
|
||||
char sep = ' ';
|
||||
for (const auto &p : sym2id_) {
|
||||
os << p.first << sep << p.second << "\n";
|
||||
}
|
||||
return os.str();
|
||||
}
|
||||
|
||||
const std::string &SymbolTable::operator[](int32_t id) const {
|
||||
return id2sym_.at(id);
|
||||
}
|
||||
|
||||
int32_t SymbolTable::operator[](const std::string &sym) const {
|
||||
return sym2id_.at(sym);
|
||||
}
|
||||
|
||||
bool SymbolTable::contains(int32_t id) const { return id2sym_.count(id) != 0; }
|
||||
|
||||
bool SymbolTable::contains(const std::string &sym) const {
|
||||
return sym2id_.count(sym) != 0;
|
||||
}
|
||||
|
||||
std::ostream &operator<<(std::ostream &os, const SymbolTable &symbol_table) {
|
||||
return os << symbol_table.ToString();
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
62
sherpa-onnx/csrc/symbol-table.h
Normal file
62
sherpa-onnx/csrc/symbol-table.h
Normal file
@@ -0,0 +1,62 @@
|
||||
/**
|
||||
* Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
*
|
||||
* See LICENSE for clarification regarding multiple authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_SYMBOL_TABLE_H_
|
||||
#define SHERPA_ONNX_CSRC_SYMBOL_TABLE_H_
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
/// It manages mapping between symbols and integer IDs.
|
||||
class SymbolTable {
|
||||
public:
|
||||
SymbolTable() = default;
|
||||
/// Construct a symbol table from a file.
|
||||
/// Each line in the file contains two fields:
|
||||
///
|
||||
/// sym ID
|
||||
///
|
||||
/// Fields are separated by space(s).
|
||||
explicit SymbolTable(const std::string &filename);
|
||||
|
||||
/// Return a string representation of this symbol table
|
||||
std::string ToString() const;
|
||||
|
||||
/// Return the symbol corresponding to the given ID.
|
||||
const std::string &operator[](int32_t id) const;
|
||||
/// Return the ID corresponding to the given symbol.
|
||||
int32_t operator[](const std::string &sym) const;
|
||||
|
||||
/// Return true if there is a symbol with the given ID.
|
||||
bool contains(int32_t id) const;
|
||||
|
||||
/// Return true if there is a given symbol in the symbol table.
|
||||
bool contains(const std::string &sym) const;
|
||||
|
||||
private:
|
||||
std::unordered_map<std::string, int32_t> sym2id_;
|
||||
std::unordered_map<int32_t, std::string> id2sym_;
|
||||
};
|
||||
|
||||
std::ostream &operator<<(std::ostream &os, const SymbolTable &symbol_table);
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_SYMBOL_TABLE_H_
|
||||
@@ -1,39 +0,0 @@
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
|
||||
|
||||
void vector2file(std::vector<float> vector, std::string saveFileName){
|
||||
std::ofstream f(saveFileName);
|
||||
for(std::vector<float>::const_iterator i = vector.begin(); i != vector.end(); ++i) {
|
||||
f << *i << '\n';
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
std::vector<std::string> hyps2result(std::map<int, std::string> token_map, std::vector<std::vector<int32_t>> hyps, int context_size = 2){
|
||||
std::vector<std::string> results;
|
||||
|
||||
for (int k=0; k < hyps.size(); k++){
|
||||
std::string result = token_map[hyps[k][context_size]];
|
||||
|
||||
for (int i=context_size+1; i < hyps[k].size(); i++){
|
||||
std::string token = token_map[hyps[k][i]];
|
||||
|
||||
// TODO: recognising '_' is not working
|
||||
if (token.at(0) == '_')
|
||||
result += " " + token;
|
||||
else
|
||||
result += token;
|
||||
}
|
||||
results.push_back(result);
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
|
||||
void print_hyps(std::vector<std::vector<int32_t>> hyps, int context_size = 2){
|
||||
std::cout << "Hyps:" << std::endl;
|
||||
for (int i=context_size; i<hyps[0].size(); i++)
|
||||
std::cout << hyps[0][i] << "-";
|
||||
std::cout << "|" << std::endl;
|
||||
}
|
||||
@@ -1,77 +0,0 @@
|
||||
#include <iostream>
|
||||
#include "onnxruntime_cxx_api.h"
|
||||
|
||||
Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "test");
|
||||
const auto& api = Ort::GetApi();
|
||||
OrtTensorRTProviderOptionsV2* tensorrt_options;
|
||||
Ort::SessionOptions session_options;
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
|
||||
|
||||
|
||||
std::vector<float> ortVal2Vector(Ort::Value &tensor, int tensor_length){
|
||||
/**
|
||||
* convert ort tensor to vector
|
||||
*/
|
||||
float* floatarr = tensor.GetTensorMutableData<float>();
|
||||
std::vector<float> vector {floatarr, floatarr + tensor_length};
|
||||
return vector;
|
||||
}
|
||||
|
||||
|
||||
void print_onnx_forward_output(std::vector<Ort::Value> &output_tensors, int num){
|
||||
float* floatarr = output_tensors.front().GetTensorMutableData<float>();
|
||||
for (int i = 0; i < num; i++)
|
||||
printf("[%d] = %f\n", i, floatarr[i]);
|
||||
}
|
||||
|
||||
|
||||
void print_shape_of_ort_val(std::vector<Ort::Value> &tensor){
|
||||
auto out_shape = tensor.front().GetTensorTypeAndShapeInfo().GetShape();
|
||||
auto out_size = out_shape.size();
|
||||
std::cout << "(";
|
||||
for (int i=0; i<out_size; i++){
|
||||
std::cout << out_shape[i];
|
||||
if (i < out_size-1)
|
||||
std::cout << ",";
|
||||
}
|
||||
std::cout << ")" << std::endl;
|
||||
}
|
||||
|
||||
|
||||
void print_model_info(Ort::Session &session, std::string title){
|
||||
std::cout << "=== Printing '" << title << "' model ===" << std::endl;
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
|
||||
// print number of model input nodes
|
||||
size_t num_input_nodes = session.GetInputCount();
|
||||
std::vector<const char*> input_node_names(num_input_nodes);
|
||||
std::vector<int64_t> input_node_dims;
|
||||
|
||||
printf("Number of inputs = %zu\n", num_input_nodes);
|
||||
|
||||
char* output_name = session.GetOutputName(0, allocator);
|
||||
printf("output name: %s\n", output_name);
|
||||
|
||||
// iterate over all input nodes
|
||||
for (int i = 0; i < num_input_nodes; i++) {
|
||||
// print input node names
|
||||
char* input_name = session.GetInputName(i, allocator);
|
||||
printf("Input %d : name=%s\n", i, input_name);
|
||||
input_node_names[i] = input_name;
|
||||
|
||||
// print input node types
|
||||
Ort::TypeInfo type_info = session.GetInputTypeInfo(i);
|
||||
auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
|
||||
|
||||
ONNXTensorElementDataType type = tensor_info.GetElementType();
|
||||
printf("Input %d : type=%d\n", i, type);
|
||||
|
||||
// print input shapes/dims
|
||||
input_node_dims = tensor_info.GetShape();
|
||||
printf("Input %d : num_dims=%zu\n", i, input_node_dims.size());
|
||||
for (size_t j = 0; j < input_node_dims.size(); j++)
|
||||
printf("Input %d : dim %zu=%jd\n", i, j, input_node_dims[j]);
|
||||
}
|
||||
std::cout << "=======================================" << std::endl;
|
||||
}
|
||||
108
sherpa-onnx/csrc/wave-reader.cc
Normal file
108
sherpa-onnx/csrc/wave-reader.cc
Normal file
@@ -0,0 +1,108 @@
|
||||
/**
|
||||
* Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
*
|
||||
* See LICENSE for clarification regarding multiple authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "sherpa-onnx/csrc/wave-reader.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
namespace {
|
||||
// see http://soundfile.sapp.org/doc/WaveFormat/
|
||||
//
|
||||
// Note: We assume little endian here
|
||||
// TODO(fangjun): Support big endian
|
||||
struct WaveHeader {
|
||||
void Validate() const {
|
||||
// F F I R
|
||||
assert(chunk_id == 0x46464952);
|
||||
assert(chunk_size == 36 + subchunk2_size);
|
||||
// E V A W
|
||||
assert(format == 0x45564157);
|
||||
assert(subchunk1_id == 0x20746d66);
|
||||
assert(subchunk1_size == 16); // 16 for PCM
|
||||
assert(audio_format == 1); // 1 for PCM
|
||||
assert(num_channels == 1); // we support only single channel for now
|
||||
assert(byte_rate == sample_rate * num_channels * bits_per_sample / 8);
|
||||
assert(block_align == num_channels * bits_per_sample / 8);
|
||||
assert(bits_per_sample == 16); // we support only 16 bits per sample
|
||||
}
|
||||
|
||||
int32_t chunk_id;
|
||||
int32_t chunk_size;
|
||||
int32_t format;
|
||||
int32_t subchunk1_id;
|
||||
int32_t subchunk1_size;
|
||||
int16_t audio_format;
|
||||
int16_t num_channels;
|
||||
int32_t sample_rate;
|
||||
int32_t byte_rate;
|
||||
int16_t block_align;
|
||||
int16_t bits_per_sample;
|
||||
int32_t subchunk2_id;
|
||||
int32_t subchunk2_size;
|
||||
};
|
||||
static_assert(sizeof(WaveHeader) == 44, "");
|
||||
|
||||
// Read a wave file of mono-channel.
|
||||
// Return its samples normalized to the range [-1, 1).
|
||||
std::vector<float> ReadWaveImpl(std::istream &is, float *sample_rate) {
|
||||
WaveHeader header;
|
||||
is.read(reinterpret_cast<char *>(&header), sizeof(header));
|
||||
assert(static_cast<bool>(is));
|
||||
|
||||
header.Validate();
|
||||
|
||||
*sample_rate = header.sample_rate;
|
||||
|
||||
// header.subchunk2_size contains the number of bytes in the data.
|
||||
// As we assume each sample contains two bytes, so it is divided by 2 here
|
||||
std::vector<int16_t> samples(header.subchunk2_size / 2);
|
||||
|
||||
is.read(reinterpret_cast<char *>(samples.data()), header.subchunk2_size);
|
||||
|
||||
assert(static_cast<bool>(is));
|
||||
|
||||
std::vector<float> ans(samples.size());
|
||||
for (int32_t i = 0; i != ans.size(); ++i) {
|
||||
ans[i] = samples[i] / 32768.;
|
||||
}
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::vector<float> ReadWave(const std::string &filename,
|
||||
float expected_sample_rate) {
|
||||
std::ifstream is(filename, std::ifstream::binary);
|
||||
float sample_rate;
|
||||
auto samples = ReadWaveImpl(is, &sample_rate);
|
||||
if (expected_sample_rate != sample_rate) {
|
||||
std::cerr << "Expected sample rate: " << expected_sample_rate
|
||||
<< ". Given: " << sample_rate << ".\n";
|
||||
exit(-1);
|
||||
}
|
||||
return samples;
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
41
sherpa-onnx/csrc/wave-reader.h
Normal file
41
sherpa-onnx/csrc/wave-reader.h
Normal file
@@ -0,0 +1,41 @@
|
||||
/**
|
||||
* Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
*
|
||||
* See LICENSE for clarification regarding multiple authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_WAVE_READER_H_
|
||||
#define SHERPA_ONNX_CSRC_WAVE_READER_H_
|
||||
|
||||
#include <istream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
/** Read a wave file with expected sample rate.
|
||||
|
||||
@param filename Path to a wave file. It MUST be single channel, PCM encoded.
|
||||
@param expected_sample_rate Expected sample rate of the wave file. If the
|
||||
sample rate don't match, it throws an exception.
|
||||
|
||||
@return Return wave samples normalized to the range [-1, 1).
|
||||
*/
|
||||
std::vector<float> ReadWave(const std::string &filename,
|
||||
float expected_sample_rate);
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_WAVE_READER_H_
|
||||
Reference in New Issue
Block a user