Implement context biasing with a Aho Corasick automata (#145)
* Implement context graph * Modify the interface to support context biasing * Support context biasing in modified beam search; add python wrapper * Support context biasing in python api example * Minor fixes * Fix context graph * Minor fixes * Fix tests * Fix style * Fix style * Fix comments * Minor fixes * Add missing header * Replace std::shared_ptr with std::unique_ptr for effciency * Build graph in constructor * Fix comments * Minor fixes * Fix docs
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/context-graph.h"
|
||||
#include "sherpa-onnx/csrc/math.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
|
||||
@@ -39,11 +40,18 @@ struct Hypothesis {
|
||||
// the nn lm states
|
||||
std::vector<CopyableOrtValue> nn_lm_states;
|
||||
|
||||
const ContextState *context_state;
|
||||
|
||||
// TODO(fangjun): Make it configurable
|
||||
// the minimum of tokens in a chunk for streaming RNN LM
|
||||
int32_t lm_rescore_min_chunk = 2; // a const
|
||||
|
||||
int32_t num_trailing_blanks = 0;
|
||||
|
||||
Hypothesis() = default;
|
||||
Hypothesis(const std::vector<int64_t> &ys, double log_prob)
|
||||
: ys(ys), log_prob(log_prob) {}
|
||||
Hypothesis(const std::vector<int64_t> &ys, double log_prob,
|
||||
const ContextState *context_state = nullptr)
|
||||
: ys(ys), log_prob(log_prob), context_state(context_state) {}
|
||||
|
||||
double TotalLogProb() const { return log_prob + lm_log_prob; }
|
||||
|
||||
|
||||
Reference in New Issue
Block a user