/** * Copyright (c) 2022 Xiaomi Corporation (authors: Daniel Povey) * Copyright (c) 2023 (Pingfeng Luo) * */ // This file is copied from k2/csrc/utils.h #ifndef SHERPA_ONNX_CSRC_MATH_H_ #define SHERPA_ONNX_CSRC_MATH_H_ #include #include #include #include #include namespace sherpa_onnx { // logf(FLT_EPSILON) #define SHERPA_ONNX_MIN_LOG_DIFF_FLOAT -15.9423847198486328125f // log(DBL_EPSILON) #define SHERPA_ONNX_MIN_LOG_DIFF_DOUBLE \ -36.0436533891171535515240975655615329742431640625 template struct LogAdd; template <> struct LogAdd { double operator()(double x, double y) const { double diff; if (x < y) { diff = x - y; x = y; } else { diff = y - x; } // diff is negative. x is now the larger one. if (diff >= SHERPA_ONNX_MIN_LOG_DIFF_DOUBLE) { double res; res = x + log1p(exp(diff)); return res; } return x; // return the larger one. } }; template <> struct LogAdd { float operator()(float x, float y) const { float diff; if (x < y) { diff = x - y; x = y; } else { diff = y - x; } // diff is negative. x is now the larger one. if (diff >= SHERPA_ONNX_MIN_LOG_DIFF_DOUBLE) { float res; res = x + log1pf(expf(diff)); return res; } return x; // return the larger one. } }; template void LogSoftmax(T *input, int32_t input_len) { assert(input); T m = *std::max_element(input, input + input_len); T sum = 0.0; for (int32_t i = 0; i < input_len; i++) { sum += exp(input[i] - m); } T offset = m + log(sum); for (int32_t i = 0; i < input_len; i++) { input[i] -= offset; } } template void LogSoftmax(T *in, int32_t w, int32_t h) { for (int32_t i = 0; i != h; ++i) { LogSoftmax(in, w); in += w; } } template std::vector TopkIndex(const T *vec, int32_t size, int32_t topk) { std::vector vec_index(size); std::iota(vec_index.begin(), vec_index.end(), 0); std::partial_sort(vec_index.begin(), vec_index.begin() + topk, vec_index.end(), [vec](int32_t index_1, int32_t index_2) { return vec[index_1] > vec[index_2]; }); int32_t k_num = std::min(size, topk); std::vector index(vec_index.begin(), vec_index.begin() + k_num); return index; } } // namespace sherpa_onnx #endif // SHERPA_ONNX_CSRC_MATH_H_