decoder for open vocabulary keyword spotting (#505)

* various fixes to ContextGraph to support open vocabulary keywords decoder

* Add keyword spotter runtime

* Add binary

* First version works

* Minor fixes

* update text2token

* default values

* Add jni for kws

* add kws android project

* Minor fixes

* Remove unused interface

* Minor fixes

* Add workflow

* handle extra info in texts

* Minor fixes

* Add more comments

* Fix ci

* fix cpp style

* Add input box in android demo so that users can specify their keywords

* Fix cpp style

* Fix comments

* Minor fixes

* Minor fixes

* minor fixes

* Minor fixes

* Minor fixes

* Add CI

* Fix code style

* cpplint

* Fix comments

* Fix error
This commit is contained in:
Wei Kang
2024-01-20 22:52:41 +08:00
committed by GitHub
parent bf1dd3daf6
commit b6c020901a
77 changed files with 3316 additions and 68 deletions

View File

@@ -4,22 +4,59 @@
#include "sherpa-onnx/csrc/context-graph.h"
#include <algorithm>
#include <cassert>
#include <queue>
#include <string>
#include <tuple>
#include <utility>
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
void ContextGraph::Build(
const std::vector<std::vector<int32_t>> &token_ids) const {
void ContextGraph::Build(const std::vector<std::vector<int32_t>> &token_ids,
const std::vector<float> &scores,
const std::vector<std::string> &phrases,
const std::vector<float> &ac_thresholds) const {
if (!scores.empty()) {
SHERPA_ONNX_CHECK_EQ(token_ids.size(), scores.size());
}
if (!phrases.empty()) {
SHERPA_ONNX_CHECK_EQ(token_ids.size(), phrases.size());
}
if (!ac_thresholds.empty()) {
SHERPA_ONNX_CHECK_EQ(token_ids.size(), ac_thresholds.size());
}
for (int32_t i = 0; i < token_ids.size(); ++i) {
auto node = root_.get();
float score = scores.empty() ? 0.0f : scores[i];
score = score == 0.0f ? context_score_ : score;
float ac_threshold = ac_thresholds.empty() ? 0.0f : ac_thresholds[i];
ac_threshold = ac_threshold == 0.0f ? ac_threshold_ : ac_threshold;
std::string phrase = phrases.empty() ? std::string() : phrases[i];
for (int32_t j = 0; j < token_ids[i].size(); ++j) {
int32_t token = token_ids[i][j];
if (0 == node->next.count(token)) {
bool is_end = j == token_ids[i].size() - 1;
node->next[token] = std::make_unique<ContextState>(
token, context_score_, node->node_score + context_score_,
is_end ? node->node_score + context_score_ : 0, is_end);
token, score, node->node_score + score,
is_end ? node->node_score + score : 0, j + 1,
is_end ? ac_threshold : 0.0f, is_end,
is_end ? phrase : std::string());
} else {
float token_score = std::max(score, node->next[token]->token_score);
node->next[token]->token_score = token_score;
float node_score = node->node_score + token_score;
node->next[token]->node_score = node_score;
bool is_end =
(j == token_ids[i].size() - 1) || node->next[token]->is_end;
node->next[token]->output_score = is_end ? node_score : 0.0f;
node->next[token]->is_end = is_end;
if (j == token_ids[i].size() - 1) {
node->next[token]->phrase = phrase;
node->next[token]->ac_threshold = ac_threshold;
}
}
node = node->next[token].get();
}
@@ -27,8 +64,9 @@ void ContextGraph::Build(
FillFailOutput();
}
std::pair<float, const ContextState *> ContextGraph::ForwardOneStep(
const ContextState *state, int32_t token) const {
std::tuple<float, const ContextState *, const ContextState *>
ContextGraph::ForwardOneStep(const ContextState *state, int32_t token,
bool strict_mode /*= true*/) const {
const ContextState *node;
float score;
if (1 == state->next.count(token)) {
@@ -45,8 +83,22 @@ std::pair<float, const ContextState *> ContextGraph::ForwardOneStep(
}
score = node->node_score - state->node_score;
}
SHERPA_ONNX_CHECK(nullptr != node);
return std::make_pair(score + node->output_score, node);
const ContextState *matched_node =
node->is_end ? node : (node->output != nullptr ? node->output : nullptr);
if (!strict_mode && node->output_score != 0) {
SHERPA_ONNX_CHECK(nullptr != matched_node);
float output_score =
node->is_end ? node->node_score
: (node->output != nullptr ? node->output->node_score
: node->node_score);
return std::make_tuple(score + output_score - node->node_score, root_.get(),
matched_node);
}
return std::make_tuple(score + node->output_score, node, matched_node);
}
std::pair<float, const ContextState *> ContextGraph::Finalize(
@@ -55,6 +107,22 @@ std::pair<float, const ContextState *> ContextGraph::Finalize(
return std::make_pair(score, root_.get());
}
std::pair<bool, const ContextState *> ContextGraph::IsMatched(
const ContextState *state) const {
bool status = false;
const ContextState *node = nullptr;
if (state->is_end) {
status = true;
node = state;
} else {
if (state->output != nullptr) {
status = true;
node = state->output;
}
}
return std::make_pair(status, node);
}
void ContextGraph::FillFailOutput() const {
std::queue<const ContextState *> node_queue;
for (auto &kv : root_->next) {