From 8b989a851cbb759976d1f8d40cae91dd9362f816 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 20 Jan 2025 16:41:10 +0800 Subject: [PATCH] Fix keyword spotting. (#1689) Reset the stream right after detecting a keyword --- .github/scripts/test-python.sh | 33 +-- .github/workflows/c-api.yaml | 21 ++ .github/workflows/cxx-api.yaml | 22 ++ .../com/k2fsa/sherpa/onnx/MainActivity.kt | 29 ++- c-api-examples/CMakeLists.txt | 3 + c-api-examples/kws-c-api.c | 150 +++++++++++ cxx-api-examples/CMakeLists.txt | 3 + cxx-api-examples/kws-cxx-api.cc | 141 +++++++++++ .../bin/zipformer-transducer.dart | 2 + .../keyword-spotting-from-files/Program.cs | 6 + .../Program.cs | 13 +- .../sherpa_onnx/lib/src/keyword_spotter.dart | 4 + .../lib/src/sherpa_onnx_bindings.dart | 12 + .../keyword-spotting-from-file/main.go | 2 + .../src/main/cpp/keyword-spotting.cc | 61 ++++- java-api-examples/KeywordSpotterFromFile.java | 2 + .../test-keyword-spotter-transducer.js | 3 + .../keyword-spotter-from-microphone.py | 11 +- python-api-examples/keyword-spotter.py | 235 ++++-------------- scripts/dotnet/KeywordSpotter.cs | 8 + scripts/go/sherpa_onnx.go | 5 + scripts/node-addon-api/lib/keyword-spotter.js | 4 + sherpa-onnx/c-api/c-api.cc | 37 +-- sherpa-onnx/c-api/c-api.h | 33 ++- sherpa-onnx/c-api/cxx-api.cc | 108 ++++++++ sherpa-onnx/c-api/cxx-api.h | 47 ++++ sherpa-onnx/csrc/keyword-spotter-impl.h | 2 + .../csrc/keyword-spotter-transducer-impl.h | 16 ++ sherpa-onnx/csrc/keyword-spotter.cc | 2 + sherpa-onnx/csrc/keyword-spotter.h | 3 + .../csrc/sherpa-onnx-keyword-spotter-alsa.cc | 14 +- .../sherpa-onnx-keyword-spotter-microphone.cc | 14 +- .../com/k2fsa/sherpa/onnx/KeywordSpotter.java | 6 + sherpa-onnx/jni/keyword-spotter.cc | 9 + sherpa-onnx/kotlin-api/KeywordSpotter.kt | 2 + sherpa-onnx/python/csrc/keyword-spotter.cc | 1 + .../python/sherpa_onnx/keyword_spotter.py | 7 +- .../python/tests/test_keyword_spotter.py | 6 + swift-api-examples/SherpaOnnx.swift | 4 + .../keyword-spotting-from-file.swift | 3 + wasm/kws/CMakeLists.txt | 1 + wasm/kws/app.js | 14 +- wasm/kws/sherpa-onnx-kws.js | 7 +- 43 files changed, 813 insertions(+), 293 deletions(-) create mode 100644 c-api-examples/kws-c-api.c create mode 100644 cxx-api-examples/kws-cxx-api.cc diff --git a/.github/scripts/test-python.sh b/.github/scripts/test-python.sh index ad037438..39e6577a 100755 --- a/.github/scripts/test-python.sh +++ b/.github/scripts/test-python.sh @@ -574,29 +574,6 @@ echo "sherpa_onnx version: $sherpa_onnx_version" pwd ls -lh -repo=sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01 -log "Start testing ${repo}" - -pushd $dir -curl -LS -O https://github.com/pkufool/keyword-spotting-models/releases/download/v0.1/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01.tar.bz -tar xf sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01.tar.bz -rm sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01.tar.bz -popd - -repo=$dir/$repo -ls -lh $repo - -python3 ./python-api-examples/keyword-spotter.py \ - --tokens=$repo/tokens.txt \ - --encoder=$repo/encoder-epoch-12-avg-2-chunk-16-left-64.onnx \ - --decoder=$repo/decoder-epoch-12-avg-2-chunk-16-left-64.onnx \ - --joiner=$repo/joiner-epoch-12-avg-2-chunk-16-left-64.onnx \ - --keywords-file=$repo/test_wavs/test_keywords.txt \ - $repo/test_wavs/0.wav \ - $repo/test_wavs/1.wav - -rm -rf $repo - if [[ x$OS != x'windows-latest' ]]; then echo "OS: $OS" @@ -612,15 +589,7 @@ if [[ x$OS != x'windows-latest' ]]; then repo=$dir/$repo ls -lh $repo - python3 ./python-api-examples/keyword-spotter.py \ - --tokens=$repo/tokens.txt \ - --encoder=$repo/encoder-epoch-12-avg-2-chunk-16-left-64.onnx \ - --decoder=$repo/decoder-epoch-12-avg-2-chunk-16-left-64.onnx \ - --joiner=$repo/joiner-epoch-12-avg-2-chunk-16-left-64.onnx \ - --keywords-file=$repo/test_wavs/test_keywords.txt \ - $repo/test_wavs/3.wav \ - $repo/test_wavs/4.wav \ - $repo/test_wavs/5.wav + python3 ./python-api-examples/keyword-spotter.py python3 sherpa-onnx/python/tests/test_keyword_spotter.py --verbose diff --git a/.github/workflows/c-api.yaml b/.github/workflows/c-api.yaml index 58820dc0..44379769 100644 --- a/.github/workflows/c-api.yaml +++ b/.github/workflows/c-api.yaml @@ -79,6 +79,27 @@ jobs: otool -L ./install/lib/libsherpa-onnx-c-api.dylib fi + - name: Test kws (zh) + shell: bash + run: | + gcc -o kws-c-api ./c-api-examples/kws-c-api.c \ + -I ./build/install/include \ + -L ./build/install/lib/ \ + -l sherpa-onnx-c-api \ + -l onnxruntime + + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/kws-models/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2 + tar xvf sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2 + rm sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2 + + export LD_LIBRARY_PATH=$PWD/build/install/lib:$LD_LIBRARY_PATH + export DYLD_LIBRARY_PATH=$PWD/build/install/lib:$DYLD_LIBRARY_PATH + + ./kws-c-api + + rm ./kws-c-api + rm -rf sherpa-onnx-kws-* + - name: Test Kokoro TTS (en) shell: bash run: | diff --git a/.github/workflows/cxx-api.yaml b/.github/workflows/cxx-api.yaml index 2f6a3b2e..7227dd42 100644 --- a/.github/workflows/cxx-api.yaml +++ b/.github/workflows/cxx-api.yaml @@ -81,6 +81,28 @@ jobs: otool -L ./install/lib/libsherpa-onnx-cxx-api.dylib fi + - name: Test KWS (zh) + shell: bash + run: | + g++ -std=c++17 -o kws-cxx-api ./cxx-api-examples/kws-cxx-api.cc \ + -I ./build/install/include \ + -L ./build/install/lib/ \ + -l sherpa-onnx-cxx-api \ + -l sherpa-onnx-c-api \ + -l onnxruntime + + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/kws-models/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2 + tar xvf sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2 + rm sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2 + + export LD_LIBRARY_PATH=$PWD/build/install/lib:$LD_LIBRARY_PATH + export DYLD_LIBRARY_PATH=$PWD/build/install/lib:$DYLD_LIBRARY_PATH + + ./kws-cxx-api + + rm kws-cxx-api + rm -rf sherpa-onnx-kws-* + - name: Test Kokoro TTS (en) shell: bash run: | diff --git a/android/SherpaOnnxKws/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt b/android/SherpaOnnxKws/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt index b17a6ea6..b42937ad 100644 --- a/android/SherpaOnnxKws/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt +++ b/android/SherpaOnnxKws/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt @@ -151,24 +151,27 @@ class MainActivity : AppCompatActivity() { stream.acceptWaveform(samples, sampleRate = sampleRateInHz) while (kws.isReady(stream)) { kws.decode(stream) - } - val text = kws.getResult(stream).keyword + val text = kws.getResult(stream).keyword - var textToDisplay = lastText + var textToDisplay = lastText - if (text.isNotBlank()) { - if (lastText.isBlank()) { - textToDisplay = "$idx: $text" - } else { - textToDisplay = "$idx: $text\n$lastText" + if (text.isNotBlank()) { + // Remember to reset the stream right after detecting a keyword + + kws.reset(stream) + if (lastText.isBlank()) { + textToDisplay = "$idx: $text" + } else { + textToDisplay = "$idx: $text\n$lastText" + } + lastText = "$idx: $text\n$lastText" + idx += 1 } - lastText = "$idx: $text\n$lastText" - idx += 1 - } - runOnUiThread { - textView.text = textToDisplay + runOnUiThread { + textView.text = textToDisplay + } } } } diff --git a/c-api-examples/CMakeLists.txt b/c-api-examples/CMakeLists.txt index a2bfb6fd..3db3f253 100644 --- a/c-api-examples/CMakeLists.txt +++ b/c-api-examples/CMakeLists.txt @@ -4,6 +4,9 @@ include_directories(${CMAKE_SOURCE_DIR}) add_executable(decode-file-c-api decode-file-c-api.c) target_link_libraries(decode-file-c-api sherpa-onnx-c-api cargs) +add_executable(kws-c-api kws-c-api.c) +target_link_libraries(kws-c-api sherpa-onnx-c-api) + if(SHERPA_ONNX_ENABLE_TTS) add_executable(offline-tts-c-api offline-tts-c-api.c) target_link_libraries(offline-tts-c-api sherpa-onnx-c-api cargs) diff --git a/c-api-examples/kws-c-api.c b/c-api-examples/kws-c-api.c new file mode 100644 index 00000000..3ac42758 --- /dev/null +++ b/c-api-examples/kws-c-api.c @@ -0,0 +1,150 @@ +// c-api-examples/kws-c-api.c +// +// Copyright (c) 2025 Xiaomi Corporation +// +// This file demonstrates how to use keywords spotter with sherpa-onnx's C +// clang-format off +// +// Usage +// +// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/kws-models/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2 +// tar xvf sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2 +// rm sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2 +// +// ./kws-c-api +// +// clang-format on +#include +#include // exit +#include // memset + +#include "sherpa-onnx/c-api/c-api.h" + +int32_t main() { + SherpaOnnxKeywordSpotterConfig config; + + memset(&config, 0, sizeof(config)); + config.model_config.transducer.encoder = + "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/" + "encoder-epoch-12-avg-2-chunk-16-left-64.onnx"; + + config.model_config.transducer.decoder = + "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/" + "decoder-epoch-12-avg-2-chunk-16-left-64.onnx"; + + config.model_config.transducer.joiner = + "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/" + "joiner-epoch-12-avg-2-chunk-16-left-64.onnx"; + + config.model_config.tokens = + "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/tokens.txt"; + + config.model_config.provider = "cpu"; + config.model_config.num_threads = 1; + config.model_config.debug = 1; + + config.keywords_file = + "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/" + "test_keywords.txt"; + + const SherpaOnnxKeywordSpotter *kws = SherpaOnnxCreateKeywordSpotter(&config); + if (!kws) { + fprintf(stderr, "Please check your config"); + exit(-1); + } + + fprintf(stderr, + "--Test pre-defined keywords from test_wavs/test_keywords.txt--\n"); + + const char *wav_filename = + "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/3.wav"; + + float tail_paddings[8000] = {0}; // 0.5 seconds + + const SherpaOnnxWave *wave = SherpaOnnxReadWave(wav_filename); + if (wave == NULL) { + fprintf(stderr, "Failed to read %s\n", wav_filename); + exit(-1); + } + + const SherpaOnnxOnlineStream *stream = SherpaOnnxCreateKeywordStream(kws); + if (!stream) { + fprintf(stderr, "Failed to create stream\n"); + exit(-1); + } + + SherpaOnnxOnlineStreamAcceptWaveform(stream, wave->sample_rate, wave->samples, + wave->num_samples); + + SherpaOnnxOnlineStreamAcceptWaveform(stream, wave->sample_rate, tail_paddings, + sizeof(tail_paddings) / sizeof(float)); + SherpaOnnxOnlineStreamInputFinished(stream); + while (SherpaOnnxIsKeywordStreamReady(kws, stream)) { + SherpaOnnxDecodeKeywordStream(kws, stream); + const SherpaOnnxKeywordResult *r = SherpaOnnxGetKeywordResult(kws, stream); + if (r && r->json && strlen(r->keyword)) { + fprintf(stderr, "Detected keyword: %s\n", r->json); + + // Remember to reset the keyword stream right after a keyword is detected + SherpaOnnxResetKeywordStream(kws, stream); + } + SherpaOnnxDestroyKeywordResult(r); + } + SherpaOnnxDestroyOnlineStream(stream); + + // -------------------------------------------------------------------------- + + fprintf(stderr, "--Use pre-defined keywords + add a new keyword--\n"); + + stream = SherpaOnnxCreateKeywordStreamWithKeywords(kws, "y ǎn y uán @演员"); + + SherpaOnnxOnlineStreamAcceptWaveform(stream, wave->sample_rate, wave->samples, + wave->num_samples); + + SherpaOnnxOnlineStreamAcceptWaveform(stream, wave->sample_rate, tail_paddings, + sizeof(tail_paddings) / sizeof(float)); + SherpaOnnxOnlineStreamInputFinished(stream); + while (SherpaOnnxIsKeywordStreamReady(kws, stream)) { + SherpaOnnxDecodeKeywordStream(kws, stream); + const SherpaOnnxKeywordResult *r = SherpaOnnxGetKeywordResult(kws, stream); + if (r && r->json && strlen(r->keyword)) { + fprintf(stderr, "Detected keyword: %s\n", r->json); + + // Remember to reset the keyword stream + SherpaOnnxResetKeywordStream(kws, stream); + } + SherpaOnnxDestroyKeywordResult(r); + } + SherpaOnnxDestroyOnlineStream(stream); + + // -------------------------------------------------------------------------- + + fprintf(stderr, "--Use pre-defined keywords + add two new keywords--\n"); + + stream = SherpaOnnxCreateKeywordStreamWithKeywords( + kws, "y ǎn y uán @演员/zh ī m íng @知名"); + + SherpaOnnxOnlineStreamAcceptWaveform(stream, wave->sample_rate, wave->samples, + wave->num_samples); + + SherpaOnnxOnlineStreamAcceptWaveform(stream, wave->sample_rate, tail_paddings, + sizeof(tail_paddings) / sizeof(float)); + SherpaOnnxOnlineStreamInputFinished(stream); + while (SherpaOnnxIsKeywordStreamReady(kws, stream)) { + SherpaOnnxDecodeKeywordStream(kws, stream); + const SherpaOnnxKeywordResult *r = SherpaOnnxGetKeywordResult(kws, stream); + if (r && r->json && strlen(r->keyword)) { + fprintf(stderr, "Detected keyword: %s\n", r->json); + + // Remember to reset the keyword stream + SherpaOnnxResetKeywordStream(kws, stream); + } + SherpaOnnxDestroyKeywordResult(r); + } + SherpaOnnxDestroyOnlineStream(stream); + + SherpaOnnxFreeWave(wave); + SherpaOnnxDestroyKeywordSpotter(kws); + + return 0; +} diff --git a/cxx-api-examples/CMakeLists.txt b/cxx-api-examples/CMakeLists.txt index 040925c2..2250736d 100644 --- a/cxx-api-examples/CMakeLists.txt +++ b/cxx-api-examples/CMakeLists.txt @@ -3,6 +3,9 @@ include_directories(${CMAKE_SOURCE_DIR}) add_executable(streaming-zipformer-cxx-api ./streaming-zipformer-cxx-api.cc) target_link_libraries(streaming-zipformer-cxx-api sherpa-onnx-cxx-api) +add_executable(kws-cxx-api ./kws-cxx-api.cc) +target_link_libraries(kws-cxx-api sherpa-onnx-cxx-api) + add_executable(streaming-zipformer-rtf-cxx-api ./streaming-zipformer-rtf-cxx-api.cc) target_link_libraries(streaming-zipformer-rtf-cxx-api sherpa-onnx-cxx-api) diff --git a/cxx-api-examples/kws-cxx-api.cc b/cxx-api-examples/kws-cxx-api.cc new file mode 100644 index 00000000..cdcb86ba --- /dev/null +++ b/cxx-api-examples/kws-cxx-api.cc @@ -0,0 +1,141 @@ +// cxx-api-examples/kws-cxx-api.cc +// +// Copyright (c) 2025 Xiaomi Corporation +// +// This file demonstrates how to use keywords spotter with sherpa-onnx's C +// clang-format off +// +// Usage +// +// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/kws-models/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2 +// tar xvf sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2 +// rm sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2 +// +// ./kws-cxx-api +// +// clang-format on +#include +#include + +#include "sherpa-onnx/c-api/cxx-api.h" + +int32_t main() { + using namespace sherpa_onnx::cxx; // NOLINT + + KeywordSpotterConfig config; + config.model_config.transducer.encoder = + "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/" + "encoder-epoch-12-avg-2-chunk-16-left-64.onnx"; + + config.model_config.transducer.decoder = + "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/" + "decoder-epoch-12-avg-2-chunk-16-left-64.onnx"; + + config.model_config.transducer.joiner = + "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/" + "joiner-epoch-12-avg-2-chunk-16-left-64.onnx"; + + config.model_config.tokens = + "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/tokens.txt"; + + config.model_config.provider = "cpu"; + config.model_config.num_threads = 1; + config.model_config.debug = 1; + + config.keywords_file = + "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/" + "test_keywords.txt"; + + KeywordSpotter kws = KeywordSpotter::Create(config); + if (!kws.Get()) { + std::cerr << "Please check your config\n"; + return -1; + } + + std::cout + << "--Test pre-defined keywords from test_wavs/test_keywords.txt--\n"; + + std::string wave_filename = + "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/3.wav"; + + std::array tail_paddings = {0}; // 0.5 seconds + + Wave wave = ReadWave(wave_filename); + if (wave.samples.empty()) { + std::cerr << "Failed to read: '" << wave_filename << "'\n"; + return -1; + } + + OnlineStream stream = kws.CreateStream(); + if (!stream.Get()) { + std::cerr << "Failed to create stream\n"; + return -1; + } + + stream.AcceptWaveform(wave.sample_rate, wave.samples.data(), + wave.samples.size()); + + stream.AcceptWaveform(wave.sample_rate, tail_paddings.data(), + tail_paddings.size()); + stream.InputFinished(); + + while (kws.IsReady(&stream)) { + kws.Decode(&stream); + auto r = kws.GetResult(&stream); + if (!r.keyword.empty()) { + std::cout << "Detected keyword: " << r.json << "\n"; + + // Remember to reset the keyword stream right after a keyword is detected + kws.Reset(&stream); + } + } + + // -------------------------------------------------------------------------- + + std::cout << "--Use pre-defined keywords + add a new keyword--\n"; + + stream = kws.CreateStream("y ǎn y uán @演员"); + + stream.AcceptWaveform(wave.sample_rate, wave.samples.data(), + wave.samples.size()); + + stream.AcceptWaveform(wave.sample_rate, tail_paddings.data(), + tail_paddings.size()); + stream.InputFinished(); + + while (kws.IsReady(&stream)) { + kws.Decode(&stream); + auto r = kws.GetResult(&stream); + if (!r.keyword.empty()) { + std::cout << "Detected keyword: " << r.json << "\n"; + + // Remember to reset the keyword stream right after a keyword is detected + kws.Reset(&stream); + } + } + + // -------------------------------------------------------------------------- + + std::cout << "--Use pre-defined keywords + add two new keywords--\n"; + + stream = kws.CreateStream("y ǎn y uán @演员/zh ī m íng @知名"); + + stream.AcceptWaveform(wave.sample_rate, wave.samples.data(), + wave.samples.size()); + + stream.AcceptWaveform(wave.sample_rate, tail_paddings.data(), + tail_paddings.size()); + stream.InputFinished(); + + while (kws.IsReady(&stream)) { + kws.Decode(&stream); + auto r = kws.GetResult(&stream); + if (!r.keyword.empty()) { + std::cout << "Detected keyword: " << r.json << "\n"; + + // Remember to reset the keyword stream right after a keyword is detected + kws.Reset(&stream); + } + } + return 0; +} diff --git a/dart-api-examples/keyword-spotter/bin/zipformer-transducer.dart b/dart-api-examples/keyword-spotter/bin/zipformer-transducer.dart index ebef1fd7..47d58798 100644 --- a/dart-api-examples/keyword-spotter/bin/zipformer-transducer.dart +++ b/dart-api-examples/keyword-spotter/bin/zipformer-transducer.dart @@ -73,6 +73,8 @@ void main(List arguments) async { spotter.decode(stream); final result = spotter.getResult(stream); if (result.keyword != '') { + // Remember to reset the stream right after detecting a keyword + spotter.reset(stream); print('Detected: ${result.keyword}'); } } diff --git a/dotnet-examples/keyword-spotting-from-files/Program.cs b/dotnet-examples/keyword-spotting-from-files/Program.cs index 00ba3777..7ab0da2f 100644 --- a/dotnet-examples/keyword-spotting-from-files/Program.cs +++ b/dotnet-examples/keyword-spotting-from-files/Program.cs @@ -53,6 +53,8 @@ class KeywordSpotterDemo var result = kws.GetResult(s); if (result.Keyword != string.Empty) { + // Remember to call Reset() right after detecting a keyword + kws.Reset(s); Console.WriteLine("Detected: {0}", result.Keyword); } } @@ -70,6 +72,8 @@ class KeywordSpotterDemo var result = kws.GetResult(s); if (result.Keyword != string.Empty) { + // Remember to call Reset() right after detecting a keyword + kws.Reset(s); Console.WriteLine("Detected: {0}", result.Keyword); } } @@ -89,6 +93,8 @@ class KeywordSpotterDemo var result = kws.GetResult(s); if (result.Keyword != string.Empty) { + // Remember to call Reset() right after detecting a keyword + kws.Reset(s); Console.WriteLine("Detected: {0}", result.Keyword); } } diff --git a/dotnet-examples/keyword-spotting-from-microphone/Program.cs b/dotnet-examples/keyword-spotting-from-microphone/Program.cs index 05d22aee..140e6a40 100644 --- a/dotnet-examples/keyword-spotting-from-microphone/Program.cs +++ b/dotnet-examples/keyword-spotting-from-microphone/Program.cs @@ -107,12 +107,15 @@ class KeywordSpotterDemo while (kws.IsReady(s)) { kws.Decode(s); - } - var result = kws.GetResult(s); - if (result.Keyword != string.Empty) - { - Console.WriteLine("Detected: {0}", result.Keyword); + var result = kws.GetResult(s); + if (result.Keyword != string.Empty) + { + // Remember to call Reset() right after detecting a keyword + kws.Reset(s); + + Console.WriteLine("Detected: {0}", result.Keyword); + } } Thread.Sleep(200); // ms diff --git a/flutter/sherpa_onnx/lib/src/keyword_spotter.dart b/flutter/sherpa_onnx/lib/src/keyword_spotter.dart index 6e2c669a..310657d1 100644 --- a/flutter/sherpa_onnx/lib/src/keyword_spotter.dart +++ b/flutter/sherpa_onnx/lib/src/keyword_spotter.dart @@ -168,6 +168,10 @@ class KeywordSpotter { SherpaOnnxBindings.decodeKeywordStream?.call(ptr, stream.ptr); } + void reset(OnlineStream stream) { + SherpaOnnxBindings.resetKeywordStream?.call(ptr, stream.ptr); + } + Pointer ptr; KeywordSpotterConfig config; } diff --git a/flutter/sherpa_onnx/lib/src/sherpa_onnx_bindings.dart b/flutter/sherpa_onnx/lib/src/sherpa_onnx_bindings.dart index 1e41d091..e544da95 100644 --- a/flutter/sherpa_onnx/lib/src/sherpa_onnx_bindings.dart +++ b/flutter/sherpa_onnx/lib/src/sherpa_onnx_bindings.dart @@ -667,6 +667,12 @@ typedef DecodeKeywordStreamNative = Void Function( typedef DecodeKeywordStream = void Function( Pointer, Pointer); +typedef ResetKeywordStreamNative = Void Function( + Pointer, Pointer); + +typedef ResetKeywordStream = void Function( + Pointer, Pointer); + typedef GetKeywordResultAsJsonNative = Pointer Function( Pointer, Pointer); @@ -1157,6 +1163,7 @@ class SherpaOnnxBindings { static CreateKeywordStreamWithKeywords? createKeywordStreamWithKeywords; static IsKeywordStreamReady? isKeywordStreamReady; static DecodeKeywordStream? decodeKeywordStream; + static ResetKeywordStream? resetKeywordStream; static GetKeywordResultAsJson? getKeywordResultAsJson; static FreeKeywordResultJson? freeKeywordResultJson; @@ -1459,6 +1466,11 @@ class SherpaOnnxBindings { 'SherpaOnnxDecodeKeywordStream') .asFunction(); + resetKeywordStream ??= dynamicLibrary + .lookup>( + 'SherpaOnnxResetKeywordStream') + .asFunction(); + getKeywordResultAsJson ??= dynamicLibrary .lookup>( 'SherpaOnnxGetKeywordResultAsJson') diff --git a/go-api-examples/keyword-spotting-from-file/main.go b/go-api-examples/keyword-spotting-from-file/main.go index cf6ffa84..697f9f4d 100644 --- a/go-api-examples/keyword-spotting-from-file/main.go +++ b/go-api-examples/keyword-spotting-from-file/main.go @@ -43,6 +43,8 @@ func main() { spotter.Decode(stream) result := spotter.GetResult(stream) if result.Keyword != "" { + // You have to reset the stream right after detecting a keyword + spotter.Reset(stream) log.Printf("Detected %v\n", result.Keyword) } } diff --git a/harmony-os/SherpaOnnxHar/sherpa_onnx/src/main/cpp/keyword-spotting.cc b/harmony-os/SherpaOnnxHar/sherpa_onnx/src/main/cpp/keyword-spotting.cc index 2b5a2410..08f1b517 100644 --- a/harmony-os/SherpaOnnxHar/sherpa_onnx/src/main/cpp/keyword-spotting.cc +++ b/harmony-os/SherpaOnnxHar/sherpa_onnx/src/main/cpp/keyword-spotting.cc @@ -46,7 +46,7 @@ static Napi::External CreateKeywordSpotterWrapper( SHERPA_ONNX_ASSIGN_ATTR_STR(keywords_buf, keywordsBuf); SHERPA_ONNX_ASSIGN_ATTR_INT32(keywords_buf_size, keywordsBufSize); - SherpaOnnxKeywordSpotter *kws = SherpaOnnxCreateKeywordSpotter(&c); + const SherpaOnnxKeywordSpotter *kws = SherpaOnnxCreateKeywordSpotter(&c); if (c.model_config.transducer.encoder) { delete[] c.model_config.transducer.encoder; @@ -100,7 +100,8 @@ static Napi::External CreateKeywordSpotterWrapper( } return Napi::External::New( - env, kws, [](Napi::Env env, SherpaOnnxKeywordSpotter *kws) { + env, const_cast(kws), + [](Napi::Env env, SherpaOnnxKeywordSpotter *kws) { SherpaOnnxDestroyKeywordSpotter(kws); }); } @@ -125,13 +126,14 @@ static Napi::External CreateKeywordStreamWrapper( return {}; } - SherpaOnnxKeywordSpotter *kws = + const SherpaOnnxKeywordSpotter *kws = info[0].As>().Data(); - SherpaOnnxOnlineStream *stream = SherpaOnnxCreateKeywordStream(kws); + const SherpaOnnxOnlineStream *stream = SherpaOnnxCreateKeywordStream(kws); return Napi::External::New( - env, stream, [](Napi::Env env, SherpaOnnxOnlineStream *stream) { + env, const_cast(stream), + [](Napi::Env env, SherpaOnnxOnlineStream *stream) { SherpaOnnxDestroyOnlineStream(stream); }); } @@ -162,10 +164,10 @@ static Napi::Boolean IsKeywordStreamReadyWrapper( return {}; } - SherpaOnnxKeywordSpotter *kws = + const SherpaOnnxKeywordSpotter *kws = info[0].As>().Data(); - SherpaOnnxOnlineStream *stream = + const SherpaOnnxOnlineStream *stream = info[1].As>().Data(); int32_t is_ready = SherpaOnnxIsKeywordStreamReady(kws, stream); @@ -198,15 +200,49 @@ static void DecodeKeywordStreamWrapper(const Napi::CallbackInfo &info) { return; } - SherpaOnnxKeywordSpotter *kws = + const SherpaOnnxKeywordSpotter *kws = info[0].As>().Data(); - SherpaOnnxOnlineStream *stream = + const SherpaOnnxOnlineStream *stream = info[1].As>().Data(); SherpaOnnxDecodeKeywordStream(kws, stream); } +static void ResetKeywordStreamWrapper(const Napi::CallbackInfo &info) { + Napi::Env env = info.Env(); + if (info.Length() != 2) { + std::ostringstream os; + os << "Expect only 2 arguments. Given: " << info.Length(); + + Napi::TypeError::New(env, os.str()).ThrowAsJavaScriptException(); + + return; + } + + if (!info[0].IsExternal()) { + Napi::TypeError::New(env, "Argument 0 should be a keyword spotter pointer.") + .ThrowAsJavaScriptException(); + + return; + } + + if (!info[1].IsExternal()) { + Napi::TypeError::New(env, "Argument 1 should be an online stream pointer.") + .ThrowAsJavaScriptException(); + + return; + } + + const SherpaOnnxKeywordSpotter *kws = + info[0].As>().Data(); + + const SherpaOnnxOnlineStream *stream = + info[1].As>().Data(); + + SherpaOnnxResetKeywordStream(kws, stream); +} + static Napi::String GetKeywordResultAsJsonWrapper( const Napi::CallbackInfo &info) { Napi::Env env = info.Env(); @@ -233,10 +269,10 @@ static Napi::String GetKeywordResultAsJsonWrapper( return {}; } - SherpaOnnxKeywordSpotter *kws = + const SherpaOnnxKeywordSpotter *kws = info[0].As>().Data(); - SherpaOnnxOnlineStream *stream = + const SherpaOnnxOnlineStream *stream = info[1].As>().Data(); const char *json = SherpaOnnxGetKeywordResultAsJson(kws, stream); @@ -261,6 +297,9 @@ void InitKeywordSpotting(Napi::Env env, Napi::Object exports) { exports.Set(Napi::String::New(env, "decodeKeywordStream"), Napi::Function::New(env, DecodeKeywordStreamWrapper)); + exports.Set(Napi::String::New(env, "resetKeywordStream"), + Napi::Function::New(env, ResetKeywordStreamWrapper)); + exports.Set(Napi::String::New(env, "getKeywordResultAsJson"), Napi::Function::New(env, GetKeywordResultAsJsonWrapper)); } diff --git a/java-api-examples/KeywordSpotterFromFile.java b/java-api-examples/KeywordSpotterFromFile.java index 1b7a739a..9634800a 100644 --- a/java-api-examples/KeywordSpotterFromFile.java +++ b/java-api-examples/KeywordSpotterFromFile.java @@ -56,6 +56,8 @@ public class KyewordSpotterFromFile { String keyword = kws.getResult(stream).getKeyword(); if (!keyword.isEmpty()) { + // Remember to reset the stream right after detecting a keyword + kws.reset(stream); System.out.printf("Detected keyword: %s\n", keyword); } } diff --git a/nodejs-examples/test-keyword-spotter-transducer.js b/nodejs-examples/test-keyword-spotter-transducer.js index 9ead2b19..746d9995 100644 --- a/nodejs-examples/test-keyword-spotter-transducer.js +++ b/nodejs-examples/test-keyword-spotter-transducer.js @@ -41,6 +41,9 @@ while (kws.isReady(stream)) { const keyword = kws.getResult(stream).keyword; if (keyword != '') { detectedKeywords.push(keyword); + + // remember to reset the stream right after detecting a keyword + kws.reset(stream); } } console.log(detectedKeywords); diff --git a/python-api-examples/keyword-spotter-from-microphone.py b/python-api-examples/keyword-spotter-from-microphone.py index 65a59fca..b634c907 100755 --- a/python-api-examples/keyword-spotter-from-microphone.py +++ b/python-api-examples/keyword-spotter-from-microphone.py @@ -169,6 +169,8 @@ def main(): print("Started! Please speak") + idx = 0 + sample_rate = 16000 samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms stream = keyword_spotter.create_stream() @@ -179,9 +181,12 @@ def main(): stream.accept_waveform(sample_rate, samples) while keyword_spotter.is_ready(stream): keyword_spotter.decode_stream(stream) - result = keyword_spotter.get_result(stream) - if result: - print("\r{}".format(result), end="", flush=True) + result = keyword_spotter.get_result(stream) + if result: + print(f"{idx}: {result }") + idx += 1 + # Remember to reset stream right after detecting a keyword + keyword_spotter.reset_stream(stream) if __name__ == "__main__": diff --git a/python-api-examples/keyword-spotter.py b/python-api-examples/keyword-spotter.py index 1b1de77e..f3f76420 100755 --- a/python-api-examples/keyword-spotter.py +++ b/python-api-examples/keyword-spotter.py @@ -18,122 +18,6 @@ import numpy as np import sherpa_onnx -def get_args(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--tokens", - type=str, - help="Path to tokens.txt", - ) - - parser.add_argument( - "--encoder", - type=str, - help="Path to the transducer encoder model", - ) - - parser.add_argument( - "--decoder", - type=str, - help="Path to the transducer decoder model", - ) - - parser.add_argument( - "--joiner", - type=str, - help="Path to the transducer joiner model", - ) - - parser.add_argument( - "--num-threads", - type=int, - default=1, - help="Number of threads for neural network computation", - ) - - parser.add_argument( - "--provider", - type=str, - default="cpu", - help="Valid values: cpu, cuda, coreml", - ) - - parser.add_argument( - "--max-active-paths", - type=int, - default=4, - help=""" - It specifies number of active paths to keep during decoding. - """, - ) - - parser.add_argument( - "--num-trailing-blanks", - type=int, - default=1, - help="""The number of trailing blanks a keyword should be followed. Setting - to a larger value (e.g. 8) when your keywords has overlapping tokens - between each other. - """, - ) - - parser.add_argument( - "--keywords-file", - type=str, - help=""" - The file containing keywords, one words/phrases per line, and for each - phrase the bpe/cjkchar/pinyin are separated by a space. For example: - - ▁HE LL O ▁WORLD - x iǎo ài t óng x ué - """, - ) - - parser.add_argument( - "--keywords-score", - type=float, - default=1.0, - help=""" - The boosting score of each token for keywords. The larger the easier to - survive beam search. - """, - ) - - parser.add_argument( - "--keywords-threshold", - type=float, - default=0.25, - help=""" - The trigger threshold (i.e. probability) of the keyword. The larger the - harder to trigger. - """, - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to decode. Each file must be of WAVE" - "format with a single channel, and each sample has 16-bit, " - "i.e., int16_t. " - "The sample rate of the file can be arbitrary and does not need to " - "be 16 kHz", - ) - - return parser.parse_args() - - -def assert_file_exists(filename: str): - assert Path(filename).is_file(), ( - f"{filename} does not exist!\n" - "Please refer to " - "https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html to download it" - ) - - def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: """ Args: @@ -159,83 +43,74 @@ def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: return samples_float32, f.getframerate() +def create_keyword_spotter(): + kws = sherpa_onnx.KeywordSpotter( + tokens="./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/tokens.txt", + encoder="./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/encoder-epoch-12-avg-2-chunk-16-left-64.onnx", + decoder="./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/decoder-epoch-12-avg-2-chunk-16-left-64.onnx", + joiner="./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/joiner-epoch-12-avg-2-chunk-16-left-64.onnx", + num_threads=2, + keywords_file="./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/test_keywords.txt", + provider="cpu", + ) + + return kws + + def main(): - args = get_args() - assert_file_exists(args.tokens) - assert_file_exists(args.encoder) - assert_file_exists(args.decoder) - assert_file_exists(args.joiner) + kws = create_keyword_spotter() - assert Path( - args.keywords_file - ).is_file(), ( - f"keywords_file : {args.keywords_file} not exist, please provide a valid path." + wave_filename = ( + "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/3.wav" ) - keyword_spotter = sherpa_onnx.KeywordSpotter( - tokens=args.tokens, - encoder=args.encoder, - decoder=args.decoder, - joiner=args.joiner, - num_threads=args.num_threads, - max_active_paths=args.max_active_paths, - keywords_file=args.keywords_file, - keywords_score=args.keywords_score, - keywords_threshold=args.keywords_threshold, - num_trailing_blanks=args.num_trailing_blanks, - provider=args.provider, - ) + samples, sample_rate = read_wave(wave_filename) - print("Started!") - start_time = time.time() + tail_paddings = np.zeros(int(0.66 * sample_rate), dtype=np.float32) - streams = [] - total_duration = 0 - for wave_filename in args.sound_files: - assert_file_exists(wave_filename) - samples, sample_rate = read_wave(wave_filename) - duration = len(samples) / sample_rate - total_duration += duration + print("----------Use pre-defined keywords----------") + s = kws.create_stream() + s.accept_waveform(sample_rate, samples) + s.accept_waveform(sample_rate, tail_paddings) + s.input_finished() + while kws.is_ready(s): + kws.decode_stream(s) + r = kws.get_result(s) + if r != "": + # Remember to call reset right after detected a keyword + kws.reset_stream(s) - s = keyword_spotter.create_stream() + print(f"Detected {r}") - s.accept_waveform(sample_rate, samples) + print("----------Use pre-defined keywords + add a new keyword----------") - tail_paddings = np.zeros(int(0.66 * sample_rate), dtype=np.float32) - s.accept_waveform(sample_rate, tail_paddings) + s = kws.create_stream("y ǎn y uán @演员") + s.accept_waveform(sample_rate, samples) + s.accept_waveform(sample_rate, tail_paddings) + s.input_finished() + while kws.is_ready(s): + kws.decode_stream(s) + r = kws.get_result(s) + if r != "": + # Remember to call reset right after detected a keyword + kws.reset_stream(s) - s.input_finished() + print(f"Detected {r}") - streams.append(s) + print("----------Use pre-defined keywords + add 2 new keywords----------") - results = [""] * len(streams) - while True: - ready_list = [] - for i, s in enumerate(streams): - if keyword_spotter.is_ready(s): - ready_list.append(s) - r = keyword_spotter.get_result(s) - if r: - results[i] += f"{r}/" - print(f"{r} is detected.") - if len(ready_list) == 0: - break - keyword_spotter.decode_streams(ready_list) - end_time = time.time() - print("Done!") + s = kws.create_stream("y ǎn y uán @演员/zh ī m íng @知名") + s.accept_waveform(sample_rate, samples) + s.accept_waveform(sample_rate, tail_paddings) + s.input_finished() + while kws.is_ready(s): + kws.decode_stream(s) + r = kws.get_result(s) + if r != "": + # Remember to call reset right after detected a keyword + kws.reset_stream(s) - for wave_filename, result in zip(args.sound_files, results): - print(f"{wave_filename}\n{result}") - print("-" * 10) - - elapsed_seconds = end_time - start_time - rtf = elapsed_seconds / total_duration - print(f"num_threads: {args.num_threads}") - print(f"Wave duration: {total_duration:.3f} s") - print(f"Elapsed time: {elapsed_seconds:.3f} s") - print( - f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}" - ) + print(f"Detected {r}") if __name__ == "__main__": diff --git a/scripts/dotnet/KeywordSpotter.cs b/scripts/dotnet/KeywordSpotter.cs index 6f19fc94..d71d8924 100644 --- a/scripts/dotnet/KeywordSpotter.cs +++ b/scripts/dotnet/KeywordSpotter.cs @@ -46,6 +46,11 @@ namespace SherpaOnnx Decode(_handle.Handle, stream.Handle); } + public void Reset(OnlineStream stream) + { + Reset(_handle.Handle, stream.Handle); + } + // The caller should ensure all passed streams are ready for decoding. public void Decode(IEnumerable streams) { @@ -110,6 +115,9 @@ namespace SherpaOnnx [DllImport(Dll.Filename, EntryPoint = "SherpaOnnxDecodeKeywordStream")] private static extern void Decode(IntPtr handle, IntPtr stream); + [DllImport(Dll.Filename, EntryPoint = "SherpaOnnxResetKeywordStream")] + private static extern void Reset(IntPtr handle, IntPtr stream); + [DllImport(Dll.Filename, EntryPoint = "SherpaOnnxDecodeMultipleKeywordStreams")] private static extern void Decode(IntPtr handle, IntPtr[] streams, int n); diff --git a/scripts/go/sherpa_onnx.go b/scripts/go/sherpa_onnx.go index 4da12cfb..37f59662 100644 --- a/scripts/go/sherpa_onnx.go +++ b/scripts/go/sherpa_onnx.go @@ -1584,6 +1584,11 @@ func (spotter *KeywordSpotter) Decode(s *OnlineStream) { C.SherpaOnnxDecodeKeywordStream(spotter.impl, s.impl) } +// You MUST call it right after detecting a keyword +func (spotter *KeywordSpotter) Reset(s *OnlineStream) { + C.SherpaOnnxResetKeywordStream(spotter.impl, s.impl) +} + // Get the current result of stream since the last invoke of Reset() func (spotter *KeywordSpotter) GetResult(s *OnlineStream) *KeywordSpotterResult { p := C.SherpaOnnxGetKeywordResult(spotter.impl, s.impl) diff --git a/scripts/node-addon-api/lib/keyword-spotter.js b/scripts/node-addon-api/lib/keyword-spotter.js index 9fbadef4..d06764e8 100644 --- a/scripts/node-addon-api/lib/keyword-spotter.js +++ b/scripts/node-addon-api/lib/keyword-spotter.js @@ -20,6 +20,10 @@ class KeywordSpotter { addon.decodeKeywordStream(this.handle, stream.handle); } + reset(stream) { + addon.resetKeywordStream(this.handle, stream.handle); + } + getResult(stream) { const jsonStr = addon.getKeywordResultAsJson(this.handle, stream.handle); diff --git a/sherpa-onnx/c-api/c-api.cc b/sherpa-onnx/c-api/c-api.cc index da7c7317..ce6d5563 100644 --- a/sherpa-onnx/c-api/c-api.cc +++ b/sherpa-onnx/c-api/c-api.cc @@ -678,7 +678,7 @@ struct SherpaOnnxKeywordSpotter { std::unique_ptr impl; }; -SherpaOnnxKeywordSpotter *SherpaOnnxCreateKeywordSpotter( +const SherpaOnnxKeywordSpotter *SherpaOnnxCreateKeywordSpotter( const SherpaOnnxKeywordSpotterConfig *config) { sherpa_onnx::KeywordSpotterConfig spotter_config; @@ -755,37 +755,42 @@ SherpaOnnxKeywordSpotter *SherpaOnnxCreateKeywordSpotter( return spotter; } -void SherpaOnnxDestroyKeywordSpotter(SherpaOnnxKeywordSpotter *spotter) { +void SherpaOnnxDestroyKeywordSpotter(const SherpaOnnxKeywordSpotter *spotter) { delete spotter; } -SherpaOnnxOnlineStream *SherpaOnnxCreateKeywordStream( +const SherpaOnnxOnlineStream *SherpaOnnxCreateKeywordStream( const SherpaOnnxKeywordSpotter *spotter) { SherpaOnnxOnlineStream *stream = new SherpaOnnxOnlineStream(spotter->impl->CreateStream()); return stream; } -SherpaOnnxOnlineStream *SherpaOnnxCreateKeywordStreamWithKeywords( +const SherpaOnnxOnlineStream *SherpaOnnxCreateKeywordStreamWithKeywords( const SherpaOnnxKeywordSpotter *spotter, const char *keywords) { SherpaOnnxOnlineStream *stream = new SherpaOnnxOnlineStream(spotter->impl->CreateStream(keywords)); return stream; } -int32_t SherpaOnnxIsKeywordStreamReady(SherpaOnnxKeywordSpotter *spotter, - SherpaOnnxOnlineStream *stream) { +int32_t SherpaOnnxIsKeywordStreamReady(const SherpaOnnxKeywordSpotter *spotter, + const SherpaOnnxOnlineStream *stream) { return spotter->impl->IsReady(stream->impl.get()); } -void SherpaOnnxDecodeKeywordStream(SherpaOnnxKeywordSpotter *spotter, - SherpaOnnxOnlineStream *stream) { - return spotter->impl->DecodeStream(stream->impl.get()); +void SherpaOnnxDecodeKeywordStream(const SherpaOnnxKeywordSpotter *spotter, + const SherpaOnnxOnlineStream *stream) { + spotter->impl->DecodeStream(stream->impl.get()); } -void SherpaOnnxDecodeMultipleKeywordStreams(SherpaOnnxKeywordSpotter *spotter, - SherpaOnnxOnlineStream **streams, - int32_t n) { +void SherpaOnnxResetKeywordStream(const SherpaOnnxKeywordSpotter *spotter, + const SherpaOnnxOnlineStream *stream) { + spotter->impl->Reset(stream->impl.get()); +} + +void SherpaOnnxDecodeMultipleKeywordStreams( + const SherpaOnnxKeywordSpotter *spotter, + const SherpaOnnxOnlineStream **streams, int32_t n) { std::vector ss(n); for (int32_t i = 0; i != n; ++i) { ss[i] = streams[i]->impl.get(); @@ -794,7 +799,8 @@ void SherpaOnnxDecodeMultipleKeywordStreams(SherpaOnnxKeywordSpotter *spotter, } const SherpaOnnxKeywordResult *SherpaOnnxGetKeywordResult( - SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream *stream) { + const SherpaOnnxKeywordSpotter *spotter, + const SherpaOnnxOnlineStream *stream) { const sherpa_onnx::KeywordResult &result = spotter->impl->GetResult(stream->impl.get()); const auto &keyword = result.keyword; @@ -869,8 +875,9 @@ void SherpaOnnxDestroyKeywordResult(const SherpaOnnxKeywordResult *r) { } } -const char *SherpaOnnxGetKeywordResultAsJson(SherpaOnnxKeywordSpotter *spotter, - SherpaOnnxOnlineStream *stream) { +const char *SherpaOnnxGetKeywordResultAsJson( + const SherpaOnnxKeywordSpotter *spotter, + const SherpaOnnxOnlineStream *stream) { const sherpa_onnx::KeywordResult &result = spotter->impl->GetResult(stream->impl.get()); diff --git a/sherpa-onnx/c-api/c-api.h b/sherpa-onnx/c-api/c-api.h index db832178..10bd101f 100644 --- a/sherpa-onnx/c-api/c-api.h +++ b/sherpa-onnx/c-api/c-api.h @@ -600,7 +600,7 @@ SHERPA_ONNX_API const char *SherpaOnnxGetOfflineStreamResultAsJson( SHERPA_ONNX_API void SherpaOnnxDestroyOfflineStreamResultJson(const char *s); // ============================================================ -// For Keyword Spot +// For Keyword Spotter // ============================================================ SHERPA_ONNX_API typedef struct SherpaOnnxKeywordResult { /// The triggered keyword. @@ -660,21 +660,21 @@ SHERPA_ONNX_API typedef struct SherpaOnnxKeywordSpotter /// @param config Config for the keyword spotter. /// @return Return a pointer to the spotter. The user has to invoke /// SherpaOnnxDestroyKeywordSpotter() to free it to avoid memory leak. -SHERPA_ONNX_API SherpaOnnxKeywordSpotter *SherpaOnnxCreateKeywordSpotter( +SHERPA_ONNX_API const SherpaOnnxKeywordSpotter *SherpaOnnxCreateKeywordSpotter( const SherpaOnnxKeywordSpotterConfig *config); /// Free a pointer returned by SherpaOnnxCreateKeywordSpotter() /// /// @param p A pointer returned by SherpaOnnxCreateKeywordSpotter() SHERPA_ONNX_API void SherpaOnnxDestroyKeywordSpotter( - SherpaOnnxKeywordSpotter *spotter); + const SherpaOnnxKeywordSpotter *spotter); /// Create an online stream for accepting wave samples. /// /// @param spotter A pointer returned by SherpaOnnxCreateKeywordSpotter() /// @return Return a pointer to an OnlineStream. The user has to invoke /// SherpaOnnxDestroyOnlineStream() to free it to avoid memory leak. -SHERPA_ONNX_API SherpaOnnxOnlineStream *SherpaOnnxCreateKeywordStream( +SHERPA_ONNX_API const SherpaOnnxOnlineStream *SherpaOnnxCreateKeywordStream( const SherpaOnnxKeywordSpotter *spotter); /// Create an online stream for accepting wave samples with the specified hot @@ -684,7 +684,7 @@ SHERPA_ONNX_API SherpaOnnxOnlineStream *SherpaOnnxCreateKeywordStream( /// @param keywords A pointer points to the keywords that you set /// @return Return a pointer to an OnlineStream. The user has to invoke /// SherpaOnnxDestroyOnlineStream() to free it to avoid memory leak. -SHERPA_ONNX_API SherpaOnnxOnlineStream * +SHERPA_ONNX_API const SherpaOnnxOnlineStream * SherpaOnnxCreateKeywordStreamWithKeywords( const SherpaOnnxKeywordSpotter *spotter, const char *keywords); @@ -693,15 +693,22 @@ SherpaOnnxCreateKeywordStreamWithKeywords( /// /// @param spotter A pointer returned by SherpaOnnxCreateKeywordSpotter /// @param stream A pointer returned by SherpaOnnxCreateKeywordStream -SHERPA_ONNX_API int32_t SherpaOnnxIsKeywordStreamReady( - SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream *stream); +SHERPA_ONNX_API int32_t +SherpaOnnxIsKeywordStreamReady(const SherpaOnnxKeywordSpotter *spotter, + const SherpaOnnxOnlineStream *stream); /// Call this function to run the neural network model and decoding. // /// Precondition for this function: SherpaOnnxIsKeywordStreamReady() MUST /// return 1. SHERPA_ONNX_API void SherpaOnnxDecodeKeywordStream( - SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream *stream); + const SherpaOnnxKeywordSpotter *spotter, + const SherpaOnnxOnlineStream *stream); + +/// Please call it right after a keyword is detected +SHERPA_ONNX_API void SherpaOnnxResetKeywordStream( + const SherpaOnnxKeywordSpotter *spotter, + const SherpaOnnxOnlineStream *stream); /// This function is similar to SherpaOnnxDecodeKeywordStream(). It decodes /// multiple OnlineStream in parallel. @@ -714,8 +721,8 @@ SHERPA_ONNX_API void SherpaOnnxDecodeKeywordStream( /// SherpaOnnxCreateKeywordStream() /// @param n Number of elements in the given streams array. SHERPA_ONNX_API void SherpaOnnxDecodeMultipleKeywordStreams( - SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream **streams, - int32_t n); + const SherpaOnnxKeywordSpotter *spotter, + const SherpaOnnxOnlineStream **streams, int32_t n); /// Get the decoding results so far for an OnlineStream. /// @@ -725,7 +732,8 @@ SHERPA_ONNX_API void SherpaOnnxDecodeMultipleKeywordStreams( /// SherpaOnnxDestroyKeywordResult() to free the returned pointer to /// avoid memory leak. SHERPA_ONNX_API const SherpaOnnxKeywordResult *SherpaOnnxGetKeywordResult( - SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream *stream); + const SherpaOnnxKeywordSpotter *spotter, + const SherpaOnnxOnlineStream *stream); /// Destroy the pointer returned by SherpaOnnxGetKeywordResult(). /// @@ -736,7 +744,8 @@ SHERPA_ONNX_API void SherpaOnnxDestroyKeywordResult( // the user has to call SherpaOnnxFreeKeywordResultJson() to free the returned // pointer to avoid memory leak SHERPA_ONNX_API const char *SherpaOnnxGetKeywordResultAsJson( - SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream *stream); + const SherpaOnnxKeywordSpotter *spotter, + const SherpaOnnxOnlineStream *stream); SHERPA_ONNX_API void SherpaOnnxFreeKeywordResultJson(const char *s); diff --git a/sherpa-onnx/c-api/cxx-api.cc b/sherpa-onnx/c-api/cxx-api.cc index 50a1f4e1..63a47cc6 100644 --- a/sherpa-onnx/c-api/cxx-api.cc +++ b/sherpa-onnx/c-api/cxx-api.cc @@ -391,4 +391,112 @@ GeneratedAudio OfflineTts::Generate(const std::string &text, return ans; } +KeywordSpotter KeywordSpotter::Create(const KeywordSpotterConfig &config) { + struct SherpaOnnxKeywordSpotterConfig c; + memset(&c, 0, sizeof(c)); + + c.feat_config.sample_rate = config.feat_config.sample_rate; + + c.model_config.transducer.encoder = + config.model_config.transducer.encoder.c_str(); + c.model_config.transducer.decoder = + config.model_config.transducer.decoder.c_str(); + c.model_config.transducer.joiner = + config.model_config.transducer.joiner.c_str(); + c.feat_config.feature_dim = config.feat_config.feature_dim; + + c.model_config.paraformer.encoder = + config.model_config.paraformer.encoder.c_str(); + c.model_config.paraformer.decoder = + config.model_config.paraformer.decoder.c_str(); + + c.model_config.zipformer2_ctc.model = + config.model_config.zipformer2_ctc.model.c_str(); + + c.model_config.tokens = config.model_config.tokens.c_str(); + c.model_config.num_threads = config.model_config.num_threads; + c.model_config.provider = config.model_config.provider.c_str(); + c.model_config.debug = config.model_config.debug; + c.model_config.model_type = config.model_config.model_type.c_str(); + c.model_config.modeling_unit = config.model_config.modeling_unit.c_str(); + c.model_config.bpe_vocab = config.model_config.bpe_vocab.c_str(); + c.model_config.tokens_buf = config.model_config.tokens_buf.c_str(); + c.model_config.tokens_buf_size = config.model_config.tokens_buf.size(); + + c.max_active_paths = config.max_active_paths; + c.num_trailing_blanks = config.num_trailing_blanks; + c.keywords_score = config.keywords_score; + c.keywords_threshold = config.keywords_threshold; + c.keywords_file = config.keywords_file.c_str(); + + auto p = SherpaOnnxCreateKeywordSpotter(&c); + return KeywordSpotter(p); +} + +KeywordSpotter::KeywordSpotter(const SherpaOnnxKeywordSpotter *p) + : MoveOnly(p) {} + +void KeywordSpotter::Destroy(const SherpaOnnxKeywordSpotter *p) const { + SherpaOnnxDestroyKeywordSpotter(p); +} + +OnlineStream KeywordSpotter::CreateStream() const { + auto s = SherpaOnnxCreateKeywordStream(p_); + return OnlineStream{s}; +} + +OnlineStream KeywordSpotter::CreateStream(const std::string &keywords) const { + auto s = SherpaOnnxCreateKeywordStreamWithKeywords(p_, keywords.c_str()); + return OnlineStream{s}; +} + +bool KeywordSpotter::IsReady(const OnlineStream *s) const { + return SherpaOnnxIsKeywordStreamReady(p_, s->Get()); +} + +void KeywordSpotter::Decode(const OnlineStream *s) const { + return SherpaOnnxDecodeKeywordStream(p_, s->Get()); +} + +void KeywordSpotter::Decode(const OnlineStream *ss, int32_t n) const { + if (n <= 0) { + return; + } + + std::vector streams(n); + for (int32_t i = 0; i != n; ++n) { + streams[i] = ss[i].Get(); + } + + SherpaOnnxDecodeMultipleKeywordStreams(p_, streams.data(), n); +} + +KeywordResult KeywordSpotter::GetResult(const OnlineStream *s) const { + auto r = SherpaOnnxGetKeywordResult(p_, s->Get()); + + KeywordResult ans; + ans.keyword = r->keyword; + + ans.tokens.resize(r->count); + for (int32_t i = 0; i < r->count; ++i) { + ans.tokens[i] = r->tokens_arr[i]; + } + + if (r->timestamps) { + ans.timestamps.resize(r->count); + std::copy(r->timestamps, r->timestamps + r->count, ans.timestamps.data()); + } + + ans.start_time = r->start_time; + ans.json = r->json; + + SherpaOnnxDestroyKeywordResult(r); + + return ans; +} + +void KeywordSpotter::Reset(const OnlineStream *s) const { + SherpaOnnxResetKeywordStream(p_, s->Get()); +} + } // namespace sherpa_onnx::cxx diff --git a/sherpa-onnx/c-api/cxx-api.h b/sherpa-onnx/c-api/cxx-api.h index ce65a9ec..8416c594 100644 --- a/sherpa-onnx/c-api/cxx-api.h +++ b/sherpa-onnx/c-api/cxx-api.h @@ -406,6 +406,53 @@ class SHERPA_ONNX_API OfflineTts explicit OfflineTts(const SherpaOnnxOfflineTts *p); }; +// ============================================================ +// For Keyword Spotter +// ============================================================ + +struct KeywordResult { + std::string keyword; + std::vector tokens; + std::vector timestamps; + float start_time; + std::string json; +}; + +struct KeywordSpotterConfig { + FeatureConfig feat_config; + OnlineModelConfig model_config; + int32_t max_active_paths = 4; + int32_t num_trailing_blanks = 1; + float keywords_score = 1.0f; + float keywords_threshold = 0.25f; + std::string keywords_file; +}; + +class SHERPA_ONNX_API KeywordSpotter + : public MoveOnly { + public: + static KeywordSpotter Create(const KeywordSpotterConfig &config); + + void Destroy(const SherpaOnnxKeywordSpotter *p) const; + + OnlineStream CreateStream() const; + + OnlineStream CreateStream(const std::string &keywords) const; + + bool IsReady(const OnlineStream *s) const; + + void Decode(const OnlineStream *s) const; + + void Decode(const OnlineStream *ss, int32_t n) const; + + void Reset(const OnlineStream *s) const; + + KeywordResult GetResult(const OnlineStream *s) const; + + private: + explicit KeywordSpotter(const SherpaOnnxKeywordSpotter *p); +}; + } // namespace sherpa_onnx::cxx #endif // SHERPA_ONNX_C_API_CXX_API_H_ diff --git a/sherpa-onnx/csrc/keyword-spotter-impl.h b/sherpa-onnx/csrc/keyword-spotter-impl.h index ded735ff..6180f917 100644 --- a/sherpa-onnx/csrc/keyword-spotter-impl.h +++ b/sherpa-onnx/csrc/keyword-spotter-impl.h @@ -38,6 +38,8 @@ class KeywordSpotterImpl { virtual bool IsReady(OnlineStream *s) const = 0; + virtual void Reset(OnlineStream *s) const = 0; + virtual void DecodeStreams(OnlineStream **ss, int32_t n) const = 0; virtual KeywordResult GetResult(OnlineStream *s) const = 0; diff --git a/sherpa-onnx/csrc/keyword-spotter-transducer-impl.h b/sherpa-onnx/csrc/keyword-spotter-transducer-impl.h index 75963918..d29b8b58 100644 --- a/sherpa-onnx/csrc/keyword-spotter-transducer-impl.h +++ b/sherpa-onnx/csrc/keyword-spotter-transducer-impl.h @@ -195,8 +195,24 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl { return s->GetNumProcessedFrames() + model_->ChunkSize() < s->NumFramesReady(); } + void Reset(OnlineStream *s) const override { InitOnlineStream(s); } void DecodeStreams(OnlineStream **ss, int32_t n) const override { + for (int32_t i = 0; i < n; ++i) { + auto s = ss[i]; + auto r = s->GetKeywordResult(true); + int32_t num_trailing_blanks = r.num_trailing_blanks; + // assume subsampling_factor is 4 + // assume frameshift is 0.01 second + float trailing_slience = num_trailing_blanks * 4 * 0.01; + + // it resets automatically after detecting 1.5 seconds of silence + float threshold = 1.5; + if (trailing_slience > threshold) { + Reset(s); + } + } + int32_t chunk_size = model_->ChunkSize(); int32_t chunk_shift = model_->ChunkShift(); diff --git a/sherpa-onnx/csrc/keyword-spotter.cc b/sherpa-onnx/csrc/keyword-spotter.cc index d1bf6d63..66d0907a 100644 --- a/sherpa-onnx/csrc/keyword-spotter.cc +++ b/sherpa-onnx/csrc/keyword-spotter.cc @@ -157,6 +157,8 @@ bool KeywordSpotter::IsReady(OnlineStream *s) const { return impl_->IsReady(s); } +void KeywordSpotter::Reset(OnlineStream *s) const { impl_->Reset(s); } + void KeywordSpotter::DecodeStreams(OnlineStream **ss, int32_t n) const { impl_->DecodeStreams(ss, n); } diff --git a/sherpa-onnx/csrc/keyword-spotter.h b/sherpa-onnx/csrc/keyword-spotter.h index f0c31bdb..c933f4b2 100644 --- a/sherpa-onnx/csrc/keyword-spotter.h +++ b/sherpa-onnx/csrc/keyword-spotter.h @@ -129,6 +129,9 @@ class KeywordSpotter { */ bool IsReady(OnlineStream *s) const; + // Remember to call it after detecting a keyword + void Reset(OnlineStream *s) const; + /** Decode a single stream. */ void DecodeStream(OnlineStream *s) const { OnlineStream *ss[1] = {s}; diff --git a/sherpa-onnx/csrc/sherpa-onnx-keyword-spotter-alsa.cc b/sherpa-onnx/csrc/sherpa-onnx-keyword-spotter-alsa.cc index a909ff25..cfa46dc9 100644 --- a/sherpa-onnx/csrc/sherpa-onnx-keyword-spotter-alsa.cc +++ b/sherpa-onnx/csrc/sherpa-onnx-keyword-spotter-alsa.cc @@ -106,13 +106,15 @@ as the device_name. while (spotter.IsReady(stream.get())) { spotter.DecodeStream(stream.get()); - } - const auto r = spotter.GetResult(stream.get()); - if (!r.keyword.empty()) { - display.Print(keyword_index, r.AsJsonString()); - fflush(stderr); - keyword_index++; + const auto r = spotter.GetResult(stream.get()); + if (!r.keyword.empty()) { + display.Print(keyword_index, r.AsJsonString()); + fflush(stderr); + keyword_index++; + + spotter.Reset(stream.get()); + } } } diff --git a/sherpa-onnx/csrc/sherpa-onnx-keyword-spotter-microphone.cc b/sherpa-onnx/csrc/sherpa-onnx-keyword-spotter-microphone.cc index 903debea..4d75f9d4 100644 --- a/sherpa-onnx/csrc/sherpa-onnx-keyword-spotter-microphone.cc +++ b/sherpa-onnx/csrc/sherpa-onnx-keyword-spotter-microphone.cc @@ -150,13 +150,15 @@ for a list of pre-trained models to download. while (!stop) { while (spotter.IsReady(s.get())) { spotter.DecodeStream(s.get()); - } - const auto r = spotter.GetResult(s.get()); - if (!r.keyword.empty()) { - display.Print(keyword_index, r.AsJsonString()); - fflush(stderr); - keyword_index++; + const auto r = spotter.GetResult(s.get()); + if (!r.keyword.empty()) { + display.Print(keyword_index, r.AsJsonString()); + fflush(stderr); + keyword_index++; + + spotter.Reset(s.get()); + } } Pa_Sleep(20); // sleep for 20ms diff --git a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/KeywordSpotter.java b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/KeywordSpotter.java index a1b897b0..3565c05e 100644 --- a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/KeywordSpotter.java +++ b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/KeywordSpotter.java @@ -27,6 +27,10 @@ public class KeywordSpotter { decode(ptr, s.getPtr()); } + public void reset(OnlineStream s) { + reset(ptr, s.getPtr()); + } + public boolean isReady(OnlineStream s) { return isReady(ptr, s.getPtr()); } @@ -60,6 +64,8 @@ public class KeywordSpotter { private native void decode(long ptr, long streamPtr); + private native void reset(long ptr, long streamPtr); + private native boolean isReady(long ptr, long streamPtr); private native Object[] getResult(long ptr, long streamPtr); diff --git a/sherpa-onnx/jni/keyword-spotter.cc b/sherpa-onnx/jni/keyword-spotter.cc index 4ac80a29..0dc5685e 100644 --- a/sherpa-onnx/jni/keyword-spotter.cc +++ b/sherpa-onnx/jni/keyword-spotter.cc @@ -161,6 +161,15 @@ JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_KeywordSpotter_decode( kws->DecodeStream(stream); } +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_KeywordSpotter_reset( + JNIEnv * /*env*/, jobject /*obj*/, jlong ptr, jlong stream_ptr) { + auto kws = reinterpret_cast(ptr); + auto stream = reinterpret_cast(stream_ptr); + + kws->Reset(stream); +} + SHERPA_ONNX_EXTERN_C JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_KeywordSpotter_createStream( JNIEnv *env, jobject /*obj*/, jlong ptr, jstring keywords) { diff --git a/sherpa-onnx/kotlin-api/KeywordSpotter.kt b/sherpa-onnx/kotlin-api/KeywordSpotter.kt index 3801d32a..5b3cdbb7 100644 --- a/sherpa-onnx/kotlin-api/KeywordSpotter.kt +++ b/sherpa-onnx/kotlin-api/KeywordSpotter.kt @@ -49,6 +49,7 @@ class KeywordSpotter( } fun decode(stream: OnlineStream) = decode(ptr, stream.ptr) + fun reset(stream: OnlineStream) = reset(ptr, stream.ptr) fun isReady(stream: OnlineStream) = isReady(ptr, stream.ptr) fun getResult(stream: OnlineStream): KeywordSpotterResult { val objArray = getResult(ptr, stream.ptr) @@ -74,6 +75,7 @@ class KeywordSpotter( private external fun createStream(ptr: Long, keywords: String): Long private external fun isReady(ptr: Long, streamPtr: Long): Boolean private external fun decode(ptr: Long, streamPtr: Long) + private external fun reset(ptr: Long, streamPtr: Long) private external fun getResult(ptr: Long, streamPtr: Long): Array companion object { diff --git a/sherpa-onnx/python/csrc/keyword-spotter.cc b/sherpa-onnx/python/csrc/keyword-spotter.cc index 14499260..4a48ada4 100644 --- a/sherpa-onnx/python/csrc/keyword-spotter.cc +++ b/sherpa-onnx/python/csrc/keyword-spotter.cc @@ -67,6 +67,7 @@ void PybindKeywordSpotter(py::module *m) { py::arg("keywords"), py::call_guard()) .def("is_ready", &PyClass::IsReady, py::call_guard()) + .def("reset", &PyClass::Reset, py::call_guard()) .def("decode_stream", &PyClass::DecodeStream, py::call_guard()) .def( diff --git a/sherpa-onnx/python/sherpa_onnx/keyword_spotter.py b/sherpa-onnx/python/sherpa_onnx/keyword_spotter.py index 66d71698..a9d8573f 100644 --- a/sherpa-onnx/python/sherpa_onnx/keyword_spotter.py +++ b/sherpa-onnx/python/sherpa_onnx/keyword_spotter.py @@ -104,8 +104,8 @@ class KeywordSpotter(object): ) provider_config = ProviderConfig( - provider=provider, - device = device, + provider=provider, + device=device, ) model_config = OnlineModelConfig( @@ -131,6 +131,9 @@ class KeywordSpotter(object): ) self.keyword_spotter = _KeywordSpotter(keywords_spotter_config) + def reset_stream(self, s: OnlineStream): + self.keyword_spotter.reset(s) + def create_stream(self, keywords: Optional[str] = None): if keywords is None: return self.keyword_spotter.create_stream() diff --git a/sherpa-onnx/python/tests/test_keyword_spotter.py b/sherpa-onnx/python/tests/test_keyword_spotter.py index f4d79830..62691d18 100755 --- a/sherpa-onnx/python/tests/test_keyword_spotter.py +++ b/sherpa-onnx/python/tests/test_keyword_spotter.py @@ -98,6 +98,9 @@ class TestKeywordSpotter(unittest.TestCase): if r: print(f"{r} is detected.") results[i] += f"{r}/" + + keyword_spotter.reset_stream(s) + if len(ready_list) == 0: break keyword_spotter.decode_streams(ready_list) @@ -158,6 +161,9 @@ class TestKeywordSpotter(unittest.TestCase): if r: print(f"{r} is detected.") results[i] += f"{r}/" + + keyword_spotter.reset_stream(s) + if len(ready_list) == 0: break keyword_spotter.decode_streams(ready_list) diff --git a/swift-api-examples/SherpaOnnx.swift b/swift-api-examples/SherpaOnnx.swift index d4c39645..81a1c9e4 100644 --- a/swift-api-examples/SherpaOnnx.swift +++ b/swift-api-examples/SherpaOnnx.swift @@ -1076,6 +1076,10 @@ class SherpaOnnxKeywordSpotterWrapper { SherpaOnnxDecodeKeywordStream(spotter, stream) } + func reset() { + SherpaOnnxResetKeywordStream(spotter, stream) + } + func getResult() -> SherpaOnnxKeywordResultWrapper { let result: UnsafePointer? = SherpaOnnxGetKeywordResult( spotter, stream) diff --git a/swift-api-examples/keyword-spotting-from-file.swift b/swift-api-examples/keyword-spotting-from-file.swift index 08487eb4..498852a8 100644 --- a/swift-api-examples/keyword-spotting-from-file.swift +++ b/swift-api-examples/keyword-spotting-from-file.swift @@ -70,6 +70,9 @@ func run() { spotter.decode() let keyword = spotter.getResult().keyword if keyword != "" { + // Remember to call reset() right after detecting a keyword + spotter.reset() + print("Detected: \(keyword)") } } diff --git a/wasm/kws/CMakeLists.txt b/wasm/kws/CMakeLists.txt index 197aa38b..5620b80d 100644 --- a/wasm/kws/CMakeLists.txt +++ b/wasm/kws/CMakeLists.txt @@ -17,6 +17,7 @@ set(exported_functions SherpaOnnxIsKeywordStreamReady SherpaOnnxOnlineStreamAcceptWaveform SherpaOnnxOnlineStreamInputFinished + SherpaOnnxResetKeywordStream ) set(mangled_exported_functions) foreach(x IN LISTS exported_functions) diff --git a/wasm/kws/app.js b/wasm/kws/app.js index 1e97262a..6df20d23 100644 --- a/wasm/kws/app.js +++ b/wasm/kws/app.js @@ -102,15 +102,17 @@ if (navigator.mediaDevices.getUserMedia) { recognizer_stream.acceptWaveform(expectedSampleRate, samples); while (recognizer.isReady(recognizer_stream)) { recognizer.decode(recognizer_stream); - } + let result = recognizer.getResult(recognizer_stream); - let result = recognizer.getResult(recognizer_stream); + if (result.keyword.length > 0) { + console.log(result) + lastResult = result; + resultList.push(JSON.stringify(result)); - if (result.keyword.length > 0) { - console.log(result) - lastResult = result; - resultList.push(JSON.stringify(result)); + // remember to reset the stream right after detecting a keyword + recognizer.reset(recognizer_stream); + } } diff --git a/wasm/kws/sherpa-onnx-kws.js b/wasm/kws/sherpa-onnx-kws.js index b7c02335..57c81e09 100644 --- a/wasm/kws/sherpa-onnx-kws.js +++ b/wasm/kws/sherpa-onnx-kws.js @@ -296,8 +296,11 @@ class Kws { } decode(stream) { - return this.Module._SherpaOnnxDecodeKeywordStream( - this.handle, stream.handle); + this.Module._SherpaOnnxDecodeKeywordStream(this.handle, stream.handle); + } + + reset(stream) { + this.Module._SherpaOnnxResetKeywordStream(this.handle, stream.handle); } getResult(stream) {