Add Python APIs for WeNet CTC models (#428)

This commit is contained in:
Fangjun Kuang
2023-11-16 14:20:41 +08:00
committed by GitHub
parent fac4f6bc7c
commit 049fb9f451
13 changed files with 538 additions and 11 deletions

View File

@@ -58,6 +58,15 @@ wget https://github.com/snakers4/silero-vad/raw/master/files/silero_vad.onnx
--num-threads=2 \
/path/to/test.mp4
(4) For WeNet CTC models
./python-api-examples/generate-subtitles.py \
--silero-vad-model=/path/to/silero_vad.onnx \
--wenet-ctc=./sherpa-onnx-zh-wenet-wenetspeech/model.onnx \
--tokens=./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt \
--num-threads=2 \
/path/to/test.mp4
Please refer to
https://k2-fsa.github.io/sherpa/onnx/index.html
to install sherpa-onnx and to download non-streaming pre-trained models
@@ -121,6 +130,13 @@ def get_args():
help="Path to the model.onnx from Paraformer",
)
parser.add_argument(
"--wenet-ctc",
default="",
type=str,
help="Path to the CTC model.onnx from WeNet",
)
parser.add_argument(
"--num-threads",
type=int,
@@ -215,6 +231,7 @@ def assert_file_exists(filename: str):
def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
if args.encoder:
assert len(args.paraformer) == 0, args.paraformer
assert len(args.wenet_ctc) == 0, args.wenet_ctc
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
@@ -234,6 +251,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
debug=args.debug,
)
elif args.paraformer:
assert len(args.wenet_ctc) == 0, args.wenet_ctc
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
@@ -248,6 +266,21 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
decoding_method=args.decoding_method,
debug=args.debug,
)
elif args.wenet_ctc:
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert_file_exists(args.wenet_ctc)
recognizer = sherpa_onnx.OfflineRecognizer.from_wenet_ctc(
model=args.wenet_ctc,
tokens=args.tokens,
num_threads=args.num_threads,
sample_rate=args.sample_rate,
feature_dim=args.feature_dim,
decoding_method=args.decoding_method,
debug=args.debug,
)
elif args.whisper_encoder:
assert_file_exists(args.whisper_encoder)
assert_file_exists(args.whisper_decoder)

View File

@@ -58,7 +58,19 @@ python3 ./python-api-examples/non_streaming_server.py \
--nemo-ctc ./sherpa-onnx-nemo-ctc-en-conformer-medium/model.onnx \
--tokens ./sherpa-onnx-nemo-ctc-en-conformer-medium/tokens.txt
(4) Use a Whisper model
(4) Use a non-streaming CTC model from WeNet
cd /path/to/sherpa-onnx
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-zh-wenet-wenetspeech
cd sherpa-onnx-zh-wenet-wenetspeech
git lfs pull --include "*.onnx"
cd ..
python3 ./python-api-examples/non_streaming_server.py \
--wenet-ctc ./sherpa-onnx-zh-wenet-wenetspeech/model.onnx \
--tokens ./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt
(5) Use a Whisper model
cd /path/to/sherpa-onnx
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-whisper-tiny.en
@@ -210,6 +222,15 @@ def add_nemo_ctc_model_args(parser: argparse.ArgumentParser):
)
def add_wenet_ctc_model_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--wenet-ctc",
default="",
type=str,
help="Path to the model.onnx from WeNet CTC",
)
def add_tdnn_ctc_model_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--tdnn-model",
@@ -261,6 +282,7 @@ def add_model_args(parser: argparse.ArgumentParser):
add_transducer_model_args(parser)
add_paraformer_model_args(parser)
add_nemo_ctc_model_args(parser)
add_wenet_ctc_model_args(parser)
add_tdnn_ctc_model_args(parser)
add_whisper_model_args(parser)
@@ -804,6 +826,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
if args.encoder:
assert len(args.paraformer) == 0, args.paraformer
assert len(args.nemo_ctc) == 0, args.nemo_ctc
assert len(args.wenet_ctc) == 0, args.wenet_ctc
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.tdnn_model) == 0, args.tdnn_model
@@ -827,6 +850,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
)
elif args.paraformer:
assert len(args.nemo_ctc) == 0, args.nemo_ctc
assert len(args.wenet_ctc) == 0, args.wenet_ctc
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.tdnn_model) == 0, args.tdnn_model
@@ -842,6 +866,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
decoding_method=args.decoding_method,
)
elif args.nemo_ctc:
assert len(args.wenet_ctc) == 0, args.wenet_ctc
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.tdnn_model) == 0, args.tdnn_model
@@ -856,6 +881,21 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
feature_dim=args.feat_dim,
decoding_method=args.decoding_method,
)
elif args.wenet_ctc:
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.tdnn_model) == 0, args.tdnn_model
assert_file_exists(args.wenet_ctc)
recognizer = sherpa_onnx.OfflineRecognizer.from_wenet_ctc(
model=args.wenet_ctc,
tokens=args.tokens,
num_threads=args.num_threads,
sample_rate=args.sample_rate,
feature_dim=args.feat_dim,
decoding_method=args.decoding_method,
)
elif args.whisper_encoder:
assert len(args.tdnn_model) == 0, args.tdnn_model
assert_file_exists(args.whisper_encoder)

View File

@@ -59,7 +59,16 @@ python3 ./python-api-examples/offline-decode-files.py \
./sherpa-onnx-whisper-base.en/test_wavs/1.wav \
./sherpa-onnx-whisper-base.en/test_wavs/8k.wav
(5) For tdnn models of the yesno recipe from icefall
(5) For CTC models from WeNet
python3 ./python-api-examples/offline-decode-files.py \
--wenet-ctc=./sherpa-onnx-zh-wenet-wenetspeech/model.onnx \
--tokens=./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt \
./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/0.wav \
./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/1.wav \
./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/8k.wav
(6) For tdnn models of the yesno recipe from icefall
python3 ./python-api-examples/offline-decode-files.py \
--sample-rate=8000 \
@@ -154,6 +163,13 @@ def get_args():
help="Path to the model.onnx from NeMo CTC",
)
parser.add_argument(
"--wenet-ctc",
default="",
type=str,
help="Path to the model.onnx from WeNet CTC",
)
parser.add_argument(
"--tdnn-model",
default="",
@@ -254,6 +270,7 @@ def assert_file_exists(filename: str):
"https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it"
)
def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
"""
Args:
@@ -287,6 +304,7 @@ def main():
if args.encoder:
assert len(args.paraformer) == 0, args.paraformer
assert len(args.nemo_ctc) == 0, args.nemo_ctc
assert len(args.wenet_ctc) == 0, args.wenet_ctc
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.tdnn_model) == 0, args.tdnn_model
@@ -310,6 +328,7 @@ def main():
)
elif args.paraformer:
assert len(args.nemo_ctc) == 0, args.nemo_ctc
assert len(args.wenet_ctc) == 0, args.wenet_ctc
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.tdnn_model) == 0, args.tdnn_model
@@ -326,6 +345,7 @@ def main():
debug=args.debug,
)
elif args.nemo_ctc:
assert len(args.wenet_ctc) == 0, args.wenet_ctc
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.tdnn_model) == 0, args.tdnn_model
@@ -341,6 +361,22 @@ def main():
decoding_method=args.decoding_method,
debug=args.debug,
)
elif args.wenet_ctc:
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.tdnn_model) == 0, args.tdnn_model
assert_file_exists(args.wenet_ctc)
recognizer = sherpa_onnx.OfflineRecognizer.from_wenet_ctc(
model=args.wenet_ctc,
tokens=args.tokens,
num_threads=args.num_threads,
sample_rate=args.sample_rate,
feature_dim=args.feature_dim,
decoding_method=args.decoding_method,
debug=args.debug,
)
elif args.whisper_encoder:
assert len(args.tdnn_model) == 0, args.tdnn_model
assert_file_exists(args.whisper_encoder)

View File

@@ -37,8 +37,25 @@ git lfs pull --include "*.onnx"
./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/3.wav \
./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/8k.wav
(3) Streaming Conformer CTC from WeNet
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-zh-wenet-wenetspeech
cd sherpa-onnx-zh-wenet-wenetspeech
git lfs pull --include "*.onnx"
./python-api-examples/online-decode-files.py \
--tokens=./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt \
--wenet-ctc=./sherpa-onnx-zh-wenet-wenetspeech/model-streaming.onnx \
./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/0.wav \
./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/1.wav \
./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/8k.wav
Please refer to
https://k2-fsa.github.io/sherpa/onnx/index.html
and
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/wenet/index.html
to install sherpa-onnx and to download streaming pre-trained models.
"""
import argparse
@@ -92,6 +109,26 @@ def get_args():
help="Path to the paraformer decoder model",
)
parser.add_argument(
"--wenet-ctc",
type=str,
help="Path to the wenet ctc model model",
)
parser.add_argument(
"--wenet-ctc-chunk-size",
type=int,
default=16,
help="The --chunk-size parameter for streaming WeNet models",
)
parser.add_argument(
"--wenet-ctc-num-left-chunks",
type=int,
default=4,
help="The --num-left-chunks parameter for streaming WeNet models",
)
parser.add_argument(
"--num-threads",
type=int,
@@ -249,6 +286,18 @@ def main():
feature_dim=80,
decoding_method="greedy_search",
)
elif args.wenet_ctc:
recognizer = sherpa_onnx.OnlineRecognizer.from_wenet_ctc(
tokens=args.tokens,
model=args.wenet_ctc,
chunk_size=args.wenet_ctc_chunk_size,
num_left_chunks=args.wenet_ctc_num_left_chunks,
num_threads=args.num_threads,
provider=args.provider,
sample_rate=16000,
feature_dim=80,
decoding_method="greedy_search",
)
else:
raise ValueError("Please provide a model")

View File

@@ -40,10 +40,17 @@ python3 ./python-api-examples/streaming_server.py \
Please refer to
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/index.html
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/wenet/index.html
to download pre-trained models.
The model in the above help messages is from
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/zipformer-transducer-models.html#csukuangfj-sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english
To use a WeNet streaming Conformer CTC model, please use
python3 ./python-api-examples/streaming_server.py \
--tokens=./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt \
--wenet-ctc=./sherpa-onnx-zh-wenet-wenetspeech/model-streaming.onnx
"""
import argparse
@@ -130,6 +137,12 @@ def add_model_args(parser: argparse.ArgumentParser):
help="Path to the transducer joiner model.",
)
parser.add_argument(
"--wenet-ctc",
type=str,
help="Path to the model.onnx from WeNet",
)
parser.add_argument(
"--paraformer-encoder",
type=str,
@@ -212,7 +225,6 @@ def add_hotwords_args(parser: argparse.ArgumentParser):
)
def add_modified_beam_search_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--num-active-paths",
@@ -393,6 +405,20 @@ def create_recognizer(args) -> sherpa_onnx.OnlineRecognizer:
rule3_min_utterance_length=args.rule3_min_utterance_length,
provider=args.provider,
)
elif args.wenet_ctc:
recognizer = sherpa_onnx.OnlineRecognizer.from_wenet_ctc(
tokens=args.tokens,
model=args.wenet_ctc,
num_threads=args.num_threads,
sample_rate=args.sample_rate,
feature_dim=args.feat_dim,
decoding_method=args.decoding_method,
enable_endpoint_detection=args.use_endpoint != 0,
rule1_min_trailing_silence=args.rule1_min_trailing_silence,
rule2_min_trailing_silence=args.rule2_min_trailing_silence,
rule3_min_utterance_length=args.rule3_min_utterance_length,
provider=args.provider,
)
else:
raise ValueError("Please provide a model")
@@ -727,6 +753,8 @@ def check_args(args):
assert Path(
args.paraformer_decoder
).is_file(), f"{args.paraformer_decoder} does not exist"
elif args.wenet_ctc:
assert Path(args.wenet_ctc).is_file(), f"{args.wenet_ctc} does not exist"
else:
raise ValueError("Please provide a model")