Add onnxruntime gpu for cmake (#153)
* add onnxruntime gpu for cmake * fix clang * fix typo * cpplint
This commit is contained in:
@@ -19,6 +19,7 @@ option(SHERPA_ONNX_ENABLE_PORTAUDIO "Whether to build with portaudio" ON)
|
|||||||
option(SHERPA_ONNX_ENABLE_JNI "Whether to build JNI internface" OFF)
|
option(SHERPA_ONNX_ENABLE_JNI "Whether to build JNI internface" OFF)
|
||||||
option(SHERPA_ONNX_ENABLE_C_API "Whether to build C API" ON)
|
option(SHERPA_ONNX_ENABLE_C_API "Whether to build C API" ON)
|
||||||
option(SHERPA_ONNX_ENABLE_WEBSOCKET "Whether to build webscoket server/client" ON)
|
option(SHERPA_ONNX_ENABLE_WEBSOCKET "Whether to build webscoket server/client" ON)
|
||||||
|
option(SHERPA_ONNX_ENABLE_GPU "Enable ONNX Runtime GPU support" OFF)
|
||||||
|
|
||||||
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
|
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
|
||||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
|
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
|
||||||
@@ -71,6 +72,7 @@ message(STATUS "SHERPA_ONNX_ENABLE_PORTAUDIO ${SHERPA_ONNX_ENABLE_PORTAUDIO}")
|
|||||||
message(STATUS "SHERPA_ONNX_ENABLE_JNI ${SHERPA_ONNX_ENABLE_JNI}")
|
message(STATUS "SHERPA_ONNX_ENABLE_JNI ${SHERPA_ONNX_ENABLE_JNI}")
|
||||||
message(STATUS "SHERPA_ONNX_ENABLE_C_API ${SHERPA_ONNX_ENABLE_C_API}")
|
message(STATUS "SHERPA_ONNX_ENABLE_C_API ${SHERPA_ONNX_ENABLE_C_API}")
|
||||||
message(STATUS "SHERPA_ONNX_ENABLE_WEBSOCKET ${SHERPA_ONNX_ENABLE_WEBSOCKET}")
|
message(STATUS "SHERPA_ONNX_ENABLE_WEBSOCKET ${SHERPA_ONNX_ENABLE_WEBSOCKET}")
|
||||||
|
message(STATUS "SHERPA_ONNX_ENABLE_GPU ${SHERPA_ONNX_ENABLE_GPU}")
|
||||||
|
|
||||||
set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.")
|
set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.")
|
||||||
set(CMAKE_CXX_EXTENSIONS OFF)
|
set(CMAKE_CXX_EXTENSIONS OFF)
|
||||||
|
|||||||
@@ -33,6 +33,14 @@ function(download_onnxruntime)
|
|||||||
#
|
#
|
||||||
# ./include
|
# ./include
|
||||||
# It contains all the needed header files
|
# It contains all the needed header files
|
||||||
|
if(SHERPA_ONNX_ENABLE_GPU)
|
||||||
|
set(onnxruntime_URL "https://github.com/microsoft/onnxruntime/releases/download/v1.14.1/onnxruntime-linux-x64-gpu-1.14.1.tgz")
|
||||||
|
endif()
|
||||||
|
# After downloading, it contains:
|
||||||
|
# ./lib/libonnxruntime.so.1.14.1
|
||||||
|
# ./lib/libonnxruntime.so, which is a symlink to lib/libonnxruntime.so.1.14.1
|
||||||
|
# ./lib/libonnxruntime_providers_cuda.so
|
||||||
|
# ./include, which contains all the needed header files
|
||||||
elseif(APPLE)
|
elseif(APPLE)
|
||||||
# If you don't have access to the Internet,
|
# If you don't have access to the Internet,
|
||||||
# please pre-download onnxruntime
|
# please pre-download onnxruntime
|
||||||
@@ -97,21 +105,28 @@ function(download_onnxruntime)
|
|||||||
message(FATAL_ERROR "Only support Linux, macOS, and Windows at present. Will support other OSes later")
|
message(FATAL_ERROR "Only support Linux, macOS, and Windows at present. Will support other OSes later")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
foreach(f IN LISTS possible_file_locations)
|
if(NOT SHERPA_ONNX_ENABLE_GPU)
|
||||||
if(EXISTS ${f})
|
foreach(f IN LISTS possible_file_locations)
|
||||||
set(onnxruntime_URL "${f}")
|
if(EXISTS ${f})
|
||||||
file(TO_CMAKE_PATH "${onnxruntime_URL}" onnxruntime_URL)
|
set(onnxruntime_URL "${f}")
|
||||||
set(onnxruntime_URL2)
|
file(TO_CMAKE_PATH "${onnxruntime_URL}" onnxruntime_URL)
|
||||||
break()
|
set(onnxruntime_URL2)
|
||||||
endif()
|
break()
|
||||||
endforeach()
|
endif()
|
||||||
|
endforeach()
|
||||||
|
|
||||||
FetchContent_Declare(onnxruntime
|
FetchContent_Declare(onnxruntime
|
||||||
URL
|
URL
|
||||||
${onnxruntime_URL}
|
${onnxruntime_URL}
|
||||||
${onnxruntime_URL2}
|
${onnxruntime_URL2}
|
||||||
URL_HASH ${onnxruntime_HASH}
|
URL_HASH ${onnxruntime_HASH}
|
||||||
)
|
)
|
||||||
|
else()
|
||||||
|
FetchContent_Declare(onnxruntime
|
||||||
|
URL
|
||||||
|
${onnxruntime_URL}
|
||||||
|
)
|
||||||
|
endif()
|
||||||
|
|
||||||
FetchContent_GetProperties(onnxruntime)
|
FetchContent_GetProperties(onnxruntime)
|
||||||
if(NOT onnxruntime_POPULATED)
|
if(NOT onnxruntime_POPULATED)
|
||||||
@@ -134,6 +149,19 @@ function(download_onnxruntime)
|
|||||||
IMPORTED_LOCATION ${location_onnxruntime}
|
IMPORTED_LOCATION ${location_onnxruntime}
|
||||||
INTERFACE_INCLUDE_DIRECTORIES "${onnxruntime_SOURCE_DIR}/include"
|
INTERFACE_INCLUDE_DIRECTORIES "${onnxruntime_SOURCE_DIR}/include"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if(SHERPA_ONNX_ENABLE_GPU)
|
||||||
|
find_library(location_onnxruntime_cuda_lib onnxruntime_providers_cuda
|
||||||
|
PATHS
|
||||||
|
"${onnxruntime_SOURCE_DIR}/lib"
|
||||||
|
NO_CMAKE_SYSTEM_PATH
|
||||||
|
)
|
||||||
|
add_library(onnxruntime_providers_cuda SHARED IMPORTED)
|
||||||
|
set_target_properties(onnxruntime_providers_cuda PROPERTIES
|
||||||
|
IMPORTED_LOCATION ${location_onnxruntime_cuda_lib}
|
||||||
|
)
|
||||||
|
endif()
|
||||||
|
|
||||||
if(WIN32)
|
if(WIN32)
|
||||||
set_property(TARGET onnxruntime
|
set_property(TARGET onnxruntime
|
||||||
PROPERTY
|
PROPERTY
|
||||||
@@ -185,6 +213,12 @@ if(DEFINED ENV{SHERPA_ONNXRUNTIME_LIB_DIR})
|
|||||||
if(NOT EXISTS ${location_onnxruntime_lib})
|
if(NOT EXISTS ${location_onnxruntime_lib})
|
||||||
set(location_onnxruntime_lib $ENV{SHERPA_ONNXRUNTIME_LIB_DIR}/libonnxruntime.a)
|
set(location_onnxruntime_lib $ENV{SHERPA_ONNXRUNTIME_LIB_DIR}/libonnxruntime.a)
|
||||||
endif()
|
endif()
|
||||||
|
if(SHERPA_ONNX_ENABLE_GPU)
|
||||||
|
set(location_onnxruntime_cuda_lib $ENV{SHERPA_ONNXRUNTIME_LIB_DIR}/libonnxruntime_providers_cuda.so)
|
||||||
|
if(NOT EXISTS ${location_onnxruntime_cuda_lib})
|
||||||
|
set(location_onnxruntime_cuda_lib $ENV{SHERPA_ONNXRUNTIME_LIB_DIR}/libonnxruntime_providers_cuda.a)
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
else()
|
else()
|
||||||
find_library(location_onnxruntime_lib onnxruntime
|
find_library(location_onnxruntime_lib onnxruntime
|
||||||
PATHS
|
PATHS
|
||||||
@@ -192,9 +226,21 @@ else()
|
|||||||
/usr/lib
|
/usr/lib
|
||||||
/usr/local/lib
|
/usr/local/lib
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if(SHERPA_ONNX_ENABLE_GPU)
|
||||||
|
find_library(location_onnxruntime_cuda_lib onnxruntime_providers_cuda
|
||||||
|
PATHS
|
||||||
|
/lib
|
||||||
|
/usr/lib
|
||||||
|
/usr/local/lib
|
||||||
|
)
|
||||||
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
message(STATUS "location_onnxruntime_lib: ${location_onnxruntime_lib}")
|
message(STATUS "location_onnxruntime_lib: ${location_onnxruntime_lib}")
|
||||||
|
if(SHERPA_ONNX_ENABLE_GPU)
|
||||||
|
message(STATUS "location_onnxruntime_cuda_lib: ${location_onnxruntime_cuda_lib}")
|
||||||
|
endif()
|
||||||
|
|
||||||
if(location_onnxruntime_header_dir AND location_onnxruntime_lib)
|
if(location_onnxruntime_header_dir AND location_onnxruntime_lib)
|
||||||
add_library(onnxruntime SHARED IMPORTED)
|
add_library(onnxruntime SHARED IMPORTED)
|
||||||
@@ -202,6 +248,12 @@ if(location_onnxruntime_header_dir AND location_onnxruntime_lib)
|
|||||||
IMPORTED_LOCATION ${location_onnxruntime_lib}
|
IMPORTED_LOCATION ${location_onnxruntime_lib}
|
||||||
INTERFACE_INCLUDE_DIRECTORIES "${location_onnxruntime_header_dir}"
|
INTERFACE_INCLUDE_DIRECTORIES "${location_onnxruntime_header_dir}"
|
||||||
)
|
)
|
||||||
|
if(SHERPA_ONNX_ENABLE_GPU AND location_onnxruntime_cuda_lib)
|
||||||
|
add_library(onnxruntime_providers_cuda SHARED IMPORTED)
|
||||||
|
set_target_properties(onnxruntime_providers_cuda PROPERTIES
|
||||||
|
IMPORTED_LOCATION ${location_onnxruntime_cuda_lib}
|
||||||
|
)
|
||||||
|
endif()
|
||||||
else()
|
else()
|
||||||
message(STATUS "Could not find a pre-installed onnxruntime. Downloading pre-compiled onnxruntime")
|
message(STATUS "Could not find a pre-installed onnxruntime. Downloading pre-compiled onnxruntime")
|
||||||
download_onnxruntime()
|
download_onnxruntime()
|
||||||
|
|||||||
@@ -78,6 +78,12 @@ target_link_libraries(sherpa-onnx-core
|
|||||||
kaldi-native-fbank-core
|
kaldi-native-fbank-core
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if(SHERPA_ONNX_ENABLE_GPU)
|
||||||
|
target_link_libraries(sherpa-onnx-core
|
||||||
|
onnxruntime_providers_cuda
|
||||||
|
)
|
||||||
|
endif()
|
||||||
|
|
||||||
if(SHERPA_ONNX_ENABLE_CHECK)
|
if(SHERPA_ONNX_ENABLE_CHECK)
|
||||||
target_compile_definitions(sherpa-onnx-core PUBLIC SHERPA_ONNX_ENABLE_CHECK=1)
|
target_compile_definitions(sherpa-onnx-core PUBLIC SHERPA_ONNX_ENABLE_CHECK=1)
|
||||||
|
|
||||||
|
|||||||
@@ -4,8 +4,10 @@
|
|||||||
|
|
||||||
#include "sherpa-onnx/csrc/session.h"
|
#include "sherpa-onnx/csrc/session.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "sherpa-onnx/csrc/macros.h"
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
#include "sherpa-onnx/csrc/provider.h"
|
#include "sherpa-onnx/csrc/provider.h"
|
||||||
@@ -27,10 +29,20 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
|
|||||||
case Provider::kCPU:
|
case Provider::kCPU:
|
||||||
break; // nothing to do for the CPU provider
|
break; // nothing to do for the CPU provider
|
||||||
case Provider::kCUDA: {
|
case Provider::kCUDA: {
|
||||||
OrtCUDAProviderOptions options;
|
std::vector<std::string> available_providers =
|
||||||
options.device_id = 0;
|
Ort::GetAvailableProviders();
|
||||||
// set more options on need
|
if (std::find(available_providers.begin(), available_providers.end(),
|
||||||
sess_opts.AppendExecutionProvider_CUDA(options);
|
"CUDAExecutionProvider") != available_providers.end()) {
|
||||||
|
// The CUDA provider is available, proceed with setting the options
|
||||||
|
OrtCUDAProviderOptions options;
|
||||||
|
options.device_id = 0;
|
||||||
|
// set more options on need
|
||||||
|
sess_opts.AppendExecutionProvider_CUDA(options);
|
||||||
|
} else {
|
||||||
|
SHERPA_ONNX_LOGE(
|
||||||
|
"Please compile with -DSHERPA_ONNX_ENABLE_GPU=ON. Fallback to "
|
||||||
|
"cpu!");
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case Provider::kCoreML: {
|
case Provider::kCoreML: {
|
||||||
|
|||||||
Reference in New Issue
Block a user