feat: add directml support (#1153)

This commit is contained in:
thewh1teagle
2024-07-22 18:50:48 +03:00
committed by GitHub
parent ea1d81bdfe
commit d32a46169f
6 changed files with 218 additions and 7 deletions

View File

@@ -30,6 +30,7 @@ option(SHERPA_ONNX_ENABLE_JNI "Whether to build JNI internface" OFF)
option(SHERPA_ONNX_ENABLE_C_API "Whether to build C API" ON)
option(SHERPA_ONNX_ENABLE_WEBSOCKET "Whether to build webscoket server/client" ON)
option(SHERPA_ONNX_ENABLE_GPU "Enable ONNX Runtime GPU support" OFF)
option(SHERPA_ONNX_ENABLE_DIRECTML "Enable ONNX Runtime DirectML support" OFF)
option(SHERPA_ONNX_ENABLE_WASM "Whether to enable WASM" OFF)
option(SHERPA_ONNX_ENABLE_WASM_TTS "Whether to enable WASM for TTS" OFF)
option(SHERPA_ONNX_ENABLE_WASM_ASR "Whether to enable WASM for ASR" OFF)
@@ -94,6 +95,19 @@ to install CUDA toolkit if you have not installed it.")
endif()
endif()
if(SHERPA_ONNX_ENABLE_DIRECTML)
message(WARNING "\
Compiling with DirectML enabled. Please make sure Windows 10 SDK
is installed on your system. Otherwise, you will get errors at runtime.
Please refer to
https://onnxruntime.ai/docs/execution-providers/DirectML-ExecutionProvider.html#requirements
to install Windows 10 SDK if you have not installed it.")
if(NOT BUILD_SHARED_LIBS)
message(STATUS "Set BUILD_SHARED_LIBS to ON since SHERPA_ONNX_ENABLE_DIRECTML is ON")
set(BUILD_SHARED_LIBS ON CACHE BOOL "" FORCE)
endif()
endif()
# see https://cmake.org/cmake/help/latest/prop_tgt/MSVC_RUNTIME_LIBRARY.html
# https://stackoverflow.com/questions/14172856/compile-with-mt-instead-of-md-using-cmake
if(MSVC)
@@ -160,6 +174,14 @@ else()
add_definitions(-DSHERPA_ONNX_ENABLE_TTS=0)
endif()
if(SHERPA_ONNX_ENABLE_DIRECTML)
message(STATUS "DirectML is enabled")
add_definitions(-DSHERPA_ONNX_ENABLE_DIRECTML=1)
else()
message(WARNING "DirectML is disabled")
add_definitions(-DSHERPA_ONNX_ENABLE_DIRECTML=0)
endif()
if(SHERPA_ONNX_ENABLE_WASM_TTS)
if(NOT SHERPA_ONNX_ENABLE_TTS)
message(FATAL_ERROR "Please set SHERPA_ONNX_ENABLE_TTS to ON if you want to build wasm TTS")