Replace Clone() with View() (#432)

Co-authored-by: hiedean <hiedean@tju.edu.cn>
This commit is contained in:
HieDean
2023-11-20 09:20:50 +08:00
committed by GitHub
parent ac00edab5b
commit e6a2d0da3b
5 changed files with 14 additions and 12 deletions

View File

@@ -99,9 +99,11 @@ void OnlineTransducerGreedySearchDecoder::Decode(
}
if (is_batch_decoder_out_cached) {
auto &r = result->front();
std::vector<int64_t> decoder_out_shape = r.decoder_out.GetTensorTypeAndShapeInfo().GetShape();
std::vector<int64_t> decoder_out_shape =
r.decoder_out.GetTensorTypeAndShapeInfo().GetShape();
decoder_out_shape[0] = batch_size;
decoder_out = Ort::Value::CreateTensor<float>(model_->Allocator(), decoder_out_shape.data(), decoder_out_shape.size());
decoder_out = Ort::Value::CreateTensor<float>(model_->Allocator(),
decoder_out_shape.data(), decoder_out_shape.size());
UseCachedDecoderOut(*result, &decoder_out);
} else {
Ort::Value decoder_input = model_->BuildDecoderInput(*result);
@@ -112,7 +114,7 @@ void OnlineTransducerGreedySearchDecoder::Decode(
Ort::Value cur_encoder_out =
GetEncoderOutFrame(model_->Allocator(), &encoder_out, t);
Ort::Value logit = model_->RunJoiner(
std::move(cur_encoder_out), Clone(model_->Allocator(), &decoder_out));
std::move(cur_encoder_out), View(&decoder_out));
const float *p_logit = logit.GetTensorData<float>();