decoder for open vocabulary keyword spotting (#505)

* various fixes to ContextGraph to support open vocabulary keywords decoder

* Add keyword spotter runtime

* Add binary

* First version works

* Minor fixes

* update text2token

* default values

* Add jni for kws

* add kws android project

* Minor fixes

* Remove unused interface

* Minor fixes

* Add workflow

* handle extra info in texts

* Minor fixes

* Add more comments

* Fix ci

* fix cpp style

* Add input box in android demo so that users can specify their keywords

* Fix cpp style

* Fix comments

* Minor fixes

* Minor fixes

* minor fixes

* Minor fixes

* Minor fixes

* Add CI

* Fix code style

* cpplint

* Fix comments

* Fix error
This commit is contained in:
Wei Kang
2024-01-20 22:52:41 +08:00
committed by GitHub
parent bf1dd3daf6
commit b6c020901a
77 changed files with 3316 additions and 68 deletions

View File

@@ -28,9 +28,14 @@ def cli():
)
@click.option(
"--tokens-type",
type=str,
type=click.Choice(
["cjkchar", "bpe", "cjkchar+bpe", "fpinyin", "ppinyin"], case_sensitive=True
),
required=True,
help="The type of modeling units, should be cjkchar, bpe or cjkchar+bpe",
help="""The type of modeling units, should be cjkchar, bpe, cjkchar+bpe, fpinyin or ppinyin.
fpinyin means full pinyin, each cjkchar has a pinyin(with tone).
ppinyin means partial pinyin, it splits pinyin into initial and final,
""",
)
@click.option(
"--bpe-model",
@@ -42,14 +47,56 @@ def encode_text(
):
"""
Encode the texts given by the INPUT to tokens and write the results to the OUTPUT.
Each line in the texts contains the original phrase, it might also contain some
extra items, for example, the boosting score (startting with :), the triggering
threshold (startting with #, only used in keyword spotting task) and the original
phrase (startting with @). Note: the extra items will be kept same in the output.
example input 1 (tokens_type = ppinyin):
小爱同学 :2.0 #0.6 @小爱同学
你好问问 :3.5 @你好问问
小艺小艺 #0.6 @小艺小艺
example output 1:
x iǎo ài t óng x ué :2.0 #0.6 @小爱同学
n ǐ h ǎo w èn w èn :3.5 @你好问问
x iǎo y ì x iǎo y ì #0.6 @小艺小艺
example input 2 (tokens_type = bpe):
HELLO WORLD :1.5 #0.4
HI GOOGLE :2.0 #0.8
HEY SIRI #0.35
example output 2:
▁HE LL O ▁WORLD :1.5 #0.4
▁HI ▁GO O G LE :2.0 #0.8
▁HE Y ▁S I RI #0.35
"""
texts = []
# extra information like boosting score (start with :), triggering threshold (start with #)
# original keyword (start with @)
extra_info = []
with open(input, "r", encoding="utf8") as f:
for line in f:
texts.append(line.strip())
extra = []
text = []
toks = line.strip().split()
for tok in toks:
if tok[0] == ":" or tok[0] == "#" or tok[0] == "@":
extra.append(tok)
else:
text.append(tok)
texts.append(" ".join(text))
extra_info.append(extra)
encoded_texts = text2token(
texts, tokens=tokens, tokens_type=tokens_type, bpe_model=bpe_model
)
with open(output, "w", encoding="utf8") as f:
for txt in encoded_texts:
for i, txt in enumerate(encoded_texts):
txt += extra_info[i]
f.write(" ".join(txt) + "\n")

View File

@@ -6,6 +6,9 @@ from typing import List, Optional, Union
import sentencepiece as spm
from pypinyin import pinyin
from pypinyin.contrib.tone_convert import to_initials, to_finals_tone
def text2token(
texts: List[str],
@@ -23,7 +26,9 @@ def text2token(
tokens:
The path of the tokens.txt.
tokens_type:
The valid values are cjkchar, bpe, cjkchar+bpe.
The valid values are cjkchar, bpe, cjkchar+bpe, fpinyin, ppinyin.
fpinyin means full pinyin, each cjkchar has a pinyin(with tone).
ppinyin means partial pinyin, it splits pinyin into initial and final,
bpe_model:
The path of the bpe model. Only required when tokens_type is bpe or
cjkchar+bpe.
@@ -53,6 +58,24 @@ def text2token(
texts_list = [list("".join(text.split())) for text in texts]
elif tokens_type == "bpe":
texts_list = sp.encode(texts, out_type=str)
elif "pinyin" in tokens_type:
for txt in texts:
py = [x[0] for x in pinyin(txt)]
if "ppinyin" == tokens_type:
res = []
for x in py:
initial = to_initials(x, strict=False)
final = to_finals_tone(x, strict=False)
if initial == "" and final == "":
res.append(x)
else:
if initial != "":
res.append(initial)
if final != "":
res.append(final)
texts_list.append(res)
else:
texts_list.append(py)
else:
assert (
tokens_type == "cjkchar+bpe"