Fix keyword spotting. (#1689)

Reset the stream right after detecting a keyword
This commit is contained in:
Fangjun Kuang
2025-01-20 16:41:10 +08:00
committed by GitHub
parent b943341fb1
commit 8b989a851c
43 changed files with 813 additions and 293 deletions

View File

@@ -38,6 +38,8 @@ class KeywordSpotterImpl {
virtual bool IsReady(OnlineStream *s) const = 0;
virtual void Reset(OnlineStream *s) const = 0;
virtual void DecodeStreams(OnlineStream **ss, int32_t n) const = 0;
virtual KeywordResult GetResult(OnlineStream *s) const = 0;

View File

@@ -195,8 +195,24 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl {
return s->GetNumProcessedFrames() + model_->ChunkSize() <
s->NumFramesReady();
}
void Reset(OnlineStream *s) const override { InitOnlineStream(s); }
void DecodeStreams(OnlineStream **ss, int32_t n) const override {
for (int32_t i = 0; i < n; ++i) {
auto s = ss[i];
auto r = s->GetKeywordResult(true);
int32_t num_trailing_blanks = r.num_trailing_blanks;
// assume subsampling_factor is 4
// assume frameshift is 0.01 second
float trailing_slience = num_trailing_blanks * 4 * 0.01;
// it resets automatically after detecting 1.5 seconds of silence
float threshold = 1.5;
if (trailing_slience > threshold) {
Reset(s);
}
}
int32_t chunk_size = model_->ChunkSize();
int32_t chunk_shift = model_->ChunkShift();

View File

@@ -157,6 +157,8 @@ bool KeywordSpotter::IsReady(OnlineStream *s) const {
return impl_->IsReady(s);
}
void KeywordSpotter::Reset(OnlineStream *s) const { impl_->Reset(s); }
void KeywordSpotter::DecodeStreams(OnlineStream **ss, int32_t n) const {
impl_->DecodeStreams(ss, n);
}

View File

@@ -129,6 +129,9 @@ class KeywordSpotter {
*/
bool IsReady(OnlineStream *s) const;
// Remember to call it after detecting a keyword
void Reset(OnlineStream *s) const;
/** Decode a single stream. */
void DecodeStream(OnlineStream *s) const {
OnlineStream *ss[1] = {s};

View File

@@ -106,13 +106,15 @@ as the device_name.
while (spotter.IsReady(stream.get())) {
spotter.DecodeStream(stream.get());
}
const auto r = spotter.GetResult(stream.get());
if (!r.keyword.empty()) {
display.Print(keyword_index, r.AsJsonString());
fflush(stderr);
keyword_index++;
const auto r = spotter.GetResult(stream.get());
if (!r.keyword.empty()) {
display.Print(keyword_index, r.AsJsonString());
fflush(stderr);
keyword_index++;
spotter.Reset(stream.get());
}
}
}

View File

@@ -150,13 +150,15 @@ for a list of pre-trained models to download.
while (!stop) {
while (spotter.IsReady(s.get())) {
spotter.DecodeStream(s.get());
}
const auto r = spotter.GetResult(s.get());
if (!r.keyword.empty()) {
display.Print(keyword_index, r.AsJsonString());
fflush(stderr);
keyword_index++;
const auto r = spotter.GetResult(s.get());
if (!r.keyword.empty()) {
display.Print(keyword_index, r.AsJsonString());
fflush(stderr);
keyword_index++;
spotter.Reset(s.get());
}
}
Pa_Sleep(20); // sleep for 20ms