Add RNN LM rescore for offline ASR with modified_beam_search (#125)
This commit is contained in:
@@ -95,6 +95,30 @@ class OfflineTransducerModel::Impl {
|
||||
std::copy(begin, end, p);
|
||||
p += context_size;
|
||||
}
|
||||
|
||||
return decoder_input;
|
||||
}
|
||||
|
||||
Ort::Value BuildDecoderInput(const std::vector<Hypothesis> &results,
|
||||
int32_t end_index) const {
|
||||
assert(end_index <= results.size());
|
||||
|
||||
int32_t batch_size = end_index;
|
||||
int32_t context_size = ContextSize();
|
||||
std::array<int64_t, 2> shape{batch_size, context_size};
|
||||
|
||||
Ort::Value decoder_input = Ort::Value::CreateTensor<int64_t>(
|
||||
Allocator(), shape.data(), shape.size());
|
||||
int64_t *p = decoder_input.GetTensorMutableData<int64_t>();
|
||||
|
||||
for (int32_t i = 0; i != batch_size; ++i) {
|
||||
const auto &r = results[i];
|
||||
const int64_t *begin = r.ys.data() + r.ys.size() - context_size;
|
||||
const int64_t *end = r.ys.data() + r.ys.size();
|
||||
std::copy(begin, end, p);
|
||||
p += context_size;
|
||||
}
|
||||
|
||||
return decoder_input;
|
||||
}
|
||||
|
||||
@@ -234,4 +258,9 @@ Ort::Value OfflineTransducerModel::BuildDecoderInput(
|
||||
return impl_->BuildDecoderInput(results, end_index);
|
||||
}
|
||||
|
||||
Ort::Value OfflineTransducerModel::BuildDecoderInput(
|
||||
const std::vector<Hypothesis> &results, int32_t end_index) const {
|
||||
return impl_->BuildDecoderInput(results, end_index);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
Reference in New Issue
Block a user