Fix keyword spotting. (#1689)
Reset the stream right after detecting a keyword
This commit is contained in:
@@ -38,6 +38,8 @@ class KeywordSpotterImpl {
|
||||
|
||||
virtual bool IsReady(OnlineStream *s) const = 0;
|
||||
|
||||
virtual void Reset(OnlineStream *s) const = 0;
|
||||
|
||||
virtual void DecodeStreams(OnlineStream **ss, int32_t n) const = 0;
|
||||
|
||||
virtual KeywordResult GetResult(OnlineStream *s) const = 0;
|
||||
|
||||
@@ -195,8 +195,24 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl {
|
||||
return s->GetNumProcessedFrames() + model_->ChunkSize() <
|
||||
s->NumFramesReady();
|
||||
}
|
||||
void Reset(OnlineStream *s) const override { InitOnlineStream(s); }
|
||||
|
||||
void DecodeStreams(OnlineStream **ss, int32_t n) const override {
|
||||
for (int32_t i = 0; i < n; ++i) {
|
||||
auto s = ss[i];
|
||||
auto r = s->GetKeywordResult(true);
|
||||
int32_t num_trailing_blanks = r.num_trailing_blanks;
|
||||
// assume subsampling_factor is 4
|
||||
// assume frameshift is 0.01 second
|
||||
float trailing_slience = num_trailing_blanks * 4 * 0.01;
|
||||
|
||||
// it resets automatically after detecting 1.5 seconds of silence
|
||||
float threshold = 1.5;
|
||||
if (trailing_slience > threshold) {
|
||||
Reset(s);
|
||||
}
|
||||
}
|
||||
|
||||
int32_t chunk_size = model_->ChunkSize();
|
||||
int32_t chunk_shift = model_->ChunkShift();
|
||||
|
||||
|
||||
@@ -157,6 +157,8 @@ bool KeywordSpotter::IsReady(OnlineStream *s) const {
|
||||
return impl_->IsReady(s);
|
||||
}
|
||||
|
||||
void KeywordSpotter::Reset(OnlineStream *s) const { impl_->Reset(s); }
|
||||
|
||||
void KeywordSpotter::DecodeStreams(OnlineStream **ss, int32_t n) const {
|
||||
impl_->DecodeStreams(ss, n);
|
||||
}
|
||||
|
||||
@@ -129,6 +129,9 @@ class KeywordSpotter {
|
||||
*/
|
||||
bool IsReady(OnlineStream *s) const;
|
||||
|
||||
// Remember to call it after detecting a keyword
|
||||
void Reset(OnlineStream *s) const;
|
||||
|
||||
/** Decode a single stream. */
|
||||
void DecodeStream(OnlineStream *s) const {
|
||||
OnlineStream *ss[1] = {s};
|
||||
|
||||
@@ -106,13 +106,15 @@ as the device_name.
|
||||
|
||||
while (spotter.IsReady(stream.get())) {
|
||||
spotter.DecodeStream(stream.get());
|
||||
}
|
||||
|
||||
const auto r = spotter.GetResult(stream.get());
|
||||
if (!r.keyword.empty()) {
|
||||
display.Print(keyword_index, r.AsJsonString());
|
||||
fflush(stderr);
|
||||
keyword_index++;
|
||||
const auto r = spotter.GetResult(stream.get());
|
||||
if (!r.keyword.empty()) {
|
||||
display.Print(keyword_index, r.AsJsonString());
|
||||
fflush(stderr);
|
||||
keyword_index++;
|
||||
|
||||
spotter.Reset(stream.get());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -150,13 +150,15 @@ for a list of pre-trained models to download.
|
||||
while (!stop) {
|
||||
while (spotter.IsReady(s.get())) {
|
||||
spotter.DecodeStream(s.get());
|
||||
}
|
||||
|
||||
const auto r = spotter.GetResult(s.get());
|
||||
if (!r.keyword.empty()) {
|
||||
display.Print(keyword_index, r.AsJsonString());
|
||||
fflush(stderr);
|
||||
keyword_index++;
|
||||
const auto r = spotter.GetResult(s.get());
|
||||
if (!r.keyword.empty()) {
|
||||
display.Print(keyword_index, r.AsJsonString());
|
||||
fflush(stderr);
|
||||
keyword_index++;
|
||||
|
||||
spotter.Reset(s.get());
|
||||
}
|
||||
}
|
||||
|
||||
Pa_Sleep(20); // sleep for 20ms
|
||||
|
||||
Reference in New Issue
Block a user