Add config for TensorRT and CUDA execution provider (#992)

Signed-off-by: manickavela1998@gmail.com <manickavela1998@gmail.com>
Signed-off-by: manickavela1998@gmail.com <manickavela.arumugam@uniphore.com>
This commit is contained in:
Manix
2024-07-05 12:48:37 +05:30
committed by GitHub
parent f5e9a162d1
commit 55decb7bee
21 changed files with 622 additions and 49 deletions

View File

@@ -9,6 +9,7 @@ from _sherpa_onnx import (
OnlineModelConfig,
OnlineTransducerModelConfig,
OnlineStream,
ProviderConfig,
)
from _sherpa_onnx import KeywordSpotter as _KeywordSpotter
@@ -41,6 +42,7 @@ class KeywordSpotter(object):
keywords_threshold: float = 0.25,
num_trailing_blanks: int = 1,
provider: str = "cpu",
device: int = 0,
):
"""
Please refer to
@@ -85,6 +87,8 @@ class KeywordSpotter(object):
between each other.
provider:
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
device:
onnxruntime cuda device index.
"""
_assert_file_exists(tokens)
_assert_file_exists(encoder)
@@ -99,11 +103,16 @@ class KeywordSpotter(object):
joiner=joiner,
)
provider_config = ProviderConfig(
provider=provider,
device = device,
)
model_config = OnlineModelConfig(
transducer=transducer_config,
tokens=tokens,
num_threads=num_threads,
provider=provider,
provider_config=provider_config,
)
feat_config = FeatureExtractorConfig(