add shallow fusion (#147)
This commit is contained in:
@@ -121,7 +121,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
|
||||
// add log_prob of each hypothesis to p_logprob before taking top_k
|
||||
for (int32_t i = 0; i != num_hyps; ++i) {
|
||||
float log_prob = prev[i].log_prob;
|
||||
float log_prob = prev[i].log_prob + prev[i].lm_log_prob;
|
||||
for (int32_t k = 0; k != vocab_size; ++k, ++p_logprob) {
|
||||
*p_logprob += log_prob;
|
||||
}
|
||||
@@ -141,14 +141,18 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
int32_t new_token = k % vocab_size;
|
||||
|
||||
Hypothesis new_hyp = prev[hyp_index];
|
||||
const float prev_lm_log_prob = new_hyp.lm_log_prob;
|
||||
if (new_token != 0) {
|
||||
new_hyp.ys.push_back(new_token);
|
||||
new_hyp.timestamps.push_back(t + frame_offset);
|
||||
new_hyp.num_trailing_blanks = 0;
|
||||
if (lm_) {
|
||||
lm_->ComputeLMScore(lm_scale_, &new_hyp);
|
||||
}
|
||||
} else {
|
||||
++new_hyp.num_trailing_blanks;
|
||||
}
|
||||
new_hyp.log_prob = p_logprob[k];
|
||||
new_hyp.log_prob = p_logprob[k] - prev_lm_log_prob;
|
||||
hyps.Add(std::move(new_hyp));
|
||||
} // for (auto k : topk)
|
||||
cur.push_back(std::move(hyps));
|
||||
@@ -156,10 +160,6 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
} // for (int32_t b = 0; b != batch_size; ++b)
|
||||
}
|
||||
|
||||
if (lm_) {
|
||||
lm_->ComputeLMScore(lm_scale_, model_->ContextSize(), &cur);
|
||||
}
|
||||
|
||||
for (int32_t b = 0; b != batch_size; ++b) {
|
||||
auto &hyps = cur[b];
|
||||
auto best_hyp = hyps.GetMostProbable(true);
|
||||
|
||||
Reference in New Issue
Block a user