From 8ee02c28b0f7dcab91fe76e039303fc5320a2143 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 12 Oct 2022 21:35:07 +0800 Subject: [PATCH] Refactor the code (#15) * code refactoring * Remove reference files * Update README and CI * small fixes * fix style issues * add style check for CI * fix style issues * remove kaldi-native-io --- .clang-format | 9 + .github/workflows/style_check.yaml | 58 ++++ .github/workflows/test-linux-macos.yaml | 9 +- CMakeLists.txt | 1 - README.md | 8 +- cmake/kaldi_native_io.cmake | 39 --- cmake/onnxruntime.cmake | 2 +- scripts/check_style_cpplint.sh | 126 +++++++++ scripts/utils.sh | 19 ++ sherpa-onnx/csrc/CMakeLists.txt | 13 +- sherpa-onnx/csrc/decode.cc | 84 ++++++ sherpa-onnx/csrc/decode.h | 40 +++ sherpa-onnx/csrc/fbank_features.h | 57 ---- sherpa-onnx/csrc/main.cpp | 99 ------- sherpa-onnx/csrc/models.h | 253 ------------------ sherpa-onnx/csrc/rnnt-model.cc | 247 +++++++++++++++++ sherpa-onnx/csrc/rnnt-model.h | 148 ++++++++++ sherpa-onnx/csrc/rnnt_beam_search.h | 120 --------- sherpa-onnx/csrc/sherpa-onnx.cc | 129 +++++++++ ...online-fbank-test.cc => show-onnx-info.cc} | 31 +-- sherpa-onnx/csrc/symbol-table.cc | 78 ++++++ sherpa-onnx/csrc/symbol-table.h | 62 +++++ sherpa-onnx/csrc/utils.h | 39 --- sherpa-onnx/csrc/utils_onnx.h | 77 ------ sherpa-onnx/csrc/wave-reader.cc | 108 ++++++++ sherpa-onnx/csrc/wave-reader.h | 41 +++ 26 files changed, 1179 insertions(+), 718 deletions(-) create mode 100644 .clang-format create mode 100644 .github/workflows/style_check.yaml delete mode 100644 cmake/kaldi_native_io.cmake create mode 100755 scripts/check_style_cpplint.sh create mode 100644 scripts/utils.sh create mode 100644 sherpa-onnx/csrc/decode.cc create mode 100644 sherpa-onnx/csrc/decode.h delete mode 100644 sherpa-onnx/csrc/fbank_features.h delete mode 100644 sherpa-onnx/csrc/main.cpp delete mode 100644 sherpa-onnx/csrc/models.h create mode 100644 sherpa-onnx/csrc/rnnt-model.cc create mode 100644 sherpa-onnx/csrc/rnnt-model.h delete mode 100644 sherpa-onnx/csrc/rnnt_beam_search.h create mode 100644 sherpa-onnx/csrc/sherpa-onnx.cc rename sherpa-onnx/csrc/{online-fbank-test.cc => show-onnx-info.cc} (60%) create mode 100644 sherpa-onnx/csrc/symbol-table.cc create mode 100644 sherpa-onnx/csrc/symbol-table.h delete mode 100644 sherpa-onnx/csrc/utils.h delete mode 100644 sherpa-onnx/csrc/utils_onnx.h create mode 100644 sherpa-onnx/csrc/wave-reader.cc create mode 100644 sherpa-onnx/csrc/wave-reader.h diff --git a/.clang-format b/.clang-format new file mode 100644 index 00000000..c65e7720 --- /dev/null +++ b/.clang-format @@ -0,0 +1,9 @@ +--- +BasedOnStyle: Google +--- +Language: Cpp +Cpp11BracedListStyle: true +Standard: Cpp11 +DerivePointerAlignment: false +PointerAlignment: Right +--- diff --git a/.github/workflows/style_check.yaml b/.github/workflows/style_check.yaml new file mode 100644 index 00000000..dc0f775f --- /dev/null +++ b/.github/workflows/style_check.yaml @@ -0,0 +1,58 @@ +# Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) +# +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: style_check + +on: + push: + branches: + - master + paths: + - '.github/workflows/style_check.yaml' + - 'sherpa-onnx/**' + pull_request: + branches: + - master + paths: + - '.github/workflows/style_check.yaml' + - 'sherpa-onnx/**' + +concurrency: + group: style_check-${{ github.ref }} + cancel-in-progress: true + +jobs: + style_check: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: [3.8] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v1 + with: + python-version: ${{ matrix.python-version }} + + - name: Check style with cpplint + shell: bash + working-directory: ${{github.workspace}} + run: ./scripts/check_style_cpplint.sh diff --git a/.github/workflows/test-linux-macos.yaml b/.github/workflows/test-linux-macos.yaml index e182b961..08671bb7 100644 --- a/.github/workflows/test-linux-macos.yaml +++ b/.github/workflows/test-linux-macos.yaml @@ -59,31 +59,28 @@ jobs: - name: Run tests for ubuntu/macos (English) run: | time ./build/bin/sherpa-onnx \ + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \ ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \ ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \ ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \ ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \ ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \ - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \ - greedy \ ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1089-134686-0001.wav time ./build/bin/sherpa-onnx \ + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \ ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \ ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \ ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \ ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \ ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \ - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \ - greedy \ ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0001.wav time ./build/bin/sherpa-onnx \ + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \ ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \ ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \ ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \ ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \ ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \ - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \ - greedy \ ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0002.wav diff --git a/CMakeLists.txt b/CMakeLists.txt index c48ff078..72fc4981 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -38,7 +38,6 @@ set(CMAKE_CXX_EXTENSIONS OFF) list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules) list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake) -include(kaldi_native_io) include(kaldi-native-fbank) include(onnxruntime) diff --git a/README.md b/README.md index 4dab1a78..de50b573 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,9 @@ the following links: **NOTE**: We provide only non-streaming models at present. +**HINT**: The script for exporting the English model can be found at + + # Usage ```bash @@ -34,13 +37,14 @@ cd .. git lfs install git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13 +./build/bin/sherpa-onnx --help + ./build/bin/sherpa-onnx \ + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \ ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \ ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \ ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \ ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \ ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \ - ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \ - greedy \ ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1089-134686-0001.wav ``` diff --git a/cmake/kaldi_native_io.cmake b/cmake/kaldi_native_io.cmake deleted file mode 100644 index 9e7cdec4..00000000 --- a/cmake/kaldi_native_io.cmake +++ /dev/null @@ -1,39 +0,0 @@ -function(download_kaldi_native_io) - if(CMAKE_VERSION VERSION_LESS 3.11) - # FetchContent is available since 3.11, - # we've copied it to ${CMAKE_SOURCE_DIR}/cmake/Modules - # so that it can be used in lower CMake versions. - message(STATUS "Use FetchContent provided by sherpa-onnx") - list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules) - endif() - - include(FetchContent) - - set(kaldi_native_io_URL "https://github.com/csukuangfj/kaldi_native_io/archive/refs/tags/v1.15.1.tar.gz") - set(kaldi_native_io_HASH "SHA256=97377e1d61e99d8fc1d6037a418d3037522dfa46337e06162e24b1d97f3d70a6") - - set(KALDI_NATIVE_IO_BUILD_TESTS OFF CACHE BOOL "" FORCE) - set(KALDI_NATIVE_IO_BUILD_PYTHON OFF CACHE BOOL "" FORCE) - - FetchContent_Declare(kaldi_native_io - URL ${kaldi_native_io_URL} - URL_HASH ${kaldi_native_io_HASH} - ) - - FetchContent_GetProperties(kaldi_native_io) - if(NOT kaldi_native_io_POPULATED) - message(STATUS "Downloading kaldi_native_io ${kaldi_native_io_URL}") - FetchContent_Populate(kaldi_native_io) - endif() - message(STATUS "kaldi_native_io is downloaded to ${kaldi_native_io_SOURCE_DIR}") - message(STATUS "kaldi_native_io's binary dir is ${kaldi_native_io_BINARY_DIR}") - - add_subdirectory(${kaldi_native_io_SOURCE_DIR} ${kaldi_native_io_BINARY_DIR} EXCLUDE_FROM_ALL) - - target_include_directories(kaldi_native_io_core - PUBLIC - ${kaldi_native_io_SOURCE_DIR}/ - ) -endfunction() - -download_kaldi_native_io() diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake index bfcaf8b3..0c5cb5f9 100644 --- a/cmake/onnxruntime.cmake +++ b/cmake/onnxruntime.cmake @@ -10,7 +10,7 @@ function(download_onnxruntime) include(FetchContent) if(UNIX AND NOT APPLE) - set(onnxruntime_URL "http://github.com/microsoft/onnxruntime/releases/download/v1.12.1/onnxruntime-linux-x64-1.12.1.tgz") + set(onnxruntime_URL "https://github.com/microsoft/onnxruntime/releases/download/v1.12.1/onnxruntime-linux-x64-1.12.1.tgz") # If you don't have access to the internet, you can first download onnxruntime to some directory, and the use # set(onnxruntime_URL "file:///ceph-fj/fangjun/open-source/sherpa-onnx/onnxruntime-linux-x64-1.12.1.tgz") diff --git a/scripts/check_style_cpplint.sh b/scripts/check_style_cpplint.sh new file mode 100755 index 00000000..f81e02a2 --- /dev/null +++ b/scripts/check_style_cpplint.sh @@ -0,0 +1,126 @@ +#!/bin/bash +# +# Copyright 2020 Mobvoi Inc. (authors: Fangjun Kuang) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Usage: +# +# (1) To check files of the last commit +# ./scripts/check_style_cpplint.sh +# +# (2) To check changed files not committed yet +# ./scripts/check_style_cpplint.sh 1 +# +# (3) To check all files in the project +# ./scripts/check_style_cpplint.sh 2 + + +cpplint_version="1.5.4" +cur_dir=$(cd $(dirname $BASH_SOURCE) && pwd) +sherpa_onnx_dir=$(cd $cur_dir/.. && pwd) + +build_dir=$sherpa_onnx_dir/build +mkdir -p $build_dir + +cpplint_src=$build_dir/cpplint-${cpplint_version}/cpplint.py + +if [ ! -d "$build_dir/cpplint-${cpplint_version}" ]; then + pushd $build_dir + if command -v wget &> /dev/null; then + wget https://github.com/cpplint/cpplint/archive/${cpplint_version}.tar.gz + elif command -v curl &> /dev/null; then + curl -O -SL https://github.com/cpplint/cpplint/archive/${cpplint_version}.tar.gz + else + echo "Please install wget or curl to download cpplint" + exit 1 + fi + tar xf ${cpplint_version}.tar.gz + rm ${cpplint_version}.tar.gz + + # cpplint will report the following error for: __host__ __device__ ( + # + # Extra space before ( in function call [whitespace/parens] [4] + # + # the following patch disables the above error + sed -i "3490i\ not Search(r'__host__ __device__\\\s+\\\(', fncall) and" $cpplint_src + popd +fi + +source $sherpa_onnx_dir/scripts/utils.sh + +# return true if the given file is a c++ source file +# return false otherwise +function is_source_code_file() { + case "$1" in + *.cc|*.h|*.cu) + echo true;; + *) + echo false;; + esac +} + +function check_style() { + python3 $cpplint_src $1 || abort $1 +} + +function check_last_commit() { + files=$(git diff HEAD^1 --name-only --diff-filter=ACDMRUXB) + echo $files +} + +function check_current_dir() { + files=$(git status -s -uno --porcelain | awk '{ + if (NF == 4) { + # a file has been renamed + print $NF + } else { + print $2 + }}') + + echo $files +} + +function do_check() { + case "$1" in + 1) + echo "Check changed files" + files=$(check_current_dir) + ;; + 2) + echo "Check all files" + files=$(find $sherpa_onnx_dir/sherpa-onnx -name "*.h" -o -name "*.cc") + ;; + *) + echo "Check last commit" + files=$(check_last_commit) + ;; + esac + + for f in $files; do + need_check=$(is_source_code_file $f) + if $need_check; then + [[ -f $f ]] && check_style $f + fi + done +} + +function main() { + do_check $1 + + ok "Great! Style check passed!" +} + +cd $sherpa_onnx_dir + +main $1 diff --git a/scripts/utils.sh b/scripts/utils.sh new file mode 100644 index 00000000..fb424a7b --- /dev/null +++ b/scripts/utils.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +default='\033[0m' +bold='\033[1m' +red='\033[31m' +green='\033[32m' + +function ok() { + printf "${bold}${green}[OK]${default} $1\n" +} + +function error() { + printf "${bold}${red}[FAILED]${default} $1\n" +} + +function abort() { + printf "${bold}${red}[FAILED]${default} $1\n" + exit 1 +} diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index 94ef9147..d255456e 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -1,8 +1,17 @@ include_directories(${CMAKE_SOURCE_DIR}) -add_executable(sherpa-onnx main.cpp) + +add_executable(sherpa-onnx + decode.cc + rnnt-model.cc + sherpa-onnx.cc + symbol-table.cc + wave-reader.cc +) target_link_libraries(sherpa-onnx onnxruntime kaldi-native-fbank-core - kaldi_native_io_core ) + +add_executable(sherpa-show-onnx-info show-onnx-info.cc) +target_link_libraries(sherpa-show-onnx-info onnxruntime) diff --git a/sherpa-onnx/csrc/decode.cc b/sherpa-onnx/csrc/decode.cc new file mode 100644 index 00000000..5e5cf65b --- /dev/null +++ b/sherpa-onnx/csrc/decode.cc @@ -0,0 +1,84 @@ +/** + * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "sherpa-onnx/csrc/decode.h" + +#include + +#include +#include + +namespace sherpa_onnx { + +std::vector GreedySearch(RnntModel &model, // NOLINT + const Ort::Value &encoder_out) { + std::vector encoder_out_shape = + encoder_out.GetTensorTypeAndShapeInfo().GetShape(); + assert(encoder_out_shape[0] == 1 && "Only batch_size=1 is implemented"); + Ort::Value projected_encoder_out = + model.RunJoinerEncoderProj(encoder_out.GetTensorData(), + encoder_out_shape[1], encoder_out_shape[2]); + + const float *p_projected_encoder_out = + projected_encoder_out.GetTensorData(); + + int32_t context_size = 2; // hard-code it to 2 + int32_t blank_id = 0; // hard-code it to 0 + std::vector hyp(context_size, blank_id); + std::array decoder_input{blank_id, blank_id}; + + Ort::Value decoder_out = model.RunDecoder(decoder_input.data(), context_size); + + std::vector decoder_out_shape = + decoder_out.GetTensorTypeAndShapeInfo().GetShape(); + + Ort::Value projected_decoder_out = model.RunJoinerDecoderProj( + decoder_out.GetTensorData(), decoder_out_shape[2]); + + int32_t joiner_dim = + projected_decoder_out.GetTensorTypeAndShapeInfo().GetShape()[1]; + + int32_t T = encoder_out_shape[1]; + for (int32_t t = 0; t != T; ++t) { + Ort::Value logit = model.RunJoiner( + p_projected_encoder_out + t * joiner_dim, + projected_decoder_out.GetTensorData(), joiner_dim); + + int32_t vocab_size = logit.GetTensorTypeAndShapeInfo().GetShape()[1]; + + const float *p_logit = logit.GetTensorData(); + + auto y = static_cast(std::distance( + static_cast(p_logit), + std::max_element(static_cast(p_logit), + static_cast(p_logit) + vocab_size))); + + if (y != blank_id) { + decoder_input[0] = hyp.back(); + decoder_input[1] = y; + hyp.push_back(y); + decoder_out = model.RunDecoder(decoder_input.data(), context_size); + projected_decoder_out = model.RunJoinerDecoderProj( + decoder_out.GetTensorData(), decoder_out_shape[2]); + } + } + + return {hyp.begin() + context_size, hyp.end()}; +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/decode.h b/sherpa-onnx/csrc/decode.h new file mode 100644 index 00000000..7511247c --- /dev/null +++ b/sherpa-onnx/csrc/decode.h @@ -0,0 +1,40 @@ +/** + * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef SHERPA_ONNX_CSRC_DECODE_H_ +#define SHERPA_ONNX_CSRC_DECODE_H_ + +#include + +#include "sherpa-onnx/csrc/rnnt-model.h" + +namespace sherpa_onnx { + +/** Greedy search for non-streaming ASR. + * + * @TODO(fangjun) Support batch size > 1 + * + * @param model The RnntModel + * @param encoder_out Its shape is (1, num_frames, encoder_out_dim). + */ +std::vector GreedySearch(RnntModel &model, // NOLINT + const Ort::Value &encoder_out); + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_DECODE_H_ diff --git a/sherpa-onnx/csrc/fbank_features.h b/sherpa-onnx/csrc/fbank_features.h deleted file mode 100644 index d0caab06..00000000 --- a/sherpa-onnx/csrc/fbank_features.h +++ /dev/null @@ -1,57 +0,0 @@ -#include - -#include "kaldi_native_io/csrc/kaldi-io.h" -#include "kaldi_native_io/csrc/wave-reader.h" -#include "kaldi-native-fbank/csrc/online-feature.h" - - -kaldiio::Matrix readWav(std::string filename, bool log = false){ - if (log) - std::cout << "reading " << filename << std::endl; - - bool binary = true; - kaldiio::Input ki(filename, &binary); - kaldiio::WaveHolder wh; - - if (!wh.Read(ki.Stream())) { - std::cerr << "Failed to read " << filename; - exit(EXIT_FAILURE); - } - - auto &wave_data = wh.Value(); - auto &d = wave_data.Data(); - - if (log) - std::cout << "wav shape: " << "(" << d.NumRows() << "," << d.NumCols() << ")" << std::endl; - - return d; -} - - -std::vector ComputeFeatures(knf::OnlineFbank &fbank, knf::FbankOptions opts, kaldiio::Matrix samples, bool log = false){ - int numSamples = samples.NumCols(); - - for (int i = 0; i < numSamples; i++) - { - float currentSample = samples.Row(0).Data()[i] / 32768; - fbank.AcceptWaveform(opts.frame_opts.samp_freq, ¤tSample, 1); - } - - std::vector features; - int32_t num_frames = fbank.NumFramesReady(); - for (int32_t i = 0; i != num_frames; ++i) { - const float *frame = fbank.GetFrame(i); - for (int32_t k = 0; k != opts.mel_opts.num_bins; ++k) { - features.push_back(frame[k]); - } - } - if (log){ - std::cout << "done feature extraction" << std::endl; - std::cout << "extracted fbank shape " << "(" << num_frames << "," << opts.mel_opts.num_bins << ")" << std::endl; - - for (int i=0; i< 20; i++) - std::cout << features.at(i) << std::endl; - } - - return features; -} \ No newline at end of file diff --git a/sherpa-onnx/csrc/main.cpp b/sherpa-onnx/csrc/main.cpp deleted file mode 100644 index 77577c2e..00000000 --- a/sherpa-onnx/csrc/main.cpp +++ /dev/null @@ -1,99 +0,0 @@ -#include -#include -#include -#include -#include -#include - -#include "sherpa-onnx/csrc/fbank_features.h" -#include "sherpa-onnx/csrc/rnnt_beam_search.h" - -#include "kaldi-native-fbank/csrc/online-feature.h" - -int main(int argc, char *argv[]) { - char *encoder_path = argv[1]; - char *decoder_path = argv[2]; - char *joiner_path = argv[3]; - char *joiner_encoder_proj_path = argv[4]; - char *joiner_decoder_proj_path = argv[5]; - char *token_path = argv[6]; - std::string search_method = argv[7]; - char *filename = argv[8]; - - // General parameters - int numberOfThreads = 16; - - // Initialize fbanks - knf::FbankOptions opts; - opts.frame_opts.dither = 0; - opts.frame_opts.samp_freq = 16000; - opts.frame_opts.frame_shift_ms = 10.0f; - opts.frame_opts.frame_length_ms = 25.0f; - opts.mel_opts.num_bins = 80; - opts.frame_opts.window_type = "povey"; - opts.frame_opts.snip_edges = false; - knf::OnlineFbank fbank(opts); - - // set session opts - // https://onnxruntime.ai/docs/performance/tune-performance.html - session_options.SetIntraOpNumThreads(numberOfThreads); - session_options.SetInterOpNumThreads(numberOfThreads); - session_options.SetGraphOptimizationLevel( - GraphOptimizationLevel::ORT_ENABLE_EXTENDED); - session_options.SetLogSeverityLevel(4); - session_options.SetExecutionMode(ExecutionMode::ORT_SEQUENTIAL); - - api.CreateTensorRTProviderOptions(&tensorrt_options); - std::unique_ptr - rel_trt_options(tensorrt_options, api.ReleaseTensorRTProviderOptions); - api.SessionOptionsAppendExecutionProvider_TensorRT_V2( - static_cast(session_options), rel_trt_options.get()); - - // Define model - auto model = - get_model(encoder_path, decoder_path, joiner_path, - joiner_encoder_proj_path, joiner_decoder_proj_path, token_path); - - std::vector filename_list{filename}; - - for (auto filename : filename_list) { - std::cout << filename << std::endl; - auto samples = readWav(filename, true); - int numSamples = samples.NumCols(); - - auto features = ComputeFeatures(fbank, opts, samples); - - auto tic = std::chrono::high_resolution_clock::now(); - - // # === Encoder Out === # - int num_frames = features.size() / opts.mel_opts.num_bins; - auto encoder_out = - model.encoder_forward(features, std::vector{num_frames}, - std::vector{1, num_frames, 80}, - std::vector{1}, memory_info); - - // # === Search === # - std::vector> hyps; - if (search_method == "greedy") - hyps = GreedySearch(&model, &encoder_out); - else { - std::cout << "wrong search method!" << std::endl; - exit(0); - } - auto results = hyps2result(model.tokens_map, hyps); - - // # === Print Elapsed Time === # - auto elapsed = std::chrono::duration_cast( - std::chrono::high_resolution_clock::now() - tic); - std::cout << "Elapsed: " << float(elapsed.count()) / 1000 << " seconds" - << std::endl; - std::cout << "rtf: " << float(elapsed.count()) / 1000 / (numSamples / 16000) - << std::endl; - - print_hyps(hyps); - std::cout << results[0] << std::endl; - } - - return 0; -} diff --git a/sherpa-onnx/csrc/models.h b/sherpa-onnx/csrc/models.h deleted file mode 100644 index b18d168a..00000000 --- a/sherpa-onnx/csrc/models.h +++ /dev/null @@ -1,253 +0,0 @@ -#include -#include -#include -#include -#include - -#include "utils_onnx.h" - - -struct Model -{ - public: - const char* encoder_path; - const char* decoder_path; - const char* joiner_path; - const char* joiner_encoder_proj_path; - const char* joiner_decoder_proj_path; - const char* tokens_path; - - Ort::Session encoder = load_model(encoder_path); - Ort::Session decoder = load_model(decoder_path); - Ort::Session joiner = load_model(joiner_path); - Ort::Session joiner_encoder_proj = load_model(joiner_encoder_proj_path); - Ort::Session joiner_decoder_proj = load_model(joiner_decoder_proj_path); - std::map tokens_map = get_token_map(tokens_path); - - int32_t blank_id; - int32_t unk_id; - int32_t context_size; - - std::vector encoder_forward(std::vector in_vector, - std::vector in_vector_length, - std::vector feature_dims, - std::vector feature_length_dims, - Ort::MemoryInfo &memory_info){ - std::vector encoder_inputTensors; - encoder_inputTensors.push_back(Ort::Value::CreateTensor(memory_info, in_vector.data(), in_vector.size(), feature_dims.data(), feature_dims.size())); - encoder_inputTensors.push_back(Ort::Value::CreateTensor(memory_info, in_vector_length.data(), in_vector_length.size(), feature_length_dims.data(), feature_length_dims.size())); - - std::vector encoder_inputNames = {encoder.GetInputName(0, allocator), encoder.GetInputName(1, allocator)}; - std::vector encoder_outputNames = {encoder.GetOutputName(0, allocator)}; - - auto out = encoder.Run(Ort::RunOptions{nullptr}, - encoder_inputNames.data(), - encoder_inputTensors.data(), - encoder_inputTensors.size(), - encoder_outputNames.data(), - encoder_outputNames.size()); - return out; - } - - std::vector decoder_forward(std::vector in_vector, - std::vector dims, - Ort::MemoryInfo &memory_info){ - std::vector inputTensors; - inputTensors.push_back(Ort::Value::CreateTensor(memory_info, in_vector.data(), in_vector.size(), dims.data(), dims.size())); - - std::vector inputNames {decoder.GetInputName(0, allocator)}; - std::vector outputNames {decoder.GetOutputName(0, allocator)}; - - auto out = decoder.Run(Ort::RunOptions{nullptr}, - inputNames.data(), - inputTensors.data(), - inputTensors.size(), - outputNames.data(), - outputNames.size()); - - return out; - } - - std::vector joiner_forward(std::vector projected_encoder_out, - std::vector decoder_out, - std::vector projected_encoder_out_dims, - std::vector decoder_out_dims, - Ort::MemoryInfo &memory_info){ - std::vector inputTensors; - inputTensors.push_back(Ort::Value::CreateTensor(memory_info, projected_encoder_out.data(), projected_encoder_out.size(), projected_encoder_out_dims.data(), projected_encoder_out_dims.size())); - inputTensors.push_back(Ort::Value::CreateTensor(memory_info, decoder_out.data(), decoder_out.size(), decoder_out_dims.data(), decoder_out_dims.size())); - std::vector inputNames = {joiner.GetInputName(0, allocator), joiner.GetInputName(1, allocator)}; - std::vector outputNames = {joiner.GetOutputName(0, allocator)}; - - auto out = joiner.Run(Ort::RunOptions{nullptr}, - inputNames.data(), - inputTensors.data(), - inputTensors.size(), - outputNames.data(), - outputNames.size()); - - return out; - } - - std::vector joiner_encoder_proj_forward(std::vector in_vector, - std::vector dims, - Ort::MemoryInfo &memory_info){ - std::vector inputTensors; - inputTensors.push_back(Ort::Value::CreateTensor(memory_info, in_vector.data(), in_vector.size(), dims.data(), dims.size())); - - std::vector inputNames {joiner_encoder_proj.GetInputName(0, allocator)}; - std::vector outputNames {joiner_encoder_proj.GetOutputName(0, allocator)}; - - auto out = joiner_encoder_proj.Run(Ort::RunOptions{nullptr}, - inputNames.data(), - inputTensors.data(), - inputTensors.size(), - outputNames.data(), - outputNames.size()); - - return out; - } - - std::vector joiner_decoder_proj_forward(std::vector in_vector, - std::vector dims, - Ort::MemoryInfo &memory_info){ - std::vector inputTensors; - inputTensors.push_back(Ort::Value::CreateTensor(memory_info, in_vector.data(), in_vector.size(), dims.data(), dims.size())); - - std::vector inputNames {joiner_decoder_proj.GetInputName(0, allocator)}; - std::vector outputNames {joiner_decoder_proj.GetOutputName(0, allocator)}; - - auto out = joiner_decoder_proj.Run(Ort::RunOptions{nullptr}, - inputNames.data(), - inputTensors.data(), - inputTensors.size(), - outputNames.data(), - outputNames.size()); - - return out; - } - - Ort::Session load_model(const char* path){ - struct stat buffer; - if (stat(path, &buffer) != 0){ - std::cout << "File does not exist!: " << path << std::endl; - exit(0); - } - std::cout << "loading " << path << std::endl; - Ort::Session onnx_model(env, path, session_options); - return onnx_model; - } - - void extract_constant_lm_parameters(){ - /* - all_in_one contains these params. We should trace all_in_one and find 'constants_lm' nodes to extract these params - For now, these params are set staticaly. - in: Ort::Session &all_in_one - out: {blank_id, unk_id, context_size} - should return std::vector - */ - blank_id = 0; - unk_id = 0; - context_size = 2; - } - - std::map get_token_map(const char* token_path){ - std::ifstream inFile; - inFile.open(token_path); - if (inFile.fail()) - std::cerr << "Could not find token file" << std::endl; - - std::map token_map; - - std::string line; - while (std::getline(inFile, line)) - { - int id; - std::string token; - - std::istringstream iss(line); - iss >> token; - iss >> id; - - token_map[id] = token; - } - - return token_map; - } - -}; - - -Model get_model(std::string exp_path, char* tokens_path){ - Model model{ - (exp_path + "/encoder_simp.onnx").c_str(), - (exp_path + "/decoder_simp.onnx").c_str(), - (exp_path + "/joiner_simp.onnx").c_str(), - (exp_path + "/joiner_encoder_proj_simp.onnx").c_str(), - (exp_path + "/joiner_decoder_proj_simp.onnx").c_str(), - tokens_path, - }; - model.extract_constant_lm_parameters(); - - return model; -} - -Model get_model(char* encoder_path, - char* decoder_path, - char* joiner_path, - char* joiner_encoder_proj_path, - char* joiner_decoder_proj_path, - char* tokens_path){ - Model model{ - encoder_path, - decoder_path, - joiner_path, - joiner_encoder_proj_path, - joiner_decoder_proj_path, - tokens_path, - }; - model.extract_constant_lm_parameters(); - - return model; -} - - -void doWarmup(Model *model, int numWarmup = 5){ - std::cout << "Warmup is started" << std::endl; - - std::vector encoder_warmup_sample (500 * 80, 1.0); - for (int i=0; iencoder_forward(encoder_warmup_sample, - std::vector {500}, - std::vector {1, 500, 80}, - std::vector {1}, - memory_info); - - std::vector decoder_warmup_sample {1, 1}; - for (int i=0; idecoder_forward(decoder_warmup_sample, - std::vector {1, 2}, - memory_info); - - std::vector joiner_warmup_sample1 (512, 1.0); - std::vector joiner_warmup_sample2 (512, 1.0); - for (int i=0; ijoiner_forward(joiner_warmup_sample1, - joiner_warmup_sample2, - std::vector {1, 1, 1, 512}, - std::vector {1, 1, 1, 512}, - memory_info); - - std::vector joiner_encoder_proj_warmup_sample (100 * 512, 1.0); - for (int i=0; ijoiner_encoder_proj_forward(joiner_encoder_proj_warmup_sample, - std::vector {100, 512}, - memory_info); - - std::vector joiner_decoder_proj_warmup_sample (512, 1.0); - for (int i=0; ijoiner_decoder_proj_forward(joiner_decoder_proj_warmup_sample, - std::vector {1, 512}, - memory_info); - std::cout << "Warmup is done" << std::endl; -} diff --git a/sherpa-onnx/csrc/rnnt-model.cc b/sherpa-onnx/csrc/rnnt-model.cc new file mode 100644 index 00000000..c7f9279a --- /dev/null +++ b/sherpa-onnx/csrc/rnnt-model.cc @@ -0,0 +1,247 @@ +/** + * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "sherpa-onnx/csrc/rnnt-model.h" + +#include +#include +#include + +namespace sherpa_onnx { + +/** + * Get the input names of a model. + * + * @param sess An onnxruntime session. + * @param input_names. On return, it contains the input names of the model. + * @param input_names_ptr. On return, input_names_ptr[i] contains + * input_names[i].c_str() + */ +static void GetInputNames(Ort::Session *sess, + std::vector *input_names, + std::vector *input_names_ptr) { + Ort::AllocatorWithDefaultOptions allocator; + size_t node_count = sess->GetInputCount(); + input_names->resize(node_count); + input_names_ptr->resize(node_count); + for (size_t i = 0; i != node_count; ++i) { + auto tmp = sess->GetInputNameAllocated(i, allocator); + (*input_names)[i] = tmp.get(); + (*input_names_ptr)[i] = (*input_names)[i].c_str(); + } +} + +/** + * Get the output names of a model. + * + * @param sess An onnxruntime session. + * @param output_names. On return, it contains the output names of the model. + * @param output_names_ptr. On return, output_names_ptr[i] contains + * output_names[i].c_str() + */ +static void GetOutputNames(Ort::Session *sess, + std::vector *output_names, + std::vector *output_names_ptr) { + Ort::AllocatorWithDefaultOptions allocator; + size_t node_count = sess->GetOutputCount(); + output_names->resize(node_count); + output_names_ptr->resize(node_count); + for (size_t i = 0; i != node_count; ++i) { + auto tmp = sess->GetOutputNameAllocated(i, allocator); + (*output_names)[i] = tmp.get(); + (*output_names_ptr)[i] = (*output_names)[i].c_str(); + } +} + +RnntModel::RnntModel(const std::string &encoder_filename, + const std::string &decoder_filename, + const std::string &joiner_filename, + const std::string &joiner_encoder_proj_filename, + const std::string &joiner_decoder_proj_filename, + int32_t num_threads) + : env_(ORT_LOGGING_LEVEL_WARNING) { + sess_opts_.SetIntraOpNumThreads(num_threads); + sess_opts_.SetInterOpNumThreads(num_threads); + + InitEncoder(encoder_filename); + InitDecoder(decoder_filename); + InitJoiner(joiner_filename); + InitJoinerEncoderProj(joiner_encoder_proj_filename); + InitJoinerDecoderProj(joiner_decoder_proj_filename); +} + +void RnntModel::InitEncoder(const std::string &filename) { + encoder_sess_ = + std::make_unique(env_, filename.c_str(), sess_opts_); + GetInputNames(encoder_sess_.get(), &encoder_input_names_, + &encoder_input_names_ptr_); + + GetOutputNames(encoder_sess_.get(), &encoder_output_names_, + &encoder_output_names_ptr_); +} + +void RnntModel::InitDecoder(const std::string &filename) { + decoder_sess_ = + std::make_unique(env_, filename.c_str(), sess_opts_); + + GetInputNames(decoder_sess_.get(), &decoder_input_names_, + &decoder_input_names_ptr_); + + GetOutputNames(decoder_sess_.get(), &decoder_output_names_, + &decoder_output_names_ptr_); +} + +void RnntModel::InitJoiner(const std::string &filename) { + joiner_sess_ = + std::make_unique(env_, filename.c_str(), sess_opts_); + + GetInputNames(joiner_sess_.get(), &joiner_input_names_, + &joiner_input_names_ptr_); + + GetOutputNames(joiner_sess_.get(), &joiner_output_names_, + &joiner_output_names_ptr_); +} + +void RnntModel::InitJoinerEncoderProj(const std::string &filename) { + joiner_encoder_proj_sess_ = + std::make_unique(env_, filename.c_str(), sess_opts_); + + GetInputNames(joiner_encoder_proj_sess_.get(), + &joiner_encoder_proj_input_names_, + &joiner_encoder_proj_input_names_ptr_); + + GetOutputNames(joiner_encoder_proj_sess_.get(), + &joiner_encoder_proj_output_names_, + &joiner_encoder_proj_output_names_ptr_); +} + +void RnntModel::InitJoinerDecoderProj(const std::string &filename) { + joiner_decoder_proj_sess_ = + std::make_unique(env_, filename.c_str(), sess_opts_); + + GetInputNames(joiner_decoder_proj_sess_.get(), + &joiner_decoder_proj_input_names_, + &joiner_decoder_proj_input_names_ptr_); + + GetOutputNames(joiner_decoder_proj_sess_.get(), + &joiner_decoder_proj_output_names_, + &joiner_decoder_proj_output_names_ptr_); +} + +Ort::Value RnntModel::RunEncoder(const float *features, int32_t T, + int32_t feature_dim) { + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + std::array x_shape{1, T, feature_dim}; + Ort::Value x = + Ort::Value::CreateTensor(memory_info, const_cast(features), + T * feature_dim, x_shape.data(), x_shape.size()); + + std::array x_lens_shape{1}; + int64_t x_lens_tmp = T; + + Ort::Value x_lens = Ort::Value::CreateTensor( + memory_info, &x_lens_tmp, 1, x_lens_shape.data(), x_lens_shape.size()); + + std::array encoder_inputs{std::move(x), std::move(x_lens)}; + + // Note: We discard encoder_out_lens since we only implement + // batch==1. + auto encoder_out = encoder_sess_->Run( + {}, encoder_input_names_ptr_.data(), encoder_inputs.data(), + encoder_inputs.size(), encoder_output_names_ptr_.data(), + encoder_output_names_ptr_.size()); + return std::move(encoder_out[0]); +} +Ort::Value RnntModel::RunJoinerEncoderProj(const float *encoder_out, int32_t T, + int32_t encoder_out_dim) { + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + std::array in_shape{T, encoder_out_dim}; + Ort::Value in = Ort::Value::CreateTensor( + memory_info, const_cast(encoder_out), T * encoder_out_dim, + in_shape.data(), in_shape.size()); + + auto encoder_proj_out = joiner_encoder_proj_sess_->Run( + {}, joiner_encoder_proj_input_names_ptr_.data(), &in, 1, + joiner_encoder_proj_output_names_ptr_.data(), + joiner_encoder_proj_output_names_ptr_.size()); + return std::move(encoder_proj_out[0]); +} + +Ort::Value RnntModel::RunDecoder(const int64_t *decoder_input, + int32_t context_size) { + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + int32_t batch_size = 1; // TODO(fangjun): handle the case when it's > 1 + std::array shape{batch_size, context_size}; + Ort::Value in = Ort::Value::CreateTensor( + memory_info, const_cast(decoder_input), + batch_size * context_size, shape.data(), shape.size()); + + auto decoder_out = decoder_sess_->Run( + {}, decoder_input_names_ptr_.data(), &in, 1, + decoder_output_names_ptr_.data(), decoder_output_names_ptr_.size()); + return std::move(decoder_out[0]); +} + +Ort::Value RnntModel::RunJoinerDecoderProj(const float *decoder_out, + int32_t decoder_out_dim) { + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + int32_t batch_size = 1; // TODO(fangjun): handle the case when it's > 1 + std::array shape{batch_size, decoder_out_dim}; + Ort::Value in = Ort::Value::CreateTensor( + memory_info, const_cast(decoder_out), + batch_size * decoder_out_dim, shape.data(), shape.size()); + + auto decoder_proj_out = joiner_decoder_proj_sess_->Run( + {}, joiner_decoder_proj_input_names_ptr_.data(), &in, 1, + joiner_decoder_proj_output_names_ptr_.data(), + joiner_decoder_proj_output_names_ptr_.size()); + return std::move(decoder_proj_out[0]); +} + +Ort::Value RnntModel::RunJoiner(const float *projected_encoder_out, + const float *projected_decoder_out, + int32_t joiner_dim) { + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + int32_t batch_size = 1; // TODO(fangjun): handle the case when it's > 1 + std::array shape{batch_size, joiner_dim}; + + Ort::Value enc = Ort::Value::CreateTensor( + memory_info, const_cast(projected_encoder_out), + batch_size * joiner_dim, shape.data(), shape.size()); + + Ort::Value dec = Ort::Value::CreateTensor( + memory_info, const_cast(projected_decoder_out), + batch_size * joiner_dim, shape.data(), shape.size()); + + std::array inputs{std::move(enc), std::move(dec)}; + + auto logit = joiner_sess_->Run( + {}, joiner_input_names_ptr_.data(), inputs.data(), inputs.size(), + joiner_output_names_ptr_.data(), joiner_output_names_ptr_.size()); + + return std::move(logit[0]); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/rnnt-model.h b/sherpa-onnx/csrc/rnnt-model.h new file mode 100644 index 00000000..9068d2cb --- /dev/null +++ b/sherpa-onnx/csrc/rnnt-model.h @@ -0,0 +1,148 @@ +/** + * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef SHERPA_ONNX_CSRC_RNNT_MODEL_H_ +#define SHERPA_ONNX_CSRC_RNNT_MODEL_H_ + +#include +#include +#include + +#include "onnxruntime_cxx_api.h" // NOLINT + +namespace sherpa_onnx { + +class RnntModel { + public: + /** + * @param encoder_filename Path to the encoder model + * @param decoder_filename Path to the decoder model + * @param joiner_filename Path to the joiner model + * @param joiner_encoder_proj_filename Path to the joiner encoder_proj model + * @param joiner_decoder_proj_filename Path to the joiner decoder_proj model + * @param num_threads Number of threads to use to run the models + */ + RnntModel(const std::string &encoder_filename, + const std::string &decoder_filename, + const std::string &joiner_filename, + const std::string &joiner_encoder_proj_filename, + const std::string &joiner_decoder_proj_filename, + int32_t num_threads); + + /** Run the encoder model. + * + * @TODO(fangjun): Support batch_size > 1 + * + * @param features A tensor of shape (batch_size, T, feature_dim) + * @param T Number of feature frames + * @param feature_dim Dimension of the feature. + * + * @return Return a tensor of shape (batch_size, T', encoder_out_dim) + */ + Ort::Value RunEncoder(const float *features, int32_t T, int32_t feature_dim); + + /** Run the joiner encoder_proj model. + * + * @param encoder_out A tensor of shape (T, encoder_out_dim) + * @param T Number of frames in encoder_out. + * @param encoder_out_dim Dimension of encoder_out. + * + * @return Return a tensor of shape (T, joiner_dim) + * + */ + Ort::Value RunJoinerEncoderProj(const float *encoder_out, int32_t T, + int32_t encoder_out_dim); + + /** Run the decoder model. + * + * @TODO(fangjun): Support batch_size > 1 + * + * @param decoder_input A tensor of shape (batch_size, context_size). + * @return Return a tensor of shape (batch_size, 1, decoder_out_dim) + */ + Ort::Value RunDecoder(const int64_t *decoder_input, int32_t context_size); + + /** Run joiner decoder_proj model. + * + * @TODO(fangjun): Support batch_size > 1 + * + * @param decoder_out A tensor of shape (batch_size, decoder_out_dim) + * @param decoder_out_dim Output dimension of the decoder_out. + * + * @return Return a tensor of shape (batch_size, joiner_dim); + */ + Ort::Value RunJoinerDecoderProj(const float *decoder_out, + int32_t decoder_out_dim); + + /** Run the joiner model. + * + * @TODO(fangjun): Support batch_size > 1 + * + * @param projected_encoder_out A tensor of shape (batch_size, joiner_dim). + * @param projected_decoder_out A tensor of shape (batch_size, joiner_dim). + * + * @return Return a tensor of shape (batch_size, vocab_size) + */ + Ort::Value RunJoiner(const float *projected_encoder_out, + const float *projected_decoder_out, int32_t joiner_dim); + + private: + void InitEncoder(const std::string &encoder_filename); + void InitDecoder(const std::string &decoder_filename); + void InitJoiner(const std::string &joiner_filename); + void InitJoinerEncoderProj(const std::string &joiner_encoder_proj_filename); + void InitJoinerDecoderProj(const std::string &joiner_decoder_proj_filename); + + private: + Ort::Env env_; + Ort::SessionOptions sess_opts_; + std::unique_ptr encoder_sess_; + std::unique_ptr decoder_sess_; + std::unique_ptr joiner_sess_; + std::unique_ptr joiner_encoder_proj_sess_; + std::unique_ptr joiner_decoder_proj_sess_; + + std::vector encoder_input_names_; + std::vector encoder_input_names_ptr_; + std::vector encoder_output_names_; + std::vector encoder_output_names_ptr_; + + std::vector decoder_input_names_; + std::vector decoder_input_names_ptr_; + std::vector decoder_output_names_; + std::vector decoder_output_names_ptr_; + + std::vector joiner_input_names_; + std::vector joiner_input_names_ptr_; + std::vector joiner_output_names_; + std::vector joiner_output_names_ptr_; + + std::vector joiner_encoder_proj_input_names_; + std::vector joiner_encoder_proj_input_names_ptr_; + std::vector joiner_encoder_proj_output_names_; + std::vector joiner_encoder_proj_output_names_ptr_; + + std::vector joiner_decoder_proj_input_names_; + std::vector joiner_decoder_proj_input_names_ptr_; + std::vector joiner_decoder_proj_output_names_; + std::vector joiner_decoder_proj_output_names_ptr_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_RNNT_MODEL_H_ diff --git a/sherpa-onnx/csrc/rnnt_beam_search.h b/sherpa-onnx/csrc/rnnt_beam_search.h deleted file mode 100644 index bf0d6425..00000000 --- a/sherpa-onnx/csrc/rnnt_beam_search.h +++ /dev/null @@ -1,120 +0,0 @@ -#include -#include -#include -#include - -#include "models.h" -#include "utils.h" - - -std::vector getEncoderCol(Ort::Value &tensor, int start, int length){ - float* floatarr = tensor.GetTensorMutableData(); - std::vector vector {floatarr + start, floatarr + length}; - return vector; -} - - -/** - * Assume batch size = 1 - */ -std::vector BuildDecoderInput(const std::vector> &hyps, - std::vector &decoder_input) { - - int32_t context_size = decoder_input.size(); - int32_t hyps_length = hyps[0].size(); - for (int i=0; i < context_size; i++) - decoder_input[i] = hyps[0][hyps_length-context_size+i]; - - return decoder_input; -} - - -std::vector> GreedySearch( - Model *model, // NOLINT - std::vector *encoder_out){ - Ort::Value &encoder_out_tensor = encoder_out->at(0); - int encoder_out_dim1 = encoder_out_tensor.GetTensorTypeAndShapeInfo().GetShape()[1]; - int encoder_out_dim2 = encoder_out_tensor.GetTensorTypeAndShapeInfo().GetShape()[2]; - auto encoder_out_vector = ortVal2Vector(encoder_out_tensor, encoder_out_dim1 * encoder_out_dim2); - - // # === Greedy Search === # - int32_t batch_size = 1; - std::vector blanks(model->context_size, model->blank_id); - std::vector> hyps(batch_size, blanks); - std::vector decoder_input(model->context_size, model->blank_id); - - auto decoder_out = model->decoder_forward(decoder_input, - std::vector {batch_size, model->context_size}, - memory_info); - - Ort::Value &decoder_out_tensor = decoder_out[0]; - int decoder_out_dim = decoder_out_tensor.GetTensorTypeAndShapeInfo().GetShape()[2]; - auto decoder_out_vector = ortVal2Vector(decoder_out_tensor, decoder_out_dim); - - decoder_out = model->joiner_decoder_proj_forward(decoder_out_vector, - std::vector {1, decoder_out_dim}, - memory_info); - Ort::Value &projected_decoder_out_tensor = decoder_out[0]; - auto projected_decoder_out_dim = projected_decoder_out_tensor.GetTensorTypeAndShapeInfo().GetShape()[1]; - auto projected_decoder_out_vector = ortVal2Vector(projected_decoder_out_tensor, projected_decoder_out_dim); - - auto projected_encoder_out = model->joiner_encoder_proj_forward(encoder_out_vector, - std::vector {encoder_out_dim1, encoder_out_dim2}, - memory_info); - Ort::Value &projected_encoder_out_tensor = projected_encoder_out[0]; - int projected_encoder_out_dim1 = projected_encoder_out_tensor.GetTensorTypeAndShapeInfo().GetShape()[0]; - int projected_encoder_out_dim2 = projected_encoder_out_tensor.GetTensorTypeAndShapeInfo().GetShape()[1]; - auto projected_encoder_out_vector = ortVal2Vector(projected_encoder_out_tensor, projected_encoder_out_dim1 * projected_encoder_out_dim2); - - int32_t offset = 0; - for (int i=0; i< projected_encoder_out_dim1; i++){ - int32_t cur_batch_size = 1; - int32_t start = offset; - int32_t end = start + cur_batch_size; - offset = end; - - auto cur_encoder_out = getEncoderCol(projected_encoder_out_tensor, start * projected_encoder_out_dim2, end * projected_encoder_out_dim2); - - auto logits = model->joiner_forward(cur_encoder_out, - projected_decoder_out_vector, - std::vector {1, projected_encoder_out_dim2}, - std::vector {1, projected_decoder_out_dim}, - memory_info); - - Ort::Value &logits_tensor = logits[0]; - int logits_dim = logits_tensor.GetTensorTypeAndShapeInfo().GetShape()[1]; - auto logits_vector = ortVal2Vector(logits_tensor, logits_dim); - - int max_indices = static_cast(std::distance(logits_vector.begin(), std::max_element(logits_vector.begin(), logits_vector.end()))); - bool emitted = false; - - for (int32_t k = 0; k != cur_batch_size; ++k) { - auto index = max_indices; - if (index != model->blank_id && index != model->unk_id) { - emitted = true; - hyps[k].push_back(index); - } - } - - if (emitted) { - decoder_input = BuildDecoderInput(hyps, decoder_input); - - decoder_out = model->decoder_forward(decoder_input, - std::vector {batch_size, model->context_size}, - memory_info); - - decoder_out_dim = decoder_out[0].GetTensorTypeAndShapeInfo().GetShape()[2]; - decoder_out_vector = ortVal2Vector(decoder_out[0], decoder_out_dim); - - decoder_out = model->joiner_decoder_proj_forward(decoder_out_vector, - std::vector {1, decoder_out_dim}, - memory_info); - - projected_decoder_out_dim = decoder_out[0].GetTensorTypeAndShapeInfo().GetShape()[1]; - projected_decoder_out_vector = ortVal2Vector(decoder_out[0], projected_decoder_out_dim); - } - } - - return hyps; -} - diff --git a/sherpa-onnx/csrc/sherpa-onnx.cc b/sherpa-onnx/csrc/sherpa-onnx.cc new file mode 100644 index 00000000..f0f2261c --- /dev/null +++ b/sherpa-onnx/csrc/sherpa-onnx.cc @@ -0,0 +1,129 @@ +/** + * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include "kaldi-native-fbank/csrc/online-feature.h" +#include "sherpa-onnx/csrc/decode.h" +#include "sherpa-onnx/csrc/rnnt-model.h" +#include "sherpa-onnx/csrc/symbol-table.h" +#include "sherpa-onnx/csrc/wave-reader.h" + +/** Compute fbank features of the input wave filename. + * + * @param wav_filename. Path to a mono wave file. + * @param expected_sampling_rate Expected sampling rate of the input wave file. + * @param num_frames On return, it contains the number of feature frames. + * @return Return the computed feature of shape (num_frames, feature_dim) + * stored in row-major. + */ +static std::vector ComputeFeatures(const std::string &wav_filename, + float expected_sampling_rate, + int32_t *num_frames) { + std::vector samples = + sherpa_onnx::ReadWave(wav_filename, expected_sampling_rate); + + float duration = samples.size() / expected_sampling_rate; + + std::cout << "wav filename: " << wav_filename << "\n"; + std::cout << "wav duration (s): " << duration << "\n"; + + knf::FbankOptions opts; + opts.frame_opts.dither = 0; + opts.frame_opts.snip_edges = false; + opts.frame_opts.samp_freq = expected_sampling_rate; + + int32_t feature_dim = 80; + + opts.mel_opts.num_bins = feature_dim; + + knf::OnlineFbank fbank(opts); + fbank.AcceptWaveform(expected_sampling_rate, samples.data(), samples.size()); + fbank.InputFinished(); + + *num_frames = fbank.NumFramesReady(); + + std::vector features(*num_frames * feature_dim); + float *p = features.data(); + + for (int32_t i = 0; i != fbank.NumFramesReady(); ++i, p += feature_dim) { + const float *f = fbank.GetFrame(i); + std::copy(f, f + feature_dim, p); + } + + return features; +} + +int main(int32_t argc, char *argv[]) { + if (argc < 8 || argc > 9) { + const char *usage = R"usage( +Usage: + ./bin/sherpa-onnx \ + /path/to/tokens.txt \ + /path/to/encoder.onnx \ + /path/to/decoder.onnx \ + /path/to/joiner.onnx \ + /path/to/joiner_encoder_proj.ncnn.param \ + /path/to/joiner_decoder_proj.ncnn.param \ + /path/to/foo.wav [num_threads] + +You can download pre-trained models from the following repository: +https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13 +)usage"; + std::cerr << usage << "\n"; + + return 0; + } + + std::string tokens = argv[1]; + std::string encoder = argv[2]; + std::string decoder = argv[3]; + std::string joiner = argv[4]; + std::string joiner_encoder_proj = argv[5]; + std::string joiner_decoder_proj = argv[6]; + std::string wav_filename = argv[7]; + int32_t num_threads = 4; + if (argc == 9) { + num_threads = atoi(argv[8]); + } + + sherpa_onnx::SymbolTable sym(tokens); + + int32_t num_frames; + auto features = ComputeFeatures(wav_filename, 16000, &num_frames); + int32_t feature_dim = features.size() / num_frames; + + sherpa_onnx::RnntModel model(encoder, decoder, joiner, joiner_encoder_proj, + joiner_decoder_proj, num_threads); + Ort::Value encoder_out = + model.RunEncoder(features.data(), num_frames, feature_dim); + + auto hyp = sherpa_onnx::GreedySearch(model, encoder_out); + + std::string text; + for (auto i : hyp) { + text += sym[i]; + } + + std::cout << "Recognition result for " << wav_filename << "\n" + << text << "\n"; + + return 0; +} diff --git a/sherpa-onnx/csrc/online-fbank-test.cc b/sherpa-onnx/csrc/show-onnx-info.cc similarity index 60% rename from sherpa-onnx/csrc/online-fbank-test.cc rename to sherpa-onnx/csrc/show-onnx-info.cc index 9f595cf7..3ee78fbb 100644 --- a/sherpa-onnx/csrc/online-fbank-test.cc +++ b/sherpa-onnx/csrc/show-onnx-info.cc @@ -15,34 +15,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #include +#include -#include "kaldi-native-fbank/csrc/online-feature.h" +#include "onnxruntime_cxx_api.h" // NOLINT int main() { - knf::FbankOptions opts; - opts.frame_opts.dither = 0; - opts.mel_opts.num_bins = 10; - - knf::OnlineFbank fbank(opts); - for (int32_t i = 0; i < 1600; ++i) { - float s = (i * i - i / 2) / 32767.; - fbank.AcceptWaveform(16000, &s, 1); - } - + std::cout << "ORT_API_VERSION: " << ORT_API_VERSION << "\n"; + std::vector providers = Ort::GetAvailableProviders(); std::ostringstream os; - - int32_t n = fbank.NumFramesReady(); - for (int32_t i = 0; i != n; ++i) { - const float *frame = fbank.GetFrame(i); - for (int32_t k = 0; k != opts.mel_opts.num_bins; ++k) { - os << frame[k] << ", "; - } - os << "\n"; + os << "Available providers: "; + std::string sep = ""; + for (const auto &p : providers) { + os << sep << p; + sep = ", "; } - std::cout << os.str() << "\n"; - return 0; } diff --git a/sherpa-onnx/csrc/symbol-table.cc b/sherpa-onnx/csrc/symbol-table.cc new file mode 100644 index 00000000..50ffe961 --- /dev/null +++ b/sherpa-onnx/csrc/symbol-table.cc @@ -0,0 +1,78 @@ +/** + * Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "sherpa-onnx/csrc/symbol-table.h" + +#include +#include +#include + +namespace sherpa_onnx { + +SymbolTable::SymbolTable(const std::string &filename) { + std::ifstream is(filename); + std::string sym; + int32_t id; + while (is >> sym >> id) { + if (sym.size() >= 3) { + // For BPE-based models, we replace ▁ with a space + // Unicode 9601, hex 0x2581, utf8 0xe29681 + const uint8_t *p = reinterpret_cast(sym.c_str()); + if (p[0] == 0xe2 && p[1] == 0x96 && p[2] == 0x81) { + sym = sym.replace(0, 3, " "); + } + } + + assert(!sym.empty()); + assert(sym2id_.count(sym) == 0); + assert(id2sym_.count(id) == 0); + + sym2id_.insert({sym, id}); + id2sym_.insert({id, sym}); + } + assert(is.eof()); +} + +std::string SymbolTable::ToString() const { + std::ostringstream os; + char sep = ' '; + for (const auto &p : sym2id_) { + os << p.first << sep << p.second << "\n"; + } + return os.str(); +} + +const std::string &SymbolTable::operator[](int32_t id) const { + return id2sym_.at(id); +} + +int32_t SymbolTable::operator[](const std::string &sym) const { + return sym2id_.at(sym); +} + +bool SymbolTable::contains(int32_t id) const { return id2sym_.count(id) != 0; } + +bool SymbolTable::contains(const std::string &sym) const { + return sym2id_.count(sym) != 0; +} + +std::ostream &operator<<(std::ostream &os, const SymbolTable &symbol_table) { + return os << symbol_table.ToString(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/symbol-table.h b/sherpa-onnx/csrc/symbol-table.h new file mode 100644 index 00000000..46044cfd --- /dev/null +++ b/sherpa-onnx/csrc/symbol-table.h @@ -0,0 +1,62 @@ +/** + * Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef SHERPA_ONNX_CSRC_SYMBOL_TABLE_H_ +#define SHERPA_ONNX_CSRC_SYMBOL_TABLE_H_ + +#include +#include + +namespace sherpa_onnx { + +/// It manages mapping between symbols and integer IDs. +class SymbolTable { + public: + SymbolTable() = default; + /// Construct a symbol table from a file. + /// Each line in the file contains two fields: + /// + /// sym ID + /// + /// Fields are separated by space(s). + explicit SymbolTable(const std::string &filename); + + /// Return a string representation of this symbol table + std::string ToString() const; + + /// Return the symbol corresponding to the given ID. + const std::string &operator[](int32_t id) const; + /// Return the ID corresponding to the given symbol. + int32_t operator[](const std::string &sym) const; + + /// Return true if there is a symbol with the given ID. + bool contains(int32_t id) const; + + /// Return true if there is a given symbol in the symbol table. + bool contains(const std::string &sym) const; + + private: + std::unordered_map sym2id_; + std::unordered_map id2sym_; +}; + +std::ostream &operator<<(std::ostream &os, const SymbolTable &symbol_table); + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_SYMBOL_TABLE_H_ diff --git a/sherpa-onnx/csrc/utils.h b/sherpa-onnx/csrc/utils.h deleted file mode 100644 index 17fdbbc0..00000000 --- a/sherpa-onnx/csrc/utils.h +++ /dev/null @@ -1,39 +0,0 @@ -#include -#include - - -void vector2file(std::vector vector, std::string saveFileName){ - std::ofstream f(saveFileName); - for(std::vector::const_iterator i = vector.begin(); i != vector.end(); ++i) { - f << *i << '\n'; - } -} - - -std::vector hyps2result(std::map token_map, std::vector> hyps, int context_size = 2){ - std::vector results; - - for (int k=0; k < hyps.size(); k++){ - std::string result = token_map[hyps[k][context_size]]; - - for (int i=context_size+1; i < hyps[k].size(); i++){ - std::string token = token_map[hyps[k][i]]; - - // TODO: recognising '_' is not working - if (token.at(0) == '_') - result += " " + token; - else - result += token; - } - results.push_back(result); - } - return results; -} - - -void print_hyps(std::vector> hyps, int context_size = 2){ - std::cout << "Hyps:" << std::endl; - for (int i=context_size; i -#include "onnxruntime_cxx_api.h" - -Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "test"); -const auto& api = Ort::GetApi(); -OrtTensorRTProviderOptionsV2* tensorrt_options; -Ort::SessionOptions session_options; -Ort::AllocatorWithDefaultOptions allocator; -auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); - - -std::vector ortVal2Vector(Ort::Value &tensor, int tensor_length){ - /** - * convert ort tensor to vector - */ - float* floatarr = tensor.GetTensorMutableData(); - std::vector vector {floatarr, floatarr + tensor_length}; - return vector; -} - - -void print_onnx_forward_output(std::vector &output_tensors, int num){ - float* floatarr = output_tensors.front().GetTensorMutableData(); - for (int i = 0; i < num; i++) - printf("[%d] = %f\n", i, floatarr[i]); -} - - -void print_shape_of_ort_val(std::vector &tensor){ - auto out_shape = tensor.front().GetTensorTypeAndShapeInfo().GetShape(); - auto out_size = out_shape.size(); - std::cout << "("; - for (int i=0; i input_node_names(num_input_nodes); - std::vector input_node_dims; - - printf("Number of inputs = %zu\n", num_input_nodes); - - char* output_name = session.GetOutputName(0, allocator); - printf("output name: %s\n", output_name); - - // iterate over all input nodes - for (int i = 0; i < num_input_nodes; i++) { - // print input node names - char* input_name = session.GetInputName(i, allocator); - printf("Input %d : name=%s\n", i, input_name); - input_node_names[i] = input_name; - - // print input node types - Ort::TypeInfo type_info = session.GetInputTypeInfo(i); - auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); - - ONNXTensorElementDataType type = tensor_info.GetElementType(); - printf("Input %d : type=%d\n", i, type); - - // print input shapes/dims - input_node_dims = tensor_info.GetShape(); - printf("Input %d : num_dims=%zu\n", i, input_node_dims.size()); - for (size_t j = 0; j < input_node_dims.size(); j++) - printf("Input %d : dim %zu=%jd\n", i, j, input_node_dims[j]); - } - std::cout << "=======================================" << std::endl; -} diff --git a/sherpa-onnx/csrc/wave-reader.cc b/sherpa-onnx/csrc/wave-reader.cc new file mode 100644 index 00000000..505f18b9 --- /dev/null +++ b/sherpa-onnx/csrc/wave-reader.cc @@ -0,0 +1,108 @@ +/** + * Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "sherpa-onnx/csrc/wave-reader.h" + +#include +#include +#include +#include +#include + +namespace sherpa_onnx { + +namespace { +// see http://soundfile.sapp.org/doc/WaveFormat/ +// +// Note: We assume little endian here +// TODO(fangjun): Support big endian +struct WaveHeader { + void Validate() const { + // F F I R + assert(chunk_id == 0x46464952); + assert(chunk_size == 36 + subchunk2_size); + // E V A W + assert(format == 0x45564157); + assert(subchunk1_id == 0x20746d66); + assert(subchunk1_size == 16); // 16 for PCM + assert(audio_format == 1); // 1 for PCM + assert(num_channels == 1); // we support only single channel for now + assert(byte_rate == sample_rate * num_channels * bits_per_sample / 8); + assert(block_align == num_channels * bits_per_sample / 8); + assert(bits_per_sample == 16); // we support only 16 bits per sample + } + + int32_t chunk_id; + int32_t chunk_size; + int32_t format; + int32_t subchunk1_id; + int32_t subchunk1_size; + int16_t audio_format; + int16_t num_channels; + int32_t sample_rate; + int32_t byte_rate; + int16_t block_align; + int16_t bits_per_sample; + int32_t subchunk2_id; + int32_t subchunk2_size; +}; +static_assert(sizeof(WaveHeader) == 44, ""); + +// Read a wave file of mono-channel. +// Return its samples normalized to the range [-1, 1). +std::vector ReadWaveImpl(std::istream &is, float *sample_rate) { + WaveHeader header; + is.read(reinterpret_cast(&header), sizeof(header)); + assert(static_cast(is)); + + header.Validate(); + + *sample_rate = header.sample_rate; + + // header.subchunk2_size contains the number of bytes in the data. + // As we assume each sample contains two bytes, so it is divided by 2 here + std::vector samples(header.subchunk2_size / 2); + + is.read(reinterpret_cast(samples.data()), header.subchunk2_size); + + assert(static_cast(is)); + + std::vector ans(samples.size()); + for (int32_t i = 0; i != ans.size(); ++i) { + ans[i] = samples[i] / 32768.; + } + + return ans; +} + +} // namespace + +std::vector ReadWave(const std::string &filename, + float expected_sample_rate) { + std::ifstream is(filename, std::ifstream::binary); + float sample_rate; + auto samples = ReadWaveImpl(is, &sample_rate); + if (expected_sample_rate != sample_rate) { + std::cerr << "Expected sample rate: " << expected_sample_rate + << ". Given: " << sample_rate << ".\n"; + exit(-1); + } + return samples; +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/wave-reader.h b/sherpa-onnx/csrc/wave-reader.h new file mode 100644 index 00000000..7db5c1f9 --- /dev/null +++ b/sherpa-onnx/csrc/wave-reader.h @@ -0,0 +1,41 @@ +/** + * Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef SHERPA_ONNX_CSRC_WAVE_READER_H_ +#define SHERPA_ONNX_CSRC_WAVE_READER_H_ + +#include +#include +#include + +namespace sherpa_onnx { + +/** Read a wave file with expected sample rate. + + @param filename Path to a wave file. It MUST be single channel, PCM encoded. + @param expected_sample_rate Expected sample rate of the wave file. If the + sample rate don't match, it throws an exception. + + @return Return wave samples normalized to the range [-1, 1). + */ +std::vector ReadWave(const std::string &filename, + float expected_sample_rate); + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_WAVE_READER_H_