Refactor hotwords,support loading hotwords from file (#296)
This commit is contained in:
@@ -72,6 +72,7 @@ set(sources
|
||||
text-utils.cc
|
||||
transpose.cc
|
||||
unbind.cc
|
||||
utils.cc
|
||||
wave-reader.cc
|
||||
)
|
||||
|
||||
|
||||
@@ -4,11 +4,14 @@
|
||||
|
||||
#include "sherpa-onnx/csrc/context-graph.h"
|
||||
|
||||
#include <chrono> // NOLINT
|
||||
#include <map>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
@@ -41,4 +44,29 @@ TEST(ContextGraph, TestBasic) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ContextGraph, Benchmark) {
|
||||
std::random_device rd;
|
||||
std::mt19937 mt(rd());
|
||||
std::uniform_int_distribution<int32_t> char_dist(0, 25);
|
||||
std::uniform_int_distribution<int32_t> len_dist(3, 8);
|
||||
for (int32_t num = 10; num <= 10000; num *= 10) {
|
||||
std::vector<std::vector<int32_t>> contexts;
|
||||
for (int32_t i = 0; i < num; ++i) {
|
||||
std::vector<int32_t> tmp;
|
||||
int32_t word_len = len_dist(mt);
|
||||
for (int32_t j = 0; j < word_len; ++j) {
|
||||
tmp.push_back(char_dist(mt));
|
||||
}
|
||||
contexts.push_back(std::move(tmp));
|
||||
}
|
||||
auto start = std::chrono::high_resolution_clock::now();
|
||||
auto context_graph = ContextGraph(contexts, 1);
|
||||
auto stop = std::chrono::high_resolution_clock::now();
|
||||
auto duration =
|
||||
std::chrono::duration_cast<std::chrono::microseconds>(stop - start);
|
||||
SHERPA_ONNX_LOGE("Construct context graph for %d item takes %ld us.", num,
|
||||
duration.count());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_IMPL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
@@ -32,7 +33,7 @@ class OfflineRecognizerImpl {
|
||||
virtual ~OfflineRecognizerImpl() = default;
|
||||
|
||||
virtual std::unique_ptr<OfflineStream> CreateStream(
|
||||
const std::vector<std::vector<int32_t>> &context_list) const {
|
||||
const std::string &hotwords) const {
|
||||
SHERPA_ONNX_LOGE("Only transducer models support contextual biasing.");
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
@@ -5,7 +5,9 @@
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_IMPL_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_IMPL_H_
|
||||
|
||||
#include <fstream>
|
||||
#include <memory>
|
||||
#include <regex> // NOLINT
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
@@ -16,6 +18,7 @@
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/context-graph.h"
|
||||
#include "sherpa-onnx/csrc/log.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer.h"
|
||||
@@ -25,6 +28,7 @@
|
||||
#include "sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.h"
|
||||
#include "sherpa-onnx/csrc/pad-sequence.h"
|
||||
#include "sherpa-onnx/csrc/symbol-table.h"
|
||||
#include "sherpa-onnx/csrc/utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
@@ -60,6 +64,9 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
|
||||
: config_(config),
|
||||
symbol_table_(config_.model_config.tokens),
|
||||
model_(std::make_unique<OfflineTransducerModel>(config_.model_config)) {
|
||||
if (!config_.hotwords_file.empty()) {
|
||||
InitHotwords();
|
||||
}
|
||||
if (config_.decoding_method == "greedy_search") {
|
||||
decoder_ =
|
||||
std::make_unique<OfflineTransducerGreedySearchDecoder>(model_.get());
|
||||
@@ -105,17 +112,24 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
|
||||
#endif
|
||||
|
||||
std::unique_ptr<OfflineStream> CreateStream(
|
||||
const std::vector<std::vector<int32_t>> &context_list) const override {
|
||||
// We create context_graph at this level, because we might have default
|
||||
// context_graph(will be added later if needed) that belongs to the whole
|
||||
// model rather than each stream.
|
||||
const std::string &hotwords) const override {
|
||||
auto hws = std::regex_replace(hotwords, std::regex("/"), "\n");
|
||||
std::istringstream is(hws);
|
||||
std::vector<std::vector<int32_t>> current;
|
||||
if (!EncodeHotwords(is, symbol_table_, ¤t)) {
|
||||
SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s",
|
||||
hotwords.c_str());
|
||||
}
|
||||
current.insert(current.end(), hotwords_.begin(), hotwords_.end());
|
||||
|
||||
auto context_graph =
|
||||
std::make_shared<ContextGraph>(context_list, config_.context_score);
|
||||
std::make_shared<ContextGraph>(current, config_.hotwords_score);
|
||||
return std::make_unique<OfflineStream>(config_.feat_config, context_graph);
|
||||
}
|
||||
|
||||
std::unique_ptr<OfflineStream> CreateStream() const override {
|
||||
return std::make_unique<OfflineStream>(config_.feat_config);
|
||||
return std::make_unique<OfflineStream>(config_.feat_config,
|
||||
hotwords_graph_);
|
||||
}
|
||||
|
||||
void DecodeStreams(OfflineStream **ss, int32_t n) const override {
|
||||
@@ -171,9 +185,29 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
|
||||
}
|
||||
}
|
||||
|
||||
void InitHotwords() {
|
||||
// each line in hotwords_file contains space-separated words
|
||||
|
||||
std::ifstream is(config_.hotwords_file);
|
||||
if (!is) {
|
||||
SHERPA_ONNX_LOGE("Open hotwords file failed: %s",
|
||||
config_.hotwords_file.c_str());
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
if (!EncodeHotwords(is, symbol_table_, &hotwords_)) {
|
||||
SHERPA_ONNX_LOGE("Encode hotwords failed.");
|
||||
exit(-1);
|
||||
}
|
||||
hotwords_graph_ =
|
||||
std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score);
|
||||
}
|
||||
|
||||
private:
|
||||
OfflineRecognizerConfig config_;
|
||||
SymbolTable symbol_table_;
|
||||
std::vector<std::vector<int32_t>> hotwords_;
|
||||
ContextGraphPtr hotwords_graph_;
|
||||
std::unique_ptr<OfflineTransducerModel> model_;
|
||||
std::unique_ptr<OfflineTransducerDecoder> decoder_;
|
||||
std::unique_ptr<OfflineLM> lm_;
|
||||
|
||||
@@ -26,7 +26,15 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) {
|
||||
|
||||
po->Register("max-active-paths", &max_active_paths,
|
||||
"Used only when decoding_method is modified_beam_search");
|
||||
po->Register("context-score", &context_score,
|
||||
|
||||
po->Register(
|
||||
"hotwords-file", &hotwords_file,
|
||||
"The file containing hotwords, one words/phrases per line, and for each"
|
||||
"phrase the bpe/cjkchar are separated by a space. For example: "
|
||||
"▁HE LL O ▁WORLD"
|
||||
"你 好 世 界");
|
||||
|
||||
po->Register("hotwords-score", &hotwords_score,
|
||||
"The bonus score for each token in context word/phrase. "
|
||||
"Used only when decoding_method is modified_beam_search");
|
||||
}
|
||||
@@ -53,7 +61,8 @@ std::string OfflineRecognizerConfig::ToString() const {
|
||||
os << "lm_config=" << lm_config.ToString() << ", ";
|
||||
os << "decoding_method=\"" << decoding_method << "\", ";
|
||||
os << "max_active_paths=" << max_active_paths << ", ";
|
||||
os << "context_score=" << context_score << ")";
|
||||
os << "hotwords_file=\"" << hotwords_file << "\", ";
|
||||
os << "hotwords_score=" << hotwords_score << ")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
@@ -70,8 +79,8 @@ OfflineRecognizer::OfflineRecognizer(const OfflineRecognizerConfig &config)
|
||||
OfflineRecognizer::~OfflineRecognizer() = default;
|
||||
|
||||
std::unique_ptr<OfflineStream> OfflineRecognizer::CreateStream(
|
||||
const std::vector<std::vector<int32_t>> &context_list) const {
|
||||
return impl_->CreateStream(context_list);
|
||||
const std::string &hotwords) const {
|
||||
return impl_->CreateStream(hotwords);
|
||||
}
|
||||
|
||||
std::unique_ptr<OfflineStream> OfflineRecognizer::CreateStream() const {
|
||||
|
||||
@@ -31,7 +31,10 @@ struct OfflineRecognizerConfig {
|
||||
|
||||
std::string decoding_method = "greedy_search";
|
||||
int32_t max_active_paths = 4;
|
||||
float context_score = 1.5;
|
||||
|
||||
std::string hotwords_file;
|
||||
float hotwords_score = 1.5;
|
||||
|
||||
// only greedy_search is implemented
|
||||
// TODO(fangjun): Implement modified_beam_search
|
||||
|
||||
@@ -40,13 +43,16 @@ struct OfflineRecognizerConfig {
|
||||
const OfflineModelConfig &model_config,
|
||||
const OfflineLMConfig &lm_config,
|
||||
const std::string &decoding_method,
|
||||
int32_t max_active_paths, float context_score)
|
||||
int32_t max_active_paths,
|
||||
const std::string &hotwords_file,
|
||||
float hotwords_score)
|
||||
: feat_config(feat_config),
|
||||
model_config(model_config),
|
||||
lm_config(lm_config),
|
||||
decoding_method(decoding_method),
|
||||
max_active_paths(max_active_paths),
|
||||
context_score(context_score) {}
|
||||
hotwords_file(hotwords_file),
|
||||
hotwords_score(hotwords_score) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
@@ -69,9 +75,17 @@ class OfflineRecognizer {
|
||||
/// Create a stream for decoding.
|
||||
std::unique_ptr<OfflineStream> CreateStream() const;
|
||||
|
||||
/// Create a stream for decoding.
|
||||
/** Create a stream for decoding.
|
||||
*
|
||||
* @param The hotwords for this string, it might contain several hotwords,
|
||||
* the hotwords are separated by "/". In each of the hotwords, there
|
||||
* are cjkchars or bpes, the bpe/cjkchar are separated by space (" ").
|
||||
* For example, hotwords I LOVE YOU and HELLO WORLD, looks like:
|
||||
*
|
||||
* "▁I ▁LOVE ▁YOU/▁HE LL O ▁WORLD"
|
||||
*/
|
||||
std::unique_ptr<OfflineStream> CreateStream(
|
||||
const std::vector<std::vector<int32_t>> &context_list) const;
|
||||
const std::string &hotwords) const;
|
||||
|
||||
/** Decode a single stream
|
||||
*
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#define SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_IMPL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
@@ -29,7 +30,7 @@ class OnlineRecognizerImpl {
|
||||
virtual std::unique_ptr<OnlineStream> CreateStream() const = 0;
|
||||
|
||||
virtual std::unique_ptr<OnlineStream> CreateStream(
|
||||
const std::vector<std::vector<int32_t>> &contexts) const {
|
||||
const std::string &hotwords) const {
|
||||
SHERPA_ONNX_LOGE("Only transducer models support contextual biasing.");
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
@@ -7,6 +7,8 @@
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <regex> // NOLINT
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
@@ -20,6 +22,7 @@
|
||||
#include "sherpa-onnx/csrc/online-transducer-model.h"
|
||||
#include "sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h"
|
||||
#include "sherpa-onnx/csrc/symbol-table.h"
|
||||
#include "sherpa-onnx/csrc/utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
@@ -57,6 +60,9 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
||||
model_(OnlineTransducerModel::Create(config.model_config)),
|
||||
sym_(config.model_config.tokens),
|
||||
endpoint_(config_.endpoint_config) {
|
||||
if (!config_.hotwords_file.empty()) {
|
||||
InitHotwords();
|
||||
}
|
||||
if (sym_.contains("<unk>")) {
|
||||
unk_id_ = sym_["<unk>"];
|
||||
}
|
||||
@@ -106,18 +112,24 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
||||
#endif
|
||||
|
||||
std::unique_ptr<OnlineStream> CreateStream() const override {
|
||||
auto stream = std::make_unique<OnlineStream>(config_.feat_config);
|
||||
auto stream =
|
||||
std::make_unique<OnlineStream>(config_.feat_config, hotwords_graph_);
|
||||
InitOnlineStream(stream.get());
|
||||
return stream;
|
||||
}
|
||||
|
||||
std::unique_ptr<OnlineStream> CreateStream(
|
||||
const std::vector<std::vector<int32_t>> &contexts) const override {
|
||||
// We create context_graph at this level, because we might have default
|
||||
// context_graph(will be added later if needed) that belongs to the whole
|
||||
// model rather than each stream.
|
||||
const std::string &hotwords) const override {
|
||||
auto hws = std::regex_replace(hotwords, std::regex("/"), "\n");
|
||||
std::istringstream is(hws);
|
||||
std::vector<std::vector<int32_t>> current;
|
||||
if (!EncodeHotwords(is, sym_, ¤t)) {
|
||||
SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s",
|
||||
hotwords.c_str());
|
||||
}
|
||||
current.insert(current.end(), hotwords_.begin(), hotwords_.end());
|
||||
auto context_graph =
|
||||
std::make_shared<ContextGraph>(contexts, config_.context_score);
|
||||
std::make_shared<ContextGraph>(current, config_.hotwords_score);
|
||||
auto stream =
|
||||
std::make_unique<OnlineStream>(config_.feat_config, context_graph);
|
||||
InitOnlineStream(stream.get());
|
||||
@@ -253,6 +265,24 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
||||
s->Reset();
|
||||
}
|
||||
|
||||
void InitHotwords() {
|
||||
// each line in hotwords_file contains space-separated words
|
||||
|
||||
std::ifstream is(config_.hotwords_file);
|
||||
if (!is) {
|
||||
SHERPA_ONNX_LOGE("Open hotwords file failed: %s",
|
||||
config_.hotwords_file.c_str());
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
if (!EncodeHotwords(is, sym_, &hotwords_)) {
|
||||
SHERPA_ONNX_LOGE("Encode hotwords failed.");
|
||||
exit(-1);
|
||||
}
|
||||
hotwords_graph_ =
|
||||
std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score);
|
||||
}
|
||||
|
||||
private:
|
||||
void InitOnlineStream(OnlineStream *stream) const {
|
||||
auto r = decoder_->GetEmptyResult();
|
||||
@@ -271,6 +301,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
||||
|
||||
private:
|
||||
OnlineRecognizerConfig config_;
|
||||
std::vector<std::vector<int32_t>> hotwords_;
|
||||
ContextGraphPtr hotwords_graph_;
|
||||
std::unique_ptr<OnlineTransducerModel> model_;
|
||||
std::unique_ptr<OnlineLM> lm_;
|
||||
std::unique_ptr<OnlineTransducerDecoder> decoder_;
|
||||
|
||||
@@ -57,9 +57,15 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) {
|
||||
"True to enable endpoint detection. False to disable it.");
|
||||
po->Register("max-active-paths", &max_active_paths,
|
||||
"beam size used in modified beam search.");
|
||||
po->Register("context-score", &context_score,
|
||||
po->Register("hotwords-score", &hotwords_score,
|
||||
"The bonus score for each token in context word/phrase. "
|
||||
"Used only when decoding_method is modified_beam_search");
|
||||
po->Register(
|
||||
"hotwords-file", &hotwords_file,
|
||||
"The file containing hotwords, one words/phrases per line, and for each"
|
||||
"phrase the bpe/cjkchar are separated by a space. For example: "
|
||||
"▁HE LL O ▁WORLD"
|
||||
"你 好 世 界");
|
||||
po->Register("decoding-method", &decoding_method,
|
||||
"decoding method,"
|
||||
"now support greedy_search and modified_beam_search.");
|
||||
@@ -87,7 +93,8 @@ std::string OnlineRecognizerConfig::ToString() const {
|
||||
os << "endpoint_config=" << endpoint_config.ToString() << ", ";
|
||||
os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ", ";
|
||||
os << "max_active_paths=" << max_active_paths << ", ";
|
||||
os << "context_score=" << context_score << ", ";
|
||||
os << "hotwords_score=" << hotwords_score << ", ";
|
||||
os << "hotwords_file=\"" << hotwords_file << "\", ";
|
||||
os << "decoding_method=\"" << decoding_method << "\")";
|
||||
|
||||
return os.str();
|
||||
@@ -109,8 +116,8 @@ std::unique_ptr<OnlineStream> OnlineRecognizer::CreateStream() const {
|
||||
}
|
||||
|
||||
std::unique_ptr<OnlineStream> OnlineRecognizer::CreateStream(
|
||||
const std::vector<std::vector<int32_t>> &context_list) const {
|
||||
return impl_->CreateStream(context_list);
|
||||
const std::string &hotwords) const {
|
||||
return impl_->CreateStream(hotwords);
|
||||
}
|
||||
|
||||
bool OnlineRecognizer::IsReady(OnlineStream *s) const {
|
||||
|
||||
@@ -78,8 +78,10 @@ struct OnlineRecognizerConfig {
|
||||
|
||||
// used only for modified_beam_search
|
||||
int32_t max_active_paths = 4;
|
||||
|
||||
/// used only for modified_beam_search
|
||||
float context_score = 1.5;
|
||||
float hotwords_score = 1.5;
|
||||
std::string hotwords_file;
|
||||
|
||||
OnlineRecognizerConfig() = default;
|
||||
|
||||
@@ -89,14 +91,16 @@ struct OnlineRecognizerConfig {
|
||||
const EndpointConfig &endpoint_config,
|
||||
bool enable_endpoint,
|
||||
const std::string &decoding_method,
|
||||
int32_t max_active_paths, float context_score)
|
||||
int32_t max_active_paths,
|
||||
const std::string &hotwords_file, float hotwords_score)
|
||||
: feat_config(feat_config),
|
||||
model_config(model_config),
|
||||
endpoint_config(endpoint_config),
|
||||
enable_endpoint(enable_endpoint),
|
||||
decoding_method(decoding_method),
|
||||
max_active_paths(max_active_paths),
|
||||
context_score(context_score) {}
|
||||
hotwords_score(hotwords_score),
|
||||
hotwords_file(hotwords_file) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
@@ -119,9 +123,16 @@ class OnlineRecognizer {
|
||||
/// Create a stream for decoding.
|
||||
std::unique_ptr<OnlineStream> CreateStream() const;
|
||||
|
||||
// Create a stream with context phrases
|
||||
std::unique_ptr<OnlineStream> CreateStream(
|
||||
const std::vector<std::vector<int32_t>> &context_list) const;
|
||||
/** Create a stream for decoding.
|
||||
*
|
||||
* @param The hotwords for this string, it might contain several hotwords,
|
||||
* the hotwords are separated by "/". In each of the hotwords, there
|
||||
* are cjkchars or bpes, the bpe/cjkchar are separated by space (" ").
|
||||
* For example, hotwords I LOVE YOU and HELLO WORLD, looks like:
|
||||
*
|
||||
* "▁I ▁LOVE ▁YOU/▁HE LL O ▁WORLD"
|
||||
*/
|
||||
std::unique_ptr<OnlineStream> CreateStream(const std::string &hotwords) const;
|
||||
|
||||
/**
|
||||
* Return true if the given stream has enough frames for decoding.
|
||||
|
||||
54
sherpa-onnx/csrc/utils.cc
Normal file
54
sherpa-onnx/csrc/utils.cc
Normal file
@@ -0,0 +1,54 @@
|
||||
// sherpa-onnx/csrc/utils.cc
|
||||
//
|
||||
// Copyright 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/utils.h"
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/log.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
bool EncodeHotwords(std::istream &is, const SymbolTable &symbol_table,
|
||||
std::vector<std::vector<int32_t>> *hotwords) {
|
||||
hotwords->clear();
|
||||
std::vector<int32_t> tmp;
|
||||
std::string line;
|
||||
std::string word;
|
||||
|
||||
while (std::getline(is, line)) {
|
||||
std::istringstream iss(line);
|
||||
std::vector<std::string> syms;
|
||||
while (iss >> word) {
|
||||
if (word.size() >= 3) {
|
||||
// For BPE-based models, we replace ▁ with a space
|
||||
// Unicode 9601, hex 0x2581, utf8 0xe29681
|
||||
const uint8_t *p = reinterpret_cast<const uint8_t *>(word.c_str());
|
||||
if (p[0] == 0xe2 && p[1] == 0x96 && p[2] == 0x81) {
|
||||
word = word.replace(0, 3, " ");
|
||||
}
|
||||
}
|
||||
if (symbol_table.contains(word)) {
|
||||
int32_t number = symbol_table[word];
|
||||
tmp.push_back(number);
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Cannot find ID for hotword %s at line: %s. (Hint: words on "
|
||||
"the "
|
||||
"same line are separated by spaces)",
|
||||
word.c_str(), line.c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
hotwords->push_back(std::move(tmp));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
33
sherpa-onnx/csrc/utils.h
Normal file
33
sherpa-onnx/csrc/utils.h
Normal file
@@ -0,0 +1,33 @@
|
||||
// sherpa-onnx/csrc/utils.h
|
||||
//
|
||||
// Copyright 2023 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_UTILS_H_
|
||||
#define SHERPA_ONNX_CSRC_UTILS_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/symbol-table.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
/* Encode the hotwords in an input stream to be tokens ids.
|
||||
*
|
||||
* @param is The input stream, it contains several lines, one hotword for each
|
||||
* line. For each hotword, the tokens (cjkchar or bpe) are separated
|
||||
* by spaces.
|
||||
* @param symbol_table The tokens table mapping symbols to ids. All the symbols
|
||||
* in the stream should be in the symbol_table, if not this
|
||||
* function returns fasle.
|
||||
*
|
||||
* @@param hotwords The encoded ids to be written to.
|
||||
*
|
||||
* @return If all the symbols from ``is`` are in the symbol_table, returns true
|
||||
* otherwise returns false.
|
||||
*/
|
||||
bool EncodeHotwords(std::istream &is, const SymbolTable &symbol_table,
|
||||
std::vector<std::vector<int32_t>> *hotwords);
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_UTILS_H_
|
||||
Reference in New Issue
Block a user