Code refactoring (#74)
* Don't reset model state and feature extractor on endpointing * support passing decoding_method from commandline * Add modified_beam_search to Python API * fix C API example * Fix style issues
This commit is contained in:
@@ -53,6 +53,20 @@ def get_args():
|
||||
help="Path to the joiner model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-threads",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of threads for neural network computation",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="Valid values are greedy_search and modified_beam_search",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--wave-filename",
|
||||
type=str,
|
||||
@@ -65,7 +79,6 @@ def get_args():
|
||||
|
||||
def main():
|
||||
sample_rate = 16000
|
||||
num_threads = 2
|
||||
|
||||
args = get_args()
|
||||
assert_file_exists(args.encoder)
|
||||
@@ -81,9 +94,10 @@ def main():
|
||||
encoder=args.encoder,
|
||||
decoder=args.decoder,
|
||||
joiner=args.joiner,
|
||||
num_threads=num_threads,
|
||||
num_threads=args.num_threads,
|
||||
sample_rate=sample_rate,
|
||||
feature_dim=80,
|
||||
decoding_method=args.decoding_method,
|
||||
)
|
||||
with wave.open(args.wave_filename) as f:
|
||||
assert f.getframerate() == sample_rate, f.getframerate()
|
||||
@@ -119,7 +133,8 @@ def main():
|
||||
end_time = time.time()
|
||||
elapsed_seconds = end_time - start_time
|
||||
rtf = elapsed_seconds / duration
|
||||
print(f"num_threads: {num_threads}")
|
||||
print(f"num_threads: {args.num_threads}")
|
||||
print(f"decoding_method: {args.decoding_method}")
|
||||
print(f"Wave duration: {duration:.3f} s")
|
||||
print(f"Elapsed time: {elapsed_seconds:.3f} s")
|
||||
print(f"Real time factor (RTF): {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}")
|
||||
|
||||
@@ -60,10 +60,10 @@ def get_args():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--wave-filename",
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
help="""Path to the wave filename. Must be 16 kHz,
|
||||
mono with 16-bit samples""",
|
||||
default="greedy_search",
|
||||
help="Valid values are greedy_search and modified_beam_search",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
@@ -83,17 +83,23 @@ def create_recognizer():
|
||||
encoder=args.encoder,
|
||||
decoder=args.decoder,
|
||||
joiner=args.joiner,
|
||||
num_threads=1,
|
||||
sample_rate=16000,
|
||||
feature_dim=80,
|
||||
enable_endpoint_detection=True,
|
||||
rule1_min_trailing_silence=2.4,
|
||||
rule2_min_trailing_silence=1.2,
|
||||
rule3_min_utterance_length=300, # it essentially disables this rule
|
||||
decoding_method=args.decoding_method,
|
||||
max_feature_vectors=100, # 1 second
|
||||
)
|
||||
return recognizer
|
||||
|
||||
|
||||
def main():
|
||||
print("Started! Please speak")
|
||||
recognizer = create_recognizer()
|
||||
print("Started! Please speak")
|
||||
|
||||
sample_rate = 16000
|
||||
samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms
|
||||
last_result = ""
|
||||
@@ -101,6 +107,7 @@ def main():
|
||||
|
||||
last_result = ""
|
||||
segment_id = 0
|
||||
display = sherpa_onnx.Display(max_word_per_line=30)
|
||||
with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s:
|
||||
while True:
|
||||
samples, _ = s.read(samples_per_read) # a blocking read
|
||||
@@ -115,7 +122,7 @@ def main():
|
||||
|
||||
if result and (last_result != result):
|
||||
last_result = result
|
||||
print(f"{segment_id}: {result}")
|
||||
display.print(segment_id, result)
|
||||
|
||||
if is_endpoint:
|
||||
if result:
|
||||
|
||||
@@ -59,10 +59,10 @@ def get_args():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--wave-filename",
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
help="""Path to the wave filename. Must be 16 kHz,
|
||||
mono with 16-bit samples""",
|
||||
default="greedy_search",
|
||||
help="Valid values are greedy_search and modified_beam_search",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
@@ -82,9 +82,11 @@ def create_recognizer():
|
||||
encoder=args.encoder,
|
||||
decoder=args.decoder,
|
||||
joiner=args.joiner,
|
||||
num_threads=4,
|
||||
num_threads=1,
|
||||
sample_rate=16000,
|
||||
feature_dim=80,
|
||||
decoding_method=args.decoding_method,
|
||||
max_feature_vectors=100, # 1 second
|
||||
)
|
||||
return recognizer
|
||||
|
||||
@@ -96,6 +98,7 @@ def main():
|
||||
samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms
|
||||
last_result = ""
|
||||
stream = recognizer.create_stream()
|
||||
display = sherpa_onnx.Display(max_word_per_line=40)
|
||||
with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s:
|
||||
while True:
|
||||
samples, _ = s.read(samples_per_read) # a blocking read
|
||||
@@ -106,7 +109,7 @@ def main():
|
||||
result = recognizer.get_result(stream)
|
||||
if last_result != result:
|
||||
last_result = result
|
||||
print(result)
|
||||
display.print(-1, result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user