Support TensorRT provider (#921)
Signed-off-by: manickavela1998@gmail.com <manickavela1998@gmail.com> Signed-off-by: manickavela1998@gmail.com <manickavela.arumugam@uniphore.com>
This commit is contained in:
@@ -24,6 +24,8 @@ Provider StringToProvider(std::string s) {
|
|||||||
return Provider::kXnnpack;
|
return Provider::kXnnpack;
|
||||||
} else if (s == "nnapi") {
|
} else if (s == "nnapi") {
|
||||||
return Provider::kNNAPI;
|
return Provider::kNNAPI;
|
||||||
|
} else if (s == "trt") {
|
||||||
|
return Provider::kTRT;
|
||||||
} else {
|
} else {
|
||||||
SHERPA_ONNX_LOGE("Unsupported string: %s. Fallback to cpu", s.c_str());
|
SHERPA_ONNX_LOGE("Unsupported string: %s. Fallback to cpu", s.c_str());
|
||||||
return Provider::kCPU;
|
return Provider::kCPU;
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ enum class Provider {
|
|||||||
kCoreML = 2, // CoreMLExecutionProvider
|
kCoreML = 2, // CoreMLExecutionProvider
|
||||||
kXnnpack = 3, // XnnpackExecutionProvider
|
kXnnpack = 3, // XnnpackExecutionProvider
|
||||||
kNNAPI = 4, // NnapiExecutionProvider
|
kNNAPI = 4, // NnapiExecutionProvider
|
||||||
|
kTRT = 5, // TensorRTExecutionProvider
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@@ -21,6 +21,16 @@
|
|||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
|
||||||
|
static void OrtStatusFailure(OrtStatus *status, const char *s) {
|
||||||
|
const auto &api = Ort::GetApi();
|
||||||
|
const char *msg = api.GetErrorMessage(status);
|
||||||
|
SHERPA_ONNX_LOGE(
|
||||||
|
"Failed to enable TensorRT : %s."
|
||||||
|
"Available providers: %s. Fallback to cuda", msg, s);
|
||||||
|
api.ReleaseStatus(status);
|
||||||
|
}
|
||||||
|
|
||||||
static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
|
static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
|
||||||
std::string provider_str) {
|
std::string provider_str) {
|
||||||
Provider p = StringToProvider(std::move(provider_str));
|
Provider p = StringToProvider(std::move(provider_str));
|
||||||
@@ -53,6 +63,57 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
|
|||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case Provider::kTRT: {
|
||||||
|
struct TrtPairs {
|
||||||
|
const char* op_keys;
|
||||||
|
const char* op_values;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<TrtPairs> trt_options = {
|
||||||
|
{"device_id", "0"},
|
||||||
|
{"trt_max_workspace_size", "2147483648"},
|
||||||
|
{"trt_max_partition_iterations", "10"},
|
||||||
|
{"trt_min_subgraph_size", "5"},
|
||||||
|
{"trt_fp16_enable", "0"},
|
||||||
|
{"trt_detailed_build_log", "0"},
|
||||||
|
{"trt_engine_cache_enable", "1"},
|
||||||
|
{"trt_engine_cache_path", "."},
|
||||||
|
{"trt_timing_cache_enable", "1"},
|
||||||
|
{"trt_timing_cache_path", "."}
|
||||||
|
};
|
||||||
|
// ToDo : Trt configs
|
||||||
|
// "trt_int8_enable"
|
||||||
|
// "trt_int8_use_native_calibration_table"
|
||||||
|
// "trt_dump_subgraphs"
|
||||||
|
|
||||||
|
std::vector<const char*> option_keys, option_values;
|
||||||
|
for (const TrtPairs& pair : trt_options) {
|
||||||
|
option_keys.emplace_back(pair.op_keys);
|
||||||
|
option_values.emplace_back(pair.op_values);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::string> available_providers =
|
||||||
|
Ort::GetAvailableProviders();
|
||||||
|
if (std::find(available_providers.begin(), available_providers.end(),
|
||||||
|
"TensorrtExecutionProvider") != available_providers.end()) {
|
||||||
|
const auto& api = Ort::GetApi();
|
||||||
|
|
||||||
|
OrtTensorRTProviderOptionsV2* tensorrt_options;
|
||||||
|
OrtStatus *statusC = api.CreateTensorRTProviderOptions(
|
||||||
|
&tensorrt_options);
|
||||||
|
OrtStatus *statusU = api.UpdateTensorRTProviderOptions(
|
||||||
|
tensorrt_options, option_keys.data(), option_values.data(),
|
||||||
|
option_keys.size());
|
||||||
|
sess_opts.AppendExecutionProvider_TensorRT_V2(*tensorrt_options);
|
||||||
|
|
||||||
|
if (statusC) { OrtStatusFailure(statusC, os.str().c_str()); }
|
||||||
|
if (statusU) { OrtStatusFailure(statusU, os.str().c_str()); }
|
||||||
|
|
||||||
|
api.ReleaseTensorRTProviderOptions(tensorrt_options);
|
||||||
|
}
|
||||||
|
// break; is omitted here intentionally so that
|
||||||
|
// if TRT not available, CUDA will be used
|
||||||
|
}
|
||||||
case Provider::kCUDA: {
|
case Provider::kCUDA: {
|
||||||
if (std::find(available_providers.begin(), available_providers.end(),
|
if (std::find(available_providers.begin(), available_providers.end(),
|
||||||
"CUDAExecutionProvider") != available_providers.end()) {
|
"CUDAExecutionProvider") != available_providers.end()) {
|
||||||
@@ -116,7 +177,6 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return sess_opts;
|
return sess_opts;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user