Implement context biasing with a Aho Corasick automata (#145)
* Implement context graph * Modify the interface to support context biasing * Support context biasing in modified beam search; add python wrapper * Support context biasing in python api example * Minor fixes * Fix context graph * Minor fixes * Fix tests * Fix style * Fix style * Fix comments * Minor fixes * Add missing header * Replace std::shared_ptr with std::unique_ptr for effciency * Build graph in constructor * Fix comments * Minor fixes * Fix docs
This commit is contained in:
@@ -10,6 +10,7 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/context-graph.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer.h"
|
||||
@@ -72,6 +73,16 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<OfflineStream> CreateStream(
|
||||
const std::vector<std::vector<int32_t>> &context_list) const override {
|
||||
// We create context_graph at this level, because we might have default
|
||||
// context_graph(will be added later if needed) that belongs to the whole
|
||||
// model rather than each stream.
|
||||
auto context_graph =
|
||||
std::make_shared<ContextGraph>(context_list, config_.context_score);
|
||||
return std::make_unique<OfflineStream>(config_.feat_config, context_graph);
|
||||
}
|
||||
|
||||
std::unique_ptr<OfflineStream> CreateStream() const override {
|
||||
return std::make_unique<OfflineStream>(config_.feat_config);
|
||||
}
|
||||
@@ -117,7 +128,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
|
||||
-23.025850929940457f);
|
||||
|
||||
auto t = model_->RunEncoder(std::move(x), std::move(x_length));
|
||||
auto results = decoder_->Decode(std::move(t.first), std::move(t.second));
|
||||
auto results =
|
||||
decoder_->Decode(std::move(t.first), std::move(t.second), ss, n);
|
||||
|
||||
int32_t frame_shift_ms = 10;
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
|
||||
Reference in New Issue
Block a user