Add Python API for keyword spotting (#576)
* Add alsa & microphone support for keyword spotting * Add python wrapper
This commit is contained in:
58
.github/scripts/test-python.sh
vendored
58
.github/scripts/test-python.sh
vendored
@@ -293,3 +293,61 @@ git clone https://github.com/pkufool/sherpa-test-data /tmp/sherpa-test-data
|
||||
python3 sherpa-onnx/python/tests/test_text2token.py --verbose
|
||||
|
||||
rm -rf /tmp/sherpa-test-data
|
||||
|
||||
mkdir -p /tmp/onnx-models
|
||||
dir=/tmp/onnx-models
|
||||
|
||||
log "Test keyword spotting models"
|
||||
|
||||
python3 -c "import sherpa_onnx; print(sherpa_onnx.__file__)"
|
||||
sherpa_onnx_version=$(python3 -c "import sherpa_onnx; print(sherpa_onnx.__version__)")
|
||||
|
||||
echo "sherpa_onnx version: $sherpa_onnx_version"
|
||||
|
||||
pwd
|
||||
ls -lh
|
||||
|
||||
repo=sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01
|
||||
log "Start testing ${repo}"
|
||||
|
||||
pushd $dir
|
||||
wget -qq https://github.com/pkufool/keyword-spotting-models/releases/download/v0.1/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01.tar.bz
|
||||
tar xf sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01.tar.bz
|
||||
popd
|
||||
|
||||
repo=$dir/$repo
|
||||
ls -lh $repo
|
||||
|
||||
python3 ./python-api-examples/keyword-spotter.py \
|
||||
--tokens=$repo/tokens.txt \
|
||||
--encoder=$repo/encoder-epoch-12-avg-2-chunk-16-left-64.onnx \
|
||||
--decoder=$repo/decoder-epoch-12-avg-2-chunk-16-left-64.onnx \
|
||||
--joiner=$repo/joiner-epoch-12-avg-2-chunk-16-left-64.onnx \
|
||||
--keywords-file=$repo/test_wavs/test_keywords.txt \
|
||||
$repo/test_wavs/0.wav \
|
||||
$repo/test_wavs/1.wav
|
||||
|
||||
repo=sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01
|
||||
log "Start testing ${repo}"
|
||||
|
||||
pushd $dir
|
||||
wget -qq https://github.com/pkufool/keyword-spotting-models/releases/download/v0.1/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01.tar.bz
|
||||
tar xf sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01.tar.bz
|
||||
popd
|
||||
|
||||
repo=$dir/$repo
|
||||
ls -lh $repo
|
||||
|
||||
python3 ./python-api-examples/keyword-spotter.py \
|
||||
--tokens=$repo/tokens.txt \
|
||||
--encoder=$repo/encoder-epoch-12-avg-2-chunk-16-left-64.onnx \
|
||||
--decoder=$repo/decoder-epoch-12-avg-2-chunk-16-left-64.onnx \
|
||||
--joiner=$repo/joiner-epoch-12-avg-2-chunk-16-left-64.onnx \
|
||||
--keywords-file=$repo/test_wavs/test_keywords.txt \
|
||||
$repo/test_wavs/3.wav \
|
||||
$repo/test_wavs/4.wav \
|
||||
$repo/test_wavs/5.wav
|
||||
|
||||
python3 sherpa-onnx/python/tests/test_keyword_spotter.py --verbose
|
||||
|
||||
rm -r $dir
|
||||
|
||||
191
python-api-examples/keyword-spotter-from-microphone.py
Executable file
191
python-api-examples/keyword-spotter-from-microphone.py
Executable file
@@ -0,0 +1,191 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Real-time keyword spotting from a microphone with sherpa-onnx Python API
|
||||
#
|
||||
# Please refer to
|
||||
# https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html
|
||||
# to download pre-trained models
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from typing import List
|
||||
|
||||
try:
|
||||
import sounddevice as sd
|
||||
except ImportError:
|
||||
print("Please install sounddevice first. You can use")
|
||||
print()
|
||||
print(" pip install sounddevice")
|
||||
print()
|
||||
print("to install it")
|
||||
sys.exit(-1)
|
||||
|
||||
import sherpa_onnx
|
||||
|
||||
|
||||
def assert_file_exists(filename: str):
|
||||
assert Path(filename).is_file(), (
|
||||
f"{filename} does not exist!\n"
|
||||
"Please refer to "
|
||||
"https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html to download it"
|
||||
)
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=str,
|
||||
help="Path to tokens.txt",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder",
|
||||
type=str,
|
||||
help="Path to the transducer encoder model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoder",
|
||||
type=str,
|
||||
help="Path to the transducer decoder model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--joiner",
|
||||
type=str,
|
||||
help="Path to the transducer joiner model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-threads",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of threads for neural network computation",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--provider",
|
||||
type=str,
|
||||
default="cpu",
|
||||
help="Valid values: cpu, cuda, coreml",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-active-paths",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""
|
||||
It specifies number of active paths to keep during decoding.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-trailing-blanks",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""The number of trailing blanks a keyword should be followed. Setting
|
||||
to a larger value (e.g. 8) when your keywords has overlapping tokens
|
||||
between each other.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--keywords-file",
|
||||
type=str,
|
||||
help="""
|
||||
The file containing keywords, one words/phrases per line, and for each
|
||||
phrase the bpe/cjkchar/pinyin are separated by a space. For example:
|
||||
|
||||
▁HE LL O ▁WORLD
|
||||
x iǎo ài t óng x ué
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--keywords-score",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="""
|
||||
The boosting score of each token for keywords. The larger the easier to
|
||||
survive beam search.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--keywords-threshold",
|
||||
type=float,
|
||||
default=0.25,
|
||||
help="""
|
||||
The trigger threshold (i.e. probability) of the keyword. The larger the
|
||||
harder to trigger.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
|
||||
devices = sd.query_devices()
|
||||
if len(devices) == 0:
|
||||
print("No microphone devices found")
|
||||
sys.exit(0)
|
||||
|
||||
print(devices)
|
||||
default_input_device_idx = sd.default.device[0]
|
||||
print(f'Use default device: {devices[default_input_device_idx]["name"]}')
|
||||
|
||||
assert_file_exists(args.tokens)
|
||||
assert_file_exists(args.encoder)
|
||||
assert_file_exists(args.decoder)
|
||||
assert_file_exists(args.joiner)
|
||||
|
||||
assert Path(
|
||||
args.keywords_file
|
||||
).is_file(), (
|
||||
f"keywords_file : {args.keywords_file} not exist, please provide a valid path."
|
||||
)
|
||||
|
||||
keyword_spotter = sherpa_onnx.KeywordSpotter(
|
||||
tokens=args.tokens,
|
||||
encoder=args.encoder,
|
||||
decoder=args.decoder,
|
||||
joiner=args.joiner,
|
||||
num_threads=args.num_threads,
|
||||
max_active_paths=args.max_active_paths,
|
||||
keywords_file=args.keywords_file,
|
||||
keywords_score=args.keywords_score,
|
||||
keywords_threshold=args.keywords_threshold,
|
||||
num_tailing_blanks=args.rnum_tailing_blanks,
|
||||
provider=args.provider,
|
||||
)
|
||||
|
||||
print("Started! Please speak")
|
||||
|
||||
sample_rate = 16000
|
||||
samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms
|
||||
stream = keyword_spotter.create_stream()
|
||||
with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s:
|
||||
while True:
|
||||
samples, _ = s.read(samples_per_read) # a blocking read
|
||||
samples = samples.reshape(-1)
|
||||
stream.accept_waveform(sample_rate, samples)
|
||||
while keyword_spotter.is_ready(stream):
|
||||
keyword_spotter.decode_stream(stream)
|
||||
result = keyword_spotter.get_result(stream)
|
||||
if result:
|
||||
print("\r{}".format(result), end="", flush=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
main()
|
||||
except KeyboardInterrupt:
|
||||
print("\nCaught Ctrl + C. Exiting")
|
||||
242
python-api-examples/keyword-spotter.py
Executable file
242
python-api-examples/keyword-spotter.py
Executable file
@@ -0,0 +1,242 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
"""
|
||||
This file demonstrates how to use sherpa-onnx Python API to do keyword spotting
|
||||
from wave file(s).
|
||||
|
||||
Please refer to
|
||||
https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html
|
||||
to download pre-trained models.
|
||||
"""
|
||||
import argparse
|
||||
import time
|
||||
import wave
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import sherpa_onnx
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=str,
|
||||
help="Path to tokens.txt",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder",
|
||||
type=str,
|
||||
help="Path to the transducer encoder model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoder",
|
||||
type=str,
|
||||
help="Path to the transducer decoder model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--joiner",
|
||||
type=str,
|
||||
help="Path to the transducer joiner model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-threads",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of threads for neural network computation",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--provider",
|
||||
type=str,
|
||||
default="cpu",
|
||||
help="Valid values: cpu, cuda, coreml",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-active-paths",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""
|
||||
It specifies number of active paths to keep during decoding.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-trailing-blanks",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""The number of trailing blanks a keyword should be followed. Setting
|
||||
to a larger value (e.g. 8) when your keywords has overlapping tokens
|
||||
between each other.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--keywords-file",
|
||||
type=str,
|
||||
help="""
|
||||
The file containing keywords, one words/phrases per line, and for each
|
||||
phrase the bpe/cjkchar/pinyin are separated by a space. For example:
|
||||
|
||||
▁HE LL O ▁WORLD
|
||||
x iǎo ài t óng x ué
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--keywords-score",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="""
|
||||
The boosting score of each token for keywords. The larger the easier to
|
||||
survive beam search.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--keywords-threshold",
|
||||
type=float,
|
||||
default=0.25,
|
||||
help="""
|
||||
The trigger threshold (i.e. probability) of the keyword. The larger the
|
||||
harder to trigger.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to decode. Each file must be of WAVE"
|
||||
"format with a single channel, and each sample has 16-bit, "
|
||||
"i.e., int16_t. "
|
||||
"The sample rate of the file can be arbitrary and does not need to "
|
||||
"be 16 kHz",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def assert_file_exists(filename: str):
|
||||
assert Path(filename).is_file(), (
|
||||
f"{filename} does not exist!\n"
|
||||
"Please refer to "
|
||||
"https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html to download it"
|
||||
)
|
||||
|
||||
|
||||
def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
|
||||
"""
|
||||
Args:
|
||||
wave_filename:
|
||||
Path to a wave file. It should be single channel and each sample should
|
||||
be 16-bit. Its sample rate does not need to be 16kHz.
|
||||
Returns:
|
||||
Return a tuple containing:
|
||||
- A 1-D array of dtype np.float32 containing the samples, which are
|
||||
normalized to the range [-1, 1].
|
||||
- sample rate of the wave file
|
||||
"""
|
||||
|
||||
with wave.open(wave_filename) as f:
|
||||
assert f.getnchannels() == 1, f.getnchannels()
|
||||
assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes
|
||||
num_samples = f.getnframes()
|
||||
samples = f.readframes(num_samples)
|
||||
samples_int16 = np.frombuffer(samples, dtype=np.int16)
|
||||
samples_float32 = samples_int16.astype(np.float32)
|
||||
|
||||
samples_float32 = samples_float32 / 32768
|
||||
return samples_float32, f.getframerate()
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
assert_file_exists(args.tokens)
|
||||
assert_file_exists(args.encoder)
|
||||
assert_file_exists(args.decoder)
|
||||
assert_file_exists(args.joiner)
|
||||
|
||||
assert Path(
|
||||
args.keywords_file
|
||||
).is_file(), (
|
||||
f"keywords_file : {args.keywords_file} not exist, please provide a valid path."
|
||||
)
|
||||
|
||||
keyword_spotter = sherpa_onnx.KeywordSpotter(
|
||||
tokens=args.tokens,
|
||||
encoder=args.encoder,
|
||||
decoder=args.decoder,
|
||||
joiner=args.joiner,
|
||||
num_threads=args.num_threads,
|
||||
max_active_paths=args.max_active_paths,
|
||||
keywords_file=args.keywords_file,
|
||||
keywords_score=args.keywords_score,
|
||||
keywords_threshold=args.keywords_threshold,
|
||||
num_trailing_blanks=args.num_trailing_blanks,
|
||||
provider=args.provider,
|
||||
)
|
||||
|
||||
print("Started!")
|
||||
start_time = time.time()
|
||||
|
||||
streams = []
|
||||
total_duration = 0
|
||||
for wave_filename in args.sound_files:
|
||||
assert_file_exists(wave_filename)
|
||||
samples, sample_rate = read_wave(wave_filename)
|
||||
duration = len(samples) / sample_rate
|
||||
total_duration += duration
|
||||
|
||||
s = keyword_spotter.create_stream()
|
||||
|
||||
s.accept_waveform(sample_rate, samples)
|
||||
|
||||
tail_paddings = np.zeros(int(0.66 * sample_rate), dtype=np.float32)
|
||||
s.accept_waveform(sample_rate, tail_paddings)
|
||||
|
||||
s.input_finished()
|
||||
|
||||
streams.append(s)
|
||||
|
||||
results = [""] * len(streams)
|
||||
while True:
|
||||
ready_list = []
|
||||
for i, s in enumerate(streams):
|
||||
if keyword_spotter.is_ready(s):
|
||||
ready_list.append(s)
|
||||
r = keyword_spotter.get_result(s)
|
||||
if r:
|
||||
results[i] += f"{r}/"
|
||||
print(f"{r} is detected.")
|
||||
if len(ready_list) == 0:
|
||||
break
|
||||
keyword_spotter.decode_streams(ready_list)
|
||||
end_time = time.time()
|
||||
print("Done!")
|
||||
|
||||
for wave_filename, result in zip(args.sound_files, results):
|
||||
print(f"{wave_filename}\n{result}")
|
||||
print("-" * 10)
|
||||
|
||||
elapsed_seconds = end_time - start_time
|
||||
rtf = elapsed_seconds / total_duration
|
||||
print(f"num_threads: {args.num_threads}")
|
||||
print(f"Wave duration: {total_duration:.3f} s")
|
||||
print(f"Elapsed time: {elapsed_seconds:.3f} s")
|
||||
print(
|
||||
f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -230,12 +230,14 @@ endif()
|
||||
|
||||
if(SHERPA_ONNX_HAS_ALSA AND SHERPA_ONNX_ENABLE_BINARY)
|
||||
add_executable(sherpa-onnx-alsa sherpa-onnx-alsa.cc alsa.cc)
|
||||
add_executable(sherpa-onnx-keyword-spotter-alsa sherpa-onnx-keyword-spotter-alsa.cc alsa.cc)
|
||||
add_executable(sherpa-onnx-offline-tts-play-alsa sherpa-onnx-offline-tts-play-alsa.cc alsa-play.cc)
|
||||
add_executable(sherpa-onnx-alsa-offline sherpa-onnx-alsa-offline.cc alsa.cc)
|
||||
add_executable(sherpa-onnx-alsa-offline-speaker-identification sherpa-onnx-alsa-offline-speaker-identification.cc alsa.cc)
|
||||
|
||||
set(exes
|
||||
sherpa-onnx-alsa
|
||||
sherpa-onnx-keyword-spotter-alsa
|
||||
sherpa-onnx-alsa-offline
|
||||
sherpa-onnx-offline-tts-play-alsa
|
||||
sherpa-onnx-alsa-offline-speaker-identification
|
||||
@@ -278,6 +280,11 @@ if(SHERPA_ONNX_ENABLE_PORTAUDIO AND SHERPA_ONNX_ENABLE_BINARY)
|
||||
microphone.cc
|
||||
)
|
||||
|
||||
add_executable(sherpa-onnx-keyword-spotter-microphone
|
||||
sherpa-onnx-keyword-spotter-microphone.cc
|
||||
microphone.cc
|
||||
)
|
||||
|
||||
add_executable(sherpa-onnx-microphone
|
||||
sherpa-onnx-microphone.cc
|
||||
microphone.cc
|
||||
@@ -311,6 +318,7 @@ if(SHERPA_ONNX_ENABLE_PORTAUDIO AND SHERPA_ONNX_ENABLE_BINARY)
|
||||
|
||||
set(exes
|
||||
sherpa-onnx-microphone
|
||||
sherpa-onnx-keyword-spotter-microphone
|
||||
sherpa-onnx-microphone-offline
|
||||
sherpa-onnx-microphone-offline-speaker-identification
|
||||
sherpa-onnx-offline-tts-play
|
||||
|
||||
124
sherpa-onnx/csrc/sherpa-onnx-keyword-spotter-alsa.cc
Normal file
124
sherpa-onnx/csrc/sherpa-onnx-keyword-spotter-alsa.cc
Normal file
@@ -0,0 +1,124 @@
|
||||
// sherpa-onnx/csrc/sherpa-onnx-keyword-spotter-alsa.cc
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
#include <signal.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
|
||||
#include "sherpa-onnx/csrc/alsa.h"
|
||||
#include "sherpa-onnx/csrc/display.h"
|
||||
#include "sherpa-onnx/csrc/keyword-spotter.h"
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
|
||||
bool stop = false;
|
||||
|
||||
static void Handler(int sig) {
|
||||
stop = true;
|
||||
fprintf(stderr, "\nCaught Ctrl + C. Exiting...\n");
|
||||
}
|
||||
|
||||
int main(int32_t argc, char *argv[]) {
|
||||
signal(SIGINT, Handler);
|
||||
|
||||
const char *kUsageMessage = R"usage(
|
||||
Usage:
|
||||
./bin/sherpa-onnx-keyword-spotter-alsa \
|
||||
--tokens=/path/to/tokens.txt \
|
||||
--encoder=/path/to/encoder.onnx \
|
||||
--decoder=/path/to/decoder.onnx \
|
||||
--joiner=/path/to/joiner.onnx \
|
||||
--provider=cpu \
|
||||
--num-threads=2 \
|
||||
--keywords-file=keywords.txt \
|
||||
device_name
|
||||
|
||||
Please refer to
|
||||
https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html
|
||||
for a list of pre-trained models to download.
|
||||
|
||||
The device name specifies which microphone to use in case there are several
|
||||
on you system. You can use
|
||||
|
||||
arecord -l
|
||||
|
||||
to find all available microphones on your computer. For instance, if it outputs
|
||||
|
||||
**** List of CAPTURE Hardware Devices ****
|
||||
card 3: UACDemoV10 [UACDemoV1.0], device 0: USB Audio [USB Audio]
|
||||
Subdevices: 1/1
|
||||
Subdevice #0: subdevice #0
|
||||
|
||||
and if you want to select card 3 and the device 0 on that card, please use:
|
||||
|
||||
hw:3,0
|
||||
|
||||
or
|
||||
|
||||
plughw:3,0
|
||||
|
||||
as the device_name.
|
||||
)usage";
|
||||
sherpa_onnx::ParseOptions po(kUsageMessage);
|
||||
sherpa_onnx::KeywordSpotterConfig config;
|
||||
|
||||
config.Register(&po);
|
||||
|
||||
po.Read(argc, argv);
|
||||
if (po.NumArgs() != 1) {
|
||||
fprintf(stderr, "Please provide only 1 argument: the device name\n");
|
||||
po.PrintUsage();
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s\n", config.ToString().c_str());
|
||||
|
||||
if (!config.Validate()) {
|
||||
fprintf(stderr, "Errors in config!\n");
|
||||
return -1;
|
||||
}
|
||||
sherpa_onnx::KeywordSpotter spotter(config);
|
||||
|
||||
int32_t expected_sample_rate = config.feat_config.sampling_rate;
|
||||
|
||||
std::string device_name = po.GetArg(1);
|
||||
sherpa_onnx::Alsa alsa(device_name.c_str());
|
||||
fprintf(stderr, "Use recording device: %s\n", device_name.c_str());
|
||||
|
||||
if (alsa.GetExpectedSampleRate() != expected_sample_rate) {
|
||||
fprintf(stderr, "sample rate: %d != %d\n", alsa.GetExpectedSampleRate(),
|
||||
expected_sample_rate);
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
int32_t chunk = 0.1 * alsa.GetActualSampleRate();
|
||||
|
||||
std::string last_text;
|
||||
|
||||
auto stream = spotter.CreateStream();
|
||||
|
||||
sherpa_onnx::Display display;
|
||||
|
||||
int32_t keyword_index = 0;
|
||||
while (!stop) {
|
||||
const std::vector<float> &samples = alsa.Read(chunk);
|
||||
|
||||
stream->AcceptWaveform(expected_sample_rate, samples.data(),
|
||||
samples.size());
|
||||
|
||||
while (spotter.IsReady(stream.get())) {
|
||||
spotter.DecodeStream(stream.get());
|
||||
}
|
||||
|
||||
const auto r = spotter.GetResult(stream.get());
|
||||
if (!r.keyword.empty()) {
|
||||
display.Print(keyword_index, r.AsJsonString());
|
||||
fflush(stderr);
|
||||
keyword_index++;
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
148
sherpa-onnx/csrc/sherpa-onnx-keyword-spotter-microphone.cc
Normal file
148
sherpa-onnx/csrc/sherpa-onnx-keyword-spotter-microphone.cc
Normal file
@@ -0,0 +1,148 @@
|
||||
// sherpa-onnx/csrc/sherpa-onnx-keyword-spotter-microphone.cc
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#include <signal.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "portaudio.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/display.h"
|
||||
#include "sherpa-onnx/csrc/microphone.h"
|
||||
#include "sherpa-onnx/csrc/keyword-spotter.h"
|
||||
|
||||
bool stop = false;
|
||||
|
||||
static int32_t RecordCallback(const void *input_buffer,
|
||||
void * /*output_buffer*/,
|
||||
unsigned long frames_per_buffer, // NOLINT
|
||||
const PaStreamCallbackTimeInfo * /*time_info*/,
|
||||
PaStreamCallbackFlags /*status_flags*/,
|
||||
void *user_data) {
|
||||
auto stream = reinterpret_cast<sherpa_onnx::OnlineStream *>(user_data);
|
||||
|
||||
stream->AcceptWaveform(16000, reinterpret_cast<const float *>(input_buffer),
|
||||
frames_per_buffer);
|
||||
|
||||
return stop ? paComplete : paContinue;
|
||||
}
|
||||
|
||||
static void Handler(int32_t sig) {
|
||||
stop = true;
|
||||
fprintf(stderr, "\nCaught Ctrl + C. Exiting...\n");
|
||||
}
|
||||
|
||||
int32_t main(int32_t argc, char *argv[]) {
|
||||
signal(SIGINT, Handler);
|
||||
|
||||
const char *kUsageMessage = R"usage(
|
||||
This program uses streaming models with microphone for keyword spotting.
|
||||
Usage:
|
||||
|
||||
./bin/sherpa-onnx-keyword-spotter-microphone \
|
||||
--tokens=/path/to/tokens.txt \
|
||||
--encoder=/path/to/encoder.onnx \
|
||||
--decoder=/path/to/decoder.onnx \
|
||||
--joiner=/path/to/joiner.onnx \
|
||||
--provider=cpu \
|
||||
--num-threads=1 \
|
||||
--keywords-file=keywords.txt
|
||||
|
||||
Please refer to
|
||||
https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html
|
||||
for a list of pre-trained models to download.
|
||||
)usage";
|
||||
|
||||
sherpa_onnx::ParseOptions po(kUsageMessage);
|
||||
sherpa_onnx::KeywordSpotterConfig config;
|
||||
|
||||
config.Register(&po);
|
||||
po.Read(argc, argv);
|
||||
if (po.NumArgs() != 0) {
|
||||
po.PrintUsage();
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s\n", config.ToString().c_str());
|
||||
|
||||
if (!config.Validate()) {
|
||||
fprintf(stderr, "Errors in config!\n");
|
||||
return -1;
|
||||
}
|
||||
|
||||
sherpa_onnx::KeywordSpotter spotter(config);
|
||||
auto s = spotter.CreateStream();
|
||||
|
||||
sherpa_onnx::Microphone mic;
|
||||
|
||||
PaDeviceIndex num_devices = Pa_GetDeviceCount();
|
||||
fprintf(stderr, "Num devices: %d\n", num_devices);
|
||||
|
||||
PaStreamParameters param;
|
||||
|
||||
param.device = Pa_GetDefaultInputDevice();
|
||||
if (param.device == paNoDevice) {
|
||||
fprintf(stderr, "No default input device found\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
fprintf(stderr, "Use default device: %d\n", param.device);
|
||||
|
||||
const PaDeviceInfo *info = Pa_GetDeviceInfo(param.device);
|
||||
fprintf(stderr, " Name: %s\n", info->name);
|
||||
fprintf(stderr, " Max input channels: %d\n", info->maxInputChannels);
|
||||
|
||||
param.channelCount = 1;
|
||||
param.sampleFormat = paFloat32;
|
||||
|
||||
param.suggestedLatency = info->defaultLowInputLatency;
|
||||
param.hostApiSpecificStreamInfo = nullptr;
|
||||
float sample_rate = 16000;
|
||||
|
||||
PaStream *stream;
|
||||
PaError err =
|
||||
Pa_OpenStream(&stream, ¶m, nullptr, /* &outputParameters, */
|
||||
sample_rate,
|
||||
0, // frames per buffer
|
||||
paClipOff, // we won't output out of range samples
|
||||
// so don't bother clipping them
|
||||
RecordCallback, s.get());
|
||||
if (err != paNoError) {
|
||||
fprintf(stderr, "portaudio error: %s\n", Pa_GetErrorText(err));
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
err = Pa_StartStream(stream);
|
||||
fprintf(stderr, "Started\n");
|
||||
|
||||
if (err != paNoError) {
|
||||
fprintf(stderr, "portaudio error: %s\n", Pa_GetErrorText(err));
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
int32_t keyword_index = 0;
|
||||
sherpa_onnx::Display display;
|
||||
while (!stop) {
|
||||
while (spotter.IsReady(s.get())) {
|
||||
spotter.DecodeStream(s.get());
|
||||
}
|
||||
|
||||
const auto r = spotter.GetResult(s.get());
|
||||
if (!r.keyword.empty()) {
|
||||
display.Print(keyword_index, r.AsJsonString());
|
||||
fflush(stderr);
|
||||
keyword_index++;
|
||||
}
|
||||
|
||||
Pa_Sleep(20); // sleep for 20ms
|
||||
}
|
||||
|
||||
err = Pa_CloseStream(stream);
|
||||
if (err != paNoError) {
|
||||
fprintf(stderr, "portaudio error: %s\n", Pa_GetErrorText(err));
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -12,7 +12,6 @@
|
||||
#include "sherpa-onnx/csrc/keyword-spotter.h"
|
||||
#include "sherpa-onnx/csrc/online-stream.h"
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
#include "sherpa-onnx/csrc/symbol-table.h"
|
||||
#include "sherpa-onnx/csrc/wave-reader.h"
|
||||
|
||||
typedef struct {
|
||||
|
||||
@@ -5,6 +5,7 @@ pybind11_add_module(_sherpa_onnx
|
||||
display.cc
|
||||
endpoint.cc
|
||||
features.cc
|
||||
keyword-spotter.cc
|
||||
offline-ctc-fst-decoder-config.cc
|
||||
offline-lm-config.cc
|
||||
offline-model-config.cc
|
||||
|
||||
82
sherpa-onnx/python/csrc/keyword-spotter.cc
Normal file
82
sherpa-onnx/python/csrc/keyword-spotter.cc
Normal file
@@ -0,0 +1,82 @@
|
||||
// sherpa-onnx/python/csrc/keyword-spotter.cc
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/python/csrc/keyword-spotter.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/keyword-spotter.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static void PybindKeywordResult(py::module *m) {
|
||||
using PyClass = KeywordResult;
|
||||
py::class_<PyClass>(*m, "KeywordResult")
|
||||
.def_property_readonly(
|
||||
"keyword",
|
||||
[](PyClass &self) -> py::str {
|
||||
return py::str(PyUnicode_DecodeUTF8(self.keyword.c_str(),
|
||||
self.keyword.size(), "ignore"));
|
||||
})
|
||||
.def_property_readonly(
|
||||
"tokens",
|
||||
[](PyClass &self) -> std::vector<std::string> { return self.tokens; })
|
||||
.def_property_readonly(
|
||||
"timestamps",
|
||||
[](PyClass &self) -> std::vector<float> { return self.timestamps; });
|
||||
}
|
||||
|
||||
static void PybindKeywordSpotterConfig(py::module *m) {
|
||||
using PyClass = KeywordSpotterConfig;
|
||||
py::class_<PyClass>(*m, "KeywordSpotterConfig")
|
||||
.def(py::init<const FeatureExtractorConfig &, const OnlineModelConfig &,
|
||||
int32_t, int32_t, float, float, const std::string &>(),
|
||||
py::arg("feat_config"), py::arg("model_config"),
|
||||
py::arg("max_active_paths") = 4, py::arg("num_trailing_blanks") = 1,
|
||||
py::arg("keywords_score") = 1.0,
|
||||
py::arg("keywords_threshold") = 0.25, py::arg("keywords_file") = "")
|
||||
.def_readwrite("feat_config", &PyClass::feat_config)
|
||||
.def_readwrite("model_config", &PyClass::model_config)
|
||||
.def_readwrite("max_active_paths", &PyClass::max_active_paths)
|
||||
.def_readwrite("num_trailing_blanks", &PyClass::num_trailing_blanks)
|
||||
.def_readwrite("keywords_score", &PyClass::keywords_score)
|
||||
.def_readwrite("keywords_threshold", &PyClass::keywords_threshold)
|
||||
.def_readwrite("keywords_file", &PyClass::keywords_file)
|
||||
.def("__str__", &PyClass::ToString);
|
||||
}
|
||||
|
||||
void PybindKeywordSpotter(py::module *m) {
|
||||
PybindKeywordResult(m);
|
||||
PybindKeywordSpotterConfig(m);
|
||||
|
||||
using PyClass = KeywordSpotter;
|
||||
py::class_<PyClass>(*m, "KeywordSpotter")
|
||||
.def(py::init<const KeywordSpotterConfig &>(), py::arg("config"),
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"create_stream",
|
||||
[](const PyClass &self) { return self.CreateStream(); },
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"create_stream",
|
||||
[](PyClass &self, const std::string &keywords) {
|
||||
return self.CreateStream(keywords);
|
||||
},
|
||||
py::arg("keywords"), py::call_guard<py::gil_scoped_release>())
|
||||
.def("is_ready", &PyClass::IsReady,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("decode_stream", &PyClass::DecodeStream,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"decode_streams",
|
||||
[](PyClass &self, std::vector<OnlineStream *> ss) {
|
||||
self.DecodeStreams(ss.data(), ss.size());
|
||||
},
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("get_result", &PyClass::GetResult,
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
16
sherpa-onnx/python/csrc/keyword-spotter.h
Normal file
16
sherpa-onnx/python/csrc/keyword-spotter.h
Normal file
@@ -0,0 +1,16 @@
|
||||
// sherpa-onnx/python/csrc/keyword-spotter.h
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_PYTHON_CSRC_KEYWORD_SPOTTER_H_
|
||||
#define SHERPA_ONNX_PYTHON_CSRC_KEYWORD_SPOTTER_H_
|
||||
|
||||
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void PybindKeywordSpotter(py::module *m);
|
||||
|
||||
}
|
||||
|
||||
#endif // SHERPA_ONNX_PYTHON_CSRC_KEYWORD_SPOTTER_H_
|
||||
@@ -8,6 +8,7 @@
|
||||
#include "sherpa-onnx/python/csrc/display.h"
|
||||
#include "sherpa-onnx/python/csrc/endpoint.h"
|
||||
#include "sherpa-onnx/python/csrc/features.h"
|
||||
#include "sherpa-onnx/python/csrc/keyword-spotter.h"
|
||||
#include "sherpa-onnx/python/csrc/offline-ctc-fst-decoder-config.h"
|
||||
#include "sherpa-onnx/python/csrc/offline-lm-config.h"
|
||||
#include "sherpa-onnx/python/csrc/offline-model-config.h"
|
||||
@@ -35,6 +36,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
|
||||
PybindOnlineStream(&m);
|
||||
PybindEndpoint(&m);
|
||||
PybindOnlineRecognizer(&m);
|
||||
PybindKeywordSpotter(&m);
|
||||
|
||||
PybindDisplay(&m);
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ from _sherpa_onnx import (
|
||||
VoiceActivityDetector,
|
||||
)
|
||||
|
||||
from .keyword_spotter import KeywordSpotter
|
||||
from .offline_recognizer import OfflineRecognizer
|
||||
from .online_recognizer import OnlineRecognizer
|
||||
from .utils import text2token
|
||||
|
||||
147
sherpa-onnx/python/sherpa_onnx/keyword_spotter.py
Normal file
147
sherpa-onnx/python/sherpa_onnx/keyword_spotter.py
Normal file
@@ -0,0 +1,147 @@
|
||||
# Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from _sherpa_onnx import (
|
||||
FeatureExtractorConfig,
|
||||
KeywordSpotterConfig,
|
||||
OnlineModelConfig,
|
||||
OnlineTransducerModelConfig,
|
||||
OnlineStream,
|
||||
)
|
||||
|
||||
from _sherpa_onnx import KeywordSpotter as _KeywordSpotter
|
||||
|
||||
|
||||
def _assert_file_exists(f: str):
|
||||
assert Path(f).is_file(), f"{f} does not exist"
|
||||
|
||||
|
||||
class KeywordSpotter(object):
|
||||
"""A class for keyword spotting.
|
||||
|
||||
Please refer to the following files for usages
|
||||
- https://github.com/k2-fsa/sherpa-onnx/blob/master/python-api-examples/keyword-spotter.py
|
||||
- https://github.com/k2-fsa/sherpa-onnx/blob/master/python-api-examples/keyword-spotter-from-microphone.py
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokens: str,
|
||||
encoder: str,
|
||||
decoder: str,
|
||||
joiner: str,
|
||||
keywords_file: str,
|
||||
num_threads: int = 2,
|
||||
sample_rate: float = 16000,
|
||||
feature_dim: int = 80,
|
||||
max_active_paths: int = 4,
|
||||
keywords_score: float = 1.0,
|
||||
keywords_threshold: float = 0.25,
|
||||
num_trailing_blanks: int = 1,
|
||||
provider: str = "cpu",
|
||||
):
|
||||
"""
|
||||
Please refer to
|
||||
`<https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html>`_
|
||||
to download pre-trained models for different languages, e.g., Chinese,
|
||||
English, etc.
|
||||
|
||||
Args:
|
||||
tokens:
|
||||
Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
|
||||
columns::
|
||||
|
||||
symbol integer_id
|
||||
|
||||
encoder:
|
||||
Path to ``encoder.onnx``.
|
||||
decoder:
|
||||
Path to ``decoder.onnx``.
|
||||
joiner:
|
||||
Path to ``joiner.onnx``.
|
||||
keywords_file:
|
||||
The file containing keywords, one word/phrase per line, and for each
|
||||
phrase the bpe/cjkchar/pinyin are separated by a space.
|
||||
num_threads:
|
||||
Number of threads for neural network computation.
|
||||
sample_rate:
|
||||
Sample rate of the training data used to train the model.
|
||||
feature_dim:
|
||||
Dimension of the feature used to train the model.
|
||||
max_active_paths:
|
||||
Use only when decoding_method is modified_beam_search. It specifies
|
||||
the maximum number of active paths during beam search.
|
||||
keywords_score:
|
||||
The boosting score of each token for keywords. The larger the easier to
|
||||
survive beam search.
|
||||
keywords_threshold:
|
||||
The trigger threshold (i.e. probability) of the keyword. The larger the
|
||||
harder to trigger.
|
||||
num_trailing_blanks:
|
||||
The number of trailing blanks a keyword should be followed. Setting
|
||||
to a larger value (e.g. 8) when your keywords has overlapping tokens
|
||||
between each other.
|
||||
provider:
|
||||
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
|
||||
"""
|
||||
_assert_file_exists(tokens)
|
||||
_assert_file_exists(encoder)
|
||||
_assert_file_exists(decoder)
|
||||
_assert_file_exists(joiner)
|
||||
|
||||
assert num_threads > 0, num_threads
|
||||
|
||||
transducer_config = OnlineTransducerModelConfig(
|
||||
encoder=encoder,
|
||||
decoder=decoder,
|
||||
joiner=joiner,
|
||||
)
|
||||
|
||||
model_config = OnlineModelConfig(
|
||||
transducer=transducer_config,
|
||||
tokens=tokens,
|
||||
num_threads=num_threads,
|
||||
provider=provider,
|
||||
)
|
||||
|
||||
feat_config = FeatureExtractorConfig(
|
||||
sampling_rate=sample_rate,
|
||||
feature_dim=feature_dim,
|
||||
)
|
||||
|
||||
keywords_spotter_config = KeywordSpotterConfig(
|
||||
feat_config=feat_config,
|
||||
model_config=model_config,
|
||||
max_active_paths=max_active_paths,
|
||||
num_trailing_blanks=num_trailing_blanks,
|
||||
keywords_score=keywords_score,
|
||||
keywords_threshold=keywords_threshold,
|
||||
keywords_file=keywords_file,
|
||||
)
|
||||
self.keyword_spotter = _KeywordSpotter(keywords_spotter_config)
|
||||
|
||||
def create_stream(self, keywords: Optional[str] = None):
|
||||
if keywords is None:
|
||||
return self.keyword_spotter.create_stream()
|
||||
else:
|
||||
return self.keyword_spotter.create_stream(keywords)
|
||||
|
||||
def decode_stream(self, s: OnlineStream):
|
||||
self.keyword_spotter.decode_stream(s)
|
||||
|
||||
def decode_streams(self, ss: List[OnlineStream]):
|
||||
self.keyword_spotter.decode_streams(ss)
|
||||
|
||||
def is_ready(self, s: OnlineStream) -> bool:
|
||||
return self.keyword_spotter.is_ready(s)
|
||||
|
||||
def get_result(self, s: OnlineStream) -> str:
|
||||
return self.keyword_spotter.get_result(s).keyword.strip()
|
||||
|
||||
def tokens(self, s: OnlineStream) -> List[str]:
|
||||
return self.keyword_spotter.get_result(s).tokens
|
||||
|
||||
def timestamps(self, s: OnlineStream) -> List[float]:
|
||||
return self.keyword_spotter.get_result(s).timestamps
|
||||
@@ -20,6 +20,7 @@ endfunction()
|
||||
# please sort the files in alphabetic order
|
||||
set(py_test_files
|
||||
test_feature_extractor_config.py
|
||||
test_keyword_spotter.py
|
||||
test_offline_recognizer.py
|
||||
test_online_recognizer.py
|
||||
test_online_transducer_model_config.py
|
||||
|
||||
170
sherpa-onnx/python/tests/test_keyword_spotter.py
Executable file
170
sherpa-onnx/python/tests/test_keyword_spotter.py
Executable file
@@ -0,0 +1,170 @@
|
||||
# sherpa-onnx/python/tests/test_keyword_spotter.py
|
||||
#
|
||||
# Copyright (c) 2024 Xiaomi Corporation
|
||||
#
|
||||
# To run this single test, use
|
||||
#
|
||||
# ctest --verbose -R test_keyword_spotter_py
|
||||
|
||||
import unittest
|
||||
import wave
|
||||
from pathlib import Path
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import sherpa_onnx
|
||||
|
||||
d = "/tmp/onnx-models"
|
||||
# Please refer to
|
||||
# https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html
|
||||
# to download pre-trained models for testing
|
||||
|
||||
|
||||
def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
|
||||
"""
|
||||
Args:
|
||||
wave_filename:
|
||||
Path to a wave file. It should be single channel and each sample should
|
||||
be 16-bit. Its sample rate does not need to be 16kHz.
|
||||
Returns:
|
||||
Return a tuple containing:
|
||||
- A 1-D array of dtype np.float32 containing the samples, which are
|
||||
normalized to the range [-1, 1].
|
||||
- sample rate of the wave file
|
||||
"""
|
||||
|
||||
with wave.open(wave_filename) as f:
|
||||
assert f.getnchannels() == 1, f.getnchannels()
|
||||
assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes
|
||||
num_samples = f.getnframes()
|
||||
samples = f.readframes(num_samples)
|
||||
samples_int16 = np.frombuffer(samples, dtype=np.int16)
|
||||
samples_float32 = samples_int16.astype(np.float32)
|
||||
|
||||
samples_float32 = samples_float32 / 32768
|
||||
return samples_float32, f.getframerate()
|
||||
|
||||
|
||||
class TestKeywordSpotter(unittest.TestCase):
|
||||
def test_zipformer_transducer_en(self):
|
||||
for use_int8 in [True, False]:
|
||||
if use_int8:
|
||||
encoder = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/encoder-epoch-12-avg-2-chunk-16-left-64.int8.onnx"
|
||||
decoder = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/decoder-epoch-12-avg-2-chunk-16-left-64.int8.onnx"
|
||||
joiner = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/joiner-epoch-12-avg-2-chunk-16-left-64.int8.onnx"
|
||||
else:
|
||||
encoder = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/encoder-epoch-12-avg-2-chunk-16-left-64.int8.onnx"
|
||||
decoder = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/decoder-epoch-12-avg-2-chunk-16-left-64.int8.onnx"
|
||||
joiner = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/joiner-epoch-12-avg-2-chunk-16-left-64.int8.onnx"
|
||||
|
||||
tokens = (
|
||||
f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/tokens.txt"
|
||||
)
|
||||
keywords_file = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/test_wavs/test_keywords.txt"
|
||||
wave0 = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/test_wavs/0.wav"
|
||||
wave1 = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/test_wavs/1.wav"
|
||||
|
||||
if not Path(encoder).is_file():
|
||||
print("skipping test_zipformer_transducer_en()")
|
||||
return
|
||||
keyword_spotter = sherpa_onnx.KeywordSpotter(
|
||||
encoder=encoder,
|
||||
decoder=decoder,
|
||||
joiner=joiner,
|
||||
tokens=tokens,
|
||||
num_threads=1,
|
||||
keywords_file=keywords_file,
|
||||
provider="cpu",
|
||||
)
|
||||
streams = []
|
||||
waves = [wave0, wave1]
|
||||
for wave in waves:
|
||||
s = keyword_spotter.create_stream()
|
||||
samples, sample_rate = read_wave(wave)
|
||||
s.accept_waveform(sample_rate, samples)
|
||||
|
||||
tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32)
|
||||
s.accept_waveform(sample_rate, tail_paddings)
|
||||
s.input_finished()
|
||||
streams.append(s)
|
||||
|
||||
results = [""] * len(streams)
|
||||
while True:
|
||||
ready_list = []
|
||||
for i, s in enumerate(streams):
|
||||
if keyword_spotter.is_ready(s):
|
||||
ready_list.append(s)
|
||||
r = keyword_spotter.get_result(s)
|
||||
if r:
|
||||
print(f"{r} is detected.")
|
||||
results[i] += f"{r}/"
|
||||
if len(ready_list) == 0:
|
||||
break
|
||||
keyword_spotter.decode_streams(ready_list)
|
||||
for wave_filename, result in zip(waves, results):
|
||||
print(f"{wave_filename}\n{result[0:-1]}")
|
||||
print("-" * 10)
|
||||
|
||||
def test_zipformer_transducer_cn(self):
|
||||
for use_int8 in [True, False]:
|
||||
if use_int8:
|
||||
encoder = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/encoder-epoch-12-avg-2-chunk-16-left-64.int8.onnx"
|
||||
decoder = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/decoder-epoch-12-avg-2-chunk-16-left-64.int8.onnx"
|
||||
joiner = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/joiner-epoch-12-avg-2-chunk-16-left-64.int8.onnx"
|
||||
else:
|
||||
encoder = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/encoder-epoch-12-avg-2-chunk-16-left-64.int8.onnx"
|
||||
decoder = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/decoder-epoch-12-avg-2-chunk-16-left-64.int8.onnx"
|
||||
joiner = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/joiner-epoch-12-avg-2-chunk-16-left-64.int8.onnx"
|
||||
|
||||
tokens = (
|
||||
f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/tokens.txt"
|
||||
)
|
||||
keywords_file = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/test_keywords.txt"
|
||||
wave0 = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/3.wav"
|
||||
wave1 = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/4.wav"
|
||||
wave2 = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/5.wav"
|
||||
|
||||
if not Path(encoder).is_file():
|
||||
print("skipping test_zipformer_transducer_cn()")
|
||||
return
|
||||
keyword_spotter = sherpa_onnx.KeywordSpotter(
|
||||
encoder=encoder,
|
||||
decoder=decoder,
|
||||
joiner=joiner,
|
||||
tokens=tokens,
|
||||
num_threads=1,
|
||||
keywords_file=keywords_file,
|
||||
provider="cpu",
|
||||
)
|
||||
streams = []
|
||||
waves = [wave0, wave1, wave2]
|
||||
for wave in waves:
|
||||
s = keyword_spotter.create_stream()
|
||||
samples, sample_rate = read_wave(wave)
|
||||
s.accept_waveform(sample_rate, samples)
|
||||
|
||||
tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32)
|
||||
s.accept_waveform(sample_rate, tail_paddings)
|
||||
s.input_finished()
|
||||
streams.append(s)
|
||||
|
||||
results = [""] * len(streams)
|
||||
while True:
|
||||
ready_list = []
|
||||
for i, s in enumerate(streams):
|
||||
if keyword_spotter.is_ready(s):
|
||||
ready_list.append(s)
|
||||
r = keyword_spotter.get_result(s)
|
||||
if r:
|
||||
print(f"{r} is detected.")
|
||||
results[i] += f"{r}/"
|
||||
if len(ready_list) == 0:
|
||||
break
|
||||
keyword_spotter.decode_streams(ready_list)
|
||||
for wave_filename, result in zip(waves, results):
|
||||
print(f"{wave_filename}\n{result[0:-1]}")
|
||||
print("-" * 10)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user