Add Python APIs for WeNet CTC models (#428)

This commit is contained in:
Fangjun Kuang
2023-11-16 14:20:41 +08:00
committed by GitHub
parent fac4f6bc7c
commit 049fb9f451
13 changed files with 538 additions and 11 deletions

View File

@@ -40,10 +40,17 @@ python3 ./python-api-examples/streaming_server.py \
Please refer to
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/index.html
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/wenet/index.html
to download pre-trained models.
The model in the above help messages is from
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/zipformer-transducer-models.html#csukuangfj-sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english
To use a WeNet streaming Conformer CTC model, please use
python3 ./python-api-examples/streaming_server.py \
--tokens=./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt \
--wenet-ctc=./sherpa-onnx-zh-wenet-wenetspeech/model-streaming.onnx
"""
import argparse
@@ -130,6 +137,12 @@ def add_model_args(parser: argparse.ArgumentParser):
help="Path to the transducer joiner model.",
)
parser.add_argument(
"--wenet-ctc",
type=str,
help="Path to the model.onnx from WeNet",
)
parser.add_argument(
"--paraformer-encoder",
type=str,
@@ -212,7 +225,6 @@ def add_hotwords_args(parser: argparse.ArgumentParser):
)
def add_modified_beam_search_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--num-active-paths",
@@ -393,6 +405,20 @@ def create_recognizer(args) -> sherpa_onnx.OnlineRecognizer:
rule3_min_utterance_length=args.rule3_min_utterance_length,
provider=args.provider,
)
elif args.wenet_ctc:
recognizer = sherpa_onnx.OnlineRecognizer.from_wenet_ctc(
tokens=args.tokens,
model=args.wenet_ctc,
num_threads=args.num_threads,
sample_rate=args.sample_rate,
feature_dim=args.feat_dim,
decoding_method=args.decoding_method,
enable_endpoint_detection=args.use_endpoint != 0,
rule1_min_trailing_silence=args.rule1_min_trailing_silence,
rule2_min_trailing_silence=args.rule2_min_trailing_silence,
rule3_min_utterance_length=args.rule3_min_utterance_length,
provider=args.provider,
)
else:
raise ValueError("Please provide a model")
@@ -727,6 +753,8 @@ def check_args(args):
assert Path(
args.paraformer_decoder
).is_file(), f"{args.paraformer_decoder} does not exist"
elif args.wenet_ctc:
assert Path(args.wenet_ctc).is_file(), f"{args.wenet_ctc} does not exist"
else:
raise ValueError("Please provide a model")