Implement context biasing with a Aho Corasick automata (#145)
* Implement context graph * Modify the interface to support context biasing * Support context biasing in modified beam search; add python wrapper * Support context biasing in python api example * Minor fixes * Fix context graph * Minor fixes * Fix tests * Fix style * Fix style * Fix comments * Minor fixes * Add missing header * Replace std::shared_ptr with std::unique_ptr for effciency * Build graph in constructor * Fix comments * Minor fixes * Fix docs
This commit is contained in:
74
sherpa-onnx/python/sherpa_onnx/utils.py
Normal file
74
sherpa-onnx/python/sherpa_onnx/utils.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
|
||||
def encode_contexts(
|
||||
modeling_unit: str,
|
||||
contexts: List[str],
|
||||
sp: Optional["SentencePieceProcessor"] = None,
|
||||
tokens_table: Optional[Dict[str, int]] = None,
|
||||
) -> List[List[int]]:
|
||||
"""
|
||||
Encode the given contexts (a list of string) to a list of a list of token ids.
|
||||
|
||||
Args:
|
||||
modeling_unit:
|
||||
The valid values are bpe, char, bpe+char.
|
||||
Note: char here means characters in CJK languages, not English like languages.
|
||||
contexts:
|
||||
The given contexts list (a list of string).
|
||||
sp:
|
||||
An instance of SentencePieceProcessor.
|
||||
tokens_table:
|
||||
The tokens_table containing the tokens and the corresponding ids.
|
||||
Returns:
|
||||
Return the contexts_list, it is a list of a list of token ids.
|
||||
"""
|
||||
contexts_list = []
|
||||
if "bpe" in modeling_unit:
|
||||
assert sp is not None
|
||||
if "char" in modeling_unit:
|
||||
assert tokens_table is not None
|
||||
assert len(tokens_table) > 0, len(tokens_table)
|
||||
|
||||
if "char" == modeling_unit:
|
||||
for context in contexts:
|
||||
assert ' ' not in context
|
||||
ids = [
|
||||
tokens_table[txt] if txt in tokens_table else tokens_table["<unk>"]
|
||||
for txt in context
|
||||
]
|
||||
contexts_list.append(ids)
|
||||
elif "bpe" == modeling_unit:
|
||||
contexts_list = sp.encode(contexts, out_type=int)
|
||||
else:
|
||||
assert modeling_unit == "bpe+char", modeling_unit
|
||||
|
||||
# CJK(China Japan Korea) unicode range is [U+4E00, U+9FFF], ref:
|
||||
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
||||
pattern = re.compile(r"([\u4e00-\u9fff])")
|
||||
for context in contexts:
|
||||
# Example:
|
||||
# txt = "你好 ITS'S OKAY 的"
|
||||
# chars = ["你", "好", " ITS'S OKAY ", "的"]
|
||||
chars = pattern.split(context.upper())
|
||||
mix_chars = [w for w in chars if len(w.strip()) > 0]
|
||||
ids = []
|
||||
for ch_or_w in mix_chars:
|
||||
# ch_or_w is a single CJK charater(i.e., "你"), do nothing.
|
||||
if pattern.fullmatch(ch_or_w) is not None:
|
||||
ids.append(
|
||||
tokens_table[ch_or_w]
|
||||
if ch_or_w in tokens_table
|
||||
else tokens_table["<unk>"]
|
||||
)
|
||||
# ch_or_w contains non-CJK charaters(i.e., " IT'S OKAY "),
|
||||
# encode ch_or_w using bpe_model.
|
||||
else:
|
||||
for p in sp.encode_as_pieces(ch_or_w):
|
||||
ids.append(
|
||||
tokens_table[p]
|
||||
if p in tokens_table
|
||||
else tokens_table["<unk>"]
|
||||
)
|
||||
contexts_list.append(ids)
|
||||
return contexts_list
|
||||
Reference in New Issue
Block a user