decoder for open vocabulary keyword spotting (#505)
* various fixes to ContextGraph to support open vocabulary keywords decoder * Add keyword spotter runtime * Add binary * First version works * Minor fixes * update text2token * default values * Add jni for kws * add kws android project * Minor fixes * Remove unused interface * Minor fixes * Add workflow * handle extra info in texts * Minor fixes * Add more comments * Fix ci * fix cpp style * Add input box in android demo so that users can specify their keywords * Fix cpp style * Fix comments * Minor fixes * Minor fixes * minor fixes * Minor fixes * Minor fixes * Add CI * Fix code style * cpplint * Fix comments * Fix error
This commit is contained in:
@@ -5,6 +5,7 @@
|
||||
#include "sherpa-onnx/csrc/context-graph.h"
|
||||
|
||||
#include <chrono> // NOLINT
|
||||
#include <cmath>
|
||||
#include <map>
|
||||
#include <random>
|
||||
#include <string>
|
||||
@@ -15,27 +16,25 @@
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
TEST(ContextGraph, TestBasic) {
|
||||
static void TestHelper(const std::map<std::string, float> &queries, float score,
|
||||
bool strict_mode) {
|
||||
std::vector<std::string> contexts_str(
|
||||
{"S", "HE", "SHE", "SHELL", "HIS", "HERS", "HELLO", "THIS", "THEM"});
|
||||
std::vector<std::vector<int32_t>> contexts;
|
||||
std::vector<float> scores;
|
||||
for (int32_t i = 0; i < contexts_str.size(); ++i) {
|
||||
contexts.emplace_back(contexts_str[i].begin(), contexts_str[i].end());
|
||||
scores.push_back(std::round(score / contexts_str[i].size() * 100) / 100);
|
||||
}
|
||||
auto context_graph = ContextGraph(contexts, 1);
|
||||
|
||||
auto queries = std::map<std::string, float>{
|
||||
{"HEHERSHE", 14}, {"HERSHE", 12}, {"HISHE", 9},
|
||||
{"SHED", 6}, {"SHELF", 6}, {"HELL", 2},
|
||||
{"HELLO", 7}, {"DHRHISQ", 4}, {"THEN", 2}};
|
||||
auto context_graph = ContextGraph(contexts, 1, scores);
|
||||
|
||||
for (const auto &iter : queries) {
|
||||
float total_scores = 0;
|
||||
auto state = context_graph.Root();
|
||||
for (auto q : iter.first) {
|
||||
auto res = context_graph.ForwardOneStep(state, q);
|
||||
total_scores += res.first;
|
||||
state = res.second;
|
||||
auto res = context_graph.ForwardOneStep(state, q, strict_mode);
|
||||
total_scores += std::get<0>(res);
|
||||
state = std::get<1>(res);
|
||||
}
|
||||
auto res = context_graph.Finalize(state);
|
||||
EXPECT_EQ(res.second->token, -1);
|
||||
@@ -44,6 +43,37 @@ TEST(ContextGraph, TestBasic) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ContextGraph, TestBasic) {
|
||||
auto queries = std::map<std::string, float>{
|
||||
{"HEHERSHE", 14}, {"HERSHE", 12}, {"HISHE", 9},
|
||||
{"SHED", 6}, {"SHELF", 6}, {"HELL", 2},
|
||||
{"HELLO", 7}, {"DHRHISQ", 4}, {"THEN", 2}};
|
||||
TestHelper(queries, 0, true);
|
||||
}
|
||||
|
||||
TEST(ContextGraph, TestBasicNonStrict) {
|
||||
auto queries = std::map<std::string, float>{
|
||||
{"HEHERSHE", 7}, {"HERSHE", 5}, {"HISHE", 5}, {"SHED", 3}, {"SHELF", 3},
|
||||
{"HELL", 2}, {"HELLO", 2}, {"DHRHISQ", 3}, {"THEN", 2}};
|
||||
TestHelper(queries, 0, false);
|
||||
}
|
||||
|
||||
TEST(ContextGraph, TestCustomize) {
|
||||
auto queries = std::map<std::string, float>{
|
||||
{"HEHERSHE", 35.84}, {"HERSHE", 30.84}, {"HISHE", 24.18},
|
||||
{"SHED", 18.34}, {"SHELF", 18.34}, {"HELL", 5},
|
||||
{"HELLO", 13}, {"DHRHISQ", 10.84}, {"THEN", 5}};
|
||||
TestHelper(queries, 5, true);
|
||||
}
|
||||
|
||||
TEST(ContextGraph, TestCustomizeNonStrict) {
|
||||
auto queries = std::map<std::string, float>{
|
||||
{"HEHERSHE", 20}, {"HERSHE", 15}, {"HISHE", 10.84},
|
||||
{"SHED", 10}, {"SHELF", 10}, {"HELL", 5},
|
||||
{"HELLO", 5}, {"DHRHISQ", 5.84}, {"THEN", 5}};
|
||||
TestHelper(queries, 5, false);
|
||||
}
|
||||
|
||||
TEST(ContextGraph, Benchmark) {
|
||||
std::random_device rd;
|
||||
std::mt19937 mt(rd());
|
||||
|
||||
Reference in New Issue
Block a user