Implement context biasing with a Aho Corasick automata (#145)

* Implement context graph

* Modify the interface to support context biasing

* Support context biasing in modified beam search; add python wrapper

* Support context biasing in python api example

* Minor fixes

* Fix context graph

* Minor fixes

* Fix tests

* Fix style

* Fix style

* Fix comments

* Minor fixes

* Add missing header

* Replace std::shared_ptr with std::unique_ptr for effciency

* Build graph in constructor

* Fix comments

* Minor fixes

* Fix docs
This commit is contained in:
Wei Kang
2023-06-16 14:26:36 +08:00
committed by GitHub
parent 1a1b9fd236
commit 8562711252
23 changed files with 515 additions and 29 deletions

View File

@@ -1,6 +1,6 @@
# Copyright (c) 2023 by manyeyes
from pathlib import Path
from typing import List
from typing import List, Optional
from _sherpa_onnx import (
OfflineFeatureExtractorConfig,
@@ -39,6 +39,7 @@ class OfflineRecognizer(object):
sample_rate: int = 16000,
feature_dim: int = 80,
decoding_method: str = "greedy_search",
context_score: float = 1.5,
debug: bool = False,
provider: str = "cpu",
):
@@ -96,6 +97,7 @@ class OfflineRecognizer(object):
feat_config=feat_config,
model_config=model_config,
decoding_method=decoding_method,
context_score=context_score,
)
self.recognizer = _Recognizer(recognizer_config)
return self
@@ -216,8 +218,11 @@ class OfflineRecognizer(object):
self.recognizer = _Recognizer(recognizer_config)
return self
def create_stream(self):
return self.recognizer.create_stream()
def create_stream(self, contexts_list: Optional[List[List[int]]] = None):
if contexts_list is None:
return self.recognizer.create_stream()
else:
return self.recognizer.create_stream(contexts_list)
def decode_stream(self, s: OfflineStream):
self.recognizer.decode_stream(s)