diff --git a/.github/workflows/test-linux.yaml b/.github/workflows/test-linux.yaml new file mode 100644 index 00000000..09ff6b81 --- /dev/null +++ b/.github/workflows/test-linux.yaml @@ -0,0 +1,85 @@ +name: test-linux + +on: + push: + branches: + - master + paths: + - '.github/workflows/test-linux.yaml' + - 'CMakeLists.txt' + - 'cmake/**' + - 'sherpa-onnx/csrc/*' + pull_request: + branches: + - master + paths: + - '.github/workflows/test-linux.yaml' + - 'CMakeLists.txt' + - 'cmake/**' + - 'sherpa-onnx/csrc/*' + +concurrency: + group: test-linux-${{ github.ref }} + cancel-in-progress: true + +permissions: + contents: read + +jobs: + test-linux: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Download pretrained model and test-data (English) + shell: bash + run: | + git lfs install + git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13 + + - name: Configure Cmake + shell: bash + run: | + mkdir build + cd build + cmake -D CMAKE_BUILD_TYPE=Release .. + + - name: Build sherpa-onnx for ubuntu + run: | + cd build + make VERBOSE=1 -j3 + + - name: Run tests for ubuntu (English) + run: | + time ./build/bin/sherpa-onnx \ + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \ + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \ + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \ + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \ + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \ + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \ + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1089-134686-0001.wav + + time ./build/bin/sherpa-onnx \ + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \ + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \ + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \ + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \ + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \ + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \ + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0001.wav + + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \ + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \ + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \ + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \ + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \ + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \ + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0002.wav diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..378eac25 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +build diff --git a/CMakeLists.txt b/CMakeLists.txt index a8ec5099..c48ff078 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -38,7 +38,8 @@ set(CMAKE_CXX_EXTENSIONS OFF) list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules) list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake) -include(cmake/kaldi_native_io.cmake) -include(cmake/kaldi-native-fbank.cmake) +include(kaldi_native_io) +include(kaldi-native-fbank) +include(onnxruntime) add_subdirectory(sherpa-onnx) diff --git a/cmake/kaldi_native_io.cmake b/cmake/kaldi_native_io.cmake index 406f2efb..9e7cdec4 100644 --- a/cmake/kaldi_native_io.cmake +++ b/cmake/kaldi_native_io.cmake @@ -1,27 +1,39 @@ -if(DEFINED ENV{KALDI_NATIVE_IO_INSTALL_PREFIX}) - message(STATUS "Using environment variable KALDI_NATIVE_IO_INSTALL_PREFIX: $ENV{KALDI_NATIVE_IO_INSTALL_PREFIX}") - set(KALDI_NATIVE_IO_CMAKE_PREFIX_PATH $ENV{KALDI_NATIVE_IO_INSTALL_PREFIX}) -else() - # PYTHON_EXECUTABLE is set by cmake/pybind11.cmake - message(STATUS "Python executable: ${PYTHON_EXECUTABLE}") +function(download_kaldi_native_io) + if(CMAKE_VERSION VERSION_LESS 3.11) + # FetchContent is available since 3.11, + # we've copied it to ${CMAKE_SOURCE_DIR}/cmake/Modules + # so that it can be used in lower CMake versions. + message(STATUS "Use FetchContent provided by sherpa-onnx") + list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules) + endif() - execute_process( - COMMAND "${PYTHON_EXECUTABLE}" -c "import kaldi_native_io; print(kaldi_native_io.cmake_prefix_path)" - OUTPUT_STRIP_TRAILING_WHITESPACE - OUTPUT_VARIABLE KALDI_NATIVE_IO_CMAKE_PREFIX_PATH + include(FetchContent) + + set(kaldi_native_io_URL "https://github.com/csukuangfj/kaldi_native_io/archive/refs/tags/v1.15.1.tar.gz") + set(kaldi_native_io_HASH "SHA256=97377e1d61e99d8fc1d6037a418d3037522dfa46337e06162e24b1d97f3d70a6") + + set(KALDI_NATIVE_IO_BUILD_TESTS OFF CACHE BOOL "" FORCE) + set(KALDI_NATIVE_IO_BUILD_PYTHON OFF CACHE BOOL "" FORCE) + + FetchContent_Declare(kaldi_native_io + URL ${kaldi_native_io_URL} + URL_HASH ${kaldi_native_io_HASH} ) -endif() -message(STATUS "KALDI_NATIVE_IO_CMAKE_PREFIX_PATH: ${KALDI_NATIVE_IO_CMAKE_PREFIX_PATH}") -list(APPEND CMAKE_PREFIX_PATH "${KALDI_NATIVE_IO_CMAKE_PREFIX_PATH}") + FetchContent_GetProperties(kaldi_native_io) + if(NOT kaldi_native_io_POPULATED) + message(STATUS "Downloading kaldi_native_io ${kaldi_native_io_URL}") + FetchContent_Populate(kaldi_native_io) + endif() + message(STATUS "kaldi_native_io is downloaded to ${kaldi_native_io_SOURCE_DIR}") + message(STATUS "kaldi_native_io's binary dir is ${kaldi_native_io_BINARY_DIR}") -find_package(kaldi_native_io REQUIRED) + add_subdirectory(${kaldi_native_io_SOURCE_DIR} ${kaldi_native_io_BINARY_DIR} EXCLUDE_FROM_ALL) -message(STATUS "KALDI_NATIVE_IO_FOUND: ${KALDI_NATIVE_IO_FOUND}") -message(STATUS "KALDI_NATIVE_IO_VERSION: ${KALDI_NATIVE_IO_VERSION}") -message(STATUS "KALDI_NATIVE_IO_INCLUDE_DIRS: ${KALDI_NATIVE_IO_INCLUDE_DIRS}") -message(STATUS "KALDI_NATIVE_IO_CXX_FLAGS: ${KALDI_NATIVE_IO_CXX_FLAGS}") -message(STATUS "KALDI_NATIVE_IO_LIBRARIES: ${KALDI_NATIVE_IO_LIBRARIES}") + target_include_directories(kaldi_native_io_core + PUBLIC + ${kaldi_native_io_SOURCE_DIR}/ + ) +endfunction() -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${KALDI_NATIVE_IO_CXX_FLAGS}") -message(STATUS "CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}") \ No newline at end of file +download_kaldi_native_io() diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake new file mode 100644 index 00000000..a9ea794d --- /dev/null +++ b/cmake/onnxruntime.cmake @@ -0,0 +1,55 @@ +function(download_onnxruntime) + if(CMAKE_VERSION VERSION_LESS 3.11) + # FetchContent is available since 3.11, + # we've copied it to ${CMAKE_SOURCE_DIR}/cmake/Modules + # so that it can be used in lower CMake versions. + message(STATUS "Use FetchContent provided by sherpa-onnx") + list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules) + endif() + + include(FetchContent) + + if(UNIX AND NOT APPLE) + # set(onnxruntime_URL "http://github.com/microsoft/onnxruntime/releases/download/v1.12.1/onnxruntime-linux-x64-1.12.1.tgz") + + # If you don't have access to the internet, you can first download onnxruntime to some directory, and the use + # set(onnxruntime_URL "file:///ceph-fj/fangjun/open-source/sherpa-onnx/onnxruntime-linux-x64-1.12.1.tgz") + + set(onnxruntime_HASH "SHA256=8f6eb9e2da9cf74e7905bf3fc687ef52e34cc566af7af2f92dafe5a5d106aa3d") + # After downloading, it contains: + # ./lib/libonnxruntime.so.1.12.1 + # ./lib/libonnxruntime.so, which is a symlink to lib/libonnxruntime.so.1.12.1 + # + # ./include + # It contains all the needed header files + else() + message(FATAL_ERROR "Only support Linux at present. Will support other OSes later") + endif() + + FetchContent_Declare(onnxruntime + URL ${onnxruntime_URL} + URL_HASH ${onnxruntime_HASH} + ) + + FetchContent_GetProperties(onnxruntime) + if(NOT onnxruntime_POPULATED) + message(STATUS "Downloading onnxruntime ${onnxruntime_URL}") + FetchContent_Populate(onnxruntime) + endif() + message(STATUS "onnxruntime is downloaded to ${onnxruntime_SOURCE_DIR}") + + find_library(location_onnxruntime onnxruntime + PATHS + "${onnxruntime_SOURCE_DIR}/lib" + ) + + message(STATUS "location_onnxruntime: ${location_onnxruntime}") + + add_library(onnxruntime SHARED IMPORTED) + set_target_properties(onnxruntime PROPERTIES + IMPORTED_LOCATION ${location_onnxruntime} + INTERFACE_INCLUDE_DIRECTORIES "${onnxruntime_SOURCE_DIR}/include" + ) +endfunction() + +download_onnxruntime() diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index e6bfd4e6..94ef9147 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -1,13 +1,8 @@ -add_executable(online-fbank-test online-fbank-test.cc) -target_link_libraries(online-fbank-test kaldi-native-fbank-core) - -include_directories( - ${ONNXRUNTIME_ROOTDIR}/include/onnxruntime/core/session/ - ${ONNXRUNTIME_ROOTDIR}/include/onnxruntime/core/providers/tensorrt/ -) - -include_directories( - ${KALDINATIVEIO} -) +include_directories(${CMAKE_SOURCE_DIR}) add_executable(sherpa-onnx main.cpp) -target_link_libraries(sherpa-onnx onnxruntime kaldi-native-fbank-core kaldi_native_io_core) + +target_link_libraries(sherpa-onnx + onnxruntime + kaldi-native-fbank-core + kaldi_native_io_core +) diff --git a/sherpa-onnx/csrc/main.cpp b/sherpa-onnx/csrc/main.cpp index cd240915..77577c2e 100644 --- a/sherpa-onnx/csrc/main.cpp +++ b/sherpa-onnx/csrc/main.cpp @@ -1,101 +1,99 @@ -#include -#include #include -#include -#include #include +#include +#include +#include +#include -#include "fbank_features.h" -#include "rnnt_beam_search.h" +#include "sherpa-onnx/csrc/fbank_features.h" +#include "sherpa-onnx/csrc/rnnt_beam_search.h" #include "kaldi-native-fbank/csrc/online-feature.h" +int main(int argc, char *argv[]) { + char *encoder_path = argv[1]; + char *decoder_path = argv[2]; + char *joiner_path = argv[3]; + char *joiner_encoder_proj_path = argv[4]; + char *joiner_decoder_proj_path = argv[5]; + char *token_path = argv[6]; + std::string search_method = argv[7]; + char *filename = argv[8]; -int main(int argc, char* argv[]) { - char* encoder_path = argv[1]; - char* decoder_path = argv[2]; - char* joiner_path = argv[3]; - char* joiner_encoder_proj_path = argv[4]; - char* joiner_decoder_proj_path = argv[5]; - char* token_path = argv[6]; - std::string search_method = argv[7]; - char* filename = argv[8]; + // General parameters + int numberOfThreads = 16; - // General parameters - int numberOfThreads = 16; + // Initialize fbanks + knf::FbankOptions opts; + opts.frame_opts.dither = 0; + opts.frame_opts.samp_freq = 16000; + opts.frame_opts.frame_shift_ms = 10.0f; + opts.frame_opts.frame_length_ms = 25.0f; + opts.mel_opts.num_bins = 80; + opts.frame_opts.window_type = "povey"; + opts.frame_opts.snip_edges = false; + knf::OnlineFbank fbank(opts); - // Initialize fbanks - knf::FbankOptions opts; - opts.frame_opts.dither = 0; - opts.frame_opts.samp_freq = 16000; - opts.frame_opts.frame_shift_ms = 10.0f; - opts.frame_opts.frame_length_ms = 25.0f; - opts.mel_opts.num_bins = 80; - opts.frame_opts.window_type = "povey"; - opts.frame_opts.snip_edges = false; - knf::OnlineFbank fbank(opts); + // set session opts + // https://onnxruntime.ai/docs/performance/tune-performance.html + session_options.SetIntraOpNumThreads(numberOfThreads); + session_options.SetInterOpNumThreads(numberOfThreads); + session_options.SetGraphOptimizationLevel( + GraphOptimizationLevel::ORT_ENABLE_EXTENDED); + session_options.SetLogSeverityLevel(4); + session_options.SetExecutionMode(ExecutionMode::ORT_SEQUENTIAL); - // set session opts - // https://onnxruntime.ai/docs/performance/tune-performance.html - session_options.SetIntraOpNumThreads(numberOfThreads); - session_options.SetInterOpNumThreads(numberOfThreads); - session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED); - session_options.SetLogSeverityLevel(4); - session_options.SetExecutionMode(ExecutionMode::ORT_SEQUENTIAL); - - api.CreateTensorRTProviderOptions(&tensorrt_options); - std::unique_ptr rel_trt_options(tensorrt_options, api.ReleaseTensorRTProviderOptions); - api.SessionOptionsAppendExecutionProvider_TensorRT_V2(static_cast(session_options), rel_trt_options.get()); + api.CreateTensorRTProviderOptions(&tensorrt_options); + std::unique_ptr + rel_trt_options(tensorrt_options, api.ReleaseTensorRTProviderOptions); + api.SessionOptionsAppendExecutionProvider_TensorRT_V2( + static_cast(session_options), rel_trt_options.get()); - // Define model - auto model = get_model( - encoder_path, - decoder_path, - joiner_path, - joiner_encoder_proj_path, - joiner_decoder_proj_path, - token_path - ); - - std::vector filename_list { - filename - }; + // Define model + auto model = + get_model(encoder_path, decoder_path, joiner_path, + joiner_encoder_proj_path, joiner_decoder_proj_path, token_path); - for (auto filename : filename_list){ - std::cout << filename << std::endl; - auto samples = readWav(filename, true); - int numSamples = samples.NumCols(); + std::vector filename_list{filename}; - auto features = ComputeFeatures(fbank, opts, samples); + for (auto filename : filename_list) { + std::cout << filename << std::endl; + auto samples = readWav(filename, true); + int numSamples = samples.NumCols(); - auto tic = std::chrono::high_resolution_clock::now(); + auto features = ComputeFeatures(fbank, opts, samples); - // # === Encoder Out === # - int num_frames = features.size() / opts.mel_opts.num_bins; - auto encoder_out = model.encoder_forward(features, - std::vector {num_frames}, - std::vector {1, num_frames, 80}, - std::vector {1}, - memory_info); + auto tic = std::chrono::high_resolution_clock::now(); - // # === Search === # - std::vector> hyps; - if (search_method == "greedy") - hyps = GreedySearch(&model, &encoder_out); - else{ - std::cout << "wrong search method!" << std::endl; - exit(0); - } - auto results = hyps2result(model.tokens_map, hyps); + // # === Encoder Out === # + int num_frames = features.size() / opts.mel_opts.num_bins; + auto encoder_out = + model.encoder_forward(features, std::vector{num_frames}, + std::vector{1, num_frames, 80}, + std::vector{1}, memory_info); - // # === Print Elapsed Time === # - auto elapsed = std::chrono::duration_cast(std::chrono::high_resolution_clock::now() - tic); - std::cout << "Elapsed: " << float(elapsed.count()) / 1000 << " seconds" << std::endl; - std::cout << "rtf: " << float(elapsed.count()) / 1000 / (numSamples / 16000) << std::endl; - - print_hyps(hyps); - std::cout << results[0] << std::endl; + // # === Search === # + std::vector> hyps; + if (search_method == "greedy") + hyps = GreedySearch(&model, &encoder_out); + else { + std::cout << "wrong search method!" << std::endl; + exit(0); } + auto results = hyps2result(model.tokens_map, hyps); - return 0; + // # === Print Elapsed Time === # + auto elapsed = std::chrono::duration_cast( + std::chrono::high_resolution_clock::now() - tic); + std::cout << "Elapsed: " << float(elapsed.count()) / 1000 << " seconds" + << std::endl; + std::cout << "rtf: " << float(elapsed.count()) / 1000 / (numSamples / 16000) + << std::endl; + + print_hyps(hyps); + std::cout << results[0] << std::endl; + } + + return 0; } diff --git a/sherpa-onnx/csrc/rnnt_beam_search.h b/sherpa-onnx/csrc/rnnt_beam_search.h index c027680f..bf0d6425 100644 --- a/sherpa-onnx/csrc/rnnt_beam_search.h +++ b/sherpa-onnx/csrc/rnnt_beam_search.h @@ -61,7 +61,6 @@ std::vector> GreedySearch( auto projected_encoder_out = model->joiner_encoder_proj_forward(encoder_out_vector, std::vector {encoder_out_dim1, encoder_out_dim2}, memory_info); - Ort::Value &projected_encoder_out_tensor = projected_encoder_out[0]; int projected_encoder_out_dim1 = projected_encoder_out_tensor.GetTensorTypeAndShapeInfo().GetShape()[0]; int projected_encoder_out_dim2 = projected_encoder_out_tensor.GetTensorTypeAndShapeInfo().GetShape()[1]; @@ -78,12 +77,12 @@ std::vector> GreedySearch( auto logits = model->joiner_forward(cur_encoder_out, projected_decoder_out_vector, - std::vector {1, 1, 1, projected_encoder_out_dim2}, - std::vector {1, 1, 1, projected_decoder_out_dim}, + std::vector {1, projected_encoder_out_dim2}, + std::vector {1, projected_decoder_out_dim}, memory_info); Ort::Value &logits_tensor = logits[0]; - int logits_dim = logits_tensor.GetTensorTypeAndShapeInfo().GetShape()[3]; + int logits_dim = logits_tensor.GetTensorTypeAndShapeInfo().GetShape()[1]; auto logits_vector = ortVal2Vector(logits_tensor, logits_dim); int max_indices = static_cast(std::distance(logits_vector.begin(), std::max_element(logits_vector.begin(), logits_vector.end()))); diff --git a/sherpa-onnx/csrc/utils_onnx.h b/sherpa-onnx/csrc/utils_onnx.h index 4da8d752..5683c901 100644 --- a/sherpa-onnx/csrc/utils_onnx.h +++ b/sherpa-onnx/csrc/utils_onnx.h @@ -1,5 +1,5 @@ #include -#include +#include "onnxruntime_cxx_api.h" Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "test"); const auto& api = Ort::GetApi();