Support specifying providers in Python API (#198)
This commit is contained in:
@@ -112,6 +112,7 @@ class OfflineRecognizer(object):
|
|||||||
feature_dim: int = 80,
|
feature_dim: int = 80,
|
||||||
decoding_method: str = "greedy_search",
|
decoding_method: str = "greedy_search",
|
||||||
debug: bool = False,
|
debug: bool = False,
|
||||||
|
provider: str = "cpu",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Please refer to
|
Please refer to
|
||||||
@@ -138,6 +139,8 @@ class OfflineRecognizer(object):
|
|||||||
Valid values are greedy_search, modified_beam_search.
|
Valid values are greedy_search, modified_beam_search.
|
||||||
debug:
|
debug:
|
||||||
True to show debug messages.
|
True to show debug messages.
|
||||||
|
provider:
|
||||||
|
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
|
||||||
"""
|
"""
|
||||||
self = cls.__new__(cls)
|
self = cls.__new__(cls)
|
||||||
model_config = OfflineModelConfig(
|
model_config = OfflineModelConfig(
|
||||||
@@ -145,6 +148,7 @@ class OfflineRecognizer(object):
|
|||||||
tokens=tokens,
|
tokens=tokens,
|
||||||
num_threads=num_threads,
|
num_threads=num_threads,
|
||||||
debug=debug,
|
debug=debug,
|
||||||
|
provider=provider,
|
||||||
)
|
)
|
||||||
|
|
||||||
feat_config = OfflineFeatureExtractorConfig(
|
feat_config = OfflineFeatureExtractorConfig(
|
||||||
@@ -170,6 +174,7 @@ class OfflineRecognizer(object):
|
|||||||
feature_dim: int = 80,
|
feature_dim: int = 80,
|
||||||
decoding_method: str = "greedy_search",
|
decoding_method: str = "greedy_search",
|
||||||
debug: bool = False,
|
debug: bool = False,
|
||||||
|
provider: str = "cpu",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Please refer to
|
Please refer to
|
||||||
@@ -196,6 +201,8 @@ class OfflineRecognizer(object):
|
|||||||
Valid values are greedy_search, modified_beam_search.
|
Valid values are greedy_search, modified_beam_search.
|
||||||
debug:
|
debug:
|
||||||
True to show debug messages.
|
True to show debug messages.
|
||||||
|
provider:
|
||||||
|
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
|
||||||
"""
|
"""
|
||||||
self = cls.__new__(cls)
|
self = cls.__new__(cls)
|
||||||
model_config = OfflineModelConfig(
|
model_config = OfflineModelConfig(
|
||||||
@@ -203,6 +210,7 @@ class OfflineRecognizer(object):
|
|||||||
tokens=tokens,
|
tokens=tokens,
|
||||||
num_threads=num_threads,
|
num_threads=num_threads,
|
||||||
debug=debug,
|
debug=debug,
|
||||||
|
provider=provider,
|
||||||
)
|
)
|
||||||
|
|
||||||
feat_config = OfflineFeatureExtractorConfig(
|
feat_config = OfflineFeatureExtractorConfig(
|
||||||
|
|||||||
@@ -131,7 +131,7 @@ class OnlineRecognizer(object):
|
|||||||
self.recognizer = _Recognizer(recognizer_config)
|
self.recognizer = _Recognizer(recognizer_config)
|
||||||
self.config = recognizer_config
|
self.config = recognizer_config
|
||||||
|
|
||||||
def create_stream(self, contexts_list : Optional[List[List[int]]] = None):
|
def create_stream(self, contexts_list: Optional[List[List[int]]] = None):
|
||||||
if contexts_list is None:
|
if contexts_list is None:
|
||||||
return self.recognizer.create_stream()
|
return self.recognizer.create_stream()
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -72,6 +72,7 @@ class TestOfflineRecognizer(unittest.TestCase):
|
|||||||
joiner=joiner,
|
joiner=joiner,
|
||||||
tokens=tokens,
|
tokens=tokens,
|
||||||
num_threads=1,
|
num_threads=1,
|
||||||
|
provider="cpu",
|
||||||
)
|
)
|
||||||
|
|
||||||
s = recognizer.create_stream()
|
s = recognizer.create_stream()
|
||||||
@@ -106,6 +107,7 @@ class TestOfflineRecognizer(unittest.TestCase):
|
|||||||
joiner=joiner,
|
joiner=joiner,
|
||||||
tokens=tokens,
|
tokens=tokens,
|
||||||
num_threads=1,
|
num_threads=1,
|
||||||
|
provider="cpu",
|
||||||
)
|
)
|
||||||
|
|
||||||
s0 = recognizer.create_stream()
|
s0 = recognizer.create_stream()
|
||||||
@@ -143,6 +145,7 @@ class TestOfflineRecognizer(unittest.TestCase):
|
|||||||
paraformer=model,
|
paraformer=model,
|
||||||
tokens=tokens,
|
tokens=tokens,
|
||||||
num_threads=1,
|
num_threads=1,
|
||||||
|
provider="cpu",
|
||||||
)
|
)
|
||||||
|
|
||||||
s = recognizer.create_stream()
|
s = recognizer.create_stream()
|
||||||
@@ -172,6 +175,7 @@ class TestOfflineRecognizer(unittest.TestCase):
|
|||||||
paraformer=model,
|
paraformer=model,
|
||||||
tokens=tokens,
|
tokens=tokens,
|
||||||
num_threads=1,
|
num_threads=1,
|
||||||
|
provider="cpu",
|
||||||
)
|
)
|
||||||
|
|
||||||
s0 = recognizer.create_stream()
|
s0 = recognizer.create_stream()
|
||||||
@@ -214,6 +218,7 @@ class TestOfflineRecognizer(unittest.TestCase):
|
|||||||
model=model,
|
model=model,
|
||||||
tokens=tokens,
|
tokens=tokens,
|
||||||
num_threads=1,
|
num_threads=1,
|
||||||
|
provider="cpu",
|
||||||
)
|
)
|
||||||
|
|
||||||
s = recognizer.create_stream()
|
s = recognizer.create_stream()
|
||||||
@@ -242,6 +247,7 @@ class TestOfflineRecognizer(unittest.TestCase):
|
|||||||
model=model,
|
model=model,
|
||||||
tokens=tokens,
|
tokens=tokens,
|
||||||
num_threads=1,
|
num_threads=1,
|
||||||
|
provider="cpu",
|
||||||
)
|
)
|
||||||
|
|
||||||
s0 = recognizer.create_stream()
|
s0 = recognizer.create_stream()
|
||||||
|
|||||||
@@ -72,6 +72,7 @@ class TestOnlineRecognizer(unittest.TestCase):
|
|||||||
tokens=tokens,
|
tokens=tokens,
|
||||||
num_threads=1,
|
num_threads=1,
|
||||||
decoding_method=decoding_method,
|
decoding_method=decoding_method,
|
||||||
|
provider="cpu",
|
||||||
)
|
)
|
||||||
s = recognizer.create_stream()
|
s = recognizer.create_stream()
|
||||||
samples, sample_rate = read_wave(wave0)
|
samples, sample_rate = read_wave(wave0)
|
||||||
@@ -115,6 +116,7 @@ class TestOnlineRecognizer(unittest.TestCase):
|
|||||||
tokens=tokens,
|
tokens=tokens,
|
||||||
num_threads=1,
|
num_threads=1,
|
||||||
decoding_method=decoding_method,
|
decoding_method=decoding_method,
|
||||||
|
provider="cpu",
|
||||||
)
|
)
|
||||||
streams = []
|
streams = []
|
||||||
waves = [wave0, wave1, wave2, wave3, wave4]
|
waves = [wave0, wave1, wave2, wave3, wave4]
|
||||||
|
|||||||
Reference in New Issue
Block a user