Implement context biasing with a Aho Corasick automata (#145)

* Implement context graph

* Modify the interface to support context biasing

* Support context biasing in modified beam search; add python wrapper

* Support context biasing in python api example

* Minor fixes

* Fix context graph

* Minor fixes

* Fix tests

* Fix style

* Fix style

* Fix comments

* Minor fixes

* Add missing header

* Replace std::shared_ptr with std::unique_ptr for effciency

* Build graph in constructor

* Fix comments

* Minor fixes

* Fix docs
This commit is contained in:
Wei Kang
2023-06-16 14:26:36 +08:00
committed by GitHub
parent 1a1b9fd236
commit 8562711252
23 changed files with 515 additions and 29 deletions

View File

@@ -8,7 +8,9 @@
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/context-graph.h"
#include "sherpa-onnx/csrc/hypothesis.h"
#include "sherpa-onnx/csrc/log.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/packed-sequence.h"
#include "sherpa-onnx/csrc/slice.h"
@@ -17,23 +19,39 @@ namespace sherpa_onnx {
std::vector<OfflineTransducerDecoderResult>
OfflineTransducerModifiedBeamSearchDecoder::Decode(
Ort::Value encoder_out, Ort::Value encoder_out_length) {
Ort::Value encoder_out, Ort::Value encoder_out_length,
OfflineStream **ss /*=nullptr */, int32_t n /*= 0*/) {
PackedSequence packed_encoder_out = PackPaddedSequence(
model_->Allocator(), &encoder_out, &encoder_out_length);
int32_t batch_size =
static_cast<int32_t>(packed_encoder_out.sorted_indexes.size());
if (ss != nullptr) SHERPA_ONNX_CHECK_EQ(batch_size, n);
int32_t vocab_size = model_->VocabSize();
int32_t context_size = model_->ContextSize();
std::vector<int64_t> blanks(context_size, 0);
Hypotheses blank_hyp({{blanks, 0}});
std::deque<Hypotheses> finalized;
std::vector<Hypotheses> cur(batch_size, blank_hyp);
std::vector<Hypotheses> cur;
std::vector<Hypothesis> prev;
std::vector<ContextGraphPtr> context_graphs(batch_size, nullptr);
for (int32_t i = 0; i < batch_size; ++i) {
const ContextState *context_state;
if (ss != nullptr) {
context_graphs[i] =
ss[packed_encoder_out.sorted_indexes[i]]->GetContextGraph();
if (context_graphs[i] != nullptr)
context_state = context_graphs[i]->Root();
}
Hypotheses blank_hyp({{blanks, 0, context_state}});
cur.emplace_back(std::move(blank_hyp));
}
int32_t start = 0;
int32_t t = 0;
for (auto n : packed_encoder_out.batch_sizes) {
@@ -106,13 +124,21 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode(
int32_t new_token = k % vocab_size;
Hypothesis new_hyp = prev[hyp_index];
float context_score = 0;
auto context_state = new_hyp.context_state;
if (new_token != 0) {
// blank id is fixed to 0
new_hyp.ys.push_back(new_token);
new_hyp.timestamps.push_back(t);
if (context_graphs[i] != nullptr) {
auto context_res =
context_graphs[i]->ForwardOneStep(context_state, new_token);
context_score = context_res.first;
new_hyp.context_state = context_res.second;
}
}
new_hyp.log_prob = p_logprob[k];
new_hyp.log_prob = p_logprob[k] + context_score;
hyps.Add(std::move(new_hyp));
} // for (auto k : topk)
p_logprob += (end - start) * vocab_size;
@@ -126,6 +152,18 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode(
cur.push_back(std::move(h));
}
// Finalize context biasing matching..
for (int32_t i = 0; i < cur.size(); ++i) {
for (auto iter = cur[i].begin(); iter != cur[i].end(); ++iter) {
if (context_graphs[i] != nullptr) {
auto context_res =
context_graphs[i]->Finalize(iter->second.context_state);
iter->second.log_prob += context_res.first;
iter->second.context_state = context_res.second;
}
}
}
if (lm_) {
// use LM for rescoring
lm_->ComputeLMScore(lm_scale_, context_size, &cur);