Fix keyword spotting. (#1689)
Reset the stream right after detecting a keyword
This commit is contained in:
33
.github/scripts/test-python.sh
vendored
33
.github/scripts/test-python.sh
vendored
@@ -574,29 +574,6 @@ echo "sherpa_onnx version: $sherpa_onnx_version"
|
||||
pwd
|
||||
ls -lh
|
||||
|
||||
repo=sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01
|
||||
log "Start testing ${repo}"
|
||||
|
||||
pushd $dir
|
||||
curl -LS -O https://github.com/pkufool/keyword-spotting-models/releases/download/v0.1/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01.tar.bz
|
||||
tar xf sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01.tar.bz
|
||||
rm sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01.tar.bz
|
||||
popd
|
||||
|
||||
repo=$dir/$repo
|
||||
ls -lh $repo
|
||||
|
||||
python3 ./python-api-examples/keyword-spotter.py \
|
||||
--tokens=$repo/tokens.txt \
|
||||
--encoder=$repo/encoder-epoch-12-avg-2-chunk-16-left-64.onnx \
|
||||
--decoder=$repo/decoder-epoch-12-avg-2-chunk-16-left-64.onnx \
|
||||
--joiner=$repo/joiner-epoch-12-avg-2-chunk-16-left-64.onnx \
|
||||
--keywords-file=$repo/test_wavs/test_keywords.txt \
|
||||
$repo/test_wavs/0.wav \
|
||||
$repo/test_wavs/1.wav
|
||||
|
||||
rm -rf $repo
|
||||
|
||||
if [[ x$OS != x'windows-latest' ]]; then
|
||||
echo "OS: $OS"
|
||||
|
||||
@@ -612,15 +589,7 @@ if [[ x$OS != x'windows-latest' ]]; then
|
||||
repo=$dir/$repo
|
||||
ls -lh $repo
|
||||
|
||||
python3 ./python-api-examples/keyword-spotter.py \
|
||||
--tokens=$repo/tokens.txt \
|
||||
--encoder=$repo/encoder-epoch-12-avg-2-chunk-16-left-64.onnx \
|
||||
--decoder=$repo/decoder-epoch-12-avg-2-chunk-16-left-64.onnx \
|
||||
--joiner=$repo/joiner-epoch-12-avg-2-chunk-16-left-64.onnx \
|
||||
--keywords-file=$repo/test_wavs/test_keywords.txt \
|
||||
$repo/test_wavs/3.wav \
|
||||
$repo/test_wavs/4.wav \
|
||||
$repo/test_wavs/5.wav
|
||||
python3 ./python-api-examples/keyword-spotter.py
|
||||
|
||||
python3 sherpa-onnx/python/tests/test_keyword_spotter.py --verbose
|
||||
|
||||
|
||||
21
.github/workflows/c-api.yaml
vendored
21
.github/workflows/c-api.yaml
vendored
@@ -79,6 +79,27 @@ jobs:
|
||||
otool -L ./install/lib/libsherpa-onnx-c-api.dylib
|
||||
fi
|
||||
|
||||
- name: Test kws (zh)
|
||||
shell: bash
|
||||
run: |
|
||||
gcc -o kws-c-api ./c-api-examples/kws-c-api.c \
|
||||
-I ./build/install/include \
|
||||
-L ./build/install/lib/ \
|
||||
-l sherpa-onnx-c-api \
|
||||
-l onnxruntime
|
||||
|
||||
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/kws-models/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2
|
||||
tar xvf sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2
|
||||
rm sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2
|
||||
|
||||
export LD_LIBRARY_PATH=$PWD/build/install/lib:$LD_LIBRARY_PATH
|
||||
export DYLD_LIBRARY_PATH=$PWD/build/install/lib:$DYLD_LIBRARY_PATH
|
||||
|
||||
./kws-c-api
|
||||
|
||||
rm ./kws-c-api
|
||||
rm -rf sherpa-onnx-kws-*
|
||||
|
||||
- name: Test Kokoro TTS (en)
|
||||
shell: bash
|
||||
run: |
|
||||
|
||||
22
.github/workflows/cxx-api.yaml
vendored
22
.github/workflows/cxx-api.yaml
vendored
@@ -81,6 +81,28 @@ jobs:
|
||||
otool -L ./install/lib/libsherpa-onnx-cxx-api.dylib
|
||||
fi
|
||||
|
||||
- name: Test KWS (zh)
|
||||
shell: bash
|
||||
run: |
|
||||
g++ -std=c++17 -o kws-cxx-api ./cxx-api-examples/kws-cxx-api.cc \
|
||||
-I ./build/install/include \
|
||||
-L ./build/install/lib/ \
|
||||
-l sherpa-onnx-cxx-api \
|
||||
-l sherpa-onnx-c-api \
|
||||
-l onnxruntime
|
||||
|
||||
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/kws-models/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2
|
||||
tar xvf sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2
|
||||
rm sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2
|
||||
|
||||
export LD_LIBRARY_PATH=$PWD/build/install/lib:$LD_LIBRARY_PATH
|
||||
export DYLD_LIBRARY_PATH=$PWD/build/install/lib:$DYLD_LIBRARY_PATH
|
||||
|
||||
./kws-cxx-api
|
||||
|
||||
rm kws-cxx-api
|
||||
rm -rf sherpa-onnx-kws-*
|
||||
|
||||
- name: Test Kokoro TTS (en)
|
||||
shell: bash
|
||||
run: |
|
||||
|
||||
@@ -151,24 +151,27 @@ class MainActivity : AppCompatActivity() {
|
||||
stream.acceptWaveform(samples, sampleRate = sampleRateInHz)
|
||||
while (kws.isReady(stream)) {
|
||||
kws.decode(stream)
|
||||
}
|
||||
|
||||
val text = kws.getResult(stream).keyword
|
||||
val text = kws.getResult(stream).keyword
|
||||
|
||||
var textToDisplay = lastText
|
||||
var textToDisplay = lastText
|
||||
|
||||
if (text.isNotBlank()) {
|
||||
if (lastText.isBlank()) {
|
||||
textToDisplay = "$idx: $text"
|
||||
} else {
|
||||
textToDisplay = "$idx: $text\n$lastText"
|
||||
if (text.isNotBlank()) {
|
||||
// Remember to reset the stream right after detecting a keyword
|
||||
|
||||
kws.reset(stream)
|
||||
if (lastText.isBlank()) {
|
||||
textToDisplay = "$idx: $text"
|
||||
} else {
|
||||
textToDisplay = "$idx: $text\n$lastText"
|
||||
}
|
||||
lastText = "$idx: $text\n$lastText"
|
||||
idx += 1
|
||||
}
|
||||
lastText = "$idx: $text\n$lastText"
|
||||
idx += 1
|
||||
}
|
||||
|
||||
runOnUiThread {
|
||||
textView.text = textToDisplay
|
||||
runOnUiThread {
|
||||
textView.text = textToDisplay
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,9 @@ include_directories(${CMAKE_SOURCE_DIR})
|
||||
add_executable(decode-file-c-api decode-file-c-api.c)
|
||||
target_link_libraries(decode-file-c-api sherpa-onnx-c-api cargs)
|
||||
|
||||
add_executable(kws-c-api kws-c-api.c)
|
||||
target_link_libraries(kws-c-api sherpa-onnx-c-api)
|
||||
|
||||
if(SHERPA_ONNX_ENABLE_TTS)
|
||||
add_executable(offline-tts-c-api offline-tts-c-api.c)
|
||||
target_link_libraries(offline-tts-c-api sherpa-onnx-c-api cargs)
|
||||
|
||||
150
c-api-examples/kws-c-api.c
Normal file
150
c-api-examples/kws-c-api.c
Normal file
@@ -0,0 +1,150 @@
|
||||
// c-api-examples/kws-c-api.c
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
//
|
||||
// This file demonstrates how to use keywords spotter with sherpa-onnx's C
|
||||
// clang-format off
|
||||
//
|
||||
// Usage
|
||||
//
|
||||
// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/kws-models/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2
|
||||
// tar xvf sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2
|
||||
// rm sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2
|
||||
//
|
||||
// ./kws-c-api
|
||||
//
|
||||
// clang-format on
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h> // exit
|
||||
#include <string.h> // memset
|
||||
|
||||
#include "sherpa-onnx/c-api/c-api.h"
|
||||
|
||||
int32_t main() {
|
||||
SherpaOnnxKeywordSpotterConfig config;
|
||||
|
||||
memset(&config, 0, sizeof(config));
|
||||
config.model_config.transducer.encoder =
|
||||
"./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/"
|
||||
"encoder-epoch-12-avg-2-chunk-16-left-64.onnx";
|
||||
|
||||
config.model_config.transducer.decoder =
|
||||
"./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/"
|
||||
"decoder-epoch-12-avg-2-chunk-16-left-64.onnx";
|
||||
|
||||
config.model_config.transducer.joiner =
|
||||
"./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/"
|
||||
"joiner-epoch-12-avg-2-chunk-16-left-64.onnx";
|
||||
|
||||
config.model_config.tokens =
|
||||
"./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/tokens.txt";
|
||||
|
||||
config.model_config.provider = "cpu";
|
||||
config.model_config.num_threads = 1;
|
||||
config.model_config.debug = 1;
|
||||
|
||||
config.keywords_file =
|
||||
"./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/"
|
||||
"test_keywords.txt";
|
||||
|
||||
const SherpaOnnxKeywordSpotter *kws = SherpaOnnxCreateKeywordSpotter(&config);
|
||||
if (!kws) {
|
||||
fprintf(stderr, "Please check your config");
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
fprintf(stderr,
|
||||
"--Test pre-defined keywords from test_wavs/test_keywords.txt--\n");
|
||||
|
||||
const char *wav_filename =
|
||||
"./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/3.wav";
|
||||
|
||||
float tail_paddings[8000] = {0}; // 0.5 seconds
|
||||
|
||||
const SherpaOnnxWave *wave = SherpaOnnxReadWave(wav_filename);
|
||||
if (wave == NULL) {
|
||||
fprintf(stderr, "Failed to read %s\n", wav_filename);
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
const SherpaOnnxOnlineStream *stream = SherpaOnnxCreateKeywordStream(kws);
|
||||
if (!stream) {
|
||||
fprintf(stderr, "Failed to create stream\n");
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
SherpaOnnxOnlineStreamAcceptWaveform(stream, wave->sample_rate, wave->samples,
|
||||
wave->num_samples);
|
||||
|
||||
SherpaOnnxOnlineStreamAcceptWaveform(stream, wave->sample_rate, tail_paddings,
|
||||
sizeof(tail_paddings) / sizeof(float));
|
||||
SherpaOnnxOnlineStreamInputFinished(stream);
|
||||
while (SherpaOnnxIsKeywordStreamReady(kws, stream)) {
|
||||
SherpaOnnxDecodeKeywordStream(kws, stream);
|
||||
const SherpaOnnxKeywordResult *r = SherpaOnnxGetKeywordResult(kws, stream);
|
||||
if (r && r->json && strlen(r->keyword)) {
|
||||
fprintf(stderr, "Detected keyword: %s\n", r->json);
|
||||
|
||||
// Remember to reset the keyword stream right after a keyword is detected
|
||||
SherpaOnnxResetKeywordStream(kws, stream);
|
||||
}
|
||||
SherpaOnnxDestroyKeywordResult(r);
|
||||
}
|
||||
SherpaOnnxDestroyOnlineStream(stream);
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
fprintf(stderr, "--Use pre-defined keywords + add a new keyword--\n");
|
||||
|
||||
stream = SherpaOnnxCreateKeywordStreamWithKeywords(kws, "y ǎn y uán @演员");
|
||||
|
||||
SherpaOnnxOnlineStreamAcceptWaveform(stream, wave->sample_rate, wave->samples,
|
||||
wave->num_samples);
|
||||
|
||||
SherpaOnnxOnlineStreamAcceptWaveform(stream, wave->sample_rate, tail_paddings,
|
||||
sizeof(tail_paddings) / sizeof(float));
|
||||
SherpaOnnxOnlineStreamInputFinished(stream);
|
||||
while (SherpaOnnxIsKeywordStreamReady(kws, stream)) {
|
||||
SherpaOnnxDecodeKeywordStream(kws, stream);
|
||||
const SherpaOnnxKeywordResult *r = SherpaOnnxGetKeywordResult(kws, stream);
|
||||
if (r && r->json && strlen(r->keyword)) {
|
||||
fprintf(stderr, "Detected keyword: %s\n", r->json);
|
||||
|
||||
// Remember to reset the keyword stream
|
||||
SherpaOnnxResetKeywordStream(kws, stream);
|
||||
}
|
||||
SherpaOnnxDestroyKeywordResult(r);
|
||||
}
|
||||
SherpaOnnxDestroyOnlineStream(stream);
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
fprintf(stderr, "--Use pre-defined keywords + add two new keywords--\n");
|
||||
|
||||
stream = SherpaOnnxCreateKeywordStreamWithKeywords(
|
||||
kws, "y ǎn y uán @演员/zh ī m íng @知名");
|
||||
|
||||
SherpaOnnxOnlineStreamAcceptWaveform(stream, wave->sample_rate, wave->samples,
|
||||
wave->num_samples);
|
||||
|
||||
SherpaOnnxOnlineStreamAcceptWaveform(stream, wave->sample_rate, tail_paddings,
|
||||
sizeof(tail_paddings) / sizeof(float));
|
||||
SherpaOnnxOnlineStreamInputFinished(stream);
|
||||
while (SherpaOnnxIsKeywordStreamReady(kws, stream)) {
|
||||
SherpaOnnxDecodeKeywordStream(kws, stream);
|
||||
const SherpaOnnxKeywordResult *r = SherpaOnnxGetKeywordResult(kws, stream);
|
||||
if (r && r->json && strlen(r->keyword)) {
|
||||
fprintf(stderr, "Detected keyword: %s\n", r->json);
|
||||
|
||||
// Remember to reset the keyword stream
|
||||
SherpaOnnxResetKeywordStream(kws, stream);
|
||||
}
|
||||
SherpaOnnxDestroyKeywordResult(r);
|
||||
}
|
||||
SherpaOnnxDestroyOnlineStream(stream);
|
||||
|
||||
SherpaOnnxFreeWave(wave);
|
||||
SherpaOnnxDestroyKeywordSpotter(kws);
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -3,6 +3,9 @@ include_directories(${CMAKE_SOURCE_DIR})
|
||||
add_executable(streaming-zipformer-cxx-api ./streaming-zipformer-cxx-api.cc)
|
||||
target_link_libraries(streaming-zipformer-cxx-api sherpa-onnx-cxx-api)
|
||||
|
||||
add_executable(kws-cxx-api ./kws-cxx-api.cc)
|
||||
target_link_libraries(kws-cxx-api sherpa-onnx-cxx-api)
|
||||
|
||||
add_executable(streaming-zipformer-rtf-cxx-api ./streaming-zipformer-rtf-cxx-api.cc)
|
||||
target_link_libraries(streaming-zipformer-rtf-cxx-api sherpa-onnx-cxx-api)
|
||||
|
||||
|
||||
141
cxx-api-examples/kws-cxx-api.cc
Normal file
141
cxx-api-examples/kws-cxx-api.cc
Normal file
@@ -0,0 +1,141 @@
|
||||
// cxx-api-examples/kws-cxx-api.cc
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
//
|
||||
// This file demonstrates how to use keywords spotter with sherpa-onnx's C
|
||||
// clang-format off
|
||||
//
|
||||
// Usage
|
||||
//
|
||||
// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/kws-models/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2
|
||||
// tar xvf sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2
|
||||
// rm sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2
|
||||
//
|
||||
// ./kws-cxx-api
|
||||
//
|
||||
// clang-format on
|
||||
#include <array>
|
||||
#include <iostream>
|
||||
|
||||
#include "sherpa-onnx/c-api/cxx-api.h"
|
||||
|
||||
int32_t main() {
|
||||
using namespace sherpa_onnx::cxx; // NOLINT
|
||||
|
||||
KeywordSpotterConfig config;
|
||||
config.model_config.transducer.encoder =
|
||||
"./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/"
|
||||
"encoder-epoch-12-avg-2-chunk-16-left-64.onnx";
|
||||
|
||||
config.model_config.transducer.decoder =
|
||||
"./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/"
|
||||
"decoder-epoch-12-avg-2-chunk-16-left-64.onnx";
|
||||
|
||||
config.model_config.transducer.joiner =
|
||||
"./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/"
|
||||
"joiner-epoch-12-avg-2-chunk-16-left-64.onnx";
|
||||
|
||||
config.model_config.tokens =
|
||||
"./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/tokens.txt";
|
||||
|
||||
config.model_config.provider = "cpu";
|
||||
config.model_config.num_threads = 1;
|
||||
config.model_config.debug = 1;
|
||||
|
||||
config.keywords_file =
|
||||
"./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/"
|
||||
"test_keywords.txt";
|
||||
|
||||
KeywordSpotter kws = KeywordSpotter::Create(config);
|
||||
if (!kws.Get()) {
|
||||
std::cerr << "Please check your config\n";
|
||||
return -1;
|
||||
}
|
||||
|
||||
std::cout
|
||||
<< "--Test pre-defined keywords from test_wavs/test_keywords.txt--\n";
|
||||
|
||||
std::string wave_filename =
|
||||
"./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/3.wav";
|
||||
|
||||
std::array<float, 8000> tail_paddings = {0}; // 0.5 seconds
|
||||
|
||||
Wave wave = ReadWave(wave_filename);
|
||||
if (wave.samples.empty()) {
|
||||
std::cerr << "Failed to read: '" << wave_filename << "'\n";
|
||||
return -1;
|
||||
}
|
||||
|
||||
OnlineStream stream = kws.CreateStream();
|
||||
if (!stream.Get()) {
|
||||
std::cerr << "Failed to create stream\n";
|
||||
return -1;
|
||||
}
|
||||
|
||||
stream.AcceptWaveform(wave.sample_rate, wave.samples.data(),
|
||||
wave.samples.size());
|
||||
|
||||
stream.AcceptWaveform(wave.sample_rate, tail_paddings.data(),
|
||||
tail_paddings.size());
|
||||
stream.InputFinished();
|
||||
|
||||
while (kws.IsReady(&stream)) {
|
||||
kws.Decode(&stream);
|
||||
auto r = kws.GetResult(&stream);
|
||||
if (!r.keyword.empty()) {
|
||||
std::cout << "Detected keyword: " << r.json << "\n";
|
||||
|
||||
// Remember to reset the keyword stream right after a keyword is detected
|
||||
kws.Reset(&stream);
|
||||
}
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
std::cout << "--Use pre-defined keywords + add a new keyword--\n";
|
||||
|
||||
stream = kws.CreateStream("y ǎn y uán @演员");
|
||||
|
||||
stream.AcceptWaveform(wave.sample_rate, wave.samples.data(),
|
||||
wave.samples.size());
|
||||
|
||||
stream.AcceptWaveform(wave.sample_rate, tail_paddings.data(),
|
||||
tail_paddings.size());
|
||||
stream.InputFinished();
|
||||
|
||||
while (kws.IsReady(&stream)) {
|
||||
kws.Decode(&stream);
|
||||
auto r = kws.GetResult(&stream);
|
||||
if (!r.keyword.empty()) {
|
||||
std::cout << "Detected keyword: " << r.json << "\n";
|
||||
|
||||
// Remember to reset the keyword stream right after a keyword is detected
|
||||
kws.Reset(&stream);
|
||||
}
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
std::cout << "--Use pre-defined keywords + add two new keywords--\n";
|
||||
|
||||
stream = kws.CreateStream("y ǎn y uán @演员/zh ī m íng @知名");
|
||||
|
||||
stream.AcceptWaveform(wave.sample_rate, wave.samples.data(),
|
||||
wave.samples.size());
|
||||
|
||||
stream.AcceptWaveform(wave.sample_rate, tail_paddings.data(),
|
||||
tail_paddings.size());
|
||||
stream.InputFinished();
|
||||
|
||||
while (kws.IsReady(&stream)) {
|
||||
kws.Decode(&stream);
|
||||
auto r = kws.GetResult(&stream);
|
||||
if (!r.keyword.empty()) {
|
||||
std::cout << "Detected keyword: " << r.json << "\n";
|
||||
|
||||
// Remember to reset the keyword stream right after a keyword is detected
|
||||
kws.Reset(&stream);
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
@@ -73,6 +73,8 @@ void main(List<String> arguments) async {
|
||||
spotter.decode(stream);
|
||||
final result = spotter.getResult(stream);
|
||||
if (result.keyword != '') {
|
||||
// Remember to reset the stream right after detecting a keyword
|
||||
spotter.reset(stream);
|
||||
print('Detected: ${result.keyword}');
|
||||
}
|
||||
}
|
||||
|
||||
@@ -53,6 +53,8 @@ class KeywordSpotterDemo
|
||||
var result = kws.GetResult(s);
|
||||
if (result.Keyword != string.Empty)
|
||||
{
|
||||
// Remember to call Reset() right after detecting a keyword
|
||||
kws.Reset(s);
|
||||
Console.WriteLine("Detected: {0}", result.Keyword);
|
||||
}
|
||||
}
|
||||
@@ -70,6 +72,8 @@ class KeywordSpotterDemo
|
||||
var result = kws.GetResult(s);
|
||||
if (result.Keyword != string.Empty)
|
||||
{
|
||||
// Remember to call Reset() right after detecting a keyword
|
||||
kws.Reset(s);
|
||||
Console.WriteLine("Detected: {0}", result.Keyword);
|
||||
}
|
||||
}
|
||||
@@ -89,6 +93,8 @@ class KeywordSpotterDemo
|
||||
var result = kws.GetResult(s);
|
||||
if (result.Keyword != string.Empty)
|
||||
{
|
||||
// Remember to call Reset() right after detecting a keyword
|
||||
kws.Reset(s);
|
||||
Console.WriteLine("Detected: {0}", result.Keyword);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -107,12 +107,15 @@ class KeywordSpotterDemo
|
||||
while (kws.IsReady(s))
|
||||
{
|
||||
kws.Decode(s);
|
||||
}
|
||||
|
||||
var result = kws.GetResult(s);
|
||||
if (result.Keyword != string.Empty)
|
||||
{
|
||||
Console.WriteLine("Detected: {0}", result.Keyword);
|
||||
var result = kws.GetResult(s);
|
||||
if (result.Keyword != string.Empty)
|
||||
{
|
||||
// Remember to call Reset() right after detecting a keyword
|
||||
kws.Reset(s);
|
||||
|
||||
Console.WriteLine("Detected: {0}", result.Keyword);
|
||||
}
|
||||
}
|
||||
|
||||
Thread.Sleep(200); // ms
|
||||
|
||||
@@ -168,6 +168,10 @@ class KeywordSpotter {
|
||||
SherpaOnnxBindings.decodeKeywordStream?.call(ptr, stream.ptr);
|
||||
}
|
||||
|
||||
void reset(OnlineStream stream) {
|
||||
SherpaOnnxBindings.resetKeywordStream?.call(ptr, stream.ptr);
|
||||
}
|
||||
|
||||
Pointer<SherpaOnnxKeywordSpotter> ptr;
|
||||
KeywordSpotterConfig config;
|
||||
}
|
||||
|
||||
@@ -667,6 +667,12 @@ typedef DecodeKeywordStreamNative = Void Function(
|
||||
typedef DecodeKeywordStream = void Function(
|
||||
Pointer<SherpaOnnxKeywordSpotter>, Pointer<SherpaOnnxOnlineStream>);
|
||||
|
||||
typedef ResetKeywordStreamNative = Void Function(
|
||||
Pointer<SherpaOnnxKeywordSpotter>, Pointer<SherpaOnnxOnlineStream>);
|
||||
|
||||
typedef ResetKeywordStream = void Function(
|
||||
Pointer<SherpaOnnxKeywordSpotter>, Pointer<SherpaOnnxOnlineStream>);
|
||||
|
||||
typedef GetKeywordResultAsJsonNative = Pointer<Utf8> Function(
|
||||
Pointer<SherpaOnnxKeywordSpotter>, Pointer<SherpaOnnxOnlineStream>);
|
||||
|
||||
@@ -1157,6 +1163,7 @@ class SherpaOnnxBindings {
|
||||
static CreateKeywordStreamWithKeywords? createKeywordStreamWithKeywords;
|
||||
static IsKeywordStreamReady? isKeywordStreamReady;
|
||||
static DecodeKeywordStream? decodeKeywordStream;
|
||||
static ResetKeywordStream? resetKeywordStream;
|
||||
static GetKeywordResultAsJson? getKeywordResultAsJson;
|
||||
static FreeKeywordResultJson? freeKeywordResultJson;
|
||||
|
||||
@@ -1459,6 +1466,11 @@ class SherpaOnnxBindings {
|
||||
'SherpaOnnxDecodeKeywordStream')
|
||||
.asFunction();
|
||||
|
||||
resetKeywordStream ??= dynamicLibrary
|
||||
.lookup<NativeFunction<ResetKeywordStreamNative>>(
|
||||
'SherpaOnnxResetKeywordStream')
|
||||
.asFunction();
|
||||
|
||||
getKeywordResultAsJson ??= dynamicLibrary
|
||||
.lookup<NativeFunction<GetKeywordResultAsJsonNative>>(
|
||||
'SherpaOnnxGetKeywordResultAsJson')
|
||||
|
||||
@@ -43,6 +43,8 @@ func main() {
|
||||
spotter.Decode(stream)
|
||||
result := spotter.GetResult(stream)
|
||||
if result.Keyword != "" {
|
||||
// You have to reset the stream right after detecting a keyword
|
||||
spotter.Reset(stream)
|
||||
log.Printf("Detected %v\n", result.Keyword)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -46,7 +46,7 @@ static Napi::External<SherpaOnnxKeywordSpotter> CreateKeywordSpotterWrapper(
|
||||
SHERPA_ONNX_ASSIGN_ATTR_STR(keywords_buf, keywordsBuf);
|
||||
SHERPA_ONNX_ASSIGN_ATTR_INT32(keywords_buf_size, keywordsBufSize);
|
||||
|
||||
SherpaOnnxKeywordSpotter *kws = SherpaOnnxCreateKeywordSpotter(&c);
|
||||
const SherpaOnnxKeywordSpotter *kws = SherpaOnnxCreateKeywordSpotter(&c);
|
||||
|
||||
if (c.model_config.transducer.encoder) {
|
||||
delete[] c.model_config.transducer.encoder;
|
||||
@@ -100,7 +100,8 @@ static Napi::External<SherpaOnnxKeywordSpotter> CreateKeywordSpotterWrapper(
|
||||
}
|
||||
|
||||
return Napi::External<SherpaOnnxKeywordSpotter>::New(
|
||||
env, kws, [](Napi::Env env, SherpaOnnxKeywordSpotter *kws) {
|
||||
env, const_cast<SherpaOnnxKeywordSpotter *>(kws),
|
||||
[](Napi::Env env, SherpaOnnxKeywordSpotter *kws) {
|
||||
SherpaOnnxDestroyKeywordSpotter(kws);
|
||||
});
|
||||
}
|
||||
@@ -125,13 +126,14 @@ static Napi::External<SherpaOnnxOnlineStream> CreateKeywordStreamWrapper(
|
||||
return {};
|
||||
}
|
||||
|
||||
SherpaOnnxKeywordSpotter *kws =
|
||||
const SherpaOnnxKeywordSpotter *kws =
|
||||
info[0].As<Napi::External<SherpaOnnxKeywordSpotter>>().Data();
|
||||
|
||||
SherpaOnnxOnlineStream *stream = SherpaOnnxCreateKeywordStream(kws);
|
||||
const SherpaOnnxOnlineStream *stream = SherpaOnnxCreateKeywordStream(kws);
|
||||
|
||||
return Napi::External<SherpaOnnxOnlineStream>::New(
|
||||
env, stream, [](Napi::Env env, SherpaOnnxOnlineStream *stream) {
|
||||
env, const_cast<SherpaOnnxOnlineStream *>(stream),
|
||||
[](Napi::Env env, SherpaOnnxOnlineStream *stream) {
|
||||
SherpaOnnxDestroyOnlineStream(stream);
|
||||
});
|
||||
}
|
||||
@@ -162,10 +164,10 @@ static Napi::Boolean IsKeywordStreamReadyWrapper(
|
||||
return {};
|
||||
}
|
||||
|
||||
SherpaOnnxKeywordSpotter *kws =
|
||||
const SherpaOnnxKeywordSpotter *kws =
|
||||
info[0].As<Napi::External<SherpaOnnxKeywordSpotter>>().Data();
|
||||
|
||||
SherpaOnnxOnlineStream *stream =
|
||||
const SherpaOnnxOnlineStream *stream =
|
||||
info[1].As<Napi::External<SherpaOnnxOnlineStream>>().Data();
|
||||
|
||||
int32_t is_ready = SherpaOnnxIsKeywordStreamReady(kws, stream);
|
||||
@@ -198,15 +200,49 @@ static void DecodeKeywordStreamWrapper(const Napi::CallbackInfo &info) {
|
||||
return;
|
||||
}
|
||||
|
||||
SherpaOnnxKeywordSpotter *kws =
|
||||
const SherpaOnnxKeywordSpotter *kws =
|
||||
info[0].As<Napi::External<SherpaOnnxKeywordSpotter>>().Data();
|
||||
|
||||
SherpaOnnxOnlineStream *stream =
|
||||
const SherpaOnnxOnlineStream *stream =
|
||||
info[1].As<Napi::External<SherpaOnnxOnlineStream>>().Data();
|
||||
|
||||
SherpaOnnxDecodeKeywordStream(kws, stream);
|
||||
}
|
||||
|
||||
static void ResetKeywordStreamWrapper(const Napi::CallbackInfo &info) {
|
||||
Napi::Env env = info.Env();
|
||||
if (info.Length() != 2) {
|
||||
std::ostringstream os;
|
||||
os << "Expect only 2 arguments. Given: " << info.Length();
|
||||
|
||||
Napi::TypeError::New(env, os.str()).ThrowAsJavaScriptException();
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
if (!info[0].IsExternal()) {
|
||||
Napi::TypeError::New(env, "Argument 0 should be a keyword spotter pointer.")
|
||||
.ThrowAsJavaScriptException();
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
if (!info[1].IsExternal()) {
|
||||
Napi::TypeError::New(env, "Argument 1 should be an online stream pointer.")
|
||||
.ThrowAsJavaScriptException();
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const SherpaOnnxKeywordSpotter *kws =
|
||||
info[0].As<Napi::External<SherpaOnnxKeywordSpotter>>().Data();
|
||||
|
||||
const SherpaOnnxOnlineStream *stream =
|
||||
info[1].As<Napi::External<SherpaOnnxOnlineStream>>().Data();
|
||||
|
||||
SherpaOnnxResetKeywordStream(kws, stream);
|
||||
}
|
||||
|
||||
static Napi::String GetKeywordResultAsJsonWrapper(
|
||||
const Napi::CallbackInfo &info) {
|
||||
Napi::Env env = info.Env();
|
||||
@@ -233,10 +269,10 @@ static Napi::String GetKeywordResultAsJsonWrapper(
|
||||
return {};
|
||||
}
|
||||
|
||||
SherpaOnnxKeywordSpotter *kws =
|
||||
const SherpaOnnxKeywordSpotter *kws =
|
||||
info[0].As<Napi::External<SherpaOnnxKeywordSpotter>>().Data();
|
||||
|
||||
SherpaOnnxOnlineStream *stream =
|
||||
const SherpaOnnxOnlineStream *stream =
|
||||
info[1].As<Napi::External<SherpaOnnxOnlineStream>>().Data();
|
||||
|
||||
const char *json = SherpaOnnxGetKeywordResultAsJson(kws, stream);
|
||||
@@ -261,6 +297,9 @@ void InitKeywordSpotting(Napi::Env env, Napi::Object exports) {
|
||||
exports.Set(Napi::String::New(env, "decodeKeywordStream"),
|
||||
Napi::Function::New(env, DecodeKeywordStreamWrapper));
|
||||
|
||||
exports.Set(Napi::String::New(env, "resetKeywordStream"),
|
||||
Napi::Function::New(env, ResetKeywordStreamWrapper));
|
||||
|
||||
exports.Set(Napi::String::New(env, "getKeywordResultAsJson"),
|
||||
Napi::Function::New(env, GetKeywordResultAsJsonWrapper));
|
||||
}
|
||||
|
||||
@@ -56,6 +56,8 @@ public class KyewordSpotterFromFile {
|
||||
|
||||
String keyword = kws.getResult(stream).getKeyword();
|
||||
if (!keyword.isEmpty()) {
|
||||
// Remember to reset the stream right after detecting a keyword
|
||||
kws.reset(stream);
|
||||
System.out.printf("Detected keyword: %s\n", keyword);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -41,6 +41,9 @@ while (kws.isReady(stream)) {
|
||||
const keyword = kws.getResult(stream).keyword;
|
||||
if (keyword != '') {
|
||||
detectedKeywords.push(keyword);
|
||||
|
||||
// remember to reset the stream right after detecting a keyword
|
||||
kws.reset(stream);
|
||||
}
|
||||
}
|
||||
console.log(detectedKeywords);
|
||||
|
||||
@@ -169,6 +169,8 @@ def main():
|
||||
|
||||
print("Started! Please speak")
|
||||
|
||||
idx = 0
|
||||
|
||||
sample_rate = 16000
|
||||
samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms
|
||||
stream = keyword_spotter.create_stream()
|
||||
@@ -179,9 +181,12 @@ def main():
|
||||
stream.accept_waveform(sample_rate, samples)
|
||||
while keyword_spotter.is_ready(stream):
|
||||
keyword_spotter.decode_stream(stream)
|
||||
result = keyword_spotter.get_result(stream)
|
||||
if result:
|
||||
print("\r{}".format(result), end="", flush=True)
|
||||
result = keyword_spotter.get_result(stream)
|
||||
if result:
|
||||
print(f"{idx}: {result }")
|
||||
idx += 1
|
||||
# Remember to reset stream right after detecting a keyword
|
||||
keyword_spotter.reset_stream(stream)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -18,122 +18,6 @@ import numpy as np
|
||||
import sherpa_onnx
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=str,
|
||||
help="Path to tokens.txt",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder",
|
||||
type=str,
|
||||
help="Path to the transducer encoder model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoder",
|
||||
type=str,
|
||||
help="Path to the transducer decoder model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--joiner",
|
||||
type=str,
|
||||
help="Path to the transducer joiner model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-threads",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of threads for neural network computation",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--provider",
|
||||
type=str,
|
||||
default="cpu",
|
||||
help="Valid values: cpu, cuda, coreml",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-active-paths",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""
|
||||
It specifies number of active paths to keep during decoding.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-trailing-blanks",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""The number of trailing blanks a keyword should be followed. Setting
|
||||
to a larger value (e.g. 8) when your keywords has overlapping tokens
|
||||
between each other.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--keywords-file",
|
||||
type=str,
|
||||
help="""
|
||||
The file containing keywords, one words/phrases per line, and for each
|
||||
phrase the bpe/cjkchar/pinyin are separated by a space. For example:
|
||||
|
||||
▁HE LL O ▁WORLD
|
||||
x iǎo ài t óng x ué
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--keywords-score",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="""
|
||||
The boosting score of each token for keywords. The larger the easier to
|
||||
survive beam search.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--keywords-threshold",
|
||||
type=float,
|
||||
default=0.25,
|
||||
help="""
|
||||
The trigger threshold (i.e. probability) of the keyword. The larger the
|
||||
harder to trigger.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to decode. Each file must be of WAVE"
|
||||
"format with a single channel, and each sample has 16-bit, "
|
||||
"i.e., int16_t. "
|
||||
"The sample rate of the file can be arbitrary and does not need to "
|
||||
"be 16 kHz",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def assert_file_exists(filename: str):
|
||||
assert Path(filename).is_file(), (
|
||||
f"{filename} does not exist!\n"
|
||||
"Please refer to "
|
||||
"https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html to download it"
|
||||
)
|
||||
|
||||
|
||||
def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
|
||||
"""
|
||||
Args:
|
||||
@@ -159,83 +43,74 @@ def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
|
||||
return samples_float32, f.getframerate()
|
||||
|
||||
|
||||
def create_keyword_spotter():
|
||||
kws = sherpa_onnx.KeywordSpotter(
|
||||
tokens="./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/tokens.txt",
|
||||
encoder="./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/encoder-epoch-12-avg-2-chunk-16-left-64.onnx",
|
||||
decoder="./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/decoder-epoch-12-avg-2-chunk-16-left-64.onnx",
|
||||
joiner="./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/joiner-epoch-12-avg-2-chunk-16-left-64.onnx",
|
||||
num_threads=2,
|
||||
keywords_file="./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/test_keywords.txt",
|
||||
provider="cpu",
|
||||
)
|
||||
|
||||
return kws
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
assert_file_exists(args.tokens)
|
||||
assert_file_exists(args.encoder)
|
||||
assert_file_exists(args.decoder)
|
||||
assert_file_exists(args.joiner)
|
||||
kws = create_keyword_spotter()
|
||||
|
||||
assert Path(
|
||||
args.keywords_file
|
||||
).is_file(), (
|
||||
f"keywords_file : {args.keywords_file} not exist, please provide a valid path."
|
||||
wave_filename = (
|
||||
"./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/3.wav"
|
||||
)
|
||||
|
||||
keyword_spotter = sherpa_onnx.KeywordSpotter(
|
||||
tokens=args.tokens,
|
||||
encoder=args.encoder,
|
||||
decoder=args.decoder,
|
||||
joiner=args.joiner,
|
||||
num_threads=args.num_threads,
|
||||
max_active_paths=args.max_active_paths,
|
||||
keywords_file=args.keywords_file,
|
||||
keywords_score=args.keywords_score,
|
||||
keywords_threshold=args.keywords_threshold,
|
||||
num_trailing_blanks=args.num_trailing_blanks,
|
||||
provider=args.provider,
|
||||
)
|
||||
samples, sample_rate = read_wave(wave_filename)
|
||||
|
||||
print("Started!")
|
||||
start_time = time.time()
|
||||
tail_paddings = np.zeros(int(0.66 * sample_rate), dtype=np.float32)
|
||||
|
||||
streams = []
|
||||
total_duration = 0
|
||||
for wave_filename in args.sound_files:
|
||||
assert_file_exists(wave_filename)
|
||||
samples, sample_rate = read_wave(wave_filename)
|
||||
duration = len(samples) / sample_rate
|
||||
total_duration += duration
|
||||
print("----------Use pre-defined keywords----------")
|
||||
s = kws.create_stream()
|
||||
s.accept_waveform(sample_rate, samples)
|
||||
s.accept_waveform(sample_rate, tail_paddings)
|
||||
s.input_finished()
|
||||
while kws.is_ready(s):
|
||||
kws.decode_stream(s)
|
||||
r = kws.get_result(s)
|
||||
if r != "":
|
||||
# Remember to call reset right after detected a keyword
|
||||
kws.reset_stream(s)
|
||||
|
||||
s = keyword_spotter.create_stream()
|
||||
print(f"Detected {r}")
|
||||
|
||||
s.accept_waveform(sample_rate, samples)
|
||||
print("----------Use pre-defined keywords + add a new keyword----------")
|
||||
|
||||
tail_paddings = np.zeros(int(0.66 * sample_rate), dtype=np.float32)
|
||||
s.accept_waveform(sample_rate, tail_paddings)
|
||||
s = kws.create_stream("y ǎn y uán @演员")
|
||||
s.accept_waveform(sample_rate, samples)
|
||||
s.accept_waveform(sample_rate, tail_paddings)
|
||||
s.input_finished()
|
||||
while kws.is_ready(s):
|
||||
kws.decode_stream(s)
|
||||
r = kws.get_result(s)
|
||||
if r != "":
|
||||
# Remember to call reset right after detected a keyword
|
||||
kws.reset_stream(s)
|
||||
|
||||
s.input_finished()
|
||||
print(f"Detected {r}")
|
||||
|
||||
streams.append(s)
|
||||
print("----------Use pre-defined keywords + add 2 new keywords----------")
|
||||
|
||||
results = [""] * len(streams)
|
||||
while True:
|
||||
ready_list = []
|
||||
for i, s in enumerate(streams):
|
||||
if keyword_spotter.is_ready(s):
|
||||
ready_list.append(s)
|
||||
r = keyword_spotter.get_result(s)
|
||||
if r:
|
||||
results[i] += f"{r}/"
|
||||
print(f"{r} is detected.")
|
||||
if len(ready_list) == 0:
|
||||
break
|
||||
keyword_spotter.decode_streams(ready_list)
|
||||
end_time = time.time()
|
||||
print("Done!")
|
||||
s = kws.create_stream("y ǎn y uán @演员/zh ī m íng @知名")
|
||||
s.accept_waveform(sample_rate, samples)
|
||||
s.accept_waveform(sample_rate, tail_paddings)
|
||||
s.input_finished()
|
||||
while kws.is_ready(s):
|
||||
kws.decode_stream(s)
|
||||
r = kws.get_result(s)
|
||||
if r != "":
|
||||
# Remember to call reset right after detected a keyword
|
||||
kws.reset_stream(s)
|
||||
|
||||
for wave_filename, result in zip(args.sound_files, results):
|
||||
print(f"{wave_filename}\n{result}")
|
||||
print("-" * 10)
|
||||
|
||||
elapsed_seconds = end_time - start_time
|
||||
rtf = elapsed_seconds / total_duration
|
||||
print(f"num_threads: {args.num_threads}")
|
||||
print(f"Wave duration: {total_duration:.3f} s")
|
||||
print(f"Elapsed time: {elapsed_seconds:.3f} s")
|
||||
print(
|
||||
f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}"
|
||||
)
|
||||
print(f"Detected {r}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -46,6 +46,11 @@ namespace SherpaOnnx
|
||||
Decode(_handle.Handle, stream.Handle);
|
||||
}
|
||||
|
||||
public void Reset(OnlineStream stream)
|
||||
{
|
||||
Reset(_handle.Handle, stream.Handle);
|
||||
}
|
||||
|
||||
// The caller should ensure all passed streams are ready for decoding.
|
||||
public void Decode(IEnumerable<OnlineStream> streams)
|
||||
{
|
||||
@@ -110,6 +115,9 @@ namespace SherpaOnnx
|
||||
[DllImport(Dll.Filename, EntryPoint = "SherpaOnnxDecodeKeywordStream")]
|
||||
private static extern void Decode(IntPtr handle, IntPtr stream);
|
||||
|
||||
[DllImport(Dll.Filename, EntryPoint = "SherpaOnnxResetKeywordStream")]
|
||||
private static extern void Reset(IntPtr handle, IntPtr stream);
|
||||
|
||||
[DllImport(Dll.Filename, EntryPoint = "SherpaOnnxDecodeMultipleKeywordStreams")]
|
||||
private static extern void Decode(IntPtr handle, IntPtr[] streams, int n);
|
||||
|
||||
|
||||
@@ -1584,6 +1584,11 @@ func (spotter *KeywordSpotter) Decode(s *OnlineStream) {
|
||||
C.SherpaOnnxDecodeKeywordStream(spotter.impl, s.impl)
|
||||
}
|
||||
|
||||
// You MUST call it right after detecting a keyword
|
||||
func (spotter *KeywordSpotter) Reset(s *OnlineStream) {
|
||||
C.SherpaOnnxResetKeywordStream(spotter.impl, s.impl)
|
||||
}
|
||||
|
||||
// Get the current result of stream since the last invoke of Reset()
|
||||
func (spotter *KeywordSpotter) GetResult(s *OnlineStream) *KeywordSpotterResult {
|
||||
p := C.SherpaOnnxGetKeywordResult(spotter.impl, s.impl)
|
||||
|
||||
@@ -20,6 +20,10 @@ class KeywordSpotter {
|
||||
addon.decodeKeywordStream(this.handle, stream.handle);
|
||||
}
|
||||
|
||||
reset(stream) {
|
||||
addon.resetKeywordStream(this.handle, stream.handle);
|
||||
}
|
||||
|
||||
getResult(stream) {
|
||||
const jsonStr = addon.getKeywordResultAsJson(this.handle, stream.handle);
|
||||
|
||||
|
||||
@@ -678,7 +678,7 @@ struct SherpaOnnxKeywordSpotter {
|
||||
std::unique_ptr<sherpa_onnx::KeywordSpotter> impl;
|
||||
};
|
||||
|
||||
SherpaOnnxKeywordSpotter *SherpaOnnxCreateKeywordSpotter(
|
||||
const SherpaOnnxKeywordSpotter *SherpaOnnxCreateKeywordSpotter(
|
||||
const SherpaOnnxKeywordSpotterConfig *config) {
|
||||
sherpa_onnx::KeywordSpotterConfig spotter_config;
|
||||
|
||||
@@ -755,37 +755,42 @@ SherpaOnnxKeywordSpotter *SherpaOnnxCreateKeywordSpotter(
|
||||
return spotter;
|
||||
}
|
||||
|
||||
void SherpaOnnxDestroyKeywordSpotter(SherpaOnnxKeywordSpotter *spotter) {
|
||||
void SherpaOnnxDestroyKeywordSpotter(const SherpaOnnxKeywordSpotter *spotter) {
|
||||
delete spotter;
|
||||
}
|
||||
|
||||
SherpaOnnxOnlineStream *SherpaOnnxCreateKeywordStream(
|
||||
const SherpaOnnxOnlineStream *SherpaOnnxCreateKeywordStream(
|
||||
const SherpaOnnxKeywordSpotter *spotter) {
|
||||
SherpaOnnxOnlineStream *stream =
|
||||
new SherpaOnnxOnlineStream(spotter->impl->CreateStream());
|
||||
return stream;
|
||||
}
|
||||
|
||||
SherpaOnnxOnlineStream *SherpaOnnxCreateKeywordStreamWithKeywords(
|
||||
const SherpaOnnxOnlineStream *SherpaOnnxCreateKeywordStreamWithKeywords(
|
||||
const SherpaOnnxKeywordSpotter *spotter, const char *keywords) {
|
||||
SherpaOnnxOnlineStream *stream =
|
||||
new SherpaOnnxOnlineStream(spotter->impl->CreateStream(keywords));
|
||||
return stream;
|
||||
}
|
||||
|
||||
int32_t SherpaOnnxIsKeywordStreamReady(SherpaOnnxKeywordSpotter *spotter,
|
||||
SherpaOnnxOnlineStream *stream) {
|
||||
int32_t SherpaOnnxIsKeywordStreamReady(const SherpaOnnxKeywordSpotter *spotter,
|
||||
const SherpaOnnxOnlineStream *stream) {
|
||||
return spotter->impl->IsReady(stream->impl.get());
|
||||
}
|
||||
|
||||
void SherpaOnnxDecodeKeywordStream(SherpaOnnxKeywordSpotter *spotter,
|
||||
SherpaOnnxOnlineStream *stream) {
|
||||
return spotter->impl->DecodeStream(stream->impl.get());
|
||||
void SherpaOnnxDecodeKeywordStream(const SherpaOnnxKeywordSpotter *spotter,
|
||||
const SherpaOnnxOnlineStream *stream) {
|
||||
spotter->impl->DecodeStream(stream->impl.get());
|
||||
}
|
||||
|
||||
void SherpaOnnxDecodeMultipleKeywordStreams(SherpaOnnxKeywordSpotter *spotter,
|
||||
SherpaOnnxOnlineStream **streams,
|
||||
int32_t n) {
|
||||
void SherpaOnnxResetKeywordStream(const SherpaOnnxKeywordSpotter *spotter,
|
||||
const SherpaOnnxOnlineStream *stream) {
|
||||
spotter->impl->Reset(stream->impl.get());
|
||||
}
|
||||
|
||||
void SherpaOnnxDecodeMultipleKeywordStreams(
|
||||
const SherpaOnnxKeywordSpotter *spotter,
|
||||
const SherpaOnnxOnlineStream **streams, int32_t n) {
|
||||
std::vector<sherpa_onnx::OnlineStream *> ss(n);
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
ss[i] = streams[i]->impl.get();
|
||||
@@ -794,7 +799,8 @@ void SherpaOnnxDecodeMultipleKeywordStreams(SherpaOnnxKeywordSpotter *spotter,
|
||||
}
|
||||
|
||||
const SherpaOnnxKeywordResult *SherpaOnnxGetKeywordResult(
|
||||
SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream *stream) {
|
||||
const SherpaOnnxKeywordSpotter *spotter,
|
||||
const SherpaOnnxOnlineStream *stream) {
|
||||
const sherpa_onnx::KeywordResult &result =
|
||||
spotter->impl->GetResult(stream->impl.get());
|
||||
const auto &keyword = result.keyword;
|
||||
@@ -869,8 +875,9 @@ void SherpaOnnxDestroyKeywordResult(const SherpaOnnxKeywordResult *r) {
|
||||
}
|
||||
}
|
||||
|
||||
const char *SherpaOnnxGetKeywordResultAsJson(SherpaOnnxKeywordSpotter *spotter,
|
||||
SherpaOnnxOnlineStream *stream) {
|
||||
const char *SherpaOnnxGetKeywordResultAsJson(
|
||||
const SherpaOnnxKeywordSpotter *spotter,
|
||||
const SherpaOnnxOnlineStream *stream) {
|
||||
const sherpa_onnx::KeywordResult &result =
|
||||
spotter->impl->GetResult(stream->impl.get());
|
||||
|
||||
|
||||
@@ -600,7 +600,7 @@ SHERPA_ONNX_API const char *SherpaOnnxGetOfflineStreamResultAsJson(
|
||||
SHERPA_ONNX_API void SherpaOnnxDestroyOfflineStreamResultJson(const char *s);
|
||||
|
||||
// ============================================================
|
||||
// For Keyword Spot
|
||||
// For Keyword Spotter
|
||||
// ============================================================
|
||||
SHERPA_ONNX_API typedef struct SherpaOnnxKeywordResult {
|
||||
/// The triggered keyword.
|
||||
@@ -660,21 +660,21 @@ SHERPA_ONNX_API typedef struct SherpaOnnxKeywordSpotter
|
||||
/// @param config Config for the keyword spotter.
|
||||
/// @return Return a pointer to the spotter. The user has to invoke
|
||||
/// SherpaOnnxDestroyKeywordSpotter() to free it to avoid memory leak.
|
||||
SHERPA_ONNX_API SherpaOnnxKeywordSpotter *SherpaOnnxCreateKeywordSpotter(
|
||||
SHERPA_ONNX_API const SherpaOnnxKeywordSpotter *SherpaOnnxCreateKeywordSpotter(
|
||||
const SherpaOnnxKeywordSpotterConfig *config);
|
||||
|
||||
/// Free a pointer returned by SherpaOnnxCreateKeywordSpotter()
|
||||
///
|
||||
/// @param p A pointer returned by SherpaOnnxCreateKeywordSpotter()
|
||||
SHERPA_ONNX_API void SherpaOnnxDestroyKeywordSpotter(
|
||||
SherpaOnnxKeywordSpotter *spotter);
|
||||
const SherpaOnnxKeywordSpotter *spotter);
|
||||
|
||||
/// Create an online stream for accepting wave samples.
|
||||
///
|
||||
/// @param spotter A pointer returned by SherpaOnnxCreateKeywordSpotter()
|
||||
/// @return Return a pointer to an OnlineStream. The user has to invoke
|
||||
/// SherpaOnnxDestroyOnlineStream() to free it to avoid memory leak.
|
||||
SHERPA_ONNX_API SherpaOnnxOnlineStream *SherpaOnnxCreateKeywordStream(
|
||||
SHERPA_ONNX_API const SherpaOnnxOnlineStream *SherpaOnnxCreateKeywordStream(
|
||||
const SherpaOnnxKeywordSpotter *spotter);
|
||||
|
||||
/// Create an online stream for accepting wave samples with the specified hot
|
||||
@@ -684,7 +684,7 @@ SHERPA_ONNX_API SherpaOnnxOnlineStream *SherpaOnnxCreateKeywordStream(
|
||||
/// @param keywords A pointer points to the keywords that you set
|
||||
/// @return Return a pointer to an OnlineStream. The user has to invoke
|
||||
/// SherpaOnnxDestroyOnlineStream() to free it to avoid memory leak.
|
||||
SHERPA_ONNX_API SherpaOnnxOnlineStream *
|
||||
SHERPA_ONNX_API const SherpaOnnxOnlineStream *
|
||||
SherpaOnnxCreateKeywordStreamWithKeywords(
|
||||
const SherpaOnnxKeywordSpotter *spotter, const char *keywords);
|
||||
|
||||
@@ -693,15 +693,22 @@ SherpaOnnxCreateKeywordStreamWithKeywords(
|
||||
///
|
||||
/// @param spotter A pointer returned by SherpaOnnxCreateKeywordSpotter
|
||||
/// @param stream A pointer returned by SherpaOnnxCreateKeywordStream
|
||||
SHERPA_ONNX_API int32_t SherpaOnnxIsKeywordStreamReady(
|
||||
SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream *stream);
|
||||
SHERPA_ONNX_API int32_t
|
||||
SherpaOnnxIsKeywordStreamReady(const SherpaOnnxKeywordSpotter *spotter,
|
||||
const SherpaOnnxOnlineStream *stream);
|
||||
|
||||
/// Call this function to run the neural network model and decoding.
|
||||
//
|
||||
/// Precondition for this function: SherpaOnnxIsKeywordStreamReady() MUST
|
||||
/// return 1.
|
||||
SHERPA_ONNX_API void SherpaOnnxDecodeKeywordStream(
|
||||
SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream *stream);
|
||||
const SherpaOnnxKeywordSpotter *spotter,
|
||||
const SherpaOnnxOnlineStream *stream);
|
||||
|
||||
/// Please call it right after a keyword is detected
|
||||
SHERPA_ONNX_API void SherpaOnnxResetKeywordStream(
|
||||
const SherpaOnnxKeywordSpotter *spotter,
|
||||
const SherpaOnnxOnlineStream *stream);
|
||||
|
||||
/// This function is similar to SherpaOnnxDecodeKeywordStream(). It decodes
|
||||
/// multiple OnlineStream in parallel.
|
||||
@@ -714,8 +721,8 @@ SHERPA_ONNX_API void SherpaOnnxDecodeKeywordStream(
|
||||
/// SherpaOnnxCreateKeywordStream()
|
||||
/// @param n Number of elements in the given streams array.
|
||||
SHERPA_ONNX_API void SherpaOnnxDecodeMultipleKeywordStreams(
|
||||
SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream **streams,
|
||||
int32_t n);
|
||||
const SherpaOnnxKeywordSpotter *spotter,
|
||||
const SherpaOnnxOnlineStream **streams, int32_t n);
|
||||
|
||||
/// Get the decoding results so far for an OnlineStream.
|
||||
///
|
||||
@@ -725,7 +732,8 @@ SHERPA_ONNX_API void SherpaOnnxDecodeMultipleKeywordStreams(
|
||||
/// SherpaOnnxDestroyKeywordResult() to free the returned pointer to
|
||||
/// avoid memory leak.
|
||||
SHERPA_ONNX_API const SherpaOnnxKeywordResult *SherpaOnnxGetKeywordResult(
|
||||
SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream *stream);
|
||||
const SherpaOnnxKeywordSpotter *spotter,
|
||||
const SherpaOnnxOnlineStream *stream);
|
||||
|
||||
/// Destroy the pointer returned by SherpaOnnxGetKeywordResult().
|
||||
///
|
||||
@@ -736,7 +744,8 @@ SHERPA_ONNX_API void SherpaOnnxDestroyKeywordResult(
|
||||
// the user has to call SherpaOnnxFreeKeywordResultJson() to free the returned
|
||||
// pointer to avoid memory leak
|
||||
SHERPA_ONNX_API const char *SherpaOnnxGetKeywordResultAsJson(
|
||||
SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream *stream);
|
||||
const SherpaOnnxKeywordSpotter *spotter,
|
||||
const SherpaOnnxOnlineStream *stream);
|
||||
|
||||
SHERPA_ONNX_API void SherpaOnnxFreeKeywordResultJson(const char *s);
|
||||
|
||||
|
||||
@@ -391,4 +391,112 @@ GeneratedAudio OfflineTts::Generate(const std::string &text,
|
||||
return ans;
|
||||
}
|
||||
|
||||
KeywordSpotter KeywordSpotter::Create(const KeywordSpotterConfig &config) {
|
||||
struct SherpaOnnxKeywordSpotterConfig c;
|
||||
memset(&c, 0, sizeof(c));
|
||||
|
||||
c.feat_config.sample_rate = config.feat_config.sample_rate;
|
||||
|
||||
c.model_config.transducer.encoder =
|
||||
config.model_config.transducer.encoder.c_str();
|
||||
c.model_config.transducer.decoder =
|
||||
config.model_config.transducer.decoder.c_str();
|
||||
c.model_config.transducer.joiner =
|
||||
config.model_config.transducer.joiner.c_str();
|
||||
c.feat_config.feature_dim = config.feat_config.feature_dim;
|
||||
|
||||
c.model_config.paraformer.encoder =
|
||||
config.model_config.paraformer.encoder.c_str();
|
||||
c.model_config.paraformer.decoder =
|
||||
config.model_config.paraformer.decoder.c_str();
|
||||
|
||||
c.model_config.zipformer2_ctc.model =
|
||||
config.model_config.zipformer2_ctc.model.c_str();
|
||||
|
||||
c.model_config.tokens = config.model_config.tokens.c_str();
|
||||
c.model_config.num_threads = config.model_config.num_threads;
|
||||
c.model_config.provider = config.model_config.provider.c_str();
|
||||
c.model_config.debug = config.model_config.debug;
|
||||
c.model_config.model_type = config.model_config.model_type.c_str();
|
||||
c.model_config.modeling_unit = config.model_config.modeling_unit.c_str();
|
||||
c.model_config.bpe_vocab = config.model_config.bpe_vocab.c_str();
|
||||
c.model_config.tokens_buf = config.model_config.tokens_buf.c_str();
|
||||
c.model_config.tokens_buf_size = config.model_config.tokens_buf.size();
|
||||
|
||||
c.max_active_paths = config.max_active_paths;
|
||||
c.num_trailing_blanks = config.num_trailing_blanks;
|
||||
c.keywords_score = config.keywords_score;
|
||||
c.keywords_threshold = config.keywords_threshold;
|
||||
c.keywords_file = config.keywords_file.c_str();
|
||||
|
||||
auto p = SherpaOnnxCreateKeywordSpotter(&c);
|
||||
return KeywordSpotter(p);
|
||||
}
|
||||
|
||||
KeywordSpotter::KeywordSpotter(const SherpaOnnxKeywordSpotter *p)
|
||||
: MoveOnly<KeywordSpotter, SherpaOnnxKeywordSpotter>(p) {}
|
||||
|
||||
void KeywordSpotter::Destroy(const SherpaOnnxKeywordSpotter *p) const {
|
||||
SherpaOnnxDestroyKeywordSpotter(p);
|
||||
}
|
||||
|
||||
OnlineStream KeywordSpotter::CreateStream() const {
|
||||
auto s = SherpaOnnxCreateKeywordStream(p_);
|
||||
return OnlineStream{s};
|
||||
}
|
||||
|
||||
OnlineStream KeywordSpotter::CreateStream(const std::string &keywords) const {
|
||||
auto s = SherpaOnnxCreateKeywordStreamWithKeywords(p_, keywords.c_str());
|
||||
return OnlineStream{s};
|
||||
}
|
||||
|
||||
bool KeywordSpotter::IsReady(const OnlineStream *s) const {
|
||||
return SherpaOnnxIsKeywordStreamReady(p_, s->Get());
|
||||
}
|
||||
|
||||
void KeywordSpotter::Decode(const OnlineStream *s) const {
|
||||
return SherpaOnnxDecodeKeywordStream(p_, s->Get());
|
||||
}
|
||||
|
||||
void KeywordSpotter::Decode(const OnlineStream *ss, int32_t n) const {
|
||||
if (n <= 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<const SherpaOnnxOnlineStream *> streams(n);
|
||||
for (int32_t i = 0; i != n; ++n) {
|
||||
streams[i] = ss[i].Get();
|
||||
}
|
||||
|
||||
SherpaOnnxDecodeMultipleKeywordStreams(p_, streams.data(), n);
|
||||
}
|
||||
|
||||
KeywordResult KeywordSpotter::GetResult(const OnlineStream *s) const {
|
||||
auto r = SherpaOnnxGetKeywordResult(p_, s->Get());
|
||||
|
||||
KeywordResult ans;
|
||||
ans.keyword = r->keyword;
|
||||
|
||||
ans.tokens.resize(r->count);
|
||||
for (int32_t i = 0; i < r->count; ++i) {
|
||||
ans.tokens[i] = r->tokens_arr[i];
|
||||
}
|
||||
|
||||
if (r->timestamps) {
|
||||
ans.timestamps.resize(r->count);
|
||||
std::copy(r->timestamps, r->timestamps + r->count, ans.timestamps.data());
|
||||
}
|
||||
|
||||
ans.start_time = r->start_time;
|
||||
ans.json = r->json;
|
||||
|
||||
SherpaOnnxDestroyKeywordResult(r);
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
void KeywordSpotter::Reset(const OnlineStream *s) const {
|
||||
SherpaOnnxResetKeywordStream(p_, s->Get());
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx::cxx
|
||||
|
||||
@@ -406,6 +406,53 @@ class SHERPA_ONNX_API OfflineTts
|
||||
explicit OfflineTts(const SherpaOnnxOfflineTts *p);
|
||||
};
|
||||
|
||||
// ============================================================
|
||||
// For Keyword Spotter
|
||||
// ============================================================
|
||||
|
||||
struct KeywordResult {
|
||||
std::string keyword;
|
||||
std::vector<std::string> tokens;
|
||||
std::vector<float> timestamps;
|
||||
float start_time;
|
||||
std::string json;
|
||||
};
|
||||
|
||||
struct KeywordSpotterConfig {
|
||||
FeatureConfig feat_config;
|
||||
OnlineModelConfig model_config;
|
||||
int32_t max_active_paths = 4;
|
||||
int32_t num_trailing_blanks = 1;
|
||||
float keywords_score = 1.0f;
|
||||
float keywords_threshold = 0.25f;
|
||||
std::string keywords_file;
|
||||
};
|
||||
|
||||
class SHERPA_ONNX_API KeywordSpotter
|
||||
: public MoveOnly<KeywordSpotter, SherpaOnnxKeywordSpotter> {
|
||||
public:
|
||||
static KeywordSpotter Create(const KeywordSpotterConfig &config);
|
||||
|
||||
void Destroy(const SherpaOnnxKeywordSpotter *p) const;
|
||||
|
||||
OnlineStream CreateStream() const;
|
||||
|
||||
OnlineStream CreateStream(const std::string &keywords) const;
|
||||
|
||||
bool IsReady(const OnlineStream *s) const;
|
||||
|
||||
void Decode(const OnlineStream *s) const;
|
||||
|
||||
void Decode(const OnlineStream *ss, int32_t n) const;
|
||||
|
||||
void Reset(const OnlineStream *s) const;
|
||||
|
||||
KeywordResult GetResult(const OnlineStream *s) const;
|
||||
|
||||
private:
|
||||
explicit KeywordSpotter(const SherpaOnnxKeywordSpotter *p);
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx::cxx
|
||||
|
||||
#endif // SHERPA_ONNX_C_API_CXX_API_H_
|
||||
|
||||
@@ -38,6 +38,8 @@ class KeywordSpotterImpl {
|
||||
|
||||
virtual bool IsReady(OnlineStream *s) const = 0;
|
||||
|
||||
virtual void Reset(OnlineStream *s) const = 0;
|
||||
|
||||
virtual void DecodeStreams(OnlineStream **ss, int32_t n) const = 0;
|
||||
|
||||
virtual KeywordResult GetResult(OnlineStream *s) const = 0;
|
||||
|
||||
@@ -195,8 +195,24 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl {
|
||||
return s->GetNumProcessedFrames() + model_->ChunkSize() <
|
||||
s->NumFramesReady();
|
||||
}
|
||||
void Reset(OnlineStream *s) const override { InitOnlineStream(s); }
|
||||
|
||||
void DecodeStreams(OnlineStream **ss, int32_t n) const override {
|
||||
for (int32_t i = 0; i < n; ++i) {
|
||||
auto s = ss[i];
|
||||
auto r = s->GetKeywordResult(true);
|
||||
int32_t num_trailing_blanks = r.num_trailing_blanks;
|
||||
// assume subsampling_factor is 4
|
||||
// assume frameshift is 0.01 second
|
||||
float trailing_slience = num_trailing_blanks * 4 * 0.01;
|
||||
|
||||
// it resets automatically after detecting 1.5 seconds of silence
|
||||
float threshold = 1.5;
|
||||
if (trailing_slience > threshold) {
|
||||
Reset(s);
|
||||
}
|
||||
}
|
||||
|
||||
int32_t chunk_size = model_->ChunkSize();
|
||||
int32_t chunk_shift = model_->ChunkShift();
|
||||
|
||||
|
||||
@@ -157,6 +157,8 @@ bool KeywordSpotter::IsReady(OnlineStream *s) const {
|
||||
return impl_->IsReady(s);
|
||||
}
|
||||
|
||||
void KeywordSpotter::Reset(OnlineStream *s) const { impl_->Reset(s); }
|
||||
|
||||
void KeywordSpotter::DecodeStreams(OnlineStream **ss, int32_t n) const {
|
||||
impl_->DecodeStreams(ss, n);
|
||||
}
|
||||
|
||||
@@ -129,6 +129,9 @@ class KeywordSpotter {
|
||||
*/
|
||||
bool IsReady(OnlineStream *s) const;
|
||||
|
||||
// Remember to call it after detecting a keyword
|
||||
void Reset(OnlineStream *s) const;
|
||||
|
||||
/** Decode a single stream. */
|
||||
void DecodeStream(OnlineStream *s) const {
|
||||
OnlineStream *ss[1] = {s};
|
||||
|
||||
@@ -106,13 +106,15 @@ as the device_name.
|
||||
|
||||
while (spotter.IsReady(stream.get())) {
|
||||
spotter.DecodeStream(stream.get());
|
||||
}
|
||||
|
||||
const auto r = spotter.GetResult(stream.get());
|
||||
if (!r.keyword.empty()) {
|
||||
display.Print(keyword_index, r.AsJsonString());
|
||||
fflush(stderr);
|
||||
keyword_index++;
|
||||
const auto r = spotter.GetResult(stream.get());
|
||||
if (!r.keyword.empty()) {
|
||||
display.Print(keyword_index, r.AsJsonString());
|
||||
fflush(stderr);
|
||||
keyword_index++;
|
||||
|
||||
spotter.Reset(stream.get());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -150,13 +150,15 @@ for a list of pre-trained models to download.
|
||||
while (!stop) {
|
||||
while (spotter.IsReady(s.get())) {
|
||||
spotter.DecodeStream(s.get());
|
||||
}
|
||||
|
||||
const auto r = spotter.GetResult(s.get());
|
||||
if (!r.keyword.empty()) {
|
||||
display.Print(keyword_index, r.AsJsonString());
|
||||
fflush(stderr);
|
||||
keyword_index++;
|
||||
const auto r = spotter.GetResult(s.get());
|
||||
if (!r.keyword.empty()) {
|
||||
display.Print(keyword_index, r.AsJsonString());
|
||||
fflush(stderr);
|
||||
keyword_index++;
|
||||
|
||||
spotter.Reset(s.get());
|
||||
}
|
||||
}
|
||||
|
||||
Pa_Sleep(20); // sleep for 20ms
|
||||
|
||||
@@ -27,6 +27,10 @@ public class KeywordSpotter {
|
||||
decode(ptr, s.getPtr());
|
||||
}
|
||||
|
||||
public void reset(OnlineStream s) {
|
||||
reset(ptr, s.getPtr());
|
||||
}
|
||||
|
||||
public boolean isReady(OnlineStream s) {
|
||||
return isReady(ptr, s.getPtr());
|
||||
}
|
||||
@@ -60,6 +64,8 @@ public class KeywordSpotter {
|
||||
|
||||
private native void decode(long ptr, long streamPtr);
|
||||
|
||||
private native void reset(long ptr, long streamPtr);
|
||||
|
||||
private native boolean isReady(long ptr, long streamPtr);
|
||||
|
||||
private native Object[] getResult(long ptr, long streamPtr);
|
||||
|
||||
@@ -161,6 +161,15 @@ JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_KeywordSpotter_decode(
|
||||
kws->DecodeStream(stream);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_KeywordSpotter_reset(
|
||||
JNIEnv * /*env*/, jobject /*obj*/, jlong ptr, jlong stream_ptr) {
|
||||
auto kws = reinterpret_cast<sherpa_onnx::KeywordSpotter *>(ptr);
|
||||
auto stream = reinterpret_cast<sherpa_onnx::OnlineStream *>(stream_ptr);
|
||||
|
||||
kws->Reset(stream);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_KeywordSpotter_createStream(
|
||||
JNIEnv *env, jobject /*obj*/, jlong ptr, jstring keywords) {
|
||||
|
||||
@@ -49,6 +49,7 @@ class KeywordSpotter(
|
||||
}
|
||||
|
||||
fun decode(stream: OnlineStream) = decode(ptr, stream.ptr)
|
||||
fun reset(stream: OnlineStream) = reset(ptr, stream.ptr)
|
||||
fun isReady(stream: OnlineStream) = isReady(ptr, stream.ptr)
|
||||
fun getResult(stream: OnlineStream): KeywordSpotterResult {
|
||||
val objArray = getResult(ptr, stream.ptr)
|
||||
@@ -74,6 +75,7 @@ class KeywordSpotter(
|
||||
private external fun createStream(ptr: Long, keywords: String): Long
|
||||
private external fun isReady(ptr: Long, streamPtr: Long): Boolean
|
||||
private external fun decode(ptr: Long, streamPtr: Long)
|
||||
private external fun reset(ptr: Long, streamPtr: Long)
|
||||
private external fun getResult(ptr: Long, streamPtr: Long): Array<Any>
|
||||
|
||||
companion object {
|
||||
|
||||
@@ -67,6 +67,7 @@ void PybindKeywordSpotter(py::module *m) {
|
||||
py::arg("keywords"), py::call_guard<py::gil_scoped_release>())
|
||||
.def("is_ready", &PyClass::IsReady,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("reset", &PyClass::Reset, py::call_guard<py::gil_scoped_release>())
|
||||
.def("decode_stream", &PyClass::DecodeStream,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
|
||||
@@ -104,8 +104,8 @@ class KeywordSpotter(object):
|
||||
)
|
||||
|
||||
provider_config = ProviderConfig(
|
||||
provider=provider,
|
||||
device = device,
|
||||
provider=provider,
|
||||
device=device,
|
||||
)
|
||||
|
||||
model_config = OnlineModelConfig(
|
||||
@@ -131,6 +131,9 @@ class KeywordSpotter(object):
|
||||
)
|
||||
self.keyword_spotter = _KeywordSpotter(keywords_spotter_config)
|
||||
|
||||
def reset_stream(self, s: OnlineStream):
|
||||
self.keyword_spotter.reset(s)
|
||||
|
||||
def create_stream(self, keywords: Optional[str] = None):
|
||||
if keywords is None:
|
||||
return self.keyword_spotter.create_stream()
|
||||
|
||||
@@ -98,6 +98,9 @@ class TestKeywordSpotter(unittest.TestCase):
|
||||
if r:
|
||||
print(f"{r} is detected.")
|
||||
results[i] += f"{r}/"
|
||||
|
||||
keyword_spotter.reset_stream(s)
|
||||
|
||||
if len(ready_list) == 0:
|
||||
break
|
||||
keyword_spotter.decode_streams(ready_list)
|
||||
@@ -158,6 +161,9 @@ class TestKeywordSpotter(unittest.TestCase):
|
||||
if r:
|
||||
print(f"{r} is detected.")
|
||||
results[i] += f"{r}/"
|
||||
|
||||
keyword_spotter.reset_stream(s)
|
||||
|
||||
if len(ready_list) == 0:
|
||||
break
|
||||
keyword_spotter.decode_streams(ready_list)
|
||||
|
||||
@@ -1076,6 +1076,10 @@ class SherpaOnnxKeywordSpotterWrapper {
|
||||
SherpaOnnxDecodeKeywordStream(spotter, stream)
|
||||
}
|
||||
|
||||
func reset() {
|
||||
SherpaOnnxResetKeywordStream(spotter, stream)
|
||||
}
|
||||
|
||||
func getResult() -> SherpaOnnxKeywordResultWrapper {
|
||||
let result: UnsafePointer<SherpaOnnxKeywordResult>? = SherpaOnnxGetKeywordResult(
|
||||
spotter, stream)
|
||||
|
||||
@@ -70,6 +70,9 @@ func run() {
|
||||
spotter.decode()
|
||||
let keyword = spotter.getResult().keyword
|
||||
if keyword != "" {
|
||||
// Remember to call reset() right after detecting a keyword
|
||||
spotter.reset()
|
||||
|
||||
print("Detected: \(keyword)")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ set(exported_functions
|
||||
SherpaOnnxIsKeywordStreamReady
|
||||
SherpaOnnxOnlineStreamAcceptWaveform
|
||||
SherpaOnnxOnlineStreamInputFinished
|
||||
SherpaOnnxResetKeywordStream
|
||||
)
|
||||
set(mangled_exported_functions)
|
||||
foreach(x IN LISTS exported_functions)
|
||||
|
||||
@@ -102,15 +102,17 @@ if (navigator.mediaDevices.getUserMedia) {
|
||||
recognizer_stream.acceptWaveform(expectedSampleRate, samples);
|
||||
while (recognizer.isReady(recognizer_stream)) {
|
||||
recognizer.decode(recognizer_stream);
|
||||
}
|
||||
|
||||
let result = recognizer.getResult(recognizer_stream);
|
||||
|
||||
let result = recognizer.getResult(recognizer_stream);
|
||||
if (result.keyword.length > 0) {
|
||||
console.log(result)
|
||||
lastResult = result;
|
||||
resultList.push(JSON.stringify(result));
|
||||
|
||||
if (result.keyword.length > 0) {
|
||||
console.log(result)
|
||||
lastResult = result;
|
||||
resultList.push(JSON.stringify(result));
|
||||
// remember to reset the stream right after detecting a keyword
|
||||
recognizer.reset(recognizer_stream);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -296,8 +296,11 @@ class Kws {
|
||||
}
|
||||
|
||||
decode(stream) {
|
||||
return this.Module._SherpaOnnxDecodeKeywordStream(
|
||||
this.handle, stream.handle);
|
||||
this.Module._SherpaOnnxDecodeKeywordStream(this.handle, stream.handle);
|
||||
}
|
||||
|
||||
reset(stream) {
|
||||
this.Module._SherpaOnnxResetKeywordStream(this.handle, stream.handle);
|
||||
}
|
||||
|
||||
getResult(stream) {
|
||||
|
||||
Reference in New Issue
Block a user