Track token scores (#571)

* add export of per-token scores (ys, lm, context)

- for best path of the modified-beam-search decoding of transducer

* refactoring JSON export of OnlineRecognitionResult, extending pybind11 API of OnlineRecognitionResult

* export per-token scores also for greedy-search (online-transducer)

- export un-scaled lm_probs (modified-beam search, online-transducer)
- polishing

* fill lm_probs/context_scores only if LM/ContextGraph is present (make Result smaller)
This commit is contained in:
Karel Vesely
2024-02-28 23:28:45 +01:00
committed by GitHub
parent 85d59b5840
commit 38c072dcb2
11 changed files with 155 additions and 49 deletions

View File

@@ -71,9 +71,11 @@ void OnlineTransducerGreedySearchDecoder::StripLeadingBlanks(
r->tokens = std::vector<int64_t>(start, end);
}
void OnlineTransducerGreedySearchDecoder::Decode(
Ort::Value encoder_out,
std::vector<OnlineTransducerDecoderResult> *result) {
std::vector<int64_t> encoder_out_shape =
encoder_out.GetTensorTypeAndShapeInfo().GetShape();
@@ -97,6 +99,7 @@ void OnlineTransducerGreedySearchDecoder::Decode(
break;
}
}
if (is_batch_decoder_out_cached) {
auto &r = result->front();
std::vector<int64_t> decoder_out_shape =
@@ -124,6 +127,7 @@ void OnlineTransducerGreedySearchDecoder::Decode(
if (blank_penalty_ > 0.0) {
p_logit[0] -= blank_penalty_; // assuming blank id is 0
}
auto y = static_cast<int32_t>(std::distance(
static_cast<const float *>(p_logit),
std::max_element(static_cast<const float *>(p_logit),
@@ -138,6 +142,17 @@ void OnlineTransducerGreedySearchDecoder::Decode(
} else {
++r.num_trailing_blanks;
}
// export the per-token log scores
if (y != 0 && y != unk_id_) {
LogSoftmax(p_logit, vocab_size); // renormalize probabilities,
// save time by doing it only for
// emitted symbols
const float *p_logprob = p_logit; // rename p_logit as p_logprob,
// now it contains normalized
// probability
r.ys_probs.push_back(p_logprob[y]);
}
}
if (emitted) {
Ort::Value decoder_input = model_->BuildDecoderInput(*result);