Add config for TensorRT and CUDA execution provider (#992)

Signed-off-by: manickavela1998@gmail.com <manickavela1998@gmail.com>
Signed-off-by: manickavela1998@gmail.com <manickavela.arumugam@uniphore.com>
This commit is contained in:
Manix
2024-07-05 12:48:37 +05:30
committed by GitHub
parent f5e9a162d1
commit 55decb7bee
21 changed files with 622 additions and 49 deletions

View File

@@ -73,7 +73,7 @@ SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer(
SHERPA_ONNX_OR(config->model_config.tokens, ""); SHERPA_ONNX_OR(config->model_config.tokens, "");
recognizer_config.model_config.num_threads = recognizer_config.model_config.num_threads =
SHERPA_ONNX_OR(config->model_config.num_threads, 1); SHERPA_ONNX_OR(config->model_config.num_threads, 1);
recognizer_config.model_config.provider = recognizer_config.model_config.provider_config.provider =
SHERPA_ONNX_OR(config->model_config.provider, "cpu"); SHERPA_ONNX_OR(config->model_config.provider, "cpu");
recognizer_config.model_config.model_type = recognizer_config.model_config.model_type =
SHERPA_ONNX_OR(config->model_config.model_type, ""); SHERPA_ONNX_OR(config->model_config.model_type, "");
@@ -570,7 +570,7 @@ SherpaOnnxKeywordSpotter *CreateKeywordSpotter(
SHERPA_ONNX_OR(config->model_config.tokens, ""); SHERPA_ONNX_OR(config->model_config.tokens, "");
spotter_config.model_config.num_threads = spotter_config.model_config.num_threads =
SHERPA_ONNX_OR(config->model_config.num_threads, 1); SHERPA_ONNX_OR(config->model_config.num_threads, 1);
spotter_config.model_config.provider = spotter_config.model_config.provider_config.provider =
SHERPA_ONNX_OR(config->model_config.provider, "cpu"); SHERPA_ONNX_OR(config->model_config.provider, "cpu");
spotter_config.model_config.model_type = spotter_config.model_config.model_type =
SHERPA_ONNX_OR(config->model_config.model_type, ""); SHERPA_ONNX_OR(config->model_config.model_type, "");

View File

@@ -87,6 +87,7 @@ set(sources
packed-sequence.cc packed-sequence.cc
pad-sequence.cc pad-sequence.cc
parse-options.cc parse-options.cc
provider-config.cc
provider.cc provider.cc
resample.cc resample.cc
session.cc session.cc

View File

@@ -16,6 +16,7 @@ void OnlineModelConfig::Register(ParseOptions *po) {
wenet_ctc.Register(po); wenet_ctc.Register(po);
zipformer2_ctc.Register(po); zipformer2_ctc.Register(po);
nemo_ctc.Register(po); nemo_ctc.Register(po);
provider_config.Register(po);
po->Register("tokens", &tokens, "Path to tokens.txt"); po->Register("tokens", &tokens, "Path to tokens.txt");
@@ -29,9 +30,6 @@ void OnlineModelConfig::Register(ParseOptions *po) {
po->Register("debug", &debug, po->Register("debug", &debug,
"true to print model information while loading it."); "true to print model information while loading it.");
po->Register("provider", &provider,
"Specify a provider to use: cpu, cuda, coreml");
po->Register("modeling-unit", &modeling_unit, po->Register("modeling-unit", &modeling_unit,
"The modeling unit of the model, commonly used units are bpe, " "The modeling unit of the model, commonly used units are bpe, "
"cjkchar, cjkchar+bpe, etc. Currently, it is needed only when " "cjkchar, cjkchar+bpe, etc. Currently, it is needed only when "
@@ -87,6 +85,10 @@ bool OnlineModelConfig::Validate() const {
return nemo_ctc.Validate(); return nemo_ctc.Validate();
} }
if (!provider_config.Validate()) {
return false;
}
return transducer.Validate(); return transducer.Validate();
} }
@@ -99,11 +101,11 @@ std::string OnlineModelConfig::ToString() const {
os << "wenet_ctc=" << wenet_ctc.ToString() << ", "; os << "wenet_ctc=" << wenet_ctc.ToString() << ", ";
os << "zipformer2_ctc=" << zipformer2_ctc.ToString() << ", "; os << "zipformer2_ctc=" << zipformer2_ctc.ToString() << ", ";
os << "nemo_ctc=" << nemo_ctc.ToString() << ", "; os << "nemo_ctc=" << nemo_ctc.ToString() << ", ";
os << "provider_config=" << provider_config.ToString() << ", ";
os << "tokens=\"" << tokens << "\", "; os << "tokens=\"" << tokens << "\", ";
os << "num_threads=" << num_threads << ", "; os << "num_threads=" << num_threads << ", ";
os << "warm_up=" << warm_up << ", "; os << "warm_up=" << warm_up << ", ";
os << "debug=" << (debug ? "True" : "False") << ", "; os << "debug=" << (debug ? "True" : "False") << ", ";
os << "provider=\"" << provider << "\", ";
os << "model_type=\"" << model_type << "\", "; os << "model_type=\"" << model_type << "\", ";
os << "modeling_unit=\"" << modeling_unit << "\", "; os << "modeling_unit=\"" << modeling_unit << "\", ";
os << "bpe_vocab=\"" << bpe_vocab << "\")"; os << "bpe_vocab=\"" << bpe_vocab << "\")";

View File

@@ -11,6 +11,7 @@
#include "sherpa-onnx/csrc/online-transducer-model-config.h" #include "sherpa-onnx/csrc/online-transducer-model-config.h"
#include "sherpa-onnx/csrc/online-wenet-ctc-model-config.h" #include "sherpa-onnx/csrc/online-wenet-ctc-model-config.h"
#include "sherpa-onnx/csrc/online-zipformer2-ctc-model-config.h" #include "sherpa-onnx/csrc/online-zipformer2-ctc-model-config.h"
#include "sherpa-onnx/csrc/provider-config.h"
namespace sherpa_onnx { namespace sherpa_onnx {
@@ -20,11 +21,11 @@ struct OnlineModelConfig {
OnlineWenetCtcModelConfig wenet_ctc; OnlineWenetCtcModelConfig wenet_ctc;
OnlineZipformer2CtcModelConfig zipformer2_ctc; OnlineZipformer2CtcModelConfig zipformer2_ctc;
OnlineNeMoCtcModelConfig nemo_ctc; OnlineNeMoCtcModelConfig nemo_ctc;
ProviderConfig provider_config;
std::string tokens; std::string tokens;
int32_t num_threads = 1; int32_t num_threads = 1;
int32_t warm_up = 0; int32_t warm_up = 0;
bool debug = false; bool debug = false;
std::string provider = "cpu";
// Valid values: // Valid values:
// - conformer, conformer transducer from icefall // - conformer, conformer transducer from icefall
@@ -50,8 +51,9 @@ struct OnlineModelConfig {
const OnlineWenetCtcModelConfig &wenet_ctc, const OnlineWenetCtcModelConfig &wenet_ctc,
const OnlineZipformer2CtcModelConfig &zipformer2_ctc, const OnlineZipformer2CtcModelConfig &zipformer2_ctc,
const OnlineNeMoCtcModelConfig &nemo_ctc, const OnlineNeMoCtcModelConfig &nemo_ctc,
const ProviderConfig &provider_config,
const std::string &tokens, int32_t num_threads, const std::string &tokens, int32_t num_threads,
int32_t warm_up, bool debug, const std::string &provider, int32_t warm_up, bool debug,
const std::string &model_type, const std::string &model_type,
const std::string &modeling_unit, const std::string &modeling_unit,
const std::string &bpe_vocab) const std::string &bpe_vocab)
@@ -60,11 +62,11 @@ struct OnlineModelConfig {
wenet_ctc(wenet_ctc), wenet_ctc(wenet_ctc),
zipformer2_ctc(zipformer2_ctc), zipformer2_ctc(zipformer2_ctc),
nemo_ctc(nemo_ctc), nemo_ctc(nemo_ctc),
provider_config(provider_config),
tokens(tokens), tokens(tokens),
num_threads(num_threads), num_threads(num_threads),
warm_up(warm_up), warm_up(warm_up),
debug(debug), debug(debug),
provider(provider),
model_type(model_type), model_type(model_type),
modeling_unit(modeling_unit), modeling_unit(modeling_unit),
bpe_vocab(bpe_vocab) {} bpe_vocab(bpe_vocab) {}

View File

@@ -0,0 +1,143 @@
// sherpa-onnx/csrc/provider-config.cc
//
// Copyright (c) 2024 Uniphore (Author: Manickavela)
#include "sherpa-onnx/csrc/provider-config.h"
#include <sstream>
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
void CudaConfig::Register(ParseOptions *po) {
po->Register("cuda-cudnn-conv-algo-search", &cudnn_conv_algo_search,
"CuDNN convolution algrorithm search");
}
bool CudaConfig::Validate() const {
if (cudnn_conv_algo_search < 1 || cudnn_conv_algo_search > 3) {
SHERPA_ONNX_LOGE("cudnn_conv_algo_search: '%d' is not a valid option."
"Options : [1,3]. Check OnnxRT docs",
cudnn_conv_algo_search);
return false;
}
return true;
}
std::string CudaConfig::ToString() const {
std::ostringstream os;
os << "CudaConfig(";
os << "cudnn_conv_algo_search=" << cudnn_conv_algo_search << ")";
return os.str();
}
void TensorrtConfig::Register(ParseOptions *po) {
po->Register("trt-max-workspace-size", &trt_max_workspace_size,
"Set TensorRT EP GPU memory usage limit.");
po->Register("trt-max-partition-iterations", &trt_max_partition_iterations,
"Limit partitioning iterations for model conversion.");
po->Register("trt-min-subgraph-size", &trt_min_subgraph_size,
"Set minimum size for subgraphs in partitioning.");
po->Register("trt-fp16-enable", &trt_fp16_enable,
"Enable FP16 precision for faster performance.");
po->Register("trt-detailed-build-log", &trt_detailed_build_log,
"Enable detailed logging of build steps.");
po->Register("trt-engine-cache-enable", &trt_engine_cache_enable,
"Enable caching of TensorRT engines.");
po->Register("trt-timing-cache-enable", &trt_timing_cache_enable,
"Enable use of timing cache to speed up builds.");
po->Register("trt-engine-cache-path", &trt_engine_cache_path,
"Set path to store cached TensorRT engines.");
po->Register("trt-timing-cache-path", &trt_timing_cache_path,
"Set path for storing timing cache.");
po->Register("trt-dump-subgraphs", &trt_dump_subgraphs,
"Dump optimized subgraphs for debugging.");
}
bool TensorrtConfig::Validate() const {
if (trt_max_workspace_size < 0) {
SHERPA_ONNX_LOGE("trt_max_workspace_size: %d is not valid.",
trt_max_workspace_size);
return false;
}
if (trt_max_partition_iterations < 0) {
SHERPA_ONNX_LOGE("trt_max_partition_iterations: %d is not valid.",
trt_max_partition_iterations);
return false;
}
if (trt_min_subgraph_size < 0) {
SHERPA_ONNX_LOGE("trt_min_subgraph_size: %d is not valid.",
trt_min_subgraph_size);
return false;
}
return true;
}
std::string TensorrtConfig::ToString() const {
std::ostringstream os;
os << "TensorrtConfig(";
os << "trt_max_workspace_size=" << trt_max_workspace_size << ", ";
os << "trt_max_partition_iterations="
<< trt_max_partition_iterations << ", ";
os << "trt_min_subgraph_size=" << trt_min_subgraph_size << ", ";
os << "trt_fp16_enable=\""
<< (trt_fp16_enable? "True" : "False") << "\", ";
os << "trt_detailed_build_log=\""
<< (trt_detailed_build_log? "True" : "False") << "\", ";
os << "trt_engine_cache_enable=\""
<< (trt_engine_cache_enable? "True" : "False") << "\", ";
os << "trt_engine_cache_path=\""
<< trt_engine_cache_path.c_str() << "\", ";
os << "trt_timing_cache_enable=\""
<< (trt_timing_cache_enable? "True" : "False") << "\", ";
os << "trt_timing_cache_path=\""
<< trt_timing_cache_path.c_str() << "\",";
os << "trt_dump_subgraphs=\""
<< (trt_dump_subgraphs? "True" : "False") << "\" )";
return os.str();
}
void ProviderConfig::Register(ParseOptions *po) {
cuda_config.Register(po);
trt_config.Register(po);
po->Register("device", &device, "GPU device index for CUDA and Trt EP");
po->Register("provider", &provider,
"Specify a provider to use: cpu, cuda, coreml");
}
bool ProviderConfig::Validate() const {
if (device < 0) {
SHERPA_ONNX_LOGE("device: '%d' is invalid.", device);
return false;
}
if (provider == "cuda" && !cuda_config.Validate()) {
return false;
}
if (provider == "trt" && !trt_config.Validate()) {
return false;
}
return true;
}
std::string ProviderConfig::ToString() const {
std::ostringstream os;
os << "ProviderConfig(";
os << "device=" << device << ", ";
os << "provider=\"" << provider << "\", ";
os << "cuda_config=" << cuda_config.ToString() << ", ";
os << "trt_config=" << trt_config.ToString() << ")";
return os.str();
}
} // namespace sherpa_onnx

View File

@@ -0,0 +1,95 @@
// sherpa-onnx/csrc/provider-config.h
//
// Copyright (c) 2024 Uniphore (Author: Manickavela)
#ifndef SHERPA_ONNX_CSRC_PROVIDER_CONFIG_H_
#define SHERPA_ONNX_CSRC_PROVIDER_CONFIG_H_
#include <string>
#include "sherpa-onnx/csrc/parse-options.h"
#include "sherpa-onnx/csrc/macros.h"
#include "onnxruntime_cxx_api.h" // NOLINT
namespace sherpa_onnx {
struct CudaConfig {
int32_t cudnn_conv_algo_search = OrtCudnnConvAlgoSearchHeuristic;
CudaConfig() = default;
explicit CudaConfig(int32_t cudnn_conv_algo_search)
: cudnn_conv_algo_search(cudnn_conv_algo_search) {}
void Register(ParseOptions *po);
bool Validate() const;
std::string ToString() const;
};
struct TensorrtConfig {
int32_t trt_max_workspace_size = 2147483647;
int32_t trt_max_partition_iterations = 10;
int32_t trt_min_subgraph_size = 5;
bool trt_fp16_enable = true;
bool trt_detailed_build_log = false;
bool trt_engine_cache_enable = true;
bool trt_timing_cache_enable = true;
std::string trt_engine_cache_path = ".";
std::string trt_timing_cache_path = ".";
bool trt_dump_subgraphs = false;
TensorrtConfig() = default;
TensorrtConfig(int32_t trt_max_workspace_size,
int32_t trt_max_partition_iterations,
int32_t trt_min_subgraph_size,
bool trt_fp16_enable,
bool trt_detailed_build_log,
bool trt_engine_cache_enable,
bool trt_timing_cache_enable,
const std::string &trt_engine_cache_path,
const std::string &trt_timing_cache_path,
bool trt_dump_subgraphs)
: trt_max_workspace_size(trt_max_workspace_size),
trt_max_partition_iterations(trt_max_partition_iterations),
trt_min_subgraph_size(trt_min_subgraph_size),
trt_fp16_enable(trt_fp16_enable),
trt_detailed_build_log(trt_detailed_build_log),
trt_engine_cache_enable(trt_engine_cache_enable),
trt_timing_cache_enable(trt_timing_cache_enable),
trt_engine_cache_path(trt_engine_cache_path),
trt_timing_cache_path(trt_timing_cache_path),
trt_dump_subgraphs(trt_dump_subgraphs) {}
void Register(ParseOptions *po);
bool Validate() const;
std::string ToString() const;
};
struct ProviderConfig {
TensorrtConfig trt_config;
CudaConfig cuda_config;
std::string provider = "cpu";
int32_t device = 0;
// device only used for cuda and trt
ProviderConfig() = default;
ProviderConfig(const std::string &provider,
int32_t device)
: provider(provider), device(device) {}
ProviderConfig(const TensorrtConfig &trt_config,
const CudaConfig &cuda_config,
const std::string &provider,
int32_t device)
: trt_config(trt_config), cuda_config(cuda_config),
provider(provider), device(device) {}
void Register(ParseOptions *po);
bool Validate() const;
std::string ToString() const;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_PROVIDER_CONFIG_H_

View File

@@ -7,6 +7,7 @@
#include <string> #include <string>
#include "sherpa-onnx/csrc/provider-config.h"
namespace sherpa_onnx { namespace sherpa_onnx {
// Please refer to // Please refer to

View File

@@ -32,11 +32,13 @@ static void OrtStatusFailure(OrtStatus *status, const char *s) {
} }
static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
std::string provider_str) { const std::string &provider_str,
Provider p = StringToProvider(std::move(provider_str)); const ProviderConfig *provider_config = nullptr) {
Provider p = StringToProvider(provider_str);
Ort::SessionOptions sess_opts; Ort::SessionOptions sess_opts;
sess_opts.SetIntraOpNumThreads(num_threads); sess_opts.SetIntraOpNumThreads(num_threads);
sess_opts.SetInterOpNumThreads(num_threads); sess_opts.SetInterOpNumThreads(num_threads);
std::vector<std::string> available_providers = Ort::GetAvailableProviders(); std::vector<std::string> available_providers = Ort::GetAvailableProviders();
@@ -64,26 +66,51 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
break; break;
} }
case Provider::kTRT: { case Provider::kTRT: {
if (provider_config == nullptr) {
SHERPA_ONNX_LOGE("Tensorrt support for Online models ony,"
"Must be extended for offline and others");
exit(1);
}
auto trt_config = provider_config->trt_config;
struct TrtPairs { struct TrtPairs {
const char *op_keys; const char *op_keys;
const char *op_values; const char *op_values;
}; };
auto device_id = std::to_string(provider_config->device);
auto trt_max_workspace_size =
std::to_string(trt_config.trt_max_workspace_size);
auto trt_max_partition_iterations =
std::to_string(trt_config.trt_max_partition_iterations);
auto trt_min_subgraph_size =
std::to_string(trt_config.trt_min_subgraph_size);
auto trt_fp16_enable =
std::to_string(trt_config.trt_fp16_enable);
auto trt_detailed_build_log =
std::to_string(trt_config.trt_detailed_build_log);
auto trt_engine_cache_enable =
std::to_string(trt_config.trt_engine_cache_enable);
auto trt_timing_cache_enable =
std::to_string(trt_config.trt_timing_cache_enable);
auto trt_dump_subgraphs =
std::to_string(trt_config.trt_dump_subgraphs);
std::vector<TrtPairs> trt_options = { std::vector<TrtPairs> trt_options = {
{"device_id", "0"}, {"device_id", device_id.c_str()},
{"trt_max_workspace_size", "2147483648"}, {"trt_max_workspace_size", trt_max_workspace_size.c_str()},
{"trt_max_partition_iterations", "10"}, {"trt_max_partition_iterations", trt_max_partition_iterations.c_str()},
{"trt_min_subgraph_size", "5"}, {"trt_min_subgraph_size", trt_min_subgraph_size.c_str()},
{"trt_fp16_enable", "0"}, {"trt_fp16_enable", trt_fp16_enable.c_str()},
{"trt_detailed_build_log", "0"}, {"trt_detailed_build_log", trt_detailed_build_log.c_str()},
{"trt_engine_cache_enable", "1"}, {"trt_engine_cache_enable", trt_engine_cache_enable.c_str()},
{"trt_engine_cache_path", "."}, {"trt_engine_cache_path", trt_config.trt_engine_cache_path.c_str()},
{"trt_timing_cache_enable", "1"}, {"trt_timing_cache_enable", trt_timing_cache_enable.c_str()},
{"trt_timing_cache_path", "."}}; {"trt_timing_cache_path", trt_config.trt_timing_cache_path.c_str()},
{"trt_dump_subgraphs", trt_dump_subgraphs.c_str()}
};
// ToDo : Trt configs // ToDo : Trt configs
// "trt_int8_enable" // "trt_int8_enable"
// "trt_int8_use_native_calibration_table" // "trt_int8_use_native_calibration_table"
// "trt_dump_subgraphs"
std::vector<const char *> option_keys, option_values; std::vector<const char *> option_keys, option_values;
for (const TrtPairs &pair : trt_options) { for (const TrtPairs &pair : trt_options) {
@@ -122,10 +149,18 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
"CUDAExecutionProvider") != available_providers.end()) { "CUDAExecutionProvider") != available_providers.end()) {
// The CUDA provider is available, proceed with setting the options // The CUDA provider is available, proceed with setting the options
OrtCUDAProviderOptions options; OrtCUDAProviderOptions options;
if (provider_config != nullptr) {
options.device_id = provider_config->device;
options.cudnn_conv_algo_search =
OrtCudnnConvAlgoSearch(provider_config->cuda_config
.cudnn_conv_algo_search);
} else {
options.device_id = 0; options.device_id = 0;
// Default OrtCudnnConvAlgoSearchExhaustive is extremely slow // Default OrtCudnnConvAlgoSearchExhaustive is extremely slow
options.cudnn_conv_algo_search = OrtCudnnConvAlgoSearchHeuristic; options.cudnn_conv_algo_search = OrtCudnnConvAlgoSearchHeuristic;
// set more options on need // set more options on need
}
sess_opts.AppendExecutionProvider_CUDA(options); sess_opts.AppendExecutionProvider_CUDA(options);
} else { } else {
SHERPA_ONNX_LOGE( SHERPA_ONNX_LOGE(
@@ -184,7 +219,8 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
} }
Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config) { Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config) {
return GetSessionOptionsImpl(config.num_threads, config.provider); return GetSessionOptionsImpl(config.num_threads,
config.provider_config.provider, &config.provider_config);
} }
Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config) { Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config) {

View File

@@ -94,7 +94,7 @@ static KeywordSpotterConfig GetKwsConfig(JNIEnv *env, jobject config) {
fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;"); fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid); s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr); p = env->GetStringUTFChars(s, nullptr);
ans.model_config.provider = p; ans.model_config.provider_config.provider = p;
env->ReleaseStringUTFChars(s, p); env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;"); fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;");

View File

@@ -198,7 +198,7 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;"); fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid); s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr); p = env->GetStringUTFChars(s, nullptr);
ans.model_config.provider = p; ans.model_config.provider_config.provider = p;
env->ReleaseStringUTFChars(s, p); env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;"); fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;");

View File

@@ -3,6 +3,7 @@ include_directories(${CMAKE_SOURCE_DIR})
set(srcs set(srcs
audio-tagging.cc audio-tagging.cc
circular-buffer.cc circular-buffer.cc
cuda-config.cc
display.cc display.cc
endpoint.cc endpoint.cc
features.cc features.cc
@@ -30,11 +31,13 @@ set(srcs
online-transducer-model-config.cc online-transducer-model-config.cc
online-wenet-ctc-model-config.cc online-wenet-ctc-model-config.cc
online-zipformer2-ctc-model-config.cc online-zipformer2-ctc-model-config.cc
provider-config.cc
sherpa-onnx.cc sherpa-onnx.cc
silero-vad-model-config.cc silero-vad-model-config.cc
speaker-embedding-extractor.cc speaker-embedding-extractor.cc
speaker-embedding-manager.cc speaker-embedding-manager.cc
spoken-language-identification.cc spoken-language-identification.cc
tensorrt-config.cc
vad-model-config.cc vad-model-config.cc
vad-model.cc vad-model.cc
voice-activity-detector.cc voice-activity-detector.cc

View File

@@ -0,0 +1,24 @@
// sherpa-onnx/python/csrc/cuda-config.cc
//
// Copyright (c) 2024 Uniphore (Author: Manickavela A)
#include "sherpa-onnx/python/csrc/cuda-config.h"
#include <memory>
#include <string>
#include "sherpa-onnx/csrc/provider-config.h"
namespace sherpa_onnx {
void PybindCudaConfig(py::module *m) {
using PyClass = CudaConfig;
py::class_<PyClass>(*m, "CudaConfig")
.def(py::init<>())
.def(py::init<int32_t>(),
py::arg("cudnn_conv_algo_search") = 1)
.def_readwrite("cudnn_conv_algo_search", &PyClass::cudnn_conv_algo_search)
.def("__str__", &PyClass::ToString);
}
} // namespace sherpa_onnx

View File

@@ -0,0 +1,16 @@
// sherpa-onnx/python/csrc/cuda-config.h
//
// Copyright (c) 2024 Uniphore (Author: Manickavela A)
#ifndef SHERPA_ONNX_PYTHON_CSRC_CUDA_CONFIG_H_
#define SHERPA_ONNX_PYTHON_CSRC_CUDA_CONFIG_H_
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
namespace sherpa_onnx {
void PybindCudaConfig(py::module *m);
}
#endif // SHERPA_ONNX_PYTHON_CSRC_CUDA_CONFIG_H_

View File

@@ -9,11 +9,13 @@
#include "sherpa-onnx/csrc/online-model-config.h" #include "sherpa-onnx/csrc/online-model-config.h"
#include "sherpa-onnx/csrc/online-transducer-model-config.h" #include "sherpa-onnx/csrc/online-transducer-model-config.h"
#include "sherpa-onnx/csrc/provider-config.h"
#include "sherpa-onnx/python/csrc/online-nemo-ctc-model-config.h" #include "sherpa-onnx/python/csrc/online-nemo-ctc-model-config.h"
#include "sherpa-onnx/python/csrc/online-paraformer-model-config.h" #include "sherpa-onnx/python/csrc/online-paraformer-model-config.h"
#include "sherpa-onnx/python/csrc/online-transducer-model-config.h" #include "sherpa-onnx/python/csrc/online-transducer-model-config.h"
#include "sherpa-onnx/python/csrc/online-wenet-ctc-model-config.h" #include "sherpa-onnx/python/csrc/online-wenet-ctc-model-config.h"
#include "sherpa-onnx/python/csrc/online-zipformer2-ctc-model-config.h" #include "sherpa-onnx/python/csrc/online-zipformer2-ctc-model-config.h"
#include "sherpa-onnx/python/csrc/provider-config.h"
namespace sherpa_onnx { namespace sherpa_onnx {
@@ -23,6 +25,7 @@ void PybindOnlineModelConfig(py::module *m) {
PybindOnlineWenetCtcModelConfig(m); PybindOnlineWenetCtcModelConfig(m);
PybindOnlineZipformer2CtcModelConfig(m); PybindOnlineZipformer2CtcModelConfig(m);
PybindOnlineNeMoCtcModelConfig(m); PybindOnlineNeMoCtcModelConfig(m);
PybindProviderConfig(m);
using PyClass = OnlineModelConfig; using PyClass = OnlineModelConfig;
py::class_<PyClass>(*m, "OnlineModelConfig") py::class_<PyClass>(*m, "OnlineModelConfig")
@@ -30,33 +33,34 @@ void PybindOnlineModelConfig(py::module *m) {
const OnlineParaformerModelConfig &, const OnlineParaformerModelConfig &,
const OnlineWenetCtcModelConfig &, const OnlineWenetCtcModelConfig &,
const OnlineZipformer2CtcModelConfig &, const OnlineZipformer2CtcModelConfig &,
const OnlineNeMoCtcModelConfig &, const std::string &, const OnlineNeMoCtcModelConfig &,
int32_t, int32_t, bool, const std::string &, const ProviderConfig &,
const std::string &, const std::string &, const std::string &, int32_t, int32_t,
bool, const std::string &, const std::string &,
const std::string &>(), const std::string &>(),
py::arg("transducer") = OnlineTransducerModelConfig(), py::arg("transducer") = OnlineTransducerModelConfig(),
py::arg("paraformer") = OnlineParaformerModelConfig(), py::arg("paraformer") = OnlineParaformerModelConfig(),
py::arg("wenet_ctc") = OnlineWenetCtcModelConfig(), py::arg("wenet_ctc") = OnlineWenetCtcModelConfig(),
py::arg("zipformer2_ctc") = OnlineZipformer2CtcModelConfig(), py::arg("zipformer2_ctc") = OnlineZipformer2CtcModelConfig(),
py::arg("nemo_ctc") = OnlineNeMoCtcModelConfig(), py::arg("tokens"), py::arg("nemo_ctc") = OnlineNeMoCtcModelConfig(),
py::arg("num_threads"), py::arg("warm_up") = 0, py::arg("provider_config") = ProviderConfig(),
py::arg("debug") = false, py::arg("provider") = "cpu", py::arg("tokens"), py::arg("num_threads"), py::arg("warm_up") = 0,
py::arg("model_type") = "", py::arg("modeling_unit") = "", py::arg("debug") = false, py::arg("model_type") = "",
py::arg("bpe_vocab") = "") py::arg("modeling_unit") = "", py::arg("bpe_vocab") = "")
.def_readwrite("transducer", &PyClass::transducer) .def_readwrite("transducer", &PyClass::transducer)
.def_readwrite("paraformer", &PyClass::paraformer) .def_readwrite("paraformer", &PyClass::paraformer)
.def_readwrite("wenet_ctc", &PyClass::wenet_ctc) .def_readwrite("wenet_ctc", &PyClass::wenet_ctc)
.def_readwrite("zipformer2_ctc", &PyClass::zipformer2_ctc) .def_readwrite("zipformer2_ctc", &PyClass::zipformer2_ctc)
.def_readwrite("nemo_ctc", &PyClass::nemo_ctc) .def_readwrite("nemo_ctc", &PyClass::nemo_ctc)
.def_readwrite("provider_config", &PyClass::provider_config)
.def_readwrite("tokens", &PyClass::tokens) .def_readwrite("tokens", &PyClass::tokens)
.def_readwrite("num_threads", &PyClass::num_threads) .def_readwrite("num_threads", &PyClass::num_threads)
.def_readwrite("warm_up", &PyClass::warm_up)
.def_readwrite("debug", &PyClass::debug) .def_readwrite("debug", &PyClass::debug)
.def_readwrite("provider", &PyClass::provider)
.def_readwrite("model_type", &PyClass::model_type) .def_readwrite("model_type", &PyClass::model_type)
.def_readwrite("modeling_unit", &PyClass::modeling_unit) .def_readwrite("modeling_unit", &PyClass::modeling_unit)
.def_readwrite("bpe_vocab", &PyClass::bpe_vocab) .def_readwrite("bpe_vocab", &PyClass::bpe_vocab)
.def("validate", &PyClass::Validate) .def("validate", &PyClass::Validate)
.def("__str__", &PyClass::ToString); .def("__str__", &PyClass::ToString);
} }
} // namespace sherpa_onnx } // namespace sherpa_onnx

View File

@@ -0,0 +1,39 @@
// sherpa-onnx/python/csrc/provider-config.cc
//
// Copyright (c) 2024 Uniphore (Author: Manickavela A)
#include "sherpa-onnx/python/csrc/provider-config.h"
#include <string>
#include "sherpa-onnx/csrc/provider-config.h"
#include "sherpa-onnx/python/csrc/cuda-config.h"
#include "sherpa-onnx/python/csrc/tensorrt-config.h"
namespace sherpa_onnx {
void PybindProviderConfig(py::module *m) {
PybindCudaConfig(m);
PybindTensorrtConfig(m);
using PyClass = ProviderConfig;
py::class_<PyClass>(*m, "ProviderConfig")
.def(py::init<>())
.def(py::init<const std::string &, int32_t>(),
py::arg("provider") = "cpu",
py::arg("device") = 0)
.def(py::init<const TensorrtConfig &, const CudaConfig &,
const std::string &, int32_t>(),
py::arg("trt_config") = TensorrtConfig{},
py::arg("cuda_config") = CudaConfig{},
py::arg("provider") = "cpu",
py::arg("device") = 0)
.def_readwrite("trt_config", &PyClass::trt_config)
.def_readwrite("cuda_config", &PyClass::cuda_config)
.def_readwrite("provider", &PyClass::provider)
.def_readwrite("device", &PyClass::device)
.def("__str__", &PyClass::ToString)
.def("validate", &PyClass::Validate);
}
} // namespace sherpa_onnx

View File

@@ -0,0 +1,16 @@
// sherpa-onnx/python/csrc/provider-config.h
//
// Copyright (c) 2024 Uniphore (Author: Manickavela A)
#ifndef SHERPA_ONNX_PYTHON_CSRC_PROVIDER_CONFIG_H_
#define SHERPA_ONNX_PYTHON_CSRC_PROVIDER_CONFIG_H_
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
namespace sherpa_onnx {
void PybindProviderConfig(py::module *m);
}
#endif // SHERPA_ONNX_PYTHON_CSRC_PROVIDER_CONFIG_H_

View File

@@ -51,7 +51,6 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
PybindEndpoint(&m); PybindEndpoint(&m);
PybindOnlineRecognizer(&m); PybindOnlineRecognizer(&m);
PybindKeywordSpotter(&m); PybindKeywordSpotter(&m);
PybindDisplay(&m); PybindDisplay(&m);
PybindOfflineStream(&m); PybindOfflineStream(&m);

View File

@@ -0,0 +1,72 @@
// sherpa-onnx/python/csrc/tensorrt-config.cc
//
// Copyright (c) 2024 Uniphore (Author: Manickavela A)
#include "sherpa-onnx/python/csrc/tensorrt-config.h"
#include <string>
#include <memory>
#include "sherpa-onnx/csrc/provider-config.h"
namespace sherpa_onnx {
void PybindTensorrtConfig(py::module *m) {
using PyClass = TensorrtConfig;
py::class_<PyClass>(*m, "TensorrtConfig")
.def(py::init<>())
.def(py::init([](int32_t trt_max_workspace_size,
int32_t trt_max_partition_iterations,
int32_t trt_min_subgraph_size,
bool trt_fp16_enable,
bool trt_detailed_build_log,
bool trt_engine_cache_enable,
bool trt_timing_cache_enable,
const std::string &trt_engine_cache_path,
const std::string &trt_timing_cache_path,
bool trt_dump_subgraphs) -> std::unique_ptr<PyClass> {
auto ans = std::make_unique<PyClass>();
ans->trt_max_workspace_size = trt_max_workspace_size;
ans->trt_max_partition_iterations = trt_max_partition_iterations;
ans->trt_min_subgraph_size = trt_min_subgraph_size;
ans->trt_fp16_enable = trt_fp16_enable;
ans->trt_detailed_build_log = trt_detailed_build_log;
ans->trt_engine_cache_enable = trt_engine_cache_enable;
ans->trt_timing_cache_enable = trt_timing_cache_enable;
ans->trt_engine_cache_path = trt_engine_cache_path;
ans->trt_timing_cache_path = trt_timing_cache_path;
ans->trt_dump_subgraphs = trt_dump_subgraphs;
return ans;
}),
py::arg("trt_max_workspace_size") = 2147483647,
py::arg("trt_max_partition_iterations") = 10,
py::arg("trt_min_subgraph_size") = 5,
py::arg("trt_fp16_enable") = true,
py::arg("trt_detailed_build_log") = false,
py::arg("trt_engine_cache_enable") = true,
py::arg("trt_timing_cache_enable") = true,
py::arg("trt_engine_cache_path") = ".",
py::arg("trt_timing_cache_path") = ".",
py::arg("trt_dump_subgraphs") = false)
.def_readwrite("trt_max_workspace_size",
&PyClass::trt_max_workspace_size)
.def_readwrite("trt_max_partition_iterations",
&PyClass::trt_max_partition_iterations)
.def_readwrite("trt_min_subgraph_size", &PyClass::trt_min_subgraph_size)
.def_readwrite("trt_fp16_enable", &PyClass::trt_fp16_enable)
.def_readwrite("trt_detailed_build_log",
&PyClass::trt_detailed_build_log)
.def_readwrite("trt_engine_cache_enable",
&PyClass::trt_engine_cache_enable)
.def_readwrite("trt_timing_cache_enable",
&PyClass::trt_timing_cache_enable)
.def_readwrite("trt_engine_cache_path", &PyClass::trt_engine_cache_path)
.def_readwrite("trt_timing_cache_path", &PyClass::trt_timing_cache_path)
.def_readwrite("trt_dump_subgraphs", &PyClass::trt_dump_subgraphs)
.def("__str__", &PyClass::ToString)
.def("validate", &PyClass::Validate);
}
} // namespace sherpa_onnx

View File

@@ -0,0 +1,16 @@
// sherpa-onnx/python/csrc/tensorrt-config.h
//
// Copyright (c) 2024 Uniphore (Author: Manickavela A)
#ifndef SHERPA_ONNX_PYTHON_CSRC_TENSORRT_CONFIG_H_
#define SHERPA_ONNX_PYTHON_CSRC_TENSORRT_CONFIG_H_
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
namespace sherpa_onnx {
void PybindTensorrtConfig(py::module *m);
}
#endif // SHERPA_ONNX_PYTHON_CSRC_TENSORRT_CONFIG_H_

View File

@@ -9,6 +9,7 @@ from _sherpa_onnx import (
OnlineModelConfig, OnlineModelConfig,
OnlineTransducerModelConfig, OnlineTransducerModelConfig,
OnlineStream, OnlineStream,
ProviderConfig,
) )
from _sherpa_onnx import KeywordSpotter as _KeywordSpotter from _sherpa_onnx import KeywordSpotter as _KeywordSpotter
@@ -41,6 +42,7 @@ class KeywordSpotter(object):
keywords_threshold: float = 0.25, keywords_threshold: float = 0.25,
num_trailing_blanks: int = 1, num_trailing_blanks: int = 1,
provider: str = "cpu", provider: str = "cpu",
device: int = 0,
): ):
""" """
Please refer to Please refer to
@@ -85,6 +87,8 @@ class KeywordSpotter(object):
between each other. between each other.
provider: provider:
onnxruntime execution providers. Valid values are: cpu, cuda, coreml. onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
device:
onnxruntime cuda device index.
""" """
_assert_file_exists(tokens) _assert_file_exists(tokens)
_assert_file_exists(encoder) _assert_file_exists(encoder)
@@ -99,11 +103,16 @@ class KeywordSpotter(object):
joiner=joiner, joiner=joiner,
) )
provider_config = ProviderConfig(
provider=provider,
device = device,
)
model_config = OnlineModelConfig( model_config = OnlineModelConfig(
transducer=transducer_config, transducer=transducer_config,
tokens=tokens, tokens=tokens,
num_threads=num_threads, num_threads=num_threads,
provider=provider, provider_config=provider_config,
) )
feat_config = FeatureExtractorConfig( feat_config = FeatureExtractorConfig(

View File

@@ -11,6 +11,9 @@ from _sherpa_onnx import (
) )
from _sherpa_onnx import OnlineRecognizer as _Recognizer from _sherpa_onnx import OnlineRecognizer as _Recognizer
from _sherpa_onnx import ( from _sherpa_onnx import (
CudaConfig,
TensorrtConfig,
ProviderConfig,
OnlineRecognizerConfig, OnlineRecognizerConfig,
OnlineRecognizerResult, OnlineRecognizerResult,
OnlineStream, OnlineStream,
@@ -56,7 +59,6 @@ class OnlineRecognizer(object):
hotwords_score: float = 1.5, hotwords_score: float = 1.5,
blank_penalty: float = 0.0, blank_penalty: float = 0.0,
hotwords_file: str = "", hotwords_file: str = "",
provider: str = "cpu",
model_type: str = "", model_type: str = "",
modeling_unit: str = "cjkchar", modeling_unit: str = "cjkchar",
bpe_vocab: str = "", bpe_vocab: str = "",
@@ -66,6 +68,19 @@ class OnlineRecognizer(object):
debug: bool = False, debug: bool = False,
rule_fsts: str = "", rule_fsts: str = "",
rule_fars: str = "", rule_fars: str = "",
provider: str = "cpu",
device: int = 0,
cudnn_conv_algo_search: int = 1,
trt_max_workspace_size: int = 2147483647,
trt_max_partition_iterations: int = 10,
trt_min_subgraph_size: int = 5,
trt_fp16_enable: bool = True,
trt_detailed_build_log: bool = False,
trt_engine_cache_enable: bool = True,
trt_timing_cache_enable: bool = True,
trt_engine_cache_path: str ="",
trt_timing_cache_path: str ="",
trt_dump_subgraphs: bool = False,
): ):
""" """
Please refer to Please refer to
@@ -135,8 +150,6 @@ class OnlineRecognizer(object):
Temperature scaling for output symbol confidence estiamation. Temperature scaling for output symbol confidence estiamation.
It affects only confidence values, the decoding uses the original It affects only confidence values, the decoding uses the original
logits without temperature. logits without temperature.
provider:
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
model_type: model_type:
Online transducer model type. Valid values are: conformer, lstm, Online transducer model type. Valid values are: conformer, lstm,
zipformer, zipformer2. All other values lead to loading the model twice. zipformer, zipformer2. All other values lead to loading the model twice.
@@ -156,6 +169,32 @@ class OnlineRecognizer(object):
rule_fars: rule_fars:
If not empty, it specifies fst archives for inverse text normalization. If not empty, it specifies fst archives for inverse text normalization.
If there are multiple archives, they are separated by a comma. If there are multiple archives, they are separated by a comma.
provider:
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
device:
onnxruntime cuda device index.
cudnn_conv_algo_search:
onxrt CuDNN convolution search algorithm selection. CUDA EP
trt_max_workspace_size:
Set TensorRT EP GPU memory usage limit. TensorRT EP
trt_max_partition_iterations:
Limit partitioning iterations for model conversion. TensorRT EP
trt_min_subgraph_size:
Set minimum size for subgraphs in partitioning. TensorRT EP
trt_fp16_enable: bool = True,
Enable FP16 precision for faster performance. TensorRT EP
trt_detailed_build_log: bool = False,
Enable detailed logging of build steps. TensorRT EP
trt_engine_cache_enable: bool = True,
Enable caching of TensorRT engines. TensorRT EP
trt_timing_cache_enable: bool = True,
"Enable use of timing cache to speed up builds." TensorRT EP
trt_engine_cache_path: str ="",
"Set path to store cached TensorRT engines." TensorRT EP
trt_timing_cache_path: str ="",
"Set path for storing timing cache." TensorRT EP
trt_dump_subgraphs: bool = False,
"Dump optimized subgraphs for debugging." TensorRT EP
""" """
self = cls.__new__(cls) self = cls.__new__(cls)
_assert_file_exists(tokens) _assert_file_exists(tokens)
@@ -171,11 +210,35 @@ class OnlineRecognizer(object):
joiner=joiner, joiner=joiner,
) )
cuda_config = CudaConfig(
cudnn_conv_algo_search=cudnn_conv_algo_search,
)
trt_config = TensorrtConfig(
trt_max_workspace_size=trt_max_workspace_size,
trt_max_partition_iterations=trt_max_partition_iterations,
trt_min_subgraph_size=trt_min_subgraph_size,
trt_fp16_enable=trt_fp16_enable,
trt_detailed_build_log=trt_detailed_build_log,
trt_engine_cache_enable=trt_engine_cache_enable,
trt_timing_cache_enable=trt_timing_cache_enable,
trt_engine_cache_path=trt_engine_cache_path,
trt_timing_cache_path=trt_timing_cache_path,
trt_dump_subgraphs=trt_dump_subgraphs,
)
provider_config = ProviderConfig(
trt_config=trt_config,
cuda_config=cuda_config,
provider=provider,
device=device,
)
model_config = OnlineModelConfig( model_config = OnlineModelConfig(
transducer=transducer_config, transducer=transducer_config,
tokens=tokens, tokens=tokens,
num_threads=num_threads, num_threads=num_threads,
provider=provider, provider_config=provider_config,
model_type=model_type, model_type=model_type,
modeling_unit=modeling_unit, modeling_unit=modeling_unit,
bpe_vocab=bpe_vocab, bpe_vocab=bpe_vocab,
@@ -251,6 +314,7 @@ class OnlineRecognizer(object):
debug: bool = False, debug: bool = False,
rule_fsts: str = "", rule_fsts: str = "",
rule_fars: str = "", rule_fars: str = "",
device: int = 0,
): ):
""" """
Please refer to Please refer to
@@ -301,6 +365,8 @@ class OnlineRecognizer(object):
rule_fars: rule_fars:
If not empty, it specifies fst archives for inverse text normalization. If not empty, it specifies fst archives for inverse text normalization.
If there are multiple archives, they are separated by a comma. If there are multiple archives, they are separated by a comma.
device:
onnxruntime cuda device index.
""" """
self = cls.__new__(cls) self = cls.__new__(cls)
_assert_file_exists(tokens) _assert_file_exists(tokens)
@@ -314,11 +380,16 @@ class OnlineRecognizer(object):
decoder=decoder, decoder=decoder,
) )
provider_config = ProviderConfig(
provider=provider,
device=device,
)
model_config = OnlineModelConfig( model_config = OnlineModelConfig(
paraformer=paraformer_config, paraformer=paraformer_config,
tokens=tokens, tokens=tokens,
num_threads=num_threads, num_threads=num_threads,
provider=provider, provider_config=provider_config,
model_type="paraformer", model_type="paraformer",
debug=debug, debug=debug,
) )
@@ -367,6 +438,7 @@ class OnlineRecognizer(object):
debug: bool = False, debug: bool = False,
rule_fsts: str = "", rule_fsts: str = "",
rule_fars: str = "", rule_fars: str = "",
device: int = 0,
): ):
""" """
Please refer to Please refer to
@@ -421,6 +493,8 @@ class OnlineRecognizer(object):
rule_fars: rule_fars:
If not empty, it specifies fst archives for inverse text normalization. If not empty, it specifies fst archives for inverse text normalization.
If there are multiple archives, they are separated by a comma. If there are multiple archives, they are separated by a comma.
device:
onnxruntime cuda device index.
""" """
self = cls.__new__(cls) self = cls.__new__(cls)
_assert_file_exists(tokens) _assert_file_exists(tokens)
@@ -430,11 +504,16 @@ class OnlineRecognizer(object):
zipformer2_ctc_config = OnlineZipformer2CtcModelConfig(model=model) zipformer2_ctc_config = OnlineZipformer2CtcModelConfig(model=model)
provider_config = ProviderConfig(
provider=provider,
device=device,
)
model_config = OnlineModelConfig( model_config = OnlineModelConfig(
zipformer2_ctc=zipformer2_ctc_config, zipformer2_ctc=zipformer2_ctc_config,
tokens=tokens, tokens=tokens,
num_threads=num_threads, num_threads=num_threads,
provider=provider, provider_config=provider_config,
debug=debug, debug=debug,
) )
@@ -486,6 +565,7 @@ class OnlineRecognizer(object):
debug: bool = False, debug: bool = False,
rule_fsts: str = "", rule_fsts: str = "",
rule_fars: str = "", rule_fars: str = "",
device: int = 0,
): ):
""" """
Please refer to Please refer to
@@ -535,6 +615,8 @@ class OnlineRecognizer(object):
rule_fars: rule_fars:
If not empty, it specifies fst archives for inverse text normalization. If not empty, it specifies fst archives for inverse text normalization.
If there are multiple archives, they are separated by a comma. If there are multiple archives, they are separated by a comma.
device:
onnxruntime cuda device index.
""" """
self = cls.__new__(cls) self = cls.__new__(cls)
_assert_file_exists(tokens) _assert_file_exists(tokens)
@@ -546,11 +628,16 @@ class OnlineRecognizer(object):
model=model, model=model,
) )
provider_config = ProviderConfig(
provider=provider,
device=device,
)
model_config = OnlineModelConfig( model_config = OnlineModelConfig(
nemo_ctc=nemo_ctc_config, nemo_ctc=nemo_ctc_config,
tokens=tokens, tokens=tokens,
num_threads=num_threads, num_threads=num_threads,
provider=provider, provider_config=provider_config,
debug=debug, debug=debug,
) )
@@ -598,6 +685,7 @@ class OnlineRecognizer(object):
debug: bool = False, debug: bool = False,
rule_fsts: str = "", rule_fsts: str = "",
rule_fars: str = "", rule_fars: str = "",
device: int = 0,
): ):
""" """
Please refer to Please refer to
@@ -650,6 +738,8 @@ class OnlineRecognizer(object):
rule_fars: rule_fars:
If not empty, it specifies fst archives for inverse text normalization. If not empty, it specifies fst archives for inverse text normalization.
If there are multiple archives, they are separated by a comma. If there are multiple archives, they are separated by a comma.
device:
onnxruntime cuda device index.
""" """
self = cls.__new__(cls) self = cls.__new__(cls)
_assert_file_exists(tokens) _assert_file_exists(tokens)
@@ -663,11 +753,16 @@ class OnlineRecognizer(object):
num_left_chunks=num_left_chunks, num_left_chunks=num_left_chunks,
) )
provider_config = ProviderConfig(
provider=provider,
device=device,
)
model_config = OnlineModelConfig( model_config = OnlineModelConfig(
wenet_ctc=wenet_ctc_config, wenet_ctc=wenet_ctc_config,
tokens=tokens, tokens=tokens,
num_threads=num_threads, num_threads=num_threads,
provider=provider, provider_config=provider_config,
debug=debug, debug=debug,
) )