Fix keyword spotting. (#1689)
Reset the stream right after detecting a keyword
This commit is contained in:
33
.github/scripts/test-python.sh
vendored
33
.github/scripts/test-python.sh
vendored
@@ -574,29 +574,6 @@ echo "sherpa_onnx version: $sherpa_onnx_version"
|
|||||||
pwd
|
pwd
|
||||||
ls -lh
|
ls -lh
|
||||||
|
|
||||||
repo=sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01
|
|
||||||
log "Start testing ${repo}"
|
|
||||||
|
|
||||||
pushd $dir
|
|
||||||
curl -LS -O https://github.com/pkufool/keyword-spotting-models/releases/download/v0.1/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01.tar.bz
|
|
||||||
tar xf sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01.tar.bz
|
|
||||||
rm sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01.tar.bz
|
|
||||||
popd
|
|
||||||
|
|
||||||
repo=$dir/$repo
|
|
||||||
ls -lh $repo
|
|
||||||
|
|
||||||
python3 ./python-api-examples/keyword-spotter.py \
|
|
||||||
--tokens=$repo/tokens.txt \
|
|
||||||
--encoder=$repo/encoder-epoch-12-avg-2-chunk-16-left-64.onnx \
|
|
||||||
--decoder=$repo/decoder-epoch-12-avg-2-chunk-16-left-64.onnx \
|
|
||||||
--joiner=$repo/joiner-epoch-12-avg-2-chunk-16-left-64.onnx \
|
|
||||||
--keywords-file=$repo/test_wavs/test_keywords.txt \
|
|
||||||
$repo/test_wavs/0.wav \
|
|
||||||
$repo/test_wavs/1.wav
|
|
||||||
|
|
||||||
rm -rf $repo
|
|
||||||
|
|
||||||
if [[ x$OS != x'windows-latest' ]]; then
|
if [[ x$OS != x'windows-latest' ]]; then
|
||||||
echo "OS: $OS"
|
echo "OS: $OS"
|
||||||
|
|
||||||
@@ -612,15 +589,7 @@ if [[ x$OS != x'windows-latest' ]]; then
|
|||||||
repo=$dir/$repo
|
repo=$dir/$repo
|
||||||
ls -lh $repo
|
ls -lh $repo
|
||||||
|
|
||||||
python3 ./python-api-examples/keyword-spotter.py \
|
python3 ./python-api-examples/keyword-spotter.py
|
||||||
--tokens=$repo/tokens.txt \
|
|
||||||
--encoder=$repo/encoder-epoch-12-avg-2-chunk-16-left-64.onnx \
|
|
||||||
--decoder=$repo/decoder-epoch-12-avg-2-chunk-16-left-64.onnx \
|
|
||||||
--joiner=$repo/joiner-epoch-12-avg-2-chunk-16-left-64.onnx \
|
|
||||||
--keywords-file=$repo/test_wavs/test_keywords.txt \
|
|
||||||
$repo/test_wavs/3.wav \
|
|
||||||
$repo/test_wavs/4.wav \
|
|
||||||
$repo/test_wavs/5.wav
|
|
||||||
|
|
||||||
python3 sherpa-onnx/python/tests/test_keyword_spotter.py --verbose
|
python3 sherpa-onnx/python/tests/test_keyword_spotter.py --verbose
|
||||||
|
|
||||||
|
|||||||
21
.github/workflows/c-api.yaml
vendored
21
.github/workflows/c-api.yaml
vendored
@@ -79,6 +79,27 @@ jobs:
|
|||||||
otool -L ./install/lib/libsherpa-onnx-c-api.dylib
|
otool -L ./install/lib/libsherpa-onnx-c-api.dylib
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
- name: Test kws (zh)
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
gcc -o kws-c-api ./c-api-examples/kws-c-api.c \
|
||||||
|
-I ./build/install/include \
|
||||||
|
-L ./build/install/lib/ \
|
||||||
|
-l sherpa-onnx-c-api \
|
||||||
|
-l onnxruntime
|
||||||
|
|
||||||
|
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/kws-models/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2
|
||||||
|
tar xvf sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2
|
||||||
|
rm sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2
|
||||||
|
|
||||||
|
export LD_LIBRARY_PATH=$PWD/build/install/lib:$LD_LIBRARY_PATH
|
||||||
|
export DYLD_LIBRARY_PATH=$PWD/build/install/lib:$DYLD_LIBRARY_PATH
|
||||||
|
|
||||||
|
./kws-c-api
|
||||||
|
|
||||||
|
rm ./kws-c-api
|
||||||
|
rm -rf sherpa-onnx-kws-*
|
||||||
|
|
||||||
- name: Test Kokoro TTS (en)
|
- name: Test Kokoro TTS (en)
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
22
.github/workflows/cxx-api.yaml
vendored
22
.github/workflows/cxx-api.yaml
vendored
@@ -81,6 +81,28 @@ jobs:
|
|||||||
otool -L ./install/lib/libsherpa-onnx-cxx-api.dylib
|
otool -L ./install/lib/libsherpa-onnx-cxx-api.dylib
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
- name: Test KWS (zh)
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
g++ -std=c++17 -o kws-cxx-api ./cxx-api-examples/kws-cxx-api.cc \
|
||||||
|
-I ./build/install/include \
|
||||||
|
-L ./build/install/lib/ \
|
||||||
|
-l sherpa-onnx-cxx-api \
|
||||||
|
-l sherpa-onnx-c-api \
|
||||||
|
-l onnxruntime
|
||||||
|
|
||||||
|
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/kws-models/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2
|
||||||
|
tar xvf sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2
|
||||||
|
rm sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2
|
||||||
|
|
||||||
|
export LD_LIBRARY_PATH=$PWD/build/install/lib:$LD_LIBRARY_PATH
|
||||||
|
export DYLD_LIBRARY_PATH=$PWD/build/install/lib:$DYLD_LIBRARY_PATH
|
||||||
|
|
||||||
|
./kws-cxx-api
|
||||||
|
|
||||||
|
rm kws-cxx-api
|
||||||
|
rm -rf sherpa-onnx-kws-*
|
||||||
|
|
||||||
- name: Test Kokoro TTS (en)
|
- name: Test Kokoro TTS (en)
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
@@ -151,13 +151,15 @@ class MainActivity : AppCompatActivity() {
|
|||||||
stream.acceptWaveform(samples, sampleRate = sampleRateInHz)
|
stream.acceptWaveform(samples, sampleRate = sampleRateInHz)
|
||||||
while (kws.isReady(stream)) {
|
while (kws.isReady(stream)) {
|
||||||
kws.decode(stream)
|
kws.decode(stream)
|
||||||
}
|
|
||||||
|
|
||||||
val text = kws.getResult(stream).keyword
|
val text = kws.getResult(stream).keyword
|
||||||
|
|
||||||
var textToDisplay = lastText
|
var textToDisplay = lastText
|
||||||
|
|
||||||
if (text.isNotBlank()) {
|
if (text.isNotBlank()) {
|
||||||
|
// Remember to reset the stream right after detecting a keyword
|
||||||
|
|
||||||
|
kws.reset(stream)
|
||||||
if (lastText.isBlank()) {
|
if (lastText.isBlank()) {
|
||||||
textToDisplay = "$idx: $text"
|
textToDisplay = "$idx: $text"
|
||||||
} else {
|
} else {
|
||||||
@@ -173,6 +175,7 @@ class MainActivity : AppCompatActivity() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private fun initMicrophone(): Boolean {
|
private fun initMicrophone(): Boolean {
|
||||||
if (ActivityCompat.checkSelfPermission(
|
if (ActivityCompat.checkSelfPermission(
|
||||||
|
|||||||
@@ -4,6 +4,9 @@ include_directories(${CMAKE_SOURCE_DIR})
|
|||||||
add_executable(decode-file-c-api decode-file-c-api.c)
|
add_executable(decode-file-c-api decode-file-c-api.c)
|
||||||
target_link_libraries(decode-file-c-api sherpa-onnx-c-api cargs)
|
target_link_libraries(decode-file-c-api sherpa-onnx-c-api cargs)
|
||||||
|
|
||||||
|
add_executable(kws-c-api kws-c-api.c)
|
||||||
|
target_link_libraries(kws-c-api sherpa-onnx-c-api)
|
||||||
|
|
||||||
if(SHERPA_ONNX_ENABLE_TTS)
|
if(SHERPA_ONNX_ENABLE_TTS)
|
||||||
add_executable(offline-tts-c-api offline-tts-c-api.c)
|
add_executable(offline-tts-c-api offline-tts-c-api.c)
|
||||||
target_link_libraries(offline-tts-c-api sherpa-onnx-c-api cargs)
|
target_link_libraries(offline-tts-c-api sherpa-onnx-c-api cargs)
|
||||||
|
|||||||
150
c-api-examples/kws-c-api.c
Normal file
150
c-api-examples/kws-c-api.c
Normal file
@@ -0,0 +1,150 @@
|
|||||||
|
// c-api-examples/kws-c-api.c
|
||||||
|
//
|
||||||
|
// Copyright (c) 2025 Xiaomi Corporation
|
||||||
|
//
|
||||||
|
// This file demonstrates how to use keywords spotter with sherpa-onnx's C
|
||||||
|
// clang-format off
|
||||||
|
//
|
||||||
|
// Usage
|
||||||
|
//
|
||||||
|
// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/kws-models/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2
|
||||||
|
// tar xvf sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2
|
||||||
|
// rm sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2
|
||||||
|
//
|
||||||
|
// ./kws-c-api
|
||||||
|
//
|
||||||
|
// clang-format on
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h> // exit
|
||||||
|
#include <string.h> // memset
|
||||||
|
|
||||||
|
#include "sherpa-onnx/c-api/c-api.h"
|
||||||
|
|
||||||
|
int32_t main() {
|
||||||
|
SherpaOnnxKeywordSpotterConfig config;
|
||||||
|
|
||||||
|
memset(&config, 0, sizeof(config));
|
||||||
|
config.model_config.transducer.encoder =
|
||||||
|
"./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/"
|
||||||
|
"encoder-epoch-12-avg-2-chunk-16-left-64.onnx";
|
||||||
|
|
||||||
|
config.model_config.transducer.decoder =
|
||||||
|
"./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/"
|
||||||
|
"decoder-epoch-12-avg-2-chunk-16-left-64.onnx";
|
||||||
|
|
||||||
|
config.model_config.transducer.joiner =
|
||||||
|
"./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/"
|
||||||
|
"joiner-epoch-12-avg-2-chunk-16-left-64.onnx";
|
||||||
|
|
||||||
|
config.model_config.tokens =
|
||||||
|
"./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/tokens.txt";
|
||||||
|
|
||||||
|
config.model_config.provider = "cpu";
|
||||||
|
config.model_config.num_threads = 1;
|
||||||
|
config.model_config.debug = 1;
|
||||||
|
|
||||||
|
config.keywords_file =
|
||||||
|
"./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/"
|
||||||
|
"test_keywords.txt";
|
||||||
|
|
||||||
|
const SherpaOnnxKeywordSpotter *kws = SherpaOnnxCreateKeywordSpotter(&config);
|
||||||
|
if (!kws) {
|
||||||
|
fprintf(stderr, "Please check your config");
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
fprintf(stderr,
|
||||||
|
"--Test pre-defined keywords from test_wavs/test_keywords.txt--\n");
|
||||||
|
|
||||||
|
const char *wav_filename =
|
||||||
|
"./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/3.wav";
|
||||||
|
|
||||||
|
float tail_paddings[8000] = {0}; // 0.5 seconds
|
||||||
|
|
||||||
|
const SherpaOnnxWave *wave = SherpaOnnxReadWave(wav_filename);
|
||||||
|
if (wave == NULL) {
|
||||||
|
fprintf(stderr, "Failed to read %s\n", wav_filename);
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
const SherpaOnnxOnlineStream *stream = SherpaOnnxCreateKeywordStream(kws);
|
||||||
|
if (!stream) {
|
||||||
|
fprintf(stderr, "Failed to create stream\n");
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
SherpaOnnxOnlineStreamAcceptWaveform(stream, wave->sample_rate, wave->samples,
|
||||||
|
wave->num_samples);
|
||||||
|
|
||||||
|
SherpaOnnxOnlineStreamAcceptWaveform(stream, wave->sample_rate, tail_paddings,
|
||||||
|
sizeof(tail_paddings) / sizeof(float));
|
||||||
|
SherpaOnnxOnlineStreamInputFinished(stream);
|
||||||
|
while (SherpaOnnxIsKeywordStreamReady(kws, stream)) {
|
||||||
|
SherpaOnnxDecodeKeywordStream(kws, stream);
|
||||||
|
const SherpaOnnxKeywordResult *r = SherpaOnnxGetKeywordResult(kws, stream);
|
||||||
|
if (r && r->json && strlen(r->keyword)) {
|
||||||
|
fprintf(stderr, "Detected keyword: %s\n", r->json);
|
||||||
|
|
||||||
|
// Remember to reset the keyword stream right after a keyword is detected
|
||||||
|
SherpaOnnxResetKeywordStream(kws, stream);
|
||||||
|
}
|
||||||
|
SherpaOnnxDestroyKeywordResult(r);
|
||||||
|
}
|
||||||
|
SherpaOnnxDestroyOnlineStream(stream);
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
|
||||||
|
fprintf(stderr, "--Use pre-defined keywords + add a new keyword--\n");
|
||||||
|
|
||||||
|
stream = SherpaOnnxCreateKeywordStreamWithKeywords(kws, "y ǎn y uán @演员");
|
||||||
|
|
||||||
|
SherpaOnnxOnlineStreamAcceptWaveform(stream, wave->sample_rate, wave->samples,
|
||||||
|
wave->num_samples);
|
||||||
|
|
||||||
|
SherpaOnnxOnlineStreamAcceptWaveform(stream, wave->sample_rate, tail_paddings,
|
||||||
|
sizeof(tail_paddings) / sizeof(float));
|
||||||
|
SherpaOnnxOnlineStreamInputFinished(stream);
|
||||||
|
while (SherpaOnnxIsKeywordStreamReady(kws, stream)) {
|
||||||
|
SherpaOnnxDecodeKeywordStream(kws, stream);
|
||||||
|
const SherpaOnnxKeywordResult *r = SherpaOnnxGetKeywordResult(kws, stream);
|
||||||
|
if (r && r->json && strlen(r->keyword)) {
|
||||||
|
fprintf(stderr, "Detected keyword: %s\n", r->json);
|
||||||
|
|
||||||
|
// Remember to reset the keyword stream
|
||||||
|
SherpaOnnxResetKeywordStream(kws, stream);
|
||||||
|
}
|
||||||
|
SherpaOnnxDestroyKeywordResult(r);
|
||||||
|
}
|
||||||
|
SherpaOnnxDestroyOnlineStream(stream);
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
|
||||||
|
fprintf(stderr, "--Use pre-defined keywords + add two new keywords--\n");
|
||||||
|
|
||||||
|
stream = SherpaOnnxCreateKeywordStreamWithKeywords(
|
||||||
|
kws, "y ǎn y uán @演员/zh ī m íng @知名");
|
||||||
|
|
||||||
|
SherpaOnnxOnlineStreamAcceptWaveform(stream, wave->sample_rate, wave->samples,
|
||||||
|
wave->num_samples);
|
||||||
|
|
||||||
|
SherpaOnnxOnlineStreamAcceptWaveform(stream, wave->sample_rate, tail_paddings,
|
||||||
|
sizeof(tail_paddings) / sizeof(float));
|
||||||
|
SherpaOnnxOnlineStreamInputFinished(stream);
|
||||||
|
while (SherpaOnnxIsKeywordStreamReady(kws, stream)) {
|
||||||
|
SherpaOnnxDecodeKeywordStream(kws, stream);
|
||||||
|
const SherpaOnnxKeywordResult *r = SherpaOnnxGetKeywordResult(kws, stream);
|
||||||
|
if (r && r->json && strlen(r->keyword)) {
|
||||||
|
fprintf(stderr, "Detected keyword: %s\n", r->json);
|
||||||
|
|
||||||
|
// Remember to reset the keyword stream
|
||||||
|
SherpaOnnxResetKeywordStream(kws, stream);
|
||||||
|
}
|
||||||
|
SherpaOnnxDestroyKeywordResult(r);
|
||||||
|
}
|
||||||
|
SherpaOnnxDestroyOnlineStream(stream);
|
||||||
|
|
||||||
|
SherpaOnnxFreeWave(wave);
|
||||||
|
SherpaOnnxDestroyKeywordSpotter(kws);
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
@@ -3,6 +3,9 @@ include_directories(${CMAKE_SOURCE_DIR})
|
|||||||
add_executable(streaming-zipformer-cxx-api ./streaming-zipformer-cxx-api.cc)
|
add_executable(streaming-zipformer-cxx-api ./streaming-zipformer-cxx-api.cc)
|
||||||
target_link_libraries(streaming-zipformer-cxx-api sherpa-onnx-cxx-api)
|
target_link_libraries(streaming-zipformer-cxx-api sherpa-onnx-cxx-api)
|
||||||
|
|
||||||
|
add_executable(kws-cxx-api ./kws-cxx-api.cc)
|
||||||
|
target_link_libraries(kws-cxx-api sherpa-onnx-cxx-api)
|
||||||
|
|
||||||
add_executable(streaming-zipformer-rtf-cxx-api ./streaming-zipformer-rtf-cxx-api.cc)
|
add_executable(streaming-zipformer-rtf-cxx-api ./streaming-zipformer-rtf-cxx-api.cc)
|
||||||
target_link_libraries(streaming-zipformer-rtf-cxx-api sherpa-onnx-cxx-api)
|
target_link_libraries(streaming-zipformer-rtf-cxx-api sherpa-onnx-cxx-api)
|
||||||
|
|
||||||
|
|||||||
141
cxx-api-examples/kws-cxx-api.cc
Normal file
141
cxx-api-examples/kws-cxx-api.cc
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
// cxx-api-examples/kws-cxx-api.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2025 Xiaomi Corporation
|
||||||
|
//
|
||||||
|
// This file demonstrates how to use keywords spotter with sherpa-onnx's C
|
||||||
|
// clang-format off
|
||||||
|
//
|
||||||
|
// Usage
|
||||||
|
//
|
||||||
|
// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/kws-models/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2
|
||||||
|
// tar xvf sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2
|
||||||
|
// rm sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2
|
||||||
|
//
|
||||||
|
// ./kws-cxx-api
|
||||||
|
//
|
||||||
|
// clang-format on
|
||||||
|
#include <array>
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/c-api/cxx-api.h"
|
||||||
|
|
||||||
|
int32_t main() {
|
||||||
|
using namespace sherpa_onnx::cxx; // NOLINT
|
||||||
|
|
||||||
|
KeywordSpotterConfig config;
|
||||||
|
config.model_config.transducer.encoder =
|
||||||
|
"./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/"
|
||||||
|
"encoder-epoch-12-avg-2-chunk-16-left-64.onnx";
|
||||||
|
|
||||||
|
config.model_config.transducer.decoder =
|
||||||
|
"./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/"
|
||||||
|
"decoder-epoch-12-avg-2-chunk-16-left-64.onnx";
|
||||||
|
|
||||||
|
config.model_config.transducer.joiner =
|
||||||
|
"./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/"
|
||||||
|
"joiner-epoch-12-avg-2-chunk-16-left-64.onnx";
|
||||||
|
|
||||||
|
config.model_config.tokens =
|
||||||
|
"./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/tokens.txt";
|
||||||
|
|
||||||
|
config.model_config.provider = "cpu";
|
||||||
|
config.model_config.num_threads = 1;
|
||||||
|
config.model_config.debug = 1;
|
||||||
|
|
||||||
|
config.keywords_file =
|
||||||
|
"./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/"
|
||||||
|
"test_keywords.txt";
|
||||||
|
|
||||||
|
KeywordSpotter kws = KeywordSpotter::Create(config);
|
||||||
|
if (!kws.Get()) {
|
||||||
|
std::cerr << "Please check your config\n";
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::cout
|
||||||
|
<< "--Test pre-defined keywords from test_wavs/test_keywords.txt--\n";
|
||||||
|
|
||||||
|
std::string wave_filename =
|
||||||
|
"./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/3.wav";
|
||||||
|
|
||||||
|
std::array<float, 8000> tail_paddings = {0}; // 0.5 seconds
|
||||||
|
|
||||||
|
Wave wave = ReadWave(wave_filename);
|
||||||
|
if (wave.samples.empty()) {
|
||||||
|
std::cerr << "Failed to read: '" << wave_filename << "'\n";
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
OnlineStream stream = kws.CreateStream();
|
||||||
|
if (!stream.Get()) {
|
||||||
|
std::cerr << "Failed to create stream\n";
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
stream.AcceptWaveform(wave.sample_rate, wave.samples.data(),
|
||||||
|
wave.samples.size());
|
||||||
|
|
||||||
|
stream.AcceptWaveform(wave.sample_rate, tail_paddings.data(),
|
||||||
|
tail_paddings.size());
|
||||||
|
stream.InputFinished();
|
||||||
|
|
||||||
|
while (kws.IsReady(&stream)) {
|
||||||
|
kws.Decode(&stream);
|
||||||
|
auto r = kws.GetResult(&stream);
|
||||||
|
if (!r.keyword.empty()) {
|
||||||
|
std::cout << "Detected keyword: " << r.json << "\n";
|
||||||
|
|
||||||
|
// Remember to reset the keyword stream right after a keyword is detected
|
||||||
|
kws.Reset(&stream);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
|
||||||
|
std::cout << "--Use pre-defined keywords + add a new keyword--\n";
|
||||||
|
|
||||||
|
stream = kws.CreateStream("y ǎn y uán @演员");
|
||||||
|
|
||||||
|
stream.AcceptWaveform(wave.sample_rate, wave.samples.data(),
|
||||||
|
wave.samples.size());
|
||||||
|
|
||||||
|
stream.AcceptWaveform(wave.sample_rate, tail_paddings.data(),
|
||||||
|
tail_paddings.size());
|
||||||
|
stream.InputFinished();
|
||||||
|
|
||||||
|
while (kws.IsReady(&stream)) {
|
||||||
|
kws.Decode(&stream);
|
||||||
|
auto r = kws.GetResult(&stream);
|
||||||
|
if (!r.keyword.empty()) {
|
||||||
|
std::cout << "Detected keyword: " << r.json << "\n";
|
||||||
|
|
||||||
|
// Remember to reset the keyword stream right after a keyword is detected
|
||||||
|
kws.Reset(&stream);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
|
||||||
|
std::cout << "--Use pre-defined keywords + add two new keywords--\n";
|
||||||
|
|
||||||
|
stream = kws.CreateStream("y ǎn y uán @演员/zh ī m íng @知名");
|
||||||
|
|
||||||
|
stream.AcceptWaveform(wave.sample_rate, wave.samples.data(),
|
||||||
|
wave.samples.size());
|
||||||
|
|
||||||
|
stream.AcceptWaveform(wave.sample_rate, tail_paddings.data(),
|
||||||
|
tail_paddings.size());
|
||||||
|
stream.InputFinished();
|
||||||
|
|
||||||
|
while (kws.IsReady(&stream)) {
|
||||||
|
kws.Decode(&stream);
|
||||||
|
auto r = kws.GetResult(&stream);
|
||||||
|
if (!r.keyword.empty()) {
|
||||||
|
std::cout << "Detected keyword: " << r.json << "\n";
|
||||||
|
|
||||||
|
// Remember to reset the keyword stream right after a keyword is detected
|
||||||
|
kws.Reset(&stream);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
@@ -73,6 +73,8 @@ void main(List<String> arguments) async {
|
|||||||
spotter.decode(stream);
|
spotter.decode(stream);
|
||||||
final result = spotter.getResult(stream);
|
final result = spotter.getResult(stream);
|
||||||
if (result.keyword != '') {
|
if (result.keyword != '') {
|
||||||
|
// Remember to reset the stream right after detecting a keyword
|
||||||
|
spotter.reset(stream);
|
||||||
print('Detected: ${result.keyword}');
|
print('Detected: ${result.keyword}');
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -53,6 +53,8 @@ class KeywordSpotterDemo
|
|||||||
var result = kws.GetResult(s);
|
var result = kws.GetResult(s);
|
||||||
if (result.Keyword != string.Empty)
|
if (result.Keyword != string.Empty)
|
||||||
{
|
{
|
||||||
|
// Remember to call Reset() right after detecting a keyword
|
||||||
|
kws.Reset(s);
|
||||||
Console.WriteLine("Detected: {0}", result.Keyword);
|
Console.WriteLine("Detected: {0}", result.Keyword);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -70,6 +72,8 @@ class KeywordSpotterDemo
|
|||||||
var result = kws.GetResult(s);
|
var result = kws.GetResult(s);
|
||||||
if (result.Keyword != string.Empty)
|
if (result.Keyword != string.Empty)
|
||||||
{
|
{
|
||||||
|
// Remember to call Reset() right after detecting a keyword
|
||||||
|
kws.Reset(s);
|
||||||
Console.WriteLine("Detected: {0}", result.Keyword);
|
Console.WriteLine("Detected: {0}", result.Keyword);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -89,6 +93,8 @@ class KeywordSpotterDemo
|
|||||||
var result = kws.GetResult(s);
|
var result = kws.GetResult(s);
|
||||||
if (result.Keyword != string.Empty)
|
if (result.Keyword != string.Empty)
|
||||||
{
|
{
|
||||||
|
// Remember to call Reset() right after detecting a keyword
|
||||||
|
kws.Reset(s);
|
||||||
Console.WriteLine("Detected: {0}", result.Keyword);
|
Console.WriteLine("Detected: {0}", result.Keyword);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -107,13 +107,16 @@ class KeywordSpotterDemo
|
|||||||
while (kws.IsReady(s))
|
while (kws.IsReady(s))
|
||||||
{
|
{
|
||||||
kws.Decode(s);
|
kws.Decode(s);
|
||||||
}
|
|
||||||
|
|
||||||
var result = kws.GetResult(s);
|
var result = kws.GetResult(s);
|
||||||
if (result.Keyword != string.Empty)
|
if (result.Keyword != string.Empty)
|
||||||
{
|
{
|
||||||
|
// Remember to call Reset() right after detecting a keyword
|
||||||
|
kws.Reset(s);
|
||||||
|
|
||||||
Console.WriteLine("Detected: {0}", result.Keyword);
|
Console.WriteLine("Detected: {0}", result.Keyword);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Thread.Sleep(200); // ms
|
Thread.Sleep(200); // ms
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -168,6 +168,10 @@ class KeywordSpotter {
|
|||||||
SherpaOnnxBindings.decodeKeywordStream?.call(ptr, stream.ptr);
|
SherpaOnnxBindings.decodeKeywordStream?.call(ptr, stream.ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void reset(OnlineStream stream) {
|
||||||
|
SherpaOnnxBindings.resetKeywordStream?.call(ptr, stream.ptr);
|
||||||
|
}
|
||||||
|
|
||||||
Pointer<SherpaOnnxKeywordSpotter> ptr;
|
Pointer<SherpaOnnxKeywordSpotter> ptr;
|
||||||
KeywordSpotterConfig config;
|
KeywordSpotterConfig config;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -667,6 +667,12 @@ typedef DecodeKeywordStreamNative = Void Function(
|
|||||||
typedef DecodeKeywordStream = void Function(
|
typedef DecodeKeywordStream = void Function(
|
||||||
Pointer<SherpaOnnxKeywordSpotter>, Pointer<SherpaOnnxOnlineStream>);
|
Pointer<SherpaOnnxKeywordSpotter>, Pointer<SherpaOnnxOnlineStream>);
|
||||||
|
|
||||||
|
typedef ResetKeywordStreamNative = Void Function(
|
||||||
|
Pointer<SherpaOnnxKeywordSpotter>, Pointer<SherpaOnnxOnlineStream>);
|
||||||
|
|
||||||
|
typedef ResetKeywordStream = void Function(
|
||||||
|
Pointer<SherpaOnnxKeywordSpotter>, Pointer<SherpaOnnxOnlineStream>);
|
||||||
|
|
||||||
typedef GetKeywordResultAsJsonNative = Pointer<Utf8> Function(
|
typedef GetKeywordResultAsJsonNative = Pointer<Utf8> Function(
|
||||||
Pointer<SherpaOnnxKeywordSpotter>, Pointer<SherpaOnnxOnlineStream>);
|
Pointer<SherpaOnnxKeywordSpotter>, Pointer<SherpaOnnxOnlineStream>);
|
||||||
|
|
||||||
@@ -1157,6 +1163,7 @@ class SherpaOnnxBindings {
|
|||||||
static CreateKeywordStreamWithKeywords? createKeywordStreamWithKeywords;
|
static CreateKeywordStreamWithKeywords? createKeywordStreamWithKeywords;
|
||||||
static IsKeywordStreamReady? isKeywordStreamReady;
|
static IsKeywordStreamReady? isKeywordStreamReady;
|
||||||
static DecodeKeywordStream? decodeKeywordStream;
|
static DecodeKeywordStream? decodeKeywordStream;
|
||||||
|
static ResetKeywordStream? resetKeywordStream;
|
||||||
static GetKeywordResultAsJson? getKeywordResultAsJson;
|
static GetKeywordResultAsJson? getKeywordResultAsJson;
|
||||||
static FreeKeywordResultJson? freeKeywordResultJson;
|
static FreeKeywordResultJson? freeKeywordResultJson;
|
||||||
|
|
||||||
@@ -1459,6 +1466,11 @@ class SherpaOnnxBindings {
|
|||||||
'SherpaOnnxDecodeKeywordStream')
|
'SherpaOnnxDecodeKeywordStream')
|
||||||
.asFunction();
|
.asFunction();
|
||||||
|
|
||||||
|
resetKeywordStream ??= dynamicLibrary
|
||||||
|
.lookup<NativeFunction<ResetKeywordStreamNative>>(
|
||||||
|
'SherpaOnnxResetKeywordStream')
|
||||||
|
.asFunction();
|
||||||
|
|
||||||
getKeywordResultAsJson ??= dynamicLibrary
|
getKeywordResultAsJson ??= dynamicLibrary
|
||||||
.lookup<NativeFunction<GetKeywordResultAsJsonNative>>(
|
.lookup<NativeFunction<GetKeywordResultAsJsonNative>>(
|
||||||
'SherpaOnnxGetKeywordResultAsJson')
|
'SherpaOnnxGetKeywordResultAsJson')
|
||||||
|
|||||||
@@ -43,6 +43,8 @@ func main() {
|
|||||||
spotter.Decode(stream)
|
spotter.Decode(stream)
|
||||||
result := spotter.GetResult(stream)
|
result := spotter.GetResult(stream)
|
||||||
if result.Keyword != "" {
|
if result.Keyword != "" {
|
||||||
|
// You have to reset the stream right after detecting a keyword
|
||||||
|
spotter.Reset(stream)
|
||||||
log.Printf("Detected %v\n", result.Keyword)
|
log.Printf("Detected %v\n", result.Keyword)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ static Napi::External<SherpaOnnxKeywordSpotter> CreateKeywordSpotterWrapper(
|
|||||||
SHERPA_ONNX_ASSIGN_ATTR_STR(keywords_buf, keywordsBuf);
|
SHERPA_ONNX_ASSIGN_ATTR_STR(keywords_buf, keywordsBuf);
|
||||||
SHERPA_ONNX_ASSIGN_ATTR_INT32(keywords_buf_size, keywordsBufSize);
|
SHERPA_ONNX_ASSIGN_ATTR_INT32(keywords_buf_size, keywordsBufSize);
|
||||||
|
|
||||||
SherpaOnnxKeywordSpotter *kws = SherpaOnnxCreateKeywordSpotter(&c);
|
const SherpaOnnxKeywordSpotter *kws = SherpaOnnxCreateKeywordSpotter(&c);
|
||||||
|
|
||||||
if (c.model_config.transducer.encoder) {
|
if (c.model_config.transducer.encoder) {
|
||||||
delete[] c.model_config.transducer.encoder;
|
delete[] c.model_config.transducer.encoder;
|
||||||
@@ -100,7 +100,8 @@ static Napi::External<SherpaOnnxKeywordSpotter> CreateKeywordSpotterWrapper(
|
|||||||
}
|
}
|
||||||
|
|
||||||
return Napi::External<SherpaOnnxKeywordSpotter>::New(
|
return Napi::External<SherpaOnnxKeywordSpotter>::New(
|
||||||
env, kws, [](Napi::Env env, SherpaOnnxKeywordSpotter *kws) {
|
env, const_cast<SherpaOnnxKeywordSpotter *>(kws),
|
||||||
|
[](Napi::Env env, SherpaOnnxKeywordSpotter *kws) {
|
||||||
SherpaOnnxDestroyKeywordSpotter(kws);
|
SherpaOnnxDestroyKeywordSpotter(kws);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@@ -125,13 +126,14 @@ static Napi::External<SherpaOnnxOnlineStream> CreateKeywordStreamWrapper(
|
|||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
SherpaOnnxKeywordSpotter *kws =
|
const SherpaOnnxKeywordSpotter *kws =
|
||||||
info[0].As<Napi::External<SherpaOnnxKeywordSpotter>>().Data();
|
info[0].As<Napi::External<SherpaOnnxKeywordSpotter>>().Data();
|
||||||
|
|
||||||
SherpaOnnxOnlineStream *stream = SherpaOnnxCreateKeywordStream(kws);
|
const SherpaOnnxOnlineStream *stream = SherpaOnnxCreateKeywordStream(kws);
|
||||||
|
|
||||||
return Napi::External<SherpaOnnxOnlineStream>::New(
|
return Napi::External<SherpaOnnxOnlineStream>::New(
|
||||||
env, stream, [](Napi::Env env, SherpaOnnxOnlineStream *stream) {
|
env, const_cast<SherpaOnnxOnlineStream *>(stream),
|
||||||
|
[](Napi::Env env, SherpaOnnxOnlineStream *stream) {
|
||||||
SherpaOnnxDestroyOnlineStream(stream);
|
SherpaOnnxDestroyOnlineStream(stream);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@@ -162,10 +164,10 @@ static Napi::Boolean IsKeywordStreamReadyWrapper(
|
|||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
SherpaOnnxKeywordSpotter *kws =
|
const SherpaOnnxKeywordSpotter *kws =
|
||||||
info[0].As<Napi::External<SherpaOnnxKeywordSpotter>>().Data();
|
info[0].As<Napi::External<SherpaOnnxKeywordSpotter>>().Data();
|
||||||
|
|
||||||
SherpaOnnxOnlineStream *stream =
|
const SherpaOnnxOnlineStream *stream =
|
||||||
info[1].As<Napi::External<SherpaOnnxOnlineStream>>().Data();
|
info[1].As<Napi::External<SherpaOnnxOnlineStream>>().Data();
|
||||||
|
|
||||||
int32_t is_ready = SherpaOnnxIsKeywordStreamReady(kws, stream);
|
int32_t is_ready = SherpaOnnxIsKeywordStreamReady(kws, stream);
|
||||||
@@ -198,15 +200,49 @@ static void DecodeKeywordStreamWrapper(const Napi::CallbackInfo &info) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
SherpaOnnxKeywordSpotter *kws =
|
const SherpaOnnxKeywordSpotter *kws =
|
||||||
info[0].As<Napi::External<SherpaOnnxKeywordSpotter>>().Data();
|
info[0].As<Napi::External<SherpaOnnxKeywordSpotter>>().Data();
|
||||||
|
|
||||||
SherpaOnnxOnlineStream *stream =
|
const SherpaOnnxOnlineStream *stream =
|
||||||
info[1].As<Napi::External<SherpaOnnxOnlineStream>>().Data();
|
info[1].As<Napi::External<SherpaOnnxOnlineStream>>().Data();
|
||||||
|
|
||||||
SherpaOnnxDecodeKeywordStream(kws, stream);
|
SherpaOnnxDecodeKeywordStream(kws, stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ResetKeywordStreamWrapper(const Napi::CallbackInfo &info) {
|
||||||
|
Napi::Env env = info.Env();
|
||||||
|
if (info.Length() != 2) {
|
||||||
|
std::ostringstream os;
|
||||||
|
os << "Expect only 2 arguments. Given: " << info.Length();
|
||||||
|
|
||||||
|
Napi::TypeError::New(env, os.str()).ThrowAsJavaScriptException();
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!info[0].IsExternal()) {
|
||||||
|
Napi::TypeError::New(env, "Argument 0 should be a keyword spotter pointer.")
|
||||||
|
.ThrowAsJavaScriptException();
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!info[1].IsExternal()) {
|
||||||
|
Napi::TypeError::New(env, "Argument 1 should be an online stream pointer.")
|
||||||
|
.ThrowAsJavaScriptException();
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const SherpaOnnxKeywordSpotter *kws =
|
||||||
|
info[0].As<Napi::External<SherpaOnnxKeywordSpotter>>().Data();
|
||||||
|
|
||||||
|
const SherpaOnnxOnlineStream *stream =
|
||||||
|
info[1].As<Napi::External<SherpaOnnxOnlineStream>>().Data();
|
||||||
|
|
||||||
|
SherpaOnnxResetKeywordStream(kws, stream);
|
||||||
|
}
|
||||||
|
|
||||||
static Napi::String GetKeywordResultAsJsonWrapper(
|
static Napi::String GetKeywordResultAsJsonWrapper(
|
||||||
const Napi::CallbackInfo &info) {
|
const Napi::CallbackInfo &info) {
|
||||||
Napi::Env env = info.Env();
|
Napi::Env env = info.Env();
|
||||||
@@ -233,10 +269,10 @@ static Napi::String GetKeywordResultAsJsonWrapper(
|
|||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
SherpaOnnxKeywordSpotter *kws =
|
const SherpaOnnxKeywordSpotter *kws =
|
||||||
info[0].As<Napi::External<SherpaOnnxKeywordSpotter>>().Data();
|
info[0].As<Napi::External<SherpaOnnxKeywordSpotter>>().Data();
|
||||||
|
|
||||||
SherpaOnnxOnlineStream *stream =
|
const SherpaOnnxOnlineStream *stream =
|
||||||
info[1].As<Napi::External<SherpaOnnxOnlineStream>>().Data();
|
info[1].As<Napi::External<SherpaOnnxOnlineStream>>().Data();
|
||||||
|
|
||||||
const char *json = SherpaOnnxGetKeywordResultAsJson(kws, stream);
|
const char *json = SherpaOnnxGetKeywordResultAsJson(kws, stream);
|
||||||
@@ -261,6 +297,9 @@ void InitKeywordSpotting(Napi::Env env, Napi::Object exports) {
|
|||||||
exports.Set(Napi::String::New(env, "decodeKeywordStream"),
|
exports.Set(Napi::String::New(env, "decodeKeywordStream"),
|
||||||
Napi::Function::New(env, DecodeKeywordStreamWrapper));
|
Napi::Function::New(env, DecodeKeywordStreamWrapper));
|
||||||
|
|
||||||
|
exports.Set(Napi::String::New(env, "resetKeywordStream"),
|
||||||
|
Napi::Function::New(env, ResetKeywordStreamWrapper));
|
||||||
|
|
||||||
exports.Set(Napi::String::New(env, "getKeywordResultAsJson"),
|
exports.Set(Napi::String::New(env, "getKeywordResultAsJson"),
|
||||||
Napi::Function::New(env, GetKeywordResultAsJsonWrapper));
|
Napi::Function::New(env, GetKeywordResultAsJsonWrapper));
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -56,6 +56,8 @@ public class KyewordSpotterFromFile {
|
|||||||
|
|
||||||
String keyword = kws.getResult(stream).getKeyword();
|
String keyword = kws.getResult(stream).getKeyword();
|
||||||
if (!keyword.isEmpty()) {
|
if (!keyword.isEmpty()) {
|
||||||
|
// Remember to reset the stream right after detecting a keyword
|
||||||
|
kws.reset(stream);
|
||||||
System.out.printf("Detected keyword: %s\n", keyword);
|
System.out.printf("Detected keyword: %s\n", keyword);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -41,6 +41,9 @@ while (kws.isReady(stream)) {
|
|||||||
const keyword = kws.getResult(stream).keyword;
|
const keyword = kws.getResult(stream).keyword;
|
||||||
if (keyword != '') {
|
if (keyword != '') {
|
||||||
detectedKeywords.push(keyword);
|
detectedKeywords.push(keyword);
|
||||||
|
|
||||||
|
// remember to reset the stream right after detecting a keyword
|
||||||
|
kws.reset(stream);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
console.log(detectedKeywords);
|
console.log(detectedKeywords);
|
||||||
|
|||||||
@@ -169,6 +169,8 @@ def main():
|
|||||||
|
|
||||||
print("Started! Please speak")
|
print("Started! Please speak")
|
||||||
|
|
||||||
|
idx = 0
|
||||||
|
|
||||||
sample_rate = 16000
|
sample_rate = 16000
|
||||||
samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms
|
samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms
|
||||||
stream = keyword_spotter.create_stream()
|
stream = keyword_spotter.create_stream()
|
||||||
@@ -181,7 +183,10 @@ def main():
|
|||||||
keyword_spotter.decode_stream(stream)
|
keyword_spotter.decode_stream(stream)
|
||||||
result = keyword_spotter.get_result(stream)
|
result = keyword_spotter.get_result(stream)
|
||||||
if result:
|
if result:
|
||||||
print("\r{}".format(result), end="", flush=True)
|
print(f"{idx}: {result }")
|
||||||
|
idx += 1
|
||||||
|
# Remember to reset stream right after detecting a keyword
|
||||||
|
keyword_spotter.reset_stream(stream)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -18,122 +18,6 @@ import numpy as np
|
|||||||
import sherpa_onnx
|
import sherpa_onnx
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--tokens",
|
|
||||||
type=str,
|
|
||||||
help="Path to tokens.txt",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--encoder",
|
|
||||||
type=str,
|
|
||||||
help="Path to the transducer encoder model",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--decoder",
|
|
||||||
type=str,
|
|
||||||
help="Path to the transducer decoder model",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--joiner",
|
|
||||||
type=str,
|
|
||||||
help="Path to the transducer joiner model",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--num-threads",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="Number of threads for neural network computation",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--provider",
|
|
||||||
type=str,
|
|
||||||
default="cpu",
|
|
||||||
help="Valid values: cpu, cuda, coreml",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--max-active-paths",
|
|
||||||
type=int,
|
|
||||||
default=4,
|
|
||||||
help="""
|
|
||||||
It specifies number of active paths to keep during decoding.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--num-trailing-blanks",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="""The number of trailing blanks a keyword should be followed. Setting
|
|
||||||
to a larger value (e.g. 8) when your keywords has overlapping tokens
|
|
||||||
between each other.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--keywords-file",
|
|
||||||
type=str,
|
|
||||||
help="""
|
|
||||||
The file containing keywords, one words/phrases per line, and for each
|
|
||||||
phrase the bpe/cjkchar/pinyin are separated by a space. For example:
|
|
||||||
|
|
||||||
▁HE LL O ▁WORLD
|
|
||||||
x iǎo ài t óng x ué
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--keywords-score",
|
|
||||||
type=float,
|
|
||||||
default=1.0,
|
|
||||||
help="""
|
|
||||||
The boosting score of each token for keywords. The larger the easier to
|
|
||||||
survive beam search.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--keywords-threshold",
|
|
||||||
type=float,
|
|
||||||
default=0.25,
|
|
||||||
help="""
|
|
||||||
The trigger threshold (i.e. probability) of the keyword. The larger the
|
|
||||||
harder to trigger.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"sound_files",
|
|
||||||
type=str,
|
|
||||||
nargs="+",
|
|
||||||
help="The input sound file(s) to decode. Each file must be of WAVE"
|
|
||||||
"format with a single channel, and each sample has 16-bit, "
|
|
||||||
"i.e., int16_t. "
|
|
||||||
"The sample rate of the file can be arbitrary and does not need to "
|
|
||||||
"be 16 kHz",
|
|
||||||
)
|
|
||||||
|
|
||||||
return parser.parse_args()
|
|
||||||
|
|
||||||
|
|
||||||
def assert_file_exists(filename: str):
|
|
||||||
assert Path(filename).is_file(), (
|
|
||||||
f"{filename} does not exist!\n"
|
|
||||||
"Please refer to "
|
|
||||||
"https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html to download it"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
|
def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@@ -159,83 +43,74 @@ def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
|
|||||||
return samples_float32, f.getframerate()
|
return samples_float32, f.getframerate()
|
||||||
|
|
||||||
|
|
||||||
|
def create_keyword_spotter():
|
||||||
|
kws = sherpa_onnx.KeywordSpotter(
|
||||||
|
tokens="./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/tokens.txt",
|
||||||
|
encoder="./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/encoder-epoch-12-avg-2-chunk-16-left-64.onnx",
|
||||||
|
decoder="./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/decoder-epoch-12-avg-2-chunk-16-left-64.onnx",
|
||||||
|
joiner="./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/joiner-epoch-12-avg-2-chunk-16-left-64.onnx",
|
||||||
|
num_threads=2,
|
||||||
|
keywords_file="./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/test_keywords.txt",
|
||||||
|
provider="cpu",
|
||||||
|
)
|
||||||
|
|
||||||
|
return kws
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = get_args()
|
kws = create_keyword_spotter()
|
||||||
assert_file_exists(args.tokens)
|
|
||||||
assert_file_exists(args.encoder)
|
|
||||||
assert_file_exists(args.decoder)
|
|
||||||
assert_file_exists(args.joiner)
|
|
||||||
|
|
||||||
assert Path(
|
wave_filename = (
|
||||||
args.keywords_file
|
"./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/3.wav"
|
||||||
).is_file(), (
|
|
||||||
f"keywords_file : {args.keywords_file} not exist, please provide a valid path."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
keyword_spotter = sherpa_onnx.KeywordSpotter(
|
|
||||||
tokens=args.tokens,
|
|
||||||
encoder=args.encoder,
|
|
||||||
decoder=args.decoder,
|
|
||||||
joiner=args.joiner,
|
|
||||||
num_threads=args.num_threads,
|
|
||||||
max_active_paths=args.max_active_paths,
|
|
||||||
keywords_file=args.keywords_file,
|
|
||||||
keywords_score=args.keywords_score,
|
|
||||||
keywords_threshold=args.keywords_threshold,
|
|
||||||
num_trailing_blanks=args.num_trailing_blanks,
|
|
||||||
provider=args.provider,
|
|
||||||
)
|
|
||||||
|
|
||||||
print("Started!")
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
streams = []
|
|
||||||
total_duration = 0
|
|
||||||
for wave_filename in args.sound_files:
|
|
||||||
assert_file_exists(wave_filename)
|
|
||||||
samples, sample_rate = read_wave(wave_filename)
|
samples, sample_rate = read_wave(wave_filename)
|
||||||
duration = len(samples) / sample_rate
|
|
||||||
total_duration += duration
|
|
||||||
|
|
||||||
s = keyword_spotter.create_stream()
|
|
||||||
|
|
||||||
s.accept_waveform(sample_rate, samples)
|
|
||||||
|
|
||||||
tail_paddings = np.zeros(int(0.66 * sample_rate), dtype=np.float32)
|
tail_paddings = np.zeros(int(0.66 * sample_rate), dtype=np.float32)
|
||||||
|
|
||||||
|
print("----------Use pre-defined keywords----------")
|
||||||
|
s = kws.create_stream()
|
||||||
|
s.accept_waveform(sample_rate, samples)
|
||||||
s.accept_waveform(sample_rate, tail_paddings)
|
s.accept_waveform(sample_rate, tail_paddings)
|
||||||
|
|
||||||
s.input_finished()
|
s.input_finished()
|
||||||
|
while kws.is_ready(s):
|
||||||
|
kws.decode_stream(s)
|
||||||
|
r = kws.get_result(s)
|
||||||
|
if r != "":
|
||||||
|
# Remember to call reset right after detected a keyword
|
||||||
|
kws.reset_stream(s)
|
||||||
|
|
||||||
streams.append(s)
|
print(f"Detected {r}")
|
||||||
|
|
||||||
results = [""] * len(streams)
|
print("----------Use pre-defined keywords + add a new keyword----------")
|
||||||
while True:
|
|
||||||
ready_list = []
|
|
||||||
for i, s in enumerate(streams):
|
|
||||||
if keyword_spotter.is_ready(s):
|
|
||||||
ready_list.append(s)
|
|
||||||
r = keyword_spotter.get_result(s)
|
|
||||||
if r:
|
|
||||||
results[i] += f"{r}/"
|
|
||||||
print(f"{r} is detected.")
|
|
||||||
if len(ready_list) == 0:
|
|
||||||
break
|
|
||||||
keyword_spotter.decode_streams(ready_list)
|
|
||||||
end_time = time.time()
|
|
||||||
print("Done!")
|
|
||||||
|
|
||||||
for wave_filename, result in zip(args.sound_files, results):
|
s = kws.create_stream("y ǎn y uán @演员")
|
||||||
print(f"{wave_filename}\n{result}")
|
s.accept_waveform(sample_rate, samples)
|
||||||
print("-" * 10)
|
s.accept_waveform(sample_rate, tail_paddings)
|
||||||
|
s.input_finished()
|
||||||
|
while kws.is_ready(s):
|
||||||
|
kws.decode_stream(s)
|
||||||
|
r = kws.get_result(s)
|
||||||
|
if r != "":
|
||||||
|
# Remember to call reset right after detected a keyword
|
||||||
|
kws.reset_stream(s)
|
||||||
|
|
||||||
elapsed_seconds = end_time - start_time
|
print(f"Detected {r}")
|
||||||
rtf = elapsed_seconds / total_duration
|
|
||||||
print(f"num_threads: {args.num_threads}")
|
print("----------Use pre-defined keywords + add 2 new keywords----------")
|
||||||
print(f"Wave duration: {total_duration:.3f} s")
|
|
||||||
print(f"Elapsed time: {elapsed_seconds:.3f} s")
|
s = kws.create_stream("y ǎn y uán @演员/zh ī m íng @知名")
|
||||||
print(
|
s.accept_waveform(sample_rate, samples)
|
||||||
f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}"
|
s.accept_waveform(sample_rate, tail_paddings)
|
||||||
)
|
s.input_finished()
|
||||||
|
while kws.is_ready(s):
|
||||||
|
kws.decode_stream(s)
|
||||||
|
r = kws.get_result(s)
|
||||||
|
if r != "":
|
||||||
|
# Remember to call reset right after detected a keyword
|
||||||
|
kws.reset_stream(s)
|
||||||
|
|
||||||
|
print(f"Detected {r}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -46,6 +46,11 @@ namespace SherpaOnnx
|
|||||||
Decode(_handle.Handle, stream.Handle);
|
Decode(_handle.Handle, stream.Handle);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void Reset(OnlineStream stream)
|
||||||
|
{
|
||||||
|
Reset(_handle.Handle, stream.Handle);
|
||||||
|
}
|
||||||
|
|
||||||
// The caller should ensure all passed streams are ready for decoding.
|
// The caller should ensure all passed streams are ready for decoding.
|
||||||
public void Decode(IEnumerable<OnlineStream> streams)
|
public void Decode(IEnumerable<OnlineStream> streams)
|
||||||
{
|
{
|
||||||
@@ -110,6 +115,9 @@ namespace SherpaOnnx
|
|||||||
[DllImport(Dll.Filename, EntryPoint = "SherpaOnnxDecodeKeywordStream")]
|
[DllImport(Dll.Filename, EntryPoint = "SherpaOnnxDecodeKeywordStream")]
|
||||||
private static extern void Decode(IntPtr handle, IntPtr stream);
|
private static extern void Decode(IntPtr handle, IntPtr stream);
|
||||||
|
|
||||||
|
[DllImport(Dll.Filename, EntryPoint = "SherpaOnnxResetKeywordStream")]
|
||||||
|
private static extern void Reset(IntPtr handle, IntPtr stream);
|
||||||
|
|
||||||
[DllImport(Dll.Filename, EntryPoint = "SherpaOnnxDecodeMultipleKeywordStreams")]
|
[DllImport(Dll.Filename, EntryPoint = "SherpaOnnxDecodeMultipleKeywordStreams")]
|
||||||
private static extern void Decode(IntPtr handle, IntPtr[] streams, int n);
|
private static extern void Decode(IntPtr handle, IntPtr[] streams, int n);
|
||||||
|
|
||||||
|
|||||||
@@ -1584,6 +1584,11 @@ func (spotter *KeywordSpotter) Decode(s *OnlineStream) {
|
|||||||
C.SherpaOnnxDecodeKeywordStream(spotter.impl, s.impl)
|
C.SherpaOnnxDecodeKeywordStream(spotter.impl, s.impl)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// You MUST call it right after detecting a keyword
|
||||||
|
func (spotter *KeywordSpotter) Reset(s *OnlineStream) {
|
||||||
|
C.SherpaOnnxResetKeywordStream(spotter.impl, s.impl)
|
||||||
|
}
|
||||||
|
|
||||||
// Get the current result of stream since the last invoke of Reset()
|
// Get the current result of stream since the last invoke of Reset()
|
||||||
func (spotter *KeywordSpotter) GetResult(s *OnlineStream) *KeywordSpotterResult {
|
func (spotter *KeywordSpotter) GetResult(s *OnlineStream) *KeywordSpotterResult {
|
||||||
p := C.SherpaOnnxGetKeywordResult(spotter.impl, s.impl)
|
p := C.SherpaOnnxGetKeywordResult(spotter.impl, s.impl)
|
||||||
|
|||||||
@@ -20,6 +20,10 @@ class KeywordSpotter {
|
|||||||
addon.decodeKeywordStream(this.handle, stream.handle);
|
addon.decodeKeywordStream(this.handle, stream.handle);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
reset(stream) {
|
||||||
|
addon.resetKeywordStream(this.handle, stream.handle);
|
||||||
|
}
|
||||||
|
|
||||||
getResult(stream) {
|
getResult(stream) {
|
||||||
const jsonStr = addon.getKeywordResultAsJson(this.handle, stream.handle);
|
const jsonStr = addon.getKeywordResultAsJson(this.handle, stream.handle);
|
||||||
|
|
||||||
|
|||||||
@@ -678,7 +678,7 @@ struct SherpaOnnxKeywordSpotter {
|
|||||||
std::unique_ptr<sherpa_onnx::KeywordSpotter> impl;
|
std::unique_ptr<sherpa_onnx::KeywordSpotter> impl;
|
||||||
};
|
};
|
||||||
|
|
||||||
SherpaOnnxKeywordSpotter *SherpaOnnxCreateKeywordSpotter(
|
const SherpaOnnxKeywordSpotter *SherpaOnnxCreateKeywordSpotter(
|
||||||
const SherpaOnnxKeywordSpotterConfig *config) {
|
const SherpaOnnxKeywordSpotterConfig *config) {
|
||||||
sherpa_onnx::KeywordSpotterConfig spotter_config;
|
sherpa_onnx::KeywordSpotterConfig spotter_config;
|
||||||
|
|
||||||
@@ -755,37 +755,42 @@ SherpaOnnxKeywordSpotter *SherpaOnnxCreateKeywordSpotter(
|
|||||||
return spotter;
|
return spotter;
|
||||||
}
|
}
|
||||||
|
|
||||||
void SherpaOnnxDestroyKeywordSpotter(SherpaOnnxKeywordSpotter *spotter) {
|
void SherpaOnnxDestroyKeywordSpotter(const SherpaOnnxKeywordSpotter *spotter) {
|
||||||
delete spotter;
|
delete spotter;
|
||||||
}
|
}
|
||||||
|
|
||||||
SherpaOnnxOnlineStream *SherpaOnnxCreateKeywordStream(
|
const SherpaOnnxOnlineStream *SherpaOnnxCreateKeywordStream(
|
||||||
const SherpaOnnxKeywordSpotter *spotter) {
|
const SherpaOnnxKeywordSpotter *spotter) {
|
||||||
SherpaOnnxOnlineStream *stream =
|
SherpaOnnxOnlineStream *stream =
|
||||||
new SherpaOnnxOnlineStream(spotter->impl->CreateStream());
|
new SherpaOnnxOnlineStream(spotter->impl->CreateStream());
|
||||||
return stream;
|
return stream;
|
||||||
}
|
}
|
||||||
|
|
||||||
SherpaOnnxOnlineStream *SherpaOnnxCreateKeywordStreamWithKeywords(
|
const SherpaOnnxOnlineStream *SherpaOnnxCreateKeywordStreamWithKeywords(
|
||||||
const SherpaOnnxKeywordSpotter *spotter, const char *keywords) {
|
const SherpaOnnxKeywordSpotter *spotter, const char *keywords) {
|
||||||
SherpaOnnxOnlineStream *stream =
|
SherpaOnnxOnlineStream *stream =
|
||||||
new SherpaOnnxOnlineStream(spotter->impl->CreateStream(keywords));
|
new SherpaOnnxOnlineStream(spotter->impl->CreateStream(keywords));
|
||||||
return stream;
|
return stream;
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t SherpaOnnxIsKeywordStreamReady(SherpaOnnxKeywordSpotter *spotter,
|
int32_t SherpaOnnxIsKeywordStreamReady(const SherpaOnnxKeywordSpotter *spotter,
|
||||||
SherpaOnnxOnlineStream *stream) {
|
const SherpaOnnxOnlineStream *stream) {
|
||||||
return spotter->impl->IsReady(stream->impl.get());
|
return spotter->impl->IsReady(stream->impl.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
void SherpaOnnxDecodeKeywordStream(SherpaOnnxKeywordSpotter *spotter,
|
void SherpaOnnxDecodeKeywordStream(const SherpaOnnxKeywordSpotter *spotter,
|
||||||
SherpaOnnxOnlineStream *stream) {
|
const SherpaOnnxOnlineStream *stream) {
|
||||||
return spotter->impl->DecodeStream(stream->impl.get());
|
spotter->impl->DecodeStream(stream->impl.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
void SherpaOnnxDecodeMultipleKeywordStreams(SherpaOnnxKeywordSpotter *spotter,
|
void SherpaOnnxResetKeywordStream(const SherpaOnnxKeywordSpotter *spotter,
|
||||||
SherpaOnnxOnlineStream **streams,
|
const SherpaOnnxOnlineStream *stream) {
|
||||||
int32_t n) {
|
spotter->impl->Reset(stream->impl.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
void SherpaOnnxDecodeMultipleKeywordStreams(
|
||||||
|
const SherpaOnnxKeywordSpotter *spotter,
|
||||||
|
const SherpaOnnxOnlineStream **streams, int32_t n) {
|
||||||
std::vector<sherpa_onnx::OnlineStream *> ss(n);
|
std::vector<sherpa_onnx::OnlineStream *> ss(n);
|
||||||
for (int32_t i = 0; i != n; ++i) {
|
for (int32_t i = 0; i != n; ++i) {
|
||||||
ss[i] = streams[i]->impl.get();
|
ss[i] = streams[i]->impl.get();
|
||||||
@@ -794,7 +799,8 @@ void SherpaOnnxDecodeMultipleKeywordStreams(SherpaOnnxKeywordSpotter *spotter,
|
|||||||
}
|
}
|
||||||
|
|
||||||
const SherpaOnnxKeywordResult *SherpaOnnxGetKeywordResult(
|
const SherpaOnnxKeywordResult *SherpaOnnxGetKeywordResult(
|
||||||
SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream *stream) {
|
const SherpaOnnxKeywordSpotter *spotter,
|
||||||
|
const SherpaOnnxOnlineStream *stream) {
|
||||||
const sherpa_onnx::KeywordResult &result =
|
const sherpa_onnx::KeywordResult &result =
|
||||||
spotter->impl->GetResult(stream->impl.get());
|
spotter->impl->GetResult(stream->impl.get());
|
||||||
const auto &keyword = result.keyword;
|
const auto &keyword = result.keyword;
|
||||||
@@ -869,8 +875,9 @@ void SherpaOnnxDestroyKeywordResult(const SherpaOnnxKeywordResult *r) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const char *SherpaOnnxGetKeywordResultAsJson(SherpaOnnxKeywordSpotter *spotter,
|
const char *SherpaOnnxGetKeywordResultAsJson(
|
||||||
SherpaOnnxOnlineStream *stream) {
|
const SherpaOnnxKeywordSpotter *spotter,
|
||||||
|
const SherpaOnnxOnlineStream *stream) {
|
||||||
const sherpa_onnx::KeywordResult &result =
|
const sherpa_onnx::KeywordResult &result =
|
||||||
spotter->impl->GetResult(stream->impl.get());
|
spotter->impl->GetResult(stream->impl.get());
|
||||||
|
|
||||||
|
|||||||
@@ -600,7 +600,7 @@ SHERPA_ONNX_API const char *SherpaOnnxGetOfflineStreamResultAsJson(
|
|||||||
SHERPA_ONNX_API void SherpaOnnxDestroyOfflineStreamResultJson(const char *s);
|
SHERPA_ONNX_API void SherpaOnnxDestroyOfflineStreamResultJson(const char *s);
|
||||||
|
|
||||||
// ============================================================
|
// ============================================================
|
||||||
// For Keyword Spot
|
// For Keyword Spotter
|
||||||
// ============================================================
|
// ============================================================
|
||||||
SHERPA_ONNX_API typedef struct SherpaOnnxKeywordResult {
|
SHERPA_ONNX_API typedef struct SherpaOnnxKeywordResult {
|
||||||
/// The triggered keyword.
|
/// The triggered keyword.
|
||||||
@@ -660,21 +660,21 @@ SHERPA_ONNX_API typedef struct SherpaOnnxKeywordSpotter
|
|||||||
/// @param config Config for the keyword spotter.
|
/// @param config Config for the keyword spotter.
|
||||||
/// @return Return a pointer to the spotter. The user has to invoke
|
/// @return Return a pointer to the spotter. The user has to invoke
|
||||||
/// SherpaOnnxDestroyKeywordSpotter() to free it to avoid memory leak.
|
/// SherpaOnnxDestroyKeywordSpotter() to free it to avoid memory leak.
|
||||||
SHERPA_ONNX_API SherpaOnnxKeywordSpotter *SherpaOnnxCreateKeywordSpotter(
|
SHERPA_ONNX_API const SherpaOnnxKeywordSpotter *SherpaOnnxCreateKeywordSpotter(
|
||||||
const SherpaOnnxKeywordSpotterConfig *config);
|
const SherpaOnnxKeywordSpotterConfig *config);
|
||||||
|
|
||||||
/// Free a pointer returned by SherpaOnnxCreateKeywordSpotter()
|
/// Free a pointer returned by SherpaOnnxCreateKeywordSpotter()
|
||||||
///
|
///
|
||||||
/// @param p A pointer returned by SherpaOnnxCreateKeywordSpotter()
|
/// @param p A pointer returned by SherpaOnnxCreateKeywordSpotter()
|
||||||
SHERPA_ONNX_API void SherpaOnnxDestroyKeywordSpotter(
|
SHERPA_ONNX_API void SherpaOnnxDestroyKeywordSpotter(
|
||||||
SherpaOnnxKeywordSpotter *spotter);
|
const SherpaOnnxKeywordSpotter *spotter);
|
||||||
|
|
||||||
/// Create an online stream for accepting wave samples.
|
/// Create an online stream for accepting wave samples.
|
||||||
///
|
///
|
||||||
/// @param spotter A pointer returned by SherpaOnnxCreateKeywordSpotter()
|
/// @param spotter A pointer returned by SherpaOnnxCreateKeywordSpotter()
|
||||||
/// @return Return a pointer to an OnlineStream. The user has to invoke
|
/// @return Return a pointer to an OnlineStream. The user has to invoke
|
||||||
/// SherpaOnnxDestroyOnlineStream() to free it to avoid memory leak.
|
/// SherpaOnnxDestroyOnlineStream() to free it to avoid memory leak.
|
||||||
SHERPA_ONNX_API SherpaOnnxOnlineStream *SherpaOnnxCreateKeywordStream(
|
SHERPA_ONNX_API const SherpaOnnxOnlineStream *SherpaOnnxCreateKeywordStream(
|
||||||
const SherpaOnnxKeywordSpotter *spotter);
|
const SherpaOnnxKeywordSpotter *spotter);
|
||||||
|
|
||||||
/// Create an online stream for accepting wave samples with the specified hot
|
/// Create an online stream for accepting wave samples with the specified hot
|
||||||
@@ -684,7 +684,7 @@ SHERPA_ONNX_API SherpaOnnxOnlineStream *SherpaOnnxCreateKeywordStream(
|
|||||||
/// @param keywords A pointer points to the keywords that you set
|
/// @param keywords A pointer points to the keywords that you set
|
||||||
/// @return Return a pointer to an OnlineStream. The user has to invoke
|
/// @return Return a pointer to an OnlineStream. The user has to invoke
|
||||||
/// SherpaOnnxDestroyOnlineStream() to free it to avoid memory leak.
|
/// SherpaOnnxDestroyOnlineStream() to free it to avoid memory leak.
|
||||||
SHERPA_ONNX_API SherpaOnnxOnlineStream *
|
SHERPA_ONNX_API const SherpaOnnxOnlineStream *
|
||||||
SherpaOnnxCreateKeywordStreamWithKeywords(
|
SherpaOnnxCreateKeywordStreamWithKeywords(
|
||||||
const SherpaOnnxKeywordSpotter *spotter, const char *keywords);
|
const SherpaOnnxKeywordSpotter *spotter, const char *keywords);
|
||||||
|
|
||||||
@@ -693,15 +693,22 @@ SherpaOnnxCreateKeywordStreamWithKeywords(
|
|||||||
///
|
///
|
||||||
/// @param spotter A pointer returned by SherpaOnnxCreateKeywordSpotter
|
/// @param spotter A pointer returned by SherpaOnnxCreateKeywordSpotter
|
||||||
/// @param stream A pointer returned by SherpaOnnxCreateKeywordStream
|
/// @param stream A pointer returned by SherpaOnnxCreateKeywordStream
|
||||||
SHERPA_ONNX_API int32_t SherpaOnnxIsKeywordStreamReady(
|
SHERPA_ONNX_API int32_t
|
||||||
SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream *stream);
|
SherpaOnnxIsKeywordStreamReady(const SherpaOnnxKeywordSpotter *spotter,
|
||||||
|
const SherpaOnnxOnlineStream *stream);
|
||||||
|
|
||||||
/// Call this function to run the neural network model and decoding.
|
/// Call this function to run the neural network model and decoding.
|
||||||
//
|
//
|
||||||
/// Precondition for this function: SherpaOnnxIsKeywordStreamReady() MUST
|
/// Precondition for this function: SherpaOnnxIsKeywordStreamReady() MUST
|
||||||
/// return 1.
|
/// return 1.
|
||||||
SHERPA_ONNX_API void SherpaOnnxDecodeKeywordStream(
|
SHERPA_ONNX_API void SherpaOnnxDecodeKeywordStream(
|
||||||
SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream *stream);
|
const SherpaOnnxKeywordSpotter *spotter,
|
||||||
|
const SherpaOnnxOnlineStream *stream);
|
||||||
|
|
||||||
|
/// Please call it right after a keyword is detected
|
||||||
|
SHERPA_ONNX_API void SherpaOnnxResetKeywordStream(
|
||||||
|
const SherpaOnnxKeywordSpotter *spotter,
|
||||||
|
const SherpaOnnxOnlineStream *stream);
|
||||||
|
|
||||||
/// This function is similar to SherpaOnnxDecodeKeywordStream(). It decodes
|
/// This function is similar to SherpaOnnxDecodeKeywordStream(). It decodes
|
||||||
/// multiple OnlineStream in parallel.
|
/// multiple OnlineStream in parallel.
|
||||||
@@ -714,8 +721,8 @@ SHERPA_ONNX_API void SherpaOnnxDecodeKeywordStream(
|
|||||||
/// SherpaOnnxCreateKeywordStream()
|
/// SherpaOnnxCreateKeywordStream()
|
||||||
/// @param n Number of elements in the given streams array.
|
/// @param n Number of elements in the given streams array.
|
||||||
SHERPA_ONNX_API void SherpaOnnxDecodeMultipleKeywordStreams(
|
SHERPA_ONNX_API void SherpaOnnxDecodeMultipleKeywordStreams(
|
||||||
SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream **streams,
|
const SherpaOnnxKeywordSpotter *spotter,
|
||||||
int32_t n);
|
const SherpaOnnxOnlineStream **streams, int32_t n);
|
||||||
|
|
||||||
/// Get the decoding results so far for an OnlineStream.
|
/// Get the decoding results so far for an OnlineStream.
|
||||||
///
|
///
|
||||||
@@ -725,7 +732,8 @@ SHERPA_ONNX_API void SherpaOnnxDecodeMultipleKeywordStreams(
|
|||||||
/// SherpaOnnxDestroyKeywordResult() to free the returned pointer to
|
/// SherpaOnnxDestroyKeywordResult() to free the returned pointer to
|
||||||
/// avoid memory leak.
|
/// avoid memory leak.
|
||||||
SHERPA_ONNX_API const SherpaOnnxKeywordResult *SherpaOnnxGetKeywordResult(
|
SHERPA_ONNX_API const SherpaOnnxKeywordResult *SherpaOnnxGetKeywordResult(
|
||||||
SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream *stream);
|
const SherpaOnnxKeywordSpotter *spotter,
|
||||||
|
const SherpaOnnxOnlineStream *stream);
|
||||||
|
|
||||||
/// Destroy the pointer returned by SherpaOnnxGetKeywordResult().
|
/// Destroy the pointer returned by SherpaOnnxGetKeywordResult().
|
||||||
///
|
///
|
||||||
@@ -736,7 +744,8 @@ SHERPA_ONNX_API void SherpaOnnxDestroyKeywordResult(
|
|||||||
// the user has to call SherpaOnnxFreeKeywordResultJson() to free the returned
|
// the user has to call SherpaOnnxFreeKeywordResultJson() to free the returned
|
||||||
// pointer to avoid memory leak
|
// pointer to avoid memory leak
|
||||||
SHERPA_ONNX_API const char *SherpaOnnxGetKeywordResultAsJson(
|
SHERPA_ONNX_API const char *SherpaOnnxGetKeywordResultAsJson(
|
||||||
SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream *stream);
|
const SherpaOnnxKeywordSpotter *spotter,
|
||||||
|
const SherpaOnnxOnlineStream *stream);
|
||||||
|
|
||||||
SHERPA_ONNX_API void SherpaOnnxFreeKeywordResultJson(const char *s);
|
SHERPA_ONNX_API void SherpaOnnxFreeKeywordResultJson(const char *s);
|
||||||
|
|
||||||
|
|||||||
@@ -391,4 +391,112 @@ GeneratedAudio OfflineTts::Generate(const std::string &text,
|
|||||||
return ans;
|
return ans;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
KeywordSpotter KeywordSpotter::Create(const KeywordSpotterConfig &config) {
|
||||||
|
struct SherpaOnnxKeywordSpotterConfig c;
|
||||||
|
memset(&c, 0, sizeof(c));
|
||||||
|
|
||||||
|
c.feat_config.sample_rate = config.feat_config.sample_rate;
|
||||||
|
|
||||||
|
c.model_config.transducer.encoder =
|
||||||
|
config.model_config.transducer.encoder.c_str();
|
||||||
|
c.model_config.transducer.decoder =
|
||||||
|
config.model_config.transducer.decoder.c_str();
|
||||||
|
c.model_config.transducer.joiner =
|
||||||
|
config.model_config.transducer.joiner.c_str();
|
||||||
|
c.feat_config.feature_dim = config.feat_config.feature_dim;
|
||||||
|
|
||||||
|
c.model_config.paraformer.encoder =
|
||||||
|
config.model_config.paraformer.encoder.c_str();
|
||||||
|
c.model_config.paraformer.decoder =
|
||||||
|
config.model_config.paraformer.decoder.c_str();
|
||||||
|
|
||||||
|
c.model_config.zipformer2_ctc.model =
|
||||||
|
config.model_config.zipformer2_ctc.model.c_str();
|
||||||
|
|
||||||
|
c.model_config.tokens = config.model_config.tokens.c_str();
|
||||||
|
c.model_config.num_threads = config.model_config.num_threads;
|
||||||
|
c.model_config.provider = config.model_config.provider.c_str();
|
||||||
|
c.model_config.debug = config.model_config.debug;
|
||||||
|
c.model_config.model_type = config.model_config.model_type.c_str();
|
||||||
|
c.model_config.modeling_unit = config.model_config.modeling_unit.c_str();
|
||||||
|
c.model_config.bpe_vocab = config.model_config.bpe_vocab.c_str();
|
||||||
|
c.model_config.tokens_buf = config.model_config.tokens_buf.c_str();
|
||||||
|
c.model_config.tokens_buf_size = config.model_config.tokens_buf.size();
|
||||||
|
|
||||||
|
c.max_active_paths = config.max_active_paths;
|
||||||
|
c.num_trailing_blanks = config.num_trailing_blanks;
|
||||||
|
c.keywords_score = config.keywords_score;
|
||||||
|
c.keywords_threshold = config.keywords_threshold;
|
||||||
|
c.keywords_file = config.keywords_file.c_str();
|
||||||
|
|
||||||
|
auto p = SherpaOnnxCreateKeywordSpotter(&c);
|
||||||
|
return KeywordSpotter(p);
|
||||||
|
}
|
||||||
|
|
||||||
|
KeywordSpotter::KeywordSpotter(const SherpaOnnxKeywordSpotter *p)
|
||||||
|
: MoveOnly<KeywordSpotter, SherpaOnnxKeywordSpotter>(p) {}
|
||||||
|
|
||||||
|
void KeywordSpotter::Destroy(const SherpaOnnxKeywordSpotter *p) const {
|
||||||
|
SherpaOnnxDestroyKeywordSpotter(p);
|
||||||
|
}
|
||||||
|
|
||||||
|
OnlineStream KeywordSpotter::CreateStream() const {
|
||||||
|
auto s = SherpaOnnxCreateKeywordStream(p_);
|
||||||
|
return OnlineStream{s};
|
||||||
|
}
|
||||||
|
|
||||||
|
OnlineStream KeywordSpotter::CreateStream(const std::string &keywords) const {
|
||||||
|
auto s = SherpaOnnxCreateKeywordStreamWithKeywords(p_, keywords.c_str());
|
||||||
|
return OnlineStream{s};
|
||||||
|
}
|
||||||
|
|
||||||
|
bool KeywordSpotter::IsReady(const OnlineStream *s) const {
|
||||||
|
return SherpaOnnxIsKeywordStreamReady(p_, s->Get());
|
||||||
|
}
|
||||||
|
|
||||||
|
void KeywordSpotter::Decode(const OnlineStream *s) const {
|
||||||
|
return SherpaOnnxDecodeKeywordStream(p_, s->Get());
|
||||||
|
}
|
||||||
|
|
||||||
|
void KeywordSpotter::Decode(const OnlineStream *ss, int32_t n) const {
|
||||||
|
if (n <= 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<const SherpaOnnxOnlineStream *> streams(n);
|
||||||
|
for (int32_t i = 0; i != n; ++n) {
|
||||||
|
streams[i] = ss[i].Get();
|
||||||
|
}
|
||||||
|
|
||||||
|
SherpaOnnxDecodeMultipleKeywordStreams(p_, streams.data(), n);
|
||||||
|
}
|
||||||
|
|
||||||
|
KeywordResult KeywordSpotter::GetResult(const OnlineStream *s) const {
|
||||||
|
auto r = SherpaOnnxGetKeywordResult(p_, s->Get());
|
||||||
|
|
||||||
|
KeywordResult ans;
|
||||||
|
ans.keyword = r->keyword;
|
||||||
|
|
||||||
|
ans.tokens.resize(r->count);
|
||||||
|
for (int32_t i = 0; i < r->count; ++i) {
|
||||||
|
ans.tokens[i] = r->tokens_arr[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (r->timestamps) {
|
||||||
|
ans.timestamps.resize(r->count);
|
||||||
|
std::copy(r->timestamps, r->timestamps + r->count, ans.timestamps.data());
|
||||||
|
}
|
||||||
|
|
||||||
|
ans.start_time = r->start_time;
|
||||||
|
ans.json = r->json;
|
||||||
|
|
||||||
|
SherpaOnnxDestroyKeywordResult(r);
|
||||||
|
|
||||||
|
return ans;
|
||||||
|
}
|
||||||
|
|
||||||
|
void KeywordSpotter::Reset(const OnlineStream *s) const {
|
||||||
|
SherpaOnnxResetKeywordStream(p_, s->Get());
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace sherpa_onnx::cxx
|
} // namespace sherpa_onnx::cxx
|
||||||
|
|||||||
@@ -406,6 +406,53 @@ class SHERPA_ONNX_API OfflineTts
|
|||||||
explicit OfflineTts(const SherpaOnnxOfflineTts *p);
|
explicit OfflineTts(const SherpaOnnxOfflineTts *p);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// ============================================================
|
||||||
|
// For Keyword Spotter
|
||||||
|
// ============================================================
|
||||||
|
|
||||||
|
struct KeywordResult {
|
||||||
|
std::string keyword;
|
||||||
|
std::vector<std::string> tokens;
|
||||||
|
std::vector<float> timestamps;
|
||||||
|
float start_time;
|
||||||
|
std::string json;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct KeywordSpotterConfig {
|
||||||
|
FeatureConfig feat_config;
|
||||||
|
OnlineModelConfig model_config;
|
||||||
|
int32_t max_active_paths = 4;
|
||||||
|
int32_t num_trailing_blanks = 1;
|
||||||
|
float keywords_score = 1.0f;
|
||||||
|
float keywords_threshold = 0.25f;
|
||||||
|
std::string keywords_file;
|
||||||
|
};
|
||||||
|
|
||||||
|
class SHERPA_ONNX_API KeywordSpotter
|
||||||
|
: public MoveOnly<KeywordSpotter, SherpaOnnxKeywordSpotter> {
|
||||||
|
public:
|
||||||
|
static KeywordSpotter Create(const KeywordSpotterConfig &config);
|
||||||
|
|
||||||
|
void Destroy(const SherpaOnnxKeywordSpotter *p) const;
|
||||||
|
|
||||||
|
OnlineStream CreateStream() const;
|
||||||
|
|
||||||
|
OnlineStream CreateStream(const std::string &keywords) const;
|
||||||
|
|
||||||
|
bool IsReady(const OnlineStream *s) const;
|
||||||
|
|
||||||
|
void Decode(const OnlineStream *s) const;
|
||||||
|
|
||||||
|
void Decode(const OnlineStream *ss, int32_t n) const;
|
||||||
|
|
||||||
|
void Reset(const OnlineStream *s) const;
|
||||||
|
|
||||||
|
KeywordResult GetResult(const OnlineStream *s) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
explicit KeywordSpotter(const SherpaOnnxKeywordSpotter *p);
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace sherpa_onnx::cxx
|
} // namespace sherpa_onnx::cxx
|
||||||
|
|
||||||
#endif // SHERPA_ONNX_C_API_CXX_API_H_
|
#endif // SHERPA_ONNX_C_API_CXX_API_H_
|
||||||
|
|||||||
@@ -38,6 +38,8 @@ class KeywordSpotterImpl {
|
|||||||
|
|
||||||
virtual bool IsReady(OnlineStream *s) const = 0;
|
virtual bool IsReady(OnlineStream *s) const = 0;
|
||||||
|
|
||||||
|
virtual void Reset(OnlineStream *s) const = 0;
|
||||||
|
|
||||||
virtual void DecodeStreams(OnlineStream **ss, int32_t n) const = 0;
|
virtual void DecodeStreams(OnlineStream **ss, int32_t n) const = 0;
|
||||||
|
|
||||||
virtual KeywordResult GetResult(OnlineStream *s) const = 0;
|
virtual KeywordResult GetResult(OnlineStream *s) const = 0;
|
||||||
|
|||||||
@@ -195,8 +195,24 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl {
|
|||||||
return s->GetNumProcessedFrames() + model_->ChunkSize() <
|
return s->GetNumProcessedFrames() + model_->ChunkSize() <
|
||||||
s->NumFramesReady();
|
s->NumFramesReady();
|
||||||
}
|
}
|
||||||
|
void Reset(OnlineStream *s) const override { InitOnlineStream(s); }
|
||||||
|
|
||||||
void DecodeStreams(OnlineStream **ss, int32_t n) const override {
|
void DecodeStreams(OnlineStream **ss, int32_t n) const override {
|
||||||
|
for (int32_t i = 0; i < n; ++i) {
|
||||||
|
auto s = ss[i];
|
||||||
|
auto r = s->GetKeywordResult(true);
|
||||||
|
int32_t num_trailing_blanks = r.num_trailing_blanks;
|
||||||
|
// assume subsampling_factor is 4
|
||||||
|
// assume frameshift is 0.01 second
|
||||||
|
float trailing_slience = num_trailing_blanks * 4 * 0.01;
|
||||||
|
|
||||||
|
// it resets automatically after detecting 1.5 seconds of silence
|
||||||
|
float threshold = 1.5;
|
||||||
|
if (trailing_slience > threshold) {
|
||||||
|
Reset(s);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
int32_t chunk_size = model_->ChunkSize();
|
int32_t chunk_size = model_->ChunkSize();
|
||||||
int32_t chunk_shift = model_->ChunkShift();
|
int32_t chunk_shift = model_->ChunkShift();
|
||||||
|
|
||||||
|
|||||||
@@ -157,6 +157,8 @@ bool KeywordSpotter::IsReady(OnlineStream *s) const {
|
|||||||
return impl_->IsReady(s);
|
return impl_->IsReady(s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void KeywordSpotter::Reset(OnlineStream *s) const { impl_->Reset(s); }
|
||||||
|
|
||||||
void KeywordSpotter::DecodeStreams(OnlineStream **ss, int32_t n) const {
|
void KeywordSpotter::DecodeStreams(OnlineStream **ss, int32_t n) const {
|
||||||
impl_->DecodeStreams(ss, n);
|
impl_->DecodeStreams(ss, n);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -129,6 +129,9 @@ class KeywordSpotter {
|
|||||||
*/
|
*/
|
||||||
bool IsReady(OnlineStream *s) const;
|
bool IsReady(OnlineStream *s) const;
|
||||||
|
|
||||||
|
// Remember to call it after detecting a keyword
|
||||||
|
void Reset(OnlineStream *s) const;
|
||||||
|
|
||||||
/** Decode a single stream. */
|
/** Decode a single stream. */
|
||||||
void DecodeStream(OnlineStream *s) const {
|
void DecodeStream(OnlineStream *s) const {
|
||||||
OnlineStream *ss[1] = {s};
|
OnlineStream *ss[1] = {s};
|
||||||
|
|||||||
@@ -106,13 +106,15 @@ as the device_name.
|
|||||||
|
|
||||||
while (spotter.IsReady(stream.get())) {
|
while (spotter.IsReady(stream.get())) {
|
||||||
spotter.DecodeStream(stream.get());
|
spotter.DecodeStream(stream.get());
|
||||||
}
|
|
||||||
|
|
||||||
const auto r = spotter.GetResult(stream.get());
|
const auto r = spotter.GetResult(stream.get());
|
||||||
if (!r.keyword.empty()) {
|
if (!r.keyword.empty()) {
|
||||||
display.Print(keyword_index, r.AsJsonString());
|
display.Print(keyword_index, r.AsJsonString());
|
||||||
fflush(stderr);
|
fflush(stderr);
|
||||||
keyword_index++;
|
keyword_index++;
|
||||||
|
|
||||||
|
spotter.Reset(stream.get());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -150,13 +150,15 @@ for a list of pre-trained models to download.
|
|||||||
while (!stop) {
|
while (!stop) {
|
||||||
while (spotter.IsReady(s.get())) {
|
while (spotter.IsReady(s.get())) {
|
||||||
spotter.DecodeStream(s.get());
|
spotter.DecodeStream(s.get());
|
||||||
}
|
|
||||||
|
|
||||||
const auto r = spotter.GetResult(s.get());
|
const auto r = spotter.GetResult(s.get());
|
||||||
if (!r.keyword.empty()) {
|
if (!r.keyword.empty()) {
|
||||||
display.Print(keyword_index, r.AsJsonString());
|
display.Print(keyword_index, r.AsJsonString());
|
||||||
fflush(stderr);
|
fflush(stderr);
|
||||||
keyword_index++;
|
keyword_index++;
|
||||||
|
|
||||||
|
spotter.Reset(s.get());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Pa_Sleep(20); // sleep for 20ms
|
Pa_Sleep(20); // sleep for 20ms
|
||||||
|
|||||||
@@ -27,6 +27,10 @@ public class KeywordSpotter {
|
|||||||
decode(ptr, s.getPtr());
|
decode(ptr, s.getPtr());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void reset(OnlineStream s) {
|
||||||
|
reset(ptr, s.getPtr());
|
||||||
|
}
|
||||||
|
|
||||||
public boolean isReady(OnlineStream s) {
|
public boolean isReady(OnlineStream s) {
|
||||||
return isReady(ptr, s.getPtr());
|
return isReady(ptr, s.getPtr());
|
||||||
}
|
}
|
||||||
@@ -60,6 +64,8 @@ public class KeywordSpotter {
|
|||||||
|
|
||||||
private native void decode(long ptr, long streamPtr);
|
private native void decode(long ptr, long streamPtr);
|
||||||
|
|
||||||
|
private native void reset(long ptr, long streamPtr);
|
||||||
|
|
||||||
private native boolean isReady(long ptr, long streamPtr);
|
private native boolean isReady(long ptr, long streamPtr);
|
||||||
|
|
||||||
private native Object[] getResult(long ptr, long streamPtr);
|
private native Object[] getResult(long ptr, long streamPtr);
|
||||||
|
|||||||
@@ -161,6 +161,15 @@ JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_KeywordSpotter_decode(
|
|||||||
kws->DecodeStream(stream);
|
kws->DecodeStream(stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SHERPA_ONNX_EXTERN_C
|
||||||
|
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_KeywordSpotter_reset(
|
||||||
|
JNIEnv * /*env*/, jobject /*obj*/, jlong ptr, jlong stream_ptr) {
|
||||||
|
auto kws = reinterpret_cast<sherpa_onnx::KeywordSpotter *>(ptr);
|
||||||
|
auto stream = reinterpret_cast<sherpa_onnx::OnlineStream *>(stream_ptr);
|
||||||
|
|
||||||
|
kws->Reset(stream);
|
||||||
|
}
|
||||||
|
|
||||||
SHERPA_ONNX_EXTERN_C
|
SHERPA_ONNX_EXTERN_C
|
||||||
JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_KeywordSpotter_createStream(
|
JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_KeywordSpotter_createStream(
|
||||||
JNIEnv *env, jobject /*obj*/, jlong ptr, jstring keywords) {
|
JNIEnv *env, jobject /*obj*/, jlong ptr, jstring keywords) {
|
||||||
|
|||||||
@@ -49,6 +49,7 @@ class KeywordSpotter(
|
|||||||
}
|
}
|
||||||
|
|
||||||
fun decode(stream: OnlineStream) = decode(ptr, stream.ptr)
|
fun decode(stream: OnlineStream) = decode(ptr, stream.ptr)
|
||||||
|
fun reset(stream: OnlineStream) = reset(ptr, stream.ptr)
|
||||||
fun isReady(stream: OnlineStream) = isReady(ptr, stream.ptr)
|
fun isReady(stream: OnlineStream) = isReady(ptr, stream.ptr)
|
||||||
fun getResult(stream: OnlineStream): KeywordSpotterResult {
|
fun getResult(stream: OnlineStream): KeywordSpotterResult {
|
||||||
val objArray = getResult(ptr, stream.ptr)
|
val objArray = getResult(ptr, stream.ptr)
|
||||||
@@ -74,6 +75,7 @@ class KeywordSpotter(
|
|||||||
private external fun createStream(ptr: Long, keywords: String): Long
|
private external fun createStream(ptr: Long, keywords: String): Long
|
||||||
private external fun isReady(ptr: Long, streamPtr: Long): Boolean
|
private external fun isReady(ptr: Long, streamPtr: Long): Boolean
|
||||||
private external fun decode(ptr: Long, streamPtr: Long)
|
private external fun decode(ptr: Long, streamPtr: Long)
|
||||||
|
private external fun reset(ptr: Long, streamPtr: Long)
|
||||||
private external fun getResult(ptr: Long, streamPtr: Long): Array<Any>
|
private external fun getResult(ptr: Long, streamPtr: Long): Array<Any>
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
|
|||||||
@@ -67,6 +67,7 @@ void PybindKeywordSpotter(py::module *m) {
|
|||||||
py::arg("keywords"), py::call_guard<py::gil_scoped_release>())
|
py::arg("keywords"), py::call_guard<py::gil_scoped_release>())
|
||||||
.def("is_ready", &PyClass::IsReady,
|
.def("is_ready", &PyClass::IsReady,
|
||||||
py::call_guard<py::gil_scoped_release>())
|
py::call_guard<py::gil_scoped_release>())
|
||||||
|
.def("reset", &PyClass::Reset, py::call_guard<py::gil_scoped_release>())
|
||||||
.def("decode_stream", &PyClass::DecodeStream,
|
.def("decode_stream", &PyClass::DecodeStream,
|
||||||
py::call_guard<py::gil_scoped_release>())
|
py::call_guard<py::gil_scoped_release>())
|
||||||
.def(
|
.def(
|
||||||
|
|||||||
@@ -105,7 +105,7 @@ class KeywordSpotter(object):
|
|||||||
|
|
||||||
provider_config = ProviderConfig(
|
provider_config = ProviderConfig(
|
||||||
provider=provider,
|
provider=provider,
|
||||||
device = device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
model_config = OnlineModelConfig(
|
model_config = OnlineModelConfig(
|
||||||
@@ -131,6 +131,9 @@ class KeywordSpotter(object):
|
|||||||
)
|
)
|
||||||
self.keyword_spotter = _KeywordSpotter(keywords_spotter_config)
|
self.keyword_spotter = _KeywordSpotter(keywords_spotter_config)
|
||||||
|
|
||||||
|
def reset_stream(self, s: OnlineStream):
|
||||||
|
self.keyword_spotter.reset(s)
|
||||||
|
|
||||||
def create_stream(self, keywords: Optional[str] = None):
|
def create_stream(self, keywords: Optional[str] = None):
|
||||||
if keywords is None:
|
if keywords is None:
|
||||||
return self.keyword_spotter.create_stream()
|
return self.keyword_spotter.create_stream()
|
||||||
|
|||||||
@@ -98,6 +98,9 @@ class TestKeywordSpotter(unittest.TestCase):
|
|||||||
if r:
|
if r:
|
||||||
print(f"{r} is detected.")
|
print(f"{r} is detected.")
|
||||||
results[i] += f"{r}/"
|
results[i] += f"{r}/"
|
||||||
|
|
||||||
|
keyword_spotter.reset_stream(s)
|
||||||
|
|
||||||
if len(ready_list) == 0:
|
if len(ready_list) == 0:
|
||||||
break
|
break
|
||||||
keyword_spotter.decode_streams(ready_list)
|
keyword_spotter.decode_streams(ready_list)
|
||||||
@@ -158,6 +161,9 @@ class TestKeywordSpotter(unittest.TestCase):
|
|||||||
if r:
|
if r:
|
||||||
print(f"{r} is detected.")
|
print(f"{r} is detected.")
|
||||||
results[i] += f"{r}/"
|
results[i] += f"{r}/"
|
||||||
|
|
||||||
|
keyword_spotter.reset_stream(s)
|
||||||
|
|
||||||
if len(ready_list) == 0:
|
if len(ready_list) == 0:
|
||||||
break
|
break
|
||||||
keyword_spotter.decode_streams(ready_list)
|
keyword_spotter.decode_streams(ready_list)
|
||||||
|
|||||||
@@ -1076,6 +1076,10 @@ class SherpaOnnxKeywordSpotterWrapper {
|
|||||||
SherpaOnnxDecodeKeywordStream(spotter, stream)
|
SherpaOnnxDecodeKeywordStream(spotter, stream)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func reset() {
|
||||||
|
SherpaOnnxResetKeywordStream(spotter, stream)
|
||||||
|
}
|
||||||
|
|
||||||
func getResult() -> SherpaOnnxKeywordResultWrapper {
|
func getResult() -> SherpaOnnxKeywordResultWrapper {
|
||||||
let result: UnsafePointer<SherpaOnnxKeywordResult>? = SherpaOnnxGetKeywordResult(
|
let result: UnsafePointer<SherpaOnnxKeywordResult>? = SherpaOnnxGetKeywordResult(
|
||||||
spotter, stream)
|
spotter, stream)
|
||||||
|
|||||||
@@ -70,6 +70,9 @@ func run() {
|
|||||||
spotter.decode()
|
spotter.decode()
|
||||||
let keyword = spotter.getResult().keyword
|
let keyword = spotter.getResult().keyword
|
||||||
if keyword != "" {
|
if keyword != "" {
|
||||||
|
// Remember to call reset() right after detecting a keyword
|
||||||
|
spotter.reset()
|
||||||
|
|
||||||
print("Detected: \(keyword)")
|
print("Detected: \(keyword)")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ set(exported_functions
|
|||||||
SherpaOnnxIsKeywordStreamReady
|
SherpaOnnxIsKeywordStreamReady
|
||||||
SherpaOnnxOnlineStreamAcceptWaveform
|
SherpaOnnxOnlineStreamAcceptWaveform
|
||||||
SherpaOnnxOnlineStreamInputFinished
|
SherpaOnnxOnlineStreamInputFinished
|
||||||
|
SherpaOnnxResetKeywordStream
|
||||||
)
|
)
|
||||||
set(mangled_exported_functions)
|
set(mangled_exported_functions)
|
||||||
foreach(x IN LISTS exported_functions)
|
foreach(x IN LISTS exported_functions)
|
||||||
|
|||||||
@@ -102,8 +102,6 @@ if (navigator.mediaDevices.getUserMedia) {
|
|||||||
recognizer_stream.acceptWaveform(expectedSampleRate, samples);
|
recognizer_stream.acceptWaveform(expectedSampleRate, samples);
|
||||||
while (recognizer.isReady(recognizer_stream)) {
|
while (recognizer.isReady(recognizer_stream)) {
|
||||||
recognizer.decode(recognizer_stream);
|
recognizer.decode(recognizer_stream);
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
let result = recognizer.getResult(recognizer_stream);
|
let result = recognizer.getResult(recognizer_stream);
|
||||||
|
|
||||||
@@ -111,6 +109,10 @@ if (navigator.mediaDevices.getUserMedia) {
|
|||||||
console.log(result)
|
console.log(result)
|
||||||
lastResult = result;
|
lastResult = result;
|
||||||
resultList.push(JSON.stringify(result));
|
resultList.push(JSON.stringify(result));
|
||||||
|
|
||||||
|
// remember to reset the stream right after detecting a keyword
|
||||||
|
recognizer.reset(recognizer_stream);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -296,8 +296,11 @@ class Kws {
|
|||||||
}
|
}
|
||||||
|
|
||||||
decode(stream) {
|
decode(stream) {
|
||||||
return this.Module._SherpaOnnxDecodeKeywordStream(
|
this.Module._SherpaOnnxDecodeKeywordStream(this.handle, stream.handle);
|
||||||
this.handle, stream.handle);
|
}
|
||||||
|
|
||||||
|
reset(stream) {
|
||||||
|
this.Module._SherpaOnnxResetKeywordStream(this.handle, stream.handle);
|
||||||
}
|
}
|
||||||
|
|
||||||
getResult(stream) {
|
getResult(stream) {
|
||||||
|
|||||||
Reference in New Issue
Block a user