Add lm rescore to online-modified-beam-search (#133)
This commit is contained in:
@@ -96,17 +96,15 @@ void LogSoftmax(T *in, int32_t w, int32_t h) {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(fangjun): use std::partial_sort to replace std::sort.
|
||||
// Remember also to fix sherpa-ncnn
|
||||
template <class T>
|
||||
std::vector<int32_t> TopkIndex(const T *vec, int32_t size, int32_t topk) {
|
||||
std::vector<int32_t> vec_index(size);
|
||||
std::iota(vec_index.begin(), vec_index.end(), 0);
|
||||
|
||||
std::sort(vec_index.begin(), vec_index.end(),
|
||||
[vec](int32_t index_1, int32_t index_2) {
|
||||
return vec[index_1] > vec[index_2];
|
||||
});
|
||||
std::partial_sort(vec_index.begin(), vec_index.begin() + topk,
|
||||
vec_index.end(), [vec](int32_t index_1, int32_t index_2) {
|
||||
return vec[index_1] > vec[index_2];
|
||||
});
|
||||
|
||||
int32_t k_num = std::min<int32_t>(size, topk);
|
||||
std::vector<int32_t> index(vec_index.begin(), vec_index.begin() + k_num);
|
||||
|
||||
Reference in New Issue
Block a user