Adding temperature scaling on Joiner logits: (#789)
* Adding temperature scaling on Joiner logits:
- T hard-coded to 2.0
- so far best result NCE 0.122 (still not so high)
- the BPE scores were rescaled with 0.2 (but then also incorrect words
get high confidence, visually reasonable histograms are for 0.5 scale)
- BPE->WORD score merging done by min(.) function
(tried also prob-product, and also arithmetic, geometric, harmonic mean)
- without temperature scaling (i.e. scale 1.0), the best NCE was 0.032 (here product merging was best)
Results seem consistent with: https://arxiv.org/abs/2110.15222
Everything tuned on a very-small set of 100 sentences with 813 words and 10.2% WER, a Czech model.
I also experimented with blank posteriors mixed into the BPE confidences,
but no NCE improvement found, so not pushing that.
Temperature scling added also to the Greedy search confidences.
* making `temperature_scale` configurable from outside
This commit is contained in:
@@ -103,11 +103,21 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
||||
}
|
||||
|
||||
decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
|
||||
model_.get(), lm_.get(), config_.max_active_paths,
|
||||
config_.lm_config.scale, unk_id_, config_.blank_penalty);
|
||||
model_.get(),
|
||||
lm_.get(),
|
||||
config_.max_active_paths,
|
||||
config_.lm_config.scale,
|
||||
unk_id_,
|
||||
config_.blank_penalty,
|
||||
config_.temperature_scale);
|
||||
|
||||
} else if (config.decoding_method == "greedy_search") {
|
||||
decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>(
|
||||
model_.get(), unk_id_, config_.blank_penalty);
|
||||
model_.get(),
|
||||
unk_id_,
|
||||
config_.blank_penalty,
|
||||
config_.temperature_scale);
|
||||
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
|
||||
config.decoding_method.c_str());
|
||||
@@ -141,11 +151,21 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
||||
}
|
||||
|
||||
decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
|
||||
model_.get(), lm_.get(), config_.max_active_paths,
|
||||
config_.lm_config.scale, unk_id_, config_.blank_penalty);
|
||||
model_.get(),
|
||||
lm_.get(),
|
||||
config_.max_active_paths,
|
||||
config_.lm_config.scale,
|
||||
unk_id_,
|
||||
config_.blank_penalty,
|
||||
config_.temperature_scale);
|
||||
|
||||
} else if (config.decoding_method == "greedy_search") {
|
||||
decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>(
|
||||
model_.get(), unk_id_, config_.blank_penalty);
|
||||
model_.get(),
|
||||
unk_id_,
|
||||
config_.blank_penalty,
|
||||
config_.temperature_scale);
|
||||
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
|
||||
config.decoding_method.c_str());
|
||||
|
||||
@@ -96,6 +96,8 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) {
|
||||
po->Register("decoding-method", &decoding_method,
|
||||
"decoding method,"
|
||||
"now support greedy_search and modified_beam_search.");
|
||||
po->Register("temperature-scale", &temperature_scale,
|
||||
"Temperature scale for confidence computation in decoding.");
|
||||
}
|
||||
|
||||
bool OnlineRecognizerConfig::Validate() const {
|
||||
@@ -142,7 +144,8 @@ std::string OnlineRecognizerConfig::ToString() const {
|
||||
os << "hotwords_score=" << hotwords_score << ", ";
|
||||
os << "hotwords_file=\"" << hotwords_file << "\", ";
|
||||
os << "decoding_method=\"" << decoding_method << "\", ";
|
||||
os << "blank_penalty=" << blank_penalty << ")";
|
||||
os << "blank_penalty=" << blank_penalty << ", ";
|
||||
os << "temperature_scale=" << temperature_scale << ")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
@@ -96,16 +96,23 @@ struct OnlineRecognizerConfig {
|
||||
|
||||
float blank_penalty = 0.0;
|
||||
|
||||
float temperature_scale = 2.0;
|
||||
|
||||
OnlineRecognizerConfig() = default;
|
||||
|
||||
OnlineRecognizerConfig(
|
||||
const FeatureExtractorConfig &feat_config,
|
||||
const OnlineModelConfig &model_config, const OnlineLMConfig &lm_config,
|
||||
const OnlineModelConfig &model_config,
|
||||
const OnlineLMConfig &lm_config,
|
||||
const EndpointConfig &endpoint_config,
|
||||
const OnlineCtcFstDecoderConfig &ctc_fst_decoder_config,
|
||||
bool enable_endpoint, const std::string &decoding_method,
|
||||
int32_t max_active_paths, const std::string &hotwords_file,
|
||||
float hotwords_score, float blank_penalty)
|
||||
bool enable_endpoint,
|
||||
const std::string &decoding_method,
|
||||
int32_t max_active_paths,
|
||||
const std::string &hotwords_file,
|
||||
float hotwords_score,
|
||||
float blank_penalty,
|
||||
float temperature_scale)
|
||||
: feat_config(feat_config),
|
||||
model_config(model_config),
|
||||
lm_config(lm_config),
|
||||
@@ -114,9 +121,10 @@ struct OnlineRecognizerConfig {
|
||||
enable_endpoint(enable_endpoint),
|
||||
decoding_method(decoding_method),
|
||||
max_active_paths(max_active_paths),
|
||||
hotwords_score(hotwords_score),
|
||||
hotwords_file(hotwords_file),
|
||||
blank_penalty(blank_penalty) {}
|
||||
hotwords_score(hotwords_score),
|
||||
blank_penalty(blank_penalty),
|
||||
temperature_scale(temperature_scale) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
|
||||
@@ -144,6 +144,10 @@ void OnlineTransducerGreedySearchDecoder::Decode(
|
||||
|
||||
// export the per-token log scores
|
||||
if (y != 0 && y != unk_id_) {
|
||||
// apply temperature-scaling
|
||||
for (int32_t n = 0; n < vocab_size; ++n) {
|
||||
p_logit[n] /= temperature_scale_;
|
||||
}
|
||||
LogSoftmax(p_logit, vocab_size); // renormalize probabilities,
|
||||
// save time by doing it only for
|
||||
// emitted symbols
|
||||
|
||||
@@ -15,8 +15,13 @@ namespace sherpa_onnx {
|
||||
class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder {
|
||||
public:
|
||||
OnlineTransducerGreedySearchDecoder(OnlineTransducerModel *model,
|
||||
int32_t unk_id, float blank_penalty)
|
||||
: model_(model), unk_id_(unk_id), blank_penalty_(blank_penalty) {}
|
||||
int32_t unk_id,
|
||||
float blank_penalty,
|
||||
float temperature_scale)
|
||||
: model_(model),
|
||||
unk_id_(unk_id),
|
||||
blank_penalty_(blank_penalty),
|
||||
temperature_scale_(temperature_scale) {}
|
||||
|
||||
OnlineTransducerDecoderResult GetEmptyResult() const override;
|
||||
|
||||
@@ -29,6 +34,7 @@ class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder {
|
||||
OnlineTransducerModel *model_; // Not owned
|
||||
int32_t unk_id_;
|
||||
float blank_penalty_;
|
||||
float temperature_scale_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -129,6 +129,22 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
model_->RunJoiner(std::move(cur_encoder_out), View(&decoder_out));
|
||||
|
||||
float *p_logit = logit.GetTensorMutableData<float>();
|
||||
|
||||
// copy raw logits, apply temperature-scaling (for confidences)
|
||||
// Note: temperature scaling is used only for the confidences,
|
||||
// the decoding algorithm uses the original logits
|
||||
int32_t p_logit_items = vocab_size * num_hyps;
|
||||
std::vector<float> logit_with_temperature(p_logit_items);
|
||||
{
|
||||
std::copy(p_logit,
|
||||
p_logit + p_logit_items,
|
||||
logit_with_temperature.begin());
|
||||
for (float& elem : logit_with_temperature) {
|
||||
elem /= temperature_scale_;
|
||||
}
|
||||
LogSoftmax(logit_with_temperature.data(), vocab_size, num_hyps);
|
||||
}
|
||||
|
||||
if (blank_penalty_ > 0.0) {
|
||||
// assuming blank id is 0
|
||||
SubtractBlank(p_logit, vocab_size, num_hyps, 0, blank_penalty_);
|
||||
@@ -188,10 +204,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
// score of the transducer
|
||||
// export the per-token log scores
|
||||
if (new_token != 0 && new_token != unk_id_) {
|
||||
const Hypothesis &prev_i = prev[hyp_index];
|
||||
// subtract 'prev[i]' path scores, which were added before
|
||||
// getting topk tokens
|
||||
float y_prob = p_logprob[k] - prev_i.log_prob - prev_i.lm_log_prob;
|
||||
float y_prob = logit_with_temperature[start * vocab_size + k];
|
||||
new_hyp.ys_probs.push_back(y_prob);
|
||||
|
||||
if (lm_) { // export only when LM is used
|
||||
@@ -213,7 +226,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
cur.push_back(std::move(hyps));
|
||||
p_logprob += (end - start) * vocab_size;
|
||||
} // for (int32_t b = 0; b != batch_size; ++b)
|
||||
}
|
||||
} // for (int32_t t = 0; t != num_frames; ++t)
|
||||
|
||||
for (int32_t b = 0; b != batch_size; ++b) {
|
||||
auto &hyps = cur[b];
|
||||
|
||||
@@ -22,13 +22,15 @@ class OnlineTransducerModifiedBeamSearchDecoder
|
||||
OnlineLM *lm,
|
||||
int32_t max_active_paths,
|
||||
float lm_scale, int32_t unk_id,
|
||||
float blank_penalty)
|
||||
float blank_penalty,
|
||||
float temperature_scale)
|
||||
: model_(model),
|
||||
lm_(lm),
|
||||
max_active_paths_(max_active_paths),
|
||||
lm_scale_(lm_scale),
|
||||
unk_id_(unk_id),
|
||||
blank_penalty_(blank_penalty) {}
|
||||
blank_penalty_(blank_penalty),
|
||||
temperature_scale_(temperature_scale) {}
|
||||
|
||||
OnlineTransducerDecoderResult GetEmptyResult() const override;
|
||||
|
||||
@@ -50,6 +52,7 @@ class OnlineTransducerModifiedBeamSearchDecoder
|
||||
float lm_scale_; // used only when lm_ is not nullptr
|
||||
int32_t unk_id_;
|
||||
float blank_penalty_;
|
||||
float temperature_scale_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
Reference in New Issue
Block a user