feat: add directml support (#1153)

This commit is contained in:
thewh1teagle
2024-07-22 18:50:48 +03:00
committed by GitHub
parent ea1d81bdfe
commit d32a46169f
6 changed files with 218 additions and 7 deletions

View File

@@ -26,6 +26,8 @@ Provider StringToProvider(std::string s) {
return Provider::kNNAPI;
} else if (s == "trt") {
return Provider::kTRT;
} else if (s == "directml") {
return Provider::kDirectML;
} else {
SHERPA_ONNX_LOGE("Unsupported string: %s. Fallback to cpu", s.c_str());
return Provider::kCPU;

View File

@@ -14,12 +14,13 @@ namespace sherpa_onnx {
// https://github.com/microsoft/onnxruntime/blob/main/java/src/main/java/ai/onnxruntime/OrtProvider.java
// for a list of available providers
enum class Provider {
kCPU = 0, // CPUExecutionProvider
kCUDA = 1, // CUDAExecutionProvider
kCoreML = 2, // CoreMLExecutionProvider
kXnnpack = 3, // XnnpackExecutionProvider
kNNAPI = 4, // NnapiExecutionProvider
kTRT = 5, // TensorRTExecutionProvider
kCPU = 0, // CPUExecutionProvider
kCUDA = 1, // CUDAExecutionProvider
kCoreML = 2, // CoreMLExecutionProvider
kXnnpack = 3, // XnnpackExecutionProvider
kNNAPI = 4, // NnapiExecutionProvider
kTRT = 5, // TensorRTExecutionProvider
kDirectML = 6, // DmlExecutionProvider
};
/**

View File

@@ -19,6 +19,10 @@
#include "nnapi_provider_factory.h" // NOLINT
#endif
#if defined(_WIN32) && SHERPA_ONNX_ENABLE_DIRECTML == 1
#include "dml_provider_factory.h" // NOLINT
#endif
namespace sherpa_onnx {
static void OrtStatusFailure(OrtStatus *status, const char *s) {
@@ -167,6 +171,24 @@ static Ort::SessionOptions GetSessionOptionsImpl(
}
break;
}
case Provider::kDirectML: {
#if defined(_WIN32) && SHERPA_ONNX_ENABLE_DIRECTML == 1
sess_opts.DisableMemPattern();
sess_opts.SetExecutionMode(ORT_SEQUENTIAL);
int32_t device_id = 0;
OrtStatus *status =
OrtSessionOptionsAppendExecutionProvider_DML(sess_opts, device_id);
if (status) {
const auto &api = Ort::GetApi();
const char *msg = api.GetErrorMessage(status);
SHERPA_ONNX_LOGE("Failed to enable DirectML: %s. Fallback to cpu", msg);
api.ReleaseStatus(status);
}
#else
SHERPA_ONNX_LOGE("DirectML is for Windows only. Fallback to cpu!");
#endif
break;
}
case Provider::kCoreML: {
#if defined(__APPLE__)
uint32_t coreml_flags = 0;