add blank_penalty for online transducer (#548)
This commit is contained in:
@@ -116,11 +116,14 @@ void OnlineTransducerGreedySearchDecoder::Decode(
|
||||
Ort::Value logit = model_->RunJoiner(
|
||||
std::move(cur_encoder_out), View(&decoder_out));
|
||||
|
||||
const float *p_logit = logit.GetTensorData<float>();
|
||||
float *p_logit = logit.GetTensorMutableData<float>();
|
||||
|
||||
bool emitted = false;
|
||||
for (int32_t i = 0; i < batch_size; ++i, p_logit += vocab_size) {
|
||||
auto &r = (*result)[i];
|
||||
if (blank_penalty_ > 0.0) {
|
||||
p_logit[0] -= blank_penalty_; // assuming blank id is 0
|
||||
}
|
||||
auto y = static_cast<int32_t>(std::distance(
|
||||
static_cast<const float *>(p_logit),
|
||||
std::max_element(static_cast<const float *>(p_logit),
|
||||
|
||||
Reference in New Issue
Block a user