Implement context biasing with a Aho Corasick automata (#145)

* Implement context graph

* Modify the interface to support context biasing

* Support context biasing in modified beam search; add python wrapper

* Support context biasing in python api example

* Minor fixes

* Fix context graph

* Minor fixes

* Fix tests

* Fix style

* Fix style

* Fix comments

* Minor fixes

* Add missing header

* Replace std::shared_ptr with std::unique_ptr for effciency

* Build graph in constructor

* Fix comments

* Minor fixes

* Fix docs
This commit is contained in:
Wei Kang
2023-06-16 14:26:36 +08:00
committed by GitHub
parent 1a1b9fd236
commit 8562711252
23 changed files with 515 additions and 29 deletions

View File

@@ -75,7 +75,9 @@ std::string OfflineFeatureExtractorConfig::ToString() const {
class OfflineStream::Impl {
public:
explicit Impl(const OfflineFeatureExtractorConfig &config) : config_(config) {
explicit Impl(const OfflineFeatureExtractorConfig &config,
ContextGraphPtr context_graph)
: config_(config), context_graph_(context_graph) {
opts_.frame_opts.dither = 0;
opts_.frame_opts.snip_edges = false;
opts_.frame_opts.samp_freq = config.sampling_rate;
@@ -152,6 +154,8 @@ class OfflineStream::Impl {
const OfflineRecognitionResult &GetResult() const { return r_; }
const ContextGraphPtr &GetContextGraph() const { return context_graph_; }
private:
void NemoNormalizeFeatures(float *p, int32_t num_frames,
int32_t feature_dim) const {
@@ -189,11 +193,13 @@ class OfflineStream::Impl {
std::unique_ptr<knf::OnlineFbank> fbank_;
knf::FbankOptions opts_;
OfflineRecognitionResult r_;
ContextGraphPtr context_graph_;
};
OfflineStream::OfflineStream(
const OfflineFeatureExtractorConfig &config /*= {}*/)
: impl_(std::make_unique<Impl>(config)) {}
const OfflineFeatureExtractorConfig &config /*= {}*/,
ContextGraphPtr context_graph /*= nullptr*/)
: impl_(std::make_unique<Impl>(config, context_graph)) {}
OfflineStream::~OfflineStream() = default;
@@ -212,6 +218,10 @@ void OfflineStream::SetResult(const OfflineRecognitionResult &r) {
impl_->SetResult(r);
}
const ContextGraphPtr &OfflineStream::GetContextGraph() const {
return impl_->GetContextGraph();
}
const OfflineRecognitionResult &OfflineStream::GetResult() const {
return impl_->GetResult();
}