re-pull-request allow tokens and hotwords be loaded from buffered string driectly (#1339)

Co-authored-by: xiao <shawl336@163.com>
This commit is contained in:
lxiao336
2024-09-13 09:58:17 +08:00
committed by GitHub
parent 6b6e7635ed
commit 65cfa7548a
12 changed files with 414 additions and 16 deletions

View File

@@ -83,8 +83,14 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
: OnlineRecognizerImpl(config),
config_(config),
model_(OnlineTransducerModel::Create(config.model_config)),
sym_(config.model_config.tokens),
endpoint_(config_.endpoint_config) {
if (!config.model_config.tokens_buf.empty()) {
sym_ = SymbolTable(config.model_config.tokens_buf, false);
} else {
/// assuming tokens_buf and tokens are guaranteed not being both empty
sym_ = SymbolTable(config.model_config.tokens, true);
}
if (sym_.Contains("<unk>")) {
unk_id_ = sym_["<unk>"];
}
@@ -97,7 +103,9 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
config_.model_config.bpe_vocab);
}
if (!config_.hotwords_file.empty()) {
if (!config_.hotwords_buf.empty()) {
InitHotwordsFromBufStr();
} else if (!config_.hotwords_file.empty()) {
InitHotwords();
}
@@ -108,8 +116,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
model_.get(), lm_.get(), config_.max_active_paths,
config_.lm_config.scale, config_.lm_config.shallow_fusion, unk_id_,
config_.blank_penalty,
config_.temperature_scale);
config_.blank_penalty, config_.temperature_scale);
} else if (config.decoding_method == "greedy_search") {
decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>(
@@ -158,8 +165,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
model_.get(), lm_.get(), config_.max_active_paths,
config_.lm_config.scale, config_.lm_config.shallow_fusion, unk_id_,
config_.blank_penalty,
config_.temperature_scale);
config_.blank_penalty, config_.temperature_scale);
} else if (config.decoding_method == "greedy_search") {
decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>(
@@ -446,6 +452,20 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
}
#endif
void InitHotwordsFromBufStr() {
// each line in hotwords_file contains space-separated words
std::istringstream iss(config_.hotwords_buf);
if (!EncodeHotwords(iss, config_.model_config.modeling_unit, sym_,
bpe_encoder_.get(), &hotwords_, &boost_scores_)) {
SHERPA_ONNX_LOGE(
"Failed to encode some hotwords, skip them already, see logs above "
"for details.");
}
hotwords_graph_ = std::make_shared<ContextGraph>(
hotwords_, config_.hotwords_score, boost_scores_);
}
void InitOnlineStream(OnlineStream *stream) const {
auto r = decoder_->GetEmptyResult();