Re-implement LM rescore for online transducer (#1231)
Co-authored-by: Martins Kronis <martins.kuznecovs@tilde.lv>
This commit is contained in:
@@ -156,7 +156,11 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
|
||||
// add log_prob of each hypothesis to p_logprob before taking top_k
|
||||
for (int32_t i = 0; i != num_hyps; ++i) {
|
||||
float log_prob = prev[i].log_prob + prev[i].lm_log_prob;
|
||||
float log_prob = prev[i].log_prob;
|
||||
if (lm_ && shallow_fusion_) {
|
||||
log_prob += prev[i].lm_log_prob;
|
||||
}
|
||||
|
||||
for (int32_t k = 0; k != vocab_size; ++k, ++p_logprob) {
|
||||
*p_logprob += log_prob;
|
||||
}
|
||||
@@ -192,22 +196,31 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
context_score = std::get<0>(context_res);
|
||||
new_hyp.context_state = std::get<1>(context_res);
|
||||
}
|
||||
if (lm_) {
|
||||
lm_->ComputeLMScore(lm_scale_, &new_hyp);
|
||||
if (lm_ && shallow_fusion_) {
|
||||
lm_->ComputeLMScoreSF(lm_scale_, &new_hyp);
|
||||
}
|
||||
} else {
|
||||
++new_hyp.num_trailing_blanks;
|
||||
}
|
||||
new_hyp.log_prob = p_logprob[k] + context_score -
|
||||
if (lm_ && shallow_fusion_) {
|
||||
new_hyp.log_prob = p_logprob[k] + context_score -
|
||||
prev_lm_log_prob; // log_prob only includes the
|
||||
// score of the transducer
|
||||
} else {
|
||||
new_hyp.log_prob = p_logprob[k] + context_score; // rescore or no LM
|
||||
// previous token
|
||||
// score is ignored
|
||||
}
|
||||
|
||||
// export the per-token log scores
|
||||
if (new_token != 0 && new_token != unk_id_) {
|
||||
float y_prob = logit_with_temperature[start * vocab_size + k];
|
||||
new_hyp.ys_probs.push_back(y_prob);
|
||||
|
||||
if (lm_) { // export only when LM is used
|
||||
if (lm_ && shallow_fusion_) { // export only if
|
||||
// LM shallow fusion is used
|
||||
float lm_prob = new_hyp.lm_log_prob - prev_lm_log_prob;
|
||||
|
||||
if (lm_scale_ != 0.0) {
|
||||
lm_prob /= lm_scale_; // remove lm-scale
|
||||
}
|
||||
@@ -227,6 +240,11 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
} // for (int32_t b = 0; b != batch_size; ++b)
|
||||
} // for (int32_t t = 0; t != num_frames; ++t)
|
||||
|
||||
// classic lm rescore
|
||||
if (lm_ && !shallow_fusion_) {
|
||||
lm_->ComputeLMScore(lm_scale_, model_->ContextSize(), &cur);
|
||||
}
|
||||
|
||||
for (int32_t b = 0; b != batch_size; ++b) {
|
||||
auto &hyps = cur[b];
|
||||
auto best_hyp = hyps.GetMostProbable(true);
|
||||
|
||||
Reference in New Issue
Block a user