Add Python APIs for WeNet CTC models (#428)
This commit is contained in:
@@ -40,10 +40,17 @@ python3 ./python-api-examples/streaming_server.py \
|
||||
|
||||
Please refer to
|
||||
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/index.html
|
||||
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/wenet/index.html
|
||||
to download pre-trained models.
|
||||
|
||||
The model in the above help messages is from
|
||||
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/zipformer-transducer-models.html#csukuangfj-sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english
|
||||
|
||||
To use a WeNet streaming Conformer CTC model, please use
|
||||
|
||||
python3 ./python-api-examples/streaming_server.py \
|
||||
--tokens=./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt \
|
||||
--wenet-ctc=./sherpa-onnx-zh-wenet-wenetspeech/model-streaming.onnx
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@@ -130,6 +137,12 @@ def add_model_args(parser: argparse.ArgumentParser):
|
||||
help="Path to the transducer joiner model.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--wenet-ctc",
|
||||
type=str,
|
||||
help="Path to the model.onnx from WeNet",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--paraformer-encoder",
|
||||
type=str,
|
||||
@@ -212,7 +225,6 @@ def add_hotwords_args(parser: argparse.ArgumentParser):
|
||||
)
|
||||
|
||||
|
||||
|
||||
def add_modified_beam_search_args(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--num-active-paths",
|
||||
@@ -393,6 +405,20 @@ def create_recognizer(args) -> sherpa_onnx.OnlineRecognizer:
|
||||
rule3_min_utterance_length=args.rule3_min_utterance_length,
|
||||
provider=args.provider,
|
||||
)
|
||||
elif args.wenet_ctc:
|
||||
recognizer = sherpa_onnx.OnlineRecognizer.from_wenet_ctc(
|
||||
tokens=args.tokens,
|
||||
model=args.wenet_ctc,
|
||||
num_threads=args.num_threads,
|
||||
sample_rate=args.sample_rate,
|
||||
feature_dim=args.feat_dim,
|
||||
decoding_method=args.decoding_method,
|
||||
enable_endpoint_detection=args.use_endpoint != 0,
|
||||
rule1_min_trailing_silence=args.rule1_min_trailing_silence,
|
||||
rule2_min_trailing_silence=args.rule2_min_trailing_silence,
|
||||
rule3_min_utterance_length=args.rule3_min_utterance_length,
|
||||
provider=args.provider,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Please provide a model")
|
||||
|
||||
@@ -727,6 +753,8 @@ def check_args(args):
|
||||
assert Path(
|
||||
args.paraformer_decoder
|
||||
).is_file(), f"{args.paraformer_decoder} does not exist"
|
||||
elif args.wenet_ctc:
|
||||
assert Path(args.wenet_ctc).is_file(), f"{args.wenet_ctc} does not exist"
|
||||
else:
|
||||
raise ValueError("Please provide a model")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user