Add C++ runtime for MeloTTS (#1138)

This commit is contained in:
Fangjun Kuang
2024-07-16 15:55:02 +08:00
committed by GitHub
parent 95485411fa
commit 960eb7529e
51 changed files with 693 additions and 156 deletions

View File

@@ -31,8 +31,8 @@ static void OrtStatusFailure(OrtStatus *status, const char *s) {
api.ReleaseStatus(status);
}
static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
const std::string &provider_str,
static Ort::SessionOptions GetSessionOptionsImpl(
int32_t num_threads, const std::string &provider_str,
const ProviderConfig *provider_config = nullptr) {
Provider p = StringToProvider(provider_str);
@@ -67,8 +67,9 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
}
case Provider::kTRT: {
if (provider_config == nullptr) {
SHERPA_ONNX_LOGE("Tensorrt support for Online models ony,"
"Must be extended for offline and others");
SHERPA_ONNX_LOGE(
"Tensorrt support for Online models ony,"
"Must be extended for offline and others");
exit(1);
}
auto trt_config = provider_config->trt_config;
@@ -84,29 +85,27 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
std::to_string(trt_config.trt_max_partition_iterations);
auto trt_min_subgraph_size =
std::to_string(trt_config.trt_min_subgraph_size);
auto trt_fp16_enable =
std::to_string(trt_config.trt_fp16_enable);
auto trt_fp16_enable = std::to_string(trt_config.trt_fp16_enable);
auto trt_detailed_build_log =
std::to_string(trt_config.trt_detailed_build_log);
auto trt_engine_cache_enable =
std::to_string(trt_config.trt_engine_cache_enable);
auto trt_timing_cache_enable =
std::to_string(trt_config.trt_timing_cache_enable);
auto trt_dump_subgraphs =
std::to_string(trt_config.trt_dump_subgraphs);
auto trt_dump_subgraphs = std::to_string(trt_config.trt_dump_subgraphs);
std::vector<TrtPairs> trt_options = {
{"device_id", device_id.c_str()},
{"trt_max_workspace_size", trt_max_workspace_size.c_str()},
{"trt_max_partition_iterations", trt_max_partition_iterations.c_str()},
{"trt_min_subgraph_size", trt_min_subgraph_size.c_str()},
{"trt_fp16_enable", trt_fp16_enable.c_str()},
{"trt_detailed_build_log", trt_detailed_build_log.c_str()},
{"trt_engine_cache_enable", trt_engine_cache_enable.c_str()},
{"trt_engine_cache_path", trt_config.trt_engine_cache_path.c_str()},
{"trt_timing_cache_enable", trt_timing_cache_enable.c_str()},
{"trt_timing_cache_path", trt_config.trt_timing_cache_path.c_str()},
{"trt_dump_subgraphs", trt_dump_subgraphs.c_str()}
};
{"device_id", device_id.c_str()},
{"trt_max_workspace_size", trt_max_workspace_size.c_str()},
{"trt_max_partition_iterations",
trt_max_partition_iterations.c_str()},
{"trt_min_subgraph_size", trt_min_subgraph_size.c_str()},
{"trt_fp16_enable", trt_fp16_enable.c_str()},
{"trt_detailed_build_log", trt_detailed_build_log.c_str()},
{"trt_engine_cache_enable", trt_engine_cache_enable.c_str()},
{"trt_engine_cache_path", trt_config.trt_engine_cache_path.c_str()},
{"trt_timing_cache_enable", trt_timing_cache_enable.c_str()},
{"trt_timing_cache_path", trt_config.trt_timing_cache_path.c_str()},
{"trt_dump_subgraphs", trt_dump_subgraphs.c_str()}};
// ToDo : Trt configs
// "trt_int8_enable"
// "trt_int8_use_native_calibration_table"
@@ -151,9 +150,8 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
if (provider_config != nullptr) {
options.device_id = provider_config->device;
options.cudnn_conv_algo_search =
OrtCudnnConvAlgoSearch(provider_config->cuda_config
.cudnn_conv_algo_search);
options.cudnn_conv_algo_search = OrtCudnnConvAlgoSearch(
provider_config->cuda_config.cudnn_conv_algo_search);
} else {
options.device_id = 0;
// Default OrtCudnnConvAlgoSearchExhaustive is extremely slow
@@ -219,22 +217,24 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config) {
return GetSessionOptionsImpl(config.num_threads,
config.provider_config.provider, &config.provider_config);
config.provider_config.provider,
&config.provider_config);
}
Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config,
const std::string &model_type) {
const std::string &model_type) {
/*
Transducer models : Only encoder will run with tensorrt,
decoder and joiner will run with cuda
*/
if(config.provider_config.provider == "trt" &&
if (config.provider_config.provider == "trt" &&
(model_type == "decoder" || model_type == "joiner")) {
return GetSessionOptionsImpl(config.num_threads,
"cuda", &config.provider_config);
return GetSessionOptionsImpl(config.num_threads, "cuda",
&config.provider_config);
}
return GetSessionOptionsImpl(config.num_threads,
config.provider_config.provider, &config.provider_config);
config.provider_config.provider,
&config.provider_config);
}
Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config) {