Code refactoring (#74)

* Don't reset model state and feature extractor on endpointing

* support passing decoding_method from commandline

* Add modified_beam_search to Python API

* fix C API example

* Fix style issues
This commit is contained in:
Fangjun Kuang
2023-03-03 12:10:59 +08:00
committed by GitHub
parent c241f93c40
commit 7f72c13d9a
34 changed files with 744 additions and 374 deletions

View File

@@ -53,6 +53,20 @@ def get_args():
help="Path to the joiner model",
)
parser.add_argument(
"--num-threads",
type=int,
default=1,
help="Number of threads for neural network computation",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="Valid values are greedy_search and modified_beam_search",
)
parser.add_argument(
"--wave-filename",
type=str,
@@ -65,7 +79,6 @@ def get_args():
def main():
sample_rate = 16000
num_threads = 2
args = get_args()
assert_file_exists(args.encoder)
@@ -81,9 +94,10 @@ def main():
encoder=args.encoder,
decoder=args.decoder,
joiner=args.joiner,
num_threads=num_threads,
num_threads=args.num_threads,
sample_rate=sample_rate,
feature_dim=80,
decoding_method=args.decoding_method,
)
with wave.open(args.wave_filename) as f:
assert f.getframerate() == sample_rate, f.getframerate()
@@ -119,7 +133,8 @@ def main():
end_time = time.time()
elapsed_seconds = end_time - start_time
rtf = elapsed_seconds / duration
print(f"num_threads: {num_threads}")
print(f"num_threads: {args.num_threads}")
print(f"decoding_method: {args.decoding_method}")
print(f"Wave duration: {duration:.3f} s")
print(f"Elapsed time: {elapsed_seconds:.3f} s")
print(f"Real time factor (RTF): {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}")

View File

@@ -60,10 +60,10 @@ def get_args():
)
parser.add_argument(
"--wave-filename",
"--decoding-method",
type=str,
help="""Path to the wave filename. Must be 16 kHz,
mono with 16-bit samples""",
default="greedy_search",
help="Valid values are greedy_search and modified_beam_search",
)
return parser.parse_args()
@@ -83,17 +83,23 @@ def create_recognizer():
encoder=args.encoder,
decoder=args.decoder,
joiner=args.joiner,
num_threads=1,
sample_rate=16000,
feature_dim=80,
enable_endpoint_detection=True,
rule1_min_trailing_silence=2.4,
rule2_min_trailing_silence=1.2,
rule3_min_utterance_length=300, # it essentially disables this rule
decoding_method=args.decoding_method,
max_feature_vectors=100, # 1 second
)
return recognizer
def main():
print("Started! Please speak")
recognizer = create_recognizer()
print("Started! Please speak")
sample_rate = 16000
samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms
last_result = ""
@@ -101,6 +107,7 @@ def main():
last_result = ""
segment_id = 0
display = sherpa_onnx.Display(max_word_per_line=30)
with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s:
while True:
samples, _ = s.read(samples_per_read) # a blocking read
@@ -115,7 +122,7 @@ def main():
if result and (last_result != result):
last_result = result
print(f"{segment_id}: {result}")
display.print(segment_id, result)
if is_endpoint:
if result:

View File

@@ -59,10 +59,10 @@ def get_args():
)
parser.add_argument(
"--wave-filename",
"--decoding-method",
type=str,
help="""Path to the wave filename. Must be 16 kHz,
mono with 16-bit samples""",
default="greedy_search",
help="Valid values are greedy_search and modified_beam_search",
)
return parser.parse_args()
@@ -82,9 +82,11 @@ def create_recognizer():
encoder=args.encoder,
decoder=args.decoder,
joiner=args.joiner,
num_threads=4,
num_threads=1,
sample_rate=16000,
feature_dim=80,
decoding_method=args.decoding_method,
max_feature_vectors=100, # 1 second
)
return recognizer
@@ -96,6 +98,7 @@ def main():
samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms
last_result = ""
stream = recognizer.create_stream()
display = sherpa_onnx.Display(max_word_per_line=40)
with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s:
while True:
samples, _ = s.read(samples_per_read) # a blocking read
@@ -106,7 +109,7 @@ def main():
result = recognizer.get_result(stream)
if last_result != result:
last_result = result
print(result)
display.print(-1, result)
if __name__ == "__main__":