diff --git a/python-api-examples/speech-recognition-from-microphone.py b/python-api-examples/speech-recognition-from-microphone.py index 45a8936d..8818513c 100755 --- a/python-api-examples/speech-recognition-from-microphone.py +++ b/python-api-examples/speech-recognition-from-microphone.py @@ -10,6 +10,9 @@ import argparse import sys from pathlib import Path +from typing import List, Tuple +import sentencepiece as spm + try: import sounddevice as sd except ImportError: @@ -70,6 +73,59 @@ def get_args(): help="Valid values are greedy_search and modified_beam_search", ) + parser.add_argument( + "--max-active-paths", + type=int, + default=4, + help="""Used only when --decoding-method is modified_beam_search. + It specifies number of active paths to keep during decoding. + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="", + help=""" + Path to bpe.model, it will be used to tokenize contexts biasing phrases. + Used only when --decoding-method=modified_beam_search + """, + ) + + parser.add_argument( + "--modeling-unit", + type=str, + default="char", + help=""" + The type of modeling unit, it will be used to tokenize contexts biasing phrases. + Valid values are bpe, bpe+char, char. + Note: the char here means characters in CJK languages. + Used only when --decoding-method=modified_beam_search + """, + ) + + parser.add_argument( + "--contexts", + type=str, + default="", + help=""" + The context list, it is a string containing some words/phrases separated + with /, for example, 'HELLO WORLD/I LOVE YOU/GO AWAY". + Used only when --decoding-method=modified_beam_search + """, + ) + + parser.add_argument( + "--context-score", + type=float, + default=1.5, + help=""" + The context score of each token for biasing word/phrase. Used only if + --contexts is given. + Used only when --decoding-method=modified_beam_search + """, + ) + return parser.parse_args() @@ -91,11 +147,40 @@ def create_recognizer(): sample_rate=16000, feature_dim=80, decoding_method=args.decoding_method, + max_active_paths=args.max_active_paths, + context_score=args.context_score, ) return recognizer +def encode_contexts(args, contexts: List[str]) -> List[List[int]]: + sp = None + if "bpe" in args.modeling_unit: + assert_file_exists(args.bpe_model) + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + tokens = {} + with open(args.tokens, "r", encoding="utf-8") as f: + for line in f: + toks = line.strip().split() + assert len(toks) == 2, len(toks) + assert toks[0] not in tokens, f"Duplicate token: {toks} " + tokens[toks[0]] = int(toks[1]) + return sherpa_onnx.encode_contexts( + modeling_unit=args.modeling_unit, + contexts=contexts, + sp=sp, + tokens_table=tokens, + ) def main(): + args = get_args() + + contexts_list = [] + contexts = [x.strip().upper() for x in args.contexts.split("/") if x.strip()] + if contexts: + print(f"Contexts list: {contexts}") + contexts_list = encode_contexts(args, contexts) + recognizer = create_recognizer() print("Started! Please speak") @@ -104,7 +189,10 @@ def main(): sample_rate = 48000 samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms last_result = "" - stream = recognizer.create_stream() + if contexts_list: + stream = recognizer.create_stream(contexts_list=contexts_list) + else: + stream = recognizer.create_stream() with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s: while True: samples, _ = s.read(samples_per_read) # a blocking read @@ -117,7 +205,6 @@ def main(): last_result = result print("\r{}".format(result), end="", flush=True) - if __name__ == "__main__": devices = sd.query_devices() print(devices)