2023-03-01 15:32:54 +08:00
|
|
|
/**
|
|
|
|
|
* Copyright (c) 2023 Xiaomi Corporation
|
2023-05-05 21:23:54 +08:00
|
|
|
* Copyright (c) 2023 Pingfeng Luo
|
2023-03-01 15:32:54 +08:00
|
|
|
*
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
#ifndef SHERPA_ONNX_CSRC_HYPOTHESIS_H_
|
|
|
|
|
#define SHERPA_ONNX_CSRC_HYPOTHESIS_H_
|
|
|
|
|
|
|
|
|
|
#include <sstream>
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <unordered_map>
|
|
|
|
|
#include <utility>
|
|
|
|
|
#include <vector>
|
2025-07-09 11:23:46 +03:00
|
|
|
#include <memory>
|
2023-03-01 15:32:54 +08:00
|
|
|
|
2023-05-05 21:23:54 +08:00
|
|
|
#include "onnxruntime_cxx_api.h" // NOLINT
|
2023-06-16 14:26:36 +08:00
|
|
|
#include "sherpa-onnx/csrc/context-graph.h"
|
2025-07-09 11:23:46 +03:00
|
|
|
#include "sherpa-onnx/csrc/lodr-fst.h"
|
2023-03-01 15:32:54 +08:00
|
|
|
#include "sherpa-onnx/csrc/math.h"
|
2023-05-05 21:23:54 +08:00
|
|
|
#include "sherpa-onnx/csrc/onnx-utils.h"
|
2023-03-01 15:32:54 +08:00
|
|
|
|
|
|
|
|
namespace sherpa_onnx {
|
|
|
|
|
|
|
|
|
|
struct Hypothesis {
|
|
|
|
|
// The predicted tokens so far. Newly predicated tokens are appended.
|
2023-03-03 12:10:59 +08:00
|
|
|
std::vector<int64_t> ys;
|
2023-03-01 15:32:54 +08:00
|
|
|
|
|
|
|
|
// timestamps[i] contains the frame number after subsampling
|
|
|
|
|
// on which ys[i] is decoded.
|
|
|
|
|
std::vector<int32_t> timestamps;
|
|
|
|
|
|
2024-01-20 22:52:41 +08:00
|
|
|
// The acoustic probability for each token in ys.
|
2024-02-28 23:28:45 +01:00
|
|
|
// Used for keyword spotting task.
|
|
|
|
|
// For transducer mofified beam-search and greedy-search,
|
|
|
|
|
// this is filled with log_posterior scores.
|
2024-01-20 22:52:41 +08:00
|
|
|
std::vector<float> ys_probs;
|
|
|
|
|
|
2024-02-28 23:28:45 +01:00
|
|
|
// lm_probs[i] contains the lm score for each token in ys.
|
|
|
|
|
// Used only in transducer mofified beam-search.
|
|
|
|
|
// Elements filled only if LM is used.
|
|
|
|
|
std::vector<float> lm_probs;
|
|
|
|
|
|
|
|
|
|
// context_scores[i] contains the context-graph score for each token in ys.
|
|
|
|
|
// Used only in transducer mofified beam-search.
|
|
|
|
|
// Elements filled only if `ContextGraph` is used.
|
|
|
|
|
std::vector<float> context_scores;
|
|
|
|
|
|
2023-03-01 15:32:54 +08:00
|
|
|
// The total score of ys in log space.
|
2023-04-23 17:15:18 +08:00
|
|
|
// It contains only acoustic scores
|
2023-03-01 15:32:54 +08:00
|
|
|
double log_prob = 0;
|
|
|
|
|
|
2023-04-23 17:15:18 +08:00
|
|
|
// LM log prob if any.
|
|
|
|
|
double lm_log_prob = 0;
|
|
|
|
|
|
2024-09-06 05:01:25 +03:00
|
|
|
// the nn lm score for next token given the current ys,
|
|
|
|
|
// when using shallow fusion
|
2023-05-10 22:30:57 +08:00
|
|
|
CopyableOrtValue nn_lm_scores;
|
2024-09-06 05:01:25 +03:00
|
|
|
|
|
|
|
|
// cur scored tokens by RNN LM, when rescoring
|
|
|
|
|
int32_t cur_scored_pos = 0;
|
|
|
|
|
|
2023-05-10 22:30:57 +08:00
|
|
|
// the nn lm states
|
2023-05-05 21:23:54 +08:00
|
|
|
std::vector<CopyableOrtValue> nn_lm_states;
|
|
|
|
|
|
2025-07-09 11:23:46 +03:00
|
|
|
// the LODR states
|
|
|
|
|
std::shared_ptr<LodrStateCost> lodr_state;
|
|
|
|
|
|
2023-06-16 14:26:36 +08:00
|
|
|
const ContextState *context_state;
|
|
|
|
|
|
|
|
|
|
// TODO(fangjun): Make it configurable
|
|
|
|
|
// the minimum of tokens in a chunk for streaming RNN LM
|
|
|
|
|
int32_t lm_rescore_min_chunk = 2; // a const
|
|
|
|
|
|
2023-03-01 15:32:54 +08:00
|
|
|
int32_t num_trailing_blanks = 0;
|
|
|
|
|
|
|
|
|
|
Hypothesis() = default;
|
2023-06-16 14:26:36 +08:00
|
|
|
Hypothesis(const std::vector<int64_t> &ys, double log_prob,
|
|
|
|
|
const ContextState *context_state = nullptr)
|
|
|
|
|
: ys(ys), log_prob(log_prob), context_state(context_state) {}
|
2023-03-01 15:32:54 +08:00
|
|
|
|
2023-04-23 17:15:18 +08:00
|
|
|
double TotalLogProb() const { return log_prob + lm_log_prob; }
|
|
|
|
|
|
2023-03-01 15:32:54 +08:00
|
|
|
// If two Hypotheses have the same `Key`, then they contain
|
|
|
|
|
// the same token sequence.
|
|
|
|
|
std::string Key() const {
|
|
|
|
|
// TODO(fangjun): Use a hash function?
|
|
|
|
|
std::ostringstream os;
|
2023-07-27 23:19:49 -07:00
|
|
|
std::string sep;
|
2023-03-01 15:32:54 +08:00
|
|
|
for (auto i : ys) {
|
2023-07-27 23:19:49 -07:00
|
|
|
os << sep << i;
|
2023-03-01 15:32:54 +08:00
|
|
|
sep = "-";
|
|
|
|
|
}
|
|
|
|
|
return os.str();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// For debugging
|
|
|
|
|
std::string ToString() const {
|
|
|
|
|
std::ostringstream os;
|
|
|
|
|
os << "(" << Key() << ", " << log_prob << ")";
|
|
|
|
|
return os.str();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class Hypotheses {
|
|
|
|
|
public:
|
|
|
|
|
Hypotheses() = default;
|
|
|
|
|
|
|
|
|
|
explicit Hypotheses(std::vector<Hypothesis> hyps) {
|
|
|
|
|
for (auto &h : hyps) {
|
|
|
|
|
hyps_dict_[h.Key()] = std::move(h);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
explicit Hypotheses(std::unordered_map<std::string, Hypothesis> hyps_dict)
|
|
|
|
|
: hyps_dict_(std::move(hyps_dict)) {}
|
|
|
|
|
|
|
|
|
|
// Add hyp to this object. If it already exists, its log_prob
|
|
|
|
|
// is updated with the given hyp using log-sum-exp.
|
|
|
|
|
void Add(Hypothesis hyp);
|
|
|
|
|
|
|
|
|
|
// Get the hyp that has the largest log_prob.
|
|
|
|
|
// If length_norm is true, hyp's log_prob is divided by
|
|
|
|
|
// len(hyp.ys) before comparison.
|
|
|
|
|
Hypothesis GetMostProbable(bool length_norm) const;
|
|
|
|
|
|
|
|
|
|
// Get the k hyps that have the largest log_prob.
|
|
|
|
|
// If length_norm is true, hyp's log_prob is divided by
|
|
|
|
|
// len(hyp.ys) before comparison.
|
|
|
|
|
std::vector<Hypothesis> GetTopK(int32_t k, bool length_norm) const;
|
|
|
|
|
|
|
|
|
|
int32_t Size() const { return hyps_dict_.size(); }
|
|
|
|
|
|
|
|
|
|
std::string ToString() const {
|
|
|
|
|
std::ostringstream os;
|
|
|
|
|
for (const auto &p : hyps_dict_) {
|
|
|
|
|
os << p.second.ToString() << "\n";
|
|
|
|
|
}
|
|
|
|
|
return os.str();
|
|
|
|
|
}
|
|
|
|
|
|
2024-05-31 13:17:01 +08:00
|
|
|
auto begin() const { return hyps_dict_.begin(); }
|
|
|
|
|
auto end() const { return hyps_dict_.end(); }
|
2023-03-01 15:32:54 +08:00
|
|
|
|
2023-04-23 17:15:18 +08:00
|
|
|
auto begin() { return hyps_dict_.begin(); }
|
|
|
|
|
auto end() { return hyps_dict_.end(); }
|
|
|
|
|
|
2023-03-01 15:32:54 +08:00
|
|
|
void Clear() { hyps_dict_.clear(); }
|
|
|
|
|
|
|
|
|
|
// Return a list of hyps contained in this object.
|
|
|
|
|
std::vector<Hypothesis> Vec() const {
|
|
|
|
|
std::vector<Hypothesis> ans;
|
|
|
|
|
ans.reserve(hyps_dict_.size());
|
|
|
|
|
for (const auto &p : hyps_dict_) {
|
|
|
|
|
ans.push_back(p.second);
|
|
|
|
|
}
|
|
|
|
|
return ans;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
using Map = std ::unordered_map<std::string, Hypothesis>;
|
|
|
|
|
Map hyps_dict_;
|
|
|
|
|
};
|
|
|
|
|
|
2023-04-26 11:41:04 +08:00
|
|
|
const std::vector<int32_t> GetHypsRowSplits(
|
|
|
|
|
const std::vector<Hypotheses> &hyps);
|
|
|
|
|
|
2023-03-01 15:32:54 +08:00
|
|
|
} // namespace sherpa_onnx
|
|
|
|
|
|
|
|
|
|
#endif // SHERPA_ONNX_CSRC_HYPOTHESIS_H_
|