feat: add directml support (#1153)
This commit is contained in:
@@ -30,6 +30,7 @@ option(SHERPA_ONNX_ENABLE_JNI "Whether to build JNI internface" OFF)
|
|||||||
option(SHERPA_ONNX_ENABLE_C_API "Whether to build C API" ON)
|
option(SHERPA_ONNX_ENABLE_C_API "Whether to build C API" ON)
|
||||||
option(SHERPA_ONNX_ENABLE_WEBSOCKET "Whether to build webscoket server/client" ON)
|
option(SHERPA_ONNX_ENABLE_WEBSOCKET "Whether to build webscoket server/client" ON)
|
||||||
option(SHERPA_ONNX_ENABLE_GPU "Enable ONNX Runtime GPU support" OFF)
|
option(SHERPA_ONNX_ENABLE_GPU "Enable ONNX Runtime GPU support" OFF)
|
||||||
|
option(SHERPA_ONNX_ENABLE_DIRECTML "Enable ONNX Runtime DirectML support" OFF)
|
||||||
option(SHERPA_ONNX_ENABLE_WASM "Whether to enable WASM" OFF)
|
option(SHERPA_ONNX_ENABLE_WASM "Whether to enable WASM" OFF)
|
||||||
option(SHERPA_ONNX_ENABLE_WASM_TTS "Whether to enable WASM for TTS" OFF)
|
option(SHERPA_ONNX_ENABLE_WASM_TTS "Whether to enable WASM for TTS" OFF)
|
||||||
option(SHERPA_ONNX_ENABLE_WASM_ASR "Whether to enable WASM for ASR" OFF)
|
option(SHERPA_ONNX_ENABLE_WASM_ASR "Whether to enable WASM for ASR" OFF)
|
||||||
@@ -94,6 +95,19 @@ to install CUDA toolkit if you have not installed it.")
|
|||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
if(SHERPA_ONNX_ENABLE_DIRECTML)
|
||||||
|
message(WARNING "\
|
||||||
|
Compiling with DirectML enabled. Please make sure Windows 10 SDK
|
||||||
|
is installed on your system. Otherwise, you will get errors at runtime.
|
||||||
|
Please refer to
|
||||||
|
https://onnxruntime.ai/docs/execution-providers/DirectML-ExecutionProvider.html#requirements
|
||||||
|
to install Windows 10 SDK if you have not installed it.")
|
||||||
|
if(NOT BUILD_SHARED_LIBS)
|
||||||
|
message(STATUS "Set BUILD_SHARED_LIBS to ON since SHERPA_ONNX_ENABLE_DIRECTML is ON")
|
||||||
|
set(BUILD_SHARED_LIBS ON CACHE BOOL "" FORCE)
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
# see https://cmake.org/cmake/help/latest/prop_tgt/MSVC_RUNTIME_LIBRARY.html
|
# see https://cmake.org/cmake/help/latest/prop_tgt/MSVC_RUNTIME_LIBRARY.html
|
||||||
# https://stackoverflow.com/questions/14172856/compile-with-mt-instead-of-md-using-cmake
|
# https://stackoverflow.com/questions/14172856/compile-with-mt-instead-of-md-using-cmake
|
||||||
if(MSVC)
|
if(MSVC)
|
||||||
@@ -160,6 +174,14 @@ else()
|
|||||||
add_definitions(-DSHERPA_ONNX_ENABLE_TTS=0)
|
add_definitions(-DSHERPA_ONNX_ENABLE_TTS=0)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
if(SHERPA_ONNX_ENABLE_DIRECTML)
|
||||||
|
message(STATUS "DirectML is enabled")
|
||||||
|
add_definitions(-DSHERPA_ONNX_ENABLE_DIRECTML=1)
|
||||||
|
else()
|
||||||
|
message(WARNING "DirectML is disabled")
|
||||||
|
add_definitions(-DSHERPA_ONNX_ENABLE_DIRECTML=0)
|
||||||
|
endif()
|
||||||
|
|
||||||
if(SHERPA_ONNX_ENABLE_WASM_TTS)
|
if(SHERPA_ONNX_ENABLE_WASM_TTS)
|
||||||
if(NOT SHERPA_ONNX_ENABLE_TTS)
|
if(NOT SHERPA_ONNX_ENABLE_TTS)
|
||||||
message(FATAL_ERROR "Please set SHERPA_ONNX_ENABLE_TTS to ON if you want to build wasm TTS")
|
message(FATAL_ERROR "Please set SHERPA_ONNX_ENABLE_TTS to ON if you want to build wasm TTS")
|
||||||
|
|||||||
161
cmake/onnxruntime-win-x64-directml.cmake
Normal file
161
cmake/onnxruntime-win-x64-directml.cmake
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
# Copyright (c) 2022-2023 Xiaomi Corporation
|
||||||
|
message(STATUS "CMAKE_SYSTEM_NAME: ${CMAKE_SYSTEM_NAME}")
|
||||||
|
message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}")
|
||||||
|
message(STATUS "CMAKE_VS_PLATFORM_NAME: ${CMAKE_VS_PLATFORM_NAME}")
|
||||||
|
|
||||||
|
if(NOT CMAKE_SYSTEM_NAME STREQUAL Windows)
|
||||||
|
message(FATAL_ERROR "This file is for Windows only. Given: ${CMAKE_SYSTEM_NAME}")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if(NOT (CMAKE_VS_PLATFORM_NAME STREQUAL X64 OR CMAKE_VS_PLATFORM_NAME STREQUAL x64))
|
||||||
|
message(FATAL_ERROR "This file is for Windows x64 only. Given: ${CMAKE_VS_PLATFORM_NAME}")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if(NOT BUILD_SHARED_LIBS)
|
||||||
|
message(FATAL_ERROR "This file is for building shared libraries. BUILD_SHARED_LIBS: ${BUILD_SHARED_LIBS}")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if(NOT SHERPA_ONNX_ENABLE_DIRECTML)
|
||||||
|
message(FATAL_ERROR "This file is for DirectML. Given SHERPA_ONNX_ENABLE_DIRECTML: ${SHERPA_ONNX_ENABLE_DIRECTML}")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
set(onnxruntime_URL "https://globalcdn.nuget.org/packages/microsoft.ml.onnxruntime.directml.1.14.1.nupkg")
|
||||||
|
set(onnxruntime_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/microsoft.ml.onnxruntime.directml.1.14.1.nupkg")
|
||||||
|
set(onnxruntime_HASH "SHA256=c8ae7623385b19cd5de968d0df5383e13b97d1b3a6771c9177eac15b56013a5a")
|
||||||
|
|
||||||
|
# If you don't have access to the Internet,
|
||||||
|
# please download onnxruntime to one of the following locations.
|
||||||
|
# You can add more if you want.
|
||||||
|
set(possible_file_locations
|
||||||
|
$ENV{HOME}/Downloads/microsoft.ml.onnxruntime.directml.1.14.1.nupkg
|
||||||
|
${PROJECT_SOURCE_DIR}/microsoft.ml.onnxruntime.directml.1.14.1.nupkg
|
||||||
|
${PROJECT_BINARY_DIR}/microsoft.ml.onnxruntime.directml.1.14.1.nupkg
|
||||||
|
/tmp/microsoft.ml.onnxruntime.directml.1.14.1.nupkg
|
||||||
|
)
|
||||||
|
|
||||||
|
foreach(f IN LISTS possible_file_locations)
|
||||||
|
if(EXISTS ${f})
|
||||||
|
set(onnxruntime_URL "${f}")
|
||||||
|
file(TO_CMAKE_PATH "${onnxruntime_URL}" onnxruntime_URL)
|
||||||
|
message(STATUS "Found local downloaded onnxruntime: ${onnxruntime_URL}")
|
||||||
|
set(onnxruntime_URL2)
|
||||||
|
break()
|
||||||
|
endif()
|
||||||
|
endforeach()
|
||||||
|
|
||||||
|
FetchContent_Declare(onnxruntime
|
||||||
|
URL
|
||||||
|
${onnxruntime_URL}
|
||||||
|
${onnxruntime_URL2}
|
||||||
|
URL_HASH ${onnxruntime_HASH}
|
||||||
|
)
|
||||||
|
|
||||||
|
FetchContent_GetProperties(onnxruntime)
|
||||||
|
if(NOT onnxruntime_POPULATED)
|
||||||
|
message(STATUS "Downloading onnxruntime from ${onnxruntime_URL}")
|
||||||
|
FetchContent_Populate(onnxruntime)
|
||||||
|
endif()
|
||||||
|
message(STATUS "onnxruntime is downloaded to ${onnxruntime_SOURCE_DIR}")
|
||||||
|
|
||||||
|
find_library(location_onnxruntime onnxruntime
|
||||||
|
PATHS
|
||||||
|
"${onnxruntime_SOURCE_DIR}/runtimes/win-x64/native"
|
||||||
|
NO_CMAKE_SYSTEM_PATH
|
||||||
|
)
|
||||||
|
|
||||||
|
message(STATUS "location_onnxruntime: ${location_onnxruntime}")
|
||||||
|
|
||||||
|
add_library(onnxruntime SHARED IMPORTED)
|
||||||
|
|
||||||
|
set_target_properties(onnxruntime PROPERTIES
|
||||||
|
IMPORTED_LOCATION ${location_onnxruntime}
|
||||||
|
INTERFACE_INCLUDE_DIRECTORIES "${onnxruntime_SOURCE_DIR}/build/native/include"
|
||||||
|
)
|
||||||
|
|
||||||
|
set_property(TARGET onnxruntime
|
||||||
|
PROPERTY
|
||||||
|
IMPORTED_IMPLIB "${onnxruntime_SOURCE_DIR}/runtimes/win-x64/native/onnxruntime.lib"
|
||||||
|
)
|
||||||
|
|
||||||
|
file(COPY ${onnxruntime_SOURCE_DIR}/runtimes/win-x64/native/onnxruntime.dll
|
||||||
|
DESTINATION
|
||||||
|
${CMAKE_BINARY_DIR}/bin/${CMAKE_BUILD_TYPE}
|
||||||
|
)
|
||||||
|
|
||||||
|
file(GLOB onnxruntime_lib_files "${onnxruntime_SOURCE_DIR}/runtimes/win-x64/native/onnxruntime.*")
|
||||||
|
|
||||||
|
message(STATUS "onnxruntime lib files: ${onnxruntime_lib_files}")
|
||||||
|
|
||||||
|
if(SHERPA_ONNX_ENABLE_PYTHON)
|
||||||
|
install(FILES ${onnxruntime_lib_files} DESTINATION ..)
|
||||||
|
else()
|
||||||
|
install(FILES ${onnxruntime_lib_files} DESTINATION lib)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
install(FILES ${onnxruntime_lib_files} DESTINATION bin)
|
||||||
|
|
||||||
|
# Setup DirectML
|
||||||
|
|
||||||
|
set(directml_URL "https://www.nuget.org/api/v2/package/Microsoft.AI.DirectML/1.15.0")
|
||||||
|
set(directml_HASH "SHA256=10d175f8e97447712b3680e3ac020bbb8eafdf651332b48f09ffee2eec801c23")
|
||||||
|
|
||||||
|
set(possible_directml_file_locations
|
||||||
|
$ENV{HOME}/Downloads/Microsoft.AI.DirectML.1.15.0.nupkg
|
||||||
|
${PROJECT_SOURCE_DIR}/Microsoft.AI.DirectML.1.15.0.nupkg
|
||||||
|
${PROJECT_BINARY_DIR}/Microsoft.AI.DirectML.1.15.0.nupkg
|
||||||
|
/tmp/Microsoft.AI.DirectML.1.15.0.nupkg
|
||||||
|
)
|
||||||
|
|
||||||
|
foreach(f IN LISTS possible_directml_file_locations)
|
||||||
|
if(EXISTS ${f})
|
||||||
|
set(directml_URL "${f}")
|
||||||
|
file(TO_CMAKE_PATH "${directml_URL}" directml_URL)
|
||||||
|
message(STATUS "Found local downloaded DirectML: ${directml_URL}")
|
||||||
|
break()
|
||||||
|
endif()
|
||||||
|
endforeach()
|
||||||
|
|
||||||
|
FetchContent_Declare(directml
|
||||||
|
URL
|
||||||
|
${directml_URL}
|
||||||
|
URL_HASH ${directml_HASH}
|
||||||
|
)
|
||||||
|
|
||||||
|
FetchContent_GetProperties(directml)
|
||||||
|
if(NOT directml_POPULATED)
|
||||||
|
message(STATUS "Downloading DirectML from ${directml_URL}")
|
||||||
|
FetchContent_Populate(directml)
|
||||||
|
endif()
|
||||||
|
message(STATUS "DirectML is downloaded to ${directml_SOURCE_DIR}")
|
||||||
|
|
||||||
|
find_library(location_directml DirectML
|
||||||
|
PATHS
|
||||||
|
"${directml_SOURCE_DIR}/bin/x64-win"
|
||||||
|
NO_CMAKE_SYSTEM_PATH
|
||||||
|
)
|
||||||
|
|
||||||
|
message(STATUS "location_directml: ${location_directml}")
|
||||||
|
|
||||||
|
add_library(directml SHARED IMPORTED)
|
||||||
|
|
||||||
|
set_target_properties(directml PROPERTIES
|
||||||
|
IMPORTED_LOCATION ${location_directml}
|
||||||
|
INTERFACE_INCLUDE_DIRECTORIES "${directml_SOURCE_DIR}/bin/x64-win"
|
||||||
|
)
|
||||||
|
|
||||||
|
set_property(TARGET directml
|
||||||
|
PROPERTY
|
||||||
|
IMPORTED_IMPLIB "${directml_SOURCE_DIR}/bin/x64-win/DirectML.lib"
|
||||||
|
)
|
||||||
|
|
||||||
|
file(COPY ${directml_SOURCE_DIR}/bin/x64-win/DirectML.dll
|
||||||
|
DESTINATION
|
||||||
|
${CMAKE_BINARY_DIR}/bin/${CMAKE_BUILD_TYPE}
|
||||||
|
)
|
||||||
|
|
||||||
|
file(GLOB directml_lib_files "${directml_SOURCE_DIR}/bin/x64-win/DirectML.*")
|
||||||
|
|
||||||
|
message(STATUS "DirectML lib files: ${directml_lib_files}")
|
||||||
|
|
||||||
|
install(FILES ${directml_lib_files} DESTINATION lib)
|
||||||
|
install(FILES ${directml_lib_files} DESTINATION bin)
|
||||||
@@ -95,7 +95,10 @@ function(download_onnxruntime)
|
|||||||
include(onnxruntime-win-arm64)
|
include(onnxruntime-win-arm64)
|
||||||
else()
|
else()
|
||||||
# for 64-bit windows (x64)
|
# for 64-bit windows (x64)
|
||||||
if(BUILD_SHARED_LIBS)
|
if(SHERPA_ONNX_ENABLE_DIRECTML)
|
||||||
|
message(STATUS "Use DirectML")
|
||||||
|
include(onnxruntime-win-x64-directml)
|
||||||
|
elseif(BUILD_SHARED_LIBS)
|
||||||
message(STATUS "Use dynamic onnxruntime libraries")
|
message(STATUS "Use dynamic onnxruntime libraries")
|
||||||
if(SHERPA_ONNX_ENABLE_GPU)
|
if(SHERPA_ONNX_ENABLE_GPU)
|
||||||
include(onnxruntime-win-x64-gpu)
|
include(onnxruntime-win-x64-gpu)
|
||||||
|
|||||||
@@ -26,6 +26,8 @@ Provider StringToProvider(std::string s) {
|
|||||||
return Provider::kNNAPI;
|
return Provider::kNNAPI;
|
||||||
} else if (s == "trt") {
|
} else if (s == "trt") {
|
||||||
return Provider::kTRT;
|
return Provider::kTRT;
|
||||||
|
} else if (s == "directml") {
|
||||||
|
return Provider::kDirectML;
|
||||||
} else {
|
} else {
|
||||||
SHERPA_ONNX_LOGE("Unsupported string: %s. Fallback to cpu", s.c_str());
|
SHERPA_ONNX_LOGE("Unsupported string: %s. Fallback to cpu", s.c_str());
|
||||||
return Provider::kCPU;
|
return Provider::kCPU;
|
||||||
|
|||||||
@@ -14,12 +14,13 @@ namespace sherpa_onnx {
|
|||||||
// https://github.com/microsoft/onnxruntime/blob/main/java/src/main/java/ai/onnxruntime/OrtProvider.java
|
// https://github.com/microsoft/onnxruntime/blob/main/java/src/main/java/ai/onnxruntime/OrtProvider.java
|
||||||
// for a list of available providers
|
// for a list of available providers
|
||||||
enum class Provider {
|
enum class Provider {
|
||||||
kCPU = 0, // CPUExecutionProvider
|
kCPU = 0, // CPUExecutionProvider
|
||||||
kCUDA = 1, // CUDAExecutionProvider
|
kCUDA = 1, // CUDAExecutionProvider
|
||||||
kCoreML = 2, // CoreMLExecutionProvider
|
kCoreML = 2, // CoreMLExecutionProvider
|
||||||
kXnnpack = 3, // XnnpackExecutionProvider
|
kXnnpack = 3, // XnnpackExecutionProvider
|
||||||
kNNAPI = 4, // NnapiExecutionProvider
|
kNNAPI = 4, // NnapiExecutionProvider
|
||||||
kTRT = 5, // TensorRTExecutionProvider
|
kTRT = 5, // TensorRTExecutionProvider
|
||||||
|
kDirectML = 6, // DmlExecutionProvider
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@@ -19,6 +19,10 @@
|
|||||||
#include "nnapi_provider_factory.h" // NOLINT
|
#include "nnapi_provider_factory.h" // NOLINT
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if defined(_WIN32) && SHERPA_ONNX_ENABLE_DIRECTML == 1
|
||||||
|
#include "dml_provider_factory.h" // NOLINT
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
static void OrtStatusFailure(OrtStatus *status, const char *s) {
|
static void OrtStatusFailure(OrtStatus *status, const char *s) {
|
||||||
@@ -167,6 +171,24 @@ static Ort::SessionOptions GetSessionOptionsImpl(
|
|||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case Provider::kDirectML: {
|
||||||
|
#if defined(_WIN32) && SHERPA_ONNX_ENABLE_DIRECTML == 1
|
||||||
|
sess_opts.DisableMemPattern();
|
||||||
|
sess_opts.SetExecutionMode(ORT_SEQUENTIAL);
|
||||||
|
int32_t device_id = 0;
|
||||||
|
OrtStatus *status =
|
||||||
|
OrtSessionOptionsAppendExecutionProvider_DML(sess_opts, device_id);
|
||||||
|
if (status) {
|
||||||
|
const auto &api = Ort::GetApi();
|
||||||
|
const char *msg = api.GetErrorMessage(status);
|
||||||
|
SHERPA_ONNX_LOGE("Failed to enable DirectML: %s. Fallback to cpu", msg);
|
||||||
|
api.ReleaseStatus(status);
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
SHERPA_ONNX_LOGE("DirectML is for Windows only. Fallback to cpu!");
|
||||||
|
#endif
|
||||||
|
break;
|
||||||
|
}
|
||||||
case Provider::kCoreML: {
|
case Provider::kCoreML: {
|
||||||
#if defined(__APPLE__)
|
#if defined(__APPLE__)
|
||||||
uint32_t coreml_flags = 0;
|
uint32_t coreml_flags = 0;
|
||||||
|
|||||||
Reference in New Issue
Block a user