updating trt workspace int64 (#1094)

Signed-off-by: Manix <manickavela1998@gmail.com>
This commit is contained in:
Manix
2024-07-08 18:08:16 +05:30
committed by GitHub
parent 4fd0493037
commit 3e4307e2fb
3 changed files with 4 additions and 4 deletions

View File

@@ -60,7 +60,7 @@ void TensorrtConfig::Register(ParseOptions *po) {
bool TensorrtConfig::Validate() const { bool TensorrtConfig::Validate() const {
if (trt_max_workspace_size < 0) { if (trt_max_workspace_size < 0) {
SHERPA_ONNX_LOGE("trt_max_workspace_size: %d is not valid.", SHERPA_ONNX_LOGE("trt_max_workspace_size: %lld is not valid.",
trt_max_workspace_size); trt_max_workspace_size);
return false; return false;
} }

View File

@@ -27,7 +27,7 @@ struct CudaConfig {
}; };
struct TensorrtConfig { struct TensorrtConfig {
int32_t trt_max_workspace_size = 2147483647; int64_t trt_max_workspace_size = 2147483647;
int32_t trt_max_partition_iterations = 10; int32_t trt_max_partition_iterations = 10;
int32_t trt_min_subgraph_size = 5; int32_t trt_min_subgraph_size = 5;
bool trt_fp16_enable = true; bool trt_fp16_enable = true;
@@ -39,7 +39,7 @@ struct TensorrtConfig {
bool trt_dump_subgraphs = false; bool trt_dump_subgraphs = false;
TensorrtConfig() = default; TensorrtConfig() = default;
TensorrtConfig(int32_t trt_max_workspace_size, TensorrtConfig(int64_t trt_max_workspace_size,
int32_t trt_max_partition_iterations, int32_t trt_max_partition_iterations,
int32_t trt_min_subgraph_size, int32_t trt_min_subgraph_size,
bool trt_fp16_enable, bool trt_fp16_enable,

View File

@@ -14,7 +14,7 @@ void PybindTensorrtConfig(py::module *m) {
using PyClass = TensorrtConfig; using PyClass = TensorrtConfig;
py::class_<PyClass>(*m, "TensorrtConfig") py::class_<PyClass>(*m, "TensorrtConfig")
.def(py::init<>()) .def(py::init<>())
.def(py::init([](int32_t trt_max_workspace_size, .def(py::init([](int64_t trt_max_workspace_size,
int32_t trt_max_partition_iterations, int32_t trt_max_partition_iterations,
int32_t trt_min_subgraph_size, int32_t trt_min_subgraph_size,
bool trt_fp16_enable, bool trt_fp16_enable,