Begin to support CTC models (#119)
Please see https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/nemo/index.html for a list of pre-trained CTC models from NeMo.
This commit is contained in:
@@ -6,7 +6,7 @@
|
||||
This file demonstrates how to use sherpa-onnx Python API to transcribe
|
||||
file(s) with a non-streaming model.
|
||||
|
||||
paraformer Usage:
|
||||
(1) For paraformer
|
||||
./python-api-examples/offline-decode-files.py \
|
||||
--tokens=/path/to/tokens.txt \
|
||||
--paraformer=/path/to/paraformer.onnx \
|
||||
@@ -18,7 +18,7 @@ paraformer Usage:
|
||||
/path/to/0.wav \
|
||||
/path/to/1.wav
|
||||
|
||||
transducer Usage:
|
||||
(2) For transducer models from icefall
|
||||
./python-api-examples/offline-decode-files.py \
|
||||
--tokens=/path/to/tokens.txt \
|
||||
--encoder=/path/to/encoder.onnx \
|
||||
@@ -32,6 +32,8 @@ transducer Usage:
|
||||
/path/to/0.wav \
|
||||
/path/to/1.wav
|
||||
|
||||
(3) For CTC models from NeMo
|
||||
|
||||
Please refer to
|
||||
https://k2-fsa.github.io/sherpa/onnx/index.html
|
||||
to install sherpa-onnx and to download the pre-trained models
|
||||
@@ -83,7 +85,14 @@ def get_args():
|
||||
"--paraformer",
|
||||
default="",
|
||||
type=str,
|
||||
help="Path to the paraformer model",
|
||||
help="Path to the model.onnx from Paraformer",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nemo-ctc",
|
||||
default="",
|
||||
type=str,
|
||||
help="Path to the model.onnx from NeMo CTC",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@@ -171,11 +180,14 @@ def main():
|
||||
args = get_args()
|
||||
assert_file_exists(args.tokens)
|
||||
assert args.num_threads > 0, args.num_threads
|
||||
if len(args.encoder) > 0:
|
||||
if args.encoder:
|
||||
assert len(args.paraformer) == 0, args.paraformer
|
||||
assert len(args.nemo_ctc) == 0, args.nemo_ctc
|
||||
|
||||
assert_file_exists(args.encoder)
|
||||
assert_file_exists(args.decoder)
|
||||
assert_file_exists(args.joiner)
|
||||
assert len(args.paraformer) == 0, args.paraformer
|
||||
|
||||
recognizer = sherpa_onnx.OfflineRecognizer.from_transducer(
|
||||
encoder=args.encoder,
|
||||
decoder=args.decoder,
|
||||
@@ -187,8 +199,10 @@ def main():
|
||||
decoding_method=args.decoding_method,
|
||||
debug=args.debug,
|
||||
)
|
||||
else:
|
||||
elif args.paraformer:
|
||||
assert len(args.nemo_ctc) == 0, args.nemo_ctc
|
||||
assert_file_exists(args.paraformer)
|
||||
|
||||
recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer(
|
||||
paraformer=args.paraformer,
|
||||
tokens=args.tokens,
|
||||
@@ -198,6 +212,19 @@ def main():
|
||||
decoding_method=args.decoding_method,
|
||||
debug=args.debug,
|
||||
)
|
||||
elif args.nemo_ctc:
|
||||
recognizer = sherpa_onnx.OfflineRecognizer.from_nemo_ctc(
|
||||
model=args.nemo_ctc,
|
||||
tokens=args.tokens,
|
||||
num_threads=args.num_threads,
|
||||
sample_rate=args.sample_rate,
|
||||
feature_dim=args.feature_dim,
|
||||
decoding_method=args.decoding_method,
|
||||
debug=args.debug,
|
||||
)
|
||||
else:
|
||||
print("Please specify at least one model")
|
||||
return
|
||||
|
||||
print("Started!")
|
||||
start_time = time.time()
|
||||
@@ -225,12 +252,14 @@ def main():
|
||||
print("-" * 10)
|
||||
|
||||
elapsed_seconds = end_time - start_time
|
||||
rtf = elapsed_seconds / duration
|
||||
rtf = elapsed_seconds / total_duration
|
||||
print(f"num_threads: {args.num_threads}")
|
||||
print(f"decoding_method: {args.decoding_method}")
|
||||
print(f"Wave duration: {duration:.3f} s")
|
||||
print(f"Wave duration: {total_duration:.3f} s")
|
||||
print(f"Elapsed time: {elapsed_seconds:.3f} s")
|
||||
print(f"Real time factor (RTF): {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}")
|
||||
print(
|
||||
f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -172,12 +172,14 @@ def main():
|
||||
print("-" * 10)
|
||||
|
||||
elapsed_seconds = end_time - start_time
|
||||
rtf = elapsed_seconds / duration
|
||||
rtf = elapsed_seconds / total_duration
|
||||
print(f"num_threads: {args.num_threads}")
|
||||
print(f"decoding_method: {args.decoding_method}")
|
||||
print(f"Wave duration: {duration:.3f} s")
|
||||
print(f"Wave duration: {total_duration:.3f} s")
|
||||
print(f"Elapsed time: {elapsed_seconds:.3f} s")
|
||||
print(f"Real time factor (RTF): {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}")
|
||||
print(
|
||||
f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user