Implement context biasing with a Aho Corasick automata (#145)
* Implement context graph * Modify the interface to support context biasing * Support context biasing in modified beam search; add python wrapper * Support context biasing in python api example * Minor fixes * Fix context graph * Minor fixes * Fix tests * Fix style * Fix style * Fix comments * Minor fixes * Add missing header * Replace std::shared_ptr with std::unique_ptr for effciency * Build graph in constructor * Fix comments * Minor fixes * Fix docs
This commit is contained in:
43
sherpa-onnx/csrc/context-graph-test.cc
Normal file
43
sherpa-onnx/csrc/context-graph-test.cc
Normal file
@@ -0,0 +1,43 @@
|
||||
// sherpa-onnx/csrc/context-graph-test.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/context-graph.h"
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
TEST(ContextGraph, TestBasic) {
|
||||
std::vector<std::string> contexts_str(
|
||||
{"S", "HE", "SHE", "SHELL", "HIS", "HERS", "HELLO", "THIS", "THEM"});
|
||||
std::vector<std::vector<int32_t>> contexts;
|
||||
for (int32_t i = 0; i < contexts_str.size(); ++i) {
|
||||
contexts.emplace_back(contexts_str[i].begin(), contexts_str[i].end());
|
||||
}
|
||||
auto context_graph = ContextGraph(contexts, 1);
|
||||
|
||||
auto queries = std::map<std::string, float>{
|
||||
{"HEHERSHE", 14}, {"HERSHE", 12}, {"HISHE", 9}, {"SHED", 6},
|
||||
{"HELL", 2}, {"HELLO", 7}, {"DHRHISQ", 4}, {"THEN", 2}};
|
||||
|
||||
for (const auto &iter : queries) {
|
||||
float total_scores = 0;
|
||||
auto state = context_graph.Root();
|
||||
for (auto q : iter.first) {
|
||||
auto res = context_graph.ForwardOneStep(state, q);
|
||||
total_scores += res.first;
|
||||
state = res.second;
|
||||
}
|
||||
auto res = context_graph.Finalize(state);
|
||||
EXPECT_EQ(res.second->token, -1);
|
||||
total_scores += res.first;
|
||||
EXPECT_EQ(total_scores, iter.second);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
Reference in New Issue
Block a user