Re-implement LM rescore for online transducer (#1231)

Co-authored-by: Martins Kronis <martins.kuznecovs@tilde.lv>
This commit is contained in:
SilverSulfide
2024-09-06 05:01:25 +03:00
committed by GitHub
parent 1f29e4a1a9
commit 888f74bf3c
11 changed files with 175 additions and 31 deletions

View File

@@ -156,7 +156,11 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
// add log_prob of each hypothesis to p_logprob before taking top_k
for (int32_t i = 0; i != num_hyps; ++i) {
float log_prob = prev[i].log_prob + prev[i].lm_log_prob;
float log_prob = prev[i].log_prob;
if (lm_ && shallow_fusion_) {
log_prob += prev[i].lm_log_prob;
}
for (int32_t k = 0; k != vocab_size; ++k, ++p_logprob) {
*p_logprob += log_prob;
}
@@ -192,22 +196,31 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
context_score = std::get<0>(context_res);
new_hyp.context_state = std::get<1>(context_res);
}
if (lm_) {
lm_->ComputeLMScore(lm_scale_, &new_hyp);
if (lm_ && shallow_fusion_) {
lm_->ComputeLMScoreSF(lm_scale_, &new_hyp);
}
} else {
++new_hyp.num_trailing_blanks;
}
new_hyp.log_prob = p_logprob[k] + context_score -
if (lm_ && shallow_fusion_) {
new_hyp.log_prob = p_logprob[k] + context_score -
prev_lm_log_prob; // log_prob only includes the
// score of the transducer
} else {
new_hyp.log_prob = p_logprob[k] + context_score; // rescore or no LM
// previous token
// score is ignored
}
// export the per-token log scores
if (new_token != 0 && new_token != unk_id_) {
float y_prob = logit_with_temperature[start * vocab_size + k];
new_hyp.ys_probs.push_back(y_prob);
if (lm_) { // export only when LM is used
if (lm_ && shallow_fusion_) { // export only if
// LM shallow fusion is used
float lm_prob = new_hyp.lm_log_prob - prev_lm_log_prob;
if (lm_scale_ != 0.0) {
lm_prob /= lm_scale_; // remove lm-scale
}
@@ -227,6 +240,11 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
} // for (int32_t b = 0; b != batch_size; ++b)
} // for (int32_t t = 0; t != num_frames; ++t)
// classic lm rescore
if (lm_ && !shallow_fusion_) {
lm_->ComputeLMScore(lm_scale_, model_->ContextSize(), &cur);
}
for (int32_t b = 0; b != batch_size; ++b) {
auto &hyps = cur[b];
auto best_hyp = hyps.GetMostProbable(true);