Add RNN LM rescore for offline ASR with modified_beam_search (#125)
This commit is contained in:
@@ -193,4 +193,29 @@ std::vector<char> ReadFile(AAssetManager *mgr, const std::string &filename) {
|
||||
}
|
||||
#endif
|
||||
|
||||
Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out,
|
||||
const std::vector<int32_t> &hyps_num_split) {
|
||||
std::vector<int64_t> cur_encoder_out_shape =
|
||||
cur_encoder_out->GetTensorTypeAndShapeInfo().GetShape();
|
||||
|
||||
std::array<int64_t, 2> ans_shape{hyps_num_split.back(),
|
||||
cur_encoder_out_shape[1]};
|
||||
|
||||
Ort::Value ans = Ort::Value::CreateTensor<float>(allocator, ans_shape.data(),
|
||||
ans_shape.size());
|
||||
|
||||
const float *src = cur_encoder_out->GetTensorData<float>();
|
||||
float *dst = ans.GetTensorMutableData<float>();
|
||||
int32_t batch_size = static_cast<int32_t>(hyps_num_split.size()) - 1;
|
||||
for (int32_t b = 0; b != batch_size; ++b) {
|
||||
int32_t cur_stream_hyps_num = hyps_num_split[b + 1] - hyps_num_split[b];
|
||||
for (int32_t i = 0; i != cur_stream_hyps_num; ++i) {
|
||||
std::copy(src, src + cur_encoder_out_shape[1], dst);
|
||||
dst += cur_encoder_out_shape[1];
|
||||
}
|
||||
src += cur_encoder_out_shape[1];
|
||||
}
|
||||
return ans;
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
Reference in New Issue
Block a user