Implement context biasing with a Aho Corasick automata (#145)

* Implement context graph

* Modify the interface to support context biasing

* Support context biasing in modified beam search; add python wrapper

* Support context biasing in python api example

* Minor fixes

* Fix context graph

* Minor fixes

* Fix tests

* Fix style

* Fix style

* Fix comments

* Minor fixes

* Add missing header

* Replace std::shared_ptr with std::unique_ptr for effciency

* Build graph in constructor

* Fix comments

* Minor fixes

* Fix docs
This commit is contained in:
Wei Kang
2023-06-16 14:26:36 +08:00
committed by GitHub
parent 1a1b9fd236
commit 8562711252
23 changed files with 515 additions and 29 deletions

View File

@@ -12,6 +12,7 @@ endif()
set(sources
cat.cc
context-graph.cc
endpoint.cc
features.cc
file-utils.cc
@@ -248,6 +249,7 @@ endif()
if(SHERPA_ONNX_ENABLE_TESTS)
set(sherpa_onnx_test_srcs
cat-test.cc
context-graph-test.cc
packed-sequence-test.cc
pad-sequence-test.cc
slice-test.cc

View File

@@ -0,0 +1,43 @@
// sherpa-onnx/csrc/context-graph-test.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/context-graph.h"
#include <map>
#include <string>
#include <vector>
#include "gtest/gtest.h"
namespace sherpa_onnx {
TEST(ContextGraph, TestBasic) {
std::vector<std::string> contexts_str(
{"S", "HE", "SHE", "SHELL", "HIS", "HERS", "HELLO", "THIS", "THEM"});
std::vector<std::vector<int32_t>> contexts;
for (int32_t i = 0; i < contexts_str.size(); ++i) {
contexts.emplace_back(contexts_str[i].begin(), contexts_str[i].end());
}
auto context_graph = ContextGraph(contexts, 1);
auto queries = std::map<std::string, float>{
{"HEHERSHE", 14}, {"HERSHE", 12}, {"HISHE", 9}, {"SHED", 6},
{"HELL", 2}, {"HELLO", 7}, {"DHRHISQ", 4}, {"THEN", 2}};
for (const auto &iter : queries) {
float total_scores = 0;
auto state = context_graph.Root();
for (auto q : iter.first) {
auto res = context_graph.ForwardOneStep(state, q);
total_scores += res.first;
state = res.second;
}
auto res = context_graph.Finalize(state);
EXPECT_EQ(res.second->token, -1);
total_scores += res.first;
EXPECT_EQ(total_scores, iter.second);
}
}
} // namespace sherpa_onnx

View File

@@ -0,0 +1,105 @@
// sherpa-onnx/csrc/context-graph.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/context-graph.h"
#include <cassert>
#include <queue>
#include <utility>
namespace sherpa_onnx {
void ContextGraph::Build(
const std::vector<std::vector<int32_t>> &token_ids) const {
for (int32_t i = 0; i < token_ids.size(); ++i) {
auto node = root_.get();
for (int32_t j = 0; j < token_ids[i].size(); ++j) {
int32_t token = token_ids[i][j];
if (0 == node->next.count(token)) {
bool is_end = j == token_ids[i].size() - 1;
node->next[token] = std::make_unique<ContextState>(
token, context_score_, node->node_score + context_score_,
is_end ? 0 : node->local_node_score + context_score_, is_end);
}
node = node->next[token].get();
}
}
FillFailOutput();
}
std::pair<float, const ContextState *> ContextGraph::ForwardOneStep(
const ContextState *state, int32_t token) const {
const ContextState *node;
float score;
if (1 == state->next.count(token)) {
node = state->next.at(token).get();
score = node->token_score;
if (state->is_end) score += state->node_score;
} else {
node = state->fail;
while (0 == node->next.count(token)) {
node = node->fail;
if (-1 == node->token) break; // root
}
if (1 == node->next.count(token)) {
node = node->next.at(token).get();
}
score = node->node_score - state->local_node_score;
}
SHERPA_ONNX_CHECK(nullptr != node);
float matched_score = 0;
auto output = node->output;
while (nullptr != output) {
matched_score += output->node_score;
output = output->output;
}
return std::make_pair(score + matched_score, node);
}
std::pair<float, const ContextState *> ContextGraph::Finalize(
const ContextState *state) const {
float score = -state->node_score;
if (state->is_end) {
score = 0;
}
return std::make_pair(score, root_.get());
}
void ContextGraph::FillFailOutput() const {
std::queue<const ContextState *> node_queue;
for (auto &kv : root_->next) {
kv.second->fail = root_.get();
node_queue.push(kv.second.get());
}
while (!node_queue.empty()) {
auto current_node = node_queue.front();
node_queue.pop();
for (auto &kv : current_node->next) {
auto fail = current_node->fail;
if (1 == fail->next.count(kv.first)) {
fail = fail->next.at(kv.first).get();
} else {
fail = fail->fail;
while (0 == fail->next.count(kv.first)) {
fail = fail->fail;
if (-1 == fail->token) break;
}
if (1 == fail->next.count(kv.first))
fail = fail->next.at(kv.first).get();
}
kv.second->fail = fail;
// fill the output arc
auto output = fail;
while (!output->is_end) {
output = output->fail;
if (-1 == output->token) {
output = nullptr;
break;
}
}
kv.second->output = output;
node_queue.push(kv.second.get());
}
}
}
} // namespace sherpa_onnx

View File

@@ -0,0 +1,66 @@
// sherpa-onnx/csrc/context-graph.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_CONTEXT_GRAPH_H_
#define SHERPA_ONNX_CSRC_CONTEXT_GRAPH_H_
#include <memory>
#include <unordered_map>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/log.h"
namespace sherpa_onnx {
class ContextGraph;
using ContextGraphPtr = std::shared_ptr<ContextGraph>;
struct ContextState {
int32_t token;
float token_score;
float node_score;
float local_node_score;
bool is_end;
std::unordered_map<int32_t, std::unique_ptr<ContextState>> next;
const ContextState *fail = nullptr;
const ContextState *output = nullptr;
ContextState() = default;
ContextState(int32_t token, float token_score, float node_score,
float local_node_score, bool is_end)
: token(token),
token_score(token_score),
node_score(node_score),
local_node_score(local_node_score),
is_end(is_end) {}
};
class ContextGraph {
public:
ContextGraph() = default;
ContextGraph(const std::vector<std::vector<int32_t>> &token_ids,
float context_score)
: context_score_(context_score) {
root_ = std::make_unique<ContextState>(-1, 0, 0, 0, false);
root_->fail = root_.get();
Build(token_ids);
}
std::pair<float, const ContextState *> ForwardOneStep(
const ContextState *state, int32_t token_id) const;
std::pair<float, const ContextState *> Finalize(
const ContextState *state) const;
const ContextState *Root() const { return root_.get(); }
private:
float context_score_;
std::unique_ptr<ContextState> root_;
void Build(const std::vector<std::vector<int32_t>> &token_ids) const;
void FillFailOutput() const;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_CONTEXT_GRAPH_H_

View File

@@ -14,6 +14,7 @@
#include <vector>
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/context-graph.h"
#include "sherpa-onnx/csrc/math.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
@@ -39,11 +40,18 @@ struct Hypothesis {
// the nn lm states
std::vector<CopyableOrtValue> nn_lm_states;
const ContextState *context_state;
// TODO(fangjun): Make it configurable
// the minimum of tokens in a chunk for streaming RNN LM
int32_t lm_rescore_min_chunk = 2; // a const
int32_t num_trailing_blanks = 0;
Hypothesis() = default;
Hypothesis(const std::vector<int64_t> &ys, double log_prob)
: ys(ys), log_prob(log_prob) {}
Hypothesis(const std::vector<int64_t> &ys, double log_prob,
const ContextState *context_state = nullptr)
: ys(ys), log_prob(log_prob), context_state(context_state) {}
double TotalLogProb() const { return log_prob + lm_log_prob; }

View File

@@ -6,7 +6,9 @@
#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_IMPL_H_
#include <memory>
#include <vector>
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-recognizer.h"
#include "sherpa-onnx/csrc/offline-stream.h"
@@ -19,6 +21,12 @@ class OfflineRecognizerImpl {
virtual ~OfflineRecognizerImpl() = default;
virtual std::unique_ptr<OfflineStream> CreateStream(
const std::vector<std::vector<int32_t>> &context_list) const {
SHERPA_ONNX_LOGE("Only transducer models support contextual biasing.");
exit(-1);
}
virtual std::unique_ptr<OfflineStream> CreateStream() const = 0;
virtual void DecodeStreams(OfflineStream **ss, int32_t n) const = 0;

View File

@@ -10,6 +10,7 @@
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/context-graph.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
#include "sherpa-onnx/csrc/offline-recognizer.h"
@@ -72,6 +73,16 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
}
}
std::unique_ptr<OfflineStream> CreateStream(
const std::vector<std::vector<int32_t>> &context_list) const override {
// We create context_graph at this level, because we might have default
// context_graph(will be added later if needed) that belongs to the whole
// model rather than each stream.
auto context_graph =
std::make_shared<ContextGraph>(context_list, config_.context_score);
return std::make_unique<OfflineStream>(config_.feat_config, context_graph);
}
std::unique_ptr<OfflineStream> CreateStream() const override {
return std::make_unique<OfflineStream>(config_.feat_config);
}
@@ -117,7 +128,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
-23.025850929940457f);
auto t = model_->RunEncoder(std::move(x), std::move(x_length));
auto results = decoder_->Decode(std::move(t.first), std::move(t.second));
auto results =
decoder_->Decode(std::move(t.first), std::move(t.second), ss, n);
int32_t frame_shift_ms = 10;
for (int32_t i = 0; i != n; ++i) {

View File

@@ -26,6 +26,9 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) {
po->Register("max-active-paths", &max_active_paths,
"Used only when decoding_method is modified_beam_search");
po->Register("context-score", &context_score,
"The bonus score for each token in context word/phrase. "
"Used only when decoding_method is modified_beam_search");
}
bool OfflineRecognizerConfig::Validate() const {
@@ -49,7 +52,8 @@ std::string OfflineRecognizerConfig::ToString() const {
os << "model_config=" << model_config.ToString() << ", ";
os << "lm_config=" << lm_config.ToString() << ", ";
os << "decoding_method=\"" << decoding_method << "\", ";
os << "max_active_paths=" << max_active_paths << ")";
os << "max_active_paths=" << max_active_paths << ", ";
os << "context_score=" << context_score << ")";
return os.str();
}
@@ -59,6 +63,11 @@ OfflineRecognizer::OfflineRecognizer(const OfflineRecognizerConfig &config)
OfflineRecognizer::~OfflineRecognizer() = default;
std::unique_ptr<OfflineStream> OfflineRecognizer::CreateStream(
const std::vector<std::vector<int32_t>> &context_list) const {
return impl_->CreateStream(context_list);
}
std::unique_ptr<OfflineStream> OfflineRecognizer::CreateStream() const {
return impl_->CreateStream();
}

View File

@@ -26,6 +26,7 @@ struct OfflineRecognizerConfig {
std::string decoding_method = "greedy_search";
int32_t max_active_paths = 4;
float context_score = 1.5;
// only greedy_search is implemented
// TODO(fangjun): Implement modified_beam_search
@@ -34,12 +35,13 @@ struct OfflineRecognizerConfig {
const OfflineModelConfig &model_config,
const OfflineLMConfig &lm_config,
const std::string &decoding_method,
int32_t max_active_paths)
int32_t max_active_paths, float context_score)
: feat_config(feat_config),
model_config(model_config),
lm_config(lm_config),
decoding_method(decoding_method),
max_active_paths(max_active_paths) {}
max_active_paths(max_active_paths),
context_score(context_score) {}
void Register(ParseOptions *po);
bool Validate() const;
@@ -58,6 +60,10 @@ class OfflineRecognizer {
/// Create a stream for decoding.
std::unique_ptr<OfflineStream> CreateStream() const;
/// Create a stream for decoding.
std::unique_ptr<OfflineStream> CreateStream(
const std::vector<std::vector<int32_t>> &context_list) const;
/** Decode a single stream
*
* @param s The stream to decode.

View File

@@ -75,7 +75,9 @@ std::string OfflineFeatureExtractorConfig::ToString() const {
class OfflineStream::Impl {
public:
explicit Impl(const OfflineFeatureExtractorConfig &config) : config_(config) {
explicit Impl(const OfflineFeatureExtractorConfig &config,
ContextGraphPtr context_graph)
: config_(config), context_graph_(context_graph) {
opts_.frame_opts.dither = 0;
opts_.frame_opts.snip_edges = false;
opts_.frame_opts.samp_freq = config.sampling_rate;
@@ -152,6 +154,8 @@ class OfflineStream::Impl {
const OfflineRecognitionResult &GetResult() const { return r_; }
const ContextGraphPtr &GetContextGraph() const { return context_graph_; }
private:
void NemoNormalizeFeatures(float *p, int32_t num_frames,
int32_t feature_dim) const {
@@ -189,11 +193,13 @@ class OfflineStream::Impl {
std::unique_ptr<knf::OnlineFbank> fbank_;
knf::FbankOptions opts_;
OfflineRecognitionResult r_;
ContextGraphPtr context_graph_;
};
OfflineStream::OfflineStream(
const OfflineFeatureExtractorConfig &config /*= {}*/)
: impl_(std::make_unique<Impl>(config)) {}
const OfflineFeatureExtractorConfig &config /*= {}*/,
ContextGraphPtr context_graph /*= nullptr*/)
: impl_(std::make_unique<Impl>(config, context_graph)) {}
OfflineStream::~OfflineStream() = default;
@@ -212,6 +218,10 @@ void OfflineStream::SetResult(const OfflineRecognitionResult &r) {
impl_->SetResult(r);
}
const ContextGraphPtr &OfflineStream::GetContextGraph() const {
return impl_->GetContextGraph();
}
const OfflineRecognitionResult &OfflineStream::GetResult() const {
return impl_->GetResult();
}

View File

@@ -10,6 +10,7 @@
#include <string>
#include <vector>
#include "sherpa-onnx/csrc/context-graph.h"
#include "sherpa-onnx/csrc/parse-options.h"
namespace sherpa_onnx {
@@ -66,7 +67,8 @@ struct OfflineFeatureExtractorConfig {
class OfflineStream {
public:
explicit OfflineStream(const OfflineFeatureExtractorConfig &config = {});
explicit OfflineStream(const OfflineFeatureExtractorConfig &config = {},
ContextGraphPtr context_graph = nullptr);
~OfflineStream();
/**
@@ -96,6 +98,9 @@ class OfflineStream {
/** Get the recognition result of this stream */
const OfflineRecognitionResult &GetResult() const;
/** Get the ContextGraph of this stream */
const ContextGraphPtr &GetContextGraph() const;
private:
class Impl;
std::unique_ptr<Impl> impl_;

View File

@@ -8,6 +8,7 @@
#include <vector>
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-stream.h"
namespace sherpa_onnx {
@@ -33,7 +34,8 @@ class OfflineTransducerDecoder {
* @return Return a vector of size `N` containing the decoded results.
*/
virtual std::vector<OfflineTransducerDecoderResult> Decode(
Ort::Value encoder_out, Ort::Value encoder_out_length) = 0;
Ort::Value encoder_out, Ort::Value encoder_out_length,
OfflineStream **ss = nullptr, int32_t n = 0) = 0;
};
} // namespace sherpa_onnx

View File

@@ -16,7 +16,9 @@ namespace sherpa_onnx {
std::vector<OfflineTransducerDecoderResult>
OfflineTransducerGreedySearchDecoder::Decode(Ort::Value encoder_out,
Ort::Value encoder_out_length) {
Ort::Value encoder_out_length,
OfflineStream **ss /*= nullptr*/,
int32_t n /*= 0*/) {
PackedSequence packed_encoder_out = PackPaddedSequence(
model_->Allocator(), &encoder_out, &encoder_out_length);

View File

@@ -18,7 +18,8 @@ class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder {
: model_(model) {}
std::vector<OfflineTransducerDecoderResult> Decode(
Ort::Value encoder_out, Ort::Value encoder_out_length) override;
Ort::Value encoder_out, Ort::Value encoder_out_length,
OfflineStream **ss = nullptr, int32_t n = 0) override;
private:
OfflineTransducerModel *model_; // Not owned

View File

@@ -8,7 +8,9 @@
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/context-graph.h"
#include "sherpa-onnx/csrc/hypothesis.h"
#include "sherpa-onnx/csrc/log.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/packed-sequence.h"
#include "sherpa-onnx/csrc/slice.h"
@@ -17,23 +19,39 @@ namespace sherpa_onnx {
std::vector<OfflineTransducerDecoderResult>
OfflineTransducerModifiedBeamSearchDecoder::Decode(
Ort::Value encoder_out, Ort::Value encoder_out_length) {
Ort::Value encoder_out, Ort::Value encoder_out_length,
OfflineStream **ss /*=nullptr */, int32_t n /*= 0*/) {
PackedSequence packed_encoder_out = PackPaddedSequence(
model_->Allocator(), &encoder_out, &encoder_out_length);
int32_t batch_size =
static_cast<int32_t>(packed_encoder_out.sorted_indexes.size());
if (ss != nullptr) SHERPA_ONNX_CHECK_EQ(batch_size, n);
int32_t vocab_size = model_->VocabSize();
int32_t context_size = model_->ContextSize();
std::vector<int64_t> blanks(context_size, 0);
Hypotheses blank_hyp({{blanks, 0}});
std::deque<Hypotheses> finalized;
std::vector<Hypotheses> cur(batch_size, blank_hyp);
std::vector<Hypotheses> cur;
std::vector<Hypothesis> prev;
std::vector<ContextGraphPtr> context_graphs(batch_size, nullptr);
for (int32_t i = 0; i < batch_size; ++i) {
const ContextState *context_state;
if (ss != nullptr) {
context_graphs[i] =
ss[packed_encoder_out.sorted_indexes[i]]->GetContextGraph();
if (context_graphs[i] != nullptr)
context_state = context_graphs[i]->Root();
}
Hypotheses blank_hyp({{blanks, 0, context_state}});
cur.emplace_back(std::move(blank_hyp));
}
int32_t start = 0;
int32_t t = 0;
for (auto n : packed_encoder_out.batch_sizes) {
@@ -106,13 +124,21 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode(
int32_t new_token = k % vocab_size;
Hypothesis new_hyp = prev[hyp_index];
float context_score = 0;
auto context_state = new_hyp.context_state;
if (new_token != 0) {
// blank id is fixed to 0
new_hyp.ys.push_back(new_token);
new_hyp.timestamps.push_back(t);
if (context_graphs[i] != nullptr) {
auto context_res =
context_graphs[i]->ForwardOneStep(context_state, new_token);
context_score = context_res.first;
new_hyp.context_state = context_res.second;
}
}
new_hyp.log_prob = p_logprob[k];
new_hyp.log_prob = p_logprob[k] + context_score;
hyps.Add(std::move(new_hyp));
} // for (auto k : topk)
p_logprob += (end - start) * vocab_size;
@@ -126,6 +152,18 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode(
cur.push_back(std::move(h));
}
// Finalize context biasing matching..
for (int32_t i = 0; i < cur.size(); ++i) {
for (auto iter = cur[i].begin(); iter != cur[i].end(); ++iter) {
if (context_graphs[i] != nullptr) {
auto context_res =
context_graphs[i]->Finalize(iter->second.context_state);
iter->second.log_prob += context_res.first;
iter->second.context_state = context_res.second;
}
}
}
if (lm_) {
// use LM for rescoring
lm_->ComputeLMScore(lm_scale_, context_size, &cur);

View File

@@ -26,7 +26,8 @@ class OfflineTransducerModifiedBeamSearchDecoder
lm_scale_(lm_scale) {}
std::vector<OfflineTransducerDecoderResult> Decode(
Ort::Value encoder_out, Ort::Value encoder_out_length) override;
Ort::Value encoder_out, Ort::Value encoder_out_length,
OfflineStream **ss = nullptr, int32_t n = 0) override;
private:
OfflineTransducerModel *model_; // Not owned