Add Python APIs for WeNet CTC models (#428)
This commit is contained in:
@@ -58,6 +58,15 @@ wget https://github.com/snakers4/silero-vad/raw/master/files/silero_vad.onnx
|
||||
--num-threads=2 \
|
||||
/path/to/test.mp4
|
||||
|
||||
(4) For WeNet CTC models
|
||||
|
||||
./python-api-examples/generate-subtitles.py \
|
||||
--silero-vad-model=/path/to/silero_vad.onnx \
|
||||
--wenet-ctc=./sherpa-onnx-zh-wenet-wenetspeech/model.onnx \
|
||||
--tokens=./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt \
|
||||
--num-threads=2 \
|
||||
/path/to/test.mp4
|
||||
|
||||
Please refer to
|
||||
https://k2-fsa.github.io/sherpa/onnx/index.html
|
||||
to install sherpa-onnx and to download non-streaming pre-trained models
|
||||
@@ -121,6 +130,13 @@ def get_args():
|
||||
help="Path to the model.onnx from Paraformer",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--wenet-ctc",
|
||||
default="",
|
||||
type=str,
|
||||
help="Path to the CTC model.onnx from WeNet",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-threads",
|
||||
type=int,
|
||||
@@ -215,6 +231,7 @@ def assert_file_exists(filename: str):
|
||||
def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
||||
if args.encoder:
|
||||
assert len(args.paraformer) == 0, args.paraformer
|
||||
assert len(args.wenet_ctc) == 0, args.wenet_ctc
|
||||
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
||||
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
||||
|
||||
@@ -234,6 +251,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
||||
debug=args.debug,
|
||||
)
|
||||
elif args.paraformer:
|
||||
assert len(args.wenet_ctc) == 0, args.wenet_ctc
|
||||
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
||||
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
||||
|
||||
@@ -248,6 +266,21 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
||||
decoding_method=args.decoding_method,
|
||||
debug=args.debug,
|
||||
)
|
||||
elif args.wenet_ctc:
|
||||
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
||||
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
||||
|
||||
assert_file_exists(args.wenet_ctc)
|
||||
|
||||
recognizer = sherpa_onnx.OfflineRecognizer.from_wenet_ctc(
|
||||
model=args.wenet_ctc,
|
||||
tokens=args.tokens,
|
||||
num_threads=args.num_threads,
|
||||
sample_rate=args.sample_rate,
|
||||
feature_dim=args.feature_dim,
|
||||
decoding_method=args.decoding_method,
|
||||
debug=args.debug,
|
||||
)
|
||||
elif args.whisper_encoder:
|
||||
assert_file_exists(args.whisper_encoder)
|
||||
assert_file_exists(args.whisper_decoder)
|
||||
|
||||
@@ -58,7 +58,19 @@ python3 ./python-api-examples/non_streaming_server.py \
|
||||
--nemo-ctc ./sherpa-onnx-nemo-ctc-en-conformer-medium/model.onnx \
|
||||
--tokens ./sherpa-onnx-nemo-ctc-en-conformer-medium/tokens.txt
|
||||
|
||||
(4) Use a Whisper model
|
||||
(4) Use a non-streaming CTC model from WeNet
|
||||
|
||||
cd /path/to/sherpa-onnx
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-zh-wenet-wenetspeech
|
||||
cd sherpa-onnx-zh-wenet-wenetspeech
|
||||
git lfs pull --include "*.onnx"
|
||||
cd ..
|
||||
|
||||
python3 ./python-api-examples/non_streaming_server.py \
|
||||
--wenet-ctc ./sherpa-onnx-zh-wenet-wenetspeech/model.onnx \
|
||||
--tokens ./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt
|
||||
|
||||
(5) Use a Whisper model
|
||||
|
||||
cd /path/to/sherpa-onnx
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-whisper-tiny.en
|
||||
@@ -210,6 +222,15 @@ def add_nemo_ctc_model_args(parser: argparse.ArgumentParser):
|
||||
)
|
||||
|
||||
|
||||
def add_wenet_ctc_model_args(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--wenet-ctc",
|
||||
default="",
|
||||
type=str,
|
||||
help="Path to the model.onnx from WeNet CTC",
|
||||
)
|
||||
|
||||
|
||||
def add_tdnn_ctc_model_args(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--tdnn-model",
|
||||
@@ -261,6 +282,7 @@ def add_model_args(parser: argparse.ArgumentParser):
|
||||
add_transducer_model_args(parser)
|
||||
add_paraformer_model_args(parser)
|
||||
add_nemo_ctc_model_args(parser)
|
||||
add_wenet_ctc_model_args(parser)
|
||||
add_tdnn_ctc_model_args(parser)
|
||||
add_whisper_model_args(parser)
|
||||
|
||||
@@ -804,6 +826,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
||||
if args.encoder:
|
||||
assert len(args.paraformer) == 0, args.paraformer
|
||||
assert len(args.nemo_ctc) == 0, args.nemo_ctc
|
||||
assert len(args.wenet_ctc) == 0, args.wenet_ctc
|
||||
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
||||
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
||||
assert len(args.tdnn_model) == 0, args.tdnn_model
|
||||
@@ -827,6 +850,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
||||
)
|
||||
elif args.paraformer:
|
||||
assert len(args.nemo_ctc) == 0, args.nemo_ctc
|
||||
assert len(args.wenet_ctc) == 0, args.wenet_ctc
|
||||
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
||||
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
||||
assert len(args.tdnn_model) == 0, args.tdnn_model
|
||||
@@ -842,6 +866,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
||||
decoding_method=args.decoding_method,
|
||||
)
|
||||
elif args.nemo_ctc:
|
||||
assert len(args.wenet_ctc) == 0, args.wenet_ctc
|
||||
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
||||
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
||||
assert len(args.tdnn_model) == 0, args.tdnn_model
|
||||
@@ -856,6 +881,21 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
||||
feature_dim=args.feat_dim,
|
||||
decoding_method=args.decoding_method,
|
||||
)
|
||||
elif args.wenet_ctc:
|
||||
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
||||
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
||||
assert len(args.tdnn_model) == 0, args.tdnn_model
|
||||
|
||||
assert_file_exists(args.wenet_ctc)
|
||||
|
||||
recognizer = sherpa_onnx.OfflineRecognizer.from_wenet_ctc(
|
||||
model=args.wenet_ctc,
|
||||
tokens=args.tokens,
|
||||
num_threads=args.num_threads,
|
||||
sample_rate=args.sample_rate,
|
||||
feature_dim=args.feat_dim,
|
||||
decoding_method=args.decoding_method,
|
||||
)
|
||||
elif args.whisper_encoder:
|
||||
assert len(args.tdnn_model) == 0, args.tdnn_model
|
||||
assert_file_exists(args.whisper_encoder)
|
||||
|
||||
@@ -59,7 +59,16 @@ python3 ./python-api-examples/offline-decode-files.py \
|
||||
./sherpa-onnx-whisper-base.en/test_wavs/1.wav \
|
||||
./sherpa-onnx-whisper-base.en/test_wavs/8k.wav
|
||||
|
||||
(5) For tdnn models of the yesno recipe from icefall
|
||||
(5) For CTC models from WeNet
|
||||
|
||||
python3 ./python-api-examples/offline-decode-files.py \
|
||||
--wenet-ctc=./sherpa-onnx-zh-wenet-wenetspeech/model.onnx \
|
||||
--tokens=./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt \
|
||||
./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/0.wav \
|
||||
./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/1.wav \
|
||||
./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/8k.wav
|
||||
|
||||
(6) For tdnn models of the yesno recipe from icefall
|
||||
|
||||
python3 ./python-api-examples/offline-decode-files.py \
|
||||
--sample-rate=8000 \
|
||||
@@ -154,6 +163,13 @@ def get_args():
|
||||
help="Path to the model.onnx from NeMo CTC",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--wenet-ctc",
|
||||
default="",
|
||||
type=str,
|
||||
help="Path to the model.onnx from WeNet CTC",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tdnn-model",
|
||||
default="",
|
||||
@@ -254,6 +270,7 @@ def assert_file_exists(filename: str):
|
||||
"https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it"
|
||||
)
|
||||
|
||||
|
||||
def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
|
||||
"""
|
||||
Args:
|
||||
@@ -287,6 +304,7 @@ def main():
|
||||
if args.encoder:
|
||||
assert len(args.paraformer) == 0, args.paraformer
|
||||
assert len(args.nemo_ctc) == 0, args.nemo_ctc
|
||||
assert len(args.wenet_ctc) == 0, args.wenet_ctc
|
||||
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
||||
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
||||
assert len(args.tdnn_model) == 0, args.tdnn_model
|
||||
@@ -310,6 +328,7 @@ def main():
|
||||
)
|
||||
elif args.paraformer:
|
||||
assert len(args.nemo_ctc) == 0, args.nemo_ctc
|
||||
assert len(args.wenet_ctc) == 0, args.wenet_ctc
|
||||
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
||||
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
||||
assert len(args.tdnn_model) == 0, args.tdnn_model
|
||||
@@ -326,6 +345,7 @@ def main():
|
||||
debug=args.debug,
|
||||
)
|
||||
elif args.nemo_ctc:
|
||||
assert len(args.wenet_ctc) == 0, args.wenet_ctc
|
||||
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
||||
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
||||
assert len(args.tdnn_model) == 0, args.tdnn_model
|
||||
@@ -341,6 +361,22 @@ def main():
|
||||
decoding_method=args.decoding_method,
|
||||
debug=args.debug,
|
||||
)
|
||||
elif args.wenet_ctc:
|
||||
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
||||
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
||||
assert len(args.tdnn_model) == 0, args.tdnn_model
|
||||
|
||||
assert_file_exists(args.wenet_ctc)
|
||||
|
||||
recognizer = sherpa_onnx.OfflineRecognizer.from_wenet_ctc(
|
||||
model=args.wenet_ctc,
|
||||
tokens=args.tokens,
|
||||
num_threads=args.num_threads,
|
||||
sample_rate=args.sample_rate,
|
||||
feature_dim=args.feature_dim,
|
||||
decoding_method=args.decoding_method,
|
||||
debug=args.debug,
|
||||
)
|
||||
elif args.whisper_encoder:
|
||||
assert len(args.tdnn_model) == 0, args.tdnn_model
|
||||
assert_file_exists(args.whisper_encoder)
|
||||
|
||||
@@ -37,8 +37,25 @@ git lfs pull --include "*.onnx"
|
||||
./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/3.wav \
|
||||
./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/8k.wav
|
||||
|
||||
(3) Streaming Conformer CTC from WeNet
|
||||
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-zh-wenet-wenetspeech
|
||||
cd sherpa-onnx-zh-wenet-wenetspeech
|
||||
git lfs pull --include "*.onnx"
|
||||
|
||||
./python-api-examples/online-decode-files.py \
|
||||
--tokens=./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt \
|
||||
--wenet-ctc=./sherpa-onnx-zh-wenet-wenetspeech/model-streaming.onnx \
|
||||
./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/0.wav \
|
||||
./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/1.wav \
|
||||
./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/8k.wav
|
||||
|
||||
|
||||
|
||||
Please refer to
|
||||
https://k2-fsa.github.io/sherpa/onnx/index.html
|
||||
and
|
||||
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/wenet/index.html
|
||||
to install sherpa-onnx and to download streaming pre-trained models.
|
||||
"""
|
||||
import argparse
|
||||
@@ -92,6 +109,26 @@ def get_args():
|
||||
help="Path to the paraformer decoder model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--wenet-ctc",
|
||||
type=str,
|
||||
help="Path to the wenet ctc model model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--wenet-ctc-chunk-size",
|
||||
type=int,
|
||||
default=16,
|
||||
help="The --chunk-size parameter for streaming WeNet models",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--wenet-ctc-num-left-chunks",
|
||||
type=int,
|
||||
default=4,
|
||||
help="The --num-left-chunks parameter for streaming WeNet models",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-threads",
|
||||
type=int,
|
||||
@@ -249,6 +286,18 @@ def main():
|
||||
feature_dim=80,
|
||||
decoding_method="greedy_search",
|
||||
)
|
||||
elif args.wenet_ctc:
|
||||
recognizer = sherpa_onnx.OnlineRecognizer.from_wenet_ctc(
|
||||
tokens=args.tokens,
|
||||
model=args.wenet_ctc,
|
||||
chunk_size=args.wenet_ctc_chunk_size,
|
||||
num_left_chunks=args.wenet_ctc_num_left_chunks,
|
||||
num_threads=args.num_threads,
|
||||
provider=args.provider,
|
||||
sample_rate=16000,
|
||||
feature_dim=80,
|
||||
decoding_method="greedy_search",
|
||||
)
|
||||
else:
|
||||
raise ValueError("Please provide a model")
|
||||
|
||||
|
||||
@@ -40,10 +40,17 @@ python3 ./python-api-examples/streaming_server.py \
|
||||
|
||||
Please refer to
|
||||
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/index.html
|
||||
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/wenet/index.html
|
||||
to download pre-trained models.
|
||||
|
||||
The model in the above help messages is from
|
||||
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/zipformer-transducer-models.html#csukuangfj-sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english
|
||||
|
||||
To use a WeNet streaming Conformer CTC model, please use
|
||||
|
||||
python3 ./python-api-examples/streaming_server.py \
|
||||
--tokens=./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt \
|
||||
--wenet-ctc=./sherpa-onnx-zh-wenet-wenetspeech/model-streaming.onnx
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@@ -130,6 +137,12 @@ def add_model_args(parser: argparse.ArgumentParser):
|
||||
help="Path to the transducer joiner model.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--wenet-ctc",
|
||||
type=str,
|
||||
help="Path to the model.onnx from WeNet",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--paraformer-encoder",
|
||||
type=str,
|
||||
@@ -212,7 +225,6 @@ def add_hotwords_args(parser: argparse.ArgumentParser):
|
||||
)
|
||||
|
||||
|
||||
|
||||
def add_modified_beam_search_args(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--num-active-paths",
|
||||
@@ -393,6 +405,20 @@ def create_recognizer(args) -> sherpa_onnx.OnlineRecognizer:
|
||||
rule3_min_utterance_length=args.rule3_min_utterance_length,
|
||||
provider=args.provider,
|
||||
)
|
||||
elif args.wenet_ctc:
|
||||
recognizer = sherpa_onnx.OnlineRecognizer.from_wenet_ctc(
|
||||
tokens=args.tokens,
|
||||
model=args.wenet_ctc,
|
||||
num_threads=args.num_threads,
|
||||
sample_rate=args.sample_rate,
|
||||
feature_dim=args.feat_dim,
|
||||
decoding_method=args.decoding_method,
|
||||
enable_endpoint_detection=args.use_endpoint != 0,
|
||||
rule1_min_trailing_silence=args.rule1_min_trailing_silence,
|
||||
rule2_min_trailing_silence=args.rule2_min_trailing_silence,
|
||||
rule3_min_utterance_length=args.rule3_min_utterance_length,
|
||||
provider=args.provider,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Please provide a model")
|
||||
|
||||
@@ -727,6 +753,8 @@ def check_args(args):
|
||||
assert Path(
|
||||
args.paraformer_decoder
|
||||
).is_file(), f"{args.paraformer_decoder} does not exist"
|
||||
elif args.wenet_ctc:
|
||||
assert Path(args.wenet_ctc).is_file(), f"{args.wenet_ctc} does not exist"
|
||||
else:
|
||||
raise ValueError("Please provide a model")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user