diff --git a/CMakeLists.txt b/CMakeLists.txt index 209f5e33..15bd7895 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -19,6 +19,7 @@ option(SHERPA_ONNX_ENABLE_PORTAUDIO "Whether to build with portaudio" ON) option(SHERPA_ONNX_ENABLE_JNI "Whether to build JNI internface" OFF) option(SHERPA_ONNX_ENABLE_C_API "Whether to build C API" ON) option(SHERPA_ONNX_ENABLE_WEBSOCKET "Whether to build webscoket server/client" ON) +option(SHERPA_ONNX_ENABLE_GPU "Enable ONNX Runtime GPU support" OFF) set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") @@ -71,6 +72,7 @@ message(STATUS "SHERPA_ONNX_ENABLE_PORTAUDIO ${SHERPA_ONNX_ENABLE_PORTAUDIO}") message(STATUS "SHERPA_ONNX_ENABLE_JNI ${SHERPA_ONNX_ENABLE_JNI}") message(STATUS "SHERPA_ONNX_ENABLE_C_API ${SHERPA_ONNX_ENABLE_C_API}") message(STATUS "SHERPA_ONNX_ENABLE_WEBSOCKET ${SHERPA_ONNX_ENABLE_WEBSOCKET}") +message(STATUS "SHERPA_ONNX_ENABLE_GPU ${SHERPA_ONNX_ENABLE_GPU}") set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.") set(CMAKE_CXX_EXTENSIONS OFF) diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake index 202b586c..00078231 100644 --- a/cmake/onnxruntime.cmake +++ b/cmake/onnxruntime.cmake @@ -33,6 +33,14 @@ function(download_onnxruntime) # # ./include # It contains all the needed header files + if(SHERPA_ONNX_ENABLE_GPU) + set(onnxruntime_URL "https://github.com/microsoft/onnxruntime/releases/download/v1.14.1/onnxruntime-linux-x64-gpu-1.14.1.tgz") + endif() + # After downloading, it contains: + # ./lib/libonnxruntime.so.1.14.1 + # ./lib/libonnxruntime.so, which is a symlink to lib/libonnxruntime.so.1.14.1 + # ./lib/libonnxruntime_providers_cuda.so + # ./include, which contains all the needed header files elseif(APPLE) # If you don't have access to the Internet, # please pre-download onnxruntime @@ -97,21 +105,28 @@ function(download_onnxruntime) message(FATAL_ERROR "Only support Linux, macOS, and Windows at present. Will support other OSes later") endif() - foreach(f IN LISTS possible_file_locations) - if(EXISTS ${f}) - set(onnxruntime_URL "${f}") - file(TO_CMAKE_PATH "${onnxruntime_URL}" onnxruntime_URL) - set(onnxruntime_URL2) - break() - endif() - endforeach() + if(NOT SHERPA_ONNX_ENABLE_GPU) + foreach(f IN LISTS possible_file_locations) + if(EXISTS ${f}) + set(onnxruntime_URL "${f}") + file(TO_CMAKE_PATH "${onnxruntime_URL}" onnxruntime_URL) + set(onnxruntime_URL2) + break() + endif() + endforeach() - FetchContent_Declare(onnxruntime - URL - ${onnxruntime_URL} - ${onnxruntime_URL2} - URL_HASH ${onnxruntime_HASH} - ) + FetchContent_Declare(onnxruntime + URL + ${onnxruntime_URL} + ${onnxruntime_URL2} + URL_HASH ${onnxruntime_HASH} + ) + else() + FetchContent_Declare(onnxruntime + URL + ${onnxruntime_URL} + ) + endif() FetchContent_GetProperties(onnxruntime) if(NOT onnxruntime_POPULATED) @@ -134,6 +149,19 @@ function(download_onnxruntime) IMPORTED_LOCATION ${location_onnxruntime} INTERFACE_INCLUDE_DIRECTORIES "${onnxruntime_SOURCE_DIR}/include" ) + + if(SHERPA_ONNX_ENABLE_GPU) + find_library(location_onnxruntime_cuda_lib onnxruntime_providers_cuda + PATHS + "${onnxruntime_SOURCE_DIR}/lib" + NO_CMAKE_SYSTEM_PATH + ) + add_library(onnxruntime_providers_cuda SHARED IMPORTED) + set_target_properties(onnxruntime_providers_cuda PROPERTIES + IMPORTED_LOCATION ${location_onnxruntime_cuda_lib} + ) + endif() + if(WIN32) set_property(TARGET onnxruntime PROPERTY @@ -185,6 +213,12 @@ if(DEFINED ENV{SHERPA_ONNXRUNTIME_LIB_DIR}) if(NOT EXISTS ${location_onnxruntime_lib}) set(location_onnxruntime_lib $ENV{SHERPA_ONNXRUNTIME_LIB_DIR}/libonnxruntime.a) endif() + if(SHERPA_ONNX_ENABLE_GPU) + set(location_onnxruntime_cuda_lib $ENV{SHERPA_ONNXRUNTIME_LIB_DIR}/libonnxruntime_providers_cuda.so) + if(NOT EXISTS ${location_onnxruntime_cuda_lib}) + set(location_onnxruntime_cuda_lib $ENV{SHERPA_ONNXRUNTIME_LIB_DIR}/libonnxruntime_providers_cuda.a) + endif() + endif() else() find_library(location_onnxruntime_lib onnxruntime PATHS @@ -192,9 +226,21 @@ else() /usr/lib /usr/local/lib ) + + if(SHERPA_ONNX_ENABLE_GPU) + find_library(location_onnxruntime_cuda_lib onnxruntime_providers_cuda + PATHS + /lib + /usr/lib + /usr/local/lib + ) + endif() endif() message(STATUS "location_onnxruntime_lib: ${location_onnxruntime_lib}") +if(SHERPA_ONNX_ENABLE_GPU) + message(STATUS "location_onnxruntime_cuda_lib: ${location_onnxruntime_cuda_lib}") +endif() if(location_onnxruntime_header_dir AND location_onnxruntime_lib) add_library(onnxruntime SHARED IMPORTED) @@ -202,6 +248,12 @@ if(location_onnxruntime_header_dir AND location_onnxruntime_lib) IMPORTED_LOCATION ${location_onnxruntime_lib} INTERFACE_INCLUDE_DIRECTORIES "${location_onnxruntime_header_dir}" ) + if(SHERPA_ONNX_ENABLE_GPU AND location_onnxruntime_cuda_lib) + add_library(onnxruntime_providers_cuda SHARED IMPORTED) + set_target_properties(onnxruntime_providers_cuda PROPERTIES + IMPORTED_LOCATION ${location_onnxruntime_cuda_lib} + ) + endif() else() message(STATUS "Could not find a pre-installed onnxruntime. Downloading pre-compiled onnxruntime") download_onnxruntime() diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index a8efdf59..82a831fe 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -78,6 +78,12 @@ target_link_libraries(sherpa-onnx-core kaldi-native-fbank-core ) +if(SHERPA_ONNX_ENABLE_GPU) + target_link_libraries(sherpa-onnx-core + onnxruntime_providers_cuda + ) +endif() + if(SHERPA_ONNX_ENABLE_CHECK) target_compile_definitions(sherpa-onnx-core PUBLIC SHERPA_ONNX_ENABLE_CHECK=1) diff --git a/sherpa-onnx/csrc/session.cc b/sherpa-onnx/csrc/session.cc index 9920ec17..49979873 100644 --- a/sherpa-onnx/csrc/session.cc +++ b/sherpa-onnx/csrc/session.cc @@ -4,8 +4,10 @@ #include "sherpa-onnx/csrc/session.h" +#include #include #include +#include #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/provider.h" @@ -27,10 +29,20 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, case Provider::kCPU: break; // nothing to do for the CPU provider case Provider::kCUDA: { - OrtCUDAProviderOptions options; - options.device_id = 0; - // set more options on need - sess_opts.AppendExecutionProvider_CUDA(options); + std::vector available_providers = + Ort::GetAvailableProviders(); + if (std::find(available_providers.begin(), available_providers.end(), + "CUDAExecutionProvider") != available_providers.end()) { + // The CUDA provider is available, proceed with setting the options + OrtCUDAProviderOptions options; + options.device_id = 0; + // set more options on need + sess_opts.AppendExecutionProvider_CUDA(options); + } else { + SHERPA_ONNX_LOGE( + "Please compile with -DSHERPA_ONNX_ENABLE_GPU=ON. Fallback to " + "cpu!"); + } break; } case Provider::kCoreML: {