// sherpa-onnx/csrc/context-graph.h // // Copyright (c) 2023 Xiaomi Corporation #ifndef SHERPA_ONNX_CSRC_CONTEXT_GRAPH_H_ #define SHERPA_ONNX_CSRC_CONTEXT_GRAPH_H_ #include #include #include #include #include #include #include "sherpa-onnx/csrc/log.h" namespace sherpa_onnx { class ContextGraph; using ContextGraphPtr = std::shared_ptr; struct ContextState { int32_t token; float token_score; float node_score; float output_score; int32_t level; float ac_threshold; bool is_end; std::string phrase; std::unordered_map> next; const ContextState *fail = nullptr; const ContextState *output = nullptr; ContextState() = default; ContextState(int32_t token, float token_score, float node_score, float output_score, int32_t level = 0, float ac_threshold = 0.0f, bool is_end = false, const std::string &phrase = {}) : token(token), token_score(token_score), node_score(node_score), output_score(output_score), level(level), ac_threshold(ac_threshold), is_end(is_end), phrase(phrase) {} }; class ContextGraph { public: ContextGraph() = default; ContextGraph(const std::vector> &token_ids, float context_score, float ac_threshold, const std::vector &scores = {}, const std::vector &phrases = {}, const std::vector &ac_thresholds = {}) : context_score_(context_score), ac_threshold_(ac_threshold) { root_ = std::make_unique(-1, 0, 0, 0); root_->fail = root_.get(); Build(token_ids, scores, phrases, ac_thresholds); } ContextGraph(const std::vector> &token_ids, float context_score, const std::vector &scores = {}) : ContextGraph(token_ids, context_score, 0.0f, scores, std::vector(), std::vector()) {} std::tuple ForwardOneStep( const ContextState *state, int32_t token_id, bool strict_mode = true) const; std::pair IsMatched( const ContextState *state) const; std::pair Finalize( const ContextState *state) const; const ContextState *Root() const { return root_.get(); } private: float context_score_; float ac_threshold_; std::unique_ptr root_; void Build(const std::vector> &token_ids, const std::vector &scores, const std::vector &phrases, const std::vector &ac_thresholds) const; void FillFailOutput() const; }; } // namespace sherpa_onnx #endif // SHERPA_ONNX_CSRC_CONTEXT_GRAPH_H_