feat: add directml support (#1153)

This commit is contained in:
thewh1teagle
2024-07-22 18:50:48 +03:00
committed by GitHub
parent ea1d81bdfe
commit d32a46169f
6 changed files with 218 additions and 7 deletions

View File

@@ -19,6 +19,10 @@
#include "nnapi_provider_factory.h" // NOLINT
#endif
#if defined(_WIN32) && SHERPA_ONNX_ENABLE_DIRECTML == 1
#include "dml_provider_factory.h" // NOLINT
#endif
namespace sherpa_onnx {
static void OrtStatusFailure(OrtStatus *status, const char *s) {
@@ -167,6 +171,24 @@ static Ort::SessionOptions GetSessionOptionsImpl(
}
break;
}
case Provider::kDirectML: {
#if defined(_WIN32) && SHERPA_ONNX_ENABLE_DIRECTML == 1
sess_opts.DisableMemPattern();
sess_opts.SetExecutionMode(ORT_SEQUENTIAL);
int32_t device_id = 0;
OrtStatus *status =
OrtSessionOptionsAppendExecutionProvider_DML(sess_opts, device_id);
if (status) {
const auto &api = Ort::GetApi();
const char *msg = api.GetErrorMessage(status);
SHERPA_ONNX_LOGE("Failed to enable DirectML: %s. Fallback to cpu", msg);
api.ReleaseStatus(status);
}
#else
SHERPA_ONNX_LOGE("DirectML is for Windows only. Fallback to cpu!");
#endif
break;
}
case Provider::kCoreML: {
#if defined(__APPLE__)
uint32_t coreml_flags = 0;