Add lm rescore to online-modified-beam-search (#133)

This commit is contained in:
PF Luo
2023-05-05 21:23:54 +08:00
committed by GitHub
parent 3b9c3db31d
commit 8c6a6768d5
26 changed files with 495 additions and 39 deletions

View File

@@ -96,17 +96,15 @@ void LogSoftmax(T *in, int32_t w, int32_t h) {
}
}
// TODO(fangjun): use std::partial_sort to replace std::sort.
// Remember also to fix sherpa-ncnn
template <class T>
std::vector<int32_t> TopkIndex(const T *vec, int32_t size, int32_t topk) {
std::vector<int32_t> vec_index(size);
std::iota(vec_index.begin(), vec_index.end(), 0);
std::sort(vec_index.begin(), vec_index.end(),
[vec](int32_t index_1, int32_t index_2) {
return vec[index_1] > vec[index_2];
});
std::partial_sort(vec_index.begin(), vec_index.begin() + topk,
vec_index.end(), [vec](int32_t index_1, int32_t index_2) {
return vec[index_1] > vec[index_2];
});
int32_t k_num = std::min<int32_t>(size, topk);
std::vector<int32_t> index(vec_index.begin(), vec_index.begin() + k_num);