Implement context biasing with a Aho Corasick automata (#145)
* Implement context graph * Modify the interface to support context biasing * Support context biasing in modified beam search; add python wrapper * Support context biasing in python api example * Minor fixes * Fix context graph * Minor fixes * Fix tests * Fix style * Fix style * Fix comments * Minor fixes * Add missing header * Replace std::shared_ptr with std::unique_ptr for effciency * Build graph in constructor * Fix comments * Minor fixes * Fix docs
This commit is contained in:
@@ -8,7 +8,9 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/context-graph.h"
|
||||
#include "sherpa-onnx/csrc/hypothesis.h"
|
||||
#include "sherpa-onnx/csrc/log.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/packed-sequence.h"
|
||||
#include "sherpa-onnx/csrc/slice.h"
|
||||
@@ -17,23 +19,39 @@ namespace sherpa_onnx {
|
||||
|
||||
std::vector<OfflineTransducerDecoderResult>
|
||||
OfflineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
Ort::Value encoder_out, Ort::Value encoder_out_length) {
|
||||
Ort::Value encoder_out, Ort::Value encoder_out_length,
|
||||
OfflineStream **ss /*=nullptr */, int32_t n /*= 0*/) {
|
||||
PackedSequence packed_encoder_out = PackPaddedSequence(
|
||||
model_->Allocator(), &encoder_out, &encoder_out_length);
|
||||
|
||||
int32_t batch_size =
|
||||
static_cast<int32_t>(packed_encoder_out.sorted_indexes.size());
|
||||
|
||||
if (ss != nullptr) SHERPA_ONNX_CHECK_EQ(batch_size, n);
|
||||
|
||||
int32_t vocab_size = model_->VocabSize();
|
||||
int32_t context_size = model_->ContextSize();
|
||||
|
||||
std::vector<int64_t> blanks(context_size, 0);
|
||||
Hypotheses blank_hyp({{blanks, 0}});
|
||||
|
||||
std::deque<Hypotheses> finalized;
|
||||
std::vector<Hypotheses> cur(batch_size, blank_hyp);
|
||||
std::vector<Hypotheses> cur;
|
||||
std::vector<Hypothesis> prev;
|
||||
|
||||
std::vector<ContextGraphPtr> context_graphs(batch_size, nullptr);
|
||||
|
||||
for (int32_t i = 0; i < batch_size; ++i) {
|
||||
const ContextState *context_state;
|
||||
if (ss != nullptr) {
|
||||
context_graphs[i] =
|
||||
ss[packed_encoder_out.sorted_indexes[i]]->GetContextGraph();
|
||||
if (context_graphs[i] != nullptr)
|
||||
context_state = context_graphs[i]->Root();
|
||||
}
|
||||
Hypotheses blank_hyp({{blanks, 0, context_state}});
|
||||
cur.emplace_back(std::move(blank_hyp));
|
||||
}
|
||||
|
||||
int32_t start = 0;
|
||||
int32_t t = 0;
|
||||
for (auto n : packed_encoder_out.batch_sizes) {
|
||||
@@ -106,13 +124,21 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
int32_t new_token = k % vocab_size;
|
||||
Hypothesis new_hyp = prev[hyp_index];
|
||||
|
||||
float context_score = 0;
|
||||
auto context_state = new_hyp.context_state;
|
||||
if (new_token != 0) {
|
||||
// blank id is fixed to 0
|
||||
new_hyp.ys.push_back(new_token);
|
||||
new_hyp.timestamps.push_back(t);
|
||||
if (context_graphs[i] != nullptr) {
|
||||
auto context_res =
|
||||
context_graphs[i]->ForwardOneStep(context_state, new_token);
|
||||
context_score = context_res.first;
|
||||
new_hyp.context_state = context_res.second;
|
||||
}
|
||||
}
|
||||
|
||||
new_hyp.log_prob = p_logprob[k];
|
||||
new_hyp.log_prob = p_logprob[k] + context_score;
|
||||
hyps.Add(std::move(new_hyp));
|
||||
} // for (auto k : topk)
|
||||
p_logprob += (end - start) * vocab_size;
|
||||
@@ -126,6 +152,18 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
cur.push_back(std::move(h));
|
||||
}
|
||||
|
||||
// Finalize context biasing matching..
|
||||
for (int32_t i = 0; i < cur.size(); ++i) {
|
||||
for (auto iter = cur[i].begin(); iter != cur[i].end(); ++iter) {
|
||||
if (context_graphs[i] != nullptr) {
|
||||
auto context_res =
|
||||
context_graphs[i]->Finalize(iter->second.context_state);
|
||||
iter->second.log_prob += context_res.first;
|
||||
iter->second.context_state = context_res.second;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (lm_) {
|
||||
// use LM for rescoring
|
||||
lm_->ComputeLMScore(lm_scale_, context_size, &cur);
|
||||
|
||||
Reference in New Issue
Block a user