add modified beam search (#69)

This commit is contained in:
PF Luo
2023-03-01 15:32:54 +08:00
committed by GitHub
parent e0b76655c8
commit 5326d0f81f
19 changed files with 614 additions and 87 deletions

View File

@@ -4,8 +4,6 @@
#include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h"
#include <assert.h>
#include <algorithm>
#include <utility>
#include <vector>
@@ -15,39 +13,6 @@
namespace sherpa_onnx {
static Ort::Value GetFrame(OrtAllocator *allocator, Ort::Value *encoder_out,
int32_t t) {
std::vector<int64_t> encoder_out_shape =
encoder_out->GetTensorTypeAndShapeInfo().GetShape();
auto batch_size = encoder_out_shape[0];
auto num_frames = encoder_out_shape[1];
assert(t < num_frames);
auto encoder_out_dim = encoder_out_shape[2];
auto offset = num_frames * encoder_out_dim;
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
std::array<int64_t, 2> shape{batch_size, encoder_out_dim};
Ort::Value ans =
Ort::Value::CreateTensor<float>(allocator, shape.data(), shape.size());
float *dst = ans.GetTensorMutableData<float>();
const float *src = encoder_out->GetTensorData<float>();
for (int32_t i = 0; i != batch_size; ++i) {
std::copy(src + t * encoder_out_dim, src + (t + 1) * encoder_out_dim, dst);
src += offset;
dst += encoder_out_dim;
}
return ans;
}
OnlineTransducerDecoderResult
OnlineTransducerGreedySearchDecoder::GetEmptyResult() const {
int32_t context_size = model_->ContextSize();
@@ -90,7 +55,8 @@ void OnlineTransducerGreedySearchDecoder::Decode(
Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input));
for (int32_t t = 0; t != num_frames; ++t) {
Ort::Value cur_encoder_out = GetFrame(model_->Allocator(), &encoder_out, t);
Ort::Value cur_encoder_out =
GetEncoderOutFrame(model_->Allocator(), &encoder_out, t);
Ort::Value logit = model_->RunJoiner(
std::move(cur_encoder_out), Clone(model_->Allocator(), &decoder_out));