130 lines
4.6 KiB
Python
130 lines
4.6 KiB
Python
# Copyright (c) 2023 Xiaomi Corporation
|
|
import re
|
|
|
|
from pathlib import Path
|
|
from typing import List, Optional, Union
|
|
|
|
def text2token(
|
|
texts: List[str],
|
|
tokens: str,
|
|
tokens_type: str = "cjkchar",
|
|
bpe_model: Optional[str] = None,
|
|
output_ids: bool = False,
|
|
) -> List[List[Union[str, int]]]:
|
|
"""
|
|
Encode the given texts (a list of string) to a list of a list of tokens.
|
|
|
|
Args:
|
|
texts:
|
|
The given contexts list (a list of string).
|
|
tokens:
|
|
The path of the tokens.txt.
|
|
tokens_type:
|
|
The valid values are cjkchar, bpe, cjkchar+bpe, fpinyin, ppinyin.
|
|
fpinyin means full pinyin, each cjkchar has a pinyin(with tone).
|
|
ppinyin means partial pinyin, it splits pinyin into initial and final,
|
|
bpe_model:
|
|
The path of the bpe model. Only required when tokens_type is bpe or
|
|
cjkchar+bpe.
|
|
output_ids:
|
|
True to output token ids otherwise tokens.
|
|
Returns:
|
|
Return the encoded texts, it is a list of a list of token ids if output_ids
|
|
is True, or it is a list of list of tokens.
|
|
"""
|
|
try:
|
|
import sentencepiece as spm
|
|
except ImportError:
|
|
print('Please run')
|
|
print(' pip install sentencepiece')
|
|
print('before you continue')
|
|
raise
|
|
|
|
try:
|
|
from pypinyin import pinyin
|
|
from pypinyin.contrib.tone_convert import to_initials, to_finals_tone
|
|
except ImportError:
|
|
print('Please run')
|
|
print(' pip install pypinyin')
|
|
print('before you continue')
|
|
raise
|
|
|
|
assert Path(tokens).is_file(), f"File not exists, {tokens}"
|
|
tokens_table = {}
|
|
with open(tokens, "r", encoding="utf-8") as f:
|
|
for line in f:
|
|
toks = line.strip().split()
|
|
assert len(toks) == 2, len(toks)
|
|
assert toks[0] not in tokens_table, f"Duplicate token: {toks} "
|
|
tokens_table[toks[0]] = int(toks[1])
|
|
|
|
if "bpe" in tokens_type:
|
|
assert Path(bpe_model).is_file(), f"File not exists, {bpe_model}"
|
|
sp = spm.SentencePieceProcessor()
|
|
sp.load(bpe_model)
|
|
|
|
texts_list: List[List[str]] = []
|
|
|
|
if tokens_type == "cjkchar":
|
|
texts_list = [list("".join(text.split())) for text in texts]
|
|
elif tokens_type == "bpe":
|
|
texts_list = sp.encode(texts, out_type=str)
|
|
elif "pinyin" in tokens_type:
|
|
for txt in texts:
|
|
py = [x[0] for x in pinyin(txt)]
|
|
if "ppinyin" == tokens_type:
|
|
res = []
|
|
for x in py:
|
|
initial = to_initials(x, strict=False)
|
|
final = to_finals_tone(x, strict=False)
|
|
if initial == "" and final == "":
|
|
res.append(x)
|
|
else:
|
|
if initial != "":
|
|
res.append(initial)
|
|
if final != "":
|
|
res.append(final)
|
|
texts_list.append(res)
|
|
else:
|
|
texts_list.append(py)
|
|
else:
|
|
assert (
|
|
tokens_type == "cjkchar+bpe"
|
|
), f"Supported tokens_type are cjkchar, bpe, cjkchar+bpe, given {tokens_type}"
|
|
# CJK(China Japan Korea) unicode range is [U+4E00, U+9FFF], ref:
|
|
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
|
pattern = re.compile(r"([\u4e00-\u9fff])")
|
|
for text in texts:
|
|
# Example:
|
|
# txt = "你好 ITS'S OKAY 的"
|
|
# chars = ["你", "好", " ITS'S OKAY ", "的"]
|
|
chars = pattern.split(text)
|
|
mix_chars = [w for w in chars if len(w.strip()) > 0]
|
|
text_list = []
|
|
for ch_or_w in mix_chars:
|
|
# ch_or_w is a single CJK charater(i.e., "你"), do nothing.
|
|
if pattern.fullmatch(ch_or_w) is not None:
|
|
text_list.append(ch_or_w)
|
|
# ch_or_w contains non-CJK charaters(i.e., " IT'S OKAY "),
|
|
# encode ch_or_w using bpe_model.
|
|
else:
|
|
text_list += sp.encode_as_pieces(ch_or_w)
|
|
texts_list.append(text_list)
|
|
|
|
result: List[List[Union[int, str]]] = []
|
|
for text in texts_list:
|
|
text_list = []
|
|
contain_oov = False
|
|
for txt in text:
|
|
if txt in tokens_table:
|
|
text_list.append(tokens_table[txt] if output_ids else txt)
|
|
else:
|
|
print(f"OOV token : {txt}, skipping text : {text}.")
|
|
contain_oov = True
|
|
break
|
|
if contain_oov:
|
|
continue
|
|
else:
|
|
result.append(text_list)
|
|
return result
|