Adding temperature scaling on Joiner logits: (#789)
* Adding temperature scaling on Joiner logits:
- T hard-coded to 2.0
- so far best result NCE 0.122 (still not so high)
- the BPE scores were rescaled with 0.2 (but then also incorrect words
get high confidence, visually reasonable histograms are for 0.5 scale)
- BPE->WORD score merging done by min(.) function
(tried also prob-product, and also arithmetic, geometric, harmonic mean)
- without temperature scaling (i.e. scale 1.0), the best NCE was 0.032 (here product merging was best)
Results seem consistent with: https://arxiv.org/abs/2110.15222
Everything tuned on a very-small set of 100 sentences with 813 words and 10.2% WER, a Czech model.
I also experimented with blank posteriors mixed into the BPE confidences,
but no NCE improvement found, so not pushing that.
Temperature scling added also to the Greedy search confidences.
* making `temperature_scale` configurable from outside
This commit is contained in:
@@ -103,11 +103,21 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
||||
}
|
||||
|
||||
decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
|
||||
model_.get(), lm_.get(), config_.max_active_paths,
|
||||
config_.lm_config.scale, unk_id_, config_.blank_penalty);
|
||||
model_.get(),
|
||||
lm_.get(),
|
||||
config_.max_active_paths,
|
||||
config_.lm_config.scale,
|
||||
unk_id_,
|
||||
config_.blank_penalty,
|
||||
config_.temperature_scale);
|
||||
|
||||
} else if (config.decoding_method == "greedy_search") {
|
||||
decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>(
|
||||
model_.get(), unk_id_, config_.blank_penalty);
|
||||
model_.get(),
|
||||
unk_id_,
|
||||
config_.blank_penalty,
|
||||
config_.temperature_scale);
|
||||
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
|
||||
config.decoding_method.c_str());
|
||||
@@ -141,11 +151,21 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
||||
}
|
||||
|
||||
decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
|
||||
model_.get(), lm_.get(), config_.max_active_paths,
|
||||
config_.lm_config.scale, unk_id_, config_.blank_penalty);
|
||||
model_.get(),
|
||||
lm_.get(),
|
||||
config_.max_active_paths,
|
||||
config_.lm_config.scale,
|
||||
unk_id_,
|
||||
config_.blank_penalty,
|
||||
config_.temperature_scale);
|
||||
|
||||
} else if (config.decoding_method == "greedy_search") {
|
||||
decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>(
|
||||
model_.get(), unk_id_, config_.blank_penalty);
|
||||
model_.get(),
|
||||
unk_id_,
|
||||
config_.blank_penalty,
|
||||
config_.temperature_scale);
|
||||
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
|
||||
config.decoding_method.c_str());
|
||||
|
||||
@@ -96,6 +96,8 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) {
|
||||
po->Register("decoding-method", &decoding_method,
|
||||
"decoding method,"
|
||||
"now support greedy_search and modified_beam_search.");
|
||||
po->Register("temperature-scale", &temperature_scale,
|
||||
"Temperature scale for confidence computation in decoding.");
|
||||
}
|
||||
|
||||
bool OnlineRecognizerConfig::Validate() const {
|
||||
@@ -142,7 +144,8 @@ std::string OnlineRecognizerConfig::ToString() const {
|
||||
os << "hotwords_score=" << hotwords_score << ", ";
|
||||
os << "hotwords_file=\"" << hotwords_file << "\", ";
|
||||
os << "decoding_method=\"" << decoding_method << "\", ";
|
||||
os << "blank_penalty=" << blank_penalty << ")";
|
||||
os << "blank_penalty=" << blank_penalty << ", ";
|
||||
os << "temperature_scale=" << temperature_scale << ")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
@@ -96,16 +96,23 @@ struct OnlineRecognizerConfig {
|
||||
|
||||
float blank_penalty = 0.0;
|
||||
|
||||
float temperature_scale = 2.0;
|
||||
|
||||
OnlineRecognizerConfig() = default;
|
||||
|
||||
OnlineRecognizerConfig(
|
||||
const FeatureExtractorConfig &feat_config,
|
||||
const OnlineModelConfig &model_config, const OnlineLMConfig &lm_config,
|
||||
const OnlineModelConfig &model_config,
|
||||
const OnlineLMConfig &lm_config,
|
||||
const EndpointConfig &endpoint_config,
|
||||
const OnlineCtcFstDecoderConfig &ctc_fst_decoder_config,
|
||||
bool enable_endpoint, const std::string &decoding_method,
|
||||
int32_t max_active_paths, const std::string &hotwords_file,
|
||||
float hotwords_score, float blank_penalty)
|
||||
bool enable_endpoint,
|
||||
const std::string &decoding_method,
|
||||
int32_t max_active_paths,
|
||||
const std::string &hotwords_file,
|
||||
float hotwords_score,
|
||||
float blank_penalty,
|
||||
float temperature_scale)
|
||||
: feat_config(feat_config),
|
||||
model_config(model_config),
|
||||
lm_config(lm_config),
|
||||
@@ -114,9 +121,10 @@ struct OnlineRecognizerConfig {
|
||||
enable_endpoint(enable_endpoint),
|
||||
decoding_method(decoding_method),
|
||||
max_active_paths(max_active_paths),
|
||||
hotwords_score(hotwords_score),
|
||||
hotwords_file(hotwords_file),
|
||||
blank_penalty(blank_penalty) {}
|
||||
hotwords_score(hotwords_score),
|
||||
blank_penalty(blank_penalty),
|
||||
temperature_scale(temperature_scale) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
|
||||
@@ -144,6 +144,10 @@ void OnlineTransducerGreedySearchDecoder::Decode(
|
||||
|
||||
// export the per-token log scores
|
||||
if (y != 0 && y != unk_id_) {
|
||||
// apply temperature-scaling
|
||||
for (int32_t n = 0; n < vocab_size; ++n) {
|
||||
p_logit[n] /= temperature_scale_;
|
||||
}
|
||||
LogSoftmax(p_logit, vocab_size); // renormalize probabilities,
|
||||
// save time by doing it only for
|
||||
// emitted symbols
|
||||
|
||||
@@ -15,8 +15,13 @@ namespace sherpa_onnx {
|
||||
class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder {
|
||||
public:
|
||||
OnlineTransducerGreedySearchDecoder(OnlineTransducerModel *model,
|
||||
int32_t unk_id, float blank_penalty)
|
||||
: model_(model), unk_id_(unk_id), blank_penalty_(blank_penalty) {}
|
||||
int32_t unk_id,
|
||||
float blank_penalty,
|
||||
float temperature_scale)
|
||||
: model_(model),
|
||||
unk_id_(unk_id),
|
||||
blank_penalty_(blank_penalty),
|
||||
temperature_scale_(temperature_scale) {}
|
||||
|
||||
OnlineTransducerDecoderResult GetEmptyResult() const override;
|
||||
|
||||
@@ -29,6 +34,7 @@ class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder {
|
||||
OnlineTransducerModel *model_; // Not owned
|
||||
int32_t unk_id_;
|
||||
float blank_penalty_;
|
||||
float temperature_scale_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -129,6 +129,22 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
model_->RunJoiner(std::move(cur_encoder_out), View(&decoder_out));
|
||||
|
||||
float *p_logit = logit.GetTensorMutableData<float>();
|
||||
|
||||
// copy raw logits, apply temperature-scaling (for confidences)
|
||||
// Note: temperature scaling is used only for the confidences,
|
||||
// the decoding algorithm uses the original logits
|
||||
int32_t p_logit_items = vocab_size * num_hyps;
|
||||
std::vector<float> logit_with_temperature(p_logit_items);
|
||||
{
|
||||
std::copy(p_logit,
|
||||
p_logit + p_logit_items,
|
||||
logit_with_temperature.begin());
|
||||
for (float& elem : logit_with_temperature) {
|
||||
elem /= temperature_scale_;
|
||||
}
|
||||
LogSoftmax(logit_with_temperature.data(), vocab_size, num_hyps);
|
||||
}
|
||||
|
||||
if (blank_penalty_ > 0.0) {
|
||||
// assuming blank id is 0
|
||||
SubtractBlank(p_logit, vocab_size, num_hyps, 0, blank_penalty_);
|
||||
@@ -188,10 +204,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
// score of the transducer
|
||||
// export the per-token log scores
|
||||
if (new_token != 0 && new_token != unk_id_) {
|
||||
const Hypothesis &prev_i = prev[hyp_index];
|
||||
// subtract 'prev[i]' path scores, which were added before
|
||||
// getting topk tokens
|
||||
float y_prob = p_logprob[k] - prev_i.log_prob - prev_i.lm_log_prob;
|
||||
float y_prob = logit_with_temperature[start * vocab_size + k];
|
||||
new_hyp.ys_probs.push_back(y_prob);
|
||||
|
||||
if (lm_) { // export only when LM is used
|
||||
@@ -213,7 +226,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
cur.push_back(std::move(hyps));
|
||||
p_logprob += (end - start) * vocab_size;
|
||||
} // for (int32_t b = 0; b != batch_size; ++b)
|
||||
}
|
||||
} // for (int32_t t = 0; t != num_frames; ++t)
|
||||
|
||||
for (int32_t b = 0; b != batch_size; ++b) {
|
||||
auto &hyps = cur[b];
|
||||
|
||||
@@ -22,13 +22,15 @@ class OnlineTransducerModifiedBeamSearchDecoder
|
||||
OnlineLM *lm,
|
||||
int32_t max_active_paths,
|
||||
float lm_scale, int32_t unk_id,
|
||||
float blank_penalty)
|
||||
float blank_penalty,
|
||||
float temperature_scale)
|
||||
: model_(model),
|
||||
lm_(lm),
|
||||
max_active_paths_(max_active_paths),
|
||||
lm_scale_(lm_scale),
|
||||
unk_id_(unk_id),
|
||||
blank_penalty_(blank_penalty) {}
|
||||
blank_penalty_(blank_penalty),
|
||||
temperature_scale_(temperature_scale) {}
|
||||
|
||||
OnlineTransducerDecoderResult GetEmptyResult() const override;
|
||||
|
||||
@@ -50,6 +52,7 @@ class OnlineTransducerModifiedBeamSearchDecoder
|
||||
float lm_scale_; // used only when lm_ is not nullptr
|
||||
int32_t unk_id_;
|
||||
float blank_penalty_;
|
||||
float temperature_scale_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -50,17 +50,30 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
|
||||
using PyClass = OnlineRecognizerConfig;
|
||||
py::class_<PyClass>(*m, "OnlineRecognizerConfig")
|
||||
.def(
|
||||
py::init<const FeatureExtractorConfig &, const OnlineModelConfig &,
|
||||
const OnlineLMConfig &, const EndpointConfig &,
|
||||
const OnlineCtcFstDecoderConfig &, bool, const std::string &,
|
||||
int32_t, const std::string &, float, float>(),
|
||||
py::arg("feat_config"), py::arg("model_config"),
|
||||
py::init<const FeatureExtractorConfig &,
|
||||
const OnlineModelConfig &,
|
||||
const OnlineLMConfig &,
|
||||
const EndpointConfig &,
|
||||
const OnlineCtcFstDecoderConfig &,
|
||||
bool,
|
||||
const std::string &,
|
||||
int32_t,
|
||||
const std::string &,
|
||||
float,
|
||||
float,
|
||||
float>(),
|
||||
py::arg("feat_config"),
|
||||
py::arg("model_config"),
|
||||
py::arg("lm_config") = OnlineLMConfig(),
|
||||
py::arg("endpoint_config") = EndpointConfig(),
|
||||
py::arg("ctc_fst_decoder_config") = OnlineCtcFstDecoderConfig(),
|
||||
py::arg("enable_endpoint"), py::arg("decoding_method"),
|
||||
py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "",
|
||||
py::arg("hotwords_score") = 0, py::arg("blank_penalty") = 0.0)
|
||||
py::arg("enable_endpoint"),
|
||||
py::arg("decoding_method"),
|
||||
py::arg("max_active_paths") = 4,
|
||||
py::arg("hotwords_file") = "",
|
||||
py::arg("hotwords_score") = 0,
|
||||
py::arg("blank_penalty") = 0.0,
|
||||
py::arg("temperature_scale") = 2.0)
|
||||
.def_readwrite("feat_config", &PyClass::feat_config)
|
||||
.def_readwrite("model_config", &PyClass::model_config)
|
||||
.def_readwrite("lm_config", &PyClass::lm_config)
|
||||
@@ -72,6 +85,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
|
||||
.def_readwrite("hotwords_file", &PyClass::hotwords_file)
|
||||
.def_readwrite("hotwords_score", &PyClass::hotwords_score)
|
||||
.def_readwrite("blank_penalty", &PyClass::blank_penalty)
|
||||
.def_readwrite("temperature_scale", &PyClass::temperature_scale)
|
||||
.def("__str__", &PyClass::ToString);
|
||||
}
|
||||
|
||||
|
||||
@@ -58,6 +58,7 @@ class OnlineRecognizer(object):
|
||||
model_type: str = "",
|
||||
lm: str = "",
|
||||
lm_scale: float = 0.1,
|
||||
temperature_scale: float = 2.0,
|
||||
):
|
||||
"""
|
||||
Please refer to
|
||||
@@ -123,6 +124,10 @@ class OnlineRecognizer(object):
|
||||
hotwords_score:
|
||||
The hotword score of each token for biasing word/phrase. Used only if
|
||||
hotwords_file is given with modified_beam_search as decoding method.
|
||||
temperature_scale:
|
||||
Temperature scaling for output symbol confidence estiamation.
|
||||
It affects only confidence values, the decoding uses the original
|
||||
logits without temperature.
|
||||
provider:
|
||||
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
|
||||
model_type:
|
||||
@@ -193,6 +198,7 @@ class OnlineRecognizer(object):
|
||||
hotwords_score=hotwords_score,
|
||||
hotwords_file=hotwords_file,
|
||||
blank_penalty=blank_penalty,
|
||||
temperature_scale=temperature_scale,
|
||||
)
|
||||
|
||||
self.recognizer = _Recognizer(recognizer_config)
|
||||
|
||||
Reference in New Issue
Block a user