Track token scores (#571)
* add export of per-token scores (ys, lm, context) - for best path of the modified-beam-search decoding of transducer * refactoring JSON export of OnlineRecognitionResult, extending pybind11 API of OnlineRecognitionResult * export per-token scores also for greedy-search (online-transducer) - export un-scaled lm_probs (modified-beam search, online-transducer) - polishing * fill lm_probs/context_scores only if LM/ContextGraph is present (make Result smaller)
This commit is contained in:
@@ -59,6 +59,12 @@ void OnlineTransducerModifiedBeamSearchDecoder::StripLeadingBlanks(
|
||||
std::vector<int64_t> tokens(hyp.ys.begin() + context_size, hyp.ys.end());
|
||||
r->tokens = std::move(tokens);
|
||||
r->timestamps = std::move(hyp.timestamps);
|
||||
|
||||
// export per-token scores
|
||||
r->ys_probs = std::move(hyp.ys_probs);
|
||||
r->lm_probs = std::move(hyp.lm_probs);
|
||||
r->context_scores = std::move(hyp.context_scores);
|
||||
|
||||
r->num_trailing_blanks = hyp.num_trailing_blanks;
|
||||
}
|
||||
|
||||
@@ -180,6 +186,28 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
new_hyp.log_prob = p_logprob[k] + context_score -
|
||||
prev_lm_log_prob; // log_prob only includes the
|
||||
// score of the transducer
|
||||
// export the per-token log scores
|
||||
if (new_token != 0 && new_token != unk_id_) {
|
||||
const Hypothesis& prev_i = prev[hyp_index];
|
||||
// subtract 'prev[i]' path scores, which were added before
|
||||
// getting topk tokens
|
||||
float y_prob = p_logprob[k] - prev_i.log_prob - prev_i.lm_log_prob;
|
||||
new_hyp.ys_probs.push_back(y_prob);
|
||||
|
||||
if (lm_) { // export only when LM is used
|
||||
float lm_prob = new_hyp.lm_log_prob - prev_lm_log_prob;
|
||||
if (lm_scale_ != 0.0) {
|
||||
lm_prob /= lm_scale_; // remove lm-scale
|
||||
}
|
||||
new_hyp.lm_probs.push_back(lm_prob);
|
||||
}
|
||||
|
||||
// export only when `ContextGraph` is used
|
||||
if (ss != nullptr && ss[b]->GetContextGraph() != nullptr) {
|
||||
new_hyp.context_scores.push_back(context_score);
|
||||
}
|
||||
}
|
||||
|
||||
hyps.Add(std::move(new_hyp));
|
||||
} // for (auto k : topk)
|
||||
cur.push_back(std::move(hyps));
|
||||
|
||||
Reference in New Issue
Block a user