add blank_penalty for online transducer (#548)

This commit is contained in:
chiiyeh
2024-01-26 12:12:13 +08:00
committed by GitHub
parent 466a6855c8
commit e7b18a2139
13 changed files with 94 additions and 14 deletions

View File

@@ -116,11 +116,14 @@ void OnlineTransducerGreedySearchDecoder::Decode(
Ort::Value logit = model_->RunJoiner(
std::move(cur_encoder_out), View(&decoder_out));
const float *p_logit = logit.GetTensorData<float>();
float *p_logit = logit.GetTensorMutableData<float>();
bool emitted = false;
for (int32_t i = 0; i < batch_size; ++i, p_logit += vocab_size) {
auto &r = (*result)[i];
if (blank_penalty_ > 0.0) {
p_logit[0] -= blank_penalty_; // assuming blank id is 0
}
auto y = static_cast<int32_t>(std::distance(
static_cast<const float *>(p_logit),
std::max_element(static_cast<const float *>(p_logit),