feat: add directml support (#1153)
This commit is contained in:
@@ -26,6 +26,8 @@ Provider StringToProvider(std::string s) {
|
||||
return Provider::kNNAPI;
|
||||
} else if (s == "trt") {
|
||||
return Provider::kTRT;
|
||||
} else if (s == "directml") {
|
||||
return Provider::kDirectML;
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Unsupported string: %s. Fallback to cpu", s.c_str());
|
||||
return Provider::kCPU;
|
||||
|
||||
@@ -14,12 +14,13 @@ namespace sherpa_onnx {
|
||||
// https://github.com/microsoft/onnxruntime/blob/main/java/src/main/java/ai/onnxruntime/OrtProvider.java
|
||||
// for a list of available providers
|
||||
enum class Provider {
|
||||
kCPU = 0, // CPUExecutionProvider
|
||||
kCUDA = 1, // CUDAExecutionProvider
|
||||
kCoreML = 2, // CoreMLExecutionProvider
|
||||
kXnnpack = 3, // XnnpackExecutionProvider
|
||||
kNNAPI = 4, // NnapiExecutionProvider
|
||||
kTRT = 5, // TensorRTExecutionProvider
|
||||
kCPU = 0, // CPUExecutionProvider
|
||||
kCUDA = 1, // CUDAExecutionProvider
|
||||
kCoreML = 2, // CoreMLExecutionProvider
|
||||
kXnnpack = 3, // XnnpackExecutionProvider
|
||||
kNNAPI = 4, // NnapiExecutionProvider
|
||||
kTRT = 5, // TensorRTExecutionProvider
|
||||
kDirectML = 6, // DmlExecutionProvider
|
||||
};
|
||||
|
||||
/**
|
||||
|
||||
@@ -19,6 +19,10 @@
|
||||
#include "nnapi_provider_factory.h" // NOLINT
|
||||
#endif
|
||||
|
||||
#if defined(_WIN32) && SHERPA_ONNX_ENABLE_DIRECTML == 1
|
||||
#include "dml_provider_factory.h" // NOLINT
|
||||
#endif
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static void OrtStatusFailure(OrtStatus *status, const char *s) {
|
||||
@@ -167,6 +171,24 @@ static Ort::SessionOptions GetSessionOptionsImpl(
|
||||
}
|
||||
break;
|
||||
}
|
||||
case Provider::kDirectML: {
|
||||
#if defined(_WIN32) && SHERPA_ONNX_ENABLE_DIRECTML == 1
|
||||
sess_opts.DisableMemPattern();
|
||||
sess_opts.SetExecutionMode(ORT_SEQUENTIAL);
|
||||
int32_t device_id = 0;
|
||||
OrtStatus *status =
|
||||
OrtSessionOptionsAppendExecutionProvider_DML(sess_opts, device_id);
|
||||
if (status) {
|
||||
const auto &api = Ort::GetApi();
|
||||
const char *msg = api.GetErrorMessage(status);
|
||||
SHERPA_ONNX_LOGE("Failed to enable DirectML: %s. Fallback to cpu", msg);
|
||||
api.ReleaseStatus(status);
|
||||
}
|
||||
#else
|
||||
SHERPA_ONNX_LOGE("DirectML is for Windows only. Fallback to cpu!");
|
||||
#endif
|
||||
break;
|
||||
}
|
||||
case Provider::kCoreML: {
|
||||
#if defined(__APPLE__)
|
||||
uint32_t coreml_flags = 0;
|
||||
|
||||
Reference in New Issue
Block a user