Track token scores (#571)
* add export of per-token scores (ys, lm, context) - for best path of the modified-beam-search decoding of transducer * refactoring JSON export of OnlineRecognitionResult, extending pybind11 API of OnlineRecognitionResult * export per-token scores also for greedy-search (online-transducer) - export un-scaled lm_probs (modified-beam search, online-transducer) - polishing * fill lm_probs/context_scores only if LM/ContextGraph is present (make Result smaller)
This commit is contained in:
@@ -71,9 +71,11 @@ void OnlineTransducerGreedySearchDecoder::StripLeadingBlanks(
|
||||
r->tokens = std::vector<int64_t>(start, end);
|
||||
}
|
||||
|
||||
|
||||
void OnlineTransducerGreedySearchDecoder::Decode(
|
||||
Ort::Value encoder_out,
|
||||
std::vector<OnlineTransducerDecoderResult> *result) {
|
||||
|
||||
std::vector<int64_t> encoder_out_shape =
|
||||
encoder_out.GetTensorTypeAndShapeInfo().GetShape();
|
||||
|
||||
@@ -97,6 +99,7 @@ void OnlineTransducerGreedySearchDecoder::Decode(
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (is_batch_decoder_out_cached) {
|
||||
auto &r = result->front();
|
||||
std::vector<int64_t> decoder_out_shape =
|
||||
@@ -124,6 +127,7 @@ void OnlineTransducerGreedySearchDecoder::Decode(
|
||||
if (blank_penalty_ > 0.0) {
|
||||
p_logit[0] -= blank_penalty_; // assuming blank id is 0
|
||||
}
|
||||
|
||||
auto y = static_cast<int32_t>(std::distance(
|
||||
static_cast<const float *>(p_logit),
|
||||
std::max_element(static_cast<const float *>(p_logit),
|
||||
@@ -138,6 +142,17 @@ void OnlineTransducerGreedySearchDecoder::Decode(
|
||||
} else {
|
||||
++r.num_trailing_blanks;
|
||||
}
|
||||
|
||||
// export the per-token log scores
|
||||
if (y != 0 && y != unk_id_) {
|
||||
LogSoftmax(p_logit, vocab_size); // renormalize probabilities,
|
||||
// save time by doing it only for
|
||||
// emitted symbols
|
||||
const float *p_logprob = p_logit; // rename p_logit as p_logprob,
|
||||
// now it contains normalized
|
||||
// probability
|
||||
r.ys_probs.push_back(p_logprob[y]);
|
||||
}
|
||||
}
|
||||
if (emitted) {
|
||||
Ort::Value decoder_input = model_->BuildDecoderInput(*result);
|
||||
|
||||
Reference in New Issue
Block a user