Support TDNN models from the yesno recipe from icefall (#262)

This commit is contained in:
Fangjun Kuang
2023-08-12 19:50:22 +08:00
committed by GitHub
parent b094868fb8
commit a4bff28e21
23 changed files with 612 additions and 36 deletions

View File

@@ -71,6 +71,20 @@ python3 ./python-api-examples/non_streaming_server.py \
--whisper-decoder=./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.onnx \
--tokens=./sherpa-onnx-whisper-tiny.en/tiny.en-tokens.txt
(5) Use a tdnn model of the yesno recipe from icefall
cd /path/to/sherpa-onnx
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-tdnn-yesno
cd sherpa-onnx-tdnn-yesno
git lfs pull --include "*.onnx"
python3 ./python-api-examples/non_streaming_server.py \
--sample-rate=8000 \
--feat-dim=23 \
--tdnn-model=./sherpa-onnx-tdnn-yesno/model-epoch-14-avg-2.onnx \
--tokens=./sherpa-onnx-tdnn-yesno/tokens.txt
----
To use a certificate so that you can use https, please use
@@ -196,6 +210,15 @@ def add_nemo_ctc_model_args(parser: argparse.ArgumentParser):
)
def add_tdnn_ctc_model_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--tdnn-model",
default="",
type=str,
help="Path to the model.onnx for the tdnn model of the yesno recipe",
)
def add_whisper_model_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--whisper-encoder",
@@ -216,6 +239,7 @@ def add_model_args(parser: argparse.ArgumentParser):
add_transducer_model_args(parser)
add_paraformer_model_args(parser)
add_nemo_ctc_model_args(parser)
add_tdnn_ctc_model_args(parser)
add_whisper_model_args(parser)
parser.add_argument(
@@ -730,6 +754,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
assert len(args.nemo_ctc) == 0, args.nemo_ctc
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.tdnn_model) == 0, args.tdnn_model
assert_file_exists(args.encoder)
assert_file_exists(args.decoder)
@@ -750,6 +775,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
assert len(args.nemo_ctc) == 0, args.nemo_ctc
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.tdnn_model) == 0, args.tdnn_model
assert_file_exists(args.paraformer)
@@ -764,6 +790,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
elif args.nemo_ctc:
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.tdnn_model) == 0, args.tdnn_model
assert_file_exists(args.nemo_ctc)
@@ -776,6 +803,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
decoding_method=args.decoding_method,
)
elif args.whisper_encoder:
assert len(args.tdnn_model) == 0, args.tdnn_model
assert_file_exists(args.whisper_encoder)
assert_file_exists(args.whisper_decoder)
@@ -786,6 +814,17 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
num_threads=args.num_threads,
decoding_method=args.decoding_method,
)
elif args.tdnn_model:
assert_file_exists(args.tdnn_model)
recognizer = sherpa_onnx.OfflineRecognizer.from_tdnn_ctc(
model=args.tdnn_model,
tokens=args.tokens,
sample_rate=args.sample_rate,
feature_dim=args.feat_dim,
num_threads=args.num_threads,
decoding_method=args.decoding_method,
)
else:
raise ValueError("Please specify at least one model")