Support TDNN models from the yesno recipe from icefall (#262)

This commit is contained in:
Fangjun Kuang
2023-08-12 19:50:22 +08:00
committed by GitHub
parent b094868fb8
commit a4bff28e21
23 changed files with 612 additions and 36 deletions

View File

@@ -8,6 +8,7 @@ This file demonstrates how to use sherpa-onnx Python API to transcribe
file(s) with a non-streaming model.
(1) For paraformer
./python-api-examples/offline-decode-files.py \
--tokens=/path/to/tokens.txt \
--paraformer=/path/to/paraformer.onnx \
@@ -20,6 +21,7 @@ file(s) with a non-streaming model.
/path/to/1.wav
(2) For transducer models from icefall
./python-api-examples/offline-decode-files.py \
--tokens=/path/to/tokens.txt \
--encoder=/path/to/encoder.onnx \
@@ -56,9 +58,20 @@ python3 ./python-api-examples/offline-decode-files.py \
./sherpa-onnx-whisper-base.en/test_wavs/1.wav \
./sherpa-onnx-whisper-base.en/test_wavs/8k.wav
(5) For tdnn models of the yesno recipe from icefall
python3 ./python-api-examples/offline-decode-files.py \
--sample-rate=8000 \
--feature-dim=23 \
--tdnn-model=./sherpa-onnx-tdnn-yesno/model-epoch-14-avg-2.onnx \
--tokens=./sherpa-onnx-tdnn-yesno/tokens.txt \
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_0_1_0_0_0_1.wav \
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_0_1_0.wav \
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_1_1_1.wav
Please refer to
https://k2-fsa.github.io/sherpa/onnx/index.html
to install sherpa-onnx and to download the pre-trained models
to install sherpa-onnx and to download non-streaming pre-trained models
used in this file.
"""
import argparse
@@ -159,6 +172,13 @@ def get_args():
help="Path to the model.onnx from NeMo CTC",
)
parser.add_argument(
"--tdnn-model",
default="",
type=str,
help="Path to the model.onnx for the tdnn model of the yesno recipe",
)
parser.add_argument(
"--num-threads",
type=int,
@@ -285,6 +305,7 @@ def main():
assert len(args.nemo_ctc) == 0, args.nemo_ctc
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.tdnn_model) == 0, args.tdnn_model
contexts = [x.strip().upper() for x in args.contexts.split("/") if x.strip()]
if contexts:
@@ -311,6 +332,7 @@ def main():
assert len(args.nemo_ctc) == 0, args.nemo_ctc
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.tdnn_model) == 0, args.tdnn_model
assert_file_exists(args.paraformer)
@@ -326,6 +348,7 @@ def main():
elif args.nemo_ctc:
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.tdnn_model) == 0, args.tdnn_model
assert_file_exists(args.nemo_ctc)
@@ -339,6 +362,7 @@ def main():
debug=args.debug,
)
elif args.whisper_encoder:
assert len(args.tdnn_model) == 0, args.tdnn_model
assert_file_exists(args.whisper_encoder)
assert_file_exists(args.whisper_decoder)
@@ -347,6 +371,20 @@ def main():
decoder=args.whisper_decoder,
tokens=args.tokens,
num_threads=args.num_threads,
sample_rate=args.sample_rate,
feature_dim=args.feature_dim,
decoding_method=args.decoding_method,
debug=args.debug,
)
elif args.tdnn_model:
assert_file_exists(args.tdnn_model)
recognizer = sherpa_onnx.OfflineRecognizer.from_tdnn_ctc(
model=args.tdnn_model,
tokens=args.tokens,
sample_rate=args.sample_rate,
feature_dim=args.feature_dim,
num_threads=args.num_threads,
decoding_method=args.decoding_method,
debug=args.debug,
)