re-pull-request allow tokens and hotwords be loaded from buffered string driectly (#1339)
Co-authored-by: xiao <shawl336@163.com>
This commit is contained in:
@@ -83,8 +83,14 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
||||
: OnlineRecognizerImpl(config),
|
||||
config_(config),
|
||||
model_(OnlineTransducerModel::Create(config.model_config)),
|
||||
sym_(config.model_config.tokens),
|
||||
endpoint_(config_.endpoint_config) {
|
||||
if (!config.model_config.tokens_buf.empty()) {
|
||||
sym_ = SymbolTable(config.model_config.tokens_buf, false);
|
||||
} else {
|
||||
/// assuming tokens_buf and tokens are guaranteed not being both empty
|
||||
sym_ = SymbolTable(config.model_config.tokens, true);
|
||||
}
|
||||
|
||||
if (sym_.Contains("<unk>")) {
|
||||
unk_id_ = sym_["<unk>"];
|
||||
}
|
||||
@@ -97,7 +103,9 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
||||
config_.model_config.bpe_vocab);
|
||||
}
|
||||
|
||||
if (!config_.hotwords_file.empty()) {
|
||||
if (!config_.hotwords_buf.empty()) {
|
||||
InitHotwordsFromBufStr();
|
||||
} else if (!config_.hotwords_file.empty()) {
|
||||
InitHotwords();
|
||||
}
|
||||
|
||||
@@ -108,8 +116,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
||||
decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
|
||||
model_.get(), lm_.get(), config_.max_active_paths,
|
||||
config_.lm_config.scale, config_.lm_config.shallow_fusion, unk_id_,
|
||||
config_.blank_penalty,
|
||||
config_.temperature_scale);
|
||||
config_.blank_penalty, config_.temperature_scale);
|
||||
|
||||
} else if (config.decoding_method == "greedy_search") {
|
||||
decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>(
|
||||
@@ -158,8 +165,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
||||
decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
|
||||
model_.get(), lm_.get(), config_.max_active_paths,
|
||||
config_.lm_config.scale, config_.lm_config.shallow_fusion, unk_id_,
|
||||
config_.blank_penalty,
|
||||
config_.temperature_scale);
|
||||
config_.blank_penalty, config_.temperature_scale);
|
||||
|
||||
} else if (config.decoding_method == "greedy_search") {
|
||||
decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>(
|
||||
@@ -446,6 +452,20 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
||||
}
|
||||
#endif
|
||||
|
||||
void InitHotwordsFromBufStr() {
|
||||
// each line in hotwords_file contains space-separated words
|
||||
|
||||
std::istringstream iss(config_.hotwords_buf);
|
||||
if (!EncodeHotwords(iss, config_.model_config.modeling_unit, sym_,
|
||||
bpe_encoder_.get(), &hotwords_, &boost_scores_)) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Failed to encode some hotwords, skip them already, see logs above "
|
||||
"for details.");
|
||||
}
|
||||
hotwords_graph_ = std::make_shared<ContextGraph>(
|
||||
hotwords_, config_.hotwords_score, boost_scores_);
|
||||
}
|
||||
|
||||
void InitOnlineStream(OnlineStream *stream) const {
|
||||
auto r = decoder_->GetEmptyResult();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user