Judge before UseCachedDecoderOut (#431)
Co-authored-by: hiedean <hiedean@tju.edu.cn>
This commit is contained in:
@@ -89,9 +89,24 @@ void OnlineTransducerGreedySearchDecoder::Decode(
|
||||
int32_t num_frames = static_cast<int32_t>(encoder_out_shape[1]);
|
||||
int32_t vocab_size = model_->VocabSize();
|
||||
|
||||
Ort::Value decoder_input = model_->BuildDecoderInput(*result);
|
||||
Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input));
|
||||
UseCachedDecoderOut(*result, &decoder_out);
|
||||
Ort::Value decoder_out{nullptr};
|
||||
bool is_batch_decoder_out_cached = true;
|
||||
for (const auto &r : *result) {
|
||||
if (!r.decoder_out) {
|
||||
is_batch_decoder_out_cached = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (is_batch_decoder_out_cached) {
|
||||
auto &r = result->front();
|
||||
std::vector<int64_t> decoder_out_shape = r.decoder_out.GetTensorTypeAndShapeInfo().GetShape();
|
||||
decoder_out_shape[0] = batch_size;
|
||||
decoder_out = Ort::Value::CreateTensor<float>(model_->Allocator(), decoder_out_shape.data(), decoder_out_shape.size());
|
||||
UseCachedDecoderOut(*result, &decoder_out);
|
||||
} else {
|
||||
Ort::Value decoder_input = model_->BuildDecoderInput(*result);
|
||||
decoder_out = model_->RunDecoder(std::move(decoder_input));
|
||||
}
|
||||
|
||||
for (int32_t t = 0; t != num_frames; ++t) {
|
||||
Ort::Value cur_encoder_out =
|
||||
|
||||
Reference in New Issue
Block a user