Add RNN LM rescore for offline ASR with modified_beam_search (#125)

This commit is contained in:
Fangjun Kuang
2023-04-23 17:15:18 +08:00
committed by GitHub
parent d49a597431
commit 86017f9833
32 changed files with 842 additions and 52 deletions

View File

@@ -193,4 +193,29 @@ std::vector<char> ReadFile(AAssetManager *mgr, const std::string &filename) {
}
#endif
Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out,
const std::vector<int32_t> &hyps_num_split) {
std::vector<int64_t> cur_encoder_out_shape =
cur_encoder_out->GetTensorTypeAndShapeInfo().GetShape();
std::array<int64_t, 2> ans_shape{hyps_num_split.back(),
cur_encoder_out_shape[1]};
Ort::Value ans = Ort::Value::CreateTensor<float>(allocator, ans_shape.data(),
ans_shape.size());
const float *src = cur_encoder_out->GetTensorData<float>();
float *dst = ans.GetTensorMutableData<float>();
int32_t batch_size = static_cast<int32_t>(hyps_num_split.size()) - 1;
for (int32_t b = 0; b != batch_size; ++b) {
int32_t cur_stream_hyps_num = hyps_num_split[b + 1] - hyps_num_split[b];
for (int32_t i = 0; i != cur_stream_hyps_num; ++i) {
std::copy(src, src + cur_encoder_out_shape[1], dst);
dst += cur_encoder_out_shape[1];
}
src += cur_encoder_out_shape[1];
}
return ans;
}
} // namespace sherpa_onnx