Support whisper models (#238)

This commit is contained in:
Fangjun Kuang
2023-08-07 12:34:18 +08:00
committed by GitHub
parent 64efbd82af
commit 45b9d4ab37
39 changed files with 1836 additions and 52 deletions

View File

@@ -1,6 +1,7 @@
#!/usr/bin/env python3
#
# Copyright (c) 2023 by manyeyes
# Copyright (c) 2023 Xiaomi Corporation
"""
This file demonstrates how to use sherpa-onnx Python API to transcribe
@@ -34,6 +35,27 @@ file(s) with a non-streaming model.
(3) For CTC models from NeMo
python3 ./python-api-examples/offline-decode-files.py \
--tokens=./sherpa-onnx-nemo-ctc-en-citrinet-512/tokens.txt \
--nemo-ctc=./sherpa-onnx-nemo-ctc-en-citrinet-512/model.onnx \
--num-threads=2 \
--decoding-method=greedy_search \
--debug=false \
./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/0.wav \
./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/1.wav \
./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/8k.wav
(4) For Whisper models
python3 ./python-api-examples/offline-decode-files.py \
--whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \
--whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \
--tokens=./sherpa-onnx-whisper-base.en/base.en-tokens.txt \
--num-threads=1 \
./sherpa-onnx-whisper-base.en/test_wavs/0.wav \
./sherpa-onnx-whisper-base.en/test_wavs/1.wav \
./sherpa-onnx-whisper-base.en/test_wavs/8k.wav
Please refer to
https://k2-fsa.github.io/sherpa/onnx/index.html
to install sherpa-onnx and to download the pre-trained models
@@ -144,6 +166,20 @@ def get_args():
help="Number of threads for neural network computation",
)
parser.add_argument(
"--whisper-encoder",
default="",
type=str,
help="Path to whisper encoder model",
)
parser.add_argument(
"--whisper-decoder",
default="",
type=str,
help="Path to whisper decoder model",
)
parser.add_argument(
"--decoding-method",
type=str,
@@ -247,6 +283,8 @@ def main():
if args.encoder:
assert len(args.paraformer) == 0, args.paraformer
assert len(args.nemo_ctc) == 0, args.nemo_ctc
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
contexts = [x.strip().upper() for x in args.contexts.split("/") if x.strip()]
if contexts:
@@ -271,6 +309,9 @@ def main():
)
elif args.paraformer:
assert len(args.nemo_ctc) == 0, args.nemo_ctc
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert_file_exists(args.paraformer)
recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer(
@@ -283,6 +324,11 @@ def main():
debug=args.debug,
)
elif args.nemo_ctc:
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert_file_exists(args.nemo_ctc)
recognizer = sherpa_onnx.OfflineRecognizer.from_nemo_ctc(
model=args.nemo_ctc,
tokens=args.tokens,
@@ -292,6 +338,18 @@ def main():
decoding_method=args.decoding_method,
debug=args.debug,
)
elif args.whisper_encoder:
assert_file_exists(args.whisper_encoder)
assert_file_exists(args.whisper_decoder)
recognizer = sherpa_onnx.OfflineRecognizer.from_whisper(
encoder=args.whisper_encoder,
decoder=args.whisper_decoder,
tokens=args.tokens,
num_threads=args.num_threads,
decoding_method=args.decoding_method,
debug=args.debug,
)
else:
print("Please specify at least one model")
return