// sherpa-onnx/csrc/context-graph-test.cc // // Copyright (c) 2023 Xiaomi Corporation #include "sherpa-onnx/csrc/context-graph.h" #include // NOLINT #include #include #include #include #include "gtest/gtest.h" #include "sherpa-onnx/csrc/macros.h" namespace sherpa_onnx { TEST(ContextGraph, TestBasic) { std::vector contexts_str( {"S", "HE", "SHE", "SHELL", "HIS", "HERS", "HELLO", "THIS", "THEM"}); std::vector> contexts; for (int32_t i = 0; i < contexts_str.size(); ++i) { contexts.emplace_back(contexts_str[i].begin(), contexts_str[i].end()); } auto context_graph = ContextGraph(contexts, 1); auto queries = std::map{ {"HEHERSHE", 14}, {"HERSHE", 12}, {"HISHE", 9}, {"SHED", 6}, {"SHELF", 6}, {"HELL", 2}, {"HELLO", 7}, {"DHRHISQ", 4}, {"THEN", 2}}; for (const auto &iter : queries) { float total_scores = 0; auto state = context_graph.Root(); for (auto q : iter.first) { auto res = context_graph.ForwardOneStep(state, q); total_scores += res.first; state = res.second; } auto res = context_graph.Finalize(state); EXPECT_EQ(res.second->token, -1); total_scores += res.first; EXPECT_EQ(total_scores, iter.second); } } TEST(ContextGraph, Benchmark) { std::random_device rd; std::mt19937 mt(rd()); std::uniform_int_distribution char_dist(0, 25); std::uniform_int_distribution len_dist(3, 8); for (int32_t num = 10; num <= 10000; num *= 10) { std::vector> contexts; for (int32_t i = 0; i < num; ++i) { std::vector tmp; int32_t word_len = len_dist(mt); for (int32_t j = 0; j < word_len; ++j) { tmp.push_back(char_dist(mt)); } contexts.push_back(std::move(tmp)); } auto start = std::chrono::high_resolution_clock::now(); auto context_graph = ContextGraph(contexts, 1); auto stop = std::chrono::high_resolution_clock::now(); auto duration = std::chrono::duration_cast(stop - start); SHERPA_ONNX_LOGE("Construct context graph for %d item takes %ld us.", num, duration.count()); } } } // namespace sherpa_onnx