Code refactoring (#74)

* Don't reset model state and feature extractor on endpointing

* support passing decoding_method from commandline

* Add modified_beam_search to Python API

* fix C API example

* Fix style issues
This commit is contained in:
Fangjun Kuang
2023-03-03 12:10:59 +08:00
committed by GitHub
parent c241f93c40
commit 7f72c13d9a
34 changed files with 744 additions and 374 deletions

View File

@@ -53,6 +53,20 @@ def get_args():
help="Path to the joiner model",
)
parser.add_argument(
"--num-threads",
type=int,
default=1,
help="Number of threads for neural network computation",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="Valid values are greedy_search and modified_beam_search",
)
parser.add_argument(
"--wave-filename",
type=str,
@@ -65,7 +79,6 @@ def get_args():
def main():
sample_rate = 16000
num_threads = 2
args = get_args()
assert_file_exists(args.encoder)
@@ -81,9 +94,10 @@ def main():
encoder=args.encoder,
decoder=args.decoder,
joiner=args.joiner,
num_threads=num_threads,
num_threads=args.num_threads,
sample_rate=sample_rate,
feature_dim=80,
decoding_method=args.decoding_method,
)
with wave.open(args.wave_filename) as f:
assert f.getframerate() == sample_rate, f.getframerate()
@@ -119,7 +133,8 @@ def main():
end_time = time.time()
elapsed_seconds = end_time - start_time
rtf = elapsed_seconds / duration
print(f"num_threads: {num_threads}")
print(f"num_threads: {args.num_threads}")
print(f"decoding_method: {args.decoding_method}")
print(f"Wave duration: {duration:.3f} s")
print(f"Elapsed time: {elapsed_seconds:.3f} s")
print(f"Real time factor (RTF): {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}")