Track token scores (#571)

* add export of per-token scores (ys, lm, context)

- for best path of the modified-beam-search decoding of transducer

* refactoring JSON export of OnlineRecognitionResult, extending pybind11 API of OnlineRecognitionResult

* export per-token scores also for greedy-search (online-transducer)

- export un-scaled lm_probs (modified-beam search, online-transducer)
- polishing

* fill lm_probs/context_scores only if LM/ContextGraph is present (make Result smaller)
This commit is contained in:
Karel Vesely
2024-02-28 23:28:45 +01:00
committed by GitHub
parent 85d59b5840
commit 38c072dcb2
11 changed files with 155 additions and 49 deletions

View File

@@ -59,6 +59,12 @@ void OnlineTransducerModifiedBeamSearchDecoder::StripLeadingBlanks(
std::vector<int64_t> tokens(hyp.ys.begin() + context_size, hyp.ys.end());
r->tokens = std::move(tokens);
r->timestamps = std::move(hyp.timestamps);
// export per-token scores
r->ys_probs = std::move(hyp.ys_probs);
r->lm_probs = std::move(hyp.lm_probs);
r->context_scores = std::move(hyp.context_scores);
r->num_trailing_blanks = hyp.num_trailing_blanks;
}
@@ -180,6 +186,28 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
new_hyp.log_prob = p_logprob[k] + context_score -
prev_lm_log_prob; // log_prob only includes the
// score of the transducer
// export the per-token log scores
if (new_token != 0 && new_token != unk_id_) {
const Hypothesis& prev_i = prev[hyp_index];
// subtract 'prev[i]' path scores, which were added before
// getting topk tokens
float y_prob = p_logprob[k] - prev_i.log_prob - prev_i.lm_log_prob;
new_hyp.ys_probs.push_back(y_prob);
if (lm_) { // export only when LM is used
float lm_prob = new_hyp.lm_log_prob - prev_lm_log_prob;
if (lm_scale_ != 0.0) {
lm_prob /= lm_scale_; // remove lm-scale
}
new_hyp.lm_probs.push_back(lm_prob);
}
// export only when `ContextGraph` is used
if (ss != nullptr && ss[b]->GetContextGraph() != nullptr) {
new_hyp.context_scores.push_back(context_score);
}
}
hyps.Add(std::move(new_hyp));
} // for (auto k : topk)
cur.push_back(std::move(hyps));