Refactor hotwords,support loading hotwords from file (#296)

This commit is contained in:
Wei Kang
2023-09-14 19:33:17 +08:00
committed by GitHub
parent 087367d7fe
commit 47184f9db7
34 changed files with 803 additions and 300 deletions

View File

@@ -4,4 +4,4 @@ from _sherpa_onnx import Display, OfflineStream, OnlineStream
from .offline_recognizer import OfflineRecognizer
from .online_recognizer import OnlineRecognizer
from .utils import encode_contexts
from .utils import text2token

View File

@@ -0,0 +1,55 @@
# Copyright (c) 2023 Xiaomi Corporation
import logging
import click
from pathlib import Path
from sherpa_onnx import text2token
@click.group()
def cli():
"""
The shell entry point to sherpa-onnx.
"""
logging.basicConfig(
format="%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s",
level=logging.INFO,
)
@cli.command(name="text2token")
@click.argument("input", type=click.Path(exists=True, dir_okay=False))
@click.argument("output", type=click.Path())
@click.option(
"--tokens",
type=str,
required=True,
help="The path to tokens.txt.",
)
@click.option(
"--tokens-type",
type=str,
required=True,
help="The type of modeling units, should be cjkchar, bpe or cjkchar+bpe",
)
@click.option(
"--bpe-model",
type=str,
help="The path to bpe.model. Only required when tokens-type is bpe or cjkchar+bpe.",
)
def encode_text(
input: Path, output: Path, tokens: Path, tokens_type: str, bpe_model: Path
):
"""
Encode the texts given by the INPUT to tokens and write the results to the OUTPUT.
"""
texts = []
with open(input, "r", encoding="utf8") as f:
for line in f:
texts.append(line.strip())
encoded_texts = text2token(
texts, tokens=tokens, tokens_type=tokens_type, bpe_model=bpe_model
)
with open(output, "w", encoding="utf8") as f:
for txt in encoded_texts:
f.write(" ".join(txt) + "\n")

View File

@@ -43,7 +43,8 @@ class OfflineRecognizer(object):
feature_dim: int = 80,
decoding_method: str = "greedy_search",
max_active_paths: int = 4,
context_score: float = 1.5,
hotwords_file: str = "",
hotwords_score: float = 1.5,
debug: bool = False,
provider: str = "cpu",
):
@@ -105,7 +106,8 @@ class OfflineRecognizer(object):
feat_config=feat_config,
model_config=model_config,
decoding_method=decoding_method,
context_score=context_score,
hotwords_file=hotwords_file,
hotwords_score=hotwords_score,
)
self.recognizer = _Recognizer(recognizer_config)
self.config = recognizer_config
@@ -379,11 +381,11 @@ class OfflineRecognizer(object):
self.config = recognizer_config
return self
def create_stream(self, contexts_list: Optional[List[List[int]]] = None):
if contexts_list is None:
def create_stream(self, hotwords: Optional[str] = None):
if hotwords is None:
return self.recognizer.create_stream()
else:
return self.recognizer.create_stream(contexts_list)
return self.recognizer.create_stream(hotwords)
def decode_stream(self, s: OfflineStream):
self.recognizer.decode_stream(s)

View File

@@ -42,7 +42,8 @@ class OnlineRecognizer(object):
rule3_min_utterance_length: float = 20.0,
decoding_method: str = "greedy_search",
max_active_paths: int = 4,
context_score: float = 1.5,
hotwords_score: float = 1.5,
hotwords_file: str = "",
provider: str = "cpu",
model_type: str = "",
):
@@ -138,7 +139,8 @@ class OnlineRecognizer(object):
enable_endpoint=enable_endpoint_detection,
decoding_method=decoding_method,
max_active_paths=max_active_paths,
context_score=context_score,
hotwords_score=hotwords_score,
hotwords_file=hotwords_file,
)
self.recognizer = _Recognizer(recognizer_config)
@@ -248,11 +250,11 @@ class OnlineRecognizer(object):
self.config = recognizer_config
return self
def create_stream(self, contexts_list: Optional[List[List[int]]] = None):
if contexts_list is None:
def create_stream(self, hotwords: Optional[str] = None):
if hotwords is None:
return self.recognizer.create_stream()
else:
return self.recognizer.create_stream(contexts_list)
return self.recognizer.create_stream(hotwords)
def decode_stream(self, s: OnlineStream):
self.recognizer.decode_stream(s)

View File

@@ -1,74 +1,95 @@
from typing import Dict, List, Optional
# Copyright (c) 2023 Xiaomi Corporation
import re
from pathlib import Path
from typing import List, Optional, Union
import sentencepiece as spm
def encode_contexts(
modeling_unit: str,
contexts: List[str],
sp: Optional["SentencePieceProcessor"] = None,
tokens_table: Optional[Dict[str, int]] = None,
) -> List[List[int]]:
def text2token(
texts: List[str],
tokens: str,
tokens_type: str = "cjkchar",
bpe_model: Optional[str] = None,
output_ids: bool = False,
) -> List[List[Union[str, int]]]:
"""
Encode the given contexts (a list of string) to a list of a list of token ids.
Encode the given texts (a list of string) to a list of a list of tokens.
Args:
modeling_unit:
The valid values are bpe, char, bpe+char.
Note: char here means characters in CJK languages, not English like languages.
contexts:
texts:
The given contexts list (a list of string).
sp:
An instance of SentencePieceProcessor.
tokens_table:
The tokens_table containing the tokens and the corresponding ids.
tokens:
The path of the tokens.txt.
tokens_type:
The valid values are cjkchar, bpe, cjkchar+bpe.
bpe_model:
The path of the bpe model. Only required when tokens_type is bpe or
cjkchar+bpe.
output_ids:
True to output token ids otherwise tokens.
Returns:
Return the contexts_list, it is a list of a list of token ids.
Return the encoded texts, it is a list of a list of token ids if output_ids
is True, or it is a list of list of tokens.
"""
contexts_list = []
if "bpe" in modeling_unit:
assert sp is not None
if "char" in modeling_unit:
assert tokens_table is not None
assert len(tokens_table) > 0, len(tokens_table)
assert Path(tokens).is_file(), f"File not exists, {tokens}"
tokens_table = {}
with open(tokens, "r", encoding="utf-8") as f:
for line in f:
toks = line.strip().split()
assert len(toks) == 2, len(toks)
assert toks[0] not in tokens_table, f"Duplicate token: {toks} "
tokens_table[toks[0]] = int(toks[1])
if "char" == modeling_unit:
for context in contexts:
assert ' ' not in context
ids = [
tokens_table[txt] if txt in tokens_table else tokens_table["<unk>"]
for txt in context
]
contexts_list.append(ids)
elif "bpe" == modeling_unit:
contexts_list = sp.encode(contexts, out_type=int)
if "bpe" in tokens_type:
assert Path(bpe_model).is_file(), f"File not exists, {bpe_model}"
sp = spm.SentencePieceProcessor()
sp.load(bpe_model)
texts_list: List[List[str]] = []
if tokens_type == "cjkchar":
texts_list = [list("".join(text.split())) for text in texts]
elif tokens_type == "bpe":
texts_list = sp.encode(texts, out_type=str)
else:
assert modeling_unit == "bpe+char", modeling_unit
assert (
tokens_type == "cjkchar+bpe"
), f"Supported tokens_type are cjkchar, bpe, cjkchar+bpe, given {tokens_type}"
# CJK(China Japan Korea) unicode range is [U+4E00, U+9FFF], ref:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
pattern = re.compile(r"([\u4e00-\u9fff])")
for context in contexts:
for text in texts:
# Example:
# txt = "你好 ITS'S OKAY 的"
# chars = ["你", "好", " ITS'S OKAY ", "的"]
chars = pattern.split(context.upper())
chars = pattern.split(text)
mix_chars = [w for w in chars if len(w.strip()) > 0]
ids = []
text_list = []
for ch_or_w in mix_chars:
# ch_or_w is a single CJK charater(i.e., "你"), do nothing.
if pattern.fullmatch(ch_or_w) is not None:
ids.append(
tokens_table[ch_or_w]
if ch_or_w in tokens_table
else tokens_table["<unk>"]
)
text_list.append(ch_or_w)
# ch_or_w contains non-CJK charaters(i.e., " IT'S OKAY "),
# encode ch_or_w using bpe_model.
else:
for p in sp.encode_as_pieces(ch_or_w):
ids.append(
tokens_table[p]
if p in tokens_table
else tokens_table["<unk>"]
)
contexts_list.append(ids)
return contexts_list
text_list += sp.encode_as_pieces(ch_or_w)
texts_list.append(text_list)
result: List[List[Union[int, str]]] = []
for text in texts_list:
text_list = []
contain_oov = False
for txt in text:
if txt in tokens_table:
text_list.append(tokens_table[txt] if output_ids else txt)
else:
print(f"OOV token : {txt}, skipping text : {text}.")
contain_oov = True
break
if contain_oov:
continue
else:
result.append(text_list)
return result