diff --git a/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py b/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py index 0e1a0494..9c321384 100644 --- a/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py @@ -112,6 +112,7 @@ class OfflineRecognizer(object): feature_dim: int = 80, decoding_method: str = "greedy_search", debug: bool = False, + provider: str = "cpu", ): """ Please refer to @@ -138,6 +139,8 @@ class OfflineRecognizer(object): Valid values are greedy_search, modified_beam_search. debug: True to show debug messages. + provider: + onnxruntime execution providers. Valid values are: cpu, cuda, coreml. """ self = cls.__new__(cls) model_config = OfflineModelConfig( @@ -145,6 +148,7 @@ class OfflineRecognizer(object): tokens=tokens, num_threads=num_threads, debug=debug, + provider=provider, ) feat_config = OfflineFeatureExtractorConfig( @@ -170,6 +174,7 @@ class OfflineRecognizer(object): feature_dim: int = 80, decoding_method: str = "greedy_search", debug: bool = False, + provider: str = "cpu", ): """ Please refer to @@ -196,6 +201,8 @@ class OfflineRecognizer(object): Valid values are greedy_search, modified_beam_search. debug: True to show debug messages. + provider: + onnxruntime execution providers. Valid values are: cpu, cuda, coreml. """ self = cls.__new__(cls) model_config = OfflineModelConfig( @@ -203,6 +210,7 @@ class OfflineRecognizer(object): tokens=tokens, num_threads=num_threads, debug=debug, + provider=provider, ) feat_config = OfflineFeatureExtractorConfig( diff --git a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py index c981bc04..20a84d7d 100644 --- a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py @@ -131,7 +131,7 @@ class OnlineRecognizer(object): self.recognizer = _Recognizer(recognizer_config) self.config = recognizer_config - def create_stream(self, contexts_list : Optional[List[List[int]]] = None): + def create_stream(self, contexts_list: Optional[List[List[int]]] = None): if contexts_list is None: return self.recognizer.create_stream() else: diff --git a/sherpa-onnx/python/tests/test_offline_recognizer.py b/sherpa-onnx/python/tests/test_offline_recognizer.py index bb6c994c..f6d36a53 100755 --- a/sherpa-onnx/python/tests/test_offline_recognizer.py +++ b/sherpa-onnx/python/tests/test_offline_recognizer.py @@ -72,6 +72,7 @@ class TestOfflineRecognizer(unittest.TestCase): joiner=joiner, tokens=tokens, num_threads=1, + provider="cpu", ) s = recognizer.create_stream() @@ -106,6 +107,7 @@ class TestOfflineRecognizer(unittest.TestCase): joiner=joiner, tokens=tokens, num_threads=1, + provider="cpu", ) s0 = recognizer.create_stream() @@ -143,6 +145,7 @@ class TestOfflineRecognizer(unittest.TestCase): paraformer=model, tokens=tokens, num_threads=1, + provider="cpu", ) s = recognizer.create_stream() @@ -172,6 +175,7 @@ class TestOfflineRecognizer(unittest.TestCase): paraformer=model, tokens=tokens, num_threads=1, + provider="cpu", ) s0 = recognizer.create_stream() @@ -214,6 +218,7 @@ class TestOfflineRecognizer(unittest.TestCase): model=model, tokens=tokens, num_threads=1, + provider="cpu", ) s = recognizer.create_stream() @@ -242,6 +247,7 @@ class TestOfflineRecognizer(unittest.TestCase): model=model, tokens=tokens, num_threads=1, + provider="cpu", ) s0 = recognizer.create_stream() diff --git a/sherpa-onnx/python/tests/test_online_recognizer.py b/sherpa-onnx/python/tests/test_online_recognizer.py index 157cfd8d..0769b5e6 100755 --- a/sherpa-onnx/python/tests/test_online_recognizer.py +++ b/sherpa-onnx/python/tests/test_online_recognizer.py @@ -72,6 +72,7 @@ class TestOnlineRecognizer(unittest.TestCase): tokens=tokens, num_threads=1, decoding_method=decoding_method, + provider="cpu", ) s = recognizer.create_stream() samples, sample_rate = read_wave(wave0) @@ -115,6 +116,7 @@ class TestOnlineRecognizer(unittest.TestCase): tokens=tokens, num_threads=1, decoding_method=decoding_method, + provider="cpu", ) streams = [] waves = [wave0, wave1, wave2, wave3, wave4]