/** * Copyright (c) 2023 Xiaomi Corporation * Copyright (c) 2023 Pingfeng Luo */ #include "sherpa-onnx/csrc/hypothesis.h" #include #include namespace sherpa_onnx { void Hypotheses::Add(Hypothesis hyp) { auto key = hyp.Key(); auto it = hyps_dict_.find(key); if (it == hyps_dict_.end()) { hyps_dict_[key] = std::move(hyp); } else { it->second.log_prob = LogAdd()(it->second.log_prob, hyp.log_prob); } } Hypothesis Hypotheses::GetMostProbable(bool length_norm) const { if (length_norm == false) { return std::max_element(hyps_dict_.begin(), hyps_dict_.end(), [](const auto &left, auto &right) -> bool { return left.second.TotalLogProb() < right.second.TotalLogProb(); }) ->second; } else { // for length_norm is true return std::max_element( hyps_dict_.begin(), hyps_dict_.end(), [](const auto &left, const auto &right) -> bool { return left.second.TotalLogProb() / left.second.ys.size() < right.second.TotalLogProb() / right.second.ys.size(); }) ->second; } } std::vector Hypotheses::GetTopK(int32_t k, bool length_norm) const { k = std::max(k, 1); k = std::min(k, Size()); std::vector all_hyps = Vec(); if (length_norm == false) { std::partial_sort(all_hyps.begin(), all_hyps.begin() + k, all_hyps.end(), [](const auto &a, const auto &b) { return a.TotalLogProb() > b.TotalLogProb(); }); } else { // for length_norm is true std::partial_sort(all_hyps.begin(), all_hyps.begin() + k, all_hyps.end(), [](const auto &a, const auto &b) { return a.TotalLogProb() / a.ys.size() > b.TotalLogProb() / b.ys.size(); }); } return {all_hyps.begin(), all_hyps.begin() + k}; } const std::vector GetHypsRowSplits( const std::vector &hyps) { std::vector row_splits; row_splits.reserve(hyps.size() + 1); row_splits.push_back(0); int32_t s = 0; for (const auto &h : hyps) { s += h.Size(); row_splits.push_back(s); } return row_splits; } } // namespace sherpa_onnx