add shallow fusion (#147)

This commit is contained in:
PF Luo
2023-05-10 22:30:57 +08:00
committed by GitHub
parent 7969cf44ac
commit 824b0809a4
9 changed files with 104 additions and 125 deletions

View File

@@ -121,7 +121,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
// add log_prob of each hypothesis to p_logprob before taking top_k
for (int32_t i = 0; i != num_hyps; ++i) {
float log_prob = prev[i].log_prob;
float log_prob = prev[i].log_prob + prev[i].lm_log_prob;
for (int32_t k = 0; k != vocab_size; ++k, ++p_logprob) {
*p_logprob += log_prob;
}
@@ -141,14 +141,18 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
int32_t new_token = k % vocab_size;
Hypothesis new_hyp = prev[hyp_index];
const float prev_lm_log_prob = new_hyp.lm_log_prob;
if (new_token != 0) {
new_hyp.ys.push_back(new_token);
new_hyp.timestamps.push_back(t + frame_offset);
new_hyp.num_trailing_blanks = 0;
if (lm_) {
lm_->ComputeLMScore(lm_scale_, &new_hyp);
}
} else {
++new_hyp.num_trailing_blanks;
}
new_hyp.log_prob = p_logprob[k];
new_hyp.log_prob = p_logprob[k] - prev_lm_log_prob;
hyps.Add(std::move(new_hyp));
} // for (auto k : topk)
cur.push_back(std::move(hyps));
@@ -156,10 +160,6 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
} // for (int32_t b = 0; b != batch_size; ++b)
}
if (lm_) {
lm_->ComputeLMScore(lm_scale_, model_->ContextSize(), &cur);
}
for (int32_t b = 0; b != batch_size; ++b) {
auto &hyps = cur[b];
auto best_hyp = hyps.GetMostProbable(true);