118 lines
3.7 KiB
Python
Executable File
118 lines
3.7 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
|
|
"""
|
|
This file demonstrates how to use sherpa-onnx Python API to do keyword spotting
|
|
from wave file(s).
|
|
|
|
Please refer to
|
|
https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html
|
|
to download pre-trained models.
|
|
"""
|
|
import argparse
|
|
import time
|
|
import wave
|
|
from pathlib import Path
|
|
from typing import List, Tuple
|
|
|
|
import numpy as np
|
|
import sherpa_onnx
|
|
|
|
|
|
def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
|
|
"""
|
|
Args:
|
|
wave_filename:
|
|
Path to a wave file. It should be single channel and each sample should
|
|
be 16-bit. Its sample rate does not need to be 16kHz.
|
|
Returns:
|
|
Return a tuple containing:
|
|
- A 1-D array of dtype np.float32 containing the samples, which are
|
|
normalized to the range [-1, 1].
|
|
- sample rate of the wave file
|
|
"""
|
|
|
|
with wave.open(wave_filename) as f:
|
|
assert f.getnchannels() == 1, f.getnchannels()
|
|
assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes
|
|
num_samples = f.getnframes()
|
|
samples = f.readframes(num_samples)
|
|
samples_int16 = np.frombuffer(samples, dtype=np.int16)
|
|
samples_float32 = samples_int16.astype(np.float32)
|
|
|
|
samples_float32 = samples_float32 / 32768
|
|
return samples_float32, f.getframerate()
|
|
|
|
|
|
def create_keyword_spotter():
|
|
kws = sherpa_onnx.KeywordSpotter(
|
|
tokens="./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/tokens.txt",
|
|
encoder="./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/encoder-epoch-12-avg-2-chunk-16-left-64.onnx",
|
|
decoder="./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/decoder-epoch-12-avg-2-chunk-16-left-64.onnx",
|
|
joiner="./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/joiner-epoch-12-avg-2-chunk-16-left-64.onnx",
|
|
num_threads=2,
|
|
keywords_file="./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/test_keywords.txt",
|
|
provider="cpu",
|
|
)
|
|
|
|
return kws
|
|
|
|
|
|
def main():
|
|
kws = create_keyword_spotter()
|
|
|
|
wave_filename = (
|
|
"./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/3.wav"
|
|
)
|
|
|
|
samples, sample_rate = read_wave(wave_filename)
|
|
|
|
tail_paddings = np.zeros(int(0.66 * sample_rate), dtype=np.float32)
|
|
|
|
print("----------Use pre-defined keywords----------")
|
|
s = kws.create_stream()
|
|
s.accept_waveform(sample_rate, samples)
|
|
s.accept_waveform(sample_rate, tail_paddings)
|
|
s.input_finished()
|
|
while kws.is_ready(s):
|
|
kws.decode_stream(s)
|
|
r = kws.get_result(s)
|
|
if r != "":
|
|
# Remember to call reset right after detected a keyword
|
|
kws.reset_stream(s)
|
|
|
|
print(f"Detected {r}")
|
|
|
|
print("----------Use pre-defined keywords + add a new keyword----------")
|
|
|
|
s = kws.create_stream("y ǎn y uán @演员")
|
|
s.accept_waveform(sample_rate, samples)
|
|
s.accept_waveform(sample_rate, tail_paddings)
|
|
s.input_finished()
|
|
while kws.is_ready(s):
|
|
kws.decode_stream(s)
|
|
r = kws.get_result(s)
|
|
if r != "":
|
|
# Remember to call reset right after detected a keyword
|
|
kws.reset_stream(s)
|
|
|
|
print(f"Detected {r}")
|
|
|
|
print("----------Use pre-defined keywords + add 2 new keywords----------")
|
|
|
|
s = kws.create_stream("y ǎn y uán @演员/zh ī m íng @知名")
|
|
s.accept_waveform(sample_rate, samples)
|
|
s.accept_waveform(sample_rate, tail_paddings)
|
|
s.input_finished()
|
|
while kws.is_ready(s):
|
|
kws.decode_stream(s)
|
|
r = kws.get_result(s)
|
|
if r != "":
|
|
# Remember to call reset right after detected a keyword
|
|
kws.reset_stream(s)
|
|
|
|
print(f"Detected {r}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|