add modified beam search (#69)
This commit is contained in:
@@ -4,8 +4,6 @@
|
||||
|
||||
#include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h"
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
@@ -15,39 +13,6 @@
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static Ort::Value GetFrame(OrtAllocator *allocator, Ort::Value *encoder_out,
|
||||
int32_t t) {
|
||||
std::vector<int64_t> encoder_out_shape =
|
||||
encoder_out->GetTensorTypeAndShapeInfo().GetShape();
|
||||
|
||||
auto batch_size = encoder_out_shape[0];
|
||||
auto num_frames = encoder_out_shape[1];
|
||||
assert(t < num_frames);
|
||||
|
||||
auto encoder_out_dim = encoder_out_shape[2];
|
||||
|
||||
auto offset = num_frames * encoder_out_dim;
|
||||
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
std::array<int64_t, 2> shape{batch_size, encoder_out_dim};
|
||||
|
||||
Ort::Value ans =
|
||||
Ort::Value::CreateTensor<float>(allocator, shape.data(), shape.size());
|
||||
|
||||
float *dst = ans.GetTensorMutableData<float>();
|
||||
const float *src = encoder_out->GetTensorData<float>();
|
||||
|
||||
for (int32_t i = 0; i != batch_size; ++i) {
|
||||
std::copy(src + t * encoder_out_dim, src + (t + 1) * encoder_out_dim, dst);
|
||||
src += offset;
|
||||
dst += encoder_out_dim;
|
||||
}
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
OnlineTransducerDecoderResult
|
||||
OnlineTransducerGreedySearchDecoder::GetEmptyResult() const {
|
||||
int32_t context_size = model_->ContextSize();
|
||||
@@ -90,7 +55,8 @@ void OnlineTransducerGreedySearchDecoder::Decode(
|
||||
Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input));
|
||||
|
||||
for (int32_t t = 0; t != num_frames; ++t) {
|
||||
Ort::Value cur_encoder_out = GetFrame(model_->Allocator(), &encoder_out, t);
|
||||
Ort::Value cur_encoder_out =
|
||||
GetEncoderOutFrame(model_->Allocator(), &encoder_out, t);
|
||||
Ort::Value logit = model_->RunJoiner(
|
||||
std::move(cur_encoder_out), Clone(model_->Allocator(), &decoder_out));
|
||||
|
||||
|
||||
Reference in New Issue
Block a user