add modified beam search (#69)
This commit is contained in:
@@ -44,6 +44,38 @@ void GetOutputNames(Ort::Session *sess, std::vector<std::string> *output_names,
|
||||
}
|
||||
}
|
||||
|
||||
Ort::Value GetEncoderOutFrame(OrtAllocator *allocator, Ort::Value *encoder_out,
|
||||
int32_t t) {
|
||||
std::vector<int64_t> encoder_out_shape =
|
||||
encoder_out->GetTensorTypeAndShapeInfo().GetShape();
|
||||
|
||||
auto batch_size = encoder_out_shape[0];
|
||||
auto num_frames = encoder_out_shape[1];
|
||||
assert(t < num_frames);
|
||||
|
||||
auto encoder_out_dim = encoder_out_shape[2];
|
||||
|
||||
auto offset = num_frames * encoder_out_dim;
|
||||
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
std::array<int64_t, 2> shape{batch_size, encoder_out_dim};
|
||||
|
||||
Ort::Value ans =
|
||||
Ort::Value::CreateTensor<float>(allocator, shape.data(), shape.size());
|
||||
|
||||
float *dst = ans.GetTensorMutableData<float>();
|
||||
const float *src = encoder_out->GetTensorData<float>();
|
||||
|
||||
for (int32_t i = 0; i != batch_size; ++i) {
|
||||
std::copy(src + t * encoder_out_dim, src + (t + 1) * encoder_out_dim, dst);
|
||||
src += offset;
|
||||
dst += encoder_out_dim;
|
||||
}
|
||||
return ans;
|
||||
}
|
||||
|
||||
void PrintModelMetadata(std::ostream &os, const Ort::ModelMetadata &meta_data) {
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
std::vector<Ort::AllocatedStringPtr> v =
|
||||
|
||||
Reference in New Issue
Block a user