Files
enginex-mr_series-asr/utils/tokenizer.py
2025-08-20 14:29:42 +08:00

161 lines
6.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import subprocess
from enum import Enum
from typing import List
from utils.logger import logger
class TokenizerType(str, Enum):
word = "word"
whitespace = "whitespace"
class LangType(str, Enum):
zh = "zh"
en = "en"
TOKENIZER_MAPPING = dict()
TOKENIZER_MAPPING['zh'] = TokenizerType.word
TOKENIZER_MAPPING['en'] = TokenizerType.whitespace
TOKENIZER_MAPPING['ru'] = TokenizerType.whitespace
TOKENIZER_MAPPING['ar'] = TokenizerType.whitespace
TOKENIZER_MAPPING['tr'] = TokenizerType.whitespace
TOKENIZER_MAPPING['es'] = TokenizerType.whitespace
TOKENIZER_MAPPING['pt'] = TokenizerType.whitespace
TOKENIZER_MAPPING['id'] = TokenizerType.whitespace
TOKENIZER_MAPPING['he'] = TokenizerType.whitespace
TOKENIZER_MAPPING['ja'] = TokenizerType.word
TOKENIZER_MAPPING['pl'] = TokenizerType.whitespace
TOKENIZER_MAPPING['de'] = TokenizerType.whitespace
TOKENIZER_MAPPING['fr'] = TokenizerType.whitespace
TOKENIZER_MAPPING['nl'] = TokenizerType.whitespace
TOKENIZER_MAPPING['el'] = TokenizerType.whitespace
TOKENIZER_MAPPING['vi'] = TokenizerType.whitespace
TOKENIZER_MAPPING['th'] = TokenizerType.whitespace
TOKENIZER_MAPPING['it'] = TokenizerType.whitespace
TOKENIZER_MAPPING['fa'] = TokenizerType.whitespace
TOKENIZER_MAPPING['ti'] = TokenizerType.word
import nltk
import re
from nltk.tokenize import word_tokenize
from nltk.stem import WordNetLemmatizer
lemmatizer = WordNetLemmatizer()
class Tokenizer:
@classmethod
def norm_and_tokenize(cls, sentences: List[str], lang: str = None):
tokenizer = TOKENIZER_MAPPING.get(lang, None)
sentences = cls.replace_general_punc(sentences, tokenizer)
sentences = cls.norm(sentences, lang)
return cls.tokenize(sentences, lang)
@classmethod
def tokenize(cls, sentences: List[str], lang: str = None):
tokenizer = TOKENIZER_MAPPING.get(lang, None)
# sentences = cls.replace_general_punc(sentences, tokenizer)
if tokenizer == TokenizerType.word:
return [[ch for ch in sentence] for sentence in sentences]
elif tokenizer == TokenizerType.whitespace:
return [re.findall(r"\w+", sentence.lower()) for sentence in sentences]
else:
logger.error("找不到对应的分词器")
exit(-1)
@classmethod
def norm(cls, sentences: List[str], lang: LangType = None):
if lang == "zh":
from utils.speechio import textnorm_zh as textnorm
normalizer = textnorm.TextNorm(
to_banjiao=True,
to_upper=True,
to_lower=False,
remove_fillers=True,
remove_erhua=False, # 这里同批量识别不同,改成了 False
check_chars=False,
remove_space=False,
cc_mode="",
)
return [normalizer(sentence) for sentence in sentences]
elif lang == "en":
# pwd = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
# with open('./predict.txt', 'w', encoding='utf-8') as fp:
# for idx, sentence in enumerate(sentences):
# fp.write('%s\t%s\n' % (idx, sentence))
# subprocess.run(
# f'PYTHONPATH={pwd}/utils/speechio python {pwd}/utils/speechio/textnorm_en.py --has_key --to_upper ./predict.txt ./predict_norm.txt',
# shell=True,
# check=True,
# )
# sentence_norm = []
# with open('./predict_norm.txt', 'r', encoding='utf-8') as fp:
# for line in fp.readlines():
# line_split_result = line.strip().split('\t', 1)
# if len(line_split_result) >= 2:
# sentence_norm.append(line_split_result[1])
# else:
# sentence_norm.append("")
# # 有可能没有 norm 后就没了
# return sentence_norm
# sentence_norm = []
# for sentence in sentences:
# doc = _nlp_en(sentence)
# # 保留单词,去除标点、数字、特殊符号;做词形还原
# tokens = [token.lemma_ for token in doc if token.is_alpha]
# tokens = [t.upper() for t in tokens] # 根据你的原逻辑 to_upper=True
# sentence_norm.append(" ".join(tokens))
# return sentence_norm
result = []
for sentence in sentences:
sentence = re.sub(r"[^a-zA-Z\s]", "", sentence)
tokens = word_tokenize(sentence)
tokens = [lemmatizer.lemmatize(t) for t in tokens]
# if to_upper:
# tokens = [t.upper() for t in tokens]
result.append(" ".join(tokens))
return result
else:
punc = "!?。"#$%&'()*+,-/:;<=>[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘'‛“”„‟…‧﹏.`! #$%^&*()_+-=|';\":/.,?><~·!#¥%……&*()——+-=“:’;、。,?》《{}"
return [sentence.translate(str.maketrans(dict.fromkeys(punc, " "))).lower() for sentence in sentences]
@classmethod
def replace_general_punc(cls, sentences: List[str], tokenizer: TokenizerType,language:str = None) -> List[str]:
"""代替原来的函数 utils.metrics.cut_sentence"""
if language:
tokenizer = TOKENIZER_MAPPING.get(language)
general_puncs = [
"······",
"......",
"",
"",
"",
"",
"",
"",
"...",
".",
",",
"?",
"!",
";",
":",
]
if tokenizer == TokenizerType.whitespace:
replacer = " "
else:
replacer = ""
trans = str.maketrans(dict.fromkeys("".join(general_puncs), replacer))
ret_sentences = [""] * len(sentences)
for i, sentence in enumerate(sentences):
sentence = sentence.translate(trans)
sentence = sentence.strip()
sentence = sentence.lower()
ret_sentences[i] = sentence
return ret_sentences