Refactor hotwords,support loading hotwords from file (#296)
This commit is contained in:
@@ -4,11 +4,14 @@
|
||||
|
||||
#include "sherpa-onnx/csrc/context-graph.h"
|
||||
|
||||
#include <chrono> // NOLINT
|
||||
#include <map>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
@@ -41,4 +44,29 @@ TEST(ContextGraph, TestBasic) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ContextGraph, Benchmark) {
|
||||
std::random_device rd;
|
||||
std::mt19937 mt(rd());
|
||||
std::uniform_int_distribution<int32_t> char_dist(0, 25);
|
||||
std::uniform_int_distribution<int32_t> len_dist(3, 8);
|
||||
for (int32_t num = 10; num <= 10000; num *= 10) {
|
||||
std::vector<std::vector<int32_t>> contexts;
|
||||
for (int32_t i = 0; i < num; ++i) {
|
||||
std::vector<int32_t> tmp;
|
||||
int32_t word_len = len_dist(mt);
|
||||
for (int32_t j = 0; j < word_len; ++j) {
|
||||
tmp.push_back(char_dist(mt));
|
||||
}
|
||||
contexts.push_back(std::move(tmp));
|
||||
}
|
||||
auto start = std::chrono::high_resolution_clock::now();
|
||||
auto context_graph = ContextGraph(contexts, 1);
|
||||
auto stop = std::chrono::high_resolution_clock::now();
|
||||
auto duration =
|
||||
std::chrono::duration_cast<std::chrono::microseconds>(stop - start);
|
||||
SHERPA_ONNX_LOGE("Construct context graph for %d item takes %ld us.", num,
|
||||
duration.count());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
Reference in New Issue
Block a user