Support streaming paraformer (#263)

This commit is contained in:
Fangjun Kuang
2023-08-14 10:32:14 +08:00
committed by GitHub
parent a4bff28e21
commit 6038e2aa62
38 changed files with 1488 additions and 112 deletions

View File

@@ -5,16 +5,41 @@ This file demonstrates how to use sherpa-onnx Python API to transcribe
file(s) with a streaming model.
Usage:
./online-decode-files.py \
/path/to/foo.wav \
/path/to/bar.wav \
/path/to/16kHz.wav \
/path/to/8kHz.wav
(1) Streaming transducer
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-06-26
cd sherpa-onnx-streaming-zipformer-en-2023-06-26
git lfs pull --include "*.onnx"
./python-api-examples/online-decode-files.py \
--tokens=./sherpa-onnx-streaming-zipformer-en-2023-06-26/tokens.txt \
--encoder=./sherpa-onnx-streaming-zipformer-en-2023-06-26/encoder-epoch-99-avg-1-chunk-16-left-64.onnx \
--decoder=./sherpa-onnx-streaming-zipformer-en-2023-06-26/decoder-epoch-99-avg-1-chunk-16-left-64.onnx \
--joiner=./sherpa-onnx-streaming-zipformer-en-2023-06-26/joiner-epoch-99-avg-1-chunk-16-left-64.onnx \
./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/0.wav \
./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/1.wav \
./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/8k.wav
(2) Streaming paraformer
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-paraformer-bilingual-zh-en
cd sherpa-onnx-streaming-paraformer-bilingual-zh-en
git lfs pull --include "*.onnx"
./python-api-examples/online-decode-files.py \
--tokens=./sherpa-onnx-streaming-paraformer-bilingual-zh-en/tokens.txt \
--paraformer-encoder=./sherpa-onnx-streaming-paraformer-bilingual-zh-en/encoder.int8.onnx \
--paraformer-decoder=./sherpa-onnx-streaming-paraformer-bilingual-zh-en/decoder.int8.onnx \
./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/0.wav \
./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/1.wav \
./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/2.wav \
./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/3.wav \
./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/8k.wav
Please refer to
https://k2-fsa.github.io/sherpa/onnx/index.html
to install sherpa-onnx and to download the pre-trained models
used in this file.
to install sherpa-onnx and to download streaming pre-trained models.
"""
import argparse
import time
@@ -41,19 +66,31 @@ def get_args():
parser.add_argument(
"--encoder",
type=str,
help="Path to the encoder model",
help="Path to the transducer encoder model",
)
parser.add_argument(
"--decoder",
type=str,
help="Path to the decoder model",
help="Path to the transducer decoder model",
)
parser.add_argument(
"--joiner",
type=str,
help="Path to the joiner model",
help="Path to the transducer joiner model",
)
parser.add_argument(
"--paraformer-encoder",
type=str,
help="Path to the paraformer encoder model",
)
parser.add_argument(
"--paraformer-decoder",
type=str,
help="Path to the paraformer decoder model",
)
parser.add_argument(
@@ -200,24 +237,42 @@ def encode_contexts(args, contexts: List[str]) -> List[List[int]]:
def main():
args = get_args()
assert_file_exists(args.encoder)
assert_file_exists(args.decoder)
assert_file_exists(args.joiner)
assert_file_exists(args.tokens)
recognizer = sherpa_onnx.OnlineRecognizer.from_transducer(
tokens=args.tokens,
encoder=args.encoder,
decoder=args.decoder,
joiner=args.joiner,
num_threads=args.num_threads,
provider=args.provider,
sample_rate=16000,
feature_dim=80,
decoding_method=args.decoding_method,
max_active_paths=args.max_active_paths,
context_score=args.context_score,
)
if args.encoder:
assert_file_exists(args.encoder)
assert_file_exists(args.decoder)
assert_file_exists(args.joiner)
assert not args.paraformer_encoder, args.paraformer_encoder
assert not args.paraformer_decoder, args.paraformer_decoder
recognizer = sherpa_onnx.OnlineRecognizer.from_transducer(
tokens=args.tokens,
encoder=args.encoder,
decoder=args.decoder,
joiner=args.joiner,
num_threads=args.num_threads,
provider=args.provider,
sample_rate=16000,
feature_dim=80,
decoding_method=args.decoding_method,
max_active_paths=args.max_active_paths,
context_score=args.context_score,
)
elif args.paraformer_encoder:
recognizer = sherpa_onnx.OnlineRecognizer.from_paraformer(
tokens=args.tokens,
encoder=args.paraformer_encoder,
decoder=args.paraformer_decoder,
num_threads=args.num_threads,
provider=args.provider,
sample_rate=16000,
feature_dim=80,
decoding_method="greedy_search",
)
else:
raise ValueError("Please provide a model")
print("Started!")
start_time = time.time()
@@ -243,7 +298,7 @@ def main():
s.accept_waveform(sample_rate, samples)
tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32)
tail_paddings = np.zeros(int(0.66 * sample_rate), dtype=np.float32)
s.accept_waveform(sample_rate, tail_paddings)
s.input_finished()