Fix keyword spotting. (#1689)
Reset the stream right after detecting a keyword
This commit is contained in:
@@ -3,6 +3,9 @@ include_directories(${CMAKE_SOURCE_DIR})
|
||||
add_executable(streaming-zipformer-cxx-api ./streaming-zipformer-cxx-api.cc)
|
||||
target_link_libraries(streaming-zipformer-cxx-api sherpa-onnx-cxx-api)
|
||||
|
||||
add_executable(kws-cxx-api ./kws-cxx-api.cc)
|
||||
target_link_libraries(kws-cxx-api sherpa-onnx-cxx-api)
|
||||
|
||||
add_executable(streaming-zipformer-rtf-cxx-api ./streaming-zipformer-rtf-cxx-api.cc)
|
||||
target_link_libraries(streaming-zipformer-rtf-cxx-api sherpa-onnx-cxx-api)
|
||||
|
||||
|
||||
141
cxx-api-examples/kws-cxx-api.cc
Normal file
141
cxx-api-examples/kws-cxx-api.cc
Normal file
@@ -0,0 +1,141 @@
|
||||
// cxx-api-examples/kws-cxx-api.cc
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
//
|
||||
// This file demonstrates how to use keywords spotter with sherpa-onnx's C
|
||||
// clang-format off
|
||||
//
|
||||
// Usage
|
||||
//
|
||||
// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/kws-models/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2
|
||||
// tar xvf sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2
|
||||
// rm sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2
|
||||
//
|
||||
// ./kws-cxx-api
|
||||
//
|
||||
// clang-format on
|
||||
#include <array>
|
||||
#include <iostream>
|
||||
|
||||
#include "sherpa-onnx/c-api/cxx-api.h"
|
||||
|
||||
int32_t main() {
|
||||
using namespace sherpa_onnx::cxx; // NOLINT
|
||||
|
||||
KeywordSpotterConfig config;
|
||||
config.model_config.transducer.encoder =
|
||||
"./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/"
|
||||
"encoder-epoch-12-avg-2-chunk-16-left-64.onnx";
|
||||
|
||||
config.model_config.transducer.decoder =
|
||||
"./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/"
|
||||
"decoder-epoch-12-avg-2-chunk-16-left-64.onnx";
|
||||
|
||||
config.model_config.transducer.joiner =
|
||||
"./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/"
|
||||
"joiner-epoch-12-avg-2-chunk-16-left-64.onnx";
|
||||
|
||||
config.model_config.tokens =
|
||||
"./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/tokens.txt";
|
||||
|
||||
config.model_config.provider = "cpu";
|
||||
config.model_config.num_threads = 1;
|
||||
config.model_config.debug = 1;
|
||||
|
||||
config.keywords_file =
|
||||
"./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/"
|
||||
"test_keywords.txt";
|
||||
|
||||
KeywordSpotter kws = KeywordSpotter::Create(config);
|
||||
if (!kws.Get()) {
|
||||
std::cerr << "Please check your config\n";
|
||||
return -1;
|
||||
}
|
||||
|
||||
std::cout
|
||||
<< "--Test pre-defined keywords from test_wavs/test_keywords.txt--\n";
|
||||
|
||||
std::string wave_filename =
|
||||
"./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/3.wav";
|
||||
|
||||
std::array<float, 8000> tail_paddings = {0}; // 0.5 seconds
|
||||
|
||||
Wave wave = ReadWave(wave_filename);
|
||||
if (wave.samples.empty()) {
|
||||
std::cerr << "Failed to read: '" << wave_filename << "'\n";
|
||||
return -1;
|
||||
}
|
||||
|
||||
OnlineStream stream = kws.CreateStream();
|
||||
if (!stream.Get()) {
|
||||
std::cerr << "Failed to create stream\n";
|
||||
return -1;
|
||||
}
|
||||
|
||||
stream.AcceptWaveform(wave.sample_rate, wave.samples.data(),
|
||||
wave.samples.size());
|
||||
|
||||
stream.AcceptWaveform(wave.sample_rate, tail_paddings.data(),
|
||||
tail_paddings.size());
|
||||
stream.InputFinished();
|
||||
|
||||
while (kws.IsReady(&stream)) {
|
||||
kws.Decode(&stream);
|
||||
auto r = kws.GetResult(&stream);
|
||||
if (!r.keyword.empty()) {
|
||||
std::cout << "Detected keyword: " << r.json << "\n";
|
||||
|
||||
// Remember to reset the keyword stream right after a keyword is detected
|
||||
kws.Reset(&stream);
|
||||
}
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
std::cout << "--Use pre-defined keywords + add a new keyword--\n";
|
||||
|
||||
stream = kws.CreateStream("y ǎn y uán @演员");
|
||||
|
||||
stream.AcceptWaveform(wave.sample_rate, wave.samples.data(),
|
||||
wave.samples.size());
|
||||
|
||||
stream.AcceptWaveform(wave.sample_rate, tail_paddings.data(),
|
||||
tail_paddings.size());
|
||||
stream.InputFinished();
|
||||
|
||||
while (kws.IsReady(&stream)) {
|
||||
kws.Decode(&stream);
|
||||
auto r = kws.GetResult(&stream);
|
||||
if (!r.keyword.empty()) {
|
||||
std::cout << "Detected keyword: " << r.json << "\n";
|
||||
|
||||
// Remember to reset the keyword stream right after a keyword is detected
|
||||
kws.Reset(&stream);
|
||||
}
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
std::cout << "--Use pre-defined keywords + add two new keywords--\n";
|
||||
|
||||
stream = kws.CreateStream("y ǎn y uán @演员/zh ī m íng @知名");
|
||||
|
||||
stream.AcceptWaveform(wave.sample_rate, wave.samples.data(),
|
||||
wave.samples.size());
|
||||
|
||||
stream.AcceptWaveform(wave.sample_rate, tail_paddings.data(),
|
||||
tail_paddings.size());
|
||||
stream.InputFinished();
|
||||
|
||||
while (kws.IsReady(&stream)) {
|
||||
kws.Decode(&stream);
|
||||
auto r = kws.GetResult(&stream);
|
||||
if (!r.keyword.empty()) {
|
||||
std::cout << "Detected keyword: " << r.json << "\n";
|
||||
|
||||
// Remember to reset the keyword stream right after a keyword is detected
|
||||
kws.Reset(&stream);
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
Reference in New Issue
Block a user