Add RNN LM rescore for offline ASR with modified_beam_search (#125)

This commit is contained in:
Fangjun Kuang
2023-04-23 17:15:18 +08:00
committed by GitHub
parent d49a597431
commit 86017f9833
32 changed files with 842 additions and 52 deletions

View File

@@ -36,38 +36,6 @@ static void UseCachedDecoderOut(
}
}
static Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out,
const std::vector<int32_t> &hyps_num_split) {
std::vector<int64_t> cur_encoder_out_shape =
cur_encoder_out->GetTensorTypeAndShapeInfo().GetShape();
std::array<int64_t, 2> ans_shape{hyps_num_split.back(),
cur_encoder_out_shape[1]};
Ort::Value ans = Ort::Value::CreateTensor<float>(allocator, ans_shape.data(),
ans_shape.size());
const float *src = cur_encoder_out->GetTensorData<float>();
float *dst = ans.GetTensorMutableData<float>();
int32_t batch_size = static_cast<int32_t>(hyps_num_split.size()) - 1;
for (int32_t b = 0; b != batch_size; ++b) {
int32_t cur_stream_hyps_num = hyps_num_split[b + 1] - hyps_num_split[b];
for (int32_t i = 0; i != cur_stream_hyps_num; ++i) {
std::copy(src, src + cur_encoder_out_shape[1], dst);
dst += cur_encoder_out_shape[1];
}
src += cur_encoder_out_shape[1];
}
return ans;
}
static void LogSoftmax(float *in, int32_t w, int32_t h) {
for (int32_t i = 0; i != h; ++i) {
LogSoftmax(in, w);
in += w;
}
}
OnlineTransducerDecoderResult
OnlineTransducerModifiedBeamSearchDecoder::GetEmptyResult() const {
int32_t context_size = model_->ContextSize();