Refactor hotwords,support loading hotwords from file (#296)

This commit is contained in:
Wei Kang
2023-09-14 19:33:17 +08:00
committed by GitHub
parent 087367d7fe
commit 47184f9db7
34 changed files with 803 additions and 300 deletions

View File

@@ -7,6 +7,8 @@
#include <algorithm>
#include <memory>
#include <regex> // NOLINT
#include <string>
#include <utility>
#include <vector>
@@ -20,6 +22,7 @@
#include "sherpa-onnx/csrc/online-transducer-model.h"
#include "sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h"
#include "sherpa-onnx/csrc/symbol-table.h"
#include "sherpa-onnx/csrc/utils.h"
namespace sherpa_onnx {
@@ -57,6 +60,9 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
model_(OnlineTransducerModel::Create(config.model_config)),
sym_(config.model_config.tokens),
endpoint_(config_.endpoint_config) {
if (!config_.hotwords_file.empty()) {
InitHotwords();
}
if (sym_.contains("<unk>")) {
unk_id_ = sym_["<unk>"];
}
@@ -106,18 +112,24 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
#endif
std::unique_ptr<OnlineStream> CreateStream() const override {
auto stream = std::make_unique<OnlineStream>(config_.feat_config);
auto stream =
std::make_unique<OnlineStream>(config_.feat_config, hotwords_graph_);
InitOnlineStream(stream.get());
return stream;
}
std::unique_ptr<OnlineStream> CreateStream(
const std::vector<std::vector<int32_t>> &contexts) const override {
// We create context_graph at this level, because we might have default
// context_graph(will be added later if needed) that belongs to the whole
// model rather than each stream.
const std::string &hotwords) const override {
auto hws = std::regex_replace(hotwords, std::regex("/"), "\n");
std::istringstream is(hws);
std::vector<std::vector<int32_t>> current;
if (!EncodeHotwords(is, sym_, &current)) {
SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s",
hotwords.c_str());
}
current.insert(current.end(), hotwords_.begin(), hotwords_.end());
auto context_graph =
std::make_shared<ContextGraph>(contexts, config_.context_score);
std::make_shared<ContextGraph>(current, config_.hotwords_score);
auto stream =
std::make_unique<OnlineStream>(config_.feat_config, context_graph);
InitOnlineStream(stream.get());
@@ -253,6 +265,24 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
s->Reset();
}
void InitHotwords() {
// each line in hotwords_file contains space-separated words
std::ifstream is(config_.hotwords_file);
if (!is) {
SHERPA_ONNX_LOGE("Open hotwords file failed: %s",
config_.hotwords_file.c_str());
exit(-1);
}
if (!EncodeHotwords(is, sym_, &hotwords_)) {
SHERPA_ONNX_LOGE("Encode hotwords failed.");
exit(-1);
}
hotwords_graph_ =
std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score);
}
private:
void InitOnlineStream(OnlineStream *stream) const {
auto r = decoder_->GetEmptyResult();
@@ -271,6 +301,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
private:
OnlineRecognizerConfig config_;
std::vector<std::vector<int32_t>> hotwords_;
ContextGraphPtr hotwords_graph_;
std::unique_ptr<OnlineTransducerModel> model_;
std::unique_ptr<OnlineLM> lm_;
std::unique_ptr<OnlineTransducerDecoder> decoder_;