From 343e732ccb546c4ee2cf398f9210728c0471f72a Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sun, 26 Feb 2023 20:33:16 +0800 Subject: [PATCH] Refactor python examples (#67) --- .github/scripts/test-python.sh | 9 ++- .gitignore | 1 + CMakeLists.txt | 2 +- python-api-examples/decode-file.py | 72 ++++++++++++++++--- ...from-microphone-with-endpoint-detection.py | 63 ++++++++++++++-- .../speech-recognition-from-microphone.py | 60 ++++++++++++++-- sherpa-onnx/csrc/onnx-utils.cc | 1 + 7 files changed, 186 insertions(+), 22 deletions(-) mode change 100644 => 100755 python-api-examples/decode-file.py mode change 100644 => 100755 python-api-examples/speech-recognition-from-microphone-with-endpoint-detection.py mode change 100644 => 100755 python-api-examples/speech-recognition-from-microphone.py diff --git a/.github/scripts/test-python.sh b/.github/scripts/test-python.sh index c5d9accb..ca903f43 100755 --- a/.github/scripts/test-python.sh +++ b/.github/scripts/test-python.sh @@ -9,7 +9,7 @@ log() { } -repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-lstm-en-2023-02-17 +repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 log "Start testing ${repo_url}" repo=$(basename $repo_url) @@ -30,4 +30,9 @@ ls -lh ls -lh $repo -python3 python-api-examples/decode-file.py +python3 ./python-api-examples/decode-file.py \ + --tokens=$repo/tokens.txt \ + --encoder=$repo/encoder-epoch-99-avg-1.onnx \ + --decoder=$repo/decoder-epoch-99-avg-1.onnx \ + --joiner=$repo/joiner-epoch-99-avg-1.onnx \ + --wave-filename=$repo/test_wavs/4.wav diff --git a/.gitignore b/.gitignore index 2f7060e9..b3ea7a9a 100644 --- a/.gitignore +++ b/.gitignore @@ -33,3 +33,4 @@ decode-file *.dylib tokens.txt *.onnx +log.txt diff --git a/CMakeLists.txt b/CMakeLists.txt index b64d9241..d2c3a443 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,7 +1,7 @@ cmake_minimum_required(VERSION 3.13 FATAL_ERROR) project(sherpa-onnx) -set(SHERPA_ONNX_VERSION "1.1") +set(SHERPA_ONNX_VERSION "1.2") # Disable warning about # diff --git a/python-api-examples/decode-file.py b/python-api-examples/decode-file.py old mode 100644 new mode 100755 index 5bc2288e..79e846a4 --- a/python-api-examples/decode-file.py +++ b/python-api-examples/decode-file.py @@ -9,27 +9,83 @@ https://k2-fsa.github.io/sherpa/onnx/index.html to install sherpa-onnx and to download the pre-trained models used in this file. """ -import wave +import argparse import time +import wave +from pathlib import Path import numpy as np import sherpa_onnx +def assert_file_exists(filename: str): + assert Path( + filename + ).is_file(), f"{filename} does not exist!\nPlease refer to https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it" + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--tokens", + type=str, + help="Path to tokens.txt", + ) + + parser.add_argument( + "--encoder", + type=str, + help="Path to the encoder model", + ) + + parser.add_argument( + "--decoder", + type=str, + help="Path to the decoder model", + ) + + parser.add_argument( + "--joiner", + type=str, + help="Path to the joiner model", + ) + + parser.add_argument( + "--wave-filename", + type=str, + help="""Path to the wave filename. Must be 16 kHz, + mono with 16-bit samples""", + ) + + return parser.parse_args() + + def main(): sample_rate = 16000 - num_threads = 4 + num_threads = 2 + + args = get_args() + assert_file_exists(args.encoder) + assert_file_exists(args.decoder) + assert_file_exists(args.joiner) + assert_file_exists(args.tokens) + if not Path(args.wave_filename).is_file(): + print(f"{args.wave_filename} does not exist!") + return + recognizer = sherpa_onnx.OnlineRecognizer( - tokens="./sherpa-onnx-lstm-en-2023-02-17/tokens.txt", - encoder="./sherpa-onnx-lstm-en-2023-02-17/encoder-epoch-99-avg-1.onnx", - decoder="./sherpa-onnx-lstm-en-2023-02-17/decoder-epoch-99-avg-1.onnx", - joiner="./sherpa-onnx-lstm-en-2023-02-17/joiner-epoch-99-avg-1.onnx", + tokens=args.tokens, + encoder=args.encoder, + decoder=args.decoder, + joiner=args.joiner, num_threads=num_threads, sample_rate=sample_rate, feature_dim=80, ) - filename = "./sherpa-onnx-lstm-en-2023-02-17/test_wavs/1089-134686-0001.wav" - with wave.open(filename) as f: + with wave.open(args.wave_filename) as f: assert f.getframerate() == sample_rate, f.getframerate() assert f.getnchannels() == 1, f.getnchannels() assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes diff --git a/python-api-examples/speech-recognition-from-microphone-with-endpoint-detection.py b/python-api-examples/speech-recognition-from-microphone-with-endpoint-detection.py old mode 100644 new mode 100755 index 4d148d1a..93571364 --- a/python-api-examples/speech-recognition-from-microphone-with-endpoint-detection.py +++ b/python-api-examples/speech-recognition-from-microphone-with-endpoint-detection.py @@ -7,7 +7,9 @@ # https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html # to download pre-trained models +import argparse import sys +from pathlib import Path try: import sounddevice as sd @@ -22,18 +24,65 @@ except ImportError as e: import sherpa_onnx +def assert_file_exists(filename: str): + assert Path( + filename + ).is_file(), f"{filename} does not exist!\nPlease refer to https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it" + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--tokens", + type=str, + help="Path to tokens.txt", + ) + + parser.add_argument( + "--encoder", + type=str, + help="Path to the encoder model", + ) + + parser.add_argument( + "--decoder", + type=str, + help="Path to the decoder model", + ) + + parser.add_argument( + "--joiner", + type=str, + help="Path to the joiner model", + ) + + parser.add_argument( + "--wave-filename", + type=str, + help="""Path to the wave filename. Must be 16 kHz, + mono with 16-bit samples""", + ) + + return parser.parse_args() + + def create_recognizer(): + args = get_args() + assert_file_exists(args.encoder) + assert_file_exists(args.decoder) + assert_file_exists(args.joiner) + assert_file_exists(args.tokens) # Please replace the model files if needed. # See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html # for download links. recognizer = sherpa_onnx.OnlineRecognizer( - tokens="./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt", - encoder="./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx", - decoder="./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx", - joiner="./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx", - num_threads=4, - sample_rate=16000, - feature_dim=80, + tokens=args.tokens, + encoder=args.encoder, + decoder=args.decoder, + joiner=args.joiner, enable_endpoint_detection=True, rule1_min_trailing_silence=2.4, rule2_min_trailing_silence=1.2, diff --git a/python-api-examples/speech-recognition-from-microphone.py b/python-api-examples/speech-recognition-from-microphone.py old mode 100644 new mode 100755 index 6303642e..dcc72f51 --- a/python-api-examples/speech-recognition-from-microphone.py +++ b/python-api-examples/speech-recognition-from-microphone.py @@ -6,7 +6,9 @@ # https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html # to download pre-trained models +import argparse import sys +from pathlib import Path try: import sounddevice as sd @@ -21,15 +23,65 @@ except ImportError as e: import sherpa_onnx +def assert_file_exists(filename: str): + assert Path( + filename + ).is_file(), f"{filename} does not exist!\nPlease refer to https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it" + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--tokens", + type=str, + help="Path to tokens.txt", + ) + + parser.add_argument( + "--encoder", + type=str, + help="Path to the encoder model", + ) + + parser.add_argument( + "--decoder", + type=str, + help="Path to the decoder model", + ) + + parser.add_argument( + "--joiner", + type=str, + help="Path to the joiner model", + ) + + parser.add_argument( + "--wave-filename", + type=str, + help="""Path to the wave filename. Must be 16 kHz, + mono with 16-bit samples""", + ) + + return parser.parse_args() + + def create_recognizer(): + args = get_args() + assert_file_exists(args.encoder) + assert_file_exists(args.decoder) + assert_file_exists(args.joiner) + assert_file_exists(args.tokens) # Please replace the model files if needed. # See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html # for download links. recognizer = sherpa_onnx.OnlineRecognizer( - tokens="./sherpa-onnx-lstm-en-2023-02-17/tokens.txt", - encoder="./sherpa-onnx-lstm-en-2023-02-17/encoder-epoch-99-avg-1.onnx", - decoder="./sherpa-onnx-lstm-en-2023-02-17/decoder-epoch-99-avg-1.onnx", - joiner="./sherpa-onnx-lstm-en-2023-02-17/joiner-epoch-99-avg-1.onnx", + tokens=args.tokens, + encoder=args.encoder, + decoder=args.decoder, + joiner=args.joiner, num_threads=4, sample_rate=16000, feature_dim=80, diff --git a/sherpa-onnx/csrc/onnx-utils.cc b/sherpa-onnx/csrc/onnx-utils.cc index efef86b1..664aac03 100644 --- a/sherpa-onnx/csrc/onnx-utils.cc +++ b/sherpa-onnx/csrc/onnx-utils.cc @@ -3,6 +3,7 @@ // Copyright (c) 2023 Xiaomi Corporation #include "sherpa-onnx/csrc/onnx-utils.h" +#include #include #include #include