Add config for TensorRT and CUDA execution provider (#992)
Signed-off-by: manickavela1998@gmail.com <manickavela1998@gmail.com> Signed-off-by: manickavela1998@gmail.com <manickavela.arumugam@uniphore.com>
This commit is contained in:
@@ -9,6 +9,7 @@ from _sherpa_onnx import (
|
||||
OnlineModelConfig,
|
||||
OnlineTransducerModelConfig,
|
||||
OnlineStream,
|
||||
ProviderConfig,
|
||||
)
|
||||
|
||||
from _sherpa_onnx import KeywordSpotter as _KeywordSpotter
|
||||
@@ -41,6 +42,7 @@ class KeywordSpotter(object):
|
||||
keywords_threshold: float = 0.25,
|
||||
num_trailing_blanks: int = 1,
|
||||
provider: str = "cpu",
|
||||
device: int = 0,
|
||||
):
|
||||
"""
|
||||
Please refer to
|
||||
@@ -85,6 +87,8 @@ class KeywordSpotter(object):
|
||||
between each other.
|
||||
provider:
|
||||
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
|
||||
device:
|
||||
onnxruntime cuda device index.
|
||||
"""
|
||||
_assert_file_exists(tokens)
|
||||
_assert_file_exists(encoder)
|
||||
@@ -99,11 +103,16 @@ class KeywordSpotter(object):
|
||||
joiner=joiner,
|
||||
)
|
||||
|
||||
provider_config = ProviderConfig(
|
||||
provider=provider,
|
||||
device = device,
|
||||
)
|
||||
|
||||
model_config = OnlineModelConfig(
|
||||
transducer=transducer_config,
|
||||
tokens=tokens,
|
||||
num_threads=num_threads,
|
||||
provider=provider,
|
||||
provider_config=provider_config,
|
||||
)
|
||||
|
||||
feat_config = FeatureExtractorConfig(
|
||||
|
||||
@@ -11,6 +11,9 @@ from _sherpa_onnx import (
|
||||
)
|
||||
from _sherpa_onnx import OnlineRecognizer as _Recognizer
|
||||
from _sherpa_onnx import (
|
||||
CudaConfig,
|
||||
TensorrtConfig,
|
||||
ProviderConfig,
|
||||
OnlineRecognizerConfig,
|
||||
OnlineRecognizerResult,
|
||||
OnlineStream,
|
||||
@@ -56,7 +59,6 @@ class OnlineRecognizer(object):
|
||||
hotwords_score: float = 1.5,
|
||||
blank_penalty: float = 0.0,
|
||||
hotwords_file: str = "",
|
||||
provider: str = "cpu",
|
||||
model_type: str = "",
|
||||
modeling_unit: str = "cjkchar",
|
||||
bpe_vocab: str = "",
|
||||
@@ -66,6 +68,19 @@ class OnlineRecognizer(object):
|
||||
debug: bool = False,
|
||||
rule_fsts: str = "",
|
||||
rule_fars: str = "",
|
||||
provider: str = "cpu",
|
||||
device: int = 0,
|
||||
cudnn_conv_algo_search: int = 1,
|
||||
trt_max_workspace_size: int = 2147483647,
|
||||
trt_max_partition_iterations: int = 10,
|
||||
trt_min_subgraph_size: int = 5,
|
||||
trt_fp16_enable: bool = True,
|
||||
trt_detailed_build_log: bool = False,
|
||||
trt_engine_cache_enable: bool = True,
|
||||
trt_timing_cache_enable: bool = True,
|
||||
trt_engine_cache_path: str ="",
|
||||
trt_timing_cache_path: str ="",
|
||||
trt_dump_subgraphs: bool = False,
|
||||
):
|
||||
"""
|
||||
Please refer to
|
||||
@@ -135,8 +150,6 @@ class OnlineRecognizer(object):
|
||||
Temperature scaling for output symbol confidence estiamation.
|
||||
It affects only confidence values, the decoding uses the original
|
||||
logits without temperature.
|
||||
provider:
|
||||
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
|
||||
model_type:
|
||||
Online transducer model type. Valid values are: conformer, lstm,
|
||||
zipformer, zipformer2. All other values lead to loading the model twice.
|
||||
@@ -156,6 +169,32 @@ class OnlineRecognizer(object):
|
||||
rule_fars:
|
||||
If not empty, it specifies fst archives for inverse text normalization.
|
||||
If there are multiple archives, they are separated by a comma.
|
||||
provider:
|
||||
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
|
||||
device:
|
||||
onnxruntime cuda device index.
|
||||
cudnn_conv_algo_search:
|
||||
onxrt CuDNN convolution search algorithm selection. CUDA EP
|
||||
trt_max_workspace_size:
|
||||
Set TensorRT EP GPU memory usage limit. TensorRT EP
|
||||
trt_max_partition_iterations:
|
||||
Limit partitioning iterations for model conversion. TensorRT EP
|
||||
trt_min_subgraph_size:
|
||||
Set minimum size for subgraphs in partitioning. TensorRT EP
|
||||
trt_fp16_enable: bool = True,
|
||||
Enable FP16 precision for faster performance. TensorRT EP
|
||||
trt_detailed_build_log: bool = False,
|
||||
Enable detailed logging of build steps. TensorRT EP
|
||||
trt_engine_cache_enable: bool = True,
|
||||
Enable caching of TensorRT engines. TensorRT EP
|
||||
trt_timing_cache_enable: bool = True,
|
||||
"Enable use of timing cache to speed up builds." TensorRT EP
|
||||
trt_engine_cache_path: str ="",
|
||||
"Set path to store cached TensorRT engines." TensorRT EP
|
||||
trt_timing_cache_path: str ="",
|
||||
"Set path for storing timing cache." TensorRT EP
|
||||
trt_dump_subgraphs: bool = False,
|
||||
"Dump optimized subgraphs for debugging." TensorRT EP
|
||||
"""
|
||||
self = cls.__new__(cls)
|
||||
_assert_file_exists(tokens)
|
||||
@@ -171,11 +210,35 @@ class OnlineRecognizer(object):
|
||||
joiner=joiner,
|
||||
)
|
||||
|
||||
cuda_config = CudaConfig(
|
||||
cudnn_conv_algo_search=cudnn_conv_algo_search,
|
||||
)
|
||||
|
||||
trt_config = TensorrtConfig(
|
||||
trt_max_workspace_size=trt_max_workspace_size,
|
||||
trt_max_partition_iterations=trt_max_partition_iterations,
|
||||
trt_min_subgraph_size=trt_min_subgraph_size,
|
||||
trt_fp16_enable=trt_fp16_enable,
|
||||
trt_detailed_build_log=trt_detailed_build_log,
|
||||
trt_engine_cache_enable=trt_engine_cache_enable,
|
||||
trt_timing_cache_enable=trt_timing_cache_enable,
|
||||
trt_engine_cache_path=trt_engine_cache_path,
|
||||
trt_timing_cache_path=trt_timing_cache_path,
|
||||
trt_dump_subgraphs=trt_dump_subgraphs,
|
||||
)
|
||||
|
||||
provider_config = ProviderConfig(
|
||||
trt_config=trt_config,
|
||||
cuda_config=cuda_config,
|
||||
provider=provider,
|
||||
device=device,
|
||||
)
|
||||
|
||||
model_config = OnlineModelConfig(
|
||||
transducer=transducer_config,
|
||||
tokens=tokens,
|
||||
num_threads=num_threads,
|
||||
provider=provider,
|
||||
provider_config=provider_config,
|
||||
model_type=model_type,
|
||||
modeling_unit=modeling_unit,
|
||||
bpe_vocab=bpe_vocab,
|
||||
@@ -251,6 +314,7 @@ class OnlineRecognizer(object):
|
||||
debug: bool = False,
|
||||
rule_fsts: str = "",
|
||||
rule_fars: str = "",
|
||||
device: int = 0,
|
||||
):
|
||||
"""
|
||||
Please refer to
|
||||
@@ -301,6 +365,8 @@ class OnlineRecognizer(object):
|
||||
rule_fars:
|
||||
If not empty, it specifies fst archives for inverse text normalization.
|
||||
If there are multiple archives, they are separated by a comma.
|
||||
device:
|
||||
onnxruntime cuda device index.
|
||||
"""
|
||||
self = cls.__new__(cls)
|
||||
_assert_file_exists(tokens)
|
||||
@@ -314,11 +380,16 @@ class OnlineRecognizer(object):
|
||||
decoder=decoder,
|
||||
)
|
||||
|
||||
provider_config = ProviderConfig(
|
||||
provider=provider,
|
||||
device=device,
|
||||
)
|
||||
|
||||
model_config = OnlineModelConfig(
|
||||
paraformer=paraformer_config,
|
||||
tokens=tokens,
|
||||
num_threads=num_threads,
|
||||
provider=provider,
|
||||
provider_config=provider_config,
|
||||
model_type="paraformer",
|
||||
debug=debug,
|
||||
)
|
||||
@@ -367,6 +438,7 @@ class OnlineRecognizer(object):
|
||||
debug: bool = False,
|
||||
rule_fsts: str = "",
|
||||
rule_fars: str = "",
|
||||
device: int = 0,
|
||||
):
|
||||
"""
|
||||
Please refer to
|
||||
@@ -421,6 +493,8 @@ class OnlineRecognizer(object):
|
||||
rule_fars:
|
||||
If not empty, it specifies fst archives for inverse text normalization.
|
||||
If there are multiple archives, they are separated by a comma.
|
||||
device:
|
||||
onnxruntime cuda device index.
|
||||
"""
|
||||
self = cls.__new__(cls)
|
||||
_assert_file_exists(tokens)
|
||||
@@ -430,11 +504,16 @@ class OnlineRecognizer(object):
|
||||
|
||||
zipformer2_ctc_config = OnlineZipformer2CtcModelConfig(model=model)
|
||||
|
||||
provider_config = ProviderConfig(
|
||||
provider=provider,
|
||||
device=device,
|
||||
)
|
||||
|
||||
model_config = OnlineModelConfig(
|
||||
zipformer2_ctc=zipformer2_ctc_config,
|
||||
tokens=tokens,
|
||||
num_threads=num_threads,
|
||||
provider=provider,
|
||||
provider_config=provider_config,
|
||||
debug=debug,
|
||||
)
|
||||
|
||||
@@ -486,6 +565,7 @@ class OnlineRecognizer(object):
|
||||
debug: bool = False,
|
||||
rule_fsts: str = "",
|
||||
rule_fars: str = "",
|
||||
device: int = 0,
|
||||
):
|
||||
"""
|
||||
Please refer to
|
||||
@@ -535,6 +615,8 @@ class OnlineRecognizer(object):
|
||||
rule_fars:
|
||||
If not empty, it specifies fst archives for inverse text normalization.
|
||||
If there are multiple archives, they are separated by a comma.
|
||||
device:
|
||||
onnxruntime cuda device index.
|
||||
"""
|
||||
self = cls.__new__(cls)
|
||||
_assert_file_exists(tokens)
|
||||
@@ -546,11 +628,16 @@ class OnlineRecognizer(object):
|
||||
model=model,
|
||||
)
|
||||
|
||||
provider_config = ProviderConfig(
|
||||
provider=provider,
|
||||
device=device,
|
||||
)
|
||||
|
||||
model_config = OnlineModelConfig(
|
||||
nemo_ctc=nemo_ctc_config,
|
||||
tokens=tokens,
|
||||
num_threads=num_threads,
|
||||
provider=provider,
|
||||
provider_config=provider_config,
|
||||
debug=debug,
|
||||
)
|
||||
|
||||
@@ -598,6 +685,7 @@ class OnlineRecognizer(object):
|
||||
debug: bool = False,
|
||||
rule_fsts: str = "",
|
||||
rule_fars: str = "",
|
||||
device: int = 0,
|
||||
):
|
||||
"""
|
||||
Please refer to
|
||||
@@ -650,6 +738,8 @@ class OnlineRecognizer(object):
|
||||
rule_fars:
|
||||
If not empty, it specifies fst archives for inverse text normalization.
|
||||
If there are multiple archives, they are separated by a comma.
|
||||
device:
|
||||
onnxruntime cuda device index.
|
||||
"""
|
||||
self = cls.__new__(cls)
|
||||
_assert_file_exists(tokens)
|
||||
@@ -663,11 +753,16 @@ class OnlineRecognizer(object):
|
||||
num_left_chunks=num_left_chunks,
|
||||
)
|
||||
|
||||
provider_config = ProviderConfig(
|
||||
provider=provider,
|
||||
device=device,
|
||||
)
|
||||
|
||||
model_config = OnlineModelConfig(
|
||||
wenet_ctc=wenet_ctc_config,
|
||||
tokens=tokens,
|
||||
num_threads=num_threads,
|
||||
provider=provider,
|
||||
provider_config=provider_config,
|
||||
debug=debug,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user