Encode hotwords in C++ side (#828)

* Encode hotwords in C++ side
This commit is contained in:
Wei Kang
2024-05-20 19:41:36 +08:00
committed by GitHub
parent 8af2af8466
commit b012b78ceb
43 changed files with 714 additions and 102 deletions

View File

@@ -15,8 +15,6 @@
#include <vector>
#if __ANDROID_API__ >= 9
#include <strstream>
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
@@ -33,6 +31,7 @@
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/symbol-table.h"
#include "sherpa-onnx/csrc/utils.h"
#include "ssentencepiece/csrc/ssentencepiece.h"
namespace sherpa_onnx {
@@ -94,6 +93,11 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
model_->SetFeatureDim(config.feat_config.feature_dim);
if (config.decoding_method == "modified_beam_search") {
if (!config_.model_config.bpe_vocab.empty()) {
bpe_encoder_ = std::make_unique<ssentencepiece::Ssentencepiece>(
config_.model_config.bpe_vocab);
}
if (!config_.hotwords_file.empty()) {
InitHotwords();
}
@@ -140,6 +144,12 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
}
#endif
if (!config_.model_config.bpe_vocab.empty()) {
auto buf = ReadFile(mgr, config_.model_config.bpe_vocab);
std::istringstream iss(std::string(buf.begin(), buf.end()));
bpe_encoder_ = std::make_unique<ssentencepiece::Ssentencepiece>(iss);
}
if (!config_.hotwords_file.empty()) {
InitHotwords(mgr);
}
@@ -174,7 +184,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
auto hws = std::regex_replace(hotwords, std::regex("/"), "\n");
std::istringstream is(hws);
std::vector<std::vector<int32_t>> current;
if (!EncodeHotwords(is, sym_, &current)) {
if (!EncodeHotwords(is, config_.model_config.modeling_unit, sym_,
bpe_encoder_.get(), &current)) {
SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s",
hotwords.c_str());
}
@@ -363,9 +374,11 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
exit(-1);
}
if (!EncodeHotwords(is, sym_, &hotwords_)) {
SHERPA_ONNX_LOGE("Encode hotwords failed.");
exit(-1);
if (!EncodeHotwords(is, config_.model_config.modeling_unit, sym_,
bpe_encoder_.get(), &hotwords_)) {
SHERPA_ONNX_LOGE(
"Failed to encode some hotwords, skip them already, see logs above "
"for details.");
}
hotwords_graph_ =
std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score);
@@ -377,7 +390,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
auto buf = ReadFile(mgr, config_.hotwords_file);
std::istrstream is(buf.data(), buf.size());
std::istringstream is(std::string(buf.begin(), buf.end()));
if (!is) {
SHERPA_ONNX_LOGE("Open hotwords file failed: %s",
@@ -385,9 +398,11 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
exit(-1);
}
if (!EncodeHotwords(is, sym_, &hotwords_)) {
SHERPA_ONNX_LOGE("Encode hotwords failed.");
exit(-1);
if (!EncodeHotwords(is, config_.model_config.modeling_unit, sym_,
bpe_encoder_.get(), &hotwords_)) {
SHERPA_ONNX_LOGE(
"Failed to encode some hotwords, skip them already, see logs above "
"for details.");
}
hotwords_graph_ =
std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score);
@@ -413,6 +428,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
OnlineRecognizerConfig config_;
std::vector<std::vector<int32_t>> hotwords_;
ContextGraphPtr hotwords_graph_;
std::unique_ptr<ssentencepiece::Ssentencepiece> bpe_encoder_;
std::unique_ptr<OnlineTransducerModel> model_;
std::unique_ptr<OnlineLM> lm_;
std::unique_ptr<OnlineTransducerDecoder> decoder_;