diff --git a/python-api-examples/online-decode-files.py b/python-api-examples/online-decode-files.py index 44fda24e..03d0e6f5 100755 --- a/python-api-examples/online-decode-files.py +++ b/python-api-examples/online-decode-files.py @@ -79,6 +79,13 @@ def get_args(): """, ) + parser.add_argument( + "--provider", + type=str, + default="cpu", + help="Valid values: cpu, cuda, coreml", + ) + parser.add_argument( "--bpe-model", type=str, @@ -204,6 +211,7 @@ def main(): decoder=args.decoder, joiner=args.joiner, num_threads=args.num_threads, + provider=args.provider, sample_rate=16000, feature_dim=80, decoding_method=args.decoding_method, @@ -220,7 +228,6 @@ def main(): print(f"Contexts list: {contexts}") contexts_list = encode_contexts(args, contexts) - streams = [] total_duration = 0 for wave_filename in args.sound_files: diff --git a/python-api-examples/speech-recognition-from-microphone-with-endpoint-detection.py b/python-api-examples/speech-recognition-from-microphone-with-endpoint-detection.py index c0faaa0f..45cdb1e6 100755 --- a/python-api-examples/speech-recognition-from-microphone-with-endpoint-detection.py +++ b/python-api-examples/speech-recognition-from-microphone-with-endpoint-detection.py @@ -72,6 +72,13 @@ def get_args(): help="Valid values are greedy_search and modified_beam_search", ) + parser.add_argument( + "--provider", + type=str, + default="cpu", + help="Valid values: cpu, cuda, coreml", + ) + return parser.parse_args() @@ -97,6 +104,7 @@ def create_recognizer(): rule2_min_trailing_silence=1.2, rule3_min_utterance_length=300, # it essentially disables this rule decoding_method=args.decoding_method, + provider=args.provider, ) return recognizer diff --git a/python-api-examples/speech-recognition-from-microphone.py b/python-api-examples/speech-recognition-from-microphone.py index 8818513c..6edbb804 100755 --- a/python-api-examples/speech-recognition-from-microphone.py +++ b/python-api-examples/speech-recognition-from-microphone.py @@ -82,6 +82,13 @@ def get_args(): """, ) + parser.add_argument( + "--provider", + type=str, + default="cpu", + help="Valid values: cpu, cuda, coreml", + ) + parser.add_argument( "--bpe-model", type=str, @@ -148,10 +155,12 @@ def create_recognizer(): feature_dim=80, decoding_method=args.decoding_method, max_active_paths=args.max_active_paths, + provider=args.provider, context_score=args.context_score, ) return recognizer + def encode_contexts(args, contexts: List[str]) -> List[List[int]]: sp = None if "bpe" in args.modeling_unit: @@ -172,6 +181,7 @@ def encode_contexts(args, contexts: List[str]) -> List[List[int]]: tokens_table=tokens, ) + def main(): args = get_args() @@ -205,6 +215,7 @@ def main(): last_result = result print("\r{}".format(result), end="", flush=True) + if __name__ == "__main__": devices = sd.query_devices() print(devices) diff --git a/python-api-examples/streaming_server.py b/python-api-examples/streaming_server.py index 229e73fa..ea9d111f 100755 --- a/python-api-examples/streaming_server.py +++ b/python-api-examples/streaming_server.py @@ -129,6 +129,13 @@ def add_model_args(parser: argparse.ArgumentParser): help="Feature dimension of the model", ) + parser.add_argument( + "--provider", + type=str, + default="cpu", + help="Valid values: cpu, cuda, coreml", + ) + def add_decoding_args(parser: argparse.ArgumentParser): parser.add_argument( @@ -301,6 +308,7 @@ def create_recognizer(args) -> sherpa_onnx.OnlineRecognizer: rule1_min_trailing_silence=args.rule1_min_trailing_silence, rule2_min_trailing_silence=args.rule2_min_trailing_silence, rule3_min_utterance_length=args.rule3_min_utterance_length, + provider=args.provider, ) return recognizer