Refactor hotwords,support loading hotwords from file (#296)

This commit is contained in:
Wei Kang
2023-09-14 19:33:17 +08:00
committed by GitHub
parent 087367d7fe
commit 47184f9db7
34 changed files with 803 additions and 300 deletions

View File

@@ -72,6 +72,7 @@ set(sources
text-utils.cc
transpose.cc
unbind.cc
utils.cc
wave-reader.cc
)

View File

@@ -4,11 +4,14 @@
#include "sherpa-onnx/csrc/context-graph.h"
#include <chrono> // NOLINT
#include <map>
#include <random>
#include <string>
#include <vector>
#include "gtest/gtest.h"
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
@@ -41,4 +44,29 @@ TEST(ContextGraph, TestBasic) {
}
}
TEST(ContextGraph, Benchmark) {
std::random_device rd;
std::mt19937 mt(rd());
std::uniform_int_distribution<int32_t> char_dist(0, 25);
std::uniform_int_distribution<int32_t> len_dist(3, 8);
for (int32_t num = 10; num <= 10000; num *= 10) {
std::vector<std::vector<int32_t>> contexts;
for (int32_t i = 0; i < num; ++i) {
std::vector<int32_t> tmp;
int32_t word_len = len_dist(mt);
for (int32_t j = 0; j < word_len; ++j) {
tmp.push_back(char_dist(mt));
}
contexts.push_back(std::move(tmp));
}
auto start = std::chrono::high_resolution_clock::now();
auto context_graph = ContextGraph(contexts, 1);
auto stop = std::chrono::high_resolution_clock::now();
auto duration =
std::chrono::duration_cast<std::chrono::microseconds>(stop - start);
SHERPA_ONNX_LOGE("Construct context graph for %d item takes %ld us.", num,
duration.count());
}
}
} // namespace sherpa_onnx

View File

@@ -6,6 +6,7 @@
#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_IMPL_H_
#include <memory>
#include <string>
#include <vector>
#if __ANDROID_API__ >= 9
@@ -32,7 +33,7 @@ class OfflineRecognizerImpl {
virtual ~OfflineRecognizerImpl() = default;
virtual std::unique_ptr<OfflineStream> CreateStream(
const std::vector<std::vector<int32_t>> &context_list) const {
const std::string &hotwords) const {
SHERPA_ONNX_LOGE("Only transducer models support contextual biasing.");
exit(-1);
}

View File

@@ -5,7 +5,9 @@
#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_IMPL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_IMPL_H_
#include <fstream>
#include <memory>
#include <regex> // NOLINT
#include <string>
#include <utility>
#include <vector>
@@ -16,6 +18,7 @@
#endif
#include "sherpa-onnx/csrc/context-graph.h"
#include "sherpa-onnx/csrc/log.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
#include "sherpa-onnx/csrc/offline-recognizer.h"
@@ -25,6 +28,7 @@
#include "sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.h"
#include "sherpa-onnx/csrc/pad-sequence.h"
#include "sherpa-onnx/csrc/symbol-table.h"
#include "sherpa-onnx/csrc/utils.h"
namespace sherpa_onnx {
@@ -60,6 +64,9 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
: config_(config),
symbol_table_(config_.model_config.tokens),
model_(std::make_unique<OfflineTransducerModel>(config_.model_config)) {
if (!config_.hotwords_file.empty()) {
InitHotwords();
}
if (config_.decoding_method == "greedy_search") {
decoder_ =
std::make_unique<OfflineTransducerGreedySearchDecoder>(model_.get());
@@ -105,17 +112,24 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
#endif
std::unique_ptr<OfflineStream> CreateStream(
const std::vector<std::vector<int32_t>> &context_list) const override {
// We create context_graph at this level, because we might have default
// context_graph(will be added later if needed) that belongs to the whole
// model rather than each stream.
const std::string &hotwords) const override {
auto hws = std::regex_replace(hotwords, std::regex("/"), "\n");
std::istringstream is(hws);
std::vector<std::vector<int32_t>> current;
if (!EncodeHotwords(is, symbol_table_, &current)) {
SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s",
hotwords.c_str());
}
current.insert(current.end(), hotwords_.begin(), hotwords_.end());
auto context_graph =
std::make_shared<ContextGraph>(context_list, config_.context_score);
std::make_shared<ContextGraph>(current, config_.hotwords_score);
return std::make_unique<OfflineStream>(config_.feat_config, context_graph);
}
std::unique_ptr<OfflineStream> CreateStream() const override {
return std::make_unique<OfflineStream>(config_.feat_config);
return std::make_unique<OfflineStream>(config_.feat_config,
hotwords_graph_);
}
void DecodeStreams(OfflineStream **ss, int32_t n) const override {
@@ -171,9 +185,29 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
}
}
void InitHotwords() {
// each line in hotwords_file contains space-separated words
std::ifstream is(config_.hotwords_file);
if (!is) {
SHERPA_ONNX_LOGE("Open hotwords file failed: %s",
config_.hotwords_file.c_str());
exit(-1);
}
if (!EncodeHotwords(is, symbol_table_, &hotwords_)) {
SHERPA_ONNX_LOGE("Encode hotwords failed.");
exit(-1);
}
hotwords_graph_ =
std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score);
}
private:
OfflineRecognizerConfig config_;
SymbolTable symbol_table_;
std::vector<std::vector<int32_t>> hotwords_;
ContextGraphPtr hotwords_graph_;
std::unique_ptr<OfflineTransducerModel> model_;
std::unique_ptr<OfflineTransducerDecoder> decoder_;
std::unique_ptr<OfflineLM> lm_;

View File

@@ -26,7 +26,15 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) {
po->Register("max-active-paths", &max_active_paths,
"Used only when decoding_method is modified_beam_search");
po->Register("context-score", &context_score,
po->Register(
"hotwords-file", &hotwords_file,
"The file containing hotwords, one words/phrases per line, and for each"
"phrase the bpe/cjkchar are separated by a space. For example: "
"▁HE LL O ▁WORLD"
"你 好 世 界");
po->Register("hotwords-score", &hotwords_score,
"The bonus score for each token in context word/phrase. "
"Used only when decoding_method is modified_beam_search");
}
@@ -53,7 +61,8 @@ std::string OfflineRecognizerConfig::ToString() const {
os << "lm_config=" << lm_config.ToString() << ", ";
os << "decoding_method=\"" << decoding_method << "\", ";
os << "max_active_paths=" << max_active_paths << ", ";
os << "context_score=" << context_score << ")";
os << "hotwords_file=\"" << hotwords_file << "\", ";
os << "hotwords_score=" << hotwords_score << ")";
return os.str();
}
@@ -70,8 +79,8 @@ OfflineRecognizer::OfflineRecognizer(const OfflineRecognizerConfig &config)
OfflineRecognizer::~OfflineRecognizer() = default;
std::unique_ptr<OfflineStream> OfflineRecognizer::CreateStream(
const std::vector<std::vector<int32_t>> &context_list) const {
return impl_->CreateStream(context_list);
const std::string &hotwords) const {
return impl_->CreateStream(hotwords);
}
std::unique_ptr<OfflineStream> OfflineRecognizer::CreateStream() const {

View File

@@ -31,7 +31,10 @@ struct OfflineRecognizerConfig {
std::string decoding_method = "greedy_search";
int32_t max_active_paths = 4;
float context_score = 1.5;
std::string hotwords_file;
float hotwords_score = 1.5;
// only greedy_search is implemented
// TODO(fangjun): Implement modified_beam_search
@@ -40,13 +43,16 @@ struct OfflineRecognizerConfig {
const OfflineModelConfig &model_config,
const OfflineLMConfig &lm_config,
const std::string &decoding_method,
int32_t max_active_paths, float context_score)
int32_t max_active_paths,
const std::string &hotwords_file,
float hotwords_score)
: feat_config(feat_config),
model_config(model_config),
lm_config(lm_config),
decoding_method(decoding_method),
max_active_paths(max_active_paths),
context_score(context_score) {}
hotwords_file(hotwords_file),
hotwords_score(hotwords_score) {}
void Register(ParseOptions *po);
bool Validate() const;
@@ -69,9 +75,17 @@ class OfflineRecognizer {
/// Create a stream for decoding.
std::unique_ptr<OfflineStream> CreateStream() const;
/// Create a stream for decoding.
/** Create a stream for decoding.
*
* @param The hotwords for this string, it might contain several hotwords,
* the hotwords are separated by "/". In each of the hotwords, there
* are cjkchars or bpes, the bpe/cjkchar are separated by space (" ").
* For example, hotwords I LOVE YOU and HELLO WORLD, looks like:
*
* "▁I ▁LOVE ▁YOU/▁HE LL O ▁WORLD"
*/
std::unique_ptr<OfflineStream> CreateStream(
const std::vector<std::vector<int32_t>> &context_list) const;
const std::string &hotwords) const;
/** Decode a single stream
*

View File

@@ -6,6 +6,7 @@
#define SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_IMPL_H_
#include <memory>
#include <string>
#include <vector>
#include "sherpa-onnx/csrc/macros.h"
@@ -29,7 +30,7 @@ class OnlineRecognizerImpl {
virtual std::unique_ptr<OnlineStream> CreateStream() const = 0;
virtual std::unique_ptr<OnlineStream> CreateStream(
const std::vector<std::vector<int32_t>> &contexts) const {
const std::string &hotwords) const {
SHERPA_ONNX_LOGE("Only transducer models support contextual biasing.");
exit(-1);
}

View File

@@ -7,6 +7,8 @@
#include <algorithm>
#include <memory>
#include <regex> // NOLINT
#include <string>
#include <utility>
#include <vector>
@@ -20,6 +22,7 @@
#include "sherpa-onnx/csrc/online-transducer-model.h"
#include "sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h"
#include "sherpa-onnx/csrc/symbol-table.h"
#include "sherpa-onnx/csrc/utils.h"
namespace sherpa_onnx {
@@ -57,6 +60,9 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
model_(OnlineTransducerModel::Create(config.model_config)),
sym_(config.model_config.tokens),
endpoint_(config_.endpoint_config) {
if (!config_.hotwords_file.empty()) {
InitHotwords();
}
if (sym_.contains("<unk>")) {
unk_id_ = sym_["<unk>"];
}
@@ -106,18 +112,24 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
#endif
std::unique_ptr<OnlineStream> CreateStream() const override {
auto stream = std::make_unique<OnlineStream>(config_.feat_config);
auto stream =
std::make_unique<OnlineStream>(config_.feat_config, hotwords_graph_);
InitOnlineStream(stream.get());
return stream;
}
std::unique_ptr<OnlineStream> CreateStream(
const std::vector<std::vector<int32_t>> &contexts) const override {
// We create context_graph at this level, because we might have default
// context_graph(will be added later if needed) that belongs to the whole
// model rather than each stream.
const std::string &hotwords) const override {
auto hws = std::regex_replace(hotwords, std::regex("/"), "\n");
std::istringstream is(hws);
std::vector<std::vector<int32_t>> current;
if (!EncodeHotwords(is, sym_, &current)) {
SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s",
hotwords.c_str());
}
current.insert(current.end(), hotwords_.begin(), hotwords_.end());
auto context_graph =
std::make_shared<ContextGraph>(contexts, config_.context_score);
std::make_shared<ContextGraph>(current, config_.hotwords_score);
auto stream =
std::make_unique<OnlineStream>(config_.feat_config, context_graph);
InitOnlineStream(stream.get());
@@ -253,6 +265,24 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
s->Reset();
}
void InitHotwords() {
// each line in hotwords_file contains space-separated words
std::ifstream is(config_.hotwords_file);
if (!is) {
SHERPA_ONNX_LOGE("Open hotwords file failed: %s",
config_.hotwords_file.c_str());
exit(-1);
}
if (!EncodeHotwords(is, sym_, &hotwords_)) {
SHERPA_ONNX_LOGE("Encode hotwords failed.");
exit(-1);
}
hotwords_graph_ =
std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score);
}
private:
void InitOnlineStream(OnlineStream *stream) const {
auto r = decoder_->GetEmptyResult();
@@ -271,6 +301,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
private:
OnlineRecognizerConfig config_;
std::vector<std::vector<int32_t>> hotwords_;
ContextGraphPtr hotwords_graph_;
std::unique_ptr<OnlineTransducerModel> model_;
std::unique_ptr<OnlineLM> lm_;
std::unique_ptr<OnlineTransducerDecoder> decoder_;

View File

@@ -57,9 +57,15 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) {
"True to enable endpoint detection. False to disable it.");
po->Register("max-active-paths", &max_active_paths,
"beam size used in modified beam search.");
po->Register("context-score", &context_score,
po->Register("hotwords-score", &hotwords_score,
"The bonus score for each token in context word/phrase. "
"Used only when decoding_method is modified_beam_search");
po->Register(
"hotwords-file", &hotwords_file,
"The file containing hotwords, one words/phrases per line, and for each"
"phrase the bpe/cjkchar are separated by a space. For example: "
"▁HE LL O ▁WORLD"
"你 好 世 界");
po->Register("decoding-method", &decoding_method,
"decoding method,"
"now support greedy_search and modified_beam_search.");
@@ -87,7 +93,8 @@ std::string OnlineRecognizerConfig::ToString() const {
os << "endpoint_config=" << endpoint_config.ToString() << ", ";
os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ", ";
os << "max_active_paths=" << max_active_paths << ", ";
os << "context_score=" << context_score << ", ";
os << "hotwords_score=" << hotwords_score << ", ";
os << "hotwords_file=\"" << hotwords_file << "\", ";
os << "decoding_method=\"" << decoding_method << "\")";
return os.str();
@@ -109,8 +116,8 @@ std::unique_ptr<OnlineStream> OnlineRecognizer::CreateStream() const {
}
std::unique_ptr<OnlineStream> OnlineRecognizer::CreateStream(
const std::vector<std::vector<int32_t>> &context_list) const {
return impl_->CreateStream(context_list);
const std::string &hotwords) const {
return impl_->CreateStream(hotwords);
}
bool OnlineRecognizer::IsReady(OnlineStream *s) const {

View File

@@ -78,8 +78,10 @@ struct OnlineRecognizerConfig {
// used only for modified_beam_search
int32_t max_active_paths = 4;
/// used only for modified_beam_search
float context_score = 1.5;
float hotwords_score = 1.5;
std::string hotwords_file;
OnlineRecognizerConfig() = default;
@@ -89,14 +91,16 @@ struct OnlineRecognizerConfig {
const EndpointConfig &endpoint_config,
bool enable_endpoint,
const std::string &decoding_method,
int32_t max_active_paths, float context_score)
int32_t max_active_paths,
const std::string &hotwords_file, float hotwords_score)
: feat_config(feat_config),
model_config(model_config),
endpoint_config(endpoint_config),
enable_endpoint(enable_endpoint),
decoding_method(decoding_method),
max_active_paths(max_active_paths),
context_score(context_score) {}
hotwords_score(hotwords_score),
hotwords_file(hotwords_file) {}
void Register(ParseOptions *po);
bool Validate() const;
@@ -119,9 +123,16 @@ class OnlineRecognizer {
/// Create a stream for decoding.
std::unique_ptr<OnlineStream> CreateStream() const;
// Create a stream with context phrases
std::unique_ptr<OnlineStream> CreateStream(
const std::vector<std::vector<int32_t>> &context_list) const;
/** Create a stream for decoding.
*
* @param The hotwords for this string, it might contain several hotwords,
* the hotwords are separated by "/". In each of the hotwords, there
* are cjkchars or bpes, the bpe/cjkchar are separated by space (" ").
* For example, hotwords I LOVE YOU and HELLO WORLD, looks like:
*
* "▁I ▁LOVE ▁YOU/▁HE LL O ▁WORLD"
*/
std::unique_ptr<OnlineStream> CreateStream(const std::string &hotwords) const;
/**
* Return true if the given stream has enough frames for decoding.

54
sherpa-onnx/csrc/utils.cc Normal file
View File

@@ -0,0 +1,54 @@
// sherpa-onnx/csrc/utils.cc
//
// Copyright 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/utils.h"
#include <iostream>
#include <sstream>
#include <string>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/log.h"
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
bool EncodeHotwords(std::istream &is, const SymbolTable &symbol_table,
std::vector<std::vector<int32_t>> *hotwords) {
hotwords->clear();
std::vector<int32_t> tmp;
std::string line;
std::string word;
while (std::getline(is, line)) {
std::istringstream iss(line);
std::vector<std::string> syms;
while (iss >> word) {
if (word.size() >= 3) {
// For BPE-based models, we replace ▁ with a space
// Unicode 9601, hex 0x2581, utf8 0xe29681
const uint8_t *p = reinterpret_cast<const uint8_t *>(word.c_str());
if (p[0] == 0xe2 && p[1] == 0x96 && p[2] == 0x81) {
word = word.replace(0, 3, " ");
}
}
if (symbol_table.contains(word)) {
int32_t number = symbol_table[word];
tmp.push_back(number);
} else {
SHERPA_ONNX_LOGE(
"Cannot find ID for hotword %s at line: %s. (Hint: words on "
"the "
"same line are separated by spaces)",
word.c_str(), line.c_str());
return false;
}
}
hotwords->push_back(std::move(tmp));
}
return true;
}
} // namespace sherpa_onnx

33
sherpa-onnx/csrc/utils.h Normal file
View File

@@ -0,0 +1,33 @@
// sherpa-onnx/csrc/utils.h
//
// Copyright 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_UTILS_H_
#define SHERPA_ONNX_CSRC_UTILS_H_
#include <string>
#include <vector>
#include "sherpa-onnx/csrc/symbol-table.h"
namespace sherpa_onnx {
/* Encode the hotwords in an input stream to be tokens ids.
*
* @param is The input stream, it contains several lines, one hotword for each
* line. For each hotword, the tokens (cjkchar or bpe) are separated
* by spaces.
* @param symbol_table The tokens table mapping symbols to ids. All the symbols
* in the stream should be in the symbol_table, if not this
* function returns fasle.
*
* @@param hotwords The encoded ids to be written to.
*
* @return If all the symbols from ``is`` are in the symbol_table, returns true
* otherwise returns false.
*/
bool EncodeHotwords(std::istream &is, const SymbolTable &symbol_table,
std::vector<std::vector<int32_t>> *hotwords);
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_UTILS_H_