Implement context biasing with a Aho Corasick automata (#145)

* Implement context graph

* Modify the interface to support context biasing

* Support context biasing in modified beam search; add python wrapper

* Support context biasing in python api example

* Minor fixes

* Fix context graph

* Minor fixes

* Fix tests

* Fix style

* Fix style

* Fix comments

* Minor fixes

* Add missing header

* Replace std::shared_ptr with std::unique_ptr for effciency

* Build graph in constructor

* Fix comments

* Minor fixes

* Fix docs
This commit is contained in:
Wei Kang
2023-06-16 14:26:36 +08:00
committed by GitHub
parent 1a1b9fd236
commit 8562711252
23 changed files with 515 additions and 29 deletions

View File

@@ -16,16 +16,17 @@ static void PybindOfflineRecognizerConfig(py::module *m) {
py::class_<PyClass>(*m, "OfflineRecognizerConfig")
.def(py::init<const OfflineFeatureExtractorConfig &,
const OfflineModelConfig &, const OfflineLMConfig &,
const std::string &, int32_t>(),
const std::string &, int32_t, float>(),
py::arg("feat_config"), py::arg("model_config"),
py::arg("lm_config") = OfflineLMConfig(),
py::arg("decoding_method") = "greedy_search",
py::arg("max_active_paths") = 4)
py::arg("max_active_paths") = 4, py::arg("context_score") = 1.5)
.def_readwrite("feat_config", &PyClass::feat_config)
.def_readwrite("model_config", &PyClass::model_config)
.def_readwrite("lm_config", &PyClass::lm_config)
.def_readwrite("decoding_method", &PyClass::decoding_method)
.def_readwrite("max_active_paths", &PyClass::max_active_paths)
.def_readwrite("context_score", &PyClass::context_score)
.def("__str__", &PyClass::ToString);
}
@@ -35,10 +36,18 @@ void PybindOfflineRecognizer(py::module *m) {
using PyClass = OfflineRecognizer;
py::class_<PyClass>(*m, "OfflineRecognizer")
.def(py::init<const OfflineRecognizerConfig &>(), py::arg("config"))
.def("create_stream", &PyClass::CreateStream)
.def("create_stream",
[](const PyClass &self) { return self.CreateStream(); })
.def(
"create_stream",
[](PyClass &self,
const std::vector<std::vector<int32_t>> &contexts_list) {
return self.CreateStream(contexts_list);
},
py::arg("contexts_list"))
.def("decode_stream", &PyClass::DecodeStream)
.def("decode_streams",
[](PyClass &self, std::vector<OfflineStream *> ss) {
[](const PyClass &self, std::vector<OfflineStream *> ss) {
self.DecodeStreams(ss.data(), ss.size());
});
}

View File

@@ -1,5 +1,12 @@
from typing import Dict, List, Optional
from _sherpa_onnx import Display
from .online_recognizer import OnlineRecognizer
from .online_recognizer import OnlineStream
from .offline_recognizer import OfflineRecognizer
from .utils import encode_contexts

View File

@@ -1,6 +1,6 @@
# Copyright (c) 2023 by manyeyes
from pathlib import Path
from typing import List
from typing import List, Optional
from _sherpa_onnx import (
OfflineFeatureExtractorConfig,
@@ -39,6 +39,7 @@ class OfflineRecognizer(object):
sample_rate: int = 16000,
feature_dim: int = 80,
decoding_method: str = "greedy_search",
context_score: float = 1.5,
debug: bool = False,
provider: str = "cpu",
):
@@ -96,6 +97,7 @@ class OfflineRecognizer(object):
feat_config=feat_config,
model_config=model_config,
decoding_method=decoding_method,
context_score=context_score,
)
self.recognizer = _Recognizer(recognizer_config)
return self
@@ -216,8 +218,11 @@ class OfflineRecognizer(object):
self.recognizer = _Recognizer(recognizer_config)
return self
def create_stream(self):
return self.recognizer.create_stream()
def create_stream(self, contexts_list: Optional[List[List[int]]] = None):
if contexts_list is None:
return self.recognizer.create_stream()
else:
return self.recognizer.create_stream(contexts_list)
def decode_stream(self, s: OfflineStream):
self.recognizer.decode_stream(s)

View File

@@ -0,0 +1,74 @@
from typing import Dict, List, Optional
def encode_contexts(
modeling_unit: str,
contexts: List[str],
sp: Optional["SentencePieceProcessor"] = None,
tokens_table: Optional[Dict[str, int]] = None,
) -> List[List[int]]:
"""
Encode the given contexts (a list of string) to a list of a list of token ids.
Args:
modeling_unit:
The valid values are bpe, char, bpe+char.
Note: char here means characters in CJK languages, not English like languages.
contexts:
The given contexts list (a list of string).
sp:
An instance of SentencePieceProcessor.
tokens_table:
The tokens_table containing the tokens and the corresponding ids.
Returns:
Return the contexts_list, it is a list of a list of token ids.
"""
contexts_list = []
if "bpe" in modeling_unit:
assert sp is not None
if "char" in modeling_unit:
assert tokens_table is not None
assert len(tokens_table) > 0, len(tokens_table)
if "char" == modeling_unit:
for context in contexts:
assert ' ' not in context
ids = [
tokens_table[txt] if txt in tokens_table else tokens_table["<unk>"]
for txt in context
]
contexts_list.append(ids)
elif "bpe" == modeling_unit:
contexts_list = sp.encode(contexts, out_type=int)
else:
assert modeling_unit == "bpe+char", modeling_unit
# CJK(China Japan Korea) unicode range is [U+4E00, U+9FFF], ref:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
pattern = re.compile(r"([\u4e00-\u9fff])")
for context in contexts:
# Example:
# txt = "你好 ITS'S OKAY 的"
# chars = ["你", "好", " ITS'S OKAY ", "的"]
chars = pattern.split(context.upper())
mix_chars = [w for w in chars if len(w.strip()) > 0]
ids = []
for ch_or_w in mix_chars:
# ch_or_w is a single CJK charater(i.e., "你"), do nothing.
if pattern.fullmatch(ch_or_w) is not None:
ids.append(
tokens_table[ch_or_w]
if ch_or_w in tokens_table
else tokens_table["<unk>"]
)
# ch_or_w contains non-CJK charaters(i.e., " IT'S OKAY "),
# encode ch_or_w using bpe_model.
else:
for p in sp.encode_as_pieces(ch_or_w):
ids.append(
tokens_table[p]
if p in tokens_table
else tokens_table["<unk>"]
)
contexts_list.append(ids)
return contexts_list