73 lines
2.2 KiB
C++
73 lines
2.2 KiB
C++
// sherpa-onnx/csrc/context-graph-test.cc
|
|
//
|
|
// Copyright (c) 2023 Xiaomi Corporation
|
|
|
|
#include "sherpa-onnx/csrc/context-graph.h"
|
|
|
|
#include <chrono> // NOLINT
|
|
#include <map>
|
|
#include <random>
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
#include "gtest/gtest.h"
|
|
#include "sherpa-onnx/csrc/macros.h"
|
|
|
|
namespace sherpa_onnx {
|
|
|
|
TEST(ContextGraph, TestBasic) {
|
|
std::vector<std::string> contexts_str(
|
|
{"S", "HE", "SHE", "SHELL", "HIS", "HERS", "HELLO", "THIS", "THEM"});
|
|
std::vector<std::vector<int32_t>> contexts;
|
|
for (int32_t i = 0; i < contexts_str.size(); ++i) {
|
|
contexts.emplace_back(contexts_str[i].begin(), contexts_str[i].end());
|
|
}
|
|
auto context_graph = ContextGraph(contexts, 1);
|
|
|
|
auto queries = std::map<std::string, float>{
|
|
{"HEHERSHE", 14}, {"HERSHE", 12}, {"HISHE", 9},
|
|
{"SHED", 6}, {"SHELF", 6}, {"HELL", 2},
|
|
{"HELLO", 7}, {"DHRHISQ", 4}, {"THEN", 2}};
|
|
|
|
for (const auto &iter : queries) {
|
|
float total_scores = 0;
|
|
auto state = context_graph.Root();
|
|
for (auto q : iter.first) {
|
|
auto res = context_graph.ForwardOneStep(state, q);
|
|
total_scores += res.first;
|
|
state = res.second;
|
|
}
|
|
auto res = context_graph.Finalize(state);
|
|
EXPECT_EQ(res.second->token, -1);
|
|
total_scores += res.first;
|
|
EXPECT_EQ(total_scores, iter.second);
|
|
}
|
|
}
|
|
|
|
TEST(ContextGraph, Benchmark) {
|
|
std::random_device rd;
|
|
std::mt19937 mt(rd());
|
|
std::uniform_int_distribution<int32_t> char_dist(0, 25);
|
|
std::uniform_int_distribution<int32_t> len_dist(3, 8);
|
|
for (int32_t num = 10; num <= 10000; num *= 10) {
|
|
std::vector<std::vector<int32_t>> contexts;
|
|
for (int32_t i = 0; i < num; ++i) {
|
|
std::vector<int32_t> tmp;
|
|
int32_t word_len = len_dist(mt);
|
|
for (int32_t j = 0; j < word_len; ++j) {
|
|
tmp.push_back(char_dist(mt));
|
|
}
|
|
contexts.push_back(std::move(tmp));
|
|
}
|
|
auto start = std::chrono::high_resolution_clock::now();
|
|
auto context_graph = ContextGraph(contexts, 1);
|
|
auto stop = std::chrono::high_resolution_clock::now();
|
|
auto duration =
|
|
std::chrono::duration_cast<std::chrono::microseconds>(stop - start);
|
|
SHERPA_ONNX_LOGE("Construct context graph for %d item takes %ld us.", num,
|
|
duration.count());
|
|
}
|
|
}
|
|
|
|
} // namespace sherpa_onnx
|