decoder for open vocabulary keyword spotting (#505)
* various fixes to ContextGraph to support open vocabulary keywords decoder * Add keyword spotter runtime * Add binary * First version works * Minor fixes * update text2token * default values * Add jni for kws * add kws android project * Minor fixes * Remove unused interface * Minor fixes * Add workflow * handle extra info in texts * Minor fixes * Add more comments * Fix ci * fix cpp style * Add input box in android demo so that users can specify their keywords * Fix cpp style * Fix comments * Minor fixes * Minor fixes * minor fixes * Minor fixes * Minor fixes * Add CI * Fix code style * cpplint * Fix comments * Fix error
This commit is contained in:
@@ -28,9 +28,14 @@ def cli():
|
||||
)
|
||||
@click.option(
|
||||
"--tokens-type",
|
||||
type=str,
|
||||
type=click.Choice(
|
||||
["cjkchar", "bpe", "cjkchar+bpe", "fpinyin", "ppinyin"], case_sensitive=True
|
||||
),
|
||||
required=True,
|
||||
help="The type of modeling units, should be cjkchar, bpe or cjkchar+bpe",
|
||||
help="""The type of modeling units, should be cjkchar, bpe, cjkchar+bpe, fpinyin or ppinyin.
|
||||
fpinyin means full pinyin, each cjkchar has a pinyin(with tone).
|
||||
ppinyin means partial pinyin, it splits pinyin into initial and final,
|
||||
""",
|
||||
)
|
||||
@click.option(
|
||||
"--bpe-model",
|
||||
@@ -42,14 +47,56 @@ def encode_text(
|
||||
):
|
||||
"""
|
||||
Encode the texts given by the INPUT to tokens and write the results to the OUTPUT.
|
||||
Each line in the texts contains the original phrase, it might also contain some
|
||||
extra items, for example, the boosting score (startting with :), the triggering
|
||||
threshold (startting with #, only used in keyword spotting task) and the original
|
||||
phrase (startting with @). Note: the extra items will be kept same in the output.
|
||||
|
||||
example input 1 (tokens_type = ppinyin):
|
||||
|
||||
小爱同学 :2.0 #0.6 @小爱同学
|
||||
你好问问 :3.5 @你好问问
|
||||
小艺小艺 #0.6 @小艺小艺
|
||||
|
||||
example output 1:
|
||||
|
||||
x iǎo ài t óng x ué :2.0 #0.6 @小爱同学
|
||||
n ǐ h ǎo w èn w èn :3.5 @你好问问
|
||||
x iǎo y ì x iǎo y ì #0.6 @小艺小艺
|
||||
|
||||
example input 2 (tokens_type = bpe):
|
||||
|
||||
HELLO WORLD :1.5 #0.4
|
||||
HI GOOGLE :2.0 #0.8
|
||||
HEY SIRI #0.35
|
||||
|
||||
example output 2:
|
||||
|
||||
▁HE LL O ▁WORLD :1.5 #0.4
|
||||
▁HI ▁GO O G LE :2.0 #0.8
|
||||
▁HE Y ▁S I RI #0.35
|
||||
"""
|
||||
texts = []
|
||||
# extra information like boosting score (start with :), triggering threshold (start with #)
|
||||
# original keyword (start with @)
|
||||
extra_info = []
|
||||
with open(input, "r", encoding="utf8") as f:
|
||||
for line in f:
|
||||
texts.append(line.strip())
|
||||
extra = []
|
||||
text = []
|
||||
toks = line.strip().split()
|
||||
for tok in toks:
|
||||
if tok[0] == ":" or tok[0] == "#" or tok[0] == "@":
|
||||
extra.append(tok)
|
||||
else:
|
||||
text.append(tok)
|
||||
texts.append(" ".join(text))
|
||||
extra_info.append(extra)
|
||||
|
||||
encoded_texts = text2token(
|
||||
texts, tokens=tokens, tokens_type=tokens_type, bpe_model=bpe_model
|
||||
)
|
||||
with open(output, "w", encoding="utf8") as f:
|
||||
for txt in encoded_texts:
|
||||
for i, txt in enumerate(encoded_texts):
|
||||
txt += extra_info[i]
|
||||
f.write(" ".join(txt) + "\n")
|
||||
|
||||
@@ -6,6 +6,9 @@ from typing import List, Optional, Union
|
||||
|
||||
import sentencepiece as spm
|
||||
|
||||
from pypinyin import pinyin
|
||||
from pypinyin.contrib.tone_convert import to_initials, to_finals_tone
|
||||
|
||||
|
||||
def text2token(
|
||||
texts: List[str],
|
||||
@@ -23,7 +26,9 @@ def text2token(
|
||||
tokens:
|
||||
The path of the tokens.txt.
|
||||
tokens_type:
|
||||
The valid values are cjkchar, bpe, cjkchar+bpe.
|
||||
The valid values are cjkchar, bpe, cjkchar+bpe, fpinyin, ppinyin.
|
||||
fpinyin means full pinyin, each cjkchar has a pinyin(with tone).
|
||||
ppinyin means partial pinyin, it splits pinyin into initial and final,
|
||||
bpe_model:
|
||||
The path of the bpe model. Only required when tokens_type is bpe or
|
||||
cjkchar+bpe.
|
||||
@@ -53,6 +58,24 @@ def text2token(
|
||||
texts_list = [list("".join(text.split())) for text in texts]
|
||||
elif tokens_type == "bpe":
|
||||
texts_list = sp.encode(texts, out_type=str)
|
||||
elif "pinyin" in tokens_type:
|
||||
for txt in texts:
|
||||
py = [x[0] for x in pinyin(txt)]
|
||||
if "ppinyin" == tokens_type:
|
||||
res = []
|
||||
for x in py:
|
||||
initial = to_initials(x, strict=False)
|
||||
final = to_finals_tone(x, strict=False)
|
||||
if initial == "" and final == "":
|
||||
res.append(x)
|
||||
else:
|
||||
if initial != "":
|
||||
res.append(initial)
|
||||
if final != "":
|
||||
res.append(final)
|
||||
texts_list.append(res)
|
||||
else:
|
||||
texts_list.append(py)
|
||||
else:
|
||||
assert (
|
||||
tokens_type == "cjkchar+bpe"
|
||||
|
||||
Reference in New Issue
Block a user