add blank_penalty for offline transducer (#542)

This commit is contained in:
chiiyeh
2024-01-25 15:00:09 +08:00
committed by GitHub
parent a9e7747736
commit 3bb3849ec5
13 changed files with 97 additions and 14 deletions

View File

@@ -46,9 +46,12 @@ OfflineTransducerGreedySearchDecoder::Decode(Ort::Value encoder_out,
start += n;
Ort::Value logit = model_->RunJoiner(std::move(cur_encoder_out),
std::move(cur_decoder_out));
const float *p_logit = logit.GetTensorData<float>();
float *p_logit = logit.GetTensorMutableData<float>();
bool emitted = false;
for (int32_t i = 0; i != n; ++i) {
if (blank_penalty_ > 0.0) {
p_logit[0] -= blank_penalty_; // assuming blank id is 0
}
auto y = static_cast<int32_t>(std::distance(
static_cast<const float *>(p_logit),
std::max_element(static_cast<const float *>(p_logit),