Implement context biasing with a Aho Corasick automata (#145)
* Implement context graph * Modify the interface to support context biasing * Support context biasing in modified beam search; add python wrapper * Support context biasing in python api example * Minor fixes * Fix context graph * Minor fixes * Fix tests * Fix style * Fix style * Fix comments * Minor fixes * Add missing header * Replace std::shared_ptr with std::unique_ptr for effciency * Build graph in constructor * Fix comments * Minor fixes * Fix docs
This commit is contained in:
@@ -75,7 +75,9 @@ std::string OfflineFeatureExtractorConfig::ToString() const {
|
||||
|
||||
class OfflineStream::Impl {
|
||||
public:
|
||||
explicit Impl(const OfflineFeatureExtractorConfig &config) : config_(config) {
|
||||
explicit Impl(const OfflineFeatureExtractorConfig &config,
|
||||
ContextGraphPtr context_graph)
|
||||
: config_(config), context_graph_(context_graph) {
|
||||
opts_.frame_opts.dither = 0;
|
||||
opts_.frame_opts.snip_edges = false;
|
||||
opts_.frame_opts.samp_freq = config.sampling_rate;
|
||||
@@ -152,6 +154,8 @@ class OfflineStream::Impl {
|
||||
|
||||
const OfflineRecognitionResult &GetResult() const { return r_; }
|
||||
|
||||
const ContextGraphPtr &GetContextGraph() const { return context_graph_; }
|
||||
|
||||
private:
|
||||
void NemoNormalizeFeatures(float *p, int32_t num_frames,
|
||||
int32_t feature_dim) const {
|
||||
@@ -189,11 +193,13 @@ class OfflineStream::Impl {
|
||||
std::unique_ptr<knf::OnlineFbank> fbank_;
|
||||
knf::FbankOptions opts_;
|
||||
OfflineRecognitionResult r_;
|
||||
ContextGraphPtr context_graph_;
|
||||
};
|
||||
|
||||
OfflineStream::OfflineStream(
|
||||
const OfflineFeatureExtractorConfig &config /*= {}*/)
|
||||
: impl_(std::make_unique<Impl>(config)) {}
|
||||
const OfflineFeatureExtractorConfig &config /*= {}*/,
|
||||
ContextGraphPtr context_graph /*= nullptr*/)
|
||||
: impl_(std::make_unique<Impl>(config, context_graph)) {}
|
||||
|
||||
OfflineStream::~OfflineStream() = default;
|
||||
|
||||
@@ -212,6 +218,10 @@ void OfflineStream::SetResult(const OfflineRecognitionResult &r) {
|
||||
impl_->SetResult(r);
|
||||
}
|
||||
|
||||
const ContextGraphPtr &OfflineStream::GetContextGraph() const {
|
||||
return impl_->GetContextGraph();
|
||||
}
|
||||
|
||||
const OfflineRecognitionResult &OfflineStream::GetResult() const {
|
||||
return impl_->GetResult();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user