// sherpa-onnx/csrc/context-graph-test.cc // // Copyright (c) 2023 Xiaomi Corporation #include "sherpa-onnx/csrc/context-graph.h" #include #include #include #include "gtest/gtest.h" namespace sherpa_onnx { TEST(ContextGraph, TestBasic) { std::vector contexts_str( {"S", "HE", "SHE", "SHELL", "HIS", "HERS", "HELLO", "THIS", "THEM"}); std::vector> contexts; for (int32_t i = 0; i < contexts_str.size(); ++i) { contexts.emplace_back(contexts_str[i].begin(), contexts_str[i].end()); } auto context_graph = ContextGraph(contexts, 1); auto queries = std::map{ {"HEHERSHE", 14}, {"HERSHE", 12}, {"HISHE", 9}, {"SHED", 6}, {"HELL", 2}, {"HELLO", 7}, {"DHRHISQ", 4}, {"THEN", 2}}; for (const auto &iter : queries) { float total_scores = 0; auto state = context_graph.Root(); for (auto q : iter.first) { auto res = context_graph.ForwardOneStep(state, q); total_scores += res.first; state = res.second; } auto res = context_graph.Finalize(state); EXPECT_EQ(res.second->token, -1); total_scores += res.first; EXPECT_EQ(total_scores, iter.second); } } } // namespace sherpa_onnx