Begin to support CTC models (#119)

Please see https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/nemo/index.html for a list of pre-trained CTC models from NeMo.
This commit is contained in:
Fangjun Kuang
2023-04-07 23:11:34 +08:00
committed by GitHub
parent 9ac747248b
commit 80060c276d
40 changed files with 1244 additions and 60 deletions

View File

@@ -6,7 +6,7 @@
This file demonstrates how to use sherpa-onnx Python API to transcribe
file(s) with a non-streaming model.
paraformer Usage:
(1) For paraformer
./python-api-examples/offline-decode-files.py \
--tokens=/path/to/tokens.txt \
--paraformer=/path/to/paraformer.onnx \
@@ -18,7 +18,7 @@ paraformer Usage:
/path/to/0.wav \
/path/to/1.wav
transducer Usage:
(2) For transducer models from icefall
./python-api-examples/offline-decode-files.py \
--tokens=/path/to/tokens.txt \
--encoder=/path/to/encoder.onnx \
@@ -32,6 +32,8 @@ transducer Usage:
/path/to/0.wav \
/path/to/1.wav
(3) For CTC models from NeMo
Please refer to
https://k2-fsa.github.io/sherpa/onnx/index.html
to install sherpa-onnx and to download the pre-trained models
@@ -83,7 +85,14 @@ def get_args():
"--paraformer",
default="",
type=str,
help="Path to the paraformer model",
help="Path to the model.onnx from Paraformer",
)
parser.add_argument(
"--nemo-ctc",
default="",
type=str,
help="Path to the model.onnx from NeMo CTC",
)
parser.add_argument(
@@ -171,11 +180,14 @@ def main():
args = get_args()
assert_file_exists(args.tokens)
assert args.num_threads > 0, args.num_threads
if len(args.encoder) > 0:
if args.encoder:
assert len(args.paraformer) == 0, args.paraformer
assert len(args.nemo_ctc) == 0, args.nemo_ctc
assert_file_exists(args.encoder)
assert_file_exists(args.decoder)
assert_file_exists(args.joiner)
assert len(args.paraformer) == 0, args.paraformer
recognizer = sherpa_onnx.OfflineRecognizer.from_transducer(
encoder=args.encoder,
decoder=args.decoder,
@@ -187,8 +199,10 @@ def main():
decoding_method=args.decoding_method,
debug=args.debug,
)
else:
elif args.paraformer:
assert len(args.nemo_ctc) == 0, args.nemo_ctc
assert_file_exists(args.paraformer)
recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer(
paraformer=args.paraformer,
tokens=args.tokens,
@@ -198,6 +212,19 @@ def main():
decoding_method=args.decoding_method,
debug=args.debug,
)
elif args.nemo_ctc:
recognizer = sherpa_onnx.OfflineRecognizer.from_nemo_ctc(
model=args.nemo_ctc,
tokens=args.tokens,
num_threads=args.num_threads,
sample_rate=args.sample_rate,
feature_dim=args.feature_dim,
decoding_method=args.decoding_method,
debug=args.debug,
)
else:
print("Please specify at least one model")
return
print("Started!")
start_time = time.time()
@@ -225,12 +252,14 @@ def main():
print("-" * 10)
elapsed_seconds = end_time - start_time
rtf = elapsed_seconds / duration
rtf = elapsed_seconds / total_duration
print(f"num_threads: {args.num_threads}")
print(f"decoding_method: {args.decoding_method}")
print(f"Wave duration: {duration:.3f} s")
print(f"Wave duration: {total_duration:.3f} s")
print(f"Elapsed time: {elapsed_seconds:.3f} s")
print(f"Real time factor (RTF): {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}")
print(
f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}"
)
if __name__ == "__main__":

View File

@@ -172,12 +172,14 @@ def main():
print("-" * 10)
elapsed_seconds = end_time - start_time
rtf = elapsed_seconds / duration
rtf = elapsed_seconds / total_duration
print(f"num_threads: {args.num_threads}")
print(f"decoding_method: {args.decoding_method}")
print(f"Wave duration: {duration:.3f} s")
print(f"Wave duration: {total_duration:.3f} s")
print(f"Elapsed time: {elapsed_seconds:.3f} s")
print(f"Real time factor (RTF): {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}")
print(
f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}"
)
if __name__ == "__main__":