Adding temperature scaling on Joiner logits: (#789)
* Adding temperature scaling on Joiner logits:
- T hard-coded to 2.0
- so far best result NCE 0.122 (still not so high)
- the BPE scores were rescaled with 0.2 (but then also incorrect words
get high confidence, visually reasonable histograms are for 0.5 scale)
- BPE->WORD score merging done by min(.) function
(tried also prob-product, and also arithmetic, geometric, harmonic mean)
- without temperature scaling (i.e. scale 1.0), the best NCE was 0.032 (here product merging was best)
Results seem consistent with: https://arxiv.org/abs/2110.15222
Everything tuned on a very-small set of 100 sentences with 813 words and 10.2% WER, a Czech model.
I also experimented with blank posteriors mixed into the BPE confidences,
but no NCE improvement found, so not pushing that.
Temperature scling added also to the Greedy search confidences.
* making `temperature_scale` configurable from outside
This commit is contained in:
@@ -103,11 +103,21 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
|||||||
}
|
}
|
||||||
|
|
||||||
decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
|
decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
|
||||||
model_.get(), lm_.get(), config_.max_active_paths,
|
model_.get(),
|
||||||
config_.lm_config.scale, unk_id_, config_.blank_penalty);
|
lm_.get(),
|
||||||
|
config_.max_active_paths,
|
||||||
|
config_.lm_config.scale,
|
||||||
|
unk_id_,
|
||||||
|
config_.blank_penalty,
|
||||||
|
config_.temperature_scale);
|
||||||
|
|
||||||
} else if (config.decoding_method == "greedy_search") {
|
} else if (config.decoding_method == "greedy_search") {
|
||||||
decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>(
|
decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>(
|
||||||
model_.get(), unk_id_, config_.blank_penalty);
|
model_.get(),
|
||||||
|
unk_id_,
|
||||||
|
config_.blank_penalty,
|
||||||
|
config_.temperature_scale);
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
|
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
|
||||||
config.decoding_method.c_str());
|
config.decoding_method.c_str());
|
||||||
@@ -141,11 +151,21 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
|||||||
}
|
}
|
||||||
|
|
||||||
decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
|
decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
|
||||||
model_.get(), lm_.get(), config_.max_active_paths,
|
model_.get(),
|
||||||
config_.lm_config.scale, unk_id_, config_.blank_penalty);
|
lm_.get(),
|
||||||
|
config_.max_active_paths,
|
||||||
|
config_.lm_config.scale,
|
||||||
|
unk_id_,
|
||||||
|
config_.blank_penalty,
|
||||||
|
config_.temperature_scale);
|
||||||
|
|
||||||
} else if (config.decoding_method == "greedy_search") {
|
} else if (config.decoding_method == "greedy_search") {
|
||||||
decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>(
|
decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>(
|
||||||
model_.get(), unk_id_, config_.blank_penalty);
|
model_.get(),
|
||||||
|
unk_id_,
|
||||||
|
config_.blank_penalty,
|
||||||
|
config_.temperature_scale);
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
|
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
|
||||||
config.decoding_method.c_str());
|
config.decoding_method.c_str());
|
||||||
|
|||||||
@@ -96,6 +96,8 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) {
|
|||||||
po->Register("decoding-method", &decoding_method,
|
po->Register("decoding-method", &decoding_method,
|
||||||
"decoding method,"
|
"decoding method,"
|
||||||
"now support greedy_search and modified_beam_search.");
|
"now support greedy_search and modified_beam_search.");
|
||||||
|
po->Register("temperature-scale", &temperature_scale,
|
||||||
|
"Temperature scale for confidence computation in decoding.");
|
||||||
}
|
}
|
||||||
|
|
||||||
bool OnlineRecognizerConfig::Validate() const {
|
bool OnlineRecognizerConfig::Validate() const {
|
||||||
@@ -142,7 +144,8 @@ std::string OnlineRecognizerConfig::ToString() const {
|
|||||||
os << "hotwords_score=" << hotwords_score << ", ";
|
os << "hotwords_score=" << hotwords_score << ", ";
|
||||||
os << "hotwords_file=\"" << hotwords_file << "\", ";
|
os << "hotwords_file=\"" << hotwords_file << "\", ";
|
||||||
os << "decoding_method=\"" << decoding_method << "\", ";
|
os << "decoding_method=\"" << decoding_method << "\", ";
|
||||||
os << "blank_penalty=" << blank_penalty << ")";
|
os << "blank_penalty=" << blank_penalty << ", ";
|
||||||
|
os << "temperature_scale=" << temperature_scale << ")";
|
||||||
|
|
||||||
return os.str();
|
return os.str();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -96,16 +96,23 @@ struct OnlineRecognizerConfig {
|
|||||||
|
|
||||||
float blank_penalty = 0.0;
|
float blank_penalty = 0.0;
|
||||||
|
|
||||||
|
float temperature_scale = 2.0;
|
||||||
|
|
||||||
OnlineRecognizerConfig() = default;
|
OnlineRecognizerConfig() = default;
|
||||||
|
|
||||||
OnlineRecognizerConfig(
|
OnlineRecognizerConfig(
|
||||||
const FeatureExtractorConfig &feat_config,
|
const FeatureExtractorConfig &feat_config,
|
||||||
const OnlineModelConfig &model_config, const OnlineLMConfig &lm_config,
|
const OnlineModelConfig &model_config,
|
||||||
|
const OnlineLMConfig &lm_config,
|
||||||
const EndpointConfig &endpoint_config,
|
const EndpointConfig &endpoint_config,
|
||||||
const OnlineCtcFstDecoderConfig &ctc_fst_decoder_config,
|
const OnlineCtcFstDecoderConfig &ctc_fst_decoder_config,
|
||||||
bool enable_endpoint, const std::string &decoding_method,
|
bool enable_endpoint,
|
||||||
int32_t max_active_paths, const std::string &hotwords_file,
|
const std::string &decoding_method,
|
||||||
float hotwords_score, float blank_penalty)
|
int32_t max_active_paths,
|
||||||
|
const std::string &hotwords_file,
|
||||||
|
float hotwords_score,
|
||||||
|
float blank_penalty,
|
||||||
|
float temperature_scale)
|
||||||
: feat_config(feat_config),
|
: feat_config(feat_config),
|
||||||
model_config(model_config),
|
model_config(model_config),
|
||||||
lm_config(lm_config),
|
lm_config(lm_config),
|
||||||
@@ -114,9 +121,10 @@ struct OnlineRecognizerConfig {
|
|||||||
enable_endpoint(enable_endpoint),
|
enable_endpoint(enable_endpoint),
|
||||||
decoding_method(decoding_method),
|
decoding_method(decoding_method),
|
||||||
max_active_paths(max_active_paths),
|
max_active_paths(max_active_paths),
|
||||||
hotwords_score(hotwords_score),
|
|
||||||
hotwords_file(hotwords_file),
|
hotwords_file(hotwords_file),
|
||||||
blank_penalty(blank_penalty) {}
|
hotwords_score(hotwords_score),
|
||||||
|
blank_penalty(blank_penalty),
|
||||||
|
temperature_scale(temperature_scale) {}
|
||||||
|
|
||||||
void Register(ParseOptions *po);
|
void Register(ParseOptions *po);
|
||||||
bool Validate() const;
|
bool Validate() const;
|
||||||
|
|||||||
@@ -144,6 +144,10 @@ void OnlineTransducerGreedySearchDecoder::Decode(
|
|||||||
|
|
||||||
// export the per-token log scores
|
// export the per-token log scores
|
||||||
if (y != 0 && y != unk_id_) {
|
if (y != 0 && y != unk_id_) {
|
||||||
|
// apply temperature-scaling
|
||||||
|
for (int32_t n = 0; n < vocab_size; ++n) {
|
||||||
|
p_logit[n] /= temperature_scale_;
|
||||||
|
}
|
||||||
LogSoftmax(p_logit, vocab_size); // renormalize probabilities,
|
LogSoftmax(p_logit, vocab_size); // renormalize probabilities,
|
||||||
// save time by doing it only for
|
// save time by doing it only for
|
||||||
// emitted symbols
|
// emitted symbols
|
||||||
|
|||||||
@@ -15,8 +15,13 @@ namespace sherpa_onnx {
|
|||||||
class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder {
|
class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder {
|
||||||
public:
|
public:
|
||||||
OnlineTransducerGreedySearchDecoder(OnlineTransducerModel *model,
|
OnlineTransducerGreedySearchDecoder(OnlineTransducerModel *model,
|
||||||
int32_t unk_id, float blank_penalty)
|
int32_t unk_id,
|
||||||
: model_(model), unk_id_(unk_id), blank_penalty_(blank_penalty) {}
|
float blank_penalty,
|
||||||
|
float temperature_scale)
|
||||||
|
: model_(model),
|
||||||
|
unk_id_(unk_id),
|
||||||
|
blank_penalty_(blank_penalty),
|
||||||
|
temperature_scale_(temperature_scale) {}
|
||||||
|
|
||||||
OnlineTransducerDecoderResult GetEmptyResult() const override;
|
OnlineTransducerDecoderResult GetEmptyResult() const override;
|
||||||
|
|
||||||
@@ -29,6 +34,7 @@ class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder {
|
|||||||
OnlineTransducerModel *model_; // Not owned
|
OnlineTransducerModel *model_; // Not owned
|
||||||
int32_t unk_id_;
|
int32_t unk_id_;
|
||||||
float blank_penalty_;
|
float blank_penalty_;
|
||||||
|
float temperature_scale_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|||||||
@@ -129,6 +129,22 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
|||||||
model_->RunJoiner(std::move(cur_encoder_out), View(&decoder_out));
|
model_->RunJoiner(std::move(cur_encoder_out), View(&decoder_out));
|
||||||
|
|
||||||
float *p_logit = logit.GetTensorMutableData<float>();
|
float *p_logit = logit.GetTensorMutableData<float>();
|
||||||
|
|
||||||
|
// copy raw logits, apply temperature-scaling (for confidences)
|
||||||
|
// Note: temperature scaling is used only for the confidences,
|
||||||
|
// the decoding algorithm uses the original logits
|
||||||
|
int32_t p_logit_items = vocab_size * num_hyps;
|
||||||
|
std::vector<float> logit_with_temperature(p_logit_items);
|
||||||
|
{
|
||||||
|
std::copy(p_logit,
|
||||||
|
p_logit + p_logit_items,
|
||||||
|
logit_with_temperature.begin());
|
||||||
|
for (float& elem : logit_with_temperature) {
|
||||||
|
elem /= temperature_scale_;
|
||||||
|
}
|
||||||
|
LogSoftmax(logit_with_temperature.data(), vocab_size, num_hyps);
|
||||||
|
}
|
||||||
|
|
||||||
if (blank_penalty_ > 0.0) {
|
if (blank_penalty_ > 0.0) {
|
||||||
// assuming blank id is 0
|
// assuming blank id is 0
|
||||||
SubtractBlank(p_logit, vocab_size, num_hyps, 0, blank_penalty_);
|
SubtractBlank(p_logit, vocab_size, num_hyps, 0, blank_penalty_);
|
||||||
@@ -188,10 +204,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
|||||||
// score of the transducer
|
// score of the transducer
|
||||||
// export the per-token log scores
|
// export the per-token log scores
|
||||||
if (new_token != 0 && new_token != unk_id_) {
|
if (new_token != 0 && new_token != unk_id_) {
|
||||||
const Hypothesis &prev_i = prev[hyp_index];
|
float y_prob = logit_with_temperature[start * vocab_size + k];
|
||||||
// subtract 'prev[i]' path scores, which were added before
|
|
||||||
// getting topk tokens
|
|
||||||
float y_prob = p_logprob[k] - prev_i.log_prob - prev_i.lm_log_prob;
|
|
||||||
new_hyp.ys_probs.push_back(y_prob);
|
new_hyp.ys_probs.push_back(y_prob);
|
||||||
|
|
||||||
if (lm_) { // export only when LM is used
|
if (lm_) { // export only when LM is used
|
||||||
@@ -213,7 +226,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
|||||||
cur.push_back(std::move(hyps));
|
cur.push_back(std::move(hyps));
|
||||||
p_logprob += (end - start) * vocab_size;
|
p_logprob += (end - start) * vocab_size;
|
||||||
} // for (int32_t b = 0; b != batch_size; ++b)
|
} // for (int32_t b = 0; b != batch_size; ++b)
|
||||||
}
|
} // for (int32_t t = 0; t != num_frames; ++t)
|
||||||
|
|
||||||
for (int32_t b = 0; b != batch_size; ++b) {
|
for (int32_t b = 0; b != batch_size; ++b) {
|
||||||
auto &hyps = cur[b];
|
auto &hyps = cur[b];
|
||||||
|
|||||||
@@ -22,13 +22,15 @@ class OnlineTransducerModifiedBeamSearchDecoder
|
|||||||
OnlineLM *lm,
|
OnlineLM *lm,
|
||||||
int32_t max_active_paths,
|
int32_t max_active_paths,
|
||||||
float lm_scale, int32_t unk_id,
|
float lm_scale, int32_t unk_id,
|
||||||
float blank_penalty)
|
float blank_penalty,
|
||||||
|
float temperature_scale)
|
||||||
: model_(model),
|
: model_(model),
|
||||||
lm_(lm),
|
lm_(lm),
|
||||||
max_active_paths_(max_active_paths),
|
max_active_paths_(max_active_paths),
|
||||||
lm_scale_(lm_scale),
|
lm_scale_(lm_scale),
|
||||||
unk_id_(unk_id),
|
unk_id_(unk_id),
|
||||||
blank_penalty_(blank_penalty) {}
|
blank_penalty_(blank_penalty),
|
||||||
|
temperature_scale_(temperature_scale) {}
|
||||||
|
|
||||||
OnlineTransducerDecoderResult GetEmptyResult() const override;
|
OnlineTransducerDecoderResult GetEmptyResult() const override;
|
||||||
|
|
||||||
@@ -50,6 +52,7 @@ class OnlineTransducerModifiedBeamSearchDecoder
|
|||||||
float lm_scale_; // used only when lm_ is not nullptr
|
float lm_scale_; // used only when lm_ is not nullptr
|
||||||
int32_t unk_id_;
|
int32_t unk_id_;
|
||||||
float blank_penalty_;
|
float blank_penalty_;
|
||||||
|
float temperature_scale_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|||||||
@@ -50,17 +50,30 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
|
|||||||
using PyClass = OnlineRecognizerConfig;
|
using PyClass = OnlineRecognizerConfig;
|
||||||
py::class_<PyClass>(*m, "OnlineRecognizerConfig")
|
py::class_<PyClass>(*m, "OnlineRecognizerConfig")
|
||||||
.def(
|
.def(
|
||||||
py::init<const FeatureExtractorConfig &, const OnlineModelConfig &,
|
py::init<const FeatureExtractorConfig &,
|
||||||
const OnlineLMConfig &, const EndpointConfig &,
|
const OnlineModelConfig &,
|
||||||
const OnlineCtcFstDecoderConfig &, bool, const std::string &,
|
const OnlineLMConfig &,
|
||||||
int32_t, const std::string &, float, float>(),
|
const EndpointConfig &,
|
||||||
py::arg("feat_config"), py::arg("model_config"),
|
const OnlineCtcFstDecoderConfig &,
|
||||||
|
bool,
|
||||||
|
const std::string &,
|
||||||
|
int32_t,
|
||||||
|
const std::string &,
|
||||||
|
float,
|
||||||
|
float,
|
||||||
|
float>(),
|
||||||
|
py::arg("feat_config"),
|
||||||
|
py::arg("model_config"),
|
||||||
py::arg("lm_config") = OnlineLMConfig(),
|
py::arg("lm_config") = OnlineLMConfig(),
|
||||||
py::arg("endpoint_config") = EndpointConfig(),
|
py::arg("endpoint_config") = EndpointConfig(),
|
||||||
py::arg("ctc_fst_decoder_config") = OnlineCtcFstDecoderConfig(),
|
py::arg("ctc_fst_decoder_config") = OnlineCtcFstDecoderConfig(),
|
||||||
py::arg("enable_endpoint"), py::arg("decoding_method"),
|
py::arg("enable_endpoint"),
|
||||||
py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "",
|
py::arg("decoding_method"),
|
||||||
py::arg("hotwords_score") = 0, py::arg("blank_penalty") = 0.0)
|
py::arg("max_active_paths") = 4,
|
||||||
|
py::arg("hotwords_file") = "",
|
||||||
|
py::arg("hotwords_score") = 0,
|
||||||
|
py::arg("blank_penalty") = 0.0,
|
||||||
|
py::arg("temperature_scale") = 2.0)
|
||||||
.def_readwrite("feat_config", &PyClass::feat_config)
|
.def_readwrite("feat_config", &PyClass::feat_config)
|
||||||
.def_readwrite("model_config", &PyClass::model_config)
|
.def_readwrite("model_config", &PyClass::model_config)
|
||||||
.def_readwrite("lm_config", &PyClass::lm_config)
|
.def_readwrite("lm_config", &PyClass::lm_config)
|
||||||
@@ -72,6 +85,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
|
|||||||
.def_readwrite("hotwords_file", &PyClass::hotwords_file)
|
.def_readwrite("hotwords_file", &PyClass::hotwords_file)
|
||||||
.def_readwrite("hotwords_score", &PyClass::hotwords_score)
|
.def_readwrite("hotwords_score", &PyClass::hotwords_score)
|
||||||
.def_readwrite("blank_penalty", &PyClass::blank_penalty)
|
.def_readwrite("blank_penalty", &PyClass::blank_penalty)
|
||||||
|
.def_readwrite("temperature_scale", &PyClass::temperature_scale)
|
||||||
.def("__str__", &PyClass::ToString);
|
.def("__str__", &PyClass::ToString);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -58,6 +58,7 @@ class OnlineRecognizer(object):
|
|||||||
model_type: str = "",
|
model_type: str = "",
|
||||||
lm: str = "",
|
lm: str = "",
|
||||||
lm_scale: float = 0.1,
|
lm_scale: float = 0.1,
|
||||||
|
temperature_scale: float = 2.0,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Please refer to
|
Please refer to
|
||||||
@@ -123,6 +124,10 @@ class OnlineRecognizer(object):
|
|||||||
hotwords_score:
|
hotwords_score:
|
||||||
The hotword score of each token for biasing word/phrase. Used only if
|
The hotword score of each token for biasing word/phrase. Used only if
|
||||||
hotwords_file is given with modified_beam_search as decoding method.
|
hotwords_file is given with modified_beam_search as decoding method.
|
||||||
|
temperature_scale:
|
||||||
|
Temperature scaling for output symbol confidence estiamation.
|
||||||
|
It affects only confidence values, the decoding uses the original
|
||||||
|
logits without temperature.
|
||||||
provider:
|
provider:
|
||||||
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
|
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
|
||||||
model_type:
|
model_type:
|
||||||
@@ -193,6 +198,7 @@ class OnlineRecognizer(object):
|
|||||||
hotwords_score=hotwords_score,
|
hotwords_score=hotwords_score,
|
||||||
hotwords_file=hotwords_file,
|
hotwords_file=hotwords_file,
|
||||||
blank_penalty=blank_penalty,
|
blank_penalty=blank_penalty,
|
||||||
|
temperature_scale=temperature_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.recognizer = _Recognizer(recognizer_config)
|
self.recognizer = _Recognizer(recognizer_config)
|
||||||
|
|||||||
Reference in New Issue
Block a user