Adding temperature scaling on Joiner logits: (#789)

* Adding temperature scaling on Joiner logits:

- T hard-coded to 2.0
- so far best result NCE 0.122 (still not so high)
    - the BPE scores were rescaled with 0.2 (but then also incorrect words
      get high confidence, visually reasonable histograms are for 0.5 scale)
    - BPE->WORD score merging done by min(.) function
      (tried also prob-product, and also arithmetic, geometric, harmonic mean)

- without temperature scaling (i.e. scale 1.0), the best NCE was 0.032 (here product merging was best)

Results seem consistent with: https://arxiv.org/abs/2110.15222

Everything tuned on a very-small set of 100 sentences with 813 words and 10.2% WER, a Czech model.

I also experimented with blank posteriors mixed into the BPE confidences,
but no NCE improvement found, so not pushing that.

Temperature scling added also to the Greedy search confidences.

* making `temperature_scale` configurable from outside
This commit is contained in:
Karel Vesely
2024-04-26 03:44:26 +02:00
committed by GitHub
parent 15772d2150
commit 2e45d327a5
9 changed files with 107 additions and 30 deletions

View File

@@ -129,6 +129,22 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
model_->RunJoiner(std::move(cur_encoder_out), View(&decoder_out));
float *p_logit = logit.GetTensorMutableData<float>();
// copy raw logits, apply temperature-scaling (for confidences)
// Note: temperature scaling is used only for the confidences,
// the decoding algorithm uses the original logits
int32_t p_logit_items = vocab_size * num_hyps;
std::vector<float> logit_with_temperature(p_logit_items);
{
std::copy(p_logit,
p_logit + p_logit_items,
logit_with_temperature.begin());
for (float& elem : logit_with_temperature) {
elem /= temperature_scale_;
}
LogSoftmax(logit_with_temperature.data(), vocab_size, num_hyps);
}
if (blank_penalty_ > 0.0) {
// assuming blank id is 0
SubtractBlank(p_logit, vocab_size, num_hyps, 0, blank_penalty_);
@@ -188,10 +204,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
// score of the transducer
// export the per-token log scores
if (new_token != 0 && new_token != unk_id_) {
const Hypothesis &prev_i = prev[hyp_index];
// subtract 'prev[i]' path scores, which were added before
// getting topk tokens
float y_prob = p_logprob[k] - prev_i.log_prob - prev_i.lm_log_prob;
float y_prob = logit_with_temperature[start * vocab_size + k];
new_hyp.ys_probs.push_back(y_prob);
if (lm_) { // export only when LM is used
@@ -213,7 +226,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
cur.push_back(std::move(hyps));
p_logprob += (end - start) * vocab_size;
} // for (int32_t b = 0; b != batch_size; ++b)
}
} // for (int32_t t = 0; t != num_frames; ++t)
for (int32_t b = 0; b != batch_size; ++b) {
auto &hyps = cur[b];