feat: add directml support (#1153)
This commit is contained in:
@@ -19,6 +19,10 @@
|
||||
#include "nnapi_provider_factory.h" // NOLINT
|
||||
#endif
|
||||
|
||||
#if defined(_WIN32) && SHERPA_ONNX_ENABLE_DIRECTML == 1
|
||||
#include "dml_provider_factory.h" // NOLINT
|
||||
#endif
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static void OrtStatusFailure(OrtStatus *status, const char *s) {
|
||||
@@ -167,6 +171,24 @@ static Ort::SessionOptions GetSessionOptionsImpl(
|
||||
}
|
||||
break;
|
||||
}
|
||||
case Provider::kDirectML: {
|
||||
#if defined(_WIN32) && SHERPA_ONNX_ENABLE_DIRECTML == 1
|
||||
sess_opts.DisableMemPattern();
|
||||
sess_opts.SetExecutionMode(ORT_SEQUENTIAL);
|
||||
int32_t device_id = 0;
|
||||
OrtStatus *status =
|
||||
OrtSessionOptionsAppendExecutionProvider_DML(sess_opts, device_id);
|
||||
if (status) {
|
||||
const auto &api = Ort::GetApi();
|
||||
const char *msg = api.GetErrorMessage(status);
|
||||
SHERPA_ONNX_LOGE("Failed to enable DirectML: %s. Fallback to cpu", msg);
|
||||
api.ReleaseStatus(status);
|
||||
}
|
||||
#else
|
||||
SHERPA_ONNX_LOGE("DirectML is for Windows only. Fallback to cpu!");
|
||||
#endif
|
||||
break;
|
||||
}
|
||||
case Provider::kCoreML: {
|
||||
#if defined(__APPLE__)
|
||||
uint32_t coreml_flags = 0;
|
||||
|
||||
Reference in New Issue
Block a user