Support customize scores for hotwords (#926)
* Support customize scores for hotwords * Skip blank lines
This commit is contained in:
@@ -145,15 +145,35 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
|
||||
auto hws = std::regex_replace(hotwords, std::regex("/"), "\n");
|
||||
std::istringstream is(hws);
|
||||
std::vector<std::vector<int32_t>> current;
|
||||
std::vector<float> current_scores;
|
||||
if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_,
|
||||
bpe_encoder_.get(), ¤t)) {
|
||||
bpe_encoder_.get(), ¤t, ¤t_scores)) {
|
||||
SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s",
|
||||
hotwords.c_str());
|
||||
}
|
||||
|
||||
int32_t num_default_hws = hotwords_.size();
|
||||
int32_t num_hws = current.size();
|
||||
|
||||
current.insert(current.end(), hotwords_.begin(), hotwords_.end());
|
||||
|
||||
auto context_graph =
|
||||
std::make_shared<ContextGraph>(current, config_.hotwords_score);
|
||||
if (!current_scores.empty() && !boost_scores_.empty()) {
|
||||
current_scores.insert(current_scores.end(), boost_scores_.begin(),
|
||||
boost_scores_.end());
|
||||
} else if (!current_scores.empty() && boost_scores_.empty()) {
|
||||
current_scores.insert(current_scores.end(), num_default_hws,
|
||||
config_.hotwords_score);
|
||||
} else if (current_scores.empty() && !boost_scores_.empty()) {
|
||||
current_scores.insert(current_scores.end(), num_hws,
|
||||
config_.hotwords_score);
|
||||
current_scores.insert(current_scores.end(), boost_scores_.begin(),
|
||||
boost_scores_.end());
|
||||
} else {
|
||||
// Do nothing.
|
||||
}
|
||||
|
||||
auto context_graph = std::make_shared<ContextGraph>(
|
||||
current, config_.hotwords_score, current_scores);
|
||||
return std::make_unique<OfflineStream>(config_.feat_config, context_graph);
|
||||
}
|
||||
|
||||
@@ -226,13 +246,13 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
|
||||
}
|
||||
|
||||
if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_,
|
||||
bpe_encoder_.get(), &hotwords_)) {
|
||||
bpe_encoder_.get(), &hotwords_, &boost_scores_)) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Failed to encode some hotwords, skip them already, see logs above "
|
||||
"for details.");
|
||||
}
|
||||
hotwords_graph_ =
|
||||
std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score);
|
||||
hotwords_graph_ = std::make_shared<ContextGraph>(
|
||||
hotwords_, config_.hotwords_score, boost_scores_);
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
@@ -250,13 +270,13 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
|
||||
}
|
||||
|
||||
if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_,
|
||||
bpe_encoder_.get(), &hotwords_)) {
|
||||
bpe_encoder_.get(), &hotwords_, &boost_scores_)) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Failed to encode some hotwords, skip them already, see logs above "
|
||||
"for details.");
|
||||
}
|
||||
hotwords_graph_ =
|
||||
std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score);
|
||||
hotwords_graph_ = std::make_shared<ContextGraph>(
|
||||
hotwords_, config_.hotwords_score, boost_scores_);
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -264,6 +284,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
|
||||
OfflineRecognizerConfig config_;
|
||||
SymbolTable symbol_table_;
|
||||
std::vector<std::vector<int32_t>> hotwords_;
|
||||
std::vector<float> boost_scores_;
|
||||
ContextGraphPtr hotwords_graph_;
|
||||
std::unique_ptr<ssentencepiece::Ssentencepiece> bpe_encoder_;
|
||||
std::unique_ptr<OfflineTransducerModel> model_;
|
||||
|
||||
Reference in New Issue
Block a user