This repository has been archived on 2025-08-26. You can view files and clone it, but cannot push or open issues or pull requests.
Files
enginex_bi_series-sherpa-onnx/python-api-examples/keyword-spotter.py
Fangjun Kuang 8b989a851c Fix keyword spotting. (#1689)
Reset the stream right after detecting a keyword
2025-01-20 16:41:10 +08:00

118 lines
3.7 KiB
Python
Executable File

#!/usr/bin/env python3
"""
This file demonstrates how to use sherpa-onnx Python API to do keyword spotting
from wave file(s).
Please refer to
https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html
to download pre-trained models.
"""
import argparse
import time
import wave
from pathlib import Path
from typing import List, Tuple
import numpy as np
import sherpa_onnx
def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
"""
Args:
wave_filename:
Path to a wave file. It should be single channel and each sample should
be 16-bit. Its sample rate does not need to be 16kHz.
Returns:
Return a tuple containing:
- A 1-D array of dtype np.float32 containing the samples, which are
normalized to the range [-1, 1].
- sample rate of the wave file
"""
with wave.open(wave_filename) as f:
assert f.getnchannels() == 1, f.getnchannels()
assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes
num_samples = f.getnframes()
samples = f.readframes(num_samples)
samples_int16 = np.frombuffer(samples, dtype=np.int16)
samples_float32 = samples_int16.astype(np.float32)
samples_float32 = samples_float32 / 32768
return samples_float32, f.getframerate()
def create_keyword_spotter():
kws = sherpa_onnx.KeywordSpotter(
tokens="./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/tokens.txt",
encoder="./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/encoder-epoch-12-avg-2-chunk-16-left-64.onnx",
decoder="./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/decoder-epoch-12-avg-2-chunk-16-left-64.onnx",
joiner="./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/joiner-epoch-12-avg-2-chunk-16-left-64.onnx",
num_threads=2,
keywords_file="./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/test_keywords.txt",
provider="cpu",
)
return kws
def main():
kws = create_keyword_spotter()
wave_filename = (
"./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/3.wav"
)
samples, sample_rate = read_wave(wave_filename)
tail_paddings = np.zeros(int(0.66 * sample_rate), dtype=np.float32)
print("----------Use pre-defined keywords----------")
s = kws.create_stream()
s.accept_waveform(sample_rate, samples)
s.accept_waveform(sample_rate, tail_paddings)
s.input_finished()
while kws.is_ready(s):
kws.decode_stream(s)
r = kws.get_result(s)
if r != "":
# Remember to call reset right after detected a keyword
kws.reset_stream(s)
print(f"Detected {r}")
print("----------Use pre-defined keywords + add a new keyword----------")
s = kws.create_stream("y ǎn y uán @演员")
s.accept_waveform(sample_rate, samples)
s.accept_waveform(sample_rate, tail_paddings)
s.input_finished()
while kws.is_ready(s):
kws.decode_stream(s)
r = kws.get_result(s)
if r != "":
# Remember to call reset right after detected a keyword
kws.reset_stream(s)
print(f"Detected {r}")
print("----------Use pre-defined keywords + add 2 new keywords----------")
s = kws.create_stream("y ǎn y uán @演员/zh ī m íng @知名")
s.accept_waveform(sample_rate, samples)
s.accept_waveform(sample_rate, tail_paddings)
s.input_finished()
while kws.is_ready(s):
kws.decode_stream(s)
r = kws.get_result(s)
if r != "":
# Remember to call reset right after detected a keyword
kws.reset_stream(s)
print(f"Detected {r}")
if __name__ == "__main__":
main()