Support clang-tidy (#1034)
This commit is contained in:
@@ -21,14 +21,14 @@
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
|
||||
static void OrtStatusFailure(OrtStatus *status, const char *s) {
|
||||
const auto &api = Ort::GetApi();
|
||||
const char *msg = api.GetErrorMessage(status);
|
||||
SHERPA_ONNX_LOGE(
|
||||
const auto &api = Ort::GetApi();
|
||||
const char *msg = api.GetErrorMessage(status);
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Failed to enable TensorRT : %s."
|
||||
"Available providers: %s. Fallback to cuda", msg, s);
|
||||
api.ReleaseStatus(status);
|
||||
"Available providers: %s. Fallback to cuda",
|
||||
msg, s);
|
||||
api.ReleaseStatus(status);
|
||||
}
|
||||
|
||||
static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
|
||||
@@ -65,29 +65,28 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
|
||||
}
|
||||
case Provider::kTRT: {
|
||||
struct TrtPairs {
|
||||
const char* op_keys;
|
||||
const char* op_values;
|
||||
const char *op_keys;
|
||||
const char *op_values;
|
||||
};
|
||||
|
||||
std::vector<TrtPairs> trt_options = {
|
||||
{"device_id", "0"},
|
||||
{"trt_max_workspace_size", "2147483648"},
|
||||
{"trt_max_partition_iterations", "10"},
|
||||
{"trt_min_subgraph_size", "5"},
|
||||
{"trt_fp16_enable", "0"},
|
||||
{"trt_detailed_build_log", "0"},
|
||||
{"trt_engine_cache_enable", "1"},
|
||||
{"trt_engine_cache_path", "."},
|
||||
{"trt_timing_cache_enable", "1"},
|
||||
{"trt_timing_cache_path", "."}
|
||||
};
|
||||
{"device_id", "0"},
|
||||
{"trt_max_workspace_size", "2147483648"},
|
||||
{"trt_max_partition_iterations", "10"},
|
||||
{"trt_min_subgraph_size", "5"},
|
||||
{"trt_fp16_enable", "0"},
|
||||
{"trt_detailed_build_log", "0"},
|
||||
{"trt_engine_cache_enable", "1"},
|
||||
{"trt_engine_cache_path", "."},
|
||||
{"trt_timing_cache_enable", "1"},
|
||||
{"trt_timing_cache_path", "."}};
|
||||
// ToDo : Trt configs
|
||||
// "trt_int8_enable"
|
||||
// "trt_int8_use_native_calibration_table"
|
||||
// "trt_dump_subgraphs"
|
||||
|
||||
std::vector<const char*> option_keys, option_values;
|
||||
for (const TrtPairs& pair : trt_options) {
|
||||
std::vector<const char *> option_keys, option_values;
|
||||
for (const TrtPairs &pair : trt_options) {
|
||||
option_keys.emplace_back(pair.op_keys);
|
||||
option_values.emplace_back(pair.op_values);
|
||||
}
|
||||
@@ -95,19 +94,23 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
|
||||
std::vector<std::string> available_providers =
|
||||
Ort::GetAvailableProviders();
|
||||
if (std::find(available_providers.begin(), available_providers.end(),
|
||||
"TensorrtExecutionProvider") != available_providers.end()) {
|
||||
const auto& api = Ort::GetApi();
|
||||
"TensorrtExecutionProvider") != available_providers.end()) {
|
||||
const auto &api = Ort::GetApi();
|
||||
|
||||
OrtTensorRTProviderOptionsV2* tensorrt_options;
|
||||
OrtStatus *statusC = api.CreateTensorRTProviderOptions(
|
||||
&tensorrt_options);
|
||||
OrtTensorRTProviderOptionsV2 *tensorrt_options = nullptr;
|
||||
OrtStatus *statusC =
|
||||
api.CreateTensorRTProviderOptions(&tensorrt_options);
|
||||
OrtStatus *statusU = api.UpdateTensorRTProviderOptions(
|
||||
tensorrt_options, option_keys.data(), option_values.data(),
|
||||
option_keys.size());
|
||||
tensorrt_options, option_keys.data(), option_values.data(),
|
||||
option_keys.size());
|
||||
sess_opts.AppendExecutionProvider_TensorRT_V2(*tensorrt_options);
|
||||
|
||||
if (statusC) { OrtStatusFailure(statusC, os.str().c_str()); }
|
||||
if (statusU) { OrtStatusFailure(statusU, os.str().c_str()); }
|
||||
if (statusC) {
|
||||
OrtStatusFailure(statusC, os.str().c_str());
|
||||
}
|
||||
if (statusU) {
|
||||
OrtStatusFailure(statusU, os.str().c_str());
|
||||
}
|
||||
|
||||
api.ReleaseTensorRTProviderOptions(tensorrt_options);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user