initial commit
This commit is contained in:
22
Dockerfile.funasr-mlu370
Normal file
22
Dockerfile.funasr-mlu370
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
FROM git.modelhub.org.cn:9443/enginex-cambricon/mlu370-pytorch:v25.01-torch2.5.0-torchmlu1.24.1-ubuntu22.04-py310
|
||||||
|
|
||||||
|
WORKDIR /root
|
||||||
|
|
||||||
|
COPY requirements.txt /root
|
||||||
|
|
||||||
|
SHELL ["/bin/bash", "-c"]
|
||||||
|
|
||||||
|
RUN source /torch/venv3/pytorch/bin/activate && \
|
||||||
|
pip uninstall -y deepspeed
|
||||||
|
RUN source /torch/venv3/pytorch/bin/activate && \
|
||||||
|
pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||||
|
|
||||||
|
RUN source /torch/venv3/pytorch/bin/activate && \
|
||||||
|
pip install funasr==1.2.6 openai-whisper -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||||
|
|
||||||
|
ADD . /root/
|
||||||
|
ADD nltk_data.tar.gz /root/
|
||||||
|
RUN tar -xzvf nltk_data.tar.gz
|
||||||
|
RUN cp ./replaced_files/cam_mlu370_x8/auto_model.py /torch/venv3/pytorch/lib/python3.10/site-packages/funasr/auto/
|
||||||
|
EXPOSE 80
|
||||||
|
ENTRYPOINT ["/bin/bash", "-c", "source /torch/venv3/pytorch/bin/activate && ./start_funasr.sh"]
|
||||||
53
README.md
Normal file
53
README.md
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
# 寒武纪MLU370系列 FunASR
|
||||||
|
|
||||||
|
## 镜像构造
|
||||||
|
```shell
|
||||||
|
docker build -f ./Dockerfile.funasr-mlu370 -t <your_image> .
|
||||||
|
```
|
||||||
|
|
||||||
|
## 使用说明
|
||||||
|
### 快速镜像测试
|
||||||
|
对funasr的测试需要在以上构造好的镜像容器内测试,测试步骤
|
||||||
|
1. 本项目中附带上了示例测试数据,音频文件为`lei-jun-test.wav`,音频的识别准确内容文件为`lei-jun.txt`,用户需要准备好相应的ASR模型路径,本例中假设我们已经下载好了SenseVoiceSmall模型存放于/model/SenseVoiceSmall
|
||||||
|
2. 在本项目路径下执行以下快速测试命令
|
||||||
|
```shell
|
||||||
|
docker run -it \
|
||||||
|
-v $PWD:/tmp/workspace \
|
||||||
|
-v /model/SenseVoiceSmall:/model \
|
||||||
|
--device=/dev/cambricon_dev0:/dev/cambricon_dev0 \
|
||||||
|
--device=/dev/cambricon_ctl:/dev/cambricon_ctl \
|
||||||
|
-e MODEL_DIR=/model \
|
||||||
|
-e TEST_FILE=lei-jun-test.wav \
|
||||||
|
-e ANSWER_FILE=lei-jun.txt \
|
||||||
|
-e CUSTOM_DEVICE=MLU370 \
|
||||||
|
--cpus=4 --memory=16g \
|
||||||
|
<your_image>
|
||||||
|
```
|
||||||
|
上述测试指令成功运行将会在terminal中看到对测试音频的识别结果,运行时间以及1-cer效果指标
|
||||||
|
|
||||||
|
### 定制化手动运行
|
||||||
|
|
||||||
|
用户可使用类似上述的docker run指令以交互形式进入镜像中,主要的测试代码为`test_funasr.py`,用户可自行修改代码中需要测试的模型路径、测试文件路径以及调用funASR逻辑
|
||||||
|
|
||||||
|
## 寒武纪MLU370系列加速卡 模型适配情况
|
||||||
|
我们在寒武纪MLU370系列加速卡上针对funASR部分进行了所有大类的适配,测试方式为在Nvidia A100环境下和寒武纪MLU370系列加速卡上对同一段长音频进行语音识别任务,获取运行时间,1-cer指标。运行时都只使用一张显卡
|
||||||
|
|
||||||
|
### 寒武纪MLU370-X8
|
||||||
|
| 模型大类 | 模型地址 |A100运行时间(秒)|寒武纪MLU370-X8运行时间(秒)|A100 1-cer|寒武纪MLU370-X8 1-cer| 备注 |
|
||||||
|
|------|---------------|-----|----|-------|-------|---------------------|
|
||||||
|
| sense_voice | https://www.modelscope.cn/models/iic/SenseVoiceSmall | 1.4021 | 1.7852 | 0.980033 | 0.980033 | |
|
||||||
|
| whisper | https://www.modelscope.cn/models/iic/Whisper-large-v3 | 18.3006 | 45.8322 | 0.910150 | 0.910150 | |
|
||||||
|
| paraformer | https://modelscope.cn/models/iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch | 3.9866 | 7.9367 | 0.955075 | 0.955075 | |
|
||||||
|
| conformer | https://www.modelscope.cn/models/iic/speech_conformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch | 83.8346 | 138.5394 | 0.349418 | 0.346090 | |
|
||||||
|
| uni_asr | https://www.modelscope.cn/models/iic/speech_UniASR-large_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline | 118.4741 | 175.5439 | 0.717138 | 0.950083 | 寒武纪上1-cer效果明显要更高<sup>(*)</sup> |
|
||||||
|
|
||||||
|
### 寒武纪MLU370-X4
|
||||||
|
| 模型大类 | 模型地址 |A100运行时间(秒)|寒武纪MLU370-X4运行时间(秒)|A100 1-cer|寒武纪MLU370-X4 1-cer| 备注 |
|
||||||
|
|------|---------------|-----|----|-------|-------|---------------------|
|
||||||
|
| sense_voice | https://www.modelscope.cn/models/iic/SenseVoiceSmall | 1.2767 | 1.6982 | 0.980033 | 0.980033 | |
|
||||||
|
| whisper | https://www.modelscope.cn/models/iic/Whisper-large-v3 | 18.8058 | 42.1840 | 0.910150 | 0.910150 | |
|
||||||
|
| paraformer | https://modelscope.cn/models/iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch | 4.2537 | 7.9753 | 0.955075 | 0.955075 | |
|
||||||
|
| conformer | https://www.modelscope.cn/models/iic/speech_conformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch | 82.9607 | 128.3273 | 0.349418 | 0.346090 | |
|
||||||
|
| uni_asr | https://www.modelscope.cn/models/iic/speech_UniASR-large_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline | 109.2819 | 112.0971 | 0.717138 | 0.950083 | 寒武纪上1-cer效果明显要更高<sup>(*)</sup> |
|
||||||
|
|
||||||
|
(*) uni_asr模型寒武纪卡表现出与A100以及CPU上运行很不同的行为,uni_asr模型比较特殊的是自带VAD功能,测试音频中途有演讲人停顿的地方,uni_asr模型会只识别出一部分语音内容,所以导致A100、CPU上测出的1-cer分数较低(0.71),但是寒武纪显卡在同样代码调用过程中,识别出了所有音频内容,所以识别分数较高(0.95)。虽然说看起来也比较正确,但是这一行为与A100/CPU以及其他国产显卡同样这一套测试的行为预期不同,可能模型内部VAD功能失效
|
||||||
4
download_nltk_model.py
Normal file
4
download_nltk_model.py
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
import nltk
|
||||||
|
nltk.download('punkt')
|
||||||
|
nltk.download('wordnet')
|
||||||
|
nltk.download('omw-1.4')
|
||||||
BIN
lei-jun-test.wav
Normal file
BIN
lei-jun-test.wav
Normal file
Binary file not shown.
1
lei-jun.txt
Normal file
1
lei-jun.txt
Normal file
@@ -0,0 +1 @@
|
|||||||
|
朋友们晚上好,欢迎大家来参加今天晚上的活动,谢谢大家。这是我第四次办年度演讲,前三次呢因为疫情的原因,都在小米科技园内举办。现场呢人很少。这是第四次,我们仔细想了想,我们还是想办一个比较大的聚会。然后呢让我们的新朋友老朋友一起聚一聚。今天的话呢我们就在北京的国家会议中心呢举办了这么一个活动。现场呢来了很多人大概有3500人。还有很多很多的朋友呢,通过观看直播的方式来参与。再一次呢对大家的参加,表示感谢,谢谢大家。两个月前我参加了今年武汉大学的毕业典礼。今年呢是武汉大学建校130周年,作为校友被母校邀请在毕业典礼上致辞,这对我来说,是至高无上的荣誉。站在讲台的那一刻,面对全校师生,关于武大的所有的记忆一下子涌现在脑海里。今天呢我就先和大家聊聊武大往事。那还是36年前,1987年我呢考上了武汉大学的计算机系。在武汉大学的图书馆里看了一本书《硅谷之火》,建立了我一生的梦想。看完书以后,热血沸腾,激动得睡不着觉。我还记得那天晚上星光很亮。我就在武大的操场上,就是屏幕上这个操场,走了一圈又一圈,走了整整一个晚上。我心里有团火,我也想办一个伟大的公司,就是这样,梦想之火,在我心里彻底点燃了。但是一个大一的新生,但是一个大一的新生,一个从县城里出来的年轻人,什么也不会,什么也没有,就想创办一家伟大的公司,这不就是天方夜谭吗?这么离谱的一个梦想,该如何实现呢?那天晚上我想了一整晚上,说实话,越想越糊涂,完全理不清头绪,后来我在想,哎干脆别想了,把书念好是正事,所以呢我就下定决心认认真真读书,那么我怎么能够把书读得不同凡响呢?
|
||||||
BIN
nltk_data.tar.gz
Normal file
BIN
nltk_data.tar.gz
Normal file
Binary file not shown.
678
replaced_files/cam_mlu370_x8/auto_model.py
Normal file
678
replaced_files/cam_mlu370_x8/auto_model.py
Normal file
@@ -0,0 +1,678 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- encoding: utf-8 -*-
|
||||||
|
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
||||||
|
# MIT License (https://opensource.org/licenses/MIT)
|
||||||
|
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import copy
|
||||||
|
import torch
|
||||||
|
import random
|
||||||
|
import string
|
||||||
|
import logging
|
||||||
|
import os.path
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from omegaconf import DictConfig, ListConfig
|
||||||
|
from funasr.utils.misc import deep_update
|
||||||
|
from funasr.register import tables
|
||||||
|
from funasr.utils.load_utils import load_bytes
|
||||||
|
from funasr.download.file import download_from_url
|
||||||
|
from funasr.utils.timestamp_tools import timestamp_sentence
|
||||||
|
from funasr.utils.timestamp_tools import timestamp_sentence_en
|
||||||
|
from funasr.download.download_model_from_hub import download_model
|
||||||
|
from funasr.utils.vad_utils import slice_padding_audio_samples
|
||||||
|
from funasr.utils.vad_utils import merge_vad
|
||||||
|
from funasr.utils.load_utils import load_audio_text_image_video
|
||||||
|
from funasr.train_utils.set_all_random_seed import set_all_random_seed
|
||||||
|
from funasr.train_utils.load_pretrained_model import load_pretrained_model
|
||||||
|
from funasr.utils import export_utils
|
||||||
|
from funasr.utils import misc
|
||||||
|
|
||||||
|
try:
|
||||||
|
from funasr.models.campplus.utils import sv_chunk, postprocess, distribute_spk
|
||||||
|
from funasr.models.campplus.cluster_backend import ClusterBackend
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None):
|
||||||
|
""" """
|
||||||
|
data_list = []
|
||||||
|
key_list = []
|
||||||
|
filelist = [".scp", ".txt", ".json", ".jsonl", ".text"]
|
||||||
|
|
||||||
|
chars = string.ascii_letters + string.digits
|
||||||
|
if isinstance(data_in, str):
|
||||||
|
if data_in.startswith("http://") or data_in.startswith("https://"): # url
|
||||||
|
data_in = download_from_url(data_in)
|
||||||
|
|
||||||
|
if isinstance(data_in, str) and os.path.exists(
|
||||||
|
data_in
|
||||||
|
): # wav_path; filelist: wav.scp, file.jsonl;text.txt;
|
||||||
|
_, file_extension = os.path.splitext(data_in)
|
||||||
|
file_extension = file_extension.lower()
|
||||||
|
if file_extension in filelist: # filelist: wav.scp, file.jsonl;text.txt;
|
||||||
|
with open(data_in, encoding="utf-8") as fin:
|
||||||
|
for line in fin:
|
||||||
|
key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
|
||||||
|
if data_in.endswith(".jsonl"): # file.jsonl: json.dumps({"source": data})
|
||||||
|
lines = json.loads(line.strip())
|
||||||
|
data = lines["source"]
|
||||||
|
key = data["key"] if "key" in data else key
|
||||||
|
else: # filelist, wav.scp, text.txt: id \t data or data
|
||||||
|
lines = line.strip().split(maxsplit=1)
|
||||||
|
data = lines[1] if len(lines) > 1 else lines[0]
|
||||||
|
key = lines[0] if len(lines) > 1 else key
|
||||||
|
|
||||||
|
data_list.append(data)
|
||||||
|
key_list.append(key)
|
||||||
|
else:
|
||||||
|
if key is None:
|
||||||
|
# key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
|
||||||
|
key = misc.extract_filename_without_extension(data_in)
|
||||||
|
data_list = [data_in]
|
||||||
|
key_list = [key]
|
||||||
|
elif isinstance(data_in, (list, tuple)):
|
||||||
|
if data_type is not None and isinstance(data_type, (list, tuple)): # mutiple inputs
|
||||||
|
data_list_tmp = []
|
||||||
|
for data_in_i, data_type_i in zip(data_in, data_type):
|
||||||
|
key_list, data_list_i = prepare_data_iterator(
|
||||||
|
data_in=data_in_i, data_type=data_type_i
|
||||||
|
)
|
||||||
|
data_list_tmp.append(data_list_i)
|
||||||
|
data_list = []
|
||||||
|
for item in zip(*data_list_tmp):
|
||||||
|
data_list.append(item)
|
||||||
|
else:
|
||||||
|
# [audio sample point, fbank, text]
|
||||||
|
data_list = data_in
|
||||||
|
key_list = []
|
||||||
|
for data_i in data_in:
|
||||||
|
if isinstance(data_i, str) and os.path.exists(data_i):
|
||||||
|
key = misc.extract_filename_without_extension(data_i)
|
||||||
|
else:
|
||||||
|
if key is None:
|
||||||
|
key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
|
||||||
|
key_list.append(key)
|
||||||
|
|
||||||
|
else: # raw text; audio sample point, fbank; bytes
|
||||||
|
if isinstance(data_in, bytes): # audio bytes
|
||||||
|
data_in = load_bytes(data_in)
|
||||||
|
if key is None:
|
||||||
|
key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
|
||||||
|
data_list = [data_in]
|
||||||
|
key_list = [key]
|
||||||
|
|
||||||
|
return key_list, data_list
|
||||||
|
|
||||||
|
|
||||||
|
class AutoModel:
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
|
||||||
|
try:
|
||||||
|
from funasr.utils.version_checker import check_for_update
|
||||||
|
|
||||||
|
check_for_update(disable=kwargs.get("disable_update", False))
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())
|
||||||
|
logging.basicConfig(level=log_level)
|
||||||
|
|
||||||
|
model, kwargs = self.build_model(**kwargs)
|
||||||
|
|
||||||
|
# if vad_model is not None, build vad model else None
|
||||||
|
vad_model = kwargs.get("vad_model", None)
|
||||||
|
vad_kwargs = {} if kwargs.get("vad_kwargs", {}) is None else kwargs.get("vad_kwargs", {})
|
||||||
|
if vad_model is not None:
|
||||||
|
logging.info("Building VAD model.")
|
||||||
|
vad_kwargs["model"] = vad_model
|
||||||
|
vad_kwargs["model_revision"] = kwargs.get("vad_model_revision", "master")
|
||||||
|
vad_kwargs["device"] = kwargs["device"]
|
||||||
|
vad_model, vad_kwargs = self.build_model(**vad_kwargs)
|
||||||
|
|
||||||
|
# if punc_model is not None, build punc model else None
|
||||||
|
punc_model = kwargs.get("punc_model", None)
|
||||||
|
punc_kwargs = {} if kwargs.get("punc_kwargs", {}) is None else kwargs.get("punc_kwargs", {})
|
||||||
|
if punc_model is not None:
|
||||||
|
logging.info("Building punc model.")
|
||||||
|
punc_kwargs["model"] = punc_model
|
||||||
|
punc_kwargs["model_revision"] = kwargs.get("punc_model_revision", "master")
|
||||||
|
punc_kwargs["device"] = kwargs["device"]
|
||||||
|
punc_model, punc_kwargs = self.build_model(**punc_kwargs)
|
||||||
|
|
||||||
|
# if spk_model is not None, build spk model else None
|
||||||
|
spk_model = kwargs.get("spk_model", None)
|
||||||
|
spk_kwargs = {} if kwargs.get("spk_kwargs", {}) is None else kwargs.get("spk_kwargs", {})
|
||||||
|
cb_kwargs = (
|
||||||
|
{} if spk_kwargs.get("cb_kwargs", {}) is None else spk_kwargs.get("cb_kwargs", {})
|
||||||
|
)
|
||||||
|
if spk_model is not None:
|
||||||
|
logging.info("Building SPK model.")
|
||||||
|
spk_kwargs["model"] = spk_model
|
||||||
|
spk_kwargs["model_revision"] = kwargs.get("spk_model_revision", "master")
|
||||||
|
spk_kwargs["device"] = kwargs["device"]
|
||||||
|
spk_model, spk_kwargs = self.build_model(**spk_kwargs)
|
||||||
|
self.cb_model = ClusterBackend(**cb_kwargs).to(kwargs["device"])
|
||||||
|
spk_mode = kwargs.get("spk_mode", "punc_segment")
|
||||||
|
if spk_mode not in ["default", "vad_segment", "punc_segment"]:
|
||||||
|
logging.error("spk_mode should be one of default, vad_segment and punc_segment.")
|
||||||
|
self.spk_mode = spk_mode
|
||||||
|
|
||||||
|
self.kwargs = kwargs
|
||||||
|
self.model = model
|
||||||
|
self.vad_model = vad_model
|
||||||
|
self.vad_kwargs = vad_kwargs
|
||||||
|
self.punc_model = punc_model
|
||||||
|
self.punc_kwargs = punc_kwargs
|
||||||
|
self.spk_model = spk_model
|
||||||
|
self.spk_kwargs = spk_kwargs
|
||||||
|
self.model_path = kwargs.get("model_path")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def build_model(**kwargs):
|
||||||
|
assert "model" in kwargs
|
||||||
|
if "model_conf" not in kwargs:
|
||||||
|
logging.info("download models from model hub: {}".format(kwargs.get("hub", "ms")))
|
||||||
|
kwargs = download_model(**kwargs)
|
||||||
|
|
||||||
|
set_all_random_seed(kwargs.get("seed", 0))
|
||||||
|
|
||||||
|
device = kwargs.get("device", "cuda")
|
||||||
|
# both check cuda and mlu
|
||||||
|
mlu_available = device.startswith("mlu") and hasattr(torch, "mlu") and torch.mlu.is_available()
|
||||||
|
cuda_available = device.startswith("cuda") and torch.cuda.is_available() and kwargs.get("ngpu", 1) != 0
|
||||||
|
if not (mlu_available or cuda_available):
|
||||||
|
device = "cpu"
|
||||||
|
kwargs["batch_size"] = 1
|
||||||
|
kwargs["device"] = device
|
||||||
|
|
||||||
|
torch.set_num_threads(kwargs.get("ncpu", 4))
|
||||||
|
|
||||||
|
# build tokenizer
|
||||||
|
tokenizer = kwargs.get("tokenizer", None)
|
||||||
|
kwargs["tokenizer"] = tokenizer
|
||||||
|
kwargs["vocab_size"] = -1
|
||||||
|
|
||||||
|
if tokenizer is not None:
|
||||||
|
tokenizers = (
|
||||||
|
tokenizer.split(",") if isinstance(tokenizer, str) else tokenizer
|
||||||
|
) # type of tokenizers is list!!!
|
||||||
|
tokenizers_conf = kwargs.get("tokenizer_conf", {})
|
||||||
|
tokenizers_build = []
|
||||||
|
vocab_sizes = []
|
||||||
|
token_lists = []
|
||||||
|
|
||||||
|
### === only for kws ===
|
||||||
|
token_list_files = kwargs.get("token_lists", [])
|
||||||
|
seg_dicts = kwargs.get("seg_dicts", [])
|
||||||
|
### === only for kws ===
|
||||||
|
|
||||||
|
if not isinstance(tokenizers_conf, (list, tuple, ListConfig)):
|
||||||
|
tokenizers_conf = [tokenizers_conf] * len(tokenizers)
|
||||||
|
|
||||||
|
for i, tokenizer in enumerate(tokenizers):
|
||||||
|
tokenizer_class = tables.tokenizer_classes.get(tokenizer)
|
||||||
|
tokenizer_conf = tokenizers_conf[i]
|
||||||
|
|
||||||
|
### === only for kws ===
|
||||||
|
if len(token_list_files) > 1:
|
||||||
|
tokenizer_conf["token_list"] = token_list_files[i]
|
||||||
|
if len(seg_dicts) > 1:
|
||||||
|
tokenizer_conf["seg_dict"] = seg_dicts[i]
|
||||||
|
### === only for kws ===
|
||||||
|
|
||||||
|
tokenizer = tokenizer_class(**tokenizer_conf)
|
||||||
|
tokenizers_build.append(tokenizer)
|
||||||
|
token_list = tokenizer.token_list if hasattr(tokenizer, "token_list") else None
|
||||||
|
token_list = (
|
||||||
|
tokenizer.get_vocab() if hasattr(tokenizer, "get_vocab") else token_list
|
||||||
|
)
|
||||||
|
vocab_size = -1
|
||||||
|
if token_list is not None:
|
||||||
|
vocab_size = len(token_list)
|
||||||
|
|
||||||
|
if vocab_size == -1 and hasattr(tokenizer, "get_vocab_size"):
|
||||||
|
vocab_size = tokenizer.get_vocab_size()
|
||||||
|
token_lists.append(token_list)
|
||||||
|
vocab_sizes.append(vocab_size)
|
||||||
|
|
||||||
|
if len(tokenizers_build) <= 1:
|
||||||
|
tokenizers_build = tokenizers_build[0]
|
||||||
|
token_lists = token_lists[0]
|
||||||
|
vocab_sizes = vocab_sizes[0]
|
||||||
|
|
||||||
|
kwargs["tokenizer"] = tokenizers_build
|
||||||
|
kwargs["vocab_size"] = vocab_sizes
|
||||||
|
kwargs["token_list"] = token_lists
|
||||||
|
|
||||||
|
# build frontend
|
||||||
|
frontend = kwargs.get("frontend", None)
|
||||||
|
kwargs["input_size"] = None
|
||||||
|
if frontend is not None:
|
||||||
|
frontend_class = tables.frontend_classes.get(frontend)
|
||||||
|
frontend = frontend_class(**kwargs.get("frontend_conf", {}))
|
||||||
|
kwargs["input_size"] = (
|
||||||
|
frontend.output_size() if hasattr(frontend, "output_size") else None
|
||||||
|
)
|
||||||
|
kwargs["frontend"] = frontend
|
||||||
|
# build model
|
||||||
|
model_class = tables.model_classes.get(kwargs["model"])
|
||||||
|
assert model_class is not None, f'{kwargs["model"]} is not registered'
|
||||||
|
model_conf = {}
|
||||||
|
deep_update(model_conf, kwargs.get("model_conf", {}))
|
||||||
|
deep_update(model_conf, kwargs)
|
||||||
|
model = model_class(**model_conf)
|
||||||
|
|
||||||
|
# init_param
|
||||||
|
init_param = kwargs.get("init_param", None)
|
||||||
|
if init_param is not None:
|
||||||
|
if os.path.exists(init_param):
|
||||||
|
logging.info(f"Loading pretrained params from {init_param}")
|
||||||
|
load_pretrained_model(
|
||||||
|
model=model,
|
||||||
|
path=init_param,
|
||||||
|
ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True),
|
||||||
|
oss_bucket=kwargs.get("oss_bucket", None),
|
||||||
|
scope_map=kwargs.get("scope_map", []),
|
||||||
|
excludes=kwargs.get("excludes", None),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(f"error, init_param does not exist!: {init_param}")
|
||||||
|
|
||||||
|
# fp16
|
||||||
|
if kwargs.get("fp16", False):
|
||||||
|
model.to(torch.float16)
|
||||||
|
elif kwargs.get("bf16", False):
|
||||||
|
model.to(torch.bfloat16)
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
if not kwargs.get("disable_log", True):
|
||||||
|
tables.print()
|
||||||
|
|
||||||
|
return model, kwargs
|
||||||
|
|
||||||
|
def __call__(self, *args, **cfg):
|
||||||
|
kwargs = self.kwargs
|
||||||
|
deep_update(kwargs, cfg)
|
||||||
|
res = self.model(*args, kwargs)
|
||||||
|
return res
|
||||||
|
|
||||||
|
def generate(self, input, input_len=None, **cfg):
|
||||||
|
if self.vad_model is None:
|
||||||
|
return self.inference(input, input_len=input_len, **cfg)
|
||||||
|
|
||||||
|
else:
|
||||||
|
return self.inference_with_vad(input, input_len=input_len, **cfg)
|
||||||
|
|
||||||
|
def inference(self, input, input_len=None, model=None, kwargs=None, key=None, **cfg):
|
||||||
|
kwargs = self.kwargs if kwargs is None else kwargs
|
||||||
|
if "cache" in kwargs:
|
||||||
|
kwargs.pop("cache")
|
||||||
|
deep_update(kwargs, cfg)
|
||||||
|
model = self.model if model is None else model
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
batch_size = kwargs.get("batch_size", 1)
|
||||||
|
# if kwargs.get("device", "cpu") == "cpu":
|
||||||
|
# batch_size = 1
|
||||||
|
|
||||||
|
key_list, data_list = prepare_data_iterator(
|
||||||
|
input, input_len=input_len, data_type=kwargs.get("data_type", None), key=key
|
||||||
|
)
|
||||||
|
|
||||||
|
speed_stats = {}
|
||||||
|
asr_result_list = []
|
||||||
|
num_samples = len(data_list)
|
||||||
|
disable_pbar = self.kwargs.get("disable_pbar", False)
|
||||||
|
pbar = (
|
||||||
|
tqdm(colour="blue", total=num_samples, dynamic_ncols=True) if not disable_pbar else None
|
||||||
|
)
|
||||||
|
time_speech_total = 0.0
|
||||||
|
time_escape_total = 0.0
|
||||||
|
for beg_idx in range(0, num_samples, batch_size):
|
||||||
|
end_idx = min(num_samples, beg_idx + batch_size)
|
||||||
|
data_batch = data_list[beg_idx:end_idx]
|
||||||
|
key_batch = key_list[beg_idx:end_idx]
|
||||||
|
batch = {"data_in": data_batch, "key": key_batch}
|
||||||
|
|
||||||
|
if (end_idx - beg_idx) == 1 and kwargs.get("data_type", None) == "fbank": # fbank
|
||||||
|
batch["data_in"] = data_batch[0]
|
||||||
|
batch["data_lengths"] = input_len
|
||||||
|
|
||||||
|
time1 = time.perf_counter()
|
||||||
|
with torch.no_grad():
|
||||||
|
res = model.inference(**batch, **kwargs)
|
||||||
|
if isinstance(res, (list, tuple)):
|
||||||
|
results = res[0] if len(res) > 0 else [{"text": ""}]
|
||||||
|
meta_data = res[1] if len(res) > 1 else {}
|
||||||
|
time2 = time.perf_counter()
|
||||||
|
|
||||||
|
asr_result_list.extend(results)
|
||||||
|
|
||||||
|
# batch_data_time = time_per_frame_s * data_batch_i["speech_lengths"].sum().item()
|
||||||
|
batch_data_time = meta_data.get("batch_data_time", -1)
|
||||||
|
time_escape = time2 - time1
|
||||||
|
speed_stats["load_data"] = meta_data.get("load_data", 0.0)
|
||||||
|
speed_stats["extract_feat"] = meta_data.get("extract_feat", 0.0)
|
||||||
|
speed_stats["forward"] = f"{time_escape:0.3f}"
|
||||||
|
speed_stats["batch_size"] = f"{len(results)}"
|
||||||
|
speed_stats["rtf"] = f"{(time_escape) / batch_data_time:0.3f}"
|
||||||
|
description = f"{speed_stats}, "
|
||||||
|
if pbar:
|
||||||
|
pbar.update(end_idx - beg_idx)
|
||||||
|
pbar.set_description(description)
|
||||||
|
time_speech_total += batch_data_time
|
||||||
|
time_escape_total += time_escape
|
||||||
|
|
||||||
|
if pbar:
|
||||||
|
# pbar.update(1)
|
||||||
|
pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}")
|
||||||
|
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
if device.type == "cuda":
|
||||||
|
with torch.cuda.device(device):
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
elif device.type == "mlu":
|
||||||
|
with torch.mlu.device(device):
|
||||||
|
torch.mlu.empty_cache()
|
||||||
|
return asr_result_list
|
||||||
|
|
||||||
|
def inference_with_vad(self, input, input_len=None, **cfg):
|
||||||
|
kwargs = self.kwargs
|
||||||
|
# step.1: compute the vad model
|
||||||
|
deep_update(self.vad_kwargs, cfg)
|
||||||
|
beg_vad = time.time()
|
||||||
|
res = self.inference(
|
||||||
|
input, input_len=input_len, model=self.vad_model, kwargs=self.vad_kwargs, **cfg
|
||||||
|
)
|
||||||
|
end_vad = time.time()
|
||||||
|
|
||||||
|
# FIX(gcf): concat the vad clips for sense vocie model for better aed
|
||||||
|
if cfg.get("merge_vad", False):
|
||||||
|
for i in range(len(res)):
|
||||||
|
res[i]["value"] = merge_vad(
|
||||||
|
res[i]["value"], kwargs.get("merge_length_s", 15) * 1000
|
||||||
|
)
|
||||||
|
|
||||||
|
# step.2 compute asr model
|
||||||
|
model = self.model
|
||||||
|
deep_update(kwargs, cfg)
|
||||||
|
batch_size = max(int(kwargs.get("batch_size_s", 300)) * 1000, 1)
|
||||||
|
batch_size_threshold_ms = int(kwargs.get("batch_size_threshold_s", 60)) * 1000
|
||||||
|
kwargs["batch_size"] = batch_size
|
||||||
|
|
||||||
|
key_list, data_list = prepare_data_iterator(
|
||||||
|
input, input_len=input_len, data_type=kwargs.get("data_type", None)
|
||||||
|
)
|
||||||
|
results_ret_list = []
|
||||||
|
time_speech_total_all_samples = 1e-6
|
||||||
|
|
||||||
|
beg_total = time.time()
|
||||||
|
pbar_total = (
|
||||||
|
tqdm(colour="red", total=len(res), dynamic_ncols=True)
|
||||||
|
if not kwargs.get("disable_pbar", False)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
for i in range(len(res)):
|
||||||
|
key = res[i]["key"]
|
||||||
|
vadsegments = res[i]["value"]
|
||||||
|
input_i = data_list[i]
|
||||||
|
fs = kwargs["frontend"].fs if hasattr(kwargs["frontend"], "fs") else 16000
|
||||||
|
speech = load_audio_text_image_video(input_i, fs=fs, audio_fs=kwargs.get("fs", 16000))
|
||||||
|
speech_lengths = len(speech)
|
||||||
|
n = len(vadsegments)
|
||||||
|
data_with_index = [(vadsegments[i], i) for i in range(n)]
|
||||||
|
sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0])
|
||||||
|
results_sorted = []
|
||||||
|
|
||||||
|
if not len(sorted_data):
|
||||||
|
results_ret_list.append({"key": key, "text": "", "timestamp": []})
|
||||||
|
logging.info("decoding, utt: {}, empty speech".format(key))
|
||||||
|
continue
|
||||||
|
|
||||||
|
if len(sorted_data) > 0 and len(sorted_data[0]) > 0:
|
||||||
|
batch_size = max(batch_size, sorted_data[0][0][1] - sorted_data[0][0][0])
|
||||||
|
|
||||||
|
if kwargs["device"] == "cpu":
|
||||||
|
batch_size = 0
|
||||||
|
|
||||||
|
beg_idx = 0
|
||||||
|
beg_asr_total = time.time()
|
||||||
|
time_speech_total_per_sample = speech_lengths / 16000
|
||||||
|
time_speech_total_all_samples += time_speech_total_per_sample
|
||||||
|
|
||||||
|
# pbar_sample = tqdm(colour="blue", total=n, dynamic_ncols=True)
|
||||||
|
|
||||||
|
all_segments = []
|
||||||
|
max_len_in_batch = 0
|
||||||
|
end_idx = 1
|
||||||
|
for j, _ in enumerate(range(0, n)):
|
||||||
|
# pbar_sample.update(1)
|
||||||
|
sample_length = sorted_data[j][0][1] - sorted_data[j][0][0]
|
||||||
|
potential_batch_length = max(max_len_in_batch, sample_length) * (j + 1 - beg_idx)
|
||||||
|
# batch_size_ms_cum += sorted_data[j][0][1] - sorted_data[j][0][0]
|
||||||
|
if (
|
||||||
|
j < n - 1
|
||||||
|
and sample_length < batch_size_threshold_ms
|
||||||
|
and potential_batch_length < batch_size
|
||||||
|
):
|
||||||
|
max_len_in_batch = max(max_len_in_batch, sample_length)
|
||||||
|
end_idx += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
speech_j, speech_lengths_j = slice_padding_audio_samples(
|
||||||
|
speech, speech_lengths, sorted_data[beg_idx:end_idx]
|
||||||
|
)
|
||||||
|
results = self.inference(
|
||||||
|
speech_j, input_len=None, model=model, kwargs=kwargs, **cfg
|
||||||
|
)
|
||||||
|
if self.spk_model is not None:
|
||||||
|
# compose vad segments: [[start_time_sec, end_time_sec, speech], [...]]
|
||||||
|
for _b in range(len(speech_j)):
|
||||||
|
vad_segments = [
|
||||||
|
[
|
||||||
|
sorted_data[beg_idx:end_idx][_b][0][0] / 1000.0,
|
||||||
|
sorted_data[beg_idx:end_idx][_b][0][1] / 1000.0,
|
||||||
|
np.array(speech_j[_b]),
|
||||||
|
]
|
||||||
|
]
|
||||||
|
segments = sv_chunk(vad_segments)
|
||||||
|
all_segments.extend(segments)
|
||||||
|
speech_b = [i[2] for i in segments]
|
||||||
|
spk_res = self.inference(
|
||||||
|
speech_b, input_len=None, model=self.spk_model, kwargs=kwargs, **cfg
|
||||||
|
)
|
||||||
|
results[_b]["spk_embedding"] = spk_res[0]["spk_embedding"]
|
||||||
|
beg_idx = end_idx
|
||||||
|
end_idx += 1
|
||||||
|
max_len_in_batch = sample_length
|
||||||
|
if len(results) < 1:
|
||||||
|
continue
|
||||||
|
results_sorted.extend(results)
|
||||||
|
|
||||||
|
# end_asr_total = time.time()
|
||||||
|
# time_escape_total_per_sample = end_asr_total - beg_asr_total
|
||||||
|
# pbar_sample.update(1)
|
||||||
|
# pbar_sample.set_description(f"rtf_avg_per_sample: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
|
||||||
|
# f"time_speech_total_per_sample: {time_speech_total_per_sample: 0.3f}, "
|
||||||
|
# f"time_escape_total_per_sample: {time_escape_total_per_sample:0.3f}")
|
||||||
|
|
||||||
|
if len(results_sorted) != n:
|
||||||
|
results_ret_list.append({"key": key, "text": "", "timestamp": []})
|
||||||
|
logging.info("decoding, utt: {}, empty result".format(key))
|
||||||
|
continue
|
||||||
|
restored_data = [0] * n
|
||||||
|
for j in range(n):
|
||||||
|
index = sorted_data[j][1]
|
||||||
|
restored_data[index] = results_sorted[j]
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
# results combine for texts, timestamps, speaker embeddings and others
|
||||||
|
# TODO: rewrite for clean code
|
||||||
|
for j in range(n):
|
||||||
|
for k, v in restored_data[j].items():
|
||||||
|
if k.startswith("timestamp"):
|
||||||
|
if k not in result:
|
||||||
|
result[k] = []
|
||||||
|
for t in restored_data[j][k]:
|
||||||
|
t[0] += vadsegments[j][0]
|
||||||
|
t[1] += vadsegments[j][0]
|
||||||
|
result[k].extend(restored_data[j][k])
|
||||||
|
elif k == "spk_embedding":
|
||||||
|
if k not in result:
|
||||||
|
result[k] = restored_data[j][k]
|
||||||
|
else:
|
||||||
|
result[k] = torch.cat([result[k], restored_data[j][k]], dim=0)
|
||||||
|
elif "text" in k:
|
||||||
|
if k not in result:
|
||||||
|
result[k] = restored_data[j][k]
|
||||||
|
else:
|
||||||
|
result[k] += " " + restored_data[j][k]
|
||||||
|
else:
|
||||||
|
if k not in result:
|
||||||
|
result[k] = restored_data[j][k]
|
||||||
|
else:
|
||||||
|
result[k] += restored_data[j][k]
|
||||||
|
|
||||||
|
if not len(result["text"].strip()):
|
||||||
|
continue
|
||||||
|
return_raw_text = kwargs.get("return_raw_text", False)
|
||||||
|
# step.3 compute punc model
|
||||||
|
raw_text = None
|
||||||
|
if self.punc_model is not None:
|
||||||
|
deep_update(self.punc_kwargs, cfg)
|
||||||
|
punc_res = self.inference(
|
||||||
|
result["text"], model=self.punc_model, kwargs=self.punc_kwargs, **cfg
|
||||||
|
)
|
||||||
|
raw_text = copy.copy(result["text"])
|
||||||
|
if return_raw_text:
|
||||||
|
result["raw_text"] = raw_text
|
||||||
|
result["text"] = punc_res[0]["text"]
|
||||||
|
|
||||||
|
# speaker embedding cluster after resorted
|
||||||
|
if self.spk_model is not None and kwargs.get("return_spk_res", True):
|
||||||
|
if raw_text is None:
|
||||||
|
logging.error("Missing punc_model, which is required by spk_model.")
|
||||||
|
all_segments = sorted(all_segments, key=lambda x: x[0])
|
||||||
|
spk_embedding = result["spk_embedding"]
|
||||||
|
labels = self.cb_model(
|
||||||
|
spk_embedding.cpu(), oracle_num=kwargs.get("preset_spk_num", None)
|
||||||
|
)
|
||||||
|
# del result['spk_embedding']
|
||||||
|
sv_output = postprocess(all_segments, None, labels, spk_embedding.cpu())
|
||||||
|
if self.spk_mode == "vad_segment": # recover sentence_list
|
||||||
|
sentence_list = []
|
||||||
|
for rest, vadsegment in zip(restored_data, vadsegments):
|
||||||
|
if "timestamp" not in rest:
|
||||||
|
logging.error(
|
||||||
|
"Only 'iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch' \
|
||||||
|
and 'iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'\
|
||||||
|
can predict timestamp, and speaker diarization relies on timestamps."
|
||||||
|
)
|
||||||
|
sentence_list.append(
|
||||||
|
{
|
||||||
|
"start": vadsegment[0],
|
||||||
|
"end": vadsegment[1],
|
||||||
|
"sentence": rest["text"],
|
||||||
|
"timestamp": rest["timestamp"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif self.spk_mode == "punc_segment":
|
||||||
|
if "timestamp" not in result:
|
||||||
|
logging.error(
|
||||||
|
"Only 'iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch' \
|
||||||
|
and 'iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'\
|
||||||
|
can predict timestamp, and speaker diarization relies on timestamps."
|
||||||
|
)
|
||||||
|
if kwargs.get("en_post_proc", False):
|
||||||
|
sentence_list = timestamp_sentence_en(
|
||||||
|
punc_res[0]["punc_array"],
|
||||||
|
result["timestamp"],
|
||||||
|
raw_text,
|
||||||
|
return_raw_text=return_raw_text,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sentence_list = timestamp_sentence(
|
||||||
|
punc_res[0]["punc_array"],
|
||||||
|
result["timestamp"],
|
||||||
|
raw_text,
|
||||||
|
return_raw_text=return_raw_text,
|
||||||
|
)
|
||||||
|
distribute_spk(sentence_list, sv_output)
|
||||||
|
result["sentence_info"] = sentence_list
|
||||||
|
elif kwargs.get("sentence_timestamp", False):
|
||||||
|
if not len(result["text"].strip()):
|
||||||
|
sentence_list = []
|
||||||
|
else:
|
||||||
|
if kwargs.get("en_post_proc", False):
|
||||||
|
sentence_list = timestamp_sentence_en(
|
||||||
|
punc_res[0]["punc_array"],
|
||||||
|
result["timestamp"],
|
||||||
|
raw_text,
|
||||||
|
return_raw_text=return_raw_text,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sentence_list = timestamp_sentence(
|
||||||
|
punc_res[0]["punc_array"],
|
||||||
|
result["timestamp"],
|
||||||
|
raw_text,
|
||||||
|
return_raw_text=return_raw_text,
|
||||||
|
)
|
||||||
|
result["sentence_info"] = sentence_list
|
||||||
|
if "spk_embedding" in result:
|
||||||
|
del result["spk_embedding"]
|
||||||
|
|
||||||
|
result["key"] = key
|
||||||
|
results_ret_list.append(result)
|
||||||
|
end_asr_total = time.time()
|
||||||
|
time_escape_total_per_sample = end_asr_total - beg_asr_total
|
||||||
|
if pbar_total:
|
||||||
|
pbar_total.update(1)
|
||||||
|
pbar_total.set_description(
|
||||||
|
f"rtf_avg: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
|
||||||
|
f"time_speech: {time_speech_total_per_sample: 0.3f}, "
|
||||||
|
f"time_escape: {time_escape_total_per_sample:0.3f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# end_total = time.time()
|
||||||
|
# time_escape_total_all_samples = end_total - beg_total
|
||||||
|
# print(f"rtf_avg_all: {time_escape_total_all_samples / time_speech_total_all_samples:0.3f}, "
|
||||||
|
# f"time_speech_all: {time_speech_total_all_samples: 0.3f}, "
|
||||||
|
# f"time_escape_all: {time_escape_total_all_samples:0.3f}")
|
||||||
|
return results_ret_list
|
||||||
|
|
||||||
|
def export(self, input=None, **cfg):
|
||||||
|
"""
|
||||||
|
|
||||||
|
:param input:
|
||||||
|
:param type:
|
||||||
|
:param quantize:
|
||||||
|
:param fallback_num:
|
||||||
|
:param calib_num:
|
||||||
|
:param opset_version:
|
||||||
|
:param cfg:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
|
||||||
|
device = cfg.get("device", "cpu")
|
||||||
|
model = self.model.to(device=device)
|
||||||
|
kwargs = self.kwargs
|
||||||
|
deep_update(kwargs, cfg)
|
||||||
|
kwargs["device"] = device
|
||||||
|
del kwargs["model"]
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
type = kwargs.get("type", "onnx")
|
||||||
|
|
||||||
|
key_list, data_list = prepare_data_iterator(
|
||||||
|
input, input_len=None, data_type=kwargs.get("data_type", None), key=None
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
export_dir = export_utils.export(model=model, data_in=data_list, **kwargs)
|
||||||
|
|
||||||
|
return export_dir
|
||||||
11
requirements.txt
Normal file
11
requirements.txt
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
requests
|
||||||
|
wheel
|
||||||
|
websocket-client
|
||||||
|
pydantic<2.0.0
|
||||||
|
numpy<2.0
|
||||||
|
PYYaml
|
||||||
|
Levenshtein
|
||||||
|
ruamel.yaml
|
||||||
|
nltk==3.7
|
||||||
|
pynini==2.1.6
|
||||||
|
soundfile
|
||||||
3
start_funasr.sh
Executable file
3
start_funasr.sh
Executable file
@@ -0,0 +1,3 @@
|
|||||||
|
unset CUDA_VISIBLE_DEVICES
|
||||||
|
unset NVIDIA_VISIBLE_DEVICES
|
||||||
|
python3 ./test_funasr.py
|
||||||
192
test_funasr.py
Normal file
192
test_funasr.py
Normal file
@@ -0,0 +1,192 @@
|
|||||||
|
import os
|
||||||
|
import time
|
||||||
|
import torchaudio
|
||||||
|
import torch
|
||||||
|
from funasr import AutoModel
|
||||||
|
from funasr.utils.postprocess_utils import rich_transcription_postprocess
|
||||||
|
from utils.calculate import cal_per_cer
|
||||||
|
import json
|
||||||
|
|
||||||
|
CUSTOM_DEVICE = os.getenv("CUSTOM_DEVICE", "").lower()
|
||||||
|
if CUSTOM_DEVICE.startswith("mlu"):
|
||||||
|
import torch_mlu
|
||||||
|
|
||||||
|
def split_audio(waveform, sample_rate, segment_seconds=20):
|
||||||
|
segment_samples = segment_seconds * sample_rate
|
||||||
|
segments = []
|
||||||
|
for i in range(0, waveform.shape[1], segment_samples):
|
||||||
|
segment = waveform[:, i:i + segment_samples]
|
||||||
|
if segment.shape[1] > 0:
|
||||||
|
segments.append(segment)
|
||||||
|
return segments
|
||||||
|
|
||||||
|
def determine_model_type(model_name):
|
||||||
|
if "sensevoice" in model_name.lower():
|
||||||
|
return "sense_voice"
|
||||||
|
elif "whisper" in model_name.lower():
|
||||||
|
return "whisper"
|
||||||
|
elif "paraformer" in model_name.lower():
|
||||||
|
return "paraformer"
|
||||||
|
elif "conformer" in model_name.lower():
|
||||||
|
return "conformer"
|
||||||
|
elif "uniasr" in model_name.lower():
|
||||||
|
return "uni_asr"
|
||||||
|
else:
|
||||||
|
return "unknown"
|
||||||
|
|
||||||
|
def test_funasr(model_dir, audio_file, answer_file, use_gpu):
|
||||||
|
model_name = os.path.basename(model_dir)
|
||||||
|
model_type = determine_model_type(model_name)
|
||||||
|
|
||||||
|
device = "cpu"
|
||||||
|
if use_gpu:
|
||||||
|
if CUSTOM_DEVICE.startswith("mlu"):
|
||||||
|
device = "mlu:0"
|
||||||
|
else:
|
||||||
|
device = "cuda:0"
|
||||||
|
|
||||||
|
if device == "cuda:0" and torch.cuda.get_device_name() == "Iluvatar BI-V100" and model_type == "whisper":
|
||||||
|
# 天垓100情况下的Whisper需要绕过不支持算子
|
||||||
|
torch.backends.cuda.enable_flash_sdp(False)
|
||||||
|
torch.backends.cuda.enable_mem_efficient_sdp(False)
|
||||||
|
torch.backends.cuda.enable_math_sdp(True)
|
||||||
|
print(f"device: {device}", flush=True)
|
||||||
|
|
||||||
|
# 不使用VAD, punct,spk模型,就测试原始ASR能力
|
||||||
|
model = AutoModel(
|
||||||
|
model=model_dir,
|
||||||
|
# vad_model="fsmn-vad",
|
||||||
|
# vad_kwargs={"max_single_segment_time": 30000},
|
||||||
|
vad_model=None,
|
||||||
|
device=device,
|
||||||
|
disable_update=True
|
||||||
|
)
|
||||||
|
|
||||||
|
waveform, sample_rate = torchaudio.load(audio_file)
|
||||||
|
# print(waveform.shape)
|
||||||
|
duration = waveform.shape[1] / sample_rate
|
||||||
|
segments = split_audio(waveform, sample_rate, segment_seconds=20)
|
||||||
|
|
||||||
|
generated_text = ""
|
||||||
|
processing_time = 0
|
||||||
|
|
||||||
|
if model_type == "uni_asr":
|
||||||
|
# uni_asr比较特殊,设计就是处理长音频的(自带VAD),切分的话前20s如果几乎没有人讲话全是音乐直接会报错
|
||||||
|
# 因为可能会被切掉所有音频导致实际编解码输入为0
|
||||||
|
ts1 = time.time()
|
||||||
|
res = model.generate(
|
||||||
|
input=audio_file
|
||||||
|
)
|
||||||
|
generated_text = res[0]["text"]
|
||||||
|
ts2 = time.time()
|
||||||
|
processing_time = ts2 - ts1
|
||||||
|
else:
|
||||||
|
# 按照切分的音频依次输入
|
||||||
|
for i, segment in enumerate(segments):
|
||||||
|
segment_path = f"temp_seg_{i}.wav"
|
||||||
|
torchaudio.save(segment_path, segment, sample_rate)
|
||||||
|
ts1 = time.time()
|
||||||
|
if model_type == "sense_voice":
|
||||||
|
res = model.generate(
|
||||||
|
input=segment_path,
|
||||||
|
cache={},
|
||||||
|
language="auto", # "zn", "en", "yue", "ja", "ko", "nospeech"
|
||||||
|
use_itn=True,
|
||||||
|
batch_size_s=60,
|
||||||
|
merge_vad=False,
|
||||||
|
# merge_length_s=15,
|
||||||
|
)
|
||||||
|
text = rich_transcription_postprocess(res[0]["text"])
|
||||||
|
elif model_type == "whisper":
|
||||||
|
DecodingOptions = {
|
||||||
|
"task": "transcribe",
|
||||||
|
"language": "zh",
|
||||||
|
"beam_size": None,
|
||||||
|
"fp16": False,
|
||||||
|
"without_timestamps": False,
|
||||||
|
"prompt": None,
|
||||||
|
}
|
||||||
|
res = model.generate(
|
||||||
|
DecodingOptions=DecodingOptions,
|
||||||
|
input=segment_path,
|
||||||
|
batch_size_s=0,
|
||||||
|
)
|
||||||
|
text = res[0]["text"]
|
||||||
|
elif model_type == "paraformer":
|
||||||
|
res = model.generate(
|
||||||
|
input=segment_path,
|
||||||
|
batch_size_s=300
|
||||||
|
)
|
||||||
|
text = res[0]["text"]
|
||||||
|
# paraformer模型会一个字一个字输出,中间夹太多空格会影响1-cer的结果
|
||||||
|
text = text.replace(" ", "")
|
||||||
|
elif model_type == "conformer":
|
||||||
|
res = model.generate(
|
||||||
|
input=segment_path,
|
||||||
|
batch_size_s=300
|
||||||
|
)
|
||||||
|
text = res[0]["text"]
|
||||||
|
# elif model_type == "uni_asr":
|
||||||
|
# if i == 0:
|
||||||
|
# os.remove(segment_path)
|
||||||
|
# continue
|
||||||
|
# res = model.generate(
|
||||||
|
# input=segment_path
|
||||||
|
# )
|
||||||
|
# text = res[0]["text"]
|
||||||
|
else:
|
||||||
|
raise RuntimeError("unknown model type")
|
||||||
|
ts2 = time.time()
|
||||||
|
generated_text += text
|
||||||
|
processing_time += (ts2 - ts1)
|
||||||
|
os.remove(segment_path)
|
||||||
|
|
||||||
|
rtf = processing_time / duration
|
||||||
|
print("Text:", generated_text, flush=True)
|
||||||
|
print(f"Audio duration:\t{duration:.3f} s", flush=True)
|
||||||
|
print(f"Elapsed:\t{processing_time:.3f} s", flush=True)
|
||||||
|
print(f"RTF = {processing_time:.3f}/{duration:.3f} = {rtf:.3f}", flush=True)
|
||||||
|
with open(answer_file, 'r', encoding='utf-8') as f:
|
||||||
|
groundtruth_text = f.read()
|
||||||
|
acc = cal_per_cer(generated_text, groundtruth_text, "zh")
|
||||||
|
print(f"1-cer = {acc}", flush=True)
|
||||||
|
|
||||||
|
return processing_time, acc, generated_text
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_result = {
|
||||||
|
"time_cuda": 0,
|
||||||
|
"acc_cuda": 0,
|
||||||
|
"text_cuda": "",
|
||||||
|
"success": False
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
model_dict = {
|
||||||
|
"sense_voice": "SenseVoiceSmall",
|
||||||
|
"paraformer": "speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
|
||||||
|
"conformer": "speech_conformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch",
|
||||||
|
"whisper": "Whisper-large-v3",
|
||||||
|
"uni_asr": "speech_UniASR-large_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline"
|
||||||
|
}
|
||||||
|
LOCAL_TEST = os.getenv("LOCAL_TEST", "false").lower() == "true"
|
||||||
|
K8S_TEST = os.getenv("K8S_TEST", "false").lower() == "true"
|
||||||
|
workspace_path = "../" if LOCAL_TEST else "/tmp/workspace"
|
||||||
|
model_dir = os.path.join("/model", model_dict["paraformer"]) if LOCAL_TEST else os.environ["MODEL_DIR"]
|
||||||
|
audio_file = "lei-jun-test.wav" if LOCAL_TEST else os.path.join(workspace_path, os.environ["TEST_FILE"])
|
||||||
|
answer_file = "lei-jun.txt" if LOCAL_TEST else os.path.join(workspace_path, os.environ["ANSWER_FILE"])
|
||||||
|
result_file = "result.json" if LOCAL_TEST else os.path.join(workspace_path, os.environ["RESULT_FILE"])
|
||||||
|
# test_funasr(model_dir, audio_file, answer_file, False)
|
||||||
|
processing_time, acc, generated_text = test_funasr(model_dir, audio_file, answer_file, True)
|
||||||
|
test_result["time_cuda"] = processing_time
|
||||||
|
test_result["acc_cuda"] = acc
|
||||||
|
test_result["text_cuda"] = generated_text
|
||||||
|
test_result["success"] = True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"ASR测试出错: {e}", flush=True)
|
||||||
|
with open(result_file, "w", encoding="utf-8") as fp:
|
||||||
|
json.dump(test_result, fp, ensure_ascii=False, indent=4)
|
||||||
|
# 如果是SUT起来镜像的话,需要加上下面让pod永不停止以迎合k8s deployment, 本地测试以及docker run均不需要
|
||||||
|
if K8S_TEST:
|
||||||
|
print(f"Start to sleep indefinitely", flush=True)
|
||||||
|
time.sleep(100000)
|
||||||
0
utils/__init__.py
Normal file
0
utils/__init__.py
Normal file
BIN
utils/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
utils/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/calculate.cpython-310.pyc
Normal file
BIN
utils/__pycache__/calculate.cpython-310.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/logger.cpython-310.pyc
Normal file
BIN
utils/__pycache__/logger.cpython-310.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/model.cpython-310.pyc
Normal file
BIN
utils/__pycache__/model.cpython-310.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/reader.cpython-310.pyc
Normal file
BIN
utils/__pycache__/reader.cpython-310.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/tokenizer.cpython-310.pyc
Normal file
BIN
utils/__pycache__/tokenizer.cpython-310.pyc
Normal file
Binary file not shown.
834
utils/calculate.py
Normal file
834
utils/calculate.py
Normal file
@@ -0,0 +1,834 @@
|
|||||||
|
import re
|
||||||
|
import time
|
||||||
|
|
||||||
|
import Levenshtein
|
||||||
|
from utils.tokenizer import Tokenizer
|
||||||
|
from typing import List, Tuple
|
||||||
|
from utils.model import SegmentModel
|
||||||
|
from utils.model import AudioItem
|
||||||
|
from utils.reader import read_data
|
||||||
|
from utils.logger import logger
|
||||||
|
from utils.model import VoiceSegment
|
||||||
|
from utils.model import WordModel
|
||||||
|
from difflib import SequenceMatcher
|
||||||
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_punctuation_ratio(datas: List[Tuple[AudioItem, List[SegmentModel]]]) -> float:
|
||||||
|
"""
|
||||||
|
计算acc
|
||||||
|
:param datas:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
total_standard_punctuation = 0
|
||||||
|
total_gen_punctuation = 0
|
||||||
|
for answer, results in datas:
|
||||||
|
# 计算 1-cer。
|
||||||
|
# 计算标点符号比例。
|
||||||
|
# 将所有的text组合起来与标答计算 1-cer
|
||||||
|
standard_text = ""
|
||||||
|
for item in answer.voice:
|
||||||
|
standard_text = standard_text + item.answer
|
||||||
|
gen_text = ""
|
||||||
|
for item in results:
|
||||||
|
gen_text = gen_text + item.text
|
||||||
|
|
||||||
|
total_standard_punctuation = total_standard_punctuation + count_punctuation(standard_text)
|
||||||
|
total_gen_punctuation = total_gen_punctuation + count_punctuation(gen_text)
|
||||||
|
|
||||||
|
punctuation_ratio = total_gen_punctuation / total_standard_punctuation
|
||||||
|
return punctuation_ratio
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_acc(datas: List[Tuple[AudioItem, List[SegmentModel]]], language: str) -> float:
|
||||||
|
"""
|
||||||
|
计算acc
|
||||||
|
:param datas:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
total_acc = 0
|
||||||
|
for answer, results in datas:
|
||||||
|
# 计算 1-cer。
|
||||||
|
# 计算标点符号比例。
|
||||||
|
# 将所有的text组合起来与标答计算 1-cer
|
||||||
|
standard_text = ""
|
||||||
|
for item in answer.voice:
|
||||||
|
standard_text = standard_text + item.answer
|
||||||
|
gen_text = ""
|
||||||
|
for item in results:
|
||||||
|
gen_text = gen_text + item.text
|
||||||
|
acc = cal_per_cer(gen_text, standard_text, language)
|
||||||
|
total_acc = total_acc + acc
|
||||||
|
acc = total_acc / len(datas)
|
||||||
|
return acc
|
||||||
|
|
||||||
|
|
||||||
|
def get_alignment_type(language: str):
|
||||||
|
chart_langs = ["zh", "ja", "ko", "th", "lo", "my", "km", "bo"] # 中文、日文、韩文、泰语、老挝语、缅甸语、高棉语、藏语
|
||||||
|
if language in chart_langs:
|
||||||
|
return "chart"
|
||||||
|
else:
|
||||||
|
return "word"
|
||||||
|
|
||||||
|
|
||||||
|
def cal_per_cer(text: str, answer: str, language: str):
|
||||||
|
if not answer:
|
||||||
|
return 1.0 if text else 0.0 # 如果标签为空,预测也为空则为 0,否则为 1
|
||||||
|
|
||||||
|
text = remove_punctuation(text)
|
||||||
|
answer = remove_punctuation(answer)
|
||||||
|
|
||||||
|
text_chars = Tokenizer.norm_and_tokenize([text], language)[0]
|
||||||
|
answer_chars = Tokenizer.norm_and_tokenize([answer], language)[0]
|
||||||
|
|
||||||
|
# 如果答案为空,返回默认准确率
|
||||||
|
if not answer_chars:
|
||||||
|
return 0.0 # 或者 1.0,取决于你的设计需求
|
||||||
|
|
||||||
|
alignment_type = get_alignment_type(language)
|
||||||
|
if alignment_type == "chart":
|
||||||
|
text_chars = list(text)
|
||||||
|
answer_chars = list(answer)
|
||||||
|
ops = Levenshtein.editops(text_chars, answer_chars)
|
||||||
|
insert = len(list(filter(lambda x: x[0] == "insert", ops)))
|
||||||
|
delete = len(list(filter(lambda x: x[0] == "delete", ops)))
|
||||||
|
replace = len(list(filter(lambda x: x[0] == "replace", ops)))
|
||||||
|
else:
|
||||||
|
matcher = SequenceMatcher(None, text_chars, answer_chars)
|
||||||
|
|
||||||
|
insert = 0
|
||||||
|
delete = 0
|
||||||
|
replace = 0
|
||||||
|
|
||||||
|
for tag, i1, i2, j1, j2 in matcher.get_opcodes():
|
||||||
|
if tag == 'replace':
|
||||||
|
replace += max(i2 - i1, j2 - j1)
|
||||||
|
elif tag == 'delete':
|
||||||
|
delete += (i2 - i1)
|
||||||
|
elif tag == 'insert':
|
||||||
|
insert += (j2 - j1)
|
||||||
|
|
||||||
|
cer = (insert + delete + replace) / len(answer_chars)
|
||||||
|
acc = 1 - cer
|
||||||
|
return acc
|
||||||
|
|
||||||
|
|
||||||
|
def cal_total_cer(samples: list):
|
||||||
|
"""
|
||||||
|
samples: List of tuples [(pred_text, ref_text), ...]
|
||||||
|
"""
|
||||||
|
total_insert = 0
|
||||||
|
total_delete = 0
|
||||||
|
total_replace = 0
|
||||||
|
total_ref_len = 0
|
||||||
|
|
||||||
|
for text, answer in samples:
|
||||||
|
|
||||||
|
if not answer:
|
||||||
|
return 1.0 if text else 0.0 # 如果标签为空,预测也为空则为 0,否则为 1
|
||||||
|
|
||||||
|
text = remove_punctuation(text)
|
||||||
|
answer = remove_punctuation(answer)
|
||||||
|
|
||||||
|
text_chars = list(text)
|
||||||
|
answer_chars = list(answer)
|
||||||
|
|
||||||
|
ops = Levenshtein.editops(text_chars, answer_chars)
|
||||||
|
insert = len(list(filter(lambda x: x[0] == "insert", ops)))
|
||||||
|
delete = len(list(filter(lambda x: x[0] == "delete", ops)))
|
||||||
|
replace = len(list(filter(lambda x: x[0] == "replace", ops)))
|
||||||
|
|
||||||
|
total_insert += insert
|
||||||
|
total_delete += delete
|
||||||
|
total_replace += replace
|
||||||
|
total_ref_len += len(answer_chars)
|
||||||
|
|
||||||
|
total_cer = (total_insert + total_delete + total_replace) / total_ref_len if total_ref_len > 0 else 0.0
|
||||||
|
total_acc = 1 - total_cer
|
||||||
|
return total_acc
|
||||||
|
|
||||||
|
|
||||||
|
def remove_punctuation(text: str) -> str:
|
||||||
|
# 去除中英文标点
|
||||||
|
return re.sub(r'[^\w\s\u4e00-\u9fff]', '', text)
|
||||||
|
|
||||||
|
|
||||||
|
def count_punctuation(text: str) -> int:
|
||||||
|
"""统计文本中的指定标点个数"""
|
||||||
|
return len(re.findall(r"[^\w\s\u4e00-\u9fa5]", text))
|
||||||
|
|
||||||
|
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_standard_sentence_delay(datas: List[Tuple[AudioItem, List[SegmentModel]]]) -> float:
|
||||||
|
for audio_item, asr_results in datas:
|
||||||
|
if not audio_item.voice:
|
||||||
|
continue # 没有标答内容
|
||||||
|
#
|
||||||
|
audio_texts = []
|
||||||
|
asr_texts = []
|
||||||
|
|
||||||
|
ref = audio_item.voice[0] # 默认取第一个标答段
|
||||||
|
ref_end_ms = int(ref.end * 1000)
|
||||||
|
|
||||||
|
# 找出所有ASR中包含标答尾字的文本(简化为包含标答最后一个字)
|
||||||
|
target_char = ref.answer.strip()[-1] # 标答尾字
|
||||||
|
matching_results = [r for r in asr_results if target_char in r.text and r.words]
|
||||||
|
|
||||||
|
if not matching_results:
|
||||||
|
continue # 没有找到包含尾字的ASR段
|
||||||
|
|
||||||
|
# 找出这些ASR段中最后一个词的end_time,最大值作为尾字时间
|
||||||
|
latest_word_time = max(word.end_time for r in matching_results for word in r.words)
|
||||||
|
|
||||||
|
delay = latest_word_time - ref_end_ms
|
||||||
|
|
||||||
|
print(audio_item)
|
||||||
|
print(asr_results)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def align_texts(ref_text: str, hyp_text: str) -> List[Tuple[Optional[int], Optional[int]]]:
|
||||||
|
"""
|
||||||
|
使用编辑距离计算两个字符串的字符对齐
|
||||||
|
返回:[(ref_idx, hyp_idx), ...]
|
||||||
|
"""
|
||||||
|
ops = Levenshtein.editops(ref_text, hyp_text)
|
||||||
|
ref_len = len(ref_text)
|
||||||
|
hyp_len = len(hyp_text)
|
||||||
|
ref_idx = 0
|
||||||
|
hyp_idx = 0
|
||||||
|
alignment = []
|
||||||
|
|
||||||
|
for op, i, j in ops:
|
||||||
|
while ref_idx < i and hyp_idx < j:
|
||||||
|
alignment.append((ref_idx, hyp_idx))
|
||||||
|
ref_idx += 1
|
||||||
|
hyp_idx += 1
|
||||||
|
if op == "replace":
|
||||||
|
alignment.append((i, j))
|
||||||
|
ref_idx = i + 1
|
||||||
|
hyp_idx = j + 1
|
||||||
|
elif op == "delete":
|
||||||
|
alignment.append((i, None))
|
||||||
|
ref_idx = i + 1
|
||||||
|
elif op == "insert":
|
||||||
|
alignment.append((None, j))
|
||||||
|
hyp_idx = j + 1
|
||||||
|
|
||||||
|
while ref_idx < ref_len and hyp_idx < hyp_len:
|
||||||
|
alignment.append((ref_idx, hyp_idx))
|
||||||
|
ref_idx += 1
|
||||||
|
hyp_idx += 1
|
||||||
|
while ref_idx < ref_len:
|
||||||
|
alignment.append((ref_idx, None))
|
||||||
|
ref_idx += 1
|
||||||
|
while hyp_idx < hyp_len:
|
||||||
|
alignment.append((None, hyp_idx))
|
||||||
|
hyp_idx += 1
|
||||||
|
|
||||||
|
return alignment
|
||||||
|
|
||||||
|
|
||||||
|
def align_tokens(ref_text: List[str], hyp_text: List[str]) -> List[Tuple[Optional[int], Optional[int]]]:
|
||||||
|
"""
|
||||||
|
计算分词后的两个字符串的对齐
|
||||||
|
返回:[(ref_idx, hyp_idx), ...]
|
||||||
|
"""
|
||||||
|
matcher = SequenceMatcher(None, ref_text, hyp_text)
|
||||||
|
alignment = []
|
||||||
|
|
||||||
|
for tag, i1, i2, j1, j2 in matcher.get_opcodes():
|
||||||
|
if tag == 'equal' or tag == 'replace':
|
||||||
|
for r, h in zip(range(i1, i2), range(j1, j2)):
|
||||||
|
alignment.append((r, h))
|
||||||
|
elif tag == 'delete':
|
||||||
|
for r in range(i1, i2):
|
||||||
|
alignment.append((r, None))
|
||||||
|
elif tag == 'insert':
|
||||||
|
for h in range(j1, j2):
|
||||||
|
alignment.append((None, h))
|
||||||
|
return alignment
|
||||||
|
|
||||||
|
|
||||||
|
def find_tail_word_time(
|
||||||
|
ref_text: List[str],
|
||||||
|
pred_text: List[str],
|
||||||
|
merged_words: List[WordModel],
|
||||||
|
char2word_idx: List[int],
|
||||||
|
alignment: List[Tuple[Optional[int], Optional[int]]],
|
||||||
|
) -> Optional[WordModel]:
|
||||||
|
# alignment = align_texts(ref_text, merged_text)
|
||||||
|
"""
|
||||||
|
根据标答文本 ref_text 找尾字(非标点)
|
||||||
|
通过 alignment 找对应合成文本 pred_text 尾字索引
|
||||||
|
再通过 char2word_idx 找对应word索引,返回对应的 WordModel
|
||||||
|
|
||||||
|
:param ref_text: 标答文本字符列表
|
||||||
|
:param pred_text: 合成文本字符列表
|
||||||
|
:param merged_words: 合成文本对应的WordModel列表
|
||||||
|
:param char2word_idx: 合成文本每个字符对应的WordModel索引
|
||||||
|
:param alignment: ref_text和pred_text的字符对齐列表 (ref_idx, hyp_idx)
|
||||||
|
:return: 对应尾字的WordModel 或 None
|
||||||
|
"""
|
||||||
|
punct_set = set(",。!?、,.!?;;")
|
||||||
|
|
||||||
|
# 1. 找到ref_text中最后一个非标点字符的索引 tail_ref_idx
|
||||||
|
tail_ref_idx = len(ref_text) - 1
|
||||||
|
while tail_ref_idx >= 0 and ref_text[tail_ref_idx] in punct_set:
|
||||||
|
tail_ref_idx -= 1
|
||||||
|
|
||||||
|
if tail_ref_idx < 0:
|
||||||
|
# 全是标点,找不到尾字
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 2. 找 alignment 中 ref_idx == tail_ref_idx 对应的 hyp_idx
|
||||||
|
tail_hyp_idx = None
|
||||||
|
for ref_idx, hyp_idx in reversed(alignment):
|
||||||
|
if ref_idx == tail_ref_idx and hyp_idx is not None:
|
||||||
|
tail_hyp_idx = hyp_idx
|
||||||
|
break
|
||||||
|
|
||||||
|
if tail_hyp_idx is None:
|
||||||
|
# 没有对应的hyp_idx
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 3. hyp_idx 超出范围
|
||||||
|
if tail_hyp_idx >= len(char2word_idx):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 4. 通过 char2word_idx 找对应 word 索引
|
||||||
|
word_index = char2word_idx[tail_hyp_idx]
|
||||||
|
|
||||||
|
if word_index >= len(merged_words):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 5. 返回对应的WordModel
|
||||||
|
return merged_words[word_index]
|
||||||
|
|
||||||
|
|
||||||
|
def find_head_word_time(
|
||||||
|
ref_text: List[str],
|
||||||
|
pred_text: List[str],
|
||||||
|
merged_words: List[WordModel],
|
||||||
|
char2word_idx: List[int],
|
||||||
|
alignment: List[Tuple[Optional[int], Optional[int]]],
|
||||||
|
) -> Optional[WordModel]:
|
||||||
|
"""
|
||||||
|
找标答首字在ASR中的start_time
|
||||||
|
|
||||||
|
参数:
|
||||||
|
ref_text:标答完整文本
|
||||||
|
merged_text:ASR合并后的完整文本(逐字)
|
||||||
|
merged_words:ASR合并的WordModel列表
|
||||||
|
char2word_idx:字符到词的映射索引列表
|
||||||
|
|
||||||
|
返回:
|
||||||
|
找到的首字对应词的start_time(毫秒),没找到返回None
|
||||||
|
"""
|
||||||
|
# alignment = align_texts(ref_text, merged_text)
|
||||||
|
|
||||||
|
ref_head_index = 0 # 首字索引固定0
|
||||||
|
|
||||||
|
for ref_idx, hyp_idx in alignment:
|
||||||
|
if ref_idx == ref_head_index and hyp_idx is not None:
|
||||||
|
if 0 <= hyp_idx < len(char2word_idx):
|
||||||
|
word_idx = char2word_idx[hyp_idx]
|
||||||
|
return merged_words[word_idx]
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def merge_asr_results(asr_list: List[SegmentModel]) -> Tuple[str, List[WordModel], List[int]]:
|
||||||
|
"""
|
||||||
|
合并多个 ASRResultModel 成一个大文本和 word 列表,同时建立每个字符对应的 WordModel 索引
|
||||||
|
返回:
|
||||||
|
- 合并文本 merged_text
|
||||||
|
- WordModel 列表 merged_words
|
||||||
|
- 每个字符所在 WordModel 的索引 char2word_idx
|
||||||
|
"""
|
||||||
|
# merged_text = ""
|
||||||
|
# merged_words = []
|
||||||
|
# char2word_idx = []
|
||||||
|
#
|
||||||
|
# for asr in asr_list:
|
||||||
|
# if not asr.text or not asr.words:
|
||||||
|
# continue
|
||||||
|
# merged_text += asr.text
|
||||||
|
# for word in asr.words:
|
||||||
|
# word.segment = asr
|
||||||
|
# merged_words.append(word)
|
||||||
|
# for ch in word.text:
|
||||||
|
# char2word_idx.append(len(merged_words) - 1)
|
||||||
|
# return merged_text, merged_words, char2word_idx
|
||||||
|
"""
|
||||||
|
合并多个 ASRResultModel 成一个大文本和 word 列表,
|
||||||
|
去掉标点符号,建立每个字符对应的 WordModel 索引
|
||||||
|
|
||||||
|
返回:
|
||||||
|
- 去标点后的合并文本 merged_text
|
||||||
|
- WordModel 列表 merged_words(包含标点)
|
||||||
|
- 去标点后的每个字符对应 WordModel 的索引 char2word_idx
|
||||||
|
"""
|
||||||
|
punct_set = set(",。!?、,.!?;;") # 需要过滤的标点集合
|
||||||
|
|
||||||
|
merged_text = ""
|
||||||
|
merged_words = []
|
||||||
|
char2word_idx = []
|
||||||
|
|
||||||
|
for asr in asr_list:
|
||||||
|
if not asr.text or not asr.words:
|
||||||
|
continue
|
||||||
|
merged_words_start_len = len(merged_words)
|
||||||
|
for word in asr.words:
|
||||||
|
word.segment = asr
|
||||||
|
merged_words.append(word)
|
||||||
|
|
||||||
|
# 遍历所有word,拼接时去掉标点,同时维护 char2word_idx
|
||||||
|
for idx_in_asr, word in enumerate(asr.words):
|
||||||
|
word_idx = merged_words_start_len + idx_in_asr
|
||||||
|
for ch in word.text:
|
||||||
|
if ch not in punct_set:
|
||||||
|
merged_text += ch
|
||||||
|
char2word_idx.append(word_idx)
|
||||||
|
|
||||||
|
return merged_text, merged_words, char2word_idx
|
||||||
|
|
||||||
|
|
||||||
|
def rebuild_char2word_idx(pred_tokens: List[str], merged_words: List[WordModel]) -> List[int]:
|
||||||
|
"""
|
||||||
|
重新构建 char2word_idx,使其与 pred_tokens 一一对应
|
||||||
|
"""
|
||||||
|
char2word_idx = []
|
||||||
|
word_char_idx = 0
|
||||||
|
for word_idx, word in enumerate(merged_words):
|
||||||
|
for _ in word.text:
|
||||||
|
if word_char_idx < len(pred_tokens):
|
||||||
|
char2word_idx.append(word_idx)
|
||||||
|
word_char_idx += 1
|
||||||
|
return char2word_idx
|
||||||
|
|
||||||
|
|
||||||
|
def build_hyp_token_to_asr_chart_index(
|
||||||
|
hyp_tokens: List[str],
|
||||||
|
asr_words: List[WordModel]
|
||||||
|
) -> List[int]:
|
||||||
|
"""
|
||||||
|
建立从 hyp_token 索引到 asr_word 索引的映射
|
||||||
|
假设 asr_words 的 text 组成 hyp_tokens 的连续子串(简单匹配)
|
||||||
|
"""
|
||||||
|
hyp_to_asr_word_idx = [-1] * len(hyp_tokens)
|
||||||
|
|
||||||
|
i_asr = 0
|
||||||
|
i_hyp = 0
|
||||||
|
|
||||||
|
while i_asr < len(asr_words) and i_hyp < len(hyp_tokens):
|
||||||
|
asr_word = asr_words[i_asr].text
|
||||||
|
length = len(asr_word)
|
||||||
|
# 拼接 hyp_tokens 从 i_hyp 开始的 length 个 token
|
||||||
|
hyp_substr = "".join(hyp_tokens[i_hyp:i_hyp + length])
|
||||||
|
if hyp_substr == asr_word:
|
||||||
|
# 匹配成功,建立映射
|
||||||
|
for k in range(i_hyp, i_hyp + length):
|
||||||
|
hyp_to_asr_word_idx[k] = i_asr
|
||||||
|
i_hyp += length
|
||||||
|
i_asr += 1
|
||||||
|
else:
|
||||||
|
# 如果不匹配,尝试扩大或缩小匹配长度(容错)
|
||||||
|
# 也可以根据具体情况改进此逻辑
|
||||||
|
# 这里简化处理,跳过一个hyp token
|
||||||
|
i_hyp += 1
|
||||||
|
|
||||||
|
return hyp_to_asr_word_idx
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
def normalize(text: str) -> str:
|
||||||
|
return re.sub(r"[^\w']+", '', text.lower()) # 去除非单词字符,保留撇号
|
||||||
|
|
||||||
|
def build_hyp_token_to_asr_word_index(hyp_tokens: List[str], asr_words: List[WordModel]) -> List[int]:
|
||||||
|
hyp_to_asr_word_idx = [-1] * len(hyp_tokens)
|
||||||
|
i_hyp, i_asr = 0, 0
|
||||||
|
|
||||||
|
while i_hyp < len(hyp_tokens) and i_asr < len(asr_words):
|
||||||
|
hyp_token = normalize(hyp_tokens[i_hyp])
|
||||||
|
asr_word = normalize(asr_words[i_asr].text)
|
||||||
|
|
||||||
|
# 匹配包含/前缀关系,提高鲁棒性
|
||||||
|
if hyp_token == asr_word or hyp_token in asr_word or asr_word in hyp_token:
|
||||||
|
hyp_to_asr_word_idx[i_hyp] = i_asr
|
||||||
|
i_hyp += 1
|
||||||
|
i_asr += 1
|
||||||
|
else:
|
||||||
|
i_hyp += 1
|
||||||
|
|
||||||
|
return hyp_to_asr_word_idx
|
||||||
|
|
||||||
|
def find_tail_word(
|
||||||
|
ref_tokens: List[str], # 参考文本token列表
|
||||||
|
hyp_tokens: List[str], # 预测文本token列表
|
||||||
|
alignment: List[Tuple[Optional[int], Optional[int]]], # (ref_idx, hyp_idx)对齐结果
|
||||||
|
hyp_to_asr_word_idx: dict,
|
||||||
|
asr_words: List[WordModel],
|
||||||
|
punct_set: set = set(",。!?、,.!?;;")
|
||||||
|
) -> Optional[WordModel]:
|
||||||
|
"""
|
||||||
|
通过参考文本尾token,定位对应预测token,再映射到ASR词,拿时间
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
找到 ASR 结果中对应预测文本“最后一个有效对齐词”的 WordModel(tail word)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 1. 去掉 ref 尾部标点,找到 ref 尾词 index
|
||||||
|
tail_ref_idx = len(ref_tokens) - 1
|
||||||
|
while tail_ref_idx >= 0 and ref_tokens[tail_ref_idx] in punct_set:
|
||||||
|
tail_ref_idx -= 1
|
||||||
|
if tail_ref_idx < 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 2. 在 alignment 中找到对应的 hyp_idx
|
||||||
|
tail_hyp_idx = None
|
||||||
|
for ref_idx, hyp_idx in reversed(alignment):
|
||||||
|
if ref_idx == tail_ref_idx and hyp_idx is not None:
|
||||||
|
tail_hyp_idx = hyp_idx
|
||||||
|
break
|
||||||
|
|
||||||
|
# 3. 如果找不到,退一步找最后一个有匹配的 ref_idx
|
||||||
|
if tail_hyp_idx is None:
|
||||||
|
for ref_idx, hyp_idx in reversed(alignment):
|
||||||
|
if hyp_idx is not None:
|
||||||
|
tail_hyp_idx = hyp_idx
|
||||||
|
break
|
||||||
|
|
||||||
|
if tail_hyp_idx is None or tail_hyp_idx >= len(hyp_to_asr_word_idx):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 4. 映射到 ASR word index
|
||||||
|
asr_word_idx = hyp_to_asr_word_idx[tail_hyp_idx]
|
||||||
|
if asr_word_idx is None or asr_word_idx < 0 or asr_word_idx >= len(asr_words):
|
||||||
|
return None
|
||||||
|
|
||||||
|
return asr_words[asr_word_idx]
|
||||||
|
|
||||||
|
|
||||||
|
def find_tail_word2(
|
||||||
|
ref_tokens: List[str], # 标答token列表
|
||||||
|
hyp_tokens: List[str], # 预测token列表
|
||||||
|
alignment: List[Tuple[Optional[int], Optional[int]]], # 对齐 (ref_idx, hyp_idx)
|
||||||
|
hyp_to_asr_word_idx: List[int], # hyp token 对应的 ASR word 索引
|
||||||
|
asr_words: List[WordModel],
|
||||||
|
punct_set: set = set(",。!?、,.!?;;"),
|
||||||
|
enable_debug: bool = False
|
||||||
|
) -> Optional[WordModel]:
|
||||||
|
"""
|
||||||
|
找到 ASR 结果中对应预测文本“最后一个有效对齐词”的 WordModel(tail word)
|
||||||
|
|
||||||
|
返回 None 表示没找到
|
||||||
|
"""
|
||||||
|
# Step 1. 找到 ref_tokens 中最后一个非标点的索引
|
||||||
|
tail_ref_idx = len(ref_tokens) - 1
|
||||||
|
while tail_ref_idx >= 0 and ref_tokens[tail_ref_idx] in punct_set:
|
||||||
|
tail_ref_idx -= 1
|
||||||
|
if tail_ref_idx < 0:
|
||||||
|
if enable_debug:
|
||||||
|
print("全是标点,尾字找不到")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Step 2. alignment 中查找 tail_ref_idx 对应的 hyp_idx
|
||||||
|
tail_hyp_idx = None
|
||||||
|
for ref_idx, hyp_idx in reversed(alignment):
|
||||||
|
if ref_idx == tail_ref_idx and hyp_idx is not None:
|
||||||
|
tail_hyp_idx = hyp_idx
|
||||||
|
break
|
||||||
|
|
||||||
|
# Step 3. fallback:如果找不到,向前找最近一个非标点且能对齐的 ref_idx
|
||||||
|
fallback_idx = tail_ref_idx
|
||||||
|
while tail_hyp_idx is None and fallback_idx >= 0:
|
||||||
|
if ref_tokens[fallback_idx] not in punct_set:
|
||||||
|
for ref_idx, hyp_idx in reversed(alignment):
|
||||||
|
if ref_idx == fallback_idx and hyp_idx is not None:
|
||||||
|
tail_hyp_idx = hyp_idx
|
||||||
|
break
|
||||||
|
fallback_idx -= 1
|
||||||
|
|
||||||
|
if tail_hyp_idx is None or tail_hyp_idx >= len(hyp_to_asr_word_idx):
|
||||||
|
if enable_debug:
|
||||||
|
print(f"tail_hyp_idx 无法找到或超出范围: {tail_hyp_idx}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
asr_word_idx = hyp_to_asr_word_idx[tail_hyp_idx]
|
||||||
|
if asr_word_idx is None or asr_word_idx < 0 or asr_word_idx >= len(asr_words):
|
||||||
|
if enable_debug:
|
||||||
|
print(f"asr_word_idx 无效: {asr_word_idx}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return asr_words[asr_word_idx]
|
||||||
|
|
||||||
|
|
||||||
|
def find_head_word(
|
||||||
|
ref_tokens: List[str],
|
||||||
|
hyp_tokens: List[str],
|
||||||
|
alignment: List[Tuple[Optional[int], Optional[int]]],
|
||||||
|
hyp_to_asr_word_idx: dict,
|
||||||
|
asr_words: List[WordModel],
|
||||||
|
punct_set: set = set(",。!?、,.!?;;")
|
||||||
|
) -> Optional[WordModel]:
|
||||||
|
"""
|
||||||
|
通过参考文本开头第一个非标点token,定位对应预测token,再映射到ASR词,拿时间
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 1. 找到参考文本开头非标点索引
|
||||||
|
head_ref_idx = 0
|
||||||
|
while head_ref_idx < len(ref_tokens) and ref_tokens[head_ref_idx] in punct_set:
|
||||||
|
head_ref_idx += 1
|
||||||
|
if head_ref_idx >= len(ref_tokens):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 2. 找到 alignment 中对应的 hyp_idx
|
||||||
|
head_hyp_idx = None
|
||||||
|
for ref_idx, hyp_idx in alignment:
|
||||||
|
if ref_idx == head_ref_idx and hyp_idx is not None:
|
||||||
|
head_hyp_idx = hyp_idx
|
||||||
|
break
|
||||||
|
|
||||||
|
if head_hyp_idx is None or head_hyp_idx >= len(hyp_to_asr_word_idx):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 3. 映射到 asr_words 的索引
|
||||||
|
asr_word_idx = hyp_to_asr_word_idx[head_hyp_idx]
|
||||||
|
if asr_word_idx is None or asr_word_idx < 0 or asr_word_idx >= len(asr_words):
|
||||||
|
return None
|
||||||
|
|
||||||
|
return asr_words[asr_word_idx]
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_sentence_delay(
|
||||||
|
datas: List[Tuple[AudioItem, List[SegmentModel]]], language: str = "zh"
|
||||||
|
) -> (float, float, float):
|
||||||
|
"""
|
||||||
|
|
||||||
|
:param datas: 标答和模型结果
|
||||||
|
:return: 未找到尾字的比例,修正0的数量,平均延迟时间。
|
||||||
|
"""
|
||||||
|
tail_offset_time = 0 # 尾字偏移时间
|
||||||
|
standard_offset_time = 0 # 尾字偏移时间
|
||||||
|
tail_not_found = 0 # 未找到尾字数量
|
||||||
|
tail_found = 0 # 找到尾字数量
|
||||||
|
|
||||||
|
standard_fix = 0
|
||||||
|
final_fix = 0
|
||||||
|
|
||||||
|
head_offset_time = 0 # 尾字偏移时间
|
||||||
|
final_offset_time = 0 # 尾字偏移时间
|
||||||
|
head_not_found = 0 # 未找到尾字数量
|
||||||
|
head_found = 0 # 找到尾字数量
|
||||||
|
|
||||||
|
for audio_item, asr_list in datas:
|
||||||
|
if not audio_item.voice:
|
||||||
|
continue
|
||||||
|
# (以防万一)将标答中所有的文本连起来,并将标答中最后一条信息的结束时间作为结束时间。
|
||||||
|
ref_text = ""
|
||||||
|
for voice in audio_item.voice:
|
||||||
|
ref_text = ref_text + voice.answer.strip()
|
||||||
|
if not ref_text:
|
||||||
|
continue
|
||||||
|
logger.debug(f"-=-=-=-=-=-=-=-=-=-=-=-=-=start-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=--=")
|
||||||
|
|
||||||
|
ref_end_ms = int(audio_item.voice[-1].end * 1000)
|
||||||
|
ref_start_ms = int(audio_item.voice[0].start * 1000)
|
||||||
|
|
||||||
|
# 录音所有的片段都过一下,text的end和receive的end
|
||||||
|
# 统计定稿时间。 接收到segment的时间 - segment.end_time
|
||||||
|
#
|
||||||
|
# pred_text = ""
|
||||||
|
# asr_words: List[WordModel] = []
|
||||||
|
# for asr in asr_list:
|
||||||
|
# pred_text = pred_text + asr.text
|
||||||
|
# final_offset = asr.receive_time - asr.end_time
|
||||||
|
# asr_words = asr_words + asr.words
|
||||||
|
# for word in asr.words:
|
||||||
|
# word.segment = asr
|
||||||
|
# if final_offset < 0:
|
||||||
|
# final_fix = final_fix + 1
|
||||||
|
# # 统计被修复的数量
|
||||||
|
# final_offset = 0
|
||||||
|
# # 统计定稿偏移时间
|
||||||
|
# final_offset_time = final_offset_time + final_offset
|
||||||
|
|
||||||
|
pred_text = []
|
||||||
|
asr_words: List[WordModel] = []
|
||||||
|
temp_final_offset_time = 0
|
||||||
|
for asr in asr_list:
|
||||||
|
pred_text = pred_text + [word.text for word in asr.words]
|
||||||
|
final_offset = asr.receive_time - asr.end_time
|
||||||
|
logger.debug(f"asr.receive_time {asr.receive_time} , asr.end_time {asr.end_time} , final_offset {final_offset}")
|
||||||
|
asr_words = asr_words + asr.words
|
||||||
|
for word in asr.words:
|
||||||
|
word.segment = asr
|
||||||
|
if final_offset < 0:
|
||||||
|
final_fix = final_fix + 1
|
||||||
|
# 统计被修复的数量
|
||||||
|
final_offset = 0
|
||||||
|
# 统计定稿偏移时间
|
||||||
|
temp_final_offset_time = temp_final_offset_time + final_offset
|
||||||
|
final_offset_time = final_offset_time + temp_final_offset_time / len(asr_list)
|
||||||
|
|
||||||
|
# 处理模型给出的结果。
|
||||||
|
logger.debug(f"text: {ref_text},pred_text: {pred_text}")
|
||||||
|
# 计算对应关系
|
||||||
|
|
||||||
|
# pred_tokens 是与原文一致的,只是可能多了几个为空的位置。需要打平为一维数组,并记录对应的word的位置。
|
||||||
|
flat_pred_tokens = []
|
||||||
|
hyp_to_asr_word_idx = {} # key: flat_pred_token_index -> asr_word_index
|
||||||
|
|
||||||
|
alignment_type = get_alignment_type(language)
|
||||||
|
if alignment_type == "chart":
|
||||||
|
label_tokens = Tokenizer.tokenize([ref_text], language)[0]
|
||||||
|
pred_tokens = Tokenizer.tokenize(pred_text, language)
|
||||||
|
for asr_idx, token_group in enumerate(pred_tokens):
|
||||||
|
for token in token_group:
|
||||||
|
flat_pred_tokens.append(token)
|
||||||
|
hyp_to_asr_word_idx[len(flat_pred_tokens) - 1] = asr_idx
|
||||||
|
alignment = align_texts(label_tokens, "".join(flat_pred_tokens))
|
||||||
|
else:
|
||||||
|
label_tokens = Tokenizer.norm_and_tokenize([ref_text], language)[0]
|
||||||
|
pred_tokens = Tokenizer.norm_and_tokenize(pred_text, language)
|
||||||
|
for asr_idx, token_group in enumerate(pred_tokens):
|
||||||
|
for token in token_group:
|
||||||
|
flat_pred_tokens.append(token)
|
||||||
|
hyp_to_asr_word_idx[len(flat_pred_tokens) - 1] = asr_idx
|
||||||
|
alignment = align_tokens(label_tokens, flat_pred_tokens)
|
||||||
|
|
||||||
|
logger.debug(f"ref_tokens: {label_tokens}")
|
||||||
|
logger.debug(f"pred_tokens: {pred_tokens}")
|
||||||
|
logger.debug(f"alignment sample: {alignment[:30]}") # 只打印前30个,避免日志过大
|
||||||
|
logger.debug(f"hyp_to_asr_word_idx: {hyp_to_asr_word_idx}")
|
||||||
|
|
||||||
|
head_word_info = find_head_word(label_tokens, pred_tokens, alignment, hyp_to_asr_word_idx, asr_words)
|
||||||
|
|
||||||
|
if head_word_info is None:
|
||||||
|
# 统计没有找到首字的数量
|
||||||
|
head_not_found = head_not_found + 1
|
||||||
|
logger.debug(f"未找到首字")
|
||||||
|
else:
|
||||||
|
logger.debug(f"head_word: {head_word_info.text} ref_start_ms:{ref_start_ms}")
|
||||||
|
# 找到首字
|
||||||
|
# 统计首字偏移时间 首字在策略中出现的word的时间 - 标答start_time
|
||||||
|
head_offset_time = head_offset_time + abs(head_word_info.start_time - ref_start_ms)
|
||||||
|
# 统计找到首字的数量
|
||||||
|
head_found += 1
|
||||||
|
|
||||||
|
# 找尾字所在的模型返回words信息。
|
||||||
|
|
||||||
|
tail_word_info = find_tail_word(label_tokens, pred_tokens, alignment, hyp_to_asr_word_idx, asr_words)
|
||||||
|
if tail_word_info is None:
|
||||||
|
# 没有找到尾字,记录数量
|
||||||
|
tail_not_found = tail_not_found + 1
|
||||||
|
logger.debug(f"未找到尾字")
|
||||||
|
else:
|
||||||
|
# 找到尾字了
|
||||||
|
logger.debug(f"tail_word: {tail_word_info.text} ref_end_ms: {ref_end_ms}")
|
||||||
|
# 统计尾字偏移时间 标答的end_time - 策略尾字所在word的end_time
|
||||||
|
tail_offset_time = abs(ref_end_ms - tail_word_info.end_time) + tail_offset_time
|
||||||
|
|
||||||
|
# 统计标答句延迟时间 策略尾字所在word的实际接收时间 - 标答句end时间
|
||||||
|
standard_offset = tail_word_info.segment.receive_time - ref_end_ms
|
||||||
|
logger.debug(f"tail_word_info.segment.receive_time {tail_word_info.segment.receive_time } , tail_word_info.end_time {tail_word_info.end_time} , ref_end_ms {ref_end_ms}")
|
||||||
|
# 如果小于0修正为0
|
||||||
|
if standard_offset < 0:
|
||||||
|
standard_offset = 0
|
||||||
|
# 统计被修正的数量
|
||||||
|
standard_fix = standard_fix + 1
|
||||||
|
standard_offset_time = standard_offset + standard_offset_time
|
||||||
|
|
||||||
|
# 统计找到尾字的数量
|
||||||
|
tail_found += 1
|
||||||
|
|
||||||
|
logger.info(f"-=-=-=-=-=-=-=-=-=-=-=-=-=end-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=--=")
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"找到首字字数量: {head_found},未找到首字数量:{head_not_found},找到尾字数量: {tail_found},未找到尾字数量:{tail_not_found},修正标答偏移负数数量:{standard_fix},修正定稿偏移负数数量:{final_fix},")
|
||||||
|
logger.debug(
|
||||||
|
f"尾字偏移总时间:{tail_offset_time},标答句偏移总时间:{standard_offset_time},首字偏移总时间:{head_offset_time},定稿偏移总时间:{final_offset_time},")
|
||||||
|
|
||||||
|
#
|
||||||
|
# 统计平均值
|
||||||
|
head_not_found_ratio = head_not_found / (head_found + head_not_found)
|
||||||
|
tail_not_found_ratio = tail_not_found / (tail_found + tail_not_found)
|
||||||
|
|
||||||
|
average_tail_offset = tail_offset_time / tail_found / 1000
|
||||||
|
average_head_offset = head_offset_time / head_found / 1000
|
||||||
|
|
||||||
|
average_standard_offset = standard_offset_time / tail_found / 1000
|
||||||
|
average_final_offset = final_offset_time / tail_found / 1000
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"首字未找到比例:{head_not_found_ratio},尾字未找到比例:{tail_not_found_ratio},首字偏移时间:{average_head_offset},尾字偏移时间:{average_tail_offset},标答句偏移时间:{average_standard_offset},定稿偏移时间:{average_final_offset}")
|
||||||
|
|
||||||
|
return head_not_found_ratio, average_head_offset, tail_not_found_ratio, average_standard_offset, average_final_offset, average_tail_offset
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
checks = [
|
||||||
|
{
|
||||||
|
"type": "zh",
|
||||||
|
"ref": "今天天气真好",
|
||||||
|
"hyp": "今天真好"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "zh",
|
||||||
|
"ref": "我喜欢吃苹果",
|
||||||
|
"hyp": "我很喜欢吃香蕉"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "zh",
|
||||||
|
"ref": "我喜欢吃苹果",
|
||||||
|
"hyp": "我喜欢吃苹果"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "en",
|
||||||
|
"ref": "I like to eat apples",
|
||||||
|
"hyp": "I really like eating apples"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "en",
|
||||||
|
"ref": "She is going to the market",
|
||||||
|
"hyp": "She went market"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "en",
|
||||||
|
"ref": "Hello world",
|
||||||
|
"hyp": "Hello world"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "en",
|
||||||
|
"ref": "Good morning",
|
||||||
|
"hyp": "Bad night"
|
||||||
|
},
|
||||||
|
]
|
||||||
|
for check in checks:
|
||||||
|
ref = check.get("ref")
|
||||||
|
type = check.get("type")
|
||||||
|
hyp = check.get("hyp")
|
||||||
|
|
||||||
|
res1 = align_texts(ref, hyp)
|
||||||
|
|
||||||
|
res2 = align_tokens(list(ref), list(hyp))
|
||||||
|
|
||||||
|
from utils.tokenizer import Tokenizer
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
tokens_pred = Tokenizer.norm_and_tokenize([ref], type)
|
||||||
|
print(time.time() - start)
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
Tokenizer.norm_and_tokenize([ref + ref + ref], type)
|
||||||
|
print(time.time() - start)
|
||||||
|
|
||||||
|
tokens_label = Tokenizer.norm_and_tokenize([hyp], type)
|
||||||
|
print(tokens_pred)
|
||||||
|
print(tokens_label)
|
||||||
|
res3 = align_tokens(tokens_pred[0], tokens_label[0])
|
||||||
|
print(res1 == res2)
|
||||||
|
print(res1 == res3)
|
||||||
195
utils/client.py
Normal file
195
utils/client.py
Normal file
@@ -0,0 +1,195 @@
|
|||||||
|
import json, os, threading, time, traceback
|
||||||
|
from typing import (
|
||||||
|
Any, List
|
||||||
|
)
|
||||||
|
from copy import deepcopy
|
||||||
|
from utils.logger import logger
|
||||||
|
|
||||||
|
from websocket import (
|
||||||
|
create_connection,
|
||||||
|
WebSocketConnectionClosedException,
|
||||||
|
ABNF
|
||||||
|
)
|
||||||
|
|
||||||
|
from utils.model import (
|
||||||
|
ASRResponseModel,
|
||||||
|
SegmentModel
|
||||||
|
)
|
||||||
|
import threading
|
||||||
|
import queue
|
||||||
|
|
||||||
|
from pydantic_core import ValidationError
|
||||||
|
|
||||||
|
_IS_TEST = os.getenv("SUBMIT_CONFIG_FILEPATH") is None
|
||||||
|
|
||||||
|
|
||||||
|
class ASRWebSocketClient:
|
||||||
|
def __init__(self, url: str):
|
||||||
|
self.endpoint = f"{url}/recognition"
|
||||||
|
# self.ctx = deepcopy(ctx)
|
||||||
|
self.ws = None
|
||||||
|
self.conn_attempts = -1
|
||||||
|
self.failed = False
|
||||||
|
self.terminate_time = float("inf")
|
||||||
|
self.sent_timestamps: List[float] = []
|
||||||
|
self.received_timestamps: List[float] = []
|
||||||
|
self.results: List[Any] = []
|
||||||
|
self.connected = True
|
||||||
|
logger.info(f"Target endpoint: {self.endpoint}")
|
||||||
|
|
||||||
|
def execute(self, path: str) -> List[SegmentModel]:
|
||||||
|
# 开启线程,一个发,一个接要记录接收数据的时间。
|
||||||
|
send_thread = threading.Thread(target=self.send, args=(path,))
|
||||||
|
# 记录开始时间
|
||||||
|
start_time = time.time()
|
||||||
|
# 启动线程发送数据
|
||||||
|
send_thread.start()
|
||||||
|
|
||||||
|
# 用来返回结果的队列
|
||||||
|
result_queue = queue.Queue()
|
||||||
|
|
||||||
|
# 线程封装 receive
|
||||||
|
def receive_thread_fn():
|
||||||
|
try:
|
||||||
|
res = self.receive(start_time)
|
||||||
|
result_queue.put(res)
|
||||||
|
except Exception as e:
|
||||||
|
# 放异常信息
|
||||||
|
result_queue.put(None)
|
||||||
|
|
||||||
|
# 启动 receive 线程(关注返回值)
|
||||||
|
receive_thread = threading.Thread(target=receive_thread_fn)
|
||||||
|
receive_thread.start()
|
||||||
|
|
||||||
|
# 可选:等待 send 线程也结束(可删)
|
||||||
|
send_thread.join()
|
||||||
|
receive_thread.join()
|
||||||
|
|
||||||
|
# 等待 receive 完成,并获取返回值
|
||||||
|
result = result_queue.get()
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def initialize_connection(self, language: str):
|
||||||
|
expiration = time.time() + float(os.getenv("end_time", "2"))
|
||||||
|
self.connected = False
|
||||||
|
init = False
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
if not init:
|
||||||
|
logger.debug(f"建立ws链接:发送建立ws链接请求。")
|
||||||
|
self.ws = create_connection(self.endpoint)
|
||||||
|
body = json.dumps(self._get_init_payload(language))
|
||||||
|
logger.debug(f"建立ws链接:发送初始化数据。{body}")
|
||||||
|
self.ws.send(body)
|
||||||
|
init = True
|
||||||
|
msg = self.ws.recv()
|
||||||
|
logger.debug(f"收到响应数据: {msg}")
|
||||||
|
if len(msg) == 0:
|
||||||
|
time.sleep(0.5) # 睡眠一下等待数据写回来
|
||||||
|
continue
|
||||||
|
if isinstance(msg, str):
|
||||||
|
try:
|
||||||
|
msg = json.loads(msg)
|
||||||
|
except Exception:
|
||||||
|
raise Exception("建立ws链接:响应数据非json格式!")
|
||||||
|
if isinstance(msg, dict):
|
||||||
|
connected = msg.get("success")
|
||||||
|
if connected:
|
||||||
|
logger.debug("建立ws链接:链接建立成功!")
|
||||||
|
self.conn_attempts = 0
|
||||||
|
self.connected = True
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
logger.info("建立ws链接:链接建立失败!")
|
||||||
|
init = False
|
||||||
|
self.conn_attempts = self.conn_attempts + 1
|
||||||
|
if self.conn_attempts > 5:
|
||||||
|
raise ConnectionRefusedError("重试5次后,仍然无法建立ws链接。")
|
||||||
|
if time.time() > expiration:
|
||||||
|
raise RuntimeError("建立ws链接:链接建立超时!")
|
||||||
|
|
||||||
|
except WebSocketConnectionClosedException or TimeoutError:
|
||||||
|
raise Exception("建立ws链接:初始化阶段连接中断,退出。")
|
||||||
|
except Exception as e:
|
||||||
|
logger.info("建立ws链接:链接建立失败!")
|
||||||
|
init = False
|
||||||
|
self.conn_attempts = self.conn_attempts + 1
|
||||||
|
if self.conn_attempts > 5:
|
||||||
|
raise ConnectionRefusedError("重试5次后,仍然无法建立ws链接。")
|
||||||
|
if time.time() > expiration:
|
||||||
|
raise RuntimeError("建立ws链接:链接建立超时!")
|
||||||
|
|
||||||
|
def shutdown(self):
|
||||||
|
try:
|
||||||
|
if self.ws:
|
||||||
|
self.ws.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
self.connected = False
|
||||||
|
|
||||||
|
def _get_init_payload(self, language="zh") -> dict:
|
||||||
|
# language = "zh"
|
||||||
|
return {
|
||||||
|
"language": language
|
||||||
|
}
|
||||||
|
|
||||||
|
def _get_finish_payload(self) -> dict:
|
||||||
|
return {
|
||||||
|
"end": "true"
|
||||||
|
}
|
||||||
|
|
||||||
|
def send(self, path):
|
||||||
|
skip_wav_header = path.endswith("wav")
|
||||||
|
with open(path, "rb") as f:
|
||||||
|
if skip_wav_header:
|
||||||
|
# WAV 文件头部为 44 字节
|
||||||
|
f.read(44)
|
||||||
|
|
||||||
|
while chunk := f.read(3200):
|
||||||
|
logger.debug(f"发送 {len(chunk)} 字节数据.")
|
||||||
|
self.ws.send(chunk, opcode=ABNF.OPCODE_BINARY)
|
||||||
|
time.sleep(0.1)
|
||||||
|
self.ws.send(json.dumps(self._get_finish_payload()))
|
||||||
|
|
||||||
|
def receive(self, start_time) -> [SegmentModel]:
|
||||||
|
results = []
|
||||||
|
while True:
|
||||||
|
msg = self.ws.recv()
|
||||||
|
# 记录读取到数据的时间(毫秒值)
|
||||||
|
now = 1000 * (time.time() - start_time)
|
||||||
|
logger.debug(f"{now} 收到响应数据: {msg}")
|
||||||
|
res = json.loads(msg)
|
||||||
|
if res.get("asr_results"):
|
||||||
|
item = ASRResponseModel.model_validate_json(msg).asr_results
|
||||||
|
item.receive_time = now
|
||||||
|
results.append(item)
|
||||||
|
# logger.info(item.summary())
|
||||||
|
else:
|
||||||
|
logger.debug(f"响应结束")
|
||||||
|
# 按照para_seq排序,并检查一下序号是否连续
|
||||||
|
results.sort(key=lambda x: x.para_seq)
|
||||||
|
|
||||||
|
missing_seqs = []
|
||||||
|
for i in range(len(results) - 1):
|
||||||
|
expected_next = results[i].para_seq + 1
|
||||||
|
actual_next = results[i + 1].para_seq
|
||||||
|
if actual_next != expected_next:
|
||||||
|
missing_seqs.extend(range(expected_next, actual_next))
|
||||||
|
|
||||||
|
if missing_seqs:
|
||||||
|
logger.warning(f"检测到丢失的段落序号:{missing_seqs}")
|
||||||
|
else:
|
||||||
|
logger.debug("响应数据正常")
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
ws_client = ASRWebSocketClient("ws://localhost:18000")
|
||||||
|
ws_client.initialize_connection("zh")
|
||||||
|
|
||||||
|
res = ws_client.execute("/Users/yu/Documents/code-work/asr-live-iluvatar/zh_250312/zh/99.wav")
|
||||||
|
for i in res:
|
||||||
|
print(i.summary())
|
||||||
|
print()
|
||||||
331
utils/helm.py
Normal file
331
utils/helm.py
Normal file
@@ -0,0 +1,331 @@
|
|||||||
|
import copy
|
||||||
|
import io
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import tarfile
|
||||||
|
import time
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from ruamel.yaml import YAML
|
||||||
|
|
||||||
|
from typing import Dict, Any
|
||||||
|
|
||||||
|
sut_chart_root = os.path.join(
|
||||||
|
os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "helm-chart", "sut"
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__file__)
|
||||||
|
lb_headers = (
|
||||||
|
{"Authorization": "Bearer " + os.getenv("LEADERBOARD_API_TOKEN", "")}
|
||||||
|
if os.getenv("LEADERBOARD_API_TOKEN")
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
pull_num: defaultdict = defaultdict()
|
||||||
|
|
||||||
|
JOB_ID = int(os.getenv("JOB_ID", "-1"))
|
||||||
|
assert JOB_ID != -1
|
||||||
|
LOAD_SUT_URL = os.getenv("LOAD_SUT_URL")
|
||||||
|
assert LOAD_SUT_URL is not None
|
||||||
|
GET_JOB_SUT_INFO_URL = os.getenv("GET_JOB_SUT_INFO_URL")
|
||||||
|
assert GET_JOB_SUT_INFO_URL is not None
|
||||||
|
|
||||||
|
|
||||||
|
def apply_env_to_values(values, envs):
|
||||||
|
if "env" not in values:
|
||||||
|
values["env"] = []
|
||||||
|
old_key_list = [x["name"] for x in values["env"]]
|
||||||
|
for k, v in envs.items():
|
||||||
|
try:
|
||||||
|
idx = old_key_list.index(k)
|
||||||
|
values["env"][idx]["value"] = v
|
||||||
|
except ValueError:
|
||||||
|
values["env"].append({"name": k, "value": v})
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
def merge_values(base_value, incr_value):
|
||||||
|
if isinstance(base_value, dict) and isinstance(incr_value, dict):
|
||||||
|
for k in incr_value:
|
||||||
|
base_value[k] = (
|
||||||
|
merge_values(base_value[k], incr_value[k])
|
||||||
|
if k in base_value
|
||||||
|
else incr_value[k]
|
||||||
|
)
|
||||||
|
elif isinstance(base_value, list) and isinstance(incr_value, list):
|
||||||
|
base_value.extend(incr_value)
|
||||||
|
else:
|
||||||
|
base_value = incr_value
|
||||||
|
return base_value
|
||||||
|
|
||||||
|
|
||||||
|
def gen_chart_tarball(docker_image, policy=None):
|
||||||
|
"""docker image加上digest并根据image生成helm chart包, 失败直接异常退出
|
||||||
|
|
||||||
|
Args:
|
||||||
|
docker_image (_type_): docker image
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[BytesIO, dict]: [helm chart包file对象, values内容]
|
||||||
|
"""
|
||||||
|
# load values template
|
||||||
|
with open(os.path.join(sut_chart_root, "values.yaml.tmpl")) as fp:
|
||||||
|
yaml = YAML(typ="rt")
|
||||||
|
values = yaml.load(fp)
|
||||||
|
# update docker_image
|
||||||
|
get_image_hash_url = os.getenv("GET_IMAGE_HASH_URL", None)
|
||||||
|
if get_image_hash_url is not None:
|
||||||
|
# convert tag to hash for docker_image
|
||||||
|
resp = requests.get(
|
||||||
|
get_image_hash_url,
|
||||||
|
headers=lb_headers,
|
||||||
|
params={"image": docker_image},
|
||||||
|
timeout=600,
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200, (
|
||||||
|
"Convert tag to hash for docker image failed, API retcode %d"
|
||||||
|
% resp.status_code
|
||||||
|
)
|
||||||
|
resp = resp.json()
|
||||||
|
assert resp[
|
||||||
|
"success"
|
||||||
|
], "Convert tag to hash for docker image failed, response: %s" % str(resp)
|
||||||
|
token = resp["data"]["image"].rsplit(":", 2)
|
||||||
|
assert len(token) == 3, "Invalid docker image %s" % resp["data"]["image"]
|
||||||
|
values["image"]["repository"] = token[0]
|
||||||
|
values["image"]["tag"] = ":".join(token[1:])
|
||||||
|
else:
|
||||||
|
token = docker_image.rsplit(":", 1)
|
||||||
|
if len(token) != 2:
|
||||||
|
raise RuntimeError("Invalid docker image %s" % docker_image)
|
||||||
|
values["image"]["repository"] = token[0]
|
||||||
|
values["image"]["tag"] = token[1]
|
||||||
|
values["image"]["pullPolicy"] = policy
|
||||||
|
# output values.yaml
|
||||||
|
with open(os.path.join(sut_chart_root, "values.yaml"), "w") as fp:
|
||||||
|
yaml = YAML(typ="rt")
|
||||||
|
yaml.dump(values, fp)
|
||||||
|
# tarball
|
||||||
|
tarfp = io.BytesIO()
|
||||||
|
with tarfile.open(fileobj=tarfp, mode="w:gz") as tar:
|
||||||
|
tar.add(
|
||||||
|
sut_chart_root, arcname=os.path.basename(sut_chart_root), recursive=True
|
||||||
|
)
|
||||||
|
tarfp.seek(0)
|
||||||
|
logger.debug(f"Generated chart using values: {values}")
|
||||||
|
return tarfp, values
|
||||||
|
|
||||||
|
|
||||||
|
def deploy_chart(
|
||||||
|
name_suffix,
|
||||||
|
readiness_timeout,
|
||||||
|
chart_str=None,
|
||||||
|
chart_fileobj=None,
|
||||||
|
extra_values=None,
|
||||||
|
restart_count_limit=3,
|
||||||
|
pullimage_count_limit=3,
|
||||||
|
):
|
||||||
|
"""部署sut, 失败直接异常退出
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name_suffix (str): 同一个job有多个sut时, 区分不同sut的名称
|
||||||
|
readiness_timeout (int): readiness超时时间, 单位s
|
||||||
|
chart_str (int, optional): chart url, 不为None则忽略chart_fileobj. Defaults to None.
|
||||||
|
chart_fileobj (BytesIO, optional): helm chart包file对象, chart_str不为None使用. Defaults to None.
|
||||||
|
extra_values (dict, optional): helm values的补充内容. Defaults to None.
|
||||||
|
restart_count_limit (int, optional): sut重启次数限制, 超出则异常退出. Defaults to 3.
|
||||||
|
pullimage_count_limit (int, optional): image拉取次数限制, 超出则异常退出. Defaults to 3.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[str, str]: [用于访问服务的k8s域名, 用于unload_sut的名称]
|
||||||
|
"""
|
||||||
|
logger.info(
|
||||||
|
f"Deploying SUT application for JOB {JOB_ID}, name_suffix {name_suffix}, extra_values {extra_values}"
|
||||||
|
)
|
||||||
|
# deploy
|
||||||
|
payload = {
|
||||||
|
"job_id": JOB_ID,
|
||||||
|
"resource_name": name_suffix,
|
||||||
|
"priorityclassname": os.environ.get("priorityclassname"),
|
||||||
|
}
|
||||||
|
extra_values = {} if not extra_values else extra_values
|
||||||
|
payload["values"] = json.dumps(extra_values, ensure_ascii=False)
|
||||||
|
if chart_str is not None:
|
||||||
|
payload["helm_chart"] = chart_str
|
||||||
|
resp = requests.post(
|
||||||
|
LOAD_SUT_URL,
|
||||||
|
data=payload,
|
||||||
|
headers=lb_headers,
|
||||||
|
timeout=600,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
chart_fileobj is not None
|
||||||
|
), "Either chart_str or chart_fileobj should be set"
|
||||||
|
resp = requests.post(
|
||||||
|
LOAD_SUT_URL,
|
||||||
|
data=payload,
|
||||||
|
headers=lb_headers,
|
||||||
|
files=[("helm_chart_file", (name_suffix + ".tgz", chart_fileobj))],
|
||||||
|
timeout=600,
|
||||||
|
)
|
||||||
|
if resp.status_code != 200:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Failed to deploy application status_code %d %s"
|
||||||
|
% (resp.status_code, resp.text)
|
||||||
|
)
|
||||||
|
resp = resp.json()
|
||||||
|
if not resp["success"]:
|
||||||
|
raise RuntimeError("Failed to deploy application response %r" % resp)
|
||||||
|
service_name = resp["data"]["service_name"]
|
||||||
|
sut_name = resp["data"]["sut_name"]
|
||||||
|
logger.info(f"SUT application deployed with service_name {service_name}")
|
||||||
|
# waiting for appliation ready
|
||||||
|
running_at = None
|
||||||
|
while True:
|
||||||
|
retry_interval = 10
|
||||||
|
logger.info(
|
||||||
|
f"Waiting {retry_interval} seconds to check whether SUT application {service_name} is ready..."
|
||||||
|
)
|
||||||
|
time.sleep(retry_interval)
|
||||||
|
check_result, running_at = check_sut_ready_from_resp(
|
||||||
|
service_name,
|
||||||
|
running_at,
|
||||||
|
readiness_timeout,
|
||||||
|
restart_count_limit,
|
||||||
|
pullimage_count_limit,
|
||||||
|
)
|
||||||
|
if check_result:
|
||||||
|
break
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"SUT application for JOB {JOB_ID} name_suffix {name_suffix} is ready, service_name {service_name}"
|
||||||
|
)
|
||||||
|
return service_name, sut_name
|
||||||
|
|
||||||
|
|
||||||
|
def check_sut_ready_from_resp(
|
||||||
|
service_name,
|
||||||
|
running_at,
|
||||||
|
readiness_timeout,
|
||||||
|
restart_count_limit,
|
||||||
|
pullimage_count_limit,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
resp = requests.get(
|
||||||
|
f"{GET_JOB_SUT_INFO_URL}/{JOB_ID}",
|
||||||
|
headers=lb_headers,
|
||||||
|
params={"with_detail": True},
|
||||||
|
timeout=600,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Exception occured while getting SUT application {service_name} status", e
|
||||||
|
)
|
||||||
|
return False, running_at
|
||||||
|
if resp.status_code != 200:
|
||||||
|
logger.warning(
|
||||||
|
f"Get SUT application {service_name} status failed with status_code {resp.status_code}"
|
||||||
|
)
|
||||||
|
return False, running_at
|
||||||
|
resp = resp.json()
|
||||||
|
if not resp["success"]:
|
||||||
|
logger.warning(
|
||||||
|
f"Get SUT application {service_name} status failed with response {resp}"
|
||||||
|
)
|
||||||
|
return False, running_at
|
||||||
|
if len(resp["data"]["sut"]) == 0:
|
||||||
|
logger.warning("Empty SUT application status")
|
||||||
|
return False, running_at
|
||||||
|
resp_data_sut = copy.deepcopy(resp["data"]["sut"])
|
||||||
|
for status in resp_data_sut:
|
||||||
|
del status["detail"]
|
||||||
|
logger.info(f"Got SUT application status: {resp_data_sut}")
|
||||||
|
for status in resp["data"]["sut"]:
|
||||||
|
if status["phase"] in ["Succeeded", "Failed"]:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Some pods of SUT application {service_name} terminated with status {status}"
|
||||||
|
)
|
||||||
|
elif status["phase"] in ["Pending", "Unknown"]:
|
||||||
|
return False, running_at
|
||||||
|
elif status["phase"] != "Running":
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Unexcepted pod status {status} of SUT application {service_name}"
|
||||||
|
)
|
||||||
|
if running_at is None:
|
||||||
|
running_at = time.time()
|
||||||
|
for ct in status["detail"]["status"]["container_statuses"]:
|
||||||
|
if ct["restart_count"] > 0:
|
||||||
|
logger.info(
|
||||||
|
f"pod {status['pod_name']} restart count = {ct['restart_count']}"
|
||||||
|
)
|
||||||
|
if ct["restart_count"] > restart_count_limit:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"pod {status['pod_name']} restart too many times(over {restart_count_limit})"
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
ct["state"]["waiting"] is not None
|
||||||
|
and "reason" in ct["state"]["waiting"]
|
||||||
|
and ct["state"]["waiting"]["reason"]
|
||||||
|
in ["ImagePullBackOff", "ErrImagePull"]
|
||||||
|
):
|
||||||
|
pull_num[status["pod_name"]] += 1
|
||||||
|
logger.info(
|
||||||
|
"pod %s has {pull_num[status['pod_name']]} times inspect pulling image info: %s"
|
||||||
|
% (status["pod_name"], ct["state"]["waiting"])
|
||||||
|
)
|
||||||
|
if pull_num[status["pod_name"]] > pullimage_count_limit:
|
||||||
|
raise RuntimeError(f"pod {status['pod_name']} cannot pull image")
|
||||||
|
if not status["conditions"]["Ready"]:
|
||||||
|
if running_at is not None and time.time() - running_at > readiness_timeout:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"SUT Application readiness has exceeded readiness_timeout:{readiness_timeout}s"
|
||||||
|
)
|
||||||
|
return False, running_at
|
||||||
|
return True, running_at
|
||||||
|
|
||||||
|
|
||||||
|
def resource_check(values: Dict[str, Any], image: str):
|
||||||
|
# 补充镜像信息。
|
||||||
|
values["docker_image"] = image
|
||||||
|
values["image_pull_policy"] = "IfNotPresent"
|
||||||
|
values["image_pull_policy"] = "Always"
|
||||||
|
# values["command"] = ["python", "run.py"]
|
||||||
|
# 补充resources限制
|
||||||
|
values["resources"] = {
|
||||||
|
"limits": {
|
||||||
|
"cpu": 4,
|
||||||
|
"memory": "16Gi",
|
||||||
|
"iluvatar.ai/gpu": "1"
|
||||||
|
},
|
||||||
|
"requests": {
|
||||||
|
"cpu": 4,
|
||||||
|
"memory": "16Gi",
|
||||||
|
"iluvatar.ai/gpu": "1"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
values["nodeSelector"] = {
|
||||||
|
"contest.4pd.io/accelerator": "iluvatar-BI-V100"
|
||||||
|
}
|
||||||
|
values["tolerations"] = [
|
||||||
|
{
|
||||||
|
"key": "hosttype",
|
||||||
|
"operator": "Equal",
|
||||||
|
"value": "iluvatar",
|
||||||
|
"effect": "NoSchedule",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
"""
|
||||||
|
nodeSelector:
|
||||||
|
contest.4pd.io/accelerator: iluvatar-BI-V100
|
||||||
|
tolerations:
|
||||||
|
- key: hosttype
|
||||||
|
operator: Equal
|
||||||
|
value: iluvatar
|
||||||
|
effect: NoSchedule
|
||||||
|
"""
|
||||||
|
|
||||||
|
# TODO 补充选择规则
|
||||||
|
return values
|
||||||
36
utils/logger.py
Normal file
36
utils/logger.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
level = logging.INFO
|
||||||
|
level_str = "INFO"
|
||||||
|
|
||||||
|
# level = logging.DEBUG
|
||||||
|
# level_str = "DEBUG"
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
format="%(asctime)s %(name)-12s %(levelname)-4s %(message)s",
|
||||||
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
|
level=os.environ.get("LOGLEVEL", level_str),
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__file__)
|
||||||
|
|
||||||
|
# another logger
|
||||||
|
|
||||||
|
log = logging.getLogger("detailed_logger")
|
||||||
|
|
||||||
|
log.propagate = False
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
log.setLevel(level)
|
||||||
|
|
||||||
|
formatter = logging.Formatter(
|
||||||
|
"[%(asctime)s] %(levelname)s : %(pathname)s:%(lineno)d - %(message)s",
|
||||||
|
"%Y-%m-%d %H:%M:%S",
|
||||||
|
)
|
||||||
|
|
||||||
|
streamHandler = logging.StreamHandler()
|
||||||
|
streamHandler.setLevel(level)
|
||||||
|
streamHandler.setFormatter(formatter)
|
||||||
|
log.addHandler(streamHandler)
|
||||||
320
utils/metrics.py
Normal file
320
utils/metrics.py
Normal file
@@ -0,0 +1,320 @@
|
|||||||
|
# coding: utf-8
|
||||||
|
|
||||||
|
import os
|
||||||
|
from collections import Counter
|
||||||
|
from copy import deepcopy
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
import Levenshtein
|
||||||
|
import numpy as np
|
||||||
|
from schemas.context import ASRContext
|
||||||
|
from utils.logger import logger
|
||||||
|
from utils.tokenizer import Tokenizer, TokenizerType
|
||||||
|
from utils.update_submit import change_product_available
|
||||||
|
|
||||||
|
IN_TEST = os.getenv("SUBMIT_CONFIG_FILEPATH", None) is None
|
||||||
|
|
||||||
|
|
||||||
|
def text_align(context: ASRContext) -> Tuple:
|
||||||
|
start_end_count = 0
|
||||||
|
|
||||||
|
label_start_time_list = []
|
||||||
|
label_end_time_list = []
|
||||||
|
for label_item in context.labels:
|
||||||
|
label_start_time_list.append(label_item.start)
|
||||||
|
label_end_time_list.append(label_item.end)
|
||||||
|
pred_start_time_list = []
|
||||||
|
pred_end_time_list = []
|
||||||
|
sentence_start = True
|
||||||
|
for pred_item in context.preds:
|
||||||
|
if sentence_start:
|
||||||
|
pred_start_time_list.append(pred_item.recognition_results.start_time)
|
||||||
|
if pred_item.recognition_results.final_result:
|
||||||
|
pred_end_time_list.append(pred_item.recognition_results.end_time)
|
||||||
|
sentence_start = pred_item.recognition_results.final_result
|
||||||
|
# check start0 < end0 < start1 < end1 < start2 < end2 - ...
|
||||||
|
if IN_TEST:
|
||||||
|
print(pred_start_time_list)
|
||||||
|
print(pred_end_time_list)
|
||||||
|
pred_time_list = []
|
||||||
|
i, j = 0, 0
|
||||||
|
while i < len(pred_start_time_list) and j < len(pred_end_time_list):
|
||||||
|
pred_time_list.append(pred_start_time_list[i])
|
||||||
|
pred_time_list.append(pred_end_time_list[j])
|
||||||
|
i += 1
|
||||||
|
j += 1
|
||||||
|
if i < len(pred_start_time_list):
|
||||||
|
pred_time_list.append(pred_start_time_list[-1])
|
||||||
|
for i in range(1, len(pred_time_list)):
|
||||||
|
# 这里给个 600ms 的宽限
|
||||||
|
if pred_time_list[i] < pred_time_list[i - 1] - 0.6:
|
||||||
|
logger.error("识别的 start、end 不符合 start0 < end0 < start1 < end1 < start2 < end2 ...")
|
||||||
|
logger.error(
|
||||||
|
f"当前识别的每个句子开始和结束时间分别为: \
|
||||||
|
开始时间:{pred_start_time_list}, \
|
||||||
|
结束时间:{pred_end_time_list}"
|
||||||
|
)
|
||||||
|
start_end_count += 1
|
||||||
|
# change_product_available()
|
||||||
|
# 时间前后差值 300ms 范围内
|
||||||
|
start_time_align_count = 0
|
||||||
|
end_time_align_count = 0
|
||||||
|
for label_start_time in label_start_time_list:
|
||||||
|
for pred_start_time in pred_start_time_list:
|
||||||
|
if pred_start_time <= label_start_time + 0.3 and pred_start_time >= label_start_time - 0.3:
|
||||||
|
start_time_align_count += 1
|
||||||
|
break
|
||||||
|
for label_end_time in label_end_time_list:
|
||||||
|
for pred_end_time in pred_end_time_list:
|
||||||
|
if pred_end_time <= label_end_time + 0.3 and pred_end_time >= label_end_time - 0.3:
|
||||||
|
end_time_align_count += 1
|
||||||
|
break
|
||||||
|
logger.info(
|
||||||
|
f"start-time 对齐个数 {start_time_align_count}, \
|
||||||
|
end-time 对齐个数 {end_time_align_count}\
|
||||||
|
数据集中句子总数 {len(label_start_time_list)}"
|
||||||
|
)
|
||||||
|
return start_time_align_count, end_time_align_count, start_end_count
|
||||||
|
|
||||||
|
|
||||||
|
def first_delay(context: ASRContext) -> Tuple:
|
||||||
|
first_send_time = context.preds[0].send_time
|
||||||
|
first_delay_list = []
|
||||||
|
sentence_start = True
|
||||||
|
for pred_context in context.preds:
|
||||||
|
if sentence_start:
|
||||||
|
sentence_begin_time = pred_context.recognition_results.start_time
|
||||||
|
first_delay_time = pred_context.recv_time - first_send_time - sentence_begin_time
|
||||||
|
first_delay_list.append(first_delay_time)
|
||||||
|
sentence_start = pred_context.recognition_results.final_result
|
||||||
|
if IN_TEST:
|
||||||
|
print(f"当前音频的首字延迟为{first_delay_list}")
|
||||||
|
logger.info(f"当前音频的首字延迟均值为 {np.mean(first_delay_list)}s")
|
||||||
|
return np.sum(first_delay_list), len(first_delay_list)
|
||||||
|
|
||||||
|
|
||||||
|
def revision_delay(context: ASRContext):
|
||||||
|
first_send_time = context.preds[0].send_time
|
||||||
|
revision_delay_list = []
|
||||||
|
for pred_context in context.preds:
|
||||||
|
if pred_context.recognition_results.final_result:
|
||||||
|
sentence_end_time = pred_context.recognition_results.end_time
|
||||||
|
revision_delay_time = pred_context.recv_time - first_send_time - sentence_end_time
|
||||||
|
revision_delay_list.append(revision_delay_time)
|
||||||
|
|
||||||
|
if IN_TEST:
|
||||||
|
print(revision_delay_list)
|
||||||
|
logger.info(f"当前音频的修正延迟均值为 {np.mean(revision_delay_list)}s")
|
||||||
|
return np.sum(revision_delay_list), len(revision_delay_list)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_unique_token_count(context: ASRContext):
|
||||||
|
# print(context.__dict__)
|
||||||
|
# 对于每一个返回的结果都进行 tokenize
|
||||||
|
pred_text_list = [pred_context.recognition_results.text for pred_context in context.preds]
|
||||||
|
pred_text_tokenized_list = Tokenizer.norm_and_tokenize(pred_text_list, lang=context.lang)
|
||||||
|
# print(pred_text_list)
|
||||||
|
# print(pred_text_tokenized_list)
|
||||||
|
|
||||||
|
# 判断当前是否修改了超过 3s 内的 token 数目
|
||||||
|
## 当前句子的最开始接受时间
|
||||||
|
first_recv_time = None
|
||||||
|
## 不可修改的 token 个数
|
||||||
|
unmodified_token_cnt = 0
|
||||||
|
## 3s 的 index 位置
|
||||||
|
time_token_idx = 0
|
||||||
|
## 当前是句子的开始
|
||||||
|
final_sentence = True
|
||||||
|
|
||||||
|
## 修改了不可修改的范围
|
||||||
|
is_unmodified_token = False
|
||||||
|
|
||||||
|
for idx, (now_tokens, pred_context) in enumerate(zip(pred_text_tokenized_list, context.preds)):
|
||||||
|
## 当前是句子的第一次返回
|
||||||
|
if final_sentence:
|
||||||
|
first_recv_time = pred_context.recv_time
|
||||||
|
unmodified_token_cnt = 0
|
||||||
|
time_token_idx = idx
|
||||||
|
final_sentence = pred_context.recognition_results.final_result
|
||||||
|
continue
|
||||||
|
final_sentence = pred_context.recognition_results.final_result
|
||||||
|
## 当前 pred 的 recv-time
|
||||||
|
pred_recv_time = pred_context.recv_time
|
||||||
|
## 最开始 3s 直接忽略
|
||||||
|
if pred_recv_time - first_recv_time < 3:
|
||||||
|
continue
|
||||||
|
## 根据历史返回信息,获得最长不可修改长度
|
||||||
|
while time_token_idx < idx:
|
||||||
|
context_pred_tmp = context.preds[time_token_idx]
|
||||||
|
context_pred_tmp_recv_time = context_pred_tmp.recv_time
|
||||||
|
tmp_tokens = pred_text_tokenized_list[time_token_idx]
|
||||||
|
if pred_recv_time - context_pred_tmp_recv_time >= 3:
|
||||||
|
unmodified_token_cnt = max(unmodified_token_cnt, len(tmp_tokens))
|
||||||
|
time_token_idx += 1
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
## 和自己的上一条音频比,只能修改 unmodified_token_cnt 个 token
|
||||||
|
last_tokens = pred_text_tokenized_list[idx - 1]
|
||||||
|
if context.lang in ['ar', 'he']:
|
||||||
|
tokens_check_pre, tokens_check_now = last_tokens[::-1], now_tokens[::-1]
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
tokens_check_pre, tokens_check_now = last_tokens, now_tokens
|
||||||
|
for token_a, token_b in zip(tokens_check_pre[:unmodified_token_cnt], tokens_check_now[:unmodified_token_cnt]):
|
||||||
|
if token_a != token_b:
|
||||||
|
is_unmodified_token = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if is_unmodified_token and int(os.getenv('test', 0)):
|
||||||
|
logger.error(
|
||||||
|
f"{idx}-{unmodified_token_cnt}-{last_tokens[:unmodified_token_cnt]}-{now_tokens[:unmodified_token_cnt]}"
|
||||||
|
)
|
||||||
|
if is_unmodified_token:
|
||||||
|
break
|
||||||
|
|
||||||
|
if is_unmodified_token:
|
||||||
|
logger.error("修改了不可修改的文字范围")
|
||||||
|
# change_product_available()
|
||||||
|
if int(os.getenv('test', 0)):
|
||||||
|
final_result = True
|
||||||
|
result_list = []
|
||||||
|
for tokens, pred in zip(pred_text_tokenized_list, context.preds):
|
||||||
|
if final_result:
|
||||||
|
result_list.append([])
|
||||||
|
result_list[-1].append((tokens, pred.recv_time - context.preds[0].recv_time))
|
||||||
|
final_result = pred.recognition_results.final_result
|
||||||
|
for item in result_list:
|
||||||
|
logger.info(str(item))
|
||||||
|
|
||||||
|
# 记录每个 patch 的 token 个数
|
||||||
|
patch_unique_cnt_counter = Counter()
|
||||||
|
patch_unique_cnt_in_one_sentence = set()
|
||||||
|
for pred_text_tokenized, pred_context in zip(pred_text_tokenized_list, context.preds):
|
||||||
|
token_cnt = len(pred_text_tokenized)
|
||||||
|
patch_unique_cnt_in_one_sentence.add(token_cnt)
|
||||||
|
if pred_context.recognition_results.final_result:
|
||||||
|
for unique_cnt in patch_unique_cnt_in_one_sentence:
|
||||||
|
patch_unique_cnt_counter[unique_cnt] += 1
|
||||||
|
patch_unique_cnt_in_one_sentence.clear()
|
||||||
|
if context.preds and not context.preds[-1].recognition_results.final_result:
|
||||||
|
for unique_cnt in patch_unique_cnt_in_one_sentence:
|
||||||
|
patch_unique_cnt_counter[unique_cnt] += 1
|
||||||
|
# print(patch_unique_cnt_counter)
|
||||||
|
logger.info(
|
||||||
|
f"当前音频的 patch token 均值为 {mean_on_counter(patch_unique_cnt_counter)}, \
|
||||||
|
当前音频的 patch token 方差为 {var_on_counter(patch_unique_cnt_counter)}"
|
||||||
|
)
|
||||||
|
return patch_unique_cnt_counter
|
||||||
|
|
||||||
|
|
||||||
|
def mean_on_counter(counter: Counter):
|
||||||
|
total_sum = sum(key * count for key, count in counter.items())
|
||||||
|
total_count = sum(counter.values())
|
||||||
|
return total_sum * 1.0 / total_count
|
||||||
|
|
||||||
|
|
||||||
|
def var_on_counter(counter: Counter):
|
||||||
|
total_sum = sum(key * count for key, count in counter.items())
|
||||||
|
total_count = sum(counter.values())
|
||||||
|
mean = total_sum * 1.0 / total_count
|
||||||
|
return sum((key - mean) ** 2 * count for key, count in counter.items()) / total_count
|
||||||
|
|
||||||
|
|
||||||
|
def edit_distance(arr1: List, arr2: List):
|
||||||
|
operations = Levenshtein.editops(arr1, arr2)
|
||||||
|
i = sum([1 for operation in operations if operation[0] == "insert"])
|
||||||
|
s = sum([1 for operation in operations if operation[0] == "replace"])
|
||||||
|
d = sum([1 for operation in operations if operation[0] == "delete"])
|
||||||
|
c = len(arr1) - s - d
|
||||||
|
return s, d, i, c
|
||||||
|
|
||||||
|
|
||||||
|
def cer(tokens_gt_mapping: List[str], tokens_dt_mapping: List[str]):
|
||||||
|
"""输入的是经过编辑距离映射后的两个 token 序列,返回 1-cer, token-cnt"""
|
||||||
|
insert = sum(1 for item in tokens_gt_mapping if item is None)
|
||||||
|
delete = sum(1 for item in tokens_dt_mapping if item is None)
|
||||||
|
equal = sum(1 for token_gt, token_dt in zip(tokens_gt_mapping, tokens_dt_mapping) if token_gt == token_dt)
|
||||||
|
replace = len(tokens_gt_mapping) - insert - equal
|
||||||
|
|
||||||
|
token_count = replace + equal + delete
|
||||||
|
cer_value = (replace + delete + insert) * 1.0 / token_count
|
||||||
|
logger.info(f"当前音频的 cer/wer 值为 {cer_value}, token 个数为 {token_count}")
|
||||||
|
return 1 - cer_value, token_count
|
||||||
|
|
||||||
|
|
||||||
|
def cut_rate(
|
||||||
|
tokens_gt: List[List[str]],
|
||||||
|
tokens_dt: List[List[str]],
|
||||||
|
tokens_gt_mapping: List[str],
|
||||||
|
tokens_dt_mapping: List[str],
|
||||||
|
):
|
||||||
|
sentence_final_token_index_gt = sentence_final_token_index(tokens_gt, tokens_gt_mapping)
|
||||||
|
sentence_final_token_index_dt = sentence_final_token_index(tokens_dt, tokens_dt_mapping)
|
||||||
|
sentence_final_token_index_gt = set(sentence_final_token_index_gt)
|
||||||
|
sentence_final_token_index_dt = set(sentence_final_token_index_dt)
|
||||||
|
sentence_count_gt = len(sentence_final_token_index_gt)
|
||||||
|
miss_count = len(sentence_final_token_index_gt - sentence_final_token_index_dt)
|
||||||
|
more_count = len(sentence_final_token_index_dt - sentence_final_token_index_gt)
|
||||||
|
rate = max(1 - (miss_count + more_count * 2) / sentence_count_gt, 0)
|
||||||
|
return rate, sentence_count_gt, miss_count, more_count
|
||||||
|
|
||||||
|
|
||||||
|
def token_mapping(tokens_gt: List[str], tokens_dt: List[str]) -> Tuple[List[str], List[str]]:
|
||||||
|
arr1 = deepcopy(tokens_gt)
|
||||||
|
arr2 = deepcopy(tokens_dt)
|
||||||
|
operations = Levenshtein.editops(arr1, arr2)
|
||||||
|
for op in operations[::-1]:
|
||||||
|
if op[0] == "insert":
|
||||||
|
arr1.insert(op[1], None)
|
||||||
|
elif op[0] == "delete":
|
||||||
|
arr2.insert(op[2], None)
|
||||||
|
return arr1, arr2
|
||||||
|
|
||||||
|
|
||||||
|
def sentence_final_token_index(tokens: List[List[str]], tokens_mapping: List[str]) -> List[int]:
|
||||||
|
"""获得原句子中每个句子尾部 token 的 index"""
|
||||||
|
token_index_list = []
|
||||||
|
token_index = 0
|
||||||
|
for token_in_one_sentence in tokens:
|
||||||
|
for _ in range(len(token_in_one_sentence)):
|
||||||
|
while token_index < len(tokens_mapping) and tokens_mapping[token_index] is None:
|
||||||
|
token_index += 1
|
||||||
|
token_index += 1
|
||||||
|
token_index_list.append(token_index - 1)
|
||||||
|
return token_index_list
|
||||||
|
|
||||||
|
|
||||||
|
def cut_sentence(sentences: List[str], tokenizerType: TokenizerType) -> List[str]:
|
||||||
|
"""use self.cut_punc to cut all sentences, merge them and put them into list"""
|
||||||
|
sentence_cut_list = []
|
||||||
|
for sentence in sentences:
|
||||||
|
sentence_list = [sentence]
|
||||||
|
sentence_tmp_list = []
|
||||||
|
for punc in [
|
||||||
|
"······",
|
||||||
|
"......",
|
||||||
|
"。",
|
||||||
|
",",
|
||||||
|
"?",
|
||||||
|
"!",
|
||||||
|
";",
|
||||||
|
":",
|
||||||
|
"...",
|
||||||
|
".",
|
||||||
|
",",
|
||||||
|
"?",
|
||||||
|
"!",
|
||||||
|
";",
|
||||||
|
":",
|
||||||
|
]:
|
||||||
|
for sentence in sentence_list:
|
||||||
|
sentence_tmp_list.extend(sentence.split(punc))
|
||||||
|
sentence_list, sentence_tmp_list = sentence_tmp_list, []
|
||||||
|
sentence_list = [item for item in sentence_list if item]
|
||||||
|
|
||||||
|
if tokenizerType == TokenizerType.whitespace:
|
||||||
|
sentence_cut_list.append(" ".join(sentence_list))
|
||||||
|
else:
|
||||||
|
sentence_cut_list.append("".join(sentence_list))
|
||||||
|
|
||||||
|
return sentence_cut_list
|
||||||
63
utils/model.py
Normal file
63
utils/model.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing import (
|
||||||
|
Optional,
|
||||||
|
List,
|
||||||
|
Any
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class WordModel(BaseModel):
|
||||||
|
text: str
|
||||||
|
start_time: int # 或 float,取决时间戳格式
|
||||||
|
end_time: int
|
||||||
|
segment: Optional[Any] = Field(default=None, exclude=True) # 所属文段
|
||||||
|
# receive_time: Optional[Any] = None # 所属文段接收到的时间偏移,这里为了处理时方便,记录了ASRResultModel中的receive_time
|
||||||
|
class Config:
|
||||||
|
fields = {
|
||||||
|
'segment': {'exclude': True}
|
||||||
|
}
|
||||||
|
|
||||||
|
class SegmentModel(BaseModel):
|
||||||
|
# 文段接收到的时间
|
||||||
|
receive_time: Optional[Any] = None
|
||||||
|
language: str
|
||||||
|
para_seq: int
|
||||||
|
final_result: bool
|
||||||
|
text: str
|
||||||
|
start_time: int # 或者 float,如果时间戳是毫秒精度
|
||||||
|
end_time: int
|
||||||
|
words: List[WordModel] # 补充 words 字段
|
||||||
|
|
||||||
|
def summary(self) -> str:
|
||||||
|
duration = (self.end_time - self.start_time) / 1000 # 秒
|
||||||
|
return (
|
||||||
|
f"\n"
|
||||||
|
f"language:{self.language} \n"
|
||||||
|
f"para_seq:{self.para_seq} \n"
|
||||||
|
f"final_result {self.final_result}\n"
|
||||||
|
f"text:{self.text}\n"
|
||||||
|
f"words:[{', '.join(w.text for w in self.words)}]\n"
|
||||||
|
f"start_time:{self.start_time}\n"
|
||||||
|
f"end_time:{self.end_time}\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ASRResponseModel(BaseModel):
|
||||||
|
asr_results: SegmentModel
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class VoiceSegment(BaseModel):
|
||||||
|
answer: str
|
||||||
|
start: float
|
||||||
|
end: float
|
||||||
|
|
||||||
|
|
||||||
|
class AudioItem(BaseModel):
|
||||||
|
audio_length: float
|
||||||
|
duration: Optional[float] = None
|
||||||
|
file: str
|
||||||
|
orig_file: Optional[str] = None
|
||||||
|
voice: List[VoiceSegment]
|
||||||
|
absolute_path: Optional[str] = None
|
||||||
|
|
||||||
23
utils/platform_tools.py
Normal file
23
utils/platform_tools.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
import os
|
||||||
|
from utils.logger import logger
|
||||||
|
import requests
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
def mark_not_available():
|
||||||
|
headers = {"Content-Type": "application/json"}
|
||||||
|
if os.getenv("LEADERBOARD_API_TOKEN"):
|
||||||
|
headers["Authorization"] = "Bearer " + os.getenv("LEADERBOARD_API_TOKEN")
|
||||||
|
|
||||||
|
logger.info("更改为产品不可用...")
|
||||||
|
try:
|
||||||
|
submit_id = str(os.getenv("SUBMIT_ID", "-1"))
|
||||||
|
resp = requests.post(
|
||||||
|
os.getenv("UPDATE_SUBMIT_URL", "http://contest.4pd.io:8080/submit/update"),
|
||||||
|
data=json.dumps({submit_id: {"product_avaliable": 0}}),
|
||||||
|
headers=headers,
|
||||||
|
timeout=600,
|
||||||
|
)
|
||||||
|
logger.info(resp.json())
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"change product available error, {e}")
|
||||||
138
utils/reader.py
Normal file
138
utils/reader.py
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
from typing import (
|
||||||
|
Tuple,
|
||||||
|
List, Any
|
||||||
|
)
|
||||||
|
from utils.model import AudioItem
|
||||||
|
import os
|
||||||
|
import zipfile
|
||||||
|
import tarfile
|
||||||
|
import gzip
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def read_data(dataset_filepath: str) -> Tuple[str, List[AudioItem]]:
|
||||||
|
"""
|
||||||
|
读取数据文件,返回语言和文本列表
|
||||||
|
|
||||||
|
参数:
|
||||||
|
dataset_filepath (str): 数据文件路径
|
||||||
|
|
||||||
|
返回:
|
||||||
|
Tuple[str, List[str]]:
|
||||||
|
- language: 文件中指定的语言字符串
|
||||||
|
- datas: 文件中除语言行以外的文本列表(每行为一个元素)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 认为都是压缩包,先解压数据。
|
||||||
|
data_extract_path = "/tmp/datas"
|
||||||
|
data_yaml_path = extract_file(dataset_filepath, data_extract_path)
|
||||||
|
if not data_yaml_path:
|
||||||
|
raise ValueError(f"未找到数据集data.yaml文件。")
|
||||||
|
dataset_filepath = str(Path(data_yaml_path).parent.resolve())
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception(e)
|
||||||
|
|
||||||
|
|
||||||
|
with open(f"{dataset_filepath}/data.yaml") as f:
|
||||||
|
datas = yaml.safe_load(f)
|
||||||
|
language = datas.get("global", {}).get("lang", "zh")
|
||||||
|
query_data = datas.get("query_data", [])
|
||||||
|
|
||||||
|
audios = []
|
||||||
|
"""
|
||||||
|
- audio_length: 1.0099999999997635
|
||||||
|
duration: 1.0099999999997635
|
||||||
|
file: zh/0.wav
|
||||||
|
orig_file: ./112801_1/112801_1-631-772.wav
|
||||||
|
voice:
|
||||||
|
- answer: 好吧。
|
||||||
|
end: 1.0099999999997635
|
||||||
|
start: 0
|
||||||
|
"""
|
||||||
|
for item in query_data:
|
||||||
|
audio = AudioItem.model_validate(item)
|
||||||
|
audio.absolute_path = f"{dataset_filepath}/{audio.file}"
|
||||||
|
audios.append(audio)
|
||||||
|
|
||||||
|
return (
|
||||||
|
language,
|
||||||
|
audios
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_file(filepath: str, output_dir: str = ".") -> None:
|
||||||
|
"""
|
||||||
|
将数据集解压到指定路径,返回data.yaml文件的路径
|
||||||
|
"""
|
||||||
|
if not os.path.exists(output_dir):
|
||||||
|
os.makedirs(output_dir)
|
||||||
|
data_yaml_path = None
|
||||||
|
# 硬编码一下 leaderboard_data_samples 数据集没有加拓展名
|
||||||
|
# if filepath.endswith(".zip") or filepath.endswith("leaderboard_data_samples"):
|
||||||
|
# with zipfile.ZipFile(filepath, 'r') as zf:
|
||||||
|
# # 获取所有文件(非目录)
|
||||||
|
# all_files = [f for f in zf.namelist() if not f.endswith('/')]
|
||||||
|
#
|
||||||
|
# if not all_files:
|
||||||
|
# raise ValueError(f"数据集文件为空。{filepath}")
|
||||||
|
#
|
||||||
|
# # 获取公共路径前缀
|
||||||
|
# parts_list = [Path(f).parts for f in all_files]
|
||||||
|
# common_parts = os.path.commonprefix(parts_list)
|
||||||
|
# strip_prefix_len = len(common_parts)
|
||||||
|
#
|
||||||
|
# for file in all_files:
|
||||||
|
# file_parts = Path(file).parts
|
||||||
|
# relative_parts = file_parts[strip_prefix_len:]
|
||||||
|
# dest_path = Path(output_dir).joinpath(*relative_parts)
|
||||||
|
#
|
||||||
|
# # 创建父目录
|
||||||
|
# dest_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
#
|
||||||
|
# # 解压写入
|
||||||
|
# with zf.open(file) as source, open(dest_path, "wb") as target:
|
||||||
|
# shutil.copyfileobj(source, target)
|
||||||
|
#
|
||||||
|
# # 检查是否是 data.yaml
|
||||||
|
# if Path(file).name == "data.yaml":
|
||||||
|
# data_yaml_path = str(dest_path.resolve())
|
||||||
|
#
|
||||||
|
# logging.info(f"数据集解压成功。")
|
||||||
|
# else:
|
||||||
|
# raise ValueError(f"暂时不支持的压缩格式。{filepath}")
|
||||||
|
# TODO 使用的是已有的数据,都不是按照 zip结尾命名的,强制按照zip解压
|
||||||
|
with zipfile.ZipFile(filepath, 'r') as zf:
|
||||||
|
# 获取所有文件(非目录)
|
||||||
|
all_files = [f for f in zf.namelist() if not f.endswith('/')]
|
||||||
|
|
||||||
|
if not all_files:
|
||||||
|
raise ValueError(f"数据集文件为空。{filepath}")
|
||||||
|
|
||||||
|
# 获取公共路径前缀
|
||||||
|
parts_list = [Path(f).parts for f in all_files]
|
||||||
|
common_parts = os.path.commonprefix(parts_list)
|
||||||
|
strip_prefix_len = len(common_parts)
|
||||||
|
|
||||||
|
for file in all_files:
|
||||||
|
file_parts = Path(file).parts
|
||||||
|
relative_parts = file_parts[strip_prefix_len:]
|
||||||
|
dest_path = Path(output_dir).joinpath(*relative_parts)
|
||||||
|
|
||||||
|
# 创建父目录
|
||||||
|
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# 解压写入
|
||||||
|
with zf.open(file) as source, open(dest_path, "wb") as target:
|
||||||
|
shutil.copyfileobj(source, target)
|
||||||
|
|
||||||
|
# 检查是否是 data.yaml
|
||||||
|
if Path(file).name == "data.yaml":
|
||||||
|
data_yaml_path = str(dest_path.resolve())
|
||||||
|
|
||||||
|
logging.info(f"数据集解压成功。")
|
||||||
|
|
||||||
|
return data_yaml_path
|
||||||
46
utils/service.py
Normal file
46
utils/service.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from utils.helm import deploy_chart, gen_chart_tarball
|
||||||
|
from utils.logger import logger
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
UNIT_TEST = os.getenv("UNIT_TEST", 0)
|
||||||
|
|
||||||
|
|
||||||
|
def register_sut(st_config, resource_name, **kwargs):
|
||||||
|
st_config_values = st_config.get("values", {})
|
||||||
|
docker_image = st_config_values["docker_image"]
|
||||||
|
image_pull_policy = st_config_values["image_pull_policy"]
|
||||||
|
chart_tar_fp, chart_values = gen_chart_tarball(docker_image, image_pull_policy)
|
||||||
|
sut_service_name, _ = deploy_chart(
|
||||||
|
resource_name,
|
||||||
|
int(os.getenv("readiness_timeout", 60 * 3)),
|
||||||
|
chart_fileobj=chart_tar_fp,
|
||||||
|
extra_values=st_config_values,
|
||||||
|
restart_count_limit=int(os.getenv('restart_count', 3)),
|
||||||
|
)
|
||||||
|
chart_tar_fp.close()
|
||||||
|
sut_service_port = str(chart_values["service"]["port"])
|
||||||
|
return "ws://{}:{}".format(sut_service_name, sut_service_port)
|
||||||
|
|
||||||
|
|
||||||
|
def start_server(submit_config_filepath: str, language: str):
|
||||||
|
resource_name = "model-server"
|
||||||
|
# 读取提交配置 & 修改配置信息 & 启动被测服务
|
||||||
|
with open(submit_config_filepath, "r") as fp:
|
||||||
|
st_config = yaml.safe_load(fp)
|
||||||
|
from utils.helm import resource_check
|
||||||
|
if language == "zh":
|
||||||
|
image = "harbor-contest.4pd.io/yuxiaojie/judge_flow/asr-live-iluvatar/asr_engine_zh_semantic:contest-v0"
|
||||||
|
elif language == "en":
|
||||||
|
image = "harbor-contest.4pd.io/yuxiaojie/judge_flow/asr-live-iluvatar/asr_engine_en_semantic:contest-v0"
|
||||||
|
else:
|
||||||
|
image = ""
|
||||||
|
st_config["values"] = resource_check(st_config.get("values", {}), image)
|
||||||
|
sut_url = register_sut(st_config, resource_name)
|
||||||
|
print(sut_url)
|
||||||
|
return sut_url
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
start_server("/Users/yu/Documents/code-work/asr-live-iluvatar/script/config.yaml")
|
||||||
3
utils/speechio/__init__.py
Normal file
3
utils/speechio/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
'''
|
||||||
|
reference: https://github.com/SpeechColab/Leaderboard/tree/f287a992dc359d1c021bfc6ce810e5e36608e057/utils
|
||||||
|
'''
|
||||||
BIN
utils/speechio/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
utils/speechio/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
utils/speechio/__pycache__/textnorm_zh.cpython-310.pyc
Normal file
BIN
utils/speechio/__pycache__/textnorm_zh.cpython-310.pyc
Normal file
Binary file not shown.
551
utils/speechio/error_rate_en.py
Normal file
551
utils/speechio/error_rate_en.py
Normal file
@@ -0,0 +1,551 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# coding=utf8
|
||||||
|
# Copyright 2022 Zhenxiang MA, Jiayu DU (SpeechColab)
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import csv
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from typing import Iterable
|
||||||
|
|
||||||
|
logging.basicConfig(stream=sys.stderr, level=logging.ERROR, format='[%(levelname)s] %(message)s')
|
||||||
|
|
||||||
|
import pynini
|
||||||
|
from pynini.lib import pynutil
|
||||||
|
|
||||||
|
|
||||||
|
# reference: https://github.com/kylebgorman/pynini/blob/master/pynini/lib/edit_transducer.py
|
||||||
|
# to import original lib:
|
||||||
|
# from pynini.lib.edit_transducer import EditTransducer
|
||||||
|
class EditTransducer:
|
||||||
|
DELETE = "<delete>"
|
||||||
|
INSERT = "<insert>"
|
||||||
|
SUBSTITUTE = "<substitute>"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
symbol_table,
|
||||||
|
vocab: Iterable[str],
|
||||||
|
insert_cost: float = 1.0,
|
||||||
|
delete_cost: float = 1.0,
|
||||||
|
substitute_cost: float = 1.0,
|
||||||
|
bound: int = 0,
|
||||||
|
):
|
||||||
|
# Left factor; note that we divide the edit costs by two because they also
|
||||||
|
# will be incurred when traversing the right factor.
|
||||||
|
sigma = pynini.union(
|
||||||
|
*[pynini.accep(token, token_type=symbol_table) for token in vocab],
|
||||||
|
).optimize()
|
||||||
|
|
||||||
|
insert = pynutil.insert(f"[{self.INSERT}]", weight=insert_cost / 2)
|
||||||
|
delete = pynini.cross(sigma, pynini.accep(f"[{self.DELETE}]", weight=delete_cost / 2))
|
||||||
|
substitute = pynini.cross(sigma, pynini.accep(f"[{self.SUBSTITUTE}]", weight=substitute_cost / 2))
|
||||||
|
|
||||||
|
edit = pynini.union(insert, delete, substitute).optimize()
|
||||||
|
|
||||||
|
if bound:
|
||||||
|
sigma_star = pynini.closure(sigma)
|
||||||
|
self._e_i = sigma_star.copy()
|
||||||
|
for _ in range(bound):
|
||||||
|
self._e_i.concat(edit.ques).concat(sigma_star)
|
||||||
|
else:
|
||||||
|
self._e_i = edit.union(sigma).closure()
|
||||||
|
|
||||||
|
self._e_i.optimize()
|
||||||
|
|
||||||
|
right_factor_std = EditTransducer._right_factor(self._e_i)
|
||||||
|
# right_factor_ext allows 0-cost matching between token's raw form & auxiliary form
|
||||||
|
# e.g.: 'I' -> 'I#', 'AM' -> 'AM#'
|
||||||
|
right_factor_ext = (
|
||||||
|
pynini.union(
|
||||||
|
*[
|
||||||
|
pynini.cross(
|
||||||
|
pynini.accep(x, token_type=symbol_table),
|
||||||
|
pynini.accep(x + '#', token_type=symbol_table),
|
||||||
|
)
|
||||||
|
for x in vocab
|
||||||
|
]
|
||||||
|
)
|
||||||
|
.optimize()
|
||||||
|
.closure()
|
||||||
|
)
|
||||||
|
self._e_o = pynini.union(right_factor_std, right_factor_ext).closure().optimize()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _right_factor(ifst: pynini.Fst) -> pynini.Fst:
|
||||||
|
ofst = pynini.invert(ifst)
|
||||||
|
syms = pynini.generated_symbols()
|
||||||
|
insert_label = syms.find(EditTransducer.INSERT)
|
||||||
|
delete_label = syms.find(EditTransducer.DELETE)
|
||||||
|
pairs = [(insert_label, delete_label), (delete_label, insert_label)]
|
||||||
|
right_factor = ofst.relabel_pairs(ipairs=pairs)
|
||||||
|
return right_factor
|
||||||
|
|
||||||
|
def create_lattice(self, iexpr: pynini.FstLike, oexpr: pynini.FstLike) -> pynini.Fst:
|
||||||
|
lattice = (iexpr @ self._e_i) @ (self._e_o @ oexpr)
|
||||||
|
EditTransducer.check_wellformed_lattice(lattice)
|
||||||
|
return lattice
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def check_wellformed_lattice(lattice: pynini.Fst) -> None:
|
||||||
|
if lattice.start() == pynini.NO_STATE_ID:
|
||||||
|
raise RuntimeError("Edit distance composition lattice is empty.")
|
||||||
|
|
||||||
|
def compute_distance(self, iexpr: pynini.FstLike, oexpr: pynini.FstLike) -> float:
|
||||||
|
lattice = self.create_lattice(iexpr, oexpr)
|
||||||
|
# The shortest cost from all final states to the start state is
|
||||||
|
# equivalent to the cost of the shortest path.
|
||||||
|
start = lattice.start()
|
||||||
|
return float(pynini.shortestdistance(lattice, reverse=True)[start])
|
||||||
|
|
||||||
|
def compute_alignment(self, iexpr: pynini.FstLike, oexpr: pynini.FstLike) -> pynini.FstLike:
|
||||||
|
print(iexpr)
|
||||||
|
print(oexpr)
|
||||||
|
lattice = self.create_lattice(iexpr, oexpr)
|
||||||
|
alignment = pynini.shortestpath(lattice, nshortest=1, unique=True)
|
||||||
|
return alignment.optimize()
|
||||||
|
|
||||||
|
|
||||||
|
class ErrorStats:
|
||||||
|
def __init__(self):
|
||||||
|
self.num_ref_utts = 0
|
||||||
|
self.num_hyp_utts = 0
|
||||||
|
self.num_eval_utts = 0 # in both ref & hyp
|
||||||
|
self.num_hyp_without_ref = 0
|
||||||
|
|
||||||
|
self.C = 0
|
||||||
|
self.S = 0
|
||||||
|
self.I = 0
|
||||||
|
self.D = 0
|
||||||
|
self.token_error_rate = 0.0
|
||||||
|
self.modified_token_error_rate = 0.0
|
||||||
|
|
||||||
|
self.num_utts_with_error = 0
|
||||||
|
self.sentence_error_rate = 0.0
|
||||||
|
|
||||||
|
def to_json(self):
|
||||||
|
# return json.dumps(self.__dict__, indent=4)
|
||||||
|
return json.dumps(self.__dict__)
|
||||||
|
|
||||||
|
def to_kaldi(self):
|
||||||
|
info = (
|
||||||
|
F'%WER {self.token_error_rate:.2f} [ {self.S + self.D + self.I} / {self.C + self.S + self.D}, {self.I} ins, {self.D} del, {self.S} sub ]\n'
|
||||||
|
F'%SER {self.sentence_error_rate:.2f} [ {self.num_utts_with_error} / {self.num_eval_utts} ]\n'
|
||||||
|
)
|
||||||
|
return info
|
||||||
|
|
||||||
|
def to_summary(self):
|
||||||
|
summary = (
|
||||||
|
'==================== Overall Statistics ====================\n'
|
||||||
|
F'num_ref_utts: {self.num_ref_utts}\n'
|
||||||
|
F'num_hyp_utts: {self.num_hyp_utts}\n'
|
||||||
|
F'num_hyp_without_ref: {self.num_hyp_without_ref}\n'
|
||||||
|
F'num_eval_utts: {self.num_eval_utts}\n'
|
||||||
|
F'sentence_error_rate: {self.sentence_error_rate:.2f}%\n'
|
||||||
|
F'token_error_rate: {self.token_error_rate:.2f}%\n'
|
||||||
|
F'modified_token_error_rate: {self.modified_token_error_rate:.2f}%\n'
|
||||||
|
F'token_stats:\n'
|
||||||
|
F' - tokens:{self.C + self.S + self.D:>7}\n'
|
||||||
|
F' - edits: {self.S + self.I + self.D:>7}\n'
|
||||||
|
F' - cor: {self.C:>7}\n'
|
||||||
|
F' - sub: {self.S:>7}\n'
|
||||||
|
F' - ins: {self.I:>7}\n'
|
||||||
|
F' - del: {self.D:>7}\n'
|
||||||
|
'============================================================\n'
|
||||||
|
)
|
||||||
|
return summary
|
||||||
|
|
||||||
|
|
||||||
|
class Utterance:
|
||||||
|
def __init__(self, uid, text):
|
||||||
|
self.uid = uid
|
||||||
|
self.text = text
|
||||||
|
|
||||||
|
|
||||||
|
def LoadKaldiArc(filepath):
|
||||||
|
utts = {}
|
||||||
|
with open(filepath, 'r', encoding='utf8') as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if line:
|
||||||
|
cols = line.split(maxsplit=1)
|
||||||
|
assert len(cols) == 2 or len(cols) == 1
|
||||||
|
uid = cols[0]
|
||||||
|
text = cols[1] if len(cols) == 2 else ''
|
||||||
|
if utts.get(uid) != None:
|
||||||
|
raise RuntimeError(F'Found duplicated utterence id {uid}')
|
||||||
|
utts[uid] = Utterance(uid, text)
|
||||||
|
return utts
|
||||||
|
|
||||||
|
|
||||||
|
def BreakHyphen(token: str):
|
||||||
|
# 'T-SHIRT' should also introduce new words into vocabulary, e.g.:
|
||||||
|
# 1. 'T' & 'SHIRT'
|
||||||
|
# 2. 'TSHIRT'
|
||||||
|
assert '-' in token
|
||||||
|
v = token.split('-')
|
||||||
|
v.append(token.replace('-', ''))
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
def LoadGLM(rel_path):
|
||||||
|
'''
|
||||||
|
glm.csv:
|
||||||
|
I'VE,I HAVE
|
||||||
|
GOING TO,GONNA
|
||||||
|
...
|
||||||
|
T-SHIRT,T SHIRT,TSHIRT
|
||||||
|
|
||||||
|
glm:
|
||||||
|
{
|
||||||
|
'<RULE_00000>': ["I'VE", 'I HAVE'],
|
||||||
|
'<RULE_00001>': ['GOING TO', 'GONNA'],
|
||||||
|
...
|
||||||
|
'<RULE_99999>': ['T-SHIRT', 'T SHIRT', 'TSHIRT'],
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
logging.info(f'Loading GLM from {rel_path} ...')
|
||||||
|
|
||||||
|
abs_path = os.path.dirname(os.path.abspath(__file__)) + '/' + rel_path
|
||||||
|
reader = list(csv.reader(open(abs_path, encoding="utf-8"), delimiter=','))
|
||||||
|
|
||||||
|
glm = {}
|
||||||
|
for k, rule in enumerate(reader):
|
||||||
|
rule_name = f'<RULE_{k:06d}>'
|
||||||
|
glm[rule_name] = [phrase.strip() for phrase in rule]
|
||||||
|
logging.info(f' #rule: {len(glm)}')
|
||||||
|
|
||||||
|
return glm
|
||||||
|
|
||||||
|
|
||||||
|
def SymbolEQ(symbol_table, i1, i2):
|
||||||
|
return symbol_table.find(i1).strip('#') == symbol_table.find(i2).strip('#')
|
||||||
|
|
||||||
|
|
||||||
|
def PrintSymbolTable(symbol_table: pynini.SymbolTable):
|
||||||
|
print('SYMBOL_TABLE:')
|
||||||
|
for k in range(symbol_table.num_symbols()):
|
||||||
|
sym = symbol_table.find(k)
|
||||||
|
assert symbol_table.find(sym) == k # symbol table's find can be used for bi-directional lookup (id <-> sym)
|
||||||
|
print(k, sym)
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
def BuildSymbolTable(vocab) -> pynini.SymbolTable:
|
||||||
|
logging.info('Building symbol table ...')
|
||||||
|
symbol_table = pynini.SymbolTable()
|
||||||
|
symbol_table.add_symbol('<epsilon>')
|
||||||
|
|
||||||
|
for w in vocab:
|
||||||
|
symbol_table.add_symbol(w)
|
||||||
|
logging.info(f' #symbols: {symbol_table.num_symbols()}')
|
||||||
|
|
||||||
|
# PrintSymbolTable(symbol_table)
|
||||||
|
# symbol_table.write_text('symbol_table.txt')
|
||||||
|
return symbol_table
|
||||||
|
|
||||||
|
|
||||||
|
def BuildGLMTagger(glm, symbol_table) -> pynini.Fst:
|
||||||
|
logging.info('Building GLM tagger ...')
|
||||||
|
rule_taggers = []
|
||||||
|
for rule_tag, rule in glm.items():
|
||||||
|
for phrase in rule:
|
||||||
|
rule_taggers.append(
|
||||||
|
(
|
||||||
|
pynutil.insert(pynini.accep(rule_tag, token_type=symbol_table))
|
||||||
|
+ pynini.accep(phrase, token_type=symbol_table)
|
||||||
|
+ pynutil.insert(pynini.accep(rule_tag, token_type=symbol_table))
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
alphabet = pynini.union(
|
||||||
|
*[pynini.accep(sym, token_type=symbol_table) for k, sym in symbol_table if k != 0] # non-epsilon
|
||||||
|
).optimize()
|
||||||
|
|
||||||
|
tagger = pynini.cdrewrite(
|
||||||
|
pynini.union(*rule_taggers).optimize(), '', '', alphabet.closure()
|
||||||
|
).optimize() # could be slow with large vocabulary
|
||||||
|
return tagger
|
||||||
|
|
||||||
|
|
||||||
|
def TokenWidth(token: str):
|
||||||
|
def CharWidth(c):
|
||||||
|
return 2 if (c >= '\u4e00') and (c <= '\u9fa5') else 1
|
||||||
|
|
||||||
|
return sum([CharWidth(c) for c in token])
|
||||||
|
|
||||||
|
|
||||||
|
def PrintPrettyAlignment(raw_hyp, edit_ali, ref_ali, hyp_ali, stream=sys.stderr):
|
||||||
|
assert len(edit_ali) == len(ref_ali) and len(ref_ali) == len(hyp_ali)
|
||||||
|
|
||||||
|
H = ' HYP# : '
|
||||||
|
R = ' REF : '
|
||||||
|
E = ' EDIT : '
|
||||||
|
for i, e in enumerate(edit_ali):
|
||||||
|
h, r = hyp_ali[i], ref_ali[i]
|
||||||
|
e = '' if e == 'C' else e # don't bother printing correct edit-tag
|
||||||
|
|
||||||
|
nr, nh, ne = TokenWidth(r), TokenWidth(h), TokenWidth(e)
|
||||||
|
n = max(nr, nh, ne) + 1
|
||||||
|
|
||||||
|
H += h + ' ' * (n - nh)
|
||||||
|
R += r + ' ' * (n - nr)
|
||||||
|
E += e + ' ' * (n - ne)
|
||||||
|
|
||||||
|
print(F' HYP : {raw_hyp}', file=stream)
|
||||||
|
print(H, file=stream)
|
||||||
|
print(R, file=stream)
|
||||||
|
print(E, file=stream)
|
||||||
|
|
||||||
|
|
||||||
|
def ComputeTokenErrorRate(c, s, i, d):
|
||||||
|
assert (s + d + c) != 0
|
||||||
|
num_edits = s + d + i
|
||||||
|
ref_len = c + s + d
|
||||||
|
hyp_len = c + s + i
|
||||||
|
return 100.0 * num_edits / ref_len, 100.0 * num_edits / max(ref_len, hyp_len)
|
||||||
|
|
||||||
|
|
||||||
|
def ComputeSentenceErrorRate(num_err_utts, num_utts):
|
||||||
|
assert num_utts != 0
|
||||||
|
return 100.0 * num_err_utts / num_utts
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--logk', type=int, default=500, help='logging interval')
|
||||||
|
parser.add_argument(
|
||||||
|
'--tokenizer', choices=['whitespace', 'char'], default='whitespace', help='whitespace for WER, char for CER'
|
||||||
|
)
|
||||||
|
parser.add_argument('--glm', type=str, default='glm_en.csv', help='glm')
|
||||||
|
parser.add_argument('--ref', type=str, required=True, help='reference kaldi arc file')
|
||||||
|
parser.add_argument('--hyp', type=str, required=True, help='hypothesis kaldi arc file')
|
||||||
|
parser.add_argument('result_file', type=str)
|
||||||
|
args = parser.parse_args()
|
||||||
|
logging.info(args)
|
||||||
|
|
||||||
|
stats = ErrorStats()
|
||||||
|
|
||||||
|
logging.info('Generating tokenizer ...')
|
||||||
|
if args.tokenizer == 'whitespace':
|
||||||
|
|
||||||
|
def word_tokenizer(text):
|
||||||
|
return text.strip().split()
|
||||||
|
|
||||||
|
tokenizer = word_tokenizer
|
||||||
|
elif args.tokenizer == 'char':
|
||||||
|
|
||||||
|
def char_tokenizer(text):
|
||||||
|
return [c for c in text.strip().replace(' ', '')]
|
||||||
|
|
||||||
|
tokenizer = char_tokenizer
|
||||||
|
else:
|
||||||
|
tokenizer = None
|
||||||
|
assert tokenizer
|
||||||
|
|
||||||
|
logging.info('Loading REF & HYP ...')
|
||||||
|
ref_utts = LoadKaldiArc(args.ref)
|
||||||
|
hyp_utts = LoadKaldiArc(args.hyp)
|
||||||
|
|
||||||
|
# check valid utterances in hyp that have matched non-empty reference
|
||||||
|
uids = []
|
||||||
|
for uid in sorted(hyp_utts.keys()):
|
||||||
|
if uid in ref_utts.keys():
|
||||||
|
if ref_utts[uid].text.strip(): # non-empty reference
|
||||||
|
uids.append(uid)
|
||||||
|
else:
|
||||||
|
logging.warning(F'Found {uid} with empty reference, skipping...')
|
||||||
|
else:
|
||||||
|
logging.warning(F'Found {uid} without reference, skipping...')
|
||||||
|
stats.num_hyp_without_ref += 1
|
||||||
|
|
||||||
|
stats.num_hyp_utts = len(hyp_utts)
|
||||||
|
stats.num_ref_utts = len(ref_utts)
|
||||||
|
stats.num_eval_utts = len(uids)
|
||||||
|
logging.info(f' #hyp:{stats.num_hyp_utts}, #ref:{stats.num_ref_utts}, #utts_to_evaluate:{stats.num_eval_utts}')
|
||||||
|
print(f' #hyp:{stats.num_hyp_utts}, #ref:{stats.num_ref_utts}, #utts_to_evaluate:{stats.num_eval_utts}')
|
||||||
|
|
||||||
|
tokens = []
|
||||||
|
for uid in uids:
|
||||||
|
ref_tokens = tokenizer(ref_utts[uid].text)
|
||||||
|
hyp_tokens = tokenizer(hyp_utts[uid].text)
|
||||||
|
for t in ref_tokens + hyp_tokens:
|
||||||
|
tokens.append(t)
|
||||||
|
if '-' in t:
|
||||||
|
tokens.extend(BreakHyphen(t))
|
||||||
|
vocab_from_utts = list(set(tokens))
|
||||||
|
logging.info(f' HYP&REF vocab size: {len(vocab_from_utts)}')
|
||||||
|
print(f' HYP&REF vocab size: {len(vocab_from_utts)}')
|
||||||
|
|
||||||
|
assert args.glm
|
||||||
|
glm = LoadGLM(args.glm)
|
||||||
|
|
||||||
|
tokens = []
|
||||||
|
for rule in glm.values():
|
||||||
|
for phrase in rule:
|
||||||
|
for t in tokenizer(phrase):
|
||||||
|
tokens.append(t)
|
||||||
|
if '-' in t:
|
||||||
|
tokens.extend(BreakHyphen(t))
|
||||||
|
vocab_from_glm = list(set(tokens))
|
||||||
|
logging.info(f' GLM vocab size: {len(vocab_from_glm)}')
|
||||||
|
print(f' GLM vocab size: {len(vocab_from_glm)}')
|
||||||
|
|
||||||
|
vocab = list(set(vocab_from_utts + vocab_from_glm))
|
||||||
|
logging.info(f'Global vocab size: {len(vocab)}')
|
||||||
|
print(f'Global vocab size: {len(vocab)}')
|
||||||
|
|
||||||
|
symtab = BuildSymbolTable(
|
||||||
|
# Normal evaluation vocab + auxiliary form for alternative paths + GLM tags
|
||||||
|
vocab
|
||||||
|
+ [x + '#' for x in vocab]
|
||||||
|
+ [x for x in glm.keys()]
|
||||||
|
)
|
||||||
|
glm_tagger = BuildGLMTagger(glm, symtab)
|
||||||
|
edit_transducer = EditTransducer(symbol_table=symtab, vocab=vocab)
|
||||||
|
print(edit_transducer)
|
||||||
|
|
||||||
|
logging.info('Evaluating error rate ...')
|
||||||
|
print('Evaluating error rate ...')
|
||||||
|
fo = open(args.result_file, 'w+', encoding='utf8')
|
||||||
|
ndone = 0
|
||||||
|
for uid in uids:
|
||||||
|
ref = ref_utts[uid].text
|
||||||
|
raw_hyp = hyp_utts[uid].text
|
||||||
|
|
||||||
|
ref_fst = pynini.accep(' '.join(tokenizer(ref)), token_type=symtab)
|
||||||
|
print(ref_fst)
|
||||||
|
|
||||||
|
# print(ref_fst.string(token_type = symtab))
|
||||||
|
|
||||||
|
raw_hyp_fst = pynini.accep(' '.join(tokenizer(raw_hyp)), token_type=symtab)
|
||||||
|
# print(raw_hyp_fst.string(token_type = symtab))
|
||||||
|
|
||||||
|
# Say, we have:
|
||||||
|
# RULE_001: "I'M" <-> "I AM"
|
||||||
|
# REF: HEY I AM HERE
|
||||||
|
# HYP: HEY I'M HERE
|
||||||
|
#
|
||||||
|
# We want to expand HYP with GLM rules(marked with auxiliary #)
|
||||||
|
# HYP#: HEY {I'M | I# AM#} HERE
|
||||||
|
# REF is honored to keep its original form.
|
||||||
|
#
|
||||||
|
# This could be considered as a flexible on-the-fly TN towards HYP.
|
||||||
|
|
||||||
|
# 1. GLM rule tagging:
|
||||||
|
# HEY I'M HERE
|
||||||
|
# ->
|
||||||
|
# HEY <RULE_001> I'M <RULE_001> HERE
|
||||||
|
lattice = (raw_hyp_fst @ glm_tagger).optimize()
|
||||||
|
tagged_ir = pynini.shortestpath(lattice, nshortest=1, unique=True).string(token_type=symtab)
|
||||||
|
# print(hyp_tagged)
|
||||||
|
|
||||||
|
# 2. GLM rule expansion:
|
||||||
|
# HEY <RULE_001> I'M <RULE_001> HERE
|
||||||
|
# ->
|
||||||
|
# sausage-like fst: HEY {I'M | I# AM#} HERE
|
||||||
|
tokens = tagged_ir.split()
|
||||||
|
sausage = pynini.accep('', token_type=symtab)
|
||||||
|
i = 0
|
||||||
|
while i < len(tokens): # invariant: tokens[0, i) has been built into fst
|
||||||
|
forms = []
|
||||||
|
if tokens[i].startswith('<RULE_') and tokens[i].endswith('>'): # rule segment
|
||||||
|
rule_name = tokens[i]
|
||||||
|
rule = glm[rule_name]
|
||||||
|
# pre-condition: i -> ltag
|
||||||
|
raw_form = ''
|
||||||
|
for j in range(i + 1, len(tokens)):
|
||||||
|
if tokens[j] == rule_name:
|
||||||
|
raw_form = ' '.join(tokens[i + 1 : j])
|
||||||
|
break
|
||||||
|
assert raw_form
|
||||||
|
# post-condition: i -> ltag, j -> rtag
|
||||||
|
|
||||||
|
forms.append(raw_form)
|
||||||
|
for phrase in rule:
|
||||||
|
if phrase != raw_form:
|
||||||
|
forms.append(' '.join([x + '#' for x in phrase.split()]))
|
||||||
|
i = j + 1
|
||||||
|
else: # normal token segment
|
||||||
|
token = tokens[i]
|
||||||
|
forms.append(token)
|
||||||
|
if "-" in token: # token with hyphen yields extra forms
|
||||||
|
forms.append(' '.join([x + '#' for x in token.split('-')])) # 'T-SHIRT' -> 'T# SHIRT#'
|
||||||
|
forms.append(token.replace('-', '') + '#') # 'T-SHIRT' -> 'TSHIRT#'
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
sausage_segment = pynini.union(*[pynini.accep(x, token_type=symtab) for x in forms]).optimize()
|
||||||
|
sausage += sausage_segment
|
||||||
|
hyp_fst = sausage.optimize()
|
||||||
|
print(hyp_fst)
|
||||||
|
|
||||||
|
# Utterance-Level error rate evaluation
|
||||||
|
alignment = edit_transducer.compute_alignment(ref_fst, hyp_fst)
|
||||||
|
print("alignment", alignment)
|
||||||
|
|
||||||
|
distance = 0.0
|
||||||
|
C, S, I, D = 0, 0, 0, 0 # Cor, Sub, Ins, Del
|
||||||
|
edit_ali, ref_ali, hyp_ali = [], [], []
|
||||||
|
for state in alignment.states():
|
||||||
|
for arc in alignment.arcs(state):
|
||||||
|
i, o = arc.ilabel, arc.olabel
|
||||||
|
if i != 0 and o != 0 and SymbolEQ(symtab, i, o):
|
||||||
|
e = 'C'
|
||||||
|
r, h = symtab.find(i), symtab.find(o)
|
||||||
|
|
||||||
|
C += 1
|
||||||
|
distance += 0.0
|
||||||
|
elif i != 0 and o != 0 and not SymbolEQ(symtab, i, o):
|
||||||
|
e = 'S'
|
||||||
|
r, h = symtab.find(i), symtab.find(o)
|
||||||
|
|
||||||
|
S += 1
|
||||||
|
distance += 1.0
|
||||||
|
elif i == 0 and o != 0:
|
||||||
|
e = 'I'
|
||||||
|
r, h = '*', symtab.find(o)
|
||||||
|
|
||||||
|
I += 1
|
||||||
|
distance += 1.0
|
||||||
|
elif i != 0 and o == 0:
|
||||||
|
e = 'D'
|
||||||
|
r, h = symtab.find(i), '*'
|
||||||
|
|
||||||
|
D += 1
|
||||||
|
distance += 1.0
|
||||||
|
else:
|
||||||
|
raise RuntimeError
|
||||||
|
|
||||||
|
edit_ali.append(e)
|
||||||
|
ref_ali.append(r)
|
||||||
|
hyp_ali.append(h)
|
||||||
|
# assert(distance == edit_transducer.compute_distance(ref_fst, sausage))
|
||||||
|
|
||||||
|
utt_ter, utt_mter = ComputeTokenErrorRate(C, S, I, D)
|
||||||
|
# print(F'{{"uid":{uid}, "score":{-distance}, "TER":{utt_ter:.2f}, "mTER":{utt_mter:.2f}, "cor":{C}, "sub":{S}, "ins":{I}, "del":{D}}}', file=fo)
|
||||||
|
# PrintPrettyAlignment(raw_hyp, edit_ali, ref_ali, hyp_ali, fo)
|
||||||
|
|
||||||
|
if utt_ter > 0:
|
||||||
|
stats.num_utts_with_error += 1
|
||||||
|
|
||||||
|
stats.C += C
|
||||||
|
stats.S += S
|
||||||
|
stats.I += I
|
||||||
|
stats.D += D
|
||||||
|
|
||||||
|
ndone += 1
|
||||||
|
if ndone % args.logk == 0:
|
||||||
|
logging.info(f'{ndone} utts evaluated.')
|
||||||
|
logging.info(f'{ndone} utts evaluated in total.')
|
||||||
|
|
||||||
|
# Corpus-Level evaluation
|
||||||
|
stats.token_error_rate, stats.modified_token_error_rate = ComputeTokenErrorRate(stats.C, stats.S, stats.I, stats.D)
|
||||||
|
stats.sentence_error_rate = ComputeSentenceErrorRate(stats.num_utts_with_error, stats.num_eval_utts)
|
||||||
|
|
||||||
|
print(stats.to_json(), file=fo)
|
||||||
|
# print(stats.to_kaldi())
|
||||||
|
# print(stats.to_summary(), file=fo)
|
||||||
|
|
||||||
|
fo.close()
|
||||||
370
utils/speechio/error_rate_zh.py
Normal file
370
utils/speechio/error_rate_zh.py
Normal file
@@ -0,0 +1,370 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# coding=utf8
|
||||||
|
|
||||||
|
# Copyright 2021 Jiayu DU
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
logging.basicConfig(stream=sys.stderr, level=logging.INFO, format='[%(levelname)s] %(message)s')
|
||||||
|
|
||||||
|
DEBUG = None
|
||||||
|
|
||||||
|
def GetEditType(ref_token, hyp_token):
|
||||||
|
if ref_token == None and hyp_token != None:
|
||||||
|
return 'I'
|
||||||
|
elif ref_token != None and hyp_token == None:
|
||||||
|
return 'D'
|
||||||
|
elif ref_token == hyp_token:
|
||||||
|
return 'C'
|
||||||
|
elif ref_token != hyp_token:
|
||||||
|
return 'S'
|
||||||
|
else:
|
||||||
|
raise RuntimeError
|
||||||
|
|
||||||
|
class AlignmentArc:
|
||||||
|
def __init__(self, src, dst, ref, hyp):
|
||||||
|
self.src = src
|
||||||
|
self.dst = dst
|
||||||
|
self.ref = ref
|
||||||
|
self.hyp = hyp
|
||||||
|
self.edit_type = GetEditType(ref, hyp)
|
||||||
|
|
||||||
|
def similarity_score_function(ref_token, hyp_token):
|
||||||
|
return 0 if (ref_token == hyp_token) else -1.0
|
||||||
|
|
||||||
|
def insertion_score_function(token):
|
||||||
|
return -1.0
|
||||||
|
|
||||||
|
def deletion_score_function(token):
|
||||||
|
return -1.0
|
||||||
|
|
||||||
|
def EditDistance(
|
||||||
|
ref,
|
||||||
|
hyp,
|
||||||
|
similarity_score_function = similarity_score_function,
|
||||||
|
insertion_score_function = insertion_score_function,
|
||||||
|
deletion_score_function = deletion_score_function):
|
||||||
|
assert(len(ref) != 0)
|
||||||
|
class DPState:
|
||||||
|
def __init__(self):
|
||||||
|
self.score = -float('inf')
|
||||||
|
# backpointer
|
||||||
|
self.prev_r = None
|
||||||
|
self.prev_h = None
|
||||||
|
|
||||||
|
def print_search_grid(S, R, H, fstream):
|
||||||
|
print(file=fstream)
|
||||||
|
for r in range(R):
|
||||||
|
for h in range(H):
|
||||||
|
print(F'[{r},{h}]:{S[r][h].score:4.3f}:({S[r][h].prev_r},{S[r][h].prev_h}) ', end='', file=fstream)
|
||||||
|
print(file=fstream)
|
||||||
|
|
||||||
|
R = len(ref) + 1
|
||||||
|
H = len(hyp) + 1
|
||||||
|
|
||||||
|
# Construct DP search space, a (R x H) grid
|
||||||
|
S = [ [] for r in range(R) ]
|
||||||
|
for r in range(R):
|
||||||
|
S[r] = [ DPState() for x in range(H) ]
|
||||||
|
|
||||||
|
# initialize DP search grid origin, S(r = 0, h = 0)
|
||||||
|
S[0][0].score = 0.0
|
||||||
|
S[0][0].prev_r = None
|
||||||
|
S[0][0].prev_h = None
|
||||||
|
|
||||||
|
# initialize REF axis
|
||||||
|
for r in range(1, R):
|
||||||
|
S[r][0].score = S[r-1][0].score + deletion_score_function(ref[r-1])
|
||||||
|
S[r][0].prev_r = r-1
|
||||||
|
S[r][0].prev_h = 0
|
||||||
|
|
||||||
|
# initialize HYP axis
|
||||||
|
for h in range(1, H):
|
||||||
|
S[0][h].score = S[0][h-1].score + insertion_score_function(hyp[h-1])
|
||||||
|
S[0][h].prev_r = 0
|
||||||
|
S[0][h].prev_h = h-1
|
||||||
|
|
||||||
|
best_score = S[0][0].score
|
||||||
|
best_state = (0, 0)
|
||||||
|
|
||||||
|
for r in range(1, R):
|
||||||
|
for h in range(1, H):
|
||||||
|
sub_or_cor_score = similarity_score_function(ref[r-1], hyp[h-1])
|
||||||
|
new_score = S[r-1][h-1].score + sub_or_cor_score
|
||||||
|
if new_score >= S[r][h].score:
|
||||||
|
S[r][h].score = new_score
|
||||||
|
S[r][h].prev_r = r-1
|
||||||
|
S[r][h].prev_h = h-1
|
||||||
|
|
||||||
|
del_score = deletion_score_function(ref[r-1])
|
||||||
|
new_score = S[r-1][h].score + del_score
|
||||||
|
if new_score >= S[r][h].score:
|
||||||
|
S[r][h].score = new_score
|
||||||
|
S[r][h].prev_r = r - 1
|
||||||
|
S[r][h].prev_h = h
|
||||||
|
|
||||||
|
ins_score = insertion_score_function(hyp[h-1])
|
||||||
|
new_score = S[r][h-1].score + ins_score
|
||||||
|
if new_score >= S[r][h].score:
|
||||||
|
S[r][h].score = new_score
|
||||||
|
S[r][h].prev_r = r
|
||||||
|
S[r][h].prev_h = h-1
|
||||||
|
|
||||||
|
best_score = S[R-1][H-1].score
|
||||||
|
best_state = (R-1, H-1)
|
||||||
|
|
||||||
|
if DEBUG:
|
||||||
|
print_search_grid(S, R, H, sys.stderr)
|
||||||
|
|
||||||
|
# Backtracing best alignment path, i.e. a list of arcs
|
||||||
|
# arc = (src, dst, ref, hyp, edit_type)
|
||||||
|
# src/dst = (r, h), where r/h refers to search grid state-id along Ref/Hyp axis
|
||||||
|
best_path = []
|
||||||
|
r, h = best_state[0], best_state[1]
|
||||||
|
prev_r, prev_h = S[r][h].prev_r, S[r][h].prev_h
|
||||||
|
score = S[r][h].score
|
||||||
|
# loop invariant:
|
||||||
|
# 1. (prev_r, prev_h) -> (r, h) is a "forward arc" on best alignment path
|
||||||
|
# 2. score is the value of point(r, h) on DP search grid
|
||||||
|
while prev_r != None or prev_h != None:
|
||||||
|
src = (prev_r, prev_h)
|
||||||
|
dst = (r, h)
|
||||||
|
if (r == prev_r + 1 and h == prev_h + 1): # Substitution or correct
|
||||||
|
arc = AlignmentArc(src, dst, ref[prev_r], hyp[prev_h])
|
||||||
|
elif (r == prev_r + 1 and h == prev_h): # Deletion
|
||||||
|
arc = AlignmentArc(src, dst, ref[prev_r], None)
|
||||||
|
elif (r == prev_r and h == prev_h + 1): # Insertion
|
||||||
|
arc = AlignmentArc(src, dst, None, hyp[prev_h])
|
||||||
|
else:
|
||||||
|
raise RuntimeError
|
||||||
|
best_path.append(arc)
|
||||||
|
r, h = prev_r, prev_h
|
||||||
|
prev_r, prev_h = S[r][h].prev_r, S[r][h].prev_h
|
||||||
|
score = S[r][h].score
|
||||||
|
|
||||||
|
best_path.reverse()
|
||||||
|
return (best_path, best_score)
|
||||||
|
|
||||||
|
def PrettyPrintAlignment(alignment, stream = sys.stderr):
|
||||||
|
def get_token_str(token):
|
||||||
|
if token == None:
|
||||||
|
return "*"
|
||||||
|
return token
|
||||||
|
|
||||||
|
def is_double_width_char(ch):
|
||||||
|
if (ch >= '\u4e00') and (ch <= '\u9fa5'): # codepoint ranges for Chinese chars
|
||||||
|
return True
|
||||||
|
# TODO: support other double-width-char language such as Japanese, Korean
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def display_width(token_str):
|
||||||
|
m = 0
|
||||||
|
for c in token_str:
|
||||||
|
if is_double_width_char(c):
|
||||||
|
m += 2
|
||||||
|
else:
|
||||||
|
m += 1
|
||||||
|
return m
|
||||||
|
|
||||||
|
R = ' REF : '
|
||||||
|
H = ' HYP : '
|
||||||
|
E = ' EDIT : '
|
||||||
|
for arc in alignment:
|
||||||
|
r = get_token_str(arc.ref)
|
||||||
|
h = get_token_str(arc.hyp)
|
||||||
|
e = arc.edit_type if arc.edit_type != 'C' else ''
|
||||||
|
|
||||||
|
nr, nh, ne = display_width(r), display_width(h), display_width(e)
|
||||||
|
n = max(nr, nh, ne) + 1
|
||||||
|
|
||||||
|
R += r + ' ' * (n-nr)
|
||||||
|
H += h + ' ' * (n-nh)
|
||||||
|
E += e + ' ' * (n-ne)
|
||||||
|
|
||||||
|
print(R, file=stream)
|
||||||
|
print(H, file=stream)
|
||||||
|
print(E, file=stream)
|
||||||
|
|
||||||
|
def CountEdits(alignment):
|
||||||
|
c, s, i, d = 0, 0, 0, 0
|
||||||
|
for arc in alignment:
|
||||||
|
if arc.edit_type == 'C':
|
||||||
|
c += 1
|
||||||
|
elif arc.edit_type == 'S':
|
||||||
|
s += 1
|
||||||
|
elif arc.edit_type == 'I':
|
||||||
|
i += 1
|
||||||
|
elif arc.edit_type == 'D':
|
||||||
|
d += 1
|
||||||
|
else:
|
||||||
|
raise RuntimeError
|
||||||
|
return (c, s, i, d)
|
||||||
|
|
||||||
|
def ComputeTokenErrorRate(c, s, i, d):
|
||||||
|
return 100.0 * (s + d + i) / (s + d + c)
|
||||||
|
|
||||||
|
def ComputeSentenceErrorRate(num_err_utts, num_utts):
|
||||||
|
assert(num_utts != 0)
|
||||||
|
return 100.0 * num_err_utts / num_utts
|
||||||
|
|
||||||
|
|
||||||
|
class EvaluationResult:
|
||||||
|
def __init__(self):
|
||||||
|
self.num_ref_utts = 0
|
||||||
|
self.num_hyp_utts = 0
|
||||||
|
self.num_eval_utts = 0 # seen in both ref & hyp
|
||||||
|
self.num_hyp_without_ref = 0
|
||||||
|
|
||||||
|
self.C = 0
|
||||||
|
self.S = 0
|
||||||
|
self.I = 0
|
||||||
|
self.D = 0
|
||||||
|
self.token_error_rate = 0.0
|
||||||
|
|
||||||
|
self.num_utts_with_error = 0
|
||||||
|
self.sentence_error_rate = 0.0
|
||||||
|
|
||||||
|
def to_json(self):
|
||||||
|
return json.dumps(self.__dict__)
|
||||||
|
|
||||||
|
def to_kaldi(self):
|
||||||
|
info = (
|
||||||
|
F'%WER {self.token_error_rate:.2f} [ {self.S + self.D + self.I} / {self.C + self.S + self.D}, {self.I} ins, {self.D} del, {self.S} sub ]\n'
|
||||||
|
F'%SER {self.sentence_error_rate:.2f} [ {self.num_utts_with_error} / {self.num_eval_utts} ]\n'
|
||||||
|
)
|
||||||
|
return info
|
||||||
|
|
||||||
|
def to_sclite(self):
|
||||||
|
return "TODO"
|
||||||
|
|
||||||
|
def to_espnet(self):
|
||||||
|
return "TODO"
|
||||||
|
|
||||||
|
def to_summary(self):
|
||||||
|
#return json.dumps(self.__dict__, indent=4)
|
||||||
|
summary = (
|
||||||
|
'==================== Overall Statistics ====================\n'
|
||||||
|
F'num_ref_utts: {self.num_ref_utts}\n'
|
||||||
|
F'num_hyp_utts: {self.num_hyp_utts}\n'
|
||||||
|
F'num_hyp_without_ref: {self.num_hyp_without_ref}\n'
|
||||||
|
F'num_eval_utts: {self.num_eval_utts}\n'
|
||||||
|
F'sentence_error_rate: {self.sentence_error_rate:.2f}%\n'
|
||||||
|
F'token_error_rate: {self.token_error_rate:.2f}%\n'
|
||||||
|
F'token_stats:\n'
|
||||||
|
F' - tokens:{self.C + self.S + self.D:>7}\n'
|
||||||
|
F' - edits: {self.S + self.I + self.D:>7}\n'
|
||||||
|
F' - cor: {self.C:>7}\n'
|
||||||
|
F' - sub: {self.S:>7}\n'
|
||||||
|
F' - ins: {self.I:>7}\n'
|
||||||
|
F' - del: {self.D:>7}\n'
|
||||||
|
'============================================================\n'
|
||||||
|
)
|
||||||
|
return summary
|
||||||
|
|
||||||
|
|
||||||
|
class Utterance:
|
||||||
|
def __init__(self, uid, text):
|
||||||
|
self.uid = uid
|
||||||
|
self.text = text
|
||||||
|
|
||||||
|
|
||||||
|
def LoadUtterances(filepath, format):
|
||||||
|
utts = {}
|
||||||
|
if format == 'text': # utt_id word1 word2 ...
|
||||||
|
with open(filepath, 'r', encoding='utf8') as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if line:
|
||||||
|
cols = line.split(maxsplit=1)
|
||||||
|
assert(len(cols) == 2 or len(cols) == 1)
|
||||||
|
uid = cols[0]
|
||||||
|
text = cols[1] if len(cols) == 2 else ''
|
||||||
|
if utts.get(uid) != None:
|
||||||
|
raise RuntimeError(F'Found duplicated utterence id {uid}')
|
||||||
|
utts[uid] = Utterance(uid, text)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(F'Unsupported text format {format}')
|
||||||
|
return utts
|
||||||
|
|
||||||
|
|
||||||
|
def tokenize_text(text, tokenizer):
|
||||||
|
if tokenizer == 'whitespace':
|
||||||
|
return text.split()
|
||||||
|
elif tokenizer == 'char':
|
||||||
|
return [ ch for ch in ''.join(text.split()) ]
|
||||||
|
else:
|
||||||
|
raise RuntimeError(F'ERROR: Unsupported tokenizer {tokenizer}')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
# optional
|
||||||
|
parser.add_argument('--tokenizer', choices=['whitespace', 'char'], default='whitespace', help='whitespace for WER, char for CER')
|
||||||
|
parser.add_argument('--ref-format', choices=['text'], default='text', help='reference format, first col is utt_id, the rest is text')
|
||||||
|
parser.add_argument('--hyp-format', choices=['text'], default='text', help='hypothesis format, first col is utt_id, the rest is text')
|
||||||
|
# required
|
||||||
|
parser.add_argument('--ref', type=str, required=True, help='input reference file')
|
||||||
|
parser.add_argument('--hyp', type=str, required=True, help='input hypothesis file')
|
||||||
|
|
||||||
|
parser.add_argument('result_file', type=str)
|
||||||
|
args = parser.parse_args()
|
||||||
|
logging.info(args)
|
||||||
|
|
||||||
|
ref_utts = LoadUtterances(args.ref, args.ref_format)
|
||||||
|
hyp_utts = LoadUtterances(args.hyp, args.hyp_format)
|
||||||
|
|
||||||
|
r = EvaluationResult()
|
||||||
|
|
||||||
|
# check valid utterances in hyp that have matched non-empty reference
|
||||||
|
eval_utts = []
|
||||||
|
r.num_hyp_without_ref = 0
|
||||||
|
for uid in sorted(hyp_utts.keys()):
|
||||||
|
if uid in ref_utts.keys(): # TODO: efficiency
|
||||||
|
if ref_utts[uid].text.strip(): # non-empty reference
|
||||||
|
eval_utts.append(uid)
|
||||||
|
else:
|
||||||
|
logging.warn(F'Found {uid} with empty reference, skipping...')
|
||||||
|
else:
|
||||||
|
logging.warn(F'Found {uid} without reference, skipping...')
|
||||||
|
r.num_hyp_without_ref += 1
|
||||||
|
|
||||||
|
r.num_hyp_utts = len(hyp_utts)
|
||||||
|
r.num_ref_utts = len(ref_utts)
|
||||||
|
r.num_eval_utts = len(eval_utts)
|
||||||
|
|
||||||
|
with open(args.result_file, 'w+', encoding='utf8') as fo:
|
||||||
|
for uid in eval_utts:
|
||||||
|
ref = ref_utts[uid]
|
||||||
|
hyp = hyp_utts[uid]
|
||||||
|
|
||||||
|
alignment, score = EditDistance(
|
||||||
|
tokenize_text(ref.text, args.tokenizer),
|
||||||
|
tokenize_text(hyp.text, args.tokenizer)
|
||||||
|
)
|
||||||
|
|
||||||
|
c, s, i, d = CountEdits(alignment)
|
||||||
|
utt_ter = ComputeTokenErrorRate(c, s, i, d)
|
||||||
|
|
||||||
|
# utt-level evaluation result
|
||||||
|
print(F'{{"uid":{uid}, "score":{score}, "ter":{utt_ter:.2f}, "cor":{c}, "sub":{s}, "ins":{i}, "del":{d}}}', file=fo)
|
||||||
|
PrettyPrintAlignment(alignment, fo)
|
||||||
|
|
||||||
|
r.C += c
|
||||||
|
r.S += s
|
||||||
|
r.I += i
|
||||||
|
r.D += d
|
||||||
|
|
||||||
|
if utt_ter > 0:
|
||||||
|
r.num_utts_with_error += 1
|
||||||
|
|
||||||
|
# corpus level evaluation result
|
||||||
|
r.sentence_error_rate = ComputeSentenceErrorRate(r.num_utts_with_error, r.num_eval_utts)
|
||||||
|
r.token_error_rate = ComputeTokenErrorRate(r.C, r.S, r.I, r.D)
|
||||||
|
|
||||||
|
print(r.to_summary(), file=fo)
|
||||||
|
|
||||||
|
print(r.to_json())
|
||||||
|
print(r.to_kaldi())
|
||||||
744
utils/speechio/glm_en.csv
Normal file
744
utils/speechio/glm_en.csv
Normal file
@@ -0,0 +1,744 @@
|
|||||||
|
I'M,I AM
|
||||||
|
I'LL,I WILL
|
||||||
|
I'D,I HAD
|
||||||
|
I'VE,I HAVE
|
||||||
|
I WOULD'VE,I'D HAVE
|
||||||
|
YOU'RE,YOU ARE
|
||||||
|
YOU'LL,YOU WILL
|
||||||
|
YOU'D,YOU WOULD
|
||||||
|
YOU'VE,YOU HAVE
|
||||||
|
HE'S,HE IS,HE WAS
|
||||||
|
HE'LL,HE WILL
|
||||||
|
HE'D,HE HAD
|
||||||
|
SHE'S,SHE IS,SHE WAS
|
||||||
|
SHE'LL,SHE WILL
|
||||||
|
SHE'D,SHE HAD
|
||||||
|
IT'S,IT IS,IT WAS
|
||||||
|
IT'LL,IT WILL
|
||||||
|
WE'RE,WE ARE,WE WERE
|
||||||
|
WE'LL,WE WILL
|
||||||
|
WE'D,WE WOULD
|
||||||
|
WE'VE,WE HAVE
|
||||||
|
WHO'LL,WHO WILL
|
||||||
|
THEY'RE,THEY ARE
|
||||||
|
THEY'LL,THEY WILL
|
||||||
|
THAT'S,THAT IS,THAT WAS
|
||||||
|
THAT'LL,THAT WILL
|
||||||
|
HERE'S,HERE IS,HERE WAS
|
||||||
|
THERE'S,THERE IS,THERE WAS
|
||||||
|
WHERE'S,WHERE IS,WHERE WAS
|
||||||
|
WHAT'S,WHAT IS,WHAT WAS
|
||||||
|
LET'S,LET US
|
||||||
|
WHO'S,WHO IS
|
||||||
|
ONE'S,ONE IS
|
||||||
|
THERE'LL,THERE WILL
|
||||||
|
SOMEBODY'S,SOMEBODY IS
|
||||||
|
EVERYBODY'S,EVERYBODY IS
|
||||||
|
WOULD'VE,WOULD HAVE
|
||||||
|
CAN'T,CANNOT,CAN NOT
|
||||||
|
HADN'T,HAD NOT
|
||||||
|
HASN'T,HAS NOT
|
||||||
|
HAVEN'T,HAVE NOT
|
||||||
|
ISN'T,IS NOT
|
||||||
|
AREN'T,ARE NOT
|
||||||
|
WON'T,WILL NOT
|
||||||
|
WOULDN'T,WOULD NOT
|
||||||
|
SHOULDN'T,SHOULD NOT
|
||||||
|
DON'T,DO NOT
|
||||||
|
DIDN'T,DID NOT
|
||||||
|
GOTTA,GOT TO
|
||||||
|
GONNA,GOING TO
|
||||||
|
WANNA,WANT TO
|
||||||
|
LEMME,LET ME
|
||||||
|
GIMME,GIVE ME
|
||||||
|
DUNNO,DON'T KNOW
|
||||||
|
GOTCHA,GOT YOU
|
||||||
|
KINDA,KIND OF
|
||||||
|
MYSELF,MY SELF
|
||||||
|
YOURSELF,YOUR SELF
|
||||||
|
HIMSELF,HIM SELF
|
||||||
|
HERSELF,HER SELF
|
||||||
|
ITSELF,IT SELF
|
||||||
|
OURSELVES,OUR SELVES
|
||||||
|
OKAY,OK,O K
|
||||||
|
Y'ALL,YALL,YOU ALL
|
||||||
|
'CAUSE,'COS,CUZ,BECAUSE
|
||||||
|
FUCKIN',FUCKING
|
||||||
|
KILLING,KILLIN'
|
||||||
|
EVERYDAY,EVERY DAY
|
||||||
|
DOCTOR,DR,DR.
|
||||||
|
MRS,MISSES,MISSUS
|
||||||
|
MR,MR.,MISTER
|
||||||
|
SR,SR.,SENIOR
|
||||||
|
JR,JR.,JUNIOR
|
||||||
|
ST,ST.,SAINT
|
||||||
|
VOL,VOL.,VOLUME
|
||||||
|
CM,CENTIMETER,CENTIMETRE
|
||||||
|
MM,MILLIMETER,MILLIMETRE
|
||||||
|
KM,KILOMETER,KILOMETRE
|
||||||
|
KB,KILOBYTES,KILO BYTES,K B
|
||||||
|
MB,MEGABYTES,MEGA BYTES
|
||||||
|
GB,GIGABYTES,GIGA BYTES,G B
|
||||||
|
THOUSAND,THOUSAND AND
|
||||||
|
HUNDRED,HUNDRED AND
|
||||||
|
A HUNDRED,ONE HUNDRED
|
||||||
|
TWO THOUSAND AND,TWENTY,TWO THOUSAND
|
||||||
|
STORYTELLER,STORY TELLER
|
||||||
|
TSHIRT,T SHIRT
|
||||||
|
TSHIRTS,T SHIRTS
|
||||||
|
LEUKAEMIA,LEUKEMIA
|
||||||
|
OESTROGEN,ESTROGEN
|
||||||
|
ACKNOWLEDGMENT,ACKNOWLEDGEMENT
|
||||||
|
JUDGMENT,JUDGEMENT
|
||||||
|
MAMMA,MAMA
|
||||||
|
DINING,DINNING
|
||||||
|
FLACK,FLAK
|
||||||
|
LEARNT,LEARNED
|
||||||
|
BLONDE,BLOND
|
||||||
|
JUMPSTART,JUMP START
|
||||||
|
RIGHTNOW,RIGHT NOW
|
||||||
|
EVERYONE,EVERY ONE
|
||||||
|
NAME'S,NAME IS
|
||||||
|
FAMILY'S,FAMILY IS
|
||||||
|
COMPANY'S,COMPANY HAS
|
||||||
|
GRANDKID,GRAND KID
|
||||||
|
GRANDKIDS,GRAND KIDS
|
||||||
|
MEALTIMES,MEAL TIMES
|
||||||
|
ALRIGHT,ALL RIGHT
|
||||||
|
GROWNUP,GROWN UP
|
||||||
|
GROWNUPS,GROWN UPS
|
||||||
|
SCHOOLDAYS,SCHOOL DAYS
|
||||||
|
SCHOOLCHILDREN,SCHOOL CHILDREN
|
||||||
|
CASEBOOK,CASE BOOK
|
||||||
|
HUNGOVER,HUNG OVER
|
||||||
|
HANDCLAPS,HAND CLAPS
|
||||||
|
HANDCLAP,HAND CLAP
|
||||||
|
HEATWAVE,HEAT WAVE
|
||||||
|
ADDON,ADD ON
|
||||||
|
ONTO,ON TO
|
||||||
|
INTO,IN TO
|
||||||
|
GOTO,GO TO
|
||||||
|
GUNSHOT,GUN SHOT
|
||||||
|
MOTHERFUCKER,MOTHER FUCKER
|
||||||
|
OFTENTIMES,OFTEN TIMES
|
||||||
|
SARTRE'S,SARTRE IS
|
||||||
|
NONSTARTER,NON STARTER
|
||||||
|
NONSTARTERS,NON STARTERS
|
||||||
|
LONGTIME,LONG TIME
|
||||||
|
POLICYMAKERS,POLICY MAKERS
|
||||||
|
ANYMORE,ANY MORE
|
||||||
|
CANADA'S,CANADA IS
|
||||||
|
CELLPHONE,CELL PHONE
|
||||||
|
WORKPLACE,WORK PLACE
|
||||||
|
UNDERESTIMATING,UNDER ESTIMATING
|
||||||
|
CYBERSECURITY,CYBER SECURITY
|
||||||
|
NORTHEAST,NORTH EAST
|
||||||
|
ANYTIME,ANY TIME
|
||||||
|
LIVESTREAM,LIVE STREAM
|
||||||
|
LIVESTREAMS,LIVE STREAMS
|
||||||
|
WEBCAM,WEB CAM
|
||||||
|
EMAIL,E MAIL
|
||||||
|
ECAM,E CAM
|
||||||
|
VMIX,V MIX
|
||||||
|
SETUP,SET UP
|
||||||
|
SMARTPHONE,SMART PHONE
|
||||||
|
MULTICASTING,MULTI CASTING
|
||||||
|
CHITCHAT,CHIT CHAT
|
||||||
|
SEMIFINAL,SEMI FINAL
|
||||||
|
SEMIFINALS,SEMI FINALS
|
||||||
|
BBQ,BARBECUE
|
||||||
|
STORYLINE,STORY LINE
|
||||||
|
STORYLINES,STORY LINES
|
||||||
|
BRO,BROTHER
|
||||||
|
BROS,BROTHERS
|
||||||
|
OVERPROTECTIIVE,OVER PROTECTIVE
|
||||||
|
TIMEOUT,TIME OUT
|
||||||
|
ADVISOR,ADVISER
|
||||||
|
TIMBERWOLVES,TIMBER WOLVES
|
||||||
|
WEBPAGE,WEB PAGE
|
||||||
|
NEWCOMER,NEW COMER
|
||||||
|
DELMAR,DEL MAR
|
||||||
|
NETPLAY,NET PLAY
|
||||||
|
STREETSIDE,STREET SIDE
|
||||||
|
COLOURED,COLORED
|
||||||
|
COLOURFUL,COLORFUL
|
||||||
|
O,ZERO
|
||||||
|
ETCETERA,ET CETERA
|
||||||
|
FUNDRAISING,FUND RAISING
|
||||||
|
RAINFOREST,RAIN FOREST
|
||||||
|
BREATHTAKING,BREATH TAKING
|
||||||
|
WIKIPAGE,WIKI PAGE
|
||||||
|
OVERTIME,OVER TIME
|
||||||
|
TRAIN'S TRAIN IS
|
||||||
|
ANYONE,ANY ONE
|
||||||
|
PHYSIOTHERAPY,PHYSIO THERAPY
|
||||||
|
ANYBODY,ANY BODY
|
||||||
|
BOTTLECAPS,BOTTLE CAPS
|
||||||
|
BOTTLECAP,BOTTLE CAP
|
||||||
|
STEPFATHER'S,STEP FATHER'S
|
||||||
|
STEPFATHER,STEP FATHER
|
||||||
|
WARTIME,WAR TIME
|
||||||
|
SCREENSHOT,SCREEN SHOT
|
||||||
|
TIMELINE,TIME LINE
|
||||||
|
CITY'S,CITY IS
|
||||||
|
NONPROFIT,NON PROFIT
|
||||||
|
KPOP,K POP
|
||||||
|
HOMEBASE,HOME BASE
|
||||||
|
LIFELONG,LIFE LONG
|
||||||
|
LAWSUITS,LAW SUITS
|
||||||
|
MULTIBILLION,MULTI BILLION
|
||||||
|
ROADMAP,ROAD MAP
|
||||||
|
GUY'S,GUY IS
|
||||||
|
CHECKOUT,CHECK OUT
|
||||||
|
SQUARESPACE,SQUARE SPACE
|
||||||
|
REDLINING,RED LINING
|
||||||
|
BASE'S,BASE IS
|
||||||
|
TAKEAWAY,TAKE AWAY
|
||||||
|
CANDYLAND,CANDY LAND
|
||||||
|
ANTISOCIAL,ANTI SOCIAL
|
||||||
|
CASEWORK,CASE WORK
|
||||||
|
RIGOR,RIGOUR
|
||||||
|
ORGANIZATIONS,ORGANISATIONS
|
||||||
|
ORGANIZATION,ORGANISATION
|
||||||
|
SIGNPOST,SIGN POST
|
||||||
|
WWII,WORLD WAR TWO
|
||||||
|
WINDOWPANE,WINDOW PANE
|
||||||
|
SUREFIRE,SURE FIRE
|
||||||
|
MOUNTAINTOP,MOUNTAIN TOP
|
||||||
|
SALESPERSON,SALES PERSON
|
||||||
|
NETWORK,NET WORK
|
||||||
|
MINISERIES,MINI SERIES
|
||||||
|
EDWARDS'S,EDWARDS IS
|
||||||
|
INTERSUBJECTIVITY,INTER SUBJECTIVITY
|
||||||
|
LIBERALISM'S,LIBERALISM IS
|
||||||
|
TAGLINE,TAG LINE
|
||||||
|
SHINETHEORY,SHINE THEORY
|
||||||
|
CALLYOURGIRLFRIEND,CALL YOUR GIRLFRIEND
|
||||||
|
STARTUP,START UP
|
||||||
|
BREAKUP,BREAK UP
|
||||||
|
RADIOTOPIA,RADIO TOPIA
|
||||||
|
HEARTBREAKING,HEART BREAKING
|
||||||
|
AUTOIMMUNE,AUTO IMMUNE
|
||||||
|
SINISE'S,SINISE IS
|
||||||
|
KICKBACK,KICK BACK
|
||||||
|
FOGHORN,FOG HORN
|
||||||
|
BADASS,BAD ASS
|
||||||
|
POWERAMERICAFORWARD,POWER AMERICA FORWARD
|
||||||
|
GOOGLE'S,GOOGLE IS
|
||||||
|
ROLEPLAY,ROLE PLAY
|
||||||
|
PRICE'S,PRICE IS
|
||||||
|
STANDOFF,STAND OFF
|
||||||
|
FOREVER,FOR EVER
|
||||||
|
GENERAL'S,GENERAL IS
|
||||||
|
DOG'S,DOG IS
|
||||||
|
AUDIOBOOK,AUDIO BOOK
|
||||||
|
ANYWAY,ANY WAY
|
||||||
|
PIGEONHOLE,PIEGON HOLE
|
||||||
|
EGGSHELLS,EGG SHELLS
|
||||||
|
VACCINE'S,VACCINE IS
|
||||||
|
WORKOUT,WORK OUT
|
||||||
|
ADMINISTRATOR'S,ADMINISTRATOR IS
|
||||||
|
FUCKUP,FUCK UP
|
||||||
|
RUNOFFS,RUN OFFS
|
||||||
|
COLORWAY,COLOR WAY
|
||||||
|
WAITLIST,WAIT LIST
|
||||||
|
HEALTHCARE,HEALTH CARE
|
||||||
|
TEXTBOOK,TEXT BOOK
|
||||||
|
CALLBACK,CALL BACK
|
||||||
|
PARTYGOERS,PARTY GOERS
|
||||||
|
SOMEDAY,SOME DAY
|
||||||
|
NIGHTGOWN,NIGHT GOWN
|
||||||
|
STANDALONG,STAND ALONG
|
||||||
|
BUSSINESSWOMAN,BUSSINESS WOMAN
|
||||||
|
STORYTELLING,STORY TELLING
|
||||||
|
MARKETPLACE,MARKET PLACE
|
||||||
|
CRATEJOY,CRATE JOY
|
||||||
|
OUTPERFORMED,OUT PERFORMED
|
||||||
|
TRUEBOTANICALS,TRUE BOTANICALS
|
||||||
|
NONFICTION,NON FICTION
|
||||||
|
SPINOFF,SPIN OFF
|
||||||
|
MOTHERFUCKING,MOTHER FUCKING
|
||||||
|
TRACKLIST,TRACK LIST
|
||||||
|
GODDAMN,GOD DAMN
|
||||||
|
PORNHUB,PORN HUB
|
||||||
|
UNDERAGE,UNDER AGE
|
||||||
|
GOODBYE,GOOD BYE
|
||||||
|
HARDCORE,HARD CORE
|
||||||
|
TRUCK'S,TRUCK IS
|
||||||
|
COUNTERSTEERING,COUNTER STEERING
|
||||||
|
BUZZWORD,BUZZ WORD
|
||||||
|
SUBCOMPONENTS,SUB COMPONENTS
|
||||||
|
MOREOVER,MORE OVER
|
||||||
|
PICKUP,PICK UP
|
||||||
|
NEWSLETTER,NEWS LETTER
|
||||||
|
KEYWORD,KEY WORD
|
||||||
|
LOGIN,LOG IN
|
||||||
|
TOOLBOX,TOOL BOX
|
||||||
|
LINK'S,LINK IS
|
||||||
|
PRIMIALVIDEO,PRIMAL VIDEO
|
||||||
|
DOTNET,DOT NET
|
||||||
|
AIRSTRIKE,AIR STRIKE
|
||||||
|
HAIRSTYLE,HAIR STYLE
|
||||||
|
TOWNSFOLK,TOWNS FOLK
|
||||||
|
GOLDFISH,GOLD FISH
|
||||||
|
TOM'S,TOM IS
|
||||||
|
HOMETOWN,HOME TOWN
|
||||||
|
CORONAVIRUS,CORONA VIRUS
|
||||||
|
PLAYSTATION,PLAY STATION
|
||||||
|
TOMORROW,TO MORROW
|
||||||
|
TIMECONSUMING,TIME CONSUMING
|
||||||
|
POSTWAR,POST WAR
|
||||||
|
HANDSON,HANDS ON
|
||||||
|
SHAKEUP,SHAKE UP
|
||||||
|
ECOMERS,E COMERS
|
||||||
|
COFOUNDER,CO FOUNDER
|
||||||
|
HIGHEND,HIGH END
|
||||||
|
INPERSON,IN PERSON
|
||||||
|
GROWNUP,GROWN UP
|
||||||
|
SELFREGULATION,SELF REGULATION
|
||||||
|
INDEPTH,IN DEPTH
|
||||||
|
ALLTIME,ALL TIME
|
||||||
|
LONGTERM,LONG TERM
|
||||||
|
SOCALLED,SO CALLED
|
||||||
|
SELFCONFIDENCE,SELF CONFIDENCE
|
||||||
|
STANDUP,STAND UP
|
||||||
|
MINDBOGGLING,MIND BOGGLING
|
||||||
|
BEINGFOROTHERS,BEING FOR OTHERS
|
||||||
|
COWROTE,CO WROTE
|
||||||
|
COSTARRED,CO STARRED
|
||||||
|
EDITORINCHIEF,EDITOR IN CHIEF
|
||||||
|
HIGHSPEED,HIGH SPEED
|
||||||
|
DECISIONMAKING,DECISION MAKING
|
||||||
|
WELLBEING,WELL BEING
|
||||||
|
NONTRIVIAL,NON TRIVIAL
|
||||||
|
PREEXISTING,PRE EXISTING
|
||||||
|
STATEOWNED,STATE OWNED
|
||||||
|
PLUGIN,PLUG IN
|
||||||
|
PROVERSION,PRO VERSION
|
||||||
|
OPTIN,OPT IN
|
||||||
|
FOLLOWUP,FOLLOW UP
|
||||||
|
FOLLOWUPS,FOLLOW UPS
|
||||||
|
WIFI,WI FI
|
||||||
|
THIRDPARTY,THIRD PARTY
|
||||||
|
PROFESSIONALLOOKING,PROFESSIONAL LOOKING
|
||||||
|
FULLSCREEN,FULL SCREEN
|
||||||
|
BUILTIN,BUILT IN
|
||||||
|
MULTISTREAM,MULTI STREAM
|
||||||
|
LOWCOST,LOW COST
|
||||||
|
RESTREAM,RE STREAM
|
||||||
|
GAMECHANGER,GAME CHANGER
|
||||||
|
WELLDEVELOPED,WELL DEVELOPED
|
||||||
|
QUARTERINCH,QUARTER INCH
|
||||||
|
FASTFASHION,FAST FASHION
|
||||||
|
ECOMMERCE,E COMMERCE
|
||||||
|
PRIZEWINNING,PRIZE WINNING
|
||||||
|
NEVERENDING,NEVER ENDING
|
||||||
|
MINDBLOWING,MIND BLOWING
|
||||||
|
REALLIFE,REAL LIFE
|
||||||
|
REOPEN,RE OPEN
|
||||||
|
ONDEMAND,ON DEMAND
|
||||||
|
PROBLEMSOLVING,PROBLEM SOLVING
|
||||||
|
HEAVYHANDED,HEAVY HANDED
|
||||||
|
OPENENDED,OPEN ENDED
|
||||||
|
SELFCONTROL,SELF CONTROL
|
||||||
|
WELLMEANING,WELL MEANING
|
||||||
|
COHOST,CO HOST
|
||||||
|
RIGHTSBASED,RIGHTS BASED
|
||||||
|
HALFBROTHER,HALF BROTHER
|
||||||
|
FATHERINLAW,FATHER IN LAW
|
||||||
|
COAUTHOR,CO AUTHOR
|
||||||
|
REELECTION,RE ELECTION
|
||||||
|
SELFHELP,SELF HELP
|
||||||
|
PROLIFE,PRO LIFE
|
||||||
|
ANTIDUKE,ANTI DUKE
|
||||||
|
POSTSTRUCTURALIST,POST STRUCTURALIST
|
||||||
|
COFOUNDED,CO FOUNDED
|
||||||
|
XRAY,X RAY
|
||||||
|
ALLAROUND,ALL AROUND
|
||||||
|
HIGHTECH,HIGH TECH
|
||||||
|
TMOBILE,T MOBILE
|
||||||
|
INHOUSE,IN HOUSE
|
||||||
|
POSTMORTEM,POST MORTEM
|
||||||
|
LITTLEKNOWN,LITTLE KNOWN
|
||||||
|
FALSEPOSITIVE,FALSE POSITIVE
|
||||||
|
ANTIVAXXER,ANTI VAXXER
|
||||||
|
EMAILS,E MAILS
|
||||||
|
DRIVETHROUGH,DRIVE THROUGH
|
||||||
|
DAYTODAY,DAY TO DAY
|
||||||
|
COSTAR,CO STAR
|
||||||
|
EBAY,E BAY
|
||||||
|
KOOLAID,KOOL AID
|
||||||
|
ANTIDEMOCRATIC,ANTI DEMOCRATIC
|
||||||
|
MIDDLEAGED,MIDDLE AGED
|
||||||
|
SHORTLIVED,SHORT LIVED
|
||||||
|
BESTSELLING,BEST SELLING
|
||||||
|
TICTACS,TIC TACS
|
||||||
|
UHHUH,UH HUH
|
||||||
|
MULTITANK,MULTI TANK
|
||||||
|
JAWDROPPING,JAW DROPPING
|
||||||
|
LIVESTREAMING,LIVE STREAMING
|
||||||
|
HARDWORKING,HARD WORKING
|
||||||
|
BOTTOMDWELLING,BOTTOM DWELLING
|
||||||
|
PRESHOW,PRE SHOW
|
||||||
|
HANDSFREE,HANDS FREE
|
||||||
|
TRICKORTREATING,TRICK OR TREATING
|
||||||
|
PRERECORDED,PRE RECORDED
|
||||||
|
DOGOODERS,DO GOODERS
|
||||||
|
WIDERANGING,WIDE RANGING
|
||||||
|
LIFESAVING,LIFE SAVING
|
||||||
|
SKIREPORT,SKI REPORT
|
||||||
|
SNOWBASE,SNOW BASE
|
||||||
|
JAYZ,JAY Z
|
||||||
|
SPIDERMAN,SPIDER MAN
|
||||||
|
FREEKICK,FREE KICK
|
||||||
|
EDWARDSHELAIRE,EDWARDS HELAIRE
|
||||||
|
SHORTTERM,SHORT TERM
|
||||||
|
HAVENOTS,HAVE NOTS
|
||||||
|
SELFINTEREST,SELF INTEREST
|
||||||
|
SELFINTERESTED,SELF INTERESTED
|
||||||
|
SELFCOMPASSION,SELF COMPASSION
|
||||||
|
MACHINELEARNING,MACHINE LEARNING
|
||||||
|
COAUTHORED,CO AUTHORED
|
||||||
|
NONGOVERNMENT,NON GOVERNMENT
|
||||||
|
SUBSAHARAN,SUB SAHARAN
|
||||||
|
COCHAIR,CO CHAIR
|
||||||
|
LARGESCALE,LARGE SCALE
|
||||||
|
VIDEOONDEMAND,VIDEO ON DEMAND
|
||||||
|
FIRSTCLASS,FIRST CLASS
|
||||||
|
COFOUNDERS,CO FOUNDERS
|
||||||
|
COOP,CO OP
|
||||||
|
PREORDERS,PRE ORDERS
|
||||||
|
DOUBLEENTRY,DOUBLE ENTRY
|
||||||
|
SELFCONFIDENT,SELF CONFIDENT
|
||||||
|
SELFPORTRAIT,SELF PORTRAIT
|
||||||
|
NONWHITE,NON WHITE
|
||||||
|
ONBOARD,ON BOARD
|
||||||
|
HALFLIFE,HALF LIFE
|
||||||
|
ONCOURT,ON COURT
|
||||||
|
SCIFI,SCI FI
|
||||||
|
XMEN,X MEN
|
||||||
|
DAYLEWIS,DAY LEWIS
|
||||||
|
LALALAND,LA LA LAND
|
||||||
|
AWARDWINNING,AWARD WINNING
|
||||||
|
BOXOFFICE,BOX OFFICE
|
||||||
|
TRIDACTYLS,TRI DACTYLS
|
||||||
|
TRIDACTYL,TRI DACTYL
|
||||||
|
MEDIUMSIZED,MEDIUM SIZED
|
||||||
|
POSTSECONDARY,POST SECONDARY
|
||||||
|
FULLTIME,FULL TIME
|
||||||
|
GOKART,GO KART
|
||||||
|
OPENAIR,OPEN AIR
|
||||||
|
WELLKNOWN,WELL KNOWN
|
||||||
|
ICECREAM,ICE CREAM
|
||||||
|
EARTHMOON,EARTH MOON
|
||||||
|
STATEOFTHEART,STATE OF THE ART
|
||||||
|
BSIDE,B SIDE
|
||||||
|
EASTWEST,EAST WEST
|
||||||
|
ALLSTAR,ALL STAR
|
||||||
|
RUNNERUP,RUNNER UP
|
||||||
|
HORSEDRAWN,HORSE DRAWN
|
||||||
|
OPENSOURCE,OPEN SOURCE
|
||||||
|
PURPOSEBUILT,PURPOSE BUILT
|
||||||
|
SQUAREFREE,SQUARE FREE
|
||||||
|
PRESENTDAY,PRESENT DAY
|
||||||
|
CANADAUNITED,CANADA UNITED
|
||||||
|
HOTCHPOTCH,HOTCH POTCH
|
||||||
|
LOWLYING,LOW LYING
|
||||||
|
RIGHTHANDED,RIGHT HANDED
|
||||||
|
PEARSHAPED,PEAR SHAPED
|
||||||
|
BESTKNOWN,BEST KNOWN
|
||||||
|
FULLLENGTH,FULL LENGTH
|
||||||
|
YEARROUND,YEAR ROUND
|
||||||
|
PREELECTION,PRE ELECTION
|
||||||
|
RERECORD,RE RECORD
|
||||||
|
MINIALBUM,MINI ALBUM
|
||||||
|
LONGESTRUNNING,LONGEST RUNNING
|
||||||
|
ALLIRELAND,ALL IRELAND
|
||||||
|
NORTHWESTERN,NORTH WESTERN
|
||||||
|
PARTTIME,PART TIME
|
||||||
|
NONGOVERNMENTAL,NON GOVERNMENTAL
|
||||||
|
ONLINE,ON LINE
|
||||||
|
ONAIR,ON AIR
|
||||||
|
NORTHSOUTH,NORTH SOUTH
|
||||||
|
RERELEASED,RE RELEASED
|
||||||
|
LEFTHANDED,LEFT HANDED
|
||||||
|
BSIDES,B SIDES
|
||||||
|
ANGLOSAXON,ANGLO SAXON
|
||||||
|
SOUTHSOUTHEAST,SOUTH SOUTHEAST
|
||||||
|
CROSSCOUNTRY,CROSS COUNTRY
|
||||||
|
REBUILT,RE BUILT
|
||||||
|
FREEFORM,FREE FORM
|
||||||
|
SCOOBYDOO,SCOOBY DOO
|
||||||
|
ATLARGE,AT LARGE
|
||||||
|
COUNCILMANAGER,COUNCIL MANAGER
|
||||||
|
LONGRUNNING,LONG RUNNING
|
||||||
|
PREWAR,PRE WAR
|
||||||
|
REELECTED,RE ELECTED
|
||||||
|
HIGHSCHOOL,HIGH SCHOOL
|
||||||
|
RUNNERSUP,RUNNERS UP
|
||||||
|
NORTHWEST,NORTH WEST
|
||||||
|
WEBBASED,WEB BASED
|
||||||
|
HIGHQUALITY,HIGH QUALITY
|
||||||
|
RIGHTWING,RIGHT WING
|
||||||
|
LANEFOX,LANE FOX
|
||||||
|
PAYPERVIEW,PAY PER VIEW
|
||||||
|
COPRODUCTION,CO PRODUCTION
|
||||||
|
NONPARTISAN,NON PARTISAN
|
||||||
|
FIRSTPERSON,FIRST PERSON
|
||||||
|
WORLDRENOWNED,WORLD RENOWNED
|
||||||
|
VICEPRESIDENT,VICE PRESIDENT
|
||||||
|
PROROMAN,PRO ROMAN
|
||||||
|
COPRODUCED,CO PRODUCED
|
||||||
|
LOWPOWER,LOW POWER
|
||||||
|
SELFESTEEM,SELF ESTEEM
|
||||||
|
SEMITRANSPARENT,SEMI TRANSPARENT
|
||||||
|
SECONDINCOMMAND,SECOND IN COMMAND
|
||||||
|
HIGHRISE,HIGH RISE
|
||||||
|
COHOSTED,CO HOSTED
|
||||||
|
AFRICANAMERICAN,AFRICAN AMERICAN
|
||||||
|
SOUTHWEST,SOUTH WEST
|
||||||
|
WELLPRESERVED,WELL PRESERVED
|
||||||
|
FEATURELENGTH,FEATURE LENGTH
|
||||||
|
HIPHOP,HIP HOP
|
||||||
|
ALLBIG,ALL BIG
|
||||||
|
SOUTHEAST,SOUTH EAST
|
||||||
|
COUNTERATTACK,COUNTER ATTACK
|
||||||
|
QUARTERFINALS,QUARTER FINALS
|
||||||
|
STABLEDOOR,STABLE DOOR
|
||||||
|
DARKEYED,DARK EYED
|
||||||
|
ALLAMERICAN,ALL AMERICAN
|
||||||
|
THIRDPERSON,THIRD PERSON
|
||||||
|
LOWLEVEL,LOW LEVEL
|
||||||
|
NTERMINAL,N TERMINAL
|
||||||
|
DRIEDUP,DRIED UP
|
||||||
|
AFRICANAMERICANS,AFRICAN AMERICANS
|
||||||
|
ANTIAPARTHEID,ANTI APARTHEID
|
||||||
|
STOKEONTRENT,STOKE ON TRENT
|
||||||
|
NORTHNORTHEAST,NORTH NORTHEAST
|
||||||
|
BRANDNEW,BRAND NEW
|
||||||
|
RIGHTANGLED,RIGHT ANGLED
|
||||||
|
GOVERNMENTOWNED,GOVERNMENT OWNED
|
||||||
|
SONINLAW,SON IN LAW
|
||||||
|
SUBJECTOBJECTVERB,SUBJECT OBJECT VERB
|
||||||
|
LEFTARM,LEFT ARM
|
||||||
|
LONGLIVED,LONG LIVED
|
||||||
|
REDEYE,RED EYE
|
||||||
|
TPOSE,T POSE
|
||||||
|
NIGHTVISION,NIGHT VISION
|
||||||
|
SOUTHEASTERN,SOUTH EASTERN
|
||||||
|
WELLRECEIVED,WELL RECEIVED
|
||||||
|
ALFAYOUM,AL FAYOUM
|
||||||
|
TIMEBASED,TIME BASED
|
||||||
|
KETTLEDRUMS,KETTLE DRUMS
|
||||||
|
BRIGHTEYED,BRIGHT EYED
|
||||||
|
REDBROWN,RED BROWN
|
||||||
|
SAMESEX,SAME SEX
|
||||||
|
PORTDEPAIX,PORT DE PAIX
|
||||||
|
CLEANUP,CLEAN UP
|
||||||
|
PERCENT,PERCENT SIGN
|
||||||
|
TAKEOUT,TAKE OUT
|
||||||
|
KNOWHOW,KNOW HOW
|
||||||
|
FISHBONE,FISH BONE
|
||||||
|
FISHSTICKS,FISH STICKS
|
||||||
|
PAPERWORK,PAPER WORK
|
||||||
|
NICKNACKS,NICK NACKS
|
||||||
|
STREETTALKING,STREET TALKING
|
||||||
|
NONACADEMIC,NON ACADEMIC
|
||||||
|
SHELLY,SHELLEY
|
||||||
|
SHELLY'S,SHELLEY'S
|
||||||
|
JIMMY,JIMMIE
|
||||||
|
JIMMY'S,JIMMIE'S
|
||||||
|
DRUGSTORE,DRUG STORE
|
||||||
|
THRU,THROUGH
|
||||||
|
PLAYDATE,PLAY DATE
|
||||||
|
MICROLIFE,MICRO LIFE
|
||||||
|
SKILLSET,SKILL SET
|
||||||
|
SKILLSETS,SKILL SETS
|
||||||
|
TRADEOFF,TRADE OFF
|
||||||
|
TRADEOFFS,TRADE OFFS
|
||||||
|
ONSCREEN,ON SCREEN
|
||||||
|
PLAYBACK,PLAY BACK
|
||||||
|
ARTWORK,ART WORK
|
||||||
|
COWORKER,CO WORDER
|
||||||
|
COWORKERS,CO WORDERS
|
||||||
|
SOMETIME,SOME TIME
|
||||||
|
SOMETIMES,SOME TIMES
|
||||||
|
CROWDFUNDING,CROWD FUNDING
|
||||||
|
AM,A.M.,A M
|
||||||
|
PM,P.M.,P M
|
||||||
|
TV,T V
|
||||||
|
MBA,M B A
|
||||||
|
USA,U S A
|
||||||
|
US,U S
|
||||||
|
UK,U K
|
||||||
|
CEO,C E O
|
||||||
|
CFO,C F O
|
||||||
|
COO,C O O
|
||||||
|
CIO,C I O
|
||||||
|
FM,F M
|
||||||
|
GMC,G M C
|
||||||
|
FSC,F S C
|
||||||
|
NPD,N P D
|
||||||
|
APM,A P M
|
||||||
|
NGO,N G O
|
||||||
|
TD,T D
|
||||||
|
LOL,L O L
|
||||||
|
IPO,I P O
|
||||||
|
CNBC,C N B C
|
||||||
|
IPOS,I P OS
|
||||||
|
CNBC's,C N B C'S
|
||||||
|
JT,J T
|
||||||
|
NPR,N P R
|
||||||
|
NPR'S,N P R'S
|
||||||
|
MP,M P
|
||||||
|
IOI,I O I
|
||||||
|
DW,D W
|
||||||
|
CNN,C N N
|
||||||
|
WSM,W S M
|
||||||
|
ET,E T
|
||||||
|
IT,I T
|
||||||
|
RJ,R J
|
||||||
|
DVD,D V D
|
||||||
|
DVD'S,D V D'S
|
||||||
|
HBO,H B O
|
||||||
|
LA,L A
|
||||||
|
XC,X C
|
||||||
|
SUV,S U V
|
||||||
|
NBA,N B A
|
||||||
|
NBA'S,N B A'S
|
||||||
|
ESPN,E S P N
|
||||||
|
ESPN'S,E S P N'S
|
||||||
|
ADT,A D T
|
||||||
|
HD,H D
|
||||||
|
VIP,V I P
|
||||||
|
TMZ,T M Z
|
||||||
|
CBC,C B C
|
||||||
|
NPO,N P O
|
||||||
|
BBC,B B C
|
||||||
|
LA'S,L A'S
|
||||||
|
TMZ'S,T M Z'S
|
||||||
|
HIV,H I V
|
||||||
|
FTC,F T C
|
||||||
|
EU,E U
|
||||||
|
PHD,P H D
|
||||||
|
AI,A I
|
||||||
|
FHI,F H I
|
||||||
|
ICML,I C M L
|
||||||
|
ICLR,I C L R
|
||||||
|
BMW,B M W
|
||||||
|
EV,E V
|
||||||
|
CR,C R
|
||||||
|
API,A P I
|
||||||
|
ICO,I C O
|
||||||
|
LTE,L T E
|
||||||
|
OBS,O B S
|
||||||
|
PC,P C
|
||||||
|
IO,I O
|
||||||
|
CRM,C R M
|
||||||
|
RTMP,R T M P
|
||||||
|
ASMR,A S M R
|
||||||
|
GG,G G
|
||||||
|
WWW,W W W
|
||||||
|
PEI,P E I
|
||||||
|
JJ,J J
|
||||||
|
PT,P T
|
||||||
|
DJ,D J
|
||||||
|
SD,S D
|
||||||
|
POW,P.O.W.,P O W
|
||||||
|
FYI,F Y I
|
||||||
|
DC,D C,D.C
|
||||||
|
ABC,A B C
|
||||||
|
TJ,T J
|
||||||
|
WMDT,W M D T
|
||||||
|
WDTN,W D T N
|
||||||
|
TY,T Y
|
||||||
|
EJ,E J
|
||||||
|
CJ,C J
|
||||||
|
ACL,A C L
|
||||||
|
UK'S,U K'S
|
||||||
|
GTV,G T V
|
||||||
|
MDMA,M D M A
|
||||||
|
DFW,D F W
|
||||||
|
WTF,W T F
|
||||||
|
AJ,A J
|
||||||
|
MD,M D
|
||||||
|
PH,P H
|
||||||
|
ID,I D
|
||||||
|
SEO,S E O
|
||||||
|
UTM'S,U T M'S
|
||||||
|
EC,E C
|
||||||
|
UFC,U F C
|
||||||
|
RV,R V
|
||||||
|
UTM,U T M
|
||||||
|
CSV,C S V
|
||||||
|
SMS,S M S
|
||||||
|
GRB,G R B
|
||||||
|
GT,G T
|
||||||
|
LEM,L E M
|
||||||
|
XR,X R
|
||||||
|
EDU,E D U
|
||||||
|
NBC,N B C
|
||||||
|
EMS,E M S
|
||||||
|
CDC,C D C
|
||||||
|
MLK,M L K
|
||||||
|
IE,I E
|
||||||
|
OC,O C
|
||||||
|
HR,H R
|
||||||
|
MA,M A
|
||||||
|
DEE,D E E
|
||||||
|
AP,A P
|
||||||
|
UFO,U F O
|
||||||
|
DE,D E
|
||||||
|
LGBTQ,L G B T Q
|
||||||
|
PTA,P T A
|
||||||
|
NHS,N H S
|
||||||
|
CMA,C M A
|
||||||
|
MGM,M G M
|
||||||
|
AKA,A K A
|
||||||
|
HW,H W
|
||||||
|
GOP,G O P
|
||||||
|
GOP'S,G O P'S
|
||||||
|
FBI,F B I
|
||||||
|
PRX,P R X
|
||||||
|
CTO,C T O
|
||||||
|
URL,U R L
|
||||||
|
EIN,E I N
|
||||||
|
MLS,M L S
|
||||||
|
CSI,C S I
|
||||||
|
AOC,A O C
|
||||||
|
CND,C N D
|
||||||
|
CP,C P
|
||||||
|
PP,P P
|
||||||
|
CLI,C L I
|
||||||
|
PB,P B
|
||||||
|
FDA,F D A
|
||||||
|
MRNA,M R N A
|
||||||
|
PR,P R
|
||||||
|
VP,V P
|
||||||
|
DNC,D N C
|
||||||
|
MSNBC,M S N B C
|
||||||
|
GQ,G Q
|
||||||
|
UT,U T
|
||||||
|
XXI,X X I
|
||||||
|
HRV,H R V
|
||||||
|
WHO,W H O
|
||||||
|
CRO,C R O
|
||||||
|
DPA,D P A
|
||||||
|
PPE,P P E
|
||||||
|
EVA,E V A
|
||||||
|
BP,B P
|
||||||
|
GPS,G P S
|
||||||
|
AR,A R
|
||||||
|
PJ,P J
|
||||||
|
MLM,M L M
|
||||||
|
OLED,O L E D
|
||||||
|
BO,B O
|
||||||
|
VE,V E
|
||||||
|
UN,U N
|
||||||
|
SLS,S L S
|
||||||
|
DM,D M
|
||||||
|
DM'S,D M'S
|
||||||
|
ASAP,A S A P
|
||||||
|
ETA,E T A
|
||||||
|
DOB,D O B
|
||||||
|
BMW,B M W
|
||||||
|
20
utils/speechio/interjections_en.csv
Normal file
20
utils/speechio/interjections_en.csv
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
ach
|
||||||
|
ah
|
||||||
|
eee
|
||||||
|
eh
|
||||||
|
er
|
||||||
|
ew
|
||||||
|
ha
|
||||||
|
hee
|
||||||
|
hm
|
||||||
|
hmm
|
||||||
|
hmmm
|
||||||
|
huh
|
||||||
|
mm
|
||||||
|
mmm
|
||||||
|
oof
|
||||||
|
uh
|
||||||
|
uhh
|
||||||
|
um
|
||||||
|
oh
|
||||||
|
hum
|
||||||
|
1
utils/speechio/nemo_text_processing/README.md
Normal file
1
utils/speechio/nemo_text_processing/README.md
Normal file
@@ -0,0 +1 @@
|
|||||||
|
nemo_version from commit:eae1684f7f33c2a18de9ecfa42ec7db93d39e631
|
||||||
13
utils/speechio/nemo_text_processing/__init__.py
Normal file
13
utils/speechio/nemo_text_processing/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
@@ -0,0 +1,10 @@
|
|||||||
|
# Text Normalization
|
||||||
|
|
||||||
|
Text Normalization is part of NeMo's `nemo_text_processing` - a Python package that is installed with the `nemo_toolkit`.
|
||||||
|
It converts text from written form into its verbalized form, e.g. "123" -> "one hundred twenty three".
|
||||||
|
|
||||||
|
See [NeMo documentation](https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/nlp/text_normalization/wfst/wfst_text_normalization.html) for details.
|
||||||
|
|
||||||
|
Tutorial with overview of the package capabilities: [Text_(Inverse)_Normalization.ipynb](https://colab.research.google.com/github/NVIDIA/NeMo/blob/stable/tutorials/text_processing/Text_(Inverse)_Normalization.ipynb)
|
||||||
|
|
||||||
|
Tutorial on how to customize the underlying gramamrs: [WFST_Tutorial.ipynb](https://colab.research.google.com/github/NVIDIA/NeMo/blob/stable/tutorials/text_processing/WFST_Tutorial.ipynb)
|
||||||
@@ -0,0 +1,13 @@
|
|||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
@@ -0,0 +1,350 @@
|
|||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
import string
|
||||||
|
from collections import defaultdict, namedtuple
|
||||||
|
from typing import Dict, List, Optional, Set, Tuple
|
||||||
|
from unicodedata import category
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
EOS_TYPE = "EOS"
|
||||||
|
PUNCT_TYPE = "PUNCT"
|
||||||
|
PLAIN_TYPE = "PLAIN"
|
||||||
|
Instance = namedtuple('Instance', 'token_type un_normalized normalized')
|
||||||
|
known_types = [
|
||||||
|
"PLAIN",
|
||||||
|
"DATE",
|
||||||
|
"CARDINAL",
|
||||||
|
"LETTERS",
|
||||||
|
"VERBATIM",
|
||||||
|
"MEASURE",
|
||||||
|
"DECIMAL",
|
||||||
|
"ORDINAL",
|
||||||
|
"DIGIT",
|
||||||
|
"MONEY",
|
||||||
|
"TELEPHONE",
|
||||||
|
"ELECTRONIC",
|
||||||
|
"FRACTION",
|
||||||
|
"TIME",
|
||||||
|
"ADDRESS",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _load_kaggle_text_norm_file(file_path: str) -> List[Instance]:
|
||||||
|
"""
|
||||||
|
https://www.kaggle.com/richardwilliamsproat/text-normalization-for-english-russian-and-polish
|
||||||
|
Loads text file in the Kaggle Google text normalization file format: <semiotic class>\t<unnormalized text>\t<`self` if trivial class or normalized text>
|
||||||
|
E.g.
|
||||||
|
PLAIN Brillantaisia <self>
|
||||||
|
PLAIN is <self>
|
||||||
|
PLAIN a <self>
|
||||||
|
PLAIN genus <self>
|
||||||
|
PLAIN of <self>
|
||||||
|
PLAIN plant <self>
|
||||||
|
PLAIN in <self>
|
||||||
|
PLAIN family <self>
|
||||||
|
PLAIN Acanthaceae <self>
|
||||||
|
PUNCT . sil
|
||||||
|
<eos> <eos>
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: file path to text file
|
||||||
|
|
||||||
|
Returns: flat list of instances
|
||||||
|
"""
|
||||||
|
res = []
|
||||||
|
with open(file_path, 'r') as fp:
|
||||||
|
for line in fp:
|
||||||
|
parts = line.strip().split("\t")
|
||||||
|
if parts[0] == "<eos>":
|
||||||
|
res.append(Instance(token_type=EOS_TYPE, un_normalized="", normalized=""))
|
||||||
|
else:
|
||||||
|
l_type, l_token, l_normalized = parts
|
||||||
|
l_token = l_token.lower()
|
||||||
|
l_normalized = l_normalized.lower()
|
||||||
|
|
||||||
|
if l_type == PLAIN_TYPE:
|
||||||
|
res.append(Instance(token_type=l_type, un_normalized=l_token, normalized=l_token))
|
||||||
|
elif l_type != PUNCT_TYPE:
|
||||||
|
res.append(Instance(token_type=l_type, un_normalized=l_token, normalized=l_normalized))
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def load_files(file_paths: List[str], load_func=_load_kaggle_text_norm_file) -> List[Instance]:
|
||||||
|
"""
|
||||||
|
Load given list of text files using the `load_func` function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_paths: list of file paths
|
||||||
|
load_func: loading function
|
||||||
|
|
||||||
|
Returns: flat list of instances
|
||||||
|
"""
|
||||||
|
res = []
|
||||||
|
for file_path in file_paths:
|
||||||
|
res.extend(load_func(file_path=file_path))
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def clean_generic(text: str) -> str:
|
||||||
|
"""
|
||||||
|
Cleans text without affecting semiotic classes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: string
|
||||||
|
|
||||||
|
Returns: cleaned string
|
||||||
|
"""
|
||||||
|
text = text.strip()
|
||||||
|
text = text.lower()
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate(preds: List[str], labels: List[str], input: Optional[List[str]] = None, verbose: bool = True) -> float:
|
||||||
|
"""
|
||||||
|
Evaluates accuracy given predictions and labels.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
preds: predictions
|
||||||
|
labels: labels
|
||||||
|
input: optional, only needed for verbosity
|
||||||
|
verbose: if true prints [input], golden labels and predictions
|
||||||
|
|
||||||
|
Returns accuracy
|
||||||
|
"""
|
||||||
|
acc = 0
|
||||||
|
nums = len(preds)
|
||||||
|
for i in range(nums):
|
||||||
|
pred_norm = clean_generic(preds[i])
|
||||||
|
label_norm = clean_generic(labels[i])
|
||||||
|
if pred_norm == label_norm:
|
||||||
|
acc = acc + 1
|
||||||
|
else:
|
||||||
|
if input:
|
||||||
|
print(f"inpu: {json.dumps(input[i])}")
|
||||||
|
print(f"gold: {json.dumps(label_norm)}")
|
||||||
|
print(f"pred: {json.dumps(pred_norm)}")
|
||||||
|
return acc / nums
|
||||||
|
|
||||||
|
|
||||||
|
def training_data_to_tokens(
|
||||||
|
data: List[Instance], category: Optional[str] = None
|
||||||
|
) -> Dict[str, Tuple[List[str], List[str]]]:
|
||||||
|
"""
|
||||||
|
Filters the instance list by category if provided and converts it into a map from token type to list of un_normalized and normalized strings
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: list of instances
|
||||||
|
category: optional semiotic class category name
|
||||||
|
|
||||||
|
Returns Dict: token type -> (list of un_normalized strings, list of normalized strings)
|
||||||
|
"""
|
||||||
|
result = defaultdict(lambda: ([], []))
|
||||||
|
for instance in data:
|
||||||
|
if instance.token_type != EOS_TYPE:
|
||||||
|
if category is None or instance.token_type == category:
|
||||||
|
result[instance.token_type][0].append(instance.un_normalized)
|
||||||
|
result[instance.token_type][1].append(instance.normalized)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def training_data_to_sentences(data: List[Instance]) -> Tuple[List[str], List[str], List[Set[str]]]:
|
||||||
|
"""
|
||||||
|
Takes instance list, creates list of sentences split by EOS_Token
|
||||||
|
Args:
|
||||||
|
data: list of instances
|
||||||
|
Returns (list of unnormalized sentences, list of normalized sentences, list of sets of categories in a sentence)
|
||||||
|
"""
|
||||||
|
# split data at EOS boundaries
|
||||||
|
sentences = []
|
||||||
|
sentence = []
|
||||||
|
categories = []
|
||||||
|
sentence_categories = set()
|
||||||
|
|
||||||
|
for instance in data:
|
||||||
|
if instance.token_type == EOS_TYPE:
|
||||||
|
sentences.append(sentence)
|
||||||
|
sentence = []
|
||||||
|
categories.append(sentence_categories)
|
||||||
|
sentence_categories = set()
|
||||||
|
else:
|
||||||
|
sentence.append(instance)
|
||||||
|
sentence_categories.update([instance.token_type])
|
||||||
|
un_normalized = [" ".join([instance.un_normalized for instance in sentence]) for sentence in sentences]
|
||||||
|
normalized = [" ".join([instance.normalized for instance in sentence]) for sentence in sentences]
|
||||||
|
return un_normalized, normalized, categories
|
||||||
|
|
||||||
|
|
||||||
|
def post_process_punctuation(text: str) -> str:
|
||||||
|
"""
|
||||||
|
Normalized quotes and spaces
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: text
|
||||||
|
|
||||||
|
Returns: text with normalized spaces and quotes
|
||||||
|
"""
|
||||||
|
text = (
|
||||||
|
text.replace('( ', '(')
|
||||||
|
.replace(' )', ')')
|
||||||
|
.replace('{ ', '{')
|
||||||
|
.replace(' }', '}')
|
||||||
|
.replace('[ ', '[')
|
||||||
|
.replace(' ]', ']')
|
||||||
|
.replace(' ', ' ')
|
||||||
|
.replace('”', '"')
|
||||||
|
.replace("’", "'")
|
||||||
|
.replace("»", '"')
|
||||||
|
.replace("«", '"')
|
||||||
|
.replace("\\", "")
|
||||||
|
.replace("„", '"')
|
||||||
|
.replace("´", "'")
|
||||||
|
.replace("’", "'")
|
||||||
|
.replace('“', '"')
|
||||||
|
.replace("‘", "'")
|
||||||
|
.replace('`', "'")
|
||||||
|
.replace('- -', "--")
|
||||||
|
)
|
||||||
|
|
||||||
|
for punct in "!,.:;?":
|
||||||
|
text = text.replace(f' {punct}', punct)
|
||||||
|
return text.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def pre_process(text: str) -> str:
|
||||||
|
"""
|
||||||
|
Optional text preprocessing before normalization (part of TTS TN pipeline)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: string that may include semiotic classes
|
||||||
|
|
||||||
|
Returns: text with spaces around punctuation marks
|
||||||
|
"""
|
||||||
|
space_both = '[]'
|
||||||
|
for punct in space_both:
|
||||||
|
text = text.replace(punct, ' ' + punct + ' ')
|
||||||
|
|
||||||
|
# remove extra space
|
||||||
|
text = re.sub(r' +', ' ', text)
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def load_file(file_path: str) -> List[str]:
|
||||||
|
"""
|
||||||
|
Loads given text file with separate lines into list of string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: file path
|
||||||
|
|
||||||
|
Returns: flat list of string
|
||||||
|
"""
|
||||||
|
res = []
|
||||||
|
with open(file_path, 'r') as fp:
|
||||||
|
for line in fp:
|
||||||
|
res.append(line)
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def write_file(file_path: str, data: List[str]):
|
||||||
|
"""
|
||||||
|
Writes out list of string to file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: file path
|
||||||
|
data: list of string
|
||||||
|
|
||||||
|
"""
|
||||||
|
with open(file_path, 'w') as fp:
|
||||||
|
for line in data:
|
||||||
|
fp.write(line + '\n')
|
||||||
|
|
||||||
|
|
||||||
|
def post_process_punct(input: str, normalized_text: str, add_unicode_punct: bool = False):
|
||||||
|
"""
|
||||||
|
Post-processing of the normalized output to match input in terms of spaces around punctuation marks.
|
||||||
|
After NN normalization, Moses detokenization puts a space after
|
||||||
|
punctuation marks, and attaches an opening quote "'" to the word to the right.
|
||||||
|
E.g., input to the TN NN model is "12 test' example",
|
||||||
|
after normalization and detokenization -> "twelve test 'example" (the quote is considered to be an opening quote,
|
||||||
|
but it doesn't match the input and can cause issues during TTS voice generation.)
|
||||||
|
The current function will match the punctuation and spaces of the normalized text with the input sequence.
|
||||||
|
"12 test' example" -> "twelve test 'example" -> "twelve test' example" (the quote was shifted to match the input).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input: input text (original input to the NN, before normalization or tokenization)
|
||||||
|
normalized_text: output text (output of the TN NN model)
|
||||||
|
add_unicode_punct: set to True to handle unicode punctuation marks as well as default string.punctuation (increases post processing time)
|
||||||
|
"""
|
||||||
|
# in the post-processing WFST graph "``" are repalced with '"" quotes (otherwise single quotes "`" won't be handled correctly)
|
||||||
|
# this function fixes spaces around them based on input sequence, so here we're making the same double quote replacement
|
||||||
|
# to make sure these new double quotes work with this function
|
||||||
|
if "``" in input and "``" not in normalized_text:
|
||||||
|
input = input.replace("``", '"')
|
||||||
|
input = [x for x in input]
|
||||||
|
normalized_text = [x for x in normalized_text]
|
||||||
|
punct_marks = [x for x in string.punctuation if x in input]
|
||||||
|
|
||||||
|
if add_unicode_punct:
|
||||||
|
punct_unicode = [
|
||||||
|
chr(i)
|
||||||
|
for i in range(sys.maxunicode)
|
||||||
|
if category(chr(i)).startswith("P") and chr(i) not in punct_default and chr(i) in input
|
||||||
|
]
|
||||||
|
punct_marks = punct_marks.extend(punct_unicode)
|
||||||
|
|
||||||
|
for punct in punct_marks:
|
||||||
|
try:
|
||||||
|
equal = True
|
||||||
|
if input.count(punct) != normalized_text.count(punct):
|
||||||
|
equal = False
|
||||||
|
idx_in, idx_out = 0, 0
|
||||||
|
while punct in input[idx_in:]:
|
||||||
|
idx_out = normalized_text.index(punct, idx_out)
|
||||||
|
idx_in = input.index(punct, idx_in)
|
||||||
|
|
||||||
|
def _is_valid(idx_out, idx_in, normalized_text, input):
|
||||||
|
"""Check if previous or next word match (for cases when punctuation marks are part of
|
||||||
|
semiotic token, i.e. some punctuation can be missing in the normalized text)"""
|
||||||
|
return (idx_out > 0 and idx_in > 0 and normalized_text[idx_out - 1] == input[idx_in - 1]) or (
|
||||||
|
idx_out < len(normalized_text) - 1
|
||||||
|
and idx_in < len(input) - 1
|
||||||
|
and normalized_text[idx_out + 1] == input[idx_in + 1]
|
||||||
|
)
|
||||||
|
|
||||||
|
if not equal and not _is_valid(idx_out, idx_in, normalized_text, input):
|
||||||
|
idx_in += 1
|
||||||
|
continue
|
||||||
|
if idx_in > 0 and idx_out > 0:
|
||||||
|
if normalized_text[idx_out - 1] == " " and input[idx_in - 1] != " ":
|
||||||
|
normalized_text[idx_out - 1] = ""
|
||||||
|
|
||||||
|
elif normalized_text[idx_out - 1] != " " and input[idx_in - 1] == " ":
|
||||||
|
normalized_text[idx_out - 1] += " "
|
||||||
|
|
||||||
|
if idx_in < len(input) - 1 and idx_out < len(normalized_text) - 1:
|
||||||
|
if normalized_text[idx_out + 1] == " " and input[idx_in + 1] != " ":
|
||||||
|
normalized_text[idx_out + 1] = ""
|
||||||
|
elif normalized_text[idx_out + 1] != " " and input[idx_in + 1] == " ":
|
||||||
|
normalized_text[idx_out] = normalized_text[idx_out] + " "
|
||||||
|
idx_out += 1
|
||||||
|
idx_in += 1
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
normalized_text = "".join(normalized_text)
|
||||||
|
return re.sub(r' +', ' ', normalized_text)
|
||||||
@@ -0,0 +1,17 @@
|
|||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from nemo_text_processing.text_normalization.en.taggers.tokenize_and_classify import ClassifyFst
|
||||||
|
from nemo_text_processing.text_normalization.en.verbalizers.verbalize import VerbalizeFst
|
||||||
|
from nemo_text_processing.text_normalization.en.verbalizers.verbalize_final import VerbalizeFinalFst
|
||||||
@@ -0,0 +1,342 @@
|
|||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import regex as re
|
||||||
|
from nemo_text_processing.text_normalization.data_loader_utils import (
|
||||||
|
EOS_TYPE,
|
||||||
|
Instance,
|
||||||
|
load_files,
|
||||||
|
training_data_to_sentences,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
This file is for evaluation purposes.
|
||||||
|
filter_loaded_data() cleans data (list of instances) for text normalization. Filters and cleaners can be specified for each semiotic class individually.
|
||||||
|
For example, normalized text should only include characters and whitespace characters but no punctuation.
|
||||||
|
Cardinal unnormalized instances should contain at least one integer and all other characters are removed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class Filter:
|
||||||
|
"""
|
||||||
|
Filter class
|
||||||
|
|
||||||
|
Args:
|
||||||
|
class_type: semiotic class used in dataset
|
||||||
|
process_func: function to transform text
|
||||||
|
filter_func: function to filter text
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, class_type: str, process_func: object, filter_func: object):
|
||||||
|
self.class_type = class_type
|
||||||
|
self.process_func = process_func
|
||||||
|
self.filter_func = filter_func
|
||||||
|
|
||||||
|
def filter(self, instance: Instance) -> bool:
|
||||||
|
"""
|
||||||
|
filter function
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filters given instance with filter function
|
||||||
|
|
||||||
|
Returns: True if given instance fulfills criteria or does not belong to class type
|
||||||
|
"""
|
||||||
|
if instance.token_type != self.class_type:
|
||||||
|
return True
|
||||||
|
return self.filter_func(instance)
|
||||||
|
|
||||||
|
def process(self, instance: Instance) -> Instance:
|
||||||
|
"""
|
||||||
|
process function
|
||||||
|
|
||||||
|
Args:
|
||||||
|
processes given instance with process function
|
||||||
|
|
||||||
|
Returns: processed instance if instance belongs to expected class type or original instance
|
||||||
|
"""
|
||||||
|
if instance.token_type != self.class_type:
|
||||||
|
return instance
|
||||||
|
return self.process_func(instance)
|
||||||
|
|
||||||
|
|
||||||
|
def filter_cardinal_1(instance: Instance) -> bool:
|
||||||
|
ok = re.search(r"[0-9]", instance.un_normalized)
|
||||||
|
return ok
|
||||||
|
|
||||||
|
|
||||||
|
def process_cardinal_1(instance: Instance) -> Instance:
|
||||||
|
un_normalized = instance.un_normalized
|
||||||
|
normalized = instance.normalized
|
||||||
|
un_normalized = re.sub(r"[^0-9]", "", un_normalized)
|
||||||
|
normalized = re.sub(r"[^a-z ]", "", normalized)
|
||||||
|
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
|
||||||
|
|
||||||
|
|
||||||
|
def filter_ordinal_1(instance: Instance) -> bool:
|
||||||
|
ok = re.search(r"(st|nd|rd|th)\s*$", instance.un_normalized)
|
||||||
|
return ok
|
||||||
|
|
||||||
|
|
||||||
|
def process_ordinal_1(instance: Instance) -> Instance:
|
||||||
|
un_normalized = instance.un_normalized
|
||||||
|
normalized = instance.normalized
|
||||||
|
un_normalized = re.sub(r"[,\s]", "", un_normalized)
|
||||||
|
normalized = re.sub(r"[^a-z ]", "", normalized)
|
||||||
|
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
|
||||||
|
|
||||||
|
|
||||||
|
def filter_decimal_1(instance: Instance) -> bool:
|
||||||
|
ok = re.search(r"[0-9]", instance.un_normalized)
|
||||||
|
return ok
|
||||||
|
|
||||||
|
|
||||||
|
def process_decimal_1(instance: Instance) -> Instance:
|
||||||
|
un_normalized = instance.un_normalized
|
||||||
|
un_normalized = re.sub(r",", "", un_normalized)
|
||||||
|
normalized = instance.normalized
|
||||||
|
normalized = re.sub(r"[^a-z ]", "", normalized)
|
||||||
|
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
|
||||||
|
|
||||||
|
|
||||||
|
def filter_measure_1(instance: Instance) -> bool:
|
||||||
|
ok = True
|
||||||
|
return ok
|
||||||
|
|
||||||
|
|
||||||
|
def process_measure_1(instance: Instance) -> Instance:
|
||||||
|
un_normalized = instance.un_normalized
|
||||||
|
normalized = instance.normalized
|
||||||
|
un_normalized = re.sub(r",", "", un_normalized)
|
||||||
|
un_normalized = re.sub(r"m2", "m²", un_normalized)
|
||||||
|
un_normalized = re.sub(r"(\d)([^\d.\s])", r"\1 \2", un_normalized)
|
||||||
|
normalized = re.sub(r"[^a-z\s]", "", normalized)
|
||||||
|
normalized = re.sub(r"per ([a-z\s]*)s$", r"per \1", normalized)
|
||||||
|
normalized = re.sub(r"[^a-z ]", "", normalized)
|
||||||
|
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
|
||||||
|
|
||||||
|
|
||||||
|
def filter_money_1(instance: Instance) -> bool:
|
||||||
|
ok = re.search(r"[0-9]", instance.un_normalized)
|
||||||
|
return ok
|
||||||
|
|
||||||
|
|
||||||
|
def process_money_1(instance: Instance) -> Instance:
|
||||||
|
un_normalized = instance.un_normalized
|
||||||
|
normalized = instance.normalized
|
||||||
|
un_normalized = re.sub(r",", "", un_normalized)
|
||||||
|
un_normalized = re.sub(r"a\$", r"$", un_normalized)
|
||||||
|
un_normalized = re.sub(r"us\$", r"$", un_normalized)
|
||||||
|
un_normalized = re.sub(r"(\d)m\s*$", r"\1 million", un_normalized)
|
||||||
|
un_normalized = re.sub(r"(\d)bn?\s*$", r"\1 billion", un_normalized)
|
||||||
|
normalized = re.sub(r"[^a-z ]", "", normalized)
|
||||||
|
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
|
||||||
|
|
||||||
|
|
||||||
|
def filter_time_1(instance: Instance) -> bool:
|
||||||
|
ok = re.search(r"[0-9]", instance.un_normalized)
|
||||||
|
return ok
|
||||||
|
|
||||||
|
|
||||||
|
def process_time_1(instance: Instance) -> Instance:
|
||||||
|
un_normalized = instance.un_normalized
|
||||||
|
un_normalized = re.sub(r": ", ":", un_normalized)
|
||||||
|
un_normalized = re.sub(r"(\d)\s?a\s?m\s?", r"\1 a.m.", un_normalized)
|
||||||
|
un_normalized = re.sub(r"(\d)\s?p\s?m\s?", r"\1 p.m.", un_normalized)
|
||||||
|
normalized = instance.normalized
|
||||||
|
normalized = re.sub(r"[^a-z ]", "", normalized)
|
||||||
|
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
|
||||||
|
|
||||||
|
|
||||||
|
def filter_plain_1(instance: Instance) -> bool:
|
||||||
|
ok = True
|
||||||
|
return ok
|
||||||
|
|
||||||
|
|
||||||
|
def process_plain_1(instance: Instance) -> Instance:
|
||||||
|
un_normalized = instance.un_normalized
|
||||||
|
normalized = instance.normalized
|
||||||
|
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
|
||||||
|
|
||||||
|
|
||||||
|
def filter_punct_1(instance: Instance) -> bool:
|
||||||
|
ok = True
|
||||||
|
return ok
|
||||||
|
|
||||||
|
|
||||||
|
def process_punct_1(instance: Instance) -> Instance:
|
||||||
|
un_normalized = instance.un_normalized
|
||||||
|
normalized = instance.normalized
|
||||||
|
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
|
||||||
|
|
||||||
|
|
||||||
|
def filter_date_1(instance: Instance) -> bool:
|
||||||
|
ok = True
|
||||||
|
return ok
|
||||||
|
|
||||||
|
|
||||||
|
def process_date_1(instance: Instance) -> Instance:
|
||||||
|
un_normalized = instance.un_normalized
|
||||||
|
un_normalized = re.sub(r",", "", un_normalized)
|
||||||
|
normalized = instance.normalized
|
||||||
|
normalized = re.sub(r"[^a-z ]", "", normalized)
|
||||||
|
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
|
||||||
|
|
||||||
|
|
||||||
|
def filter_letters_1(instance: Instance) -> bool:
|
||||||
|
ok = True
|
||||||
|
return ok
|
||||||
|
|
||||||
|
|
||||||
|
def process_letters_1(instance: Instance) -> Instance:
|
||||||
|
un_normalized = instance.un_normalized
|
||||||
|
normalized = instance.normalized
|
||||||
|
normalized = re.sub(r"[^a-z ]", "", normalized)
|
||||||
|
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
|
||||||
|
|
||||||
|
|
||||||
|
def filter_verbatim_1(instance: Instance) -> bool:
|
||||||
|
ok = True
|
||||||
|
return ok
|
||||||
|
|
||||||
|
|
||||||
|
def process_verbatim_1(instance: Instance) -> Instance:
|
||||||
|
un_normalized = instance.un_normalized
|
||||||
|
normalized = instance.normalized
|
||||||
|
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
|
||||||
|
|
||||||
|
|
||||||
|
def filter_digit_1(instance: Instance) -> bool:
|
||||||
|
ok = re.search(r"[0-9]", instance.un_normalized)
|
||||||
|
return ok
|
||||||
|
|
||||||
|
|
||||||
|
def process_digit_1(instance: Instance) -> Instance:
|
||||||
|
un_normalized = instance.un_normalized
|
||||||
|
normalized = instance.normalized
|
||||||
|
normalized = re.sub(r"[^a-z ]", "", normalized)
|
||||||
|
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
|
||||||
|
|
||||||
|
|
||||||
|
def filter_telephone_1(instance: Instance) -> bool:
|
||||||
|
ok = re.search(r"[0-9]", instance.un_normalized)
|
||||||
|
return ok
|
||||||
|
|
||||||
|
|
||||||
|
def process_telephone_1(instance: Instance) -> Instance:
|
||||||
|
un_normalized = instance.un_normalized
|
||||||
|
normalized = instance.normalized
|
||||||
|
normalized = re.sub(r"[^a-z ]", "", normalized)
|
||||||
|
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
|
||||||
|
|
||||||
|
|
||||||
|
def filter_electronic_1(instance: Instance) -> bool:
|
||||||
|
ok = re.search(r"[0-9]", instance.un_normalized)
|
||||||
|
return ok
|
||||||
|
|
||||||
|
|
||||||
|
def process_electronic_1(instance: Instance) -> Instance:
|
||||||
|
un_normalized = instance.un_normalized
|
||||||
|
normalized = instance.normalized
|
||||||
|
normalized = re.sub(r"[^a-z ]", "", normalized)
|
||||||
|
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
|
||||||
|
|
||||||
|
|
||||||
|
def filter_fraction_1(instance: Instance) -> bool:
|
||||||
|
ok = re.search(r"[0-9]", instance.un_normalized)
|
||||||
|
return ok
|
||||||
|
|
||||||
|
|
||||||
|
def process_fraction_1(instance: Instance) -> Instance:
|
||||||
|
un_normalized = instance.un_normalized
|
||||||
|
normalized = instance.normalized
|
||||||
|
normalized = re.sub(r"[^a-z ]", "", normalized)
|
||||||
|
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
|
||||||
|
|
||||||
|
|
||||||
|
def filter_address_1(instance: Instance) -> bool:
|
||||||
|
ok = True
|
||||||
|
return ok
|
||||||
|
|
||||||
|
|
||||||
|
def process_address_1(instance: Instance) -> Instance:
|
||||||
|
un_normalized = instance.un_normalized
|
||||||
|
normalized = instance.normalized
|
||||||
|
normalized = re.sub(r"[^a-z ]", "", normalized)
|
||||||
|
return Instance(token_type=instance.token_type, un_normalized=un_normalized, normalized=normalized)
|
||||||
|
|
||||||
|
|
||||||
|
filters = []
|
||||||
|
filters.append(Filter(class_type="CARDINAL", process_func=process_cardinal_1, filter_func=filter_cardinal_1))
|
||||||
|
filters.append(Filter(class_type="ORDINAL", process_func=process_ordinal_1, filter_func=filter_ordinal_1))
|
||||||
|
filters.append(Filter(class_type="DECIMAL", process_func=process_decimal_1, filter_func=filter_decimal_1))
|
||||||
|
filters.append(Filter(class_type="MEASURE", process_func=process_measure_1, filter_func=filter_measure_1))
|
||||||
|
filters.append(Filter(class_type="MONEY", process_func=process_money_1, filter_func=filter_money_1))
|
||||||
|
filters.append(Filter(class_type="TIME", process_func=process_time_1, filter_func=filter_time_1))
|
||||||
|
|
||||||
|
filters.append(Filter(class_type="DATE", process_func=process_date_1, filter_func=filter_date_1))
|
||||||
|
filters.append(Filter(class_type="PLAIN", process_func=process_plain_1, filter_func=filter_plain_1))
|
||||||
|
filters.append(Filter(class_type="PUNCT", process_func=process_punct_1, filter_func=filter_punct_1))
|
||||||
|
filters.append(Filter(class_type="LETTERS", process_func=process_letters_1, filter_func=filter_letters_1))
|
||||||
|
filters.append(Filter(class_type="VERBATIM", process_func=process_verbatim_1, filter_func=filter_verbatim_1))
|
||||||
|
filters.append(Filter(class_type="DIGIT", process_func=process_digit_1, filter_func=filter_digit_1))
|
||||||
|
filters.append(Filter(class_type="TELEPHONE", process_func=process_telephone_1, filter_func=filter_telephone_1))
|
||||||
|
filters.append(Filter(class_type="ELECTRONIC", process_func=process_electronic_1, filter_func=filter_electronic_1))
|
||||||
|
filters.append(Filter(class_type="FRACTION", process_func=process_fraction_1, filter_func=filter_fraction_1))
|
||||||
|
filters.append(Filter(class_type="ADDRESS", process_func=process_address_1, filter_func=filter_address_1))
|
||||||
|
filters.append(Filter(class_type=EOS_TYPE, process_func=lambda x: x, filter_func=lambda x: True))
|
||||||
|
|
||||||
|
|
||||||
|
def filter_loaded_data(data: List[Instance], verbose: bool = False) -> List[Instance]:
|
||||||
|
"""
|
||||||
|
Filters list of instances
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: list of instances
|
||||||
|
|
||||||
|
Returns: filtered and transformed list of instances
|
||||||
|
"""
|
||||||
|
updates_instances = []
|
||||||
|
for instance in data:
|
||||||
|
updated_instance = False
|
||||||
|
for fil in filters:
|
||||||
|
if fil.class_type == instance.token_type and fil.filter(instance):
|
||||||
|
instance = fil.process(instance)
|
||||||
|
updated_instance = True
|
||||||
|
if updated_instance:
|
||||||
|
if verbose:
|
||||||
|
print(instance)
|
||||||
|
updates_instances.append(instance)
|
||||||
|
return updates_instances
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = ArgumentParser()
|
||||||
|
parser.add_argument("--input", help="input file path", type=str, default='./en_with_types/output-00001-of-00100')
|
||||||
|
parser.add_argument("--verbose", help="print filtered instances", action='store_true')
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = parse_args()
|
||||||
|
file_path = args.input
|
||||||
|
|
||||||
|
print("Loading training data: " + file_path)
|
||||||
|
instance_list = load_files([file_path]) # List of instances
|
||||||
|
filtered_instance_list = filter_loaded_data(instance_list, args.verbose)
|
||||||
|
training_data_to_sentences(filtered_instance_list)
|
||||||
@@ -0,0 +1,196 @@
|
|||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
# Copyright 2015 and onwards Google, Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import string
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import pynini
|
||||||
|
from nemo_text_processing.text_normalization.en.utils import get_abs_path
|
||||||
|
from pynini import Far
|
||||||
|
from pynini.examples import plurals
|
||||||
|
from pynini.export import export
|
||||||
|
from pynini.lib import byte, pynutil, utf8
|
||||||
|
|
||||||
|
NEMO_CHAR = utf8.VALID_UTF8_CHAR
|
||||||
|
|
||||||
|
NEMO_DIGIT = byte.DIGIT
|
||||||
|
NEMO_LOWER = pynini.union(*string.ascii_lowercase).optimize()
|
||||||
|
NEMO_UPPER = pynini.union(*string.ascii_uppercase).optimize()
|
||||||
|
NEMO_ALPHA = pynini.union(NEMO_LOWER, NEMO_UPPER).optimize()
|
||||||
|
NEMO_ALNUM = pynini.union(NEMO_DIGIT, NEMO_ALPHA).optimize()
|
||||||
|
NEMO_HEX = pynini.union(*string.hexdigits).optimize()
|
||||||
|
NEMO_NON_BREAKING_SPACE = u"\u00A0"
|
||||||
|
NEMO_SPACE = " "
|
||||||
|
NEMO_WHITE_SPACE = pynini.union(" ", "\t", "\n", "\r", u"\u00A0").optimize()
|
||||||
|
NEMO_NOT_SPACE = pynini.difference(NEMO_CHAR, NEMO_WHITE_SPACE).optimize()
|
||||||
|
NEMO_NOT_QUOTE = pynini.difference(NEMO_CHAR, r'"').optimize()
|
||||||
|
|
||||||
|
NEMO_PUNCT = pynini.union(*map(pynini.escape, string.punctuation)).optimize()
|
||||||
|
NEMO_GRAPH = pynini.union(NEMO_ALNUM, NEMO_PUNCT).optimize()
|
||||||
|
|
||||||
|
NEMO_SIGMA = pynini.closure(NEMO_CHAR)
|
||||||
|
|
||||||
|
delete_space = pynutil.delete(pynini.closure(NEMO_WHITE_SPACE))
|
||||||
|
delete_zero_or_one_space = pynutil.delete(pynini.closure(NEMO_WHITE_SPACE, 0, 1))
|
||||||
|
insert_space = pynutil.insert(" ")
|
||||||
|
delete_extra_space = pynini.cross(pynini.closure(NEMO_WHITE_SPACE, 1), " ")
|
||||||
|
delete_preserve_order = pynini.closure(
|
||||||
|
pynutil.delete(" preserve_order: true")
|
||||||
|
| (pynutil.delete(" field_order: \"") + NEMO_NOT_QUOTE + pynutil.delete("\""))
|
||||||
|
)
|
||||||
|
|
||||||
|
suppletive = pynini.string_file(get_abs_path("data/suppletive.tsv"))
|
||||||
|
# _v = pynini.union("a", "e", "i", "o", "u")
|
||||||
|
_c = pynini.union(
|
||||||
|
"b", "c", "d", "f", "g", "h", "j", "k", "l", "m", "n", "p", "q", "r", "s", "t", "v", "w", "x", "y", "z"
|
||||||
|
)
|
||||||
|
_ies = NEMO_SIGMA + _c + pynini.cross("y", "ies")
|
||||||
|
_es = NEMO_SIGMA + pynini.union("s", "sh", "ch", "x", "z") + pynutil.insert("es")
|
||||||
|
_s = NEMO_SIGMA + pynutil.insert("s")
|
||||||
|
|
||||||
|
graph_plural = plurals._priority_union(
|
||||||
|
suppletive, plurals._priority_union(_ies, plurals._priority_union(_es, _s, NEMO_SIGMA), NEMO_SIGMA), NEMO_SIGMA
|
||||||
|
).optimize()
|
||||||
|
|
||||||
|
SINGULAR_TO_PLURAL = graph_plural
|
||||||
|
PLURAL_TO_SINGULAR = pynini.invert(graph_plural)
|
||||||
|
TO_LOWER = pynini.union(*[pynini.cross(x, y) for x, y in zip(string.ascii_uppercase, string.ascii_lowercase)])
|
||||||
|
TO_UPPER = pynini.invert(TO_LOWER)
|
||||||
|
MIN_NEG_WEIGHT = -0.0001
|
||||||
|
MIN_POS_WEIGHT = 0.0001
|
||||||
|
|
||||||
|
|
||||||
|
def generator_main(file_name: str, graphs: Dict[str, 'pynini.FstLike']):
|
||||||
|
"""
|
||||||
|
Exports graph as OpenFst finite state archive (FAR) file with given file name and rule name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_name: exported file name
|
||||||
|
graphs: Mapping of a rule name and Pynini WFST graph to be exported
|
||||||
|
"""
|
||||||
|
exporter = export.Exporter(file_name)
|
||||||
|
for rule, graph in graphs.items():
|
||||||
|
exporter[rule] = graph.optimize()
|
||||||
|
exporter.close()
|
||||||
|
print(f'Created {file_name}')
|
||||||
|
|
||||||
|
|
||||||
|
def get_plurals(fst):
|
||||||
|
"""
|
||||||
|
Given singular returns plurals
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fst: Fst
|
||||||
|
|
||||||
|
Returns plurals to given singular forms
|
||||||
|
"""
|
||||||
|
return SINGULAR_TO_PLURAL @ fst
|
||||||
|
|
||||||
|
|
||||||
|
def get_singulars(fst):
|
||||||
|
"""
|
||||||
|
Given plural returns singulars
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fst: Fst
|
||||||
|
|
||||||
|
Returns singulars to given plural forms
|
||||||
|
"""
|
||||||
|
return PLURAL_TO_SINGULAR @ fst
|
||||||
|
|
||||||
|
|
||||||
|
def convert_space(fst) -> 'pynini.FstLike':
|
||||||
|
"""
|
||||||
|
Converts space to nonbreaking space.
|
||||||
|
Used only in tagger grammars for transducing token values within quotes, e.g. name: "hello kitty"
|
||||||
|
This is making transducer significantly slower, so only use when there could be potential spaces within quotes, otherwise leave it.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fst: input fst
|
||||||
|
|
||||||
|
Returns output fst where breaking spaces are converted to non breaking spaces
|
||||||
|
"""
|
||||||
|
return fst @ pynini.cdrewrite(pynini.cross(NEMO_SPACE, NEMO_NON_BREAKING_SPACE), "", "", NEMO_SIGMA)
|
||||||
|
|
||||||
|
|
||||||
|
class GraphFst:
|
||||||
|
"""
|
||||||
|
Base class for all grammar fsts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: name of grammar class
|
||||||
|
kind: either 'classify' or 'verbalize'
|
||||||
|
deterministic: if True will provide a single transduction option,
|
||||||
|
for False multiple transduction are generated (used for audio-based normalization)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, name: str, kind: str, deterministic: bool = True):
|
||||||
|
self.name = name
|
||||||
|
self.kind = str
|
||||||
|
self._fst = None
|
||||||
|
self.deterministic = deterministic
|
||||||
|
|
||||||
|
self.far_path = Path(os.path.dirname(__file__) + '/grammars/' + kind + '/' + name + '.far')
|
||||||
|
if self.far_exist():
|
||||||
|
self._fst = Far(self.far_path, mode="r", arc_type="standard", far_type="default").get_fst()
|
||||||
|
|
||||||
|
def far_exist(self) -> bool:
|
||||||
|
"""
|
||||||
|
Returns true if FAR can be loaded
|
||||||
|
"""
|
||||||
|
return self.far_path.exists()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def fst(self) -> 'pynini.FstLike':
|
||||||
|
return self._fst
|
||||||
|
|
||||||
|
@fst.setter
|
||||||
|
def fst(self, fst):
|
||||||
|
self._fst = fst
|
||||||
|
|
||||||
|
def add_tokens(self, fst) -> 'pynini.FstLike':
|
||||||
|
"""
|
||||||
|
Wraps class name around to given fst
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fst: input fst
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Fst: fst
|
||||||
|
"""
|
||||||
|
return pynutil.insert(f"{self.name} {{ ") + fst + pynutil.insert(" }")
|
||||||
|
|
||||||
|
def delete_tokens(self, fst) -> 'pynini.FstLike':
|
||||||
|
"""
|
||||||
|
Deletes class name wrap around output of given fst
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fst: input fst
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Fst: fst
|
||||||
|
"""
|
||||||
|
res = (
|
||||||
|
pynutil.delete(f"{self.name}")
|
||||||
|
+ delete_space
|
||||||
|
+ pynutil.delete("{")
|
||||||
|
+ delete_space
|
||||||
|
+ fst
|
||||||
|
+ delete_space
|
||||||
|
+ pynutil.delete("}")
|
||||||
|
)
|
||||||
|
return res @ pynini.cdrewrite(pynini.cross(u"\u00A0", " "), "", "", NEMO_SIGMA)
|
||||||
@@ -0,0 +1,13 @@
|
|||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
@@ -0,0 +1,50 @@
|
|||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import pynini
|
||||||
|
from nemo_text_processing.text_normalization.en.graph_utils import NEMO_UPPER, GraphFst, insert_space
|
||||||
|
from pynini.lib import pynutil
|
||||||
|
|
||||||
|
|
||||||
|
class AbbreviationFst(GraphFst):
|
||||||
|
"""
|
||||||
|
Finite state transducer for classifying electronic: as URLs, email addresses, etc.
|
||||||
|
e.g. "ABC" -> tokens { abbreviation { value: "A B C" } }
|
||||||
|
|
||||||
|
Args:
|
||||||
|
whitelist: whitelist FST
|
||||||
|
deterministic: if True will provide a single transduction option,
|
||||||
|
for False multiple transduction are generated (used for audio-based normalization)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, whitelist: 'pynini.FstLike', deterministic: bool = True):
|
||||||
|
super().__init__(name="abbreviation", kind="classify", deterministic=deterministic)
|
||||||
|
|
||||||
|
dot = pynini.accep(".")
|
||||||
|
# A.B.C. -> A. B. C.
|
||||||
|
graph = NEMO_UPPER + dot + pynini.closure(insert_space + NEMO_UPPER + dot, 1)
|
||||||
|
# A.B.C. -> A.B.C.
|
||||||
|
graph |= NEMO_UPPER + dot + pynini.closure(NEMO_UPPER + dot, 1)
|
||||||
|
# ABC -> A B C
|
||||||
|
graph |= NEMO_UPPER + pynini.closure(insert_space + NEMO_UPPER, 1)
|
||||||
|
|
||||||
|
# exclude words that are included in the whitelist
|
||||||
|
graph = pynini.compose(
|
||||||
|
pynini.difference(pynini.project(graph, "input"), pynini.project(whitelist.graph, "input")), graph
|
||||||
|
)
|
||||||
|
|
||||||
|
graph = pynutil.insert("value: \"") + graph.optimize() + pynutil.insert("\"")
|
||||||
|
graph = self.add_tokens(graph)
|
||||||
|
self.fst = graph.optimize()
|
||||||
@@ -0,0 +1,138 @@
|
|||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import pynini
|
||||||
|
from nemo_text_processing.text_normalization.en.graph_utils import (
|
||||||
|
NEMO_DIGIT,
|
||||||
|
NEMO_NOT_QUOTE,
|
||||||
|
NEMO_SIGMA,
|
||||||
|
GraphFst,
|
||||||
|
insert_space,
|
||||||
|
)
|
||||||
|
from nemo_text_processing.text_normalization.en.taggers.date import get_four_digit_year_graph
|
||||||
|
from nemo_text_processing.text_normalization.en.utils import get_abs_path
|
||||||
|
from pynini.examples import plurals
|
||||||
|
from pynini.lib import pynutil
|
||||||
|
|
||||||
|
|
||||||
|
class CardinalFst(GraphFst):
|
||||||
|
"""
|
||||||
|
Finite state transducer for classifying cardinals, e.g.
|
||||||
|
-23 -> cardinal { negative: "true" integer: "twenty three" } }
|
||||||
|
|
||||||
|
Args:
|
||||||
|
deterministic: if True will provide a single transduction option,
|
||||||
|
for False multiple transduction are generated (used for audio-based normalization)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, deterministic: bool = True, lm: bool = False):
|
||||||
|
super().__init__(name="cardinal", kind="classify", deterministic=deterministic)
|
||||||
|
|
||||||
|
self.lm = lm
|
||||||
|
self.deterministic = deterministic
|
||||||
|
# TODO replace to have "oh" as a default for "0"
|
||||||
|
graph = pynini.Far(get_abs_path("data/number/cardinal_number_name.far")).get_fst()
|
||||||
|
self.graph_hundred_component_at_least_one_none_zero_digit = (
|
||||||
|
pynini.closure(NEMO_DIGIT, 2, 3) | pynini.difference(NEMO_DIGIT, pynini.accep("0"))
|
||||||
|
) @ graph
|
||||||
|
|
||||||
|
graph_digit = pynini.string_file(get_abs_path("data/number/digit.tsv"))
|
||||||
|
graph_zero = pynini.string_file(get_abs_path("data/number/zero.tsv"))
|
||||||
|
|
||||||
|
single_digits_graph = pynini.invert(graph_digit | graph_zero)
|
||||||
|
self.single_digits_graph = single_digits_graph + pynini.closure(insert_space + single_digits_graph)
|
||||||
|
|
||||||
|
if not deterministic:
|
||||||
|
# for a single token allow only the same normalization
|
||||||
|
# "007" -> {"oh oh seven", "zero zero seven"} not {"oh zero seven"}
|
||||||
|
single_digits_graph_zero = pynini.invert(graph_digit | graph_zero)
|
||||||
|
single_digits_graph_oh = pynini.invert(graph_digit) | pynini.cross("0", "oh")
|
||||||
|
|
||||||
|
self.single_digits_graph = single_digits_graph_zero + pynini.closure(
|
||||||
|
insert_space + single_digits_graph_zero
|
||||||
|
)
|
||||||
|
self.single_digits_graph |= single_digits_graph_oh + pynini.closure(insert_space + single_digits_graph_oh)
|
||||||
|
|
||||||
|
single_digits_graph_with_commas = pynini.closure(
|
||||||
|
self.single_digits_graph + insert_space, 1, 3
|
||||||
|
) + pynini.closure(
|
||||||
|
pynutil.delete(",")
|
||||||
|
+ single_digits_graph
|
||||||
|
+ insert_space
|
||||||
|
+ single_digits_graph
|
||||||
|
+ insert_space
|
||||||
|
+ single_digits_graph,
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
|
||||||
|
optional_minus_graph = pynini.closure(pynutil.insert("negative: ") + pynini.cross("-", "\"true\" "), 0, 1)
|
||||||
|
|
||||||
|
graph = (
|
||||||
|
pynini.closure(NEMO_DIGIT, 1, 3)
|
||||||
|
+ (pynini.closure(pynutil.delete(",") + NEMO_DIGIT ** 3) | pynini.closure(NEMO_DIGIT ** 3))
|
||||||
|
) @ graph
|
||||||
|
|
||||||
|
self.graph = graph
|
||||||
|
self.graph_with_and = self.add_optional_and(graph)
|
||||||
|
|
||||||
|
if deterministic:
|
||||||
|
long_numbers = pynini.compose(NEMO_DIGIT ** (5, ...), self.single_digits_graph).optimize()
|
||||||
|
final_graph = plurals._priority_union(long_numbers, self.graph_with_and, NEMO_SIGMA).optimize()
|
||||||
|
cardinal_with_leading_zeros = pynini.compose(
|
||||||
|
pynini.accep("0") + pynini.closure(NEMO_DIGIT), self.single_digits_graph
|
||||||
|
)
|
||||||
|
final_graph |= cardinal_with_leading_zeros
|
||||||
|
else:
|
||||||
|
leading_zeros = pynini.compose(pynini.closure(pynini.accep("0"), 1), self.single_digits_graph)
|
||||||
|
cardinal_with_leading_zeros = (
|
||||||
|
leading_zeros + pynutil.insert(" ") + pynini.compose(pynini.closure(NEMO_DIGIT), self.graph_with_and)
|
||||||
|
)
|
||||||
|
|
||||||
|
# add small weight to non-default graphs to make sure the deterministic option is listed first
|
||||||
|
final_graph = (
|
||||||
|
self.graph_with_and
|
||||||
|
| pynutil.add_weight(self.single_digits_graph, 0.0001)
|
||||||
|
| get_four_digit_year_graph() # allows e.g. 4567 be pronouced as forty five sixty seven
|
||||||
|
| pynutil.add_weight(single_digits_graph_with_commas, 0.0001)
|
||||||
|
| cardinal_with_leading_zeros
|
||||||
|
)
|
||||||
|
|
||||||
|
final_graph = optional_minus_graph + pynutil.insert("integer: \"") + final_graph + pynutil.insert("\"")
|
||||||
|
final_graph = self.add_tokens(final_graph)
|
||||||
|
self.fst = final_graph.optimize()
|
||||||
|
|
||||||
|
def add_optional_and(self, graph):
|
||||||
|
graph_with_and = graph
|
||||||
|
|
||||||
|
if not self.lm:
|
||||||
|
graph_with_and = pynutil.add_weight(graph, 0.00001)
|
||||||
|
not_quote = pynini.closure(NEMO_NOT_QUOTE)
|
||||||
|
no_thousand_million = pynini.difference(
|
||||||
|
not_quote, not_quote + pynini.union("thousand", "million") + not_quote
|
||||||
|
).optimize()
|
||||||
|
integer = (
|
||||||
|
not_quote + pynutil.add_weight(pynini.cross("hundred ", "hundred and ") + no_thousand_million, -0.0001)
|
||||||
|
).optimize()
|
||||||
|
|
||||||
|
no_hundred = pynini.difference(NEMO_SIGMA, not_quote + pynini.accep("hundred") + not_quote).optimize()
|
||||||
|
integer |= (
|
||||||
|
not_quote + pynutil.add_weight(pynini.cross("thousand ", "thousand and ") + no_hundred, -0.0001)
|
||||||
|
).optimize()
|
||||||
|
|
||||||
|
optional_hundred = pynini.compose((NEMO_DIGIT - "0") ** 3, graph).optimize()
|
||||||
|
optional_hundred = pynini.compose(optional_hundred, NEMO_SIGMA + pynini.cross(" hundred", "") + NEMO_SIGMA)
|
||||||
|
graph_with_and |= pynini.compose(graph, integer).optimize()
|
||||||
|
graph_with_and |= optional_hundred
|
||||||
|
return graph_with_and
|
||||||
@@ -0,0 +1,370 @@
|
|||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import pynini
|
||||||
|
from nemo_text_processing.text_normalization.en.graph_utils import (
|
||||||
|
NEMO_CHAR,
|
||||||
|
NEMO_DIGIT,
|
||||||
|
NEMO_LOWER,
|
||||||
|
NEMO_SIGMA,
|
||||||
|
NEMO_NOT_QUOTE,
|
||||||
|
TO_LOWER,
|
||||||
|
GraphFst,
|
||||||
|
delete_extra_space,
|
||||||
|
delete_space,
|
||||||
|
insert_space,
|
||||||
|
)
|
||||||
|
from nemo_text_processing.text_normalization.en.utils import (
|
||||||
|
augment_labels_with_punct_at_end,
|
||||||
|
get_abs_path,
|
||||||
|
load_labels,
|
||||||
|
)
|
||||||
|
from pynini.examples import plurals
|
||||||
|
from pynini.lib import pynutil
|
||||||
|
|
||||||
|
graph_teen = pynini.invert(pynini.string_file(get_abs_path("data/number/teen.tsv"))).optimize()
|
||||||
|
graph_digit = pynini.invert(pynini.string_file(get_abs_path("data/number/digit.tsv"))).optimize()
|
||||||
|
ties_graph = pynini.invert(pynini.string_file(get_abs_path("data/number/ty.tsv"))).optimize()
|
||||||
|
year_suffix = load_labels(get_abs_path("data/date/year_suffix.tsv"))
|
||||||
|
year_suffix.extend(augment_labels_with_punct_at_end(year_suffix))
|
||||||
|
year_suffix = pynini.string_map(year_suffix).optimize()
|
||||||
|
|
||||||
|
|
||||||
|
def get_ties_graph(deterministic: bool = True):
|
||||||
|
"""
|
||||||
|
Returns two digit transducer, e.g.
|
||||||
|
03 -> o three
|
||||||
|
12 -> thirteen
|
||||||
|
20 -> twenty
|
||||||
|
"""
|
||||||
|
graph = graph_teen | ties_graph + pynutil.delete("0") | ties_graph + insert_space + graph_digit
|
||||||
|
|
||||||
|
if deterministic:
|
||||||
|
graph = graph | pynini.cross("0", "o") + insert_space + graph_digit
|
||||||
|
else:
|
||||||
|
graph = graph | (pynini.cross("0", "o") | pynini.cross("0", "zero")) + insert_space + graph_digit
|
||||||
|
|
||||||
|
return graph.optimize()
|
||||||
|
|
||||||
|
|
||||||
|
def get_four_digit_year_graph(deterministic: bool = True):
|
||||||
|
"""
|
||||||
|
Returns a four digit transducer which is combination of ties/teen or digits
|
||||||
|
(using hundred instead of thousand format), e.g.
|
||||||
|
1219 -> twelve nineteen
|
||||||
|
3900 -> thirty nine hundred
|
||||||
|
"""
|
||||||
|
graph_ties = get_ties_graph(deterministic)
|
||||||
|
|
||||||
|
graph_with_s = (
|
||||||
|
(graph_ties + insert_space + graph_ties)
|
||||||
|
| (graph_teen + insert_space + (ties_graph | pynini.cross("1", "ten")))
|
||||||
|
) + pynutil.delete("0s")
|
||||||
|
|
||||||
|
graph_with_s |= (graph_teen | graph_ties) + insert_space + pynini.cross("00", "hundred") + pynutil.delete("s")
|
||||||
|
graph_with_s = graph_with_s @ pynini.cdrewrite(
|
||||||
|
pynini.cross("y", "ies") | pynutil.insert("s"), "", "[EOS]", NEMO_SIGMA
|
||||||
|
)
|
||||||
|
|
||||||
|
graph = graph_ties + insert_space + graph_ties
|
||||||
|
graph |= (graph_teen | graph_ties) + insert_space + pynini.cross("00", "hundred")
|
||||||
|
|
||||||
|
thousand_graph = (
|
||||||
|
graph_digit
|
||||||
|
+ insert_space
|
||||||
|
+ pynini.cross("00", "thousand")
|
||||||
|
+ (pynutil.delete("0") | insert_space + graph_digit)
|
||||||
|
)
|
||||||
|
thousand_graph |= (
|
||||||
|
graph_digit
|
||||||
|
+ insert_space
|
||||||
|
+ pynini.cross("000", "thousand")
|
||||||
|
+ pynini.closure(pynutil.delete(" "), 0, 1)
|
||||||
|
+ pynini.accep("s")
|
||||||
|
)
|
||||||
|
|
||||||
|
graph |= graph_with_s
|
||||||
|
if deterministic:
|
||||||
|
graph = plurals._priority_union(thousand_graph, graph, NEMO_SIGMA)
|
||||||
|
else:
|
||||||
|
graph |= thousand_graph
|
||||||
|
|
||||||
|
return graph.optimize()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_two_digit_year_with_s_graph():
|
||||||
|
# to handle '70s -> seventies
|
||||||
|
graph = (
|
||||||
|
pynini.closure(pynutil.delete("'"), 0, 1)
|
||||||
|
+ pynini.compose(
|
||||||
|
ties_graph + pynutil.delete("0s"), pynini.cdrewrite(pynini.cross("y", "ies"), "", "[EOS]", NEMO_SIGMA)
|
||||||
|
)
|
||||||
|
).optimize()
|
||||||
|
return graph
|
||||||
|
|
||||||
|
|
||||||
|
def _get_year_graph(cardinal_graph, deterministic: bool = True):
|
||||||
|
"""
|
||||||
|
Transducer for year, only from 1000 - 2999 e.g.
|
||||||
|
1290 -> twelve nineteen
|
||||||
|
2000 - 2009 will be verbalized as two thousand.
|
||||||
|
|
||||||
|
Transducer for 3 digit year, e.g. 123-> one twenty three
|
||||||
|
|
||||||
|
Transducer for year with suffix
|
||||||
|
123 A.D., 4200 B.C
|
||||||
|
"""
|
||||||
|
graph = get_four_digit_year_graph(deterministic)
|
||||||
|
graph = (pynini.union("1", "2") + (NEMO_DIGIT ** 3) + pynini.closure(pynini.cross(" s", "s") | "s", 0, 1)) @ graph
|
||||||
|
|
||||||
|
graph |= _get_two_digit_year_with_s_graph()
|
||||||
|
|
||||||
|
three_digit_year = (NEMO_DIGIT @ cardinal_graph) + insert_space + (NEMO_DIGIT ** 2) @ cardinal_graph
|
||||||
|
year_with_suffix = (
|
||||||
|
(get_four_digit_year_graph(deterministic=True) | three_digit_year) + delete_space + insert_space + year_suffix
|
||||||
|
)
|
||||||
|
graph |= year_with_suffix
|
||||||
|
return graph.optimize()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_two_digit_year(cardinal_graph, single_digits_graph):
|
||||||
|
wo_digit_year = NEMO_DIGIT ** (2) @ plurals._priority_union(cardinal_graph, single_digits_graph, NEMO_SIGMA)
|
||||||
|
return wo_digit_year
|
||||||
|
|
||||||
|
|
||||||
|
class DateFst(GraphFst):
|
||||||
|
"""
|
||||||
|
Finite state transducer for classifying date, e.g.
|
||||||
|
jan. 5, 2012 -> date { month: "january" day: "five" year: "twenty twelve" preserve_order: true }
|
||||||
|
jan. 5 -> date { month: "january" day: "five" preserve_order: true }
|
||||||
|
5 january 2012 -> date { day: "five" month: "january" year: "twenty twelve" preserve_order: true }
|
||||||
|
2012-01-05 -> date { year: "twenty twelve" month: "january" day: "five" }
|
||||||
|
2012.01.05 -> date { year: "twenty twelve" month: "january" day: "five" }
|
||||||
|
2012/01/05 -> date { year: "twenty twelve" month: "january" day: "five" }
|
||||||
|
2012 -> date { year: "twenty twelve" }
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cardinal: CardinalFst
|
||||||
|
deterministic: if True will provide a single transduction option,
|
||||||
|
for False multiple transduction are generated (used for audio-based normalization)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cardinal: GraphFst, deterministic: bool, lm: bool = False):
|
||||||
|
super().__init__(name="date", kind="classify", deterministic=deterministic)
|
||||||
|
|
||||||
|
# january
|
||||||
|
month_graph = pynini.string_file(get_abs_path("data/date/month_name.tsv")).optimize()
|
||||||
|
# January, JANUARY
|
||||||
|
month_graph |= pynini.compose(TO_LOWER + pynini.closure(NEMO_CHAR), month_graph) | pynini.compose(
|
||||||
|
TO_LOWER ** (2, ...), month_graph
|
||||||
|
)
|
||||||
|
|
||||||
|
# jan
|
||||||
|
month_abbr_graph = pynini.string_file(get_abs_path("data/date/month_abbr.tsv")).optimize()
|
||||||
|
# jan, Jan, JAN
|
||||||
|
month_abbr_graph = (
|
||||||
|
month_abbr_graph
|
||||||
|
| pynini.compose(TO_LOWER + pynini.closure(NEMO_LOWER, 1), month_abbr_graph).optimize()
|
||||||
|
| pynini.compose(TO_LOWER ** (2, ...), month_abbr_graph).optimize()
|
||||||
|
) + pynini.closure(pynutil.delete("."), 0, 1)
|
||||||
|
month_graph |= month_abbr_graph.optimize()
|
||||||
|
|
||||||
|
month_numbers_labels = pynini.string_file(get_abs_path("data/date/month_number.tsv")).optimize()
|
||||||
|
cardinal_graph = cardinal.graph_hundred_component_at_least_one_none_zero_digit
|
||||||
|
|
||||||
|
year_graph = _get_year_graph(cardinal_graph=cardinal_graph, deterministic=deterministic)
|
||||||
|
|
||||||
|
# three_digit_year = (NEMO_DIGIT @ cardinal_graph) + insert_space + (NEMO_DIGIT ** 2) @ cardinal_graph
|
||||||
|
# year_graph |= three_digit_year
|
||||||
|
|
||||||
|
month_graph = pynutil.insert("month: \"") + month_graph + pynutil.insert("\"")
|
||||||
|
month_numbers_graph = pynutil.insert("month: \"") + month_numbers_labels + pynutil.insert("\"")
|
||||||
|
|
||||||
|
endings = ["rd", "th", "st", "nd"]
|
||||||
|
endings += [x.upper() for x in endings]
|
||||||
|
endings = pynini.union(*endings)
|
||||||
|
|
||||||
|
day_graph = (
|
||||||
|
pynutil.insert("day: \"")
|
||||||
|
+ pynini.closure(pynutil.delete("the "), 0, 1)
|
||||||
|
+ (
|
||||||
|
((pynini.union("1", "2") + NEMO_DIGIT) | NEMO_DIGIT | (pynini.accep("3") + pynini.union("0", "1")))
|
||||||
|
+ pynini.closure(pynutil.delete(endings), 0, 1)
|
||||||
|
)
|
||||||
|
@ cardinal_graph
|
||||||
|
+ pynutil.insert("\"")
|
||||||
|
)
|
||||||
|
|
||||||
|
two_digit_year = _get_two_digit_year(
|
||||||
|
cardinal_graph=cardinal_graph, single_digits_graph=cardinal.single_digits_graph
|
||||||
|
)
|
||||||
|
two_digit_year = pynutil.insert("year: \"") + two_digit_year + pynutil.insert("\"")
|
||||||
|
|
||||||
|
# if lm:
|
||||||
|
# two_digit_year = pynini.compose(pynini.difference(NEMO_DIGIT, "0") + NEMO_DIGIT ** (3), two_digit_year)
|
||||||
|
# year_graph = pynini.compose(pynini.difference(NEMO_DIGIT, "0") + NEMO_DIGIT ** (2), year_graph)
|
||||||
|
# year_graph |= pynini.compose(pynini.difference(NEMO_DIGIT, "0") + NEMO_DIGIT ** (4, ...), year_graph)
|
||||||
|
|
||||||
|
graph_year = pynutil.insert(" year: \"") + pynutil.delete(" ") + year_graph + pynutil.insert("\"")
|
||||||
|
graph_year |= (
|
||||||
|
pynutil.insert(" year: \"")
|
||||||
|
+ pynini.accep(",")
|
||||||
|
+ pynini.closure(pynini.accep(" "), 0, 1)
|
||||||
|
+ year_graph
|
||||||
|
+ pynutil.insert("\"")
|
||||||
|
)
|
||||||
|
optional_graph_year = pynini.closure(graph_year, 0, 1)
|
||||||
|
|
||||||
|
year_graph = pynutil.insert("year: \"") + year_graph + pynutil.insert("\"")
|
||||||
|
|
||||||
|
graph_mdy = month_graph + (
|
||||||
|
(delete_extra_space + day_graph)
|
||||||
|
| (pynini.accep(" ") + day_graph)
|
||||||
|
| graph_year
|
||||||
|
| (delete_extra_space + day_graph + graph_year)
|
||||||
|
)
|
||||||
|
|
||||||
|
graph_mdy |= (
|
||||||
|
month_graph
|
||||||
|
+ pynini.cross("-", " ")
|
||||||
|
+ day_graph
|
||||||
|
+ pynini.closure(((pynini.cross("-", " ") + NEMO_SIGMA) @ graph_year), 0, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
for x in ["-", "/", "."]:
|
||||||
|
delete_sep = pynutil.delete(x)
|
||||||
|
graph_mdy |= (
|
||||||
|
month_numbers_graph
|
||||||
|
+ delete_sep
|
||||||
|
+ insert_space
|
||||||
|
+ pynini.closure(pynutil.delete("0"), 0, 1)
|
||||||
|
+ day_graph
|
||||||
|
+ delete_sep
|
||||||
|
+ insert_space
|
||||||
|
+ (year_graph | two_digit_year)
|
||||||
|
)
|
||||||
|
|
||||||
|
graph_dmy = day_graph + delete_extra_space + month_graph + optional_graph_year
|
||||||
|
day_ex_month = (NEMO_DIGIT ** 2 - pynini.project(month_numbers_graph, "input")) @ day_graph
|
||||||
|
for x in ["-", "/", "."]:
|
||||||
|
delete_sep = pynutil.delete(x)
|
||||||
|
graph_dmy |= (
|
||||||
|
day_ex_month
|
||||||
|
+ delete_sep
|
||||||
|
+ insert_space
|
||||||
|
+ month_numbers_graph
|
||||||
|
+ delete_sep
|
||||||
|
+ insert_space
|
||||||
|
+ (year_graph | two_digit_year)
|
||||||
|
)
|
||||||
|
|
||||||
|
graph_ymd = pynini.accep("")
|
||||||
|
for x in ["-", "/", "."]:
|
||||||
|
delete_sep = pynutil.delete(x)
|
||||||
|
graph_ymd |= (
|
||||||
|
(year_graph | two_digit_year)
|
||||||
|
+ delete_sep
|
||||||
|
+ insert_space
|
||||||
|
+ month_numbers_graph
|
||||||
|
+ delete_sep
|
||||||
|
+ insert_space
|
||||||
|
+ pynini.closure(pynutil.delete("0"), 0, 1)
|
||||||
|
+ day_graph
|
||||||
|
)
|
||||||
|
|
||||||
|
final_graph = graph_mdy | graph_dmy
|
||||||
|
|
||||||
|
if not deterministic or lm:
|
||||||
|
final_graph += pynini.closure(pynutil.insert(" preserve_order: true"), 0, 1)
|
||||||
|
m_sep_d = (
|
||||||
|
month_numbers_graph
|
||||||
|
+ pynutil.delete(pynini.union("-", "/"))
|
||||||
|
+ insert_space
|
||||||
|
+ pynini.closure(pynutil.delete("0"), 0, 1)
|
||||||
|
+ day_graph
|
||||||
|
)
|
||||||
|
final_graph |= m_sep_d
|
||||||
|
else:
|
||||||
|
final_graph += pynutil.insert(" preserve_order: true")
|
||||||
|
|
||||||
|
final_graph |= graph_ymd | year_graph
|
||||||
|
|
||||||
|
if not deterministic or lm:
|
||||||
|
ymd_to_mdy_graph = None
|
||||||
|
ymd_to_dmy_graph = None
|
||||||
|
mdy_to_dmy_graph = None
|
||||||
|
md_to_dm_graph = None
|
||||||
|
|
||||||
|
for month in [x[0] for x in load_labels(get_abs_path("data/date/month_name.tsv"))]:
|
||||||
|
for day in [x[0] for x in load_labels(get_abs_path("data/date/day.tsv"))]:
|
||||||
|
ymd_to_mdy_curr = (
|
||||||
|
pynutil.insert("month: \"" + month + "\" day: \"" + day + "\" ")
|
||||||
|
+ pynini.accep('year:')
|
||||||
|
+ NEMO_SIGMA
|
||||||
|
+ pynutil.delete(" month: \"" + month + "\" day: \"" + day + "\"")
|
||||||
|
)
|
||||||
|
|
||||||
|
# YY-MM-DD -> MM-DD-YY
|
||||||
|
ymd_to_mdy_curr = pynini.compose(graph_ymd, ymd_to_mdy_curr)
|
||||||
|
ymd_to_mdy_graph = (
|
||||||
|
ymd_to_mdy_curr
|
||||||
|
if ymd_to_mdy_graph is None
|
||||||
|
else pynini.union(ymd_to_mdy_curr, ymd_to_mdy_graph)
|
||||||
|
)
|
||||||
|
|
||||||
|
ymd_to_dmy_curr = (
|
||||||
|
pynutil.insert("day: \"" + day + "\" month: \"" + month + "\" ")
|
||||||
|
+ pynini.accep('year:')
|
||||||
|
+ NEMO_SIGMA
|
||||||
|
+ pynutil.delete(" month: \"" + month + "\" day: \"" + day + "\"")
|
||||||
|
)
|
||||||
|
|
||||||
|
# YY-MM-DD -> MM-DD-YY
|
||||||
|
ymd_to_dmy_curr = pynini.compose(graph_ymd, ymd_to_dmy_curr).optimize()
|
||||||
|
ymd_to_dmy_graph = (
|
||||||
|
ymd_to_dmy_curr
|
||||||
|
if ymd_to_dmy_graph is None
|
||||||
|
else pynini.union(ymd_to_dmy_curr, ymd_to_dmy_graph)
|
||||||
|
)
|
||||||
|
|
||||||
|
mdy_to_dmy_curr = (
|
||||||
|
pynutil.insert("day: \"" + day + "\" month: \"" + month + "\" ")
|
||||||
|
+ pynutil.delete("month: \"" + month + "\" day: \"" + day + "\" ")
|
||||||
|
+ pynini.accep('year:')
|
||||||
|
+ NEMO_SIGMA
|
||||||
|
).optimize()
|
||||||
|
# MM-DD-YY -> verbalize as MM-DD-YY (February fourth 1991) or DD-MM-YY (the fourth of February 1991)
|
||||||
|
mdy_to_dmy_curr = pynini.compose(graph_mdy, mdy_to_dmy_curr).optimize()
|
||||||
|
mdy_to_dmy_graph = (
|
||||||
|
mdy_to_dmy_curr
|
||||||
|
if mdy_to_dmy_graph is None
|
||||||
|
else pynini.union(mdy_to_dmy_curr, mdy_to_dmy_graph).optimize()
|
||||||
|
).optimize()
|
||||||
|
|
||||||
|
md_to_dm_curr = pynutil.insert("day: \"" + day + "\" month: \"" + month + "\"") + pynutil.delete(
|
||||||
|
"month: \"" + month + "\" day: \"" + day + "\""
|
||||||
|
)
|
||||||
|
md_to_dm_curr = pynini.compose(m_sep_d, md_to_dm_curr).optimize()
|
||||||
|
|
||||||
|
md_to_dm_graph = (
|
||||||
|
md_to_dm_curr
|
||||||
|
if md_to_dm_graph is None
|
||||||
|
else pynini.union(md_to_dm_curr, md_to_dm_graph).optimize()
|
||||||
|
).optimize()
|
||||||
|
|
||||||
|
final_graph |= mdy_to_dmy_graph | md_to_dm_graph | ymd_to_mdy_graph | ymd_to_dmy_graph
|
||||||
|
|
||||||
|
final_graph = self.add_tokens(final_graph)
|
||||||
|
self.fst = final_graph.optimize()
|
||||||
|
|
||||||
@@ -0,0 +1,129 @@
|
|||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import pynini
|
||||||
|
from nemo_text_processing.text_normalization.en.graph_utils import NEMO_SIGMA, TO_UPPER, GraphFst, get_abs_path
|
||||||
|
from pynini.lib import pynutil
|
||||||
|
|
||||||
|
delete_space = pynutil.delete(" ")
|
||||||
|
quantities = pynini.string_file(get_abs_path("data/number/thousand.tsv"))
|
||||||
|
quantities_abbr = pynini.string_file(get_abs_path("data/number/quantity_abbr.tsv"))
|
||||||
|
quantities_abbr |= TO_UPPER @ quantities_abbr
|
||||||
|
|
||||||
|
|
||||||
|
def get_quantity(
|
||||||
|
decimal: 'pynini.FstLike', cardinal_up_to_hundred: 'pynini.FstLike', include_abbr: bool
|
||||||
|
) -> 'pynini.FstLike':
|
||||||
|
"""
|
||||||
|
Returns FST that transforms either a cardinal or decimal followed by a quantity into a numeral,
|
||||||
|
e.g. 1 million -> integer_part: "one" quantity: "million"
|
||||||
|
e.g. 1.5 million -> integer_part: "one" fractional_part: "five" quantity: "million"
|
||||||
|
|
||||||
|
Args:
|
||||||
|
decimal: decimal FST
|
||||||
|
cardinal_up_to_hundred: cardinal FST
|
||||||
|
"""
|
||||||
|
quantity_wo_thousand = pynini.project(quantities, "input") - pynini.union("k", "K", "thousand")
|
||||||
|
if include_abbr:
|
||||||
|
quantity_wo_thousand |= pynini.project(quantities_abbr, "input") - pynini.union("k", "K", "thousand")
|
||||||
|
res = (
|
||||||
|
pynutil.insert("integer_part: \"")
|
||||||
|
+ cardinal_up_to_hundred
|
||||||
|
+ pynutil.insert("\"")
|
||||||
|
+ pynini.closure(pynutil.delete(" "), 0, 1)
|
||||||
|
+ pynutil.insert(" quantity: \"")
|
||||||
|
+ (quantity_wo_thousand @ (quantities | quantities_abbr))
|
||||||
|
+ pynutil.insert("\"")
|
||||||
|
)
|
||||||
|
if include_abbr:
|
||||||
|
quantity = quantities | quantities_abbr
|
||||||
|
else:
|
||||||
|
quantity = quantities
|
||||||
|
res |= (
|
||||||
|
decimal
|
||||||
|
+ pynini.closure(pynutil.delete(" "), 0, 1)
|
||||||
|
+ pynutil.insert("quantity: \"")
|
||||||
|
+ quantity
|
||||||
|
+ pynutil.insert("\"")
|
||||||
|
)
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
class DecimalFst(GraphFst):
|
||||||
|
"""
|
||||||
|
Finite state transducer for classifying decimal, e.g.
|
||||||
|
-12.5006 billion -> decimal { negative: "true" integer_part: "12" fractional_part: "five o o six" quantity: "billion" }
|
||||||
|
1 billion -> decimal { integer_part: "one" quantity: "billion" }
|
||||||
|
|
||||||
|
cardinal: CardinalFst
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cardinal: GraphFst, deterministic: bool):
|
||||||
|
super().__init__(name="decimal", kind="classify", deterministic=deterministic)
|
||||||
|
|
||||||
|
cardinal_graph = cardinal.graph_with_and
|
||||||
|
cardinal_graph_hundred_component_at_least_one_none_zero_digit = (
|
||||||
|
cardinal.graph_hundred_component_at_least_one_none_zero_digit
|
||||||
|
)
|
||||||
|
|
||||||
|
self.graph = cardinal.single_digits_graph.optimize()
|
||||||
|
|
||||||
|
if not deterministic:
|
||||||
|
self.graph = self.graph | cardinal_graph
|
||||||
|
|
||||||
|
point = pynutil.delete(".")
|
||||||
|
optional_graph_negative = pynini.closure(pynutil.insert("negative: ") + pynini.cross("-", "\"true\" "), 0, 1)
|
||||||
|
|
||||||
|
self.graph_fractional = pynutil.insert("fractional_part: \"") + self.graph + pynutil.insert("\"")
|
||||||
|
self.graph_integer = pynutil.insert("integer_part: \"") + cardinal_graph + pynutil.insert("\"")
|
||||||
|
final_graph_wo_sign = (
|
||||||
|
pynini.closure(self.graph_integer + pynutil.insert(" "), 0, 1)
|
||||||
|
+ point
|
||||||
|
+ pynutil.insert(" ")
|
||||||
|
+ self.graph_fractional
|
||||||
|
)
|
||||||
|
|
||||||
|
quantity_w_abbr = get_quantity(
|
||||||
|
final_graph_wo_sign, cardinal_graph_hundred_component_at_least_one_none_zero_digit, include_abbr=True
|
||||||
|
)
|
||||||
|
quantity_wo_abbr = get_quantity(
|
||||||
|
final_graph_wo_sign, cardinal_graph_hundred_component_at_least_one_none_zero_digit, include_abbr=False
|
||||||
|
)
|
||||||
|
self.final_graph_wo_negative_w_abbr = final_graph_wo_sign | quantity_w_abbr
|
||||||
|
self.final_graph_wo_negative = final_graph_wo_sign | quantity_wo_abbr
|
||||||
|
|
||||||
|
# reduce options for non_deterministic and allow either "oh" or "zero", but not combination
|
||||||
|
if not deterministic:
|
||||||
|
no_oh_zero = pynini.difference(
|
||||||
|
NEMO_SIGMA,
|
||||||
|
(NEMO_SIGMA + "oh" + NEMO_SIGMA + "zero" + NEMO_SIGMA)
|
||||||
|
| (NEMO_SIGMA + "zero" + NEMO_SIGMA + "oh" + NEMO_SIGMA),
|
||||||
|
).optimize()
|
||||||
|
no_zero_oh = pynini.difference(
|
||||||
|
NEMO_SIGMA, NEMO_SIGMA + pynini.accep("zero") + NEMO_SIGMA + pynini.accep("oh") + NEMO_SIGMA
|
||||||
|
).optimize()
|
||||||
|
|
||||||
|
self.final_graph_wo_negative |= pynini.compose(
|
||||||
|
self.final_graph_wo_negative,
|
||||||
|
pynini.cdrewrite(
|
||||||
|
pynini.cross("integer_part: \"zero\"", "integer_part: \"oh\""), NEMO_SIGMA, NEMO_SIGMA, NEMO_SIGMA
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.final_graph_wo_negative = pynini.compose(self.final_graph_wo_negative, no_oh_zero).optimize()
|
||||||
|
self.final_graph_wo_negative = pynini.compose(self.final_graph_wo_negative, no_zero_oh).optimize()
|
||||||
|
|
||||||
|
final_graph = optional_graph_negative + self.final_graph_wo_negative
|
||||||
|
|
||||||
|
final_graph = self.add_tokens(final_graph)
|
||||||
|
self.fst = final_graph.optimize()
|
||||||
@@ -0,0 +1,87 @@
|
|||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import pynini
|
||||||
|
from nemo_text_processing.text_normalization.en.graph_utils import (
|
||||||
|
NEMO_ALPHA,
|
||||||
|
NEMO_DIGIT,
|
||||||
|
NEMO_SIGMA,
|
||||||
|
GraphFst,
|
||||||
|
get_abs_path,
|
||||||
|
insert_space,
|
||||||
|
)
|
||||||
|
from pynini.lib import pynutil
|
||||||
|
|
||||||
|
|
||||||
|
class ElectronicFst(GraphFst):
|
||||||
|
"""
|
||||||
|
Finite state transducer for classifying electronic: as URLs, email addresses, etc.
|
||||||
|
e.g. cdf1@abc.edu -> tokens { electronic { username: "cdf1" domain: "abc.edu" } }
|
||||||
|
|
||||||
|
Args:
|
||||||
|
deterministic: if True will provide a single transduction option,
|
||||||
|
for False multiple transduction are generated (used for audio-based normalization)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, deterministic: bool = True):
|
||||||
|
super().__init__(name="electronic", kind="classify", deterministic=deterministic)
|
||||||
|
|
||||||
|
accepted_symbols = pynini.project(pynini.string_file(get_abs_path("data/electronic/symbol.tsv")), "input")
|
||||||
|
accepted_common_domains = pynini.project(
|
||||||
|
pynini.string_file(get_abs_path("data/electronic/domain.tsv")), "input"
|
||||||
|
)
|
||||||
|
all_accepted_symbols = NEMO_ALPHA + pynini.closure(NEMO_ALPHA | NEMO_DIGIT | accepted_symbols)
|
||||||
|
graph_symbols = pynini.string_file(get_abs_path("data/electronic/symbol.tsv")).optimize()
|
||||||
|
|
||||||
|
username = (
|
||||||
|
pynutil.insert("username: \"") + all_accepted_symbols + pynutil.insert("\"") + pynini.cross('@', ' ')
|
||||||
|
)
|
||||||
|
domain_graph = all_accepted_symbols + pynini.accep('.') + all_accepted_symbols + NEMO_ALPHA
|
||||||
|
protocol_symbols = pynini.closure((graph_symbols | pynini.cross(":", "semicolon")) + pynutil.insert(" "))
|
||||||
|
protocol_start = (pynini.cross("https", "HTTPS ") | pynini.cross("http", "HTTP ")) + (
|
||||||
|
pynini.accep("://") @ protocol_symbols
|
||||||
|
)
|
||||||
|
protocol_file_start = pynini.accep("file") + insert_space + (pynini.accep(":///") @ protocol_symbols)
|
||||||
|
|
||||||
|
protocol_end = pynini.cross("www", "WWW ") + pynini.accep(".") @ protocol_symbols
|
||||||
|
protocol = protocol_file_start | protocol_start | protocol_end | (protocol_start + protocol_end)
|
||||||
|
|
||||||
|
domain_graph = (
|
||||||
|
pynutil.insert("domain: \"")
|
||||||
|
+ pynini.difference(domain_graph, pynini.project(protocol, "input") + NEMO_SIGMA)
|
||||||
|
+ pynutil.insert("\"")
|
||||||
|
)
|
||||||
|
domain_common_graph = (
|
||||||
|
pynutil.insert("domain: \"")
|
||||||
|
+ pynini.difference(
|
||||||
|
all_accepted_symbols
|
||||||
|
+ accepted_common_domains
|
||||||
|
+ pynini.closure(accepted_symbols + pynini.closure(NEMO_ALPHA | NEMO_DIGIT | accepted_symbols), 0, 1),
|
||||||
|
pynini.project(protocol, "input") + NEMO_SIGMA,
|
||||||
|
)
|
||||||
|
+ pynutil.insert("\"")
|
||||||
|
)
|
||||||
|
|
||||||
|
protocol = pynutil.insert("protocol: \"") + protocol + pynutil.insert("\"")
|
||||||
|
# email
|
||||||
|
graph = username + domain_graph
|
||||||
|
# abc.com, abc.com/123-sm
|
||||||
|
graph |= domain_common_graph
|
||||||
|
# www.abc.com/sdafsdf, or https://www.abc.com/asdfad or www.abc.abc/asdfad
|
||||||
|
graph |= protocol + pynutil.insert(" ") + domain_graph
|
||||||
|
|
||||||
|
final_graph = self.add_tokens(graph)
|
||||||
|
|
||||||
|
self.fst = final_graph.optimize()
|
||||||
@@ -0,0 +1,55 @@
|
|||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import pynini
|
||||||
|
from nemo_text_processing.text_normalization.en.graph_utils import GraphFst, get_abs_path
|
||||||
|
from pynini.lib import pynutil
|
||||||
|
|
||||||
|
|
||||||
|
class FractionFst(GraphFst):
|
||||||
|
"""
|
||||||
|
Finite state transducer for classifying fraction
|
||||||
|
"23 4/5" ->
|
||||||
|
tokens { fraction { integer: "twenty three" numerator: "four" denominator: "five" } }
|
||||||
|
"23 4/5th" ->
|
||||||
|
tokens { fraction { integer: "twenty three" numerator: "four" denominator: "five" } }
|
||||||
|
|
||||||
|
Args:
|
||||||
|
deterministic: if True will provide a single transduction option,
|
||||||
|
for False multiple transduction are generated (used for audio-based normalization)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cardinal, deterministic: bool = True):
|
||||||
|
super().__init__(name="fraction", kind="classify", deterministic=deterministic)
|
||||||
|
cardinal_graph = cardinal.graph
|
||||||
|
|
||||||
|
integer = pynutil.insert("integer_part: \"") + cardinal_graph + pynutil.insert("\"")
|
||||||
|
numerator = (
|
||||||
|
pynutil.insert("numerator: \"") + cardinal_graph + (pynini.cross("/", "\" ") | pynini.cross(" / ", "\" "))
|
||||||
|
)
|
||||||
|
|
||||||
|
endings = ["rd", "th", "st", "nd"]
|
||||||
|
endings += [x.upper() for x in endings]
|
||||||
|
optional_end = pynini.closure(pynini.cross(pynini.union(*endings), ""), 0, 1)
|
||||||
|
|
||||||
|
denominator = pynutil.insert("denominator: \"") + cardinal_graph + optional_end + pynutil.insert("\"")
|
||||||
|
|
||||||
|
graph = pynini.closure(integer + pynini.accep(" "), 0, 1) + (numerator + denominator)
|
||||||
|
graph |= pynini.closure(integer + (pynini.accep(" ") | pynutil.insert(" ")), 0, 1) + pynini.compose(
|
||||||
|
pynini.string_file(get_abs_path("data/number/fraction.tsv")), (numerator + denominator)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.graph = graph
|
||||||
|
final_graph = self.add_tokens(self.graph)
|
||||||
|
self.fst = final_graph.optimize()
|
||||||
@@ -0,0 +1,304 @@
|
|||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import pynini
|
||||||
|
from nemo_text_processing.text_normalization.en.graph_utils import (
|
||||||
|
NEMO_ALPHA,
|
||||||
|
NEMO_DIGIT,
|
||||||
|
NEMO_NON_BREAKING_SPACE,
|
||||||
|
NEMO_SIGMA,
|
||||||
|
NEMO_SPACE,
|
||||||
|
NEMO_UPPER,
|
||||||
|
SINGULAR_TO_PLURAL,
|
||||||
|
TO_LOWER,
|
||||||
|
GraphFst,
|
||||||
|
convert_space,
|
||||||
|
delete_space,
|
||||||
|
delete_zero_or_one_space,
|
||||||
|
insert_space,
|
||||||
|
)
|
||||||
|
from nemo_text_processing.text_normalization.en.taggers.ordinal import OrdinalFst as OrdinalTagger
|
||||||
|
from nemo_text_processing.text_normalization.en.taggers.whitelist import get_formats
|
||||||
|
from nemo_text_processing.text_normalization.en.utils import get_abs_path, load_labels
|
||||||
|
from nemo_text_processing.text_normalization.en.verbalizers.ordinal import OrdinalFst as OrdinalVerbalizer
|
||||||
|
from pynini.examples import plurals
|
||||||
|
from pynini.lib import pynutil
|
||||||
|
|
||||||
|
|
||||||
|
class MeasureFst(GraphFst):
|
||||||
|
"""
|
||||||
|
Finite state transducer for classifying measure, suppletive aware, e.g.
|
||||||
|
-12kg -> measure { negative: "true" cardinal { integer: "twelve" } units: "kilograms" }
|
||||||
|
1kg -> measure { cardinal { integer: "one" } units: "kilogram" }
|
||||||
|
.5kg -> measure { decimal { fractional_part: "five" } units: "kilograms" }
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cardinal: CardinalFst
|
||||||
|
decimal: DecimalFst
|
||||||
|
fraction: FractionFst
|
||||||
|
deterministic: if True will provide a single transduction option,
|
||||||
|
for False multiple transduction are generated (used for audio-based normalization)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cardinal: GraphFst, decimal: GraphFst, fraction: GraphFst, deterministic: bool = True):
|
||||||
|
super().__init__(name="measure", kind="classify", deterministic=deterministic)
|
||||||
|
cardinal_graph = cardinal.graph_with_and | self.get_range(cardinal.graph_with_and)
|
||||||
|
|
||||||
|
graph_unit = pynini.string_file(get_abs_path("data/measure/unit.tsv"))
|
||||||
|
if not deterministic:
|
||||||
|
graph_unit |= pynini.string_file(get_abs_path("data/measure/unit_alternatives.tsv"))
|
||||||
|
|
||||||
|
graph_unit |= pynini.compose(
|
||||||
|
pynini.closure(TO_LOWER, 1) + (NEMO_ALPHA | TO_LOWER) + pynini.closure(NEMO_ALPHA | TO_LOWER), graph_unit
|
||||||
|
).optimize()
|
||||||
|
|
||||||
|
graph_unit_plural = convert_space(graph_unit @ SINGULAR_TO_PLURAL)
|
||||||
|
graph_unit = convert_space(graph_unit)
|
||||||
|
|
||||||
|
optional_graph_negative = pynini.closure(pynutil.insert("negative: ") + pynini.cross("-", "\"true\" "), 0, 1)
|
||||||
|
|
||||||
|
graph_unit2 = (
|
||||||
|
pynini.cross("/", "per") + delete_zero_or_one_space + pynutil.insert(NEMO_NON_BREAKING_SPACE) + graph_unit
|
||||||
|
)
|
||||||
|
|
||||||
|
optional_graph_unit2 = pynini.closure(
|
||||||
|
delete_zero_or_one_space + pynutil.insert(NEMO_NON_BREAKING_SPACE) + graph_unit2, 0, 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
unit_plural = (
|
||||||
|
pynutil.insert("units: \"")
|
||||||
|
+ (graph_unit_plural + optional_graph_unit2 | graph_unit2)
|
||||||
|
+ pynutil.insert("\"")
|
||||||
|
)
|
||||||
|
|
||||||
|
unit_singular = (
|
||||||
|
pynutil.insert("units: \"") + (graph_unit + optional_graph_unit2 | graph_unit2) + pynutil.insert("\"")
|
||||||
|
)
|
||||||
|
|
||||||
|
subgraph_decimal = (
|
||||||
|
pynutil.insert("decimal { ")
|
||||||
|
+ optional_graph_negative
|
||||||
|
+ decimal.final_graph_wo_negative
|
||||||
|
+ delete_space
|
||||||
|
+ pynutil.insert(" } ")
|
||||||
|
+ unit_plural
|
||||||
|
)
|
||||||
|
|
||||||
|
# support radio FM/AM
|
||||||
|
subgraph_decimal |= (
|
||||||
|
pynutil.insert("decimal { ")
|
||||||
|
+ decimal.final_graph_wo_negative
|
||||||
|
+ delete_space
|
||||||
|
+ pynutil.insert(" } ")
|
||||||
|
+ pynutil.insert("units: \"")
|
||||||
|
+ pynini.union("AM", "FM")
|
||||||
|
+ pynutil.insert("\"")
|
||||||
|
)
|
||||||
|
|
||||||
|
subgraph_cardinal = (
|
||||||
|
pynutil.insert("cardinal { ")
|
||||||
|
+ optional_graph_negative
|
||||||
|
+ pynutil.insert("integer: \"")
|
||||||
|
+ ((NEMO_SIGMA - "1") @ cardinal_graph)
|
||||||
|
+ delete_space
|
||||||
|
+ pynutil.insert("\"")
|
||||||
|
+ pynutil.insert(" } ")
|
||||||
|
+ unit_plural
|
||||||
|
)
|
||||||
|
|
||||||
|
subgraph_cardinal |= (
|
||||||
|
pynutil.insert("cardinal { ")
|
||||||
|
+ optional_graph_negative
|
||||||
|
+ pynutil.insert("integer: \"")
|
||||||
|
+ pynini.cross("1", "one")
|
||||||
|
+ delete_space
|
||||||
|
+ pynutil.insert("\"")
|
||||||
|
+ pynutil.insert(" } ")
|
||||||
|
+ unit_singular
|
||||||
|
)
|
||||||
|
|
||||||
|
unit_graph = (
|
||||||
|
pynutil.insert("cardinal { integer: \"-\" } units: \"")
|
||||||
|
+ pynini.cross(pynini.union("/", "per"), "per")
|
||||||
|
+ delete_zero_or_one_space
|
||||||
|
+ pynutil.insert(NEMO_NON_BREAKING_SPACE)
|
||||||
|
+ graph_unit
|
||||||
|
+ pynutil.insert("\" preserve_order: true")
|
||||||
|
)
|
||||||
|
|
||||||
|
decimal_dash_alpha = (
|
||||||
|
pynutil.insert("decimal { ")
|
||||||
|
+ decimal.final_graph_wo_negative
|
||||||
|
+ pynini.cross('-', '')
|
||||||
|
+ pynutil.insert(" } units: \"")
|
||||||
|
+ pynini.closure(NEMO_ALPHA, 1)
|
||||||
|
+ pynutil.insert("\"")
|
||||||
|
)
|
||||||
|
|
||||||
|
decimal_times = (
|
||||||
|
pynutil.insert("decimal { ")
|
||||||
|
+ decimal.final_graph_wo_negative
|
||||||
|
+ pynutil.insert(" } units: \"")
|
||||||
|
+ pynini.cross(pynini.union('x', "X"), 'x')
|
||||||
|
+ pynutil.insert("\"")
|
||||||
|
)
|
||||||
|
|
||||||
|
alpha_dash_decimal = (
|
||||||
|
pynutil.insert("units: \"")
|
||||||
|
+ pynini.closure(NEMO_ALPHA, 1)
|
||||||
|
+ pynini.accep('-')
|
||||||
|
+ pynutil.insert("\"")
|
||||||
|
+ pynutil.insert(" decimal { ")
|
||||||
|
+ decimal.final_graph_wo_negative
|
||||||
|
+ pynutil.insert(" } preserve_order: true")
|
||||||
|
)
|
||||||
|
|
||||||
|
subgraph_fraction = (
|
||||||
|
pynutil.insert("fraction { ") + fraction.graph + delete_space + pynutil.insert(" } ") + unit_plural
|
||||||
|
)
|
||||||
|
|
||||||
|
address = self.get_address_graph(cardinal)
|
||||||
|
address = (
|
||||||
|
pynutil.insert("units: \"address\" cardinal { integer: \"")
|
||||||
|
+ address
|
||||||
|
+ pynutil.insert("\" } preserve_order: true")
|
||||||
|
)
|
||||||
|
|
||||||
|
math_operations = pynini.string_file(get_abs_path("data/measure/math_operation.tsv"))
|
||||||
|
delimiter = pynini.accep(" ") | pynutil.insert(" ")
|
||||||
|
|
||||||
|
math = (
|
||||||
|
(cardinal_graph | NEMO_ALPHA)
|
||||||
|
+ delimiter
|
||||||
|
+ math_operations
|
||||||
|
+ (delimiter | NEMO_ALPHA)
|
||||||
|
+ cardinal_graph
|
||||||
|
+ delimiter
|
||||||
|
+ pynini.cross("=", "equals")
|
||||||
|
+ delimiter
|
||||||
|
+ (cardinal_graph | NEMO_ALPHA)
|
||||||
|
)
|
||||||
|
|
||||||
|
math |= (
|
||||||
|
(cardinal_graph | NEMO_ALPHA)
|
||||||
|
+ delimiter
|
||||||
|
+ pynini.cross("=", "equals")
|
||||||
|
+ delimiter
|
||||||
|
+ (cardinal_graph | NEMO_ALPHA)
|
||||||
|
+ delimiter
|
||||||
|
+ math_operations
|
||||||
|
+ delimiter
|
||||||
|
+ cardinal_graph
|
||||||
|
)
|
||||||
|
|
||||||
|
math = (
|
||||||
|
pynutil.insert("units: \"math\" cardinal { integer: \"")
|
||||||
|
+ math
|
||||||
|
+ pynutil.insert("\" } preserve_order: true")
|
||||||
|
)
|
||||||
|
final_graph = (
|
||||||
|
subgraph_decimal
|
||||||
|
| subgraph_cardinal
|
||||||
|
| unit_graph
|
||||||
|
| decimal_dash_alpha
|
||||||
|
| decimal_times
|
||||||
|
| alpha_dash_decimal
|
||||||
|
| subgraph_fraction
|
||||||
|
| address
|
||||||
|
| math
|
||||||
|
)
|
||||||
|
|
||||||
|
final_graph = self.add_tokens(final_graph)
|
||||||
|
self.fst = final_graph.optimize()
|
||||||
|
|
||||||
|
def get_range(self, cardinal: GraphFst):
|
||||||
|
"""
|
||||||
|
Returns range forms for measure tagger, e.g. 2-3, 2x3, 2*2
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cardinal: cardinal GraphFst
|
||||||
|
"""
|
||||||
|
range_graph = cardinal + pynini.cross(pynini.union("-", " - "), " to ") + cardinal
|
||||||
|
|
||||||
|
for x in [" x ", "x"]:
|
||||||
|
range_graph |= cardinal + pynini.cross(x, " by ") + cardinal
|
||||||
|
if not self.deterministic:
|
||||||
|
range_graph |= cardinal + pynini.cross(x, " times ") + cardinal
|
||||||
|
|
||||||
|
for x in ["*", " * "]:
|
||||||
|
range_graph |= cardinal + pynini.cross(x, " times ") + cardinal
|
||||||
|
return range_graph.optimize()
|
||||||
|
|
||||||
|
def get_address_graph(self, cardinal):
|
||||||
|
"""
|
||||||
|
Finite state transducer for classifying serial.
|
||||||
|
The serial is a combination of digits, letters and dashes, e.g.:
|
||||||
|
2788 San Tomas Expy, Santa Clara, CA 95051 ->
|
||||||
|
units: "address" cardinal
|
||||||
|
{ integer: "two seven eight eight San Tomas Expressway Santa Clara California nine five zero five one" }
|
||||||
|
preserve_order: true
|
||||||
|
"""
|
||||||
|
ordinal_verbalizer = OrdinalVerbalizer().graph
|
||||||
|
ordinal_tagger = OrdinalTagger(cardinal=cardinal).graph
|
||||||
|
ordinal_num = pynini.compose(
|
||||||
|
pynutil.insert("integer: \"") + ordinal_tagger + pynutil.insert("\""), ordinal_verbalizer
|
||||||
|
)
|
||||||
|
|
||||||
|
address_num = NEMO_DIGIT ** (1, 2) @ cardinal.graph_hundred_component_at_least_one_none_zero_digit
|
||||||
|
address_num += insert_space + NEMO_DIGIT ** 2 @ (
|
||||||
|
pynini.closure(pynini.cross("0", "zero "), 0, 1)
|
||||||
|
+ cardinal.graph_hundred_component_at_least_one_none_zero_digit
|
||||||
|
)
|
||||||
|
# to handle the rest of the numbers
|
||||||
|
address_num = pynini.compose(NEMO_DIGIT ** (3, 4), address_num)
|
||||||
|
address_num = plurals._priority_union(address_num, cardinal.graph, NEMO_SIGMA)
|
||||||
|
|
||||||
|
direction = (
|
||||||
|
pynini.cross("E", "East")
|
||||||
|
| pynini.cross("S", "South")
|
||||||
|
| pynini.cross("W", "West")
|
||||||
|
| pynini.cross("N", "North")
|
||||||
|
) + pynini.closure(pynutil.delete("."), 0, 1)
|
||||||
|
|
||||||
|
direction = pynini.closure(pynini.accep(NEMO_SPACE) + direction, 0, 1)
|
||||||
|
address_words = get_formats(get_abs_path("data/address/address_word.tsv"))
|
||||||
|
address_words = (
|
||||||
|
pynini.accep(NEMO_SPACE)
|
||||||
|
+ (pynini.closure(ordinal_num, 0, 1) | NEMO_UPPER + pynini.closure(NEMO_ALPHA, 1))
|
||||||
|
+ NEMO_SPACE
|
||||||
|
+ pynini.closure(NEMO_UPPER + pynini.closure(NEMO_ALPHA) + NEMO_SPACE)
|
||||||
|
+ address_words
|
||||||
|
)
|
||||||
|
|
||||||
|
city = pynini.closure(NEMO_ALPHA | pynini.accep(NEMO_SPACE), 1)
|
||||||
|
city = pynini.closure(pynini.accep(",") + pynini.accep(NEMO_SPACE) + city, 0, 1)
|
||||||
|
|
||||||
|
states = load_labels(get_abs_path("data/address/state.tsv"))
|
||||||
|
|
||||||
|
additional_options = []
|
||||||
|
for x, y in states:
|
||||||
|
additional_options.append((x, f"{y[0]}.{y[1:]}"))
|
||||||
|
states.extend(additional_options)
|
||||||
|
state_graph = pynini.string_map(states)
|
||||||
|
state = pynini.invert(state_graph)
|
||||||
|
state = pynini.closure(pynini.accep(",") + pynini.accep(NEMO_SPACE) + state, 0, 1)
|
||||||
|
|
||||||
|
zip_code = pynini.compose(NEMO_DIGIT ** 5, cardinal.single_digits_graph)
|
||||||
|
zip_code = pynini.closure(pynini.closure(pynini.accep(","), 0, 1) + pynini.accep(NEMO_SPACE) + zip_code, 0, 1,)
|
||||||
|
|
||||||
|
address = address_num + direction + address_words + pynini.closure(city + state + zip_code, 0, 1)
|
||||||
|
|
||||||
|
address |= address_num + direction + address_words + pynini.closure(pynini.cross(".", ""), 0, 1)
|
||||||
|
|
||||||
|
return address
|
||||||
@@ -0,0 +1,192 @@
|
|||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import pynini
|
||||||
|
from nemo_text_processing.text_normalization.en.graph_utils import (
|
||||||
|
NEMO_ALPHA,
|
||||||
|
NEMO_DIGIT,
|
||||||
|
NEMO_SIGMA,
|
||||||
|
SINGULAR_TO_PLURAL,
|
||||||
|
GraphFst,
|
||||||
|
convert_space,
|
||||||
|
insert_space,
|
||||||
|
)
|
||||||
|
from nemo_text_processing.text_normalization.en.utils import get_abs_path, load_labels
|
||||||
|
from pynini.lib import pynutil
|
||||||
|
|
||||||
|
min_singular = pynini.string_file(get_abs_path("data/money/currency_minor_singular.tsv"))
|
||||||
|
min_plural = pynini.string_file(get_abs_path("data/money/currency_minor_plural.tsv"))
|
||||||
|
maj_singular = pynini.string_file((get_abs_path("data/money/currency_major.tsv")))
|
||||||
|
|
||||||
|
|
||||||
|
class MoneyFst(GraphFst):
|
||||||
|
"""
|
||||||
|
Finite state transducer for classifying money, suppletive aware, e.g.
|
||||||
|
$12.05 -> money { integer_part: "twelve" currency_maj: "dollars" fractional_part: "five" currency_min: "cents" preserve_order: true }
|
||||||
|
$12.0500 -> money { integer_part: "twelve" currency_maj: "dollars" fractional_part: "five" currency_min: "cents" preserve_order: true }
|
||||||
|
$1 -> money { currency_maj: "dollar" integer_part: "one" }
|
||||||
|
$1.00 -> money { currency_maj: "dollar" integer_part: "one" }
|
||||||
|
$0.05 -> money { fractional_part: "five" currency_min: "cents" preserve_order: true }
|
||||||
|
$1 million -> money { currency_maj: "dollars" integer_part: "one" quantity: "million" }
|
||||||
|
$1.2 million -> money { currency_maj: "dollars" integer_part: "one" fractional_part: "two" quantity: "million" }
|
||||||
|
$1.2320 -> money { currency_maj: "dollars" integer_part: "one" fractional_part: "two three two" }
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cardinal: CardinalFst
|
||||||
|
decimal: DecimalFst
|
||||||
|
deterministic: if True will provide a single transduction option,
|
||||||
|
for False multiple transduction are generated (used for audio-based normalization)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cardinal: GraphFst, decimal: GraphFst, deterministic: bool = True):
|
||||||
|
super().__init__(name="money", kind="classify", deterministic=deterministic)
|
||||||
|
cardinal_graph = cardinal.graph_with_and
|
||||||
|
graph_decimal_final = decimal.final_graph_wo_negative_w_abbr
|
||||||
|
|
||||||
|
maj_singular_labels = load_labels(get_abs_path("data/money/currency_major.tsv"))
|
||||||
|
maj_unit_plural = convert_space(maj_singular @ SINGULAR_TO_PLURAL)
|
||||||
|
maj_unit_singular = convert_space(maj_singular)
|
||||||
|
|
||||||
|
graph_maj_singular = pynutil.insert("currency_maj: \"") + maj_unit_singular + pynutil.insert("\"")
|
||||||
|
graph_maj_plural = pynutil.insert("currency_maj: \"") + maj_unit_plural + pynutil.insert("\"")
|
||||||
|
|
||||||
|
optional_delete_fractional_zeros = pynini.closure(
|
||||||
|
pynutil.delete(".") + pynini.closure(pynutil.delete("0"), 1), 0, 1
|
||||||
|
)
|
||||||
|
|
||||||
|
graph_integer_one = pynutil.insert("integer_part: \"") + pynini.cross("1", "one") + pynutil.insert("\"")
|
||||||
|
# only for decimals where third decimal after comma is non-zero or with quantity
|
||||||
|
decimal_delete_last_zeros = (
|
||||||
|
pynini.closure(NEMO_DIGIT | pynutil.delete(","))
|
||||||
|
+ pynini.accep(".")
|
||||||
|
+ pynini.closure(NEMO_DIGIT, 2)
|
||||||
|
+ (NEMO_DIGIT - "0")
|
||||||
|
+ pynini.closure(pynutil.delete("0"))
|
||||||
|
)
|
||||||
|
decimal_with_quantity = NEMO_SIGMA + NEMO_ALPHA
|
||||||
|
|
||||||
|
graph_decimal = (
|
||||||
|
graph_maj_plural + insert_space + (decimal_delete_last_zeros | decimal_with_quantity) @ graph_decimal_final
|
||||||
|
)
|
||||||
|
|
||||||
|
graph_integer = (
|
||||||
|
pynutil.insert("integer_part: \"") + ((NEMO_SIGMA - "1") @ cardinal_graph) + pynutil.insert("\"")
|
||||||
|
)
|
||||||
|
|
||||||
|
graph_integer_only = graph_maj_singular + insert_space + graph_integer_one
|
||||||
|
graph_integer_only |= graph_maj_plural + insert_space + graph_integer
|
||||||
|
|
||||||
|
final_graph = (graph_integer_only + optional_delete_fractional_zeros) | graph_decimal
|
||||||
|
|
||||||
|
# remove trailing zeros of non zero number in the first 2 digits and fill up to 2 digits
|
||||||
|
# e.g. 2000 -> 20, 0200->02, 01 -> 01, 10 -> 10
|
||||||
|
# not accepted: 002, 00, 0,
|
||||||
|
two_digits_fractional_part = (
|
||||||
|
pynini.closure(NEMO_DIGIT) + (NEMO_DIGIT - "0") + pynini.closure(pynutil.delete("0"))
|
||||||
|
) @ (
|
||||||
|
(pynutil.delete("0") + (NEMO_DIGIT - "0"))
|
||||||
|
| ((NEMO_DIGIT - "0") + pynutil.insert("0"))
|
||||||
|
| ((NEMO_DIGIT - "0") + NEMO_DIGIT)
|
||||||
|
)
|
||||||
|
|
||||||
|
graph_min_singular = pynutil.insert(" currency_min: \"") + min_singular + pynutil.insert("\"")
|
||||||
|
graph_min_plural = pynutil.insert(" currency_min: \"") + min_plural + pynutil.insert("\"")
|
||||||
|
# format ** dollars ** cent
|
||||||
|
decimal_graph_with_minor = None
|
||||||
|
integer_graph_reordered = None
|
||||||
|
decimal_default_reordered = None
|
||||||
|
for curr_symbol, _ in maj_singular_labels:
|
||||||
|
preserve_order = pynutil.insert(" preserve_order: true")
|
||||||
|
integer_plus_maj = graph_integer + insert_space + pynutil.insert(curr_symbol) @ graph_maj_plural
|
||||||
|
integer_plus_maj |= graph_integer_one + insert_space + pynutil.insert(curr_symbol) @ graph_maj_singular
|
||||||
|
|
||||||
|
integer_plus_maj_with_comma = pynini.compose(
|
||||||
|
NEMO_DIGIT - "0" + pynini.closure(NEMO_DIGIT | pynutil.delete(",")), integer_plus_maj
|
||||||
|
)
|
||||||
|
integer_plus_maj = pynini.compose(pynini.closure(NEMO_DIGIT) - "0", integer_plus_maj)
|
||||||
|
integer_plus_maj |= integer_plus_maj_with_comma
|
||||||
|
|
||||||
|
graph_fractional_one = two_digits_fractional_part @ pynini.cross("1", "one")
|
||||||
|
graph_fractional_one = pynutil.insert("fractional_part: \"") + graph_fractional_one + pynutil.insert("\"")
|
||||||
|
graph_fractional = (
|
||||||
|
two_digits_fractional_part
|
||||||
|
@ (pynini.closure(NEMO_DIGIT, 1, 2) - "1")
|
||||||
|
@ cardinal.graph_hundred_component_at_least_one_none_zero_digit
|
||||||
|
)
|
||||||
|
graph_fractional = pynutil.insert("fractional_part: \"") + graph_fractional + pynutil.insert("\"")
|
||||||
|
|
||||||
|
fractional_plus_min = graph_fractional + insert_space + pynutil.insert(curr_symbol) @ graph_min_plural
|
||||||
|
fractional_plus_min |= (
|
||||||
|
graph_fractional_one + insert_space + pynutil.insert(curr_symbol) @ graph_min_singular
|
||||||
|
)
|
||||||
|
|
||||||
|
decimal_graph_with_minor_curr = integer_plus_maj + pynini.cross(".", " ") + fractional_plus_min
|
||||||
|
|
||||||
|
if not deterministic:
|
||||||
|
decimal_graph_with_minor_curr |= pynutil.add_weight(
|
||||||
|
integer_plus_maj
|
||||||
|
+ pynini.cross(".", " ")
|
||||||
|
+ pynutil.insert("fractional_part: \"")
|
||||||
|
+ two_digits_fractional_part @ cardinal.graph_hundred_component_at_least_one_none_zero_digit
|
||||||
|
+ pynutil.insert("\""),
|
||||||
|
weight=0.0001,
|
||||||
|
)
|
||||||
|
default_fraction_graph = (decimal_delete_last_zeros | decimal_with_quantity) @ graph_decimal_final
|
||||||
|
decimal_graph_with_minor_curr |= (
|
||||||
|
pynini.closure(pynutil.delete("0"), 0, 1) + pynutil.delete(".") + fractional_plus_min
|
||||||
|
)
|
||||||
|
decimal_graph_with_minor_curr = (
|
||||||
|
pynutil.delete(curr_symbol) + decimal_graph_with_minor_curr + preserve_order
|
||||||
|
)
|
||||||
|
|
||||||
|
decimal_graph_with_minor = (
|
||||||
|
decimal_graph_with_minor_curr
|
||||||
|
if decimal_graph_with_minor is None
|
||||||
|
else pynini.union(decimal_graph_with_minor, decimal_graph_with_minor_curr).optimize()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not deterministic:
|
||||||
|
integer_graph_reordered_curr = (
|
||||||
|
pynutil.delete(curr_symbol) + integer_plus_maj + preserve_order
|
||||||
|
).optimize()
|
||||||
|
|
||||||
|
integer_graph_reordered = (
|
||||||
|
integer_graph_reordered_curr
|
||||||
|
if integer_graph_reordered is None
|
||||||
|
else pynini.union(integer_graph_reordered, integer_graph_reordered_curr).optimize()
|
||||||
|
)
|
||||||
|
decimal_default_reordered_curr = (
|
||||||
|
pynutil.delete(curr_symbol)
|
||||||
|
+ default_fraction_graph
|
||||||
|
+ insert_space
|
||||||
|
+ pynutil.insert(curr_symbol) @ graph_maj_plural
|
||||||
|
)
|
||||||
|
|
||||||
|
decimal_default_reordered = (
|
||||||
|
decimal_default_reordered_curr
|
||||||
|
if decimal_default_reordered is None
|
||||||
|
else pynini.union(decimal_default_reordered, decimal_default_reordered_curr)
|
||||||
|
).optimize()
|
||||||
|
|
||||||
|
# weight for SH
|
||||||
|
final_graph |= pynutil.add_weight(decimal_graph_with_minor, -0.0001)
|
||||||
|
|
||||||
|
if not deterministic:
|
||||||
|
final_graph |= integer_graph_reordered | decimal_default_reordered
|
||||||
|
# to handle "$2.00" cases
|
||||||
|
final_graph |= pynini.compose(
|
||||||
|
NEMO_SIGMA + pynutil.delete(".") + pynini.closure(pynutil.delete("0"), 1), integer_graph_reordered
|
||||||
|
)
|
||||||
|
final_graph = self.add_tokens(final_graph.optimize())
|
||||||
|
self.fst = final_graph.optimize()
|
||||||
@@ -0,0 +1,61 @@
|
|||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import pynini
|
||||||
|
from nemo_text_processing.text_normalization.en.graph_utils import NEMO_DIGIT, GraphFst
|
||||||
|
from pynini.lib import pynutil
|
||||||
|
|
||||||
|
|
||||||
|
class OrdinalFst(GraphFst):
|
||||||
|
"""
|
||||||
|
Finite state transducer for classifying ordinal, e.g.
|
||||||
|
13th -> ordinal { integer: "thirteen" }
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cardinal: CardinalFst
|
||||||
|
deterministic: if True will provide a single transduction option,
|
||||||
|
for False multiple transduction are generated (used for audio-based normalization)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cardinal: GraphFst, deterministic: bool = True):
|
||||||
|
super().__init__(name="ordinal", kind="classify", deterministic=deterministic)
|
||||||
|
|
||||||
|
cardinal_graph = cardinal.graph
|
||||||
|
cardinal_format = pynini.closure(NEMO_DIGIT | pynini.accep(","))
|
||||||
|
st_format = (
|
||||||
|
pynini.closure(cardinal_format + (NEMO_DIGIT - "1"), 0, 1)
|
||||||
|
+ pynini.accep("1")
|
||||||
|
+ pynutil.delete(pynini.union("st", "ST"))
|
||||||
|
)
|
||||||
|
nd_format = (
|
||||||
|
pynini.closure(cardinal_format + (NEMO_DIGIT - "1"), 0, 1)
|
||||||
|
+ pynini.accep("2")
|
||||||
|
+ pynutil.delete(pynini.union("nd", "ND"))
|
||||||
|
)
|
||||||
|
rd_format = (
|
||||||
|
pynini.closure(cardinal_format + (NEMO_DIGIT - "1"), 0, 1)
|
||||||
|
+ pynini.accep("3")
|
||||||
|
+ pynutil.delete(pynini.union("rd", "RD"))
|
||||||
|
)
|
||||||
|
th_format = pynini.closure(
|
||||||
|
(NEMO_DIGIT - "1" - "2" - "3")
|
||||||
|
| (cardinal_format + "1" + NEMO_DIGIT)
|
||||||
|
| (cardinal_format + (NEMO_DIGIT - "1") + (NEMO_DIGIT - "1" - "2" - "3")),
|
||||||
|
1,
|
||||||
|
) + pynutil.delete(pynini.union("th", "TH"))
|
||||||
|
self.graph = (st_format | nd_format | rd_format | th_format) @ cardinal_graph
|
||||||
|
final_graph = pynutil.insert("integer: \"") + self.graph + pynutil.insert("\"")
|
||||||
|
final_graph = self.add_tokens(final_graph)
|
||||||
|
self.fst = final_graph.optimize()
|
||||||
@@ -0,0 +1,65 @@
|
|||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from unicodedata import category
|
||||||
|
|
||||||
|
import pynini
|
||||||
|
from nemo_text_processing.text_normalization.en.graph_utils import NEMO_NOT_SPACE, NEMO_SIGMA, GraphFst
|
||||||
|
from nemo_text_processing.text_normalization.en.utils import get_abs_path, load_labels
|
||||||
|
from pynini.examples import plurals
|
||||||
|
from pynini.lib import pynutil
|
||||||
|
|
||||||
|
|
||||||
|
class PunctuationFst(GraphFst):
|
||||||
|
"""
|
||||||
|
Finite state transducer for classifying punctuation
|
||||||
|
e.g. a, -> tokens { name: "a" } tokens { name: "," }
|
||||||
|
|
||||||
|
Args:
|
||||||
|
deterministic: if True will provide a single transduction option,
|
||||||
|
for False multiple transduction are generated (used for audio-based normalization)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, deterministic: bool = True):
|
||||||
|
super().__init__(name="punctuation", kind="classify", deterministic=deterministic)
|
||||||
|
s = "!#%&\'()*+,-./:;<=>?@^_`{|}~\""
|
||||||
|
|
||||||
|
punct_symbols_to_exclude = ["[", "]"]
|
||||||
|
punct_unicode = [
|
||||||
|
chr(i)
|
||||||
|
for i in range(sys.maxunicode)
|
||||||
|
if category(chr(i)).startswith("P") and chr(i) not in punct_symbols_to_exclude
|
||||||
|
]
|
||||||
|
|
||||||
|
whitelist_symbols = load_labels(get_abs_path("data/whitelist/symbol.tsv"))
|
||||||
|
whitelist_symbols = [x[0] for x in whitelist_symbols]
|
||||||
|
self.punct_marks = [p for p in punct_unicode + list(s) if p not in whitelist_symbols]
|
||||||
|
|
||||||
|
punct = pynini.union(*self.punct_marks)
|
||||||
|
punct = pynini.closure(punct, 1)
|
||||||
|
|
||||||
|
emphasis = (
|
||||||
|
pynini.accep("<")
|
||||||
|
+ (
|
||||||
|
(pynini.closure(NEMO_NOT_SPACE - pynini.union("<", ">"), 1) + pynini.closure(pynini.accep("/"), 0, 1))
|
||||||
|
| (pynini.accep("/") + pynini.closure(NEMO_NOT_SPACE - pynini.union("<", ">"), 1))
|
||||||
|
)
|
||||||
|
+ pynini.accep(">")
|
||||||
|
)
|
||||||
|
punct = plurals._priority_union(emphasis, punct, NEMO_SIGMA)
|
||||||
|
|
||||||
|
self.graph = punct
|
||||||
|
self.fst = (pynutil.insert("name: \"") + self.graph + pynutil.insert("\"")).optimize()
|
||||||
@@ -0,0 +1,102 @@
|
|||||||
|
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import pynini
|
||||||
|
from nemo_text_processing.text_normalization.en.graph_utils import NEMO_DIGIT, GraphFst, convert_space
|
||||||
|
from pynini.lib import pynutil
|
||||||
|
|
||||||
|
|
||||||
|
class RangeFst(GraphFst):
|
||||||
|
"""
|
||||||
|
This class is a composite class of two other class instances
|
||||||
|
|
||||||
|
Args:
|
||||||
|
time: composed tagger and verbalizer
|
||||||
|
date: composed tagger and verbalizer
|
||||||
|
cardinal: tagger
|
||||||
|
deterministic: if True will provide a single transduction option,
|
||||||
|
for False multiple transduction are generated (used for audio-based normalization)
|
||||||
|
lm: whether to use for hybrid LM
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, time: GraphFst, date: GraphFst, cardinal: GraphFst, deterministic: bool = True, lm: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__(name="range", kind="classify", deterministic=deterministic)
|
||||||
|
|
||||||
|
delete_space = pynini.closure(pynutil.delete(" "), 0, 1)
|
||||||
|
|
||||||
|
approx = pynini.cross("~", "approximately")
|
||||||
|
|
||||||
|
# TIME
|
||||||
|
time_graph = time + delete_space + pynini.cross("-", " to ") + delete_space + time
|
||||||
|
self.graph = time_graph | (approx + time)
|
||||||
|
|
||||||
|
cardinal = cardinal.graph_with_and
|
||||||
|
# YEAR
|
||||||
|
date_year_four_digit = (NEMO_DIGIT ** 4 + pynini.closure(pynini.accep("s"), 0, 1)) @ date
|
||||||
|
date_year_two_digit = (NEMO_DIGIT ** 2 + pynini.closure(pynini.accep("s"), 0, 1)) @ date
|
||||||
|
year_to_year_graph = (
|
||||||
|
date_year_four_digit
|
||||||
|
+ delete_space
|
||||||
|
+ pynini.cross("-", " to ")
|
||||||
|
+ delete_space
|
||||||
|
+ (date_year_four_digit | date_year_two_digit | (NEMO_DIGIT ** 2 @ cardinal))
|
||||||
|
)
|
||||||
|
mid_year_graph = pynini.accep("mid") + pynini.cross("-", " ") + (date_year_four_digit | date_year_two_digit)
|
||||||
|
|
||||||
|
self.graph |= year_to_year_graph
|
||||||
|
self.graph |= mid_year_graph
|
||||||
|
|
||||||
|
# ADDITION
|
||||||
|
range_graph = cardinal + pynini.closure(pynini.cross("+", " plus ") + cardinal, 1)
|
||||||
|
range_graph |= cardinal + pynini.closure(pynini.cross(" + ", " plus ") + cardinal, 1)
|
||||||
|
range_graph |= approx + cardinal
|
||||||
|
range_graph |= cardinal + (pynini.cross("...", " ... ") | pynini.accep(" ... ")) + cardinal
|
||||||
|
|
||||||
|
if not deterministic or lm:
|
||||||
|
# cardinal ----
|
||||||
|
cardinal_to_cardinal_graph = (
|
||||||
|
cardinal + delete_space + pynini.cross("-", pynini.union(" to ", " minus ")) + delete_space + cardinal
|
||||||
|
)
|
||||||
|
|
||||||
|
range_graph |= cardinal_to_cardinal_graph | (
|
||||||
|
cardinal + delete_space + pynini.cross(":", " to ") + delete_space + cardinal
|
||||||
|
)
|
||||||
|
|
||||||
|
# MULTIPLY
|
||||||
|
for x in [" x ", "x"]:
|
||||||
|
range_graph |= cardinal + pynini.closure(
|
||||||
|
pynini.cross(x, pynini.union(" by ", " times ")) + cardinal, 1
|
||||||
|
)
|
||||||
|
|
||||||
|
for x in ["*", " * "]:
|
||||||
|
range_graph |= cardinal + pynini.closure(pynini.cross(x, " times ") + cardinal, 1)
|
||||||
|
|
||||||
|
# supports "No. 12" -> "Number 12"
|
||||||
|
range_graph |= (
|
||||||
|
(pynini.cross(pynini.union("NO", "No"), "Number") | pynini.cross("no", "number"))
|
||||||
|
+ pynini.closure(pynini.union(". ", " "), 0, 1)
|
||||||
|
+ cardinal
|
||||||
|
)
|
||||||
|
|
||||||
|
for x in ["/", " / "]:
|
||||||
|
range_graph |= cardinal + pynini.closure(pynini.cross(x, " divided by ") + cardinal, 1)
|
||||||
|
|
||||||
|
self.graph |= range_graph
|
||||||
|
|
||||||
|
self.graph = self.graph.optimize()
|
||||||
|
graph = pynutil.insert("name: \"") + convert_space(self.graph).optimize() + pynutil.insert("\"")
|
||||||
|
self.fst = graph.optimize()
|
||||||
@@ -0,0 +1,114 @@
|
|||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import pynini
|
||||||
|
from nemo_text_processing.text_normalization.en.graph_utils import NEMO_ALPHA, NEMO_SIGMA, GraphFst
|
||||||
|
from nemo_text_processing.text_normalization.en.utils import get_abs_path, load_labels
|
||||||
|
from pynini.lib import pynutil
|
||||||
|
|
||||||
|
|
||||||
|
class RomanFst(GraphFst):
|
||||||
|
"""
|
||||||
|
Finite state transducer for classifying roman numbers:
|
||||||
|
e.g. "IV" -> tokens { roman { integer: "four" } }
|
||||||
|
|
||||||
|
Args:
|
||||||
|
deterministic: if True will provide a single transduction option,
|
||||||
|
for False multiple transduction are generated (used for audio-based normalization)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, deterministic: bool = True, lm: bool = False):
|
||||||
|
super().__init__(name="roman", kind="classify", deterministic=deterministic)
|
||||||
|
|
||||||
|
roman_dict = load_labels(get_abs_path("data/roman/roman_to_spoken.tsv"))
|
||||||
|
default_graph = pynini.string_map(roman_dict).optimize()
|
||||||
|
default_graph = pynutil.insert("integer: \"") + default_graph + pynutil.insert("\"")
|
||||||
|
ordinal_limit = 19
|
||||||
|
|
||||||
|
if deterministic:
|
||||||
|
# exclude "I"
|
||||||
|
start_idx = 1
|
||||||
|
else:
|
||||||
|
start_idx = 0
|
||||||
|
|
||||||
|
graph_teens = pynini.string_map([x[0] for x in roman_dict[start_idx:ordinal_limit]]).optimize()
|
||||||
|
|
||||||
|
# roman numerals up to ordinal_limit with a preceding name are converted to ordinal form
|
||||||
|
names = get_names()
|
||||||
|
graph = (
|
||||||
|
pynutil.insert("key_the_ordinal: \"")
|
||||||
|
+ names
|
||||||
|
+ pynutil.insert("\"")
|
||||||
|
+ pynini.accep(" ")
|
||||||
|
+ graph_teens @ default_graph
|
||||||
|
).optimize()
|
||||||
|
|
||||||
|
# single symbol roman numerals with preceding key words (multiple formats) are converted to cardinal form
|
||||||
|
key_words = []
|
||||||
|
for k_word in load_labels(get_abs_path("data/roman/key_word.tsv")):
|
||||||
|
key_words.append(k_word)
|
||||||
|
key_words.append([k_word[0][0].upper() + k_word[0][1:]])
|
||||||
|
key_words.append([k_word[0].upper()])
|
||||||
|
|
||||||
|
key_words = pynini.string_map(key_words).optimize()
|
||||||
|
graph |= (
|
||||||
|
pynutil.insert("key_cardinal: \"") + key_words + pynutil.insert("\"") + pynini.accep(" ") + default_graph
|
||||||
|
).optimize()
|
||||||
|
|
||||||
|
if deterministic or lm:
|
||||||
|
# two digit roman numerals up to 49
|
||||||
|
roman_to_cardinal = pynini.compose(
|
||||||
|
pynini.closure(NEMO_ALPHA, 2),
|
||||||
|
(
|
||||||
|
pynutil.insert("default_cardinal: \"default\" ")
|
||||||
|
+ (pynini.string_map([x[0] for x in roman_dict[:50]]).optimize()) @ default_graph
|
||||||
|
),
|
||||||
|
)
|
||||||
|
graph |= roman_to_cardinal
|
||||||
|
elif not lm:
|
||||||
|
# two or more digit roman numerals
|
||||||
|
roman_to_cardinal = pynini.compose(
|
||||||
|
pynini.difference(NEMO_SIGMA, "I"),
|
||||||
|
(
|
||||||
|
pynutil.insert("default_cardinal: \"default\" integer: \"")
|
||||||
|
+ pynini.string_map(roman_dict).optimize()
|
||||||
|
+ pynutil.insert("\"")
|
||||||
|
),
|
||||||
|
).optimize()
|
||||||
|
graph |= roman_to_cardinal
|
||||||
|
|
||||||
|
# convert three digit roman or up with suffix to ordinal
|
||||||
|
roman_to_ordinal = pynini.compose(
|
||||||
|
pynini.closure(NEMO_ALPHA, 3),
|
||||||
|
(pynutil.insert("default_ordinal: \"default\" ") + graph_teens @ default_graph + pynutil.delete("th")),
|
||||||
|
)
|
||||||
|
|
||||||
|
graph |= roman_to_ordinal
|
||||||
|
graph = self.add_tokens(graph.optimize())
|
||||||
|
|
||||||
|
self.fst = graph.optimize()
|
||||||
|
|
||||||
|
|
||||||
|
def get_names():
|
||||||
|
"""
|
||||||
|
Returns the graph that matched common male and female names.
|
||||||
|
"""
|
||||||
|
male_labels = load_labels(get_abs_path("data/roman/male.tsv"))
|
||||||
|
female_labels = load_labels(get_abs_path("data/roman/female.tsv"))
|
||||||
|
male_labels.extend([[x[0].upper()] for x in male_labels])
|
||||||
|
female_labels.extend([[x[0].upper()] for x in female_labels])
|
||||||
|
names = pynini.string_map(male_labels).optimize()
|
||||||
|
names |= pynini.string_map(female_labels).optimize()
|
||||||
|
return names
|
||||||
@@ -0,0 +1,136 @@
|
|||||||
|
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import pynini
|
||||||
|
from nemo_text_processing.text_normalization.en.graph_utils import (
|
||||||
|
NEMO_ALPHA,
|
||||||
|
NEMO_DIGIT,
|
||||||
|
NEMO_NOT_SPACE,
|
||||||
|
NEMO_SIGMA,
|
||||||
|
GraphFst,
|
||||||
|
convert_space,
|
||||||
|
)
|
||||||
|
from nemo_text_processing.text_normalization.en.utils import get_abs_path, load_labels
|
||||||
|
from pynini.examples import plurals
|
||||||
|
from pynini.lib import pynutil
|
||||||
|
|
||||||
|
|
||||||
|
class SerialFst(GraphFst):
|
||||||
|
"""
|
||||||
|
This class is a composite class of two other class instances
|
||||||
|
|
||||||
|
Args:
|
||||||
|
time: composed tagger and verbalizer
|
||||||
|
date: composed tagger and verbalizer
|
||||||
|
cardinal: tagger
|
||||||
|
deterministic: if True will provide a single transduction option,
|
||||||
|
for False multiple transduction are generated (used for audio-based normalization)
|
||||||
|
lm: whether to use for hybrid LM
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cardinal: GraphFst, ordinal: GraphFst, deterministic: bool = True, lm: bool = False):
|
||||||
|
super().__init__(name="integer", kind="classify", deterministic=deterministic)
|
||||||
|
|
||||||
|
"""
|
||||||
|
Finite state transducer for classifying serial (handles only cases without delimiters,
|
||||||
|
values with delimiters are handled by default).
|
||||||
|
The serial is a combination of digits, letters and dashes, e.g.:
|
||||||
|
c325b -> tokens { cardinal { integer: "c three two five b" } }
|
||||||
|
"""
|
||||||
|
num_graph = pynini.compose(NEMO_DIGIT ** (6, ...), cardinal.single_digits_graph).optimize()
|
||||||
|
num_graph |= pynini.compose(NEMO_DIGIT ** (1, 5), cardinal.graph).optimize()
|
||||||
|
# to handle numbers starting with zero
|
||||||
|
num_graph |= pynini.compose(
|
||||||
|
pynini.accep("0") + pynini.closure(NEMO_DIGIT), cardinal.single_digits_graph
|
||||||
|
).optimize()
|
||||||
|
# TODO: "#" doesn't work from the file
|
||||||
|
symbols_graph = pynini.string_file(get_abs_path("data/whitelist/symbol.tsv")).optimize() | pynini.cross(
|
||||||
|
"#", "hash"
|
||||||
|
)
|
||||||
|
num_graph |= symbols_graph
|
||||||
|
|
||||||
|
if not self.deterministic and not lm:
|
||||||
|
num_graph |= cardinal.single_digits_graph
|
||||||
|
# also allow double digits to be pronounced as integer in serial number
|
||||||
|
num_graph |= pynutil.add_weight(
|
||||||
|
NEMO_DIGIT ** 2 @ cardinal.graph_hundred_component_at_least_one_none_zero_digit, weight=0.0001
|
||||||
|
)
|
||||||
|
|
||||||
|
# add space between letter and digit/symbol
|
||||||
|
symbols = [x[0] for x in load_labels(get_abs_path("data/whitelist/symbol.tsv"))]
|
||||||
|
symbols = pynini.union(*symbols)
|
||||||
|
digit_symbol = NEMO_DIGIT | symbols
|
||||||
|
|
||||||
|
graph_with_space = pynini.compose(
|
||||||
|
pynini.cdrewrite(pynutil.insert(" "), NEMO_ALPHA | symbols, digit_symbol, NEMO_SIGMA),
|
||||||
|
pynini.cdrewrite(pynutil.insert(" "), digit_symbol, NEMO_ALPHA | symbols, NEMO_SIGMA),
|
||||||
|
)
|
||||||
|
|
||||||
|
# serial graph with delimiter
|
||||||
|
delimiter = pynini.accep("-") | pynini.accep("/") | pynini.accep(" ")
|
||||||
|
if not deterministic:
|
||||||
|
delimiter |= pynini.cross("-", " dash ") | pynini.cross("/", " slash ")
|
||||||
|
|
||||||
|
alphas = pynini.closure(NEMO_ALPHA, 1)
|
||||||
|
letter_num = alphas + delimiter + num_graph
|
||||||
|
num_letter = pynini.closure(num_graph + delimiter, 1) + alphas
|
||||||
|
next_alpha_or_num = pynini.closure(delimiter + (alphas | num_graph))
|
||||||
|
next_alpha_or_num |= pynini.closure(
|
||||||
|
delimiter
|
||||||
|
+ num_graph
|
||||||
|
+ plurals._priority_union(pynini.accep(" "), pynutil.insert(" "), NEMO_SIGMA).optimize()
|
||||||
|
+ alphas
|
||||||
|
)
|
||||||
|
|
||||||
|
serial_graph = letter_num + next_alpha_or_num
|
||||||
|
serial_graph |= num_letter + next_alpha_or_num
|
||||||
|
# numbers only with 2+ delimiters
|
||||||
|
serial_graph |= (
|
||||||
|
num_graph + delimiter + num_graph + delimiter + num_graph + pynini.closure(delimiter + num_graph)
|
||||||
|
)
|
||||||
|
# 2+ symbols
|
||||||
|
serial_graph |= pynini.compose(NEMO_SIGMA + symbols + NEMO_SIGMA, num_graph + delimiter + num_graph)
|
||||||
|
|
||||||
|
# exclude ordinal numbers from serial options
|
||||||
|
serial_graph = pynini.compose(
|
||||||
|
pynini.difference(NEMO_SIGMA, pynini.project(ordinal.graph, "input")), serial_graph
|
||||||
|
).optimize()
|
||||||
|
|
||||||
|
serial_graph = pynutil.add_weight(serial_graph, 0.0001)
|
||||||
|
serial_graph |= (
|
||||||
|
pynini.closure(NEMO_NOT_SPACE, 1)
|
||||||
|
+ (pynini.cross("^2", " squared") | pynini.cross("^3", " cubed")).optimize()
|
||||||
|
)
|
||||||
|
|
||||||
|
# at least one serial graph with alpha numeric value and optional additional serial/num/alpha values
|
||||||
|
serial_graph = (
|
||||||
|
pynini.closure((serial_graph | num_graph | alphas) + delimiter)
|
||||||
|
+ serial_graph
|
||||||
|
+ pynini.closure(delimiter + (serial_graph | num_graph | alphas))
|
||||||
|
)
|
||||||
|
|
||||||
|
serial_graph |= pynini.compose(graph_with_space, serial_graph.optimize()).optimize()
|
||||||
|
serial_graph = pynini.compose(pynini.closure(NEMO_NOT_SPACE, 2), serial_graph).optimize()
|
||||||
|
|
||||||
|
# this is not to verbolize "/" as "slash" in cases like "import/export"
|
||||||
|
serial_graph = pynini.compose(
|
||||||
|
pynini.difference(
|
||||||
|
NEMO_SIGMA, pynini.closure(NEMO_ALPHA, 1) + pynini.accep("/") + pynini.closure(NEMO_ALPHA, 1)
|
||||||
|
),
|
||||||
|
serial_graph,
|
||||||
|
)
|
||||||
|
self.graph = serial_graph.optimize()
|
||||||
|
graph = pynutil.insert("name: \"") + convert_space(self.graph).optimize() + pynutil.insert("\"")
|
||||||
|
self.fst = graph.optimize()
|
||||||
@@ -0,0 +1,133 @@
|
|||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import pynini
|
||||||
|
from nemo_text_processing.text_normalization.en.graph_utils import (
|
||||||
|
NEMO_ALPHA,
|
||||||
|
NEMO_DIGIT,
|
||||||
|
NEMO_SIGMA,
|
||||||
|
GraphFst,
|
||||||
|
delete_extra_space,
|
||||||
|
delete_space,
|
||||||
|
insert_space,
|
||||||
|
plurals,
|
||||||
|
)
|
||||||
|
from nemo_text_processing.text_normalization.en.utils import get_abs_path
|
||||||
|
from pynini.lib import pynutil
|
||||||
|
|
||||||
|
|
||||||
|
class TelephoneFst(GraphFst):
|
||||||
|
"""
|
||||||
|
Finite state transducer for classifying telephone, and IP, and SSN which includes country code, number part and extension
|
||||||
|
country code optional: +***
|
||||||
|
number part: ***-***-****, or (***) ***-****
|
||||||
|
extension optional: 1-9999
|
||||||
|
E.g
|
||||||
|
+1 123-123-5678-1 -> telephone { country_code: "one" number_part: "one two three, one two three, five six seven eight" extension: "one" }
|
||||||
|
1-800-GO-U-HAUL -> telephone { country_code: "one" number_part: "one, eight hundred GO U HAUL" }
|
||||||
|
Args:
|
||||||
|
deterministic: if True will provide a single transduction option,
|
||||||
|
for False multiple transduction are generated (used for audio-based normalization)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, deterministic: bool = True):
|
||||||
|
super().__init__(name="telephone", kind="classify", deterministic=deterministic)
|
||||||
|
|
||||||
|
add_separator = pynutil.insert(", ") # between components
|
||||||
|
zero = pynini.cross("0", "zero")
|
||||||
|
if not deterministic:
|
||||||
|
zero |= pynini.cross("0", pynini.union("o", "oh"))
|
||||||
|
digit = pynini.invert(pynini.string_file(get_abs_path("data/number/digit.tsv"))).optimize() | zero
|
||||||
|
|
||||||
|
telephone_prompts = pynini.string_file(get_abs_path("data/telephone/telephone_prompt.tsv"))
|
||||||
|
country_code = (
|
||||||
|
pynini.closure(telephone_prompts + delete_extra_space, 0, 1)
|
||||||
|
+ pynini.closure(pynini.cross("+", "plus "), 0, 1)
|
||||||
|
+ pynini.closure(digit + insert_space, 0, 2)
|
||||||
|
+ digit
|
||||||
|
+ pynutil.insert(",")
|
||||||
|
)
|
||||||
|
country_code |= telephone_prompts
|
||||||
|
country_code = pynutil.insert("country_code: \"") + country_code + pynutil.insert("\"")
|
||||||
|
country_code = country_code + pynini.closure(pynutil.delete("-"), 0, 1) + delete_space + insert_space
|
||||||
|
|
||||||
|
area_part_default = pynini.closure(digit + insert_space, 2, 2) + digit
|
||||||
|
area_part = pynini.cross("800", "eight hundred") | pynini.compose(
|
||||||
|
pynini.difference(NEMO_SIGMA, "800"), area_part_default
|
||||||
|
)
|
||||||
|
|
||||||
|
area_part = (
|
||||||
|
(area_part + (pynutil.delete("-") | pynutil.delete(".")))
|
||||||
|
| (
|
||||||
|
pynutil.delete("(")
|
||||||
|
+ area_part
|
||||||
|
+ ((pynutil.delete(")") + pynini.closure(pynutil.delete(" "), 0, 1)) | pynutil.delete(")-"))
|
||||||
|
)
|
||||||
|
) + add_separator
|
||||||
|
|
||||||
|
del_separator = pynini.closure(pynini.union("-", " ", "."), 0, 1)
|
||||||
|
number_length = ((NEMO_DIGIT + del_separator) | (NEMO_ALPHA + del_separator)) ** 7
|
||||||
|
number_words = pynini.closure(
|
||||||
|
(NEMO_DIGIT @ digit) + (insert_space | (pynini.cross("-", ', ')))
|
||||||
|
| NEMO_ALPHA
|
||||||
|
| (NEMO_ALPHA + pynini.cross("-", ' '))
|
||||||
|
)
|
||||||
|
number_words |= pynini.closure(
|
||||||
|
(NEMO_DIGIT @ digit) + (insert_space | (pynini.cross(".", ', ')))
|
||||||
|
| NEMO_ALPHA
|
||||||
|
| (NEMO_ALPHA + pynini.cross(".", ' '))
|
||||||
|
)
|
||||||
|
number_words = pynini.compose(number_length, number_words)
|
||||||
|
number_part = area_part + number_words
|
||||||
|
number_part = pynutil.insert("number_part: \"") + number_part + pynutil.insert("\"")
|
||||||
|
extension = (
|
||||||
|
pynutil.insert("extension: \"") + pynini.closure(digit + insert_space, 0, 3) + digit + pynutil.insert("\"")
|
||||||
|
)
|
||||||
|
extension = pynini.closure(insert_space + extension, 0, 1)
|
||||||
|
|
||||||
|
graph = plurals._priority_union(country_code + number_part, number_part, NEMO_SIGMA).optimize()
|
||||||
|
graph = plurals._priority_union(country_code + number_part + extension, graph, NEMO_SIGMA).optimize()
|
||||||
|
graph = plurals._priority_union(number_part + extension, graph, NEMO_SIGMA).optimize()
|
||||||
|
|
||||||
|
# ip
|
||||||
|
ip_prompts = pynini.string_file(get_abs_path("data/telephone/ip_prompt.tsv"))
|
||||||
|
digit_to_str_graph = digit + pynini.closure(pynutil.insert(" ") + digit, 0, 2)
|
||||||
|
ip_graph = digit_to_str_graph + (pynini.cross(".", " dot ") + digit_to_str_graph) ** 3
|
||||||
|
graph |= (
|
||||||
|
pynini.closure(
|
||||||
|
pynutil.insert("country_code: \"") + ip_prompts + pynutil.insert("\"") + delete_extra_space, 0, 1
|
||||||
|
)
|
||||||
|
+ pynutil.insert("number_part: \"")
|
||||||
|
+ ip_graph.optimize()
|
||||||
|
+ pynutil.insert("\"")
|
||||||
|
)
|
||||||
|
# ssn
|
||||||
|
ssn_prompts = pynini.string_file(get_abs_path("data/telephone/ssn_prompt.tsv"))
|
||||||
|
three_digit_part = digit + (pynutil.insert(" ") + digit) ** 2
|
||||||
|
two_digit_part = digit + pynutil.insert(" ") + digit
|
||||||
|
four_digit_part = digit + (pynutil.insert(" ") + digit) ** 3
|
||||||
|
ssn_separator = pynini.cross("-", ", ")
|
||||||
|
ssn_graph = three_digit_part + ssn_separator + two_digit_part + ssn_separator + four_digit_part
|
||||||
|
|
||||||
|
graph |= (
|
||||||
|
pynini.closure(
|
||||||
|
pynutil.insert("country_code: \"") + ssn_prompts + pynutil.insert("\"") + delete_extra_space, 0, 1
|
||||||
|
)
|
||||||
|
+ pynutil.insert("number_part: \"")
|
||||||
|
+ ssn_graph.optimize()
|
||||||
|
+ pynutil.insert("\"")
|
||||||
|
)
|
||||||
|
|
||||||
|
final_graph = self.add_tokens(graph)
|
||||||
|
self.fst = final_graph.optimize()
|
||||||
@@ -0,0 +1,132 @@
|
|||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import pynini
|
||||||
|
from nemo_text_processing.text_normalization.en.graph_utils import (
|
||||||
|
NEMO_DIGIT,
|
||||||
|
GraphFst,
|
||||||
|
convert_space,
|
||||||
|
delete_space,
|
||||||
|
insert_space,
|
||||||
|
)
|
||||||
|
from nemo_text_processing.text_normalization.en.utils import (
|
||||||
|
augment_labels_with_punct_at_end,
|
||||||
|
get_abs_path,
|
||||||
|
load_labels,
|
||||||
|
)
|
||||||
|
from pynini.lib import pynutil
|
||||||
|
|
||||||
|
|
||||||
|
class TimeFst(GraphFst):
|
||||||
|
"""
|
||||||
|
Finite state transducer for classifying time, e.g.
|
||||||
|
12:30 a.m. est -> time { hours: "twelve" minutes: "thirty" suffix: "a m" zone: "e s t" }
|
||||||
|
2.30 a.m. -> time { hours: "two" minutes: "thirty" suffix: "a m" }
|
||||||
|
02.30 a.m. -> time { hours: "two" minutes: "thirty" suffix: "a m" }
|
||||||
|
2.00 a.m. -> time { hours: "two" suffix: "a m" }
|
||||||
|
2 a.m. -> time { hours: "two" suffix: "a m" }
|
||||||
|
02:00 -> time { hours: "two" }
|
||||||
|
2:00 -> time { hours: "two" }
|
||||||
|
10:00:05 a.m. -> time { hours: "ten" minutes: "zero" seconds: "five" suffix: "a m" }
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cardinal: CardinalFst
|
||||||
|
deterministic: if True will provide a single transduction option,
|
||||||
|
for False multiple transduction are generated (used for audio-based normalization)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cardinal: GraphFst, deterministic: bool = True):
|
||||||
|
super().__init__(name="time", kind="classify", deterministic=deterministic)
|
||||||
|
suffix_labels = load_labels(get_abs_path("data/time/suffix.tsv"))
|
||||||
|
suffix_labels.extend(augment_labels_with_punct_at_end(suffix_labels))
|
||||||
|
suffix_graph = pynini.string_map(suffix_labels)
|
||||||
|
|
||||||
|
time_zone_graph = pynini.string_file(get_abs_path("data/time/zone.tsv"))
|
||||||
|
|
||||||
|
# only used for < 1000 thousand -> 0 weight
|
||||||
|
cardinal = cardinal.graph
|
||||||
|
|
||||||
|
labels_hour = [str(x) for x in range(0, 24)]
|
||||||
|
labels_minute_single = [str(x) for x in range(1, 10)]
|
||||||
|
labels_minute_double = [str(x) for x in range(10, 60)]
|
||||||
|
|
||||||
|
delete_leading_zero_to_double_digit = (NEMO_DIGIT + NEMO_DIGIT) | (
|
||||||
|
pynini.closure(pynutil.delete("0"), 0, 1) + NEMO_DIGIT
|
||||||
|
)
|
||||||
|
|
||||||
|
graph_hour = delete_leading_zero_to_double_digit @ pynini.union(*labels_hour) @ cardinal
|
||||||
|
|
||||||
|
graph_minute_single = pynini.union(*labels_minute_single) @ cardinal
|
||||||
|
graph_minute_double = pynini.union(*labels_minute_double) @ cardinal
|
||||||
|
|
||||||
|
final_graph_hour = pynutil.insert("hours: \"") + graph_hour + pynutil.insert("\"")
|
||||||
|
final_graph_minute = (
|
||||||
|
pynutil.insert("minutes: \"")
|
||||||
|
+ (pynini.cross("0", "o") + insert_space + graph_minute_single | graph_minute_double)
|
||||||
|
+ pynutil.insert("\"")
|
||||||
|
)
|
||||||
|
final_graph_second = (
|
||||||
|
pynutil.insert("seconds: \"")
|
||||||
|
+ (pynini.cross("0", "o") + insert_space + graph_minute_single | graph_minute_double)
|
||||||
|
+ pynutil.insert("\"")
|
||||||
|
)
|
||||||
|
final_suffix = pynutil.insert("suffix: \"") + convert_space(suffix_graph) + pynutil.insert("\"")
|
||||||
|
final_suffix_optional = pynini.closure(delete_space + insert_space + final_suffix, 0, 1)
|
||||||
|
final_time_zone_optional = pynini.closure(
|
||||||
|
delete_space
|
||||||
|
+ insert_space
|
||||||
|
+ pynutil.insert("zone: \"")
|
||||||
|
+ convert_space(time_zone_graph)
|
||||||
|
+ pynutil.insert("\""),
|
||||||
|
0,
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2:30 pm, 02:30, 2:00
|
||||||
|
graph_hm = (
|
||||||
|
final_graph_hour
|
||||||
|
+ pynutil.delete(":")
|
||||||
|
+ (pynutil.delete("00") | insert_space + final_graph_minute)
|
||||||
|
+ final_suffix_optional
|
||||||
|
+ final_time_zone_optional
|
||||||
|
)
|
||||||
|
|
||||||
|
# 10:30:05 pm,
|
||||||
|
graph_hms = (
|
||||||
|
final_graph_hour
|
||||||
|
+ pynutil.delete(":")
|
||||||
|
+ (pynini.cross("00", " minutes: \"zero\"") | insert_space + final_graph_minute)
|
||||||
|
+ pynutil.delete(":")
|
||||||
|
+ (pynini.cross("00", " seconds: \"zero\"") | insert_space + final_graph_second)
|
||||||
|
+ final_suffix_optional
|
||||||
|
+ final_time_zone_optional
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2.xx pm/am
|
||||||
|
graph_hm2 = (
|
||||||
|
final_graph_hour
|
||||||
|
+ pynutil.delete(".")
|
||||||
|
+ (pynutil.delete("00") | insert_space + final_graph_minute)
|
||||||
|
+ delete_space
|
||||||
|
+ insert_space
|
||||||
|
+ final_suffix
|
||||||
|
+ final_time_zone_optional
|
||||||
|
)
|
||||||
|
# 2 pm est
|
||||||
|
graph_h = final_graph_hour + delete_space + insert_space + final_suffix + final_time_zone_optional
|
||||||
|
final_graph = (graph_hm | graph_h | graph_hm2 | graph_hms).optimize()
|
||||||
|
|
||||||
|
final_graph = self.add_tokens(final_graph)
|
||||||
|
self.fst = final_graph.optimize()
|
||||||
@@ -0,0 +1,201 @@
|
|||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
|
import pynini
|
||||||
|
from nemo_text_processing.text_normalization.en.graph_utils import (
|
||||||
|
NEMO_WHITE_SPACE,
|
||||||
|
GraphFst,
|
||||||
|
delete_extra_space,
|
||||||
|
delete_space,
|
||||||
|
generator_main,
|
||||||
|
)
|
||||||
|
from nemo_text_processing.text_normalization.en.taggers.abbreviation import AbbreviationFst
|
||||||
|
from nemo_text_processing.text_normalization.en.taggers.cardinal import CardinalFst
|
||||||
|
from nemo_text_processing.text_normalization.en.taggers.date import DateFst
|
||||||
|
from nemo_text_processing.text_normalization.en.taggers.decimal import DecimalFst
|
||||||
|
from nemo_text_processing.text_normalization.en.taggers.electronic import ElectronicFst
|
||||||
|
from nemo_text_processing.text_normalization.en.taggers.fraction import FractionFst
|
||||||
|
from nemo_text_processing.text_normalization.en.taggers.measure import MeasureFst
|
||||||
|
from nemo_text_processing.text_normalization.en.taggers.money import MoneyFst
|
||||||
|
from nemo_text_processing.text_normalization.en.taggers.ordinal import OrdinalFst
|
||||||
|
from nemo_text_processing.text_normalization.en.taggers.punctuation import PunctuationFst
|
||||||
|
from nemo_text_processing.text_normalization.en.taggers.range import RangeFst as RangeFst
|
||||||
|
from nemo_text_processing.text_normalization.en.taggers.roman import RomanFst
|
||||||
|
from nemo_text_processing.text_normalization.en.taggers.serial import SerialFst
|
||||||
|
from nemo_text_processing.text_normalization.en.taggers.telephone import TelephoneFst
|
||||||
|
from nemo_text_processing.text_normalization.en.taggers.time import TimeFst
|
||||||
|
from nemo_text_processing.text_normalization.en.taggers.whitelist import WhiteListFst
|
||||||
|
from nemo_text_processing.text_normalization.en.taggers.word import WordFst
|
||||||
|
from nemo_text_processing.text_normalization.en.verbalizers.date import DateFst as vDateFst
|
||||||
|
from nemo_text_processing.text_normalization.en.verbalizers.ordinal import OrdinalFst as vOrdinalFst
|
||||||
|
from nemo_text_processing.text_normalization.en.verbalizers.time import TimeFst as vTimeFst
|
||||||
|
from pynini.lib import pynutil
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ClassifyFst(GraphFst):
|
||||||
|
"""
|
||||||
|
Final class that composes all other classification grammars. This class can process an entire sentence including punctuation.
|
||||||
|
For deployment, this grammar will be compiled and exported to OpenFst Finate State Archiv (FAR) File.
|
||||||
|
More details to deployment at NeMo/tools/text_processing_deployment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_case: accepting either "lower_cased" or "cased" input.
|
||||||
|
deterministic: if True will provide a single transduction option,
|
||||||
|
for False multiple options (used for audio-based normalization)
|
||||||
|
cache_dir: path to a dir with .far grammar file. Set to None to avoid using cache.
|
||||||
|
overwrite_cache: set to True to overwrite .far files
|
||||||
|
whitelist: path to a file with whitelist replacements
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_case: str,
|
||||||
|
deterministic: bool = True,
|
||||||
|
cache_dir: str = None,
|
||||||
|
overwrite_cache: bool = False,
|
||||||
|
whitelist: str = None,
|
||||||
|
):
|
||||||
|
super().__init__(name="tokenize_and_classify", kind="classify", deterministic=deterministic)
|
||||||
|
|
||||||
|
far_file = None
|
||||||
|
if cache_dir is not None and cache_dir != "None":
|
||||||
|
os.makedirs(cache_dir, exist_ok=True)
|
||||||
|
whitelist_file = os.path.basename(whitelist) if whitelist else ""
|
||||||
|
far_file = os.path.join(
|
||||||
|
cache_dir, f"en_tn_{deterministic}_deterministic_{input_case}_{whitelist_file}_tokenize.far"
|
||||||
|
)
|
||||||
|
if not overwrite_cache and far_file and os.path.exists(far_file):
|
||||||
|
self.fst = pynini.Far(far_file, mode="r")["tokenize_and_classify"]
|
||||||
|
else:
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
cardinal = CardinalFst(deterministic=deterministic)
|
||||||
|
cardinal_graph = cardinal.fst
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
ordinal = OrdinalFst(cardinal=cardinal, deterministic=deterministic)
|
||||||
|
ordinal_graph = ordinal.fst
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
decimal = DecimalFst(cardinal=cardinal, deterministic=deterministic)
|
||||||
|
decimal_graph = decimal.fst
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
fraction = FractionFst(deterministic=deterministic, cardinal=cardinal)
|
||||||
|
fraction_graph = fraction.fst
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
measure = MeasureFst(cardinal=cardinal, decimal=decimal, fraction=fraction, deterministic=deterministic)
|
||||||
|
measure_graph = measure.fst
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
date_graph = DateFst(cardinal=cardinal, deterministic=deterministic).fst
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
time_graph = TimeFst(cardinal=cardinal, deterministic=deterministic).fst
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
telephone_graph = TelephoneFst(deterministic=deterministic).fst
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
electonic_graph = ElectronicFst(deterministic=deterministic).fst
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
money_graph = MoneyFst(cardinal=cardinal, decimal=decimal, deterministic=deterministic).fst
|
||||||
|
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
whitelist_graph = WhiteListFst(
|
||||||
|
input_case=input_case, deterministic=deterministic, input_file=whitelist
|
||||||
|
).fst
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
punctuation = PunctuationFst(deterministic=deterministic)
|
||||||
|
punct_graph = punctuation.fst
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
word_graph = WordFst(punctuation=punctuation, deterministic=deterministic).fst
|
||||||
|
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
serial_graph = SerialFst(cardinal=cardinal, ordinal=ordinal, deterministic=deterministic).fst
|
||||||
|
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
v_time_graph = vTimeFst(deterministic=deterministic).fst
|
||||||
|
v_ordinal_graph = vOrdinalFst(deterministic=deterministic)
|
||||||
|
v_date_graph = vDateFst(ordinal=v_ordinal_graph, deterministic=deterministic).fst
|
||||||
|
time_final = pynini.compose(time_graph, v_time_graph)
|
||||||
|
date_final = pynini.compose(date_graph, v_date_graph)
|
||||||
|
range_graph = RangeFst(
|
||||||
|
time=time_final, date=date_final, cardinal=cardinal, deterministic=deterministic
|
||||||
|
).fst
|
||||||
|
|
||||||
|
|
||||||
|
classify = (
|
||||||
|
pynutil.add_weight(whitelist_graph, 1.01)
|
||||||
|
| pynutil.add_weight(time_graph, 1.1)
|
||||||
|
| pynutil.add_weight(date_graph, 1.09)
|
||||||
|
| pynutil.add_weight(decimal_graph, 1.1)
|
||||||
|
| pynutil.add_weight(measure_graph, 1.1)
|
||||||
|
| pynutil.add_weight(cardinal_graph, 1.1)
|
||||||
|
| pynutil.add_weight(ordinal_graph, 1.1)
|
||||||
|
| pynutil.add_weight(money_graph, 1.1)
|
||||||
|
| pynutil.add_weight(telephone_graph, 1.1)
|
||||||
|
| pynutil.add_weight(electonic_graph, 1.1)
|
||||||
|
| pynutil.add_weight(fraction_graph, 1.1)
|
||||||
|
| pynutil.add_weight(range_graph, 1.1)
|
||||||
|
| pynutil.add_weight(serial_graph, 1.1001) # should be higher than the rest of the classes
|
||||||
|
)
|
||||||
|
|
||||||
|
roman_graph = RomanFst(deterministic=deterministic).fst
|
||||||
|
classify |= pynutil.add_weight(roman_graph, 1.1)
|
||||||
|
|
||||||
|
if not deterministic:
|
||||||
|
abbreviation_graph = AbbreviationFst(deterministic=deterministic).fst
|
||||||
|
classify |= pynutil.add_weight(abbreviation_graph, 100)
|
||||||
|
|
||||||
|
punct = pynutil.insert("tokens { ") + pynutil.add_weight(punct_graph, weight=2.1) + pynutil.insert(" }")
|
||||||
|
punct = pynini.closure(
|
||||||
|
pynini.compose(pynini.closure(NEMO_WHITE_SPACE, 1), delete_extra_space)
|
||||||
|
| (pynutil.insert(" ") + punct),
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
|
||||||
|
classify |= pynutil.add_weight(word_graph, 100)
|
||||||
|
token = pynutil.insert("tokens { ") + classify + pynutil.insert(" }")
|
||||||
|
token_plus_punct = (
|
||||||
|
pynini.closure(punct + pynutil.insert(" ")) + token + pynini.closure(pynutil.insert(" ") + punct)
|
||||||
|
)
|
||||||
|
|
||||||
|
graph = token_plus_punct + pynini.closure(
|
||||||
|
(
|
||||||
|
pynini.compose(pynini.closure(NEMO_WHITE_SPACE, 1), delete_extra_space)
|
||||||
|
| (pynutil.insert(" ") + punct + pynutil.insert(" "))
|
||||||
|
)
|
||||||
|
+ token_plus_punct
|
||||||
|
)
|
||||||
|
|
||||||
|
graph = delete_space + graph + delete_space
|
||||||
|
graph |= punct
|
||||||
|
|
||||||
|
self.fst = graph.optimize()
|
||||||
|
|
||||||
|
if far_file:
|
||||||
|
generator_main(far_file, {"tokenize_and_classify": self.fst})
|
||||||
|
|
||||||
@@ -0,0 +1,228 @@
|
|||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pynini
|
||||||
|
from nemo_text_processing.text_normalization.en.graph_utils import (
|
||||||
|
NEMO_CHAR,
|
||||||
|
NEMO_DIGIT,
|
||||||
|
NEMO_NOT_SPACE,
|
||||||
|
NEMO_SIGMA,
|
||||||
|
NEMO_WHITE_SPACE,
|
||||||
|
GraphFst,
|
||||||
|
delete_extra_space,
|
||||||
|
delete_space,
|
||||||
|
generator_main,
|
||||||
|
)
|
||||||
|
from nemo_text_processing.text_normalization.en.taggers.cardinal import CardinalFst
|
||||||
|
from nemo_text_processing.text_normalization.en.taggers.date import DateFst
|
||||||
|
from nemo_text_processing.text_normalization.en.taggers.decimal import DecimalFst
|
||||||
|
from nemo_text_processing.text_normalization.en.taggers.electronic import ElectronicFst
|
||||||
|
from nemo_text_processing.text_normalization.en.taggers.fraction import FractionFst
|
||||||
|
from nemo_text_processing.text_normalization.en.taggers.measure import MeasureFst
|
||||||
|
from nemo_text_processing.text_normalization.en.taggers.money import MoneyFst
|
||||||
|
from nemo_text_processing.text_normalization.en.taggers.ordinal import OrdinalFst
|
||||||
|
from nemo_text_processing.text_normalization.en.taggers.punctuation import PunctuationFst
|
||||||
|
from nemo_text_processing.text_normalization.en.taggers.range import RangeFst as RangeFst
|
||||||
|
from nemo_text_processing.text_normalization.en.taggers.roman import RomanFst
|
||||||
|
from nemo_text_processing.text_normalization.en.taggers.serial import SerialFst
|
||||||
|
from nemo_text_processing.text_normalization.en.taggers.telephone import TelephoneFst
|
||||||
|
from nemo_text_processing.text_normalization.en.taggers.time import TimeFst
|
||||||
|
from nemo_text_processing.text_normalization.en.taggers.whitelist import WhiteListFst
|
||||||
|
from nemo_text_processing.text_normalization.en.taggers.word import WordFst
|
||||||
|
from nemo_text_processing.text_normalization.en.verbalizers.cardinal import CardinalFst as vCardinal
|
||||||
|
from nemo_text_processing.text_normalization.en.verbalizers.date import DateFst as vDate
|
||||||
|
from nemo_text_processing.text_normalization.en.verbalizers.decimal import DecimalFst as vDecimal
|
||||||
|
from nemo_text_processing.text_normalization.en.verbalizers.electronic import ElectronicFst as vElectronic
|
||||||
|
from nemo_text_processing.text_normalization.en.verbalizers.fraction import FractionFst as vFraction
|
||||||
|
from nemo_text_processing.text_normalization.en.verbalizers.measure import MeasureFst as vMeasure
|
||||||
|
from nemo_text_processing.text_normalization.en.verbalizers.money import MoneyFst as vMoney
|
||||||
|
from nemo_text_processing.text_normalization.en.verbalizers.ordinal import OrdinalFst as vOrdinal
|
||||||
|
from nemo_text_processing.text_normalization.en.verbalizers.roman import RomanFst as vRoman
|
||||||
|
from nemo_text_processing.text_normalization.en.verbalizers.telephone import TelephoneFst as vTelephone
|
||||||
|
from nemo_text_processing.text_normalization.en.verbalizers.time import TimeFst as vTime
|
||||||
|
from nemo_text_processing.text_normalization.en.verbalizers.word import WordFst as vWord
|
||||||
|
from pynini.examples import plurals
|
||||||
|
from pynini.lib import pynutil
|
||||||
|
|
||||||
|
from nemo.utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
class ClassifyFst(GraphFst):
|
||||||
|
"""
|
||||||
|
Final class that composes all other classification grammars. This class can process an entire sentence including punctuation.
|
||||||
|
For deployment, this grammar will be compiled and exported to OpenFst Finite State Archive (FAR) File.
|
||||||
|
More details to deployment at NeMo/tools/text_processing_deployment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_case: accepting either "lower_cased" or "cased" input.
|
||||||
|
deterministic: if True will provide a single transduction option,
|
||||||
|
for False multiple options (used for audio-based normalization)
|
||||||
|
cache_dir: path to a dir with .far grammar file. Set to None to avoid using cache.
|
||||||
|
overwrite_cache: set to True to overwrite .far files
|
||||||
|
whitelist: path to a file with whitelist replacements
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_case: str,
|
||||||
|
deterministic: bool = True,
|
||||||
|
cache_dir: str = None,
|
||||||
|
overwrite_cache: bool = True,
|
||||||
|
whitelist: str = None,
|
||||||
|
):
|
||||||
|
super().__init__(name="tokenize_and_classify", kind="classify", deterministic=deterministic)
|
||||||
|
|
||||||
|
far_file = None
|
||||||
|
if cache_dir is not None and cache_dir != 'None':
|
||||||
|
os.makedirs(cache_dir, exist_ok=True)
|
||||||
|
whitelist_file = os.path.basename(whitelist) if whitelist else ""
|
||||||
|
far_file = os.path.join(
|
||||||
|
cache_dir, f"_{input_case}_en_tn_{deterministic}_deterministic{whitelist_file}_lm.far"
|
||||||
|
)
|
||||||
|
if not overwrite_cache and far_file and os.path.exists(far_file):
|
||||||
|
self.fst = pynini.Far(far_file, mode='r')['tokenize_and_classify']
|
||||||
|
no_digits = pynini.closure(pynini.difference(NEMO_CHAR, NEMO_DIGIT))
|
||||||
|
self.fst_no_digits = pynini.compose(self.fst, no_digits).optimize()
|
||||||
|
logging.info(f'ClassifyFst.fst was restored from {far_file}.')
|
||||||
|
else:
|
||||||
|
logging.info(f'Creating ClassifyFst grammars. This might take some time...')
|
||||||
|
# TAGGERS
|
||||||
|
cardinal = CardinalFst(deterministic=True, lm=True)
|
||||||
|
cardinal_tagger = cardinal
|
||||||
|
cardinal_graph = cardinal.fst
|
||||||
|
|
||||||
|
ordinal = OrdinalFst(cardinal=cardinal, deterministic=True)
|
||||||
|
ordinal_graph = ordinal.fst
|
||||||
|
|
||||||
|
decimal = DecimalFst(cardinal=cardinal, deterministic=True)
|
||||||
|
decimal_graph = decimal.fst
|
||||||
|
fraction = FractionFst(deterministic=True, cardinal=cardinal)
|
||||||
|
fraction_graph = fraction.fst
|
||||||
|
|
||||||
|
measure = MeasureFst(cardinal=cardinal, decimal=decimal, fraction=fraction, deterministic=True)
|
||||||
|
measure_graph = measure.fst
|
||||||
|
date = DateFst(cardinal=cardinal, deterministic=True, lm=True)
|
||||||
|
date_graph = date.fst
|
||||||
|
punctuation = PunctuationFst(deterministic=True)
|
||||||
|
punct_graph = punctuation.graph
|
||||||
|
word_graph = WordFst(punctuation=punctuation, deterministic=deterministic).graph
|
||||||
|
time_graph = TimeFst(cardinal=cardinal, deterministic=True).fst
|
||||||
|
telephone_graph = TelephoneFst(deterministic=True).fst
|
||||||
|
electronic_graph = ElectronicFst(deterministic=True).fst
|
||||||
|
money_graph = MoneyFst(cardinal=cardinal, decimal=decimal, deterministic=False).fst
|
||||||
|
whitelist = WhiteListFst(input_case=input_case, deterministic=False, input_file=whitelist)
|
||||||
|
whitelist_graph = whitelist.graph
|
||||||
|
serial_graph = SerialFst(cardinal=cardinal, ordinal=ordinal, deterministic=deterministic, lm=True).fst
|
||||||
|
|
||||||
|
# VERBALIZERS
|
||||||
|
cardinal = vCardinal(deterministic=True)
|
||||||
|
v_cardinal_graph = cardinal.fst
|
||||||
|
decimal = vDecimal(cardinal=cardinal, deterministic=True)
|
||||||
|
v_decimal_graph = decimal.fst
|
||||||
|
ordinal = vOrdinal(deterministic=True)
|
||||||
|
v_ordinal_graph = ordinal.fst
|
||||||
|
fraction = vFraction(deterministic=True, lm=True)
|
||||||
|
v_fraction_graph = fraction.fst
|
||||||
|
v_telephone_graph = vTelephone(deterministic=True).fst
|
||||||
|
v_electronic_graph = vElectronic(deterministic=True).fst
|
||||||
|
measure = vMeasure(decimal=decimal, cardinal=cardinal, fraction=fraction, deterministic=False)
|
||||||
|
v_measure_graph = measure.fst
|
||||||
|
v_time_graph = vTime(deterministic=True).fst
|
||||||
|
v_date_graph = vDate(ordinal=ordinal, deterministic=deterministic, lm=True).fst
|
||||||
|
v_money_graph = vMoney(decimal=decimal, deterministic=deterministic).fst
|
||||||
|
v_roman_graph = vRoman(deterministic=deterministic).fst
|
||||||
|
v_word_graph = vWord(deterministic=deterministic).fst
|
||||||
|
|
||||||
|
cardinal_or_date_final = plurals._priority_union(date_graph, cardinal_graph, NEMO_SIGMA)
|
||||||
|
cardinal_or_date_final = pynini.compose(cardinal_or_date_final, (v_cardinal_graph | v_date_graph))
|
||||||
|
|
||||||
|
time_final = pynini.compose(time_graph, v_time_graph)
|
||||||
|
ordinal_final = pynini.compose(ordinal_graph, v_ordinal_graph)
|
||||||
|
sem_w = 1
|
||||||
|
word_w = 100
|
||||||
|
punct_w = 2
|
||||||
|
classify_and_verbalize = (
|
||||||
|
pynutil.add_weight(time_final, sem_w)
|
||||||
|
| pynutil.add_weight(pynini.compose(decimal_graph, v_decimal_graph), sem_w)
|
||||||
|
| pynutil.add_weight(pynini.compose(measure_graph, v_measure_graph), sem_w)
|
||||||
|
| pynutil.add_weight(ordinal_final, sem_w)
|
||||||
|
| pynutil.add_weight(pynini.compose(telephone_graph, v_telephone_graph), sem_w)
|
||||||
|
| pynutil.add_weight(pynini.compose(electronic_graph, v_electronic_graph), sem_w)
|
||||||
|
| pynutil.add_weight(pynini.compose(fraction_graph, v_fraction_graph), sem_w)
|
||||||
|
| pynutil.add_weight(pynini.compose(money_graph, v_money_graph), sem_w)
|
||||||
|
| pynutil.add_weight(cardinal_or_date_final, sem_w)
|
||||||
|
| pynutil.add_weight(whitelist_graph, sem_w)
|
||||||
|
| pynutil.add_weight(
|
||||||
|
pynini.compose(serial_graph, v_word_graph), 1.1001
|
||||||
|
) # should be higher than the rest of the classes
|
||||||
|
).optimize()
|
||||||
|
|
||||||
|
roman_graph = RomanFst(deterministic=deterministic, lm=True).fst
|
||||||
|
# the weight matches the word_graph weight for "I" cases in long sentences with multiple semiotic tokens
|
||||||
|
classify_and_verbalize |= pynutil.add_weight(pynini.compose(roman_graph, v_roman_graph), sem_w)
|
||||||
|
|
||||||
|
date_final = pynini.compose(date_graph, v_date_graph)
|
||||||
|
range_graph = RangeFst(
|
||||||
|
time=time_final, cardinal=cardinal_tagger, date=date_final, deterministic=deterministic
|
||||||
|
).fst
|
||||||
|
classify_and_verbalize |= pynutil.add_weight(pynini.compose(range_graph, v_word_graph), sem_w)
|
||||||
|
classify_and_verbalize = pynutil.insert("< ") + classify_and_verbalize + pynutil.insert(" >")
|
||||||
|
classify_and_verbalize |= pynutil.add_weight(word_graph, word_w)
|
||||||
|
|
||||||
|
punct_only = pynutil.add_weight(punct_graph, weight=punct_w)
|
||||||
|
punct = pynini.closure(
|
||||||
|
pynini.compose(pynini.closure(NEMO_WHITE_SPACE, 1), delete_extra_space)
|
||||||
|
| (pynutil.insert(" ") + punct_only),
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_token_sem_graph(classify_and_verbalize):
|
||||||
|
token_plus_punct = (
|
||||||
|
pynini.closure(punct + pynutil.insert(" "))
|
||||||
|
+ classify_and_verbalize
|
||||||
|
+ pynini.closure(pynutil.insert(" ") + punct)
|
||||||
|
)
|
||||||
|
|
||||||
|
graph = token_plus_punct + pynini.closure(
|
||||||
|
(
|
||||||
|
pynini.compose(pynini.closure(NEMO_WHITE_SPACE, 1), delete_extra_space)
|
||||||
|
| (pynutil.insert(" ") + punct + pynutil.insert(" "))
|
||||||
|
)
|
||||||
|
+ token_plus_punct
|
||||||
|
)
|
||||||
|
|
||||||
|
graph |= punct_only + pynini.closure(punct)
|
||||||
|
graph = delete_space + graph + delete_space
|
||||||
|
|
||||||
|
remove_extra_spaces = pynini.closure(NEMO_NOT_SPACE, 1) + pynini.closure(
|
||||||
|
delete_extra_space + pynini.closure(NEMO_NOT_SPACE, 1)
|
||||||
|
)
|
||||||
|
remove_extra_spaces |= (
|
||||||
|
pynini.closure(pynutil.delete(" "), 1)
|
||||||
|
+ pynini.closure(NEMO_NOT_SPACE, 1)
|
||||||
|
+ pynini.closure(delete_extra_space + pynini.closure(NEMO_NOT_SPACE, 1))
|
||||||
|
)
|
||||||
|
|
||||||
|
graph = pynini.compose(graph.optimize(), remove_extra_spaces).optimize()
|
||||||
|
return graph
|
||||||
|
|
||||||
|
self.fst = get_token_sem_graph(classify_and_verbalize)
|
||||||
|
no_digits = pynini.closure(pynini.difference(NEMO_CHAR, NEMO_DIGIT))
|
||||||
|
self.fst_no_digits = pynini.compose(self.fst, no_digits).optimize()
|
||||||
|
|
||||||
|
if far_file:
|
||||||
|
generator_main(far_file, {"tokenize_and_classify": self.fst})
|
||||||
|
logging.info(f'ClassifyFst grammars are saved to {far_file}.')
|
||||||
@@ -0,0 +1,229 @@
|
|||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pynini
|
||||||
|
from nemo_text_processing.text_normalization.en.graph_utils import (
|
||||||
|
NEMO_CHAR,
|
||||||
|
NEMO_DIGIT,
|
||||||
|
NEMO_NOT_SPACE,
|
||||||
|
NEMO_WHITE_SPACE,
|
||||||
|
GraphFst,
|
||||||
|
delete_extra_space,
|
||||||
|
delete_space,
|
||||||
|
generator_main,
|
||||||
|
)
|
||||||
|
from nemo_text_processing.text_normalization.en.taggers.abbreviation import AbbreviationFst
|
||||||
|
from nemo_text_processing.text_normalization.en.taggers.cardinal import CardinalFst
|
||||||
|
from nemo_text_processing.text_normalization.en.taggers.date import DateFst
|
||||||
|
from nemo_text_processing.text_normalization.en.taggers.decimal import DecimalFst
|
||||||
|
from nemo_text_processing.text_normalization.en.taggers.electronic import ElectronicFst
|
||||||
|
from nemo_text_processing.text_normalization.en.taggers.fraction import FractionFst
|
||||||
|
from nemo_text_processing.text_normalization.en.taggers.measure import MeasureFst
|
||||||
|
from nemo_text_processing.text_normalization.en.taggers.money import MoneyFst
|
||||||
|
from nemo_text_processing.text_normalization.en.taggers.ordinal import OrdinalFst
|
||||||
|
from nemo_text_processing.text_normalization.en.taggers.punctuation import PunctuationFst
|
||||||
|
from nemo_text_processing.text_normalization.en.taggers.range import RangeFst as RangeFst
|
||||||
|
from nemo_text_processing.text_normalization.en.taggers.roman import RomanFst
|
||||||
|
from nemo_text_processing.text_normalization.en.taggers.serial import SerialFst
|
||||||
|
from nemo_text_processing.text_normalization.en.taggers.telephone import TelephoneFst
|
||||||
|
from nemo_text_processing.text_normalization.en.taggers.time import TimeFst
|
||||||
|
from nemo_text_processing.text_normalization.en.taggers.whitelist import WhiteListFst
|
||||||
|
from nemo_text_processing.text_normalization.en.taggers.word import WordFst
|
||||||
|
from nemo_text_processing.text_normalization.en.verbalizers.abbreviation import AbbreviationFst as vAbbreviation
|
||||||
|
from nemo_text_processing.text_normalization.en.verbalizers.cardinal import CardinalFst as vCardinal
|
||||||
|
from nemo_text_processing.text_normalization.en.verbalizers.date import DateFst as vDate
|
||||||
|
from nemo_text_processing.text_normalization.en.verbalizers.decimal import DecimalFst as vDecimal
|
||||||
|
from nemo_text_processing.text_normalization.en.verbalizers.electronic import ElectronicFst as vElectronic
|
||||||
|
from nemo_text_processing.text_normalization.en.verbalizers.fraction import FractionFst as vFraction
|
||||||
|
from nemo_text_processing.text_normalization.en.verbalizers.measure import MeasureFst as vMeasure
|
||||||
|
from nemo_text_processing.text_normalization.en.verbalizers.money import MoneyFst as vMoney
|
||||||
|
from nemo_text_processing.text_normalization.en.verbalizers.ordinal import OrdinalFst as vOrdinal
|
||||||
|
from nemo_text_processing.text_normalization.en.verbalizers.roman import RomanFst as vRoman
|
||||||
|
from nemo_text_processing.text_normalization.en.verbalizers.telephone import TelephoneFst as vTelephone
|
||||||
|
from nemo_text_processing.text_normalization.en.verbalizers.time import TimeFst as vTime
|
||||||
|
from nemo_text_processing.text_normalization.en.verbalizers.word import WordFst as vWord
|
||||||
|
from pynini.lib import pynutil
|
||||||
|
|
||||||
|
from nemo.utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
class ClassifyFst(GraphFst):
|
||||||
|
"""
|
||||||
|
Final class that composes all other classification grammars. This class can process an entire sentence including punctuation.
|
||||||
|
For deployment, this grammar will be compiled and exported to OpenFst Finite State Archive (FAR) File.
|
||||||
|
More details to deployment at NeMo/tools/text_processing_deployment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_case: accepting either "lower_cased" or "cased" input.
|
||||||
|
deterministic: if True will provide a single transduction option,
|
||||||
|
for False multiple options (used for audio-based normalization)
|
||||||
|
cache_dir: path to a dir with .far grammar file. Set to None to avoid using cache.
|
||||||
|
overwrite_cache: set to True to overwrite .far files
|
||||||
|
whitelist: path to a file with whitelist replacements
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_case: str,
|
||||||
|
deterministic: bool = True,
|
||||||
|
cache_dir: str = None,
|
||||||
|
overwrite_cache: bool = True,
|
||||||
|
whitelist: str = None,
|
||||||
|
):
|
||||||
|
super().__init__(name="tokenize_and_classify", kind="classify", deterministic=deterministic)
|
||||||
|
|
||||||
|
far_file = None
|
||||||
|
if cache_dir is not None and cache_dir != 'None':
|
||||||
|
os.makedirs(cache_dir, exist_ok=True)
|
||||||
|
whitelist_file = os.path.basename(whitelist) if whitelist else ""
|
||||||
|
far_file = os.path.join(
|
||||||
|
cache_dir, f"_{input_case}_en_tn_{deterministic}_deterministic{whitelist_file}.far"
|
||||||
|
)
|
||||||
|
if not overwrite_cache and far_file and os.path.exists(far_file):
|
||||||
|
self.fst = pynini.Far(far_file, mode='r')['tokenize_and_classify']
|
||||||
|
no_digits = pynini.closure(pynini.difference(NEMO_CHAR, NEMO_DIGIT))
|
||||||
|
self.fst_no_digits = pynini.compose(self.fst, no_digits).optimize()
|
||||||
|
logging.info(f'ClassifyFst.fst was restored from {far_file}.')
|
||||||
|
else:
|
||||||
|
logging.info(f'Creating ClassifyFst grammars. This might take some time...')
|
||||||
|
# TAGGERS
|
||||||
|
cardinal = CardinalFst(deterministic=deterministic)
|
||||||
|
cardinal_graph = cardinal.fst
|
||||||
|
|
||||||
|
ordinal = OrdinalFst(cardinal=cardinal, deterministic=deterministic)
|
||||||
|
deterministic_ordinal = OrdinalFst(cardinal=cardinal, deterministic=True)
|
||||||
|
ordinal_graph = ordinal.fst
|
||||||
|
|
||||||
|
decimal = DecimalFst(cardinal=cardinal, deterministic=deterministic)
|
||||||
|
decimal_graph = decimal.fst
|
||||||
|
fraction = FractionFst(deterministic=deterministic, cardinal=cardinal)
|
||||||
|
fraction_graph = fraction.fst
|
||||||
|
|
||||||
|
measure = MeasureFst(cardinal=cardinal, decimal=decimal, fraction=fraction, deterministic=deterministic)
|
||||||
|
measure_graph = measure.fst
|
||||||
|
date_graph = DateFst(cardinal=cardinal, deterministic=deterministic).fst
|
||||||
|
punctuation = PunctuationFst(deterministic=True)
|
||||||
|
punct_graph = punctuation.graph
|
||||||
|
word_graph = WordFst(punctuation=punctuation, deterministic=deterministic).graph
|
||||||
|
time_graph = TimeFst(cardinal=cardinal, deterministic=deterministic).fst
|
||||||
|
telephone_graph = TelephoneFst(deterministic=deterministic).fst
|
||||||
|
electronic_graph = ElectronicFst(deterministic=deterministic).fst
|
||||||
|
money_graph = MoneyFst(cardinal=cardinal, decimal=decimal, deterministic=deterministic).fst
|
||||||
|
whitelist = WhiteListFst(input_case=input_case, deterministic=deterministic, input_file=whitelist)
|
||||||
|
whitelist_graph = whitelist.graph
|
||||||
|
serial_graph = SerialFst(cardinal=cardinal, ordinal=deterministic_ordinal, deterministic=deterministic).fst
|
||||||
|
|
||||||
|
# VERBALIZERS
|
||||||
|
cardinal = vCardinal(deterministic=deterministic)
|
||||||
|
v_cardinal_graph = cardinal.fst
|
||||||
|
decimal = vDecimal(cardinal=cardinal, deterministic=deterministic)
|
||||||
|
v_decimal_graph = decimal.fst
|
||||||
|
ordinal = vOrdinal(deterministic=deterministic)
|
||||||
|
v_ordinal_graph = ordinal.fst
|
||||||
|
fraction = vFraction(deterministic=deterministic)
|
||||||
|
v_fraction_graph = fraction.fst
|
||||||
|
v_telephone_graph = vTelephone(deterministic=deterministic).fst
|
||||||
|
v_electronic_graph = vElectronic(deterministic=deterministic).fst
|
||||||
|
measure = vMeasure(decimal=decimal, cardinal=cardinal, fraction=fraction, deterministic=deterministic)
|
||||||
|
v_measure_graph = measure.fst
|
||||||
|
v_time_graph = vTime(deterministic=deterministic).fst
|
||||||
|
v_date_graph = vDate(ordinal=ordinal, deterministic=deterministic).fst
|
||||||
|
v_money_graph = vMoney(decimal=decimal, deterministic=deterministic).fst
|
||||||
|
v_roman_graph = vRoman(deterministic=deterministic).fst
|
||||||
|
v_abbreviation = vAbbreviation(deterministic=deterministic).fst
|
||||||
|
|
||||||
|
det_v_time_graph = vTime(deterministic=True).fst
|
||||||
|
det_v_date_graph = vDate(ordinal=vOrdinal(deterministic=True), deterministic=True).fst
|
||||||
|
time_final = pynini.compose(time_graph, det_v_time_graph)
|
||||||
|
date_final = pynini.compose(date_graph, det_v_date_graph)
|
||||||
|
range_graph = RangeFst(
|
||||||
|
time=time_final, date=date_final, cardinal=CardinalFst(deterministic=True), deterministic=deterministic
|
||||||
|
).fst
|
||||||
|
v_word_graph = vWord(deterministic=deterministic).fst
|
||||||
|
|
||||||
|
sem_w = 1
|
||||||
|
word_w = 100
|
||||||
|
punct_w = 2
|
||||||
|
classify_and_verbalize = (
|
||||||
|
pynutil.add_weight(whitelist_graph, sem_w)
|
||||||
|
| pynutil.add_weight(pynini.compose(time_graph, v_time_graph), sem_w)
|
||||||
|
| pynutil.add_weight(pynini.compose(decimal_graph, v_decimal_graph), sem_w)
|
||||||
|
| pynutil.add_weight(pynini.compose(measure_graph, v_measure_graph), sem_w)
|
||||||
|
| pynutil.add_weight(pynini.compose(cardinal_graph, v_cardinal_graph), sem_w)
|
||||||
|
| pynutil.add_weight(pynini.compose(ordinal_graph, v_ordinal_graph), sem_w)
|
||||||
|
| pynutil.add_weight(pynini.compose(telephone_graph, v_telephone_graph), sem_w)
|
||||||
|
| pynutil.add_weight(pynini.compose(electronic_graph, v_electronic_graph), sem_w)
|
||||||
|
| pynutil.add_weight(pynini.compose(fraction_graph, v_fraction_graph), sem_w)
|
||||||
|
| pynutil.add_weight(pynini.compose(money_graph, v_money_graph), sem_w)
|
||||||
|
| pynutil.add_weight(word_graph, word_w)
|
||||||
|
| pynutil.add_weight(pynini.compose(date_graph, v_date_graph), sem_w - 0.01)
|
||||||
|
| pynutil.add_weight(pynini.compose(range_graph, v_word_graph), sem_w)
|
||||||
|
| pynutil.add_weight(
|
||||||
|
pynini.compose(serial_graph, v_word_graph), 1.1001
|
||||||
|
) # should be higher than the rest of the classes
|
||||||
|
).optimize()
|
||||||
|
|
||||||
|
if not deterministic:
|
||||||
|
roman_graph = RomanFst(deterministic=deterministic).fst
|
||||||
|
# the weight matches the word_graph weight for "I" cases in long sentences with multiple semiotic tokens
|
||||||
|
classify_and_verbalize |= pynutil.add_weight(pynini.compose(roman_graph, v_roman_graph), word_w)
|
||||||
|
|
||||||
|
abbreviation_graph = AbbreviationFst(whitelist=whitelist, deterministic=deterministic).fst
|
||||||
|
classify_and_verbalize |= pynutil.add_weight(
|
||||||
|
pynini.compose(abbreviation_graph, v_abbreviation), word_w
|
||||||
|
)
|
||||||
|
|
||||||
|
punct_only = pynutil.add_weight(punct_graph, weight=punct_w)
|
||||||
|
punct = pynini.closure(
|
||||||
|
pynini.compose(pynini.closure(NEMO_WHITE_SPACE, 1), delete_extra_space)
|
||||||
|
| (pynutil.insert(" ") + punct_only),
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
|
||||||
|
token_plus_punct = (
|
||||||
|
pynini.closure(punct + pynutil.insert(" "))
|
||||||
|
+ classify_and_verbalize
|
||||||
|
+ pynini.closure(pynutil.insert(" ") + punct)
|
||||||
|
)
|
||||||
|
|
||||||
|
graph = token_plus_punct + pynini.closure(
|
||||||
|
(
|
||||||
|
pynini.compose(pynini.closure(NEMO_WHITE_SPACE, 1), delete_extra_space)
|
||||||
|
| (pynutil.insert(" ") + punct + pynutil.insert(" "))
|
||||||
|
)
|
||||||
|
+ token_plus_punct
|
||||||
|
)
|
||||||
|
|
||||||
|
graph |= punct_only + pynini.closure(punct)
|
||||||
|
graph = delete_space + graph + delete_space
|
||||||
|
|
||||||
|
remove_extra_spaces = pynini.closure(NEMO_NOT_SPACE, 1) + pynini.closure(
|
||||||
|
delete_extra_space + pynini.closure(NEMO_NOT_SPACE, 1)
|
||||||
|
)
|
||||||
|
remove_extra_spaces |= (
|
||||||
|
pynini.closure(pynutil.delete(" "), 1)
|
||||||
|
+ pynini.closure(NEMO_NOT_SPACE, 1)
|
||||||
|
+ pynini.closure(delete_extra_space + pynini.closure(NEMO_NOT_SPACE, 1))
|
||||||
|
)
|
||||||
|
|
||||||
|
graph = pynini.compose(graph.optimize(), remove_extra_spaces).optimize()
|
||||||
|
self.fst = graph
|
||||||
|
no_digits = pynini.closure(pynini.difference(NEMO_CHAR, NEMO_DIGIT))
|
||||||
|
self.fst_no_digits = pynini.compose(graph, no_digits).optimize()
|
||||||
|
|
||||||
|
if far_file:
|
||||||
|
generator_main(far_file, {"tokenize_and_classify": self.fst})
|
||||||
|
logging.info(f'ClassifyFst grammars are saved to {far_file}.')
|
||||||
@@ -0,0 +1,151 @@
|
|||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import pynini
|
||||||
|
from nemo_text_processing.text_normalization.en.graph_utils import (
|
||||||
|
NEMO_CHAR,
|
||||||
|
NEMO_NOT_SPACE,
|
||||||
|
NEMO_SIGMA,
|
||||||
|
NEMO_UPPER,
|
||||||
|
SINGULAR_TO_PLURAL,
|
||||||
|
GraphFst,
|
||||||
|
convert_space,
|
||||||
|
)
|
||||||
|
from nemo_text_processing.text_normalization.en.taggers.roman import get_names
|
||||||
|
from nemo_text_processing.text_normalization.en.utils import (
|
||||||
|
augment_labels_with_punct_at_end,
|
||||||
|
get_abs_path,
|
||||||
|
load_labels,
|
||||||
|
)
|
||||||
|
from pynini.lib import pynutil
|
||||||
|
|
||||||
|
|
||||||
|
class WhiteListFst(GraphFst):
|
||||||
|
"""
|
||||||
|
Finite state transducer for classifying whitelist, e.g.
|
||||||
|
misses -> tokens { name: "mrs" }
|
||||||
|
for non-deterministic case: "Dr. Abc" ->
|
||||||
|
tokens { name: "drive" } tokens { name: "Abc" }
|
||||||
|
tokens { name: "doctor" } tokens { name: "Abc" }
|
||||||
|
tokens { name: "Dr." } tokens { name: "Abc" }
|
||||||
|
This class has highest priority among all classifier grammars. Whitelisted tokens are defined and loaded from "data/whitelist.tsv".
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_case: accepting either "lower_cased" or "cased" input.
|
||||||
|
deterministic: if True will provide a single transduction option,
|
||||||
|
for False multiple options (used for audio-based normalization)
|
||||||
|
input_file: path to a file with whitelist replacements
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, input_case: str, deterministic: bool = True, input_file: str = None):
|
||||||
|
super().__init__(name="whitelist", kind="classify", deterministic=deterministic)
|
||||||
|
|
||||||
|
def _get_whitelist_graph(input_case, file, keep_punct_add_end: bool = False):
|
||||||
|
whitelist = load_labels(file)
|
||||||
|
if input_case == "lower_cased":
|
||||||
|
whitelist = [[x.lower(), y] for x, y in whitelist]
|
||||||
|
else:
|
||||||
|
whitelist = [[x, y] for x, y in whitelist]
|
||||||
|
|
||||||
|
if keep_punct_add_end:
|
||||||
|
whitelist.extend(augment_labels_with_punct_at_end(whitelist))
|
||||||
|
|
||||||
|
graph = pynini.string_map(whitelist)
|
||||||
|
return graph
|
||||||
|
|
||||||
|
graph = _get_whitelist_graph(input_case, get_abs_path("data/whitelist/tts.tsv"))
|
||||||
|
graph |= _get_whitelist_graph(input_case, get_abs_path("data/whitelist/UK_to_US.tsv")) # Jiayu 2022.10
|
||||||
|
graph |= pynini.compose(
|
||||||
|
pynini.difference(NEMO_SIGMA, pynini.accep("/")).optimize(),
|
||||||
|
_get_whitelist_graph(input_case, get_abs_path("data/whitelist/symbol.tsv")),
|
||||||
|
).optimize()
|
||||||
|
|
||||||
|
if deterministic:
|
||||||
|
names = get_names()
|
||||||
|
graph |= (
|
||||||
|
pynini.cross(pynini.union("st", "St", "ST"), "Saint")
|
||||||
|
+ pynini.closure(pynutil.delete("."))
|
||||||
|
+ pynini.accep(" ")
|
||||||
|
+ names
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
graph |= _get_whitelist_graph(
|
||||||
|
input_case, get_abs_path("data/whitelist/alternatives.tsv"), keep_punct_add_end=True
|
||||||
|
)
|
||||||
|
|
||||||
|
for x in [".", ". "]:
|
||||||
|
graph |= (
|
||||||
|
NEMO_UPPER
|
||||||
|
+ pynini.closure(pynutil.delete(x) + NEMO_UPPER, 2)
|
||||||
|
+ pynini.closure(pynutil.delete("."), 0, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not deterministic:
|
||||||
|
multiple_forms_whitelist_graph = get_formats(get_abs_path("data/whitelist/alternatives_all_format.tsv"))
|
||||||
|
graph |= multiple_forms_whitelist_graph
|
||||||
|
|
||||||
|
graph_unit = pynini.string_file(get_abs_path("data/measure/unit.tsv")) | pynini.string_file(
|
||||||
|
get_abs_path("data/measure/unit_alternatives.tsv")
|
||||||
|
)
|
||||||
|
graph_unit_plural = graph_unit @ SINGULAR_TO_PLURAL
|
||||||
|
units_graph = pynini.compose(NEMO_CHAR ** (3, ...), convert_space(graph_unit | graph_unit_plural))
|
||||||
|
graph |= units_graph
|
||||||
|
|
||||||
|
# convert to states only if comma is present before the abbreviation to avoid converting all caps words,
|
||||||
|
# e.g. "IN", "OH", "OK"
|
||||||
|
# TODO or only exclude above?
|
||||||
|
states = load_labels(get_abs_path("data/address/state.tsv"))
|
||||||
|
additional_options = []
|
||||||
|
for x, y in states:
|
||||||
|
if input_case == "lower_cased":
|
||||||
|
x = x.lower()
|
||||||
|
additional_options.append((x, f"{y[0]}.{y[1:]}"))
|
||||||
|
if not deterministic:
|
||||||
|
additional_options.append((x, f"{y[0]}.{y[1:]}."))
|
||||||
|
|
||||||
|
states.extend(additional_options)
|
||||||
|
state_graph = pynini.string_map(states)
|
||||||
|
graph |= pynini.closure(NEMO_NOT_SPACE, 1) + pynini.union(", ", ",") + pynini.invert(state_graph).optimize()
|
||||||
|
|
||||||
|
if input_file:
|
||||||
|
whitelist_provided = _get_whitelist_graph(input_case, input_file)
|
||||||
|
if not deterministic:
|
||||||
|
graph |= whitelist_provided
|
||||||
|
else:
|
||||||
|
graph = whitelist_provided
|
||||||
|
|
||||||
|
self.graph = (convert_space(graph)).optimize()
|
||||||
|
|
||||||
|
self.fst = (pynutil.insert("name: \"") + self.graph + pynutil.insert("\"")).optimize()
|
||||||
|
|
||||||
|
|
||||||
|
def get_formats(input_f, input_case="cased", is_default=True):
|
||||||
|
"""
|
||||||
|
Adds various abbreviation format options to the list of acceptable input forms
|
||||||
|
"""
|
||||||
|
multiple_formats = load_labels(input_f)
|
||||||
|
additional_options = []
|
||||||
|
for x, y in multiple_formats:
|
||||||
|
if input_case == "lower_cased":
|
||||||
|
x = x.lower()
|
||||||
|
additional_options.append((f"{x}.", y)) # default "dr" -> doctor, this includes period "dr." -> doctor
|
||||||
|
additional_options.append((f"{x[0].upper() + x[1:]}", f"{y[0].upper() + y[1:]}")) # "Dr" -> Doctor
|
||||||
|
additional_options.append((f"{x[0].upper() + x[1:]}.", f"{y[0].upper() + y[1:]}")) # "Dr." -> Doctor
|
||||||
|
multiple_formats.extend(additional_options)
|
||||||
|
|
||||||
|
if not is_default:
|
||||||
|
multiple_formats = [(x, f"|raw_start|{x}|raw_end||norm_start|{y}|norm_end|") for (x, y) in multiple_formats]
|
||||||
|
|
||||||
|
multiple_formats = pynini.string_map(multiple_formats)
|
||||||
|
return multiple_formats
|
||||||
@@ -0,0 +1,90 @@
|
|||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import pynini
|
||||||
|
from nemo_text_processing.text_normalization.en.graph_utils import (
|
||||||
|
MIN_NEG_WEIGHT,
|
||||||
|
NEMO_ALPHA,
|
||||||
|
NEMO_DIGIT,
|
||||||
|
NEMO_NOT_SPACE,
|
||||||
|
NEMO_SIGMA,
|
||||||
|
GraphFst,
|
||||||
|
convert_space,
|
||||||
|
get_abs_path,
|
||||||
|
)
|
||||||
|
from nemo_text_processing.text_normalization.en.taggers.punctuation import PunctuationFst
|
||||||
|
from pynini.examples import plurals
|
||||||
|
from pynini.lib import pynutil
|
||||||
|
|
||||||
|
|
||||||
|
class WordFst(GraphFst):
|
||||||
|
"""
|
||||||
|
Finite state transducer for classifying word. Considers sentence boundary exceptions.
|
||||||
|
e.g. sleep -> tokens { name: "sleep" }
|
||||||
|
|
||||||
|
Args:
|
||||||
|
punctuation: PunctuationFst
|
||||||
|
deterministic: if True will provide a single transduction option,
|
||||||
|
for False multiple transduction are generated (used for audio-based normalization)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, punctuation: GraphFst, deterministic: bool = True):
|
||||||
|
super().__init__(name="word", kind="classify", deterministic=deterministic)
|
||||||
|
|
||||||
|
punct = PunctuationFst().graph
|
||||||
|
default_graph = pynini.closure(pynini.difference(NEMO_NOT_SPACE, punct.project("input")), 1)
|
||||||
|
symbols_to_exclude = (pynini.union("$", "€", "₩", "£", "¥", "#", "%") | NEMO_DIGIT).optimize()
|
||||||
|
graph = pynini.closure(pynini.difference(NEMO_NOT_SPACE, symbols_to_exclude), 1)
|
||||||
|
graph = pynutil.add_weight(graph, MIN_NEG_WEIGHT) | default_graph
|
||||||
|
|
||||||
|
# leave phones of format [HH AH0 L OW1] untouched
|
||||||
|
phoneme_unit = pynini.closure(NEMO_ALPHA, 1) + pynini.closure(NEMO_DIGIT)
|
||||||
|
phoneme = (
|
||||||
|
pynini.accep(pynini.escape("["))
|
||||||
|
+ pynini.closure(phoneme_unit + pynini.accep(" "))
|
||||||
|
+ phoneme_unit
|
||||||
|
+ pynini.accep(pynini.escape("]"))
|
||||||
|
)
|
||||||
|
|
||||||
|
# leave IPA phones of format [ˈdoʊv] untouched, single words and sentences with punctuation marks allowed
|
||||||
|
punct_marks = pynini.union(*punctuation.punct_marks).optimize()
|
||||||
|
stress = pynini.union("ˈ", "'", "ˌ")
|
||||||
|
ipa_phoneme_unit = pynini.string_file(get_abs_path("data/whitelist/ipa_symbols.tsv"))
|
||||||
|
# word in ipa form
|
||||||
|
ipa_phonemes = (
|
||||||
|
pynini.closure(stress, 0, 1)
|
||||||
|
+ pynini.closure(ipa_phoneme_unit, 1)
|
||||||
|
+ pynini.closure(stress | ipa_phoneme_unit)
|
||||||
|
)
|
||||||
|
# allow sentences of words in IPA format separated with spaces or punct marks
|
||||||
|
delim = (punct_marks | pynini.accep(" ")) ** (1, ...)
|
||||||
|
ipa_phonemes = ipa_phonemes + pynini.closure(delim + ipa_phonemes) + pynini.closure(delim, 0, 1)
|
||||||
|
ipa_phonemes = (pynini.accep(pynini.escape("[")) + ipa_phonemes + pynini.accep(pynini.escape("]"))).optimize()
|
||||||
|
|
||||||
|
if not deterministic:
|
||||||
|
phoneme = (
|
||||||
|
pynini.accep(pynini.escape("["))
|
||||||
|
+ pynini.closure(pynini.accep(" "), 0, 1)
|
||||||
|
+ pynini.closure(phoneme_unit + pynini.accep(" "))
|
||||||
|
+ phoneme_unit
|
||||||
|
+ pynini.closure(pynini.accep(" "), 0, 1)
|
||||||
|
+ pynini.accep(pynini.escape("]"))
|
||||||
|
).optimize()
|
||||||
|
ipa_phonemes = (
|
||||||
|
pynini.accep(pynini.escape("[")) + ipa_phonemes + pynini.accep(pynini.escape("]"))
|
||||||
|
).optimize()
|
||||||
|
|
||||||
|
phoneme |= ipa_phonemes
|
||||||
|
self.graph = plurals._priority_union(convert_space(phoneme.optimize()), graph, NEMO_SIGMA)
|
||||||
|
self.fst = (pynutil.insert("name: \"") + self.graph + pynutil.insert("\"")).optimize()
|
||||||
@@ -0,0 +1,60 @@
|
|||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import csv
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def get_abs_path(rel_path):
|
||||||
|
"""
|
||||||
|
Get absolute path
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rel_path: relative path to this file
|
||||||
|
|
||||||
|
Returns absolute path
|
||||||
|
"""
|
||||||
|
return os.path.dirname(os.path.abspath(__file__)) + '/' + rel_path
|
||||||
|
|
||||||
|
|
||||||
|
def load_labels(abs_path):
|
||||||
|
"""
|
||||||
|
loads relative path file as dictionary
|
||||||
|
|
||||||
|
Args:
|
||||||
|
abs_path: absolute path
|
||||||
|
|
||||||
|
Returns dictionary of mappings
|
||||||
|
"""
|
||||||
|
label_tsv = open(abs_path, encoding="utf-8")
|
||||||
|
labels = list(csv.reader(label_tsv, delimiter="\t"))
|
||||||
|
return labels
|
||||||
|
|
||||||
|
|
||||||
|
def augment_labels_with_punct_at_end(labels):
|
||||||
|
"""
|
||||||
|
augments labels: if key ends on a punctuation that value does not have, add a new label
|
||||||
|
where the value maintains the punctuation
|
||||||
|
|
||||||
|
Args:
|
||||||
|
labels : input labels
|
||||||
|
Returns:
|
||||||
|
additional labels
|
||||||
|
"""
|
||||||
|
res = []
|
||||||
|
for label in labels:
|
||||||
|
if len(label) > 1:
|
||||||
|
if label[0][-1] == "." and label[1][-1] != ".":
|
||||||
|
res.append([label[0], label[1] + "."] + label[2:])
|
||||||
|
return res
|
||||||
@@ -0,0 +1,13 @@
|
|||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
@@ -0,0 +1,35 @@
|
|||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import pynini
|
||||||
|
from nemo_text_processing.text_normalization.en.graph_utils import NEMO_NOT_QUOTE, GraphFst
|
||||||
|
from pynini.lib import pynutil
|
||||||
|
|
||||||
|
|
||||||
|
class AbbreviationFst(GraphFst):
|
||||||
|
"""
|
||||||
|
Finite state transducer for verbalizing abbreviations
|
||||||
|
e.g. tokens { abbreviation { value: "A B C" } } -> "ABC"
|
||||||
|
|
||||||
|
Args:
|
||||||
|
deterministic: if True will provide a single transduction option,
|
||||||
|
for False multiple transduction are generated (used for audio-based normalization)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, deterministic: bool = True):
|
||||||
|
super().__init__(name="abbreviation", kind="verbalize", deterministic=deterministic)
|
||||||
|
|
||||||
|
graph = pynutil.delete("value: \"") + pynini.closure(NEMO_NOT_QUOTE, 1) + pynutil.delete("\"")
|
||||||
|
delete_tokens = self.delete_tokens(graph)
|
||||||
|
self.fst = delete_tokens.optimize()
|
||||||
@@ -0,0 +1,45 @@
|
|||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import pynini
|
||||||
|
from nemo_text_processing.text_normalization.en.graph_utils import NEMO_NOT_QUOTE, GraphFst, delete_space
|
||||||
|
from pynini.lib import pynutil
|
||||||
|
|
||||||
|
|
||||||
|
class CardinalFst(GraphFst):
|
||||||
|
"""
|
||||||
|
Finite state transducer for verbalizing cardinal, e.g.
|
||||||
|
cardinal { negative: "true" integer: "23" } -> minus twenty three
|
||||||
|
|
||||||
|
Args:
|
||||||
|
deterministic: if True will provide a single transduction option,
|
||||||
|
for False multiple options (used for audio-based normalization)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, deterministic: bool = True):
|
||||||
|
super().__init__(name="cardinal", kind="verbalize", deterministic=deterministic)
|
||||||
|
|
||||||
|
self.optional_sign = pynini.cross("negative: \"true\"", "minus ")
|
||||||
|
if not deterministic:
|
||||||
|
self.optional_sign |= pynini.cross("negative: \"true\"", "negative ")
|
||||||
|
self.optional_sign = pynini.closure(self.optional_sign + delete_space, 0, 1)
|
||||||
|
|
||||||
|
integer = pynini.closure(NEMO_NOT_QUOTE)
|
||||||
|
|
||||||
|
self.integer = delete_space + pynutil.delete("\"") + integer + pynutil.delete("\"")
|
||||||
|
integer = pynutil.delete("integer:") + self.integer
|
||||||
|
|
||||||
|
self.numbers = self.optional_sign + integer
|
||||||
|
delete_tokens = self.delete_tokens(self.numbers)
|
||||||
|
self.fst = delete_tokens.optimize()
|
||||||
@@ -0,0 +1,101 @@
|
|||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import pynini
|
||||||
|
from nemo_text_processing.text_normalization.en.graph_utils import (
|
||||||
|
NEMO_NOT_QUOTE,
|
||||||
|
NEMO_SIGMA,
|
||||||
|
GraphFst,
|
||||||
|
delete_extra_space,
|
||||||
|
delete_space,
|
||||||
|
)
|
||||||
|
from pynini.examples import plurals
|
||||||
|
from pynini.lib import pynutil
|
||||||
|
|
||||||
|
|
||||||
|
class DateFst(GraphFst):
|
||||||
|
"""
|
||||||
|
Finite state transducer for verbalizing date, e.g.
|
||||||
|
date { month: "february" day: "five" year: "twenty twelve" preserve_order: true } -> february fifth twenty twelve
|
||||||
|
date { day: "five" month: "february" year: "twenty twelve" preserve_order: true } -> the fifth of february twenty twelve
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ordinal: OrdinalFst
|
||||||
|
deterministic: if True will provide a single transduction option,
|
||||||
|
for False multiple transduction are generated (used for audio-based normalization)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, ordinal: GraphFst, deterministic: bool = True, lm: bool = False):
|
||||||
|
super().__init__(name="date", kind="verbalize", deterministic=deterministic)
|
||||||
|
|
||||||
|
month = pynini.closure(NEMO_NOT_QUOTE, 1)
|
||||||
|
day_cardinal = (
|
||||||
|
pynutil.delete("day:")
|
||||||
|
+ delete_space
|
||||||
|
+ pynutil.delete("\"")
|
||||||
|
+ pynini.closure(NEMO_NOT_QUOTE, 1)
|
||||||
|
+ pynutil.delete("\"")
|
||||||
|
)
|
||||||
|
day = day_cardinal @ ordinal.suffix
|
||||||
|
|
||||||
|
month = pynutil.delete("month:") + delete_space + pynutil.delete("\"") + month + pynutil.delete("\"")
|
||||||
|
|
||||||
|
year = (
|
||||||
|
pynutil.delete("year:")
|
||||||
|
+ delete_space
|
||||||
|
+ pynutil.delete("\"")
|
||||||
|
+ pynini.closure(NEMO_NOT_QUOTE, 1)
|
||||||
|
+ delete_space
|
||||||
|
+ pynutil.delete("\"")
|
||||||
|
)
|
||||||
|
|
||||||
|
# month (day) year
|
||||||
|
graph_mdy = (
|
||||||
|
month + pynini.closure(delete_extra_space + day, 0, 1) + pynini.closure(delete_extra_space + year, 0, 1)
|
||||||
|
)
|
||||||
|
# may 5 -> may five
|
||||||
|
if not deterministic and not lm:
|
||||||
|
graph_mdy |= (
|
||||||
|
month
|
||||||
|
+ pynini.closure(delete_extra_space + day_cardinal, 0, 1)
|
||||||
|
+ pynini.closure(delete_extra_space + year, 0, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
# day month year
|
||||||
|
graph_dmy = (
|
||||||
|
pynutil.insert("the ")
|
||||||
|
+ day
|
||||||
|
+ delete_extra_space
|
||||||
|
+ pynutil.insert("of ")
|
||||||
|
+ month
|
||||||
|
+ pynini.closure(delete_extra_space + year, 0, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
optional_preserve_order = pynini.closure(
|
||||||
|
pynutil.delete("preserve_order:") + delete_space + pynutil.delete("true") + delete_space
|
||||||
|
| pynutil.delete("field_order:")
|
||||||
|
+ delete_space
|
||||||
|
+ pynutil.delete("\"")
|
||||||
|
+ NEMO_NOT_QUOTE
|
||||||
|
+ pynutil.delete("\"")
|
||||||
|
+ delete_space
|
||||||
|
)
|
||||||
|
|
||||||
|
final_graph = (
|
||||||
|
(plurals._priority_union(graph_mdy, pynutil.add_weight(graph_dmy, 0.0001), NEMO_SIGMA) | year)
|
||||||
|
+ delete_space
|
||||||
|
+ optional_preserve_order
|
||||||
|
)
|
||||||
|
delete_tokens = self.delete_tokens(final_graph)
|
||||||
|
self.fst = delete_tokens.optimize()
|
||||||
@@ -0,0 +1,67 @@
|
|||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import pynini
|
||||||
|
from nemo_text_processing.text_normalization.en.graph_utils import NEMO_NOT_QUOTE, GraphFst, delete_space, insert_space
|
||||||
|
from pynini.lib import pynutil
|
||||||
|
|
||||||
|
|
||||||
|
class DecimalFst(GraphFst):
|
||||||
|
"""
|
||||||
|
Finite state transducer for verbalizing decimal, e.g.
|
||||||
|
decimal { negative: "true" integer_part: "twelve" fractional_part: "five o o six" quantity: "billion" } -> minus twelve point five o o six billion
|
||||||
|
|
||||||
|
Args:
|
||||||
|
deterministic: if True will provide a single transduction option,
|
||||||
|
for False multiple transduction are generated (used for audio-based normalization)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cardinal, deterministic: bool = True):
|
||||||
|
super().__init__(name="decimal", kind="verbalize", deterministic=deterministic)
|
||||||
|
self.optional_sign = pynini.cross("negative: \"true\"", "minus ")
|
||||||
|
if not deterministic:
|
||||||
|
self.optional_sign |= pynini.cross("negative: \"true\"", "negative ")
|
||||||
|
self.optional_sign = pynini.closure(self.optional_sign + delete_space, 0, 1)
|
||||||
|
self.integer = pynutil.delete("integer_part:") + cardinal.integer
|
||||||
|
self.optional_integer = pynini.closure(self.integer + delete_space + insert_space, 0, 1)
|
||||||
|
self.fractional_default = (
|
||||||
|
pynutil.delete("fractional_part:")
|
||||||
|
+ delete_space
|
||||||
|
+ pynutil.delete("\"")
|
||||||
|
+ pynini.closure(NEMO_NOT_QUOTE, 1)
|
||||||
|
+ pynutil.delete("\"")
|
||||||
|
)
|
||||||
|
|
||||||
|
self.fractional = pynutil.insert("point ") + self.fractional_default
|
||||||
|
|
||||||
|
self.quantity = (
|
||||||
|
delete_space
|
||||||
|
+ insert_space
|
||||||
|
+ pynutil.delete("quantity:")
|
||||||
|
+ delete_space
|
||||||
|
+ pynutil.delete("\"")
|
||||||
|
+ pynini.closure(NEMO_NOT_QUOTE, 1)
|
||||||
|
+ pynutil.delete("\"")
|
||||||
|
)
|
||||||
|
self.optional_quantity = pynini.closure(self.quantity, 0, 1)
|
||||||
|
|
||||||
|
graph = self.optional_sign + (
|
||||||
|
self.integer
|
||||||
|
| (self.integer + self.quantity)
|
||||||
|
| (self.optional_integer + self.fractional + self.optional_quantity)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.numbers = graph
|
||||||
|
delete_tokens = self.delete_tokens(graph)
|
||||||
|
self.fst = delete_tokens.optimize()
|
||||||
@@ -0,0 +1,97 @@
|
|||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import pynini
|
||||||
|
from nemo_text_processing.text_normalization.en.graph_utils import (
|
||||||
|
NEMO_NOT_QUOTE,
|
||||||
|
NEMO_NOT_SPACE,
|
||||||
|
NEMO_SIGMA,
|
||||||
|
TO_UPPER,
|
||||||
|
GraphFst,
|
||||||
|
delete_extra_space,
|
||||||
|
delete_space,
|
||||||
|
insert_space,
|
||||||
|
)
|
||||||
|
from nemo_text_processing.text_normalization.en.utils import get_abs_path
|
||||||
|
from pynini.examples import plurals
|
||||||
|
from pynini.lib import pynutil
|
||||||
|
|
||||||
|
|
||||||
|
class ElectronicFst(GraphFst):
|
||||||
|
"""
|
||||||
|
Finite state transducer for verbalizing electronic
|
||||||
|
e.g. tokens { electronic { username: "cdf1" domain: "abc.edu" } } -> c d f one at a b c dot e d u
|
||||||
|
|
||||||
|
Args:
|
||||||
|
deterministic: if True will provide a single transduction option,
|
||||||
|
for False multiple transduction are generated (used for audio-based normalization)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, deterministic: bool = True):
|
||||||
|
super().__init__(name="electronic", kind="verbalize", deterministic=deterministic)
|
||||||
|
graph_digit_no_zero = pynini.invert(pynini.string_file(get_abs_path("data/number/digit.tsv"))).optimize()
|
||||||
|
graph_zero = pynini.cross("0", "zero")
|
||||||
|
|
||||||
|
if not deterministic:
|
||||||
|
graph_zero |= pynini.cross("0", "o") | pynini.cross("0", "oh")
|
||||||
|
|
||||||
|
graph_digit = graph_digit_no_zero | graph_zero
|
||||||
|
graph_symbols = pynini.string_file(get_abs_path("data/electronic/symbol.tsv")).optimize()
|
||||||
|
|
||||||
|
default_chars_symbols = pynini.cdrewrite(
|
||||||
|
pynutil.insert(" ") + (graph_symbols | graph_digit) + pynutil.insert(" "), "", "", NEMO_SIGMA
|
||||||
|
)
|
||||||
|
default_chars_symbols = pynini.compose(
|
||||||
|
pynini.closure(NEMO_NOT_SPACE), default_chars_symbols.optimize()
|
||||||
|
).optimize()
|
||||||
|
|
||||||
|
user_name = (
|
||||||
|
pynutil.delete("username:")
|
||||||
|
+ delete_space
|
||||||
|
+ pynutil.delete("\"")
|
||||||
|
+ default_chars_symbols
|
||||||
|
+ pynutil.delete("\"")
|
||||||
|
)
|
||||||
|
|
||||||
|
domain_common = pynini.string_file(get_abs_path("data/electronic/domain.tsv"))
|
||||||
|
|
||||||
|
domain = (
|
||||||
|
default_chars_symbols
|
||||||
|
+ insert_space
|
||||||
|
+ plurals._priority_union(
|
||||||
|
domain_common, pynutil.add_weight(pynini.cross(".", "dot"), weight=0.0001), NEMO_SIGMA
|
||||||
|
)
|
||||||
|
+ pynini.closure(
|
||||||
|
insert_space + (pynini.cdrewrite(TO_UPPER, "", "", NEMO_SIGMA) @ default_chars_symbols), 0, 1
|
||||||
|
)
|
||||||
|
)
|
||||||
|
domain = (
|
||||||
|
pynutil.delete("domain:")
|
||||||
|
+ delete_space
|
||||||
|
+ pynutil.delete("\"")
|
||||||
|
+ domain
|
||||||
|
+ delete_space
|
||||||
|
+ pynutil.delete("\"")
|
||||||
|
).optimize()
|
||||||
|
|
||||||
|
protocol = pynutil.delete("protocol: \"") + pynini.closure(NEMO_NOT_QUOTE, 1) + pynutil.delete("\"")
|
||||||
|
graph = (
|
||||||
|
pynini.closure(protocol + delete_space, 0, 1)
|
||||||
|
+ pynini.closure(user_name + delete_space + pynutil.insert(" at ") + delete_space, 0, 1)
|
||||||
|
+ domain
|
||||||
|
+ delete_space
|
||||||
|
).optimize() @ pynini.cdrewrite(delete_extra_space, "", "", NEMO_SIGMA)
|
||||||
|
|
||||||
|
delete_tokens = self.delete_tokens(graph)
|
||||||
|
self.fst = delete_tokens.optimize()
|
||||||
@@ -0,0 +1,88 @@
|
|||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import pynini
|
||||||
|
from nemo_text_processing.text_normalization.en.graph_utils import NEMO_NOT_QUOTE, NEMO_SIGMA, GraphFst, insert_space
|
||||||
|
from nemo_text_processing.text_normalization.en.verbalizers.ordinal import OrdinalFst
|
||||||
|
from pynini.examples import plurals
|
||||||
|
from pynini.lib import pynutil
|
||||||
|
|
||||||
|
|
||||||
|
class FractionFst(GraphFst):
|
||||||
|
"""
|
||||||
|
Finite state transducer for verbalizing fraction
|
||||||
|
e.g. tokens { fraction { integer: "twenty three" numerator: "four" denominator: "five" } } ->
|
||||||
|
twenty three and four fifth
|
||||||
|
|
||||||
|
Args:
|
||||||
|
deterministic: if True will provide a single transduction option,
|
||||||
|
for False multiple transduction are generated (used for audio-based normalization)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, deterministic: bool = True, lm: bool = False):
|
||||||
|
super().__init__(name="fraction", kind="verbalize", deterministic=deterministic)
|
||||||
|
suffix = OrdinalFst().suffix
|
||||||
|
|
||||||
|
integer = pynutil.delete("integer_part: \"") + pynini.closure(NEMO_NOT_QUOTE) + pynutil.delete("\" ")
|
||||||
|
denominator_one = pynini.cross("denominator: \"one\"", "over one")
|
||||||
|
denominator_half = pynini.cross("denominator: \"two\"", "half")
|
||||||
|
denominator_quarter = pynini.cross("denominator: \"four\"", "quarter")
|
||||||
|
|
||||||
|
denominator_rest = (
|
||||||
|
pynutil.delete("denominator: \"") + pynini.closure(NEMO_NOT_QUOTE) @ suffix + pynutil.delete("\"")
|
||||||
|
)
|
||||||
|
|
||||||
|
denominators = plurals._priority_union(
|
||||||
|
denominator_one,
|
||||||
|
plurals._priority_union(
|
||||||
|
denominator_half,
|
||||||
|
plurals._priority_union(denominator_quarter, denominator_rest, NEMO_SIGMA),
|
||||||
|
NEMO_SIGMA,
|
||||||
|
),
|
||||||
|
NEMO_SIGMA,
|
||||||
|
).optimize()
|
||||||
|
if not deterministic:
|
||||||
|
denominators |= pynutil.delete("denominator: \"") + (pynini.accep("four") @ suffix) + pynutil.delete("\"")
|
||||||
|
|
||||||
|
numerator_one = pynutil.delete("numerator: \"") + pynini.accep("one") + pynutil.delete("\" ")
|
||||||
|
numerator_one = numerator_one + insert_space + denominators
|
||||||
|
numerator_rest = (
|
||||||
|
pynutil.delete("numerator: \"")
|
||||||
|
+ (pynini.closure(NEMO_NOT_QUOTE) - pynini.accep("one"))
|
||||||
|
+ pynutil.delete("\" ")
|
||||||
|
)
|
||||||
|
numerator_rest = numerator_rest + insert_space + denominators
|
||||||
|
numerator_rest @= pynini.cdrewrite(
|
||||||
|
plurals._priority_union(pynini.cross("half", "halves"), pynutil.insert("s"), NEMO_SIGMA),
|
||||||
|
"",
|
||||||
|
"[EOS]",
|
||||||
|
NEMO_SIGMA,
|
||||||
|
)
|
||||||
|
|
||||||
|
graph = numerator_one | numerator_rest
|
||||||
|
|
||||||
|
conjunction = pynutil.insert("and ")
|
||||||
|
if not deterministic and not lm:
|
||||||
|
conjunction = pynini.closure(conjunction, 0, 1)
|
||||||
|
|
||||||
|
integer = pynini.closure(integer + insert_space + conjunction, 0, 1)
|
||||||
|
|
||||||
|
graph = integer + graph
|
||||||
|
graph @= pynini.cdrewrite(
|
||||||
|
pynini.cross("and one half", "and a half") | pynini.cross("over ones", "over one"), "", "[EOS]", NEMO_SIGMA
|
||||||
|
)
|
||||||
|
|
||||||
|
self.graph = graph
|
||||||
|
delete_tokens = self.delete_tokens(self.graph)
|
||||||
|
self.fst = delete_tokens.optimize()
|
||||||
@@ -0,0 +1,102 @@
|
|||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import pynini
|
||||||
|
from nemo_text_processing.text_normalization.en.graph_utils import NEMO_NOT_QUOTE, GraphFst, delete_space, insert_space
|
||||||
|
from pynini.lib import pynutil
|
||||||
|
|
||||||
|
|
||||||
|
class MeasureFst(GraphFst):
|
||||||
|
"""
|
||||||
|
Finite state transducer for verbalizing measure, e.g.
|
||||||
|
measure { negative: "true" cardinal { integer: "twelve" } units: "kilograms" } -> minus twelve kilograms
|
||||||
|
measure { decimal { integer_part: "twelve" fractional_part: "five" } units: "kilograms" } -> twelve point five kilograms
|
||||||
|
tokens { measure { units: "covid" decimal { integer_part: "nineteen" fractional_part: "five" } } } -> covid nineteen point five
|
||||||
|
|
||||||
|
Args:
|
||||||
|
decimal: DecimalFst
|
||||||
|
cardinal: CardinalFst
|
||||||
|
fraction: FractionFst
|
||||||
|
deterministic: if True will provide a single transduction option,
|
||||||
|
for False multiple transduction are generated (used for audio-based normalization)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, decimal: GraphFst, cardinal: GraphFst, fraction: GraphFst, deterministic: bool = True):
|
||||||
|
super().__init__(name="measure", kind="verbalize", deterministic=deterministic)
|
||||||
|
optional_sign = cardinal.optional_sign
|
||||||
|
unit = (
|
||||||
|
pynutil.delete("units: \"")
|
||||||
|
+ pynini.difference(pynini.closure(NEMO_NOT_QUOTE, 1), pynini.union("address", "math"))
|
||||||
|
+ pynutil.delete("\"")
|
||||||
|
+ delete_space
|
||||||
|
)
|
||||||
|
|
||||||
|
if not deterministic:
|
||||||
|
unit |= pynini.compose(unit, pynini.cross(pynini.union("inch", "inches"), "\""))
|
||||||
|
|
||||||
|
graph_decimal = (
|
||||||
|
pynutil.delete("decimal {")
|
||||||
|
+ delete_space
|
||||||
|
+ optional_sign
|
||||||
|
+ delete_space
|
||||||
|
+ decimal.numbers
|
||||||
|
+ delete_space
|
||||||
|
+ pynutil.delete("}")
|
||||||
|
)
|
||||||
|
graph_cardinal = (
|
||||||
|
pynutil.delete("cardinal {")
|
||||||
|
+ delete_space
|
||||||
|
+ optional_sign
|
||||||
|
+ delete_space
|
||||||
|
+ cardinal.numbers
|
||||||
|
+ delete_space
|
||||||
|
+ pynutil.delete("}")
|
||||||
|
)
|
||||||
|
|
||||||
|
graph_fraction = (
|
||||||
|
pynutil.delete("fraction {") + delete_space + fraction.graph + delete_space + pynutil.delete("}")
|
||||||
|
)
|
||||||
|
|
||||||
|
graph = (graph_cardinal | graph_decimal | graph_fraction) + delete_space + insert_space + unit
|
||||||
|
|
||||||
|
# SH adds "preserve_order: true" by default
|
||||||
|
preserve_order = pynutil.delete("preserve_order:") + delete_space + pynutil.delete("true") + delete_space
|
||||||
|
graph |= unit + insert_space + (graph_cardinal | graph_decimal) + delete_space + pynini.closure(preserve_order)
|
||||||
|
# for only unit
|
||||||
|
graph |= (
|
||||||
|
pynutil.delete("cardinal { integer: \"-\"")
|
||||||
|
+ delete_space
|
||||||
|
+ pynutil.delete("}")
|
||||||
|
+ delete_space
|
||||||
|
+ unit
|
||||||
|
+ pynini.closure(preserve_order)
|
||||||
|
)
|
||||||
|
address = (
|
||||||
|
pynutil.delete("units: \"address\" ")
|
||||||
|
+ delete_space
|
||||||
|
+ graph_cardinal
|
||||||
|
+ delete_space
|
||||||
|
+ pynini.closure(preserve_order)
|
||||||
|
)
|
||||||
|
math = (
|
||||||
|
pynutil.delete("units: \"math\" ")
|
||||||
|
+ delete_space
|
||||||
|
+ graph_cardinal
|
||||||
|
+ delete_space
|
||||||
|
+ pynini.closure(preserve_order)
|
||||||
|
)
|
||||||
|
graph |= address | math
|
||||||
|
|
||||||
|
delete_tokens = self.delete_tokens(graph)
|
||||||
|
self.fst = delete_tokens.optimize()
|
||||||
@@ -0,0 +1,71 @@
|
|||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import pynini
|
||||||
|
from nemo_text_processing.text_normalization.en.graph_utils import (
|
||||||
|
NEMO_NOT_QUOTE,
|
||||||
|
GraphFst,
|
||||||
|
delete_extra_space,
|
||||||
|
delete_preserve_order,
|
||||||
|
)
|
||||||
|
from pynini.lib import pynutil
|
||||||
|
|
||||||
|
|
||||||
|
class MoneyFst(GraphFst):
|
||||||
|
"""
|
||||||
|
Finite state transducer for verbalizing money, e.g.
|
||||||
|
money { integer_part: "twelve" fractional_part: "o five" currency: "dollars" } -> twelve o five dollars
|
||||||
|
|
||||||
|
Args:
|
||||||
|
decimal: DecimalFst
|
||||||
|
deterministic: if True will provide a single transduction option,
|
||||||
|
for False multiple transduction are generated (used for audio-based normalization)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, decimal: GraphFst, deterministic: bool = True):
|
||||||
|
super().__init__(name="money", kind="verbalize", deterministic=deterministic)
|
||||||
|
keep_space = pynini.accep(" ")
|
||||||
|
maj = pynutil.delete("currency_maj: \"") + pynini.closure(NEMO_NOT_QUOTE, 1) + pynutil.delete("\"")
|
||||||
|
min = pynutil.delete("currency_min: \"") + pynini.closure(NEMO_NOT_QUOTE, 1) + pynutil.delete("\"")
|
||||||
|
|
||||||
|
fractional_part = (
|
||||||
|
pynutil.delete("fractional_part: \"") + pynini.closure(NEMO_NOT_QUOTE, 1) + pynutil.delete("\"")
|
||||||
|
)
|
||||||
|
|
||||||
|
integer_part = decimal.integer
|
||||||
|
|
||||||
|
# *** currency_maj
|
||||||
|
graph_integer = integer_part + keep_space + maj
|
||||||
|
|
||||||
|
# *** currency_maj + (***) | ((and) *** current_min)
|
||||||
|
fractional = fractional_part + delete_extra_space + min
|
||||||
|
|
||||||
|
if not deterministic:
|
||||||
|
fractional |= pynutil.insert("and ") + fractional
|
||||||
|
|
||||||
|
graph_integer_with_minor = integer_part + keep_space + maj + keep_space + fractional + delete_preserve_order
|
||||||
|
|
||||||
|
# *** point *** currency_maj
|
||||||
|
graph_decimal = decimal.numbers + keep_space + maj
|
||||||
|
|
||||||
|
# *** current_min
|
||||||
|
graph_minor = fractional_part + delete_extra_space + min + delete_preserve_order
|
||||||
|
|
||||||
|
graph = graph_integer | graph_integer_with_minor | graph_decimal | graph_minor
|
||||||
|
|
||||||
|
if not deterministic:
|
||||||
|
graph |= graph_integer + delete_preserve_order
|
||||||
|
|
||||||
|
delete_tokens = self.delete_tokens(graph)
|
||||||
|
self.fst = delete_tokens.optimize()
|
||||||
@@ -0,0 +1,53 @@
|
|||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import pynini
|
||||||
|
from nemo_text_processing.text_normalization.en.graph_utils import NEMO_NOT_QUOTE, NEMO_SIGMA, GraphFst, delete_space
|
||||||
|
from nemo_text_processing.text_normalization.en.utils import get_abs_path
|
||||||
|
from pynini.lib import pynutil
|
||||||
|
|
||||||
|
|
||||||
|
class OrdinalFst(GraphFst):
|
||||||
|
"""
|
||||||
|
Finite state transducer for verbalizing ordinal, e.g.
|
||||||
|
ordinal { integer: "thirteen" } } -> thirteenth
|
||||||
|
|
||||||
|
Args:
|
||||||
|
deterministic: if True will provide a single transduction option,
|
||||||
|
for False multiple transduction are generated (used for audio-based normalization)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, deterministic: bool = True):
|
||||||
|
super().__init__(name="ordinal", kind="verbalize", deterministic=deterministic)
|
||||||
|
|
||||||
|
graph_digit = pynini.string_file(get_abs_path("data/ordinal/digit.tsv")).invert()
|
||||||
|
graph_teens = pynini.string_file(get_abs_path("data/ordinal/teen.tsv")).invert()
|
||||||
|
|
||||||
|
graph = (
|
||||||
|
pynutil.delete("integer:")
|
||||||
|
+ delete_space
|
||||||
|
+ pynutil.delete("\"")
|
||||||
|
+ pynini.closure(NEMO_NOT_QUOTE, 1)
|
||||||
|
+ pynutil.delete("\"")
|
||||||
|
)
|
||||||
|
convert_rest = pynutil.insert("th")
|
||||||
|
|
||||||
|
suffix = pynini.cdrewrite(
|
||||||
|
graph_digit | graph_teens | pynini.cross("ty", "tieth") | convert_rest, "", "[EOS]", NEMO_SIGMA,
|
||||||
|
).optimize()
|
||||||
|
self.graph = pynini.compose(graph, suffix)
|
||||||
|
self.suffix = suffix
|
||||||
|
delete_tokens = self.delete_tokens(self.graph)
|
||||||
|
self.fst = delete_tokens.optimize()
|
||||||
@@ -0,0 +1,180 @@
|
|||||||
|
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pynini
|
||||||
|
from nemo_text_processing.text_normalization.en.graph_utils import (
|
||||||
|
MIN_NEG_WEIGHT,
|
||||||
|
NEMO_ALPHA,
|
||||||
|
NEMO_CHAR,
|
||||||
|
NEMO_SIGMA,
|
||||||
|
NEMO_SPACE,
|
||||||
|
generator_main,
|
||||||
|
)
|
||||||
|
from nemo_text_processing.text_normalization.en.taggers.punctuation import PunctuationFst
|
||||||
|
from pynini.lib import pynutil
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class PostProcessingFst:
|
||||||
|
"""
|
||||||
|
Finite state transducer that post-processing an entire sentence after verbalization is complete, e.g.
|
||||||
|
removes extra spaces around punctuation marks " ( one hundred and twenty three ) " -> "(one hundred and twenty three)"
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cache_dir: path to a dir with .far grammar file. Set to None to avoid using cache.
|
||||||
|
overwrite_cache: set to True to overwrite .far files
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cache_dir: str = None, overwrite_cache: bool = False):
|
||||||
|
|
||||||
|
far_file = None
|
||||||
|
if cache_dir is not None and cache_dir != "None":
|
||||||
|
os.makedirs(cache_dir, exist_ok=True)
|
||||||
|
far_file = os.path.join(cache_dir, "en_tn_post_processing.far")
|
||||||
|
if not overwrite_cache and far_file and os.path.exists(far_file):
|
||||||
|
self.fst = pynini.Far(far_file, mode="r")["post_process_graph"]
|
||||||
|
else:
|
||||||
|
self.set_punct_dict()
|
||||||
|
self.fst = self.get_punct_postprocess_graph()
|
||||||
|
|
||||||
|
if far_file:
|
||||||
|
generator_main(far_file, {"post_process_graph": self.fst})
|
||||||
|
|
||||||
|
def set_punct_dict(self):
|
||||||
|
self.punct_marks = {
|
||||||
|
"'": [
|
||||||
|
"'",
|
||||||
|
'´',
|
||||||
|
'ʹ',
|
||||||
|
'ʻ',
|
||||||
|
'ʼ',
|
||||||
|
'ʽ',
|
||||||
|
'ʾ',
|
||||||
|
'ˈ',
|
||||||
|
'ˊ',
|
||||||
|
'ˋ',
|
||||||
|
'˴',
|
||||||
|
'ʹ',
|
||||||
|
'΄',
|
||||||
|
'՚',
|
||||||
|
'՝',
|
||||||
|
'י',
|
||||||
|
'׳',
|
||||||
|
'ߴ',
|
||||||
|
'ߵ',
|
||||||
|
'ᑊ',
|
||||||
|
'ᛌ',
|
||||||
|
'᾽',
|
||||||
|
'᾿',
|
||||||
|
'`',
|
||||||
|
'´',
|
||||||
|
'῾',
|
||||||
|
'‘',
|
||||||
|
'’',
|
||||||
|
'‛',
|
||||||
|
'′',
|
||||||
|
'‵',
|
||||||
|
'ꞌ',
|
||||||
|
''',
|
||||||
|
'`',
|
||||||
|
'𖽑',
|
||||||
|
'𖽒',
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_punct_postprocess_graph(self):
|
||||||
|
"""
|
||||||
|
Returns graph to post process punctuation marks.
|
||||||
|
|
||||||
|
{``} quotes are converted to {"}. Note, if there are spaces around single quote {'}, they will be kept.
|
||||||
|
By default, a space is added after a punctuation mark, and spaces are removed before punctuation marks.
|
||||||
|
"""
|
||||||
|
punct_marks_all = PunctuationFst().punct_marks
|
||||||
|
|
||||||
|
# no_space_before_punct assume no space before them
|
||||||
|
quotes = ["'", "\"", "``", "«"]
|
||||||
|
dashes = ["-", "—"]
|
||||||
|
brackets = ["<", "{", "("]
|
||||||
|
open_close_single_quotes = [
|
||||||
|
("`", "`"),
|
||||||
|
]
|
||||||
|
|
||||||
|
open_close_double_quotes = [('"', '"'), ("``", "``"), ("“", "”")]
|
||||||
|
open_close_symbols = open_close_single_quotes + open_close_double_quotes
|
||||||
|
allow_space_before_punct = ["&"] + quotes + dashes + brackets + [k[0] for k in open_close_symbols]
|
||||||
|
|
||||||
|
no_space_before_punct = [m for m in punct_marks_all if m not in allow_space_before_punct]
|
||||||
|
no_space_before_punct = pynini.union(*no_space_before_punct)
|
||||||
|
no_space_after_punct = pynini.union(*brackets)
|
||||||
|
delete_space = pynutil.delete(" ")
|
||||||
|
delete_space_optional = pynini.closure(delete_space, 0, 1)
|
||||||
|
|
||||||
|
# non_punct allows space
|
||||||
|
# delete space before no_space_before_punct marks, if present
|
||||||
|
non_punct = pynini.difference(NEMO_CHAR, no_space_before_punct).optimize()
|
||||||
|
graph = (
|
||||||
|
pynini.closure(non_punct)
|
||||||
|
+ pynini.closure(
|
||||||
|
no_space_before_punct | pynutil.add_weight(delete_space + no_space_before_punct, MIN_NEG_WEIGHT)
|
||||||
|
)
|
||||||
|
+ pynini.closure(non_punct)
|
||||||
|
)
|
||||||
|
graph = pynini.closure(graph).optimize()
|
||||||
|
graph = pynini.compose(
|
||||||
|
graph, pynini.cdrewrite(pynini.cross("``", '"'), "", "", NEMO_SIGMA).optimize()
|
||||||
|
).optimize()
|
||||||
|
|
||||||
|
# remove space after no_space_after_punct (even if there are no matching closing brackets)
|
||||||
|
no_space_after_punct = pynini.cdrewrite(delete_space, no_space_after_punct, NEMO_SIGMA, NEMO_SIGMA).optimize()
|
||||||
|
graph = pynini.compose(graph, no_space_after_punct).optimize()
|
||||||
|
|
||||||
|
# remove space around text in quotes
|
||||||
|
single_quote = pynutil.add_weight(pynini.accep("`"), MIN_NEG_WEIGHT)
|
||||||
|
double_quotes = pynutil.add_weight(pynini.accep('"'), MIN_NEG_WEIGHT)
|
||||||
|
quotes_graph = (
|
||||||
|
single_quote + delete_space_optional + NEMO_ALPHA + NEMO_SIGMA + delete_space_optional + single_quote
|
||||||
|
).optimize()
|
||||||
|
|
||||||
|
# this is to make sure multiple quotes are tagged from right to left without skipping any quotes in the left
|
||||||
|
not_alpha = pynini.difference(NEMO_CHAR, NEMO_ALPHA).optimize() | pynutil.add_weight(
|
||||||
|
NEMO_SPACE, MIN_NEG_WEIGHT
|
||||||
|
)
|
||||||
|
end = pynini.closure(pynutil.add_weight(not_alpha, MIN_NEG_WEIGHT))
|
||||||
|
quotes_graph |= (
|
||||||
|
double_quotes
|
||||||
|
+ delete_space_optional
|
||||||
|
+ NEMO_ALPHA
|
||||||
|
+ NEMO_SIGMA
|
||||||
|
+ delete_space_optional
|
||||||
|
+ double_quotes
|
||||||
|
+ end
|
||||||
|
)
|
||||||
|
|
||||||
|
quotes_graph = pynutil.add_weight(quotes_graph, MIN_NEG_WEIGHT)
|
||||||
|
quotes_graph = NEMO_SIGMA + pynini.closure(NEMO_SIGMA + quotes_graph + NEMO_SIGMA)
|
||||||
|
|
||||||
|
graph = pynini.compose(graph, quotes_graph).optimize()
|
||||||
|
|
||||||
|
# remove space between a word and a single quote followed by s
|
||||||
|
remove_space_around_single_quote = pynini.cdrewrite(
|
||||||
|
delete_space_optional + pynini.union(*self.punct_marks["'"]) + delete_space,
|
||||||
|
NEMO_ALPHA,
|
||||||
|
pynini.union("s ", "s[EOS]"),
|
||||||
|
NEMO_SIGMA,
|
||||||
|
)
|
||||||
|
|
||||||
|
graph = pynini.compose(graph, remove_space_around_single_quote).optimize()
|
||||||
|
return graph
|
||||||
@@ -0,0 +1,68 @@
|
|||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import pynini
|
||||||
|
from nemo_text_processing.text_normalization.en.graph_utils import NEMO_NOT_QUOTE, GraphFst
|
||||||
|
from nemo_text_processing.text_normalization.en.verbalizers.ordinal import OrdinalFst
|
||||||
|
from pynini.lib import pynutil
|
||||||
|
|
||||||
|
|
||||||
|
class RomanFst(GraphFst):
|
||||||
|
"""
|
||||||
|
Finite state transducer for verbalizing roman numerals
|
||||||
|
e.g. tokens { roman { integer: "one" } } -> one
|
||||||
|
|
||||||
|
Args:
|
||||||
|
deterministic: if True will provide a single transduction option,
|
||||||
|
for False multiple transduction are generated (used for audio-based normalization)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, deterministic: bool = True):
|
||||||
|
super().__init__(name="roman", kind="verbalize", deterministic=deterministic)
|
||||||
|
suffix = OrdinalFst().suffix
|
||||||
|
|
||||||
|
cardinal = pynini.closure(NEMO_NOT_QUOTE)
|
||||||
|
ordinal = pynini.compose(cardinal, suffix)
|
||||||
|
|
||||||
|
graph = (
|
||||||
|
pynutil.delete("key_cardinal: \"")
|
||||||
|
+ pynini.closure(NEMO_NOT_QUOTE, 1)
|
||||||
|
+ pynutil.delete("\"")
|
||||||
|
+ pynini.accep(" ")
|
||||||
|
+ pynutil.delete("integer: \"")
|
||||||
|
+ cardinal
|
||||||
|
+ pynutil.delete("\"")
|
||||||
|
).optimize()
|
||||||
|
|
||||||
|
graph |= (
|
||||||
|
pynutil.delete("default_cardinal: \"default\" integer: \"") + cardinal + pynutil.delete("\"")
|
||||||
|
).optimize()
|
||||||
|
|
||||||
|
graph |= (
|
||||||
|
pynutil.delete("default_ordinal: \"default\" integer: \"") + ordinal + pynutil.delete("\"")
|
||||||
|
).optimize()
|
||||||
|
|
||||||
|
graph |= (
|
||||||
|
pynutil.delete("key_the_ordinal: \"")
|
||||||
|
+ pynini.closure(NEMO_NOT_QUOTE, 1)
|
||||||
|
+ pynutil.delete("\"")
|
||||||
|
+ pynini.accep(" ")
|
||||||
|
+ pynutil.delete("integer: \"")
|
||||||
|
+ pynini.closure(pynutil.insert("the "), 0, 1)
|
||||||
|
+ ordinal
|
||||||
|
+ pynutil.delete("\"")
|
||||||
|
).optimize()
|
||||||
|
|
||||||
|
delete_tokens = self.delete_tokens(graph)
|
||||||
|
self.fst = delete_tokens.optimize()
|
||||||
@@ -0,0 +1,63 @@
|
|||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import pynini
|
||||||
|
from nemo_text_processing.text_normalization.en.graph_utils import NEMO_NOT_QUOTE, GraphFst, delete_space, insert_space
|
||||||
|
from pynini.lib import pynutil
|
||||||
|
|
||||||
|
|
||||||
|
class TelephoneFst(GraphFst):
|
||||||
|
"""
|
||||||
|
Finite state transducer for verbalizing telephone numbers, e.g.
|
||||||
|
telephone { country_code: "one" number_part: "one two three, one two three, five six seven eight" extension: "one" }
|
||||||
|
-> one, one two three, one two three, five six seven eight, one
|
||||||
|
|
||||||
|
Args:
|
||||||
|
deterministic: if True will provide a single transduction option,
|
||||||
|
for False multiple transduction are generated (used for audio-based normalization)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, deterministic: bool = True):
|
||||||
|
super().__init__(name="telephone", kind="verbalize", deterministic=deterministic)
|
||||||
|
|
||||||
|
optional_country_code = pynini.closure(
|
||||||
|
pynutil.delete("country_code: \"")
|
||||||
|
+ pynini.closure(NEMO_NOT_QUOTE, 1)
|
||||||
|
+ pynutil.delete("\"")
|
||||||
|
+ delete_space
|
||||||
|
+ insert_space,
|
||||||
|
0,
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
|
||||||
|
number_part = (
|
||||||
|
pynutil.delete("number_part: \"")
|
||||||
|
+ pynini.closure(NEMO_NOT_QUOTE, 1)
|
||||||
|
+ pynini.closure(pynutil.add_weight(pynutil.delete(" "), -0.0001), 0, 1)
|
||||||
|
+ pynutil.delete("\"")
|
||||||
|
)
|
||||||
|
|
||||||
|
optional_extension = pynini.closure(
|
||||||
|
delete_space
|
||||||
|
+ insert_space
|
||||||
|
+ pynutil.delete("extension: \"")
|
||||||
|
+ pynini.closure(NEMO_NOT_QUOTE, 1)
|
||||||
|
+ pynutil.delete("\""),
|
||||||
|
0,
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
|
||||||
|
graph = optional_country_code + number_part + optional_extension
|
||||||
|
delete_tokens = self.delete_tokens(graph)
|
||||||
|
self.fst = delete_tokens.optimize()
|
||||||
@@ -0,0 +1,102 @@
|
|||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import pynini
|
||||||
|
from nemo_text_processing.text_normalization.en.graph_utils import (
|
||||||
|
NEMO_NOT_QUOTE,
|
||||||
|
NEMO_SIGMA,
|
||||||
|
GraphFst,
|
||||||
|
delete_space,
|
||||||
|
insert_space,
|
||||||
|
)
|
||||||
|
from pynini.lib import pynutil
|
||||||
|
|
||||||
|
|
||||||
|
class TimeFst(GraphFst):
|
||||||
|
"""
|
||||||
|
Finite state transducer for verbalizing time, e.g.
|
||||||
|
time { hours: "twelve" minutes: "thirty" suffix: "a m" zone: "e s t" } -> twelve thirty a m e s t
|
||||||
|
time { hours: "twelve" } -> twelve o'clock
|
||||||
|
|
||||||
|
Args:
|
||||||
|
deterministic: if True will provide a single transduction option,
|
||||||
|
for False multiple transduction are generated (used for audio-based normalization)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, deterministic: bool = True):
|
||||||
|
super().__init__(name="time", kind="verbalize", deterministic=deterministic)
|
||||||
|
hour = (
|
||||||
|
pynutil.delete("hours:")
|
||||||
|
+ delete_space
|
||||||
|
+ pynutil.delete("\"")
|
||||||
|
+ pynini.closure(NEMO_NOT_QUOTE, 1)
|
||||||
|
+ pynutil.delete("\"")
|
||||||
|
)
|
||||||
|
minute = (
|
||||||
|
pynutil.delete("minutes:")
|
||||||
|
+ delete_space
|
||||||
|
+ pynutil.delete("\"")
|
||||||
|
+ pynini.closure(NEMO_NOT_QUOTE, 1)
|
||||||
|
+ pynutil.delete("\"")
|
||||||
|
)
|
||||||
|
suffix = (
|
||||||
|
pynutil.delete("suffix:")
|
||||||
|
+ delete_space
|
||||||
|
+ pynutil.delete("\"")
|
||||||
|
+ pynini.closure(NEMO_NOT_QUOTE, 1)
|
||||||
|
+ pynutil.delete("\"")
|
||||||
|
)
|
||||||
|
optional_suffix = pynini.closure(delete_space + insert_space + suffix, 0, 1)
|
||||||
|
zone = (
|
||||||
|
pynutil.delete("zone:")
|
||||||
|
+ delete_space
|
||||||
|
+ pynutil.delete("\"")
|
||||||
|
+ pynini.closure(NEMO_NOT_QUOTE, 1)
|
||||||
|
+ pynutil.delete("\"")
|
||||||
|
)
|
||||||
|
optional_zone = pynini.closure(delete_space + insert_space + zone, 0, 1)
|
||||||
|
second = (
|
||||||
|
pynutil.delete("seconds:")
|
||||||
|
+ delete_space
|
||||||
|
+ pynutil.delete("\"")
|
||||||
|
+ pynini.closure(NEMO_NOT_QUOTE, 1)
|
||||||
|
+ pynutil.delete("\"")
|
||||||
|
)
|
||||||
|
graph_hms = (
|
||||||
|
hour
|
||||||
|
+ pynutil.insert(" hours ")
|
||||||
|
+ delete_space
|
||||||
|
+ minute
|
||||||
|
+ pynutil.insert(" minutes and ")
|
||||||
|
+ delete_space
|
||||||
|
+ second
|
||||||
|
+ pynutil.insert(" seconds")
|
||||||
|
+ optional_suffix
|
||||||
|
+ optional_zone
|
||||||
|
)
|
||||||
|
graph_hms @= pynini.cdrewrite(
|
||||||
|
pynutil.delete("o ")
|
||||||
|
| pynini.cross("one minutes", "one minute")
|
||||||
|
| pynini.cross("one seconds", "one second")
|
||||||
|
| pynini.cross("one hours", "one hour"),
|
||||||
|
pynini.union(" ", "[BOS]"),
|
||||||
|
"",
|
||||||
|
NEMO_SIGMA,
|
||||||
|
)
|
||||||
|
graph = hour + delete_space + insert_space + minute + optional_suffix + optional_zone
|
||||||
|
graph |= hour + insert_space + pynutil.insert("o'clock") + optional_zone
|
||||||
|
graph |= hour + delete_space + insert_space + suffix + optional_zone
|
||||||
|
graph |= graph_hms
|
||||||
|
delete_tokens = self.delete_tokens(graph)
|
||||||
|
self.fst = delete_tokens.optimize()
|
||||||
@@ -0,0 +1,82 @@
|
|||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from nemo_text_processing.text_normalization.en.graph_utils import GraphFst
|
||||||
|
from nemo_text_processing.text_normalization.en.verbalizers.abbreviation import AbbreviationFst
|
||||||
|
from nemo_text_processing.text_normalization.en.verbalizers.cardinal import CardinalFst
|
||||||
|
from nemo_text_processing.text_normalization.en.verbalizers.date import DateFst
|
||||||
|
from nemo_text_processing.text_normalization.en.verbalizers.decimal import DecimalFst
|
||||||
|
from nemo_text_processing.text_normalization.en.verbalizers.electronic import ElectronicFst
|
||||||
|
from nemo_text_processing.text_normalization.en.verbalizers.fraction import FractionFst
|
||||||
|
from nemo_text_processing.text_normalization.en.verbalizers.measure import MeasureFst
|
||||||
|
from nemo_text_processing.text_normalization.en.verbalizers.money import MoneyFst
|
||||||
|
from nemo_text_processing.text_normalization.en.verbalizers.ordinal import OrdinalFst
|
||||||
|
from nemo_text_processing.text_normalization.en.verbalizers.roman import RomanFst
|
||||||
|
from nemo_text_processing.text_normalization.en.verbalizers.telephone import TelephoneFst
|
||||||
|
from nemo_text_processing.text_normalization.en.verbalizers.time import TimeFst
|
||||||
|
from nemo_text_processing.text_normalization.en.verbalizers.whitelist import WhiteListFst
|
||||||
|
|
||||||
|
|
||||||
|
class VerbalizeFst(GraphFst):
|
||||||
|
"""
|
||||||
|
Composes other verbalizer grammars.
|
||||||
|
For deployment, this grammar will be compiled and exported to OpenFst Finate State Archiv (FAR) File.
|
||||||
|
More details to deployment at NeMo/tools/text_processing_deployment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
deterministic: if True will provide a single transduction option,
|
||||||
|
for False multiple options (used for audio-based normalization)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, deterministic: bool = True):
|
||||||
|
super().__init__(name="verbalize", kind="verbalize", deterministic=deterministic)
|
||||||
|
cardinal = CardinalFst(deterministic=deterministic)
|
||||||
|
cardinal_graph = cardinal.fst
|
||||||
|
decimal = DecimalFst(cardinal=cardinal, deterministic=deterministic)
|
||||||
|
decimal_graph = decimal.fst
|
||||||
|
ordinal = OrdinalFst(deterministic=deterministic)
|
||||||
|
ordinal_graph = ordinal.fst
|
||||||
|
fraction = FractionFst(deterministic=deterministic)
|
||||||
|
fraction_graph = fraction.fst
|
||||||
|
telephone_graph = TelephoneFst(deterministic=deterministic).fst
|
||||||
|
electronic_graph = ElectronicFst(deterministic=deterministic).fst
|
||||||
|
measure = MeasureFst(decimal=decimal, cardinal=cardinal, fraction=fraction, deterministic=deterministic)
|
||||||
|
measure_graph = measure.fst
|
||||||
|
time_graph = TimeFst(deterministic=deterministic).fst
|
||||||
|
date_graph = DateFst(ordinal=ordinal, deterministic=deterministic).fst
|
||||||
|
money_graph = MoneyFst(decimal=decimal, deterministic=deterministic).fst
|
||||||
|
whitelist_graph = WhiteListFst(deterministic=deterministic).fst
|
||||||
|
|
||||||
|
graph = (
|
||||||
|
time_graph
|
||||||
|
| date_graph
|
||||||
|
| money_graph
|
||||||
|
| measure_graph
|
||||||
|
| ordinal_graph
|
||||||
|
| decimal_graph
|
||||||
|
| cardinal_graph
|
||||||
|
| telephone_graph
|
||||||
|
| electronic_graph
|
||||||
|
| fraction_graph
|
||||||
|
| whitelist_graph
|
||||||
|
)
|
||||||
|
|
||||||
|
roman_graph = RomanFst(deterministic=deterministic).fst
|
||||||
|
graph |= roman_graph
|
||||||
|
|
||||||
|
if not deterministic:
|
||||||
|
abbreviation_graph = AbbreviationFst(deterministic=deterministic).fst
|
||||||
|
graph |= abbreviation_graph
|
||||||
|
|
||||||
|
self.fst = graph
|
||||||
@@ -0,0 +1,75 @@
|
|||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pynini
|
||||||
|
from nemo_text_processing.text_normalization.en.graph_utils import (
|
||||||
|
GraphFst,
|
||||||
|
delete_extra_space,
|
||||||
|
delete_space,
|
||||||
|
generator_main,
|
||||||
|
)
|
||||||
|
from nemo_text_processing.text_normalization.en.verbalizers.verbalize import VerbalizeFst
|
||||||
|
from nemo_text_processing.text_normalization.en.verbalizers.word import WordFst
|
||||||
|
from pynini.lib import pynutil
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class VerbalizeFinalFst(GraphFst):
|
||||||
|
"""
|
||||||
|
Finite state transducer that verbalizes an entire sentence, e.g.
|
||||||
|
tokens { name: "its" } tokens { time { hours: "twelve" minutes: "thirty" } } tokens { name: "now" } tokens { name: "." } -> its twelve thirty now .
|
||||||
|
|
||||||
|
Args:
|
||||||
|
deterministic: if True will provide a single transduction option,
|
||||||
|
for False multiple options (used for audio-based normalization)
|
||||||
|
cache_dir: path to a dir with .far grammar file. Set to None to avoid using cache.
|
||||||
|
overwrite_cache: set to True to overwrite .far files
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, deterministic: bool = True, cache_dir: str = None, overwrite_cache: bool = False):
|
||||||
|
super().__init__(name="verbalize_final", kind="verbalize", deterministic=deterministic)
|
||||||
|
|
||||||
|
far_file = None
|
||||||
|
if cache_dir is not None and cache_dir != "None":
|
||||||
|
os.makedirs(cache_dir, exist_ok=True)
|
||||||
|
far_file = os.path.join(cache_dir, f"en_tn_{deterministic}_deterministic_verbalizer.far")
|
||||||
|
if not overwrite_cache and far_file and os.path.exists(far_file):
|
||||||
|
self.fst = pynini.Far(far_file, mode="r")["verbalize"]
|
||||||
|
|
||||||
|
else:
|
||||||
|
verbalize = VerbalizeFst(deterministic=deterministic).fst
|
||||||
|
word = WordFst(deterministic=deterministic).fst
|
||||||
|
types = verbalize | word
|
||||||
|
|
||||||
|
if deterministic:
|
||||||
|
graph = (
|
||||||
|
pynutil.delete("tokens")
|
||||||
|
+ delete_space
|
||||||
|
+ pynutil.delete("{")
|
||||||
|
+ delete_space
|
||||||
|
+ types
|
||||||
|
+ delete_space
|
||||||
|
+ pynutil.delete("}")
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
graph = delete_space + types + delete_space
|
||||||
|
|
||||||
|
graph = delete_space + pynini.closure(graph + delete_extra_space) + graph + delete_space
|
||||||
|
|
||||||
|
self.fst = graph.optimize()
|
||||||
|
if far_file:
|
||||||
|
generator_main(far_file, {"verbalize": self.fst})
|
||||||
|
|
||||||
@@ -0,0 +1,39 @@
|
|||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import pynini
|
||||||
|
from nemo_text_processing.text_normalization.en.graph_utils import NEMO_CHAR, NEMO_SIGMA, GraphFst, delete_space
|
||||||
|
from pynini.lib import pynutil
|
||||||
|
|
||||||
|
|
||||||
|
class WhiteListFst(GraphFst):
|
||||||
|
"""
|
||||||
|
Finite state transducer for verbalizing whitelist
|
||||||
|
e.g. tokens { name: "misses" } } -> misses
|
||||||
|
|
||||||
|
Args:
|
||||||
|
deterministic: if True will provide a single transduction option,
|
||||||
|
for False multiple transduction are generated (used for audio-based normalization)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, deterministic: bool = True):
|
||||||
|
super().__init__(name="whitelist", kind="verbalize", deterministic=deterministic)
|
||||||
|
graph = (
|
||||||
|
pynutil.delete("name:")
|
||||||
|
+ delete_space
|
||||||
|
+ pynutil.delete("\"")
|
||||||
|
+ pynini.closure(NEMO_CHAR - " ", 1)
|
||||||
|
+ pynutil.delete("\"")
|
||||||
|
)
|
||||||
|
graph = graph @ pynini.cdrewrite(pynini.cross(u"\u00A0", " "), "", "", NEMO_SIGMA)
|
||||||
|
self.fst = graph.optimize()
|
||||||
@@ -0,0 +1,35 @@
|
|||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import pynini
|
||||||
|
from nemo_text_processing.text_normalization.en.graph_utils import NEMO_CHAR, NEMO_SIGMA, GraphFst, delete_space
|
||||||
|
from pynini.lib import pynutil
|
||||||
|
|
||||||
|
|
||||||
|
class WordFst(GraphFst):
|
||||||
|
"""
|
||||||
|
Finite state transducer for verbalizing word
|
||||||
|
e.g. tokens { name: "sleep" } -> sleep
|
||||||
|
|
||||||
|
Args:
|
||||||
|
deterministic: if True will provide a single transduction option,
|
||||||
|
for False multiple transduction are generated (used for audio-based normalization)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, deterministic: bool = True):
|
||||||
|
super().__init__(name="word", kind="verbalize", deterministic=deterministic)
|
||||||
|
chars = pynini.closure(NEMO_CHAR - " ", 1)
|
||||||
|
char = pynutil.delete("name:") + delete_space + pynutil.delete("\"") + chars + pynutil.delete("\"")
|
||||||
|
graph = char @ pynini.cdrewrite(pynini.cross(u"\u00A0", " "), "", "", NEMO_SIGMA)
|
||||||
|
|
||||||
|
self.fst = graph.optimize()
|
||||||
@@ -0,0 +1,479 @@
|
|||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import itertools
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
from collections import OrderedDict
|
||||||
|
from math import factorial
|
||||||
|
from time import perf_counter
|
||||||
|
from typing import Dict, List, Union
|
||||||
|
|
||||||
|
import pynini
|
||||||
|
import regex
|
||||||
|
from nemo_text_processing.text_normalization.data_loader_utils import (
|
||||||
|
load_file,
|
||||||
|
post_process_punct,
|
||||||
|
pre_process,
|
||||||
|
write_file,
|
||||||
|
)
|
||||||
|
from nemo_text_processing.text_normalization.token_parser import PRESERVE_ORDER_KEY, TokenParser
|
||||||
|
from pynini.lib.rewrite import top_rewrite
|
||||||
|
|
||||||
|
SPACE_DUP = re.compile(' {2,}')
|
||||||
|
|
||||||
|
|
||||||
|
class Normalizer:
|
||||||
|
"""
|
||||||
|
Normalizer class that converts text from written to spoken form.
|
||||||
|
Useful for TTS preprocessing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_case: expected input capitalization
|
||||||
|
lang: language specifying the TN rules, by default: English
|
||||||
|
cache_dir: path to a dir with .far grammar file. Set to None to avoid using cache.
|
||||||
|
overwrite_cache: set to True to overwrite .far files
|
||||||
|
whitelist: path to a file with whitelist replacements
|
||||||
|
post_process: WFST-based post processing, e.g. to remove extra spaces added during TN.
|
||||||
|
Note: punct_post_process flag in normalize() supports all languages.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_case: str,
|
||||||
|
lang: str = 'en',
|
||||||
|
deterministic: bool = True,
|
||||||
|
cache_dir: str = None,
|
||||||
|
overwrite_cache: bool = False,
|
||||||
|
whitelist: str = None,
|
||||||
|
lm: bool = False,
|
||||||
|
post_process: bool = True,
|
||||||
|
):
|
||||||
|
assert input_case in ["lower_cased", "cased"]
|
||||||
|
|
||||||
|
self.post_processor = None
|
||||||
|
|
||||||
|
if lang == "en":
|
||||||
|
from nemo_text_processing.text_normalization.en.verbalizers.post_processing import PostProcessingFst
|
||||||
|
from nemo_text_processing.text_normalization.en.verbalizers.verbalize_final import VerbalizeFinalFst
|
||||||
|
|
||||||
|
if post_process:
|
||||||
|
self.post_processor = PostProcessingFst(cache_dir=cache_dir, overwrite_cache=overwrite_cache)
|
||||||
|
|
||||||
|
if deterministic:
|
||||||
|
from nemo_text_processing.text_normalization.en.taggers.tokenize_and_classify import ClassifyFst
|
||||||
|
else:
|
||||||
|
if lm:
|
||||||
|
from nemo_text_processing.text_normalization.en.taggers.tokenize_and_classify_lm import ClassifyFst
|
||||||
|
else:
|
||||||
|
from nemo_text_processing.text_normalization.en.taggers.tokenize_and_classify_with_audio import (
|
||||||
|
ClassifyFst,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif lang == 'ru':
|
||||||
|
# Ru TN only support non-deterministic cases and produces multiple normalization options
|
||||||
|
# use normalize_with_audio.py
|
||||||
|
from nemo_text_processing.text_normalization.ru.taggers.tokenize_and_classify import ClassifyFst
|
||||||
|
from nemo_text_processing.text_normalization.ru.verbalizers.verbalize_final import VerbalizeFinalFst
|
||||||
|
elif lang == 'de':
|
||||||
|
from nemo_text_processing.text_normalization.de.taggers.tokenize_and_classify import ClassifyFst
|
||||||
|
from nemo_text_processing.text_normalization.de.verbalizers.verbalize_final import VerbalizeFinalFst
|
||||||
|
elif lang == 'es':
|
||||||
|
from nemo_text_processing.text_normalization.es.taggers.tokenize_and_classify import ClassifyFst
|
||||||
|
from nemo_text_processing.text_normalization.es.verbalizers.verbalize_final import VerbalizeFinalFst
|
||||||
|
self.tagger = ClassifyFst(
|
||||||
|
input_case=input_case,
|
||||||
|
deterministic=deterministic,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
overwrite_cache=overwrite_cache,
|
||||||
|
whitelist=whitelist,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.verbalizer = VerbalizeFinalFst(
|
||||||
|
deterministic=deterministic, cache_dir=cache_dir, overwrite_cache=overwrite_cache
|
||||||
|
)
|
||||||
|
|
||||||
|
self.parser = TokenParser()
|
||||||
|
self.lang = lang
|
||||||
|
|
||||||
|
self.processor = 0
|
||||||
|
|
||||||
|
def __process_batch(self, batch, verbose, punct_pre_process, punct_post_process):
|
||||||
|
"""
|
||||||
|
Normalizes batch of text sequences
|
||||||
|
Args:
|
||||||
|
batch: list of texts
|
||||||
|
verbose: whether to print intermediate meta information
|
||||||
|
punct_pre_process: whether to do punctuation pre processing
|
||||||
|
punct_post_process: whether to do punctuation post processing
|
||||||
|
"""
|
||||||
|
normalized_lines = [
|
||||||
|
self.normalize(
|
||||||
|
text, verbose=verbose, punct_pre_process=punct_pre_process, punct_post_process=punct_post_process
|
||||||
|
)
|
||||||
|
for text in tqdm(batch)
|
||||||
|
]
|
||||||
|
return normalized_lines
|
||||||
|
|
||||||
|
def _estimate_number_of_permutations_in_nested_dict(
|
||||||
|
self, token_group: Dict[str, Union[OrderedDict, str, bool]]
|
||||||
|
) -> int:
|
||||||
|
num_perms = 1
|
||||||
|
for k, inner in token_group.items():
|
||||||
|
if isinstance(inner, dict):
|
||||||
|
num_perms *= self._estimate_number_of_permutations_in_nested_dict(inner)
|
||||||
|
num_perms *= factorial(len(token_group))
|
||||||
|
return num_perms
|
||||||
|
|
||||||
|
def _split_tokens_to_reduce_number_of_permutations(
|
||||||
|
self, tokens: List[dict], max_number_of_permutations_per_split: int = 729
|
||||||
|
) -> List[List[dict]]:
|
||||||
|
"""
|
||||||
|
Splits a sequence of tokens in a smaller sequences of tokens in a way that maximum number of composite
|
||||||
|
tokens permutations does not exceed ``max_number_of_permutations_per_split``.
|
||||||
|
|
||||||
|
For example,
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
tokens = [
|
||||||
|
{"tokens": {"date": {"year": "twenty eighteen", "month": "december", "day": "thirty one"}}},
|
||||||
|
{"tokens": {"date": {"year": "twenty eighteen", "month": "january", "day": "eight"}}},
|
||||||
|
]
|
||||||
|
split = normalizer._split_tokens_to_reduce_number_of_permutations(
|
||||||
|
tokens, max_number_of_permutations_per_split=6
|
||||||
|
)
|
||||||
|
assert split == [
|
||||||
|
[{"tokens": {"date": {"year": "twenty eighteen", "month": "december", "day": "thirty one"}}}],
|
||||||
|
[{"tokens": {"date": {"year": "twenty eighteen", "month": "january", "day": "eight"}}}],
|
||||||
|
]
|
||||||
|
|
||||||
|
Date tokens contain 3 items each which gives 6 permutations for every date. Since there are 2 dates, total
|
||||||
|
number of permutations would be ``6 * 6 == 36``. Parameter ``max_number_of_permutations_per_split`` equals 6,
|
||||||
|
so input sequence of tokens is split into 2 smaller sequences.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokens (:obj:`List[dict]`): a list of dictionaries, possibly nested.
|
||||||
|
max_number_of_permutations_per_split (:obj:`int`, `optional`, defaults to :obj:`243`): a maximum number
|
||||||
|
of permutations which can be generated from input sequence of tokens.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
:obj:`List[List[dict]]`: a list of smaller sequences of tokens resulting from ``tokens`` split.
|
||||||
|
"""
|
||||||
|
splits = []
|
||||||
|
prev_end_of_split = 0
|
||||||
|
current_number_of_permutations = 1
|
||||||
|
for i, token_group in enumerate(tokens):
|
||||||
|
n = self._estimate_number_of_permutations_in_nested_dict(token_group)
|
||||||
|
if n * current_number_of_permutations > max_number_of_permutations_per_split:
|
||||||
|
splits.append(tokens[prev_end_of_split:i])
|
||||||
|
prev_end_of_split = i
|
||||||
|
current_number_of_permutations = 1
|
||||||
|
if n > max_number_of_permutations_per_split:
|
||||||
|
raise ValueError(
|
||||||
|
f"Could not split token list with respect to condition that every split can generate number of "
|
||||||
|
f"permutations less or equal to "
|
||||||
|
f"`max_number_of_permutations_per_split={max_number_of_permutations_per_split}`. "
|
||||||
|
f"There is an unsplittable token group that generates more than "
|
||||||
|
f"{max_number_of_permutations_per_split} permutations. Try to increase "
|
||||||
|
f"`max_number_of_permutations_per_split` parameter."
|
||||||
|
)
|
||||||
|
current_number_of_permutations *= n
|
||||||
|
splits.append(tokens[prev_end_of_split:])
|
||||||
|
assert sum([len(s) for s in splits]) == len(tokens)
|
||||||
|
return splits
|
||||||
|
|
||||||
|
def normalize(
|
||||||
|
self, text: str, verbose: bool = False, punct_pre_process: bool = False, punct_post_process: bool = False
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Main function. Normalizes tokens from written to spoken form
|
||||||
|
e.g. 12 kg -> twelve kilograms
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: string that may include semiotic classes
|
||||||
|
verbose: whether to print intermediate meta information
|
||||||
|
punct_pre_process: whether to perform punctuation pre-processing, for example, [25] -> [ 25 ]
|
||||||
|
punct_post_process: whether to normalize punctuation
|
||||||
|
|
||||||
|
Returns: spoken form
|
||||||
|
"""
|
||||||
|
|
||||||
|
original_text = text
|
||||||
|
if punct_pre_process:
|
||||||
|
text = pre_process(text)
|
||||||
|
text = text.strip()
|
||||||
|
if not text:
|
||||||
|
if verbose:
|
||||||
|
print(text)
|
||||||
|
return text
|
||||||
|
text = pynini.escape(text)
|
||||||
|
tagged_lattice = self.find_tags(text)
|
||||||
|
tagged_text = self.select_tag(tagged_lattice)
|
||||||
|
if verbose:
|
||||||
|
print(tagged_text)
|
||||||
|
self.parser(tagged_text)
|
||||||
|
tokens = self.parser.parse()
|
||||||
|
split_tokens = self._split_tokens_to_reduce_number_of_permutations(tokens)
|
||||||
|
output = ""
|
||||||
|
for s in split_tokens:
|
||||||
|
tags_reordered = self.generate_permutations(s)
|
||||||
|
verbalizer_lattice = None
|
||||||
|
for tagged_text in tags_reordered:
|
||||||
|
tagged_text = pynini.escape(tagged_text)
|
||||||
|
|
||||||
|
verbalizer_lattice = self.find_verbalizer(tagged_text)
|
||||||
|
if verbalizer_lattice.num_states() != 0:
|
||||||
|
break
|
||||||
|
if verbalizer_lattice is None:
|
||||||
|
raise ValueError(f"No permutations were generated from tokens {s}")
|
||||||
|
output += ' ' + self.select_verbalizer(verbalizer_lattice)
|
||||||
|
output = SPACE_DUP.sub(' ', output[1:])
|
||||||
|
|
||||||
|
if self.lang == "en" and hasattr(self, 'post_processor'):
|
||||||
|
output = self.post_process(output)
|
||||||
|
|
||||||
|
if punct_post_process:
|
||||||
|
# do post-processing based on Moses detokenizer
|
||||||
|
if self.processor:
|
||||||
|
output = self.processor.moses_detokenizer.detokenize([output], unescape=False)
|
||||||
|
output = post_process_punct(input=original_text, normalized_text=output)
|
||||||
|
else:
|
||||||
|
print("NEMO_NLP collection is not available: skipping punctuation post_processing")
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def split_text_into_sentences(self, text: str) -> List[str]:
|
||||||
|
"""
|
||||||
|
Split text into sentences.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: text
|
||||||
|
|
||||||
|
Returns list of sentences
|
||||||
|
"""
|
||||||
|
lower_case_unicode = ''
|
||||||
|
upper_case_unicode = ''
|
||||||
|
if self.lang == "ru":
|
||||||
|
lower_case_unicode = '\u0430-\u04FF'
|
||||||
|
upper_case_unicode = '\u0410-\u042F'
|
||||||
|
|
||||||
|
# Read and split transcript by utterance (roughly, sentences)
|
||||||
|
split_pattern = f"(?<!\w\.\w.)(?<![A-Z{upper_case_unicode}][a-z{lower_case_unicode}]+\.)(?<![A-Z{upper_case_unicode}]\.)(?<=\.|\?|\!|\.”|\?”\!”)\s(?![0-9]+[a-z]*\.)"
|
||||||
|
|
||||||
|
sentences = regex.split(split_pattern, text)
|
||||||
|
return sentences
|
||||||
|
|
||||||
|
def _permute(self, d: OrderedDict) -> List[str]:
|
||||||
|
"""
|
||||||
|
Creates reorderings of dictionary elements and serializes as strings
|
||||||
|
|
||||||
|
Args:
|
||||||
|
d: (nested) dictionary of key value pairs
|
||||||
|
|
||||||
|
Return permutations of different string serializations of key value pairs
|
||||||
|
"""
|
||||||
|
l = []
|
||||||
|
if PRESERVE_ORDER_KEY in d.keys():
|
||||||
|
d_permutations = [d.items()]
|
||||||
|
else:
|
||||||
|
d_permutations = itertools.permutations(d.items())
|
||||||
|
for perm in d_permutations:
|
||||||
|
subl = [""]
|
||||||
|
for k, v in perm:
|
||||||
|
if isinstance(v, str):
|
||||||
|
subl = ["".join(x) for x in itertools.product(subl, [f"{k}: \"{v}\" "])]
|
||||||
|
elif isinstance(v, OrderedDict):
|
||||||
|
rec = self._permute(v)
|
||||||
|
subl = ["".join(x) for x in itertools.product(subl, [f" {k} {{ "], rec, [f" }} "])]
|
||||||
|
elif isinstance(v, bool):
|
||||||
|
subl = ["".join(x) for x in itertools.product(subl, [f"{k}: true "])]
|
||||||
|
else:
|
||||||
|
raise ValueError()
|
||||||
|
l.extend(subl)
|
||||||
|
return l
|
||||||
|
|
||||||
|
def generate_permutations(self, tokens: List[dict]):
|
||||||
|
"""
|
||||||
|
Generates permutations of string serializations of list of dictionaries
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokens: list of dictionaries
|
||||||
|
|
||||||
|
Returns string serialization of list of dictionaries
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _helper(prefix: str, tokens: List[dict], idx: int):
|
||||||
|
"""
|
||||||
|
Generates permutations of string serializations of given dictionary
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokens: list of dictionaries
|
||||||
|
prefix: prefix string
|
||||||
|
idx: index of next dictionary
|
||||||
|
|
||||||
|
Returns string serialization of dictionary
|
||||||
|
"""
|
||||||
|
if idx == len(tokens):
|
||||||
|
yield prefix
|
||||||
|
return
|
||||||
|
token_options = self._permute(tokens[idx])
|
||||||
|
for token_option in token_options:
|
||||||
|
yield from _helper(prefix + token_option, tokens, idx + 1)
|
||||||
|
|
||||||
|
return _helper("", tokens, 0)
|
||||||
|
|
||||||
|
def find_tags(self, text: str) -> 'pynini.FstLike':
|
||||||
|
"""
|
||||||
|
Given text use tagger Fst to tag text
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: sentence
|
||||||
|
|
||||||
|
Returns: tagged lattice
|
||||||
|
"""
|
||||||
|
lattice = text @ self.tagger.fst
|
||||||
|
return lattice
|
||||||
|
|
||||||
|
def select_tag(self, lattice: 'pynini.FstLike') -> str:
|
||||||
|
"""
|
||||||
|
Given tagged lattice return shortest path
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tagged_text: tagged text
|
||||||
|
|
||||||
|
Returns: shortest path
|
||||||
|
"""
|
||||||
|
tagged_text = pynini.shortestpath(lattice, nshortest=1, unique=True).string()
|
||||||
|
return tagged_text
|
||||||
|
|
||||||
|
def find_verbalizer(self, tagged_text: str) -> 'pynini.FstLike':
|
||||||
|
"""
|
||||||
|
Given tagged text creates verbalization lattice
|
||||||
|
This is context-independent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tagged_text: input text
|
||||||
|
|
||||||
|
Returns: verbalized lattice
|
||||||
|
"""
|
||||||
|
lattice = tagged_text @ self.verbalizer.fst
|
||||||
|
return lattice
|
||||||
|
|
||||||
|
def select_verbalizer(self, lattice: 'pynini.FstLike') -> str:
|
||||||
|
"""
|
||||||
|
Given verbalized lattice return shortest path
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lattice: verbalization lattice
|
||||||
|
|
||||||
|
Returns: shortest path
|
||||||
|
"""
|
||||||
|
output = pynini.shortestpath(lattice, nshortest=1, unique=True).string()
|
||||||
|
# lattice = output @ self.verbalizer.punct_graph
|
||||||
|
# output = pynini.shortestpath(lattice, nshortest=1, unique=True).string()
|
||||||
|
return output
|
||||||
|
|
||||||
|
def post_process(self, normalized_text: 'pynini.FstLike') -> str:
|
||||||
|
"""
|
||||||
|
Runs post processing graph on normalized text
|
||||||
|
|
||||||
|
Args:
|
||||||
|
normalized_text: normalized text
|
||||||
|
|
||||||
|
Returns: shortest path
|
||||||
|
"""
|
||||||
|
normalized_text = normalized_text.strip()
|
||||||
|
if not normalized_text:
|
||||||
|
return normalized_text
|
||||||
|
normalized_text = pynini.escape(normalized_text)
|
||||||
|
|
||||||
|
if self.post_processor is not None:
|
||||||
|
normalized_text = top_rewrite(normalized_text, self.post_processor.fst)
|
||||||
|
return normalized_text
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = ArgumentParser()
|
||||||
|
input = parser.add_mutually_exclusive_group()
|
||||||
|
input.add_argument("--text", dest="input_string", help="input string", type=str)
|
||||||
|
input.add_argument("--input_file", dest="input_file", help="input file path", type=str)
|
||||||
|
parser.add_argument('--output_file', dest="output_file", help="output file path", type=str)
|
||||||
|
parser.add_argument("--language", help="language", choices=["en", "de", "es"], default="en", type=str)
|
||||||
|
parser.add_argument(
|
||||||
|
"--input_case", help="input capitalization", choices=["lower_cased", "cased"], default="cased", type=str
|
||||||
|
)
|
||||||
|
parser.add_argument("--verbose", help="print info for debugging", action='store_true')
|
||||||
|
parser.add_argument(
|
||||||
|
"--punct_post_process",
|
||||||
|
help="set to True to enable punctuation post processing to match input.",
|
||||||
|
action="store_true",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--punct_pre_process", help="set to True to enable punctuation pre processing", action="store_true"
|
||||||
|
)
|
||||||
|
parser.add_argument("--overwrite_cache", help="set to True to re-create .far grammar files", action="store_true")
|
||||||
|
parser.add_argument("--whitelist", help="path to a file with with whitelist", default=None, type=str)
|
||||||
|
parser.add_argument(
|
||||||
|
"--cache_dir",
|
||||||
|
help="path to a dir with .far grammar file. Set to None to avoid using cache",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
start_time = perf_counter()
|
||||||
|
|
||||||
|
args = parse_args()
|
||||||
|
whitelist = os.path.abspath(args.whitelist) if args.whitelist else None
|
||||||
|
|
||||||
|
if not args.input_string and not args.input_file:
|
||||||
|
raise ValueError("Either `--text` or `--input_file` required")
|
||||||
|
|
||||||
|
normalizer = Normalizer(
|
||||||
|
input_case=args.input_case,
|
||||||
|
cache_dir=args.cache_dir,
|
||||||
|
overwrite_cache=args.overwrite_cache,
|
||||||
|
whitelist=whitelist,
|
||||||
|
lang=args.language,
|
||||||
|
)
|
||||||
|
if args.input_string:
|
||||||
|
print(
|
||||||
|
normalizer.normalize(
|
||||||
|
args.input_string,
|
||||||
|
verbose=args.verbose,
|
||||||
|
punct_pre_process=args.punct_pre_process,
|
||||||
|
punct_post_process=args.punct_post_process,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif args.input_file:
|
||||||
|
print("Loading data: " + args.input_file)
|
||||||
|
data = load_file(args.input_file)
|
||||||
|
|
||||||
|
print("- Data: " + str(len(data)) + " sentences")
|
||||||
|
normalizer_prediction = normalizer.normalize_list(
|
||||||
|
data,
|
||||||
|
verbose=args.verbose,
|
||||||
|
punct_pre_process=args.punct_pre_process,
|
||||||
|
punct_post_process=args.punct_post_process,
|
||||||
|
)
|
||||||
|
if args.output_file:
|
||||||
|
write_file(args.output_file, normalizer_prediction)
|
||||||
|
print(f"- Normalized. Writing out to {args.output_file}")
|
||||||
|
else:
|
||||||
|
print(normalizer_prediction)
|
||||||
|
|
||||||
|
print(f"Execution time: {perf_counter() - start_time:.02f} sec")
|
||||||
@@ -0,0 +1,543 @@
|
|||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
from glob import glob
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
import pynini
|
||||||
|
from joblib import Parallel, delayed
|
||||||
|
from nemo_text_processing.text_normalization.data_loader_utils import post_process_punct, pre_process
|
||||||
|
from nemo_text_processing.text_normalization.normalize import Normalizer
|
||||||
|
from pynini.lib import rewrite
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
try:
|
||||||
|
from nemo.collections.asr.metrics.wer import word_error_rate
|
||||||
|
from nemo.collections.asr.models import ASRModel
|
||||||
|
|
||||||
|
ASR_AVAILABLE = True
|
||||||
|
except (ModuleNotFoundError, ImportError):
|
||||||
|
ASR_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
The script provides multiple normalization options and chooses the best one that minimizes CER of the ASR output
|
||||||
|
(most of the semiotic classes use deterministic=False flag).
|
||||||
|
|
||||||
|
To run this script with a .json manifest file, the manifest file should contain the following fields:
|
||||||
|
"audio_data" - path to the audio file
|
||||||
|
"text" - raw text
|
||||||
|
"pred_text" - ASR model prediction
|
||||||
|
|
||||||
|
See https://github.com/NVIDIA/NeMo/blob/main/examples/asr/transcribe_speech.py on how to add ASR predictions
|
||||||
|
|
||||||
|
When the manifest is ready, run:
|
||||||
|
python normalize_with_audio.py \
|
||||||
|
--audio_data PATH/TO/MANIFEST.JSON \
|
||||||
|
--language en
|
||||||
|
|
||||||
|
|
||||||
|
To run with a single audio file, specify path to audio and text with:
|
||||||
|
python normalize_with_audio.py \
|
||||||
|
--audio_data PATH/TO/AUDIO.WAV \
|
||||||
|
--language en \
|
||||||
|
--text raw text OR PATH/TO/.TXT/FILE
|
||||||
|
--model QuartzNet15x5Base-En \
|
||||||
|
--verbose
|
||||||
|
|
||||||
|
To see possible normalization options for a text input without an audio file (could be used for debugging), run:
|
||||||
|
python python normalize_with_audio.py --text "RAW TEXT"
|
||||||
|
|
||||||
|
Specify `--cache_dir` to generate .far grammars once and re-used them for faster inference
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class NormalizerWithAudio(Normalizer):
|
||||||
|
"""
|
||||||
|
Normalizer class that converts text from written to spoken form.
|
||||||
|
Useful for TTS preprocessing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_case: expected input capitalization
|
||||||
|
lang: language
|
||||||
|
cache_dir: path to a dir with .far grammar file. Set to None to avoid using cache.
|
||||||
|
overwrite_cache: set to True to overwrite .far files
|
||||||
|
whitelist: path to a file with whitelist replacements
|
||||||
|
post_process: WFST-based post processing, e.g. to remove extra spaces added during TN.
|
||||||
|
Note: punct_post_process flag in normalize() supports all languages.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_case: str,
|
||||||
|
lang: str = 'en',
|
||||||
|
cache_dir: str = None,
|
||||||
|
overwrite_cache: bool = False,
|
||||||
|
whitelist: str = None,
|
||||||
|
lm: bool = False,
|
||||||
|
post_process: bool = True,
|
||||||
|
):
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
input_case=input_case,
|
||||||
|
lang=lang,
|
||||||
|
deterministic=False,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
overwrite_cache=overwrite_cache,
|
||||||
|
whitelist=whitelist,
|
||||||
|
lm=lm,
|
||||||
|
post_process=post_process,
|
||||||
|
)
|
||||||
|
self.lm = lm
|
||||||
|
|
||||||
|
def normalize(self, text: str, n_tagged: int, punct_post_process: bool = True, verbose: bool = False,) -> str:
|
||||||
|
"""
|
||||||
|
Main function. Normalizes tokens from written to spoken form
|
||||||
|
e.g. 12 kg -> twelve kilograms
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: string that may include semiotic classes
|
||||||
|
n_tagged: number of tagged options to consider, -1 - to get all possible tagged options
|
||||||
|
punct_post_process: whether to normalize punctuation
|
||||||
|
verbose: whether to print intermediate meta information
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
normalized text options (usually there are multiple ways of normalizing a given semiotic class)
|
||||||
|
"""
|
||||||
|
|
||||||
|
if len(text.split()) > 500:
|
||||||
|
raise ValueError(
|
||||||
|
"Your input is too long. Please split up the input into sentences, "
|
||||||
|
"or strings with fewer than 500 words"
|
||||||
|
)
|
||||||
|
|
||||||
|
original_text = text
|
||||||
|
text = pre_process(text) # to handle []
|
||||||
|
|
||||||
|
text = text.strip()
|
||||||
|
if not text:
|
||||||
|
if verbose:
|
||||||
|
print(text)
|
||||||
|
return text
|
||||||
|
text = pynini.escape(text)
|
||||||
|
print(text)
|
||||||
|
|
||||||
|
if self.lm:
|
||||||
|
if self.lang not in ["en"]:
|
||||||
|
raise ValueError(f"{self.lang} is not supported in LM mode")
|
||||||
|
|
||||||
|
if self.lang == "en":
|
||||||
|
# this to keep arpabet phonemes in the list of options
|
||||||
|
if "[" in text and "]" in text:
|
||||||
|
|
||||||
|
lattice = rewrite.rewrite_lattice(text, self.tagger.fst)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
lattice = rewrite.rewrite_lattice(text, self.tagger.fst_no_digits)
|
||||||
|
except pynini.lib.rewrite.Error:
|
||||||
|
lattice = rewrite.rewrite_lattice(text, self.tagger.fst)
|
||||||
|
lattice = rewrite.lattice_to_nshortest(lattice, n_tagged)
|
||||||
|
tagged_texts = [(x[1], float(x[2])) for x in lattice.paths().items()]
|
||||||
|
tagged_texts.sort(key=lambda x: x[1])
|
||||||
|
tagged_texts, weights = list(zip(*tagged_texts))
|
||||||
|
else:
|
||||||
|
tagged_texts = self._get_tagged_text(text, n_tagged)
|
||||||
|
# non-deterministic Eng normalization uses tagger composed with verbalizer, no permutation in between
|
||||||
|
if self.lang == "en":
|
||||||
|
normalized_texts = tagged_texts
|
||||||
|
normalized_texts = [self.post_process(text) for text in normalized_texts]
|
||||||
|
else:
|
||||||
|
normalized_texts = []
|
||||||
|
for tagged_text in tagged_texts:
|
||||||
|
self._verbalize(tagged_text, normalized_texts, verbose=verbose)
|
||||||
|
|
||||||
|
if len(normalized_texts) == 0:
|
||||||
|
raise ValueError()
|
||||||
|
|
||||||
|
if punct_post_process:
|
||||||
|
# do post-processing based on Moses detokenizer
|
||||||
|
if self.processor:
|
||||||
|
normalized_texts = [self.processor.detokenize([t]) for t in normalized_texts]
|
||||||
|
normalized_texts = [
|
||||||
|
post_process_punct(input=original_text, normalized_text=t) for t in normalized_texts
|
||||||
|
]
|
||||||
|
|
||||||
|
if self.lm:
|
||||||
|
remove_dup = sorted(list(set(zip(normalized_texts, weights))), key=lambda x: x[1])
|
||||||
|
normalized_texts, weights = zip(*remove_dup)
|
||||||
|
return list(normalized_texts), weights
|
||||||
|
|
||||||
|
normalized_texts = set(normalized_texts)
|
||||||
|
return normalized_texts
|
||||||
|
|
||||||
|
def _get_tagged_text(self, text, n_tagged):
|
||||||
|
"""
|
||||||
|
Returns text after tokenize and classify
|
||||||
|
Args;
|
||||||
|
text: input text
|
||||||
|
n_tagged: number of tagged options to consider, -1 - return all possible tagged options
|
||||||
|
"""
|
||||||
|
if n_tagged == -1:
|
||||||
|
if self.lang == "en":
|
||||||
|
# this to keep arpabet phonemes in the list of options
|
||||||
|
if "[" in text and "]" in text:
|
||||||
|
tagged_texts = rewrite.rewrites(text, self.tagger.fst)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
tagged_texts = rewrite.rewrites(text, self.tagger.fst_no_digits)
|
||||||
|
except pynini.lib.rewrite.Error:
|
||||||
|
tagged_texts = rewrite.rewrites(text, self.tagger.fst)
|
||||||
|
else:
|
||||||
|
tagged_texts = rewrite.rewrites(text, self.tagger.fst)
|
||||||
|
else:
|
||||||
|
if self.lang == "en":
|
||||||
|
# this to keep arpabet phonemes in the list of options
|
||||||
|
if "[" in text and "]" in text:
|
||||||
|
tagged_texts = rewrite.top_rewrites(text, self.tagger.fst, nshortest=n_tagged)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
# try self.tagger graph that produces output without digits
|
||||||
|
tagged_texts = rewrite.top_rewrites(text, self.tagger.fst_no_digits, nshortest=n_tagged)
|
||||||
|
except pynini.lib.rewrite.Error:
|
||||||
|
tagged_texts = rewrite.top_rewrites(text, self.tagger.fst, nshortest=n_tagged)
|
||||||
|
else:
|
||||||
|
tagged_texts = rewrite.top_rewrites(text, self.tagger.fst, nshortest=n_tagged)
|
||||||
|
return tagged_texts
|
||||||
|
|
||||||
|
def _verbalize(self, tagged_text: str, normalized_texts: List[str], verbose: bool = False):
|
||||||
|
"""
|
||||||
|
Verbalizes tagged text
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tagged_text: text with tags
|
||||||
|
normalized_texts: list of possible normalization options
|
||||||
|
verbose: if true prints intermediate classification results
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_verbalized_text(tagged_text):
|
||||||
|
return rewrite.rewrites(tagged_text, self.verbalizer.fst)
|
||||||
|
|
||||||
|
self.parser(tagged_text)
|
||||||
|
tokens = self.parser.parse()
|
||||||
|
tags_reordered = self.generate_permutations(tokens)
|
||||||
|
for tagged_text_reordered in tags_reordered:
|
||||||
|
try:
|
||||||
|
tagged_text_reordered = pynini.escape(tagged_text_reordered)
|
||||||
|
normalized_texts.extend(get_verbalized_text(tagged_text_reordered))
|
||||||
|
if verbose:
|
||||||
|
print(tagged_text_reordered)
|
||||||
|
|
||||||
|
except pynini.lib.rewrite.Error:
|
||||||
|
continue
|
||||||
|
|
||||||
|
def select_best_match(
|
||||||
|
self,
|
||||||
|
normalized_texts: List[str],
|
||||||
|
input_text: str,
|
||||||
|
pred_text: str,
|
||||||
|
verbose: bool = False,
|
||||||
|
remove_punct: bool = False,
|
||||||
|
cer_threshold: int = 100,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Selects the best normalization option based on the lowest CER
|
||||||
|
|
||||||
|
Args:
|
||||||
|
normalized_texts: normalized text options
|
||||||
|
input_text: input text
|
||||||
|
pred_text: ASR model transcript of the audio file corresponding to the normalized text
|
||||||
|
verbose: whether to print intermediate meta information
|
||||||
|
remove_punct: whether to remove punctuation before calculating CER
|
||||||
|
cer_threshold: if CER for pred_text is above the cer_threshold, no normalization will be performed
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
normalized text with the lowest CER and CER value
|
||||||
|
"""
|
||||||
|
if pred_text == "":
|
||||||
|
return input_text, cer_threshold
|
||||||
|
|
||||||
|
normalized_texts_cer = calculate_cer(normalized_texts, pred_text, remove_punct)
|
||||||
|
normalized_texts_cer = sorted(normalized_texts_cer, key=lambda x: x[1])
|
||||||
|
normalized_text, cer = normalized_texts_cer[0]
|
||||||
|
|
||||||
|
if cer > cer_threshold:
|
||||||
|
return input_text, cer
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print('-' * 30)
|
||||||
|
for option in normalized_texts:
|
||||||
|
print(option)
|
||||||
|
print('-' * 30)
|
||||||
|
return normalized_text, cer
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_cer(normalized_texts: List[str], pred_text: str, remove_punct=False) -> List[Tuple[str, float]]:
|
||||||
|
"""
|
||||||
|
Calculates character error rate (CER)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
normalized_texts: normalized text options
|
||||||
|
pred_text: ASR model output
|
||||||
|
|
||||||
|
Returns: normalized options with corresponding CER
|
||||||
|
"""
|
||||||
|
normalized_options = []
|
||||||
|
for text in normalized_texts:
|
||||||
|
text_clean = text.replace('-', ' ').lower()
|
||||||
|
if remove_punct:
|
||||||
|
for punct in "!?:;,.-()*+-/<=>@^_":
|
||||||
|
text_clean = text_clean.replace(punct, "")
|
||||||
|
cer = round(word_error_rate([pred_text], [text_clean], use_cer=True) * 100, 2)
|
||||||
|
normalized_options.append((text, cer))
|
||||||
|
return normalized_options
|
||||||
|
|
||||||
|
|
||||||
|
def get_asr_model(asr_model):
|
||||||
|
"""
|
||||||
|
Returns ASR Model
|
||||||
|
|
||||||
|
Args:
|
||||||
|
asr_model: NeMo ASR model
|
||||||
|
"""
|
||||||
|
if os.path.exists(args.model):
|
||||||
|
asr_model = ASRModel.restore_from(asr_model)
|
||||||
|
elif args.model in ASRModel.get_available_model_names():
|
||||||
|
asr_model = ASRModel.from_pretrained(asr_model)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f'Provide path to the pretrained checkpoint or choose from {ASRModel.get_available_model_names()}'
|
||||||
|
)
|
||||||
|
return asr_model
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = ArgumentParser()
|
||||||
|
parser.add_argument("--text", help="input string or path to a .txt file", default=None, type=str)
|
||||||
|
parser.add_argument(
|
||||||
|
"--input_case", help="input capitalization", choices=["lower_cased", "cased"], default="cased", type=str
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--language", help="Select target language", choices=["en", "ru", "de", "es"], default="en", type=str
|
||||||
|
)
|
||||||
|
parser.add_argument("--audio_data", default=None, help="path to an audio file or .json manifest")
|
||||||
|
parser.add_argument(
|
||||||
|
'--model', type=str, default='QuartzNet15x5Base-En', help='Pre-trained model name or path to model checkpoint'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--n_tagged",
|
||||||
|
type=int,
|
||||||
|
default=30,
|
||||||
|
help="number of tagged options to consider, -1 - return all possible tagged options",
|
||||||
|
)
|
||||||
|
parser.add_argument("--verbose", help="print info for debugging", action="store_true")
|
||||||
|
parser.add_argument(
|
||||||
|
"--no_remove_punct_for_cer",
|
||||||
|
help="Set to True to NOT remove punctuation before calculating CER",
|
||||||
|
action="store_true",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--no_punct_post_process", help="set to True to disable punctuation post processing", action="store_true"
|
||||||
|
)
|
||||||
|
parser.add_argument("--overwrite_cache", help="set to True to re-create .far grammar files", action="store_true")
|
||||||
|
parser.add_argument("--whitelist", help="path to a file with with whitelist", default=None, type=str)
|
||||||
|
parser.add_argument(
|
||||||
|
"--cache_dir",
|
||||||
|
help="path to a dir with .far grammar file. Set to None to avoid using cache",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
)
|
||||||
|
parser.add_argument("--n_jobs", default=-2, type=int, help="The maximum number of concurrently running jobs")
|
||||||
|
parser.add_argument(
|
||||||
|
"--lm", action="store_true", help="Set to True for WFST+LM. Only available for English right now."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--cer_threshold",
|
||||||
|
default=100,
|
||||||
|
type=int,
|
||||||
|
help="if CER for pred_text is above the cer_threshold, no normalization will be performed",
|
||||||
|
)
|
||||||
|
parser.add_argument("--batch_size", default=200, type=int, help="Number of examples for each process")
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_line(
|
||||||
|
normalizer: NormalizerWithAudio, n_tagged, verbose, line: str, remove_punct, punct_post_process, cer_threshold
|
||||||
|
):
|
||||||
|
line = json.loads(line)
|
||||||
|
pred_text = line["pred_text"]
|
||||||
|
|
||||||
|
normalized_texts = normalizer.normalize(
|
||||||
|
text=line["text"], verbose=verbose, n_tagged=n_tagged, punct_post_process=punct_post_process,
|
||||||
|
)
|
||||||
|
|
||||||
|
normalized_texts = set(normalized_texts)
|
||||||
|
normalized_text, cer = normalizer.select_best_match(
|
||||||
|
normalized_texts=normalized_texts,
|
||||||
|
input_text=line["text"],
|
||||||
|
pred_text=pred_text,
|
||||||
|
verbose=verbose,
|
||||||
|
remove_punct=remove_punct,
|
||||||
|
cer_threshold=cer_threshold,
|
||||||
|
)
|
||||||
|
line["nemo_normalized"] = normalized_text
|
||||||
|
line["CER_nemo_normalized"] = cer
|
||||||
|
return line
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_manifest(
|
||||||
|
normalizer,
|
||||||
|
audio_data: str,
|
||||||
|
n_jobs: int,
|
||||||
|
n_tagged: int,
|
||||||
|
remove_punct: bool,
|
||||||
|
punct_post_process: bool,
|
||||||
|
batch_size: int,
|
||||||
|
cer_threshold: int,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
args.audio_data: path to .json manifest file.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __process_batch(batch_idx: int, batch: List[str], dir_name: str):
|
||||||
|
"""
|
||||||
|
Normalizes batch of text sequences
|
||||||
|
Args:
|
||||||
|
batch: list of texts
|
||||||
|
batch_idx: batch index
|
||||||
|
dir_name: path to output directory to save results
|
||||||
|
"""
|
||||||
|
normalized_lines = [
|
||||||
|
_normalize_line(
|
||||||
|
normalizer,
|
||||||
|
n_tagged,
|
||||||
|
verbose=False,
|
||||||
|
line=line,
|
||||||
|
remove_punct=remove_punct,
|
||||||
|
punct_post_process=punct_post_process,
|
||||||
|
cer_threshold=cer_threshold,
|
||||||
|
)
|
||||||
|
for line in tqdm(batch)
|
||||||
|
]
|
||||||
|
|
||||||
|
with open(f"{dir_name}/{batch_idx:05}.json", "w") as f_out:
|
||||||
|
for line in normalized_lines:
|
||||||
|
f_out.write(json.dumps(line, ensure_ascii=False) + '\n')
|
||||||
|
|
||||||
|
print(f"Batch -- {batch_idx} -- is complete")
|
||||||
|
|
||||||
|
manifest_out = audio_data.replace('.json', '_normalized.json')
|
||||||
|
with open(audio_data, 'r') as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
|
||||||
|
print(f'Normalizing {len(lines)} lines of {audio_data}...')
|
||||||
|
|
||||||
|
# to save intermediate results to a file
|
||||||
|
batch = min(len(lines), batch_size)
|
||||||
|
|
||||||
|
tmp_dir = manifest_out.replace(".json", "_parts")
|
||||||
|
os.makedirs(tmp_dir, exist_ok=True)
|
||||||
|
|
||||||
|
Parallel(n_jobs=n_jobs)(
|
||||||
|
delayed(__process_batch)(idx, lines[i : i + batch], tmp_dir)
|
||||||
|
for idx, i in enumerate(range(0, len(lines), batch))
|
||||||
|
)
|
||||||
|
|
||||||
|
# aggregate all intermediate files
|
||||||
|
with open(manifest_out, "w") as f_out:
|
||||||
|
for batch_f in sorted(glob(f"{tmp_dir}/*.json")):
|
||||||
|
with open(batch_f, "r") as f_in:
|
||||||
|
lines = f_in.read()
|
||||||
|
f_out.write(lines)
|
||||||
|
|
||||||
|
print(f'Normalized version saved at {manifest_out}')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
if not ASR_AVAILABLE and args.audio_data:
|
||||||
|
raise ValueError("NeMo ASR collection is not installed.")
|
||||||
|
start = time.time()
|
||||||
|
args.whitelist = os.path.abspath(args.whitelist) if args.whitelist else None
|
||||||
|
if args.text is not None:
|
||||||
|
normalizer = NormalizerWithAudio(
|
||||||
|
input_case=args.input_case,
|
||||||
|
lang=args.language,
|
||||||
|
cache_dir=args.cache_dir,
|
||||||
|
overwrite_cache=args.overwrite_cache,
|
||||||
|
whitelist=args.whitelist,
|
||||||
|
lm=args.lm,
|
||||||
|
)
|
||||||
|
|
||||||
|
if os.path.exists(args.text):
|
||||||
|
with open(args.text, 'r') as f:
|
||||||
|
args.text = f.read().strip()
|
||||||
|
normalized_texts = normalizer.normalize(
|
||||||
|
text=args.text,
|
||||||
|
verbose=args.verbose,
|
||||||
|
n_tagged=args.n_tagged,
|
||||||
|
punct_post_process=not args.no_punct_post_process,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not normalizer.lm:
|
||||||
|
normalized_texts = set(normalized_texts)
|
||||||
|
if args.audio_data:
|
||||||
|
asr_model = get_asr_model(args.model)
|
||||||
|
pred_text = asr_model.transcribe([args.audio_data])[0]
|
||||||
|
normalized_text, cer = normalizer.select_best_match(
|
||||||
|
normalized_texts=normalized_texts,
|
||||||
|
pred_text=pred_text,
|
||||||
|
input_text=args.text,
|
||||||
|
verbose=args.verbose,
|
||||||
|
remove_punct=not args.no_remove_punct_for_cer,
|
||||||
|
cer_threshold=args.cer_threshold,
|
||||||
|
)
|
||||||
|
print(f"Transcript: {pred_text}")
|
||||||
|
print(f"Normalized: {normalized_text}")
|
||||||
|
else:
|
||||||
|
print("Normalization options:")
|
||||||
|
for norm_text in normalized_texts:
|
||||||
|
print(norm_text)
|
||||||
|
elif not os.path.exists(args.audio_data):
|
||||||
|
raise ValueError(f"{args.audio_data} not found.")
|
||||||
|
elif args.audio_data.endswith('.json'):
|
||||||
|
normalizer = NormalizerWithAudio(
|
||||||
|
input_case=args.input_case,
|
||||||
|
lang=args.language,
|
||||||
|
cache_dir=args.cache_dir,
|
||||||
|
overwrite_cache=args.overwrite_cache,
|
||||||
|
whitelist=args.whitelist,
|
||||||
|
)
|
||||||
|
normalize_manifest(
|
||||||
|
normalizer=normalizer,
|
||||||
|
audio_data=args.audio_data,
|
||||||
|
n_jobs=args.n_jobs,
|
||||||
|
n_tagged=args.n_tagged,
|
||||||
|
remove_punct=not args.no_remove_punct_for_cer,
|
||||||
|
punct_post_process=not args.no_punct_post_process,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
cer_threshold=args.cer_threshold,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Provide either path to .json manifest in '--audio_data' OR "
|
||||||
|
+ "'--audio_data' path to audio file and '--text' path to a text file OR"
|
||||||
|
"'--text' string text (for debugging without audio)"
|
||||||
|
)
|
||||||
|
print(f'Execution time: {round((time.time() - start)/60, 2)} min.')
|
||||||
@@ -0,0 +1,117 @@
|
|||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
|
||||||
|
from nemo_text_processing.text_normalization.data_loader_utils import (
|
||||||
|
evaluate,
|
||||||
|
known_types,
|
||||||
|
load_files,
|
||||||
|
training_data_to_sentences,
|
||||||
|
training_data_to_tokens,
|
||||||
|
)
|
||||||
|
from nemo_text_processing.text_normalization.normalize import Normalizer
|
||||||
|
|
||||||
|
|
||||||
|
'''
|
||||||
|
Runs Evaluation on data in the format of : <semiotic class>\t<unnormalized text>\t<`self` if trivial class or normalized text>
|
||||||
|
like the Google text normalization data https://www.kaggle.com/richardwilliamsproat/text-normalization-for-english-russian-and-polish
|
||||||
|
'''
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = ArgumentParser()
|
||||||
|
parser.add_argument("--input", help="input file path", type=str)
|
||||||
|
parser.add_argument("--lang", help="language", choices=['en'], default="en", type=str)
|
||||||
|
parser.add_argument(
|
||||||
|
"--input_case", help="input capitalization", choices=["lower_cased", "cased"], default="cased", type=str
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--cat",
|
||||||
|
dest="category",
|
||||||
|
help="focus on class only (" + ", ".join(known_types) + ")",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
choices=known_types,
|
||||||
|
)
|
||||||
|
parser.add_argument("--filter", action='store_true', help="clean data for normalization purposes")
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Example usage:
|
||||||
|
# python run_evaluate.py --input=<INPUT> --cat=<CATEGORY> --filter
|
||||||
|
args = parse_args()
|
||||||
|
if args.lang == 'en':
|
||||||
|
from nemo_text_processing.text_normalization.en.clean_eval_data import filter_loaded_data
|
||||||
|
file_path = args.input
|
||||||
|
normalizer = Normalizer(input_case=args.input_case, lang=args.lang)
|
||||||
|
|
||||||
|
print("Loading training data: " + file_path)
|
||||||
|
training_data = load_files([file_path])
|
||||||
|
|
||||||
|
if args.filter:
|
||||||
|
training_data = filter_loaded_data(training_data)
|
||||||
|
|
||||||
|
if args.category is None:
|
||||||
|
print("Sentence level evaluation...")
|
||||||
|
sentences_un_normalized, sentences_normalized, _ = training_data_to_sentences(training_data)
|
||||||
|
print("- Data: " + str(len(sentences_normalized)) + " sentences")
|
||||||
|
sentences_prediction = normalizer.normalize_list(sentences_un_normalized)
|
||||||
|
print("- Normalized. Evaluating...")
|
||||||
|
sentences_accuracy = evaluate(
|
||||||
|
preds=sentences_prediction, labels=sentences_normalized, input=sentences_un_normalized
|
||||||
|
)
|
||||||
|
print("- Accuracy: " + str(sentences_accuracy))
|
||||||
|
|
||||||
|
print("Token level evaluation...")
|
||||||
|
tokens_per_type = training_data_to_tokens(training_data, category=args.category)
|
||||||
|
token_accuracy = {}
|
||||||
|
for token_type in tokens_per_type:
|
||||||
|
print("- Token type: " + token_type)
|
||||||
|
tokens_un_normalized, tokens_normalized = tokens_per_type[token_type]
|
||||||
|
print(" - Data: " + str(len(tokens_normalized)) + " tokens")
|
||||||
|
tokens_prediction = normalizer.normalize_list(tokens_un_normalized)
|
||||||
|
print(" - Denormalized. Evaluating...")
|
||||||
|
token_accuracy[token_type] = evaluate(
|
||||||
|
preds=tokens_prediction, labels=tokens_normalized, input=tokens_un_normalized
|
||||||
|
)
|
||||||
|
print(" - Accuracy: " + str(token_accuracy[token_type]))
|
||||||
|
token_count_per_type = {token_type: len(tokens_per_type[token_type][0]) for token_type in tokens_per_type}
|
||||||
|
token_weighted_accuracy = [
|
||||||
|
token_count_per_type[token_type] * accuracy for token_type, accuracy in token_accuracy.items()
|
||||||
|
]
|
||||||
|
print("- Accuracy: " + str(sum(token_weighted_accuracy) / sum(token_count_per_type.values())))
|
||||||
|
print(" - Total: " + str(sum(token_count_per_type.values())), '\n')
|
||||||
|
|
||||||
|
print(" - Total: " + str(sum(token_count_per_type.values())), '\n')
|
||||||
|
|
||||||
|
for token_type in token_accuracy:
|
||||||
|
if token_type not in known_types:
|
||||||
|
raise ValueError("Unexpected token type: " + token_type)
|
||||||
|
|
||||||
|
if args.category is None:
|
||||||
|
c1 = ['Class', 'sent level'] + known_types
|
||||||
|
c2 = ['Num Tokens', len(sentences_normalized)] + [
|
||||||
|
token_count_per_type[known_type] if known_type in tokens_per_type else '0' for known_type in known_types
|
||||||
|
]
|
||||||
|
c3 = ['Normalization', sentences_accuracy] + [
|
||||||
|
token_accuracy[known_type] if known_type in token_accuracy else '0' for known_type in known_types
|
||||||
|
]
|
||||||
|
|
||||||
|
for i in range(len(c1)):
|
||||||
|
print(f'{str(c1[i]):10s} | {str(c2[i]):10s} | {str(c3[i]):5s}')
|
||||||
|
else:
|
||||||
|
print(f'numbers\t{token_count_per_type[args.category]}')
|
||||||
|
print(f'Normalization\t{token_accuracy[args.category]}')
|
||||||
@@ -0,0 +1,192 @@
|
|||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import string
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Dict, List, Union
|
||||||
|
|
||||||
|
PRESERVE_ORDER_KEY = "preserve_order"
|
||||||
|
EOS = "<EOS>"
|
||||||
|
|
||||||
|
|
||||||
|
class TokenParser:
|
||||||
|
"""
|
||||||
|
Parses tokenized/classified text, e.g. 'tokens { money { integer: "20" currency: "$" } } tokens { name: "left"}'
|
||||||
|
|
||||||
|
Args
|
||||||
|
text: tokenized text
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __call__(self, text):
|
||||||
|
"""
|
||||||
|
Setup function
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: text to be parsed
|
||||||
|
|
||||||
|
"""
|
||||||
|
self.text = text
|
||||||
|
self.len_text = len(text)
|
||||||
|
self.char = text[0] # cannot handle empty string
|
||||||
|
self.index = 0
|
||||||
|
|
||||||
|
def parse(self) -> List[dict]:
|
||||||
|
"""
|
||||||
|
Main function. Implements grammar:
|
||||||
|
A -> space F space F space F ... space
|
||||||
|
|
||||||
|
Returns list of dictionaries
|
||||||
|
"""
|
||||||
|
l = list()
|
||||||
|
while self.parse_ws():
|
||||||
|
token = self.parse_token()
|
||||||
|
if not token:
|
||||||
|
break
|
||||||
|
l.append(token)
|
||||||
|
return l
|
||||||
|
|
||||||
|
def parse_token(self) -> Dict[str, Union[str, dict]]:
|
||||||
|
"""
|
||||||
|
Implements grammar:
|
||||||
|
F-> no_space KG no_space
|
||||||
|
|
||||||
|
Returns: K, G as dictionary values
|
||||||
|
"""
|
||||||
|
d = OrderedDict()
|
||||||
|
key = self.parse_string_key()
|
||||||
|
if key is None:
|
||||||
|
return None
|
||||||
|
self.parse_ws()
|
||||||
|
if key == PRESERVE_ORDER_KEY:
|
||||||
|
self.parse_char(":")
|
||||||
|
self.parse_ws()
|
||||||
|
value = self.parse_chars("true")
|
||||||
|
else:
|
||||||
|
value = self.parse_token_value()
|
||||||
|
|
||||||
|
d[key] = value
|
||||||
|
return d
|
||||||
|
|
||||||
|
def parse_token_value(self) -> Union[str, dict]:
|
||||||
|
"""
|
||||||
|
Implements grammar:
|
||||||
|
G-> no_space :"VALUE" no_space | no_space {A} no_space
|
||||||
|
|
||||||
|
Returns: string or dictionary
|
||||||
|
"""
|
||||||
|
if self.char == ":":
|
||||||
|
self.parse_char(":")
|
||||||
|
self.parse_ws()
|
||||||
|
self.parse_char("\"")
|
||||||
|
value_string = self.parse_string_value()
|
||||||
|
self.parse_char("\"")
|
||||||
|
return value_string
|
||||||
|
elif self.char == "{":
|
||||||
|
d = OrderedDict()
|
||||||
|
self.parse_char("{")
|
||||||
|
list_token_dicts = self.parse()
|
||||||
|
# flatten tokens
|
||||||
|
for tok_dict in list_token_dicts:
|
||||||
|
for k, v in tok_dict.items():
|
||||||
|
d[k] = v
|
||||||
|
self.parse_char("}")
|
||||||
|
return d
|
||||||
|
else:
|
||||||
|
raise ValueError()
|
||||||
|
|
||||||
|
def parse_char(self, exp) -> bool:
|
||||||
|
"""
|
||||||
|
Parses character
|
||||||
|
|
||||||
|
Args:
|
||||||
|
exp: character to read in
|
||||||
|
|
||||||
|
Returns true if successful
|
||||||
|
"""
|
||||||
|
assert self.char == exp
|
||||||
|
self.read()
|
||||||
|
return True
|
||||||
|
|
||||||
|
def parse_chars(self, exp) -> bool:
|
||||||
|
"""
|
||||||
|
Parses characters
|
||||||
|
|
||||||
|
Args:
|
||||||
|
exp: characters to read in
|
||||||
|
|
||||||
|
Returns true if successful
|
||||||
|
"""
|
||||||
|
ok = False
|
||||||
|
for x in exp:
|
||||||
|
ok |= self.parse_char(x)
|
||||||
|
return ok
|
||||||
|
|
||||||
|
def parse_string_key(self) -> str:
|
||||||
|
"""
|
||||||
|
Parses string key, can only contain ascii and '_' characters
|
||||||
|
|
||||||
|
Returns parsed string key
|
||||||
|
"""
|
||||||
|
assert self.char not in string.whitespace and self.char != EOS
|
||||||
|
incl_criterium = string.ascii_letters + "_"
|
||||||
|
l = []
|
||||||
|
while self.char in incl_criterium:
|
||||||
|
l.append(self.char)
|
||||||
|
if not self.read():
|
||||||
|
raise ValueError()
|
||||||
|
|
||||||
|
if not l:
|
||||||
|
return None
|
||||||
|
return "".join(l)
|
||||||
|
|
||||||
|
def parse_string_value(self) -> str:
|
||||||
|
"""
|
||||||
|
Parses string value, ends with quote followed by space
|
||||||
|
|
||||||
|
Returns parsed string value
|
||||||
|
"""
|
||||||
|
assert self.char not in string.whitespace and self.char != EOS
|
||||||
|
l = []
|
||||||
|
while self.char != "\"" or self.text[self.index + 1] != " ":
|
||||||
|
l.append(self.char)
|
||||||
|
if not self.read():
|
||||||
|
raise ValueError()
|
||||||
|
|
||||||
|
if not l:
|
||||||
|
return None
|
||||||
|
return "".join(l)
|
||||||
|
|
||||||
|
def parse_ws(self):
|
||||||
|
"""
|
||||||
|
Deletes whitespaces.
|
||||||
|
|
||||||
|
Returns true if not EOS after parsing
|
||||||
|
"""
|
||||||
|
not_eos = self.char != EOS
|
||||||
|
while not_eos and self.char == " ":
|
||||||
|
not_eos = self.read()
|
||||||
|
return not_eos
|
||||||
|
|
||||||
|
def read(self):
|
||||||
|
"""
|
||||||
|
Reads in next char.
|
||||||
|
|
||||||
|
Returns true if not EOS
|
||||||
|
"""
|
||||||
|
if self.index < self.len_text - 1: # should be unique
|
||||||
|
self.index += 1
|
||||||
|
self.char = self.text[self.index]
|
||||||
|
return True
|
||||||
|
self.char = EOS
|
||||||
|
return False
|
||||||
100
utils/speechio/textnorm_en.py
Normal file
100
utils/speechio/textnorm_en.py
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2022 Ruiqi WANG, Jinpeng LI, Jiayu DU
|
||||||
|
#
|
||||||
|
# only tested and validated on pynini v2.1.5 via : 'conda install -c conda-forge pynini'
|
||||||
|
# pynini v2.1.0 doesn't work
|
||||||
|
#
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import string
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from nemo_text_processing.text_normalization.normalize import Normalizer
|
||||||
|
|
||||||
|
|
||||||
|
def read_interjections(filepath):
|
||||||
|
interjections = []
|
||||||
|
with open(filepath) as f:
|
||||||
|
for line in f:
|
||||||
|
words = [x.strip() for x in line.split(',')]
|
||||||
|
interjections += [w for w in words] + [w.upper() for w in words] + [w.lower() for w in words]
|
||||||
|
return list(set(interjections)) # deduplicated
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
p = argparse.ArgumentParser()
|
||||||
|
p.add_argument('ifile', help='input filename, assume utf-8 encoding')
|
||||||
|
p.add_argument('ofile', help='output filename')
|
||||||
|
p.add_argument('--to_upper', action='store_true', help='convert to upper case')
|
||||||
|
p.add_argument('--to_lower', action='store_true', help='convert to lower case')
|
||||||
|
p.add_argument('--has_key', action='store_true', help="input text has Kaldi's key as first field.")
|
||||||
|
p.add_argument('--log_interval', type=int, default=10000, help='log interval in number of processed lines')
|
||||||
|
args = p.parse_args()
|
||||||
|
|
||||||
|
nemo_tn_en = Normalizer(input_case='lower_cased', lang='en')
|
||||||
|
|
||||||
|
itj = read_interjections(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'interjections_en.csv'))
|
||||||
|
itj_map = {x: True for x in itj}
|
||||||
|
|
||||||
|
certain_single_quote_items = ["\"'", "'?", "'!", "'.", "?'", "!'", ".'", "''", "<BOS>'", "'<EOS>"]
|
||||||
|
single_quote_removed_items = [x.replace("'", '') for x in certain_single_quote_items]
|
||||||
|
|
||||||
|
puncts_to_remove = string.punctuation.replace("'", '') + "—–“”"
|
||||||
|
puncts_trans = str.maketrans(puncts_to_remove, ' ' * len(puncts_to_remove), '')
|
||||||
|
|
||||||
|
n = 0
|
||||||
|
with open(args.ifile, 'r', encoding='utf8') as fi, open(args.ofile, 'w+', encoding='utf8') as fo:
|
||||||
|
for line in fi:
|
||||||
|
if args.has_key:
|
||||||
|
cols = line.strip().split(maxsplit=1)
|
||||||
|
key, text = cols[0].strip(), cols[1].strip() if len(cols) == 2 else ''
|
||||||
|
else:
|
||||||
|
text = line.strip()
|
||||||
|
|
||||||
|
text = text.replace("‘", "'").replace("’", "'")
|
||||||
|
|
||||||
|
# nemo text normalization
|
||||||
|
# modifications to NeMo:
|
||||||
|
# 1. added UK to US conversion: nemo_text_processing/text_normalization/en/data/whitelist/UK_to_US.tsv
|
||||||
|
# 2. swith 'oh' to 'o' in year TN to avoid confusion with interjections, e.g.:
|
||||||
|
# 1805: eighteen oh five -> eighteen o five
|
||||||
|
text = nemo_tn_en.normalize(text.lower())
|
||||||
|
|
||||||
|
# Punctuations
|
||||||
|
# NOTE(2022.10 Jiayu):
|
||||||
|
# Single quote removal is not perfect.
|
||||||
|
# ' needs to be reserved for:
|
||||||
|
# Abbreviations:
|
||||||
|
# I'm, don't, she'd, 'cause, Sweet Child o' Mine, Guns N' Roses, ...
|
||||||
|
# Possessions:
|
||||||
|
# John's, the king's, parents', ...
|
||||||
|
text = '<BOS>' + text + '<EOS>'
|
||||||
|
for x, y in zip(certain_single_quote_items, single_quote_removed_items):
|
||||||
|
text = text.replace(x, y)
|
||||||
|
text = text.replace('<BOS>', '').replace('<EOS>', '')
|
||||||
|
|
||||||
|
text = text.translate(puncts_trans).replace(" ' ", " ")
|
||||||
|
|
||||||
|
# Interjections
|
||||||
|
text = ' '.join([x for x in text.strip().split() if x not in itj_map])
|
||||||
|
|
||||||
|
# Cases
|
||||||
|
if args.to_upper and args.to_lower:
|
||||||
|
sys.stderr.write('text norm: to_upper OR to_lower?')
|
||||||
|
exit(1)
|
||||||
|
if args.to_upper:
|
||||||
|
text = text.upper()
|
||||||
|
if args.to_lower:
|
||||||
|
text = text.lower()
|
||||||
|
|
||||||
|
if args.has_key:
|
||||||
|
print(key + '\t' + text, file=fo)
|
||||||
|
else:
|
||||||
|
print(text, file=fo)
|
||||||
|
|
||||||
|
n += 1
|
||||||
|
if n % args.log_interval == 0:
|
||||||
|
print(f'text norm: {n} lines done.', file=sys.stderr)
|
||||||
|
print(f'text norm: {n} lines done in total.', file=sys.stderr)
|
||||||
1204
utils/speechio/textnorm_zh.py
Normal file
1204
utils/speechio/textnorm_zh.py
Normal file
File diff suppressed because it is too large
Load Diff
160
utils/tokenizer.py
Normal file
160
utils/tokenizer.py
Normal file
@@ -0,0 +1,160 @@
|
|||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
from enum import Enum
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from utils.logger import logger
|
||||||
|
|
||||||
|
|
||||||
|
class TokenizerType(str, Enum):
|
||||||
|
word = "word"
|
||||||
|
whitespace = "whitespace"
|
||||||
|
|
||||||
|
|
||||||
|
class LangType(str, Enum):
|
||||||
|
zh = "zh"
|
||||||
|
en = "en"
|
||||||
|
|
||||||
|
|
||||||
|
TOKENIZER_MAPPING = dict()
|
||||||
|
TOKENIZER_MAPPING['zh'] = TokenizerType.word
|
||||||
|
TOKENIZER_MAPPING['en'] = TokenizerType.whitespace
|
||||||
|
TOKENIZER_MAPPING['ru'] = TokenizerType.whitespace
|
||||||
|
TOKENIZER_MAPPING['ar'] = TokenizerType.whitespace
|
||||||
|
TOKENIZER_MAPPING['tr'] = TokenizerType.whitespace
|
||||||
|
TOKENIZER_MAPPING['es'] = TokenizerType.whitespace
|
||||||
|
TOKENIZER_MAPPING['pt'] = TokenizerType.whitespace
|
||||||
|
TOKENIZER_MAPPING['id'] = TokenizerType.whitespace
|
||||||
|
TOKENIZER_MAPPING['he'] = TokenizerType.whitespace
|
||||||
|
TOKENIZER_MAPPING['ja'] = TokenizerType.word
|
||||||
|
TOKENIZER_MAPPING['pl'] = TokenizerType.whitespace
|
||||||
|
TOKENIZER_MAPPING['de'] = TokenizerType.whitespace
|
||||||
|
TOKENIZER_MAPPING['fr'] = TokenizerType.whitespace
|
||||||
|
TOKENIZER_MAPPING['nl'] = TokenizerType.whitespace
|
||||||
|
TOKENIZER_MAPPING['el'] = TokenizerType.whitespace
|
||||||
|
TOKENIZER_MAPPING['vi'] = TokenizerType.whitespace
|
||||||
|
TOKENIZER_MAPPING['th'] = TokenizerType.whitespace
|
||||||
|
TOKENIZER_MAPPING['it'] = TokenizerType.whitespace
|
||||||
|
TOKENIZER_MAPPING['fa'] = TokenizerType.whitespace
|
||||||
|
TOKENIZER_MAPPING['ti'] = TokenizerType.word
|
||||||
|
|
||||||
|
import nltk
|
||||||
|
|
||||||
|
import re
|
||||||
|
from nltk.tokenize import word_tokenize
|
||||||
|
from nltk.stem import WordNetLemmatizer
|
||||||
|
lemmatizer = WordNetLemmatizer()
|
||||||
|
|
||||||
|
|
||||||
|
class Tokenizer:
|
||||||
|
@classmethod
|
||||||
|
def norm_and_tokenize(cls, sentences: List[str], lang: str = None):
|
||||||
|
tokenizer = TOKENIZER_MAPPING.get(lang, None)
|
||||||
|
sentences = cls.replace_general_punc(sentences, tokenizer)
|
||||||
|
sentences = cls.norm(sentences, lang)
|
||||||
|
return cls.tokenize(sentences, lang)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tokenize(cls, sentences: List[str], lang: str = None):
|
||||||
|
tokenizer = TOKENIZER_MAPPING.get(lang, None)
|
||||||
|
# sentences = cls.replace_general_punc(sentences, tokenizer)
|
||||||
|
if tokenizer == TokenizerType.word:
|
||||||
|
return [[ch for ch in sentence] for sentence in sentences]
|
||||||
|
elif tokenizer == TokenizerType.whitespace:
|
||||||
|
return [re.findall(r"\w+", sentence.lower()) for sentence in sentences]
|
||||||
|
else:
|
||||||
|
logger.error("找不到对应的分词器")
|
||||||
|
exit(-1)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def norm(cls, sentences: List[str], lang: LangType = None):
|
||||||
|
if lang == "zh":
|
||||||
|
from utils.speechio import textnorm_zh as textnorm
|
||||||
|
|
||||||
|
normalizer = textnorm.TextNorm(
|
||||||
|
to_banjiao=True,
|
||||||
|
to_upper=True,
|
||||||
|
to_lower=False,
|
||||||
|
remove_fillers=True,
|
||||||
|
remove_erhua=False, # 这里同批量识别不同,改成了 False
|
||||||
|
check_chars=False,
|
||||||
|
remove_space=False,
|
||||||
|
cc_mode="",
|
||||||
|
)
|
||||||
|
return [normalizer(sentence) for sentence in sentences]
|
||||||
|
elif lang == "en":
|
||||||
|
# pwd = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
# with open('./predict.txt', 'w', encoding='utf-8') as fp:
|
||||||
|
# for idx, sentence in enumerate(sentences):
|
||||||
|
# fp.write('%s\t%s\n' % (idx, sentence))
|
||||||
|
# subprocess.run(
|
||||||
|
# f'PYTHONPATH={pwd}/utils/speechio python {pwd}/utils/speechio/textnorm_en.py --has_key --to_upper ./predict.txt ./predict_norm.txt',
|
||||||
|
# shell=True,
|
||||||
|
# check=True,
|
||||||
|
# )
|
||||||
|
# sentence_norm = []
|
||||||
|
# with open('./predict_norm.txt', 'r', encoding='utf-8') as fp:
|
||||||
|
# for line in fp.readlines():
|
||||||
|
# line_split_result = line.strip().split('\t', 1)
|
||||||
|
# if len(line_split_result) >= 2:
|
||||||
|
# sentence_norm.append(line_split_result[1])
|
||||||
|
# else:
|
||||||
|
# sentence_norm.append("")
|
||||||
|
# # 有可能没有 norm 后就没了
|
||||||
|
# return sentence_norm
|
||||||
|
|
||||||
|
# sentence_norm = []
|
||||||
|
# for sentence in sentences:
|
||||||
|
# doc = _nlp_en(sentence)
|
||||||
|
# # 保留单词,去除标点、数字、特殊符号;做词形还原
|
||||||
|
# tokens = [token.lemma_ for token in doc if token.is_alpha]
|
||||||
|
# tokens = [t.upper() for t in tokens] # 根据你的原逻辑 to_upper=True
|
||||||
|
# sentence_norm.append(" ".join(tokens))
|
||||||
|
# return sentence_norm
|
||||||
|
result = []
|
||||||
|
for sentence in sentences:
|
||||||
|
sentence = re.sub(r"[^a-zA-Z\s]", "", sentence)
|
||||||
|
tokens = word_tokenize(sentence)
|
||||||
|
tokens = [lemmatizer.lemmatize(t) for t in tokens]
|
||||||
|
# if to_upper:
|
||||||
|
# tokens = [t.upper() for t in tokens]
|
||||||
|
result.append(" ".join(tokens))
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
punc = "!?。"#$%&'()*+,-/:;<=>[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘'‛“”„‟…‧﹏.`! #$%^&*()_+-=|';\":/.,?><~·!#¥%……&*()——+-=“:’;、。,?》《{}"
|
||||||
|
return [sentence.translate(str.maketrans(dict.fromkeys(punc, " "))).lower() for sentence in sentences]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def replace_general_punc(cls, sentences: List[str], tokenizer: TokenizerType,language:str = None) -> List[str]:
|
||||||
|
"""代替原来的函数 utils.metrics.cut_sentence"""
|
||||||
|
if language:
|
||||||
|
tokenizer = TOKENIZER_MAPPING.get(language)
|
||||||
|
general_puncs = [
|
||||||
|
"······",
|
||||||
|
"......",
|
||||||
|
"。",
|
||||||
|
",",
|
||||||
|
"?",
|
||||||
|
"!",
|
||||||
|
";",
|
||||||
|
":",
|
||||||
|
"...",
|
||||||
|
".",
|
||||||
|
",",
|
||||||
|
"?",
|
||||||
|
"!",
|
||||||
|
";",
|
||||||
|
":",
|
||||||
|
]
|
||||||
|
if tokenizer == TokenizerType.whitespace:
|
||||||
|
replacer = " "
|
||||||
|
else:
|
||||||
|
replacer = ""
|
||||||
|
trans = str.maketrans(dict.fromkeys("".join(general_puncs), replacer))
|
||||||
|
ret_sentences = [""] * len(sentences)
|
||||||
|
for i, sentence in enumerate(sentences):
|
||||||
|
sentence = sentence.translate(trans)
|
||||||
|
sentence = sentence.strip()
|
||||||
|
sentence = sentence.lower()
|
||||||
|
ret_sentences[i] = sentence
|
||||||
|
return ret_sentences
|
||||||
Reference in New Issue
Block a user