add blank_penalty for online transducer (#548)
This commit is contained in:
@@ -123,6 +123,10 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
model_->RunJoiner(std::move(cur_encoder_out), View(&decoder_out));
|
||||
|
||||
float *p_logit = logit.GetTensorMutableData<float>();
|
||||
if (blank_penalty_ > 0.0) {
|
||||
// assuming blank id is 0
|
||||
SubtractBlank(p_logit, vocab_size, num_hyps, 0, blank_penalty_);
|
||||
}
|
||||
LogSoftmax(p_logit, vocab_size, num_hyps);
|
||||
|
||||
// now p_logit contains log_softmax output, we rename it to p_logprob
|
||||
|
||||
Reference in New Issue
Block a user