Support specifying provider for python examples (#244)
This commit is contained in:
@@ -79,6 +79,13 @@ def get_args():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--provider",
|
||||||
|
type=str,
|
||||||
|
default="cpu",
|
||||||
|
help="Valid values: cpu, cuda, coreml",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--bpe-model",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -204,6 +211,7 @@ def main():
|
|||||||
decoder=args.decoder,
|
decoder=args.decoder,
|
||||||
joiner=args.joiner,
|
joiner=args.joiner,
|
||||||
num_threads=args.num_threads,
|
num_threads=args.num_threads,
|
||||||
|
provider=args.provider,
|
||||||
sample_rate=16000,
|
sample_rate=16000,
|
||||||
feature_dim=80,
|
feature_dim=80,
|
||||||
decoding_method=args.decoding_method,
|
decoding_method=args.decoding_method,
|
||||||
@@ -220,7 +228,6 @@ def main():
|
|||||||
print(f"Contexts list: {contexts}")
|
print(f"Contexts list: {contexts}")
|
||||||
contexts_list = encode_contexts(args, contexts)
|
contexts_list = encode_contexts(args, contexts)
|
||||||
|
|
||||||
|
|
||||||
streams = []
|
streams = []
|
||||||
total_duration = 0
|
total_duration = 0
|
||||||
for wave_filename in args.sound_files:
|
for wave_filename in args.sound_files:
|
||||||
|
|||||||
@@ -72,6 +72,13 @@ def get_args():
|
|||||||
help="Valid values are greedy_search and modified_beam_search",
|
help="Valid values are greedy_search and modified_beam_search",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--provider",
|
||||||
|
type=str,
|
||||||
|
default="cpu",
|
||||||
|
help="Valid values: cpu, cuda, coreml",
|
||||||
|
)
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@@ -97,6 +104,7 @@ def create_recognizer():
|
|||||||
rule2_min_trailing_silence=1.2,
|
rule2_min_trailing_silence=1.2,
|
||||||
rule3_min_utterance_length=300, # it essentially disables this rule
|
rule3_min_utterance_length=300, # it essentially disables this rule
|
||||||
decoding_method=args.decoding_method,
|
decoding_method=args.decoding_method,
|
||||||
|
provider=args.provider,
|
||||||
)
|
)
|
||||||
return recognizer
|
return recognizer
|
||||||
|
|
||||||
|
|||||||
@@ -82,6 +82,13 @@ def get_args():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--provider",
|
||||||
|
type=str,
|
||||||
|
default="cpu",
|
||||||
|
help="Valid values: cpu, cuda, coreml",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--bpe-model",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -148,10 +155,12 @@ def create_recognizer():
|
|||||||
feature_dim=80,
|
feature_dim=80,
|
||||||
decoding_method=args.decoding_method,
|
decoding_method=args.decoding_method,
|
||||||
max_active_paths=args.max_active_paths,
|
max_active_paths=args.max_active_paths,
|
||||||
|
provider=args.provider,
|
||||||
context_score=args.context_score,
|
context_score=args.context_score,
|
||||||
)
|
)
|
||||||
return recognizer
|
return recognizer
|
||||||
|
|
||||||
|
|
||||||
def encode_contexts(args, contexts: List[str]) -> List[List[int]]:
|
def encode_contexts(args, contexts: List[str]) -> List[List[int]]:
|
||||||
sp = None
|
sp = None
|
||||||
if "bpe" in args.modeling_unit:
|
if "bpe" in args.modeling_unit:
|
||||||
@@ -172,6 +181,7 @@ def encode_contexts(args, contexts: List[str]) -> List[List[int]]:
|
|||||||
tokens_table=tokens,
|
tokens_table=tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = get_args()
|
args = get_args()
|
||||||
|
|
||||||
@@ -205,6 +215,7 @@ def main():
|
|||||||
last_result = result
|
last_result = result
|
||||||
print("\r{}".format(result), end="", flush=True)
|
print("\r{}".format(result), end="", flush=True)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
devices = sd.query_devices()
|
devices = sd.query_devices()
|
||||||
print(devices)
|
print(devices)
|
||||||
|
|||||||
@@ -129,6 +129,13 @@ def add_model_args(parser: argparse.ArgumentParser):
|
|||||||
help="Feature dimension of the model",
|
help="Feature dimension of the model",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--provider",
|
||||||
|
type=str,
|
||||||
|
default="cpu",
|
||||||
|
help="Valid values: cpu, cuda, coreml",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def add_decoding_args(parser: argparse.ArgumentParser):
|
def add_decoding_args(parser: argparse.ArgumentParser):
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -301,6 +308,7 @@ def create_recognizer(args) -> sherpa_onnx.OnlineRecognizer:
|
|||||||
rule1_min_trailing_silence=args.rule1_min_trailing_silence,
|
rule1_min_trailing_silence=args.rule1_min_trailing_silence,
|
||||||
rule2_min_trailing_silence=args.rule2_min_trailing_silence,
|
rule2_min_trailing_silence=args.rule2_min_trailing_silence,
|
||||||
rule3_min_utterance_length=args.rule3_min_utterance_length,
|
rule3_min_utterance_length=args.rule3_min_utterance_length,
|
||||||
|
provider=args.provider,
|
||||||
)
|
)
|
||||||
|
|
||||||
return recognizer
|
return recognizer
|
||||||
|
|||||||
Reference in New Issue
Block a user