Refactor hotwords,support loading hotwords from file (#296)

This commit is contained in:
Wei Kang
2023-09-14 19:33:17 +08:00
committed by GitHub
parent 087367d7fe
commit 47184f9db7
34 changed files with 803 additions and 300 deletions

View File

@@ -4,11 +4,14 @@
#include "sherpa-onnx/csrc/context-graph.h"
#include <chrono> // NOLINT
#include <map>
#include <random>
#include <string>
#include <vector>
#include "gtest/gtest.h"
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
@@ -41,4 +44,29 @@ TEST(ContextGraph, TestBasic) {
}
}
TEST(ContextGraph, Benchmark) {
std::random_device rd;
std::mt19937 mt(rd());
std::uniform_int_distribution<int32_t> char_dist(0, 25);
std::uniform_int_distribution<int32_t> len_dist(3, 8);
for (int32_t num = 10; num <= 10000; num *= 10) {
std::vector<std::vector<int32_t>> contexts;
for (int32_t i = 0; i < num; ++i) {
std::vector<int32_t> tmp;
int32_t word_len = len_dist(mt);
for (int32_t j = 0; j < word_len; ++j) {
tmp.push_back(char_dist(mt));
}
contexts.push_back(std::move(tmp));
}
auto start = std::chrono::high_resolution_clock::now();
auto context_graph = ContextGraph(contexts, 1);
auto stop = std::chrono::high_resolution_clock::now();
auto duration =
std::chrono::duration_cast<std::chrono::microseconds>(stop - start);
SHERPA_ONNX_LOGE("Construct context graph for %d item takes %ld us.", num,
duration.count());
}
}
} // namespace sherpa_onnx