Add RNN LM rescore for offline ASR with modified_beam_search (#125)

This commit is contained in:
Fangjun Kuang
2023-04-23 17:15:18 +08:00
committed by GitHub
parent d49a597431
commit 86017f9833
32 changed files with 842 additions and 52 deletions

View File

@@ -95,6 +95,30 @@ class OfflineTransducerModel::Impl {
std::copy(begin, end, p);
p += context_size;
}
return decoder_input;
}
Ort::Value BuildDecoderInput(const std::vector<Hypothesis> &results,
int32_t end_index) const {
assert(end_index <= results.size());
int32_t batch_size = end_index;
int32_t context_size = ContextSize();
std::array<int64_t, 2> shape{batch_size, context_size};
Ort::Value decoder_input = Ort::Value::CreateTensor<int64_t>(
Allocator(), shape.data(), shape.size());
int64_t *p = decoder_input.GetTensorMutableData<int64_t>();
for (int32_t i = 0; i != batch_size; ++i) {
const auto &r = results[i];
const int64_t *begin = r.ys.data() + r.ys.size() - context_size;
const int64_t *end = r.ys.data() + r.ys.size();
std::copy(begin, end, p);
p += context_size;
}
return decoder_input;
}
@@ -234,4 +258,9 @@ Ort::Value OfflineTransducerModel::BuildDecoderInput(
return impl_->BuildDecoderInput(results, end_index);
}
Ort::Value OfflineTransducerModel::BuildDecoderInput(
const std::vector<Hypothesis> &results, int32_t end_index) const {
return impl_->BuildDecoderInput(results, end_index);
}
} // namespace sherpa_onnx