diff --git a/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc b/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc index 46dbcbb3..007c596d 100644 --- a/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc +++ b/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc @@ -10,47 +10,39 @@ #include #include +#include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/onnx-utils.h" namespace sherpa_onnx { -static Ort::Value GetFrame(Ort::Value *encoder_out, int32_t t) { +static Ort::Value GetFrame(OrtAllocator *allocator, Ort::Value *encoder_out, + int32_t t) { std::vector encoder_out_shape = encoder_out->GetTensorTypeAndShapeInfo().GetShape(); - assert(encoder_out_shape[0] == 1); - int32_t encoder_out_dim = encoder_out_shape[2]; + auto batch_size = encoder_out_shape[0]; + auto num_frames = encoder_out_shape[1]; + assert(t < num_frames); + + auto encoder_out_dim = encoder_out_shape[2]; + + auto offset = num_frames * encoder_out_dim; auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); - std::array shape{1, encoder_out_dim}; + std::array shape{batch_size, encoder_out_dim}; - return Ort::Value::CreateTensor( - memory_info, - encoder_out->GetTensorMutableData() + t * encoder_out_dim, - encoder_out_dim, shape.data(), shape.size()); -} + Ort::Value ans = + Ort::Value::CreateTensor(allocator, shape.data(), shape.size()); -static Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out, - int32_t n) { - if (n == 1) { - return std::move(*cur_encoder_out); - } - - std::vector cur_encoder_out_shape = - cur_encoder_out->GetTensorTypeAndShapeInfo().GetShape(); - - std::array ans_shape{n, cur_encoder_out_shape[1]}; - - Ort::Value ans = Ort::Value::CreateTensor(allocator, ans_shape.data(), - ans_shape.size()); - - const float *src = cur_encoder_out->GetTensorData(); float *dst = ans.GetTensorMutableData(); - for (int32_t i = 0; i != n; ++i) { - std::copy(src, src + cur_encoder_out_shape[1], dst); - dst += cur_encoder_out_shape[1]; + const float *src = encoder_out->GetTensorData(); + + for (int32_t i = 0; i != batch_size; ++i) { + std::copy(src + t * encoder_out_dim, src + (t + 1) * encoder_out_dim, dst); + src += offset; + dst += encoder_out_dim; } return ans; @@ -83,10 +75,10 @@ void OnlineTransducerGreedySearchDecoder::Decode( encoder_out.GetTensorTypeAndShapeInfo().GetShape(); if (encoder_out_shape[0] != result->size()) { - fprintf(stderr, - "Size mismatch! encoder_out.size(0) %d, result.size(0): %d\n", - static_cast(encoder_out_shape[0]), - static_cast(result->size())); + SHERPA_ONNX_LOGE( + "Size mismatch! encoder_out.size(0) %d, result.size(0): %d", + static_cast(encoder_out_shape[0]), + static_cast(result->size())); exit(-1); } @@ -98,10 +90,10 @@ void OnlineTransducerGreedySearchDecoder::Decode( Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input)); for (int32_t t = 0; t != num_frames; ++t) { - Ort::Value cur_encoder_out = GetFrame(&encoder_out, t); - cur_encoder_out = Repeat(model_->Allocator(), &cur_encoder_out, batch_size); + Ort::Value cur_encoder_out = GetFrame(model_->Allocator(), &encoder_out, t); Ort::Value logit = model_->RunJoiner( std::move(cur_encoder_out), Clone(model_->Allocator(), &decoder_out)); + const float *p_logit = logit.GetTensorData(); bool emitted = false;