Support customize scores for hotwords (#926)

* Support customize scores for hotwords

* Skip blank lines
This commit is contained in:
Wei Kang
2024-05-31 12:34:30 +08:00
committed by GitHub
parent a689249f88
commit a38881817c
6 changed files with 103 additions and 35 deletions

View File

@@ -182,14 +182,35 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
auto hws = std::regex_replace(hotwords, std::regex("/"), "\n");
std::istringstream is(hws);
std::vector<std::vector<int32_t>> current;
std::vector<float> current_scores;
if (!EncodeHotwords(is, config_.model_config.modeling_unit, sym_,
bpe_encoder_.get(), &current)) {
bpe_encoder_.get(), &current, &current_scores)) {
SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s",
hotwords.c_str());
}
int32_t num_default_hws = hotwords_.size();
int32_t num_hws = current.size();
current.insert(current.end(), hotwords_.begin(), hotwords_.end());
auto context_graph =
std::make_shared<ContextGraph>(current, config_.hotwords_score);
if (!current_scores.empty() && !boost_scores_.empty()) {
current_scores.insert(current_scores.end(), boost_scores_.begin(),
boost_scores_.end());
} else if (!current_scores.empty() && boost_scores_.empty()) {
current_scores.insert(current_scores.end(), num_default_hws,
config_.hotwords_score);
} else if (current_scores.empty() && !boost_scores_.empty()) {
current_scores.insert(current_scores.end(), num_hws,
config_.hotwords_score);
current_scores.insert(current_scores.end(), boost_scores_.begin(),
boost_scores_.end());
} else {
// Do nothing.
}
auto context_graph = std::make_shared<ContextGraph>(
current, config_.hotwords_score, current_scores);
auto stream =
std::make_unique<OnlineStream>(config_.feat_config, context_graph);
InitOnlineStream(stream.get());
@@ -376,13 +397,13 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
}
if (!EncodeHotwords(is, config_.model_config.modeling_unit, sym_,
bpe_encoder_.get(), &hotwords_)) {
bpe_encoder_.get(), &hotwords_, &boost_scores_)) {
SHERPA_ONNX_LOGE(
"Failed to encode some hotwords, skip them already, see logs above "
"for details.");
}
hotwords_graph_ =
std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score);
hotwords_graph_ = std::make_shared<ContextGraph>(
hotwords_, config_.hotwords_score, boost_scores_);
}
#if __ANDROID_API__ >= 9
@@ -400,13 +421,13 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
}
if (!EncodeHotwords(is, config_.model_config.modeling_unit, sym_,
bpe_encoder_.get(), &hotwords_)) {
bpe_encoder_.get(), &hotwords_, &boost_scores_)) {
SHERPA_ONNX_LOGE(
"Failed to encode some hotwords, skip them already, see logs above "
"for details.");
}
hotwords_graph_ =
std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score);
hotwords_graph_ = std::make_shared<ContextGraph>(
hotwords_, config_.hotwords_score, boost_scores_);
}
#endif
@@ -428,6 +449,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
private:
OnlineRecognizerConfig config_;
std::vector<std::vector<int32_t>> hotwords_;
std::vector<float> boost_scores_;
ContextGraphPtr hotwords_graph_;
std::unique_ptr<ssentencepiece::Ssentencepiece> bpe_encoder_;
std::unique_ptr<OnlineTransducerModel> model_;