diff --git a/sherpa-onnx/c-api/c-api.cc b/sherpa-onnx/c-api/c-api.cc index 9d70fb8a..feec532c 100644 --- a/sherpa-onnx/c-api/c-api.cc +++ b/sherpa-onnx/c-api/c-api.cc @@ -73,7 +73,7 @@ SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer( SHERPA_ONNX_OR(config->model_config.tokens, ""); recognizer_config.model_config.num_threads = SHERPA_ONNX_OR(config->model_config.num_threads, 1); - recognizer_config.model_config.provider = + recognizer_config.model_config.provider_config.provider = SHERPA_ONNX_OR(config->model_config.provider, "cpu"); recognizer_config.model_config.model_type = SHERPA_ONNX_OR(config->model_config.model_type, ""); @@ -570,7 +570,7 @@ SherpaOnnxKeywordSpotter *CreateKeywordSpotter( SHERPA_ONNX_OR(config->model_config.tokens, ""); spotter_config.model_config.num_threads = SHERPA_ONNX_OR(config->model_config.num_threads, 1); - spotter_config.model_config.provider = + spotter_config.model_config.provider_config.provider = SHERPA_ONNX_OR(config->model_config.provider, "cpu"); spotter_config.model_config.model_type = SHERPA_ONNX_OR(config->model_config.model_type, ""); diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index 29d78955..48d0c258 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -87,6 +87,7 @@ set(sources packed-sequence.cc pad-sequence.cc parse-options.cc + provider-config.cc provider.cc resample.cc session.cc diff --git a/sherpa-onnx/csrc/online-model-config.cc b/sherpa-onnx/csrc/online-model-config.cc index a8efa870..9913fa9e 100644 --- a/sherpa-onnx/csrc/online-model-config.cc +++ b/sherpa-onnx/csrc/online-model-config.cc @@ -16,6 +16,7 @@ void OnlineModelConfig::Register(ParseOptions *po) { wenet_ctc.Register(po); zipformer2_ctc.Register(po); nemo_ctc.Register(po); + provider_config.Register(po); po->Register("tokens", &tokens, "Path to tokens.txt"); @@ -29,9 +30,6 @@ void OnlineModelConfig::Register(ParseOptions *po) { po->Register("debug", &debug, "true to print model information while loading it."); - po->Register("provider", &provider, - "Specify a provider to use: cpu, cuda, coreml"); - po->Register("modeling-unit", &modeling_unit, "The modeling unit of the model, commonly used units are bpe, " "cjkchar, cjkchar+bpe, etc. Currently, it is needed only when " @@ -87,6 +85,10 @@ bool OnlineModelConfig::Validate() const { return nemo_ctc.Validate(); } + if (!provider_config.Validate()) { + return false; + } + return transducer.Validate(); } @@ -99,11 +101,11 @@ std::string OnlineModelConfig::ToString() const { os << "wenet_ctc=" << wenet_ctc.ToString() << ", "; os << "zipformer2_ctc=" << zipformer2_ctc.ToString() << ", "; os << "nemo_ctc=" << nemo_ctc.ToString() << ", "; + os << "provider_config=" << provider_config.ToString() << ", "; os << "tokens=\"" << tokens << "\", "; os << "num_threads=" << num_threads << ", "; os << "warm_up=" << warm_up << ", "; os << "debug=" << (debug ? "True" : "False") << ", "; - os << "provider=\"" << provider << "\", "; os << "model_type=\"" << model_type << "\", "; os << "modeling_unit=\"" << modeling_unit << "\", "; os << "bpe_vocab=\"" << bpe_vocab << "\")"; diff --git a/sherpa-onnx/csrc/online-model-config.h b/sherpa-onnx/csrc/online-model-config.h index 1509bd5b..0b64e06d 100644 --- a/sherpa-onnx/csrc/online-model-config.h +++ b/sherpa-onnx/csrc/online-model-config.h @@ -11,6 +11,7 @@ #include "sherpa-onnx/csrc/online-transducer-model-config.h" #include "sherpa-onnx/csrc/online-wenet-ctc-model-config.h" #include "sherpa-onnx/csrc/online-zipformer2-ctc-model-config.h" +#include "sherpa-onnx/csrc/provider-config.h" namespace sherpa_onnx { @@ -20,11 +21,11 @@ struct OnlineModelConfig { OnlineWenetCtcModelConfig wenet_ctc; OnlineZipformer2CtcModelConfig zipformer2_ctc; OnlineNeMoCtcModelConfig nemo_ctc; + ProviderConfig provider_config; std::string tokens; int32_t num_threads = 1; int32_t warm_up = 0; bool debug = false; - std::string provider = "cpu"; // Valid values: // - conformer, conformer transducer from icefall @@ -50,8 +51,9 @@ struct OnlineModelConfig { const OnlineWenetCtcModelConfig &wenet_ctc, const OnlineZipformer2CtcModelConfig &zipformer2_ctc, const OnlineNeMoCtcModelConfig &nemo_ctc, + const ProviderConfig &provider_config, const std::string &tokens, int32_t num_threads, - int32_t warm_up, bool debug, const std::string &provider, + int32_t warm_up, bool debug, const std::string &model_type, const std::string &modeling_unit, const std::string &bpe_vocab) @@ -60,11 +62,11 @@ struct OnlineModelConfig { wenet_ctc(wenet_ctc), zipformer2_ctc(zipformer2_ctc), nemo_ctc(nemo_ctc), + provider_config(provider_config), tokens(tokens), num_threads(num_threads), warm_up(warm_up), debug(debug), - provider(provider), model_type(model_type), modeling_unit(modeling_unit), bpe_vocab(bpe_vocab) {} diff --git a/sherpa-onnx/csrc/provider-config.cc b/sherpa-onnx/csrc/provider-config.cc new file mode 100644 index 00000000..8a58746c --- /dev/null +++ b/sherpa-onnx/csrc/provider-config.cc @@ -0,0 +1,143 @@ +// sherpa-onnx/csrc/provider-config.cc +// +// Copyright (c) 2024 Uniphore (Author: Manickavela) + +#include "sherpa-onnx/csrc/provider-config.h" + +#include + +#include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/macros.h" + +namespace sherpa_onnx { + +void CudaConfig::Register(ParseOptions *po) { + po->Register("cuda-cudnn-conv-algo-search", &cudnn_conv_algo_search, + "CuDNN convolution algrorithm search"); +} + +bool CudaConfig::Validate() const { + if (cudnn_conv_algo_search < 1 || cudnn_conv_algo_search > 3) { + SHERPA_ONNX_LOGE("cudnn_conv_algo_search: '%d' is not a valid option." + "Options : [1,3]. Check OnnxRT docs", + cudnn_conv_algo_search); + return false; + } + return true; +} + +std::string CudaConfig::ToString() const { + std::ostringstream os; + + os << "CudaConfig("; + os << "cudnn_conv_algo_search=" << cudnn_conv_algo_search << ")"; + + return os.str(); +} + +void TensorrtConfig::Register(ParseOptions *po) { + po->Register("trt-max-workspace-size", &trt_max_workspace_size, + "Set TensorRT EP GPU memory usage limit."); + po->Register("trt-max-partition-iterations", &trt_max_partition_iterations, + "Limit partitioning iterations for model conversion."); + po->Register("trt-min-subgraph-size", &trt_min_subgraph_size, + "Set minimum size for subgraphs in partitioning."); + po->Register("trt-fp16-enable", &trt_fp16_enable, + "Enable FP16 precision for faster performance."); + po->Register("trt-detailed-build-log", &trt_detailed_build_log, + "Enable detailed logging of build steps."); + po->Register("trt-engine-cache-enable", &trt_engine_cache_enable, + "Enable caching of TensorRT engines."); + po->Register("trt-timing-cache-enable", &trt_timing_cache_enable, + "Enable use of timing cache to speed up builds."); + po->Register("trt-engine-cache-path", &trt_engine_cache_path, + "Set path to store cached TensorRT engines."); + po->Register("trt-timing-cache-path", &trt_timing_cache_path, + "Set path for storing timing cache."); + po->Register("trt-dump-subgraphs", &trt_dump_subgraphs, + "Dump optimized subgraphs for debugging."); +} + +bool TensorrtConfig::Validate() const { + if (trt_max_workspace_size < 0) { + SHERPA_ONNX_LOGE("trt_max_workspace_size: %d is not valid.", + trt_max_workspace_size); + return false; + } + if (trt_max_partition_iterations < 0) { + SHERPA_ONNX_LOGE("trt_max_partition_iterations: %d is not valid.", + trt_max_partition_iterations); + return false; + } + if (trt_min_subgraph_size < 0) { + SHERPA_ONNX_LOGE("trt_min_subgraph_size: %d is not valid.", + trt_min_subgraph_size); + return false; + } + + return true; +} + +std::string TensorrtConfig::ToString() const { + std::ostringstream os; + + os << "TensorrtConfig("; + os << "trt_max_workspace_size=" << trt_max_workspace_size << ", "; + os << "trt_max_partition_iterations=" + << trt_max_partition_iterations << ", "; + os << "trt_min_subgraph_size=" << trt_min_subgraph_size << ", "; + os << "trt_fp16_enable=\"" + << (trt_fp16_enable? "True" : "False") << "\", "; + os << "trt_detailed_build_log=\"" + << (trt_detailed_build_log? "True" : "False") << "\", "; + os << "trt_engine_cache_enable=\"" + << (trt_engine_cache_enable? "True" : "False") << "\", "; + os << "trt_engine_cache_path=\"" + << trt_engine_cache_path.c_str() << "\", "; + os << "trt_timing_cache_enable=\"" + << (trt_timing_cache_enable? "True" : "False") << "\", "; + os << "trt_timing_cache_path=\"" + << trt_timing_cache_path.c_str() << "\","; + os << "trt_dump_subgraphs=\"" + << (trt_dump_subgraphs? "True" : "False") << "\" )"; + return os.str(); +} + +void ProviderConfig::Register(ParseOptions *po) { + cuda_config.Register(po); + trt_config.Register(po); + + po->Register("device", &device, "GPU device index for CUDA and Trt EP"); + po->Register("provider", &provider, + "Specify a provider to use: cpu, cuda, coreml"); +} + +bool ProviderConfig::Validate() const { + if (device < 0) { + SHERPA_ONNX_LOGE("device: '%d' is invalid.", device); + return false; + } + + if (provider == "cuda" && !cuda_config.Validate()) { + return false; + } + + if (provider == "trt" && !trt_config.Validate()) { + return false; + } + + return true; +} + +std::string ProviderConfig::ToString() const { + std::ostringstream os; + + os << "ProviderConfig("; + os << "device=" << device << ", "; + os << "provider=\"" << provider << "\", "; + os << "cuda_config=" << cuda_config.ToString() << ", "; + os << "trt_config=" << trt_config.ToString() << ")"; + return os.str(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/provider-config.h b/sherpa-onnx/csrc/provider-config.h new file mode 100644 index 00000000..ff960790 --- /dev/null +++ b/sherpa-onnx/csrc/provider-config.h @@ -0,0 +1,95 @@ +// sherpa-onnx/csrc/provider-config.h +// +// Copyright (c) 2024 Uniphore (Author: Manickavela) + +#ifndef SHERPA_ONNX_CSRC_PROVIDER_CONFIG_H_ +#define SHERPA_ONNX_CSRC_PROVIDER_CONFIG_H_ + +#include + +#include "sherpa-onnx/csrc/parse-options.h" +#include "sherpa-onnx/csrc/macros.h" +#include "onnxruntime_cxx_api.h" // NOLINT + +namespace sherpa_onnx { + +struct CudaConfig { + int32_t cudnn_conv_algo_search = OrtCudnnConvAlgoSearchHeuristic; + + CudaConfig() = default; + explicit CudaConfig(int32_t cudnn_conv_algo_search) + : cudnn_conv_algo_search(cudnn_conv_algo_search) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +struct TensorrtConfig { + int32_t trt_max_workspace_size = 2147483647; + int32_t trt_max_partition_iterations = 10; + int32_t trt_min_subgraph_size = 5; + bool trt_fp16_enable = true; + bool trt_detailed_build_log = false; + bool trt_engine_cache_enable = true; + bool trt_timing_cache_enable = true; + std::string trt_engine_cache_path = "."; + std::string trt_timing_cache_path = "."; + bool trt_dump_subgraphs = false; + + TensorrtConfig() = default; + TensorrtConfig(int32_t trt_max_workspace_size, + int32_t trt_max_partition_iterations, + int32_t trt_min_subgraph_size, + bool trt_fp16_enable, + bool trt_detailed_build_log, + bool trt_engine_cache_enable, + bool trt_timing_cache_enable, + const std::string &trt_engine_cache_path, + const std::string &trt_timing_cache_path, + bool trt_dump_subgraphs) + : trt_max_workspace_size(trt_max_workspace_size), + trt_max_partition_iterations(trt_max_partition_iterations), + trt_min_subgraph_size(trt_min_subgraph_size), + trt_fp16_enable(trt_fp16_enable), + trt_detailed_build_log(trt_detailed_build_log), + trt_engine_cache_enable(trt_engine_cache_enable), + trt_timing_cache_enable(trt_timing_cache_enable), + trt_engine_cache_path(trt_engine_cache_path), + trt_timing_cache_path(trt_timing_cache_path), + trt_dump_subgraphs(trt_dump_subgraphs) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +struct ProviderConfig { + TensorrtConfig trt_config; + CudaConfig cuda_config; + std::string provider = "cpu"; + int32_t device = 0; + // device only used for cuda and trt + + ProviderConfig() = default; + ProviderConfig(const std::string &provider, + int32_t device) + : provider(provider), device(device) {} + ProviderConfig(const TensorrtConfig &trt_config, + const CudaConfig &cuda_config, + const std::string &provider, + int32_t device) + : trt_config(trt_config), cuda_config(cuda_config), + provider(provider), device(device) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_PROVIDER_CONFIG_H_ diff --git a/sherpa-onnx/csrc/provider.h b/sherpa-onnx/csrc/provider.h index c104d401..712006f2 100644 --- a/sherpa-onnx/csrc/provider.h +++ b/sherpa-onnx/csrc/provider.h @@ -7,6 +7,7 @@ #include +#include "sherpa-onnx/csrc/provider-config.h" namespace sherpa_onnx { // Please refer to diff --git a/sherpa-onnx/csrc/session.cc b/sherpa-onnx/csrc/session.cc index 0f6ed89d..b6fdaaa8 100644 --- a/sherpa-onnx/csrc/session.cc +++ b/sherpa-onnx/csrc/session.cc @@ -32,11 +32,13 @@ static void OrtStatusFailure(OrtStatus *status, const char *s) { } static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, - std::string provider_str) { - Provider p = StringToProvider(std::move(provider_str)); + const std::string &provider_str, + const ProviderConfig *provider_config = nullptr) { + Provider p = StringToProvider(provider_str); Ort::SessionOptions sess_opts; sess_opts.SetIntraOpNumThreads(num_threads); + sess_opts.SetInterOpNumThreads(num_threads); std::vector available_providers = Ort::GetAvailableProviders(); @@ -64,26 +66,51 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, break; } case Provider::kTRT: { + if (provider_config == nullptr) { + SHERPA_ONNX_LOGE("Tensorrt support for Online models ony," + "Must be extended for offline and others"); + exit(1); + } + auto trt_config = provider_config->trt_config; struct TrtPairs { const char *op_keys; const char *op_values; }; + auto device_id = std::to_string(provider_config->device); + auto trt_max_workspace_size = + std::to_string(trt_config.trt_max_workspace_size); + auto trt_max_partition_iterations = + std::to_string(trt_config.trt_max_partition_iterations); + auto trt_min_subgraph_size = + std::to_string(trt_config.trt_min_subgraph_size); + auto trt_fp16_enable = + std::to_string(trt_config.trt_fp16_enable); + auto trt_detailed_build_log = + std::to_string(trt_config.trt_detailed_build_log); + auto trt_engine_cache_enable = + std::to_string(trt_config.trt_engine_cache_enable); + auto trt_timing_cache_enable = + std::to_string(trt_config.trt_timing_cache_enable); + auto trt_dump_subgraphs = + std::to_string(trt_config.trt_dump_subgraphs); + std::vector trt_options = { - {"device_id", "0"}, - {"trt_max_workspace_size", "2147483648"}, - {"trt_max_partition_iterations", "10"}, - {"trt_min_subgraph_size", "5"}, - {"trt_fp16_enable", "0"}, - {"trt_detailed_build_log", "0"}, - {"trt_engine_cache_enable", "1"}, - {"trt_engine_cache_path", "."}, - {"trt_timing_cache_enable", "1"}, - {"trt_timing_cache_path", "."}}; + {"device_id", device_id.c_str()}, + {"trt_max_workspace_size", trt_max_workspace_size.c_str()}, + {"trt_max_partition_iterations", trt_max_partition_iterations.c_str()}, + {"trt_min_subgraph_size", trt_min_subgraph_size.c_str()}, + {"trt_fp16_enable", trt_fp16_enable.c_str()}, + {"trt_detailed_build_log", trt_detailed_build_log.c_str()}, + {"trt_engine_cache_enable", trt_engine_cache_enable.c_str()}, + {"trt_engine_cache_path", trt_config.trt_engine_cache_path.c_str()}, + {"trt_timing_cache_enable", trt_timing_cache_enable.c_str()}, + {"trt_timing_cache_path", trt_config.trt_timing_cache_path.c_str()}, + {"trt_dump_subgraphs", trt_dump_subgraphs.c_str()} + }; // ToDo : Trt configs // "trt_int8_enable" // "trt_int8_use_native_calibration_table" - // "trt_dump_subgraphs" std::vector option_keys, option_values; for (const TrtPairs &pair : trt_options) { @@ -122,10 +149,18 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, "CUDAExecutionProvider") != available_providers.end()) { // The CUDA provider is available, proceed with setting the options OrtCUDAProviderOptions options; - options.device_id = 0; - // Default OrtCudnnConvAlgoSearchExhaustive is extremely slow - options.cudnn_conv_algo_search = OrtCudnnConvAlgoSearchHeuristic; - // set more options on need + + if (provider_config != nullptr) { + options.device_id = provider_config->device; + options.cudnn_conv_algo_search = + OrtCudnnConvAlgoSearch(provider_config->cuda_config + .cudnn_conv_algo_search); + } else { + options.device_id = 0; + // Default OrtCudnnConvAlgoSearchExhaustive is extremely slow + options.cudnn_conv_algo_search = OrtCudnnConvAlgoSearchHeuristic; + // set more options on need + } sess_opts.AppendExecutionProvider_CUDA(options); } else { SHERPA_ONNX_LOGE( @@ -184,7 +219,8 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, } Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config) { - return GetSessionOptionsImpl(config.num_threads, config.provider); + return GetSessionOptionsImpl(config.num_threads, + config.provider_config.provider, &config.provider_config); } Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config) { diff --git a/sherpa-onnx/jni/keyword-spotter.cc b/sherpa-onnx/jni/keyword-spotter.cc index 7a05b485..ca0c229c 100644 --- a/sherpa-onnx/jni/keyword-spotter.cc +++ b/sherpa-onnx/jni/keyword-spotter.cc @@ -94,7 +94,7 @@ static KeywordSpotterConfig GetKwsConfig(JNIEnv *env, jobject config) { fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;"); s = (jstring)env->GetObjectField(model_config, fid); p = env->GetStringUTFChars(s, nullptr); - ans.model_config.provider = p; + ans.model_config.provider_config.provider = p; env->ReleaseStringUTFChars(s, p); fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;"); diff --git a/sherpa-onnx/jni/online-recognizer.cc b/sherpa-onnx/jni/online-recognizer.cc index d8acd0fe..643b037b 100644 --- a/sherpa-onnx/jni/online-recognizer.cc +++ b/sherpa-onnx/jni/online-recognizer.cc @@ -198,7 +198,7 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) { fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;"); s = (jstring)env->GetObjectField(model_config, fid); p = env->GetStringUTFChars(s, nullptr); - ans.model_config.provider = p; + ans.model_config.provider_config.provider = p; env->ReleaseStringUTFChars(s, p); fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;"); diff --git a/sherpa-onnx/python/csrc/CMakeLists.txt b/sherpa-onnx/python/csrc/CMakeLists.txt index 6d61d11d..5e74a172 100644 --- a/sherpa-onnx/python/csrc/CMakeLists.txt +++ b/sherpa-onnx/python/csrc/CMakeLists.txt @@ -3,6 +3,7 @@ include_directories(${CMAKE_SOURCE_DIR}) set(srcs audio-tagging.cc circular-buffer.cc + cuda-config.cc display.cc endpoint.cc features.cc @@ -30,11 +31,13 @@ set(srcs online-transducer-model-config.cc online-wenet-ctc-model-config.cc online-zipformer2-ctc-model-config.cc + provider-config.cc sherpa-onnx.cc silero-vad-model-config.cc speaker-embedding-extractor.cc speaker-embedding-manager.cc spoken-language-identification.cc + tensorrt-config.cc vad-model-config.cc vad-model.cc voice-activity-detector.cc diff --git a/sherpa-onnx/python/csrc/cuda-config.cc b/sherpa-onnx/python/csrc/cuda-config.cc new file mode 100644 index 00000000..43627d3a --- /dev/null +++ b/sherpa-onnx/python/csrc/cuda-config.cc @@ -0,0 +1,24 @@ +// sherpa-onnx/python/csrc/cuda-config.cc +// +// Copyright (c) 2024 Uniphore (Author: Manickavela A) + +#include "sherpa-onnx/python/csrc/cuda-config.h" + +#include +#include + +#include "sherpa-onnx/csrc/provider-config.h" + +namespace sherpa_onnx { + +void PybindCudaConfig(py::module *m) { + using PyClass = CudaConfig; + py::class_(*m, "CudaConfig") + .def(py::init<>()) + .def(py::init(), + py::arg("cudnn_conv_algo_search") = 1) + .def_readwrite("cudnn_conv_algo_search", &PyClass::cudnn_conv_algo_search) + .def("__str__", &PyClass::ToString); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/cuda-config.h b/sherpa-onnx/python/csrc/cuda-config.h new file mode 100644 index 00000000..012fb29e --- /dev/null +++ b/sherpa-onnx/python/csrc/cuda-config.h @@ -0,0 +1,16 @@ +// sherpa-onnx/python/csrc/cuda-config.h +// +// Copyright (c) 2024 Uniphore (Author: Manickavela A) + +#ifndef SHERPA_ONNX_PYTHON_CSRC_CUDA_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_CUDA_CONFIG_H_ + +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" + +namespace sherpa_onnx { + +void PybindCudaConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_CUDA_CONFIG_H_ diff --git a/sherpa-onnx/python/csrc/online-model-config.cc b/sherpa-onnx/python/csrc/online-model-config.cc index d6db809b..4ea13fd6 100644 --- a/sherpa-onnx/python/csrc/online-model-config.cc +++ b/sherpa-onnx/python/csrc/online-model-config.cc @@ -9,11 +9,13 @@ #include "sherpa-onnx/csrc/online-model-config.h" #include "sherpa-onnx/csrc/online-transducer-model-config.h" +#include "sherpa-onnx/csrc/provider-config.h" #include "sherpa-onnx/python/csrc/online-nemo-ctc-model-config.h" #include "sherpa-onnx/python/csrc/online-paraformer-model-config.h" #include "sherpa-onnx/python/csrc/online-transducer-model-config.h" #include "sherpa-onnx/python/csrc/online-wenet-ctc-model-config.h" #include "sherpa-onnx/python/csrc/online-zipformer2-ctc-model-config.h" +#include "sherpa-onnx/python/csrc/provider-config.h" namespace sherpa_onnx { @@ -23,6 +25,7 @@ void PybindOnlineModelConfig(py::module *m) { PybindOnlineWenetCtcModelConfig(m); PybindOnlineZipformer2CtcModelConfig(m); PybindOnlineNeMoCtcModelConfig(m); + PybindProviderConfig(m); using PyClass = OnlineModelConfig; py::class_(*m, "OnlineModelConfig") @@ -30,33 +33,34 @@ void PybindOnlineModelConfig(py::module *m) { const OnlineParaformerModelConfig &, const OnlineWenetCtcModelConfig &, const OnlineZipformer2CtcModelConfig &, - const OnlineNeMoCtcModelConfig &, const std::string &, - int32_t, int32_t, bool, const std::string &, - const std::string &, const std::string &, + const OnlineNeMoCtcModelConfig &, + const ProviderConfig &, + const std::string &, int32_t, int32_t, + bool, const std::string &, const std::string &, const std::string &>(), py::arg("transducer") = OnlineTransducerModelConfig(), py::arg("paraformer") = OnlineParaformerModelConfig(), py::arg("wenet_ctc") = OnlineWenetCtcModelConfig(), py::arg("zipformer2_ctc") = OnlineZipformer2CtcModelConfig(), - py::arg("nemo_ctc") = OnlineNeMoCtcModelConfig(), py::arg("tokens"), - py::arg("num_threads"), py::arg("warm_up") = 0, - py::arg("debug") = false, py::arg("provider") = "cpu", - py::arg("model_type") = "", py::arg("modeling_unit") = "", - py::arg("bpe_vocab") = "") + py::arg("nemo_ctc") = OnlineNeMoCtcModelConfig(), + py::arg("provider_config") = ProviderConfig(), + py::arg("tokens"), py::arg("num_threads"), py::arg("warm_up") = 0, + py::arg("debug") = false, py::arg("model_type") = "", + py::arg("modeling_unit") = "", py::arg("bpe_vocab") = "") .def_readwrite("transducer", &PyClass::transducer) .def_readwrite("paraformer", &PyClass::paraformer) .def_readwrite("wenet_ctc", &PyClass::wenet_ctc) .def_readwrite("zipformer2_ctc", &PyClass::zipformer2_ctc) .def_readwrite("nemo_ctc", &PyClass::nemo_ctc) + .def_readwrite("provider_config", &PyClass::provider_config) .def_readwrite("tokens", &PyClass::tokens) .def_readwrite("num_threads", &PyClass::num_threads) + .def_readwrite("warm_up", &PyClass::warm_up) .def_readwrite("debug", &PyClass::debug) - .def_readwrite("provider", &PyClass::provider) .def_readwrite("model_type", &PyClass::model_type) .def_readwrite("modeling_unit", &PyClass::modeling_unit) .def_readwrite("bpe_vocab", &PyClass::bpe_vocab) .def("validate", &PyClass::Validate) .def("__str__", &PyClass::ToString); } - } // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/provider-config.cc b/sherpa-onnx/python/csrc/provider-config.cc new file mode 100644 index 00000000..c29d48ab --- /dev/null +++ b/sherpa-onnx/python/csrc/provider-config.cc @@ -0,0 +1,39 @@ +// sherpa-onnx/python/csrc/provider-config.cc +// +// Copyright (c) 2024 Uniphore (Author: Manickavela A) + + +#include "sherpa-onnx/python/csrc/provider-config.h" + +#include + +#include "sherpa-onnx/csrc/provider-config.h" +#include "sherpa-onnx/python/csrc/cuda-config.h" +#include "sherpa-onnx/python/csrc/tensorrt-config.h" + +namespace sherpa_onnx { + +void PybindProviderConfig(py::module *m) { + PybindCudaConfig(m); + PybindTensorrtConfig(m); + + using PyClass = ProviderConfig; + py::class_(*m, "ProviderConfig") + .def(py::init<>()) + .def(py::init(), + py::arg("provider") = "cpu", + py::arg("device") = 0) + .def(py::init(), + py::arg("trt_config") = TensorrtConfig{}, + py::arg("cuda_config") = CudaConfig{}, + py::arg("provider") = "cpu", + py::arg("device") = 0) + .def_readwrite("trt_config", &PyClass::trt_config) + .def_readwrite("cuda_config", &PyClass::cuda_config) + .def_readwrite("provider", &PyClass::provider) + .def_readwrite("device", &PyClass::device) + .def("__str__", &PyClass::ToString) + .def("validate", &PyClass::Validate); +} +} // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/provider-config.h b/sherpa-onnx/python/csrc/provider-config.h new file mode 100644 index 00000000..76377dde --- /dev/null +++ b/sherpa-onnx/python/csrc/provider-config.h @@ -0,0 +1,16 @@ +// sherpa-onnx/python/csrc/provider-config.h +// +// Copyright (c) 2024 Uniphore (Author: Manickavela A) + +#ifndef SHERPA_ONNX_PYTHON_CSRC_PROVIDER_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_PROVIDER_CONFIG_H_ + +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" + +namespace sherpa_onnx { + +void PybindProviderConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_PROVIDER_CONFIG_H_ diff --git a/sherpa-onnx/python/csrc/sherpa-onnx.cc b/sherpa-onnx/python/csrc/sherpa-onnx.cc index 242d8597..5b369ed8 100644 --- a/sherpa-onnx/python/csrc/sherpa-onnx.cc +++ b/sherpa-onnx/python/csrc/sherpa-onnx.cc @@ -51,7 +51,6 @@ PYBIND11_MODULE(_sherpa_onnx, m) { PybindEndpoint(&m); PybindOnlineRecognizer(&m); PybindKeywordSpotter(&m); - PybindDisplay(&m); PybindOfflineStream(&m); diff --git a/sherpa-onnx/python/csrc/tensorrt-config.cc b/sherpa-onnx/python/csrc/tensorrt-config.cc new file mode 100644 index 00000000..87962a2d --- /dev/null +++ b/sherpa-onnx/python/csrc/tensorrt-config.cc @@ -0,0 +1,72 @@ +// sherpa-onnx/python/csrc/tensorrt-config.cc +// +// Copyright (c) 2024 Uniphore (Author: Manickavela A) + +#include "sherpa-onnx/python/csrc/tensorrt-config.h" + +#include +#include +#include "sherpa-onnx/csrc/provider-config.h" + +namespace sherpa_onnx { + +void PybindTensorrtConfig(py::module *m) { + using PyClass = TensorrtConfig; + py::class_(*m, "TensorrtConfig") + .def(py::init<>()) + .def(py::init([](int32_t trt_max_workspace_size, + int32_t trt_max_partition_iterations, + int32_t trt_min_subgraph_size, + bool trt_fp16_enable, + bool trt_detailed_build_log, + bool trt_engine_cache_enable, + bool trt_timing_cache_enable, + const std::string &trt_engine_cache_path, + const std::string &trt_timing_cache_path, + bool trt_dump_subgraphs) -> std::unique_ptr { + auto ans = std::make_unique(); + + ans->trt_max_workspace_size = trt_max_workspace_size; + ans->trt_max_partition_iterations = trt_max_partition_iterations; + ans->trt_min_subgraph_size = trt_min_subgraph_size; + ans->trt_fp16_enable = trt_fp16_enable; + ans->trt_detailed_build_log = trt_detailed_build_log; + ans->trt_engine_cache_enable = trt_engine_cache_enable; + ans->trt_timing_cache_enable = trt_timing_cache_enable; + ans->trt_engine_cache_path = trt_engine_cache_path; + ans->trt_timing_cache_path = trt_timing_cache_path; + ans->trt_dump_subgraphs = trt_dump_subgraphs; + + return ans; + }), + py::arg("trt_max_workspace_size") = 2147483647, + py::arg("trt_max_partition_iterations") = 10, + py::arg("trt_min_subgraph_size") = 5, + py::arg("trt_fp16_enable") = true, + py::arg("trt_detailed_build_log") = false, + py::arg("trt_engine_cache_enable") = true, + py::arg("trt_timing_cache_enable") = true, + py::arg("trt_engine_cache_path") = ".", + py::arg("trt_timing_cache_path") = ".", + py::arg("trt_dump_subgraphs") = false) + + .def_readwrite("trt_max_workspace_size", + &PyClass::trt_max_workspace_size) + .def_readwrite("trt_max_partition_iterations", + &PyClass::trt_max_partition_iterations) + .def_readwrite("trt_min_subgraph_size", &PyClass::trt_min_subgraph_size) + .def_readwrite("trt_fp16_enable", &PyClass::trt_fp16_enable) + .def_readwrite("trt_detailed_build_log", + &PyClass::trt_detailed_build_log) + .def_readwrite("trt_engine_cache_enable", + &PyClass::trt_engine_cache_enable) + .def_readwrite("trt_timing_cache_enable", + &PyClass::trt_timing_cache_enable) + .def_readwrite("trt_engine_cache_path", &PyClass::trt_engine_cache_path) + .def_readwrite("trt_timing_cache_path", &PyClass::trt_timing_cache_path) + .def_readwrite("trt_dump_subgraphs", &PyClass::trt_dump_subgraphs) + .def("__str__", &PyClass::ToString) + .def("validate", &PyClass::Validate); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/tensorrt-config.h b/sherpa-onnx/python/csrc/tensorrt-config.h new file mode 100644 index 00000000..d8eea700 --- /dev/null +++ b/sherpa-onnx/python/csrc/tensorrt-config.h @@ -0,0 +1,16 @@ +// sherpa-onnx/python/csrc/tensorrt-config.h +// +// Copyright (c) 2024 Uniphore (Author: Manickavela A) + +#ifndef SHERPA_ONNX_PYTHON_CSRC_TENSORRT_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_TENSORRT_CONFIG_H_ + +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" + +namespace sherpa_onnx { + +void PybindTensorrtConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_TENSORRT_CONFIG_H_ diff --git a/sherpa-onnx/python/sherpa_onnx/keyword_spotter.py b/sherpa-onnx/python/sherpa_onnx/keyword_spotter.py index 218628ea..66d71698 100644 --- a/sherpa-onnx/python/sherpa_onnx/keyword_spotter.py +++ b/sherpa-onnx/python/sherpa_onnx/keyword_spotter.py @@ -9,6 +9,7 @@ from _sherpa_onnx import ( OnlineModelConfig, OnlineTransducerModelConfig, OnlineStream, + ProviderConfig, ) from _sherpa_onnx import KeywordSpotter as _KeywordSpotter @@ -41,6 +42,7 @@ class KeywordSpotter(object): keywords_threshold: float = 0.25, num_trailing_blanks: int = 1, provider: str = "cpu", + device: int = 0, ): """ Please refer to @@ -85,6 +87,8 @@ class KeywordSpotter(object): between each other. provider: onnxruntime execution providers. Valid values are: cpu, cuda, coreml. + device: + onnxruntime cuda device index. """ _assert_file_exists(tokens) _assert_file_exists(encoder) @@ -99,11 +103,16 @@ class KeywordSpotter(object): joiner=joiner, ) + provider_config = ProviderConfig( + provider=provider, + device = device, + ) + model_config = OnlineModelConfig( transducer=transducer_config, tokens=tokens, num_threads=num_threads, - provider=provider, + provider_config=provider_config, ) feat_config = FeatureExtractorConfig( diff --git a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py index 82b2e3b4..779ba6e7 100644 --- a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py @@ -11,6 +11,9 @@ from _sherpa_onnx import ( ) from _sherpa_onnx import OnlineRecognizer as _Recognizer from _sherpa_onnx import ( + CudaConfig, + TensorrtConfig, + ProviderConfig, OnlineRecognizerConfig, OnlineRecognizerResult, OnlineStream, @@ -56,7 +59,6 @@ class OnlineRecognizer(object): hotwords_score: float = 1.5, blank_penalty: float = 0.0, hotwords_file: str = "", - provider: str = "cpu", model_type: str = "", modeling_unit: str = "cjkchar", bpe_vocab: str = "", @@ -66,6 +68,19 @@ class OnlineRecognizer(object): debug: bool = False, rule_fsts: str = "", rule_fars: str = "", + provider: str = "cpu", + device: int = 0, + cudnn_conv_algo_search: int = 1, + trt_max_workspace_size: int = 2147483647, + trt_max_partition_iterations: int = 10, + trt_min_subgraph_size: int = 5, + trt_fp16_enable: bool = True, + trt_detailed_build_log: bool = False, + trt_engine_cache_enable: bool = True, + trt_timing_cache_enable: bool = True, + trt_engine_cache_path: str ="", + trt_timing_cache_path: str ="", + trt_dump_subgraphs: bool = False, ): """ Please refer to @@ -135,8 +150,6 @@ class OnlineRecognizer(object): Temperature scaling for output symbol confidence estiamation. It affects only confidence values, the decoding uses the original logits without temperature. - provider: - onnxruntime execution providers. Valid values are: cpu, cuda, coreml. model_type: Online transducer model type. Valid values are: conformer, lstm, zipformer, zipformer2. All other values lead to loading the model twice. @@ -156,6 +169,32 @@ class OnlineRecognizer(object): rule_fars: If not empty, it specifies fst archives for inverse text normalization. If there are multiple archives, they are separated by a comma. + provider: + onnxruntime execution providers. Valid values are: cpu, cuda, coreml. + device: + onnxruntime cuda device index. + cudnn_conv_algo_search: + onxrt CuDNN convolution search algorithm selection. CUDA EP + trt_max_workspace_size: + Set TensorRT EP GPU memory usage limit. TensorRT EP + trt_max_partition_iterations: + Limit partitioning iterations for model conversion. TensorRT EP + trt_min_subgraph_size: + Set minimum size for subgraphs in partitioning. TensorRT EP + trt_fp16_enable: bool = True, + Enable FP16 precision for faster performance. TensorRT EP + trt_detailed_build_log: bool = False, + Enable detailed logging of build steps. TensorRT EP + trt_engine_cache_enable: bool = True, + Enable caching of TensorRT engines. TensorRT EP + trt_timing_cache_enable: bool = True, + "Enable use of timing cache to speed up builds." TensorRT EP + trt_engine_cache_path: str ="", + "Set path to store cached TensorRT engines." TensorRT EP + trt_timing_cache_path: str ="", + "Set path for storing timing cache." TensorRT EP + trt_dump_subgraphs: bool = False, + "Dump optimized subgraphs for debugging." TensorRT EP """ self = cls.__new__(cls) _assert_file_exists(tokens) @@ -171,11 +210,35 @@ class OnlineRecognizer(object): joiner=joiner, ) + cuda_config = CudaConfig( + cudnn_conv_algo_search=cudnn_conv_algo_search, + ) + + trt_config = TensorrtConfig( + trt_max_workspace_size=trt_max_workspace_size, + trt_max_partition_iterations=trt_max_partition_iterations, + trt_min_subgraph_size=trt_min_subgraph_size, + trt_fp16_enable=trt_fp16_enable, + trt_detailed_build_log=trt_detailed_build_log, + trt_engine_cache_enable=trt_engine_cache_enable, + trt_timing_cache_enable=trt_timing_cache_enable, + trt_engine_cache_path=trt_engine_cache_path, + trt_timing_cache_path=trt_timing_cache_path, + trt_dump_subgraphs=trt_dump_subgraphs, + ) + + provider_config = ProviderConfig( + trt_config=trt_config, + cuda_config=cuda_config, + provider=provider, + device=device, + ) + model_config = OnlineModelConfig( transducer=transducer_config, tokens=tokens, num_threads=num_threads, - provider=provider, + provider_config=provider_config, model_type=model_type, modeling_unit=modeling_unit, bpe_vocab=bpe_vocab, @@ -251,6 +314,7 @@ class OnlineRecognizer(object): debug: bool = False, rule_fsts: str = "", rule_fars: str = "", + device: int = 0, ): """ Please refer to @@ -301,6 +365,8 @@ class OnlineRecognizer(object): rule_fars: If not empty, it specifies fst archives for inverse text normalization. If there are multiple archives, they are separated by a comma. + device: + onnxruntime cuda device index. """ self = cls.__new__(cls) _assert_file_exists(tokens) @@ -314,11 +380,16 @@ class OnlineRecognizer(object): decoder=decoder, ) + provider_config = ProviderConfig( + provider=provider, + device=device, + ) + model_config = OnlineModelConfig( paraformer=paraformer_config, tokens=tokens, num_threads=num_threads, - provider=provider, + provider_config=provider_config, model_type="paraformer", debug=debug, ) @@ -367,6 +438,7 @@ class OnlineRecognizer(object): debug: bool = False, rule_fsts: str = "", rule_fars: str = "", + device: int = 0, ): """ Please refer to @@ -421,6 +493,8 @@ class OnlineRecognizer(object): rule_fars: If not empty, it specifies fst archives for inverse text normalization. If there are multiple archives, they are separated by a comma. + device: + onnxruntime cuda device index. """ self = cls.__new__(cls) _assert_file_exists(tokens) @@ -430,11 +504,16 @@ class OnlineRecognizer(object): zipformer2_ctc_config = OnlineZipformer2CtcModelConfig(model=model) + provider_config = ProviderConfig( + provider=provider, + device=device, + ) + model_config = OnlineModelConfig( zipformer2_ctc=zipformer2_ctc_config, tokens=tokens, num_threads=num_threads, - provider=provider, + provider_config=provider_config, debug=debug, ) @@ -486,6 +565,7 @@ class OnlineRecognizer(object): debug: bool = False, rule_fsts: str = "", rule_fars: str = "", + device: int = 0, ): """ Please refer to @@ -535,6 +615,8 @@ class OnlineRecognizer(object): rule_fars: If not empty, it specifies fst archives for inverse text normalization. If there are multiple archives, they are separated by a comma. + device: + onnxruntime cuda device index. """ self = cls.__new__(cls) _assert_file_exists(tokens) @@ -546,11 +628,16 @@ class OnlineRecognizer(object): model=model, ) + provider_config = ProviderConfig( + provider=provider, + device=device, + ) + model_config = OnlineModelConfig( nemo_ctc=nemo_ctc_config, tokens=tokens, num_threads=num_threads, - provider=provider, + provider_config=provider_config, debug=debug, ) @@ -598,6 +685,7 @@ class OnlineRecognizer(object): debug: bool = False, rule_fsts: str = "", rule_fars: str = "", + device: int = 0, ): """ Please refer to @@ -650,6 +738,8 @@ class OnlineRecognizer(object): rule_fars: If not empty, it specifies fst archives for inverse text normalization. If there are multiple archives, they are separated by a comma. + device: + onnxruntime cuda device index. """ self = cls.__new__(cls) _assert_file_exists(tokens) @@ -663,11 +753,16 @@ class OnlineRecognizer(object): num_left_chunks=num_left_chunks, ) + provider_config = ProviderConfig( + provider=provider, + device=device, + ) + model_config = OnlineModelConfig( wenet_ctc=wenet_ctc_config, tokens=tokens, num_threads=num_threads, - provider=provider, + provider_config=provider_config, debug=debug, )