Refactor python examples (#67)

This commit is contained in:
Fangjun Kuang
2023-02-26 20:33:16 +08:00
committed by GitHub
parent 5a8c3a6d10
commit 343e732ccb
7 changed files with 186 additions and 22 deletions

View File

@@ -9,7 +9,7 @@ log() {
} }
repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-lstm-en-2023-02-17 repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20
log "Start testing ${repo_url}" log "Start testing ${repo_url}"
repo=$(basename $repo_url) repo=$(basename $repo_url)
@@ -30,4 +30,9 @@ ls -lh
ls -lh $repo ls -lh $repo
python3 python-api-examples/decode-file.py python3 ./python-api-examples/decode-file.py \
--tokens=$repo/tokens.txt \
--encoder=$repo/encoder-epoch-99-avg-1.onnx \
--decoder=$repo/decoder-epoch-99-avg-1.onnx \
--joiner=$repo/joiner-epoch-99-avg-1.onnx \
--wave-filename=$repo/test_wavs/4.wav

1
.gitignore vendored
View File

@@ -33,3 +33,4 @@ decode-file
*.dylib *.dylib
tokens.txt tokens.txt
*.onnx *.onnx
log.txt

View File

@@ -1,7 +1,7 @@
cmake_minimum_required(VERSION 3.13 FATAL_ERROR) cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
project(sherpa-onnx) project(sherpa-onnx)
set(SHERPA_ONNX_VERSION "1.1") set(SHERPA_ONNX_VERSION "1.2")
# Disable warning about # Disable warning about
# #

72
python-api-examples/decode-file.py Normal file → Executable file
View File

@@ -9,27 +9,83 @@ https://k2-fsa.github.io/sherpa/onnx/index.html
to install sherpa-onnx and to download the pre-trained models to install sherpa-onnx and to download the pre-trained models
used in this file. used in this file.
""" """
import wave import argparse
import time import time
import wave
from pathlib import Path
import numpy as np import numpy as np
import sherpa_onnx import sherpa_onnx
def assert_file_exists(filename: str):
assert Path(
filename
).is_file(), f"{filename} does not exist!\nPlease refer to https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it"
def get_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--tokens",
type=str,
help="Path to tokens.txt",
)
parser.add_argument(
"--encoder",
type=str,
help="Path to the encoder model",
)
parser.add_argument(
"--decoder",
type=str,
help="Path to the decoder model",
)
parser.add_argument(
"--joiner",
type=str,
help="Path to the joiner model",
)
parser.add_argument(
"--wave-filename",
type=str,
help="""Path to the wave filename. Must be 16 kHz,
mono with 16-bit samples""",
)
return parser.parse_args()
def main(): def main():
sample_rate = 16000 sample_rate = 16000
num_threads = 4 num_threads = 2
args = get_args()
assert_file_exists(args.encoder)
assert_file_exists(args.decoder)
assert_file_exists(args.joiner)
assert_file_exists(args.tokens)
if not Path(args.wave_filename).is_file():
print(f"{args.wave_filename} does not exist!")
return
recognizer = sherpa_onnx.OnlineRecognizer( recognizer = sherpa_onnx.OnlineRecognizer(
tokens="./sherpa-onnx-lstm-en-2023-02-17/tokens.txt", tokens=args.tokens,
encoder="./sherpa-onnx-lstm-en-2023-02-17/encoder-epoch-99-avg-1.onnx", encoder=args.encoder,
decoder="./sherpa-onnx-lstm-en-2023-02-17/decoder-epoch-99-avg-1.onnx", decoder=args.decoder,
joiner="./sherpa-onnx-lstm-en-2023-02-17/joiner-epoch-99-avg-1.onnx", joiner=args.joiner,
num_threads=num_threads, num_threads=num_threads,
sample_rate=sample_rate, sample_rate=sample_rate,
feature_dim=80, feature_dim=80,
) )
filename = "./sherpa-onnx-lstm-en-2023-02-17/test_wavs/1089-134686-0001.wav" with wave.open(args.wave_filename) as f:
with wave.open(filename) as f:
assert f.getframerate() == sample_rate, f.getframerate() assert f.getframerate() == sample_rate, f.getframerate()
assert f.getnchannels() == 1, f.getnchannels() assert f.getnchannels() == 1, f.getnchannels()
assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes

View File

@@ -7,7 +7,9 @@
# https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html # https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
# to download pre-trained models # to download pre-trained models
import argparse
import sys import sys
from pathlib import Path
try: try:
import sounddevice as sd import sounddevice as sd
@@ -22,18 +24,65 @@ except ImportError as e:
import sherpa_onnx import sherpa_onnx
def assert_file_exists(filename: str):
assert Path(
filename
).is_file(), f"{filename} does not exist!\nPlease refer to https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it"
def get_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--tokens",
type=str,
help="Path to tokens.txt",
)
parser.add_argument(
"--encoder",
type=str,
help="Path to the encoder model",
)
parser.add_argument(
"--decoder",
type=str,
help="Path to the decoder model",
)
parser.add_argument(
"--joiner",
type=str,
help="Path to the joiner model",
)
parser.add_argument(
"--wave-filename",
type=str,
help="""Path to the wave filename. Must be 16 kHz,
mono with 16-bit samples""",
)
return parser.parse_args()
def create_recognizer(): def create_recognizer():
args = get_args()
assert_file_exists(args.encoder)
assert_file_exists(args.decoder)
assert_file_exists(args.joiner)
assert_file_exists(args.tokens)
# Please replace the model files if needed. # Please replace the model files if needed.
# See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html # See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
# for download links. # for download links.
recognizer = sherpa_onnx.OnlineRecognizer( recognizer = sherpa_onnx.OnlineRecognizer(
tokens="./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt", tokens=args.tokens,
encoder="./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx", encoder=args.encoder,
decoder="./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx", decoder=args.decoder,
joiner="./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx", joiner=args.joiner,
num_threads=4,
sample_rate=16000,
feature_dim=80,
enable_endpoint_detection=True, enable_endpoint_detection=True,
rule1_min_trailing_silence=2.4, rule1_min_trailing_silence=2.4,
rule2_min_trailing_silence=1.2, rule2_min_trailing_silence=1.2,

View File

@@ -6,7 +6,9 @@
# https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html # https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
# to download pre-trained models # to download pre-trained models
import argparse
import sys import sys
from pathlib import Path
try: try:
import sounddevice as sd import sounddevice as sd
@@ -21,15 +23,65 @@ except ImportError as e:
import sherpa_onnx import sherpa_onnx
def assert_file_exists(filename: str):
assert Path(
filename
).is_file(), f"{filename} does not exist!\nPlease refer to https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it"
def get_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--tokens",
type=str,
help="Path to tokens.txt",
)
parser.add_argument(
"--encoder",
type=str,
help="Path to the encoder model",
)
parser.add_argument(
"--decoder",
type=str,
help="Path to the decoder model",
)
parser.add_argument(
"--joiner",
type=str,
help="Path to the joiner model",
)
parser.add_argument(
"--wave-filename",
type=str,
help="""Path to the wave filename. Must be 16 kHz,
mono with 16-bit samples""",
)
return parser.parse_args()
def create_recognizer(): def create_recognizer():
args = get_args()
assert_file_exists(args.encoder)
assert_file_exists(args.decoder)
assert_file_exists(args.joiner)
assert_file_exists(args.tokens)
# Please replace the model files if needed. # Please replace the model files if needed.
# See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html # See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
# for download links. # for download links.
recognizer = sherpa_onnx.OnlineRecognizer( recognizer = sherpa_onnx.OnlineRecognizer(
tokens="./sherpa-onnx-lstm-en-2023-02-17/tokens.txt", tokens=args.tokens,
encoder="./sherpa-onnx-lstm-en-2023-02-17/encoder-epoch-99-avg-1.onnx", encoder=args.encoder,
decoder="./sherpa-onnx-lstm-en-2023-02-17/decoder-epoch-99-avg-1.onnx", decoder=args.decoder,
joiner="./sherpa-onnx-lstm-en-2023-02-17/joiner-epoch-99-avg-1.onnx", joiner=args.joiner,
num_threads=4, num_threads=4,
sample_rate=16000, sample_rate=16000,
feature_dim=80, feature_dim=80,

View File

@@ -3,6 +3,7 @@
// Copyright (c) 2023 Xiaomi Corporation // Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/onnx-utils.h"
#include <algorithm>
#include <fstream> #include <fstream>
#include <string> #include <string>
#include <vector> #include <vector>