Refactor hotwords,support loading hotwords from file (#296)
This commit is contained in:
@@ -5,7 +5,9 @@
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_IMPL_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_IMPL_H_
|
||||
|
||||
#include <fstream>
|
||||
#include <memory>
|
||||
#include <regex> // NOLINT
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
@@ -16,6 +18,7 @@
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/context-graph.h"
|
||||
#include "sherpa-onnx/csrc/log.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer.h"
|
||||
@@ -25,6 +28,7 @@
|
||||
#include "sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.h"
|
||||
#include "sherpa-onnx/csrc/pad-sequence.h"
|
||||
#include "sherpa-onnx/csrc/symbol-table.h"
|
||||
#include "sherpa-onnx/csrc/utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
@@ -60,6 +64,9 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
|
||||
: config_(config),
|
||||
symbol_table_(config_.model_config.tokens),
|
||||
model_(std::make_unique<OfflineTransducerModel>(config_.model_config)) {
|
||||
if (!config_.hotwords_file.empty()) {
|
||||
InitHotwords();
|
||||
}
|
||||
if (config_.decoding_method == "greedy_search") {
|
||||
decoder_ =
|
||||
std::make_unique<OfflineTransducerGreedySearchDecoder>(model_.get());
|
||||
@@ -105,17 +112,24 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
|
||||
#endif
|
||||
|
||||
std::unique_ptr<OfflineStream> CreateStream(
|
||||
const std::vector<std::vector<int32_t>> &context_list) const override {
|
||||
// We create context_graph at this level, because we might have default
|
||||
// context_graph(will be added later if needed) that belongs to the whole
|
||||
// model rather than each stream.
|
||||
const std::string &hotwords) const override {
|
||||
auto hws = std::regex_replace(hotwords, std::regex("/"), "\n");
|
||||
std::istringstream is(hws);
|
||||
std::vector<std::vector<int32_t>> current;
|
||||
if (!EncodeHotwords(is, symbol_table_, ¤t)) {
|
||||
SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s",
|
||||
hotwords.c_str());
|
||||
}
|
||||
current.insert(current.end(), hotwords_.begin(), hotwords_.end());
|
||||
|
||||
auto context_graph =
|
||||
std::make_shared<ContextGraph>(context_list, config_.context_score);
|
||||
std::make_shared<ContextGraph>(current, config_.hotwords_score);
|
||||
return std::make_unique<OfflineStream>(config_.feat_config, context_graph);
|
||||
}
|
||||
|
||||
std::unique_ptr<OfflineStream> CreateStream() const override {
|
||||
return std::make_unique<OfflineStream>(config_.feat_config);
|
||||
return std::make_unique<OfflineStream>(config_.feat_config,
|
||||
hotwords_graph_);
|
||||
}
|
||||
|
||||
void DecodeStreams(OfflineStream **ss, int32_t n) const override {
|
||||
@@ -171,9 +185,29 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
|
||||
}
|
||||
}
|
||||
|
||||
void InitHotwords() {
|
||||
// each line in hotwords_file contains space-separated words
|
||||
|
||||
std::ifstream is(config_.hotwords_file);
|
||||
if (!is) {
|
||||
SHERPA_ONNX_LOGE("Open hotwords file failed: %s",
|
||||
config_.hotwords_file.c_str());
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
if (!EncodeHotwords(is, symbol_table_, &hotwords_)) {
|
||||
SHERPA_ONNX_LOGE("Encode hotwords failed.");
|
||||
exit(-1);
|
||||
}
|
||||
hotwords_graph_ =
|
||||
std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score);
|
||||
}
|
||||
|
||||
private:
|
||||
OfflineRecognizerConfig config_;
|
||||
SymbolTable symbol_table_;
|
||||
std::vector<std::vector<int32_t>> hotwords_;
|
||||
ContextGraphPtr hotwords_graph_;
|
||||
std::unique_ptr<OfflineTransducerModel> model_;
|
||||
std::unique_ptr<OfflineTransducerDecoder> decoder_;
|
||||
std::unique_ptr<OfflineLM> lm_;
|
||||
|
||||
Reference in New Issue
Block a user