decoder for open vocabulary keyword spotting (#505)
* various fixes to ContextGraph to support open vocabulary keywords decoder * Add keyword spotter runtime * Add binary * First version works * Minor fixes * update text2token * default values * Add jni for kws * add kws android project * Minor fixes * Remove unused interface * Minor fixes * Add workflow * handle extra info in texts * Minor fixes * Add more comments * Fix ci * fix cpp style * Add input box in android demo so that users can specify their keywords * Fix cpp style * Fix comments * Minor fixes * Minor fixes * minor fixes * Minor fixes * Minor fixes * Add CI * Fix code style * cpplint * Fix comments * Fix error
This commit is contained in:
@@ -4,22 +4,59 @@
|
||||
|
||||
#include "sherpa-onnx/csrc/context-graph.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <queue>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
void ContextGraph::Build(
|
||||
const std::vector<std::vector<int32_t>> &token_ids) const {
|
||||
void ContextGraph::Build(const std::vector<std::vector<int32_t>> &token_ids,
|
||||
const std::vector<float> &scores,
|
||||
const std::vector<std::string> &phrases,
|
||||
const std::vector<float> &ac_thresholds) const {
|
||||
if (!scores.empty()) {
|
||||
SHERPA_ONNX_CHECK_EQ(token_ids.size(), scores.size());
|
||||
}
|
||||
if (!phrases.empty()) {
|
||||
SHERPA_ONNX_CHECK_EQ(token_ids.size(), phrases.size());
|
||||
}
|
||||
if (!ac_thresholds.empty()) {
|
||||
SHERPA_ONNX_CHECK_EQ(token_ids.size(), ac_thresholds.size());
|
||||
}
|
||||
for (int32_t i = 0; i < token_ids.size(); ++i) {
|
||||
auto node = root_.get();
|
||||
float score = scores.empty() ? 0.0f : scores[i];
|
||||
score = score == 0.0f ? context_score_ : score;
|
||||
float ac_threshold = ac_thresholds.empty() ? 0.0f : ac_thresholds[i];
|
||||
ac_threshold = ac_threshold == 0.0f ? ac_threshold_ : ac_threshold;
|
||||
std::string phrase = phrases.empty() ? std::string() : phrases[i];
|
||||
|
||||
for (int32_t j = 0; j < token_ids[i].size(); ++j) {
|
||||
int32_t token = token_ids[i][j];
|
||||
if (0 == node->next.count(token)) {
|
||||
bool is_end = j == token_ids[i].size() - 1;
|
||||
node->next[token] = std::make_unique<ContextState>(
|
||||
token, context_score_, node->node_score + context_score_,
|
||||
is_end ? node->node_score + context_score_ : 0, is_end);
|
||||
token, score, node->node_score + score,
|
||||
is_end ? node->node_score + score : 0, j + 1,
|
||||
is_end ? ac_threshold : 0.0f, is_end,
|
||||
is_end ? phrase : std::string());
|
||||
} else {
|
||||
float token_score = std::max(score, node->next[token]->token_score);
|
||||
node->next[token]->token_score = token_score;
|
||||
float node_score = node->node_score + token_score;
|
||||
node->next[token]->node_score = node_score;
|
||||
bool is_end =
|
||||
(j == token_ids[i].size() - 1) || node->next[token]->is_end;
|
||||
node->next[token]->output_score = is_end ? node_score : 0.0f;
|
||||
node->next[token]->is_end = is_end;
|
||||
if (j == token_ids[i].size() - 1) {
|
||||
node->next[token]->phrase = phrase;
|
||||
node->next[token]->ac_threshold = ac_threshold;
|
||||
}
|
||||
}
|
||||
node = node->next[token].get();
|
||||
}
|
||||
@@ -27,8 +64,9 @@ void ContextGraph::Build(
|
||||
FillFailOutput();
|
||||
}
|
||||
|
||||
std::pair<float, const ContextState *> ContextGraph::ForwardOneStep(
|
||||
const ContextState *state, int32_t token) const {
|
||||
std::tuple<float, const ContextState *, const ContextState *>
|
||||
ContextGraph::ForwardOneStep(const ContextState *state, int32_t token,
|
||||
bool strict_mode /*= true*/) const {
|
||||
const ContextState *node;
|
||||
float score;
|
||||
if (1 == state->next.count(token)) {
|
||||
@@ -45,8 +83,22 @@ std::pair<float, const ContextState *> ContextGraph::ForwardOneStep(
|
||||
}
|
||||
score = node->node_score - state->node_score;
|
||||
}
|
||||
|
||||
SHERPA_ONNX_CHECK(nullptr != node);
|
||||
return std::make_pair(score + node->output_score, node);
|
||||
|
||||
const ContextState *matched_node =
|
||||
node->is_end ? node : (node->output != nullptr ? node->output : nullptr);
|
||||
|
||||
if (!strict_mode && node->output_score != 0) {
|
||||
SHERPA_ONNX_CHECK(nullptr != matched_node);
|
||||
float output_score =
|
||||
node->is_end ? node->node_score
|
||||
: (node->output != nullptr ? node->output->node_score
|
||||
: node->node_score);
|
||||
return std::make_tuple(score + output_score - node->node_score, root_.get(),
|
||||
matched_node);
|
||||
}
|
||||
return std::make_tuple(score + node->output_score, node, matched_node);
|
||||
}
|
||||
|
||||
std::pair<float, const ContextState *> ContextGraph::Finalize(
|
||||
@@ -55,6 +107,22 @@ std::pair<float, const ContextState *> ContextGraph::Finalize(
|
||||
return std::make_pair(score, root_.get());
|
||||
}
|
||||
|
||||
std::pair<bool, const ContextState *> ContextGraph::IsMatched(
|
||||
const ContextState *state) const {
|
||||
bool status = false;
|
||||
const ContextState *node = nullptr;
|
||||
if (state->is_end) {
|
||||
status = true;
|
||||
node = state;
|
||||
} else {
|
||||
if (state->output != nullptr) {
|
||||
status = true;
|
||||
node = state->output;
|
||||
}
|
||||
}
|
||||
return std::make_pair(status, node);
|
||||
}
|
||||
|
||||
void ContextGraph::FillFailOutput() const {
|
||||
std::queue<const ContextState *> node_queue;
|
||||
for (auto &kv : root_->next) {
|
||||
|
||||
Reference in New Issue
Block a user