Add C++ runtime for MeloTTS (#1138)
This commit is contained in:
@@ -31,8 +31,8 @@ static void OrtStatusFailure(OrtStatus *status, const char *s) {
|
||||
api.ReleaseStatus(status);
|
||||
}
|
||||
|
||||
static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
|
||||
const std::string &provider_str,
|
||||
static Ort::SessionOptions GetSessionOptionsImpl(
|
||||
int32_t num_threads, const std::string &provider_str,
|
||||
const ProviderConfig *provider_config = nullptr) {
|
||||
Provider p = StringToProvider(provider_str);
|
||||
|
||||
@@ -67,8 +67,9 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
|
||||
}
|
||||
case Provider::kTRT: {
|
||||
if (provider_config == nullptr) {
|
||||
SHERPA_ONNX_LOGE("Tensorrt support for Online models ony,"
|
||||
"Must be extended for offline and others");
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Tensorrt support for Online models ony,"
|
||||
"Must be extended for offline and others");
|
||||
exit(1);
|
||||
}
|
||||
auto trt_config = provider_config->trt_config;
|
||||
@@ -84,29 +85,27 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
|
||||
std::to_string(trt_config.trt_max_partition_iterations);
|
||||
auto trt_min_subgraph_size =
|
||||
std::to_string(trt_config.trt_min_subgraph_size);
|
||||
auto trt_fp16_enable =
|
||||
std::to_string(trt_config.trt_fp16_enable);
|
||||
auto trt_fp16_enable = std::to_string(trt_config.trt_fp16_enable);
|
||||
auto trt_detailed_build_log =
|
||||
std::to_string(trt_config.trt_detailed_build_log);
|
||||
auto trt_engine_cache_enable =
|
||||
std::to_string(trt_config.trt_engine_cache_enable);
|
||||
auto trt_timing_cache_enable =
|
||||
std::to_string(trt_config.trt_timing_cache_enable);
|
||||
auto trt_dump_subgraphs =
|
||||
std::to_string(trt_config.trt_dump_subgraphs);
|
||||
auto trt_dump_subgraphs = std::to_string(trt_config.trt_dump_subgraphs);
|
||||
std::vector<TrtPairs> trt_options = {
|
||||
{"device_id", device_id.c_str()},
|
||||
{"trt_max_workspace_size", trt_max_workspace_size.c_str()},
|
||||
{"trt_max_partition_iterations", trt_max_partition_iterations.c_str()},
|
||||
{"trt_min_subgraph_size", trt_min_subgraph_size.c_str()},
|
||||
{"trt_fp16_enable", trt_fp16_enable.c_str()},
|
||||
{"trt_detailed_build_log", trt_detailed_build_log.c_str()},
|
||||
{"trt_engine_cache_enable", trt_engine_cache_enable.c_str()},
|
||||
{"trt_engine_cache_path", trt_config.trt_engine_cache_path.c_str()},
|
||||
{"trt_timing_cache_enable", trt_timing_cache_enable.c_str()},
|
||||
{"trt_timing_cache_path", trt_config.trt_timing_cache_path.c_str()},
|
||||
{"trt_dump_subgraphs", trt_dump_subgraphs.c_str()}
|
||||
};
|
||||
{"device_id", device_id.c_str()},
|
||||
{"trt_max_workspace_size", trt_max_workspace_size.c_str()},
|
||||
{"trt_max_partition_iterations",
|
||||
trt_max_partition_iterations.c_str()},
|
||||
{"trt_min_subgraph_size", trt_min_subgraph_size.c_str()},
|
||||
{"trt_fp16_enable", trt_fp16_enable.c_str()},
|
||||
{"trt_detailed_build_log", trt_detailed_build_log.c_str()},
|
||||
{"trt_engine_cache_enable", trt_engine_cache_enable.c_str()},
|
||||
{"trt_engine_cache_path", trt_config.trt_engine_cache_path.c_str()},
|
||||
{"trt_timing_cache_enable", trt_timing_cache_enable.c_str()},
|
||||
{"trt_timing_cache_path", trt_config.trt_timing_cache_path.c_str()},
|
||||
{"trt_dump_subgraphs", trt_dump_subgraphs.c_str()}};
|
||||
// ToDo : Trt configs
|
||||
// "trt_int8_enable"
|
||||
// "trt_int8_use_native_calibration_table"
|
||||
@@ -151,9 +150,8 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
|
||||
|
||||
if (provider_config != nullptr) {
|
||||
options.device_id = provider_config->device;
|
||||
options.cudnn_conv_algo_search =
|
||||
OrtCudnnConvAlgoSearch(provider_config->cuda_config
|
||||
.cudnn_conv_algo_search);
|
||||
options.cudnn_conv_algo_search = OrtCudnnConvAlgoSearch(
|
||||
provider_config->cuda_config.cudnn_conv_algo_search);
|
||||
} else {
|
||||
options.device_id = 0;
|
||||
// Default OrtCudnnConvAlgoSearchExhaustive is extremely slow
|
||||
@@ -219,22 +217,24 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
|
||||
|
||||
Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config) {
|
||||
return GetSessionOptionsImpl(config.num_threads,
|
||||
config.provider_config.provider, &config.provider_config);
|
||||
config.provider_config.provider,
|
||||
&config.provider_config);
|
||||
}
|
||||
|
||||
Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config,
|
||||
const std::string &model_type) {
|
||||
const std::string &model_type) {
|
||||
/*
|
||||
Transducer models : Only encoder will run with tensorrt,
|
||||
decoder and joiner will run with cuda
|
||||
*/
|
||||
if(config.provider_config.provider == "trt" &&
|
||||
if (config.provider_config.provider == "trt" &&
|
||||
(model_type == "decoder" || model_type == "joiner")) {
|
||||
return GetSessionOptionsImpl(config.num_threads,
|
||||
"cuda", &config.provider_config);
|
||||
return GetSessionOptionsImpl(config.num_threads, "cuda",
|
||||
&config.provider_config);
|
||||
}
|
||||
return GetSessionOptionsImpl(config.num_threads,
|
||||
config.provider_config.provider, &config.provider_config);
|
||||
config.provider_config.provider,
|
||||
&config.provider_config);
|
||||
}
|
||||
|
||||
Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config) {
|
||||
|
||||
Reference in New Issue
Block a user