code refactoring and add CI (#11)
This commit is contained in:
85
.github/workflows/test-linux.yaml
vendored
Normal file
85
.github/workflows/test-linux.yaml
vendored
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
name: test-linux
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- master
|
||||||
|
paths:
|
||||||
|
- '.github/workflows/test-linux.yaml'
|
||||||
|
- 'CMakeLists.txt'
|
||||||
|
- 'cmake/**'
|
||||||
|
- 'sherpa-onnx/csrc/*'
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- master
|
||||||
|
paths:
|
||||||
|
- '.github/workflows/test-linux.yaml'
|
||||||
|
- 'CMakeLists.txt'
|
||||||
|
- 'cmake/**'
|
||||||
|
- 'sherpa-onnx/csrc/*'
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: test-linux-${{ github.ref }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test-linux:
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
os: [ubuntu-latest]
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- name: Download pretrained model and test-data (English)
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
git lfs install
|
||||||
|
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13
|
||||||
|
|
||||||
|
- name: Configure Cmake
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
mkdir build
|
||||||
|
cd build
|
||||||
|
cmake -D CMAKE_BUILD_TYPE=Release ..
|
||||||
|
|
||||||
|
- name: Build sherpa-onnx for ubuntu
|
||||||
|
run: |
|
||||||
|
cd build
|
||||||
|
make VERBOSE=1 -j3
|
||||||
|
|
||||||
|
- name: Run tests for ubuntu (English)
|
||||||
|
run: |
|
||||||
|
time ./build/bin/sherpa-onnx \
|
||||||
|
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \
|
||||||
|
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \
|
||||||
|
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \
|
||||||
|
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \
|
||||||
|
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \
|
||||||
|
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \
|
||||||
|
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1089-134686-0001.wav
|
||||||
|
|
||||||
|
time ./build/bin/sherpa-onnx \
|
||||||
|
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \
|
||||||
|
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \
|
||||||
|
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \
|
||||||
|
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \
|
||||||
|
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \
|
||||||
|
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \
|
||||||
|
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0001.wav
|
||||||
|
|
||||||
|
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/encoder.onnx \
|
||||||
|
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/decoder.onnx \
|
||||||
|
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner.onnx \
|
||||||
|
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_encoder_proj.onnx \
|
||||||
|
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/onnx/joiner_decoder_proj.onnx \
|
||||||
|
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \
|
||||||
|
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0002.wav
|
||||||
1
.gitignore
vendored
Normal file
1
.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
build
|
||||||
@@ -38,7 +38,8 @@ set(CMAKE_CXX_EXTENSIONS OFF)
|
|||||||
list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules)
|
list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules)
|
||||||
list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake)
|
list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake)
|
||||||
|
|
||||||
include(cmake/kaldi_native_io.cmake)
|
include(kaldi_native_io)
|
||||||
include(cmake/kaldi-native-fbank.cmake)
|
include(kaldi-native-fbank)
|
||||||
|
include(onnxruntime)
|
||||||
|
|
||||||
add_subdirectory(sherpa-onnx)
|
add_subdirectory(sherpa-onnx)
|
||||||
|
|||||||
@@ -1,27 +1,39 @@
|
|||||||
if(DEFINED ENV{KALDI_NATIVE_IO_INSTALL_PREFIX})
|
function(download_kaldi_native_io)
|
||||||
message(STATUS "Using environment variable KALDI_NATIVE_IO_INSTALL_PREFIX: $ENV{KALDI_NATIVE_IO_INSTALL_PREFIX}")
|
if(CMAKE_VERSION VERSION_LESS 3.11)
|
||||||
set(KALDI_NATIVE_IO_CMAKE_PREFIX_PATH $ENV{KALDI_NATIVE_IO_INSTALL_PREFIX})
|
# FetchContent is available since 3.11,
|
||||||
else()
|
# we've copied it to ${CMAKE_SOURCE_DIR}/cmake/Modules
|
||||||
# PYTHON_EXECUTABLE is set by cmake/pybind11.cmake
|
# so that it can be used in lower CMake versions.
|
||||||
message(STATUS "Python executable: ${PYTHON_EXECUTABLE}")
|
message(STATUS "Use FetchContent provided by sherpa-onnx")
|
||||||
|
list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules)
|
||||||
|
endif()
|
||||||
|
|
||||||
execute_process(
|
include(FetchContent)
|
||||||
COMMAND "${PYTHON_EXECUTABLE}" -c "import kaldi_native_io; print(kaldi_native_io.cmake_prefix_path)"
|
|
||||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
set(kaldi_native_io_URL "https://github.com/csukuangfj/kaldi_native_io/archive/refs/tags/v1.15.1.tar.gz")
|
||||||
OUTPUT_VARIABLE KALDI_NATIVE_IO_CMAKE_PREFIX_PATH
|
set(kaldi_native_io_HASH "SHA256=97377e1d61e99d8fc1d6037a418d3037522dfa46337e06162e24b1d97f3d70a6")
|
||||||
|
|
||||||
|
set(KALDI_NATIVE_IO_BUILD_TESTS OFF CACHE BOOL "" FORCE)
|
||||||
|
set(KALDI_NATIVE_IO_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
|
||||||
|
|
||||||
|
FetchContent_Declare(kaldi_native_io
|
||||||
|
URL ${kaldi_native_io_URL}
|
||||||
|
URL_HASH ${kaldi_native_io_HASH}
|
||||||
)
|
)
|
||||||
endif()
|
|
||||||
|
|
||||||
message(STATUS "KALDI_NATIVE_IO_CMAKE_PREFIX_PATH: ${KALDI_NATIVE_IO_CMAKE_PREFIX_PATH}")
|
FetchContent_GetProperties(kaldi_native_io)
|
||||||
list(APPEND CMAKE_PREFIX_PATH "${KALDI_NATIVE_IO_CMAKE_PREFIX_PATH}")
|
if(NOT kaldi_native_io_POPULATED)
|
||||||
|
message(STATUS "Downloading kaldi_native_io ${kaldi_native_io_URL}")
|
||||||
|
FetchContent_Populate(kaldi_native_io)
|
||||||
|
endif()
|
||||||
|
message(STATUS "kaldi_native_io is downloaded to ${kaldi_native_io_SOURCE_DIR}")
|
||||||
|
message(STATUS "kaldi_native_io's binary dir is ${kaldi_native_io_BINARY_DIR}")
|
||||||
|
|
||||||
find_package(kaldi_native_io REQUIRED)
|
add_subdirectory(${kaldi_native_io_SOURCE_DIR} ${kaldi_native_io_BINARY_DIR} EXCLUDE_FROM_ALL)
|
||||||
|
|
||||||
message(STATUS "KALDI_NATIVE_IO_FOUND: ${KALDI_NATIVE_IO_FOUND}")
|
target_include_directories(kaldi_native_io_core
|
||||||
message(STATUS "KALDI_NATIVE_IO_VERSION: ${KALDI_NATIVE_IO_VERSION}")
|
PUBLIC
|
||||||
message(STATUS "KALDI_NATIVE_IO_INCLUDE_DIRS: ${KALDI_NATIVE_IO_INCLUDE_DIRS}")
|
${kaldi_native_io_SOURCE_DIR}/
|
||||||
message(STATUS "KALDI_NATIVE_IO_CXX_FLAGS: ${KALDI_NATIVE_IO_CXX_FLAGS}")
|
)
|
||||||
message(STATUS "KALDI_NATIVE_IO_LIBRARIES: ${KALDI_NATIVE_IO_LIBRARIES}")
|
endfunction()
|
||||||
|
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${KALDI_NATIVE_IO_CXX_FLAGS}")
|
download_kaldi_native_io()
|
||||||
message(STATUS "CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
|
|
||||||
|
|||||||
55
cmake/onnxruntime.cmake
Normal file
55
cmake/onnxruntime.cmake
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
function(download_onnxruntime)
|
||||||
|
if(CMAKE_VERSION VERSION_LESS 3.11)
|
||||||
|
# FetchContent is available since 3.11,
|
||||||
|
# we've copied it to ${CMAKE_SOURCE_DIR}/cmake/Modules
|
||||||
|
# so that it can be used in lower CMake versions.
|
||||||
|
message(STATUS "Use FetchContent provided by sherpa-onnx")
|
||||||
|
list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
include(FetchContent)
|
||||||
|
|
||||||
|
if(UNIX AND NOT APPLE)
|
||||||
|
# set(onnxruntime_URL "http://github.com/microsoft/onnxruntime/releases/download/v1.12.1/onnxruntime-linux-x64-1.12.1.tgz")
|
||||||
|
|
||||||
|
# If you don't have access to the internet, you can first download onnxruntime to some directory, and the use
|
||||||
|
# set(onnxruntime_URL "file:///ceph-fj/fangjun/open-source/sherpa-onnx/onnxruntime-linux-x64-1.12.1.tgz")
|
||||||
|
|
||||||
|
set(onnxruntime_HASH "SHA256=8f6eb9e2da9cf74e7905bf3fc687ef52e34cc566af7af2f92dafe5a5d106aa3d")
|
||||||
|
# After downloading, it contains:
|
||||||
|
# ./lib/libonnxruntime.so.1.12.1
|
||||||
|
# ./lib/libonnxruntime.so, which is a symlink to lib/libonnxruntime.so.1.12.1
|
||||||
|
#
|
||||||
|
# ./include
|
||||||
|
# It contains all the needed header files
|
||||||
|
else()
|
||||||
|
message(FATAL_ERROR "Only support Linux at present. Will support other OSes later")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
FetchContent_Declare(onnxruntime
|
||||||
|
URL ${onnxruntime_URL}
|
||||||
|
URL_HASH ${onnxruntime_HASH}
|
||||||
|
)
|
||||||
|
|
||||||
|
FetchContent_GetProperties(onnxruntime)
|
||||||
|
if(NOT onnxruntime_POPULATED)
|
||||||
|
message(STATUS "Downloading onnxruntime ${onnxruntime_URL}")
|
||||||
|
FetchContent_Populate(onnxruntime)
|
||||||
|
endif()
|
||||||
|
message(STATUS "onnxruntime is downloaded to ${onnxruntime_SOURCE_DIR}")
|
||||||
|
|
||||||
|
find_library(location_onnxruntime onnxruntime
|
||||||
|
PATHS
|
||||||
|
"${onnxruntime_SOURCE_DIR}/lib"
|
||||||
|
)
|
||||||
|
|
||||||
|
message(STATUS "location_onnxruntime: ${location_onnxruntime}")
|
||||||
|
|
||||||
|
add_library(onnxruntime SHARED IMPORTED)
|
||||||
|
set_target_properties(onnxruntime PROPERTIES
|
||||||
|
IMPORTED_LOCATION ${location_onnxruntime}
|
||||||
|
INTERFACE_INCLUDE_DIRECTORIES "${onnxruntime_SOURCE_DIR}/include"
|
||||||
|
)
|
||||||
|
endfunction()
|
||||||
|
|
||||||
|
download_onnxruntime()
|
||||||
@@ -1,13 +1,8 @@
|
|||||||
add_executable(online-fbank-test online-fbank-test.cc)
|
include_directories(${CMAKE_SOURCE_DIR})
|
||||||
target_link_libraries(online-fbank-test kaldi-native-fbank-core)
|
|
||||||
|
|
||||||
include_directories(
|
|
||||||
${ONNXRUNTIME_ROOTDIR}/include/onnxruntime/core/session/
|
|
||||||
${ONNXRUNTIME_ROOTDIR}/include/onnxruntime/core/providers/tensorrt/
|
|
||||||
)
|
|
||||||
|
|
||||||
include_directories(
|
|
||||||
${KALDINATIVEIO}
|
|
||||||
)
|
|
||||||
add_executable(sherpa-onnx main.cpp)
|
add_executable(sherpa-onnx main.cpp)
|
||||||
target_link_libraries(sherpa-onnx onnxruntime kaldi-native-fbank-core kaldi_native_io_core)
|
|
||||||
|
target_link_libraries(sherpa-onnx
|
||||||
|
onnxruntime
|
||||||
|
kaldi-native-fbank-core
|
||||||
|
kaldi_native_io_core
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,25 +1,24 @@
|
|||||||
#include <vector>
|
|
||||||
#include <iostream>
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <time.h>
|
|
||||||
#include <math.h>
|
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
|
#include <iostream>
|
||||||
|
#include <math.h>
|
||||||
|
#include <time.h>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "fbank_features.h"
|
#include "sherpa-onnx/csrc/fbank_features.h"
|
||||||
#include "rnnt_beam_search.h"
|
#include "sherpa-onnx/csrc/rnnt_beam_search.h"
|
||||||
|
|
||||||
#include "kaldi-native-fbank/csrc/online-feature.h"
|
#include "kaldi-native-fbank/csrc/online-feature.h"
|
||||||
|
|
||||||
|
int main(int argc, char *argv[]) {
|
||||||
int main(int argc, char* argv[]) {
|
char *encoder_path = argv[1];
|
||||||
char* encoder_path = argv[1];
|
char *decoder_path = argv[2];
|
||||||
char* decoder_path = argv[2];
|
char *joiner_path = argv[3];
|
||||||
char* joiner_path = argv[3];
|
char *joiner_encoder_proj_path = argv[4];
|
||||||
char* joiner_encoder_proj_path = argv[4];
|
char *joiner_decoder_proj_path = argv[5];
|
||||||
char* joiner_decoder_proj_path = argv[5];
|
char *token_path = argv[6];
|
||||||
char* token_path = argv[6];
|
|
||||||
std::string search_method = argv[7];
|
std::string search_method = argv[7];
|
||||||
char* filename = argv[8];
|
char *filename = argv[8];
|
||||||
|
|
||||||
// General parameters
|
// General parameters
|
||||||
int numberOfThreads = 16;
|
int numberOfThreads = 16;
|
||||||
@@ -39,29 +38,26 @@ int main(int argc, char* argv[]) {
|
|||||||
// https://onnxruntime.ai/docs/performance/tune-performance.html
|
// https://onnxruntime.ai/docs/performance/tune-performance.html
|
||||||
session_options.SetIntraOpNumThreads(numberOfThreads);
|
session_options.SetIntraOpNumThreads(numberOfThreads);
|
||||||
session_options.SetInterOpNumThreads(numberOfThreads);
|
session_options.SetInterOpNumThreads(numberOfThreads);
|
||||||
session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
|
session_options.SetGraphOptimizationLevel(
|
||||||
|
GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
|
||||||
session_options.SetLogSeverityLevel(4);
|
session_options.SetLogSeverityLevel(4);
|
||||||
session_options.SetExecutionMode(ExecutionMode::ORT_SEQUENTIAL);
|
session_options.SetExecutionMode(ExecutionMode::ORT_SEQUENTIAL);
|
||||||
|
|
||||||
api.CreateTensorRTProviderOptions(&tensorrt_options);
|
api.CreateTensorRTProviderOptions(&tensorrt_options);
|
||||||
std::unique_ptr<OrtTensorRTProviderOptionsV2, decltype(api.ReleaseTensorRTProviderOptions)> rel_trt_options(tensorrt_options, api.ReleaseTensorRTProviderOptions);
|
std::unique_ptr<OrtTensorRTProviderOptionsV2,
|
||||||
api.SessionOptionsAppendExecutionProvider_TensorRT_V2(static_cast<OrtSessionOptions*>(session_options), rel_trt_options.get());
|
decltype(api.ReleaseTensorRTProviderOptions)>
|
||||||
|
rel_trt_options(tensorrt_options, api.ReleaseTensorRTProviderOptions);
|
||||||
|
api.SessionOptionsAppendExecutionProvider_TensorRT_V2(
|
||||||
|
static_cast<OrtSessionOptions *>(session_options), rel_trt_options.get());
|
||||||
|
|
||||||
// Define model
|
// Define model
|
||||||
auto model = get_model(
|
auto model =
|
||||||
encoder_path,
|
get_model(encoder_path, decoder_path, joiner_path,
|
||||||
decoder_path,
|
joiner_encoder_proj_path, joiner_decoder_proj_path, token_path);
|
||||||
joiner_path,
|
|
||||||
joiner_encoder_proj_path,
|
|
||||||
joiner_decoder_proj_path,
|
|
||||||
token_path
|
|
||||||
);
|
|
||||||
|
|
||||||
std::vector<std::string> filename_list {
|
std::vector<std::string> filename_list{filename};
|
||||||
filename
|
|
||||||
};
|
|
||||||
|
|
||||||
for (auto filename : filename_list){
|
for (auto filename : filename_list) {
|
||||||
std::cout << filename << std::endl;
|
std::cout << filename << std::endl;
|
||||||
auto samples = readWav(filename, true);
|
auto samples = readWav(filename, true);
|
||||||
int numSamples = samples.NumCols();
|
int numSamples = samples.NumCols();
|
||||||
@@ -72,26 +68,28 @@ int main(int argc, char* argv[]) {
|
|||||||
|
|
||||||
// # === Encoder Out === #
|
// # === Encoder Out === #
|
||||||
int num_frames = features.size() / opts.mel_opts.num_bins;
|
int num_frames = features.size() / opts.mel_opts.num_bins;
|
||||||
auto encoder_out = model.encoder_forward(features,
|
auto encoder_out =
|
||||||
std::vector<int64_t> {num_frames},
|
model.encoder_forward(features, std::vector<int64_t>{num_frames},
|
||||||
std::vector<int64_t> {1, num_frames, 80},
|
std::vector<int64_t>{1, num_frames, 80},
|
||||||
std::vector<int64_t> {1},
|
std::vector<int64_t>{1}, memory_info);
|
||||||
memory_info);
|
|
||||||
|
|
||||||
// # === Search === #
|
// # === Search === #
|
||||||
std::vector<std::vector<int32_t>> hyps;
|
std::vector<std::vector<int32_t>> hyps;
|
||||||
if (search_method == "greedy")
|
if (search_method == "greedy")
|
||||||
hyps = GreedySearch(&model, &encoder_out);
|
hyps = GreedySearch(&model, &encoder_out);
|
||||||
else{
|
else {
|
||||||
std::cout << "wrong search method!" << std::endl;
|
std::cout << "wrong search method!" << std::endl;
|
||||||
exit(0);
|
exit(0);
|
||||||
}
|
}
|
||||||
auto results = hyps2result(model.tokens_map, hyps);
|
auto results = hyps2result(model.tokens_map, hyps);
|
||||||
|
|
||||||
// # === Print Elapsed Time === #
|
// # === Print Elapsed Time === #
|
||||||
auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::high_resolution_clock::now() - tic);
|
auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(
|
||||||
std::cout << "Elapsed: " << float(elapsed.count()) / 1000 << " seconds" << std::endl;
|
std::chrono::high_resolution_clock::now() - tic);
|
||||||
std::cout << "rtf: " << float(elapsed.count()) / 1000 / (numSamples / 16000) << std::endl;
|
std::cout << "Elapsed: " << float(elapsed.count()) / 1000 << " seconds"
|
||||||
|
<< std::endl;
|
||||||
|
std::cout << "rtf: " << float(elapsed.count()) / 1000 / (numSamples / 16000)
|
||||||
|
<< std::endl;
|
||||||
|
|
||||||
print_hyps(hyps);
|
print_hyps(hyps);
|
||||||
std::cout << results[0] << std::endl;
|
std::cout << results[0] << std::endl;
|
||||||
|
|||||||
@@ -61,7 +61,6 @@ std::vector<std::vector<int32_t>> GreedySearch(
|
|||||||
auto projected_encoder_out = model->joiner_encoder_proj_forward(encoder_out_vector,
|
auto projected_encoder_out = model->joiner_encoder_proj_forward(encoder_out_vector,
|
||||||
std::vector<int64_t> {encoder_out_dim1, encoder_out_dim2},
|
std::vector<int64_t> {encoder_out_dim1, encoder_out_dim2},
|
||||||
memory_info);
|
memory_info);
|
||||||
|
|
||||||
Ort::Value &projected_encoder_out_tensor = projected_encoder_out[0];
|
Ort::Value &projected_encoder_out_tensor = projected_encoder_out[0];
|
||||||
int projected_encoder_out_dim1 = projected_encoder_out_tensor.GetTensorTypeAndShapeInfo().GetShape()[0];
|
int projected_encoder_out_dim1 = projected_encoder_out_tensor.GetTensorTypeAndShapeInfo().GetShape()[0];
|
||||||
int projected_encoder_out_dim2 = projected_encoder_out_tensor.GetTensorTypeAndShapeInfo().GetShape()[1];
|
int projected_encoder_out_dim2 = projected_encoder_out_tensor.GetTensorTypeAndShapeInfo().GetShape()[1];
|
||||||
@@ -78,12 +77,12 @@ std::vector<std::vector<int32_t>> GreedySearch(
|
|||||||
|
|
||||||
auto logits = model->joiner_forward(cur_encoder_out,
|
auto logits = model->joiner_forward(cur_encoder_out,
|
||||||
projected_decoder_out_vector,
|
projected_decoder_out_vector,
|
||||||
std::vector<int64_t> {1, 1, 1, projected_encoder_out_dim2},
|
std::vector<int64_t> {1, projected_encoder_out_dim2},
|
||||||
std::vector<int64_t> {1, 1, 1, projected_decoder_out_dim},
|
std::vector<int64_t> {1, projected_decoder_out_dim},
|
||||||
memory_info);
|
memory_info);
|
||||||
|
|
||||||
Ort::Value &logits_tensor = logits[0];
|
Ort::Value &logits_tensor = logits[0];
|
||||||
int logits_dim = logits_tensor.GetTensorTypeAndShapeInfo().GetShape()[3];
|
int logits_dim = logits_tensor.GetTensorTypeAndShapeInfo().GetShape()[1];
|
||||||
auto logits_vector = ortVal2Vector(logits_tensor, logits_dim);
|
auto logits_vector = ortVal2Vector(logits_tensor, logits_dim);
|
||||||
|
|
||||||
int max_indices = static_cast<int>(std::distance(logits_vector.begin(), std::max_element(logits_vector.begin(), logits_vector.end())));
|
int max_indices = static_cast<int>(std::distance(logits_vector.begin(), std::max_element(logits_vector.begin(), logits_vector.end())));
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <onnxruntime_cxx_api.h>
|
#include "onnxruntime_cxx_api.h"
|
||||||
|
|
||||||
Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "test");
|
Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "test");
|
||||||
const auto& api = Ort::GetApi();
|
const auto& api = Ort::GetApi();
|
||||||
|
|||||||
Reference in New Issue
Block a user