Add config for TensorRT and CUDA execution provider (#992)

Signed-off-by: manickavela1998@gmail.com <manickavela1998@gmail.com>
Signed-off-by: manickavela1998@gmail.com <manickavela.arumugam@uniphore.com>
This commit is contained in:
Manix
2024-07-05 12:48:37 +05:30
committed by GitHub
parent f5e9a162d1
commit 55decb7bee
21 changed files with 622 additions and 49 deletions

View File

@@ -9,6 +9,7 @@ from _sherpa_onnx import (
OnlineModelConfig,
OnlineTransducerModelConfig,
OnlineStream,
ProviderConfig,
)
from _sherpa_onnx import KeywordSpotter as _KeywordSpotter
@@ -41,6 +42,7 @@ class KeywordSpotter(object):
keywords_threshold: float = 0.25,
num_trailing_blanks: int = 1,
provider: str = "cpu",
device: int = 0,
):
"""
Please refer to
@@ -85,6 +87,8 @@ class KeywordSpotter(object):
between each other.
provider:
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
device:
onnxruntime cuda device index.
"""
_assert_file_exists(tokens)
_assert_file_exists(encoder)
@@ -99,11 +103,16 @@ class KeywordSpotter(object):
joiner=joiner,
)
provider_config = ProviderConfig(
provider=provider,
device = device,
)
model_config = OnlineModelConfig(
transducer=transducer_config,
tokens=tokens,
num_threads=num_threads,
provider=provider,
provider_config=provider_config,
)
feat_config = FeatureExtractorConfig(

View File

@@ -11,6 +11,9 @@ from _sherpa_onnx import (
)
from _sherpa_onnx import OnlineRecognizer as _Recognizer
from _sherpa_onnx import (
CudaConfig,
TensorrtConfig,
ProviderConfig,
OnlineRecognizerConfig,
OnlineRecognizerResult,
OnlineStream,
@@ -56,7 +59,6 @@ class OnlineRecognizer(object):
hotwords_score: float = 1.5,
blank_penalty: float = 0.0,
hotwords_file: str = "",
provider: str = "cpu",
model_type: str = "",
modeling_unit: str = "cjkchar",
bpe_vocab: str = "",
@@ -66,6 +68,19 @@ class OnlineRecognizer(object):
debug: bool = False,
rule_fsts: str = "",
rule_fars: str = "",
provider: str = "cpu",
device: int = 0,
cudnn_conv_algo_search: int = 1,
trt_max_workspace_size: int = 2147483647,
trt_max_partition_iterations: int = 10,
trt_min_subgraph_size: int = 5,
trt_fp16_enable: bool = True,
trt_detailed_build_log: bool = False,
trt_engine_cache_enable: bool = True,
trt_timing_cache_enable: bool = True,
trt_engine_cache_path: str ="",
trt_timing_cache_path: str ="",
trt_dump_subgraphs: bool = False,
):
"""
Please refer to
@@ -135,8 +150,6 @@ class OnlineRecognizer(object):
Temperature scaling for output symbol confidence estiamation.
It affects only confidence values, the decoding uses the original
logits without temperature.
provider:
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
model_type:
Online transducer model type. Valid values are: conformer, lstm,
zipformer, zipformer2. All other values lead to loading the model twice.
@@ -156,6 +169,32 @@ class OnlineRecognizer(object):
rule_fars:
If not empty, it specifies fst archives for inverse text normalization.
If there are multiple archives, they are separated by a comma.
provider:
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
device:
onnxruntime cuda device index.
cudnn_conv_algo_search:
onxrt CuDNN convolution search algorithm selection. CUDA EP
trt_max_workspace_size:
Set TensorRT EP GPU memory usage limit. TensorRT EP
trt_max_partition_iterations:
Limit partitioning iterations for model conversion. TensorRT EP
trt_min_subgraph_size:
Set minimum size for subgraphs in partitioning. TensorRT EP
trt_fp16_enable: bool = True,
Enable FP16 precision for faster performance. TensorRT EP
trt_detailed_build_log: bool = False,
Enable detailed logging of build steps. TensorRT EP
trt_engine_cache_enable: bool = True,
Enable caching of TensorRT engines. TensorRT EP
trt_timing_cache_enable: bool = True,
"Enable use of timing cache to speed up builds." TensorRT EP
trt_engine_cache_path: str ="",
"Set path to store cached TensorRT engines." TensorRT EP
trt_timing_cache_path: str ="",
"Set path for storing timing cache." TensorRT EP
trt_dump_subgraphs: bool = False,
"Dump optimized subgraphs for debugging." TensorRT EP
"""
self = cls.__new__(cls)
_assert_file_exists(tokens)
@@ -171,11 +210,35 @@ class OnlineRecognizer(object):
joiner=joiner,
)
cuda_config = CudaConfig(
cudnn_conv_algo_search=cudnn_conv_algo_search,
)
trt_config = TensorrtConfig(
trt_max_workspace_size=trt_max_workspace_size,
trt_max_partition_iterations=trt_max_partition_iterations,
trt_min_subgraph_size=trt_min_subgraph_size,
trt_fp16_enable=trt_fp16_enable,
trt_detailed_build_log=trt_detailed_build_log,
trt_engine_cache_enable=trt_engine_cache_enable,
trt_timing_cache_enable=trt_timing_cache_enable,
trt_engine_cache_path=trt_engine_cache_path,
trt_timing_cache_path=trt_timing_cache_path,
trt_dump_subgraphs=trt_dump_subgraphs,
)
provider_config = ProviderConfig(
trt_config=trt_config,
cuda_config=cuda_config,
provider=provider,
device=device,
)
model_config = OnlineModelConfig(
transducer=transducer_config,
tokens=tokens,
num_threads=num_threads,
provider=provider,
provider_config=provider_config,
model_type=model_type,
modeling_unit=modeling_unit,
bpe_vocab=bpe_vocab,
@@ -251,6 +314,7 @@ class OnlineRecognizer(object):
debug: bool = False,
rule_fsts: str = "",
rule_fars: str = "",
device: int = 0,
):
"""
Please refer to
@@ -301,6 +365,8 @@ class OnlineRecognizer(object):
rule_fars:
If not empty, it specifies fst archives for inverse text normalization.
If there are multiple archives, they are separated by a comma.
device:
onnxruntime cuda device index.
"""
self = cls.__new__(cls)
_assert_file_exists(tokens)
@@ -314,11 +380,16 @@ class OnlineRecognizer(object):
decoder=decoder,
)
provider_config = ProviderConfig(
provider=provider,
device=device,
)
model_config = OnlineModelConfig(
paraformer=paraformer_config,
tokens=tokens,
num_threads=num_threads,
provider=provider,
provider_config=provider_config,
model_type="paraformer",
debug=debug,
)
@@ -367,6 +438,7 @@ class OnlineRecognizer(object):
debug: bool = False,
rule_fsts: str = "",
rule_fars: str = "",
device: int = 0,
):
"""
Please refer to
@@ -421,6 +493,8 @@ class OnlineRecognizer(object):
rule_fars:
If not empty, it specifies fst archives for inverse text normalization.
If there are multiple archives, they are separated by a comma.
device:
onnxruntime cuda device index.
"""
self = cls.__new__(cls)
_assert_file_exists(tokens)
@@ -430,11 +504,16 @@ class OnlineRecognizer(object):
zipformer2_ctc_config = OnlineZipformer2CtcModelConfig(model=model)
provider_config = ProviderConfig(
provider=provider,
device=device,
)
model_config = OnlineModelConfig(
zipformer2_ctc=zipformer2_ctc_config,
tokens=tokens,
num_threads=num_threads,
provider=provider,
provider_config=provider_config,
debug=debug,
)
@@ -486,6 +565,7 @@ class OnlineRecognizer(object):
debug: bool = False,
rule_fsts: str = "",
rule_fars: str = "",
device: int = 0,
):
"""
Please refer to
@@ -535,6 +615,8 @@ class OnlineRecognizer(object):
rule_fars:
If not empty, it specifies fst archives for inverse text normalization.
If there are multiple archives, they are separated by a comma.
device:
onnxruntime cuda device index.
"""
self = cls.__new__(cls)
_assert_file_exists(tokens)
@@ -546,11 +628,16 @@ class OnlineRecognizer(object):
model=model,
)
provider_config = ProviderConfig(
provider=provider,
device=device,
)
model_config = OnlineModelConfig(
nemo_ctc=nemo_ctc_config,
tokens=tokens,
num_threads=num_threads,
provider=provider,
provider_config=provider_config,
debug=debug,
)
@@ -598,6 +685,7 @@ class OnlineRecognizer(object):
debug: bool = False,
rule_fsts: str = "",
rule_fars: str = "",
device: int = 0,
):
"""
Please refer to
@@ -650,6 +738,8 @@ class OnlineRecognizer(object):
rule_fars:
If not empty, it specifies fst archives for inverse text normalization.
If there are multiple archives, they are separated by a comma.
device:
onnxruntime cuda device index.
"""
self = cls.__new__(cls)
_assert_file_exists(tokens)
@@ -663,11 +753,16 @@ class OnlineRecognizer(object):
num_left_chunks=num_left_chunks,
)
provider_config = ProviderConfig(
provider=provider,
device=device,
)
model_config = OnlineModelConfig(
wenet_ctc=wenet_ctc_config,
tokens=tokens,
num_threads=num_threads,
provider=provider,
provider_config=provider_config,
debug=debug,
)